5 | Close the testing-deployment gap in molecular scoring.
6 |
7 |
8 |
9 | ---
10 |
11 | [](https://github.com/psf/black)
12 | [](https://doi.org/10.1021/acs.jcim.3c01774)
13 |
14 |
15 | Python repository with all the code that was used for the [MOOD paper](https://doi.org/10.1021/acs.jcim.3c01774).
16 |
17 | ## Setup
18 | We recommend you to use `mamba` ([learn more](https://github.com/mamba-org/mamba)).
19 |
20 | ```shell
21 | mamba env create -n mood -f env.yml
22 | conda activate mood
23 | pip install -e .
24 | ```
25 |
26 | ## Overview
27 | The repository is set-up to make the results easy to reproduce. If you get stuck or like to learn more, please feel free to open an issue.
28 |
29 | ### CLI
30 | After installation, the MOOD CLI can be used to reproduce the results.
31 | ```shell
32 | mood --help
33 | ```
34 |
35 | ### Data
36 | All data has been made available in a public bucket. See [`https://storage.valencelabs.com/mood-data/`](https://storage.valencelabs.com/mood-data/).
37 |
38 | ### Code
39 | - `mood/`: This is the main part of the codebase. It contains Python implementations of several reusable components and defines the CLI.
40 | - `notebooks/`: Notebooks were used to visualize and otherwise explore the results. All plots in the paper can be reproduced through these notebooks.
41 | - `scripts/`: Generally more messy pieces of code that were used to generate (intermediate) results.
42 |
43 | ## Use the MOOD splitting protocol
44 |
45 |
46 |
47 |
48 | One of the main results of the MOOD paper, is the MOOD protocol. This protocol helps to close the testing-deployment gap in molecular scoring by finding the most representative splitting method. To make it easy for others to experiment with this protocol, we made an effort to make it easy to use.
49 |
50 | ```python
51 | import datamol as dm
52 | import numpy as np
53 | from sklearn.model_selection import ShuffleSplit
54 | from mood.splitter import PerimeterSplit, MaxDissimilaritySplit, PredefinedGroupShuffleSplit, MOODSplitter
55 |
56 | # Load your data
57 | data = dm.data.freesolv()
58 | smiles = data["smiles"].values
59 | X = np.stack([dm.to_fp(dm.to_mol(smi)) for smi in smiles])
60 | y = data["expt"].values
61 |
62 | # Load your deployment data
63 | X_deployment = np.random.random((100, 2048)).round()
64 |
65 | # Set-up your candidate splitting methods
66 | scaffolds = [dm.to_smiles(dm.to_scaffold_murcko(dm.to_mol(smi))) for smi in smiles]
67 | candidate_splitters = {
68 | "Random": ShuffleSplit(n_splits=5), # MOOD is Scikit-learn compatible!
69 | "Scaffold": PredefinedGroupShuffleSplit(groups=scaffolds, n_splits=5),
70 | "Perimeter": PerimeterSplit(n_splits=5),
71 | "Maximum Dissimilarity": MaxDissimilaritySplit(n_splits=5),
72 | }
73 |
74 | # Set-up the MOOD splitter
75 | mood_splitter = MOODSplitter(candidate_splitters, metric="jaccard")
76 | mood_splitter.fit(X, y, X_deployment=X_deployment)
77 |
78 | for train, test in mood_splitter.split(X, y):
79 | # Work your magic!
80 | ...
81 | ```
82 |
83 | ## How to cite
84 | Please cite MOOD if you use it in your research: [](https://doi.org/10.1021/acs.jcim.3c01774)
85 |
86 | ```
87 | Real-World Molecular Out-Of-Distribution: Specification and Investigation
88 | Prudencio Tossou, Cas Wognum, Michael Craig, Hadrien Mary, and Emmanuel Noutahi
89 | Journal of Chemical Information and Modeling 2024 64 (3), 697-711
90 | DOI: 10.1021/acs.jcim.3c01774
91 | ```
92 |
93 | ```bib
94 | @article{doi:10.1021/acs.jcim.3c01774,
95 | author = {Tossou, Prudencio and Wognum, Cas and Craig, Michael and Mary, Hadrien and Noutahi, Emmanuel},
96 | title = {Real-World Molecular Out-Of-Distribution: Specification and Investigation},
97 | journal = {Journal of Chemical Information and Modeling},
98 | volume = {64},
99 | number = {3},
100 | pages = {697-711},
101 | year = {2024},
102 | doi = {10.1021/acs.jcim.3c01774},
103 | note = {PMID: 38300258},
104 | URL = {https://doi.org/10.1021/acs.jcim.3c01774},
105 | eprint = {https://doi.org/10.1021/acs.jcim.3c01774}
106 | }
107 | ```
108 |
109 |
--------------------------------------------------------------------------------
/docs/images/logo.svg:
--------------------------------------------------------------------------------
1 |
18 |
19 |
--------------------------------------------------------------------------------
/docs/images/protocol.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/docs/images/protocol.png
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: mood_v2
2 |
3 | channels:
4 | - conda-forge
5 |
6 | dependencies:
7 | - pip
8 | - python =3.10
9 | - pandas
10 | - matplotlib
11 | - scikit-learn
12 | - torchmetrics
13 | - pytorch-lightning <2.0
14 | - pytorch >=1.10.2
15 | - numpy <1.24
16 | - tqdm
17 | - optuna
18 | - datamol
19 | - notebook
20 | - pytdc
21 | - typer
22 | - gcsfs
23 | - pyarrow
24 | - fastparquet
25 | - transformers
26 |
27 | # Dev
28 | - black >=20.8b1
29 |
30 |
--------------------------------------------------------------------------------
/mood/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mood/__main__.py:
--------------------------------------------------------------------------------
1 | from mood.cli import app
2 |
3 |
4 | if __name__ == "__main__":
5 | app()
6 |
--------------------------------------------------------------------------------
/mood/baselines.py:
--------------------------------------------------------------------------------
1 | import optuna
2 | import numpy as np
3 | import datamol as dm
4 |
5 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, VotingRegressor, VotingClassifier
6 | from sklearn.neural_network import MLPRegressor, MLPClassifier
7 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
8 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum, WhiteKernel
9 | from sklearn.calibration import CalibratedClassifierCV
10 | from sklearn.base import ClassifierMixin, clone
11 |
12 |
13 | """
14 | These are the three baselines we consider in the MOOD study
15 | For benchmarking, we use a torch MLP rather than the one from scikit-learn.
16 | """
17 | MOOD_BASELINES = ["MLP", "RF", "GP"]
18 |
19 |
20 | def get_baseline_cls(name, is_regression):
21 | """Simple method that allows a model to be identified by its name"""
22 |
23 | target_type = "regression" if is_regression else "classification"
24 | data = {
25 | "MLP": {
26 | "regression": MLPRegressor,
27 | "classification": MLPClassifier,
28 | },
29 | "RF": {
30 | "regression": RandomForestRegressor,
31 | "classification": RandomForestClassifier,
32 | },
33 | "GP": {
34 | "regression": GaussianProcessRegressor,
35 | "classification": GaussianProcessClassifier,
36 | },
37 | }
38 |
39 | return data[name][target_type]
40 |
41 |
42 | def get_baseline_model(
43 | name: str,
44 | is_regression: bool,
45 | params: dict,
46 | for_uncertainty_estimation: bool = False,
47 | ensemble_size: int = 10,
48 | calibrate: bool = True,
49 | ):
50 | """Entrypoint for constructing a baseline model from scikit-learn"""
51 | model = get_baseline_cls(name, is_regression)(**params)
52 | if for_uncertainty_estimation:
53 | model = uncertainty_wrapper(model, ensemble_size, calibrate)
54 | return model
55 |
56 |
57 | def uncertainty_wrapper(model, ensemble_size: int = 10, calibrate: bool = True):
58 | """Wraps the model so that it can be used for uncertainty estimation.
59 | This includes at most two steps: Turning MLPs in an ensemble and calibrating RF and MLP classifiers
60 | """
61 | if isinstance(model, MLPClassifier) or isinstance(model, MLPRegressor):
62 | models = []
63 | for idx in range(ensemble_size):
64 | model = clone(model)
65 | model.set_params(random_state=model.random_state + idx)
66 | models.append((f"mlp_{idx}", model))
67 |
68 | if isinstance(model, MLPClassifier):
69 | model = VotingClassifier(models, voting="soft", n_jobs=-1)
70 | else:
71 | model = VotingRegressor(models, n_jobs=-1)
72 |
73 | if calibrate and isinstance(model, RandomForestClassifier) or isinstance(model, MLPClassifier):
74 | model = CalibratedClassifierCV(model)
75 | return model
76 |
77 |
78 | def predict_baseline_uncertainty(model, X):
79 | """Predicts the uncertainty of the model.
80 |
81 | For GP regressors, we use the included uncertainty estimation.
82 | For classifiers, the entropy of the prediction is used as uncertainty.
83 | For regressors, the variance of the prediction is used as uncertainty.
84 | """
85 | if isinstance(model, ClassifierMixin):
86 | uncertainty = model.predict_proba(X)
87 |
88 | elif isinstance(model, GaussianProcessRegressor):
89 | std = model.predict(X, return_std=True)[1]
90 | uncertainty = std**2
91 |
92 | else:
93 | # VotingRegressor or RandomForestRegressor
94 | preds = dm.utils.parallelized(lambda x: x.predict(X), model.estimators_, n_jobs=model.n_jobs)
95 | uncertainty = np.var(preds, axis=0)
96 |
97 | return uncertainty
98 |
99 |
100 | def suggest_mlp_hparams(trial, is_regression):
101 | """Sample the hyper-parameter search space for MLPs"""
102 | architectures = [[width] * depth for width in [64, 128, 256] for depth in range(1, 4)]
103 | arch = trial.suggest_categorical("hidden_layer_sizes", architectures)
104 | lr = trial.suggest_float("learning_rate_init", 1e-7, 1e0, log=True)
105 | alpha = trial.suggest_float("alpha", 1e-10, 1e0, log=True)
106 | max_iter = trial.suggest_int("max_iter", 1, 300)
107 | batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
108 |
109 | return {
110 | "max_iter": max_iter,
111 | "alpha": alpha,
112 | "learning_rate_init": lr,
113 | "hidden_layer_sizes": arch,
114 | "batch_size": batch_size,
115 | }
116 |
117 |
118 | def suggest_rf_hparams(trial, is_regression):
119 | """Sample the hyper-parameter search space for RFs"""
120 | n_estimators = trial.suggest_int("n_estimators", 100, 1000)
121 | max_depth = trial.suggest_categorical("max_depth", [None] + list(range(1, 11)))
122 |
123 | params = {
124 | "max_depth": max_depth,
125 | "n_estimators": n_estimators,
126 | }
127 | return params
128 |
129 |
130 | def construct_kernel(is_regression, params):
131 | """Constructs a scikit-learn kernel based on provided hyper-parameters"""
132 |
133 | metric = params.pop("kernel_metric")
134 | gamma = params.pop("kernel_gamma")
135 | coef0 = params.pop("kernel_coef0")
136 |
137 | if is_regression:
138 | kernel = PairwiseKernel(metric=metric, gamma=gamma, pairwise_kernels_kwargs={"coef0": coef0})
139 | else:
140 | noise_level = params.pop("kernel_noise_level")
141 | kernel = Sum(
142 | PairwiseKernel(
143 | metric=metric,
144 | gamma=gamma,
145 | pairwise_kernels_kwargs={"coef0": coef0},
146 | ),
147 | WhiteKernel(noise_level=noise_level),
148 | )
149 | return kernel, params
150 |
151 |
152 | def suggest_gp_hparams(trial, is_regression):
153 | """Sample the hyper-parameter search space for GPs"""
154 |
155 | kernel_types = ["linear", "poly", "polynomial", "rbf", "laplacian", "sigmoid", "cosine"]
156 | metric = trial.suggest_categorical("kernel_metric", kernel_types)
157 | gamma = trial.suggest_float("kernel_gamma", 1e-5, 1e0, log=True)
158 | coef0 = trial.suggest_float("kernel_coef0", 1e-5, 1.0, log=True)
159 | params = {"kernel_gamma": gamma, "kernel_coef0": coef0, "kernel_metric": metric}
160 |
161 | n_restarts_optimizer = trial.suggest_int("n_restarts_optimizer", 0, 10)
162 | params["n_restarts_optimizer"] = n_restarts_optimizer
163 |
164 | if is_regression:
165 | params["alpha"] = trial.suggest_float("alpha", 1e-10, 1e0, log=True)
166 | else:
167 | max_iter_predict = trial.suggest_int("max_iter_predict", 10, 250)
168 | noise_level = trial.suggest_float("kernel_noise_level", 1e-5, 1, log=True)
169 | params["kernel_noise_level"] = noise_level
170 | params["max_iter_predict"] = max_iter_predict
171 |
172 | return params
173 |
174 |
175 | def suggest_baseline_hparams(name: str, is_regression: bool, trial: optuna.Trial):
176 | """Endpoint for sampling the hyper-parameter search space of the baselines"""
177 | fs = {
178 | "MLP": suggest_mlp_hparams,
179 | "RF": suggest_rf_hparams,
180 | "GP": suggest_gp_hparams,
181 | }
182 | return fs[name](trial, is_regression)
183 |
--------------------------------------------------------------------------------
/mood/chemistry.py:
--------------------------------------------------------------------------------
1 | import datamol as dm
2 | from rdkit.Chem.Scaffolds import MurckoScaffold
3 | from loguru import logger
4 |
5 |
6 | def compute_murcko_scaffold(mol):
7 | """Computes the Bemis-Murcko scaffold of a compounds."""
8 | mol = dm.to_mol(mol)
9 | scaffold = dm.to_scaffold_murcko(mol)
10 | scaffold = dm.to_smiles(scaffold)
11 | return scaffold
12 |
13 |
14 | def compute_generic_scaffold(smi: str):
15 | """Computes the scaffold (i.e. the domain) for the datapoint. The generic scaffold is the
16 | structural graph of the Murcko scaffold
17 |
18 | Args:
19 | smi (str): The SMILES string of the molecule to find the generic scaffold for
20 | Returns:
21 | The SMILES of the Generic scaffold of the input SMILES
22 | """
23 |
24 | scaffold = compute_murcko_scaffold(smi)
25 | with dm.without_rdkit_log(mute_errors=False):
26 | try:
27 | scaffold = dm.to_mol(scaffold)
28 | scaffold = MurckoScaffold.MakeScaffoldGeneric(mol=scaffold)
29 | scaffold = dm.to_smiles(scaffold)
30 | except Exception as exception:
31 | logger.debug(f"Failed to compute the GenericScaffold for {smi} due to {exception}")
32 | logger.debug(f"Returning the empty SMILES as the scaffold")
33 | scaffold = ""
34 | return scaffold
35 |
--------------------------------------------------------------------------------
/mood/cli.py:
--------------------------------------------------------------------------------
1 | import typer
2 | from scripts.cli import app as scripts_app
3 | from mood.experiment import tune_cmd
4 | from mood.experiment import rct_cmd
5 |
6 |
7 | app = typer.Typer(add_completion=False)
8 | app.add_typer(scripts_app, name="scripts")
9 | app.command(name="tune", help="Hyper-param search for a model with a specific configuration")(tune_cmd)
10 | app.command(name="rct", help="Randomly sample a configuration for training")(rct_cmd)
11 |
12 |
13 | if __name__ == "__main__":
14 | app()
15 |
--------------------------------------------------------------------------------
/mood/constants.py:
--------------------------------------------------------------------------------
1 | import datamol as dm
2 |
3 |
4 | """
5 | Where results and data are saved to
6 | """
7 | CACHE_DIR = dm.fs.get_cache_dir("MOOD")
8 |
9 | """
10 | For the downstream applications (optimization and virtual screening)
11 | we save all related data to this directory
12 | """
13 | DOWNSTREAM_APPS_DATA_DIR = "https://storage.valencelabs.com/mood-data/downstream_applications/"
14 |
15 | """
16 | Where the results of MOOD are saved
17 | """
18 | RESULTS_DIR = "https://storage.valencelabs.com/mood-data/results/"
19 |
20 | """
21 | The two downstream applications we consider for MOOD as application areas of molecular scoring
22 | """
23 | SUPPORTED_DOWNSTREAM_APPS = ["virtual_screening", "optimization"]
24 |
25 | """
26 | Where data related to specific datasets is saved
27 | """
28 | DATASET_DATA_DIR = "https://storage.valencelabs.com/mood-data/datasets/"
29 |
30 |
31 | """The number of epochs to train NNs for"""
32 | NUM_EPOCHS = 100
33 |
--------------------------------------------------------------------------------
/mood/criteria.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Sequence
3 |
4 | import numpy as np
5 |
6 | from mood.dataset import SimpleMolecularDataset
7 | from mood.distance import compute_knn_distance, get_distance_metric
8 | from mood.metrics import Metric
9 |
10 |
11 | MOOD_CRITERIA = [
12 | "Performance",
13 | "Domain Weighted Performance",
14 | "Distance Weighted Performance",
15 | "Calibration",
16 | "Calibration x Performance",
17 | ]
18 |
19 |
20 | def get_mood_criteria(performance_metric, calibration_metric):
21 | """Endpoint for easily creating a criterion by name"""
22 |
23 | return {
24 | "Performance": PerformanceCriterion(performance_metric),
25 | "Calibration": CalibrationCriterion(calibration_metric),
26 | "Domain Weighted Performance": DomainWeightedPerformanceCriterion(performance_metric),
27 | "Distance Weighted Performance": DistanceWeightedPerformanceCriterion(performance_metric),
28 | "Calibration x Performance": CombinedCriterion(performance_metric, calibration_metric),
29 | }
30 |
31 |
32 | class ModelSelectionCriterion(abc.ABC):
33 | """
34 | In MOOD, we argue that one of the tools to improve _model selection_, is the criterion we use
35 | to select. Besides selecting for raw, validation performance we suspect there could be better alternatives.
36 | This class defines the interface for a criterion to implement.
37 |
38 | We distinguish multiple iterations within a hyper-parameter search trial.
39 | For example, you might train and evaluate a model on N different splits before scoring the hyper-parameters.
40 | """
41 |
42 | def __init__(self, mode, needs_uncertainty: bool):
43 | self.mode = mode
44 | self.needs_uncertainty = needs_uncertainty
45 | self.scores = []
46 |
47 | @abc.abstractmethod
48 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
49 | pass
50 |
51 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
52 | return None
53 |
54 | def __call__(self, *args, **kwargs):
55 | return self.score(*args, **kwargs)
56 |
57 | def update(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
58 | """Scores a single iteration with the hyper-parameter search."""
59 | self.scores.append(self.score(predictions, uncertainties, train, val))
60 |
61 | def critique(self):
62 | """Aggregates the scores of individual iterations."""
63 | if len(self.scores) == 0:
64 | raise RuntimeError("Cannot critique when no scores have been computed yet")
65 |
66 | if isinstance(self.scores[0], Sequence):
67 | lengths = set(len(s) for s in self.scores)
68 | if len(lengths) != 1:
69 | raise RuntimeError("All scores need to have the same number of dimensions")
70 | n = lengths.pop()
71 | return list(np.mean([s[i] for s in self.scores]) for i in range(n))
72 | else:
73 | score = np.mean(self.scores)
74 |
75 | self.reset()
76 | return score
77 |
78 | def reset(self):
79 | self.scores = []
80 |
81 |
82 | class PerformanceCriterion(ModelSelectionCriterion):
83 | """Select models based on the mean validation performance."""
84 |
85 | def __init__(self, metric: Metric):
86 | super().__init__(mode=metric.mode, needs_uncertainty=False)
87 | if metric.is_calibration:
88 | raise ValueError(f"{metric.name} cannot be used with {type(self).__name__}")
89 | self.metric = metric
90 |
91 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
92 | sample_weights = self.compute_weights(train, val) # noqa
93 | return self.metric(y_true=val.y, y_pred=predictions, sample_weights=sample_weights)
94 |
95 |
96 | class DomainWeightedPerformanceCriterion(PerformanceCriterion):
97 | """Select models based on the mean weighted validation performance, where the weight of each sample is
98 | 1 over the domain frequency of the domain it is part of."""
99 |
100 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
101 | _, inverse, counts = np.unique(val.domains, return_counts=True, return_inverse=True)
102 | counts = [n / sum(counts) for n in counts]
103 | weights = [counts[idx] for idx in inverse]
104 | return weights
105 |
106 |
107 | class DistanceWeightedPerformanceCriterion(PerformanceCriterion):
108 | """Select models based on the mean weighted validation performance,
109 | where the weight of each sample is its distance to the train set."""
110 |
111 | def compute_weights(self, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
112 | distance_metric = get_distance_metric(val.X)
113 | return compute_knn_distance(train.X, val.X, distance_metric, k=5, n_jobs=-1)
114 |
115 |
116 | class CalibrationCriterion(ModelSelectionCriterion):
117 | """Select a model based on the mean validation calibration"""
118 |
119 | def __init__(self, metric: Metric):
120 | super().__init__(mode=metric.mode, needs_uncertainty=True)
121 | if not metric.is_calibration:
122 | raise ValueError(f"{metric.name} cannot be used with {type(self).__name__}")
123 | self.metric = metric
124 |
125 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
126 | return self.metric(y_true=val.y, y_pred=predictions, uncertainty=uncertainties)
127 |
128 |
129 | class CombinedCriterion(ModelSelectionCriterion):
130 | """Selects a model based on a combined score of the validation calibration
131 | and the validation performance. Since calibration score is between [0, 1], does so by
132 | either multiplying (when maximizing) or dividing (when minimizing)
133 | the performance score by the calibration score"""
134 |
135 | def __init__(self, performance_metric: Metric, calibration_metric: Metric):
136 | super().__init__(mode=[performance_metric.mode, calibration_metric.mode], needs_uncertainty=True)
137 | self.performance_criterion = PerformanceCriterion(performance_metric)
138 | self.calibration_criterion = CalibrationCriterion(calibration_metric)
139 |
140 | def score(self, predictions, uncertainties, train: SimpleMolecularDataset, val: SimpleMolecularDataset):
141 | prf_score = self.performance_criterion.score(predictions, uncertainties, train, val)
142 | cal_score = self.calibration_criterion.score(predictions, uncertainties, train, val)
143 | return prf_score, cal_score
144 |
--------------------------------------------------------------------------------
/mood/dataset.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import datamol as dm
4 | import numpy as np
5 |
6 | from typing import Optional, List
7 |
8 | import torch.utils.data
9 | from tdc.single_pred import ADME, Tox
10 | from tdc.metadata import dataset_names
11 | from torch.utils.data import default_collate
12 |
13 | from mood.chemistry import compute_generic_scaffold
14 | from mood.constants import CACHE_DIR
15 |
16 |
17 | class SimpleMolecularDataset(torch.utils.data.Dataset):
18 | """
19 | Simple wrapper to use these datasets with PyTorch (Lightning)
20 | Special is that this includes the notion of a molecular domain,
21 | needed for Domain Generalization
22 | """
23 |
24 | def __init__(self, smiles, X, y):
25 | self.smiles = smiles
26 | self.X = X
27 | self.y = y
28 | self.random_state = None
29 | self.domains = np.array([compute_generic_scaffold(smi) for smi in self.smiles])
30 | self.domain_representations = None
31 |
32 | def __getitem__(self, index):
33 | x = self.X[index]
34 | if self.domains is not None:
35 | d = self.domains[index]
36 | if self.domain_representations is not None:
37 | d = (d, self.domain_representations[d])
38 | x = (x, d)
39 | return x, self.y[index]
40 |
41 | def __len__(self):
42 | return len(self.X)
43 |
44 | def compute_domain_representations(self):
45 | self.domain_representations = {}
46 | unique_domains, inverse = np.unique(self.domains, return_inverse=True, axis=0)
47 | for idx in np.unique(inverse):
48 | indices = np.flatnonzero(inverse == idx)
49 | representation = self.X[indices].mean(axis=0)
50 | self.domain_representations[unique_domains[idx]] = representation
51 |
52 | def filter_by_indices(self, indices):
53 | cpy = deepcopy(self)
54 | cpy.smiles = cpy.smiles[indices]
55 | cpy.X = cpy.X[indices]
56 | cpy.y = cpy.y[indices]
57 | if cpy.domains is not None:
58 | cpy.domains = cpy.domains[indices]
59 | return cpy
60 |
61 |
62 | class DAMolecularDataset(torch.utils.data.Dataset):
63 | """Simple wrapper that creates a dataset with a supervised source domain and unsupervised target domain.
64 | This is needed for Domain Adaptation algorithms."""
65 |
66 | def __init__(self, source_dataset: SimpleMolecularDataset, target_dataset: SimpleMolecularDataset):
67 | self.src = source_dataset
68 | self.tgt = target_dataset
69 |
70 | def __getitem__(self, item):
71 | src = self.src.__getitem__(item)
72 | tgt_index = np.random.default_rng(item).integers(0, len(self.tgt))
73 | (x, domain), y = self.tgt.__getitem__(tgt_index)
74 | return {"source": src, "target": (x, domain)}
75 |
76 | def __len__(self):
77 | return len(self.src)
78 |
79 |
80 | def domain_based_collate(batch):
81 | """Custom collate function that splits a single batch into several mini-batches based on the domain.
82 | Doing that here instead of on the GPU is faster."""
83 |
84 | domains = [domain[0] if isinstance(domain, tuple) else domain for (X, domain), y in batch]
85 | batch = [
86 | ((X, domain[1]), y) if isinstance(domain, tuple) else ((X, domain), y) for (X, domain), y in batch
87 | ]
88 | _, inverse = np.unique(domains, return_inverse=True, axis=0)
89 |
90 | mini_batches = []
91 | for idx in np.unique(inverse):
92 | indices = np.flatnonzero(inverse == idx)
93 | mini_batch = [batch[i] for i in indices]
94 | mini_batch = default_collate(mini_batch)
95 | mini_batches.append(mini_batch)
96 | return mini_batches
97 |
98 |
99 | def domain_based_inference_collate(batch):
100 | """Custom collate function that preprocesses our custom dataset for inference."""
101 | batch = [
102 | ((X, domain[1]), y) if isinstance(domain, tuple) else ((X, domain), y) for (X, domain), y in batch
103 | ]
104 | return default_collate(batch)
105 |
106 |
107 | def load_data_from_tdc(name: str, disable_logs: bool = False):
108 | """
109 | Endpoint from loading a dataset from TDC
110 | """
111 |
112 | original_name = name
113 | if name in MOOD_TO_TDC:
114 | name = MOOD_TO_TDC[name]
115 |
116 | path = dm.fs.join(CACHE_DIR, "TDC")
117 |
118 | # Load the dataset
119 | if name.lower() in dataset_names["ADME"]:
120 | dataset = ADME(name=name, path=path)
121 | elif name.lower() in dataset_names["Tox"]:
122 | dataset = Tox(name=name, path=path)
123 | else:
124 | msg = f"{original_name} is not supported. Choose from {MOOD_DATASETS}."
125 | raise RuntimeError(msg)
126 |
127 | # Standardize the SMILES
128 | with dm.without_rdkit_log(enable=disable_logs):
129 | smiles = dataset.entity1
130 | smiles = np.array([dm.to_smiles(dm.to_mol(smi)) for smi in smiles])
131 |
132 | # Load the targets
133 | y = np.array(dataset.y)
134 |
135 | # Mask out NaN that might be the result of standardization
136 | mask = [i for i, x in enumerate(smiles) if x is not None]
137 | smiles = smiles[mask]
138 | y = y[mask]
139 |
140 | return smiles, y
141 |
142 |
143 | def dataset_iterator(
144 | disable_logs: bool = True,
145 | whitelist: Optional[List[str]] = None,
146 | blacklist: Optional[List[str]] = None,
147 | ):
148 | """Endpoint for iterating over several datasets from TDC"""
149 |
150 | if whitelist is not None and blacklist is not None:
151 | msg = "You cannot use a blacklist and whitelist at the same time"
152 | raise ValueError(msg)
153 |
154 | all_datasets = MOOD_DATASETS
155 |
156 | if whitelist is not None:
157 | all_datasets = [d for d in all_datasets if d in whitelist]
158 | if blacklist is not None:
159 | all_datasets = [d for d in all_datasets if d not in blacklist]
160 |
161 | for name in all_datasets:
162 | yield name, load_data_from_tdc(name, disable_logs)
163 |
164 |
165 | TDC_TO_MOOD = {
166 | "BBB_Martins": "BBB",
167 | "CYP2C9_Veith": "CYP2C9",
168 | "Caco2_Wang": "Caco-2",
169 | "Clearance_Hepatocyte_AZ": "Clearance",
170 | "DILI": "DILI",
171 | "HIA_Hou": "HIA",
172 | "Half_Life_Obach": "HalfLife",
173 | "Lipophilicity_AstraZeneca": "Lipophilicity",
174 | "PPBR_AZ": "PPBR",
175 | "Pgp_Broccatelli": "Pgp",
176 | "hERG": "hERG",
177 | }
178 |
179 | # Ordered by size
180 | MOOD_DATASETS = [
181 | "DILI",
182 | "HIA",
183 | "hERG",
184 | "HalfLife",
185 | "Caco-2",
186 | "Clearance",
187 | "Pgp",
188 | "PPBR",
189 | "BBB",
190 | "Lipophilicity",
191 | "CYP2C9",
192 | ]
193 |
194 | MOOD_TO_TDC = {v: k for k, v in TDC_TO_MOOD.items()}
195 | MOOD_CLSF_DATASETS = ["BBB", "CYP2C9", "DILI", "HIA", "Pgp", "hERG"]
196 | MOOD_REGR_DATASETS = [d for d in MOOD_DATASETS if d not in MOOD_CLSF_DATASETS]
197 |
--------------------------------------------------------------------------------
/mood/distance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from typing import Optional, Union, List
5 |
6 | from sklearn.neighbors import NearestNeighbors
7 |
8 |
9 | def get_distance_metric(example):
10 | """Get the appropriate distance metric given an exemplary datapoint"""
11 |
12 | # By default we use the Euclidean distance
13 | metric = "euclidean"
14 |
15 | # For binary vectors we use jaccard
16 | if isinstance(example, pd.DataFrame):
17 | example = example.values # DataFrames would require all().all() otherwise
18 | if ((example == 0) | (example == 1)).all():
19 | metric = "jaccard"
20 |
21 | return metric
22 |
23 |
24 | def compute_knn_distance(
25 | X: np.ndarray,
26 | Y: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
27 | metric: Optional[str] = None,
28 | k: int = 5,
29 | n_jobs: Optional[int] = None,
30 | return_indices: bool = False,
31 | ):
32 | """
33 | Computes the mean k-Nearest Neighbors distance
34 | between a set of database embeddings and a set of query embeddings
35 |
36 | Args:
37 | X: The set of samples that form kNN candidates
38 | Y: The samples for which to find the kNN for. If None, will find kNN for `database`
39 | metric: The pairwise distance metric to define the neighborhood
40 | k: The number of neighbors to find
41 | n_jobs: Controls the parallelization
42 | return_indices: Whether to return the indices of the NNs as well
43 | """
44 |
45 | if metric is None:
46 | metric = get_distance_metric(X[0])
47 |
48 | knn = NearestNeighbors(n_neighbors=k, metric=metric, n_jobs=n_jobs)
49 | knn.fit(X)
50 |
51 | if not isinstance(Y, list):
52 | Y = [Y]
53 |
54 | distances, indices = [], []
55 | for queries in Y:
56 | if np.array_equal(X, queries):
57 | # Use k + 1 and filter out the first
58 | # because the sample will always be its own neighbor
59 | dist, ind = knn.kneighbors(queries, n_neighbors=k + 1)
60 | dist, ind = dist[:, 1:], ind[:, 1:]
61 | else:
62 | dist, ind = knn.kneighbors(queries, n_neighbors=k)
63 |
64 | distances.append(dist)
65 | indices.append(ind)
66 |
67 | # The distance from the query molecule to its NNs is the mean of all pairwise distances
68 | distances = [np.mean(dist, axis=1) for dist in distances]
69 |
70 | if len(distances) == 1:
71 | assert len(indices) == 1
72 | distances = distances[0]
73 | indices = indices[0]
74 |
75 | if return_indices:
76 | return distances, indices
77 | return distances
78 |
--------------------------------------------------------------------------------
/mood/experiment.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import fsspec
3 | import optuna
4 | import numpy as np
5 | import datamol as dm
6 |
7 | from copy import deepcopy
8 | from datetime import datetime
9 | from typing import Optional
10 | from loguru import logger
11 | from sklearn.base import BaseEstimator
12 | from sklearn.preprocessing import StandardScaler
13 | from torch.utils.data import DataLoader
14 | from torchmetrics.functional.regression import (
15 | r2_score,
16 | spearman_corrcoef,
17 | pearson_corrcoef,
18 | mean_absolute_error,
19 | )
20 |
21 | from mood.baselines import suggest_baseline_hparams, predict_baseline_uncertainty
22 | from mood.constants import RESULTS_DIR
23 | from mood.model import MOOD_DA_DG_ALGORITHMS, needs_domain_representation, is_domain_generalization
24 | from mood.model.base import Ensemble
25 | from mood.train import train_baseline_model, train
26 | from mood.criteria import get_mood_criteria
27 | from mood.dataset import (
28 | load_data_from_tdc,
29 | SimpleMolecularDataset,
30 | MOOD_REGR_DATASETS,
31 | domain_based_inference_collate,
32 | )
33 | from mood.distance import get_distance_metric
34 | from mood.metrics import Metric
35 | from mood.representations import featurize
36 | from mood.preprocessing import DEFAULT_PREPROCESSING
37 | from mood.splitter import get_mood_splitters, MOODSplitter
38 | from mood.utils import load_distances_for_downstream_application
39 | from mood.rct import get_experimental_configurations
40 |
41 |
42 | def run_study(metric, algorithm, n_startup_trials, n_trials, trial_fn, seed):
43 | """Endpoint for running an Optuna study"""
44 |
45 | sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)
46 |
47 | if isinstance(metric.mode, list):
48 | directions = ["maximize" if m == "max" else "minimize" for m in metric.mode]
49 | study = optuna.create_study(directions=directions, sampler=sampler)
50 | else:
51 | direction = "maximize" if metric.mode == "max" else "minimize"
52 | study = optuna.create_study(direction=direction, sampler=sampler)
53 |
54 | if algorithm == "GP":
55 | # ValueError: array must not contain infs or NaNs
56 | # LinAlgError: N-th leading minor of the array is not positive definite
57 | # LinAlgError: The kernel is not returning a positive definite matrix
58 | catch = (np.linalg.LinAlgError, ValueError)
59 | elif algorithm == "Mixup":
60 | # RuntimeError: all elements of input should be between 0 and 1
61 | # NOTE: This is not robust (as other RunTimeErrors could be thrown)
62 | # but is an easy, performant way to check for NaN values which often
63 | # occurred for Mixup due to high losses on the first few batches
64 | catch = (RuntimeError,)
65 | else:
66 | catch = ()
67 |
68 | study.optimize(trial_fn, n_trials=n_trials, catch=catch)
69 | return study
70 |
71 |
72 | def basic_tuning_loop(
73 | X_train,
74 | X_test,
75 | y_train,
76 | y_test,
77 | name: str,
78 | is_regression: bool,
79 | metric: Metric,
80 | global_seed: int,
81 | for_uncertainty_estimation: bool = False,
82 | ensemble_size: int = 10,
83 | n_trials: int = 50,
84 | n_startup_trials: int = 10,
85 | ):
86 | """
87 | This hyper-parameter search loop is used to train baseline models for the MOOD specification.
88 | All baselines are from scikit-learn.
89 |
90 | NOTE: This could be merged with the more elaborate tuning loop we wrote later
91 | However, for the sake of reproducibility, I wanted to keep this code intact.
92 | This way, the exact code used to generate results is still easily accessible
93 | in the code base
94 | """
95 |
96 | def run_trial(trial):
97 | random_state = global_seed + trial.number
98 | params = suggest_baseline_hparams(name, is_regression, trial)
99 | model = train_baseline_model(
100 | X_train,
101 | y_train,
102 | name,
103 | is_regression,
104 | params,
105 | random_state,
106 | for_uncertainty_estimation,
107 | ensemble_size,
108 | calibrate=True,
109 | )
110 | y_pred = model.predict(X_test)
111 | score = metric(y_test, y_pred)
112 | return score
113 |
114 | study = run_study(
115 | metric=metric,
116 | algorithm=name,
117 | n_startup_trials=n_startup_trials,
118 | n_trials=n_trials,
119 | trial_fn=run_trial,
120 | seed=global_seed,
121 | )
122 | return study
123 |
124 |
125 | def rct_dataset_setup(dataset, train_indices, val_indices, test_dataset, is_regression):
126 | """Sets up the dataset. Specifically, splits the dataset and standardizes the targets for regression tasks"""
127 |
128 | train_dataset = dataset.filter_by_indices(train_indices)
129 | val_dataset = dataset.filter_by_indices(val_indices)
130 | test_dataset = deepcopy(test_dataset)
131 |
132 | # Z-standardization of the targets
133 | if is_regression:
134 | scaler = StandardScaler()
135 | train_dataset.y = scaler.fit_transform(train_dataset.y)
136 | val_dataset.y = scaler.transform(val_dataset.y)
137 | test_dataset.y = scaler.transform(test_dataset.y)
138 |
139 | return train_dataset, val_dataset, test_dataset
140 |
141 |
142 | def rct_predict_step(model, dataset):
143 | """Get the predictions and uncertainty estimates from either a scikit-learn model or torch model"""
144 | if isinstance(model, BaseEstimator):
145 | y_pred = model.predict(dataset.X).reshape(-1, 1)
146 | uncertainty = predict_baseline_uncertainty(model, dataset.X)
147 | elif isinstance(model, Ensemble):
148 | collate_fn = domain_based_inference_collate if is_domain_generalization(model.models[0]) else None
149 | dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=model.models[0].batch_size)
150 | y_pred = model.predict(dataloader)
151 | uncertainty = model.predict_uncertainty(dataloader)
152 | else:
153 | raise NotImplementedError
154 | return y_pred, uncertainty
155 |
156 |
157 | def rct_compute_metrics(
158 | y_true, y_pred, y_uncertainty, performance_metric, calibration_metric, is_regression, prefix, suffix
159 | ):
160 | prf_score = performance_metric(y_true, y_pred)
161 |
162 | # NOTE: Ideally we would always compute the calibration metric,
163 | # but that was too computationally expensive due to the need of ensembles
164 | if y_uncertainty is not None:
165 | cal_score = calibration_metric(y_true, y_pred, y_uncertainty)
166 | else:
167 | cal_score = None
168 |
169 | ret = {
170 | f"{prefix}_calibration_{calibration_metric.name}_{suffix}": cal_score,
171 | f"{prefix}_performance_{performance_metric.name}_{suffix}": prf_score,
172 | }
173 |
174 | if is_regression:
175 | targets = Metric.preprocess_targets(y_true, is_regression)
176 | predictions = Metric.preprocess_predictions(y_pred, targets.device)
177 |
178 | # NOTE: Before starting the RCT, we were not sure what metric to use
179 | # to compare models for regression tasks, that's why we compute some extra here
180 | ret[f"{prefix}_extra_r2_{suffix}"] = r2_score(preds=predictions, target=targets).item()
181 | ret[f"{prefix}_extra_spearman_{suffix}"] = spearman_corrcoef(preds=predictions, target=targets).item()
182 | ret[f"{prefix}_extra_pearson_{suffix}"] = pearson_corrcoef(preds=predictions, target=targets).item()
183 | ret[f"{prefix}_extra_mae_{suffix}"] = mean_absolute_error(preds=predictions, target=targets).item()
184 |
185 | return ret
186 |
187 |
188 | def rct_evaluate_step(
189 | model,
190 | train_dataset,
191 | val_dataset,
192 | test_dataset,
193 | performance_metric,
194 | calibration_metric,
195 | is_regression,
196 | suffix,
197 | criterion: Optional = None,
198 | ):
199 | val_y_pred, val_uncertainty = rct_predict_step(model, val_dataset)
200 | val_metrics = rct_compute_metrics(
201 | y_true=val_dataset.y,
202 | y_pred=val_y_pred,
203 | y_uncertainty=val_uncertainty,
204 | performance_metric=performance_metric,
205 | calibration_metric=calibration_metric,
206 | is_regression=is_regression,
207 | prefix="val",
208 | suffix=suffix,
209 | )
210 |
211 | test_y_pred, test_uncertainty = rct_predict_step(model, test_dataset)
212 | test_metrics = rct_compute_metrics(
213 | y_true=test_dataset.y,
214 | y_pred=test_y_pred,
215 | y_uncertainty=test_uncertainty,
216 | performance_metric=performance_metric,
217 | calibration_metric=calibration_metric,
218 | is_regression=is_regression,
219 | prefix="test",
220 | suffix=suffix,
221 | )
222 |
223 | # Update the criterion used to select which model is based
224 | if criterion is not None:
225 | criterion.update(val_y_pred, val_uncertainty, train_dataset, val_dataset)
226 |
227 | metrics = val_metrics
228 | metrics.update(test_metrics)
229 | return metrics
230 |
231 |
232 | def rct_tuning_loop(
233 | train_val_dataset: SimpleMolecularDataset,
234 | test_dataset: SimpleMolecularDataset,
235 | algorithm: str,
236 | train_val_split: str,
237 | criterion_name: str,
238 | performance_metric: Metric,
239 | calibration_metric: Metric,
240 | is_regression: bool,
241 | global_seed: int,
242 | num_repeated_splits: int = 3,
243 | num_trials: int = 50,
244 | num_startup_trials: int = 10,
245 | ):
246 | """
247 | This hyper-parameter search loop is used to benchmark different tools to improve generalization in
248 | the MOOD investigation. It combines training scikit-learn and pytorch (lightning) models.
249 | """
250 |
251 | rng = np.random.default_rng(global_seed)
252 | seeds = rng.integers(0, 2**16, num_trials)
253 |
254 | def run_trial(trial: optuna.Trial):
255 | random_state = seeds[trial.number].item()
256 | trial.set_user_attr("trial_seed", random_state)
257 |
258 | splitters = get_mood_splitters(train_val_dataset.smiles, num_repeated_splits, random_state, n_jobs=-1)
259 | train_val_splitter = splitters[train_val_split]
260 |
261 | for split_idx, (train_ind, val_ind) in enumerate(train_val_splitter.split(train_val_dataset.X)):
262 | train_dataset, val_dataset, test_dataset_inner = rct_dataset_setup(
263 | train_val_dataset, train_ind, val_ind, test_dataset, is_regression
264 | )
265 |
266 | # NOTE: AUROC is not defined when there's just a single ground truth class.
267 | # Since this only happens for the unbalanced and small HIA dataset, we just skip.
268 | if performance_metric.name == "AUROC" and len(np.unique(val_dataset.y)) == 1:
269 | continue
270 |
271 | if algorithm in MOOD_DA_DG_ALGORITHMS:
272 | params = MOOD_DA_DG_ALGORITHMS[algorithm].suggest_params(trial)
273 | else:
274 | params = suggest_baseline_hparams(algorithm, is_regression, trial)
275 |
276 | model = train(
277 | train_dataset=train_dataset,
278 | val_dataset=val_dataset,
279 | test_dataset=test_dataset_inner,
280 | algorithm=algorithm,
281 | is_regression=is_regression,
282 | params=params,
283 | seed=random_state,
284 | calibrate=False,
285 | # NOTE: If we do not select models based on uncertainty,
286 | # we don't train an ensemble to reduce computational cost
287 | ensemble_size=5 if criterion.needs_uncertainty else 1,
288 | )
289 |
290 | metrics = rct_evaluate_step(
291 | model=model,
292 | train_dataset=train_dataset,
293 | val_dataset=val_dataset,
294 | test_dataset=test_dataset_inner,
295 | performance_metric=performance_metric,
296 | calibration_metric=calibration_metric,
297 | is_regression=is_regression,
298 | suffix=str(split_idx),
299 | criterion=criterion,
300 | )
301 |
302 | # We save the val and test performance for each trial to analyze the success
303 | # of the model selection procedure (gap between best and selected model)
304 | for k, v in metrics.items():
305 | trial.set_user_attr(k, v)
306 |
307 | return criterion.critique()
308 |
309 | criterion = get_mood_criteria(performance_metric, calibration_metric)[criterion_name]
310 |
311 | study = run_study(
312 | metric=criterion,
313 | algorithm=algorithm,
314 | n_startup_trials=num_startup_trials,
315 | n_trials=num_trials,
316 | trial_fn=run_trial,
317 | seed=global_seed,
318 | )
319 |
320 | return study
321 |
322 |
323 | def tune_cmd(
324 | dataset,
325 | algorithm,
326 | representation,
327 | train_val_split,
328 | criterion,
329 | seed: int = 0,
330 | use_cache: bool = False,
331 | base_save_dir: str = RESULTS_DIR,
332 | sub_save_dir: Optional[str] = None,
333 | overwrite: bool = False,
334 | ):
335 | """
336 | The MOOD tuning loop: Runs a hyper-parameter search.
337 |
338 | Prescribes a train-test split based on the MOOD specification and runs a hyper-parameter search
339 | for the training set.
340 | """
341 |
342 | if sub_save_dir is None:
343 | sub_save_dir = datetime.now().strftime("%Y%m%d")
344 |
345 | csv_out_dir = dm.fs.join(base_save_dir, "dataframes", "RCT", sub_save_dir)
346 | csv_fname = f"rct_study_{dataset}_{algorithm}_{representation}_{train_val_split}_{criterion}_{seed}.csv"
347 | csv_path = dm.fs.join(csv_out_dir, csv_fname)
348 | dm.fs.mkdir(csv_out_dir, exist_ok=True)
349 |
350 | yaml_out_dir = dm.fs.join(base_save_dir, "YAML", "RCT", sub_save_dir)
351 | yaml_fname = (
352 | f"rct_selected_model_{dataset}_{algorithm}_{representation}_{train_val_split}_{criterion}_{seed}.yaml"
353 | )
354 | yaml_path = dm.fs.join(yaml_out_dir, yaml_fname)
355 | dm.fs.mkdir(yaml_out_dir, exist_ok=True)
356 |
357 | if not overwrite and dm.fs.exists(yaml_path) and dm.fs.exists(csv_path):
358 | logger.info(f"Both the files already exists and overwrite=False. Skipping!")
359 | return
360 |
361 | # Load and preprocess the data
362 | smiles, y = load_data_from_tdc(dataset, disable_logs=True)
363 | X, mask = featurize(smiles, representation, DEFAULT_PREPROCESSING[representation], disable_logs=True)
364 | X = X.astype(np.float32)
365 | smiles = smiles[mask]
366 | y = y[mask]
367 |
368 | is_regression = dataset in MOOD_REGR_DATASETS
369 | if is_regression:
370 | y = y.reshape(-1, 1)
371 |
372 | # Prescribe a train-test split
373 | distances_vs = load_distances_for_downstream_application(
374 | "virtual_screening", representation, dataset, update_cache=not use_cache
375 | )
376 | distances_op = load_distances_for_downstream_application(
377 | "optimization", representation, dataset, update_cache=not use_cache
378 | )
379 |
380 | distance_metric = get_distance_metric(X)
381 | splitters = get_mood_splitters(smiles, 5, seed, n_jobs=-1)
382 | train_test_splitter = MOODSplitter(
383 | splitters, np.concatenate((distances_vs, distances_op)), distance_metric, k=5
384 | )
385 | train_test_splitter.fit(X)
386 |
387 | # Split the data using the prescribed split
388 | trainval, test = next(train_test_splitter.split(X, y))
389 | train_val_dataset = SimpleMolecularDataset(smiles[trainval], X[trainval], y[trainval])
390 | test_dataset = SimpleMolecularDataset(smiles[test], X[test], y[test])
391 |
392 | if needs_domain_representation(algorithm):
393 | train_val_dataset.compute_domain_representations()
394 | test_dataset.compute_domain_representations()
395 |
396 | # Get metrics for this dataset
397 | performance_metric = Metric.get_default_performance_metric(dataset)
398 | calibration_metric = Metric.get_default_calibration_metric(dataset)
399 |
400 | # Run the hyper-parameter search
401 | study = rct_tuning_loop(
402 | train_val_dataset=train_val_dataset,
403 | test_dataset=test_dataset,
404 | algorithm=algorithm,
405 | train_val_split=train_val_split,
406 | criterion_name=criterion,
407 | performance_metric=performance_metric,
408 | calibration_metric=calibration_metric,
409 | is_regression=is_regression,
410 | global_seed=seed,
411 | )
412 |
413 | # Train the best model found again, but this time as an ensemble
414 | # to evaluate the test performance and calibration
415 | if len(study.directions) > 1:
416 | logger.info(f"There's {len(study.best_trials)} models on the Pareto front. Picking one randomly!")
417 | rng = np.random.default_rng(seed)
418 | best_trial = rng.choice(study.best_trials)
419 | else:
420 | best_trial = study.best_trial
421 |
422 | # NOTE: Some methods are really sensitive to hyper-parameters (e.g. GPs, Mixup)
423 | # So with a different train-val split, these might no longer succeed to train.
424 | random_state = best_trial.user_attrs["trial_seed"]
425 | splitters = get_mood_splitters(train_val_dataset.smiles, 1, random_state, n_jobs=-1)
426 | train_val_splitter = splitters[train_val_split]
427 | train_ind, val_ind = next(train_val_splitter.split(train_val_dataset.X))
428 |
429 | train_dataset, val_dataset, test_dataset = rct_dataset_setup(
430 | train_val_dataset, train_ind, val_ind, test_dataset, is_regression
431 | )
432 | model = train(
433 | train_dataset=train_dataset,
434 | val_dataset=val_dataset,
435 | test_dataset=test_dataset,
436 | algorithm=algorithm,
437 | is_regression=is_regression,
438 | params=best_trial.params,
439 | seed=random_state,
440 | calibrate=False,
441 | ensemble_size=5,
442 | )
443 |
444 | metrics = rct_evaluate_step(
445 | model=model,
446 | train_dataset=train_dataset,
447 | val_dataset=val_dataset,
448 | test_dataset=test_dataset,
449 | performance_metric=performance_metric,
450 | calibration_metric=calibration_metric,
451 | is_regression=is_regression,
452 | suffix="final",
453 | )
454 |
455 | # Save the full trial results as a CSV
456 | logger.info(f"Saving the full study data to {csv_path}")
457 | df = study.trials_dataframe()
458 | df["dataset"] = dataset
459 | df["algorithm"] = algorithm
460 | df["representation"] = representation
461 | df["train-val split"] = train_val_split
462 | df["criterion"] = criterion
463 | df["seed"] = seed
464 | df.to_csv(csv_path)
465 |
466 | # Save the most important information as YAML (higher precision)
467 | data = {
468 | "hparams": best_trial.params,
469 | "criterion_final": best_trial.values,
470 | "dataset": dataset,
471 | "algorithm": algorithm,
472 | "representation": representation,
473 | "train_val_split": train_val_split,
474 | "criterion": criterion,
475 | "seed": seed,
476 | **best_trial.user_attrs,
477 | **metrics,
478 | }
479 |
480 | logger.info(f"Saving the data of the best model to {yaml_path}")
481 | with fsspec.open(yaml_path, "w") as fd:
482 | yaml.dump(data, fd)
483 |
484 |
485 | def rct_cmd(
486 | dataset: str,
487 | index: int,
488 | base_save_dir: str = RESULTS_DIR,
489 | sub_save_dir: Optional[str] = None,
490 | overwrite: bool = False,
491 | ):
492 | """
493 | Entrypoint for the benchmarking study in the MOOD Investigation.
494 |
495 | Deterministically samples one of the unordered set of experimental configurations in the RCT.
496 | And runs the tuning loop for that experimental configuration.
497 |
498 | Here an experimental configuration consists of an algorithm, representation, train-val split,
499 | model selection criterion and seed.
500 | """
501 |
502 | configs = get_experimental_configurations(dataset)
503 | logger.info(f"Sampled configuration #{index} / {len(configs)} for {dataset}: {configs[index]}")
504 | algorithm, representation, train_val_split, criterion, seed = configs[index]
505 |
506 | tune_cmd(
507 | dataset=dataset,
508 | algorithm=algorithm,
509 | representation=representation,
510 | train_val_split=train_val_split,
511 | criterion=criterion,
512 | seed=seed,
513 | base_save_dir=base_save_dir,
514 | sub_save_dir=sub_save_dir,
515 | overwrite=overwrite,
516 | )
517 |
--------------------------------------------------------------------------------
/mood/metrics.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import torch
3 | import datamol as dm
4 | import numpy as np
5 | from typing import Callable, Optional
6 |
7 | from sklearn.metrics import roc_auc_score
8 | from torchmetrics.functional import mean_absolute_error, mean_squared_error
9 | from torchmetrics.functional.classification import binary_auroc
10 | from torchmetrics.functional.regression.pearson import pearson_corrcoef
11 | from torchmetrics.wrappers.bootstrapping import _bootstrap_sampler
12 |
13 | from mood.dataset import MOOD_REGR_DATASETS
14 |
15 |
16 | def weighted_pearson(preds, target, sample_weights=None):
17 | """
18 | The weighted Pearson correlation efficient. Based on:
19 | https://stats.stackexchange.com/a/222107
20 | """
21 | if sample_weights is None:
22 | # If not sample weights are provided, just rely on TorchMetrics' implementation
23 | return pearson_corrcoef(preds=preds, target=target)
24 |
25 | def _weighted_mean(x, w):
26 | return torch.sum(w * x) / torch.sum(w)
27 |
28 | # Copied over the implementation from TorchMetric and made three changes:
29 | # 1) Instead of computing the mean, compute the weighted mean
30 | # 2) Computing the weighted covariance
31 | # 3) Computing the weighted correlation
32 |
33 | preds_diff = preds - _weighted_mean(preds, sample_weights)
34 | target_diff = target - _weighted_mean(target, sample_weights)
35 |
36 | cov = _weighted_mean(preds_diff * target_diff, sample_weights)
37 |
38 | preds_std = torch.sqrt(_weighted_mean(preds_diff * preds_diff, sample_weights))
39 | target_std = torch.sqrt(_weighted_mean(target_diff * target_diff, sample_weights))
40 |
41 | corrcoef = cov / (preds_std * target_std + 1e-6)
42 | return torch.clamp(corrcoef, -1.0, 1.0)
43 |
44 |
45 | def weighted_pearson_calibration(preds, target, uncertainty, sample_weights=None):
46 | error = torch.abs(preds - target)
47 | return weighted_pearson(error, uncertainty, sample_weights)
48 |
49 |
50 | def weighted_mae(preds, target, sample_weights=None):
51 | if sample_weights is None:
52 | return mean_absolute_error(preds=preds, target=target)
53 | summed_mae = torch.abs(sample_weights * (preds - target)).sum()
54 | return summed_mae / torch.sum(sample_weights)
55 |
56 |
57 | def weighted_brier_score(target, uncertainty, sample_weights=None):
58 | confidence = torch.stack([unc[tar] for unc, tar in zip(uncertainty, target)])
59 | if sample_weights is None:
60 | return mean_squared_error(preds=confidence, target=target)
61 | summed_mse = torch.square(sample_weights * (confidence - target)).sum()
62 | brier_score = summed_mse / torch.sum(sample_weights)
63 | return brier_score
64 |
65 |
66 | def weighted_auroc(preds, target, sample_weights=None):
67 | if sample_weights is None:
68 | return binary_auroc(preds=preds, target=target)
69 |
70 | # TorchMetrics does not actually support sample weights, so we rely on the sklearn implementation
71 | preds = preds.cpu().numpy()
72 | target = target.cpu().numpy()
73 | sample_weights = sample_weights.cpu().numpy()
74 | return roc_auc_score(y_true=target, y_score=preds, sample_weight=sample_weights)
75 |
76 |
77 | class TargetType(enum.Enum):
78 | REGRESSION = "regression"
79 | BINARY_CLASSIFICATION = "binary_classification"
80 |
81 | def is_regression(self):
82 | return self == TargetType.REGRESSION
83 |
84 |
85 | class Metric:
86 | def __init__(
87 | self,
88 | name: str,
89 | fn: Callable,
90 | mode: str,
91 | target_type: TargetType,
92 | needs_predictions: bool = True,
93 | needs_uncertainty: bool = False,
94 | range_min: Optional[float] = None,
95 | range_max: Optional[float] = None,
96 | ):
97 | self.fn_ = fn
98 | self.name = name
99 | self.mode = mode
100 | self.target_type = target_type
101 | self.needs_predictions = needs_predictions
102 | self.needs_uncertainty = needs_uncertainty
103 | self.range_min = range_min
104 | self.range_max = range_max
105 |
106 | @property
107 | def is_calibration(self):
108 | return self.needs_uncertainty
109 |
110 | @classmethod
111 | def get_default_calibration_metric(cls, dataset):
112 | is_regression = dataset in MOOD_REGR_DATASETS
113 | if is_regression:
114 | metric = cls.by_name("Pearson")
115 | else:
116 | metric = cls.by_name("Brier score")
117 | return metric
118 |
119 | @classmethod
120 | def get_default_performance_metric(cls, dataset):
121 | is_regression = dataset in MOOD_REGR_DATASETS
122 | if is_regression:
123 | metric = cls.by_name("MAE")
124 | else:
125 | metric = cls.by_name("AUROC")
126 | return metric
127 |
128 | @classmethod
129 | def by_name(cls, name):
130 | if name == "MAE":
131 | return cls(
132 | name="MAE",
133 | fn=weighted_mae,
134 | mode="min",
135 | target_type=TargetType.REGRESSION,
136 | needs_predictions=True,
137 | needs_uncertainty=False,
138 | range_min=0,
139 | range_max=None,
140 | )
141 | elif name == "Pearson":
142 | return cls(
143 | name="Pearson",
144 | fn=weighted_pearson_calibration,
145 | mode="max",
146 | target_type=TargetType.REGRESSION,
147 | needs_predictions=True,
148 | needs_uncertainty=True,
149 | range_min=-1,
150 | range_max=1,
151 | )
152 | elif name == "AUROC":
153 | return cls(
154 | name="AUROC",
155 | fn=weighted_auroc,
156 | mode="max",
157 | target_type=TargetType.BINARY_CLASSIFICATION,
158 | needs_predictions=True,
159 | needs_uncertainty=False,
160 | range_min=0,
161 | range_max=1,
162 | )
163 | elif name == "Brier score":
164 | return cls(
165 | name="Brier score",
166 | fn=weighted_brier_score,
167 | mode="min",
168 | target_type=TargetType.BINARY_CLASSIFICATION,
169 | needs_predictions=False,
170 | needs_uncertainty=True,
171 | range_min=0,
172 | range_max=1,
173 | )
174 |
175 | def __call__(
176 | self, y_true, y_pred: Optional = None, uncertainty: Optional = None, sample_weights: Optional = None
177 | ):
178 | if self.needs_uncertainty and uncertainty is None:
179 | raise ValueError("Uncertainty estimates needed, but not provided.")
180 | if self.needs_predictions and y_pred is None:
181 | raise ValueError("Predictions needed, but not provided.")
182 | kwargs = self.to_kwargs(y_true, y_pred, uncertainty, sample_weights)
183 | return self.fn_(**kwargs).item()
184 |
185 | @staticmethod
186 | def preprocess_targets(y_true, is_regression: bool):
187 | if not isinstance(y_true, torch.Tensor):
188 | y_true = torch.tensor(y_true)
189 | if is_regression:
190 | y_true = y_true.float().squeeze()
191 | if y_true.ndim == 0:
192 | y_true = y_true.unsqueeze(0)
193 | else:
194 | y_true = y_true.int()
195 |
196 | return y_true
197 |
198 | @staticmethod
199 | def preprocess_predictions(y_pred, device):
200 | if not isinstance(y_pred, torch.Tensor):
201 | y_pred = torch.tensor(y_pred, device=device)
202 | y_pred = y_pred.float().squeeze()
203 | if y_pred.ndim == 0:
204 | y_pred = y_pred.unsqueeze(0)
205 | return y_pred
206 |
207 | @staticmethod
208 | def preprocess_uncertainties(uncertainty, device):
209 | return Metric.preprocess_predictions(uncertainty, device)
210 |
211 | @staticmethod
212 | def preprocess_sample_weights(sample_weights, device):
213 | if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
214 | sample_weights = torch.tensor(sample_weights, device=device)
215 | return sample_weights
216 |
217 | @property
218 | def range(self):
219 | return self.range_min, self.range_max
220 |
221 | def to_kwargs(self, y_true, y_pred, uncertainty, sample_weights):
222 | kwargs = {"target": self.preprocess_targets(y_true, self.target_type.is_regression())}
223 | kwargs["sample_weights"] = self.preprocess_sample_weights(sample_weights, kwargs["target"].device)
224 | if self.needs_predictions:
225 | kwargs["preds"] = self.preprocess_predictions(y_pred, kwargs["target"].device)
226 | if self.needs_uncertainty:
227 | kwargs["uncertainty"] = self.preprocess_uncertainties(uncertainty, kwargs["target"].device)
228 | return kwargs
229 |
230 |
231 | def compute_bootstrapped_metric(
232 | targets,
233 | metric: Metric,
234 | predictions: Optional = None,
235 | uncertainties: Optional = None,
236 | sample_weights: Optional = None,
237 | sampling_strategy: str = "poisson",
238 | n_bootstraps: int = 1000,
239 | n_jobs: Optional[int] = None,
240 | ):
241 | """
242 | Bootstrapping to compute confidence intervals for a metric.
243 | Inspired by https://stackoverflow.com/a/19132400 and
244 | https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/wrappers/bootstrapping.py
245 | """
246 |
247 | def fn(it):
248 | indices = _bootstrap_sampler(len(predictions), sampling_strategy=sampling_strategy)
249 | _sample_weights = None if sample_weights is None else sample_weights[indices]
250 | _uncertainties = None if uncertainties is None else uncertainties[indices]
251 | _predictions = None if predictions is None else predictions[indices]
252 | score = metric(targets[indices], _predictions, _uncertainties, _sample_weights)
253 | return score
254 |
255 | bootstrapped_scores = dm.utils.parallelized(fn, range(n_bootstraps), n_jobs=n_jobs)
256 | bootstrapped_scores = [score for score in bootstrapped_scores if score is not None]
257 | return np.mean(bootstrapped_scores), np.std(bootstrapped_scores)
258 |
--------------------------------------------------------------------------------
/mood/model/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from sklearn.base import BaseEstimator
4 |
5 | from mood.model.base import BaseModel
6 | from mood.model.vrex import VREx
7 | from mood.model.coral import CORAL
8 | from mood.model.dann import DANN
9 | from mood.model.ib_erm import InformationBottleneckERM
10 | from mood.model.mixup import Mixup
11 | from mood.model.erm import ERM
12 | from mood.model.mtl import MTL
13 |
14 |
15 | MOOD_DA_DG_ALGORITHMS = {
16 | "VREx": VREx,
17 | "CORAL": CORAL,
18 | "DANN": DANN,
19 | "IB-ERM": InformationBottleneckERM,
20 | "Mixup": Mixup,
21 | "MLP": ERM,
22 | "MTL": MTL,
23 | }
24 |
25 | MOOD_ALGORITHMS = ["RF", "GP", "MLP", "MTL", "VREx", "IB-ERM", "CORAL", "DANN", "Mixup"]
26 |
27 |
28 | def _get_type(model: Union[BaseEstimator, BaseModel, str]):
29 | if isinstance(model, str):
30 | model_type = MOOD_DA_DG_ALGORITHMS.get(model)
31 | else:
32 | model_type = type(model)
33 | if not (model_type is None or issubclass(model_type, BaseEstimator) or issubclass(model_type, BaseModel)):
34 | raise TypeError(f"Can only test models from sklearn, good-learn or mood, not {model_type}")
35 | return model_type
36 |
37 |
38 | def is_domain_adaptation(model: Union[BaseEstimator, BaseModel, str]):
39 | model_type = _get_type(model)
40 | return model_type in [Mixup, DANN, CORAL]
41 |
42 |
43 | def is_domain_generalization(model: Union[BaseEstimator, BaseModel, str]):
44 | model_type = _get_type(model)
45 | return model_type in [MTL, InformationBottleneckERM, VREx]
46 |
47 |
48 | def needs_domain_representation(model: Union[BaseEstimator, BaseModel, str]):
49 | model_type = _get_type(model)
50 | return model_type == MTL
51 |
--------------------------------------------------------------------------------
/mood/model/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from functools import partial
3 |
4 | import torch
5 | from torch import nn
6 | from itertools import tee
7 | from typing import Optional, Union
8 | from pytorch_lightning import LightningModule
9 |
10 |
11 | class BaseModel(LightningModule, abc.ABC):
12 | def __init__(
13 | self,
14 | base_network: nn.Module,
15 | prediction_head: nn.Module,
16 | loss_fn: nn.Module,
17 | lr: float,
18 | weight_decay: Union[float, str],
19 | batch_size: int,
20 | ):
21 | super().__init__()
22 | self.base_network = base_network
23 | self.prediction_head = prediction_head
24 | self.loss_fn = partial(self.loss_function_wrapper, loss_fn=loss_fn)
25 | self.l2 = weight_decay
26 | self.lr = lr
27 | self.batch_size = batch_size
28 |
29 | def forward(self, x, domains: Optional = None, return_embedding: bool = False):
30 | embedding = self.base_network(x)
31 | label = self.prediction_head(embedding)
32 | out = (label, embedding) if return_embedding else label
33 | return out
34 |
35 | def training_step(
36 | self, batch, batch_idx, optimizer_idx: Optional[int] = None, dataset_idx: Optional[int] = None
37 | ):
38 | return self._step(batch, batch_idx, optimizer_idx)
39 |
40 | def validation_step(self, batch, batch_idx, dataset_idx: Optional[int] = None):
41 | return self._step(batch, batch_idx, optimizer_idx=None)
42 |
43 | def predict(self, dataloader):
44 | self.training = False
45 | with torch.inference_mode():
46 | res = torch.cat([self.forward(*X) for X, y in dataloader], dim=0)
47 | return res
48 |
49 | @staticmethod
50 | def loss_function_wrapper(preds, targets, loss_fn):
51 | if preds.ndim > 1:
52 | preds = preds.squeeze(dim=-1)
53 | targets = targets.float()
54 | return loss_fn(preds, targets)
55 |
56 | @abc.abstractmethod
57 | def _step(self, batch, batch_idx, optimizer_idx: Optional[int] = None):
58 | raise NotImplementedError
59 |
60 | def configure_optimizers(self):
61 | if self.l2 == "auto":
62 | parameters, parameters_copy = tee(self.parameters())
63 | self.l2 = 1.0 / sum(p.numel() for p in parameters_copy if p.requires_grad)
64 |
65 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2)
66 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
67 | return {
68 | "optimizer": optimizer,
69 | "lr_scheduler": {
70 | "scheduler": lr_scheduler,
71 | "interval": "epoch",
72 | "frequency": 1,
73 | "monitor": "val_loss",
74 | "strict": True,
75 | },
76 | }
77 |
78 | def log(self, name: str, *args, **kwargs):
79 | prefix = "train_" if self.training else "val_"
80 | super(BaseModel, self).log(prefix + name, *args, batch_size=self.batch_size, **kwargs)
81 |
82 | @staticmethod
83 | def suggest_params(trial):
84 | width = trial.suggest_categorical("mlp_width", [64, 128, 256, 512])
85 | depth = trial.suggest_int("mlp_depth", 1, 5)
86 | lr = trial.suggest_float("lr", 1e-8, 1e-1, log=True)
87 | batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
88 | weight_decay = trial.suggest_categorical(
89 | "weight_decay", ["auto", 0.0, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0]
90 | )
91 | return {
92 | "mlp_width": width,
93 | "mlp_depth": depth,
94 | "lr": lr,
95 | "weight_decay": weight_decay,
96 | "batch_size": batch_size,
97 | }
98 |
99 |
100 | class Ensemble(LightningModule):
101 | def __init__(self, models, is_regression):
102 | super().__init__()
103 | self.models = models
104 | self.is_regression = is_regression
105 |
106 | def predict(self, dataloader):
107 | return torch.stack([model.predict(dataloader) for model in self.models]).mean(dim=0)
108 |
109 | def predict_uncertainty(self, dataloader):
110 | if self.is_regression:
111 | if len(self.models) == 1:
112 | uncertainty = None
113 | else:
114 | uncertainty = torch.stack([model.predict(dataloader) for model in self.models]).var(dim=0)
115 | else:
116 | proba = self.predict(dataloader)[:, 0]
117 | uncertainty = torch.stack([proba, 1.0 - proba], dim=1)
118 | return uncertainty
119 |
--------------------------------------------------------------------------------
/mood/model/coral.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, List
3 |
4 | from torch import nn
5 | from mood.model.base import BaseModel
6 | from mood.model.utils import linear_interpolation
7 |
8 |
9 | class CORAL(BaseModel):
10 | """CORAL is a **domain adaptation** method.
11 |
12 | In addition to the traditional loss, adds a penalty based on the difference of the first and second moment
13 | of the source and target features.
14 |
15 | References:
16 | Sun, B., & Saenko, K. (2016, October). Deep coral: Correlation alignment for deep domain adaptation.
17 | In European conference on computer vision (pp. 443-450). Springer, Cham.
18 | https://arxiv.org/abs/1607.01719
19 | """
20 |
21 | def __init__(
22 | self,
23 | base_network: nn.Module,
24 | prediction_head: nn.Module,
25 | loss_fn: nn.Module,
26 | batch_size: int,
27 | penalty_weight: float,
28 | penalty_weight_schedule: List[int],
29 | lr: float = 1e-4,
30 | weight_decay: float = 0,
31 | ):
32 | """
33 | Args:
34 | base_network: The neural network architecture endoing the features
35 | prediction_head: The neural network architecture that takes the concatenated
36 | representation of the domain and features and returns a task-specific prediction.
37 | loss_fn: The loss function to optimize for.
38 | penalty_weight: The weight to multiply the penalty by.
39 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the
40 | penalty weight. The first integer is the last epoch at which the penalty weight is 0.
41 | The second integer is the first epoch at which the penalty weight is its max value.
42 | Linearly interpolates between the two.
43 | """
44 |
45 | super().__init__(
46 | lr=lr,
47 | weight_decay=weight_decay,
48 | base_network=base_network,
49 | prediction_head=prediction_head,
50 | loss_fn=loss_fn,
51 | batch_size=batch_size,
52 | )
53 |
54 | self.penalty_weight = penalty_weight
55 | if len(penalty_weight_schedule) != 2:
56 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step")
57 | self.start = penalty_weight_schedule[0]
58 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0]
59 |
60 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
61 | if self.training:
62 | batch_src, batch_tgt = batch["source"], batch["target"]
63 | (x_src, domains_src), y_true = batch_src
64 | x_tgt, domains_tgt = batch_tgt
65 | y_pred, phis_src = self.forward(x_src, return_embedding=True)
66 | _, phis_tgt = self.forward(x_tgt, return_embedding=True)
67 | loss = self._loss(y_pred, y_true, phis_src, phis_tgt)
68 |
69 | else:
70 | (x, domains), y_true = batch
71 | y_pred = self.forward(x)
72 | loss = self._loss(y_pred, y_true)
73 |
74 | self.log("loss", loss)
75 | return loss
76 |
77 | def _loss(self, y_pred, y_true, phis_src: Optional = None, phis_tgt: Optional = None):
78 | erm_loss = self.loss_fn(y_pred, y_true)
79 |
80 | if not self.training:
81 | return erm_loss
82 |
83 | penalty = self._coral_penalty(phis_src, phis_tgt)
84 |
85 | penalty_weight = linear_interpolation(
86 | self.current_epoch, self.duration, self.penalty_weight, start=self.start
87 | )
88 |
89 | loss = erm_loss + penalty_weight * penalty
90 | return loss
91 |
92 | @staticmethod
93 | def _coral_penalty(x, y):
94 | """The CORAL penalty aligns the Covariance matrix of the features across domains"""
95 | mean_x = x.mean(0, keepdim=True)
96 | mean_y = y.mean(0, keepdim=True)
97 | cent_x = x - mean_x
98 | cent_y = y - mean_y
99 | cova_x = (cent_x.T @ cent_x) / max(1, len(x) - 1)
100 | cova_y = (cent_y.T @ cent_y) / max(1, len(y) - 1)
101 |
102 | mean_diff = (mean_x - mean_y).pow(2).mean()
103 | cova_diff = (cova_x - cova_y).pow(2).mean()
104 | return mean_diff + cova_diff
105 |
106 | @staticmethod
107 | def suggest_params(trial):
108 | params = BaseModel.suggest_params(trial)
109 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True)
110 | params["penalty_weight_schedule"] = trial.suggest_categorical(
111 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]]
112 | )
113 | return params
114 |
--------------------------------------------------------------------------------
/mood/model/dann.py:
--------------------------------------------------------------------------------
1 | from itertools import chain, tee
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from typing import Optional, List
6 | from torch import nn
7 |
8 | from mood.model.base import BaseModel
9 | from mood.model.nn import FCLayer
10 | from mood.model.utils import linear_interpolation
11 |
12 |
13 | class DANN(BaseModel):
14 | """Domain Adversarial Neural Network (DANN) is a **domain adaptation** method.
15 |
16 | Adversarial framework that includes a prediction and discriminator network. The goal of the discriminator
17 | is to classify the domain (source or target) from the hidden embedding. The goal of the predictor is to achieve
18 | a good task-specific performance. By optimizing these in an adversarial fashion, the goal is to have
19 | domain-invariant features.
20 |
21 | References:
22 | Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., ... & Lempitsky, V. (2016).
23 | Domain-adversarial training of neural networks. The journal of machine learning research, 17(1), 2096-2030.
24 | https://arxiv.org/abs/1505.07818
25 | """
26 |
27 | def __init__(
28 | self,
29 | base_network: nn.Module,
30 | prediction_head: FCLayer,
31 | loss_fn: nn.Module,
32 | batch_size: int,
33 | penalty_weight: float,
34 | penalty_weight_schedule: List[int],
35 | discriminator_network: Optional[nn.Module] = None,
36 | lr: float = 1e-3,
37 | discr_lr: float = 1e-3,
38 | weight_decay: float = "auto",
39 | discr_weight_decay: float = "auto",
40 | lambda_reg: float = 0.1,
41 | n_discr_steps_per_predictor_step=3,
42 | ):
43 | """
44 | Args:
45 | base_network: The neural network architecture endoing the features
46 | prediction_head: The neural network architecture that takes the concatenated
47 | representation of the domain and features and returns a task-specific prediction.
48 | loss_fn: The loss function to optimize for.
49 | penalty_weight: The weight to multiply the penalty by.
50 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the
51 | penalty weight. The first integer is the last epoch at which the penalty weight is 0.
52 | The second integer is the first epoch at which the penalty weight is its max value.
53 | Linearly interpolates between the two.
54 | discriminator_network: The discriminator network that predicts the domain from the hidden embedding.
55 | lambda_reg: An additional weighing factor for the penalty. Following the implementation of DomainBed.
56 | """
57 | super().__init__(
58 | lr=lr,
59 | weight_decay=weight_decay,
60 | base_network=base_network,
61 | prediction_head=prediction_head,
62 | loss_fn=loss_fn,
63 | batch_size=batch_size,
64 | )
65 |
66 | self._discriminator_network = discriminator_network
67 | if self._discriminator_network is None:
68 | self._discriminator_network = FCLayer(prediction_head.input_dim, 2, activation=None)
69 |
70 | self.penalty_weight = penalty_weight
71 | if len(penalty_weight_schedule) != 2:
72 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step")
73 | self.start = penalty_weight_schedule[0]
74 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0]
75 |
76 | self._discriminator_loss = nn.CrossEntropyLoss()
77 | self.discr_lr = discr_lr
78 | self.discr_l2 = discr_weight_decay
79 | self.lambda_reg = lambda_reg
80 | self._n_discr_steps_per_predictor_step = n_discr_steps_per_predictor_step
81 |
82 | @staticmethod
83 | def get_optimizer(parameters, lr, weight_decay, monitor: str = "val_loss"):
84 | if weight_decay == "auto":
85 | parameters, parameters_copy = tee(parameters)
86 | weight_decay = 1.0 / sum(p.numel() for p in parameters_copy if p.requires_grad)
87 |
88 | optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
89 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
90 | return {
91 | "optimizer": optimizer,
92 | "lr_scheduler": {
93 | "scheduler": lr_scheduler,
94 | "interval": "epoch",
95 | "frequency": 1,
96 | "monitor": monitor,
97 | "strict": True,
98 | },
99 | }
100 |
101 | def configure_optimizers(self):
102 | optimizer_predictor = self.get_optimizer(
103 | chain(self.base_network.parameters(), self.prediction_head.parameters()),
104 | self.lr,
105 | self.l2,
106 | )
107 | optimizer_predictor["frequency"] = self._n_discr_steps_per_predictor_step
108 |
109 | optimizer_discriminator = self.get_optimizer(
110 | chain(self.base_network.parameters(), self._discriminator_network.parameters()),
111 | self.discr_lr,
112 | self.discr_l2,
113 | )
114 | optimizer_discriminator["frequency"] = 1
115 |
116 | return optimizer_predictor, optimizer_discriminator
117 |
118 | def forward(
119 | self, x, domains: Optional = None, return_embedding: bool = False, return_discriminator: bool = False
120 | ):
121 | input_embeddings = self.base_network(x)
122 | label = self.prediction_head(input_embeddings)
123 |
124 | ret = (label,)
125 | if return_embedding:
126 | ret += (input_embeddings,)
127 |
128 | if return_discriminator:
129 | discr_out = self._discriminator_network(input_embeddings)
130 | ret += (discr_out,)
131 |
132 | if len(ret) == 1:
133 | return ret[0]
134 | return ret
135 |
136 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
137 | if not self.training:
138 | (x, domains), y_true = batch
139 | y_pred = self.forward(x)
140 | loss = self.loss_fn(y_pred, y_true)
141 | self.log("loss", loss)
142 | return loss
143 |
144 | batch_src, batch_tgt = batch["source"], batch["target"]
145 |
146 | (x_src, domains_src), y_true = batch_src
147 | x_tgt, domains_tgt = batch_tgt
148 |
149 | y_pred, input_embeddings_src, discr_out_src = self.forward(
150 | x_src, return_embedding=True, return_discriminator=True
151 | )
152 | _, input_embeddings_tgt, discr_out_tgt = self.forward(
153 | x_tgt, return_embedding=True, return_discriminator=True
154 | )
155 |
156 | erm_loss = self.loss_fn(y_pred, y_true)
157 |
158 | # For losses w.r.t. the discriminator
159 | domain_pred = torch.cat([discr_out_src, discr_out_tgt], dim=0)
160 | domain_true = torch.cat([torch.zeros_like(discr_out_src), torch.ones_like(discr_out_tgt)], dim=0)
161 | loss = self._discriminator_loss(domain_pred, domain_true)
162 |
163 | domain_pred_softmax = F.softmax(domain_pred, dim=1)
164 | penalty_loss = domain_pred_softmax[:, domain_true.long()].sum()
165 |
166 | # When optimizing the discriminator, PTL automatically set requires_grad to False
167 | # for all parameters not optimized by the parameter. However, we still need True to
168 | # be able to compute the gradient penalty
169 | if optimizer_idx == 1:
170 | for param in self.base_network.parameters():
171 | param.requires_grad = True
172 |
173 | penalty = self.gradient_reversal(penalty_loss, [input_embeddings_src, input_embeddings_tgt])
174 |
175 | if optimizer_idx == 1:
176 | for param in self.base_network.parameters():
177 | param.requires_grad = False
178 |
179 | penalty_weight = linear_interpolation(
180 | self.current_epoch, self.duration, self.penalty_weight, start=self.start
181 | )
182 | loss += penalty * penalty_weight
183 |
184 | if optimizer_idx == 1:
185 | # Train the discriminator
186 | self.log("discriminator_loss", loss)
187 | return loss
188 | else:
189 | # Train the predictor and add a penalty for making features domain-invariant
190 | loss = erm_loss + (self.lambda_reg * -loss)
191 | self.log("predictive_loss", loss)
192 | return loss
193 |
194 | @staticmethod
195 | def gradient_reversal(loss, inputs):
196 | grad = torch.cat(torch.autograd.grad(loss, inputs, create_graph=True))
197 | return (grad**2).sum(dim=1).mean(dim=0)
198 |
199 | @staticmethod
200 | def suggest_params(trial):
201 | params = BaseModel.suggest_params(trial)
202 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 1e-10, 1.0, log=True)
203 | params["penalty_weight_schedule"] = trial.suggest_categorical(
204 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]]
205 | )
206 | params["discr_lr"] = trial.suggest_float("discr_lr", 1e-8, 1.0, log=True)
207 | params["discr_weight_decay"] = trial.suggest_categorical(
208 | "discr_weight_decay", ["auto", 0.0, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0]
209 | )
210 | params["lambda_reg"] = trial.suggest_float("lambda_reg", 0.001, 10, log=True)
211 | params["n_discr_steps_per_predictor_step"] = trial.suggest_int(
212 | "n_discr_steps_per_predictor_step", 1, 5
213 | )
214 | return params
215 |
--------------------------------------------------------------------------------
/mood/model/erm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from mood.model.base import BaseModel
3 |
4 |
5 | class ERM(BaseModel):
6 | """Empirical Risk Minimization
7 |
8 | The "vanilla" neural network. Updates the weight to minimize the loss of the batch.
9 |
10 | References:
11 | Vapnik, V. N. (1998). Statistical Learning Theory. Wiley-Interscience.
12 | https://www.wiley.com/en-fr/Statistical+Learning+Theory-p-9780471030034
13 | """
14 |
15 | def _step(self, batch, batch_idx, optimizer_idx: Optional[int] = None):
16 | (x, domain), y_true = batch
17 | y_pred = self.forward(x)
18 | loss = self._loss(y_pred, y_true)
19 | self.log("loss", loss, prog_bar=True)
20 | return loss
21 |
22 | def _loss(self, y_pred, y_true):
23 | return self.loss_fn(y_pred, y_true)
24 |
--------------------------------------------------------------------------------
/mood/model/ib_erm.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from mood.model.base import BaseModel
7 | from mood.model.utils import linear_interpolation
8 |
9 |
10 | class InformationBottleneckERM(BaseModel):
11 | """Information Bottleneck ERM (IB-ERM) is a **domain generalization** method.
12 |
13 | Similar to MTL, computes the loss per domain and computes the mean of these domain-specific losses.
14 | Additionally, adds a penalty based on the variance of the feature dimensions.
15 |
16 | References:
17 | Ahuja, K., Caballero, E., Zhang, D., Gagnon-Audet, J. C., Bengio, Y., Mitliagkas, I., & Rish, I. (2021).
18 | Invariance principle meets information bottleneck for out-of-distribution generalization.
19 | Advances in Neural Information Processing Systems, 34, 3438-3450.
20 | https://arxiv.org/abs/2106.06607
21 | """
22 |
23 | def __init__(
24 | self,
25 | base_network: nn.Module,
26 | prediction_head: nn.Module,
27 | loss_fn: nn.Module,
28 | batch_size: int,
29 | penalty_weight: float,
30 | penalty_weight_schedule: List[int],
31 | lr=1e-3,
32 | weight_decay=0,
33 | ):
34 | """
35 | Args:
36 | base_network: The neural network architecture endoing the features
37 | prediction_head: The neural network architecture that takes the concatenated
38 | representation of the domain and features and returns a task-specific prediction.
39 | loss_fn: The loss function to optimize for.
40 | penalty_weight: The weight to multiply the penalty by.
41 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the
42 | penalty weight. The first integer is the last epoch at which the penalty weight is 0.
43 | The second integer is the first epoch at which the penalty weight is its max value.
44 | Linearly interpolates between the two.
45 | """
46 |
47 | super().__init__(
48 | lr=lr,
49 | weight_decay=weight_decay,
50 | base_network=base_network,
51 | prediction_head=prediction_head,
52 | loss_fn=loss_fn,
53 | batch_size=batch_size,
54 | )
55 |
56 | self.penalty_weight = penalty_weight
57 | if len(penalty_weight_schedule) != 2:
58 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step")
59 | self.start = penalty_weight_schedule[0]
60 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0]
61 |
62 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
63 | phis = []
64 | erm_loss = 0
65 |
66 | for mini_batch in batch:
67 | (xs, _), y_true = mini_batch
68 | y_pred, phi = self.forward(xs, return_embedding=True)
69 | erm_loss += self.loss_fn(y_pred, y_true)
70 | phis.append(phi)
71 |
72 | erm_loss /= len(batch)
73 | phis = torch.cat(phis, dim=0)
74 | loss = self._loss(erm_loss, phis)
75 | self.log("loss", loss)
76 | return loss
77 |
78 | def _loss(self, erm_loss, phis):
79 | if not self.training:
80 | return erm_loss
81 |
82 | penalty_weight = linear_interpolation(
83 | step=self.current_epoch,
84 | duration=self.duration,
85 | max_value=self.penalty_weight,
86 | start=self.start,
87 | )
88 |
89 | penalty = 0
90 | for i in range(len(phis)):
91 | # Add the Information Bottleneck penalty
92 | penalty += self.ib_penalty(phis[i])
93 | penalty /= len(phis)
94 |
95 | loss = erm_loss + penalty_weight * penalty
96 | return loss
97 |
98 | @staticmethod
99 | def ib_penalty(features):
100 | if len(features) == 1:
101 | return 0.0
102 | return features.var(dim=0).mean()
103 |
104 | @staticmethod
105 | def suggest_params(trial):
106 | params = BaseModel.suggest_params(trial)
107 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True)
108 | params["penalty_weight_schedule"] = trial.suggest_categorical(
109 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]]
110 | )
111 | return params
112 |
--------------------------------------------------------------------------------
/mood/model/mixup.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from mood.model.base import BaseModel
7 | from mood.model.utils import linear_interpolation
8 |
9 |
10 | class Mixup(BaseModel):
11 | """Mixup is a **domain adaptation** method.
12 |
13 | Mixup interpolates both the features and (pseudo-)targets inter- and intra-domain and trains on
14 | these interpolates samples instead.
15 |
16 | References:
17 | Yan, S., Song, H., Li, N., Zou, L., & Ren, L. (2020). Improve unsupervised domain adaptation with mixup training.
18 | https://arxiv.org/abs/2001.00677
19 | """
20 |
21 | def __init__(
22 | self,
23 | base_network: nn.Module,
24 | prediction_head: nn.Module,
25 | loss_fn: nn.Module,
26 | batch_size: int,
27 | penalty_weight: float,
28 | penalty_weight_schedule: List[int],
29 | lr: float = 1e-3,
30 | weight_decay: float = "auto",
31 | augmentation_std: float = 0.1,
32 | no_augmentations: int = 10,
33 | alpha: float = 0.1,
34 | ):
35 | """
36 | Args:
37 | base_network: The neural network architecture endoing the features
38 | prediction_head: The neural network architecture that takes the concatenated
39 | representation of the domain and features and returns a task-specific prediction.
40 | loss_fn: The loss function to optimize for.
41 | penalty_weight: The weight to multiply the penalty by.
42 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the
43 | penalty weight. The first integer is the last epoch at which the penalty weight is 0.
44 | The second integer is the first epoch at which the penalty weight is its max value.
45 | Linearly interpolates between the two.
46 | augmentation_std: The standard deviation of the noise to multiply each sample by
47 | as augmentation.
48 | no_augmentations: The number of augmentations to do to compute pseudo labels for the unlabeled samples
49 | of the target domain.
50 | alpha: The parameter of the Beta distribution used to compute the interpolation factor
51 | """
52 | super().__init__(
53 | lr=lr,
54 | weight_decay=weight_decay,
55 | base_network=base_network,
56 | prediction_head=prediction_head,
57 | loss_fn=loss_fn,
58 | batch_size=batch_size,
59 | )
60 |
61 | self._classification_loss = torch.nn.BCELoss()
62 | self._target_domain_loss = torch.nn.MSELoss()
63 |
64 | self._augmentation_std = augmentation_std
65 | self._no_augmentations = no_augmentations
66 |
67 | self.distribution_lambda = torch.distributions.Beta(alpha, alpha)
68 |
69 | self.penalty_weight = penalty_weight
70 | if len(penalty_weight_schedule) != 2:
71 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step")
72 | self.start = penalty_weight_schedule[0]
73 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0]
74 |
75 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
76 | if not self.training:
77 | (x, domains), y_true = batch
78 | y_pred = self.forward(x)
79 | loss = self.loss_fn(y_pred, y_true)
80 | self.log("loss", loss)
81 | return loss
82 |
83 | batch_src, batch_tgt = batch["source"], batch["target"]
84 | (x_src, _), y_src = batch_src
85 | x_tgt, _ = batch_tgt
86 | y_tgt = self._get_pseudo_labels(x_tgt)
87 |
88 | penalty_weight = linear_interpolation(
89 | step=self.current_epoch,
90 | duration=self.duration,
91 | max_value=self.penalty_weight,
92 | start=self.start,
93 | )
94 |
95 | loss = 0
96 |
97 | # Inter-domain, from source to target
98 | loss += self._loss(x_src, x_tgt, y_src, y_tgt, True)
99 | # Intra-domain, from source to source
100 | loss += self._loss(x_src, x_src, y_src, y_src, False)
101 | # Intra-domain, from target to target
102 | # Quote from the paper: "We set a linear schedule for w_t in training,
103 | # from 0 to a predefined maximum value. From initial experiments, we observe that
104 | # the algorithm is robust to other weighting parameters. Therefore we only
105 | # search w_t while simply fixing all other weightings to 1."
106 | loss += penalty_weight * self._loss(x_tgt, x_tgt, y_tgt, y_tgt, False)
107 |
108 | self.log("loss", loss)
109 | return loss
110 |
111 | def _get_pseudo_labels(self, xs):
112 | augmented_labels = []
113 | for i in range(self._no_augmentations):
114 | sample = xs * torch.normal(1, self._augmentation_std, xs.size(), device=xs.device)
115 | augmented_labels.append(self.forward(sample))
116 | pseudo_labels = torch.stack(augmented_labels).mean(0).squeeze()
117 | return torch.atleast_1d(pseudo_labels)
118 |
119 | def _loss(self, x_src, x_tgt, y_src, y_tgt, inter_domain: bool):
120 | lam = self.distribution_lambda.sample()
121 | lam_prime = torch.max(lam, 1.0 - lam)
122 |
123 | x_src, x_tgt, y_src, y_tgt = self._get_random_pairs(x_src, x_tgt, y_src, y_tgt, round(len(x_src) / 3))
124 | x_st, y_st = self._mixup(x_src, x_tgt, y_src, y_tgt, lam_prime)
125 | y_pred_st, phi_st = self.forward(x_st, return_embedding=True)
126 |
127 | if inter_domain:
128 | # Predictive loss
129 | loss_q = self.loss_fn(y_pred_st, y_st)
130 |
131 | # Consistency regularizer
132 | y_pred_s, phi_s = self.forward(x_src, return_embedding=True)
133 | y_pred_t, phi_t = self.forward(x_tgt, return_embedding=True)
134 | zi_st = lam_prime * phi_s + (1.0 - lam_prime) * phi_t
135 | loss_z = torch.norm(zi_st - phi_st, dim=0).mean()
136 | loss = loss_q + loss_z
137 |
138 | # Intra target domain
139 | else:
140 | loss = self.loss_fn(y_pred_st, y_st)
141 | return loss
142 |
143 | def _get_random_pairs(self, x_src, x_tgt, y_src, y_tgt, size):
144 | assert len(x_src) == len(y_src)
145 | assert len(x_tgt) == len(y_tgt)
146 |
147 | size = max(torch.tensor(1), size)
148 | indices = torch.multinomial(torch.ones(len(x_src), device=x_src.device), size, replacement=True)
149 | x_src = x_src[indices]
150 | y_src = y_src[indices]
151 |
152 | indices = torch.multinomial(torch.ones(len(x_tgt), device=x_tgt.device), size, replacement=True)
153 | x_tgt = x_tgt[indices]
154 | y_tgt = y_tgt[indices]
155 |
156 | return x_src, x_tgt, y_src, y_tgt
157 |
158 | def _mixup(self, x_s, x_t, y_s, y_t, lam_prime):
159 | xi_st = lam_prime * x_s + (1.0 - lam_prime) * x_t
160 | yi_st = lam_prime * y_s + (1.0 - lam_prime) * y_t
161 | return xi_st, yi_st
162 |
163 | @staticmethod
164 | def suggest_params(trial):
165 | params = BaseModel.suggest_params(trial)
166 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True)
167 | params["penalty_weight_schedule"] = trial.suggest_categorical(
168 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]]
169 | )
170 | params["augmentation_std"] = trial.suggest_float("augmentation_std", 0.001, 0.15)
171 | params["no_augmentations"] = trial.suggest_categorical("no_augmentations", [3, 5, 10])
172 | params["alpha"] = trial.suggest_float("alpha", 0.0, 1.0)
173 | return params
174 |
--------------------------------------------------------------------------------
/mood/model/mtl.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import nn
6 | from mood.model.base import BaseModel
7 |
8 |
9 | class MTL(BaseModel):
10 | """Marginal Transfer Learning (MTL) is a **domain generalization** method.
11 |
12 | MTL uses a representation of the domain.
13 | Additionally, rather than computing the loss for the entire batch at once, it computes the loss
14 | for each domain individually and then returns the mean of this.
15 |
16 | Originally proposed as a kernel method, it is here implemented as a Neural Network.
17 |
18 | References:
19 | Blanchard, G., Deshmukh, A. A., Dogan, U., Lee, G., & Scott, C. (2017).
20 | Domain generalization by marginal transfer learning.
21 | https://arxiv.org/abs/1711.07910
22 | """
23 |
24 | def __init__(
25 | self,
26 | base_network: nn.Module,
27 | prediction_head: nn.Module,
28 | loss_fn: nn.Module,
29 | batch_size: int,
30 | lr: float = 1e-3,
31 | weight_decay: float = 0,
32 | ):
33 | """
34 | Args:
35 | base_network: The neural network architecture endoing the features
36 | domain_network: The neural network architecture encoding the domain
37 | prediction_head: The neural network architecture that takes the concatenated
38 | representation of the domain and features and returns a task-specific prediction.
39 | loss_fn: The loss function to optimize for.
40 | """
41 | super().__init__(
42 | lr=lr,
43 | weight_decay=weight_decay,
44 | base_network=base_network,
45 | loss_fn=loss_fn,
46 | prediction_head=prediction_head,
47 | batch_size=batch_size,
48 | )
49 | self.domain_network = deepcopy(base_network)
50 |
51 | def forward(self, x, domains: Optional = None, return_embedding: bool = False):
52 | input_embeddings = self.base_network(x)
53 | domain_embeddings = self.domain_network(domains)
54 | # add noise to domains to avoid overfitting
55 | domain_embeddings = domain_embeddings + torch.randn_like(domain_embeddings)
56 | embeddings = torch.cat((domain_embeddings, input_embeddings), 1)
57 | label = self.prediction_head(embeddings)
58 | out = (label, embeddings) if return_embedding else label
59 | return out
60 |
61 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
62 | xs = torch.cat([xs for (xs, _), _ in batch], dim=0)
63 | domains = torch.cat([ds for (_, ds), _ in batch], dim=0)
64 | y_true = torch.cat([ys for _, ys in batch], dim=0)
65 | ns = [len(ys) for _, ys in batch]
66 |
67 | y_pred, phis = self.forward(xs, domains, return_embedding=True)
68 |
69 | loss = self._loss(y_pred, y_true, ns)
70 | self.log("loss", loss)
71 | return loss
72 |
73 | def _loss(self, y_pred, y_true, ns):
74 | loss, i = 0, 0
75 | for n in ns:
76 | loss += self.loss_fn(y_pred[i : i + n], y_true[i : i + n])
77 | i += n
78 | loss = loss / len(ns)
79 | return loss
80 |
--------------------------------------------------------------------------------
/mood/model/nn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from torch import nn
6 | from torch.nn import Flatten
7 |
8 | from mood.model.utils import get_activation
9 |
10 |
11 | class FCLayer(nn.Module):
12 | r"""
13 | A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module.
14 | The order in which transformations are applied is:
15 |
16 | #. Dense Layer
17 | #. Activation
18 | #. Dropout (if applicable)
19 | #. Batch Normalization (if applicable)
20 |
21 | Arguments
22 | ----------
23 | in_size: int
24 | Input dimension of the layer (the torch.nn.Linear)
25 | out_size: int
26 | Output dimension of the layer. Should be one supported by :func:`ivbase.nn.commons.get_activation`.
27 | dropout: float, optional
28 | The ratio of units to dropout. No dropout by default.
29 | (Default value = 0.)
30 | activation: str or callable, optional
31 | Activation function to use.
32 | (Default value = relu)
33 | b_norm: bool, optional
34 | Whether to use batch normalization
35 | (Default value = False)
36 | bias: bool, optional
37 | Whether to enable bias in for the linear layer.
38 | (Default value = True)
39 | init_fn: callable, optional
40 | Initialization function to use for the weight of the layer. Default is
41 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` with :math:`k=\frac{1}{ \text{node_feats_dim}}`
42 | (Default value = None)
43 |
44 | Attributes
45 | ----------
46 | dropout: int
47 | The ratio of units to dropout.
48 | b_norm: int
49 | Whether to use batch normalization
50 | linear: torch.nn.Linear
51 | The linear layer
52 | activation: the torch.nn.Module
53 | The activation layer
54 | init_fn: function
55 | Initialization function used for the weight of the layer
56 | in_size: int
57 | Input dimension of the linear layer
58 | out_size: int
59 | Output dimension of the linear layer
60 |
61 | """
62 |
63 | def __init__(
64 | self, in_size, out_size, activation="relu", dropout=0.0, b_norm=False, bias=True, init_fn=None
65 | ):
66 | super(FCLayer, self).__init__()
67 | # Although I disagree with this it is simple enough and robust
68 | # if we trust the user base
69 | self._params = locals()
70 | self.in_size = in_size
71 | self.out_size = out_size
72 | activation = get_activation(activation)
73 | linear = nn.Linear(in_size, out_size, bias=bias)
74 | if init_fn:
75 | init_fn(linear)
76 | layers = [linear]
77 | if activation is not None:
78 | layers.append(activation)
79 | if dropout:
80 | layers.append(nn.Dropout(p=dropout))
81 | if b_norm:
82 | layers.append(nn.BatchNorm1d(out_size))
83 | self.net = nn.Sequential(*layers)
84 |
85 | @property
86 | def output_dim(self):
87 | return self.out_size
88 |
89 | @property
90 | def input_dim(self):
91 | return self.in_size
92 |
93 | def forward(self, x):
94 | return self.net(x)
95 |
96 |
97 | class MLP(nn.Module):
98 | r"""
99 | Feature extractor using a Fully Connected Neural Network
100 |
101 | Arguments
102 | ----------
103 | input_size: int
104 | size of the input
105 | hidden_sizes: int list or int
106 | size of the hidden layers
107 | out_size: int list or int or None
108 | if None, uses the last hidden size as the output
109 | activation: str or callable
110 | activation function. Should be supported by :func:`ivbase.nn.commons.get_activation`
111 | (Default value = 'relu')
112 | b_norm: bool, optional):
113 | Whether batch norm is used or not.
114 | (Default value = False)
115 | dropout: float, optional
116 | Dropout probability to regularize the network. No dropout by default.
117 | (Default value = .0)
118 |
119 | Attributes
120 | ----------
121 | extractor: torch.nn.Module
122 | The underlying feature extractor of the model.
123 | """
124 |
125 | def __init__(
126 | self,
127 | input_size,
128 | hidden_sizes=None,
129 | out_size=None,
130 | activation="ReLU",
131 | out_activation=None,
132 | b_norm=False,
133 | l_norm=False,
134 | dropout=0.0,
135 | ):
136 | super(MLP, self).__init__()
137 | self._params = locals()
138 | layers = []
139 |
140 | if out_size is None and hidden_sizes is None:
141 | raise ValueError("You need to specify either hidden_sizes or output_size")
142 |
143 | if out_size is None:
144 | out_size = hidden_sizes[-1]
145 | hidden_sizes = hidden_sizes[:-1]
146 |
147 | if hidden_sizes is None:
148 | hidden_sizes = []
149 |
150 | in_ = input_size
151 | if l_norm:
152 | layers.append(nn.LayerNorm(input_size))
153 |
154 | for i, out_ in enumerate(hidden_sizes):
155 | layer = FCLayer(
156 | in_,
157 | out_,
158 | activation=activation,
159 | b_norm=False,
160 | dropout=dropout,
161 | )
162 | layers.append(layer)
163 | in_ = out_
164 |
165 | layers.append(FCLayer(in_, out_size, activation=out_activation, b_norm=b_norm, dropout=False))
166 | self.extractor = nn.Sequential(*layers)
167 |
168 | @property
169 | def output_dim(self):
170 | return self.extractor[-1].output_dim
171 |
172 | @property
173 | def input_dim(self):
174 | return self.extractor[0].input_dim
175 |
176 | def forward(self, x):
177 | x = Flatten()(x)
178 | res = self.extractor(x)
179 | return res
180 |
181 |
182 | def get_simple_mlp(
183 | input_size: int,
184 | width: int = 0,
185 | depth: int = 0,
186 | out_size: Optional[int] = 1,
187 | is_regression: Optional[bool] = None,
188 | ):
189 | if out_size is not None and not isinstance(is_regression, bool):
190 | raise TypeError("Specify is_regression to be True or False")
191 |
192 | if out_size is None:
193 | out_activation = "ReLU"
194 | else:
195 | out_activation = None if is_regression else "Sigmoid"
196 |
197 | return MLP(
198 | input_size=input_size,
199 | hidden_sizes=[width] * depth,
200 | out_size=out_size,
201 | activation="ReLU",
202 | out_activation=out_activation,
203 | )
204 |
--------------------------------------------------------------------------------
/mood/model/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import torch
3 |
4 |
5 | def linear_interpolation(step, duration, max_value, start: int = 0, min_value: float = 0):
6 | if max_value < min_value:
7 | raise ValueError("max_value cannot be smaller than min value")
8 | if step < start:
9 | return min_value
10 | if step >= start + duration:
11 | return max_value
12 | step_size = (max_value - min_value) / duration
13 | step = step - start
14 | return step_size * step
15 |
16 |
17 | def get_activation(activation_spec):
18 | if isinstance(activation_spec, Callable):
19 | return activation_spec
20 | if activation_spec is None:
21 | return None
22 |
23 | activation_fs = vars(torch.nn.modules.activation)
24 | for activation in activation_fs:
25 | if activation.lower() == activation_spec.lower():
26 | return activation_fs[activation]()
27 |
28 | raise ValueError(f"{activation_spec} is not a valid activation function")
29 |
--------------------------------------------------------------------------------
/mood/model/vrex.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, List
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from mood.model.base import BaseModel
8 | from mood.model.utils import linear_interpolation
9 |
10 |
11 | class VREx(BaseModel):
12 | """Variance Risk Extrapolation (VREx) is a **domain generalization** method.
13 |
14 | Similar to MTL, computes the loss per domain and computes the mean of these domain-specific losses.
15 | Additionally, following equation 8 of the paper, returns an additional penalty based on the variance
16 | of the domain specific losses.
17 |
18 | References:
19 | Krueger, D., Caballero, E., Jacobsen, J. H., Zhang, A., Binas, J., Zhang, D., ... & Courville, A. (2021, July).
20 | Out-of-distribution generalization via risk extrapolation (rex).
21 | In International Conference on Machine Learning (pp. 5815-5826). PMLR.
22 | https://arxiv.org/abs/2003.00688
23 | """
24 |
25 | def __init__(
26 | self,
27 | base_network: nn.Module,
28 | prediction_head: nn.Module,
29 | loss_fn: nn.Module,
30 | batch_size: int,
31 | penalty_weight: float,
32 | penalty_weight_schedule: List[int],
33 | lr=1e-3,
34 | weight_decay=0,
35 | ):
36 | """
37 | Args:
38 | base_network: The neural network architecture endoing the features
39 | prediction_head: The neural network architecture that takes the concatenated
40 | representation of the domain and features and returns a task-specific prediction.
41 | loss_fn: The loss function to optimize for.
42 | penalty_weight: The weight to multiply the penalty by.
43 | penalty_weight_schedule: List of two integers as a very rudimental way of scheduling the
44 | penalty weight. The first integer is the last epoch at which the penalty weight is 0.
45 | The second integer is the first epoch at which the penalty weight is its max value.
46 | Linearly interpolates between the two.
47 | """
48 | super().__init__(
49 | lr=lr,
50 | weight_decay=weight_decay,
51 | base_network=base_network,
52 | prediction_head=prediction_head,
53 | loss_fn=loss_fn,
54 | batch_size=batch_size,
55 | )
56 |
57 | self.penalty_weight = penalty_weight
58 | if len(penalty_weight_schedule) != 2:
59 | raise ValueError("The penalty weight schedule needs to define two values; The start and end step")
60 | self.start = penalty_weight_schedule[0]
61 | self.duration = penalty_weight_schedule[1] - penalty_weight_schedule[0]
62 |
63 | def _step(self, batch, batch_idx=0, optimizer_idx=None):
64 | erm_losses = []
65 | for mini_batch in batch:
66 | (xs, _), y_true = mini_batch
67 | y_pred = self.forward(xs, return_embedding=False)
68 | erm_losses.append(self.loss_fn(y_pred, y_true))
69 |
70 | penalty_weight = linear_interpolation(
71 | step=self.current_epoch,
72 | duration=self.duration,
73 | max_value=self.penalty_weight,
74 | start=self.start,
75 | )
76 |
77 | erm_losses = torch.stack(erm_losses)
78 | erm_loss = erm_losses.mean()
79 |
80 | # NOTE: A batch can have just a single domain
81 | if len(batch) == 1:
82 | loss = erm_loss
83 | else:
84 | rex_penalty = erm_losses.var()
85 | loss = erm_loss + penalty_weight * rex_penalty
86 |
87 | self.log("loss", loss)
88 | return loss
89 |
90 | @staticmethod
91 | def suggest_params(trial):
92 | params = BaseModel.suggest_params(trial)
93 | params["penalty_weight"] = trial.suggest_float("penalty_weight", 0.0001, 100, log=True)
94 | params["penalty_weight_schedule"] = trial.suggest_categorical(
95 | "penalty_weight_schedule", [[0, 25], [0, 50], [0, 0], [25, 50]]
96 | )
97 | return params
98 |
--------------------------------------------------------------------------------
/mood/model_space.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from typing import Union
4 |
5 | from sklearn.neural_network import MLPRegressor, MLPClassifier
6 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
7 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
8 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum
9 |
10 |
11 | _SKLEARN_MLP_TYPE = Union[MLPRegressor, MLPClassifier]
12 | _SKLEARN_RF_TYPE = Union[RandomForestRegressor, RandomForestClassifier]
13 | _SKLEARN_GP_TYPE = Union[GaussianProcessRegressor, GaussianProcessClassifier]
14 |
15 |
16 | def is_linear_kernel(kernel):
17 | return isinstance(kernel, PairwiseKernel) and kernel.metric == "linear"
18 |
19 |
20 | class ModelSpaceTransformer:
21 | SUPPORTED_TYPES = Union[
22 | _SKLEARN_MLP_TYPE,
23 | _SKLEARN_RF_TYPE,
24 | _SKLEARN_GP_TYPE,
25 | ]
26 |
27 | def __init__(self, model, embedding_size: int):
28 | if not isinstance(model, self.SUPPORTED_TYPES):
29 | raise TypeError(f"{type(model)} is not supported")
30 | self._model = model
31 | self._embedding_size = embedding_size
32 |
33 | def __call__(self, X):
34 | """Transforms a list of datapoints"""
35 | return self.transform(X)
36 |
37 | def transform(self, X):
38 | """Transforms a single datapoint"""
39 | if isinstance(self._model, _SKLEARN_RF_TYPE):
40 | return self.get_rf_embedding(self._model, X)
41 | elif isinstance(self._model, _SKLEARN_GP_TYPE):
42 | return self.get_gp_embedding(self._model, X)
43 | elif isinstance(self._model, _SKLEARN_MLP_TYPE):
44 | return self.get_mlp_embedding(self._model, X)
45 | # This should never be reached given the
46 | # type check in the constructor
47 | raise NotImplementedError
48 |
49 | def get_rf_embedding(self, model, X):
50 | """
51 | For a random forest, the model space embedding is given by
52 | a subset of 100 features that have the highest importance
53 | """
54 | importances = model.feature_importances_
55 | mask = np.argsort(importances)[-self._embedding_size :]
56 | return X[:, mask]
57 |
58 | def get_gp_embedding(self, model, X):
59 | """
60 | In a Gaussian Process, the model space embedding is given
61 | by a subset of 100 features that have the highest importance.
62 | This importance is computed based on alpha and the train set.
63 | For now, this only supports a linear kernel.
64 | """
65 | # Check the target type
66 | is_regression = isinstance(model, GaussianProcessRegressor)
67 | if not is_regression and model.n_classes_ != 2:
68 | msg = f"We only support regression and binary classification"
69 | raise ValueError(msg)
70 |
71 | # Check the kernel type
72 | is_linear = (
73 | is_linear_kernel(model.kernel_)
74 | or isinstance(model.kernel_, Sum)
75 | and (is_linear_kernel(model.kernel_.k1) or is_linear_kernel(model.kernel_.k2))
76 | )
77 | if not is_linear:
78 | msg = f"We only support the linear kernel, not {model.kernel_}"
79 | raise NotImplementedError(msg)
80 |
81 | if is_regression:
82 | alpha = model.alpha_
83 | X_train = model.X_train_
84 | else:
85 | est = model.base_estimator_
86 | alpha = est.y_train_ - est.pi_
87 | X_train = est.X_train_
88 |
89 | importances = (alpha[:, None] * X_train).sum(axis=0)
90 | importances = np.abs(importances)
91 | mask = np.argsort(importances)[-self._embedding_size :]
92 | return X[:, mask]
93 |
94 | def get_mlp_embedding(self, model, X):
95 | """
96 | For an multi-layer perceptron, the model space embedding is given by
97 | the activations of the second-to-last layer
98 | """
99 | hidden_layer_sizes = model.hidden_layer_sizes
100 |
101 | # Get the MLP architecture
102 | if not hasattr(hidden_layer_sizes, "__iter__"):
103 | hidden_layer_sizes = [hidden_layer_sizes]
104 | hidden_layer_sizes = list(hidden_layer_sizes)
105 | layer_units = [X.shape[1]] + hidden_layer_sizes + [model.n_outputs_]
106 |
107 | # Create empty arrays to save all activations in
108 | activations = [X]
109 | for i in range(model.n_layers_ - 1):
110 | activations.append(np.empty((X.shape[0], layer_units[i + 1])))
111 |
112 | # Actually populate the empty arrays
113 | model._forward_pass(activations)
114 |
115 | # Return the activations of the second-to-last layer
116 | hidden_rep = activations[-2]
117 |
118 | importances = model.coefs_[-1][:, 0]
119 | mask = np.argsort(importances)[-self._embedding_size :]
120 | return hidden_rep[:, mask]
121 |
--------------------------------------------------------------------------------
/mood/preprocessing.py:
--------------------------------------------------------------------------------
1 | import datamol as dm
2 | from functools import partial
3 | from rdkit.Chem import SaltRemover
4 |
5 |
6 | def standardize_smiles(smi, for_text_based_model: bool = False, disable_logs: bool = False):
7 | """A good default standardization function for fingerprints and GNNs"""
8 |
9 | with dm.without_rdkit_log(enable=disable_logs):
10 | mol = dm.to_mol(smi, ordered=True, sanitize=False)
11 | mol = dm.sanitize_mol(mol)
12 |
13 | if for_text_based_model:
14 | mol = dm.standardize_mol(mol)
15 |
16 | else:
17 | mol = dm.standardize_mol(mol, disconnect_metals=True)
18 | remover = SaltRemover.SaltRemover()
19 | mol = remover.StripMol(mol, dontRemoveEverything=True)
20 |
21 | return dm.to_smiles(mol)
22 |
23 |
24 | DEFAULT_PREPROCESSING = {
25 | "MACCS": partial(standardize_smiles, for_text_based_model=False),
26 | "ECFP6": partial(standardize_smiles, for_text_based_model=False),
27 | "Desc2D": partial(standardize_smiles, for_text_based_model=False),
28 | "WHIM": partial(standardize_smiles, for_text_based_model=False),
29 | "ChemBERTa": partial(standardize_smiles, for_text_based_model=True),
30 | "Graphormer": partial(standardize_smiles, for_text_based_model=False),
31 | }
32 |
--------------------------------------------------------------------------------
/mood/rct.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import zlib
3 |
4 | import numpy as np
5 |
6 | from mood.model import MOOD_DA_DG_ALGORITHMS
7 | from mood.representations import MOOD_REPRESENTATIONS
8 | from mood.splitter import MOOD_SPLITTERS
9 | from mood.baselines import MOOD_BASELINES
10 | from mood.criteria import get_mood_criteria
11 | from mood.metrics import Metric
12 |
13 |
14 | RCT_SEED = 1234
15 | NUM_SEEDS = 10
16 |
17 |
18 | def get_experimental_configurations(dataset):
19 | """
20 | To randomly sample different configurations of the RCT experiment, we use a deterministic approach.
21 | This facilitates reproducibility, but also makes it easy to run the experiment on preemptible instances.
22 | Otherwise, it could happen that models that take longer to train have a higher chance of failing,
23 | biasing the experiment.
24 | """
25 |
26 | # NOTE: We should not rely on the order of a dict for creating these configurations,
27 | # as a dict is not ordered. We unfortunately only realized this halfway through generating the results.
28 | # Luckily, it seems like for our use case this does result in consistent results.
29 | # As updating the code would change the ordering of the RCT, we kept it like this for now.
30 |
31 | prf_metric = Metric.get_default_performance_metric(dataset)
32 | cal_metric = Metric.get_default_calibration_metric(dataset)
33 |
34 | mood_criteria = get_mood_criteria(prf_metric, cal_metric).keys()
35 |
36 | mood_baselines = MOOD_BASELINES.copy()
37 | mood_baselines.pop(mood_baselines.index("MLP"))
38 | mood_algorithms = mood_baselines + list(MOOD_DA_DG_ALGORITHMS.keys())
39 |
40 | all_options = list(
41 | itertools.product(
42 | mood_algorithms,
43 | MOOD_REPRESENTATIONS,
44 | MOOD_SPLITTERS,
45 | mood_criteria,
46 | list(range(NUM_SEEDS)),
47 | )
48 | )
49 | # NOTE: We add the hash of the dataset to make the sampled configurations dataset-dependent
50 | rng = np.random.default_rng(RCT_SEED + zlib.adler32(dataset.encode("utf-8")))
51 | rng.shuffle(all_options)
52 | return all_options
53 |
--------------------------------------------------------------------------------
/mood/representations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 |
4 | import pandas as pd
5 | import datamol as dm
6 | import numpy as np
7 |
8 | from loguru import logger
9 | from collections import OrderedDict
10 | from typing import Optional, List, Callable, Union, Dict
11 | from functools import partial
12 | from copy import deepcopy
13 |
14 | from rdkit import Chem
15 | from rdkit.Chem.QED import properties
16 | from rdkit.Chem import Descriptors
17 | from rdkit.Chem import FindMolChiralCenters
18 | from rdkit.Chem import rdPartialCharges
19 | from rdkit.Chem import rdMolDescriptors
20 | from rdkit.Chem import AllChem
21 |
22 | from transformers import AutoTokenizer, AutoModelForMaskedLM
23 | from mood.constants import DATASET_DATA_DIR
24 | from mood.utils import get_mask_for_distances_or_representations
25 |
26 |
27 | _CHEMBERTA_HF_ID = "seyonec/PubChem10M_SMILES_BPE_450k"
28 |
29 |
30 | def representation_iterator(
31 | smiles,
32 | standardize_fn: Union[Callable, Dict[str, Callable]],
33 | n_jobs: Optional[int] = None,
34 | progress: bool = True,
35 | mask_nan: bool = True,
36 | return_mask: bool = True,
37 | disable_logs: bool = True,
38 | whitelist: Optional[List[str]] = None,
39 | blacklist: Optional[List[str]] = None,
40 | batch_size: int = 16,
41 | ):
42 | if whitelist is not None and blacklist is not None:
43 | msg = "You cannot use a blacklist and whitelist at the same time"
44 | raise ValueError(msg)
45 |
46 | all_representations = MOOD_REPRESENTATIONS
47 |
48 | if whitelist is not None:
49 | all_representations = [d for d in all_representations if d in whitelist]
50 | if blacklist is not None:
51 | all_representations = [d for d in all_representations if d not in blacklist]
52 | if not isinstance(standardize_fn, dict):
53 | standardize_fn = {repr_: standardize_fn for repr_ in all_representations}
54 |
55 | for name in all_representations:
56 | feats = featurize(
57 | smiles,
58 | name,
59 | standardize_fn[name],
60 | n_jobs,
61 | mask_nan,
62 | return_mask,
63 | progress,
64 | disable_logs,
65 | batch_size,
66 | )
67 | yield name, feats
68 |
69 |
70 | def featurize(
71 | smiles,
72 | name,
73 | standardize_fn,
74 | n_jobs: Optional[int] = None,
75 | mask_nan: bool = True,
76 | return_mask: bool = True,
77 | progress: bool = True,
78 | disable_logs: bool = False,
79 | batch_size: int = 16,
80 | ):
81 | if name not in _REPR_TO_FUNC:
82 | msg = f"{name} is not supported. Choose from {MOOD_REPRESENTATIONS}"
83 | raise NotImplementedError(msg)
84 |
85 | fn = partial(standardize_fn, disable_logs=disable_logs)
86 | smiles = np.array(
87 | dm.utils.parallelized(fn, smiles, progress=progress, tqdm_kwargs={"desc": f"Preprocess {name}"})
88 | )
89 |
90 | fn = _REPR_TO_FUNC[name]
91 | fn = partial(fn, disable_logs=disable_logs)
92 |
93 | if name in BATCHED_FEATURIZERS:
94 | reprs = fn(smiles, batch_size=batch_size)
95 |
96 | else:
97 | reprs = dm.utils.parallelized(
98 | fn, smiles, n_jobs=n_jobs, progress=progress, tqdm_kwargs={"desc": name}
99 | )
100 | reprs = np.array(reprs)
101 |
102 | # Mask out invalid features
103 | mask = get_mask_for_distances_or_representations(reprs)
104 |
105 | logger.info(f"Succesfully computed representations for {len(reprs[mask])}/{len(smiles)} compounds")
106 |
107 | if mask_nan:
108 | reprs = reprs[mask]
109 |
110 | # If the array had any Nones, it would not be a proper
111 | # 2D array so we convert to one here.
112 | reprs = np.stack(reprs)
113 |
114 | if return_mask:
115 | reprs = reprs, mask
116 | return reprs
117 |
118 |
119 | def compute_whim(smi, disable_logs: bool = False):
120 | """
121 | Compute a WHIM descriptor from a RDkit molecule object
122 | Code adapted from MoleculeACE, Van Tilborg et al. (2022)
123 | """
124 |
125 | smi = dm.to_smiles(dm.keep_largest_fragment(dm.to_mol(smi)))
126 |
127 | with dm.without_rdkit_log(enable=disable_logs):
128 | mol = dm.to_mol(smi)
129 | if mol is None:
130 | # Failed
131 | return
132 |
133 | mol = Chem.AddHs(mol)
134 |
135 | # Use distance geometry to obtain initial coordinates for a molecule
136 | ret = AllChem.EmbedMolecule(
137 | mol, useRandomCoords=True, useBasicKnowledge=True, randomSeed=0, clearConfs=True, maxAttempts=5
138 | )
139 | if ret == -1:
140 | # Failed
141 | return
142 |
143 | AllChem.MMFFOptimizeMolecule(mol, maxIters=1000, mmffVariant="MMFF94")
144 |
145 | # calculate WHIM 3D descriptor
146 | whim = rdMolDescriptors.CalcWHIM(mol)
147 | whim = np.array(whim).astype(np.float32)
148 | return whim
149 |
150 |
151 | def _compute_extra_2d_features(mol):
152 | """Computes some additional descriptors besides the default ones RDKit offers"""
153 | mol = deepcopy(mol)
154 | FindMolChiralCenters(mol, force=True)
155 | p_obj = rdMolDescriptors.Properties()
156 | props = OrderedDict(zip(p_obj.GetPropertyNames(), p_obj.ComputeProperties(mol)))
157 | qed_props = properties(mol)
158 | props["Alerts"] = qed_props.ALERTS
159 | return props
160 |
161 |
162 | def _charge_descriptors_fix(mol: dm.Mol):
163 | """Recompute the RDKIT 2D Descriptors related to charge
164 |
165 | We change the procedure:
166 | 1. We disconnect the metal from the molecule
167 | 2. We add the hydrogen atoms
168 | 3. We make sure that gasteiger is recomputed.
169 |
170 | This fixes an issue where these descriptors could be NaN or Inf,
171 | while also making sure we are closer to the proper interpretation
172 | """
173 | descrs = {}
174 | mol = dm.standardize_mol(mol, disconnect_metals=True)
175 | mol = dm.add_hs(mol, explicit_only=False)
176 | rdPartialCharges.ComputeGasteigerCharges(mol)
177 | atomic_charges = [float(at.GetProp("_GasteigerCharge")) for at in mol.GetAtoms()]
178 | atomic_charges = np.clip(atomic_charges, a_min=-500, a_max=500)
179 | min_charge, max_charge = np.nanmin(atomic_charges), np.nanmax(atomic_charges)
180 | descrs["MaxPartialCharge"] = max_charge
181 | descrs["MinPartialCharge"] = min_charge
182 | descrs["MaxAbsPartialCharge"] = max(np.abs(min_charge), np.abs(max_charge))
183 | descrs["MinAbsPartialCharge"] = min(np.abs(min_charge), np.abs(max_charge))
184 | return descrs
185 |
186 |
187 | def compute_desc2d(smi, disable_logs: bool = False):
188 | descr_fns = {name: fn for (name, fn) in Descriptors.descList}
189 |
190 | all_features = [d[0] for d in Descriptors.descList]
191 | all_features += [
192 | "NumAtomStereoCenters",
193 | "NumUnspecifiedAtomStereoCenters",
194 | "NumBridgeheadAtoms",
195 | "NumAmideBonds",
196 | "NumSpiroAtoms",
197 | "Alerts",
198 | ]
199 |
200 | with dm.without_rdkit_log(enable=disable_logs):
201 | mol = dm.to_mol(smi)
202 | descr_extra = _compute_extra_2d_features(mol)
203 | descr_charge = _charge_descriptors_fix(mol)
204 |
205 | descr = []
206 |
207 | for name in all_features:
208 | val = float("nan")
209 | if name in descr_charge:
210 | val = descr_charge[name]
211 | elif name == "Ipc":
212 | # Fixes a bug for the RDKit IPC value. For context, see:
213 | # https://github.com/rdkit/rdkit/issues/1527
214 | val = descr_fns[name](mol, avg=True)
215 | elif name in descr_fns:
216 | val = descr_fns[name](mol)
217 | else:
218 | assert name in descr_extra
219 | val = descr_extra[name]
220 | descr.append(val)
221 |
222 | descr = np.asarray(descr)
223 | return descr
224 |
225 |
226 | def compute_ecfp6(smi, disable_logs: bool = False):
227 | with dm.without_rdkit_log(enable=disable_logs):
228 | return dm.to_fp(smi, fp_type="ecfp", radius=3)
229 |
230 |
231 | def compute_maccs(smi, disable_logs: bool = False):
232 | with dm.without_rdkit_log(enable=disable_logs):
233 | return dm.to_fp(smi, fp_type="maccs")
234 |
235 |
236 | def compute_chemberta(smis, disable_logs: bool = False, batch_size: int = 16):
237 | # Batch the input
238 | step_size = int(np.ceil(len(smis) / batch_size))
239 | batched = np.array_split(smis, step_size)
240 |
241 | # Load the model
242 | tokenizer = AutoTokenizer.from_pretrained(_CHEMBERTA_HF_ID)
243 | model = AutoModelForMaskedLM.from_pretrained(_CHEMBERTA_HF_ID)
244 |
245 | # Use the GPU if it is available
246 | device = "cuda" if torch.cuda.is_available() else "cpu"
247 | model.to(device)
248 | model.eval()
249 |
250 | hidden_states = []
251 | for batch in tqdm.tqdm(batched, desc="Batch"):
252 | model_input = tokenizer(
253 | batch.tolist(),
254 | return_tensors="pt",
255 | add_special_tokens=True,
256 | truncation=True,
257 | padding=True,
258 | max_length=512,
259 | ).to(device)
260 |
261 | with torch.no_grad():
262 | model_output = model(
263 | model_input["input_ids"],
264 | attention_mask=model_input["attention_mask"],
265 | output_hidden_states=True,
266 | )
267 |
268 | # We use mean aggregation of the different token embeddings
269 | h = model_output.hidden_states[-1]
270 | h = [h_[mask].mean(0) for h_, mask in zip(h, model_input["attention_mask"])]
271 | h = torch.stack(h)
272 | h = h.cpu().detach().numpy()
273 |
274 | hidden_states.append(h)
275 |
276 | hidden_states = np.concatenate(hidden_states)
277 | return hidden_states
278 |
279 |
280 | def load_graphormer(smis, disable_logs: bool = False, batch_size: int = 16):
281 | # NOTE: Since we needed to make some changes to the Graphormer repo
282 | # to actually extract the embeddings, including that code in MOOD
283 | # would pollute the repo quite a bit. Therefore, we precomputed the
284 | # representations and just load these here.
285 |
286 | pattern = dm.fs.join(DATASET_DATA_DIR, "representations", "**", "*.parquet")
287 | paths = dm.fs.glob(pattern)
288 | df_repr = pd.concat([pd.read_parquet(p) for p in paths])
289 |
290 | unique_ids = dm.utils.parallelized(dm.unique_id, smis, progress=True)
291 | df = pd.DataFrame(index=unique_ids)
292 |
293 | df_repr = df_repr[df_repr["unique_id"].isin(unique_ids)]
294 | df_repr = df_repr.set_index("unique_id")
295 |
296 | df = df.join(df_repr)
297 | df = df[~df.index.duplicated(keep="first")]
298 | df = df.reindex(unique_ids)
299 |
300 | feats = df["representation"].to_numpy()
301 | return feats
302 |
303 |
304 | _REPR_TO_FUNC = {
305 | "MACCS": compute_maccs,
306 | "ECFP6": compute_ecfp6,
307 | "Desc2D": compute_desc2d,
308 | "WHIM": compute_whim,
309 | "ChemBERTa": compute_chemberta,
310 | "Graphormer": load_graphormer,
311 | }
312 |
313 | MOOD_REPRESENTATIONS = list(_REPR_TO_FUNC.keys())
314 | BATCHED_FEATURIZERS = ["ChemBERTa", "Graphormer"]
315 | TEXTUAL_FEATURIZERS = ["ChemBERTa"]
316 |
--------------------------------------------------------------------------------
/mood/splitter.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 | import numpy as np
4 | import datamol as dm
5 | import seaborn as sns
6 | import pandas as pd
7 |
8 | from sklearn.model_selection import ShuffleSplit
9 | from dataclasses import dataclass
10 | from loguru import logger
11 | from typing import Union, List, Optional, Callable, Dict
12 |
13 | from scipy.stats import gaussian_kde
14 | from scipy.spatial.distance import jensenshannon
15 | from sklearn.metrics import pairwise_distances
16 | from sklearn.neighbors import NearestNeighbors
17 | from sklearn.model_selection import BaseShuffleSplit, GroupShuffleSplit
18 | from sklearn.model_selection._split import _validate_shuffle_split, _num_samples
19 | from sklearn.cluster import MiniBatchKMeans
20 |
21 | from mood.transformer import EmpiricalKernelMapTransformer
22 | from mood.distance import get_distance_metric
23 | from mood.visualize import plot_distance_distributions
24 | from mood.utils import get_outlier_bounds
25 |
26 |
27 | MOOD_SPLITTERS = ["Random", "Scaffold", "Perimeter", "Maximum Dissimilarity"]
28 |
29 |
30 | def get_mood_splitters(smiles, n_splits: int = 5, random_state: int = 0, n_jobs: Optional[int] = None):
31 | scaffolds = [dm.to_smiles(dm.to_scaffold_murcko(dm.to_mol(smi))) for smi in smiles]
32 | splitters = {
33 | "Random": ShuffleSplit(n_splits=n_splits, random_state=random_state),
34 | "Scaffold": PredefinedGroupShuffleSplit(
35 | groups=scaffolds, n_splits=n_splits, random_state=random_state
36 | ),
37 | "Perimeter": PerimeterSplit(
38 | n_clusters=25, n_splits=n_splits, random_state=random_state, n_jobs=n_jobs
39 | ),
40 | "Maximum Dissimilarity": MaxDissimilaritySplit(
41 | n_clusters=25, n_splits=n_splits, random_state=random_state, n_jobs=n_jobs
42 | ),
43 | }
44 | return splitters
45 |
46 |
47 | @dataclass
48 | class SplitCharacterization:
49 | """
50 | Within the context of MOOD, a split is characterized by
51 | a distribution of distances and an associated representativeness score
52 | """
53 |
54 | distances: np.ndarray
55 | representativeness: float
56 | label: str
57 |
58 | @classmethod
59 | def concat(cls, splits):
60 | names = set([obj.label for obj in splits])
61 | if len(names) != 1:
62 | raise RuntimeError("Can only concatenate equally labeled split characterizations")
63 |
64 | dist = np.concatenate([obj.distances for obj in splits])
65 | score = np.mean([obj.representativeness for obj in splits])
66 | return cls(dist, score, names.pop())
67 |
68 | @staticmethod
69 | def best(splits):
70 | return max(splits, key=lambda spl: spl.representativeness)
71 |
72 | @staticmethod
73 | def as_dataframe(splits):
74 | df = pd.DataFrame()
75 | best = SplitCharacterization.best(splits)
76 | for split in splits:
77 | df_ = pd.DataFrame(
78 | {
79 | "split": split.label,
80 | "representativeness": split.representativeness,
81 | "best": split == best,
82 | },
83 | index=[0],
84 | )
85 | df = pd.concat((df, df_), ignore_index=True)
86 | df["rank"] = df["representativeness"].rank(ascending=False)
87 | return df
88 |
89 | def __eq__(self, other):
90 | return self.label == other.label and self.representativeness == other.representativeness
91 |
92 | def __repr__(self):
93 | return self.__str__()
94 |
95 | def __str__(self):
96 | return f"{self.__class__.__name__}[{self.label}]"
97 |
98 |
99 | class MOODSplitter(BaseShuffleSplit):
100 | """
101 | The MOOD splitter takes in multiple splitters and a set of
102 | downstream molecules and prescribes one splitting method
103 | that creates the test set that is most representative of
104 | downstream applications.
105 | """
106 |
107 | def __init__(
108 | self,
109 | splitters: Dict[str, BaseShuffleSplit],
110 | downstream_distances: Optional[np.ndarray] = None,
111 | metric: Union[str, Callable] = "minkowski",
112 | p: int = 2,
113 | k: int = 5,
114 | ):
115 | """
116 | Args:
117 | splitters: A list of splitter methods you are considering
118 | downstream_distances: A list of precomputed distances for the downstream application
119 | metric: The distance metric to use
120 | p: If the metric is the minkowski distance, this is the p in that distance.
121 | k: The number of nearest neighbors to use to compute the distance.
122 | """
123 | super().__init__()
124 | if not all(isinstance(obj, BaseShuffleSplit) for obj in splitters.values()):
125 | raise TypeError("All splitters should be BaseShuffleSplit objects")
126 |
127 | n_splits_per_splitter = [obj.get_n_splits() for obj in splitters.values()]
128 | if not len(set(n_splits_per_splitter)) == 1:
129 | raise TypeError("n_splits is inconsistent across the different splitters")
130 | self._n_splits = n_splits_per_splitter[0]
131 |
132 | self._p = p
133 | self._k = k
134 | self._metric = metric
135 | self._splitters = splitters
136 | self._downstream_distances = downstream_distances
137 |
138 | self._split_chars = None
139 | self._prescribed_splitter_label = None
140 |
141 | @staticmethod
142 | def visualize(downstream_distances: np.ndarray, splits: List[SplitCharacterization], ax: Optional = None):
143 | splits = sorted(splits, key=lambda spl: spl.representativeness)
144 | cmap = sns.color_palette("rocket", len(splits) + 1)
145 |
146 | distances = [spl.distances for spl in splits]
147 | colors = [cmap[rank + 1] for rank, spl in enumerate(splits)]
148 | labels = [spl.label for spl in splits]
149 |
150 | ax = plot_distance_distributions(distances, labels, colors, ax=ax)
151 |
152 | lower, upper = get_outlier_bounds(downstream_distances, factor=3.0)
153 | mask = (downstream_distances >= lower) & (downstream_distances <= upper)
154 | downstream_distances = downstream_distances[mask]
155 |
156 | sns.kdeplot(downstream_distances, color=cmap[0], linestyle="--", alpha=0.3, ax=ax)
157 | return ax
158 |
159 | @staticmethod
160 | def score_representativeness(downstream_distances, distances, num_samples: int = 100):
161 | """Scores a representativeness score between two distributions
162 | A higher score should be interpreted as _more_ representative"""
163 | pdf_split = gaussian_kde(distances)
164 | pdf_downstream = gaussian_kde(downstream_distances)
165 |
166 | vmin = np.min(np.concatenate((downstream_distances, distances)))
167 | vmax = np.max(np.concatenate((downstream_distances, distances)))
168 | positions = np.linspace(vmin, vmax, num=num_samples)
169 |
170 | samples_split = pdf_split(positions)
171 | samples_downstream = pdf_downstream(positions)
172 |
173 | return 1.0 - jensenshannon(samples_downstream, samples_split, base=2)
174 |
175 | @property
176 | def prescribed_splitter_label(self):
177 | if not self.fitted:
178 | raise RuntimeError("The splitter has not be fitted yet")
179 | return self._prescribed_splitter_label
180 |
181 | @property
182 | def fitted(self):
183 | return self._prescribed_splitter_label is not None
184 |
185 | def _compute_distance(self, X_from, X_to):
186 | """
187 | Computes the k-NN distance from one set to another
188 |
189 | Args:
190 | X_from: The set to compute the distance for
191 | X_to: The set to compute the distance to (i.e. the neighbor candidates)
192 | """
193 | knn = NearestNeighbors(n_neighbors=self._k, metric=self._metric, p=self._p).fit(X_to)
194 | distances, ind = knn.kneighbors(X_from)
195 | distances = np.mean(distances, axis=1)
196 | return distances
197 |
198 | def get_prescribed_splitter(self):
199 | return self._splitters[self.prescribed_splitter_label]
200 |
201 | def get_protocol_results(self):
202 | return SplitCharacterization.as_dataframe(self._split_chars)
203 |
204 | def fit(self, X, y=None, groups=None, X_deployment=None, plot: bool = False, progress: bool = False):
205 | """Follows the MOOD specification to prescribe a train-test split
206 | that is most representative of downstream applications.
207 |
208 | In MOOD, the k-NN distance in the representation space functions
209 | as a proxy of difficulty. The further a datapoint is from the training
210 | set, in general the lower a model's performance. Using that observation,
211 | we prescribe the train-test split that best replicates the distance
212 | distribution (i.e. "the difficulty") of a downstream application.
213 | """
214 |
215 | if self._downstream_distances is None:
216 | self._downstream_distances = self._compute_distance(X_deployment, X)
217 |
218 | # Precompute all splits. Since splitters are implemented as generators,
219 | # we store the resulting splits so we can replicate them later on.
220 | split_chars = list()
221 |
222 | it = self._splitters.items()
223 | if progress:
224 | it = tqdm.tqdm(it, desc="Splitter")
225 |
226 | for name, splitter in it:
227 | # We possibly repeat the split multiple times to
228 | # get a more reliable estimate
229 | chars = []
230 |
231 | it_ = splitter.split(X, y, groups)
232 | if progress:
233 | it_ = tqdm.tqdm(it_, leave=False, desc="Split", total=self._n_splits)
234 |
235 | for split in it_:
236 | train, test = split
237 | distances = self._compute_distance(X[test], X[train])
238 | distances = distances[np.isfinite(distances)]
239 | distances = distances[~np.isnan(distances)]
240 |
241 | score = self.score_representativeness(self._downstream_distances, distances)
242 | chars.append(SplitCharacterization(distances, score, name))
243 |
244 | split_chars.append(SplitCharacterization.concat(chars))
245 |
246 | # Rank different splitting methods by their ability to
247 | # replicate the downstream distance distribution.
248 | chosen = SplitCharacterization.best(split_chars)
249 |
250 | self._split_chars = split_chars
251 | self._prescribed_splitter_label = chosen.label
252 |
253 | logger.info(
254 | f"Ranked all different splitting methods:\n{SplitCharacterization.as_dataframe(split_chars)}"
255 | )
256 | logger.info(f"Selected {chosen.label} as the most representative splitting method")
257 |
258 | if plot:
259 | # Visualize the results
260 | return self.visualize(self._downstream_distances, split_chars)
261 |
262 | def _iter_indices(self, X=None, y=None, groups=None):
263 | """Generate (train, test) indices"""
264 | if not self.fitted:
265 | raise RuntimeError("The splitter has not be fitted yet")
266 | yield from self.get_prescribed_splitter()._iter_indices(X, y, groups)
267 |
268 |
269 | class PredefinedGroupShuffleSplit(GroupShuffleSplit):
270 | """Simple class that tackles the limitation of the MOODSplitter
271 | that all splitters need to use the same grouping."""
272 |
273 | def __init__(self, groups, n_splits=5, *, test_size=None, train_size=None, random_state=None):
274 | super().__init__(
275 | n_splits=n_splits,
276 | test_size=test_size,
277 | train_size=train_size,
278 | random_state=random_state,
279 | )
280 | self._groups = groups
281 |
282 | def _iter_indices(self, X=None, y=None, groups=None):
283 | """Generate (train, test) indices"""
284 | if groups is not None:
285 | logger.warning("Ignoring the groups parameter in favor of the predefined groups")
286 | yield from super()._iter_indices(X, y, self._groups)
287 |
288 |
289 | class KMeansSplit(GroupShuffleSplit):
290 | """Split based on the k-Mean clustering in input space"""
291 |
292 | def __init__(
293 | self, n_clusters: int = 10, n_splits: int = 5, *, test_size=None, train_size=None, random_state=None
294 | ):
295 | super().__init__(
296 | n_splits=n_splits,
297 | test_size=test_size,
298 | train_size=train_size,
299 | random_state=random_state,
300 | )
301 | self._n_clusters = n_clusters
302 |
303 | def compute_kmeans_clustering(self, X, random_state_offset: int = 0, return_centers: bool = False):
304 | metric = get_distance_metric(X)
305 |
306 | if self.random_state is not None:
307 | seed = self.random_state + random_state_offset
308 | else:
309 | seed = None
310 |
311 | if metric != "euclidean":
312 | logger.debug(f"To use KMeans with the {metric} metric, we use the Empirical Kernel Map")
313 | transformer = EmpiricalKernelMapTransformer(
314 | n_samples=min(512, len(X)),
315 | metric=metric,
316 | random_state=seed,
317 | )
318 | X = transformer(X)
319 |
320 | model = MiniBatchKMeans(self._n_clusters, random_state=seed, compute_labels=True)
321 | model.fit(X)
322 |
323 | indices = model.labels_
324 | if not return_centers:
325 | return indices
326 |
327 | centers = model.cluster_centers_[indices]
328 | return indices, centers
329 |
330 | def _iter_indices(self, X=None, y=None, groups=None):
331 | """Generate (train, test) indices"""
332 | if groups is not None:
333 | logger.warning("Ignoring the groups parameter in favor of the predefined groups")
334 | groups = self.compute_kmeans_clustering(X)
335 | yield from super()._iter_indices(X, y, groups)
336 |
337 |
338 | class PerimeterSplit(KMeansSplit):
339 | """
340 | Places the pairs of data points with maximal pairwise distance in the test set.
341 | This was originally called the extrapolation-oriented split, introduced in Szántai-Kis et. al., 2003
342 | """
343 |
344 | def __init__(
345 | self,
346 | n_clusters: int = 10,
347 | n_splits: int = 5,
348 | n_jobs: Optional[int] = None,
349 | *,
350 | test_size=None,
351 | train_size=None,
352 | random_state=None,
353 | ):
354 | super().__init__(
355 | n_clusters=n_clusters,
356 | n_splits=n_splits,
357 | test_size=test_size,
358 | train_size=train_size,
359 | random_state=random_state,
360 | )
361 | self._n_jobs = n_jobs
362 |
363 | def _iter_indices(self, X, y=None, groups=None):
364 | if groups is not None:
365 | logger.warning("Ignoring the groups parameter in favor of the predefined groups")
366 |
367 | n_samples = _num_samples(X)
368 | n_train, n_test = _validate_shuffle_split(
369 | n_samples,
370 | self.test_size,
371 | self.train_size,
372 | default_test_size=self._default_test_size,
373 | )
374 |
375 | for i in range(self.n_splits):
376 | groups, centers = self.compute_kmeans_clustering(X, random_state_offset=i, return_centers=True)
377 | centers, group_indices, group_counts = np.unique(
378 | centers, return_inverse=True, return_counts=True, axis=0
379 | )
380 | groups_set = np.unique(group_indices)
381 |
382 | # We always use the euclidean metric. For binary vectors we would have
383 | # used the jaccard metric normally, but because of the k-Means clustering this
384 | # data would be transformed using the Empirical Kernel Map.
385 | distance_matrix = pairwise_distances(centers, metric="euclidean", n_jobs=self._n_jobs)
386 |
387 | # Sort the distance matrix to find the groups that are the furthest away from one another
388 | tril_indices = np.tril_indices_from(distance_matrix, k=-1)
389 | maximum_distance_indices = np.argsort(distance_matrix[tril_indices])[::-1]
390 |
391 | test_indices = []
392 | remaining = set(groups_set)
393 |
394 | for pos in maximum_distance_indices:
395 | if len(test_indices) >= n_test:
396 | break
397 |
398 | i, j = (
399 | tril_indices[0][pos],
400 | tril_indices[1][pos],
401 | )
402 |
403 | # If one of the molecules in this pair is already in the test set, skip to the next
404 | if not (i in remaining and j in remaining):
405 | continue
406 |
407 | remaining.remove(i)
408 | test_indices.extend(list(np.flatnonzero(group_indices == groups_set[i])))
409 | remaining.remove(j)
410 | test_indices.extend(list(np.flatnonzero(group_indices == groups_set[j])))
411 |
412 | train_indices = []
413 | for i in remaining:
414 | train_indices.extend(list(np.flatnonzero(group_indices == groups_set[i])))
415 |
416 | yield np.array(train_indices), np.array(test_indices)
417 |
418 |
419 | class MaxDissimilaritySplit(KMeansSplit):
420 | """Splits the data such that the train and test set are maximally dissimilar."""
421 |
422 | def __init__(
423 | self,
424 | n_clusters: int = 10,
425 | n_splits: int = 5,
426 | n_jobs: Optional[int] = None,
427 | *,
428 | test_size=None,
429 | train_size=None,
430 | random_state=None,
431 | ):
432 | super().__init__(
433 | n_clusters=n_clusters,
434 | n_splits=n_splits,
435 | test_size=test_size,
436 | train_size=train_size,
437 | random_state=random_state,
438 | )
439 | self._n_jobs = n_jobs
440 |
441 | def _iter_indices(self, X, y=None, groups=None):
442 | """Generate (train, test) indices"""
443 |
444 | if groups is not None:
445 | logger.warning("Ignoring the groups parameter in favor of the predefined groups")
446 |
447 | metric = get_distance_metric(X)
448 |
449 | n_samples = _num_samples(X)
450 | n_train, n_test = _validate_shuffle_split(
451 | n_samples,
452 | self.test_size,
453 | self.train_size,
454 | default_test_size=self._default_test_size,
455 | )
456 |
457 | for i in range(self.n_splits):
458 | # We introduce some stochasticity through the k-Means clustering
459 | groups, centers = self.compute_kmeans_clustering(X, random_state_offset=i, return_centers=True)
460 | centers, group_indices, group_counts = np.unique(
461 | centers, return_inverse=True, return_counts=True, axis=0
462 | )
463 | groups_set = np.unique(group_indices)
464 |
465 | # We always use the euclidean metric. For binary vectors we would have
466 | # used the jaccard metric normally, but because of the k-Means clustering this
467 | # data would be transformed using the Empirical Kernel Map.
468 | distance_matrix = pairwise_distances(centers, metric="euclidean", n_jobs=self._n_jobs)
469 |
470 | # The initial test cluster is the one with the
471 | # highest mean distance to all other clusters
472 | test_idx = np.argmax(distance_matrix.mean(axis=0))
473 |
474 | # The initial train cluster is the one furthest from
475 | # the initial test cluster
476 | train_idx = np.argmax(distance_matrix[test_idx])
477 |
478 | train_indices = np.flatnonzero(group_indices == groups_set[train_idx])
479 | test_indices = np.flatnonzero(group_indices == groups_set[test_idx])
480 |
481 | # Iteratively add the train cluster that is furthest
482 | # from the _initial_ test cluster.
483 | sorted_groups = np.argsort(distance_matrix[train_idx])
484 | for group_idx in sorted_groups:
485 | if len(train_indices) >= n_train:
486 | break
487 |
488 | if group_idx == train_idx or group_idx == test_idx:
489 | continue
490 |
491 | indices_to_add = np.flatnonzero(group_indices == groups_set[group_idx])
492 | train_indices = np.concatenate([train_indices, indices_to_add])
493 |
494 | # Construct test set
495 | remaining_groups = list(set(range(n_samples)) - set(train_indices) - set(test_indices))
496 | test_indices = np.concatenate([test_indices, remaining_groups]).astype(int)
497 |
498 | yield train_indices, test_indices
499 |
--------------------------------------------------------------------------------
/mood/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 |
4 | from typing import Optional
5 | from pytorch_lightning import Trainer
6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
7 | from pytorch_lightning.utilities.seed import seed_everything
8 | from torch.utils.data import DataLoader
9 |
10 | from mood.baselines import construct_kernel, get_baseline_model, MOOD_BASELINES
11 | from mood.constants import NUM_EPOCHS
12 | from mood.dataset import SimpleMolecularDataset, DAMolecularDataset, domain_based_collate
13 | from mood.model import MOOD_DA_DG_ALGORITHMS, is_domain_generalization, is_domain_adaptation
14 | from mood.model.base import Ensemble
15 | from mood.model.nn import get_simple_mlp
16 |
17 |
18 | def train_baseline_model(
19 | X,
20 | y,
21 | algorithm: str,
22 | is_regression: bool,
23 | params: Optional[dict] = None,
24 | seed: Optional[int] = None,
25 | for_uncertainty_estimation: bool = False,
26 | ensemble_size: int = 10,
27 | calibrate: bool = False,
28 | n_jobs: int = -1,
29 | ):
30 | if params is None:
31 | params = {}
32 | if seed is not None:
33 | params["random_state"] = seed
34 |
35 | if algorithm == "RF" or (algorithm == "GP" and not is_regression):
36 | params["n_jobs"] = n_jobs
37 |
38 | if algorithm == "RF" and not is_regression:
39 | params["class_weight"] = "balanced"
40 | if algorithm == "GP":
41 | params["kernel"], params = construct_kernel(is_regression, params)
42 |
43 | model = get_baseline_model(
44 | name=algorithm,
45 | is_regression=is_regression,
46 | params=params,
47 | for_uncertainty_estimation=for_uncertainty_estimation,
48 | ensemble_size=ensemble_size,
49 | calibrate=calibrate,
50 | )
51 | model.fit(X, y)
52 |
53 | return model
54 |
55 |
56 | def train_torch_model(
57 | train_dataset: SimpleMolecularDataset,
58 | val_dataset: SimpleMolecularDataset,
59 | test_dataset: SimpleMolecularDataset,
60 | algorithm: str,
61 | is_regression: bool,
62 | params: Optional[dict] = None,
63 | seed: Optional[int] = None,
64 | ensemble_size: int = 5,
65 | ):
66 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
67 | seed_everything(seed, workers=True)
68 |
69 | width = params.pop("mlp_width")
70 | depth = params.pop("mlp_depth")
71 | batch_size = params["batch_size"]
72 |
73 | # NOTE: Since the datasets are all very small,
74 | # setting up and syncing the threads takes longer than
75 | # what we gain by using the threads
76 | no_workers = 0
77 |
78 | models = []
79 | for i in range(ensemble_size):
80 | base = get_simple_mlp(len(train_dataset.X[0]), width, depth, out_size=None)
81 | head = get_simple_mlp(
82 | input_size=width * 2 if algorithm == "MTL" else width, is_regression=is_regression
83 | )
84 |
85 | model = MOOD_DA_DG_ALGORITHMS[algorithm](
86 | base_network=base,
87 | prediction_head=head,
88 | loss_fn=torch.nn.MSELoss() if is_regression else torch.nn.BCELoss(),
89 | **params,
90 | )
91 |
92 | if is_domain_adaptation(model):
93 | train_dataset_da = DAMolecularDataset(source_dataset=train_dataset, target_dataset=test_dataset)
94 | train_dataloader = DataLoader(
95 | train_dataset_da, batch_size=batch_size, shuffle=True, num_workers=no_workers
96 | )
97 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=no_workers)
98 | elif is_domain_generalization(model):
99 | train_dataloader = DataLoader(
100 | train_dataset,
101 | batch_size=batch_size,
102 | shuffle=True,
103 | collate_fn=domain_based_collate,
104 | num_workers=no_workers,
105 | )
106 | val_dataloader = DataLoader(
107 | val_dataset, batch_size=batch_size, collate_fn=domain_based_collate, num_workers=no_workers
108 | )
109 | else:
110 | train_dataloader = DataLoader(
111 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=no_workers
112 | )
113 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=no_workers)
114 |
115 | # NOTE: For smaller dataset, moving data between and CPU and GPU will be the bottleneck
116 | use_gpu = torch.cuda.is_available() and len(train_dataset) > 2500
117 |
118 | callbacks = [EarlyStopping("val_loss", patience=10, mode="min")]
119 | trainer = Trainer(
120 | max_epochs=NUM_EPOCHS,
121 | deterministic="warn",
122 | callbacks=callbacks,
123 | enable_model_summary=False,
124 | enable_progress_bar=False,
125 | num_sanity_val_steps=0,
126 | logger=False,
127 | accelerator="gpu" if use_gpu else None,
128 | devices=1 if use_gpu else None,
129 | enable_checkpointing=False,
130 | )
131 | trainer.fit(model, train_dataloader, val_dataloader)
132 | models.append(model)
133 |
134 | return Ensemble(models, is_regression)
135 |
136 |
137 | def train(
138 | train_dataset: SimpleMolecularDataset,
139 | val_dataset: SimpleMolecularDataset,
140 | test_dataset: SimpleMolecularDataset,
141 | algorithm: str,
142 | is_regression: bool,
143 | params: dict,
144 | seed: int,
145 | calibrate: bool = False,
146 | ensemble_size: int = 5,
147 | ):
148 | # NOTE: The order here matters since there are two MLP implementations
149 | # In this case, we want to use the torch implementation.
150 |
151 | if algorithm in MOOD_DA_DG_ALGORITHMS:
152 | if calibrate:
153 | raise NotImplementedError("We only support calibration for scikit-learn models")
154 |
155 | return train_torch_model(
156 | train_dataset=train_dataset,
157 | val_dataset=val_dataset,
158 | test_dataset=test_dataset,
159 | algorithm=algorithm,
160 | is_regression=is_regression,
161 | params=params,
162 | seed=seed,
163 | ensemble_size=ensemble_size,
164 | )
165 |
166 | elif algorithm in MOOD_BASELINES:
167 | return train_baseline_model(
168 | X=train_dataset.X,
169 | y=train_dataset.y,
170 | algorithm=algorithm,
171 | is_regression=is_regression,
172 | params=params,
173 | seed=seed,
174 | for_uncertainty_estimation=True,
175 | ensemble_size=ensemble_size,
176 | calibrate=calibrate,
177 | )
178 |
179 | else:
180 | raise NotImplementedError(f"{algorithm} is not supported")
181 |
--------------------------------------------------------------------------------
/mood/transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from typing import Union, Optional
4 |
5 | from scipy.spatial.distance import cdist
6 | from sklearn.neural_network import MLPRegressor, MLPClassifier
7 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
8 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
9 |
10 |
11 | _SKLEARN_MLP_TYPE = Union[MLPRegressor, MLPClassifier]
12 | _SKLEARN_RF_TYPE = Union[RandomForestRegressor, RandomForestClassifier]
13 | _SKLEARN_GP_TYPE = Union[GaussianProcessRegressor, GaussianProcessClassifier]
14 |
15 |
16 | class EmpiricalKernelMapTransformer:
17 | def __init__(self, n_samples: int, metric: str, random_state: Optional[int] = None):
18 | self._n_samples = n_samples
19 | self._random_state = random_state
20 | self._samples = None
21 | self._metric = metric
22 |
23 | def __call__(self, X):
24 | """Transforms a list of datapoints"""
25 | return self.transform(X)
26 |
27 | def transform(self, X):
28 | """Transforms a single datapoint"""
29 | if self._samples is None:
30 | rng = np.random.default_rng(self._random_state)
31 | self._samples = X[rng.choice(np.arange(len(X)), self._n_samples)]
32 | X = cdist(X, self._samples, metric=self._metric)
33 | return X
34 |
--------------------------------------------------------------------------------
/mood/utils.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import uuid
3 | from datetime import datetime
4 | from typing import Optional
5 |
6 | import datamol as dm
7 | import pandas as pd
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | from loguru import logger
12 |
13 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, CACHE_DIR
14 |
15 |
16 | def load_representation_for_downstream_application(
17 | name,
18 | representation,
19 | update_cache: bool = False,
20 | return_compound_ids: bool = False,
21 | ):
22 | suffix = ["representations", name, f"{representation}.parquet"]
23 |
24 | lpath = dm.fs.join(CACHE_DIR, "downstream_applications", *suffix)
25 | if not dm.fs.exists(lpath) or update_cache:
26 | rpath = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, *suffix)
27 | logger.debug(f"Downloading {rpath} to {lpath}")
28 | dm.fs.copy_file(rpath, lpath, force=update_cache)
29 | else:
30 | logger.debug(f"Using cache at {lpath}")
31 |
32 | data = pd.read_parquet(lpath)
33 |
34 | X = np.stack(data["representation"].values)
35 |
36 | mask = get_mask_for_distances_or_representations(X)
37 |
38 | if not return_compound_ids:
39 | return X[mask]
40 |
41 | indices = data.iloc[mask]["unique_id"].to_numpy()
42 | return X[mask], indices
43 |
44 |
45 | def load_distances_for_downstream_application(
46 | name,
47 | representation,
48 | dataset,
49 | update_cache: bool = False,
50 | return_compound_ids: bool = False,
51 | ):
52 | suffix = ["distances", name, dataset, f"{representation}.parquet"]
53 |
54 | lpath = dm.fs.join(CACHE_DIR, "downstream_applications", *suffix)
55 | if not dm.fs.exists(lpath) or update_cache:
56 | rpath = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, *suffix)
57 | logger.debug(f"Downloading {rpath} to {lpath}")
58 | dm.fs.copy_file(rpath, lpath, force=update_cache)
59 | else:
60 | logger.debug(f"Using cache at {lpath}")
61 |
62 | data = pd.read_parquet(lpath)
63 |
64 | distances = data["distance"].to_numpy()
65 | mask = get_mask_for_distances_or_representations(distances)
66 |
67 | if not return_compound_ids:
68 | return distances[mask]
69 |
70 | indices = data.iloc[mask]["unique_id"].to_numpy()
71 | return distances[mask], indices
72 |
73 |
74 | def save_figure_with_fsspec(path, exist_ok=False):
75 | if dm.fs.exists(path) and not exist_ok:
76 | raise RuntimeError(f"{path} already exists")
77 |
78 | if dm.fs.is_local_path(path):
79 | plt.savefig(path)
80 | return
81 |
82 | mapper = dm.fs.get_mapper(path)
83 | clean_path = path.rstrip(mapper.fs.sep)
84 | dir_components = str(clean_path).split(mapper.fs.sep)[:-1]
85 | dir_path = mapper.fs.sep.join(dir_components)
86 |
87 | dm.fs.mkdir(dir_path, exist_ok=True)
88 |
89 | with tempfile.TemporaryDirectory() as tmpdir:
90 | lpath = dm.fs.join(tmpdir, f"{str(uuid.uuid4())}.png")
91 |
92 | plt.savefig(lpath)
93 | dm.fs.copy_file(lpath, path, force=exist_ok)
94 |
95 |
96 | def get_outlier_bounds(X, factor: float = 1.5):
97 | q1 = np.quantile(X, 0.25)
98 | q3 = np.quantile(X, 0.75)
99 | iqr = q3 - q1
100 |
101 | lower = max(np.min(X), q1 - factor * iqr)
102 | upper = min(np.max(X), q3 + factor * iqr)
103 |
104 | return lower, upper
105 |
106 |
107 | def bin_with_overlap(data, filter_outliers: bool = True):
108 | if filter_outliers:
109 | minimum, maximum = get_outlier_bounds(data)
110 | window_size = (maximum - minimum) / 10
111 | yield minimum, np.nonzero(data <= minimum)[0]
112 |
113 | else:
114 | minimum = np.min(data)
115 | maximum = np.max(data)
116 | window_size = (maximum - minimum) / 10
117 |
118 | assert minimum >= 0, "A distance cannot be lower than 0"
119 |
120 | x = minimum
121 | step_size = window_size / 20
122 | while x + window_size < maximum:
123 | yield x + 0.5 * window_size, np.nonzero(np.logical_and(data >= x, data < x + window_size))[0]
124 | x += step_size
125 |
126 | # Yield the rest data
127 | yield x + ((maximum - x) / 2.0), np.nonzero(data >= x)[0]
128 |
129 |
130 | def get_mask_for_distances_or_representations(X):
131 | # The 1e4 threshold is somewhat arbitrary, but manually chosen to
132 | # filter out compounds that don't make sense (and are outliers).
133 | # (e.g. WHIM for [C-]#N.[C-]#N.[C-]#N.[C-]#N.[C-]#N.[Fe+4].[N-]=O)
134 | # Propagating such high-values would cause issues in downstream
135 | # functions (e.g. KMeans)
136 | mask = [
137 | i
138 | for i, a in enumerate(X)
139 | if a is not None and ~np.isnan(a).any() and np.isfinite(a).all() and ~(a > 1e4).any()
140 | ]
141 | return mask
142 |
143 |
144 | class Timer:
145 | """Context manager for timing operations"""
146 |
147 | def __init__(self, name: Optional[str] = None):
148 | self.name = name if name is not None else "operation"
149 | self.start_time = None
150 | self.end_time = None
151 |
152 | @property
153 | def duration(self):
154 | if self.start_time is None:
155 | raise RuntimeError("Cannot get the duration for an operation that has not started yet")
156 | if self.end_time is None:
157 | return datetime.now() - self.start_time
158 | return self.end_time - self.start_time
159 |
160 | def __enter__(self):
161 | self.start_time = datetime.now()
162 | logger.debug(f"Starting {self.name}")
163 | return self
164 |
165 | def __exit__(self, exc_type, exc_val, exc_tb):
166 | self.end_time = datetime.now()
167 | logger.info(f"Finished {self.name}. Duration: {self.duration}")
168 |
--------------------------------------------------------------------------------
/mood/visualize.py:
--------------------------------------------------------------------------------
1 | import fsspec
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | import seaborn as sns
5 | import numpy as np
6 | from typing import Optional, List
7 | from mood.utils import get_outlier_bounds
8 | from mood.metrics import Metric
9 |
10 |
11 | def plot_performance_over_distance(
12 | performance_data: pd.DataFrame,
13 | calibration_data: pd.DataFrame,
14 | dataset_name: str,
15 | ax: Optional = None,
16 | show_legend: bool = True,
17 | show_title: bool = True,
18 | show_xlabel: bool = True,
19 | show_ylabel: bool = True,
20 | ):
21 | if ax is None:
22 | _, ax = plt.subplots(figsize=(12, 6))
23 |
24 | expected_columns = ["distance", "score_lower", "score_mu", "score_upper"]
25 | if not all(c in performance_data.columns for c in expected_columns):
26 | raise ValueError(
27 | f"For performance_data, expecting {expected_columns}, found {performance_data.columns}"
28 | )
29 | if not all(c in calibration_data.columns for c in expected_columns):
30 | raise ValueError(
31 | f"For calibration_data, expecting {expected_columns}, found {calibration_data.columns}"
32 | )
33 |
34 | def _plot(data, color, ax):
35 | sns.lineplot(x=data[:, 0], y=data[:, 2], color=color, ax=ax, lw=4)
36 | ax.fill_between(data[:, 0], data[:, 1], data[:, 3], color=color, alpha=0.2)
37 | return ax
38 |
39 | ax = _plot(performance_data[expected_columns].to_numpy(), "tab:blue", ax)
40 | ax_calibration = _plot(calibration_data[expected_columns].to_numpy(), "tab:orange", ax.twinx())
41 |
42 | perf_metric = Metric.get_default_performance_metric(dataset_name)
43 | cali_metric = Metric.get_default_calibration_metric(dataset_name)
44 |
45 | if perf_metric.mode == "min":
46 | ax.invert_yaxis()
47 | if cali_metric.mode == "min":
48 | ax_calibration.invert_yaxis()
49 |
50 | if show_ylabel:
51 | label = f"Calibration ({cali_metric.name})"
52 | ax_calibration.set_ylabel(label, rotation=-90, labelpad=18, fontsize=12)
53 |
54 | label = f"Performance ({perf_metric.name})"
55 | ax.set_ylabel(label, fontsize=12)
56 |
57 | if show_xlabel:
58 | ax.set_xlabel("Distance")
59 |
60 | if show_title:
61 | ax.set_title(dataset_name, fontsize=18)
62 |
63 | if show_legend:
64 | legend_lines = [
65 | plt.Line2D([0], [0], color="tab:blue", lw=4),
66 | plt.Line2D([0], [0], color="tab:orange", lw=4),
67 | ]
68 | labels = ["Performance", "Calibration"]
69 |
70 | ax.legend(legend_lines, labels, fontsize=12, loc="lower center", ncol=len(labels), fancybox=True)
71 |
72 | return ax, ax_calibration
73 |
74 |
75 | def plot_distance_distributions(
76 | distances,
77 | labels: Optional[List[str]] = None,
78 | colors: Optional[List[str]] = None,
79 | styles: Optional[List[str]] = None,
80 | ax: Optional = None,
81 | outlier_factor: Optional[float] = 3.0,
82 | ):
83 | n = len(distances)
84 | show_legend = True
85 |
86 | # Set defaults
87 | if ax is None:
88 | fig, ax = plt.subplots(figsize=(12, 6))
89 | if colors is None:
90 | cmap = sns.color_palette("rocket", n)
91 | colors = [cmap[i] for i in range(n)]
92 | if labels is None:
93 | show_legend = False
94 | labels = [""] * n
95 | if styles is None:
96 | styles = ["-"] * n
97 |
98 | ax.spines["right"].set_visible(False)
99 | ax.spines["left"].set_visible(False)
100 | ax.spines["top"].set_visible(False)
101 | ax.yaxis.set_ticklabels([])
102 | ax.yaxis.set_ticks([])
103 |
104 | if outlier_factor is not None:
105 | all_distances = np.concatenate(distances)
106 | lower, upper = get_outlier_bounds(all_distances, factor=outlier_factor)
107 | distances = [X[(X >= lower) & (X <= upper)] for X in distances]
108 |
109 | # Visualize all splitting methods
110 | for idx, dist in enumerate(distances):
111 | sns.kdeplot(dist, color=colors[idx], linestyle=styles[idx], ax=ax, label=labels[idx])
112 |
113 | ax.set_xlabel(f"Distance")
114 |
115 | if show_legend:
116 | ax.legend()
117 |
118 | return ax
119 |
120 |
121 | def axes_grid_iterator(
122 | col_labels: List[str],
123 | row_labels: List[str],
124 | col_size: int = 5,
125 | row_size: int = 5,
126 | fontsize: int = 24,
127 | margin: float = 0.25,
128 | ):
129 | ncols = len(col_labels)
130 | nrows = len(row_labels)
131 |
132 | fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(col_size * ncols, row_size * nrows))
133 | axs = np.atleast_2d(axs)
134 |
135 | for ri, row in enumerate(row_labels):
136 | for ci, col in enumerate(col_labels):
137 | ax = axs[ri][ci]
138 | if ci == 0:
139 | ax.text(
140 | -margin,
141 | 0.5,
142 | row,
143 | rotation="vertical",
144 | va="center",
145 | ha="center",
146 | transform=ax.transAxes,
147 | fontsize=fontsize,
148 | )
149 | if ri == 0:
150 | ax.text(
151 | 0.5, 1 + margin, col, transform=ax.transAxes, va="center", ha="center", fontsize=fontsize
152 | )
153 | yield ax, ri, ci
154 |
--------------------------------------------------------------------------------
/notebooks/ToC_graphic.tiff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/ToC_graphic.tiff
--------------------------------------------------------------------------------
/notebooks/assets/checkmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/assets/checkmark.png
--------------------------------------------------------------------------------
/notebooks/assets/cross.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/notebooks/assets/cross.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "mood"
7 | description = "A python library to accompany the MOOD paper from Tossou et al. (2023)"
8 | authors = [{ name = "Cas Wognum", email = "cas@valencediscovery.com" }]
9 | readme = "README.md"
10 | dynamic = ["version"]
11 | requires-python = ">=3.10,<3.11"
12 | license = { text = "Apache" }
13 | classifiers = [
14 | "Development Status :: 2 - Pre-Alpha",
15 | "Intended Audience :: Developers",
16 | "Intended Audience :: Healthcare Industry",
17 | "Intended Audience :: Science/Research",
18 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
19 | "Topic :: Scientific/Engineering :: Bio-Informatics",
20 | "Topic :: Scientific/Engineering :: Information Analysis",
21 | "Topic :: Scientific/Engineering :: Medical Science Apps.",
22 | "Natural Language :: English",
23 | "Operating System :: OS Independent",
24 | "Programming Language :: Python",
25 | "Programming Language :: Python :: 3",
26 | "Programming Language :: Python :: 3.10",
27 | ]
28 |
29 | dependencies = [
30 | "pandas",
31 | "matplotlib",
32 | "scikit-learn",
33 | "torchmetrics",
34 | "pytorch-lightning <2.0",
35 | "torch >=1.10.2",
36 | "numpy <1.24",
37 | "tqdm",
38 | "optuna",
39 | "datamol",
40 | "notebook",
41 | "pytdc",
42 | "typer",
43 | "gcsfs",
44 | "pyarrow",
45 | "fastparquet",
46 | "transformers",
47 | ]
48 |
49 | [project.scripts]
50 | mood = "mood.__main__:app"
51 |
52 | [tool.black]
53 | line-length = 110
54 | target-version = ['py310']
55 | include = '\.pyi?$'
56 |
57 | [project.urls]
58 | "Source Code" = "https://github.com/cwognum/mood-experiments/tree/main"
59 | "Paper" = "TODO"
60 |
61 | [tool.setuptools]
62 | include-package-data = true
63 |
64 | [tool.setuptools_scm]
65 | fallback_version = "dev"
66 |
67 | [tool.setuptools.packages.find]
68 | where = ["."]
69 | include = ["mood", "mood.*", "scripts", "scripts.*"]
70 | exclude = []
71 | namespaces = true
72 |
73 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valence-labs/mood-experiments/4788e0c57f557916792247eadebbe61d2fa91714/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/__main__.py:
--------------------------------------------------------------------------------
1 | from .cli import app
2 |
3 |
4 | if __name__ == "__main__":
5 | app()
6 |
--------------------------------------------------------------------------------
/scripts/cli.py:
--------------------------------------------------------------------------------
1 | import typer
2 | from scripts.compare_spaces import cli as model_vs_input_space_cmd
3 | from scripts.compare_splits import cli as compare_splits_cmd
4 | from scripts.compare_performance import cli as iid_ood_gap_cmd
5 | from scripts.precompute_representations import cli as precompute_representation_cmd
6 | from scripts.precompute_distances import cli as precompute_distances_cmd
7 | from scripts.visualize_shift import cli as visualize_shift_cmd
8 | from scripts.visualize_splits import cli as visualize_splits_cmd
9 | from scripts.visualize_perf_over_distance import cli as perf_over_distance_cmd
10 |
11 |
12 | compare_app = typer.Typer(help="Various CLIs that involve comparing two things")
13 |
14 | compare_app.command(
15 | name="spaces", help="Compare how distances in the Model and various Input spaces correlate"
16 | )(model_vs_input_space_cmd)
17 |
18 | compare_app.command(
19 | name="performance", help="Compare how the model performs on compounds in the IID and OOD range"
20 | )(iid_ood_gap_cmd)
21 |
22 | compare_app.command(
23 | name="splits",
24 | help="Compare how different splits replicate the shift between train and downstream applications",
25 | )(compare_splits_cmd)
26 |
27 |
28 | precompute_app = typer.Typer(help="Various CLIs that precompute data used later on")
29 |
30 | precompute_app.command(
31 | name="representation", help="Precompute representations and save these as .parquet files"
32 | )(precompute_representation_cmd)
33 |
34 | precompute_app.command(
35 | name="distances", help="Precompute distances from downstream applications to the different train sets"
36 | )(precompute_distances_cmd)
37 |
38 |
39 | visualize_app = typer.Typer(help="Various CLIs to visualize results")
40 |
41 | visualize_app.command(name="shift", help="Visualize the shift from train to downstream applications")(
42 | visualize_shift_cmd
43 | )
44 |
45 | visualize_app.command(
46 | name="splits", help="Visualize how representative different splits are of downstream applications"
47 | )(visualize_splits_cmd)
48 |
49 | visualize_app.command(
50 | name="performance_over_distance", help="Visualize how performance and calibration evolve over distance"
51 | )(perf_over_distance_cmd)
52 |
53 |
54 | app = typer.Typer(help="CLI for the various stand-alone scripts of MOOD")
55 | app.add_typer(compare_app, name="compare")
56 | app.add_typer(precompute_app, name="precompute")
57 | app.add_typer(visualize_app, name="visualize")
58 |
--------------------------------------------------------------------------------
/scripts/compare_performance.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import fsspec
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import datamol as dm
7 |
8 | from typing import Optional
9 | from datetime import datetime
10 | from loguru import logger
11 | from sklearn.model_selection import train_test_split
12 | from tqdm import tqdm
13 |
14 | from mood.constants import RESULTS_DIR
15 | from mood.dataset import load_data_from_tdc, MOOD_REGR_DATASETS
16 | from mood.metrics import Metric, compute_bootstrapped_metric
17 | from mood.representations import featurize
18 | from mood.baselines import predict_baseline_uncertainty
19 | from mood.train import train_baseline_model
20 | from mood.experiment import basic_tuning_loop
21 | from mood.utils import bin_with_overlap, load_distances_for_downstream_application
22 | from mood.distance import compute_knn_distance
23 | from mood.preprocessing import DEFAULT_PREPROCESSING
24 |
25 |
26 | def cli(
27 | baseline_algorithm: str,
28 | representation: str,
29 | dataset: str,
30 | n_seeds: int = 5,
31 | n_trials: int = 50,
32 | n_startup_trials: int = 10,
33 | base_save_dir: str = RESULTS_DIR,
34 | sub_save_dir: Optional[str] = None,
35 | overwrite: bool = False,
36 | ):
37 | if sub_save_dir is None:
38 | sub_save_dir = datetime.now().strftime("%Y%m%d")
39 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_performance", sub_save_dir)
40 | dm.fs.mkdir(out_dir, exist_ok=True)
41 |
42 | # Load the dataset
43 | smiles, y = load_data_from_tdc(dataset)
44 | X, mask = featurize(
45 | smiles,
46 | representation,
47 | standardize_fn=DEFAULT_PREPROCESSING[representation],
48 | disable_logs=True,
49 | )
50 | y = y[mask]
51 | is_regression = dataset in MOOD_REGR_DATASETS
52 |
53 | # Get the metrics
54 | perf_metric = Metric.get_default_performance_metric(dataset)
55 | cali_metric = Metric.get_default_calibration_metric(dataset)
56 |
57 | # Generate all data needed for these plots
58 | dist_train = []
59 | dist_test = []
60 | y_pred = []
61 | y_true = []
62 | y_uncertainty = []
63 |
64 | for seed in range(n_seeds):
65 | # Randomly split the dataset
66 | # This ensures that the distribution of distances from val to train is relatively uniform
67 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
68 | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=seed)
69 |
70 | file_name = f"best_hparams_{dataset}_{baseline_algorithm}_{representation}_{seed}.yaml"
71 | out_path = dm.fs.join(out_dir, file_name)
72 |
73 | if dm.fs.exists(out_path):
74 | # Load the data of the completed hyper-param study if it already exists
75 | logger.info(f"Loading the best hyper-params from {out_path}")
76 | with fsspec.open(out_path) as fd:
77 | params = yaml.safe_load(fd)
78 |
79 | else:
80 | # Run a hyper-parameter search
81 | study = basic_tuning_loop(
82 | X_train=X_train,
83 | X_test=X_val,
84 | y_train=y_train,
85 | y_test=y_val,
86 | name=baseline_algorithm,
87 | is_regression=is_regression,
88 | metric=perf_metric,
89 | global_seed=seed,
90 | n_trials=n_trials,
91 | n_startup_trials=n_startup_trials,
92 | )
93 |
94 | params = study.best_params
95 | random_state = seed + study.best_trial.number
96 | params["random_state"] = random_state
97 |
98 | logger.info(f"Saving the best hyper-params to {out_path}")
99 | with fsspec.open(out_path, "w") as fd:
100 | yaml.dump(params, fd)
101 |
102 | file_name = f"trials_{dataset}_{baseline_algorithm}_{representation}_{seed}.csv"
103 | out_path = dm.fs.join(out_dir, file_name)
104 |
105 | logger.info(f"Saving the trials dataframe to {out_path}")
106 | study.trials_dataframe().to_csv(out_path)
107 |
108 | random_state = params.pop("random_state")
109 | model = train_baseline_model(
110 | X_train,
111 | y_train,
112 | baseline_algorithm,
113 | is_regression,
114 | params,
115 | random_state,
116 | for_uncertainty_estimation=True,
117 | ensemble_size=10,
118 | )
119 |
120 | y_pred_ = model.predict(X_test)
121 | y_uncertainty_ = predict_baseline_uncertainty(model, X_test)
122 |
123 | y_pred.append(y_pred_)
124 | y_true.append(y_test)
125 | y_uncertainty.append(y_uncertainty_)
126 |
127 | dist_train_, dist_test_ = compute_knn_distance(X_train, [X_train, X_test])
128 | dist_train.append(dist_train_)
129 | dist_test.append(dist_test_)
130 |
131 | dist_test = np.concatenate(dist_test)
132 | dist_train = np.concatenate(dist_train)
133 | y_pred = np.concatenate(y_pred)
134 | y_true = np.concatenate(y_true)
135 | y_uncertainty = np.concatenate(y_uncertainty)
136 |
137 | # Collect the distances of the downstream applications
138 | dist_scr = load_distances_for_downstream_application(
139 | "virtual_screening", representation, dataset, update_cache=True
140 | )
141 | dist_opt = load_distances_for_downstream_application(
142 | "optimization", representation, dataset, update_cache=True
143 | )
144 | dist_app = np.concatenate((dist_opt, dist_scr))
145 |
146 | # Compute the difference in IID and OOD performance and calibration
147 | lower, upper = np.quantile(dist_train, 0.025), np.quantile(dist_train, 0.975)
148 | mask = np.logical_and(dist_test >= lower, dist_test <= upper)
149 | score_iid = perf_metric(y_true[mask], y_pred[mask])
150 | calibration_iid = cali_metric(y_true[mask], y_pred[mask], y_uncertainty[mask])
151 | logger.info(f"Found an IID {perf_metric.name} score of {score_iid:.3f}")
152 | logger.info(f"Found an IID {cali_metric.name} calibration score of {calibration_iid:.3f}")
153 |
154 | lower, upper = np.quantile(dist_app, 0.025), np.quantile(dist_app, 0.975)
155 | mask = np.logical_and(dist_test >= lower, dist_test <= upper)
156 | score_ood = perf_metric(y_true[mask], y_pred[mask])
157 | calibration_ood = cali_metric(y_true[mask], y_pred[mask], y_uncertainty[mask])
158 | logger.info(f"Found an OOD {perf_metric.name} score of {score_ood:.3f}")
159 | logger.info(f"Found an OOD {cali_metric.name} calibration score of {calibration_ood:.3f}")
160 |
161 | file_name = f"gap_{dataset}_{baseline_algorithm}_{representation}.csv"
162 | out_path = dm.fs.join(out_dir, file_name)
163 | if dm.fs.exists(out_path) and not overwrite:
164 | raise RuntimeError(f"{out_path} already exists!")
165 |
166 | # Saving this as a CSV might be a bit wasteful,
167 | # but it's convenient
168 | logger.info(f"Saving the IID/OOD gap data to {out_path}")
169 |
170 | pd.DataFrame(
171 | {
172 | "dataset": dataset,
173 | "algorithm": baseline_algorithm,
174 | "representation": representation,
175 | "iid_score": [score_iid, calibration_iid],
176 | "ood_score": [score_ood, calibration_ood],
177 | "metric": [perf_metric.name, cali_metric.name],
178 | "type": ["performance", "calibration"],
179 | }
180 | ).to_csv(out_path, index=False)
181 |
182 | # Compute the performance over distance
183 | df = pd.DataFrame()
184 | for distance, mask in tqdm(bin_with_overlap(dist_test)):
185 | target = y_true[mask]
186 | preds = y_pred[mask]
187 | uncertainty = y_uncertainty[mask]
188 |
189 | n_samples = len(mask)
190 | if n_samples < 25 or len(np.unique(target)) == 1:
191 | continue
192 |
193 | perf_mu, perf_std = compute_bootstrapped_metric(
194 | targets=target, predictions=preds, metric=perf_metric, n_jobs=-1
195 | )
196 |
197 | cali_mu, cali_std = compute_bootstrapped_metric(
198 | targets=target, predictions=preds, uncertainties=uncertainty, metric=cali_metric, n_jobs=-1
199 | )
200 |
201 | df_ = pd.DataFrame(
202 | {
203 | "dataset": dataset,
204 | "algorithm": baseline_algorithm,
205 | "representation": representation,
206 | "distance": distance,
207 | "score_mu": [perf_mu, cali_mu],
208 | "score_std": [perf_std, cali_std],
209 | "type": ["performance", "calibration"],
210 | "metric": [perf_metric.name, cali_metric.name],
211 | "n_samples": n_samples,
212 | }
213 | )
214 | df = pd.concat((df, df_), ignore_index=True)
215 |
216 | file_name = f"perf_over_distance_{dataset}_{baseline_algorithm}_{representation}.csv"
217 | out_path = dm.fs.join(out_dir, file_name)
218 | if dm.fs.exists(out_path) and not overwrite:
219 | raise RuntimeError(f"{out_path} already exists!")
220 |
221 | logger.info(f"Saving the performance over distance data to {out_path}")
222 | df.to_csv(out_path, index=False)
223 |
--------------------------------------------------------------------------------
/scripts/compare_spaces.py:
--------------------------------------------------------------------------------
1 | import typer
2 | import fsspec
3 | import datamol as dm
4 | import numpy as np
5 | import pandas as pd
6 | import matplotlib.pyplot as plt
7 |
8 | from datetime import datetime
9 | from loguru import logger
10 | from typing import List, Optional
11 |
12 | from scipy.stats import pearsonr, spearmanr
13 | from sklearn.metrics import r2_score
14 | from sklearn.gaussian_process.kernels import PairwiseKernel, Sum, WhiteKernel
15 | from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
16 |
17 | from mood.dataset import dataset_iterator, MOOD_REGR_DATASETS
18 | from mood.model_space import ModelSpaceTransformer
19 | from mood.preprocessing import DEFAULT_PREPROCESSING
20 | from mood.distance import compute_knn_distance
21 | from mood.visualize import plot_distance_distributions
22 | from mood.representations import representation_iterator, featurize
23 | from mood.constants import RESULTS_DIR
24 | from mood.utils import (
25 | load_representation_for_downstream_application,
26 | save_figure_with_fsspec,
27 | get_outlier_bounds,
28 | )
29 | from mood.train import train_baseline_model
30 |
31 |
32 | def train_gp(X, y, is_regression):
33 | alpha = 1e-10
34 | for i in range(10):
35 | try:
36 | if is_regression:
37 | kernel = PairwiseKernel(metric="linear")
38 | model = GaussianProcessRegressor(kernel, alpha=alpha, random_state=0)
39 | else:
40 | kernel = Sum(PairwiseKernel(metric="linear"), WhiteKernel(noise_level=alpha))
41 | model = GaussianProcessClassifier(kernel, random_state=0)
42 | model.fit(X, y)
43 | return model
44 | except (np.linalg.LinAlgError, ValueError):
45 | # ValueError: array must not contain infs or NaNs
46 | # LinAlgError: N-th leading minor of the array is not positive definite
47 | # LinAlgError: The kernel is not returning a positive definite matrix
48 | alpha = alpha * 10
49 | return None
50 |
51 |
52 | def get_model_space_distances(model, train, queries):
53 | embedding_size = int(round(train.shape[1] * 0.25))
54 | trans = ModelSpaceTransformer(model, embedding_size)
55 |
56 | model_space_train = trans(train)
57 | queries = [trans(q) for q in queries]
58 |
59 | distances = compute_knn_distance(model_space_train, queries, n_jobs=-1)
60 | return distances
61 |
62 |
63 | def compute_correlations(input_spaces, model_spaces, labels):
64 | lower, upper = get_outlier_bounds(np.concatenate(input_spaces), factor=3.0)
65 | input_masks = [(X >= lower) & (X <= upper) for X in input_spaces]
66 |
67 | lower, upper = get_outlier_bounds(np.concatenate(model_spaces), factor=3.0)
68 | model_masks = [(X >= lower) & (X <= upper) for X in model_spaces]
69 |
70 | masks = [mask1 & mask2 for mask1, mask2 in zip(input_masks, model_masks)]
71 | input_spaces = [d[mask] for d, mask in zip(input_spaces, masks)]
72 | model_spaces = [d[mask] for d, mask in zip(model_spaces, masks)]
73 |
74 | df = pd.DataFrame()
75 | for input_space, model_space, label in zip(input_spaces, model_spaces, labels):
76 | df_ = pd.DataFrame(
77 | {
78 | "pearson": pearsonr(input_space, model_space)[0],
79 | "spearman": spearmanr(input_space, model_space)[0],
80 | "r_squared": r2_score(input_space, model_space),
81 | },
82 | index=[0],
83 | )
84 | df = pd.concat((df, df_), ignore_index=True)
85 |
86 | return df
87 |
88 |
89 | def cli(
90 | base_save_dir: str = RESULTS_DIR,
91 | sub_save_dir: Optional[str] = None,
92 | overwrite: bool = False,
93 | skip_representation: Optional[List[str]] = None,
94 | skip_dataset: Optional[List[str]] = None,
95 | batch_size: int = 16,
96 | ):
97 | if sub_save_dir is None:
98 | sub_save_dir = datetime.now().strftime("%Y%m%d")
99 |
100 | fig_save_dir = dm.fs.join(base_save_dir, "figures", "compare_spaces", sub_save_dir)
101 | dm.fs.mkdir(fig_save_dir, exist_ok=True)
102 | logger.info(f"Saving figures to {fig_save_dir}")
103 |
104 | np_save_dir = dm.fs.join(base_save_dir, "numpy", "compare_spaces", sub_save_dir)
105 | dm.fs.mkdir(np_save_dir, exist_ok=True)
106 | logger.info(f"Saving NumPy arrays to {np_save_dir}")
107 |
108 | df_save_dir = dm.fs.join(base_save_dir, "dataframes", "compare_spaces", sub_save_dir)
109 | dm.fs.mkdir(df_save_dir, exist_ok=True)
110 | corr_path = dm.fs.join(df_save_dir, "correlations.csv")
111 | if dm.fs.exists(df_save_dir) and not overwrite:
112 | raise typer.BadParameter(f"{corr_path} already exists. Specify --overwrite or --base-save-dir.")
113 | logger.info(f"Saving correlation dataframe to {corr_path}")
114 |
115 | df_corr = pd.DataFrame()
116 |
117 | dataset_it = dataset_iterator(blacklist=skip_dataset)
118 |
119 | for dataset, (smiles, y) in dataset_it:
120 | representation_it = representation_iterator(
121 | smiles,
122 | n_jobs=-1,
123 | progress=True,
124 | blacklist=skip_representation,
125 | standardize_fn=DEFAULT_PREPROCESSING,
126 | batch_size=batch_size,
127 | )
128 |
129 | for representation, (X, mask) in representation_it:
130 | y_repr = y[mask]
131 |
132 | virtual_screening = load_representation_for_downstream_application(
133 | "virtual_screening", representation, update_cache=True
134 | )
135 | optimization = load_representation_for_downstream_application(
136 | "optimization", representation, update_cache=True
137 | )
138 |
139 | is_regression = dataset in MOOD_REGR_DATASETS
140 | mlp_model = train_baseline_model(X, y_repr, "MLP", is_regression)
141 | rf_model = train_baseline_model(X, y_repr, "RF", is_regression)
142 | # We use a custom train function for GPs to include a retry system
143 | gp_model = train_gp(X, y_repr, is_regression)
144 |
145 | # Distances in input spaces
146 | input_distances = compute_knn_distance(X, [X, optimization, virtual_screening], n_jobs=-1)
147 | labels = ["Train", "Optimization", "Virtual Screening"]
148 |
149 | for dist, label in zip(input_distances, labels):
150 | path = dm.fs.join(np_save_dir, f"{dataset}_{representation}_{label}_input_space.npy")
151 | if dm.fs.exists(df_save_dir) and not overwrite:
152 | raise RuntimeError(f"{path} already exists. Specify --overwrite or --base-save-dir.")
153 | with fsspec.open(path, "wb") as fd:
154 | np.save(fd, dist)
155 |
156 | ax = plot_distance_distributions(input_distances, labels=labels)
157 | ax.set_title(f"Input space ({representation}, {dataset})")
158 | save_figure_with_fsspec(
159 | dm.fs.join(fig_save_dir, f"{dataset}_{representation}_input_space.png"), exist_ok=overwrite
160 | )
161 | plt.close()
162 |
163 | # Distances in different model spaces
164 | for name, model in {"MLP": mlp_model, "RF": rf_model, "GP": gp_model}.items():
165 | if model is None:
166 | logger.warning(f"Failed to train a {name} model for {dataset} on {representation}")
167 | continue
168 |
169 | model_distances = get_model_space_distances(model, X, [X, optimization, virtual_screening])
170 | for dist, label in zip(model_distances, labels):
171 | path = dm.fs.join(np_save_dir, f"{dataset}_{representation}_{label}_{name}_space.npy")
172 | if dm.fs.exists(df_save_dir) and not overwrite:
173 | raise RuntimeError(f"{path} already exists. Specify --overwrite or --base-save-dir.")
174 | with fsspec.open(path, "wb") as fd:
175 | np.save(fd, dist)
176 |
177 | ax = plot_distance_distributions(model_distances, labels=labels)
178 | ax.set_title(f"{name} space ({representation}, {dataset})")
179 | save_figure_with_fsspec(
180 | dm.fs.join(fig_save_dir, f"{dataset}_{representation}_{name}_space.png"),
181 | exist_ok=overwrite,
182 | )
183 | plt.close()
184 |
185 | # Compute correlation
186 | df = compute_correlations(
187 | input_distances,
188 | model_distances,
189 | labels,
190 | )
191 | df["model"] = name
192 | df["dataset"] = dataset
193 | df["representation"] = representation
194 | df_corr = pd.concat((df_corr, df), ignore_index=True)
195 |
196 | df_corr.to_csv(corr_path, index=False)
197 | df_corr.head()
198 |
--------------------------------------------------------------------------------
/scripts/compare_splits.py:
--------------------------------------------------------------------------------
1 | import fsspec
2 | import numpy as np
3 | import pandas as pd
4 | import datamol as dm
5 |
6 | from loguru import logger
7 | from typing import Optional, List
8 | from datetime import datetime
9 |
10 | from mood.dataset import dataset_iterator
11 | from mood.representations import representation_iterator
12 | from mood.splitter import MOODSplitter, get_mood_splitters
13 | from mood.preprocessing import DEFAULT_PREPROCESSING
14 | from mood.utils import load_distances_for_downstream_application, save_figure_with_fsspec
15 | from mood.constants import RESULTS_DIR
16 | from mood.distance import get_distance_metric
17 |
18 |
19 | def cli(
20 | base_save_dir: str = RESULTS_DIR,
21 | sub_save_dir: Optional[str] = None,
22 | skip_representation: Optional[List[str]] = None,
23 | skip_dataset: Optional[List[str]] = None,
24 | save_figures: bool = True,
25 | n_splits: int = 5,
26 | seed: Optional[int] = 0,
27 | use_cache: bool = False,
28 | batch_size: int = 16,
29 | verbose: bool = False,
30 | overwrite: bool = False,
31 | n_jobs: Optional[int] = None,
32 | ):
33 | df = pd.DataFrame()
34 |
35 | if sub_save_dir is None:
36 | sub_save_dir = datetime.now().strftime("%Y%m%d")
37 |
38 | fig_out_dir = dm.fs.join(base_save_dir, "figures", "compare_splits", sub_save_dir)
39 | dm.fs.mkdir(fig_out_dir, exist_ok=True)
40 |
41 | data_out_dir = dm.fs.join(base_save_dir, "numpy", "compare_splits", sub_save_dir)
42 | dm.fs.mkdir(data_out_dir, exist_ok=True)
43 |
44 | dataset_it = dataset_iterator(blacklist=skip_dataset)
45 | for dataset, (smiles, y) in dataset_it:
46 | representation_it = representation_iterator(
47 | smiles,
48 | n_jobs=n_jobs,
49 | progress=verbose,
50 | standardize_fn=DEFAULT_PREPROCESSING,
51 | batch_size=batch_size,
52 | blacklist=skip_representation,
53 | )
54 |
55 | for representation, (X, mask) in representation_it:
56 | logger.info(f"Loading precomputed distances for virtual screening")
57 | distances_vs = load_distances_for_downstream_application(
58 | "virtual_screening", representation, dataset, update_cache=not use_cache
59 | )
60 |
61 | logger.info(f"Loading precomputed distances for optimization")
62 | distances_op = load_distances_for_downstream_application(
63 | "optimization", representation, dataset, update_cache=not use_cache
64 | )
65 |
66 | metric = get_distance_metric(X)
67 | if metric == "jaccard":
68 | X = X.astype(bool)
69 |
70 | splitters = get_mood_splitters(smiles[mask], n_splits, seed, n_jobs=n_jobs)
71 | splitter = MOODSplitter(splitters, np.concatenate((distances_vs, distances_op)), metric, k=5)
72 | splitter.fit(X, progress=verbose, plot=save_figures)
73 |
74 | if save_figures:
75 | out_path = dm.fs.join(fig_out_dir, f"fig_{dataset}_{representation}.png")
76 | if not overwrite and dm.fs.exists(out_path):
77 | raise RuntimeError(
78 | f"{out_path} already exists. Specify a different path or use --overwrite"
79 | )
80 | logger.info(f"Saving figure to {out_path}")
81 | save_figure_with_fsspec(out_path, exist_ok=overwrite)
82 |
83 | for char in splitter._split_chars:
84 | out_path = dm.fs.join(data_out_dir, f"distances_{dataset}_{representation}_{char.label}.npy")
85 | if not overwrite and dm.fs.exists(out_path):
86 | raise RuntimeError(
87 | f"{out_path} already exists. Specify a different path or use --overwrite"
88 | )
89 | logger.info(f"Saving distance data to {out_path}")
90 | with fsspec.open(out_path, "wb") as fd:
91 | np.save(fd, char.distances)
92 |
93 | df_ = splitter.get_protocol_results()
94 | df_["representation"] = representation
95 | df_["dataset"] = dataset
96 |
97 | df = pd.concat((df, df_), ignore_index=True)
98 |
99 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_splits", sub_save_dir)
100 | dm.fs.mkdir(out_dir, exist_ok=True)
101 |
102 | out_path = dm.fs.join(out_dir, "splits.csv")
103 | if not overwrite and dm.fs.exists(out_path):
104 | raise RuntimeError(f"{out_path} already exists. Specify a different path or use --overwrite")
105 |
106 | df.to_csv(out_path, index=False)
107 | logger.info(f"Saving dataframe to {out_path}")
108 |
--------------------------------------------------------------------------------
/scripts/precompute_distances.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import datamol as dm
3 |
4 | from typing import Optional, List
5 | from mood.dataset import dataset_iterator
6 | from mood.representations import representation_iterator
7 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR
8 | from mood.utils import load_representation_for_downstream_application
9 | from mood.distance import compute_knn_distance
10 | from mood.preprocessing import DEFAULT_PREPROCESSING
11 |
12 |
13 | def save(distances, compounds, molecule_set, dataset, representation, overwrite):
14 | out_path = dm.fs.join(
15 | DOWNSTREAM_APPS_DATA_DIR, "distances", molecule_set, dataset, f"{representation}.parquet"
16 | )
17 | if dm.fs.exists(out_path) and not overwrite:
18 | raise RuntimeError(f"{out_path} already exists!")
19 |
20 | df = pd.DataFrame({"unique_id": compounds, "distance": distances})
21 | df.to_parquet(out_path)
22 | return df
23 |
24 |
25 | def cli(
26 | overwrite: bool = False,
27 | representation: Optional[List[str]] = None,
28 | dataset: Optional[List[str]] = None,
29 | verbose: bool = False,
30 | ):
31 | if len(dataset) == 0:
32 | dataset = None
33 | if len(representation) == 0:
34 | representation = None
35 |
36 | for dataset, (smiles, y) in dataset_iterator(whitelist=dataset, disable_logs=True):
37 | it = representation_iterator(
38 | smiles,
39 | standardize_fn=DEFAULT_PREPROCESSING,
40 | n_jobs=-1,
41 | progress=True,
42 | whitelist=representation,
43 | disable_logs=True,
44 | )
45 |
46 | for representation, (X, mask) in it:
47 | virtual_screening, vs_compounds = load_representation_for_downstream_application(
48 | "virtual_screening",
49 | representation,
50 | return_compound_ids=True,
51 | update_cache=True,
52 | )
53 | optimization, opt_compounds = load_representation_for_downstream_application(
54 | "optimization",
55 | representation,
56 | return_compound_ids=True,
57 | update_cache=True,
58 | )
59 |
60 | input_distances = compute_knn_distance(X, [optimization, virtual_screening], n_jobs=-1)
61 |
62 | save(input_distances[0], opt_compounds, "optimization", dataset, representation, overwrite)
63 | save(input_distances[1], vs_compounds, "virtual_screening", dataset, representation, overwrite)
64 |
--------------------------------------------------------------------------------
/scripts/precompute_graphormer.py:
--------------------------------------------------------------------------------
1 | import typer
2 |
3 | import pandas as pd
4 | import datamol as dm
5 | from molfeat.trans.base import MoleculeTransformer
6 |
7 | from loguru import logger
8 | from functools import partial
9 | from typing import Optional, List
10 | from mood.preprocessing import DEFAULT_PREPROCESSING
11 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, DATASET_DATA_DIR, SUPPORTED_DOWNSTREAM_APPS
12 | from mood.dataset import dataset_iterator, MOOD_DATASETS
13 |
14 | STATE_DICT = {
15 | "_molfeat_version": "0.5.2",
16 | "args": {
17 | "max_length": 256,
18 | "name": "pcqm4mv1_graphormer_base",
19 | "pooling": "mean",
20 | "precompute_cache": False,
21 | "version": None,
22 | },
23 | "name": "GraphormerTransformer",
24 | }
25 |
26 |
27 | def cli(
28 | batch_size: int = 16,
29 | verbose: bool = False,
30 | overwrite: bool = False,
31 | skip: Optional[List[str]] = None,
32 | ):
33 | if skip is None:
34 | skip = []
35 |
36 | graphormer = MoleculeTransformer.from_state_dict(STATE_DICT)
37 | standardize_fn = partial(DEFAULT_PREPROCESSING["Graphormer"], disable_logs=True)
38 |
39 | # Precompute Graphormer for the downstream applications
40 | for molecule_set in [app for app in SUPPORTED_DOWNSTREAM_APPS if app not in skip]:
41 | in_path = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, f"{molecule_set}.csv")
42 | out_path = dm.fs.join(
43 | DOWNSTREAM_APPS_DATA_DIR, "representations", molecule_set, f"Graphormer.parquet"
44 | )
45 |
46 | if dm.fs.exists(out_path) and not overwrite:
47 | raise ValueError(f"{out_path} already exists! Use --overwrite to overwrite!")
48 |
49 | # Load
50 | logger.info(f"Loading SMILES from {in_path}")
51 | df = pd.read_csv(in_path)
52 |
53 | # Standardization
54 | df["smiles"] = dm.utils.parallelized(
55 | standardize_fn,
56 | df["canonical_smiles"].values,
57 | n_jobs=-1,
58 | progress=verbose,
59 | )
60 |
61 | # Setting max length. We don't ignore padding tokens, so best to do this per dataset
62 | graphormer.set_max_length(graphormer.compute_max_length(df["smiles"].values))
63 | logger.info(f"Computed a max number of nodes of {graphormer.max_length}")
64 |
65 | # Compute the representation
66 | logger.info(f"Precomputing Graphormer representation")
67 | feats = graphormer.batch_transform(
68 | graphormer, df["smiles"].values, batch_size=batch_size, n_jobs=None
69 | )
70 |
71 | df["representation"] = list(feats)
72 | df = df[~pd.isna(df["representation"])]
73 |
74 | # Save
75 | logger.info(f"Saving results to {out_path}")
76 | df[["unique_id", "representation"]].to_parquet(out_path)
77 |
78 | blacklist = [app for app in MOOD_DATASETS if app in skip]
79 | for dataset, (smiles, _) in dataset_iterator(blacklist=blacklist, disable_logs=True):
80 | logger.info(f"Dataset {dataset}")
81 | out_path = dm.fs.join(DATASET_DATA_DIR, "representations", dataset, f"Graphormer.parquet")
82 |
83 | if dm.fs.exists(out_path) and not overwrite:
84 | raise ValueError(f"{out_path} already exists! Use --overwrite to overwrite!")
85 |
86 | df = pd.DataFrame()
87 | df["smiles"] = dm.utils.parallelized(
88 | standardize_fn,
89 | smiles,
90 | n_jobs=-1,
91 | progress=verbose,
92 | )
93 |
94 | # Setting max length. We don't ignore padding tokens, so best to do this per dataset
95 | graphormer.set_max_length(graphormer.compute_max_length(df["smiles"].values))
96 | logger.info(f"Computed a max number of nodes of {graphormer.max_length}")
97 |
98 | df["unique_id"] = dm.utils.parallelized(
99 | dm.unique_id,
100 | df["smiles"].values,
101 | n_jobs=-1,
102 | progress=verbose,
103 | )
104 |
105 | feats = graphormer.batch_transform(
106 | graphormer, df["smiles"].values, batch_size=batch_size, n_jobs=None
107 | )
108 | df["representation"] = list(feats)
109 |
110 | # Save
111 | logger.info(f"Saving results to {out_path}")
112 | df.to_parquet(out_path)
113 |
114 |
115 | if __name__ == "__main__":
116 | typer.run(cli)
117 |
--------------------------------------------------------------------------------
/scripts/precompute_representations.py:
--------------------------------------------------------------------------------
1 | import typer
2 |
3 | import pandas as pd
4 | import datamol as dm
5 |
6 | from loguru import logger
7 | from typing import Optional
8 | from mood.preprocessing import DEFAULT_PREPROCESSING
9 | from mood.representations import MOOD_REPRESENTATIONS, featurize
10 | from mood.constants import DOWNSTREAM_APPS_DATA_DIR, SUPPORTED_DOWNSTREAM_APPS
11 |
12 |
13 | def cli(
14 | molecule_set: str,
15 | representation: str,
16 | n_jobs: Optional[int] = None,
17 | batch_size: int = 16,
18 | verbose: bool = False,
19 | override: bool = False,
20 | ):
21 | if molecule_set not in SUPPORTED_DOWNSTREAM_APPS:
22 | raise typer.BadParameter(f"--molecule-set should be one of {SUPPORTED_DOWNSTREAM_APPS}.")
23 | if representation not in MOOD_REPRESENTATIONS:
24 | raise typer.BadParameter(f"--representation should be one of {MOOD_REPRESENTATIONS}.")
25 |
26 | in_path = dm.fs.join(DOWNSTREAM_APPS_DATA_DIR, f"{molecule_set}.csv")
27 | out_path = dm.fs.join(
28 | DOWNSTREAM_APPS_DATA_DIR, "representations", molecule_set, f"{representation}.parquet"
29 | )
30 |
31 | if dm.fs.exists(out_path) and not override:
32 | raise ValueError(f"{out_path} already exists! Use --override to override!")
33 |
34 | # Load
35 | logger.info(f"Loading SMILES from {in_path}")
36 | df = pd.read_csv(in_path)
37 |
38 | # Standardization fn
39 | standardize_fn = DEFAULT_PREPROCESSING[representation]
40 |
41 | # Compute the representation
42 | logger.info(f"Precomputing {representation} representation")
43 | df["representation"] = list(
44 | featurize(
45 | df["canonical_smiles"].values,
46 | representation,
47 | standardize_fn,
48 | n_jobs=n_jobs,
49 | progress=verbose,
50 | batch_size=batch_size,
51 | return_mask=False,
52 | disable_logs=True,
53 | )
54 | )
55 | df = df[~pd.isna(df["representation"])]
56 |
57 | # Save
58 | logger.info(f"Saving results to {out_path}")
59 | df[["unique_id", "representation"]].to_parquet(out_path)
60 |
--------------------------------------------------------------------------------
/scripts/visualize_perf_over_distance.py:
--------------------------------------------------------------------------------
1 | import datamol as dm
2 | import pandas as pd
3 |
4 | from datetime import datetime
5 | from typing import Optional
6 |
7 | from matplotlib import pyplot as plt
8 |
9 | from mood.constants import RESULTS_DIR
10 | from mood.visualize import plot_performance_over_distance
11 |
12 |
13 | def cli(
14 | baseline_algorithm: str,
15 | representation: str,
16 | dataset: str,
17 | base_save_dir: str = RESULTS_DIR,
18 | sub_save_dir: Optional[str] = None,
19 | ):
20 | if sub_save_dir is None:
21 | sub_save_dir = datetime.now().strftime("%Y%m%d")
22 | out_dir = dm.fs.join(base_save_dir, "dataframes", "compare_performance", sub_save_dir)
23 | dm.fs.mkdir(out_dir, exist_ok=True)
24 |
25 | file_name = f"perf_over_distance_{dataset}_{baseline_algorithm}_{representation}.csv"
26 | out_path = dm.fs.join(out_dir, file_name)
27 |
28 | df = pd.read_csv(out_path)
29 | df["score_lower"] = df["score_mu"] - df["score_std"]
30 | df["score_upper"] = df["score_mu"] + df["score_std"]
31 |
32 | plot_performance_over_distance(
33 | performance_data=df[df["type"] == "performance"],
34 | calibration_data=df[df["type"] == "calibration"],
35 | dataset_name=dataset,
36 | )
37 | plt.show()
38 |
--------------------------------------------------------------------------------
/scripts/visualize_shift.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from typing import Optional
4 | from loguru import logger
5 |
6 | from mood.dataset import load_data_from_tdc, MOOD_REGR_DATASETS
7 | from mood.representations import featurize
8 | from mood.preprocessing import DEFAULT_PREPROCESSING
9 | from mood.visualize import plot_distance_distributions
10 | from mood.distance import compute_knn_distance
11 | from mood.utils import load_representation_for_downstream_application
12 | from mood.train import train_baseline_model
13 | from mood.model_space import ModelSpaceTransformer
14 |
15 |
16 | def cli(
17 | dataset: str,
18 | representation: str,
19 | model_space: Optional[str] = None,
20 | ):
21 | smiles, y = load_data_from_tdc(dataset)
22 | standardize_fn = DEFAULT_PREPROCESSING[representation]
23 | X, mask = featurize(smiles, representation, standardize_fn, disable_logs=True)
24 | y = y[mask]
25 |
26 | logger.info(f"Loading precomputed representations for virtual screening")
27 | virtual_screening = load_representation_for_downstream_application("virtual_screening", representation)
28 |
29 | logger.info(f"Loading precomputed representations for optimization")
30 | optimization = load_representation_for_downstream_application("optimization", representation)
31 |
32 | if model_space is not None:
33 | logger.info(f"Computing distance in the {model_space} model space")
34 | is_regression = dataset in MOOD_REGR_DATASETS
35 | model = train_baseline_model(X, y, model_space, is_regression)
36 | embedding_size = int(round(X.shape[1] * 0.25))
37 | trans = ModelSpaceTransformer(model, embedding_size)
38 |
39 | X = trans(X)
40 | virtual_screening = trans(virtual_screening)
41 | optimization = trans(optimization)
42 |
43 | logger.info("Computing the k-NN distance")
44 | distances = compute_knn_distance(X, [X, optimization, virtual_screening], n_jobs=-1)
45 |
46 | logger.info("Plotting the results")
47 | labels = ["Train", "Optimization", "Virtual Screening"]
48 | ax = plot_distance_distributions(distances, labels=labels)
49 | plt.show()
50 |
--------------------------------------------------------------------------------
/scripts/visualize_splits.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from typing import Optional
5 | from loguru import logger
6 |
7 | from mood.distance import get_distance_metric
8 | from mood.representations import featurize
9 | from mood.dataset import load_data_from_tdc
10 | from mood.utils import load_distances_for_downstream_application
11 | from mood.splitter import MOODSplitter, get_mood_splitters
12 | from mood.preprocessing import DEFAULT_PREPROCESSING
13 |
14 |
15 | def cli(
16 | dataset: str,
17 | representation: str,
18 | n_splits: int = 5,
19 | use_cache: bool = True,
20 | seed: Optional[int] = None,
21 | ):
22 | logger.info(f"Loading precomputed distances for virtual screening")
23 | distances_vs = load_distances_for_downstream_application(
24 | "virtual_screening", representation, dataset, update_cache=not use_cache
25 | )
26 |
27 | logger.info(f"Loading precomputed distances for optimization")
28 | distances_op = load_distances_for_downstream_application(
29 | "optimization", representation, dataset, update_cache=not use_cache
30 | )
31 |
32 | smiles, y = load_data_from_tdc(dataset)
33 | standardize_fn = DEFAULT_PREPROCESSING[representation]
34 | X, mask = featurize(smiles, representation, standardize_fn, disable_logs=True)
35 | y = y[mask]
36 |
37 | metric = get_distance_metric(X)
38 | if metric == "jaccard":
39 | X = X.astype(bool)
40 |
41 | splitters = get_mood_splitters(smiles[mask], n_splits, seed)
42 | splitter = MOODSplitter(splitters, np.concatenate((distances_vs, distances_op)), metric, k=5)
43 | ax = splitter.fit(X, y, plot=True, progress=False)
44 | plt.show()
45 |
--------------------------------------------------------------------------------