├── .github └── workflows │ ├── lint.yml │ ├── test.yml │ └── upload-to-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.rst ├── LICENSE ├── README.rst ├── ci └── matplotlibrc ├── doc ├── Makefile ├── api.rst ├── changelog.rst ├── conf.py ├── formats.ipynb ├── index.rst ├── make.bat └── requirements.txt ├── examples ├── README.txt ├── plot_customize_after_plot.py ├── plot_diabetes.py ├── plot_discrete.py ├── plot_generated.py ├── plot_hide.py ├── plot_highlight.py ├── plot_highlight_categories.py ├── plot_missingness.py ├── plot_sizing.py ├── plot_theming.py └── plot_vertical.py ├── pyproject.toml ├── setup.cfg ├── setup.py └── upsetplot ├── __init__.py ├── data.py ├── plotting.py ├── reformat.py ├── tests ├── __init__.py ├── test_data.py ├── test_examples.py ├── test_reformat.py └── test_upsetplot.py └── util.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: [push] 3 | jobs: 4 | lint: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - name: Set up Python 9 | uses: actions/setup-python@v4 10 | with: 11 | python-version: '3.x' 12 | - name: Install dependencies 13 | run: | 14 | python -m pip install --upgrade pip 15 | pip install . 16 | - name: Lint with Ruff 17 | run: | 18 | pip install ruff 19 | ruff --output-format=github . 20 | - name: Check format with ruff 21 | run: ruff format --check 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: [push] 3 | jobs: 4 | test: 5 | runs-on: ubuntu-latest 6 | strategy: 7 | matrix: 8 | conda-deps: 9 | - python=3.8 pandas=1.0 matplotlib=3.1.2 numpy=1.17 10 | - pandas matplotlib numpy 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: conda-incubator/setup-miniconda@v3 14 | with: 15 | auto-update-conda: true 16 | - name: conda debug info 17 | shell: bash -el {0} 18 | run: conda info -a 19 | - name: install 20 | shell: bash -el {0} 21 | run: | 22 | conda install pytest pytest-cov coveralls ${{ matrix.conda-deps }} 23 | python setup.py install 24 | cp ci/matplotlibrc matplotlibrc 25 | - name: test 26 | shell: bash -el {0} 27 | run: pytest 28 | - name: Coveralls 29 | uses: coverallsapp/github-action@v2 30 | -------------------------------------------------------------------------------- /.github/workflows/upload-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install build 20 | - name: Build package 21 | run: python -m build 22 | - name: Publish package 23 | uses: pypa/gh-action-pypi-publish@v1 24 | with: 25 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ######################################### 2 | # Editor temporary/working/backup files # 3 | .#* 4 | *\#*\# 5 | [#]*# 6 | *~ 7 | *$ 8 | *.bak 9 | *flymake* 10 | *.kdev4 11 | *.log 12 | *.swp 13 | *.pdb 14 | .project 15 | .pydevproject 16 | .settings 17 | .idea 18 | .vagrant 19 | .noseids 20 | .ipynb_checkpoints 21 | .tags 22 | .cache/ 23 | .direnv/ 24 | 25 | # Compiled source # 26 | ################### 27 | *.a 28 | *.com 29 | *.class 30 | *.dll 31 | *.exe 32 | *.pxi 33 | *.o 34 | *.py[ocd] 35 | *.so 36 | .build_cache_dir 37 | MANIFEST 38 | 39 | # Python files # 40 | ################ 41 | # setup.py working directory 42 | build 43 | # sphinx build directory 44 | doc/_* 45 | # setup.py dist directory 46 | dist 47 | # Egg metadata 48 | *.egg-info 49 | .eggs 50 | .pypirc 51 | 52 | # tox testing tool 53 | .tox 54 | # rope 55 | .ropeproject 56 | # wheel files 57 | *.whl 58 | **/wheelhouse/* 59 | # coverage 60 | .coverage 61 | coverage.xml 62 | coverage_html_report 63 | 64 | # OS generated files # 65 | ###################### 66 | .directory 67 | .gdb_history 68 | .DS_Store 69 | ehthumbs.db 70 | Icon? 71 | Thumbs.db 72 | 73 | # Data files # 74 | ############## 75 | *.dta 76 | *.xpt 77 | *.h5 78 | 79 | # Generated Sources # 80 | ##################### 81 | !skts.c 82 | !np_datetime.c 83 | !np_datetime_strings.c 84 | *.c 85 | *.cpp 86 | 87 | .pytest_cache 88 | .envrc 89 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | rev: v0.1.9 10 | hooks: 11 | - id: ruff 12 | args: ["--fix", "--show-source"] 13 | - id: ruff-format 14 | types_or: [ python, pyi, jupyter ] 15 | - repo: https://github.com/pre-commit/mirrors-mypy 16 | rev: v1.3.0 17 | hooks: 18 | - id: mypy 19 | files: upsetplot/ 20 | additional_dependencies: [pytest==6.2.4] 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: doc/conf.py 17 | 18 | # Optionally declare the Python requirements required to build your docs 19 | python: 20 | install: 21 | - requirements: doc/requirements.txt 22 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | What's new in version 1.0 2 | ------------------------- 3 | 4 | In development 5 | 6 | What's new in version 0.9 7 | ------------------------- 8 | 9 | - Fixes a bug where ``show_percentages`` used the incorrect denominator if 10 | filtering (e.g. ``min_subset_size``) was applied. This bug was a regression 11 | introduced in version 0.7. (:issue:`248`) 12 | - Align ylabels of subplots added using ``add_catplot``. (:issue:`266`) 13 | - Add a ``style_categories`` method to customize category plot styles, including 14 | shading of rows in the intersection matrix, and bars in the totals plot. 15 | (:issue:`261` with thanks to :user:`Marcel Albus `). 16 | - Ability to disable totals plot with `totals_plot_elements=0`. (:issue:`246`) 17 | - Ability to set totals y axis label (:issue:`243`) 18 | - Added ``max_subset_rank`` to get only n most populous subsets. (:issue:`253`) 19 | - Added support for ``min_subset_size`` and ``max_subset_size`` specified as 20 | percentage. (:issue:`264`) 21 | 22 | What's new in version 0.8 23 | ------------------------- 24 | 25 | - Allowed ``show_percentages`` to be provided with a custom formatting string, 26 | for example to show more decimal places. (:issue:`194`) 27 | - Added `include_empty_subsets` to `UpSet` and `query` to allow the display of 28 | all possible subsets. (:issue:`185`) 29 | - `sort_by` and `sort_categories_by` now accept '-' prefix to their values 30 | to sort in reverse. 'input' and '-input' are also supported. (:issue:`180`) 31 | - Added `subsets` attribute to QueryResult. (:issue:`198`) 32 | - Fixed a bug where more than 64 categories could result in an error. (:issue:`193`) 33 | 34 | Patch release 0.8.2 handles deprecations in dependencies. 35 | 36 | What's new in version 0.7 37 | ------------------------- 38 | 39 | - Added `query` function to support analysing set-based data. 40 | - Fixed support for matplotlib >3.5.2 (:issue:`191`. Thanks :user:`GuyTeichman`) 41 | 42 | What's new in version 0.6 43 | ------------------------- 44 | 45 | - Added `add_stacked_bars`, similar to `add_catplot` but to add stacked bar 46 | charts to show discrete variable distributions within each subset. 47 | (:issue:`137`) 48 | - Improved ability to control colors, and added a new example of same. 49 | Parameters ``other_dots_color`` and ``shading_color`` were added. 50 | ``facecolor`` will now default to white if 51 | ``matplotlib.rcParams['axes.facecolor']`` is dark. (:issue:`138`) 52 | - Added `style_subsets` to colour intersection size bars and matrix 53 | dots in the plot according to a specified query. (:issue:`152`) 54 | - Added `from_indicators` to allow yet another data input format. This 55 | allows category membership to be easily derived from a DataFrame, such as 56 | when plotting missing values in the columns of a DataFrame. (:issue:`143`) 57 | 58 | What's new in version 0.5 59 | ------------------------- 60 | 61 | - Support using input intersection order with ``sort_by=None`` (:issue:`133` 62 | with thanks to :user:`Brandon B `). 63 | - Add parameters for filtering by subset size (with thanks to 64 | :user:`Sichong Peng `) and degree. (:issue:`134`) 65 | - Fixed an issue where tick labels were not given enough space and overlapped 66 | category totals. (:issue:`132`) 67 | - Fixed an issue where our implementation of ``sort_by='degree'`` apparently 68 | gave incorrect results for some inputs and versions of Pandas. (:issue:`134`) 69 | 70 | What's new in version 0.4.4 71 | --------------------------- 72 | 73 | - Fixed a regresion which caused the first column to be hidden 74 | (:issue:`125`) 75 | 76 | What's new in version 0.4.3 77 | --------------------------- 78 | 79 | - Fixed issue with the order of catplots being reversed for vertical plots 80 | (:issue:`122` with thanks to :user:`Enrique Fernandez-Blanco `) 81 | - Fixed issue with the x limits of vertical plots (:issue:`121`). 82 | 83 | What's new in version 0.4.2 84 | --------------------------- 85 | 86 | - Fixed large x-axis plot margins with high number of unique intersections 87 | (:issue:`106` with thanks to :user:`Yidi Huang `) 88 | 89 | What's new in version 0.4.1 90 | --------------------------- 91 | 92 | - Fixed the calculation of percentage which was broken in 0.4.0. (:issue:`101`) 93 | 94 | What's new in version 0.4 95 | ------------------------- 96 | 97 | - Added option to display both the absolute frequency and the percentage of 98 | the total for each intersection and category. (:issue:`89` with thanks to 99 | :user:`Carlos Melus ` and :user:`Aaron Rosenfeld `) 100 | - Improved efficiency where there are many categories, but valid combinations 101 | are sparse, if `sort_by='degree'`. (:issue:`82`) 102 | - Permit truthy (not necessarily bool) values in index. 103 | (:issue:`74` with thanks to :user:`ZaxR`) 104 | - `intersection_plot_elements` can now be set to 0 to hide the intersection 105 | size plot when `add_catplot` is used. (:issue:`80`) 106 | 107 | What's new in version 0.3 108 | ------------------------- 109 | 110 | - Added `from_contents` to provide an alternative, intuitive way of specifying 111 | category membership of elements. 112 | - To improve code legibility and intuitiveness, `sum_over=False` was deprecated 113 | and a `subset_size` parameter was added. It will have better default 114 | handling of DataFrames after a short deprecation period. 115 | - `generate_data` has been replaced with `generate_counts` and 116 | `generate_samples`. 117 | - Fixed the display of the "intersection size" label on plots, which had been 118 | missing. 119 | - Trying to improve nomenclature, upsetplot now avoids "set" to refer to the 120 | top-level sets, which are now to be known as "categories". This matches the 121 | intuition that categories are named, logical groupings, as opposed to 122 | "subsets". To this end: 123 | 124 | - `generate_counts` (formerly `generate_data`) now names its categories 125 | "cat1", "cat2" etc. rather than "set1", "set2", etc. 126 | - the `sort_sets_by` parameter has been renamed to `sort_categories_by` and 127 | will be removed in version 0.4. 128 | 129 | What's new in version 0.2.1 130 | --------------------------- 131 | 132 | - Return a Series (not a DataFrame) from `from_memberships` if data is 133 | 1-dimensional. 134 | 135 | What's new in version 0.2 136 | ------------------------- 137 | 138 | - Added `from_memberships` to allow a more convenient data input format. 139 | - `plot` and `UpSet` now accept a `pandas.DataFrame` as input, if the 140 | `sum_over` parameter is also given. 141 | - Added an `add_catplot` method to `UpSet` which adds Seaborn plots of set 142 | intersection data to show more than just set size or total. 143 | - Shading of subset matrix is continued through to totals. 144 | - Added a `show_counts` option to show counts at the ends of bar plots. 145 | (:issue:`5`) 146 | - Defined `_repr_html_` so that an `UpSet` object will render in Jupyter 147 | notebooks. 148 | (:issue:`36`) 149 | - Fix a bug where an error was raised if an input set was empty. 150 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | New BSD License 2 | 3 | Copyright (c) 2018-2024 Joel Nothman. 4 | All rights reserved. 5 | 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | a. Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | b. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | c. The names of the contributors may not be used to endorse or promote 16 | products derived from this software without specific prior written 17 | permission. 18 | 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 28 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 29 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 30 | DAMAGE. 31 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | UpSetPlot documentation 2 | ============================ 3 | 4 | |version| |licence| |py-versions| 5 | 6 | |issues| |build| |docs| |coverage| 7 | 8 | This is another Python implementation of UpSet plots by Lex et al. [Lex2014]_. 9 | UpSet plots are used to visualise set overlaps; like Venn diagrams but 10 | more readable. Documentation is at https://upsetplot.readthedocs.io. 11 | 12 | This ``upsetplot`` library tries to provide a simple interface backed by an 13 | extensible, object-oriented design. 14 | 15 | There are many ways to represent the categorisation of data, as covered in 16 | our `Data Format Guide `_. 17 | 18 | Our internal input format uses a `pandas.Series` containing counts 19 | corresponding to subset sizes, where each subset is an intersection of named 20 | categories. The index of the Series indicates which rows pertain to which 21 | categories, by having multiple boolean indices, like ``example`` in the 22 | following:: 23 | 24 | >>> from upsetplot import generate_counts 25 | >>> example = generate_counts() 26 | >>> example 27 | cat0 cat1 cat2 28 | False False False 56 29 | True 283 30 | True False 1279 31 | True 5882 32 | True False False 24 33 | True 90 34 | True False 429 35 | True 1957 36 | Name: value, dtype: int64 37 | 38 | Then:: 39 | 40 | >>> from upsetplot import plot 41 | >>> plot(example) # doctest: +SKIP 42 | >>> from matplotlib import pyplot 43 | >>> pyplot.show() # doctest: +SKIP 44 | 45 | makes: 46 | 47 | .. image:: http://upsetplot.readthedocs.io/en/latest/_images/sphx_glr_plot_generated_001.png 48 | :target: ../auto_examples/plot_generated.html 49 | 50 | And you can save the image in various formats:: 51 | 52 | >>> pyplot.savefig("/path/to/myplot.pdf") # doctest: +SKIP 53 | >>> pyplot.savefig("/path/to/myplot.png") # doctest: +SKIP 54 | 55 | This plot shows the cardinality of every category combination seen in our data. 56 | The leftmost column counts items absent from any category. The next three 57 | columns count items only in ``cat1``, ``cat2`` and ``cat3`` respectively, with 58 | following columns showing cardinalities for items in each combination of 59 | exactly two named sets. The rightmost column counts items in all three sets. 60 | 61 | Rotation 62 | ........ 63 | 64 | We call the above plot style "horizontal" because the category intersections 65 | are presented from left to right. `Vertical plots 66 | `__ 67 | are also supported! 68 | 69 | .. image:: http://upsetplot.readthedocs.io/en/latest/_images/sphx_glr_plot_vertical_001.png 70 | :target: http://upsetplot.readthedocs.io/en/latest/auto_examples/plot_vertical.html 71 | 72 | Distributions 73 | ............. 74 | 75 | Providing a DataFrame rather than a Series as input allows us to expressively 76 | `plot the distribution of variables 77 | `__ 78 | in each subset. 79 | 80 | .. image:: http://upsetplot.readthedocs.io/en/latest/_images/sphx_glr_plot_diabetes_001.png 81 | :target: http://upsetplot.readthedocs.io/en/latest/auto_examples/plot_diabetes.html 82 | 83 | Loading datasets 84 | ................ 85 | 86 | While the dataset above is randomly generated, you can prepare your own dataset 87 | for input to upsetplot. A helpful tool is `from_memberships`, which allows 88 | us to reconstruct the example above by indicating each data point's category 89 | membership:: 90 | 91 | >>> from upsetplot import from_memberships 92 | >>> example = from_memberships( 93 | ... [[], 94 | ... ['cat2'], 95 | ... ['cat1'], 96 | ... ['cat1', 'cat2'], 97 | ... ['cat0'], 98 | ... ['cat0', 'cat2'], 99 | ... ['cat0', 'cat1'], 100 | ... ['cat0', 'cat1', 'cat2'], 101 | ... ], 102 | ... data=[56, 283, 1279, 5882, 24, 90, 429, 1957] 103 | ... ) 104 | >>> example 105 | cat0 cat1 cat2 106 | False False False 56 107 | True 283 108 | True False 1279 109 | True 5882 110 | True False False 24 111 | True 90 112 | True False 429 113 | True 1957 114 | dtype: int64 115 | 116 | See also `from_contents`, another way to describe categorised data, and 117 | `from_indicators` which allows each category to be indicated by a column in 118 | the data frame (or a function of the column's data such as whether it is a 119 | missing value). 120 | 121 | Installation 122 | ------------ 123 | 124 | To install the library, you can use `pip`:: 125 | 126 | $ pip install upsetplot 127 | 128 | Installation requires: 129 | 130 | * pandas 131 | * matplotlib >= 2.0 132 | * seaborn to use `UpSet.add_catplot` 133 | 134 | It should then be possible to:: 135 | 136 | >>> import upsetplot 137 | 138 | in Python. 139 | 140 | Why an alternative to py-upset? 141 | ------------------------------- 142 | 143 | Probably for petty reasons. It appeared `py-upset 144 | `_ was not being maintained. Its 145 | input format was undocumented, inefficient and, IMO, inappropriate. It did not 146 | facilitate showing plots of each subset's distribution as in Lex et al's work 147 | introducing UpSet plots. Nor did it include the horizontal bar plots 148 | illustrated there. It did not support Python 2. I decided it would be easier to 149 | construct a cleaner version than to fix it. 150 | 151 | References 152 | ---------- 153 | 154 | .. [Lex2014] Alexander Lex, Nils Gehlenborg, Hendrik Strobelt, Romain Vuillemot, Hanspeter Pfister, 155 | *UpSet: Visualization of Intersecting Sets*, 156 | IEEE Transactions on Visualization and Computer Graphics (InfoVis '14), vol. 20, no. 12, pp. 1983–1992, 2014. 157 | doi: `doi.org/10.1109/TVCG.2014.2346248 `_ 158 | 159 | 160 | .. |py-versions| image:: https://img.shields.io/pypi/pyversions/upsetplot.svg 161 | :alt: Python versions supported 162 | 163 | .. |version| image:: https://badge.fury.io/py/UpSetPlot.svg 164 | :alt: Latest version on PyPi 165 | :target: https://badge.fury.io/py/UpSetPlot 166 | 167 | .. |build| image:: https://github.com/jnothman/upsetplot/actions/workflows/test.yml/badge.svg 168 | :alt: Github Workflows CI build status 169 | :scale: 100% 170 | :target: https://github.com/jnothman/UpSetPlot/actions/workflows/test.yml 171 | 172 | .. |issues| image:: https://img.shields.io/github/issues/jnothman/UpSetPlot.svg 173 | :alt: Issue tracker 174 | :target: https://github.com/jnothman/UpSetPlot 175 | 176 | .. |coverage| image:: https://coveralls.io/repos/github/jnothman/UpSetPlot/badge.svg 177 | :alt: Test coverage 178 | :target: https://coveralls.io/github/jnothman/UpSetPlot 179 | 180 | .. |docs| image:: https://readthedocs.org/projects/upsetplot/badge/?version=latest 181 | :alt: Documentation Status 182 | :scale: 100% 183 | :target: https://upsetplot.readthedocs.io/en/latest/?badge=latest 184 | 185 | .. |licence| image:: https://img.shields.io/badge/Licence-BSD-blue.svg 186 | :target: https://opensource.org/licenses/BSD-3-Clause 187 | -------------------------------------------------------------------------------- /ci/matplotlibrc: -------------------------------------------------------------------------------- 1 | backend : Agg 2 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf _modules/* 53 | 54 | html: 55 | # These two lines make the build a bit more lengthy, and the 56 | # the embedding of images more robust 57 | rm -rf $(BUILDDIR)/html/_images 58 | #rm -rf _build/doctrees/ 59 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 60 | @echo 61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 62 | 63 | dirhtml: 64 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 65 | @echo 66 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 67 | 68 | singlehtml: 69 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 70 | @echo 71 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 72 | 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | json: 79 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 80 | @echo 81 | @echo "Build finished; now you can process the JSON files." 82 | 83 | htmlhelp: 84 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 85 | @echo 86 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 87 | ".hhp project file in $(BUILDDIR)/htmlhelp." 88 | 89 | qthelp: 90 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 91 | @echo 92 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 93 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 94 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 95 | @echo "To view the help file:" 96 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 97 | 98 | devhelp: 99 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 100 | @echo 101 | @echo "Build finished." 102 | @echo "To view the help file:" 103 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 104 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 105 | @echo "# devhelp" 106 | 107 | epub: 108 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 109 | @echo 110 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 111 | 112 | latex: 113 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 114 | @echo 115 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 116 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 117 | "(use \`make latexpdf' here to do that automatically)." 118 | 119 | latexpdf: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through pdflatex..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | latexpdfja: 126 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 127 | @echo "Running LaTeX files through platex and dvipdfmx..." 128 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 129 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 130 | 131 | text: 132 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 133 | @echo 134 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 135 | 136 | man: 137 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 138 | @echo 139 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 140 | 141 | texinfo: 142 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 143 | @echo 144 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 145 | @echo "Run \`make' in that directory to run these through makeinfo" \ 146 | "(use \`make info' here to do that automatically)." 147 | 148 | info: 149 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 150 | @echo "Running Texinfo files through makeinfo..." 151 | make -C $(BUILDDIR)/texinfo info 152 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 153 | 154 | gettext: 155 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 156 | @echo 157 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 158 | 159 | changes: 160 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 161 | @echo 162 | @echo "The overview file is in $(BUILDDIR)/changes." 163 | 164 | linkcheck: 165 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 166 | @echo 167 | @echo "Link check complete; look for any errors in the above output " \ 168 | "or in $(BUILDDIR)/linkcheck/output.txt." 169 | 170 | doctest: 171 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 172 | @echo "Testing of doctests in the sources finished, look at the " \ 173 | "results in $(BUILDDIR)/doctest/output.txt." 174 | 175 | xml: 176 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 177 | @echo 178 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 179 | 180 | pseudoxml: 181 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 182 | @echo 183 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 184 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | 2 | 3 | API Reference 4 | ............. 5 | 6 | .. currentmodule:: upsetplot 7 | 8 | Plotting 9 | -------- 10 | 11 | .. autofunction:: plot 12 | 13 | .. autoclass:: UpSet 14 | :members: 15 | 16 | 17 | Dataset loading and generation 18 | ------------------------------ 19 | 20 | .. autofunction:: from_contents 21 | 22 | .. autofunction:: from_indicators 23 | 24 | .. autofunction:: from_memberships 25 | 26 | .. autofunction:: generate_counts 27 | 28 | .. autofunction:: generate_samples 29 | 30 | Data querying and transformation 31 | -------------------------------- 32 | 33 | .. autofunction:: query 34 | -------------------------------------------------------------------------------- /doc/changelog.rst: -------------------------------------------------------------------------------- 1 | 2 | Changelog 3 | ......... 4 | 5 | .. include:: ../CHANGELOG.rst 6 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # project-template documentation build configuration file, created by 2 | # sphinx-quickstart on Mon Jan 18 14:44:12 2016. 3 | # 4 | # This file is execfile()d with the current directory set to its 5 | # containing dir. 6 | # 7 | # Note that not all possible configuration values are present in this 8 | # autogenerated file. 9 | # 10 | # All configuration values have a default; values that are commented out 11 | # serve to show the default. 12 | 13 | import os 14 | import re 15 | import sys 16 | import warnings 17 | 18 | # project root 19 | sys.path.insert(0, os.path.abspath("..")) 20 | 21 | import matplotlib # noqa 22 | 23 | matplotlib.use("agg") 24 | warnings.filterwarnings( 25 | "ignore", 26 | category=UserWarning, 27 | message="Matplotlib is currently using agg, which is a" 28 | " non-GUI backend, so cannot show the figure." 29 | "|(\n|.)*is non-interactive, and thus cannot be shown", 30 | ) 31 | 32 | import sphinx_rtd_theme # noqa 33 | from sphinx_gallery.sorting import ExampleTitleSortKey # noqa 34 | from upsetplot import __version__ as release # noqa 35 | 36 | 37 | # If extensions (or modules to document with autodoc) are in another directory, 38 | # add these directories to sys.path here. If the directory is relative to the 39 | # documentation root, use os.path.abspath to make it absolute, like shown here. 40 | 41 | # -- General configuration --------------------------------------------------- 42 | 43 | # If your documentation needs a minimal Sphinx version, state it here. 44 | # needs_sphinx = '1.0' 45 | 46 | # Add any Sphinx extension module names here, as strings. They can be 47 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 48 | # ones. 49 | extensions = [ 50 | "sphinx_gallery.gen_gallery", 51 | "sphinx.ext.autodoc", 52 | "sphinx.ext.autosummary", 53 | "sphinx.ext.doctest", 54 | "sphinx.ext.intersphinx", 55 | "sphinx.ext.todo", 56 | "numpydoc", 57 | "sphinx.ext.ifconfig", 58 | "sphinx.ext.viewcode", 59 | "sphinx_issues", 60 | "nbsphinx", 61 | ] 62 | 63 | # Add any paths that contain templates here, relative to this directory. 64 | templates_path = ["_templates"] 65 | 66 | # The suffix of source filenames. 67 | source_suffix = ".rst" 68 | 69 | # The encoding of source files. 70 | # source_encoding = 'utf-8-sig' 71 | 72 | # The master toctree document. 73 | master_doc = "index" 74 | 75 | # General information about the project. 76 | project = "upsetplot" 77 | copyright = "2018-2024, Joel Nothman" 78 | 79 | # The version info for the project you're documenting, acts as replacement for 80 | # |version| and |release|, also used in various other places throughout the 81 | # built documents. 82 | # 83 | # The short X.Y version. 84 | 85 | version = re.match(r"^\d+(\.\d+)*", release).group() 86 | 87 | # version = upsetplot.__version__ 88 | # The full version, including alpha/beta/rc tags. 89 | # release = version 90 | 91 | # The language for content autogenerated by Sphinx. Refer to documentation 92 | # for a list of supported languages. 93 | # language = None 94 | 95 | # There are two options for replacing |today|: either, you set today to some 96 | # non-false value, then it is used: 97 | # today = '' 98 | # Else, today_fmt is used as the format for a strftime call. 99 | # today_fmt = '%B %d, %Y' 100 | 101 | # List of patterns, relative to source directory, that match files and 102 | # directories to ignore when looking for source files. 103 | exclude_patterns = ["_build"] 104 | 105 | # The reST default role (used for this markup: `text`) to use for all 106 | # documents. 107 | default_role = "any" 108 | 109 | # If true, '()' will be appended to :func: etc. cross-reference text. 110 | # add_function_parentheses = True 111 | 112 | # If true, the current module name will be prepended to all description 113 | # unit titles (such as .. function::). 114 | # add_module_names = True 115 | 116 | # If true, sectionauthor and moduleauthor directives will be shown in the 117 | # output. They are ignored by default. 118 | # show_authors = False 119 | 120 | # The name of the Pygments (syntax highlighting) style to use. 121 | pygments_style = "sphinx" 122 | 123 | # A list of ignored prefixes for module index sorting. 124 | # modindex_common_prefix = [] 125 | 126 | # If true, keep warnings as "system message" paragraphs in the built documents. 127 | # keep_warnings = False 128 | 129 | 130 | # -- Options for HTML output ---------------------------------------------- 131 | 132 | # The theme to use for HTML and HTML Help pages. See the documentation for 133 | # a list of builtin themes. 134 | html_theme = "sphinx_rtd_theme" 135 | 136 | # Theme options are theme-specific and customize the look and feel of a theme 137 | # further. For a list of options available for each theme, see the 138 | # documentation. 139 | # html_theme_options = {} 140 | 141 | # Add any paths that contain custom themes here, relative to this directory. 142 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 143 | 144 | # The name for this set of Sphinx documents. If None, it defaults to 145 | # " v documentation". 146 | # html_title = None 147 | 148 | # A shorter title for the navigation bar. Default is the same as html_title. 149 | # html_short_title = None 150 | 151 | # The name of an image file (relative to this directory) to place at the top 152 | # of the sidebar. 153 | # html_logo = None 154 | 155 | # The name of an image file (within the static path) to use as favicon of the 156 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 157 | # pixels large. 158 | # html_favicon = None 159 | 160 | # Add any paths that contain custom static files (such as style sheets) here, 161 | # relative to this directory. They are copied after the builtin static files, 162 | # so a file named "default.css" will overwrite the builtin "default.css". 163 | html_static_path = ["_static"] 164 | 165 | # Add any extra paths that contain custom files (such as robots.txt or 166 | # .htaccess) here, relative to this directory. These files are copied 167 | # directly to the root of the documentation. 168 | # html_extra_path = [] 169 | 170 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 171 | # using the given strftime format. 172 | # html_last_updated_fmt = '%b %d, %Y' 173 | 174 | # If true, SmartyPants will be used to convert quotes and dashes to 175 | # typographically correct entities. 176 | # html_use_smartypants = True 177 | 178 | # Custom sidebar templates, maps document names to template names. 179 | # html_sidebars = {} 180 | 181 | # Additional templates that should be rendered to pages, maps page names to 182 | # template names. 183 | # html_additional_pages = {} 184 | 185 | # If false, no module index is generated. 186 | # html_domain_indices = True 187 | 188 | # If false, no index is generated. 189 | # html_use_index = True 190 | 191 | # If true, the index is split into individual pages for each letter. 192 | # html_split_index = False 193 | 194 | # If true, links to the reST sources are added to the pages. 195 | # html_show_sourcelink = True 196 | 197 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 198 | # html_show_sphinx = True 199 | 200 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 201 | # html_show_copyright = True 202 | 203 | # If true, an OpenSearch description file will be output, and all pages will 204 | # contain a tag referring to it. The value of this option must be the 205 | # base URL from which the finished HTML is served. 206 | # html_use_opensearch = '' 207 | 208 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 209 | # html_file_suffix = None 210 | 211 | # Output file base name for HTML help builder. 212 | htmlhelp_basename = "project-templatedoc" 213 | 214 | 215 | # -- Options for LaTeX output --------------------------------------------- 216 | 217 | latex_elements = { 218 | # The paper size ('letterpaper' or 'a4paper'). 219 | # 'papersize': 'letterpaper', 220 | # The font size ('10pt', '11pt' or '12pt'). 221 | # 'pointsize': '10pt', 222 | # Additional stuff for the LaTeX preamble. 223 | # 'preamble': '', 224 | } 225 | 226 | # Grouping the document tree into LaTeX files. List of tuples 227 | # (source start file, target name, title, 228 | # author, documentclass [howto, manual, or own class]). 229 | latex_documents = [ 230 | ("index", "upsetplot.tex", "upsetplot Documentation", "Joel Nothman", "manual"), 231 | ] 232 | 233 | # The name of an image file (relative to this directory) to place at the top of 234 | # the title page. 235 | # latex_logo = None 236 | 237 | # For "manual" documents, if this is true, then toplevel headings are parts, 238 | # not chapters. 239 | # latex_use_parts = False 240 | 241 | # If true, show page references after internal links. 242 | # latex_show_pagerefs = False 243 | 244 | # If true, show URL addresses after external links. 245 | # latex_show_urls = False 246 | 247 | # Documents to append as an appendix to all manuals. 248 | # latex_appendices = [] 249 | 250 | # If false, no module index is generated. 251 | # latex_domain_indices = True 252 | 253 | # Documents to append as an appendix to all manuals. 254 | # texinfo_appendices = [] 255 | 256 | # If false, no module index is generated. 257 | # texinfo_domain_indices = True 258 | 259 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 260 | # texinfo_show_urls = 'footnote' 261 | 262 | # If true, do not generate a @detailmenu in the "Top" node's menu. 263 | # texinfo_no_detailmenu = False 264 | 265 | 266 | # Example configuration for intersphinx: refer to the Python standard library. 267 | intersphinx_mapping = { 268 | "python": ("http://docs.python.org/", None), 269 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 270 | "matplotlib": ("https://matplotlib.org/", None), 271 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 272 | } 273 | 274 | 275 | # Config for sphinx_issues 276 | 277 | issues_uri = "https://github.com/jnothman/upsetplot/issues/{issue}" 278 | issues_github_path = "jnothman/upsetplot" 279 | issues_user_uri = "https://github.com/{user}" 280 | 281 | 282 | sphinx_gallery_conf = { 283 | # path to your examples scripts 284 | "examples_dirs": "../examples", 285 | # path where to save gallery generated examples 286 | "gallery_dirs": "auto_examples", 287 | "backreferences_dir": "_modules", 288 | "within_subsection_order": ExampleTitleSortKey, 289 | } 290 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | 5 | auto_examples/index 6 | formats.ipynb 7 | api 8 | changelog 9 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | matplotlib 5 | numpydoc 6 | sphinx-gallery 7 | sphinx-issues 8 | seaborn 9 | scikit-learn 10 | nbsphinx 11 | sphinx<2 12 | sphinx-rtd-theme 13 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Examples 4 | ======== 5 | 6 | Introductory examples for upsetplot. 7 | -------------------------------------------------------------------------------- /examples/plot_customize_after_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | Design: Customizing axis labels 4 | =============================== 5 | 6 | This example illustrates how the return value of the plot method can be used 7 | to customize aspects of the plot, such as axis labels, legend position, etc. 8 | """ 9 | 10 | from matplotlib import pyplot as plt 11 | 12 | from upsetplot import generate_counts, plot 13 | 14 | example = generate_counts() 15 | print(example) 16 | 17 | ########################################################################## 18 | 19 | plot_result = plot(example) 20 | plot_result["intersections"].set_ylabel("Subset size") 21 | plot_result["totals"].set_xlabel("Category size") 22 | plt.show() 23 | -------------------------------------------------------------------------------- /examples/plot_diabetes.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================== 3 | Data Vis: Feature distribution in Diabetes 4 | ========================================== 5 | 6 | Explore above-average attributes in the Diabetes dataset (Efron et al, 2004). 7 | 8 | Here we take some features correlated with disease progression, and look at the 9 | distribution of that disease progression value when each of these features is 10 | above average. 11 | 12 | The most correlated features are: 13 | 14 | - bmi body mass index 15 | - bp average blood pressure 16 | - s4 tch, total cholesterol / HDL 17 | - s5 ltg, possibly log of serum triglycerides level 18 | - s6 glu, blood sugar level 19 | 20 | This kind of dataset analysis may not be a practical use of UpSet, but helps 21 | to illustrate the :meth:`UpSet.add_catplot` feature. 22 | """ 23 | 24 | import pandas as pd 25 | from matplotlib import pyplot as plt 26 | from sklearn.datasets import load_diabetes 27 | 28 | from upsetplot import UpSet 29 | 30 | # Load the dataset into a DataFrame 31 | diabetes = load_diabetes() 32 | diabetes_df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names) 33 | 34 | # Get five features most correlated with median house value 35 | correls = diabetes_df.corrwith( 36 | pd.Series(diabetes.target), method="spearman" 37 | ).sort_values() 38 | top_features = correls.index[-5:] 39 | 40 | # Get a binary indicator of whether each top feature is above average 41 | diabetes_above_avg = diabetes_df > diabetes_df.median(axis=0) 42 | diabetes_above_avg = diabetes_above_avg[top_features] 43 | diabetes_above_avg = diabetes_above_avg.rename(columns=lambda x: x + ">") 44 | 45 | # Make this indicator mask an index of diabetes_df 46 | diabetes_df = pd.concat([diabetes_df, diabetes_above_avg], axis=1) 47 | diabetes_df = diabetes_df.set_index(list(diabetes_above_avg.columns)) 48 | 49 | # Also give us access to the target (median house value) 50 | diabetes_df = diabetes_df.assign(progression=diabetes.target) 51 | 52 | ########################################################################## 53 | 54 | # UpSet plot it! 55 | upset = UpSet(diabetes_df, subset_size="count", intersection_plot_elements=3) 56 | upset.add_catplot(value="progression", kind="strip", color="blue") 57 | print(diabetes_df) 58 | upset.add_catplot(value="bmi", kind="strip", color="black") 59 | upset.plot() 60 | plt.title("UpSet with catplots, for orientation='horizontal'") 61 | plt.show() 62 | 63 | ########################################################################## 64 | 65 | # And again in vertical orientation 66 | 67 | upset = UpSet( 68 | diabetes_df, 69 | subset_size="count", 70 | intersection_plot_elements=3, 71 | orientation="vertical", 72 | ) 73 | upset.add_catplot(value="progression", kind="strip", color="blue") 74 | upset.add_catplot(value="bmi", kind="strip", color="black") 75 | upset.plot() 76 | plt.suptitle("UpSet with catplots, for orientation='vertical'") 77 | plt.show() 78 | -------------------------------------------------------------------------------- /examples/plot_discrete.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================================================== 3 | Data Vis: Plotting discrete variables as stacked bar charts 4 | =========================================================== 5 | 6 | Currently, a somewhat contrived example of `add_stacked_bars`. 7 | """ 8 | 9 | import pandas as pd 10 | from matplotlib import cm 11 | from matplotlib import pyplot as plt 12 | 13 | from upsetplot import UpSet 14 | 15 | TITANIC_URL = ( 16 | "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv" # noqa 17 | ) 18 | df = pd.read_csv(TITANIC_URL) 19 | # Show UpSet on survival and first classs 20 | df = df.set_index(df.Survived == 1).set_index(df.Pclass == 1, append=True) 21 | 22 | upset = UpSet(df, intersection_plot_elements=0) # disable the default bar chart 23 | upset.add_stacked_bars( 24 | by="Sex", colors=cm.Pastel1, title="Count by gender", elements=10 25 | ) 26 | upset.plot() 27 | plt.suptitle("Gender for first class and survival on Titanic") 28 | plt.show() 29 | 30 | 31 | upset = UpSet( 32 | df, show_counts=True, orientation="vertical", intersection_plot_elements=0 33 | ) 34 | upset.add_stacked_bars( 35 | by="Sex", colors=cm.Pastel1, title="Count by gender", elements=10 36 | ) 37 | upset.plot() 38 | plt.suptitle("Same, but vertical, with counts shown") 39 | plt.show() 40 | -------------------------------------------------------------------------------- /examples/plot_generated.py: -------------------------------------------------------------------------------- 1 | """ 2 | =================================== 3 | Basic: Examples with generated data 4 | =================================== 5 | 6 | This example illustrates basic plotting functionality using generated data. 7 | """ 8 | 9 | import matplotlib 10 | from matplotlib import pyplot as plt 11 | 12 | from upsetplot import generate_counts, plot 13 | 14 | example = generate_counts() 15 | print(example) 16 | 17 | ########################################################################## 18 | 19 | plot(example) 20 | plt.suptitle("Ordered by degree") 21 | plt.show() 22 | 23 | ########################################################################## 24 | 25 | plot(example, sort_by="cardinality") 26 | plt.suptitle("Ordered by cardinality") 27 | plt.show() 28 | 29 | ########################################################################## 30 | 31 | plot(example, show_counts="{:,}") 32 | plt.suptitle("With counts shown, using a thousands separator") 33 | plt.show() 34 | 35 | ########################################################################## 36 | 37 | plot(example, show_counts="%d", show_percentages=True) 38 | plt.suptitle("With counts and % shown") 39 | plt.show() 40 | 41 | ########################################################################## 42 | 43 | plot(example, show_percentages="{:.2%}") 44 | plt.suptitle("With fraction shown in custom format") 45 | plt.show() 46 | 47 | ########################################################################## 48 | 49 | matplotlib.rcParams["font.size"] = 6 50 | plot(example, show_percentages="{:.2%}") 51 | plt.suptitle("With a smaller font size") 52 | plt.show() 53 | -------------------------------------------------------------------------------- /examples/plot_hide.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================= 3 | Basic: Hiding subsets based on size or degree 4 | ============================================= 5 | 6 | This illustrates the use of ``min_subset_size``, ``max_subset_size``, 7 | ``min_degree`` or ``max_degree``. 8 | """ 9 | 10 | from matplotlib import pyplot as plt 11 | 12 | from upsetplot import generate_counts, plot 13 | 14 | example = generate_counts() 15 | 16 | plot(example, show_counts=True) 17 | plt.suptitle("Nothing hidden") 18 | plt.show() 19 | 20 | ########################################################################## 21 | 22 | plot(example, show_counts=True, min_subset_size=100) 23 | plt.suptitle("Small subsets hidden") 24 | plt.show() 25 | 26 | ########################################################################## 27 | 28 | plot(example, show_counts=True, max_subset_size=500) 29 | plt.suptitle("Large subsets hidden") 30 | plt.show() 31 | 32 | ########################################################################## 33 | 34 | plot(example, show_counts=True, min_degree=2) 35 | plt.suptitle("Degree <2 hidden") 36 | plt.show() 37 | 38 | ########################################################################## 39 | 40 | plot(example, show_counts=True, max_degree=2) 41 | plt.suptitle("Degree >2 hidden") 42 | plt.show() 43 | -------------------------------------------------------------------------------- /examples/plot_highlight.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================= 3 | Data Vis: Highlighting selected subsets 4 | ======================================= 5 | 6 | Demonstrates use of the `style_subsets` method to mark some subsets as 7 | different. 8 | 9 | """ 10 | 11 | from matplotlib import pyplot as plt 12 | 13 | from upsetplot import UpSet, generate_counts 14 | 15 | example = generate_counts() 16 | 17 | ########################################################################## 18 | # Subsets can be styled by the categories present in them, and a legend 19 | # can be optionally generated. 20 | 21 | upset = UpSet(example) 22 | upset.style_subsets(present=["cat1", "cat2"], facecolor="blue", label="special") 23 | upset.plot() 24 | plt.suptitle("Paint blue subsets including both cat1 and cat2; show a legend") 25 | plt.show() 26 | 27 | ########################################################################## 28 | # ... or styling can be applied by the categories absent in a subset. 29 | 30 | upset = UpSet(example, orientation="vertical") 31 | upset.style_subsets(present="cat2", absent="cat1", edgecolor="red", linewidth=2) 32 | upset.plot() 33 | plt.suptitle("Border for subsets including cat2 but not cat1") 34 | plt.show() 35 | 36 | ########################################################################## 37 | # ... or their size. 38 | 39 | upset = UpSet(example) 40 | upset.style_subsets( 41 | min_subset_size=1000, facecolor="lightblue", hatch="xx", label="big" 42 | ) 43 | upset.plot() 44 | plt.suptitle("Hatch subsets with size >1000") 45 | plt.show() 46 | 47 | ########################################################################## 48 | # ... or degree. 49 | 50 | upset = UpSet(example) 51 | upset.style_subsets(min_degree=1, facecolor="blue") 52 | upset.style_subsets(min_degree=2, facecolor="purple") 53 | upset.style_subsets(min_degree=3, facecolor="red") 54 | upset.plot() 55 | plt.suptitle("Coloring by degree") 56 | plt.show() 57 | 58 | ########################################################################## 59 | # Multiple stylings can be applied with different criteria in the same 60 | # plot. 61 | 62 | 63 | upset = UpSet(example, facecolor="gray") 64 | upset.style_subsets(present="cat0", label="Contains cat0", facecolor="blue") 65 | upset.style_subsets( 66 | present="cat1", label="Contains cat1", hatch="xx", edgecolor="black" 67 | ) 68 | upset.style_subsets(present="cat2", label="Contains cat2", edgecolor="red") 69 | 70 | # reduce legend size: 71 | params = {"legend.fontsize": 8} 72 | with plt.rc_context(params): 73 | upset.plot() 74 | plt.suptitle("Styles for every category!") 75 | plt.show() 76 | -------------------------------------------------------------------------------- /examples/plot_highlight_categories.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================== 3 | Data Vis: Highlighting selected categories 4 | ========================================== 5 | 6 | Demonstrates use of the `style_categories` method to mark some 7 | categories differently. 8 | """ 9 | 10 | from matplotlib import pyplot as plt 11 | 12 | from upsetplot import UpSet, generate_counts 13 | 14 | example = generate_counts() 15 | 16 | 17 | ########################################################################## 18 | # Categories can be shaded by name with the ``shading_`` parameters. 19 | 20 | upset = UpSet(example) 21 | upset.style_categories("cat2", shading_edgecolor="darkgreen", shading_linewidth=1) 22 | upset.style_categories( 23 | "cat1", 24 | shading_facecolor="lavender", 25 | ) 26 | upset.plot() 27 | plt.suptitle("Shade or edge a category with color") 28 | plt.show() 29 | 30 | 31 | ########################################################################## 32 | # Category total bars can be styled with the ``bar_`` parameters. 33 | # You can also specify categories using a list of names. 34 | 35 | upset = UpSet(example) 36 | upset.style_categories( 37 | ["cat2", "cat1"], bar_facecolor="aqua", bar_hatch="xx", bar_edgecolor="black" 38 | ) 39 | upset.plot() 40 | plt.suptitle("") 41 | plt.show() 42 | -------------------------------------------------------------------------------- /examples/plot_missingness.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================== 3 | Basic: Plotting the distribution of missing values 4 | ================================================== 5 | 6 | UpSet plots are often used to show which variables are missing together. 7 | 8 | Passing a callable ``indicators=pd.isna`` to :func:`from_indicators` is 9 | an easy way to categorise a record by the variables that are missing in it. 10 | """ 11 | 12 | import pandas as pd 13 | from matplotlib import pyplot as plt 14 | 15 | from upsetplot import from_indicators, plot 16 | 17 | TITANIC_URL = ( 18 | "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv" # noqa 19 | ) 20 | data = pd.read_csv(TITANIC_URL) 21 | 22 | plot(from_indicators(indicators=pd.isna, data=data), show_counts=True) 23 | plt.show() 24 | -------------------------------------------------------------------------------- /examples/plot_sizing.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================ 3 | Design: Customizing element size and figure size 4 | ================================================ 5 | 6 | This example illustrates controlling sizing within an UpSet plot. 7 | """ 8 | 9 | from matplotlib import pyplot as plt 10 | 11 | from upsetplot import generate_counts, plot 12 | 13 | example = generate_counts() 14 | print(example) 15 | 16 | plot(example) 17 | plt.suptitle("Defaults") 18 | plt.show() 19 | 20 | ########################################################################## 21 | # upsetplot uses a grid of square "elements" to display. Controlling the 22 | # size of these elements affects all components of the plot. 23 | 24 | plot(example, element_size=40) 25 | plt.suptitle("Increased element_size") 26 | plt.show() 27 | 28 | ########################################################################## 29 | # When setting ``figsize`` explicitly, you then need to pass the figure to 30 | # ``plot``, and use ``element_size=None`` for optimal sizing. 31 | 32 | fig = plt.figure(figsize=(10, 3)) 33 | plot(example, fig=fig, element_size=None) 34 | plt.suptitle("Setting figsize explicitly") 35 | plt.show() 36 | 37 | ########################################################################## 38 | # Components in the plot can be resized by indicating how many elements 39 | # they should equate to. 40 | 41 | plot(example, intersection_plot_elements=3) 42 | plt.suptitle("Decreased intersection_plot_elements") 43 | plt.show() 44 | 45 | ########################################################################## 46 | 47 | plot(example, totals_plot_elements=5) 48 | plt.suptitle("Increased totals_plot_elements") 49 | plt.show() 50 | -------------------------------------------------------------------------------- /examples/plot_theming.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================ 3 | Design: Changing Plot Colors 4 | ============================ 5 | 6 | This example illustrates use of matplotlib and upsetplot color settings, aside 7 | from matplotlib style sheets, which can control colors as well as grid lines, 8 | fonts and tick display. 9 | 10 | Upsetplot provides some color settings: 11 | 12 | * ``facecolor``: sets the color for intersection size bars, and for active 13 | matrix dots. Defaults to white on a dark background, otherwise black. 14 | * ``other_dots_color``: sets the color for other (inactive) dots. Specify as a 15 | color, or a float specifying opacity relative to facecolor. 16 | * ``shading_color``: sets the color odd rows. Specify as a color, or a float 17 | specifying opacity relative to facecolor. 18 | 19 | For an introduction to matplotlib theming see: 20 | 21 | * `Tutorial 22 | `__ 23 | * `Reference 24 | `__ 25 | """ 26 | 27 | from matplotlib import pyplot as plt 28 | 29 | from upsetplot import generate_counts, plot 30 | 31 | example = generate_counts() 32 | 33 | plot(example, facecolor="darkblue") 34 | plt.suptitle('facecolor="darkblue"') 35 | plt.show() 36 | 37 | ########################################################################## 38 | 39 | plot(example, facecolor="darkblue", shading_color="lightgray") 40 | plt.suptitle('facecolor="darkblue", shading_color="lightgray"') 41 | plt.show() 42 | 43 | ########################################################################## 44 | 45 | with plt.style.context("Solarize_Light2"): 46 | plot(example) 47 | plt.suptitle("matplotlib classic stylesheet") 48 | plt.show() 49 | 50 | ########################################################################## 51 | 52 | with plt.style.context("dark_background"): 53 | plot(example, show_counts=True) 54 | plt.suptitle("matplotlib dark_background stylesheet") 55 | plt.show() 56 | 57 | ########################################################################## 58 | 59 | with plt.style.context("dark_background"): 60 | plot(example, show_counts=True, shading_color=0.15) 61 | plt.suptitle("matplotlib dark_background stylesheet, shading_color=.15") 62 | plt.show() 63 | 64 | ########################################################################## 65 | 66 | with plt.style.context("dark_background"): 67 | plot(example, show_counts=True, facecolor="red") 68 | plt.suptitle('matplotlib dark_background, facecolor="red"') 69 | plt.show() 70 | 71 | ########################################################################## 72 | 73 | with plt.style.context("dark_background"): 74 | plot( 75 | example, 76 | show_counts=True, 77 | facecolor="red", 78 | other_dots_color=0.4, 79 | shading_color=0.2, 80 | ) 81 | plt.suptitle("dark_background, red face, stronger other colors") 82 | plt.show() 83 | -------------------------------------------------------------------------------- /examples/plot_vertical.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================== 3 | Basic: Vertical orientation 4 | =========================== 5 | 6 | This illustrates the effect of orientation='vertical'. 7 | """ 8 | 9 | from matplotlib import pyplot as plt 10 | 11 | from upsetplot import generate_counts, plot 12 | 13 | example = generate_counts() 14 | plot(example, orientation="vertical") 15 | plt.suptitle("A vertical plot") 16 | plt.show() 17 | 18 | ########################################################################## 19 | 20 | plot(example, orientation="vertical", show_counts="{:d}") 21 | plt.suptitle("A vertical plot with counts shown") 22 | plt.show() 23 | 24 | ########################################################################## 25 | 26 | plot(example, orientation="vertical", show_counts="{:d}", show_percentages=True) 27 | plt.suptitle("With counts and percentages shown") 28 | plt.show() 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | extend-include = ["*.ipynb"] 3 | target-version = "py38" 4 | 5 | [tool.ruff.lint] 6 | # see https://docs.astral.sh/ruff/rules/ 7 | select = [ 8 | # Pyflakes 9 | "F", 10 | # Pycodestyle 11 | "E", 12 | "UP", 13 | "W", 14 | # isort 15 | "I", 16 | "B", # bugbear 17 | "C4", # comprehensions 18 | "PT", # pytest 19 | "SIM", # simplify 20 | ] 21 | ignore = [ 22 | "B007", # breaks for pandas.query 23 | "PT011", 24 | ] 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description = Draw Lex et al.'s UpSet plots with Pandas and Matplotlib 3 | long_description = file: README.rst 4 | author = Joel Nothman 5 | author_email = joel.nothman@gmail.com 6 | url = https://upsetplot.readthedocs.io 7 | license = BSD 3-Clause License 8 | classifiers = 9 | License :: OSI Approved :: BSD License 10 | Programming Language :: Python :: 3 11 | Programming Language :: Python :: 3.6 12 | Programming Language :: Python :: 3.10 13 | Topic :: Scientific/Engineering :: Visualization 14 | Intended Audience :: Science/Research 15 | 16 | [aliases] 17 | test = pytest 18 | 19 | [tool:pytest] 20 | addopts = --doctest-modules --verbose --cov=upsetplot --showlocals 21 | # --cov=upsetplot 22 | testpaths = upsetplot README.rst 23 | doctest_optionflags = ALLOW_UNICODE NORMALIZE_WHITESPACE ELLIPSIS 24 | 25 | [flake8] 26 | ignore = W503,W504 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | 6 | from setuptools import setup 7 | 8 | 9 | def setup_package(): 10 | src_path = os.path.dirname(os.path.abspath(sys.argv[0])) 11 | old_path = os.getcwd() 12 | os.chdir(src_path) 13 | sys.path.insert(0, src_path) 14 | 15 | try: 16 | os.environ["__IN-SETUP"] = "1" # ensures only version is imported 17 | from upsetplot import __version__ as version 18 | 19 | # See also setup.cfg 20 | setup( 21 | name="UpSetPlot", 22 | version=version, 23 | packages=["upsetplot"], 24 | license="BSD-3-Clause", 25 | extras_require={"testing": ["pytest>=2.7", "pytest-cov<2.6"]}, 26 | # TODO: check versions 27 | install_requires=["pandas>=0.23", "matplotlib>=2.0"], 28 | ) 29 | finally: 30 | del sys.path[0] 31 | os.chdir(old_path) 32 | return 33 | 34 | 35 | if __name__ == "__main__": 36 | setup_package() 37 | -------------------------------------------------------------------------------- /upsetplot/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.10dev1" 2 | 3 | import os 4 | 5 | if os.environ.get("__IN-SETUP", None) != "1": 6 | from .data import ( 7 | from_contents, 8 | from_indicators, 9 | from_memberships, 10 | generate_counts, 11 | generate_data, 12 | generate_samples, 13 | ) 14 | from .plotting import UpSet, plot 15 | from .reformat import query 16 | 17 | __all__ = [ 18 | "UpSet", 19 | "generate_data", 20 | "generate_counts", 21 | "generate_samples", 22 | "plot", 23 | "from_memberships", 24 | "from_contents", 25 | "from_indicators", 26 | "query", 27 | ] 28 | -------------------------------------------------------------------------------- /upsetplot/data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from numbers import Number 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def generate_samples(seed=0, n_samples=10000, n_categories=3): 9 | """Generate artificial samples assigned to set intersections 10 | 11 | Parameters 12 | ---------- 13 | seed : int 14 | A seed for randomisation 15 | n_samples : int 16 | Number of samples to generate 17 | n_categories : int 18 | Number of categories (named "cat0", "cat1", ...) to generate 19 | 20 | Returns 21 | ------- 22 | DataFrame 23 | Field 'value' is a weight or score for each element. 24 | Field 'index' is a unique id for each element. 25 | Index includes a boolean indicator mask for each category. 26 | 27 | Note: Further fields may be added in future versions. 28 | 29 | See Also 30 | -------- 31 | generate_counts : Generates the counts for each subset of categories 32 | corresponding to these samples. 33 | """ 34 | rng = np.random.RandomState(seed) 35 | df = pd.DataFrame({"value": np.zeros(n_samples)}) 36 | for i in range(n_categories): 37 | r = rng.rand(n_samples) 38 | df["cat%d" % i] = r > rng.rand() 39 | df["value"] += r 40 | 41 | df.reset_index(inplace=True) 42 | df.set_index(["cat%d" % i for i in range(n_categories)], inplace=True) 43 | return df 44 | 45 | 46 | def generate_counts(seed=0, n_samples=10000, n_categories=3): 47 | """Generate artificial counts corresponding to set intersections 48 | 49 | Parameters 50 | ---------- 51 | seed : int 52 | A seed for randomisation 53 | n_samples : int 54 | Number of samples to generate statistics over 55 | n_categories : int 56 | Number of categories (named "cat0", "cat1", ...) to generate 57 | 58 | Returns 59 | ------- 60 | Series 61 | Counts indexed by boolean indicator mask for each category. 62 | 63 | See Also 64 | -------- 65 | generate_samples : Generates a DataFrame of samples that these counts are 66 | derived from. 67 | """ 68 | df = generate_samples(seed=seed, n_samples=n_samples, n_categories=n_categories) 69 | return df.value.groupby(level=list(range(n_categories))).count() 70 | 71 | 72 | def generate_data(seed=0, n_samples=10000, n_sets=3, aggregated=False): 73 | warnings.warn( 74 | "generate_data was replaced by generate_counts in version " 75 | "0.3 and will be removed in version 0.4.", 76 | DeprecationWarning, 77 | stacklevel=2, 78 | ) 79 | if aggregated: 80 | return generate_counts(seed=seed, n_samples=n_samples, n_categories=n_sets) 81 | else: 82 | return generate_samples(seed=seed, n_samples=n_samples, n_categories=n_sets)[ 83 | "value" 84 | ] 85 | 86 | 87 | def from_indicators(indicators, data=None): 88 | """Load category membership indicated by a boolean indicator matrix 89 | 90 | This loader also supports the case where the indicator columns can be 91 | derived from `data`. 92 | 93 | .. versionadded:: 0.6 94 | 95 | Parameters 96 | ---------- 97 | indicators : DataFrame-like of booleans, Sequence of str, or callable 98 | Specifies the category indicators (boolean mask arrays) within 99 | ``data``, i.e. which records in ``data`` belong to which categories. 100 | 101 | If a list of strings, these should be column names found in ``data`` 102 | whose values are boolean mask arrays. 103 | 104 | If a DataFrame, its columns should correspond to categories, and its 105 | index should be a subset of those in ``data``, values should be True 106 | where a data record is in that category, and False or NA otherwise. 107 | 108 | If callable, it will be applied to ``data`` after the latter is 109 | converted to a Series or DataFrame. 110 | 111 | data : Series-like or DataFrame-like, optional 112 | If given, the index of category membership is attached to this data. 113 | It must have the same length as `indicators`. 114 | If not given, the series will contain the value 1. 115 | 116 | Returns 117 | ------- 118 | DataFrame or Series 119 | `data` is returned with its index indicating category membership. 120 | It will be a Series if `data` is a Series or 1d numeric array or None. 121 | 122 | Notes 123 | ----- 124 | Categories with indicators that are all False will be removed. 125 | 126 | Examples 127 | -------- 128 | >>> import pandas as pd 129 | >>> from upsetplot import from_indicators 130 | >>> 131 | >>> # Just indicators: 132 | >>> indicators = {"cat1": [True, False, True, False], 133 | ... "cat2": [False, True, False, False], 134 | ... "cat3": [True, True, False, False]} 135 | >>> from_indicators(indicators) 136 | cat1 cat2 cat3 137 | True False True 1.0 138 | False True True 1.0 139 | True False False 1.0 140 | False False False 1.0 141 | Name: ones, dtype: float64 142 | >>> 143 | >>> # Where indicators are included within data, specifying 144 | >>> # columns by name: 145 | >>> data = pd.DataFrame({"value": [5, 4, 6, 4], **indicators}) 146 | >>> from_indicators(["cat1", "cat3"], data=data) 147 | value cat1 cat2 cat3 148 | cat1 cat3 149 | True True 5 True False True 150 | False True 4 False True True 151 | True False 6 True False False 152 | False False 4 False False False 153 | >>> 154 | >>> # Making indicators out of all boolean columns: 155 | >>> from_indicators(lambda data: data.select_dtypes(bool), data=data) 156 | value cat1 cat2 cat3 157 | cat1 cat2 cat3 158 | True False True 5 True False True 159 | False True True 4 False True True 160 | True False False 6 True False False 161 | False False False 4 False False False 162 | >>> 163 | >>> # Using a dataset with missing data, we can use missingness as 164 | >>> # an indicator: 165 | >>> data = pd.DataFrame({"val1": [pd.NA, .7, pd.NA, .9], 166 | ... "val2": ["male", pd.NA, "female", "female"], 167 | ... "val3": [pd.NA, pd.NA, 23000, 78000]}) 168 | >>> from_indicators(pd.isna, data=data) 169 | val1 val2 val3 170 | val1 val2 val3 171 | True False True male 172 | False True True 0.7 173 | True False False female 23000 174 | False False False 0.9 female 78000 175 | """ 176 | if data is not None: 177 | data = _convert_to_pandas(data) 178 | 179 | if callable(indicators): 180 | if data is None: 181 | raise ValueError("data must be provided when indicators is " "callable") 182 | indicators = indicators(data) 183 | 184 | try: 185 | indicators[0] 186 | except Exception: 187 | pass 188 | else: 189 | if isinstance(indicators[0], (str, int)): 190 | if data is None: 191 | raise ValueError( 192 | "data must be provided when indicators are " 193 | "specified as a list of columns" 194 | ) 195 | if isinstance(indicators, tuple): 196 | raise ValueError("indicators as tuple is not supported") 197 | # column array 198 | indicators = data[indicators] 199 | 200 | indicators = pd.DataFrame(indicators).fillna(False).infer_objects() 201 | # drop all-False (should we be dropping all-True also? making an option?) 202 | indicators = indicators.loc[:, indicators.any(axis=0)] 203 | 204 | if not all(dtype.kind == "b" for dtype in indicators.dtypes): 205 | raise ValueError("The indicators must all be boolean") 206 | 207 | if data is not None: 208 | if not ( 209 | isinstance(indicators.index, pd.RangeIndex) 210 | and indicators.index[0] == 0 211 | and indicators.index[-1] == len(data) - 1 212 | ): 213 | # index is specified on indicators. Need to align it to data 214 | if not indicators.index.isin(data.index).all(): 215 | raise ValueError( 216 | "If indicators.index is not the default, " 217 | "all its values must be present in " 218 | "data.index" 219 | ) 220 | indicators = indicators.reindex(index=data.index, fill_value=False) 221 | else: 222 | data = pd.Series(np.ones(len(indicators)), name="ones") 223 | 224 | indicators.set_index(list(indicators.columns), inplace=True) 225 | data.index = indicators.index 226 | 227 | return data 228 | 229 | 230 | def _convert_to_pandas(data, copy=True): 231 | is_series = False 232 | if hasattr(data, "loc"): 233 | if copy: 234 | data = data.copy(deep=False) 235 | is_series = data.ndim == 1 236 | elif len(data): 237 | try: 238 | is_series = isinstance(data[0], Number) 239 | except KeyError: 240 | is_series = False 241 | return pd.Series(data) if is_series else pd.DataFrame(data) 242 | 243 | 244 | def from_memberships(memberships, data=None): 245 | """Load data where each sample has a collection of category names 246 | 247 | The output should be suitable for passing to `UpSet` or `plot`. 248 | 249 | Parameters 250 | ---------- 251 | memberships : sequence of collections of strings 252 | Each element corresponds to a data point, indicating the sets it is a 253 | member of. Each category is named by a string. 254 | data : Series-like or DataFrame-like, optional 255 | If given, the index of category memberships is attached to this data. 256 | It must have the same length as `memberships`. 257 | If not given, the series will contain the value 1. 258 | 259 | Returns 260 | ------- 261 | DataFrame or Series 262 | `data` is returned with its index indicating category membership. 263 | It will be a Series if `data` is a Series or 1d numeric array. 264 | The index will have levels ordered by category names. 265 | 266 | Examples 267 | -------- 268 | >>> from upsetplot import from_memberships 269 | >>> from_memberships([ 270 | ... ['cat1', 'cat3'], 271 | ... ['cat2', 'cat3'], 272 | ... ['cat1'], 273 | ... [] 274 | ... ]) 275 | cat1 cat2 cat3 276 | True False True 1 277 | False True True 1 278 | True False False 1 279 | False False False 1 280 | Name: ones, dtype: ... 281 | >>> # now with data: 282 | >>> import numpy as np 283 | >>> from_memberships([ 284 | ... ['cat1', 'cat3'], 285 | ... ['cat2', 'cat3'], 286 | ... ['cat1'], 287 | ... [] 288 | ... ], data=np.arange(12).reshape(4, 3)) 289 | 0 1 2 290 | cat1 cat2 cat3 291 | True False True 0 1 2 292 | False True True 3 4 5 293 | True False False 6 7 8 294 | False False False 9 10 11 295 | """ 296 | df = pd.DataFrame([{name: True for name in names} for names in memberships]) 297 | for set_name in df.columns: 298 | if not hasattr(set_name, "lower"): 299 | raise ValueError("Category names should be strings") 300 | if df.shape[1] == 0: 301 | raise ValueError("Require at least one category. None were found.") 302 | df.sort_index(axis=1, inplace=True) 303 | df.fillna(False, inplace=True) 304 | df = df.astype(bool) 305 | df.set_index(list(df.columns), inplace=True) 306 | if data is None: 307 | return df.assign(ones=1)["ones"] 308 | 309 | data = _convert_to_pandas(data) 310 | if len(data) != len(df): 311 | raise ValueError( 312 | "memberships and data must have the same length. " 313 | "Got len(memberships) == %d, len(data) == %d" 314 | % (len(memberships), len(data)) 315 | ) 316 | data.index = df.index 317 | return data 318 | 319 | 320 | def from_contents(contents, data=None, id_column="id"): 321 | """Build data from category listings 322 | 323 | Parameters 324 | ---------- 325 | contents : Mapping (or iterable over pairs) of strings to sets 326 | Keys are category names, values are sets of identifiers (int or 327 | string). 328 | data : DataFrame, optional 329 | If provided, this should be indexed by the identifiers used in 330 | `contents`. 331 | id_column : str, default='id' 332 | The column name to use for the identifiers in the output. 333 | 334 | Returns 335 | ------- 336 | DataFrame 337 | `data` is returned with its index indicating category membership, 338 | including a column named according to id_column. 339 | If data is not given, the order of rows is not assured. 340 | 341 | Notes 342 | ----- 343 | The order of categories in the output DataFrame is determined from 344 | `contents`, which may have non-deterministic iteration order. 345 | 346 | Examples 347 | -------- 348 | >>> from upsetplot import from_contents 349 | >>> contents = {'cat1': ['a', 'b', 'c'], 350 | ... 'cat2': ['b', 'd'], 351 | ... 'cat3': ['e']} 352 | >>> from_contents(contents) 353 | id 354 | cat1 cat2 cat3 355 | True False False a 356 | True False b 357 | False False c 358 | False True False d 359 | False True e 360 | >>> import pandas as pd 361 | >>> contents = {'cat1': [0, 1, 2], 362 | ... 'cat2': [1, 3], 363 | ... 'cat3': [4]} 364 | >>> data = pd.DataFrame({'favourite': ['green', 'red', 'red', 365 | ... 'yellow', 'blue']}) 366 | >>> from_contents(contents, data=data) 367 | id favourite 368 | cat1 cat2 cat3 369 | True False False 0 green 370 | True False 1 red 371 | False False 2 red 372 | False True False 3 yellow 373 | False True 4 blue 374 | """ 375 | cat_series = [ 376 | pd.Series(True, index=list(elements), name=name) 377 | for name, elements in contents.items() 378 | ] 379 | if not all(s.index.is_unique for s in cat_series): 380 | raise ValueError("Got duplicate ids in a category") 381 | 382 | df = pd.concat(cat_series, axis=1, sort=False) 383 | if id_column in df.columns: 384 | raise ValueError("A category cannot be named %r" % id_column) 385 | df.fillna(False, inplace=True) 386 | cat_names = list(df.columns) 387 | 388 | if data is not None: 389 | if set(df.columns).intersection(data.columns): 390 | raise ValueError("Data columns overlap with category names") 391 | if id_column in data.columns: 392 | raise ValueError("data cannot contain a column named %r" % id_column) 393 | not_in_data = df.drop(data.index, axis=0, errors="ignore") 394 | if len(not_in_data): 395 | raise ValueError( 396 | "Found identifiers in contents that are not in " 397 | "data: %r" % not_in_data.index.values 398 | ) 399 | df = df.reindex(index=data.index).fillna(False) 400 | df = pd.concat([data, df], axis=1, sort=False) 401 | df.index.name = id_column 402 | return df.reset_index().set_index(cat_names) 403 | -------------------------------------------------------------------------------- /upsetplot/plotting.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import warnings 3 | 4 | import matplotlib 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib import colors, patches 8 | from matplotlib import pyplot as plt 9 | 10 | from . import util 11 | from .reformat import _get_subset_mask, query 12 | 13 | # prevents ImportError on matplotlib versions >3.5.2 14 | try: 15 | from matplotlib.tight_layout import get_renderer 16 | 17 | RENDERER_IMPORTED = True 18 | except ImportError: 19 | RENDERER_IMPORTED = False 20 | 21 | 22 | def _process_data( 23 | df, 24 | *, 25 | sort_by, 26 | sort_categories_by, 27 | subset_size, 28 | sum_over, 29 | min_subset_size=None, 30 | max_subset_size=None, 31 | max_subset_rank=None, 32 | min_degree=None, 33 | max_degree=None, 34 | reverse=False, 35 | include_empty_subsets=False, 36 | ): 37 | results = query( 38 | df, 39 | sort_by=sort_by, 40 | sort_categories_by=sort_categories_by, 41 | subset_size=subset_size, 42 | sum_over=sum_over, 43 | min_subset_size=min_subset_size, 44 | max_subset_size=max_subset_size, 45 | max_subset_rank=max_subset_rank, 46 | min_degree=min_degree, 47 | max_degree=max_degree, 48 | include_empty_subsets=include_empty_subsets, 49 | ) 50 | 51 | df = results.data 52 | agg = results.subset_sizes 53 | 54 | # add '_bin' to df indicating index in agg 55 | # XXX: ugly! 56 | def _pack_binary(X): 57 | X = pd.DataFrame(X) 58 | # use objects if arbitrary precision integers are needed 59 | dtype = np.object_ if X.shape[1] > 62 else np.uint64 60 | out = pd.Series(0, index=X.index, dtype=dtype) 61 | for _, col in X.items(): 62 | out *= 2 63 | out += col 64 | return out 65 | 66 | df_packed = _pack_binary(df.index.to_frame()) 67 | data_packed = _pack_binary(agg.index.to_frame()) 68 | df["_bin"] = pd.Series(df_packed).map( 69 | pd.Series( 70 | np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed 71 | ) 72 | ) 73 | if reverse: 74 | agg = agg[::-1] 75 | 76 | return results.total, df, agg, results.category_totals 77 | 78 | 79 | def _multiply_alpha(c, mult): 80 | r, g, b, a = colors.to_rgba(c) 81 | a *= mult 82 | return colors.to_hex((r, g, b, a), keep_alpha=True) 83 | 84 | 85 | class _Transposed: 86 | """Wrap an object in order to transpose some plotting operations 87 | 88 | Attributes of obj will be mapped. 89 | Keyword arguments when calling obj will be mapped. 90 | 91 | The mapping is not recursive: callable attributes need to be _Transposed 92 | again. 93 | """ 94 | 95 | def __init__(self, obj): 96 | self.__obj = obj 97 | 98 | def __getattr__(self, key): 99 | return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key)) 100 | 101 | def __call__(self, *args, **kwargs): 102 | return self.__obj( 103 | *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()} 104 | ) 105 | 106 | _NAME_TRANSPOSE = { 107 | "align_xlabels": "align_ylabels", 108 | "align_ylabels": "align_xlabels", 109 | "bar": "barh", 110 | "barh": "bar", 111 | "bottom": "left", 112 | "get_figheight": "get_figwidth", 113 | "get_figwidth": "get_figheight", 114 | "get_xlim": "get_ylim", 115 | "get_ylim": "get_xlim", 116 | "height": "width", 117 | "hlines": "vlines", 118 | "hspace": "wspace", 119 | "left": "bottom", 120 | "right": "top", 121 | "set_autoscalex_on": "set_autoscaley_on", 122 | "set_autoscaley_on": "set_autoscalex_on", 123 | "set_figheight": "set_figwidth", 124 | "set_figwidth": "set_figheight", 125 | "set_xlabel": "set_ylabel", 126 | "set_xlim": "set_ylim", 127 | "set_ylabel": "set_xlabel", 128 | "set_ylim": "set_xlim", 129 | "sharex": "sharey", 130 | "sharey": "sharex", 131 | "top": "right", 132 | "vlines": "hlines", 133 | "width": "height", 134 | "wspace": "hspace", 135 | "xaxis": "yaxis", 136 | "yaxis": "xaxis", 137 | } 138 | 139 | 140 | def _transpose(obj): 141 | if isinstance(obj, str): 142 | return _Transposed._NAME_TRANSPOSE.get(obj, obj) 143 | return _Transposed(obj) 144 | 145 | 146 | def _identity(obj): 147 | return obj 148 | 149 | 150 | class UpSet: 151 | """Manage the data and drawing for a basic UpSet plot 152 | 153 | Primary public method is :meth:`plot`. 154 | 155 | Parameters 156 | ---------- 157 | data : pandas.Series or pandas.DataFrame 158 | Elements associated with categories (a DataFrame), or the size of each 159 | subset of categories (a Series). 160 | Should have MultiIndex where each level is binary, 161 | corresponding to category membership. 162 | If a DataFrame, `sum_over` must be a string or False. 163 | orientation : {'horizontal' (default), 'vertical'} 164 | If horizontal, intersections are listed from left to right. 165 | sort_by : {'cardinality', 'degree', '-cardinality', '-degree', 166 | 'input', '-input'} 167 | If 'cardinality', subset are listed from largest to smallest. 168 | If 'degree', they are listed in order of the number of categories 169 | intersected. If 'input', the order they appear in the data input is 170 | used. 171 | Prefix with '-' to reverse the ordering. 172 | 173 | Note this affects ``subset_sizes`` but not ``data``. 174 | sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'} 175 | Whether to sort the categories by total cardinality, or leave them 176 | in the input data's provided order (order of index levels). 177 | Prefix with '-' to reverse the ordering. 178 | subset_size : {'auto', 'count', 'sum'} 179 | Configures how to calculate the size of a subset. Choices are: 180 | 181 | 'auto' (default) 182 | If `data` is a DataFrame, count the number of rows in each group, 183 | unless `sum_over` is specified. 184 | If `data` is a Series with at most one row for each group, use 185 | the value of the Series. If `data` is a Series with more than one 186 | row per group, raise a ValueError. 187 | 'count' 188 | Count the number of rows in each group. 189 | 'sum' 190 | Sum the value of the `data` Series, or the DataFrame field 191 | specified by `sum_over`. 192 | sum_over : str or None 193 | If `subset_size='sum'` or `'auto'`, then the intersection size is the 194 | sum of the specified field in the `data` DataFrame. If a Series, only 195 | None is supported and its value is summed. 196 | min_subset_size : int or "number%", optional 197 | Minimum size of a subset to be shown in the plot. All subsets with 198 | a size smaller than this threshold will be omitted from plotting. 199 | This may be specified as a percentage 200 | using a string, like "50%". 201 | Size may be a sum of values, see `subset_size`. 202 | 203 | .. versionadded:: 0.5 204 | 205 | .. versionchanged:: 0.9 206 | Support percentages 207 | max_subset_size : int or "number%", optional 208 | Maximum size of a subset to be shown in the plot. All subsets with 209 | a size greater than this threshold will be omitted from plotting. 210 | This may be specified as a percentage 211 | using a string, like "50%". 212 | 213 | .. versionadded:: 0.5 214 | 215 | .. versionchanged:: 0.9 216 | Support percentages 217 | max_subset_rank : int, optional 218 | Limit to the top N ranked subsets in descending order of size. 219 | All tied subsets are included. 220 | 221 | .. versionadded:: 0.9 222 | min_degree : int, optional 223 | Minimum degree of a subset to be shown in the plot. 224 | 225 | .. versionadded:: 0.5 226 | max_degree : int, optional 227 | Maximum degree of a subset to be shown in the plot. 228 | 229 | .. versionadded:: 0.5 230 | facecolor : 'auto' or matplotlib color or float 231 | Color for bar charts and active dots. Defaults to black if 232 | axes.facecolor is a light color, otherwise white. 233 | 234 | .. versionchanged:: 0.6 235 | Before 0.6, the default was 'black' 236 | other_dots_color : matplotlib color or float 237 | Color for shading of inactive dots, or opacity (between 0 and 1) 238 | applied to facecolor. 239 | 240 | .. versionadded:: 0.6 241 | shading_color : matplotlib color or float 242 | Color for shading of odd rows in matrix and totals, or opacity (between 243 | 0 and 1) applied to facecolor. 244 | 245 | .. versionadded:: 0.6 246 | with_lines : bool 247 | Whether to show lines joining dots in the matrix, to mark multiple 248 | categories being intersected. 249 | element_size : float or None 250 | Side length in pt. If None, size is estimated to fit figure 251 | intersection_plot_elements : int 252 | The intersections plot should be large enough to fit this many matrix 253 | elements. Set to 0 to disable intersection size bars. 254 | 255 | .. versionchanged:: 0.4 256 | Setting to 0 is handled. 257 | totals_plot_elements : int 258 | The totals plot should be large enough to fit this many matrix 259 | elements. Set to 0 to disable the totals plot. 260 | 261 | .. versionchanged:: 0.9 262 | Setting to 0 is handled. 263 | show_counts : bool or str, default=False 264 | Whether to label the intersection size bars with the cardinality 265 | of the intersection. When a string, this formats the number. 266 | For example, '{:d}' is equivalent to True. 267 | Note that, for legacy reasons, if the string does not contain '{', 268 | it will be interpreted as a C-style format string, such as '%d'. 269 | show_percentages : bool or str, default=False 270 | Whether to label the intersection size bars with the percentage 271 | of the intersection relative to the total dataset. 272 | When a string, this formats the number representing a fraction of 273 | samples. 274 | For example, '{:.1%}' is the default, formatting .123 as 12.3%. 275 | This may be applied with or without show_counts. 276 | 277 | .. versionadded:: 0.4 278 | include_empty_subsets : bool (default=False) 279 | If True, all possible category combinations will be shown as subsets, 280 | even when some are not present in data. 281 | """ 282 | 283 | _default_figsize = (10, 6) 284 | DPI = 100 # standard matplotlib value 285 | 286 | def __init__( 287 | self, 288 | data, 289 | orientation="horizontal", 290 | sort_by="degree", 291 | sort_categories_by="cardinality", 292 | subset_size="auto", 293 | sum_over=None, 294 | min_subset_size=None, 295 | max_subset_size=None, 296 | max_subset_rank=None, 297 | min_degree=None, 298 | max_degree=None, 299 | facecolor="auto", 300 | other_dots_color=0.18, 301 | shading_color=0.05, 302 | with_lines=True, 303 | element_size=32, 304 | intersection_plot_elements=6, 305 | totals_plot_elements=2, 306 | show_counts="", 307 | show_percentages=False, 308 | include_empty_subsets=False, 309 | ): 310 | self._horizontal = orientation == "horizontal" 311 | self._reorient = _identity if self._horizontal else _transpose 312 | if facecolor == "auto": 313 | bgcolor = matplotlib.rcParams.get("axes.facecolor", "white") 314 | r, g, b, a = colors.to_rgba(bgcolor) 315 | lightness = colors.rgb_to_hsv((r, g, b))[-1] * a 316 | facecolor = "black" if lightness >= 0.5 else "white" 317 | self._facecolor = facecolor 318 | self._shading_color = ( 319 | _multiply_alpha(facecolor, shading_color) 320 | if isinstance(shading_color, float) 321 | else shading_color 322 | ) 323 | self._other_dots_color = ( 324 | _multiply_alpha(facecolor, other_dots_color) 325 | if isinstance(other_dots_color, float) 326 | else other_dots_color 327 | ) 328 | self._with_lines = with_lines 329 | self._element_size = element_size 330 | self._totals_plot_elements = totals_plot_elements 331 | self._subset_plots = [ 332 | { 333 | "type": "default", 334 | "id": "intersections", 335 | "elements": intersection_plot_elements, 336 | } 337 | ] 338 | if not intersection_plot_elements: 339 | self._subset_plots.pop() 340 | self._show_counts = show_counts 341 | self._show_percentages = show_percentages 342 | 343 | (self.total, self._df, self.intersections, self.totals) = _process_data( 344 | data, 345 | sort_by=sort_by, 346 | sort_categories_by=sort_categories_by, 347 | subset_size=subset_size, 348 | sum_over=sum_over, 349 | min_subset_size=min_subset_size, 350 | max_subset_size=max_subset_size, 351 | max_subset_rank=max_subset_rank, 352 | min_degree=min_degree, 353 | max_degree=max_degree, 354 | reverse=not self._horizontal, 355 | include_empty_subsets=include_empty_subsets, 356 | ) 357 | self.category_styles = {} 358 | self.subset_styles = [ 359 | {"facecolor": facecolor} for i in range(len(self.intersections)) 360 | ] 361 | self.subset_legend = [] # pairs of (style, label) 362 | 363 | def _swapaxes(self, x, y): 364 | if self._horizontal: 365 | return x, y 366 | return y, x 367 | 368 | def style_subsets( 369 | self, 370 | present=None, 371 | absent=None, 372 | min_subset_size=None, 373 | max_subset_size=None, 374 | max_subset_rank=None, 375 | min_degree=None, 376 | max_degree=None, 377 | facecolor=None, 378 | edgecolor=None, 379 | hatch=None, 380 | linewidth=None, 381 | linestyle=None, 382 | label=None, 383 | ): 384 | """Updates the style of selected subsets' bars and matrix dots 385 | 386 | Parameters are either used to select subsets, or to style them with 387 | attributes of :class:`matplotlib.patches.Patch`, apart from label, 388 | which adds a legend entry. 389 | 390 | Parameters 391 | ---------- 392 | present : str or list of str, optional 393 | Category or categories that must be present in subsets for styling. 394 | absent : str or list of str, optional 395 | Category or categories that must not be present in subsets for 396 | styling. 397 | min_subset_size : int or "number%", optional 398 | Minimum size of a subset to be styled. 399 | This may be specified as a percentage using a string, like "50%". 400 | 401 | .. versionchanged:: 0.9 402 | Support percentages 403 | max_subset_size : int or "number%", optional 404 | Maximum size of a subset to be styled. 405 | This may be specified as a percentage using a string, like "50%". 406 | 407 | .. versionchanged:: 0.9 408 | Support percentages 409 | max_subset_rank : int, optional 410 | Limit to the top N ranked subsets in descending order of size. 411 | All tied subsets are included. 412 | 413 | .. versionadded:: 0.9 414 | min_degree : int, optional 415 | Minimum degree of a subset to be styled. 416 | max_degree : int, optional 417 | Maximum degree of a subset to be styled. 418 | 419 | facecolor : str or matplotlib color, optional 420 | Override the default UpSet facecolor for selected subsets. 421 | edgecolor : str or matplotlib color, optional 422 | Set the edgecolor for bars, dots, and the line between dots. 423 | hatch : str, optional 424 | Set the hatch. This will apply to intersection size bars, but not 425 | to matrix dots. 426 | linewidth : int, optional 427 | Line width in points for edges. 428 | linestyle : str, optional 429 | Line style for edges. 430 | 431 | label : str, optional 432 | If provided, a legend will be added 433 | """ 434 | style = { 435 | "facecolor": facecolor, 436 | "edgecolor": edgecolor, 437 | "hatch": hatch, 438 | "linewidth": linewidth, 439 | "linestyle": linestyle, 440 | } 441 | style = {k: v for k, v in style.items() if v is not None} 442 | mask = _get_subset_mask( 443 | self.intersections, 444 | present=present, 445 | absent=absent, 446 | min_subset_size=min_subset_size, 447 | max_subset_size=max_subset_size, 448 | max_subset_rank=max_subset_rank, 449 | min_degree=min_degree, 450 | max_degree=max_degree, 451 | ) 452 | for idx in np.flatnonzero(mask): 453 | self.subset_styles[idx].update(style) 454 | 455 | if label is not None: 456 | if "facecolor" not in style: 457 | style["facecolor"] = self._facecolor 458 | for i, (other_style, other_label) in enumerate(self.subset_legend): 459 | if other_style == style: 460 | if other_label != label: 461 | self.subset_legend[i] = (style, other_label + "; " + label) 462 | break 463 | else: 464 | self.subset_legend.append((style, label)) 465 | 466 | def _plot_bars(self, ax, data, title, colors=None, use_labels=False): 467 | ax = self._reorient(ax) 468 | ax.set_autoscalex_on(False) 469 | data_df = pd.DataFrame(data) 470 | if self._horizontal: 471 | data_df = data_df.loc[:, ::-1] # reverse: top row is top of stack 472 | 473 | # TODO: colors should be broadcastable to data_df shape 474 | if callable(colors): 475 | colors = colors(range(data_df.shape[1])) 476 | elif isinstance(colors, (str, type(None))): 477 | colors = [colors] * len(data_df) 478 | 479 | if self._horizontal: 480 | colors = reversed(colors) 481 | 482 | x = np.arange(len(data_df)) 483 | cum_y = None 484 | all_rects = [] 485 | for (name, y), color in zip(data_df.items(), colors): 486 | rects = ax.bar( 487 | x, 488 | y, 489 | 0.5, 490 | cum_y, 491 | color=color, 492 | zorder=10, 493 | label=name if use_labels else None, 494 | align="center", 495 | ) 496 | cum_y = y if cum_y is None else cum_y + y 497 | all_rects.extend(rects) 498 | 499 | self._label_sizes(ax, rects, "top" if self._horizontal else "right") 500 | 501 | ax.xaxis.set_visible(False) 502 | for x in ["top", "bottom", "right"]: 503 | ax.spines[self._reorient(x)].set_visible(False) 504 | 505 | tick_axis = ax.yaxis 506 | tick_axis.grid(True) 507 | ax.set_ylabel(title) 508 | return all_rects 509 | 510 | def _plot_stacked_bars(self, ax, by, sum_over, colors, title): 511 | df = self._df.set_index("_bin").set_index(by, append=True, drop=False) 512 | gb = df.groupby(level=list(range(df.index.nlevels)), sort=True) 513 | if sum_over is None and "_value" in df.columns: 514 | data = gb["_value"].sum() 515 | elif sum_over is None: 516 | data = gb.size() 517 | else: 518 | data = gb[sum_over].sum() 519 | data = data.unstack(by).fillna(0) 520 | if isinstance(colors, str): 521 | colors = matplotlib.cm.get_cmap(colors) 522 | elif isinstance(colors, typing.Mapping): 523 | colors = data.columns.map(colors).values 524 | if pd.isna(colors).any(): 525 | raise KeyError( 526 | "Some labels mapped by colors: %r" 527 | % data.columns[pd.isna(colors)].tolist() 528 | ) 529 | 530 | self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True) 531 | 532 | handles, labels = ax.get_legend_handles_labels() 533 | if self._horizontal: 534 | # Make legend order match visual stack order 535 | ax.legend(reversed(handles), reversed(labels)) 536 | else: 537 | ax.legend() 538 | 539 | def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None): 540 | """Add a stacked bar chart over subsets when :func:`plot` is called. 541 | 542 | Used to plot categorical variable distributions within each subset. 543 | 544 | .. versionadded:: 0.6 545 | 546 | Parameters 547 | ---------- 548 | by : str 549 | Column name within the dataframe for color coding the stacked bars, 550 | containing discrete or categorical values. 551 | sum_over : str, optional 552 | Ordinarily the bars will chart the size of each group. sum_over 553 | may specify a column which will be summed to determine the size 554 | of each bar. 555 | colors : Mapping, list-like, str or callable, optional 556 | The facecolors to use for bars corresponding to each discrete 557 | label, specified as one of: 558 | 559 | Mapping 560 | Maps from label to matplotlib-compatible color specification. 561 | list-like 562 | A list of matplotlib colors to apply to labels in order. 563 | str 564 | The name of a matplotlib colormap name. 565 | callable 566 | When called with the number of labels, this should return a 567 | list-like of that many colors. Matplotlib colormaps satisfy 568 | this callable API. 569 | None 570 | Uses the matplotlib default colormap. 571 | elements : int, default=3 572 | Size of the axes counted in number of matrix elements. 573 | title : str, optional 574 | The axis title labelling bar length. 575 | 576 | Returns 577 | ------- 578 | None 579 | """ 580 | # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares", 581 | # list of labels} 582 | self._subset_plots.append( 583 | { 584 | "type": "stacked_bars", 585 | "by": by, 586 | "sum_over": sum_over, 587 | "colors": colors, 588 | "title": title, 589 | "id": "extra%d" % len(self._subset_plots), 590 | "elements": elements, 591 | } 592 | ) 593 | 594 | def add_catplot(self, kind, value=None, elements=3, **kw): 595 | """Add a seaborn catplot over subsets when :func:`plot` is called. 596 | 597 | Parameters 598 | ---------- 599 | kind : str 600 | One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"} 601 | value : str, optional 602 | Column name for the value to plot (i.e. y if 603 | orientation='horizontal'), required if `data` is a DataFrame. 604 | elements : int, default=3 605 | Size of the axes counted in number of matrix elements. 606 | **kw : dict 607 | Additional keywords to pass to :func:`seaborn.catplot`. 608 | 609 | Our implementation automatically determines 'ax', 'data', 'x', 'y' 610 | and 'orient', so these are prohibited keys in `kw`. 611 | 612 | Returns 613 | ------- 614 | None 615 | """ 616 | assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"} 617 | if value is None: 618 | if "_value" not in self._df.columns: 619 | raise ValueError( 620 | "value cannot be set if data is a Series. " "Got %r" % value 621 | ) 622 | else: 623 | if value not in self._df.columns: 624 | raise ValueError("value %r is not a column in data" % value) 625 | self._subset_plots.append( 626 | { 627 | "type": "catplot", 628 | "value": value, 629 | "kind": kind, 630 | "id": "extra%d" % len(self._subset_plots), 631 | "elements": elements, 632 | "kw": kw, 633 | } 634 | ) 635 | 636 | def _check_value(self, value): 637 | if value is None and "_value" in self._df.columns: 638 | value = "_value" 639 | elif value is None: 640 | raise ValueError("value can only be None when data is a Series") 641 | return value 642 | 643 | def _plot_catplot(self, ax, value, kind, kw): 644 | df = self._df 645 | value = self._check_value(value) 646 | kw = kw.copy() 647 | if self._horizontal: 648 | kw["orient"] = "v" 649 | kw["x"] = "_bin" 650 | kw["y"] = value 651 | else: 652 | kw["orient"] = "h" 653 | kw["x"] = value 654 | kw["y"] = "_bin" 655 | import seaborn 656 | 657 | kw["ax"] = ax 658 | getattr(seaborn, kind + "plot")(data=df, **kw) 659 | 660 | ax = self._reorient(ax) 661 | if value == "_value": 662 | ax.set_ylabel("") 663 | 664 | ax.xaxis.set_visible(False) 665 | for x in ["top", "bottom", "right"]: 666 | ax.spines[self._reorient(x)].set_visible(False) 667 | 668 | tick_axis = ax.yaxis 669 | tick_axis.grid(True) 670 | 671 | def make_grid(self, fig=None): 672 | """Get a SubplotSpec for each Axes, accounting for label text width""" 673 | n_cats = len(self.totals) 674 | n_inters = len(self.intersections) 675 | 676 | if fig is None: 677 | fig = plt.gcf() 678 | 679 | # Determine text size to determine figure size / spacing 680 | text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]} 681 | # adding "x" ensures a margin 682 | t = fig.text( 683 | 0, 684 | 0, 685 | "\n".join(str(label) + "x" for label in self.totals.index.values), 686 | **text_kw, 687 | ) 688 | window_extent_args = {} 689 | if RENDERER_IMPORTED: 690 | with warnings.catch_warnings(): 691 | warnings.simplefilter("ignore", DeprecationWarning) 692 | window_extent_args["renderer"] = get_renderer(fig) 693 | textw = t.get_window_extent(**window_extent_args).width 694 | t.remove() 695 | 696 | window_extent_args = {} 697 | if RENDERER_IMPORTED: 698 | with warnings.catch_warnings(): 699 | warnings.simplefilter("ignore", DeprecationWarning) 700 | window_extent_args["renderer"] = get_renderer(fig) 701 | figw = self._reorient(fig.get_window_extent(**window_extent_args)).width 702 | 703 | sizes = np.asarray([p["elements"] for p in self._subset_plots]) 704 | fig = self._reorient(fig) 705 | 706 | non_text_nelems = len(self.intersections) + self._totals_plot_elements 707 | if self._element_size is None: 708 | colw = (figw - textw) / non_text_nelems 709 | else: 710 | render_ratio = figw / fig.get_figwidth() 711 | colw = self._element_size / 72 * render_ratio 712 | figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1) 713 | fig.set_figwidth(figw / render_ratio) 714 | fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio) 715 | 716 | text_nelems = int(np.ceil(figw / colw - non_text_nelems)) 717 | 718 | GS = self._reorient(matplotlib.gridspec.GridSpec) 719 | gridspec = GS( 720 | *self._swapaxes( 721 | n_cats + (sizes.sum() or 0), 722 | n_inters + text_nelems + self._totals_plot_elements, 723 | ), 724 | hspace=1, 725 | ) 726 | if self._horizontal: 727 | out = { 728 | "matrix": gridspec[-n_cats:, -n_inters:], 729 | "shading": gridspec[-n_cats:, :], 730 | "totals": None 731 | if self._totals_plot_elements == 0 732 | else gridspec[-n_cats:, : self._totals_plot_elements], 733 | "gs": gridspec, 734 | } 735 | cumsizes = np.cumsum(sizes[::-1]) 736 | for start, stop, plot in zip( 737 | np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1] 738 | ): 739 | out[plot["id"]] = gridspec[start:stop, -n_inters:] 740 | else: 741 | out = { 742 | "matrix": gridspec[-n_inters:, :n_cats], 743 | "shading": gridspec[:, :n_cats], 744 | "totals": None 745 | if self._totals_plot_elements == 0 746 | else gridspec[: self._totals_plot_elements, :n_cats], 747 | "gs": gridspec, 748 | } 749 | cumsizes = np.cumsum(sizes) 750 | for start, stop, plot in zip( 751 | np.hstack([[0], cumsizes]), cumsizes, self._subset_plots 752 | ): 753 | out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats] 754 | return out 755 | 756 | def plot_matrix(self, ax): 757 | """Plot the matrix of intersection indicators onto ax""" 758 | ax = self._reorient(ax) 759 | data = self.intersections 760 | n_cats = data.index.nlevels 761 | 762 | inclusion = data.index.to_frame().values 763 | 764 | # Prepare styling 765 | styles = [ 766 | [ 767 | self.subset_styles[i] 768 | if inclusion[i, j] 769 | else {"facecolor": self._other_dots_color, "linewidth": 0} 770 | for j in range(n_cats) 771 | ] 772 | for i in range(len(data)) 773 | ] 774 | styles = sum(styles, []) # flatten nested list 775 | style_columns = { 776 | "facecolor": "facecolors", 777 | "edgecolor": "edgecolors", 778 | "linewidth": "linewidths", 779 | "linestyle": "linestyles", 780 | "hatch": "hatch", 781 | } 782 | styles = ( 783 | pd.DataFrame(styles) 784 | .reindex(columns=style_columns.keys()) 785 | .astype( 786 | { 787 | "facecolor": "O", 788 | "edgecolor": "O", 789 | "linewidth": float, 790 | "linestyle": "O", 791 | "hatch": "O", 792 | } 793 | ) 794 | ) 795 | styles["linewidth"].fillna(1, inplace=True) 796 | styles["facecolor"].fillna(self._facecolor, inplace=True) 797 | styles["edgecolor"].fillna(styles["facecolor"], inplace=True) 798 | styles["linestyle"].fillna("solid", inplace=True) 799 | del styles["hatch"] # not supported in matrix (currently) 800 | 801 | x = np.repeat(np.arange(len(data)), n_cats) 802 | y = np.tile(np.arange(n_cats), len(data)) 803 | 804 | # Plot dots 805 | if self._element_size is not None: # noqa 806 | s = (self._element_size * 0.35) ** 2 807 | else: 808 | # TODO: make s relative to colw 809 | s = 200 810 | ax.scatter( 811 | *self._swapaxes(x, y), 812 | s=s, 813 | zorder=10, 814 | **styles.rename(columns=style_columns), 815 | ) 816 | 817 | # Plot lines 818 | if self._with_lines: 819 | idx = np.flatnonzero(inclusion) 820 | line_data = ( 821 | pd.Series(y[idx], index=x[idx]) 822 | .groupby(level=0) 823 | .aggregate(["min", "max"]) 824 | ) 825 | colors = pd.Series( 826 | [ 827 | style.get("edgecolor", style.get("facecolor", self._facecolor)) 828 | for style in self.subset_styles 829 | ], 830 | name="color", 831 | ) 832 | line_data = line_data.join(colors) 833 | ax.vlines( 834 | line_data.index.values, 835 | line_data["min"], 836 | line_data["max"], 837 | lw=2, 838 | colors=line_data["color"], 839 | zorder=5, 840 | ) 841 | 842 | # Ticks and axes 843 | tick_axis = ax.yaxis 844 | tick_axis.set_ticks(np.arange(n_cats)) 845 | tick_axis.set_ticklabels( 846 | data.index.names, rotation=0 if self._horizontal else -90 847 | ) 848 | ax.xaxis.set_visible(False) 849 | ax.tick_params(axis="both", which="both", length=0) 850 | if not self._horizontal: 851 | ax.yaxis.set_ticks_position("top") 852 | ax.set_frame_on(False) 853 | ax.set_xlim(-0.5, x[-1] + 0.5, auto=False) 854 | ax.grid(False) 855 | 856 | def plot_intersections(self, ax): 857 | """Plot bars indicating intersection size""" 858 | rects = self._plot_bars( 859 | ax, self.intersections, title="Intersection size", colors=self._facecolor 860 | ) 861 | for style, rect in zip(self.subset_styles, rects): 862 | style = style.copy() 863 | style.setdefault("edgecolor", style.get("facecolor", self._facecolor)) 864 | for attr, val in style.items(): 865 | getattr(rect, "set_" + attr)(val) 866 | 867 | if self.subset_legend: 868 | styles, labels = zip(*self.subset_legend) 869 | styles = [patches.Patch(**patch_style) for patch_style in styles] 870 | ax.legend(styles, labels) 871 | 872 | def _label_sizes(self, ax, rects, where): 873 | if not self._show_counts and not self._show_percentages: 874 | return 875 | if self._show_counts is True: 876 | count_fmt = "{:.0f}" 877 | else: 878 | count_fmt = self._show_counts 879 | if "{" not in count_fmt: 880 | count_fmt = util.to_new_pos_format(count_fmt) 881 | 882 | pct_fmt = "{:.1%}" if self._show_percentages is True else self._show_percentages 883 | 884 | if count_fmt and pct_fmt: 885 | if where == "top": 886 | fmt = f"{count_fmt}\n({pct_fmt})" 887 | else: 888 | fmt = f"{count_fmt} ({pct_fmt})" 889 | 890 | def make_args(val): 891 | return val, val / self.total 892 | elif count_fmt: 893 | fmt = count_fmt 894 | 895 | def make_args(val): 896 | return (val,) 897 | else: 898 | fmt = pct_fmt 899 | 900 | def make_args(val): 901 | return (val / self.total,) 902 | 903 | if where == "right": 904 | margin = 0.01 * abs(np.diff(ax.get_xlim())) 905 | for rect in rects: 906 | width = rect.get_width() + rect.get_x() 907 | ax.text( 908 | width + margin, 909 | rect.get_y() + rect.get_height() * 0.5, 910 | fmt.format(*make_args(width)), 911 | ha="left", 912 | va="center", 913 | ) 914 | elif where == "left": 915 | margin = 0.01 * abs(np.diff(ax.get_xlim())) 916 | for rect in rects: 917 | width = rect.get_width() + rect.get_x() 918 | ax.text( 919 | width + margin, 920 | rect.get_y() + rect.get_height() * 0.5, 921 | fmt.format(*make_args(width)), 922 | ha="right", 923 | va="center", 924 | ) 925 | elif where == "top": 926 | margin = 0.01 * abs(np.diff(ax.get_ylim())) 927 | for rect in rects: 928 | height = rect.get_height() + rect.get_y() 929 | ax.text( 930 | rect.get_x() + rect.get_width() * 0.5, 931 | height + margin, 932 | fmt.format(*make_args(height)), 933 | ha="center", 934 | va="bottom", 935 | ) 936 | else: 937 | raise NotImplementedError("unhandled where: %r" % where) 938 | 939 | def plot_totals(self, ax): 940 | """Plot bars indicating total set size""" 941 | orig_ax = ax 942 | ax = self._reorient(ax) 943 | rects = ax.barh( 944 | np.arange(len(self.totals.index.values)), 945 | self.totals, 946 | 0.5, 947 | color=self._facecolor, 948 | align="center", 949 | ) 950 | self._label_sizes(ax, rects, "left" if self._horizontal else "top") 951 | 952 | for category, rect in zip(self.totals.index.values, rects): 953 | style = { 954 | k[len("bar_") :]: v 955 | for k, v in self.category_styles.get(category, {}).items() 956 | if k.startswith("bar_") 957 | } 958 | style.setdefault("edgecolor", style.get("facecolor", self._facecolor)) 959 | for attr, val in style.items(): 960 | getattr(rect, "set_" + attr)(val) 961 | 962 | max_total = self.totals.max() 963 | if self._horizontal: 964 | orig_ax.set_xlim(max_total, 0) 965 | for x in ["top", "left", "right"]: 966 | ax.spines[self._reorient(x)].set_visible(False) 967 | ax.yaxis.set_visible(False) 968 | ax.xaxis.grid(True) 969 | ax.yaxis.grid(False) 970 | ax.patch.set_visible(False) 971 | 972 | def plot_shading(self, ax): 973 | # shade all rows, set every second row to zero visibility 974 | for i, category in enumerate(self.totals.index): 975 | default_shading = ( 976 | self._shading_color if i % 2 == 0 else (0.0, 0.0, 0.0, 0.0) 977 | ) 978 | shading_style = { 979 | k[len("shading_") :]: v 980 | for k, v in self.category_styles.get(category, {}).items() 981 | if k.startswith("shading_") 982 | } 983 | 984 | lw = shading_style.get( 985 | "linewidth", 1 if shading_style.get("edgecolor") else 0 986 | ) 987 | lw_padding = lw / (self._default_figsize[0] * self.DPI) 988 | start_x = lw_padding 989 | end_x = 1 - lw_padding * 3 990 | 991 | rect = plt.Rectangle( 992 | self._swapaxes(start_x, i - 0.4), 993 | *self._swapaxes(end_x, 0.8), 994 | facecolor=shading_style.get("facecolor", default_shading), 995 | edgecolor=shading_style.get("edgecolor", None), 996 | ls=shading_style.get("linestyle", "-"), 997 | lw=lw, 998 | zorder=0, 999 | ) 1000 | 1001 | ax.add_patch(rect) 1002 | ax.set_frame_on(False) 1003 | ax.tick_params( 1004 | axis="both", 1005 | which="both", 1006 | left=False, 1007 | right=False, 1008 | bottom=False, 1009 | top=False, 1010 | labelbottom=False, 1011 | labelleft=False, 1012 | ) 1013 | ax.grid(False) 1014 | ax.set_xticks([]) 1015 | ax.set_yticks([]) 1016 | ax.set_xticklabels([]) 1017 | ax.set_yticklabels([]) 1018 | 1019 | def style_categories( 1020 | self, 1021 | categories, 1022 | *, 1023 | bar_facecolor=None, 1024 | bar_hatch=None, 1025 | bar_edgecolor=None, 1026 | bar_linewidth=None, 1027 | bar_linestyle=None, 1028 | shading_facecolor=None, 1029 | shading_edgecolor=None, 1030 | shading_linewidth=None, 1031 | shading_linestyle=None, 1032 | ): 1033 | """Updates the style of the categories. 1034 | 1035 | Select a category by name, and style either its total bar or its shading. 1036 | 1037 | .. versionadded:: 0.9 1038 | 1039 | Parameters 1040 | ---------- 1041 | categories : str or list[str] 1042 | Category names where the changed style applies. 1043 | bar_facecolor : str or RGBA matplotlib color tuple, optional. 1044 | Override the default facecolor in the totals plot. 1045 | bar_hatch : str, optional 1046 | Set a hatch for the totals plot. 1047 | bar_edgecolor : str or matplotlib color, optional 1048 | Set the edgecolor for total bars. 1049 | bar_linewidth : int, optional 1050 | Line width in points for total bar edges. 1051 | bar_linestyle : str, optional 1052 | Line style for edges. 1053 | shading_facecolor : str or RGBA matplotlib color tuple, optional. 1054 | Override the default alternating shading for specified categories. 1055 | shading_edgecolor : str or matplotlib color, optional 1056 | Set the edgecolor for bars, dots, and the line between dots. 1057 | shading_linewidth : int, optional 1058 | Line width in points for edges. 1059 | shading_linestyle : str, optional 1060 | Line style for edges. 1061 | """ 1062 | if isinstance(categories, str): 1063 | categories = [categories] 1064 | style = { 1065 | "bar_facecolor": bar_facecolor, 1066 | "bar_hatch": bar_hatch, 1067 | "bar_edgecolor": bar_edgecolor, 1068 | "bar_linewidth": bar_linewidth, 1069 | "bar_linestyle": bar_linestyle, 1070 | "shading_facecolor": shading_facecolor, 1071 | "shading_edgecolor": shading_edgecolor, 1072 | "shading_linewidth": shading_linewidth, 1073 | "shading_linestyle": shading_linestyle, 1074 | } 1075 | style = {k: v for k, v in style.items() if v is not None} 1076 | for category_name in categories: 1077 | self.category_styles.setdefault(category_name, {}).update(style) 1078 | 1079 | def plot(self, fig=None): 1080 | """Draw all parts of the plot onto fig or a new figure 1081 | 1082 | Parameters 1083 | ---------- 1084 | fig : matplotlib.figure.Figure, optional 1085 | Defaults to a new figure. 1086 | 1087 | Returns 1088 | ------- 1089 | subplots : dict of matplotlib.axes.Axes 1090 | Keys are 'matrix', 'intersections', 'totals', 'shading' 1091 | """ 1092 | if fig is None: 1093 | fig = plt.figure(figsize=self._default_figsize) 1094 | specs = self.make_grid(fig) 1095 | shading_ax = fig.add_subplot(specs["shading"]) 1096 | self.plot_shading(shading_ax) 1097 | matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax) 1098 | self.plot_matrix(matrix_ax) 1099 | if specs["totals"] is None: 1100 | totals_ax = None 1101 | else: 1102 | totals_ax = self._reorient(fig.add_subplot)( 1103 | specs["totals"], sharey=matrix_ax 1104 | ) 1105 | self.plot_totals(totals_ax) 1106 | out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax} 1107 | 1108 | for plot in self._subset_plots: 1109 | ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax) 1110 | if plot["type"] == "default": 1111 | self.plot_intersections(ax) 1112 | elif plot["type"] in self.PLOT_TYPES: 1113 | kw = plot.copy() 1114 | del kw["type"] 1115 | del kw["elements"] 1116 | del kw["id"] 1117 | self.PLOT_TYPES[plot["type"]](self, ax, **kw) 1118 | else: 1119 | raise ValueError("Unknown subset plot type: %r" % plot["type"]) 1120 | out[plot["id"]] = ax 1121 | 1122 | self._reorient(fig).align_ylabels( 1123 | [out[plot["id"]] for plot in self._subset_plots] 1124 | ) 1125 | return out 1126 | 1127 | PLOT_TYPES = { 1128 | "catplot": _plot_catplot, 1129 | "stacked_bars": _plot_stacked_bars, 1130 | } 1131 | 1132 | def _repr_html_(self): 1133 | fig = plt.figure(figsize=self._default_figsize) 1134 | self.plot(fig=fig) 1135 | return fig._repr_html_() 1136 | 1137 | 1138 | def plot(data, fig=None, **kwargs): 1139 | """Make an UpSet plot of data on fig 1140 | 1141 | Parameters 1142 | ---------- 1143 | data : pandas.Series or pandas.DataFrame 1144 | Values for each set to plot. 1145 | Should have multi-index where each level is binary, 1146 | corresponding to set membership. 1147 | If a DataFrame, `sum_over` must be a string or False. 1148 | fig : matplotlib.figure.Figure, optional 1149 | Defaults to a new figure. 1150 | kwargs 1151 | Other arguments for :class:`UpSet` 1152 | 1153 | Returns 1154 | ------- 1155 | subplots : dict of matplotlib.axes.Axes 1156 | Keys are 'matrix', 'intersections', 'totals', 'shading' 1157 | """ 1158 | return UpSet(data, **kwargs).plot(fig) 1159 | -------------------------------------------------------------------------------- /upsetplot/reformat.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def _aggregate_data(df, subset_size, sum_over): 8 | """ 9 | Returns 10 | ------- 11 | df : DataFrame 12 | full data frame 13 | aggregated : Series 14 | aggregates 15 | """ 16 | _SUBSET_SIZE_VALUES = ["auto", "count", "sum"] 17 | if subset_size not in _SUBSET_SIZE_VALUES: 18 | raise ValueError( 19 | f"subset_size should be one of {_SUBSET_SIZE_VALUES}." 20 | f" Got {repr(subset_size)}" 21 | ) 22 | if df.ndim == 1: 23 | # Series 24 | input_name = df.name 25 | df = pd.DataFrame({"_value": df}) 26 | 27 | if subset_size == "auto" and not df.index.is_unique: 28 | raise ValueError( 29 | 'subset_size="auto" cannot be used for a ' 30 | "Series with non-unique groups." 31 | ) 32 | if sum_over is not None: 33 | raise ValueError("sum_over is not applicable when the input is a " "Series") 34 | sum_over = False if subset_size == "count" else "_value" 35 | else: 36 | # DataFrame 37 | if sum_over is False: 38 | raise ValueError("Unsupported value for sum_over: False") 39 | elif subset_size == "auto" and sum_over is None: 40 | sum_over = False 41 | elif subset_size == "count": 42 | if sum_over is not None: 43 | raise ValueError( 44 | "sum_over cannot be set if subset_size=%r" % subset_size 45 | ) 46 | sum_over = False 47 | elif subset_size == "sum" and sum_over is None: 48 | raise ValueError( 49 | "sum_over should be a field name if " 50 | 'subset_size="sum" and a DataFrame is ' 51 | "provided." 52 | ) 53 | 54 | gb = df.groupby(level=list(range(df.index.nlevels)), sort=False) 55 | if sum_over is False: 56 | aggregated = gb.size() 57 | aggregated.name = "size" 58 | elif hasattr(sum_over, "lower"): 59 | aggregated = gb[sum_over].sum() 60 | else: 61 | raise ValueError("Unsupported value for sum_over: %r" % sum_over) 62 | 63 | if aggregated.name == "_value": 64 | aggregated.name = input_name 65 | 66 | return df, aggregated 67 | 68 | 69 | def _check_index(df): 70 | # check all indices are boolean 71 | if not all({True, False} >= set(level) for level in df.index.levels): 72 | raise ValueError( 73 | "The DataFrame has values in its index that are not " "boolean" 74 | ) 75 | df = df.copy(deep=False) 76 | # XXX: this may break if input is not MultiIndex 77 | kw = { 78 | "levels": [x.astype(bool) for x in df.index.levels], 79 | "names": df.index.names, 80 | } 81 | if hasattr(df.index, "codes"): 82 | # compat for pandas <= 0.20 83 | kw["codes"] = df.index.codes 84 | else: 85 | kw["labels"] = df.index.labels 86 | df.index = pd.MultiIndex(**kw) 87 | return df 88 | 89 | 90 | def _scalar_to_list(val): 91 | if not isinstance(val, (typing.Sequence, set)) or isinstance(val, str): 92 | val = [val] 93 | return val 94 | 95 | 96 | def _check_percent(value, agg): 97 | if not isinstance(value, str): 98 | return value 99 | try: 100 | if value.endswith("%") and 0 <= float(value[:-1]) <= 100: 101 | return float(value[:-1]) / 100 * agg.sum() 102 | except ValueError: 103 | pass 104 | raise ValueError( 105 | f"String value must be formatted as percentage between 0 and 100. Got {value}" 106 | ) 107 | 108 | 109 | def _get_subset_mask( 110 | agg, 111 | min_subset_size, 112 | max_subset_size, 113 | max_subset_rank, 114 | min_degree, 115 | max_degree, 116 | present, 117 | absent, 118 | ): 119 | """Get a mask over subsets based on size, degree or category presence""" 120 | min_subset_size = _check_percent(min_subset_size, agg) 121 | max_subset_size = _check_percent(max_subset_size, agg) 122 | subset_mask = True 123 | if min_subset_size is not None: 124 | subset_mask = np.logical_and(subset_mask, agg >= min_subset_size) 125 | if max_subset_size is not None: 126 | subset_mask = np.logical_and(subset_mask, agg <= max_subset_size) 127 | if max_subset_rank is not None: 128 | subset_mask = np.logical_and( 129 | subset_mask, agg.rank(method="min", ascending=False) <= max_subset_rank 130 | ) 131 | if (min_degree is not None and min_degree >= 0) or max_degree is not None: 132 | degree = agg.index.to_frame().sum(axis=1) 133 | if min_degree is not None: 134 | subset_mask = np.logical_and(subset_mask, degree >= min_degree) 135 | if max_degree is not None: 136 | subset_mask = np.logical_and(subset_mask, degree <= max_degree) 137 | if present is not None: 138 | for col in _scalar_to_list(present): 139 | subset_mask = np.logical_and( 140 | subset_mask, agg.index.get_level_values(col).values 141 | ) 142 | if absent is not None: 143 | for col in _scalar_to_list(absent): 144 | exclude_mask = np.logical_not(agg.index.get_level_values(col).values) 145 | subset_mask = np.logical_and(subset_mask, exclude_mask) 146 | return subset_mask 147 | 148 | 149 | def _filter_subsets( 150 | df, 151 | agg, 152 | min_subset_size, 153 | max_subset_size, 154 | max_subset_rank, 155 | min_degree, 156 | max_degree, 157 | present, 158 | absent, 159 | ): 160 | subset_mask = _get_subset_mask( 161 | agg, 162 | min_subset_size=min_subset_size, 163 | max_subset_size=max_subset_size, 164 | max_subset_rank=max_subset_rank, 165 | min_degree=min_degree, 166 | max_degree=max_degree, 167 | present=present, 168 | absent=absent, 169 | ) 170 | 171 | if subset_mask is True: 172 | return df, agg 173 | 174 | agg = agg[subset_mask] 175 | df = df[df.index.isin(agg.index)] 176 | return df, agg 177 | 178 | 179 | class QueryResult: 180 | """Container for reformatted data and aggregates 181 | 182 | Attributes 183 | ---------- 184 | data : DataFrame 185 | Selected samples. The index is a MultiIndex with one boolean level for 186 | each category. 187 | subsets : dict[frozenset, DataFrame] 188 | Dataframes for each intersection of categories. 189 | subset_sizes : Series 190 | Total size of each selected subset as a series. The index is as 191 | for `data`. 192 | category_totals : Series 193 | Total size of each category, regardless of selection. 194 | total : number 195 | Total number of samples, or sum of sum_over value. 196 | """ 197 | 198 | def __init__(self, data, subset_sizes, category_totals, total): 199 | self.data = data 200 | self.subset_sizes = subset_sizes 201 | self.category_totals = category_totals 202 | self.total = total 203 | 204 | def __repr__(self): 205 | return ( 206 | "QueryResult(data={data}, subset_sizes={subset_sizes}, " 207 | "category_totals={category_totals}, total={total}".format(**vars(self)) 208 | ) 209 | 210 | @property 211 | def subsets(self): 212 | categories = np.asarray(self.data.index.names) 213 | return { 214 | frozenset(categories.take(mask)): subset_data 215 | for mask, subset_data in self.data.groupby( 216 | level=list(range(len(categories))), sort=False 217 | ) 218 | } 219 | 220 | 221 | def query( 222 | data, 223 | present=None, 224 | absent=None, 225 | min_subset_size=None, 226 | max_subset_size=None, 227 | max_subset_rank=None, 228 | min_degree=None, 229 | max_degree=None, 230 | sort_by="degree", 231 | sort_categories_by="cardinality", 232 | subset_size="auto", 233 | sum_over=None, 234 | include_empty_subsets=False, 235 | ): 236 | """Transform and filter a categorised dataset 237 | 238 | Retrieve the set of items and totals corresponding to subsets of interest. 239 | 240 | Parameters 241 | ---------- 242 | data : pandas.Series or pandas.DataFrame 243 | Elements associated with categories (a DataFrame), or the size of each 244 | subset of categories (a Series). 245 | Should have MultiIndex where each level is binary, 246 | corresponding to category membership. 247 | If a DataFrame, `sum_over` must be a string or False. 248 | present : str or list of str, optional 249 | Category or categories that must be present in subsets for styling. 250 | absent : str or list of str, optional 251 | Category or categories that must not be present in subsets for 252 | styling. 253 | min_subset_size : int or "number%", optional 254 | Minimum size of a subset to be reported. All subsets with 255 | a size smaller than this threshold will be omitted from 256 | category_totals and data. This may be specified as a percentage 257 | using a string, like "50%". 258 | Size may be a sum of values, see `subset_size`. 259 | 260 | .. versionchanged:: 0.9 261 | Support percentages 262 | max_subset_size : int or "number%", optional 263 | Maximum size of a subset to be reported. 264 | 265 | .. versionchanged:: 0.9 266 | Support percentages 267 | max_subset_rank : int, optional 268 | Limit to the top N ranked subsets in descending order of size. 269 | All tied subsets are included. 270 | 271 | .. versionadded:: 0.9 272 | min_degree : int, optional 273 | Minimum degree of a subset to be reported. 274 | max_degree : int, optional 275 | Maximum degree of a subset to be reported. 276 | sort_by : {'cardinality', 'degree', '-cardinality', '-degree', 277 | 'input', '-input'} 278 | If 'cardinality', subset are listed from largest to smallest. 279 | If 'degree', they are listed in order of the number of categories 280 | intersected. If 'input', the order they appear in the data input is 281 | used. 282 | Prefix with '-' to reverse the ordering. 283 | 284 | Note this affects ``subset_sizes`` but not ``data``. 285 | sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'} 286 | Whether to sort the categories by total cardinality, or leave them 287 | in the input data's provided order (order of index levels). 288 | Prefix with '-' to reverse the ordering. 289 | subset_size : {'auto', 'count', 'sum'} 290 | Configures how to calculate the size of a subset. Choices are: 291 | 292 | 'auto' (default) 293 | If `data` is a DataFrame, count the number of rows in each group, 294 | unless `sum_over` is specified. 295 | If `data` is a Series with at most one row for each group, use 296 | the value of the Series. If `data` is a Series with more than one 297 | row per group, raise a ValueError. 298 | 'count' 299 | Count the number of rows in each group. 300 | 'sum' 301 | Sum the value of the `data` Series, or the DataFrame field 302 | specified by `sum_over`. 303 | sum_over : str or None 304 | If `subset_size='sum'` or `'auto'`, then the intersection size is the 305 | sum of the specified field in the `data` DataFrame. If a Series, only 306 | None is supported and its value is summed. 307 | include_empty_subsets : bool (default=False) 308 | If True, all possible category combinations will be returned in 309 | subset_sizes, even when some are not present in data. 310 | 311 | Returns 312 | ------- 313 | QueryResult 314 | Including filtered ``data``, filtered and sorted ``subset_sizes`` and 315 | overall ``category_totals`` and ``total``. 316 | 317 | Examples 318 | -------- 319 | >>> from upsetplot import query, generate_samples 320 | >>> data = generate_samples(n_samples=20) 321 | >>> result = query(data, present="cat1", max_subset_size=4) 322 | >>> result.category_totals 323 | cat1 14 324 | cat2 4 325 | cat0 0 326 | dtype: int64 327 | >>> result.subset_sizes 328 | cat1 cat2 cat0 329 | True True False 3 330 | Name: size, dtype: int64 331 | >>> result.data 332 | index value 333 | cat1 cat2 cat0 334 | True True False 0 2.04... 335 | False 2 2.05... 336 | False 10 2.55... 337 | >>> 338 | >>> # Sorting: 339 | >>> query(data, min_degree=1, sort_by="degree").subset_sizes 340 | cat1 cat2 cat0 341 | True False False 11 342 | False True False 1 343 | True True False 3 344 | Name: size, dtype: int64 345 | >>> query(data, min_degree=1, sort_by="cardinality").subset_sizes 346 | cat1 cat2 cat0 347 | True False False 11 348 | True False 3 349 | False True False 1 350 | Name: size, dtype: int64 351 | >>> 352 | >>> # Getting each subset's data 353 | >>> result = query(data) 354 | >>> result.subsets[frozenset({"cat1", "cat2"})] 355 | index value 356 | cat1 cat2 cat0 357 | False True False 3 1.333795 358 | >>> result.subsets[frozenset({"cat1"})] 359 | index value 360 | cat1 cat2 cat0 361 | False False False 5 0.918174 362 | False 8 1.948521 363 | False 9 1.086599 364 | False 13 1.105696 365 | False 19 1.339895 366 | """ 367 | 368 | data, agg = _aggregate_data(data, subset_size, sum_over) 369 | data = _check_index(data) 370 | grand_total = agg.sum() 371 | category_totals = [ 372 | agg[agg.index.get_level_values(name).values.astype(bool)].sum() 373 | for name in agg.index.names 374 | ] 375 | category_totals = pd.Series(category_totals, index=agg.index.names) 376 | 377 | if include_empty_subsets: 378 | nlevels = len(agg.index.levels) 379 | if nlevels > 10: 380 | raise ValueError( 381 | "include_empty_subsets is supported for at most 10 categories" 382 | ) 383 | new_agg = pd.Series( 384 | 0, 385 | index=pd.MultiIndex.from_product( 386 | [[False, True]] * nlevels, names=agg.index.names 387 | ), 388 | dtype=agg.dtype, 389 | name=agg.name, 390 | ) 391 | new_agg.update(agg) 392 | agg = new_agg 393 | 394 | data, agg = _filter_subsets( 395 | data, 396 | agg, 397 | min_subset_size=min_subset_size, 398 | max_subset_size=max_subset_size, 399 | max_subset_rank=max_subset_rank, 400 | min_degree=min_degree, 401 | max_degree=max_degree, 402 | present=present, 403 | absent=absent, 404 | ) 405 | 406 | # sort: 407 | if sort_categories_by in ("cardinality", "-cardinality"): 408 | category_totals.sort_values( 409 | ascending=sort_categories_by[:1] == "-", inplace=True 410 | ) 411 | elif sort_categories_by == "-input": 412 | category_totals = category_totals[::-1] 413 | elif sort_categories_by in (None, "input"): 414 | pass 415 | else: 416 | raise ValueError("Unknown sort_categories_by: %r" % sort_categories_by) 417 | data = data.reorder_levels(category_totals.index.values) 418 | agg = agg.reorder_levels(category_totals.index.values) 419 | 420 | if sort_by in ("cardinality", "-cardinality"): 421 | agg = agg.sort_values(ascending=sort_by[:1] == "-") 422 | elif sort_by in ("degree", "-degree"): 423 | index_tuples = sorted( 424 | agg.index, 425 | key=lambda x: (sum(x),) + tuple(reversed(x)), 426 | reverse=sort_by[:1] == "-", 427 | ) 428 | agg = agg.reindex( 429 | pd.MultiIndex.from_tuples(index_tuples, names=agg.index.names) 430 | ) 431 | elif sort_by == "-input": 432 | agg = agg[::-1] 433 | elif sort_by in (None, "input"): 434 | pass 435 | else: 436 | raise ValueError("Unknown sort_by: %r" % sort_by) 437 | 438 | return QueryResult( 439 | data=data, subset_sizes=agg, category_totals=category_totals, total=grand_total 440 | ) 441 | -------------------------------------------------------------------------------- /upsetplot/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jnothman/UpSetPlot/dcfadae04edfffe321fc0b002bf1865019f75f9b/upsetplot/tests/__init__.py -------------------------------------------------------------------------------- /upsetplot/tests/test_data.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal 7 | 8 | from upsetplot import from_contents, from_indicators, from_memberships, generate_data 9 | 10 | 11 | @pytest.mark.parametrize("typ", [set, list, tuple, iter]) 12 | def test_from_memberships_no_data(typ): 13 | with pytest.raises(ValueError, match="at least one category"): 14 | from_memberships([]) 15 | with pytest.raises(ValueError, match="at least one category"): 16 | from_memberships([[], []]) 17 | with pytest.raises(ValueError, match="strings"): 18 | from_memberships([[1]]) 19 | with pytest.raises(ValueError, match="strings"): 20 | from_memberships([[1, "str"]]) 21 | with pytest.raises(TypeError): 22 | from_memberships([1]) 23 | 24 | out = from_memberships( 25 | [ 26 | typ([]), 27 | typ(["hello"]), 28 | typ(["world"]), 29 | typ(["hello", "world"]), 30 | ] 31 | ) 32 | exp = pd.DataFrame( 33 | [[False, False, 1], [True, False, 1], [False, True, 1], [True, True, 1]], 34 | columns=["hello", "world", "ones"], 35 | ).set_index(["hello", "world"])["ones"] 36 | assert isinstance(exp.index, pd.MultiIndex) 37 | assert_series_equal(exp, out) 38 | 39 | # test sorting by name 40 | out = from_memberships([typ(["hello"]), typ(["world"])]) 41 | exp = pd.DataFrame( 42 | [[True, False, 1], [False, True, 1]], columns=["hello", "world", "ones"] 43 | ).set_index(["hello", "world"])["ones"] 44 | assert_series_equal(exp, out) 45 | out = from_memberships([typ(["world"]), typ(["hello"])]) 46 | exp = pd.DataFrame( 47 | [[False, True, 1], [True, False, 1]], columns=["hello", "world", "ones"] 48 | ).set_index(["hello", "world"])["ones"] 49 | assert_series_equal(exp, out) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | ("data", "ndim"), 54 | [ 55 | ([1, 2, 3, 4], 1), 56 | (np.array([1, 2, 3, 4]), 1), 57 | (pd.Series([1, 2, 3, 4], name="foo"), 1), 58 | ([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], 2), 59 | ( 60 | pd.DataFrame( 61 | [[1, "a"], [2, "b"], [3, "c"], [4, "d"]], 62 | columns=["foo", "bar"], 63 | index=["q", "r", "s", "t"], 64 | ), 65 | 2, 66 | ), 67 | ], 68 | ) 69 | def test_from_memberships_with_data(data, ndim): 70 | memberships = [[], ["hello"], ["world"], ["hello", "world"]] 71 | out = from_memberships(memberships, data=data) 72 | assert out is not data # make sure frame is copied 73 | if hasattr(data, "loc") and np.asarray(data).dtype.kind in "ifb": 74 | # but not deepcopied when possible 75 | assert out.values.base is np.asarray(data).base 76 | if ndim == 1: 77 | assert isinstance(out, pd.Series) 78 | else: 79 | assert isinstance(out, pd.DataFrame) 80 | assert_frame_equal( 81 | pd.DataFrame(out).reset_index(drop=True), 82 | pd.DataFrame(data).reset_index(drop=True), 83 | ) 84 | no_data = from_memberships(memberships=memberships) 85 | assert_index_equal(out.index, no_data.index) 86 | 87 | with pytest.raises(ValueError, match="length"): 88 | from_memberships(memberships[:-1], data=data) 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "data", [None, {"attr1": [3, 4, 5, 6, 7, 8], "attr2": list("qrstuv")}] 93 | ) 94 | @pytest.mark.parametrize("typ", [set, list, tuple, iter]) 95 | @pytest.mark.parametrize("id_column", ["id", "blah"]) 96 | def test_from_contents_vs_memberships(data, typ, id_column): 97 | contents = OrderedDict( 98 | [ 99 | ("cat1", typ(["aa", "bb", "cc"])), 100 | ("cat2", typ(["cc", "dd"])), 101 | ("cat3", typ(["ee"])), 102 | ] 103 | ) 104 | # Note that ff is not present in contents 105 | data_df = pd.DataFrame(data, index=["aa", "bb", "cc", "dd", "ee", "ff"]) 106 | baseline = from_contents(contents, data=data_df, id_column=id_column) 107 | # compare from_contents to from_memberships 108 | expected = from_memberships( 109 | memberships=[{"cat1"}, {"cat1"}, {"cat1", "cat2"}, {"cat2"}, {"cat3"}, []], 110 | data=data_df, 111 | ) 112 | assert_series_equal( 113 | baseline[id_column].reset_index(drop=True), 114 | pd.Series(["aa", "bb", "cc", "dd", "ee", "ff"], name=id_column), 115 | ) 116 | baseline_without_id = baseline.drop([id_column], axis=1) 117 | assert_frame_equal( 118 | baseline_without_id, 119 | expected, 120 | check_column_type=baseline_without_id.shape[1] > 0, 121 | ) 122 | 123 | 124 | def test_from_contents(typ=set, id_column="id"): 125 | contents = OrderedDict( 126 | [("cat1", {"aa", "bb", "cc"}), ("cat2", {"cc", "dd"}), ("cat3", {"ee"})] 127 | ) 128 | empty_data = pd.DataFrame(index=["aa", "bb", "cc", "dd", "ee"]) 129 | baseline = from_contents(contents, data=empty_data, id_column=id_column) 130 | # data=None 131 | out = from_contents(contents, id_column=id_column) 132 | assert_frame_equal(out.sort_values(id_column), baseline) 133 | 134 | # unordered contents dict 135 | out = from_contents( 136 | {"cat3": contents["cat3"], "cat2": contents["cat2"], "cat1": contents["cat1"]}, 137 | data=empty_data, 138 | id_column=id_column, 139 | ) 140 | assert_frame_equal(out.reorder_levels(["cat1", "cat2", "cat3"]), baseline) 141 | 142 | # empty category 143 | out = from_contents( 144 | { 145 | "cat1": contents["cat1"], 146 | "cat2": contents["cat2"], 147 | "cat3": contents["cat3"], 148 | "cat4": [], 149 | }, 150 | data=empty_data, 151 | id_column=id_column, 152 | ) 153 | assert not out.index.to_frame()["cat4"].any() # cat4 should be all-false 154 | assert len(out.index.names) == 4 155 | out.index = out.index.to_frame().set_index(["cat1", "cat2", "cat3"]).index 156 | assert_frame_equal(out, baseline) 157 | 158 | 159 | @pytest.mark.parametrize("id_column", ["id", "blah"]) 160 | def test_from_contents_invalid(id_column): 161 | contents = OrderedDict( 162 | [("cat1", {"aa", "bb", "cc"}), ("cat2", {"cc", "dd"}), ("cat3", {"ee"})] 163 | ) 164 | with pytest.raises(ValueError, match="columns overlap"): 165 | from_contents( 166 | contents, data=pd.DataFrame({"cat1": [1, 2, 3, 4, 5]}), id_column=id_column 167 | ) 168 | with pytest.raises(ValueError, match="duplicate ids"): 169 | from_contents({"cat1": ["aa", "bb"], "cat2": ["dd", "dd"]}, id_column=id_column) 170 | # category named id 171 | with pytest.raises(ValueError, match="cannot be named"): 172 | from_contents( 173 | { 174 | id_column: {"aa", "bb", "cc"}, 175 | "cat2": {"cc", "dd"}, 176 | }, 177 | id_column=id_column, 178 | ) 179 | # category named id 180 | with pytest.raises(ValueError, match="cannot contain"): 181 | from_contents( 182 | contents, 183 | data=pd.DataFrame( 184 | {id_column: [1, 2, 3, 4, 5]}, index=["aa", "bb", "cc", "dd", "ee"] 185 | ), 186 | id_column=id_column, 187 | ) 188 | with pytest.raises(ValueError, match="identifiers in contents"): 189 | from_contents({"cat1": ["aa"]}, data=pd.DataFrame([[1]]), id_column=id_column) 190 | 191 | 192 | @pytest.mark.parametrize( 193 | ("indicators", "data", "exc_type", "match"), 194 | [ 195 | (["a", "b"], None, ValueError, "data must be provided"), 196 | (lambda df: [True, False, True], None, ValueError, "data must be provided"), 197 | (["a", "unknown_col"], {"a": [1, 2, 3]}, KeyError, "unknown_col"), 198 | (("a",), {"a": [1, 2, 3]}, ValueError, "tuple"), 199 | ({"cat1": [0, 1, 1]}, {"a": [1, 2, 3]}, ValueError, "must all be boolean"), 200 | ( 201 | pd.DataFrame({"cat1": [True, False, True]}, index=["a", "b", "c"]), 202 | {"A": [1, 2, 3]}, 203 | ValueError, 204 | "all its values must be present", 205 | ), 206 | ], 207 | ) 208 | def test_from_indicators_invalid(indicators, data, exc_type, match): 209 | with pytest.raises(exc_type, match=match): 210 | from_indicators(indicators=indicators, data=data) 211 | 212 | 213 | @pytest.mark.parametrize( 214 | "indicators", 215 | [ 216 | pd.DataFrame({"cat1": [False, True, False]}), 217 | pd.DataFrame({"cat1": [False, True, False]}, dtype="O"), 218 | {"cat1": [False, True, False]}, 219 | lambda data: {"cat1": {pd.DataFrame(data).index.values[1]: True}}, 220 | ], 221 | ) 222 | @pytest.mark.parametrize( 223 | "data", 224 | [ 225 | pd.DataFrame({"val1": [3, 4, 5]}), 226 | pd.DataFrame({"val1": [3, 4, 5]}, index=["a", "b", "c"]), 227 | {"val1": [3, 4, 5]}, 228 | ], 229 | ) 230 | def test_from_indicators_equivalence(indicators, data): 231 | assert_frame_equal( 232 | from_indicators(indicators, data), from_memberships([[], ["cat1"], []], data) 233 | ) 234 | 235 | 236 | def test_generate_data_warning(): 237 | with pytest.warns(DeprecationWarning): 238 | generate_data() 239 | -------------------------------------------------------------------------------- /upsetplot/tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | import pytest 7 | 8 | exa_glob = os.path.join( 9 | os.path.dirname(os.path.abspath(__file__)), "..", "..", "examples", "*.py" 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize("path", glob.glob(exa_glob)) 14 | def test_example(path): 15 | pytest.importorskip("sklearn") 16 | pytest.importorskip("seaborn") 17 | env = os.environ.copy() 18 | env["PYTHONPATH"] = os.getcwd() + ":" + env.get("PYTHONPATH", "") 19 | subprocess.check_output([sys.executable, path], env=env) 20 | -------------------------------------------------------------------------------- /upsetplot/tests/test_reformat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | from pandas.testing import assert_frame_equal, assert_series_equal 4 | 5 | from upsetplot import generate_counts, generate_samples, query 6 | 7 | # `query` is mostly tested through plotting tests, especially tests of 8 | # `_process_data` which cover sort_by, sort_categories_by, subset_size 9 | # and sum_over. 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "data", 14 | [ 15 | generate_counts(), 16 | generate_samples(), 17 | ], 18 | ) 19 | @pytest.mark.parametrize( 20 | "param_set", 21 | [ 22 | [{"present": "cat1"}, {"absent": "cat1"}], 23 | [{"max_degree": 0}, {"min_degree": 1, "max_degree": 2}, {"min_degree": 3}], 24 | [{"max_subset_size": 30}, {"min_subset_size": 31}], 25 | [ 26 | {"present": "cat1", "max_subset_size": 30}, 27 | {"absent": "cat1", "max_subset_size": 30}, 28 | {"present": "cat1", "min_subset_size": 31}, 29 | {"absent": "cat1", "min_subset_size": 31}, 30 | ], 31 | ], 32 | ) 33 | def test_mece_queries(data, param_set): 34 | unfiltered_results = query(data) 35 | all_results = [query(data, **params) for params in param_set] 36 | 37 | # category_totals is unaffected by filter 38 | for results in all_results: 39 | assert_series_equal(unfiltered_results.category_totals, results.category_totals) 40 | 41 | combined_data = pd.concat([results.data for results in all_results]) 42 | combined_data.sort_index(inplace=True) 43 | assert_frame_equal(unfiltered_results.data.sort_index(), combined_data) 44 | 45 | combined_sizes = pd.concat([results.subset_sizes for results in all_results]) 46 | combined_sizes.sort_index(inplace=True) 47 | assert_series_equal(unfiltered_results.subset_sizes.sort_index(), combined_sizes) 48 | -------------------------------------------------------------------------------- /upsetplot/tests/test_upsetplot.py: -------------------------------------------------------------------------------- 1 | import io 2 | import itertools 3 | 4 | import matplotlib.figure 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | from matplotlib import cm 10 | from matplotlib.colors import to_hex 11 | from matplotlib.text import Text 12 | from numpy.testing import assert_array_equal 13 | from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal 14 | 15 | from upsetplot import UpSet, generate_counts, generate_samples, plot 16 | from upsetplot.plotting import _process_data 17 | 18 | # TODO: warnings should raise errors 19 | 20 | 21 | def is_ascending(seq): 22 | # return np.all(np.diff(seq) >= 0) 23 | return sorted(seq) == list(seq) 24 | 25 | 26 | def get_all_texts(mpl_artist): 27 | out = [text.get_text() for text in mpl_artist.findobj(Text)] 28 | return [text for text in out if text] 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "x", 33 | [ 34 | generate_counts(), 35 | generate_counts().iloc[1:-2], 36 | ], 37 | ) 38 | @pytest.mark.parametrize( 39 | "sort_by", 40 | ["cardinality", "degree", "-cardinality", "-degree", None, "input", "-input"], 41 | ) 42 | @pytest.mark.parametrize( 43 | "sort_categories_by", [None, "input", "-input", "cardinality", "-cardinality"] 44 | ) 45 | def test_process_data_series(x, sort_by, sort_categories_by): 46 | assert x.name == "value" 47 | for subset_size in ["auto", "sum", "count"]: 48 | for sum_over in ["abc", False]: 49 | with pytest.raises(ValueError, match="sum_over is not applicable"): 50 | _process_data( 51 | x, 52 | sort_by=sort_by, 53 | sort_categories_by=sort_categories_by, 54 | subset_size=subset_size, 55 | sum_over=sum_over, 56 | ) 57 | 58 | # shuffle input to test sorting 59 | x = x.sample(frac=1.0, replace=False, random_state=0) 60 | 61 | total, df, intersections, totals = _process_data( 62 | x, 63 | subset_size="auto", 64 | sort_by=sort_by, 65 | sort_categories_by=sort_categories_by, 66 | sum_over=None, 67 | ) 68 | 69 | assert total == x.sum() 70 | 71 | assert intersections.name == "value" 72 | x_reordered_levels = x.reorder_levels(intersections.index.names) 73 | x_reordered = x_reordered_levels.reindex(index=intersections.index) 74 | assert len(x) == len(x_reordered) 75 | assert x_reordered.index.is_unique 76 | assert_series_equal(x_reordered, intersections, check_dtype=False) 77 | 78 | if sort_by == "cardinality": 79 | assert is_ascending(intersections.values[::-1]) 80 | elif sort_by == "-cardinality": 81 | assert is_ascending(intersections.values) 82 | elif sort_by == "degree": 83 | # check degree order 84 | assert is_ascending(intersections.index.to_frame().sum(axis=1)) 85 | # TODO: within a same-degree group, the tuple of active names should 86 | # be in sort-order 87 | elif sort_by == "-degree": 88 | # check degree order 89 | assert is_ascending(intersections.index.to_frame().sum(axis=1)[::-1]) 90 | else: 91 | find_first_in_orig = x_reordered_levels.index.tolist().index 92 | orig_order = [find_first_in_orig(key) for key in intersections.index.tolist()] 93 | assert orig_order == sorted( 94 | orig_order, reverse=sort_by is not None and sort_by.startswith("-") 95 | ) 96 | 97 | if sort_categories_by == "cardinality": 98 | assert is_ascending(totals.values[::-1]) 99 | elif sort_categories_by == "-cardinality": 100 | assert is_ascending(totals.values) 101 | 102 | assert np.all(totals.index.values == intersections.index.names) 103 | 104 | assert np.all(df.index.names == intersections.index.names) 105 | assert set(df.columns) == {"_value", "_bin"} 106 | assert_index_equal(df["_value"].reorder_levels(x.index.names).index, x.index) 107 | assert_array_equal(df["_value"], x) 108 | assert_index_equal(intersections.iloc[df["_bin"]].index, df.index) 109 | assert len(df) == len(x) 110 | 111 | 112 | @pytest.mark.parametrize( 113 | "x", 114 | [ 115 | generate_samples()["value"], 116 | generate_counts(), 117 | ], 118 | ) 119 | def test_subset_size_series(x): 120 | kw = { 121 | "sort_by": "cardinality", 122 | "sort_categories_by": "cardinality", 123 | "sum_over": None, 124 | } 125 | total, df_sum, intersections_sum, totals_sum = _process_data( 126 | x, subset_size="sum", **kw 127 | ) 128 | assert total == intersections_sum.sum() 129 | 130 | if x.index.is_unique: 131 | total, df, intersections, totals = _process_data(x, subset_size="auto", **kw) 132 | assert total == intersections.sum() 133 | assert_frame_equal(df, df_sum) 134 | assert_series_equal(intersections, intersections_sum) 135 | assert_series_equal(totals, totals_sum) 136 | else: 137 | with pytest.raises(ValueError): 138 | _process_data(x, subset_size="auto", **kw) 139 | 140 | total, df_count, intersections_count, totals_count = _process_data( 141 | x, subset_size="count", **kw 142 | ) 143 | assert total == intersections_count.sum() 144 | total, df, intersections, totals = _process_data( 145 | x.groupby(level=list(range(len(x.index.levels)))).count(), 146 | subset_size="sum", 147 | **kw, 148 | ) 149 | assert total == intersections.sum() 150 | assert_series_equal(intersections, intersections_count, check_names=False) 151 | assert_series_equal(totals, totals_count) 152 | 153 | 154 | @pytest.mark.parametrize( 155 | "x", 156 | [ 157 | generate_samples()["value"], 158 | ], 159 | ) 160 | @pytest.mark.parametrize("sort_by", ["cardinality", "degree", None]) 161 | @pytest.mark.parametrize("sort_categories_by", [None, "cardinality"]) 162 | def test_process_data_frame(x, sort_by, sort_categories_by): 163 | # shuffle input to test sorting 164 | x = x.sample(frac=1.0, replace=False, random_state=0) 165 | 166 | X = pd.DataFrame({"a": x}) 167 | 168 | with pytest.warns(None): 169 | total, df, intersections, totals = _process_data( 170 | X, 171 | sort_by=sort_by, 172 | sort_categories_by=sort_categories_by, 173 | sum_over="a", 174 | subset_size="auto", 175 | ) 176 | assert df is not X 177 | assert total == pytest.approx(intersections.sum()) 178 | 179 | # check equivalence to Series 180 | total1, df1, intersections1, totals1 = _process_data( 181 | x, 182 | sort_by=sort_by, 183 | sort_categories_by=sort_categories_by, 184 | subset_size="sum", 185 | sum_over=None, 186 | ) 187 | 188 | assert intersections.name == "a" 189 | assert_frame_equal(df, df1.rename(columns={"_value": "a"})) 190 | assert_series_equal(intersections, intersections1, check_names=False) 191 | assert_series_equal(totals, totals1) 192 | 193 | # check effect of extra column 194 | X = pd.DataFrame({"a": x, "b": np.arange(len(x))}) 195 | total2, df2, intersections2, totals2 = _process_data( 196 | X, 197 | sort_by=sort_by, 198 | sort_categories_by=sort_categories_by, 199 | sum_over="a", 200 | subset_size="auto", 201 | ) 202 | assert total2 == pytest.approx(intersections2.sum()) 203 | assert_series_equal(intersections, intersections2) 204 | assert_series_equal(totals, totals2) 205 | assert_frame_equal(df, df2.drop("b", axis=1)) 206 | assert_array_equal(df2["b"], X["b"]) # disregard levels, tested above 207 | 208 | # check effect not dependent on order/name 209 | X = pd.DataFrame({"b": np.arange(len(x)), "c": x}) 210 | total3, df3, intersections3, totals3 = _process_data( 211 | X, 212 | sort_by=sort_by, 213 | sort_categories_by=sort_categories_by, 214 | sum_over="c", 215 | subset_size="auto", 216 | ) 217 | assert total3 == pytest.approx(intersections3.sum()) 218 | assert_series_equal(intersections, intersections3, check_names=False) 219 | assert intersections.name == "a" 220 | assert intersections3.name == "c" 221 | assert_series_equal(totals, totals3) 222 | assert_frame_equal(df.rename(columns={"a": "c"}), df3.drop("b", axis=1)) 223 | assert_array_equal(df3["b"], X["b"]) 224 | 225 | # check subset_size='count' 226 | X = pd.DataFrame({"b": np.ones(len(x), dtype="int64"), "c": x}) 227 | 228 | total4, df4, intersections4, totals4 = _process_data( 229 | X, 230 | sort_by=sort_by, 231 | sort_categories_by=sort_categories_by, 232 | sum_over="b", 233 | subset_size="auto", 234 | ) 235 | total5, df5, intersections5, totals5 = _process_data( 236 | X, 237 | sort_by=sort_by, 238 | sort_categories_by=sort_categories_by, 239 | subset_size="count", 240 | sum_over=None, 241 | ) 242 | assert total5 == pytest.approx(intersections5.sum()) 243 | assert_series_equal(intersections4, intersections5, check_names=False) 244 | assert intersections4.name == "b" 245 | assert intersections5.name == "size" 246 | assert_series_equal(totals4, totals5) 247 | assert_frame_equal(df4, df5) 248 | 249 | 250 | @pytest.mark.parametrize( 251 | "x", 252 | [ 253 | generate_samples()["value"], 254 | generate_counts(), 255 | ], 256 | ) 257 | def test_subset_size_frame(x): 258 | kw = {"sort_by": "cardinality", "sort_categories_by": "cardinality"} 259 | X = pd.DataFrame({"x": x}) 260 | total_sum, df_sum, intersections_sum, totals_sum = _process_data( 261 | X, subset_size="sum", sum_over="x", **kw 262 | ) 263 | total_count, df_count, intersections_count, totals_count = _process_data( 264 | X, subset_size="count", sum_over=None, **kw 265 | ) 266 | 267 | # error cases: sum_over=False 268 | for subset_size in ["auto", "sum", "count"]: 269 | with pytest.raises(ValueError, match="sum_over"): 270 | _process_data(X, subset_size=subset_size, sum_over=False, **kw) 271 | 272 | with pytest.raises(ValueError, match="sum_over"): 273 | _process_data(X, subset_size=subset_size, sum_over=False, **kw) 274 | 275 | # error cases: sum_over incompatible with subset_size 276 | with pytest.raises(ValueError, match="sum_over should be a field"): 277 | _process_data(X, subset_size="sum", sum_over=None, **kw) 278 | with pytest.raises(ValueError, match="sum_over cannot be set"): 279 | _process_data(X, subset_size="count", sum_over="x", **kw) 280 | 281 | # check subset_size='auto' with sum_over=str => sum 282 | total, df, intersections, totals = _process_data( 283 | X, subset_size="auto", sum_over="x", **kw 284 | ) 285 | assert total == intersections.sum() 286 | assert_frame_equal(df, df_sum) 287 | assert_series_equal(intersections, intersections_sum) 288 | assert_series_equal(totals, totals_sum) 289 | 290 | # check subset_size='auto' with sum_over=None => count 291 | total, df, intersections, totals = _process_data( 292 | X, subset_size="auto", sum_over=None, **kw 293 | ) 294 | assert total == intersections.sum() 295 | assert_frame_equal(df, df_count) 296 | assert_series_equal(intersections, intersections_count) 297 | assert_series_equal(totals, totals_count) 298 | 299 | 300 | @pytest.mark.parametrize("sort_by", ["cardinality", "degree"]) 301 | @pytest.mark.parametrize("sort_categories_by", [None, "cardinality"]) 302 | def test_not_unique(sort_by, sort_categories_by): 303 | kw = { 304 | "sort_by": sort_by, 305 | "sort_categories_by": sort_categories_by, 306 | "subset_size": "sum", 307 | "sum_over": None, 308 | } 309 | Xagg = generate_counts() 310 | total1, df1, intersections1, totals1 = _process_data(Xagg, **kw) 311 | Xunagg = generate_samples()["value"] 312 | Xunagg.loc[:] = 1 313 | total2, df2, intersections2, totals2 = _process_data(Xunagg, **kw) 314 | assert_series_equal(intersections1, intersections2, check_dtype=False) 315 | assert total2 == intersections2.sum() 316 | assert_series_equal(totals1, totals2, check_dtype=False) 317 | assert set(df1.columns) == {"_value", "_bin"} 318 | assert set(df2.columns) == {"_value", "_bin"} 319 | assert len(df2) == len(Xunagg) 320 | assert df2["_bin"].nunique() == len(intersections2) 321 | 322 | 323 | def test_include_empty_subsets(): 324 | X = generate_counts(n_samples=2, n_categories=3) 325 | 326 | no_empty_upset = UpSet(X, include_empty_subsets=False) 327 | assert len(no_empty_upset.intersections) <= 2 328 | 329 | include_empty_upset = UpSet(X, include_empty_subsets=True) 330 | assert len(include_empty_upset.intersections) == 2**3 331 | common_intersections = include_empty_upset.intersections.loc[ 332 | no_empty_upset.intersections.index 333 | ] 334 | assert_series_equal(no_empty_upset.intersections, common_intersections) 335 | include_empty_upset.plot() # smoke test 336 | 337 | 338 | @pytest.mark.parametrize( 339 | "kw", 340 | [ 341 | {"sort_by": "blah"}, 342 | {"sort_by": True}, 343 | {"sort_categories_by": "blah"}, 344 | {"sort_categories_by": True}, 345 | ], 346 | ) 347 | def test_param_validation(kw): 348 | X = generate_counts(n_samples=100) 349 | with pytest.raises(ValueError): 350 | UpSet(X, **kw) 351 | 352 | 353 | @pytest.mark.parametrize( 354 | "kw", 355 | [ 356 | {}, 357 | {"element_size": None}, 358 | {"orientation": "vertical"}, 359 | {"intersection_plot_elements": 0}, 360 | {"facecolor": "red"}, 361 | {"shading_color": "lightgrey", "other_dots_color": "pink"}, 362 | {"totals_plot_elements": 0}, 363 | ], 364 | ) 365 | def test_plot_smoke_test(kw): 366 | fig = matplotlib.figure.Figure() 367 | X = generate_counts(n_samples=100) 368 | axes = plot(X, fig, **kw) 369 | fig.savefig(io.BytesIO(), format="png") 370 | 371 | attr = ( 372 | "get_xlim" 373 | if kw.get("orientation", "horizontal") == "horizontal" 374 | else "get_ylim" 375 | ) 376 | lim = getattr(axes["matrix"], attr)() 377 | expected_width = len(X) 378 | assert expected_width == lim[1] - lim[0] 379 | 380 | # Also check fig is optional 381 | n_nums = len(plt.get_fignums()) 382 | plot(X, **kw) 383 | assert len(plt.get_fignums()) - n_nums == 1 384 | assert plt.gcf().axes 385 | 386 | 387 | @pytest.mark.parametrize("set1", itertools.product([False, True], repeat=2)) 388 | @pytest.mark.parametrize("set2", itertools.product([False, True], repeat=2)) 389 | def test_two_sets(set1, set2): 390 | # we had a bug where processing failed if no items were in some set 391 | fig = matplotlib.figure.Figure() 392 | plot( 393 | pd.DataFrame({"val": [5, 7], "set1": set1, "set2": set2}).set_index( 394 | ["set1", "set2"] 395 | )["val"], 396 | fig, 397 | subset_size="sum", 398 | ) 399 | 400 | 401 | def test_vertical(): 402 | X = generate_counts(n_samples=100) 403 | 404 | fig = matplotlib.figure.Figure() 405 | UpSet(X, orientation="horizontal").make_grid(fig) 406 | horz_height = fig.get_figheight() 407 | horz_width = fig.get_figwidth() 408 | assert horz_height < horz_width 409 | 410 | fig = matplotlib.figure.Figure() 411 | UpSet(X, orientation="vertical").make_grid(fig) 412 | vert_height = fig.get_figheight() 413 | vert_width = fig.get_figwidth() 414 | assert horz_width / horz_height > vert_width / vert_height 415 | 416 | # TODO: test axes positions, plot order, bar orientation 417 | pass 418 | 419 | 420 | def test_element_size(): 421 | X = generate_counts(n_samples=100) 422 | figsizes = [] 423 | for element_size in range(10, 50, 5): 424 | fig = matplotlib.figure.Figure() 425 | UpSet(X, element_size=element_size).make_grid(fig) 426 | figsizes.append((fig.get_figwidth(), fig.get_figheight())) 427 | 428 | figwidths, figheights = zip(*figsizes) 429 | # Absolute width increases 430 | assert np.all(np.diff(figwidths) > 0) 431 | aspect = np.divide(figwidths, figheights) 432 | # Font size stays constant, so aspect ratio decreases 433 | assert np.all(np.diff(aspect) <= 1e-8) # allow for near-equality 434 | assert np.any(np.diff(aspect) < 1e-4) # require some significant decrease 435 | # But doesn't decrease by much 436 | assert np.all(aspect[:-1] / aspect[1:] < 1.1) 437 | 438 | fig = matplotlib.figure.Figure() 439 | figsize_before = fig.get_figwidth(), fig.get_figheight() 440 | UpSet(X, element_size=None).make_grid(fig) 441 | figsize_after = fig.get_figwidth(), fig.get_figheight() 442 | assert figsize_before == figsize_after 443 | 444 | # TODO: make sure axes are all within figure 445 | # TODO: make sure text does not overlap axes, even with element_size=None 446 | 447 | 448 | def _walk_artists(el): 449 | children = el.get_children() 450 | yield el, children 451 | for ch in children: 452 | yield from _walk_artists(ch) 453 | 454 | 455 | def _count_descendants(el): 456 | return sum(len(children) for x, children in _walk_artists(el)) 457 | 458 | 459 | @pytest.mark.parametrize("orientation", ["horizontal", "vertical"]) 460 | def test_show_counts(orientation): 461 | fig = matplotlib.figure.Figure() 462 | X = generate_counts(n_samples=10000) 463 | plot(X, fig, orientation=orientation) 464 | n_artists_no_sizes = _count_descendants(fig) 465 | 466 | fig = matplotlib.figure.Figure() 467 | plot(X, fig, orientation=orientation, show_counts=True) 468 | n_artists_yes_sizes = _count_descendants(fig) 469 | assert n_artists_yes_sizes - n_artists_no_sizes > 6 470 | assert "9547" in get_all_texts(fig) # set size 471 | assert "283" in get_all_texts(fig) # intersection size 472 | 473 | fig = matplotlib.figure.Figure() 474 | plot(X, fig, orientation=orientation, show_counts="%0.2g") 475 | assert n_artists_yes_sizes == _count_descendants(fig) 476 | assert "9.5e+03" in get_all_texts(fig) 477 | assert "2.8e+02" in get_all_texts(fig) 478 | 479 | fig = matplotlib.figure.Figure() 480 | plot(X, fig, orientation=orientation, show_counts="{:0.2g}") 481 | assert n_artists_yes_sizes == _count_descendants(fig) 482 | assert "9.5e+03" in get_all_texts(fig) 483 | assert "2.8e+02" in get_all_texts(fig) 484 | 485 | fig = matplotlib.figure.Figure() 486 | plot(X, fig, orientation=orientation, show_percentages=True) 487 | assert n_artists_yes_sizes == _count_descendants(fig) 488 | assert "95.5%" in get_all_texts(fig) 489 | assert "2.8%" in get_all_texts(fig) 490 | 491 | fig = matplotlib.figure.Figure() 492 | plot(X, fig, orientation=orientation, show_percentages="!{:0.2f}!") 493 | assert n_artists_yes_sizes == _count_descendants(fig) 494 | assert "!0.95!" in get_all_texts(fig) 495 | assert "!0.03!" in get_all_texts(fig) 496 | 497 | fig = matplotlib.figure.Figure() 498 | plot(X, fig, orientation=orientation, show_counts=True, show_percentages=True) 499 | assert n_artists_yes_sizes == _count_descendants(fig) 500 | if orientation == "vertical": 501 | assert "9547\n(95.5%)" in get_all_texts(fig) 502 | assert "283 (2.8%)" in get_all_texts(fig) 503 | else: 504 | assert "9547 (95.5%)" in get_all_texts(fig) 505 | assert "283\n(2.8%)" in get_all_texts(fig) 506 | 507 | fig = matplotlib.figure.Figure() 508 | with pytest.raises(ValueError): 509 | plot(X, fig, orientation=orientation, show_counts="%0.2h") 510 | 511 | 512 | def test_add_catplot(): 513 | pytest.importorskip("seaborn") 514 | X = generate_counts(n_samples=100) 515 | upset = UpSet(X) 516 | # smoke test 517 | upset.add_catplot("violin") 518 | fig = matplotlib.figure.Figure() 519 | upset.plot(fig) 520 | 521 | # can't provide value with Series 522 | with pytest.raises(ValueError): 523 | upset.add_catplot("violin", value="foo") 524 | 525 | # check the above add_catplot did not break the state 526 | upset.plot(fig) 527 | 528 | X = generate_counts(n_samples=100) 529 | X.name = "foo" 530 | X = X.to_frame() 531 | upset = UpSet(X, subset_size="count") 532 | # must provide value with DataFrame 533 | with pytest.raises(ValueError): 534 | upset.add_catplot("violin") 535 | upset.add_catplot("violin", value="foo") 536 | with pytest.raises(ValueError): 537 | # not a known column 538 | upset.add_catplot("violin", value="bar") 539 | upset.plot(fig) 540 | 541 | # invalid plot kind raises error when plotting 542 | upset.add_catplot("foobar", value="foo") 543 | with pytest.raises(AttributeError): 544 | upset.plot(fig) 545 | 546 | 547 | def _get_patch_data(axes, is_vertical): 548 | out = [ 549 | { 550 | "y": patch.get_y(), 551 | "x": patch.get_x(), 552 | "h": patch.get_height(), 553 | "w": patch.get_width(), 554 | "fc": patch.get_facecolor(), 555 | "ec": patch.get_edgecolor(), 556 | "lw": patch.get_linewidth(), 557 | "ls": patch.get_linestyle(), 558 | "hatch": patch.get_hatch(), 559 | } 560 | for patch in axes.patches 561 | ] 562 | if is_vertical: 563 | out = [ 564 | { 565 | "y": patch["x"], 566 | "x": 6.5 - patch["y"], 567 | "h": patch["w"], 568 | "w": patch["h"], 569 | "fc": patch["fc"], 570 | "ec": patch["ec"], 571 | "lw": patch["lw"], 572 | "ls": patch["ls"], 573 | "hatch": patch["hatch"], 574 | } 575 | for patch in out 576 | ] 577 | return pd.DataFrame(out).sort_values("x").reset_index(drop=True) 578 | 579 | 580 | def _get_color_to_label_from_legend(ax): 581 | handles, labels = ax.get_legend_handles_labels() 582 | color_to_label = { 583 | patches[0].get_facecolor(): label for patches, label in zip(handles, labels) 584 | } 585 | return color_to_label 586 | 587 | 588 | @pytest.mark.parametrize("orientation", ["horizontal", "vertical"]) 589 | @pytest.mark.parametrize("show_counts", [False, True]) 590 | def test_add_stacked_bars(orientation, show_counts): 591 | df = generate_samples() 592 | df["label"] = pd.cut( 593 | generate_samples().value + np.random.rand() / 2, 3 594 | ).cat.codes.map({0: "foo", 1: "bar", 2: "baz"}) 595 | 596 | upset = UpSet(df, show_counts=show_counts, orientation=orientation) 597 | upset.add_stacked_bars(by="label") 598 | upset_axes = upset.plot() 599 | 600 | int_axes = upset_axes["intersections"] 601 | stacked_axes = upset_axes["extra1"] 602 | 603 | is_vertical = orientation == "vertical" 604 | int_rects = _get_patch_data(int_axes, is_vertical) 605 | stacked_rects = _get_patch_data(stacked_axes, is_vertical) 606 | 607 | # check bar heights match between int_rects and stacked_rects 608 | assert_series_equal( 609 | int_rects.groupby("x")["h"].sum(), 610 | stacked_rects.groupby("x")["h"].sum(), 611 | check_dtype=False, 612 | ) 613 | # check count labels match (TODO: check coordinate) 614 | assert [elem.get_text() for elem in int_axes.texts] == [ 615 | elem.get_text() for elem in stacked_axes.texts 616 | ] 617 | 618 | color_to_label = _get_color_to_label_from_legend(stacked_axes) 619 | stacked_rects["label"] = stacked_rects["fc"].map(color_to_label) 620 | # check totals for each label 621 | assert_series_equal( 622 | stacked_rects.groupby("label")["h"].sum(), 623 | df.groupby("label").size(), 624 | check_dtype=False, 625 | check_names=False, 626 | ) 627 | 628 | label_order = [ 629 | text_obj.get_text() for text_obj in stacked_axes.get_legend().get_texts() 630 | ] 631 | # label order should be lexicographic 632 | assert label_order == sorted(label_order) 633 | 634 | if orientation == "horizontal": 635 | # order of labels in legend should match stack, top to bottom 636 | for prev, curr in zip(label_order, label_order[1:]): 637 | assert ( 638 | stacked_rects.query("label == @prev").sort_values("x")["y"].values 639 | >= stacked_rects.query("label == @curr").sort_values("x")["y"].values 640 | ).all() 641 | else: 642 | # order of labels in legend should match stack, left to right 643 | for prev, curr in zip(label_order, label_order[1:]): 644 | assert ( 645 | stacked_rects.query("label == @prev").sort_values("x")["y"].values 646 | <= stacked_rects.query("label == @curr").sort_values("x")["y"].values 647 | ).all() 648 | 649 | 650 | @pytest.mark.parametrize( 651 | ("colors", "expected"), 652 | [ 653 | (["blue", "red", "green"], ["blue", "red", "green"]), 654 | ({"bar": "blue", "baz": "red", "foo": "green"}, ["blue", "red", "green"]), 655 | ("Pastel1", ["#fbb4ae", "#b3cde3", "#ccebc5"]), 656 | (cm.viridis, ["#440154", "#440256", "#450457"]), 657 | (lambda x: cm.Pastel1(x), ["#fbb4ae", "#b3cde3", "#ccebc5"]), 658 | ], 659 | ) 660 | def test_add_stacked_bars_colors(colors, expected): 661 | df = generate_samples() 662 | df["label"] = pd.cut( 663 | generate_samples().value + np.random.rand() / 2, 3 664 | ).cat.codes.map({0: "foo", 1: "bar", 2: "baz"}) 665 | 666 | upset = UpSet(df) 667 | upset.add_stacked_bars(by="label", colors=colors, title="Count by gender") 668 | upset_axes = upset.plot() 669 | stacked_axes = upset_axes["extra1"] 670 | color_to_label = _get_color_to_label_from_legend(stacked_axes) 671 | label_to_color = {v: k for k, v in color_to_label.items()} 672 | actual = [to_hex(label_to_color[label]) for label in ["bar", "baz", "foo"]] 673 | expected = [to_hex(color) for color in expected] 674 | assert actual == expected 675 | 676 | 677 | @pytest.mark.parametrize("int_sum_over", [False, True]) 678 | @pytest.mark.parametrize("stack_sum_over", [False, True]) 679 | @pytest.mark.parametrize("show_counts", [False, True]) 680 | def test_add_stacked_bars_sum_over(int_sum_over, stack_sum_over, show_counts): 681 | # A rough test of sum_over 682 | df = generate_samples() 683 | df["label"] = pd.cut( 684 | generate_samples().value + np.random.rand() / 2, 3 685 | ).cat.codes.map({0: "foo", 1: "bar", 2: "baz"}) 686 | 687 | upset = UpSet( 688 | df, sum_over="value" if int_sum_over else None, show_counts=show_counts 689 | ) 690 | upset.add_stacked_bars( 691 | by="label", sum_over="value" if stack_sum_over else None, colors="Pastel1" 692 | ) 693 | upset_axes = upset.plot() 694 | 695 | int_axes = upset_axes["intersections"] 696 | stacked_axes = upset_axes["extra1"] 697 | 698 | int_rects = _get_patch_data(int_axes, is_vertical=False) 699 | stacked_rects = _get_patch_data(stacked_axes, is_vertical=False) 700 | 701 | if int_sum_over == stack_sum_over: 702 | # check bar heights match between int_rects and stacked_rects 703 | assert_series_equal( 704 | int_rects.groupby("x")["h"].sum(), 705 | stacked_rects.groupby("x")["h"].sum(), 706 | check_dtype=False, 707 | ) 708 | # and check labels match with show_counts 709 | assert [elem.get_text() for elem in int_axes.texts] == [ 710 | elem.get_text() for elem in stacked_axes.texts 711 | ] 712 | else: 713 | assert ( 714 | int_rects.groupby("x")["h"].sum() != stacked_rects.groupby("x")["h"].sum() 715 | ).all() 716 | if show_counts: 717 | assert [elem.get_text() for elem in int_axes.texts] != [ 718 | elem.get_text() for elem in stacked_axes.texts 719 | ] 720 | 721 | 722 | @pytest.mark.parametrize( 723 | "x", 724 | [ 725 | generate_counts(), 726 | ], 727 | ) 728 | def test_index_must_be_bool(x): 729 | # Truthy ints are okay 730 | x = x.reset_index() 731 | x[["cat0", "cat2", "cat2"]] = x[["cat0", "cat1", "cat2"]].astype(int) 732 | x = x.set_index(["cat0", "cat1", "cat2"]).iloc[:, 0] 733 | 734 | UpSet(x) 735 | 736 | # other ints are not 737 | x = x.reset_index() 738 | x[["cat0", "cat2", "cat2"]] = x[["cat0", "cat1", "cat2"]] + 1 739 | x = x.set_index(["cat0", "cat1", "cat2"]).iloc[:, 0] 740 | with pytest.raises(ValueError, match="not boolean"): 741 | UpSet(x) 742 | 743 | 744 | @pytest.mark.parametrize( 745 | ("filter_params", "expected"), 746 | [ 747 | ( 748 | {"min_subset_size": 623}, 749 | { 750 | (True, False, False): 884, 751 | (True, True, False): 1547, 752 | (True, False, True): 623, 753 | (True, True, True): 990, 754 | }, 755 | ), 756 | ( 757 | {"max_subset_rank": 3}, 758 | { 759 | (True, False, False): 884, 760 | (True, True, False): 1547, 761 | (True, True, True): 990, 762 | }, 763 | ), 764 | ( 765 | {"min_subset_size": 800, "max_subset_size": 990}, 766 | { 767 | (True, False, False): 884, 768 | (True, True, True): 990, 769 | }, 770 | ), 771 | ( 772 | {"min_subset_size": "15%", "max_subset_size": "30.1%"}, 773 | { 774 | (True, False, False): 884, 775 | (True, True, True): 990, 776 | }, 777 | ), 778 | ( 779 | {"min_degree": 2}, 780 | { 781 | (True, True, False): 1547, 782 | (True, False, True): 623, 783 | (False, True, True): 258, 784 | (True, True, True): 990, 785 | }, 786 | ), 787 | ( 788 | {"min_degree": 2, "max_degree": 2}, 789 | { 790 | (True, True, False): 1547, 791 | (True, False, True): 623, 792 | (False, True, True): 258, 793 | }, 794 | ), 795 | ( 796 | {"max_subset_size": 500, "max_degree": 2}, 797 | { 798 | (False, False, False): 220, 799 | (False, True, False): 335, 800 | (False, False, True): 143, 801 | (False, True, True): 258, 802 | }, 803 | ), 804 | ], 805 | ) 806 | @pytest.mark.parametrize("sort_by", ["cardinality", "degree"]) 807 | def test_filter_subsets(filter_params, expected, sort_by): 808 | data = generate_samples(seed=0, n_samples=5000, n_categories=3) 809 | # data = 810 | # cat1 cat0 cat2 811 | # False False False 220 812 | # True False False 884 813 | # False True False 335 814 | # False True 143 815 | # True True False 1547 816 | # False True 623 817 | # False True True 258 818 | # True True True 990 819 | upset_full = UpSet(data, subset_size="auto", sort_by=sort_by) 820 | upset_filtered = UpSet(data, subset_size="auto", sort_by=sort_by, **filter_params) 821 | intersections = upset_full.intersections 822 | df = upset_full._df 823 | # check integrity of expected, just to be sure 824 | for key, value in expected.items(): 825 | assert intersections.loc[key] == value 826 | subset_intersections = intersections[ 827 | intersections.index.isin(list(expected.keys())) 828 | ] 829 | subset_df = df[df.index.isin(list(expected.keys()))] 830 | assert len(subset_intersections) < len(intersections) 831 | assert_series_equal(upset_filtered.intersections, subset_intersections) 832 | assert_frame_equal( 833 | upset_filtered._df.drop("_bin", axis=1), subset_df.drop("_bin", axis=1) 834 | ) 835 | # category totals should not be affected 836 | assert_series_equal(upset_full.totals, upset_filtered.totals) 837 | assert upset_full.total == pytest.approx(upset_filtered.total) 838 | 839 | 840 | def test_filter_subsets_max_subset_rank_tie(): 841 | data = generate_samples(seed=0, n_samples=5, n_categories=3) 842 | tested_non_tie = False 843 | tested_tie = True 844 | full = UpSet(data, subset_size="count").intersections 845 | prev = None 846 | for max_rank in range(1, 5): 847 | cur = UpSet(data, subset_size="count", max_subset_rank=max_rank).intersections 848 | if prev is not None: 849 | if cur.shape[0] > prev.shape[0]: 850 | # check we add rows only when they are new 851 | assert cur.min() < prev.min() 852 | tested_non_tie = True 853 | elif cur.shape[0] != full.shape[0]: 854 | assert (cur == cur.min()).sum() > 1 855 | tested_tie = True 856 | 857 | prev = cur 858 | assert tested_non_tie 859 | assert tested_tie 860 | assert cur.shape[0] == full.shape[0] 861 | 862 | 863 | @pytest.mark.parametrize( 864 | "value", 865 | [ 866 | "1", 867 | "-1%", 868 | "1%%", 869 | "%1", 870 | "hello", 871 | ], 872 | ) 873 | def test_bad_percentages(value): 874 | data = generate_samples(seed=0, n_samples=5, n_categories=3) 875 | with pytest.raises(ValueError, match="percentage"): 876 | UpSet(data, min_subset_size=value) 877 | 878 | 879 | @pytest.mark.parametrize( 880 | "x", 881 | [ 882 | generate_counts(n_categories=3), 883 | generate_counts(n_categories=8), 884 | generate_counts(n_categories=15), 885 | ], 886 | ) 887 | @pytest.mark.parametrize( 888 | "orientation", 889 | [ 890 | "horizontal", 891 | "vertical", 892 | ], 893 | ) 894 | def test_matrix_plot_margins(x, orientation): 895 | """Non-regression test addressing a bug where there is are large whitespace 896 | margins around the matrix when the number of intersections is large""" 897 | axes = plot(x, orientation=orientation) 898 | 899 | # Expected behavior is that each matrix column takes up one unit on x-axis 900 | expected_width = len(x) 901 | attr = "get_xlim" if orientation == "horizontal" else "get_ylim" 902 | lim = getattr(axes["matrix"], attr)() 903 | assert expected_width == lim[1] - lim[0] 904 | 905 | 906 | def _make_facecolor_list(colors): 907 | return [{"facecolor": c} for c in colors] 908 | 909 | 910 | CAT1_2_RED_STYLES = _make_facecolor_list( 911 | ["blue", "blue", "blue", "blue", "red", "blue", "blue", "red"] 912 | ) 913 | CAT1_RED_STYLES = _make_facecolor_list( 914 | ["blue", "red", "blue", "blue", "red", "red", "blue", "red"] 915 | ) 916 | CAT_NOT1_RED_STYLES = _make_facecolor_list( 917 | ["red", "blue", "red", "red", "blue", "blue", "red", "blue"] 918 | ) 919 | CAT1_NOT2_RED_STYLES = _make_facecolor_list( 920 | ["blue", "red", "blue", "blue", "blue", "red", "blue", "blue"] 921 | ) 922 | CAT_NOT1_2_RED_STYLES = _make_facecolor_list( 923 | ["red", "blue", "blue", "red", "blue", "blue", "blue", "blue"] 924 | ) 925 | 926 | 927 | @pytest.mark.parametrize( 928 | ("kwarg_list", "expected_subset_styles", "expected_legend"), 929 | [ 930 | # Different forms of including two categories 931 | ([{"present": ["cat1", "cat2"], "facecolor": "red"}], CAT1_2_RED_STYLES, []), 932 | ([{"present": {"cat1", "cat2"}, "facecolor": "red"}], CAT1_2_RED_STYLES, []), 933 | ([{"present": ("cat1", "cat2"), "facecolor": "red"}], CAT1_2_RED_STYLES, []), 934 | # with legend 935 | ( 936 | [{"present": ("cat1", "cat2"), "facecolor": "red", "label": "foo"}], 937 | CAT1_2_RED_STYLES, 938 | [({"facecolor": "red"}, "foo")], 939 | ), 940 | # present only cat1 941 | ([{"present": ("cat1",), "facecolor": "red"}], CAT1_RED_STYLES, []), 942 | ([{"present": "cat1", "facecolor": "red"}], CAT1_RED_STYLES, []), 943 | # Some uses of absent 944 | ([{"absent": "cat1", "facecolor": "red"}], CAT_NOT1_RED_STYLES, []), 945 | ( 946 | [{"present": "cat1", "absent": ["cat2"], "facecolor": "red"}], 947 | CAT1_NOT2_RED_STYLES, 948 | [], 949 | ), 950 | ([{"absent": ["cat2", "cat1"], "facecolor": "red"}], CAT_NOT1_2_RED_STYLES, []), 951 | # min/max args 952 | ( 953 | [{"present": ["cat1", "cat2"], "min_degree": 3, "facecolor": "red"}], 954 | _make_facecolor_list(["blue"] * 7 + ["red"]), 955 | [], 956 | ), 957 | ( 958 | [ 959 | { 960 | "present": ["cat1", "cat2"], 961 | "max_subset_size": 3000, 962 | "facecolor": "red", 963 | } 964 | ], 965 | _make_facecolor_list(["blue"] * 7 + ["red"]), 966 | [], 967 | ), 968 | ( 969 | [{"present": ["cat1", "cat2"], "max_degree": 2, "facecolor": "red"}], 970 | _make_facecolor_list(["blue"] * 4 + ["red"] + ["blue"] * 3), 971 | [], 972 | ), 973 | ( 974 | [ 975 | { 976 | "present": ["cat1", "cat2"], 977 | "min_subset_size": 3000, 978 | "facecolor": "red", 979 | } 980 | ], 981 | _make_facecolor_list(["blue"] * 4 + ["red"] + ["blue"] * 3), 982 | [], 983 | ), 984 | # cat1 _or_ cat2 985 | ( 986 | [ 987 | {"present": "cat1", "facecolor": "red"}, 988 | {"present": "cat2", "facecolor": "red"}, 989 | ], 990 | _make_facecolor_list( 991 | ["blue", "red", "red", "blue", "red", "red", "red", "red"] 992 | ), 993 | [], 994 | ), 995 | # With multiple uses of label 996 | ( 997 | [ 998 | {"present": "cat1", "facecolor": "red", "label": "foo"}, 999 | {"present": "cat2", "facecolor": "red", "label": "bar"}, 1000 | ], 1001 | _make_facecolor_list( 1002 | ["blue", "red", "red", "blue", "red", "red", "red", "red"] 1003 | ), 1004 | [({"facecolor": "red"}, "foo; bar")], 1005 | ), 1006 | ( 1007 | [ 1008 | {"present": "cat1", "facecolor": "red", "label": "foo"}, 1009 | {"present": "cat2", "facecolor": "red", "label": "foo"}, 1010 | ], 1011 | _make_facecolor_list( 1012 | ["blue", "red", "red", "blue", "red", "red", "red", "red"] 1013 | ), 1014 | [({"facecolor": "red"}, "foo")], 1015 | ), 1016 | # With multiple colours, the latest overrides 1017 | ( 1018 | [ 1019 | {"present": "cat1", "facecolor": "red", "label": "foo"}, 1020 | {"present": "cat2", "facecolor": "green", "label": "bar"}, 1021 | ], 1022 | _make_facecolor_list( 1023 | ["blue", "red", "green", "blue", "green", "red", "green", "green"] 1024 | ), 1025 | [({"facecolor": "red"}, "foo"), ({"facecolor": "green"}, "bar")], 1026 | ), 1027 | # Combining multiple style properties 1028 | ( 1029 | [ 1030 | {"present": "cat1", "facecolor": "red", "hatch": "//"}, 1031 | {"present": "cat2", "edgecolor": "green", "linestyle": "dotted"}, 1032 | ], 1033 | [ 1034 | {"facecolor": "blue"}, 1035 | {"facecolor": "red", "hatch": "//"}, 1036 | {"facecolor": "blue", "edgecolor": "green", "linestyle": "dotted"}, 1037 | {"facecolor": "blue"}, 1038 | { 1039 | "facecolor": "red", 1040 | "hatch": "//", 1041 | "edgecolor": "green", 1042 | "linestyle": "dotted", 1043 | }, 1044 | {"facecolor": "red", "hatch": "//"}, 1045 | {"facecolor": "blue", "edgecolor": "green", "linestyle": "dotted"}, 1046 | { 1047 | "facecolor": "red", 1048 | "hatch": "//", 1049 | "edgecolor": "green", 1050 | "linestyle": "dotted", 1051 | }, 1052 | ], 1053 | [], 1054 | ), 1055 | ], 1056 | ) 1057 | def test_style_subsets(kwarg_list, expected_subset_styles, expected_legend): 1058 | data = generate_counts() 1059 | upset = UpSet(data, facecolor="blue") 1060 | for kw in kwarg_list: 1061 | upset.style_subsets(**kw) 1062 | actual_subset_styles = upset.subset_styles 1063 | assert actual_subset_styles == expected_subset_styles 1064 | assert upset.subset_legend == expected_legend 1065 | 1066 | 1067 | def _dots_to_dataframe(ax, is_vertical): 1068 | matrix_path_collection = ax.collections[0] 1069 | matrix_dots = ( 1070 | pd.DataFrame(matrix_path_collection.get_offsets(), columns=["x", "y"]) 1071 | .join( 1072 | pd.DataFrame( 1073 | matrix_path_collection.get_facecolors(), 1074 | columns=["fc_r", "fc_g", "fc_b", "fc_a"], 1075 | ), 1076 | ) 1077 | .join( 1078 | pd.DataFrame( 1079 | matrix_path_collection.get_edgecolors(), 1080 | columns=["ec_r", "ec_g", "ec_b", "ec_a"], 1081 | ), 1082 | ) 1083 | .assign( 1084 | lw=matrix_path_collection.get_linewidths(), 1085 | ls=matrix_path_collection.get_linestyles(), 1086 | hatch=matrix_path_collection.get_hatch(), 1087 | ) 1088 | ) 1089 | 1090 | matrix_dots["ls_offset"] = matrix_dots["ls"].map(lambda tup: tup[0]).astype(float) 1091 | matrix_dots["ls_seq"] = matrix_dots["ls"].map( 1092 | lambda tup: None if tup[1] is None else tuple(tup[1]) 1093 | ) 1094 | del matrix_dots["ls"] 1095 | 1096 | if is_vertical: 1097 | matrix_dots[["x", "y"]] = matrix_dots[["y", "x"]] 1098 | matrix_dots["x"] = 7 - matrix_dots["x"] 1099 | return matrix_dots 1100 | 1101 | 1102 | @pytest.mark.parametrize("orientation", ["horizontal", "vertical"]) 1103 | def test_style_subsets_artists(orientation): 1104 | # Check that subset_styles are all appropriately reflected in matplotlib 1105 | # artists. 1106 | # This may be a bit overkill, and too coupled with implementation details. 1107 | is_vertical = orientation == "vertical" 1108 | data = generate_counts() 1109 | upset = UpSet(data, orientation=orientation) 1110 | subset_styles = [ 1111 | {"facecolor": "black"}, 1112 | {"facecolor": "red"}, 1113 | {"edgecolor": "red"}, 1114 | {"edgecolor": "red", "linewidth": 4}, 1115 | {"linestyle": "dotted"}, 1116 | {"edgecolor": "red", "facecolor": "blue", "hatch": "//"}, 1117 | {"facecolor": "blue"}, 1118 | {}, 1119 | ] 1120 | 1121 | if is_vertical: 1122 | upset.subset_styles = subset_styles[::-1] 1123 | else: 1124 | upset.subset_styles = subset_styles 1125 | 1126 | upset_axes = upset.plot() 1127 | 1128 | int_rects = _get_patch_data(upset_axes["intersections"], is_vertical) 1129 | int_rects[["fc_r", "fc_g", "fc_b", "fc_a"]] = int_rects.pop("fc").apply( 1130 | lambda x: pd.Series(x) 1131 | ) 1132 | int_rects[["ec_r", "ec_g", "ec_b", "ec_a"]] = int_rects.pop("ec").apply( 1133 | lambda x: pd.Series(x) 1134 | ) 1135 | int_rects["ls_is_solid"] = int_rects.pop("ls").map( 1136 | lambda x: x == "solid" or pd.isna(x) 1137 | ) 1138 | expected = pd.DataFrame( 1139 | { 1140 | "fc_r": [0, 1, 0, 0, 0, 0, 0, 0], 1141 | "fc_g": [0, 0, 0, 0, 0, 0, 0, 0], 1142 | "fc_b": [0, 0, 0, 0, 0, 1, 1, 0], 1143 | "ec_r": [0, 1, 1, 1, 0, 1, 0, 0], 1144 | "ec_g": [0, 0, 0, 0, 0, 0, 0, 0], 1145 | "ec_b": [0, 0, 0, 0, 0, 0, 1, 0], 1146 | "lw": [1, 1, 1, 4, 1, 1, 1, 1], 1147 | "ls_is_solid": [True, True, True, True, False, True, True, True], 1148 | } 1149 | ) 1150 | 1151 | assert_frame_equal(expected, int_rects[expected.columns], check_dtype=False) 1152 | 1153 | styled_dots = _dots_to_dataframe(upset_axes["matrix"], is_vertical) 1154 | baseline_dots = _dots_to_dataframe( 1155 | UpSet(data, orientation=orientation).plot()["matrix"], is_vertical 1156 | ) 1157 | inactive_dot_mask = (baseline_dots[["fc_a"]] < 1).values.ravel() 1158 | assert_frame_equal( 1159 | baseline_dots.loc[inactive_dot_mask], styled_dots.loc[inactive_dot_mask] 1160 | ) 1161 | 1162 | styled_dots = styled_dots.loc[~inactive_dot_mask] 1163 | 1164 | styled_dots = ( 1165 | styled_dots.drop(columns="y") 1166 | .groupby("x") 1167 | .apply(lambda df: df.drop_duplicates()) 1168 | ) 1169 | styled_dots["ls_is_solid"] = styled_dots.pop("ls_seq").isna() 1170 | assert_frame_equal( 1171 | expected.iloc[1:].reset_index(drop=True), 1172 | styled_dots[expected.columns].reset_index(drop=True), 1173 | check_dtype=False, 1174 | ) 1175 | 1176 | # TODO: check lines between dots 1177 | # matrix_line_collection = upset_axes["matrix"].collections[1] 1178 | 1179 | 1180 | @pytest.mark.parametrize( 1181 | ( 1182 | "kwarg_list", 1183 | "expected_category_styles", 1184 | ), 1185 | [ 1186 | # Different forms of including two categories 1187 | ( 1188 | [{"categories": ["cat1", "cat2"], "shading_facecolor": "red"}], 1189 | { 1190 | "cat1": {"shading_facecolor": "red"}, 1191 | "cat2": {"shading_facecolor": "red"}, 1192 | }, 1193 | ), 1194 | ( 1195 | [ 1196 | {"categories": ["cat1", "cat2"], "shading_facecolor": "red"}, 1197 | {"categories": "cat1", "shading_facecolor": "green"}, 1198 | ], 1199 | { 1200 | "cat1": {"shading_facecolor": "green"}, 1201 | "cat2": {"shading_facecolor": "red"}, 1202 | }, 1203 | ), 1204 | ( 1205 | [ 1206 | {"categories": ["cat1", "cat2"], "shading_facecolor": "red"}, 1207 | {"categories": "cat1", "shading_edgecolor": "green"}, 1208 | ], 1209 | { 1210 | "cat1": {"shading_facecolor": "red", "shading_edgecolor": "green"}, 1211 | "cat2": {"shading_facecolor": "red"}, 1212 | }, 1213 | ), 1214 | ], 1215 | ) 1216 | def test_categories(kwarg_list, expected_category_styles): 1217 | data = generate_counts() 1218 | upset = UpSet(data, facecolor="blue") 1219 | for kw in kwarg_list: 1220 | upset.style_categories(**kw) 1221 | actual_category_styles = upset.category_styles 1222 | assert actual_category_styles == expected_category_styles 1223 | 1224 | 1225 | def test_many_categories(): 1226 | # Tests regressions against GH#193 1227 | n_cats = 250 1228 | index1 = [True, False] + [False] * (n_cats - 2) 1229 | index2 = [False, True] + [False] * (n_cats - 2) 1230 | columns = [chr(i + 33) for i in range(n_cats)] 1231 | data = pd.DataFrame([index1, index2], columns=columns) 1232 | data["value"] = 1 1233 | data = data.set_index(columns)["value"] 1234 | UpSet(data) 1235 | -------------------------------------------------------------------------------- /upsetplot/util.py: -------------------------------------------------------------------------------- 1 | """Generic utilities""" 2 | import re 3 | 4 | # The below is adapted from an answer to 5 | # https://stackoverflow.com/questions/66822945 6 | # by Andrius at https://stackoverflow.com/a/66869159/1017546 7 | # Reproduced under the CC-BY-SA 3.0 licence. 8 | 9 | ODD_REPEAT_PATTERN = r"((? str: 52 | """Convert old style named formatting to new style formatting. 53 | For example: '%(x)s - %%%(y)s' -> '{x} - %{y}' 54 | Args: 55 | fmt: old style formatting to convert. 56 | Returns: 57 | new style formatting. 58 | """ 59 | return __to_new_format(fmt, named=True) 60 | 61 | 62 | def to_new_pos_format(fmt: str) -> str: 63 | """Convert old style positional formatting to new style formatting. 64 | For example: '%s - %%%s' -> '{} - %{}' 65 | Args: 66 | fmt: old style formatting to convert. 67 | Returns: 68 | new style formatting. 69 | """ 70 | return __to_new_format(fmt, named=False) 71 | --------------------------------------------------------------------------------