├── .circleci
└── config.yml
├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── MANIFEST.in
├── README.rst
├── appveyor.yml
├── doc
├── Makefile
├── _static
│ ├── css
│ │ └── project-template.css
│ ├── img
│ │ ├── cover.png
│ │ ├── dag1.png
│ │ ├── dag2.png
│ │ ├── dag2a.png
│ │ ├── dag3.png
│ │ ├── dag3a.png
│ │ ├── skdag-banner.png
│ │ ├── skdag-dark.png
│ │ └── stack.png
│ └── js
│ │ └── copybutton.js
├── api.rst
├── conf.py
├── index.rst
├── make.bat
├── quick_start.rst
└── user_guide.rst
├── environment.yml
├── examples
└── README.txt
├── img
├── skdag-banner.png
├── skdag-dark-fill.png
├── skdag-dark.kra
├── skdag-dark.png
├── skdag-fill.png
├── skdag.kra
└── skdag.png
├── requirements.txt
├── requirements_doc.txt
├── requirements_full.txt
├── requirements_test.txt
├── setup.cfg
├── setup.py
└── skdag
├── __init__.py
├── _version.py
├── dag
├── __init__.py
├── _builder.py
├── _dag.py
├── _render.py
├── _utils.py
└── tests
│ ├── __init__.py
│ ├── test_builder.py
│ ├── test_dag.py
│ └── utils.py
└── exceptions.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | jobs:
4 | build:
5 | docker:
6 | - image: circleci/python:3.6.1
7 | working_directory: ~/repo
8 | steps:
9 | - checkout
10 | - run:
11 | name: install dependencies
12 | command: |
13 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
14 | chmod +x miniconda.sh && ./miniconda.sh -b -p ~/miniconda
15 | export PATH="~/miniconda/bin:$PATH"
16 | conda update --yes --quiet conda
17 | conda create -n testenv --yes --quiet python=3
18 | source activate testenv
19 | conda install --yes pip numpy scipy scikit-learn matplotlib sphinx sphinx_rtd_theme numpydoc pillow
20 | pip install sphinx-gallery
21 | pip install .
22 | cd doc
23 | make html
24 | - store_artifacts:
25 | path: doc/_build/html/
26 | destination: doc
27 | - store_artifacts:
28 | path: ~/log.txt
29 | - run: ls -ltrh doc/_build/html
30 | filters:
31 | branches:
32 | ignore: gh-pages
33 |
34 | workflows:
35 | version: 2
36 | workflow:
37 | jobs:
38 | - build
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # scikit-learn specific
10 | doc/_build/
11 | doc/auto_examples/
12 | doc/modules/generated/
13 | doc/datasets/generated/
14 |
15 | # Distribution / packaging
16 |
17 | .Python
18 | env/
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *,cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 |
62 | # Sphinx documentation
63 | doc/_build/
64 | doc/generated/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # Jupyter
70 | .ipynb_checkpoints
71 |
72 | # Misc
73 | *.swp
74 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | formats:
2 | - none
3 | requirements_file: requirements.txt
4 | python:
5 | pip_install: true
6 | extra_requirements:
7 | - test
8 | - doc
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 big-o
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.*
2 | include requirements*.txt
3 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. -*- mode: rst -*-
2 |
3 | |AppVeyor|_ |Codecov|_ |ReadTheDocs|_
4 |
5 | .. |AppVeyor| image:: https://ci.appveyor.com/api/projects/status/github/big-o/skdag?branch=main&svg=true
6 | .. _AppVeyor: https://ci.appveyor.com/project/big-o/skdag
7 |
8 | .. |Codecov| image:: https://codecov.io/gh/big-o/skdag/branch/main/graph/badge.svg
9 | .. _Codecov: https://codecov.io/gh/big-o/skdag
10 |
11 | .. |ReadTheDocs| image:: https://readthedocs.org/projects/skdag/badge/?version=latest
12 | .. _ReadTheDocs: https://skdag.readthedocs.io/en/latest/?badge=latest
13 |
14 | skdag - A more flexible alternative to scikit-learn Pipelines
15 | =============================================================
16 |
17 | .. image:: img/skdag-banner.png
18 |
19 | scikit-dag (``skdag``) is an open-sourced, MIT-licenced library that provides advanced
20 | workflow management to any machine learning operations that follow
21 | scikit-learn_ conventions. Installation is simple:
22 |
23 | .. code-block:: bash
24 |
25 | pip install skdag
26 |
27 | It works by introducing Directed Acyclic Graphs as a drop-in replacement for traditional
28 | scikit-learn ``Pipeline``. This gives you a simple interface for a range of use cases
29 | including complex pre-processing, model stacking and benchmarking.
30 |
31 | .. code-block:: python
32 |
33 | from skdag import DAGBuilder
34 |
35 | dag = (
36 | DAGBuilder(infer_dataframe=True)
37 | .add_step("impute", SimpleImputer())
38 | .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]})
39 | .add_step(
40 | "blood",
41 | PCA(n_components=2, random_state=0),
42 | deps={"impute": ["s1", "s2", "s3", "s4", "s5", "s6"]}
43 | )
44 | .add_step(
45 | "rf",
46 | RandomForestRegressor(max_depth=5, random_state=0),
47 | deps=["blood", "vitals"]
48 | )
49 | .add_step("svm", SVR(C=0.7), deps=["blood", "vitals"])
50 | .add_step(
51 | "knn",
52 | KNeighborsRegressor(n_neighbors=5),
53 | deps=["blood", "vitals"]
54 | )
55 | .add_step("meta", LinearRegression(), deps=["rf", "svm", "knn"])
56 | .make_dag()
57 | )
58 |
59 | dag.show(detailed=True)
60 |
61 | .. image:: doc/_static/img/cover.png
62 |
63 | The above DAG imputes missing values, runs PCA on the columns relating to blood test
64 | results and leaves the other columns as they are. Then they get passed to three
65 | different regressors before being passed onto a final meta-estimator. Because DAGs
66 | (unlike pipelines) allow predictors in the middle or a workflow, you can use them to
67 | implement model stacking. We also chose to run the DAG steps in parallel wherever
68 | possible.
69 |
70 | After building our DAG, we can treat it as any other estimator:
71 |
72 | .. code-block:: python
73 |
74 | from sklearn import datasets
75 |
76 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
77 | X_train, X_test, y_train, y_test = train_test_split(
78 | X, y, test_size=0.2, random_state=0
79 | )
80 |
81 | dag.fit(X_train, y_train)
82 | dag.predict(X_test)
83 |
84 | Just like a pipeline, you can optimise it with a gridsearch, pickle it etc.
85 |
86 | Note that this package does not deal with things like delayed dependencies and
87 | distributed architectures - consider an `established `_
88 | `solution `_ for such use cases. ``skdag`` is just for building and
89 | executing local ensembles from estimators.
90 |
91 | `Read on `_ to learn more about ``skdag``...
92 |
93 | .. _scikit-learn: https://scikit-learn.org
94 |
--------------------------------------------------------------------------------
/appveyor.yml:
--------------------------------------------------------------------------------
1 | build: false
2 |
3 | environment:
4 | matrix:
5 | - APPVEYOR_BUILD_WORKER_IMAGE: Ubuntu
6 | APPVEYOR_YML_DISABLE_PS_LINUX: true
7 |
8 | stack: python 3.8
9 |
10 | install: |
11 | if [[ "${APPVEYOR_BUILD_WORKER_IMAGE}" == "Ubuntu" ]]; then
12 | sudo apt update
13 | sudo apt install -y graphviz libgraphviz-dev
14 | elif [[ "${APPVEYOR_BUILD_WORKER_IMAGE}" == "macOS" ]]; then
15 | brew update
16 | brew install graphviz
17 | fi
18 | pip install --upgrade pip
19 | for f in $(find . -maxdepth 1 -name 'requirements*.txt'); do
20 | pip install -r ${f}
21 | done
22 | pip install .
23 |
24 | test_script:
25 | - mkdir for_test
26 | - cd for_test
27 | - pytest -v --cov=skdag --pyargs skdag
28 |
29 | after_test:
30 | - cp .coverage ${APPVEYOR_BUILD_FOLDER}
31 | - cd ${APPVEYOR_BUILD_FOLDER}
32 | - curl -Os https://uploader.codecov.io/latest/linux/codecov
33 | - chmod +x codecov
34 | - ./codecov
35 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # User-friendly check for sphinx-build
11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
13 | endif
14 |
15 | # Internal variables.
16 | PAPEROPT_a4 = -D latex_paper_size=a4
17 | PAPEROPT_letter = -D latex_paper_size=letter
18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
19 | # the i18n builder cannot share the environment and doctrees with the others
20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
21 |
22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
23 |
24 | help:
25 | @echo "Please use \`make ' where is one of"
26 | @echo " html to make standalone HTML files"
27 | @echo " dirhtml to make HTML files named index.html in directories"
28 | @echo " singlehtml to make a single large HTML file"
29 | @echo " pickle to make pickle files"
30 | @echo " json to make JSON files"
31 | @echo " htmlhelp to make HTML files and a HTML help project"
32 | @echo " qthelp to make HTML files and a qthelp project"
33 | @echo " devhelp to make HTML files and a Devhelp project"
34 | @echo " epub to make an epub"
35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
36 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
38 | @echo " text to make text files"
39 | @echo " man to make manual pages"
40 | @echo " texinfo to make Texinfo files"
41 | @echo " info to make Texinfo files and run them through makeinfo"
42 | @echo " gettext to make PO message catalogs"
43 | @echo " changes to make an overview of all changed/added/deprecated items"
44 | @echo " xml to make Docutils-native XML files"
45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
46 | @echo " linkcheck to check all external links for integrity"
47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
48 |
49 | clean:
50 | -rm -rf $(BUILDDIR)/*
51 | -rm -rf auto_examples/
52 | -rm -rf generated/*
53 | -rm -rf modules/generated/*
54 |
55 | html:
56 | # These two lines make the build a bit more lengthy, and the
57 | # the embedding of images more robust
58 | rm -rf $(BUILDDIR)/html/_images
59 | #rm -rf _build/doctrees/
60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
61 | @echo
62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
63 |
64 | dirhtml:
65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
66 | @echo
67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
68 |
69 | singlehtml:
70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
71 | @echo
72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
73 |
74 | pickle:
75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
76 | @echo
77 | @echo "Build finished; now you can process the pickle files."
78 |
79 | json:
80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
81 | @echo
82 | @echo "Build finished; now you can process the JSON files."
83 |
84 | htmlhelp:
85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
86 | @echo
87 | @echo "Build finished; now you can run HTML Help Workshop with the" \
88 | ".hhp project file in $(BUILDDIR)/htmlhelp."
89 |
90 | qthelp:
91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
92 | @echo
93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp"
96 | @echo "To view the help file:"
97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc"
98 |
99 | devhelp:
100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
101 | @echo
102 | @echo "Build finished."
103 | @echo "To view the help file:"
104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template"
105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template"
106 | @echo "# devhelp"
107 |
108 | epub:
109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
110 | @echo
111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
112 |
113 | latex:
114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
115 | @echo
116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
118 | "(use \`make latexpdf' here to do that automatically)."
119 |
120 | latexpdf:
121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
122 | @echo "Running LaTeX files through pdflatex..."
123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
125 |
126 | latexpdfja:
127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
128 | @echo "Running LaTeX files through platex and dvipdfmx..."
129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
131 |
132 | text:
133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
134 | @echo
135 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
136 |
137 | man:
138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
139 | @echo
140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
141 |
142 | texinfo:
143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
144 | @echo
145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
146 | @echo "Run \`make' in that directory to run these through makeinfo" \
147 | "(use \`make info' here to do that automatically)."
148 |
149 | info:
150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
151 | @echo "Running Texinfo files through makeinfo..."
152 | make -C $(BUILDDIR)/texinfo info
153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
154 |
155 | gettext:
156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
157 | @echo
158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
159 |
160 | changes:
161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
162 | @echo
163 | @echo "The overview file is in $(BUILDDIR)/changes."
164 |
165 | linkcheck:
166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
167 | @echo
168 | @echo "Link check complete; look for any errors in the above output " \
169 | "or in $(BUILDDIR)/linkcheck/output.txt."
170 |
171 | doctest:
172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
173 | @echo "Testing of doctests in the sources finished, look at the " \
174 | "results in $(BUILDDIR)/doctest/output.txt."
175 |
176 | xml:
177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
178 | @echo
179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
180 |
181 | pseudoxml:
182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
183 | @echo
184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
185 |
--------------------------------------------------------------------------------
/doc/_static/css/project-template.css:
--------------------------------------------------------------------------------
1 | @import url("theme.css");
2 |
3 | .highlight a {
4 | text-decoration: underline;
5 | }
6 |
7 | .deprecated p {
8 | padding: 10px 7px 10px 10px;
9 | color: #b94a48;
10 | background-color: #F3E5E5;
11 | border: 1px solid #eed3d7;
12 | }
13 |
14 | .deprecated p span.versionmodified {
15 | font-weight: bold;
16 | }
17 |
18 | img.logo {
19 | max-height: 100px;
20 | }
21 |
--------------------------------------------------------------------------------
/doc/_static/img/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/cover.png
--------------------------------------------------------------------------------
/doc/_static/img/dag1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag1.png
--------------------------------------------------------------------------------
/doc/_static/img/dag2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag2.png
--------------------------------------------------------------------------------
/doc/_static/img/dag2a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag2a.png
--------------------------------------------------------------------------------
/doc/_static/img/dag3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag3.png
--------------------------------------------------------------------------------
/doc/_static/img/dag3a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/dag3a.png
--------------------------------------------------------------------------------
/doc/_static/img/skdag-banner.png:
--------------------------------------------------------------------------------
1 | ../../../img/skdag-banner.png
--------------------------------------------------------------------------------
/doc/_static/img/skdag-dark.png:
--------------------------------------------------------------------------------
1 | ../../../img/skdag-dark.png
--------------------------------------------------------------------------------
/doc/_static/img/stack.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/doc/_static/img/stack.png
--------------------------------------------------------------------------------
/doc/_static/js/copybutton.js:
--------------------------------------------------------------------------------
1 | $(document).ready(function() {
2 | /* Add a [>>>] button on the top-right corner of code samples to hide
3 | * the >>> and ... prompts and the output and thus make the code
4 | * copyable. */
5 | var div = $('.highlight-python .highlight,' +
6 | '.highlight-python3 .highlight,' +
7 | '.highlight-pycon .highlight,' +
8 | '.highlight-default .highlight')
9 | var pre = div.find('pre');
10 |
11 | // get the styles from the current theme
12 | pre.parent().parent().css('position', 'relative');
13 | var hide_text = 'Hide the prompts and output';
14 | var show_text = 'Show the prompts and output';
15 | var border_width = pre.css('border-top-width');
16 | var border_style = pre.css('border-top-style');
17 | var border_color = pre.css('border-top-color');
18 | var button_styles = {
19 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0',
20 | 'border-color': border_color, 'border-style': border_style,
21 | 'border-width': border_width, 'color': border_color, 'text-size': '75%',
22 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em',
23 | 'border-radius': '0 3px 0 0'
24 | }
25 |
26 | // create and add the button to all the code blocks that contain >>>
27 | div.each(function(index) {
28 | var jthis = $(this);
29 | if (jthis.find('.gp').length > 0) {
30 | var button = $('>>>');
31 | button.css(button_styles)
32 | button.attr('title', hide_text);
33 | button.data('hidden', 'false');
34 | jthis.prepend(button);
35 | }
36 | // tracebacks (.gt) contain bare text elements that need to be
37 | // wrapped in a span to work with .nextUntil() (see later)
38 | jthis.find('pre:has(.gt)').contents().filter(function() {
39 | return ((this.nodeType == 3) && (this.data.trim().length > 0));
40 | }).wrap('');
41 | });
42 |
43 | // define the behavior of the button when it's clicked
44 | $('.copybutton').click(function(e){
45 | e.preventDefault();
46 | var button = $(this);
47 | if (button.data('hidden') === 'false') {
48 | // hide the code output
49 | button.parent().find('.go, .gp, .gt').hide();
50 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden');
51 | button.css('text-decoration', 'line-through');
52 | button.attr('title', show_text);
53 | button.data('hidden', 'true');
54 | } else {
55 | // show the code output
56 | button.parent().find('.go, .gp, .gt').show();
57 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible');
58 | button.css('text-decoration', 'none');
59 | button.attr('title', hide_text);
60 | button.data('hidden', 'false');
61 | }
62 | });
63 | });
64 |
--------------------------------------------------------------------------------
/doc/api.rst:
--------------------------------------------------------------------------------
1 | #########
2 | skdag API
3 | #########
4 |
5 | See below for detailed descriptions of the ``skdag`` interface.
6 |
7 | .. currentmodule:: skdag
8 |
9 | Estimator
10 | =========
11 |
12 | .. autosummary::
13 | :toctree: generated/
14 | :template: class.rst
15 |
16 | DAG
17 |
18 | Exceptions
19 | ==========
20 |
21 | .. autosummary::
22 | :toctree: generated/
23 | :template: class.rst
24 |
25 | exceptions.DAGError
26 |
27 | Utilities
28 | =========
29 |
30 | .. autosummary::
31 | :toctree: generated/
32 | :template: class.rst
33 |
34 | DAGBuilder
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # project-template documentation build configuration file, created by
4 | # sphinx-quickstart on Mon Jan 18 14:44:12 2016.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 |
15 | import sys
16 | import os
17 |
18 | import sphinx_gallery
19 | import sphinx_rtd_theme
20 |
21 | # Add to sys.path the top-level directory where the package is located.
22 | sys.path.insert(0, os.path.abspath(".."))
23 |
24 | # If extensions (or modules to document with autodoc) are in another directory,
25 | # add these directories to sys.path here. If the directory is relative to the
26 | # documentation root, use os.path.abspath to make it absolute, like shown here.
27 | # sys.path.insert(0, os.path.abspath('.'))
28 |
29 | # -- General configuration ------------------------------------------------
30 |
31 | # If your documentation needs a minimal Sphinx version, state it here.
32 | # needs_sphinx = '1.0'
33 |
34 | # Add any Sphinx extension module names here, as strings. They can be
35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
36 | # ones.
37 | extensions = [
38 | "sphinx.ext.autodoc",
39 | "sphinx.ext.autosummary",
40 | "sphinx.ext.doctest",
41 | "sphinx.ext.intersphinx",
42 | "sphinx.ext.viewcode",
43 | "numpydoc",
44 | "sphinx_gallery.gen_gallery",
45 | ]
46 |
47 | # this is needed for some reason...
48 | # see https://github.com/numpy/numpydoc/issues/69
49 | numpydoc_show_class_members = False
50 |
51 | # pngmath / imgmath compatibility layer for different sphinx versions
52 | import sphinx
53 | from distutils.version import LooseVersion
54 |
55 | if LooseVersion(sphinx.__version__) < LooseVersion("1.4"):
56 | extensions.append("sphinx.ext.pngmath")
57 | else:
58 | extensions.append("sphinx.ext.imgmath")
59 |
60 | autodoc_default_flags = ["members", "inherited-members"]
61 |
62 | # Add any paths that contain templates here, relative to this directory.
63 | templates_path = ["_templates"]
64 |
65 | # generate autosummary even if no references
66 | autosummary_generate = True
67 |
68 | # The suffix of source filenames.
69 | source_suffix = ".rst"
70 |
71 | # The encoding of source files.
72 | # source_encoding = 'utf-8-sig'
73 |
74 | # Generate the plots for the gallery
75 | plot_gallery = False
76 |
77 | # The master toctree document.
78 | master_doc = "index"
79 |
80 | # General information about the project.
81 | project = "skdag"
82 | copyright = "2022, big-o (github)"
83 |
84 | # The version info for the project you're documenting, acts as replacement for
85 | # |version| and |release|, also used in various other places throughout the
86 | # built documents.
87 | #
88 | # The short X.Y version.
89 | from skdag import __version__
90 |
91 | version = __version__
92 | # The full version, including alpha/beta/rc tags.
93 | release = __version__
94 |
95 | # The language for content autogenerated by Sphinx. Refer to documentation
96 | # for a list of supported languages.
97 | # language = None
98 |
99 | # There are two options for replacing |today|: either, you set today to some
100 | # non-false value, then it is used:
101 | # today = ''
102 | # Else, today_fmt is used as the format for a strftime call.
103 | # today_fmt = '%B %d, %Y'
104 |
105 | # List of patterns, relative to source directory, that match files and
106 | # directories to ignore when looking for source files.
107 | exclude_patterns = ["_build", "_templates"]
108 |
109 | # The reST default role (used for this markup: `text`) to use for all
110 | # documents.
111 | # default_role = None
112 |
113 | # If true, '()' will be appended to :func: etc. cross-reference text.
114 | # add_function_parentheses = True
115 |
116 | # If true, the current module name will be prepended to all description
117 | # unit titles (such as .. function::).
118 | # add_module_names = True
119 |
120 | # If true, sectionauthor and moduleauthor directives will be shown in the
121 | # output. They are ignored by default.
122 | # show_authors = False
123 |
124 | # The name of the Pygments (syntax highlighting) style to use.
125 | pygments_style = "sphinx"
126 |
127 | # Custom style
128 | html_style = "css/project-template.css"
129 |
130 | # A list of ignored prefixes for module index sorting.
131 | # modindex_common_prefix = []
132 |
133 | # If true, keep warnings as "system message" paragraphs in the built documents.
134 | # keep_warnings = False
135 |
136 |
137 | # -- Options for HTML output ----------------------------------------------
138 |
139 | # The theme to use for HTML and HTML Help pages. See the documentation for
140 | # a list of builtin themes.
141 | html_theme = "sphinx_rtd_theme"
142 |
143 | # Theme options are theme-specific and customize the look and feel of a theme
144 | # further. For a list of options available for each theme, see the
145 | # documentation.
146 | html_theme_options = {
147 | "logo_only": True,
148 | }
149 |
150 | # Add any paths that contain custom themes here, relative to this directory.
151 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
152 |
153 | # The name for this set of Sphinx documents. If None, it defaults to
154 | # " v documentation".
155 | # html_title = None
156 |
157 | # A shorter title for the navigation bar. Default is the same as html_title.
158 | # html_short_title = None
159 |
160 | # The name of an image file (relative to this directory) to place at the top
161 | # of the sidebar.
162 | html_logo = "_static/img/skdag-dark.png"
163 |
164 | # The name of an image file (within the static path) to use as favicon of the
165 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
166 | # pixels large.
167 | # html_favicon = None
168 |
169 | # Add any paths that contain custom static files (such as style sheets) here,
170 | # relative to this directory. They are copied after the builtin static files,
171 | # so a file named "default.css" will overwrite the builtin "default.css".
172 | html_static_path = ["_static"]
173 |
174 | # Add any extra paths that contain custom files (such as robots.txt or
175 | # .htaccess) here, relative to this directory. These files are copied
176 | # directly to the root of the documentation.
177 | # html_extra_path = []
178 |
179 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
180 | # using the given strftime format.
181 | # html_last_updated_fmt = '%b %d, %Y'
182 |
183 | # If true, SmartyPants will be used to convert quotes and dashes to
184 | # typographically correct entities.
185 | # html_use_smartypants = True
186 |
187 | # Custom sidebar templates, maps document names to template names.
188 | # html_sidebars = {}
189 |
190 | # Additional templates that should be rendered to pages, maps page names to
191 | # template names.
192 | # html_additional_pages = {}
193 |
194 | # If false, no module index is generated.
195 | # html_domain_indices = True
196 |
197 | # If false, no index is generated.
198 | # html_use_index = True
199 |
200 | # If true, the index is split into individual pages for each letter.
201 | # html_split_index = False
202 |
203 | # If true, links to the reST sources are added to the pages.
204 | # html_show_sourcelink = True
205 |
206 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
207 | # html_show_sphinx = True
208 |
209 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
210 | # html_show_copyright = True
211 |
212 | # If true, an OpenSearch description file will be output, and all pages will
213 | # contain a tag referring to it. The value of this option must be the
214 | # base URL from which the finished HTML is served.
215 | # html_use_opensearch = ''
216 |
217 | # This is the file name suffix for HTML files (e.g. ".xhtml").
218 | # html_file_suffix = None
219 |
220 | # Output file base name for HTML help builder.
221 | htmlhelp_basename = "project-templatedoc"
222 |
223 |
224 | # -- Options for LaTeX output ---------------------------------------------
225 |
226 | latex_elements = {
227 | # The paper size ('letterpaper' or 'a4paper').
228 | #'papersize': 'letterpaper',
229 | # The font size ('10pt', '11pt' or '12pt').
230 | #'pointsize': '10pt',
231 | # Additional stuff for the LaTeX preamble.
232 | #'preamble': '',
233 | }
234 |
235 | # Grouping the document tree into LaTeX files. List of tuples
236 | # (source start file, target name, title,
237 | # author, documentclass [howto, manual, or own class]).
238 | latex_documents = [
239 | (
240 | "index",
241 | "project-template.tex",
242 | "project-template Documentation",
243 | "big-o",
244 | "manual",
245 | ),
246 | ]
247 |
248 | # The name of an image file (relative to this directory) to place at the top of
249 | # the title page.
250 | # latex_logo = None
251 |
252 | # For "manual" documents, if this is true, then toplevel headings are parts,
253 | # not chapters.
254 | # latex_use_parts = False
255 |
256 | # If true, show page references after internal links.
257 | # latex_show_pagerefs = False
258 |
259 | # If true, show URL addresses after external links.
260 | # latex_show_urls = False
261 |
262 | # Documents to append as an appendix to all manuals.
263 | # latex_appendices = []
264 |
265 | # If false, no module index is generated.
266 | # latex_domain_indices = True
267 |
268 |
269 | # -- Options for manual page output ---------------------------------------
270 |
271 | # One entry per manual page. List of tuples
272 | # (source start file, name, description, authors, manual section).
273 | man_pages = [
274 | (
275 | "index",
276 | "project-template",
277 | "project-template Documentation",
278 | ["big-o"],
279 | 1,
280 | )
281 | ]
282 |
283 | # If true, show URL addresses after external links.
284 | # man_show_urls = False
285 |
286 |
287 | # -- Options for Texinfo output -------------------------------------------
288 |
289 | # Grouping the document tree into Texinfo files. List of tuples
290 | # (source start file, target name, title, author,
291 | # dir menu entry, description, category)
292 | texinfo_documents = [
293 | (
294 | "index",
295 | "project-template",
296 | "project-template Documentation",
297 | "big-o",
298 | "project-template",
299 | "One line description of project.",
300 | "Miscellaneous",
301 | ),
302 | ]
303 |
304 | # Documents to append as an appendix to all manuals.
305 | # texinfo_appendices = []
306 |
307 | # If false, no module index is generated.
308 | # texinfo_domain_indices = True
309 |
310 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
311 | # texinfo_show_urls = 'footnote'
312 |
313 | # If true, do not generate a @detailmenu in the "Top" node's menu.
314 | # texinfo_no_detailmenu = False
315 |
316 |
317 | # Example configuration for intersphinx: refer to the Python standard library.
318 | # intersphinx configuration
319 | intersphinx_mapping = {
320 | "python": ("https://docs.python.org/{.major}".format(sys.version_info), None),
321 | "numpy": ("https://numpy.org/doc/stable", None),
322 | "scipy": ("https://docs.scipy.org/doc/scipy", None),
323 | "matplotlib": ("https://matplotlib.org/stable", None),
324 | "sklearn": ("https://scikit-learn.org/stable", None),
325 | }
326 |
327 | # sphinx-gallery configuration
328 | sphinx_gallery_conf = {
329 | "doc_module": "skdag",
330 | "backreferences_dir": os.path.join("generated"),
331 | "reference_url": {"skdag": None},
332 | }
333 |
334 |
335 | def setup(app):
336 | # a copy button to copy snippet of code from the documentation
337 | app.add_js_file("js/copybutton.js")
338 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | skdag - scikit-learn workflow management
2 | ============================================
3 |
4 | scikit-dag (``skdag``) is an open-sourced, MIT-licenced library that provides advanced
5 | workflow management to any machine learning operations that follow
6 | :mod:`sklearn` conventions. It does this by introducing Directed Acyclic
7 | Graphs (:class:`skdag.dag.DAG`) as a drop-in replacement for traditional scikit-learn
8 | :mod:`sklearn.pipeline.Pipeline`. This gives you a simple interface for a range of use
9 | cases including complex pre-processing, model stacking and benchmarking.
10 |
11 | .. code-block:: python
12 |
13 | from skdag import DAGBuilder
14 |
15 | dag = (
16 | DAGBuilder(infer_dataframe=True)
17 | .add_step("impute", SimpleImputer())
18 | .add_step(
19 | "vitals",
20 | "passthrough",
21 | deps={"impute": ["age", "sex", "bmi", "bp"]},
22 | )
23 | .add_step(
24 | "blood",
25 | PCA(n_components=2, random_state=0),
26 | deps={"impute": ["s1", "s2", "s3", "s4", "s5", "s6"]},
27 | )
28 | .add_step(
29 | "rf",
30 | RandomForestRegressor(max_depth=5, random_state=0),
31 | deps=["blood", "vitals"],
32 | )
33 | .add_step("svm", SVR(C=0.7), deps=["blood", "vitals"])
34 | .add_step(
35 | "knn",
36 | KNeighborsRegressor(n_neighbors=5),
37 | deps=["blood", "vitals"],
38 | )
39 | .add_step("meta", LinearRegression(), deps=["rf", "svm", "knn"])
40 | .make_dag(n_jobs=2, verbose=True)
41 | )
42 |
43 | dag.show(detailed=True)
44 |
45 | .. image:: _static/img/cover.png
46 |
47 | The above DAG imputes missing values, runs PCA on the columns relating to blood test
48 | results and leaves the other columns as they are. Then they get passed to three
49 | different regressors before being passed onto a final meta-estimator. Because DAGs
50 | (unlike pipelines) allow predictors in the middle or a workflow, you can use them to
51 | implement model stacking. We also chose to run the DAG steps in parallel wherever
52 | possible.
53 |
54 | After building our DAG, we can treat it as any other estimator:
55 |
56 | .. code-block:: python
57 |
58 | from sklearn import datasets
59 |
60 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
61 | X_train, X_test, y_train, y_test = train_test_split(
62 | X, y, test_size=0.2, random_state=0
63 | )
64 |
65 | dag.fit(X_train, y_train)
66 | dag.predict(X_test)
67 |
68 | Just like a pipeline, you can optimise it with a gridsearch, pickle it etc.
69 |
70 | Note that this package does not deal with things like delayed dependencies and
71 | distributed architectures - consider an `established `_
72 | `solution `_ for such use cases. ``skdag`` is just for building and
73 | executing local ensembles from estimators.
74 |
75 | :ref:`Read on` to learn more about ``skdag``...
76 |
77 | .. toctree::
78 | :maxdepth: 2
79 | :hidden:
80 | :caption: Getting Started
81 |
82 | quick_start
83 |
84 | .. toctree::
85 | :maxdepth: 2
86 | :hidden:
87 | :caption: Documentation
88 |
89 | user_guide
90 |
91 | .. toctree::
92 | :maxdepth: 2
93 | :hidden:
94 | :caption: API
95 |
96 | api
97 |
98 | .. toctree::
99 | :maxdepth: 2
100 | :hidden:
101 | :caption: Tutorial - Examples
102 |
103 | auto_examples/index
104 |
105 | `Getting started `_
106 | -------------------------------------
107 |
108 | A practical introduction to DAGs for scikit-learn.
109 |
110 | `User Guide `_
111 | -------------------------------
112 |
113 | Details of the full functionality provided by ``skdag``.
114 |
115 | `API Documentation `_
116 | -------------------------------
117 |
118 | Detailed API documentation.
119 |
120 | `Examples `_
121 | --------------------------------------
122 |
123 | Further examples that complement the `User Guide `_.
124 |
--------------------------------------------------------------------------------
/doc/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | REM Command file for Sphinx documentation
4 |
5 | if "%SPHINXBUILD%" == "" (
6 | set SPHINXBUILD=sphinx-build
7 | )
8 | set BUILDDIR=_build
9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
10 | set I18NSPHINXOPTS=%SPHINXOPTS% .
11 | if NOT "%PAPER%" == "" (
12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
14 | )
15 |
16 | if "%1" == "" goto help
17 |
18 | if "%1" == "help" (
19 | :help
20 | echo.Please use `make ^` where ^ is one of
21 | echo. html to make standalone HTML files
22 | echo. dirhtml to make HTML files named index.html in directories
23 | echo. singlehtml to make a single large HTML file
24 | echo. pickle to make pickle files
25 | echo. json to make JSON files
26 | echo. htmlhelp to make HTML files and a HTML help project
27 | echo. qthelp to make HTML files and a qthelp project
28 | echo. devhelp to make HTML files and a Devhelp project
29 | echo. epub to make an epub
30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
31 | echo. text to make text files
32 | echo. man to make manual pages
33 | echo. texinfo to make Texinfo files
34 | echo. gettext to make PO message catalogs
35 | echo. changes to make an overview over all changed/added/deprecated items
36 | echo. xml to make Docutils-native XML files
37 | echo. pseudoxml to make pseudoxml-XML files for display purposes
38 | echo. linkcheck to check all external links for integrity
39 | echo. doctest to run all doctests embedded in the documentation if enabled
40 | goto end
41 | )
42 |
43 | if "%1" == "clean" (
44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
45 | del /q /s %BUILDDIR%\*
46 | goto end
47 | )
48 |
49 |
50 | %SPHINXBUILD% 2> nul
51 | if errorlevel 9009 (
52 | echo.
53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
54 | echo.installed, then set the SPHINXBUILD environment variable to point
55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
56 | echo.may add the Sphinx directory to PATH.
57 | echo.
58 | echo.If you don't have Sphinx installed, grab it from
59 | echo.http://sphinx-doc.org/
60 | exit /b 1
61 | )
62 |
63 | if "%1" == "html" (
64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
65 | if errorlevel 1 exit /b 1
66 | echo.
67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html.
68 | goto end
69 | )
70 |
71 | if "%1" == "dirhtml" (
72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
73 | if errorlevel 1 exit /b 1
74 | echo.
75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
76 | goto end
77 | )
78 |
79 | if "%1" == "singlehtml" (
80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
81 | if errorlevel 1 exit /b 1
82 | echo.
83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
84 | goto end
85 | )
86 |
87 | if "%1" == "pickle" (
88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
89 | if errorlevel 1 exit /b 1
90 | echo.
91 | echo.Build finished; now you can process the pickle files.
92 | goto end
93 | )
94 |
95 | if "%1" == "json" (
96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
97 | if errorlevel 1 exit /b 1
98 | echo.
99 | echo.Build finished; now you can process the JSON files.
100 | goto end
101 | )
102 |
103 | if "%1" == "htmlhelp" (
104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
105 | if errorlevel 1 exit /b 1
106 | echo.
107 | echo.Build finished; now you can run HTML Help Workshop with the ^
108 | .hhp project file in %BUILDDIR%/htmlhelp.
109 | goto end
110 | )
111 |
112 | if "%1" == "qthelp" (
113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
114 | if errorlevel 1 exit /b 1
115 | echo.
116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^
117 | .qhcp project file in %BUILDDIR%/qthelp, like this:
118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp
119 | echo.To view the help file:
120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc
121 | goto end
122 | )
123 |
124 | if "%1" == "devhelp" (
125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
126 | if errorlevel 1 exit /b 1
127 | echo.
128 | echo.Build finished.
129 | goto end
130 | )
131 |
132 | if "%1" == "epub" (
133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
134 | if errorlevel 1 exit /b 1
135 | echo.
136 | echo.Build finished. The epub file is in %BUILDDIR%/epub.
137 | goto end
138 | )
139 |
140 | if "%1" == "latex" (
141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
142 | if errorlevel 1 exit /b 1
143 | echo.
144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
145 | goto end
146 | )
147 |
148 | if "%1" == "latexpdf" (
149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
150 | cd %BUILDDIR%/latex
151 | make all-pdf
152 | cd %BUILDDIR%/..
153 | echo.
154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex.
155 | goto end
156 | )
157 |
158 | if "%1" == "latexpdfja" (
159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
160 | cd %BUILDDIR%/latex
161 | make all-pdf-ja
162 | cd %BUILDDIR%/..
163 | echo.
164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex.
165 | goto end
166 | )
167 |
168 | if "%1" == "text" (
169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
170 | if errorlevel 1 exit /b 1
171 | echo.
172 | echo.Build finished. The text files are in %BUILDDIR%/text.
173 | goto end
174 | )
175 |
176 | if "%1" == "man" (
177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
178 | if errorlevel 1 exit /b 1
179 | echo.
180 | echo.Build finished. The manual pages are in %BUILDDIR%/man.
181 | goto end
182 | )
183 |
184 | if "%1" == "texinfo" (
185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
186 | if errorlevel 1 exit /b 1
187 | echo.
188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
189 | goto end
190 | )
191 |
192 | if "%1" == "gettext" (
193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
194 | if errorlevel 1 exit /b 1
195 | echo.
196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
197 | goto end
198 | )
199 |
200 | if "%1" == "changes" (
201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
202 | if errorlevel 1 exit /b 1
203 | echo.
204 | echo.The overview file is in %BUILDDIR%/changes.
205 | goto end
206 | )
207 |
208 | if "%1" == "linkcheck" (
209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
210 | if errorlevel 1 exit /b 1
211 | echo.
212 | echo.Link check complete; look for any errors in the above output ^
213 | or in %BUILDDIR%/linkcheck/output.txt.
214 | goto end
215 | )
216 |
217 | if "%1" == "doctest" (
218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
219 | if errorlevel 1 exit /b 1
220 | echo.
221 | echo.Testing of doctests in the sources finished, look at the ^
222 | results in %BUILDDIR%/doctest/output.txt.
223 | goto end
224 | )
225 |
226 | if "%1" == "xml" (
227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
228 | if errorlevel 1 exit /b 1
229 | echo.
230 | echo.Build finished. The XML files are in %BUILDDIR%/xml.
231 | goto end
232 | )
233 |
234 | if "%1" == "pseudoxml" (
235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
236 | if errorlevel 1 exit /b 1
237 | echo.
238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
239 | goto end
240 | )
241 |
242 | :end
243 |
--------------------------------------------------------------------------------
/doc/quick_start.rst:
--------------------------------------------------------------------------------
1 | .. _quickstart:
2 |
3 | ######################
4 | Quick Start with skdag
5 | ######################
6 |
7 | The following tutorial shows you how to write some simple directed acyclic graphs (DAGs)
8 | with ``skdag``.
9 |
10 | Installation
11 | ============
12 |
13 | Installing skdag is simple:
14 |
15 | .. code:: bash
16 |
17 | pip install skdag
18 |
19 | Note that to visualise graphs you need to install the graphviz libraries too. See the
20 | `pygraphviz documentation `_ for installation guidance.
21 |
22 | Creating a DAG
23 | ==============
24 |
25 | The simplest DAGs are just a chain of singular dependencies. These DAGs may be
26 | created from the :meth:`skdag.dag.DAG.from_pipeline` method in the same way as a
27 | DAG:
28 |
29 | .. code-block:: python
30 |
31 | >>> from skdag import DAGBuilder
32 | >>> from sklearn.decomposition import PCA
33 | >>> from sklearn.impute import SimpleImputer
34 | >>> from sklearn.linear_model import LogisticRegression
35 | >>> dag = DAGBuilder().from_pipeline(
36 | ... steps=[
37 | ... ("impute", SimpleImputer()),
38 | ... ("pca", PCA()),
39 | ... ("lr", LogisticRegression())
40 | ... ]
41 | ... ).make_dag()
42 | >>> dag.show()
43 | o impute
44 | |
45 | o pca
46 | |
47 | o lr
48 |
49 |
50 | .. image:: _static/img/dag1.png
51 |
52 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`,
53 | which allows you to define the graph by specifying the dependencies of each new
54 | estimator:
55 |
56 | .. code-block:: python
57 |
58 | >>> dag = (
59 | ... DAGBuilder(infer_dataframe=True)
60 | ... .add_step("impute", SimpleImputer())
61 | ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]})
62 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)})
63 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"])
64 | ... .make_dag()
65 | ... )
66 | >>> dag.show()
67 | o impute
68 | |\
69 | o o blood,vitals
70 | |/
71 | o lr
72 |
73 |
74 | .. image:: _static/img/dag2a.png
75 |
76 | In the above examples we pass the first four columns directly to a regressor, but
77 | the remaining columns have dimensionality reduction applied first before being
78 | passed to the same regressor as extra input columns.
79 |
80 | In this DAG, as well as using the ``deps`` option to control which estimators feed in to
81 | other estimators, but which columns are used (and ignored) by each step. For more detail
82 | on how to control this behaviour, see the `User Guide `_.
83 |
84 | The DAG may now be used as an estimator in its own right:
85 |
86 | >>> from sklearn import datasets
87 | >>> X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
88 | >>> type(dag.fit_predict(X, y))
89 |
90 |
91 | In an extension to the scikit-learn estimator interface, DAGs also support multiple
92 | inputs and multiple outputs. Let's say we want to compare two different classifiers:
93 |
94 | >>> from sklearn.ensemble import RandomForestClassifier
95 | >>> cal = DAGBuilder(infer_dataframe=True).from_pipeline(
96 | ... [("rf", RandomForestClassifier(random_state=0))]
97 | ... ).make_dag()
98 | >>> dag2 = dag.join(cal, edges=[("blood", "rf"), ("vitals", "rf")])
99 | >>> dag2.show()
100 | o impute
101 | |\
102 | o o blood,vitals
103 | |x|
104 | o o lr,rf
105 |
106 |
107 | .. image:: _static/img/dag3a.png
108 |
109 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are
110 | returned as a :class:`sklearn.utils.Bunch`:
111 |
112 | >>> y_pred = dag2.fit_predict(X, y)
113 | >>> type(y_pred.lr)
114 |
115 | >>> type(y_pred.rf)
116 |
117 |
118 | Similarly, multiple inputs are also acceptable and inputs can be provided by
119 | specifying ``X`` and ``y`` as ``dict``-like objects.
--------------------------------------------------------------------------------
/doc/user_guide.rst:
--------------------------------------------------------------------------------
1 | .. title:: User guide : contents
2 |
3 | .. _user_guide:
4 |
5 | ########################
6 | Composing Estimator DAGs
7 | ########################
8 |
9 | The following tutorial shows you how to write some simple directed acyclic graphs (DAGs)
10 | with ``skdag``.
11 |
12 | Creating your first DAG
13 | =======================
14 |
15 | The simplest DAGs are just a chain of singular dependencies, which is equivalent to a
16 | scikit-learn :class:`~sklearn.pipeline.Pipeline`. These DAGs may be created from the
17 | :meth:`~skdag.DAG.dag._dag.from_pipeline` method in the same way as a DAG:
18 |
19 | .. code-block:: python
20 |
21 | >>> from skdag import DAGBuilder
22 | >>> from sklearn.decomposition import PCA
23 | >>> from sklearn.impute import SimpleImputer
24 | >>> from sklearn.linear_model import LogisticRegression
25 | >>> dag = DAGBuilder(infer_dataframe=True).from_pipeline(
26 | ... steps=[
27 | ... ("impute", SimpleImputer()),
28 | ... ("pca", PCA()),
29 | ... ("lr", LogisticRegression())
30 | ... ]
31 | ... ).make_dag()
32 |
33 | You may view a diagram of the DAG with the :meth:`~skdag.dag.DAG.show` method. In a
34 | notbook environment this will display an image, whereas in a terminal it will generate
35 | ASCII text:
36 |
37 | .. code-block:: python
38 |
39 | >>> dag.show()
40 | o impute
41 | |
42 | o pca
43 | |
44 | o lr
45 |
46 | .. image:: _static/img/dag1.png
47 |
48 | Note that we also provided an extra option, ``infer_dataframe``. This is entirely
49 | optional, but if set the DAG will ensure that dataframe inputs have column and index
50 | information preserved (or inferred), and the output of the pipeline will also be a
51 | dataframe. This is useful if you wish to filter down the inputs for one particular step
52 | to only include certain columns; something we shall see in action later.
53 |
54 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`,
55 | which allows you to define the graph by specifying the dependencies of each new
56 | estimator:
57 |
58 | .. code-block:: python
59 |
60 | >>> from skdag import DAGBuilder
61 | >>> from sklearn.compose import make_column_selector
62 | >>> dag = (
63 | ... DAGBuilder(infer_dataframe=True)
64 | ... .add_step("impute", SimpleImputer())
65 | ... .add_step("vitals", "passthrough", deps={"impute": ["age", "sex", "bmi", "bp"]})
66 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": make_column_selector("s[0-9]+")})
67 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"])
68 | ... .make_dag()
69 | ... )
70 | >>> dag.show()
71 | o impute
72 | |\
73 | o o blood,vitals
74 | |/
75 | o lr
76 |
77 | .. image:: _static/img/dag2.png
78 |
79 | In the above examples we pass the first four columns directly to a regressor, but
80 | the remaining columns have dimensionality reduction applied first before being
81 | passed to the same regressor. Note that we can define our graph edges in two
82 | different ways: as a dict (if we need to select only certain columns from the source
83 | node) or as a simple list (if we want to simply grab all columns from all input
84 | nodes). Columns may be specified as any kind of iterable (list, slice etc.) or a column
85 | selector function that conforms to :meth:`sklearn.compose.make_column_selector`.
86 |
87 | If you wish to specify string column names for dependencies, ensure you provide the
88 | ``infer_dataframe=True`` option when you create a dag. This will ensure that all
89 | estimator outputs are coerced into dataframes. Where possible column names will be
90 | inferred, otherwise the column names will just be the name of the estimator step with an
91 | appended index number. If you do not specify ``infer_dataframe=True``, the dag will
92 | leave the outputs unmodified, which in most cases will mean numpy arrays that only
93 | support numeric column indices.
94 |
95 | The DAG may now be used as an estimator in its own right:
96 |
97 | .. code-block:: python
98 |
99 | >>> from sklearn import datasets
100 | >>> X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
101 | >>> y_hat = dag.fit_predict(X, y)
102 | >>> type(y_hat)
103 |
104 |
105 | In an extension to the scikit-learn estimator interface, DAGs also support multiple
106 | inputs and multiple outputs. Let's say we want to compare two different classifiers:
107 |
108 | .. code-block:: python
109 |
110 | >>> from sklearn.ensemble import RandomForestClassifier
111 | >>> rf = DAGBuilder().from_pipeline(
112 | ... [("rf", RandomForestClassifier(random_state=0))]
113 | ... ).make_dag()
114 | >>> dag2 = dag.join(rf, edges=[("blood", "rf"), ("vitals", "rf")])
115 | >>> dag2.show()
116 | o impute
117 | |\
118 | o o blood,vitals
119 | |x|
120 | o o lr,rf
121 |
122 | .. image:: _static/img/dag3.png
123 |
124 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are
125 | returned as a :class:`sklearn.utils.Bunch`:
126 |
127 | .. code-block:: python
128 |
129 | >>> y_pred = dag2.fit_predict(X, y)
130 | >>> type(y_pred.lr)
131 |
132 | >>> type(y_pred.rf)
133 |
134 |
135 | Note that we have different types of output here because ``LogisticRegression`` natively
136 | supports dataframe input whereas ``RandomForestClassifier`` does not. We could fix this
137 | by specifying ``infer_dataframe=True`` when we createed our ``rf`` DAG extension.
138 |
139 | Similarly, multiple inputs are also acceptable and inputs can be provided by
140 | specifying ``X`` and ``y`` as ``dict``-like objects.
141 |
142 | ########
143 | Stacking
144 | ########
145 |
146 | Unlike Pipelines, DAGs do not require only the final step to be an estimator. This
147 | allows DAGs to be used for model stacking.
148 |
149 | Stacking is an ensemble method, like bagging or boosting, that allows multiple models
150 | to be combined into a single, more robust estimator. In stacking, predictions from
151 | multiple models are passed to a final `meta-estimator`; a simple model that combines the
152 | previous predictions into a final output. Like other ensemble methods, stacking can help
153 | to improve the performance and robustness of individual models.
154 |
155 | ``skdag`` implements stacking in a simple way. If an estimator without a ``transform()``
156 | method is placed in a non-leaf step of the DAG, then the output of
157 | :meth:`predict_proba`, :meth:`decision_function` or :meth:`predict` will be passed to
158 | the next step(s).
159 |
160 | .. code-block:: python
161 |
162 | >>> from sklearn import datasets
163 | >>> from sklearn.linear_model import LinearRegression
164 | >>> from sklearn.model_selection import train_test_split
165 | >>> from sklearn.neighbors import KNeighborsRegressor
166 | >>> from sklearn.svm import SVR
167 | >>> X, y = datasets.load_diabetes(return_X_y=True)
168 | >>> X_train, X_test, y_train, y_test = train_test_split(
169 | ... X, y, test_size=0.2, random_state=0
170 | ... )
171 | >>> knn = KNeighborsRegressor(3)
172 | >>> svr = SVR(C=1.0)
173 | >>> stack = (
174 | ... DAGBuilder()
175 | ... .add_step("pass", "passthrough")
176 | ... .add_step("knn", knn, deps=["pass"])
177 | ... .add_step("svr", svr, deps=["pass"])
178 | ... .add_step("meta", LinearRegression(), deps=["knn", "svr"])
179 | ... .make_dag()
180 | ... )
181 | >>> stack.fit(X_train, y_train)
182 | DAG(...
183 |
184 | .. image:: _static/img/stack.png
185 |
186 | Note that the passthrough is not strictly necessary but it is convenient as it ensures
187 | the stack has a single entry point, which makes it simpler to use.
188 |
189 | The DAG infers that :meth:`predict` should be called for the two intermediate
190 | estimators. Our meta-estimator is then simply taking in prediction for each classifier
191 | as its input features.
192 |
193 | As we can now see, the stacking ensemble method gives us a boost in performance:
194 |
195 | .. code-block:: python
196 |
197 | >>> stack.score(X_test, y_test)
198 | 0.145...
199 | >>> knn.score(X_test, y_test)
200 | 0.138...
201 | >>> svr.score(X_test, y_test)
202 | 0.128...
203 |
204 | Note that for binary classifiers you probably need to specify that only the positive
205 | class probability is used as input by the meta-estimator. The DAG will automatically
206 | infer that :meth:`predict_proba` should be called, but you will need to manually tell
207 | the DAG which column to take. To do this, you can simply specify your step dependencies
208 | as a dictionary of step name to column indices instead:
209 |
210 | .. code:: python
211 |
212 | >>> from sklearn.ensemble import RandomForestClassifier
213 | >>> from sklearn.svm import SVC
214 | >>> clf_stack = (
215 | ... DAGBuilder(infer_dataframe=True)
216 | ... .add_step("pass", "passthrough")
217 | ... .add_step("rf", RandomForestClassifier(), deps=["pass"])
218 | ... .add_step("svr", SVC(), deps=["pass"])
219 | ... .add_step("meta", LinearRegression(), deps={"rf": 1, "svr": 1})
220 | ... .make_dag()
221 | ... )
222 |
223 | Stacking works best when a diverse range of algorithms are used to provide predictions,
224 | which are then fed into a very simple meta-estimator. To minimize overfitting,
225 | cross-validation should be considered when using stacking.
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: project-template
2 | dependencies:
3 | - networkx>2.8
4 | - numpy
5 | - scipy
6 | - scikit-learn
7 |
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 | .. _general_examples:
2 |
3 | General examples
4 | ================
5 |
6 | Introductory examples.
7 |
--------------------------------------------------------------------------------
/img/skdag-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-banner.png
--------------------------------------------------------------------------------
/img/skdag-dark-fill.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark-fill.png
--------------------------------------------------------------------------------
/img/skdag-dark.kra:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark.kra
--------------------------------------------------------------------------------
/img/skdag-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-dark.png
--------------------------------------------------------------------------------
/img/skdag-fill.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag-fill.png
--------------------------------------------------------------------------------
/img/skdag.kra:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag.kra
--------------------------------------------------------------------------------
/img/skdag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/img/skdag.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | joblib
3 | networkx>=2.6
4 | numpy
5 | scipy
6 | scikit-learn
7 | stackeddag
8 |
--------------------------------------------------------------------------------
/requirements_doc.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpydoc
3 | sphinx
4 | sphinx-gallery
5 | sphinx_rtd_theme
6 |
--------------------------------------------------------------------------------
/requirements_full.txt:
--------------------------------------------------------------------------------
1 | pygraphviz
2 |
--------------------------------------------------------------------------------
/requirements_test.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | pytest
3 | pytest-cov
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.rst
3 |
4 | [aliases]
5 | test = pytest
6 |
7 | [tool:pytest]
8 | doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS
9 | testpaths = .
10 | addopts =
11 | -s
12 | --doctest-modules
13 | --doctest-glob="*.rst"
14 | --cov=skdag
15 | --ignore setup.py
16 | --ignore doc/_build
17 | --ignore doc/_templates
18 | --no-cov-on-fail
19 |
20 | [coverage:run]
21 | branch = True
22 | source = skdag
23 | include = */skdag/*
24 | omit =
25 | */tests/*
26 | *_test.py
27 | test_*.py
28 | */setup.py
29 |
30 | [coverage:report]
31 | exclude_lines =
32 | pragma: no cover
33 | def __repr__
34 | if self.debug:
35 | if settings.DEBUG
36 | raise AssertionError
37 | raise NotImplementedError
38 | if 0:
39 | if __name__ == .__main__.:
40 | if self.verbose:
41 | show_missing = True
42 |
43 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | """A flexible alternative to scikit-learn Pipelines"""
3 |
4 | import codecs
5 | import os
6 |
7 | from setuptools import find_packages, setup
8 |
9 |
10 | def parse_requirements(filename):
11 | # Copy dependencies from requirements file
12 | with open(filename, encoding="utf-8") as f:
13 | requirements = [line.strip() for line in f.read().splitlines()]
14 | requirements = [
15 | line.split("#")[0].strip()
16 | for line in requirements
17 | if not line.startswith("#")
18 | ]
19 |
20 | return requirements
21 |
22 |
23 | # get __version__ from _version.py
24 | ver_file = os.path.join("skdag", "_version.py")
25 | with open(ver_file) as f:
26 | exec(f.read())
27 |
28 | DISTNAME = "skdag"
29 | DESCRIPTION = "A flexible alternative to scikit-learn Pipelines"
30 |
31 | with codecs.open("README.rst", encoding="utf-8") as f:
32 | LONG_DESCRIPTION = f.read()
33 |
34 | MAINTAINER = "big-o"
35 | MAINTAINER_EMAIL = "big-o@users.noreply.github.com"
36 | URL = "https://github.com/big-o/skdag"
37 | LICENSE = "new BSD"
38 | DOWNLOAD_URL = "https://github.com/scikit-learn-contrib/project-template"
39 | VERSION = __version__
40 | INSTALL_REQUIRES = parse_requirements("requirements.txt")
41 | CLASSIFIERS = [
42 | "Intended Audience :: Science/Research",
43 | "Intended Audience :: Developers",
44 | "License :: OSI Approved",
45 | "Programming Language :: Python",
46 | "Topic :: Software Development",
47 | "Topic :: Scientific/Engineering",
48 | "Operating System :: Microsoft :: Windows",
49 | "Operating System :: POSIX",
50 | "Operating System :: Unix",
51 | "Operating System :: MacOS",
52 | "Programming Language :: Python :: 3.7",
53 | "Programming Language :: Python :: 3.8",
54 | "Programming Language :: Python :: 3.9",
55 | ]
56 | EXTRAS_REQUIRE = {
57 | tgt: parse_requirements(f"requirements_{tgt}.txt")
58 | for tgt in ["test", "doc"]
59 | }
60 |
61 | setup(
62 | name=DISTNAME,
63 | maintainer=MAINTAINER,
64 | maintainer_email=MAINTAINER_EMAIL,
65 | description=DESCRIPTION,
66 | license=LICENSE,
67 | url=URL,
68 | version=VERSION,
69 | download_url=DOWNLOAD_URL,
70 | long_description=LONG_DESCRIPTION,
71 | zip_safe=False, # the package can run out of an .egg file
72 | classifiers=CLASSIFIERS,
73 | packages=find_packages(),
74 | install_requires=INSTALL_REQUIRES,
75 | extras_require=EXTRAS_REQUIRE,
76 | )
77 |
--------------------------------------------------------------------------------
/skdag/__init__.py:
--------------------------------------------------------------------------------
1 | from skdag.dag import *
2 |
3 | from skdag._version import __version__
4 |
5 | __all__ = [
6 | "DAG",
7 | "DAGBuilder",
8 | "DAGRenderer",
9 | "DAGStep",
10 | "__version__",
11 | ]
12 |
--------------------------------------------------------------------------------
/skdag/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.7"
2 |
--------------------------------------------------------------------------------
/skdag/dag/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`skdag.dag` module implements utilities to build a composite
3 | estimator, as a directed acyclic graph (DAG). A DAG may have one or more inputs (roots)
4 | and also one or more outputs (leaves), but may not contain any cyclic processing paths.
5 | """
6 | # Author: big-o (github)
7 | # License: BSD
8 |
9 | from skdag.dag._dag import *
10 | from skdag.dag._builder import *
11 | from skdag.dag._render import *
12 |
--------------------------------------------------------------------------------
/skdag/dag/_builder.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping, Sequence
2 |
3 | import networkx as nx
4 | from skdag.dag._dag import DAG, DAGStep
5 | from skdag.exceptions import DAGError
6 |
7 | __all__ = ["DAGBuilder"]
8 |
9 |
10 | class DAGBuilder:
11 | """
12 | Helper utility for creating a :class:`skdag.DAG`.
13 |
14 | ``DAGBuilder`` allows a graph to be defined incrementally by specifying one node
15 | (step) at a time. Graph edges are defined by providing optional dependency lists
16 | that reference each step by name. Note that steps must be defined before they are
17 | used as dependencies.
18 |
19 | Parameters
20 | ----------
21 |
22 | infer_dataframe : bool, default = False
23 | If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` is
24 | called, if ``dataframe_columns`` is set to ``None``. This effectively makes the
25 | resulting DAG always try to coerce output into pandas DataFrames wherever
26 | possible.
27 |
28 | See Also
29 | --------
30 | :class:`skdag.DAG` : The estimator DAG created by this utility.
31 |
32 | Examples
33 | --------
34 |
35 | >>> from skdag import DAGBuilder
36 | >>> from sklearn.decomposition import PCA
37 | >>> from sklearn.impute import SimpleImputer
38 | >>> from sklearn.linear_model import LogisticRegression
39 | >>> dag = (
40 | ... DAGBuilder()
41 | ... .add_step("impute", SimpleImputer())
42 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)})
43 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)})
44 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"])
45 | ... .make_dag()
46 | ... )
47 | >>> print(dag.draw().strip())
48 | o impute
49 | |\\
50 | o o blood,vitals
51 | |/
52 | o lr
53 | """
54 |
55 | def __init__(self, infer_dataframe=False):
56 | self.graph = nx.DiGraph()
57 | self.infer_dataframe = infer_dataframe
58 |
59 | def from_pipeline(self, steps, **kwargs):
60 | """
61 | Construct a DAG from a simple linear sequence of steps. The resulting DAG will
62 | be equivalent to a :class:`~sklearn.pipeline.Pipeline`.
63 |
64 | Parameters
65 | ----------
66 |
67 | steps : sequence of (str, estimator)
68 | An ordered sequence of pipeline steps. A step is simply a pair of
69 | ``(name, estimator)``, just like a scikit-learn Pipeline.
70 |
71 | infer_dataframe : bool, default = False
72 | If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step`
73 | is called, if ``dataframe_columns`` is set to ``None``. This effectively
74 | makes the resulting DAG always try to coerce output into pandas DataFrames
75 | wherever possible.
76 |
77 | kwargs : kwargs
78 | Any other hyperparameters that are accepted by :class:`~skdag.dag.DAG`'s
79 | contructor.
80 | """
81 | if hasattr(steps, "steps"):
82 | pipe = steps
83 | steps = pipe.steps
84 | if hasattr(pipe, "get_params"):
85 | kwargs = {
86 | **{
87 | k: v
88 | for k, v in pipe.get_params().items()
89 | if k in ("memory", "verbose")
90 | },
91 | **kwargs,
92 | }
93 |
94 | dfcols = "infer" if self.infer_dataframe else None
95 |
96 | for i in range(len(steps)):
97 | name, estimator = steps[i]
98 | self._validate_name(name)
99 | deps = {}
100 | if i > 0:
101 | dep = steps[i - 1][0]
102 | deps[dep] = None
103 | self._validate_deps(deps)
104 |
105 | step = DAGStep(name, estimator, deps, dataframe_columns=dfcols)
106 | self.graph.add_node(name, step=step)
107 | if deps:
108 | self.graph.add_edge(dep, name)
109 |
110 | self._validate_graph()
111 |
112 | return self
113 |
114 | def add_step(self, name, est, deps=None, dataframe_columns=None):
115 | self._validate_name(name)
116 | if isinstance(deps, Sequence):
117 | deps = {dep: None for dep in deps}
118 |
119 | if deps is not None:
120 | self._validate_deps(deps)
121 | else:
122 | deps = {}
123 |
124 | if dataframe_columns is None and self.infer_dataframe:
125 | dfcols = "infer"
126 | else:
127 | dfcols = dataframe_columns
128 |
129 | step = DAGStep(name, est, deps=deps, dataframe_columns=dfcols)
130 | self.graph.add_node(name, step=step)
131 |
132 | for dep in deps:
133 | # Since node is new, edges will never form a cycle.
134 | self.graph.add_edge(dep, name)
135 |
136 | self._validate_graph()
137 |
138 | return self
139 |
140 | def _validate_name(self, name):
141 | if not isinstance(name, str):
142 | raise KeyError(f"step names must be strings, got '{type(name)}'")
143 |
144 | if name in self.graph.nodes:
145 | raise KeyError(f"step with name '{name}' already exists")
146 |
147 | def _validate_deps(self, deps):
148 | if not isinstance(deps, Mapping) or not all(
149 | [isinstance(dep, str) for dep in deps]
150 | ):
151 | raise ValueError(
152 | "deps parameter must be a map of labels to indices, "
153 | f"got '{type(deps)}'."
154 | )
155 |
156 | missing = [dep for dep in deps if dep not in self.graph]
157 | if missing:
158 | raise ValueError(f"unresolvable dependencies: {', '.join(sorted(missing))}")
159 |
160 | def _validate_graph(self):
161 | if not nx.algorithms.dag.is_directed_acyclic_graph(self.graph):
162 | raise DAGError("Workflow is not a DAG.")
163 |
164 | def make_dag(self, **kwargs):
165 | self._validate_graph()
166 | # Give the DAG a read-only view of the graph.
167 | return DAG(graph=self.graph.copy(as_view=True), **kwargs)
168 |
169 | def _repr_html_(self):
170 | return self.make_dag()._repr_html_()
171 |
--------------------------------------------------------------------------------
/skdag/dag/_dag.py:
--------------------------------------------------------------------------------
1 | """
2 | Directed Acyclic Graphs (DAGs) may be used to construct complex workflows for
3 | scikit-learn estimators. As the name suggests, data may only flow in one
4 | direction and can't go back on itself to a previously run step.
5 | """
6 | from collections import UserDict
7 | from copy import deepcopy
8 | from inspect import signature
9 | from itertools import chain
10 | from typing import Iterable
11 |
12 | import networkx as nx
13 | import numpy as np
14 | from joblib import Parallel, delayed
15 | from scipy.sparse import dok_matrix, issparse
16 | from skdag.dag._render import DAGRenderer
17 | from skdag.dag._utils import (
18 | _format_output,
19 | _in_notebook,
20 | _is_pandas,
21 | _is_passthrough,
22 | _is_predictor,
23 | _is_transformer,
24 | _stack,
25 | )
26 | from sklearn.base import clone
27 | from sklearn.exceptions import NotFittedError
28 | from sklearn.utils import Bunch, _safe_indexing, deprecated
29 | from sklearn.utils._tags import _safe_tags
30 | from sklearn.utils.metaestimators import _BaseComposition, available_if
31 | from sklearn.utils.validation import check_is_fitted, check_memory
32 |
33 | try:
34 | from sklearn.utils import _print_elapsed_time
35 | except ImportError:
36 | from sklearn.utils._user_interface import _print_elapsed_time
37 |
38 | __all__ = ["DAG", "DAGStep"]
39 |
40 |
41 | def _get_columns(X, dep, cols, is_root, dep_is_passthrough, axis=1):
42 | if callable(cols):
43 | # sklearn.compose.make_column_selector
44 | cols = cols(X)
45 |
46 | if not is_root and not dep_is_passthrough:
47 | # The DAG will prepend output columns with the step name, so add this in to any
48 | # dep columns if missing. This helps keep user-provided deps readable.
49 | if isinstance(cols, str):
50 | cols = cols if cols.startswith(f"{dep}__") else f"{dep}__{cols}"
51 | elif isinstance(cols, Iterable):
52 | orig = cols
53 | cols = []
54 | for col in orig:
55 | if isinstance(col, str):
56 | cols.append(col if col.startswith(f"{dep}__") else f"{dep}__{col}")
57 | else:
58 | cols.append(col)
59 |
60 | return _safe_indexing(X, cols, axis=axis)
61 |
62 |
63 | def _stack_inputs(dag, X, node):
64 | # For root nodes, the dependency is just the node name itself.
65 | deps = {node.name: None} if node.is_root else node.deps
66 |
67 | cols = [
68 | _get_columns(
69 | X[dep],
70 | dep,
71 | cols,
72 | node.is_root,
73 | _is_passthrough(dag.graph_.nodes[dep]["step"].estimator),
74 | axis=1,
75 | )
76 | for dep, cols in deps.items()
77 | ]
78 |
79 | to_stack = [
80 | # If we sliced a single column from an input, reshape it to a 2d array.
81 | col.reshape(-1, 1)
82 | if col is not None and deps[dep] is not None and col.ndim < 2
83 | else col
84 | for col, dep in zip(cols, deps)
85 | ]
86 |
87 | X_stacked = _stack(to_stack, axis=node.axis)
88 |
89 | return X_stacked
90 |
91 |
92 | def _leaf_estimators_have(attr, how="all"):
93 | """Check that leaves have `attr`.
94 | Used together with `avaliable_if` in `DAG`."""
95 |
96 | def check_leaves(self):
97 | # raises `AttributeError` with all details if `attr` does not exist
98 | failed = []
99 | for leaf in self.leaves_:
100 | try:
101 | _is_passthrough(leaf.estimator) or getattr(leaf.estimator, attr)
102 | except AttributeError:
103 | failed.append(leaf.estimator)
104 |
105 | if (how == "all" and failed) or (
106 | how == "any" and len(failed) != len(self.leaves_)
107 | ):
108 | raise AttributeError(
109 | f"{', '.join([repr(type(est)) for est in failed])} "
110 | f"object(s) has no attribute '{attr}'"
111 | )
112 | return True
113 |
114 | return check_leaves
115 |
116 |
117 | def _transform_one(transformer, X, weight, allow_predictor=True, **fit_params):
118 | if _is_passthrough(transformer):
119 | res = X
120 | elif allow_predictor and not hasattr(transformer, "transform"):
121 | for fn in ["predict_proba", "decision_function", "predict"]:
122 | if hasattr(transformer, fn):
123 | res = getattr(transformer, fn)(X)
124 | if res.ndim < 2:
125 | res = res.reshape(-1, 1)
126 | break
127 | else:
128 | raise AttributeError(
129 | f"'{type(transformer).__name__}' object has no attribute 'transform'"
130 | )
131 | else:
132 | res = transformer.transform(X)
133 | # if we have a weight for this transformer, multiply output
134 | if weight is not None:
135 | res = res * weight
136 |
137 | return res
138 |
139 |
140 | def _fit_transform_one(
141 | transformer,
142 | X,
143 | y,
144 | weight,
145 | message_clsname="",
146 | message=None,
147 | allow_predictor=True,
148 | **fit_params,
149 | ):
150 | """
151 | Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
152 | with the fitted transformer. If ``weight`` is not ``None``, the result will
153 | be multiplied by ``weight``.
154 | """
155 | with _print_elapsed_time(message_clsname, message):
156 | failed = False
157 | if _is_passthrough(transformer):
158 | res = X
159 | elif hasattr(transformer, "fit_transform"):
160 | res = transformer.fit_transform(X, y, **fit_params)
161 | elif hasattr(transformer, "transform"):
162 | res = transformer.fit(X, y, **fit_params).transform(X)
163 | elif allow_predictor:
164 | for fn in ["predict_proba", "decision_function", "predict"]:
165 | if hasattr(transformer, fn):
166 | res = getattr(transformer.fit(X, y, **fit_params), fn)(X)
167 | if res.ndim < 2:
168 | res = res.reshape(-1, 1)
169 | break
170 | else:
171 | failed = True
172 | res = None
173 |
174 | if res is not None and res.ndim < 2:
175 | res = res.reshape(-1, 1)
176 | else:
177 | failed = True
178 |
179 | if failed:
180 | raise AttributeError(
181 | f"'{type(transformer).__name__}' object has no attribute 'transform'"
182 | )
183 |
184 | if weight is not None:
185 | res = res * weight
186 |
187 | return res, transformer
188 |
189 |
190 | def _parallel_fit(dag, step, Xin, Xs, y, fit_transform_fn, memory, **fit_params):
191 | transformer = step.estimator
192 |
193 | if step.deps:
194 | X = _stack_inputs(dag, Xs, step)
195 | else:
196 | # For root nodes, the destination rather than the source is
197 | # specified.
198 | # X = Xin[step.name]
199 | X = _stack_inputs(dag, Xin, step)
200 |
201 | clsname = type(dag).__name__
202 | with _print_elapsed_time(clsname, dag._log_message(step)):
203 | if transformer is None or transformer == "passthrough":
204 | Xt, fitted_transformer = X, transformer
205 | else:
206 | if hasattr(memory, "location") and memory.location is None:
207 | # we do not clone when caching is disabled to
208 | # preserve backward compatibility
209 | cloned_transformer = transformer
210 | else:
211 | cloned_transformer = clone(transformer)
212 |
213 | # Fit or load from cache the current transformer
214 | Xt, fitted_transformer = fit_transform_fn(
215 | cloned_transformer,
216 | X,
217 | y,
218 | None,
219 | message_clsname=clsname,
220 | message=dag._log_message(step),
221 | **fit_params,
222 | )
223 |
224 | Xt = _format_output(Xt, X, step)
225 |
226 | return Xt, fitted_transformer
227 |
228 |
229 | def _parallel_transform(dag, step, Xin, Xs, transform_fn, **fn_params):
230 | transformer = step.estimator
231 | if step.deps:
232 | X = _stack_inputs(dag, Xs, step)
233 | else:
234 | # For root nodes, the destination rather than the source is
235 | # specified.
236 | X = _stack_inputs(dag, Xin, step)
237 | # X = Xin[step.name]
238 |
239 | clsname = type(dag).__name__
240 | with _print_elapsed_time(clsname, dag._log_message(step)):
241 | if transformer is None or transformer == "passthrough":
242 | Xt = X
243 | else:
244 | # Fit or load from cache the current transformer
245 | Xt = transform_fn(
246 | transformer,
247 | X,
248 | None,
249 | message_clsname=clsname,
250 | message=dag._log_message(step),
251 | **fn_params,
252 | )
253 |
254 | Xt = _format_output(Xt, X, step)
255 |
256 | return Xt
257 |
258 |
259 | def _parallel_fit_leaf(dag, leaf, Xts, y, **fit_params):
260 | with _print_elapsed_time(type(dag).__name__, dag._log_message(leaf)):
261 | if leaf.estimator == "passthrough":
262 | fitted_estimator = leaf.estimator
263 | else:
264 | Xt = _stack_inputs(dag, Xts, leaf)
265 | fitted_estimator = leaf.estimator.fit(Xt, y, **fit_params)
266 |
267 | return fitted_estimator
268 |
269 |
270 | def _parallel_execute(
271 | dag, leaf, fn, Xts, y=None, fit_first=False, fit_params=None, fn_params=None
272 | ):
273 | with _print_elapsed_time("DAG", dag._log_message(leaf)):
274 | Xt = _stack_inputs(dag, Xts, leaf)
275 | fit_params = fit_params or {}
276 | fn_params = fn_params or {}
277 | if leaf.estimator == "passthrough":
278 | Xout = Xt
279 | elif fit_first and hasattr(leaf.estimator, f"fit_{fn}"):
280 | Xout = getattr(leaf.estimator, f"fit_{fn}")(Xt, y, **fit_params)
281 | else:
282 | if fit_first:
283 | leaf.estimator.fit(Xt, y, **fit_params)
284 |
285 | est_fn = getattr(leaf.estimator, fn)
286 | if "y" in signature(est_fn).parameters:
287 | Xout = est_fn(Xt, y=y, **fn_params)
288 | else:
289 | Xout = est_fn(Xt, **fn_params)
290 |
291 | Xout = _format_output(Xout, Xt, leaf)
292 |
293 | fitted_estimator = leaf.estimator
294 |
295 | return Xout, fitted_estimator
296 |
297 |
298 | class DAGStep:
299 | """
300 | A single estimator step in a DAG.
301 |
302 | Parameters
303 | ----------
304 | name : str
305 | The reference name for this step.
306 | estimator : estimator-like
307 | The estimator (transformer or predictor) that will be executed by this step.
308 | deps : dict
309 | A map of dependency names to columns. If columns is ``None``, then all input
310 | columns will be selected.
311 | dataframe_columns : list of str or "infer" (optional)
312 | Either a hard-coded list of column names to apply to any output data, or the
313 | string "infer", which means the column outputs will be assumed to match the
314 | column inputs if the output is 2d and not already a dataframe, the estimator is
315 | a transformer, and the final axis dimensions match the inputs. Otherwise the
316 | column names will be assumed to be the step name + index if the output is not
317 | already a dataframe. If set to ``None`` or inference is not possible, the
318 | outputs will be left unmodified.
319 | axis : int, default = 1
320 | The strategy for merging inputs if there is more than upstream dependency.
321 | ``axis=0`` will assume all inputs have the same features and stack the rows
322 | together; ``axis=1`` will assume each input provides different features for the
323 | same samples.
324 | """
325 |
326 | def __init__(self, name, estimator, deps, dataframe_columns, axis=1):
327 | self.name = name
328 | self.estimator = estimator
329 | self.deps = deps
330 | self.dataframe_columns = dataframe_columns
331 | self.axis = axis
332 | self.index = None
333 | self.is_root = False
334 | self.is_leaf = False
335 | self.is_fitted = False
336 |
337 | def __repr__(self):
338 | return f"{type(self).__name__}({repr(self.name)}, {repr(self.estimator)})"
339 |
340 |
341 | class DAG(_BaseComposition):
342 | """
343 | A Directed Acyclic Graph (DAG) of estimators, that itself implements the estimator
344 | interface.
345 |
346 | A DAG may consist of a simple chain of estimators (being exactly equivalent to a
347 | :mod:`sklearn.pipeline.Pipeline`) or a more complex path of dependencies. But as the
348 | name suggests, it may not contain any cyclic dependencies and data may only flow
349 | from one or more start points (roots) to one or more endpoints (leaves).
350 |
351 | Parameters
352 | ----------
353 |
354 | graph : :class:`networkx.DiGraph`
355 | A directed graph with string node IDs indicating the step name. Each node must
356 | have a ``step`` attribute, which contains a :class:`skdag.dag.DAGStep`.
357 |
358 | memory : str or object with the joblib.Memory interface, default=None
359 | Used to cache the fitted transformers of the DAG. By default, no caching is
360 | performed. If a string is given, it is the path to the caching directory.
361 | Enabling caching triggers a clone of the transformers before fitting. Therefore,
362 | the transformer instance given to the DAG cannot be inspected directly. Use the
363 | attribute ``named_steps`` or ``steps`` to inspect estimators within the
364 | pipeline. Caching the transformers is advantageous when fitting is time
365 | consuming.
366 |
367 | n_jobs : int, default=None
368 | Number of jobs to run in parallel. ``None`` means 1 unless in a
369 | :obj:`joblib.parallel_backend` context.
370 |
371 | verbose : bool, default=False
372 | If True, the time elapsed while fitting each step will be printed as it is
373 | completed.
374 |
375 | Attributes
376 | ----------
377 |
378 | graph_ : :class:`networkx.DiGraph`
379 | A read-only view of the workflow.
380 |
381 | classes_ : ndarray of shape (n_classes,)
382 | The classes labels. Only exists if the last step of the pipeline is a
383 | classifier.
384 |
385 | n_features_in_ : int
386 | Number of features seen during :term:`fit`. Only defined if all of the
387 | underlying root estimators in `graph_` expose such an attribute when fit.
388 |
389 | feature_names_in_ : ndarray of shape (`n_features_in_`,)
390 | Names of features seen during :term:`fit`. Only defined if the underlying
391 | estimators expose such an attribute when fit.
392 |
393 | See Also
394 | --------
395 | :class:`skdag.DAGBuilder` : Convenience utility for simplified DAG construction.
396 |
397 | Examples
398 | --------
399 |
400 | The simplest DAGs are just a chain of singular dependencies. These DAGs may be
401 | created from the :meth:`skdag.dag.DAG.from_pipeline` method in the same way as a
402 | DAG:
403 |
404 | >>> from sklearn.decomposition import PCA
405 | >>> from sklearn.impute import SimpleImputer
406 | >>> from sklearn.linear_model import LogisticRegression
407 | >>> dag = DAG.from_pipeline(
408 | ... steps=[
409 | ... ("impute", SimpleImputer()),
410 | ... ("pca", PCA()),
411 | ... ("lr", LogisticRegression())
412 | ... ]
413 | ... )
414 | >>> print(dag.draw().strip())
415 | o impute
416 | |
417 | o pca
418 | |
419 | o lr
420 |
421 | For more complex DAGs, it is recommended to use a :class:`skdag.dag.DAGBuilder`,
422 | which allows you to define the graph by specifying the dependencies of each new
423 | estimator:
424 |
425 | >>> from skdag import DAGBuilder
426 | >>> dag = (
427 | ... DAGBuilder()
428 | ... .add_step("impute", SimpleImputer())
429 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)})
430 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)})
431 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"])
432 | ... .make_dag()
433 | ... )
434 | >>> print(dag.draw().strip())
435 | o impute
436 | |\\
437 | o o blood,vitals
438 | |/
439 | o lr
440 |
441 | In the above examples we pass the first four columns directly to a regressor, but
442 | the remaining columns have dimensionality reduction applied first before being
443 | passed to the same regressor. Note that we can define our graph edges in two
444 | different ways: as a dict (if we need to select only certain columns from the source
445 | node) or as a simple list (if we want to simply grab all columns from all input
446 | nodes).
447 |
448 | The DAG may now be used as an estimator in its own right:
449 |
450 | >>> from sklearn import datasets
451 | >>> X, y = datasets.load_diabetes(return_X_y=True)
452 | >>> dag.fit_predict(X, y)
453 | array([...
454 |
455 | In an extension to the scikit-learn estimator interface, DAGs also support multiple
456 | inputs and multiple outputs. Let's say we want to compare two different classifiers:
457 |
458 | >>> from sklearn.ensemble import RandomForestClassifier
459 | >>> cal = DAG.from_pipeline(
460 | ... [("rf", RandomForestClassifier(random_state=0))]
461 | ... )
462 | >>> dag2 = dag.join(cal, edges=[("blood", "rf"), ("vitals", "rf")])
463 | >>> print(dag2.draw().strip())
464 | o impute
465 | |\\
466 | o o blood,vitals
467 | |x|
468 | o o lr,rf
469 |
470 | Now our DAG will return two outputs: one from each classifier. Multiple outputs are
471 | returned as a :class:`sklearn.utils.Bunch`:
472 |
473 | >>> y_pred = dag2.fit_predict(X, y)
474 | >>> y_pred.lr
475 | array([...
476 | >>> y_pred.rf
477 | array([...
478 |
479 | Similarly, multiple inputs are also acceptable and inputs can be provided by
480 | specifying ``X`` and ``y`` as a ``dict``-like object.
481 | """
482 |
483 | # BaseEstimator interface
484 | _required_parameters = ["graph"]
485 |
486 | @classmethod
487 | @deprecated(
488 | "DAG.from_pipeline is deprecated in 0.0.3 and will be removed in a future "
489 | "release. Please use DAGBuilder.from_pipeline instead."
490 | )
491 | def from_pipeline(cls, steps, **kwargs):
492 | from skdag.dag._builder import DAGBuilder
493 |
494 | return DAGBuilder().from_pipeline(steps, **kwargs).make_dag()
495 |
496 | def __init__(self, graph, *, memory=None, n_jobs=None, verbose=False):
497 | self.graph = graph
498 | self.memory = memory
499 | self.verbose = verbose
500 | self.n_jobs = n_jobs
501 |
502 | def get_params(self, deep=True):
503 | """
504 | Get parameters for this metaestimator.
505 |
506 | Returns the parameters given in the constructor as well as the
507 | estimators contained within the `steps_` of the `DAG`.
508 |
509 | Parameters
510 | ----------
511 | deep : bool, default=True
512 | If True, will return the parameters for this estimator and
513 | contained subobjects that are estimators.
514 |
515 | Returns
516 | -------
517 | params : mapping of string to any
518 | Parameter names mapped to their values.
519 | """
520 | return self._get_params("steps_", deep=deep)
521 |
522 | def set_params(self, **params):
523 | """
524 | Set the parameters of this metaestimator.
525 |
526 | Valid parameter keys can be listed with ``get_params()``. Note that
527 | you can directly set the parameters of the estimators contained in
528 | `steps_`.
529 |
530 | Parameters
531 | ----------
532 | **params : dict
533 | Parameters of this metaestimator or parameters of estimators contained
534 | in `steps`. Parameters of the steps may be set using its name and
535 | the parameter name separated by a '__'.
536 |
537 | Returns
538 | -------
539 | self : object
540 | DAG class instance.
541 | """
542 | step_names = set(self.step_names)
543 | for param in list(params.keys()):
544 | if "__" not in param and param in step_names:
545 | self.graph_.nodes[param]["step"].estimator = params.pop(param)
546 |
547 | super().set_params(**params)
548 | return self
549 |
550 | def _log_message(self, step):
551 | if not self.verbose:
552 | return None
553 |
554 | return f"(step {step.name}: {step.index} of {len(self.graph_)}) Processing {step.name}"
555 |
556 | def _iter(self, with_leaves=True, filter_passthrough=True):
557 | """
558 | Generate stage lists from self.graph_.
559 | When filter_passthrough is True, 'passthrough' and None transformers
560 | are filtered out.
561 | """
562 | for stage in nx.topological_generations(self.graph_):
563 | stage = [self.graph_.nodes[step]["step"] for step in stage]
564 | if not with_leaves:
565 | stage = [step for step in stage if not step.is_leaf]
566 |
567 | if filter_passthrough:
568 | stage = [
569 | step
570 | for step in stage
571 | if step.estimator is not None and step.estimator != "passthough"
572 | ]
573 |
574 | if len(stage) == 0:
575 | continue
576 |
577 | yield stage
578 |
579 | def __len__(self):
580 | """
581 | Returns the size of the DAG
582 | """
583 | return len(self.graph_)
584 |
585 | def __getitem__(self, name):
586 | """
587 | Retrieve a named estimator.
588 | """
589 | return self.graph_.nodes[name]["step"].estimator
590 |
591 | def _fit(self, X, y=None, **fit_params_steps):
592 | # Setup the memory
593 | memory = check_memory(self.memory)
594 |
595 | fit_transform_one_cached = memory.cache(_fit_transform_one)
596 |
597 | root_names = set([root.name for root in self.roots_])
598 | Xin = self._resolve_inputs(X)
599 | Xs = {}
600 | with Parallel(n_jobs=self.n_jobs) as parallel:
601 | for stage in self._iter(with_leaves=False, filter_passthrough=False):
602 | stage_names = [step.name for step in stage]
603 | outputs, fitted_transformers = zip(
604 | *parallel(
605 | delayed(_parallel_fit)(
606 | self,
607 | step,
608 | Xin,
609 | Xs,
610 | y,
611 | fit_transform_one_cached,
612 | memory,
613 | **fit_params_steps[step.name],
614 | )
615 | for step in stage
616 | )
617 | )
618 |
619 | for step, fitted_transformer in zip(stage, fitted_transformers):
620 | # Replace the transformer of the step with the fitted
621 | # transformer. This is necessary when loading the transformer
622 | # from the cache.
623 | step.estimator = fitted_transformer
624 | step.is_fitted = True
625 |
626 | Xs.update(dict(zip(stage_names, outputs)))
627 |
628 | # If all of a dep's dependents are now complete, we can free up some
629 | # memory.
630 | root_names = root_names - set(stage_names)
631 | for dep in {dep for step in stage for dep in step.deps}:
632 | dependents = self.graph_.successors(dep)
633 | if all(d in Xs and d not in root_names for d in dependents):
634 | del Xs[dep]
635 |
636 | # If a root node is also a leaf, it hasn't been fit yet and we need to pass on
637 | # its input for later.
638 | Xs.update({name: Xin[name] for name in root_names})
639 | return Xs
640 |
641 | def _transform(self, X, **fn_params_steps):
642 | # Setup the memory
643 | memory = check_memory(self.memory)
644 |
645 | transform_one_cached = memory.cache(_transform_one)
646 |
647 | root_names = set([root.name for root in self.roots_])
648 | Xin = self._resolve_inputs(X)
649 | Xs = {}
650 | with Parallel(n_jobs=self.n_jobs) as parallel:
651 | for stage in self._iter(with_leaves=False, filter_passthrough=False):
652 | stage_names = [step.name for step in stage]
653 | outputs = parallel(
654 | delayed(_parallel_transform)(
655 | self,
656 | step,
657 | Xin,
658 | Xs,
659 | transform_one_cached,
660 | **fn_params_steps[step.name],
661 | )
662 | for step in stage
663 | )
664 |
665 | Xs.update(dict(zip(stage_names, outputs)))
666 |
667 | # If all of a dep's dependents are now complete, we can free up some
668 | # memory.
669 | root_names = root_names - set(stage_names)
670 | for dep in {dep for step in stage for dep in step.deps}:
671 | dependents = self.graph_.successors(dep)
672 | if all(d in Xs and d not in root_names for d in dependents):
673 | del Xs[dep]
674 |
675 | # If a root node is also a leaf, it hasn't been fit yet and we need to pass on
676 | # its input for later.
677 | Xs.update({name: Xin[name] for name in root_names})
678 | return Xs
679 |
680 | def _resolve_inputs(self, X):
681 | if isinstance(X, (dict, Bunch, UserDict)) and not isinstance(X, dok_matrix):
682 | inputs = sorted(X.keys())
683 | if inputs != sorted(root.name for root in self.roots_):
684 | raise ValueError(
685 | "Input dicts must contain one key per entry node. "
686 | f"Entry nodes are {self.roots_}, got {inputs}."
687 | )
688 | else:
689 | if len(self.roots_) != 1:
690 | raise ValueError(
691 | "Must provide a dictionary of inputs for a DAG with multiple entry "
692 | "points."
693 | )
694 | X = {self.roots_[0].name: X}
695 |
696 | X = {
697 | step: x if issparse(x) or _is_pandas(x) else np.asarray(x)
698 | for step, x in X.items()
699 | }
700 |
701 | return X
702 |
703 | def _match_input_format(self, Xin, Xout):
704 | if len(self.leaves_) == 1 and (
705 | not isinstance(Xin, (dict, Bunch, UserDict)) or isinstance(Xin, dok_matrix)
706 | ):
707 | return Xout[self.leaves_[0].name]
708 | return Bunch(**Xout)
709 |
710 | def fit(self, X, y=None, **fit_params):
711 | """
712 | Fit the model.
713 |
714 | Fit all the transformers one after the other and transform the
715 | data. Finally, fit the transformed data using the final estimators.
716 |
717 | Parameters
718 | ----------
719 | X : iterable
720 | Training data. Must fulfill input requirements of first step of the
721 | DAG.
722 | y : iterable, default=None
723 | Training targets. Must fulfill label requirements for all steps of
724 | the DAG.
725 | **fit_params : dict of string -> object
726 | Parameters passed to the ``fit`` method of each step, where
727 | each parameter name is prefixed such that parameter ``p`` for step
728 | ``s`` has key ``s__p``.
729 |
730 | Returns
731 | -------
732 | self : object
733 | DAG fitted steps.
734 | """
735 | self._validate_graph()
736 | fit_params_steps = self._check_fit_params(**fit_params)
737 | Xts = self._fit(X, y, **fit_params_steps)
738 | fitted_estimators = Parallel(n_jobs=self.n_jobs)(
739 | [
740 | delayed(_parallel_fit_leaf)(
741 | self, leaf, Xts, y, **fit_params_steps[leaf.name]
742 | )
743 | for leaf in self.leaves_
744 | ]
745 | )
746 | for est, leaf in zip(fitted_estimators, self.leaves_):
747 | leaf.estimator = est
748 | leaf.is_fitted = True
749 |
750 | # If we have a single root, mirror certain attributes in the DAG.
751 | if len(self.roots_) == 1:
752 | root = self.roots_[0].estimator
753 | for attr in ["n_features_in_", "feature_names_in_"]:
754 | if hasattr(root, attr):
755 | setattr(self, attr, getattr(root, attr))
756 |
757 | return self
758 |
759 | def _fit_execute(self, fn, X, y=None, **fit_params):
760 | self._validate_graph()
761 | fit_params_steps = self._check_fit_params(**fit_params)
762 | Xts = self._fit(X, y, **fit_params_steps)
763 | Xout = {}
764 |
765 | leaf_names = [leaf.name for leaf in self.leaves_]
766 | outputs, fitted_estimators = zip(
767 | *Parallel(n_jobs=self.n_jobs)(
768 | delayed(_parallel_execute)(
769 | self,
770 | leaf,
771 | fn,
772 | Xts,
773 | y,
774 | fit_first=True,
775 | fit_params=fit_params_steps[leaf.name],
776 | )
777 | for leaf in self.leaves_
778 | )
779 | )
780 |
781 | Xout = dict(zip(leaf_names, outputs))
782 | for step, fitted_estimator in zip(self.leaves_, fitted_estimators):
783 | step.estimator = fitted_estimator
784 | step.is_fitted = True
785 |
786 | return self._match_input_format(X, Xout)
787 |
788 | def _execute(self, fn, X, y=None, **fn_params):
789 | Xout = {}
790 | fn_params_steps = self._check_fn_params(**fn_params)
791 | Xts = self._transform(X, **fn_params_steps)
792 |
793 | leaf_names = [leaf.name for leaf in self.leaves_]
794 | outputs, _ = zip(
795 | *Parallel(n_jobs=self.n_jobs)(
796 | delayed(_parallel_execute)(
797 | self,
798 | leaf,
799 | fn,
800 | Xts,
801 | y,
802 | fit_first=False,
803 | fn_params=fn_params_steps[leaf.name],
804 | )
805 | for leaf in self.leaves_
806 | )
807 | )
808 |
809 | Xout = dict(zip(leaf_names, outputs))
810 |
811 | return self._match_input_format(X, Xout)
812 |
813 | @available_if(_leaf_estimators_have("transform"))
814 | def fit_transform(self, X, y=None, **fit_params):
815 | """
816 | Fit the model and transform with the final estimator.
817 |
818 | Fits all the transformers one after the other and transform the
819 | data. Then uses `fit_transform` on transformed data with the final
820 | estimator.
821 |
822 | Parameters
823 | ----------
824 | X : iterable
825 | Training data. Must fulfill input requirements of first step of the
826 | DAG.
827 | y : iterable, default=None
828 | Training targets. Must fulfill label requirements for all steps of
829 | the DAG.
830 | **fit_params : dict of string -> object
831 | Parameters passed to the ``fit`` method of each step, where
832 | each parameter name is prefixed such that parameter ``p`` for step
833 | ``s`` has key ``s__p``.
834 |
835 | Returns
836 | -------
837 | Xt : ndarray of shape (n_samples, n_transformed_features)
838 | Transformed samples.
839 | """
840 | return self._fit_execute("transform", X, y, **fit_params)
841 |
842 | @available_if(_leaf_estimators_have("transform"))
843 | def transform(self, X):
844 | """
845 | Transform the data, and apply `transform` with the final estimator.
846 |
847 | Call `transform` of each transformer in the DAG. The transformed
848 | data are finally passed to the final estimator that calls
849 | `transform` method. Only valid if the final estimator
850 | implements `transform`.
851 |
852 | This also works where final estimator is `None` in which case all prior
853 | transformations are applied.
854 |
855 | Parameters
856 | ----------
857 | X : iterable
858 | Data to transform. Must fulfill input requirements of first step
859 | of the DAG.
860 |
861 | Returns
862 | -------
863 | Xt : ndarray of shape (n_samples, n_transformed_features)
864 | Transformed data.
865 | """
866 | return self._execute("transform", X)
867 |
868 | @available_if(_leaf_estimators_have("predict"))
869 | def fit_predict(self, X, y=None, **fit_params):
870 | """
871 | Transform the data, and apply `fit_predict` with the final estimator.
872 |
873 | Call `fit_transform` of each transformer in the DAG. The transformed data are
874 | finally passed to the final estimator that calls `fit_predict` method. Only
875 | valid if the final estimators implement `fit_predict`.
876 |
877 | Parameters
878 | ----------
879 | X : iterable
880 | Training data. Must fulfill input requirements of first step of
881 | the DAG.
882 | y : iterable, default=None
883 | Training targets. Must fulfill label requirements for all steps
884 | of the DAG.
885 | **fit_params : dict of string -> object
886 | Parameters passed to the ``fit`` method of each step, where
887 | each parameter name is prefixed such that parameter ``p`` for step
888 | ``s`` has key ``s__p``.
889 |
890 | Returns
891 | -------
892 | y_pred : ndarray
893 | Result of calling `fit_predict` on the final estimator.
894 | """
895 | return self._fit_execute("predict", X, y, **fit_params)
896 |
897 | @available_if(_leaf_estimators_have("predict"))
898 | def predict(self, X, **predict_params):
899 | """
900 | Transform the data, and apply `predict` with the final estimator.
901 |
902 | Call `transform` of each transformer in the DAG. The transformed
903 | data are finally passed to the final estimator that calls `predict`
904 | method. Only valid if the final estimators implement `predict`.
905 |
906 | Parameters
907 | ----------
908 | X : iterable
909 | Data to predict on. Must fulfill input requirements of first step
910 | of the DAG.
911 | **predict_params : dict of string -> object
912 | Parameters to the ``predict`` called at the end of all
913 | transformations in the DAG. Note that while this may be
914 | used to return uncertainties from some models with return_std
915 | or return_cov, uncertainties that are generated by the
916 | transformations in the DAG are not propagated to the
917 | final estimator.
918 |
919 | Returns
920 | -------
921 | y_pred : ndarray
922 | Result of calling `predict` on the final estimator.
923 | """
924 | return self._execute("predict", X, **predict_params)
925 |
926 | @available_if(_leaf_estimators_have("predict_proba"))
927 | def predict_proba(self, X, **predict_proba_params):
928 | """
929 | Transform the data, and apply `predict_proba` with the final estimator.
930 |
931 | Call `transform` of each transformer in the DAG. The transformed
932 | data are finally passed to the final estimator that calls
933 | `predict_proba` method. Only valid if the final estimators implement
934 | `predict_proba`.
935 |
936 | Parameters
937 | ----------
938 | X : iterable
939 | Data to predict on. Must fulfill input requirements of first step
940 | of the DAG.
941 | **predict_proba_params : dict of string -> object
942 | Parameters to the `predict_proba` called at the end of all
943 | transformations in the DAG.
944 |
945 | Returns
946 | -------
947 | y_proba : ndarray of shape (n_samples, n_classes)
948 | Result of calling `predict_proba` on the final estimator.
949 | """
950 | return self._execute("predict_proba", X, **predict_proba_params)
951 |
952 | @available_if(_leaf_estimators_have("decision_function"))
953 | def decision_function(self, X):
954 | """
955 | Transform the data, and apply `decision_function` with the final estimator.
956 |
957 | Call `transform` of each transformer in the DAG. The transformed
958 | data are finally passed to the final estimator that calls
959 | `decision_function` method. Only valid if the final estimators
960 | implement `decision_function`.
961 |
962 | Parameters
963 | ----------
964 | X : iterable
965 | Data to predict on. Must fulfill input requirements of first step
966 | of the DAG.
967 |
968 | Returns
969 | -------
970 | y_score : ndarray of shape (n_samples, n_classes)
971 | Result of calling `decision_function` on the final estimator.
972 | """
973 | return self._execute("decision_function", X)
974 |
975 | @available_if(_leaf_estimators_have("score_samples"))
976 | def score_samples(self, X):
977 | """
978 | Transform the data, and apply `score_samples` with the final estimator.
979 |
980 | Call `transform` of each transformer in the DAG. The transformed
981 | data are finally passed to the final estimator that calls
982 | `score_samples` method. Only valid if the final estimators implement
983 | `score_samples`.
984 |
985 | Parameters
986 | ----------
987 | X : iterable
988 | Data to predict on. Must fulfill input requirements of first step
989 | of the DAG.
990 |
991 | Returns
992 | -------
993 | y_score : ndarray of shape (n_samples,)
994 | Result of calling `score_samples` on the final estimator.
995 | """
996 | return self._execute("score_samples", X)
997 |
998 | @available_if(_leaf_estimators_have("score"))
999 | def score(self, X, y=None, **score_params):
1000 | """
1001 | Transform the data, and apply `score` with the final estimator.
1002 |
1003 | Call `transform` of each transformer in the DAG. The transformed
1004 | data are finally passed to the final estimator that calls
1005 | `score` method. Only valid if the final estimators implement `score`.
1006 |
1007 | Parameters
1008 | ----------
1009 | X : iterable
1010 | Data to predict on. Must fulfill input requirements of first step
1011 | of the DAG.
1012 | y : iterable, default=None
1013 | Targets used for scoring. Must fulfill label requirements for all
1014 | steps of the DAG.
1015 | sample_weight : array-like, default=None
1016 | If not None, this argument is passed as ``sample_weight`` keyword
1017 | argument to the ``score`` method of the final estimator.
1018 |
1019 | Returns
1020 | -------
1021 | score : float
1022 | Result of calling `score` on the final estimator.
1023 | """
1024 | return self._execute("score", X, y, **score_params)
1025 |
1026 | @available_if(_leaf_estimators_have("predict_log_proba"))
1027 | def predict_log_proba(self, X, **predict_log_proba_params):
1028 | """
1029 | Transform the data, and apply `predict_log_proba` with the final estimator.
1030 |
1031 | Call `transform` of each transformer in the DAG. The transformed
1032 | data are finally passed to the final estimator that calls
1033 | `predict_log_proba` method. Only valid if the final estimator
1034 | implements `predict_log_proba`.
1035 |
1036 | Parameters
1037 | ----------
1038 | X : iterable
1039 | Data to predict on. Must fulfill input requirements of first step
1040 | of the DAG.
1041 | **predict_log_proba_params : dict of string -> object
1042 | Parameters to the ``predict_log_proba`` called at the end of all
1043 | transformations in the DAG.
1044 |
1045 | Returns
1046 | -------
1047 | y_log_proba : ndarray of shape (n_samples, n_classes)
1048 | Result of calling `predict_log_proba` on the final estimator.
1049 | """
1050 | return self._execute("predict_log_proba", X, **predict_log_proba_params)
1051 |
1052 | def _check_fit_params(self, **fit_params):
1053 | fit_params_steps = {
1054 | name: {} for (name, step) in self.steps_ if step is not None
1055 | }
1056 | for pname, pval in fit_params.items():
1057 | if pval is None:
1058 | continue
1059 |
1060 | if "__" not in pname:
1061 | raise ValueError(
1062 | f"DAG.fit does not accept the {pname} parameter. "
1063 | "You can pass parameters to specific steps of your "
1064 | "DAG using the stepname__parameter format, e.g. "
1065 | "`DAG.fit(X, y, logisticregression__sample_weight"
1066 | "=sample_weight)`."
1067 | )
1068 | step, param = pname.split("__", 1)
1069 | fit_params_steps[step][param] = pval
1070 | return fit_params_steps
1071 |
1072 | def _check_fn_params(self, **fn_params):
1073 | global_params = {}
1074 | fn_params_steps = {name: {} for (name, step) in self.steps_ if step is not None}
1075 | for pname, pval in fn_params.items():
1076 | if pval is None:
1077 | continue
1078 |
1079 | if "__" not in pname:
1080 | global_params[pname] = pval
1081 | else:
1082 | step, param = pname.split("__", 1)
1083 | fn_params_steps[step][param] = pval
1084 |
1085 | for step in fn_params_steps:
1086 | fn_params_steps[step].update(global_params)
1087 |
1088 | return fn_params_steps
1089 |
1090 | def _validate_graph(self):
1091 | if len(self.graph_) == 0:
1092 | raise ValueError("DAG has no nodes.")
1093 |
1094 | for i, (name, est) in enumerate(self.steps_):
1095 | step = self.graph_.nodes[name]["step"]
1096 | step.index = i
1097 |
1098 | # validate names
1099 | self._validate_names([name for (name, step) in self.steps_])
1100 |
1101 | # validate transformers
1102 | for step in self.roots_ + self.branches_:
1103 | if step in self.leaves_:
1104 | # This will get validated later
1105 | continue
1106 |
1107 | est = step.estimator
1108 | # Unlike pipelines we also allow predictors to be used as a transformer, to support
1109 | # model stacking.
1110 | if (
1111 | not _is_passthrough(est)
1112 | and not _is_transformer(est)
1113 | and not _is_predictor(est)
1114 | ):
1115 | raise TypeError(
1116 | "All intermediate steps should be "
1117 | "transformers and implement fit and transform "
1118 | "or be the string 'passthrough' "
1119 | f"'{est}' (type {type(est)}) doesn't"
1120 | )
1121 |
1122 | # Validate final estimator(s)
1123 | for step in self.leaves_:
1124 | est = step.estimator
1125 | if not _is_passthrough(est) and not hasattr(est, "fit"):
1126 | raise TypeError(
1127 | "Leaf nodes of a DAG should implement fit "
1128 | "or be the string 'passthrough'. "
1129 | f"'{est}' (type {type(est)}) doesn't"
1130 | )
1131 |
1132 | @property
1133 | def graph_(self):
1134 | if not hasattr(self, "_graph"):
1135 | # Read-only view of the graph. We should not modify
1136 | # the original graph.
1137 | self._graph = self.graph.copy(as_view=True)
1138 |
1139 | return self._graph
1140 |
1141 | @property
1142 | def leaves_(self):
1143 | if not hasattr(self, "_leaves"):
1144 | self._leaves = [node for node in self.nodes_ if node.is_leaf]
1145 |
1146 | return self._leaves
1147 |
1148 | @property
1149 | def branches_(self):
1150 | if not hasattr(self, "_branches"):
1151 | self._branches = [
1152 | node for node in self.nodes_ if not node.is_leaf and not node.is_root
1153 | ]
1154 |
1155 | return self._branches
1156 |
1157 | @property
1158 | def roots_(self):
1159 | if not hasattr(self, "_roots"):
1160 | self._roots = [node for node in self.nodes_ if node.is_root]
1161 |
1162 | return self._roots
1163 |
1164 | @property
1165 | def nodes_(self):
1166 | if not hasattr(self, "_nodes"):
1167 | self._nodes = []
1168 | for name, estimator in self.steps_:
1169 | step = self.graph_.nodes[name]["step"]
1170 | if self.graph_.out_degree(name) == 0:
1171 | step.is_leaf = True
1172 | if self.graph_.in_degree(name) == 0:
1173 | step.is_root = True
1174 | self._nodes.append(step)
1175 |
1176 | return self._nodes
1177 |
1178 | @property
1179 | def steps_(self):
1180 | "return list of (name, estimator) tuples to conform with Pipeline interface."
1181 | if not hasattr(self, "_steps"):
1182 | self._steps = [
1183 | (node, self.graph_.nodes[node]["step"].estimator)
1184 | for node in nx.lexicographical_topological_sort(self.graph_)
1185 | ]
1186 |
1187 | return self._steps
1188 |
1189 | def join(self, other, edges, **kwargs):
1190 | """
1191 | Create a new DAG by joining this DAG to another one, according to the edges
1192 | specified.
1193 |
1194 | Parameters
1195 | ----------
1196 |
1197 | other : :class:`skdag.dag.DAG`
1198 | The other DAG to connect to.
1199 | edges : (str, str) or (str, str, index-like)
1200 | ``(u, v)`` edges that connect the two DAGs. ``u`` and ``v`` should be the
1201 | names of steps in the first and second DAG respectively. Optionally a third
1202 | parameter may be included to specify which columns to pass along the edge.
1203 | **kwargs : keyword params
1204 | Any other parameters to pass to the new DAG's constructor.
1205 |
1206 | Returns
1207 | -------
1208 | dag : :class:`skdag.DAG`
1209 | A new DAG, containing a copy of each of the input DAGs, joined by the
1210 | specified edges. Note that the original input dags are unmodified.
1211 |
1212 | Examples
1213 | --------
1214 |
1215 | >>> from sklearn.decomposition import PCA
1216 | >>> from sklearn.impute import SimpleImputer
1217 | >>> from sklearn.linear_model import LogisticRegression
1218 | >>> from sklearn.calibration import CalibratedClassifierCV
1219 | >>> from skdag.dag import DAGBuilder
1220 | >>> dag1 = (
1221 | ... DAGBuilder()
1222 | ... .add_step("impute", SimpleImputer())
1223 | ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)})
1224 | ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)})
1225 | ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"])
1226 | ... .make_dag()
1227 | ... )
1228 | >>> print(dag1.draw().strip())
1229 | o impute
1230 | |\\
1231 | o o blood,vitals
1232 | |/
1233 | o lr
1234 | >>> dag2 = (
1235 | ... DAGBuilder()
1236 | ... .add_step(
1237 | ... "calib",
1238 | ... CalibratedClassifierCV(LogisticRegression(random_state=0), cv=5),
1239 | ... )
1240 | ... .make_dag()
1241 | ... )
1242 | >>> print(dag2.draw().strip())
1243 | o calib
1244 | >>> dag3 = dag1.join(dag2, edges=[("blood", "calib"), ("vitals", "calib")])
1245 | >>> print(dag3.draw().strip())
1246 | o impute
1247 | |\\
1248 | o o blood,vitals
1249 | |x|
1250 | o o calib,lr
1251 | """
1252 | if set(self.step_names) & set(other.step_names):
1253 | raise ValueError("DAGs with overlapping step names cannot be combined.")
1254 |
1255 | newgraph = deepcopy(self.graph_).copy()
1256 | for edge in edges:
1257 | if len(edge) == 2:
1258 | u, v, idx = *edge, None
1259 | else:
1260 | u, v, idx = edge
1261 |
1262 | if u not in self.graph_:
1263 | raise KeyError(u)
1264 | if v not in other.graph_:
1265 | raise KeyError(v)
1266 |
1267 | # source node can no longer be a leaf
1268 | ustep = newgraph.nodes[u]["step"]
1269 | if ustep.is_leaf:
1270 | ustep.is_leaf = False
1271 |
1272 | vnode = other.graph_.nodes[v]
1273 | old_step = vnode["step"]
1274 | vstep = DAGStep(
1275 | name=old_step.name,
1276 | estimator=old_step.estimator,
1277 | deps=old_step.deps,
1278 | dataframe_columns=old_step.dataframe_columns,
1279 | axis=old_step.axis,
1280 | )
1281 |
1282 | if u not in vstep.deps:
1283 | vstep.deps[u] = idx
1284 |
1285 | vnode["step"] = vstep
1286 |
1287 | newgraph.add_node(v, **vnode)
1288 | newgraph.add_edge(u, v)
1289 |
1290 | return DAG(newgraph, **kwargs)
1291 |
1292 | def draw(
1293 | self, filename=None, style=None, detailed=False, format=None, layout="dot"
1294 | ):
1295 | """
1296 | Render a graphical view of the DAG.
1297 |
1298 | By default the rendered file will be returned as a string. However if an output
1299 | file is provided then the output will be saved to file.
1300 |
1301 | Parameters
1302 | ----------
1303 |
1304 | filename : str
1305 | The file to write the image to. If None, the rendered image will be sent to
1306 | stdout.
1307 | style : str, optional, choice of ['light', 'dark']
1308 | Draw the image in light or dark mode.
1309 | detailed : bool, default = False
1310 | If True, show extra details in the node labels such as the estimator
1311 | signature.
1312 | format : str, choice of ['svg', 'png', 'jpg', 'txt']
1313 | The rendering format to use. MAy be omitted if the format can be inferred
1314 | from the filename.
1315 | layout : str, default = 'dot'
1316 | The program to use for generating a graph layout.
1317 |
1318 | See Also
1319 | --------
1320 |
1321 | :meth:`skdag.dag.DAG.show`, for use in interactive notebooks.
1322 |
1323 | Returns
1324 | -------
1325 |
1326 | output : str, bytes or None
1327 | If a filename is provided the output is written to file and `None` is
1328 | returned. Otherwise, the output is returned as a string (for textual formats
1329 | like ascii or svg) or bytes.
1330 | """
1331 | if filename is None and format is None:
1332 | try:
1333 | from IPython import get_ipython
1334 |
1335 | rich = type(get_ipython()).__name__ == "ZMQInteractiveShell"
1336 | except (ModuleNotFoundError, NameError):
1337 | rich = False
1338 |
1339 | format = "svg" if rich else "txt"
1340 |
1341 | if format is None:
1342 | format = filename.split(".")[-1]
1343 |
1344 | if format not in ["svg", "png", "jpg", "txt"]:
1345 | raise ValueError(f"Unsupported file format '{format}'")
1346 |
1347 | render = DAGRenderer(self.graph_, detailed=detailed, style=style).draw(
1348 | format=format, layout=layout
1349 | )
1350 | if filename is None:
1351 | return render
1352 | else:
1353 | mode = "wb" if isinstance(render, bytes) else "w"
1354 | with open(filename, mode) as fp:
1355 | fp.write(render)
1356 |
1357 | def show(self, style=None, detailed=False, format=None, layout="dot"):
1358 | """
1359 | Display a graphical representation of the DAG in an interactive notebook
1360 | environment.
1361 |
1362 | DAGs will be shown when displayed in a notebook, but calling this method
1363 | directly allows more options to be passed to customise the appearance more.
1364 |
1365 | Arguments are as for :meth`.draw`.
1366 |
1367 | Returns
1368 | -------
1369 |
1370 | ``None``
1371 |
1372 | See Also
1373 | --------
1374 |
1375 | :meth:`skdag.DAG.draw`
1376 | """
1377 | if format is None:
1378 | format = "svg" if _in_notebook() else "txt"
1379 |
1380 | data = self.draw(style=style, detailed=detailed, format=format, layout=layout)
1381 | if format == "svg":
1382 | from IPython.display import SVG, display
1383 |
1384 | display(SVG(data))
1385 | elif format == "txt":
1386 | print(data)
1387 | elif format in ("jpg", "png"):
1388 | from IPython.display import Image, display
1389 |
1390 | display(Image(data))
1391 | else:
1392 | raise ValueError(f"'{format}' format not supported.")
1393 |
1394 | def _repr_svg_(self):
1395 | return self.draw(format="svg")
1396 |
1397 | def _repr_png_(self):
1398 | return self.draw(format="png")
1399 |
1400 | def _repr_jpeg_(self):
1401 | return self.draw(format="jpg")
1402 |
1403 | def _repr_html_(self):
1404 | return self.draw(format="svg")
1405 |
1406 | def _repr_pretty_(self, p, cycle):
1407 | if cycle:
1408 | p.text(repr(self))
1409 | else:
1410 | p.text(str(self))
1411 |
1412 | def _repr_mimebundle_(self, **kwargs):
1413 | # Don't render yet...
1414 | renderers = {
1415 | "image/svg+xml": self._repr_svg_,
1416 | "image/png": self._repr_png_,
1417 | "image/jpeg": self._repr_jpeg_,
1418 | "text/plain": self.__str__,
1419 | "text/html": self._repr_html_,
1420 | }
1421 |
1422 | include = kwargs.get("include")
1423 | if include:
1424 | renderers = {k: v for k, v in renderers.items() if k in include}
1425 |
1426 | exclude = kwargs.get("exclude")
1427 | if exclude:
1428 | renderers = {k: v for k, v in renderers.items() if k not in exclude}
1429 |
1430 | # Now render any remaining options.
1431 | return {k: v() for k, v in renderers.items()}
1432 |
1433 | @property
1434 | def named_steps(self):
1435 | """
1436 | Access the steps by name.
1437 |
1438 | Read-only attribute to access any step by given name.
1439 | Keys are steps names and values are the steps objects.
1440 | """
1441 | # Use Bunch object to improve autocomplete
1442 | return Bunch(**dict(self.steps_))
1443 |
1444 | @property
1445 | def step_names(self):
1446 | return list(self.graph_.nodes)
1447 |
1448 | @property
1449 | def edges(self):
1450 | return self.graph_.edges
1451 |
1452 | def _get_leaf_attr(self, attr):
1453 | if len(self.leaves_) == 1:
1454 | return getattr(self.leaves_[0].estimator, attr)
1455 | else:
1456 | return Bunch(
1457 | **{leaf.name: getattr(leaf.estimator, attr) for leaf in self.leaves_}
1458 | )
1459 |
1460 | @property
1461 | def _estimator_type(self):
1462 | return self._get_leaf_attr("_estimator_type")
1463 |
1464 | @property
1465 | def classes_(self):
1466 | """The classes labels. Only exist if the leaf steps are classifiers."""
1467 | return self._get_leaf_attr("classes_")
1468 |
1469 | def __sklearn_is_fitted__(self):
1470 | """Indicate whether DAG has been fit."""
1471 | try:
1472 | # check if the last steps of the DAG are fitted
1473 | # we only check the last steps since if the last steps are fit, it
1474 | # means the previous steps should also be fit. This is faster than
1475 | # checking if every step of the DAG is fit.
1476 | for leaf in self._leaves:
1477 | check_is_fitted(leaf.estimator)
1478 | return True
1479 | except NotFittedError:
1480 | return False
1481 |
1482 | def _more_tags(self):
1483 | tags = {}
1484 |
1485 | # We assume the DAG can handle NaN if *all* the steps can.
1486 | tags["allow_nan"] = all(
1487 | _safe_tags(node.estimator, "allow_nan") for node in self.nodes_
1488 | )
1489 |
1490 | # Check if all *root* nodes expect pairwise input.
1491 | tags["pairwise"] = all(
1492 | _safe_tags(root.estimator, "pairwise") for root in self.roots_
1493 | )
1494 |
1495 | # CHeck if all *leaf* notes support multioutput
1496 | tags["multioutput"] = all(
1497 | _safe_tags(leaf.estimator, "multioutput") for leaf in self.leaves_
1498 | )
1499 |
1500 | return tags
1501 |
--------------------------------------------------------------------------------
/skdag/dag/_render.py:
--------------------------------------------------------------------------------
1 | import html
2 | from typing import Iterable
3 |
4 | import black
5 | from matplotlib.pyplot import isinteractive
6 | import networkx as nx
7 | import stackeddag.core as sd
8 | from skdag.dag._utils import _is_passthrough
9 |
10 | __all__ = ["DAGRenderer"]
11 |
12 |
13 | class DAGRenderer:
14 | _EMPTY = "[empty]"
15 | _STYLES = {
16 | "light": {
17 | "node__color": "black",
18 | "node__fontcolor": "black",
19 | "edge__color": "black",
20 | "graph__bgcolor": "white",
21 | },
22 | "dark": {
23 | "node__color": "white",
24 | "node__fontcolor": "white",
25 | "edge__color": "white",
26 | "graph__bgcolor": "none",
27 | },
28 | }
29 |
30 | def __init__(self, dag, detailed=False, style=None):
31 | self.dag = dag
32 | if style is not None and style not in self._STYLES:
33 | raise ValueError(f"Unknown style: '{style}'.")
34 | self.style = style
35 | self.agraph = self.to_agraph(detailed)
36 |
37 | def _get_node_shape(self, estimator):
38 | if _is_passthrough(estimator):
39 | return "hexagon"
40 | if any(hasattr(estimator, attr) for attr in ["fit_predict", "predict"]):
41 | return "ellipse"
42 | return "box"
43 |
44 | def _is_empty(self):
45 | return len(self.dag) == 0
46 |
47 | def to_agraph(self, detailed):
48 | G = self.dag
49 | if self._is_empty():
50 | G = nx.DiGraph()
51 | G.add_node(
52 | self._EMPTY,
53 | shape="box",
54 | )
55 |
56 | try:
57 | A = nx.nx_agraph.to_agraph(G)
58 | except (ImportError, ModuleNotFoundError) as err: # pragma: no cover
59 | raise ImportError(
60 | "DAG visualisation requires pygraphviz to be installed. "
61 | "See http://pygraphviz.github.io/ for guidance."
62 | ) from err
63 |
64 | A.graph_attr["rankdir"] = "LR"
65 | if self.style is not None:
66 | A.graph_attr.update(
67 | {
68 | key.replace("graph__", ""): val
69 | for key, val in self._STYLES[self.style].items()
70 | if key.startswith("graph__")
71 | }
72 | )
73 |
74 | mode = black.FileMode()
75 | mode.line_length = 16
76 |
77 | for v in G.nodes:
78 | anode = A.get_node(v)
79 | gnode = G.nodes[v]
80 | anode.attr["fontname"] = gnode.get("fontname", "SANS")
81 | if self.style is not None:
82 | anode.attr.update(
83 | {
84 | key.replace("node__", ""): val
85 | for key, val in self._STYLES[self.style].items()
86 | if key.startswith("node__")
87 | }
88 | )
89 | if "step" in gnode:
90 | estimator = gnode["step"].estimator
91 | anode.attr["tooltip"] = repr(estimator)
92 |
93 | if detailed:
94 | estimator_str = html.escape(
95 | black.format_str(repr(estimator), mode=mode)
96 | ).replace("\n", ' ')
97 |
98 | anode.attr["label"] = (
99 | '<
'
100 | f'
{v}
'
101 | f'
{estimator_str}
'
102 | "
>"
103 | )
104 |
105 | anode.attr["shape"] = gnode.get(
106 | "shape", self._get_node_shape(estimator)
107 | )
108 | if gnode["step"].is_fitted:
109 | anode.attr["peripheries"] = 2
110 |
111 | for u, v in G.edges:
112 | aedge = A.get_edge(u, v)
113 | if self.style is not None:
114 | aedge.attr.update(
115 | {
116 | key.replace("edge__", ""): val
117 | for key, val in self._STYLES[self.style].items()
118 | if key.startswith("edge__")
119 | }
120 | )
121 | cols = G.nodes[v]["step"].deps[u]
122 | if cols:
123 | if isinstance(cols, Iterable):
124 | cols = list(cols)
125 |
126 | if len(cols) > 5:
127 | colrepr = f"[{repr(cols[0])}, ..., {repr(cols[-1])}]"
128 | else:
129 | colrepr = repr(cols)
130 | elif callable(cols):
131 | selector = cols
132 | cols = {}
133 | for attr in ["pattern", "dtype_include", "dtype_exclude"]:
134 | if hasattr(selector, attr):
135 | val = getattr(selector, attr)
136 | if val is not None:
137 | cols[attr] = val
138 | if cols:
139 | selrepr = ", ".join(
140 | f"{key}={repr(val)}" for key, val in cols.items()
141 | )
142 | colrepr = f"column_selector({selrepr})"
143 | else:
144 | colrepr = f"{selector.__name__}()"
145 | else:
146 | colrepr = repr(cols)
147 |
148 | aedge.attr.update(
149 | {"label": colrepr, "fontsize": "8pt", "fontname": "SANS"}
150 | )
151 |
152 | A.layout()
153 | return A
154 |
155 | def draw(self, format="svg", layout="dot"):
156 | A = self.agraph
157 |
158 | if format == "txt":
159 | if self._is_empty():
160 | return self._EMPTY
161 |
162 | # Edge case: single node, no edges.
163 | if len(A.edges()) == 0:
164 | return f"o {next(A.nodes_iter())}"
165 |
166 | la = []
167 | for node in A.nodes():
168 | la.append((node, node))
169 |
170 | ed = []
171 | for src, dst in A.edges():
172 | ed.append((src, [dst]))
173 |
174 | return sd.edgesToText(sd.mkLabels(la), sd.mkEdges(ed)).strip()
175 |
176 | data = A.draw(format=format, prog=layout)
177 |
178 | if format == "svg":
179 | return data.decode(self.agraph.encoding)
180 | return data
181 |
--------------------------------------------------------------------------------
/skdag/dag/_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import sparse
3 |
4 | try:
5 | import pandas as pd
6 | except ImportError:
7 | pd = None
8 |
9 |
10 | def _is_passthrough(estimator):
11 | return estimator is None or estimator == "passthrough"
12 |
13 |
14 | def _is_transformer(estimator):
15 | return (
16 | hasattr(estimator, "fit") or hasattr(estimator, "fit_transform")
17 | ) and hasattr(estimator, "transform")
18 |
19 |
20 | def _is_predictor(estimator):
21 | return (hasattr(estimator, "fit") or hasattr(estimator, "fit_predict")) and hasattr(
22 | estimator, "predict"
23 | )
24 |
25 |
26 | def _in_notebook():
27 | try:
28 | from IPython import get_ipython
29 |
30 | if "IPKernelApp" not in get_ipython().config: # pragma: no cover
31 | return False
32 | except ImportError:
33 | return False
34 | except AttributeError:
35 | return False
36 | return True
37 |
38 |
39 | def _stack(Xs, axis=0):
40 | """
41 | Where an estimator has multiple upstream dependencies, this method defines the
42 | strategy for merging the upstream outputs into a single input. ``axis=0`` is
43 | equivalent to a vstack (combination of samples) and ``axis=1`` is equivalent to a
44 | hstack (combination of features). Higher axes are only supported for non-sparse
45 | data sources.
46 | """
47 | if any(sparse.issparse(x) for x in Xs):
48 | if -2 <= axis < 2:
49 | axis = axis % 2
50 | else:
51 | raise NotImplementedError(
52 | f"Stacking is not supported for sparse inputs on axis > 1."
53 | )
54 |
55 | if axis == 0:
56 | Xs = sparse.vstack(Xs).tocsr()
57 | elif axis == 1:
58 | Xs = sparse.hstack(Xs).tocsr()
59 | elif pd and all(_is_pandas(x) for x in Xs):
60 | Xs = pd.concat(Xs, axis=axis)
61 | else:
62 | if axis == 1:
63 | Xs = np.hstack(Xs)
64 | else:
65 | Xs = np.stack(Xs, axis=axis)
66 |
67 | return Xs
68 |
69 |
70 | def _is_pandas(X):
71 | "Check if X is a DataFrame or Series"
72 | return hasattr(X, "iloc")
73 |
74 |
75 | def _get_feature_names(estimator):
76 | try:
77 | feature_names = estimator.get_feature_names_out()
78 | except AttributeError:
79 | try:
80 | feature_names = estimator.get_feature_names()
81 | except AttributeError:
82 | feature_names = None
83 |
84 | return feature_names
85 |
86 |
87 | def _format_output(X, input, node):
88 | if node.dataframe_columns is None or pd is None or _is_pandas(X):
89 | return X
90 |
91 | outshape = np.asarray(X).shape
92 | outdim = len(outshape)
93 | if outdim > 2 or outdim < 1:
94 | return X
95 |
96 | if node.dataframe_columns == "infer":
97 | if outdim == 1:
98 | columns = [node.name]
99 | else:
100 | feature_names = _get_feature_names(node.estimator)
101 | if feature_names is None:
102 | feature_names = range(outshape[1])
103 |
104 | columns = [f"{node.name}__{f}" for f in feature_names]
105 | else:
106 | columns = node.dataframe_columns
107 |
108 | if hasattr(input, "index"):
109 | index = input.index
110 | else:
111 | index = None
112 |
113 | if outdim == 2:
114 | df = pd.DataFrame(X, columns=columns, index=index)
115 | else:
116 | df = pd.Series(X, name=columns[0], index=index)
117 |
118 | return df
119 |
--------------------------------------------------------------------------------
/skdag/dag/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scikit-learn-contrib/skdag/a04f75f58b3126b98d304a5a03b49e3861108eff/skdag/dag/tests/__init__.py
--------------------------------------------------------------------------------
/skdag/dag/tests/test_builder.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from skdag import DAG, DAGBuilder
3 | from skdag.dag.tests.utils import Mult, Transf
4 | from skdag.exceptions import DAGError
5 | from sklearn.pipeline import Pipeline
6 |
7 |
8 | def test_builder_basics():
9 | builder = DAGBuilder()
10 | builder.add_step("tr", Transf())
11 | builder.add_step("ident", "passthrough", deps=["tr"])
12 | builder.add_step("est", Mult(), deps=["ident"])
13 |
14 | dag = builder.make_dag()
15 | assert isinstance(dag, DAG)
16 | assert dag._repr_html_() == builder._repr_html_()
17 |
18 | with pytest.raises(ValueError):
19 | builder.add_step("est2", Mult(), deps=["missing"])
20 |
21 | with pytest.raises(ValueError):
22 | builder.add_step("est2", Mult(), deps=[1])
23 |
24 | with pytest.raises(KeyError):
25 | # Step names *must* be strings.
26 | builder.add_step(2, Mult())
27 |
28 | with pytest.raises(KeyError):
29 | # Step names *must* be unique.
30 | builder.add_step("est", Mult())
31 |
32 | with pytest.raises(DAGError):
33 | # Cycles not allowed.
34 | builder.graph.add_edge("est", "ident")
35 | builder.make_dag()
36 |
37 |
38 | def test_pipeline():
39 | steps = [("tr", Transf()), ("ident", "passthrough"), ("est", Mult())]
40 |
41 | builder = DAGBuilder()
42 | prev = None
43 | for name, est in steps:
44 | builder.add_step(name, est, deps=[prev] if prev else None)
45 | prev = name
46 |
47 | dag = builder.make_dag()
48 |
49 | assert dag.steps_ == steps
50 |
--------------------------------------------------------------------------------
/skdag/dag/tests/test_dag.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the DAG module.
3 | """
4 | import re
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import pytest
9 | from skdag import DAGBuilder
10 | from skdag.dag.tests.utils import FitParamT, Mult, NoFit, NoTrans, Transf
11 | from sklearn import datasets
12 | from sklearn.base import clone
13 | from sklearn.compose import make_column_selector
14 | from sklearn.decomposition import PCA
15 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
16 | from sklearn.feature_selection import SelectKBest, f_classif
17 | from sklearn.impute import SimpleImputer
18 | from sklearn.linear_model import LinearRegression, LogisticRegression
19 | from sklearn.pipeline import Pipeline
20 | from sklearn.preprocessing import StandardScaler
21 | from sklearn.svm import SVC
22 | from sklearn.utils._testing import assert_array_almost_equal
23 | from sklearn.utils.estimator_checks import parametrize_with_checks
24 |
25 | iris = datasets.load_iris()
26 | cancer = datasets.load_breast_cancer()
27 |
28 | JUNK_FOOD_DOCS = (
29 | "the pizza pizza beer copyright",
30 | "the pizza burger beer copyright",
31 | "the the pizza beer beer copyright",
32 | "the burger beer beer copyright",
33 | "the coke burger coke copyright",
34 | "the coke burger burger",
35 | )
36 |
37 |
38 | def test_dag_invalid_parameters():
39 | # Test the various init parameters of the dag in fit
40 | # method
41 | with pytest.raises(KeyError):
42 | dag = DAGBuilder().from_pipeline([(1, 1)]).make_dag()
43 |
44 | # Check that we can't fit DAGs with objects without fit
45 | # method
46 | msg = (
47 | "Leaf nodes of a DAG should implement fit or be the string 'passthrough'"
48 | ".*NoFit.*"
49 | )
50 | dag = DAGBuilder().from_pipeline([("clf", NoFit())]).make_dag()
51 | with pytest.raises(TypeError, match=msg):
52 | dag.fit([[1]], [1])
53 |
54 | # Smoke test with only an estimator
55 | clf = NoTrans()
56 | dag = DAGBuilder().from_pipeline([("svc", clf)]).make_dag()
57 | assert dag.get_params(deep=True) == dict(
58 | svc__a=None, svc__b=None, svc=clf, **dag.get_params(deep=False)
59 | )
60 |
61 | # Check that params are set
62 | dag.set_params(svc__a=0.1)
63 | assert clf.a == 0.1
64 | assert clf.b is None
65 | # Smoke test the repr:
66 | repr(dag)
67 |
68 | # Test with two objects
69 | clf = SVC()
70 | filter1 = SelectKBest(f_classif)
71 | dag = DAGBuilder().from_pipeline([("anova", filter1), ("svc", clf)]).make_dag()
72 |
73 | # Check that estimators are not cloned on pipeline construction
74 | assert dag.named_steps["anova"] is filter1
75 | assert dag.named_steps["svc"] is clf
76 |
77 | # Check that we can't fit with non-transformers on the way
78 | # Note that NoTrans implements fit, but not transform
79 | msg = "All intermediate steps should be transformers.*\\bNoTrans\\b.*"
80 | dag2 = DAGBuilder().from_pipeline([("t", NoTrans()), ("svc", clf)]).make_dag()
81 | with pytest.raises(TypeError, match=msg):
82 | dag2.fit([[1]], [1])
83 |
84 | # Check that params are set
85 | dag.set_params(svc__C=0.1)
86 | assert clf.C == 0.1
87 | # Smoke test the repr:
88 | repr(dag)
89 |
90 | # Check that params are not set when naming them wrong
91 | msg = re.escape(
92 | "Invalid parameter 'C' for estimator SelectKBest(). Valid parameters are: ['k',"
93 | " 'score_func']."
94 | )
95 | with pytest.raises(ValueError, match=msg):
96 | dag.set_params(anova__C=0.1)
97 |
98 | # Test clone
99 | dag2 = clone(dag)
100 | assert not dag.named_steps["svc"] is dag2.named_steps["svc"]
101 |
102 | # Check that apart from estimators, the parameters are the same
103 | params = dag.get_params(deep=True)
104 | params2 = dag2.get_params(deep=True)
105 |
106 | for x in dag.get_params(deep=False):
107 | params.pop(x)
108 |
109 | for x in dag.get_params(deep=False):
110 | params2.pop(x)
111 |
112 | # Remove estimators that where copied
113 | params.pop("svc")
114 | params.pop("anova")
115 | params2.pop("svc")
116 | params2.pop("anova")
117 | assert params == params2
118 |
119 |
120 | def test_dag_simple():
121 | # Build a simple DAG of one node
122 | rng = np.random.default_rng(1)
123 | X = rng.random(size=(20, 5))
124 |
125 | dag = DAGBuilder().add_step("clf", FitParamT()).make_dag()
126 | dag.fit(X)
127 | dag.predict(X)
128 |
129 |
130 | def test_dag_pipeline_init():
131 | # Build a dag from a pipeline
132 | X = np.array([[1, 2]])
133 | steps = (("transf", Transf()), ("clf", FitParamT()))
134 | pipe = Pipeline(steps, verbose=False)
135 | for inp in [pipe, steps]:
136 | dag = DAGBuilder().from_pipeline(inp).make_dag()
137 | dag.fit(X, y=None)
138 | dag.score(X)
139 |
140 | dag.set_params(transf="passthrough")
141 | dag.fit(X, y=None)
142 | dag.score(X)
143 |
144 |
145 | def test_dag_methods_anova():
146 | # Test the various methods of the dag (anova).
147 | X = iris.data
148 | y = iris.target
149 | # Test with Anova + LogisticRegression
150 | clf = LogisticRegression()
151 | filter1 = SelectKBest(f_classif, k=2)
152 | dag1 = (
153 | DAGBuilder().from_pipeline([("anova", filter1), ("logistic", clf)]).make_dag()
154 | )
155 | dag2 = (
156 | DAGBuilder()
157 | .add_step("anova", filter1)
158 | .add_step("logistic", clf, deps=["anova"])
159 | .make_dag()
160 | )
161 | dag1.fit(X, y)
162 | dag2.fit(X, y)
163 | assert_array_almost_equal(dag1.predict(X), dag2.predict(X))
164 | assert_array_almost_equal(dag1.predict_proba(X), dag2.predict_proba(X))
165 | assert_array_almost_equal(dag1.predict_log_proba(X), dag2.predict_log_proba(X))
166 | assert_array_almost_equal(dag1.score(X, y), dag2.score(X, y))
167 |
168 |
169 | def test_dag_fit_params():
170 | # Test that the pipeline can take fit parameters
171 | dag = (
172 | DAGBuilder()
173 | .from_pipeline([("transf", Transf()), ("clf", FitParamT())])
174 | .make_dag()
175 | )
176 | dag.fit(X=None, y=None, clf__should_succeed=True)
177 | # classifier should return True
178 | assert dag.predict(None)
179 | # and transformer params should not be changed
180 | assert dag.named_steps["transf"].a is None
181 | assert dag.named_steps["transf"].b is None
182 | # invalid parameters should raise an error message
183 |
184 | msg = re.escape("fit() got an unexpected keyword argument 'bad'")
185 | with pytest.raises(TypeError, match=msg):
186 | dag.fit(None, None, clf__bad=True)
187 |
188 |
189 | def test_dag_sample_weight_supported():
190 | # DAG should pass sample_weight
191 | X = np.array([[1, 2]])
192 | dag = (
193 | DAGBuilder()
194 | .from_pipeline([("transf", Transf()), ("clf", FitParamT())])
195 | .make_dag()
196 | )
197 | dag.fit(X, y=None)
198 | assert dag.score(X) == 3
199 | assert dag.score(X, y=None) == 3
200 | assert dag.score(X, y=None, sample_weight=None) == 3
201 | assert dag.score(X, sample_weight=np.array([2, 3])) == 8
202 |
203 |
204 | def test_dag_sample_weight_unsupported():
205 | # When sample_weight is None it shouldn't be passed
206 | X = np.array([[1, 2]])
207 | dag = DAGBuilder().from_pipeline([("transf", Transf()), ("clf", Mult())]).make_dag()
208 | dag.fit(X, y=None)
209 | assert dag.score(X) == 3
210 | assert dag.score(X, sample_weight=None) == 3
211 |
212 | msg = re.escape("score() got an unexpected keyword argument 'sample_weight'")
213 | with pytest.raises(TypeError, match=msg):
214 | dag.score(X, sample_weight=np.array([2, 3]))
215 |
216 |
217 | def test_dag_raise_set_params_error():
218 | # Test dag raises set params error message for nested models.
219 | dag = DAGBuilder().from_pipeline([("cls", LinearRegression())]).make_dag()
220 |
221 | # expected error message
222 | error_msg = (
223 | r"Invalid parameter 'fake' for estimator DAG\(graph=]*>"
224 | r"\)\. Valid parameters are: \['graph', 'memory', 'n_jobs', 'verbose'\]."
225 | )
226 | with pytest.raises(ValueError, match=error_msg):
227 | dag.set_params(fake="nope")
228 |
229 | # invalid outer parameter name for compound parameter: the expected error message
230 | # is the same as above.
231 | with pytest.raises(ValueError, match=error_msg):
232 | dag.set_params(fake__estimator="nope")
233 |
234 | # expected error message for invalid inner parameter
235 | error_msg = (
236 | r"Invalid parameter 'invalid_param' for estimator LinearRegression\(\)\. Valid"
237 | r" parameters are: \['copy_X', 'fit_intercept', 'n_jobs', 'normalize',"
238 | r" 'positive'\]."
239 | )
240 | with pytest.raises(ValueError, match=error_msg):
241 | dag.set_params(cls__invalid_param="nope")
242 |
243 |
244 | @pytest.mark.parametrize("idx", [1, [1]])
245 | def test_dag_stacking_pca_svm_rf(idx):
246 | # Test the various methods of the pipeline (pca + svm).
247 | X = cancer.data
248 | y = cancer.target
249 | # Build a simple model stack with some preprocessing.
250 | pca = PCA(svd_solver="full", n_components="mle", whiten=True)
251 | svc = SVC(probability=True, random_state=0)
252 | rf = RandomForestClassifier(random_state=0)
253 | log = LogisticRegression()
254 |
255 | dag = (
256 | DAGBuilder()
257 | .add_step("pca", pca)
258 | .add_step("svc", svc, deps=["pca"])
259 | .add_step("rf", rf, deps=["pca"])
260 | .add_step("log", log, deps={"svc": idx, "rf": idx})
261 | .make_dag()
262 | )
263 | dag.fit(X, y)
264 |
265 | prob_shape = len(cancer.target), len(cancer.target_names)
266 | tgt_shape = cancer.target.shape
267 |
268 | assert dag.predict_proba(X).shape == prob_shape
269 | assert dag.predict(X).shape == tgt_shape
270 | assert dag.predict_log_proba(X).shape == prob_shape
271 | assert isinstance(dag.score(X, y), (float, np.floating))
272 |
273 | root = dag["log"]
274 | for attr in ["n_features_in_", "feature_names_in_"]:
275 | if hasattr(root, attr):
276 | assert hasattr(dag, attr)
277 |
278 |
279 | def test_dag_draw():
280 | txt = DAGBuilder().make_dag().draw(format="txt")
281 | assert "[empty]" in txt
282 |
283 | svg = DAGBuilder().make_dag().draw(format="svg")
284 | assert "[empty]" in svg
285 |
286 | # Build a simple model stack with some preprocessing.
287 | pca = PCA(svd_solver="full", n_components="mle", whiten=True)
288 | svc = SVC(probability=True, random_state=0)
289 | rf = RandomForestClassifier(random_state=0)
290 | log = LogisticRegression()
291 |
292 | dag = (
293 | DAGBuilder()
294 | .add_step("pca", pca)
295 | .add_step("svc1", svc, deps={"pca": slice(4)})
296 | .add_step("svc2", svc, deps={"pca": [0, 1, 2]})
297 | .add_step(
298 | "svc3", svc, deps={"pca": lambda X: [c for c in X if c.startswith("foo")]}
299 | )
300 | .add_step("rf1", rf, deps={"pca": [0, 1, 2, 3, 4, 5, 6, 7, 8]})
301 | .add_step("rf2", rf, deps={"pca": make_column_selector(pattern="^pca.*")})
302 | .add_step("log", log, deps=["svc1", "svc2", "rf1", "rf2"])
303 | .make_dag()
304 | )
305 |
306 | for repr_method in [fn for fn in dir(dag) if fn.startswith("_repr_")]:
307 | if repr_method == "_repr_pretty_":
308 | try:
309 | from IPython.lib.pretty import PrettyPrinter
310 | except ImportError: # pragma: no cover
311 | continue
312 | from io import StringIO
313 |
314 | sio = StringIO()
315 | getattr(dag, repr_method)(PrettyPrinter(sio), False)
316 | sio.seek(0)
317 | out = sio.read()
318 | else:
319 | out = getattr(dag, repr_method)()
320 |
321 | if repr_method == "_repr_mimebundle_":
322 | for mimetype, data in out.items():
323 | if mimetype in ("image/png", "image/jpeg"):
324 | expected = bytes
325 | else:
326 | expected = str
327 | assert isinstance(
328 | data, expected
329 | ), f"{repr_method} {mimetype} returns unexpected type {data}"
330 | elif repr_method in ("_repr_jpeg_", "_repr_png_"):
331 | assert isinstance(
332 | out, bytes
333 | ), f"{repr_method} returns unexpected type {out}"
334 | else:
335 | assert isinstance(out, str), f"{repr_method} returns unexpected type {out}"
336 |
337 | txt = dag.draw(format="txt")
338 | for step in dag.step_names:
339 | assert step in txt
340 |
341 | svg = dag.draw(format="svg")
342 | for step in dag.step_names:
343 | assert f"{step}" in svg
344 |
345 | svg = dag.draw(format="svg", style="dark")
346 | for step in dag.step_names:
347 | assert f"{step}" in svg
348 |
349 | with pytest.raises(ValueError):
350 | badstyle = "foo"
351 | dag.draw(format="svg", style=badstyle)
352 |
353 | svg = dag.draw(format="svg", detailed=True)
354 | for step, est in dag.steps_:
355 | assert f"{step}" in svg
356 | assert f"{type(est).__name__}" in svg
357 |
358 |
359 | def _dag_from_steplist(steps, **builder_opts):
360 | builder = DAGBuilder(**builder_opts)
361 | for step in steps:
362 | builder.add_step(**step)
363 | return builder.make_dag()
364 |
365 |
366 | @pytest.mark.parametrize(
367 | "steps",
368 | [
369 | [
370 | {
371 | "name": "pca",
372 | "est": PCA(n_components=1),
373 | },
374 | {
375 | "name": "svc",
376 | "est": SVC(probability=True, random_state=0),
377 | "deps": ["pca"],
378 | },
379 | {
380 | "name": "rf",
381 | "est": RandomForestClassifier(random_state=0),
382 | "deps": ["pca"],
383 | },
384 | {
385 | "name": "log",
386 | "est": LogisticRegression(),
387 | "deps": ["svc", "rf"],
388 | },
389 | ],
390 | ],
391 | )
392 | @pytest.mark.parametrize(
393 | "X,y",
394 | [datasets.make_blobs(n_samples=200, n_features=10, centers=3, random_state=0)],
395 | )
396 | def test_pandas(X, y, steps):
397 | dag_np = _dag_from_steplist(steps, infer_dataframe=False)
398 | dag_pd = _dag_from_steplist(steps, infer_dataframe=True)
399 |
400 | dag_np.fit(X, y)
401 | dag_pd.fit(X, y)
402 |
403 | y_pred_np = dag_np.predict_proba(X)
404 | y_pred_pd = dag_pd.predict_proba(X)
405 | assert isinstance(y_pred_np, np.ndarray)
406 | assert isinstance(y_pred_pd, pd.DataFrame)
407 | assert np.allclose(y_pred_np, y_pred_pd)
408 |
409 |
410 | @pytest.mark.parametrize("input_passthrough", [False, True])
411 | def test_pandas_indexing(input_passthrough):
412 | X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
413 |
414 | passcols = ["age", "sex", "bmi", "bp"]
415 |
416 | builder = DAGBuilder(infer_dataframe=True)
417 | if input_passthrough:
418 | builder.add_step("inp", "passthrough")
419 |
420 | preprocessing = (
421 | builder.add_step(
422 | "imp", SimpleImputer(), deps=["inp"] if input_passthrough else None
423 | )
424 | .add_step("vitals", "passthrough", deps={"imp": passcols})
425 | .add_step(
426 | "blood",
427 | PCA(n_components=2, random_state=0),
428 | deps={"imp": make_column_selector("s[0-9]+")},
429 | )
430 | .add_step("out", "passthrough", deps=["vitals", "blood"])
431 | .make_dag()
432 | )
433 |
434 | X_tr = preprocessing.fit_transform(X, y)
435 | assert isinstance(X_tr, pd.DataFrame)
436 | assert (X_tr.index == X.index).all()
437 | assert X_tr.columns.tolist() == [f"imp__{col}" for col in passcols] + [
438 | "blood__pca0",
439 | "blood__pca1",
440 | ]
441 |
442 | predictor = (
443 | DAGBuilder(infer_dataframe=True)
444 | .add_step("rf", RandomForestRegressor(random_state=0))
445 | .make_dag()
446 | )
447 |
448 | dag = preprocessing.join(
449 | predictor,
450 | edges=[("out", "rf")],
451 | )
452 |
453 | y_pred = dag.fit_predict(X, y)
454 |
455 | assert isinstance(y_pred, pd.Series)
456 | assert (y_pred.index == y.index).all()
457 | assert y_pred.name == dag.leaves_[0].name
458 |
459 |
460 | @parametrize_with_checks(
461 | [
462 | DAGBuilder().from_pipeline([("ss", StandardScaler())]).make_dag(),
463 | DAGBuilder().from_pipeline([("lr", LinearRegression())]).make_dag(),
464 | (
465 | DAGBuilder()
466 | .add_step("pca", PCA(n_components=1))
467 | .add_step("svc", SVC(probability=True, random_state=0), deps=["pca"])
468 | .add_step("rf", RandomForestClassifier(random_state=0), deps=["pca"])
469 | .add_step("log", LogisticRegression(), deps=["svc", "rf"])
470 | .make_dag()
471 | ),
472 | ]
473 | )
474 | def test_dag_check_estimator(estimator, check):
475 | # Since some parameters are estimators, we expect them to be modified during fit(),
476 | # which is why we skip these checks (in line with checks on other metaestimators
477 | # like Pipeline)
478 | if check.func.__name__ in [
479 | "check_estimators_overwrite_params",
480 | "check_dont_overwrite_parameters",
481 | ]:
482 | # we don't clone in pipeline or feature union
483 | return
484 | check(estimator)
485 |
--------------------------------------------------------------------------------
/skdag/dag/tests/utils.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 | import numpy as np
4 | from sklearn.base import BaseEstimator
5 |
6 |
7 | class NoFit: # pragma: no cover
8 | """Small class to test parameter dispatching."""
9 |
10 | def __init__(self, a=None, b=None):
11 | self.a = a
12 | self.b = b
13 |
14 |
15 | class NoTrans(NoFit): # pragma: no cover
16 | def fit(self, X, y):
17 | return self
18 |
19 | def get_params(self, deep=False):
20 | return {"a": self.a, "b": self.b}
21 |
22 | def set_params(self, **params):
23 | self.a = params["a"]
24 | return self
25 |
26 |
27 | class NoInvTransf(NoTrans): # pragma: no cover
28 | def transform(self, X):
29 | return X
30 |
31 |
32 | class Transf(NoInvTransf): # pragma: no cover
33 | def inverse_transform(self, X):
34 | return X
35 |
36 |
37 | class TransfFitParams(Transf): # pragma: no cover
38 | def fit(self, X, y, **fit_params):
39 | self.fit_params = fit_params
40 | return self
41 |
42 |
43 | class Mult(BaseEstimator): # pragma: no cover
44 | def __init__(self, mult=1):
45 | self.mult = mult
46 |
47 | def fit(self, X, y):
48 | return self
49 |
50 | def transform(self, X):
51 | return np.asarray(X) * self.mult
52 |
53 | def inverse_transform(self, X):
54 | return np.asarray(X) / self.mult
55 |
56 | def predict(self, X):
57 | return (np.asarray(X) * self.mult).sum(axis=1)
58 |
59 | predict_proba = predict_log_proba = decision_function = predict
60 |
61 | def score(self, X, y=None):
62 | return np.sum(X)
63 |
64 |
65 | class FitParamT(BaseEstimator): # pragma: no cover
66 | """Mock classifier"""
67 |
68 | def __init__(self):
69 | self.successful = False
70 |
71 | def fit(self, X, y, should_succeed=False):
72 | self.successful = should_succeed
73 | return self
74 |
75 | def predict(self, X):
76 | return self.successful
77 |
78 | def fit_predict(self, X, y, should_succeed=False):
79 | self.fit(X, y, should_succeed=should_succeed)
80 | return self.predict(X)
81 |
82 | def score(self, X, y=None, sample_weight=None):
83 | if sample_weight is not None:
84 | X = X * sample_weight
85 | return np.sum(X)
86 |
87 |
88 | class DummyTransf(Transf): # pragma: no cover
89 | """Transformer which store the column means"""
90 |
91 | def fit(self, X, y):
92 | self.means_ = np.mean(X, axis=0)
93 | # store timestamp to figure out whether the result of 'fit' has been
94 | # cached or not
95 | self.timestamp_ = time()
96 | return self
97 |
98 |
99 | class DummyEstimatorParams(BaseEstimator): # pragma: no cover
100 | """Mock classifier that takes params on predict"""
101 |
102 | def fit(self, X, y):
103 | return self
104 |
105 | def predict(self, X, got_attribute=False):
106 | self.got_attribute = got_attribute
107 | return self
108 |
109 | def predict_proba(self, X, got_attribute=False):
110 | self.got_attribute = got_attribute
111 | return self
112 |
113 | def predict_log_proba(self, X, got_attribute=False):
114 | self.got_attribute = got_attribute
115 | return self
116 |
--------------------------------------------------------------------------------
/skdag/exceptions.py:
--------------------------------------------------------------------------------
1 | class DAGError(Exception):
2 | "An exception indicating an error in constructing the requested DAG."
--------------------------------------------------------------------------------