├── mixmil ├── paths.py ├── __init__.py ├── data.py ├── posterior.py ├── simulation.py ├── utils.py └── model.py ├── tests ├── test_simulation.py └── test_model.py ├── pyproject.toml ├── scripts └── dsmil_data_download.py ├── README.md ├── .gitignore ├── experiments ├── simulation.ipynb ├── histopathology_camelyon16.ipynb └── reduced_histopathology_camelyon16.ipynb └── LICENSE /mixmil/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT = Path(__file__).parent.parent 4 | 5 | DATA = ROOT / "data" 6 | -------------------------------------------------------------------------------- /mixmil/__init__.py: -------------------------------------------------------------------------------- 1 | from mixmil.model import MixMIL 2 | 3 | __all__ = ["MixMIL", "utils", "likelihood", "posterior", "data", "simulation", "paths"] 4 | __version__ = "0.1.2" 5 | -------------------------------------------------------------------------------- /tests/test_simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as st 3 | from sklearn.metrics import roc_auc_score 4 | 5 | from mixmil import MixMIL 6 | from mixmil.data import load_data 7 | 8 | 9 | def calc_metrics(model, X, u, w): 10 | u_pred = model.predict(X["test"]).cpu().numpy().ravel() 11 | w_pred = model.get_weights(X["test"])[0].cpu().numpy().ravel() 12 | rho_bag = st.spearmanr(u_pred, u["test"]).correlation # bag level correlation 13 | is_top_instance = (w["test"] > np.quantile(w["test"], 0.90)).long().ravel() 14 | auc_instance = roc_auc_score(is_top_instance, w_pred) # instance-retrieval AUC 15 | return rho_bag, auc_instance 16 | 17 | 18 | def test_simulation(): 19 | X, F, Y, u, w = load_data(P=1, seed=0) 20 | model = MixMIL.init_with_mean_model(X["train"], F["train"], Y["train"], likelihood="binomial", n_trials=2) 21 | 22 | start_rho_bag, start_auc_instance = calc_metrics(model, X, u, w) 23 | model.train(X["train"], F["train"], Y["train"], n_epochs=40) 24 | end_rho_bag, end_auc_instance = calc_metrics(model, X, u, w) 25 | 26 | # assert that both metrics improved by at least 10% 27 | assert end_rho_bag > start_rho_bag * 1.1 28 | assert end_auc_instance > start_auc_instance * 1.1 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mixmil" 3 | dynamic = ["version"] 4 | description = "Attention-based Multi-instance Mixed Models" 5 | readme = "README.md" 6 | license.file = "LICENSE" 7 | authors = [ 8 | { name = "Jan Engelmann", email = "jan.engelmann@helmholtz-munich.de" }, 9 | { name = "Alessandro Palma", email = "alessandro.palma@helmholtz-munich.de" }, 10 | { name = "Paolo Casale", email = "francescopaolo.casale@helmholtz-munich.de" }, 11 | ] 12 | dependencies = [ 13 | "numpy>=1.22.0", 14 | "torch>=1.4.0", 15 | "torch_scatter>=2.0.1", 16 | "scipy>=1.8.0", 17 | "scikit-learn>=1.3.0", 18 | "tqdm>=4.0.0", 19 | "statsmodels>=0.11.0", 20 | ] 21 | requires-python = ">=3.9" 22 | classifiers = [ 23 | "Development Status :: 3 - Alpha", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Programming Language :: Python :: 3 :: Only", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Topic :: Scientific/Engineering :: Bio-Informatics", 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Science/Research", 33 | "Natural Language :: English", 34 | "Operating System :: MacOS :: MacOS X", 35 | "Operating System :: Microsoft :: Windows", 36 | "Operating System :: POSIX :: Linux", 37 | ] 38 | 39 | [project.urls] 40 | Homepage = "https://github.com/AIH-SGML/mixmil" 41 | "Bug Tracker" = "https://github.com/AIH-SGML/mixmil/issues" 42 | Discussions = "https://github.com/AIH-SGML/mixmil/discussions" 43 | 44 | [project.optional-dependencies] 45 | experiments = ["anndata>=0.8.0", "jupyterlab>=3.0.0"] 46 | test = ["pytest>=6.0.0"] 47 | all = ["mixmil[experiments,test]"] 48 | 49 | [build-system] 50 | requires = ["hatchling"] 51 | build-backend = "hatchling.build" 52 | 53 | [tool.hatch.build.targets.wheel] 54 | packages = ["mixmil"] 55 | 56 | [tool.hatch] 57 | version.path = "mixmil/__init__.py" 58 | 59 | [tool.black] 60 | line-length = 120 61 | 62 | [tool.ruff] 63 | line-length = 120 64 | ignore = ["E741"] 65 | -------------------------------------------------------------------------------- /mixmil/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from mixmil.simulation import load_simulation 6 | 7 | 8 | def setup_scatter(Xs): 9 | device = Xs[0].device 10 | x = torch.cat(Xs, dim=0) 11 | i = torch.cat([torch.full((x.shape[0],), idx) for idx, x in enumerate(Xs)]).to(device) 12 | i_ptr = torch.cat([torch.tensor([0], device=device), i.bincount().cumsum(0)]) 13 | return x, i, i_ptr 14 | 15 | 16 | def xgower_factor(X): 17 | a = np.power(X, 2).sum() 18 | b = X.dot(X.sum(0)).sum() 19 | return np.sqrt((a - b / X.shape[0]) / (X.shape[0] - 1)) 20 | 21 | 22 | class MILDataset(Dataset): 23 | def __init__(self, Xs, F, Y): 24 | self.Xs = Xs 25 | self.F = F 26 | self.Y = Y 27 | 28 | def __len__(self): 29 | return len(self.Xs) 30 | 31 | def __getitem__(self, idx): 32 | X = self.Xs[idx] 33 | F = self.F[idx] 34 | Y = self.Y[idx] 35 | return X, F, Y 36 | 37 | 38 | def mil_collate_fn(batch): 39 | X = [item[0] for item in batch] 40 | F = torch.stack([item[1] for item in batch]) 41 | Y = torch.stack([item[2] for item in batch]) 42 | return X, F, Y 43 | 44 | 45 | def normalize_feats(X, norm_factor="std_sqrt"): 46 | assert norm_factor in ["std", "std_sqrt"] 47 | train_data = ( 48 | torch.cat(X["train"], dim=0) if isinstance(X["train"], list) else X["train"].reshape(-1, X["train"].shape[2]) 49 | ) 50 | mean = train_data.mean(0, keepdims=True) 51 | std = train_data.std(0, keepdims=True) 52 | factor = std * np.sqrt(train_data.shape[1]) if norm_factor == "std_sqrt" else std 53 | 54 | for key in X: 55 | X[key] = [(x - mean) / factor for x in X[key]] if isinstance(X[key], list) else (X[key] - mean) / factor 56 | 57 | return X 58 | 59 | 60 | def load_data(dataset="simulation", norm_factor="std_sqrt", **kwargs): 61 | if dataset == "simulation": 62 | X, F, Y, u, w = load_simulation(**kwargs) 63 | X = normalize_feats(X, norm_factor) 64 | else: 65 | raise ValueError(f"Unknown dataset: {dataset}") 66 | return X, F, Y, u, w 67 | -------------------------------------------------------------------------------- /mixmil/posterior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.distributions import LowRankMultivariateNormal 4 | 5 | 6 | def get_params(vc_mean, vc_sd, n_outs, n_vars, mean_field): 7 | mu_z = torch.Tensor(vc_mean) 8 | mu_u = torch.zeros_like(mu_z) 9 | mu = torch.cat([mu_u, mu_z], 1) 10 | sd_z = torch.Tensor(vc_sd) 11 | sd_u = torch.sqrt(0.1 * torch.ones_like(sd_z)) 12 | 13 | if mean_field: 14 | cov_factor = torch.zeros(n_outs, n_vars, 1) 15 | cov_logdiag = 2.0 * torch.log(torch.cat([sd_u, sd_z], 1)) 16 | else: 17 | diag = torch.diag_embed(torch.cat([sd_u, sd_z], 1)) 18 | cov_factor = 1e-4 * torch.randn(diag.shape) + diag 19 | cov_logdiag = np.log(1e-4) * torch.ones(n_outs, n_vars) 20 | 21 | return mu, cov_factor, cov_logdiag 22 | 23 | 24 | class GaussianVariationalPosterior(torch.nn.Module): 25 | def __init__(self, n_vars, n_outs, mean_field=True, init_params=None): 26 | super().__init__() 27 | self.n_vars = n_vars 28 | self.n_outs = n_outs 29 | self.mean_field = mean_field 30 | 31 | if init_params is not None: 32 | mu_z, sd_z, *_ = init_params 33 | mu_z = mu_z.T 34 | sd_z = sd_z.T 35 | else: 36 | mu_z = 1e-3 * torch.randn(n_outs, n_vars // 2) 37 | sd_z = 1e-3 * torch.randn(n_outs, n_vars // 2) 38 | 39 | mu, cov_factor, cov_logdiag = get_params(mu_z, sd_z, n_outs, n_vars, mean_field) 40 | 41 | self.mu = torch.nn.Parameter(mu) 42 | if mean_field: 43 | self.register_buffer("cov_factor", cov_factor) 44 | self.cov_logdiag = torch.nn.Parameter(cov_logdiag) 45 | else: 46 | self.cov_factor = torch.nn.Parameter(cov_factor) 47 | self.register_buffer("cov_logdiag", cov_logdiag) 48 | 49 | @property 50 | def distribution(self): 51 | return LowRankMultivariateNormal(self.mu, self.cov_factor, torch.exp(self.cov_logdiag)) 52 | 53 | @property 54 | def q_mu(self): 55 | return self.mu.T 56 | 57 | def sample(self, n_samples): 58 | return self.distribution.rsample([n_samples]).permute([2, 1, 0]) 59 | 60 | def extra_repr(self) -> str: 61 | return f"n_vars=2*{self.n_vars//2}, n_outs={self.n_outs}, mean_field={self.mean_field}" 62 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from mixmil import MixMIL 6 | 7 | 8 | @pytest.fixture 9 | def mock_data_binomial(): 10 | Xs = [torch.randn(10, 3) for _ in range(5)] # List of tensors 11 | F = torch.randn(5, 4) # Fixed effects 12 | Y = torch.randint(0, 2, (5, 1)) # Labels for binomial 13 | return Xs, F, Y 14 | 15 | 16 | @pytest.fixture 17 | def mock_data_categorical(): 18 | N, Q, K = 50, 10, 4 19 | bag_sizes = torch.randint(5, 15, (N,)) 20 | Xs = [torch.randn(bag_sizes[n], Q) for n in range(N)] # List of tensors 21 | F = torch.randn(N, K) # Fixed effects 22 | Y = torch.randint(0, 5, (N, 1)) # Labels for categorical 23 | return Xs, F, Y 24 | 25 | 26 | def test_init_with_mean_model_binomial(mock_data_binomial): 27 | Xs, F, Y = mock_data_binomial 28 | model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="binomial", n_trials=2) 29 | model.train(Xs, F, Y, n_epochs=3) 30 | assert isinstance(model, MixMIL) 31 | assert model.likelihood_name == "binomial" 32 | assert model.n_trials == 2 33 | assert model.log_sigma_u.numel() == 1 34 | 35 | 36 | def test_init_with_mean_model_categorical(mock_data_categorical): 37 | Xs, F, Y = mock_data_categorical 38 | model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="categorical") 39 | model.train(Xs, F, Y, n_epochs=3) 40 | assert isinstance(model, MixMIL) 41 | assert model.likelihood_name == "categorical" 42 | assert model.n_trials is None 43 | assert model.log_sigma_u.numel() == len(np.unique(Y)) # separate prior for each class 44 | 45 | 46 | def test_initialization(): 47 | model = MixMIL(Q=10, K=5, P=2, likelihood="binomial", n_trials=2) 48 | assert model.Q == 10 49 | assert model.alpha.shape == (5, 2) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "Q, K, P, likelihood, n_trials, mean_field", 54 | [ 55 | (10, 5, 2, "categorical", None, True), 56 | (10, 5, 2, "categorical", None, False), 57 | (10, 5, 2, "binomial", 1, True), 58 | (10, 5, 2, "binomial", 2, False), 59 | (10, 5, 1, "binomial", 2, False), 60 | ], 61 | ) 62 | def test_init_model(Q, K, P, likelihood, n_trials, mean_field): 63 | MixMIL(Q=Q, K=K, P=P, likelihood=likelihood, n_trials=n_trials, mean_field=mean_field) 64 | -------------------------------------------------------------------------------- /scripts/dsmil_data_download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script adapted from 3 | https://github.com/binli123/dsmil-wsi/blob/master/download.py 4 | 5 | Script and data provided by the DSMIL authors: 6 | 7 | @inproceedings{li2021dual, 8 | title={Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning}, 9 | author={Li, Bin and Li, Yin and Eliceiri, Kevin W}, 10 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 11 | pages={14318--14328}, 12 | year={2021} 13 | } 14 | """ 15 | 16 | import argparse 17 | import os 18 | import urllib.request 19 | import zipfile 20 | 21 | from tqdm import tqdm 22 | 23 | 24 | class DownloadProgressBar(tqdm): 25 | def update_to(self, b=1, bsize=1, tsize=None): 26 | if tsize is not None: 27 | self.total = tsize 28 | self.update(b * bsize - self.n) 29 | 30 | 31 | def download_url(url, output_path): 32 | with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]) as t: 33 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 34 | 35 | 36 | def unzip_data(zip_path, data_path): 37 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 38 | zip_ref.extractall(data_path) 39 | 40 | 41 | DATASET_URLS = { 42 | "camelyon16": "https://uwmadison.box.com/shared/static/l9ou15iwup73ivdjq0bc61wcg5ae8dwe.zip", 43 | } 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--dataset", type=str, default="camelyon16", help="Dataset to be downloaded: camelyon16") 49 | parser.add_argument("--keep-zip", action="store_true", help="Keep the downloaded zip file") 50 | args = parser.parse_args() 51 | 52 | assert args.dataset in DATASET_URLS, f"Dataset {args.dataset} not found" 53 | 54 | print(f"downloading dataset: {args.dataset}") 55 | unzip_dir = f"data/{args.dataset}" 56 | zip_file_path = f"data/{args.dataset}-dataset.zip" 57 | os.makedirs(unzip_dir, exist_ok=True) 58 | download_url(DATASET_URLS[args.dataset], zip_file_path) 59 | unzip_data(zip_file_path, unzip_dir) 60 | 61 | if not args.keep_zip: 62 | os.remove(f"{args.dataset}-dataset.zip") 63 | 64 | print("done!") 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /mixmil/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.model_selection import train_test_split 4 | 5 | 6 | def get_X(N=1_000, I=50, Q=30, N_test=200): 7 | N = N + N_test 8 | 9 | X = torch.randn([N * I, Q], dtype=torch.float32) 10 | 11 | X = (X - X.mean(0)) / X.std(0) 12 | X = X / np.sqrt(X.shape[1]) 13 | X = X.reshape([N, I, Q]) 14 | 15 | return X 16 | 17 | 18 | def simulate(X, v_beta=0.5, v_gamma=0.8, b=-1, F=None, P=1): 19 | if F is None: 20 | F = torch.ones([X.shape[0], 1]) 21 | 22 | # simulate single phenotype 23 | K = F.shape[1] 24 | b = b * torch.ones([K, P]) 25 | v_beta = v_beta * torch.ones(P) 26 | v_gamma = v_gamma * torch.ones(P) 27 | 28 | # sample weights 29 | gamma = torch.randn((X.shape[2], v_gamma.shape[0])) 30 | _w = torch.einsum("nik,kp->nip", X, gamma) 31 | _scale_w = torch.sqrt(v_gamma / _w.var([0, 1])) 32 | gamma = _scale_w[None, :] * gamma 33 | _w = _scale_w[None, None, :] * _w 34 | 35 | w = torch.softmax(_w, dim=1) 36 | 37 | # sample z 38 | beta = torch.randn((X.shape[2], v_beta.shape[0])) 39 | beta = beta / torch.sqrt((beta**2).mean(0, keepdim=True)) 40 | z = torch.einsum("nik,kp->nip", X, beta) 41 | u = torch.einsum("nip,nip->np", w, z) 42 | u = (u - u.mean(0)) / u.std(0) 43 | u = torch.sqrt(v_beta) * u 44 | beta = torch.sqrt(v_beta) * beta 45 | 46 | # compute rates 47 | logits = F.mm(b) + u 48 | 49 | # sample Y 50 | Y = torch.distributions.Binomial(2, logits=logits).sample() 51 | 52 | return F, Y, u, w 53 | 54 | 55 | def split_data(Xs, test_size=200, val_size=0.0, test_rs=127, val_rs=412): 56 | idxs_all = np.arange(Xs[0].shape[0]) 57 | idxs = {} 58 | idxs["train_val"], idxs["test"] = train_test_split(idxs_all, test_size=test_size, random_state=test_rs) 59 | 60 | if not np.isclose(val_size, 0): 61 | idxs["train"], idxs["val"] = train_test_split(idxs["train_val"], test_size=val_size, random_state=val_rs) 62 | else: 63 | idxs["train"] = idxs["train_val"] 64 | del idxs["train_val"] 65 | 66 | outs = [] 67 | for X in Xs: 68 | out = {} 69 | for key in idxs.keys(): 70 | out[key] = X[idxs[key]] 71 | outs.append(out) 72 | return outs 73 | 74 | 75 | def load_simulation(seed=42, N=1_000, I=50, Q=30, N_test=200, P=1, v_beta=0.5, v_gamma=0.8, b=-1, F=None): 76 | np.random.seed(seed) 77 | torch.manual_seed(seed) 78 | 79 | X = get_X(N, I, Q, N_test) 80 | data = simulate(X, v_beta, v_gamma, b, F, P) 81 | X, F, Y, u, w = split_data((X, *data)) 82 | return X, F, Y, u, w 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixMIL 2 | Code for the paper: [Mixed Models with Multiple Instance Learning](https://arxiv.org/abs/2311.02455) 3 | 4 | Accepted at AISTATS 24 as an oral presentation & [Outstanding Student Paper Highlight](https://aistats.org/aistats2024/awards.html). 5 | 6 | Please raise an issue for questions and bug-reports. 7 | ## Installation 8 | Install with: 9 | ``` 10 | pip install mixmil 11 | ``` 12 | alternatively, if you want to include the optional experiment and test dependencies use: 13 | ``` 14 | pip install "mixmil[experiments,test]" 15 | ``` 16 | or if you want to adapt the code: 17 | ``` 18 | git clone https://github.com/AIH-SGML/mixmil.git 19 | cd mixmil 20 | pip install -e ".[experiments,test]" 21 | ``` 22 | To enable computations on GPU please follow the installation instructions of [PyTorch](https://pytorch.org/) and [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter). 23 | MixMIL works e.g. with PyTorch 2.1. 24 | ## Experiments 25 | See the notebooks in the `experiments` folder for examples on how to run the simulation and histopathology experiments. 26 | 27 | Make sure the `experiments` requirements are installed: 28 | ``` 29 | pip install "mixmil[experiments]" 30 | ``` 31 | ### Histopathology 32 | The histopathology experiment was performed on the [CAMELYON16](https://camelyon16.grand-challenge.org/) dataset. 33 | #### Download Data 34 | To download the embeddings provided by the DSMIL authors, either: 35 | - Full embeddings: `python scripts/dsmil_data_download.py` 36 | - PCA reduced embeddings: [Google Drive](https://drive.google.com/drive/folders/1X9ho1_W5ixyHSw_2hCfQsBb5nzkjMviA?usp=sharing) 37 | 38 | ### Microscopy 39 | The full BBBC021 dataset can be downloaded [here](https://bbbc.broadinstitute.org/BBBC021). 40 | #### Download Data 41 | - We make the featurized cells available at [BBBC021](https://drive.google.com/file/d/1OyH3zg22N107qrPVp3p-GLoFa1KzeoID/view?usp=sharing) 42 | - The features are stored as an [AnnData](https://anndata.readthedocs.io/en/latest/) object. We recommend using the [scanpy](https://scanpy.readthedocs.io/en/stable/) package to read and process them 43 | - The weights of the featurizer trained with the SimCLR algorithm can be downloaded from the original [GitHub repository](https://github.com/SamriddhiJain/SimCLR-for-cell-profiling?tab=readme-ov-file) 44 | 45 | ## Citation 46 | ``` 47 | @inproceedings{engelmann2024mixed, 48 | title={Mixed Models with Multiple Instance Learning}, 49 | author={Engelmann, Jan P. and Palma, Alessandro and Tomczak, Jakub M. and Theis, Fabian and Casale, Francesco Paolo}, 50 | booktitle={International Conference on Artificial Intelligence and Statistics}, 51 | pages={3664--3672}, 52 | year={2024}, 53 | organization={PMLR} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /mixmil/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as la 3 | import statsmodels.api as sm 4 | import torch 5 | from sklearn.linear_model import LogisticRegressionCV 6 | from tqdm.auto import trange 7 | 8 | from mixmil.data import xgower_factor 9 | 10 | 11 | def regressOut(Y, X, return_b=False, return_pinv=False): 12 | """ 13 | regresses out X from Y 14 | """ 15 | Xd = la.pinv(X) 16 | b = Xd.dot(Y) 17 | Y_out = Y - X.dot(b) 18 | out = [Y_out] 19 | if return_b: 20 | out.append(b) 21 | if return_pinv: 22 | out.append(Xd) 23 | return out if len(out) > 1 else out[0] 24 | 25 | 26 | def _get_single_binomial_init_params(X, F, y): 27 | ident = np.zeros((X.shape[1]), dtype=int) 28 | model = sm.BinomialBayesMixedGLM(y, F, X, ident).fit_vb() 29 | 30 | u = np.dot(X, model.vc_mean)[::2] 31 | 32 | _scale = u.std() / np.sqrt((model.vc_mean**2).mean(0)) 33 | mu_beta = _scale * model.vc_mean 34 | sd_beta = _scale * model.vc_sd 35 | var_z = (mu_beta**2 + sd_beta**2).mean().reshape(1) 36 | alpha = model.fe_mean 37 | 38 | return mu_beta, sd_beta, var_z, alpha 39 | 40 | 41 | def get_binomial_init_params(X, F, Y): 42 | results = [_get_single_binomial_init_params(X, F, Y[:, p]) for p in trange(Y.shape[1], desc="GLMM Init")] 43 | 44 | mu_beta, sd_beta, var_z, alpha = [_list2tensor(listo) for listo in zip(*results)] 45 | 46 | return mu_beta, sd_beta, var_z, alpha 47 | 48 | 49 | def get_lr_init_params(X, Y, b, Fiv): 50 | model = LogisticRegressionCV( 51 | Cs=10, 52 | fit_intercept=True, 53 | penalty="l2", 54 | multi_class="multinomial", 55 | solver="lbfgs", 56 | n_jobs=1, 57 | verbose=0, 58 | random_state=42, 59 | max_iter=1000, 60 | refit=True, 61 | ) 62 | 63 | model.fit(X, Y.ravel()) 64 | 65 | alpha = model.intercept_[None] 66 | beta = model.coef_.T 67 | 68 | # Compute bag prediction u and reparametrize 69 | u = X.dot(beta) 70 | um = u.mean(0)[None] 71 | us = u.std(0)[None] 72 | alpha = alpha + um 73 | mu_beta = us * beta / np.sqrt((beta**2).mean(0)[None]) 74 | sd_beta = np.sqrt(0.1 * (mu_beta**2).mean()) * np.ones_like(mu_beta) 75 | 76 | alpha = Fiv.dot(np.ones((Fiv.shape[1], 1))).dot(alpha) - b.dot(mu_beta) 77 | 78 | # init prior 79 | var_z = (mu_beta**2 + sd_beta**2).mean(axis=0, keepdims=True) 80 | 81 | return [torch.Tensor(el) for el in (mu_beta, sd_beta, var_z, alpha)] 82 | 83 | 84 | def _list2tensor(_list): 85 | return torch.Tensor(np.stack(_list, axis=1)) 86 | 87 | 88 | def get_init_params(Xs, F, Y, likelihood, n_trials): 89 | Xm = np.concatenate([x.mean(0, keepdims=True) for x in Xs], axis=0) 90 | Fe, Ye = F.numpy(), Y.long().numpy() 91 | 92 | if likelihood == "binomial": 93 | Xm = (Xm - Xm.mean(0, keepdims=True)) / xgower_factor(Xm) 94 | 95 | if n_trials == 2: 96 | Xm, Fe = Xm.repeat(2, axis=0), Fe.repeat(2, axis=0) 97 | to_expanded = np.array(([[0, 0], [1, 0], [1, 1]])) 98 | Ye = to_expanded[Y.long().numpy().T].transpose(1, 2, 0).reshape(-1, Y.shape[1]) 99 | 100 | mu_z, sd_z, var_z, alpha = get_binomial_init_params(Xm, Fe, Ye) 101 | 102 | elif likelihood == "categorical": 103 | Xm, b, Fiv = regressOut(Xm, Fe, return_b=True, return_pinv=True) 104 | Xm = (Xm - Xm.mean(0, keepdims=True)) / (Xm.std(0, keepdims=True) * np.sqrt(Xm.shape[-1])) 105 | 106 | mu_z, sd_z, var_z, alpha = get_lr_init_params(Xm, Ye, b, Fiv) 107 | 108 | return mu_z, sd_z, var_z, alpha 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | data 164 | *.out 165 | assets 166 | -------------------------------------------------------------------------------- /experiments/simulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Simulation Experiment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "> In this notebook we demonstrate the how to train the MixMIL model on data simulated under as specified in the paper in the Binomial likelihood setting. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/home/icb/alessandro.palma/miniconda3/envs/sslbio-env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 27 | " from .autonotebook import tqdm as notebook_tqdm\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import scipy.stats as st\n", 34 | "import torch\n", 35 | "from sklearn.metrics import roc_auc_score\n", 36 | "\n", 37 | "from mixmil import MixMIL\n", 38 | "from mixmil.data import load_data\n", 39 | "import pandas as pd" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Utility Functions" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import numpy as np\n", 56 | "from sklearn.metrics import roc_auc_score\n", 57 | "import scipy.stats as st\n", 58 | "\n", 59 | "def _calc_metrics(u_pred, w_pred, u, w):\n", 60 | " \"\"\"\n", 61 | " Calculate correlation and AUC metrics using real and predicted instance weights.\n", 62 | "\n", 63 | " Parameters:\n", 64 | " - u_pred (numpy.ndarray): Predicted instance-level weights.\n", 65 | " - w_pred (numpy.ndarray): Predicted instance-level weights as instance proportions.\n", 66 | " - u (numpy.ndarray): True instance-level weights.\n", 67 | " - w (numpy.ndarray): True instance-level weights as instance proportions.\n", 68 | "\n", 69 | " Returns:\n", 70 | " - rho_bag (float): Weight correlation (Spearman's rank correlation coefficient).\n", 71 | " - auc_instance (float): Instance retrieval AUC (Area Under the Receiver Operating Characteristic curve).\n", 72 | " \"\"\"\n", 73 | " rho_bag = st.spearmanr(u_pred, u).correlation # bag level correlation\n", 74 | " is_top_instance = (w > np.quantile(w, 0.90)).long().ravel()\n", 75 | " auc_instance = roc_auc_score(is_top_instance, w_pred) # instance-retrieval AUC\n", 76 | " return rho_bag, auc_instance\n", 77 | "\n", 78 | "\n", 79 | "def calc_metrics(model, X, u, w):\n", 80 | " \"\"\"\n", 81 | " Calculate aggregated metrics over multiple bags or instances for a given model.\n", 82 | "\n", 83 | " Parameters:\n", 84 | " - model: Trained model.\n", 85 | " - X (dict): Dictionary containing input data.\n", 86 | " - u (dict): Dictionary containing true values for instance-level weights.\n", 87 | " - w (dict): Dictionary containing true values for instance-level weights as proportions.\n", 88 | "\n", 89 | " Returns:\n", 90 | " - res_dict (dict): Dictionary containing aggregated metrics including:\n", 91 | " - 'rho_bag' (float): Mean bag-level weight correlation.\n", 92 | " - 'rho_bag_err' (float): Standard error of the mean for bag-level correlation.\n", 93 | " - 'auc_instance' (float): Mean instance retrieval AUC.\n", 94 | " - 'auc_instance_err' (float): Standard error of the mean for instance retrieval AUC.\n", 95 | " \"\"\"\n", 96 | " u_pred = model.predict(X[\"test\"]).cpu().numpy()\n", 97 | " w_pred = model.get_weights(X[\"test\"])[0].cpu().numpy()\n", 98 | "\n", 99 | " P = u_pred.shape[1]\n", 100 | " if P > 0:\n", 101 | " rho_bag, auc_instance = [], []\n", 102 | " for i in range(P):\n", 103 | " _rho_bag, _auc_instance = _calc_metrics(\n", 104 | " u_pred[..., i], w_pred[..., i].ravel(), u[\"test\"][..., i], w[\"test\"][..., i].ravel()\n", 105 | " )\n", 106 | " rho_bag.append(_rho_bag)\n", 107 | " auc_instance.append(_auc_instance)\n", 108 | "\n", 109 | " res_dict = {\n", 110 | " \"rho_bag\": np.mean(rho_bag),\n", 111 | " \"rho_bag_err\": np.std(rho_bag) / np.sqrt(P),\n", 112 | " \"auc_instance\": np.mean(auc_instance),\n", 113 | " \"auc_instance_err\": np.std(auc_instance) / np.sqrt(P),\n", 114 | " }\n", 115 | " else:\n", 116 | " rho_bag, auc_instance = _calc_metrics(u_pred, w_pred.ravel(), u[\"test\"], w[\"test\"].ravel())\n", 117 | " res_dict = {\"rho_bag\": rho_bag, \"auc_instance\": auc_instance}\n", 118 | " return res_dict\n", 119 | "\n", 120 | "def print_metrics(prefix, metrics):\n", 121 | " \"\"\"\n", 122 | " Print a formatted representation of metrics with a specified prefix.\n", 123 | "\n", 124 | " Parameters:\n", 125 | " - prefix (str): Prefix to be added to the printed metrics, for better identification.\n", 126 | " - metrics (dict): Dictionary containing metrics to be printed.\n", 127 | "\n", 128 | " Returns:\n", 129 | " - None: This function prints the metrics to the console without returning any value.\n", 130 | " \"\"\"\n", 131 | " print(f\"{prefix} metrics:\")\n", 132 | " for k, v in metrics.items():\n", 133 | " print(f\"{k}: {v:.4f}\")\n", 134 | " print()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Training" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Train model with simulated data under using a binomial likelihood." 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "GLMM Init: 100%|██████████| 10/10 [00:07<00:00, 1.33it/s]\n" 161 | ] 162 | }, 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "[START] metrics:\n", 168 | "rho_bag: 0.6004\n", 169 | "rho_bag_err: 0.0131\n", 170 | "auc_instance: 0.5000\n", 171 | "auc_instance_err: 0.0000\n", 172 | "\n" 173 | ] 174 | }, 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "Epoch: 100%|██████████| 2000/2000 [04:30<00:00, 7.39it/s]" 180 | ] 181 | }, 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "[END] metrics:\n", 187 | "rho_bag: 0.8316\n", 188 | "rho_bag_err: 0.0111\n", 189 | "auc_instance: 0.9362\n", 190 | "auc_instance_err: 0.0078\n", 191 | "\n" 192 | ] 193 | }, 194 | { 195 | "name": "stderr", 196 | "output_type": "stream", 197 | "text": [ 198 | "\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", 204 | "\n", 205 | "# Simulate data as described in the paper\n", 206 | "# embeddings, fixed effects, labels, sim bag predictions, sim instance weights\n", 207 | "# P: number of outputs, simulated from the same embeddings X\n", 208 | "X, F, Y, u, w = load_data(P=10, seed=0)\n", 209 | "model = MixMIL.init_with_mean_model(X[\"train\"], F[\"train\"], Y[\"train\"], likelihood=\"binomial\", n_trials=2).to(device)\n", 210 | "X, F, Y = [{key: val.to(device) for key, val in el.items()} for el in [X, F, Y]]\n", 211 | "\n", 212 | "print_metrics(\"[START]\", calc_metrics(model, X, u, w))\n", 213 | "# Fit model in parallel to each output separately\n", 214 | "model.train(X[\"train\"], F[\"train\"], Y[\"train\"], n_epochs=2_000)\n", 215 | "print_metrics(\"[END]\", calc_metrics(model, X, u, w))" 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3 (ipykernel)", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.9.13" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /mixmil/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.distributions import Binomial, Categorical, LowRankMultivariateNormal 4 | from torch.distributions.kl import kl_divergence 5 | from torch.utils.data import DataLoader 6 | from torch_scatter import scatter_softmax, segment_add_csr 7 | from tqdm.auto import trange 8 | 9 | from mixmil.data import MILDataset, mil_collate_fn, setup_scatter 10 | from mixmil.posterior import GaussianVariationalPosterior 11 | from mixmil.utils import get_init_params 12 | 13 | 14 | class MixMIL(torch.nn.Module): 15 | """Attention-based Multi-instance Mixed Models 16 | 17 | https://arxiv.org/abs/2311.02455 18 | 19 | """ 20 | 21 | def __init__(self, Q, K, P=1, likelihood="binomial", n_trials=2, mean_field=False, init_params=None): 22 | r"""Initialize the MixMil class. 23 | 24 | Parameters: 25 | - Q (int): The dimension of the latent space. 26 | - K (int): The number of fixed effects. 27 | - P (int): The number of outputs. 28 | - likelihood (str, optional): The likelihood to use. Either "binomial" or "categorical". Default is "binomial". 29 | - n_trials (int, optional): Number of trials for binomial likelihood. Not used for categorical. Default is 2. 30 | - mean_field (bool, optional): Toggle mean field approximation for the posterior. Default is False. 31 | - init_params (tuple, optional): Tuple of (mean, var, var_z, alpha) to initialize the model. Default is None. 32 | mean (torch.Tensor): The mean of the posterior. Shape: (Q, P). d 33 | var (torch.Tensor): The variance of the posterior. Shape: (Q, P). 34 | var_z (torch.Tensor): The $\sigma_{\beta}^2$ hparam of the prior. 35 | Shape: (1, P) with separate and (1, 1) with shared priors . 36 | alpha (torch.Tensor): The fixed effect parameters. Shape: (K, P). 37 | """ 38 | super().__init__() 39 | self.Q = Q 40 | 41 | alpha = torch.zeros((K, P)) 42 | log_sigma_u = torch.full((1, P), 0.5 * np.log(0.5)) 43 | log_sigma_z = torch.full((1, P), 0.5 * np.log(0.5)) 44 | 45 | if init_params is not None: 46 | *_, var_z, alpha = init_params 47 | log_sigma_z = 0.5 * torch.log(var_z) 48 | 49 | self.alpha = torch.nn.Parameter(alpha) 50 | self.log_sigma_u = torch.nn.Parameter(log_sigma_u) 51 | self.log_sigma_z = torch.nn.Parameter(log_sigma_z) 52 | 53 | self.posterior = GaussianVariationalPosterior(2 * Q, P, mean_field, init_params) 54 | 55 | self.likelihood_name = likelihood 56 | self.n_trials = n_trials if likelihood == "binomial" else None 57 | self.is_trained = False 58 | 59 | def init_with_mean_model(Xs, F, Y, likelihood="binomial", n_trials=None, mean_field=False): 60 | assert (likelihood == "binomial" and n_trials is not None and 0 < n_trials <= 2) or ( 61 | likelihood == "categorical" and n_trials is None 62 | ), f"n_trials must be 1 or 2 to initialize with binomial mean model, got {n_trials=} and {likelihood=}" 63 | init_params = get_init_params(Xs, F, Y, likelihood, n_trials) 64 | Q, K, P = Xs[0].shape[1], F.shape[1], init_params[0].shape[1] 65 | return MixMIL(Q, K, P, likelihood, n_trials, mean_field, init_params) 66 | 67 | @property 68 | def prior_distribution(self): 69 | device = self.log_sigma_u.device 70 | scale_u = self.log_sigma_u.T * torch.ones([1, self.Q], device=device) 71 | scale_z = self.log_sigma_z.T * torch.ones([1, self.Q], device=device) 72 | cov_logdiag = 2 * torch.cat([scale_u, scale_z], 1) 73 | cov_factor = torch.zeros_like(cov_logdiag)[:, :, None] 74 | mu = torch.zeros_like(cov_logdiag) 75 | return LowRankMultivariateNormal(mu, cov_factor, torch.exp(cov_logdiag)) 76 | 77 | @property 78 | def posterior_distribution(self): 79 | return self.posterior.distribution 80 | 81 | @property 82 | def qu_mu(self): 83 | return self.posterior.q_mu[: self.Q] 84 | 85 | @property 86 | def qz_mu(self): 87 | return self.posterior.q_mu[self.Q :] 88 | 89 | def likelihood(self, logits, y): 90 | if self.likelihood_name == "binomial": 91 | return Binomial(total_count=self.n_trials, logits=logits).log_prob(y[:, :, None]).sum(1).mean() 92 | elif self.likelihood_name == "categorical": 93 | logits = logits.permute(0, 2, 1) 94 | if logits.shape[-1] == 1: 95 | logits = torch.cat([-logits, logits], 2) 96 | return Categorical(logits=logits).log_prob(y).mean() 97 | 98 | def loss(self, u, f, y, kld_w=1.0, return_dict=False): 99 | logits = f.mm(self.alpha)[:, :, None] + u 100 | 101 | ll = self.likelihood(logits, y) 102 | kld = kl_divergence(self.posterior_distribution, self.prior_distribution) 103 | kld_term = kld_w * kld.sum() / y.shape[0] 104 | loss = -ll + kld_term 105 | if return_dict: 106 | return loss, dict(loss=loss.item(), ll=ll.item(), kld=kld_term.item()) 107 | return loss 108 | 109 | def get_betas(self, n_samples=None, predict=False): 110 | assert not (n_samples and predict) 111 | if n_samples: 112 | beta = self.posterior.sample(n_samples) 113 | beta_u = beta[: self.Q, :, :] 114 | beta_z = beta[self.Q :, :, :] 115 | else: 116 | beta_u = self.qu_mu[:, :, None] 117 | beta_z = self.qz_mu[:, :, None] 118 | return beta_u, beta_z 119 | 120 | def forward(self, Xs, n_samples=8, scaling=None, predict=False): 121 | beta_u, beta_z = self.get_betas(n_samples, predict) 122 | b = torch.sqrt((beta_z**2).mean(0, keepdim=True)) 123 | eta = beta_z / b 124 | 125 | if torch.is_tensor(Xs): 126 | u = self._calc_bag_emb_effect_tensor(beta_u, eta, Xs) 127 | else: 128 | u = self._calc_bag_emb_effect_scatter(beta_u, eta, Xs) 129 | 130 | mean, std = (u.mean(0), u.std(0)) if scaling is None else scaling 131 | if std.isnan().any(): 132 | std = 1 133 | u = b * (u - mean) / std 134 | return u 135 | 136 | def _calc_bag_emb_effect_tensor(self, beta_u, eta, Xs): 137 | _w = torch.einsum("niq,qps->nips", Xs, beta_u) 138 | w = torch.softmax(_w, dim=1) 139 | t = torch.einsum("niq,qps->nips", Xs, eta) 140 | u = torch.einsum("nips,nips->nps", w, t) 141 | return u 142 | 143 | def _calc_bag_emb_effect_scatter(self, beta_u, eta, Xs): 144 | x, i, i_ptr = setup_scatter(Xs) 145 | 146 | _w = torch.einsum("iq,qps->ips", x, beta_u) 147 | w = scatter_softmax(_w, i, dim=0) 148 | t = torch.einsum("iq,qps->ips", x, eta) 149 | u = segment_add_csr(w * t, i_ptr) 150 | return u 151 | 152 | def train(self, X, F, Y, n_epochs=2_000, batch_size=64, lr=1e-3, verbose=True): 153 | train_loader = DataLoader( 154 | MILDataset(X, F, Y), 155 | shuffle=True, 156 | batch_size=batch_size, 157 | collate_fn=None if torch.is_tensor(X) else mil_collate_fn, 158 | ) 159 | optim = torch.optim.Adam(lr=lr, params=self.parameters()) 160 | 161 | history = [] 162 | for epoch in trange(1, n_epochs + 1, desc="Epoch", disable=not verbose): 163 | for step, (xs, f, y) in enumerate(train_loader): 164 | u = self(xs) 165 | loss, ldict = self.loss(u, f, y, kld_w=len(xs) / len(Y), return_dict=True) 166 | ldict["epoch"], ldict["step"] = epoch, step 167 | history.append(ldict) 168 | optim.zero_grad() 169 | loss.backward() 170 | optim.step() 171 | 172 | self.is_trained = True 173 | return history 174 | 175 | @torch.inference_mode() 176 | def predict(self, Xs, scaling=None): 177 | return self(Xs, n_samples=None, predict=True, scaling=scaling).squeeze(2) 178 | 179 | @torch.inference_mode() 180 | def get_weights(self, Xs, ravel=False): 181 | """Get instance weights after and before softmax""" 182 | beta_u, _ = self.get_betas(predict=True) 183 | beta_u = beta_u.squeeze(2) # not taking mcmc samples 184 | if torch.is_tensor(Xs): 185 | _w = torch.einsum("niq,qp->nip", Xs, beta_u) 186 | w = torch.softmax(_w, dim=1) 187 | 188 | else: 189 | x, i, _ = setup_scatter(Xs) 190 | _w = torch.einsum("iq,qp->ip", x, beta_u) 191 | w = scatter_softmax(_w, i, dim=0) 192 | 193 | if ravel: 194 | w, _w = w.ravel(), _w.ravel() 195 | elif not torch.is_tensor(Xs): 196 | _w = [_w[i == idx] for idx in range(len(Xs))] 197 | w = [w[i == idx] for idx in range(len(Xs))] 198 | return w, _w 199 | 200 | def extra_repr(self): 201 | string = f"Q={self.Q}, K={self.alpha.shape[0]}, P={self.alpha.shape[1]}, likelihood={self.likelihood_name}" 202 | if self.likelihood_name == "binomial": 203 | string += f", n_trials={self.n_trials}" 204 | string += f", device={self.alpha.device}, trained={self.is_trained}" 205 | string += f"\n(alpha): Parameter(shape={tuple(self.alpha.shape)})\n" 206 | string += f"(log_sigma_u): Parameter(shape={tuple(self.log_sigma_u.shape)})\n" 207 | string += f"(log_sigma_z): Parameter(shape={tuple(self.log_sigma_z.shape)})" 208 | return string 209 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /experiments/histopathology_camelyon16.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Histopathology Experiment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 10, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from mixmil.paths import DATA\n", 17 | "import pandas as pd\n", 18 | "from tqdm import tqdm\n", 19 | "from sklearn.preprocessing import StandardScaler\n", 20 | "import numpy as np\n", 21 | "import anndata as ad\n", 22 | "from mixmil import MixMIL\n", 23 | "import torch\n", 24 | "from sklearn.metrics import roc_auc_score\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import scipy.stats as st" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "In this notebook we show how to apply MixMIL to the Camelyon dataset for the classification of histopathological slides as containing tumor or not. " 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Utility functions " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def to_device(el, device):\n", 50 | " \"\"\"\n", 51 | " Move a nested structure of elements (dict, list, tuple, torch.Tensor, torch.nn.Module) to the specified device.\n", 52 | "\n", 53 | " Parameters:\n", 54 | " - el: Element or nested structure of elements to be moved to the device.\n", 55 | " - device (torch.device): The target device, such as 'cuda' for GPU or 'cpu' for CPU.\n", 56 | "\n", 57 | " Returns:\n", 58 | " - Transferred element(s) in the same structure: Elements moved to the specified device.\n", 59 | " \"\"\"\n", 60 | " if isinstance(el, dict):\n", 61 | " return {k: to_device(v, device) for k, v in el.items()}\n", 62 | " elif isinstance(el, (list, tuple)):\n", 63 | " return [to_device(x, device) for x in el]\n", 64 | " elif isinstance(el, (torch.Tensor, torch.nn.Module)):\n", 65 | " return el.to(device)\n", 66 | " else:\n", 67 | " return el" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Data Loading" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "We start from reading the Camelyon dataset. The bag features are stored in multiple `.csv` files. The association of the files, their labels and whether they are train or test samples is stored in the `Camelyon16.csv` file." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/html": [ 92 | "
\n", 93 | "\n", 106 | "\n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | "
labelsplit
file
1-tumor/test_033.csv1test
0-normal/normal_148.csv0train
0-normal/test_095.csv0test
0-normal/normal_025.csv0train
0-normal/test_087.csv0test
.........
0-normal/normal_006.csv0train
1-tumor/tumor_003.csv1train
0-normal/normal_018.csv0train
1-tumor/tumor_017.csv1train
0-normal/test_067.csv0test
\n", 177 | "

399 rows × 2 columns

\n", 178 | "
" 179 | ], 180 | "text/plain": [ 181 | " label split\n", 182 | "file \n", 183 | "1-tumor/test_033.csv 1 test\n", 184 | "0-normal/normal_148.csv 0 train\n", 185 | "0-normal/test_095.csv 0 test\n", 186 | "0-normal/normal_025.csv 0 train\n", 187 | "0-normal/test_087.csv 0 test\n", 188 | "... ... ...\n", 189 | "0-normal/normal_006.csv 0 train\n", 190 | "1-tumor/tumor_003.csv 1 train\n", 191 | "0-normal/normal_018.csv 0 train\n", 192 | "1-tumor/tumor_017.csv 1 train\n", 193 | "0-normal/test_067.csv 0 test\n", 194 | "\n", 195 | "[399 rows x 2 columns]" 196 | ] 197 | }, 198 | "execution_count": 3, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "dataset_index_file = DATA / \"camelyon16\" / \"Camelyon16.csv\"\n", 205 | "bagdf = pd.read_csv(dataset_index_file)\n", 206 | "bagdf.columns = [\"file\", \"label\"]\n", 207 | "bagdf[\"file\"] = bagdf[\"file\"].str.replace(\"datasets/Camelyon16/\", \"\")\n", 208 | "bagdf[\"split\"] = bagdf[\"file\"].apply(lambda x: \"test\" if \"test\" in x else \"train\")\n", 209 | "bagdf = bagdf.set_index(\"file\")\n", 210 | "bagdf" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "Compute and save `anndata` files containing train and test sets collected from the true " 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 4, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stderr", 227 | "output_type": "stream", 228 | "text": [ 229 | "100%|██████████| 270/270 [01:58<00:00, 2.27it/s]\n", 230 | "100%|██████████| 129/129 [00:59<00:00, 2.18it/s]\n", 231 | "/home/icb/alessandro.palma/miniconda3/envs/sslbio-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 232 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "dtype = \"float32\"\n", 238 | "\n", 239 | "if not (DATA / \"camelyon16\" / \"full_test.h5ad\").exists() or not (DATA / \"full_camelyon16\" / \"full_train.h5ad\").exists():\n", 240 | " train_data = []\n", 241 | " train_bag_indices = []\n", 242 | " for _, row in tqdm(list(bagdf[bagdf[\"split\"] == \"train\"].iterrows())):\n", 243 | " train_data.append(pd.read_csv(dataset_index_file.parent / row.name).values.astype(dtype))\n", 244 | " train_bag_indices.extend([row.name] * len(train_data[-1]))\n", 245 | "\n", 246 | " test_data = []\n", 247 | " test_bag_indices = []\n", 248 | " for _, row in tqdm(list(bagdf[bagdf[\"split\"] == \"test\"].iterrows())):\n", 249 | " test_data.append(pd.read_csv(dataset_index_file.parent / row.name).values.astype(dtype))\n", 250 | " test_bag_indices.extend([row.name] * len(test_data[-1]))\n", 251 | "\n", 252 | " i_train = np.array([idx for idx, x in enumerate(train_data) for _ in range(len(x))])\n", 253 | " X_train = np.concatenate(train_data, 0).astype(dtype)\n", 254 | " X_test = np.concatenate(test_data, 0).astype(dtype)\n", 255 | " scaler = StandardScaler()\n", 256 | " X_train = scaler.fit_transform(X_train)\n", 257 | " X_test = scaler.transform(X_test)\n", 258 | "\n", 259 | " train_obs = pd.DataFrame(\n", 260 | " {\"bag\": train_bag_indices, \"label\": bagdf.loc[train_bag_indices][\"label\"].values, \"split\": \"train\"}\n", 261 | " )\n", 262 | " train_adata = ad.AnnData(X_train, obs=train_obs)\n", 263 | " test_obs = pd.DataFrame(\n", 264 | " {\"bag\": test_bag_indices, \"label\": bagdf.loc[test_bag_indices][\"label\"].values, \"split\": \"test\"}\n", 265 | " )\n", 266 | " test_adata = ad.AnnData(X_test, obs=test_obs)\n", 267 | "\n", 268 | " test_adata.write(DATA / \"camelyon16\" / \"full_test.h5ad\")\n", 269 | " train_adata.write(DATA / \"camelyon16\" / \"full_train.h5ad\")\n", 270 | "else:\n", 271 | " print(\"Loading precomputed anndatas\")\n", 272 | " train_adata = ad.read_h5ad(DATA / \"camelyon16\" / \"full_train.h5ad\")\n", 273 | " test_adata = ad.read_h5ad(DATA / \"camelyon16\" / \"full_test.h5ad\")" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "Initialize bags of observations as tensors. " 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 5, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "# prepare train data\n", 290 | "train_bags = train_adata.obs[\"bag\"].unique().tolist()\n", 291 | "Xs = [torch.Tensor(train_adata[train_adata.obs[\"bag\"] == bag].X) for bag in train_bags]\n", 292 | "F = torch.ones((len(train_bags), 1))\n", 293 | "Y = torch.Tensor(train_adata.obs[[\"bag\", \"label\"]].drop_duplicates().set_index(\"bag\").loc[train_bags].values)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 6, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "# prepare test data, following official train-test split\n", 303 | "test_bags = test_adata.obs[\"bag\"].unique().tolist()\n", 304 | "test_Xs = [torch.Tensor(test_adata[test_adata.obs[\"bag\"] == bag].X) for bag in test_bags]\n", 305 | "test_Y = torch.Tensor(test_adata.obs[[\"bag\", \"label\"]].drop_duplicates().set_index(\"bag\").loc[test_bags].values)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "## Training" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "Now, we initialize MixMIL with a simple GLMM and use it for prediction." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 7, 325 | "metadata": {}, 326 | "outputs": [ 327 | { 328 | "name": "stderr", 329 | "output_type": "stream", 330 | "text": [ 331 | "GLMM Init: 100%|██████████| 1/1 [00:01<00:00, 1.78s/it]\n" 332 | ] 333 | }, 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Test AUC: 0.698 Spearman: 0.333\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "# initialize model with mean model and Bernoulli likelihood\n", 344 | "model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood=\"binomial\", n_trials=1)\n", 345 | "y_pred_mean = model.predict(test_Xs)\n", 346 | "print(\n", 347 | " \"Test AUC:\",\n", 348 | " round(roc_auc_score(test_Y, y_pred_mean), 3),\n", 349 | " \"Spearman:\",\n", 350 | " round(st.spearmanr(test_Y, y_pred_mean).correlation, 3),\n", 351 | ")" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "Finally, we train MixMIL starting from the GLMM initialization." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 16, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "Test AUC: 0.977 Spearman: 0.802\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "# train model for 1000 epochs\n", 376 | "device = \"cuda:0\"\n", 377 | "model, Xs, F, Y, test_Xs, test_Y = to_device((model, Xs, F, Y, test_Xs, test_Y), device)\n", 378 | "model.train(Xs, F, Y, n_epochs=1000)\n", 379 | "model.to(\"cpu\")\n", 380 | "test_Xs = [x.cpu() for x in test_Xs]\n", 381 | "y_pred = model.predict(test_Xs).cpu().numpy()\n", 382 | "y_true = test_Y.cpu().numpy()\n", 383 | "print(\n", 384 | " \"Test AUC:\",\n", 385 | " round(roc_auc_score(y_true, y_pred), 3),\n", 386 | " \"Spearman:\",\n", 387 | " round(st.spearmanr(y_true, y_pred).correlation, 3),\n", 388 | ")" 389 | ] 390 | } 391 | ], 392 | "metadata": { 393 | "kernelspec": { 394 | "display_name": "Python 3 (ipykernel)", 395 | "language": "python", 396 | "name": "python3" 397 | }, 398 | "language_info": { 399 | "codemirror_mode": { 400 | "name": "ipython", 401 | "version": 3 402 | }, 403 | "file_extension": ".py", 404 | "mimetype": "text/x-python", 405 | "name": "python", 406 | "nbconvert_exporter": "python", 407 | "pygments_lexer": "ipython3", 408 | "version": "3.9.13" 409 | } 410 | }, 411 | "nbformat": 4, 412 | "nbformat_minor": 2 413 | } 414 | -------------------------------------------------------------------------------- /experiments/reduced_histopathology_camelyon16.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Histopathology Experiment (Reduced Dimension)\n", 8 | "> In this version we run PCA to half the data dimensionality to enable in-memory training" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "from mixmil.paths import DATA\n", 18 | "import pandas as pd\n", 19 | "from tqdm import tqdm\n", 20 | "from sklearn.preprocessing import StandardScaler\n", 21 | "from sklearn.decomposition import PCA\n", 22 | "import numpy as np\n", 23 | "import anndata as ad\n", 24 | "from mixmil import MixMIL\n", 25 | "import torch\n", 26 | "from sklearn.metrics import roc_auc_score\n", 27 | "import scipy.stats as st" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Utility function" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def to_device(el, device):\n", 44 | " \"\"\"\n", 45 | " Move a nested structure of elements (dict, list, tuple, torch.Tensor, torch.nn.Module) to the specified device.\n", 46 | "\n", 47 | " Parameters:\n", 48 | " - el: Element or nested structure of elements to be moved to the device.\n", 49 | " - device (torch.device): The target device, such as 'cuda' for GPU or 'cpu' for CPU.\n", 50 | "\n", 51 | " Returns:\n", 52 | " - Transferred element(s) in the same structure: Elements moved to the specified device.\n", 53 | " \"\"\"\n", 54 | " if isinstance(el, dict):\n", 55 | " return {k: to_device(v, device) for k, v in el.items()}\n", 56 | " elif isinstance(el, (list, tuple)):\n", 57 | " return [to_device(x, device) for x in el]\n", 58 | " elif isinstance(el, (torch.Tensor, torch.nn.Module)):\n", 59 | " return el.to(device)\n", 60 | " else:\n", 61 | " return el" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Data Loading" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "We start from reading the Camelyon dataset. The bag features are stored in multiple `.csv` files. The association of the files, their labels and whether they are train or test samples are stored in the `Camelyon16.csv` file." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/html": [ 86 | "
\n", 87 | "\n", 100 | "\n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | "
labelsplit
file
1-tumor/test_033.csv1test
0-normal/normal_148.csv0train
0-normal/test_095.csv0test
0-normal/normal_025.csv0train
0-normal/test_087.csv0test
.........
0-normal/normal_006.csv0train
1-tumor/tumor_003.csv1train
0-normal/normal_018.csv0train
1-tumor/tumor_017.csv1train
0-normal/test_067.csv0test
\n", 171 | "

399 rows × 2 columns

\n", 172 | "
" 173 | ], 174 | "text/plain": [ 175 | " label split\n", 176 | "file \n", 177 | "1-tumor/test_033.csv 1 test\n", 178 | "0-normal/normal_148.csv 0 train\n", 179 | "0-normal/test_095.csv 0 test\n", 180 | "0-normal/normal_025.csv 0 train\n", 181 | "0-normal/test_087.csv 0 test\n", 182 | "... ... ...\n", 183 | "0-normal/normal_006.csv 0 train\n", 184 | "1-tumor/tumor_003.csv 1 train\n", 185 | "0-normal/normal_018.csv 0 train\n", 186 | "1-tumor/tumor_017.csv 1 train\n", 187 | "0-normal/test_067.csv 0 test\n", 188 | "\n", 189 | "[399 rows x 2 columns]" 190 | ] 191 | }, 192 | "execution_count": 3, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "dataset_index_file = DATA / \"camelyon16\" / \"Camelyon16.csv\"\n", 199 | "bagdf = pd.read_csv(dataset_index_file)\n", 200 | "bagdf.columns = [\"file\", \"label\"]\n", 201 | "bagdf[\"file\"] = bagdf[\"file\"].str.replace(\"datasets/Camelyon16/\", \"\")\n", 202 | "bagdf[\"split\"] = bagdf[\"file\"].apply(lambda x: \"test\" if \"test\" in x else \"train\")\n", 203 | "bagdf = bagdf.set_index(\"file\")\n", 204 | "bagdf" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "Compute and save anndata files containing train and test sets collected from the true. We collect and apply the model to 128 PCA features, scaled using the training set statistics and saved in memory." 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 6, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "Loading precomputed anndatas\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "dtype = \"float32\"\n", 229 | "n_pcs = 128\n", 230 | "\n", 231 | "if not (DATA / \"camelyon16\" / \"test.h5ad\").exists() or not (DATA / \"camelyon16\" / \"train.h5ad\").exists():\n", 232 | " train_data = []\n", 233 | " train_bag_indices = []\n", 234 | " for _, row in tqdm(list(bagdf[bagdf[\"split\"] == \"train\"].iterrows())):\n", 235 | " train_data.append(pd.read_csv(dataset_index_file.parent / row.name).values.astype(dtype))\n", 236 | " train_bag_indices.extend([row.name] * len(train_data[-1]))\n", 237 | "\n", 238 | " test_data = []\n", 239 | " test_bag_indices = []\n", 240 | " for _, row in tqdm(list(bagdf[bagdf[\"split\"] == \"test\"].iterrows())):\n", 241 | " test_data.append(pd.read_csv(dataset_index_file.parent / row.name).values.astype(dtype))\n", 242 | " test_bag_indices.extend([row.name] * len(test_data[-1]))\n", 243 | "\n", 244 | " i_train = np.array([idx for idx, x in enumerate(train_data) for _ in range(len(x))])\n", 245 | " X_train = np.concatenate(train_data, 0).astype(dtype)\n", 246 | " X_test = np.concatenate(test_data, 0).astype(dtype)\n", 247 | " pca = PCA(n_components=n_pcs)\n", 248 | " scaler = StandardScaler()\n", 249 | " X_train = scaler.fit_transform(pca.fit_transform(X_train))\n", 250 | " X_test = scaler.transform(pca.transform(X_test))\n", 251 | "\n", 252 | " train_obs = pd.DataFrame(\n", 253 | " {\"bag\": train_bag_indices, \"label\": bagdf.loc[train_bag_indices][\"label\"].values, \"split\": \"train\"}\n", 254 | " )\n", 255 | " train_adata = ad.AnnData(X_train, obs=train_obs, var=pd.DataFrame(index=[f\"PC{i}\" for i in range(n_pcs)]))\n", 256 | " test_obs = pd.DataFrame(\n", 257 | " {\"bag\": test_bag_indices, \"label\": bagdf.loc[test_bag_indices][\"label\"].values, \"split\": \"test\"}\n", 258 | " )\n", 259 | " test_adata = ad.AnnData(X_test, obs=test_obs, var=pd.DataFrame(index=[f\"PC{i}\" for i in range(n_pcs)]))\n", 260 | "\n", 261 | " test_adata.write(DATA / \"camelyon16\" / \"test.h5ad\")\n", 262 | " train_adata.write(DATA / \"camelyon16\" / \"train.h5ad\")\n", 263 | "else:\n", 264 | " print(\"Loading precomputed anndatas\")\n", 265 | " train_adata = ad.read_h5ad(DATA / \"camelyon16\" / \"train.h5ad\")\n", 266 | " test_adata = ad.read_h5ad(DATA / \"camelyon16\" / \"test.h5ad\")" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 8, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "# prepare train data\n", 276 | "train_bags = train_adata.obs[\"bag\"].unique().tolist()\n", 277 | "Xs = [torch.Tensor(train_adata[train_adata.obs[\"bag\"] == bag].X) for bag in train_bags]\n", 278 | "F = torch.ones((len(train_bags), 1))\n", 279 | "Y = torch.Tensor(train_adata.obs[[\"bag\", \"label\"]].drop_duplicates().set_index(\"bag\").loc[train_bags].values)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "# prepare test data, following official train-test split\n", 289 | "test_bags = test_adata.obs[\"bag\"].unique().tolist()\n", 290 | "test_Xs = [torch.Tensor(test_adata[test_adata.obs[\"bag\"] == bag].X) for bag in test_bags]\n", 291 | "test_Y = torch.Tensor(test_adata.obs[[\"bag\", \"label\"]].drop_duplicates().set_index(\"bag\").loc[test_bags].values)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "## Training" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "Initialize MixMIL with a simple GLMM and use it for prediction." 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 12, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "data": { 315 | "application/vnd.jupyter.widget-view+json": { 316 | "model_id": "66eb150f90564e659021f5bcb6cafdd3", 317 | "version_major": 2, 318 | "version_minor": 0 319 | }, 320 | "text/plain": [ 321 | "GLMM Init: 0%| | 0/1 [00:00