├── .gitignore ├── LICENSE ├── README.md ├── cellnet ├── __init__.py ├── datamodules.py ├── estimators.py ├── models.py ├── tabnet │ ├── __init__.py │ ├── sparsemax.py │ └── tab_network.py └── utils │ ├── __init__.py │ ├── cell_ontology.py │ ├── data_loading.py │ └── tabnet_explain.py ├── docs ├── classification-evaluation-metrics.md ├── data.md └── models.md ├── notebooks-tutorials ├── data_loading.ipynb ├── model_inference.ipynb └── model_inference.md ├── notebooks ├── data_augmentation │ ├── data_augmentation.ipynb │ ├── evaluation.ipynb │ ├── shortend_cell_types.yaml │ └── visualize_augmentations.ipynb ├── loss_curve_plotting │ └── tabnet-vs-linear.ipynb ├── model_evaluation │ ├── __init__.py │ ├── classification-CIForm.ipynb │ ├── classification-celltypist.ipynb │ ├── classification-linear.ipynb │ ├── classification-sc-foundation-models.ipynb │ ├── classification-tabnet-ensembl.ipynb │ ├── classification-tabnet.ipynb │ ├── classification-xgboost.ipynb │ ├── integration.ipynb │ ├── model-comparision.ipynb │ ├── model-improvement.ipynb │ ├── model-scaling.ipynb │ ├── predictions.py │ ├── shortend_cell_types.yaml │ ├── tabnet-attention-masks.ipynb │ ├── tabnet-coarse-celltypes.ipynb │ ├── tabnet-detailed-eval.ipynb │ ├── tabnet_non_10X.ipynb │ └── utils.py ├── store_creation │ ├── 01_download_data.ipynb │ ├── 02_create_train_val_test_split.ipynb │ ├── 03_write_store_merlin.ipynb │ ├── 04_create_hierarchy_matrices.ipynb │ ├── 05_compute_pca.ipynb │ ├── 06_check_written_store.ipynb │ ├── 07_data_summary.ipynb │ ├── features.parquet │ └── subsetted_data │ │ ├── subset_to_lung_only.ipynb │ │ ├── write_store_merlin_lung_only.ipynb │ │ └── write_store_merlin_subset.ipynb └── training │ ├── scGPT-finetuning.ipynb │ ├── train_celltypist.ipynb │ ├── train_linear.ipynb │ ├── train_mlp.ipynb │ ├── train_tabnet.ipynb │ ├── train_xgboost.ipynb │ └── zheng68k.ipynb ├── requirements-gpu.txt ├── requirements.txt ├── scripts ├── create_venv.sh ├── py_scripts │ ├── CIForm.py │ ├── scGPT-inference.py │ ├── train_linear.py │ ├── train_mlp.py │ ├── train_tabnet.py │ ├── train_xgboost.py │ └── utils.py ├── scGPT-inference.sh ├── train_CIForm.sh ├── train_linear_jsc.sh ├── train_linear_lrz.sh ├── train_mlp_lrz.sh ├── train_tabnet_jsc.sh ├── train_tabnet_lrz.sh └── train_xgboost_lrz.sh ├── setup.py ├── tests └── test_sample.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | .idea/ 4 | .ipynb_checkpoints/ 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Packages 10 | *.egg 11 | *.egg-info 12 | dist 13 | build 14 | eggs 15 | parts 16 | bin 17 | var 18 | sdist 19 | develop-eggs 20 | .installed.cfg 21 | lib 22 | lib64 23 | __pycache__ 24 | 25 | # Installer logs 26 | pip-log.txt 27 | 28 | # Unit test / coverage reports 29 | .coverage 30 | .tox 31 | nosetests.xml 32 | 33 | # Translations 34 | *.mo 35 | 36 | # Mr Developer 37 | .mr.developer.cfg 38 | .project 39 | .pydevproject 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Felix Fischer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | scTab 2 | ======= 3 | De novo cell type prediction model for single-cell RNA-seq data that can be trained across a large-scale collection of 4 | curated datasets. 5 | 6 | Model checkpoints and traning data 7 | ----- 8 | * Training data (compatible with Merlin Dataloader infrastructure): https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p.tar.gz (164GB) 9 | * Model checkpoints: https://pklab.med.harvard.edu/felix/data/scTab-checkpoints.tar.gz (8.1GB) 10 | * Minimal subset of the training, validation and test data: https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p_minimal.tar.gz (0.5GB) 11 | 12 | Project structure 13 | ----- 14 | * ``cellnet``: code for models + data loading infrastructure 15 | * ``docs``: 16 | * ``data.md``: Details about data preparation 17 | * ``models.md``: Details about used models 18 | * ``classification-evaluation-metrics.md``: Details about used evaluation metrics 19 | * ``notebooks``: 20 | * ``data_augmentation``: Notebooks related to data augmentation → calculation of augmentation vectors + 21 | evaluation 22 | * ``model_evaluation``: Notebooks containing all evaluation code from this paper 23 | * ``loss_curve_plotting``: Notebooks to plot and compare loss curves 24 | * ``store_creation``: Notebooks used to create and reproduce the datasets used in this paper 25 | * ``training``: Notebooks to train models 26 | * ``notebooks-tutorials``: 27 | * ``data_loading.ipynb``: Example notebook about how to use data loading 28 | * ``model_inference.ipynb``: Example notebook how to use trained models for inference 29 | * ``scripts``: Scripts used to train models 30 | 31 | Installation 32 | ------------ 33 | 34 | ### Installation via Nvidia Enroot / Docker (easy) 35 | A base docker image with most packages preinstalled can be pulled from here: 36 | nvcr.io/nvidia/merlin/merlin-pytorch:23.02 37 | 38 | Moreover, the Nvidia Enroot (https://github.com/NVIDIA/enroot) container image which was used to run all the experiments 39 | in this paper can be found to download here: https://pklab.med.harvard.edu/felix/data/merlin-2302.sqsh 40 | 41 | For ease of use, we recommend to use the above supplied Enroot container image as it comes with all relevant software 42 | preinstalled. 43 | 44 | ### Installation via pip 45 | Run the following command the project folder to install the ``cellnet`` package: 46 | ``pip install -e .`` 47 | 48 | To install GPU dependencies install the dependencies from the ``requirements-gpu.txt`` file first. 49 | To do so, use ``--extra-index-url https://pypi.nvidia.com/`` argument when installing packages via pip. 50 | 51 | Installation time on a local computer should be a couple of minutes. 52 | 53 | System requirements 54 | ------------ 55 | Operating system: Ubuntu 20.04.5 LTS (used OS version)\ 56 | Python version: 3.8 or 3.10\ 57 | Packages: See requirements.txt and requirements-gpu.txt 58 | 59 | Hardware requirements 60 | ------------ 61 | Due to high computational demands, a modern GPU (e.g. Nvidia A100 or V100 GPU with at least 16GB of VRAM) is needed to 62 | run the training and evaluation scripts in this repository.\ 63 | On a normal desktop computer without GPU acceleration runtime will probably exceed several days. 64 | 65 | Licence 66 | ------- 67 | MIT license 68 | 69 | Authors 70 | ------- 71 | `scTab` was written by `Felix Fischer ` 72 | 73 | Support for software development, testing, modeling, and benchmarking provided by the Cell Annotation Platform team 74 | (Roman Mukhin, Andrey Isaev, Uğur Bayındır) 75 | 76 | Citation 77 | -------- 78 | If scTab is helpful in your research, please consider citing the following [paper](https://www.nature.com/articles/s41467-024-51059-5) 79 | 80 | ``` 81 | Fischer, Felix, David S. Fischer, Roman Mukhin, Andrey Isaev, Evan Biederstedt, Alexandra-Chloé Villani, and Fabian J. Theis. 2024. “scTab: Scaling Cross-Tissue Single-Cell Annotation Models.” Nature Communications 15 (1). https://doi.org/10.1038/s41467-024-51059-5. 82 | ``` 83 | -------------------------------------------------------------------------------- /cellnet/__init__.py: -------------------------------------------------------------------------------- 1 | """cellnet - Scaling single cell models to bigger data sets.""" 2 | 3 | __version__ = '0.1.0' 4 | __author__ = 'Felix Fischer ' 5 | __all__ = [] 6 | -------------------------------------------------------------------------------- /cellnet/datamodules.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | from os.path import join 4 | from typing import Dict, List 5 | 6 | import lightning.pytorch as pl 7 | import merlin.io 8 | from merlin.dataloader.torch import Loader 9 | from merlin.dtypes import boolean 10 | from merlin.dtypes import float32, int64 11 | from merlin.schema import ColumnSchema, Schema 12 | 13 | 14 | PARQUET_SCHEMA = { 15 | 'X': float32, 16 | 'soma_joinid': int64, 17 | 'is_primary_data': boolean, 18 | 'dataset_id': int64, 19 | 'donor_id': int64, 20 | 'assay': int64, 21 | 'cell_type': int64, 22 | 'development_stage': int64, 23 | 'disease': int64, 24 | 'tissue': int64, 25 | 'tissue_general': int64, 26 | 'tech_sample': int64, 27 | 'idx': int64, 28 | } 29 | 30 | 31 | def merlin_dataset_factory(path: str, columns: List[str], dataset_kwargs: Dict[str, any]): 32 | return merlin.io.Dataset( 33 | path, 34 | engine='parquet', 35 | schema=Schema( 36 | [ 37 | ColumnSchema( 38 | 'X', dtype=PARQUET_SCHEMA['X'], 39 | is_list=True, is_ragged=False, 40 | properties={'value_count': {'max': 19331}} 41 | ) 42 | ] + 43 | [ColumnSchema(col, dtype=PARQUET_SCHEMA[col]) for col in columns] 44 | ), 45 | **dataset_kwargs 46 | ) 47 | 48 | 49 | def set_default_kwargs_dataloader(kwargs: Dict[str, any] = None, training: bool = True): 50 | assert isinstance(training, bool) 51 | if kwargs is None: 52 | kwargs = {} 53 | if 'parts_per_chunk' not in kwargs: 54 | kwargs['parts_per_chunk'] = 8 if training else 1 55 | if 'drop_last' not in kwargs: 56 | kwargs['drop_last'] = training 57 | if'shuffle' not in kwargs: 58 | kwargs['shuffle'] = training 59 | 60 | return kwargs 61 | 62 | 63 | def set_default_kwargs_dataset(kwargs: Dict[str, any] = None, training: bool = True): 64 | if kwargs is None: 65 | kwargs = {} 66 | if all(['part_size' not in kwargs, 'part_mem_fraction' not in kwargs]): 67 | kwargs['part_size'] = '100MB' if training else '325MB' 68 | 69 | return kwargs 70 | 71 | 72 | def _get_data_files(base_path: str, split: str, sub_sample_frac: float): 73 | if sub_sample_frac == 1.: 74 | # if no subsampling -> just return base path and merlin takes care of the rest 75 | return join(base_path, split) 76 | else: 77 | files = [file for file in os.listdir(join(base_path, split)) if file.endswith('.parquet')] 78 | files = [join(base_path, split, file) for file in sorted(files, key=lambda x: int(x.split('.')[1]))] 79 | return files[:ceil(sub_sample_frac * len(files))] 80 | 81 | 82 | class MerlinDataModule(pl.LightningDataModule): 83 | 84 | def __init__( 85 | self, 86 | path: str, 87 | columns: List[str], 88 | batch_size: int, 89 | sub_sample_frac: float = 1., 90 | dataloader_kwargs_train: Dict[str, any] = None, 91 | dataloader_kwargs_inference: Dict[str, any] = None, 92 | dataset_kwargs_train: Dict[str, any] = None, 93 | dataset_kwargs_inference: Dict[str, any] = None 94 | ): 95 | super(MerlinDataModule).__init__() 96 | for col in columns: 97 | assert col in PARQUET_SCHEMA 98 | 99 | self.dataloader_kwargs_train = set_default_kwargs_dataloader(dataloader_kwargs_train, training=True) 100 | self.dataloader_kwargs_inference = set_default_kwargs_dataloader(dataloader_kwargs_inference, training=False) 101 | 102 | self.train_dataset = merlin_dataset_factory( 103 | _get_data_files(path, 'train', sub_sample_frac), 104 | columns, 105 | set_default_kwargs_dataset(dataset_kwargs_train, training=True) 106 | ) 107 | self.val_dataset = merlin_dataset_factory( 108 | _get_data_files(path, 'val', sub_sample_frac), 109 | columns, 110 | set_default_kwargs_dataset(dataset_kwargs_inference, training=False) 111 | ) 112 | self.test_dataset = merlin_dataset_factory( 113 | join(path, 'test'), columns, set_default_kwargs_dataset(dataset_kwargs_inference, training=False)) 114 | 115 | self.batch_size = batch_size 116 | 117 | def train_dataloader(self): 118 | return Loader(self.train_dataset, batch_size=self.batch_size, **self.dataloader_kwargs_train) 119 | 120 | def val_dataloader(self): 121 | return Loader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs_inference) 122 | 123 | def test_dataloader(self): 124 | return Loader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs_inference) 125 | 126 | def predict_dataloader(self): 127 | return Loader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs_inference) 128 | -------------------------------------------------------------------------------- /cellnet/estimators.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from typing import Dict, List 3 | 4 | import lightning.pytorch as pl 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from lightning.pytorch.tuner.tuning import Tuner 9 | 10 | from cellnet.datamodules import MerlinDataModule 11 | from cellnet.models import TabnetClassifier, LinearClassifier, MLPClassifier 12 | 13 | 14 | class EstimatorCellTypeClassifier: 15 | 16 | datamodule: MerlinDataModule 17 | model: pl.LightningModule 18 | trainer: pl.Trainer 19 | 20 | def __init__(self, data_path: str): 21 | self.data_path = data_path 22 | 23 | def init_datamodule( 24 | self, 25 | batch_size: int = 2048, 26 | sub_sample_frac: float = 1., 27 | dataloader_kwargs_train: Dict = None, 28 | dataloader_kwargs_inference: Dict = None, 29 | merlin_dataset_kwargs_train: Dict = None, 30 | merlin_dataset_kwargs_inference: Dict = None 31 | ): 32 | self.datamodule = MerlinDataModule( 33 | self.data_path, 34 | columns=['cell_type'], 35 | batch_size=batch_size, 36 | sub_sample_frac=sub_sample_frac, 37 | dataloader_kwargs_train=dataloader_kwargs_train, 38 | dataloader_kwargs_inference=dataloader_kwargs_inference, 39 | dataset_kwargs_train=merlin_dataset_kwargs_train, 40 | dataset_kwargs_inference=merlin_dataset_kwargs_inference 41 | ) 42 | 43 | def init_model(self, model_type: str, model_kwargs): 44 | if model_type == 'tabnet': 45 | self.model = TabnetClassifier(**{**self.get_fixed_model_params(model_type), **model_kwargs}) 46 | elif model_type == 'linear': 47 | self.model = LinearClassifier(**{**self.get_fixed_model_params(model_type), **model_kwargs}) 48 | elif model_type == 'mlp': 49 | self.model = MLPClassifier(**{**self.get_fixed_model_params(model_type), **model_kwargs}) 50 | else: 51 | raise ValueError(f'model_type has to be in ["linear", "tabnet"]. You supplied: {model_type}') 52 | 53 | def init_trainer(self, trainer_kwargs): 54 | self.trainer = pl.Trainer(**trainer_kwargs) 55 | 56 | def _check_is_initialized(self): 57 | if not self.model: 58 | raise RuntimeError('You need to call self.init_model before calling self.train') 59 | if not self.datamodule: 60 | raise RuntimeError('You need to call self.init_datamodule before calling self.train') 61 | if not self.trainer: 62 | raise RuntimeError('You need to call self.init_trainer before calling self.train') 63 | 64 | def get_fixed_model_params(self, model_type: str): 65 | model_params = { 66 | 'gene_dim': len(pd.read_parquet(join(self.data_path, 'var.parquet'))), 67 | 'type_dim': len(pd.read_parquet(join(self.data_path, 'categorical_lookup/cell_type.parquet'))), 68 | 'class_weights': np.load(join(self.data_path, 'class_weights.npy')), 69 | 'child_matrix': np.load(join(self.data_path, 'cell_type_hierarchy/child_matrix.npy')), 70 | 'train_set_size': sum(self.datamodule.train_dataset.partition_lens), 71 | 'val_set_size': sum(self.datamodule.val_dataset.partition_lens), 72 | 'batch_size': self.datamodule.batch_size, 73 | } 74 | if model_type in ['tabnet', 'mlp']: 75 | model_params['augmentations'] = np.load(join(self.data_path, 'augmentations.npy')) 76 | 77 | return model_params 78 | 79 | def find_lr(self, lr_find_kwargs, plot_results: bool = False): 80 | self._check_is_initialized() 81 | tuner = Tuner(self.trainer) 82 | lr_finder = tuner.lr_find( 83 | self.model, 84 | train_dataloaders=self.datamodule.train_dataloader(), 85 | val_dataloaders=self.datamodule.val_dataloader(), 86 | **lr_find_kwargs 87 | ) 88 | if plot_results: 89 | lr_finder.plot(suggest=True) 90 | 91 | return lr_finder.suggestion(), lr_finder.results 92 | 93 | def train(self, ckpt_path: str = None): 94 | self._check_is_initialized() 95 | self.trainer.fit( 96 | self.model, 97 | train_dataloaders=self.datamodule.train_dataloader(), 98 | val_dataloaders=self.datamodule.val_dataloader(), 99 | ckpt_path=ckpt_path 100 | ) 101 | 102 | def validate(self, ckpt_path: str = None): 103 | self._check_is_initialized() 104 | return self.trainer.validate(self.model, dataloaders=self.datamodule.val_dataloader(), ckpt_path=ckpt_path) 105 | 106 | def test(self, ckpt_path: str = None): 107 | self._check_is_initialized() 108 | return self.trainer.test(self.model, dataloaders=self.datamodule.test_dataloader(), ckpt_path=ckpt_path) 109 | 110 | def predict(self, dataloader=None, ckpt_path: str = None) -> np.ndarray: 111 | self._check_is_initialized() 112 | predictions_batched: List[torch.Tensor] = self.trainer.predict( 113 | self.model, 114 | dataloaders=dataloader if dataloader else self.datamodule.predict_dataloader(), 115 | ckpt_path=ckpt_path 116 | ) 117 | return torch.vstack(predictions_batched).numpy() 118 | -------------------------------------------------------------------------------- /cellnet/tabnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/scTab/5ede7f2ba1f9618b86924f2ff587931de18f4ada/cellnet/tabnet/__init__.py -------------------------------------------------------------------------------- /cellnet/tabnet/sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Function 5 | 6 | """ 7 | Other possible implementations: 8 | https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py 9 | https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py 10 | https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py 11 | """ 12 | 13 | 14 | # credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py 15 | def _make_ix_like(input, dim=0): 16 | d = input.size(dim) 17 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 18 | view = [1] * input.dim() 19 | view[0] = -1 20 | return rho.view(view).transpose(0, dim) 21 | 22 | 23 | class SparsemaxFunction(Function): 24 | """ 25 | An implementation of sparsemax (Martins & Astudillo, 2016). See 26 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 27 | By Ben Peters and Vlad Niculae 28 | """ 29 | 30 | @staticmethod 31 | def forward(ctx, input, dim=-1): 32 | """sparsemax: normalizing sparse transform (a la softmax) 33 | 34 | Parameters 35 | ---------- 36 | ctx : torch.autograd.function._ContextMethodMixin 37 | input : torch.Tensor 38 | any shape 39 | dim : int 40 | dimension along which to apply sparsemax 41 | 42 | Returns 43 | ------- 44 | output : torch.Tensor 45 | same shape as input 46 | 47 | """ 48 | ctx.dim = dim 49 | max_val, _ = input.max(dim=dim, keepdim=True) 50 | input -= max_val # same numerical stability trick as for softmax 51 | tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) 52 | output = torch.clamp(input - tau, min=0) 53 | ctx.save_for_backward(supp_size, output) 54 | return output 55 | 56 | @staticmethod 57 | def backward(ctx, grad_output): 58 | supp_size, output = ctx.saved_tensors 59 | dim = ctx.dim 60 | grad_input = grad_output.clone() 61 | grad_input[output == 0] = 0 62 | 63 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 64 | v_hat = v_hat.unsqueeze(dim) 65 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 66 | return grad_input, None 67 | 68 | @staticmethod 69 | def _threshold_and_support(input, dim=-1): 70 | """Sparsemax building block: compute the threshold 71 | 72 | Parameters 73 | ---------- 74 | input: torch.Tensor 75 | any dimension 76 | dim : int 77 | dimension along which to apply the sparsemax 78 | 79 | Returns 80 | ------- 81 | tau : torch.Tensor 82 | the threshold value 83 | support_size : torch.Tensor 84 | 85 | """ 86 | 87 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 88 | input_cumsum = input_srt.cumsum(dim) - 1 89 | rhos = _make_ix_like(input, dim) 90 | support = rhos * input_srt > input_cumsum 91 | 92 | support_size = support.sum(dim=dim).unsqueeze(dim) 93 | tau = input_cumsum.gather(dim, support_size - 1) 94 | tau /= support_size.to(input.dtype) 95 | return tau, support_size 96 | 97 | 98 | sparsemax = SparsemaxFunction.apply 99 | 100 | 101 | class Sparsemax(nn.Module): 102 | 103 | def __init__(self, dim=-1): 104 | self.dim = dim 105 | super(Sparsemax, self).__init__() 106 | 107 | def forward(self, input): 108 | return sparsemax(input, self.dim) 109 | 110 | 111 | class Entmax15Function(Function): 112 | """ 113 | An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See 114 | :cite:`https://arxiv.org/abs/1905.05702 for detailed description. 115 | Source: https://github.com/deep-spin/entmax 116 | """ 117 | 118 | @staticmethod 119 | def forward(ctx, input, dim=-1): 120 | ctx.dim = dim 121 | 122 | max_val, _ = input.max(dim=dim, keepdim=True) 123 | input = input - max_val # same numerical stability trick as for softmax 124 | input = input / 2 # divide by 2 to solve actual Entmax 125 | 126 | tau_star, _ = Entmax15Function._threshold_and_support(input, dim) 127 | output = torch.clamp(input - tau_star, min=0) ** 2 128 | ctx.save_for_backward(output) 129 | return output 130 | 131 | @staticmethod 132 | def backward(ctx, grad_output): 133 | Y, = ctx.saved_tensors 134 | gppr = Y.sqrt() # = 1 / g'' (Y) 135 | dX = grad_output * gppr 136 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 137 | q = q.unsqueeze(ctx.dim) 138 | dX -= q * gppr 139 | return dX, None 140 | 141 | @staticmethod 142 | def _threshold_and_support(input, dim=-1): 143 | Xsrt, _ = torch.sort(input, descending=True, dim=dim) 144 | 145 | rho = _make_ix_like(input, dim) 146 | mean = Xsrt.cumsum(dim) / rho 147 | mean_sq = (Xsrt ** 2).cumsum(dim) / rho 148 | ss = rho * (mean_sq - mean ** 2) 149 | delta = (1 - ss) / rho 150 | 151 | # NOTE this is not exactly the same as in reference algo 152 | # Fortunately it seems the clamped values never wrongly 153 | # get selected by tau <= sorted_z. Prove this! 154 | delta_nz = torch.clamp(delta, 0) 155 | tau = mean - torch.sqrt(delta_nz) 156 | 157 | support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) 158 | tau_star = tau.gather(dim, support_size - 1) 159 | return tau_star, support_size 160 | 161 | 162 | class Entmoid15(Function): 163 | """ A highly optimized equivalent of lambda x: Entmax15([x, 0]) """ 164 | 165 | @staticmethod 166 | def forward(ctx, input): 167 | output = Entmoid15._forward(input) 168 | ctx.save_for_backward(output) 169 | return output 170 | 171 | @staticmethod 172 | def _forward(input): 173 | input, is_pos = abs(input), input >= 0 174 | tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 175 | tau.masked_fill_(tau <= input, 2.0) 176 | y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 177 | return torch.where(is_pos, 1 - y_neg, y_neg) 178 | 179 | @staticmethod 180 | def backward(ctx, grad_output): 181 | return Entmoid15._backward(ctx.saved_tensors[0], grad_output) 182 | 183 | @staticmethod 184 | def _backward(output, grad_output): 185 | gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() 186 | grad_input = grad_output * gppr0 187 | q = grad_input / (gppr0 + gppr1) 188 | grad_input -= q * gppr0 189 | return grad_input 190 | 191 | 192 | entmax15 = Entmax15Function.apply 193 | entmoid15 = Entmoid15.apply 194 | 195 | 196 | class Entmax15(nn.Module): 197 | 198 | def __init__(self, dim=-1): 199 | self.dim = dim 200 | super(Entmax15, self).__init__() 201 | 202 | def forward(self, input): 203 | return entmax15(input, self.dim) 204 | -------------------------------------------------------------------------------- /cellnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/scTab/5ede7f2ba1f9618b86924f2ff587931de18f4ada/cellnet/utils/__init__.py -------------------------------------------------------------------------------- /cellnet/utils/cell_ontology.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | 3 | from SPARQLWrapper import SPARQLWrapper, JSON 4 | 5 | 6 | def chunks(lst, n): 7 | for i in range(0, len(lst), n): 8 | yield lst[i: i + n] 9 | 10 | 11 | def get_child_node_query(value_term: str, cell_type_list: List[str]) -> str: 12 | """Returns a SPARQL query to retrieve child nodes of given cell type 13 | 14 | Args: 15 | value_term: It should be '?parent' for CURIE queries and should be '?cell_type' for label queries 16 | cell_type_list: Cell type label 17 | 18 | Returns: 19 | A SPARQL query 20 | 21 | """ 22 | if value_term == "?cell_type": 23 | updated_list = ["'" + element + "'" for element in cell_type_list] 24 | elif value_term == "?parent": 25 | updated_list = cell_type_list 26 | else: 27 | raise ValueError(f"value_term cannot be {value_term}") 28 | return ( 29 | f"PREFIX rdfs: " 30 | f"PREFIX CL: " 31 | f"PREFIX owl: " 32 | f"SELECT * WHERE {{ ?parent rdfs:label ?cell. " 33 | f"?child rdfs:subClassOf ?parent. ?child rdfs:label ?child_label." 34 | f"?parent rdfs:isDefinedBy . " 35 | f"?child rdfs:isDefinedBy . " 36 | f" owl:versionIRI ?version." 37 | f"BIND(str(?cell) AS ?cell_type) VALUES {value_term} {{{' '.join(updated_list)}}} }} " 38 | ) 39 | 40 | 41 | def retrieve_child_nodes_from_ubergraph(cell_list: List[str]) -> Dict[str, Union[List[str], str]]: 42 | """This method returns a dictionary containing the child nodes of the specified cell types. Additionally, 43 | the dictionary includes the corresponding CL version from which the information has been retrieved. 44 | 45 | Args: 46 | cell_list: List of cell type labels, labels should match the CL term labels 47 | 48 | Returns: 49 | Cell type to list of corresponding child nodes with CL version they have retrieved 50 | 51 | """ 52 | sparql = SPARQLWrapper("https://ubergraph.apps.renci.org/sparql") 53 | sparql.method = 'POST' 54 | sparql.setReturnFormat(JSON) 55 | child_nodes_dict = {} 56 | cl_version = "" 57 | for chunk in chunks(cell_list, 90): 58 | sparql.setQuery(get_child_node_query("?parent" if ":" in cell_list[0] else "?cell_type", chunk)) 59 | ret = sparql.queryAndConvert() 60 | for row in ret["results"]["bindings"]: 61 | parent = row["cell_type"]["value"] 62 | child = row["child_label"]["value"] 63 | if parent in child_nodes_dict: 64 | child_nodes_dict[parent].append(child) 65 | else: 66 | child_nodes_dict[parent] = [child] 67 | if not cl_version: 68 | cl_version = row["version"]["value"] 69 | 70 | return child_nodes_dict 71 | -------------------------------------------------------------------------------- /cellnet/utils/data_loading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from scipy.sparse import csc_matrix, csr_matrix, issparse 5 | from sklearn.utils import sparsefuncs 6 | from torch.utils.data import Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler 7 | 8 | """ 9 | Data streamlining. 10 | """ 11 | 12 | 13 | def streamline_count_matrix(x_raw, gene_names_raw, gene_names_model): 14 | assert len(gene_names_raw) == len(set(gene_names_raw)) 15 | assert len(gene_names_model) == len(set(gene_names_model)) 16 | assert len(gene_names_raw) == x_raw.shape[1] 17 | assert np.isin(gene_names_raw, gene_names_model).sum() == x_raw.shape[1] 18 | # For fast column-wise slicing matrix has to be in csc format 19 | assert isinstance(x_raw, csc_matrix) 20 | # skip zero filling missing genes if all genes are present 21 | if len(gene_names_raw) == len(gene_names_model): 22 | return x_raw.tocsr().astype('f4') 23 | gene_names_raw, gene_names_model = np.array(gene_names_raw), np.array(gene_names_model) 24 | row, col = np.empty(x_raw.nnz, dtype='i8'), np.empty(x_raw.nnz, dtype='i8') 25 | data = np.empty(x_raw.nnz, dtype='f4') 26 | 27 | ctr = 0 28 | for i, gene in enumerate(gene_names_model): 29 | if gene in gene_names_raw: 30 | gene_idx = int(np.where(gene == gene_names_raw)[0]) 31 | x_col = x_raw[:, gene_idx] 32 | idxs_nnz = x_col.indices.tolist() 33 | n_nnz = len(idxs_nnz) 34 | col[ctr:ctr+n_nnz] = i 35 | row[ctr:ctr+n_nnz] = idxs_nnz 36 | data[ctr:ctr+n_nnz] = x_col.data 37 | ctr += n_nnz 38 | 39 | return csr_matrix( 40 | (data, (row, col)), 41 | shape=(x_raw.shape[0], len(gene_names_model)), 42 | dtype='f4' 43 | ) 44 | 45 | 46 | def sf_normalize(x): 47 | """Normalize each cell to have 10000 counts. """ 48 | x = x.copy() 49 | counts = np.array(x.sum(axis=1)) 50 | # avoid zero division error 51 | counts += counts == 0. 52 | # normalize to 10000. counts 53 | scaling_factor = 10000. / counts 54 | 55 | if issparse(x): 56 | sparsefuncs.inplace_row_scale(x, scaling_factor) 57 | else: 58 | np.multiply(x, scaling_factor.reshape((-1, 1)), out=x) 59 | 60 | return x 61 | 62 | 63 | """ 64 | Data Loaders. 65 | """ 66 | 67 | 68 | class CustomDataset(Dataset): 69 | 70 | def __init__(self, x, obs=None): 71 | super(CustomDataset).__init__() 72 | assert any([isinstance(x, np.ndarray), isinstance(x, csr_matrix)]) 73 | self.x = x 74 | self.obs = obs 75 | 76 | def __len__(self): 77 | return self.x.shape[0] 78 | 79 | def __getitem__(self, idx): 80 | if isinstance(idx, int): 81 | idx = [idx] 82 | x = self.x[idx, :] 83 | if isinstance(x, csr_matrix): 84 | x = x.toarray() 85 | 86 | if self.obs is not None: 87 | # replicate merlin dataloader output format 88 | out = ( 89 | { 90 | 'X': torch.tensor(x.squeeze()), 91 | 'cell_type': torch.tensor( 92 | self.obs.iloc[idx]['cell_type'].cat.codes.to_numpy().reshape((-1, 1)).astype('i8') 93 | ) 94 | }, None 95 | ) 96 | else: 97 | out = ({'X': torch.tensor(x.squeeze())}, None) 98 | 99 | return out 100 | 101 | 102 | def dataloader_factory(x, obs=None, batch_size=2048, shuffle=False): 103 | if shuffle: 104 | sampler = BatchSampler(RandomSampler(range(x.shape[0])), batch_size=batch_size, drop_last=True) 105 | else: 106 | sampler = BatchSampler(SequentialSampler(range(x.shape[0])), batch_size=batch_size, drop_last=False) 107 | 108 | return DataLoader(CustomDataset(x, obs), sampler=sampler, batch_size=None) 109 | -------------------------------------------------------------------------------- /cellnet/utils/tabnet_explain.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | from scipy.sparse import csc_matrix 4 | from tqdm import tqdm 5 | 6 | from cellnet.tabnet.tab_network import TabNet 7 | 8 | 9 | @numba.njit 10 | def _get_nnz_idxs_per_row(x: np.ndarray): 11 | nnz_idxs = [] 12 | for ix in range(x.shape[0]): 13 | nnz_idxs.append(np.where(x[ix, :] > 0.)[0]) 14 | 15 | return nnz_idxs 16 | 17 | 18 | def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): 19 | """ 20 | This is a computational trick. 21 | In order to rapidly sum importances from same embeddings 22 | to the initial index. 23 | 24 | Parameters 25 | ---------- 26 | input_dim : int 27 | Initial input dim 28 | cat_emb_dim : int or list of int 29 | if int : size of embedding for all categorical feature 30 | if list of int : size of embedding for each categorical feature 31 | cat_idxs : list of int 32 | Initial position of categorical features 33 | post_embed_dim : int 34 | Post embedding inputs dimension 35 | 36 | Returns 37 | ------- 38 | reducing_matrix : np.ndarray 39 | Matrix of dim (post_embed_dim, input_dim) to perform reduce 40 | """ 41 | 42 | if isinstance(cat_emb_dim, int): 43 | all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs) 44 | else: 45 | all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim] 46 | 47 | acc_emb = 0 48 | nb_emb = 0 49 | indices_trick = [] 50 | for i in range(input_dim): 51 | if i not in cat_idxs: 52 | indices_trick.append([i + acc_emb]) 53 | else: 54 | indices_trick.append( 55 | range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1) 56 | ) 57 | acc_emb += all_emb_impact[nb_emb] 58 | nb_emb += 1 59 | 60 | reducing_matrix = np.zeros((post_embed_dim, input_dim)) 61 | for i, cols in enumerate(indices_trick): 62 | reducing_matrix[cols, i] = 1 63 | 64 | return csc_matrix(reducing_matrix) 65 | 66 | 67 | def explain( 68 | model: TabNet, 69 | dataloader, 70 | only_return_nnz_idxs: bool = True, 71 | normalize: bool = False, 72 | device: str = 'cuda' 73 | ): 74 | reducing_matrix = create_explain_matrix( 75 | model.classifier.input_dim,1,[], model.classifier.input_dim) 76 | model.to(device) 77 | model.eval() 78 | 79 | res_explain = [] 80 | for batch_idx, data in tqdm(enumerate(dataloader)): 81 | data = data[0]['X'].to(device) 82 | M_explain, _ = model.classifier.forward_masks(data) 83 | original_feat_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(), reducing_matrix) 84 | if not only_return_nnz_idxs: 85 | res_explain.append(original_feat_explain) 86 | else: 87 | res_explain += _get_nnz_idxs_per_row(original_feat_explain) 88 | 89 | if not only_return_nnz_idxs: 90 | res_explain = np.vstack(res_explain) 91 | if normalize: 92 | res_explain /= np.sum(res_explain, axis=1)[:, None] 93 | 94 | return res_explain 95 | 96 | 97 | def get_feature_masks( 98 | model: TabNet, 99 | dataloader, 100 | device: str = 'cuda' 101 | ): 102 | model.to(device) 103 | model.eval() 104 | 105 | res_explain = {i: [] for i in range(model.n_steps)} 106 | for batch_idx, data in tqdm(enumerate(dataloader)): 107 | data = data[0]['X'].to(device) 108 | _, masks = model.forward_masks(data) 109 | for i in range(model.n_steps): 110 | res_explain[i].append(masks[i].cpu().detach().numpy()) 111 | 112 | return {i: np.vstack(res_explain[i]) for i in range(model.n_steps)} 113 | -------------------------------------------------------------------------------- /docs/classification-evaluation-metrics.md: -------------------------------------------------------------------------------- 1 | # Model Evaluation 2 | **Author:** Felix Fischer (GitHub @felix0097) \ 3 | **Date:** 31.05.2023 4 | 5 | 6 | ## Evaluation metrics 7 | 8 | * F1-score: https://en.wikipedia.org/wiki/F-score 9 | * Macro F1-score (macro average F1-score) 10 | * Macro average == calculate F1-score per cell type and take the average of the per cell type F1-scores 11 | * Macro average better reflects performance across all cell types as there is a strong class imbalance in the data set 12 | * The macro F1-score is often taken as the default metric to evaluate cell type classification performance 13 | * Weighted F1-score (weighted average F1-score) 14 | * Weighted average == calculate F1-score per cell type and take the weighted average of the per cell type F1-scores 15 | (classes with more samples get a higher weight) 16 | * Weighted average better reflects how many of the cells are classified correctly 17 | 18 | 19 | ## Evaluation data 20 | 21 | * Evaluation covers the same cell types as seen during model training 22 | * Evaluation data consists of donors the model has not seen during training 23 | * Random donor holdout and not random subsampling! 24 | * This better represents how well the classifier generalises to unseen donors than just random subsampling of cells 25 | 26 | 27 | ## Dealing with different cell type annotation granularity 28 | 29 | * Different data sets are often annotated with vastly different granularity e.g. `T cell` vs 30 | `CD4-positive, alpha-beta T cell` 31 | * To account for this, the following rule applies when evaluating whether a prediction is right or wrong: 32 | * Prediction is `right` or `true` if: 33 | * Classifier predicts the same label as the author 34 | * Classifier predicts a subtype of the true label 35 | * This is considered as a right prediction as the prediction agrees with the true label up to the annotation 36 | granularity the author provided 37 | * e.g. classifier predicts `CD4-positive, alpha-beta T cell` when the author annotated the cell with `T cell` 38 | * Prediction is `wrong` or `false` if: 39 | * Classifier predicts a parent cell type of the true label 40 | * This is considered a wrong prediction as the author supplied a more fine-grained label 41 | * e.g. classifier predicts `T cell` instead of `CD4-positive, alpha-beta T cell` 42 | * Anything else 43 | * The code to find the child nodes based on the cell ontology (https://www.ebi.ac.uk/ols/ontologies/cl) can be 44 | found under `cellnet/utils/cell_ontology.py` 45 | 46 | 47 | ## Evaluation results 48 | 49 | * The cell type classification evaluation notebooks for TabNet + reference models can be found here: 50 | * TabNet: `notebooks/model_evaluation/classification-tabnet.ipynb` 51 | * Linear reference model: `notebooks/model_evaluation/classification-linear.ipynb` 52 | * XGBoost reference model: `notebooks/model_evalutation/classification-xgboost.ipynb` 53 | 54 | * The evaluation notebooks contain the following metrics: 55 | * Overall classification performance measured by macro F1-score (shows overall performance of the classifier) 56 | * Plot of per cell type F1-score (can be used to spot cell types where the model currently struggles with) 57 | * TSNE visualization of predicted and true labels 58 | * TSNE visualization is calculated based on the first 50 PCA components of the test set 59 | * Additionally, binary indicator whether a prediction is `right` or `wrong` is overlaid on the TSNE plots 60 | 61 | * Evaluation notebooks for TabNet model: 62 | * `notebooks/model_evaluation/model-scaling-tabnet.ipynb`: Classification performance vs training data size for TabNet 63 | model 64 | * `notebooks/model_evaluation/classificationf-tabnet-ensembl.ipynb`: Deep ensemble of TabNet models + evaluation of 65 | uncertainty quantification of predictions 66 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Data set creation 2 | **Author:** Felix Fischer (GitHub @felix0097) \ 3 | **Date:** 29.08.2023 4 | 5 | ## Data set curation 6 | * Based on Cell-by-Gene (CxG) census version `2023-05-15` 7 | * CxG census: https://chanzuckerberg.github.io/cellxgene-census/index.html) 8 | ```python 9 | import cellxgene_census 10 | 11 | census = cellxgene_census.open_soma(census_version='2023-05-15') 12 | ``` 13 | * The CxG census can be used to query CxG data with a standardised API: 14 | https://chanzuckerberg.github.io/cellxgene-census/python-api.html 15 | * The census version `2023-05-15` is a long-term supported (LTS) release and will be hosted by cellxgene for at least 16 | 5 years 17 | * Gene space subset to `19331` protein coding genes (see `notebooks/store_creation/features.parquet` for full list) 18 | * Cells from the census are subset with following criterions: 19 | 1. Data has to be primary data to prevent label leakage between the train and test set: `is_primary_data == True` 20 | 2. Sequencing protocol has to be in `Protocols` (`assay in PROTOCOLS`) with 21 | ```pytyhon 22 | PROTOCOLS = [ 23 | "10x 5' v2", 24 | "10x 3' v3", 25 | "10x 3' v2", 26 | "10x 5' v1", 27 | "10x 3' v1", 28 | "10x 3' transcription profiling", 29 | "10x 5' transcription profiling" 30 | ] 31 | ``` 32 | 3. Annotated cell type has to be a subtype of `native cell` 33 | 4. There have to be at least `5000` cells for a specific cell type 34 | 5. Each cell type has to be observed in at least `30` donors to reliably quantify whether the classifier can 35 | generalize to new donors. 36 | 6. Each cell type needs to have at least `7` parent nodes according to the cell type ontology. This criterion is used 37 | as heuristic filter out to filter out to granular cell type labels 38 | granular / general cell type labels 39 | * Split data into train, val and test set 40 | * Splits are based on donors: 41 | * E.g. each donor is either in the train, val or test set (Unlike for random subsampling) 42 | * A donor based split better represents how the classifier generalises to unseen donors / data sets 43 | * A donor based split roughly resembles a random split when looking at the overall proportion of cells in the 44 | train, val and test set.cIf the number of donors is large enough this hold wells. On average 68% of the samples per 45 | cell type are in the train set in the current setting (ideal would be 70% with a 70% - 15% - 15% split). The worst 46 | outlier cell type is that only 37% of the cells are in the training set. 47 | * Split fraction: train=0.7, val=0.15, test=0.15 48 | * The code to reproduce the data set creation can be found under 49 | `notebooks/store_creation/01_create_train_val_test_splits.ipynb` 50 | 51 | 52 | ## Data preprocessing 53 | Preprocessing includes the following steps: 54 | 1. Normalize each cell to have `10000` counts (size factor normalization) 55 | ```python 56 | import numpy as np 57 | 58 | 59 | def sf_normalize(X): 60 | X = X.copy() 61 | counts = np.array(X.sum(axis=1)) 62 | # avoid zero division error 63 | counts += counts == 0. 64 | # normalize to 10000. counts 65 | scaling_factor = 10000. / counts 66 | np.multiply(X, scaling_factor.reshape((-1, 1)), out=X) 67 | 68 | return X 69 | ``` 70 | 2. After size factor normalization the data is Log1p transformed 71 | ```python 72 | import numpy as np 73 | 74 | 75 | def log1p_norm(x): 76 | return np.log1p(x) 77 | ``` 78 | 79 | 80 | ## Data statistics 81 | 82 | * `164` cell types 83 | * `5052` unique donors 84 | * Data set size: `22.189.056` cells 85 | * train: `15.240.192` cells 86 | * val: `3.500.,032` cells 87 | * test: `3.448.832` cells 88 | * `19331` genes (protein coding genes) 89 | * `56` different tissues (`197` with more fine-grained tissue annotation) 90 | 91 | 92 | ## Data preparation pipeline 93 | 94 | * The data preparation pipeline can be found under `notebooks/store_creation`: 95 | 1. `01_download_data.ipynb`: Subset and download data from CELLxGENE census. 96 | 2. `02_create_train_val_test_split.ipynb`: Split data into train, val and test sets 97 | 3. `03_write_store_merlin.ipynb`: Save data into on-disk format that can be used by Nvidia Merlin dataloader 98 | (https://github.com/NVIDIA-Merlin/dataloader) 99 | 4. `04_create_hierarchy_matrices.ipynb`: Create child node lookup matrix to find subtypes based on cell type 100 | ontology 101 | 5. `05_compute_pca.ipynb`: Compute PCA embeddings for visualization (50 components) and model training 102 | (256 components) 103 | 6. `06_check_written_store.ipynb` (can be skipped): Sanity check written data 104 | 7. `07_data_summary.ipynb` (can be skipped): Data summary statistics 105 | -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | # Model overview 2 | **Author:** Felix Fischer (GitHub @felix0097) \ 3 | **Date:** 12.06.2023 4 | 5 | ## TabNet model 6 | * Based on https://arxiv.org/abs/1908.07442 7 | * Implementation: `TabnetClassifier` class under `cellnet/models.py` 8 | * Trained with cross-entropy loss 9 | * Input features: 19331 protein coding genes 10 | * Feature normalization: Normalize each cell to have 10000 counts + log1p transform 11 | * Output: class probabilities 12 | 13 | 14 | ## Linear model 15 | * Linear model (single fully connected layer) trained with cross-entropy loss 16 | * Implementation: `LinearClassifier` class under `cellnet/models.py` 17 | * Input features: 19331 protein coding genes 18 | * Feature normalization: Normalize each cell to have 10000 counts + log1p transform 19 | * Output: class probabilities 20 | 21 | 22 | ## MLP model 23 | * Multi-layer perceptron (MLP) model trained with cross-entropy loss 24 | * Implementation: `MLPClassifier` class under `cellnet/models.py` 25 | * Input features: 19331 protein coding genes 26 | * Feature normalization: Normalize each cell to have 10000 counts + log1p transform 27 | * Output: class probabilities 28 | 29 | 30 | ## XGBoost model 31 | * Based on official XGBoost model: https://xgboost.readthedocs.io/en/stable/ (version 1.6.2) 32 | * Trained with `multi:softprob` objective 33 | * Input features: 34 | * 256 PCA components (calculated based on all protein coding genes / 19331 genes) 35 | * PCA is computed based on training data 36 | * Each cell is normalized to have 10000 counts + log1p transformed before computing PCA 37 | * Output: class probabilities 38 | -------------------------------------------------------------------------------- /notebooks-tutorials/data_loading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "21eef928-9a7c-447e-b1cb-a81330027c3c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Details about Nvidia Merlin" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "a1ddbe18-a511-4182-80b5-8df0a847727c", 14 | "metadata": {}, 15 | "source": [ 16 | "Documentation links:\n", 17 | "\n", 18 | "* General documentation: https://nvidia-merlin.github.io/Merlin/stable/README.html\n", 19 | "* GitHub: https://github.com/NVIDIA-Merlin/dataloader\n", 20 | "* Docker containers: https://nvidia-merlin.github.io/Merlin/stable/containers.html\n", 21 | "\n", 22 | "\n", 23 | "Installation via pip:\n", 24 | "1. Install cudf + dask-cudf: `python -m pip install cudf-cu11==23.08 rmm-cu11==23.08 dask-cudf-cu11==23.08 --extra-index-url https://pypi.nvidia.com/`\n", 25 | "2. Install merlin dataloader: `python -m pip install merlin-dataloader`\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "abd58faa-5248-4ee8-aaf9-0273ba21baf3", 31 | "metadata": {}, 32 | "source": [ 33 | "# Details about data" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "27bf368e-f476-45e4-b58d-332164d9c779", 39 | "metadata": {}, 40 | "source": [ 41 | "The parquet files have the folowing columns:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "id": "0553ff7a-56bd-4027-adbb-0002691274e9", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", 55 | " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", 56 | "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 57 | " from .autonotebook import tqdm as notebook_tqdm\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "from merlin.dtypes import boolean, float32, int64\n", 63 | "\n", 64 | "\n", 65 | "PARQUET_SCHEMA = {\n", 66 | " 'X': float32, # -> gene expression values (normalized to 10.000 counts per cell + log1p transformed)\n", 67 | " 'soma_joinid': int64, # soma_joinid from CELLxGENE\n", 68 | " 'is_primary_data': boolean, # binary indicator whether data is primary data or not (currently all data is primary data)\n", 69 | " 'dataset_id': int64, # name of the associated data set\n", 70 | " 'donor_id': int64, # name of the donor (caution! This might not be unique across datasets -> use tech_sample column instead)\n", 71 | " 'assay': int64, # name of the used assay\n", 72 | " 'cell_type': int64, # cell type label\n", 73 | " 'development_stage': int64, # development stage label\n", 74 | " 'disease': int64, # disease state label\n", 75 | " 'tissue': int64, # specfic tissue label\n", 76 | " 'tissue_general': int64, # general tissue label\n", 77 | " 'tech_sample': int64, # batch indicator \n", 78 | " 'idx': int64, # consecutive enumeration of all cells in the train, val and test data\n", 79 | "}" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "fe2671f8-1f32-4337-b4ae-b23f3123e703", 85 | "metadata": {}, 86 | "source": [ 87 | "All categorical meta data (['dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'tech_sample']) are encoded as integers. \n", 88 | "\n", 89 | "The lookup tables to map the integer labels to their corresponding string labels can be found under: `join(DATA_PATH, categorical_lookup)`\n", 90 | "\n", 91 | "E.g. the mapping for the `cell_type` column can be found in the `cell_type.parquet` file." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "c3376667-0555-4d9d-8864-3b26fdf31da4", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "8f0be050-9c18-489f-8fef-31d402ba288c", 105 | "metadata": {}, 106 | "source": [ 107 | "# Use with PyTorch Lightning DataModule" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 2, 113 | "id": "e6c007e2-5ca5-4eb4-a717-354626b73a68", 114 | "metadata": { 115 | "tags": [] 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "from cellnet.datamodules import MerlinDataModule" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 3, 125 | "id": "53c267d4-c197-4d13-a3ab-28a49c76e3ad", 126 | "metadata": { 127 | "tags": [] 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "# path to merlin store\n", 132 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "id": "7d74cb9b-3dc5-497d-b8e1-352b16d3a437", 139 | "metadata": { 140 | "tags": [] 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "datamodule = MerlinDataModule(\n", 145 | " path=DATA_PATH,\n", 146 | " columns=['cell_type'],\n", 147 | " batch_size=2048,\n", 148 | " sub_sample_frac=1., # randomly subsample data (can be between (0., 1.])\n", 149 | ")" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "id": "89058ebb-62b4-4919-8d31-9850ee58656a", 156 | "metadata": { 157 | "tags": [] 158 | }, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "X: tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 165 | " [0., 0., 0., ..., 0., 0., 0.],\n", 166 | " [0., 0., 0., ..., 0., 0., 0.],\n", 167 | " ...,\n", 168 | " [0., 0., 0., ..., 0., 0., 0.],\n", 169 | " [0., 0., 0., ..., 0., 0., 0.],\n", 170 | " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')\n", 171 | "cell_type: tensor([ 7, 127, 152, ..., 22, 127, 4], device='cuda:0')\n", 172 | "X: tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 173 | " [0., 0., 0., ..., 0., 0., 0.],\n", 174 | " [0., 0., 0., ..., 0., 0., 0.],\n", 175 | " ...,\n", 176 | " [0., 0., 0., ..., 0., 0., 0.],\n", 177 | " [0., 0., 0., ..., 0., 0., 0.],\n", 178 | " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')\n", 179 | "cell_type: tensor([ 9, 132, 118, ..., 14, 129, 127], device='cuda:0')\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "import gc\n", 185 | "\n", 186 | "\n", 187 | "# get dataloaders for train, valiation and test set\n", 188 | "train_loader = datamodule.train_dataloader()\n", 189 | "val_loader = datamodule.val_dataloader()\n", 190 | "test_loader = datamodule.test_dataloader()\n", 191 | "\n", 192 | "\n", 193 | "# how to use dataloaders\n", 194 | "for ix, (batch, _) in enumerate(train_loader):\n", 195 | " # put your training code here:\n", 196 | " print('X:', batch['X'])\n", 197 | " print('cell_type:', batch['cell_type'])\n", 198 | "\n", 199 | " # Merlin tends to use a lot of GPU memory if the garbage collection isn't called regularly\n", 200 | " # -> manually call python garbage collection every 10 steps \n", 201 | " if ix % 10 == 0:\n", 202 | " gc.collect()\n", 203 | "\n", 204 | " # don't iterate over all traning data for this tutorial\n", 205 | " if ix == 1:\n", 206 | " break\n" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "f07e2569-e3c8-4bf6-af2b-80f1c361a8c2", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "0abb0c02-7f7f-4521-b866-45c6fa3e395d", 220 | "metadata": {}, 221 | "source": [ 222 | "# Use as standalone PyTorch DataLoader" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 6, 228 | "id": "b308cc65-a2dd-4b37-9c2c-0c324c9825d5", 229 | "metadata": { 230 | "tags": [] 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "from os.path import join\n", 235 | "\n", 236 | "from cellnet.datamodules import merlin_dataset_factory, set_default_kwargs_dataset\n", 237 | "from merlin.dataloader.torch import Loader" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 7, 243 | "id": "96d04467-242d-4beb-b89b-ecf7c964d3ad", 244 | "metadata": { 245 | "tags": [] 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "# path to merlin store\n", 250 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 8, 256 | "id": "4732b0e7-ad05-43d7-b817-65cad0fca24e", 257 | "metadata": { 258 | "tags": [] 259 | }, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "X: tensor([[0., 0., 0., ..., 0., 0., 0.],\n", 266 | " [0., 0., 0., ..., 0., 0., 0.],\n", 267 | " [0., 0., 0., ..., 0., 0., 0.],\n", 268 | " ...,\n", 269 | " [0., 0., 0., ..., 0., 0., 0.],\n", 270 | " [0., 0., 0., ..., 0., 0., 0.],\n", 271 | " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')\n", 272 | "cell_type: tensor([46, 44, 38, ..., 67, 19, 60], device='cuda:0')\n", 273 | "X: tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 274 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 275 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 276 | " ...,\n", 277 | " [0.0000, 0.0000, 0.0000, ..., 1.4717, 0.0000, 0.0000],\n", 278 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 279 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", 280 | " device='cuda:0')\n", 281 | "cell_type: tensor([ 14, 64, 131, ..., 60, 63, 132], device='cuda:0')\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "# manually create data loaders for train and validation set\n", 287 | "train_dataset = merlin_dataset_factory(\n", 288 | " join(DATA_PATH, 'train'), \n", 289 | " columns=['cell_type'], \n", 290 | " dataset_kwargs=set_default_kwargs_dataset(training=True)\n", 291 | ")\n", 292 | "train_loader = Loader(train_dataset, batch_size=2048, shuffle=True)\n", 293 | "\n", 294 | "\n", 295 | "val_dataset = merlin_dataset_factory(\n", 296 | " join(DATA_PATH, 'val'), \n", 297 | " columns=['cell_type'], \n", 298 | " dataset_kwargs=set_default_kwargs_dataset(training=False)\n", 299 | ")\n", 300 | "val_loader = Loader(val_dataset, batch_size=2048, shuffle=False)\n", 301 | "\n", 302 | "\n", 303 | "# how to use dataloaders\n", 304 | "for ix, (batch, _) in enumerate(train_loader):\n", 305 | " # put your training code here:\n", 306 | " print('X:', batch['X'])\n", 307 | " print('cell_type:', batch['cell_type'])\n", 308 | "\n", 309 | " # Merlin tends to use a lot of GPU memory if the garbage collection isn't called regularly\n", 310 | " # -> manually call python garbage collection every 10 steps \n", 311 | " if ix % 10 == 0:\n", 312 | " gc.collect()\n", 313 | "\n", 314 | " # don't iterate over all traning data for this tutorial\n", 315 | " if ix == 1:\n", 316 | " break\n" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "85addfa5-f3da-4991-9222-1beb319be814", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [] 326 | } 327 | ], 328 | "metadata": { 329 | "kernelspec": { 330 | "display_name": "Python 3 (ipykernel)", 331 | "language": "python", 332 | "name": "python3" 333 | }, 334 | "language_info": { 335 | "codemirror_mode": { 336 | "name": "ipython", 337 | "version": 3 338 | }, 339 | "file_extension": ".py", 340 | "mimetype": "text/x-python", 341 | "name": "python", 342 | "nbconvert_exporter": "python", 343 | "pygments_lexer": "ipython3", 344 | "version": "3.8.10" 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 5 349 | } 350 | -------------------------------------------------------------------------------- /notebooks-tutorials/model_inference.md: -------------------------------------------------------------------------------- 1 | # Inference Steps 2 | 3 | ## Phase 1: Data Preprocessing 4 | 1. Collect Data 5 | * Raw count data is required (no normalization) 6 | * Data should be supplied in cellxgene `.h5ad` format (raw counts under `.raw.X`) 7 | 2. Align gene feature space 8 | * Order of columns in count matrix is given in `var.parquet` file. Reorder genes accordingly. Genes have to be in 9 | exactly same order! 10 | * Zero fill genes if they're missing in the supplied data 11 | * For the code in this notebook to work correctly, both data sets need to have the same ensembl release 12 | (release 104 in this case) 13 | 3. Wrap count data array into PyTorch data loader 14 | 15 | ## Phase 2: Load Trained Model 16 | 1. Load model checkpoint via `torch.load()` 17 | 2. Initialize new model 18 | 1. Load model architecture / hyperparameters from `hparams.yaml` file 19 | 2. Initialize model according to model architecture from step i. 20 | 3. Load weights + set model to inference model 21 | 1. Load pretrained weights via `model.load_state_dict(weights)` 22 | 2. Set model to eval mode `model.eval()` 23 | 24 | ## Phase 3: Run Model Inference 25 | 26 | 1. Run model inference 27 | 2. Map integer predictions to string labels via `cell_type.parquet` file 28 | -------------------------------------------------------------------------------- /notebooks/data_augmentation/shortend_cell_types.yaml: -------------------------------------------------------------------------------- 1 | B cell: B cell 2 | Bergmann glial cell: Bergmann glial cell 3 | CD14-low, CD16-positive monocyte: CD14-low, CD16+ monocyte 4 | CD14-positive monocyte: CD14+ monocyte 5 | CD14-positive, CD16-negative classical monocyte: CD14+, CD16- clas. monocyte 6 | CD14-positive, CD16-positive monocyte: CD14+, CD16+ monocyte 7 | CD16-negative, CD56-bright natural killer cell, human: CD16-, CD56-bright NK cell 8 | CD16-positive, CD56-dim natural killer cell, human: CD16+, CD56-dim NK cell 9 | CD1c-positive myeloid dendritic cell: CD1c+ myeloid dendritic cell 10 | CD4-positive helper T cell: CD4+ helper T cell 11 | CD4-positive, alpha-beta T cell: CD4+, $\alpha$-$\beta$ T cell 12 | CD4-positive, alpha-beta cytotoxic T cell: CD4+, $\alpha$-$\beta$ cytotoxic T cell 13 | CD4-positive, alpha-beta memory T cell: CD4+, $\alpha$-$\beta$ mem. T cell 14 | CD8-alpha-alpha-positive, alpha-beta intraepithelial T cell: "CD8-$\\alpha$-$\\alpha$+,\ 15 | \ $\alpha$-$\beta$ intraepith. T cell" 16 | CD8-positive, alpha-beta T cell: CD8+, $\alpha$-$\beta$ T cell 17 | CD8-positive, alpha-beta cytotoxic T cell: CD8+, $\alpha$-$\beta$ cytotoxic T cell 18 | CD8-positive, alpha-beta memory T cell: CD8+, $\alpha$-$\beta$ mem. T cell 19 | IgA plasma cell: IgA plasma cell 20 | IgG plasma cell: IgG plasma cell 21 | L2/3-6 intratelencephalic projecting glutamatergic cortical neuron: L2/3-6 IT projecting 22 | neuron 23 | L6b glutamatergic cortical neuron: L6b glutamatergic cortical neuron 24 | Schwann cell: Schwann cell 25 | T cell: T cell 26 | T follicular helper cell: T follicular helper cell 27 | T-helper 17 cell: T-helper 17 cell 28 | T-helper 22 cell: T-helper 22 cell 29 | activated CD4-positive, alpha-beta T cell: activated CD4+, $\alpha$-$\beta$ T cell 30 | activated CD8-positive, alpha-beta T cell: activated CD8+, $\alpha$-$\beta$ T cell 31 | alpha-beta T cell: $\alpha$-$\beta$ T cell 32 | alternatively activated macrophage: alternatively activated macrophage 33 | alveolar macrophage: alveolar macrophage 34 | alveolar type 1 fibroblast cell: alveolar type 1 fibroblast cell 35 | alveolar type 2 fibroblast cell: alveolar type 2 fibroblast cell 36 | amacrine cell: amacrine cell 37 | astrocyte: astrocyte 38 | astrocyte of the cerebral cortex: astrocyte of the cerebral cortex 39 | basal cell: basal cell 40 | basal cell of epithelium of trachea: basal cell of epithelium of trachea 41 | blood vessel endothelial cell: blood vessel endothelial cell 42 | bronchus fibroblast of lung: bronchus fibroblast of lung 43 | capillary endothelial cell: capillary endothelial cell 44 | cardiac muscle cell: cardiac muscle cell 45 | cardiac neuron: cardiac neuron 46 | caudal ganglionic eminence derived GABAergic cortical interneuron: caudal ganglionic 47 | eminence der. GABAergic cortical interneuron 48 | central memory CD4-positive, alpha-beta T cell: central mem. CD4+, $\alpha$-$\beta$ 49 | T cell 50 | central memory CD8-positive, alpha-beta T cell: central mem. CD8+, $\alpha$-$\beta$ 51 | T cell 52 | central nervous system macrophage: central nervous sys. macrophage 53 | chandelier pvalb GABAergic cortical interneuron: chandelier pvalb GABAergic cortical 54 | interneuron 55 | ciliated columnar cell of tracheobronchial tree: ciliated columnar cell of tracheobronchial 56 | tree 57 | class switched memory B cell: class switched mem. B cell 58 | classical monocyte: clas. monocyte 59 | club cell: club cell 60 | conventional dendritic cell: conventional dendritic cell 61 | corticothalamic-projecting glutamatergic cortical neuron: corticothalamic-project. 62 | glutamatergic cortical neuron 63 | dendritic cell: dendritic cell 64 | double negative T regulatory cell: double negative T regulatory cell 65 | double negative thymocyte: double negative thymocyte 66 | double-positive, alpha-beta thymocyte: double+, $\alpha$-$\beta$ thymocyte 67 | effector CD8-positive, alpha-beta T cell: effector CD8+, $\alpha$-$\beta$ T cell 68 | effector memory CD4-positive, alpha-beta T cell: effector mem. CD4+, $\alpha$-$\beta$ 69 | T cell 70 | effector memory CD8-positive, alpha-beta T cell: effector mem. CD8+, $\alpha$-$\beta$ 71 | T cell 72 | effector memory CD8-positive, alpha-beta T cell, terminally differentiated: effector 73 | mem. CD8+, $\alpha$-$\beta$ T cell, term. diff. 74 | elicited macrophage: elicited macrophage 75 | endothelial cell: endothelial cell 76 | endothelial cell of artery: endothelial cell of artery 77 | endothelial cell of lymphatic vessel: endothelial cell of lymphatic vessel 78 | enteric smooth muscle cell: enteric smooth muscle cell 79 | enterocyte: enterocyte 80 | enteroendocrine cell: enteroendocrine cell 81 | ependymal cell: ependymal cell 82 | epithelial cell of proximal tubule: epith. cell of proximal tubule 83 | erythroblast: erythroblast 84 | erythrocyte: erythrocyte 85 | exhausted T cell: exhausted T cell 86 | fallopian tube secretory epithelial cell: fallopian tube secretory epith. cell 87 | fibroblast of cardiac tissue: fibroblast of cardiac tissue 88 | foveolar cell of stomach: foveolar cell of stomach 89 | gamma-delta T cell: $\gamma$-$\delta$ T cell 90 | glutamatergic neuron: glutamatergic neuron 91 | goblet cell: goblet cell 92 | granulocyte: granulocyte 93 | granulosa cell: granulosa cell 94 | hematopoietic stem cell: hematopoietic stem cell 95 | immature B cell: immature B cell 96 | immature innate lymphoid cell: immature innate lymphoid cell 97 | inflammatory macrophage: inflammatory macrophage 98 | innate lymphoid cell: innate lymphoid cell 99 | intermediate monocyte: intermediate monocyte 100 | intestine goblet cell: intestine goblet cell 101 | intraepithelial lymphocyte: intraepithelial lymphocyte 102 | keratinocyte: keratinocyte 103 | kidney collecting duct intercalated cell: kidney collecting duct intercalated cell 104 | kidney collecting duct principal cell: kidney collecting duct principal cell 105 | kidney connecting tubule epithelial cell: kidney connecting tubule epith. cell 106 | kidney distal convoluted tubule epithelial cell: kidney distal convoluted tubule epith. 107 | cell 108 | kidney interstitial fibroblast: kidney interstitial fibroblast 109 | kidney loop of Henle thick ascending limb epithelial cell: kidney loop Henle thick 110 | asc. limb epith. cell 111 | kidney loop of Henle thin ascending limb epithelial cell: kidney loop Henle thin asc. 112 | limb epith. cell 113 | kidney loop of Henle thin descending limb epithelial cell: kidney loop Henle thin 114 | des. limb epith. cell 115 | lamp5 GABAergic cortical interneuron: lamp5 GABAergic cortical interneuron 116 | leukocyte: leukocyte 117 | luminal epithelial cell of mammary gland: luminal epith. cell of mammary gland 118 | lung macrophage: lung macrophage 119 | lung pericyte: lung pericyte 120 | lymphocyte: lymphocyte 121 | lymphoid lineage restricted progenitor cell: lymphoid lineage restricted progenitor 122 | cell 123 | macrophage: macrophage 124 | mast cell: mast cell 125 | mature B cell: mature B cell 126 | mature NK T cell: mature NK T cell 127 | mature T cell: mature T cell 128 | mature alpha-beta T cell: mature $\alpha$-$\beta$ T cell 129 | mature gamma-delta T cell: mature gamma-delta T cell 130 | megakaryocyte-erythroid progenitor cell: megakaryocyte-erythroid progenitor cell 131 | memory B cell: mem. B cell 132 | memory T cell: mem. T cell 133 | mesothelial cell: mesothelial cell 134 | microglial cell: microglial cell 135 | monocyte: monocyte 136 | mucosal invariant T cell: mucosal invariant T cell 137 | mucous neck cell: mucous neck cell 138 | myoepithelial cell of mammary gland: myoepithelial cell of mammary gland 139 | naive B cell: naive B cell 140 | naive T cell: naive T cell 141 | naive thymus-derived CD4-positive, alpha-beta T cell: naive thym.-der. CD4+, $\alpha$-$\beta$ 142 | T cell 143 | naive thymus-derived CD8-positive, alpha-beta T cell: naive thym.-der. CD8+, $\alpha$-$\beta$ 144 | T cell 145 | nasal mucosa goblet cell: nasal mucosa goblet cell 146 | natural killer cell: NK cell 147 | near-projecting glutamatergic cortical neuron: near-project. glutamatergic cortical 148 | neuron 149 | neuron: neuron 150 | neutrophil: neutrophil 151 | non-classical monocyte: non-classical monocyte 152 | oligodendrocyte: oligodendrocyte 153 | oligodendrocyte precursor cell: oligodendrocyte pre. cell 154 | paneth cell: paneth cell 155 | pericyte: pericyte 156 | plasma cell: plasma cell 157 | plasmablast: plasmablast 158 | plasmacytoid dendritic cell: plasmacytoid dendritic cell 159 | platelet: platelet 160 | precursor B cell: pre. B cell 161 | pro-B cell: pro-B cell 162 | promonocyte: promonocyte 163 | pulmonary artery endothelial cell: pulmonary artery endothelial cell 164 | pvalb GABAergic cortical interneuron: pvalb GABAergic cortical interneuron 165 | regulatory T cell: regulatory T cell 166 | renal interstitial pericyte: renal interstitial pericyte 167 | respiratory basal cell: respiratory basal cell 168 | respiratory hillock cell: respiratory hillock cell 169 | retina horizontal cell: retina horizontal cell 170 | retinal cone cell: retinal cone cell 171 | retinal ganglion cell: retinal ganglion cell 172 | retinal rod cell: retinal rod cell 173 | smooth muscle cell: smooth muscle cell 174 | sncg GABAergic cortical interneuron: sncg GABAergic cortical interneuron 175 | sst GABAergic cortical interneuron: sst GABAergic cortical interneuron 176 | tracheal goblet cell: tracheal goblet cell 177 | tracheobronchial smooth muscle cell: tracheobronchial smooth muscle cell 178 | transitional stage B cell: transitional stage B cell 179 | type I pneumocyte: type I pneumocyte 180 | type II pneumocyte: type II pneumocyte 181 | vascular associated smooth muscle cell: vascular assoc. smooth muscle cell 182 | vein endothelial cell: vein endothelial cell 183 | vip GABAergic cortical interneuron: vip GABAergic cortical interneuron 184 | -------------------------------------------------------------------------------- /notebooks/model_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/scTab/5ede7f2ba1f9618b86924f2ff587931de18f4ada/notebooks/model_evaluation/__init__.py -------------------------------------------------------------------------------- /notebooks/model_evaluation/predictions.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from os.path import join 3 | 4 | import anndata 5 | import dask.dataframe as dd 6 | import lightning.pytorch as pl 7 | import numpy as np 8 | import pandas as pd 9 | import xgboost as xgb 10 | from scipy.sparse import csr_matrix 11 | from cellnet.estimators import EstimatorCellTypeClassifier 12 | from cellnet.models import TabnetClassifier, LinearClassifier, MLPClassifier 13 | 14 | 15 | def get_count_matrix(ddf): 16 | x = ( 17 | ddf['X'] 18 | .map_partitions( 19 | lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 20 | meta={col: 'f4' for col in range(19331)} 21 | ) 22 | .to_dask_array(lengths=[1024] * ddf.npartitions) 23 | ) 24 | 25 | return x 26 | 27 | 28 | def eval_tabnet(ckpts, data_path): 29 | estim = EstimatorCellTypeClassifier(data_path) 30 | estim.init_datamodule(batch_size=2048) 31 | estim.trainer = pl.Trainer(logger=[], accelerator='gpu', devices=1) 32 | 33 | preds = [] 34 | for ckpt in ckpts: 35 | estim.model = TabnetClassifier.load_from_checkpoint(ckpt, **estim.get_fixed_model_params('tabnet')) 36 | probas = estim.predict(estim.datamodule.test_dataloader()) 37 | preds.append(np.argmax(probas, axis=1)) 38 | gc.collect() 39 | 40 | return preds 41 | 42 | 43 | def eval_linear(ckpts, data_path): 44 | estim = EstimatorCellTypeClassifier(data_path) 45 | estim.init_datamodule(batch_size=2048) 46 | estim.trainer = pl.Trainer(logger=[], accelerator='gpu', devices=1) 47 | 48 | preds = [] 49 | for ckpt in ckpts: 50 | estim.model = LinearClassifier.load_from_checkpoint(ckpt, **estim.get_fixed_model_params('linear')) 51 | probas = estim.predict(estim.datamodule.test_dataloader()) 52 | preds.append(np.argmax(probas, axis=1)) 53 | gc.collect() 54 | 55 | return preds 56 | 57 | 58 | def eval_xgboost(ckpts, data_path): 59 | x_test = np.load(join(data_path, 'pca/x_pca_training_test_split_256.npy')) 60 | 61 | preds = [] 62 | for ckpt in ckpts: 63 | clf = xgb.XGBClassifier() 64 | clf.load_model(ckpt) 65 | clf.set_params(predictor='gpu_predictor') 66 | preds.append(clf.predict(x_test)) 67 | 68 | return preds 69 | 70 | 71 | def eval_mlp(ckpts, data_path): 72 | estim = EstimatorCellTypeClassifier(data_path) 73 | estim.init_datamodule(batch_size=2048) 74 | estim.trainer = pl.Trainer(logger=[], accelerator='gpu', devices=1) 75 | 76 | preds = [] 77 | for ckpt in ckpts: 78 | estim.model = MLPClassifier.load_from_checkpoint(ckpt, **estim.get_fixed_model_params('mlp')) 79 | probas = estim.predict(estim.datamodule.test_dataloader()) 80 | preds.append(np.argmax(probas, axis=1)) 81 | gc.collect() 82 | 83 | return preds 84 | 85 | 86 | def eval_celltypist(ckpts, data_path): 87 | import celltypist 88 | 89 | ddf = dd.read_parquet(join(data_path, 'test'), split_row_groups=True) 90 | x = get_count_matrix(ddf) 91 | var = pd.read_parquet(join(data_path, 'var.parquet')) 92 | 93 | preds = [] 94 | for ckpt in ckpts: 95 | preds_ckpt = [] 96 | # run this in batches to keep the memory footprint in check 97 | for i, idxs in enumerate(np.array_split(np.arange(x.shape[0]), 20)): 98 | # data is already normalized 99 | adata_test = anndata.AnnData( 100 | X=x[idxs, :].map_blocks(csr_matrix).compute(), 101 | var=var.set_index('feature_name') 102 | ) 103 | preds_ckpt.append(celltypist.annotate(adata_test, model=ckpt)) 104 | 105 | preds.append( 106 | np.concatenate([batch.predicted_labels.to_numpy().flatten() for batch in preds_ckpt]) 107 | ) 108 | 109 | return preds 110 | -------------------------------------------------------------------------------- /notebooks/model_evaluation/shortend_cell_types.yaml: -------------------------------------------------------------------------------- 1 | B cell: B cell 2 | Bergmann glial cell: Bergmann glial cell 3 | CD14-low, CD16-positive monocyte: CD14-low, CD16+ monocyte 4 | CD14-positive monocyte: CD14+ monocyte 5 | CD14-positive, CD16-negative classical monocyte: CD14+, CD16- clas. monocyte 6 | CD14-positive, CD16-positive monocyte: CD14+, CD16+ monocyte 7 | CD16-negative, CD56-bright natural killer cell, human: CD16-, CD56-bright NK cell 8 | CD16-positive, CD56-dim natural killer cell, human: CD16+, CD56-dim NK cell 9 | CD1c-positive myeloid dendritic cell: CD1c+ myeloid dendritic cell 10 | CD4-positive helper T cell: CD4+ helper T cell 11 | CD4-positive, alpha-beta T cell: CD4+, $\alpha$-$\beta$ T cell 12 | CD4-positive, alpha-beta cytotoxic T cell: CD4+, $\alpha$-$\beta$ cytotoxic T cell 13 | CD4-positive, alpha-beta memory T cell: CD4+, $\alpha$-$\beta$ mem. T cell 14 | CD8-alpha-alpha-positive, alpha-beta intraepithelial T cell: CD8-$\alpha$-$\alpha$+,\ $\alpha$-$\beta$ intraepith. T cell 15 | CD8-positive, alpha-beta T cell: CD8+, $\alpha$-$\beta$ T cell 16 | CD8-positive, alpha-beta cytotoxic T cell: CD8+, $\alpha$-$\beta$ cytotoxic T cell 17 | CD8-positive, alpha-beta memory T cell: CD8+, $\alpha$-$\beta$ mem. T cell 18 | IgA plasma cell: IgA plasma cell 19 | IgG plasma cell: IgG plasma cell 20 | L2/3-6 intratelencephalic projecting glutamatergic cortical neuron: L2/3-6 IT projecting neuron 21 | L6b glutamatergic cortical neuron: L6b glutamatergic cortical neuron 22 | Schwann cell: Schwann cell 23 | T cell: T cell 24 | T follicular helper cell: T follicular helper cell 25 | T-helper 17 cell: T-helper 17 cell 26 | T-helper 22 cell: T-helper 22 cell 27 | activated CD4-positive, alpha-beta T cell: activated CD4+, $\alpha$-$\beta$ T cell 28 | activated CD8-positive, alpha-beta T cell: activated CD8+, $\alpha$-$\beta$ T cell 29 | alpha-beta T cell: $\alpha$-$\beta$ T cell 30 | alternatively activated macrophage: alternatively activated macrophage 31 | alveolar macrophage: alveolar macrophage 32 | alveolar type 1 fibroblast cell: alveolar type 1 fibroblast 33 | alveolar type 2 fibroblast cell: alveolar type 2 fibroblast 34 | amacrine cell: amacrine cell 35 | astrocyte: astrocyte 36 | astrocyte of the cerebral cortex: astrocyte of the cerebral cortex 37 | basal cell: basal cell 38 | basal cell of epithelium of trachea: basal cell of epithelium of trachea 39 | blood vessel endothelial cell: blood vessel endothelial cell 40 | bronchus fibroblast of lung: bronchus fibroblast of lung 41 | capillary endothelial cell: capillary endothelial cell 42 | cardiac muscle cell: cardiac muscle cell 43 | cardiac neuron: cardiac neuron 44 | caudal ganglionic eminence derived GABAergic cortical interneuron: CGE cortical interneuron 45 | central memory CD4-positive, alpha-beta T cell: central mem. CD4+, $\alpha$-$\beta$ T cell 46 | central memory CD8-positive, alpha-beta T cell: central mem. CD8+, $\alpha$-$\beta$ T cell 47 | central nervous system macrophage: central nervous sys. macrophage 48 | chandelier pvalb GABAergic cortical interneuron: Chandelier PV 49 | ciliated columnar cell of tracheobronchial tree: ciliated columnar cell of tracheobronchial tree 50 | class switched memory B cell: class switched mem. B cell 51 | classical monocyte: classical monocyte 52 | club cell: club cell 53 | conventional dendritic cell: conventional dendritic cell 54 | corticothalamic-projecting glutamatergic cortical neuron: CT projecting neuron 55 | dendritic cell: dendritic cell 56 | double negative T regulatory cell: double negative T regulatory cell 57 | double negative thymocyte: double negative thymocyte 58 | double-positive, alpha-beta thymocyte: double+, $\alpha$-$\beta$ thymocyte 59 | effector CD8-positive, alpha-beta T cell: effector CD8+, $\alpha$-$\beta$ T cell 60 | effector memory CD4-positive, alpha-beta T cell: effector mem. CD4+, $\alpha$-$\beta$ T cell 61 | effector memory CD8-positive, alpha-beta T cell: effector mem. CD8+, $\alpha$-$\beta$ T cell 62 | effector memory CD8-positive, alpha-beta T cell, terminally differentiated: effector mem. CD8+, $\alpha$-$\beta$ T cell, term. diff. 63 | elicited macrophage: elicited macrophage 64 | endothelial cell: endothelial cell 65 | endothelial cell of artery: endothelial cell of artery 66 | endothelial cell of lymphatic vessel: endothelial cell of lymphatic vessel 67 | enteric smooth muscle cell: enteric smooth muscle cell 68 | enterocyte: enterocyte 69 | enteroendocrine cell: enteroendocrine cell 70 | ependymal cell: ependymal cell 71 | epithelial cell of proximal tubule: epith. cell of proximal tubule 72 | erythroblast: erythroblast 73 | erythrocyte: erythrocyte 74 | exhausted T cell: exhausted T cell 75 | fallopian tube secretory epithelial cell: fallopian tube secretory epith. cell 76 | fibroblast of cardiac tissue: fibroblast of cardiac tissue 77 | foveolar cell of stomach: foveolar cell of stomach 78 | gamma-delta T cell: $\gamma$-$\delta$ T cell 79 | glutamatergic neuron: glutamatergic neuron 80 | goblet cell: goblet cell 81 | granulocyte: granulocyte 82 | granulosa cell: granulosa cell 83 | hematopoietic stem cell: hematopoietic stem cell 84 | immature B cell: immature B cell 85 | immature innate lymphoid cell: immature innate lymphoid cell 86 | inflammatory macrophage: inflammatory macrophage 87 | innate lymphoid cell: innate lymphoid cell 88 | intermediate monocyte: intermediate monocyte 89 | intestine goblet cell: intestine goblet cell 90 | intraepithelial lymphocyte: intraepithelial lymphocyte 91 | keratinocyte: keratinocyte 92 | kidney collecting duct intercalated cell: kidney collecting duct intercalated cell 93 | kidney collecting duct principal cell: kidney collecting duct principal cell 94 | kidney connecting tubule epithelial cell: kidney connecting tubule epith. cell 95 | kidney distal convoluted tubule epithelial cell: kidney distal convoluted tubule epith. cell 96 | kidney interstitial fibroblast: kidney interstitial fibroblast 97 | kidney loop of Henle thick ascending limb epithelial cell: kidney loop Henle thick asc. limb epith. cell 98 | kidney loop of Henle thin ascending limb epithelial cell: kidney loop Henle thin asc. limb epith. cell 99 | kidney loop of Henle thin descending limb epithelial cell: kidney loop Henle thin des. limb epith. cell 100 | lamp5 GABAergic cortical interneuron: lamp5 GABAergic cortical interneuron 101 | leukocyte: leukocyte 102 | luminal epithelial cell of mammary gland: luminal epith. cell of mammary gland 103 | lung macrophage: lung macrophage 104 | lung pericyte: lung pericyte 105 | lymphocyte: lymphocyte 106 | lymphoid lineage restricted progenitor cell: lymphoid lineage restricted progenitor cell 107 | macrophage: macrophage 108 | mast cell: mast cell 109 | mature B cell: mature B cell 110 | mature NK T cell: mature NK T cell 111 | mature T cell: mature T cell 112 | mature alpha-beta T cell: mature $\alpha$-$\beta$ T cell 113 | mature gamma-delta T cell: mature gamma-delta T cell 114 | megakaryocyte-erythroid progenitor cell: megakaryocyte-erythroid progenitor cell 115 | memory B cell: mem. B cell 116 | memory T cell: mem. T cell 117 | mesothelial cell: mesothelial cell 118 | microglial cell: microglial cell 119 | monocyte: monocyte 120 | mucosal invariant T cell: mucosal invariant T cell 121 | mucous neck cell: mucous neck cell 122 | myoepithelial cell of mammary gland: myoepithelial cell of mammary gland 123 | naive B cell: naive B cell 124 | naive T cell: naive T cell 125 | naive thymus-derived CD4-positive, alpha-beta T cell: naive thym.-der. CD4+, $\alpha$-$\beta$ T cell 126 | naive thymus-derived CD8-positive, alpha-beta T cell: naive thym.-der. CD8+, $\alpha$-$\beta$ T cell 127 | nasal mucosa goblet cell: nasal mucosa goblet cell 128 | natural killer cell: NK cell 129 | near-projecting glutamatergic cortical neuron: near-project. glutamatergic cortical neuron 130 | neuron: neuron 131 | neutrophil: neutrophil 132 | non-classical monocyte: non-classical monocyte 133 | oligodendrocyte: oligodendrocyte 134 | oligodendrocyte precursor cell: oligodendrocyte pre. cell 135 | paneth cell: paneth cell 136 | pericyte: pericyte 137 | plasma cell: plasma cell 138 | plasmablast: plasmablast 139 | plasmacytoid dendritic cell: plasmacytoid dendritic cell 140 | platelet: platelet 141 | precursor B cell: pre. B cell 142 | pro-B cell: pro-B cell 143 | promonocyte: promonocyte 144 | pulmonary artery endothelial cell: pulmonary artery endothelial cell 145 | pvalb GABAergic cortical interneuron: pvalb GABAergic cortical interneuron 146 | regulatory T cell: regulatory T cell 147 | renal interstitial pericyte: renal interstitial pericyte 148 | respiratory basal cell: respiratory basal cell 149 | respiratory hillock cell: respiratory hillock cell 150 | retina horizontal cell: retina horizontal cell 151 | retinal cone cell: retinal cone cell 152 | retinal ganglion cell: retinal ganglion cell 153 | retinal rod cell: retinal rod cell 154 | smooth muscle cell: smooth muscle cell 155 | sncg GABAergic cortical interneuron: sncg GABAergic cortical interneuron 156 | sst GABAergic cortical interneuron: sst GABAergic cortical interneuron 157 | tracheal goblet cell: tracheal goblet cell 158 | tracheobronchial smooth muscle cell: tracheobronchial smooth muscle cell 159 | transitional stage B cell: transitional stage B cell 160 | type I pneumocyte: type I pneumocyte 161 | type II pneumocyte: type II pneumocyte 162 | vascular associated smooth muscle cell: vascular assoc. smooth muscle cell 163 | vein endothelial cell: vein endothelial cell 164 | vip GABAergic cortical interneuron: vip GABAergic cortical interneuron 165 | -------------------------------------------------------------------------------- /notebooks/model_evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from os.path import join 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from numba import njit 9 | from sklearn.metrics import classification_report 10 | 11 | 12 | @njit 13 | def correct_labels(y_true: np.ndarray, y_pred: np.ndarray, child_matrix: np.ndarray): 14 | """ 15 | Update predictions. 16 | If prediction is actually a child node of the true label -> update prediction to true value. 17 | 18 | E.g: Label='T cell' and prediction='CD8 positive T cell' -> update prediction to 'T cell' 19 | """ 20 | updated_predictions = y_pred.copy() 21 | # precalculate child nodes 22 | child_nodes = {i: np.where(child_matrix[i, :])[0] for i in range(child_matrix.shape[0])} 23 | 24 | for i, (pred, true_label) in enumerate(zip(y_pred, y_true)): 25 | if pred in child_nodes[true_label]: 26 | updated_predictions[i] = true_label 27 | else: 28 | updated_predictions[i] = pred 29 | 30 | return updated_predictions 31 | 32 | 33 | def get_best_ckpts(logs_path, versions): 34 | best_ckpts = [] 35 | 36 | for version in versions: 37 | # sort first -> in case both f1-score are the same -> take the one which was trained for fewer epochs 38 | files = sorted([file for file in os.listdir(join(logs_path, version, 'checkpoints')) if 'val_f1_macro' in file]) 39 | f1_scores = [float(re.search('val_f1_macro=(.*?).ckpt', file).group(1)) for file in files] 40 | best_ckpt = files[np.argmax(f1_scores)] 41 | best_ckpts.append(join(logs_path, version, 'checkpoints', best_ckpt)) 42 | 43 | return best_ckpts 44 | 45 | 46 | def macro_f1_per_group(y_true, y_pred, group_variable, grouping: Dict[str, List[str]]): 47 | assert len(y_true) == len(y_pred) == len(group_variable) 48 | groups = [] 49 | f1_macro = [] 50 | 51 | for group, group_assignments in grouping.items(): 52 | y_pred_group = y_pred[np.isin(group_variable, group_assignments).squeeze()] 53 | y_true_group = y_true[np.isin(group_variable, group_assignments).squeeze()] 54 | clf_report = pd.DataFrame(classification_report( 55 | y_true=y_true_group, 56 | y_pred=y_pred_group, 57 | labels=np.unique(y_true_group), 58 | output_dict=True, 59 | zero_division=0 60 | )).T 61 | groups.append(group) 62 | f1_macro.append(clf_report.loc['macro avg', 'f1-score']) 63 | 64 | return pd.DataFrame({'group': groups, 'f1_score': f1_macro}) 65 | 66 | 67 | """ 68 | Constants 69 | """ 70 | 71 | BIONETWORK_GROUPING = { 72 | 'adipose': ['adipose tissue'], 73 | 'breast': ['breast', 'exocrine gland'], 74 | 'eye': ['eye'], 75 | 'gut': [ 76 | 'digestive system', 'small intestine', 'colon', 'intestine', 'stomach', 'esophagus', 'large intestine', 77 | 'omentum', 'spleen', 'peritoneum', 'mucosa', 'abdomen', 'exocrine gland', 'endocrine gland' 78 | ], 79 | 'heart': ['heart', 'vasculature'], 80 | 'blood_and_immune': ['immune system', 'lymph node', 'blood', 'bone marrow', 'spleen'], 81 | 'kidney': ['kidney'], 82 | 'liver': ['liver'], 83 | 'lung': ['lung', 'respiratory system', 'pleural fluid'], 84 | 'musculoskeletal': ['musculature', 'bone marrow', 'vasculature'], 85 | 'nervous_system': ['brain', 'endocrine gland'], 86 | 'oral_and_craniofacial': ['tongue', 'nose', 'mucosa', 'exocrine gland', 'saliva', 'endocrine gland'], 87 | 'pancreas': ['pancreas', 'exocrine gland', 'endocrine gland'], 88 | 'reproduction': [ 89 | 'reproductive system', 'uterus', 'fallopian tube', 'ovary', 'prostate gland', 'endocrine gland', 90 | 'ascitic fluid', 'urinary bladder', 'peritoneum', 'bladder organ', 'placenta' 91 | ], 92 | 'skin': ['skin of body'] 93 | } 94 | -------------------------------------------------------------------------------- /notebooks/store_creation/02_create_train_val_test_split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3151a0f1-2f37-4202-aecd-5fee9f9cc909", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q zarr" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "0d59d417-2348-486f-aaab-9d43b5418e2a", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "from os.path import join\n", 24 | "\n", 25 | "import anndata\n", 26 | "import dask\n", 27 | "import dask.array as da\n", 28 | "import pandas as pd\n", 29 | "import numpy as np\n", 30 | "\n", 31 | "from scipy.sparse import csr_matrix\n", 32 | "\n", 33 | "dask.config.set(scheduler='threads');" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "5a962492-ae17-4612-926f-d5c14ae31246", 40 | "metadata": { 41 | "tags": [] 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "def read_X(path):\n", 46 | " return anndata.read_h5ad(path).X\n", 47 | "\n", 48 | "\n", 49 | "def read_obs(path):\n", 50 | " obs = anndata.read_h5ad(path, backed='r').obs\n", 51 | " obs['tech_sample'] = obs.dataset_id.astype(str) + '_' + obs.donor_id.astype(str)\n", 52 | " return obs\n", 53 | "\n", 54 | "\n", 55 | "def read_var(path):\n", 56 | " return anndata.read_h5ad(path, backed='r').var\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "1152c31d-90b6-42cf-8e53-c5ffe8caafed", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "a46eb975-b4c8-4c41-af47-6c8873ed4221", 70 | "metadata": { 71 | "tags": [] 72 | }, 73 | "source": [ 74 | "## Training data" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "a67522c8-192d-4657-999f-13cc24ea0470", 81 | "metadata": { 82 | "tags": [] 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15'" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "116db53b-3cb9-4b98-b181-e2334fe638d9", 92 | "metadata": {}, 93 | "source": [ 94 | "### Convert to zarr + DataFrame" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "bd764cbd-52c6-4833-8910-21e5ba14ef98", 101 | "metadata": { 102 | "tags": [] 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "files = [\n", 107 | " join(BASE_PATH, file) for file \n", 108 | " in sorted(os.listdir(BASE_PATH), key=lambda x: int(x.split('.')[0])) \n", 109 | " if file.endswith('.h5ad')\n", 110 | "]\n", 111 | "\n", 112 | "# read obs\n", 113 | "print('Loading obs...')\n", 114 | "obs = pd.concat([read_obs(file) for file in files]).reset_index(drop=True)\n", 115 | "for col in obs.columns:\n", 116 | " if obs[col].dtype == object:\n", 117 | " obs[col] = obs[col].astype('category')\n", 118 | " obs[col].cat.remove_unused_categories()\n", 119 | "# read var\n", 120 | "print('Loading var...')\n", 121 | "var = read_var(files[0])\n", 122 | "# read X\n", 123 | "print('Loading X...')\n", 124 | "split_lens = [len(split) for split in np.array_split(obs.soma_joinid.to_numpy(), 20)]\n", 125 | "X = da.concatenate([\n", 126 | " da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var)), dtype='f4') \n", 127 | " for file, split_len in zip(files, split_lens)\n", 128 | "]).rechunk((32768, -1)).persist()\n" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "38c08a38-3153-4e9f-8ba6-744670230044", 135 | "metadata": { 136 | "tags": [] 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "X" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "d19cc3c8-9e20-437a-8605-d5350757180e", 146 | "metadata": {}, 147 | "source": [ 148 | "### Create train, val, test split" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "a5b43814-ea3e-4680-9cb6-767d5624fbb7", 155 | "metadata": { 156 | "tags": [] 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "from statistics import mode\n", 161 | "from scipy.sparse import csr_matrix" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "c2284694-ae8a-4afc-8858-33359c99a075", 168 | "metadata": { 169 | "tags": [] 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "from math import ceil\n", 174 | "\n", 175 | "\n", 176 | "def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed=1):\n", 177 | " rng = np.random.default_rng(seed=seed)\n", 178 | "\n", 179 | " samples = np.array(samples)\n", 180 | " rng.shuffle(samples)\n", 181 | " n_samples = len(samples)\n", 182 | "\n", 183 | " n_samples_val = ceil(val_split * n_samples)\n", 184 | " n_samples_test = ceil(test_split * n_samples)\n", 185 | " n_samples_train = n_samples - n_samples_val - n_samples_test\n", 186 | "\n", 187 | " return {\n", 188 | " 'train': samples[:n_samples_train],\n", 189 | " 'val': samples[n_samples_train:(n_samples_train + n_samples_val)],\n", 190 | " 'test': samples[(n_samples_train + n_samples_val):]\n", 191 | " }\n", 192 | "\n", 193 | "\n", 194 | "def subset(splits, frac):\n", 195 | " assert 0. < frac <= 1.\n", 196 | " if frac == 1.:\n", 197 | " return splits\n", 198 | " else:\n", 199 | " return splits[:ceil(frac * len(splits))]\n" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "3650eeb4-9d73-45ab-b6f3-6e91f4a2d140", 206 | "metadata": { 207 | "tags": [] 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "# subsample_fracs: 0.15, 0.3, 0.5, 0.7, 1.\n", 212 | "SUBSAMPLE_FRAC = 1." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "f14ea53f-9cbb-4359-aaa8-7acf51217577", 219 | "metadata": { 220 | "tags": [] 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "splits = {'train': [], 'val': [], 'test': []}\n", 225 | "tech_sample_splits = get_split(obs.tech_sample.unique().tolist())\n", 226 | "for x in ['train', 'val', 'test']:\n", 227 | " # tech_samples are already shuffled in the get_split method -> just subselect to subsample donors\n", 228 | " if x == 'train':\n", 229 | " # only subset training data set\n", 230 | " splits[x] = obs[obs.tech_sample.isin(subset(tech_sample_splits[x], SUBSAMPLE_FRAC))].index.to_numpy()\n", 231 | " else:\n", 232 | " splits[x] = obs[obs.tech_sample.isin(tech_sample_splits[x])].index.to_numpy()\n", 233 | "\n", 234 | "splits" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "b59614c4-27a8-464e-809c-75860636b368", 241 | "metadata": { 242 | "tags": [] 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "assert len(np.intersect1d(splits['train'], splits['val'])) == 0\n", 247 | "assert len(np.intersect1d(splits['train'], splits['test'])) == 0\n", 248 | "assert len(np.intersect1d(splits['val'], splits['train'])) == 0\n", 249 | "assert len(np.intersect1d(splits['val'], splits['test'])) == 0" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "c3bd131f-32fb-4f7e-b8e0-516e0caf73a3", 256 | "metadata": { 257 | "tags": [] 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "print(f\"train: {len(obs.loc[splits['train'], :]):,} cells\")\n", 262 | "print(f\"val: {len(obs.loc[splits['val'], :]):,} cells\")\n", 263 | "print(f\"test: {len(obs.loc[splits['test'], :]):,} cells\")" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "2ddcd10e-d67f-416d-ba65-9f51124e8459", 270 | "metadata": { 271 | "tags": [] 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "print(f\"train: {len(np.unique(obs.loc[splits['train'], 'cell_type']))} celltypes\")\n", 276 | "print(f\"val: {len(np.unique(obs.loc[splits['val'], 'cell_type']))} celltypes\")\n", 277 | "print(f\"test: {len(np.unique(obs.loc[splits['test'], 'cell_type']))} celltypes\")" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "b5d72e17-7a52-4a93-9c7b-aaaeb6653779", 284 | "metadata": { 285 | "tags": [] 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "print(f\"train: {len(np.unique(obs.loc[splits['train'], 'tech_sample']))} donors\")\n", 290 | "print(f\"val: {len(np.unique(obs.loc[splits['val'], 'tech_sample']))} donors\")\n", 291 | "print(f\"test: {len(np.unique(obs.loc[splits['test'], 'tech_sample']))} donors\")" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "id": "4af5351c-bdf9-4f4f-8323-8975ac0945fc", 298 | "metadata": { 299 | "tags": [] 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "rng = np.random.default_rng(seed=1)\n", 304 | "\n", 305 | "splits['train'] = rng.permutation(splits['train'])\n", 306 | "splits['val'] = rng.permutation(splits['val'])\n", 307 | "splits['test'] = rng.permutation(splits['test'])\n", 308 | "\n", 309 | "splits" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "id": "ee8aa4f4-a72a-43b0-9872-dff02654894d", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "id": "69bc24e5-c047-46d2-94f2-ea901cc68438", 323 | "metadata": {}, 324 | "source": [ 325 | "### Save data" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "885bce5a-4389-411c-869c-c76ef126fdeb", 332 | "metadata": { 333 | "tags": [] 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15'\n", 338 | "if SUBSAMPLE_FRAC < 1.:\n", 339 | " SAVE_PATH = SAVE_PATH + f'_subsample_{round(SUBSAMPLE_FRAC * 100)}'\n", 340 | "\n", 341 | "CHUNK_SIZE = 16384" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "7dfe0471-2295-4e31-b4bb-bfa996f43d79", 348 | "metadata": { 349 | "tags": [] 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "if SUBSAMPLE_FRAC < 1.:\n", 354 | " # only save train data for subset stores\n", 355 | " # val + test can be copyed later from non subset store\n", 356 | " splits_to_save = ['train']\n", 357 | "else:\n", 358 | " splits_to_save = ['train', 'val', 'test']\n", 359 | "\n", 360 | "\n", 361 | "for split, idxs in splits.items():\n", 362 | " if split in splits_to_save:\n", 363 | " # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets\n", 364 | " X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))\n", 365 | " obs_split = obs.loc[idxs, :]\n", 366 | "\n", 367 | " save_dir = join(SAVE_PATH, split)\n", 368 | " os.makedirs(save_dir)\n", 369 | "\n", 370 | " var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)\n", 371 | " obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)\n", 372 | " da.to_zarr(\n", 373 | " X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),\n", 374 | " join(save_dir, 'zarr'),\n", 375 | " component='X',\n", 376 | " compute=True,\n", 377 | " compressor='default', \n", 378 | " order='C'\n", 379 | " )\n" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "id": "1455be19-10cc-4e0f-9133-9e9fe9814de6", 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [] 389 | } 390 | ], 391 | "metadata": { 392 | "kernelspec": { 393 | "display_name": "Python 3 (ipykernel)", 394 | "language": "python", 395 | "name": "python3" 396 | }, 397 | "language_info": { 398 | "codemirror_mode": { 399 | "name": "ipython", 400 | "version": 3 401 | }, 402 | "file_extension": ".py", 403 | "mimetype": "text/x-python", 404 | "name": "python", 405 | "nbconvert_exporter": "python", 406 | "pygments_lexer": "ipython3", 407 | "version": "3.8.10" 408 | } 409 | }, 410 | "nbformat": 4, 411 | "nbformat_minor": 5 412 | } 413 | -------------------------------------------------------------------------------- /notebooks/store_creation/03_write_store_merlin.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8aaf8878-307d-49c6-a21b-90607770ab0d", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install zarr\n", 13 | "!pip install scipy" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "bb3fb8cb-5e98-4af3-ab6e-51c33852e95b", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import os\n", 26 | "\n", 27 | "import dask\n", 28 | "import dask.array as da\n", 29 | "import dask.dataframe as dd\n", 30 | "import pandas as pd\n", 31 | "import numpy as np\n", 32 | "import pyarrow as pa\n", 33 | "\n", 34 | "from os.path import join\n", 35 | "import shutil" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "4c75ee8b-373f-44a8-83c3-a8ec7f0b3fd4", 42 | "metadata": { 43 | "tags": [] 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "from dask.distributed import Client, LocalCluster\n", 48 | "\n", 49 | "\n", 50 | "cluster = LocalCluster(n_workers=5) # assume 20 cores on LRZ -> 5 workers with 4 threads each\n", 51 | "client = Client(cluster)\n", 52 | "client" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "09d5274a-e048-441a-ab06-d162770febe5", 59 | "metadata": { 60 | "tags": [] 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "NORMALIZATION = 'sf-log1p'\n", 65 | "\n", 66 | "# sf-log1p -> normalize to 10000 counts + log1p transform data\n", 67 | "# raw -> don't normalize data\n", 68 | "\n", 69 | "assert NORMALIZATION in ['sf-log1p', 'raw']" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "63a4aea0-0af9-47c6-9ec6-5743aeae6375", 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "from scipy.sparse import csc_matrix, csr_matrix, issparse\n", 82 | "from sklearn.utils import sparsefuncs\n", 83 | "\n", 84 | "\n", 85 | "def sf_normalize(X):\n", 86 | " X = X.copy()\n", 87 | " counts = np.array(X.sum(axis=1))\n", 88 | " # avoid zero devision error\n", 89 | " counts += counts == 0.\n", 90 | " # normalize to 10000. counts\n", 91 | " scaling_factor = 10000. / counts\n", 92 | "\n", 93 | " if issparse(X):\n", 94 | " sparsefuncs.inplace_row_scale(X, scaling_factor)\n", 95 | " else:\n", 96 | " np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)\n", 97 | "\n", 98 | " return X\n", 99 | "\n", 100 | "\n", 101 | "def sf_log1p_norm(x):\n", 102 | " x = sf_normalize(x)\n", 103 | " return np.log1p(x).astype('f4')\n", 104 | "\n", 105 | "\n", 106 | "def preprocess_count_matrix(x, normalization):\n", 107 | " if normalization == 'sf-log1p':\n", 108 | " return x.map_blocks(sf_log1p_norm, dtype='f4')\n", 109 | " elif normalization == 'raw':\n", 110 | " return x\n", 111 | " else:\n", 112 | " raise ValueError(f'NORMALIZATION has to be in [\"sf-log1p\", \"raw\"]')\n", 113 | "\n", 114 | "\n", 115 | "@dask.delayed\n", 116 | "def convert_to_dataframe(x, start, end):\n", 117 | " return pd.DataFrame(\n", 118 | " {'X': [arr.squeeze().astype('f4') for arr in np.vsplit(x, x.shape[0])]},\n", 119 | " index=pd.RangeIndex(start, end)\n", 120 | " )\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "5851124f-a8ab-4747-9cc5-1254088c196d", 126 | "metadata": { 127 | "tags": [] 128 | }, 129 | "source": [ 130 | "# Training data" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "ec9ef276-f4d0-40b4-bb68-c886633a66da", 137 | "metadata": { 138 | "tags": [] 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "DATA_PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'\n", 143 | "OUT_PATH = f'/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_{NORMALIZATION}'\n", 144 | "\n", 145 | "os.makedirs(OUT_PATH)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "9e78af09-08ec-4c5b-9b8f-90e3cc85f9e5", 151 | "metadata": {}, 152 | "source": [ 153 | "## Copy var dataframe + norm data" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "6d13a0bc-c223-4760-9f7e-686844c255c8", 160 | "metadata": { 161 | "tags": [] 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "shutil.copy(join(DATA_PATH, 'train', 'var.parquet'), join(OUT_PATH, 'var.parquet'));" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "c8401d2a-0dcb-4980-b712-0ce28c60679c", 172 | "metadata": { 173 | "tags": [] 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "# only run if NORMALIZATION == 'sf-quantile'\n", 178 | "!cp -r {join(DATA_PATH, 'norm')} {join(OUT_PATH, 'norm')}" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "01ed62d7-f3ed-4470-8f16-786145f2dafe", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "f279a9d4-4146-4b76-b276-bef7c29e1103", 192 | "metadata": {}, 193 | "source": [ 194 | "## Create lookup tables for categorical variables" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "6a07e899-81ce-40cf-a5f9-5c3feb33f4ac", 201 | "metadata": { 202 | "tags": [] 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "from pandas import testing as tm" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "ddf6b670-6ba1-4235-9fde-05947fa520bd", 213 | "metadata": { 214 | "tags": [] 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "obs_train = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)\n", 219 | "obs_val = pd.read_parquet(join(DATA_PATH, 'val', 'obs.parquet')).reset_index(drop=True)\n", 220 | "obs_test = pd.read_parquet(join(DATA_PATH, 'test', 'obs.parquet')).reset_index(drop=True)\n", 221 | "\n", 222 | "obs = pd.concat([obs_train, obs_val, obs_test])" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "55d7697a-57f2-4615-949a-f4840896e2d1", 229 | "metadata": { 230 | "tags": [] 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "cols_train = obs_train.columns.tolist()\n", 235 | "assert cols_train == obs_val.columns.tolist()\n", 236 | "assert cols_train == obs_test.columns.tolist()" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "a5f4941e-000a-4264-962b-9b2347c2a749", 243 | "metadata": { 244 | "tags": [] 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "for col in cols_train:\n", 249 | " if obs[col].dtype.name == 'category':\n", 250 | " obs[col] = obs[col].cat.remove_unused_categories()\n", 251 | "\n", 252 | "\n", 253 | "for col in cols_train:\n", 254 | " if obs[col].dtype.name == 'category':\n", 255 | " categories = list(obs[col].cat.categories)\n", 256 | " obs_train[col] = pd.Categorical(obs_train[col], categories, ordered=False)\n", 257 | " obs_val[col] = pd.Categorical(obs_val[col], categories, ordered=False)\n", 258 | " obs_test[col] = pd.Categorical(obs_test[col], categories, ordered=False)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "id": "ee8fedbb-0e17-4e48-8db5-3f9db7218a69", 265 | "metadata": { 266 | "tags": [] 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "lookup_path = join(OUT_PATH, 'categorical_lookup')\n", 271 | "os.makedirs(lookup_path)\n", 272 | "\n", 273 | "for col in cols_train:\n", 274 | " if obs_train[col].dtype.name == 'category':\n", 275 | " cats_train = pd.Series(dict(enumerate(obs_train[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n", 276 | " cats_val = pd.Series(dict(enumerate(obs_val[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n", 277 | " cats_test = pd.Series(dict(enumerate(obs_test[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n", 278 | "\n", 279 | " tm.assert_frame_equal(cats_train, cats_val)\n", 280 | " tm.assert_frame_equal(cats_train, cats_test)\n", 281 | "\n", 282 | " cats_train.to_parquet(join(lookup_path, f'{col}.parquet'), index=True)\n" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "398a8a0a-007b-4058-8ead-b90085bb412d", 289 | "metadata": { 290 | "tags": [] 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "# only use integer labels from now on\n", 295 | "for col in cols_train:\n", 296 | " if obs_train[col].dtype.name == 'category':\n", 297 | " obs_train[col] = obs_train[col].cat.codes.astype('i8')\n", 298 | " obs_val[col] = obs_val[col].cat.codes.astype('i8')\n", 299 | " obs_test[col] = obs_test[col].cat.codes.astype('i8')" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "id": "5268e753-5aed-460a-814d-83fa9e2d17c1", 306 | "metadata": { 307 | "tags": [] 308 | }, 309 | "outputs": [], 310 | "source": [ 311 | "obs_dict = {'train': obs_train, 'val': obs_val, 'test': obs_test}" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "id": "cfcbbc45-4c01-4f36-aeea-98c2ad0689dc", 318 | "metadata": { 319 | "tags": [] 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "from sklearn.utils.class_weight import compute_class_weight\n", 324 | "\n", 325 | "# calculate and save class weights\n", 326 | "class_weights = compute_class_weight('balanced', classes=np.unique(obs_train['cell_type']), y=obs_train['cell_type'])\n", 327 | "\n", 328 | "with open(join(OUT_PATH, 'class_weights.npy'), 'wb') as f:\n", 329 | " np.save(f, class_weights)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "id": "13bfdecc-e1ec-42a3-a134-4c0c59064ce4", 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "id": "32a708b7-fde4-4f3c-b9d0-c1ba820a4a5c", 343 | "metadata": {}, 344 | "source": [ 345 | "## Write store" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "id": "238f9e13-ce0d-498e-b832-fa824737c625", 352 | "metadata": { 353 | "tags": [] 354 | }, 355 | "outputs": [], 356 | "source": [ 357 | "CHUNK_SIZE = 32768\n", 358 | "ROW_GROUP_SIZE = 1024\n", 359 | "\n", 360 | "\n", 361 | "for split in ['train', 'val', 'test']:\n", 362 | " X = preprocess_count_matrix(da.from_zarr(join(DATA_PATH, split, 'zarr'), 'X'), NORMALIZATION)\n", 363 | " obs_ = obs_dict[split]\n", 364 | " # cut off samples that all row groups are full\n", 365 | " n_samples = X.shape[0]\n", 366 | " n_samples = (n_samples // ROW_GROUP_SIZE) * ROW_GROUP_SIZE\n", 367 | " X = X[:n_samples].rechunk((CHUNK_SIZE, -1))\n", 368 | " obs_ = obs_.iloc[:n_samples].copy()\n", 369 | " # add an index column to identifiy each sample\n", 370 | " obs_['idx'] = np.arange(len(obs_), dtype='i8')\n", 371 | " start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]\n", 372 | " end_index = list(np.cumsum(X.chunks[0]))\n", 373 | " # calculate divisons for dask dataframe\n", 374 | " divisions = [0] + list(np.cumsum(X.chunks[0]))\n", 375 | " divisions[-1] = divisions[-1] - 1\n", 376 | " ddf = dd.from_delayed(\n", 377 | " [\n", 378 | " convert_to_dataframe(arr, start, end) for arr, start, end in \n", 379 | " zip(X.to_delayed().flatten().tolist(), start_index, end_index)\n", 380 | " ],\n", 381 | " divisions=divisions\n", 382 | " )\n", 383 | " obs_dask = dd.from_pandas(obs_, chunksize=CHUNK_SIZE)\n", 384 | " assert np.allclose(ddf.divisions, obs_dask.divisions)\n", 385 | " ddf = dd.multi.concat([ddf, obs_dask], axis=1)\n", 386 | "\n", 387 | " schema = pa.schema([\n", 388 | " ('X', pa.list_(pa.float32())),\n", 389 | " ('soma_joinid', pa.int64()),\n", 390 | " ('is_primary_data', pa.bool_()),\n", 391 | " ('dataset_id', pa.int64()),\n", 392 | " ('donor_id', pa.int64()),\n", 393 | " ('assay', pa.int64()),\n", 394 | " ('cell_type', pa.int64()),\n", 395 | " ('development_stage', pa.int64()),\n", 396 | " ('disease', pa.int64()),\n", 397 | " ('tissue', pa.int64()),\n", 398 | " ('tissue_general', pa.int64()),\n", 399 | " ('tech_sample', pa.int64()),\n", 400 | " ('idx', pa.int64()),\n", 401 | " ])\n", 402 | " print(f'{split}: {X.shape[0]} cells')\n", 403 | " ddf.to_parquet(\n", 404 | " join(OUT_PATH, split), \n", 405 | " engine='pyarrow',\n", 406 | " schema=schema,\n", 407 | " write_metadata_file=True,\n", 408 | " row_group_size=ROW_GROUP_SIZE\n", 409 | " )\n", 410 | " \n", 411 | " # free up memory\n", 412 | " client.restart()\n" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "id": "46c1696f-97d5-401d-b43d-36dd5b77bc48", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [] 422 | } 423 | ], 424 | "metadata": { 425 | "kernelspec": { 426 | "display_name": "Python 3 (ipykernel)", 427 | "language": "python", 428 | "name": "python3" 429 | }, 430 | "language_info": { 431 | "codemirror_mode": { 432 | "name": "ipython", 433 | "version": 3 434 | }, 435 | "file_extension": ".py", 436 | "mimetype": "text/x-python", 437 | "name": "python", 438 | "nbconvert_exporter": "python", 439 | "pygments_lexer": "ipython3", 440 | "version": "3.8.10" 441 | } 442 | }, 443 | "nbformat": 4, 444 | "nbformat_minor": 5 445 | } 446 | -------------------------------------------------------------------------------- /notebooks/store_creation/04_create_hierarchy_matrices.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "c4f0f77e-761c-4eb5-9c55-6eb8ea42dbf2", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install -q SPARQLWrapper" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "19a8fe66-f1f8-4505-a541-e70531372762", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "efec7d77-fd77-4224-becf-6b51dba6adc4", 31 | "metadata": { 32 | "tags": [] 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "from os.path import join\n", 38 | "\n", 39 | "import pandas as pd\n", 40 | "import numpy as np" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "2b02397a-ad12-4797-9e4f-58e3243f5856", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "39bc8541-9f3c-4567-b8a3-93ce7dfb8675", 58 | "metadata": {}, 59 | "source": [ 60 | "# Compute lookup matrices " 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "1aeafb86-ead1-4b93-aa3f-80dfb880b663", 67 | "metadata": { 68 | "tags": [] 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))\n", 73 | "\n", 74 | "inverse_mapping = (\n", 75 | " cell_type_mapping\n", 76 | " .assign(idx=range(len(cell_type_mapping)))\n", 77 | " .set_index('label', drop=True)\n", 78 | ")\n", 79 | "inverse_mapping.head()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "bdab6fee-d5dd-4218-8700-12c04942c3cd", 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "from cellnet.utils.cell_ontology import retrieve_child_nodes_from_ubergraph\n", 92 | "\n", 93 | "\n", 94 | "celltypes = cell_type_mapping.label.tolist()\n", 95 | "child_nodes_dict = {}\n", 96 | "for k, v in retrieve_child_nodes_from_ubergraph(celltypes).items():\n", 97 | " child_nodes_dict[k] = [elem for elem in v if elem in celltypes]" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "667fd053-155c-40e9-9b54-acf0b0bbeb22", 104 | "metadata": { 105 | "tags": [] 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "children_idx = []\n", 110 | "\n", 111 | "for cell_type in cell_type_mapping.label:\n", 112 | " child_nodes = child_nodes_dict[cell_type]\n", 113 | " children_idx.append(inverse_mapping.loc[child_nodes].idx.sort_values().tolist())\n", 114 | "\n", 115 | "cell_type_mapping['children'] = children_idx" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "a05dc6e8-7efb-4f52-903d-e9a0ce9647a0", 122 | "metadata": { 123 | "tags": [] 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "os.makedirs(join(DATA_PATH, 'cell_type_hierarchy'), exist_ok=True)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "46a01942-506b-4a34-ae11-d249f1fcf2ab", 134 | "metadata": { 135 | "tags": [] 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "child_matrix = np.eye(len(cell_type_mapping))\n", 140 | "\n", 141 | "for i, child_nodes in enumerate(cell_type_mapping.children):\n", 142 | " child_matrix[i, child_nodes] = 1.\n", 143 | " \n", 144 | "with open(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'), 'wb') as f:\n", 145 | " np.save(f, child_matrix)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "94d8b5d4-e17a-4f52-9813-540ff2aeae20", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "56c3c7b2-2a4d-482a-9922-6d20a3a287b7", 159 | "metadata": {}, 160 | "source": [ 161 | "# Sanity check lookup matrices" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "edef0da1-0a58-4e1c-b785-7d9af54f9486", 168 | "metadata": { 169 | "tags": [] 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "cell_type_mapping.loc[np.where(child_matrix[0, :] == 1.)[0]].label.tolist()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "2cf96a2d-c451-4a6d-bf55-3654d9658788", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3 (ipykernel)", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.8.10" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 5 206 | } 207 | -------------------------------------------------------------------------------- /notebooks/store_creation/05_compute_pca.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ef2525df-6418-493c-bfed-cecfc934a467", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import os\n", 13 | "from os.path import join\n", 14 | "\n", 15 | "import dask.dataframe as dd\n", 16 | "import dask.array as da\n", 17 | "import numpy as np\n", 18 | "import pandas as pd\n", 19 | "\n", 20 | "from dask_ml.decomposition import IncrementalPCA" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "e91c67be-8ca7-4496-94d8-b6a71299e560", 27 | "metadata": { 28 | "tags": [] 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "417a896b-5667-43fc-a018-c34acf9ea27d", 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "def get_count_matrix(ddf):\n", 45 | " x = (\n", 46 | " ddf['X']\n", 47 | " .map_partitions(\n", 48 | " lambda xx: pd.DataFrame(np.vstack(xx.tolist())), \n", 49 | " meta={col: 'f4' for col in range(19331)}\n", 50 | " )\n", 51 | " .to_dask_array(lengths=[1024] * ddf.npartitions)\n", 52 | " )\n", 53 | "\n", 54 | " return x\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "f0f99d72-6183-412b-b06b-423c53f5b06a", 60 | "metadata": {}, 61 | "source": [ 62 | "# Compute PCA for visualization" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "43fb96d0-7155-4e7d-8428-2757183474a0", 69 | "metadata": { 70 | "tags": [] 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "os.makedirs(join(PATH, 'pca'), exist_ok=True)\n", 75 | "\n", 76 | "\n", 77 | "n_comps = 50\n", 78 | "\n", 79 | "\n", 80 | "for split in ['test', 'val', 'train']:\n", 81 | " x = get_count_matrix(dd.read_parquet(join(PATH, split), split_row_groups=True))\n", 82 | " pca = IncrementalPCA(n_components=n_comps, iterated_power=3)\n", 83 | " x_pca = da.compute(pca.fit_transform(x))[0]\n", 84 | " with open(join(PATH, 'pca', f'x_pca_{split}_{n_comps}.npy'), 'wb') as f:\n", 85 | " np.save(f, x_pca)\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "dea42c83-eabf-4a4a-886a-f7de38411343", 91 | "metadata": {}, 92 | "source": [ 93 | "# Compute PCA for model training" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "fed1c4a0-5c42-4911-9e04-d55ae4bca680", 100 | "metadata": { 101 | "tags": [] 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "os.makedirs(join(PATH, 'pca'), exist_ok=True)\n", 106 | "\n", 107 | "\n", 108 | "n_comps = 256\n", 109 | "\n", 110 | "\n", 111 | "x_train = get_count_matrix(dd.read_parquet(join(PATH, 'train'), split_row_groups=True))\n", 112 | "x_val = get_count_matrix(dd.read_parquet(join(PATH, 'val'), split_row_groups=True))\n", 113 | "x_test = get_count_matrix(dd.read_parquet(join(PATH, 'test'), split_row_groups=True))\n", 114 | "\n", 115 | "\n", 116 | "pca = IncrementalPCA(n_components=n_comps, iterated_power=3)\n", 117 | "x_pca_train, x_pca_val, x_pca_test = da.compute(\n", 118 | " [pca.fit_transform(x_train), pca.transform(x_val), pca.transform(x_test)]\n", 119 | ")[0]\n", 120 | "\n", 121 | "\n", 122 | "with open(join(PATH, f'pca/x_pca_training_train_split_{n_comps}.npy'), 'wb') as f:\n", 123 | " np.save(f, x_pca_train)\n", 124 | "with open(join(PATH, f'pca/x_pca_training_val_split_{n_comps}.npy'), 'wb') as f:\n", 125 | " np.save(f, x_pca_val)\n", 126 | "with open(join(PATH, f'pca/x_pca_training_test_split_{n_comps}.npy'), 'wb') as f:\n", 127 | " np.save(f, x_pca_test)\n" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "b997843e-7ba9-4b77-8303-d8c0ab0c5525", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3 (ipykernel)", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.8.10" 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 5 160 | } 161 | -------------------------------------------------------------------------------- /notebooks/store_creation/features.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/scTab/5ede7f2ba1f9618b86924f2ff587931de18f4ada/notebooks/store_creation/features.parquet -------------------------------------------------------------------------------- /notebooks/store_creation/subsetted_data/subset_to_lung_only.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "965e4d64-71ac-4f23-9db2-6e742ece1da7", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install -q zarr" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "64ac080d-112a-428d-9566-dd97a5ecf98a", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import os\n", 25 | "from os.path import join\n", 26 | "\n", 27 | "import dask\n", 28 | "import dask.array as da\n", 29 | "import pandas as pd\n", 30 | "import numpy as np\n", 31 | "import numba\n", 32 | "\n", 33 | "from numba.typed import Dict\n", 34 | "from numba import prange" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "ed416b30-357c-4e2e-8b0e-aa0cea2fb52f", 41 | "metadata": { 42 | "tags": [] 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "fe375488-7aeb-4045-8220-810dd2ae57c5", 52 | "metadata": {}, 53 | "source": [ 54 | "# Get idxs for subsampling" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "0d7ae22f-901f-4530-a47a-4b24eafc195d", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "obs_train = pd.read_parquet(join(PATH, 'train/obs.parquet')).reset_index(drop=True)\n", 67 | "x_train = da.from_zarr(join(PATH, 'train/zarr'), component='X')\n", 68 | "\n", 69 | "obs_val = pd.read_parquet(join(PATH, 'val/obs.parquet')).reset_index(drop=True)\n", 70 | "x_val = da.from_zarr(join(PATH, 'val/zarr'), component='X')\n", 71 | "\n", 72 | "obs_test = pd.read_parquet(join(PATH, 'test/obs.parquet')).reset_index(drop=True)\n", 73 | "x_test = da.from_zarr(join(PATH, 'test/zarr'), component='X')" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "d7dc528c-971b-40c9-b6cb-aa98507a5e11", 80 | "metadata": { 81 | "tags": [] 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "var = pd.read_parquet(join(PATH, 'train/var.parquet'))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "399386bf-69a5-49f5-91ca-c1f716da42ca", 92 | "metadata": { 93 | "tags": [] 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "for col in obs_train.columns:\n", 98 | " if obs_train[col].dtype.name == 'category':\n", 99 | " obs_train[col] = obs_train[col].cat.remove_unused_categories()\n", 100 | "\n", 101 | "\n", 102 | "for col in obs_val.columns:\n", 103 | " if obs_val[col].dtype.name == 'category':\n", 104 | " obs_val[col] = obs_val[col].cat.remove_unused_categories()\n", 105 | " \n", 106 | "\n", 107 | "for col in obs_test.columns:\n", 108 | " if obs_test[col].dtype.name == 'category':\n", 109 | " obs_test[col] = obs_test[col].cat.remove_unused_categories()\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "3a22890a-a765-49b2-91dc-f3741f9c8cc4", 116 | "metadata": { 117 | "tags": [] 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "rng = np.random.default_rng(seed=1)\n", 122 | "\n", 123 | "subset_idxs = {}\n", 124 | "\n", 125 | "\n", 126 | "for split, obs in [('train', obs_train), ('val', obs_val), ('test', obs_test)]:\n", 127 | " idx_subset = obs[obs.tissue_general == 'lung'].index.to_numpy()\n", 128 | " rng.shuffle(idx_subset)\n", 129 | " subset_idxs[split] = idx_subset\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "403b29c0-3959-49da-b18e-c7fffec55400", 136 | "metadata": { 137 | "tags": [] 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "subset_idxs" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "82b31021-3af7-4e38-97f6-36aa52f29dd1", 148 | "metadata": { 149 | "tags": [] 150 | }, 151 | "outputs": [], 152 | "source": [] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "bdb667bf-2a00-4b44-837d-5720f4727e3b", 157 | "metadata": {}, 158 | "source": [ 159 | "# Store balanced data to disk" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "ded15de8-8238-4bf0-ba19-a45de784d794", 166 | "metadata": { 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15_lung_only'\n", 172 | "CHUNK_SIZE = 16384" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "1e23192f-3d9b-4f5f-bb71-ddf10b6fe63a", 179 | "metadata": { 180 | "tags": [] 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "for split, x, obs in [\n", 185 | " ('train', x_train, obs_train),\n", 186 | " ('val', x_val, obs_val),\n", 187 | " ('test', x_test, obs_test)\n", 188 | "]:\n", 189 | " # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets\n", 190 | " X_split = x[subset_idxs[split], :].rechunk((CHUNK_SIZE, -1))\n", 191 | " obs_split = obs.iloc[subset_idxs[split], :]\n", 192 | "\n", 193 | " save_dir = join(SAVE_PATH, split)\n", 194 | " os.makedirs(save_dir)\n", 195 | "\n", 196 | " var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)\n", 197 | " obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)\n", 198 | " da.to_zarr(\n", 199 | " X_split,\n", 200 | " join(save_dir, 'zarr'),\n", 201 | " component='X',\n", 202 | " compute=True,\n", 203 | " compressor='default', \n", 204 | " order='C'\n", 205 | " )\n" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "e964f8d5-489e-40b5-a382-3019f72a83fb", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 3 (ipykernel)", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.8.10" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 5 238 | } 239 | -------------------------------------------------------------------------------- /notebooks/store_creation/subsetted_data/write_store_merlin_lung_only.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8aaf8878-307d-49c6-a21b-90607770ab0d", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install zarr\n", 13 | "!pip install scipy" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "bb3fb8cb-5e98-4af3-ab6e-51c33852e95b", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import os\n", 26 | "\n", 27 | "import dask\n", 28 | "import dask.array as da\n", 29 | "import dask.dataframe as dd\n", 30 | "import pandas as pd\n", 31 | "import numpy as np\n", 32 | "import pyarrow as pa\n", 33 | "\n", 34 | "from os.path import join\n", 35 | "import shutil" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "4c75ee8b-373f-44a8-83c3-a8ec7f0b3fd4", 42 | "metadata": { 43 | "tags": [] 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "from dask.distributed import Client, LocalCluster\n", 48 | "\n", 49 | "\n", 50 | "cluster = LocalCluster(n_workers=5) # assume 20 cores on LRZ -> 5 workers with 4 threads each\n", 51 | "client = Client(cluster)\n", 52 | "client" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "09d5274a-e048-441a-ab06-d162770febe5", 59 | "metadata": { 60 | "tags": [] 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "NORMALIZATION = 'sf-log1p'\n", 65 | "\n", 66 | "# sf-log1p -> normalize to 10000 counts + log1p transform data\n", 67 | "# raw -> don't normalize data\n", 68 | "\n", 69 | "assert NORMALIZATION in ['sf-log1p', 'raw']" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "63a4aea0-0af9-47c6-9ec6-5743aeae6375", 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "from scipy.sparse import csc_matrix, csr_matrix, issparse\n", 82 | "from sklearn.utils import sparsefuncs\n", 83 | "\n", 84 | "\n", 85 | "def sf_normalize(X):\n", 86 | " X = X.copy()\n", 87 | " counts = np.array(X.sum(axis=1))\n", 88 | " # avoid zero devision error\n", 89 | " counts += counts == 0.\n", 90 | " # normalize to 10000. counts\n", 91 | " scaling_factor = 10000. / counts\n", 92 | "\n", 93 | " if issparse(X):\n", 94 | " sparsefuncs.inplace_row_scale(X, scaling_factor)\n", 95 | " else:\n", 96 | " np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)\n", 97 | "\n", 98 | " return X\n", 99 | "\n", 100 | "\n", 101 | "def sf_log1p_norm(x):\n", 102 | " x = sf_normalize(x)\n", 103 | " return np.log1p(x).astype('f4')\n", 104 | "\n", 105 | "\n", 106 | "def preprocess_count_matrix(x, normalization):\n", 107 | " if normalization == 'sf-log1p':\n", 108 | " return x.map_blocks(sf_log1p_norm, dtype='f4')\n", 109 | " elif normalization == 'raw':\n", 110 | " return x\n", 111 | " else:\n", 112 | " raise ValueError(f'NORMALIZATION has to be in [\"sf-quantile\", \"sf-log1p\", \"raw\"]')\n", 113 | "\n", 114 | "\n", 115 | "@dask.delayed\n", 116 | "def convert_to_dataframe(x, start, end):\n", 117 | " return pd.DataFrame(\n", 118 | " {'X': [arr.squeeze().astype('f4') for arr in np.vsplit(x, x.shape[0])]},\n", 119 | " index=pd.RangeIndex(start, end)\n", 120 | " )\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "5851124f-a8ab-4747-9cc5-1254088c196d", 126 | "metadata": { 127 | "tags": [] 128 | }, 129 | "source": [ 130 | "# Training data" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "ec9ef276-f4d0-40b4-bb68-c886633a66da", 137 | "metadata": { 138 | "tags": [] 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "DATA_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15_lung_only_extended'\n", 143 | "OUT_PATH = f'/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_{NORMALIZATION}_lung_only_extended'\n", 144 | "\n", 145 | "DATA_PATH_FULL = '/mnt/dssfs02/cxg_census/data_2023_05_15'\n", 146 | "OUT_PATH_FULL = f'/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_{NORMALIZATION}'\n", 147 | "\n", 148 | "os.makedirs(OUT_PATH)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "771c4c00-f1f7-4ca4-9d49-52444eca7e4e", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "9e78af09-08ec-4c5b-9b8f-90e3cc85f9e5", 162 | "metadata": {}, 163 | "source": [ 164 | "## Copy var dataframe + norm data" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "6d13a0bc-c223-4760-9f7e-686844c255c8", 171 | "metadata": { 172 | "tags": [] 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "shutil.copy(join(DATA_PATH, 'train', 'var.parquet'), join(OUT_PATH, 'var.parquet'));" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "01ed62d7-f3ed-4470-8f16-786145f2dafe", 183 | "metadata": { 184 | "tags": [] 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "# copy categorical mapping from full data set -> use same mapping for subset stores\n", 189 | "!cp -r {join(OUT_PATH_FULL, 'categorical_lookup')} {join(OUT_PATH, 'categorical_lookup')}" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "d526c191-4988-4a70-aa1a-a30910e8ff8c", 196 | "metadata": { 197 | "tags": [] 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "# copy cell_type_hierachy matrices from full data set\n", 202 | "!cp -r {join(OUT_PATH_FULL, 'cell_type_hierarchy')} {join(OUT_PATH, 'cell_type_hierarchy')}" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "48a4f936-4e41-45a8-9e57-d0d2e1ec7aff", 209 | "metadata": { 210 | "tags": [] 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "# copy augmentations from full data set\n", 215 | "!cp {join(OUT_PATH_FULL, 'augmentations.npy')} {join(OUT_PATH, 'augmentations.npy')}" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "bc8d6cd1-38db-4aa6-bde3-a5ed85760e8d", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "f279a9d4-4146-4b76-b276-bef7c29e1103", 229 | "metadata": {}, 230 | "source": [ 231 | "## Create lookup tables for categorical variables" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "6a07e899-81ce-40cf-a5f9-5c3feb33f4ac", 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "from pandas import testing as tm" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "412af217-74cf-405a-bb10-5c5ebb336ffe", 250 | "metadata": { 251 | "tags": [] 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "obs_train = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)\n", 256 | "obs_val = pd.read_parquet(join(DATA_PATH, 'val', 'obs.parquet')).reset_index(drop=True)\n", 257 | "obs_test = pd.read_parquet(join(DATA_PATH, 'test', 'obs.parquet')).reset_index(drop=True)\n", 258 | "\n", 259 | "cols_train = obs_train.columns.tolist()\n", 260 | "\n", 261 | "assert cols_train == obs_val.columns.tolist()\n", 262 | "assert cols_train == obs_test.columns.tolist()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "e1cd800c-706a-4126-91b0-eeaa986557e3", 269 | "metadata": { 270 | "tags": [] 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "# load category mapping from full data set\n", 275 | "category_mapping = {}\n", 276 | "for col in cols_train:\n", 277 | " if obs_train[col].dtype.name == 'category':\n", 278 | " category_mapping[col] = pd.read_parquet(join(OUT_PATH, 'categorical_lookup', f'{col}.parquet'))\n", 279 | "\n", 280 | "# use same mapping as used in the full data set\n", 281 | "for col in cols_train:\n", 282 | " if obs_train[col].dtype.name == 'category':\n", 283 | " categories = category_mapping[col].label.tolist()\n", 284 | " obs_train[col] = pd.Categorical(obs_train[col], categories, ordered=False)\n", 285 | " obs_val[col] = pd.Categorical(obs_val[col], categories, ordered=False)\n", 286 | " obs_test[col] = pd.Categorical(obs_test[col], categories, ordered=False)\n" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "f6d2d45e-1737-46ce-a9eb-41656e03ffa0", 293 | "metadata": { 294 | "tags": [] 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "# only use integer labels from now on\n", 299 | "for col in cols_train:\n", 300 | " if obs_train[col].dtype.name == 'category':\n", 301 | " obs_train[col] = obs_train[col].cat.codes.astype('i8')\n", 302 | " obs_val[col] = obs_val[col].cat.codes.astype('i8')\n", 303 | " obs_test[col] = obs_test[col].cat.codes.astype('i8')\n" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "id": "09a63581-b53f-4e9a-bb7e-6508f5587ee8", 310 | "metadata": { 311 | "tags": [] 312 | }, 313 | "outputs": [], 314 | "source": [ 315 | "obs_dict = {'train': obs_train, 'val': obs_val, 'test': obs_test}" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "id": "368ee607-974a-4208-8393-287c62aed384", 322 | "metadata": { 323 | "tags": [] 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "from sklearn.utils.class_weight import compute_class_weight\n", 328 | "\n", 329 | "# calculate and save class weights\n", 330 | "class_weights = np.zeros(len(category_mapping['cell_type']))\n", 331 | "class_weights[np.unique(obs_train['cell_type'])] = compute_class_weight(\n", 332 | " 'balanced', \n", 333 | " classes=np.unique(obs_train['cell_type']), \n", 334 | " y=obs_train['cell_type']\n", 335 | ")\n", 336 | "\n", 337 | "with open(join(OUT_PATH, 'class_weights.npy'), 'wb') as f:\n", 338 | " np.save(f, class_weights)\n" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "id": "13bfdecc-e1ec-42a3-a134-4c0c59064ce4", 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "id": "32a708b7-fde4-4f3c-b9d0-c1ba820a4a5c", 352 | "metadata": {}, 353 | "source": [ 354 | "## Write store" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "238f9e13-ce0d-498e-b832-fa824737c625", 361 | "metadata": { 362 | "tags": [] 363 | }, 364 | "outputs": [], 365 | "source": [ 366 | "CHUNK_SIZE = 32768\n", 367 | "ROW_GROUP_SIZE = 1024\n", 368 | "\n", 369 | "\n", 370 | "for split in ['train', 'val', 'test']:\n", 371 | " X = preprocess_count_matrix(da.from_zarr(join(DATA_PATH, split, 'zarr'), 'X'), NORMALIZATION)\n", 372 | " obs_ = obs_dict[split]\n", 373 | " # cut off samples that all row groups are full\n", 374 | " n_samples = X.shape[0]\n", 375 | " n_samples = (n_samples // ROW_GROUP_SIZE) * ROW_GROUP_SIZE\n", 376 | " X = X[:n_samples].rechunk((CHUNK_SIZE, -1))\n", 377 | " obs_ = obs_.iloc[:n_samples].copy()\n", 378 | " # add an index column to identifiy each sample\n", 379 | " obs_['idx'] = np.arange(len(obs_), dtype='i8')\n", 380 | " start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]\n", 381 | " end_index = list(np.cumsum(X.chunks[0]))\n", 382 | " # calculate divisons for dask dataframe\n", 383 | " divisions = [0] + list(np.cumsum(X.chunks[0]))\n", 384 | " divisions[-1] = divisions[-1] - 1\n", 385 | " ddf = dd.from_delayed(\n", 386 | " [\n", 387 | " convert_to_dataframe(arr, start, end) for arr, start, end in \n", 388 | " zip(X.to_delayed().flatten().tolist(), start_index, end_index)\n", 389 | " ],\n", 390 | " divisions=divisions\n", 391 | " )\n", 392 | " obs_dask = dd.from_pandas(obs_, chunksize=CHUNK_SIZE)\n", 393 | " assert np.allclose(ddf.divisions, obs_dask.divisions)\n", 394 | " ddf = dd.multi.concat([ddf, obs_dask], axis=1)\n", 395 | "\n", 396 | " schema = pa.schema([\n", 397 | " ('X', pa.list_(pa.float32())),\n", 398 | " ('soma_joinid', pa.int64()),\n", 399 | " ('is_primary_data', pa.bool_()),\n", 400 | " ('dataset_id', pa.int64()),\n", 401 | " ('donor_id', pa.int64()),\n", 402 | " ('assay', pa.int64()),\n", 403 | " ('cell_type', pa.int64()),\n", 404 | " ('development_stage', pa.int64()),\n", 405 | " ('disease', pa.int64()),\n", 406 | " ('tissue', pa.int64()),\n", 407 | " ('tissue_general', pa.int64()),\n", 408 | " ('tech_sample', pa.int64()),\n", 409 | " ('idx', pa.int64()),\n", 410 | " ])\n", 411 | " print(f'{split}: {X.shape[0]} cells')\n", 412 | " ddf.to_parquet(\n", 413 | " join(OUT_PATH, split), \n", 414 | " engine='pyarrow',\n", 415 | " schema=schema,\n", 416 | " write_metadata_file=True,\n", 417 | " row_group_size=ROW_GROUP_SIZE\n", 418 | " )\n", 419 | " \n", 420 | " # free up memory\n", 421 | " client.restart()\n" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "id": "46c1696f-97d5-401d-b43d-36dd5b77bc48", 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [] 431 | } 432 | ], 433 | "metadata": { 434 | "kernelspec": { 435 | "display_name": "Python 3 (ipykernel)", 436 | "language": "python", 437 | "name": "python3" 438 | }, 439 | "language_info": { 440 | "codemirror_mode": { 441 | "name": "ipython", 442 | "version": 3 443 | }, 444 | "file_extension": ".py", 445 | "mimetype": "text/x-python", 446 | "name": "python", 447 | "nbconvert_exporter": "python", 448 | "pygments_lexer": "ipython3", 449 | "version": "3.8.10" 450 | } 451 | }, 452 | "nbformat": 4, 453 | "nbformat_minor": 5 454 | } 455 | -------------------------------------------------------------------------------- /notebooks/training/train_celltypist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "88e11407-68bb-4075-922d-a50125eeaac2", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install celltypist" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "4cd684c3-173f-4662-8575-e4c5411b4fa7", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "from os.path import join\n", 25 | "\n", 26 | "import anndata\n", 27 | "import scanpy as sc\n", 28 | "import numpy as np\n", 29 | "import pandas as pd\n", 30 | "import dask.dataframe as dd\n", 31 | "import dask.array as da\n", 32 | "\n", 33 | "from scipy.sparse import csr_matrix" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "01af4fbc-bd8d-4a08-bb25-4fb66487e2a9", 39 | "metadata": {}, 40 | "source": [ 41 | "# Get subset training data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "c3efabd1-ec3a-4550-be1f-cf8a8b4d119d", 48 | "metadata": { 49 | "tags": [] 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "def get_count_matrix_and_obs(ddf):\n", 54 | " x = (\n", 55 | " ddf['X']\n", 56 | " .map_partitions(\n", 57 | " lambda xx: pd.DataFrame(np.vstack(xx.tolist())), \n", 58 | " meta={col: 'f4' for col in range(19331)}\n", 59 | " )\n", 60 | " .to_dask_array(lengths=[1024] * ddf.npartitions)\n", 61 | " )\n", 62 | " obs = ddf[['cell_type']].compute()\n", 63 | " \n", 64 | " return x, obs" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "5fcc7d3b-f557-4e2a-914b-aba95e9b7feb", 71 | "metadata": { 72 | "tags": [] 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "a0b716d9-65e9-48c3-a799-338afe2473ad", 83 | "metadata": { 84 | "tags": [] 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "ddf = dd.read_parquet(join(PATH, 'train'), split_row_groups=True)\n", 89 | "x, obs = get_count_matrix_and_obs(ddf)\n", 90 | "var = pd.read_parquet(join(PATH, 'var.parquet'))" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "31651587-9547-40fa-a180-7a40085dd677", 97 | "metadata": { 98 | "tags": [] 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "start = 0\n", 103 | "subsample_size = 1_500_000\n", 104 | "# data is already shuffled -> just take first x cells\n", 105 | "# data is already normalized\n", 106 | "adata_train = anndata.AnnData(\n", 107 | " X=x[start:start+subsample_size].map_blocks(csr_matrix).compute(), \n", 108 | " obs=obs.iloc[start:start+subsample_size],\n", 109 | " var=var.set_index('feature_name')\n", 110 | ")\n", 111 | "\n", 112 | "adata_train" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "c83cf482-4f87-478a-b3de-69e834223df5", 118 | "metadata": {}, 119 | "source": [ 120 | "# Fit celltyist model" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "5e82fe7a-2ff5-4a83-aa97-a8a04647376f", 127 | "metadata": { 128 | "tags": [] 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "import celltypist" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "80fa0837-7363-4798-8539-275ba20040b5", 139 | "metadata": { 140 | "tags": [] 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "new_model = celltypist.train(\n", 145 | " adata_train, \n", 146 | " labels='cell_type', \n", 147 | " n_jobs=20, \n", 148 | " feature_selection=True,\n", 149 | " use_SGD=True, \n", 150 | " mini_batch=True,\n", 151 | " batch_number=1500,\n", 152 | " with_mean=False,\n", 153 | " random_state=1\n", 154 | ")" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "12231d4b-0d54-467e-9444-b888d809fbf1", 161 | "metadata": { 162 | "tags": [] 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "new_model.write(f'/mnt/dssfs02/tb_logs/cxg_2023_05_15_celltypist/model_{subsample_size}_cells_run1.pkl')" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "4ae7117f-80f2-4c8e-8049-bdde8a6aa397", 173 | "metadata": { 174 | "tags": [] 175 | }, 176 | "outputs": [], 177 | "source": [] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "afe3044e-5946-4279-a50d-f144ce0d6205", 183 | "metadata": { 184 | "tags": [] 185 | }, 186 | "outputs": [], 187 | "source": [] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": "Python 3 (ipykernel)", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.8.10" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 5 211 | } 212 | -------------------------------------------------------------------------------- /notebooks/training/train_linear.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6ffb0367-a4cd-45ad-91ea-53075470a3e1", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "ed98a4a8-9985-4d40-994b-4822658cf9ae", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import os\n", 25 | "import seaborn as sns\n", 26 | "import torch\n", 27 | "\n", 28 | "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar\n", 29 | "from lightning.pytorch.loggers import TensorBoardLogger\n", 30 | "from lightning.pytorch.utilities.model_summary import ModelSummary\n", 31 | "from lightning.pytorch import seed_everything" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "f920fa0f-e133-439f-b765-a68be8ac8ed4", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "torch.set_float32_matmul_precision('high')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "40e1942f-4780-4f29-9ecb-4863a9251159", 50 | "metadata": { 51 | "tags": [] 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "%load_ext autoreload" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "56ebb8a5-0acf-4ffc-a5fc-95ff9c56c29e", 62 | "metadata": { 63 | "tags": [] 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "%autoreload\n", 68 | "from cellnet.estimators import EstimatorCellTypeClassifier" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "96c00147-7e0b-4c7f-8935-7f10b3d8767b", 74 | "metadata": {}, 75 | "source": [ 76 | "# Init model" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "3119057e-2f21-4f31-9265-b9b51cbfc5c1", 83 | "metadata": { 84 | "tags": [] 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "# config parameters\n", 89 | "MODEL = 'cxg_2023_05_15_linear'\n", 90 | "CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 91 | "LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 92 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p_subsample_15'\n", 93 | "SEED = 1\n", 94 | "\n", 95 | "\n", 96 | "estim = EstimatorCellTypeClassifier(DATA_PATH)\n", 97 | "seed_everything(SEED)\n", 98 | "estim.init_datamodule(batch_size=2048)\n", 99 | "estim.init_trainer(\n", 100 | " trainer_kwargs={\n", 101 | " 'max_epochs': 12,\n", 102 | " 'default_root_dir': CHECKPOINT_PATH,\n", 103 | " 'accelerator': 'gpu',\n", 104 | " 'devices': 1,\n", 105 | " 'num_sanity_val_steps': 0,\n", 106 | " 'check_val_every_n_epoch': 1,\n", 107 | " 'logger': [TensorBoardLogger(LOGS_PATH, name='default')],\n", 108 | " 'log_every_n_steps': 100,\n", 109 | " 'detect_anomaly': False,\n", 110 | " 'enable_progress_bar': True,\n", 111 | " 'enable_model_summary': False,\n", 112 | " 'enable_checkpointing': True,\n", 113 | " 'callbacks': [\n", 114 | " TQDMProgressBar(refresh_rate=50),\n", 115 | " LearningRateMonitor(logging_interval='step'),\n", 116 | " ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',\n", 117 | " every_n_epochs=1, save_top_k=2),\n", 118 | " ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',\n", 119 | " every_n_epochs=1, save_top_k=2)\n", 120 | " ],\n", 121 | " }\n", 122 | ")\n", 123 | "estim.init_model(\n", 124 | " model_type='linear',\n", 125 | " model_kwargs={\n", 126 | " 'learning_rate': 0.0005,\n", 127 | " 'weight_decay': 0.05,\n", 128 | " 'optimizer': torch.optim.AdamW,\n", 129 | " 'lr_scheduler': torch.optim.lr_scheduler.StepLR,\n", 130 | " 'lr_scheduler_kwargs': {\n", 131 | " 'step_size': 3,\n", 132 | " 'gamma': 0.9,\n", 133 | " 'verbose': True\n", 134 | " },\n", 135 | " },\n", 136 | ")\n", 137 | "print(ModelSummary(estim.model))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "64c24b97-c444-4903-85dd-37b972e2da13", 144 | "metadata": { 145 | "tags": [] 146 | }, 147 | "outputs": [], 148 | "source": [] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "7c968df4-ad6c-403f-82ba-fbd189d9691b", 153 | "metadata": { 154 | "tags": [] 155 | }, 156 | "source": [ 157 | "# Find learning rate" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "c7a14e18-707d-4671-ae83-0910c149ea25", 164 | "metadata": { 165 | "tags": [] 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 120})" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "dbe0d2b0-8688-49da-a4c3-84c988689199", 176 | "metadata": { 177 | "tags": [] 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])\n", 182 | "ax.set_xscale('log')\n", 183 | "ax.set_ylim(2., top=9.)\n", 184 | "ax.set_xlim(1e-6, 10.)\n", 185 | "print(f'Suggested learning rate: {lr_find_res[0]}')" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "1497c99c-db63-469b-8f98-401f2e075424", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "90bf1b36-fd69-4ee8-aa10-003e50b9e7d4", 199 | "metadata": {}, 200 | "source": [ 201 | "# Fit model" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "dd8a3efe-fb56-4794-9991-4f36816b0d85", 208 | "metadata": { 209 | "tags": [] 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "estim.train()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "9bd533cd-f029-492e-af35-49c841dc6394", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "Python 3 (ipykernel)", 228 | "language": "python", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.8.10" 242 | } 243 | }, 244 | "nbformat": 4, 245 | "nbformat_minor": 5 246 | } 247 | -------------------------------------------------------------------------------- /notebooks/training/train_mlp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import seaborn as sns\n", 24 | "import torch\n", 25 | "\n", 26 | "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar\n", 27 | "from lightning.pytorch.loggers import TensorBoardLogger\n", 28 | "from lightning.pytorch.utilities.model_summary import ModelSummary\n", 29 | "from lightning.pytorch import seed_everything" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "tags": [] 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "torch.set_float32_matmul_precision('high')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "collapsed": false, 48 | "jupyter": { 49 | "outputs_hidden": false 50 | }, 51 | "tags": [] 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "%load_ext autoreload" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "collapsed": false, 63 | "jupyter": { 64 | "outputs_hidden": false 65 | }, 66 | "tags": [] 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "%autoreload\n", 71 | "from cellnet.estimators import EstimatorCellTypeClassifier" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Init model" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "collapsed": false, 86 | "jupyter": { 87 | "outputs_hidden": false 88 | }, 89 | "tags": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "# config parameters\n", 94 | "MODEL = 'cxg_2023_05_15_mlp'\n", 95 | "CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 96 | "LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 97 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'\n", 98 | "\n", 99 | "\n", 100 | "estim = EstimatorCellTypeClassifier(DATA_PATH)\n", 101 | "seed_everything(1)\n", 102 | "estim.init_datamodule(batch_size=2048)\n", 103 | "estim.init_trainer(\n", 104 | " trainer_kwargs={\n", 105 | " 'max_epochs': 48,\n", 106 | " 'gradient_clip_val': 1.,\n", 107 | " 'gradient_clip_algorithm': 'norm',\n", 108 | " 'default_root_dir': CHECKPOINT_PATH,\n", 109 | " 'accelerator': 'gpu',\n", 110 | " 'devices': 1,\n", 111 | " 'num_sanity_val_steps': 0,\n", 112 | " 'check_val_every_n_epoch': 1,\n", 113 | " 'logger': [TensorBoardLogger(LOGS_PATH, name='default', version='version_7')],\n", 114 | " 'log_every_n_steps': 100,\n", 115 | " 'detect_anomaly': False,\n", 116 | " 'enable_progress_bar': True,\n", 117 | " 'enable_model_summary': False,\n", 118 | " 'enable_checkpointing': True,\n", 119 | " 'callbacks': [\n", 120 | " TQDMProgressBar(refresh_rate=50),\n", 121 | " LearningRateMonitor(logging_interval='step'),\n", 122 | " ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',\n", 123 | " every_n_epochs=1, save_top_k=2),\n", 124 | " ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',\n", 125 | " every_n_epochs=1, save_top_k=2)\n", 126 | " ],\n", 127 | " }\n", 128 | ")\n", 129 | "estim.init_model(\n", 130 | " model_type='mlp',\n", 131 | " model_kwargs={\n", 132 | " 'learning_rate': 0.002,\n", 133 | " 'weight_decay': 0.05,\n", 134 | " 'lr_scheduler': torch.optim.lr_scheduler.StepLR,\n", 135 | " 'lr_scheduler_kwargs': {\n", 136 | " 'step_size': 1,\n", 137 | " 'gamma': 0.9,\n", 138 | " 'verbose': True\n", 139 | " },\n", 140 | " 'optimizer': torch.optim.AdamW,\n", 141 | " 'n_hidden': 8,\n", 142 | " 'hidden_size': 128,\n", 143 | " 'dropout': 0.1,\n", 144 | " 'augment_training_data': True\n", 145 | " },\n", 146 | ")\n", 147 | "print(ModelSummary(estim.model))\n" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "tags": [] 161 | }, 162 | "source": [ 163 | "# Find learning rate" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "collapsed": false, 171 | "jupyter": { 172 | "outputs_hidden": false 173 | }, 174 | "tags": [] 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 100})" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])\n", 190 | "ax.set_xscale('log')\n", 191 | "ax.set_ylim(4., top=7.)\n", 192 | "ax.set_xlim(1e-6, 10.)\n", 193 | "print(f'Suggested learning rate: {lr_find_res[0]}')" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "# Fit model" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "collapsed": false, 208 | "jupyter": { 209 | "outputs_hidden": false 210 | }, 211 | "tags": [] 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "estim.train()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3 (ipykernel)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.10" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 4 247 | } 248 | -------------------------------------------------------------------------------- /notebooks/training/train_tabnet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import seaborn as sns\n", 24 | "import torch\n", 25 | "\n", 26 | "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar\n", 27 | "from lightning.pytorch.loggers import TensorBoardLogger\n", 28 | "from lightning.pytorch.utilities.model_summary import ModelSummary\n", 29 | "from lightning.pytorch import seed_everything" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "tags": [] 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "torch.set_float32_matmul_precision('high')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "collapsed": false, 48 | "jupyter": { 49 | "outputs_hidden": false 50 | }, 51 | "tags": [] 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "%load_ext autoreload" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "collapsed": false, 63 | "jupyter": { 64 | "outputs_hidden": false 65 | }, 66 | "tags": [] 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "%autoreload\n", 71 | "from cellnet.estimators import EstimatorCellTypeClassifier" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Init model" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "collapsed": false, 86 | "jupyter": { 87 | "outputs_hidden": false 88 | }, 89 | "tags": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "# config parameters\n", 94 | "MODEL = 'cxg_2023_05_15_lung_only_tabnet'\n", 95 | "CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 96 | "LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n", 97 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p_lung_only'\n", 98 | "\n", 99 | "\n", 100 | "estim = EstimatorCellTypeClassifier(DATA_PATH)\n", 101 | "seed_everything(1)\n", 102 | "estim.init_datamodule(batch_size=2048)\n", 103 | "estim.init_trainer(\n", 104 | " trainer_kwargs={\n", 105 | " 'max_epochs': 50,\n", 106 | " 'gradient_clip_val': 1.,\n", 107 | " 'gradient_clip_algorithm': 'norm',\n", 108 | " 'default_root_dir': CHECKPOINT_PATH,\n", 109 | " 'accelerator': 'gpu',\n", 110 | " 'devices': 1,\n", 111 | " 'num_sanity_val_steps': 0,\n", 112 | " 'check_val_every_n_epoch': 2,\n", 113 | " 'logger': [TensorBoardLogger(LOGS_PATH, name='default', version='version_2_no_augment')],\n", 114 | " 'log_every_n_steps': 100,\n", 115 | " 'detect_anomaly': False,\n", 116 | " 'enable_progress_bar': True,\n", 117 | " 'enable_model_summary': False,\n", 118 | " 'enable_checkpointing': True,\n", 119 | " 'callbacks': [\n", 120 | " TQDMProgressBar(refresh_rate=50),\n", 121 | " LearningRateMonitor(logging_interval='step'),\n", 122 | " ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',\n", 123 | " every_n_epochs=1, save_top_k=2),\n", 124 | " ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',\n", 125 | " every_n_epochs=1, save_top_k=2)\n", 126 | " ],\n", 127 | " }\n", 128 | ")\n", 129 | "estim.init_model(\n", 130 | " model_type='tabnet',\n", 131 | " model_kwargs={\n", 132 | " 'learning_rate': 0.005,\n", 133 | " 'weight_decay': 0.05,\n", 134 | " 'lr_scheduler': torch.optim.lr_scheduler.StepLR,\n", 135 | " 'lr_scheduler_kwargs': {\n", 136 | " 'step_size': 2,\n", 137 | " 'gamma': 0.9,\n", 138 | " 'verbose': True\n", 139 | " },\n", 140 | " 'optimizer': torch.optim.AdamW,\n", 141 | " 'lambda_sparse': 1e-5,\n", 142 | " 'n_d': 128,\n", 143 | " 'n_a': 64,\n", 144 | " 'n_steps': 1,\n", 145 | " 'gamma': 1.3,\n", 146 | " 'n_independent': 7,\n", 147 | " 'n_shared': 3,\n", 148 | " 'virtual_batch_size': 256,\n", 149 | " 'mask_type': 'entmax',\n", 150 | " 'augment_training_data': False\n", 151 | " },\n", 152 | ")\n", 153 | "print(ModelSummary(estim.model))\n" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": { 166 | "tags": [] 167 | }, 168 | "source": [ 169 | "# Find learning rate" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "collapsed": false, 177 | "jupyter": { 178 | "outputs_hidden": false 179 | }, 180 | "tags": [] 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 100})" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "tags": [] 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])\n", 196 | "ax.set_xscale('log')\n", 197 | "ax.set_ylim(5.25, top=7.)\n", 198 | "ax.set_xlim(1e-6, 10.)\n", 199 | "print(f'Suggested learning rate: {lr_find_res[0]}')" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "# Fit model" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": { 213 | "collapsed": false, 214 | "jupyter": { 215 | "outputs_hidden": false 216 | }, 217 | "tags": [] 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "estim.train()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "Python 3 (ipykernel)", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.8.10" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 4 253 | } 254 | -------------------------------------------------------------------------------- /notebooks/training/train_xgboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "948708a6-e94e-449a-90fc-43d327d12674", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from os.path import join\n", 13 | "\n", 14 | "import xgboost as xgb\n", 15 | "import numpy as np\n", 16 | "import dask.dataframe as dd" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "98089ad9-efdd-459f-8545-85ddcc4dcb46", 23 | "metadata": { 24 | "tags": [] 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "abdf1337-e80e-4e15-83b3-056a2a6a01f2", 35 | "metadata": { 36 | "tags": [] 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "x_train = np.load(join(DATA_PATH, 'pca/x_pca_training_train_split_256.npy'))\n", 41 | "y_train = dd.read_parquet(join(DATA_PATH, 'train'), columns='cell_type').compute().to_numpy()\n", 42 | "\n", 43 | "x_val = np.load(join(DATA_PATH, 'pca/x_pca_training_val_split_256.npy'))\n", 44 | "y_val = dd.read_parquet(join(DATA_PATH, 'val'), columns='cell_type').compute().to_numpy()\n", 45 | "\n", 46 | "class_weights = np.load(join(DATA_PATH, 'class_weights.npy'))" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "ac43212b-edf1-4cee-9116-c2d887b5a6f1", 53 | "metadata": { 54 | "tags": [] 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "class_weights = {i: weight for i, weight in enumerate(np.load(join(DATA_PATH, 'class_weights.npy')))}\n", 59 | "weights = np.array([class_weights[label] for label in y_train])" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "72a702fa-3256-409c-b522-4e7869a020ab", 66 | "metadata": { 67 | "tags": [] 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "clf = xgb.XGBClassifier(\n", 72 | " tree_method='gpu_hist',\n", 73 | " gpu_id=0,\n", 74 | " n_estimators=1000,\n", 75 | " eta=0.075,\n", 76 | " subsample=0.75,\n", 77 | " max_depth=10,\n", 78 | " n_jobs=20,\n", 79 | " early_stopping_rounds=10\n", 80 | ")\n", 81 | "clf = clf.fit(\n", 82 | " x_train, y_train, sample_weight=weights, \n", 83 | " eval_set=[(x_val, y_val)]\n", 84 | ")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "8723fffb-786e-48e0-a021-6a5d1e4fdf83", 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "clf.save_model('model.json')" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "f01dc191-6f72-4d8f-9a7b-6f5bc46f2ad7", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3 (ipykernel)", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.8.10" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 5 129 | } 130 | -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | cudf-cu11>=23.02 2 | rmm-cu11>=23.02 3 | dask-cudf-cu11>=23.02 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=62.1.0 2 | pyarrow>=10.0.0 3 | merlin-dataloader>=23.2 4 | merlin-core>=23.2 5 | numpy>=1.21.0 6 | pandas>=1.5 7 | torch>=1.12 8 | lightning>2.0.0 9 | torchmetrics>=0.11.1 10 | tensorboard>=2.11 11 | scipy>=1.10 12 | scikit-learn==1.2.2 13 | sparqlwrapper>=2.0.0 14 | -------------------------------------------------------------------------------- /scripts/create_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ~/"$USER" || exit 4 | # load modules 5 | ml purge 6 | ml Stages/2023 7 | ml GCC/11.3.0 8 | ml OpenMPI/4.1.4 9 | ml CUDA/11.7 10 | ml cuDNN/8.6.0.163-CUDA-11.7 11 | ml NCCL/default-CUDA-11.7 12 | ml Python/3.10.4 13 | 14 | python -m venv --system-site-packages merlin-torch 15 | source merlin-torch/bin/activate 16 | 17 | python -m pip install cudf-cu11==23.02 rmm-cu11==23.02 dask-cudf-cu11==23.02 --extra-index-url https://pypi.nvidia.com/ 18 | python -m pip install torch torchvision torchaudio 19 | python -m pip install merlin-dataloader 20 | python -m pip install lightning 21 | python -m pip install tensorboard 22 | python -m pip install -e git/cellnet --no-deps 23 | -------------------------------------------------------------------------------- /scripts/py_scripts/CIForm.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from os.path import join 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | warnings.filterwarnings('ignore') 8 | import anndata 9 | import math 10 | import scanpy as sc 11 | import dask.dataframe as dd 12 | import pandas as pd 13 | from tqdm.auto import tqdm 14 | from torch.utils.data import (DataLoader,Dataset) 15 | torch.set_default_tensor_type(torch.DoubleTensor) 16 | import numpy as np 17 | import random 18 | from sklearn import preprocessing 19 | 20 | 21 | ##Set random seeds 22 | def same_seeds(seed): 23 | random.seed(seed) 24 | # Numpy 25 | np.random.seed(seed) 26 | # Torch 27 | torch.manual_seed(seed) 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | 34 | 35 | same_seeds(2021) 36 | 37 | 38 | """ 39 | Code taken and slightly adapted from: https://github.com/zhanglab-wbgcas/CIForm/blob/main/Tutorial/Tutorial_Inter.ipynb 40 | """ 41 | 42 | 43 | ##Gene embedding, 44 | # Function 45 | # The pre-processed scRNA-seq data is converted into a form acceptable to the Transformer encoder 46 | # Parameters 47 | # gap: The length of a sub-vector 48 | # adata: pre-processed scRNA-seq data. The rows represent the cells and the columns represent the genes 49 | # Traindata_paths: the paths of the cell type labels file(.csv) corresponding to the training data 50 | def getXY(gap, adata): 51 | # Converting the gene expression matrix into sub-vectors 52 | # (n_cells,n_genes) -> (n_cells,gap_num,gap) gap_num = int(gene_num / gap) + 1 53 | X = adata.X # getting the gene expression matrix 54 | single_cell_list = [] 55 | for single_cell in X: 56 | feature = [] 57 | length = len(single_cell) 58 | # spliting the gene expression vector into some sub-vectors whose length is gap 59 | for k in range(0, length, gap): 60 | if (k + gap <= length): 61 | a = single_cell[k:k + gap] 62 | else: 63 | a = single_cell[length - gap:length] 64 | # scaling each sub-vectors 65 | a = preprocessing.scale(a) 66 | feature.append(a) 67 | feature = np.asarray(feature) 68 | single_cell_list.append(feature) 69 | 70 | single_cell_list = np.asarray(single_cell_list) # (n_cells,gap_num,gap) 71 | 72 | return single_cell_list, adata.obs.cell_type.to_numpy() 73 | 74 | 75 | ##Function 76 | # Converting label annotation to numeric form 77 | ##Parameters 78 | # cells: all cell type labels 79 | # cell_types: all cell types of Training datasets 80 | def getNewData(cells, cell_types): 81 | labels = [] 82 | for i in range(len(cells)): 83 | cell = cells[i] 84 | cell = str(cell).upper() 85 | 86 | if (cell_types.__contains__(cell)): 87 | indexs = cell_types.index(cell) 88 | labels.append(indexs + 1) 89 | else: 90 | labels.append(0) # 0 denotes the unknowns cell types 91 | 92 | return np.asarray(labels) 93 | 94 | 95 | class TrainDataSet(Dataset): 96 | def __init__(self, data, label): 97 | self.data = data 98 | self.label = label 99 | self.length = len(data) 100 | 101 | def __len__(self): 102 | return self.length 103 | 104 | def __getitem__(self, index): 105 | data = torch.from_numpy(self.data) 106 | label = torch.from_numpy(self.label) 107 | 108 | return data[index], label[index] 109 | 110 | 111 | class TestDataSet(Dataset): 112 | def __init__(self, data): 113 | self.data = data 114 | 115 | self.length = len(data) 116 | 117 | def __len__(self): 118 | return self.length 119 | 120 | def __getitem__(self, index): 121 | data = torch.from_numpy(self.data) 122 | return data[index] 123 | 124 | 125 | ##Positional Encoder Layer 126 | class PositionalEncoding(nn.Module): 127 | def __init__(self, d_model, dropout=0.1, max_len=5000): 128 | super(PositionalEncoding, self).__init__() 129 | self.dropout = nn.Dropout(p=dropout) 130 | pe = torch.zeros(max_len, d_model) 131 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 132 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 133 | 134 | ##the sine function is used to represent the odd-numbered sub-vectors 135 | pe[:, 0::2] = torch.sin(position * div_term) 136 | ##the cosine function is used to represent the even-numbered sub-vectors 137 | pe[:, 1::2] = torch.cos(position * div_term) 138 | pe = pe.unsqueeze(0).transpose(0, 1) 139 | self.register_buffer('pe', pe) 140 | 141 | def forward(self, x): 142 | x = x + self.pe[:x.size(0), :] 143 | return self.dropout(x) 144 | 145 | 146 | ##CIForm 147 | ##function 148 | # annotating cell type identification of scRNA-seq data 149 | ##parameters 150 | # input_dim :Default is equal to gap 151 | # nhead :Number of heads in the attention mechanism 152 | # d_model :Default is equal to gap 153 | # num_classes:Number of cell types 154 | # dropout :dropout rate which is used to prevent model overfitting 155 | class CIForm(nn.Module): 156 | def __init__(self, input_dim, nhead=2, d_model=80, num_classes=2, dropout=0.1): 157 | super().__init__() 158 | # TransformerEncoderLayer with self-attention 159 | self.encoder_layer = nn.TransformerEncoderLayer( 160 | d_model=d_model, dim_feedforward=1024, nhead=nhead, dropout=dropout 161 | ) 162 | 163 | # Positional Encoding with self-attention 164 | self.positionalEncoding = PositionalEncoding(d_model=d_model, dropout=dropout) 165 | 166 | # Classification layer 167 | self.pred_layer = nn.Sequential( 168 | nn.Linear(d_model, d_model), 169 | nn.ReLU(), 170 | nn.Linear(d_model, num_classes) 171 | ) 172 | 173 | def forward(self, mels): 174 | out = mels.permute(1, 0, 2) 175 | # Positional Encoding layer 176 | out = self.positionalEncoding(out) 177 | # Transformer Encoder layer layer 178 | out = self.encoder_layer(out) 179 | out = out.transpose(0, 1) 180 | # Pooling layer 181 | out = out.mean(dim=1) 182 | # Classification layer 183 | out = self.pred_layer(out) 184 | return out 185 | 186 | 187 | ##main 188 | ##parameters 189 | # s :the length of a sub-vector 190 | # referece_datapath :the paths of referece datasets 191 | # Train_names :the names of referece datasets 192 | # Testdata_path :the path pf test dataset 193 | # Testdata_name :the name of test dataset 194 | def main(s, adata_train, adata_test): 195 | gap = s # the length of a sub-vector 196 | d_models = s 197 | heads = 64 # the number of heads in self-attention mechanism 198 | 199 | lr = 0.0001 # learning rate 200 | dp = 0.1 # dropout rate 201 | batch_sizes = 256 # the size of batch 202 | n_epochs = 20 # the number of epoch 203 | 204 | # Getting the data which input into the CIForm 205 | train_data, labels = getXY(gap, adata_train) 206 | test_data, _ = getXY(gap, adata_test) 207 | num_classes = 164 208 | 209 | # Constructing the CIForm model 210 | model = CIForm(input_dim=d_models, nhead=heads, d_model=d_models, num_classes=num_classes, dropout=dp) 211 | # Setting loss function 212 | criterion = nn.CrossEntropyLoss() 213 | # Setting optimization function 214 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) 215 | # Setting the training dataset 216 | train_dataset = TrainDataSet(data=train_data, label=labels) 217 | train_loader = DataLoader(train_dataset, batch_size=batch_sizes, shuffle=True, pin_memory=True) 218 | # Setting the test dataset 219 | test_dataset = TestDataSet(data=test_data) 220 | test_loader = DataLoader(test_dataset, batch_size=batch_sizes, shuffle=False, pin_memory=True) 221 | 222 | # starting training CIForm.Using training data to train CIForm 223 | # n_epochs: the times of Training 224 | model.train() 225 | for epoch in range(n_epochs): 226 | for batch in tqdm(train_loader): 227 | # A batch consists of scRNA-seq data and corresponding cell type annotation. 228 | data, labels = batch 229 | logits = model(data) 230 | labels = torch.tensor(labels, dtype=torch.long) 231 | loss = criterion(logits, labels) 232 | optimizer.zero_grad() 233 | loss.backward() 234 | optimizer.step() 235 | 236 | torch.save(model.state_dict(), '/mnt/dssfs02/tb_logs/CIForm/CIForm.tar') 237 | ##Starting the validation model, which predicts the cell types in the test dataset 238 | model.eval() 239 | y_predict = [] 240 | for batch in tqdm(test_loader): 241 | data = batch 242 | with torch.no_grad(): 243 | logits = model(data) 244 | # Getting the predicted cell type 245 | preds = logits.argmax(1) 246 | preds = preds.cpu().numpy().tolist() 247 | y_predict.extend(preds) 248 | 249 | with open('/mnt/dssfs02/tb_logs/CIForm/CIForm_preds.npy', 'wb') as f: 250 | np.save(f, np.array(y_predict)) 251 | 252 | 253 | def get_count_matrix(ddf): 254 | x = ( 255 | ddf['X'] 256 | .map_partitions( 257 | lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 258 | meta={col: 'f4' for col in range(19331)} 259 | ) 260 | .to_dask_array(lengths=[1024] * ddf.npartitions) 261 | ) 262 | 263 | return x 264 | 265 | 266 | def get_adata(split, hvg_mask=None, max_cells: int = None): 267 | data_path = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p' 268 | 269 | ddf = dd.read_parquet(join(data_path, split), split_row_groups=True) 270 | if hvg_mask is None: 271 | x = get_count_matrix(ddf)[:max_cells, :].compute() 272 | var = pd.read_parquet(join(data_path, 'var.parquet')) 273 | else: 274 | x = get_count_matrix(ddf)[:max_cells, hvg_mask].compute() 275 | var = pd.read_parquet(join(data_path, 'var.parquet')).iloc[hvg_mask] 276 | obs = dd.read_parquet(join(data_path, split), columns=['cell_type']).compute().iloc[:max_cells] 277 | 278 | return anndata.AnnData(X=x, obs=obs, var=var) 279 | 280 | 281 | if __name__ == '__main__': 282 | adata_train = get_adata('train', max_cells=750_000) 283 | sc.pp.highly_variable_genes(adata_train, n_top_genes=2000) 284 | hvgs = adata_train.var.highly_variable.to_numpy() 285 | adata_train = adata_train[:, hvgs].copy() 286 | adata_test = get_adata('test', hvg_mask=hvgs, max_cells=None) 287 | 288 | s = 1024 # the length of a sub-vector 289 | main(s, adata_train, adata_test) 290 | -------------------------------------------------------------------------------- /scripts/py_scripts/scGPT-inference.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from pathlib import Path 3 | 4 | import scanpy as sc 5 | import scgpt as scg 6 | 7 | 8 | SAVE_PATH = '/mnt/dssfs02/scGPT-splits' 9 | model_dir = Path("/mnt/dssfs02/scTab-checkpoints/scGPT") 10 | cell_type_key = "cell_type" 11 | gene_col = "index" 12 | 13 | 14 | for i in range(10): 15 | adata = sc.read_h5ad(join(SAVE_PATH, 'train', f'{i}.h5ad')) 16 | adata = scg.tasks.embed_data( 17 | adata, 18 | model_dir, 19 | cell_type_key=cell_type_key, 20 | gene_col=gene_col, 21 | batch_size=64, 22 | return_new_adata=True, 23 | ).write_h5ad(join(SAVE_PATH, 'train', f'{i}_embed.h5ad')) 24 | 25 | 26 | for i in range(30): 27 | adata = sc.read_h5ad(join(SAVE_PATH, 'test', f'{i}.h5ad')) 28 | adata = scg.tasks.embed_data( 29 | adata, 30 | model_dir, 31 | cell_type_key=cell_type_key, 32 | gene_col=gene_col, 33 | batch_size=64, 34 | return_new_adata=True, 35 | ).write_h5ad(join(SAVE_PATH, 'test', f'{i}_embed.h5ad')) 36 | -------------------------------------------------------------------------------- /scripts/py_scripts/train_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from random import uniform 3 | from time import sleep 4 | 5 | import torch 6 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar 7 | from lightning.pytorch.loggers import TensorBoardLogger 8 | from lightning.pytorch.utilities.model_summary import ModelSummary 9 | from lightning.pytorch import seed_everything 10 | 11 | from cellnet.estimators import EstimatorCellTypeClassifier 12 | from utils import get_paths, get_model_checkpoint 13 | 14 | 15 | torch.set_float32_matmul_precision('high') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--cluster', type=str) 21 | parser.add_argument('--data_path', type=str, default=None) 22 | parser.add_argument('--logging_dir', type=str, default='cxg_2023_05_15_linear') 23 | 24 | parser.add_argument('--epochs', default=1000, type=int) 25 | parser.add_argument('--batch_size', default=2048, type=int) 26 | parser.add_argument('--lr', default=0.0005, type=float) 27 | parser.add_argument('--weight_decay', default=0.01, type=float) 28 | parser.add_argument('--use_class_weights', default=True, type=lambda x: x.lower() in ['true', '1', '1.']) 29 | parser.add_argument('--lr_scheduler_step_size', default=1, type=int) 30 | parser.add_argument('--lr_scheduler_gamma', default=0.9, type=float) 31 | parser.add_argument('--version', default=None, type=str) 32 | 33 | parser.add_argument('--resume_from_checkpoint', type=str, default=None) 34 | parser.add_argument('--checkpoint_interval', default=1, type=int) 35 | 36 | parser.add_argument('--seed', default=1, type=int) 37 | 38 | return parser.parse_args() 39 | 40 | 41 | if __name__ == '__main__': 42 | args = parse_args() 43 | print(args) 44 | 45 | # config parameters 46 | MODEL = args.logging_dir 47 | CHECKPOINT_PATH, LOGS_PATH, DATA_PATH = get_paths(args.cluster, MODEL) 48 | if args.data_path is not None: 49 | DATA_PATH = args.data_path 50 | 51 | sleep(uniform(0., 30.)) # add random sleep interval to avoid duplicated tensorboard log dirs 52 | estim = EstimatorCellTypeClassifier(DATA_PATH) 53 | seed_everything(args.seed) 54 | estim.init_datamodule(batch_size=args.batch_size) 55 | estim.init_trainer( 56 | trainer_kwargs={ 57 | 'max_epochs': args.epochs, 58 | 'default_root_dir': CHECKPOINT_PATH, 59 | 'accelerator': 'gpu', 60 | 'devices': 1, 61 | 'num_sanity_val_steps': 0, 62 | 'check_val_every_n_epoch': 1, 63 | 'logger': [TensorBoardLogger(LOGS_PATH, name='default', version=args.version)], 64 | 'log_every_n_steps': 100, 65 | 'detect_anomaly': False, 66 | 'enable_progress_bar': True, 67 | 'enable_model_summary': False, 68 | 'enable_checkpointing': True, 69 | 'callbacks': [ 70 | TQDMProgressBar(refresh_rate=250), 71 | LearningRateMonitor(logging_interval='step'), 72 | ModelCheckpoint(filename='last_{epoch}', every_n_epochs=args.checkpoint_interval), 73 | ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max', 74 | every_n_epochs=args.checkpoint_interval, save_top_k=2), 75 | ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min', 76 | every_n_epochs=args.checkpoint_interval, save_top_k=2) 77 | ], 78 | } 79 | ) 80 | estim.init_model( 81 | model_type='linear', 82 | model_kwargs={ 83 | 'learning_rate': args.lr, 84 | 'weight_decay': args.weight_decay, 85 | 'use_class_weights': args.use_class_weights, 86 | 'optimizer': torch.optim.AdamW, 87 | 'lr_scheduler': torch.optim.lr_scheduler.StepLR, 88 | 'lr_scheduler_kwargs': { 89 | 'step_size': args.lr_scheduler_step_size, 90 | 'gamma': args.lr_scheduler_gamma, 91 | 'verbose': True 92 | }, 93 | }, 94 | ) 95 | print(ModelSummary(estim.model)) 96 | estim.train(ckpt_path=get_model_checkpoint(CHECKPOINT_PATH, args.resume_from_checkpoint)) 97 | -------------------------------------------------------------------------------- /scripts/py_scripts/train_mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from random import uniform 3 | from time import sleep 4 | 5 | import torch 6 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar 7 | from lightning.pytorch.loggers import TensorBoardLogger 8 | from lightning.pytorch.utilities.model_summary import ModelSummary 9 | from lightning.pytorch import seed_everything 10 | 11 | from cellnet.estimators import EstimatorCellTypeClassifier 12 | from utils import get_paths, get_model_checkpoint 13 | 14 | 15 | torch.set_float32_matmul_precision('high') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--cluster', type=str) 21 | parser.add_argument('--data_path', type=str, default=None) 22 | parser.add_argument('--logging_dir', type=str, default='cxg_2023_05_15_mlp') 23 | 24 | parser.add_argument('--epochs', default=1000, type=int) 25 | parser.add_argument('--batch_size', default=2048, type=int) 26 | parser.add_argument('--sub_sample_frac', default=1., type=float) 27 | parser.add_argument('--lr', default=0.002, type=float) 28 | parser.add_argument('--weight_decay', default=0.05, type=float) 29 | parser.add_argument('--hidden_size', default=128, type=int) 30 | parser.add_argument('--n_hidden', default=8, type=int) 31 | parser.add_argument('--dropout', default=0.1, type=float) 32 | parser.add_argument('--augment_training_data', default=True, type=lambda x: x.lower() in ['true', '1', '1.']) 33 | parser.add_argument('--lr_scheduler_step_size', default=1, type=int) 34 | parser.add_argument('--lr_scheduler_gamma', default=0.9, type=float) 35 | parser.add_argument('--version', default=None, type=str) 36 | 37 | parser.add_argument('--resume_from_checkpoint', type=str, default=None) 38 | parser.add_argument('--checkpoint_interval', default=1, type=int) 39 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 40 | 41 | parser.add_argument('--seed', default=1, type=int) 42 | 43 | return parser.parse_args() 44 | 45 | 46 | if __name__ == '__main__': 47 | args = parse_args() 48 | print(args) 49 | 50 | # config parameters 51 | MODEL = args.logging_dir 52 | CHECKPOINT_PATH, LOGS_PATH, DATA_PATH = get_paths(args.cluster, MODEL) 53 | if args.data_path is not None: 54 | DATA_PATH = args.data_path 55 | 56 | sleep(uniform(0., 30.)) # add random sleep interval to avoid duplicated tensorboard log dirs 57 | estim = EstimatorCellTypeClassifier(DATA_PATH) 58 | seed_everything(args.seed) 59 | estim.init_datamodule(batch_size=args.batch_size, sub_sample_frac=args.sub_sample_frac) 60 | estim.init_trainer( 61 | trainer_kwargs={ 62 | 'max_epochs': args.epochs, 63 | 'gradient_clip_val': 1., 64 | 'gradient_clip_algorithm': 'norm', 65 | 'accelerator': 'gpu', 66 | 'devices': 1, 67 | 'num_sanity_val_steps': 0, 68 | 'check_val_every_n_epoch': args.check_val_every_n_epoch, 69 | 'logger': [TensorBoardLogger(LOGS_PATH, name='default', version=args.version)], 70 | 'log_every_n_steps': 200, 71 | 'detect_anomaly': False, 72 | 'enable_progress_bar': True, 73 | 'enable_model_summary': False, 74 | 'enable_checkpointing': True, 75 | 'default_root_dir': CHECKPOINT_PATH, 76 | 'callbacks': [ 77 | TQDMProgressBar(refresh_rate=250), 78 | LearningRateMonitor(logging_interval='step'), 79 | ModelCheckpoint(filename='last_{epoch}', every_n_epochs=args.checkpoint_interval), 80 | ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max', 81 | every_n_epochs=args.checkpoint_interval, save_top_k=2), 82 | ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min', 83 | every_n_epochs=args.checkpoint_interval, save_top_k=2) 84 | ], 85 | } 86 | ) 87 | estim.init_model( 88 | model_type='mlp', 89 | model_kwargs={ 90 | 'learning_rate': args.lr, 91 | 'weight_decay': args.weight_decay, 92 | 'lr_scheduler': torch.optim.lr_scheduler.StepLR, 93 | 'lr_scheduler_kwargs': { 94 | 'step_size': args.lr_scheduler_step_size, 95 | 'gamma': args.lr_scheduler_gamma, 96 | 'verbose': True 97 | }, 98 | 'optimizer': torch.optim.AdamW, 99 | 'hidden_size': args.hidden_size, 100 | 'n_hidden': args.n_hidden, 101 | 'augment_training_data': args.augment_training_data 102 | }, 103 | ) 104 | print(ModelSummary(estim.model)) 105 | estim.train(ckpt_path=get_model_checkpoint(CHECKPOINT_PATH, args.resume_from_checkpoint)) 106 | -------------------------------------------------------------------------------- /scripts/py_scripts/train_tabnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from random import uniform 3 | from time import sleep 4 | 5 | import torch 6 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar 7 | from lightning.pytorch.loggers import TensorBoardLogger 8 | from lightning.pytorch.utilities.model_summary import ModelSummary 9 | from lightning.pytorch import seed_everything 10 | 11 | from cellnet.estimators import EstimatorCellTypeClassifier 12 | from utils import get_paths, get_model_checkpoint 13 | 14 | 15 | torch.set_float32_matmul_precision('high') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--cluster', type=str) 21 | parser.add_argument('--data_path', type=str, default=None) 22 | parser.add_argument('--logging_dir', type=str, default='cxg_2023_05_15_tabnet') 23 | 24 | parser.add_argument('--epochs', default=1000, type=int) 25 | parser.add_argument('--batch_size', default=2048, type=int) 26 | parser.add_argument('--sub_sample_frac', default=1., type=float) 27 | parser.add_argument('--lr', default=0.005, type=float) 28 | parser.add_argument('--weight_decay', default=0.05, type=float) 29 | parser.add_argument('--use_class_weights', default=True, type=lambda x: x.lower() in ['true', '1', '1.']) 30 | parser.add_argument('--lambda_sparse', default=1e-5, type=float) 31 | parser.add_argument('--n_d', default=128, type=int) 32 | parser.add_argument('--n_a', default=64, type=int) 33 | parser.add_argument('--n_steps', default=1, type=int) 34 | parser.add_argument('--gamma', default=1.3, type=float) 35 | parser.add_argument('--n_independent', default=5, type=int) 36 | parser.add_argument('--n_shared', default=3, type=int) 37 | parser.add_argument('--virtual_batch_size', default=256, type=int) 38 | parser.add_argument('--mask_type', default='entmax', type=str) 39 | parser.add_argument('--augment_training_data', default=True, type=lambda x: x.lower() in ['true', '1', '1.']) 40 | parser.add_argument('--lr_scheduler_step_size', default=1, type=int) 41 | parser.add_argument('--lr_scheduler_gamma', default=0.9, type=float) 42 | parser.add_argument('--version', default=None, type=str) 43 | 44 | parser.add_argument('--resume_from_checkpoint', type=str, default=None) 45 | parser.add_argument('--checkpoint_interval', default=1, type=int) 46 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 47 | 48 | parser.add_argument('--seed', default=1, type=int) 49 | 50 | return parser.parse_args() 51 | 52 | 53 | if __name__ == '__main__': 54 | args = parse_args() 55 | print(args) 56 | 57 | # config parameters 58 | MODEL = args.logging_dir 59 | CHECKPOINT_PATH, LOGS_PATH, DATA_PATH = get_paths(args.cluster, MODEL) 60 | if args.data_path is not None: 61 | DATA_PATH = args.data_path 62 | 63 | sleep(uniform(0., 30.)) # add random sleep interval to avoid duplicated tensorboard log dirs 64 | estim = EstimatorCellTypeClassifier(DATA_PATH) 65 | seed_everything(args.seed) 66 | estim.init_datamodule(batch_size=args.batch_size, sub_sample_frac=args.sub_sample_frac) 67 | estim.init_trainer( 68 | trainer_kwargs={ 69 | 'max_epochs': args.epochs, 70 | 'gradient_clip_val': 1., 71 | 'gradient_clip_algorithm': 'norm', 72 | 'accelerator': 'gpu', 73 | 'devices': 1, 74 | 'num_sanity_val_steps': 0, 75 | 'check_val_every_n_epoch': args.check_val_every_n_epoch, 76 | 'logger': [TensorBoardLogger(LOGS_PATH, name='default', version=args.version)], 77 | 'log_every_n_steps': 200, 78 | 'detect_anomaly': False, 79 | 'enable_progress_bar': True, 80 | 'enable_model_summary': False, 81 | 'enable_checkpointing': True, 82 | 'default_root_dir': CHECKPOINT_PATH, 83 | 'callbacks': [ 84 | TQDMProgressBar(refresh_rate=250), 85 | LearningRateMonitor(logging_interval='step'), 86 | ModelCheckpoint(filename='last_{epoch}', every_n_epochs=args.checkpoint_interval), 87 | ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max', 88 | every_n_epochs=args.checkpoint_interval, save_top_k=2), 89 | ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min', 90 | every_n_epochs=args.checkpoint_interval, save_top_k=2) 91 | ], 92 | } 93 | ) 94 | estim.init_model( 95 | model_type='tabnet', 96 | model_kwargs={ 97 | 'learning_rate': args.lr, 98 | 'weight_decay': args.weight_decay, 99 | 'use_class_weights': args.use_class_weights, 100 | 'lr_scheduler': torch.optim.lr_scheduler.StepLR, 101 | 'lr_scheduler_kwargs': { 102 | 'step_size': args.lr_scheduler_step_size, 103 | 'gamma': args.lr_scheduler_gamma, 104 | 'verbose': True 105 | }, 106 | 'optimizer': torch.optim.AdamW, 107 | 'lambda_sparse': args.lambda_sparse, 108 | 'n_d': args.n_d, 109 | 'n_a': args.n_a, 110 | 'n_steps': args.n_steps, 111 | 'gamma': args.gamma, 112 | 'n_independent': args.n_independent, 113 | 'n_shared': args.n_shared, 114 | 'virtual_batch_size': args.virtual_batch_size, 115 | 'mask_type': args.mask_type, 116 | 'augment_training_data': args.augment_training_data 117 | }, 118 | ) 119 | print(ModelSummary(estim.model)) 120 | estim.train(ckpt_path=get_model_checkpoint(CHECKPOINT_PATH, args.resume_from_checkpoint)) 121 | -------------------------------------------------------------------------------- /scripts/py_scripts/train_xgboost.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from os.path import join 4 | 5 | import dask.dataframe as dd 6 | import numpy as np 7 | import xgboost as xgb 8 | 9 | from utils import get_paths 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--cluster', type=str) 16 | parser.add_argument('--logging_dir', type=str, default='cxg_2023_05_15_xgboost') 17 | parser.add_argument('--version', type=str) 18 | parser.add_argument('--seed', default=1, type=int) 19 | 20 | parser.add_argument('--n_estimators', type=int, default=800) 21 | parser.add_argument('--eta', type=float, default=0.05) 22 | parser.add_argument('--subsample', type=float, default=0.75) 23 | parser.add_argument('--max_depth', type=int, default=10) 24 | parser.add_argument('--early_stopping_rounds', type=int, default=10) 25 | 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | print(args) 32 | 33 | # config parameters 34 | MODEL = args.logging_dir 35 | CHECKPOINT_PATH, LOGS_PATH, DATA_PATH = get_paths(args.cluster, MODEL) 36 | # save hparams to json 37 | with open(join(CHECKPOINT_PATH, f'{args.version}_hparams.json'), 'w') as f: 38 | json.dump(vars(args), f, indent=4) 39 | 40 | # load training + val data 41 | x_train = np.load(join(DATA_PATH, 'pca/x_pca_training_train_split_256.npy')) 42 | y_train = dd.read_parquet(join(DATA_PATH, 'train'), columns='cell_type').compute().to_numpy() 43 | x_val = np.load(join(DATA_PATH, 'pca/x_pca_training_val_split_256.npy')) 44 | y_val = dd.read_parquet(join(DATA_PATH, 'val'), columns='cell_type').compute().to_numpy() 45 | class_weights = {i: weight for i, weight in enumerate(np.load(join(DATA_PATH, 'class_weights.npy')))} 46 | weights = np.array([class_weights[label] for label in y_train]) 47 | 48 | clf = xgb.XGBClassifier( 49 | tree_method='gpu_hist', 50 | n_estimators=args.n_estimators, 51 | eta=args.eta, 52 | subsample=args.subsample, 53 | max_depth=args.max_depth, 54 | early_stopping_rounds=args.early_stopping_rounds, 55 | random_state=args.seed 56 | ) 57 | clf = clf.fit(x_train, y_train, sample_weight=weights, eval_set=[(x_val, y_val)]) 58 | clf.save_model(join(CHECKPOINT_PATH, f'{args.version}.json')) 59 | -------------------------------------------------------------------------------- /scripts/py_scripts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_paths(cluster: str, model: str): 5 | if cluster == 'jsc': 6 | return ( 7 | os.path.join('/p/scratch/hai_cellnet/tb_logs/', model), 8 | os.path.join('/p/scratch/hai_cellnet/tb_logs', model), 9 | '/p/scratch/hai_cellnet/merlin_cxg_2023_05_15_sf-log1p' 10 | ) 11 | elif cluster == 'lrz': 12 | return ( 13 | os.path.join('/mnt/dssfs02/tb_logs/', model), 14 | os.path.join('/mnt/dssfs02/tb_logs/', model), 15 | '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p' 16 | ) 17 | elif cluster == 'icb': 18 | return ( 19 | os.path.join('/lustre/scratch/users/felix.fischer/tb_logs', model), 20 | os.path.join('/lustre/scratch/users/felix.fischer/tb_logs', model), 21 | '/lustre/scratch/users/felix.fischer/merlin_cxg_2023_05_15_sf-log1p' 22 | ) 23 | else: 24 | raise ValueError(f'Only "jsc", "icb" or "lrz" are supported as cluster. You supplied: {cluster}') 25 | 26 | 27 | def get_model_checkpoint(checkpoint_path, checkpoint): 28 | if checkpoint is None: 29 | return None 30 | else: 31 | return os.path.join(checkpoint_path, 'default', checkpoint) 32 | -------------------------------------------------------------------------------- /scripts/scGPT-inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J scGPT 4 | #SBATCH --output=slurm_out/scGPT_out.%j 5 | #SBATCH --error=slurm_out/scGPT_err.%j 6 | #SBATCH --partition=mcml-dgx-a100-40x8 7 | #SBATCH --qos mcml 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --time 3-00:00:00 10 | #SBATCH --mem=90GB 11 | #SBATCH --cpus-per-task=6 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 18 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 19 | --container-image="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer/enroot-images/scGPT-jupyter.sqsh" \ 20 | --no-container-remap-root \ 21 | bash -c "python -u /dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/scGPT-inference.py" 22 | -------------------------------------------------------------------------------- /scripts/train_CIForm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CIForm 4 | #SBATCH --output=slurm_out/CIForm_out.%j 5 | #SBATCH --error=slurm_out/CIForm_err.%j 6 | #SBATCH --partition=mcml-dgx-a100-40x8 7 | #SBATCH --qos mcml 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --time 3-00:00:00 10 | #SBATCH --mem=200GB 11 | #SBATCH --cpus-per-task=6 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | SCRIPT="/dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/CIForm.py" 18 | 19 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 20 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 21 | --container-image="/dss/dsshome1/04/di93zer/merlin-2302.sqsh" \ 22 | --no-container-remap-root \ 23 | bash -c "python -u ${SCRIPT}" 24 | -------------------------------------------------------------------------------- /scripts/train_linear_jsc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --account=hai_cellnet 4 | #SBATCH --nodes=1 5 | #SBATCH --output=slurm_out/out.%j 6 | #SBATCH --error=slurm_out/err.%j 7 | #SBATCH --time=24:00:00 8 | #SBATCH --partition=booster 9 | #SBATCH --gres=gpu:4 10 | 11 | cd ~/"$USER" || exit 12 | # load modules 13 | ml purge 14 | ml Stages/2023 15 | ml GCC/11.3.0 16 | ml OpenMPI/4.1.4 17 | ml CUDA/11.7 18 | ml cuDNN/8.6.0.163-CUDA-11.7 19 | ml NCCL/default-CUDA-11.7 20 | ml Python/3.10.4 21 | 22 | source merlin-torch/bin/activate 23 | cd git/cellnet/scripts || exit 24 | 25 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu0.%j --error=slurm_out/err_gpu0.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 26 | python -u py_scripts/train_linear.py \ 27 | --cluster="jsc" \ 28 | --version='version_1' --epochs=35 --seed=1 & 29 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu1.%j --error=slurm_out/err_gpu1.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 30 | python -u py_scripts/train_linear.py \ 31 | --cluster="jsc" \ 32 | --version='version_2' --epochs=35 --seed=2 & 33 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu2.%j --error=slurm_out/err_gpu2.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 34 | python -u py_scripts/train_linear.py \ 35 | --cluster="jsc" \ 36 | --version='version_3' --epochs=35 --seed=3 & 37 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu3.%j --error=slurm_out/err_gpu3.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 38 | python -u py_scripts/train_linear.py \ 39 | --cluster="jsc" \ 40 | --version='version_4' --epochs=35 --seed=4 & 41 | 42 | wait 43 | -------------------------------------------------------------------------------- /scripts/train_linear_lrz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J tabnet 4 | #SBATCH --output=slurm_out/linear_out.%j 5 | #SBATCH --error=slurm_out/linear_err.%j 6 | #SBATCH --partition=mcml-dgx-a100-40x8 7 | #SBATCH --qos mcml 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --time 3-00:00:00 10 | #SBATCH --mem=90GB 11 | #SBATCH --cpus-per-task=6 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | SCRIPT="/dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/train_linear.py" 18 | GIT_REPO="/dss/dsshome1/04/di93zer/git/cellnet" 19 | ARGS="--version=version_1--epochs=20" 20 | 21 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 22 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 23 | --container-image="/dss/dsshome1/04/di93zer/merlin-2302.sqsh" \ 24 | --no-container-remap-root \ 25 | bash -c "pip install -e ${GIT_REPO} --no-deps && python -u ${SCRIPT} --cluster=lrz ${ARGS}" 26 | -------------------------------------------------------------------------------- /scripts/train_mlp_lrz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J mlp 4 | #SBATCH --output=slurm_out/mlp_out.%j 5 | #SBATCH --error=slurm_out/mlp_err.%j 6 | #SBATCH --partition=mcml-dgx-a100-40x8 7 | #SBATCH --qos mcml 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --time 3-00:00:00 10 | #SBATCH --mem=90GB 11 | #SBATCH --cpus-per-task=6 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | SCRIPT="/dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/train_mlp.py" 18 | GIT_REPO="/dss/dsshome1/04/di93zer/git/cellnet" 19 | ARGS="--version=version_1 --epochs=48" 20 | 21 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 22 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 23 | --container-image="/dss/dsshome1/04/di93zer/merlin-2302.sqsh" \ 24 | --no-container-remap-root \ 25 | bash -c "pip install -e ${GIT_REPO} --no-deps && python -u ${SCRIPT} --cluster=lrz ${ARGS}" 26 | -------------------------------------------------------------------------------- /scripts/train_tabnet_jsc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --account=hai_cellnet 4 | #SBATCH --nodes=1 5 | #SBATCH --output=slurm_out/out.%j 6 | #SBATCH --error=slurm_out/err.%j 7 | #SBATCH --time=24:00:00 8 | #SBATCH --partition=booster 9 | #SBATCH --gres=gpu:4 10 | 11 | cd ~/"$USER" || exit 12 | # load modules 13 | ml purge 14 | ml Stages/2023 15 | ml GCC/11.3.0 16 | ml OpenMPI/4.1.4 17 | ml CUDA/11.7 18 | ml cuDNN/8.6.0.163-CUDA-11.7 19 | ml NCCL/default-CUDA-11.7 20 | ml Python/3.10.4 21 | 22 | source merlin-torch/bin/activate 23 | cd git/cellnet/scripts || exit 24 | 25 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu0.%j --error=slurm_out/err_gpu0.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 26 | python -u py_scripts/train_tabnet.py \ 27 | --cluster="jsc" \ 28 | --version='version_1' & 29 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu1.%j --error=slurm_out/err_gpu1.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 30 | python -u py_scripts/train_tabnet.py \ 31 | --cluster="jsc" \ 32 | --version='version_2' & 33 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu2.%j --error=slurm_out/err_gpu2.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 34 | python -u py_scripts/train_tabnet.py \ 35 | --cluster="jsc" \ 36 | --version='version_3' & 37 | srun -n 1 --cpus-per-task=12 --exclusive --output=slurm_out/out_gpu3.%j --error=slurm_out/err_gpu3.%j --cpu-bind=verbose,socket --gres=gpu:1 \ 38 | python -u py_scripts/train_tabnet.py \ 39 | --cluster="jsc" \ 40 | --version='version_4' & 41 | 42 | wait 43 | -------------------------------------------------------------------------------- /scripts/train_tabnet_lrz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J tabnet 4 | #SBATCH --output=slurm_out/tabnet_out.%j 5 | #SBATCH --error=slurm_out/tabnet_err.%j 6 | #SBATCH --partition=mcml-dgx-a100-40x8 7 | #SBATCH --qos mcml 8 | #SBATCH --gres=gpu:1 9 | #SBATCH --time 3-00:00:00 10 | #SBATCH --mem=90GB 11 | #SBATCH --cpus-per-task=6 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | SCRIPT="/dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/train_tabnet.py" 18 | GIT_REPO="/dss/dsshome1/04/di93zer/git/cellnet" 19 | ARGS="--version=version_1 --epochs=48" 20 | 21 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 22 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 23 | --container-image="/dss/dsshome1/04/di93zer/merlin-2302.sqsh" \ 24 | --no-container-remap-root \ 25 | bash -c "pip install -e ${GIT_REPO} --no-deps && python -u ${SCRIPT} --cluster=lrz ${ARGS}" 26 | -------------------------------------------------------------------------------- /scripts/train_xgboost_lrz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J xgboost 4 | #SBATCH --output=slurm_out/xgboost_out.%j 5 | #SBATCH --error=slurm_out/xgboost_err.%j 6 | #SBATCH --partition=lrz-dgx-a100-80x8 7 | #SBATCH --gres=gpu:1 8 | #SBATCH --time 3-00:00:00 9 | #SBATCH --mem=200GB 10 | #SBATCH --cpus-per-task=12 11 | 12 | 13 | DSSFS02_HOME="/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zer" 14 | DSSMCML_HOME="/dss/dssmcmlfs01/pn36po/pn36po-dss-0000/di93zer" 15 | 16 | 17 | srun --cpu-bind=verbose,socket --accel-bind=g --gres=gpu:1 \ 18 | --container-mounts="/dss:/dss,${DSSFS02_HOME}:/mnt/dssfs02,${DSSMCML_HOME}:/mnt/dssmcmlfs01" \ 19 | --container-image="/dss/dsshome1/04/di93zer/merlin-2302.sqsh" \ 20 | --no-container-remap-root \ 21 | python -u /dss/dsshome1/04/di93zer/git/cellnet/scripts/py_scripts/train_xgboost.py --cluster="lrz" \ 22 | --version='version_1' --seed=1 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from setuptools import find_packages 6 | from setuptools import setup 7 | 8 | 9 | def read(filename): 10 | filename = os.path.join(os.path.dirname(__file__), filename) 11 | text_type = type(u"") 12 | with io.open(filename, mode="r", encoding='utf-8') as fd: 13 | return re.sub(text_type(r':[a-z]+:`~?(.*?)`'), text_type(r'``\1``'), fd.read()) 14 | 15 | 16 | with open('requirements.txt') as f: 17 | requirements = f.read().splitlines() 18 | 19 | 20 | setup( 21 | name="cellnet", 22 | version="0.1.0", 23 | url="https://github.com/theislab/cellnet", 24 | license='MIT', 25 | author="Felix Fischer", 26 | author_email="felix.fischer@helmholtz-muenchen.de", 27 | description="Scaling single cell models to bigger data sets.", 28 | long_description=read("README.md"), 29 | packages=find_packages(exclude=('tests',)), 30 | install_requires=requirements, 31 | classifiers=[ 32 | 'Development Status :: 2 - Pre-Alpha', 33 | 'License :: OSI Approved :: MIT License', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Programming Language :: Python :: 3.9', 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /tests/test_sample.py: -------------------------------------------------------------------------------- 1 | # Sample Test passing with nose and pytest 2 | 3 | def test_pass(): 4 | assert True, "dummy sample test" 5 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38,py39 3 | 4 | [testenv] 5 | commands = py.test cellnet 6 | deps = pytest 7 | --------------------------------------------------------------------------------