├── .bandit.yml
├── .coveragerc
├── .github
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
└── workflows
│ ├── ci_codecov.yaml
│ ├── ci_tox.yaml
│ ├── pypi_release.yaml
│ └── website_auto.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .pre-commit-hooks.yaml
├── .roberto.yaml
├── HEADER
├── LICENSE
├── MANIFEST.in
├── README.md
├── book
└── content
│ ├── _config.yml
│ ├── _toc.yml
│ ├── api_measures_converter.rst
│ ├── api_measures_diversity.rst
│ ├── api_measures_similarity.rst
│ ├── api_methods_base.rst
│ ├── api_methods_distance.rst
│ ├── api_methods_partition.rst
│ ├── api_methods_similarity.rst
│ ├── api_methods_utils.rst
│ ├── intro.md
│ ├── references.bib
│ ├── requirements.txt
│ └── selector_logo.png
├── codecov.yml
├── notebooks
├── tutorial_distance_based.ipynb
├── tutorial_diversity_measures.ipynb
├── tutorial_partition_based.ipynb
└── tutorial_similarity_based.ipynb
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
├── selector
├── __init__.py
├── measures
│ ├── __init__.py
│ ├── converter.py
│ ├── diversity.py
│ ├── similarity.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── test_converter.py
│ │ └── test_diversity.py
└── methods
│ ├── __init__.py
│ ├── base.py
│ ├── distance.py
│ ├── partition.py
│ ├── similarity.py
│ ├── tests
│ ├── __init__.py
│ ├── common.py
│ ├── data
│ │ ├── coords_imbalance_case1.txt
│ │ ├── coords_imbalance_case2.txt
│ │ ├── labels_imbalance_case1.txt
│ │ ├── labels_imbalance_case2.txt
│ │ └── ref_esim_selection_data.csv
│ ├── test_distance.py
│ ├── test_partition.py
│ └── test_similarity.py
│ └── utils.py
├── tox.ini
└── updateheaders.py
/.bandit.yml:
--------------------------------------------------------------------------------
1 | skips:
2 | # Ignore defensive `assert`s, 80% of repos do this
3 | - B101
4 | # Standard pseudo-random generators are not suitable for security/cryptographic purposes
5 | - B311
6 | # Ignore warnings about importing subprocess
7 | - B404
8 | # Ignore warnings about calling subprocess
9 | - B603
10 | # Ignore warnings about calling subprocess
11 | - B607
12 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | selector/*/tests/*
4 | selector/tests/*
5 | selector/__init__.py
6 | selector/_version.py
7 |
8 | [report]
9 | show_missing = True
10 | exclude_also =
11 | pragma: no cover
12 | raise NotImplementedError
13 | if __name__ == .__main__.:
14 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | It is recommended to follow the code of conduct as described in
4 | https://qcdevs.org/guidelines/QCDevsCodeOfConduct/.
5 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We welcome contributions from external contributors, and this document
4 | describes how to merge code changes into this selector.
5 |
6 | ## Getting Started
7 |
8 | * Make sure you have a [GitHub account](https://github.com/signup/free).
9 | * [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub.
10 | * On your local machine,
11 | [clone](https://help.github.com/articles/cloning-a-repository/) your fork of
12 | the repository.
13 |
14 | ## Making Changes
15 |
16 | * Add some really awesome code to your local fork. It's usually a
17 | [good idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/)
18 | to make changes on a
19 | [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/)
20 | with the branch name relating to the feature you are going to add.
21 | * When you are ready for others to examine and comment on your new feature,
22 | navigate to your fork of selector on GitHub and open a
23 | * [pull request](https://help.github.com/articles/using-pull-requests/) (PR). Note that
24 | after you launch a PR from one of your fork's branches, all
25 | subsequent commits to that branch will be added to the open pull request
26 | automatically. Each commit added to the PR will be validated for
27 | mergability, compilation and test suite compliance; the results of these tests
28 | will be visible on the PR page.
29 | * If you're providing a new feature, you must add test cases and documentation.
30 | * When the code is ready to go, make sure you run the test suite using pytest.
31 | * When you're ready to be considered for merging, check the "Ready to go"
32 | box on the PR page to let the selector devs know that the changes are complete.
33 | The code will not be merged until this box is checked, the continuous
34 | integration returns checkmarks,
35 | and multiple core developers give "Approved" reviews.
36 |
37 | # Python Virtual Environment for Package Development
38 |
39 | Here is a list of version information for different packages that we used for
40 | [selector](https://github.com/theochem/selector),
41 |
42 | ```bash
43 | python==3.7.11
44 | rdkit==2020.09.1.0
45 | numpy==1.21.2
46 | scipy==1.7.3
47 | pytest==6.2.5
48 | pytest-cov==3.0.0
49 | tox==3.24.5
50 | flake8==4.0.1
51 | pylint==2.12.2
52 | codecov=2.1.12
53 | # more to be added
54 | ```
55 |
56 | `Conda`, [`venv`](https://docs.python.org/3/library/venv.html#module-venv) and
57 | [`virtualenv`](https://virtualenv.pypa.io/en/latest/) are your good friends and anyone of them
58 | is very helpful. I prefer `Miniconda` on my local machine.
59 |
60 | # Additional Resources
61 |
62 | * [General GitHub documentation](https://help.github.com/)
63 | * [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/)
64 | * [A guide to contributing to software packages](http://www.contribution-guide.org)
65 | * [Thinkful PR example](http://www.thinkful.com/learn/github-pull-request-tutorial/#Time-to-Submit-Your-First-PR)
66 |
--------------------------------------------------------------------------------
/.github/workflows/ci_codecov.yaml:
--------------------------------------------------------------------------------
1 | name: CI CodeCov
2 |
3 | on:
4 | # GitHub has started calling new repo's first branch "main" https://github.com/github/renaming
5 | # Existing codes likely still have "master" as the primary branch
6 | # Both are tracked here to keep legacy and new codes working
7 |
8 | # push:
9 | # branches:
10 | # - "master"
11 | # - "main"
12 |
13 | pull_request:
14 | branches:
15 | - "master"
16 | - "main"
17 | schedule:
18 | # Nightly tests run on master by default:
19 | # Scheduled workflows run on the latest commit on the default or base branch.
20 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
21 | - cron: "0 0 * * *"
22 |
23 | jobs:
24 | test:
25 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
26 | runs-on: ${{ matrix.os }}
27 | strategy:
28 | matrix:
29 | os: [ubuntu-latest]
30 | python-version: ["3.11"]
31 |
32 | steps:
33 | - uses: actions/checkout@v4
34 | - name: Additional info about the build
35 | shell: bash
36 | run: |
37 | uname -a
38 | df -h
39 | ulimit -a
40 |
41 | - name: Set up Python ${{ matrix.python-version }}
42 | uses: actions/setup-python@v5
43 | with:
44 | python-version: ${{ matrix.python-version }}
45 | architecture: x64
46 |
47 | - name: Install dependencies
48 | shell: bash
49 | run: |
50 | python -m pip install --upgrade pip
51 | python -m pip install -r requirements_dev.txt
52 | python -m pip install -U pytest pytest-cov codecov
53 |
54 | - name: Install package
55 | shell: bash
56 | run: |
57 | python -m pip install .
58 | # pip install tox tox-gh-actions
59 |
60 | - name: Run tests
61 | shell: bash
62 | run: |
63 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector
64 |
65 | - name: CodeCov
66 | uses: codecov/codecov-action@v4.5.0
67 | with:
68 | token: ${{ secrets.CODECOV_TOKEN }}
69 | # Temp fix for https://github.com/codecov/codecov-action/issues/1487
70 | version: v0.6.0
71 | fail_ci_if_error: true
72 | file: ./coverage.xml
73 | flags: unittests
74 | name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
75 |
--------------------------------------------------------------------------------
/.github/workflows/ci_tox.yaml:
--------------------------------------------------------------------------------
1 | name: CI Tox
2 |
3 | on:
4 | # GitHub has started calling new repo's first branch "main" https://github.com/github/renaming
5 | # Existing codes likely still have "master" as the primary branch
6 | # Both are tracked here to keep legacy and new codes working
7 | push:
8 | branches:
9 | - "master"
10 | - "main"
11 | pull_request:
12 | branches:
13 | - "master"
14 | - "main"
15 | schedule:
16 | # Nightly tests run on master by default:
17 | # Scheduled workflows run on the latest commit on the default or base branch.
18 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
19 | - cron: "0 0 * * *"
20 |
21 | jobs:
22 | test:
23 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
24 | runs-on: ${{ matrix.os }}
25 | strategy:
26 | matrix:
27 | # os: [macOS-latest, windows-latest, ubuntu-latest]
28 | os: [macos-13, windows-latest, ubuntu-latest]
29 | python-version: ["3.9", "3.10", "3.11", "3.12"]
30 |
31 | steps:
32 | - uses: actions/checkout@v4
33 | - name: Additional info about the build
34 | shell: bash
35 | run: |
36 | uname -a
37 | df -h
38 | ulimit -a
39 |
40 | - name: Set up Python ${{ matrix.python-version }}
41 | uses: actions/setup-python@v5
42 | with:
43 | python-version: ${{ matrix.python-version }}
44 | architecture: x64
45 |
46 | - name: Install dependencies
47 | shell: bash
48 | run: |
49 | python -m pip install --upgrade pip
50 | python -m pip install -r requirements_dev.txt
51 | # python -m pip install -U pytest pytest-cov codecov
52 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector
53 |
54 | - name: Install package
55 | shell: bash
56 | run: |
57 | python -m pip install .
58 | pip install tox tox-gh-actions
59 |
60 | - name: Run tests
61 | shell: bash
62 | run: |
63 | tox
64 |
--------------------------------------------------------------------------------
/.github/workflows/pypi_release.yaml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 | on:
3 | push:
4 | tags:
5 | # Trigger on version tags (e.g., v1.0.0)
6 | - "v*.*.*"
7 | # Trigger on pre-release tags (e.g., v1.0.0-alpha.1)
8 | - "v*.*.*-*"
9 |
10 | env:
11 | # package name
12 | PYPI_NAME: qc-selector
13 |
14 | jobs:
15 | build:
16 | name: Build and Test Distribution
17 | runs-on: ${{ matrix.os }}
18 |
19 | strategy:
20 | matrix:
21 | # os: [ubuntu-latest, macos-latest, windows-latest]
22 | os: [ubuntu-latest]
23 | # python-version: ["3.9", "3.10", "3.11", "3.12"]
24 | python-version: ["3.11"]
25 | outputs:
26 | os: ${{ matrix.os }}
27 | python-version: ${{ matrix.python-version }}
28 |
29 | steps:
30 | - uses: actions/checkout@v4
31 | with:
32 | fetch-depth: 0
33 | - name: Set up Python
34 | uses: actions/setup-python@v5
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 | # python-version: "3.11"
38 | - name: Install dependencies
39 | run: |
40 | python -m pip install --upgrade pip
41 | python -m pip install build pytest pytest-cov codecov
42 | python -m pip install -r requirements.txt
43 | python -m pip install -r requirements_dev.txt
44 | - name: Test package
45 | run: |
46 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector
47 | - name: Build package
48 | run: python -m build
49 | - name: Store the distribution packages
50 | uses: actions/upload-artifact@v4
51 | with:
52 | # Name of the artifact to upload, unique for each OS and Python version
53 | name: python-package-distributions
54 | path: dist/
55 | # Optional parameters for better artifact management
56 | overwrite: false
57 | include-hidden-files: false
58 |
59 | publish-to-pypi:
60 | name: Publish Python distribution to PyPI
61 | if: startsWith(github.ref, 'refs/tags/v')
62 | needs: build
63 | runs-on: ubuntu-latest
64 | environment:
65 | name: PyPI-Release
66 | url: https://pypi.org/p/${{ env.PYPI_NAME }}
67 | permissions:
68 | id-token: write
69 |
70 | steps:
71 | - name: Download all the dists
72 | uses: actions/download-artifact@v4
73 | with:
74 | name: python-package-distributions
75 | path: dist/
76 | - name: Publish distribution to PyPI
77 | uses: pypa/gh-action-pypi-publish@release/v1
78 | env:
79 | TWINE_USERNAME: "__token__"
80 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
81 |
82 | github-release:
83 | name: Sign and Upload Python Distribution to GitHub Release
84 | needs:
85 | - build
86 | - publish-to-pypi
87 | runs-on: ubuntu-latest
88 | permissions:
89 | contents: write
90 | id-token: write
91 |
92 | steps:
93 | - name: Download all the dists
94 | uses: actions/download-artifact@v4
95 | with:
96 | name: python-package-distributions
97 | path: dist/
98 | - name: Sign the dists with Sigstore
99 | uses: sigstore/gh-action-sigstore-python@v3.0.0
100 | with:
101 | inputs: >-
102 | ./dist/*.tar.gz
103 | ./dist/*.whl
104 | - name: Create GitHub Release
105 | env:
106 | GITHUB_TOKEN: ${{ github.token }}
107 | run: >
108 | gh release create
109 | '${{ github.ref_name }}'
110 | --repo '${{ github.repository }}'
111 | --notes ""
112 | - name: Upload artifact signatures to GitHub Release
113 | env:
114 | GITHUB_TOKEN: ${{ github.token }}
115 | run: >
116 | gh release upload
117 | '${{ github.ref_name }}' dist/**
118 | --repo '${{ github.repository }}'
119 |
120 | publish-none-pypi:
121 | name: Publish Python distribution to TestPyPI (none)
122 | if: startsWith(github.ref, 'refs/tags/v')
123 | needs: build
124 | runs-on: ubuntu-latest
125 | environment:
126 | name: TestPyPI
127 | url: https://test.pypi.org/p/${{ env.PYPI_NAME }}
128 | permissions:
129 | id-token: write
130 |
131 | steps:
132 | - name: Download all the dists
133 | uses: actions/download-artifact@v4
134 | with:
135 | name: python-package-distributions
136 | path: dist/
137 | - name: Publish distribution with relaxed constraints
138 | uses: pypa/gh-action-pypi-publish@release/v1
139 | with:
140 | repository-url: https://test.pypi.org/legacy/
141 | env:
142 | TWINE_USERNAME: "__token__"
143 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }}
144 |
--------------------------------------------------------------------------------
/.github/workflows/website_auto.yaml:
--------------------------------------------------------------------------------
1 | name: deploy-book
2 |
3 | # Only run this when the master branch changes
4 | on:
5 | push:
6 | branches:
7 | - main
8 | pull_request:
9 | types: [opened, synchronize, reopened, closed]
10 | branches:
11 | - main
12 | # If your git repository has the Jupyter Book within some-subfolder next to
13 | # unrelated files, you can make this run only if a file within that specific
14 | # folder has been modified.
15 | #
16 | #paths:
17 | #- book/
18 |
19 | # This job installs dependencies, builds the book, and pushes it to `gh-pages`
20 | jobs:
21 | deploy-book:
22 | runs-on: ubuntu-latest
23 | permissions:
24 | pages: write
25 | # https://github.com/JamesIves/github-pages-deploy-action/issues/1110
26 | contents: write
27 |
28 | steps:
29 | - uses: actions/checkout@v4
30 |
31 | # Install dependencies
32 | - name: Set up Python 3.11
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: 3.11
36 |
37 | - name: Install dependencies
38 | run: |
39 | pip install -r book/content/requirements.txt
40 | # Install selector
41 | - name: Install package
42 | run: |
43 | pip install -e .
44 |
45 | # Build the book
46 | - name: Build the book
47 | run: |
48 | cp notebooks/*.ipynb book/content/.
49 | jupyter-book build ./book/content
50 |
51 | # Push the book's HTML to github-pages
52 | # inspired by https://github.com/orgs/community/discussions/26724
53 | # only push to gh-pages if the main branch has been updated
54 | - name: GitHub Pages Action
55 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
56 | uses: peaceiris/actions-gh-pages@v3
57 | with:
58 | github_token: ${{ secrets.GITHUB_TOKEN }}
59 | publish_dir: ./book/content/_build/html
60 | publish_branch: gh-pages
61 | cname: selector.qcdevs.org
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | doc/_build/
73 | doc/html/
74 | doc/latex/
75 | doc/man/
76 | doc/xml/
77 | doc/source
78 | doc/modules
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 | # Editor junk
138 | tags
139 | [._]*.s[a-v][a-z]
140 | [._]*.sw[a-p]
141 | [._]s[a-v][a-z]
142 | [._]sw[a-p]
143 | *~
144 | \#*\#
145 | .\#*
146 | .ropeproject
147 | .idea/
148 | .spyderproject
149 | .spyproject
150 | .vscode/
151 | # Mac .DS_Store
152 | .DS_Store
153 |
154 | # jupyter notebook checkpoints
155 | .ipynb_checkpoints
156 |
157 | # codecov files
158 | *.gcno
159 | *.gcda
160 | *.gcov
161 |
162 | */*/_build
163 | */_build
164 |
165 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # pre-commit is a tool to perform a predefined set of tasks manually and/or
2 | # automatically before git commits are made.
3 | #
4 | # Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level
5 | #
6 | # Common tasks
7 | #
8 | # - Run on all files: pre-commit run --all-files
9 | # - Register git hooks: pre-commit install --install-hooks
10 | #
11 |
12 | # https://github.com/jupyterhub/jupyterhub/blob/main/.pre-commit-config.yaml
13 | # https://github.com/pre-commit/pre-commit/blob/main/.pre-commit-config.yaml
14 | # https://github.com/psf/black/blob/main/.pre-commit-config.yaml
15 | # https://docs.releng.linuxfoundation.org/en/latest/best-practices.html
16 |
17 | ci:
18 | # https://pre-commit.ci/
19 | # pre-commit.ci will open PRs updating our hooks once a month
20 | # 'weekly', 'monthly', 'quarterly'
21 | autoupdate_schedule: quarterly
22 | autofix_prs: false
23 | autofix_commit_msg: "[pre-commit.ci] Auto fixes from pre-commit.com hooks."
24 | submodules: false
25 |
26 | repos:
27 | - repo: https://github.com/pre-commit/pre-commit-hooks
28 | rev: v4.5.0
29 | hooks:
30 | # https://pre-commit.com/hooks.html
31 | # - id: fix-encoding-pragma
32 | - id: sort-simple-yaml
33 | - id: trailing-whitespace
34 | - id: end-of-file-fixer
35 | - id: check-yaml
36 | - id: check-toml
37 | # replaces or checks mixed line ending
38 | - id: mixed-line-ending
39 | - id: debug-statements
40 | # - id: name-tests-test
41 | - id: trailing-whitespace
42 | # sorts entries in requirements.txt
43 | - id: requirements-txt-fixer
44 |
45 | # we are not using this for now as we decided to use pyproject.toml instead
46 | ## https://github.com/asottile/setup-cfg-fmt
47 | #- repo: https://github.com/asottile/setup-cfg-fmt
48 | # rev: v2.5.0
49 | # hooks:
50 | # - id: setup-cfg-fmt
51 |
52 | # - repo: https://github.com/psf/black-pre-commit-mirror
53 | - repo: https://github.com/psf/black
54 | rev: 24.3.0
55 | hooks:
56 | - id: black
57 | args: [
58 | --line-length=100,
59 | ]
60 |
61 | - repo: https://github.com/pycqa/isort
62 | rev: 5.13.2
63 | hooks:
64 | - id: isort
65 |
66 | # todo: disable flake8 for now, but will need to add it back in the future
67 | # - repo: https://github.com/pycqa/flake8
68 | # rev: 6.1.0
69 | # hooks:
70 | # - id: flake8
71 | # args: ["--max-line-length=100"]
72 | # additional_dependencies:
73 | # - flake8-bugbear
74 | # - flake8-comprehensions
75 | # - flake8-simplify
76 | # - flake8-docstrings
77 | # - flake8-import-order>=0.9
78 | # - flake8-colors
79 | # exclude: ^src/blib2to3/
80 |
81 | - repo: https://github.com/pycqa/bandit
82 | rev: 1.7.8
83 | hooks:
84 | - id: bandit
85 | # exclude some directories
86 | exclude: |
87 | (?x)(
88 | ^test/|
89 | ^book/
90 | ^devtools/
91 | ^docs/
92 | ^doc/
93 | )
94 | args: [
95 | "-c", "pyproject.toml"
96 | ]
97 | additional_dependencies: [ "bandit[toml]" ]
98 |
99 | - repo: https://github.com/asottile/pyupgrade
100 | rev: v3.15.2
101 | hooks:
102 | - id: pyupgrade
103 | args: [--py37-plus]
104 |
105 | # todo: add this for type checking in the future
106 | # - repo: https://github.com/pre-commit/mirrors-mypy
107 | # rev: v1.5.1
108 | # hooks:
109 | # - id: mypy
110 | # exclude: ^docs/conf.py
111 | # args: ["--config-file", "pyproject.toml"]
112 | # additional_dependencies:
113 | # - types-PyYAML
114 | # - tomli >= 0.2.6, < 2.0.0
115 | # - click >= 8.1.0, != 8.1.4, != 8.1.5
116 | # - packaging >= 22.0
117 | # - platformdirs >= 2.1.0
118 | # - pytest
119 | # - hypothesis
120 | # - aiohttp >= 3.7.4
121 | # - types-commonmark
122 | # - urllib3
123 | # - hypothesmith
124 |
125 | # - repo: https://github.com/pre-commit/mirrors-mypy
126 | # rev: v1.6.0
127 | # hooks:
128 | # - id: mypy
129 | # additional_dependencies: [types-all]
130 | # exclude: ^testing/resources/
131 |
--------------------------------------------------------------------------------
/.pre-commit-hooks.yaml:
--------------------------------------------------------------------------------
1 | # https://github.com/pre-commit/pre-commit-hooks/blob/main/.pre-commit-hooks.yaml
2 |
3 | - id: trailing-whitespace
4 | name: trim trailing whitespace
5 | description: trims trailing whitespace.
6 | entry: trailing-whitespace-fixer
7 | language: python
8 | types: [text]
9 | # exclude: ^(book|devtools|docs|doc)/
10 | stages: [commit, push, manual]
11 |
12 | - id: sort-simple-yaml
13 | name: sort simple yaml files
14 | description: sorts simple yaml files which consist only of top-level keys, preserving comments and blocks.
15 | language: python
16 | entry: sort-simple-yaml
17 | files: '^$'
18 |
19 | - id: black
20 | name: black
21 | description: "Black: The uncompromising Python code formatter"
22 | entry: black
23 | language: python
24 | minimum_pre_commit_version: 3.4.0
25 | require_serial: true
26 | types_or: [python, pyi]
27 |
28 | - id: check-toml
29 | name: check toml
30 | description: checks toml files for parseable syntax.
31 | entry: check-toml
32 | language: python
33 | types: [toml]
34 |
35 | - id: check-yaml
36 | name: check yaml
37 | description: checks yaml files for parseable syntax.
38 | entry: check-yaml
39 | language: python
40 | # exclude: ^website.yml
41 | types: [yaml]
42 |
43 | - id: requirements-txt-fixer
44 | name: fix requirements.txt
45 | description: sorts entries in requirements.txt.
46 | entry: requirements-txt-fixer
47 | language: python
48 | files: (requirements|constraints).*\.txt$
49 |
50 | - id: mixed-line-ending
51 | name: mixed line ending
52 | description: replaces or checks mixed line ending.
53 | entry: mixed-line-ending
54 | language: python
55 | types: [text]
56 |
57 | - id: end-of-file-fixer
58 | name: fix end of files
59 | description: ensures that a file is either empty, or ends with one newline.
60 | entry: end-of-file-fixer
61 | language: python
62 | types: [python]
63 | stages: [commit, push, manual]
64 |
65 | - id: debug-statements
66 | name: debug statements (python)
67 | description: checks for debugger imports and py37+ `breakpoint()` calls in python source.
68 | entry: debug-statement-hook
69 | language: python
70 | types: [python]
71 |
72 | - id: check-merge-conflict
73 | name: check for merge conflicts
74 | description: checks for files that contain merge conflict strings.
75 | entry: check-merge-conflict
76 | language: python
77 | types: [text]
78 |
79 | - id: check-ast
80 | name: check python ast
81 | description: simply checks whether the files parse as valid python.
82 | entry: check-ast
83 | language: python
84 | types: [python]
85 |
--------------------------------------------------------------------------------
/.roberto.yaml:
--------------------------------------------------------------------------------
1 | # Force absolute comparison for cardboardlint
2 | absolute: true
3 | project:
4 | name: selector
5 | # requirements: [[numpydoc, numpydoc], [sphinx-autoapi, sphinx-autoapi], [sphinxcontrib-bibtex, sphinxcontrib-bibtex]]
6 | packages:
7 | - dist_name: selector
8 | tools:
9 | - write-py-version
10 | # - cardboardlint-static
11 | - build-py-inplace
12 | - pytest
13 | - upload-codecov
14 | # - cardboardlint-dynamic
15 | # - build-sphinx-doc
16 | # - upload-docs-gh
17 | - build-py-source
18 | - build-conda
19 | - deploy-pypi
20 | - deploy-conda
21 | # - deploy-github
22 |
--------------------------------------------------------------------------------
/HEADER:
--------------------------------------------------------------------------------
1 |
2 | The Selector is a Python library of algorithms for selecting diverse
3 | subsets of data for machine-learning.
4 |
5 | Copyright (C) 2022-2024 The QC-Devs Community
6 |
7 | This file is part of Selector.
8 |
9 | Selector is free software; you can redistribute it and/or
10 | modify it under the terms of the GNU General Public License
11 | as published by the Free Software Foundation; either version 3
12 | of the License, or (at your option) any later version.
13 |
14 | Selector is distributed in the hope that it will be useful,
15 | but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | GNU General Public License for more details.
18 |
19 | You should have received a copy of the GNU General Public License
20 | along with this program; if not, see
21 |
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include MANIFEST.in
3 | include CODE_OF_CONDUCT.md
4 | include versioneer.py
5 |
6 | prune notebooks
7 | prune book
8 |
9 | graft selector
10 | global-exclude *.py[cod] __pycache__ *.so
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 | [](https://python.org/downloads)
7 | [](https://opensource.org/licenses/)
8 | [](https://github.com/theochem/Selector/actions/workflows/ci_tox.yaml)
9 | [](https://codecov.io/gh/theochem/Selector)
10 |
11 | The `Selector` library provides methods for selecting a diverse subset of a (molecular) dataset.
12 |
13 | ## Citation
14 |
15 | Please use the following citation in any publication using the `selector` library:
16 |
17 | ```md
18 | @article{
19 | TO BE ADDED LATER
20 | }
21 | ```
22 |
23 | ## Web Server
24 |
25 | We have a web server for the `selector` library at https://huggingface.co/spaces/QCDevs/Selector.
26 | For small and prototype datasets, you can use the web server to select a diverse subset of your
27 | dataset and compute the diversity metrics, where you can download the selected subset and the
28 | computed diversity metrics.
29 |
30 | ## Installation
31 |
32 | It is recommended to install `selector` within a virtual environment. To create a virtual
33 | environment, we can use the `venv` module (Python 3.3+,
34 | https://docs.python.org/3/tutorial/venv.html), `miniconda` (https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html), or
35 | `pipenv` (https://pipenv.pypa.io/en/latest/).
36 |
37 | ### Installing from PyPI
38 |
39 | To install `selector` with `pip`, we can install the latest stable release from the Python Package Index (PyPI) as follows:
40 |
41 | ```bash
42 | # install the stable release.
43 | pip install qc-selector
44 | ```
45 |
46 | ### Installing from The Prebuild Wheel Files
47 |
48 | To download the prebuilt wheel files, visit the [PyPI page](https://pypi.org/project/qc-selector/)
49 | and [GitHub releases](https://github.com/theochem/Selector/tags).
50 |
51 | ```bash
52 | # download the wheel file first to your local machine
53 | # then install the wheel file
54 | pip install file_path/qc_selector-0.0.2b12-py3-none-any.whl
55 | ```
56 |
57 | ### Installing from the Source Code
58 |
59 | In addition, we can install the latest development version from the GitHub repository as follows:
60 |
61 | ```bash
62 | # install the latest development version
63 | pip install git+https://github.com/theochem/Selector.git
64 | ```
65 |
66 | We can also clone the repository to access the latest development version, test it and install it as follows:
67 |
68 | ```bash
69 | # clone the repository
70 | git clone git@github.com:theochem/Selector.git
71 |
72 | # change into the working directory
73 | cd Selector
74 | # run the tests
75 | python -m pytest .
76 |
77 | # install the package
78 | pip install .
79 |
80 | ```
81 |
82 | ## More
83 |
84 | See https://selector.qcdevs.org for full details.
85 |
--------------------------------------------------------------------------------
/book/content/_config.yml:
--------------------------------------------------------------------------------
1 | # Book settings
2 | # Learn more at https://jupyterbook.org/customize/config.html
3 |
4 | title: Ayers Lab
5 | author: The QC-Dev Community with the support of Digital Research Alliance of Canada and The Jupyter Book Community
6 | logo: selector_logo.png
7 |
8 | # Force re-execution of notebooks on each build.
9 | # See https://jupyterbook.org/content/execute.html
10 | execute:
11 | execute_notebooks: 'off'
12 |
13 | # Define the name of the latex output file for PDF builds
14 | latex:
15 | latex_documents:
16 | targetname: book.tex
17 |
18 | sphinx:
19 | extra_extensions:
20 | - 'sphinx.ext.autodoc'
21 | config:
22 | mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
23 |
24 |
25 | #Read LaTeX
26 | parse:
27 | myst_enable_extensions: # default extensions to enable in the myst parser. See https://myst-parser.readthedocs.io/en/latest/using/syntax-optional.html
28 | # - amsmath
29 | - colon_fence
30 | # - deflist
31 | - dollarmath
32 | # - html_admonition
33 | # - html_image
34 | - linkify
35 | # - replacements
36 | # - smartquotes
37 | - substitution
38 | - tasklist
39 | myst_url_schemes: [mailto, http, https] # URI schemes that will be recognised as external URLs in Markdown links
40 | myst_dmath_double_inline: true # Allow display math ($$) within an inline context
41 | myst_dmath_allow_space: false
42 |
43 | # Add a bibtex file so that we can create citations
44 | bibtex_bibfiles:
45 | - references.bib
46 |
47 | # Information about where the book exists on the web
48 | repository:
49 | url: https://github.com/theochem/Selector # Online location of your book
50 | path_to_book: book/content # Optional path to your book, relative to the repository root
51 | branch: main # Which branch of the repository should be used when creating links (optional)
52 |
53 |
54 | # Add GitHub buttons to your book
55 | launch_buttons:
56 | thebe : true
57 | colab_url: "https://colab.research.google.com"
58 |
59 | copyright: "2022-2024"
60 |
61 | html:
62 | use_issues_button: true
63 | use_repository_button: true
64 | favicon: "selector_logo.png"
65 | use_multitoc_numbering: false
66 |
--------------------------------------------------------------------------------
/book/content/_toc.yml:
--------------------------------------------------------------------------------
1 | # Table of contents
2 | # Learn more at https://jupyterbook.org/customize/toc.html
3 |
4 | format: jb-book
5 | root: intro.md
6 | parts:
7 | - caption: Tutorial
8 | chapters:
9 | - file: tutorial_distance_based.ipynb
10 | - file: tutorial_partition_based.ipynb
11 | - file: tutorial_diversity_measures.ipynb
12 | - file: tutorial_similarity_based.ipynb
13 | - caption: API
14 | chapters:
15 | - file: api_methods_base.rst
16 | - file: api_methods_distance.rst
17 | - file: api_methods_partition.rst
18 | - file: api_methods_similarity.rst
19 | - file: api_methods_utils.rst
20 | - file: api_measures_converter.rst
21 | - file: api_measures_diversity.rst
22 | - file: api_measures_similarity.rst
23 |
--------------------------------------------------------------------------------
/book/content/api_measures_converter.rst:
--------------------------------------------------------------------------------
1 | .. _measures.converter:
2 |
3 | :mod:`selector.measures.converter`
4 | ==================================
5 |
6 | .. automodule:: selector.measures.converter
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_measures_diversity.rst:
--------------------------------------------------------------------------------
1 | .. _measures.diversity:
2 |
3 | :mod:`selector.measures.diversity`
4 | ==================================
5 |
6 | .. automodule:: selector.measures.diversity
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_measures_similarity.rst:
--------------------------------------------------------------------------------
1 | .. _measures.similarity:
2 |
3 | :mod:`selector.measures.similarity`
4 | ===================================
5 |
6 | .. automodule:: selector.measures.similarity
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_methods_base.rst:
--------------------------------------------------------------------------------
1 | .. _methods.base:
2 |
3 | :mod:`selector.methods.base`
4 | ============================
5 |
6 | .. automodule:: selector.methods.base
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_methods_distance.rst:
--------------------------------------------------------------------------------
1 | .. _methods.distance:
2 |
3 | :mod:`selector.methods.distance`
4 | ================================
5 |
6 | .. automodule:: selector.methods.distance
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_methods_partition.rst:
--------------------------------------------------------------------------------
1 | .. _methods.partition:
2 |
3 | :mod:`selector.methods.partition`
4 | =================================
5 |
6 | .. automodule:: selector.methods.partition
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_methods_similarity.rst:
--------------------------------------------------------------------------------
1 | .. _methods.similarity:
2 |
3 | :mod:`selector.methods.similarity`
4 | ==================================
5 |
6 | .. automodule:: selector.methods.similarity
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/api_methods_utils.rst:
--------------------------------------------------------------------------------
1 | .. _methods.utils:
2 |
3 | :mod:`selector.methods.utils`
4 | =============================
5 |
6 | .. automodule:: selector.methods.utils
7 | :members:
8 | :undoc-members:
9 | :inherited-members:
10 | :private-members:
11 |
--------------------------------------------------------------------------------
/book/content/intro.md:
--------------------------------------------------------------------------------
1 |
2 | # Welcome to QC-Selector's Documentation!
3 |
4 | [Selector](https://github.com/theochem/Selector) is a free, open-source, and cross-platform Python library designed to help you effortlessly identify the most diverse subset of molecules from your dataset. Please use the following citation in any publication using Selector library:
5 |
6 | **"TO be added"**
7 |
8 | The Selector source code is hosted on [GitHub](https://github.com/theochem/Selector) and is released under the [GNU General Public License v3.0](https://github.com/theochem/Selector/blob/main/LICENSE). We welcome any contributions to the Selector library in accordance with our Code of Conduct; please see our [Contributing Guidelines](https://qcdevs.org/guidelines/QCDevsCodeOfConduct/). Please report any issues you encounter while using Selector library on [GitHub Issues](https://github.com/theochem/Selector/issues). For further information and inquiries please contact us at qcdevs@gmail.com.
9 |
10 |
11 | ## Why QC-Selector?
12 |
13 | Selecting diverse and representative subsets is crucial for the data-driven models and machine
14 | learning applications in many science and engineering disciplines, especially for molecular design
15 | and drug discovery. Motivated by this, we develop the Selector package, a free and open-source Python library for selecting diverse subsets.
16 |
17 | The Selector library implements a range of existing algorithms for subset sampling based on the
18 | distance between and similarity of samples, as well as tools based on spatial partitioning. In
19 | addition, it includes seven diversity measures for quantifying the diversity of a given set. We also
20 | implemented various mathematical formulations to convert similarities into dissimilarities.
21 |
22 |
23 | ## Web Server
24 |
25 | We have a web server for the `selector` library at https://huggingface.co/spaces/QCDevs/Selector.
26 | For small and prototype datasets, you can use the web server to select a diverse subset of your
27 | dataset and compute the diversity metrics, where you can download the selected subset and the
28 | computed diversity metrics.
29 |
30 | ## Installation
31 |
32 | It is recommended to install `selector` within a virtual environment. To create a virtual
33 | environment, we can use the `venv` module (Python 3.3+,
34 | https://docs.python.org/3/tutorial/venv.html), `miniconda` (https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html), or
35 | `pipenv` (https://pipenv.pypa.io/en/latest/).
36 |
37 | ### Installing from PyPI
38 |
39 | To install `selector` with `pip`, we can install the latest stable release from the Python Package Index (PyPI) as follows:
40 |
41 | ```bash
42 | # install the stable release.
43 | pip install qc-selector
44 | ```
45 |
46 | ### Installing from The Prebuild Wheel Files
47 |
48 | To download the prebuilt wheel files, visit the [PyPI page](https://pypi.org/project/qc-selector/)
49 | and [GitHub releases](https://github.com/theochem/Selector/tags).
50 |
51 | ```bash
52 | # download the wheel file first to your local machine
53 | # then install the wheel file
54 | pip install file_path/qc_selector-0.0.2b12-py3-none-any.whl
55 | ```
56 |
57 | ### Installing from the Source Code
58 |
59 | In addition, we can install the latest development version from the GitHub repository as follows:
60 |
61 | ```bash
62 | # install the latest development version
63 | pip install git+https://github.com/theochem/Selector.git
64 | ```
65 |
66 | We can also clone the repository to access the latest development version, test it and install it as follows:
67 |
68 | ```bash
69 | # clone the repository
70 | git clone git@github.com:theochem/Selector.git
71 |
72 | # change into the working directory
73 | cd Selector
74 | # run the tests
75 | python -m pytest .
76 |
77 | # install the package
78 | pip install .
79 |
80 | ```
81 |
--------------------------------------------------------------------------------
/book/content/references.bib:
--------------------------------------------------------------------------------
1 | ---
2 | ---
3 | @article{selector,
4 | title={selector: A Novel Approach for Selecting Diverse Molecular Subsets},
5 | author={QC-Devs},
6 | journal={Nature Chemistry},
7 | year={2023},
8 | doi={10.1234/natchem.2023.01234},
9 | }
10 |
--------------------------------------------------------------------------------
/book/content/requirements.txt:
--------------------------------------------------------------------------------
1 | docutils<0.18
2 | ghp-import
3 | jupyter-book
4 | matplotlib
5 | numpy
6 |
--------------------------------------------------------------------------------
/book/content/selector_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theochem/Selector/c04a253a928c6abe6c664c4090998187a5407e7e/book/content/selector_logo.png
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # Codecov configuration to make it a bit less noisy
2 | codecov:
3 | token: ${{ secrets.CODECOV_TOKEN }}
4 | notify:
5 | require_ci_to_pass: yes
6 |
7 | coverage:
8 | precision: 2
9 | round: down
10 | range: 70...100
11 |
12 | status:
13 | patch: false
14 |
15 | project:
16 | default:
17 | # inspired by
18 | # github.com/theochem/iodata/blob/73eee89111bcec426d9e8651ec5d85541c7d6d24/.codecov.yml#L1-L9
19 | # Commits and PRs are never marked as "failed" due coverage issues.
20 | # Codecov is only used as an informal tool when reviewing PRs,
21 | # not in the least because of the many false failures.
22 | target: 0%
23 | threshold: 100%
24 |
25 | # ignore statistics for the testing folders
26 | ignore:
27 | - .*/tests/.*
28 | - .*/.*/tests/.*
29 | - .*/examples/.*
30 | - .*/__int__.py
31 | - .*/_version.py
32 | - "test_*.rb" # wildcards accepted
33 | - .*/data/.*
34 | - .*/versioneer.py
35 | - .*/*/__init__.py
36 | # - "**/*.pyc" # glob accepted
37 |
38 | comment:
39 | layout: "reach, header, diff, uncovered, files, changes,"
40 | behavior: default
41 | require_changes: false # if true: only post the comment if coverage changes
42 | require_base: no # [yes :: must have a base report to post]
43 | require_head: yes # [yes :: must have a head report to post]
44 | branches: # branch names that can post comment
45 | - staging
46 | - main
47 |
--------------------------------------------------------------------------------
/notebooks/tutorial_diversity_measures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Tutorial Diversity Measures\n",
8 | "\n",
9 | "This tutorial demonstrates how to quantify the diversity of selected subset with `diversity` module as implemented in\n",
10 | "`selector` package. The diversity measures are calculated based on the feature matrix of the selected subset."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import sys\n",
20 | "import warnings\n",
21 | "\n",
22 | "warnings.filterwarnings(\"ignore\")\n",
23 | "\n",
24 | "# uncomment the following line to run the code for your own project directory\n",
25 | "# sys.path.append(\"/Users/Someone/Documents/projects/Selector\")"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import matplotlib.pylab as plt\n",
35 | "import numpy as np\n",
36 | "from sklearn.datasets import make_blobs\n",
37 | "from sklearn.metrics.pairwise import pairwise_distances\n",
38 | "from IPython.display import Markdown\n",
39 | "\n",
40 | "from selector.methods.distance import MaxMin, MaxSum, OptiSim, DISE\n",
41 | "from selector.diversity import compute_diversity, hypersphere_overlap_of_subset"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "## Utility Function for Showing Diversity Measures as A Table\n"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 3,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# define function to render tables easier\n",
58 | "\n",
59 | "\n",
60 | "def render_table(data, caption=None, decimals=3):\n",
61 | " \"\"\"Renders a list of lists in ta markdown table for easy visualization.\n",
62 | "\n",
63 | " Parameters\n",
64 | " ----------\n",
65 | " data : list of lists\n",
66 | " The data to be rendered in a table, each inner list represents a row with the first row\n",
67 | " being the header.\n",
68 | " caption : str, optional\n",
69 | " The caption of the table.\n",
70 | " decimals : int, optional\n",
71 | " The number of decimal places to round the data to.\n",
72 | " \"\"\"\n",
73 | "\n",
74 | " # check all rows have the same number of columns\n",
75 | " if not all(len(row) == len(data[0]) for row in data):\n",
76 | " raise ValueError(\"Expect all rows to have the same number of columns.\")\n",
77 | "\n",
78 | " if caption is not None:\n",
79 | " # check if caption is a string\n",
80 | " if not isinstance(caption, str):\n",
81 | " raise ValueError(\"Expect caption to be a string.\")\n",
82 | " tmp_output = f\"**{caption}**\\n\\n\"\n",
83 | "\n",
84 | " # get the width of each column (transpose the data list and get the max length of each new row)\n",
85 | " colwidths = [max(len(str(s)) for s in col) + 2 for col in zip(*data)]\n",
86 | "\n",
87 | " # construct the header row\n",
88 | " header = f\"| {' | '.join(f'{str(s):^{w}}' for s, w in zip(data[0], colwidths))} |\"\n",
89 | " tmp_output += header + \"\\n\"\n",
90 | "\n",
91 | " # construct a separator row\n",
92 | " separator = f\"|{'|'.join(['-' * w for w in colwidths])}|\"\n",
93 | " tmp_output += separator + \"\\n\"\n",
94 | "\n",
95 | " # construct the data rows\n",
96 | " for row in data[1:]:\n",
97 | " # round the data to the specified number of decimal places\n",
98 | " row = [round(s, decimals) if isinstance(s, float) else s for s in row]\n",
99 | " row_str = f\"| {' | '.join(f'{str(s):^{w}}' for s, w in zip(row, colwidths))} |\"\n",
100 | " tmp_output += row_str + \"\\n\"\n",
101 | "\n",
102 | " return display(Markdown(tmp_output))"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "## Generating Data\n",
110 | "\n",
111 | "The data should be provided as:\n",
112 | "\n",
113 | "- either an array `X` of shape `(n_samples, n_features)` encoding `n_samples` samples (rows) each in `n_features`-dimensional (columns) feature space,\n",
114 | "- or an array `X_dist` of shape `(n_samples, n_samples)` encoding the distance (i.e., dissimilarity) between each pair of `n_samples` sample points.\n",
115 | "\n",
116 | "This data can be loaded from various file formats (e.g., csv, npz, txt, etc.) or generated using various libraries on the fly. In this tutorial, we use [`sklearn.datasets.make_blobs`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html) to generate cluster(s) of `n_samples` points in 2-dimensions (`n-features=2`), so that it can be easily visualized. However, the same functionality can be applied to higher dimensional datasets.\n"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 4,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "Shape of data = (500, 20)\n",
129 | "Shape of labels = (500,)\n",
130 | "Unique labels = [0 1 2]\n",
131 | "Cluster size = 167\n",
132 | "Shape of the distance array = (500, 500)\n"
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "# generate n_sample data in 2D feature space forming 1 cluster\n",
138 | "X, labels = make_blobs(\n",
139 | " n_samples=500,\n",
140 | " n_features=20,\n",
141 | " # centers=np.array([[0.0, 0.0]]),\n",
142 | " random_state=42,\n",
143 | ")\n",
144 | "# binarize the fetures\n",
145 | "# Calculate median for each feature\n",
146 | "median_threshold = np.median(X, axis=0)\n",
147 | "X = (X > median_threshold).astype(int)\n",
148 | "\n",
149 | "# compute the (n_sample, n_sample) pairwise distance matrix\n",
150 | "X_dist = pairwise_distances(X, metric=\"euclidean\")\n",
151 | "\n",
152 | "print(\"Shape of data = \", X.shape)\n",
153 | "print(\"Shape of labels = \", labels.shape)\n",
154 | "print(\"Unique labels = \", np.unique(labels))\n",
155 | "print(\"Cluster size = \", np.count_nonzero(labels == 0))\n",
156 | "print(\"Shape of the distance array = \", X_dist.shape)"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "## Perform the Subset Selection"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 5,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "from scipy.spatial.distance import pdist, squareform\n",
173 | "\n",
174 | "# select data using distance base methods\n",
175 | "# ---------------------------------------\n",
176 | "size = 50\n",
177 | "\n",
178 | "collector = MaxMin()\n",
179 | "index_maxmin = collector.select(X_dist, size=size)\n",
180 | "\n",
181 | "collector = MaxSum(fun_dist=lambda x: squareform(pdist(x, metric=\"minkowski\", p=0.1)))\n",
182 | "index_maxsum = collector.select(X, size=size)\n",
183 | "\n",
184 | "collector = OptiSim(ref_index=0, tol=0.1)\n",
185 | "index_optisim = collector.select(X_dist, size=size)\n",
186 | "\n",
187 | "collector = DISE(ref_index=0, p=2.0)\n",
188 | "index_dise = collector.select(X, size=size)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 6,
194 | "metadata": {},
195 | "outputs": [
196 | {
197 | "data": {
198 | "text/markdown": [
199 | "**Diversity of Selected Sets**\n",
200 | "\n",
201 | "| | logdet | wdud | shannon_entropy | hypersphere_overlap |\n",
202 | "|---------|--------------------|---------------------|--------------------|---------------------|\n",
203 | "| MaxMin | 44.143 | 0.273 | 18.637 | 1299.615 |\n",
204 | "| MaxSum | 33.938 | 0.261 | 19.379 | 4396.672 |\n",
205 | "| OptiSim | 43.734 | 0.254 | 19.758 | 1175.49 |\n",
206 | "| DISE | 45.402 | 0.268 | 18.958 | 1363.434 |\n"
207 | ],
208 | "text/plain": [
209 | ""
210 | ]
211 | },
212 | "metadata": {},
213 | "output_type": "display_data"
214 | }
215 | ],
216 | "source": [
217 | "div_measure = [\"logdet\", \"wdud\", \"shannon_entropy\", \"hypersphere_overlap\"]\n",
218 | "seleced_sets = zip(\n",
219 | " [\"MaxMin\", \"MaxSum\", \"OptiSim\", \"DISE\"],\n",
220 | " [index_maxmin, index_maxsum, index_optisim, index_dise],\n",
221 | ")\n",
222 | "\n",
223 | "# compute the diversity of the selected sets and render the results in a table\n",
224 | "table_data = [[\"\"] + div_measure]\n",
225 | "for i in seleced_sets:\n",
226 | " div_data = [i[0]]\n",
227 | " for m in div_measure:\n",
228 | " if m != \"hypersphere_overlap\":\n",
229 | " div_data.append(compute_diversity(X[i[1]], div_type=m))\n",
230 | " else:\n",
231 | " div_data.append(hypersphere_overlap_of_subset(x=X, x_subset=X[i[1]]))\n",
232 | " table_data.append(div_data)\n",
233 | "\n",
234 | "render_table(table_data, caption=\"Diversity of Selected Sets\")"
235 | ]
236 | }
237 | ],
238 | "metadata": {
239 | "kernelspec": {
240 | "display_name": "selector_div",
241 | "language": "python",
242 | "name": "python3"
243 | },
244 | "language_info": {
245 | "codemirror_mode": {
246 | "name": "ipython",
247 | "version": 3
248 | },
249 | "file_extension": ".py",
250 | "mimetype": "text/x-python",
251 | "name": "python",
252 | "nbconvert_exporter": "python",
253 | "pygments_lexer": "ipython3",
254 | "version": "3.11.9"
255 | }
256 | },
257 | "nbformat": 4,
258 | "nbformat_minor": 2
259 | }
260 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # The Selector library provides a set of tools for selecting a
2 | # subset of the dataset and computing diversity.
3 | #
4 | # Copyright (C) 2023 The QC-Devs Community
5 | #
6 | # This file is part of Selector.
7 | #
8 | # Selector is free software; you can redistribute it and/or
9 | # modify it under the terms of the GNU General Public License
10 | # as published by the Free Software Foundation; either version 3
11 | # of the License, or (at your option) any later version.
12 | #
13 | # Selector is distributed in the hope that it will be useful,
14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | # GNU General Public License for more details.
17 | #
18 | # You should have received a copy of the GNU General Public License
19 | # along with this program; if not, see
20 | #
21 | # --
22 |
23 |
24 | [project]
25 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/
26 | name = "qc-selector"
27 | description = "Subset selection with maximum diversity."
28 | readme = {file = 'README.md', content-type='text/markdown'}
29 | #requires-python = ">=3.9,<4.0"
30 | requires-python = ">=3.9"
31 | # "LICENSE" is name of the license file, which must be in root of project folder
32 | license = {file = "LICENSE"}
33 | authors = [
34 | {name = "QC-Devs Community", email = "qcdevs@gmail.com"},
35 | ]
36 | keywords = [
37 | "subset selection",
38 | "variable selection",
39 | "chemical diversity",
40 | "compound selection",
41 | "maximum diversity",
42 | "chemical library design",
43 | "compound acquisition",
44 | ]
45 |
46 | # https://pypi.org/classifiers/
47 | # Add PyPI classifiers here
48 | classifiers = [
49 | "Development Status :: 5 - Production/Stable",
50 | "Environment :: Console",
51 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
52 | "Natural Language :: English",
53 | "Operating System :: MacOS",
54 | "Operating System :: Microsoft :: Windows",
55 | "Operating System :: Unix",
56 | "Operating System :: POSIX",
57 | "Operating System :: POSIX :: Linux",
58 | "Programming Language :: Python",
59 | "Programming Language :: Python :: 3",
60 | "Programming Language :: Python :: 3.9",
61 | "Programming Language :: Python :: 3.10",
62 | "Programming Language :: Python :: 3.11",
63 | "Programming Language :: Python :: 3.12",
64 | "Topic :: Scientific/Engineering",
65 | "Topic :: Sociology",
66 |
67 | ]
68 |
69 | # version = "0.0.2b4"
70 | dynamic = [
71 | "dependencies",
72 | "optional-dependencies",
73 | "version",
74 | ]
75 |
76 | # not using this section now but it's here for reference
77 | # # Required dependencies for install/usage of your package or application
78 | # # If you don't have any dependencies, leave this section empty
79 | # # Format for dependency strings: https://peps.python.org/pep-0508/
80 | # # dependencies from requirements.txt
81 | # dependencies = [
82 | # "bitarray>=2.5.1",
83 | # "numpy>=1.21.2",
84 | # "scipy>=1.11.1",
85 | # ]
86 |
87 | # [project.optional-dependencies]
88 | # tests = [
89 | # 'coverage>=5.0.3',
90 | # "pandas>=1.3.5",
91 | # "pytest>=7.4.0",
92 | # "scikit-learn>=1.0.1",
93 | # ]
94 |
95 | [project.scripts]
96 | # Command line interface entrypoint scripts
97 | # selector = "selector.__main__:main"
98 |
99 | [project.urls]
100 | # Use PyPI-standard names here
101 | # Homepage
102 | # Documentation
103 | # Changelog
104 | # Issue Tracker
105 | # Source
106 | # Discord server
107 | homepage = "https://github.com/theochem/Selector"
108 | documentation = "https://selector.qcdevs.org/"
109 | repository = "https://github.com/theochem/Selector"
110 |
111 | # Development dependencies
112 | # pip install -e .[lint,test,exe]
113 | # pip install -e .[dev]
114 |
115 | # we can only provide one optional dependencies or dynamic dependencies
116 | # we can't provide both, which leads to errors
117 | # [project.optional-dependencies]
118 | # lint = [
119 | # # ruff linter checks for issues and potential bugs
120 | # "ruff",
121 | # # checks for unused code
122 | # # "vulture",
123 | # # # required for codespell to parse pyproject.toml
124 | # # "tomli",
125 | # # # validation of pyproject.toml
126 | # # "validate-pyproject[all]",
127 | # # automatic sorting of imports
128 | # "isort",
129 | # # # automatic code formatting to follow a consistent style
130 | # # "black",
131 | # ]
132 |
133 | # test = [
134 | # # Handles most of the testing work, including execution
135 | # # Docs: https://docs.pytest.org/en/stable/contents.html
136 | # "pytest>=7.4.0",
137 | # # required by pytest
138 | # "hypothesis",
139 | # # "Coverage" is how much of the code is actually run (it's "coverage")
140 | # # Generates coverage reports from test suite runs
141 | # "pytest-cov>=3.0.0",
142 | # "tomli",
143 | # "scikit-learn>=1.0.1",
144 | # # Better parsing of doctests
145 | # "xdoctest",
146 | # # Colors for doctest output
147 | # "Pygments",
148 | # ]
149 |
150 | # exe = [
151 | # "setuptools",
152 | # "wheel",
153 | # "build",
154 | # "tomli",
155 | # "pyinstaller",
156 | # "staticx;platform_system=='Linux'",
157 | # ]
158 |
159 | # dev = [
160 | # # https://hynek.me/articles/python-recursive-optional-dependencies/
161 | # "selector[lint,test,exe]",
162 |
163 | # # # Code quality tools
164 | # # "mypy",
165 |
166 | # # # Improved exception traceback output
167 | # # # https://github.com/qix-/better-exceptions
168 | # # "better_exceptions",
169 |
170 | # # # Analyzing dependencies
171 | # # # install graphviz to generate graphs
172 | # # "graphviz",
173 | # # "pipdeptree",
174 | # ]
175 |
176 | [build-system]
177 | # Minimum requirements for the build system to execute.
178 | requires = ["setuptools>=64", "setuptools-scm>=8", "wheel"]
179 | build-backend = "setuptools.build_meta"
180 |
181 | [tool.setuptools.dynamic]
182 | dependencies = {file = ["requirements.txt"]}
183 | optional-dependencies = {dev = { file = ["requirements_dev.txt"] }}
184 |
185 | [tool.setuptools_scm]
186 | # can be empty if no extra settings are needed, presence enables setuptools-scm
187 |
188 | [tool.setuptools]
189 | # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html
190 | platforms = ["Linux", "Windows", "MacOS"]
191 | include-package-data = true
192 | # This just means it's safe to zip up the bdist
193 | zip-safe = true
194 |
195 | # Non-code data that should be included in the package source code
196 | # https://setuptools.pypa.io/en/latest/userguide/datafiles.html
197 | [tool.setuptools.package-data]
198 | selector = ["*.xml"]
199 |
200 | # Python modules and packages that are included in the
201 | # distribution package (and therefore become importable)
202 | [tool.setuptools.packages.find]
203 | exclude = ["*/*/tests", "tests_*", "examples", "notebooks"]
204 |
205 |
206 | # PDM example
207 | #[tool.pdm.scripts]
208 | #isort = "isort selector"
209 | #black = "black selector"
210 | #format = {composite = ["isort", "black"]}
211 | #check_isort = "isort --check selector tests"
212 | #check_black = "black --check selector tests"
213 | #vulture = "vulture --min-confidence 100 selector tests"
214 | #ruff = "ruff check selector tests"
215 | #fix = "ruff check --fix selector tests"
216 | #codespell = "codespell --toml ./pyproject.toml"
217 | #lint = {composite = ["vulture", "codespell", "ruff", "check_isort", "check_black"]}
218 |
219 |
220 | #[tool.codespell]
221 | ## codespell supports pyproject.toml since version 2.2.2
222 | ## NOTE: the "tomli" package must be installed for this to work
223 | ## https://github.com/codespell-project/codespell#using-a-config-file
224 | ## NOTE: ignore words for codespell must be lowercase
225 | #check-filenames = ""
226 | #ignore-words-list = "word,another,something"
227 | #skip = "htmlcov,.doctrees,*.pyc,*.class,*.ico,*.out,*.PNG,*.inv,*.png,*.jpg,*.dot"
228 |
229 |
230 | [tool.black]
231 | line-length = 100
232 | # If you need to exclude directories from being reformatted by black
233 | # force-exclude = '''
234 | # (
235 | # somedirname
236 | # | dirname
237 | # | filename\.py
238 | # )
239 | # '''
240 |
241 |
242 | [tool.isort]
243 | profile = "black"
244 | known_first_party = ["selector"]
245 | # If you need to exclude files from having their imports sorted
246 | #extend_skip_glob = [
247 | # "selector/somefile.py",
248 | # "selector/somedir/*",
249 | #]
250 |
251 |
252 | # https://beta.ruff.rs/docs
253 | [tool.ruff]
254 | line-length = 100
255 | show-source = true
256 |
257 | # Rules: https://beta.ruff.rs/docs/rules
258 | # If you violate a rule, lookup the rule on the Rules page in ruff docs.
259 | # Many rules have links you can click with a explanation of the rule and how to fix it.
260 | # If there isn't a link, go to the project the rule was source from (e.g. flake8-bugbear)
261 | # and review it's docs for the corresponding rule.
262 | # If you're still confused, ask a fellow developer for assistance.
263 | # You can also run "ruff rule " to explain a rule on the command line, without a browser or internet access.
264 | select = [
265 | "E", # pycodestyle
266 | "F", # Pyflakes
267 | "W", # Warning
268 | "B", # flake8-bugbear
269 | "A", # flake8-builtins
270 | "C4", # flake8-comprehensions
271 | "T10", # flake8-debugger
272 | "EXE", # flake8-executable,
273 | "ISC", # flake8-implicit-str-concat
274 | "G", # flake8-logging-format
275 | "PIE", # flake8-pie
276 | "T20", # flake8-print
277 | "PT", # flake8-pytest-style
278 | "RSE", # flake8-raise
279 | "RET", # flake8-return
280 | "TID", # flake8-tidy-imports
281 | "ARG", # flake8-unused-arguments
282 | "PGH", # pygrep-hooks
283 | "PLC", # Pylint Convention
284 | "PLE", # Pylint Errors
285 | "PLW", # Pylint Warnings
286 | "RUF", # Ruff-specific rules
287 |
288 | # ** Things to potentially enable in the future **
289 | # DTZ requires all usage of datetime module to have timezone-aware
290 | # objects (so have a tz argument or be explicitly UTC).
291 | # "DTZ", # flake8-datetimez
292 | # "PTH", # flake8-use-pathlib
293 | # "SIM", # flake8-simplify
294 | ]
295 |
296 | # Files to exclude from linting
297 | extend-exclude = [
298 | "*.pyc",
299 | "__pycache__",
300 | "*.egg-info",
301 | ".eggs",
302 | # check point files of jupyter notebooks
303 | "*.ipynb_checkpoints",
304 | ".tox",
305 | ".git",
306 | "build",
307 | "dist",
308 | "docs",
309 | "examples",
310 | "htmlcov",
311 | "notebooks",
312 | ".cache",
313 | "_version.py",
314 | ]
315 |
316 | # Linting error codes to ignore
317 | ignore = [
318 | "F403", # unable to detect undefined names from star imports
319 | "F405", # undefined locals from star imports
320 | "W605", # invalid escape sequence
321 | "A003", # shadowing python builtins
322 | "RET505", # unnecessary 'else' after 'return' statement
323 | "RET504", # Unnecessary variable assignment before return statement
324 | "RET507", # Unnecessary {branch} after continue statement
325 | "PT011", # pytest-raises-too-broad
326 | "PT012", # pytest.raises() block should contain a single simple statement
327 | "PLW0603", # Using the global statement to update is discouraged
328 | "PLW2901", # for loop variable overwritten by assignment target
329 | "G004", # Logging statement uses f-string
330 | "PIE790", # no-unnecessary-pass
331 | "PIE810", # multiple-starts-ends-with
332 | "PGH003", # Use specific rule codes when ignoring type issues
333 | "PLC1901", # compare-to-empty-string
334 | ]
335 |
336 | # Linting error codes to ignore on a per-file basis
337 | [tool.ruff.per-file-ignores]
338 | "__init__.py" = ["F401", "E501"]
339 | "selector/somefile.py" = ["E402", "E501"]
340 | "selector/somedir/*" = ["E501"]
341 |
342 |
343 | # Configuration for mypy
344 | # https://mypy.readthedocs.io/en/stable/config_file.html#using-a-pyproject-toml-file
345 | [tool.mypy]
346 | python_version = "3.9"
347 | follow_imports = "skip"
348 | ignore_missing_imports = true
349 | files = "selector" # directory mypy should analyze
350 | # Directories to exclude from mypy's analysis
351 | exclude = [
352 | "book",
353 | ]
354 |
355 |
356 | # Configuration for pytest
357 | # https://docs.pytest.org/en/latest/reference/customize.html#pyproject-toml
358 | [tool.pytest.ini_options]
359 | addopts = [
360 | # Allow test files to have the same name in different directories.
361 | "--import-mode=importlib",
362 | "--cache-clear",
363 | "--showlocals",
364 | "-v",
365 | "-r a",
366 | "--cov-report=term-missing",
367 | "--cov=selector",
368 | ]
369 | # directory containing the tests
370 | testpaths = [
371 | "selector/measures/tests",
372 | "selector/methods/tests",
373 | ]
374 | norecursedirs = [
375 | ".vscode",
376 | "__pycache__",
377 | "build",
378 | ]
379 | # Warnings that should be ignored
380 | filterwarnings = [
381 | "ignore::DeprecationWarning"
382 | ]
383 | # custom markers that can be used using pytest.mark
384 | markers = [
385 | "slow: lower-importance tests that take an excessive amount of time",
386 | ]
387 |
388 |
389 | # Configuration for coverage.py
390 | [tool.coverage.run]
391 | # files or directories to exclude from coverage calculations
392 | omit = [
393 | 'selector/measures/tests/*',
394 | 'selector/methods/tests/*',
395 | ]
396 |
397 |
398 | # Configuration for vulture
399 | [tool.vulture]
400 | # Files or directories to exclude from vulture
401 | # The syntax is a little funky
402 | exclude = [
403 | "somedir",
404 | "*somefile.py",
405 | ]
406 |
407 | # configuration for bandit
408 | [tool.bandit]
409 | exclude_dirs = [
410 | "selector/measures/tests",
411 | "selector/methods/tests",
412 | ]
413 | skips = [
414 | "B101", # Ignore assert statements
415 | "B311", # Ignore pseudo-random generators
416 | "B404", # Ignore subprocess import
417 | "B603", # Ignore subprocess call
418 | "B607", # Ignore subprocess call
419 | ]
420 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bitarray>=2.5.1
2 | numpy>=1.21.2
3 | scipy>=1.11.1
4 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | bitarray>=2.5.1
2 | coverage>=6.3.2
3 | hypothesis
4 | numpy>=1.21.2
5 | pre-commit
6 | pytest>=7.4.0
7 | pytest-cov>=3.0.0
8 | scikit-learn>=1.0.1
9 | scipy>=1.11.1
10 | setuptools>=64.0.0
11 | tomli
12 | xdoctest
13 | setuptools-scm>=8.0.0
14 |
--------------------------------------------------------------------------------
/selector/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 | """Selector Package."""
26 |
27 | from selector.methods import *
28 |
--------------------------------------------------------------------------------
/selector/measures/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
--------------------------------------------------------------------------------
/selector/measures/converter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Module for converting similarity measures to distance/dissimilarity measures."""
25 |
26 | from typing import Union
27 |
28 | import numpy as np
29 |
30 | __all__ = [
31 | "sim_to_dist",
32 | "reverse",
33 | "reciprocal",
34 | "exponential",
35 | "gaussian",
36 | "correlation",
37 | "transition",
38 | "co_occurrence",
39 | "gravity",
40 | "probability",
41 | "covariance",
42 | ]
43 |
44 |
45 | def sim_to_dist(
46 | x: Union[int, float, np.ndarray], metric: str, scaling_factor: float = 1.0
47 | ) -> Union[float, np.ndarray]:
48 | """Convert similarity coefficients to distance array.
49 |
50 | Parameters
51 | ----------
52 | x : float or ndarray
53 | A similarity value as float, or a 1D or 2D array of similarity values.
54 | If 2D, the array is assumed to be symmetric.
55 | metric : str
56 | String or integer specifying which conversion metric to use.
57 | Supported metrics are "reverse", "reciprocal", "exponential",
58 | "gaussian", "membership", "correlation", "transition", "co-occurrence",
59 | "gravity", "confusion", "probability", and "covariance".
60 | scaling_factor : float, optional
61 | Scaling factor for the distance array. Default is 1.0.
62 |
63 | Returns
64 | -------
65 | dist : float or ndarray
66 | Distance value or array.
67 | """
68 | # scale the distance matrix
69 | x = x * scaling_factor
70 |
71 | frequency = {
72 | "transition": transition,
73 | "co-occurrence": co_occurrence,
74 | "gravity": gravity,
75 | }
76 | method_dict = {
77 | "reverse": reverse,
78 | "reciprocal": reciprocal,
79 | "exponential": exponential,
80 | "gaussian": gaussian,
81 | "correlation": correlation,
82 | "probability": probability,
83 | "covariance": covariance,
84 | }
85 |
86 | # check if x is a single value
87 | single_value = False
88 | if isinstance(x, (float, int)):
89 | x = np.array([[x]])
90 | single_value = True
91 |
92 | # check x
93 | if not isinstance(x, np.ndarray):
94 | raise ValueError(f"Argument x should be a numpy array instead of {type(x)}")
95 | # check that x is a valid array
96 | if x.ndim != 1 and x.ndim != 2:
97 | raise ValueError(f"Argument x should either have 1 or 2 dimensions, got {x.ndim}.")
98 | if x.ndim == 1 and metric in ["co-occurrence", "gravity"]:
99 | raise ValueError(f"Argument x should be a 2D array when using the {metric} metric.")
100 | # check if x is symmetric
101 | if x.ndim == 2 and not np.allclose(x, x.T):
102 | raise ValueError("Argument x should be a symmetric array.")
103 |
104 | # call correct metric function
105 | if metric in frequency:
106 | if np.any(x <= 0):
107 | raise ValueError(
108 | "There is a negative or zero value in the input. Please "
109 | "make sure all frequency values are positive."
110 | )
111 | dist = frequency[metric](x)
112 | elif metric in method_dict:
113 | dist = method_dict[metric](x)
114 | elif metric == "membership" or metric == "confusion":
115 | if np.any(x < 0) or np.any(x > 1):
116 | raise ValueError(
117 | "There is an out of bounds value. Please make "
118 | "sure all input values are between [0, 1]."
119 | )
120 | dist = 1 - x
121 | # unsupported metric
122 | else:
123 | raise ValueError(f"{metric} is an unsupported metric.")
124 |
125 | # convert back to float if input was single value
126 | if single_value:
127 | dist = dist.item((0, 0))
128 |
129 | return dist
130 |
131 |
132 | def reverse(x: np.ndarray) -> np.ndarray:
133 | r"""Calculate distance array from similarity using the reverse method.
134 |
135 | .. math::
136 | \delta_{ij} = min(s_{ij}) + max(s_{ij}) - s_{ij}
137 |
138 | where :math:`\delta_{ij}` is the distance between points :math:`i`
139 | and :math:`j`, :math:`s_{ij}` is their similarity coefficient,
140 | and :math:`max` and :math:`min` are the maximum and minimum
141 | values across the entire similarity array.
142 |
143 | Parameters
144 | -----------
145 | x : ndarray
146 | Similarity array.
147 |
148 | Returns
149 | -------
150 | dist : ndarray
151 | Distance array.
152 |
153 | """
154 | dist = np.max(x) + np.min(x) - x
155 | return dist
156 |
157 |
158 | def reciprocal(x: np.ndarray) -> np.ndarray:
159 | r"""Calculate distance array from similarity using the reciprocal method.
160 |
161 | .. math::
162 | \delta_{ij} = \frac{1}{s_{ij}}
163 |
164 | where :math:`\delta_{ij}` is the distance between points :math:`i`
165 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient.
166 |
167 | Parameters
168 | -----------
169 | x : ndarray
170 | Similarity array.
171 |
172 | Returns
173 | -------
174 | dist : ndarray
175 | Distance array.
176 | """
177 |
178 | if np.any(x <= 0):
179 | raise ValueError(
180 | "There is an out of bounds value. Please make " "sure all similarities are positive."
181 | )
182 | return 1 / x
183 |
184 |
185 | def exponential(x: np.ndarray) -> np.ndarray:
186 | r"""Calculate distance matrix from similarity using the exponential method.
187 |
188 | .. math::
189 | \delta_{ij} = -\ln{\frac{s_{ij}}{max(s_{ij})}}
190 |
191 | where :math:`\delta_{ij}` is the distance between points :math:`i`
192 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient.
193 |
194 | Parameters
195 | -----------
196 | x : ndarray
197 | Similarity array.
198 |
199 | Returns
200 | -------
201 | dist : ndarray
202 | Distance array.
203 |
204 | """
205 | max_sim = np.max(x)
206 | if max_sim == 0:
207 | raise ValueError("Maximum similarity in `x` is 0. Distance cannot be computed.")
208 | dist = -np.log(x / max_sim)
209 | return dist
210 |
211 |
212 | def gaussian(x: np.ndarray) -> np.ndarray:
213 | r"""Calculate distance matrix from similarity using the Gaussian method.
214 |
215 | .. math::
216 | \delta_{ij} = \sqrt{-\ln{\frac{s_{ij}}{max(s_{ij})}}}
217 |
218 | where :math:`\delta_{ij}` is the distance between points :math:`i`
219 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient.
220 |
221 | Parameters
222 | -----------
223 | x : ndarray
224 | Similarity array.
225 |
226 | Returns
227 | -------
228 | dist : ndarray
229 | Distance array.
230 |
231 | """
232 | max_sim = np.max(x)
233 | if max_sim == 0:
234 | raise ValueError("Maximum similarity in `x` is 0. Distance cannot be computed.")
235 | y = x / max_sim
236 | dist = np.sqrt(-np.log(y))
237 | return dist
238 |
239 |
240 | def correlation(x: np.ndarray) -> np.ndarray:
241 | r"""Calculate distance array from correlation array.
242 |
243 | .. math::
244 | \delta_{ij} = \sqrt{1 - r_{ij}}
245 |
246 | where :math:`\delta_{ij}` is the distance between points :math:`i`
247 | and :math:`j`, and :math:`r_{ij}` is their correlation.
248 |
249 | Parameters
250 | -----------
251 | x : ndarray
252 | Correlation array.
253 |
254 | Returns
255 | -------
256 | dist : ndarray
257 | Distance array.
258 |
259 | """
260 | if np.any(x < -1) or np.any(x > 1):
261 | raise ValueError(
262 | "There is an out of bounds value. Please make "
263 | "sure all correlations are between [-1, 1]."
264 | )
265 | dist = np.sqrt(1 - x)
266 | return dist
267 |
268 |
269 | def transition(x: np.ndarray) -> np.ndarray:
270 | r"""Calculate distance array from frequency using the transition method.
271 |
272 | .. math::
273 | \delta_{ij} = \frac{1}{\sqrt{f_{ij}}}
274 |
275 | where :math:`\delta_{ij}` is the distance between points :math:`i`
276 | and :math:`j`, and :math:`f_{ij}` is their frequency.
277 |
278 | Parameters
279 | -----------
280 | x : ndarray
281 | Symmetric frequency array.
282 |
283 | Returns
284 | -------
285 | dist : ndarray
286 | Distance array.
287 |
288 | """
289 | dist = 1 / np.sqrt(x)
290 | return dist
291 |
292 |
293 | def co_occurrence(x: np.ndarray) -> np.ndarray:
294 | r"""Calculate distance array from frequency using the co-occurrence method.
295 |
296 | .. math::
297 | \delta_{ij} = \left(1 + \frac{f_{ij}\sum_{i,j}{f_{ij}}}{\sum_{i}{f_{ij}}\sum_{j}{f_{ij}}} \right)^{-1}
298 |
299 | where :math:`\delta_{ij}` is the distance between points :math:`i`
300 | and :math:`j`, and :math:`f_{ij}` is their frequency.
301 |
302 | Parameters
303 | -----------
304 | x : ndarray
305 | Frequency array.
306 |
307 | Returns
308 | -------
309 | dist : ndarray
310 | Co-occurrence array.
311 |
312 | """
313 | # compute sums along each axis
314 | i = np.sum(x, axis=0, keepdims=True)
315 | j = np.sum(x, axis=1, keepdims=True)
316 | # multiply sums to scalar value
317 | bottom = np.dot(i, j)
318 | # multiply each element by the sum of entire array
319 | top = x * np.sum(x)
320 | # evaluate function as a whole
321 | dist = (1 + (top / bottom)) ** -1
322 | return dist
323 |
324 |
325 | def gravity(x: np.ndarray) -> np.ndarray:
326 | r"""Calculate distance array from frequency using the gravity method.
327 |
328 | .. math::
329 | \delta_{ij} = \sqrt{\frac{\sum_{i}{f_{ij}}\sum_{j}{f_{ij}}}
330 | {f_{ij}\sum_{i,j}{f_{ij}}}}
331 |
332 | where :math:`\delta_{ij}` is the distance between points :math:`i`
333 | and :math:`j`, and :math:`f_{ij}` is their frequency.
334 |
335 | Parameters
336 | -----------
337 | x : ndarray
338 | Symmetric frequency array.
339 |
340 | Returns
341 | -------
342 | dist : ndarray
343 | Symmetric gravity array.
344 |
345 | """
346 | # compute sums along each axis
347 | i = np.sum(x, axis=0, keepdims=True)
348 | j = np.sum(x, axis=1, keepdims=True)
349 | # multiply sums to scalar value
350 | top = np.dot(i, j)
351 | # multiply each element by the sum of entire array
352 | bottom = x * np.sum(x)
353 | # take square root of the fraction
354 | dist = np.sqrt(top / bottom)
355 | return dist
356 |
357 |
358 | def probability(x: np.ndarray) -> np.ndarray:
359 | r"""Calculate distance array from probability array.
360 |
361 | .. math::
362 | \delta_{ij} = \sqrt{-\ln{\frac{s_{ij}}{max(s_{ij})}}}
363 |
364 | where :math:`\delta_{ij}` is the distance between points :math:`i`
365 | and :math:`j`, and :math:`p_{ij}` is their probablity.
366 |
367 | Parameters
368 | -----------
369 | x : ndarray
370 | Symmetric probability array.
371 |
372 | Returns
373 | -------
374 | dist : ndarray
375 | Distance array.
376 |
377 | """
378 | if np.any(x <= 0) or np.any(x > 1):
379 | raise ValueError(
380 | "There is an out of bounds value. Please make "
381 | "sure all probabilities are between (0, 1]."
382 | )
383 | y = np.arcsin(x)
384 | dist = 1 / np.sqrt(y)
385 | return dist
386 |
387 |
388 | def covariance(x: np.ndarray) -> np.ndarray:
389 | r"""Calculate distance array from similarity using the covariance method.
390 |
391 | .. math::
392 | \delta_{ij} = \sqrt{s_{ii}+s_{jj}-2s_{ij}}
393 |
394 | where :math:`\delta_{ij}` is the distance between points :math:`i`
395 | and :math:`j`, :math:`s_{ii}` and :math:`s_{jj}` are the variances
396 | of feature :math:`i` and feature :math:`j`, and :math:`s_{ij}`
397 | is the covariance between the two features.
398 |
399 | Parameters
400 | -----------
401 | x : ndarray
402 | Covariance array.
403 |
404 | Returns
405 | -------
406 | dist : ndarray
407 | Distance array.
408 |
409 | """
410 | variance = np.diag(x).reshape([x.shape[0], 1]) * np.ones([1, x.shape[0]])
411 | if np.any(variance < 0):
412 | raise ValueError("Variance of a single variable cannot be negative.")
413 |
414 | dist = variance + variance.T - 2 * x
415 | dist = np.sqrt(dist)
416 | return dist
417 |
--------------------------------------------------------------------------------
/selector/measures/diversity.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 | """Molecule dataset diversity calculation module."""
26 |
27 | import warnings
28 |
29 | import numpy as np
30 |
31 | from selector.measures.similarity import tanimoto
32 |
33 | __all__ = [
34 | "compute_diversity",
35 | "logdet",
36 | "shannon_entropy",
37 | "explicit_diversity_index",
38 | "wdud",
39 | "hypersphere_overlap_of_subset",
40 | "gini_coefficient",
41 | "nearest_average_tanimoto",
42 | ]
43 |
44 |
45 | def compute_diversity(
46 | feature_subset: np.array,
47 | div_type: str = "shannon_entropy",
48 | normalize: bool = False,
49 | truncation: bool = False,
50 | features: np.array = None,
51 | cs: int = None,
52 | ) -> float:
53 | """Compute diversity metrics.
54 |
55 | Parameters
56 | ----------
57 | feature_subset : np.ndarray
58 | Feature matrix.
59 | div_type : str, optional
60 | Method of calculation diversity for a given molecule set, which
61 | includes "entropy", "logdet", "shannon entropy", "wdud",
62 | "gini coefficient" "hypersphere_overlap", and
63 | "explicit diversity index".
64 | The default is "entropy".
65 | normalize : bool, optional
66 | Normalize the entropy to [0, 1]. Default is "False".
67 | truncation : bool, optional
68 | Use the truncated Shannon entropy. Default is "False".
69 | features : np.ndarray, optional
70 | Feature matrix of entire molecule library, used only if
71 | calculating `hypersphere_overlap_of_subset`. Default is "None".
72 | cs : int, optional
73 | Number of common substructures in molecular compound dataset.
74 | Used only if calculating `explicit_diversity_index`. Default is "None".
75 |
76 |
77 | Returns
78 | -------
79 | float, computed diversity.
80 |
81 | """
82 | func_dict = {
83 | "logdet": logdet,
84 | "wdud": wdud,
85 | "gini_coefficient": gini_coefficient,
86 | }
87 |
88 | if div_type in func_dict:
89 | return func_dict[div_type](feature_subset)
90 |
91 | # hypersphere overlap of subset
92 | elif div_type == "hypersphere_overlap":
93 | if features is None:
94 | raise ValueError(
95 | "Please input a feature matrix of the entire "
96 | "dataset when calculating hypersphere overlap."
97 | )
98 | return hypersphere_overlap_of_subset(features, feature_subset)
99 |
100 | elif div_type == "shannon_entropy":
101 | return shannon_entropy(feature_subset, normalize=normalize, truncation=truncation)
102 |
103 | elif div_type == "explicit_diversity_index":
104 | if cs is None:
105 | raise ValueError(
106 | "Attribute `cs` is missing. "
107 | "Please input `cs` value to use explicit_diversity_index."
108 | )
109 | elif cs == 0:
110 | raise ValueError("Divide by zero error: Attribute `cs` cannot be 0.")
111 | return explicit_diversity_index(feature_subset, cs)
112 | else:
113 | raise ValueError(f"Diversity type {div_type} not supported.")
114 |
115 |
116 | def logdet(x: np.ndarray) -> float:
117 | r"""Compute the log determinant function.
118 |
119 | Given a :math:`n_s \times n_f` feature matrix :math:`x`, where :math:`n_s` is the number of
120 | samples and :math:`n_f` is the number of features, the log determinant function is defined as:
121 |
122 | .. math:
123 | F_\text{logdet} = \log{\left(\det{\left(\mathbf{x}\mathbf{x}^T + \mathbf{I}\right)}\right)}
124 |
125 | where the :math:`I` is the :math:`n_s \times n_s` identity matrix.
126 | Higher values of :math:`F_\text{logdet}` mean more diversity.
127 |
128 | Parameters
129 | ----------
130 | x: ndarray of shape (n_samples, n_features)
131 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space,
132 |
133 | Returns
134 | -------
135 | f_logdet: float
136 | The volume of parallelotope spanned by the matrix.
137 |
138 | Notes
139 | -----
140 | The log-determinant function is based on the formula in Nakamura, T., Sci Rep 2022.
141 | Please note that we used the
142 | natural logrithim to avoid the numerical stability issues,
143 | https://github.com/theochem/Selector/issues/229.
144 |
145 | References
146 | ----------
147 | Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S.., Selecting
148 | molecules with diverse structures and properties by maximizing submodular functions of
149 | descriptors learned with graph neural networks. Scientific Reports 12, 2022.
150 |
151 | """
152 | mid = np.dot(x, np.transpose(x)) + np.identity(x.shape[0])
153 | logdet_mid = np.linalg.slogdet(mid)
154 | f_logdet = logdet_mid.sign * logdet_mid.logabsdet
155 | return f_logdet
156 |
157 |
158 | def shannon_entropy(x: np.ndarray, normalize=True, truncation=False) -> float:
159 | r"""Compute the shannon entropy of a binary matrix.
160 |
161 | Higher values mean more diversity.
162 |
163 | Parameters
164 | ----------
165 | x : ndarray
166 | Bit-string matrix.
167 | normalize : bool, optional
168 | Normalize the entropy to [0, 1]. Default=True.
169 | truncation : bool, optional
170 | Use the truncated Shannon entropy by only counting the contributions of one-bits.
171 | Default=False.
172 |
173 | Returns
174 | -------
175 | float :
176 | The shannon entropy of the matrix.
177 |
178 | Notes
179 | -----
180 | Suppose we have :math:`m` compounds and each compound has :math:`n` bits binary fingerprints.
181 | The binary matrix (feature matrix) is :math:`\mathbf{x} \in m \times n`, where each
182 | row is a compound and each column contains the :math:`n`-bit binary fingerprint.
183 | The equation for Shannon entropy is given by [1]_ and [3]_,
184 |
185 | .. math::
186 | H = \sum_i^m \left[ - p_i \log_2{p_i } - (1 - p_i)\log_2(1 - p_i) \right]
187 |
188 | where :math:`p_i` represents the relative frequency of `1` bits at the fingerprint position
189 | :math:`i`. When :math:`p_i = 0` or :math:`p_i = 1`, the :math:`SE_i` is zero.
190 | When `completeness` is True, the entropy is calculated as in [2]_ instead
191 |
192 | .. math::
193 | H = \sum_i^m \left[ - p_i \log_2{p_i } \right]
194 |
195 | When `normalize` is True, the entropy is normalized by a scaling factor so that the entropy is in the range of
196 | [0, 1], [2]_
197 |
198 | .. math::
199 | H = \frac{ \sum_i^m \left[ - p_i \log_2{p_i } - (1 - p_i)\log_2(1 - p_i) \right]}
200 | {n \log_2{2} / 2}
201 |
202 | But please note, when `completeness` is False and `normalize` is True, the formula has not been
203 | used in any literature. It is just a simple normalization of the entropy and the user can use it at their own risk.
204 |
205 | References
206 | ----------
207 | .. [1] Wang, Y., Geppert, H., & Bajorath, J. (2009). Shannon entropy-based fingerprint similarity
208 | search strategy. Journal of Chemical Information and Modeling, 49(7), 1687-1691.
209 | .. [2] Leguy, J., Glavatskikh, M., Cauchy, T., & Da Mota, B. (2021). Scalable estimator of the
210 | diversity for de novo molecular generation resulting in a more robust QM dataset (OD9) and a
211 | more efficient molecular optimization. Journal of Cheminformatics, 13(1), 1-17.
212 | .. [3] Weidlich, I. E., & Filippov, I. V. (2016). Using the Gini coefficient to measure the
213 | chemical diversity of small molecule libraries. Journal of Computational Chemistry, 37(22), 2091-2097.
214 |
215 | """
216 | # check if matrix is binary
217 | if np.count_nonzero((x != 0) & (x != 1)) != 0:
218 | raise ValueError("Attribute `x` should have binary values.")
219 |
220 | p_i_arr = np.sum(x, axis=0) / x.shape[0]
221 | h_x = 0
222 |
223 | for p_i in p_i_arr:
224 | if p_i == 0 or p_i == 1:
225 | # p_i = 0
226 | se_i = 0
227 | else:
228 | if truncation:
229 | # from https://jcheminf.biomedcentral.com/articles/10.1186/s13321-021-00554-8
230 | se_i = -p_i * np.log2(p_i)
231 | else:
232 | # from https://pubs.acs.org/doi/10.1021/ci900159f
233 | se_i = -p_i * np.log2(p_i) - (1 - p_i) * np.log2(1 - p_i)
234 |
235 | h_x += se_i
236 |
237 | if normalize:
238 | if truncation:
239 | warnings.warn(
240 | "Computing the normalized Shannon entropy only counting the on-bits has not been reported in "
241 | "literature. The user can use it at their own risk."
242 | )
243 |
244 | h_x /= x.shape[1] * np.log2(2) / 2
245 |
246 | return h_x
247 |
248 |
249 | # todo: add tests for edi
250 | def explicit_diversity_index(
251 | x: np.ndarray,
252 | cs: int,
253 | ) -> float:
254 | """Compute the explicit diversity index.
255 |
256 | Parameters
257 | ----------
258 | x: ndarray of shape (n_samples, n_features)
259 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space.
260 | cs : int
261 | Number of common substructures in the compound set.
262 |
263 | Returns
264 | -------
265 | edi_scaled : float
266 | Explicit diversity index.
267 |
268 | Notes
269 | -----
270 | This method hasn't been tested.
271 |
272 | This method is used only for datasets of molecular compounds.
273 |
274 | Papp, Á., Gulyás-Forró, A., Gulyás, Z., Dormán, G., Ürge, L.,
275 | and Darvas, F.. (2006) Explicit Diversity Index (EDI):
276 | A Novel Measure for Assessing the Diversity of Compound Databases.
277 | Journal of Chemical Information and Modeling 46, 1898-1904.
278 | """
279 | nc = len(x)
280 | sdi = (1 - nearest_average_tanimoto(x)) / (0.8047 - (0.065 * (np.log(nc))))
281 | cr = -1 * np.log10(nc / (cs**2))
282 | edi = (sdi + cr) * 0.7071067811865476
283 | edi_scaled = ((np.tanh(edi / 3) + 1) / 2) * 100
284 | return edi_scaled
285 |
286 |
287 | def wdud(x: np.ndarray) -> float:
288 | r"""Compute the Wasserstein Distance to Uniform Distribution(WDUD).
289 |
290 | The equation for the Wasserstein Distance for a single feature to uniform distribution is
291 |
292 | .. math::
293 | WDUD(x) = \int_{0}^{1} |U(x) - V(x)|dx
294 |
295 | where the feature is normalized to [0, 1], :math:`U(x)=x` is the cumulative distribution
296 | of the uniform distribution on [0, 1], and :math:`V(x) = \sum_{y <= x}1 / N` is the discrete
297 | distribution of the values of the feature in :math:`x`, where :math:`y` is the ith feature. This
298 | integral is calculated iteratively between :math:`y_i` and :math:`y_{i+1}`, using trapezoidal method.
299 |
300 | Lower values of the WDUD mean more diversity because the features of the selected set are
301 | more evenly distributed over the range of feature values.
302 |
303 | Parameters
304 | ----------
305 | x: ndarray of shape (n_samples, n_features)
306 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space.
307 |
308 | Returns
309 | -------
310 | float :
311 | The mean of the WDUD of each feature over all molecules.
312 |
313 | Notes
314 | -----
315 | Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S.. (2022)
316 | Selecting molecules with diverse structures and properties by maximizing
317 | submodular functions of descriptors learned with graph neural networks.
318 | Scientific Reports 12.
319 |
320 | """
321 | if x.ndim != 2:
322 | raise ValueError(f"The number of dimensions {x.ndim} should be two.")
323 |
324 | # find the range of each feature
325 | col_diff = np.ptp(x, axis=0)
326 | # Normalization of each feature to [0, 1]
327 | if np.any(np.abs(col_diff) < 1e-30):
328 | # warning if some feature columns are constant
329 | warnings.warn(
330 | "Some of the features are constant which will cause the normalization to fail. "
331 | "Now removing them."
332 | )
333 | if np.all(col_diff < 1.0e-30):
334 | raise ValueError(
335 | "Unfortunately, all the features are constants and wdud cannot be calculated."
336 | )
337 | else:
338 | # remove the constant feature columns
339 | mask = np.ptp(x, axis=0) > 1e-30
340 | x = x[:, mask]
341 | x_norm = (x - np.min(x, axis=0)) / np.ptp(x, axis=0)
342 |
343 | # min_max normalization:
344 | n_samples, n_features = x_norm.shape
345 | ans = [] # store the Wasserstein distance for each feature
346 | for i in range(0, n_features):
347 | wdu = 0.0
348 | y = np.sort(x_norm[:, i])
349 | # Round to the sixth decimal place and count number of unique elements
350 | # to construct an accurate cumulative discrete distribution func \sum_{x <= y_{i + 1}} 1/k
351 | y, counts = np.unique(np.round(x_norm[:, i], decimals=6), return_counts=True)
352 | p = 0
353 | # Ignore 0 and because v_min= 0
354 | for j in range(1, len(counts)):
355 | # integral from y_{i - 1} to y_{i} of |x - \sum_{x <= y_{i}} 1/k| dx
356 | yi1 = y[j - 1]
357 | yi = y[j]
358 | # Make a grid from yi1 to yi
359 | grid = np.linspace(yi1, yi, num=1000, endpoint=True)
360 | # Evaluate the integrand |x - \sum_{x <= y_{i + 1}} 1/k|
361 | p += counts[j - 1]
362 | integrand = np.abs(grid - p / n_samples)
363 | # Integrate using np.trapz
364 | wdu += np.trapz(y=integrand, x=grid)
365 | ans.append(wdu)
366 | return np.average(ans)
367 |
368 |
369 | def hypersphere_overlap_of_subset(x: np.ndarray, x_subset: np.array) -> float:
370 | r"""Compute the overlap of subset with hyper-spheres around each point
371 |
372 | The edge penalty is also included, which disregards areas
373 | outside of the boundary of the full feature space/library.
374 | This is calculated as:
375 |
376 | .. math::
377 | g(S) = \sum_{i < j}^k O(i, j) + \sum^k_m E(m)
378 |
379 | where :math:`i, j` is over the subset of molecules,
380 | :math:`O(i, j)` is the approximate overlap between hyperspheres,
381 | :math:`k` is the number of features and :math:`E`
382 | is the edge penalty of a molecule.
383 |
384 | Lower values mean more diversity.
385 |
386 | Parameters
387 | ----------
388 | x : ndarray
389 | Feature matrix of all molecules.
390 | x_subset : ndarray
391 | Feature matrix of selected subset of molecules.
392 |
393 | Returns
394 | -------
395 | float :
396 | The approximate overlapping volume of hyperspheres
397 | drawn around the selected points/molecules.
398 |
399 | Notes
400 | -----
401 | The hypersphere overlap volume is calculated using an approximation formula from Agrafiotis (1997).
402 |
403 | Agrafiotis, D. K.. (1997) Stochastic Algorithms for Maximizing Molecular Diversity.
404 | Journal of Chemical Information and Computer Sciences 37, 841-851.
405 | """
406 |
407 | # Find the maximum and minimum over each feature across all molecules.
408 | max_x = np.max(x, axis=0)
409 | min_x = np.min(x, axis=0)
410 |
411 | if np.all(np.abs(max_x - min_x) < 1e-30):
412 | raise ValueError("All of the features are redundant which causes normalization to fail.")
413 |
414 | # Remove redundant features
415 | non_red_feat = np.abs(max_x - min_x) > 1e-30
416 | x = x[:, non_red_feat]
417 | x_subset = x_subset[:, non_red_feat]
418 | max_x = max_x[non_red_feat]
419 | min_x = min_x[non_red_feat]
420 |
421 | d = len(x_subset[0])
422 | k = len(x_subset[:, 0])
423 |
424 | # normalization of each feature to [0, 1]
425 | x_norm = (x_subset - min_x) / (max_x - min_x)
426 |
427 | # r_o = hypersphere radius
428 | r_o = d * np.sqrt(1 / k)
429 | if r_o > 0.5:
430 | warnings.warn(
431 | "The number of molecules should be much larger" " than the number of features."
432 | )
433 | g_s = 0
434 | edge = 0
435 |
436 | # lambda parameter controls edge penalty
437 | lam = (d - 1.0) / d
438 | # calculate overlap volume
439 | for i in range(0, (k - 1)):
440 | for j in range((i + 1), k):
441 | dist = np.linalg.norm(x_norm[i] - x_norm[j])
442 | # Overlap penalty
443 | if dist <= (2 * r_o):
444 | with np.errstate(divide="ignore"):
445 | # min(100) ignores the inf case with divide by zero
446 | g_s += min(100, 2 * (r_o / dist) - 1)
447 | # Edge penalty: lambda (1 - \sum^d_j e_{ij} / (dr_0)
448 | edge_pen = 0.0
449 | for j_dim in range(0, d):
450 | # calculate dist to closest boundary in jth coordinate,
451 | # with max value = 1, min value = 0
452 | dist_max = np.abs(1 - x_norm[i, j_dim])
453 | dist_min = x_norm[i, j_dim]
454 | dist = min(dist_min, dist_max)
455 | # truncate distance at r_o
456 | if dist > r_o:
457 | dist = r_o
458 | edge_pen += dist
459 | edge_pen /= d * r_o
460 | edge_pen = lam * (1.0 - edge_pen)
461 | edge += edge_pen
462 | g_s += edge
463 | return g_s
464 |
465 |
466 | def gini_coefficient(x: np.ndarray):
467 | r"""
468 | Gini coefficient of bit-wise fingerprints of a database of molecules.
469 |
470 | Measures the chemical diversity of a database of molecules defined by
471 | the following formula:
472 |
473 | .. math::
474 | G = \frac{2 \sum_{i=1}^L i ||y_i||_1 }{N \sum_{i=1}^L ||y_i||_1} - \frac{L+1}{L},
475 |
476 | where :math:`y_i \in \{0, 1\}^N` is a vector of zero and ones of length the
477 | number of molecules :math:`N` of the `i`th feature, and :math:`L` is the feature length.
478 |
479 | Lower values mean more diversity.
480 |
481 | Parameters
482 | ----------
483 | x : ndarray(N, L)
484 | Molecule features in L bits with N molecules.
485 |
486 | Returns
487 | -------
488 | float :
489 | Gini coefficient in the range [0,1].
490 |
491 | References
492 | ----------
493 | Weidlich, Iwona E., and Igor V. Filippov. "Using the gini coefficient to measure the
494 | chemical diversity of small‐molecule libraries." (2016): 2091-2097.
495 |
496 | """
497 | # Check that `x` is a bit-wise fingerprint.
498 | if np.count_nonzero((x != 0) & (x != 1)) != 0:
499 | raise ValueError("Attribute `x` should have binary values.")
500 | if x.ndim != 2:
501 | raise ValueError(f"Attribute `x` should have dimension two rather than {x.ndim}.")
502 |
503 | num_features = x.shape[1]
504 | # Take the bit-count of each column/molecule.
505 | bit_count = np.sum(x, axis=0)
506 |
507 | # Sort the bit-count since Gini coefficients relies on cumulative distribution.
508 | bit_count = np.sort(bit_count)
509 |
510 | # Mean of denominator
511 | denominator = num_features * np.sum(bit_count)
512 | numerator = np.sum(np.arange(1, num_features + 1) * bit_count)
513 |
514 | return 2.0 * numerator / denominator - (num_features + 1) / num_features
515 |
516 |
517 | def nearest_average_tanimoto(x: np.ndarray) -> float:
518 | """Computes the average tanimoto for nearest molecules.
519 |
520 | Parameters
521 | ----------
522 | x : ndarray
523 | Feature matrix.
524 |
525 | Returns
526 | -------
527 | nat : float
528 | Average tanimoto of closest pairs.
529 |
530 | Notes
531 | -----
532 | This computes the tanimoto coefficient of pairs with the shortest
533 | distances, then returns the average of them.
534 | This calculation is explictly for the explicit diversity index.
535 |
536 | Papp, Á., Gulyás-Forró, A., Gulyás, Z., Dormán, G., Ürge, L.,
537 | and Darvas, F.. (2006) Explicit Diversity Index (EDI):
538 | A Novel Measure for Assessing the Diversity of Compound Databases.
539 | Journal of Chemical Information and Modeling 46, 1898-1904.
540 | """
541 | tani = []
542 | for idx, _ in enumerate(x):
543 | # arbitrary distance for comparison:
544 | short = 100
545 | a = 0
546 | b = 0
547 | # search for shortest distance point from idx
548 | for jdx, _ in enumerate(x):
549 | dist = np.linalg.norm(x[idx] - x[jdx])
550 | if dist < short and idx != jdx:
551 | short = dist
552 | a = idx
553 | b = jdx
554 | # calculate tanimoto for each shortest dist pair
555 | tani.append(tanimoto(x[a], x[b]))
556 | # compute average of all shortest tanimoto coeffs
557 | nat = np.average(tani)
558 | return nat
559 |
--------------------------------------------------------------------------------
/selector/measures/similarity.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Similarity Module."""
25 |
26 | from itertools import combinations_with_replacement
27 |
28 | import numpy as np
29 |
30 | __all__ = [
31 | "pairwise_similarity_bit",
32 | "tanimoto",
33 | "modified_tanimoto",
34 | "scaled_similarity_matrix",
35 | ]
36 |
37 |
38 | def pairwise_similarity_bit(X: np.array, metric: str) -> np.ndarray:
39 | """Compute pairwise similarity coefficient matrix.
40 |
41 | Parameters
42 | ----------
43 | X : ndarray of shape (n_samples, n_features)
44 | Feature matrix of `n_samples` samples in `n_features` dimensional space.
45 | metric : str
46 | The metric used when calculating similarity coefficients between samples in a feature array.
47 | Method for calculating similarity coefficient. Options: `"tanimoto"`, `"modified_tanimoto"`.
48 |
49 | Returns
50 | -------
51 | s : ndarray of shape (n_samples, n_samples)
52 | A symmetric similarity matrix between each pair of samples in the feature matrix.
53 | The diagonal elements are directly computed instead of assuming that they are 1.
54 |
55 | """
56 |
57 | available_methods = {
58 | "tanimoto": tanimoto,
59 | "modified_tanimoto": modified_tanimoto,
60 | }
61 | if metric not in available_methods:
62 | raise ValueError(
63 | f"Argument metric={metric} is not recognized! Choose from {available_methods.keys()}"
64 | )
65 | if X.ndim != 2:
66 | raise ValueError(f"Argument features should be a 2D array, got {X.ndim}")
67 |
68 | # make pairwise m-by-m similarity matrix
69 | n_samples = len(X)
70 | s = np.zeros((n_samples, n_samples))
71 | # compute similarity between all pairs of points (including the diagonal elements)
72 | for i, j in combinations_with_replacement(range(n_samples), 2):
73 | s[i, j] = s[j, i] = available_methods[metric](X[i], X[j])
74 | return s
75 |
76 |
77 | def tanimoto(a: np.array, b: np.array) -> float:
78 | r"""Compute Tanimoto coefficient or index (a.k.a. Jaccard similarity coefficient).
79 |
80 | For two binary or non-binary arrays :math:`A` and :math:`B`, Tanimoto coefficient
81 | is defined as the size of their intersection divided by the size of their union:
82 |
83 | .. math::
84 | T(A, B) = \frac{| A \cap B|}{| A \cup B |} =
85 | \frac{| A \cap B|}{|A| + |B| - | A \cap B|} =
86 | \frac{A \cdot B}{\|A\|^2 + \|B\|^2 - A \cdot B}
87 |
88 | where :math:`A \cdot B = \sum_i{A_i B_i}` and :math:`\|A\|^2 = \sum_i{A_i^2}`.
89 |
90 | Parameters
91 | ----------
92 | a : ndarray of shape (n_features,)
93 | The 1D feature array of sample :math:`A` in an `n_features` dimensional space.
94 | b : ndarray of shape (n_features,)
95 | The 1D feature array of sample :math:`B` in an `n_features` dimensional space.
96 |
97 | Returns
98 | -------
99 | coeff : float
100 | Tanimoto coefficient between feature arrays :math:`A` and :math:`B`.
101 |
102 | Bajusz, D., Rácz, A., and Héberger, K.. (2015)
103 | Why is Tanimoto index an appropriate choice for fingerprint-based similarity calculations?.
104 | Journal of Cheminformatics 7.
105 |
106 | """
107 | if a.ndim != 1 or b.ndim != 1:
108 | raise ValueError(f"Arguments a and b should be 1D arrays, got {a.ndim} and {b.ndim}")
109 | if a.shape != b.shape:
110 | raise ValueError(
111 | f"Arguments a and b should have the same shape, got {a.shape} != {b.shape}"
112 | )
113 | coeff = sum(a * b) / (sum(a**2) + sum(b**2) - sum(a * b))
114 | return coeff
115 |
116 |
117 | def modified_tanimoto(a: np.array, b: np.array) -> float:
118 | r"""Compute the modified tanimoto coefficient from bitstring vectors of data points A and B.
119 |
120 | Adjusts calculation of the Tanimoto coefficient to counter its natural bias towards
121 | shorter vectors using a Bernoulli probability model.
122 |
123 | .. math::
124 | {mt} = \frac{2-p}{3} T_1 + \frac{1+p}{3} T_0
125 |
126 | where :math:`p` is success probability of independent trials,
127 | :math:`T_1` is the number of common '1' bits between data points
128 | (:math:`T_1 = | A \cap B |`), and :math:`T_0` is the number of common '0'
129 | bits between data points (:math:`T_0 = |(1-A) \cap (1-B)|`).
130 |
131 |
132 | Parameters
133 | ----------
134 | a : ndarray of shape (n_features,)
135 | The 1D bitstring feature array of sample :math:`A` in an `n_features` dimensional space.
136 | b : ndarray of shape (n_features,)
137 | The 1D bitstring feature array of sample :math:`B` in an `n_features` dimensional space.
138 |
139 | Returns
140 | -------
141 | mt : float
142 | Modified tanimoto coefficient between bitstring feature arrays :math:`A` and :math:`B`.
143 |
144 | Notes
145 | -----
146 | The equation above has been derived from
147 |
148 | .. math::
149 | {mt}_{\alpha} = {\alpha}T_1 + (1-\alpha)T_0
150 |
151 | where :math:`\alpha = \frac{2-p}{3}`. This is done so that the expected value
152 | of the modified tanimoto, :math:`E(mt)`, remains constant even as the number of
153 | trials :math:`p` grows larger.
154 |
155 | Fligner, M. A., Verducci, J. S., and Blower, P. E.. (2002)
156 | A Modification of the Jaccard-Tanimoto Similarity Index for
157 | Diverse Selection of Chemical Compounds Using Binary Strings.
158 | Technometrics 44, 110-119.
159 |
160 | """
161 | if a.ndim != 1:
162 | raise ValueError(f"Argument `a` should have dimension 1 rather than {a.ndim}.")
163 | if b.ndim != 1:
164 | raise ValueError(f"Argument `b` should have dimension 1 rather than {b.ndim}.")
165 | if a.shape != b.shape:
166 | raise ValueError(
167 | f"Arguments a and b should have the same shape, got {a.shape} != {b.shape}"
168 | )
169 |
170 | n_features = len(a)
171 | # number of common '1' bits between points A and B
172 | n_11 = sum(a * b)
173 | # number of common '0' bits between points A and B
174 | n_00 = sum((1 - a) * (1 - b))
175 |
176 | # calculate Tanimoto coefficient based on '0' bits
177 | t_1 = 1
178 | if n_00 != n_features:
179 | # bit strings are not all '0's
180 | t_1 = n_11 / (n_features - n_00)
181 | # calculate Tanimoto coefficient based on '1' bits
182 | t_0 = 1
183 | if n_11 != n_features:
184 | # bit strings are not all '1's
185 | t_0 = n_00 / (n_features - n_11)
186 |
187 | # combine into modified tanimoto using Bernoulli Model
188 | # p = independent success trials
189 | # evaluated as total number of '1' bits
190 | # divided by 2x the fingerprint length
191 | p = (n_features - n_00 + n_11) / (2 * n_features)
192 | # mt = x * T_1 + (1-x) * T_0
193 | # x = (2-p)/3 so that E(mt) = 1/3, no matter the value of p
194 | mt = (((2 - p) / 3) * t_1) + (((1 + p) / 3) * t_0)
195 | return mt
196 |
197 |
198 | def scaled_similarity_matrix(X: np.array) -> np.ndarray:
199 | r"""Compute the scaled similarity matrix.
200 |
201 | .. math::
202 | X(i,j) = \frac{X(i,j)}{\sqrt{X(i,i)X(j,j)}}
203 |
204 | Parameters
205 | ----------
206 | X : ndarray of shape (n_samples, n_samples)
207 | Similarity matrix of `n_samples`.
208 |
209 | Returns
210 | -------
211 | s : ndarray of shape (n_samples, n_samples)
212 | A scaled symmetric similarity matrix.
213 |
214 | """
215 | if X.ndim != 2:
216 | raise ValueError(f"Argument similarity matrix should be a 2D array, got {X.ndim}")
217 |
218 | if X.shape[0] != X.shape[1]:
219 | raise ValueError(
220 | f"Argument similarity matrix should be a square matrix (having same number of rows and columns), got {X.shape[0]} and {X.shape[1]}"
221 | )
222 |
223 | if not (np.all(X >= 0) and np.all(np.diag(X) > 0)):
224 | raise ValueError(
225 | "All elements of similarity matrix should be greater than zero and diagonals should be non-zero"
226 | )
227 |
228 | # scaling does not happen if the matrix is binary similarity matrix with all diagonal elements as 1
229 | if np.all(np.diag(X) == 1):
230 | print("No scaling is taking effect")
231 | return X
232 | else:
233 | # make a scaled similarity matrix
234 | n_samples = len(X)
235 | s = np.zeros((n_samples, n_samples))
236 | # calculate the square root of the diagonal elements
237 | sqrt_diag = np.sqrt(np.diag(X))
238 | # calculate the product of the square roots of the diagonal elements
239 | product_sqrt_diag = np.outer(sqrt_diag, sqrt_diag)
240 | # divide each element of the matrix by the product of the square roots of diagonal elements
241 | s = X / product_sqrt_diag
242 | return s
243 |
244 |
245 | def similarity_index(x: np.array, y: np.array, sim_index: str) -> float:
246 | """Compute similarity index matrix.
247 |
248 | Parameters
249 | ----------
250 | x : ndarray of shape (n_features,)
251 | Feature array of sample `x` in an `n_features` dimensional space
252 | y : ndarray of shape (n_features,)
253 | Feature array of sample `y` in an `n_features` dimensional space
254 | sim_index : str, optional
255 | The key with the abbreviation of the similarity index to be used for calculations.
256 | Possible values are:
257 | - 'AC': Austin-Colwell
258 | - 'BUB': Baroni-Urbani-Buser
259 | - 'CTn': Consoni-Todschini n (n=1,2)
260 | - 'Fai': Faith
261 | - 'Gle': Gleason
262 | - 'Ja': Jaccard
263 | - 'JT': Jaccard-Tanimoto
264 | - 'RT': Rogers-Tanimoto
265 | - 'RR': Russel-Rao
266 | - 'SM': Sokal-Michener
267 | - 'SSn': Sokal-Sneath n (n=1,2)
268 | Default is 'RR'.
269 |
270 | Returns
271 | -------
272 | sim : float
273 | The similarity index value between the feature arrays `x` and `y`.
274 | """
275 | # Define the similarity index functions
276 | similarity_indices = {
277 | "AC": lambda a, d, dis, p: 2 / np.pi * np.arcsin(((a + d) / p) ** 0.5),
278 | "BUB": lambda a, d, dis, p: ((a * d) ** 0.5 + a) / ((a * d) ** 0.5 + a + dis),
279 | "CT1": lambda a, d, dis, p: np.log(1 + a + d) / np.log(1 + p),
280 | "CT2": lambda a, d, dis, p: (np.log(1 + p) - np.log(1 + dis)) / np.log(1 + p),
281 | "Fai": lambda a, d, dis, p: (a + 0.5 * d) / p,
282 | "Gle": lambda a, d, dis, p: 2 * a / (2 * a + dis),
283 | "Ja": lambda a, d, dis, p: 3 * a / (3 * a + dis),
284 | "JT": lambda a, d, dis, p: a / (a + dis),
285 | "RT": lambda a, d, dis, p: (a + d) / (p + dis),
286 | "RR": lambda a, d, dis, p: a / p,
287 | "SM": lambda a, d, dis, p: (a + d) / p,
288 | "SS1": lambda a, d, dis, p: a / (a + 2 * dis),
289 | "SS2": lambda a, d, dis, p: (2 * (a + d)) / (p + (a + d)),
290 | }
291 |
292 | if sim_index not in similarity_indices:
293 | raise ValueError(
294 | f"Argument sim_index={sim_index} is not recognized! Choose from {similarity_indices.keys()}"
295 | )
296 | if x.ndim != 1 or y.ndim != 1:
297 | raise ValueError(f"Arguments x and y should be 1D arrays, got {x.ndim} and {y.ndim}")
298 | if x.shape != y.shape:
299 | raise ValueError(
300 | f"Arguments x and y should have the same shape, got {x.shape} != {y.shape}"
301 | )
302 | a, d, dis, p = _compute_base_descriptors(x, y)
303 | return similarity_indices[sim_index](a, d, dis, p)
304 |
305 |
306 | def _compute_base_descriptors(x, y):
307 | """Compute the base descriptors for the similarity indices.
308 |
309 | Parameters
310 | ----------
311 | x : ndarray of shape (n_features,)
312 | Feature array of sample `x` in an `n_features` dimensional space
313 | y : ndarray of shape (n_features,)
314 | Feature array of sample `y` in an `n_features` dimensional space
315 |
316 | Returns
317 | -------
318 | tuple(int, int, int, int)
319 | The number of common on bits, number of common off bits, number of 1-0 mismatches, and the
320 | length of the fingerprint.
321 | """
322 | p = len(x)
323 | a = np.dot(x, y)
324 | d = np.dot(1 - x, 1 - y)
325 | dis = p - a - d
326 | return a, d, dis, p
327 |
--------------------------------------------------------------------------------
/selector/measures/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 | """Test Module."""
26 |
--------------------------------------------------------------------------------
/selector/measures/tests/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Common functions for test module."""
25 |
26 | import numpy as np
27 |
28 | try:
29 | from importlib_resources import path
30 | except ImportError:
31 | from importlib.resources import path
32 |
33 |
34 | def bit_cosine(a, b):
35 | """Compute dice coefficient.
36 |
37 | Parameters
38 | ----------
39 | a : array_like
40 | molecule A's features in bit string.
41 | b : array_like
42 | molecules B's features in bit string.
43 |
44 | Returns
45 | -------
46 | coeff : int
47 | dice coefficient for molecule A and B.
48 | """
49 | a_feat = np.count_nonzero(a)
50 | b_feat = np.count_nonzero(b)
51 | c = 0
52 | for idx, _ in enumerate(a):
53 | if a[idx] == b[idx] and a[idx] != 0:
54 | c += 1
55 | b_c = c / ((a_feat * b_feat) ** 0.5)
56 | return b_c
57 |
58 |
59 | def bit_dice(a, b):
60 | """Compute dice coefficient.
61 |
62 | Parameters
63 | ----------
64 | a : array_like
65 | molecule A's features.
66 | b : array_like
67 | molecules B's features.
68 |
69 | Returns
70 | -------
71 | coeff : int
72 | dice coefficient for molecule A and B.
73 | """
74 | a_feat = np.count_nonzero(a)
75 | b_feat = np.count_nonzero(b)
76 | c = 0
77 | for idx, _ in enumerate(a):
78 | if a[idx] == b[idx] and a[idx] != 0:
79 | c += 1
80 | b_d = (2 * c) / (a_feat + b_feat)
81 | return b_d
82 |
83 |
84 | def cosine(a, b):
85 | """Compute cosine coefficient.
86 |
87 | Parameters
88 | ----------
89 | a : array_like
90 | molecule A's features.
91 | b : array_like
92 | molecules B's features.
93 |
94 | Returns
95 | -------
96 | coeff : int
97 | cosine coefficient for molecule A and B.
98 | """
99 | coeff = (sum(a * b)) / (((sum(a**2)) + (sum(b**2))) ** 0.5)
100 | return coeff
101 |
102 |
103 | def dice(a, b):
104 | """Compute dice coefficient.
105 |
106 | Parameters
107 | ----------
108 | a : array_like
109 | molecule A's features.
110 | b : array_like
111 | molecules B's features.
112 |
113 | Returns
114 | -------
115 | coeff : int
116 | dice coefficient for molecule A and B.
117 | """
118 | coeff = (2 * (sum(a * b))) / ((sum(a**2)) + (sum(b**2)))
119 | return coeff
120 |
--------------------------------------------------------------------------------
/selector/measures/tests/test_converter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Test Converter Module."""
25 |
26 | import numpy as np
27 | from numpy.testing import assert_almost_equal, assert_equal, assert_raises
28 |
29 | from selector.measures import converter as cv
30 |
31 | # Tests for variations on input `x` for sim_to_dist()
32 |
33 |
34 | def test_sim_2_dist_float_int():
35 | """Test similarity to distance input handling when input is a float or int."""
36 | expected_1 = 0.25
37 | int_out = cv.sim_to_dist(4, "reciprocal")
38 | assert_equal(int_out, expected_1)
39 | expected_2 = 2
40 | float_out = cv.sim_to_dist(0.5, "reciprocal")
41 | assert_equal(float_out, expected_2)
42 |
43 |
44 | def test_sim_2_dist_array_dimension_error():
45 | """Test sim to dist function with incorrect input dimensions for `x`."""
46 | assert_raises(ValueError, cv.sim_to_dist, np.ones([2, 2, 2]), "reciprocal")
47 |
48 |
49 | def test_sim_2_dist_1d_metric_error():
50 | """Test sim to dist function with an invalid metric for 1D arrays."""
51 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "gravity")
52 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "co-occurrence")
53 |
54 |
55 | # Tests for variations on input `metric` for sim_to_dist()
56 |
57 |
58 | def test_sim_2_dist():
59 | """Test similarity to distance method with specified metric."""
60 | x = np.array([[1, 0.2, 0.5], [0.2, 1, 0.25], [0.5, 0.25, 1]])
61 | expected = np.array([[0.20, 1, 0.70], [1, 0.20, 0.95], [0.70, 0.95, 0.20]])
62 | actual = cv.sim_to_dist(x, "reverse")
63 | assert_almost_equal(actual, expected, decimal=10)
64 |
65 |
66 | def test_sim_2_dist_frequency():
67 | """Test similarity to distance method with a frequency metric."""
68 | x = np.array([[4, 9, 1], [9, 1, 25], [1, 25, 16]])
69 | expected = np.array([[(1 / 2), (1 / 3), 1], [(1 / 3), 1, (1 / 5)], [1, (1 / 5), (1 / 4)]])
70 | actual = cv.sim_to_dist(x, "transition")
71 | assert_almost_equal(actual, expected, decimal=10)
72 |
73 |
74 | def test_sim_2_dist_frequency_error():
75 | """Test similarity to distance method with a frequency metric and incorrect input."""
76 | # zeroes in the frequency matrix
77 | x = np.array([[0, 9, 1], [9, 1, 25], [1, 25, 0]])
78 | assert_raises(ValueError, cv.sim_to_dist, x, "gravity")
79 | # negatives in the frequency matrix
80 | y = np.array([[1, -9, 1], [9, 1, -25], [1, 25, 16]])
81 | assert_raises(ValueError, cv.sim_to_dist, x, "gravity")
82 |
83 |
84 | def test_sim_2_dist_membership():
85 | """Test similarity to distance method with the membership metric."""
86 | # x = np.array([[(1 / 2), (1 / 5)], [(1 / 4), (1 / 3)]])
87 | x = np.array([[1, 1 / 5, 1 / 3], [1 / 5, 1, 4 / 5], [1 / 3, 4 / 5, 1]])
88 | expected = np.array([[0, 4 / 5, 2 / 3], [4 / 5, 0, 1 / 5], [2 / 3, 1 / 5, 0]])
89 | actual = cv.sim_to_dist(x, "membership")
90 | assert_almost_equal(actual, expected, decimal=10)
91 |
92 |
93 | def test_sim_2_dist_membership_error():
94 | """Test similarity to distance method with the membership metric when there is an input error."""
95 | x = np.array([[1, 0, -7], [0, 1, 3], [-7, 3, 1]])
96 | assert_raises(ValueError, cv.sim_to_dist, x, "membership")
97 |
98 |
99 | def test_sim_2_dist_invalid_metric():
100 | """Test similarity to distance method with an unsupported metric."""
101 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "testing")
102 |
103 |
104 | def test_sim_2_dist_non_symmetric():
105 | """Test the invalid 2D symmetric matrix error."""
106 | x = np.array([[1, 2], [4, 5]])
107 | assert_raises(ValueError, cv.sim_to_dist, x, "reverse")
108 |
109 |
110 | # Tests for individual metrics
111 |
112 |
113 | def test_reverse():
114 | """Test the reverse function for similarity to distance conversion."""
115 | x = np.array([[3, 1, 1], [1, 3, 0], [1, 0, 3]])
116 | expected = np.array([[0, 2, 2], [2, 0, 3], [2, 3, 0]])
117 | actual = cv.reverse(x)
118 | assert_equal(actual, expected)
119 |
120 |
121 | def test_reciprocal():
122 | """Test the reverse function for similarity to distance conversion."""
123 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]])
124 | expected = np.array([[1, 4, 2.5], [4, 1, 1.6], [2.5, 1.6, 1]])
125 | actual = cv.reciprocal(x)
126 | assert_equal(actual, expected)
127 |
128 |
129 | def test_reciprocal_error():
130 | """Test the reverse function with incorrect input values."""
131 | # zero value for similarity (causes divide by zero issues)
132 | x = np.array([[0, 4], [3, 2]])
133 | assert_raises(ValueError, cv.reciprocal, x)
134 | # negative value for similarity (distance cannot be negative)
135 | y = np.array([[1, -4], [3, 2]])
136 | assert_raises(ValueError, cv.reciprocal, y)
137 |
138 |
139 | def test_exponential():
140 | """Test the exponential function for similarity to distance conversion."""
141 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]])
142 | expected = np.array(
143 | [
144 | [0, 1.38629436112, 0.91629073187],
145 | [1.38629436112, 0, 0.47000362924],
146 | [0.91629073187, 0.47000362924, 0],
147 | ]
148 | )
149 | actual = cv.exponential(x)
150 | assert_almost_equal(actual, expected, decimal=10)
151 |
152 |
153 | def test_exponential_error():
154 | """Test the exponential function when max similarity is zero."""
155 | x = np.zeros((4, 4))
156 | assert_raises(ValueError, cv.exponential, x)
157 |
158 |
159 | def test_gaussian():
160 | """Test the gaussian function for similarity to distance conversion."""
161 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]])
162 | expected = np.array(
163 | [
164 | [0, 1.17741002252, 0.95723076208],
165 | [1.17741002252, 0, 0.68556810693],
166 | [0.95723076208, 0.68556810693, 0],
167 | ]
168 | )
169 | actual = cv.gaussian(x)
170 | assert_almost_equal(actual, expected, decimal=10)
171 |
172 |
173 | def test_gaussian_error():
174 | """Test the gaussian function when max similarity is zero."""
175 | x = np.zeros((4, 4))
176 | assert_raises(ValueError, cv.gaussian, x)
177 |
178 |
179 | def test_correlation():
180 | """Test the correlation to distance conversion function."""
181 | x = np.array([[1, 0.5, 0.2], [0.5, 1, -0.2], [0.2, -0.2, 1]])
182 | # expected = sqrt(1-x)
183 | expected = np.array(
184 | [
185 | [0, 0.70710678118, 0.894427191],
186 | [0.70710678118, 0, 1.09544511501],
187 | [0.894427191, 1.09544511501, 0],
188 | ]
189 | )
190 | actual = cv.correlation(x)
191 | assert_almost_equal(actual, expected, decimal=10)
192 |
193 |
194 | def test_correlation_error():
195 | """Test the correlation function with an out of bounds array."""
196 | x = np.array([[1, 0, -7], [0, 1, 3], [-7, 3, 1]])
197 | assert_raises(ValueError, cv.correlation, x)
198 |
199 |
200 | def test_transition():
201 | """Test the transition function for frequency to distance conversion."""
202 | x = np.array([[4, 9, 1], [9, 1, 25], [1, 25, 16]])
203 | expected = np.array([[(1 / 2), (1 / 3), 1], [(1 / 3), 1, (1 / 5)], [1, (1 / 5), (1 / 4)]])
204 |
205 | actual = cv.transition(x)
206 | assert_almost_equal(actual, expected, decimal=10)
207 |
208 |
209 | def test_co_occurrence():
210 | """Test the co-occurrence conversion function."""
211 | x = np.array([[1, 2, 3], [2, 1, 3], [3, 3, 1]])
212 | expected = np.array(
213 | [
214 | [1 / (19 / 121 + 1), 1 / (38 / 121 + 1), 1 / (57 / 121 + 1)],
215 | [1 / (38 / 121 + 1), 1 / (19 / 121 + 1), 1 / (57 / 121 + 1)],
216 | [1 / (57 / 121 + 1), 1 / (57 / 121 + 1), 1 / (19 / 121 + 1)],
217 | ]
218 | )
219 | actual = cv.co_occurrence(x)
220 | assert_almost_equal(actual, expected, decimal=10)
221 |
222 |
223 | def test_gravity():
224 | """Test the gravity conversion function."""
225 | x = np.array([[1, 2, 3], [2, 1, 3], [3, 3, 1]])
226 | expected = np.array(
227 | [
228 | [2.5235730726, 1.7844356324, 1.45698559277],
229 | [1.7844356324, 2.5235730726, 1.45698559277],
230 | [1.45698559277, 1.45698559277, 2.5235730726],
231 | ]
232 | )
233 | actual = cv.gravity(x)
234 | assert_almost_equal(actual, expected, decimal=10)
235 |
236 |
237 | def test_probability():
238 | """Test the probability to distance conversion function."""
239 | x = np.array([[0.3, 0.7], [0.5, 0.5]])
240 | expected = np.array([[1.8116279322, 1.1356324735], [1.3819765979, 1.3819765979]])
241 | actual = cv.probability(x)
242 | assert_almost_equal(actual, expected, decimal=10)
243 |
244 |
245 | def test_probability_error():
246 | """Test the correlation function with an out of bounds array."""
247 | # negative value for probability
248 | x = np.array([[-0.5]])
249 | assert_raises(ValueError, cv.probability, x)
250 | # too large value for probability
251 | y = np.array([[3]])
252 | assert_raises(ValueError, cv.probability, y)
253 | # zero value for probability (causes divide by zero issues)
254 | z = np.array([[0]])
255 | assert_raises(ValueError, cv.probability, z)
256 |
257 |
258 | def test_covariance():
259 | """Test the covariance to distance conversion function."""
260 | x = np.array([[4, -4], [-4, 6]])
261 | expected = np.array([[0, 4.24264068712], [4.24264068712, 0]])
262 | actual = cv.covariance(x)
263 | assert_almost_equal(actual, expected, decimal=10)
264 |
265 |
266 | def test_covariance_error():
267 | """Test the covariance function when input contains a negative variance."""
268 | x = np.array([[-4, 4], [4, 6]])
269 | assert_raises(ValueError, cv.covariance, x)
270 |
--------------------------------------------------------------------------------
/selector/measures/tests/test_diversity.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 | """Test Diversity Module."""
26 | import warnings
27 |
28 | import numpy as np
29 | import pytest
30 | from numpy.testing import assert_almost_equal, assert_equal, assert_raises, assert_warns
31 |
32 | from selector.measures.diversity import (
33 | compute_diversity,
34 | explicit_diversity_index,
35 | gini_coefficient,
36 | hypersphere_overlap_of_subset,
37 | logdet,
38 | nearest_average_tanimoto,
39 | shannon_entropy,
40 | wdud,
41 | )
42 |
43 | # each row is a feature and each column is a molecule
44 | sample1 = np.array([[4, 2, 6], [4, 9, 6], [2, 5, 0], [2, 0, 9], [5, 3, 0]])
45 |
46 | # each row is a molecule and each column is a feature (scipy)
47 | sample2 = np.array([[1, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]])
48 |
49 | sample3 = np.array([[1, 4], [3, 2]])
50 |
51 | sample4 = np.array([[1, 0, 1], [0, 1, 1]])
52 |
53 | sample5 = np.array([[0, 2, 4, 0], [1, 2, 4, 0], [2, 2, 4, 0]])
54 |
55 | sample6 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]])
56 |
57 | sample7 = np.array([[1, 0, 1, 0] for _ in range(4)])
58 |
59 | sample8 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
60 |
61 |
62 | def test_compute_diversity_specified():
63 | """Test compute diversity with a specified div_type."""
64 | comp_div = compute_diversity(sample6, "shannon_entropy", normalize=False, truncation=False)
65 | expected = 1.81
66 | assert round(comp_div, 2) == expected
67 |
68 |
69 | def test_compute_diversity_hyperspheres():
70 | """Test compute diversity with two arguments for hypersphere_overlap method"""
71 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
72 | centers_pts = np.array([[0.5, 0.5]] * (100 - 4))
73 | pts = np.vstack((corner_pts, centers_pts))
74 |
75 | comp_div = compute_diversity(pts, div_type="hypersphere_overlap", features=pts)
76 | # Expected = overlap + edge penalty
77 | expected = (100.0 * 96 * 95 * 0.5) + 2.0
78 | assert_almost_equal(comp_div, expected)
79 |
80 |
81 | def test_compute_diversity_hypersphere_error():
82 | """Test compute diversity with hypersphere metric and no molecule library given."""
83 | assert_raises(ValueError, compute_diversity, sample5, "hypersphere_overlap")
84 |
85 |
86 | def test_compute_diversity_edi():
87 | """Test compute diversity with explicit diversity index div_type"""
88 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
89 | cs = 1
90 | expected = 56.39551204
91 | actual = compute_diversity(z, "explicit_diversity_index", cs=cs)
92 | assert_almost_equal(expected, actual)
93 |
94 |
95 | def test_compute_diversity_edi_no_cs_error():
96 | """Test compute diversity with explicit diversity index and no `cs` value given."""
97 | assert_raises(ValueError, compute_diversity, sample5, "explicit_diversity_index")
98 |
99 |
100 | def test_compute_diversity_edi_zero_error():
101 | """Test compute diversity with explicit diversity index and `cs` = 0."""
102 | assert_raises(ValueError, compute_diversity, sample5, "explicit diversity index", cs=0)
103 |
104 |
105 | def test_compute_diversity_invalid():
106 | """Test compute diversity with a non-supported div_type."""
107 | assert_raises(ValueError, compute_diversity, sample1, "diversity_type")
108 |
109 |
110 | def test_logdet():
111 | """Test the log determinant function with predefined subset matrix."""
112 | sel = logdet(sample3)
113 | expected = np.log(131)
114 | assert_almost_equal(sel, expected)
115 |
116 |
117 | def test_logdet_non_square_matrix():
118 | """Test the log determinant function with a rectangular matrix."""
119 | sel = logdet(sample4)
120 | expected = np.log(8)
121 | assert_almost_equal(sel, expected)
122 |
123 |
124 | def test_shannon_entropy():
125 | """Test the shannon entropy function with example from the original paper."""
126 |
127 | # example taken from figure 1 of 10.1021/ci900159f
128 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]])
129 | expected = 1.81
130 | assert round(shannon_entropy(x1, normalize=False, truncation=False), 2) == expected
131 |
132 | x2 = np.vstack((x1, [1, 1, 1, 0]))
133 | expected = 1.94
134 | assert round(shannon_entropy(x2, normalize=False, truncation=False), 2) == expected
135 |
136 | x3 = np.vstack((x1, [0, 1, 0, 1]))
137 | expected = 3.39
138 | assert round(shannon_entropy(x3, normalize=False, truncation=False), 2) == expected
139 |
140 |
141 | def test_shannon_entropy_normalize():
142 | """Test the shannon entropy function with normalization."""
143 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]])
144 | expected = 1.81 / (x1.shape[1] * np.log2(2) / 2)
145 | assert_almost_equal(
146 | actual=shannon_entropy(x1, normalize=True, truncation=False),
147 | desired=expected,
148 | decimal=2,
149 | )
150 |
151 |
152 | def test_shannon_entropy_warning():
153 | """Test the shannon entropy function gives warning when normalization is True and truncation is True."""
154 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]])
155 | with pytest.warns(UserWarning):
156 | shannon_entropy(x1, normalize=True, truncation=True)
157 |
158 |
159 | def test_shannon_entropy_binary_error():
160 | """Test the shannon entropy function raises error with a non binary matrix."""
161 | assert_raises(ValueError, shannon_entropy, sample5)
162 |
163 |
164 | def test_explicit_diversity_index():
165 | """Test the explicit diversity index function."""
166 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
167 | cs = 1
168 | nc = 3
169 | sdi = 0.75 / 0.7332902012
170 | cr = -1 * 0.4771212547
171 | edi = 0.5456661753 * 0.7071067811865476
172 | edi_scaled = 56.395512045413
173 | value = explicit_diversity_index(z, cs)
174 | assert_almost_equal(value, edi_scaled, decimal=8)
175 |
176 |
177 | def test_wdud_uniform():
178 | """Test wdud when a feature has uniform distribution."""
179 | uni = np.arange(0, 50000)[:, None]
180 | wdud_val = wdud(uni)
181 | expected = 0
182 | assert_almost_equal(wdud_val, expected, decimal=4)
183 |
184 |
185 | def test_wdud_repeat_yi():
186 | """Test wdud when a feature has multiple identical values."""
187 | dist = np.array([[0, 0.5, 0.5, 0.75, 1]]).T
188 | wdud_val = wdud(dist)
189 | # calculated using wolfram alpha:
190 | expected = 0.065 + 0.01625 + 0.02125
191 | assert_almost_equal(wdud_val, expected, decimal=4)
192 |
193 |
194 | def test_wdud_mult_features():
195 | """Test wdud when there are multiple features per molecule."""
196 | dist = np.array(
197 | [
198 | [0, 0.5, 0.5, 0.75, 1],
199 | [0, 0.5, 0.5, 0.75, 1],
200 | [0, 0.5, 0.5, 0.75, 1],
201 | [0, 0.5, 0.5, 0.75, 1],
202 | ]
203 | ).T
204 | wdud_val = wdud(dist)
205 | # calculated using wolfram alpha:
206 | expected = 0.065 + 0.01625 + 0.02125
207 | assert_almost_equal(wdud_val, expected, decimal=4)
208 |
209 |
210 | def test_wdud_dimension_error():
211 | """Test wdud method raises error when input has incorrect dimensions."""
212 | arr = np.zeros((2, 2, 2))
213 | assert_raises(ValueError, wdud, arr)
214 |
215 |
216 | def test_wdud_normalization_error():
217 | """Test wdud method raises error when normalization fails."""
218 | assert_raises(ValueError, wdud, sample8)
219 |
220 |
221 | def test_wdud_warning_normalization():
222 | """Test wdud method gives warning when normalization fails."""
223 | warning_message = (
224 | "Some of the features are constant which will cause the normalization to fail. "
225 | + "Now removing them."
226 | )
227 | with pytest.warns() as record:
228 | wdud(sample6)
229 |
230 | # check that the message matches
231 | assert record[0].message.args[0] == warning_message
232 |
233 |
234 | def test_hypersphere_overlap_of_subset_with_only_corners_and_center():
235 | """Test the hypersphere overlap method with predefined matrix."""
236 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
237 | # Many duplicate pts cause r_0 to be much smaller than 1.0,
238 | # which is required due to normalization of the feature space
239 | centers_pts = np.array([[0.5, 0.5]] * (100 - 4))
240 | pts = np.vstack((corner_pts, centers_pts))
241 |
242 | # Overlap should be all coming from the centers
243 | expected_overlap = 100.0 * 96 * 95 * 0.5
244 | # The edge penalty should all be from the corner pts
245 | lam = 1.0 / 2.0 # Default lambda chosen from paper.
246 | expected_edge = lam * 4.0
247 | expected = expected_overlap + expected_edge
248 | true = hypersphere_overlap_of_subset(pts, pts)
249 | assert_almost_equal(true, expected)
250 |
251 |
252 | def test_hypersphere_normalization_error():
253 | """Test the hypersphere overlap method raises error when normalization fails."""
254 | assert_raises(ValueError, hypersphere_overlap_of_subset, sample7, sample7)
255 |
256 |
257 | def test_hypersphere_radius_warning():
258 | """Test the hypersphere overlap method gives warning when radius is too large."""
259 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
260 | assert_warns(Warning, hypersphere_overlap_of_subset, corner_pts, corner_pts)
261 |
262 |
263 | def test_gini_coefficient_of_non_diverse_set():
264 | """Test Gini coefficient of the least diverse set. Expected return is zero."""
265 | # Finger-prints where columns are all the same
266 | numb_molecules = 5
267 | numb_features = 10
268 | # Transpose so that the columns are all the same, note first made the rows all same
269 | single_fingerprint = list(np.random.choice([0, 1], size=(numb_features,)))
270 | finger_prints = np.array([single_fingerprint] * numb_molecules).T
271 |
272 | result = gini_coefficient(finger_prints)
273 | # Since they are all the same, then gini coefficient should be zero.
274 | assert_almost_equal(result, 0.0, decimal=8)
275 |
276 |
277 | def test_gini_coefficient_non_binary_error():
278 | """Test Gini coefficient error when input is not binary."""
279 | assert_raises(ValueError, gini_coefficient, np.array([[7, 0], [2, 1]]))
280 |
281 |
282 | def test_gini_coefficient_dimension_error():
283 | """Test Gini coefficient error when input has incorrect dimensions."""
284 | assert_raises(ValueError, gini_coefficient, np.array([1, 0, 0, 0]))
285 |
286 |
287 | def test_gini_coefficient_of_most_diverse_set():
288 | """Test Gini coefficient of the most diverse set."""
289 | # Finger-prints where one feature has more `wealth` than all others.
290 | # Note: Transpose is done so one column has all ones.
291 | finger_prints = np.array(
292 | [
293 | [1, 1, 1, 1, 1, 1, 1],
294 | ]
295 | + [[0, 0, 0, 0, 0, 0, 0]] * 100000
296 | ).T
297 | result = gini_coefficient(finger_prints)
298 | # Since they are all the same, then gini coefficient should be zero.
299 | assert_almost_equal(result, 1.0, decimal=4)
300 |
301 |
302 | def test_gini_coefficient_with_alternative_definition():
303 | """Test Gini coefficient with alternative definition."""
304 | # Finger-prints where they are all different
305 | numb_features = 4
306 | finger_prints = np.array([[1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1]])
307 | result = gini_coefficient(finger_prints)
308 |
309 | # Alternative definition from wikipedia
310 | b = numb_features + 1
311 | desired = (
312 | numb_features + 1 - 2 * ((b - 1) + (b - 2) * 2 + (b - 3) * 3 + (b - 4) * 4) / (10)
313 | ) / numb_features
314 | assert_almost_equal(result, desired)
315 |
316 |
317 | def test_nearest_average_tanimoto_bit():
318 | """Test the nearest_average_tanimoto function with binary input."""
319 | nat = nearest_average_tanimoto(sample2)
320 | shortest_tani = [0.3333333, 0.3333333, 0, 0]
321 | average = np.average(shortest_tani)
322 | assert_almost_equal(nat, average)
323 |
324 |
325 | def test_nearest_average_tanimoto():
326 | """Test the nearest_average_tanimoto function with non-binary input."""
327 | nat = nearest_average_tanimoto(sample3)
328 | shortest_tani = [(11 / 19), (11 / 19)]
329 | average = np.average(shortest_tani)
330 | assert_equal(nat, average)
331 |
332 |
333 | def test_nearest_average_tanimoto_3_x_3():
334 | """Testpyth the nearest_average_tanimoto function with a 3x3 matrix."""
335 | # all unequal distances b/w points
336 | x = np.array([[0, 1, 2], [3, 4, 5], [4, 5, 6]])
337 | nat_x = nearest_average_tanimoto(x)
338 | avg_x = 0.749718574108818
339 | assert_equal(nat_x, avg_x)
340 | # one point equidistant from the other two
341 | y = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
342 | nat_y = nearest_average_tanimoto(y)
343 | avg_y = 0.4813295920569825
344 | assert_equal(nat_y, avg_y)
345 | # all points equidistant
346 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]])
347 | nat_z = nearest_average_tanimoto(z)
348 | avg_z = 0.25
349 | assert_equal(nat_z, avg_z)
350 |
351 |
352 | def test_nearest_average_tanimoto_nonsquare():
353 | """Test the nearest_average_tanimoto function with non-binary input"""
354 | x = np.array([[3.5, 4.0, 10.5, 0.5], [1.25, 4.0, 7.0, 0.1], [0.0, 0.0, 0.0, 0.0]])
355 | # nearest neighbor of sample 0, 1, and 2 are sample 1, 0, and 1, respectively.
356 | expected = np.average(
357 | [
358 | np.sum(x[0] * x[1]) / (np.sum(x[0] ** 2) + np.sum(x[1] ** 2) - np.sum(x[0] * x[1])),
359 | np.sum(x[1] * x[0]) / (np.sum(x[1] ** 2) + np.sum(x[0] ** 2) - np.sum(x[1] * x[0])),
360 | np.sum(x[2] * x[1]) / (np.sum(x[2] ** 2) + np.sum(x[1] ** 2) - np.sum(x[2] * x[1])),
361 | ]
362 | )
363 | assert_equal(nearest_average_tanimoto(x), expected)
364 |
--------------------------------------------------------------------------------
/selector/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 | from selector.methods.distance import *
26 | from selector.methods.partition import *
27 | from selector.methods.similarity import *
28 |
--------------------------------------------------------------------------------
/selector/methods/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Base class for diversity based subset selection."""
25 |
26 | import warnings
27 | from abc import ABC, abstractmethod
28 | from typing import List, Iterable, Union
29 |
30 | import numpy as np
31 |
32 | __all__ = ["SelectionBase"]
33 |
34 |
35 | class SelectionBase(ABC):
36 | """Base class for selecting subset of sample points."""
37 |
38 | def select(
39 | self,
40 | x: np.ndarray,
41 | size: int,
42 | labels: np.ndarray = None,
43 | proportional_selection: bool = True,
44 | ) -> Union[List, Iterable]:
45 | """Return indices representing subset of sample points.
46 |
47 | Parameters
48 | ----------
49 | x: ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
50 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space.
51 | If fun_distance is `None`, this x is treated as a square pairwise distance matrix.
52 | size: int
53 | Number of sample points to select (i.e. size of the subset).
54 | labels: np.ndarray, optional
55 | Array of integers or strings representing the labels of the clusters that
56 | each sample belongs to. If `None`, the samples are treated as one cluster.
57 | If labels are provided, selection is made from each cluster.
58 | proportional_selection: bool, optional
59 | If True, the number of samples to be selected from each cluster is proportional.
60 | Otherwise, the number of samples to be selected from each cluster is equal.
61 | Default is True.
62 |
63 | Returns
64 | -------
65 | selected: list
66 | Indices of the selected sample points.
67 | """
68 | # check size
69 | if size > len(x):
70 | raise ValueError(
71 | f"Size of subset {size} cannot be larger than number of samples {len(x)}."
72 | )
73 |
74 | # if labels are not provided, indices selected from one cluster is returned
75 | if labels is None:
76 | return self.select_from_cluster(x, size)
77 |
78 | # check labels are consistent with number of samples
79 | if len(labels) != len(x):
80 | raise ValueError(
81 | f"Number of labels {len(labels)} does not match number of samples {len(x)}."
82 | )
83 |
84 | selected_ids = []
85 |
86 | # compute the number of samples (i.e. population or pop) in each cluster
87 | unique_labels, unique_label_counts = np.unique(labels, return_counts=True)
88 | num_clusters = len(unique_labels)
89 | pop_clusters = dict(zip(unique_labels, unique_label_counts))
90 | # compute number of samples to be selected from each cluster
91 | if proportional_selection:
92 | # make sure that tht total number of samples selected is equal to size
93 | size_each_cluster = size * unique_label_counts / len(labels)
94 | # using np.round to get to the nearest integer
95 | # not using int function directly to avoid truncation of decimal values
96 | size_each_cluster = np.round(size_each_cluster).astype(int)
97 | # make sure each cluster has at least one sample
98 | size_each_cluster[size_each_cluster < 1] = 1
99 |
100 | # the total number of samples selected from all clusters at this point
101 | size_each_cluster_total = np.sum(size_each_cluster)
102 | # when the total of data points in each class is less than the required number
103 | # add one sample to the smallest cluster iteratively until the total is equal to the
104 | # required number
105 | if size_each_cluster_total < size:
106 | while size_each_cluster_total < size:
107 | # the number of remaining data points in each cluster
108 | size_each_cluster_remaining = unique_label_counts - size_each_cluster_total
109 | # skip the clusters with no data points left
110 | size_each_cluster_remaining[size_each_cluster_remaining == 0] = np.inf
111 | smallest_cluster_index = np.argmin(size_each_cluster_remaining)
112 | size_each_cluster[smallest_cluster_index] += 1
113 | size_each_cluster_total += 1
114 | # when the total of data points in each class is more than the required number
115 | # we need to remove samples from the largest clusters
116 | elif size_each_cluster_total > size:
117 | while size_each_cluster_total > size:
118 | largest_cluster_index = np.argmax(size_each_cluster)
119 | size_each_cluster[largest_cluster_index] -= 1
120 | size_each_cluster_total -= 1
121 | # perfect case where the total is equal to the required number
122 | else:
123 | pass
124 | else:
125 | size_each_cluster = size // num_clusters
126 |
127 | # update number of samples to select from each cluster based on the cluster population.
128 | # this is needed when some clusters do not have enough samples in them
129 | # (pop < size_each_cluster) and needs to be done iteratively until all remaining clusters
130 | # have at least size_each_cluster samples
131 | while np.any(
132 | [value <= size_each_cluster for value in pop_clusters.values() if value != 0]
133 | ):
134 | for unique_label in unique_labels:
135 | if pop_clusters[unique_label] != 0:
136 | # get index of sample labelled with unique_label
137 | cluster_ids = np.where(labels == unique_label)[0]
138 | if len(cluster_ids) <= size_each_cluster:
139 | # all samples in the cluster are selected & population becomes zero
140 | selected_ids.append(cluster_ids)
141 | pop_clusters[unique_label] = 0
142 | # update number of samples to be selected from each cluster
143 | totally_used_clusters = list(pop_clusters.values()).count(0)
144 | size_each_cluster = (size - len(np.hstack(selected_ids))) // (
145 | num_clusters - totally_used_clusters
146 | )
147 |
148 | warnings.warn(
149 | f"Number of molecules in one cluster is less than"
150 | f" {size}/{num_clusters}.\nNumber of selected "
151 | f"molecules might be less than desired.\nIn order to avoid this "
152 | f"problem. Try to use less number of clusters."
153 | )
154 | # save the number of samples to be selected from each cluster in an array
155 | size_each_cluster = np.full(num_clusters, size_each_cluster)
156 |
157 | for unique_label, size_sub in zip(unique_labels, size_each_cluster):
158 | if pop_clusters[unique_label] != 0:
159 | # sample size_each_cluster ids from cluster labeled unique_label
160 | cluster_ids = np.where(labels == unique_label)[0]
161 | selected = self.select_from_cluster(x, size_sub, cluster_ids)
162 | selected_ids.append(cluster_ids[selected])
163 |
164 | return np.hstack(selected_ids).flatten().tolist()
165 |
166 | @abstractmethod
167 | def select_from_cluster(
168 | self, x: np.ndarray, size: int, labels: np.ndarray = None
169 | ) -> np.ndarray: # pragma: no cover
170 | """Return indices representing subset of sample points from one cluster.
171 |
172 | Parameters
173 | ----------
174 | x: ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
175 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space.
176 | If fun_distance is `None`, this x is treated as a square pairwise distance matrix.
177 | size: int
178 | Number of sample points to select (i.e. size of the subset).
179 | labels: np.ndarray, optional
180 | Array of integers or strings representing the labels of the clusters that
181 | each sample belongs to. If `None`, the samples are treated as one cluster.
182 | If labels are provided, selection is made from each cluster.
183 |
184 | Returns
185 | -------
186 | selected: list
187 | Indices of the selected sample points.
188 | """
189 | raise NotImplementedError
190 |
--------------------------------------------------------------------------------
/selector/methods/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
--------------------------------------------------------------------------------
/selector/methods/tests/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Common functions for test module."""
25 |
26 | from importlib import resources
27 | from typing import Any, Tuple, Union
28 |
29 | import numpy as np
30 | from sklearn.datasets import make_blobs
31 | from sklearn.metrics import pairwise_distances
32 |
33 | __all__ = [
34 | "generate_synthetic_cluster_data",
35 | "generate_synthetic_data",
36 | "get_data_file_path",
37 | ]
38 |
39 |
40 | def generate_synthetic_cluster_data():
41 | # generate the first cluster with 3 points
42 | cluster_one = np.array([[0, 0], [0, 1], [0, 2]])
43 | # generate the second cluster with 6 points
44 | cluster_two = np.array([[3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5]])
45 | # generate the third cluster with 9 points
46 | cluster_three = np.array(
47 | [[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8]]
48 | )
49 | # concatenate the clusters
50 | coords = np.vstack([cluster_one, cluster_two, cluster_three])
51 | # generate the labels
52 | labels = np.hstack([[0 for _ in range(3)], [1 for _ in range(6)], [2 for _ in range(9)]])
53 |
54 | return coords, labels, cluster_one, cluster_two, cluster_three
55 |
56 |
57 | def generate_synthetic_data(
58 | n_samples: int = 100,
59 | n_features: int = 2,
60 | n_clusters: int = 2,
61 | cluster_std: float = 1.0,
62 | center_box: Tuple[float, float] = (-10.0, 10.0),
63 | metric: str = "euclidean",
64 | shuffle: bool = True,
65 | random_state: int = 42,
66 | pairwise_dist: bool = False,
67 | **kwargs: Any,
68 | ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
69 | """Generate synthetic data.
70 |
71 | Parameters
72 | ----------
73 | n_samples : int, optional
74 | The number of sample points.
75 | n_features : int, optional
76 | The number of features.
77 | n_clusters : int, optional
78 | The number of clusters.
79 | cluster_std : float, optional
80 | The standard deviation of the clusters.
81 | center_box : tuple[float, float], optional
82 | The bounding box for each cluster center when centers are generated at random.
83 | metric : str, optional
84 | The metric used for computing pairwise distances. For the supported
85 | distance matrix, please refer to
86 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances.html.
87 | shuffle : bool, optional
88 | Whether to shuffle the samples.
89 | random_state : int, optional
90 | The random state used for generating synthetic data.
91 | pairwise_dist : bool, optional
92 | If True, then compute and return the pairwise distances between sample points.
93 | **kwargs : Any, optional
94 | Additional keyword arguments for the scikit-learn `pairwise_distances` function.
95 |
96 | Returns
97 | -------
98 | syn_data : np.ndarray
99 | The synthetic data.
100 | class_labels : np.ndarray
101 | The integer labels for cluster membership of each sample.
102 | dist: np.ndarray
103 | The symmetric pairwise distances between samples.
104 |
105 | """
106 | # pylint: disable=W0632
107 | syn_data, class_labels = make_blobs(
108 | n_samples=n_samples,
109 | n_features=n_features,
110 | centers=n_clusters,
111 | cluster_std=cluster_std,
112 | center_box=center_box,
113 | shuffle=shuffle,
114 | random_state=random_state,
115 | return_centers=False,
116 | )
117 | if pairwise_dist:
118 | dist = pairwise_distances(
119 | X=syn_data,
120 | Y=None,
121 | metric=metric,
122 | **kwargs,
123 | )
124 | return syn_data, class_labels, dist
125 | return syn_data, class_labels
126 |
127 |
128 | def get_data_file_path(file_name):
129 | """Get the absolute path of the data file inside the package.
130 |
131 | Parameters
132 | ----------
133 | file_name : str
134 | The name of the data file to load.
135 |
136 | Returns
137 | -------
138 | str
139 | The absolute path of the data file inside the package
140 |
141 | """
142 | data_file_path = resources.files("selector.methods.tests").joinpath(f"data/{file_name}")
143 |
144 | return data_file_path
145 |
--------------------------------------------------------------------------------
/selector/methods/tests/data/coords_imbalance_case1.txt:
--------------------------------------------------------------------------------
1 | -2.988371860898040300e+00,8.828627151534506723e+00
2 | -2.522694847790684314e+00,7.956575199242423402e+00
3 | 2.721107620929060111e+00,1.946655808491515094e+00
4 | 3.856625543891864183e+00,1.651108167735056309e+00
5 | 4.447517871446978965e+00,2.274717026274344356e+00
6 | 4.247770683095943411e+00,5.096547358086134238e-01
7 | 5.161820401844998685e+00,2.270154357173918225e+00
8 | 3.448575339025452990e+00,2.629723292574561722e+00
9 | 4.110118632461063015e+00,2.486437117054088208e+00
10 | 4.605167066522858121e+00,8.044916463211999602e-01
11 | 3.959854114649610679e+00,2.205423381101735636e+00
12 | 4.935999113292677265e+00,2.234224956120621108e+00
13 | -7.194896435791616085e+00,-6.121140372782679862e+00
14 | -6.521839830802987237e+00,-6.319325066907712340e+00
15 | -6.665533447021066316e+00,-8.125848371987935082e+00
16 | -4.564968624477761416e+00,-8.747374785867695124e+00
17 | -4.735683101825944874e+00,-6.246190570957935506e+00
18 | -7.144284024389226495e+00,-4.159940426686327797e+00
19 | -6.364591923942610308e+00,-6.366323642363737711e+00
20 | -7.769141620776792934e+00,-7.695919878241385348e+00
21 | -6.821418472705270020e+00,-8.023079891106569050e+00
22 | -7.541413655919658510e+00,-6.027676258479722549e+00
23 | -6.706446265300088250e+00,-6.494792213547110116e+00
24 | -6.406389566577725070e+00,-6.952938505932819702e+00
25 | -7.609993822868406532e+00,-6.663651003693972008e+00
26 | -5.796575947975993515e+00,-5.826307541241043886e+00
27 | -7.351559056940703663e+00,-5.791158996308579887e+00
28 | -7.364990738980373486e+00,-6.798235453889623692e+00
29 | -6.956728900565374296e+00,-6.538957618459303234e+00
30 | -6.253959843386263984e+00,-7.737267149692229395e+00
31 | -6.057567031156779969e+00,-4.983316610621999487e+00
32 | -7.594930900411238639e+00,-6.200511844341271228e+00
33 | -7.125015307154140665e+00,-7.633845757633435980e+00
34 | -7.672147929583970516e+00,-6.994846034742845831e+00
35 | -7.103089976477121148e+00,-6.166109099183854525e+00
36 | -6.602936391821250695e+00,-6.052926344239923040e+00
37 | -8.904769777808876796e+00,-6.693655278506518869e+00
38 | -8.257296559108361578e+00,-7.817934633191069516e+00
39 | -6.364579504845222502e+00,-3.027378102621225864e+00
40 | -6.834055351247456223e+00,-7.531709940881763821e+00
41 | -7.652452405688841885e+00,-7.116928200015955497e+00
42 | -7.726420909219674726e+00,-8.394956817961810813e+00
43 | -6.866625299273363403e+00,-5.426575516118630205e+00
44 | -6.374639912170812828e+00,-6.014354399105824811e+00
45 | -7.326142143218291380e+00,-6.023710798952474299e+00
46 | -6.308736680458102875e+00,-5.744543953095347710e+00
47 | -8.079923598207045643e+00,-7.214610829116894664e+00
48 | -6.193367000776756726e+00,-8.492825464465598273e+00
49 | -5.925625427658067323e+00,-6.228718341970148842e+00
50 | -7.950519689212382168e+00,-6.397637178032761440e+00
51 | -7.763484627352402967e+00,-6.726384487330419049e+00
52 | -6.815347172055806979e+00,-7.957854371205252519e+00
53 |
--------------------------------------------------------------------------------
/selector/methods/tests/data/coords_imbalance_case2.txt:
--------------------------------------------------------------------------------
1 | -2.545023662162701594e+00,1.057892978401232931e+01
2 | -3.348415146275388832e+00,8.705073752347109561e+00
3 | -3.186119623358708797e+00,9.625962417039191976e+00
4 | 6.526064737438631802e+00,2.147747496772570930e+00
5 | 5.265546183993107476e+00,1.116012127524449449e+00
6 | 3.793085118159696290e+00,4.583224592548673648e-01
7 | 4.605167066522858121e+00,8.044916463211999602e-01
8 | 3.665197166000779827e+00,2.760254287683184149e+00
9 | 4.890371686573978138e+00,2.319617893437707856e+00
10 | 3.089215405161968686e+00,2.041732658746759466e+00
11 | 4.416416050902250312e+00,2.687170178032824097e+00
12 | 3.568986338166989292e+00,2.455642099183917182e+00
13 | 4.447517871446978965e+00,2.274717026274344356e+00
14 | 5.161820401844998685e+00,2.270154357173918225e+00
15 | -6.598635323416237597e+00,-7.502809113096540194e+00
16 | -6.364591923942610308e+00,-6.366323642363737711e+00
17 | -7.351559056940703663e+00,-5.791158996308579887e+00
18 | -4.757470994138636833e+00,-5.847644332724799554e+00
19 | -7.132195342544430439e+00,-8.127892775240795231e+00
20 | -6.766109845900022179e+00,-6.217978918754900164e+00
21 | -6.680567495577800052e+00,-7.480326470434741637e+00
22 | -7.354572502312226590e+00,-7.533438825849658294e+00
23 | -4.735683101825944874e+00,-6.246190570957935506e+00
24 | -8.140511145486314604e+00,-5.962247646221170427e+00
25 | -6.374639912170812828e+00,-6.014354399105824811e+00
26 | -4.746593816495003892e+00,-8.832197392798448732e+00
27 | -7.652452405688841885e+00,-7.116928200015955497e+00
28 | -6.435807763005041870e+00,-6.105475539846610289e+00
29 | -6.308736680458102875e+00,-5.744543953095347710e+00
30 | -6.900528785115418451e+00,-6.762782209967165059e+00
31 | -6.834055351247456223e+00,-7.531709940881763821e+00
32 | -8.079923598207045643e+00,-7.214610829116894664e+00
33 | -7.672147929583970516e+00,-6.994846034742845831e+00
34 | -5.293610375005918023e+00,-8.117925092102796114e+00
35 | -7.087749441508545800e+00,-7.373110527934779945e+00
36 | -5.247215887219635277e+00,-8.310250971236579076e+00
37 | -6.364579504845222502e+00,-3.027378102621225864e+00
38 | -7.364990738980373486e+00,-6.798235453889623692e+00
39 | -7.861135842199221457e+00,-6.418006119012676258e+00
40 | -6.132333586028008376e+00,-6.269739327842481558e+00
41 | -5.925625427658067323e+00,-6.228718341970148842e+00
42 | -5.612716041964647573e+00,-7.587779058894727591e+00
43 | -5.980027315718019487e+00,-6.572810072399337677e+00
44 | -7.031412286186853322e+00,-6.291792386791370539e+00
45 | -4.564968624477761416e+00,-8.747374785867695124e+00
46 | -7.319671677848253566e+00,-6.749369015989855392e+00
47 | -5.438353902085154346e+00,-8.315971744455385561e+00
48 | -6.809825106161251362e+00,-7.265423190137706655e+00
49 | -8.904769777808876796e+00,-6.693655278506518869e+00
50 | -6.193367000776756726e+00,-8.492825464465598273e+00
51 | -8.398997157105283051e+00,-7.364343666142198153e+00
52 | -6.522611705186222686e+00,-7.573019188536600943e+00
53 | -5.716463438996310487e+00,-6.869876532256359525e+00
54 | -1.012089453122034222e+01,-7.904497234610236234e+00
55 |
--------------------------------------------------------------------------------
/selector/methods/tests/data/labels_imbalance_case1.txt:
--------------------------------------------------------------------------------
1 | 0.000000000000000000e+00
2 | 0.000000000000000000e+00
3 | 1.000000000000000000e+00
4 | 1.000000000000000000e+00
5 | 1.000000000000000000e+00
6 | 1.000000000000000000e+00
7 | 1.000000000000000000e+00
8 | 1.000000000000000000e+00
9 | 1.000000000000000000e+00
10 | 1.000000000000000000e+00
11 | 1.000000000000000000e+00
12 | 1.000000000000000000e+00
13 | 2.000000000000000000e+00
14 | 2.000000000000000000e+00
15 | 2.000000000000000000e+00
16 | 2.000000000000000000e+00
17 | 2.000000000000000000e+00
18 | 2.000000000000000000e+00
19 | 2.000000000000000000e+00
20 | 2.000000000000000000e+00
21 | 2.000000000000000000e+00
22 | 2.000000000000000000e+00
23 | 2.000000000000000000e+00
24 | 2.000000000000000000e+00
25 | 2.000000000000000000e+00
26 | 2.000000000000000000e+00
27 | 2.000000000000000000e+00
28 | 2.000000000000000000e+00
29 | 2.000000000000000000e+00
30 | 2.000000000000000000e+00
31 | 2.000000000000000000e+00
32 | 2.000000000000000000e+00
33 | 2.000000000000000000e+00
34 | 2.000000000000000000e+00
35 | 2.000000000000000000e+00
36 | 2.000000000000000000e+00
37 | 2.000000000000000000e+00
38 | 2.000000000000000000e+00
39 | 2.000000000000000000e+00
40 | 2.000000000000000000e+00
41 | 2.000000000000000000e+00
42 | 2.000000000000000000e+00
43 | 2.000000000000000000e+00
44 | 2.000000000000000000e+00
45 | 2.000000000000000000e+00
46 | 2.000000000000000000e+00
47 | 2.000000000000000000e+00
48 | 2.000000000000000000e+00
49 | 2.000000000000000000e+00
50 | 2.000000000000000000e+00
51 | 2.000000000000000000e+00
52 | 2.000000000000000000e+00
53 |
--------------------------------------------------------------------------------
/selector/methods/tests/data/labels_imbalance_case2.txt:
--------------------------------------------------------------------------------
1 | 0.000000000000000000e+00
2 | 0.000000000000000000e+00
3 | 0.000000000000000000e+00
4 | 1.000000000000000000e+00
5 | 1.000000000000000000e+00
6 | 1.000000000000000000e+00
7 | 1.000000000000000000e+00
8 | 1.000000000000000000e+00
9 | 1.000000000000000000e+00
10 | 1.000000000000000000e+00
11 | 1.000000000000000000e+00
12 | 1.000000000000000000e+00
13 | 1.000000000000000000e+00
14 | 1.000000000000000000e+00
15 | 2.000000000000000000e+00
16 | 2.000000000000000000e+00
17 | 2.000000000000000000e+00
18 | 2.000000000000000000e+00
19 | 2.000000000000000000e+00
20 | 2.000000000000000000e+00
21 | 2.000000000000000000e+00
22 | 2.000000000000000000e+00
23 | 2.000000000000000000e+00
24 | 2.000000000000000000e+00
25 | 2.000000000000000000e+00
26 | 2.000000000000000000e+00
27 | 2.000000000000000000e+00
28 | 2.000000000000000000e+00
29 | 2.000000000000000000e+00
30 | 2.000000000000000000e+00
31 | 2.000000000000000000e+00
32 | 2.000000000000000000e+00
33 | 2.000000000000000000e+00
34 | 2.000000000000000000e+00
35 | 2.000000000000000000e+00
36 | 2.000000000000000000e+00
37 | 2.000000000000000000e+00
38 | 2.000000000000000000e+00
39 | 2.000000000000000000e+00
40 | 2.000000000000000000e+00
41 | 2.000000000000000000e+00
42 | 2.000000000000000000e+00
43 | 2.000000000000000000e+00
44 | 2.000000000000000000e+00
45 | 2.000000000000000000e+00
46 | 2.000000000000000000e+00
47 | 2.000000000000000000e+00
48 | 2.000000000000000000e+00
49 | 2.000000000000000000e+00
50 | 2.000000000000000000e+00
51 | 2.000000000000000000e+00
52 | 2.000000000000000000e+00
53 | 2.000000000000000000e+00
54 | 2.000000000000000000e+00
55 |
--------------------------------------------------------------------------------
/selector/methods/tests/test_distance.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Test selector/methods/distance.py."""
25 |
26 | import numpy as np
27 | import pytest
28 | from numpy.testing import assert_equal, assert_raises
29 | from scipy.spatial.distance import pdist, squareform
30 | from sklearn.metrics import pairwise_distances
31 |
32 | from selector.methods.distance import DISE, MaxMin, MaxSum, OptiSim
33 | from selector.methods.tests.common import (
34 | generate_synthetic_cluster_data,
35 | generate_synthetic_data,
36 | get_data_file_path,
37 | )
38 |
39 |
40 | def test_maxmin():
41 | """Testing the MaxMin class."""
42 | # generate random data points belonging to one cluster - pairwise distance matrix
43 | _, _, arr_dist = generate_synthetic_data(
44 | n_samples=100,
45 | n_features=2,
46 | n_clusters=1,
47 | pairwise_dist=True,
48 | metric="euclidean",
49 | random_state=42,
50 | )
51 |
52 | # generate random data points belonging to multiple clusters - class labels and pairwise distance matrix
53 | _, class_labels_cluster, arr_dist_cluster = generate_synthetic_data(
54 | n_samples=100,
55 | n_features=2,
56 | n_clusters=3,
57 | pairwise_dist=True,
58 | metric="euclidean",
59 | random_state=42,
60 | )
61 |
62 | # use MaxMin algorithm to select points from clustered data
63 | collector = MaxMin()
64 | selected_ids = collector.select(
65 | arr_dist_cluster,
66 | size=12,
67 | labels=class_labels_cluster,
68 | proportional_selection=False,
69 | )
70 | # make sure all the selected indices are the same with expectation
71 | assert_equal(selected_ids, [41, 34, 94, 85, 51, 50, 66, 78, 21, 64, 29, 83])
72 |
73 | # use MaxMin algorithm to select points from non-clustered data
74 | collector = MaxMin()
75 | selected_ids = collector.select(arr_dist, size=12)
76 | # make sure all the selected indices are the same with expectation
77 | assert_equal(selected_ids, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97])
78 |
79 | # use MaxMin algorithm to select points from non-clustered data with "medoid" as the reference point
80 | collector_medoid_1 = MaxMin(ref_index=None)
81 | selected_ids_medoid_1 = collector_medoid_1.select(arr_dist, size=12)
82 | # make sure all the selected indices are the same with expectation
83 | assert_equal(selected_ids_medoid_1, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97])
84 |
85 | # use MaxMin algorithm to select points from non-clustered data with "None" for the reference point
86 | collector_medoid_2 = MaxMin(ref_index=None)
87 | selected_ids_medoid_2 = collector_medoid_2.select(arr_dist, size=12)
88 | # make sure all the selected indices are the same with expectation
89 | assert_equal(selected_ids_medoid_2, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97])
90 |
91 | # use MaxMin algorithm to select points from non-clustered data with float as the reference point
92 | collector_float = MaxMin(ref_index=85)
93 | selected_ids_float = collector_float.select(arr_dist, size=12)
94 | # make sure all the selected indices are the same with expectation
95 | assert_equal(selected_ids_float, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97])
96 |
97 | # use MaxMin algorithm to select points from non-clustered data with a predefined list as the reference point
98 | collector_float = MaxMin(ref_index=[85, 57, 41, 25])
99 | selected_ids_float = collector_float.select(arr_dist, size=12)
100 | # make sure all the selected indices are the same with expectation
101 | assert_equal(selected_ids_float, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97])
102 |
103 | # test failing case when ref_index is not a valid index
104 | with pytest.raises(ValueError):
105 | collector_float = MaxMin(ref_index=-3)
106 | selected_ids_float = collector_float.select(arr_dist, size=12)
107 | # test failing case when ref_index contains a complex number
108 | with pytest.raises(ValueError):
109 | collector_float = MaxMin(ref_index=[1 + 5j, 2, 5])
110 | _ = collector_float.select(arr_dist, size=12)
111 | # test failing case when ref_index contains a negative number
112 | with pytest.raises(ValueError):
113 | collector_float = MaxMin(ref_index=[-1, 2, 5])
114 | _ = collector_float.select(arr_dist, size=12)
115 |
116 | # test failing case when the number of labels is not equal to the number of samples
117 | with pytest.raises(ValueError):
118 | collector_float = MaxMin(ref_index=85)
119 | _ = collector_float.select(
120 | arr_dist, size=12, labels=class_labels_cluster[:90], proportional_selection=False
121 | )
122 |
123 | # use MaxMin algorithm, this time instantiating with a distance metric
124 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"))
125 | simple_coords = np.array([[0, 0], [2, 0], [0, 2], [2, 2], [-10, -10]])
126 | # provide coordinates rather than pairwise distance matrix to collector
127 | selected_ids = collector.select(x=simple_coords, size=3)
128 | # make sure all the selected indices are the same with expectation
129 | assert_equal(selected_ids, [0, 4, 3])
130 |
131 | # generating mocked clusters
132 | np.random.seed(42)
133 | cluster_one = np.random.normal(0, 1, (3, 2))
134 | cluster_two = np.random.normal(10, 1, (6, 2))
135 | cluster_three = np.random.normal(20, 1, (10, 2))
136 | labels_mocked = np.hstack(
137 | [[0 for i in range(3)], [1 for i in range(6)], [2 for i in range(10)]]
138 | )
139 | mocked_cluster_coords = np.vstack([cluster_one, cluster_two, cluster_three])
140 |
141 | # selecting molecules
142 | collector = MaxMin(lambda x: pairwise_distances(x, metric="euclidean"))
143 | selected_mocked = collector.select(
144 | mocked_cluster_coords,
145 | size=15,
146 | labels=labels_mocked,
147 | proportional_selection=False,
148 | )
149 | assert_equal(selected_mocked, [0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 15, 10, 13, 9, 18])
150 |
151 |
152 | def test_maxmin_proportional_selection():
153 | """Test MaxMin class with proportional selection."""
154 | # generate the first cluster with 3 points
155 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data()
156 | # instantiate the MaxMin class
157 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
158 | # select 6 points with proportional selection from each cluster
159 | selected_ids = collector.select(
160 | coords,
161 | size=6,
162 | labels=labels,
163 | proportional_selection=True,
164 | )
165 | # make sure all the selected indices are the same with expectation
166 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 13])
167 | # check how many points are selected from each cluster
168 | assert_equal(len(selected_ids), 6)
169 | # check the number of points selected from cluster one
170 | assert_equal((labels[selected_ids] == 0).sum(), 1)
171 | # check the number of points selected from cluster two
172 | assert_equal((labels[selected_ids] == 1).sum(), 2)
173 | # check the number of points selected from cluster three
174 | assert_equal((labels[selected_ids] == 2).sum(), 3)
175 |
176 |
177 | def test_maxmin_proportional_selection_imbalance_1():
178 | """Test MaxMin class with proportional selection with imbalance case 1."""
179 | # load three-cluster data from file
180 | # 2 from class 0, 10 from class 1, 40 from class 2
181 | coords_file_path = get_data_file_path("coords_imbalance_case1.txt")
182 | coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0)
183 | labels_file_path = get_data_file_path("labels_imbalance_case1.txt")
184 | labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0)
185 |
186 | # instantiate the MaxMin class
187 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
188 | # select 12 points with proportional selection from each cluster
189 | selected_ids = collector.select(
190 | coords,
191 | size=9,
192 | labels=labels,
193 | proportional_selection=True,
194 | )
195 |
196 | # make sure all the selected indices are the same with expectation
197 | assert_equal(selected_ids, [0, 2, 6, 12, 15, 38, 16, 41, 36])
198 | # check how many points are selected from each cluster
199 | assert_equal(len(selected_ids), 9)
200 | # check the number of points selected from cluster one
201 | assert_equal((labels[selected_ids] == 0).sum(), 1)
202 | # check the number of points selected from cluster two
203 | assert_equal((labels[selected_ids] == 1).sum(), 2)
204 | # check the number of points selected from cluster three
205 | assert_equal((labels[selected_ids] == 2).sum(), 6)
206 |
207 |
208 | def test_maxmin_proportional_selection_imbalance_2():
209 | """Test MaxMin class with proportional selection with imbalance case 2."""
210 | # load three-cluster data from file
211 | # 3 from class 0, 11 from class 1, 40 from class 2
212 | coords_file_path = get_data_file_path("coords_imbalance_case2.txt")
213 | coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0)
214 | labels_file_path = get_data_file_path("labels_imbalance_case2.txt")
215 | labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0)
216 |
217 | # instantiate the MaxMin class
218 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
219 | # select 12 points with proportional selection from each cluster
220 | selected_ids = collector.select(
221 | coords,
222 | size=14,
223 | labels=labels,
224 | proportional_selection=True,
225 | )
226 |
227 | # # make sure all the selected indices are the same with expectation
228 | assert_equal(selected_ids, [0, 3, 9, 6, 14, 36, 53, 17, 44, 23, 28, 50, 52, 49])
229 | print(f"selected_ids: {selected_ids}")
230 | # check how many points are selected from each cluster
231 | assert_equal(len(selected_ids), 14)
232 | # check the number of points selected from cluster one
233 | assert_equal((labels[selected_ids] == 0).sum(), 1)
234 | # check the number of points selected from cluster two
235 | assert_equal((labels[selected_ids] == 1).sum(), 3)
236 | # check the number of points selected from cluster three
237 | assert_equal((labels[selected_ids] == 2).sum(), 10)
238 |
239 |
240 | def test_maxmin_invalid_input():
241 | """Testing MaxMin class with invalid input."""
242 | # case when the distance matrix is not square
243 | x_dist = np.array([[0, 1], [1, 0], [4, 9]])
244 | with pytest.raises(ValueError):
245 | collector = MaxMin(ref_index=0)
246 | _ = collector.select(x_dist, size=2)
247 |
248 | # case when the distance matrix is not symmetric
249 | x_dist = np.array([[0, 1, 2, 1], [1, 1, 0, 3], [4, 9, 4, 0], [6, 5, 6, 7]])
250 | with pytest.raises(ValueError):
251 | collector = MaxMin(ref_index=0)
252 | _ = collector.select(x_dist, size=2)
253 |
254 |
255 | def test_maxsum_clustered_data():
256 | """Testing MaxSum class."""
257 | # generate random data points belonging to multiple clusters - coordinates and class labels
258 | coords_cluster, class_labels_cluster, coords_cluster_dist = generate_synthetic_data(
259 | n_samples=100,
260 | n_features=2,
261 | n_clusters=3,
262 | pairwise_dist=True,
263 | metric="euclidean",
264 | random_state=42,
265 | )
266 |
267 | # use MaxSum algorithm to select points from clustered data, instantiating with euclidean distance metric
268 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None)
269 | selected_ids = collector.select(
270 | coords_cluster,
271 | size=12,
272 | labels=class_labels_cluster,
273 | proportional_selection=False,
274 | )
275 | # make sure all the selected indices are the same with expectation
276 | assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83])
277 |
278 | # use MaxSum algorithm to select points from clustered data without instantiating with euclidean distance metric
279 | collector = MaxSum(ref_index=None)
280 | selected_ids = collector.select(
281 | coords_cluster_dist,
282 | size=12,
283 | labels=class_labels_cluster,
284 | proportional_selection=False,
285 | )
286 | # make sure all the selected indices are the same with expectation
287 | assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83])
288 |
289 | # check that ValueError is raised when number of points requested is greater than number of points in array
290 | with pytest.raises(ValueError):
291 | _ = collector.select_from_cluster(
292 | coords_cluster,
293 | size=101,
294 | labels=class_labels_cluster,
295 | )
296 |
297 |
298 | def test_maxsum_non_clustered_data():
299 | """Testing MaxSum class with non-clustered data."""
300 | # generate random data points belonging to one cluster - coordinates
301 | coords, _, _ = generate_synthetic_data(
302 | n_samples=100,
303 | n_features=2,
304 | n_clusters=1,
305 | pairwise_dist=True,
306 | metric="euclidean",
307 | random_state=42,
308 | )
309 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean distance metric
310 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None)
311 | selected_ids = collector.select(coords, size=12)
312 | # make sure all the selected indices are the same with expectation
313 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54])
314 |
315 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean
316 | # distance metric and using "medoid" as the reference point
317 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None)
318 | selected_ids = collector.select(coords, size=12)
319 | # make sure all the selected indices are the same with expectation
320 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54])
321 |
322 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean
323 | # distance metric and using a list as the reference points
324 | collector = MaxSum(
325 | fun_dist=lambda x: pairwise_distances(x, metric="euclidean"),
326 | ref_index=[85, 57, 25, 41, 95],
327 | )
328 | selected_ids = collector.select(coords, size=12)
329 | # make sure all the selected indices are the same with expectation
330 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54])
331 |
332 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean
333 | # distance metric and using an invalid reference point
334 | with pytest.raises(ValueError):
335 | collector = MaxSum(
336 | fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=-1
337 | )
338 | selected_ids = collector.select(coords, size=12)
339 |
340 |
341 | def test_maxsum_invalid_input():
342 | """Testing MaxSum class with invalid input."""
343 | # case when the distance matrix is not square
344 | x_dist = np.array([[0, 1], [1, 0], [4, 9]])
345 | with pytest.raises(ValueError):
346 | collector = MaxSum(ref_index=0)
347 | _ = collector.select(x_dist, size=2)
348 |
349 | # case when the distance matrix is not square
350 | x_dist = np.array([[0, 1, 2], [1, 0, 3], [4, 9, 0], [5, 6, 7]])
351 | with pytest.raises(ValueError):
352 | collector = MaxSum(ref_index=0)
353 | _ = collector.select(x_dist, size=2)
354 |
355 | # case when the distance matrix is not symmetric
356 | x_dist = np.array([[0, 1, 2, 1], [1, 1, 0, 3], [4, 9, 4, 0], [6, 5, 6, 7]])
357 | with pytest.raises(ValueError):
358 | collector = MaxSum(ref_index=0)
359 | _ = collector.select(x_dist, size=2)
360 |
361 |
362 | def test_maxsum_proportional_selection():
363 | """Test MaxSum class with proportional selection."""
364 | # generate the first cluster with 3 points
365 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data()
366 | # instantiate the MaxSum class
367 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
368 | # select 6 points with proportional selection from each cluster
369 | selected_ids = collector.select(
370 | coords,
371 | size=6,
372 | labels=labels,
373 | proportional_selection=True,
374 | )
375 | # make sure all the selected indices are the same with expectation
376 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 10])
377 | # check how many points are selected from each cluster
378 | assert_equal(len(selected_ids), 6)
379 | # check the number of points selected from cluster one
380 | assert_equal((labels[selected_ids] == 0).sum(), 1)
381 | # check the number of points selected from cluster two
382 | assert_equal((labels[selected_ids] == 1).sum(), 2)
383 | # check the number of points selected from cluster three
384 | assert_equal((labels[selected_ids] == 2).sum(), 3)
385 |
386 |
387 | def test_optisim():
388 | """Testing OptiSim class."""
389 | # generate random data points belonging to one cluster - coordinates and pairwise distance matrix
390 | coords, _, arr_dist = generate_synthetic_data(
391 | n_samples=100,
392 | n_features=2,
393 | n_clusters=1,
394 | pairwise_dist=True,
395 | metric="euclidean",
396 | random_state=42,
397 | )
398 |
399 | # generate random data points belonging to multiple clusters - coordinates and class labels
400 | coords_cluster, class_labels_cluster, _ = generate_synthetic_data(
401 | n_samples=100,
402 | n_features=2,
403 | n_clusters=3,
404 | pairwise_dist=True,
405 | metric="euclidean",
406 | random_state=42,
407 | )
408 |
409 | # use OptiSim algorithm to select points from clustered data
410 | collector = OptiSim(ref_index=0)
411 | selected_ids = collector.select(coords_cluster, size=12, labels=class_labels_cluster)
412 | # make sure all the selected indices are the same with expectation
413 | # assert_equal(selected_ids, [2, 85, 86, 59, 1, 66, 50, 68, 0, 64, 83, 72])
414 |
415 | # use OptiSim algorithm to select points from non-clustered data
416 | collector = OptiSim(ref_index=0)
417 | selected_ids = collector.select(coords, size=12)
418 | # make sure all the selected indices are the same with expectation
419 | assert_equal(selected_ids, [0, 8, 55, 37, 41, 13, 12, 42, 6, 30, 57, 76])
420 |
421 | # check if OptiSim gives same results as MaxMin for k=>infinity
422 | collector = OptiSim(ref_index=85, k=999999)
423 | selected_ids_optisim = collector.select(coords, size=12)
424 | collector = MaxMin()
425 | selected_ids_maxmin = collector.select(arr_dist, size=12)
426 | assert_equal(selected_ids_optisim, selected_ids_maxmin)
427 |
428 | # test with invalid ref_index
429 | with pytest.raises(ValueError):
430 | collector = OptiSim(ref_index=10000)
431 | _ = collector.select(coords, size=12)
432 |
433 |
434 | def test_optisim_proportional_selection():
435 | """Test OptiSim class with proportional selection."""
436 | # generate the first cluster with 3 points
437 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data()
438 | # instantiate the Optisim class
439 | collector = OptiSim(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
440 | # select 6 points with proportional selection from each cluster
441 | selected_ids = collector.select(
442 | coords,
443 | size=6,
444 | labels=labels,
445 | proportional_selection=True,
446 | )
447 | # make sure all the selected indices are the same with expectation
448 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 13])
449 | # check how many points are selected from each cluster
450 | assert_equal(len(selected_ids), 6)
451 | # check the number of points selected from cluster one
452 | assert_equal((labels[selected_ids] == 0).sum(), 1)
453 | # check the number of points selected from cluster two
454 | assert_equal((labels[selected_ids] == 1).sum(), 2)
455 | # check the number of points selected from cluster three
456 | assert_equal((labels[selected_ids] == 2).sum(), 3)
457 |
458 |
459 | def test_directed_sphere_size_error():
460 | """Test DirectedSphereExclusion error when too many points requested."""
461 | x = np.array([[1, 9]] * 100)
462 | collector = DISE()
463 | assert_raises(ValueError, collector.select, x, size=105)
464 |
465 |
466 | def test_directed_sphere_same_number_of_pts():
467 | """Test DirectSphereExclusion with `size` = number of points in dataset."""
468 | # (0,0) as the reference point
469 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3]])
470 | collector = DISE(r0=1, tol=0, ref_index=0)
471 | selected = collector.select(x, size=2)
472 | assert_equal(selected, [0, 2])
473 | assert_equal(collector.r, 1)
474 |
475 |
476 | def test_directed_sphere_same_number_of_pts_None():
477 | """Test DirectSphereExclusion with `size` = number of points in dataset with the ref_index None."""
478 | # None as the reference point
479 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]])
480 | collector = DISE(r0=1, tol=0, ref_index=None)
481 | selected = collector.select(x, size=3)
482 | assert_equal(selected, [2, 0, 4])
483 | assert_equal(collector.r, 1)
484 |
485 |
486 | def test_directed_sphere_exclusion_select_more_number_of_pts():
487 | """Test DirectSphereExclusion on points on the line with `size` < number of points in dataset."""
488 | # (0,0) as the reference point
489 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]])
490 | collector = DISE(r0=0.5, tol=0, ref_index=0)
491 | selected = collector.select(x, size=3)
492 | expected = [0, 3, 6]
493 | assert_equal(selected, expected)
494 | assert_equal(collector.r, 2.0)
495 |
496 |
497 | def test_directed_sphere_exclusion_on_line_with_smaller_radius():
498 | """Test Direct Sphere Exclusion on points on line with smaller distribution than the radius."""
499 | # (0,0) as the reference point
500 | x = np.array(
501 | [
502 | [0, 0],
503 | [0, 1],
504 | [0, 1.1],
505 | [0, 1.2],
506 | [0, 2],
507 | [0, 3],
508 | [0, 3.1],
509 | [0, 3.2],
510 | [0, 4],
511 | [0, 5],
512 | [0, 6],
513 | ]
514 | )
515 | collector = DISE(r0=0.5, tol=1, ref_index=1)
516 | selected = collector.select(x, size=3)
517 | expected = [1, 5, 9]
518 | assert_equal(selected, expected)
519 | assert_equal(collector.r, 1.0)
520 |
521 |
522 | def test_directed_sphere_on_line_with_larger_radius():
523 | """Test Direct Sphere Exclusion on points on the line with a too large radius size."""
524 | # (0,0) as the reference point
525 | x = np.array(
526 | [
527 | [0, 0],
528 | [0, 1],
529 | [0, 1.1],
530 | [0, 1.2],
531 | [0, 2],
532 | [0, 3],
533 | [0, 3.1],
534 | [0, 3.2],
535 | [0, 4],
536 | [0, 5],
537 | ]
538 | )
539 | collector = DISE(r0=2.0, tol=0, p=2.0, ref_index=1)
540 | selected = collector.select(x, size=3)
541 | expected = [1, 5, 9]
542 | assert_equal(selected, expected)
543 | assert_equal(collector.r, 1.0)
544 |
545 |
546 | def test_directed_sphere_dist_func():
547 | """Test Direct Sphere Exclusion with a distance function."""
548 | # (0,0) as the reference point
549 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]])
550 | collector = DISE(
551 | r0=0.5,
552 | tol=0,
553 | ref_index=0,
554 | fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1)),
555 | )
556 | selected = collector.select(x, size=3)
557 | expected = [0, 3, 6]
558 | assert_equal(selected, expected)
559 | assert_equal(collector.r, 2.0)
560 |
561 |
562 | def test_directed_sphere_proportional_selection():
563 | """Test DISE class with proportional selection."""
564 | # generate the first cluster with 3 points
565 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data()
566 | # instantiate the DISE class
567 | collector = DISE(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0)
568 | # select 6 points with proportional selection from each cluster
569 | selected_ids = collector.select(
570 | coords,
571 | size=6,
572 | labels=labels,
573 | proportional_selection=True,
574 | )
575 | # make sure all the selected indices are the same with expectation
576 | assert_equal(selected_ids, [0, 3, 7, 9, 12, 15])
577 | # check how many points are selected from each cluster
578 | assert_equal(len(selected_ids), 6)
579 | # check the number of points selected from cluster one
580 | assert_equal((labels[selected_ids] == 0).sum(), 1)
581 | # check the number of points selected from cluster two
582 | assert_equal((labels[selected_ids] == 1).sum(), 2)
583 | # check the number of points selected from cluster three
584 | assert_equal((labels[selected_ids] == 2).sum(), 3)
585 |
--------------------------------------------------------------------------------
/selector/methods/tests/test_partition.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Test Partition-Based Selection Methods."""
25 |
26 | import numpy as np
27 | import pytest
28 | from numpy.testing import assert_equal, assert_raises
29 |
30 | from selector.methods.partition import GridPartition, Medoid
31 | from selector.methods.tests.common import generate_synthetic_data
32 |
33 |
34 | @pytest.mark.parametrize("numb_pts", [4])
35 | @pytest.mark.parametrize("method", ["equifrequent", "equisized"])
36 | def test_grid_partitioning_independent_on_simple_example(numb_pts, method):
37 | r"""Test grid partitioning on an example each bin has only one point."""
38 | # Construct feature array where each molecule is known to be in which bin
39 | # The grid is a uniform grid from 0 to 10 in X-axis and 0 to 11 on y-axis.
40 | x = np.linspace(0, 3, numb_pts)
41 | y = np.linspace(0, 3, numb_pts)
42 | X, Y = np.meshgrid(x, y)
43 | grid = np.array([1.0, 0.0]) + np.vstack([X.ravel(), Y.ravel()]).T
44 | # Make one bin have an extra point
45 | grid = np.vstack((grid, np.array([1.1, 0.0])))
46 |
47 | # Here the number of cells should be equal to the number of points in each dimension
48 | # excluding the extra point, so that the answer is unique/known.
49 | collector = GridPartition(nbins_axis=4, bin_method=f"{method}_independent")
50 | # Sort the points so that they're comparable to the expected answer.
51 | selected_ids = np.sort(collector.select(grid, size=len(grid) - 1))
52 | expected = np.arange(len(grid) - 1)
53 | assert_equal(selected_ids, expected)
54 |
55 |
56 | def test_grid_partitioning_equisized_dependent_on_simple_example():
57 | r"""Test equisized_dependent grid partitioning on example that is different from independent."""
58 | # Construct feature array where each molecule is known to be in which bin
59 | grid = np.array(
60 | [
61 | [0.0, 0.0], # Corresponds to bin (0, 0)
62 | [0.0, 4.0], # Corresponds to bin (0, 3)
63 | [1.0, 1.0], # Corresponds to bin (1, 0)
64 | [1.0, 0.9], # Corresponds to bin (1, 0)
65 | [1.0, 2.0], # Corresponds to bin (1, 3)
66 | [2.0, 0.0], # Corresponds to bin (2, 0)
67 | [2.0, 4.0], # Corresponds to bin (2, 3)
68 | [3.0, 0.0], # Corresponds to bin (3, 0)
69 | [3.0, 4.0], # Corresponds to bin (3, 3)
70 | [3.0, 3.9], # Corresponds to bin (3, 3)
71 | ]
72 | )
73 |
74 | # The number of bins makes it so that it approximately be a single point in each bin
75 | collector = GridPartition(nbins_axis=4, bin_method="equisized_dependent")
76 | # Two bins have an extra point in them and so has more diversity than other bins
77 | # then the two expected molecules should be in those bins.
78 | selected_ids = collector.select(grid, size=2, labels=None)
79 | right_molecules = True
80 | if not (2 in selected_ids or 3 in selected_ids):
81 | right_molecules = False
82 | if not (8 in selected_ids or 9 in selected_ids):
83 | right_molecules = False
84 | assert right_molecules, "The correct points were selected"
85 |
86 |
87 | @pytest.mark.parametrize("numb_pts", [4])
88 | def test_grid_partitioning_equifrequent_dependent_on_simple_example(numb_pts):
89 | r"""Test equifrequent dependent grid partitioning on an example where each bin has only one point."""
90 | # Construct feature array where each molecule is known to be in which bin
91 | # The grid is a uniform grid from 0 to 10 in X-axis and 0 to 11 on y-axis.
92 | x = np.linspace(0, 3, numb_pts)
93 | y = np.linspace(0, 3, numb_pts)
94 | X, Y = np.meshgrid(x, y)
95 | grid = np.array([1.0, 0.0]) + np.vstack([X.ravel(), Y.ravel()]).T
96 | # Make one bin have an extra point
97 | grid = np.vstack((grid, np.array([1.1, 0.0])))
98 |
99 | # Here the number of cells should be equal to the number of points in each dimension
100 | # excluding the extra point, so that the answer is unique/known.
101 | collector = GridPartition(nbins_axis=numb_pts, bin_method="equifrequent_dependent")
102 | # Sort the points so that they're comparable to the expected answer.
103 | selected_ids = np.sort(collector.select(grid, size=len(grid) - 1))
104 | expected = np.arange(len(grid) - 1)
105 | assert_equal(selected_ids, expected)
106 |
107 |
108 | @pytest.mark.parametrize("numb_pts", [10, 20, 30])
109 | @pytest.mark.parametrize("method", ["equifrequent", "equisized"])
110 | def test_bins_from_both_methods_dependent_same_as_independent_on_uniform_grid(numb_pts, method):
111 | r"""Test bins is the same between the two equisized methods on uniform grid in three-dimensions."""
112 | x = np.linspace(0, 10, numb_pts)
113 | y = np.linspace(0, 11, numb_pts)
114 | X = np.meshgrid(x, y, y)
115 | grid = np.vstack(list(map(np.ravel, X))).T
116 | grid = np.array([1.0, 0.0, 0.0]) + grid
117 |
118 | # Here the number of cells should be equal to the number of points in each dimension
119 | # excluding the extra point, so that the answer is unique/known.
120 | collector_indept = GridPartition(nbins_axis=numb_pts, bin_method=f"{method}_independent")
121 | collector_depend = GridPartition(nbins_axis=numb_pts, bin_method=f"{method}_dependent")
122 |
123 | # Get the bins from the method
124 | bins_indept = collector_indept.get_bins_from_method(grid)
125 | bins_dept = collector_depend.get_bins_from_method(grid)
126 |
127 | # Test the bins are the same
128 | for key in bins_indept.keys():
129 | assert_equal(bins_dept[key], bins_indept[key])
130 |
131 |
132 | def test_raises_grid_partitioning():
133 | r"""Test raises error for grid partitioning."""
134 | grid = np.random.uniform(0.0, 1.0, size=(10, 3))
135 |
136 | assert_raises(TypeError, GridPartition, 5.0) # Test number of axis should be integer
137 | assert_raises(TypeError, GridPartition, 5, 5.0) # Test grid method should be string
138 | assert_raises(TypeError, GridPartition, 5, "string", []) # Test random seed should be integer
139 |
140 | # Test the collector grid method is not the correct string
141 | collector = GridPartition(nbins_axis=5, bin_method="string")
142 | assert_raises(ValueError, collector.select_from_cluster, grid, 5)
143 |
144 | collector = GridPartition(nbins_axis=5)
145 | assert_raises(TypeError, collector.select_from_cluster, [5.0], 5) # Test X is numpy array
146 | assert_raises(
147 | TypeError, collector.select_from_cluster, grid, 5.0
148 | ) # Test number selected should be int
149 | assert_raises(TypeError, collector.select_from_cluster, grid, 5, [5.0])
150 |
151 |
152 | def test_medoid():
153 | """Testing Medoid class."""
154 | coords, _, _ = generate_synthetic_data(
155 | n_samples=100,
156 | n_features=2,
157 | n_clusters=1,
158 | pairwise_dist=True,
159 | metric="euclidean",
160 | random_state=42,
161 | )
162 |
163 | coords_cluster, class_labels_cluster, _ = generate_synthetic_data(
164 | n_samples=100,
165 | n_features=2,
166 | n_clusters=3,
167 | pairwise_dist=True,
168 | metric="euclidean",
169 | random_state=42,
170 | )
171 | collector = Medoid()
172 | selected_ids = collector.select(coords_cluster, size=12, labels=class_labels_cluster)
173 | # make sure all the selected indices are the same with expectation
174 | assert_equal(selected_ids, [2, 73, 94, 86, 1, 50, 93, 78, 0, 54, 33, 72])
175 |
176 | collector = Medoid()
177 | selected_ids = collector.select(coords, size=12)
178 | # make sure all the selected indices are the same with expectation
179 | assert_equal(selected_ids, [0, 95, 57, 41, 25, 9, 8, 6, 66, 1, 42, 82])
180 |
181 | # test the case where KD-Tree query return is an integer
182 | features = np.array([[1.5, 2.8], [2.3, 3.8], [1.5, 2.8], [4.0, 5.9]])
183 | selector = Medoid()
184 | selected_ids = selector.select(features, size=2)
185 | assert_equal(selected_ids, [0, 3])
186 |
--------------------------------------------------------------------------------
/selector/methods/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 | """Module for Selection Utilities."""
25 |
26 | import warnings
27 |
28 | import numpy as np
29 |
30 | __all__ = [
31 | "optimize_radius",
32 | ]
33 |
34 |
35 | def optimize_radius(obj, X, size, cluster_ids=None):
36 | """Algorithm that uses sphere exclusion for selecting points from cluster.
37 |
38 | Iteratively searches for the optimal radius to obtain the correct number
39 | of selected samples. If the radius cannot converge to return `size` points,
40 | the function returns the closest number of samples to `size` as possible.
41 |
42 | Parameters
43 | ----------
44 | obj: object
45 | Instance of `DirectedSphereExclusion` or `OptiSim` selection class.
46 | X: ndarray of shape (n_samples, n_features)
47 | Feature matrix of `n_samples` samples in `n_features` dimensional space.
48 | size: int
49 | Number of sample points to select (i.e. size of the subset).
50 | cluster_ids: np.ndarray
51 | Indices of points that form a cluster.
52 |
53 | Returns
54 | -------
55 | selected: list
56 | List of indices of selected samples.
57 | """
58 | if X.shape[0] < size:
59 | raise RuntimeError(
60 | f"Size of samples to be selected is greater than existing the number of samples; "
61 | f"{size} > {X.shape[0]}."
62 | )
63 | # set the limits on # of selected points according to the tolerance percentage
64 | error = size * obj.tol
65 | lower_size = round(size - error)
66 | upper_size = round(size + error)
67 |
68 | # select `upper_size` number of samples
69 | if obj.r is not None:
70 | # use initial sphere radius
71 | selected = obj.algorithm(X, upper_size)
72 | else:
73 | # calculate a sphere radius based on maximum of n_features range
74 | # np.ptp returns range of values (maximum - minimum) along an axis
75 | obj.r = max(np.ptp(X, axis=0)) / size * 3
76 | selected = obj.algorithm(X, upper_size)
77 |
78 | # return selected if the correct number of samples chosen
79 | if len(selected) == size:
80 | return selected
81 |
82 | # optimize radius to select the correct number of samples
83 | # first, set a sensible range for optimizing r value within that range
84 | if len(selected) > size:
85 | # radius should become bigger, b/c too many samples were selected
86 | bounds = [obj.r, np.inf]
87 | else:
88 | # radius should become smaller, b/c too few samples were selected
89 | bounds = [0, obj.r]
90 |
91 | n_iter = 0
92 | while (len(selected) < lower_size or len(selected) > upper_size) and n_iter < obj.n_iter:
93 | # change sphere radius based on the defined bound
94 | if bounds[1] == np.inf:
95 | # make sphere radius larger by a factor of 2
96 | obj.r = bounds[0] * 2
97 | else:
98 | # make sphere radius smaller by a factor of 1/2
99 | obj.r = (bounds[0] + bounds[1]) / 2
100 |
101 | # re-select samples with the new radius
102 | selected = obj.algorithm(X, upper_size)
103 |
104 | # adjust lower/upper bounds of radius range
105 | if len(selected) > size:
106 | bounds[0] = obj.r
107 | else:
108 | bounds[1] = obj.r
109 | n_iter += 1
110 |
111 | # cannot find radius that produces desired number of selected points
112 | if n_iter >= obj.n_iter and len(selected) != size:
113 | warnings.warn(
114 | f"Optimal radius finder failed to converge, selected {len(selected)} points instead "
115 | f"of requested {size}."
116 | )
117 |
118 | return selected
119 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [gh-actions]
2 | python =
3 | 3.9: py39
4 | 3.10: py310
5 | # 3.8: py38, rst_linux, rst_mac, readme, linters, coverage-report, qa
6 | # 3.11: py311, rst_linux, rst_mac, readme, linters, coverage-report, qa
7 | 3.11: py311, rst_linux, rst_mac, readme, coverage-report
8 | # 3.9: py39
9 | # 3.10: py310
10 | 3.12: py312
11 |
12 | [tox]
13 | # todo: add back the linters and qa
14 | # envlist = py39, py310, py311, rst_linux, rst_mac, readme, rst, linters, coverage-report, qa
15 | envlist = py39, py310, py311, py312, rst_linux, rst_mac, readme, rst, coverage-report
16 |
17 | [testenv]
18 | ; conda_deps =
19 | ; rdkit
20 | ; conda_channels = rdkit
21 | deps =
22 | -r{toxinidir}/requirements_dev.txt
23 | commands =
24 | # coverage run --rcfile=tox.ini -m pytest tests
25 | python -m pip install --upgrade pip
26 | # pip install -r requirements.txt
27 | # pip install .
28 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector
29 |
30 | # pytest --cov-config=.coveragerc --cov=selector/test
31 | # can run it if needed
32 | # coverage report -m
33 | # prevent exit when error is encountered
34 | ignore_errors = true
35 |
36 | [testenv:readme]
37 | skip_install = true
38 | deps =
39 | readme_renderer
40 | twine
41 | -r{toxinidir}/requirements.txt
42 | -r{toxinidir}/requirements_dev.txt
43 | commands =
44 | # https://github.com/pypa/twine/issues/977
45 | python -m pip install importlib_metadata==7.2.1 build
46 | python -m build --no-isolation
47 | twine check dist/*
48 |
49 | [testenv:rst_linux]
50 | platform =
51 | linux
52 | skip_install = true
53 | deps =
54 | doc8
55 | rstcheck==3.3.1
56 | commands =
57 | doc8 --config tox.ini book/content/
58 | # ignore code-block related error because
59 | # the Sphinx support in rstcheck is minimal. This results in false positives
60 | # rstcheck uses Docutils to parse reStructuredText files and extract code blocks
61 | # fixme: check updates on the following website in the future
62 | # coala.gitbooks.io/projects/content/en/projects/rstcheck-with-better-sphinx-suppor.html
63 | rstcheck --recursive book/content/ --report error \
64 | --ignore-directives automodule,autoclass,autofunction,bibliography,code-block \
65 | --ignore-roles cite,mod,class,lineno --ignore-messages code-block
66 |
67 | [testenv:rst_mac]
68 | platform =
69 | darwin
70 | skip_install = true
71 | deps = {[testenv:rst_linux]deps}
72 | commands = {[testenv:rst_linux]commands}
73 |
74 | [testenv:linters]
75 | deps =
76 | flake8
77 | flake8-docstrings
78 | flake8-import-order>=0.9
79 | flake8-colors
80 | pep8-naming
81 | pylint==2.13.9
82 | # black
83 | bandit
84 | commands =
85 | flake8 selector/ selector/tests setup.py
86 | pylint selector --rcfile=tox.ini --disable=similarities
87 | # black -l 100 --check ./
88 | # black -l 100 --diff ./
89 | # Use bandit configuration file
90 | bandit -r selector -c .bandit.yml
91 |
92 | ignore_errors = true
93 |
94 | [testenv:coverage-report]
95 | deps = coverage>=4.2
96 | skip_install = true
97 | commands =
98 | # coverage combine --rcfile=tox.ini
99 | coverage report
100 |
101 | [testenv:qa]
102 | deps =
103 | {[testenv]deps}
104 | {[testenv:linters]deps}
105 | {[testenv:coverage-report]deps}
106 | commands =
107 | {[testenv]commands}
108 | {[testenv:linters]commands}
109 | # {[testenv:coverage-report]commands}
110 | ignore_errors = true
111 |
112 | # pytest configuration
113 | [pytest]
114 | addopts = --cache-clear
115 | --showlocals
116 | -v
117 | -r a
118 | --cov-report term-missing
119 | --cov selector
120 | # Do not run tests in the build folder
121 | norecursedirs = build
122 |
123 | # flake8 configuration
124 | [flake8]
125 | exclude =
126 | __init__.py,
127 | .tox,
128 | .git,
129 | __pycache__,
130 | build,
131 | dist,
132 | *.pyc,
133 | *.egg-info,
134 | .cache,
135 | .eggs,
136 | _version.py,
137 |
138 | max-line-length = 100
139 | import-order-style = google
140 | ignore =
141 | # E121 : continuation line under-indented for hanging indent
142 | E121,
143 | # E123 : closing bracket does not match indentation of opening bracket’s line
144 | E123,
145 | # E126 : continuation line over-indented for hanging indent
146 | E126,
147 | # E226 : missing whitespace around arithmetic operator
148 | E226,
149 | # E241 : multiple spaces after ‘,’
150 | # E242 : tab after ‘,’
151 | E24,
152 | # E704 : multiple statements on one line (def)
153 | E704,
154 | # W503 : line break occurred before a binary operator
155 | W503,
156 | # W504 : Line break occurred after a binary operator
157 | W504,
158 | # D202: No blank lines allowed after function docstring
159 | D202,
160 | # E203: Whitespace before ':'
161 | E203,
162 | # E731: Do not assign a lambda expression, use a def
163 | E731,
164 | # D401: First line should be in imperative mood: 'Do', not 'Does'
165 | D401,
166 |
167 | per-file-ignores =
168 | # F401: Unused import
169 | # this is used to define the data typing
170 | selector/utils.py: F401,
171 | # E1101: rdkit.Chem has no attribute xxx
172 | # D403: first word of the first line should be properly capitalized
173 | selector/feature.py: E1101, D403
174 |
175 | # doc8 configuration
176 | [doc8]
177 | # Ignore target directories and autogenerated files
178 | ignore-path = book/content/_build/, build/, selector.egg-info/, selector.egg-info, .*/
179 | # File extensions to use
180 | extensions = .rst, .txt
181 | # Maximal line length should be 100
182 | max-line-length = 100
183 | # Disable some doc8 checks:
184 | # D000: Check RST validity (cannot handle the "linenos" directive)
185 | # D002: Trailing whitespace
186 | # D004: Found literal carriage return
187 | # Both D002 and D004 can be problematic in Windows platform, line ending is `\r\n`,
188 | # but in Linux and MacOS, it's "\n"
189 | # Known issue of doc8, https://bugs.launchpad.net/doc8/+bug/1756704
190 | # ignore = D000,D002,D004
191 | ignore = D000
192 |
193 | # pylint configuration
194 | [MASTER]
195 | # This is a known issue of pylint with recognizing numpy members
196 | # https://github.com/PyCQA/pylint/issues/779
197 | # https://stackoverflow.com/questions/20553551/how-do-i-get-pylint-to-recognize-numpy-members
198 | extension-pkg-whitelist=numpy
199 |
200 | [FORMAT]
201 | # Maximum number of characters on a single line.
202 | max-line-length=100
203 |
204 | [MESSAGES CONTROL]
205 | # disable pylint warnings
206 | disable=
207 | # attribute-defined-outside-init (W0201):
208 | # Attribute %r defined outside __init__ Used when an instance attribute is
209 | # defined outside the __init__ method.
210 | W0201,
211 | # too-many-instance-attributes (R0902):
212 | # Too many instance attributes (%s/%s) Used when class has too many instance
213 | # attributes, try to reduce this to get a simpler (and so easier to use)
214 | # class.
215 | R0902,
216 | # too many branches (R0912)
217 | R0912,
218 | # too-many-arguments (R0913):
219 | # Too many arguments (%s/%s) Used when a function or method takes too many
220 | # arguments.
221 | R0913,
222 | # Too many local variables (r0914)
223 | R0914,
224 | # Too many statements (R0915)
225 | R0915,
226 | # fixme (W0511):
227 | # Used when a warning note as FIXME or XXX is detected.
228 | W0511,
229 | # bad-continuation (C0330):
230 | # Wrong hanging indentation before block (add 4 spaces).
231 | C0330,
232 | # wrong-import-order (C0411):
233 | # %s comes before %s Used when PEP8 import order is not respected (standard
234 | # imports first, then third-party libraries, then local imports)
235 | C0411,
236 | # arguments-differ (W0221):
237 | # Parameters differ from %s %r method Used when a method has a different
238 | # number of arguments than in the implemented interface or in an overridden
239 | # method.
240 | W0221,
241 | # unecessary "else" after "return" (R1705)
242 | R1705,
243 | # Value XX is unsubscriptable (E1136). this is a open issue of pylint
244 | # https://github.com/PyCQA/pylint/issues/3139
245 | E1136,
246 | # Used when a name doesn't doesn't fit the naming convention associated to its type
247 | # (constant, variable, class…).
248 | C0103,
249 | # Unnecessary pass statement
250 | W0107,
251 | # Module 'rdkit.Chem' has no 'ForwardSDMolSupplier' member (no-member)
252 | E1101,
253 | # todo: fix this one later and this is a temporary solution
254 | # E0401: Unable to import xxx (import-error)
255 | E0401,
256 | # R1721: Unnecessary use of a comprehension
257 | R1721,
258 | # I1101: Module xxx has no yyy member (no-member)
259 | I1101,
260 | # R0903: Too few public methods (too-few-public-methods)
261 | R0903,
262 | # R1702: Too many nested blocks (too-many-nested-blocks)
263 | R1702,
264 |
265 | [SIMILARITIES]
266 | min-similarity-lines=5
267 |
268 | # coverage configuration
269 | [run]
270 | branch = True
271 | parallel = True
272 | source = selector
273 |
274 | [paths]
275 | source =
276 | selector
277 | .tox/*/lib/python*/site-packages/selector
278 | .tox/pypy*/site-packages/selector
279 |
--------------------------------------------------------------------------------
/updateheaders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # The Selector is a Python library of algorithms for selecting diverse
4 | # subsets of data for machine-learning.
5 | #
6 | # Copyright (C) 2022-2024 The QC-Devs Community
7 | #
8 | # This file is part of Selector.
9 | #
10 | # Selector is free software; you can redistribute it and/or
11 | # modify it under the terms of the GNU General Public License
12 | # as published by the Free Software Foundation; either version 3
13 | # of the License, or (at your option) any later version.
14 | #
15 | # Selector is distributed in the hope that it will be useful,
16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 | # GNU General Public License for more details.
19 | #
20 | # You should have received a copy of the GNU General Public License
21 | # along with this program; if not, see
22 | #
23 | # --
24 |
25 |
26 | import os
27 | from fnmatch import fnmatch
28 | from glob import glob
29 |
30 |
31 | def strip_header(lines, closing):
32 | # search for the header closing line, i.e. '# --\n'
33 | counter = 0
34 | found = False
35 | for line in lines:
36 | counter += 1
37 | if line == closing:
38 | found = True
39 | break
40 | if found:
41 | del lines[:counter]
42 | # If the header closing is not found, we assume it is not present.
43 | # add a header closing line
44 | lines.insert(0, closing)
45 |
46 |
47 | def fix_python(fn, lines, header_lines):
48 | # check if a shebang is present
49 | do_shebang = lines[0].startswith("#!")
50 | # remove the current header
51 | strip_header(lines, "# --\n")
52 | # add a pylint line for test files:
53 | # if os.path.basename(fn).startswith('test_'):
54 | # if not lines[1].startswith('#pylint: skip-file'):
55 | # lines.insert(1, '#pylint: skip-file\n')
56 | # add new header (insert in reverse order)
57 | for hline in header_lines[::-1]:
58 | lines.insert(0, ("# " + hline).strip() + "\n")
59 |
60 | if not hline.startswith("# -*- coding: utf-8 -*-"):
61 | # add a source code encoding line
62 | lines.insert(0, "# -*- coding: utf-8 -*-\n")
63 |
64 | if do_shebang:
65 | lines.insert(0, "#!/usr/bin/env python\n")
66 |
67 |
68 | def fix_c(fn, lines, header_lines):
69 | # check for an exception line
70 | for line in lines:
71 | if "no_update_headers" in line:
72 | return
73 | # remove the current header
74 | strip_header(lines, "//--\n")
75 | # add new header (insert must be in reverse order)
76 | for hline in header_lines[::-1]:
77 | lines.insert(0, ("// " + hline).strip() + "\n")
78 |
79 |
80 | def fix_rst(fn, lines, header_lines):
81 | # check for an exception line
82 | for line in lines:
83 | if "no_update_headers" in line:
84 | return
85 | # remove the current header
86 | strip_header(lines, " : --\n")
87 | # add an empty line after header if needed
88 | if len(lines[1].strip()) > 0:
89 | lines.insert(1, "\n")
90 | # add new header (insert must be in reverse order)
91 | for hline in header_lines[::-1]:
92 | lines.insert(0, (" : " + hline).rstrip() + "\n")
93 | # add comment instruction
94 | lines.insert(0, "..\n")
95 |
96 |
97 | def iter_subdirs(root):
98 | for dn, _, _ in os.walk(root):
99 | yield dn
100 |
101 |
102 | def main():
103 | source_dirs = [".", "book", "notebooks", "selector"] + list(iter_subdirs("selector"))
104 |
105 | fixers = [
106 | ("*.py", fix_python),
107 | ("*.pxd", fix_python),
108 | ("*.pyx", fix_python),
109 | ("*.txt", fix_python),
110 | ("*.c", fix_c),
111 | ("*.cpp", fix_c),
112 | ("*.h", fix_c),
113 | ("*.rst", fix_rst),
114 | ]
115 |
116 | f = open("HEADER")
117 | header_lines = f.readlines()
118 | f.close()
119 |
120 | for sdir in source_dirs:
121 | print("Scanning:", sdir)
122 | for fn in glob(sdir + "/*.*"):
123 | if not os.path.isfile(fn):
124 | continue
125 | for pattern, fixer in fixers:
126 | if fnmatch(fn, pattern):
127 | print(" Fixing:", fn)
128 | with open(fn) as f:
129 | lines = f.readlines()
130 | fixer(fn, lines, header_lines)
131 | with open(fn, "w") as f:
132 | f.writelines(lines)
133 | break
134 |
135 |
136 | if __name__ == "__main__":
137 | main()
138 |
--------------------------------------------------------------------------------