├── .github ├── molpipeline.png ├── workflows │ ├── linting.yml │ └── python-publish.yml └── xai_example.png ├── .gitignore ├── .pre-commit-config.yaml ├── DEVELOPMENT.md ├── LICENSE ├── README.md ├── molpipeline ├── __init__.py ├── abstract_pipeline_elements │ ├── __init__.py │ ├── any2mol │ │ ├── __init__.py │ │ └── string2mol.py │ ├── core.py │ ├── mol2any │ │ ├── __init__.py │ │ ├── mol2bitvector.py │ │ ├── mol2floatvector.py │ │ └── mol2string.py │ └── mol2mol │ │ ├── __init__.py │ │ └── filter.py ├── any2mol │ ├── __init__.py │ ├── auto2mol.py │ ├── bin2mol.py │ ├── inchi2mol.py │ ├── sdf2mol.py │ └── smiles2mol.py ├── error_handling.py ├── estimators │ ├── __init__.py │ ├── algorithm │ │ ├── __init__.py │ │ ├── connected_component_clustering.py │ │ └── union_find.py │ ├── chemprop │ │ ├── __init__.py │ │ ├── abstract.py │ │ ├── component_wrapper.py │ │ ├── featurizer_wrapper │ │ │ ├── __init__.py │ │ │ └── graph_wrapper.py │ │ ├── lightning_wrapper.py │ │ ├── loss_wrapper.py │ │ ├── models.py │ │ └── neural_fingerprint.py │ ├── connected_component_clustering.py │ ├── leader_picker_clustering.py │ ├── murcko_scaffold_clustering.py │ ├── nearest_neighbor.py │ └── similarity_transformation.py ├── experimental │ ├── __init__.py │ ├── custom_filter.py │ ├── explainability │ │ ├── __init__.py │ │ ├── explainer.py │ │ ├── explanation.py │ │ ├── fingerprint_utils.py │ │ └── visualization │ │ │ ├── __init__.py │ │ │ ├── gauss.py │ │ │ ├── heatmaps.py │ │ │ ├── utils.py │ │ │ └── visualization.py │ └── model_selection │ │ ├── __init__.py │ │ └── splitter.py ├── metrics │ ├── __init__.py │ └── ignore_error_scorer.py ├── mol2any │ ├── __init__.py │ ├── mol2bin.py │ ├── mol2bool.py │ ├── mol2chemprop.py │ ├── mol2concatinated_vector.py │ ├── mol2inchi.py │ ├── mol2maccs_key_fingerprint.py │ ├── mol2morgan_fingerprint.py │ ├── mol2net_charge.py │ ├── mol2path_fingerprint.py │ ├── mol2rdkit_phys_chem.py │ └── mol2smiles.py ├── mol2mol │ ├── __init__.py │ ├── filter.py │ ├── reaction.py │ ├── scaffolds.py │ └── standardization.py ├── pipeline │ ├── __init__.py │ ├── _molpipeline.py │ └── _skl_pipeline.py ├── post_prediction.py ├── py.typed └── utils │ ├── __init__.py │ ├── comparison.py │ ├── json_operations.py │ ├── json_operations_torch.py │ ├── kernel.py │ ├── logging.py │ ├── matrices.py │ ├── molpipeline_types.py │ ├── multi_proc.py │ ├── subpipeline.py │ ├── substructure_handling.py │ ├── value_checks.py │ └── value_conversions.py ├── notebooks ├── 01_getting_started_with_molpipeline.ipynb ├── 02_scaffold_split_with_custom_estimators.ipynb ├── 03_error_handling.ipynb ├── 04_feature_calculation.ipynb ├── advanced_01_hyperopt_on_bbbp.ipynb ├── advanced_02_add_custom_pipeline_elements.ipynb ├── advanced_03_introduction_to_explainable_ai.ipynb └── example_data │ └── renin_harren.csv ├── pyproject.toml ├── ruff.toml ├── test_extras ├── __init__.py ├── test_chemprop │ ├── __init__.py │ ├── chemprop_test_utils │ │ ├── __init__.py │ │ ├── compare_models.py │ │ ├── constant_vars.py │ │ └── default_models.py │ ├── test_abstract.py │ ├── test_chemprop_pipeline.py │ ├── test_component_wrapper.py │ ├── test_lightning_wrapper.py │ ├── test_models.py │ └── test_neural_fingerprint.py └── test_notebooks │ ├── __init__.py │ └── test_notebooks.py └── tests ├── __init__.py ├── run_tests.py ├── test_data ├── P86_B_400.sdf.gz ├── mol_descriptors.tsv ├── molecule_net_bbbp.tsv.gz ├── molecule_net_logd.tsv.gz └── multiclass_mock.tsv ├── test_elements ├── __init__.py ├── test_any2mol │ ├── __init__.py │ ├── test_auto2mol.py │ ├── test_bin2mol.py │ ├── test_sdf2mol.py │ └── test_smiles2mol.py ├── test_error_handling.py ├── test_mol2any │ ├── __init__.py │ ├── test_mol2bin.py │ ├── test_mol2bool.py │ ├── test_mol2concatenated.py │ ├── test_mol2inchi.py │ ├── test_mol2maccs_key_fingerprint.py │ ├── test_mol2morgan_fingerprint.py │ ├── test_mol2net_charge.py │ ├── test_mol2path_fingerprint.py │ └── test_mol2rdkit_phys_chem.py ├── test_mol2mol │ ├── __init__.py │ ├── test_mol2mol_filter.py │ ├── test_mol2mol_standardization.py │ └── test_mol2scaffold.py └── test_post_prediction.py ├── test_estimators ├── __init__.py ├── test_algorithm │ ├── __init__.py │ ├── test_connected_component_clustering.py │ └── test_union_find.py ├── test_connected_component_clustering.py ├── test_leader_picker_clustering.py ├── test_murcko_scaffold_clustering.py ├── test_nearest_neighbors.py └── test_similarity_transformation.py ├── test_experimental ├── __init__.py ├── test_custom_filter.py ├── test_explainability │ ├── __init__.py │ ├── test_shap_explainers.py │ ├── test_visualization │ │ ├── __init__.py │ │ ├── test_gaussian_grid.py │ │ └── test_visualization.py │ └── utils.py └── test_model_selection │ ├── __init__.py │ └── test_splitter.py ├── test_init.py ├── test_metrics ├── __init__.py └── test_ignore_error_scorer.py ├── test_pipeline.py ├── test_utils ├── __init__.py ├── test_comparison.py ├── test_json_operations.py ├── test_logging.py └── test_subpipeline.py └── utils ├── __init__.py ├── default_models.py ├── execution_count.py ├── fingerprints.py ├── logging.py └── mock_element.py /.github/molpipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/.github/molpipeline.png -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/xai_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/.github/xai_example.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | __pycache__ 3 | molpipeline.egg-info/ 4 | lib/ 5 | build/ 6 | lightning_logs/ 7 | 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: | 2 | (?x)( 3 | ^notebooks/| 4 | ^$ 5 | ) 6 | fail_fast: false 7 | default_stages: [pre-commit, pre-push] 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v5.0.0 11 | hooks: 12 | - id: check-case-conflict 13 | - id: check-merge-conflict 14 | - id: end-of-file-fixer 15 | - id: mixed-line-ending 16 | - id: trailing-whitespace 17 | exclude_types: [tsv] 18 | args: [--markdown-linebreak-ext=md] 19 | 20 | - repo: https://github.com/commitizen-tools/commitizen 21 | rev: v3.29.1 22 | hooks: 23 | - id: commitizen 24 | stages: [commit-msg] 25 | 26 | - repo: https://github.com/astral-sh/ruff-pre-commit 27 | rev: v0.11.2 28 | hooks: 29 | - id: ruff-format 30 | types_or: [python, pyi, jupyter] 31 | - id: ruff 32 | types_or: [python, pyi, jupyter] 33 | args: [ --fix, --exit-non-zero-on-fix, --config, ruff.toml] 34 | 35 | 36 | - repo: https://github.com/pre-commit/mirrors-prettier 37 | rev: v4.0.0-alpha.8 38 | hooks: 39 | - id: prettier 40 | types: 41 | - ts 42 | - javascript 43 | - yaml 44 | - markdown 45 | - json 46 | 47 | - repo: https://github.com/RobertCraigie/pyright-python 48 | rev: v1.1.398 49 | hooks: 50 | - id: pyright 51 | args: [-p, pyproject.toml] 52 | verbose: true 53 | 54 | - repo: https://github.com/PyCQA/flake8 55 | rev: 7.0.0 56 | hooks: 57 | - id: flake8 58 | additional_dependencies: 59 | - docsig==0.69.3 60 | args: 61 | - --extend-ignore=D203,E203,E501,F401,W5 62 | - "--sig-check-class-constructor" 63 | - "--sig-check-dunders" 64 | - "--sig-check-protected-class-methods" 65 | - "--sig-check-nested" 66 | - "--sig-check-overridden" 67 | - "--sig-check-protected" 68 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Development-Guidelines 2 | ## Pre-commit Hooks 3 | Before committing any changes, make sure to enable the pre-commit hooks. 4 | This will help you to automatically format your code and check for any linting issues. 5 | You can enable the pre-commit hooks by running the following command: 6 | ```bash 7 | pre-commit install 8 | ``` 9 | In case you want to run the pre-commit hooks manually, you can do so by running: 10 | ```bash 11 | pre-commit run --all-files 12 | ``` 13 | > **_NOTE:_** Be aware that the code in its current state does not comply with the pre-commit hooks. 14 | > Hence, you might encounter errors in sections that are not related to your changes. 15 | > This is intended to slowly improve the code quality over time. 16 | > If fixing the errors would cause a massive overhead, you can ignore them via the `--no-verify` flag when committing. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BASF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /molpipeline/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the molpipeline package.""" 2 | 3 | # pylint: disable=no-name-in-module 4 | from rdkit.Chem import PropertyPickleOptions, SetDefaultPickleProperties 5 | 6 | from molpipeline.error_handling import ErrorFilter, FilterReinserter 7 | from molpipeline.pipeline import Pipeline 8 | from molpipeline.post_prediction import PostPredictionWrapper 9 | 10 | # Keep all properties when pickling. Otherwise, we will lose properties set on RDKitMol when passed to 11 | # multiprocessing subprocesses. 12 | SetDefaultPickleProperties(PropertyPickleOptions.AllProps) 13 | 14 | __all__ = [ 15 | "ErrorFilter", 16 | "FilterReinserter", 17 | "Pipeline", 18 | "PostPredictionWrapper", 19 | "__version__", 20 | ] 21 | 22 | __version__ = "0.10.2" 23 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/any2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/any2mol/string2mol.py: -------------------------------------------------------------------------------- 1 | """Abstract classes for creating rdkit molecules from string representations.""" 2 | 3 | from __future__ import annotations 4 | 5 | import abc 6 | 7 | from molpipeline.abstract_pipeline_elements.core import ( 8 | AnyToMolPipelineElement, 9 | InvalidInstance, 10 | ) 11 | from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol 12 | 13 | 14 | class StringToMolPipelineElement(AnyToMolPipelineElement, abc.ABC): 15 | """Abstract class for PipelineElements which transform molecules to integer vectors.""" 16 | 17 | _input_type = "str" 18 | _output_type = "RDKitMol" 19 | 20 | def transform(self, values: list[str]) -> list[OptionalMol]: 21 | """Transform the list of molecules to sparse matrix. 22 | 23 | Parameters 24 | ---------- 25 | values: list[str] 26 | List of string representations of molecules which are transformed to RDKit molecules. 27 | 28 | Returns 29 | ------- 30 | list[OptionalMol] 31 | List of RDKit molecules. If a string representation could not be transformed to a molecule, None is returned. 32 | """ 33 | return super().transform(values) 34 | 35 | @abc.abstractmethod 36 | def pretransform_single(self, value: str) -> OptionalMol: 37 | """Transform mol to a string. 38 | 39 | Parameters 40 | ---------- 41 | value: str 42 | Representation transformed to a RDKit molecule. 43 | 44 | Returns 45 | ------- 46 | OptionalMol 47 | RDKit molecule if representation was valid, else InvalidInstance. 48 | """ 49 | 50 | 51 | class SimpleStringToMolElement(StringToMolPipelineElement, abc.ABC): 52 | """Transforms string representation to RDKit Mol objects.""" 53 | 54 | def pretransform_single(self, value: str) -> OptionalMol: 55 | """Transform string to molecule. 56 | 57 | Parameters 58 | ---------- 59 | value: str 60 | string representation. 61 | 62 | Returns 63 | ------- 64 | OptionalMol 65 | Rdkit molecule if valid string representation, else None. 66 | """ 67 | if value is None: 68 | return InvalidInstance( 69 | self.uuid, 70 | f"Invalid representation: {value}", 71 | self.name, 72 | ) 73 | 74 | if not isinstance(value, str): 75 | return InvalidInstance( 76 | self.uuid, 77 | f"Not a string: {value}", 78 | self.name, 79 | ) 80 | 81 | mol: RDKitMol = self.string_to_mol(value) 82 | 83 | if not mol: 84 | return InvalidInstance( 85 | self.uuid, 86 | f"Invalid representation: {value}", 87 | self.name, 88 | ) 89 | mol.SetProp("identifier", value) 90 | return mol 91 | 92 | @abc.abstractmethod 93 | def string_to_mol(self, value: str) -> RDKitMol: 94 | """Transform string representation to molecule. 95 | 96 | Parameters 97 | ---------- 98 | value: str 99 | string representation 100 | 101 | Returns 102 | ------- 103 | RDKitMol 104 | Rdkit molecule if valid representation, else None. 105 | """ 106 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/mol2any/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | 3 | from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( 4 | MolToFingerprintPipelineElement, 5 | ) 6 | from molpipeline.abstract_pipeline_elements.mol2any.mol2floatvector import ( 7 | MolToDescriptorPipelineElement, 8 | ) 9 | from molpipeline.abstract_pipeline_elements.mol2any.mol2string import ( 10 | MolToStringPipelineElement, 11 | ) 12 | 13 | __all__ = ( 14 | "MolToDescriptorPipelineElement", 15 | "MolToFingerprintPipelineElement", 16 | "MolToStringPipelineElement", 17 | ) 18 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/mol2any/mol2string.py: -------------------------------------------------------------------------------- 1 | """Class for transforming molecules to SMILES representations.""" 2 | 3 | from __future__ import annotations 4 | 5 | import abc 6 | 7 | from rdkit import Chem 8 | 9 | from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement 10 | 11 | 12 | class MolToStringPipelineElement(MolToAnyPipelineElement, abc.ABC): 13 | """Abstract class for PipelineElements which transform molecules to integer vectors.""" 14 | 15 | _output_type = "str" 16 | 17 | def transform(self, values: list[Chem.Mol]) -> list[str]: 18 | """Transform the list of molecules to sparse matrix. 19 | 20 | Parameters 21 | ---------- 22 | values: list[Chem.Mol] 23 | List of RDKit molecules which are transformed to a string representation. 24 | 25 | Returns 26 | ------- 27 | list[str] 28 | List of string representations of the molecules. 29 | """ 30 | string_list: list[str] = super().transform(values) 31 | return string_list 32 | 33 | @abc.abstractmethod 34 | def pretransform_single(self, value: Chem.Mol) -> str: 35 | """Transform mol to a string. 36 | 37 | Parameters 38 | ---------- 39 | value: Chem.Mol 40 | Molecule to be transformed to SMILES representation. 41 | 42 | Returns 43 | ------- 44 | str 45 | SMILES representation of molecule. 46 | """ 47 | -------------------------------------------------------------------------------- /molpipeline/abstract_pipeline_elements/mol2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the module for abstract mol2mol elements.""" 2 | 3 | from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( 4 | BaseKeepMatchesFilter, 5 | BasePatternsFilter, 6 | ) 7 | 8 | __all__ = ["BaseKeepMatchesFilter", "BasePatternsFilter"] 9 | -------------------------------------------------------------------------------- /molpipeline/any2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | 3 | from molpipeline.any2mol.auto2mol import AutoToMol 4 | from molpipeline.any2mol.bin2mol import BinaryToMol 5 | from molpipeline.any2mol.inchi2mol import InchiToMol 6 | from molpipeline.any2mol.sdf2mol import SDFToMol 7 | from molpipeline.any2mol.smiles2mol import SmilesToMol 8 | 9 | __all__ = [ 10 | "AutoToMol", 11 | "BinaryToMol", 12 | "InchiToMol", 13 | "SDFToMol", 14 | "SmilesToMol", 15 | ] 16 | -------------------------------------------------------------------------------- /molpipeline/any2mol/auto2mol.py: -------------------------------------------------------------------------------- 1 | """Classes to transform given input to a RDKit molecule.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from molpipeline.abstract_pipeline_elements.core import ( 8 | AnyToMolPipelineElement, 9 | InvalidInstance, 10 | ) 11 | from molpipeline.any2mol.bin2mol import BinaryToMol 12 | from molpipeline.any2mol.inchi2mol import InchiToMol 13 | from molpipeline.any2mol.sdf2mol import SDFToMol 14 | from molpipeline.any2mol.smiles2mol import SmilesToMol 15 | from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol 16 | 17 | 18 | class AutoToMol(AnyToMolPipelineElement): 19 | """Transforms various inputs to RDKit Mol objects. 20 | 21 | A cascade of if clauses is tried to transformer the given input to a molecule. 22 | """ 23 | 24 | elements: tuple[AnyToMolPipelineElement, ...] 25 | 26 | def __init__( 27 | self, 28 | name: str = "auto2mol", 29 | n_jobs: int = 1, 30 | uuid: str | None = None, 31 | elements: tuple[AnyToMolPipelineElement, ...] | None = None, 32 | ) -> None: 33 | """Initialize AutoToMol. 34 | 35 | Parameters 36 | ---------- 37 | name: str, default="auto2mol" 38 | Name of PipelineElement 39 | n_jobs: int, default=1 40 | Number of parallel jobs to use. 41 | uuid: str | None, optional 42 | Unique identifier of PipelineElement. 43 | elements: tuple[AnyToMol, ...] | None, optional 44 | If None, the default elements are used: 45 | - SmilesToMol 46 | - InchiToMol 47 | - BinaryToMol 48 | - SDFToMol 49 | """ 50 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 51 | # pipeline elements for transforming the input to a molecule 52 | if elements is None: 53 | elements = ( 54 | SmilesToMol(), 55 | InchiToMol(), 56 | BinaryToMol(), 57 | SDFToMol(), 58 | ) 59 | self.elements = elements 60 | 61 | def pretransform_single(self, value: Any) -> OptionalMol: 62 | """Transform input value to molecule. 63 | 64 | Parameters 65 | ---------- 66 | value: Any 67 | Input value. 68 | 69 | Returns 70 | ------- 71 | OptionalMol 72 | Rdkit molecule if the input can be transformed, else None. 73 | """ 74 | if value is None: 75 | return InvalidInstance( 76 | self.uuid, 77 | f"Invalid input molecule: {value}", 78 | self.name, 79 | ) 80 | 81 | if isinstance(value, RDKitMol): 82 | return value 83 | 84 | # sequentially try to transform the input to a molecule using predefined elements 85 | for element in self.elements: 86 | mol = element.pretransform_single(value) 87 | if not isinstance(mol, InvalidInstance): 88 | return mol 89 | 90 | return InvalidInstance( 91 | self.uuid, 92 | f"Not readable input molecule: {value}", 93 | self.name, 94 | ) 95 | -------------------------------------------------------------------------------- /molpipeline/any2mol/bin2mol.py: -------------------------------------------------------------------------------- 1 | """Classes ment to transform given input to a RDKit molecule.""" 2 | 3 | from __future__ import annotations 4 | 5 | from rdkit import Chem 6 | 7 | from molpipeline.abstract_pipeline_elements.core import ( 8 | AnyToMolPipelineElement, 9 | InvalidInstance, 10 | ) 11 | from molpipeline.utils.molpipeline_types import OptionalMol 12 | 13 | 14 | class BinaryToMol(AnyToMolPipelineElement): 15 | """Transforms binary string representation to RDKit Mol objects.""" 16 | 17 | def pretransform_single(self, value: str) -> OptionalMol: 18 | """Transform binary string to molecule. 19 | 20 | Parameters 21 | ---------- 22 | value: str 23 | Binary string. 24 | 25 | Returns 26 | ------- 27 | OptionalMol 28 | Rdkit molecule if valid binary representation, else None. 29 | """ 30 | if value is None: 31 | return InvalidInstance( 32 | self.uuid, 33 | f"Invalid binary string: {value}", 34 | self.name, 35 | ) 36 | 37 | if not isinstance(value, bytes): 38 | return InvalidInstance( 39 | self.uuid, 40 | f"Not bytes: {value}", 41 | self.name, 42 | ) 43 | 44 | mol: OptionalMol | None = None 45 | try: 46 | mol = Chem.Mol(value) 47 | except RuntimeError: 48 | pass 49 | 50 | if not mol: 51 | return InvalidInstance( 52 | self.uuid, 53 | f"Invalid binary string: {value}", 54 | self.name, 55 | ) 56 | return mol 57 | -------------------------------------------------------------------------------- /molpipeline/any2mol/inchi2mol.py: -------------------------------------------------------------------------------- 1 | """Classes ment to transform given inchi to a RDKit molecule.""" 2 | 3 | from rdkit import Chem 4 | 5 | from molpipeline.abstract_pipeline_elements.any2mol.string2mol import ( 6 | SimpleStringToMolElement, 7 | ) 8 | from molpipeline.utils.molpipeline_types import RDKitMol 9 | 10 | 11 | class InchiToMol(SimpleStringToMolElement): 12 | """Transforms Inchi to RDKit Mol objects.""" 13 | 14 | def string_to_mol(self, value: str) -> RDKitMol: 15 | """Transform Inchi string to molecule. 16 | 17 | Parameters 18 | ---------- 19 | value: str 20 | Inchi string. 21 | 22 | Returns 23 | ------- 24 | RDKitMol 25 | Rdkit molecule if valid Inchi, else None. 26 | """ 27 | return Chem.MolFromInchi(value) 28 | -------------------------------------------------------------------------------- /molpipeline/any2mol/sdf2mol.py: -------------------------------------------------------------------------------- 1 | """Class for Transforming SDF-strings to rdkit molecules.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | try: 8 | from typing import Self # type: ignore[attr-defined] 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | import copy 13 | 14 | from rdkit import Chem 15 | 16 | from molpipeline.abstract_pipeline_elements.any2mol.string2mol import ( 17 | StringToMolPipelineElement as _StringToMolPipelineElement, 18 | ) 19 | from molpipeline.abstract_pipeline_elements.core import InvalidInstance 20 | 21 | if TYPE_CHECKING: 22 | from molpipeline.utils.molpipeline_types import OptionalMol 23 | 24 | 25 | class SDFToMol(_StringToMolPipelineElement): 26 | """PipelineElement transforming a list of SDF strings to mol_objects.""" 27 | 28 | identifier: str 29 | mol_counter: int 30 | 31 | def __init__( 32 | self, 33 | identifier: str = "enumerate", 34 | name: str = "SDF2Mol", 35 | n_jobs: int = 1, 36 | uuid: str | None = None, 37 | ) -> None: 38 | """Initialize SDFToMol. 39 | 40 | Parameters 41 | ---------- 42 | identifier: str, default='enumerate' 43 | Method of assigning identifiers to molecules. Per default, an increasing 44 | integer count is assigned to each molecule. If 'smiles' is chosen, the 45 | identifier is the SMILES representation of the molecule. 46 | name: str, default='SDF2Mol' 47 | Name of PipelineElement 48 | n_jobs: int, default=1 49 | Number of cores used for processing. 50 | uuid: str | None, optional 51 | uuid of PipelineElement, by default None 52 | 53 | """ 54 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 55 | self.identifier = identifier 56 | self.mol_counter = 0 57 | 58 | def get_params(self, deep: bool = True) -> dict[str, Any]: 59 | """Return all parameters defining the object. 60 | 61 | Parameters 62 | ---------- 63 | deep: bool 64 | If True get a deep copy of the parameters. 65 | 66 | Returns 67 | ------- 68 | dict[str, Any] 69 | Dictionary containing all parameters defining the object. 70 | 71 | """ 72 | params = super().get_params(deep) 73 | if deep: 74 | params["identifier"] = copy.copy(self.identifier) 75 | else: 76 | params["identifier"] = self.identifier 77 | return params 78 | 79 | def set_params(self, **parameters: Any) -> Self: 80 | """Set parameters of the object. 81 | 82 | Parameters 83 | ---------- 84 | parameters: Any 85 | Dictionary containing all parameters defining the object. 86 | 87 | Returns 88 | ------- 89 | Self 90 | SDFToMol with updated parameters. 91 | 92 | """ 93 | super().set_params(**parameters) 94 | if "identifier" in parameters: 95 | self.identifier = parameters["identifier"] 96 | return self 97 | 98 | def finish(self) -> None: 99 | """Reset the mol counter which assigns identifiers.""" 100 | self.mol_counter = 0 101 | 102 | def pretransform_single(self, value: str) -> OptionalMol: 103 | """Transform an SDF-strings to a rdkit molecule. 104 | 105 | Parameters 106 | ---------- 107 | value: str 108 | SDF-string to transform to a molecule. 109 | 110 | Returns 111 | ------- 112 | OptionalMol 113 | Molecule if transformation was successful, else InvalidInstance. 114 | 115 | """ 116 | if not isinstance(value, (str, bytes)): 117 | return InvalidInstance( 118 | self.uuid, 119 | "Invalid SDF string!", 120 | self.name, 121 | ) 122 | supplier = Chem.SDMolSupplier() 123 | supplier.SetData(value) 124 | mol = next(supplier, None) 125 | if mol is None: 126 | return InvalidInstance( 127 | self.uuid, 128 | "Invalid SDF string!", 129 | self.name, 130 | ) 131 | if self.identifier == "smiles": 132 | mol.SetProp("identifier", str(self.mol_counter)) 133 | self.mol_counter += 1 134 | return mol 135 | -------------------------------------------------------------------------------- /molpipeline/any2mol/smiles2mol.py: -------------------------------------------------------------------------------- 1 | """Classes ment to transform given input to a RDKit molecule.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | try: 8 | from typing import Self # type: ignore[attr-defined] 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | from rdkit import Chem 13 | 14 | from molpipeline.abstract_pipeline_elements.any2mol.string2mol import ( 15 | SimpleStringToMolElement, 16 | ) 17 | from molpipeline.utils.molpipeline_types import RDKitMol 18 | 19 | 20 | class SmilesToMol(SimpleStringToMolElement): 21 | """Transforms Smiles to RDKit Mol objects.""" 22 | 23 | def __init__( 24 | self, 25 | remove_hydrogens: bool = True, 26 | name: str = "smiles2mol", 27 | n_jobs: int = 1, 28 | uuid: str | None = None, 29 | ) -> None: 30 | """Initialize SmilesToMol object. 31 | 32 | Parameters 33 | ---------- 34 | remove_hydrogens: bool, default=True 35 | Whether to remove hydrogens from the molecule. 36 | name: str, default='smiles2mol' 37 | Name of the object. 38 | n_jobs: int, default=1 39 | Number of jobs to run in parallel. 40 | uuid: str | None, optional 41 | UUID of the object. 42 | """ 43 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 44 | self._remove_hydrogens = remove_hydrogens 45 | 46 | def _get_parser_config(self) -> Chem.SmilesParserParams: 47 | """Get parser configuration. 48 | 49 | Returns 50 | ------- 51 | dict[str, Any] 52 | Configuration for the parser. 53 | """ 54 | # set up rdkit smiles parser parameters 55 | parser_params = Chem.SmilesParserParams() 56 | parser_params.removeHs = self._remove_hydrogens # type: ignore[assignment] 57 | return parser_params 58 | 59 | def string_to_mol(self, value: str) -> RDKitMol: 60 | """Transform Smiles string to molecule. 61 | 62 | Parameters 63 | ---------- 64 | value: str 65 | SMILES string. 66 | 67 | Returns 68 | ------- 69 | RDKitMol 70 | Rdkit molecule if valid SMILES, else None. 71 | """ 72 | return Chem.MolFromSmiles(value, self._get_parser_config()) 73 | 74 | def get_params(self, deep: bool = True) -> dict[str, Any]: 75 | """Get parameters for this object. 76 | 77 | Parameters 78 | ---------- 79 | deep: bool 80 | If True, return a deep copy of the parameters. 81 | 82 | Returns 83 | ------- 84 | dict[str, Any] 85 | Dictionary of parameters. 86 | """ 87 | parameters = super().get_params(deep) 88 | if deep: 89 | parameters["remove_hydrogens"] = bool(self._remove_hydrogens) 90 | 91 | else: 92 | parameters["remove_hydrogens"] = self._remove_hydrogens 93 | return parameters 94 | 95 | def set_params(self, **parameters: Any) -> Self: 96 | """Set parameters. 97 | 98 | Parameters 99 | ---------- 100 | parameters: Any 101 | Dictionary of parameter names and values. 102 | 103 | Returns 104 | ------- 105 | Self 106 | SmilesToMol pipeline element with updated parameters. 107 | """ 108 | parameter_copy = dict(parameters) 109 | remove_hydrogens = parameter_copy.pop("remove_hydrogens", None) 110 | if remove_hydrogens is not None: 111 | self._remove_hydrogens = remove_hydrogens 112 | super().set_params(**parameter_copy) 113 | return self 114 | -------------------------------------------------------------------------------- /molpipeline/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | """Init file for estimators.""" 2 | 3 | from molpipeline.estimators.connected_component_clustering import ( 4 | ConnectedComponentClustering, 5 | ) 6 | from molpipeline.estimators.leader_picker_clustering import LeaderPickerClustering 7 | from molpipeline.estimators.murcko_scaffold_clustering import MurckoScaffoldClustering 8 | from molpipeline.estimators.nearest_neighbor import NamedNearestNeighbors 9 | from molpipeline.estimators.similarity_transformation import TanimotoToTraining 10 | 11 | __all__ = [ 12 | "ConnectedComponentClustering", 13 | "LeaderPickerClustering", 14 | "MurckoScaffoldClustering", 15 | "NamedNearestNeighbors", 16 | "TanimotoToTraining", 17 | ] 18 | -------------------------------------------------------------------------------- /molpipeline/estimators/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | """Estimator algorithms.""" 2 | -------------------------------------------------------------------------------- /molpipeline/estimators/algorithm/connected_component_clustering.py: -------------------------------------------------------------------------------- 1 | """Connected component clustering algorithm.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy.sparse import csr_matrix 8 | 9 | from molpipeline.estimators.algorithm.union_find import UnionFindNode 10 | from molpipeline.utils.kernel import tanimoto_similarity_sparse 11 | 12 | 13 | def calc_chunk_size_from_memory_requirement( 14 | nof_rows: int, nof_cols: int, itemsize: int, memory_cutoff: float 15 | ) -> int: 16 | """Calculate the chunk size from the memory requirement. 17 | 18 | Parameters 19 | ---------- 20 | nof_rows: int 21 | Number of rows of the matrix. 22 | nof_cols: int 23 | Number of columns of the matrix. 24 | itemsize: int 25 | Itemsize of the matrix. 26 | memory_cutoff: float 27 | Memory cutoff in GB. 28 | 29 | Returns 30 | ------- 31 | int 32 | Chunk size in number of rows. 33 | """ 34 | memory_cutoff_byte: float = memory_cutoff * 1024**3 35 | # get memory requirement of dense matrix 36 | row_memory: int = nof_cols * itemsize 37 | # get allowed rows 38 | allowed_rows = max(int(memory_cutoff_byte / row_memory), 1) 39 | allowed_rows = min(allowed_rows, nof_rows) 40 | return allowed_rows 41 | 42 | 43 | def connected_components_iterative_algorithm( 44 | feature_mat: csr_matrix, similarity_threshold: float, chunk_size: int = 5000 45 | ) -> tuple[int, npt.NDArray[np.int32]]: 46 | """Compute connected component clustering iteratively. 47 | 48 | This algorithm is suited for large data sets since the complete similarity matrix is not stored in memory at once. 49 | 50 | Parameters 51 | ---------- 52 | feature_mat: csr_matrix 53 | Feature matrix from which to calculate row-wise similarities. 54 | similarity_threshold: float 55 | Similarity threshold used to determine edges in the graph representation. 56 | chunk_size: int 57 | Number of rows for which similarities are determined at one iteration of the algorithm. 58 | 59 | Returns 60 | ------- 61 | tuple[int, np.ndarray[int]] 62 | Number of clusters and cluster labels. 63 | """ 64 | nof_samples = feature_mat.shape[0] 65 | uf_nodes = [UnionFindNode() for _ in range(nof_samples)] 66 | 67 | for i in range(0, nof_samples, chunk_size): 68 | mat_chunk = feature_mat[i : i + chunk_size, :] 69 | 70 | similarity_mat_chunk = tanimoto_similarity_sparse(mat_chunk, feature_mat) 71 | 72 | indices = np.transpose( 73 | np.asarray(similarity_mat_chunk >= similarity_threshold).nonzero() 74 | ) 75 | 76 | for i_idx, j_idx in indices: 77 | if i + i_idx >= j_idx: 78 | continue 79 | uf_nodes[j_idx].union(uf_nodes[i + i_idx]) 80 | 81 | return UnionFindNode.get_connected_components(uf_nodes) 82 | -------------------------------------------------------------------------------- /molpipeline/estimators/algorithm/union_find.py: -------------------------------------------------------------------------------- 1 | """Union find algorithm.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | 9 | class UnionFindNode: 10 | """Union find node. 11 | 12 | A UnionFindNode is a node in a union find data structure, also called disjoint-set data structure. 13 | It stores a collection of non-overlapping sets and provides operations for merging these sets. 14 | It can be used to determine connected components in a graph. 15 | """ 16 | 17 | def __init__(self) -> None: 18 | """Initialize union find node.""" 19 | # initially, each node is its own parent 20 | self.parent = self 21 | # connected component number for each node. Needed at the end when connected components are getting extracted. 22 | self.connected_component_number = -1 23 | 24 | def find(self) -> UnionFindNode: 25 | """Find the root node. 26 | 27 | Returns 28 | ------- 29 | Self 30 | Root node. 31 | """ 32 | if self.parent is not self: 33 | self.parent = self.parent.find() 34 | return self.parent 35 | 36 | def union(self, other: UnionFindNode) -> UnionFindNode: 37 | """Union two nodes. 38 | 39 | Parameters 40 | ---------- 41 | other: Self 42 | Other node. 43 | 44 | Returns 45 | ------- 46 | UnionFindNode 47 | Root node of set or self. 48 | """ 49 | # get the root nodes of the connected components of both nodes. 50 | # Additionally, compress the paths to the root nodes by setting others' parent to the root node. 51 | if self is self.parent: 52 | # this node is a parent. Meaning the root of the connected component. Let's overwrite it. 53 | self.parent = other.find() 54 | elif self is other: 55 | # this node is the other node. They are identical and therefore already in the same set. 56 | return self 57 | else: 58 | # add other node to this node's parent (and in addition to this connected component) 59 | self.parent = self.parent.union(other) 60 | return self.parent 61 | 62 | @staticmethod 63 | def get_connected_components( 64 | union_find_nodes: list[UnionFindNode], 65 | ) -> tuple[int, npt.NDArray[np.int32]]: 66 | """Get connected components from a union find node list. 67 | 68 | Parameters 69 | ---------- 70 | union_find_nodes: list[UnionFindNode] 71 | List of union find nodes. 72 | 73 | Returns 74 | ------- 75 | tuple[int, np.ndarray[int]] 76 | Number of connected components and connected component labels. 77 | """ 78 | # results 79 | connected_components_counter = 0 80 | connected_components_array = np.empty(len(union_find_nodes), dtype=np.int32) 81 | 82 | for i, node in enumerate(union_find_nodes): 83 | root_parent = node.find() 84 | if root_parent.connected_component_number == -1: 85 | # found root node of a connected component. Annotate it with a connected component number. 86 | root_parent.connected_component_number = connected_components_counter 87 | connected_components_counter += 1 88 | connected_components_array[i] = root_parent.connected_component_number 89 | return connected_components_counter, connected_components_array 90 | -------------------------------------------------------------------------------- /molpipeline/estimators/chemprop/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize Chemprop module.""" 2 | 3 | try: 4 | from molpipeline.estimators.chemprop.models import ( # noqa: F401 5 | ChempropClassifier, 6 | ChempropModel, 7 | ChempropNeuralFP, 8 | ChempropRegressor, 9 | ) 10 | 11 | __all__ = [ 12 | "ChempropClassifier", 13 | "ChempropModel", 14 | "ChempropNeuralFP", 15 | "ChempropRegressor", 16 | ] 17 | except ImportError: 18 | __all__ = [] 19 | -------------------------------------------------------------------------------- /molpipeline/estimators/chemprop/featurizer_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrapper for Chemprop Featurizer.""" 2 | -------------------------------------------------------------------------------- /molpipeline/estimators/chemprop/featurizer_wrapper/graph_wrapper.py: -------------------------------------------------------------------------------- 1 | """Wrapper for Chemprop GraphFeaturizer.""" 2 | 3 | from dataclasses import InitVar 4 | from typing import Any 5 | 6 | try: 7 | from typing import Self # type: ignore[attr-defined] 8 | except ImportError: 9 | from typing_extensions import Self 10 | 11 | from chemprop.featurizers.molgraph import ( 12 | SimpleMoleculeMolGraphFeaturizer as _SimpleMoleculeMolGraphFeaturizer, 13 | ) 14 | 15 | 16 | class SimpleMoleculeMolGraphFeaturizer(_SimpleMoleculeMolGraphFeaturizer): 17 | """Wrapper for Chemprop SimpleMoleculeMolGraphFeaturizer.""" 18 | 19 | extra_atom_fdim: InitVar[int] 20 | extra_bond_fdim: InitVar[int] 21 | 22 | def get_params( 23 | self, 24 | deep: bool = True, # pylint: disable=unused-argument 25 | ) -> dict[str, InitVar[int]]: 26 | """Get parameters for the featurizer. 27 | 28 | Parameters 29 | ---------- 30 | deep: bool, optional (default=True) 31 | Used for compatibility with scikit-learn. 32 | 33 | Returns 34 | ------- 35 | dict[str, int] 36 | Parameters of the featurizer. 37 | """ 38 | return {} 39 | 40 | def set_params(self, **parameters: Any) -> Self: # pylint: disable=unused-argument 41 | """Set the parameters of the featurizer. 42 | 43 | Parameters 44 | ---------- 45 | parameters: Any 46 | Parameters to set. Only used for compatibility with scikit-learn. 47 | 48 | Returns 49 | ------- 50 | Self 51 | This featurizer with the parameters set. 52 | """ 53 | return self 54 | -------------------------------------------------------------------------------- /molpipeline/estimators/chemprop/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | """Wrapper for Chemprop loss functions.""" 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from chemprop.nn.loss import BCELoss as _BCELoss 7 | from chemprop.nn.loss import BinaryDirichletLoss as _BinaryDirichletLoss 8 | from chemprop.nn.loss import CrossEntropyLoss as _CrossEntropyLoss 9 | from chemprop.nn.loss import EvidentialLoss as _EvidentialLoss 10 | from chemprop.nn.loss import LossFunction as _LossFunction 11 | from chemprop.nn.loss import MSELoss as _MSELoss 12 | from chemprop.nn.loss import MulticlassDirichletLoss as _MulticlassDirichletLoss 13 | from chemprop.nn.loss import MVELoss as _MVELoss 14 | from chemprop.nn.loss import SIDLoss as _SIDLoss 15 | from numpy.typing import ArrayLike 16 | 17 | 18 | class LossFunctionParamMixin: 19 | """Mixin for loss functions to get and set parameters.""" 20 | 21 | _original_task_weights: ArrayLike 22 | 23 | def __init__(self: _LossFunction, task_weights: ArrayLike) -> None: 24 | """Initialize the loss function. 25 | 26 | Parameters 27 | ---------- 28 | task_weights : ArrayLike 29 | The weights for each task. 30 | 31 | """ 32 | super().__init__(task_weights=task_weights) # type: ignore 33 | self._original_task_weights = task_weights 34 | 35 | # pylint: disable=unused-argument 36 | def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: 37 | """Get the parameters of the loss function. 38 | 39 | Parameters 40 | ---------- 41 | deep : bool, optional 42 | Not used, only present to match the sklearn API. 43 | 44 | Returns 45 | ------- 46 | dict[str, Any] 47 | The parameters of the loss function. 48 | """ 49 | return {"task_weights": self._original_task_weights} 50 | 51 | def set_params(self: _LossFunction, **params: Any) -> _LossFunction: 52 | """Set the parameters of the loss function. 53 | 54 | Parameters 55 | ---------- 56 | **params : Any 57 | The parameters to set. 58 | 59 | Returns 60 | ------- 61 | Self 62 | The loss function with the new parameters. 63 | """ 64 | task_weights = params.pop("task_weights", None) 65 | if task_weights is not None: 66 | self._original_task_weights = task_weights 67 | state_dict = self.state_dict() 68 | state_dict["task_weights"] = torch.as_tensor( 69 | task_weights, dtype=torch.float 70 | ).view(1, -1) 71 | self.load_state_dict(state_dict) 72 | return self 73 | 74 | 75 | class BCELoss(LossFunctionParamMixin, _BCELoss): 76 | """Binary cross-entropy loss function.""" 77 | 78 | 79 | class BinaryDirichletLoss(LossFunctionParamMixin, _BinaryDirichletLoss): 80 | """Binary Dirichlet loss function.""" 81 | 82 | 83 | class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss): 84 | """Cross-entropy loss function.""" 85 | 86 | 87 | class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss): 88 | """Evidential loss function.""" 89 | 90 | 91 | class MSELoss(LossFunctionParamMixin, _MSELoss): 92 | """Mean squared error loss function.""" 93 | 94 | 95 | class MulticlassDirichletLoss(LossFunctionParamMixin, _MulticlassDirichletLoss): 96 | """Multiclass Dirichlet loss function.""" 97 | 98 | 99 | class MVELoss(LossFunctionParamMixin, _MVELoss): 100 | """Mean value entropy loss function.""" 101 | 102 | 103 | class SIDLoss(LossFunctionParamMixin, _SIDLoss): 104 | """SID loss function.""" 105 | -------------------------------------------------------------------------------- /molpipeline/estimators/chemprop/neural_fingerprint.py: -------------------------------------------------------------------------------- 1 | """Wrap Chemprop in a sklearn like transformer returning the neural fingerprint as a numpy array.""" 2 | 3 | from collections.abc import Sequence 4 | from typing import Any, Self 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from chemprop.data import BatchMolGraph, MoleculeDataset 9 | from chemprop.models.model import MPNN 10 | from lightning import pytorch as pl 11 | 12 | from molpipeline.estimators.chemprop.abstract import ABCChemprop 13 | 14 | 15 | class ChempropNeuralFP(ABCChemprop): 16 | """Wrap Chemprop in a sklearn like transformer returning the neural fingerprint as a numpy array. 17 | 18 | This class is not a (grand-) child of MolToAnyPipelineElement, as it does not support the `pretransform_single` 19 | method. To maintain compatibility with the MolToAnyPipelineElement, the `output_type` property is implemented. 20 | It can be used as any other transformer in the pipeline, except in the `MolToConcatenatedVector`. 21 | """ 22 | 23 | @property 24 | def output_type(self) -> str: 25 | """Return the output type of the transformer.""" 26 | return "float" 27 | 28 | def __init__( 29 | self, 30 | model: MPNN, 31 | lightning_trainer: pl.Trainer | None = None, 32 | batch_size: int = 64, 33 | n_jobs: int = 1, 34 | disable_fitting: bool = False, 35 | **kwargs: Any, 36 | ) -> None: 37 | """Initialize the chemprop neural fingerprint model. 38 | 39 | Parameters 40 | ---------- 41 | model : MPNN 42 | The chemprop model to wrap. 43 | lightning_trainer : pl.Trainer, optional 44 | The lightning trainer to use, by default None 45 | batch_size : int, optional (default=64) 46 | The batch size to use. 47 | n_jobs : int, optional (default=1) 48 | The number of jobs to use. 49 | disable_fitting : bool, optional (default=False) 50 | Whether to allow fitting or set to fixed encoding. 51 | **kwargs: Any 52 | Parameters for components of the model. 53 | """ 54 | # pylint: disable=duplicate-code 55 | self.disable_fitting = disable_fitting 56 | super().__init__( 57 | model=model, 58 | lightning_trainer=lightning_trainer, 59 | batch_size=batch_size, 60 | n_jobs=n_jobs, 61 | **kwargs, 62 | ) 63 | 64 | def fit( 65 | self, 66 | X: MoleculeDataset, 67 | y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64], 68 | ) -> Self: 69 | """Fit the model. 70 | 71 | Parameters 72 | ---------- 73 | X : MoleculeDataset 74 | The input data. 75 | y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] 76 | The target data. 77 | 78 | Returns 79 | ------- 80 | Self 81 | The fitted model. 82 | """ 83 | if self.disable_fitting: 84 | return self 85 | return super().fit(X, y) 86 | 87 | def transform( 88 | self, 89 | X: MoleculeDataset, # pylint: disable=invalid-name 90 | ) -> npt.NDArray[np.float64]: 91 | """Transform the input. 92 | 93 | Parameters 94 | ---------- 95 | X : MoleculeDataset 96 | The input data. 97 | 98 | Returns 99 | ------- 100 | npt.NDArray[np.float64] 101 | The neural fingerprint of the input data. 102 | """ 103 | self.model.eval() 104 | mol_data = [X[i].mg for i in range(len(X))] 105 | return self.model.fingerprint(BatchMolGraph(mol_data)).detach().numpy() 106 | 107 | def fit_transform( 108 | self, 109 | X: MoleculeDataset, # pylint: disable=invalid-name 110 | y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64], 111 | ) -> npt.NDArray[np.float64]: 112 | """Fit the model and transform the input. 113 | 114 | Parameters 115 | ---------- 116 | X : MoleculeDataset 117 | The input data. 118 | y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] 119 | The target data. 120 | 121 | Returns 122 | ------- 123 | npt.NDArray[np.float64] 124 | The neural fingerprint of the input data. 125 | """ 126 | self.fit(X, y) 127 | return self.transform(X) 128 | -------------------------------------------------------------------------------- /molpipeline/estimators/connected_component_clustering.py: -------------------------------------------------------------------------------- 1 | """Connected component clustering estimator.""" 2 | 3 | from __future__ import annotations 4 | 5 | from numbers import Real 6 | from typing import Any 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | from scipy import sparse 11 | from scipy.sparse import csr_matrix 12 | from sklearn.base import BaseEstimator, ClusterMixin, _fit_context 13 | from sklearn.utils._param_validation import Interval 14 | from sklearn.utils.validation import validate_data 15 | 16 | try: 17 | from typing import Self 18 | except ImportError: 19 | from typing_extensions import Self 20 | 21 | from molpipeline.estimators.algorithm.connected_component_clustering import ( 22 | calc_chunk_size_from_memory_requirement, 23 | connected_components_iterative_algorithm, 24 | ) 25 | from molpipeline.utils.kernel import tanimoto_similarity_sparse 26 | 27 | 28 | class ConnectedComponentClustering(ClusterMixin, BaseEstimator): 29 | """Connected component clustering estimator.""" 30 | 31 | _parameter_constraints: dict[str, Any] = { 32 | "distance_threshold": [Interval(Real, 0, None, closed="left")], 33 | } 34 | 35 | def __init__( 36 | self, 37 | distance_threshold: float, 38 | *, 39 | max_memory_usage: float = 4.0, 40 | ) -> None: 41 | """Initialize connected component clustering estimator. 42 | 43 | Parameters 44 | ---------- 45 | distance_threshold : float 46 | Distance threshold for connected component clustering. 47 | max_memory_usage : float, optional 48 | Maximum memory usage in GB, by default 4.0 GB 49 | """ 50 | self.distance_threshold: float = distance_threshold 51 | self.max_memory_usage: float = max_memory_usage 52 | self.n_clusters_: int | None = None 53 | self.labels_: npt.NDArray[np.int32] | None = None 54 | 55 | # pylint: disable=C0103,W0613 56 | @_fit_context(prefer_skip_nested_validation=True) 57 | def fit( 58 | self, 59 | X: npt.NDArray[np.float64] | csr_matrix, 60 | y: npt.NDArray[np.float64] | None = None, 61 | ) -> Self: 62 | """Fit connected component clustering estimator. 63 | 64 | Parameters 65 | ---------- 66 | X : array-like of shape (n_samples, n_features) 67 | Feature matrix. 68 | y : Ignored 69 | Not used, present for API consistency by convention. 70 | 71 | Returns 72 | ------- 73 | Self 74 | Fitted estimator. 75 | """ 76 | X = validate_data(self, X=X, ensure_min_samples=2, accept_sparse=True) 77 | return self._fit(X) 78 | 79 | # pylint: disable=C0103,W0613 80 | def _fit(self, X: npt.NDArray[np.float64] | csr_matrix) -> Self: 81 | """Fit connected component clustering estimator. 82 | 83 | Parameters 84 | ---------- 85 | X : array-like of shape (n_samples, n_features) 86 | Feature matrix. 87 | 88 | Returns 89 | ------- 90 | Self 91 | Fitted estimator. 92 | """ 93 | # convert tanimoto distance to similarity 94 | similarity_threshold: float = 1 - self.distance_threshold 95 | 96 | # get row chunk size based on 2D dense distance matrix that will be generated 97 | row_chunk_size = calc_chunk_size_from_memory_requirement( 98 | X.shape[0] * 2 99 | + 2, # the self_tanimoto_distance needs two matrices of X.shape and two additional rows. 100 | X.shape[0], 101 | np.dtype("float64").itemsize, 102 | self.max_memory_usage, 103 | ) 104 | 105 | if row_chunk_size >= X.shape[0]: 106 | similarity_matrix = tanimoto_similarity_sparse(X, X) 107 | adjacency_matrix = (similarity_matrix >= similarity_threshold).astype( 108 | np.int8 109 | ) 110 | self.n_clusters_, self.labels_ = sparse.csgraph.connected_components( 111 | adjacency_matrix, directed=False, return_labels=True 112 | ) 113 | else: 114 | self.n_clusters_, self.labels_ = connected_components_iterative_algorithm( 115 | X, similarity_threshold, row_chunk_size 116 | ) 117 | return self 118 | 119 | def fit_predict( 120 | self, 121 | X: npt.NDArray[np.float64] | csr_matrix, # pylint: disable=C0103 122 | y: npt.NDArray[np.float64] | None = None, 123 | **kwargs: Any, 124 | ) -> npt.NDArray[np.int32]: 125 | """Fit and predict connected component clustering estimator. 126 | 127 | Parameters 128 | ---------- 129 | X: npt.NDArray[np.float64] | csr_matrix 130 | Feature matrix of shape (n_samples, n_features). 131 | y: Ignored 132 | Not used, present for API consistency by convention. 133 | kwargs: Any 134 | Additional keyword arguments. 135 | 136 | Returns 137 | ------- 138 | np.ndarray[int] 139 | Cluster labels. 140 | """ 141 | # pylint: disable=W0246 142 | return super().fit_predict(X, y, **kwargs) 143 | -------------------------------------------------------------------------------- /molpipeline/estimators/leader_picker_clustering.py: -------------------------------------------------------------------------------- 1 | """LeaderPicker-based clustering estimator.""" 2 | 3 | from __future__ import annotations 4 | 5 | from itertools import compress 6 | from numbers import Real 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | from rdkit import DataStructs 11 | from rdkit.DataStructs import ExplicitBitVect 12 | from rdkit.SimDivFilters import rdSimDivPickers 13 | from sklearn.base import BaseEstimator, ClusterMixin, _fit_context 14 | from sklearn.utils._param_validation import Interval 15 | 16 | try: 17 | from collections.abc import Sequence 18 | from typing import Any, Self 19 | except ImportError: 20 | from typing_extensions import Self 21 | 22 | 23 | class LeaderPickerClustering(ClusterMixin, BaseEstimator): 24 | """LeaderPicker clustering estimator (a sphere exclusion clustering algorithm).""" 25 | 26 | # we use sklearn's input validation to check constraints 27 | _parameter_constraints: dict[str, Any] = { 28 | "distance_threshold": [Interval(Real, 0, 1.0, closed="left")], 29 | } 30 | 31 | def __init__( 32 | self, 33 | distance_threshold: float, 34 | ) -> None: 35 | """Initialize LeaderPicker clustering estimator. 36 | 37 | Parameters 38 | ---------- 39 | distance_threshold : float 40 | Minimum distance between cluster centroids. 41 | """ 42 | self.distance_threshold: float = distance_threshold 43 | self.n_clusters_: int | None = None 44 | self.labels_: npt.NDArray[np.int32] | None = None 45 | # centroid indices 46 | self.centroids_: npt.NDArray[np.int32] | None = None 47 | 48 | # pylint: disable=C0103,W0613 49 | @_fit_context(prefer_skip_nested_validation=True) 50 | def fit( 51 | self, 52 | X: list[ExplicitBitVect], 53 | y: npt.NDArray[np.float64] | None = None, 54 | ) -> Self: 55 | """Fit leader picker clustering estimator. 56 | 57 | Parameters 58 | ---------- 59 | X : array-like of shape (n_samples, n_features) 60 | Feature matrix. 61 | y : Ignored 62 | Not used, present for API consistency by convention. 63 | 64 | Returns 65 | ------- 66 | Self 67 | Fitted estimator. 68 | """ 69 | return self._fit(X) 70 | 71 | @staticmethod 72 | def _assign_points_to_clusters_based_on_centroid( 73 | picks: Sequence[int], fps: Sequence[ExplicitBitVect] 74 | ) -> tuple[int, npt.NDArray[np.int32]]: 75 | """Assign points to clusters based on centroid. 76 | 77 | Based on https://rdkit.blogspot.com/2020/11/sphere-exclusion-clustering-with-rdkit.html 78 | 79 | Parameters 80 | ---------- 81 | picks : Sequence[int] 82 | Indices of selected cluster centroids to which the remaining data will be assigned. 83 | fps : Sequence[ExplicitBitVect] 84 | Fingerprints of the whole data sets. 85 | 86 | Returns 87 | ------- 88 | tuple[int, np.ndarray[int]] 89 | Number of clusters and cluster labels. 90 | """ 91 | labels: npt.NDArray[np.int32] = np.full(len(fps), -1, dtype=np.int32) 92 | max_similarities = np.full(len(fps), -np.inf, dtype=np.float64) 93 | 94 | for i, pick_idx in enumerate(picks): 95 | similarities = DataStructs.BulkTanimotoSimilarity(fps[pick_idx], fps) 96 | max_mask = similarities > max_similarities 97 | labels[max_mask] = i 98 | max_similarities[max_mask] = list(compress(similarities, max_mask)) 99 | 100 | return np.unique(labels).shape[0], labels 101 | 102 | # pylint: disable=C0103,W0613 103 | def _fit(self, X: list[ExplicitBitVect]) -> Self: 104 | """Fit leader picker clustering estimator. 105 | 106 | Parameters 107 | ---------- 108 | X : array-like of shape (n_samples, n_features) 109 | Feature matrix. 110 | 111 | Returns 112 | ------- 113 | Self 114 | Fitted estimator. 115 | """ 116 | lp = rdSimDivPickers.LeaderPicker() 117 | 118 | # Select centroids. This part is in C++ and fast 119 | picks = lp.LazyBitVectorPick( 120 | objects=X, 121 | poolSize=len(X), 122 | threshold=self.distance_threshold, 123 | numThreads=1, # according to rdkit docu this parameter is not used 124 | # seed=self.random_state if self.random_state is not None else -1, 125 | ) 126 | 127 | # Assign points to clusters based on centroid 128 | ( 129 | self.n_clusters_, 130 | self.labels_, 131 | ) = self._assign_points_to_clusters_based_on_centroid(picks, X) 132 | 133 | self.centroids_ = np.array(picks) 134 | return self 135 | 136 | def fit_predict( 137 | self, 138 | X: list[ExplicitBitVect], # pylint: disable=C0103 139 | y: npt.NDArray[np.float64] | None = None, 140 | **kwargs: Any, 141 | ) -> npt.NDArray[np.int32]: 142 | """Fit and predict leader picker clustering estimator. 143 | 144 | Parameters 145 | ---------- 146 | X: npt.NDArray[np.float64] | csr_matrix 147 | Feature matrix of shape (n_samples, n_features). 148 | y: Ignored 149 | Not used, present for API consistency by convention. 150 | kwargs: Any 151 | Additional keyword arguments. 152 | 153 | Returns 154 | ------- 155 | np.ndarray[int] 156 | Cluster labels. 157 | """ 158 | # pylint: disable=W0246 159 | return super().fit_predict(X, y, **kwargs) 160 | -------------------------------------------------------------------------------- /molpipeline/estimators/similarity_transformation.py: -------------------------------------------------------------------------------- 1 | """Sklearn estimators for computing similarity and distance matrices.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | try: 8 | from typing import Self 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | import numpy as np 13 | import numpy.typing as npt 14 | from scipy.sparse import csr_matrix 15 | from sklearn.base import BaseEstimator, TransformerMixin 16 | 17 | from molpipeline.utils.kernel import tanimoto_similarity_sparse 18 | 19 | 20 | class TanimotoToTraining(BaseEstimator, TransformerMixin): 21 | """Transformer for computing tanimoto similarity matrices to data seen during training. 22 | 23 | Can also be used to compute distance matrices. 24 | 25 | Attributes 26 | ---------- 27 | training_matrix: npt.NDArray[np.float64] | csr_matrix | None 28 | Features seen during fit. 29 | """ 30 | 31 | training_matrix: npt.NDArray[np.float64] | csr_matrix | None 32 | 33 | def __init__(self, distance: bool = False) -> None: 34 | """Initialize TanimotoSimilarityToTraining. 35 | 36 | Parameters 37 | ---------- 38 | distance: bool, optional 39 | If True, the distance matrix is computed, by default False 40 | The distance matrix is computed as 1 - similarity_matrix. 41 | """ 42 | self.training_matrix = None 43 | self.distance = distance 44 | 45 | def _sim( 46 | self, 47 | matrix_a: npt.NDArray[np.float64] | csr_matrix, 48 | matrix_b: npt.NDArray[np.float64] | csr_matrix, 49 | ) -> npt.NDArray[np.float64]: 50 | """Compute the similarity matrix. 51 | 52 | Parameters 53 | ---------- 54 | matrix_a : npt.NDArray[np.float64] | csr_matrix 55 | First matrix. 56 | matrix_b : npt.NDArray[np.float64] | csr_matrix 57 | Second matrix. 58 | 59 | Returns 60 | ------- 61 | npt.NDArray[np.float64] 62 | Similarity matrix. If distance is True, the distance matrix is computed instead. 63 | """ 64 | if not isinstance(matrix_a, csr_matrix): 65 | matrix_a = csr_matrix(matrix_a) 66 | if not isinstance(matrix_b, csr_matrix): 67 | matrix_b = csr_matrix(matrix_b) 68 | if self.distance: 69 | return 1 - tanimoto_similarity_sparse(matrix_a, matrix_b) # type: ignore 70 | return tanimoto_similarity_sparse(matrix_a, matrix_b) 71 | 72 | def fit( 73 | self, 74 | X: npt.NDArray[np.float64] | csr_matrix, # pylint: disable=invalid-name 75 | y: npt.NDArray[np.float64] | None = None, # pylint: disable=unused-argument 76 | ) -> Self: 77 | """Fit the model. 78 | 79 | Parameters 80 | ---------- 81 | X : npt.NDArray[np.float64] | csr_matrix 82 | Feature matrix to which the similarity matrix is computed. 83 | y : npt.NDArray[np.float64] | None, optional 84 | Labels, by default None and never used 85 | 86 | Returns 87 | ------- 88 | Self 89 | Fitted model. 90 | """ 91 | self.training_matrix = X 92 | return self 93 | 94 | def transform( 95 | self, 96 | X: npt.NDArray[np.float64] | csr_matrix, # pylint: disable=invalid-name 97 | ) -> npt.NDArray[np.float64]: 98 | """Transform the data. 99 | 100 | Parameters 101 | ---------- 102 | X : npt.NDArray[np.float64] | csr_matrix 103 | Feature matrix to which the similarity matrix is computed. 104 | 105 | Raises 106 | ------ 107 | ValueError 108 | If the transformer has not been fitted yet. 109 | 110 | Returns 111 | ------- 112 | npt.NDArray[np.float64] 113 | Similarity matrix of X to the training matrix. 114 | 115 | """ 116 | if self.training_matrix is None: 117 | raise ValueError("Please fit the transformer before transforming!") 118 | return self._sim(X, self.training_matrix) 119 | 120 | def fit_transform( 121 | self, 122 | X: npt.NDArray[np.float64] | csr_matrix, 123 | y: npt.NDArray[np.float64] | None = None, 124 | **fit_params: Any, 125 | ) -> npt.NDArray[np.float64]: 126 | """Fit the model and transform the data. 127 | 128 | Parameters 129 | ---------- 130 | X: npt.NDArray[np.float64] | csr_matrix 131 | Feature matrix to fit the model. Is returned as similarity matrix to itself. 132 | y: npt.NDArray[np.float64] | None, optional 133 | Labels, by default None and never used 134 | **fit_params: Any 135 | Additional fit parameters. Ignored. 136 | 137 | Returns 138 | ------- 139 | npt.NDArray[np.float64] 140 | Similarity matrix of X to itself. 141 | """ 142 | self.fit(X, y) 143 | return self.transform(X) 144 | -------------------------------------------------------------------------------- /molpipeline/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize module for experimental classes and functions.""" 2 | 3 | from molpipeline.experimental.custom_filter import CustomFilter 4 | 5 | __all__ = [ 6 | "CustomFilter", 7 | ] 8 | -------------------------------------------------------------------------------- /molpipeline/experimental/custom_filter.py: -------------------------------------------------------------------------------- 1 | """Module for custom filter functionality.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, Callable 6 | 7 | try: 8 | from typing import Self # type: ignore[attr-defined] 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | from molpipeline.abstract_pipeline_elements.core import ( 13 | InvalidInstance, 14 | ) 15 | from molpipeline.abstract_pipeline_elements.core import ( 16 | MolToMolPipelineElement as _MolToMolPipelineElement, 17 | ) 18 | from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol 19 | 20 | 21 | class CustomFilter(_MolToMolPipelineElement): 22 | """Filters molecules based on a custom boolean function. Elements not passing the filter will be set to InvalidInstances.""" 23 | 24 | def __init__( 25 | self, 26 | func: Callable[[RDKitMol], bool], 27 | name: str = "CustomFilter", 28 | n_jobs: int = 1, 29 | uuid: str | None = None, 30 | ) -> None: 31 | """Initialize CustomFilter. 32 | 33 | Parameters 34 | ---------- 35 | func : Callable[[RDKitMol], bool] 36 | custom function to filter molecules 37 | name : str, default="CustomFilter" 38 | name of the element, by default "CustomFilter" 39 | n_jobs : int, default=1 40 | number of jobs to use, by default 1 41 | uuid : str | None, optional 42 | uuid of the element, by default None 43 | """ 44 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 45 | self.func = func 46 | 47 | def pretransform_single(self, value: RDKitMol) -> OptionalMol: 48 | """Pretransform single value. 49 | 50 | Applies the custom boolean function to the molecule. 51 | 52 | Parameters 53 | ---------- 54 | value : RDKitMol 55 | input value 56 | 57 | Returns 58 | ------- 59 | OptionalMol 60 | output value 61 | """ 62 | if self.func(value): 63 | return value 64 | return InvalidInstance( 65 | self.uuid, 66 | f"Molecule does not match filter from {self.name}", 67 | self.name, 68 | ) 69 | 70 | def get_params(self, deep: bool = True) -> dict[str, Any]: 71 | """Get parameters of CustomFilter. 72 | 73 | Parameters 74 | ---------- 75 | deep: bool, optional (default: True) 76 | If True, return the parameters of all subobjects that are PipelineElements. 77 | 78 | Returns 79 | ------- 80 | dict[str, Any] 81 | Parameters of CustomFilter. 82 | """ 83 | params = super().get_params(deep=deep) 84 | if deep: 85 | params["func"] = self.func 86 | else: 87 | params["func"] = self.func 88 | return params 89 | 90 | def set_params(self, **parameters: dict[str, Any]) -> Self: 91 | """Set parameters of CustomFilter. 92 | 93 | Parameters 94 | ---------- 95 | parameters: dict[str, Any] 96 | Parameters to set. 97 | 98 | Returns 99 | ------- 100 | Self 101 | Self. 102 | """ 103 | parameter_copy = dict(parameters) 104 | if "func" in parameter_copy: 105 | self.func = parameter_copy.pop("func") # type: ignore 106 | super().set_params(**parameter_copy) 107 | return self 108 | -------------------------------------------------------------------------------- /molpipeline/experimental/explainability/__init__.py: -------------------------------------------------------------------------------- 1 | """Explainability module for the molpipeline package.""" 2 | 3 | from molpipeline.experimental.explainability.explainer import ( 4 | SHAPKernelExplainer, 5 | SHAPTreeExplainer, 6 | ) 7 | from molpipeline.experimental.explainability.explanation import ( 8 | SHAPFeatureAndAtomExplanation, 9 | SHAPFeatureExplanation, 10 | ) 11 | from molpipeline.experimental.explainability.visualization.visualization import ( 12 | structure_heatmap, 13 | structure_heatmap_shap, 14 | ) 15 | 16 | __all__ = [ 17 | "SHAPFeatureAndAtomExplanation", 18 | "SHAPFeatureExplanation", 19 | "SHAPKernelExplainer", 20 | "SHAPTreeExplainer", 21 | "structure_heatmap", 22 | "structure_heatmap_shap", 23 | ] 24 | -------------------------------------------------------------------------------- /molpipeline/experimental/explainability/explanation.py: -------------------------------------------------------------------------------- 1 | """Module for explanation class.""" 2 | 3 | from __future__ import annotations 4 | 5 | import abc 6 | import dataclasses 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | from molpipeline.abstract_pipeline_elements.core import RDKitMol 12 | 13 | 14 | @dataclasses.dataclass(kw_only=True) 15 | class _AbstractMoleculeExplanation(abc.ABC): 16 | """Abstract class representing an explanation for a prediction for a molecule.""" 17 | 18 | molecule: RDKitMol | None = None 19 | prediction: npt.NDArray[np.float64] | None = None 20 | 21 | 22 | @dataclasses.dataclass(kw_only=True) 23 | class FeatureInfoMixin: 24 | """Mixin providing additional information about the features used in the explanation.""" 25 | 26 | feature_vector: npt.NDArray[np.float64] | None = None 27 | feature_names: list[str] | None = None 28 | 29 | 30 | @dataclasses.dataclass(kw_only=True) 31 | class FeatureExplanationMixin: 32 | """Explanation based on feature importance scores, e.g. Shapley Values.""" 33 | 34 | # explanation scores for individual features 35 | feature_weights: npt.NDArray[np.float64] | None = None 36 | 37 | 38 | @dataclasses.dataclass(kw_only=True) 39 | class AtomExplanationMixin: 40 | """Atom score based explanation.""" 41 | 42 | # explanation scores for individual atoms 43 | atom_weights: npt.NDArray[np.float64] | None = None 44 | 45 | 46 | @dataclasses.dataclass(kw_only=True) 47 | class BondExplanationMixin: 48 | """Bond score based explanation.""" 49 | 50 | # explanation scores for individual bonds 51 | bond_weights: npt.NDArray[np.float64] | None = None 52 | 53 | 54 | @dataclasses.dataclass(kw_only=True) 55 | class SHAPExplanationMixin: 56 | """Mixin providing additional information only present in SHAP explanations.""" 57 | 58 | expected_value: npt.NDArray[np.float64] | None = None 59 | 60 | 61 | @dataclasses.dataclass(kw_only=True) 62 | class SHAPFeatureExplanation( 63 | FeatureInfoMixin, 64 | FeatureExplanationMixin, 65 | SHAPExplanationMixin, 66 | _AbstractMoleculeExplanation, # base-class should be the last element https://www.ianlewis.org/en/mixins-and-python 67 | ): 68 | """Explanation using feature importance scores from SHAP.""" 69 | 70 | def is_valid(self) -> bool: 71 | """Check if the explanation is valid. 72 | 73 | Returns 74 | ------- 75 | bool 76 | True if the explanation is valid, False otherwise. 77 | """ 78 | return all( 79 | [ 80 | self.feature_vector is not None, 81 | self.feature_names is not None, 82 | self.molecule is not None, 83 | self.prediction is not None, 84 | self.feature_weights is not None, 85 | ] 86 | ) 87 | 88 | 89 | @dataclasses.dataclass(kw_only=True) 90 | class SHAPFeatureAndAtomExplanation( 91 | FeatureInfoMixin, 92 | FeatureExplanationMixin, 93 | SHAPExplanationMixin, 94 | AtomExplanationMixin, 95 | _AbstractMoleculeExplanation, 96 | ): 97 | """Explanation using feature and atom importance scores from SHAP.""" 98 | 99 | def is_valid(self) -> bool: 100 | """Check if the explanation is valid. 101 | 102 | Returns 103 | ------- 104 | bool 105 | True if the explanation is valid, False otherwise. 106 | """ 107 | return all( 108 | [ 109 | self.feature_vector is not None, 110 | self.feature_names is not None, 111 | self.molecule is not None, 112 | self.prediction is not None, 113 | self.feature_weights is not None, 114 | self.atom_weights is not None, 115 | ] 116 | ) 117 | -------------------------------------------------------------------------------- /molpipeline/experimental/explainability/fingerprint_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for explainability.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections import defaultdict 6 | from collections.abc import Sequence 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | from molpipeline.abstract_pipeline_elements.core import RDKitMol 12 | from molpipeline.mol2any import MolToMorganFP 13 | from molpipeline.utils.substructure_handling import AtomEnvironment 14 | 15 | 16 | def assign_prediction_importance( 17 | bit_dict: dict[int, Sequence[AtomEnvironment]], weights: npt.NDArray[np.float64] 18 | ) -> dict[int, float]: 19 | """Assign the prediction importance. 20 | 21 | Originally from Christian W. Feldmann 22 | https://github.com/c-feldmann/compchemkit/blob/64e5543e2b8f72e93711186b2e0b42366820fb52/compchemkit/molecular_heatmaps.py#L28 23 | 24 | Parameters 25 | ---------- 26 | bit_dict : dict[int, Sequence[AtomEnvironment]] 27 | The bit dictionary. 28 | weights : npt.NDArray[np.float64] 29 | The weights. 30 | 31 | Raises 32 | ------ 33 | AssertionError 34 | If the weights and atom contributions don't sum to the same value. 35 | 36 | Returns 37 | ------- 38 | dict[int, float] 39 | The atom contribution. 40 | 41 | """ 42 | atom_contribution: dict[int, float] = defaultdict(lambda: 0) 43 | for bit, atom_env_list in bit_dict.items(): # type: int, Sequence[AtomEnvironment] 44 | n_machtes = len(atom_env_list) 45 | for atom_set in atom_env_list: 46 | for atom in atom_set.environment_atoms: 47 | atom_contribution[atom] += weights[bit] / ( 48 | len(atom_set.environment_atoms) * n_machtes 49 | ) 50 | if not np.isclose(sum(weights), sum(atom_contribution.values())).all(): 51 | raise AssertionError( 52 | f"Weights and atom contributions don't sum to the same value:" 53 | f" {weights.sum()} != {sum(atom_contribution.values())}" 54 | ) 55 | return atom_contribution 56 | 57 | 58 | def fingerprint_shap_to_atomweights( 59 | mol: RDKitMol, fingerprint_element: MolToMorganFP, shap_mat: npt.NDArray[np.float64] 60 | ) -> list[float]: 61 | """Convert SHAP values to atom weights. 62 | 63 | Originally from Christian W. Feldmann 64 | https://github.com/c-feldmann/compchemkit/blob/64e5543e2b8f72e93711186b2e0b42366820fb52/compchemkit/molecular_heatmaps.py#L15 65 | 66 | Parameters 67 | ---------- 68 | mol : RDKitMol 69 | The molecule. 70 | fingerprint_element : MolToMorganFP 71 | The fingerprint element. 72 | shap_mat : npt.NDArray[np.float64] 73 | The SHAP values. 74 | 75 | Returns 76 | ------- 77 | list[float] 78 | The atom weights. 79 | """ 80 | bit_atom_env_dict: dict[int, Sequence[AtomEnvironment]] 81 | bit_atom_env_dict = dict( 82 | fingerprint_element.bit2atom_mapping(mol) 83 | ) # MyPy invariants make me do this. 84 | atom_weight_dict = assign_prediction_importance(bit_atom_env_dict, shap_mat) 85 | atom_weight_list = [ 86 | atom_weight_dict[a_idx] if a_idx in atom_weight_dict else 0 87 | for a_idx in range(mol.GetNumAtoms()) 88 | ] 89 | return atom_weight_list 90 | -------------------------------------------------------------------------------- /molpipeline/experimental/explainability/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | """Visualization module for explainability.""" 2 | -------------------------------------------------------------------------------- /molpipeline/experimental/explainability/visualization/gauss.py: -------------------------------------------------------------------------------- 1 | """Gaussian functions for visualization. 2 | 3 | Much of the visualization code in this file originates from projects of Christian W. Feldmann: 4 | https://github.com/c-feldmann/rdkit_heatmaps 5 | https://github.com/c-feldmann/compchemkit 6 | """ 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | 12 | class GaussFunctor2D: # pylint: disable=too-few-public-methods 13 | """2D Gaussian functor.""" 14 | 15 | def __init__( 16 | self, 17 | center: npt.NDArray[np.float64], 18 | std1: float = 1, 19 | std2: float = 1, 20 | scale: float = 1, 21 | rotation: float = 0, 22 | ) -> None: 23 | """Initialize 2D Gaussian functor. 24 | 25 | Parameters 26 | ---------- 27 | center: npt.NDArray[np.float64] 28 | Center of the Gaussian function. 29 | std1: float 30 | Standard deviation along the first axis. 31 | std2: float 32 | Standard deviation along the second axis. 33 | scale: float 34 | Scaling factor. 35 | rotation: float 36 | Rotation angle in radians. 37 | """ 38 | self.center = center 39 | self.std = np.array([std1, std2]) ** 2 # scale stds to variance 40 | self.scale = scale 41 | self.rotation = rotation 42 | 43 | self._a = np.cos(self.rotation) ** 2 / (2 * self.std[0]) + np.sin( 44 | self.rotation 45 | ) ** 2 / (2 * self.std[1]) 46 | self._b = -np.sin(2 * self.rotation) / (4 * self.std[0]) + np.sin( 47 | 2 * self.rotation 48 | ) / (4 * self.std[1]) 49 | self._c = np.sin(self.rotation) ** 2 / (2 * self.std[0]) + np.cos( 50 | self.rotation 51 | ) ** 2 / (2 * self.std[1]) 52 | 53 | def __call__(self, pos: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: 54 | """Evaluate the Gaussian function at the given positions. 55 | 56 | Parameters 57 | ---------- 58 | pos: npt.NDArray[np.float64] 59 | Array of positions to evaluate the Gaussian function at. 60 | 61 | Returns 62 | ------- 63 | npt.NDArray[np.float64] 64 | Array of function values at the given positions. 65 | """ 66 | exponent = self._a * (pos[:, 0] - self.center[0]) ** 2 67 | exponent += ( 68 | 2 * self._b * (pos[:, 0] - self.center[0]) * (pos[:, 1] - self.center[1]) 69 | ) 70 | exponent += self._c * (pos[:, 1] - self.center[1]) ** 2 71 | return self.scale * np.exp(-exponent) 72 | -------------------------------------------------------------------------------- /molpipeline/experimental/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | """Model selection module.""" 2 | 3 | from molpipeline.experimental.model_selection.splitter import ( 4 | GroupShuffleSplit, 5 | ) 6 | 7 | __all__ = ["GroupShuffleSplit"] 8 | -------------------------------------------------------------------------------- /molpipeline/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Metrics for evaluating the performance of a model.""" 2 | 3 | from molpipeline.metrics.ignore_error_scorer import ignored_value_scorer 4 | 5 | __all__ = ["ignored_value_scorer"] 6 | -------------------------------------------------------------------------------- /molpipeline/metrics/ignore_error_scorer.py: -------------------------------------------------------------------------------- 1 | """Scorer that ignores a given value in the prediction array.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | import pandas as pd 10 | from loguru import logger 11 | from sklearn import metrics 12 | from sklearn.metrics._scorer import _BaseScorer 13 | 14 | 15 | def ignored_value_scorer( 16 | scorer: str | _BaseScorer, ignore_value: Any = None 17 | ) -> _BaseScorer: 18 | """Create a scorer that ignores a given value in the prediction array. 19 | 20 | This is relevant for pipline models which replace errors with a given value. 21 | The wrapped scorer will ignore that value and return the corresponding score. 22 | 23 | Parameters 24 | ---------- 25 | scorer : str or _BaseScorer 26 | The scorer to wrap. 27 | ignore_value : Any, optional 28 | The value to ignore in the prediction array. 29 | Default: None 30 | 31 | Returns 32 | ------- 33 | _BaseScorer 34 | The scorer that ignores the given value. 35 | """ 36 | if isinstance(scorer, str): 37 | scorer = metrics.get_scorer(scorer) 38 | 39 | score_func = scorer._score_func # pylint: disable=protected-access 40 | response_method = scorer._response_method # pylint: disable=protected-access 41 | scorer_kwargs = scorer._kwargs # pylint: disable=protected-access 42 | if scorer._sign < 0: # pylint: disable=protected-access 43 | scorer_kwargs["greater_is_better"] = False 44 | 45 | def newscore( 46 | y_true: npt.NDArray[np.float64 | np.int_], 47 | y_pred: npt.NDArray[np.float64 | np.int_], 48 | **kwargs: Any, 49 | ) -> float: 50 | """Compute the score for the given prediction arrays. 51 | 52 | Parameters 53 | ---------- 54 | y_true : npt.NDArray[np.float64 | np.int_] 55 | The true values. 56 | y_pred : npt.NDArray[np.float64 | np.int_] 57 | The predicted values. 58 | **kwargs 59 | Additional keyword arguments. 60 | 61 | Returns 62 | ------- 63 | float 64 | The score for the given prediction arrays. 65 | """ 66 | retained_y_true: npt.NDArray[np.bool_] 67 | retained_y_pred: npt.NDArray[np.bool_] 68 | if pd.notna(ignore_value): 69 | retained_y_true = ~np.equal(y_true, ignore_value) 70 | retained_y_pred = ~np.equal(y_pred, ignore_value) 71 | else: 72 | retained_y_true = pd.notna(y_true) 73 | retained_y_pred = pd.notna(y_pred) 74 | 75 | all_retained = retained_y_pred & retained_y_true 76 | 77 | if not np.all(all_retained): 78 | logger.warning( 79 | f"Warning, prediction array contains NaN values, removing {sum(~all_retained)} elements" 80 | ) 81 | y_true_ = np.copy(np.array(y_true)[all_retained]) 82 | y_pred_ = np.array(np.array(y_pred)[all_retained].tolist()) 83 | kwargs_ = dict(kwargs) 84 | if "sample_weight" in kwargs_ and kwargs_["sample_weight"] is not None: 85 | kwargs_["sample_weight"] = kwargs_["sample_weight"][all_retained] 86 | return score_func(y_true_, y_pred_, **kwargs_) 87 | 88 | return metrics.make_scorer( 89 | newscore, response_method=response_method, **scorer_kwargs 90 | ) 91 | -------------------------------------------------------------------------------- /molpipeline/mol2any/__init__.py: -------------------------------------------------------------------------------- 1 | """Init the module for mol2any pipeline elements.""" 2 | 3 | from molpipeline.mol2any.mol2bin import MolToBinary 4 | from molpipeline.mol2any.mol2bool import MolToBool 5 | from molpipeline.mol2any.mol2concatinated_vector import MolToConcatenatedVector 6 | from molpipeline.mol2any.mol2inchi import MolToInchi, MolToInchiKey 7 | from molpipeline.mol2any.mol2maccs_key_fingerprint import MolToMACCSFP 8 | from molpipeline.mol2any.mol2morgan_fingerprint import MolToMorganFP 9 | from molpipeline.mol2any.mol2net_charge import MolToNetCharge 10 | from molpipeline.mol2any.mol2path_fingerprint import Mol2PathFP 11 | from molpipeline.mol2any.mol2rdkit_phys_chem import MolToRDKitPhysChem 12 | from molpipeline.mol2any.mol2smiles import MolToSmiles 13 | 14 | __all__ = [ 15 | "Mol2PathFP", 16 | "MolToBinary", 17 | "MolToBool", 18 | "MolToConcatenatedVector", 19 | "MolToInchi", 20 | "MolToInchiKey", 21 | "MolToMACCSFP", 22 | "MolToMorganFP", 23 | "MolToNetCharge", 24 | "MolToRDKitPhysChem", 25 | "MolToSmiles", 26 | ] 27 | 28 | try: 29 | from molpipeline.mol2any.mol2chemprop import MolToChemprop # noqa 30 | 31 | __all__.append("MolToChemprop") 32 | except ImportError: 33 | pass 34 | -------------------------------------------------------------------------------- /molpipeline/mol2any/mol2bin.py: -------------------------------------------------------------------------------- 1 | """Converter element for molecules to binary string representation.""" 2 | 3 | from rdkit import Chem 4 | 5 | from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement 6 | 7 | 8 | class MolToBinary(MolToAnyPipelineElement): 9 | """PipelineElement to transform a molecule to a binary.""" 10 | 11 | def pretransform_single(self, value: Chem.Mol) -> str: 12 | """Transform a molecule to a binary string. 13 | 14 | Parameters 15 | ---------- 16 | value: Chem.Mol 17 | Molecule to be transformed to binary string representation. 18 | 19 | Returns 20 | ------- 21 | str 22 | Binary representation of molecule. 23 | """ 24 | return value.ToBinary() 25 | -------------------------------------------------------------------------------- /molpipeline/mol2any/mol2bool.py: -------------------------------------------------------------------------------- 1 | """Pipeline elements for converting instances to bool.""" 2 | 3 | from typing import Any 4 | 5 | from molpipeline.abstract_pipeline_elements.core import ( 6 | InvalidInstance, 7 | MolToAnyPipelineElement, 8 | ) 9 | 10 | 11 | class MolToBool(MolToAnyPipelineElement): 12 | """Element to generate a bool array from input. 13 | 14 | Valid molecules are passed as True, InvalidInstances are passed as False. 15 | """ 16 | 17 | def pretransform_single(self, value: Any) -> bool: 18 | """Transform a value to a bool representation. 19 | 20 | Parameters 21 | ---------- 22 | value: Any 23 | Value to be transformed to bool representation. 24 | 25 | Returns 26 | ------- 27 | str 28 | Binary representation of molecule. 29 | """ 30 | if isinstance(value, InvalidInstance): 31 | return False 32 | return True 33 | 34 | def transform_single(self, value: Any) -> Any: 35 | """Transform a single molecule to a bool representation. 36 | 37 | Valid molecules are passed as True, InvalidInstances are passed as False. 38 | RemovedMolecule objects are passed without change, as no transformations are applicable. 39 | 40 | Parameters 41 | ---------- 42 | value: Any 43 | Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) 44 | 45 | Returns 46 | ------- 47 | Any 48 | Bool representation of the molecule. 49 | """ 50 | pre_value = self.pretransform_single(value) 51 | return self.finalize_single(pre_value) 52 | -------------------------------------------------------------------------------- /molpipeline/mol2any/mol2inchi.py: -------------------------------------------------------------------------------- 1 | """Classes for transforming rdkit molecules to inchi.""" 2 | 3 | from __future__ import annotations 4 | 5 | from rdkit import Chem 6 | 7 | from molpipeline.abstract_pipeline_elements.mol2any.mol2string import ( 8 | MolToStringPipelineElement as _MolToStringPipelineElement, 9 | ) 10 | from molpipeline.utils.molpipeline_types import RDKitMol 11 | 12 | 13 | class MolToInchi(_MolToStringPipelineElement): 14 | """PipelineElement to transform a molecule to an INCHI string.""" 15 | 16 | def pretransform_single(self, value: RDKitMol) -> str: 17 | """Transform a molecule to a INCHI-key string. 18 | 19 | Parameters 20 | ---------- 21 | value: RDKitMol 22 | molecule to transform 23 | 24 | Returns 25 | ------- 26 | str 27 | INCHI string 28 | """ 29 | return str(Chem.MolToInchi(value)) 30 | 31 | 32 | class MolToInchiKey(_MolToStringPipelineElement): 33 | """PipelineElement to transform a molecule to an INCHI-Key string.""" 34 | 35 | def __init__( 36 | self, 37 | name: str = "MolToInchiKey", 38 | n_jobs: int = 1, 39 | uuid: str | None = None, 40 | ): 41 | """Initialize MolToInchiKey. 42 | 43 | Parameters 44 | ---------- 45 | name: str, default="MolToInchiKey" 46 | name of PipelineElement 47 | n_jobs: int, default=1 48 | number of jobs to use for parallelization 49 | uuid: str | None, optional 50 | uuid of PipelineElement, by default None 51 | """ 52 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 53 | 54 | def pretransform_single(self, value: RDKitMol) -> str: 55 | """Transform a molecule to an INCHI-key string. 56 | 57 | Parameters 58 | ---------- 59 | value: RDKitMol 60 | molecule to transform 61 | 62 | Returns 63 | ------- 64 | str 65 | INCHI-key of molecule. 66 | """ 67 | return str(Chem.MolToInchiKey(value)) 68 | -------------------------------------------------------------------------------- /molpipeline/mol2any/mol2maccs_key_fingerprint.py: -------------------------------------------------------------------------------- 1 | """Implementation of MACCS key fingerprint.""" 2 | 3 | import numpy as np 4 | from numpy import typing as npt 5 | from rdkit.Chem import MACCSkeys 6 | from rdkit.DataStructs import ExplicitBitVect 7 | 8 | from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( 9 | MolToFingerprintPipelineElement, 10 | ) 11 | from molpipeline.utils.molpipeline_types import RDKitMol 12 | 13 | 14 | class MolToMACCSFP(MolToFingerprintPipelineElement): 15 | """MACCS key fingerprint. 16 | 17 | The MACCS keys are a set of 166 keys that encode the presence or absence of 18 | particular substructures in a molecule. The MACCS keys are a subset of the 19 | PubChem substructure keys. 20 | 21 | """ 22 | 23 | _n_bits = 167 # MACCS keys have 166 bits + 1 bit for an all-zero vector (bit 0) 24 | _feature_names = [f"maccs_{i}" for i in range(_n_bits)] 25 | 26 | def pretransform_single( 27 | self, value: RDKitMol 28 | ) -> dict[int, int] | npt.NDArray[np.int_] | ExplicitBitVect: 29 | """Transform a single molecule to MACCS key fingerprint. 30 | 31 | Parameters 32 | ---------- 33 | value : RDKitMol 34 | RDKit molecule. 35 | 36 | Raises 37 | ------ 38 | ValueError 39 | If the variable `return_as` is not one of the allowed values. 40 | 41 | Returns 42 | ------- 43 | dict[int, int] | npt.NDArray[np.int_] | ExplicitBitVect 44 | MACCS key fingerprint. 45 | 46 | """ 47 | fingerprint = MACCSkeys.GenMACCSKeys(value) # type: ignore[attr-defined] 48 | if self._return_as == "explicit_bit_vect": 49 | return fingerprint 50 | if self._return_as == "dense": 51 | return np.array(fingerprint) 52 | if self._return_as == "sparse": 53 | return dict.fromkeys(fingerprint.GetOnBits(), 1) 54 | raise ValueError(f"Unknown return_as value: {self._return_as}") 55 | -------------------------------------------------------------------------------- /molpipeline/mol2any/mol2smiles.py: -------------------------------------------------------------------------------- 1 | """Classes for transforming rdkit molecules to any type of output.""" 2 | 3 | from __future__ import annotations 4 | 5 | from rdkit import Chem 6 | 7 | from molpipeline.abstract_pipeline_elements.mol2any.mol2string import ( 8 | MolToStringPipelineElement as _MolToStringPipelineElement, 9 | ) 10 | 11 | 12 | class MolToSmiles(_MolToStringPipelineElement): 13 | """PipelineElement to transform a molecule to a SMILES string.""" 14 | 15 | def pretransform_single(self, value: Chem.Mol) -> str: 16 | """Transform a molecule to a SMILES string. 17 | 18 | Parameters 19 | ---------- 20 | value: Chem.Mol 21 | Molecule to be transformed to SMILES string. 22 | 23 | Returns 24 | ------- 25 | str 26 | SMILES string of molecule. 27 | """ 28 | return str(Chem.MolToSmiles(value)) 29 | -------------------------------------------------------------------------------- /molpipeline/mol2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Init the module for mol2mol pipeline elements.""" 2 | 3 | from molpipeline.mol2mol.filter import ( 4 | ComplexFilter, 5 | ElementFilter, 6 | EmptyMoleculeFilter, 7 | InorganicsFilter, 8 | MixtureFilter, 9 | RDKitDescriptorsFilter, 10 | SmartsFilter, 11 | SmilesFilter, 12 | ) 13 | from molpipeline.mol2mol.reaction import MolToMolReaction 14 | from molpipeline.mol2mol.scaffolds import MakeScaffoldGeneric, MurckoScaffold 15 | from molpipeline.mol2mol.standardization import ( 16 | ChargeParentExtractor, 17 | ExplicitHydrogenRemover, 18 | FragmentDeduplicator, 19 | IsotopeRemover, 20 | LargestFragmentChooser, 21 | MetalDisconnector, 22 | SaltRemover, 23 | SolventRemover, 24 | StereoRemover, 25 | TautomerCanonicalizer, 26 | Uncharger, 27 | ) 28 | 29 | __all__ = ( 30 | "ChargeParentExtractor", 31 | "ComplexFilter", 32 | "ElementFilter", 33 | "EmptyMoleculeFilter", 34 | "ExplicitHydrogenRemover", 35 | "FragmentDeduplicator", 36 | "InorganicsFilter", 37 | "IsotopeRemover", 38 | "LargestFragmentChooser", 39 | "MakeScaffoldGeneric", 40 | "MetalDisconnector", 41 | "MixtureFilter", 42 | "MolToMolReaction", 43 | "MurckoScaffold", 44 | "RDKitDescriptorsFilter", 45 | "SaltRemover", 46 | "SmartsFilter", 47 | "SmilesFilter", 48 | "SolventRemover", 49 | "StereoRemover", 50 | "TautomerCanonicalizer", 51 | "Uncharger", 52 | ) 53 | -------------------------------------------------------------------------------- /molpipeline/mol2mol/scaffolds.py: -------------------------------------------------------------------------------- 1 | """Classes for standardizing molecules.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | try: 8 | from typing import Self # pylint: disable=no-name-in-module 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | from rdkit import Chem 13 | from rdkit.Chem.Scaffolds import MurckoScaffold as RDKIT_MurckoScaffold 14 | 15 | from molpipeline.abstract_pipeline_elements.core import ( 16 | MolToMolPipelineElement as _MolToMolPipelineElement, 17 | ) 18 | from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol 19 | 20 | 21 | class MurckoScaffold(_MolToMolPipelineElement): 22 | """MolToMol-PipelineElement which yields the Murcko-scaffold of a Molecule. 23 | 24 | The Murcko-scaffold is composed of all rings and the linker atoms between them. 25 | """ 26 | 27 | def pretransform_single(self, value: RDKitMol) -> OptionalMol: 28 | """Extract Murco-scaffold of molecule. 29 | 30 | Parameters 31 | ---------- 32 | value: RDKitMol 33 | RDKit molecule object which is transformed. 34 | 35 | Returns 36 | ------- 37 | OptionalMol 38 | Murco-scaffold of molecule if possible, else InvalidInstance. 39 | """ 40 | return RDKIT_MurckoScaffold.GetScaffoldForMol(value) 41 | 42 | 43 | class MakeScaffoldGeneric(_MolToMolPipelineElement): 44 | """MolToMol-PipelineElement which sets all atoms to carbon and all bonds to single bond. 45 | 46 | Done to make scaffolds less speciffic. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | generic_atoms: bool = False, 52 | generic_bonds: bool = False, 53 | name: str = "MakeScaffoldGeneric", 54 | n_jobs: int = 1, 55 | uuid: str | None = None, 56 | ) -> None: 57 | """Initialize MakeScaffoldGeneric. 58 | 59 | Note 60 | ---- 61 | Making atoms or bonds generic will generate SMARTS strings instead of SMILES. 62 | This can be useful to search for scaffolds and substructures in data sets. 63 | Per default, the scaffold is returned as SMILES with all atoms set to carbon and 64 | all bonds are single bonds. 65 | 66 | Parameters 67 | ---------- 68 | generic_atoms: bool, default=False 69 | If True, all atoms in the molecule are set to generic atoms (*). 70 | generic_bonds: bool, default=False 71 | If True, all bonds in the molecule are set to any bonds. 72 | name: str, default="MakeScaffoldGeneric" 73 | Name of pipeline element. 74 | n_jobs: int, default=1 75 | Number of jobs to use for parallelization. 76 | uuid: str | None 77 | UUID of pipeline element. 78 | 79 | """ 80 | self.generic_atoms = generic_atoms 81 | self.generic_bonds = generic_bonds 82 | super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) 83 | 84 | def pretransform_single(self, value: RDKitMol) -> OptionalMol: 85 | """Set all atoms to carbon and all bonds to single bond and return mol object. 86 | 87 | Parameters 88 | ---------- 89 | value: RDKitMol 90 | RDKit molecule object which is transformed. 91 | 92 | Returns 93 | ------- 94 | OptionalMol 95 | Molecule where all atoms are carbon and all bonds are single bonds. 96 | If transformation failed, it returns InvalidInstance. 97 | """ 98 | scaffold = RDKIT_MurckoScaffold.MakeScaffoldGeneric(value) 99 | if self.generic_atoms: 100 | for atom in scaffold.GetAtoms(): 101 | atom.SetAtomicNum(0) 102 | if self.generic_bonds: 103 | for bond in scaffold.GetBonds(): 104 | bond.SetBondType(Chem.rdchem.BondType.UNSPECIFIED) 105 | return scaffold 106 | 107 | def get_params(self, deep: bool = True) -> dict[str, Any]: 108 | """Get parameters of pipeline element. 109 | 110 | Parameters 111 | ---------- 112 | deep: bool 113 | If True, return the parameters of the pipeline element. 114 | 115 | Returns 116 | ------- 117 | dict[str, Any] 118 | Parameters of the pipeline element. 119 | """ 120 | parent_params = super().get_params() 121 | if deep: 122 | parent_params.update( 123 | { 124 | "generic_atoms": bool(self.generic_atoms), 125 | "generic_bonds": bool(self.generic_bonds), 126 | } 127 | ) 128 | else: 129 | parent_params.update( 130 | { 131 | "generic_atoms": self.generic_atoms, 132 | "generic_bonds": self.generic_bonds, 133 | } 134 | ) 135 | return parent_params 136 | 137 | def set_params(self, **parameters: dict[str, Any]) -> Self: 138 | """Set parameters of pipeline element. 139 | 140 | Parameters 141 | ---------- 142 | parameters: dict[str, Any] 143 | Parameters to set. 144 | 145 | 146 | Raises 147 | ------ 148 | ValueError 149 | If parameters are not valid. 150 | 151 | Returns 152 | ------- 153 | Self 154 | Pipeline element with set parameters. 155 | 156 | """ 157 | param_copy = parameters.copy() 158 | generic_atoms = param_copy.pop("generic_atoms", None) 159 | generic_bonds = param_copy.pop("generic_bonds", None) 160 | if generic_atoms is not None: 161 | if not isinstance(generic_atoms, bool): 162 | raise ValueError("generic_atoms must be a boolean.") 163 | self.generic_atoms = generic_atoms 164 | if generic_bonds is not None: 165 | if not isinstance(generic_bonds, bool): 166 | raise ValueError("generic_bonds must be a boolean.") 167 | self.generic_bonds = generic_bonds 168 | return self 169 | -------------------------------------------------------------------------------- /molpipeline/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize pipeline module.""" 2 | 3 | from molpipeline.pipeline._skl_pipeline import Pipeline 4 | 5 | __all__ = ["Pipeline"] 6 | -------------------------------------------------------------------------------- /molpipeline/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/molpipeline/py.typed -------------------------------------------------------------------------------- /molpipeline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /molpipeline/utils/comparison.py: -------------------------------------------------------------------------------- 1 | """Functions for comparing pipelines.""" 2 | 3 | from typing import Any, TypeVar 4 | 5 | from molpipeline import Pipeline 6 | from molpipeline.utils.json_operations import recursive_to_json 7 | 8 | _T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any) 9 | 10 | 11 | def remove_irrelevant_params(params: _T) -> _T: 12 | """Remove irrelevant parameters from a dictionary. 13 | 14 | Parameters 15 | ---------- 16 | params : TypeVar 17 | Parameters to remove irrelevant parameters from. 18 | 19 | Returns 20 | ------- 21 | TypeVar 22 | Parameters without irrelevant parameters. 23 | 24 | """ 25 | if isinstance(params, list): 26 | return [remove_irrelevant_params(val) for val in params] 27 | if isinstance(params, tuple): 28 | return tuple(remove_irrelevant_params(val) for val in params) 29 | if isinstance(params, set): 30 | return {remove_irrelevant_params(val) for val in params} 31 | 32 | irrelevant_params = ["n_jobs", "uuid", "error_filter_id"] 33 | if isinstance(params, dict): 34 | params_new = {} 35 | for key, value in params.items(): 36 | if not isinstance(key, str): 37 | continue 38 | if key.split("__")[-1] in irrelevant_params: 39 | continue 40 | params_new[key] = remove_irrelevant_params(value) 41 | return params_new 42 | return params 43 | 44 | 45 | def compare_recursive( # pylint: disable=too-many-return-statements 46 | value_a: Any, value_b: Any 47 | ) -> bool: 48 | """Compare two values recursively. 49 | 50 | Parameters 51 | ---------- 52 | value_a : Any 53 | First value to compare. 54 | value_b : Any 55 | Second value to compare. 56 | 57 | Returns 58 | ------- 59 | bool 60 | True if the values are the same, False otherwise. 61 | 62 | """ 63 | if value_a.__class__ != value_b.__class__: 64 | return False 65 | 66 | if isinstance(value_a, dict): 67 | if set(value_a.keys()) != set(value_b.keys()): 68 | return False 69 | for key in value_a: 70 | if not compare_recursive(value_a[key], value_b[key]): 71 | return False 72 | return True 73 | 74 | if isinstance(value_a, (list, tuple)): 75 | if len(value_a) != len(value_b): 76 | return False 77 | for val_a, val_b in zip(value_a, value_b): 78 | if not compare_recursive(val_a, val_b): 79 | return False 80 | return True 81 | return value_a == value_b 82 | 83 | 84 | def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool: 85 | """Check if two pipelines are the same. 86 | 87 | Parameters 88 | ---------- 89 | pipeline_a : Pipeline 90 | Pipeline to compare. 91 | pipeline_b : Pipeline 92 | Pipeline to compare. 93 | 94 | Raises 95 | ------ 96 | ValueError 97 | If the pipelines are not of type Pipeline. 98 | 99 | Returns 100 | ------- 101 | bool 102 | True if the pipelines are the same, False otherwise. 103 | 104 | """ 105 | if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline): 106 | raise ValueError("Both inputs should be of type Pipeline.") 107 | pipeline_json_a = recursive_to_json(pipeline_a) 108 | pipeline_json_a = remove_irrelevant_params(pipeline_json_a) 109 | pipeline_json_b = recursive_to_json(pipeline_b) 110 | pipeline_json_b = remove_irrelevant_params(pipeline_json_b) 111 | return compare_recursive(pipeline_json_a, pipeline_json_b) 112 | -------------------------------------------------------------------------------- /molpipeline/utils/json_operations_torch.py: -------------------------------------------------------------------------------- 1 | """Functions for serializing and deserializing PyTorch models.""" 2 | 3 | from typing import TypeVar 4 | 5 | try: 6 | import torch 7 | 8 | TORCH_AVAILABLE = True 9 | except ImportError: 10 | TORCH_AVAILABLE = False 11 | from typing import Any, Literal 12 | 13 | _T = TypeVar("_T") 14 | 15 | if TORCH_AVAILABLE: 16 | 17 | def tensor_to_json( 18 | obj: _T, 19 | ) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: 20 | """Recursively convert a PyTorch model to a JSON-serializable object. 21 | 22 | Parameters 23 | ---------- 24 | obj : object 25 | The object to convert. 26 | 27 | Returns 28 | ------- 29 | object 30 | The JSON-serializable object. 31 | """ 32 | if isinstance(obj, torch.Tensor): 33 | object_dict: dict[str, Any] = { 34 | "__name__": obj.__class__.__name__, 35 | "__module__": obj.__class__.__module__, 36 | "__init__": True, 37 | } 38 | else: 39 | return obj, False 40 | object_dict["data"] = obj.tolist() 41 | return object_dict, True 42 | 43 | else: 44 | 45 | def tensor_to_json( 46 | obj: _T, 47 | ) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: 48 | """Recursively convert a PyTorch model to a JSON-serializable object. 49 | 50 | Parameters 51 | ---------- 52 | obj : object 53 | The object to convert. 54 | 55 | Returns 56 | ------- 57 | object 58 | The JSON-serializable object. 59 | """ 60 | return obj, False 61 | -------------------------------------------------------------------------------- /molpipeline/utils/kernel.py: -------------------------------------------------------------------------------- 1 | """Contains functions for molecular similarity.""" 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy import sparse 8 | 9 | 10 | def tanimoto_similarity_sparse( 11 | matrix_a: sparse.csr_matrix, matrix_b: sparse.csr_matrix 12 | ) -> npt.NDArray[np.float64]: 13 | """Calculate a matrix of tanimoto similarities between feature matrix a and b. 14 | 15 | Parameters 16 | ---------- 17 | matrix_a: sparse.csr_matrix 18 | Feature matrix A. 19 | matrix_b: sparse.csr_matrix 20 | Feature matrix B. 21 | 22 | Returns 23 | ------- 24 | npt.NDArray[np.float64] 25 | Matrix of similarity values between instances of A (rows/first dim) , and instances of B (columns/second dim). 26 | 27 | """ 28 | intersection = matrix_a.dot(matrix_b.transpose()).toarray() 29 | norm_1 = np.array(matrix_a.multiply(matrix_a).sum(axis=1)) 30 | norm_2 = np.array(matrix_b.multiply(matrix_b).sum(axis=1)) 31 | union = norm_1 + norm_2.T - intersection 32 | # avoid division by zero https://stackoverflow.com/a/37977222 33 | return np.divide( 34 | intersection, 35 | union, 36 | out=np.zeros(intersection.shape, dtype=float), 37 | where=union != 0, 38 | ) 39 | 40 | 41 | def tanimoto_distance_sparse( 42 | matrix_a: sparse.csr_matrix, matrix_b: sparse.csr_matrix 43 | ) -> npt.NDArray[np.float64]: 44 | """Calculate a matrix of tanimoto distance between feature matrix a and b. 45 | 46 | Tanimoto distance is defined as 1-similarity. 47 | 48 | Parameters 49 | ---------- 50 | matrix_a: sparse.csr_matrix 51 | Feature matrix A. 52 | matrix_b: sparse.csr_matrix 53 | Feature matrix B. 54 | 55 | Returns 56 | ------- 57 | npt.NDArray[np.float64] 58 | Matrix of similarity values between instances of A (rows/first dim) , and instances of B (columns/second dim). 59 | 60 | """ 61 | return 1 - tanimoto_similarity_sparse(matrix_a, matrix_b) # type: ignore 62 | 63 | 64 | def self_tanimoto_similarity( 65 | matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]], 66 | ) -> npt.NDArray[np.float64]: 67 | """Calculate a matrix of tanimoto similarity between feature matrix a and itself. 68 | 69 | Parameters 70 | ---------- 71 | matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]] 72 | Feature matrix. 73 | 74 | Raises 75 | ------ 76 | TypeError 77 | If the matrix is not a sparse matrix or a numpy array. 78 | 79 | Returns 80 | ------- 81 | npt.NDArray[np.float64] 82 | Square matrix of similarity values between all instances in the matrix. 83 | 84 | """ 85 | if isinstance(matrix_a, np.ndarray): 86 | sparse_matrix = sparse.csr_matrix(matrix_a) 87 | elif isinstance(matrix_a, sparse.csr_matrix): 88 | sparse_matrix = matrix_a 89 | else: 90 | raise TypeError(f"Unsupported type: {type(matrix_a)}") 91 | return tanimoto_similarity_sparse(sparse_matrix, sparse_matrix) 92 | 93 | 94 | def self_tanimoto_distance( 95 | matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]], 96 | ) -> npt.NDArray[np.float64]: 97 | """Calculate a matrix of tanimoto distance between feature matrix a and itself. 98 | 99 | Parameters 100 | ---------- 101 | matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]] 102 | Feature matrix. 103 | 104 | Returns 105 | ------- 106 | npt.NDArray[np.float64] 107 | Square matrix of similarity values between all instances in the matrix. 108 | 109 | """ 110 | return 1 - self_tanimoto_similarity(matrix_a) # type: ignore 111 | -------------------------------------------------------------------------------- /molpipeline/utils/logging.py: -------------------------------------------------------------------------------- 1 | """Logging helper functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import timeit 6 | from collections.abc import Generator 7 | from contextlib import contextmanager 8 | 9 | from loguru import logger 10 | 11 | 12 | def _message_with_time(source: str, message: str, time: float) -> str: 13 | """Create one line message for logging purposes. 14 | 15 | Adapted from sklearn's function to stay consistent with the logging style: 16 | https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py 17 | 18 | Parameters 19 | ---------- 20 | source : str 21 | String indicating the source or the reference of the message. 22 | message : str 23 | Short message. 24 | time : float 25 | Time in seconds. 26 | 27 | Returns 28 | ------- 29 | str 30 | Message with elapsed time. 31 | """ 32 | start_message = f"[{source}] " 33 | 34 | # adapted from joblib.logger.short_format_time without the Windows -.1s 35 | # adjustment 36 | if time > 60: 37 | time_str = f"{(time / 60):4.1f}min" 38 | else: 39 | time_str = f" {time:5.1f}s" 40 | 41 | end_message = f" {message}, total={time_str}" 42 | dots_len = 70 - len(start_message) - len(end_message) 43 | return f"{start_message}{dots_len * '.'}{end_message}" 44 | 45 | 46 | @contextmanager 47 | def print_elapsed_time( 48 | source: str, message: str | None = None, use_logger: bool = False 49 | ) -> Generator[None, None, None]: 50 | """Log elapsed time to stdout when the context is exited. 51 | 52 | Adapted from sklearn's function to stay consistent with the logging style: 53 | https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py 54 | 55 | Parameters 56 | ---------- 57 | source : str 58 | String indicating the source or the reference of the message. 59 | message : str, default=None 60 | Short message. If None, nothing will be printed. 61 | use_logger : bool, default=False 62 | If True, the message will be logged using the logger. 63 | 64 | Returns 65 | ------- 66 | context_manager 67 | Prints elapsed time upon exit if verbose. 68 | 69 | """ 70 | if message is None: 71 | yield 72 | else: 73 | start = timeit.default_timer() 74 | yield 75 | message_to_print = _message_with_time( 76 | source, message, timeit.default_timer() - start 77 | ) 78 | 79 | if use_logger: 80 | logger.info(message_to_print) 81 | else: 82 | print(message_to_print) 83 | -------------------------------------------------------------------------------- /molpipeline/utils/matrices.py: -------------------------------------------------------------------------------- 1 | """Functions to handle sparse matrices.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Iterable 6 | 7 | from scipy import sparse 8 | 9 | 10 | def sparse_from_index_value_dicts( 11 | row_index_lists: Iterable[dict[int, int]], n_columns: int 12 | ) -> sparse.csr_matrix: 13 | """Create a sparse matrix from list of dicts. 14 | 15 | Each dict represents one row. 16 | Keys in dictionary correspond to colum index, values represent values of column. 17 | 18 | Parameters 19 | ---------- 20 | row_index_lists: Iterable[dict[int, int]] 21 | Iterable of dicts of which each holds column positions and values. 22 | n_columns: int 23 | Total number of columns 24 | 25 | Returns 26 | ------- 27 | sparse.csr_matrix 28 | Has shape (len(row_index_lists), n_columns). 29 | """ 30 | data: list[int] = [] 31 | row_positions: list[int] = [] 32 | col_positions: list[int] = [] 33 | row_idx = -1 34 | for row_idx, row_dict in enumerate(row_index_lists): 35 | data.extend(row_dict.values()) 36 | col_positions.extend(row_dict.keys()) 37 | row_positions.extend([row_idx] * len(row_dict)) 38 | if row_idx == -1: 39 | return sparse.csr_matrix((0, n_columns)) 40 | 41 | return sparse.csr_matrix( 42 | (data, (row_positions, col_positions)), shape=(row_idx + 1, n_columns) 43 | ) 44 | 45 | 46 | def are_equal(matrix_a: sparse.csr_matrix, matrix_b: sparse.csr_matrix) -> bool: 47 | """Compare if any element is not equal, as this is more efficient. 48 | 49 | Parameters 50 | ---------- 51 | matrix_a: sparse.csr_matrix 52 | Matrix A to compare. 53 | matrix_b: sparse.csr_matrix 54 | Matrix B to compare. 55 | 56 | Returns 57 | ------- 58 | bool 59 | Whether the matrices are equal or not. 60 | """ 61 | is_unequal_matrix = matrix_a != matrix_b 62 | number_unequal_elements = int(is_unequal_matrix.nnz) 63 | return number_unequal_elements == 0 64 | -------------------------------------------------------------------------------- /molpipeline/utils/molpipeline_types.py: -------------------------------------------------------------------------------- 1 | """Definition of types used in molpipeline.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Sequence 6 | from numbers import Number 7 | from typing import ( 8 | Any, 9 | Literal, 10 | Optional, 11 | Protocol, 12 | TypeAlias, 13 | TypeVar, 14 | Union, 15 | ) 16 | 17 | try: 18 | from typing import Self # type: ignore[attr-defined] 19 | except ImportError: 20 | from typing_extensions import Self 21 | 22 | import numpy as np 23 | import numpy.typing as npt 24 | 25 | from molpipeline.abstract_pipeline_elements.core import ( 26 | ABCPipelineElement, 27 | OptionalMol, 28 | RDKitMol, 29 | ) 30 | 31 | __all__ = [ 32 | "AnyElement", 33 | "AnyNumpyElement", 34 | "AnyPredictor", 35 | "AnySklearnEstimator", 36 | "AnySklearnEstimator", 37 | "AnyStep", 38 | "AnyTransformer", 39 | "Number", 40 | "OptionalMol", 41 | "RDKitMol", 42 | ] 43 | # One liner type definitions 44 | 45 | AnyNumpyElement = TypeVar("AnyNumpyElement", bound=np.generic) 46 | 47 | _T = TypeVar("_T") 48 | _NT = TypeVar("_NT", bound=np.generic) 49 | TypeFixedVarSeq = TypeVar("TypeFixedVarSeq", bound=Sequence[_T] | npt.NDArray[_NT]) # type: ignore 50 | AnyVarSeq = TypeVar("AnyVarSeq", bound=Sequence[Any] | npt.NDArray[Any]) 51 | 52 | FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] 53 | IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] 54 | 55 | # IntOrIntCountRange for Typing of count ranges 56 | # - a single int for an exact value match 57 | # - a range given as a tuple with a lower and upper bound 58 | # - both limits are optional 59 | IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] 60 | 61 | 62 | class AnySklearnEstimator(Protocol): 63 | """Protocol for sklearn estimators.""" 64 | 65 | def get_params(self, deep: bool = True) -> dict[str, Any]: 66 | """Get parameters for this estimator. 67 | 68 | Parameters 69 | ---------- 70 | deep: bool 71 | If True, will return the parameters for this estimator. 72 | 73 | Returns 74 | ------- 75 | dict[str, Any] 76 | Parameter names mapped to their values. 77 | """ 78 | 79 | def set_params(self, **params: Any) -> Self: 80 | """Set the parameters of this estimator. 81 | 82 | Parameters 83 | ---------- 84 | params: Any 85 | Estimator parameters. 86 | 87 | Returns 88 | ------- 89 | Self 90 | Estimator with updated parameters. 91 | """ 92 | 93 | def fit( 94 | self, 95 | X: npt.NDArray[Any], # pylint: disable=invalid-name 96 | y: npt.NDArray[Any] | None, 97 | **fit_params: Any, 98 | ) -> Self: 99 | """Fit the model with X. 100 | 101 | Parameters 102 | ---------- 103 | X: npt.NDArray[Any] 104 | Model input. 105 | y: npt.NDArray[Any] | None 106 | Target values. 107 | fit_params: Any 108 | Additional parameters for fitting. 109 | 110 | 111 | Returns 112 | ------- 113 | Self 114 | Fitted estimator. 115 | """ 116 | 117 | 118 | class AnyPredictor(AnySklearnEstimator, Protocol): 119 | """Protocol for predictors.""" 120 | 121 | def fit_predict( 122 | self, 123 | X: npt.NDArray[Any], # pylint: disable=invalid-name 124 | y: npt.NDArray[Any] | None, 125 | **fit_params: Any, 126 | ) -> npt.NDArray[Any]: 127 | """Fit the model with X and return predictions. 128 | 129 | Parameters 130 | ---------- 131 | X: npt.NDArray[Any] 132 | Model input. 133 | y: npt.NDArray[Any] | None 134 | Target values. 135 | fit_params: Any 136 | Additional parameters for fitting. 137 | 138 | Returns 139 | ------- 140 | npt.NDArray[Any] 141 | Predictions. 142 | """ 143 | 144 | 145 | class AnyTransformer(AnySklearnEstimator, Protocol): 146 | """Protocol for transformers.""" 147 | 148 | def fit_transform( 149 | self, 150 | X: npt.NDArray[Any], # pylint: disable=invalid-name 151 | y: npt.NDArray[Any] | None, 152 | **fit_params: Any, 153 | ) -> npt.NDArray[Any]: 154 | """Fit the model with X and return the transformed array. 155 | 156 | Parameters 157 | ---------- 158 | X: npt.NDArray[Any] 159 | Model input. 160 | y: npt.NDArray[Any] | None 161 | Target values. 162 | fit_params: Any 163 | Additional parameters for fitting. 164 | 165 | 166 | Returns 167 | ------- 168 | npt.NDArray[Any] 169 | Transformed array. 170 | """ 171 | 172 | def transform( 173 | self, 174 | X: npt.NDArray[Any], # pylint: disable=invalid-name 175 | **params: Any, 176 | ) -> npt.NDArray[Any]: 177 | """Transform and return X according to object protocol. 178 | 179 | Parameters 180 | ---------- 181 | X: npt.NDArray[Any] 182 | Model input. 183 | params: Any 184 | Additional parameters for transforming. 185 | 186 | Returns 187 | ------- 188 | npt.NDArray[Any] 189 | Transformed array. 190 | """ 191 | 192 | 193 | AnyElement = Union[ 194 | AnyTransformer, AnyPredictor, ABCPipelineElement, Literal["passthrough"] 195 | ] 196 | AnyStep = tuple[str, AnyElement] 197 | -------------------------------------------------------------------------------- /molpipeline/utils/multi_proc.py: -------------------------------------------------------------------------------- 1 | """Utility functions for multiprocessing.""" 2 | 3 | from __future__ import annotations 4 | 5 | import multiprocessing 6 | import warnings 7 | 8 | 9 | def check_available_cores(n_requested_cores: int) -> int: 10 | """Compare number of requested cores with available cores and return a (corrected) number. 11 | 12 | Parameters 13 | ---------- 14 | n_requested_cores: int 15 | Number of requested cores. 16 | 17 | Raises 18 | ------ 19 | TypeError 20 | If n_requested_cores is not an integer. 21 | 22 | Returns 23 | ------- 24 | int 25 | Number of used cores. 26 | 27 | """ 28 | if not isinstance(n_requested_cores, int): 29 | raise TypeError(f"Not an integer: {n_requested_cores}") 30 | try: 31 | n_available_cores = multiprocessing.cpu_count() 32 | except ModuleNotFoundError: 33 | warnings.warn( 34 | "Cannot import multiprocessing library. Falling back to single core!", 35 | stacklevel=2, 36 | ) 37 | return 1 38 | 39 | if n_requested_cores > n_available_cores: 40 | warnings.warn( 41 | "Requested more cores than available. Using maximum number of cores!", 42 | stacklevel=2, 43 | ) 44 | return n_available_cores 45 | if n_requested_cores < 0: 46 | return n_available_cores 47 | 48 | return n_requested_cores 49 | -------------------------------------------------------------------------------- /molpipeline/utils/substructure_handling.py: -------------------------------------------------------------------------------- 1 | """Classes for handling substructures. 2 | 3 | This is only relevant for explainable AI, where atoms need to be mapped to features. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from rdkit import Chem 9 | 10 | 11 | # pylint: disable=R0903 12 | class AtomEnvironment: 13 | """A Class to store environment-information for fingerprint features.""" 14 | 15 | def __init__(self, environment_atoms: set[int]): 16 | """Initialize AtomEnvironment. 17 | 18 | Parameters 19 | ---------- 20 | environment_atoms: set[int] 21 | Indices of atoms encoded by environment. 22 | """ 23 | self.environment_atoms = environment_atoms # set of all atoms within radius 24 | 25 | 26 | # pylint: disable=R0903 27 | class CircularAtomEnvironment(AtomEnvironment): 28 | """A Class to store environment-information for morgan-fingerprint features.""" 29 | 30 | def __init__(self, central_atom: int, radius: int, environment_atoms: set[int]): 31 | """Initialize CircularAtomEnvironment. 32 | 33 | Parameters 34 | ---------- 35 | central_atom: int 36 | Index of central atom in circular fingerprint. 37 | radius: int 38 | Radius of feature. 39 | environment_atoms: set[int] 40 | All indices of atoms within radius of central atom. 41 | """ 42 | super().__init__(environment_atoms) 43 | self.central_atom = central_atom 44 | self.radius = radius 45 | 46 | @classmethod 47 | def from_mol( 48 | cls, mol: Chem.Mol, central_atom_index: int, radius: int 49 | ) -> CircularAtomEnvironment: 50 | """Generate class from mol, using location (central_atom_index) and the radius. 51 | 52 | Parameters 53 | ---------- 54 | mol: Chem.Mol 55 | Molecule from which the environment is derived. 56 | central_atom_index: int 57 | Index of central atom in feature. 58 | radius: int 59 | Radius of feature. 60 | 61 | Returns 62 | ------- 63 | CircularAtomEnvironment 64 | Encoded the atoms which are within the radius of the central atom and are part of the feature. 65 | """ 66 | if radius == 0: 67 | return CircularAtomEnvironment( 68 | central_atom_index, radius, {central_atom_index} 69 | ) 70 | 71 | env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, central_atom_index) 72 | amap: dict[int, int] = {} 73 | _ = Chem.PathToSubmol(mol, env, atomMap=amap) 74 | env_atoms = amap.keys() 75 | return CircularAtomEnvironment(central_atom_index, radius, set(env_atoms)) 76 | -------------------------------------------------------------------------------- /molpipeline/utils/value_checks.py: -------------------------------------------------------------------------------- 1 | """Module for checking values.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | __all__ = ["get_length", "is_empty"] 8 | 9 | 10 | def is_empty(value: Any) -> bool: 11 | """Check if value is empty. 12 | 13 | Parameters 14 | ---------- 15 | value: Any 16 | Value to be checked. 17 | 18 | Returns 19 | ------- 20 | bool 21 | True if value is empty, False otherwise. 22 | """ 23 | if get_length(value) == 0: 24 | return True 25 | return False 26 | 27 | 28 | def get_length(values: Any) -> int: 29 | """Get the length of the values as given by the shape or len attribute. 30 | 31 | Parameters 32 | ---------- 33 | values: Any 34 | Values to be checked. 35 | 36 | Raises 37 | ------ 38 | TypeError 39 | If values does not have a shape or len attribute. 40 | 41 | Returns 42 | ------- 43 | int 44 | Length of the values. 45 | """ 46 | if hasattr(values, "shape"): 47 | return values.shape[0] 48 | if hasattr(values, "__len__"): 49 | return len(values) 50 | raise TypeError("Values must have a shape or len attribute.") 51 | -------------------------------------------------------------------------------- /molpipeline/utils/value_conversions.py: -------------------------------------------------------------------------------- 1 | """Module for utilities converting values.""" 2 | 3 | from collections.abc import Sequence 4 | from typing import TypeVar 5 | 6 | VarNumber = TypeVar("VarNumber", float | None, int | None) 7 | 8 | 9 | def assure_range(value: VarNumber | Sequence[VarNumber]) -> tuple[VarNumber, VarNumber]: 10 | """Assure that the value is defining a range. 11 | 12 | Integers or floats are converted to a range with the same value for both 13 | 14 | Parameters 15 | ---------- 16 | value: VarNumber | Sequence[VarNumber] 17 | Count value. Can be a single int | float or a Sequence of two values. 18 | 19 | Raises 20 | ------ 21 | ValueError 22 | If the count is a sequence of length other than 2. 23 | TypeError 24 | If the count is not an int or a sequence. 25 | 26 | Returns 27 | ------- 28 | IntCountRange 29 | Tuple of count values. 30 | 31 | """ 32 | if isinstance(value, (float, int)): 33 | return value, value 34 | if isinstance(value, Sequence): 35 | range_tuple = tuple(value) 36 | if len(range_tuple) != 2: # noqa: PLR2004 37 | raise ValueError(f"Expected a sequence of length 2, got: {range_tuple}") 38 | return range_tuple 39 | raise TypeError(f"Got unexpected type: {type(value)}") 40 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | target-version = "py310" 2 | [lint] 3 | preview = true 4 | select = [ 5 | "A", # flake8-buildins 6 | "ANN", # flake8-annotations 7 | "ARG", # flake8-unused-arguments 8 | "B", # flake8-bugbear 9 | "C4", # flake8-comprehensions 10 | "COM", # flake8-commas 11 | "D", # pydocstyle 12 | "DOC", # pydoclint 13 | "E", # pycodestyle - errors 14 | "ERA", # eradicate 15 | "F", # pyflakes 16 | "FA", # flake8-funny-annotations 17 | "FIX", # flake8-fixme 18 | "FLY", # flake8-flynt 19 | "FURB", # refurb 20 | "I", # isort 21 | "ISC", # flake8-implicit-str-concat 22 | "ICN", # flake8-import-conventions 23 | "INP", # flake8-no-pep420 24 | "N", # pep8-naming 25 | "NPY", # numpy-specific rules 26 | "PD", # pandas-vet 27 | "PERF", # perflint 28 | "PGH", # pygrep-hook 29 | "PIE", # flake8-pie 30 | "PL", # pylint 31 | "PTH", # flake8-use-pathlib 32 | "PYI", # flake8-pyi 33 | "Q", # flake8-quotes 34 | "RET", # flake8-return 35 | "RUF", # ruff specific rules 36 | "S", # flake8-bandit 37 | "SIM", # flake8-simplify 38 | "SLF", # flake8-self 39 | "TC", # flake8-type-checking 40 | "TD", # flake8-todo 41 | "T10", # flake8-debugger 42 | "T20", # flake8-print 43 | "TID", # flake8-tidy-imports 44 | "UP", # pyupgrade 45 | "W", # pycodestyle - warnings 46 | ] 47 | ignore = [ 48 | "ANN401", # Allow typing.Any 49 | "ANN204", # Missing return type annotation for special method 50 | "D203", # 1 blank line required before class docstring 51 | "D213", # blank-line-before-class-docstring 52 | "PLR0913", # too-many-arguments 53 | "PGH003", # Blanket type ignore for types 54 | "PLW2901", # Redefined loop variable 55 | "S311", # suspicious-non-cryptographic-random-usage 56 | ] 57 | pylint = {max-positional-args=10 } 58 | [lint.per-file-ignores] 59 | "*.ipynb" = [ 60 | "PLE1142", 61 | "F704", 62 | "T201" 63 | ] 64 | "test_*" = [ 65 | "S101", 66 | "S404", 67 | "S603", 68 | "S607" 69 | ] 70 | -------------------------------------------------------------------------------- /test_extras/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for testing the extras packages.""" 2 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the unit tests for the chemprop wrappers.""" 2 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/chemprop_test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions repeatedly used in tests for Chemprop models.""" 2 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/chemprop_test_utils/compare_models.py: -------------------------------------------------------------------------------- 1 | """Functions for comparing chemprop models.""" 2 | 3 | from collections.abc import Sequence 4 | from unittest import TestCase 5 | 6 | import torch 7 | from chemprop.nn.loss import LossFunction 8 | from lightning.pytorch.accelerators import Accelerator 9 | from lightning.pytorch.profilers.base import PassThroughProfiler 10 | from sklearn.base import BaseEstimator 11 | from torch import nn 12 | 13 | 14 | def compare_params( 15 | test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator 16 | ) -> None: 17 | """Compare the parameters of two models. 18 | 19 | Parameters 20 | ---------- 21 | test_case : TestCase 22 | The test case for which to raise the assertion. 23 | model_a : BaseEstimator 24 | The first model. 25 | model_b : BaseEstimator 26 | The second model. 27 | """ 28 | model_a_params = model_a.get_params(deep=True) 29 | model_b_params = model_b.get_params(deep=True) 30 | test_case.assertSetEqual(set(model_a_params.keys()), set(model_b_params.keys())) 31 | for param_name, param_a in model_a_params.items(): 32 | param_b = model_b_params[param_name] 33 | test_case.assertEqual(param_a.__class__, param_b.__class__) 34 | if hasattr(param_a, "get_params"): 35 | test_case.assertTrue(hasattr(param_b, "get_params")) 36 | test_case.assertNotEqual(id(param_a), id(param_b)) 37 | elif isinstance(param_a, LossFunction): 38 | test_case.assertEqual( 39 | param_a.state_dict()["task_weights"], 40 | param_b.state_dict()["task_weights"], 41 | ) 42 | test_case.assertEqual(type(param_a), type(param_b)) 43 | elif isinstance(param_a, (nn.Identity, Accelerator, PassThroughProfiler)): 44 | test_case.assertEqual(type(param_a), type(param_b)) 45 | elif isinstance(param_a, torch.Tensor): 46 | test_case.assertTrue( 47 | torch.equal(param_a, param_b), f"Test failed for {param_name}" 48 | ) 49 | elif param_name == "lightning_trainer__callbacks": 50 | test_case.assertIsInstance(param_b, Sequence) 51 | for i, callback in enumerate(param_a): 52 | test_case.assertIsInstance(callback, type(param_b[i])) 53 | else: 54 | test_case.assertEqual(param_a, param_b, f"Test failed for {param_name}") 55 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/chemprop_test_utils/default_models.py: -------------------------------------------------------------------------------- 1 | """Functions for creating default chemprop models.""" 2 | 3 | from typing import Any 4 | 5 | from molpipeline.estimators.chemprop import ChempropModel, ChempropNeuralFP 6 | from molpipeline.estimators.chemprop.component_wrapper import ( 7 | MPNN, 8 | BinaryClassificationFFN, 9 | BondMessagePassing, 10 | SumAggregation, 11 | ) 12 | 13 | 14 | def get_binary_classification_mpnn() -> MPNN: 15 | """Get a Chemprop model for binary classification. 16 | 17 | Returns 18 | ------- 19 | ChempropModel 20 | The Chemprop model. 21 | """ 22 | binary_clf_ffn = BinaryClassificationFFN() 23 | aggregate = SumAggregation() 24 | bond_message_passing = BondMessagePassing() 25 | mpnn = MPNN( 26 | message_passing=bond_message_passing, 27 | agg=aggregate, 28 | predictor=binary_clf_ffn, 29 | ) 30 | return mpnn 31 | 32 | 33 | def get_neural_fp_encoder( 34 | init_kwargs: dict[str, Any] | None = None, 35 | ) -> ChempropNeuralFP: 36 | """Get the Chemprop model. 37 | 38 | Parameters 39 | ---------- 40 | init_kwargs : dict[str, Any], optional 41 | Additional keyword arguments to pass to `ChempropNeuralFP` during initialization. 42 | 43 | Returns 44 | ------- 45 | ChempropNeuralFP 46 | The Chemprop model. 47 | """ 48 | mpnn = get_binary_classification_mpnn() 49 | init_kwargs = init_kwargs or {} 50 | chemprop_model = ChempropNeuralFP( 51 | model=mpnn, lightning_trainer__accelerator="cpu", **init_kwargs 52 | ) 53 | return chemprop_model 54 | 55 | 56 | def get_chemprop_model_binary_classification_mpnn() -> ChempropModel: 57 | """Get the Chemprop model. 58 | 59 | Returns 60 | ------- 61 | ChempropModel 62 | The Chemprop model. 63 | """ 64 | mpnn = get_binary_classification_mpnn() 65 | chemprop_model = ChempropModel(model=mpnn, lightning_trainer__accelerator="cpu") 66 | return chemprop_model 67 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/test_abstract.py: -------------------------------------------------------------------------------- 1 | """Tests for the abstract class ABCChemprop.""" 2 | 3 | import unittest 4 | 5 | from molpipeline.estimators.chemprop.abstract import ABCChemprop 6 | 7 | 8 | class TestABCChemprop(unittest.TestCase): 9 | """Test static methods of the Chemprop model.""" 10 | 11 | def test_filter_params_callback(self) -> None: 12 | """Test the filter_params_callback method.""" 13 | dummy_params = { 14 | "callback_modelckpt__monitor": "val_loss", 15 | "other__param": "value", 16 | } 17 | # pylint: disable=protected-access 18 | other_params, callback_params = ABCChemprop._filter_params( 19 | dummy_params, "callback_modelckpt" 20 | ) 21 | # pylint: enable=protected-access 22 | self.assertEqual(callback_params, {"monitor": "val_loss"}) 23 | self.assertEqual(other_params, {"other__param": "value"}) 24 | 25 | def test_filter_params_trainer(self) -> None: 26 | """Test the filter_params_trainer method.""" 27 | dummy_params = { 28 | "lightning_trainer__max_epochs": 50, 29 | "other__param": "value", 30 | } 31 | # pylint: disable=protected-access 32 | other_params, trainer_params = ABCChemprop._filter_params( 33 | dummy_params, "lightning_trainer" 34 | ) 35 | # pylint: enable=protected-access 36 | self.assertEqual(trainer_params, {"max_epochs": 50}) 37 | self.assertEqual(other_params, {"other__param": "value"}) 38 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/test_component_wrapper.py: -------------------------------------------------------------------------------- 1 | """Test Chemprop component wrapper.""" 2 | 3 | import unittest 4 | 5 | from chemprop.nn.loss import LossFunction 6 | from sklearn.base import clone 7 | from torch import nn 8 | 9 | from molpipeline.estimators.chemprop.component_wrapper import ( 10 | MPNN, 11 | BinaryClassificationFFN, 12 | BondMessagePassing, 13 | MeanAggregation, 14 | SumAggregation, 15 | ) 16 | 17 | 18 | class BinaryClassificationFFNTest(unittest.TestCase): 19 | """Test the BinaryClassificationFFN class.""" 20 | 21 | def test_get_set_params(self) -> None: 22 | """Test the get_params and set_params methods.""" 23 | binary_clf_ffn = BinaryClassificationFFN() 24 | orig_params = binary_clf_ffn.get_params(deep=True) 25 | new_params = { 26 | "activation": "relu", 27 | "dropout": 0.5, 28 | "hidden_dim": 400, 29 | "input_dim": 300, 30 | "n_layers": 2, 31 | "n_tasks": 1, 32 | } 33 | # Check setting new parameters 34 | binary_clf_ffn.set_params(**new_params) 35 | model_params = binary_clf_ffn.get_params(deep=True) 36 | for param_name, param in new_params.items(): 37 | self.assertEqual(param, model_params[param_name]) 38 | 39 | # Check setting original parameters 40 | binary_clf_ffn.set_params(**orig_params) 41 | model_params = binary_clf_ffn.get_params(deep=True) 42 | for param_name, param in orig_params.items(): 43 | self.assertEqual(param, model_params[param_name]) 44 | 45 | 46 | class BondMessagePassingTest(unittest.TestCase): 47 | """Test the BondMessagePassing class.""" 48 | 49 | def test_get_set_params(self) -> None: 50 | """Test the get_params and set_params methods.""" 51 | bond_message_passing = BondMessagePassing() 52 | orig_params = bond_message_passing.get_params(deep=True) 53 | new_params = { 54 | "activation": "relu", 55 | "bias": True, 56 | "d_e": 14, 57 | "d_h": 300, 58 | "d_v": 133, 59 | "d_vd": None, 60 | "depth": 4, 61 | "dropout_rate": 0.5, 62 | "undirected": False, 63 | } 64 | # Check setting new parameters 65 | bond_message_passing.set_params(**new_params) 66 | model_params = bond_message_passing.get_params(deep=True) 67 | for param_name, param in new_params.items(): 68 | self.assertEqual(param, model_params[param_name]) 69 | 70 | # Check setting original parameters 71 | bond_message_passing.set_params(**orig_params) 72 | model_params = bond_message_passing.get_params(deep=True) 73 | for param_name, param in orig_params.items(): 74 | self.assertEqual(param, model_params[param_name]) 75 | 76 | 77 | class MPNNTest(unittest.TestCase): 78 | """Test the MPNN class.""" 79 | 80 | def test_get_set_params(self) -> None: 81 | """Test the get_params and set_params methods.""" 82 | mpnn1 = MPNN( 83 | message_passing=BondMessagePassing(depth=2), 84 | agg=SumAggregation(), 85 | predictor=BinaryClassificationFFN(n_layers=1), 86 | ) 87 | params1 = mpnn1.get_params(deep=True) 88 | 89 | mpnn2 = MPNN( 90 | message_passing=BondMessagePassing(depth=1), 91 | agg=MeanAggregation(), 92 | predictor=BinaryClassificationFFN(n_layers=4), 93 | ) 94 | mpnn2.set_params(**params1) 95 | for param_name, param in mpnn1.get_params(deep=True).items(): 96 | param2 = mpnn2.get_params(deep=True)[param_name] 97 | # Classes are cloned, so they are not equal, but they should be the same class 98 | # Since (here) objects are identical if their parameters are identical, and since all 99 | # their parameters are listed flat in the params dicts, all objects are identical if 100 | # param dicts are identical. 101 | if hasattr(param, "get_params"): 102 | self.assertEqual(param.__class__, param2.__class__) 103 | else: 104 | self.assertEqual(param, param2) 105 | 106 | def test_clone(self) -> None: 107 | """Test the clone method.""" 108 | mpnn = MPNN( 109 | message_passing=BondMessagePassing(), 110 | agg=SumAggregation(), 111 | predictor=BinaryClassificationFFN(), 112 | ) 113 | mpnn_clone = clone(mpnn) 114 | for param_name, param in mpnn.get_params(deep=True).items(): 115 | clone_param = mpnn_clone.get_params(deep=True)[param_name] 116 | if hasattr(param, "get_params"): 117 | self.assertEqual(param.__class__, clone_param.__class__) 118 | elif isinstance(param, LossFunction): 119 | self.assertEqual( 120 | param.state_dict()["task_weights"], 121 | clone_param.state_dict()["task_weights"], 122 | ) 123 | self.assertEqual(type(param), type(clone_param)) 124 | elif isinstance(param, nn.Identity): 125 | self.assertEqual(type(param), type(clone_param)) 126 | else: 127 | self.assertEqual(param, clone_param) 128 | 129 | 130 | if __name__ == "__main__": 131 | unittest.main() 132 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/test_lightning_wrapper.py: -------------------------------------------------------------------------------- 1 | """Module for testing if the lightning wrapper functions work as intended.""" 2 | 3 | import unittest 4 | 5 | import lightning as pl 6 | 7 | from molpipeline.estimators.chemprop.lightning_wrapper import ( 8 | get_non_default_params_trainer, 9 | get_params_trainer, 10 | ) 11 | 12 | 13 | class TestLightningWrapper(unittest.TestCase): 14 | """Test the lightning wrapper functions. 15 | 16 | Notes 17 | ----- 18 | These tests are not exhaustive. 19 | """ 20 | 21 | def test_setting_deterministic(self) -> None: 22 | """Test setting the deterministic parameter.""" 23 | trainer_params = get_params_trainer(pl.Trainer(deterministic=True)) 24 | self.assertTrue(trainer_params["deterministic"]) 25 | 26 | trainer_params = get_params_trainer(pl.Trainer(deterministic=False)) 27 | self.assertFalse(trainer_params["deterministic"]) 28 | 29 | trainer_params = get_non_default_params_trainer(pl.Trainer(deterministic=True)) 30 | self.assertIn("deterministic", trainer_params) 31 | self.assertTrue(trainer_params["deterministic"]) 32 | 33 | trainer_params = get_non_default_params_trainer(pl.Trainer(deterministic=False)) 34 | # deterministic is by default False and hence will not be listed in the parameters 35 | self.assertNotIn("deterministic", trainer_params) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /test_extras/test_chemprop/test_neural_fingerprint.py: -------------------------------------------------------------------------------- 1 | """Test Chemprop neural fingerprint.""" 2 | 3 | import logging 4 | import unittest 5 | 6 | from sklearn.base import clone 7 | 8 | from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP 9 | from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json 10 | from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params 11 | from test_extras.test_chemprop.chemprop_test_utils.default_models import ( 12 | get_neural_fp_encoder, 13 | ) 14 | 15 | logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) 16 | 17 | 18 | class TestChempropNeuralFingerprint(unittest.TestCase): 19 | """Test the Chemprop model.""" 20 | 21 | def test_clone(self) -> None: 22 | """Test the clone method.""" 23 | chemprop_fp_encoder = get_neural_fp_encoder() 24 | cloned_encoder = clone(chemprop_fp_encoder) 25 | self.assertIsInstance(cloned_encoder, ChempropNeuralFP) 26 | compare_params(self, chemprop_fp_encoder, cloned_encoder) 27 | 28 | def test_json_serialization(self) -> None: 29 | """Test the to_json and from_json methods.""" 30 | chemprop_fp_encoder = get_neural_fp_encoder() 31 | chemprop_json = recursive_to_json(chemprop_fp_encoder) 32 | chemprop_encoder_copy = recursive_from_json(chemprop_json) 33 | compare_params(self, chemprop_fp_encoder, chemprop_encoder_copy) 34 | 35 | def test_output_type(self) -> None: 36 | """Test the output type.""" 37 | chemprop_fp_encoder = get_neural_fp_encoder() 38 | self.assertEqual(chemprop_fp_encoder.output_type, "float") 39 | 40 | def test_init_with_kwargs(self) -> None: 41 | """Test the __init__ method with kwargs.""" 42 | init_kwargs = {"model__message_passing__depth": 4} 43 | chemprop_fp_encoder = get_neural_fp_encoder(init_kwargs=init_kwargs) 44 | deep_params = chemprop_fp_encoder.get_params(deep=True) 45 | self.assertEqual(deep_params["model__message_passing__depth"], 4) 46 | -------------------------------------------------------------------------------- /test_extras/test_notebooks/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the unit tests for the notebook test.""" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize module for all tests.""" 2 | 3 | from pathlib import Path 4 | 5 | TEST_DATA_DIR = Path(__file__).parent / "test_data" 6 | -------------------------------------------------------------------------------- /tests/run_tests.py: -------------------------------------------------------------------------------- 1 | """Run all tests.""" 2 | 3 | import unittest 4 | 5 | if __name__ == "__main__": 6 | pipeline_test = unittest.TestLoader().discover(".") 7 | unittest.TextTestRunner(verbosity=2).run(pipeline_test) 8 | -------------------------------------------------------------------------------- /tests/test_data/P86_B_400.sdf.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/tests/test_data/P86_B_400.sdf.gz -------------------------------------------------------------------------------- /tests/test_data/molecule_net_bbbp.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/tests/test_data/molecule_net_bbbp.tsv.gz -------------------------------------------------------------------------------- /tests/test_data/molecule_net_logd.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/MolPipeline/3ab8aa0ebd345b8b2b2b99dd608371f640211754/tests/test_data/molecule_net_logd.tsv.gz -------------------------------------------------------------------------------- /tests/test_data/multiclass_mock.tsv: -------------------------------------------------------------------------------- 1 | Molecule Label 2 | "CCCCCC" 0 3 | "CCCCCCCO" 1 4 | "CCCC" 0 5 | "CCCN" 2 6 | "CCCCCC" 0 7 | "CCCO" 1 8 | "CCCCC" 0 9 | "CCCCCN" 2 10 | "CC(C)CCC" 0 11 | "CCCCCCO" 1 12 | "CCCCCl" 0 13 | "CCC#N" 2 14 | -------------------------------------------------------------------------------- /tests/test_elements/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/test_elements/test_any2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/test_elements/test_any2mol/test_bin2mol.py: -------------------------------------------------------------------------------- 1 | """Unittests for testing conversion of binary string to RDKit molecules.""" 2 | 3 | import unittest 4 | 5 | from rdkit import Chem, rdBase 6 | 7 | from molpipeline.any2mol import BinaryToMol 8 | 9 | # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated 10 | SMILES_ANTIMONY = "[SbH6+3]" 11 | SMILES_BENZENE = "c1ccccc1" 12 | SMILES_CHLOROBENZENE = "Clc1ccccc1" 13 | SMILES_CL_BR = "NC(Cl)(Br)C(=O)O" 14 | SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O" 15 | 16 | # RDKit mols 17 | MOL_ANTIMONY = Chem.MolFromSmiles(SMILES_ANTIMONY) 18 | MOL_BENZENE = Chem.MolFromSmiles(SMILES_BENZENE) 19 | MOL_CHLOROBENZENE = Chem.MolFromSmiles(SMILES_CHLOROBENZENE) 20 | MOL_CL_BR = Chem.MolFromSmiles(SMILES_CL_BR) 21 | MOL_METAL_AU = Chem.MolFromSmiles(SMILES_METAL_AU) 22 | 23 | 24 | class TestBin2Mol(unittest.TestCase): 25 | """Test case for testing conversion of binary string to molecules.""" 26 | 27 | def test_bin2mol(self) -> None: 28 | """Test molecules can be read from binary string.""" 29 | test_mols = [ 30 | MOL_ANTIMONY, 31 | MOL_BENZENE, 32 | MOL_CHLOROBENZENE, 33 | MOL_CL_BR, 34 | MOL_METAL_AU, 35 | ] 36 | for mol in test_mols: 37 | bin2mol = BinaryToMol() 38 | transformed_mol = bin2mol.pretransform_single(mol.ToBinary()) 39 | log_block = rdBase.BlockLogs() 40 | self.assertEqual( 41 | Chem.MolToInchi(transformed_mol), Chem.MolToInchi(transformed_mol) 42 | ) 43 | del log_block 44 | -------------------------------------------------------------------------------- /tests/test_elements/test_any2mol/test_smiles2mol.py: -------------------------------------------------------------------------------- 1 | """Test smiles to mol pipeline element.""" 2 | 3 | import unittest 4 | from typing import Any 5 | 6 | from molpipeline import Pipeline 7 | from molpipeline.any2mol import SmilesToMol 8 | 9 | 10 | class TestSmiles2Mol(unittest.TestCase): 11 | """Test case for testing conversion of SMILES input to molecules.""" 12 | 13 | def test_smiles2mol_explict_hydrogens(self) -> None: 14 | """Test smiles reading with and without explicit smiles.""" 15 | smiles = "C[H]" 16 | 17 | # test: remove explicit Hs 18 | pipeline = Pipeline( 19 | [ 20 | ( 21 | "Smiles2Mol", 22 | SmilesToMol(remove_hydrogens=True), 23 | ), 24 | ] 25 | ) 26 | mols = pipeline.fit_transform([smiles]) 27 | self.assertEqual(len(mols), 1) 28 | self.assertIsNotNone(mols[0]) 29 | self.assertEqual(mols[0].GetNumAtoms(), 1) 30 | 31 | # test: keep explicit Hs 32 | pipeline2 = Pipeline( 33 | [ 34 | ( 35 | "Smiles2Mol", 36 | SmilesToMol(remove_hydrogens=False), 37 | ), 38 | ] 39 | ) 40 | mols2 = pipeline2.fit_transform([smiles]) 41 | self.assertEqual(len(mols2), 1) 42 | self.assertIsNotNone(mols2[0]) 43 | self.assertEqual(mols2[0].GetNumAtoms(), 2) 44 | 45 | def test_getter_setter(self) -> None: 46 | """Test getter and setter methods.""" 47 | smiles2mol = SmilesToMol(remove_hydrogens=False) 48 | self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], False) 49 | params: dict[str, Any] = { 50 | "remove_hydrogens": True, 51 | } 52 | smiles2mol.set_params(**params) 53 | self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], True) 54 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2bin.py: -------------------------------------------------------------------------------- 1 | """Unittests for testing conversion of molecules to binary string.""" 2 | 3 | import unittest 4 | 5 | from rdkit import Chem, rdBase 6 | 7 | from molpipeline import Pipeline 8 | from molpipeline.any2mol import SmilesToMol 9 | from molpipeline.mol2any import MolToBinary 10 | 11 | # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated 12 | SMILES_ANTIMONY = "[SbH6+3]" 13 | SMILES_BENZENE = "c1ccccc1" 14 | SMILES_CHLOROBENZENE = "Clc1ccccc1" 15 | SMILES_CL_BR = "NC(Cl)(Br)C(=O)O" 16 | SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O" 17 | 18 | MOL_ANTIMONY = Chem.MolFromSmiles(SMILES_ANTIMONY) 19 | MOL_BENZENE = Chem.MolFromSmiles(SMILES_BENZENE) 20 | MOL_CHLOROBENZENE = Chem.MolFromSmiles(SMILES_CHLOROBENZENE) 21 | MOL_CL_BR = Chem.MolFromSmiles(SMILES_CL_BR) 22 | MOL_METAL_AU = Chem.MolFromSmiles(SMILES_METAL_AU) 23 | 24 | 25 | class TestMol2Binary(unittest.TestCase): 26 | """Test case for testing conversion of molecules to binary string representation.""" 27 | 28 | def test_mol_to_binary(self) -> None: 29 | """Test if smiles converted correctly to binary string.""" 30 | test_smiles = [ 31 | SMILES_ANTIMONY, 32 | SMILES_BENZENE, 33 | SMILES_CHLOROBENZENE, 34 | SMILES_CL_BR, 35 | SMILES_METAL_AU, 36 | ] 37 | expected_mols = [ 38 | MOL_ANTIMONY, 39 | MOL_BENZENE, 40 | MOL_CHLOROBENZENE, 41 | MOL_CL_BR, 42 | MOL_METAL_AU, 43 | ] 44 | 45 | pipeline = Pipeline( 46 | [ 47 | ("Smiles2Mol", SmilesToMol()), 48 | ("Mol2Binary", MolToBinary()), 49 | ] 50 | ) 51 | log_block = rdBase.BlockLogs() 52 | binary_mols = pipeline.fit_transform(test_smiles) 53 | self.assertEqual(len(test_smiles), len(binary_mols)) 54 | actual_mols = [Chem.Mol(mol) for mol in binary_mols] 55 | self.assertTrue( 56 | all( 57 | Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) 58 | for smiles_mol, original_mol in zip(actual_mols, expected_mols) 59 | ) 60 | ) 61 | del log_block 62 | 63 | def test_mol_to_binary_invalid_input(self) -> None: 64 | """Test how invalid input is handled.""" 65 | pipeline = Pipeline( 66 | [ 67 | ("Mol2Binary", MolToBinary()), 68 | ] 69 | ) 70 | 71 | # test empty molecule 72 | binary_mols = pipeline.fit_transform([Chem.MolFromSmiles("")]) 73 | self.assertEqual(len(binary_mols), 1) 74 | self.assertEqual(Chem.MolToSmiles(Chem.Mol(binary_mols[0])), "") 75 | 76 | # test None as input 77 | self.assertRaises(AttributeError, pipeline.fit_transform, [None]) 78 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2bool.py: -------------------------------------------------------------------------------- 1 | """Test mol to bool conversion.""" 2 | 3 | import unittest 4 | 5 | from molpipeline import Pipeline 6 | from molpipeline.abstract_pipeline_elements.core import InvalidInstance 7 | from molpipeline.any2mol import AutoToMol 8 | from molpipeline.mol2any import MolToBool 9 | 10 | 11 | class TestMolToBool(unittest.TestCase): 12 | """Unittest for MolToBool.""" 13 | 14 | def test_bool_conversion(self) -> None: 15 | """Test if the invalid instances are converted to bool.""" 16 | mol2bool = MolToBool() 17 | result = mol2bool.transform( 18 | [ 19 | 1, 20 | 2, 21 | InvalidInstance(element_id="test", message="test", element_name="Test"), 22 | 4, 23 | ] 24 | ) 25 | self.assertEqual(result, [True, True, False, True]) 26 | 27 | def test_bool_conversion_pipeline(self) -> None: 28 | """Test if the invalid instances are converted to bool in pipeline.""" 29 | pipeline = Pipeline( 30 | [ 31 | ("auto_to_mol", AutoToMol()), 32 | ("mol2bool", MolToBool()), 33 | ] 34 | ) 35 | result = pipeline.transform(["CC", "CCC", "no%valid~smiles"]) 36 | self.assertEqual(result, [True, True, False]) 37 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2inchi.py: -------------------------------------------------------------------------------- 1 | """Unittests for testing conversion of molecules to InChI and InChIKey.""" 2 | 3 | import unittest 4 | 5 | from molpipeline.any2mol import SmilesToMol 6 | from molpipeline.mol2any import MolToInchi, MolToInchiKey 7 | from molpipeline.pipeline import Pipeline 8 | 9 | # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated 10 | SMILES_ANTIMONY = "[SbH6+3]" 11 | SMILES_BENZENE = "c1ccccc1" 12 | SMILES_CHLOROBENZENE = "Clc1ccccc1" 13 | SMILES_CL_BR = "NC(Cl)(Br)C(=O)O" 14 | SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O" 15 | 16 | 17 | class TestMol2Inchi(unittest.TestCase): 18 | """Test case for testing conversion of molecules to InChI and InChIKey.""" 19 | 20 | def test_to_inchi(self) -> None: 21 | """Test if smiles converted correctly to inchi string.""" 22 | input_smiles = ["CN(C)CCOC(C1=CC=CC=C1)C1=CC=CC=C1"] 23 | expected_inchis = [ 24 | "InChI=1S/C17H21NO/c1-18(2)13-14-19-17(15-9-5-3-6-10-15)16-11-7-4-8-12-16/h3-12,17H,13-14H2,1-2H3" 25 | ] 26 | pipeline = Pipeline( 27 | [ 28 | ("Smiles2Mol", SmilesToMol()), 29 | ("Mol2Inchi", MolToInchi()), 30 | ] 31 | ) 32 | actual_inchis = pipeline.fit_transform(input_smiles) 33 | self.assertEqual(expected_inchis, actual_inchis) 34 | 35 | def test_to_inchikey(self) -> None: 36 | """Test if smiles is converted correctly to inchikey string.""" 37 | input_smiles = ["CN(C)CCOC(C1=CC=CC=C1)C1=CC=CC=C1"] 38 | expected_inchikeys = ["ZZVUWRFHKOJYTH-UHFFFAOYSA-N"] 39 | 40 | pipeline = Pipeline( 41 | [ 42 | ("Smiles2Mol", SmilesToMol()), 43 | ("Mol2Inchi", MolToInchiKey()), 44 | ], 45 | ) 46 | actual_inchikeys = pipeline.fit_transform(input_smiles) 47 | self.assertEqual(expected_inchikeys, actual_inchikeys) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py: -------------------------------------------------------------------------------- 1 | """Tests for the MolToMACCSFP pipeline element.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | from typing import Any 7 | 8 | import numpy as np 9 | 10 | from molpipeline import Pipeline 11 | from molpipeline.any2mol import SmilesToMol 12 | from molpipeline.mol2any import MolToMACCSFP 13 | 14 | # pylint: disable=duplicate-code 15 | # Similar to test_mol2morgan_fingerprint.py and test_mol2path_fingerprint.py 16 | 17 | test_smiles = [ 18 | "c1ccccc1", 19 | "c1ccccc1C", 20 | "NCCOCCCC(=O)O", 21 | ] 22 | 23 | 24 | class TestMolToMACCSFP(unittest.TestCase): 25 | """Unittest for MolToMACCSFP, which calculates MACCS Key Fingerprints.""" 26 | 27 | def test_can_be_constructed(self) -> None: 28 | """Test if the MolToMACCSFP pipeline element can be constructed.""" 29 | mol_fp = MolToMACCSFP() 30 | mol_fp_copy = mol_fp.copy() 31 | self.assertTrue(mol_fp_copy is not mol_fp) 32 | for key, value in mol_fp.get_params().items(): 33 | self.assertEqual(value, mol_fp_copy.get_params()[key]) 34 | mol_fp_recreated = MolToMACCSFP(**mol_fp.get_params()) 35 | for key, value in mol_fp.get_params().items(): 36 | self.assertEqual(value, mol_fp_recreated.get_params()[key]) 37 | 38 | def test_output_types(self) -> None: 39 | """Test equality of different output_types.""" 40 | smi2mol = SmilesToMol() 41 | sparse_maccs = MolToMACCSFP(return_as="sparse") 42 | dense_maccs = MolToMACCSFP(return_as="dense") 43 | explicit_bit_vect_maccs = MolToMACCSFP(return_as="explicit_bit_vect") 44 | sparse_pipeline = Pipeline( 45 | [ 46 | ("smi2mol", smi2mol), 47 | ("sparse_maccs", sparse_maccs), 48 | ], 49 | ) 50 | dense_pipeline = Pipeline( 51 | [ 52 | ("smi2mol", smi2mol), 53 | ("dense_maccs", dense_maccs), 54 | ], 55 | ) 56 | explicit_bit_vect_pipeline = Pipeline( 57 | [ 58 | ("smi2mol", smi2mol), 59 | ("explicit_bit_vect_maccs", explicit_bit_vect_maccs), 60 | ], 61 | ) 62 | 63 | sparse_output = sparse_pipeline.fit_transform(test_smiles) 64 | dense_output = dense_pipeline.fit_transform(test_smiles) 65 | explicit_bit_vect_maccs_output = explicit_bit_vect_pipeline.fit_transform( 66 | test_smiles 67 | ) 68 | 69 | self.assertTrue(np.all(sparse_output.toarray() == dense_output)) 70 | 71 | self.assertTrue( 72 | np.equal( 73 | dense_output, 74 | np.array(explicit_bit_vect_maccs_output), 75 | ).all() 76 | ) 77 | 78 | def test_setter_getter(self) -> None: 79 | """Test if the setters and getters work as expected.""" 80 | mol_fp = MolToMACCSFP() 81 | params: dict[str, Any] = { 82 | "return_as": "dense", 83 | } 84 | mol_fp.set_params(**params) 85 | self.assertEqual(mol_fp.get_params()["return_as"], "dense") 86 | 87 | def test_setter_getter_error_handling(self) -> None: 88 | """Test if the setters and getters work as expected when errors are encountered.""" 89 | mol_fp = MolToMACCSFP() 90 | params: dict[str, Any] = { 91 | "return_as": "invalid-option", 92 | } 93 | self.assertRaises(ValueError, mol_fp.set_params, **params) 94 | 95 | def test_feature_names(self) -> None: 96 | """Test if the feature names are correct.""" 97 | mol_fp = MolToMACCSFP() 98 | feature_names = mol_fp.feature_names 99 | self.assertEqual(len(feature_names), mol_fp.n_bits) 100 | # feature names should be unique 101 | self.assertEqual(len(feature_names), len(set(feature_names))) 102 | 103 | 104 | if __name__ == "__main__": 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2net_charge.py: -------------------------------------------------------------------------------- 1 | """Test generation of net charge calculation.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from molpipeline import ErrorFilter, FilterReinserter, Pipeline 9 | from molpipeline.any2mol import SmilesToMol 10 | from molpipeline.mol2any import MolToNetCharge 11 | 12 | DF_TEST_DATA = pd.DataFrame( 13 | { 14 | "smiles": [ 15 | "[Fe+2]", 16 | "c1cc(c(nc1)Cl)C(=O)Nc2c(c3c(s2)CCCCC3)C(=O)N", 17 | "Cc1ccc(cc1)S(=O)(=O)Nc2c(c3c(s2)C[C@@H](CC3)C)C(=O)N", 18 | "c1cc(oc1)CN=C2C=C(C(=CC2=C(O)[O-])S(=O)(=O)[NH-])Cl", 19 | "C[C@@H]1[C@@H](OP2(O1)(O[C@H]([C@H](O2)C)C)C[NH+]3CCCCC3)C", # this one fails gasteiger charge computation 20 | ], 21 | "expected_net_charges_formal_charge": [2, 0, 0, -2, 1], 22 | "expected_net_charges_gasteiger": [2, -1, -1, -2, np.nan], 23 | } 24 | ) 25 | 26 | 27 | class TestNetChargeCalculator(unittest.TestCase): 28 | """Unittest for MolToNetCharge, which calculates net charges of molecules.""" 29 | 30 | def test_net_charge_calculation_formal_charge(self) -> None: 31 | """Test if the net charge calculation works as expected for formal charges.""" 32 | # we need the error filter and reinserter to handle the case where the charge calculation fails 33 | error_filter = ErrorFilter(filter_everything=True) 34 | pipeline = Pipeline( 35 | [ 36 | ("smi2mol", SmilesToMol()), 37 | ( 38 | "net_charge_element", 39 | MolToNetCharge(charge_method="formal_charge", standardizer=None), 40 | ), 41 | ("error_filter", error_filter), 42 | ( 43 | "filter_reinserter", 44 | FilterReinserter.from_error_filter(error_filter, fill_value=np.nan), 45 | ), 46 | ], 47 | ) 48 | 49 | actual_net_charges = pipeline.fit_transform(DF_TEST_DATA["smiles"]) 50 | self.assertTrue( 51 | np.allclose( 52 | DF_TEST_DATA["expected_net_charges_formal_charge"] 53 | .to_numpy() 54 | .reshape(-1, 1), 55 | actual_net_charges, 56 | equal_nan=True, 57 | ) 58 | ) 59 | 60 | def test_net_charge_calculation_gasteiger(self) -> None: 61 | """Test if the net charge calculation works as expected for gasteiger charges.""" 62 | # we need the error filter and reinserter to handle the case where the charge calculation fails 63 | error_filter = ErrorFilter(filter_everything=True) 64 | pipeline = Pipeline( 65 | [ 66 | ("smi2mol", SmilesToMol()), 67 | ( 68 | "net_charge_element", 69 | MolToNetCharge(charge_method="gasteiger", standardizer=None), 70 | ), 71 | ("error_filter", error_filter), 72 | ( 73 | "filter_reinserter", 74 | FilterReinserter.from_error_filter(error_filter, fill_value=np.nan), 75 | ), 76 | ], 77 | ) 78 | 79 | actual_net_charges = pipeline.fit_transform(DF_TEST_DATA["smiles"]) 80 | self.assertTrue( 81 | np.allclose( 82 | DF_TEST_DATA["expected_net_charges_gasteiger"] 83 | .to_numpy() 84 | .reshape(-1, 1), 85 | actual_net_charges, 86 | equal_nan=True, 87 | ) 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2any/test_mol2path_fingerprint.py: -------------------------------------------------------------------------------- 1 | """Tests for the MolToPathFingerprint pipeline element.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | from typing import Any 7 | 8 | import numpy as np 9 | 10 | from molpipeline import Pipeline 11 | from molpipeline.any2mol import SmilesToMol 12 | from molpipeline.mol2any import Mol2PathFP 13 | 14 | # pylint: disable=duplicate-code 15 | 16 | test_smiles = [ 17 | "c1ccccc1", 18 | "c1ccccc1C", 19 | "NCCOCCCC(=O)O", 20 | ] 21 | 22 | 23 | class TestMol2PathFingerprint(unittest.TestCase): 24 | """Unittest for Mol2PathFP, which calculates the RDKit Path Fingerprint.""" 25 | 26 | def test_can_be_constructed(self) -> None: 27 | """Test if the Mol2PathFP pipeline element can be constructed.""" 28 | mol_fp = Mol2PathFP() 29 | mol_fp_copy = mol_fp.copy() 30 | self.assertTrue(mol_fp_copy is not mol_fp) 31 | for key, value in mol_fp.get_params().items(): 32 | self.assertEqual(value, mol_fp_copy.get_params()[key]) 33 | mol_fp_recreated = Mol2PathFP(**mol_fp.get_params()) 34 | for key, value in mol_fp.get_params().items(): 35 | self.assertEqual(value, mol_fp_recreated.get_params()[key]) 36 | 37 | def test_output_types(self) -> None: 38 | """Test equality of different output_types.""" 39 | smi2mol = SmilesToMol() 40 | sparse_path_fp = Mol2PathFP(n_bits=1024, return_as="sparse") 41 | dense_path_fp = Mol2PathFP(n_bits=1024, return_as="dense") 42 | explicit_bit_vect_path_fp = Mol2PathFP( 43 | n_bits=1024, return_as="explicit_bit_vect" 44 | ) 45 | sparse_pipeline = Pipeline( 46 | [ 47 | ("smi2mol", smi2mol), 48 | ("sparse_path_fp", sparse_path_fp), 49 | ], 50 | ) 51 | dense_pipeline = Pipeline( 52 | [ 53 | ("smi2mol", smi2mol), 54 | ("dense_path_fp", dense_path_fp), 55 | ], 56 | ) 57 | explicit_bit_vect_pipeline = Pipeline( 58 | [ 59 | ("smi2mol", smi2mol), 60 | ("explicit_bit_vect_path_fp", explicit_bit_vect_path_fp), 61 | ], 62 | ) 63 | 64 | sparse_output = sparse_pipeline.fit_transform(test_smiles) 65 | dense_output = dense_pipeline.fit_transform(test_smiles) 66 | explicit_bit_vect_path_fp_output = explicit_bit_vect_pipeline.fit_transform( 67 | test_smiles 68 | ) 69 | 70 | self.assertTrue(np.all(sparse_output.toarray() == dense_output)) 71 | 72 | self.assertTrue( 73 | np.equal( 74 | dense_output, 75 | np.array(explicit_bit_vect_path_fp_output), 76 | ).all() 77 | ) 78 | 79 | def test_counted_bits(self) -> None: 80 | """Test if the option counted bits works as expected.""" 81 | mol_fp = Mol2PathFP(n_bits=1024, return_as="dense") 82 | smi2mol = SmilesToMol() 83 | pipeline = Pipeline( 84 | [ 85 | ("smi2mol", smi2mol), 86 | ("mol_fp", mol_fp), 87 | ], 88 | ) 89 | output_binary = pipeline.fit_transform(test_smiles) 90 | pipeline.set_params(mol_fp__counted=True) 91 | output_counted = pipeline.fit_transform(test_smiles) 92 | self.assertTrue( 93 | np.all(np.flatnonzero(output_counted) == np.flatnonzero(output_binary)) 94 | ) 95 | self.assertTrue(np.all(output_counted >= output_binary)) 96 | self.assertTrue(np.any(output_counted > output_binary)) 97 | 98 | def test_setter_getter(self) -> None: 99 | """Test if the setters and getters work as expected.""" 100 | mol_fp = Mol2PathFP() 101 | params: dict[str, Any] = { 102 | "min_path": 10, 103 | "max_path": 12, 104 | "use_hs": False, 105 | "branched_paths": False, 106 | "use_bond_order": False, 107 | "count_simulation": True, 108 | "num_bits_per_feature": 4, 109 | "counted": True, 110 | "n_bits": 1024, 111 | } 112 | mol_fp.set_params(**params) 113 | self.assertEqual(mol_fp.get_params()["min_path"], 10) 114 | self.assertEqual(mol_fp.get_params()["max_path"], 12) 115 | self.assertEqual(mol_fp.get_params()["use_hs"], False) 116 | self.assertEqual(mol_fp.get_params()["branched_paths"], False) 117 | self.assertEqual(mol_fp.get_params()["use_bond_order"], False) 118 | self.assertEqual(mol_fp.get_params()["count_simulation"], True) 119 | self.assertEqual(mol_fp.get_params()["num_bits_per_feature"], 4) 120 | self.assertEqual(mol_fp.get_params()["counted"], True) 121 | self.assertEqual(mol_fp.get_params()["n_bits"], 1024) 122 | 123 | def test_setter_getter_error_handling(self) -> None: 124 | """Test if the setters and getters work as expected when errors are encountered.""" 125 | mol_fp = Mol2PathFP() 126 | params: dict[str, Any] = { 127 | "min_path": 2, 128 | "n_bits": 1024, 129 | "return_as": "invalid-option", 130 | } 131 | self.assertRaises(ValueError, mol_fp.set_params, **params) 132 | 133 | def test_feature_names(self) -> None: 134 | """Test if the feature names are correct.""" 135 | mol_fp = Mol2PathFP(n_bits=1024) 136 | feature_names = mol_fp.feature_names 137 | self.assertEqual(len(feature_names), 1024) 138 | # feature names should be unique 139 | self.assertEqual(len(feature_names), len(set(feature_names))) 140 | 141 | 142 | if __name__ == "__main__": 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2mol/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/test_elements/test_mol2mol/test_mol2scaffold.py: -------------------------------------------------------------------------------- 1 | """Test the mol2scaffold module.""" 2 | 3 | from typing import Any 4 | from unittest import TestCase 5 | 6 | from molpipeline import Pipeline 7 | from molpipeline.any2mol import AutoToMol 8 | from molpipeline.mol2any import MolToSmiles 9 | from molpipeline.mol2mol.scaffolds import MakeScaffoldGeneric, MurckoScaffold 10 | 11 | 12 | class TestMurckoScaffold(TestCase): 13 | """Test the MurckoScaffold class.""" 14 | 15 | def test_murcko_scaffold_generation_pipeline(self) -> None: 16 | """Test the scaffold generation.""" 17 | scaffold_pipeline = Pipeline( 18 | steps=[ 19 | ("smiles_to_mol", AutoToMol()), 20 | ("murcko_scaffold", MurckoScaffold()), 21 | ("scaffold_to_smiles", MolToSmiles()), 22 | ] 23 | ) 24 | smiles_list = ["Cc1ccc(=O)[nH]c1", "O=CC1CCC(c2ccccc2)CC1", "CCC"] 25 | expected_scaffold_list = ["O=c1cccc[nH]1", "c1ccc(C2CCCCC2)cc1", ""] 26 | 27 | scaffold_list = scaffold_pipeline.transform(smiles_list) 28 | self.assertListEqual(expected_scaffold_list, scaffold_list) 29 | 30 | 31 | class TestMakeScaffoldGeneric(TestCase): 32 | """Test the MakeScaffoldGeneric class.""" 33 | 34 | def setUp(self) -> None: 35 | """Set up the pipeline and common variables.""" 36 | self.generic_scaffold_pipeline = Pipeline( 37 | steps=[ 38 | ("smiles_to_mol", AutoToMol()), 39 | ("murcko_scaffold", MurckoScaffold()), 40 | ("make_scaffold_generic", MakeScaffoldGeneric()), 41 | ("scaffold_to_smiles", MolToSmiles()), 42 | ] 43 | ) 44 | self.smiles_list = ["Cc1ccc(=O)[nH]c1", "O=CC1CCC(c2ccccc2)CC1", "CCC"] 45 | 46 | def check_generic_scaffold( 47 | self, params: dict[str, Any], expected_scaffold_list: list[str] 48 | ) -> None: 49 | """Set parameters and check the results. 50 | 51 | Parameters 52 | ---------- 53 | params: dict[str, Any] 54 | Parameters to set for the pipeline. 55 | expected_scaffold_list: list[str] 56 | Expected output of the pipeline. 57 | 58 | """ 59 | self.generic_scaffold_pipeline.set_params(**params) 60 | generic_scaffold_list = self.generic_scaffold_pipeline.transform( 61 | self.smiles_list 62 | ) 63 | self.assertListEqual(expected_scaffold_list, generic_scaffold_list) 64 | 65 | def test_generic_scaffold_generation_pipeline(self) -> None: 66 | """Test the generic scaffold generation.""" 67 | self.check_generic_scaffold( 68 | params={}, expected_scaffold_list=["CC1CCCCC1", "C1CCC(C2CCCCC2)CC1", ""] 69 | ) 70 | 71 | # Test the generic scaffold generation with generic atoms 72 | self.check_generic_scaffold( 73 | params={"make_scaffold_generic__generic_atoms": True}, 74 | expected_scaffold_list=["**1*****1", "*1***(*2*****2)**1", ""], 75 | ) 76 | 77 | # Test the generic scaffold generation with generic bonds 78 | self.check_generic_scaffold( 79 | params={ 80 | "make_scaffold_generic__generic_atoms": False, 81 | "make_scaffold_generic__generic_bonds": True, 82 | }, 83 | expected_scaffold_list=[ 84 | "C~C1~C~C~C~C~C~1", 85 | "C1~C~C~C(~C2~C~C~C~C~C~2)~C~C~1", 86 | "", 87 | ], 88 | ) 89 | 90 | # Test the generic scaffold generation with generic atoms and bonds 91 | self.check_generic_scaffold( 92 | params={ 93 | "make_scaffold_generic__generic_atoms": True, 94 | "make_scaffold_generic__generic_bonds": True, 95 | }, 96 | expected_scaffold_list=[ 97 | "*~*1~*~*~*~*~*~1", 98 | "*1~*~*~*(~*2~*~*~*~*~*~2)~*~*~1", 99 | "", 100 | ], 101 | ) 102 | -------------------------------------------------------------------------------- /tests/test_elements/test_post_prediction.py: -------------------------------------------------------------------------------- 1 | """Test the module post_prediction.py.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | from sklearn.base import clone 7 | from sklearn.decomposition import PCA 8 | from sklearn.ensemble import RandomForestClassifier 9 | 10 | from molpipeline.post_prediction import PostPredictionWrapper 11 | 12 | 13 | class TestPostPredictionWrapper(unittest.TestCase): 14 | """Test the PostPredictionWrapper class.""" 15 | 16 | def test_get_params(self) -> None: 17 | """Test get_params method.""" 18 | rf = RandomForestClassifier() 19 | rf_params = rf.get_params(deep=True) 20 | 21 | ppw = PostPredictionWrapper(rf) 22 | ppw_params = ppw.get_params(deep=True) 23 | 24 | wrapped_params = {} 25 | for key, value in ppw_params.items(): 26 | first, _, rest = key.partition("__") 27 | if first == "wrapped_estimator": 28 | if rest == "": 29 | self.assertIs(rf, value) 30 | else: 31 | wrapped_params[rest] = value 32 | 33 | self.assertDictEqual(rf_params, wrapped_params) 34 | 35 | def test_set_params(self) -> None: 36 | """Test set_params method. 37 | 38 | Raises 39 | ------ 40 | TypeError 41 | If the wrapped estimator is not a RandomForestClassifier. 42 | 43 | """ 44 | rf = RandomForestClassifier() 45 | ppw = PostPredictionWrapper(rf) 46 | 47 | ppw.set_params(wrapped_estimator__n_estimators=10) 48 | self.assertIsInstance(ppw.wrapped_estimator, RandomForestClassifier) 49 | if not isinstance(ppw.wrapped_estimator, RandomForestClassifier): 50 | raise TypeError("Wrapped estimator is not a RandomForestClassifier.") 51 | self.assertEqual(ppw.wrapped_estimator.n_estimators, 10) 52 | 53 | ppw_params = ppw.get_params(deep=True) 54 | self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10) 55 | 56 | def test_fit_transform(self) -> None: 57 | """Test fit method.""" 58 | rng = np.random.default_rng(20240918) 59 | features = rng.random((10, 5)) 60 | 61 | pca = PCA(n_components=3) 62 | pca.fit(features) 63 | pca_transformed = pca.transform(features) 64 | 65 | ppw = PostPredictionWrapper(clone(pca)) 66 | ppw.fit(features) 67 | ppw_transformed = ppw.transform(features) 68 | 69 | self.assertEqual(pca_transformed.shape, ppw_transformed.shape) 70 | self.assertTrue(np.allclose(pca_transformed, ppw_transformed)) 71 | 72 | def test_inverse_transform(self) -> None: 73 | """Test inverse_transform method.""" 74 | rng = np.random.default_rng(20240918) 75 | features = rng.random((10, 5)) 76 | 77 | pca = PCA(n_components=3) 78 | pca.fit(features) 79 | pca_transformed = pca.transform(features) 80 | pca_inverse = pca.inverse_transform(pca_transformed) 81 | 82 | ppw = PostPredictionWrapper(clone(pca)) 83 | ppw.fit(features) 84 | ppw_transformed = ppw.transform(features) 85 | ppw_inverse = ppw.inverse_transform(ppw_transformed) 86 | 87 | self.assertEqual(features.shape, ppw_inverse.shape) 88 | self.assertEqual(pca_inverse.shape, ppw_inverse.shape) 89 | 90 | self.assertTrue(np.allclose(pca_inverse, ppw_inverse)) 91 | -------------------------------------------------------------------------------- /tests/test_estimators/__init__.py: -------------------------------------------------------------------------------- 1 | """Test sklearn estimators.""" 2 | -------------------------------------------------------------------------------- /tests/test_estimators/test_algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | """Test sklearn estimator algorithms.""" 2 | -------------------------------------------------------------------------------- /tests/test_estimators/test_algorithm/test_union_find.py: -------------------------------------------------------------------------------- 1 | """Test union find algorithm.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | from molpipeline.estimators.algorithm.union_find import UnionFindNode 8 | 9 | 10 | class TestUnionFind(unittest.TestCase): 11 | """Test the UnionFindNode class.""" 12 | 13 | def test_union_find(self) -> None: 14 | """Test the union find algorithm.""" 15 | uf_nodes = [UnionFindNode() for _ in range(10)] 16 | nof_cc, cc_labels = UnionFindNode.get_connected_components(uf_nodes) 17 | 18 | self.assertTrue(np.equal(cc_labels, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).all()) 19 | self.assertEqual(nof_cc, 10) 20 | 21 | uf_nodes = [UnionFindNode() for _ in range(10)] 22 | # one cc 23 | uf_nodes[0].union(uf_nodes[1]) 24 | uf_nodes[0].union(uf_nodes[7]) 25 | # second cc 26 | uf_nodes[2].union(uf_nodes[3]) 27 | uf_nodes[2].union(uf_nodes[4]) 28 | uf_nodes[3].union(uf_nodes[9]) 29 | 30 | cc_num, cc_labels = UnionFindNode.get_connected_components(uf_nodes) 31 | self.assertTrue( 32 | np.equal( 33 | cc_labels, 34 | [ 35 | 0, 36 | 0, 37 | 1, 38 | 1, 39 | 1, 40 | 2, 41 | 3, 42 | 0, 43 | 4, 44 | 1, 45 | ], 46 | ).all() 47 | ) 48 | self.assertEqual(cc_num, 5) 49 | -------------------------------------------------------------------------------- /tests/test_estimators/test_connected_component_clustering.py: -------------------------------------------------------------------------------- 1 | """Test connected component clustering estimator.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | from scipy.sparse import csr_matrix 9 | 10 | from molpipeline.estimators import ConnectedComponentClustering 11 | 12 | 13 | class TestConnectedComponentClusteringEstimator(unittest.TestCase): 14 | """Test connected component clustering estimator.""" 15 | 16 | def test_connected_component_clustering_estimator(self) -> None: 17 | """Test connected component clustering estimator.""" 18 | ccc = ConnectedComponentClustering(distance_threshold=0.5, max_memory_usage=0.1) 19 | 20 | # test no chunking needed 21 | self.assertEqual( 22 | ccc.fit_predict(csr_matrix([[1, 0, 1], [0, 1, 0], [1, 0, 1]])).tolist(), 23 | [0, 1, 0], 24 | ) 25 | self.assertEqual(ccc.n_clusters_, 2) 26 | 27 | # test chunking needed 28 | matrix = csr_matrix([[1, 0, 1], [0, 1, 0], [1, 0, 1]]) 29 | nof_bytes_per_row = np.dtype("float64").itemsize * matrix.shape[1] / (1 << 30) 30 | ccc = ConnectedComponentClustering( 31 | distance_threshold=0.5, 32 | max_memory_usage=nof_bytes_per_row, 33 | ) 34 | self.assertEqual( 35 | ccc.fit_predict(matrix).tolist(), 36 | [0, 1, 0], 37 | ) 38 | self.assertEqual(ccc.n_clusters_, 2) 39 | -------------------------------------------------------------------------------- /tests/test_estimators/test_leader_picker_clustering.py: -------------------------------------------------------------------------------- 1 | """Test leader picker clustering estimator.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | from typing import Any 7 | 8 | import numpy as np 9 | from rdkit import DataStructs 10 | 11 | from molpipeline import Pipeline 12 | from molpipeline.any2mol import AutoToMol 13 | from molpipeline.estimators import LeaderPickerClustering 14 | from molpipeline.mol2any import MolToMorganFP 15 | 16 | 17 | class TestLeaderPickerEstimator(unittest.TestCase): 18 | """Test LeaderPicker clustering estimator.""" 19 | 20 | def test_leader_picker_clustering_estimator(self) -> None: 21 | """Test LeaderPicker clustering estimator.""" 22 | fingerprint_matrix = [ 23 | DataStructs.CreateFromBitString(x) 24 | for x in [ 25 | "000", # 0 26 | "100", # 1 27 | "110", # 2 28 | "101", # 3 29 | "010", # 4 30 | "011", # 5 31 | "001", # 6 32 | "111", # 7 33 | "000", 34 | "100", 35 | "110", 36 | "101", 37 | "010", 38 | "011", 39 | "001", 40 | "111", 41 | ] 42 | ] 43 | 44 | eps: float = 1e-10 45 | 46 | expected_clusterings: list[dict[str, Any]] = [ 47 | { 48 | "threshold": 0.0, 49 | "expected_clustering": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7], 50 | }, 51 | { 52 | "threshold": 0.0 + eps, 53 | "expected_clustering": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7], 54 | }, 55 | { 56 | "threshold": 1 / 3 - eps, 57 | "expected_clustering": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7], 58 | }, 59 | { 60 | "threshold": 1 / 3, 61 | "expected_clustering": [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7], 62 | }, 63 | ] 64 | 65 | for test_case_dict in expected_clusterings: 66 | threshold = test_case_dict["threshold"] 67 | expected_clustering = test_case_dict["expected_clustering"] 68 | 69 | self.assertEqual(len(fingerprint_matrix), len(expected_clustering)) 70 | 71 | estimator = LeaderPickerClustering(distance_threshold=threshold) 72 | actual_labels = estimator.fit_predict(fingerprint_matrix) 73 | 74 | exp_nof_clusters = np.unique(expected_clustering).shape[0] 75 | self.assertEqual(exp_nof_clusters, estimator.n_clusters_) 76 | 77 | self.assertTrue(np.equal(actual_labels, expected_clustering).all()) 78 | 79 | def test_leader_picker_pipeline(self) -> None: 80 | """Test leader picker clustering in pipeline.""" 81 | test_smiles = ["C", "N", "c1ccccc1", "c1ccc(O)cc1", "CCCCCCO", "CCCCCCC"] 82 | 83 | distances = [0.05, 0.95] 84 | expected_labels = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 2, 3, 3]] 85 | expected_centroids = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 4]] 86 | 87 | for dist, exp_labels, exp_centroids in zip( 88 | distances, expected_labels, expected_centroids 89 | ): 90 | leader_picker = LeaderPickerClustering(distance_threshold=dist) 91 | pipeline = Pipeline( 92 | [ 93 | ("auto2mol", AutoToMol()), 94 | ( 95 | "morgan2", 96 | MolToMorganFP( 97 | return_as="explicit_bit_vect", n_bits=1024, radius=2 98 | ), 99 | ), 100 | ("leader_picker", leader_picker), 101 | ], 102 | ) 103 | 104 | actual_labels = pipeline.fit_predict(test_smiles) 105 | 106 | self.assertTrue(np.equal(actual_labels, exp_labels).all()) 107 | self.assertIsNotNone(leader_picker.centroids_) 108 | self.assertTrue(np.equal(leader_picker.centroids_, exp_centroids).all()) # type: ignore 109 | -------------------------------------------------------------------------------- /tests/test_estimators/test_murcko_scaffold_clustering.py: -------------------------------------------------------------------------------- 1 | """Test Murcko scaffold clustering estimator.""" 2 | 3 | from __future__ import annotations 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from molpipeline.estimators import MurckoScaffoldClustering 10 | 11 | SCAFFOLD_SMILES: list[str] = [ 12 | "Cc1ccccc1", 13 | "Cc1cc(Oc2nccc(CCC)c2)ccc1", 14 | "c1ccccc1", 15 | ] 16 | 17 | SCAFFOLD_SMILES_TEST_GENERIC: list[str] = [*SCAFFOLD_SMILES, "c1ncccc1"] 18 | 19 | LINEAR_SMILES: list[str] = ["CC", "CCC", "CCCN"] 20 | 21 | 22 | class TestMurckoScaffoldClusteringEstimator(unittest.TestCase): 23 | """Test Murcko scaffold clustering estimator.""" 24 | 25 | def test_murcko_scaffold_clustering_ignore(self) -> None: 26 | """Test Murcko scaffold clustering estimator.""" 27 | for make_generic in [False, True]: 28 | estimator_ignore_linear: MurckoScaffoldClustering = ( 29 | MurckoScaffoldClustering( 30 | make_generic=make_generic, 31 | n_jobs=1, 32 | linear_molecules_strategy="ignore", 33 | ) 34 | ) 35 | 36 | # test basic scaffold-based clustering works as intended 37 | scaffold_cluster_labels = estimator_ignore_linear.fit_predict( 38 | SCAFFOLD_SMILES 39 | ) 40 | expected_scaffold_labels = [1.0, 0.0, 1.0] 41 | 42 | self.assertEqual(estimator_ignore_linear.n_clusters_, 2) 43 | self.assertListEqual( 44 | list(scaffold_cluster_labels), expected_scaffold_labels 45 | ) 46 | 47 | # test linear molecule handling. We expect the linear molecules to be ignored. 48 | input_smiles = SCAFFOLD_SMILES + LINEAR_SMILES 49 | cluster_labels = estimator_ignore_linear.fit_predict(input_smiles) 50 | nan_mask = np.isnan(cluster_labels) 51 | expected_nan_mask = [False, False, False, True, True, True] 52 | 53 | self.assertEqual(estimator_ignore_linear.n_clusters_, 2) 54 | self.assertListEqual(list(nan_mask), expected_nan_mask) 55 | self.assertListEqual(cluster_labels[~nan_mask].tolist(), [1.0, 0.0, 1.0]) 56 | 57 | def test_murcko_scaffold_clustering_own_cluster(self) -> None: 58 | """Test Murcko scaffold clustering estimator.""" 59 | for make_generic in [False, True]: 60 | # create new estimator with "own_cluster" strategy 61 | estimator_cluster_linear: MurckoScaffoldClustering = ( 62 | MurckoScaffoldClustering( 63 | make_generic=make_generic, 64 | n_jobs=1, 65 | linear_molecules_strategy="own_cluster", 66 | ) 67 | ) 68 | 69 | # test linear molecule handling. We expect the linear molecules to be clustered in the same cluster 70 | input_smiles = SCAFFOLD_SMILES + LINEAR_SMILES 71 | cluster_labels = estimator_cluster_linear.fit_predict(input_smiles) 72 | expected_cluster_labels = [1.0, 0.0, 1.0, 2.0, 2.0, 2.0] 73 | self.assertEqual(estimator_cluster_linear.n_clusters_, 3) 74 | self.assertListEqual(list(cluster_labels), expected_cluster_labels) 75 | 76 | def test_murcko_scaffold_clustering_generic(self) -> None: 77 | """Test Murcko scaffold clustering estimator with generic scaffold.""" 78 | # test generic clustering makes a difference 79 | estimator: MurckoScaffoldClustering = MurckoScaffoldClustering( 80 | make_generic=True, 81 | n_jobs=1, 82 | linear_molecules_strategy="ignore", 83 | ) 84 | 85 | scaffold_cluster_labels = estimator.fit_predict(SCAFFOLD_SMILES_TEST_GENERIC) 86 | expected_scaffold_labels = [1.0, 0.0, 1.0, 1.0] 87 | 88 | self.assertEqual(estimator.n_clusters_, 2) 89 | self.assertListEqual(list(scaffold_cluster_labels), expected_scaffold_labels) 90 | 91 | # test that without make_generic we get a different result 92 | estimator2: MurckoScaffoldClustering = MurckoScaffoldClustering( 93 | make_generic=False, 94 | n_jobs=1, 95 | linear_molecules_strategy="ignore", 96 | ) 97 | 98 | scaffold_cluster_labels2 = estimator2.fit_predict(SCAFFOLD_SMILES_TEST_GENERIC) 99 | expected_scaffold_labels2 = [1.0, 0.0, 1.0, 2.0] 100 | 101 | self.assertEqual(estimator2.n_clusters_, 3) 102 | self.assertListEqual(list(scaffold_cluster_labels2), expected_scaffold_labels2) 103 | 104 | def test_murcko_scaffold_clustering(self) -> None: 105 | """Test Murcko scaffold clustering estimator for purely linear molecules.""" 106 | test_smiles_failing = [ 107 | "CCCCCCCCC(=CCCCCCCCC(=O)O)[N+](=O)[O-]", 108 | "CN(C)C(=O)C(C)(C)NC(=O)OC(C)(C)C", 109 | ] 110 | 111 | pipe = MurckoScaffoldClustering( 112 | make_generic=False, 113 | n_jobs=1, 114 | linear_molecules_strategy="ignore", 115 | ) 116 | result = pipe.fit_predict(test_smiles_failing, None) 117 | self.assertTrue(np.isnan(result).all()) 118 | -------------------------------------------------------------------------------- /tests/test_experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize the test module for experimental classes and functions.""" 2 | -------------------------------------------------------------------------------- /tests/test_experimental/test_custom_filter.py: -------------------------------------------------------------------------------- 1 | """Test the custom filter element.""" 2 | 3 | import unittest 4 | 5 | from molpipeline import Pipeline 6 | from molpipeline.any2mol import AutoToMol 7 | from molpipeline.experimental import CustomFilter 8 | from molpipeline.mol2any import MolToBool 9 | 10 | 11 | class TestCustomFilter(unittest.TestCase): 12 | """Test the custom filter element.""" 13 | 14 | smiles_list = [ 15 | "CC", 16 | "CCC", 17 | "CCCC", 18 | "CO", 19 | ] 20 | 21 | def test_transform(self) -> None: 22 | """Test the custom filter.""" 23 | mol_list = AutoToMol().transform(self.smiles_list) 24 | res_filter = CustomFilter(lambda x: x.GetNumAtoms() == 2).transform(mol_list) 25 | res_bool = MolToBool().transform(res_filter) 26 | self.assertEqual(res_bool, [True, False, False, True]) 27 | 28 | def test_pipeline(self) -> None: 29 | """Test the custom filter in pipeline.""" 30 | pipeline = Pipeline( 31 | [ 32 | ("auto_to_mol", AutoToMol()), 33 | ("custom_filter", CustomFilter(lambda x: x.GetNumAtoms() == 2)), 34 | ("mol_to_bool", MolToBool()), 35 | ] 36 | ) 37 | self.assertEqual( 38 | pipeline.transform(self.smiles_list), [True, False, False, True] 39 | ) 40 | -------------------------------------------------------------------------------- /tests/test_experimental/test_explainability/__init__.py: -------------------------------------------------------------------------------- 1 | """Test explainability methods and utilities.""" 2 | -------------------------------------------------------------------------------- /tests/test_experimental/test_explainability/test_visualization/__init__.py: -------------------------------------------------------------------------------- 1 | """Test explainability visualization.""" 2 | -------------------------------------------------------------------------------- /tests/test_experimental/test_explainability/test_visualization/test_gaussian_grid.py: -------------------------------------------------------------------------------- 1 | """Test gaussian grid visualization.""" 2 | 3 | import unittest 4 | from typing import ClassVar 5 | 6 | import numpy as np 7 | from rdkit import Chem 8 | from rdkit.Chem import Draw 9 | 10 | from molpipeline import Pipeline 11 | from molpipeline.experimental.explainability import ( 12 | SHAPFeatureAndAtomExplanation, 13 | SHAPFeatureExplanation, 14 | SHAPTreeExplainer, 15 | ) 16 | from molpipeline.experimental.explainability.visualization.visualization import ( 17 | make_sum_of_gaussians_grid, 18 | ) 19 | from tests.test_experimental.test_explainability.test_visualization.test_visualization import ( 20 | _get_test_morgan_rf_pipeline, 21 | ) 22 | 23 | TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] 24 | CONTAINS_OX = [0, 1, 1, 0, 1, 0] 25 | 26 | 27 | class TestSumOfGaussiansGrid(unittest.TestCase): 28 | """Test sum of gaussian grid .""" 29 | 30 | # pylint: disable=duplicate-code 31 | test_pipeline: ClassVar[Pipeline] 32 | test_explainer: ClassVar[SHAPTreeExplainer] 33 | test_explanations: ClassVar[ 34 | list[SHAPFeatureAndAtomExplanation | SHAPFeatureExplanation] 35 | ] 36 | 37 | @classmethod 38 | def setUpClass(cls) -> None: 39 | """Set up the tests.""" 40 | cls.test_pipeline = _get_test_morgan_rf_pipeline() 41 | cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) 42 | cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) 43 | cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) 44 | 45 | def test_grid_with_shap_atom_weights(self) -> None: 46 | """Test grid with SHAP atom weights. 47 | 48 | Raises 49 | ------ 50 | ValueError 51 | If the molecule is not a Chem.Mol object. 52 | 53 | """ 54 | for explanation in self.test_explanations: 55 | self.assertTrue(explanation.is_valid()) 56 | self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr] 57 | 58 | mol = explanation.molecule 59 | if not isinstance(mol, Chem.Mol): 60 | raise ValueError("Expected a Chem.Mol object.") 61 | mol_copy = Chem.Mol(mol) 62 | mol_copy = Draw.PrepareMolForDrawing(mol_copy) 63 | value_grid = make_sum_of_gaussians_grid( 64 | mol_copy, 65 | atom_weights=explanation.atom_weights, # type: ignore[union-attr] 66 | atom_width=np.inf, 67 | grid_resolution=[8, 8], 68 | padding=[0.4, 0.4], 69 | ) 70 | self.assertIsNotNone(value_grid) 71 | grid_values = value_grid.values # type: ignore[attr-defined] 72 | self.assertEqual(grid_values.size, 8 * 8) 73 | 74 | # test that the range of summed gaussian values is as expected for SHAP 75 | self.assertTrue(grid_values.min() >= -1) 76 | self.assertTrue(grid_values.max() <= 1) 77 | -------------------------------------------------------------------------------- /tests/test_experimental/test_explainability/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for explainability tests.""" 2 | 3 | from typing import Any 4 | 5 | import scipy 6 | 7 | from molpipeline import Pipeline 8 | from molpipeline.utils.subpipeline import get_featurization_subpipeline 9 | 10 | 11 | def construct_kernel_shap_kwargs(pipeline: Pipeline, data: list[str]) -> dict[str, Any]: 12 | """Construct the kwargs for SHAPKernelExplainer. 13 | 14 | Convert sparse matrix to dense array because SHAPKernelExplainer does not 15 | support sparse matrix as `data` and then explain dense matrices. 16 | We stick to dense matrices for simplicity. 17 | 18 | Parameters 19 | ---------- 20 | pipeline : Pipeline 21 | The pipeline used for featurization. 22 | data : list[str] 23 | The input data, e.g. SMILES strings. 24 | 25 | Returns 26 | ------- 27 | dict[str, Any] 28 | The kwargs for SHAPKernelExplainer 29 | """ 30 | featurization_subpipeline = get_featurization_subpipeline( 31 | pipeline, raise_not_found=True 32 | ) 33 | data_transformed = featurization_subpipeline.transform(data) # type: ignore[union-attr] 34 | if scipy.sparse.issparse(data_transformed): 35 | data_transformed = data_transformed.toarray() 36 | return {"data": data_transformed} 37 | -------------------------------------------------------------------------------- /tests/test_experimental/test_model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | """Model selection test module.""" 2 | -------------------------------------------------------------------------------- /tests/test_init.py: -------------------------------------------------------------------------------- 1 | """Test functionality set at package init.""" 2 | 3 | import unittest 4 | 5 | from molpipeline import __version__ 6 | 7 | 8 | class TestInit(unittest.TestCase): 9 | """Test functionality set at package init.""" 10 | 11 | def test_version(self) -> None: 12 | """Test that the package has a version.""" 13 | self.assertIsInstance(__version__, str) 14 | splitted = __version__.split(".") 15 | self.assertEqual(len(splitted), 3) 16 | major, minor, patch = splitted 17 | self.assertTrue(major.isdigit()) 18 | self.assertTrue(minor.isdigit()) 19 | self.assertTrue(patch.isdigit()) 20 | -------------------------------------------------------------------------------- /tests/test_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Init file for test_metrics.""" 2 | -------------------------------------------------------------------------------- /tests/test_metrics/test_ignore_error_scorer.py: -------------------------------------------------------------------------------- 1 | """Test the ignore error scorer wrapper.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | from sklearn import linear_model 7 | from sklearn.metrics import get_scorer 8 | 9 | from molpipeline.metrics import ignored_value_scorer 10 | 11 | 12 | class IgnoreErrorScorerTest(unittest.TestCase): 13 | """Test the ignore error scorer wrapper.""" 14 | 15 | def test_filter_nan(self) -> None: 16 | """Test that filtering np.nan works.""" 17 | y_true = np.array([1, 0, 0, 1, 0]) 18 | y_pred = np.array([1, 0, 0, 1, np.nan]) 19 | ba_score = ignored_value_scorer("balanced_accuracy", np.nan) 20 | value = ba_score._score_func(y_true, y_pred) # pylint: disable=protected-access 21 | self.assertAlmostEqual(value, 1.0) 22 | 23 | def test_filter_none(self) -> None: 24 | """Test that filtering None works.""" 25 | y_true = np.array([1, 0, 0, 1, 0]) 26 | y_pred = np.array([1, 0, 0, 1, None]) 27 | ba_score = ignored_value_scorer("balanced_accuracy", None) 28 | value = ba_score._score_func(y_true, y_pred) # pylint: disable=protected-access 29 | self.assertAlmostEqual(value, 1.0) 30 | 31 | def test_filter_nan_with_none(self) -> None: 32 | """Test that filtering NaN with None works.""" 33 | y_true = np.array([1, 0, 0, 1, 0]) 34 | y_pred = np.array([1, 0, 0, 1, None]) 35 | ba_score = ignored_value_scorer("balanced_accuracy", np.nan) 36 | self.assertAlmostEqual( 37 | ba_score._score_func(y_true, y_pred), # pylint: disable=protected-access 38 | 1.0, 39 | ) 40 | 41 | def test_filter_none_with_nan(self) -> None: 42 | """Test that filtering None with NaN works.""" 43 | y_true = np.array([1, 0, 0, 1, 0]) 44 | y_pred = np.array([1, 0, 0, 1, np.nan]) 45 | ba_score = ignored_value_scorer("balanced_accuracy", None) 46 | self.assertAlmostEqual( 47 | ba_score._score_func(y_true, y_pred), # pylint: disable=protected-access 48 | 1.0, 49 | ) 50 | 51 | def test_correct_init_mse(self) -> None: 52 | """Test that initialization is correct as we access via protected vars.""" 53 | x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape( 54 | -1, 1 55 | ) 56 | y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1]) 57 | regr = linear_model.LinearRegression() 58 | regr.fit(x_train, y_train) 59 | cix_scorer = ignored_value_scorer("neg_mean_squared_error", None) 60 | scikit_scorer = get_scorer("neg_mean_squared_error") 61 | self.assertEqual( 62 | cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train) 63 | ) 64 | 65 | def test_correct_init_rmse(self) -> None: 66 | """Test that initialization is correct as we access via protected vars.""" 67 | x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape( 68 | -1, 1 69 | ) 70 | y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1]) 71 | regr = linear_model.LinearRegression() 72 | regr.fit(x_train, y_train) 73 | cix_scorer = ignored_value_scorer("neg_root_mean_squared_error", None) 74 | scikit_scorer = get_scorer("neg_root_mean_squared_error") 75 | self.assertEqual( 76 | cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train) 77 | ) 78 | 79 | def test_correct_init_inheritance(self) -> None: 80 | """Test that initialization is correct if we pass an initialized scorer.""" 81 | x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape( 82 | -1, 1 83 | ) 84 | y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1]) 85 | regr = linear_model.LinearRegression() 86 | regr.fit(x_train, y_train) 87 | scikit_scorer = get_scorer("neg_root_mean_squared_error") 88 | cix_scorer = ignored_value_scorer( 89 | get_scorer("neg_root_mean_squared_error"), None 90 | ) 91 | self.assertEqual( 92 | cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train) 93 | ) 94 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/test_utils/test_comparison.py: -------------------------------------------------------------------------------- 1 | """Test the comparison functions.""" 2 | 3 | from typing import Callable 4 | from unittest import TestCase 5 | 6 | from molpipeline import Pipeline 7 | from molpipeline.utils.comparison import check_pipelines_equivalent 8 | from tests.utils.default_models import ( 9 | get_morgan_physchem_rf_pipeline, 10 | get_standardization_pipeline, 11 | ) 12 | 13 | 14 | class TestComparison(TestCase): 15 | """Test if functional equivalent pipelines are detected as such.""" 16 | 17 | def test_are_equal(self) -> None: 18 | """Test if two equivalent pipelines are detected as such.""" 19 | # Test standardization pipelines 20 | pipline_method_list: list[Callable[[int], Pipeline]] = [ 21 | get_standardization_pipeline, 22 | get_morgan_physchem_rf_pipeline, 23 | ] 24 | for pipeline_method in pipline_method_list: 25 | pipeline_a = pipeline_method(1) 26 | pipeline_b = pipeline_method(1) 27 | self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) 28 | 29 | def test_are_not_equal(self) -> None: 30 | """Test if two different pipelines are detected as such.""" 31 | # Test changed parameters 32 | pipeline_a = get_morgan_physchem_rf_pipeline() 33 | pipeline_b = get_morgan_physchem_rf_pipeline() 34 | pipeline_b.set_params(mol2fp__morgan__n_bits=1024) 35 | self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) 36 | 37 | # Test changed steps 38 | pipeline_b = get_morgan_physchem_rf_pipeline() 39 | last_step = pipeline_b.steps[-1] 40 | pipeline_b.steps = pipeline_b.steps[:-1] 41 | self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) 42 | 43 | # Test if adding the step back makes the pipelines equivalent 44 | pipeline_b.steps.append(last_step) 45 | self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) 46 | -------------------------------------------------------------------------------- /tests/test_utils/test_json_operations.py: -------------------------------------------------------------------------------- 1 | """Tests conversion of sklearn models to json and back.""" 2 | 3 | import unittest 4 | 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.svm import SVC 7 | 8 | from molpipeline import Pipeline 9 | from molpipeline.utils.json_operations import ( 10 | recursive_from_json, 11 | recursive_to_json, 12 | transform_functions2string, 13 | transform_string2function, 14 | ) 15 | from molpipeline.utils.multi_proc import check_available_cores 16 | 17 | 18 | class JsonConversionTest(unittest.TestCase): 19 | """Unittest for conversion of sklearn models to json and back.""" 20 | 21 | def test_rf_reconstruction(self) -> None: 22 | """Test if the sklearn-rf can be reconstructed from json.""" 23 | random_forest = RandomForestClassifier(n_estimators=200) 24 | recreated_rf = recursive_from_json(recursive_to_json(random_forest)) 25 | self.assertEqual(random_forest.get_params(), recreated_rf.get_params()) 26 | 27 | def test_svc_reconstruction(self) -> None: 28 | """Test if the sklearn-svc can be reconstructed from json.""" 29 | svc = SVC() 30 | recreated_svc = recursive_from_json(recursive_to_json(svc)) 31 | self.assertEqual(svc.get_params(), recreated_svc.get_params()) 32 | 33 | def test_pipeline_reconstruction(self) -> None: 34 | """Test if the sklearn-pipleine can be reconstructed from json.""" 35 | random_forest = RandomForestClassifier(n_estimators=200) 36 | svc = SVC() 37 | pipeline = Pipeline([("rf", random_forest), ("svc", svc)]) 38 | recreated_pipeline = recursive_from_json(recursive_to_json(pipeline)) 39 | 40 | original_params = pipeline.get_params() 41 | recreated_params = recreated_pipeline.get_params() 42 | original_steps = original_params.pop("steps") 43 | recreated_steps = recreated_params.pop("steps") 44 | 45 | # Separate comparison of the steps as models cannot be compared directly 46 | for (orig_name, orig_obj), (recreated_name, recreated_obj) in zip( 47 | original_steps, recreated_steps 48 | ): 49 | # Remove the model from the original params 50 | del original_params[orig_name] 51 | del recreated_params[recreated_name] 52 | self.assertEqual(orig_name, recreated_name) 53 | self.assertEqual(orig_obj.get_params(), recreated_obj.get_params()) 54 | self.assertEqual(type(orig_obj), type(recreated_obj)) 55 | self.assertEqual(original_params, recreated_params) 56 | 57 | def test_function_dict_json(self) -> None: 58 | """Test if a dict with objects can be reconstructed from json.""" 59 | function_dict = { 60 | "dummy1": {"check_available_cores": check_available_cores}, 61 | "dummy2": str, 62 | "dummy3": 1, 63 | "dummy4": [check_available_cores, check_available_cores, "test"], 64 | } 65 | function_json = transform_functions2string(function_dict) 66 | recreated_function_dict = transform_string2function(function_json) 67 | self.assertEqual(function_dict, recreated_function_dict) 68 | 69 | def test_set_transformation(self) -> None: 70 | """Test if a set can be reconstructed from json.""" 71 | test_set = {1, "a", (1, "a")} 72 | test_set_json = recursive_to_json(test_set) 73 | recreated_set = recursive_from_json(test_set_json) 74 | self.assertEqual(test_set, recreated_set) 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /tests/test_utils/test_logging.py: -------------------------------------------------------------------------------- 1 | """Test logging utils.""" 2 | 3 | import io 4 | import unittest 5 | from contextlib import redirect_stdout 6 | 7 | from molpipeline.utils.logging import print_elapsed_time 8 | 9 | 10 | class LoggingUtilsTest(unittest.TestCase): 11 | """Unittest for conversion of sklearn models to json and back.""" 12 | 13 | def test__print_elapsed_time(self) -> None: 14 | """Test message logging with timings work as expected.""" 15 | # when message is None nothing should be printed 16 | stream1 = io.StringIO() 17 | with redirect_stdout(stream1): 18 | with print_elapsed_time("source", message=None, use_logger=False): 19 | pass 20 | output1 = stream1.getvalue() 21 | self.assertEqual(output1, "") 22 | 23 | # message should be printed in the expected sklearn format 24 | stream2 = io.StringIO() 25 | with redirect_stdout(stream2): 26 | with print_elapsed_time("source", message="my message", use_logger=False): 27 | pass 28 | output2 = stream2.getvalue() 29 | self.assertTrue( 30 | output2.startswith( 31 | "[source] ................................... my message, total=" 32 | ) 33 | ) 34 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /tests/utils/default_models.py: -------------------------------------------------------------------------------- 1 | """Module for default models used for testing molpipeline functions and classes.""" 2 | 3 | from sklearn.ensemble import RandomForestClassifier 4 | 5 | from molpipeline import Pipeline 6 | from molpipeline.any2mol import SmilesToMol 7 | from molpipeline.error_handling import ErrorFilter, FilterReinserter 8 | from molpipeline.mol2any import ( 9 | MolToConcatenatedVector, 10 | MolToMorganFP, 11 | MolToRDKitPhysChem, 12 | MolToSmiles, 13 | ) 14 | from molpipeline.mol2mol import ( 15 | EmptyMoleculeFilter, 16 | FragmentDeduplicator, 17 | MetalDisconnector, 18 | MixtureFilter, 19 | SaltRemover, 20 | StereoRemover, 21 | TautomerCanonicalizer, 22 | Uncharger, 23 | ) 24 | from molpipeline.mol2mol.filter import ElementFilter 25 | from molpipeline.post_prediction import PostPredictionWrapper 26 | 27 | 28 | def get_morgan_physchem_rf_pipeline(n_jobs: int = 1) -> Pipeline: 29 | """Get a pipeline with Morgan FP, physicochem. properties, and a RandomForest. 30 | 31 | Parameters 32 | ---------- 33 | n_jobs: int, default=-1 34 | Number of parallel jobs to use. 35 | 36 | Returns 37 | ------- 38 | Pipeline 39 | A pipeline combining Morgan fingerprints and physicochemical properties with a 40 | RandomForestClassifier. 41 | 42 | """ 43 | error_filter = ErrorFilter(filter_everything=True) 44 | pipeline = Pipeline( 45 | [ 46 | ("smi2mol", SmilesToMol()), 47 | ( 48 | "mol2fp", 49 | MolToConcatenatedVector( 50 | [ 51 | ("morgan", MolToMorganFP(n_bits=2048)), 52 | ("physchem", MolToRDKitPhysChem()), 53 | ] 54 | ), 55 | ), 56 | ("error_filter", error_filter), 57 | ("rf", RandomForestClassifier(n_jobs=n_jobs)), 58 | ( 59 | "filter_reinserter", 60 | PostPredictionWrapper( 61 | FilterReinserter.from_error_filter(error_filter, None) 62 | ), 63 | ), 64 | ], 65 | n_jobs=n_jobs, 66 | ) 67 | return pipeline 68 | 69 | 70 | def get_standardization_pipeline(n_jobs: int = 1) -> Pipeline: 71 | """Get the standardization pipeline. 72 | 73 | Parameters 74 | ---------- 75 | n_jobs: int, optional (default=-1) 76 | The number of jobs to use for standardization. 77 | In case of -1, all available CPUs are used. 78 | 79 | Returns 80 | ------- 81 | Pipeline 82 | The standardization pipeline. 83 | 84 | """ 85 | error_filter = ErrorFilter(filter_everything=True) 86 | # Set up pipeline 87 | standardization_pipeline = Pipeline( 88 | [ 89 | ("smi2mol", SmilesToMol()), 90 | ("metal_disconnector", MetalDisconnector()), 91 | ("salt_remover", SaltRemover()), 92 | ("element_filter", ElementFilter()), 93 | ("uncharge1", Uncharger()), 94 | ("canonical_tautomer", TautomerCanonicalizer()), 95 | ("uncharge2", Uncharger()), 96 | ("stereo_remover", StereoRemover()), 97 | ("fragment_deduplicator", FragmentDeduplicator()), 98 | ("mixture_remover", MixtureFilter()), 99 | ("empty_molecule_remover", EmptyMoleculeFilter()), 100 | ("mol2smi", MolToSmiles()), 101 | ("error_filter", error_filter), 102 | ("error_replacer", FilterReinserter.from_error_filter(error_filter, None)), 103 | ], 104 | n_jobs=n_jobs, 105 | ) 106 | return standardization_pipeline 107 | -------------------------------------------------------------------------------- /tests/utils/execution_count.py: -------------------------------------------------------------------------------- 1 | """Functions for counting the number of times a function is executed.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | try: 8 | from typing import Self # type: ignore[attr-defined] 9 | except ImportError: 10 | from typing_extensions import Self 11 | 12 | from sklearn.base import BaseEstimator 13 | from sklearn.ensemble import RandomForestRegressor 14 | 15 | from molpipeline import Pipeline 16 | from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement 17 | from molpipeline.any2mol import SmilesToMol 18 | from molpipeline.mol2any import MolToMorganFP 19 | 20 | 21 | class CountingTransformerWrapper(BaseEstimator): 22 | """A transformer that counts the number of transformations.""" 23 | 24 | def __init__(self, element: ABCPipelineElement): 25 | """Initialize the wrapper. 26 | 27 | Parameters 28 | ---------- 29 | element : ABCPipelineElement 30 | The element to wrap. 31 | """ 32 | self.element = element 33 | self.n_transformations = 0 34 | 35 | def fit(self, X: Any, y: Any) -> Self: # pylint: disable=invalid-name 36 | """Fit the data. 37 | 38 | Parameters 39 | ---------- 40 | X : Any 41 | The input data. 42 | y : Any 43 | The target data. 44 | 45 | Returns 46 | ------- 47 | Any 48 | The fitted data. 49 | """ 50 | self.element.fit(X, y) 51 | return self 52 | 53 | def transform(self, X: Any) -> Any: # pylint: disable=invalid-name 54 | """Transform the data. 55 | 56 | Transform is called during prediction, which is not cached. 57 | Since the transformer is not cached, the counter is not increased. 58 | 59 | Parameters 60 | ---------- 61 | X : Any 62 | The input data. 63 | 64 | Returns 65 | ------- 66 | Any 67 | The transformed data. 68 | """ 69 | return self.element.transform(X) 70 | 71 | def fit_transform(self, X: Any, y: Any) -> Any: # pylint: disable=invalid-name 72 | """Fit and transform the data. 73 | 74 | Parameters 75 | ---------- 76 | X : Any 77 | The input data. 78 | y : Any 79 | The target data. 80 | 81 | Returns 82 | ------- 83 | Any 84 | The transformed data. 85 | """ 86 | self.n_transformations += 1 87 | return self.element.fit_transform(X, y) 88 | 89 | def get_params(self, deep: bool = True) -> dict[str, Any]: 90 | """Get the parameters of the transformer. 91 | 92 | Parameters 93 | ---------- 94 | deep : bool 95 | If True, the parameters of the transformer are also returned. 96 | 97 | Returns 98 | ------- 99 | dict[str, Any] 100 | The parameters of the transformer. 101 | """ 102 | params = { 103 | "element": self.element, 104 | } 105 | if deep: 106 | params.update(self.element.get_params(deep)) 107 | return params 108 | 109 | def set_params(self, **params: Any) -> Self: 110 | """Set the parameters of the transformer. 111 | 112 | Parameters 113 | ---------- 114 | **params 115 | The parameters to set. 116 | 117 | Returns 118 | ------- 119 | Self 120 | The transformer with the set parameters 121 | """ 122 | element = params.pop("element", None) 123 | if element is not None: 124 | self.element = element 125 | self.element.set_params(**params) 126 | return self 127 | 128 | 129 | def get_exec_counted_rf_regressor(random_state: int) -> Pipeline: 130 | """Get a morgan + random forest pipeline, which counts the number of transformations. 131 | 132 | Parameters 133 | ---------- 134 | random_state : int 135 | The random state to use. 136 | 137 | Returns 138 | ------- 139 | Pipeline 140 | A pipeline with a morgan fingerprint, physchem descriptors, and a random forest 141 | """ 142 | smi2mol = SmilesToMol() 143 | 144 | mol2concat = CountingTransformerWrapper( 145 | MolToMorganFP(radius=2, n_bits=2048), 146 | ) 147 | rf = RandomForestRegressor(random_state=random_state, n_jobs=1) 148 | return Pipeline( 149 | [ 150 | ("smi2mol", smi2mol), 151 | ("mol2concat", mol2concat), 152 | ("rf", rf), 153 | ], 154 | n_jobs=1, 155 | ) 156 | -------------------------------------------------------------------------------- /tests/utils/fingerprints.py: -------------------------------------------------------------------------------- 1 | """Functions of fingerprints for comparing output with molpipline.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from rdkit import Chem 8 | from rdkit.Chem import rdFingerprintGenerator as rdkit_fp 9 | from rdkit.DataStructs import ExplicitBitVect, UIntSparseIntVect 10 | from scipy import sparse 11 | 12 | 13 | def make_sparse_fp( 14 | smiles_list: list[str], radius: int, n_bits: int 15 | ) -> sparse.csr_matrix: 16 | """Create a sparse Morgan fingerprint matrix from a list of SMILES. 17 | 18 | Used in Unittests. 19 | 20 | Parameters 21 | ---------- 22 | smiles_list: list[str] 23 | SMILES representations of molecules which will be encoded as fingerprint. 24 | radius: int 25 | Radius of features. 26 | n_bits: int 27 | Obtained features will be mapped to a vector of size n_bits. 28 | 29 | Returns 30 | ------- 31 | sparse.csr_matrix 32 | Feature matrix. 33 | """ 34 | vector_list = [] 35 | morgan_fp = rdkit_fp.GetMorganGenerator(radius=radius, fpSize=n_bits) 36 | for smiles in smiles_list: 37 | mol = Chem.MolFromSmiles(smiles) 38 | vector = morgan_fp.GetFingerprintAsNumPy(mol) 39 | vector_list.append(sparse.csr_matrix(vector)) 40 | return sparse.vstack(vector_list) 41 | 42 | 43 | def fingerprints_to_numpy( 44 | fingerprints: list[ExplicitBitVect] | sparse.csr_matrix | npt.NDArray[np.int_], 45 | ) -> npt.NDArray[np.int_]: 46 | """Convert fingerprints in various types to numpy. 47 | 48 | Parameters 49 | ---------- 50 | fingerprints: list[ExplicitBitVect] | sparse.csr_matrix | npt.NDArray[np.int_] 51 | Fingerprint matrix. 52 | 53 | Raises 54 | ------ 55 | ValueError 56 | If the fingerprints are not in a supported format. 57 | 58 | Returns 59 | ------- 60 | npt.NDArray 61 | Numpy fingerprint matrix. 62 | 63 | """ 64 | if all(isinstance(fp, ExplicitBitVect) for fp in fingerprints): 65 | return np.array(fingerprints) 66 | if all(isinstance(fp, UIntSparseIntVect) for fp in fingerprints): 67 | return np.array([fp.ToList() for fp in fingerprints]) 68 | if isinstance(fingerprints, sparse.csr_matrix): 69 | return fingerprints.toarray() 70 | if isinstance(fingerprints, np.ndarray): 71 | return fingerprints 72 | raise ValueError("Unknown fingerprint type. Can not convert to numpy") 73 | -------------------------------------------------------------------------------- /tests/utils/logging.py: -------------------------------------------------------------------------------- 1 | """Test utils for logging.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Generator 6 | from contextlib import contextmanager 7 | 8 | import loguru 9 | from loguru import logger 10 | 11 | 12 | @contextmanager 13 | def capture_logs( 14 | level: str = "INFO", log_format: str = "{level}:{name}:{message}" 15 | ) -> Generator[list[loguru.Message], None, None]: 16 | """Capture loguru-based logs. 17 | 18 | Custom context manager to test loguru-based logs. For details and usage examples, 19 | see https://loguru.readthedocs.io/en/latest/resources/migration.html#replacing-assertlogs-method-from-unittest-library 20 | 21 | Parameters 22 | ---------- 23 | level : str, optional 24 | Log level, by default "INFO" 25 | log_format : str, optional 26 | Log format, by default "{level}:{name}:{message}" 27 | 28 | Yields 29 | ------ 30 | list[loguru.Message] 31 | List of log messages 32 | 33 | """ 34 | output: list[loguru.Message] = [] 35 | handler_id = logger.add(output.append, level=level, format=log_format) 36 | yield output 37 | logger.remove(handler_id) 38 | -------------------------------------------------------------------------------- /tests/utils/mock_element.py: -------------------------------------------------------------------------------- 1 | """Mock PipelineElement for testing.""" 2 | 3 | from __future__ import annotations 4 | 5 | import copy 6 | from collections.abc import Iterable 7 | from typing import Any 8 | 9 | import numpy as np 10 | 11 | try: 12 | from typing import Self # type: ignore[attr-defined] 13 | except ImportError: 14 | from typing_extensions import Self 15 | 16 | from molpipeline.abstract_pipeline_elements.core import ( 17 | InvalidInstance, 18 | TransformingPipelineElement, 19 | ) 20 | 21 | 22 | class MockTransformingPipelineElement(TransformingPipelineElement): 23 | """Mock element for testing.""" 24 | 25 | def __init__( 26 | self, 27 | *, 28 | invalid_values: set[Any] | None = None, 29 | return_as_numpy_array: bool = False, 30 | name: str = "dummy", 31 | uuid: str | None = None, 32 | n_jobs: int = 1, 33 | ) -> None: 34 | """Initialize MockTransformingPipelineElement. 35 | 36 | Parameters 37 | ---------- 38 | invalid_values: set[Any] | None, optional 39 | Set of values to consider invalid. 40 | return_as_numpy_array: bool, default=False 41 | If True return output as numpy array, otherwise as list. 42 | name: str, default="dummy" 43 | Name of PipelineElement 44 | uuid: str | None, optional 45 | Unique identifier of PipelineElement. 46 | n_jobs: int, default=1 47 | Number of jobs to run in parallel. 48 | """ 49 | super().__init__(name=name, uuid=uuid, n_jobs=n_jobs) 50 | if invalid_values is None: 51 | invalid_values = set() 52 | self.invalid_values = invalid_values 53 | self.return_as_numpy_array: bool = return_as_numpy_array 54 | 55 | def get_params(self, deep: bool = True) -> dict[str, Any]: 56 | """Return all parameters defining the object. 57 | 58 | Parameters 59 | ---------- 60 | deep: bool 61 | If True get a deep copy of the parameters. 62 | 63 | Returns 64 | ------- 65 | dict[str, Any] 66 | Dictionary containing all parameters defining the object. 67 | """ 68 | params = super().get_params(deep) 69 | if deep: 70 | params["invalid_values"] = copy.deepcopy(self.invalid_values) 71 | params["return_as_numpy_array"] = copy.deepcopy(self.return_as_numpy_array) 72 | else: 73 | params["invalid_values"] = self.invalid_values 74 | params["return_as_numpy_array"] = self.return_as_numpy_array 75 | return params 76 | 77 | def set_params(self, **parameters: Any) -> Self: 78 | """Set parameters of the object. 79 | 80 | Parameters 81 | ---------- 82 | parameters: Any 83 | Dictionary containing all parameters defining the object. 84 | 85 | Returns 86 | ------- 87 | Self 88 | MockTransformingPipelineElement with updated parameters. 89 | """ 90 | super().set_params(**parameters) 91 | if "invalid_values" in parameters: 92 | self.invalid_values = set(parameters["invalid_values"]) 93 | if "return_as_numpy_array" in parameters: 94 | self.return_as_numpy_array = bool(parameters["return_as_numpy_array"]) 95 | return self 96 | 97 | def pretransform_single(self, value: Any) -> Any: 98 | """Transform input value to other value. 99 | 100 | Parameters 101 | ---------- 102 | value: Any 103 | Input value. 104 | 105 | Returns 106 | ------- 107 | Any 108 | Other value. 109 | """ 110 | if value in self.invalid_values: 111 | return InvalidInstance( 112 | self.uuid, 113 | f"Invalid input value by mock: {value}", 114 | self.name, 115 | ) 116 | return value 117 | 118 | def assemble_output(self, value_list: Iterable[Any]) -> Any: 119 | """Aggregate rows, which in most cases is just return the list. 120 | 121 | Some representations might be better representd as a single object. For example a list of vectors can 122 | be transformed to a matrix. 123 | 124 | Parameters 125 | ---------- 126 | value_list: Iterable[Any] 127 | Iterable of transformed rows. 128 | 129 | Returns 130 | ------- 131 | Any 132 | Aggregated output. This can also be the original input. 133 | """ 134 | if self.return_as_numpy_array: 135 | return np.array(list(value_list)) 136 | return list(value_list) 137 | --------------------------------------------------------------------------------