├── .DS_Store ├── .gitignore ├── README.md ├── docs └── images │ ├── logo.svg │ └── protocol.png ├── env.yml ├── mood ├── __init__.py ├── __main__.py ├── baselines.py ├── chemistry.py ├── cli.py ├── constants.py ├── criteria.py ├── dataset.py ├── distance.py ├── experiment.py ├── metrics.py ├── model │ ├── __init__.py │ ├── base.py │ ├── coral.py │ ├── dann.py │ ├── erm.py │ ├── ib_erm.py │ ├── mixup.py │ ├── mtl.py │ ├── nn.py │ ├── utils.py │ └── vrex.py ├── model_space.py ├── preprocessing.py ├── rct.py ├── representations.py ├── splitter.py ├── train.py ├── transformer.py ├── utils.py └── visualize.py ├── notebooks ├── 01. Dataset properties.ipynb ├── 02. Model space vs input space grid.ipynb ├── 03. Correlation between model space and input space.ipynb ├── 04. Performance over distance grid.ipynb ├── 05. Gap between IID and OOD performance.ipynb ├── 06. Split protocol visualization grid.ipynb ├── 07. Prescribed splits summary.ipynb ├── 08. Check RCT proportions.ipynb ├── 09. Comparing tools and options.ipynb ├── 10. Comparing validation and test performance.ipynb ├── 11. MOOD specification flowchart.ipynb ├── 12. MOOD ToC graphic.ipynb ├── ToC_graphic.tiff └── assets │ ├── checkmark.png │ └── cross.png ├── pyproject.toml └── scripts ├── __init__.py ├── __main__.py ├── cli.py ├── compare_performance.py ├── compare_spaces.py ├── compare_splits.py ├── precompute_distances.py ├── precompute_graphormer.py ├── precompute_representations.py ├── visualize_perf_over_distance.py ├── visualize_shift.py └── visualize_splits.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm files 132 | .idea/ 133 | # Rever 134 | rever/ 135 | 136 | # Specifically for mood 137 | lightning_logs/ 138 | run.yaml 139 | outputs/ 140 | multirun/ 141 | src/ 142 | notebooks/*.csv 143 | profile/ 144 | argo/ 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Molecular Out-Of-Distribution Generalization

4 |

5 | Close the testing-deployment gap in molecular scoring. 6 |

7 |
8 | 9 | --- 10 | 11 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 12 | [![DOI](https://img.shields.io/badge/DOI-10.1021%2Facs--jcim--3c01774-blue)](https://doi.org/10.1021/acs.jcim.3c01774) 13 | 14 | 15 | Python repository with all the code that was used for the [MOOD paper](https://doi.org/10.1021/acs.jcim.3c01774). 16 | 17 | ## Setup 18 | We recommend you to use `mamba` ([learn more](https://github.com/mamba-org/mamba)). 19 | 20 | ```shell 21 | mamba env create -n mood -f env.yml 22 | conda activate mood 23 | pip install -e . 24 | ``` 25 | 26 | ## Overview 27 | The repository is set-up to make the results easy to reproduce. If you get stuck or like to learn more, please feel free to open an issue. 28 | 29 | ### CLI 30 | After installation, the MOOD CLI can be used to reproduce the results. 31 | ```shell 32 | mood --help 33 | ``` 34 | 35 | ### Data 36 | All data has been made available in a public bucket. See [`https://storage.valencelabs.com/mood-data/`](https://storage.valencelabs.com/mood-data/). 37 | 38 | ### Code 39 | - `mood/`: This is the main part of the codebase. It contains Python implementations of several reusable components and defines the CLI. 40 | - `notebooks/`: Notebooks were used to visualize and otherwise explore the results. All plots in the paper can be reproduced through these notebooks. 41 | - `scripts/`: Generally more messy pieces of code that were used to generate (intermediate) results. 42 | 43 | ## Use the MOOD splitting protocol 44 |
45 | 46 |
47 | 48 | One of the main results of the MOOD paper, is the MOOD protocol. This protocol helps to close the testing-deployment gap in molecular scoring by finding the most representative splitting method. To make it easy for others to experiment with this protocol, we made an effort to make it easy to use. 49 | 50 | ```python 51 | import datamol as dm 52 | import numpy as np 53 | from sklearn.model_selection import ShuffleSplit 54 | from mood.splitter import PerimeterSplit, MaxDissimilaritySplit, PredefinedGroupShuffleSplit, MOODSplitter 55 | 56 | # Load your data 57 | data = dm.data.freesolv() 58 | smiles = data["smiles"].values 59 | X = np.stack([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 60 | y = data["expt"].values 61 | 62 | # Load your deployment data 63 | X_deployment = np.random.random((100, 2048)).round() 64 | 65 | # Set-up your candidate splitting methods 66 | scaffolds = [dm.to_smiles(dm.to_scaffold_murcko(dm.to_mol(smi))) for smi in smiles] 67 | candidate_splitters = { 68 | "Random": ShuffleSplit(n_splits=5), # MOOD is Scikit-learn compatible! 69 | "Scaffold": PredefinedGroupShuffleSplit(groups=scaffolds, n_splits=5), 70 | "Perimeter": PerimeterSplit(n_splits=5), 71 | "Maximum Dissimilarity": MaxDissimilaritySplit(n_splits=5), 72 | } 73 | 74 | # Set-up the MOOD splitter 75 | mood_splitter = MOODSplitter(candidate_splitters, metric="jaccard") 76 | mood_splitter.fit(X, y, X_deployment=X_deployment) 77 | 78 | for train, test in mood_splitter.split(X, y): 79 | # Work your magic! 80 | ... 81 | ``` 82 | 83 | ## How to cite 84 | Please cite MOOD if you use it in your research: [![DOI](https://img.shields.io/badge/DOI-10.1021%2Facs--jcim--3c01774-blue)](https://doi.org/10.1021/acs.jcim.3c01774) 85 | 86 | ``` 87 | Real-World Molecular Out-Of-Distribution: Specification and Investigation 88 | Prudencio Tossou, Cas Wognum, Michael Craig, Hadrien Mary, and Emmanuel Noutahi 89 | Journal of Chemical Information and Modeling 2024 64 (3), 697-711 90 | DOI: 10.1021/acs.jcim.3c01774 91 | ``` 92 | 93 | ```bib 94 | @article{doi:10.1021/acs.jcim.3c01774, 95 | author = {Tossou, Prudencio and Wognum, Cas and Craig, Michael and Mary, Hadrien and Noutahi, Emmanuel}, 96 | title = {Real-World Molecular Out-Of-Distribution: Specification and Investigation}, 97 | journal = {Journal of Chemical Information and Modeling}, 98 | volume = {64}, 99 | number = {3}, 100 | pages = {697-711}, 101 | year = {2024}, 102 | doi = {10.1021/acs.jcim.3c01774}, 103 | note = {PMID: 38300258}, 104 | URL = {https://doi.org/10.1021/acs.jcim.3c01774}, 105 | eprint = {https://doi.org/10.1021/acs.jcim.3c01774} 106 | } 107 | ``` 108 | 109 | -------------------------------------------------------------------------------- /docs/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/images/protocol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/docs/images/protocol.png -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: mood_v2 2 | 3 | channels: 4 | - conda-forge 5 | 6 | dependencies: 7 | - pip 8 | - python =3.10 9 | - pandas 10 | - matplotlib 11 | - scikit-learn 12 | - torchmetrics 13 | - pytorch-lightning <2.0 14 | - pytorch >=1.10.2 15 | - numpy <1.24 16 | - tqdm 17 | - optuna 18 | - datamol 19 | - notebook 20 | - pytdc 21 | - typer 22 | - gcsfs 23 | - pyarrow 24 | - fastparquet 25 | - transformers 26 | 27 | # Dev 28 | - black >=20.8b1 29 | 30 | -------------------------------------------------------------------------------- /mood/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mood/__main__.py: -------------------------------------------------------------------------------- 1 | from mood.cli import app 2 | 3 | 4 | if __name__ == "__main__": 5 | app() 6 | -------------------------------------------------------------------------------- /mood/baselines.py: -------------------------------------------------------------------------------- 1 | import optuna 2 | import numpy as np 3 | import datamol as dm 4 | 5 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, VotingRegressor, VotingClassifier 6 | from sklearn.neural_network import MLPRegressor, MLPClassifier 7 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier 8 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum, WhiteKernel 9 | from sklearn.calibration import CalibratedClassifierCV 10 | from sklearn.base import ClassifierMixin, clone 11 | 12 | 13 | """ 14 | These are the three baselines we consider in the MOOD study 15 | For benchmarking, we use a torch MLP rather than the one from scikit-learn. 16 | """ 17 | MOOD_BASELINES = ["MLP", "RF", "GP"] 18 | 19 | 20 | def get_baseline_cls(name, is_regression): 21 | """Simple method that allows a model to be identified by its name""" 22 | 23 | target_type = "regression" if is_regression else "classification" 24 | data = { 25 | "MLP": { 26 | "regression": MLPRegressor, 27 | "classification": MLPClassifier, 28 | }, 29 | "RF": { 30 | "regression": RandomForestRegressor, 31 | "classification": RandomForestClassifier, 32 | }, 33 | "GP": { 34 | "regression": GaussianProcessRegressor, 35 | "classification": GaussianProcessClassifier, 36 | }, 37 | } 38 | 39 | return data[name][target_type] 40 | 41 | 42 | def get_baseline_model( 43 | name: str, 44 | is_regression: bool, 45 | params: dict, 46 | for_uncertainty_estimation: bool = False, 47 | ensemble_size: int = 10, 48 | calibrate: bool = True, 49 | ): 50 | """Entrypoint for constructing a baseline model from scikit-learn""" 51 | model = get_baseline_cls(name, is_regression)(**params) 52 | if for_uncertainty_estimation: 53 | model = uncertainty_wrapper(model, ensemble_size, calibrate) 54 | return model 55 | 56 | 57 | def uncertainty_wrapper(model, ensemble_size: int = 10, calibrate: bool = True): 58 | """Wraps the model so that it can be used for uncertainty estimation. 59 | This includes at most two steps: Turning MLPs in an ensemble and calibrating RF and MLP classifiers 60 | """ 61 | if isinstance(model, MLPClassifier) or isinstance(model, MLPRegressor): 62 | models = [] 63 | for idx in range(ensemble_size): 64 | model = clone(model) 65 | model.set_params(random_state=model.random_state + idx) 66 | models.append((f"mlp_{idx}", model)) 67 | 68 | if isinstance(model, MLPClassifier): 69 | model = VotingClassifier(models, voting="soft", n_jobs=-1) 70 | else: 71 | model = VotingRegressor(models, n_jobs=-1) 72 | 73 | if calibrate and isinstance(model, RandomForestClassifier) or isinstance(model, MLPClassifier): 74 | model = CalibratedClassifierCV(model) 75 | return model 76 | 77 | 78 | def predict_baseline_uncertainty(model, X): 79 | """Predicts the uncertainty of the model. 80 | 81 | For GP regressors, we use the included uncertainty estimation. 82 | For classifiers, the entropy of the prediction is used as uncertainty. 83 | For regressors, the variance of the prediction is used as uncertainty. 84 | """ 85 | if isinstance(model, ClassifierMixin): 86 | uncertainty = model.predict_proba(X) 87 | 88 | elif isinstance(model, GaussianProcessRegressor): 89 | std = model.predict(X, return_std=True)[1] 90 | uncertainty = std**2 91 | 92 | else: 93 | # VotingRegressor or RandomForestRegressor 94 | preds = dm.utils.parallelized(lambda x: x.predict(X), model.estimators_, n_jobs=model.n_jobs) 95 | uncertainty = np.var(preds, axis=0) 96 | 97 | return uncertainty 98 | 99 | 100 | def suggest_mlp_hparams(trial, is_regression): 101 | """Sample the hyper-parameter search space for MLPs""" 102 | architectures = [[width] * depth for width in [64, 128, 256] for depth in range(1, 4)] 103 | arch = trial.suggest_categorical("hidden_layer_sizes", architectures) 104 | lr = trial.suggest_float("learning_rate_init", 1e-7, 1e0, log=True) 105 | alpha = trial.suggest_float("alpha", 1e-10, 1e0, log=True) 106 | max_iter = trial.suggest_int("max_iter", 1, 300) 107 | batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) 108 | 109 | return { 110 | "max_iter": max_iter, 111 | "alpha": alpha, 112 | "learning_rate_init": lr, 113 | "hidden_layer_sizes": arch, 114 | "batch_size": batch_size, 115 | } 116 | 117 | 118 | def suggest_rf_hparams(trial, is_regression): 119 | """Sample the hyper-parameter search space for RFs""" 120 | n_estimators = trial.suggest_int("n_estimators", 100, 1000) 121 | max_depth = trial.suggest_categorical("max_depth", [None] + list(range(1, 11))) 122 | 123 | params = { 124 | "max_depth": max_depth, 125 | "n_estimators": n_estimators, 126 | } 127 | return params 128 | 129 | 130 | def construct_kernel(is_regression, params): 131 | """Constructs a scikit-learn kernel based on provided hyper-parameters""" 132 | 133 | metric = params.pop("kernel_metric") 134 | gamma = params.pop("kernel_gamma") 135 | coef0 = params.pop("kernel_coef0") 136 | 137 | if is_regression: 138 | kernel = PairwiseKernel(metric=metric, gamma=gamma, pairwise_kernels_kwargs={"coef0": coef0}) 139 | else: 140 | noise_level = params.pop("kernel_noise_level") 141 | kernel = Sum( 142 | PairwiseKernel( 143 | metric=metric, 144 | gamma=gamma, 145 | pairwise_kernels_kwargs={"coef0": coef0}, 146 | ), 147 | WhiteKernel(noise_level=noise_level), 148 | ) 149 | return kernel, params 150 | 151 | 152 | def suggest_gp_hparams(trial, is_regression): 153 | """Sample the hyper-parameter search space for GPs""" 154 | 155 | kernel_types = ["linear", "poly", "polynomial", "rbf", "laplacian", "sigmoid", "cosine"] 156 | metric = trial.suggest_categorical("kernel_metric", kernel_types) 157 | gamma = trial.suggest_float("kernel_gamma", 1e-5, 1e0, log=True) 158 | coef0 = trial.suggest_float("kernel_coef0", 1e-5, 1.0, log=True) 159 | params = {"kernel_gamma": gamma, "kernel_coef0": coef0, "kernel_metric": metric} 160 | 161 | n_restarts_optimizer = trial.suggest_int("n_restarts_optimizer", 0, 10) 162 | params["n_restarts_optimizer"] = n_restarts_optimizer 163 | 164 | if is_regression: 165 | params["alpha"] = trial.suggest_float("alpha", 1e-10, 1e0, log=True) 166 | else: 167 | max_iter_predict = trial.suggest_int("max_iter_predict", 10, 250) 168 | noise_level = trial.suggest_float("kernel_noise_level", 1e-5, 1, log=True) 169 | params["kernel_noise_level"] = noise_level 170 | params["max_iter_predict"] = max_iter_predict 171 | 172 | return params 173 | 174 | 175 | def suggest_baseline_hparams(name: str, is_regression: bool, trial: optuna.Trial): 176 | """Endpoint for sampling the hyper-parameter search space of the baselines""" 177 | fs = { 178 | "MLP": suggest_mlp_hparams, 179 | "RF": suggest_rf_hparams, 180 | "GP": suggest_gp_hparams, 181 | } 182 | return fs[name](trial, is_regression) 183 | -------------------------------------------------------------------------------- /mood/chemistry.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | from rdkit.Chem.Scaffolds import MurckoScaffold 3 | from loguru import logger 4 | 5 | 6 | def compute_murcko_scaffold(mol): 7 | """Computes the Bemis-Murcko scaffold of a compounds.""" 8 | mol = dm.to_mol(mol) 9 | scaffold = dm.to_scaffold_murcko(mol) 10 | scaffold = dm.to_smiles(scaffold) 11 | return scaffold 12 | 13 | 14 | def compute_generic_scaffold(smi: str): 15 | """Computes the scaffold (i.e. the domain) for the datapoint. The generic scaffold is the 16 | structural graph of the Murcko scaffold 17 | 18 | Args: 19 | smi (str): The SMILES string of the molecule to find the generic scaffold for 20 | Returns: 21 | The SMILES of the Generic scaffold of the input SMILES 22 | """ 23 | 24 | scaffold = compute_murcko_scaffold(smi) 25 | with dm.without_rdkit_log(mute_errors=False): 26 | try: 27 | scaffold = dm.to_mol(scaffold) 28 | scaffold = MurckoScaffold.MakeScaffoldGeneric(mol=scaffold) 29 | scaffold = dm.to_smiles(scaffold) 30 | except Exception as exception: 31 | logger.debug(f"Failed to compute the GenericScaffold for {smi} due to {exception}") 32 | logger.debug(f"Returning the empty SMILES as the scaffold") 33 | scaffold = "" 34 | return scaffold 35 | -------------------------------------------------------------------------------- /mood/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from scripts.cli import app as scripts_app 3 | from mood.experiment import tune_cmd 4 | from mood.experiment import rct_cmd 5 | 6 | 7 | app = typer.Typer(add_completion=False) 8 | app.add_typer(scripts_app, name="scripts") 9 | app.command(name="tune", help="Hyper-param search for a model with a specific configuration")(tune_cmd) 10 | app.command(name="rct", help="Randomly sample a configuration for training")(rct_cmd) 11 | 12 | 13 | if __name__ == "__main__": 14 | app() 15 | -------------------------------------------------------------------------------- /mood/constants.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | 3 | 4 | """ 5 | Where results and data are saved to 6 | """ 7 | CACHE_DIR = dm.fs.get_cache_dir("MOOD") 8 | 9 | """ 10 | For the downstream applications (optimization and virtual screening) 11 | we save all related data to this directory 12 | """ 13 | DOWNSTREAM_APPS_DATA_DIR = "https://storage.valencelabs.com/mood-data/downstream_applications/" 14 | 15 | """ 16 | Where the results of MOOD are saved 17 | """ 18 | RESULTS_DIR = "https://storage.valencelabs.com/mood-data/results/" 19 | 20 | """ 21 | The two downstream applications we consider for MOOD as application areas of molecular scoring 22 | """ 23 | SUPPORTED_DOWNSTREAM_APPS = ["virtual_screening", "optimization"] 24 | 25 | """ 26 | Where data related to specific datasets is saved 27 | """ 28 | DATASET_DATA_DIR = "https://storage.valencelabs.com/mood-data/datasets/" 29 | 30 | 31 | """The number of epochs to train NNs for""" 32 | NUM_EPOCHS = 100 33 | -------------------------------------------------------------------------------- /mood/criteria.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Sequence 3 | 4 | import numpy as np 5 | 6 | from mood.dataset import SimpleMolecularDataset 7 | from mood.distance import compute_knn_distance, get_distance_metric 8 | from mood.metrics import Metric 9 | 10 | 11 | MOOD_CRITERIA = [ 12 | "Performance", 13 | "Domain Weighted Performance", 14 | "Distance Weighted Performance", 15 | "Calibration", 16 | "Calibration x Performance", 17 | ] 18 | 19 | 20 | def get_mood_criteria(performance_metric, calibration_metric): 21 | """Endpoint for easily creating a criterion by name""" 22 | 23 | return { 24 | "Performance": PerformanceCriterion(performance_metric), 25 | "Calibration": CalibrationCriterion(calibration_metric), 26 | "Domain Weighted Performance": DomainWeightedPerformanceCriterion(performance_metric), 27 | "Distance Weighted Performance": DistanceWeightedPerformanceCriterion(performance_metric), 28 | "Calibration x Performance": CombinedCriterion(performance_metric, calibration_metric), 29 | } 30 | 31 | 32 | class ModelSelectionCriterion(abc.ABC): 33 | """ 34 | In MOOD, we argue that one of the tools to improve _model selection_, is the criterion we use 35 | to select. Besides selecting for raw, validation performance we suspect there could be better alternatives. 36 | This class defines the interface for a criterion to implement. 37 | 38 | We distinguish multiple iterations within a hyper-parameter search trial. 39 | For example, you might train and evaluate a model on N different splits before scoring the hyper-parameters. 40 | """ 41 | 42 | def __init__(self, mode, needs_uncertainty: bool): 43 | self.mode = mode 44 | self.needs_uncertainty = needs_uncertainty 45 | self.scores = [] 46 | 47 | @abc.abstractmethod 48 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 49 | pass 50 | 51 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 52 | return None 53 | 54 | def __call__(self, *args, **kwargs): 55 | return self.score(*args, **kwargs) 56 | 57 | def update(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 58 | """Scores a single iteration with the hyper-parameter search.""" 59 | self.scores.append(self.score(predictions, uncertainties, train, val)) 60 | 61 | def critique(self): 62 | """Aggregates the scores of individual iterations.""" 63 | if len(self.scores) == 0: 64 | raise RuntimeError("Cannot critique when no scores have been computed yet") 65 | 66 | if isinstance(self.scores[0], Sequence): 67 | lengths = set(len(s) for s in self.scores) 68 | if len(lengths) != 1: 69 | raise RuntimeError("All scores need to have the same number of dimensions") 70 | n = lengths.pop() 71 | return list(np.mean([s[i] for s in self.scores]) for i in range(n)) 72 | else: 73 | score = np.mean(self.scores) 74 | 75 | self.reset() 76 | return score 77 | 78 | def reset(self): 79 | self.scores = [] 80 | 81 | 82 | class PerformanceCriterion(ModelSelectionCriterion): 83 | """Select models based on the mean validation performance.""" 84 | 85 | def __init__(self, metric: Metric): 86 | super().__init__(mode=metric.mode, needs_uncertainty=False) 87 | if metric.is_calibration: 88 | raise ValueError(f"{metric.name} cannot be used with {type(self).__name__}") 89 | self.metric = metric 90 | 91 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 92 | sample_weights = self.compute_weights(train, val) # noqa 93 | return self.metric(y_true=val.y, y_pred=predictions, sample_weights=sample_weights) 94 | 95 | 96 | class DomainWeightedPerformanceCriterion(PerformanceCriterion): 97 | """Select models based on the mean weighted validation performance, where the weight of each sample is 98 | 1 over the domain frequency of the domain it is part of.""" 99 | 100 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 101 | _, inverse, counts = np.unique(val.domains, return_counts=True, return_inverse=True) 102 | counts = [n / sum(counts) for n in counts] 103 | weights = [counts[idx] for idx in inverse] 104 | return weights 105 | 106 | 107 | class DistanceWeightedPerformanceCriterion(PerformanceCriterion): 108 | """Select models based on the mean weighted validation performance, 109 | where the weight of each sample is its distance to the train set.""" 110 | 111 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 112 | distance_metric = get_distance_metric(val.X) 113 | return compute_knn_distance(train.X, val.X, distance_metric, k=5, n_jobs=-1) 114 | 115 | 116 | class CalibrationCriterion(ModelSelectionCriterion): 117 | """Select a model based on the mean validation calibration""" 118 | 119 | def __init__(self, metric: Metric): 120 | super().__init__(mode=metric.mode, needs_uncertainty=True) 121 | if not metric.is_calibration: 122 | raise ValueError(f"{metric.name} cannot be used with {type(self).__name__}") 123 | self.metric = metric 124 | 125 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 126 | return self.metric(y_true=val.y, y_pred=predictions, uncertainty=uncertainties) 127 | 128 | 129 | class CombinedCriterion(ModelSelectionCriterion): 130 | """Selects a model based on a combined score of the validation calibration 131 | and the validation performance. Since calibration score is between [0, 1], does so by 132 | either multiplying (when maximizing) or dividing (when minimizing) 133 | the performance score by the calibration score""" 134 | 135 | def __init__(self, performance_metric: Metric, calibration_metric: Metric): 136 | super().__init__(mode=[performance_metric.mode, calibration_metric.mode], needs_uncertainty=True) 137 | self.performance_criterion = PerformanceCriterion(performance_metric) 138 | self.calibration_criterion = CalibrationCriterion(calibration_metric) 139 | 140 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset): 141 | prf_score = self.performance_criterion.score(predictions, uncertainties, train, val) 142 | cal_score = self.calibration_criterion.score(predictions, uncertainties, train, val) 143 | return prf_score, cal_score 144 | -------------------------------------------------------------------------------- /mood/dataset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import datamol as dm 4 | import numpy as np 5 | 6 | from typing import Optional, List 7 | 8 | import torch.utils.data 9 | from tdc.single_pred import ADME, Tox 10 | from tdc.metadata import dataset_names 11 | from torch.utils.data import default_collate 12 | 13 | from mood.chemistry import compute_generic_scaffold 14 | from mood.constants import CACHE_DIR 15 | 16 | 17 | class SimpleMolecularDataset(torch.utils.data.Dataset): 18 | """ 19 | Simple wrapper to use these datasets with PyTorch (Lightning) 20 | Special is that this includes the notion of a molecular domain, 21 | needed for Domain Generalization 22 | """ 23 | 24 | def __init__(self, smiles, X, y): 25 | self.smiles = smiles 26 | self.X = X 27 | self.y = y 28 | self.random_state = None 29 | self.domains = np.array([compute_generic_scaffold(smi) for smi in self.smiles]) 30 | self.domain_representations = None 31 | 32 | def __getitem__(self, index): 33 | x = self.X[index] 34 | if self.domains is not None: 35 | d = self.domains[index] 36 | if self.domain_representations is not None: 37 | d = (d, self.domain_representations[d]) 38 | x = (x, d) 39 | return x, self.y[index] 40 | 41 | def __len__(self): 42 | return len(self.X) 43 | 44 | def compute_domain_representations(self): 45 | self.domain_representations = {} 46 | unique_domains, inverse = np.unique(self.domains, return_inverse=True, axis=0) 47 | for idx in np.unique(inverse): 48 | indices = np.flatnonzero(inverse == idx) 49 | representation = self.X[indices].mean(axis=0) 50 | self.domain_representations[unique_domains[idx]] = representation 51 | 52 | def filter_by_indices(self, indices): 53 | cpy = deepcopy(self) 54 | cpy.smiles = cpy.smiles[indices] 55 | cpy.X = cpy.X[indices] 56 | cpy.y = cpy.y[indices] 57 | if cpy.domains is not None: 58 | cpy.domains = cpy.domains[indices] 59 | return cpy 60 | 61 | 62 | class DAMolecularDataset(torch.utils.data.Dataset): 63 | """Simple wrapper that creates a dataset with a supervised source domain and unsupervised target domain. 64 | This is needed for Domain Adaptation algorithms.""" 65 | 66 | def __init__(self, source_dataset: SimpleMolecularDataset, target_dataset: SimpleMolecularDataset): 67 | self.src = source_dataset 68 | self.tgt = target_dataset 69 | 70 | def __getitem__(self, item): 71 | src = self.src.__getitem__(item) 72 | tgt_index = np.random.default_rng(item).integers(0, len(self.tgt)) 73 | (x, domain), y = self.tgt.__getitem__(tgt_index) 74 | return {"source": src, "target": (x, domain)} 75 | 76 | def __len__(self): 77 | return len(self.src) 78 | 79 | 80 | def domain_based_collate(batch): 81 | """Custom collate function that splits a single batch into several mini-batches based on the domain. 82 | Doing that here instead of on the GPU is faster.""" 83 | 84 | domains = [domain[0] if isinstance(domain, tuple) else domain for (X, domain), y in batch] 85 | batch = [ 86 | ((X, domain[1]), y) if isinstance(domain, tuple) else ((X, domain), y) for (X, domain), y in batch 87 | ] 88 | _, inverse = np.unique(domains, return_inverse=True, axis=0) 89 | 90 | mini_batches = [] 91 | for idx in np.unique(inverse): 92 | indices = np.flatnonzero(inverse == idx) 93 | mini_batch = [batch[i] for i in indices] 94 | mini_batch = default_collate(mini_batch) 95 | mini_batches.append(mini_batch) 96 | return mini_batches 97 | 98 | 99 | def domain_based_inference_collate(batch): 100 | """Custom collate function that preprocesses our custom dataset for inference.""" 101 | batch = [ 102 | ((X, domain[1]), y) if isinstance(domain, tuple) else ((X, domain), y) for (X, domain), y in batch 103 | ] 104 | return default_collate(batch) 105 | 106 | 107 | def load_data_from_tdc(name: str, disable_logs: bool = False): 108 | """ 109 | Endpoint from loading a dataset from TDC 110 | """ 111 | 112 | original_name = name 113 | if name in MOOD_TO_TDC: 114 | name = MOOD_TO_TDC[name] 115 | 116 | path = dm.fs.join(CACHE_DIR, "TDC") 117 | 118 | # Load the dataset 119 | if name.lower() in dataset_names["ADME"]: 120 | dataset = ADME(name=name, path=path) 121 | elif name.lower() in dataset_names["Tox"]: 122 | dataset = Tox(name=name, path=path) 123 | else: 124 | msg = f"{original_name} is not supported. Choose from {MOOD_DATASETS}." 125 | raise RuntimeError(msg) 126 | 127 | # Standardize the SMILES 128 | with dm.without_rdkit_log(enable=disable_logs): 129 | smiles = dataset.entity1 130 | smiles = np.array([dm.to_smiles(dm.to_mol(smi)) for smi in smiles]) 131 | 132 | # Load the targets 133 | y = np.array(dataset.y) 134 | 135 | # Mask out NaN that might be the result of standardization 136 | mask = [i for i, x in enumerate(smiles) if x is not None] 137 | smiles = smiles[mask] 138 | y = y[mask] 139 | 140 | return smiles, y 141 | 142 | 143 | def dataset_iterator( 144 | disable_logs: bool = True, 145 | whitelist: Optional[List[str]] = None, 146 | blacklist: Optional[List[str]] = None, 147 | ): 148 | """Endpoint for iterating over several datasets from TDC""" 149 | 150 | if whitelist is not None and blacklist is not None: 151 | msg = "You cannot use a blacklist and whitelist at the same time" 152 | raise ValueError(msg) 153 | 154 | all_datasets = MOOD_DATASETS 155 | 156 | if whitelist is not None: 157 | all_datasets = [d for d in all_datasets if d in whitelist] 158 | if blacklist is not None: 159 | all_datasets = [d for d in all_datasets if d not in blacklist] 160 | 161 | for name in all_datasets: 162 | yield name, load_data_from_tdc(name, disable_logs) 163 | 164 | 165 | TDC_TO_MOOD = { 166 | "BBB_Martins": "BBB", 167 | "CYP2C9_Veith": "CYP2C9", 168 | "Caco2_Wang": "Caco-2", 169 | "Clearance_Hepatocyte_AZ": "Clearance", 170 | "DILI": "DILI", 171 | "HIA_Hou": "HIA", 172 | "Half_Life_Obach": "HalfLife", 173 | "Lipophilicity_AstraZeneca": "Lipophilicity", 174 | "PPBR_AZ": "PPBR", 175 | "Pgp_Broccatelli": "Pgp", 176 | "hERG": "hERG", 177 | } 178 | 179 | # Ordered by size 180 | MOOD_DATASETS = [ 181 | "DILI", 182 | "HIA", 183 | "hERG", 184 | "HalfLife", 185 | "Caco-2", 186 | "Clearance", 187 | "Pgp", 188 | "PPBR", 189 | "BBB", 190 | "Lipophilicity", 191 | "CYP2C9", 192 | ] 193 | 194 | MOOD_TO_TDC = {v: k for k, v in TDC_TO_MOOD.items()} 195 | MOOD_CLSF_DATASETS = ["BBB", "CYP2C9", "DILI", "HIA", "Pgp", "hERG"] 196 | MOOD_REGR_DATASETS = [d for d in MOOD_DATASETS if d not in MOOD_CLSF_DATASETS] 197 | -------------------------------------------------------------------------------- /mood/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from typing import Optional, Union, List 5 | 6 | from sklearn.neighbors import NearestNeighbors 7 | 8 | 9 | def get_distance_metric(example): 10 | """Get the appropriate distance metric given an exemplary datapoint""" 11 | 12 | # By default we use the Euclidean distance 13 | metric = "euclidean" 14 | 15 | # For binary vectors we use jaccard 16 | if isinstance(example, pd.DataFrame): 17 | example = example.values # DataFrames would require all().all() otherwise 18 | if ((example == 0) | (example == 1)).all(): 19 | metric = "jaccard" 20 | 21 | return metric 22 | 23 | 24 | def compute_knn_distance( 25 | X: np.ndarray, 26 | Y: Optional[Union[np.ndarray, List[np.ndarray]]] = None, 27 | metric: Optional[str] = None, 28 | k: int = 5, 29 | n_jobs: Optional[int] = None, 30 | return_indices: bool = False, 31 | ): 32 | """ 33 | Computes the mean k-Nearest Neighbors distance 34 | between a set of database embeddings and a set of query embeddings 35 | 36 | Args: 37 | X: The set of samples that form kNN candidates 38 | Y: The samples for which to find the kNN for. If None, will find kNN for `database` 39 | metric: The pairwise distance metric to define the neighborhood 40 | k: The number of neighbors to find 41 | n_jobs: Controls the parallelization 42 | return_indices: Whether to return the indices of the NNs as well 43 | """ 44 | 45 | if metric is None: 46 | metric = get_distance_metric(X[0]) 47 | 48 | knn = NearestNeighbors(n_neighbors=k, metric=metric, n_jobs=n_jobs) 49 | knn.fit(X) 50 | 51 | if not isinstance(Y, list): 52 | Y = [Y] 53 | 54 | distances, indices = [], [] 55 | for queries in Y: 56 | if np.array_equal(X, queries): 57 | # Use k + 1 and filter out the first 58 | # because the sample will always be its own neighbor 59 | dist, ind = knn.kneighbors(queries, n_neighbors=k + 1) 60 | dist, ind = dist[:, 1:], ind[:, 1:] 61 | else: 62 | dist, ind = knn.kneighbors(queries, n_neighbors=k) 63 | 64 | distances.append(dist) 65 | indices.append(ind) 66 | 67 | # The distance from the query molecule to its NNs is the mean of all pairwise distances 68 | distances = [np.mean(dist, axis=1) for dist in distances] 69 | 70 | if len(distances) == 1: 71 | assert len(indices) == 1 72 | distances = distances[0] 73 | indices = indices[0] 74 | 75 | if return_indices: 76 | return distances, indices 77 | return distances 78 | -------------------------------------------------------------------------------- /mood/experiment.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import fsspec 3 | import optuna 4 | import numpy as np 5 | import datamol as dm 6 | 7 | from copy import deepcopy 8 | from datetime import datetime 9 | from typing import Optional 10 | from loguru import logger 11 | from sklearn.base import BaseEstimator 12 | from sklearn.preprocessing import StandardScaler 13 | from torch.utils.data import DataLoader 14 | from torchmetrics.functional.regression import ( 15 | r2_score, 16 | spearman_corrcoef, 17 | pearson_corrcoef, 18 | mean_absolute_error, 19 | ) 20 | 21 | from mood.baselines import suggest_baseline_hparams, predict_baseline_uncertainty 22 | from mood.constants import RESULTS_DIR 23 | from mood.model import MOOD_DA_DG_ALGORITHMS, needs_domain_representation, is_domain_generalization 24 | from mood.model.base import Ensemble 25 | from mood.train import train_baseline_model, train 26 | from mood.criteria import get_mood_criteria 27 | from mood.dataset import ( 28 | load_data_from_tdc, 29 | SimpleMolecularDataset, 30 | MOOD_REGR_DATASETS, 31 | domain_based_inference_collate, 32 | ) 33 | from mood.distance import get_distance_metric 34 | from mood.metrics import Metric 35 | from mood.representations import featurize 36 | from mood.preprocessing import DEFAULT_PREPROCESSING 37 | from mood.splitter import get_mood_splitters, MOODSplitter 38 | from mood.utils import load_distances_for_downstream_application 39 | from mood.rct import get_experimental_configurations 40 | 41 | 42 | def run_study(metric, algorithm, n_startup_trials, n_trials, trial_fn, seed): 43 | """Endpoint for running an Optuna study""" 44 | 45 | sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials) 46 | 47 | if isinstance(metric.mode, list): 48 | directions = ["maximize" if m == "max" else "minimize" for m in metric.mode] 49 | study = optuna.create_study(directions=directions, sampler=sampler) 50 | else: 51 | direction = "maximize" if metric.mode == "max" else "minimize" 52 | study = optuna.create_study(direction=direction, sampler=sampler) 53 | 54 | if algorithm == "GP": 55 | # ValueError: array must not contain infs or NaNs 56 | # LinAlgError: N-th leading minor of the array is not positive definite 57 | # LinAlgError: The kernel is not returning a positive definite matrix 58 | catch = (np.linalg.LinAlgError, ValueError) 59 | elif algorithm == "Mixup": 60 | # RuntimeError: all elements of input should be between 0 and 1 61 | # NOTE: This is not robust (as other RunTimeErrors could be thrown) 62 | # but is an easy, performant way to check for NaN values which often 63 | # occurred for Mixup due to high losses on the first few batches 64 | catch = (RuntimeError,) 65 | else: 66 | catch = () 67 | 68 | study.optimize(trial_fn, n_trials=n_trials, catch=catch) 69 | return study 70 | 71 | 72 | def basic_tuning_loop( 73 | X_train, 74 | X_test, 75 | y_train, 76 | y_test, 77 | name: str, 78 | is_regression: bool, 79 | metric: Metric, 80 | global_seed: int, 81 | for_uncertainty_estimation: bool = False, 82 | ensemble_size: int = 10, 83 | n_trials: int = 50, 84 | n_startup_trials: int = 10, 85 | ): 86 | """ 87 | This hyper-parameter search loop is used to train baseline models for the MOOD specification. 88 | All baselines are from scikit-learn. 89 | 90 | NOTE: This could be merged with the more elaborate tuning loop we wrote later 91 | However, for the sake of reproducibility, I wanted to keep this code intact. 92 | This way, the exact code used to generate results is still easily accessible 93 | in the code base 94 | """ 95 | 96 | def run_trial(trial): 97 | random_state = global_seed + trial.number 98 | params = suggest_baseline_hparams(name, is_regression, trial) 99 | model = train_baseline_model( 100 | X_train, 101 | y_train, 102 | name, 103 | is_regression, 104 | params, 105 | random_state, 106 | for_uncertainty_estimation, 107 | ensemble_size, 108 | calibrate=True, 109 | ) 110 | y_pred = model.predict(X_test) 111 | score = metric(y_test, y_pred) 112 | return score 113 | 114 | study = run_study( 115 | metric=metric, 116 | algorithm=name, 117 | n_startup_trials=n_startup_trials, 118 | n_trials=n_trials, 119 | trial_fn=run_trial, 120 | seed=global_seed, 121 | ) 122 | return study 123 | 124 | 125 | def rct_dataset_setup(dataset, train_indices, val_indices, test_dataset, is_regression): 126 | """Sets up the dataset. Specifically, splits the dataset and standardizes the targets for regression tasks""" 127 | 128 | train_dataset = dataset.filter_by_indices(train_indices) 129 | val_dataset = dataset.filter_by_indices(val_indices) 130 | test_dataset = deepcopy(test_dataset) 131 | 132 | # Z-standardization of the targets 133 | if is_regression: 134 | scaler = StandardScaler() 135 | train_dataset.y = scaler.fit_transform(train_dataset.y) 136 | val_dataset.y = scaler.transform(val_dataset.y) 137 | test_dataset.y = scaler.transform(test_dataset.y) 138 | 139 | return train_dataset, val_dataset, test_dataset 140 | 141 | 142 | def rct_predict_step(model, dataset): 143 | """Get the predictions and uncertainty estimates from either a scikit-learn model or torch model""" 144 | if isinstance(model, BaseEstimator): 145 | y_pred = model.predict(dataset.X).reshape(-1, 1) 146 | uncertainty = predict_baseline_uncertainty(model, dataset.X) 147 | elif isinstance(model, Ensemble): 148 | collate_fn = domain_based_inference_collate if is_domain_generalization(model.models[0]) else None 149 | dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=model.models[0].batch_size) 150 | y_pred = model.predict(dataloader) 151 | uncertainty = model.predict_uncertainty(dataloader) 152 | else: 153 | raise NotImplementedError 154 | return y_pred, uncertainty 155 | 156 | 157 | def rct_compute_metrics( 158 | y_true, y_pred, y_uncertainty, performance_metric, calibration_metric, is_regression, prefix, suffix 159 | ): 160 | prf_score = performance_metric(y_true, y_pred) 161 | 162 | # NOTE: Ideally we would always compute the calibration metric, 163 | # but that was too computationally expensive due to the need of ensembles 164 | if y_uncertainty is not None: 165 | cal_score = calibration_metric(y_true, y_pred, y_uncertainty) 166 | else: 167 | cal_score = None 168 | 169 | ret = { 170 | f"{prefix}_calibration_{calibration_metric.name}_{suffix}": cal_score, 171 | f"{prefix}_performance_{performance_metric.name}_{suffix}": prf_score, 172 | } 173 | 174 | if is_regression: 175 | targets = Metric.preprocess_targets(y_true, is_regression) 176 | predictions = Metric.preprocess_predictions(y_pred, targets.device) 177 | 178 | # NOTE: Before starting the RCT, we were not sure what metric to use 179 | # to compare models for regression tasks, that's why we compute some extra here 180 | ret[f"{prefix}_extra_r2_{suffix}"] = r2_score(preds=predictions, target=targets).item() 181 | ret[f"{prefix}_extra_spearman_{suffix}"] = spearman_corrcoef(preds=predictions, target=targets).item() 182 | ret[f"{prefix}_extra_pearson_{suffix}"] = pearson_corrcoef(preds=predictions, target=targets).item() 183 | ret[f"{prefix}_extra_mae_{suffix}"] = mean_absolute_error(preds=predictions, target=targets).item() 184 | 185 | return ret 186 | 187 | 188 | def rct_evaluate_step( 189 | model, 190 | train_dataset, 191 | val_dataset, 192 | test_dataset, 193 | performance_metric, 194 | calibration_metric, 195 | is_regression, 196 | suffix, 197 | criterion: Optional = None, 198 | ): 199 | val_y_pred, val_uncertainty = rct_predict_step(model, val_dataset) 200 | val_metrics = rct_compute_metrics( 201 | y_true=val_dataset.y, 202 | y_pred=val_y_pred, 203 | y_uncertainty=val_uncertainty, 204 | performance_metric=performance_metric, 205 | calibration_metric=calibration_metric, 206 | is_regression=is_regression, 207 | prefix="val", 208 | suffix=suffix, 209 | ) 210 | 211 | test_y_pred, test_uncertainty = rct_predict_step(model, test_dataset) 212 | test_metrics = rct_compute_metrics( 213 | y_true=test_dataset.y, 214 | y_pred=test_y_pred, 215 | y_uncertainty=test_uncertainty, 216 | performance_metric=performance_metric, 217 | calibration_metric=calibration_metric, 218 | is_regression=is_regression, 219 | prefix="test", 220 | suffix=suffix, 221 | ) 222 | 223 | # Update the criterion used to select which model is based 224 | if criterion is not None: 225 | criterion.update(val_y_pred, val_uncertainty, train_dataset, val_dataset) 226 | 227 | metrics = val_metrics 228 | metrics.update(test_metrics) 229 | return metrics 230 | 231 | 232 | def rct_tuning_loop( 233 | train_val_dataset: SimpleMolecularDataset, 234 | test_dataset: SimpleMolecularDataset, 235 | algorithm: str, 236 | train_val_split: str, 237 | criterion_name: str, 238 | performance_metric: Metric, 239 | calibration_metric: Metric, 240 | is_regression: bool, 241 | global_seed: int, 242 | num_repeated_splits: int = 3, 243 | num_trials: int = 50, 244 | num_startup_trials: int = 10, 245 | ): 246 | """ 247 | This hyper-parameter search loop is used to benchmark different tools to improve generalization in 248 | the MOOD investigation. It combines training scikit-learn and pytorch (lightning) models. 249 | """ 250 | 251 | rng = np.random.default_rng(global_seed) 252 | seeds = rng.integers(0, 2**16, num_trials) 253 | 254 | def run_trial(trial: optuna.Trial): 255 | random_state = seeds[trial.number].item() 256 | trial.set_user_attr("trial_seed", random_state) 257 | 258 | splitters = get_mood_splitters(train_val_dataset.smiles, num_repeated_splits, random_state, n_jobs=-1) 259 | train_val_splitter = splitters[train_val_split] 260 | 261 | for split_idx, (train_ind, val_ind) in enumerate(train_val_splitter.split(train_val_dataset.X)): 262 | train_dataset, val_dataset, test_dataset_inner = rct_dataset_setup( 263 | train_val_dataset, train_ind, val_ind, test_dataset, is_regression 264 | ) 265 | 266 | # NOTE: AUROC is not defined when there's just a single ground truth class. 267 | # Since this only happens for the unbalanced and small HIA dataset, we just skip. 268 | if performance_metric.name == "AUROC" and len(np.unique(val_dataset.y)) == 1: 269 | continue 270 | 271 | if algorithm in MOOD_DA_DG_ALGORITHMS: 272 | params = MOOD_DA_DG_ALGORITHMS[algorithm].suggest_params(trial) 273 | else: 274 | params = suggest_baseline_hparams(algorithm, is_regression, trial) 275 | 276 | model = train( 277 | train_dataset=train_dataset, 278 | val_dataset=val_dataset, 279 | test_dataset=test_dataset_inner, 280 | algorithm=algorithm, 281 | is_regression=is_regression, 282 | params=params, 283 | seed=random_state, 284 | calibrate=False, 285 | # NOTE: If we do not select models based on uncertainty, 286 | # we don't train an ensemble to reduce computational cost 287 | ensemble_size=5 if criterion.needs_uncertainty else 1, 288 | ) 289 | 290 | metrics = rct_evaluate_step( 291 | model=model, 292 | train_dataset=train_dataset, 293 | val_dataset=val_dataset, 294 | test_dataset=test_dataset_inner, 295 | performance_metric=performance_metric, 296 | calibration_metric=calibration_metric, 297 | is_regression=is_regression, 298 | suffix=str(split_idx), 299 | criterion=criterion, 300 | ) 301 | 302 | # We save the val and test performance for each trial to analyze the success 303 | # of the model selection procedure (gap between best and selected model) 304 | for k, v in metrics.items(): 305 | trial.set_user_attr(k, v) 306 | 307 | return criterion.critique() 308 | 309 | criterion = get_mood_criteria(performance_metric, calibration_metric)[criterion_name] 310 | 311 | study = run_study( 312 | metric=criterion, 313 | algorithm=algorithm, 314 | n_startup_trials=num_startup_trials, 315 | n_trials=num_trials, 316 | trial_fn=run_trial, 317 | seed=global_seed, 318 | ) 319 | 320 | return study 321 | 322 | 323 | def tune_cmd( 324 | dataset, 325 | algorithm, 326 | representation, 327 | train_val_split, 328 | criterion, 329 | seed: int = 0, 330 | use_cache: bool = False, 331 | base_save_dir: str = RESULTS_DIR, 332 | sub_save_dir: Optional[str] = None, 333 | overwrite: bool = False, 334 | ): 335 | """ 336 | The MOOD tuning loop: Runs a hyper-parameter search. 337 | 338 | Prescribes a train-test split based on the MOOD specification and runs a hyper-parameter search 339 | for the training set. 340 | """ 341 | 342 | if sub_save_dir is None: 343 | sub_save_dir = datetime.now().strftime("%Y%m%d") 344 | 345 | csv_out_dir = dm.fs.join(base_save_dir, "dataframes", "RCT", sub_save_dir) 346 | csv_fname = f"rct_study_{dataset}_{algorithm}_{representation}_{train_val_split}_{criterion}_{seed}.csv" 347 | csv_path = dm.fs.join(csv_out_dir, csv_fname) 348 | dm.fs.mkdir(csv_out_dir, exist_ok=True) 349 | 350 | yaml_out_dir = dm.fs.join(base_save_dir, "YAML", "RCT", sub_save_dir) 351 | yaml_fname = ( 352 | f"rct_selected_model_{dataset}_{algorithm}_{representation}_{train_val_split}_{criterion}_{seed}.yaml" 353 | ) 354 | yaml_path = dm.fs.join(yaml_out_dir, yaml_fname) 355 | dm.fs.mkdir(yaml_out_dir, exist_ok=True) 356 | 357 | if not overwrite and dm.fs.exists(yaml_path) and dm.fs.exists(csv_path): 358 | logger.info(f"Both the files already exists and overwrite=False. Skipping!") 359 | return 360 | 361 | # Load and preprocess the data 362 | smiles, y = load_data_from_tdc(dataset, disable_logs=True) 363 | X, mask = featurize(smiles, representation, DEFAULT_PREPROCESSING[representation], disable_logs=True) 364 | X = X.astype(np.float32) 365 | smiles = smiles[mask] 366 | y = y[mask] 367 | 368 | is_regression = dataset in MOOD_REGR_DATASETS 369 | if is_regression: 370 | y = y.reshape(-1, 1) 371 | 372 | # Prescribe a train-test split 373 | distances_vs = load_distances_for_downstream_application( 374 | "virtual_screening", representation, dataset, update_cache=not use_cache 375 | ) 376 | distances_op = load_distances_for_downstream_application( 377 | "optimization", representation, dataset, update_cache=not use_cache 378 | ) 379 | 380 | distance_metric = get_distance_metric(X) 381 | splitters = get_mood_splitters(smiles, 5, seed, n_jobs=-1) 382 | train_test_splitter = MOODSplitter( 383 | splitters, np.concatenate((distances_vs, distances_op)), distance_metric, k=5 384 | ) 385 | train_test_splitter.fit(X) 386 | 387 | # Split the data using the prescribed split 388 | trainval, test = next(train_test_splitter.split(X, y)) 389 | train_val_dataset = SimpleMolecularDataset(smiles[trainval], X[trainval], y[trainval]) 390 | test_dataset = SimpleMolecularDataset(smiles[test], X[test], y[test]) 391 | 392 | if needs_domain_representation(algorithm): 393 | train_val_dataset.compute_domain_representations() 394 | test_dataset.compute_domain_representations() 395 | 396 | # Get metrics for this dataset 397 | performance_metric = Metric.get_default_performance_metric(dataset) 398 | calibration_metric = Metric.get_default_calibration_metric(dataset) 399 | 400 | # Run the hyper-parameter search 401 | study = rct_tuning_loop( 402 | train_val_dataset=train_val_dataset, 403 | test_dataset=test_dataset, 404 | algorithm=algorithm, 405 | train_val_split=train_val_split, 406 | criterion_name=criterion, 407 | performance_metric=performance_metric, 408 | calibration_metric=calibration_metric, 409 | is_regression=is_regression, 410 | global_seed=seed, 411 | ) 412 | 413 | # Train the best model found again, but this time as an ensemble 414 | # to evaluate the test performance and calibration 415 | if len(study.directions) > 1: 416 | logger.info(f"There's {len(study.best_trials)} models on the Pareto front. Picking one randomly!") 417 | rng = np.random.default_rng(seed) 418 | best_trial = rng.choice(study.best_trials) 419 | else: 420 | best_trial = study.best_trial 421 | 422 | # NOTE: Some methods are really sensitive to hyper-parameters (e.g. GPs, Mixup) 423 | # So with a different train-val split, these might no longer succeed to train. 424 | random_state = best_trial.user_attrs["trial_seed"] 425 | splitters = get_mood_splitters(train_val_dataset.smiles, 1, random_state, n_jobs=-1) 426 | train_val_splitter = splitters[train_val_split] 427 | train_ind, val_ind = next(train_val_splitter.split(train_val_dataset.X)) 428 | 429 | train_dataset, val_dataset, test_dataset = rct_dataset_setup( 430 | train_val_dataset, train_ind, val_ind, test_dataset, is_regression 431 | ) 432 | model = train( 433 | train_dataset=train_dataset, 434 | val_dataset=val_dataset, 435 | test_dataset=test_dataset, 436 | algorithm=algorithm, 437 | is_regression=is_regression, 438 | params=best_trial.params, 439 | seed=random_state, 440 | calibrate=False, 441 | ensemble_size=5, 442 | ) 443 | 444 | metrics = rct_evaluate_step( 445 | model=model, 446 | train_dataset=train_dataset, 447 | val_dataset=val_dataset, 448 | test_dataset=test_dataset, 449 | performance_metric=performance_metric, 450 | calibration_metric=calibration_metric, 451 | is_regression=is_regression, 452 | suffix="final", 453 | ) 454 | 455 | # Save the full trial results as a CSV 456 | logger.info(f"Saving the full study data to {csv_path}") 457 | df = study.trials_dataframe() 458 | df["dataset"] = dataset 459 | df["algorithm"] = algorithm 460 | df["representation"] = representation 461 | df["train-val split"] = train_val_split 462 | df["criterion"] = criterion 463 | df["seed"] = seed 464 | df.to_csv(csv_path) 465 | 466 | # Save the most important information as YAML (higher precision) 467 | data = { 468 | "hparams": best_trial.params, 469 | "criterion_final": best_trial.values, 470 | "dataset": dataset, 471 | "algorithm": algorithm, 472 | "representation": representation, 473 | "train_val_split": train_val_split, 474 | "criterion": criterion, 475 | "seed": seed, 476 | **best_trial.user_attrs, 477 | **metrics, 478 | } 479 | 480 | logger.info(f"Saving the data of the best model to {yaml_path}") 481 | with fsspec.open(yaml_path, "w") as fd: 482 | yaml.dump(data, fd) 483 | 484 | 485 | def rct_cmd( 486 | dataset: str, 487 | index: int, 488 | base_save_dir: str = RESULTS_DIR, 489 | sub_save_dir: Optional[str] = None, 490 | overwrite: bool = False, 491 | ): 492 | """ 493 | Entrypoint for the benchmarking study in the MOOD Investigation. 494 | 495 | Deterministically samples one of the unordered set of experimental configurations in the RCT. 496 | And runs the tuning loop for that experimental configuration. 497 | 498 | Here an experimental configuration consists of an algorithm, representation, train-val split, 499 | model selection criterion and seed. 500 | """ 501 | 502 | configs = get_experimental_configurations(dataset) 503 | logger.info(f"Sampled configuration #{index} / {len(configs)} for {dataset}: {configs[index]}") 504 | algorithm, representation, train_val_split, criterion, seed = configs[index] 505 | 506 | tune_cmd( 507 | dataset=dataset, 508 | algorithm=algorithm, 509 | representation=representation, 510 | train_val_split=train_val_split, 511 | criterion=criterion, 512 | seed=seed, 513 | base_save_dir=base_save_dir, 514 | sub_save_dir=sub_save_dir, 515 | overwrite=overwrite, 516 | ) 517 | -------------------------------------------------------------------------------- /mood/metrics.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import torch 3 | import datamol as dm 4 | import numpy as np 5 | from typing import Callable, Optional 6 | 7 | from sklearn.metrics import roc_auc_score 8 | from torchmetrics.functional import mean_absolute_error, mean_squared_error 9 | from torchmetrics.functional.classification import binary_auroc 10 | from torchmetrics.functional.regression.pearson import pearson_corrcoef 11 | from torchmetrics.wrappers.bootstrapping import _bootstrap_sampler 12 | 13 | from mood.dataset import MOOD_REGR_DATASETS 14 | 15 | 16 | def weighted_pearson(preds, target, sample_weights=None): 17 | """ 18 | The weighted Pearson correlation efficient. Based on: 19 | https://stats.stackexchange.com/a/222107 20 | """ 21 | if sample_weights is None: 22 | # If not sample weights are provided, just rely on TorchMetrics' implementation 23 | return pearson_corrcoef(preds=preds, target=target) 24 | 25 | def _weighted_mean(x, w): 26 | return torch.sum(w * x) / torch.sum(w) 27 | 28 | # Copied over the implementation from TorchMetric and made three changes: 29 | # 1) Instead of computing the mean, compute the weighted mean 30 | # 2) Computing the weighted covariance 31 | # 3) Computing the weighted correlation 32 | 33 | preds_diff = preds - _weighted_mean(preds, sample_weights) 34 | target_diff = target - _weighted_mean(target, sample_weights) 35 | 36 | cov = _weighted_mean(preds_diff * target_diff, sample_weights) 37 | 38 | preds_std = torch.sqrt(_weighted_mean(preds_diff * preds_diff, sample_weights)) 39 | target_std = torch.sqrt(_weighted_mean(target_diff * target_diff, sample_weights)) 40 | 41 | corrcoef = cov / (preds_std * target_std + 1e-6) 42 | return torch.clamp(corrcoef, -1.0, 1.0) 43 | 44 | 45 | def weighted_pearson_calibration(preds, target, uncertainty, sample_weights=None): 46 | error = torch.abs(preds - target) 47 | return weighted_pearson(error, uncertainty, sample_weights) 48 | 49 | 50 | def weighted_mae(preds, target, sample_weights=None): 51 | if sample_weights is None: 52 | return mean_absolute_error(preds=preds, target=target) 53 | summed_mae = torch.abs(sample_weights * (preds - target)).sum() 54 | return summed_mae / torch.sum(sample_weights) 55 | 56 | 57 | def weighted_brier_score(target, uncertainty, sample_weights=None): 58 | confidence = torch.stack([unc[tar] for unc, tar in zip(uncertainty, target)]) 59 | if sample_weights is None: 60 | return mean_squared_error(preds=confidence, target=target) 61 | summed_mse = torch.square(sample_weights * (confidence - target)).sum() 62 | brier_score = summed_mse / torch.sum(sample_weights) 63 | return brier_score 64 | 65 | 66 | def weighted_auroc(preds, target, sample_weights=None): 67 | if sample_weights is None: 68 | return binary_auroc(preds=preds, target=target) 69 | 70 | # TorchMetrics does not actually support sample weights, so we rely on the sklearn implementation 71 | preds = preds.cpu().numpy() 72 | target = target.cpu().numpy() 73 | sample_weights = sample_weights.cpu().numpy() 74 | return roc_auc_score(y_true=target, y_score=preds, sample_weight=sample_weights) 75 | 76 | 77 | class TargetType(enum.Enum): 78 | REGRESSION = "regression" 79 | BINARY_CLASSIFICATION = "binary_classification" 80 | 81 | def is_regression(self): 82 | return self == TargetType.REGRESSION 83 | 84 | 85 | class Metric: 86 | def __init__( 87 | self, 88 | name: str, 89 | fn: Callable, 90 | mode: str, 91 | target_type: TargetType, 92 | needs_predictions: bool = True, 93 | needs_uncertainty: bool = False, 94 | range_min: Optional[float] = None, 95 | range_max: Optional[float] = None, 96 | ): 97 | self.fn_ = fn 98 | self.name = name 99 | self.mode = mode 100 | self.target_type = target_type 101 | self.needs_predictions = needs_predictions 102 | self.needs_uncertainty = needs_uncertainty 103 | self.range_min = range_min 104 | self.range_max = range_max 105 | 106 | @property 107 | def is_calibration(self): 108 | return self.needs_uncertainty 109 | 110 | @classmethod 111 | def get_default_calibration_metric(cls, dataset): 112 | is_regression = dataset in MOOD_REGR_DATASETS 113 | if is_regression: 114 | metric = cls.by_name("Pearson") 115 | else: 116 | metric = cls.by_name("Brier score") 117 | return metric 118 | 119 | @classmethod 120 | def get_default_performance_metric(cls, dataset): 121 | is_regression = dataset in MOOD_REGR_DATASETS 122 | if is_regression: 123 | metric = cls.by_name("MAE") 124 | else: 125 | metric = cls.by_name("AUROC") 126 | return metric 127 | 128 | @classmethod 129 | def by_name(cls, name): 130 | if name == "MAE": 131 | return cls( 132 | name="MAE", 133 | fn=weighted_mae, 134 | mode="min", 135 | target_type=TargetType.REGRESSION, 136 | needs_predictions=True, 137 | needs_uncertainty=False, 138 | range_min=0, 139 | range_max=None, 140 | ) 141 | elif name == "Pearson": 142 | return cls( 143 | name="Pearson", 144 | fn=weighted_pearson_calibration, 145 | mode="max", 146 | target_type=TargetType.REGRESSION, 147 | needs_predictions=True, 148 | needs_uncertainty=True, 149 | range_min=-1, 150 | range_max=1, 151 | ) 152 | elif name == "AUROC": 153 | return cls( 154 | name="AUROC", 155 | fn=weighted_auroc, 156 | mode="max", 157 | target_type=TargetType.BINARY_CLASSIFICATION, 158 | needs_predictions=True, 159 | needs_uncertainty=False, 160 | range_min=0, 161 | range_max=1, 162 | ) 163 | elif name == "Brier score": 164 | return cls( 165 | name="Brier score", 166 | fn=weighted_brier_score, 167 | mode="min", 168 | target_type=TargetType.BINARY_CLASSIFICATION, 169 | needs_predictions=False, 170 | needs_uncertainty=True, 171 | range_min=0, 172 | range_max=1, 173 | ) 174 | 175 | def __call__( 176 | self, y_true, y_pred: Optional = None, uncertainty: Optional = None, sample_weights: Optional = None 177 | ): 178 | if self.needs_uncertainty and uncertainty is None: 179 | raise ValueError("Uncertainty estimates needed, but not provided.") 180 | if self.needs_predictions and y_pred is None: 181 | raise ValueError("Predictions needed, but not provided.") 182 | kwargs = self.to_kwargs(y_true, y_pred, uncertainty, sample_weights) 183 | return self.fn_(**kwargs).item() 184 | 185 | @staticmethod 186 | def preprocess_targets(y_true, is_regression: bool): 187 | if not isinstance(y_true, torch.Tensor): 188 | y_true = torch.tensor(y_true) 189 | if is_regression: 190 | y_true = y_true.float().squeeze() 191 | if y_true.ndim == 0: 192 | y_true = y_true.unsqueeze(0) 193 | else: 194 | y_true = y_true.int() 195 | 196 | return y_true 197 | 198 | @staticmethod 199 | def preprocess_predictions(y_pred, device): 200 | if not isinstance(y_pred, torch.Tensor): 201 | y_pred = torch.tensor(y_pred, device=device) 202 | y_pred = y_pred.float().squeeze() 203 | if y_pred.ndim == 0: 204 | y_pred = y_pred.unsqueeze(0) 205 | return y_pred 206 | 207 | @staticmethod 208 | def preprocess_uncertainties(uncertainty, device): 209 | return Metric.preprocess_predictions(uncertainty, device) 210 | 211 | @staticmethod 212 | def preprocess_sample_weights(sample_weights, device): 213 | if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): 214 | sample_weights = torch.tensor(sample_weights, device=device) 215 | return sample_weights 216 | 217 | @property 218 | def range(self): 219 | return self.range_min, self.range_max 220 | 221 | def to_kwargs(self, y_true, y_pred, uncertainty, sample_weights): 222 | kwargs = {"target": self.preprocess_targets(y_true, self.target_type.is_regression())} 223 | kwargs["sample_weights"] = self.preprocess_sample_weights(sample_weights, kwargs["target"].device) 224 | if self.needs_predictions: 225 | kwargs["preds"] = self.preprocess_predictions(y_pred, kwargs["target"].device) 226 | if self.needs_uncertainty: 227 | kwargs["uncertainty"] = self.preprocess_uncertainties(uncertainty, kwargs["target"].device) 228 | return kwargs 229 | 230 | 231 | def compute_bootstrapped_metric( 232 | targets, 233 | metric: Metric, 234 | predictions: Optional = None, 235 | uncertainties: Optional = None, 236 | sample_weights: Optional = None, 237 | sampling_strategy: str = "poisson", 238 | n_bootstraps: int = 1000, 239 | n_jobs: Optional[int] = None, 240 | ): 241 | """ 242 | Bootstrapping to compute confidence intervals for a metric. 243 | Inspired by https://stackoverflow.com/a/19132400 and 244 | https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/wrappers/bootstrapping.py 245 | """ 246 | 247 | def fn(it): 248 | indices = _bootstrap_sampler(len(predictions), sampling_strategy=sampling_strategy) 249 | _sample_weights = None if sample_weights is None else sample_weights[indices] 250 | _uncertainties = None if uncertainties is None else uncertainties[indices] 251 | _predictions = None if predictions is None else predictions[indices] 252 | score = metric(targets[indices], _predictions, _uncertainties, _sample_weights) 253 | return score 254 | 255 | bootstrapped_scores = dm.utils.parallelized(fn, range(n_bootstraps), n_jobs=n_jobs) 256 | bootstrapped_scores = [score for score in bootstrapped_scores if score is not None] 257 | return np.mean(bootstrapped_scores), np.std(bootstrapped_scores) 258 | -------------------------------------------------------------------------------- /mood/model/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from sklearn.base import BaseEstimator 4 | 5 | from mood.model.base import BaseModel 6 | from mood.model.vrex import VREx 7 | from mood.model.coral import CORAL 8 | from mood.model.dann import DANN 9 | from mood.model.ib_erm import InformationBottleneckERM 10 | from mood.model.mixup import Mixup 11 | from mood.model.erm import ERM 12 | from mood.model.mtl import MTL 13 | 14 | 15 | MOOD_DA_DG_ALGORITHMS = { 16 | "VREx": VREx, 17 | "CORAL": CORAL, 18 | "DANN": DANN, 19 | "IB-ERM": InformationBottleneckERM, 20 | "Mixup": Mixup, 21 | "MLP": ERM, 22 | "MTL": MTL, 23 | } 24 | 25 | MOOD_ALGORITHMS = ["RF", "GP", "MLP", "MTL", "VREx", "IB-ERM", "CORAL", "DANN", "Mixup"] 26 | 27 | 28 | def _get_type(model: Union[BaseEstimator, BaseModel, str]): 29 | if isinstance(model, str): 30 | model_type = MOOD_DA_DG_ALGORITHMS.get(model) 31 | else: 32 | model_type = type(model) 33 | if not (model_type is None or issubclass(model_type, BaseEstimator) or issubclass(model_type, BaseModel)): 34 | raise TypeError(f"Can only test models from sklearn, good-learn or mood, not {model_type}") 35 | return model_type 36 | 37 | 38 | def is_domain_adaptation(model: Union[BaseEstimator, BaseModel, str]): 39 | model_type = _get_type(model) 40 | return model_type in [Mixup, DANN, CORAL] 41 | 42 | 43 | def is_domain_generalization(model: Union[BaseEstimator, BaseModel, str]): 44 | model_type = _get_type(model) 45 | return model_type in [MTL, InformationBottleneckERM, VREx] 46 | 47 | 48 | def needs_domain_representation(model: Union[BaseEstimator, BaseModel, str]): 49 | model_type = _get_type(model) 50 | return model_type == MTL 51 | -------------------------------------------------------------------------------- /mood/model/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn 6 | from itertools import tee 7 | from typing import Optional, Union 8 | from pytorch_lightning import LightningModule 9 | 10 | 11 | class BaseModel(LightningModule, abc.ABC): 12 | def __init__( 13 | self, 14 | base_network: nn.Module, 15 | prediction_head: nn.Module, 16 | loss_fn: nn.Module, 17 | lr: float, 18 | weight_decay: Union[float, str], 19 | batch_size: int, 20 | ): 21 | super().__init__() 22 | self.base_network = base_network 23 | self.prediction_head = prediction_head 24 | self.loss_fn = partial(self.loss_function_wrapper, loss_fn=loss_fn) 25 | self.l2 = weight_decay 26 | self.lr = lr 27 | self.batch_size = batch_size 28 | 29 | def forward(self, x, domains: Optional = None, return_embedding: bool = False): 30 | embedding = self.base_network(x) 31 | label = self.prediction_head(embedding) 32 | out = (label, embedding) if return_embedding else label 33 | return out 34 | 35 | def training_step( 36 | self, batch, batch_idx, optimizer_idx: Optional[int] = None, dataset_idx: Optional[int] = None 37 | ): 38 | return self._step(batch, batch_idx, optimizer_idx) 39 | 40 | def validation_step(self, batch, batch_idx, dataset_idx: Optional[int] = None): 41 | return self._step(batch, batch_idx, optimizer_idx=None) 42 | 43 | def predict(self, dataloader): 44 | self.training = False 45 | with torch.inference_mode(): 46 | res = torch.cat([self.forward(*X) for X, y in dataloader], dim=0) 47 | return res 48 | 49 | @staticmethod 50 | def loss_function_wrapper(preds, targets, loss_fn): 51 | if preds.ndim > 1: 52 | preds = preds.squeeze(dim=-1) 53 | targets = targets.float() 54 | return loss_fn(preds, targets) 55 | 56 | @abc.abstractmethod 57 | def _step(self, batch, batch_idx, optimizer_idx: Optional[int] = None): 58 | raise NotImplementedError 59 | 60 | def configure_optimizers(self): 61 | if self.l2 == "auto": 62 | parameters, parameters_copy = tee(self.parameters()) 63 | self.l2 = 1.0 / sum(p.numel() for p in parameters_copy if p.requires_grad) 64 | 65 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2) 66 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) 67 | return { 68 | "optimizer": optimizer, 69 | "lr_scheduler": { 70 | "scheduler": lr_scheduler, 71 | "interval": "epoch", 72 | "frequency": 1, 73 | "monitor": "val_loss", 74 | "strict": True, 75 | }, 76 | } 77 | 78 | def log(self, name: str, *args, **kwargs): 79 | prefix = "train_" if self.training else "val_" 80 | super(BaseModel, self).log(prefix + name, *args, batch_size=self.batch_size, **kwargs) 81 | 82 | @staticmethod 83 | def suggest_params(trial): 84 | width = trial.suggest_categorical("mlp_width", [64, 128, 256, 512]) 85 | depth = trial.suggest_int("mlp_depth", 1, 5) 86 | lr = trial.suggest_float("lr", 1e-8, 1e-1, log=True) 87 | batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) 88 | weight_decay = trial.suggest_categorical( 89 | "weight_decay", ["auto", 0.0, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0] 90 | ) 91 | return { 92 | "mlp_width": width, 93 | "mlp_depth": depth, 94 | "lr": lr, 95 | "weight_decay": weight_decay, 96 | "batch_size": batch_size, 97 | } 98 | 99 | 100 | class Ensemble(LightningModule): 101 | def __init__(self, models, is_regression): 102 | super().__init__() 103 | self.models = models 104 | self.is_regression = is_regression 105 | 106 | def predict(self, dataloader): 107 | return torch.stack([model.predict(dataloader) for model in self.models]).mean(dim=0) 108 | 109 | def predict_uncertainty(self, dataloader): 110 | if self.is_regression: 111 | if len(self.models) == 1: 112 | uncertainty = None 113 | else: 114 | uncertainty = torch.stack([model.predict(dataloader) for model in self.models]).var(dim=0) 115 | else: 116 | proba = self.predict(dataloader)[:, 0] 117 | uncertainty = torch.stack([proba, 1.0 - proba], dim=1) 118 | return uncertainty 119 | -------------------------------------------------------------------------------- /mood/model/coral.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, List 3 | 4 | from torch import nn 5 | from mood.model.base import BaseModel 6 | from mood.model.utils import linear_interpolation 7 | 8 | 9 | class CORAL(BaseModel): 10 | """CORAL is a **domain adaptation** method. 11 | 12 | In addition to the traditional loss, adds a penalty based on the difference of the first and second moment 13 | of the source and target features. 14 | 15 | References: 16 | Sun, B., & Saenko, K. (2016, October). Deep coral: Correlation alignment for deep domain adaptation. 17 | In European conference on computer vision (pp. 443-450). Springer, Cham. 18 | https://arxiv.org/abs/1607.01719 19 | """ 20 | 21 | def __init__( 22 | self, 23 | base_network: nn.Module, 24 | prediction_head: nn.Module, 25 | loss_fn: nn.Module, 26 | batch_size: int, 27 | penalty_weight: float, 28 | penalty_weight_schedule: List[int], 29 | lr: float = 1e-4, 30 | weight_decay: float = 0, 31 | ): 32 | """ 33 | Args: 34 | base_network: The neural network architecture endoing the features 35 | prediction_head: The neural network architecture that takes the concatenated 36 | representation of the domain and features and returns a task-specific prediction. 37 | loss_fn: The loss function to optimize for. 38 | penalty_weight: The weight to multiply the penalty by. 39 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the 40 | penalty weight. The first integer is the last epoch at which the penalty weight is 0. 41 | The second integer is the first epoch at which the penalty weight is its max value. 42 | Linearly interpolates between the two. 43 | """ 44 | 45 | super().__init__( 46 | lr=lr, 47 | weight_decay=weight_decay, 48 | base_network=base_network, 49 | prediction_head=prediction_head, 50 | loss_fn=loss_fn, 51 | batch_size=batch_size, 52 | ) 53 | 54 | self.penalty_weight = penalty_weight 55 | if len(penalty_weight_schedule) != 2: 56 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step") 57 | self.start = penalty_weight_schedule[0] 58 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0] 59 | 60 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 61 | if self.training: 62 | batch_src, batch_tgt = batch["source"], batch["target"] 63 | (x_src, domains_src), y_true = batch_src 64 | x_tgt, domains_tgt = batch_tgt 65 | y_pred, phis_src = self.forward(x_src, return_embedding=True) 66 | _, phis_tgt = self.forward(x_tgt, return_embedding=True) 67 | loss = self._loss(y_pred, y_true, phis_src, phis_tgt) 68 | 69 | else: 70 | (x, domains), y_true = batch 71 | y_pred = self.forward(x) 72 | loss = self._loss(y_pred, y_true) 73 | 74 | self.log("loss", loss) 75 | return loss 76 | 77 | def _loss(self, y_pred, y_true, phis_src: Optional = None, phis_tgt: Optional = None): 78 | erm_loss = self.loss_fn(y_pred, y_true) 79 | 80 | if not self.training: 81 | return erm_loss 82 | 83 | penalty = self._coral_penalty(phis_src, phis_tgt) 84 | 85 | penalty_weight = linear_interpolation( 86 | self.current_epoch, self.duration, self.penalty_weight, start=self.start 87 | ) 88 | 89 | loss = erm_loss + penalty_weight * penalty 90 | return loss 91 | 92 | @staticmethod 93 | def _coral_penalty(x, y): 94 | """The CORAL penalty aligns the Covariance matrix of the features across domains""" 95 | mean_x = x.mean(0, keepdim=True) 96 | mean_y = y.mean(0, keepdim=True) 97 | cent_x = x - mean_x 98 | cent_y = y - mean_y 99 | cova_x = (cent_x.T @ cent_x) / max(1, len(x) - 1) 100 | cova_y = (cent_y.T @ cent_y) / max(1, len(y) - 1) 101 | 102 | mean_diff = (mean_x - mean_y).pow(2).mean() 103 | cova_diff = (cova_x - cova_y).pow(2).mean() 104 | return mean_diff + cova_diff 105 | 106 | @staticmethod 107 | def suggest_params(trial): 108 | params = BaseModel.suggest_params(trial) 109 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True) 110 | params["penalty_weight_schedule"] = trial.suggest_categorical( 111 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]] 112 | ) 113 | return params 114 | -------------------------------------------------------------------------------- /mood/model/dann.py: -------------------------------------------------------------------------------- 1 | from itertools import chain, tee 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from typing import Optional, List 6 | from torch import nn 7 | 8 | from mood.model.base import BaseModel 9 | from mood.model.nn import FCLayer 10 | from mood.model.utils import linear_interpolation 11 | 12 | 13 | class DANN(BaseModel): 14 | """Domain Adversarial Neural Network (DANN) is a **domain adaptation** method. 15 | 16 | Adversarial framework that includes a prediction and discriminator network. The goal of the discriminator 17 | is to classify the domain (source or target) from the hidden embedding. The goal of the predictor is to achieve 18 | a good task-specific performance. By optimizing these in an adversarial fashion, the goal is to have 19 | domain-invariant features. 20 | 21 | References: 22 | Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., ... & Lempitsky, V. (2016). 23 | Domain-adversarial training of neural networks. The journal of machine learning research, 17(1), 2096-2030. 24 | https://arxiv.org/abs/1505.07818 25 | """ 26 | 27 | def __init__( 28 | self, 29 | base_network: nn.Module, 30 | prediction_head: FCLayer, 31 | loss_fn: nn.Module, 32 | batch_size: int, 33 | penalty_weight: float, 34 | penalty_weight_schedule: List[int], 35 | discriminator_network: Optional[nn.Module] = None, 36 | lr: float = 1e-3, 37 | discr_lr: float = 1e-3, 38 | weight_decay: float = "auto", 39 | discr_weight_decay: float = "auto", 40 | lambda_reg: float = 0.1, 41 | n_discr_steps_per_predictor_step=3, 42 | ): 43 | """ 44 | Args: 45 | base_network: The neural network architecture endoing the features 46 | prediction_head: The neural network architecture that takes the concatenated 47 | representation of the domain and features and returns a task-specific prediction. 48 | loss_fn: The loss function to optimize for. 49 | penalty_weight: The weight to multiply the penalty by. 50 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the 51 | penalty weight. The first integer is the last epoch at which the penalty weight is 0. 52 | The second integer is the first epoch at which the penalty weight is its max value. 53 | Linearly interpolates between the two. 54 | discriminator_network: The discriminator network that predicts the domain from the hidden embedding. 55 | lambda_reg: An additional weighing factor for the penalty. Following the implementation of DomainBed. 56 | """ 57 | super().__init__( 58 | lr=lr, 59 | weight_decay=weight_decay, 60 | base_network=base_network, 61 | prediction_head=prediction_head, 62 | loss_fn=loss_fn, 63 | batch_size=batch_size, 64 | ) 65 | 66 | self._discriminator_network = discriminator_network 67 | if self._discriminator_network is None: 68 | self._discriminator_network = FCLayer(prediction_head.input_dim, 2, activation=None) 69 | 70 | self.penalty_weight = penalty_weight 71 | if len(penalty_weight_schedule) != 2: 72 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step") 73 | self.start = penalty_weight_schedule[0] 74 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0] 75 | 76 | self._discriminator_loss = nn.CrossEntropyLoss() 77 | self.discr_lr = discr_lr 78 | self.discr_l2 = discr_weight_decay 79 | self.lambda_reg = lambda_reg 80 | self._n_discr_steps_per_predictor_step = n_discr_steps_per_predictor_step 81 | 82 | @staticmethod 83 | def get_optimizer(parameters, lr, weight_decay, monitor: str = "val_loss"): 84 | if weight_decay == "auto": 85 | parameters, parameters_copy = tee(parameters) 86 | weight_decay = 1.0 / sum(p.numel() for p in parameters_copy if p.requires_grad) 87 | 88 | optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay) 89 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) 90 | return { 91 | "optimizer": optimizer, 92 | "lr_scheduler": { 93 | "scheduler": lr_scheduler, 94 | "interval": "epoch", 95 | "frequency": 1, 96 | "monitor": monitor, 97 | "strict": True, 98 | }, 99 | } 100 | 101 | def configure_optimizers(self): 102 | optimizer_predictor = self.get_optimizer( 103 | chain(self.base_network.parameters(), self.prediction_head.parameters()), 104 | self.lr, 105 | self.l2, 106 | ) 107 | optimizer_predictor["frequency"] = self._n_discr_steps_per_predictor_step 108 | 109 | optimizer_discriminator = self.get_optimizer( 110 | chain(self.base_network.parameters(), self._discriminator_network.parameters()), 111 | self.discr_lr, 112 | self.discr_l2, 113 | ) 114 | optimizer_discriminator["frequency"] = 1 115 | 116 | return optimizer_predictor, optimizer_discriminator 117 | 118 | def forward( 119 | self, x, domains: Optional = None, return_embedding: bool = False, return_discriminator: bool = False 120 | ): 121 | input_embeddings = self.base_network(x) 122 | label = self.prediction_head(input_embeddings) 123 | 124 | ret = (label,) 125 | if return_embedding: 126 | ret += (input_embeddings,) 127 | 128 | if return_discriminator: 129 | discr_out = self._discriminator_network(input_embeddings) 130 | ret += (discr_out,) 131 | 132 | if len(ret) == 1: 133 | return ret[0] 134 | return ret 135 | 136 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 137 | if not self.training: 138 | (x, domains), y_true = batch 139 | y_pred = self.forward(x) 140 | loss = self.loss_fn(y_pred, y_true) 141 | self.log("loss", loss) 142 | return loss 143 | 144 | batch_src, batch_tgt = batch["source"], batch["target"] 145 | 146 | (x_src, domains_src), y_true = batch_src 147 | x_tgt, domains_tgt = batch_tgt 148 | 149 | y_pred, input_embeddings_src, discr_out_src = self.forward( 150 | x_src, return_embedding=True, return_discriminator=True 151 | ) 152 | _, input_embeddings_tgt, discr_out_tgt = self.forward( 153 | x_tgt, return_embedding=True, return_discriminator=True 154 | ) 155 | 156 | erm_loss = self.loss_fn(y_pred, y_true) 157 | 158 | # For losses w.r.t. the discriminator 159 | domain_pred = torch.cat([discr_out_src, discr_out_tgt], dim=0) 160 | domain_true = torch.cat([torch.zeros_like(discr_out_src), torch.ones_like(discr_out_tgt)], dim=0) 161 | loss = self._discriminator_loss(domain_pred, domain_true) 162 | 163 | domain_pred_softmax = F.softmax(domain_pred, dim=1) 164 | penalty_loss = domain_pred_softmax[:, domain_true.long()].sum() 165 | 166 | # When optimizing the discriminator, PTL automatically set requires_grad to False 167 | # for all parameters not optimized by the parameter. However, we still need True to 168 | # be able to compute the gradient penalty 169 | if optimizer_idx == 1: 170 | for param in self.base_network.parameters(): 171 | param.requires_grad = True 172 | 173 | penalty = self.gradient_reversal(penalty_loss, [input_embeddings_src, input_embeddings_tgt]) 174 | 175 | if optimizer_idx == 1: 176 | for param in self.base_network.parameters(): 177 | param.requires_grad = False 178 | 179 | penalty_weight = linear_interpolation( 180 | self.current_epoch, self.duration, self.penalty_weight, start=self.start 181 | ) 182 | loss += penalty * penalty_weight 183 | 184 | if optimizer_idx == 1: 185 | # Train the discriminator 186 | self.log("discriminator_loss", loss) 187 | return loss 188 | else: 189 | # Train the predictor and add a penalty for making features domain-invariant 190 | loss = erm_loss + (self.lambda_reg * -loss) 191 | self.log("predictive_loss", loss) 192 | return loss 193 | 194 | @staticmethod 195 | def gradient_reversal(loss, inputs): 196 | grad = torch.cat(torch.autograd.grad(loss, inputs, create_graph=True)) 197 | return (grad**2).sum(dim=1).mean(dim=0) 198 | 199 | @staticmethod 200 | def suggest_params(trial): 201 | params = BaseModel.suggest_params(trial) 202 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 1e-10, 1.0, log=True) 203 | params["penalty_weight_schedule"] = trial.suggest_categorical( 204 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]] 205 | ) 206 | params["discr_lr"] = trial.suggest_float("discr_lr", 1e-8, 1.0, log=True) 207 | params["discr_weight_decay"] = trial.suggest_categorical( 208 | "discr_weight_decay", ["auto", 0.0, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0] 209 | ) 210 | params["lambda_reg"] = trial.suggest_float("lambda_reg", 0.001, 10, log=True) 211 | params["n_discr_steps_per_predictor_step"] = trial.suggest_int( 212 | "n_discr_steps_per_predictor_step", 1, 5 213 | ) 214 | return params 215 | -------------------------------------------------------------------------------- /mood/model/erm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from mood.model.base import BaseModel 3 | 4 | 5 | class ERM(BaseModel): 6 | """Empirical Risk Minimization 7 | 8 | The "vanilla" neural network. Updates the weight to minimize the loss of the batch. 9 | 10 | References: 11 | Vapnik, V. N. (1998). Statistical Learning Theory. Wiley-Interscience. 12 | https://www.wiley.com/en-fr/Statistical+Learning+Theory-p-9780471030034 13 | """ 14 | 15 | def _step(self, batch, batch_idx, optimizer_idx: Optional[int] = None): 16 | (x, domain), y_true = batch 17 | y_pred = self.forward(x) 18 | loss = self._loss(y_pred, y_true) 19 | self.log("loss", loss, prog_bar=True) 20 | return loss 21 | 22 | def _loss(self, y_pred, y_true): 23 | return self.loss_fn(y_pred, y_true) 24 | -------------------------------------------------------------------------------- /mood/model/ib_erm.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from mood.model.base import BaseModel 7 | from mood.model.utils import linear_interpolation 8 | 9 | 10 | class InformationBottleneckERM(BaseModel): 11 | """Information Bottleneck ERM (IB-ERM) is a **domain generalization** method. 12 | 13 | Similar to MTL, computes the loss per domain and computes the mean of these domain-specific losses. 14 | Additionally, adds a penalty based on the variance of the feature dimensions. 15 | 16 | References: 17 | Ahuja, K., Caballero, E., Zhang, D., Gagnon-Audet, J. C., Bengio, Y., Mitliagkas, I., & Rish, I. (2021). 18 | Invariance principle meets information bottleneck for out-of-distribution generalization. 19 | Advances in Neural Information Processing Systems, 34, 3438-3450. 20 | https://arxiv.org/abs/2106.06607 21 | """ 22 | 23 | def __init__( 24 | self, 25 | base_network: nn.Module, 26 | prediction_head: nn.Module, 27 | loss_fn: nn.Module, 28 | batch_size: int, 29 | penalty_weight: float, 30 | penalty_weight_schedule: List[int], 31 | lr=1e-3, 32 | weight_decay=0, 33 | ): 34 | """ 35 | Args: 36 | base_network: The neural network architecture endoing the features 37 | prediction_head: The neural network architecture that takes the concatenated 38 | representation of the domain and features and returns a task-specific prediction. 39 | loss_fn: The loss function to optimize for. 40 | penalty_weight: The weight to multiply the penalty by. 41 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the 42 | penalty weight. The first integer is the last epoch at which the penalty weight is 0. 43 | The second integer is the first epoch at which the penalty weight is its max value. 44 | Linearly interpolates between the two. 45 | """ 46 | 47 | super().__init__( 48 | lr=lr, 49 | weight_decay=weight_decay, 50 | base_network=base_network, 51 | prediction_head=prediction_head, 52 | loss_fn=loss_fn, 53 | batch_size=batch_size, 54 | ) 55 | 56 | self.penalty_weight = penalty_weight 57 | if len(penalty_weight_schedule) != 2: 58 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step") 59 | self.start = penalty_weight_schedule[0] 60 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0] 61 | 62 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 63 | phis = [] 64 | erm_loss = 0 65 | 66 | for mini_batch in batch: 67 | (xs, _), y_true = mini_batch 68 | y_pred, phi = self.forward(xs, return_embedding=True) 69 | erm_loss += self.loss_fn(y_pred, y_true) 70 | phis.append(phi) 71 | 72 | erm_loss /= len(batch) 73 | phis = torch.cat(phis, dim=0) 74 | loss = self._loss(erm_loss, phis) 75 | self.log("loss", loss) 76 | return loss 77 | 78 | def _loss(self, erm_loss, phis): 79 | if not self.training: 80 | return erm_loss 81 | 82 | penalty_weight = linear_interpolation( 83 | step=self.current_epoch, 84 | duration=self.duration, 85 | max_value=self.penalty_weight, 86 | start=self.start, 87 | ) 88 | 89 | penalty = 0 90 | for i in range(len(phis)): 91 | # Add the Information Bottleneck penalty 92 | penalty += self.ib_penalty(phis[i]) 93 | penalty /= len(phis) 94 | 95 | loss = erm_loss + penalty_weight * penalty 96 | return loss 97 | 98 | @staticmethod 99 | def ib_penalty(features): 100 | if len(features) == 1: 101 | return 0.0 102 | return features.var(dim=0).mean() 103 | 104 | @staticmethod 105 | def suggest_params(trial): 106 | params = BaseModel.suggest_params(trial) 107 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True) 108 | params["penalty_weight_schedule"] = trial.suggest_categorical( 109 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]] 110 | ) 111 | return params 112 | -------------------------------------------------------------------------------- /mood/model/mixup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from mood.model.base import BaseModel 7 | from mood.model.utils import linear_interpolation 8 | 9 | 10 | class Mixup(BaseModel): 11 | """Mixup is a **domain adaptation** method. 12 | 13 | Mixup interpolates both the features and (pseudo-)targets inter- and intra-domain and trains on 14 | these interpolates samples instead. 15 | 16 | References: 17 | Yan, S., Song, H., Li, N., Zou, L., & Ren, L. (2020). Improve unsupervised domain adaptation with mixup training. 18 | https://arxiv.org/abs/2001.00677 19 | """ 20 | 21 | def __init__( 22 | self, 23 | base_network: nn.Module, 24 | prediction_head: nn.Module, 25 | loss_fn: nn.Module, 26 | batch_size: int, 27 | penalty_weight: float, 28 | penalty_weight_schedule: List[int], 29 | lr: float = 1e-3, 30 | weight_decay: float = "auto", 31 | augmentation_std: float = 0.1, 32 | no_augmentations: int = 10, 33 | alpha: float = 0.1, 34 | ): 35 | """ 36 | Args: 37 | base_network: The neural network architecture endoing the features 38 | prediction_head: The neural network architecture that takes the concatenated 39 | representation of the domain and features and returns a task-specific prediction. 40 | loss_fn: The loss function to optimize for. 41 | penalty_weight: The weight to multiply the penalty by. 42 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the 43 | penalty weight. The first integer is the last epoch at which the penalty weight is 0. 44 | The second integer is the first epoch at which the penalty weight is its max value. 45 | Linearly interpolates between the two. 46 | augmentation_std: The standard deviation of the noise to multiply each sample by 47 | as augmentation. 48 | no_augmentations: The number of augmentations to do to compute pseudo labels for the unlabeled samples 49 | of the target domain. 50 | alpha: The parameter of the Beta distribution used to compute the interpolation factor 51 | """ 52 | super().__init__( 53 | lr=lr, 54 | weight_decay=weight_decay, 55 | base_network=base_network, 56 | prediction_head=prediction_head, 57 | loss_fn=loss_fn, 58 | batch_size=batch_size, 59 | ) 60 | 61 | self._classification_loss = torch.nn.BCELoss() 62 | self._target_domain_loss = torch.nn.MSELoss() 63 | 64 | self._augmentation_std = augmentation_std 65 | self._no_augmentations = no_augmentations 66 | 67 | self.distribution_lambda = torch.distributions.Beta(alpha, alpha) 68 | 69 | self.penalty_weight = penalty_weight 70 | if len(penalty_weight_schedule) != 2: 71 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step") 72 | self.start = penalty_weight_schedule[0] 73 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0] 74 | 75 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 76 | if not self.training: 77 | (x, domains), y_true = batch 78 | y_pred = self.forward(x) 79 | loss = self.loss_fn(y_pred, y_true) 80 | self.log("loss", loss) 81 | return loss 82 | 83 | batch_src, batch_tgt = batch["source"], batch["target"] 84 | (x_src, _), y_src = batch_src 85 | x_tgt, _ = batch_tgt 86 | y_tgt = self._get_pseudo_labels(x_tgt) 87 | 88 | penalty_weight = linear_interpolation( 89 | step=self.current_epoch, 90 | duration=self.duration, 91 | max_value=self.penalty_weight, 92 | start=self.start, 93 | ) 94 | 95 | loss = 0 96 | 97 | # Inter-domain, from source to target 98 | loss += self._loss(x_src, x_tgt, y_src, y_tgt, True) 99 | # Intra-domain, from source to source 100 | loss += self._loss(x_src, x_src, y_src, y_src, False) 101 | # Intra-domain, from target to target 102 | # Quote from the paper: "We set a linear schedule for w_t in training, 103 | # from 0 to a predefined maximum value. From initial experiments, we observe that 104 | # the algorithm is robust to other weighting parameters. Therefore we only 105 | # search w_t while simply fixing all other weightings to 1." 106 | loss += penalty_weight * self._loss(x_tgt, x_tgt, y_tgt, y_tgt, False) 107 | 108 | self.log("loss", loss) 109 | return loss 110 | 111 | def _get_pseudo_labels(self, xs): 112 | augmented_labels = [] 113 | for i in range(self._no_augmentations): 114 | sample = xs * torch.normal(1, self._augmentation_std, xs.size(), device=xs.device) 115 | augmented_labels.append(self.forward(sample)) 116 | pseudo_labels = torch.stack(augmented_labels).mean(0).squeeze() 117 | return torch.atleast_1d(pseudo_labels) 118 | 119 | def _loss(self, x_src, x_tgt, y_src, y_tgt, inter_domain: bool): 120 | lam = self.distribution_lambda.sample() 121 | lam_prime = torch.max(lam, 1.0 - lam) 122 | 123 | x_src, x_tgt, y_src, y_tgt = self._get_random_pairs(x_src, x_tgt, y_src, y_tgt, round(len(x_src) / 3)) 124 | x_st, y_st = self._mixup(x_src, x_tgt, y_src, y_tgt, lam_prime) 125 | y_pred_st, phi_st = self.forward(x_st, return_embedding=True) 126 | 127 | if inter_domain: 128 | # Predictive loss 129 | loss_q = self.loss_fn(y_pred_st, y_st) 130 | 131 | # Consistency regularizer 132 | y_pred_s, phi_s = self.forward(x_src, return_embedding=True) 133 | y_pred_t, phi_t = self.forward(x_tgt, return_embedding=True) 134 | zi_st = lam_prime * phi_s + (1.0 - lam_prime) * phi_t 135 | loss_z = torch.norm(zi_st - phi_st, dim=0).mean() 136 | loss = loss_q + loss_z 137 | 138 | # Intra target domain 139 | else: 140 | loss = self.loss_fn(y_pred_st, y_st) 141 | return loss 142 | 143 | def _get_random_pairs(self, x_src, x_tgt, y_src, y_tgt, size): 144 | assert len(x_src) == len(y_src) 145 | assert len(x_tgt) == len(y_tgt) 146 | 147 | size = max(torch.tensor(1), size) 148 | indices = torch.multinomial(torch.ones(len(x_src), device=x_src.device), size, replacement=True) 149 | x_src = x_src[indices] 150 | y_src = y_src[indices] 151 | 152 | indices = torch.multinomial(torch.ones(len(x_tgt), device=x_tgt.device), size, replacement=True) 153 | x_tgt = x_tgt[indices] 154 | y_tgt = y_tgt[indices] 155 | 156 | return x_src, x_tgt, y_src, y_tgt 157 | 158 | def _mixup(self, x_s, x_t, y_s, y_t, lam_prime): 159 | xi_st = lam_prime * x_s + (1.0 - lam_prime) * x_t 160 | yi_st = lam_prime * y_s + (1.0 - lam_prime) * y_t 161 | return xi_st, yi_st 162 | 163 | @staticmethod 164 | def suggest_params(trial): 165 | params = BaseModel.suggest_params(trial) 166 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True) 167 | params["penalty_weight_schedule"] = trial.suggest_categorical( 168 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]] 169 | ) 170 | params["augmentation_std"] = trial.suggest_float("augmentation_std", 0.001, 0.15) 171 | params["no_augmentations"] = trial.suggest_categorical("no_augmentations", [3, 5, 10]) 172 | params["alpha"] = trial.suggest_float("alpha", 0.0, 1.0) 173 | return params 174 | -------------------------------------------------------------------------------- /mood/model/mtl.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | from mood.model.base import BaseModel 7 | 8 | 9 | class MTL(BaseModel): 10 | """Marginal Transfer Learning (MTL) is a **domain generalization** method. 11 | 12 | MTL uses a representation of the domain. 13 | Additionally, rather than computing the loss for the entire batch at once, it computes the loss 14 | for each domain individually and then returns the mean of this. 15 | 16 | Originally proposed as a kernel method, it is here implemented as a Neural Network. 17 | 18 | References: 19 | Blanchard, G., Deshmukh, A. A., Dogan, U., Lee, G., & Scott, C. (2017). 20 | Domain generalization by marginal transfer learning. 21 | https://arxiv.org/abs/1711.07910 22 | """ 23 | 24 | def __init__( 25 | self, 26 | base_network: nn.Module, 27 | prediction_head: nn.Module, 28 | loss_fn: nn.Module, 29 | batch_size: int, 30 | lr: float = 1e-3, 31 | weight_decay: float = 0, 32 | ): 33 | """ 34 | Args: 35 | base_network: The neural network architecture endoing the features 36 | domain_network: The neural network architecture encoding the domain 37 | prediction_head: The neural network architecture that takes the concatenated 38 | representation of the domain and features and returns a task-specific prediction. 39 | loss_fn: The loss function to optimize for. 40 | """ 41 | super().__init__( 42 | lr=lr, 43 | weight_decay=weight_decay, 44 | base_network=base_network, 45 | loss_fn=loss_fn, 46 | prediction_head=prediction_head, 47 | batch_size=batch_size, 48 | ) 49 | self.domain_network = deepcopy(base_network) 50 | 51 | def forward(self, x, domains: Optional = None, return_embedding: bool = False): 52 | input_embeddings = self.base_network(x) 53 | domain_embeddings = self.domain_network(domains) 54 | # add noise to domains to avoid overfitting 55 | domain_embeddings = domain_embeddings + torch.randn_like(domain_embeddings) 56 | embeddings = torch.cat((domain_embeddings, input_embeddings), 1) 57 | label = self.prediction_head(embeddings) 58 | out = (label, embeddings) if return_embedding else label 59 | return out 60 | 61 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 62 | xs = torch.cat([xs for (xs, _), _ in batch], dim=0) 63 | domains = torch.cat([ds for (_, ds), _ in batch], dim=0) 64 | y_true = torch.cat([ys for _, ys in batch], dim=0) 65 | ns = [len(ys) for _, ys in batch] 66 | 67 | y_pred, phis = self.forward(xs, domains, return_embedding=True) 68 | 69 | loss = self._loss(y_pred, y_true, ns) 70 | self.log("loss", loss) 71 | return loss 72 | 73 | def _loss(self, y_pred, y_true, ns): 74 | loss, i = 0, 0 75 | for n in ns: 76 | loss += self.loss_fn(y_pred[i : i + n], y_true[i : i + n]) 77 | i += n 78 | loss = loss / len(ns) 79 | return loss 80 | -------------------------------------------------------------------------------- /mood/model/nn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from torch import nn 6 | from torch.nn import Flatten 7 | 8 | from mood.model.utils import get_activation 9 | 10 | 11 | class FCLayer(nn.Module): 12 | r""" 13 | A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. 14 | The order in which transformations are applied is: 15 | 16 | #. Dense Layer 17 | #. Activation 18 | #. Dropout (if applicable) 19 | #. Batch Normalization (if applicable) 20 | 21 | Arguments 22 | ---------- 23 | in_size: int 24 | Input dimension of the layer (the torch.nn.Linear) 25 | out_size: int 26 | Output dimension of the layer. Should be one supported by :func:`ivbase.nn.commons.get_activation`. 27 | dropout: float, optional 28 | The ratio of units to dropout. No dropout by default. 29 | (Default value = 0.) 30 | activation: str or callable, optional 31 | Activation function to use. 32 | (Default value = relu) 33 | b_norm: bool, optional 34 | Whether to use batch normalization 35 | (Default value = False) 36 | bias: bool, optional 37 | Whether to enable bias in for the linear layer. 38 | (Default value = True) 39 | init_fn: callable, optional 40 | Initialization function to use for the weight of the layer. Default is 41 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` with :math:`k=\frac{1}{ \text{node_feats_dim}}` 42 | (Default value = None) 43 | 44 | Attributes 45 | ---------- 46 | dropout: int 47 | The ratio of units to dropout. 48 | b_norm: int 49 | Whether to use batch normalization 50 | linear: torch.nn.Linear 51 | The linear layer 52 | activation: the torch.nn.Module 53 | The activation layer 54 | init_fn: function 55 | Initialization function used for the weight of the layer 56 | in_size: int 57 | Input dimension of the linear layer 58 | out_size: int 59 | Output dimension of the linear layer 60 | 61 | """ 62 | 63 | def __init__( 64 | self, in_size, out_size, activation="relu", dropout=0.0, b_norm=False, bias=True, init_fn=None 65 | ): 66 | super(FCLayer, self).__init__() 67 | # Although I disagree with this it is simple enough and robust 68 | # if we trust the user base 69 | self._params = locals() 70 | self.in_size = in_size 71 | self.out_size = out_size 72 | activation = get_activation(activation) 73 | linear = nn.Linear(in_size, out_size, bias=bias) 74 | if init_fn: 75 | init_fn(linear) 76 | layers = [linear] 77 | if activation is not None: 78 | layers.append(activation) 79 | if dropout: 80 | layers.append(nn.Dropout(p=dropout)) 81 | if b_norm: 82 | layers.append(nn.BatchNorm1d(out_size)) 83 | self.net = nn.Sequential(*layers) 84 | 85 | @property 86 | def output_dim(self): 87 | return self.out_size 88 | 89 | @property 90 | def input_dim(self): 91 | return self.in_size 92 | 93 | def forward(self, x): 94 | return self.net(x) 95 | 96 | 97 | class MLP(nn.Module): 98 | r""" 99 | Feature extractor using a Fully Connected Neural Network 100 | 101 | Arguments 102 | ---------- 103 | input_size: int 104 | size of the input 105 | hidden_sizes: int list or int 106 | size of the hidden layers 107 | out_size: int list or int or None 108 | if None, uses the last hidden size as the output 109 | activation: str or callable 110 | activation function. Should be supported by :func:`ivbase.nn.commons.get_activation` 111 | (Default value = 'relu') 112 | b_norm: bool, optional): 113 | Whether batch norm is used or not. 114 | (Default value = False) 115 | dropout: float, optional 116 | Dropout probability to regularize the network. No dropout by default. 117 | (Default value = .0) 118 | 119 | Attributes 120 | ---------- 121 | extractor: torch.nn.Module 122 | The underlying feature extractor of the model. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | input_size, 128 | hidden_sizes=None, 129 | out_size=None, 130 | activation="ReLU", 131 | out_activation=None, 132 | b_norm=False, 133 | l_norm=False, 134 | dropout=0.0, 135 | ): 136 | super(MLP, self).__init__() 137 | self._params = locals() 138 | layers = [] 139 | 140 | if out_size is None and hidden_sizes is None: 141 | raise ValueError("You need to specify either hidden_sizes or output_size") 142 | 143 | if out_size is None: 144 | out_size = hidden_sizes[-1] 145 | hidden_sizes = hidden_sizes[:-1] 146 | 147 | if hidden_sizes is None: 148 | hidden_sizes = [] 149 | 150 | in_ = input_size 151 | if l_norm: 152 | layers.append(nn.LayerNorm(input_size)) 153 | 154 | for i, out_ in enumerate(hidden_sizes): 155 | layer = FCLayer( 156 | in_, 157 | out_, 158 | activation=activation, 159 | b_norm=False, 160 | dropout=dropout, 161 | ) 162 | layers.append(layer) 163 | in_ = out_ 164 | 165 | layers.append(FCLayer(in_, out_size, activation=out_activation, b_norm=b_norm, dropout=False)) 166 | self.extractor = nn.Sequential(*layers) 167 | 168 | @property 169 | def output_dim(self): 170 | return self.extractor[-1].output_dim 171 | 172 | @property 173 | def input_dim(self): 174 | return self.extractor[0].input_dim 175 | 176 | def forward(self, x): 177 | x = Flatten()(x) 178 | res = self.extractor(x) 179 | return res 180 | 181 | 182 | def get_simple_mlp( 183 | input_size: int, 184 | width: int = 0, 185 | depth: int = 0, 186 | out_size: Optional[int] = 1, 187 | is_regression: Optional[bool] = None, 188 | ): 189 | if out_size is not None and not isinstance(is_regression, bool): 190 | raise TypeError("Specify is_regression to be True or False") 191 | 192 | if out_size is None: 193 | out_activation = "ReLU" 194 | else: 195 | out_activation = None if is_regression else "Sigmoid" 196 | 197 | return MLP( 198 | input_size=input_size, 199 | hidden_sizes=[width] * depth, 200 | out_size=out_size, 201 | activation="ReLU", 202 | out_activation=out_activation, 203 | ) 204 | -------------------------------------------------------------------------------- /mood/model/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | 4 | 5 | def linear_interpolation(step, duration, max_value, start: int = 0, min_value: float = 0): 6 | if max_value < min_value: 7 | raise ValueError("max_value cannot be smaller than min value") 8 | if step < start: 9 | return min_value 10 | if step >= start + duration: 11 | return max_value 12 | step_size = (max_value - min_value) / duration 13 | step = step - start 14 | return step_size * step 15 | 16 | 17 | def get_activation(activation_spec): 18 | if isinstance(activation_spec, Callable): 19 | return activation_spec 20 | if activation_spec is None: 21 | return None 22 | 23 | activation_fs = vars(torch.nn.modules.activation) 24 | for activation in activation_fs: 25 | if activation.lower() == activation_spec.lower(): 26 | return activation_fs[activation]() 27 | 28 | raise ValueError(f"{activation_spec} is not a valid activation function") 29 | -------------------------------------------------------------------------------- /mood/model/vrex.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, List 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from mood.model.base import BaseModel 8 | from mood.model.utils import linear_interpolation 9 | 10 | 11 | class VREx(BaseModel): 12 | """Variance Risk Extrapolation (VREx) is a **domain generalization** method. 13 | 14 | Similar to MTL, computes the loss per domain and computes the mean of these domain-specific losses. 15 | Additionally, following equation 8 of the paper, returns an additional penalty based on the variance 16 | of the domain specific losses. 17 | 18 | References: 19 | Krueger, D., Caballero, E., Jacobsen, J. H., Zhang, A., Binas, J., Zhang, D., ... & Courville, A. (2021, July). 20 | Out-of-distribution generalization via risk extrapolation (rex). 21 | In International Conference on Machine Learning (pp. 5815-5826). PMLR. 22 | https://arxiv.org/abs/2003.00688 23 | """ 24 | 25 | def __init__( 26 | self, 27 | base_network: nn.Module, 28 | prediction_head: nn.Module, 29 | loss_fn: nn.Module, 30 | batch_size: int, 31 | penalty_weight: float, 32 | penalty_weight_schedule: List[int], 33 | lr=1e-3, 34 | weight_decay=0, 35 | ): 36 | """ 37 | Args: 38 | base_network: The neural network architecture endoing the features 39 | prediction_head: The neural network architecture that takes the concatenated 40 | representation of the domain and features and returns a task-specific prediction. 41 | loss_fn: The loss function to optimize for. 42 | penalty_weight: The weight to multiply the penalty by. 43 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the 44 | penalty weight. The first integer is the last epoch at which the penalty weight is 0. 45 | The second integer is the first epoch at which the penalty weight is its max value. 46 | Linearly interpolates between the two. 47 | """ 48 | super().__init__( 49 | lr=lr, 50 | weight_decay=weight_decay, 51 | base_network=base_network, 52 | prediction_head=prediction_head, 53 | loss_fn=loss_fn, 54 | batch_size=batch_size, 55 | ) 56 | 57 | self.penalty_weight = penalty_weight 58 | if len(penalty_weight_schedule) != 2: 59 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step") 60 | self.start = penalty_weight_schedule[0] 61 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0] 62 | 63 | def _step(self, batch, batch_idx=0, optimizer_idx=None): 64 | erm_losses = [] 65 | for mini_batch in batch: 66 | (xs, _), y_true = mini_batch 67 | y_pred = self.forward(xs, return_embedding=False) 68 | erm_losses.append(self.loss_fn(y_pred, y_true)) 69 | 70 | penalty_weight = linear_interpolation( 71 | step=self.current_epoch, 72 | duration=self.duration, 73 | max_value=self.penalty_weight, 74 | start=self.start, 75 | ) 76 | 77 | erm_losses = torch.stack(erm_losses) 78 | erm_loss = erm_losses.mean() 79 | 80 | # NOTE: A batch can have just a single domain 81 | if len(batch) == 1: 82 | loss = erm_loss 83 | else: 84 | rex_penalty = erm_losses.var() 85 | loss = erm_loss + penalty_weight * rex_penalty 86 | 87 | self.log("loss", loss) 88 | return loss 89 | 90 | @staticmethod 91 | def suggest_params(trial): 92 | params = BaseModel.suggest_params(trial) 93 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True) 94 | params["penalty_weight_schedule"] = trial.suggest_categorical( 95 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]] 96 | ) 97 | return params 98 | -------------------------------------------------------------------------------- /mood/model_space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Union 4 | 5 | from sklearn.neural_network import MLPRegressor, MLPClassifier 6 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 7 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier 8 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum 9 | 10 | 11 | _SKLEARN_MLP_TYPE = Union[MLPRegressor, MLPClassifier] 12 | _SKLEARN_RF_TYPE = Union[RandomForestRegressor, RandomForestClassifier] 13 | _SKLEARN_GP_TYPE = Union[GaussianProcessRegressor, GaussianProcessClassifier] 14 | 15 | 16 | def is_linear_kernel(kernel): 17 | return isinstance(kernel, PairwiseKernel) and kernel.metric == "linear" 18 | 19 | 20 | class ModelSpaceTransformer: 21 | SUPPORTED_TYPES = Union[ 22 | _SKLEARN_MLP_TYPE, 23 | _SKLEARN_RF_TYPE, 24 | _SKLEARN_GP_TYPE, 25 | ] 26 | 27 | def __init__(self, model, embedding_size: int): 28 | if not isinstance(model, self.SUPPORTED_TYPES): 29 | raise TypeError(f"{type(model)} is not supported") 30 | self._model = model 31 | self._embedding_size = embedding_size 32 | 33 | def __call__(self, X): 34 | """Transforms a list of datapoints""" 35 | return self.transform(X) 36 | 37 | def transform(self, X): 38 | """Transforms a single datapoint""" 39 | if isinstance(self._model, _SKLEARN_RF_TYPE): 40 | return self.get_rf_embedding(self._model, X) 41 | elif isinstance(self._model, _SKLEARN_GP_TYPE): 42 | return self.get_gp_embedding(self._model, X) 43 | elif isinstance(self._model, _SKLEARN_MLP_TYPE): 44 | return self.get_mlp_embedding(self._model, X) 45 | # This should never be reached given the 46 | # type check in the constructor 47 | raise NotImplementedError 48 | 49 | def get_rf_embedding(self, model, X): 50 | """ 51 | For a random forest, the model space embedding is given by 52 | a subset of 100 features that have the highest importance 53 | """ 54 | importances = model.feature_importances_ 55 | mask = np.argsort(importances)[-self._embedding_size :] 56 | return X[:, mask] 57 | 58 | def get_gp_embedding(self, model, X): 59 | """ 60 | In a Gaussian Process, the model space embedding is given 61 | by a subset of 100 features that have the highest importance. 62 | This importance is computed based on alpha and the train set. 63 | For now, this only supports a linear kernel. 64 | """ 65 | # Check the target type 66 | is_regression = isinstance(model, GaussianProcessRegressor) 67 | if not is_regression and model.n_classes_ != 2: 68 | msg = f"We only support regression and binary classification" 69 | raise ValueError(msg) 70 | 71 | # Check the kernel type 72 | is_linear = ( 73 | is_linear_kernel(model.kernel_) 74 | or isinstance(model.kernel_, Sum) 75 | and (is_linear_kernel(model.kernel_.k1) or is_linear_kernel(model.kernel_.k2)) 76 | ) 77 | if not is_linear: 78 | msg = f"We only support the linear kernel, not {model.kernel_}" 79 | raise NotImplementedError(msg) 80 | 81 | if is_regression: 82 | alpha = model.alpha_ 83 | X_train = model.X_train_ 84 | else: 85 | est = model.base_estimator_ 86 | alpha = est.y_train_ - est.pi_ 87 | X_train = est.X_train_ 88 | 89 | importances = (alpha[:, None] * X_train).sum(axis=0) 90 | importances = np.abs(importances) 91 | mask = np.argsort(importances)[-self._embedding_size :] 92 | return X[:, mask] 93 | 94 | def get_mlp_embedding(self, model, X): 95 | """ 96 | For an multi-layer perceptron, the model space embedding is given by 97 | the activations of the second-to-last layer 98 | """ 99 | hidden_layer_sizes = model.hidden_layer_sizes 100 | 101 | # Get the MLP architecture 102 | if not hasattr(hidden_layer_sizes, "__iter__"): 103 | hidden_layer_sizes = [hidden_layer_sizes] 104 | hidden_layer_sizes = list(hidden_layer_sizes) 105 | layer_units = [X.shape[1]] + hidden_layer_sizes + [model.n_outputs_] 106 | 107 | # Create empty arrays to save all activations in 108 | activations = [X] 109 | for i in range(model.n_layers_ - 1): 110 | activations.append(np.empty((X.shape[0], layer_units[i + 1]))) 111 | 112 | # Actually populate the empty arrays 113 | model._forward_pass(activations) 114 | 115 | # Return the activations of the second-to-last layer 116 | hidden_rep = activations[-2] 117 | 118 | importances = model.coefs_[-1][:, 0] 119 | mask = np.argsort(importances)[-self._embedding_size :] 120 | return hidden_rep[:, mask] 121 | -------------------------------------------------------------------------------- /mood/preprocessing.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | from functools import partial 3 | from rdkit.Chem import SaltRemover 4 | 5 | 6 | def standardize_smiles(smi, for_text_based_model: bool = False, disable_logs: bool = False): 7 | """A good default standardization function for fingerprints and GNNs""" 8 | 9 | with dm.without_rdkit_log(enable=disable_logs): 10 | mol = dm.to_mol(smi, ordered=True, sanitize=False) 11 | mol = dm.sanitize_mol(mol) 12 | 13 | if for_text_based_model: 14 | mol = dm.standardize_mol(mol) 15 | 16 | else: 17 | mol = dm.standardize_mol(mol, disconnect_metals=True) 18 | remover = SaltRemover.SaltRemover() 19 | mol = remover.StripMol(mol, dontRemoveEverything=True) 20 | 21 | return dm.to_smiles(mol) 22 | 23 | 24 | DEFAULT_PREPROCESSING = { 25 | "MACCS": partial(standardize_smiles, for_text_based_model=False), 26 | "ECFP6": partial(standardize_smiles, for_text_based_model=False), 27 | "Desc2D": partial(standardize_smiles, for_text_based_model=False), 28 | "WHIM": partial(standardize_smiles, for_text_based_model=False), 29 | "ChemBERTa": partial(standardize_smiles, for_text_based_model=True), 30 | "Graphormer": partial(standardize_smiles, for_text_based_model=False), 31 | } 32 | -------------------------------------------------------------------------------- /mood/rct.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import zlib 3 | 4 | import numpy as np 5 | 6 | from mood.model import MOOD_DA_DG_ALGORITHMS 7 | from mood.representations import MOOD_REPRESENTATIONS 8 | from mood.splitter import MOOD_SPLITTERS 9 | from mood.baselines import MOOD_BASELINES 10 | from mood.criteria import get_mood_criteria 11 | from mood.metrics import Metric 12 | 13 | 14 | RCT_SEED = 1234 15 | NUM_SEEDS = 10 16 | 17 | 18 | def get_experimental_configurations(dataset): 19 | """ 20 | To randomly sample different configurations of the RCT experiment, we use a deterministic approach. 21 | This facilitates reproducibility, but also makes it easy to run the experiment on preemptible instances. 22 | Otherwise, it could happen that models that take longer to train have a higher chance of failing, 23 | biasing the experiment. 24 | """ 25 | 26 | # NOTE: We should not rely on the order of a dict for creating these configurations, 27 | # as a dict is not ordered. We unfortunately only realized this halfway through generating the results. 28 | # Luckily, it seems like for our use case this does result in consistent results. 29 | # As updating the code would change the ordering of the RCT, we kept it like this for now. 30 | 31 | prf_metric = Metric.get_default_performance_metric(dataset) 32 | cal_metric = Metric.get_default_calibration_metric(dataset) 33 | 34 | mood_criteria = get_mood_criteria(prf_metric, cal_metric).keys() 35 | 36 | mood_baselines = MOOD_BASELINES.copy() 37 | mood_baselines.pop(mood_baselines.index("MLP")) 38 | mood_algorithms = mood_baselines + list(MOOD_DA_DG_ALGORITHMS.keys()) 39 | 40 | all_options = list( 41 | itertools.product( 42 | mood_algorithms, 43 | MOOD_REPRESENTATIONS, 44 | MOOD_SPLITTERS, 45 | mood_criteria, 46 | list(range(NUM_SEEDS)), 47 | ) 48 | ) 49 | # NOTE: We add the hash of the dataset to make the sampled configurations dataset-dependent 50 | rng = np.random.default_rng(RCT_SEED + zlib.adler32(dataset.encode("utf-8"))) 51 | rng.shuffle(all_options) 52 | return all_options 53 | -------------------------------------------------------------------------------- /mood/representations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | import pandas as pd 5 | import datamol as dm 6 | import numpy as np 7 | 8 | from loguru import logger 9 | from collections import OrderedDict 10 | from typing import Optional, List, Callable, Union, Dict 11 | from functools import partial 12 | from copy import deepcopy 13 | 14 | from rdkit import Chem 15 | from rdkit.Chem.QED import properties 16 | from rdkit.Chem import Descriptors 17 | from rdkit.Chem import FindMolChiralCenters 18 | from rdkit.Chem import rdPartialCharges 19 | from rdkit.Chem import rdMolDescriptors 20 | from rdkit.Chem import AllChem 21 | 22 | from transformers import AutoTokenizer, AutoModelForMaskedLM 23 | from mood.constants import DATASET_DATA_DIR 24 | from mood.utils import get_mask_for_distances_or_representations 25 | 26 | 27 | _CHEMBERTA_HF_ID = "seyonec/PubChem10M_SMILES_BPE_450k" 28 | 29 | 30 | def representation_iterator( 31 | smiles, 32 | standardize_fn: Union[Callable, Dict[str, Callable]], 33 | n_jobs: Optional[int] = None, 34 | progress: bool = True, 35 | mask_nan: bool = True, 36 | return_mask: bool = True, 37 | disable_logs: bool = True, 38 | whitelist: Optional[List[str]] = None, 39 | blacklist: Optional[List[str]] = None, 40 | batch_size: int = 16, 41 | ): 42 | if whitelist is not None and blacklist is not None: 43 | msg = "You cannot use a blacklist and whitelist at the same time" 44 | raise ValueError(msg) 45 | 46 | all_representations = MOOD_REPRESENTATIONS 47 | 48 | if whitelist is not None: 49 | all_representations = [d for d in all_representations if d in whitelist] 50 | if blacklist is not None: 51 | all_representations = [d for d in all_representations if d not in blacklist] 52 | if not isinstance(standardize_fn, dict): 53 | standardize_fn = {repr_: standardize_fn for repr_ in all_representations} 54 | 55 | for name in all_representations: 56 | feats = featurize( 57 | smiles, 58 | name, 59 | standardize_fn[name], 60 | n_jobs, 61 | mask_nan, 62 | return_mask, 63 | progress, 64 | disable_logs, 65 | batch_size, 66 | ) 67 | yield name, feats 68 | 69 | 70 | def featurize( 71 | smiles, 72 | name, 73 | standardize_fn, 74 | n_jobs: Optional[int] = None, 75 | mask_nan: bool = True, 76 | return_mask: bool = True, 77 | progress: bool = True, 78 | disable_logs: bool = False, 79 | batch_size: int = 16, 80 | ): 81 | if name not in _REPR_TO_FUNC: 82 | msg = f"{name} is not supported. Choose from {MOOD_REPRESENTATIONS}" 83 | raise NotImplementedError(msg) 84 | 85 | fn = partial(standardize_fn, disable_logs=disable_logs) 86 | smiles = np.array( 87 | dm.utils.parallelized(fn, smiles, progress=progress, tqdm_kwargs={"desc": f"Preprocess {name}"}) 88 | ) 89 | 90 | fn = _REPR_TO_FUNC[name] 91 | fn = partial(fn, disable_logs=disable_logs) 92 | 93 | if name in BATCHED_FEATURIZERS: 94 | reprs = fn(smiles, batch_size=batch_size) 95 | 96 | else: 97 | reprs = dm.utils.parallelized( 98 | fn, smiles, n_jobs=n_jobs, progress=progress, tqdm_kwargs={"desc": name} 99 | ) 100 | reprs = np.array(reprs) 101 | 102 | # Mask out invalid features 103 | mask = get_mask_for_distances_or_representations(reprs) 104 | 105 | logger.info(f"Succesfully computed representations for {len(reprs[mask])}/{len(smiles)} compounds") 106 | 107 | if mask_nan: 108 | reprs = reprs[mask] 109 | 110 | # If the array had any Nones, it would not be a proper 111 | # 2D array so we convert to one here. 112 | reprs = np.stack(reprs) 113 | 114 | if return_mask: 115 | reprs = reprs, mask 116 | return reprs 117 | 118 | 119 | def compute_whim(smi, disable_logs: bool = False): 120 | """ 121 | Compute a WHIM descriptor from a RDkit molecule object 122 | Code adapted from MoleculeACE, Van Tilborg et al. (2022) 123 | """ 124 | 125 | smi = dm.to_smiles(dm.keep_largest_fragment(dm.to_mol(smi))) 126 | 127 | with dm.without_rdkit_log(enable=disable_logs): 128 | mol = dm.to_mol(smi) 129 | if mol is None: 130 | # Failed 131 | return 132 | 133 | mol = Chem.AddHs(mol) 134 | 135 | # Use distance geometry to obtain initial coordinates for a molecule 136 | ret = AllChem.EmbedMolecule( 137 | mol, useRandomCoords=True, useBasicKnowledge=True, randomSeed=0, clearConfs=True, maxAttempts=5 138 | ) 139 | if ret == -1: 140 | # Failed 141 | return 142 | 143 | AllChem.MMFFOptimizeMolecule(mol, maxIters=1000, mmffVariant="MMFF94") 144 | 145 | # calculate WHIM 3D descriptor 146 | whim = rdMolDescriptors.CalcWHIM(mol) 147 | whim = np.array(whim).astype(np.float32) 148 | return whim 149 | 150 | 151 | def _compute_extra_2d_features(mol): 152 | """Computes some additional descriptors besides the default ones RDKit offers""" 153 | mol = deepcopy(mol) 154 | FindMolChiralCenters(mol, force=True) 155 | p_obj = rdMolDescriptors.Properties() 156 | props = OrderedDict(zip(p_obj.GetPropertyNames(), p_obj.ComputeProperties(mol))) 157 | qed_props = properties(mol) 158 | props["Alerts"] = qed_props.ALERTS 159 | return props 160 | 161 | 162 | def _charge_descriptors_fix(mol: dm.Mol): 163 | """Recompute the RDKIT 2D Descriptors related to charge 164 | 165 | We change the procedure: 166 | 1. We disconnect the metal from the molecule 167 | 2. We add the hydrogen atoms 168 | 3. We make sure that gasteiger is recomputed. 169 | 170 | This fixes an issue where these descriptors could be NaN or Inf, 171 | while also making sure we are closer to the proper interpretation 172 | """ 173 | descrs = {} 174 | mol = dm.standardize_mol(mol, disconnect_metals=True) 175 | mol = dm.add_hs(mol, explicit_only=False) 176 | rdPartialCharges.ComputeGasteigerCharges(mol) 177 | atomic_charges = [float(at.GetProp("_GasteigerCharge")) for at in mol.GetAtoms()] 178 | atomic_charges = np.clip(atomic_charges, a_min=-500, a_max=500) 179 | min_charge, max_charge = np.nanmin(atomic_charges), np.nanmax(atomic_charges) 180 | descrs["MaxPartialCharge"] = max_charge 181 | descrs["MinPartialCharge"] = min_charge 182 | descrs["MaxAbsPartialCharge"] = max(np.abs(min_charge), np.abs(max_charge)) 183 | descrs["MinAbsPartialCharge"] = min(np.abs(min_charge), np.abs(max_charge)) 184 | return descrs 185 | 186 | 187 | def compute_desc2d(smi, disable_logs: bool = False): 188 | descr_fns = {name: fn for (name, fn) in Descriptors.descList} 189 | 190 | all_features = [d[0] for d in Descriptors.descList] 191 | all_features += [ 192 | "NumAtomStereoCenters", 193 | "NumUnspecifiedAtomStereoCenters", 194 | "NumBridgeheadAtoms", 195 | "NumAmideBonds", 196 | "NumSpiroAtoms", 197 | "Alerts", 198 | ] 199 | 200 | with dm.without_rdkit_log(enable=disable_logs): 201 | mol = dm.to_mol(smi) 202 | descr_extra = _compute_extra_2d_features(mol) 203 | descr_charge = _charge_descriptors_fix(mol) 204 | 205 | descr = [] 206 | 207 | for name in all_features: 208 | val = float("nan") 209 | if name in descr_charge: 210 | val = descr_charge[name] 211 | elif name == "Ipc": 212 | # Fixes a bug for the RDKit IPC value. For context, see: 213 | # https://github.com/rdkit/rdkit/issues/1527 214 | val = descr_fns[name](mol, avg=True) 215 | elif name in descr_fns: 216 | val = descr_fns[name](mol) 217 | else: 218 | assert name in descr_extra 219 | val = descr_extra[name] 220 | descr.append(val) 221 | 222 | descr = np.asarray(descr) 223 | return descr 224 | 225 | 226 | def compute_ecfp6(smi, disable_logs: bool = False): 227 | with dm.without_rdkit_log(enable=disable_logs): 228 | return dm.to_fp(smi, fp_type="ecfp", radius=3) 229 | 230 | 231 | def compute_maccs(smi, disable_logs: bool = False): 232 | with dm.without_rdkit_log(enable=disable_logs): 233 | return dm.to_fp(smi, fp_type="maccs") 234 | 235 | 236 | def compute_chemberta(smis, disable_logs: bool = False, batch_size: int = 16): 237 | # Batch the input 238 | step_size = int(np.ceil(len(smis) / batch_size)) 239 | batched = np.array_split(smis, step_size) 240 | 241 | # Load the model 242 | tokenizer = AutoTokenizer.from_pretrained(_CHEMBERTA_HF_ID) 243 | model = AutoModelForMaskedLM.from_pretrained(_CHEMBERTA_HF_ID) 244 | 245 | # Use the GPU if it is available 246 | device = "cuda" if torch.cuda.is_available() else "cpu" 247 | model.to(device) 248 | model.eval() 249 | 250 | hidden_states = [] 251 | for batch in tqdm.tqdm(batched, desc="Batch"): 252 | model_input = tokenizer( 253 | batch.tolist(), 254 | return_tensors="pt", 255 | add_special_tokens=True, 256 | truncation=True, 257 | padding=True, 258 | max_length=512, 259 | ).to(device) 260 | 261 | with torch.no_grad(): 262 | model_output = model( 263 | model_input["input_ids"], 264 | attention_mask=model_input["attention_mask"], 265 | output_hidden_states=True, 266 | ) 267 | 268 | # We use mean aggregation of the different token embeddings 269 | h = model_output.hidden_states[-1] 270 | h = [h_[mask].mean(0) for h_, mask in zip(h, model_input["attention_mask"])] 271 | h = torch.stack(h) 272 | h = h.cpu().detach().numpy() 273 | 274 | hidden_states.append(h) 275 | 276 | hidden_states = np.concatenate(hidden_states) 277 | return hidden_states 278 | 279 | 280 | def load_graphormer(smis, disable_logs: bool = False, batch_size: int = 16): 281 | # NOTE: Since we needed to make some changes to the Graphormer repo 282 | # to actually extract the embeddings, including that code in MOOD 283 | # would pollute the repo quite a bit. Therefore, we precomputed the 284 | # representations and just load these here. 285 | 286 | pattern = dm.fs.join(DATASET_DATA_DIR, "representations", "**", "*.parquet") 287 | paths = dm.fs.glob(pattern) 288 | df_repr = pd.concat([pd.read_parquet(p) for p in paths]) 289 | 290 | unique_ids = dm.utils.parallelized(dm.unique_id, smis, progress=True) 291 | df = pd.DataFrame(index=unique_ids) 292 | 293 | df_repr = df_repr[df_repr["unique_id"].isin(unique_ids)] 294 | df_repr = df_repr.set_index("unique_id") 295 | 296 | df = df.join(df_repr) 297 | df = df[~df.index.duplicated(keep="first")] 298 | df = df.reindex(unique_ids) 299 | 300 | feats = df["representation"].to_numpy() 301 | return feats 302 | 303 | 304 | _REPR_TO_FUNC = { 305 | "MACCS": compute_maccs, 306 | "ECFP6": compute_ecfp6, 307 | "Desc2D": compute_desc2d, 308 | "WHIM": compute_whim, 309 | "ChemBERTa": compute_chemberta, 310 | "Graphormer": load_graphormer, 311 | } 312 | 313 | MOOD_REPRESENTATIONS = list(_REPR_TO_FUNC.keys()) 314 | BATCHED_FEATURIZERS = ["ChemBERTa", "Graphormer"] 315 | TEXTUAL_FEATURIZERS = ["ChemBERTa"] 316 | -------------------------------------------------------------------------------- /mood/splitter.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | import numpy as np 4 | import datamol as dm 5 | import seaborn as sns 6 | import pandas as pd 7 | 8 | from sklearn.model_selection import ShuffleSplit 9 | from dataclasses import dataclass 10 | from loguru import logger 11 | from typing import Union, List, Optional, Callable, Dict 12 | 13 | from scipy.stats import gaussian_kde 14 | from scipy.spatial.distance import jensenshannon 15 | from sklearn.metrics import pairwise_distances 16 | from sklearn.neighbors import NearestNeighbors 17 | from sklearn.model_selection import BaseShuffleSplit, GroupShuffleSplit 18 | from sklearn.model_selection._split import _validate_shuffle_split, _num_samples 19 | from sklearn.cluster import MiniBatchKMeans 20 | 21 | from mood.transformer import EmpiricalKernelMapTransformer 22 | from mood.distance import get_distance_metric 23 | from mood.visualize import plot_distance_distributions 24 | from mood.utils import get_outlier_bounds 25 | 26 | 27 | MOOD_SPLITTERS = ["Random", "Scaffold", "Perimeter", "Maximum Dissimilarity"] 28 | 29 | 30 | def get_mood_splitters(smiles, n_splits: int = 5, random_state: int = 0, n_jobs: Optional[int] = None): 31 | scaffolds = [dm.to_smiles(dm.to_scaffold_murcko(dm.to_mol(smi))) for smi in smiles] 32 | splitters = { 33 | "Random": ShuffleSplit(n_splits=n_splits, random_state=random_state), 34 | "Scaffold": PredefinedGroupShuffleSplit( 35 | groups=scaffolds, n_splits=n_splits, random_state=random_state 36 | ), 37 | "Perimeter": PerimeterSplit( 38 | n_clusters=25, n_splits=n_splits, random_state=random_state, n_jobs=n_jobs 39 | ), 40 | "Maximum Dissimilarity": MaxDissimilaritySplit( 41 | n_clusters=25, n_splits=n_splits, random_state=random_state, n_jobs=n_jobs 42 | ), 43 | } 44 | return splitters 45 | 46 | 47 | @dataclass 48 | class SplitCharacterization: 49 | """ 50 | Within the context of MOOD, a split is characterized by 51 | a distribution of distances and an associated representativeness score 52 | """ 53 | 54 | distances: np.ndarray 55 | representativeness: float 56 | label: str 57 | 58 | @classmethod 59 | def concat(cls, splits): 60 | names = set([obj.label for obj in splits]) 61 | if len(names) != 1: 62 | raise RuntimeError("Can only concatenate equally labeled split characterizations") 63 | 64 | dist = np.concatenate([obj.distances for obj in splits]) 65 | score = np.mean([obj.representativeness for obj in splits]) 66 | return cls(dist, score, names.pop()) 67 | 68 | @staticmethod 69 | def best(splits): 70 | return max(splits, key=lambda spl: spl.representativeness) 71 | 72 | @staticmethod 73 | def as_dataframe(splits): 74 | df = pd.DataFrame() 75 | best = SplitCharacterization.best(splits) 76 | for split in splits: 77 | df_ = pd.DataFrame( 78 | { 79 | "split": split.label, 80 | "representativeness": split.representativeness, 81 | "best": split == best, 82 | }, 83 | index=[0], 84 | ) 85 | df = pd.concat((df, df_), ignore_index=True) 86 | df["rank"] = df["representativeness"].rank(ascending=False) 87 | return df 88 | 89 | def __eq__(self, other): 90 | return self.label == other.label and self.representativeness == other.representativeness 91 | 92 | def __repr__(self): 93 | return self.__str__() 94 | 95 | def __str__(self): 96 | return f"{self.__class__.__name__}[{self.label}]" 97 | 98 | 99 | class MOODSplitter(BaseShuffleSplit): 100 | """ 101 | The MOOD splitter takes in multiple splitters and a set of 102 | downstream molecules and prescribes one splitting method 103 | that creates the test set that is most representative of 104 | downstream applications. 105 | """ 106 | 107 | def __init__( 108 | self, 109 | splitters: Dict[str, BaseShuffleSplit], 110 | downstream_distances: Optional[np.ndarray] = None, 111 | metric: Union[str, Callable] = "minkowski", 112 | p: int = 2, 113 | k: int = 5, 114 | ): 115 | """ 116 | Args: 117 | splitters: A list of splitter methods you are considering 118 | downstream_distances: A list of precomputed distances for the downstream application 119 | metric: The distance metric to use 120 | p: If the metric is the minkowski distance, this is the p in that distance. 121 | k: The number of nearest neighbors to use to compute the distance. 122 | """ 123 | super().__init__() 124 | if not all(isinstance(obj, BaseShuffleSplit) for obj in splitters.values()): 125 | raise TypeError("All splitters should be BaseShuffleSplit objects") 126 | 127 | n_splits_per_splitter = [obj.get_n_splits() for obj in splitters.values()] 128 | if not len(set(n_splits_per_splitter)) == 1: 129 | raise TypeError("n_splits is inconsistent across the different splitters") 130 | self._n_splits = n_splits_per_splitter[0] 131 | 132 | self._p = p 133 | self._k = k 134 | self._metric = metric 135 | self._splitters = splitters 136 | self._downstream_distances = downstream_distances 137 | 138 | self._split_chars = None 139 | self._prescribed_splitter_label = None 140 | 141 | @staticmethod 142 | def visualize(downstream_distances: np.ndarray, splits: List[SplitCharacterization], ax: Optional = None): 143 | splits = sorted(splits, key=lambda spl: spl.representativeness) 144 | cmap = sns.color_palette("rocket", len(splits) + 1) 145 | 146 | distances = [spl.distances for spl in splits] 147 | colors = [cmap[rank + 1] for rank, spl in enumerate(splits)] 148 | labels = [spl.label for spl in splits] 149 | 150 | ax = plot_distance_distributions(distances, labels, colors, ax=ax) 151 | 152 | lower, upper = get_outlier_bounds(downstream_distances, factor=3.0) 153 | mask = (downstream_distances >= lower) & (downstream_distances <= upper) 154 | downstream_distances = downstream_distances[mask] 155 | 156 | sns.kdeplot(downstream_distances, color=cmap[0], linestyle="--", alpha=0.3, ax=ax) 157 | return ax 158 | 159 | @staticmethod 160 | def score_representativeness(downstream_distances, distances, num_samples: int = 100): 161 | """Scores a representativeness score between two distributions 162 | A higher score should be interpreted as _more_ representative""" 163 | pdf_split = gaussian_kde(distances) 164 | pdf_downstream = gaussian_kde(downstream_distances) 165 | 166 | vmin = np.min(np.concatenate((downstream_distances, distances))) 167 | vmax = np.max(np.concatenate((downstream_distances, distances))) 168 | positions = np.linspace(vmin, vmax, num=num_samples) 169 | 170 | samples_split = pdf_split(positions) 171 | samples_downstream = pdf_downstream(positions) 172 | 173 | return 1.0 - jensenshannon(samples_downstream, samples_split, base=2) 174 | 175 | @property 176 | def prescribed_splitter_label(self): 177 | if not self.fitted: 178 | raise RuntimeError("The splitter has not be fitted yet") 179 | return self._prescribed_splitter_label 180 | 181 | @property 182 | def fitted(self): 183 | return self._prescribed_splitter_label is not None 184 | 185 | def _compute_distance(self, X_from, X_to): 186 | """ 187 | Computes the k-NN distance from one set to another 188 | 189 | Args: 190 | X_from: The set to compute the distance for 191 | X_to: The set to compute the distance to (i.e. the neighbor candidates) 192 | """ 193 | knn = NearestNeighbors(n_neighbors=self._k, metric=self._metric, p=self._p).fit(X_to) 194 | distances, ind = knn.kneighbors(X_from) 195 | distances = np.mean(distances, axis=1) 196 | return distances 197 | 198 | def get_prescribed_splitter(self): 199 | return self._splitters[self.prescribed_splitter_label] 200 | 201 | def get_protocol_results(self): 202 | return SplitCharacterization.as_dataframe(self._split_chars) 203 | 204 | def fit(self, X, y=None, groups=None, X_deployment=None, plot: bool = False, progress: bool = False): 205 | """Follows the MOOD specification to prescribe a train-test split 206 | that is most representative of downstream applications. 207 | 208 | In MOOD, the k-NN distance in the representation space functions 209 | as a proxy of difficulty. The further a datapoint is from the training 210 | set, in general the lower a model's performance. Using that observation, 211 | we prescribe the train-test split that best replicates the distance 212 | distribution (i.e. "the difficulty") of a downstream application. 213 | """ 214 | 215 | if self._downstream_distances is None: 216 | self._downstream_distances = self._compute_distance(X_deployment, X) 217 | 218 | # Precompute all splits. Since splitters are implemented as generators, 219 | # we store the resulting splits so we can replicate them later on. 220 | split_chars = list() 221 | 222 | it = self._splitters.items() 223 | if progress: 224 | it = tqdm.tqdm(it, desc="Splitter") 225 | 226 | for name, splitter in it: 227 | # We possibly repeat the split multiple times to 228 | # get a more reliable estimate 229 | chars = [] 230 | 231 | it_ = splitter.split(X, y, groups) 232 | if progress: 233 | it_ = tqdm.tqdm(it_, leave=False, desc="Split", total=self._n_splits) 234 | 235 | for split in it_: 236 | train, test = split 237 | distances = self._compute_distance(X[test], X[train]) 238 | distances = distances[np.isfinite(distances)] 239 | distances = distances[~np.isnan(distances)] 240 | 241 | score = self.score_representativeness(self._downstream_distances, distances) 242 | chars.append(SplitCharacterization(distances, score, name)) 243 | 244 | split_chars.append(SplitCharacterization.concat(chars)) 245 | 246 | # Rank different splitting methods by their ability to 247 | # replicate the downstream distance distribution. 248 | chosen = SplitCharacterization.best(split_chars) 249 | 250 | self._split_chars = split_chars 251 | self._prescribed_splitter_label = chosen.label 252 | 253 | logger.info( 254 | f"Ranked all different splitting methods:\n{SplitCharacterization.as_dataframe(split_chars)}" 255 | ) 256 | logger.info(f"Selected {chosen.label} as the most representative splitting method") 257 | 258 | if plot: 259 | # Visualize the results 260 | return self.visualize(self._downstream_distances, split_chars) 261 | 262 | def _iter_indices(self, X=None, y=None, groups=None): 263 | """Generate (train, test) indices""" 264 | if not self.fitted: 265 | raise RuntimeError("The splitter has not be fitted yet") 266 | yield from self.get_prescribed_splitter()._iter_indices(X, y, groups) 267 | 268 | 269 | class PredefinedGroupShuffleSplit(GroupShuffleSplit): 270 | """Simple class that tackles the limitation of the MOODSplitter 271 | that all splitters need to use the same grouping.""" 272 | 273 | def __init__(self, groups, n_splits=5, *, test_size=None, train_size=None, random_state=None): 274 | super().__init__( 275 | n_splits=n_splits, 276 | test_size=test_size, 277 | train_size=train_size, 278 | random_state=random_state, 279 | ) 280 | self._groups = groups 281 | 282 | def _iter_indices(self, X=None, y=None, groups=None): 283 | """Generate (train, test) indices""" 284 | if groups is not None: 285 | logger.warning("Ignoring the groups parameter in favor of the predefined groups") 286 | yield from super()._iter_indices(X, y, self._groups) 287 | 288 | 289 | class KMeansSplit(GroupShuffleSplit): 290 | """Split based on the k-Mean clustering in input space""" 291 | 292 | def __init__( 293 | self, n_clusters: int = 10, n_splits: int = 5, *, test_size=None, train_size=None, random_state=None 294 | ): 295 | super().__init__( 296 | n_splits=n_splits, 297 | test_size=test_size, 298 | train_size=train_size, 299 | random_state=random_state, 300 | ) 301 | self._n_clusters = n_clusters 302 | 303 | def compute_kmeans_clustering(self, X, random_state_offset: int = 0, return_centers: bool = False): 304 | metric = get_distance_metric(X) 305 | 306 | if self.random_state is not None: 307 | seed = self.random_state + random_state_offset 308 | else: 309 | seed = None 310 | 311 | if metric != "euclidean": 312 | logger.debug(f"To use KMeans with the {metric} metric, we use the Empirical Kernel Map") 313 | transformer = EmpiricalKernelMapTransformer( 314 | n_samples=min(512, len(X)), 315 | metric=metric, 316 | random_state=seed, 317 | ) 318 | X = transformer(X) 319 | 320 | model = MiniBatchKMeans(self._n_clusters, random_state=seed, compute_labels=True) 321 | model.fit(X) 322 | 323 | indices = model.labels_ 324 | if not return_centers: 325 | return indices 326 | 327 | centers = model.cluster_centers_[indices] 328 | return indices, centers 329 | 330 | def _iter_indices(self, X=None, y=None, groups=None): 331 | """Generate (train, test) indices""" 332 | if groups is not None: 333 | logger.warning("Ignoring the groups parameter in favor of the predefined groups") 334 | groups = self.compute_kmeans_clustering(X) 335 | yield from super()._iter_indices(X, y, groups) 336 | 337 | 338 | class PerimeterSplit(KMeansSplit): 339 | """ 340 | Places the pairs of data points with maximal pairwise distance in the test set. 341 | This was originally called the extrapolation-oriented split, introduced in Szántai-Kis et. al., 2003 342 | """ 343 | 344 | def __init__( 345 | self, 346 | n_clusters: int = 10, 347 | n_splits: int = 5, 348 | n_jobs: Optional[int] = None, 349 | *, 350 | test_size=None, 351 | train_size=None, 352 | random_state=None, 353 | ): 354 | super().__init__( 355 | n_clusters=n_clusters, 356 | n_splits=n_splits, 357 | test_size=test_size, 358 | train_size=train_size, 359 | random_state=random_state, 360 | ) 361 | self._n_jobs = n_jobs 362 | 363 | def _iter_indices(self, X, y=None, groups=None): 364 | if groups is not None: 365 | logger.warning("Ignoring the groups parameter in favor of the predefined groups") 366 | 367 | n_samples = _num_samples(X) 368 | n_train, n_test = _validate_shuffle_split( 369 | n_samples, 370 | self.test_size, 371 | self.train_size, 372 | default_test_size=self._default_test_size, 373 | ) 374 | 375 | for i in range(self.n_splits): 376 | groups, centers = self.compute_kmeans_clustering(X, random_state_offset=i, return_centers=True) 377 | centers, group_indices, group_counts = np.unique( 378 | centers, return_inverse=True, return_counts=True, axis=0 379 | ) 380 | groups_set = np.unique(group_indices) 381 | 382 | # We always use the euclidean metric. For binary vectors we would have 383 | # used the jaccard metric normally, but because of the k-Means clustering this 384 | # data would be transformed using the Empirical Kernel Map. 385 | distance_matrix = pairwise_distances(centers, metric="euclidean", n_jobs=self._n_jobs) 386 | 387 | # Sort the distance matrix to find the groups that are the furthest away from one another 388 | tril_indices = np.tril_indices_from(distance_matrix, k=-1) 389 | maximum_distance_indices = np.argsort(distance_matrix[tril_indices])[::-1] 390 | 391 | test_indices = [] 392 | remaining = set(groups_set) 393 | 394 | for pos in maximum_distance_indices: 395 | if len(test_indices) >= n_test: 396 | break 397 | 398 | i, j = ( 399 | tril_indices[0][pos], 400 | tril_indices[1][pos], 401 | ) 402 | 403 | # If one of the molecules in this pair is already in the test set, skip to the next 404 | if not (i in remaining and j in remaining): 405 | continue 406 | 407 | remaining.remove(i) 408 | test_indices.extend(list(np.flatnonzero(group_indices == groups_set[i]))) 409 | remaining.remove(j) 410 | test_indices.extend(list(np.flatnonzero(group_indices == groups_set[j]))) 411 | 412 | train_indices = [] 413 | for i in remaining: 414 | train_indices.extend(list(np.flatnonzero(group_indices == groups_set[i]))) 415 | 416 | yield np.array(train_indices), np.array(test_indices) 417 | 418 | 419 | class MaxDissimilaritySplit(KMeansSplit): 420 | """Splits the data such that the train and test set are maximally dissimilar.""" 421 | 422 | def __init__( 423 | self, 424 | n_clusters: int = 10, 425 | n_splits: int = 5, 426 | n_jobs: Optional[int] = None, 427 | *, 428 | test_size=None, 429 | train_size=None, 430 | random_state=None, 431 | ): 432 | super().__init__( 433 | n_clusters=n_clusters, 434 | n_splits=n_splits, 435 | test_size=test_size, 436 | train_size=train_size, 437 | random_state=random_state, 438 | ) 439 | self._n_jobs = n_jobs 440 | 441 | def _iter_indices(self, X, y=None, groups=None): 442 | """Generate (train, test) indices""" 443 | 444 | if groups is not None: 445 | logger.warning("Ignoring the groups parameter in favor of the predefined groups") 446 | 447 | metric = get_distance_metric(X) 448 | 449 | n_samples = _num_samples(X) 450 | n_train, n_test = _validate_shuffle_split( 451 | n_samples, 452 | self.test_size, 453 | self.train_size, 454 | default_test_size=self._default_test_size, 455 | ) 456 | 457 | for i in range(self.n_splits): 458 | # We introduce some stochasticity through the k-Means clustering 459 | groups, centers = self.compute_kmeans_clustering(X, random_state_offset=i, return_centers=True) 460 | centers, group_indices, group_counts = np.unique( 461 | centers, return_inverse=True, return_counts=True, axis=0 462 | ) 463 | groups_set = np.unique(group_indices) 464 | 465 | # We always use the euclidean metric. For binary vectors we would have 466 | # used the jaccard metric normally, but because of the k-Means clustering this 467 | # data would be transformed using the Empirical Kernel Map. 468 | distance_matrix = pairwise_distances(centers, metric="euclidean", n_jobs=self._n_jobs) 469 | 470 | # The initial test cluster is the one with the 471 | # highest mean distance to all other clusters 472 | test_idx = np.argmax(distance_matrix.mean(axis=0)) 473 | 474 | # The initial train cluster is the one furthest from 475 | # the initial test cluster 476 | train_idx = np.argmax(distance_matrix[test_idx]) 477 | 478 | train_indices = np.flatnonzero(group_indices == groups_set[train_idx]) 479 | test_indices = np.flatnonzero(group_indices == groups_set[test_idx]) 480 | 481 | # Iteratively add the train cluster that is furthest 482 | # from the _initial_ test cluster. 483 | sorted_groups = np.argsort(distance_matrix[train_idx]) 484 | for group_idx in sorted_groups: 485 | if len(train_indices) >= n_train: 486 | break 487 | 488 | if group_idx == train_idx or group_idx == test_idx: 489 | continue 490 | 491 | indices_to_add = np.flatnonzero(group_indices == groups_set[group_idx]) 492 | train_indices = np.concatenate([train_indices, indices_to_add]) 493 | 494 | # Construct test set 495 | remaining_groups = list(set(range(n_samples)) - set(train_indices) - set(test_indices)) 496 | test_indices = np.concatenate([test_indices, remaining_groups]).astype(int) 497 | 498 | yield train_indices, test_indices 499 | -------------------------------------------------------------------------------- /mood/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from typing import Optional 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 7 | from pytorch_lightning.utilities.seed import seed_everything 8 | from torch.utils.data import DataLoader 9 | 10 | from mood.baselines import construct_kernel, get_baseline_model, MOOD_BASELINES 11 | from mood.constants import NUM_EPOCHS 12 | from mood.dataset import SimpleMolecularDataset, DAMolecularDataset, domain_based_collate 13 | from mood.model import MOOD_DA_DG_ALGORITHMS, is_domain_generalization, is_domain_adaptation 14 | from mood.model.base import Ensemble 15 | from mood.model.nn import get_simple_mlp 16 | 17 | 18 | def train_baseline_model( 19 | X, 20 | y, 21 | algorithm: str, 22 | is_regression: bool, 23 | params: Optional[dict] = None, 24 | seed: Optional[int] = None, 25 | for_uncertainty_estimation: bool = False, 26 | ensemble_size: int = 10, 27 | calibrate: bool = False, 28 | n_jobs: int = -1, 29 | ): 30 | if params is None: 31 | params = {} 32 | if seed is not None: 33 | params["random_state"] = seed 34 | 35 | if algorithm == "RF" or (algorithm == "GP" and not is_regression): 36 | params["n_jobs"] = n_jobs 37 | 38 | if algorithm == "RF" and not is_regression: 39 | params["class_weight"] = "balanced" 40 | if algorithm == "GP": 41 | params["kernel"], params = construct_kernel(is_regression, params) 42 | 43 | model = get_baseline_model( 44 | name=algorithm, 45 | is_regression=is_regression, 46 | params=params, 47 | for_uncertainty_estimation=for_uncertainty_estimation, 48 | ensemble_size=ensemble_size, 49 | calibrate=calibrate, 50 | ) 51 | model.fit(X, y) 52 | 53 | return model 54 | 55 | 56 | def train_torch_model( 57 | train_dataset: SimpleMolecularDataset, 58 | val_dataset: SimpleMolecularDataset, 59 | test_dataset: SimpleMolecularDataset, 60 | algorithm: str, 61 | is_regression: bool, 62 | params: Optional[dict] = None, 63 | seed: Optional[int] = None, 64 | ensemble_size: int = 5, 65 | ): 66 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) 67 | seed_everything(seed, workers=True) 68 | 69 | width = params.pop("mlp_width") 70 | depth = params.pop("mlp_depth") 71 | batch_size = params["batch_size"] 72 | 73 | # NOTE: Since the datasets are all very small, 74 | # setting up and syncing the threads takes longer than 75 | # what we gain by using the threads 76 | no_workers = 0 77 | 78 | models = [] 79 | for i in range(ensemble_size): 80 | base = get_simple_mlp(len(train_dataset.X[0]), width, depth, out_size=None) 81 | head = get_simple_mlp( 82 | input_size=width * 2 if algorithm == "MTL" else width, is_regression=is_regression 83 | ) 84 | 85 | model = MOOD_DA_DG_ALGORITHMS[algorithm]( 86 | base_network=base, 87 | prediction_head=head, 88 | loss_fn=torch.nn.MSELoss() if is_regression else torch.nn.BCELoss(), 89 | **params, 90 | ) 91 | 92 | if is_domain_adaptation(model): 93 | train_dataset_da = DAMolecularDataset(source_dataset=train_dataset, target_dataset=test_dataset) 94 | train_dataloader = DataLoader( 95 | train_dataset_da, batch_size=batch_size, shuffle=True, num_workers=no_workers 96 | ) 97 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=no_workers) 98 | elif is_domain_generalization(model): 99 | train_dataloader = DataLoader( 100 | train_dataset, 101 | batch_size=batch_size, 102 | shuffle=True, 103 | collate_fn=domain_based_collate, 104 | num_workers=no_workers, 105 | ) 106 | val_dataloader = DataLoader( 107 | val_dataset, batch_size=batch_size, collate_fn=domain_based_collate, num_workers=no_workers 108 | ) 109 | else: 110 | train_dataloader = DataLoader( 111 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=no_workers 112 | ) 113 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=no_workers) 114 | 115 | # NOTE: For smaller dataset, moving data between and CPU and GPU will be the bottleneck 116 | use_gpu = torch.cuda.is_available() and len(train_dataset) > 2500 117 | 118 | callbacks = [EarlyStopping("val_loss", patience=10, mode="min")] 119 | trainer = Trainer( 120 | max_epochs=NUM_EPOCHS, 121 | deterministic="warn", 122 | callbacks=callbacks, 123 | enable_model_summary=False, 124 | enable_progress_bar=False, 125 | num_sanity_val_steps=0, 126 | logger=False, 127 | accelerator="gpu" if use_gpu else None, 128 | devices=1 if use_gpu else None, 129 | enable_checkpointing=False, 130 | ) 131 | trainer.fit(model, train_dataloader, val_dataloader) 132 | models.append(model) 133 | 134 | return Ensemble(models, is_regression) 135 | 136 | 137 | def train( 138 | train_dataset: SimpleMolecularDataset, 139 | val_dataset: SimpleMolecularDataset, 140 | test_dataset: SimpleMolecularDataset, 141 | algorithm: str, 142 | is_regression: bool, 143 | params: dict, 144 | seed: int, 145 | calibrate: bool = False, 146 | ensemble_size: int = 5, 147 | ): 148 | # NOTE: The order here matters since there are two MLP implementations 149 | # In this case, we want to use the torch implementation. 150 | 151 | if algorithm in MOOD_DA_DG_ALGORITHMS: 152 | if calibrate: 153 | raise NotImplementedError("We only support calibration for scikit-learn models") 154 | 155 | return train_torch_model( 156 | train_dataset=train_dataset, 157 | val_dataset=val_dataset, 158 | test_dataset=test_dataset, 159 | algorithm=algorithm, 160 | is_regression=is_regression, 161 | params=params, 162 | seed=seed, 163 | ensemble_size=ensemble_size, 164 | ) 165 | 166 | elif algorithm in MOOD_BASELINES: 167 | return train_baseline_model( 168 | X=train_dataset.X, 169 | y=train_dataset.y, 170 | algorithm=algorithm, 171 | is_regression=is_regression, 172 | params=params, 173 | seed=seed, 174 | for_uncertainty_estimation=True, 175 | ensemble_size=ensemble_size, 176 | calibrate=calibrate, 177 | ) 178 | 179 | else: 180 | raise NotImplementedError(f"{algorithm} is not supported") 181 | -------------------------------------------------------------------------------- /mood/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Union, Optional 4 | 5 | from scipy.spatial.distance import cdist 6 | from sklearn.neural_network import MLPRegressor, MLPClassifier 7 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 8 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier 9 | 10 | 11 | _SKLEARN_MLP_TYPE = Union[MLPRegressor, MLPClassifier] 12 | _SKLEARN_RF_TYPE = Union[RandomForestRegressor, RandomForestClassifier] 13 | _SKLEARN_GP_TYPE = Union[GaussianProcessRegressor, GaussianProcessClassifier] 14 | 15 | 16 | class EmpiricalKernelMapTransformer: 17 | def __init__(self, n_samples: int, metric: str, random_state: Optional[int] = None): 18 | self._n_samples = n_samples 19 | self._random_state = random_state 20 | self._samples = None 21 | self._metric = metric 22 | 23 | def __call__(self, X): 24 | """Transforms a list of datapoints""" 25 | return self.transform(X) 26 | 27 | def transform(self, X): 28 | """Transforms a single datapoint""" 29 | if self._samples is None: 30 | rng = np.random.default_rng(self._random_state) 31 | self._samples = X[rng.choice(np.arange(len(X)), self._n_samples)] 32 | X = cdist(X, self._samples, metric=self._metric) 33 | return X 34 | -------------------------------------------------------------------------------- /mood/utils.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import uuid 3 | from datetime import datetime 4 | from typing import Optional 5 | 6 | import datamol as dm 7 | import pandas as pd 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from loguru import logger 12 | 13 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, CACHE_DIR 14 | 15 | 16 | def load_representation_for_downstream_application( 17 | name, 18 | representation, 19 | update_cache: bool = False, 20 | return_compound_ids: bool = False, 21 | ): 22 | suffix = ["representations", name, f"{representation}.parquet"] 23 | 24 | lpath = dm.fs.join(CACHE_DIR, "downstream_applications", *suffix) 25 | if not dm.fs.exists(lpath) or update_cache: 26 | rpath = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, *suffix) 27 | logger.debug(f"Downloading {rpath} to {lpath}") 28 | dm.fs.copy_file(rpath, lpath, force=update_cache) 29 | else: 30 | logger.debug(f"Using cache at {lpath}") 31 | 32 | data = pd.read_parquet(lpath) 33 | 34 | X = np.stack(data["representation"].values) 35 | 36 | mask = get_mask_for_distances_or_representations(X) 37 | 38 | if not return_compound_ids: 39 | return X[mask] 40 | 41 | indices = data.iloc[mask]["unique_id"].to_numpy() 42 | return X[mask], indices 43 | 44 | 45 | def load_distances_for_downstream_application( 46 | name, 47 | representation, 48 | dataset, 49 | update_cache: bool = False, 50 | return_compound_ids: bool = False, 51 | ): 52 | suffix = ["distances", name, dataset, f"{representation}.parquet"] 53 | 54 | lpath = dm.fs.join(CACHE_DIR, "downstream_applications", *suffix) 55 | if not dm.fs.exists(lpath) or update_cache: 56 | rpath = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, *suffix) 57 | logger.debug(f"Downloading {rpath} to {lpath}") 58 | dm.fs.copy_file(rpath, lpath, force=update_cache) 59 | else: 60 | logger.debug(f"Using cache at {lpath}") 61 | 62 | data = pd.read_parquet(lpath) 63 | 64 | distances = data["distance"].to_numpy() 65 | mask = get_mask_for_distances_or_representations(distances) 66 | 67 | if not return_compound_ids: 68 | return distances[mask] 69 | 70 | indices = data.iloc[mask]["unique_id"].to_numpy() 71 | return distances[mask], indices 72 | 73 | 74 | def save_figure_with_fsspec(path, exist_ok=False): 75 | if dm.fs.exists(path) and not exist_ok: 76 | raise RuntimeError(f"{path} already exists") 77 | 78 | if dm.fs.is_local_path(path): 79 | plt.savefig(path) 80 | return 81 | 82 | mapper = dm.fs.get_mapper(path) 83 | clean_path = path.rstrip(mapper.fs.sep) 84 | dir_components = str(clean_path).split(mapper.fs.sep)[:-1] 85 | dir_path = mapper.fs.sep.join(dir_components) 86 | 87 | dm.fs.mkdir(dir_path, exist_ok=True) 88 | 89 | with tempfile.TemporaryDirectory() as tmpdir: 90 | lpath = dm.fs.join(tmpdir, f"{str(uuid.uuid4())}.png") 91 | 92 | plt.savefig(lpath) 93 | dm.fs.copy_file(lpath, path, force=exist_ok) 94 | 95 | 96 | def get_outlier_bounds(X, factor: float = 1.5): 97 | q1 = np.quantile(X, 0.25) 98 | q3 = np.quantile(X, 0.75) 99 | iqr = q3 - q1 100 | 101 | lower = max(np.min(X), q1 - factor * iqr) 102 | upper = min(np.max(X), q3 + factor * iqr) 103 | 104 | return lower, upper 105 | 106 | 107 | def bin_with_overlap(data, filter_outliers: bool = True): 108 | if filter_outliers: 109 | minimum, maximum = get_outlier_bounds(data) 110 | window_size = (maximum - minimum) / 10 111 | yield minimum, np.nonzero(data <= minimum)[0] 112 | 113 | else: 114 | minimum = np.min(data) 115 | maximum = np.max(data) 116 | window_size = (maximum - minimum) / 10 117 | 118 | assert minimum >= 0, "A distance cannot be lower than 0" 119 | 120 | x = minimum 121 | step_size = window_size / 20 122 | while x + window_size < maximum: 123 | yield x + 0.5 * window_size, np.nonzero(np.logical_and(data >= x, data < x + window_size))[0] 124 | x += step_size 125 | 126 | # Yield the rest data 127 | yield x + ((maximum - x) / 2.0), np.nonzero(data >= x)[0] 128 | 129 | 130 | def get_mask_for_distances_or_representations(X): 131 | # The 1e4 threshold is somewhat arbitrary, but manually chosen to 132 | # filter out compounds that don't make sense (and are outliers). 133 | # (e.g. WHIM for [C-]#N.[C-]#N.[C-]#N.[C-]#N.[C-]#N.[Fe+4].[N-]=O) 134 | # Propagating such high-values would cause issues in downstream 135 | # functions (e.g. KMeans) 136 | mask = [ 137 | i 138 | for i, a in enumerate(X) 139 | if a is not None and ~np.isnan(a).any() and np.isfinite(a).all() and ~(a > 1e4).any() 140 | ] 141 | return mask 142 | 143 | 144 | class Timer: 145 | """Context manager for timing operations""" 146 | 147 | def __init__(self, name: Optional[str] = None): 148 | self.name = name if name is not None else "operation" 149 | self.start_time = None 150 | self.end_time = None 151 | 152 | @property 153 | def duration(self): 154 | if self.start_time is None: 155 | raise RuntimeError("Cannot get the duration for an operation that has not started yet") 156 | if self.end_time is None: 157 | return datetime.now() - self.start_time 158 | return self.end_time - self.start_time 159 | 160 | def __enter__(self): 161 | self.start_time = datetime.now() 162 | logger.debug(f"Starting {self.name}") 163 | return self 164 | 165 | def __exit__(self, exc_type, exc_val, exc_tb): 166 | self.end_time = datetime.now() 167 | logger.info(f"Finished {self.name}. Duration: {self.duration}") 168 | -------------------------------------------------------------------------------- /mood/visualize.py: -------------------------------------------------------------------------------- 1 | import fsspec 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import seaborn as sns 5 | import numpy as np 6 | from typing import Optional, List 7 | from mood.utils import get_outlier_bounds 8 | from mood.metrics import Metric 9 | 10 | 11 | def plot_performance_over_distance( 12 | performance_data: pd.DataFrame, 13 | calibration_data: pd.DataFrame, 14 | dataset_name: str, 15 | ax: Optional = None, 16 | show_legend: bool = True, 17 | show_title: bool = True, 18 | show_xlabel: bool = True, 19 | show_ylabel: bool = True, 20 | ): 21 | if ax is None: 22 | _, ax = plt.subplots(figsize=(12, 6)) 23 | 24 | expected_columns = ["distance", "score_lower", "score_mu", "score_upper"] 25 | if not all(c in performance_data.columns for c in expected_columns): 26 | raise ValueError( 27 | f"For performance_data, expecting {expected_columns}, found {performance_data.columns}" 28 | ) 29 | if not all(c in calibration_data.columns for c in expected_columns): 30 | raise ValueError( 31 | f"For calibration_data, expecting {expected_columns}, found {calibration_data.columns}" 32 | ) 33 | 34 | def _plot(data, color, ax): 35 | sns.lineplot(x=data[:, 0], y=data[:, 2], color=color, ax=ax, lw=4) 36 | ax.fill_between(data[:, 0], data[:, 1], data[:, 3], color=color, alpha=0.2) 37 | return ax 38 | 39 | ax = _plot(performance_data[expected_columns].to_numpy(), "tab:blue", ax) 40 | ax_calibration = _plot(calibration_data[expected_columns].to_numpy(), "tab:orange", ax.twinx()) 41 | 42 | perf_metric = Metric.get_default_performance_metric(dataset_name) 43 | cali_metric = Metric.get_default_calibration_metric(dataset_name) 44 | 45 | if perf_metric.mode == "min": 46 | ax.invert_yaxis() 47 | if cali_metric.mode == "min": 48 | ax_calibration.invert_yaxis() 49 | 50 | if show_ylabel: 51 | label = f"Calibration ({cali_metric.name})" 52 | ax_calibration.set_ylabel(label, rotation=-90, labelpad=18, fontsize=12) 53 | 54 | label = f"Performance ({perf_metric.name})" 55 | ax.set_ylabel(label, fontsize=12) 56 | 57 | if show_xlabel: 58 | ax.set_xlabel("Distance") 59 | 60 | if show_title: 61 | ax.set_title(dataset_name, fontsize=18) 62 | 63 | if show_legend: 64 | legend_lines = [ 65 | plt.Line2D([0], [0], color="tab:blue", lw=4), 66 | plt.Line2D([0], [0], color="tab:orange", lw=4), 67 | ] 68 | labels = ["Performance", "Calibration"] 69 | 70 | ax.legend(legend_lines, labels, fontsize=12, loc="lower center", ncol=len(labels), fancybox=True) 71 | 72 | return ax, ax_calibration 73 | 74 | 75 | def plot_distance_distributions( 76 | distances, 77 | labels: Optional[List[str]] = None, 78 | colors: Optional[List[str]] = None, 79 | styles: Optional[List[str]] = None, 80 | ax: Optional = None, 81 | outlier_factor: Optional[float] = 3.0, 82 | ): 83 | n = len(distances) 84 | show_legend = True 85 | 86 | # Set defaults 87 | if ax is None: 88 | fig, ax = plt.subplots(figsize=(12, 6)) 89 | if colors is None: 90 | cmap = sns.color_palette("rocket", n) 91 | colors = [cmap[i] for i in range(n)] 92 | if labels is None: 93 | show_legend = False 94 | labels = [""] * n 95 | if styles is None: 96 | styles = ["-"] * n 97 | 98 | ax.spines["right"].set_visible(False) 99 | ax.spines["left"].set_visible(False) 100 | ax.spines["top"].set_visible(False) 101 | ax.yaxis.set_ticklabels([]) 102 | ax.yaxis.set_ticks([]) 103 | 104 | if outlier_factor is not None: 105 | all_distances = np.concatenate(distances) 106 | lower, upper = get_outlier_bounds(all_distances, factor=outlier_factor) 107 | distances = [X[(X >= lower) & (X <= upper)] for X in distances] 108 | 109 | # Visualize all splitting methods 110 | for idx, dist in enumerate(distances): 111 | sns.kdeplot(dist, color=colors[idx], linestyle=styles[idx], ax=ax, label=labels[idx]) 112 | 113 | ax.set_xlabel(f"Distance") 114 | 115 | if show_legend: 116 | ax.legend() 117 | 118 | return ax 119 | 120 | 121 | def axes_grid_iterator( 122 | col_labels: List[str], 123 | row_labels: List[str], 124 | col_size: int = 5, 125 | row_size: int = 5, 126 | fontsize: int = 24, 127 | margin: float = 0.25, 128 | ): 129 | ncols = len(col_labels) 130 | nrows = len(row_labels) 131 | 132 | fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(col_size * ncols, row_size * nrows)) 133 | axs = np.atleast_2d(axs) 134 | 135 | for ri, row in enumerate(row_labels): 136 | for ci, col in enumerate(col_labels): 137 | ax = axs[ri][ci] 138 | if ci == 0: 139 | ax.text( 140 | -margin, 141 | 0.5, 142 | row, 143 | rotation="vertical", 144 | va="center", 145 | ha="center", 146 | transform=ax.transAxes, 147 | fontsize=fontsize, 148 | ) 149 | if ri == 0: 150 | ax.text( 151 | 0.5, 1 + margin, col, transform=ax.transAxes, va="center", ha="center", fontsize=fontsize 152 | ) 153 | yield ax, ri, ci 154 | -------------------------------------------------------------------------------- /notebooks/ToC_graphic.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/ToC_graphic.tiff -------------------------------------------------------------------------------- /notebooks/assets/checkmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/assets/checkmark.png -------------------------------------------------------------------------------- /notebooks/assets/cross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/assets/cross.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mood" 7 | description = "A python library to accompany the MOOD paper from Tossou et al. (2023)" 8 | authors = [{ name = "Cas Wognum", email = "cas@valencediscovery.com" }] 9 | readme = "README.md" 10 | dynamic = ["version"] 11 | requires-python = ">=3.10,<3.11" 12 | license = { text = "Apache" } 13 | classifiers = [ 14 | "Development Status :: 2 - Pre-Alpha", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Healthcare Industry", 17 | "Intended Audience :: Science/Research", 18 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 19 | "Topic :: Scientific/Engineering :: Bio-Informatics", 20 | "Topic :: Scientific/Engineering :: Information Analysis", 21 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 22 | "Natural Language :: English", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.10", 27 | ] 28 | 29 | dependencies = [ 30 | "pandas", 31 | "matplotlib", 32 | "scikit-learn", 33 | "torchmetrics", 34 | "pytorch-lightning <2.0", 35 | "torch >=1.10.2", 36 | "numpy <1.24", 37 | "tqdm", 38 | "optuna", 39 | "datamol", 40 | "notebook", 41 | "pytdc", 42 | "typer", 43 | "gcsfs", 44 | "pyarrow", 45 | "fastparquet", 46 | "transformers", 47 | ] 48 | 49 | [project.scripts] 50 | mood = "mood.__main__:app" 51 | 52 | [tool.black] 53 | line-length = 110 54 | target-version = ['py310'] 55 | include = '\.pyi?$' 56 | 57 | [project.urls] 58 | "Source Code" = "https://github.com/cwognum/mood-experiments/tree/main" 59 | "Paper" = "TODO" 60 | 61 | [tool.setuptools] 62 | include-package-data = true 63 | 64 | [tool.setuptools_scm] 65 | fallback_version = "dev" 66 | 67 | [tool.setuptools.packages.find] 68 | where = ["."] 69 | include = ["mood", "mood.*", "scripts", "scripts.*"] 70 | exclude = [] 71 | namespaces = true 72 | 73 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import app 2 | 3 | 4 | if __name__ == "__main__": 5 | app() 6 | -------------------------------------------------------------------------------- /scripts/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from scripts.compare_spaces import cli as model_vs_input_space_cmd 3 | from scripts.compare_splits import cli as compare_splits_cmd 4 | from scripts.compare_performance import cli as iid_ood_gap_cmd 5 | from scripts.precompute_representations import cli as precompute_representation_cmd 6 | from scripts.precompute_distances import cli as precompute_distances_cmd 7 | from scripts.visualize_shift import cli as visualize_shift_cmd 8 | from scripts.visualize_splits import cli as visualize_splits_cmd 9 | from scripts.visualize_perf_over_distance import cli as perf_over_distance_cmd 10 | 11 | 12 | compare_app = typer.Typer(help="Various CLIs that involve comparing two things") 13 | 14 | compare_app.command( 15 | name="spaces", help="Compare how distances in the Model and various Input spaces correlate" 16 | )(model_vs_input_space_cmd) 17 | 18 | compare_app.command( 19 | name="performance", help="Compare how the model performs on compounds in the IID and OOD range" 20 | )(iid_ood_gap_cmd) 21 | 22 | compare_app.command( 23 | name="splits", 24 | help="Compare how different splits replicate the shift between train and downstream applications", 25 | )(compare_splits_cmd) 26 | 27 | 28 | precompute_app = typer.Typer(help="Various CLIs that precompute data used later on") 29 | 30 | precompute_app.command( 31 | name="representation", help="Precompute representations and save these as .parquet files" 32 | )(precompute_representation_cmd) 33 | 34 | precompute_app.command( 35 | name="distances", help="Precompute distances from downstream applications to the different train sets" 36 | )(precompute_distances_cmd) 37 | 38 | 39 | visualize_app = typer.Typer(help="Various CLIs to visualize results") 40 | 41 | visualize_app.command(name="shift", help="Visualize the shift from train to downstream applications")( 42 | visualize_shift_cmd 43 | ) 44 | 45 | visualize_app.command( 46 | name="splits", help="Visualize how representative different splits are of downstream applications" 47 | )(visualize_splits_cmd) 48 | 49 | visualize_app.command( 50 | name="performance_over_distance", help="Visualize how performance and calibration evolve over distance" 51 | )(perf_over_distance_cmd) 52 | 53 | 54 | app = typer.Typer(help="CLI for the various stand-alone scripts of MOOD") 55 | app.add_typer(compare_app, name="compare") 56 | app.add_typer(precompute_app, name="precompute") 57 | app.add_typer(visualize_app, name="visualize") 58 | -------------------------------------------------------------------------------- /scripts/compare_performance.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import fsspec 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import datamol as dm 7 | 8 | from typing import Optional 9 | from datetime import datetime 10 | from loguru import logger 11 | from sklearn.model_selection import train_test_split 12 | from tqdm import tqdm 13 | 14 | from mood.constants import RESULTS_DIR 15 | from mood.dataset import load_data_from_tdc, MOOD_REGR_DATASETS 16 | from mood.metrics import Metric, compute_bootstrapped_metric 17 | from mood.representations import featurize 18 | from mood.baselines import predict_baseline_uncertainty 19 | from mood.train import train_baseline_model 20 | from mood.experiment import basic_tuning_loop 21 | from mood.utils import bin_with_overlap, load_distances_for_downstream_application 22 | from mood.distance import compute_knn_distance 23 | from mood.preprocessing import DEFAULT_PREPROCESSING 24 | 25 | 26 | def cli( 27 | baseline_algorithm: str, 28 | representation: str, 29 | dataset: str, 30 | n_seeds: int = 5, 31 | n_trials: int = 50, 32 | n_startup_trials: int = 10, 33 | base_save_dir: str = RESULTS_DIR, 34 | sub_save_dir: Optional[str] = None, 35 | overwrite: bool = False, 36 | ): 37 | if sub_save_dir is None: 38 | sub_save_dir = datetime.now().strftime("%Y%m%d") 39 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_performance", sub_save_dir) 40 | dm.fs.mkdir(out_dir, exist_ok=True) 41 | 42 | # Load the dataset 43 | smiles, y = load_data_from_tdc(dataset) 44 | X, mask = featurize( 45 | smiles, 46 | representation, 47 | standardize_fn=DEFAULT_PREPROCESSING[representation], 48 | disable_logs=True, 49 | ) 50 | y = y[mask] 51 | is_regression = dataset in MOOD_REGR_DATASETS 52 | 53 | # Get the metrics 54 | perf_metric = Metric.get_default_performance_metric(dataset) 55 | cali_metric = Metric.get_default_calibration_metric(dataset) 56 | 57 | # Generate all data needed for these plots 58 | dist_train = [] 59 | dist_test = [] 60 | y_pred = [] 61 | y_true = [] 62 | y_uncertainty = [] 63 | 64 | for seed in range(n_seeds): 65 | # Randomly split the dataset 66 | # This ensures that the distribution of distances from val to train is relatively uniform 67 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed) 68 | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=seed) 69 | 70 | file_name = f"best_hparams_{dataset}_{baseline_algorithm}_{representation}_{seed}.yaml" 71 | out_path = dm.fs.join(out_dir, file_name) 72 | 73 | if dm.fs.exists(out_path): 74 | # Load the data of the completed hyper-param study if it already exists 75 | logger.info(f"Loading the best hyper-params from {out_path}") 76 | with fsspec.open(out_path) as fd: 77 | params = yaml.safe_load(fd) 78 | 79 | else: 80 | # Run a hyper-parameter search 81 | study = basic_tuning_loop( 82 | X_train=X_train, 83 | X_test=X_val, 84 | y_train=y_train, 85 | y_test=y_val, 86 | name=baseline_algorithm, 87 | is_regression=is_regression, 88 | metric=perf_metric, 89 | global_seed=seed, 90 | n_trials=n_trials, 91 | n_startup_trials=n_startup_trials, 92 | ) 93 | 94 | params = study.best_params 95 | random_state = seed + study.best_trial.number 96 | params["random_state"] = random_state 97 | 98 | logger.info(f"Saving the best hyper-params to {out_path}") 99 | with fsspec.open(out_path, "w") as fd: 100 | yaml.dump(params, fd) 101 | 102 | file_name = f"trials_{dataset}_{baseline_algorithm}_{representation}_{seed}.csv" 103 | out_path = dm.fs.join(out_dir, file_name) 104 | 105 | logger.info(f"Saving the trials dataframe to {out_path}") 106 | study.trials_dataframe().to_csv(out_path) 107 | 108 | random_state = params.pop("random_state") 109 | model = train_baseline_model( 110 | X_train, 111 | y_train, 112 | baseline_algorithm, 113 | is_regression, 114 | params, 115 | random_state, 116 | for_uncertainty_estimation=True, 117 | ensemble_size=10, 118 | ) 119 | 120 | y_pred_ = model.predict(X_test) 121 | y_uncertainty_ = predict_baseline_uncertainty(model, X_test) 122 | 123 | y_pred.append(y_pred_) 124 | y_true.append(y_test) 125 | y_uncertainty.append(y_uncertainty_) 126 | 127 | dist_train_, dist_test_ = compute_knn_distance(X_train, [X_train, X_test]) 128 | dist_train.append(dist_train_) 129 | dist_test.append(dist_test_) 130 | 131 | dist_test = np.concatenate(dist_test) 132 | dist_train = np.concatenate(dist_train) 133 | y_pred = np.concatenate(y_pred) 134 | y_true = np.concatenate(y_true) 135 | y_uncertainty = np.concatenate(y_uncertainty) 136 | 137 | # Collect the distances of the downstream applications 138 | dist_scr = load_distances_for_downstream_application( 139 | "virtual_screening", representation, dataset, update_cache=True 140 | ) 141 | dist_opt = load_distances_for_downstream_application( 142 | "optimization", representation, dataset, update_cache=True 143 | ) 144 | dist_app = np.concatenate((dist_opt, dist_scr)) 145 | 146 | # Compute the difference in IID and OOD performance and calibration 147 | lower, upper = np.quantile(dist_train, 0.025), np.quantile(dist_train, 0.975) 148 | mask = np.logical_and(dist_test >= lower, dist_test <= upper) 149 | score_iid = perf_metric(y_true[mask], y_pred[mask]) 150 | calibration_iid = cali_metric(y_true[mask], y_pred[mask], y_uncertainty[mask]) 151 | logger.info(f"Found an IID {perf_metric.name} score of {score_iid:.3f}") 152 | logger.info(f"Found an IID {cali_metric.name} calibration score of {calibration_iid:.3f}") 153 | 154 | lower, upper = np.quantile(dist_app, 0.025), np.quantile(dist_app, 0.975) 155 | mask = np.logical_and(dist_test >= lower, dist_test <= upper) 156 | score_ood = perf_metric(y_true[mask], y_pred[mask]) 157 | calibration_ood = cali_metric(y_true[mask], y_pred[mask], y_uncertainty[mask]) 158 | logger.info(f"Found an OOD {perf_metric.name} score of {score_ood:.3f}") 159 | logger.info(f"Found an OOD {cali_metric.name} calibration score of {calibration_ood:.3f}") 160 | 161 | file_name = f"gap_{dataset}_{baseline_algorithm}_{representation}.csv" 162 | out_path = dm.fs.join(out_dir, file_name) 163 | if dm.fs.exists(out_path) and not overwrite: 164 | raise RuntimeError(f"{out_path} already exists!") 165 | 166 | # Saving this as a CSV might be a bit wasteful, 167 | # but it's convenient 168 | logger.info(f"Saving the IID/OOD gap data to {out_path}") 169 | 170 | pd.DataFrame( 171 | { 172 | "dataset": dataset, 173 | "algorithm": baseline_algorithm, 174 | "representation": representation, 175 | "iid_score": [score_iid, calibration_iid], 176 | "ood_score": [score_ood, calibration_ood], 177 | "metric": [perf_metric.name, cali_metric.name], 178 | "type": ["performance", "calibration"], 179 | } 180 | ).to_csv(out_path, index=False) 181 | 182 | # Compute the performance over distance 183 | df = pd.DataFrame() 184 | for distance, mask in tqdm(bin_with_overlap(dist_test)): 185 | target = y_true[mask] 186 | preds = y_pred[mask] 187 | uncertainty = y_uncertainty[mask] 188 | 189 | n_samples = len(mask) 190 | if n_samples < 25 or len(np.unique(target)) == 1: 191 | continue 192 | 193 | perf_mu, perf_std = compute_bootstrapped_metric( 194 | targets=target, predictions=preds, metric=perf_metric, n_jobs=-1 195 | ) 196 | 197 | cali_mu, cali_std = compute_bootstrapped_metric( 198 | targets=target, predictions=preds, uncertainties=uncertainty, metric=cali_metric, n_jobs=-1 199 | ) 200 | 201 | df_ = pd.DataFrame( 202 | { 203 | "dataset": dataset, 204 | "algorithm": baseline_algorithm, 205 | "representation": representation, 206 | "distance": distance, 207 | "score_mu": [perf_mu, cali_mu], 208 | "score_std": [perf_std, cali_std], 209 | "type": ["performance", "calibration"], 210 | "metric": [perf_metric.name, cali_metric.name], 211 | "n_samples": n_samples, 212 | } 213 | ) 214 | df = pd.concat((df, df_), ignore_index=True) 215 | 216 | file_name = f"perf_over_distance_{dataset}_{baseline_algorithm}_{representation}.csv" 217 | out_path = dm.fs.join(out_dir, file_name) 218 | if dm.fs.exists(out_path) and not overwrite: 219 | raise RuntimeError(f"{out_path} already exists!") 220 | 221 | logger.info(f"Saving the performance over distance data to {out_path}") 222 | df.to_csv(out_path, index=False) 223 | -------------------------------------------------------------------------------- /scripts/compare_spaces.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import fsspec 3 | import datamol as dm 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | 8 | from datetime import datetime 9 | from loguru import logger 10 | from typing import List, Optional 11 | 12 | from scipy.stats import pearsonr, spearmanr 13 | from sklearn.metrics import r2_score 14 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum, WhiteKernel 15 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier 16 | 17 | from mood.dataset import dataset_iterator, MOOD_REGR_DATASETS 18 | from mood.model_space import ModelSpaceTransformer 19 | from mood.preprocessing import DEFAULT_PREPROCESSING 20 | from mood.distance import compute_knn_distance 21 | from mood.visualize import plot_distance_distributions 22 | from mood.representations import representation_iterator, featurize 23 | from mood.constants import RESULTS_DIR 24 | from mood.utils import ( 25 | load_representation_for_downstream_application, 26 | save_figure_with_fsspec, 27 | get_outlier_bounds, 28 | ) 29 | from mood.train import train_baseline_model 30 | 31 | 32 | def train_gp(X, y, is_regression): 33 | alpha = 1e-10 34 | for i in range(10): 35 | try: 36 | if is_regression: 37 | kernel = PairwiseKernel(metric="linear") 38 | model = GaussianProcessRegressor(kernel, alpha=alpha, random_state=0) 39 | else: 40 | kernel = Sum(PairwiseKernel(metric="linear"), WhiteKernel(noise_level=alpha)) 41 | model = GaussianProcessClassifier(kernel, random_state=0) 42 | model.fit(X, y) 43 | return model 44 | except (np.linalg.LinAlgError, ValueError): 45 | # ValueError: array must not contain infs or NaNs 46 | # LinAlgError: N-th leading minor of the array is not positive definite 47 | # LinAlgError: The kernel is not returning a positive definite matrix 48 | alpha = alpha * 10 49 | return None 50 | 51 | 52 | def get_model_space_distances(model, train, queries): 53 | embedding_size = int(round(train.shape[1] * 0.25)) 54 | trans = ModelSpaceTransformer(model, embedding_size) 55 | 56 | model_space_train = trans(train) 57 | queries = [trans(q) for q in queries] 58 | 59 | distances = compute_knn_distance(model_space_train, queries, n_jobs=-1) 60 | return distances 61 | 62 | 63 | def compute_correlations(input_spaces, model_spaces, labels): 64 | lower, upper = get_outlier_bounds(np.concatenate(input_spaces), factor=3.0) 65 | input_masks = [(X >= lower) & (X <= upper) for X in input_spaces] 66 | 67 | lower, upper = get_outlier_bounds(np.concatenate(model_spaces), factor=3.0) 68 | model_masks = [(X >= lower) & (X <= upper) for X in model_spaces] 69 | 70 | masks = [mask1 & mask2 for mask1, mask2 in zip(input_masks, model_masks)] 71 | input_spaces = [d[mask] for d, mask in zip(input_spaces, masks)] 72 | model_spaces = [d[mask] for d, mask in zip(model_spaces, masks)] 73 | 74 | df = pd.DataFrame() 75 | for input_space, model_space, label in zip(input_spaces, model_spaces, labels): 76 | df_ = pd.DataFrame( 77 | { 78 | "pearson": pearsonr(input_space, model_space)[0], 79 | "spearman": spearmanr(input_space, model_space)[0], 80 | "r_squared": r2_score(input_space, model_space), 81 | }, 82 | index=[0], 83 | ) 84 | df = pd.concat((df, df_), ignore_index=True) 85 | 86 | return df 87 | 88 | 89 | def cli( 90 | base_save_dir: str = RESULTS_DIR, 91 | sub_save_dir: Optional[str] = None, 92 | overwrite: bool = False, 93 | skip_representation: Optional[List[str]] = None, 94 | skip_dataset: Optional[List[str]] = None, 95 | batch_size: int = 16, 96 | ): 97 | if sub_save_dir is None: 98 | sub_save_dir = datetime.now().strftime("%Y%m%d") 99 | 100 | fig_save_dir = dm.fs.join(base_save_dir, "figures", "compare_spaces", sub_save_dir) 101 | dm.fs.mkdir(fig_save_dir, exist_ok=True) 102 | logger.info(f"Saving figures to {fig_save_dir}") 103 | 104 | np_save_dir = dm.fs.join(base_save_dir, "numpy", "compare_spaces", sub_save_dir) 105 | dm.fs.mkdir(np_save_dir, exist_ok=True) 106 | logger.info(f"Saving NumPy arrays to {np_save_dir}") 107 | 108 | df_save_dir = dm.fs.join(base_save_dir, "dataframes", "compare_spaces", sub_save_dir) 109 | dm.fs.mkdir(df_save_dir, exist_ok=True) 110 | corr_path = dm.fs.join(df_save_dir, "correlations.csv") 111 | if dm.fs.exists(df_save_dir) and not overwrite: 112 | raise typer.BadParameter(f"{corr_path} already exists. Specify --overwrite or --base-save-dir.") 113 | logger.info(f"Saving correlation dataframe to {corr_path}") 114 | 115 | df_corr = pd.DataFrame() 116 | 117 | dataset_it = dataset_iterator(blacklist=skip_dataset) 118 | 119 | for dataset, (smiles, y) in dataset_it: 120 | representation_it = representation_iterator( 121 | smiles, 122 | n_jobs=-1, 123 | progress=True, 124 | blacklist=skip_representation, 125 | standardize_fn=DEFAULT_PREPROCESSING, 126 | batch_size=batch_size, 127 | ) 128 | 129 | for representation, (X, mask) in representation_it: 130 | y_repr = y[mask] 131 | 132 | virtual_screening = load_representation_for_downstream_application( 133 | "virtual_screening", representation, update_cache=True 134 | ) 135 | optimization = load_representation_for_downstream_application( 136 | "optimization", representation, update_cache=True 137 | ) 138 | 139 | is_regression = dataset in MOOD_REGR_DATASETS 140 | mlp_model = train_baseline_model(X, y_repr, "MLP", is_regression) 141 | rf_model = train_baseline_model(X, y_repr, "RF", is_regression) 142 | # We use a custom train function for GPs to include a retry system 143 | gp_model = train_gp(X, y_repr, is_regression) 144 | 145 | # Distances in input spaces 146 | input_distances = compute_knn_distance(X, [X, optimization, virtual_screening], n_jobs=-1) 147 | labels = ["Train", "Optimization", "Virtual Screening"] 148 | 149 | for dist, label in zip(input_distances, labels): 150 | path = dm.fs.join(np_save_dir, f"{dataset}_{representation}_{label}_input_space.npy") 151 | if dm.fs.exists(df_save_dir) and not overwrite: 152 | raise RuntimeError(f"{path} already exists. Specify --overwrite or --base-save-dir.") 153 | with fsspec.open(path, "wb") as fd: 154 | np.save(fd, dist) 155 | 156 | ax = plot_distance_distributions(input_distances, labels=labels) 157 | ax.set_title(f"Input space ({representation}, {dataset})") 158 | save_figure_with_fsspec( 159 | dm.fs.join(fig_save_dir, f"{dataset}_{representation}_input_space.png"), exist_ok=overwrite 160 | ) 161 | plt.close() 162 | 163 | # Distances in different model spaces 164 | for name, model in {"MLP": mlp_model, "RF": rf_model, "GP": gp_model}.items(): 165 | if model is None: 166 | logger.warning(f"Failed to train a {name} model for {dataset} on {representation}") 167 | continue 168 | 169 | model_distances = get_model_space_distances(model, X, [X, optimization, virtual_screening]) 170 | for dist, label in zip(model_distances, labels): 171 | path = dm.fs.join(np_save_dir, f"{dataset}_{representation}_{label}_{name}_space.npy") 172 | if dm.fs.exists(df_save_dir) and not overwrite: 173 | raise RuntimeError(f"{path} already exists. Specify --overwrite or --base-save-dir.") 174 | with fsspec.open(path, "wb") as fd: 175 | np.save(fd, dist) 176 | 177 | ax = plot_distance_distributions(model_distances, labels=labels) 178 | ax.set_title(f"{name} space ({representation}, {dataset})") 179 | save_figure_with_fsspec( 180 | dm.fs.join(fig_save_dir, f"{dataset}_{representation}_{name}_space.png"), 181 | exist_ok=overwrite, 182 | ) 183 | plt.close() 184 | 185 | # Compute correlation 186 | df = compute_correlations( 187 | input_distances, 188 | model_distances, 189 | labels, 190 | ) 191 | df["model"] = name 192 | df["dataset"] = dataset 193 | df["representation"] = representation 194 | df_corr = pd.concat((df_corr, df), ignore_index=True) 195 | 196 | df_corr.to_csv(corr_path, index=False) 197 | df_corr.head() 198 | -------------------------------------------------------------------------------- /scripts/compare_splits.py: -------------------------------------------------------------------------------- 1 | import fsspec 2 | import numpy as np 3 | import pandas as pd 4 | import datamol as dm 5 | 6 | from loguru import logger 7 | from typing import Optional, List 8 | from datetime import datetime 9 | 10 | from mood.dataset import dataset_iterator 11 | from mood.representations import representation_iterator 12 | from mood.splitter import MOODSplitter, get_mood_splitters 13 | from mood.preprocessing import DEFAULT_PREPROCESSING 14 | from mood.utils import load_distances_for_downstream_application, save_figure_with_fsspec 15 | from mood.constants import RESULTS_DIR 16 | from mood.distance import get_distance_metric 17 | 18 | 19 | def cli( 20 | base_save_dir: str = RESULTS_DIR, 21 | sub_save_dir: Optional[str] = None, 22 | skip_representation: Optional[List[str]] = None, 23 | skip_dataset: Optional[List[str]] = None, 24 | save_figures: bool = True, 25 | n_splits: int = 5, 26 | seed: Optional[int] = 0, 27 | use_cache: bool = False, 28 | batch_size: int = 16, 29 | verbose: bool = False, 30 | overwrite: bool = False, 31 | n_jobs: Optional[int] = None, 32 | ): 33 | df = pd.DataFrame() 34 | 35 | if sub_save_dir is None: 36 | sub_save_dir = datetime.now().strftime("%Y%m%d") 37 | 38 | fig_out_dir = dm.fs.join(base_save_dir, "figures", "compare_splits", sub_save_dir) 39 | dm.fs.mkdir(fig_out_dir, exist_ok=True) 40 | 41 | data_out_dir = dm.fs.join(base_save_dir, "numpy", "compare_splits", sub_save_dir) 42 | dm.fs.mkdir(data_out_dir, exist_ok=True) 43 | 44 | dataset_it = dataset_iterator(blacklist=skip_dataset) 45 | for dataset, (smiles, y) in dataset_it: 46 | representation_it = representation_iterator( 47 | smiles, 48 | n_jobs=n_jobs, 49 | progress=verbose, 50 | standardize_fn=DEFAULT_PREPROCESSING, 51 | batch_size=batch_size, 52 | blacklist=skip_representation, 53 | ) 54 | 55 | for representation, (X, mask) in representation_it: 56 | logger.info(f"Loading precomputed distances for virtual screening") 57 | distances_vs = load_distances_for_downstream_application( 58 | "virtual_screening", representation, dataset, update_cache=not use_cache 59 | ) 60 | 61 | logger.info(f"Loading precomputed distances for optimization") 62 | distances_op = load_distances_for_downstream_application( 63 | "optimization", representation, dataset, update_cache=not use_cache 64 | ) 65 | 66 | metric = get_distance_metric(X) 67 | if metric == "jaccard": 68 | X = X.astype(bool) 69 | 70 | splitters = get_mood_splitters(smiles[mask], n_splits, seed, n_jobs=n_jobs) 71 | splitter = MOODSplitter(splitters, np.concatenate((distances_vs, distances_op)), metric, k=5) 72 | splitter.fit(X, progress=verbose, plot=save_figures) 73 | 74 | if save_figures: 75 | out_path = dm.fs.join(fig_out_dir, f"fig_{dataset}_{representation}.png") 76 | if not overwrite and dm.fs.exists(out_path): 77 | raise RuntimeError( 78 | f"{out_path} already exists. Specify a different path or use --overwrite" 79 | ) 80 | logger.info(f"Saving figure to {out_path}") 81 | save_figure_with_fsspec(out_path, exist_ok=overwrite) 82 | 83 | for char in splitter._split_chars: 84 | out_path = dm.fs.join(data_out_dir, f"distances_{dataset}_{representation}_{char.label}.npy") 85 | if not overwrite and dm.fs.exists(out_path): 86 | raise RuntimeError( 87 | f"{out_path} already exists. Specify a different path or use --overwrite" 88 | ) 89 | logger.info(f"Saving distance data to {out_path}") 90 | with fsspec.open(out_path, "wb") as fd: 91 | np.save(fd, char.distances) 92 | 93 | df_ = splitter.get_protocol_results() 94 | df_["representation"] = representation 95 | df_["dataset"] = dataset 96 | 97 | df = pd.concat((df, df_), ignore_index=True) 98 | 99 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_splits", sub_save_dir) 100 | dm.fs.mkdir(out_dir, exist_ok=True) 101 | 102 | out_path = dm.fs.join(out_dir, "splits.csv") 103 | if not overwrite and dm.fs.exists(out_path): 104 | raise RuntimeError(f"{out_path} already exists. Specify a different path or use --overwrite") 105 | 106 | df.to_csv(out_path, index=False) 107 | logger.info(f"Saving dataframe to {out_path}") 108 | -------------------------------------------------------------------------------- /scripts/precompute_distances.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import datamol as dm 3 | 4 | from typing import Optional, List 5 | from mood.dataset import dataset_iterator 6 | from mood.representations import representation_iterator 7 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR 8 | from mood.utils import load_representation_for_downstream_application 9 | from mood.distance import compute_knn_distance 10 | from mood.preprocessing import DEFAULT_PREPROCESSING 11 | 12 | 13 | def save(distances, compounds, molecule_set, dataset, representation, overwrite): 14 | out_path = dm.fs.join( 15 | DOWNSTREAM_APPS_DATA_DIR, "distances", molecule_set, dataset, f"{representation}.parquet" 16 | ) 17 | if dm.fs.exists(out_path) and not overwrite: 18 | raise RuntimeError(f"{out_path} already exists!") 19 | 20 | df = pd.DataFrame({"unique_id": compounds, "distance": distances}) 21 | df.to_parquet(out_path) 22 | return df 23 | 24 | 25 | def cli( 26 | overwrite: bool = False, 27 | representation: Optional[List[str]] = None, 28 | dataset: Optional[List[str]] = None, 29 | verbose: bool = False, 30 | ): 31 | if len(dataset) == 0: 32 | dataset = None 33 | if len(representation) == 0: 34 | representation = None 35 | 36 | for dataset, (smiles, y) in dataset_iterator(whitelist=dataset, disable_logs=True): 37 | it = representation_iterator( 38 | smiles, 39 | standardize_fn=DEFAULT_PREPROCESSING, 40 | n_jobs=-1, 41 | progress=True, 42 | whitelist=representation, 43 | disable_logs=True, 44 | ) 45 | 46 | for representation, (X, mask) in it: 47 | virtual_screening, vs_compounds = load_representation_for_downstream_application( 48 | "virtual_screening", 49 | representation, 50 | return_compound_ids=True, 51 | update_cache=True, 52 | ) 53 | optimization, opt_compounds = load_representation_for_downstream_application( 54 | "optimization", 55 | representation, 56 | return_compound_ids=True, 57 | update_cache=True, 58 | ) 59 | 60 | input_distances = compute_knn_distance(X, [optimization, virtual_screening], n_jobs=-1) 61 | 62 | save(input_distances[0], opt_compounds, "optimization", dataset, representation, overwrite) 63 | save(input_distances[1], vs_compounds, "virtual_screening", dataset, representation, overwrite) 64 | -------------------------------------------------------------------------------- /scripts/precompute_graphormer.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | import pandas as pd 4 | import datamol as dm 5 | from molfeat.trans.base import MoleculeTransformer 6 | 7 | from loguru import logger 8 | from functools import partial 9 | from typing import Optional, List 10 | from mood.preprocessing import DEFAULT_PREPROCESSING 11 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, DATASET_DATA_DIR, SUPPORTED_DOWNSTREAM_APPS 12 | from mood.dataset import dataset_iterator, MOOD_DATASETS 13 | 14 | STATE_DICT = { 15 | "_molfeat_version": "0.5.2", 16 | "args": { 17 | "max_length": 256, 18 | "name": "pcqm4mv1_graphormer_base", 19 | "pooling": "mean", 20 | "precompute_cache": False, 21 | "version": None, 22 | }, 23 | "name": "GraphormerTransformer", 24 | } 25 | 26 | 27 | def cli( 28 | batch_size: int = 16, 29 | verbose: bool = False, 30 | overwrite: bool = False, 31 | skip: Optional[List[str]] = None, 32 | ): 33 | if skip is None: 34 | skip = [] 35 | 36 | graphormer = MoleculeTransformer.from_state_dict(STATE_DICT) 37 | standardize_fn = partial(DEFAULT_PREPROCESSING["Graphormer"], disable_logs=True) 38 | 39 | # Precompute Graphormer for the downstream applications 40 | for molecule_set in [app for app in SUPPORTED_DOWNSTREAM_APPS if app not in skip]: 41 | in_path = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, f"{molecule_set}.csv") 42 | out_path = dm.fs.join( 43 | DOWNSTREAM_APPS_DATA_DIR, "representations", molecule_set, f"Graphormer.parquet" 44 | ) 45 | 46 | if dm.fs.exists(out_path) and not overwrite: 47 | raise ValueError(f"{out_path} already exists! Use --overwrite to overwrite!") 48 | 49 | # Load 50 | logger.info(f"Loading SMILES from {in_path}") 51 | df = pd.read_csv(in_path) 52 | 53 | # Standardization 54 | df["smiles"] = dm.utils.parallelized( 55 | standardize_fn, 56 | df["canonical_smiles"].values, 57 | n_jobs=-1, 58 | progress=verbose, 59 | ) 60 | 61 | # Setting max length. We don't ignore padding tokens, so best to do this per dataset 62 | graphormer.set_max_length(graphormer.compute_max_length(df["smiles"].values)) 63 | logger.info(f"Computed a max number of nodes of {graphormer.max_length}") 64 | 65 | # Compute the representation 66 | logger.info(f"Precomputing Graphormer representation") 67 | feats = graphormer.batch_transform( 68 | graphormer, df["smiles"].values, batch_size=batch_size, n_jobs=None 69 | ) 70 | 71 | df["representation"] = list(feats) 72 | df = df[~pd.isna(df["representation"])] 73 | 74 | # Save 75 | logger.info(f"Saving results to {out_path}") 76 | df[["unique_id", "representation"]].to_parquet(out_path) 77 | 78 | blacklist = [app for app in MOOD_DATASETS if app in skip] 79 | for dataset, (smiles, _) in dataset_iterator(blacklist=blacklist, disable_logs=True): 80 | logger.info(f"Dataset {dataset}") 81 | out_path = dm.fs.join(DATASET_DATA_DIR, "representations", dataset, f"Graphormer.parquet") 82 | 83 | if dm.fs.exists(out_path) and not overwrite: 84 | raise ValueError(f"{out_path} already exists! Use --overwrite to overwrite!") 85 | 86 | df = pd.DataFrame() 87 | df["smiles"] = dm.utils.parallelized( 88 | standardize_fn, 89 | smiles, 90 | n_jobs=-1, 91 | progress=verbose, 92 | ) 93 | 94 | # Setting max length. We don't ignore padding tokens, so best to do this per dataset 95 | graphormer.set_max_length(graphormer.compute_max_length(df["smiles"].values)) 96 | logger.info(f"Computed a max number of nodes of {graphormer.max_length}") 97 | 98 | df["unique_id"] = dm.utils.parallelized( 99 | dm.unique_id, 100 | df["smiles"].values, 101 | n_jobs=-1, 102 | progress=verbose, 103 | ) 104 | 105 | feats = graphormer.batch_transform( 106 | graphormer, df["smiles"].values, batch_size=batch_size, n_jobs=None 107 | ) 108 | df["representation"] = list(feats) 109 | 110 | # Save 111 | logger.info(f"Saving results to {out_path}") 112 | df.to_parquet(out_path) 113 | 114 | 115 | if __name__ == "__main__": 116 | typer.run(cli) 117 | -------------------------------------------------------------------------------- /scripts/precompute_representations.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | import pandas as pd 4 | import datamol as dm 5 | 6 | from loguru import logger 7 | from typing import Optional 8 | from mood.preprocessing import DEFAULT_PREPROCESSING 9 | from mood.representations import MOOD_REPRESENTATIONS, featurize 10 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, SUPPORTED_DOWNSTREAM_APPS 11 | 12 | 13 | def cli( 14 | molecule_set: str, 15 | representation: str, 16 | n_jobs: Optional[int] = None, 17 | batch_size: int = 16, 18 | verbose: bool = False, 19 | override: bool = False, 20 | ): 21 | if molecule_set not in SUPPORTED_DOWNSTREAM_APPS: 22 | raise typer.BadParameter(f"--molecule-set should be one of {SUPPORTED_DOWNSTREAM_APPS}.") 23 | if representation not in MOOD_REPRESENTATIONS: 24 | raise typer.BadParameter(f"--representation should be one of {MOOD_REPRESENTATIONS}.") 25 | 26 | in_path = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, f"{molecule_set}.csv") 27 | out_path = dm.fs.join( 28 | DOWNSTREAM_APPS_DATA_DIR, "representations", molecule_set, f"{representation}.parquet" 29 | ) 30 | 31 | if dm.fs.exists(out_path) and not override: 32 | raise ValueError(f"{out_path} already exists! Use --override to override!") 33 | 34 | # Load 35 | logger.info(f"Loading SMILES from {in_path}") 36 | df = pd.read_csv(in_path) 37 | 38 | # Standardization fn 39 | standardize_fn = DEFAULT_PREPROCESSING[representation] 40 | 41 | # Compute the representation 42 | logger.info(f"Precomputing {representation} representation") 43 | df["representation"] = list( 44 | featurize( 45 | df["canonical_smiles"].values, 46 | representation, 47 | standardize_fn, 48 | n_jobs=n_jobs, 49 | progress=verbose, 50 | batch_size=batch_size, 51 | return_mask=False, 52 | disable_logs=True, 53 | ) 54 | ) 55 | df = df[~pd.isna(df["representation"])] 56 | 57 | # Save 58 | logger.info(f"Saving results to {out_path}") 59 | df[["unique_id", "representation"]].to_parquet(out_path) 60 | -------------------------------------------------------------------------------- /scripts/visualize_perf_over_distance.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import pandas as pd 3 | 4 | from datetime import datetime 5 | from typing import Optional 6 | 7 | from matplotlib import pyplot as plt 8 | 9 | from mood.constants import RESULTS_DIR 10 | from mood.visualize import plot_performance_over_distance 11 | 12 | 13 | def cli( 14 | baseline_algorithm: str, 15 | representation: str, 16 | dataset: str, 17 | base_save_dir: str = RESULTS_DIR, 18 | sub_save_dir: Optional[str] = None, 19 | ): 20 | if sub_save_dir is None: 21 | sub_save_dir = datetime.now().strftime("%Y%m%d") 22 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_performance", sub_save_dir) 23 | dm.fs.mkdir(out_dir, exist_ok=True) 24 | 25 | file_name = f"perf_over_distance_{dataset}_{baseline_algorithm}_{representation}.csv" 26 | out_path = dm.fs.join(out_dir, file_name) 27 | 28 | df = pd.read_csv(out_path) 29 | df["score_lower"] = df["score_mu"] - df["score_std"] 30 | df["score_upper"] = df["score_mu"] + df["score_std"] 31 | 32 | plot_performance_over_distance( 33 | performance_data=df[df["type"] == "performance"], 34 | calibration_data=df[df["type"] == "calibration"], 35 | dataset_name=dataset, 36 | ) 37 | plt.show() 38 | -------------------------------------------------------------------------------- /scripts/visualize_shift.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from typing import Optional 4 | from loguru import logger 5 | 6 | from mood.dataset import load_data_from_tdc, MOOD_REGR_DATASETS 7 | from mood.representations import featurize 8 | from mood.preprocessing import DEFAULT_PREPROCESSING 9 | from mood.visualize import plot_distance_distributions 10 | from mood.distance import compute_knn_distance 11 | from mood.utils import load_representation_for_downstream_application 12 | from mood.train import train_baseline_model 13 | from mood.model_space import ModelSpaceTransformer 14 | 15 | 16 | def cli( 17 | dataset: str, 18 | representation: str, 19 | model_space: Optional[str] = None, 20 | ): 21 | smiles, y = load_data_from_tdc(dataset) 22 | standardize_fn = DEFAULT_PREPROCESSING[representation] 23 | X, mask = featurize(smiles, representation, standardize_fn, disable_logs=True) 24 | y = y[mask] 25 | 26 | logger.info(f"Loading precomputed representations for virtual screening") 27 | virtual_screening = load_representation_for_downstream_application("virtual_screening", representation) 28 | 29 | logger.info(f"Loading precomputed representations for optimization") 30 | optimization = load_representation_for_downstream_application("optimization", representation) 31 | 32 | if model_space is not None: 33 | logger.info(f"Computing distance in the {model_space} model space") 34 | is_regression = dataset in MOOD_REGR_DATASETS 35 | model = train_baseline_model(X, y, model_space, is_regression) 36 | embedding_size = int(round(X.shape[1] * 0.25)) 37 | trans = ModelSpaceTransformer(model, embedding_size) 38 | 39 | X = trans(X) 40 | virtual_screening = trans(virtual_screening) 41 | optimization = trans(optimization) 42 | 43 | logger.info("Computing the k-NN distance") 44 | distances = compute_knn_distance(X, [X, optimization, virtual_screening], n_jobs=-1) 45 | 46 | logger.info("Plotting the results") 47 | labels = ["Train", "Optimization", "Virtual Screening"] 48 | ax = plot_distance_distributions(distances, labels=labels) 49 | plt.show() 50 | -------------------------------------------------------------------------------- /scripts/visualize_splits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from typing import Optional 5 | from loguru import logger 6 | 7 | from mood.distance import get_distance_metric 8 | from mood.representations import featurize 9 | from mood.dataset import load_data_from_tdc 10 | from mood.utils import load_distances_for_downstream_application 11 | from mood.splitter import MOODSplitter, get_mood_splitters 12 | from mood.preprocessing import DEFAULT_PREPROCESSING 13 | 14 | 15 | def cli( 16 | dataset: str, 17 | representation: str, 18 | n_splits: int = 5, 19 | use_cache: bool = True, 20 | seed: Optional[int] = None, 21 | ): 22 | logger.info(f"Loading precomputed distances for virtual screening") 23 | distances_vs = load_distances_for_downstream_application( 24 | "virtual_screening", representation, dataset, update_cache=not use_cache 25 | ) 26 | 27 | logger.info(f"Loading precomputed distances for optimization") 28 | distances_op = load_distances_for_downstream_application( 29 | "optimization", representation, dataset, update_cache=not use_cache 30 | ) 31 | 32 | smiles, y = load_data_from_tdc(dataset) 33 | standardize_fn = DEFAULT_PREPROCESSING[representation] 34 | X, mask = featurize(smiles, representation, standardize_fn, disable_logs=True) 35 | y = y[mask] 36 | 37 | metric = get_distance_metric(X) 38 | if metric == "jaccard": 39 | X = X.astype(bool) 40 | 41 | splitters = get_mood_splitters(smiles[mask], n_splits, seed) 42 | splitter = MOODSplitter(splitters, np.concatenate((distances_vs, distances_op)), metric, k=5) 43 | ax = splitter.fit(X, y, plot=True, progress=False) 44 | plt.show() 45 | --------------------------------------------------------------------------------