├── mixmil ├── paths.py ├── __init__.py ├── data.py ├── posterior.py ├── simulation.py ├── utils.py └── model.py ├── tests ├── test_simulation.py └── test_model.py ├── pyproject.toml ├── scripts └── dsmil_data_download.py ├── README.md ├── .gitignore ├── experiments ├── simulation.ipynb ├── histopathology_camelyon16.ipynb └── reduced_histopathology_camelyon16.ipynb └── LICENSE /mixmil/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT = Path(__file__).parent.parent 4 | 5 | DATA = ROOT / "data" 6 | -------------------------------------------------------------------------------- /mixmil/__init__.py: -------------------------------------------------------------------------------- 1 | from mixmil.model import MixMIL 2 | 3 | __all__ = ["MixMIL", "utils", "likelihood", "posterior", "data", "simulation", "paths"] 4 | __version__ = "0.1.2" 5 | -------------------------------------------------------------------------------- /tests/test_simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as st 3 | from sklearn.metrics import roc_auc_score 4 | 5 | from mixmil import MixMIL 6 | from mixmil.data import load_data 7 | 8 | 9 | def calc_metrics(model, X, u, w): 10 | u_pred = model.predict(X["test"]).cpu().numpy().ravel() 11 | w_pred = model.get_weights(X["test"])[0].cpu().numpy().ravel() 12 | rho_bag = st.spearmanr(u_pred, u["test"]).correlation # bag level correlation 13 | is_top_instance = (w["test"] > np.quantile(w["test"], 0.90)).long().ravel() 14 | auc_instance = roc_auc_score(is_top_instance, w_pred) # instance-retrieval AUC 15 | return rho_bag, auc_instance 16 | 17 | 18 | def test_simulation(): 19 | X, F, Y, u, w = load_data(P=1, seed=0) 20 | model = MixMIL.init_with_mean_model(X["train"], F["train"], Y["train"], likelihood="binomial", n_trials=2) 21 | 22 | start_rho_bag, start_auc_instance = calc_metrics(model, X, u, w) 23 | model.train(X["train"], F["train"], Y["train"], n_epochs=40) 24 | end_rho_bag, end_auc_instance = calc_metrics(model, X, u, w) 25 | 26 | # assert that both metrics improved by at least 10% 27 | assert end_rho_bag > start_rho_bag * 1.1 28 | assert end_auc_instance > start_auc_instance * 1.1 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mixmil" 3 | dynamic = ["version"] 4 | description = "Attention-based Multi-instance Mixed Models" 5 | readme = "README.md" 6 | license.file = "LICENSE" 7 | authors = [ 8 | { name = "Jan Engelmann", email = "jan.engelmann@helmholtz-munich.de" }, 9 | { name = "Alessandro Palma", email = "alessandro.palma@helmholtz-munich.de" }, 10 | { name = "Paolo Casale", email = "francescopaolo.casale@helmholtz-munich.de" }, 11 | ] 12 | dependencies = [ 13 | "numpy>=1.22.0", 14 | "torch>=1.4.0", 15 | "torch_scatter>=2.0.1", 16 | "scipy>=1.8.0", 17 | "scikit-learn>=1.3.0", 18 | "tqdm>=4.0.0", 19 | "statsmodels>=0.11.0", 20 | ] 21 | requires-python = ">=3.9" 22 | classifiers = [ 23 | "Development Status :: 3 - Alpha", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Programming Language :: Python :: 3 :: Only", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Topic :: Scientific/Engineering :: Bio-Informatics", 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Science/Research", 33 | "Natural Language :: English", 34 | "Operating System :: MacOS :: MacOS X", 35 | "Operating System :: Microsoft :: Windows", 36 | "Operating System :: POSIX :: Linux", 37 | ] 38 | 39 | [project.urls] 40 | Homepage = "https://github.com/AIH-SGML/mixmil" 41 | "Bug Tracker" = "https://github.com/AIH-SGML/mixmil/issues" 42 | Discussions = "https://github.com/AIH-SGML/mixmil/discussions" 43 | 44 | [project.optional-dependencies] 45 | experiments = ["anndata>=0.8.0", "jupyterlab>=3.0.0"] 46 | test = ["pytest>=6.0.0"] 47 | all = ["mixmil[experiments,test]"] 48 | 49 | [build-system] 50 | requires = ["hatchling"] 51 | build-backend = "hatchling.build" 52 | 53 | [tool.hatch.build.targets.wheel] 54 | packages = ["mixmil"] 55 | 56 | [tool.hatch] 57 | version.path = "mixmil/__init__.py" 58 | 59 | [tool.black] 60 | line-length = 120 61 | 62 | [tool.ruff] 63 | line-length = 120 64 | ignore = ["E741"] 65 | -------------------------------------------------------------------------------- /mixmil/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from mixmil.simulation import load_simulation 6 | 7 | 8 | def setup_scatter(Xs): 9 | device = Xs[0].device 10 | x = torch.cat(Xs, dim=0) 11 | i = torch.cat([torch.full((x.shape[0],), idx) for idx, x in enumerate(Xs)]).to(device) 12 | i_ptr = torch.cat([torch.tensor([0], device=device), i.bincount().cumsum(0)]) 13 | return x, i, i_ptr 14 | 15 | 16 | def xgower_factor(X): 17 | a = np.power(X, 2).sum() 18 | b = X.dot(X.sum(0)).sum() 19 | return np.sqrt((a - b / X.shape[0]) / (X.shape[0] - 1)) 20 | 21 | 22 | class MILDataset(Dataset): 23 | def __init__(self, Xs, F, Y): 24 | self.Xs = Xs 25 | self.F = F 26 | self.Y = Y 27 | 28 | def __len__(self): 29 | return len(self.Xs) 30 | 31 | def __getitem__(self, idx): 32 | X = self.Xs[idx] 33 | F = self.F[idx] 34 | Y = self.Y[idx] 35 | return X, F, Y 36 | 37 | 38 | def mil_collate_fn(batch): 39 | X = [item[0] for item in batch] 40 | F = torch.stack([item[1] for item in batch]) 41 | Y = torch.stack([item[2] for item in batch]) 42 | return X, F, Y 43 | 44 | 45 | def normalize_feats(X, norm_factor="std_sqrt"): 46 | assert norm_factor in ["std", "std_sqrt"] 47 | train_data = ( 48 | torch.cat(X["train"], dim=0) if isinstance(X["train"], list) else X["train"].reshape(-1, X["train"].shape[2]) 49 | ) 50 | mean = train_data.mean(0, keepdims=True) 51 | std = train_data.std(0, keepdims=True) 52 | factor = std * np.sqrt(train_data.shape[1]) if norm_factor == "std_sqrt" else std 53 | 54 | for key in X: 55 | X[key] = [(x - mean) / factor for x in X[key]] if isinstance(X[key], list) else (X[key] - mean) / factor 56 | 57 | return X 58 | 59 | 60 | def load_data(dataset="simulation", norm_factor="std_sqrt", **kwargs): 61 | if dataset == "simulation": 62 | X, F, Y, u, w = load_simulation(**kwargs) 63 | X = normalize_feats(X, norm_factor) 64 | else: 65 | raise ValueError(f"Unknown dataset: {dataset}") 66 | return X, F, Y, u, w 67 | -------------------------------------------------------------------------------- /mixmil/posterior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.distributions import LowRankMultivariateNormal 4 | 5 | 6 | def get_params(vc_mean, vc_sd, n_outs, n_vars, mean_field): 7 | mu_z = torch.Tensor(vc_mean) 8 | mu_u = torch.zeros_like(mu_z) 9 | mu = torch.cat([mu_u, mu_z], 1) 10 | sd_z = torch.Tensor(vc_sd) 11 | sd_u = torch.sqrt(0.1 * torch.ones_like(sd_z)) 12 | 13 | if mean_field: 14 | cov_factor = torch.zeros(n_outs, n_vars, 1) 15 | cov_logdiag = 2.0 * torch.log(torch.cat([sd_u, sd_z], 1)) 16 | else: 17 | diag = torch.diag_embed(torch.cat([sd_u, sd_z], 1)) 18 | cov_factor = 1e-4 * torch.randn(diag.shape) + diag 19 | cov_logdiag = np.log(1e-4) * torch.ones(n_outs, n_vars) 20 | 21 | return mu, cov_factor, cov_logdiag 22 | 23 | 24 | class GaussianVariationalPosterior(torch.nn.Module): 25 | def __init__(self, n_vars, n_outs, mean_field=True, init_params=None): 26 | super().__init__() 27 | self.n_vars = n_vars 28 | self.n_outs = n_outs 29 | self.mean_field = mean_field 30 | 31 | if init_params is not None: 32 | mu_z, sd_z, *_ = init_params 33 | mu_z = mu_z.T 34 | sd_z = sd_z.T 35 | else: 36 | mu_z = 1e-3 * torch.randn(n_outs, n_vars // 2) 37 | sd_z = 1e-3 * torch.randn(n_outs, n_vars // 2) 38 | 39 | mu, cov_factor, cov_logdiag = get_params(mu_z, sd_z, n_outs, n_vars, mean_field) 40 | 41 | self.mu = torch.nn.Parameter(mu) 42 | if mean_field: 43 | self.register_buffer("cov_factor", cov_factor) 44 | self.cov_logdiag = torch.nn.Parameter(cov_logdiag) 45 | else: 46 | self.cov_factor = torch.nn.Parameter(cov_factor) 47 | self.register_buffer("cov_logdiag", cov_logdiag) 48 | 49 | @property 50 | def distribution(self): 51 | return LowRankMultivariateNormal(self.mu, self.cov_factor, torch.exp(self.cov_logdiag)) 52 | 53 | @property 54 | def q_mu(self): 55 | return self.mu.T 56 | 57 | def sample(self, n_samples): 58 | return self.distribution.rsample([n_samples]).permute([2, 1, 0]) 59 | 60 | def extra_repr(self) -> str: 61 | return f"n_vars=2*{self.n_vars//2}, n_outs={self.n_outs}, mean_field={self.mean_field}" 62 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from mixmil import MixMIL 6 | 7 | 8 | @pytest.fixture 9 | def mock_data_binomial(): 10 | Xs = [torch.randn(10, 3) for _ in range(5)] # List of tensors 11 | F = torch.randn(5, 4) # Fixed effects 12 | Y = torch.randint(0, 2, (5, 1)) # Labels for binomial 13 | return Xs, F, Y 14 | 15 | 16 | @pytest.fixture 17 | def mock_data_categorical(): 18 | N, Q, K = 50, 10, 4 19 | bag_sizes = torch.randint(5, 15, (N,)) 20 | Xs = [torch.randn(bag_sizes[n], Q) for n in range(N)] # List of tensors 21 | F = torch.randn(N, K) # Fixed effects 22 | Y = torch.randint(0, 5, (N, 1)) # Labels for categorical 23 | return Xs, F, Y 24 | 25 | 26 | def test_init_with_mean_model_binomial(mock_data_binomial): 27 | Xs, F, Y = mock_data_binomial 28 | model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="binomial", n_trials=2) 29 | model.train(Xs, F, Y, n_epochs=3) 30 | assert isinstance(model, MixMIL) 31 | assert model.likelihood_name == "binomial" 32 | assert model.n_trials == 2 33 | assert model.log_sigma_u.numel() == 1 34 | 35 | 36 | def test_init_with_mean_model_categorical(mock_data_categorical): 37 | Xs, F, Y = mock_data_categorical 38 | model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="categorical") 39 | model.train(Xs, F, Y, n_epochs=3) 40 | assert isinstance(model, MixMIL) 41 | assert model.likelihood_name == "categorical" 42 | assert model.n_trials is None 43 | assert model.log_sigma_u.numel() == len(np.unique(Y)) # separate prior for each class 44 | 45 | 46 | def test_initialization(): 47 | model = MixMIL(Q=10, K=5, P=2, likelihood="binomial", n_trials=2) 48 | assert model.Q == 10 49 | assert model.alpha.shape == (5, 2) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "Q, K, P, likelihood, n_trials, mean_field", 54 | [ 55 | (10, 5, 2, "categorical", None, True), 56 | (10, 5, 2, "categorical", None, False), 57 | (10, 5, 2, "binomial", 1, True), 58 | (10, 5, 2, "binomial", 2, False), 59 | (10, 5, 1, "binomial", 2, False), 60 | ], 61 | ) 62 | def test_init_model(Q, K, P, likelihood, n_trials, mean_field): 63 | MixMIL(Q=Q, K=K, P=P, likelihood=likelihood, n_trials=n_trials, mean_field=mean_field) 64 | -------------------------------------------------------------------------------- /scripts/dsmil_data_download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script adapted from 3 | https://github.com/binli123/dsmil-wsi/blob/master/download.py 4 | 5 | Script and data provided by the DSMIL authors: 6 | 7 | @inproceedings{li2021dual, 8 | title={Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning}, 9 | author={Li, Bin and Li, Yin and Eliceiri, Kevin W}, 10 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 11 | pages={14318--14328}, 12 | year={2021} 13 | } 14 | """ 15 | 16 | import argparse 17 | import os 18 | import urllib.request 19 | import zipfile 20 | 21 | from tqdm import tqdm 22 | 23 | 24 | class DownloadProgressBar(tqdm): 25 | def update_to(self, b=1, bsize=1, tsize=None): 26 | if tsize is not None: 27 | self.total = tsize 28 | self.update(b * bsize - self.n) 29 | 30 | 31 | def download_url(url, output_path): 32 | with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]) as t: 33 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 34 | 35 | 36 | def unzip_data(zip_path, data_path): 37 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 38 | zip_ref.extractall(data_path) 39 | 40 | 41 | DATASET_URLS = { 42 | "camelyon16": "https://uwmadison.box.com/shared/static/l9ou15iwup73ivdjq0bc61wcg5ae8dwe.zip", 43 | } 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--dataset", type=str, default="camelyon16", help="Dataset to be downloaded: camelyon16") 49 | parser.add_argument("--keep-zip", action="store_true", help="Keep the downloaded zip file") 50 | args = parser.parse_args() 51 | 52 | assert args.dataset in DATASET_URLS, f"Dataset {args.dataset} not found" 53 | 54 | print(f"downloading dataset: {args.dataset}") 55 | unzip_dir = f"data/{args.dataset}" 56 | zip_file_path = f"data/{args.dataset}-dataset.zip" 57 | os.makedirs(unzip_dir, exist_ok=True) 58 | download_url(DATASET_URLS[args.dataset], zip_file_path) 59 | unzip_data(zip_file_path, unzip_dir) 60 | 61 | if not args.keep_zip: 62 | os.remove(f"{args.dataset}-dataset.zip") 63 | 64 | print("done!") 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /mixmil/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.model_selection import train_test_split 4 | 5 | 6 | def get_X(N=1_000, I=50, Q=30, N_test=200): 7 | N = N + N_test 8 | 9 | X = torch.randn([N * I, Q], dtype=torch.float32) 10 | 11 | X = (X - X.mean(0)) / X.std(0) 12 | X = X / np.sqrt(X.shape[1]) 13 | X = X.reshape([N, I, Q]) 14 | 15 | return X 16 | 17 | 18 | def simulate(X, v_beta=0.5, v_gamma=0.8, b=-1, F=None, P=1): 19 | if F is None: 20 | F = torch.ones([X.shape[0], 1]) 21 | 22 | # simulate single phenotype 23 | K = F.shape[1] 24 | b = b * torch.ones([K, P]) 25 | v_beta = v_beta * torch.ones(P) 26 | v_gamma = v_gamma * torch.ones(P) 27 | 28 | # sample weights 29 | gamma = torch.randn((X.shape[2], v_gamma.shape[0])) 30 | _w = torch.einsum("nik,kp->nip", X, gamma) 31 | _scale_w = torch.sqrt(v_gamma / _w.var([0, 1])) 32 | gamma = _scale_w[None, :] * gamma 33 | _w = _scale_w[None, None, :] * _w 34 | 35 | w = torch.softmax(_w, dim=1) 36 | 37 | # sample z 38 | beta = torch.randn((X.shape[2], v_beta.shape[0])) 39 | beta = beta / torch.sqrt((beta**2).mean(0, keepdim=True)) 40 | z = torch.einsum("nik,kp->nip", X, beta) 41 | u = torch.einsum("nip,nip->np", w, z) 42 | u = (u - u.mean(0)) / u.std(0) 43 | u = torch.sqrt(v_beta) * u 44 | beta = torch.sqrt(v_beta) * beta 45 | 46 | # compute rates 47 | logits = F.mm(b) + u 48 | 49 | # sample Y 50 | Y = torch.distributions.Binomial(2, logits=logits).sample() 51 | 52 | return F, Y, u, w 53 | 54 | 55 | def split_data(Xs, test_size=200, val_size=0.0, test_rs=127, val_rs=412): 56 | idxs_all = np.arange(Xs[0].shape[0]) 57 | idxs = {} 58 | idxs["train_val"], idxs["test"] = train_test_split(idxs_all, test_size=test_size, random_state=test_rs) 59 | 60 | if not np.isclose(val_size, 0): 61 | idxs["train"], idxs["val"] = train_test_split(idxs["train_val"], test_size=val_size, random_state=val_rs) 62 | else: 63 | idxs["train"] = idxs["train_val"] 64 | del idxs["train_val"] 65 | 66 | outs = [] 67 | for X in Xs: 68 | out = {} 69 | for key in idxs.keys(): 70 | out[key] = X[idxs[key]] 71 | outs.append(out) 72 | return outs 73 | 74 | 75 | def load_simulation(seed=42, N=1_000, I=50, Q=30, N_test=200, P=1, v_beta=0.5, v_gamma=0.8, b=-1, F=None): 76 | np.random.seed(seed) 77 | torch.manual_seed(seed) 78 | 79 | X = get_X(N, I, Q, N_test) 80 | data = simulate(X, v_beta, v_gamma, b, F, P) 81 | X, F, Y, u, w = split_data((X, *data)) 82 | return X, F, Y, u, w 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixMIL 2 | Code for the paper: [Mixed Models with Multiple Instance Learning](https://arxiv.org/abs/2311.02455) 3 | 4 | Accepted at AISTATS 24 as an oral presentation & [Outstanding Student Paper Highlight](https://aistats.org/aistats2024/awards.html). 5 | 6 | Please raise an issue for questions and bug-reports. 7 | ## Installation 8 | Install with: 9 | ``` 10 | pip install mixmil 11 | ``` 12 | alternatively, if you want to include the optional experiment and test dependencies use: 13 | ``` 14 | pip install "mixmil[experiments,test]" 15 | ``` 16 | or if you want to adapt the code: 17 | ``` 18 | git clone https://github.com/AIH-SGML/mixmil.git 19 | cd mixmil 20 | pip install -e ".[experiments,test]" 21 | ``` 22 | To enable computations on GPU please follow the installation instructions of [PyTorch](https://pytorch.org/) and [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter). 23 | MixMIL works e.g. with PyTorch 2.1. 24 | ## Experiments 25 | See the notebooks in the `experiments` folder for examples on how to run the simulation and histopathology experiments. 26 | 27 | Make sure the `experiments` requirements are installed: 28 | ``` 29 | pip install "mixmil[experiments]" 30 | ``` 31 | ### Histopathology 32 | The histopathology experiment was performed on the [CAMELYON16](https://camelyon16.grand-challenge.org/) dataset. 33 | #### Download Data 34 | To download the embeddings provided by the DSMIL authors, either: 35 | - Full embeddings: `python scripts/dsmil_data_download.py` 36 | - PCA reduced embeddings: [Google Drive](https://drive.google.com/drive/folders/1X9ho1_W5ixyHSw_2hCfQsBb5nzkjMviA?usp=sharing) 37 | 38 | ### Microscopy 39 | The full BBBC021 dataset can be downloaded [here](https://bbbc.broadinstitute.org/BBBC021). 40 | #### Download Data 41 | - We make the featurized cells available at [BBBC021](https://drive.google.com/file/d/1OyH3zg22N107qrPVp3p-GLoFa1KzeoID/view?usp=sharing) 42 | - The features are stored as an [AnnData](https://anndata.readthedocs.io/en/latest/) object. We recommend using the [scanpy](https://scanpy.readthedocs.io/en/stable/) package to read and process them 43 | - The weights of the featurizer trained with the SimCLR algorithm can be downloaded from the original [GitHub repository](https://github.com/SamriddhiJain/SimCLR-for-cell-profiling?tab=readme-ov-file) 44 | 45 | ## Citation 46 | ``` 47 | @inproceedings{engelmann2024mixed, 48 | title={Mixed Models with Multiple Instance Learning}, 49 | author={Engelmann, Jan P. and Palma, Alessandro and Tomczak, Jakub M. and Theis, Fabian and Casale, Francesco Paolo}, 50 | booktitle={International Conference on Artificial Intelligence and Statistics}, 51 | pages={3664--3672}, 52 | year={2024}, 53 | organization={PMLR} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /mixmil/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as la 3 | import statsmodels.api as sm 4 | import torch 5 | from sklearn.linear_model import LogisticRegressionCV 6 | from tqdm.auto import trange 7 | 8 | from mixmil.data import xgower_factor 9 | 10 | 11 | def regressOut(Y, X, return_b=False, return_pinv=False): 12 | """ 13 | regresses out X from Y 14 | """ 15 | Xd = la.pinv(X) 16 | b = Xd.dot(Y) 17 | Y_out = Y - X.dot(b) 18 | out = [Y_out] 19 | if return_b: 20 | out.append(b) 21 | if return_pinv: 22 | out.append(Xd) 23 | return out if len(out) > 1 else out[0] 24 | 25 | 26 | def _get_single_binomial_init_params(X, F, y): 27 | ident = np.zeros((X.shape[1]), dtype=int) 28 | model = sm.BinomialBayesMixedGLM(y, F, X, ident).fit_vb() 29 | 30 | u = np.dot(X, model.vc_mean)[::2] 31 | 32 | _scale = u.std() / np.sqrt((model.vc_mean**2).mean(0)) 33 | mu_beta = _scale * model.vc_mean 34 | sd_beta = _scale * model.vc_sd 35 | var_z = (mu_beta**2 + sd_beta**2).mean().reshape(1) 36 | alpha = model.fe_mean 37 | 38 | return mu_beta, sd_beta, var_z, alpha 39 | 40 | 41 | def get_binomial_init_params(X, F, Y): 42 | results = [_get_single_binomial_init_params(X, F, Y[:, p]) for p in trange(Y.shape[1], desc="GLMM Init")] 43 | 44 | mu_beta, sd_beta, var_z, alpha = [_list2tensor(listo) for listo in zip(*results)] 45 | 46 | return mu_beta, sd_beta, var_z, alpha 47 | 48 | 49 | def get_lr_init_params(X, Y, b, Fiv): 50 | model = LogisticRegressionCV( 51 | Cs=10, 52 | fit_intercept=True, 53 | penalty="l2", 54 | multi_class="multinomial", 55 | solver="lbfgs", 56 | n_jobs=1, 57 | verbose=0, 58 | random_state=42, 59 | max_iter=1000, 60 | refit=True, 61 | ) 62 | 63 | model.fit(X, Y.ravel()) 64 | 65 | alpha = model.intercept_[None] 66 | beta = model.coef_.T 67 | 68 | # Compute bag prediction u and reparametrize 69 | u = X.dot(beta) 70 | um = u.mean(0)[None] 71 | us = u.std(0)[None] 72 | alpha = alpha + um 73 | mu_beta = us * beta / np.sqrt((beta**2).mean(0)[None]) 74 | sd_beta = np.sqrt(0.1 * (mu_beta**2).mean()) * np.ones_like(mu_beta) 75 | 76 | alpha = Fiv.dot(np.ones((Fiv.shape[1], 1))).dot(alpha) - b.dot(mu_beta) 77 | 78 | # init prior 79 | var_z = (mu_beta**2 + sd_beta**2).mean(axis=0, keepdims=True) 80 | 81 | return [torch.Tensor(el) for el in (mu_beta, sd_beta, var_z, alpha)] 82 | 83 | 84 | def _list2tensor(_list): 85 | return torch.Tensor(np.stack(_list, axis=1)) 86 | 87 | 88 | def get_init_params(Xs, F, Y, likelihood, n_trials): 89 | Xm = np.concatenate([x.mean(0, keepdims=True) for x in Xs], axis=0) 90 | Fe, Ye = F.numpy(), Y.long().numpy() 91 | 92 | if likelihood == "binomial": 93 | Xm = (Xm - Xm.mean(0, keepdims=True)) / xgower_factor(Xm) 94 | 95 | if n_trials == 2: 96 | Xm, Fe = Xm.repeat(2, axis=0), Fe.repeat(2, axis=0) 97 | to_expanded = np.array(([[0, 0], [1, 0], [1, 1]])) 98 | Ye = to_expanded[Y.long().numpy().T].transpose(1, 2, 0).reshape(-1, Y.shape[1]) 99 | 100 | mu_z, sd_z, var_z, alpha = get_binomial_init_params(Xm, Fe, Ye) 101 | 102 | elif likelihood == "categorical": 103 | Xm, b, Fiv = regressOut(Xm, Fe, return_b=True, return_pinv=True) 104 | Xm = (Xm - Xm.mean(0, keepdims=True)) / (Xm.std(0, keepdims=True) * np.sqrt(Xm.shape[-1])) 105 | 106 | mu_z, sd_z, var_z, alpha = get_lr_init_params(Xm, Ye, b, Fiv) 107 | 108 | return mu_z, sd_z, var_z, alpha 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | data 164 | *.out 165 | assets 166 | -------------------------------------------------------------------------------- /experiments/simulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Simulation Experiment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "> In this notebook we demonstrate the how to train the MixMIL model on data simulated under as specified in the paper in the Binomial likelihood setting. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/home/icb/alessandro.palma/miniconda3/envs/sslbio-env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 27 | " from .autonotebook import tqdm as notebook_tqdm\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import scipy.stats as st\n", 34 | "import torch\n", 35 | "from sklearn.metrics import roc_auc_score\n", 36 | "\n", 37 | "from mixmil import MixMIL\n", 38 | "from mixmil.data import load_data\n", 39 | "import pandas as pd" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Utility Functions" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import numpy as np\n", 56 | "from sklearn.metrics import roc_auc_score\n", 57 | "import scipy.stats as st\n", 58 | "\n", 59 | "def _calc_metrics(u_pred, w_pred, u, w):\n", 60 | " \"\"\"\n", 61 | " Calculate correlation and AUC metrics using real and predicted instance weights.\n", 62 | "\n", 63 | " Parameters:\n", 64 | " - u_pred (numpy.ndarray): Predicted instance-level weights.\n", 65 | " - w_pred (numpy.ndarray): Predicted instance-level weights as instance proportions.\n", 66 | " - u (numpy.ndarray): True instance-level weights.\n", 67 | " - w (numpy.ndarray): True instance-level weights as instance proportions.\n", 68 | "\n", 69 | " Returns:\n", 70 | " - rho_bag (float): Weight correlation (Spearman's rank correlation coefficient).\n", 71 | " - auc_instance (float): Instance retrieval AUC (Area Under the Receiver Operating Characteristic curve).\n", 72 | " \"\"\"\n", 73 | " rho_bag = st.spearmanr(u_pred, u).correlation # bag level correlation\n", 74 | " is_top_instance = (w > np.quantile(w, 0.90)).long().ravel()\n", 75 | " auc_instance = roc_auc_score(is_top_instance, w_pred) # instance-retrieval AUC\n", 76 | " return rho_bag, auc_instance\n", 77 | "\n", 78 | "\n", 79 | "def calc_metrics(model, X, u, w):\n", 80 | " \"\"\"\n", 81 | " Calculate aggregated metrics over multiple bags or instances for a given model.\n", 82 | "\n", 83 | " Parameters:\n", 84 | " - model: Trained model.\n", 85 | " - X (dict): Dictionary containing input data.\n", 86 | " - u (dict): Dictionary containing true values for instance-level weights.\n", 87 | " - w (dict): Dictionary containing true values for instance-level weights as proportions.\n", 88 | "\n", 89 | " Returns:\n", 90 | " - res_dict (dict): Dictionary containing aggregated metrics including:\n", 91 | " - 'rho_bag' (float): Mean bag-level weight correlation.\n", 92 | " - 'rho_bag_err' (float): Standard error of the mean for bag-level correlation.\n", 93 | " - 'auc_instance' (float): Mean instance retrieval AUC.\n", 94 | " - 'auc_instance_err' (float): Standard error of the mean for instance retrieval AUC.\n", 95 | " \"\"\"\n", 96 | " u_pred = model.predict(X[\"test\"]).cpu().numpy()\n", 97 | " w_pred = model.get_weights(X[\"test\"])[0].cpu().numpy()\n", 98 | "\n", 99 | " P = u_pred.shape[1]\n", 100 | " if P > 0:\n", 101 | " rho_bag, auc_instance = [], []\n", 102 | " for i in range(P):\n", 103 | " _rho_bag, _auc_instance = _calc_metrics(\n", 104 | " u_pred[..., i], w_pred[..., i].ravel(), u[\"test\"][..., i], w[\"test\"][..., i].ravel()\n", 105 | " )\n", 106 | " rho_bag.append(_rho_bag)\n", 107 | " auc_instance.append(_auc_instance)\n", 108 | "\n", 109 | " res_dict = {\n", 110 | " \"rho_bag\": np.mean(rho_bag),\n", 111 | " \"rho_bag_err\": np.std(rho_bag) / np.sqrt(P),\n", 112 | " \"auc_instance\": np.mean(auc_instance),\n", 113 | " \"auc_instance_err\": np.std(auc_instance) / np.sqrt(P),\n", 114 | " }\n", 115 | " else:\n", 116 | " rho_bag, auc_instance = _calc_metrics(u_pred, w_pred.ravel(), u[\"test\"], w[\"test\"].ravel())\n", 117 | " res_dict = {\"rho_bag\": rho_bag, \"auc_instance\": auc_instance}\n", 118 | " return res_dict\n", 119 | "\n", 120 | "def print_metrics(prefix, metrics):\n", 121 | " \"\"\"\n", 122 | " Print a formatted representation of metrics with a specified prefix.\n", 123 | "\n", 124 | " Parameters:\n", 125 | " - prefix (str): Prefix to be added to the printed metrics, for better identification.\n", 126 | " - metrics (dict): Dictionary containing metrics to be printed.\n", 127 | "\n", 128 | " Returns:\n", 129 | " - None: This function prints the metrics to the console without returning any value.\n", 130 | " \"\"\"\n", 131 | " print(f\"{prefix} metrics:\")\n", 132 | " for k, v in metrics.items():\n", 133 | " print(f\"{k}: {v:.4f}\")\n", 134 | " print()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Training" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Train model with simulated data under using a binomial likelihood." 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "GLMM Init: 100%|██████████| 10/10 [00:07<00:00, 1.33it/s]\n" 161 | ] 162 | }, 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "[START] metrics:\n", 168 | "rho_bag: 0.6004\n", 169 | "rho_bag_err: 0.0131\n", 170 | "auc_instance: 0.5000\n", 171 | "auc_instance_err: 0.0000\n", 172 | "\n" 173 | ] 174 | }, 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "Epoch: 100%|██████████| 2000/2000 [04:30<00:00, 7.39it/s]" 180 | ] 181 | }, 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "[END] metrics:\n", 187 | "rho_bag: 0.8316\n", 188 | "rho_bag_err: 0.0111\n", 189 | "auc_instance: 0.9362\n", 190 | "auc_instance_err: 0.0078\n", 191 | "\n" 192 | ] 193 | }, 194 | { 195 | "name": "stderr", 196 | "output_type": "stream", 197 | "text": [ 198 | "\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", 204 | "\n", 205 | "# Simulate data as described in the paper\n", 206 | "# embeddings, fixed effects, labels, sim bag predictions, sim instance weights\n", 207 | "# P: number of outputs, simulated from the same embeddings X\n", 208 | "X, F, Y, u, w = load_data(P=10, seed=0)\n", 209 | "model = MixMIL.init_with_mean_model(X[\"train\"], F[\"train\"], Y[\"train\"], likelihood=\"binomial\", n_trials=2).to(device)\n", 210 | "X, F, Y = [{key: val.to(device) for key, val in el.items()} for el in [X, F, Y]]\n", 211 | "\n", 212 | "print_metrics(\"[START]\", calc_metrics(model, X, u, w))\n", 213 | "# Fit model in parallel to each output separately\n", 214 | "model.train(X[\"train\"], F[\"train\"], Y[\"train\"], n_epochs=2_000)\n", 215 | "print_metrics(\"[END]\", calc_metrics(model, X, u, w))" 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3 (ipykernel)", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.9.13" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /mixmil/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.distributions import Binomial, Categorical, LowRankMultivariateNormal 4 | from torch.distributions.kl import kl_divergence 5 | from torch.utils.data import DataLoader 6 | from torch_scatter import scatter_softmax, segment_add_csr 7 | from tqdm.auto import trange 8 | 9 | from mixmil.data import MILDataset, mil_collate_fn, setup_scatter 10 | from mixmil.posterior import GaussianVariationalPosterior 11 | from mixmil.utils import get_init_params 12 | 13 | 14 | class MixMIL(torch.nn.Module): 15 | """Attention-based Multi-instance Mixed Models 16 | 17 | https://arxiv.org/abs/2311.02455 18 | 19 | """ 20 | 21 | def __init__(self, Q, K, P=1, likelihood="binomial", n_trials=2, mean_field=False, init_params=None): 22 | r"""Initialize the MixMil class. 23 | 24 | Parameters: 25 | - Q (int): The dimension of the latent space. 26 | - K (int): The number of fixed effects. 27 | - P (int): The number of outputs. 28 | - likelihood (str, optional): The likelihood to use. Either "binomial" or "categorical". Default is "binomial". 29 | - n_trials (int, optional): Number of trials for binomial likelihood. Not used for categorical. Default is 2. 30 | - mean_field (bool, optional): Toggle mean field approximation for the posterior. Default is False. 31 | - init_params (tuple, optional): Tuple of (mean, var, var_z, alpha) to initialize the model. Default is None. 32 | mean (torch.Tensor): The mean of the posterior. Shape: (Q, P). d 33 | var (torch.Tensor): The variance of the posterior. Shape: (Q, P). 34 | var_z (torch.Tensor): The $\sigma_{\beta}^2$ hparam of the prior. 35 | Shape: (1, P) with separate and (1, 1) with shared priors . 36 | alpha (torch.Tensor): The fixed effect parameters. Shape: (K, P). 37 | """ 38 | super().__init__() 39 | self.Q = Q 40 | 41 | alpha = torch.zeros((K, P)) 42 | log_sigma_u = torch.full((1, P), 0.5 * np.log(0.5)) 43 | log_sigma_z = torch.full((1, P), 0.5 * np.log(0.5)) 44 | 45 | if init_params is not None: 46 | *_, var_z, alpha = init_params 47 | log_sigma_z = 0.5 * torch.log(var_z) 48 | 49 | self.alpha = torch.nn.Parameter(alpha) 50 | self.log_sigma_u = torch.nn.Parameter(log_sigma_u) 51 | self.log_sigma_z = torch.nn.Parameter(log_sigma_z) 52 | 53 | self.posterior = GaussianVariationalPosterior(2 * Q, P, mean_field, init_params) 54 | 55 | self.likelihood_name = likelihood 56 | self.n_trials = n_trials if likelihood == "binomial" else None 57 | self.is_trained = False 58 | 59 | def init_with_mean_model(Xs, F, Y, likelihood="binomial", n_trials=None, mean_field=False): 60 | assert (likelihood == "binomial" and n_trials is not None and 0 < n_trials <= 2) or ( 61 | likelihood == "categorical" and n_trials is None 62 | ), f"n_trials must be 1 or 2 to initialize with binomial mean model, got {n_trials=} and {likelihood=}" 63 | init_params = get_init_params(Xs, F, Y, likelihood, n_trials) 64 | Q, K, P = Xs[0].shape[1], F.shape[1], init_params[0].shape[1] 65 | return MixMIL(Q, K, P, likelihood, n_trials, mean_field, init_params) 66 | 67 | @property 68 | def prior_distribution(self): 69 | device = self.log_sigma_u.device 70 | scale_u = self.log_sigma_u.T * torch.ones([1, self.Q], device=device) 71 | scale_z = self.log_sigma_z.T * torch.ones([1, self.Q], device=device) 72 | cov_logdiag = 2 * torch.cat([scale_u, scale_z], 1) 73 | cov_factor = torch.zeros_like(cov_logdiag)[:, :, None] 74 | mu = torch.zeros_like(cov_logdiag) 75 | return LowRankMultivariateNormal(mu, cov_factor, torch.exp(cov_logdiag)) 76 | 77 | @property 78 | def posterior_distribution(self): 79 | return self.posterior.distribution 80 | 81 | @property 82 | def qu_mu(self): 83 | return self.posterior.q_mu[: self.Q] 84 | 85 | @property 86 | def qz_mu(self): 87 | return self.posterior.q_mu[self.Q :] 88 | 89 | def likelihood(self, logits, y): 90 | if self.likelihood_name == "binomial": 91 | return Binomial(total_count=self.n_trials, logits=logits).log_prob(y[:, :, None]).sum(1).mean() 92 | elif self.likelihood_name == "categorical": 93 | logits = logits.permute(0, 2, 1) 94 | if logits.shape[-1] == 1: 95 | logits = torch.cat([-logits, logits], 2) 96 | return Categorical(logits=logits).log_prob(y).mean() 97 | 98 | def loss(self, u, f, y, kld_w=1.0, return_dict=False): 99 | logits = f.mm(self.alpha)[:, :, None] + u 100 | 101 | ll = self.likelihood(logits, y) 102 | kld = kl_divergence(self.posterior_distribution, self.prior_distribution) 103 | kld_term = kld_w * kld.sum() / y.shape[0] 104 | loss = -ll + kld_term 105 | if return_dict: 106 | return loss, dict(loss=loss.item(), ll=ll.item(), kld=kld_term.item()) 107 | return loss 108 | 109 | def get_betas(self, n_samples=None, predict=False): 110 | assert not (n_samples and predict) 111 | if n_samples: 112 | beta = self.posterior.sample(n_samples) 113 | beta_u = beta[: self.Q, :, :] 114 | beta_z = beta[self.Q :, :, :] 115 | else: 116 | beta_u = self.qu_mu[:, :, None] 117 | beta_z = self.qz_mu[:, :, None] 118 | return beta_u, beta_z 119 | 120 | def forward(self, Xs, n_samples=8, scaling=None, predict=False): 121 | beta_u, beta_z = self.get_betas(n_samples, predict) 122 | b = torch.sqrt((beta_z**2).mean(0, keepdim=True)) 123 | eta = beta_z / b 124 | 125 | if torch.is_tensor(Xs): 126 | u = self._calc_bag_emb_effect_tensor(beta_u, eta, Xs) 127 | else: 128 | u = self._calc_bag_emb_effect_scatter(beta_u, eta, Xs) 129 | 130 | mean, std = (u.mean(0), u.std(0)) if scaling is None else scaling 131 | if std.isnan().any(): 132 | std = 1 133 | u = b * (u - mean) / std 134 | return u 135 | 136 | def _calc_bag_emb_effect_tensor(self, beta_u, eta, Xs): 137 | _w = torch.einsum("niq,qps->nips", Xs, beta_u) 138 | w = torch.softmax(_w, dim=1) 139 | t = torch.einsum("niq,qps->nips", Xs, eta) 140 | u = torch.einsum("nips,nips->nps", w, t) 141 | return u 142 | 143 | def _calc_bag_emb_effect_scatter(self, beta_u, eta, Xs): 144 | x, i, i_ptr = setup_scatter(Xs) 145 | 146 | _w = torch.einsum("iq,qps->ips", x, beta_u) 147 | w = scatter_softmax(_w, i, dim=0) 148 | t = torch.einsum("iq,qps->ips", x, eta) 149 | u = segment_add_csr(w * t, i_ptr) 150 | return u 151 | 152 | def train(self, X, F, Y, n_epochs=2_000, batch_size=64, lr=1e-3, verbose=True): 153 | train_loader = DataLoader( 154 | MILDataset(X, F, Y), 155 | shuffle=True, 156 | batch_size=batch_size, 157 | collate_fn=None if torch.is_tensor(X) else mil_collate_fn, 158 | ) 159 | optim = torch.optim.Adam(lr=lr, params=self.parameters()) 160 | 161 | history = [] 162 | for epoch in trange(1, n_epochs + 1, desc="Epoch", disable=not verbose): 163 | for step, (xs, f, y) in enumerate(train_loader): 164 | u = self(xs) 165 | loss, ldict = self.loss(u, f, y, kld_w=len(xs) / len(Y), return_dict=True) 166 | ldict["epoch"], ldict["step"] = epoch, step 167 | history.append(ldict) 168 | optim.zero_grad() 169 | loss.backward() 170 | optim.step() 171 | 172 | self.is_trained = True 173 | return history 174 | 175 | @torch.inference_mode() 176 | def predict(self, Xs, scaling=None): 177 | return self(Xs, n_samples=None, predict=True, scaling=scaling).squeeze(2) 178 | 179 | @torch.inference_mode() 180 | def get_weights(self, Xs, ravel=False): 181 | """Get instance weights after and before softmax""" 182 | beta_u, _ = self.get_betas(predict=True) 183 | beta_u = beta_u.squeeze(2) # not taking mcmc samples 184 | if torch.is_tensor(Xs): 185 | _w = torch.einsum("niq,qp->nip", Xs, beta_u) 186 | w = torch.softmax(_w, dim=1) 187 | 188 | else: 189 | x, i, _ = setup_scatter(Xs) 190 | _w = torch.einsum("iq,qp->ip", x, beta_u) 191 | w = scatter_softmax(_w, i, dim=0) 192 | 193 | if ravel: 194 | w, _w = w.ravel(), _w.ravel() 195 | elif not torch.is_tensor(Xs): 196 | _w = [_w[i == idx] for idx in range(len(Xs))] 197 | w = [w[i == idx] for idx in range(len(Xs))] 198 | return w, _w 199 | 200 | def extra_repr(self): 201 | string = f"Q={self.Q}, K={self.alpha.shape[0]}, P={self.alpha.shape[1]}, likelihood={self.likelihood_name}" 202 | if self.likelihood_name == "binomial": 203 | string += f", n_trials={self.n_trials}" 204 | string += f", device={self.alpha.device}, trained={self.is_trained}" 205 | string += f"\n(alpha): Parameter(shape={tuple(self.alpha.shape)})\n" 206 | string += f"(log_sigma_u): Parameter(shape={tuple(self.log_sigma_u.shape)})\n" 207 | string += f"(log_sigma_z): Parameter(shape={tuple(self.log_sigma_z.shape)})" 208 | return string 209 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /experiments/histopathology_camelyon16.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Histopathology Experiment" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 10, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from mixmil.paths import DATA\n", 17 | "import pandas as pd\n", 18 | "from tqdm import tqdm\n", 19 | "from sklearn.preprocessing import StandardScaler\n", 20 | "import numpy as np\n", 21 | "import anndata as ad\n", 22 | "from mixmil import MixMIL\n", 23 | "import torch\n", 24 | "from sklearn.metrics import roc_auc_score\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import scipy.stats as st" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "In this notebook we show how to apply MixMIL to the Camelyon dataset for the classification of histopathological slides as containing tumor or not. " 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Utility functions " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def to_device(el, device):\n", 50 | " \"\"\"\n", 51 | " Move a nested structure of elements (dict, list, tuple, torch.Tensor, torch.nn.Module) to the specified device.\n", 52 | "\n", 53 | " Parameters:\n", 54 | " - el: Element or nested structure of elements to be moved to the device.\n", 55 | " - device (torch.device): The target device, such as 'cuda' for GPU or 'cpu' for CPU.\n", 56 | "\n", 57 | " Returns:\n", 58 | " - Transferred element(s) in the same structure: Elements moved to the specified device.\n", 59 | " \"\"\"\n", 60 | " if isinstance(el, dict):\n", 61 | " return {k: to_device(v, device) for k, v in el.items()}\n", 62 | " elif isinstance(el, (list, tuple)):\n", 63 | " return [to_device(x, device) for x in el]\n", 64 | " elif isinstance(el, (torch.Tensor, torch.nn.Module)):\n", 65 | " return el.to(device)\n", 66 | " else:\n", 67 | " return el" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Data Loading" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "We start from reading the Camelyon dataset. The bag features are stored in multiple `.csv` files. The association of the files, their labels and whether they are train or test samples is stored in the `Camelyon16.csv` file." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/html": [ 92 | "
| \n", 110 | " | label | \n", 111 | "split | \n", 112 | "
|---|---|---|
| file | \n", 115 | "\n", 116 | " | \n", 117 | " |
| 1-tumor/test_033.csv | \n", 122 | "1 | \n", 123 | "test | \n", 124 | "
| 0-normal/normal_148.csv | \n", 127 | "0 | \n", 128 | "train | \n", 129 | "
| 0-normal/test_095.csv | \n", 132 | "0 | \n", 133 | "test | \n", 134 | "
| 0-normal/normal_025.csv | \n", 137 | "0 | \n", 138 | "train | \n", 139 | "
| 0-normal/test_087.csv | \n", 142 | "0 | \n", 143 | "test | \n", 144 | "
| ... | \n", 147 | "... | \n", 148 | "... | \n", 149 | "
| 0-normal/normal_006.csv | \n", 152 | "0 | \n", 153 | "train | \n", 154 | "
| 1-tumor/tumor_003.csv | \n", 157 | "1 | \n", 158 | "train | \n", 159 | "
| 0-normal/normal_018.csv | \n", 162 | "0 | \n", 163 | "train | \n", 164 | "
| 1-tumor/tumor_017.csv | \n", 167 | "1 | \n", 168 | "train | \n", 169 | "
| 0-normal/test_067.csv | \n", 172 | "0 | \n", 173 | "test | \n", 174 | "
399 rows × 2 columns
\n", 178 | "| \n", 104 | " | label | \n", 105 | "split | \n", 106 | "
|---|---|---|
| file | \n", 109 | "\n", 110 | " | \n", 111 | " |
| 1-tumor/test_033.csv | \n", 116 | "1 | \n", 117 | "test | \n", 118 | "
| 0-normal/normal_148.csv | \n", 121 | "0 | \n", 122 | "train | \n", 123 | "
| 0-normal/test_095.csv | \n", 126 | "0 | \n", 127 | "test | \n", 128 | "
| 0-normal/normal_025.csv | \n", 131 | "0 | \n", 132 | "train | \n", 133 | "
| 0-normal/test_087.csv | \n", 136 | "0 | \n", 137 | "test | \n", 138 | "
| ... | \n", 141 | "... | \n", 142 | "... | \n", 143 | "
| 0-normal/normal_006.csv | \n", 146 | "0 | \n", 147 | "train | \n", 148 | "
| 1-tumor/tumor_003.csv | \n", 151 | "1 | \n", 152 | "train | \n", 153 | "
| 0-normal/normal_018.csv | \n", 156 | "0 | \n", 157 | "train | \n", 158 | "
| 1-tumor/tumor_017.csv | \n", 161 | "1 | \n", 162 | "train | \n", 163 | "
| 0-normal/test_067.csv | \n", 166 | "0 | \n", 167 | "test | \n", 168 | "
399 rows × 2 columns
\n", 172 | "