├── .bandit.yml ├── .coveragerc ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ ├── ci_codecov.yaml │ ├── ci_tox.yaml │ ├── pypi_release.yaml │ └── website_auto.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .pre-commit-hooks.yaml ├── .roberto.yaml ├── HEADER ├── LICENSE ├── MANIFEST.in ├── README.md ├── book └── content │ ├── _config.yml │ ├── _toc.yml │ ├── api_measures_converter.rst │ ├── api_measures_diversity.rst │ ├── api_measures_similarity.rst │ ├── api_methods_base.rst │ ├── api_methods_distance.rst │ ├── api_methods_partition.rst │ ├── api_methods_similarity.rst │ ├── api_methods_utils.rst │ ├── intro.md │ ├── references.bib │ ├── requirements.txt │ └── selector_logo.png ├── codecov.yml ├── notebooks ├── tutorial_distance_based.ipynb ├── tutorial_diversity_measures.ipynb ├── tutorial_partition_based.ipynb └── tutorial_similarity_based.ipynb ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── selector ├── __init__.py ├── measures │ ├── __init__.py │ ├── converter.py │ ├── diversity.py │ ├── similarity.py │ └── tests │ │ ├── __init__.py │ │ ├── common.py │ │ ├── test_converter.py │ │ └── test_diversity.py └── methods │ ├── __init__.py │ ├── base.py │ ├── distance.py │ ├── partition.py │ ├── similarity.py │ ├── tests │ ├── __init__.py │ ├── common.py │ ├── data │ │ ├── coords_imbalance_case1.txt │ │ ├── coords_imbalance_case2.txt │ │ ├── labels_imbalance_case1.txt │ │ ├── labels_imbalance_case2.txt │ │ └── ref_esim_selection_data.csv │ ├── test_distance.py │ ├── test_partition.py │ └── test_similarity.py │ └── utils.py ├── tox.ini └── updateheaders.py /.bandit.yml: -------------------------------------------------------------------------------- 1 | skips: 2 | # Ignore defensive `assert`s, 80% of repos do this 3 | - B101 4 | # Standard pseudo-random generators are not suitable for security/cryptographic purposes 5 | - B311 6 | # Ignore warnings about importing subprocess 7 | - B404 8 | # Ignore warnings about calling subprocess 9 | - B603 10 | # Ignore warnings about calling subprocess 11 | - B607 12 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | selector/*/tests/* 4 | selector/tests/* 5 | selector/__init__.py 6 | selector/_version.py 7 | 8 | [report] 9 | show_missing = True 10 | exclude_also = 11 | pragma: no cover 12 | raise NotImplementedError 13 | if __name__ == .__main__.: 14 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | It is recommended to follow the code of conduct as described in 4 | https://qcdevs.org/guidelines/QCDevsCodeOfConduct/. 5 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We welcome contributions from external contributors, and this document 4 | describes how to merge code changes into this selector. 5 | 6 | ## Getting Started 7 | 8 | * Make sure you have a [GitHub account](https://github.com/signup/free). 9 | * [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub. 10 | * On your local machine, 11 | [clone](https://help.github.com/articles/cloning-a-repository/) your fork of 12 | the repository. 13 | 14 | ## Making Changes 15 | 16 | * Add some really awesome code to your local fork. It's usually a 17 | [good idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/) 18 | to make changes on a 19 | [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/) 20 | with the branch name relating to the feature you are going to add. 21 | * When you are ready for others to examine and comment on your new feature, 22 | navigate to your fork of selector on GitHub and open a 23 | * [pull request](https://help.github.com/articles/using-pull-requests/) (PR). Note that 24 | after you launch a PR from one of your fork's branches, all 25 | subsequent commits to that branch will be added to the open pull request 26 | automatically. Each commit added to the PR will be validated for 27 | mergability, compilation and test suite compliance; the results of these tests 28 | will be visible on the PR page. 29 | * If you're providing a new feature, you must add test cases and documentation. 30 | * When the code is ready to go, make sure you run the test suite using pytest. 31 | * When you're ready to be considered for merging, check the "Ready to go" 32 | box on the PR page to let the selector devs know that the changes are complete. 33 | The code will not be merged until this box is checked, the continuous 34 | integration returns checkmarks, 35 | and multiple core developers give "Approved" reviews. 36 | 37 | # Python Virtual Environment for Package Development 38 | 39 | Here is a list of version information for different packages that we used for 40 | [selector](https://github.com/theochem/selector), 41 | 42 | ```bash 43 | python==3.7.11 44 | rdkit==2020.09.1.0 45 | numpy==1.21.2 46 | scipy==1.7.3 47 | pytest==6.2.5 48 | pytest-cov==3.0.0 49 | tox==3.24.5 50 | flake8==4.0.1 51 | pylint==2.12.2 52 | codecov=2.1.12 53 | # more to be added 54 | ``` 55 | 56 | `Conda`, [`venv`](https://docs.python.org/3/library/venv.html#module-venv) and 57 | [`virtualenv`](https://virtualenv.pypa.io/en/latest/) are your good friends and anyone of them 58 | is very helpful. I prefer `Miniconda` on my local machine. 59 | 60 | # Additional Resources 61 | 62 | * [General GitHub documentation](https://help.github.com/) 63 | * [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/) 64 | * [A guide to contributing to software packages](http://www.contribution-guide.org) 65 | * [Thinkful PR example](http://www.thinkful.com/learn/github-pull-request-tutorial/#Time-to-Submit-Your-First-PR) 66 | -------------------------------------------------------------------------------- /.github/workflows/ci_codecov.yaml: -------------------------------------------------------------------------------- 1 | name: CI CodeCov 2 | 3 | on: 4 | # GitHub has started calling new repo's first branch "main" https://github.com/github/renaming 5 | # Existing codes likely still have "master" as the primary branch 6 | # Both are tracked here to keep legacy and new codes working 7 | 8 | # push: 9 | # branches: 10 | # - "master" 11 | # - "main" 12 | 13 | pull_request: 14 | branches: 15 | - "master" 16 | - "main" 17 | schedule: 18 | # Nightly tests run on master by default: 19 | # Scheduled workflows run on the latest commit on the default or base branch. 20 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 21 | - cron: "0 0 * * *" 22 | 23 | jobs: 24 | test: 25 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} 26 | runs-on: ${{ matrix.os }} 27 | strategy: 28 | matrix: 29 | os: [ubuntu-latest] 30 | python-version: ["3.11"] 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | - name: Additional info about the build 35 | shell: bash 36 | run: | 37 | uname -a 38 | df -h 39 | ulimit -a 40 | 41 | - name: Set up Python ${{ matrix.python-version }} 42 | uses: actions/setup-python@v5 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | architecture: x64 46 | 47 | - name: Install dependencies 48 | shell: bash 49 | run: | 50 | python -m pip install --upgrade pip 51 | python -m pip install -r requirements_dev.txt 52 | python -m pip install -U pytest pytest-cov codecov 53 | 54 | - name: Install package 55 | shell: bash 56 | run: | 57 | python -m pip install . 58 | # pip install tox tox-gh-actions 59 | 60 | - name: Run tests 61 | shell: bash 62 | run: | 63 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector 64 | 65 | - name: CodeCov 66 | uses: codecov/codecov-action@v4.5.0 67 | with: 68 | token: ${{ secrets.CODECOV_TOKEN }} 69 | # Temp fix for https://github.com/codecov/codecov-action/issues/1487 70 | version: v0.6.0 71 | fail_ci_if_error: true 72 | file: ./coverage.xml 73 | flags: unittests 74 | name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} 75 | -------------------------------------------------------------------------------- /.github/workflows/ci_tox.yaml: -------------------------------------------------------------------------------- 1 | name: CI Tox 2 | 3 | on: 4 | # GitHub has started calling new repo's first branch "main" https://github.com/github/renaming 5 | # Existing codes likely still have "master" as the primary branch 6 | # Both are tracked here to keep legacy and new codes working 7 | push: 8 | branches: 9 | - "master" 10 | - "main" 11 | pull_request: 12 | branches: 13 | - "master" 14 | - "main" 15 | schedule: 16 | # Nightly tests run on master by default: 17 | # Scheduled workflows run on the latest commit on the default or base branch. 18 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 19 | - cron: "0 0 * * *" 20 | 21 | jobs: 22 | test: 23 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} 24 | runs-on: ${{ matrix.os }} 25 | strategy: 26 | matrix: 27 | # os: [macOS-latest, windows-latest, ubuntu-latest] 28 | os: [macos-13, windows-latest, ubuntu-latest] 29 | python-version: ["3.9", "3.10", "3.11", "3.12"] 30 | 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Additional info about the build 34 | shell: bash 35 | run: | 36 | uname -a 37 | df -h 38 | ulimit -a 39 | 40 | - name: Set up Python ${{ matrix.python-version }} 41 | uses: actions/setup-python@v5 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | architecture: x64 45 | 46 | - name: Install dependencies 47 | shell: bash 48 | run: | 49 | python -m pip install --upgrade pip 50 | python -m pip install -r requirements_dev.txt 51 | # python -m pip install -U pytest pytest-cov codecov 52 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector 53 | 54 | - name: Install package 55 | shell: bash 56 | run: | 57 | python -m pip install . 58 | pip install tox tox-gh-actions 59 | 60 | - name: Run tests 61 | shell: bash 62 | run: | 63 | tox 64 | -------------------------------------------------------------------------------- /.github/workflows/pypi_release.yaml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | on: 3 | push: 4 | tags: 5 | # Trigger on version tags (e.g., v1.0.0) 6 | - "v*.*.*" 7 | # Trigger on pre-release tags (e.g., v1.0.0-alpha.1) 8 | - "v*.*.*-*" 9 | 10 | env: 11 | # package name 12 | PYPI_NAME: qc-selector 13 | 14 | jobs: 15 | build: 16 | name: Build and Test Distribution 17 | runs-on: ${{ matrix.os }} 18 | 19 | strategy: 20 | matrix: 21 | # os: [ubuntu-latest, macos-latest, windows-latest] 22 | os: [ubuntu-latest] 23 | # python-version: ["3.9", "3.10", "3.11", "3.12"] 24 | python-version: ["3.11"] 25 | outputs: 26 | os: ${{ matrix.os }} 27 | python-version: ${{ matrix.python-version }} 28 | 29 | steps: 30 | - uses: actions/checkout@v4 31 | with: 32 | fetch-depth: 0 33 | - name: Set up Python 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | # python-version: "3.11" 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | python -m pip install build pytest pytest-cov codecov 42 | python -m pip install -r requirements.txt 43 | python -m pip install -r requirements_dev.txt 44 | - name: Test package 45 | run: | 46 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector 47 | - name: Build package 48 | run: python -m build 49 | - name: Store the distribution packages 50 | uses: actions/upload-artifact@v4 51 | with: 52 | # Name of the artifact to upload, unique for each OS and Python version 53 | name: python-package-distributions 54 | path: dist/ 55 | # Optional parameters for better artifact management 56 | overwrite: false 57 | include-hidden-files: false 58 | 59 | publish-to-pypi: 60 | name: Publish Python distribution to PyPI 61 | if: startsWith(github.ref, 'refs/tags/v') 62 | needs: build 63 | runs-on: ubuntu-latest 64 | environment: 65 | name: PyPI-Release 66 | url: https://pypi.org/p/${{ env.PYPI_NAME }} 67 | permissions: 68 | id-token: write 69 | 70 | steps: 71 | - name: Download all the dists 72 | uses: actions/download-artifact@v4 73 | with: 74 | name: python-package-distributions 75 | path: dist/ 76 | - name: Publish distribution to PyPI 77 | uses: pypa/gh-action-pypi-publish@release/v1 78 | env: 79 | TWINE_USERNAME: "__token__" 80 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 81 | 82 | github-release: 83 | name: Sign and Upload Python Distribution to GitHub Release 84 | needs: 85 | - build 86 | - publish-to-pypi 87 | runs-on: ubuntu-latest 88 | permissions: 89 | contents: write 90 | id-token: write 91 | 92 | steps: 93 | - name: Download all the dists 94 | uses: actions/download-artifact@v4 95 | with: 96 | name: python-package-distributions 97 | path: dist/ 98 | - name: Sign the dists with Sigstore 99 | uses: sigstore/gh-action-sigstore-python@v3.0.0 100 | with: 101 | inputs: >- 102 | ./dist/*.tar.gz 103 | ./dist/*.whl 104 | - name: Create GitHub Release 105 | env: 106 | GITHUB_TOKEN: ${{ github.token }} 107 | run: > 108 | gh release create 109 | '${{ github.ref_name }}' 110 | --repo '${{ github.repository }}' 111 | --notes "" 112 | - name: Upload artifact signatures to GitHub Release 113 | env: 114 | GITHUB_TOKEN: ${{ github.token }} 115 | run: > 116 | gh release upload 117 | '${{ github.ref_name }}' dist/** 118 | --repo '${{ github.repository }}' 119 | 120 | publish-none-pypi: 121 | name: Publish Python distribution to TestPyPI (none) 122 | if: startsWith(github.ref, 'refs/tags/v') 123 | needs: build 124 | runs-on: ubuntu-latest 125 | environment: 126 | name: TestPyPI 127 | url: https://test.pypi.org/p/${{ env.PYPI_NAME }} 128 | permissions: 129 | id-token: write 130 | 131 | steps: 132 | - name: Download all the dists 133 | uses: actions/download-artifact@v4 134 | with: 135 | name: python-package-distributions 136 | path: dist/ 137 | - name: Publish distribution with relaxed constraints 138 | uses: pypa/gh-action-pypi-publish@release/v1 139 | with: 140 | repository-url: https://test.pypi.org/legacy/ 141 | env: 142 | TWINE_USERNAME: "__token__" 143 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }} 144 | -------------------------------------------------------------------------------- /.github/workflows/website_auto.yaml: -------------------------------------------------------------------------------- 1 | name: deploy-book 2 | 3 | # Only run this when the master branch changes 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | types: [opened, synchronize, reopened, closed] 10 | branches: 11 | - main 12 | # If your git repository has the Jupyter Book within some-subfolder next to 13 | # unrelated files, you can make this run only if a file within that specific 14 | # folder has been modified. 15 | # 16 | #paths: 17 | #- book/ 18 | 19 | # This job installs dependencies, builds the book, and pushes it to `gh-pages` 20 | jobs: 21 | deploy-book: 22 | runs-on: ubuntu-latest 23 | permissions: 24 | pages: write 25 | # https://github.com/JamesIves/github-pages-deploy-action/issues/1110 26 | contents: write 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | # Install dependencies 32 | - name: Set up Python 3.11 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: 3.11 36 | 37 | - name: Install dependencies 38 | run: | 39 | pip install -r book/content/requirements.txt 40 | # Install selector 41 | - name: Install package 42 | run: | 43 | pip install -e . 44 | 45 | # Build the book 46 | - name: Build the book 47 | run: | 48 | cp notebooks/*.ipynb book/content/. 49 | jupyter-book build ./book/content 50 | 51 | # Push the book's HTML to github-pages 52 | # inspired by https://github.com/orgs/community/discussions/26724 53 | # only push to gh-pages if the main branch has been updated 54 | - name: GitHub Pages Action 55 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 56 | uses: peaceiris/actions-gh-pages@v3 57 | with: 58 | github_token: ${{ secrets.GITHUB_TOKEN }} 59 | publish_dir: ./book/content/_build/html 60 | publish_branch: gh-pages 61 | cname: selector.qcdevs.org 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | doc/_build/ 73 | doc/html/ 74 | doc/latex/ 75 | doc/man/ 76 | doc/xml/ 77 | doc/source 78 | doc/modules 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # Editor junk 138 | tags 139 | [._]*.s[a-v][a-z] 140 | [._]*.sw[a-p] 141 | [._]s[a-v][a-z] 142 | [._]sw[a-p] 143 | *~ 144 | \#*\# 145 | .\#* 146 | .ropeproject 147 | .idea/ 148 | .spyderproject 149 | .spyproject 150 | .vscode/ 151 | # Mac .DS_Store 152 | .DS_Store 153 | 154 | # jupyter notebook checkpoints 155 | .ipynb_checkpoints 156 | 157 | # codecov files 158 | *.gcno 159 | *.gcda 160 | *.gcov 161 | 162 | */*/_build 163 | */_build 164 | 165 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # pre-commit is a tool to perform a predefined set of tasks manually and/or 2 | # automatically before git commits are made. 3 | # 4 | # Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level 5 | # 6 | # Common tasks 7 | # 8 | # - Run on all files: pre-commit run --all-files 9 | # - Register git hooks: pre-commit install --install-hooks 10 | # 11 | 12 | # https://github.com/jupyterhub/jupyterhub/blob/main/.pre-commit-config.yaml 13 | # https://github.com/pre-commit/pre-commit/blob/main/.pre-commit-config.yaml 14 | # https://github.com/psf/black/blob/main/.pre-commit-config.yaml 15 | # https://docs.releng.linuxfoundation.org/en/latest/best-practices.html 16 | 17 | ci: 18 | # https://pre-commit.ci/ 19 | # pre-commit.ci will open PRs updating our hooks once a month 20 | # 'weekly', 'monthly', 'quarterly' 21 | autoupdate_schedule: quarterly 22 | autofix_prs: false 23 | autofix_commit_msg: "[pre-commit.ci] Auto fixes from pre-commit.com hooks." 24 | submodules: false 25 | 26 | repos: 27 | - repo: https://github.com/pre-commit/pre-commit-hooks 28 | rev: v4.5.0 29 | hooks: 30 | # https://pre-commit.com/hooks.html 31 | # - id: fix-encoding-pragma 32 | - id: sort-simple-yaml 33 | - id: trailing-whitespace 34 | - id: end-of-file-fixer 35 | - id: check-yaml 36 | - id: check-toml 37 | # replaces or checks mixed line ending 38 | - id: mixed-line-ending 39 | - id: debug-statements 40 | # - id: name-tests-test 41 | - id: trailing-whitespace 42 | # sorts entries in requirements.txt 43 | - id: requirements-txt-fixer 44 | 45 | # we are not using this for now as we decided to use pyproject.toml instead 46 | ## https://github.com/asottile/setup-cfg-fmt 47 | #- repo: https://github.com/asottile/setup-cfg-fmt 48 | # rev: v2.5.0 49 | # hooks: 50 | # - id: setup-cfg-fmt 51 | 52 | # - repo: https://github.com/psf/black-pre-commit-mirror 53 | - repo: https://github.com/psf/black 54 | rev: 24.3.0 55 | hooks: 56 | - id: black 57 | args: [ 58 | --line-length=100, 59 | ] 60 | 61 | - repo: https://github.com/pycqa/isort 62 | rev: 5.13.2 63 | hooks: 64 | - id: isort 65 | 66 | # todo: disable flake8 for now, but will need to add it back in the future 67 | # - repo: https://github.com/pycqa/flake8 68 | # rev: 6.1.0 69 | # hooks: 70 | # - id: flake8 71 | # args: ["--max-line-length=100"] 72 | # additional_dependencies: 73 | # - flake8-bugbear 74 | # - flake8-comprehensions 75 | # - flake8-simplify 76 | # - flake8-docstrings 77 | # - flake8-import-order>=0.9 78 | # - flake8-colors 79 | # exclude: ^src/blib2to3/ 80 | 81 | - repo: https://github.com/pycqa/bandit 82 | rev: 1.7.8 83 | hooks: 84 | - id: bandit 85 | # exclude some directories 86 | exclude: | 87 | (?x)( 88 | ^test/| 89 | ^book/ 90 | ^devtools/ 91 | ^docs/ 92 | ^doc/ 93 | ) 94 | args: [ 95 | "-c", "pyproject.toml" 96 | ] 97 | additional_dependencies: [ "bandit[toml]" ] 98 | 99 | - repo: https://github.com/asottile/pyupgrade 100 | rev: v3.15.2 101 | hooks: 102 | - id: pyupgrade 103 | args: [--py37-plus] 104 | 105 | # todo: add this for type checking in the future 106 | # - repo: https://github.com/pre-commit/mirrors-mypy 107 | # rev: v1.5.1 108 | # hooks: 109 | # - id: mypy 110 | # exclude: ^docs/conf.py 111 | # args: ["--config-file", "pyproject.toml"] 112 | # additional_dependencies: 113 | # - types-PyYAML 114 | # - tomli >= 0.2.6, < 2.0.0 115 | # - click >= 8.1.0, != 8.1.4, != 8.1.5 116 | # - packaging >= 22.0 117 | # - platformdirs >= 2.1.0 118 | # - pytest 119 | # - hypothesis 120 | # - aiohttp >= 3.7.4 121 | # - types-commonmark 122 | # - urllib3 123 | # - hypothesmith 124 | 125 | # - repo: https://github.com/pre-commit/mirrors-mypy 126 | # rev: v1.6.0 127 | # hooks: 128 | # - id: mypy 129 | # additional_dependencies: [types-all] 130 | # exclude: ^testing/resources/ 131 | -------------------------------------------------------------------------------- /.pre-commit-hooks.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/pre-commit/pre-commit-hooks/blob/main/.pre-commit-hooks.yaml 2 | 3 | - id: trailing-whitespace 4 | name: trim trailing whitespace 5 | description: trims trailing whitespace. 6 | entry: trailing-whitespace-fixer 7 | language: python 8 | types: [text] 9 | # exclude: ^(book|devtools|docs|doc)/ 10 | stages: [commit, push, manual] 11 | 12 | - id: sort-simple-yaml 13 | name: sort simple yaml files 14 | description: sorts simple yaml files which consist only of top-level keys, preserving comments and blocks. 15 | language: python 16 | entry: sort-simple-yaml 17 | files: '^$' 18 | 19 | - id: black 20 | name: black 21 | description: "Black: The uncompromising Python code formatter" 22 | entry: black 23 | language: python 24 | minimum_pre_commit_version: 3.4.0 25 | require_serial: true 26 | types_or: [python, pyi] 27 | 28 | - id: check-toml 29 | name: check toml 30 | description: checks toml files for parseable syntax. 31 | entry: check-toml 32 | language: python 33 | types: [toml] 34 | 35 | - id: check-yaml 36 | name: check yaml 37 | description: checks yaml files for parseable syntax. 38 | entry: check-yaml 39 | language: python 40 | # exclude: ^website.yml 41 | types: [yaml] 42 | 43 | - id: requirements-txt-fixer 44 | name: fix requirements.txt 45 | description: sorts entries in requirements.txt. 46 | entry: requirements-txt-fixer 47 | language: python 48 | files: (requirements|constraints).*\.txt$ 49 | 50 | - id: mixed-line-ending 51 | name: mixed line ending 52 | description: replaces or checks mixed line ending. 53 | entry: mixed-line-ending 54 | language: python 55 | types: [text] 56 | 57 | - id: end-of-file-fixer 58 | name: fix end of files 59 | description: ensures that a file is either empty, or ends with one newline. 60 | entry: end-of-file-fixer 61 | language: python 62 | types: [python] 63 | stages: [commit, push, manual] 64 | 65 | - id: debug-statements 66 | name: debug statements (python) 67 | description: checks for debugger imports and py37+ `breakpoint()` calls in python source. 68 | entry: debug-statement-hook 69 | language: python 70 | types: [python] 71 | 72 | - id: check-merge-conflict 73 | name: check for merge conflicts 74 | description: checks for files that contain merge conflict strings. 75 | entry: check-merge-conflict 76 | language: python 77 | types: [text] 78 | 79 | - id: check-ast 80 | name: check python ast 81 | description: simply checks whether the files parse as valid python. 82 | entry: check-ast 83 | language: python 84 | types: [python] 85 | -------------------------------------------------------------------------------- /.roberto.yaml: -------------------------------------------------------------------------------- 1 | # Force absolute comparison for cardboardlint 2 | absolute: true 3 | project: 4 | name: selector 5 | # requirements: [[numpydoc, numpydoc], [sphinx-autoapi, sphinx-autoapi], [sphinxcontrib-bibtex, sphinxcontrib-bibtex]] 6 | packages: 7 | - dist_name: selector 8 | tools: 9 | - write-py-version 10 | # - cardboardlint-static 11 | - build-py-inplace 12 | - pytest 13 | - upload-codecov 14 | # - cardboardlint-dynamic 15 | # - build-sphinx-doc 16 | # - upload-docs-gh 17 | - build-py-source 18 | - build-conda 19 | - deploy-pypi 20 | - deploy-conda 21 | # - deploy-github 22 | -------------------------------------------------------------------------------- /HEADER: -------------------------------------------------------------------------------- 1 | 2 | The Selector is a Python library of algorithms for selecting diverse 3 | subsets of data for machine-learning. 4 | 5 | Copyright (C) 2022-2024 The QC-Devs Community 6 | 7 | This file is part of Selector. 8 | 9 | Selector is free software; you can redistribute it and/or 10 | modify it under the terms of the GNU General Public License 11 | as published by the Free Software Foundation; either version 3 12 | of the License, or (at your option) any later version. 13 | 14 | Selector is distributed in the hope that it will be useful, 15 | but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | GNU General Public License for more details. 18 | 19 | You should have received a copy of the GNU General Public License 20 | along with this program; if not, see 21 | 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include MANIFEST.in 3 | include CODE_OF_CONDUCT.md 4 | include versioneer.py 5 | 6 | prune notebooks 7 | prune book 8 | 9 | graft selector 10 | global-exclude *.py[cod] __pycache__ *.so 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Logo 4 |
5 | 6 | [![This project supports Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://python.org/downloads) 7 | [![GPLv3 License](https://img.shields.io/badge/License-GPL%20v3-yellow.svg)](https://opensource.org/licenses/) 8 | [![CI Tox](https://github.com/theochem/Selector/actions/workflows/ci_tox.yaml/badge.svg?branch=main)](https://github.com/theochem/Selector/actions/workflows/ci_tox.yaml) 9 | [![codecov](https://codecov.io/gh/theochem/Selector/graph/badge.svg?token=0UJixrJfNJ)](https://codecov.io/gh/theochem/Selector) 10 | 11 | The `Selector` library provides methods for selecting a diverse subset of a (molecular) dataset. 12 | 13 | ## Citation 14 | 15 | Please use the following citation in any publication using the `selector` library: 16 | 17 | ```md 18 | @article{ 19 | TO BE ADDED LATER 20 | } 21 | ``` 22 | 23 | ## Web Server 24 | 25 | We have a web server for the `selector` library at https://huggingface.co/spaces/QCDevs/Selector. 26 | For small and prototype datasets, you can use the web server to select a diverse subset of your 27 | dataset and compute the diversity metrics, where you can download the selected subset and the 28 | computed diversity metrics. 29 | 30 | ## Installation 31 | 32 | It is recommended to install `selector` within a virtual environment. To create a virtual 33 | environment, we can use the `venv` module (Python 3.3+, 34 | https://docs.python.org/3/tutorial/venv.html), `miniconda` (https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html), or 35 | `pipenv` (https://pipenv.pypa.io/en/latest/). 36 | 37 | ### Installing from PyPI 38 | 39 | To install `selector` with `pip`, we can install the latest stable release from the Python Package Index (PyPI) as follows: 40 | 41 | ```bash 42 | # install the stable release. 43 | pip install qc-selector 44 | ``` 45 | 46 | ### Installing from The Prebuild Wheel Files 47 | 48 | To download the prebuilt wheel files, visit the [PyPI page](https://pypi.org/project/qc-selector/) 49 | and [GitHub releases](https://github.com/theochem/Selector/tags). 50 | 51 | ```bash 52 | # download the wheel file first to your local machine 53 | # then install the wheel file 54 | pip install file_path/qc_selector-0.0.2b12-py3-none-any.whl 55 | ``` 56 | 57 | ### Installing from the Source Code 58 | 59 | In addition, we can install the latest development version from the GitHub repository as follows: 60 | 61 | ```bash 62 | # install the latest development version 63 | pip install git+https://github.com/theochem/Selector.git 64 | ``` 65 | 66 | We can also clone the repository to access the latest development version, test it and install it as follows: 67 | 68 | ```bash 69 | # clone the repository 70 | git clone git@github.com:theochem/Selector.git 71 | 72 | # change into the working directory 73 | cd Selector 74 | # run the tests 75 | python -m pytest . 76 | 77 | # install the package 78 | pip install . 79 | 80 | ``` 81 | 82 | ## More 83 | 84 | See https://selector.qcdevs.org for full details. 85 | -------------------------------------------------------------------------------- /book/content/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: Ayers Lab 5 | author: The QC-Dev Community with the support of Digital Research Alliance of Canada and The Jupyter Book Community 6 | logo: selector_logo.png 7 | 8 | # Force re-execution of notebooks on each build. 9 | # See https://jupyterbook.org/content/execute.html 10 | execute: 11 | execute_notebooks: 'off' 12 | 13 | # Define the name of the latex output file for PDF builds 14 | latex: 15 | latex_documents: 16 | targetname: book.tex 17 | 18 | sphinx: 19 | extra_extensions: 20 | - 'sphinx.ext.autodoc' 21 | config: 22 | mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 23 | 24 | 25 | #Read LaTeX 26 | parse: 27 | myst_enable_extensions: # default extensions to enable in the myst parser. See https://myst-parser.readthedocs.io/en/latest/using/syntax-optional.html 28 | # - amsmath 29 | - colon_fence 30 | # - deflist 31 | - dollarmath 32 | # - html_admonition 33 | # - html_image 34 | - linkify 35 | # - replacements 36 | # - smartquotes 37 | - substitution 38 | - tasklist 39 | myst_url_schemes: [mailto, http, https] # URI schemes that will be recognised as external URLs in Markdown links 40 | myst_dmath_double_inline: true # Allow display math ($$) within an inline context 41 | myst_dmath_allow_space: false 42 | 43 | # Add a bibtex file so that we can create citations 44 | bibtex_bibfiles: 45 | - references.bib 46 | 47 | # Information about where the book exists on the web 48 | repository: 49 | url: https://github.com/theochem/Selector # Online location of your book 50 | path_to_book: book/content # Optional path to your book, relative to the repository root 51 | branch: main # Which branch of the repository should be used when creating links (optional) 52 | 53 | 54 | # Add GitHub buttons to your book 55 | launch_buttons: 56 | thebe : true 57 | colab_url: "https://colab.research.google.com" 58 | 59 | copyright: "2022-2024" 60 | 61 | html: 62 | use_issues_button: true 63 | use_repository_button: true 64 | favicon: "selector_logo.png" 65 | use_multitoc_numbering: false 66 | -------------------------------------------------------------------------------- /book/content/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: intro.md 6 | parts: 7 | - caption: Tutorial 8 | chapters: 9 | - file: tutorial_distance_based.ipynb 10 | - file: tutorial_partition_based.ipynb 11 | - file: tutorial_diversity_measures.ipynb 12 | - file: tutorial_similarity_based.ipynb 13 | - caption: API 14 | chapters: 15 | - file: api_methods_base.rst 16 | - file: api_methods_distance.rst 17 | - file: api_methods_partition.rst 18 | - file: api_methods_similarity.rst 19 | - file: api_methods_utils.rst 20 | - file: api_measures_converter.rst 21 | - file: api_measures_diversity.rst 22 | - file: api_measures_similarity.rst 23 | -------------------------------------------------------------------------------- /book/content/api_measures_converter.rst: -------------------------------------------------------------------------------- 1 | .. _measures.converter: 2 | 3 | :mod:`selector.measures.converter` 4 | ================================== 5 | 6 | .. automodule:: selector.measures.converter 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_measures_diversity.rst: -------------------------------------------------------------------------------- 1 | .. _measures.diversity: 2 | 3 | :mod:`selector.measures.diversity` 4 | ================================== 5 | 6 | .. automodule:: selector.measures.diversity 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_measures_similarity.rst: -------------------------------------------------------------------------------- 1 | .. _measures.similarity: 2 | 3 | :mod:`selector.measures.similarity` 4 | =================================== 5 | 6 | .. automodule:: selector.measures.similarity 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_methods_base.rst: -------------------------------------------------------------------------------- 1 | .. _methods.base: 2 | 3 | :mod:`selector.methods.base` 4 | ============================ 5 | 6 | .. automodule:: selector.methods.base 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_methods_distance.rst: -------------------------------------------------------------------------------- 1 | .. _methods.distance: 2 | 3 | :mod:`selector.methods.distance` 4 | ================================ 5 | 6 | .. automodule:: selector.methods.distance 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_methods_partition.rst: -------------------------------------------------------------------------------- 1 | .. _methods.partition: 2 | 3 | :mod:`selector.methods.partition` 4 | ================================= 5 | 6 | .. automodule:: selector.methods.partition 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_methods_similarity.rst: -------------------------------------------------------------------------------- 1 | .. _methods.similarity: 2 | 3 | :mod:`selector.methods.similarity` 4 | ================================== 5 | 6 | .. automodule:: selector.methods.similarity 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/api_methods_utils.rst: -------------------------------------------------------------------------------- 1 | .. _methods.utils: 2 | 3 | :mod:`selector.methods.utils` 4 | ============================= 5 | 6 | .. automodule:: selector.methods.utils 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :private-members: 11 | -------------------------------------------------------------------------------- /book/content/intro.md: -------------------------------------------------------------------------------- 1 | 2 | # Welcome to QC-Selector's Documentation! 3 | 4 | [Selector](https://github.com/theochem/Selector) is a free, open-source, and cross-platform Python library designed to help you effortlessly identify the most diverse subset of molecules from your dataset. Please use the following citation in any publication using Selector library: 5 | 6 | **"TO be added"** 7 | 8 | The Selector source code is hosted on [GitHub](https://github.com/theochem/Selector) and is released under the [GNU General Public License v3.0](https://github.com/theochem/Selector/blob/main/LICENSE). We welcome any contributions to the Selector library in accordance with our Code of Conduct; please see our [Contributing Guidelines](https://qcdevs.org/guidelines/QCDevsCodeOfConduct/). Please report any issues you encounter while using Selector library on [GitHub Issues](https://github.com/theochem/Selector/issues). For further information and inquiries please contact us at qcdevs@gmail.com. 9 | 10 | 11 | ## Why QC-Selector? 12 | 13 | Selecting diverse and representative subsets is crucial for the data-driven models and machine 14 | learning applications in many science and engineering disciplines, especially for molecular design 15 | and drug discovery. Motivated by this, we develop the Selector package, a free and open-source Python library for selecting diverse subsets. 16 | 17 | The Selector library implements a range of existing algorithms for subset sampling based on the 18 | distance between and similarity of samples, as well as tools based on spatial partitioning. In 19 | addition, it includes seven diversity measures for quantifying the diversity of a given set. We also 20 | implemented various mathematical formulations to convert similarities into dissimilarities. 21 | 22 | 23 | ## Web Server 24 | 25 | We have a web server for the `selector` library at https://huggingface.co/spaces/QCDevs/Selector. 26 | For small and prototype datasets, you can use the web server to select a diverse subset of your 27 | dataset and compute the diversity metrics, where you can download the selected subset and the 28 | computed diversity metrics. 29 | 30 | ## Installation 31 | 32 | It is recommended to install `selector` within a virtual environment. To create a virtual 33 | environment, we can use the `venv` module (Python 3.3+, 34 | https://docs.python.org/3/tutorial/venv.html), `miniconda` (https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html), or 35 | `pipenv` (https://pipenv.pypa.io/en/latest/). 36 | 37 | ### Installing from PyPI 38 | 39 | To install `selector` with `pip`, we can install the latest stable release from the Python Package Index (PyPI) as follows: 40 | 41 | ```bash 42 | # install the stable release. 43 | pip install qc-selector 44 | ``` 45 | 46 | ### Installing from The Prebuild Wheel Files 47 | 48 | To download the prebuilt wheel files, visit the [PyPI page](https://pypi.org/project/qc-selector/) 49 | and [GitHub releases](https://github.com/theochem/Selector/tags). 50 | 51 | ```bash 52 | # download the wheel file first to your local machine 53 | # then install the wheel file 54 | pip install file_path/qc_selector-0.0.2b12-py3-none-any.whl 55 | ``` 56 | 57 | ### Installing from the Source Code 58 | 59 | In addition, we can install the latest development version from the GitHub repository as follows: 60 | 61 | ```bash 62 | # install the latest development version 63 | pip install git+https://github.com/theochem/Selector.git 64 | ``` 65 | 66 | We can also clone the repository to access the latest development version, test it and install it as follows: 67 | 68 | ```bash 69 | # clone the repository 70 | git clone git@github.com:theochem/Selector.git 71 | 72 | # change into the working directory 73 | cd Selector 74 | # run the tests 75 | python -m pytest . 76 | 77 | # install the package 78 | pip install . 79 | 80 | ``` 81 | -------------------------------------------------------------------------------- /book/content/references.bib: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | @article{selector, 4 | title={selector: A Novel Approach for Selecting Diverse Molecular Subsets}, 5 | author={QC-Devs}, 6 | journal={Nature Chemistry}, 7 | year={2023}, 8 | doi={10.1234/natchem.2023.01234}, 9 | } 10 | -------------------------------------------------------------------------------- /book/content/requirements.txt: -------------------------------------------------------------------------------- 1 | docutils<0.18 2 | ghp-import 3 | jupyter-book 4 | matplotlib 5 | numpy 6 | -------------------------------------------------------------------------------- /book/content/selector_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theochem/Selector/c04a253a928c6abe6c664c4090998187a5407e7e/book/content/selector_logo.png -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Codecov configuration to make it a bit less noisy 2 | codecov: 3 | token: ${{ secrets.CODECOV_TOKEN }} 4 | notify: 5 | require_ci_to_pass: yes 6 | 7 | coverage: 8 | precision: 2 9 | round: down 10 | range: 70...100 11 | 12 | status: 13 | patch: false 14 | 15 | project: 16 | default: 17 | # inspired by 18 | # github.com/theochem/iodata/blob/73eee89111bcec426d9e8651ec5d85541c7d6d24/.codecov.yml#L1-L9 19 | # Commits and PRs are never marked as "failed" due coverage issues. 20 | # Codecov is only used as an informal tool when reviewing PRs, 21 | # not in the least because of the many false failures. 22 | target: 0% 23 | threshold: 100% 24 | 25 | # ignore statistics for the testing folders 26 | ignore: 27 | - .*/tests/.* 28 | - .*/.*/tests/.* 29 | - .*/examples/.* 30 | - .*/__int__.py 31 | - .*/_version.py 32 | - "test_*.rb" # wildcards accepted 33 | - .*/data/.* 34 | - .*/versioneer.py 35 | - .*/*/__init__.py 36 | # - "**/*.pyc" # glob accepted 37 | 38 | comment: 39 | layout: "reach, header, diff, uncovered, files, changes," 40 | behavior: default 41 | require_changes: false # if true: only post the comment if coverage changes 42 | require_base: no # [yes :: must have a base report to post] 43 | require_head: yes # [yes :: must have a head report to post] 44 | branches: # branch names that can post comment 45 | - staging 46 | - main 47 | -------------------------------------------------------------------------------- /notebooks/tutorial_diversity_measures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tutorial Diversity Measures\n", 8 | "\n", 9 | "This tutorial demonstrates how to quantify the diversity of selected subset with `diversity` module as implemented in\n", 10 | "`selector` package. The diversity measures are calculated based on the feature matrix of the selected subset." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "import warnings\n", 21 | "\n", 22 | "warnings.filterwarnings(\"ignore\")\n", 23 | "\n", 24 | "# uncomment the following line to run the code for your own project directory\n", 25 | "# sys.path.append(\"/Users/Someone/Documents/projects/Selector\")" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import matplotlib.pylab as plt\n", 35 | "import numpy as np\n", 36 | "from sklearn.datasets import make_blobs\n", 37 | "from sklearn.metrics.pairwise import pairwise_distances\n", 38 | "from IPython.display import Markdown\n", 39 | "\n", 40 | "from selector.methods.distance import MaxMin, MaxSum, OptiSim, DISE\n", 41 | "from selector.diversity import compute_diversity, hypersphere_overlap_of_subset" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Utility Function for Showing Diversity Measures as A Table\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# define function to render tables easier\n", 58 | "\n", 59 | "\n", 60 | "def render_table(data, caption=None, decimals=3):\n", 61 | " \"\"\"Renders a list of lists in ta markdown table for easy visualization.\n", 62 | "\n", 63 | " Parameters\n", 64 | " ----------\n", 65 | " data : list of lists\n", 66 | " The data to be rendered in a table, each inner list represents a row with the first row\n", 67 | " being the header.\n", 68 | " caption : str, optional\n", 69 | " The caption of the table.\n", 70 | " decimals : int, optional\n", 71 | " The number of decimal places to round the data to.\n", 72 | " \"\"\"\n", 73 | "\n", 74 | " # check all rows have the same number of columns\n", 75 | " if not all(len(row) == len(data[0]) for row in data):\n", 76 | " raise ValueError(\"Expect all rows to have the same number of columns.\")\n", 77 | "\n", 78 | " if caption is not None:\n", 79 | " # check if caption is a string\n", 80 | " if not isinstance(caption, str):\n", 81 | " raise ValueError(\"Expect caption to be a string.\")\n", 82 | " tmp_output = f\"**{caption}**\\n\\n\"\n", 83 | "\n", 84 | " # get the width of each column (transpose the data list and get the max length of each new row)\n", 85 | " colwidths = [max(len(str(s)) for s in col) + 2 for col in zip(*data)]\n", 86 | "\n", 87 | " # construct the header row\n", 88 | " header = f\"| {' | '.join(f'{str(s):^{w}}' for s, w in zip(data[0], colwidths))} |\"\n", 89 | " tmp_output += header + \"\\n\"\n", 90 | "\n", 91 | " # construct a separator row\n", 92 | " separator = f\"|{'|'.join(['-' * w for w in colwidths])}|\"\n", 93 | " tmp_output += separator + \"\\n\"\n", 94 | "\n", 95 | " # construct the data rows\n", 96 | " for row in data[1:]:\n", 97 | " # round the data to the specified number of decimal places\n", 98 | " row = [round(s, decimals) if isinstance(s, float) else s for s in row]\n", 99 | " row_str = f\"| {' | '.join(f'{str(s):^{w}}' for s, w in zip(row, colwidths))} |\"\n", 100 | " tmp_output += row_str + \"\\n\"\n", 101 | "\n", 102 | " return display(Markdown(tmp_output))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Generating Data\n", 110 | "\n", 111 | "The data should be provided as:\n", 112 | "\n", 113 | "- either an array `X` of shape `(n_samples, n_features)` encoding `n_samples` samples (rows) each in `n_features`-dimensional (columns) feature space,\n", 114 | "- or an array `X_dist` of shape `(n_samples, n_samples)` encoding the distance (i.e., dissimilarity) between each pair of `n_samples` sample points.\n", 115 | "\n", 116 | "This data can be loaded from various file formats (e.g., csv, npz, txt, etc.) or generated using various libraries on the fly. In this tutorial, we use [`sklearn.datasets.make_blobs`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html) to generate cluster(s) of `n_samples` points in 2-dimensions (`n-features=2`), so that it can be easily visualized. However, the same functionality can be applied to higher dimensional datasets.\n" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "Shape of data = (500, 20)\n", 129 | "Shape of labels = (500,)\n", 130 | "Unique labels = [0 1 2]\n", 131 | "Cluster size = 167\n", 132 | "Shape of the distance array = (500, 500)\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "# generate n_sample data in 2D feature space forming 1 cluster\n", 138 | "X, labels = make_blobs(\n", 139 | " n_samples=500,\n", 140 | " n_features=20,\n", 141 | " # centers=np.array([[0.0, 0.0]]),\n", 142 | " random_state=42,\n", 143 | ")\n", 144 | "# binarize the fetures\n", 145 | "# Calculate median for each feature\n", 146 | "median_threshold = np.median(X, axis=0)\n", 147 | "X = (X > median_threshold).astype(int)\n", 148 | "\n", 149 | "# compute the (n_sample, n_sample) pairwise distance matrix\n", 150 | "X_dist = pairwise_distances(X, metric=\"euclidean\")\n", 151 | "\n", 152 | "print(\"Shape of data = \", X.shape)\n", 153 | "print(\"Shape of labels = \", labels.shape)\n", 154 | "print(\"Unique labels = \", np.unique(labels))\n", 155 | "print(\"Cluster size = \", np.count_nonzero(labels == 0))\n", 156 | "print(\"Shape of the distance array = \", X_dist.shape)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "## Perform the Subset Selection" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "from scipy.spatial.distance import pdist, squareform\n", 173 | "\n", 174 | "# select data using distance base methods\n", 175 | "# ---------------------------------------\n", 176 | "size = 50\n", 177 | "\n", 178 | "collector = MaxMin()\n", 179 | "index_maxmin = collector.select(X_dist, size=size)\n", 180 | "\n", 181 | "collector = MaxSum(fun_dist=lambda x: squareform(pdist(x, metric=\"minkowski\", p=0.1)))\n", 182 | "index_maxsum = collector.select(X, size=size)\n", 183 | "\n", 184 | "collector = OptiSim(ref_index=0, tol=0.1)\n", 185 | "index_optisim = collector.select(X_dist, size=size)\n", 186 | "\n", 187 | "collector = DISE(ref_index=0, p=2.0)\n", 188 | "index_dise = collector.select(X, size=size)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 6, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/markdown": [ 199 | "**Diversity of Selected Sets**\n", 200 | "\n", 201 | "| | logdet | wdud | shannon_entropy | hypersphere_overlap |\n", 202 | "|---------|--------------------|---------------------|--------------------|---------------------|\n", 203 | "| MaxMin | 44.143 | 0.273 | 18.637 | 1299.615 |\n", 204 | "| MaxSum | 33.938 | 0.261 | 19.379 | 4396.672 |\n", 205 | "| OptiSim | 43.734 | 0.254 | 19.758 | 1175.49 |\n", 206 | "| DISE | 45.402 | 0.268 | 18.958 | 1363.434 |\n" 207 | ], 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "metadata": {}, 213 | "output_type": "display_data" 214 | } 215 | ], 216 | "source": [ 217 | "div_measure = [\"logdet\", \"wdud\", \"shannon_entropy\", \"hypersphere_overlap\"]\n", 218 | "seleced_sets = zip(\n", 219 | " [\"MaxMin\", \"MaxSum\", \"OptiSim\", \"DISE\"],\n", 220 | " [index_maxmin, index_maxsum, index_optisim, index_dise],\n", 221 | ")\n", 222 | "\n", 223 | "# compute the diversity of the selected sets and render the results in a table\n", 224 | "table_data = [[\"\"] + div_measure]\n", 225 | "for i in seleced_sets:\n", 226 | " div_data = [i[0]]\n", 227 | " for m in div_measure:\n", 228 | " if m != \"hypersphere_overlap\":\n", 229 | " div_data.append(compute_diversity(X[i[1]], div_type=m))\n", 230 | " else:\n", 231 | " div_data.append(hypersphere_overlap_of_subset(x=X, x_subset=X[i[1]]))\n", 232 | " table_data.append(div_data)\n", 233 | "\n", 234 | "render_table(table_data, caption=\"Diversity of Selected Sets\")" 235 | ] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "selector_div", 241 | "language": "python", 242 | "name": "python3" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.11.9" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 2 259 | } 260 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # The Selector library provides a set of tools for selecting a 2 | # subset of the dataset and computing diversity. 3 | # 4 | # Copyright (C) 2023 The QC-Devs Community 5 | # 6 | # This file is part of Selector. 7 | # 8 | # Selector is free software; you can redistribute it and/or 9 | # modify it under the terms of the GNU General Public License 10 | # as published by the Free Software Foundation; either version 3 11 | # of the License, or (at your option) any later version. 12 | # 13 | # Selector is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with this program; if not, see 20 | # 21 | # -- 22 | 23 | 24 | [project] 25 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ 26 | name = "qc-selector" 27 | description = "Subset selection with maximum diversity." 28 | readme = {file = 'README.md', content-type='text/markdown'} 29 | #requires-python = ">=3.9,<4.0" 30 | requires-python = ">=3.9" 31 | # "LICENSE" is name of the license file, which must be in root of project folder 32 | license = {file = "LICENSE"} 33 | authors = [ 34 | {name = "QC-Devs Community", email = "qcdevs@gmail.com"}, 35 | ] 36 | keywords = [ 37 | "subset selection", 38 | "variable selection", 39 | "chemical diversity", 40 | "compound selection", 41 | "maximum diversity", 42 | "chemical library design", 43 | "compound acquisition", 44 | ] 45 | 46 | # https://pypi.org/classifiers/ 47 | # Add PyPI classifiers here 48 | classifiers = [ 49 | "Development Status :: 5 - Production/Stable", 50 | "Environment :: Console", 51 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 52 | "Natural Language :: English", 53 | "Operating System :: MacOS", 54 | "Operating System :: Microsoft :: Windows", 55 | "Operating System :: Unix", 56 | "Operating System :: POSIX", 57 | "Operating System :: POSIX :: Linux", 58 | "Programming Language :: Python", 59 | "Programming Language :: Python :: 3", 60 | "Programming Language :: Python :: 3.9", 61 | "Programming Language :: Python :: 3.10", 62 | "Programming Language :: Python :: 3.11", 63 | "Programming Language :: Python :: 3.12", 64 | "Topic :: Scientific/Engineering", 65 | "Topic :: Sociology", 66 | 67 | ] 68 | 69 | # version = "0.0.2b4" 70 | dynamic = [ 71 | "dependencies", 72 | "optional-dependencies", 73 | "version", 74 | ] 75 | 76 | # not using this section now but it's here for reference 77 | # # Required dependencies for install/usage of your package or application 78 | # # If you don't have any dependencies, leave this section empty 79 | # # Format for dependency strings: https://peps.python.org/pep-0508/ 80 | # # dependencies from requirements.txt 81 | # dependencies = [ 82 | # "bitarray>=2.5.1", 83 | # "numpy>=1.21.2", 84 | # "scipy>=1.11.1", 85 | # ] 86 | 87 | # [project.optional-dependencies] 88 | # tests = [ 89 | # 'coverage>=5.0.3', 90 | # "pandas>=1.3.5", 91 | # "pytest>=7.4.0", 92 | # "scikit-learn>=1.0.1", 93 | # ] 94 | 95 | [project.scripts] 96 | # Command line interface entrypoint scripts 97 | # selector = "selector.__main__:main" 98 | 99 | [project.urls] 100 | # Use PyPI-standard names here 101 | # Homepage 102 | # Documentation 103 | # Changelog 104 | # Issue Tracker 105 | # Source 106 | # Discord server 107 | homepage = "https://github.com/theochem/Selector" 108 | documentation = "https://selector.qcdevs.org/" 109 | repository = "https://github.com/theochem/Selector" 110 | 111 | # Development dependencies 112 | # pip install -e .[lint,test,exe] 113 | # pip install -e .[dev] 114 | 115 | # we can only provide one optional dependencies or dynamic dependencies 116 | # we can't provide both, which leads to errors 117 | # [project.optional-dependencies] 118 | # lint = [ 119 | # # ruff linter checks for issues and potential bugs 120 | # "ruff", 121 | # # checks for unused code 122 | # # "vulture", 123 | # # # required for codespell to parse pyproject.toml 124 | # # "tomli", 125 | # # # validation of pyproject.toml 126 | # # "validate-pyproject[all]", 127 | # # automatic sorting of imports 128 | # "isort", 129 | # # # automatic code formatting to follow a consistent style 130 | # # "black", 131 | # ] 132 | 133 | # test = [ 134 | # # Handles most of the testing work, including execution 135 | # # Docs: https://docs.pytest.org/en/stable/contents.html 136 | # "pytest>=7.4.0", 137 | # # required by pytest 138 | # "hypothesis", 139 | # # "Coverage" is how much of the code is actually run (it's "coverage") 140 | # # Generates coverage reports from test suite runs 141 | # "pytest-cov>=3.0.0", 142 | # "tomli", 143 | # "scikit-learn>=1.0.1", 144 | # # Better parsing of doctests 145 | # "xdoctest", 146 | # # Colors for doctest output 147 | # "Pygments", 148 | # ] 149 | 150 | # exe = [ 151 | # "setuptools", 152 | # "wheel", 153 | # "build", 154 | # "tomli", 155 | # "pyinstaller", 156 | # "staticx;platform_system=='Linux'", 157 | # ] 158 | 159 | # dev = [ 160 | # # https://hynek.me/articles/python-recursive-optional-dependencies/ 161 | # "selector[lint,test,exe]", 162 | 163 | # # # Code quality tools 164 | # # "mypy", 165 | 166 | # # # Improved exception traceback output 167 | # # # https://github.com/qix-/better-exceptions 168 | # # "better_exceptions", 169 | 170 | # # # Analyzing dependencies 171 | # # # install graphviz to generate graphs 172 | # # "graphviz", 173 | # # "pipdeptree", 174 | # ] 175 | 176 | [build-system] 177 | # Minimum requirements for the build system to execute. 178 | requires = ["setuptools>=64", "setuptools-scm>=8", "wheel"] 179 | build-backend = "setuptools.build_meta" 180 | 181 | [tool.setuptools.dynamic] 182 | dependencies = {file = ["requirements.txt"]} 183 | optional-dependencies = {dev = { file = ["requirements_dev.txt"] }} 184 | 185 | [tool.setuptools_scm] 186 | # can be empty if no extra settings are needed, presence enables setuptools-scm 187 | 188 | [tool.setuptools] 189 | # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html 190 | platforms = ["Linux", "Windows", "MacOS"] 191 | include-package-data = true 192 | # This just means it's safe to zip up the bdist 193 | zip-safe = true 194 | 195 | # Non-code data that should be included in the package source code 196 | # https://setuptools.pypa.io/en/latest/userguide/datafiles.html 197 | [tool.setuptools.package-data] 198 | selector = ["*.xml"] 199 | 200 | # Python modules and packages that are included in the 201 | # distribution package (and therefore become importable) 202 | [tool.setuptools.packages.find] 203 | exclude = ["*/*/tests", "tests_*", "examples", "notebooks"] 204 | 205 | 206 | # PDM example 207 | #[tool.pdm.scripts] 208 | #isort = "isort selector" 209 | #black = "black selector" 210 | #format = {composite = ["isort", "black"]} 211 | #check_isort = "isort --check selector tests" 212 | #check_black = "black --check selector tests" 213 | #vulture = "vulture --min-confidence 100 selector tests" 214 | #ruff = "ruff check selector tests" 215 | #fix = "ruff check --fix selector tests" 216 | #codespell = "codespell --toml ./pyproject.toml" 217 | #lint = {composite = ["vulture", "codespell", "ruff", "check_isort", "check_black"]} 218 | 219 | 220 | #[tool.codespell] 221 | ## codespell supports pyproject.toml since version 2.2.2 222 | ## NOTE: the "tomli" package must be installed for this to work 223 | ## https://github.com/codespell-project/codespell#using-a-config-file 224 | ## NOTE: ignore words for codespell must be lowercase 225 | #check-filenames = "" 226 | #ignore-words-list = "word,another,something" 227 | #skip = "htmlcov,.doctrees,*.pyc,*.class,*.ico,*.out,*.PNG,*.inv,*.png,*.jpg,*.dot" 228 | 229 | 230 | [tool.black] 231 | line-length = 100 232 | # If you need to exclude directories from being reformatted by black 233 | # force-exclude = ''' 234 | # ( 235 | # somedirname 236 | # | dirname 237 | # | filename\.py 238 | # ) 239 | # ''' 240 | 241 | 242 | [tool.isort] 243 | profile = "black" 244 | known_first_party = ["selector"] 245 | # If you need to exclude files from having their imports sorted 246 | #extend_skip_glob = [ 247 | # "selector/somefile.py", 248 | # "selector/somedir/*", 249 | #] 250 | 251 | 252 | # https://beta.ruff.rs/docs 253 | [tool.ruff] 254 | line-length = 100 255 | show-source = true 256 | 257 | # Rules: https://beta.ruff.rs/docs/rules 258 | # If you violate a rule, lookup the rule on the Rules page in ruff docs. 259 | # Many rules have links you can click with a explanation of the rule and how to fix it. 260 | # If there isn't a link, go to the project the rule was source from (e.g. flake8-bugbear) 261 | # and review it's docs for the corresponding rule. 262 | # If you're still confused, ask a fellow developer for assistance. 263 | # You can also run "ruff rule " to explain a rule on the command line, without a browser or internet access. 264 | select = [ 265 | "E", # pycodestyle 266 | "F", # Pyflakes 267 | "W", # Warning 268 | "B", # flake8-bugbear 269 | "A", # flake8-builtins 270 | "C4", # flake8-comprehensions 271 | "T10", # flake8-debugger 272 | "EXE", # flake8-executable, 273 | "ISC", # flake8-implicit-str-concat 274 | "G", # flake8-logging-format 275 | "PIE", # flake8-pie 276 | "T20", # flake8-print 277 | "PT", # flake8-pytest-style 278 | "RSE", # flake8-raise 279 | "RET", # flake8-return 280 | "TID", # flake8-tidy-imports 281 | "ARG", # flake8-unused-arguments 282 | "PGH", # pygrep-hooks 283 | "PLC", # Pylint Convention 284 | "PLE", # Pylint Errors 285 | "PLW", # Pylint Warnings 286 | "RUF", # Ruff-specific rules 287 | 288 | # ** Things to potentially enable in the future ** 289 | # DTZ requires all usage of datetime module to have timezone-aware 290 | # objects (so have a tz argument or be explicitly UTC). 291 | # "DTZ", # flake8-datetimez 292 | # "PTH", # flake8-use-pathlib 293 | # "SIM", # flake8-simplify 294 | ] 295 | 296 | # Files to exclude from linting 297 | extend-exclude = [ 298 | "*.pyc", 299 | "__pycache__", 300 | "*.egg-info", 301 | ".eggs", 302 | # check point files of jupyter notebooks 303 | "*.ipynb_checkpoints", 304 | ".tox", 305 | ".git", 306 | "build", 307 | "dist", 308 | "docs", 309 | "examples", 310 | "htmlcov", 311 | "notebooks", 312 | ".cache", 313 | "_version.py", 314 | ] 315 | 316 | # Linting error codes to ignore 317 | ignore = [ 318 | "F403", # unable to detect undefined names from star imports 319 | "F405", # undefined locals from star imports 320 | "W605", # invalid escape sequence 321 | "A003", # shadowing python builtins 322 | "RET505", # unnecessary 'else' after 'return' statement 323 | "RET504", # Unnecessary variable assignment before return statement 324 | "RET507", # Unnecessary {branch} after continue statement 325 | "PT011", # pytest-raises-too-broad 326 | "PT012", # pytest.raises() block should contain a single simple statement 327 | "PLW0603", # Using the global statement to update is discouraged 328 | "PLW2901", # for loop variable overwritten by assignment target 329 | "G004", # Logging statement uses f-string 330 | "PIE790", # no-unnecessary-pass 331 | "PIE810", # multiple-starts-ends-with 332 | "PGH003", # Use specific rule codes when ignoring type issues 333 | "PLC1901", # compare-to-empty-string 334 | ] 335 | 336 | # Linting error codes to ignore on a per-file basis 337 | [tool.ruff.per-file-ignores] 338 | "__init__.py" = ["F401", "E501"] 339 | "selector/somefile.py" = ["E402", "E501"] 340 | "selector/somedir/*" = ["E501"] 341 | 342 | 343 | # Configuration for mypy 344 | # https://mypy.readthedocs.io/en/stable/config_file.html#using-a-pyproject-toml-file 345 | [tool.mypy] 346 | python_version = "3.9" 347 | follow_imports = "skip" 348 | ignore_missing_imports = true 349 | files = "selector" # directory mypy should analyze 350 | # Directories to exclude from mypy's analysis 351 | exclude = [ 352 | "book", 353 | ] 354 | 355 | 356 | # Configuration for pytest 357 | # https://docs.pytest.org/en/latest/reference/customize.html#pyproject-toml 358 | [tool.pytest.ini_options] 359 | addopts = [ 360 | # Allow test files to have the same name in different directories. 361 | "--import-mode=importlib", 362 | "--cache-clear", 363 | "--showlocals", 364 | "-v", 365 | "-r a", 366 | "--cov-report=term-missing", 367 | "--cov=selector", 368 | ] 369 | # directory containing the tests 370 | testpaths = [ 371 | "selector/measures/tests", 372 | "selector/methods/tests", 373 | ] 374 | norecursedirs = [ 375 | ".vscode", 376 | "__pycache__", 377 | "build", 378 | ] 379 | # Warnings that should be ignored 380 | filterwarnings = [ 381 | "ignore::DeprecationWarning" 382 | ] 383 | # custom markers that can be used using pytest.mark 384 | markers = [ 385 | "slow: lower-importance tests that take an excessive amount of time", 386 | ] 387 | 388 | 389 | # Configuration for coverage.py 390 | [tool.coverage.run] 391 | # files or directories to exclude from coverage calculations 392 | omit = [ 393 | 'selector/measures/tests/*', 394 | 'selector/methods/tests/*', 395 | ] 396 | 397 | 398 | # Configuration for vulture 399 | [tool.vulture] 400 | # Files or directories to exclude from vulture 401 | # The syntax is a little funky 402 | exclude = [ 403 | "somedir", 404 | "*somefile.py", 405 | ] 406 | 407 | # configuration for bandit 408 | [tool.bandit] 409 | exclude_dirs = [ 410 | "selector/measures/tests", 411 | "selector/methods/tests", 412 | ] 413 | skips = [ 414 | "B101", # Ignore assert statements 415 | "B311", # Ignore pseudo-random generators 416 | "B404", # Ignore subprocess import 417 | "B603", # Ignore subprocess call 418 | "B607", # Ignore subprocess call 419 | ] 420 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitarray>=2.5.1 2 | numpy>=1.21.2 3 | scipy>=1.11.1 4 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | bitarray>=2.5.1 2 | coverage>=6.3.2 3 | hypothesis 4 | numpy>=1.21.2 5 | pre-commit 6 | pytest>=7.4.0 7 | pytest-cov>=3.0.0 8 | scikit-learn>=1.0.1 9 | scipy>=1.11.1 10 | setuptools>=64.0.0 11 | tomli 12 | xdoctest 13 | setuptools-scm>=8.0.0 14 | -------------------------------------------------------------------------------- /selector/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | """Selector Package.""" 26 | 27 | from selector.methods import * 28 | -------------------------------------------------------------------------------- /selector/measures/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | -------------------------------------------------------------------------------- /selector/measures/converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Module for converting similarity measures to distance/dissimilarity measures.""" 25 | 26 | from typing import Union 27 | 28 | import numpy as np 29 | 30 | __all__ = [ 31 | "sim_to_dist", 32 | "reverse", 33 | "reciprocal", 34 | "exponential", 35 | "gaussian", 36 | "correlation", 37 | "transition", 38 | "co_occurrence", 39 | "gravity", 40 | "probability", 41 | "covariance", 42 | ] 43 | 44 | 45 | def sim_to_dist( 46 | x: Union[int, float, np.ndarray], metric: str, scaling_factor: float = 1.0 47 | ) -> Union[float, np.ndarray]: 48 | """Convert similarity coefficients to distance array. 49 | 50 | Parameters 51 | ---------- 52 | x : float or ndarray 53 | A similarity value as float, or a 1D or 2D array of similarity values. 54 | If 2D, the array is assumed to be symmetric. 55 | metric : str 56 | String or integer specifying which conversion metric to use. 57 | Supported metrics are "reverse", "reciprocal", "exponential", 58 | "gaussian", "membership", "correlation", "transition", "co-occurrence", 59 | "gravity", "confusion", "probability", and "covariance". 60 | scaling_factor : float, optional 61 | Scaling factor for the distance array. Default is 1.0. 62 | 63 | Returns 64 | ------- 65 | dist : float or ndarray 66 | Distance value or array. 67 | """ 68 | # scale the distance matrix 69 | x = x * scaling_factor 70 | 71 | frequency = { 72 | "transition": transition, 73 | "co-occurrence": co_occurrence, 74 | "gravity": gravity, 75 | } 76 | method_dict = { 77 | "reverse": reverse, 78 | "reciprocal": reciprocal, 79 | "exponential": exponential, 80 | "gaussian": gaussian, 81 | "correlation": correlation, 82 | "probability": probability, 83 | "covariance": covariance, 84 | } 85 | 86 | # check if x is a single value 87 | single_value = False 88 | if isinstance(x, (float, int)): 89 | x = np.array([[x]]) 90 | single_value = True 91 | 92 | # check x 93 | if not isinstance(x, np.ndarray): 94 | raise ValueError(f"Argument x should be a numpy array instead of {type(x)}") 95 | # check that x is a valid array 96 | if x.ndim != 1 and x.ndim != 2: 97 | raise ValueError(f"Argument x should either have 1 or 2 dimensions, got {x.ndim}.") 98 | if x.ndim == 1 and metric in ["co-occurrence", "gravity"]: 99 | raise ValueError(f"Argument x should be a 2D array when using the {metric} metric.") 100 | # check if x is symmetric 101 | if x.ndim == 2 and not np.allclose(x, x.T): 102 | raise ValueError("Argument x should be a symmetric array.") 103 | 104 | # call correct metric function 105 | if metric in frequency: 106 | if np.any(x <= 0): 107 | raise ValueError( 108 | "There is a negative or zero value in the input. Please " 109 | "make sure all frequency values are positive." 110 | ) 111 | dist = frequency[metric](x) 112 | elif metric in method_dict: 113 | dist = method_dict[metric](x) 114 | elif metric == "membership" or metric == "confusion": 115 | if np.any(x < 0) or np.any(x > 1): 116 | raise ValueError( 117 | "There is an out of bounds value. Please make " 118 | "sure all input values are between [0, 1]." 119 | ) 120 | dist = 1 - x 121 | # unsupported metric 122 | else: 123 | raise ValueError(f"{metric} is an unsupported metric.") 124 | 125 | # convert back to float if input was single value 126 | if single_value: 127 | dist = dist.item((0, 0)) 128 | 129 | return dist 130 | 131 | 132 | def reverse(x: np.ndarray) -> np.ndarray: 133 | r"""Calculate distance array from similarity using the reverse method. 134 | 135 | .. math:: 136 | \delta_{ij} = min(s_{ij}) + max(s_{ij}) - s_{ij} 137 | 138 | where :math:`\delta_{ij}` is the distance between points :math:`i` 139 | and :math:`j`, :math:`s_{ij}` is their similarity coefficient, 140 | and :math:`max` and :math:`min` are the maximum and minimum 141 | values across the entire similarity array. 142 | 143 | Parameters 144 | ----------- 145 | x : ndarray 146 | Similarity array. 147 | 148 | Returns 149 | ------- 150 | dist : ndarray 151 | Distance array. 152 | 153 | """ 154 | dist = np.max(x) + np.min(x) - x 155 | return dist 156 | 157 | 158 | def reciprocal(x: np.ndarray) -> np.ndarray: 159 | r"""Calculate distance array from similarity using the reciprocal method. 160 | 161 | .. math:: 162 | \delta_{ij} = \frac{1}{s_{ij}} 163 | 164 | where :math:`\delta_{ij}` is the distance between points :math:`i` 165 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient. 166 | 167 | Parameters 168 | ----------- 169 | x : ndarray 170 | Similarity array. 171 | 172 | Returns 173 | ------- 174 | dist : ndarray 175 | Distance array. 176 | """ 177 | 178 | if np.any(x <= 0): 179 | raise ValueError( 180 | "There is an out of bounds value. Please make " "sure all similarities are positive." 181 | ) 182 | return 1 / x 183 | 184 | 185 | def exponential(x: np.ndarray) -> np.ndarray: 186 | r"""Calculate distance matrix from similarity using the exponential method. 187 | 188 | .. math:: 189 | \delta_{ij} = -\ln{\frac{s_{ij}}{max(s_{ij})}} 190 | 191 | where :math:`\delta_{ij}` is the distance between points :math:`i` 192 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient. 193 | 194 | Parameters 195 | ----------- 196 | x : ndarray 197 | Similarity array. 198 | 199 | Returns 200 | ------- 201 | dist : ndarray 202 | Distance array. 203 | 204 | """ 205 | max_sim = np.max(x) 206 | if max_sim == 0: 207 | raise ValueError("Maximum similarity in `x` is 0. Distance cannot be computed.") 208 | dist = -np.log(x / max_sim) 209 | return dist 210 | 211 | 212 | def gaussian(x: np.ndarray) -> np.ndarray: 213 | r"""Calculate distance matrix from similarity using the Gaussian method. 214 | 215 | .. math:: 216 | \delta_{ij} = \sqrt{-\ln{\frac{s_{ij}}{max(s_{ij})}}} 217 | 218 | where :math:`\delta_{ij}` is the distance between points :math:`i` 219 | and :math:`j`, and :math:`s_{ij}` is their similarity coefficient. 220 | 221 | Parameters 222 | ----------- 223 | x : ndarray 224 | Similarity array. 225 | 226 | Returns 227 | ------- 228 | dist : ndarray 229 | Distance array. 230 | 231 | """ 232 | max_sim = np.max(x) 233 | if max_sim == 0: 234 | raise ValueError("Maximum similarity in `x` is 0. Distance cannot be computed.") 235 | y = x / max_sim 236 | dist = np.sqrt(-np.log(y)) 237 | return dist 238 | 239 | 240 | def correlation(x: np.ndarray) -> np.ndarray: 241 | r"""Calculate distance array from correlation array. 242 | 243 | .. math:: 244 | \delta_{ij} = \sqrt{1 - r_{ij}} 245 | 246 | where :math:`\delta_{ij}` is the distance between points :math:`i` 247 | and :math:`j`, and :math:`r_{ij}` is their correlation. 248 | 249 | Parameters 250 | ----------- 251 | x : ndarray 252 | Correlation array. 253 | 254 | Returns 255 | ------- 256 | dist : ndarray 257 | Distance array. 258 | 259 | """ 260 | if np.any(x < -1) or np.any(x > 1): 261 | raise ValueError( 262 | "There is an out of bounds value. Please make " 263 | "sure all correlations are between [-1, 1]." 264 | ) 265 | dist = np.sqrt(1 - x) 266 | return dist 267 | 268 | 269 | def transition(x: np.ndarray) -> np.ndarray: 270 | r"""Calculate distance array from frequency using the transition method. 271 | 272 | .. math:: 273 | \delta_{ij} = \frac{1}{\sqrt{f_{ij}}} 274 | 275 | where :math:`\delta_{ij}` is the distance between points :math:`i` 276 | and :math:`j`, and :math:`f_{ij}` is their frequency. 277 | 278 | Parameters 279 | ----------- 280 | x : ndarray 281 | Symmetric frequency array. 282 | 283 | Returns 284 | ------- 285 | dist : ndarray 286 | Distance array. 287 | 288 | """ 289 | dist = 1 / np.sqrt(x) 290 | return dist 291 | 292 | 293 | def co_occurrence(x: np.ndarray) -> np.ndarray: 294 | r"""Calculate distance array from frequency using the co-occurrence method. 295 | 296 | .. math:: 297 | \delta_{ij} = \left(1 + \frac{f_{ij}\sum_{i,j}{f_{ij}}}{\sum_{i}{f_{ij}}\sum_{j}{f_{ij}}} \right)^{-1} 298 | 299 | where :math:`\delta_{ij}` is the distance between points :math:`i` 300 | and :math:`j`, and :math:`f_{ij}` is their frequency. 301 | 302 | Parameters 303 | ----------- 304 | x : ndarray 305 | Frequency array. 306 | 307 | Returns 308 | ------- 309 | dist : ndarray 310 | Co-occurrence array. 311 | 312 | """ 313 | # compute sums along each axis 314 | i = np.sum(x, axis=0, keepdims=True) 315 | j = np.sum(x, axis=1, keepdims=True) 316 | # multiply sums to scalar value 317 | bottom = np.dot(i, j) 318 | # multiply each element by the sum of entire array 319 | top = x * np.sum(x) 320 | # evaluate function as a whole 321 | dist = (1 + (top / bottom)) ** -1 322 | return dist 323 | 324 | 325 | def gravity(x: np.ndarray) -> np.ndarray: 326 | r"""Calculate distance array from frequency using the gravity method. 327 | 328 | .. math:: 329 | \delta_{ij} = \sqrt{\frac{\sum_{i}{f_{ij}}\sum_{j}{f_{ij}}} 330 | {f_{ij}\sum_{i,j}{f_{ij}}}} 331 | 332 | where :math:`\delta_{ij}` is the distance between points :math:`i` 333 | and :math:`j`, and :math:`f_{ij}` is their frequency. 334 | 335 | Parameters 336 | ----------- 337 | x : ndarray 338 | Symmetric frequency array. 339 | 340 | Returns 341 | ------- 342 | dist : ndarray 343 | Symmetric gravity array. 344 | 345 | """ 346 | # compute sums along each axis 347 | i = np.sum(x, axis=0, keepdims=True) 348 | j = np.sum(x, axis=1, keepdims=True) 349 | # multiply sums to scalar value 350 | top = np.dot(i, j) 351 | # multiply each element by the sum of entire array 352 | bottom = x * np.sum(x) 353 | # take square root of the fraction 354 | dist = np.sqrt(top / bottom) 355 | return dist 356 | 357 | 358 | def probability(x: np.ndarray) -> np.ndarray: 359 | r"""Calculate distance array from probability array. 360 | 361 | .. math:: 362 | \delta_{ij} = \sqrt{-\ln{\frac{s_{ij}}{max(s_{ij})}}} 363 | 364 | where :math:`\delta_{ij}` is the distance between points :math:`i` 365 | and :math:`j`, and :math:`p_{ij}` is their probablity. 366 | 367 | Parameters 368 | ----------- 369 | x : ndarray 370 | Symmetric probability array. 371 | 372 | Returns 373 | ------- 374 | dist : ndarray 375 | Distance array. 376 | 377 | """ 378 | if np.any(x <= 0) or np.any(x > 1): 379 | raise ValueError( 380 | "There is an out of bounds value. Please make " 381 | "sure all probabilities are between (0, 1]." 382 | ) 383 | y = np.arcsin(x) 384 | dist = 1 / np.sqrt(y) 385 | return dist 386 | 387 | 388 | def covariance(x: np.ndarray) -> np.ndarray: 389 | r"""Calculate distance array from similarity using the covariance method. 390 | 391 | .. math:: 392 | \delta_{ij} = \sqrt{s_{ii}+s_{jj}-2s_{ij}} 393 | 394 | where :math:`\delta_{ij}` is the distance between points :math:`i` 395 | and :math:`j`, :math:`s_{ii}` and :math:`s_{jj}` are the variances 396 | of feature :math:`i` and feature :math:`j`, and :math:`s_{ij}` 397 | is the covariance between the two features. 398 | 399 | Parameters 400 | ----------- 401 | x : ndarray 402 | Covariance array. 403 | 404 | Returns 405 | ------- 406 | dist : ndarray 407 | Distance array. 408 | 409 | """ 410 | variance = np.diag(x).reshape([x.shape[0], 1]) * np.ones([1, x.shape[0]]) 411 | if np.any(variance < 0): 412 | raise ValueError("Variance of a single variable cannot be negative.") 413 | 414 | dist = variance + variance.T - 2 * x 415 | dist = np.sqrt(dist) 416 | return dist 417 | -------------------------------------------------------------------------------- /selector/measures/diversity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | """Molecule dataset diversity calculation module.""" 26 | 27 | import warnings 28 | 29 | import numpy as np 30 | 31 | from selector.measures.similarity import tanimoto 32 | 33 | __all__ = [ 34 | "compute_diversity", 35 | "logdet", 36 | "shannon_entropy", 37 | "explicit_diversity_index", 38 | "wdud", 39 | "hypersphere_overlap_of_subset", 40 | "gini_coefficient", 41 | "nearest_average_tanimoto", 42 | ] 43 | 44 | 45 | def compute_diversity( 46 | feature_subset: np.array, 47 | div_type: str = "shannon_entropy", 48 | normalize: bool = False, 49 | truncation: bool = False, 50 | features: np.array = None, 51 | cs: int = None, 52 | ) -> float: 53 | """Compute diversity metrics. 54 | 55 | Parameters 56 | ---------- 57 | feature_subset : np.ndarray 58 | Feature matrix. 59 | div_type : str, optional 60 | Method of calculation diversity for a given molecule set, which 61 | includes "entropy", "logdet", "shannon entropy", "wdud", 62 | "gini coefficient" "hypersphere_overlap", and 63 | "explicit diversity index". 64 | The default is "entropy". 65 | normalize : bool, optional 66 | Normalize the entropy to [0, 1]. Default is "False". 67 | truncation : bool, optional 68 | Use the truncated Shannon entropy. Default is "False". 69 | features : np.ndarray, optional 70 | Feature matrix of entire molecule library, used only if 71 | calculating `hypersphere_overlap_of_subset`. Default is "None". 72 | cs : int, optional 73 | Number of common substructures in molecular compound dataset. 74 | Used only if calculating `explicit_diversity_index`. Default is "None". 75 | 76 | 77 | Returns 78 | ------- 79 | float, computed diversity. 80 | 81 | """ 82 | func_dict = { 83 | "logdet": logdet, 84 | "wdud": wdud, 85 | "gini_coefficient": gini_coefficient, 86 | } 87 | 88 | if div_type in func_dict: 89 | return func_dict[div_type](feature_subset) 90 | 91 | # hypersphere overlap of subset 92 | elif div_type == "hypersphere_overlap": 93 | if features is None: 94 | raise ValueError( 95 | "Please input a feature matrix of the entire " 96 | "dataset when calculating hypersphere overlap." 97 | ) 98 | return hypersphere_overlap_of_subset(features, feature_subset) 99 | 100 | elif div_type == "shannon_entropy": 101 | return shannon_entropy(feature_subset, normalize=normalize, truncation=truncation) 102 | 103 | elif div_type == "explicit_diversity_index": 104 | if cs is None: 105 | raise ValueError( 106 | "Attribute `cs` is missing. " 107 | "Please input `cs` value to use explicit_diversity_index." 108 | ) 109 | elif cs == 0: 110 | raise ValueError("Divide by zero error: Attribute `cs` cannot be 0.") 111 | return explicit_diversity_index(feature_subset, cs) 112 | else: 113 | raise ValueError(f"Diversity type {div_type} not supported.") 114 | 115 | 116 | def logdet(x: np.ndarray) -> float: 117 | r"""Compute the log determinant function. 118 | 119 | Given a :math:`n_s \times n_f` feature matrix :math:`x`, where :math:`n_s` is the number of 120 | samples and :math:`n_f` is the number of features, the log determinant function is defined as: 121 | 122 | .. math: 123 | F_\text{logdet} = \log{\left(\det{\left(\mathbf{x}\mathbf{x}^T + \mathbf{I}\right)}\right)} 124 | 125 | where the :math:`I` is the :math:`n_s \times n_s` identity matrix. 126 | Higher values of :math:`F_\text{logdet}` mean more diversity. 127 | 128 | Parameters 129 | ---------- 130 | x: ndarray of shape (n_samples, n_features) 131 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space, 132 | 133 | Returns 134 | ------- 135 | f_logdet: float 136 | The volume of parallelotope spanned by the matrix. 137 | 138 | Notes 139 | ----- 140 | The log-determinant function is based on the formula in Nakamura, T., Sci Rep 2022. 141 | Please note that we used the 142 | natural logrithim to avoid the numerical stability issues, 143 | https://github.com/theochem/Selector/issues/229. 144 | 145 | References 146 | ---------- 147 | Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S.., Selecting 148 | molecules with diverse structures and properties by maximizing submodular functions of 149 | descriptors learned with graph neural networks. Scientific Reports 12, 2022. 150 | 151 | """ 152 | mid = np.dot(x, np.transpose(x)) + np.identity(x.shape[0]) 153 | logdet_mid = np.linalg.slogdet(mid) 154 | f_logdet = logdet_mid.sign * logdet_mid.logabsdet 155 | return f_logdet 156 | 157 | 158 | def shannon_entropy(x: np.ndarray, normalize=True, truncation=False) -> float: 159 | r"""Compute the shannon entropy of a binary matrix. 160 | 161 | Higher values mean more diversity. 162 | 163 | Parameters 164 | ---------- 165 | x : ndarray 166 | Bit-string matrix. 167 | normalize : bool, optional 168 | Normalize the entropy to [0, 1]. Default=True. 169 | truncation : bool, optional 170 | Use the truncated Shannon entropy by only counting the contributions of one-bits. 171 | Default=False. 172 | 173 | Returns 174 | ------- 175 | float : 176 | The shannon entropy of the matrix. 177 | 178 | Notes 179 | ----- 180 | Suppose we have :math:`m` compounds and each compound has :math:`n` bits binary fingerprints. 181 | The binary matrix (feature matrix) is :math:`\mathbf{x} \in m \times n`, where each 182 | row is a compound and each column contains the :math:`n`-bit binary fingerprint. 183 | The equation for Shannon entropy is given by [1]_ and [3]_, 184 | 185 | .. math:: 186 | H = \sum_i^m \left[ - p_i \log_2{p_i } - (1 - p_i)\log_2(1 - p_i) \right] 187 | 188 | where :math:`p_i` represents the relative frequency of `1` bits at the fingerprint position 189 | :math:`i`. When :math:`p_i = 0` or :math:`p_i = 1`, the :math:`SE_i` is zero. 190 | When `completeness` is True, the entropy is calculated as in [2]_ instead 191 | 192 | .. math:: 193 | H = \sum_i^m \left[ - p_i \log_2{p_i } \right] 194 | 195 | When `normalize` is True, the entropy is normalized by a scaling factor so that the entropy is in the range of 196 | [0, 1], [2]_ 197 | 198 | .. math:: 199 | H = \frac{ \sum_i^m \left[ - p_i \log_2{p_i } - (1 - p_i)\log_2(1 - p_i) \right]} 200 | {n \log_2{2} / 2} 201 | 202 | But please note, when `completeness` is False and `normalize` is True, the formula has not been 203 | used in any literature. It is just a simple normalization of the entropy and the user can use it at their own risk. 204 | 205 | References 206 | ---------- 207 | .. [1] Wang, Y., Geppert, H., & Bajorath, J. (2009). Shannon entropy-based fingerprint similarity 208 | search strategy. Journal of Chemical Information and Modeling, 49(7), 1687-1691. 209 | .. [2] Leguy, J., Glavatskikh, M., Cauchy, T., & Da Mota, B. (2021). Scalable estimator of the 210 | diversity for de novo molecular generation resulting in a more robust QM dataset (OD9) and a 211 | more efficient molecular optimization. Journal of Cheminformatics, 13(1), 1-17. 212 | .. [3] Weidlich, I. E., & Filippov, I. V. (2016). Using the Gini coefficient to measure the 213 | chemical diversity of small molecule libraries. Journal of Computational Chemistry, 37(22), 2091-2097. 214 | 215 | """ 216 | # check if matrix is binary 217 | if np.count_nonzero((x != 0) & (x != 1)) != 0: 218 | raise ValueError("Attribute `x` should have binary values.") 219 | 220 | p_i_arr = np.sum(x, axis=0) / x.shape[0] 221 | h_x = 0 222 | 223 | for p_i in p_i_arr: 224 | if p_i == 0 or p_i == 1: 225 | # p_i = 0 226 | se_i = 0 227 | else: 228 | if truncation: 229 | # from https://jcheminf.biomedcentral.com/articles/10.1186/s13321-021-00554-8 230 | se_i = -p_i * np.log2(p_i) 231 | else: 232 | # from https://pubs.acs.org/doi/10.1021/ci900159f 233 | se_i = -p_i * np.log2(p_i) - (1 - p_i) * np.log2(1 - p_i) 234 | 235 | h_x += se_i 236 | 237 | if normalize: 238 | if truncation: 239 | warnings.warn( 240 | "Computing the normalized Shannon entropy only counting the on-bits has not been reported in " 241 | "literature. The user can use it at their own risk." 242 | ) 243 | 244 | h_x /= x.shape[1] * np.log2(2) / 2 245 | 246 | return h_x 247 | 248 | 249 | # todo: add tests for edi 250 | def explicit_diversity_index( 251 | x: np.ndarray, 252 | cs: int, 253 | ) -> float: 254 | """Compute the explicit diversity index. 255 | 256 | Parameters 257 | ---------- 258 | x: ndarray of shape (n_samples, n_features) 259 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space. 260 | cs : int 261 | Number of common substructures in the compound set. 262 | 263 | Returns 264 | ------- 265 | edi_scaled : float 266 | Explicit diversity index. 267 | 268 | Notes 269 | ----- 270 | This method hasn't been tested. 271 | 272 | This method is used only for datasets of molecular compounds. 273 | 274 | Papp, Á., Gulyás-Forró, A., Gulyás, Z., Dormán, G., Ürge, L., 275 | and Darvas, F.. (2006) Explicit Diversity Index (EDI): 276 | A Novel Measure for Assessing the Diversity of Compound Databases. 277 | Journal of Chemical Information and Modeling 46, 1898-1904. 278 | """ 279 | nc = len(x) 280 | sdi = (1 - nearest_average_tanimoto(x)) / (0.8047 - (0.065 * (np.log(nc)))) 281 | cr = -1 * np.log10(nc / (cs**2)) 282 | edi = (sdi + cr) * 0.7071067811865476 283 | edi_scaled = ((np.tanh(edi / 3) + 1) / 2) * 100 284 | return edi_scaled 285 | 286 | 287 | def wdud(x: np.ndarray) -> float: 288 | r"""Compute the Wasserstein Distance to Uniform Distribution(WDUD). 289 | 290 | The equation for the Wasserstein Distance for a single feature to uniform distribution is 291 | 292 | .. math:: 293 | WDUD(x) = \int_{0}^{1} |U(x) - V(x)|dx 294 | 295 | where the feature is normalized to [0, 1], :math:`U(x)=x` is the cumulative distribution 296 | of the uniform distribution on [0, 1], and :math:`V(x) = \sum_{y <= x}1 / N` is the discrete 297 | distribution of the values of the feature in :math:`x`, where :math:`y` is the ith feature. This 298 | integral is calculated iteratively between :math:`y_i` and :math:`y_{i+1}`, using trapezoidal method. 299 | 300 | Lower values of the WDUD mean more diversity because the features of the selected set are 301 | more evenly distributed over the range of feature values. 302 | 303 | Parameters 304 | ---------- 305 | x: ndarray of shape (n_samples, n_features) 306 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space. 307 | 308 | Returns 309 | ------- 310 | float : 311 | The mean of the WDUD of each feature over all molecules. 312 | 313 | Notes 314 | ----- 315 | Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S.. (2022) 316 | Selecting molecules with diverse structures and properties by maximizing 317 | submodular functions of descriptors learned with graph neural networks. 318 | Scientific Reports 12. 319 | 320 | """ 321 | if x.ndim != 2: 322 | raise ValueError(f"The number of dimensions {x.ndim} should be two.") 323 | 324 | # find the range of each feature 325 | col_diff = np.ptp(x, axis=0) 326 | # Normalization of each feature to [0, 1] 327 | if np.any(np.abs(col_diff) < 1e-30): 328 | # warning if some feature columns are constant 329 | warnings.warn( 330 | "Some of the features are constant which will cause the normalization to fail. " 331 | "Now removing them." 332 | ) 333 | if np.all(col_diff < 1.0e-30): 334 | raise ValueError( 335 | "Unfortunately, all the features are constants and wdud cannot be calculated." 336 | ) 337 | else: 338 | # remove the constant feature columns 339 | mask = np.ptp(x, axis=0) > 1e-30 340 | x = x[:, mask] 341 | x_norm = (x - np.min(x, axis=0)) / np.ptp(x, axis=0) 342 | 343 | # min_max normalization: 344 | n_samples, n_features = x_norm.shape 345 | ans = [] # store the Wasserstein distance for each feature 346 | for i in range(0, n_features): 347 | wdu = 0.0 348 | y = np.sort(x_norm[:, i]) 349 | # Round to the sixth decimal place and count number of unique elements 350 | # to construct an accurate cumulative discrete distribution func \sum_{x <= y_{i + 1}} 1/k 351 | y, counts = np.unique(np.round(x_norm[:, i], decimals=6), return_counts=True) 352 | p = 0 353 | # Ignore 0 and because v_min= 0 354 | for j in range(1, len(counts)): 355 | # integral from y_{i - 1} to y_{i} of |x - \sum_{x <= y_{i}} 1/k| dx 356 | yi1 = y[j - 1] 357 | yi = y[j] 358 | # Make a grid from yi1 to yi 359 | grid = np.linspace(yi1, yi, num=1000, endpoint=True) 360 | # Evaluate the integrand |x - \sum_{x <= y_{i + 1}} 1/k| 361 | p += counts[j - 1] 362 | integrand = np.abs(grid - p / n_samples) 363 | # Integrate using np.trapz 364 | wdu += np.trapz(y=integrand, x=grid) 365 | ans.append(wdu) 366 | return np.average(ans) 367 | 368 | 369 | def hypersphere_overlap_of_subset(x: np.ndarray, x_subset: np.array) -> float: 370 | r"""Compute the overlap of subset with hyper-spheres around each point 371 | 372 | The edge penalty is also included, which disregards areas 373 | outside of the boundary of the full feature space/library. 374 | This is calculated as: 375 | 376 | .. math:: 377 | g(S) = \sum_{i < j}^k O(i, j) + \sum^k_m E(m) 378 | 379 | where :math:`i, j` is over the subset of molecules, 380 | :math:`O(i, j)` is the approximate overlap between hyperspheres, 381 | :math:`k` is the number of features and :math:`E` 382 | is the edge penalty of a molecule. 383 | 384 | Lower values mean more diversity. 385 | 386 | Parameters 387 | ---------- 388 | x : ndarray 389 | Feature matrix of all molecules. 390 | x_subset : ndarray 391 | Feature matrix of selected subset of molecules. 392 | 393 | Returns 394 | ------- 395 | float : 396 | The approximate overlapping volume of hyperspheres 397 | drawn around the selected points/molecules. 398 | 399 | Notes 400 | ----- 401 | The hypersphere overlap volume is calculated using an approximation formula from Agrafiotis (1997). 402 | 403 | Agrafiotis, D. K.. (1997) Stochastic Algorithms for Maximizing Molecular Diversity. 404 | Journal of Chemical Information and Computer Sciences 37, 841-851. 405 | """ 406 | 407 | # Find the maximum and minimum over each feature across all molecules. 408 | max_x = np.max(x, axis=0) 409 | min_x = np.min(x, axis=0) 410 | 411 | if np.all(np.abs(max_x - min_x) < 1e-30): 412 | raise ValueError("All of the features are redundant which causes normalization to fail.") 413 | 414 | # Remove redundant features 415 | non_red_feat = np.abs(max_x - min_x) > 1e-30 416 | x = x[:, non_red_feat] 417 | x_subset = x_subset[:, non_red_feat] 418 | max_x = max_x[non_red_feat] 419 | min_x = min_x[non_red_feat] 420 | 421 | d = len(x_subset[0]) 422 | k = len(x_subset[:, 0]) 423 | 424 | # normalization of each feature to [0, 1] 425 | x_norm = (x_subset - min_x) / (max_x - min_x) 426 | 427 | # r_o = hypersphere radius 428 | r_o = d * np.sqrt(1 / k) 429 | if r_o > 0.5: 430 | warnings.warn( 431 | "The number of molecules should be much larger" " than the number of features." 432 | ) 433 | g_s = 0 434 | edge = 0 435 | 436 | # lambda parameter controls edge penalty 437 | lam = (d - 1.0) / d 438 | # calculate overlap volume 439 | for i in range(0, (k - 1)): 440 | for j in range((i + 1), k): 441 | dist = np.linalg.norm(x_norm[i] - x_norm[j]) 442 | # Overlap penalty 443 | if dist <= (2 * r_o): 444 | with np.errstate(divide="ignore"): 445 | # min(100) ignores the inf case with divide by zero 446 | g_s += min(100, 2 * (r_o / dist) - 1) 447 | # Edge penalty: lambda (1 - \sum^d_j e_{ij} / (dr_0) 448 | edge_pen = 0.0 449 | for j_dim in range(0, d): 450 | # calculate dist to closest boundary in jth coordinate, 451 | # with max value = 1, min value = 0 452 | dist_max = np.abs(1 - x_norm[i, j_dim]) 453 | dist_min = x_norm[i, j_dim] 454 | dist = min(dist_min, dist_max) 455 | # truncate distance at r_o 456 | if dist > r_o: 457 | dist = r_o 458 | edge_pen += dist 459 | edge_pen /= d * r_o 460 | edge_pen = lam * (1.0 - edge_pen) 461 | edge += edge_pen 462 | g_s += edge 463 | return g_s 464 | 465 | 466 | def gini_coefficient(x: np.ndarray): 467 | r""" 468 | Gini coefficient of bit-wise fingerprints of a database of molecules. 469 | 470 | Measures the chemical diversity of a database of molecules defined by 471 | the following formula: 472 | 473 | .. math:: 474 | G = \frac{2 \sum_{i=1}^L i ||y_i||_1 }{N \sum_{i=1}^L ||y_i||_1} - \frac{L+1}{L}, 475 | 476 | where :math:`y_i \in \{0, 1\}^N` is a vector of zero and ones of length the 477 | number of molecules :math:`N` of the `i`th feature, and :math:`L` is the feature length. 478 | 479 | Lower values mean more diversity. 480 | 481 | Parameters 482 | ---------- 483 | x : ndarray(N, L) 484 | Molecule features in L bits with N molecules. 485 | 486 | Returns 487 | ------- 488 | float : 489 | Gini coefficient in the range [0,1]. 490 | 491 | References 492 | ---------- 493 | Weidlich, Iwona E., and Igor V. Filippov. "Using the gini coefficient to measure the 494 | chemical diversity of small‐molecule libraries." (2016): 2091-2097. 495 | 496 | """ 497 | # Check that `x` is a bit-wise fingerprint. 498 | if np.count_nonzero((x != 0) & (x != 1)) != 0: 499 | raise ValueError("Attribute `x` should have binary values.") 500 | if x.ndim != 2: 501 | raise ValueError(f"Attribute `x` should have dimension two rather than {x.ndim}.") 502 | 503 | num_features = x.shape[1] 504 | # Take the bit-count of each column/molecule. 505 | bit_count = np.sum(x, axis=0) 506 | 507 | # Sort the bit-count since Gini coefficients relies on cumulative distribution. 508 | bit_count = np.sort(bit_count) 509 | 510 | # Mean of denominator 511 | denominator = num_features * np.sum(bit_count) 512 | numerator = np.sum(np.arange(1, num_features + 1) * bit_count) 513 | 514 | return 2.0 * numerator / denominator - (num_features + 1) / num_features 515 | 516 | 517 | def nearest_average_tanimoto(x: np.ndarray) -> float: 518 | """Computes the average tanimoto for nearest molecules. 519 | 520 | Parameters 521 | ---------- 522 | x : ndarray 523 | Feature matrix. 524 | 525 | Returns 526 | ------- 527 | nat : float 528 | Average tanimoto of closest pairs. 529 | 530 | Notes 531 | ----- 532 | This computes the tanimoto coefficient of pairs with the shortest 533 | distances, then returns the average of them. 534 | This calculation is explictly for the explicit diversity index. 535 | 536 | Papp, Á., Gulyás-Forró, A., Gulyás, Z., Dormán, G., Ürge, L., 537 | and Darvas, F.. (2006) Explicit Diversity Index (EDI): 538 | A Novel Measure for Assessing the Diversity of Compound Databases. 539 | Journal of Chemical Information and Modeling 46, 1898-1904. 540 | """ 541 | tani = [] 542 | for idx, _ in enumerate(x): 543 | # arbitrary distance for comparison: 544 | short = 100 545 | a = 0 546 | b = 0 547 | # search for shortest distance point from idx 548 | for jdx, _ in enumerate(x): 549 | dist = np.linalg.norm(x[idx] - x[jdx]) 550 | if dist < short and idx != jdx: 551 | short = dist 552 | a = idx 553 | b = jdx 554 | # calculate tanimoto for each shortest dist pair 555 | tani.append(tanimoto(x[a], x[b])) 556 | # compute average of all shortest tanimoto coeffs 557 | nat = np.average(tani) 558 | return nat 559 | -------------------------------------------------------------------------------- /selector/measures/similarity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Similarity Module.""" 25 | 26 | from itertools import combinations_with_replacement 27 | 28 | import numpy as np 29 | 30 | __all__ = [ 31 | "pairwise_similarity_bit", 32 | "tanimoto", 33 | "modified_tanimoto", 34 | "scaled_similarity_matrix", 35 | ] 36 | 37 | 38 | def pairwise_similarity_bit(X: np.array, metric: str) -> np.ndarray: 39 | """Compute pairwise similarity coefficient matrix. 40 | 41 | Parameters 42 | ---------- 43 | X : ndarray of shape (n_samples, n_features) 44 | Feature matrix of `n_samples` samples in `n_features` dimensional space. 45 | metric : str 46 | The metric used when calculating similarity coefficients between samples in a feature array. 47 | Method for calculating similarity coefficient. Options: `"tanimoto"`, `"modified_tanimoto"`. 48 | 49 | Returns 50 | ------- 51 | s : ndarray of shape (n_samples, n_samples) 52 | A symmetric similarity matrix between each pair of samples in the feature matrix. 53 | The diagonal elements are directly computed instead of assuming that they are 1. 54 | 55 | """ 56 | 57 | available_methods = { 58 | "tanimoto": tanimoto, 59 | "modified_tanimoto": modified_tanimoto, 60 | } 61 | if metric not in available_methods: 62 | raise ValueError( 63 | f"Argument metric={metric} is not recognized! Choose from {available_methods.keys()}" 64 | ) 65 | if X.ndim != 2: 66 | raise ValueError(f"Argument features should be a 2D array, got {X.ndim}") 67 | 68 | # make pairwise m-by-m similarity matrix 69 | n_samples = len(X) 70 | s = np.zeros((n_samples, n_samples)) 71 | # compute similarity between all pairs of points (including the diagonal elements) 72 | for i, j in combinations_with_replacement(range(n_samples), 2): 73 | s[i, j] = s[j, i] = available_methods[metric](X[i], X[j]) 74 | return s 75 | 76 | 77 | def tanimoto(a: np.array, b: np.array) -> float: 78 | r"""Compute Tanimoto coefficient or index (a.k.a. Jaccard similarity coefficient). 79 | 80 | For two binary or non-binary arrays :math:`A` and :math:`B`, Tanimoto coefficient 81 | is defined as the size of their intersection divided by the size of their union: 82 | 83 | .. math:: 84 | T(A, B) = \frac{| A \cap B|}{| A \cup B |} = 85 | \frac{| A \cap B|}{|A| + |B| - | A \cap B|} = 86 | \frac{A \cdot B}{\|A\|^2 + \|B\|^2 - A \cdot B} 87 | 88 | where :math:`A \cdot B = \sum_i{A_i B_i}` and :math:`\|A\|^2 = \sum_i{A_i^2}`. 89 | 90 | Parameters 91 | ---------- 92 | a : ndarray of shape (n_features,) 93 | The 1D feature array of sample :math:`A` in an `n_features` dimensional space. 94 | b : ndarray of shape (n_features,) 95 | The 1D feature array of sample :math:`B` in an `n_features` dimensional space. 96 | 97 | Returns 98 | ------- 99 | coeff : float 100 | Tanimoto coefficient between feature arrays :math:`A` and :math:`B`. 101 | 102 | Bajusz, D., Rácz, A., and Héberger, K.. (2015) 103 | Why is Tanimoto index an appropriate choice for fingerprint-based similarity calculations?. 104 | Journal of Cheminformatics 7. 105 | 106 | """ 107 | if a.ndim != 1 or b.ndim != 1: 108 | raise ValueError(f"Arguments a and b should be 1D arrays, got {a.ndim} and {b.ndim}") 109 | if a.shape != b.shape: 110 | raise ValueError( 111 | f"Arguments a and b should have the same shape, got {a.shape} != {b.shape}" 112 | ) 113 | coeff = sum(a * b) / (sum(a**2) + sum(b**2) - sum(a * b)) 114 | return coeff 115 | 116 | 117 | def modified_tanimoto(a: np.array, b: np.array) -> float: 118 | r"""Compute the modified tanimoto coefficient from bitstring vectors of data points A and B. 119 | 120 | Adjusts calculation of the Tanimoto coefficient to counter its natural bias towards 121 | shorter vectors using a Bernoulli probability model. 122 | 123 | .. math:: 124 | {mt} = \frac{2-p}{3} T_1 + \frac{1+p}{3} T_0 125 | 126 | where :math:`p` is success probability of independent trials, 127 | :math:`T_1` is the number of common '1' bits between data points 128 | (:math:`T_1 = | A \cap B |`), and :math:`T_0` is the number of common '0' 129 | bits between data points (:math:`T_0 = |(1-A) \cap (1-B)|`). 130 | 131 | 132 | Parameters 133 | ---------- 134 | a : ndarray of shape (n_features,) 135 | The 1D bitstring feature array of sample :math:`A` in an `n_features` dimensional space. 136 | b : ndarray of shape (n_features,) 137 | The 1D bitstring feature array of sample :math:`B` in an `n_features` dimensional space. 138 | 139 | Returns 140 | ------- 141 | mt : float 142 | Modified tanimoto coefficient between bitstring feature arrays :math:`A` and :math:`B`. 143 | 144 | Notes 145 | ----- 146 | The equation above has been derived from 147 | 148 | .. math:: 149 | {mt}_{\alpha} = {\alpha}T_1 + (1-\alpha)T_0 150 | 151 | where :math:`\alpha = \frac{2-p}{3}`. This is done so that the expected value 152 | of the modified tanimoto, :math:`E(mt)`, remains constant even as the number of 153 | trials :math:`p` grows larger. 154 | 155 | Fligner, M. A., Verducci, J. S., and Blower, P. E.. (2002) 156 | A Modification of the Jaccard-Tanimoto Similarity Index for 157 | Diverse Selection of Chemical Compounds Using Binary Strings. 158 | Technometrics 44, 110-119. 159 | 160 | """ 161 | if a.ndim != 1: 162 | raise ValueError(f"Argument `a` should have dimension 1 rather than {a.ndim}.") 163 | if b.ndim != 1: 164 | raise ValueError(f"Argument `b` should have dimension 1 rather than {b.ndim}.") 165 | if a.shape != b.shape: 166 | raise ValueError( 167 | f"Arguments a and b should have the same shape, got {a.shape} != {b.shape}" 168 | ) 169 | 170 | n_features = len(a) 171 | # number of common '1' bits between points A and B 172 | n_11 = sum(a * b) 173 | # number of common '0' bits between points A and B 174 | n_00 = sum((1 - a) * (1 - b)) 175 | 176 | # calculate Tanimoto coefficient based on '0' bits 177 | t_1 = 1 178 | if n_00 != n_features: 179 | # bit strings are not all '0's 180 | t_1 = n_11 / (n_features - n_00) 181 | # calculate Tanimoto coefficient based on '1' bits 182 | t_0 = 1 183 | if n_11 != n_features: 184 | # bit strings are not all '1's 185 | t_0 = n_00 / (n_features - n_11) 186 | 187 | # combine into modified tanimoto using Bernoulli Model 188 | # p = independent success trials 189 | # evaluated as total number of '1' bits 190 | # divided by 2x the fingerprint length 191 | p = (n_features - n_00 + n_11) / (2 * n_features) 192 | # mt = x * T_1 + (1-x) * T_0 193 | # x = (2-p)/3 so that E(mt) = 1/3, no matter the value of p 194 | mt = (((2 - p) / 3) * t_1) + (((1 + p) / 3) * t_0) 195 | return mt 196 | 197 | 198 | def scaled_similarity_matrix(X: np.array) -> np.ndarray: 199 | r"""Compute the scaled similarity matrix. 200 | 201 | .. math:: 202 | X(i,j) = \frac{X(i,j)}{\sqrt{X(i,i)X(j,j)}} 203 | 204 | Parameters 205 | ---------- 206 | X : ndarray of shape (n_samples, n_samples) 207 | Similarity matrix of `n_samples`. 208 | 209 | Returns 210 | ------- 211 | s : ndarray of shape (n_samples, n_samples) 212 | A scaled symmetric similarity matrix. 213 | 214 | """ 215 | if X.ndim != 2: 216 | raise ValueError(f"Argument similarity matrix should be a 2D array, got {X.ndim}") 217 | 218 | if X.shape[0] != X.shape[1]: 219 | raise ValueError( 220 | f"Argument similarity matrix should be a square matrix (having same number of rows and columns), got {X.shape[0]} and {X.shape[1]}" 221 | ) 222 | 223 | if not (np.all(X >= 0) and np.all(np.diag(X) > 0)): 224 | raise ValueError( 225 | "All elements of similarity matrix should be greater than zero and diagonals should be non-zero" 226 | ) 227 | 228 | # scaling does not happen if the matrix is binary similarity matrix with all diagonal elements as 1 229 | if np.all(np.diag(X) == 1): 230 | print("No scaling is taking effect") 231 | return X 232 | else: 233 | # make a scaled similarity matrix 234 | n_samples = len(X) 235 | s = np.zeros((n_samples, n_samples)) 236 | # calculate the square root of the diagonal elements 237 | sqrt_diag = np.sqrt(np.diag(X)) 238 | # calculate the product of the square roots of the diagonal elements 239 | product_sqrt_diag = np.outer(sqrt_diag, sqrt_diag) 240 | # divide each element of the matrix by the product of the square roots of diagonal elements 241 | s = X / product_sqrt_diag 242 | return s 243 | 244 | 245 | def similarity_index(x: np.array, y: np.array, sim_index: str) -> float: 246 | """Compute similarity index matrix. 247 | 248 | Parameters 249 | ---------- 250 | x : ndarray of shape (n_features,) 251 | Feature array of sample `x` in an `n_features` dimensional space 252 | y : ndarray of shape (n_features,) 253 | Feature array of sample `y` in an `n_features` dimensional space 254 | sim_index : str, optional 255 | The key with the abbreviation of the similarity index to be used for calculations. 256 | Possible values are: 257 | - 'AC': Austin-Colwell 258 | - 'BUB': Baroni-Urbani-Buser 259 | - 'CTn': Consoni-Todschini n (n=1,2) 260 | - 'Fai': Faith 261 | - 'Gle': Gleason 262 | - 'Ja': Jaccard 263 | - 'JT': Jaccard-Tanimoto 264 | - 'RT': Rogers-Tanimoto 265 | - 'RR': Russel-Rao 266 | - 'SM': Sokal-Michener 267 | - 'SSn': Sokal-Sneath n (n=1,2) 268 | Default is 'RR'. 269 | 270 | Returns 271 | ------- 272 | sim : float 273 | The similarity index value between the feature arrays `x` and `y`. 274 | """ 275 | # Define the similarity index functions 276 | similarity_indices = { 277 | "AC": lambda a, d, dis, p: 2 / np.pi * np.arcsin(((a + d) / p) ** 0.5), 278 | "BUB": lambda a, d, dis, p: ((a * d) ** 0.5 + a) / ((a * d) ** 0.5 + a + dis), 279 | "CT1": lambda a, d, dis, p: np.log(1 + a + d) / np.log(1 + p), 280 | "CT2": lambda a, d, dis, p: (np.log(1 + p) - np.log(1 + dis)) / np.log(1 + p), 281 | "Fai": lambda a, d, dis, p: (a + 0.5 * d) / p, 282 | "Gle": lambda a, d, dis, p: 2 * a / (2 * a + dis), 283 | "Ja": lambda a, d, dis, p: 3 * a / (3 * a + dis), 284 | "JT": lambda a, d, dis, p: a / (a + dis), 285 | "RT": lambda a, d, dis, p: (a + d) / (p + dis), 286 | "RR": lambda a, d, dis, p: a / p, 287 | "SM": lambda a, d, dis, p: (a + d) / p, 288 | "SS1": lambda a, d, dis, p: a / (a + 2 * dis), 289 | "SS2": lambda a, d, dis, p: (2 * (a + d)) / (p + (a + d)), 290 | } 291 | 292 | if sim_index not in similarity_indices: 293 | raise ValueError( 294 | f"Argument sim_index={sim_index} is not recognized! Choose from {similarity_indices.keys()}" 295 | ) 296 | if x.ndim != 1 or y.ndim != 1: 297 | raise ValueError(f"Arguments x and y should be 1D arrays, got {x.ndim} and {y.ndim}") 298 | if x.shape != y.shape: 299 | raise ValueError( 300 | f"Arguments x and y should have the same shape, got {x.shape} != {y.shape}" 301 | ) 302 | a, d, dis, p = _compute_base_descriptors(x, y) 303 | return similarity_indices[sim_index](a, d, dis, p) 304 | 305 | 306 | def _compute_base_descriptors(x, y): 307 | """Compute the base descriptors for the similarity indices. 308 | 309 | Parameters 310 | ---------- 311 | x : ndarray of shape (n_features,) 312 | Feature array of sample `x` in an `n_features` dimensional space 313 | y : ndarray of shape (n_features,) 314 | Feature array of sample `y` in an `n_features` dimensional space 315 | 316 | Returns 317 | ------- 318 | tuple(int, int, int, int) 319 | The number of common on bits, number of common off bits, number of 1-0 mismatches, and the 320 | length of the fingerprint. 321 | """ 322 | p = len(x) 323 | a = np.dot(x, y) 324 | d = np.dot(1 - x, 1 - y) 325 | dis = p - a - d 326 | return a, d, dis, p 327 | -------------------------------------------------------------------------------- /selector/measures/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | """Test Module.""" 26 | -------------------------------------------------------------------------------- /selector/measures/tests/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Common functions for test module.""" 25 | 26 | import numpy as np 27 | 28 | try: 29 | from importlib_resources import path 30 | except ImportError: 31 | from importlib.resources import path 32 | 33 | 34 | def bit_cosine(a, b): 35 | """Compute dice coefficient. 36 | 37 | Parameters 38 | ---------- 39 | a : array_like 40 | molecule A's features in bit string. 41 | b : array_like 42 | molecules B's features in bit string. 43 | 44 | Returns 45 | ------- 46 | coeff : int 47 | dice coefficient for molecule A and B. 48 | """ 49 | a_feat = np.count_nonzero(a) 50 | b_feat = np.count_nonzero(b) 51 | c = 0 52 | for idx, _ in enumerate(a): 53 | if a[idx] == b[idx] and a[idx] != 0: 54 | c += 1 55 | b_c = c / ((a_feat * b_feat) ** 0.5) 56 | return b_c 57 | 58 | 59 | def bit_dice(a, b): 60 | """Compute dice coefficient. 61 | 62 | Parameters 63 | ---------- 64 | a : array_like 65 | molecule A's features. 66 | b : array_like 67 | molecules B's features. 68 | 69 | Returns 70 | ------- 71 | coeff : int 72 | dice coefficient for molecule A and B. 73 | """ 74 | a_feat = np.count_nonzero(a) 75 | b_feat = np.count_nonzero(b) 76 | c = 0 77 | for idx, _ in enumerate(a): 78 | if a[idx] == b[idx] and a[idx] != 0: 79 | c += 1 80 | b_d = (2 * c) / (a_feat + b_feat) 81 | return b_d 82 | 83 | 84 | def cosine(a, b): 85 | """Compute cosine coefficient. 86 | 87 | Parameters 88 | ---------- 89 | a : array_like 90 | molecule A's features. 91 | b : array_like 92 | molecules B's features. 93 | 94 | Returns 95 | ------- 96 | coeff : int 97 | cosine coefficient for molecule A and B. 98 | """ 99 | coeff = (sum(a * b)) / (((sum(a**2)) + (sum(b**2))) ** 0.5) 100 | return coeff 101 | 102 | 103 | def dice(a, b): 104 | """Compute dice coefficient. 105 | 106 | Parameters 107 | ---------- 108 | a : array_like 109 | molecule A's features. 110 | b : array_like 111 | molecules B's features. 112 | 113 | Returns 114 | ------- 115 | coeff : int 116 | dice coefficient for molecule A and B. 117 | """ 118 | coeff = (2 * (sum(a * b))) / ((sum(a**2)) + (sum(b**2))) 119 | return coeff 120 | -------------------------------------------------------------------------------- /selector/measures/tests/test_converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Test Converter Module.""" 25 | 26 | import numpy as np 27 | from numpy.testing import assert_almost_equal, assert_equal, assert_raises 28 | 29 | from selector.measures import converter as cv 30 | 31 | # Tests for variations on input `x` for sim_to_dist() 32 | 33 | 34 | def test_sim_2_dist_float_int(): 35 | """Test similarity to distance input handling when input is a float or int.""" 36 | expected_1 = 0.25 37 | int_out = cv.sim_to_dist(4, "reciprocal") 38 | assert_equal(int_out, expected_1) 39 | expected_2 = 2 40 | float_out = cv.sim_to_dist(0.5, "reciprocal") 41 | assert_equal(float_out, expected_2) 42 | 43 | 44 | def test_sim_2_dist_array_dimension_error(): 45 | """Test sim to dist function with incorrect input dimensions for `x`.""" 46 | assert_raises(ValueError, cv.sim_to_dist, np.ones([2, 2, 2]), "reciprocal") 47 | 48 | 49 | def test_sim_2_dist_1d_metric_error(): 50 | """Test sim to dist function with an invalid metric for 1D arrays.""" 51 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "gravity") 52 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "co-occurrence") 53 | 54 | 55 | # Tests for variations on input `metric` for sim_to_dist() 56 | 57 | 58 | def test_sim_2_dist(): 59 | """Test similarity to distance method with specified metric.""" 60 | x = np.array([[1, 0.2, 0.5], [0.2, 1, 0.25], [0.5, 0.25, 1]]) 61 | expected = np.array([[0.20, 1, 0.70], [1, 0.20, 0.95], [0.70, 0.95, 0.20]]) 62 | actual = cv.sim_to_dist(x, "reverse") 63 | assert_almost_equal(actual, expected, decimal=10) 64 | 65 | 66 | def test_sim_2_dist_frequency(): 67 | """Test similarity to distance method with a frequency metric.""" 68 | x = np.array([[4, 9, 1], [9, 1, 25], [1, 25, 16]]) 69 | expected = np.array([[(1 / 2), (1 / 3), 1], [(1 / 3), 1, (1 / 5)], [1, (1 / 5), (1 / 4)]]) 70 | actual = cv.sim_to_dist(x, "transition") 71 | assert_almost_equal(actual, expected, decimal=10) 72 | 73 | 74 | def test_sim_2_dist_frequency_error(): 75 | """Test similarity to distance method with a frequency metric and incorrect input.""" 76 | # zeroes in the frequency matrix 77 | x = np.array([[0, 9, 1], [9, 1, 25], [1, 25, 0]]) 78 | assert_raises(ValueError, cv.sim_to_dist, x, "gravity") 79 | # negatives in the frequency matrix 80 | y = np.array([[1, -9, 1], [9, 1, -25], [1, 25, 16]]) 81 | assert_raises(ValueError, cv.sim_to_dist, x, "gravity") 82 | 83 | 84 | def test_sim_2_dist_membership(): 85 | """Test similarity to distance method with the membership metric.""" 86 | # x = np.array([[(1 / 2), (1 / 5)], [(1 / 4), (1 / 3)]]) 87 | x = np.array([[1, 1 / 5, 1 / 3], [1 / 5, 1, 4 / 5], [1 / 3, 4 / 5, 1]]) 88 | expected = np.array([[0, 4 / 5, 2 / 3], [4 / 5, 0, 1 / 5], [2 / 3, 1 / 5, 0]]) 89 | actual = cv.sim_to_dist(x, "membership") 90 | assert_almost_equal(actual, expected, decimal=10) 91 | 92 | 93 | def test_sim_2_dist_membership_error(): 94 | """Test similarity to distance method with the membership metric when there is an input error.""" 95 | x = np.array([[1, 0, -7], [0, 1, 3], [-7, 3, 1]]) 96 | assert_raises(ValueError, cv.sim_to_dist, x, "membership") 97 | 98 | 99 | def test_sim_2_dist_invalid_metric(): 100 | """Test similarity to distance method with an unsupported metric.""" 101 | assert_raises(ValueError, cv.sim_to_dist, np.ones(5), "testing") 102 | 103 | 104 | def test_sim_2_dist_non_symmetric(): 105 | """Test the invalid 2D symmetric matrix error.""" 106 | x = np.array([[1, 2], [4, 5]]) 107 | assert_raises(ValueError, cv.sim_to_dist, x, "reverse") 108 | 109 | 110 | # Tests for individual metrics 111 | 112 | 113 | def test_reverse(): 114 | """Test the reverse function for similarity to distance conversion.""" 115 | x = np.array([[3, 1, 1], [1, 3, 0], [1, 0, 3]]) 116 | expected = np.array([[0, 2, 2], [2, 0, 3], [2, 3, 0]]) 117 | actual = cv.reverse(x) 118 | assert_equal(actual, expected) 119 | 120 | 121 | def test_reciprocal(): 122 | """Test the reverse function for similarity to distance conversion.""" 123 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]]) 124 | expected = np.array([[1, 4, 2.5], [4, 1, 1.6], [2.5, 1.6, 1]]) 125 | actual = cv.reciprocal(x) 126 | assert_equal(actual, expected) 127 | 128 | 129 | def test_reciprocal_error(): 130 | """Test the reverse function with incorrect input values.""" 131 | # zero value for similarity (causes divide by zero issues) 132 | x = np.array([[0, 4], [3, 2]]) 133 | assert_raises(ValueError, cv.reciprocal, x) 134 | # negative value for similarity (distance cannot be negative) 135 | y = np.array([[1, -4], [3, 2]]) 136 | assert_raises(ValueError, cv.reciprocal, y) 137 | 138 | 139 | def test_exponential(): 140 | """Test the exponential function for similarity to distance conversion.""" 141 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]]) 142 | expected = np.array( 143 | [ 144 | [0, 1.38629436112, 0.91629073187], 145 | [1.38629436112, 0, 0.47000362924], 146 | [0.91629073187, 0.47000362924, 0], 147 | ] 148 | ) 149 | actual = cv.exponential(x) 150 | assert_almost_equal(actual, expected, decimal=10) 151 | 152 | 153 | def test_exponential_error(): 154 | """Test the exponential function when max similarity is zero.""" 155 | x = np.zeros((4, 4)) 156 | assert_raises(ValueError, cv.exponential, x) 157 | 158 | 159 | def test_gaussian(): 160 | """Test the gaussian function for similarity to distance conversion.""" 161 | x = np.array([[1, 0.25, 0.40], [0.25, 1, 0.625], [0.40, 0.625, 1]]) 162 | expected = np.array( 163 | [ 164 | [0, 1.17741002252, 0.95723076208], 165 | [1.17741002252, 0, 0.68556810693], 166 | [0.95723076208, 0.68556810693, 0], 167 | ] 168 | ) 169 | actual = cv.gaussian(x) 170 | assert_almost_equal(actual, expected, decimal=10) 171 | 172 | 173 | def test_gaussian_error(): 174 | """Test the gaussian function when max similarity is zero.""" 175 | x = np.zeros((4, 4)) 176 | assert_raises(ValueError, cv.gaussian, x) 177 | 178 | 179 | def test_correlation(): 180 | """Test the correlation to distance conversion function.""" 181 | x = np.array([[1, 0.5, 0.2], [0.5, 1, -0.2], [0.2, -0.2, 1]]) 182 | # expected = sqrt(1-x) 183 | expected = np.array( 184 | [ 185 | [0, 0.70710678118, 0.894427191], 186 | [0.70710678118, 0, 1.09544511501], 187 | [0.894427191, 1.09544511501, 0], 188 | ] 189 | ) 190 | actual = cv.correlation(x) 191 | assert_almost_equal(actual, expected, decimal=10) 192 | 193 | 194 | def test_correlation_error(): 195 | """Test the correlation function with an out of bounds array.""" 196 | x = np.array([[1, 0, -7], [0, 1, 3], [-7, 3, 1]]) 197 | assert_raises(ValueError, cv.correlation, x) 198 | 199 | 200 | def test_transition(): 201 | """Test the transition function for frequency to distance conversion.""" 202 | x = np.array([[4, 9, 1], [9, 1, 25], [1, 25, 16]]) 203 | expected = np.array([[(1 / 2), (1 / 3), 1], [(1 / 3), 1, (1 / 5)], [1, (1 / 5), (1 / 4)]]) 204 | 205 | actual = cv.transition(x) 206 | assert_almost_equal(actual, expected, decimal=10) 207 | 208 | 209 | def test_co_occurrence(): 210 | """Test the co-occurrence conversion function.""" 211 | x = np.array([[1, 2, 3], [2, 1, 3], [3, 3, 1]]) 212 | expected = np.array( 213 | [ 214 | [1 / (19 / 121 + 1), 1 / (38 / 121 + 1), 1 / (57 / 121 + 1)], 215 | [1 / (38 / 121 + 1), 1 / (19 / 121 + 1), 1 / (57 / 121 + 1)], 216 | [1 / (57 / 121 + 1), 1 / (57 / 121 + 1), 1 / (19 / 121 + 1)], 217 | ] 218 | ) 219 | actual = cv.co_occurrence(x) 220 | assert_almost_equal(actual, expected, decimal=10) 221 | 222 | 223 | def test_gravity(): 224 | """Test the gravity conversion function.""" 225 | x = np.array([[1, 2, 3], [2, 1, 3], [3, 3, 1]]) 226 | expected = np.array( 227 | [ 228 | [2.5235730726, 1.7844356324, 1.45698559277], 229 | [1.7844356324, 2.5235730726, 1.45698559277], 230 | [1.45698559277, 1.45698559277, 2.5235730726], 231 | ] 232 | ) 233 | actual = cv.gravity(x) 234 | assert_almost_equal(actual, expected, decimal=10) 235 | 236 | 237 | def test_probability(): 238 | """Test the probability to distance conversion function.""" 239 | x = np.array([[0.3, 0.7], [0.5, 0.5]]) 240 | expected = np.array([[1.8116279322, 1.1356324735], [1.3819765979, 1.3819765979]]) 241 | actual = cv.probability(x) 242 | assert_almost_equal(actual, expected, decimal=10) 243 | 244 | 245 | def test_probability_error(): 246 | """Test the correlation function with an out of bounds array.""" 247 | # negative value for probability 248 | x = np.array([[-0.5]]) 249 | assert_raises(ValueError, cv.probability, x) 250 | # too large value for probability 251 | y = np.array([[3]]) 252 | assert_raises(ValueError, cv.probability, y) 253 | # zero value for probability (causes divide by zero issues) 254 | z = np.array([[0]]) 255 | assert_raises(ValueError, cv.probability, z) 256 | 257 | 258 | def test_covariance(): 259 | """Test the covariance to distance conversion function.""" 260 | x = np.array([[4, -4], [-4, 6]]) 261 | expected = np.array([[0, 4.24264068712], [4.24264068712, 0]]) 262 | actual = cv.covariance(x) 263 | assert_almost_equal(actual, expected, decimal=10) 264 | 265 | 266 | def test_covariance_error(): 267 | """Test the covariance function when input contains a negative variance.""" 268 | x = np.array([[-4, 4], [4, 6]]) 269 | assert_raises(ValueError, cv.covariance, x) 270 | -------------------------------------------------------------------------------- /selector/measures/tests/test_diversity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | """Test Diversity Module.""" 26 | import warnings 27 | 28 | import numpy as np 29 | import pytest 30 | from numpy.testing import assert_almost_equal, assert_equal, assert_raises, assert_warns 31 | 32 | from selector.measures.diversity import ( 33 | compute_diversity, 34 | explicit_diversity_index, 35 | gini_coefficient, 36 | hypersphere_overlap_of_subset, 37 | logdet, 38 | nearest_average_tanimoto, 39 | shannon_entropy, 40 | wdud, 41 | ) 42 | 43 | # each row is a feature and each column is a molecule 44 | sample1 = np.array([[4, 2, 6], [4, 9, 6], [2, 5, 0], [2, 0, 9], [5, 3, 0]]) 45 | 46 | # each row is a molecule and each column is a feature (scipy) 47 | sample2 = np.array([[1, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) 48 | 49 | sample3 = np.array([[1, 4], [3, 2]]) 50 | 51 | sample4 = np.array([[1, 0, 1], [0, 1, 1]]) 52 | 53 | sample5 = np.array([[0, 2, 4, 0], [1, 2, 4, 0], [2, 2, 4, 0]]) 54 | 55 | sample6 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]]) 56 | 57 | sample7 = np.array([[1, 0, 1, 0] for _ in range(4)]) 58 | 59 | sample8 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) 60 | 61 | 62 | def test_compute_diversity_specified(): 63 | """Test compute diversity with a specified div_type.""" 64 | comp_div = compute_diversity(sample6, "shannon_entropy", normalize=False, truncation=False) 65 | expected = 1.81 66 | assert round(comp_div, 2) == expected 67 | 68 | 69 | def test_compute_diversity_hyperspheres(): 70 | """Test compute diversity with two arguments for hypersphere_overlap method""" 71 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 72 | centers_pts = np.array([[0.5, 0.5]] * (100 - 4)) 73 | pts = np.vstack((corner_pts, centers_pts)) 74 | 75 | comp_div = compute_diversity(pts, div_type="hypersphere_overlap", features=pts) 76 | # Expected = overlap + edge penalty 77 | expected = (100.0 * 96 * 95 * 0.5) + 2.0 78 | assert_almost_equal(comp_div, expected) 79 | 80 | 81 | def test_compute_diversity_hypersphere_error(): 82 | """Test compute diversity with hypersphere metric and no molecule library given.""" 83 | assert_raises(ValueError, compute_diversity, sample5, "hypersphere_overlap") 84 | 85 | 86 | def test_compute_diversity_edi(): 87 | """Test compute diversity with explicit diversity index div_type""" 88 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]]) 89 | cs = 1 90 | expected = 56.39551204 91 | actual = compute_diversity(z, "explicit_diversity_index", cs=cs) 92 | assert_almost_equal(expected, actual) 93 | 94 | 95 | def test_compute_diversity_edi_no_cs_error(): 96 | """Test compute diversity with explicit diversity index and no `cs` value given.""" 97 | assert_raises(ValueError, compute_diversity, sample5, "explicit_diversity_index") 98 | 99 | 100 | def test_compute_diversity_edi_zero_error(): 101 | """Test compute diversity with explicit diversity index and `cs` = 0.""" 102 | assert_raises(ValueError, compute_diversity, sample5, "explicit diversity index", cs=0) 103 | 104 | 105 | def test_compute_diversity_invalid(): 106 | """Test compute diversity with a non-supported div_type.""" 107 | assert_raises(ValueError, compute_diversity, sample1, "diversity_type") 108 | 109 | 110 | def test_logdet(): 111 | """Test the log determinant function with predefined subset matrix.""" 112 | sel = logdet(sample3) 113 | expected = np.log(131) 114 | assert_almost_equal(sel, expected) 115 | 116 | 117 | def test_logdet_non_square_matrix(): 118 | """Test the log determinant function with a rectangular matrix.""" 119 | sel = logdet(sample4) 120 | expected = np.log(8) 121 | assert_almost_equal(sel, expected) 122 | 123 | 124 | def test_shannon_entropy(): 125 | """Test the shannon entropy function with example from the original paper.""" 126 | 127 | # example taken from figure 1 of 10.1021/ci900159f 128 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]]) 129 | expected = 1.81 130 | assert round(shannon_entropy(x1, normalize=False, truncation=False), 2) == expected 131 | 132 | x2 = np.vstack((x1, [1, 1, 1, 0])) 133 | expected = 1.94 134 | assert round(shannon_entropy(x2, normalize=False, truncation=False), 2) == expected 135 | 136 | x3 = np.vstack((x1, [0, 1, 0, 1])) 137 | expected = 3.39 138 | assert round(shannon_entropy(x3, normalize=False, truncation=False), 2) == expected 139 | 140 | 141 | def test_shannon_entropy_normalize(): 142 | """Test the shannon entropy function with normalization.""" 143 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]]) 144 | expected = 1.81 / (x1.shape[1] * np.log2(2) / 2) 145 | assert_almost_equal( 146 | actual=shannon_entropy(x1, normalize=True, truncation=False), 147 | desired=expected, 148 | decimal=2, 149 | ) 150 | 151 | 152 | def test_shannon_entropy_warning(): 153 | """Test the shannon entropy function gives warning when normalization is True and truncation is True.""" 154 | x1 = np.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 0]]) 155 | with pytest.warns(UserWarning): 156 | shannon_entropy(x1, normalize=True, truncation=True) 157 | 158 | 159 | def test_shannon_entropy_binary_error(): 160 | """Test the shannon entropy function raises error with a non binary matrix.""" 161 | assert_raises(ValueError, shannon_entropy, sample5) 162 | 163 | 164 | def test_explicit_diversity_index(): 165 | """Test the explicit diversity index function.""" 166 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]]) 167 | cs = 1 168 | nc = 3 169 | sdi = 0.75 / 0.7332902012 170 | cr = -1 * 0.4771212547 171 | edi = 0.5456661753 * 0.7071067811865476 172 | edi_scaled = 56.395512045413 173 | value = explicit_diversity_index(z, cs) 174 | assert_almost_equal(value, edi_scaled, decimal=8) 175 | 176 | 177 | def test_wdud_uniform(): 178 | """Test wdud when a feature has uniform distribution.""" 179 | uni = np.arange(0, 50000)[:, None] 180 | wdud_val = wdud(uni) 181 | expected = 0 182 | assert_almost_equal(wdud_val, expected, decimal=4) 183 | 184 | 185 | def test_wdud_repeat_yi(): 186 | """Test wdud when a feature has multiple identical values.""" 187 | dist = np.array([[0, 0.5, 0.5, 0.75, 1]]).T 188 | wdud_val = wdud(dist) 189 | # calculated using wolfram alpha: 190 | expected = 0.065 + 0.01625 + 0.02125 191 | assert_almost_equal(wdud_val, expected, decimal=4) 192 | 193 | 194 | def test_wdud_mult_features(): 195 | """Test wdud when there are multiple features per molecule.""" 196 | dist = np.array( 197 | [ 198 | [0, 0.5, 0.5, 0.75, 1], 199 | [0, 0.5, 0.5, 0.75, 1], 200 | [0, 0.5, 0.5, 0.75, 1], 201 | [0, 0.5, 0.5, 0.75, 1], 202 | ] 203 | ).T 204 | wdud_val = wdud(dist) 205 | # calculated using wolfram alpha: 206 | expected = 0.065 + 0.01625 + 0.02125 207 | assert_almost_equal(wdud_val, expected, decimal=4) 208 | 209 | 210 | def test_wdud_dimension_error(): 211 | """Test wdud method raises error when input has incorrect dimensions.""" 212 | arr = np.zeros((2, 2, 2)) 213 | assert_raises(ValueError, wdud, arr) 214 | 215 | 216 | def test_wdud_normalization_error(): 217 | """Test wdud method raises error when normalization fails.""" 218 | assert_raises(ValueError, wdud, sample8) 219 | 220 | 221 | def test_wdud_warning_normalization(): 222 | """Test wdud method gives warning when normalization fails.""" 223 | warning_message = ( 224 | "Some of the features are constant which will cause the normalization to fail. " 225 | + "Now removing them." 226 | ) 227 | with pytest.warns() as record: 228 | wdud(sample6) 229 | 230 | # check that the message matches 231 | assert record[0].message.args[0] == warning_message 232 | 233 | 234 | def test_hypersphere_overlap_of_subset_with_only_corners_and_center(): 235 | """Test the hypersphere overlap method with predefined matrix.""" 236 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 237 | # Many duplicate pts cause r_0 to be much smaller than 1.0, 238 | # which is required due to normalization of the feature space 239 | centers_pts = np.array([[0.5, 0.5]] * (100 - 4)) 240 | pts = np.vstack((corner_pts, centers_pts)) 241 | 242 | # Overlap should be all coming from the centers 243 | expected_overlap = 100.0 * 96 * 95 * 0.5 244 | # The edge penalty should all be from the corner pts 245 | lam = 1.0 / 2.0 # Default lambda chosen from paper. 246 | expected_edge = lam * 4.0 247 | expected = expected_overlap + expected_edge 248 | true = hypersphere_overlap_of_subset(pts, pts) 249 | assert_almost_equal(true, expected) 250 | 251 | 252 | def test_hypersphere_normalization_error(): 253 | """Test the hypersphere overlap method raises error when normalization fails.""" 254 | assert_raises(ValueError, hypersphere_overlap_of_subset, sample7, sample7) 255 | 256 | 257 | def test_hypersphere_radius_warning(): 258 | """Test the hypersphere overlap method gives warning when radius is too large.""" 259 | corner_pts = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 260 | assert_warns(Warning, hypersphere_overlap_of_subset, corner_pts, corner_pts) 261 | 262 | 263 | def test_gini_coefficient_of_non_diverse_set(): 264 | """Test Gini coefficient of the least diverse set. Expected return is zero.""" 265 | # Finger-prints where columns are all the same 266 | numb_molecules = 5 267 | numb_features = 10 268 | # Transpose so that the columns are all the same, note first made the rows all same 269 | single_fingerprint = list(np.random.choice([0, 1], size=(numb_features,))) 270 | finger_prints = np.array([single_fingerprint] * numb_molecules).T 271 | 272 | result = gini_coefficient(finger_prints) 273 | # Since they are all the same, then gini coefficient should be zero. 274 | assert_almost_equal(result, 0.0, decimal=8) 275 | 276 | 277 | def test_gini_coefficient_non_binary_error(): 278 | """Test Gini coefficient error when input is not binary.""" 279 | assert_raises(ValueError, gini_coefficient, np.array([[7, 0], [2, 1]])) 280 | 281 | 282 | def test_gini_coefficient_dimension_error(): 283 | """Test Gini coefficient error when input has incorrect dimensions.""" 284 | assert_raises(ValueError, gini_coefficient, np.array([1, 0, 0, 0])) 285 | 286 | 287 | def test_gini_coefficient_of_most_diverse_set(): 288 | """Test Gini coefficient of the most diverse set.""" 289 | # Finger-prints where one feature has more `wealth` than all others. 290 | # Note: Transpose is done so one column has all ones. 291 | finger_prints = np.array( 292 | [ 293 | [1, 1, 1, 1, 1, 1, 1], 294 | ] 295 | + [[0, 0, 0, 0, 0, 0, 0]] * 100000 296 | ).T 297 | result = gini_coefficient(finger_prints) 298 | # Since they are all the same, then gini coefficient should be zero. 299 | assert_almost_equal(result, 1.0, decimal=4) 300 | 301 | 302 | def test_gini_coefficient_with_alternative_definition(): 303 | """Test Gini coefficient with alternative definition.""" 304 | # Finger-prints where they are all different 305 | numb_features = 4 306 | finger_prints = np.array([[1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1]]) 307 | result = gini_coefficient(finger_prints) 308 | 309 | # Alternative definition from wikipedia 310 | b = numb_features + 1 311 | desired = ( 312 | numb_features + 1 - 2 * ((b - 1) + (b - 2) * 2 + (b - 3) * 3 + (b - 4) * 4) / (10) 313 | ) / numb_features 314 | assert_almost_equal(result, desired) 315 | 316 | 317 | def test_nearest_average_tanimoto_bit(): 318 | """Test the nearest_average_tanimoto function with binary input.""" 319 | nat = nearest_average_tanimoto(sample2) 320 | shortest_tani = [0.3333333, 0.3333333, 0, 0] 321 | average = np.average(shortest_tani) 322 | assert_almost_equal(nat, average) 323 | 324 | 325 | def test_nearest_average_tanimoto(): 326 | """Test the nearest_average_tanimoto function with non-binary input.""" 327 | nat = nearest_average_tanimoto(sample3) 328 | shortest_tani = [(11 / 19), (11 / 19)] 329 | average = np.average(shortest_tani) 330 | assert_equal(nat, average) 331 | 332 | 333 | def test_nearest_average_tanimoto_3_x_3(): 334 | """Testpyth the nearest_average_tanimoto function with a 3x3 matrix.""" 335 | # all unequal distances b/w points 336 | x = np.array([[0, 1, 2], [3, 4, 5], [4, 5, 6]]) 337 | nat_x = nearest_average_tanimoto(x) 338 | avg_x = 0.749718574108818 339 | assert_equal(nat_x, avg_x) 340 | # one point equidistant from the other two 341 | y = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 342 | nat_y = nearest_average_tanimoto(y) 343 | avg_y = 0.4813295920569825 344 | assert_equal(nat_y, avg_y) 345 | # all points equidistant 346 | z = np.array([[0, 1, 2], [1, 2, 0], [2, 0, 1]]) 347 | nat_z = nearest_average_tanimoto(z) 348 | avg_z = 0.25 349 | assert_equal(nat_z, avg_z) 350 | 351 | 352 | def test_nearest_average_tanimoto_nonsquare(): 353 | """Test the nearest_average_tanimoto function with non-binary input""" 354 | x = np.array([[3.5, 4.0, 10.5, 0.5], [1.25, 4.0, 7.0, 0.1], [0.0, 0.0, 0.0, 0.0]]) 355 | # nearest neighbor of sample 0, 1, and 2 are sample 1, 0, and 1, respectively. 356 | expected = np.average( 357 | [ 358 | np.sum(x[0] * x[1]) / (np.sum(x[0] ** 2) + np.sum(x[1] ** 2) - np.sum(x[0] * x[1])), 359 | np.sum(x[1] * x[0]) / (np.sum(x[1] ** 2) + np.sum(x[0] ** 2) - np.sum(x[1] * x[0])), 360 | np.sum(x[2] * x[1]) / (np.sum(x[2] ** 2) + np.sum(x[1] ** 2) - np.sum(x[2] * x[1])), 361 | ] 362 | ) 363 | assert_equal(nearest_average_tanimoto(x), expected) 364 | -------------------------------------------------------------------------------- /selector/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | from selector.methods.distance import * 26 | from selector.methods.partition import * 27 | from selector.methods.similarity import * 28 | -------------------------------------------------------------------------------- /selector/methods/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Base class for diversity based subset selection.""" 25 | 26 | import warnings 27 | from abc import ABC, abstractmethod 28 | from typing import List, Iterable, Union 29 | 30 | import numpy as np 31 | 32 | __all__ = ["SelectionBase"] 33 | 34 | 35 | class SelectionBase(ABC): 36 | """Base class for selecting subset of sample points.""" 37 | 38 | def select( 39 | self, 40 | x: np.ndarray, 41 | size: int, 42 | labels: np.ndarray = None, 43 | proportional_selection: bool = True, 44 | ) -> Union[List, Iterable]: 45 | """Return indices representing subset of sample points. 46 | 47 | Parameters 48 | ---------- 49 | x: ndarray of shape (n_samples, n_features) or (n_samples, n_samples) 50 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space. 51 | If fun_distance is `None`, this x is treated as a square pairwise distance matrix. 52 | size: int 53 | Number of sample points to select (i.e. size of the subset). 54 | labels: np.ndarray, optional 55 | Array of integers or strings representing the labels of the clusters that 56 | each sample belongs to. If `None`, the samples are treated as one cluster. 57 | If labels are provided, selection is made from each cluster. 58 | proportional_selection: bool, optional 59 | If True, the number of samples to be selected from each cluster is proportional. 60 | Otherwise, the number of samples to be selected from each cluster is equal. 61 | Default is True. 62 | 63 | Returns 64 | ------- 65 | selected: list 66 | Indices of the selected sample points. 67 | """ 68 | # check size 69 | if size > len(x): 70 | raise ValueError( 71 | f"Size of subset {size} cannot be larger than number of samples {len(x)}." 72 | ) 73 | 74 | # if labels are not provided, indices selected from one cluster is returned 75 | if labels is None: 76 | return self.select_from_cluster(x, size) 77 | 78 | # check labels are consistent with number of samples 79 | if len(labels) != len(x): 80 | raise ValueError( 81 | f"Number of labels {len(labels)} does not match number of samples {len(x)}." 82 | ) 83 | 84 | selected_ids = [] 85 | 86 | # compute the number of samples (i.e. population or pop) in each cluster 87 | unique_labels, unique_label_counts = np.unique(labels, return_counts=True) 88 | num_clusters = len(unique_labels) 89 | pop_clusters = dict(zip(unique_labels, unique_label_counts)) 90 | # compute number of samples to be selected from each cluster 91 | if proportional_selection: 92 | # make sure that tht total number of samples selected is equal to size 93 | size_each_cluster = size * unique_label_counts / len(labels) 94 | # using np.round to get to the nearest integer 95 | # not using int function directly to avoid truncation of decimal values 96 | size_each_cluster = np.round(size_each_cluster).astype(int) 97 | # make sure each cluster has at least one sample 98 | size_each_cluster[size_each_cluster < 1] = 1 99 | 100 | # the total number of samples selected from all clusters at this point 101 | size_each_cluster_total = np.sum(size_each_cluster) 102 | # when the total of data points in each class is less than the required number 103 | # add one sample to the smallest cluster iteratively until the total is equal to the 104 | # required number 105 | if size_each_cluster_total < size: 106 | while size_each_cluster_total < size: 107 | # the number of remaining data points in each cluster 108 | size_each_cluster_remaining = unique_label_counts - size_each_cluster_total 109 | # skip the clusters with no data points left 110 | size_each_cluster_remaining[size_each_cluster_remaining == 0] = np.inf 111 | smallest_cluster_index = np.argmin(size_each_cluster_remaining) 112 | size_each_cluster[smallest_cluster_index] += 1 113 | size_each_cluster_total += 1 114 | # when the total of data points in each class is more than the required number 115 | # we need to remove samples from the largest clusters 116 | elif size_each_cluster_total > size: 117 | while size_each_cluster_total > size: 118 | largest_cluster_index = np.argmax(size_each_cluster) 119 | size_each_cluster[largest_cluster_index] -= 1 120 | size_each_cluster_total -= 1 121 | # perfect case where the total is equal to the required number 122 | else: 123 | pass 124 | else: 125 | size_each_cluster = size // num_clusters 126 | 127 | # update number of samples to select from each cluster based on the cluster population. 128 | # this is needed when some clusters do not have enough samples in them 129 | # (pop < size_each_cluster) and needs to be done iteratively until all remaining clusters 130 | # have at least size_each_cluster samples 131 | while np.any( 132 | [value <= size_each_cluster for value in pop_clusters.values() if value != 0] 133 | ): 134 | for unique_label in unique_labels: 135 | if pop_clusters[unique_label] != 0: 136 | # get index of sample labelled with unique_label 137 | cluster_ids = np.where(labels == unique_label)[0] 138 | if len(cluster_ids) <= size_each_cluster: 139 | # all samples in the cluster are selected & population becomes zero 140 | selected_ids.append(cluster_ids) 141 | pop_clusters[unique_label] = 0 142 | # update number of samples to be selected from each cluster 143 | totally_used_clusters = list(pop_clusters.values()).count(0) 144 | size_each_cluster = (size - len(np.hstack(selected_ids))) // ( 145 | num_clusters - totally_used_clusters 146 | ) 147 | 148 | warnings.warn( 149 | f"Number of molecules in one cluster is less than" 150 | f" {size}/{num_clusters}.\nNumber of selected " 151 | f"molecules might be less than desired.\nIn order to avoid this " 152 | f"problem. Try to use less number of clusters." 153 | ) 154 | # save the number of samples to be selected from each cluster in an array 155 | size_each_cluster = np.full(num_clusters, size_each_cluster) 156 | 157 | for unique_label, size_sub in zip(unique_labels, size_each_cluster): 158 | if pop_clusters[unique_label] != 0: 159 | # sample size_each_cluster ids from cluster labeled unique_label 160 | cluster_ids = np.where(labels == unique_label)[0] 161 | selected = self.select_from_cluster(x, size_sub, cluster_ids) 162 | selected_ids.append(cluster_ids[selected]) 163 | 164 | return np.hstack(selected_ids).flatten().tolist() 165 | 166 | @abstractmethod 167 | def select_from_cluster( 168 | self, x: np.ndarray, size: int, labels: np.ndarray = None 169 | ) -> np.ndarray: # pragma: no cover 170 | """Return indices representing subset of sample points from one cluster. 171 | 172 | Parameters 173 | ---------- 174 | x: ndarray of shape (n_samples, n_features) or (n_samples, n_samples) 175 | Feature matrix of `n_samples` samples in `n_features` dimensional feature space. 176 | If fun_distance is `None`, this x is treated as a square pairwise distance matrix. 177 | size: int 178 | Number of sample points to select (i.e. size of the subset). 179 | labels: np.ndarray, optional 180 | Array of integers or strings representing the labels of the clusters that 181 | each sample belongs to. If `None`, the samples are treated as one cluster. 182 | If labels are provided, selection is made from each cluster. 183 | 184 | Returns 185 | ------- 186 | selected: list 187 | Indices of the selected sample points. 188 | """ 189 | raise NotImplementedError 190 | -------------------------------------------------------------------------------- /selector/methods/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | -------------------------------------------------------------------------------- /selector/methods/tests/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Common functions for test module.""" 25 | 26 | from importlib import resources 27 | from typing import Any, Tuple, Union 28 | 29 | import numpy as np 30 | from sklearn.datasets import make_blobs 31 | from sklearn.metrics import pairwise_distances 32 | 33 | __all__ = [ 34 | "generate_synthetic_cluster_data", 35 | "generate_synthetic_data", 36 | "get_data_file_path", 37 | ] 38 | 39 | 40 | def generate_synthetic_cluster_data(): 41 | # generate the first cluster with 3 points 42 | cluster_one = np.array([[0, 0], [0, 1], [0, 2]]) 43 | # generate the second cluster with 6 points 44 | cluster_two = np.array([[3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5]]) 45 | # generate the third cluster with 9 points 46 | cluster_three = np.array( 47 | [[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8]] 48 | ) 49 | # concatenate the clusters 50 | coords = np.vstack([cluster_one, cluster_two, cluster_three]) 51 | # generate the labels 52 | labels = np.hstack([[0 for _ in range(3)], [1 for _ in range(6)], [2 for _ in range(9)]]) 53 | 54 | return coords, labels, cluster_one, cluster_two, cluster_three 55 | 56 | 57 | def generate_synthetic_data( 58 | n_samples: int = 100, 59 | n_features: int = 2, 60 | n_clusters: int = 2, 61 | cluster_std: float = 1.0, 62 | center_box: Tuple[float, float] = (-10.0, 10.0), 63 | metric: str = "euclidean", 64 | shuffle: bool = True, 65 | random_state: int = 42, 66 | pairwise_dist: bool = False, 67 | **kwargs: Any, 68 | ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]: 69 | """Generate synthetic data. 70 | 71 | Parameters 72 | ---------- 73 | n_samples : int, optional 74 | The number of sample points. 75 | n_features : int, optional 76 | The number of features. 77 | n_clusters : int, optional 78 | The number of clusters. 79 | cluster_std : float, optional 80 | The standard deviation of the clusters. 81 | center_box : tuple[float, float], optional 82 | The bounding box for each cluster center when centers are generated at random. 83 | metric : str, optional 84 | The metric used for computing pairwise distances. For the supported 85 | distance matrix, please refer to 86 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances.html. 87 | shuffle : bool, optional 88 | Whether to shuffle the samples. 89 | random_state : int, optional 90 | The random state used for generating synthetic data. 91 | pairwise_dist : bool, optional 92 | If True, then compute and return the pairwise distances between sample points. 93 | **kwargs : Any, optional 94 | Additional keyword arguments for the scikit-learn `pairwise_distances` function. 95 | 96 | Returns 97 | ------- 98 | syn_data : np.ndarray 99 | The synthetic data. 100 | class_labels : np.ndarray 101 | The integer labels for cluster membership of each sample. 102 | dist: np.ndarray 103 | The symmetric pairwise distances between samples. 104 | 105 | """ 106 | # pylint: disable=W0632 107 | syn_data, class_labels = make_blobs( 108 | n_samples=n_samples, 109 | n_features=n_features, 110 | centers=n_clusters, 111 | cluster_std=cluster_std, 112 | center_box=center_box, 113 | shuffle=shuffle, 114 | random_state=random_state, 115 | return_centers=False, 116 | ) 117 | if pairwise_dist: 118 | dist = pairwise_distances( 119 | X=syn_data, 120 | Y=None, 121 | metric=metric, 122 | **kwargs, 123 | ) 124 | return syn_data, class_labels, dist 125 | return syn_data, class_labels 126 | 127 | 128 | def get_data_file_path(file_name): 129 | """Get the absolute path of the data file inside the package. 130 | 131 | Parameters 132 | ---------- 133 | file_name : str 134 | The name of the data file to load. 135 | 136 | Returns 137 | ------- 138 | str 139 | The absolute path of the data file inside the package 140 | 141 | """ 142 | data_file_path = resources.files("selector.methods.tests").joinpath(f"data/{file_name}") 143 | 144 | return data_file_path 145 | -------------------------------------------------------------------------------- /selector/methods/tests/data/coords_imbalance_case1.txt: -------------------------------------------------------------------------------- 1 | -2.988371860898040300e+00,8.828627151534506723e+00 2 | -2.522694847790684314e+00,7.956575199242423402e+00 3 | 2.721107620929060111e+00,1.946655808491515094e+00 4 | 3.856625543891864183e+00,1.651108167735056309e+00 5 | 4.447517871446978965e+00,2.274717026274344356e+00 6 | 4.247770683095943411e+00,5.096547358086134238e-01 7 | 5.161820401844998685e+00,2.270154357173918225e+00 8 | 3.448575339025452990e+00,2.629723292574561722e+00 9 | 4.110118632461063015e+00,2.486437117054088208e+00 10 | 4.605167066522858121e+00,8.044916463211999602e-01 11 | 3.959854114649610679e+00,2.205423381101735636e+00 12 | 4.935999113292677265e+00,2.234224956120621108e+00 13 | -7.194896435791616085e+00,-6.121140372782679862e+00 14 | -6.521839830802987237e+00,-6.319325066907712340e+00 15 | -6.665533447021066316e+00,-8.125848371987935082e+00 16 | -4.564968624477761416e+00,-8.747374785867695124e+00 17 | -4.735683101825944874e+00,-6.246190570957935506e+00 18 | -7.144284024389226495e+00,-4.159940426686327797e+00 19 | -6.364591923942610308e+00,-6.366323642363737711e+00 20 | -7.769141620776792934e+00,-7.695919878241385348e+00 21 | -6.821418472705270020e+00,-8.023079891106569050e+00 22 | -7.541413655919658510e+00,-6.027676258479722549e+00 23 | -6.706446265300088250e+00,-6.494792213547110116e+00 24 | -6.406389566577725070e+00,-6.952938505932819702e+00 25 | -7.609993822868406532e+00,-6.663651003693972008e+00 26 | -5.796575947975993515e+00,-5.826307541241043886e+00 27 | -7.351559056940703663e+00,-5.791158996308579887e+00 28 | -7.364990738980373486e+00,-6.798235453889623692e+00 29 | -6.956728900565374296e+00,-6.538957618459303234e+00 30 | -6.253959843386263984e+00,-7.737267149692229395e+00 31 | -6.057567031156779969e+00,-4.983316610621999487e+00 32 | -7.594930900411238639e+00,-6.200511844341271228e+00 33 | -7.125015307154140665e+00,-7.633845757633435980e+00 34 | -7.672147929583970516e+00,-6.994846034742845831e+00 35 | -7.103089976477121148e+00,-6.166109099183854525e+00 36 | -6.602936391821250695e+00,-6.052926344239923040e+00 37 | -8.904769777808876796e+00,-6.693655278506518869e+00 38 | -8.257296559108361578e+00,-7.817934633191069516e+00 39 | -6.364579504845222502e+00,-3.027378102621225864e+00 40 | -6.834055351247456223e+00,-7.531709940881763821e+00 41 | -7.652452405688841885e+00,-7.116928200015955497e+00 42 | -7.726420909219674726e+00,-8.394956817961810813e+00 43 | -6.866625299273363403e+00,-5.426575516118630205e+00 44 | -6.374639912170812828e+00,-6.014354399105824811e+00 45 | -7.326142143218291380e+00,-6.023710798952474299e+00 46 | -6.308736680458102875e+00,-5.744543953095347710e+00 47 | -8.079923598207045643e+00,-7.214610829116894664e+00 48 | -6.193367000776756726e+00,-8.492825464465598273e+00 49 | -5.925625427658067323e+00,-6.228718341970148842e+00 50 | -7.950519689212382168e+00,-6.397637178032761440e+00 51 | -7.763484627352402967e+00,-6.726384487330419049e+00 52 | -6.815347172055806979e+00,-7.957854371205252519e+00 53 | -------------------------------------------------------------------------------- /selector/methods/tests/data/coords_imbalance_case2.txt: -------------------------------------------------------------------------------- 1 | -2.545023662162701594e+00,1.057892978401232931e+01 2 | -3.348415146275388832e+00,8.705073752347109561e+00 3 | -3.186119623358708797e+00,9.625962417039191976e+00 4 | 6.526064737438631802e+00,2.147747496772570930e+00 5 | 5.265546183993107476e+00,1.116012127524449449e+00 6 | 3.793085118159696290e+00,4.583224592548673648e-01 7 | 4.605167066522858121e+00,8.044916463211999602e-01 8 | 3.665197166000779827e+00,2.760254287683184149e+00 9 | 4.890371686573978138e+00,2.319617893437707856e+00 10 | 3.089215405161968686e+00,2.041732658746759466e+00 11 | 4.416416050902250312e+00,2.687170178032824097e+00 12 | 3.568986338166989292e+00,2.455642099183917182e+00 13 | 4.447517871446978965e+00,2.274717026274344356e+00 14 | 5.161820401844998685e+00,2.270154357173918225e+00 15 | -6.598635323416237597e+00,-7.502809113096540194e+00 16 | -6.364591923942610308e+00,-6.366323642363737711e+00 17 | -7.351559056940703663e+00,-5.791158996308579887e+00 18 | -4.757470994138636833e+00,-5.847644332724799554e+00 19 | -7.132195342544430439e+00,-8.127892775240795231e+00 20 | -6.766109845900022179e+00,-6.217978918754900164e+00 21 | -6.680567495577800052e+00,-7.480326470434741637e+00 22 | -7.354572502312226590e+00,-7.533438825849658294e+00 23 | -4.735683101825944874e+00,-6.246190570957935506e+00 24 | -8.140511145486314604e+00,-5.962247646221170427e+00 25 | -6.374639912170812828e+00,-6.014354399105824811e+00 26 | -4.746593816495003892e+00,-8.832197392798448732e+00 27 | -7.652452405688841885e+00,-7.116928200015955497e+00 28 | -6.435807763005041870e+00,-6.105475539846610289e+00 29 | -6.308736680458102875e+00,-5.744543953095347710e+00 30 | -6.900528785115418451e+00,-6.762782209967165059e+00 31 | -6.834055351247456223e+00,-7.531709940881763821e+00 32 | -8.079923598207045643e+00,-7.214610829116894664e+00 33 | -7.672147929583970516e+00,-6.994846034742845831e+00 34 | -5.293610375005918023e+00,-8.117925092102796114e+00 35 | -7.087749441508545800e+00,-7.373110527934779945e+00 36 | -5.247215887219635277e+00,-8.310250971236579076e+00 37 | -6.364579504845222502e+00,-3.027378102621225864e+00 38 | -7.364990738980373486e+00,-6.798235453889623692e+00 39 | -7.861135842199221457e+00,-6.418006119012676258e+00 40 | -6.132333586028008376e+00,-6.269739327842481558e+00 41 | -5.925625427658067323e+00,-6.228718341970148842e+00 42 | -5.612716041964647573e+00,-7.587779058894727591e+00 43 | -5.980027315718019487e+00,-6.572810072399337677e+00 44 | -7.031412286186853322e+00,-6.291792386791370539e+00 45 | -4.564968624477761416e+00,-8.747374785867695124e+00 46 | -7.319671677848253566e+00,-6.749369015989855392e+00 47 | -5.438353902085154346e+00,-8.315971744455385561e+00 48 | -6.809825106161251362e+00,-7.265423190137706655e+00 49 | -8.904769777808876796e+00,-6.693655278506518869e+00 50 | -6.193367000776756726e+00,-8.492825464465598273e+00 51 | -8.398997157105283051e+00,-7.364343666142198153e+00 52 | -6.522611705186222686e+00,-7.573019188536600943e+00 53 | -5.716463438996310487e+00,-6.869876532256359525e+00 54 | -1.012089453122034222e+01,-7.904497234610236234e+00 55 | -------------------------------------------------------------------------------- /selector/methods/tests/data/labels_imbalance_case1.txt: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 0.000000000000000000e+00 3 | 1.000000000000000000e+00 4 | 1.000000000000000000e+00 5 | 1.000000000000000000e+00 6 | 1.000000000000000000e+00 7 | 1.000000000000000000e+00 8 | 1.000000000000000000e+00 9 | 1.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 1.000000000000000000e+00 12 | 1.000000000000000000e+00 13 | 2.000000000000000000e+00 14 | 2.000000000000000000e+00 15 | 2.000000000000000000e+00 16 | 2.000000000000000000e+00 17 | 2.000000000000000000e+00 18 | 2.000000000000000000e+00 19 | 2.000000000000000000e+00 20 | 2.000000000000000000e+00 21 | 2.000000000000000000e+00 22 | 2.000000000000000000e+00 23 | 2.000000000000000000e+00 24 | 2.000000000000000000e+00 25 | 2.000000000000000000e+00 26 | 2.000000000000000000e+00 27 | 2.000000000000000000e+00 28 | 2.000000000000000000e+00 29 | 2.000000000000000000e+00 30 | 2.000000000000000000e+00 31 | 2.000000000000000000e+00 32 | 2.000000000000000000e+00 33 | 2.000000000000000000e+00 34 | 2.000000000000000000e+00 35 | 2.000000000000000000e+00 36 | 2.000000000000000000e+00 37 | 2.000000000000000000e+00 38 | 2.000000000000000000e+00 39 | 2.000000000000000000e+00 40 | 2.000000000000000000e+00 41 | 2.000000000000000000e+00 42 | 2.000000000000000000e+00 43 | 2.000000000000000000e+00 44 | 2.000000000000000000e+00 45 | 2.000000000000000000e+00 46 | 2.000000000000000000e+00 47 | 2.000000000000000000e+00 48 | 2.000000000000000000e+00 49 | 2.000000000000000000e+00 50 | 2.000000000000000000e+00 51 | 2.000000000000000000e+00 52 | 2.000000000000000000e+00 53 | -------------------------------------------------------------------------------- /selector/methods/tests/data/labels_imbalance_case2.txt: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 0.000000000000000000e+00 3 | 0.000000000000000000e+00 4 | 1.000000000000000000e+00 5 | 1.000000000000000000e+00 6 | 1.000000000000000000e+00 7 | 1.000000000000000000e+00 8 | 1.000000000000000000e+00 9 | 1.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 1.000000000000000000e+00 12 | 1.000000000000000000e+00 13 | 1.000000000000000000e+00 14 | 1.000000000000000000e+00 15 | 2.000000000000000000e+00 16 | 2.000000000000000000e+00 17 | 2.000000000000000000e+00 18 | 2.000000000000000000e+00 19 | 2.000000000000000000e+00 20 | 2.000000000000000000e+00 21 | 2.000000000000000000e+00 22 | 2.000000000000000000e+00 23 | 2.000000000000000000e+00 24 | 2.000000000000000000e+00 25 | 2.000000000000000000e+00 26 | 2.000000000000000000e+00 27 | 2.000000000000000000e+00 28 | 2.000000000000000000e+00 29 | 2.000000000000000000e+00 30 | 2.000000000000000000e+00 31 | 2.000000000000000000e+00 32 | 2.000000000000000000e+00 33 | 2.000000000000000000e+00 34 | 2.000000000000000000e+00 35 | 2.000000000000000000e+00 36 | 2.000000000000000000e+00 37 | 2.000000000000000000e+00 38 | 2.000000000000000000e+00 39 | 2.000000000000000000e+00 40 | 2.000000000000000000e+00 41 | 2.000000000000000000e+00 42 | 2.000000000000000000e+00 43 | 2.000000000000000000e+00 44 | 2.000000000000000000e+00 45 | 2.000000000000000000e+00 46 | 2.000000000000000000e+00 47 | 2.000000000000000000e+00 48 | 2.000000000000000000e+00 49 | 2.000000000000000000e+00 50 | 2.000000000000000000e+00 51 | 2.000000000000000000e+00 52 | 2.000000000000000000e+00 53 | 2.000000000000000000e+00 54 | 2.000000000000000000e+00 55 | -------------------------------------------------------------------------------- /selector/methods/tests/test_distance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Test selector/methods/distance.py.""" 25 | 26 | import numpy as np 27 | import pytest 28 | from numpy.testing import assert_equal, assert_raises 29 | from scipy.spatial.distance import pdist, squareform 30 | from sklearn.metrics import pairwise_distances 31 | 32 | from selector.methods.distance import DISE, MaxMin, MaxSum, OptiSim 33 | from selector.methods.tests.common import ( 34 | generate_synthetic_cluster_data, 35 | generate_synthetic_data, 36 | get_data_file_path, 37 | ) 38 | 39 | 40 | def test_maxmin(): 41 | """Testing the MaxMin class.""" 42 | # generate random data points belonging to one cluster - pairwise distance matrix 43 | _, _, arr_dist = generate_synthetic_data( 44 | n_samples=100, 45 | n_features=2, 46 | n_clusters=1, 47 | pairwise_dist=True, 48 | metric="euclidean", 49 | random_state=42, 50 | ) 51 | 52 | # generate random data points belonging to multiple clusters - class labels and pairwise distance matrix 53 | _, class_labels_cluster, arr_dist_cluster = generate_synthetic_data( 54 | n_samples=100, 55 | n_features=2, 56 | n_clusters=3, 57 | pairwise_dist=True, 58 | metric="euclidean", 59 | random_state=42, 60 | ) 61 | 62 | # use MaxMin algorithm to select points from clustered data 63 | collector = MaxMin() 64 | selected_ids = collector.select( 65 | arr_dist_cluster, 66 | size=12, 67 | labels=class_labels_cluster, 68 | proportional_selection=False, 69 | ) 70 | # make sure all the selected indices are the same with expectation 71 | assert_equal(selected_ids, [41, 34, 94, 85, 51, 50, 66, 78, 21, 64, 29, 83]) 72 | 73 | # use MaxMin algorithm to select points from non-clustered data 74 | collector = MaxMin() 75 | selected_ids = collector.select(arr_dist, size=12) 76 | # make sure all the selected indices are the same with expectation 77 | assert_equal(selected_ids, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97]) 78 | 79 | # use MaxMin algorithm to select points from non-clustered data with "medoid" as the reference point 80 | collector_medoid_1 = MaxMin(ref_index=None) 81 | selected_ids_medoid_1 = collector_medoid_1.select(arr_dist, size=12) 82 | # make sure all the selected indices are the same with expectation 83 | assert_equal(selected_ids_medoid_1, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97]) 84 | 85 | # use MaxMin algorithm to select points from non-clustered data with "None" for the reference point 86 | collector_medoid_2 = MaxMin(ref_index=None) 87 | selected_ids_medoid_2 = collector_medoid_2.select(arr_dist, size=12) 88 | # make sure all the selected indices are the same with expectation 89 | assert_equal(selected_ids_medoid_2, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97]) 90 | 91 | # use MaxMin algorithm to select points from non-clustered data with float as the reference point 92 | collector_float = MaxMin(ref_index=85) 93 | selected_ids_float = collector_float.select(arr_dist, size=12) 94 | # make sure all the selected indices are the same with expectation 95 | assert_equal(selected_ids_float, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97]) 96 | 97 | # use MaxMin algorithm to select points from non-clustered data with a predefined list as the reference point 98 | collector_float = MaxMin(ref_index=[85, 57, 41, 25]) 99 | selected_ids_float = collector_float.select(arr_dist, size=12) 100 | # make sure all the selected indices are the same with expectation 101 | assert_equal(selected_ids_float, [85, 57, 41, 25, 9, 62, 29, 65, 81, 61, 60, 97]) 102 | 103 | # test failing case when ref_index is not a valid index 104 | with pytest.raises(ValueError): 105 | collector_float = MaxMin(ref_index=-3) 106 | selected_ids_float = collector_float.select(arr_dist, size=12) 107 | # test failing case when ref_index contains a complex number 108 | with pytest.raises(ValueError): 109 | collector_float = MaxMin(ref_index=[1 + 5j, 2, 5]) 110 | _ = collector_float.select(arr_dist, size=12) 111 | # test failing case when ref_index contains a negative number 112 | with pytest.raises(ValueError): 113 | collector_float = MaxMin(ref_index=[-1, 2, 5]) 114 | _ = collector_float.select(arr_dist, size=12) 115 | 116 | # test failing case when the number of labels is not equal to the number of samples 117 | with pytest.raises(ValueError): 118 | collector_float = MaxMin(ref_index=85) 119 | _ = collector_float.select( 120 | arr_dist, size=12, labels=class_labels_cluster[:90], proportional_selection=False 121 | ) 122 | 123 | # use MaxMin algorithm, this time instantiating with a distance metric 124 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean")) 125 | simple_coords = np.array([[0, 0], [2, 0], [0, 2], [2, 2], [-10, -10]]) 126 | # provide coordinates rather than pairwise distance matrix to collector 127 | selected_ids = collector.select(x=simple_coords, size=3) 128 | # make sure all the selected indices are the same with expectation 129 | assert_equal(selected_ids, [0, 4, 3]) 130 | 131 | # generating mocked clusters 132 | np.random.seed(42) 133 | cluster_one = np.random.normal(0, 1, (3, 2)) 134 | cluster_two = np.random.normal(10, 1, (6, 2)) 135 | cluster_three = np.random.normal(20, 1, (10, 2)) 136 | labels_mocked = np.hstack( 137 | [[0 for i in range(3)], [1 for i in range(6)], [2 for i in range(10)]] 138 | ) 139 | mocked_cluster_coords = np.vstack([cluster_one, cluster_two, cluster_three]) 140 | 141 | # selecting molecules 142 | collector = MaxMin(lambda x: pairwise_distances(x, metric="euclidean")) 143 | selected_mocked = collector.select( 144 | mocked_cluster_coords, 145 | size=15, 146 | labels=labels_mocked, 147 | proportional_selection=False, 148 | ) 149 | assert_equal(selected_mocked, [0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 15, 10, 13, 9, 18]) 150 | 151 | 152 | def test_maxmin_proportional_selection(): 153 | """Test MaxMin class with proportional selection.""" 154 | # generate the first cluster with 3 points 155 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() 156 | # instantiate the MaxMin class 157 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 158 | # select 6 points with proportional selection from each cluster 159 | selected_ids = collector.select( 160 | coords, 161 | size=6, 162 | labels=labels, 163 | proportional_selection=True, 164 | ) 165 | # make sure all the selected indices are the same with expectation 166 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 13]) 167 | # check how many points are selected from each cluster 168 | assert_equal(len(selected_ids), 6) 169 | # check the number of points selected from cluster one 170 | assert_equal((labels[selected_ids] == 0).sum(), 1) 171 | # check the number of points selected from cluster two 172 | assert_equal((labels[selected_ids] == 1).sum(), 2) 173 | # check the number of points selected from cluster three 174 | assert_equal((labels[selected_ids] == 2).sum(), 3) 175 | 176 | 177 | def test_maxmin_proportional_selection_imbalance_1(): 178 | """Test MaxMin class with proportional selection with imbalance case 1.""" 179 | # load three-cluster data from file 180 | # 2 from class 0, 10 from class 1, 40 from class 2 181 | coords_file_path = get_data_file_path("coords_imbalance_case1.txt") 182 | coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0) 183 | labels_file_path = get_data_file_path("labels_imbalance_case1.txt") 184 | labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0) 185 | 186 | # instantiate the MaxMin class 187 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 188 | # select 12 points with proportional selection from each cluster 189 | selected_ids = collector.select( 190 | coords, 191 | size=9, 192 | labels=labels, 193 | proportional_selection=True, 194 | ) 195 | 196 | # make sure all the selected indices are the same with expectation 197 | assert_equal(selected_ids, [0, 2, 6, 12, 15, 38, 16, 41, 36]) 198 | # check how many points are selected from each cluster 199 | assert_equal(len(selected_ids), 9) 200 | # check the number of points selected from cluster one 201 | assert_equal((labels[selected_ids] == 0).sum(), 1) 202 | # check the number of points selected from cluster two 203 | assert_equal((labels[selected_ids] == 1).sum(), 2) 204 | # check the number of points selected from cluster three 205 | assert_equal((labels[selected_ids] == 2).sum(), 6) 206 | 207 | 208 | def test_maxmin_proportional_selection_imbalance_2(): 209 | """Test MaxMin class with proportional selection with imbalance case 2.""" 210 | # load three-cluster data from file 211 | # 3 from class 0, 11 from class 1, 40 from class 2 212 | coords_file_path = get_data_file_path("coords_imbalance_case2.txt") 213 | coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0) 214 | labels_file_path = get_data_file_path("labels_imbalance_case2.txt") 215 | labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0) 216 | 217 | # instantiate the MaxMin class 218 | collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 219 | # select 12 points with proportional selection from each cluster 220 | selected_ids = collector.select( 221 | coords, 222 | size=14, 223 | labels=labels, 224 | proportional_selection=True, 225 | ) 226 | 227 | # # make sure all the selected indices are the same with expectation 228 | assert_equal(selected_ids, [0, 3, 9, 6, 14, 36, 53, 17, 44, 23, 28, 50, 52, 49]) 229 | print(f"selected_ids: {selected_ids}") 230 | # check how many points are selected from each cluster 231 | assert_equal(len(selected_ids), 14) 232 | # check the number of points selected from cluster one 233 | assert_equal((labels[selected_ids] == 0).sum(), 1) 234 | # check the number of points selected from cluster two 235 | assert_equal((labels[selected_ids] == 1).sum(), 3) 236 | # check the number of points selected from cluster three 237 | assert_equal((labels[selected_ids] == 2).sum(), 10) 238 | 239 | 240 | def test_maxmin_invalid_input(): 241 | """Testing MaxMin class with invalid input.""" 242 | # case when the distance matrix is not square 243 | x_dist = np.array([[0, 1], [1, 0], [4, 9]]) 244 | with pytest.raises(ValueError): 245 | collector = MaxMin(ref_index=0) 246 | _ = collector.select(x_dist, size=2) 247 | 248 | # case when the distance matrix is not symmetric 249 | x_dist = np.array([[0, 1, 2, 1], [1, 1, 0, 3], [4, 9, 4, 0], [6, 5, 6, 7]]) 250 | with pytest.raises(ValueError): 251 | collector = MaxMin(ref_index=0) 252 | _ = collector.select(x_dist, size=2) 253 | 254 | 255 | def test_maxsum_clustered_data(): 256 | """Testing MaxSum class.""" 257 | # generate random data points belonging to multiple clusters - coordinates and class labels 258 | coords_cluster, class_labels_cluster, coords_cluster_dist = generate_synthetic_data( 259 | n_samples=100, 260 | n_features=2, 261 | n_clusters=3, 262 | pairwise_dist=True, 263 | metric="euclidean", 264 | random_state=42, 265 | ) 266 | 267 | # use MaxSum algorithm to select points from clustered data, instantiating with euclidean distance metric 268 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None) 269 | selected_ids = collector.select( 270 | coords_cluster, 271 | size=12, 272 | labels=class_labels_cluster, 273 | proportional_selection=False, 274 | ) 275 | # make sure all the selected indices are the same with expectation 276 | assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83]) 277 | 278 | # use MaxSum algorithm to select points from clustered data without instantiating with euclidean distance metric 279 | collector = MaxSum(ref_index=None) 280 | selected_ids = collector.select( 281 | coords_cluster_dist, 282 | size=12, 283 | labels=class_labels_cluster, 284 | proportional_selection=False, 285 | ) 286 | # make sure all the selected indices are the same with expectation 287 | assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83]) 288 | 289 | # check that ValueError is raised when number of points requested is greater than number of points in array 290 | with pytest.raises(ValueError): 291 | _ = collector.select_from_cluster( 292 | coords_cluster, 293 | size=101, 294 | labels=class_labels_cluster, 295 | ) 296 | 297 | 298 | def test_maxsum_non_clustered_data(): 299 | """Testing MaxSum class with non-clustered data.""" 300 | # generate random data points belonging to one cluster - coordinates 301 | coords, _, _ = generate_synthetic_data( 302 | n_samples=100, 303 | n_features=2, 304 | n_clusters=1, 305 | pairwise_dist=True, 306 | metric="euclidean", 307 | random_state=42, 308 | ) 309 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean distance metric 310 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None) 311 | selected_ids = collector.select(coords, size=12) 312 | # make sure all the selected indices are the same with expectation 313 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54]) 314 | 315 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean 316 | # distance metric and using "medoid" as the reference point 317 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None) 318 | selected_ids = collector.select(coords, size=12) 319 | # make sure all the selected indices are the same with expectation 320 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54]) 321 | 322 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean 323 | # distance metric and using a list as the reference points 324 | collector = MaxSum( 325 | fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), 326 | ref_index=[85, 57, 25, 41, 95], 327 | ) 328 | selected_ids = collector.select(coords, size=12) 329 | # make sure all the selected indices are the same with expectation 330 | assert_equal(selected_ids, [85, 57, 25, 41, 95, 9, 21, 8, 13, 68, 37, 54]) 331 | 332 | # use MaxSum algorithm to select points from non-clustered data, instantiating with euclidean 333 | # distance metric and using an invalid reference point 334 | with pytest.raises(ValueError): 335 | collector = MaxSum( 336 | fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=-1 337 | ) 338 | selected_ids = collector.select(coords, size=12) 339 | 340 | 341 | def test_maxsum_invalid_input(): 342 | """Testing MaxSum class with invalid input.""" 343 | # case when the distance matrix is not square 344 | x_dist = np.array([[0, 1], [1, 0], [4, 9]]) 345 | with pytest.raises(ValueError): 346 | collector = MaxSum(ref_index=0) 347 | _ = collector.select(x_dist, size=2) 348 | 349 | # case when the distance matrix is not square 350 | x_dist = np.array([[0, 1, 2], [1, 0, 3], [4, 9, 0], [5, 6, 7]]) 351 | with pytest.raises(ValueError): 352 | collector = MaxSum(ref_index=0) 353 | _ = collector.select(x_dist, size=2) 354 | 355 | # case when the distance matrix is not symmetric 356 | x_dist = np.array([[0, 1, 2, 1], [1, 1, 0, 3], [4, 9, 4, 0], [6, 5, 6, 7]]) 357 | with pytest.raises(ValueError): 358 | collector = MaxSum(ref_index=0) 359 | _ = collector.select(x_dist, size=2) 360 | 361 | 362 | def test_maxsum_proportional_selection(): 363 | """Test MaxSum class with proportional selection.""" 364 | # generate the first cluster with 3 points 365 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() 366 | # instantiate the MaxSum class 367 | collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 368 | # select 6 points with proportional selection from each cluster 369 | selected_ids = collector.select( 370 | coords, 371 | size=6, 372 | labels=labels, 373 | proportional_selection=True, 374 | ) 375 | # make sure all the selected indices are the same with expectation 376 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 10]) 377 | # check how many points are selected from each cluster 378 | assert_equal(len(selected_ids), 6) 379 | # check the number of points selected from cluster one 380 | assert_equal((labels[selected_ids] == 0).sum(), 1) 381 | # check the number of points selected from cluster two 382 | assert_equal((labels[selected_ids] == 1).sum(), 2) 383 | # check the number of points selected from cluster three 384 | assert_equal((labels[selected_ids] == 2).sum(), 3) 385 | 386 | 387 | def test_optisim(): 388 | """Testing OptiSim class.""" 389 | # generate random data points belonging to one cluster - coordinates and pairwise distance matrix 390 | coords, _, arr_dist = generate_synthetic_data( 391 | n_samples=100, 392 | n_features=2, 393 | n_clusters=1, 394 | pairwise_dist=True, 395 | metric="euclidean", 396 | random_state=42, 397 | ) 398 | 399 | # generate random data points belonging to multiple clusters - coordinates and class labels 400 | coords_cluster, class_labels_cluster, _ = generate_synthetic_data( 401 | n_samples=100, 402 | n_features=2, 403 | n_clusters=3, 404 | pairwise_dist=True, 405 | metric="euclidean", 406 | random_state=42, 407 | ) 408 | 409 | # use OptiSim algorithm to select points from clustered data 410 | collector = OptiSim(ref_index=0) 411 | selected_ids = collector.select(coords_cluster, size=12, labels=class_labels_cluster) 412 | # make sure all the selected indices are the same with expectation 413 | # assert_equal(selected_ids, [2, 85, 86, 59, 1, 66, 50, 68, 0, 64, 83, 72]) 414 | 415 | # use OptiSim algorithm to select points from non-clustered data 416 | collector = OptiSim(ref_index=0) 417 | selected_ids = collector.select(coords, size=12) 418 | # make sure all the selected indices are the same with expectation 419 | assert_equal(selected_ids, [0, 8, 55, 37, 41, 13, 12, 42, 6, 30, 57, 76]) 420 | 421 | # check if OptiSim gives same results as MaxMin for k=>infinity 422 | collector = OptiSim(ref_index=85, k=999999) 423 | selected_ids_optisim = collector.select(coords, size=12) 424 | collector = MaxMin() 425 | selected_ids_maxmin = collector.select(arr_dist, size=12) 426 | assert_equal(selected_ids_optisim, selected_ids_maxmin) 427 | 428 | # test with invalid ref_index 429 | with pytest.raises(ValueError): 430 | collector = OptiSim(ref_index=10000) 431 | _ = collector.select(coords, size=12) 432 | 433 | 434 | def test_optisim_proportional_selection(): 435 | """Test OptiSim class with proportional selection.""" 436 | # generate the first cluster with 3 points 437 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() 438 | # instantiate the Optisim class 439 | collector = OptiSim(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 440 | # select 6 points with proportional selection from each cluster 441 | selected_ids = collector.select( 442 | coords, 443 | size=6, 444 | labels=labels, 445 | proportional_selection=True, 446 | ) 447 | # make sure all the selected indices are the same with expectation 448 | assert_equal(selected_ids, [0, 3, 8, 9, 17, 13]) 449 | # check how many points are selected from each cluster 450 | assert_equal(len(selected_ids), 6) 451 | # check the number of points selected from cluster one 452 | assert_equal((labels[selected_ids] == 0).sum(), 1) 453 | # check the number of points selected from cluster two 454 | assert_equal((labels[selected_ids] == 1).sum(), 2) 455 | # check the number of points selected from cluster three 456 | assert_equal((labels[selected_ids] == 2).sum(), 3) 457 | 458 | 459 | def test_directed_sphere_size_error(): 460 | """Test DirectedSphereExclusion error when too many points requested.""" 461 | x = np.array([[1, 9]] * 100) 462 | collector = DISE() 463 | assert_raises(ValueError, collector.select, x, size=105) 464 | 465 | 466 | def test_directed_sphere_same_number_of_pts(): 467 | """Test DirectSphereExclusion with `size` = number of points in dataset.""" 468 | # (0,0) as the reference point 469 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3]]) 470 | collector = DISE(r0=1, tol=0, ref_index=0) 471 | selected = collector.select(x, size=2) 472 | assert_equal(selected, [0, 2]) 473 | assert_equal(collector.r, 1) 474 | 475 | 476 | def test_directed_sphere_same_number_of_pts_None(): 477 | """Test DirectSphereExclusion with `size` = number of points in dataset with the ref_index None.""" 478 | # None as the reference point 479 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]]) 480 | collector = DISE(r0=1, tol=0, ref_index=None) 481 | selected = collector.select(x, size=3) 482 | assert_equal(selected, [2, 0, 4]) 483 | assert_equal(collector.r, 1) 484 | 485 | 486 | def test_directed_sphere_exclusion_select_more_number_of_pts(): 487 | """Test DirectSphereExclusion on points on the line with `size` < number of points in dataset.""" 488 | # (0,0) as the reference point 489 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]]) 490 | collector = DISE(r0=0.5, tol=0, ref_index=0) 491 | selected = collector.select(x, size=3) 492 | expected = [0, 3, 6] 493 | assert_equal(selected, expected) 494 | assert_equal(collector.r, 2.0) 495 | 496 | 497 | def test_directed_sphere_exclusion_on_line_with_smaller_radius(): 498 | """Test Direct Sphere Exclusion on points on line with smaller distribution than the radius.""" 499 | # (0,0) as the reference point 500 | x = np.array( 501 | [ 502 | [0, 0], 503 | [0, 1], 504 | [0, 1.1], 505 | [0, 1.2], 506 | [0, 2], 507 | [0, 3], 508 | [0, 3.1], 509 | [0, 3.2], 510 | [0, 4], 511 | [0, 5], 512 | [0, 6], 513 | ] 514 | ) 515 | collector = DISE(r0=0.5, tol=1, ref_index=1) 516 | selected = collector.select(x, size=3) 517 | expected = [1, 5, 9] 518 | assert_equal(selected, expected) 519 | assert_equal(collector.r, 1.0) 520 | 521 | 522 | def test_directed_sphere_on_line_with_larger_radius(): 523 | """Test Direct Sphere Exclusion on points on the line with a too large radius size.""" 524 | # (0,0) as the reference point 525 | x = np.array( 526 | [ 527 | [0, 0], 528 | [0, 1], 529 | [0, 1.1], 530 | [0, 1.2], 531 | [0, 2], 532 | [0, 3], 533 | [0, 3.1], 534 | [0, 3.2], 535 | [0, 4], 536 | [0, 5], 537 | ] 538 | ) 539 | collector = DISE(r0=2.0, tol=0, p=2.0, ref_index=1) 540 | selected = collector.select(x, size=3) 541 | expected = [1, 5, 9] 542 | assert_equal(selected, expected) 543 | assert_equal(collector.r, 1.0) 544 | 545 | 546 | def test_directed_sphere_dist_func(): 547 | """Test Direct Sphere Exclusion with a distance function.""" 548 | # (0,0) as the reference point 549 | x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]]) 550 | collector = DISE( 551 | r0=0.5, 552 | tol=0, 553 | ref_index=0, 554 | fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1)), 555 | ) 556 | selected = collector.select(x, size=3) 557 | expected = [0, 3, 6] 558 | assert_equal(selected, expected) 559 | assert_equal(collector.r, 2.0) 560 | 561 | 562 | def test_directed_sphere_proportional_selection(): 563 | """Test DISE class with proportional selection.""" 564 | # generate the first cluster with 3 points 565 | coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() 566 | # instantiate the DISE class 567 | collector = DISE(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) 568 | # select 6 points with proportional selection from each cluster 569 | selected_ids = collector.select( 570 | coords, 571 | size=6, 572 | labels=labels, 573 | proportional_selection=True, 574 | ) 575 | # make sure all the selected indices are the same with expectation 576 | assert_equal(selected_ids, [0, 3, 7, 9, 12, 15]) 577 | # check how many points are selected from each cluster 578 | assert_equal(len(selected_ids), 6) 579 | # check the number of points selected from cluster one 580 | assert_equal((labels[selected_ids] == 0).sum(), 1) 581 | # check the number of points selected from cluster two 582 | assert_equal((labels[selected_ids] == 1).sum(), 2) 583 | # check the number of points selected from cluster three 584 | assert_equal((labels[selected_ids] == 2).sum(), 3) 585 | -------------------------------------------------------------------------------- /selector/methods/tests/test_partition.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Test Partition-Based Selection Methods.""" 25 | 26 | import numpy as np 27 | import pytest 28 | from numpy.testing import assert_equal, assert_raises 29 | 30 | from selector.methods.partition import GridPartition, Medoid 31 | from selector.methods.tests.common import generate_synthetic_data 32 | 33 | 34 | @pytest.mark.parametrize("numb_pts", [4]) 35 | @pytest.mark.parametrize("method", ["equifrequent", "equisized"]) 36 | def test_grid_partitioning_independent_on_simple_example(numb_pts, method): 37 | r"""Test grid partitioning on an example each bin has only one point.""" 38 | # Construct feature array where each molecule is known to be in which bin 39 | # The grid is a uniform grid from 0 to 10 in X-axis and 0 to 11 on y-axis. 40 | x = np.linspace(0, 3, numb_pts) 41 | y = np.linspace(0, 3, numb_pts) 42 | X, Y = np.meshgrid(x, y) 43 | grid = np.array([1.0, 0.0]) + np.vstack([X.ravel(), Y.ravel()]).T 44 | # Make one bin have an extra point 45 | grid = np.vstack((grid, np.array([1.1, 0.0]))) 46 | 47 | # Here the number of cells should be equal to the number of points in each dimension 48 | # excluding the extra point, so that the answer is unique/known. 49 | collector = GridPartition(nbins_axis=4, bin_method=f"{method}_independent") 50 | # Sort the points so that they're comparable to the expected answer. 51 | selected_ids = np.sort(collector.select(grid, size=len(grid) - 1)) 52 | expected = np.arange(len(grid) - 1) 53 | assert_equal(selected_ids, expected) 54 | 55 | 56 | def test_grid_partitioning_equisized_dependent_on_simple_example(): 57 | r"""Test equisized_dependent grid partitioning on example that is different from independent.""" 58 | # Construct feature array where each molecule is known to be in which bin 59 | grid = np.array( 60 | [ 61 | [0.0, 0.0], # Corresponds to bin (0, 0) 62 | [0.0, 4.0], # Corresponds to bin (0, 3) 63 | [1.0, 1.0], # Corresponds to bin (1, 0) 64 | [1.0, 0.9], # Corresponds to bin (1, 0) 65 | [1.0, 2.0], # Corresponds to bin (1, 3) 66 | [2.0, 0.0], # Corresponds to bin (2, 0) 67 | [2.0, 4.0], # Corresponds to bin (2, 3) 68 | [3.0, 0.0], # Corresponds to bin (3, 0) 69 | [3.0, 4.0], # Corresponds to bin (3, 3) 70 | [3.0, 3.9], # Corresponds to bin (3, 3) 71 | ] 72 | ) 73 | 74 | # The number of bins makes it so that it approximately be a single point in each bin 75 | collector = GridPartition(nbins_axis=4, bin_method="equisized_dependent") 76 | # Two bins have an extra point in them and so has more diversity than other bins 77 | # then the two expected molecules should be in those bins. 78 | selected_ids = collector.select(grid, size=2, labels=None) 79 | right_molecules = True 80 | if not (2 in selected_ids or 3 in selected_ids): 81 | right_molecules = False 82 | if not (8 in selected_ids or 9 in selected_ids): 83 | right_molecules = False 84 | assert right_molecules, "The correct points were selected" 85 | 86 | 87 | @pytest.mark.parametrize("numb_pts", [4]) 88 | def test_grid_partitioning_equifrequent_dependent_on_simple_example(numb_pts): 89 | r"""Test equifrequent dependent grid partitioning on an example where each bin has only one point.""" 90 | # Construct feature array where each molecule is known to be in which bin 91 | # The grid is a uniform grid from 0 to 10 in X-axis and 0 to 11 on y-axis. 92 | x = np.linspace(0, 3, numb_pts) 93 | y = np.linspace(0, 3, numb_pts) 94 | X, Y = np.meshgrid(x, y) 95 | grid = np.array([1.0, 0.0]) + np.vstack([X.ravel(), Y.ravel()]).T 96 | # Make one bin have an extra point 97 | grid = np.vstack((grid, np.array([1.1, 0.0]))) 98 | 99 | # Here the number of cells should be equal to the number of points in each dimension 100 | # excluding the extra point, so that the answer is unique/known. 101 | collector = GridPartition(nbins_axis=numb_pts, bin_method="equifrequent_dependent") 102 | # Sort the points so that they're comparable to the expected answer. 103 | selected_ids = np.sort(collector.select(grid, size=len(grid) - 1)) 104 | expected = np.arange(len(grid) - 1) 105 | assert_equal(selected_ids, expected) 106 | 107 | 108 | @pytest.mark.parametrize("numb_pts", [10, 20, 30]) 109 | @pytest.mark.parametrize("method", ["equifrequent", "equisized"]) 110 | def test_bins_from_both_methods_dependent_same_as_independent_on_uniform_grid(numb_pts, method): 111 | r"""Test bins is the same between the two equisized methods on uniform grid in three-dimensions.""" 112 | x = np.linspace(0, 10, numb_pts) 113 | y = np.linspace(0, 11, numb_pts) 114 | X = np.meshgrid(x, y, y) 115 | grid = np.vstack(list(map(np.ravel, X))).T 116 | grid = np.array([1.0, 0.0, 0.0]) + grid 117 | 118 | # Here the number of cells should be equal to the number of points in each dimension 119 | # excluding the extra point, so that the answer is unique/known. 120 | collector_indept = GridPartition(nbins_axis=numb_pts, bin_method=f"{method}_independent") 121 | collector_depend = GridPartition(nbins_axis=numb_pts, bin_method=f"{method}_dependent") 122 | 123 | # Get the bins from the method 124 | bins_indept = collector_indept.get_bins_from_method(grid) 125 | bins_dept = collector_depend.get_bins_from_method(grid) 126 | 127 | # Test the bins are the same 128 | for key in bins_indept.keys(): 129 | assert_equal(bins_dept[key], bins_indept[key]) 130 | 131 | 132 | def test_raises_grid_partitioning(): 133 | r"""Test raises error for grid partitioning.""" 134 | grid = np.random.uniform(0.0, 1.0, size=(10, 3)) 135 | 136 | assert_raises(TypeError, GridPartition, 5.0) # Test number of axis should be integer 137 | assert_raises(TypeError, GridPartition, 5, 5.0) # Test grid method should be string 138 | assert_raises(TypeError, GridPartition, 5, "string", []) # Test random seed should be integer 139 | 140 | # Test the collector grid method is not the correct string 141 | collector = GridPartition(nbins_axis=5, bin_method="string") 142 | assert_raises(ValueError, collector.select_from_cluster, grid, 5) 143 | 144 | collector = GridPartition(nbins_axis=5) 145 | assert_raises(TypeError, collector.select_from_cluster, [5.0], 5) # Test X is numpy array 146 | assert_raises( 147 | TypeError, collector.select_from_cluster, grid, 5.0 148 | ) # Test number selected should be int 149 | assert_raises(TypeError, collector.select_from_cluster, grid, 5, [5.0]) 150 | 151 | 152 | def test_medoid(): 153 | """Testing Medoid class.""" 154 | coords, _, _ = generate_synthetic_data( 155 | n_samples=100, 156 | n_features=2, 157 | n_clusters=1, 158 | pairwise_dist=True, 159 | metric="euclidean", 160 | random_state=42, 161 | ) 162 | 163 | coords_cluster, class_labels_cluster, _ = generate_synthetic_data( 164 | n_samples=100, 165 | n_features=2, 166 | n_clusters=3, 167 | pairwise_dist=True, 168 | metric="euclidean", 169 | random_state=42, 170 | ) 171 | collector = Medoid() 172 | selected_ids = collector.select(coords_cluster, size=12, labels=class_labels_cluster) 173 | # make sure all the selected indices are the same with expectation 174 | assert_equal(selected_ids, [2, 73, 94, 86, 1, 50, 93, 78, 0, 54, 33, 72]) 175 | 176 | collector = Medoid() 177 | selected_ids = collector.select(coords, size=12) 178 | # make sure all the selected indices are the same with expectation 179 | assert_equal(selected_ids, [0, 95, 57, 41, 25, 9, 8, 6, 66, 1, 42, 82]) 180 | 181 | # test the case where KD-Tree query return is an integer 182 | features = np.array([[1.5, 2.8], [2.3, 3.8], [1.5, 2.8], [4.0, 5.9]]) 183 | selector = Medoid() 184 | selected_ids = selector.select(features, size=2) 185 | assert_equal(selected_ids, [0, 3]) 186 | -------------------------------------------------------------------------------- /selector/methods/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | """Module for Selection Utilities.""" 25 | 26 | import warnings 27 | 28 | import numpy as np 29 | 30 | __all__ = [ 31 | "optimize_radius", 32 | ] 33 | 34 | 35 | def optimize_radius(obj, X, size, cluster_ids=None): 36 | """Algorithm that uses sphere exclusion for selecting points from cluster. 37 | 38 | Iteratively searches for the optimal radius to obtain the correct number 39 | of selected samples. If the radius cannot converge to return `size` points, 40 | the function returns the closest number of samples to `size` as possible. 41 | 42 | Parameters 43 | ---------- 44 | obj: object 45 | Instance of `DirectedSphereExclusion` or `OptiSim` selection class. 46 | X: ndarray of shape (n_samples, n_features) 47 | Feature matrix of `n_samples` samples in `n_features` dimensional space. 48 | size: int 49 | Number of sample points to select (i.e. size of the subset). 50 | cluster_ids: np.ndarray 51 | Indices of points that form a cluster. 52 | 53 | Returns 54 | ------- 55 | selected: list 56 | List of indices of selected samples. 57 | """ 58 | if X.shape[0] < size: 59 | raise RuntimeError( 60 | f"Size of samples to be selected is greater than existing the number of samples; " 61 | f"{size} > {X.shape[0]}." 62 | ) 63 | # set the limits on # of selected points according to the tolerance percentage 64 | error = size * obj.tol 65 | lower_size = round(size - error) 66 | upper_size = round(size + error) 67 | 68 | # select `upper_size` number of samples 69 | if obj.r is not None: 70 | # use initial sphere radius 71 | selected = obj.algorithm(X, upper_size) 72 | else: 73 | # calculate a sphere radius based on maximum of n_features range 74 | # np.ptp returns range of values (maximum - minimum) along an axis 75 | obj.r = max(np.ptp(X, axis=0)) / size * 3 76 | selected = obj.algorithm(X, upper_size) 77 | 78 | # return selected if the correct number of samples chosen 79 | if len(selected) == size: 80 | return selected 81 | 82 | # optimize radius to select the correct number of samples 83 | # first, set a sensible range for optimizing r value within that range 84 | if len(selected) > size: 85 | # radius should become bigger, b/c too many samples were selected 86 | bounds = [obj.r, np.inf] 87 | else: 88 | # radius should become smaller, b/c too few samples were selected 89 | bounds = [0, obj.r] 90 | 91 | n_iter = 0 92 | while (len(selected) < lower_size or len(selected) > upper_size) and n_iter < obj.n_iter: 93 | # change sphere radius based on the defined bound 94 | if bounds[1] == np.inf: 95 | # make sphere radius larger by a factor of 2 96 | obj.r = bounds[0] * 2 97 | else: 98 | # make sphere radius smaller by a factor of 1/2 99 | obj.r = (bounds[0] + bounds[1]) / 2 100 | 101 | # re-select samples with the new radius 102 | selected = obj.algorithm(X, upper_size) 103 | 104 | # adjust lower/upper bounds of radius range 105 | if len(selected) > size: 106 | bounds[0] = obj.r 107 | else: 108 | bounds[1] = obj.r 109 | n_iter += 1 110 | 111 | # cannot find radius that produces desired number of selected points 112 | if n_iter >= obj.n_iter and len(selected) != size: 113 | warnings.warn( 114 | f"Optimal radius finder failed to converge, selected {len(selected)} points instead " 115 | f"of requested {size}." 116 | ) 117 | 118 | return selected 119 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [gh-actions] 2 | python = 3 | 3.9: py39 4 | 3.10: py310 5 | # 3.8: py38, rst_linux, rst_mac, readme, linters, coverage-report, qa 6 | # 3.11: py311, rst_linux, rst_mac, readme, linters, coverage-report, qa 7 | 3.11: py311, rst_linux, rst_mac, readme, coverage-report 8 | # 3.9: py39 9 | # 3.10: py310 10 | 3.12: py312 11 | 12 | [tox] 13 | # todo: add back the linters and qa 14 | # envlist = py39, py310, py311, rst_linux, rst_mac, readme, rst, linters, coverage-report, qa 15 | envlist = py39, py310, py311, py312, rst_linux, rst_mac, readme, rst, coverage-report 16 | 17 | [testenv] 18 | ; conda_deps = 19 | ; rdkit 20 | ; conda_channels = rdkit 21 | deps = 22 | -r{toxinidir}/requirements_dev.txt 23 | commands = 24 | # coverage run --rcfile=tox.ini -m pytest tests 25 | python -m pip install --upgrade pip 26 | # pip install -r requirements.txt 27 | # pip install . 28 | python -m pytest -c pyproject.toml --cov-config=.coveragerc --cov-report=xml --color=yes selector 29 | 30 | # pytest --cov-config=.coveragerc --cov=selector/test 31 | # can run it if needed 32 | # coverage report -m 33 | # prevent exit when error is encountered 34 | ignore_errors = true 35 | 36 | [testenv:readme] 37 | skip_install = true 38 | deps = 39 | readme_renderer 40 | twine 41 | -r{toxinidir}/requirements.txt 42 | -r{toxinidir}/requirements_dev.txt 43 | commands = 44 | # https://github.com/pypa/twine/issues/977 45 | python -m pip install importlib_metadata==7.2.1 build 46 | python -m build --no-isolation 47 | twine check dist/* 48 | 49 | [testenv:rst_linux] 50 | platform = 51 | linux 52 | skip_install = true 53 | deps = 54 | doc8 55 | rstcheck==3.3.1 56 | commands = 57 | doc8 --config tox.ini book/content/ 58 | # ignore code-block related error because 59 | # the Sphinx support in rstcheck is minimal. This results in false positives 60 | # rstcheck uses Docutils to parse reStructuredText files and extract code blocks 61 | # fixme: check updates on the following website in the future 62 | # coala.gitbooks.io/projects/content/en/projects/rstcheck-with-better-sphinx-suppor.html 63 | rstcheck --recursive book/content/ --report error \ 64 | --ignore-directives automodule,autoclass,autofunction,bibliography,code-block \ 65 | --ignore-roles cite,mod,class,lineno --ignore-messages code-block 66 | 67 | [testenv:rst_mac] 68 | platform = 69 | darwin 70 | skip_install = true 71 | deps = {[testenv:rst_linux]deps} 72 | commands = {[testenv:rst_linux]commands} 73 | 74 | [testenv:linters] 75 | deps = 76 | flake8 77 | flake8-docstrings 78 | flake8-import-order>=0.9 79 | flake8-colors 80 | pep8-naming 81 | pylint==2.13.9 82 | # black 83 | bandit 84 | commands = 85 | flake8 selector/ selector/tests setup.py 86 | pylint selector --rcfile=tox.ini --disable=similarities 87 | # black -l 100 --check ./ 88 | # black -l 100 --diff ./ 89 | # Use bandit configuration file 90 | bandit -r selector -c .bandit.yml 91 | 92 | ignore_errors = true 93 | 94 | [testenv:coverage-report] 95 | deps = coverage>=4.2 96 | skip_install = true 97 | commands = 98 | # coverage combine --rcfile=tox.ini 99 | coverage report 100 | 101 | [testenv:qa] 102 | deps = 103 | {[testenv]deps} 104 | {[testenv:linters]deps} 105 | {[testenv:coverage-report]deps} 106 | commands = 107 | {[testenv]commands} 108 | {[testenv:linters]commands} 109 | # {[testenv:coverage-report]commands} 110 | ignore_errors = true 111 | 112 | # pytest configuration 113 | [pytest] 114 | addopts = --cache-clear 115 | --showlocals 116 | -v 117 | -r a 118 | --cov-report term-missing 119 | --cov selector 120 | # Do not run tests in the build folder 121 | norecursedirs = build 122 | 123 | # flake8 configuration 124 | [flake8] 125 | exclude = 126 | __init__.py, 127 | .tox, 128 | .git, 129 | __pycache__, 130 | build, 131 | dist, 132 | *.pyc, 133 | *.egg-info, 134 | .cache, 135 | .eggs, 136 | _version.py, 137 | 138 | max-line-length = 100 139 | import-order-style = google 140 | ignore = 141 | # E121 : continuation line under-indented for hanging indent 142 | E121, 143 | # E123 : closing bracket does not match indentation of opening bracket’s line 144 | E123, 145 | # E126 : continuation line over-indented for hanging indent 146 | E126, 147 | # E226 : missing whitespace around arithmetic operator 148 | E226, 149 | # E241 : multiple spaces after ‘,’ 150 | # E242 : tab after ‘,’ 151 | E24, 152 | # E704 : multiple statements on one line (def) 153 | E704, 154 | # W503 : line break occurred before a binary operator 155 | W503, 156 | # W504 : Line break occurred after a binary operator 157 | W504, 158 | # D202: No blank lines allowed after function docstring 159 | D202, 160 | # E203: Whitespace before ':' 161 | E203, 162 | # E731: Do not assign a lambda expression, use a def 163 | E731, 164 | # D401: First line should be in imperative mood: 'Do', not 'Does' 165 | D401, 166 | 167 | per-file-ignores = 168 | # F401: Unused import 169 | # this is used to define the data typing 170 | selector/utils.py: F401, 171 | # E1101: rdkit.Chem has no attribute xxx 172 | # D403: first word of the first line should be properly capitalized 173 | selector/feature.py: E1101, D403 174 | 175 | # doc8 configuration 176 | [doc8] 177 | # Ignore target directories and autogenerated files 178 | ignore-path = book/content/_build/, build/, selector.egg-info/, selector.egg-info, .*/ 179 | # File extensions to use 180 | extensions = .rst, .txt 181 | # Maximal line length should be 100 182 | max-line-length = 100 183 | # Disable some doc8 checks: 184 | # D000: Check RST validity (cannot handle the "linenos" directive) 185 | # D002: Trailing whitespace 186 | # D004: Found literal carriage return 187 | # Both D002 and D004 can be problematic in Windows platform, line ending is `\r\n`, 188 | # but in Linux and MacOS, it's "\n" 189 | # Known issue of doc8, https://bugs.launchpad.net/doc8/+bug/1756704 190 | # ignore = D000,D002,D004 191 | ignore = D000 192 | 193 | # pylint configuration 194 | [MASTER] 195 | # This is a known issue of pylint with recognizing numpy members 196 | # https://github.com/PyCQA/pylint/issues/779 197 | # https://stackoverflow.com/questions/20553551/how-do-i-get-pylint-to-recognize-numpy-members 198 | extension-pkg-whitelist=numpy 199 | 200 | [FORMAT] 201 | # Maximum number of characters on a single line. 202 | max-line-length=100 203 | 204 | [MESSAGES CONTROL] 205 | # disable pylint warnings 206 | disable= 207 | # attribute-defined-outside-init (W0201): 208 | # Attribute %r defined outside __init__ Used when an instance attribute is 209 | # defined outside the __init__ method. 210 | W0201, 211 | # too-many-instance-attributes (R0902): 212 | # Too many instance attributes (%s/%s) Used when class has too many instance 213 | # attributes, try to reduce this to get a simpler (and so easier to use) 214 | # class. 215 | R0902, 216 | # too many branches (R0912) 217 | R0912, 218 | # too-many-arguments (R0913): 219 | # Too many arguments (%s/%s) Used when a function or method takes too many 220 | # arguments. 221 | R0913, 222 | # Too many local variables (r0914) 223 | R0914, 224 | # Too many statements (R0915) 225 | R0915, 226 | # fixme (W0511): 227 | # Used when a warning note as FIXME or XXX is detected. 228 | W0511, 229 | # bad-continuation (C0330): 230 | # Wrong hanging indentation before block (add 4 spaces). 231 | C0330, 232 | # wrong-import-order (C0411): 233 | # %s comes before %s Used when PEP8 import order is not respected (standard 234 | # imports first, then third-party libraries, then local imports) 235 | C0411, 236 | # arguments-differ (W0221): 237 | # Parameters differ from %s %r method Used when a method has a different 238 | # number of arguments than in the implemented interface or in an overridden 239 | # method. 240 | W0221, 241 | # unecessary "else" after "return" (R1705) 242 | R1705, 243 | # Value XX is unsubscriptable (E1136). this is a open issue of pylint 244 | # https://github.com/PyCQA/pylint/issues/3139 245 | E1136, 246 | # Used when a name doesn't doesn't fit the naming convention associated to its type 247 | # (constant, variable, class…). 248 | C0103, 249 | # Unnecessary pass statement 250 | W0107, 251 | # Module 'rdkit.Chem' has no 'ForwardSDMolSupplier' member (no-member) 252 | E1101, 253 | # todo: fix this one later and this is a temporary solution 254 | # E0401: Unable to import xxx (import-error) 255 | E0401, 256 | # R1721: Unnecessary use of a comprehension 257 | R1721, 258 | # I1101: Module xxx has no yyy member (no-member) 259 | I1101, 260 | # R0903: Too few public methods (too-few-public-methods) 261 | R0903, 262 | # R1702: Too many nested blocks (too-many-nested-blocks) 263 | R1702, 264 | 265 | [SIMILARITIES] 266 | min-similarity-lines=5 267 | 268 | # coverage configuration 269 | [run] 270 | branch = True 271 | parallel = True 272 | source = selector 273 | 274 | [paths] 275 | source = 276 | selector 277 | .tox/*/lib/python*/site-packages/selector 278 | .tox/pypy*/site-packages/selector 279 | -------------------------------------------------------------------------------- /updateheaders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # The Selector is a Python library of algorithms for selecting diverse 4 | # subsets of data for machine-learning. 5 | # 6 | # Copyright (C) 2022-2024 The QC-Devs Community 7 | # 8 | # This file is part of Selector. 9 | # 10 | # Selector is free software; you can redistribute it and/or 11 | # modify it under the terms of the GNU General Public License 12 | # as published by the Free Software Foundation; either version 3 13 | # of the License, or (at your option) any later version. 14 | # 15 | # Selector is distributed in the hope that it will be useful, 16 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 17 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 | # GNU General Public License for more details. 19 | # 20 | # You should have received a copy of the GNU General Public License 21 | # along with this program; if not, see 22 | # 23 | # -- 24 | 25 | 26 | import os 27 | from fnmatch import fnmatch 28 | from glob import glob 29 | 30 | 31 | def strip_header(lines, closing): 32 | # search for the header closing line, i.e. '# --\n' 33 | counter = 0 34 | found = False 35 | for line in lines: 36 | counter += 1 37 | if line == closing: 38 | found = True 39 | break 40 | if found: 41 | del lines[:counter] 42 | # If the header closing is not found, we assume it is not present. 43 | # add a header closing line 44 | lines.insert(0, closing) 45 | 46 | 47 | def fix_python(fn, lines, header_lines): 48 | # check if a shebang is present 49 | do_shebang = lines[0].startswith("#!") 50 | # remove the current header 51 | strip_header(lines, "# --\n") 52 | # add a pylint line for test files: 53 | # if os.path.basename(fn).startswith('test_'): 54 | # if not lines[1].startswith('#pylint: skip-file'): 55 | # lines.insert(1, '#pylint: skip-file\n') 56 | # add new header (insert in reverse order) 57 | for hline in header_lines[::-1]: 58 | lines.insert(0, ("# " + hline).strip() + "\n") 59 | 60 | if not hline.startswith("# -*- coding: utf-8 -*-"): 61 | # add a source code encoding line 62 | lines.insert(0, "# -*- coding: utf-8 -*-\n") 63 | 64 | if do_shebang: 65 | lines.insert(0, "#!/usr/bin/env python\n") 66 | 67 | 68 | def fix_c(fn, lines, header_lines): 69 | # check for an exception line 70 | for line in lines: 71 | if "no_update_headers" in line: 72 | return 73 | # remove the current header 74 | strip_header(lines, "//--\n") 75 | # add new header (insert must be in reverse order) 76 | for hline in header_lines[::-1]: 77 | lines.insert(0, ("// " + hline).strip() + "\n") 78 | 79 | 80 | def fix_rst(fn, lines, header_lines): 81 | # check for an exception line 82 | for line in lines: 83 | if "no_update_headers" in line: 84 | return 85 | # remove the current header 86 | strip_header(lines, " : --\n") 87 | # add an empty line after header if needed 88 | if len(lines[1].strip()) > 0: 89 | lines.insert(1, "\n") 90 | # add new header (insert must be in reverse order) 91 | for hline in header_lines[::-1]: 92 | lines.insert(0, (" : " + hline).rstrip() + "\n") 93 | # add comment instruction 94 | lines.insert(0, "..\n") 95 | 96 | 97 | def iter_subdirs(root): 98 | for dn, _, _ in os.walk(root): 99 | yield dn 100 | 101 | 102 | def main(): 103 | source_dirs = [".", "book", "notebooks", "selector"] + list(iter_subdirs("selector")) 104 | 105 | fixers = [ 106 | ("*.py", fix_python), 107 | ("*.pxd", fix_python), 108 | ("*.pyx", fix_python), 109 | ("*.txt", fix_python), 110 | ("*.c", fix_c), 111 | ("*.cpp", fix_c), 112 | ("*.h", fix_c), 113 | ("*.rst", fix_rst), 114 | ] 115 | 116 | f = open("HEADER") 117 | header_lines = f.readlines() 118 | f.close() 119 | 120 | for sdir in source_dirs: 121 | print("Scanning:", sdir) 122 | for fn in glob(sdir + "/*.*"): 123 | if not os.path.isfile(fn): 124 | continue 125 | for pattern, fixer in fixers: 126 | if fnmatch(fn, pattern): 127 | print(" Fixing:", fn) 128 | with open(fn) as f: 129 | lines = f.readlines() 130 | fixer(fn, lines, header_lines) 131 | with open(fn, "w") as f: 132 | f.writelines(lines) 133 | break 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | --------------------------------------------------------------------------------