├── mober ├── __init__.py ├── core │ ├── __init__.py │ ├── utils.py │ ├── data_utils.py │ ├── projection.py │ └── train.py ├── loss │ ├── vae.py │ └── classification.py ├── models │ ├── mlp.py │ ├── utils.py │ └── batch_vae.py └── mober.py ├── asset ├── MOBER_logo.png └── MOBER_model.png ├── setup.py ├── LICENSE ├── .gitignore └── README.md /mober/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -------------------------------------------------------------------------------- /mober/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -------------------------------------------------------------------------------- /asset/MOBER_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/MOBER/HEAD/asset/MOBER_logo.png -------------------------------------------------------------------------------- /asset/MOBER_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Novartis/MOBER/HEAD/asset/MOBER_model.png -------------------------------------------------------------------------------- /mober/loss/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.distributions import Normal, kl_divergence 4 | from torch.nn import functional 5 | 6 | 7 | def loss_function_vae(dec, x, mu, stdev, kl_weight=1.0): 8 | # sum over genes, mean over samples, like trvae 9 | 10 | mean = torch.zeros_like(mu) 11 | scale = torch.ones_like(stdev) 12 | 13 | KLD = kl_divergence(Normal(mu, stdev), Normal(mean, scale)).mean(dim=1) 14 | 15 | reconst_loss = functional.mse_loss(dec, x, reduction='none').mean(dim=1) 16 | 17 | return (reconst_loss + kl_weight * KLD).sum(dim=0) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from setuptools import setup, find_packages 6 | 7 | setup( 8 | name='mober', 9 | version='2.0.0', 10 | url='ssh://git@bitbucket.prd.nibr.novartis.net/ods/ods-mober.git', 11 | author='mober team', 12 | author_email='gang-6.li@novartis.com', 13 | description='mober', 14 | packages=find_packages(), 15 | install_requires=['mlflow', 'scanpy'], 16 | 17 | keywords='mober', 18 | 19 | entry_points={ 20 | 'console_scripts': ['mober = mober.mober:main'] 21 | } 22 | ) -------------------------------------------------------------------------------- /mober/loss/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn import NLLLoss 4 | 5 | 6 | def loss_function_classification(pred, target, class_weights): 7 | """ 8 | Compute negative log likelihood loss. 9 | 10 | :param pred: predictions 11 | :param target: actual classes 12 | :param class_weights: weights - one per class 13 | :return: Weighted prediction loss. Summed for all the samples, not averaged. 14 | """ 15 | loss_function = NLLLoss(weight=class_weights, reduction="none") 16 | return loss_function(pred, torch.argmax(target, dim=1)).sum(dim=0) 17 | -------------------------------------------------------------------------------- /mober/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP(nn.Module): 5 | """ 6 | MLP module used for multiclass classification on the encodings. 7 | """ 8 | def __init__(self, enc_dim, output_dim): 9 | super().__init__() 10 | self.activation = nn.SELU() 11 | self.fc1 = nn.Linear(enc_dim, enc_dim) 12 | self.bn1 = nn.BatchNorm1d(enc_dim, momentum=0.01, eps=0.001) 13 | self.fc2 = nn.Linear(enc_dim, enc_dim) 14 | self.bn2 = nn.BatchNorm1d(enc_dim, momentum=0.01, eps=0.001) 15 | self.fc3 = nn.Linear(enc_dim, output_dim) 16 | self.soft = nn.LogSoftmax(dim=1) 17 | 18 | def forward(self, x): 19 | out = self.fc1(x) 20 | out = self.bn1(out) 21 | out = self.activation(out) 22 | out = self.fc2(out) 23 | out = self.bn2(out) 24 | out = self.activation(out) 25 | out = self.fc3(out) 26 | out = self.soft(out) 27 | return out 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Novartis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mober/core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import shutil 4 | from pathlib import Path 5 | import mlflow 6 | 7 | 8 | def create_temp_dirs(tmp_dir): 9 | Path(os.path.join(tmp_dir, "models")).mkdir(parents=True, exist_ok=True) 10 | Path(os.path.join(tmp_dir, "metrics")).mkdir(parents=True, exist_ok=True) 11 | Path(os.path.join(tmp_dir, "projection")).mkdir(parents=True, exist_ok=True) 12 | 13 | 14 | def remove_temp_dirs(tmp_dir): 15 | shutil.rmtree(tmp_dir) 16 | 17 | 18 | 19 | 20 | class log_obj: 21 | def __init__(self, use_mlflow, run_dir): 22 | self.use_mlflow = use_mlflow 23 | self.run_dir = run_dir 24 | self.fhands = {} 25 | 26 | def log_params(self,args): 27 | if self.use_mlflow: mlflow.log_params(vars(args)) 28 | dfparams = pd.DataFrame(data=vars(args),index=['value']).transpose() 29 | dfparams.to_csv(os.path.join(self.run_dir, 'models', 'params.csv')) 30 | 31 | def log_metric(self,name,value,epoch): 32 | if self.use_mlflow: mlflow.log_metric(name, value, step=epoch) 33 | else: 34 | if name not in self.fhands.keys(): 35 | fhand = open(os.path.join(self.run_dir,'metrics',name),'w',buffering=1) 36 | fhand.write('epoch\tvalue\n') 37 | self.fhands[name] = fhand 38 | 39 | self.fhands[name].write(f'{epoch}\t{value}\n') 40 | 41 | def end_log(self): 42 | if self.use_mlflow: 43 | mlflow.log_artifacts(self.run_dir) 44 | mlflow.end_run() 45 | remove_temp_dirs(self.run_dir) 46 | else: 47 | for fhand in self.fhands.values(): fhand.close() 48 | 49 | -------------------------------------------------------------------------------- /mober/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import optim, nn 4 | 5 | 6 | def create_model(model_cls, device, *args, filename=None, lr=1e-3, **kwargs): 7 | """ 8 | Simple model serialization to resume training from given epoch. 9 | 10 | :param model_cls: Model definition 11 | :param device: Device (cpu or gpu) 12 | :param args: arguments to be passed to the model constructor 13 | :param filename: filename if the model is to be loaded 14 | :param lr: learning rate to be used by the model optimizer 15 | :param kwargs: keyword arguments to be used by the model constructor 16 | :return: 17 | """ 18 | model = model_cls(*args, **kwargs) 19 | optimizer = optim.Adam(model.parameters(), lr=lr) 20 | if filename is not None: 21 | checkpoint = torch.load(filename, map_location=torch.device("cpu")) 22 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 23 | model.load_state_dict(checkpoint["model_state_dict"]) 24 | 25 | print(f"Loaded model epoch: {checkpoint['epoch']}, loss {checkpoint['loss']}") 26 | 27 | if device.type == "cuda" and torch.cuda.device_count() > 1: 28 | print("Loading model on ", torch.cuda.device_count(), "GPUs") 29 | model = nn.DataParallel(model) 30 | return model.to(device), optimizer 31 | 32 | 33 | def save_model(model, optimizer, epoch, loss, filename, device): 34 | """ 35 | Save the model to a file. 36 | 37 | :param model: model to be saved 38 | :param optimizer: model optimizer 39 | :param epoch: number of epoch, only for information 40 | :param loss: loss, only for information 41 | :param filename: where to save the model 42 | :param device: device of the model 43 | """ 44 | if device.type == "cuda" and torch.cuda.device_count() > 1: 45 | model_state_dict = model.module.state_dict() 46 | else: 47 | model_state_dict = model.state_dict() 48 | torch.save({ 49 | "epoch": epoch, 50 | "model_state_dict": model_state_dict, 51 | "optimizer_state_dict": optimizer.state_dict(), 52 | "loss": loss 53 | }, filename) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /mober/core/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data import TensorDataset 4 | from torch.utils.data import DataLoader 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import os 9 | 10 | from scipy.sparse import csr_matrix 11 | 12 | from sklearn.utils.class_weight import compute_class_weight 13 | 14 | # modified from https://discuss.pytorch.org/t/sparse-dataset-and-dataloader/55466, credits to ironv 15 | class SparseDataset(Dataset): 16 | def __init__(self, sp_matrix, label, device='cpu'): 17 | if type(sp_matrix) != csr_matrix: csr = csr_matrix(sp_matrix) 18 | else: csr = sp_matrix 19 | 20 | self.dim = csr.shape 21 | self.device = torch.device(device) 22 | 23 | self.indptr = torch.tensor(csr.indptr, dtype=torch.int64, device=self.device) 24 | self.indices = torch.tensor(csr.indices, dtype=torch.int64, device=self.device) 25 | self.data = torch.tensor(csr.data, dtype=torch.float32, device=self.device) 26 | 27 | self.label = torch.tensor(label, dtype=torch.float32, device=self.device) 28 | 29 | def __len__(self): 30 | return self.dim[0] 31 | 32 | def __getitem__(self, idx): 33 | obs = torch.zeros((self.dim[1],), dtype=torch.float32, device=self.device) 34 | ind1,ind2 = self.indptr[idx],self.indptr[idx+1] 35 | obs[self.indices[ind1:ind2]] = self.data[ind1:ind2] 36 | 37 | return obs,self.label[idx] 38 | 39 | def get_class_weights(class_series, balanced_sources): 40 | sorted_classes = sorted(class_series.unique()) 41 | source = class_series.astype(pd.CategoricalDtype(sorted_classes, ordered=True)) 42 | src_weight_factors = np.ones(source.unique().shape) 43 | if balanced_sources: 44 | src_weight_factors = compute_class_weight("balanced", classes=sorted_classes, y=source) 45 | return src_weight_factors 46 | 47 | 48 | def create_dataloaders_from_adata(adata, batch_size, val_set_size, random_seed, use_sparse_mat=False): 49 | 50 | assert val_set_size >= 0 and val_set_size < 1.0 51 | samples = adata.obs.index.values 52 | splt = int(adata.shape[0]*(1-val_set_size)) 53 | np.random.seed(random_seed) 54 | sample_inds = np.arange(len(samples)) 55 | np.random.shuffle(sample_inds) 56 | np.random.seed() # reset seed 57 | 58 | tr_samples = samples[sample_inds[:splt]] 59 | val_samples = samples[sample_inds[splt:]] 60 | 61 | label_encode = pd.get_dummies(sorted(adata.obs.data_source.unique())) 62 | label = pd.get_dummies(adata.obs.data_source) 63 | 64 | if use_sparse_mat: 65 | train_data = SparseDataset(adata[tr_samples,:].X,label.loc[tr_samples,:].values) 66 | if len(val_samples)>0: 67 | val_data = SparseDataset(adata[val_samples,:].X,label.loc[val_samples,:].values) 68 | else: 69 | try: adata.X = adata.X.todense() 70 | except: None 71 | train_data = TensorDataset(torch.Tensor(adata[tr_samples,:].X) ,torch.Tensor(label.loc[tr_samples,:].values)) 72 | if len(val_samples) > 0: 73 | val_data = TensorDataset(torch.Tensor(adata[val_samples,:].X),torch.Tensor(label.loc[val_samples,:].values)) 74 | 75 | 76 | train_loader = DataLoader(train_data,batch_size=batch_size, shuffle=True) 77 | if len(val_samples) > 0: val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True) 78 | else: val_loader = None 79 | 80 | return train_loader, val_loader, label_encode 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /mober/models/batch_vae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.distributions import Normal 4 | 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder that takes the original gene expression and produces the encoding. 9 | 10 | Consists of 3 FC layers. 11 | """ 12 | def __init__(self, n_genes, enc_dim): 13 | super().__init__() 14 | self.activation = nn.SELU() 15 | self.fc1 = nn.Linear(n_genes, 256) 16 | self.bn1 = nn.BatchNorm1d(256, momentum=0.01, eps=0.001) 17 | self.dp1 = nn.Dropout(p=0.1) 18 | self.fc2 = nn.Linear(256, 128) 19 | self.bn2 = nn.BatchNorm1d(128, momentum=0.01, eps=0.001) 20 | self.dp2 = nn.Dropout(p=0.1) 21 | 22 | self.linear_means = nn.Linear(128, enc_dim) 23 | self.linear_log_vars = nn.Linear(128, enc_dim) 24 | 25 | def reparameterize(self, means, stdev): 26 | 27 | return Normal(means, stdev).rsample() 28 | 29 | def encode(self, x): 30 | # encode 31 | enc = self.fc1(x) 32 | enc = self.bn1(enc) 33 | enc = self.activation(enc) 34 | enc = self.dp1(enc) 35 | enc = self.fc2(enc) 36 | enc = self.bn2(enc) 37 | enc = self.activation(enc) 38 | enc = self.dp2(enc) 39 | 40 | means = self.linear_means(enc) 41 | log_vars = self.linear_log_vars(enc) 42 | 43 | stdev = torch.exp(0.5 * log_vars) + 1e-4 44 | z = self.reparameterize(means, stdev) 45 | 46 | return means, stdev, z 47 | 48 | def forward(self, x): 49 | return self.encode(x) 50 | 51 | 52 | class Decoder(nn.Module): 53 | """ 54 | A decoder model that takes the encodings and a batch (source) matrix and produces decodings. 55 | 56 | Made up of 3 FC layers. 57 | """ 58 | def __init__(self, n_genes, enc_dim, n_batch): 59 | super().__init__() 60 | self.activation = nn.SELU() 61 | self.final_activation = nn.ReLU() 62 | self.fcb = nn.Linear(n_batch, n_batch) 63 | self.bnb = nn.BatchNorm1d(n_batch, momentum=0.01, eps=0.001) 64 | self.fc4 = nn.Linear(enc_dim + n_batch, 128) 65 | self.bn4 = nn.BatchNorm1d(128, momentum=0.01, eps=0.001) 66 | self.fc5 = nn.Linear(128, 256) 67 | self.bn5 = nn.BatchNorm1d(256, momentum=0.01, eps=0.001) 68 | 69 | self.out_fc = nn.Linear(256, n_genes) 70 | 71 | 72 | def forward(self, z, batch): 73 | # batch input 74 | b = self.fcb(batch) 75 | b = self.bnb(b) 76 | b = self.activation(b) 77 | 78 | # concat with z 79 | n_z = torch.cat([z, b], dim=1) 80 | 81 | # decode layers 82 | dec = self.fc4(n_z) 83 | dec = self.bn4(dec) 84 | dec = self.activation(dec) 85 | dec = self.fc5(dec) 86 | dec = self.bn5(dec) 87 | dec = self.activation(dec) 88 | dec = self.final_activation(self.out_fc(dec)) 89 | 90 | return dec 91 | 92 | 93 | class BatchVAE(nn.Module): 94 | """ 95 | Batch Autoencoder. 96 | Encoder is composed of 3 FC layers. 97 | Decoder is symmetrical to encoder + Batch input. 98 | """ 99 | 100 | def __init__(self, n_genes, enc_dim, n_batch): 101 | super().__init__() 102 | 103 | self.encoder = Encoder(n_genes, enc_dim) 104 | self.decoder = Decoder(n_genes, enc_dim, n_batch) 105 | 106 | def forward(self, x, batch): 107 | means, stdev, enc = self.encoder(x) 108 | dec = self.decoder(enc, batch) 109 | 110 | return dec, enc, means, stdev -------------------------------------------------------------------------------- /mober/core/projection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from mober.core import data_utils 8 | import scanpy as sc 9 | from scipy.sparse import csr_matrix 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data import TensorDataset 14 | 15 | from mober.models import utils as model_utlis 16 | from mober.models.batch_vae import BatchVAE 17 | 18 | def decode(data_loader, model, device,decimals): 19 | """ 20 | Get decodings numpy array from a data loader of data and a trained model 21 | 22 | :param data_loader: data loader that returns data and batch (source) annotations 23 | :param model: trained model 24 | :param device: device - cpu or gpu 25 | :return: a tuple of numpy arrays, one of decodings and another with encodings 26 | """ 27 | decoded = [] 28 | encoded = [] 29 | model.eval() 30 | with torch.no_grad(): 31 | for data, batch in data_loader: 32 | data = data.to(device) 33 | batch = batch.to(device) 34 | dec, enc = model(data, batch)[:2] 35 | encoded.append(enc) 36 | decoded.append(dec) 37 | 38 | encoded = torch.cat(encoded,dim=0).detach().cpu().numpy().round(decimals=decimals) 39 | decoded = torch.cat(decoded,dim=0).detach().cpu().numpy().round(decimals=decimals) 40 | 41 | return decoded, encoded 42 | 43 | def load_model(model_dir, device): 44 | features = pd.read_csv(os.path.join(model_dir,'features.csv'),index_col=0).index 45 | label_encode = pd.read_csv(os.path.join(model_dir,'label_encode.csv'),index_col=0) 46 | params = pd.read_csv(os.path.join(model_dir,'params.csv'),index_col=0) 47 | 48 | model, _ = model_utlis.create_model(BatchVAE, 49 | device, 50 | features.shape[0], 51 | int(params.loc['encoding_dim','value']), 52 | label_encode.shape[0], 53 | filename=os.path.join(model_dir,'batch_ae_final.model')) 54 | return model, features, label_encode 55 | 56 | def do_projection(model,adata, onto, label_encode, device, decimals=4, batch_size=1600, use_sparse_mat=False): 57 | 58 | label = np.array([label_encode[onto].values for _ in range(adata.shape[0])]) 59 | 60 | if use_sparse_mat: dataset = data_utils.SparseDataset(adata.X,label) 61 | else: 62 | try: X = adata.X.todense() 63 | except: X = adata.X 64 | dataset = TensorDataset(torch.Tensor(X),torch.Tensor(label)) 65 | 66 | data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=False) 67 | 68 | projected, z = decode(data_loader, model, device, decimals) 69 | 70 | if use_sparse_mat: proj_adata = sc.AnnData(csr_matrix(projected),obs=adata.obs,var=adata.var) 71 | else: proj_adata = sc.AnnData(projected,obs=adata.obs,var=adata.var) 72 | proj_adata.obs['projected_onto'] = onto 73 | z_adata = sc.AnnData(z, obs=adata.obs,var=pd.DataFrame(index=[f'z_{i}' for i in range(z.shape[1])])) 74 | 75 | return proj_adata, z_adata 76 | 77 | 78 | def main(args): 79 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 80 | adata = sc.read(args.projection_file) 81 | model, features, label_encode = load_model(args.model_dir, device) 82 | adata = adata[:,features] 83 | 84 | proj_adata, z_adata = do_projection(model, adata, args.onto, label_encode, device, decimals=args.decimals, batch_size=1600) 85 | proj_adata.write(args.output_file) 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![img](asset/MOBER_logo.png) 2 | 3 | **MOBER** (Multi Origin Batch Effect Remover) is a deep learning-based method that performs biologically relevant integration of transcriptional profiles from pre-clinical models and clinical tumors. MOBER can be used to guide the selection of cell lines and patient-derived xenografts and identify models that more closely resemble clinical tumors. We applied MOBER on transcriptional profiles from 932 cancer cell lines, 442 patient-derived xenografts and 11205 clinical tumors and identified pre-clinical models with greatest transcriptional fidelity to clinical tumors, and models that are transcriptionally unrepresentative of their respective clinical tumors. MOBER is interpretable by design, therefore allowing drug hunters to better understand the underlying biological differences between models and patients that are responsible for the observed lack of clinical translatability. 4 | MOBER can remove batch effects between any transcriptomics datasets of different origin while conserving relevant biological signals. 5 | 6 | 7 | ![img](asset/MOBER_model.png) 8 | 9 | See our latest [manuscript](https://doi.org/10.1101/2022.09.07.506964) and check our [web app](https://mober.pythonanywhere.com/) where the aligned data on cancer cell lines, patient-derived xenografts and clinical tumors can be explored interactively. 10 | 11 | 12 | ### Installing MOBER 13 | 1. cuda and pytorch 14 | Find cuda available cuda version with `module avail cuda`. Install [Pytorch](https://pytorch.org/) according the the latest cuda version you found. 15 | 16 | 2. Install mober 17 | ```linux 18 | git clone https://github.com/Novartis/mober.git 19 | cd mober 20 | pip install -e . 21 | ``` 22 | 23 | Check if it is successfully installed: run `mober --help` in the terminal from any directories. 24 | 25 | 26 | ### 1. Preparing input h5ad file for training 27 | The input file should be in [anndata](https://anndata.readthedocs.io/en/latest/) format and saved as h5ad. In the file, the column "**data_source**" that specifies the batch ID of samples in the sample annotation `.obs` is **required**. The h5ad file can be generated in two ways: 28 | 29 | ##### 1.1 For R users: 30 | ```R 31 | Save a seurat obj to h5ad, with 'data_source' as a column in meta 32 | ``` 33 | 34 | ##### 1.2 For Python users: 35 | ```python 36 | import scanpy as sc 37 | from scipy.sparse import csr_matrix 38 | # X, expression matrix, samples x genes 39 | # sampInfo, pd.DataFrame, with 'data_source' as one of the columns, and sample IDs as index 40 | # geneInfo, pd.DataFrame, with gene ids as index 41 | # X, sampInfo, geneInfo should be matched, in terms of sample order and gene order. 42 | adata = sc.AnnData(csr_matrix(X),obs=sampInfo,var=geneInfo) 43 | adata.write('name.h5ad') 44 | ``` 45 | 46 | 47 | ### 2. Train MOBER 48 | ```linux 49 | mober train \ 50 | --train_file input.h5ad \ 51 | --output_dir ../tmp_data/test 52 | ``` 53 | In this case, the trained model will be in `../tmp_data/test/models` and the training metrics and parameters used for training are in `../tmp_data/test/metrics`, in tsv format. 54 | 55 | 56 | ### 3. Do projection 57 | Once the model is trained, the projection can be done in two different ways: 58 | #### 3.1. through command line 59 | ```linux 60 | mober projection \ 61 | --model_dir path_to_where_models_and_metrics_folders_are/models \ 62 | --onto TCGA \# should be one of batch IDs used in training. 63 | --projection_file input.h5ad \ 64 | --output_file outname.h5ad \ 65 | --decimals 4 66 | ``` 67 | 68 | #### 3.2 within python scripts, as projection step is fast and does not need GPU. 69 | ```python 70 | from mober.core.projection import load_model, do_projection 71 | import scanpy as sc 72 | import torch 73 | 74 | model_dir = 'path_to_where_models_and_metrics_folders_are/models' 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | adata = sc.read('projection_file.h5ad') 77 | model, features, label_encode = load_model(model_dir, device) 78 | adata = adata[:,features] 79 | 80 | proj_adata, z_adata = do_projection(model,adata, onto, label_encode, device, batch_size=1600) 81 | proj_adata.write('outname.h5ad') 82 | 83 | # proj_adata contains the projected values. 84 | # z_adata contains the sample embeddings in the latent space 85 | 86 | ``` 87 | 88 | 89 | ### Get help about input arguments 90 | 1. Train 91 | ```linux 92 | mober train --help 93 | ``` 94 | 95 | 2. Projection 96 | ```linux 97 | mober projection --help 98 | ``` 99 | 100 | 101 | ### Use GPU on HPC, minimal script 102 | Copy and modify following content in a text file, e.g. `sub.sh`, then run `qsub sub.sh` to submit the job to HPC. 103 | ```linux 104 | #!/bin/bash 105 | #$ -cwd 106 | #$ -S /bin/bash 107 | #$ -l m_mem_free=32G 108 | #$ -l h_rt=24:00:00 109 | #$ -l gpu_card=4 110 | #$ -m e 111 | #$ -M your@email.com 112 | #$ -N mober 113 | #$ -o running.log 114 | #$ -e error.log 115 | #$ -V 116 | #$ -b n 117 | 118 | 119 | conda activate yourENV 120 | module module load cuda10.2/fft/10.2.89 # Found by module avail cuda 121 | 122 | mober train \ 123 | --train_file path/to/your/input.h5ad \ 124 | --output_dir output_path 125 | 126 | ``` 127 | 128 | ## License 129 | 130 | This project is licensed under the terms of MIT License. 131 | Copyright 2022 Novartis International AG. 132 | 133 | 134 | ## Reference 135 | 136 | If you use MOBER in your research, please consider citing our [manuscript](https://doi.org/10.1101/2022.09.07.506964), 137 | 138 | ``` 139 | @article {Dimitrieva2022.09.07.506964, 140 | author = {Dimitrieva, Slavica and Janssens, Rens and Li, Gang and Szalata, Artur and Gopal, Raja and Parmar, Chintan and Kauffmann, Audrey and Durand, Eric Y.}, 141 | title = {Biologically relevant integration of transcriptomics profiles from cancer cell lines, patient-derived xenografts and clinical tumors using deep learning}, 142 | elocation-id = {2022.09.07.506964}, 143 | year = {2022}, 144 | doi = {10.1101/2022.09.07.506964}, 145 | publisher = {Cold Spring Harbor Laboratory}, 146 | URL = {https://www.biorxiv.org/content/10.1101/2022.09.07.506964v2, 147 | eprint = {https://www.biorxiv.org/content/10.1101/2022.09.07.506964v2.full.pdf}, 148 | journal = {bioRxiv} 149 | } 150 | ``` 151 | 152 | -------------------------------------------------------------------------------- /mober/mober.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(prog='mober', 6 | description='''MOBER is a deep learning method that allows for integration of \ 7 | cancer models (CCLEs and PTXs) that are closest to patient tumors of interest based on \ 8 | mRNA expression data, without relying on annotated disease labels. It projects one dataset \ 9 | onto another and transforms the transcriptional profiles of cancer models (CCLEs and PTXs) \ 10 | into TCGA patient tumors.''') 11 | 12 | subparsers = parser.add_subparsers(dest='mode') 13 | 14 | 15 | ###################### Train Mode ####################### 16 | tparser = subparsers.add_parser('train',help='Train MOBER') 17 | 18 | tparser.add_argument( 19 | "--train_file", 20 | metavar = '', 21 | help = "A h5ad file that contains all the samples as well as a 'data_source' column." 22 | ) 23 | 24 | tparser.add_argument( 25 | "--use_sparse_mat", 26 | action = "store_true", 27 | help = "If to use sparse dataloader. Can be used when the training dataset is huge. Default False" 28 | ) 29 | 30 | tparser.add_argument( 31 | "--src_adv_weight", 32 | type = float, 33 | metavar = '', 34 | default=0.01, 35 | help = "Weight of the source adversary loss. Default 0.01", 36 | ) 37 | 38 | tparser.add_argument( 39 | "--src_adv_lr", 40 | type = float, 41 | metavar = '', 42 | help = "Learning rate. Default 1e-3", 43 | default=1e-3 44 | ) 45 | 46 | tparser.add_argument( 47 | "--batch_ae_lr", 48 | type = float, 49 | metavar = '', 50 | help="Learning rate. Default 1e-3", 51 | default=1e-3 52 | ) 53 | 54 | tparser.add_argument( 55 | "--val_set_size", 56 | type = float, 57 | metavar = '', 58 | help = "Fraction of samples that constitute the validation set. Default 0.0", 59 | default = 0.0 60 | ) 61 | 62 | tparser.add_argument( 63 | "--encoding_dim", 64 | type = int, 65 | default= 64, 66 | metavar = '', 67 | help = "Size of the embeddings. Default 64", 68 | ) 69 | 70 | tparser.add_argument( 71 | "--balanced_sources_ae", 72 | action = "store_true", 73 | help = "Flag that enables sample weights to balance according to the source in ae loss. Default False" 74 | ) 75 | 76 | tparser.add_argument( 77 | "--balanced_sources_src_adv", 78 | action = "store_true", 79 | help = "Flag that enables sample weights to balance according to the source in source adversary loss. Default False" 80 | ) 81 | 82 | tparser.add_argument( 83 | "--batch_size", 84 | type = int, 85 | metavar = '', 86 | default = 1600, 87 | help = 'Default 1600' 88 | ) 89 | 90 | tparser.add_argument( 91 | "--epochs", 92 | type = int, 93 | metavar = '', 94 | default = 3000, 95 | help = 'Max number of training epochs, Default 15000. Eearly Stopping implemented.' 96 | ) 97 | tparser.add_argument( 98 | "--random_seed", 99 | type = int, 100 | default=100, 101 | metavar = '', 102 | help = 'Default 100' 103 | ) 104 | 105 | 106 | tparser.add_argument( 107 | "--kl_weight", 108 | type = float, 109 | default=1e-5, 110 | metavar = '', 111 | help = 'Default 1e-6. Weight for KL loss.' 112 | ) 113 | 114 | tparser.add_argument( 115 | "--patience", 116 | type = int, 117 | default = 100, 118 | metavar = '', 119 | help = "Number of patience epochs for early stopping. Default 100" 120 | ) 121 | 122 | tparser.add_argument( 123 | "--output_dir", 124 | type = str, 125 | default = None, 126 | metavar = '', 127 | help="Output path in case MLflow not used.'" 128 | ) 129 | 130 | 131 | ##### MLFlow arguments#### 132 | tparser.add_argument( 133 | "--use_mlflow", 134 | action = "store_true", 135 | help = "Used if all results to be tracked by MLflow. Default False" 136 | ) 137 | 138 | tparser.add_argument( 139 | "--mlflow_storage_path", 140 | type = str, 141 | metavar = '', 142 | default = 'http://nrchbs-ldl31318.nibr.novartis.net:5000', 143 | help = 'Default: http://nrchbs-ldl31318.nibr.novartis.net:5000' 144 | ) 145 | 146 | tparser.add_argument( 147 | "--experiment_name", 148 | type=str, 149 | default = "mober", 150 | metavar = '', 151 | help ='Expriment name for MLFlow. Default mober' 152 | ) 153 | 154 | tparser.add_argument( 155 | "--run_name", 156 | type = str, 157 | default = "run", 158 | metavar = '', 159 | help = "Run name for MLFlow. Default run" 160 | ) 161 | 162 | tparser.add_argument( 163 | "--tmp_dir", 164 | type = str, 165 | default = "tmp", 166 | metavar = '', 167 | help = "Temporary directory for MLflow. Default ./tmp" 168 | ) 169 | 170 | 171 | 172 | 173 | 174 | ###################### Projection Mode ####################### 175 | pparser = subparsers.add_parser('projection',help = 'Make projection') 176 | 177 | pparser.add_argument( 178 | "--model_dir", 179 | type = str, 180 | metavar = '', 181 | help = "Path to model file", 182 | ) 183 | 184 | pparser.add_argument( 185 | "--onto", 186 | type = str, 187 | metavar = '', 188 | help = "The target 'data_source' ID which all samples will be projected onto", 189 | ) 190 | 191 | pparser.add_argument( 192 | "--projection_file", 193 | type = str, 194 | metavar = '', 195 | help = "An input file that contains all the gene expression of all samples. h5ad format" 196 | ) 197 | 198 | pparser.add_argument( 199 | "--output_file", 200 | type = str, 201 | metavar = '', 202 | help = "Name for output h5ad file for projected values", 203 | ) 204 | 205 | pparser.add_argument( 206 | "--decimals", 207 | type = int, 208 | default=4, 209 | metavar = '', 210 | help = "Floating-point numbers for the output file. Default 4", 211 | ) 212 | 213 | args = parser.parse_args() 214 | 215 | if args.mode == 'train': 216 | from mober.core import train 217 | train.main(args) 218 | 219 | if args.mode == 'projection': 220 | from mober.core import projection 221 | projection.main(args) 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | 227 | -------------------------------------------------------------------------------- /mober/core/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | from mober.core import utils 6 | from mober.core import data_utils 7 | from mober.models import utils as model_utils 8 | 9 | import argparse 10 | import copy 11 | 12 | import pandas as pd 13 | import numpy as np 14 | import numexpr 15 | import mlflow 16 | 17 | from mober.models.batch_vae import BatchVAE 18 | from mober.models.mlp import MLP 19 | 20 | from mober.loss.classification import loss_function_classification 21 | from mober.loss.vae import loss_function_vae 22 | 23 | import scanpy as sc 24 | 25 | def set_seed(seed): 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | 30 | 31 | def validation(model_BatchAE,model_src_adv,val_loader,device, args, log, src_weights_src_adv, epoch): 32 | model_BatchAE.eval() 33 | model_src_adv.eval() 34 | 35 | epoch_ae_loss_val = 0.0 36 | epoch_src_adv_loss_val = 0.0 37 | epoch_tot_loss_val = 0.0 38 | 39 | with torch.no_grad(): 40 | for data, batch in val_loader: 41 | 42 | data = data.to(device) 43 | batch = batch.to(device) 44 | 45 | dec, enc, means, stdev = model_BatchAE(data, batch) 46 | v_loss = loss_function_vae(dec, data, means, stdev, kl_weight=args.kl_weight) 47 | 48 | src_pred = model_src_adv(enc) 49 | loss_src_adv = loss_function_classification(src_pred, batch, src_weights_src_adv) 50 | loss_ae = v_loss - args.src_adv_weight * loss_src_adv 51 | 52 | epoch_ae_loss_val += v_loss.detach().item() 53 | epoch_src_adv_loss_val += loss_src_adv.detach().item() 54 | epoch_tot_loss_val += loss_ae.detach().item() 55 | 56 | log.log_metric("val_loss_ae" , epoch_ae_loss_val / len(val_loader.dataset), epoch) 57 | log.log_metric("val_loss_adv", epoch_src_adv_loss_val / len(val_loader.dataset), epoch) 58 | log.log_metric("val_loss_tot", epoch_tot_loss_val / len(val_loader.dataset), epoch) 59 | 60 | return epoch_ae_loss_val 61 | 62 | def train_model(model_BatchAE, 63 | optimizer_BatchAE, 64 | model_src_adv, 65 | optimizer_src_adv, 66 | train_loader, 67 | val_loader, 68 | src_weights_src_adv, 69 | run_dir, 70 | device, 71 | log, 72 | args): 73 | 74 | 75 | # Early stopping settings 76 | best_model_loss = np.inf 77 | waited_epochs = 0 78 | early_stop = False 79 | 80 | ae_model_file = os.path.join(run_dir, "models", "batch_ae_final.model") 81 | src_model_file = os.path.join(run_dir, "models", "src_adv_final.model") 82 | 83 | for epoch in range(args.epochs): 84 | if early_stop: break 85 | 86 | epoch_ae_loss = 0.0 87 | epoch_src_adv_loss = 0.0 88 | epoch_tot_loss = 0.0 89 | 90 | model_BatchAE.train() 91 | model_src_adv.train() 92 | for data, batch in train_loader: 93 | data = data.to(device) 94 | batch = batch.to(device) 95 | 96 | dec, enc, means, stdev = model_BatchAE(data, batch) 97 | v_loss = loss_function_vae(dec, data, means, stdev, kl_weight=args.kl_weight) 98 | 99 | # Source adversary 100 | model_src_adv.zero_grad() 101 | 102 | src_pred = model_src_adv(enc) 103 | 104 | loss_src_adv = loss_function_classification(src_pred, batch, src_weights_src_adv) 105 | loss_src_adv.backward(retain_graph=True) 106 | epoch_src_adv_loss += loss_src_adv.detach().item() 107 | optimizer_src_adv.step() 108 | 109 | src_pred = model_src_adv(enc) 110 | loss_src_adv = loss_function_classification(src_pred, batch, src_weights_src_adv) 111 | 112 | # Update ae 113 | model_BatchAE.zero_grad() 114 | loss_ae = v_loss - args.src_adv_weight * loss_src_adv 115 | loss_ae.backward() 116 | epoch_ae_loss += v_loss.detach().item() 117 | optimizer_BatchAE.step() 118 | 119 | epoch_tot_loss += loss_ae.detach().item() 120 | 121 | log.log_metric("train_loss_ae" , epoch_ae_loss / len(train_loader.dataset), epoch) 122 | log.log_metric("train_loss_adv", epoch_src_adv_loss / len(train_loader.dataset), epoch) 123 | log.log_metric("train_loss_tot", epoch_tot_loss / len(train_loader.dataset), epoch) 124 | 125 | 126 | # Validation 127 | if args.val_set_size != 0: 128 | epoch_ae_loss_val = validation(model_BatchAE,model_src_adv,val_loader,device, args, log, src_weights_src_adv,epoch) 129 | 130 | # Early stop 131 | if epoch_ae_loss_val < best_model_loss: # there is an improvement, update the best_val_loss and save the model 132 | best_model_loss = epoch_ae_loss_val 133 | waited_epochs = 0 134 | model_utils.save_model(model_BatchAE, optimizer_BatchAE, epoch, epoch_ae_loss/len(train_loader.dataset) ,ae_model_file , device) 135 | model_utils.save_model(model_src_adv, optimizer_src_adv, epoch, epoch_src_adv_loss/len(train_loader.dataset),src_model_file, device) 136 | 137 | else: 138 | waited_epochs += 1 139 | if waited_epochs > args.patience: early_stop = True 140 | 141 | if args.val_set_size == 0: 142 | model_utils.save_model(model_BatchAE, optimizer_BatchAE, epoch, epoch_ae_loss/len(train_loader.dataset) ,ae_model_file , device) 143 | model_utils.save_model(model_src_adv, optimizer_src_adv, epoch, epoch_src_adv_loss/len(train_loader.dataset),src_model_file, device) 144 | 145 | 146 | def main(args): 147 | if args.use_mlflow: 148 | run_dir = os.path.join(args.tmp_dir,str(int(time.time()))) 149 | mlflow.set_tracking_uri(args.mlflow_storage_path) 150 | mlflow.set_experiment(args.experiment_name) 151 | mlflow.start_run(run_name=args.run_name) 152 | else: 153 | run_dir = args.output_dir 154 | 155 | utils.create_temp_dirs(run_dir) 156 | 157 | 158 | log = utils.log_obj(args.use_mlflow,run_dir) 159 | log.log_params(args) 160 | 161 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 162 | print(f"Using device: {device}") 163 | numexpr.set_num_threads(numexpr.detect_number_of_cores()) 164 | 165 | adata = sc.read(args.train_file) 166 | 167 | train_loader, val_loader, label_encode = data_utils.create_dataloaders_from_adata(adata, 168 | args.batch_size, 169 | args.val_set_size, 170 | args.random_seed, 171 | args.use_sparse_mat 172 | ) 173 | # save features and label encoding 174 | features = adata.var.index.to_frame() 175 | label_encode.to_csv(os.path.join(run_dir, 'models', 'label_encode.csv')) 176 | features.to_csv(os.path.join(run_dir, 'models', 'features.csv')) 177 | 178 | 179 | set_seed(args.random_seed) 180 | 181 | 182 | model_BatchAE, optimizer_BatchAE = model_utils.create_model(BatchVAE, 183 | device, 184 | features.shape[0], 185 | args.encoding_dim, 186 | label_encode.shape[0], 187 | lr=args.batch_ae_lr, 188 | filename=None) 189 | 190 | model_src_adv, optimizer_src_adv = model_utils.create_model(MLP, 191 | device, 192 | args.encoding_dim, 193 | label_encode.shape[0], 194 | lr=args.src_adv_lr, 195 | filename=None) 196 | 197 | 198 | 199 | src_weights_src_adv = torch.tensor( 200 | data_utils.get_class_weights(adata.obs.data_source, args.balanced_sources_src_adv), dtype=torch.float).to(device) 201 | 202 | 203 | train_model(model_BatchAE, 204 | optimizer_BatchAE, 205 | model_src_adv, 206 | optimizer_src_adv, 207 | train_loader, 208 | val_loader, 209 | src_weights_src_adv, 210 | run_dir, 211 | device, 212 | log, 213 | args) 214 | 215 | log.end_log() 216 | 217 | --------------------------------------------------------------------------------