├── .gitmodules ├── src └── state │ ├── py.typed │ ├── __init__.py │ ├── emb │ ├── nn │ │ ├── __init__.py │ │ ├── flash_transformer.py │ │ └── loss.py │ ├── tools │ │ ├── __init__.py │ │ └── slurm.py │ ├── train │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── trainer.py │ │ └── callbacks.py │ ├── eval │ │ ├── __init__.py │ │ └── emb.py │ ├── data │ │ └── __init__.py │ ├── __init__.py │ ├── finetune_decoder.py │ └── vectordb.py │ ├── tx │ ├── __init__.py │ ├── models │ │ ├── cpa │ │ │ ├── __init__.py │ │ │ └── _callbacks.py │ │ ├── scvi │ │ │ ├── __init__.py │ │ │ ├── _callbacks.py │ │ │ └── _base_modules.py │ │ ├── scgpt │ │ │ ├── __init__.py │ │ │ ├── grad_reverse.py │ │ │ ├── loss.py │ │ │ ├── utils.py │ │ │ └── dsbn.py │ │ ├── __init__.py │ │ ├── decoders_nb.py │ │ ├── decoders.py │ │ ├── embed_sum.py │ │ ├── decoder_only.py │ │ └── old_neural_ot.py │ ├── data │ │ └── dataset │ │ │ └── __init__.py │ ├── utils │ │ └── singleton.py │ └── callbacks │ │ ├── __init__.py │ │ ├── batch_speed_monitor.py │ │ ├── cumulative_flops.py │ │ └── model_flops_utilization.py │ ├── configs │ ├── __init__.py │ ├── wandb │ │ └── default.yaml │ ├── training │ │ ├── scgpt.yaml │ │ ├── scvi.yaml │ │ ├── cpa.yaml │ │ └── default.yaml │ ├── model │ │ ├── embedsum.yaml │ │ ├── scvi.yaml │ │ ├── cpa.yaml │ │ ├── old_neuralot.yaml │ │ ├── celltypemean.yaml │ │ ├── context_mean.yaml │ │ ├── decoder_only.yaml │ │ ├── perturb_mean.yaml │ │ ├── globalsimplesum.yaml │ │ ├── tahoe_best.yaml │ │ ├── tahoe_llama_62089464.yaml │ │ ├── tahoe_llama_212693232.yaml │ │ ├── state.yaml │ │ ├── scgpt-genetic.yaml │ │ ├── scgpt-chemical.yaml │ │ ├── state_sm.yaml │ │ ├── state_lg.yaml │ │ └── pertsets.yaml │ ├── data │ │ ├── default.yaml │ │ └── perturbation.yaml │ ├── config.yaml │ └── state-defaults.yaml │ ├── _cli │ ├── __init__.py │ ├── _emb │ │ ├── __init__.py │ │ ├── _fit.py │ │ ├── _query.py │ │ └── _transform.py │ └── _tx │ │ ├── __init__.py │ │ ├── _preprocess_train.py │ │ └── _preprocess_infer.py │ └── __main__.py ├── .python-version ├── .github ├── CODEOWNERS └── workflows │ └── release.yml ├── examples ├── random.h5ad ├── zeroshot.toml ├── mixed.toml └── fewshot.toml ├── assets └── generalization_task.png ├── ruff.toml ├── .gitignore ├── singularity.def ├── pyproject.toml ├── scripts └── state_embed_anndata.py ├── MODEL_ACCEPTABLE_USE_POLICY.md ├── tests └── test_callbacks.py └── MODEL_LICENSE.md /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/state/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /src/state/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/state/emb/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/state/tx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/state/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/state/emb/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/state/emb/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @ArcInstitute/state-admins 2 | -------------------------------------------------------------------------------- /examples/random.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArcInstitute/state/HEAD/examples/random.h5ad -------------------------------------------------------------------------------- /src/state/emb/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .emb import cluster_embedding 2 | 3 | __all__ = ["cluster_embedding"] 4 | -------------------------------------------------------------------------------- /assets/generalization_task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArcInstitute/state/HEAD/assets/generalization_task.png -------------------------------------------------------------------------------- /src/state/tx/models/cpa/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import CPAPerturbationModel 2 | 3 | __all__ = ["CPAPerturbationModel"] 4 | -------------------------------------------------------------------------------- /src/state/tx/models/scvi/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import SCVIPerturbationModel 2 | 3 | __all__ = ["SCVIPerturbationModel"] 4 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Allow fixes for all enabled rules 2 | fix = true 3 | 4 | # Line length 5 | line-length = 120 6 | 7 | [lint] 8 | ignore = ["E722"] 9 | -------------------------------------------------------------------------------- /src/state/tx/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .scgpt_perturbation_dataset import scGPTPerturbationDataset 2 | 3 | __all__ = ["scGPTPerturbationDataset"] 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | hostfile 3 | __pycache__/ 4 | vci_job_* 5 | lightning_logs/ 6 | outputs/ 7 | log/ 8 | uv.lock 9 | tmp/ 10 | notebooks/ 11 | *.sif 12 | *.slurm 13 | temp 14 | wandb/ 15 | -------------------------------------------------------------------------------- /src/state/emb/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import H5adSentenceDataset, VCIDatasetSentenceCollator, create_dataloader 2 | 3 | __all__ = ["H5adSentenceDataset", "VCIDatasetSentenceCollator", "create_dataloader"] 4 | -------------------------------------------------------------------------------- /src/state/configs/wandb/default.yaml: -------------------------------------------------------------------------------- 1 | # Generic wandb configuration 2 | # Users should customize these values for their own use 3 | entity: your_entity_name 4 | project: state 5 | local_wandb_dir: ./wandb_logs 6 | tags: [] 7 | -------------------------------------------------------------------------------- /src/state/configs/training/scgpt.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 250000 2 | train_seed: 42 3 | val_freq: 5000 4 | test_freq: 9000 5 | gradient_clip_val: 10 # 0 means no clipping 6 | 7 | lr: 5e-5 8 | wd: 4e-7 9 | step_size_lr: 25 10 | do_clip_grad: false 11 | batch_size: 256 12 | -------------------------------------------------------------------------------- /src/state/configs/training/scvi.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 250000 2 | train_seed: 42 3 | val_freq: 5000 4 | test_freq: 9000 5 | gradient_clip_val: 10 # 0 means no clipping 6 | 7 | n_epochs_kl_warmup: 1e4 8 | lr: 5e-4 9 | wd: 4e-7 10 | step_size_lr: 25 11 | do_clip_grad: false 12 | batch_size: 2048 -------------------------------------------------------------------------------- /src/state/configs/model/embedsum.yaml: -------------------------------------------------------------------------------- 1 | name: EmbedSum 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | n_encoder_layers: 3 9 | n_decoder_layers: 3 10 | dropout: 0.0 11 | embed_key: X_uce # this is not actually used right now but would be good to move from the data loader to model? 12 | predict_residual: True -------------------------------------------------------------------------------- /src/state/configs/training/cpa.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 250000 2 | train_seed: 42 3 | val_freq: 5000 4 | test_freq: 9000 5 | gradient_clip_val: 10 # 0 means no clipping 6 | 7 | n_epochs_kl_warmup: null 8 | n_steps_adv_warmup: 50000 9 | n_steps_pretrain_ae: 50000 10 | adv_steps: null 11 | reg_adv: 15.0 12 | pen_adv: 20.0 13 | lr: 5e-4 14 | wd: 4e-7 15 | adv_lr: 5e-4 16 | adv_wd: 4e-7 17 | step_size_lr: 25 18 | do_clip_grad: false 19 | adv_loss: "cce" 20 | batch_size: 2048 -------------------------------------------------------------------------------- /src/state/tx/models/scgpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .generation_model import TransformerGenerator 2 | from .lightning_model import scGPTForPerturbation 3 | from .loss import criterion_neg_log_bernoulli, masked_mse_loss, masked_relative_error 4 | from .utils import map_raw_id_to_vocab_id 5 | 6 | __all__ = [ 7 | "scGPTForPerturbation", 8 | "TransformerGenerator", 9 | "masked_mse_loss", 10 | "criterion_neg_log_bernoulli", 11 | "masked_relative_error", 12 | "map_raw_id_to_vocab_id", 13 | ] 14 | -------------------------------------------------------------------------------- /examples/zeroshot.toml: -------------------------------------------------------------------------------- 1 | # Dataset paths - maps dataset names to their directories 2 | [datasets] 3 | example = "/home/aadduri/state/examples" 4 | 5 | # Training specifications 6 | # All cell types in a dataset automatically go into training (excluding zeroshot/fewshot overrides) 7 | [training] 8 | example = "train" 9 | 10 | # Zeroshot specifications - entire cell types go to val or test 11 | [zeroshot] 12 | "example.CT3" = "test" 13 | 14 | # Fewshot specifications - explicit perturbation lists 15 | [fewshot] 16 | -------------------------------------------------------------------------------- /src/state/configs/model/scvi.yaml: -------------------------------------------------------------------------------- 1 | name: scVI 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | n_latent: 84 7 | recon_loss: zinb 8 | pert_embeddings: null 9 | hidden_dim: 256 # not used 10 | n_hidden_encoder: 512 11 | n_layers_encoder: 2 12 | n_hidden_decoder: 512 13 | n_layers_decoder: 2 14 | use_batch_norm: both 15 | use_layer_norm: none 16 | dropout_rate_encoder: 0.1 17 | dropout_rate_decoder: 0.1 18 | expr_transform: none 19 | seed: 2025 20 | cell_sentence_len: 512 21 | nb_decoder: false -------------------------------------------------------------------------------- /src/state/configs/training/default.yaml: -------------------------------------------------------------------------------- 1 | wandb_track: true 2 | weight_decay: 0.0005 3 | batch_size: 16 4 | lr: 1e-4 5 | max_steps: 40000 6 | train_seed: 42 7 | val_freq: 2000 8 | ckpt_every_n_steps: 2000 9 | gradient_clip_val: 10 # 0 means no clipping 10 | loss_fn: mse 11 | devices: 1 # Number of GPUs to use for training 12 | strategy: auto # DDP strategy for multi-GPU training 13 | use_mfu: true 14 | mfu_kwargs: 15 | available_flops: 60e12 16 | use_backward: true 17 | logging_interval: 10 18 | window_size: 2 19 | cumulative_flops_use_backward: true -------------------------------------------------------------------------------- /src/state/tx/models/scgpt/grad_reverse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class GradReverse(Function): 6 | @staticmethod 7 | def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor: 8 | ctx.lambd = lambd 9 | return x.view_as(x) 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 13 | return grad_output.neg() * ctx.lambd, None 14 | 15 | 16 | def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor: 17 | return GradReverse.apply(x, lambd) 18 | -------------------------------------------------------------------------------- /examples/mixed.toml: -------------------------------------------------------------------------------- 1 | # Dataset paths - maps dataset names to their directories 2 | [datasets] 3 | example = "/home/aadduri/state/examples" 4 | 5 | # Training specifications 6 | # All cell types in a dataset automatically go into training (excluding zeroshot/fewshot overrides) 7 | [training] 8 | example = "train" 9 | 10 | # Zeroshot specifications - entire cell types go to val or test 11 | [zeroshot] 12 | "example.CT3" = "test" 13 | 14 | # Fewshot specifications - explicit perturbation lists 15 | [fewshot] 16 | 17 | [fewshot."example.CT4"] 18 | val = ["TARGET3"] 19 | test = ["TARGET4", "TARGET5"] # can overlap with val 20 | -------------------------------------------------------------------------------- /src/state/tx/utils/singleton.py: -------------------------------------------------------------------------------- 1 | # singleton.py 2 | 3 | import logging 4 | 5 | """ 6 | Metaclass for singletons. 7 | """ 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Singleton(type): 13 | """ 14 | Ensures single instance of a class. 15 | 16 | Example Usage: 17 | class MySingleton(metaclass=Singleton) 18 | pass 19 | """ 20 | 21 | _instances = {} 22 | 23 | def __call__(cls, *args, **kwargs): 24 | if cls not in cls._instances: 25 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 26 | return cls._instances[cls] 27 | -------------------------------------------------------------------------------- /examples/fewshot.toml: -------------------------------------------------------------------------------- 1 | # Dataset paths - maps dataset names to their directories 2 | [datasets] 3 | example = "/home/aadduri/state/examples" # CHANGE THIS TO YOUR DIRECTORY 4 | 5 | # Training specifications 6 | # All cell types in a dataset automatically go into training (excluding zeroshot/fewshot overrides) 7 | [training] 8 | example = "train" 9 | 10 | # Zeroshot specifications - entire cell types go to val or test 11 | [zeroshot] 12 | 13 | # Fewshot specifications - explicit perturbation lists 14 | [fewshot] 15 | 16 | [fewshot."example.CT4"] 17 | val = ["TARGET3"] 18 | test = ["TARGET4", "TARGET5"] # can overlap with val 19 | -------------------------------------------------------------------------------- /src/state/configs/model/cpa.yaml: -------------------------------------------------------------------------------- 1 | name: CPA 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | n_latent: 84 7 | recon_loss: gauss 8 | pert_embeddings: null 9 | hidden_dim: 256 # not used 10 | n_hidden_encoder: 1024 11 | n_layers_encoder: 5 12 | n_hidden_decoder: 1024 13 | n_layers_decoder: 4 14 | use_batch_norm: decoder 15 | use_layer_norm: encoder 16 | dropout_rate_encoder: 0.2 17 | dropout_rate_decoder: 0.2 18 | n_hidden_adv: 128 19 | n_layers_adv: 3 20 | use_norm_adv: batch 21 | dropout_rate_adv: 0.25 22 | variational: False 23 | expr_transform: none 24 | seed: 2025 25 | cell_sentence_len: 512 26 | nb_decoder: false -------------------------------------------------------------------------------- /src/state/configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | name: 2 | kwargs: 3 | embed_key: X_uce 4 | embed_size: null 5 | pert_rep: onehot 6 | basal_rep: sample 7 | only_keep_perts_with_expression: false # only keep perturbations for which expression data is available 8 | esm_perts_only: false 9 | n_basal_samples: 1 10 | sampling_random_state: 42 11 | split_random_state: 42 12 | normalize: true 13 | pseudobulk: false 14 | load_from_path: null 15 | test_cell_type: 16 | dataloader_preprocess: null 17 | k562_rpe1_name: replogle_k562_rpe1_filtered 18 | jurkat_name: replogle_jurkat_filtered 19 | hepg2_name: replogle_hepg2_filtered 20 | output_dir: null 21 | debug: true -------------------------------------------------------------------------------- /src/state/_cli/__init__.py: -------------------------------------------------------------------------------- 1 | from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval 2 | from ._tx import ( 3 | add_arguments_tx, 4 | run_tx_infer, 5 | run_tx_predict, 6 | run_tx_preprocess_infer, 7 | run_tx_preprocess_train, 8 | run_tx_train, 9 | ) 10 | 11 | __all__ = [ 12 | "add_arguments_emb", 13 | "add_arguments_tx", 14 | "run_tx_train", 15 | "run_tx_predict", 16 | "run_tx_infer", 17 | "run_tx_preprocess_train", 18 | "run_tx_preprocess_infer", 19 | "run_emb_fit", 20 | "run_emb_query", 21 | "run_emb_transform", 22 | "run_emb_preprocess", 23 | "run_emb_eval", 24 | ] 25 | -------------------------------------------------------------------------------- /src/state/configs/data/perturbation.yaml: -------------------------------------------------------------------------------- 1 | name: PerturbationDataModule 2 | kwargs: 3 | toml_config_path: null 4 | embed_key: null 5 | output_space: all 6 | pert_rep: onehot 7 | basal_rep: sample 8 | num_workers: 12 9 | pin_memory: true 10 | n_basal_samples: 1 11 | basal_mapping_strategy: random 12 | should_yield_control_cells: true 13 | batch_col: gem_group 14 | pert_col: gene 15 | cell_type_key: cell_type 16 | control_pert: DMSO_TF 17 | map_controls: true # for a control cell, should we use it as the target (learn identity) or sample a control? 18 | perturbation_features_file: null 19 | store_raw_basal: false 20 | int_counts: false 21 | barcode: true 22 | output_dir: null 23 | debug: true 24 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | environment: pypi 11 | permissions: 12 | id-token: write # Required for trusted publishing 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Install uv 18 | uses: astral-sh/setup-uv@v4 19 | with: 20 | version: "latest" 21 | 22 | - name: Set up Python 23 | run: uv python install 24 | 25 | - name: Install build dependencies 26 | run: uv sync --all-extras 27 | 28 | - name: Build the project 29 | run: uv build 30 | 31 | - name: Publish to PyPI 32 | run: uv publish 33 | -------------------------------------------------------------------------------- /src/state/configs/model/old_neuralot.yaml: -------------------------------------------------------------------------------- 1 | name: old_neuralot 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 328 8 | loss: energy 9 | n_encoder_layers: 4 10 | n_decoder_layers: 4 11 | predict_residual: True 12 | softplus: True 13 | freeze_pert_backbone: False 14 | transformer_decoder: False 15 | finetune_vci_decoder: False 16 | batch_encoder: False 17 | nb_decoder: False 18 | distributional_loss: energy 19 | transformer_backbone_key: GPT2 20 | transformer_backbone_kwargs: 21 | n_positions: ${model.kwargs.cell_set_len} 22 | n_embd: ${model.kwargs.hidden_dim} 23 | d_inner: 1024 24 | n_layer: 8 25 | n_head: 8 26 | resid_pdrop: 0.0 27 | embd_pdrop: 0.0 28 | attn_pdrop: 0.0 29 | use_cache: false 30 | -------------------------------------------------------------------------------- /src/state/tx/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import PerturbationModel 2 | from .context_mean import ContextMeanPerturbationModel 3 | from .decoder_only import DecoderOnlyPerturbationModel 4 | from .embed_sum import EmbedSumPerturbationModel 5 | from .perturb_mean import PerturbMeanPerturbationModel 6 | from .old_neural_ot import OldNeuralOTPerturbationModel 7 | from .state_transition import StateTransitionPerturbationModel 8 | from .pseudobulk import PseudobulkPerturbationModel 9 | 10 | __all__ = [ 11 | "PerturbationModel", 12 | "PerturbMeanPerturbationModel", 13 | "ContextMeanPerturbationModel", 14 | "EmbedSumPerturbationModel", 15 | "StateTransitionPerturbationModel", 16 | "OldNeuralOTPerturbationModel", 17 | "DecoderOnlyPerturbationModel", 18 | "PseudobulkPerturbationModel", 19 | ] 20 | -------------------------------------------------------------------------------- /src/state/configs/model/celltypemean.yaml: -------------------------------------------------------------------------------- 1 | name: CellTypeMean 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | cell_sentence_len: 512 9 | loss: energy 10 | n_encoder_layers: 4 11 | n_decoder_layers: 4 12 | predict_residual: True 13 | softplus: False 14 | freeze_pert_backbone: False 15 | transformer_decoder: False 16 | finetune_vci_decoder: False 17 | batch_encoder: False 18 | nb_decoder: False 19 | distributional_loss: energy 20 | transformer_backbone_key: GPT2 21 | transformer_backbone_kwargs: 22 | n_positions: ${model.kwargs.cell_set_len} 23 | n_embd: 512 24 | d_inner: 1024 25 | n_layer: 8 26 | n_head: 8 27 | resid_pdrop: 0.0 28 | embd_pdrop: 0.0 29 | attn_pdrop: 0.0 30 | use_cache: false 31 | -------------------------------------------------------------------------------- /src/state/configs/model/context_mean.yaml: -------------------------------------------------------------------------------- 1 | name: context_mean 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | cell_sentence_len: 512 9 | loss: energy 10 | n_encoder_layers: 4 11 | n_decoder_layers: 4 12 | predict_residual: True 13 | softplus: False 14 | freeze_pert_backbone: False 15 | transformer_decoder: False 16 | finetune_vci_decoder: False 17 | batch_encoder: False 18 | nb_decoder: False 19 | distributional_loss: energy 20 | transformer_backbone_key: GPT2 21 | transformer_backbone_kwargs: 22 | n_positions: ${model.kwargs.cell_set_len} 23 | n_embd: 512 24 | d_inner: 1024 25 | n_layer: 8 26 | n_head: 8 27 | resid_pdrop: 0.0 28 | embd_pdrop: 0.0 29 | attn_pdrop: 0.0 30 | use_cache: false 31 | -------------------------------------------------------------------------------- /src/state/configs/model/decoder_only.yaml: -------------------------------------------------------------------------------- 1 | name: decoder_only 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | cell_sentence_len: 512 9 | loss: energy 10 | n_encoder_layers: 4 11 | n_decoder_layers: 4 12 | predict_residual: True 13 | softplus: False 14 | freeze_pert_backbone: False 15 | transformer_decoder: False 16 | finetune_vci_decoder: False 17 | batch_encoder: False 18 | nb_decoder: False 19 | distributional_loss: energy 20 | transformer_backbone_key: GPT2 21 | transformer_backbone_kwargs: 22 | n_positions: ${model.kwargs.cell_set_len} 23 | n_embd: 512 24 | d_inner: 1024 25 | n_layer: 8 26 | n_head: 8 27 | resid_pdrop: 0.0 28 | embd_pdrop: 0.0 29 | attn_pdrop: 0.0 30 | use_cache: false 31 | -------------------------------------------------------------------------------- /src/state/configs/model/perturb_mean.yaml: -------------------------------------------------------------------------------- 1 | name: perturb_mean 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | cell_sentence_len: 512 9 | loss: energy 10 | n_encoder_layers: 4 11 | n_decoder_layers: 4 12 | predict_residual: True 13 | batch_encoder: False 14 | softplus: False 15 | freeze_pert_backbone: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | nb_decoder: False 19 | distributional_loss: energy 20 | transformer_backbone_key: GPT2 21 | transformer_backbone_kwargs: 22 | n_positions: ${model.kwargs.cell_set_len} 23 | n_embd: 512 24 | d_inner: 1024 25 | n_layer: 8 26 | n_head: 8 27 | resid_pdrop: 0.0 28 | embd_pdrop: 0.0 29 | attn_pdrop: 0.0 30 | use_cache: false 31 | -------------------------------------------------------------------------------- /src/state/configs/model/globalsimplesum.yaml: -------------------------------------------------------------------------------- 1 | name: GlobalSimpleSum 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | hidden_dim: 512 8 | cell_sentence_len: 512 9 | loss: energy 10 | n_encoder_layers: 4 11 | n_decoder_layers: 4 12 | predict_residual: True 13 | batch_encoder: False 14 | softplus: False 15 | freeze_pert_backbone: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | nb_decoder: False 19 | distributional_loss: energy 20 | transformer_backbone_key: GPT2 21 | transformer_backbone_kwargs: 22 | n_positions: ${model.kwargs.cell_set_len} 23 | n_embd: 512 24 | d_inner: 1024 25 | n_layer: 8 26 | n_head: 8 27 | resid_pdrop: 0.0 28 | embd_pdrop: 0.0 29 | attn_pdrop: 0.0 30 | use_cache: false 31 | -------------------------------------------------------------------------------- /src/state/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # This is a template used in the application to generating the config file for 2 | # training tasks 3 | defaults: 4 | - data: perturbation 5 | - model: pertsets 6 | - training: default 7 | - wandb: default 8 | - _self_ 9 | 10 | 11 | # output_dir must be an absolute path (so that launch scripts are fully descriptive) 12 | name: debug 13 | output_dir: ./debugging 14 | use_wandb: true 15 | overwrite: false 16 | return_adatas: false 17 | pred_adata_path: null 18 | true_adata_path: null 19 | 20 | # don't save hydra output 21 | hydra: 22 | output_subdir: null 23 | run: 24 | dir: . 25 | job_logging: 26 | formatters: 27 | simple: 28 | format: "[%(levelname)s] %(message)s" # Simple format for logging 29 | handlers: 30 | console: 31 | class: logging.StreamHandler 32 | formatter: simple 33 | level: INFO 34 | stream: ext://sys.stdout 35 | root: 36 | level: INFO 37 | loggers: 38 | __main__: 39 | level: DEBUG 40 | handlers: [console] 41 | propagate: false 42 | -------------------------------------------------------------------------------- /src/state/_cli/_emb/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | from ._fit import add_arguments_fit, run_emb_fit 4 | from ._transform import add_arguments_transform, run_emb_transform 5 | from ._query import add_arguments_query, run_emb_query 6 | from ._preprocess import add_arguments_preprocess, run_emb_preprocess 7 | from ._eval import add_arguments_eval, run_emb_eval 8 | 9 | __all__ = [ 10 | "run_emb_fit", 11 | "run_emb_transform", 12 | "run_emb_query", 13 | "run_emb_preprocess", 14 | "run_emb_eval", 15 | "add_arguments_emb", 16 | ] 17 | 18 | 19 | def add_arguments_emb(parser: ap.ArgumentParser): 20 | """""" 21 | subparsers = parser.add_subparsers(required=True, dest="subcommand") 22 | add_arguments_fit(subparsers.add_parser("fit")) 23 | add_arguments_transform(subparsers.add_parser("transform")) 24 | add_arguments_query(subparsers.add_parser("query")) 25 | add_arguments_preprocess( 26 | subparsers.add_parser("preprocess", help="Preprocess datasets and create embedding profiles") 27 | ) 28 | add_arguments_eval(subparsers.add_parser("eval", help="Evaluate embeddings")) 29 | -------------------------------------------------------------------------------- /src/state/_cli/_tx/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | from ._infer import add_arguments_infer, run_tx_infer 4 | from ._predict import add_arguments_predict, run_tx_predict 5 | from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer 6 | from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train 7 | from ._train import add_arguments_train, run_tx_train 8 | 9 | __all__ = [ 10 | "run_tx_train", 11 | "run_tx_predict", 12 | "run_tx_infer", 13 | "run_tx_preprocess_train", 14 | "run_tx_preprocess_infer", 15 | "add_arguments_tx", 16 | ] 17 | 18 | 19 | def add_arguments_tx(parser: ap.ArgumentParser): 20 | """""" 21 | subparsers = parser.add_subparsers(required=True, dest="subcommand") 22 | add_arguments_train(subparsers.add_parser("train", add_help=False)) 23 | add_arguments_predict(subparsers.add_parser("predict")) 24 | add_arguments_infer(subparsers.add_parser("infer")) 25 | add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) 26 | add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) 27 | -------------------------------------------------------------------------------- /src/state/tx/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from lightning.pytorch.callbacks import Callback 3 | from torch.optim import Optimizer 4 | 5 | from ..models import PerturbationModel 6 | from .batch_speed_monitor import BatchSpeedMonitorCallback 7 | from .model_flops_utilization import ModelFLOPSUtilizationCallback 8 | from .cumulative_flops import CumulativeFLOPSCallback 9 | 10 | __all__ = ["PerturbationModel", "BatchSpeedMonitorCallback", "ModelFLOPSUtilizationCallback", "CumulativeFLOPSCallback"] 11 | 12 | 13 | class GradNormCallback(Callback): 14 | """ 15 | Logs the gradient norm. 16 | """ 17 | 18 | def on_before_optimizer_step( 19 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer 20 | ) -> None: 21 | pl_module.log("train/gradient_norm", gradient_norm(pl_module)) 22 | 23 | 24 | def gradient_norm(model): 25 | total_norm = 0.0 26 | for p in model.parameters(): 27 | if p.grad is not None: 28 | param_norm = p.grad.detach().data.norm(2) 29 | total_norm += param_norm.item() ** 2 30 | total_norm = total_norm ** (1.0 / 2) 31 | return total_norm 32 | -------------------------------------------------------------------------------- /src/state/tx/models/scgpt/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def masked_mse_loss(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Compute the masked MSE loss between input and target. 8 | """ 9 | mask = mask.float() 10 | loss = F.mse_loss(input * mask, target * mask, reduction="sum") 11 | return loss / mask.sum() 12 | 13 | 14 | def criterion_neg_log_bernoulli(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Compute the negative log-likelihood of Bernoulli distribution 17 | """ 18 | mask = mask.float() 19 | bernoulli = torch.distributions.Bernoulli(probs=input) 20 | masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask 21 | return -masked_log_probs.sum() / mask.sum() 22 | 23 | 24 | def masked_relative_error(input: torch.Tensor, target: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor: 25 | """ 26 | Compute the masked relative error between input and target. 27 | """ 28 | assert mask.any() 29 | loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6) 30 | return loss.mean() 31 | -------------------------------------------------------------------------------- /singularity.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ubuntu:22.04 3 | 4 | %labels 5 | Author Nick Youngblut 6 | Version 1.0 7 | Description STATE - machine learning model for cellular perturbation prediction 8 | 9 | %help 10 | This container includes STATE (https://github.com/ArcInstitute/state), a machine learning model 11 | that predicts cellular perturbation response across diverse contexts. 12 | 13 | STATE is trained on single-cell RNA-seq data and can predict how cells respond to various 14 | perturbations including drugs, genetic modifications, and environmental changes. 15 | 16 | To build with singularity, run: 17 | singularity build state.sif singularity.def 18 | 19 | To run the container, run: 20 | singularity run state.sif --help 21 | 22 | %environment 23 | export PATH="/root/.local/bin:$PATH" 24 | 25 | %post 26 | # Install system dependencies 27 | apt-get update && apt-get install -y \ 28 | curl \ 29 | build-essential \ 30 | python3-dev \ 31 | && rm -rf /var/lib/apt/lists/* 32 | 33 | # Install uv 34 | curl -LsSf https://astral.sh/uv/install.sh | sh 35 | 36 | # Install STATE 37 | /root/.local/bin/uv tool install arc-state 38 | 39 | %runscript 40 | exec state "$@" -------------------------------------------------------------------------------- /src/state/tx/models/cpa/_callbacks.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.callbacks import Callback 2 | 3 | 4 | class CPABestModelTracker(Callback): 5 | def __init__(self, monitor: str = "val_loss", mode: str = "min"): 6 | super().__init__() 7 | self.monitor = monitor 8 | self.mode = mode 9 | self.best_model = None 10 | self.best_score = None 11 | 12 | def on_validation_end(self, trainer, pl_module): 13 | if self.best_score is None: 14 | self.best_score = trainer.callback_metrics[self.monitor] 15 | self.best_model = pl_module.state_dict() 16 | else: 17 | if self.mode == "min": 18 | if trainer.callback_metrics[self.monitor] < self.best_score: 19 | self.best_score = trainer.callback_metrics[self.monitor] 20 | self.best_model = pl_module.state_dict() 21 | else: 22 | if trainer.callback_metrics[self.monitor] > self.best_score: 23 | self.best_score = trainer.callback_metrics[self.monitor] 24 | self.best_model = pl_module.state_dict() 25 | 26 | def on_train_end(self, trainer, pl_module): 27 | pl_module.load_state_dict(self.best_model) 28 | print(f"Best model loaded with {self.monitor} = {self.best_score}") 29 | return self.best_model 30 | -------------------------------------------------------------------------------- /src/state/tx/models/scvi/_callbacks.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.callbacks import Callback 2 | 3 | 4 | class CPABestModelTracker(Callback): 5 | def __init__(self, monitor: str = "val_loss", mode: str = "min"): 6 | super().__init__() 7 | self.monitor = monitor 8 | self.mode = mode 9 | self.best_model = None 10 | self.best_score = None 11 | 12 | def on_validation_end(self, trainer, pl_module): 13 | if self.best_score is None: 14 | self.best_score = trainer.callback_metrics[self.monitor] 15 | self.best_model = pl_module.state_dict() 16 | else: 17 | if self.mode == "min": 18 | if trainer.callback_metrics[self.monitor] < self.best_score: 19 | self.best_score = trainer.callback_metrics[self.monitor] 20 | self.best_model = pl_module.state_dict() 21 | else: 22 | if trainer.callback_metrics[self.monitor] > self.best_score: 23 | self.best_score = trainer.callback_metrics[self.monitor] 24 | self.best_model = pl_module.state_dict() 25 | 26 | def on_train_end(self, trainer, pl_module): 27 | pl_module.load_state_dict(self.best_model) 28 | print(f"Best model loaded with {self.monitor} = {self.best_score}") 29 | return self.best_model 30 | -------------------------------------------------------------------------------- /src/state/configs/model/tahoe_best.yaml: -------------------------------------------------------------------------------- 1 | name: PertSets 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | blur: 0.05 8 | hidden_dim: 1440 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 4 12 | n_decoder_layers: 4 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | batch_encoder: False 20 | nb_decoder: False 21 | mask_attn: False 22 | use_effect_gating_token: False 23 | distributional_loss: energy 24 | init_from: null 25 | transformer_backbone_key: llama 26 | transformer_backbone_kwargs: 27 | bidirectional_attention: false 28 | max_position_embeddings: ${model.kwargs.cell_set_len} 29 | hidden_size: ${model.kwargs.hidden_dim} 30 | intermediate_size: 4416 31 | num_hidden_layers: 4 32 | num_attention_heads: 12 33 | num_key_value_heads: 12 34 | head_dim: 120 35 | use_cache: false 36 | attention_dropout: 0.0 37 | hidden_dropout: 0.0 38 | layer_norm_eps: 1e-6 39 | pad_token_id: 0 40 | bos_token_id: 1 41 | eos_token_id: 2 42 | tie_word_embeddings: false 43 | rotary_dim: 0 44 | use_rotary_embeddings: false 45 | -------------------------------------------------------------------------------- /src/state/configs/model/tahoe_llama_62089464.yaml: -------------------------------------------------------------------------------- 1 | name: PertSets 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | blur: 0.05 8 | hidden_dim: 696 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 4 12 | n_decoder_layers: 4 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | batch_encoder: False 20 | nb_decoder: False 21 | mask_attn: False 22 | use_effect_gating_token: False 23 | distributional_loss: energy 24 | init_from: null 25 | transformer_backbone_key: llama 26 | transformer_backbone_kwargs: 27 | bidirectional_attention: false 28 | max_position_embeddings: ${model.kwargs.cell_set_len} 29 | hidden_size: ${model.kwargs.hidden_dim} 30 | intermediate_size: 2784 31 | num_hidden_layers: 8 32 | num_attention_heads: 12 33 | num_key_value_heads: 12 34 | head_dim: 58 35 | use_cache: false 36 | attention_dropout: 0.0 37 | hidden_dropout: 0.0 38 | layer_norm_eps: 1e-6 39 | pad_token_id: 0 40 | bos_token_id: 1 41 | eos_token_id: 2 42 | tie_word_embeddings: false 43 | rotary_dim: 0 44 | use_rotary_embeddings: false 45 | -------------------------------------------------------------------------------- /src/state/configs/model/tahoe_llama_212693232.yaml: -------------------------------------------------------------------------------- 1 | name: PertSets 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | blur: 0.05 8 | hidden_dim: 1488 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 4 12 | n_decoder_layers: 4 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | decoder_loss_weight: 1.0 20 | batch_encoder: False 21 | nb_decoder: False 22 | mask_attn: False 23 | use_effect_gating_token: False 24 | use_basal_projection: False 25 | distributional_loss: energy 26 | init_from: null 27 | transformer_backbone_key: llama 28 | transformer_backbone_kwargs: 29 | bidirectional_attention: false 30 | max_position_embeddings: ${model.kwargs.cell_set_len} 31 | hidden_size: ${model.kwargs.hidden_dim} 32 | intermediate_size: 5952 33 | num_hidden_layers: 6 34 | num_attention_heads: 12 35 | num_key_value_heads: 12 36 | head_dim: 124 37 | use_cache: false 38 | attention_dropout: 0.0 39 | hidden_dropout: 0.0 40 | layer_norm_eps: 1e-6 41 | pad_token_id: 0 42 | bos_token_id: 1 43 | eos_token_id: 2 44 | tie_word_embeddings: false 45 | rotary_dim: 0 46 | use_rotary_embeddings: false 47 | -------------------------------------------------------------------------------- /src/state/tx/models/scgpt/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def map_raw_id_to_vocab_id( 8 | raw_ids: Union[np.ndarray, torch.Tensor], 9 | gene_ids: np.ndarray, 10 | ) -> Union[np.ndarray, torch.Tensor]: 11 | """ 12 | Map some raw ids which are indices of the raw gene names to the indices of the 13 | 14 | Args: 15 | raw_ids: the raw ids to map 16 | gene_ids: the gene ids to map to 17 | """ 18 | if isinstance(raw_ids, torch.Tensor): 19 | device = raw_ids.device 20 | dtype = raw_ids.dtype 21 | return_pt = True 22 | raw_ids = raw_ids.cpu().numpy() 23 | 24 | elif isinstance(raw_ids, np.ndarray): 25 | return_pt = False 26 | dtype = raw_ids.dtype 27 | 28 | else: 29 | raise ValueError("raw_ids must be either torch.Tensor or np.ndarray.") 30 | 31 | if raw_ids.ndim != 1: 32 | raise ValueError(f"raw_ids must be 1d, got {raw_ids.ndim}d.") 33 | 34 | if gene_ids.ndim != 1: 35 | raise ValueError(f"gene_ids must be 1d, got {gene_ids.ndim}d.") 36 | 37 | mapped_ids: np.ndarray = gene_ids[raw_ids] 38 | assert mapped_ids.shape == raw_ids.shape 39 | if return_pt: 40 | if isinstance(mapped_ids, np.ndarray): 41 | return torch.from_numpy(mapped_ids).type(dtype).to(device) 42 | return mapped_ids.to(dtype) 43 | return mapped_ids.astype(dtype) 44 | -------------------------------------------------------------------------------- /src/state/emb/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | # Set up VCI module aliases for backward compatibility 5 | def _setup_vci_aliases(): 6 | """Set up vci.* aliases to point to the current emb module structure.""" 7 | current_module = sys.modules[__name__] 8 | 9 | # Main vci alias 10 | sys.modules["vci"] = current_module 11 | 12 | # Import and alias submodules 13 | try: 14 | from emb import nn 15 | 16 | sys.modules["vci.nn"] = nn 17 | sys.modules["vci.nn.model"] = nn.model 18 | except ImportError: 19 | pass 20 | 21 | try: 22 | from emb import train 23 | 24 | sys.modules["vci.train"] = train 25 | sys.modules["vci.train.trainer"] = train.trainer 26 | except ImportError: 27 | pass 28 | 29 | try: 30 | from emb import data 31 | 32 | sys.modules["vci.data"] = data 33 | if hasattr(data, "loader"): 34 | sys.modules["vci.data.loader"] = data.loader 35 | except ImportError: 36 | pass 37 | 38 | try: 39 | from emb import utils 40 | 41 | sys.modules["vci.utils"] = utils 42 | except ImportError: 43 | pass 44 | 45 | try: 46 | from emb import eval as eval_module 47 | 48 | sys.modules["vci.eval"] = eval_module 49 | except ImportError: 50 | pass 51 | 52 | 53 | # Set up the aliases when this module is imported 54 | _setup_vci_aliases() 55 | 56 | # Your existing exports 57 | from .inference import Inference 58 | 59 | __all__ = ["Inference"] 60 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "arc-state" 3 | version = "0.9.32" 4 | description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" }, 8 | { name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" }, 9 | { name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" }, 10 | { name = "Rajesh Ilango" }, 11 | { name = "Dhruv Gautam", email = "dhruvgautam@berkeley.edu" }, 12 | ] 13 | requires-python = ">=3.10,<3.13" 14 | dependencies = [ 15 | "anndata>=0.11.4", 16 | "cell-load>=0.8.3", 17 | "numpy>=2.2.6", 18 | "pandas>=2.2.3", 19 | "pyyaml>=6.0.2", 20 | "scanpy>=1.11.2", 21 | "scikit-learn>=1.6.1", 22 | "seaborn>=0.13.2", 23 | "torch>=2.7.0", 24 | "tqdm>=4.67.1", 25 | "wandb>=0.19.11", 26 | "hydra-core>=1.3.2", 27 | "geomloss>=0.2.6", 28 | "transformers>=4.52.3", 29 | "peft>=0.11.0", 30 | "cell-eval>=0.5.22", 31 | "ipykernel>=6.30.1", 32 | "scipy>=1.15.0", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | vectordb = [ 37 | "lancedb>=0.24.0" 38 | ] 39 | 40 | [dependency-groups] 41 | dev = ["ruff>=0.11.11", "vulture>=2.14", "ipython>=8.37.0"] 42 | 43 | [build-system] 44 | requires = ["hatchling"] 45 | build-backend = "hatchling.build" 46 | 47 | [project.scripts] 48 | state = "state.__main__:main" 49 | 50 | [tool.pyright] 51 | venvPath = "." 52 | venv = ".venv" 53 | 54 | [tool.hatch.build.targets.wheel] 55 | packages = ["src/state"] 56 | -------------------------------------------------------------------------------- /src/state/configs/model/state.yaml: -------------------------------------------------------------------------------- 1 | name: state 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | blur: 0.05 8 | hidden_dim: 696 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 1 12 | n_decoder_layers: 1 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert_backbone: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | batch_encoder: False 20 | use_batch_token: False 21 | nb_decoder: False 22 | mask_attn: False 23 | use_effect_gating_token: False 24 | distributional_loss: energy 25 | init_from: null 26 | transformer_backbone_key: llama 27 | transformer_backbone_kwargs: 28 | bidirectional_attention: false 29 | max_position_embeddings: ${model.kwargs.cell_set_len} 30 | hidden_size: ${model.kwargs.hidden_dim} 31 | intermediate_size: 2784 32 | num_hidden_layers: 8 33 | num_attention_heads: 12 34 | num_key_value_heads: 12 35 | head_dim: 58 36 | use_cache: false 37 | attention_dropout: 0.0 38 | hidden_dropout: 0.0 39 | layer_norm_eps: 1e-6 40 | pad_token_id: 0 41 | bos_token_id: 1 42 | eos_token_id: 2 43 | tie_word_embeddings: false 44 | rotary_dim: 0 45 | use_rotary_embeddings: false 46 | lora: 47 | enable: false 48 | r: 16 49 | alpha: 32 50 | dropout: 0.05 51 | bias: none 52 | target: auto 53 | adapt_mlp: false 54 | task_type: FEATURE_EXTRACTION 55 | merge_on_eval: false 56 | -------------------------------------------------------------------------------- /src/state/configs/model/scgpt-genetic.yaml: -------------------------------------------------------------------------------- 1 | name: scGPT-genetic 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | hidden_dim: 256 # not used 7 | pad_token: "" 8 | special_tokens: 9 | - "" 10 | - "" 11 | - "" 12 | 13 | pad_value: 0 14 | pert_pad_id: 2 15 | 16 | include_zero_gene: "all" # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False 17 | max_seq_len: 1536 18 | 19 | do_MLM: true # whether to use masked language modeling, currently it is always on. 20 | do_CLS: false # celltype classification objective 21 | do_CCE: false # Contrastive cell embedding objective 22 | do_MVC: false # Masked value prediction for cell embedding 23 | do_ECS: false # Elastic cell similarity objective 24 | cell_emb_style: "cls" 25 | mvc_decoder_style: "inner product, detach" 26 | use_amp: true 27 | pretrained_path: "/large_storage/goodarzilab/userspace/mohsen/scGPT/scGPT_human/" 28 | load_param_prefixes: 29 | - "encoder" 30 | - "value_encoder" 31 | - "transformer_encoder" 32 | 33 | # settings for the model 34 | embsize: 512 # embedding dimension 35 | d_hid: 512 # dimension of the feedforward network model in nn.TransformerEncoder 36 | nlayers: 12 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder 37 | nhead: 8 # number of heads in nn.MultiheadAttention 38 | n_layers_cls: 3 39 | dropout: 0.2 # dropout probability 40 | use_fast_transformer: true # whether to use fast transformer 41 | 42 | expr_transform: none 43 | perturbation_type: genetic 44 | seed: 2025 45 | cell_sentence_len: 2048 46 | nb_decoder: false -------------------------------------------------------------------------------- /src/state/configs/model/scgpt-chemical.yaml: -------------------------------------------------------------------------------- 1 | name: scGPT-chemical 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | hidden_dim: 256 # not used 7 | pad_token: "" 8 | special_tokens: 9 | - "" 10 | - "" 11 | - "" 12 | 13 | pad_value: 0 14 | pert_pad_id: 2 15 | 16 | include_zero_gene: "all" # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False 17 | max_seq_len: 1536 18 | 19 | do_MLM: true # whether to use masked language modeling, currently it is always on. 20 | do_CLS: false # celltype classification objective 21 | do_CCE: false # Contrastive cell embedding objective 22 | do_MVC: false # Masked value prediction for cell embedding 23 | do_ECS: false # Elastic cell similarity objective 24 | cell_emb_style: "cls" 25 | mvc_decoder_style: "inner product, detach" 26 | use_amp: true 27 | pretrained_path: "/large_storage/goodarzilab/userspace/mohsen/scGPT/scGPT_human/" 28 | load_param_prefixes: 29 | - "encoder" 30 | - "value_encoder" 31 | - "transformer_encoder" 32 | 33 | # settings for the model 34 | embsize: 512 # embedding dimension 35 | d_hid: 512 # dimension of the feedforward network model in nn.TransformerEncoder 36 | nlayers: 12 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder 37 | nhead: 8 # number of heads in nn.MultiheadAttention 38 | n_layers_cls: 3 39 | dropout: 0.2 # dropout probability 40 | use_fast_transformer: true # whether to use fast transformer 41 | 42 | expr_transform: none 43 | perturbation_type: "chemical" 44 | cell_sentence_len: 2048 45 | seed: 2025 46 | nb_decoder: false -------------------------------------------------------------------------------- /src/state/configs/model/state_sm.yaml: -------------------------------------------------------------------------------- 1 | name: state 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 128 7 | blur: 0.05 8 | hidden_dim: 672 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 4 12 | n_decoder_layers: 4 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert_backbone: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | batch_encoder: False 20 | nb_decoder: False 21 | mask_attn: False 22 | use_effect_gating_token: False 23 | use_basal_projection: False 24 | distributional_loss: energy 25 | gene_decoder_bool: False 26 | init_from: null 27 | transformer_backbone_key: llama 28 | transformer_backbone_kwargs: 29 | bidirectional_attention: false 30 | max_position_embeddings: ${model.kwargs.cell_set_len} 31 | hidden_size: ${model.kwargs.hidden_dim} 32 | intermediate_size: 2688 33 | num_hidden_layers: 4 34 | num_attention_heads: 8 35 | num_key_value_heads: 8 36 | head_dim: 84 37 | use_cache: false 38 | attention_dropout: 0.0 39 | hidden_dropout: 0.0 40 | layer_norm_eps: 1e-6 41 | pad_token_id: 0 42 | bos_token_id: 1 43 | eos_token_id: 2 44 | tie_word_embeddings: false 45 | rotary_dim: 0 46 | use_rotary_embeddings: false 47 | lora: 48 | enable: false 49 | r: 16 50 | alpha: 32 51 | dropout: 0.05 52 | bias: none 53 | target: auto 54 | adapt_mlp: false 55 | task_type: FEATURE_EXTRACTION 56 | merge_on_eval: false 57 | -------------------------------------------------------------------------------- /src/state/configs/model/state_lg.yaml: -------------------------------------------------------------------------------- 1 | name: state 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 7 | blur: 0.05 8 | hidden_dim: 1488 # hidden dimension going into the transformer backbone 9 | loss: energy 10 | confidence_head: False 11 | n_encoder_layers: 4 12 | n_decoder_layers: 4 13 | predict_residual: True 14 | softplus: True 15 | freeze_pert_backbone: False 16 | transformer_decoder: False 17 | finetune_vci_decoder: False 18 | residual_decoder: False 19 | decoder_loss_weight: 1.0 20 | batch_encoder: False 21 | nb_decoder: False 22 | mask_attn: False 23 | use_effect_gating_token: False 24 | use_basal_projection: False 25 | distributional_loss: energy 26 | init_from: null 27 | transformer_backbone_key: llama 28 | transformer_backbone_kwargs: 29 | bidirectional_attention: false 30 | max_position_embeddings: ${model.kwargs.cell_set_len} 31 | hidden_size: ${model.kwargs.hidden_dim} 32 | intermediate_size: 5952 33 | num_hidden_layers: 6 34 | num_attention_heads: 12 35 | num_key_value_heads: 12 36 | head_dim: 124 37 | use_cache: false 38 | attention_dropout: 0.0 39 | hidden_dropout: 0.0 40 | layer_norm_eps: 1e-6 41 | pad_token_id: 0 42 | bos_token_id: 1 43 | eos_token_id: 2 44 | tie_word_embeddings: false 45 | rotary_dim: 0 46 | use_rotary_embeddings: false 47 | lora: 48 | enable: false 49 | r: 16 50 | alpha: 32 51 | dropout: 0.05 52 | bias: none 53 | target: auto 54 | adapt_mlp: false 55 | task_type: FEATURE_EXTRACTION 56 | merge_on_eval: false 57 | -------------------------------------------------------------------------------- /src/state/configs/model/pertsets.yaml: -------------------------------------------------------------------------------- 1 | name: PertSets 2 | checkpoint: null 3 | device: cuda 4 | 5 | kwargs: 6 | cell_set_len: 512 # how many cells to group together into a single set of cells 7 | extra_tokens: 1 # configurable buffer for confidence/special tokens 8 | decoder_hidden_dims: [1024, 1024, 512] 9 | blur: 0.05 10 | hidden_dim: 328 # hidden dimension going into the transformer backbone 11 | loss: energy 12 | confidence_token: False # if true, model tries to predict its own confidence 13 | n_encoder_layers: 4 # number of MLP layers for pert, basal encoders 14 | n_decoder_layers: 4 15 | predict_residual: True # if true, predicts the residual in embedding space to the basal cells 16 | freeze_pert_backbone: False # if true, the perturbation model is frozen 17 | finetune_vci_decoder: False # if true, the pretrained state decoder is used in finetuning 18 | residual_decoder: False # if true, the pretrained state decoder is used in finetuning 19 | batch_encoder: False # if true, batch variables are used 20 | use_batch_token: False # if true, batch token is appended to the sequence 21 | nb_decoder: False # if true, use a negative binomial decoder 22 | decoder_loss_weight: 1.0 23 | use_basal_projection: False 24 | mask_attn: False # if true, mask the attention 25 | distributional_loss: energy 26 | regularization: 0.0 27 | init_from: null # initial checkpoint to start the model 28 | transformer_backbone_key: GPT2 29 | transformer_backbone_kwargs: 30 | max_position_embeddings: ${model.kwargs.cell_set_len} # llama 31 | n_positions: ${model.kwargs.cell_set_len} # gpt2 32 | hidden_size: ${model.kwargs.hidden_dim} # llama 33 | n_embd: ${model.kwargs.hidden_dim} # gpt2 34 | n_layer: 8 35 | n_head: 8 36 | resid_pdrop: 0.0 37 | embd_pdrop: 0.0 38 | attn_pdrop: 0.0 39 | use_cache: false 40 | -------------------------------------------------------------------------------- /src/state/_cli/_emb/_fit.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | 4 | def add_arguments_fit(parser: ap.ArgumentParser): 5 | """Add arguments for embedding training CLI.""" 6 | parser.add_argument("--conf", type=str, default=None, help="Path to config YAML file") 7 | parser.add_argument( 8 | "hydra_overrides", nargs="*", help="Hydra configuration overrides (e.g., embeddings.current=esm2-cellxgene)" 9 | ) 10 | 11 | 12 | def run_emb_fit(cfg, args): 13 | """ 14 | Run state training with the provided config and overrides. 15 | """ 16 | import logging 17 | import os 18 | import sys 19 | 20 | from omegaconf import OmegaConf 21 | 22 | from ...emb.train.trainer import main as trainer_main 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | # Load the base configuration 27 | if args.conf: 28 | cfg = OmegaConf.load(args.conf) 29 | 30 | # Process the remaining command line arguments as overrides 31 | if args.hydra_overrides: 32 | overrides = OmegaConf.from_dotlist(args.hydra_overrides) 33 | cfg = OmegaConf.merge(cfg, overrides) 34 | 35 | # Validate required configuration 36 | if cfg.embeddings.current is None: 37 | log.error("Gene embeddings are required for training. Please set 'embeddings.current'") 38 | sys.exit(1) 39 | 40 | if cfg.dataset.current is None: 41 | log.error("Please set the desired dataset to 'dataset.current'") 42 | sys.exit(1) 43 | 44 | # Set environment variables 45 | os.environ["MASTER_PORT"] = str(cfg.experiment.port) 46 | # WAR: Workaround for sbatch failing when --ntasks-per-node is set. 47 | # lightning expects this to be set. 48 | os.environ["SLURM_NTASKS_PER_NODE"] = str(cfg.experiment.num_gpus_per_node) 49 | 50 | log.info(f"*************** Training {cfg.experiment.name} ***************") 51 | log.info(OmegaConf.to_yaml(cfg)) 52 | 53 | # Execute the main training logic 54 | trainer_main(cfg) 55 | -------------------------------------------------------------------------------- /src/state/_cli/_tx/_preprocess_train.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | 4 | def add_arguments_preprocess_train(parser: ap.ArgumentParser): 5 | """Add arguments for the preprocess_train subcommand.""" 6 | parser.add_argument( 7 | "--adata", 8 | type=str, 9 | required=True, 10 | help="Path to input AnnData file (.h5ad)", 11 | ) 12 | parser.add_argument( 13 | "--output", 14 | type=str, 15 | required=True, 16 | help="Path to output preprocessed AnnData file (.h5ad)", 17 | ) 18 | parser.add_argument( 19 | "--num_hvgs", 20 | type=int, 21 | required=True, 22 | help="Number of highly variable genes to select", 23 | ) 24 | 25 | 26 | def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): 27 | """ 28 | Preprocess training data by normalizing, log-transforming, and selecting highly variable genes. 29 | 30 | Args: 31 | adata_path: Path to input AnnData file 32 | output_path: Path to save preprocessed AnnData file 33 | num_hvgs: Number of highly variable genes to select 34 | """ 35 | import logging 36 | 37 | import anndata as ad 38 | import scanpy as sc 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | logger.info(f"Loading AnnData from {adata_path}") 43 | adata = ad.read_h5ad(adata_path) 44 | 45 | logger.info("Normalizing total counts per cell") 46 | sc.pp.normalize_total(adata) 47 | 48 | logger.info("Applying log1p transformation") 49 | sc.pp.log1p(adata) 50 | 51 | logger.info(f"Finding top {num_hvgs} highly variable genes") 52 | sc.pp.highly_variable_genes(adata, n_top_genes=num_hvgs) 53 | 54 | logger.info("Storing highly variable genes in .obsm['X_hvg']") 55 | adata.obsm["X_hvg"] = adata[:, adata.var.highly_variable].X.toarray() 56 | 57 | logger.info(f"Saving preprocessed data to {output_path}") 58 | adata.write_h5ad(output_path) 59 | 60 | logger.info(f"Preprocessing complete. Selected {adata.var.highly_variable.sum()} highly variable genes.") 61 | -------------------------------------------------------------------------------- /scripts/state_embed_anndata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | VCI Model Embedding Script 4 | 5 | This script computes embeddings for an input anndata file using a pre-trained VCI model checkpoint. 6 | It can be run from any directory and outputs the embedded anndata to a specified location. 7 | 8 | Usage: 9 | python embed_vci.py --checkpoint PATH_TO_CHECKPOINT --input INPUT_ANNDATA --output OUTPUT_ANNDATA 10 | 11 | Example: 12 | python embed_vci.py --checkpoint /path/to/model.ckpt --input data.h5ad --output embedded_data.h5ad 13 | """ 14 | 15 | import argparse 16 | import os 17 | 18 | from omegaconf import OmegaConf 19 | 20 | from state_sets.state.inference import Inference 21 | 22 | 23 | # Parse command line arguments 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="Compute embeddings for anndata using a VCI model") 26 | parser.add_argument("--checkpoint", required=True, help="Path to the model checkpoint file") 27 | parser.add_argument("--config", required=True, help="Path to the model training config") 28 | parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") 29 | parser.add_argument("--output", required=True, help="Path to output embedded anndata file (h5ad)") 30 | parser.add_argument("--dataset-name", default="perturbation", help="Dataset name to be used in dataloader creation") 31 | parser.add_argument("--gpu", action="store_true", help="Use GPU if available") 32 | parser.add_argument("--filter", action="store_true", help="Filter gene set to our esm embeddings only.") 33 | parser.add_argument("--embed-key", help="Name of key to store") 34 | 35 | return parser.parse_args() 36 | 37 | 38 | def main(): 39 | # Parse command line arguments 40 | args = parse_args() 41 | 42 | conf = OmegaConf.load(args.config) 43 | inferer = Inference(conf) 44 | inferer.load_model(args.checkpoint) 45 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 46 | inferer.encode_adata(args.input, args.output, emb_key=args.embed_key, dataset_name=args.dataset_name) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /src/state/emb/eval/emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | import matplotlib.lines as mlines 7 | from sklearn.decomposition import PCA 8 | 9 | 10 | def cluster_embedding(adata, current_step, emb_key="X_emb", use_pca=True, job_name=""): 11 | embedding = PCA(n_components=2).fit_transform(adata.obsm[emb_key]) 12 | 13 | # Get the cell type information as a categorical series 14 | cell_types = adata.obs["cell_type"].astype("category") 15 | 16 | # Create a color palette based on the number of unique cell types 17 | palette = sns.color_palette("hsv", len(cell_types.cat.categories)) 18 | color_dict = dict(zip(cell_types.cat.categories, palette)) 19 | 20 | # Instead of using .map (which may fail with a MultiIndex), use a list comprehension 21 | colors = [color_dict[ct] for ct in cell_types] 22 | 23 | # Plot the embedding 24 | plt.figure(figsize=(8, 6)) 25 | plt.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=5, alpha=0.7) 26 | plt.xlabel("Component 1") 27 | plt.ylabel("Component 2") 28 | plt.title(f"Embedding ({'PCA'}) for {emb_key} ({job_name} - Iteration: {current_step})") 29 | 30 | # Create legend handles for each cell type 31 | handles = [ 32 | mlines.Line2D([], [], color=color_dict[ct], marker="o", linestyle="None", markersize=6, label=ct) 33 | for ct in cell_types.cat.categories 34 | ] 35 | plt.legend(handles=handles, title="Cell Type", bbox_to_anchor=(1.05, 1), loc="upper left") 36 | fig = plt.gcf() 37 | if wandb.run is not None: 38 | wandb.log({f"Clusters using embedding Iteration: {current_step}": fig}) 39 | 40 | # Also save the figure to a results directory instead of logging to wandb 41 | results_dir = "results/cluster_embeddings" 42 | os.makedirs(results_dir, exist_ok=True) 43 | filename = f"{job_name}_iter{current_step}.png" if job_name else f"iter{current_step}.png" 44 | fig_path = os.path.join(results_dir, filename) 45 | fig.savefig(fig_path, bbox_inches="tight") 46 | plt.close(fig) 47 | print(f"Cluster embedding plot saved to {fig_path}") 48 | -------------------------------------------------------------------------------- /src/state/emb/train/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import hydra 5 | from pathlib import Path 6 | from omegaconf import DictConfig, OmegaConf 7 | from typing import Optional 8 | 9 | sys.path.append("../../") 10 | import vci.train.trainer as train 11 | 12 | log = logging.getLogger(__name__) 13 | # os.environ["NCCL_TIMEOUT"] = "36000" 14 | 15 | # Custom command line resolver for hydra 16 | import argparse 17 | 18 | 19 | def main(config_path: Optional[str] = None): 20 | parser = argparse.ArgumentParser(description="VCI pretraining") 21 | parser.add_argument("--conf", type=str, help="Path to config YAML file") 22 | 23 | # First parse just the conf argument 24 | args, override_args = parser.parse_known_args() 25 | 26 | if not args.conf: 27 | parser.error("--conf argument is required") 28 | 29 | # Initialize hydra with the directory of the config file 30 | config_file = Path(args.conf) 31 | config_dir = str(config_file.parent) 32 | config_name = config_file.name 33 | 34 | # Initialize configuration 35 | with hydra.initialize_config_module(config_module=None, version_base=None): 36 | # Load the base configuration 37 | cfg = OmegaConf.load(args.conf) 38 | 39 | # Process the remaining command line arguments as overrides 40 | if override_args: 41 | overrides = OmegaConf.from_dotlist(override_args) 42 | cfg = OmegaConf.merge(cfg, overrides) 43 | 44 | # Execute the main logic 45 | run_with_config(cfg) 46 | 47 | 48 | def run_with_config(cfg: DictConfig): 49 | if cfg.embeddings.current is None: 50 | log.error("Gene embeddings are required for training. Please set 'embeddings.current'") 51 | sys.exit(1) 52 | 53 | if cfg.dataset.current is None: 54 | log.error("Please set the desired dataset to 'dataset.current'") 55 | sys.exit(1) 56 | 57 | os.environ["MASTER_PORT"] = str(cfg.experiment.port) 58 | # WAR: Workaround for sbatch failing when --ntasks-per-node is set. 59 | # lightning expects this to be set. 60 | os.environ["SLURM_NTASKS_PER_NODE"] = str(cfg.experiment.num_gpus_per_node) 61 | 62 | log.info(f"*************** Training {cfg.experiment.name} ***************") 63 | log.info(OmegaConf.to_yaml(cfg)) 64 | 65 | train.main(cfg) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /src/state/tx/models/scgpt/dsbn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | # The code is modified from https://github.com/wgchang/DSBN/blob/master/model/dsbn.py 8 | class _DomainSpecificBatchNorm(nn.Module): 9 | _version = 2 10 | 11 | def __init__( 12 | self, 13 | num_features: int, 14 | num_domains: int, 15 | eps: float = 1e-5, 16 | momentum: float = 0.1, 17 | affine: bool = True, 18 | track_running_stats: bool = True, 19 | ): 20 | super(_DomainSpecificBatchNorm, self).__init__() 21 | self._cur_domain = None 22 | self.num_domains = num_domains 23 | self.bns = nn.ModuleList( 24 | [self.bn_handle(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_domains)] 25 | ) 26 | 27 | @property 28 | def bn_handle(self) -> nn.Module: 29 | raise NotImplementedError 30 | 31 | @property 32 | def cur_domain(self) -> Optional[int]: 33 | return self._cur_domain 34 | 35 | @cur_domain.setter 36 | def cur_domain(self, domain_label: int): 37 | self._cur_domain = domain_label 38 | 39 | def reset_running_stats(self): 40 | for bn in self.bns: 41 | bn.reset_running_stats() 42 | 43 | def reset_parameters(self): 44 | for bn in self.bns: 45 | bn.reset_parameters() 46 | 47 | def _check_input_dim(self, input: torch.Tensor): 48 | raise NotImplementedError 49 | 50 | def forward(self, x: torch.Tensor, domain_label: int) -> torch.Tensor: 51 | self._check_input_dim(x) 52 | if domain_label >= self.num_domains: 53 | raise ValueError(f"Domain label {domain_label} exceeds the number of domains {self.num_domains}") 54 | bn = self.bns[domain_label] 55 | self.cur_domain = domain_label 56 | return bn(x) 57 | 58 | 59 | class DomainSpecificBatchNorm1d(_DomainSpecificBatchNorm): 60 | @property 61 | def bn_handle(self) -> nn.Module: 62 | return nn.BatchNorm1d 63 | 64 | def _check_input_dim(self, input: torch.Tensor): 65 | if input.dim() > 3: 66 | raise ValueError("expected at most 3D input (got {}D input)".format(input.dim())) 67 | 68 | 69 | class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm): 70 | @property 71 | def bn_handle(self) -> nn.Module: 72 | return nn.BatchNorm2d 73 | 74 | def _check_input_dim(self, input: torch.Tensor): 75 | if input.dim() != 4: 76 | raise ValueError("expected 4D input (got {}D input)".format(input.dim())) 77 | -------------------------------------------------------------------------------- /src/state/tx/callbacks/batch_speed_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from lightning.pytorch.callbacks import Callback 4 | 5 | 6 | class BatchSpeedMonitorCallback(Callback): 7 | """ 8 | Callback that logs the number of batches processed per second to wandb. 9 | """ 10 | 11 | def __init__(self, logging_interval=50): 12 | """ 13 | Args: 14 | logging_interval: Log the speed every N batches 15 | """ 16 | super().__init__() 17 | self.logging_interval = logging_interval 18 | self.batch_start_time = None 19 | self.batch_times = [] 20 | self.last_logged_batch = 0 21 | 22 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): 23 | """Record the start time of the batch.""" 24 | self.batch_start_time = time.time() 25 | 26 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 27 | """ 28 | Calculate and log the batch processing speed. 29 | """ 30 | if self.batch_start_time is None: 31 | return 32 | 33 | # Calculate time taken for this batch 34 | batch_end_time = time.time() 35 | batch_time = batch_end_time - self.batch_start_time 36 | self.batch_times.append(batch_time) 37 | 38 | # Log every logging_interval batches 39 | if batch_idx % self.logging_interval == 0 and batch_idx > 0: 40 | # Calculate batches per second over the last interval 41 | if len(self.batch_times) > 0: 42 | avg_batch_time = sum(self.batch_times) / len(self.batch_times) 43 | batches_per_second = 1.0 / avg_batch_time if avg_batch_time > 0 else 0 44 | 45 | # Log to wandb 46 | pl_module.log("batches_per_second", batches_per_second) 47 | 48 | # Also log min, max, and coefficient of variation to help diagnose variability 49 | if len(self.batch_times) > 1: 50 | min_time = min(self.batch_times) 51 | max_time = max(self.batch_times) 52 | std_dev = (sum((t - avg_batch_time) ** 2 for t in self.batch_times) / len(self.batch_times)) ** 0.5 53 | cv = (std_dev / avg_batch_time) * 100 if avg_batch_time > 0 else 0 54 | 55 | pl_module.log("batch_time_min", min_time) 56 | pl_module.log("batch_time_max", max_time) 57 | pl_module.log("batch_time_avg", avg_batch_time) 58 | pl_module.log("batch_time_cv_percent", cv) 59 | 60 | # Log max/min ratio to identify extreme outliers 61 | if min_time > 0: 62 | pl_module.log("batch_time_max_min_ratio", max_time / min_time) 63 | 64 | # Reset for next interval 65 | self.batch_times = [] 66 | self.last_logged_batch = batch_idx 67 | -------------------------------------------------------------------------------- /src/state/tx/models/decoders_nb.py: -------------------------------------------------------------------------------- 1 | # models/decoders_nb.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import NegativeBinomial 6 | 7 | 8 | class NBDecoder(nn.Module): 9 | """ 10 | scVI‑style decoder that maps a latent embedding (optionally with batch covariates) 11 | to the parameters of a negative‑binomial (or ZINB) distribution over raw counts. 12 | 13 | Y_ig ~ NB(μ_ig, θ_g) where 14 | μ_ig = l_i * softplus(W_g z_i + b_g) 15 | θ_g = softplus(r_g) (gene‑specific inverse dispersion) 16 | 17 | Optionally, a zero‑inflation gate π_ig can be produced (not shown here). 18 | """ 19 | 20 | def __init__( 21 | self, 22 | latent_dim: int, 23 | gene_dim: int, 24 | hidden_dims=[1024, 256, 256], 25 | dropout: float = 0.0, 26 | use_zero_inflation: bool = False, 27 | ): 28 | super().__init__() 29 | modules = [] 30 | in_features = latent_dim 31 | for h in hidden_dims: 32 | modules += [ 33 | nn.Linear(in_features, h), 34 | nn.LayerNorm(h), 35 | nn.GELU(), 36 | nn.Dropout(dropout), 37 | ] 38 | in_features = h 39 | self.encoder = nn.Sequential(*modules) 40 | 41 | self.skip = nn.Identity() if in_features == latent_dim else nn.Linear(latent_dim, in_features, bias=False) 42 | self.post_norm = nn.LayerNorm(in_features) 43 | 44 | # Mean parameter 45 | self.px_scale = nn.Linear(in_features, gene_dim) 46 | 47 | self.l_encoder = nn.Linear(in_features, 1) 48 | 49 | # Gene‑specific inverse dispersion (log‑space, broadcasted) 50 | self.log_theta = nn.Parameter(torch.randn(gene_dim)) 51 | 52 | # Optional zero‑inflation gate 53 | self.use_zero_inflation = use_zero_inflation 54 | if use_zero_inflation: 55 | self.px_dropout = nn.Linear(in_features, gene_dim) 56 | 57 | @property 58 | def theta(self): 59 | # softplus to keep positive 60 | return F.softplus(self.log_theta) 61 | 62 | def forward(self, z: torch.Tensor, log_library: torch.Tensor | None = None): 63 | """ 64 | z: [B, latent_dim] 65 | log_library: [B, 1] (optional – if None we predict it) 66 | returns μ, θ (and π if requested) 67 | """ 68 | flat = False 69 | if z.dim() == 3: # [B,S,D] → flatten 70 | B, S, D = z.shape 71 | z = z.reshape(-1, D) 72 | flat = True 73 | 74 | h = self.encoder(z) # [B* S, H] 75 | h = self.post_norm(h + self.skip(z)) 76 | 77 | if log_library is None: 78 | log_library = self.l_encoder(h) # [B* S, 1] 79 | px_scale = F.softplus(self.px_scale(h)) # [B* S, G] 80 | mu = torch.exp(log_library) * px_scale # NB mean 81 | 82 | if self.use_zero_inflation: 83 | pi = torch.sigmoid(self.px_dropout(h)) 84 | outs = (mu, self.theta, pi) 85 | else: 86 | outs = (mu, self.theta) 87 | 88 | if flat: # reshape back to [B,S,*] 89 | mu = mu.reshape(B, S, -1) 90 | if self.use_zero_inflation: 91 | pi = pi.reshape(B, S, -1) 92 | return mu, self.theta, pi # θ remains [G] 93 | else: 94 | return mu, self.theta 95 | return outs 96 | 97 | def gene_dim(self) -> int: 98 | return self.px_scale.out_features 99 | 100 | 101 | def nb_nll(x, mu, theta, eps: float = 1e-6): 102 | """ 103 | Negative‑binomial negative log‑likelihood. 104 | x, mu : [..., G] 105 | theta : [G] or [..., G] 106 | returns scalar 107 | """ 108 | logits = (mu + eps).log() - (theta + eps).log() # NB parameterisation 109 | dist = NegativeBinomial(total_count=theta, logits=logits) 110 | return -dist.log_prob(x).mean() 111 | -------------------------------------------------------------------------------- /src/state/tx/callbacks/cumulative_flops.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Optional 3 | 4 | import torch 5 | from lightning import LightningModule, Trainer 6 | from lightning.fabric.utilities.throughput import measure_flops 7 | from lightning.pytorch.callbacks import Callback 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | class CumulativeFLOPSCallback(Callback): 14 | """ 15 | PyTorch Lightning callback to track cumulative FLOPS during training. 16 | 17 | - Measures FLOPs once on the first training batch using `measure_flops`. 18 | - Tracks cumulative FLOPs and logs at validation frequency. 19 | - Logs cumulative_flops to trainer loggers (e.g., W&B, CSV) at validation cadence. 20 | 21 | Args: 22 | use_backward: If True, include backward pass FLOPs in the measurement. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | *, 28 | use_backward: bool = False, 29 | ) -> None: 30 | super().__init__() 31 | self.use_backward = use_backward 32 | 33 | self._flops_per_batch: Optional[int] = None 34 | self._measured: bool = False 35 | self._cumulative_flops: int = 0 36 | self._batch_count: int = 0 37 | 38 | def _trainstep_forward_backward(self, model: LightningModule, batch: Any) -> torch.Tensor: 39 | """Encapsulate calling StateTransitionPerturbationModel.training_step and backward. 40 | 41 | This intentionally targets StateTransitionPerturbationModel's signature and 42 | performs both forward and backward to capture full FLOPs. 43 | 44 | !!WARNING!! 45 | This has only been tested with StateTransitionPerturbationModel. Behavior with any other model has not been verified. 46 | """ 47 | model.zero_grad(set_to_none=True) 48 | loss: torch.Tensor = model.training_step(batch, 0, padded=True) # type: ignore 49 | if self.use_backward: 50 | loss.backward() 51 | return loss 52 | 53 | def _measure_flops_once(self, trainer: Trainer, pl_module: Any, batch: Any) -> None: 54 | if self._measured: 55 | return 56 | 57 | model = pl_module 58 | 59 | def forward_fn(): 60 | return self._trainstep_forward_backward(model, batch) 61 | 62 | self._flops_per_batch = int(measure_flops(model, forward_fn=forward_fn)) 63 | logger.info(f"CumulativeFLOPSCallback: Measured FLOPs per batch: {self._flops_per_batch}") 64 | 65 | model.zero_grad(set_to_none=True) 66 | self._measured = True 67 | 68 | def on_train_batch_start(self, trainer: Trainer, pl_module: Any, batch: dict, batch_idx: int) -> None: 69 | if not self._measured and batch_idx == 0 and trainer.current_epoch == 0: 70 | self._measure_flops_once(trainer, pl_module, batch) 71 | 72 | def on_train_batch_end(self, trainer: Trainer, pl_module: Any, outputs: Any, batch: dict, batch_idx: int) -> None: 73 | if self._flops_per_batch is None: 74 | return 75 | 76 | self._batch_count += 1 77 | self._cumulative_flops += self._flops_per_batch 78 | 79 | # Log cumulative FLOPs after every training batch 80 | pl_module.log( 81 | "cumulative_flops", 82 | float(self._cumulative_flops), 83 | prog_bar=False, 84 | on_step=True, 85 | on_epoch=False, 86 | sync_dist=True, 87 | ) 88 | logger.info(f"CumulativeFLOPSCallback: Logged cumulative FLOPs: {self._cumulative_flops}") 89 | 90 | def on_validation_start(self, trainer: Trainer, pl_module: Any) -> None: 91 | if self._flops_per_batch is None: 92 | return 93 | 94 | # Log cumulative FLOPs at validation frequency for W&B panel alignment 95 | pl_module.log( 96 | "cumulative_flops_val_sync", 97 | float(self._cumulative_flops), 98 | prog_bar=False, 99 | on_step=False, 100 | on_epoch=True, 101 | sync_dist=True, 102 | ) 103 | -------------------------------------------------------------------------------- /src/state/emb/nn/flash_transformer.py: -------------------------------------------------------------------------------- 1 | # File: vci/flash_transformer.py 2 | """ 3 | This module implements a Transformer encoder layer. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class FlashTransformerEncoderLayer(nn.Module): 12 | def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1): 13 | """ 14 | Initializes the encoder layer. 15 | Args: 16 | d_model (int): model dimension. 17 | nhead (int): number of attention heads. 18 | dim_feedforward (int): dimension of the feed-forward network. 19 | dropout (float): dropout probability. 20 | """ 21 | super().__init__() 22 | torch.backends.cuda.enable_flash_sdp(True) 23 | 24 | self.d_model = d_model 25 | self.nhead = nhead 26 | self.dropout = dropout 27 | 28 | # Linear projections for Q, K, V in one matrix 29 | self.qkv_proj = nn.Linear(d_model, d_model * 3) 30 | self.out_proj = nn.Linear(d_model, d_model) 31 | 32 | self.norm1 = nn.LayerNorm(d_model) 33 | self.norm2 = nn.LayerNorm(d_model) 34 | self.dropout_layer = nn.Dropout(dropout) 35 | 36 | # Feed-forward network 37 | self.linear1 = nn.Linear(d_model, dim_feedforward) 38 | self.linear2 = nn.Linear(dim_feedforward, d_model) 39 | 40 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 41 | """ 42 | Args: 43 | src: Tensor of shape (batch_size, seq_len, d_model) 44 | src_mask: (optional) attention mask. 45 | src_key_padding_mask: (optional) padding mask. 46 | Returns: 47 | Tensor of shape (batch_size, seq_len, d_model) 48 | """ 49 | # For this simple implementation, we'll use either one of the masks. 50 | # You can combine them as needed. 51 | mask = src_key_padding_mask if src_key_padding_mask is not None else src_mask 52 | 53 | # ----- Self-Attention Block ----- 54 | residual = src 55 | 56 | # Compute Q, K, V projections in one go. 57 | qkv = self.qkv_proj(src) # shape: (B, T, 3*d_model) 58 | q, k, v = torch.chunk(qkv, 3, dim=-1) # each: (B, T, d_model) 59 | 60 | # Reshape for multi-head attention. 61 | head_dim = self.d_model // self.nhead 62 | q = q.view(src.size(0), src.size(1), self.nhead, head_dim).transpose(1, 2) # (B, nhead, T, head_dim) 63 | k = k.view(src.size(0), src.size(1), self.nhead, head_dim).transpose(1, 2) 64 | v = v.view(src.size(0), src.size(1), self.nhead, head_dim).transpose(1, 2) 65 | 66 | # Use PyTorch’s built-in scaled_dot_product_attention. 67 | attn_output = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False) 68 | # Merge heads. 69 | attn_output = attn_output.transpose(1, 2).contiguous().view(src.size(0), src.size(1), self.d_model) 70 | attn_output = self.out_proj(attn_output) 71 | src = self.norm1(residual + self.dropout_layer(attn_output)) 72 | 73 | # ----- Feed-Forward Block ----- 74 | residual2 = src 75 | ff_output = self.linear2(self.dropout_layer(F.gelu(self.linear1(src)))) 76 | src = self.norm2(residual2 + self.dropout_layer(ff_output)) 77 | return src 78 | 79 | 80 | class FlashTransformerEncoder(nn.Module): 81 | def __init__(self, layers): 82 | """ 83 | A simple encoder that applies a stack of FlashTransformerEncoderLayer instances. 84 | Args: 85 | layers (list[nn.Module]): list of FlashTransformerEncoderLayer instances. 86 | """ 87 | super().__init__() 88 | self.layers = nn.ModuleList(layers) 89 | 90 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 91 | """ 92 | Applies each encoder layer in sequence. 93 | Args: 94 | src: Tensor of shape (B, T, d_model) 95 | src_mask: (optional) attention mask. 96 | src_key_padding_mask: (optional) padding mask. 97 | Returns: 98 | Tensor of shape (B, T, d_model) 99 | """ 100 | # Use src_key_padding_mask if provided; otherwise use src_mask. 101 | mask = src_key_padding_mask if src_key_padding_mask is not None else src_mask 102 | output = src 103 | for layer in self.layers: 104 | output = layer(output, src_mask=mask, src_key_padding_mask=mask) 105 | return output 106 | -------------------------------------------------------------------------------- /src/state/_cli/_emb/_query.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | 4 | def add_arguments_query(parser: ap.ArgumentParser): 5 | """Add arguments for state embedding query CLI.""" 6 | parser.add_argument("--lancedb", required=True, help="Path to existing LanceDB database") 7 | parser.add_argument("--input", required=True, help="Path to input anndata file with query cells") 8 | parser.add_argument("--output", required=True, help="Path to output file for results (csv, parquet)") 9 | parser.add_argument("--k", type=int, default=3, help="Number of nearest neighbors to return") 10 | parser.add_argument("--embed-key", default="X_state", help="Key containing embeddings in input file") 11 | parser.add_argument("--exclude-distances", action="store_true", help="Exclude vector distances in results") 12 | parser.add_argument("--filter", type=str, help="Filter expression (e.g., 'cell_type==\"B cell\"')") 13 | parser.add_argument("--batch-size", type=int, default=100, help="Batch size for query operations") 14 | 15 | 16 | def run_emb_query(args: ap.ArgumentParser): 17 | import logging 18 | import pandas as pd 19 | import anndata 20 | from pathlib import Path 21 | 22 | """ 23 | Query a LanceDB database for similar cells. 24 | """ 25 | logging.basicConfig(level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | from ...emb.vectordb import StateVectorDB 29 | 30 | # check output file extension 31 | if not args.output.endswith((".csv", ".parquet")): 32 | raise ValueError("Output file must have a .csv or .parquet extension") 33 | 34 | # Load query cells 35 | logger.info(f"Loading query cells from {args.input}") 36 | query_adata = anndata.read_h5ad(args.input) 37 | 38 | # Get embeddings 39 | if args.embed_key in query_adata.obsm: 40 | query_embeddings = query_adata.obsm[args.embed_key] 41 | else: 42 | raise ValueError(f"Embedding key '{args.embed_key}' not found in input file") 43 | 44 | logger.info(f"Found {len(query_embeddings)} query cells") 45 | 46 | # Connect to database 47 | vector_db = StateVectorDB(args.lancedb) 48 | 49 | # Get database info 50 | db_info = vector_db.get_table_info() 51 | if db_info: 52 | logger.info(f"Database contains {db_info['num_rows']} cells with {db_info['embedding_dim']}-dim embeddings") 53 | 54 | # Perform batch search 55 | logger.info(f"Searching for {args.k} nearest neighbors per query cell...") 56 | results_list = vector_db.batch_search( 57 | query_vectors=query_embeddings, 58 | k=args.k, 59 | filter=args.filter, 60 | include_distance=not args.exclude_distances, 61 | batch_size=args.batch_size, 62 | show_progress=True, 63 | ) 64 | 65 | # Add query cell IDs and ranks to results 66 | all_results = [] 67 | for query_idx, result_df in enumerate(results_list): 68 | result_df["query_cell_id"] = query_adata.obs.index[query_idx] 69 | result_df["query_rank"] = range(1, len(result_df) + 1) 70 | all_results.append(result_df) 71 | 72 | # Combine results 73 | final_results = pd.concat(all_results, ignore_index=True) 74 | 75 | # Save results 76 | output_path = Path(args.output) 77 | output_path.parent.mkdir(parents=True, exist_ok=True) 78 | 79 | if args.output.endswith(".csv"): 80 | final_results.to_csv(args.output, index=False) 81 | logger.info(f"Saved results to {args.output}") 82 | elif args.output.endswith(".parquet"): 83 | final_results.to_parquet(args.output, index=False) 84 | logger.info(f"Saved results to {args.output}") 85 | else: 86 | raise ValueError(f"Unsupported output format: {args.output}") 87 | 88 | 89 | def create_result_anndata(query_adata, results_df, k): 90 | """Create an anndata object containing query results.""" 91 | # Pivot cell IDs 92 | cell_ids_pivot = results_df.pivot(index="query_cell_id", columns="query_rank", values="cell_id") 93 | cell_ids_array = np.array(cell_ids_pivot.values, dtype=str) 94 | 95 | # Handle distances - convert to float64 and handle missing values 96 | if "vector_distance" in results_df: 97 | distances_pivot = results_df.pivot(index="query_cell_id", columns="query_rank", values="vector_distance") 98 | distances_array = np.array(distances_pivot.values, dtype=np.float64) 99 | else: 100 | distances_array = None 101 | 102 | # Store match information in uns 103 | uns_data = {"query_matches": {"cell_ids": cell_ids_array, "distances": distances_array, "k": int(k)}} 104 | 105 | # Create result anndata 106 | result_adata = query_adata.copy() 107 | result_adata.uns["lancedb_query_results"] = uns_data 108 | 109 | return result_adata 110 | -------------------------------------------------------------------------------- /src/state/emb/nn/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from geomloss import SamplesLoss 5 | 6 | 7 | class WassersteinLoss(nn.Module): 8 | """ 9 | Implements Wasserstein distance loss for distributions represented by logits. 10 | This implementation supports both 1D and 2D Wasserstein distance calculations. 11 | """ 12 | 13 | def __init__(self, p=1, reduction="mean"): 14 | """ 15 | Args: 16 | p (int): Order of Wasserstein distance (1 or 2) 17 | reduction (str): 'mean', 'sum', or 'none' 18 | """ 19 | super().__init__() 20 | self.p = p 21 | self.reduction = reduction 22 | 23 | def forward(self, p, q): 24 | """ 25 | Compute Wasserstein distance between predicted and target distributions. 26 | 27 | Args: 28 | logits (torch.Tensor): Predicted logits of shape (batch_size, num_classes) 29 | target (torch.Tensor): Target probabilities of shape (batch_size, num_classes) 30 | or class indices of shape (batch_size,) 31 | 32 | Returns: 33 | torch.Tensor: Computed Wasserstein distance 34 | """ 35 | 36 | q = torch.nan_to_num(q, nan=0.0) 37 | # Convert logits to probabilities 38 | pred_probs = F.softmax(p, dim=-1) 39 | q = F.softmax(q, dim=-1) 40 | 41 | # Compute cumulative distribution functions (CDFs) 42 | pred_cdf = torch.cumsum(pred_probs, dim=-1) 43 | target_cdf = torch.cumsum(q, dim=-1) 44 | 45 | max_len = max(pred_cdf.size(1), target_cdf.size(1)) 46 | if pred_cdf.size(1) < max_len: 47 | pred_cdf = F.pad(pred_cdf, (0, max_len - pred_cdf.size(1)), "constant", 0) 48 | if target_cdf.size(1) < max_len: 49 | target_cdf = F.pad(target_cdf, (0, max_len - target_cdf.size(1)), "constant", 0) 50 | 51 | # Compute Wasserstein distance 52 | wasserstein_dist = torch.abs(pred_cdf - target_cdf).pow(self.p) 53 | wasserstein_dist = wasserstein_dist.sum(dim=-1) 54 | 55 | # Apply reduction if specified 56 | if self.reduction == "mean": 57 | return wasserstein_dist.mean() 58 | elif self.reduction == "sum": 59 | return wasserstein_dist.sum() 60 | return wasserstein_dist 61 | 62 | 63 | class KLDivergenceLoss(nn.Module): 64 | def __init__(self, apply_normalization=False, epsilon=1e-10): 65 | super().__init__() 66 | self.apply_normalization = apply_normalization 67 | self.epsilon = epsilon 68 | 69 | def forward(self, p, q): 70 | q = torch.nan_to_num(q, nan=0.0) 71 | p = torch.nan_to_num(p, nan=0.0) 72 | 73 | max_len = max(p.size(1), q.size(1)) 74 | if p.size(1) < max_len: 75 | p = F.pad(p, (0, max_len - p.size(1)), "constant", 0) 76 | if q.size(1) < max_len: 77 | q = F.pad(q, (0, max_len - q.size(1)), "constant", 0) 78 | 79 | if self.apply_normalization: 80 | p = F.softmax(p, dim=-1) 81 | q = F.softmax(q, dim=-1) 82 | 83 | return torch.sum(p * torch.log(p / q)) 84 | 85 | 86 | class MMDLoss(nn.Module): 87 | def __init__(self, kernel="energy", blur=0.05, scaling=0.5, downsample=1): 88 | super().__init__() 89 | self.mmd_loss = SamplesLoss(loss=kernel, blur=blur, scaling=scaling) 90 | self.downsample = downsample 91 | 92 | def forward(self, input, target): 93 | input = input.reshape(-1, self.downsample, input.shape[-1]) 94 | target = target.reshape(-1, self.downsample, target.shape[-1]) 95 | 96 | loss = self.mmd_loss(input, target) 97 | return loss.mean() 98 | 99 | 100 | class TabularLoss(nn.Module): 101 | def __init__(self, shared=128, downsample=1): 102 | super().__init__() 103 | self.shared = shared 104 | self.downsample = downsample 105 | 106 | self.gene_loss = SamplesLoss(loss="energy") 107 | self.cell_loss = SamplesLoss(loss="energy") 108 | 109 | def forward(self, input, target): 110 | input = input.reshape(-1, self.downsample, input.shape[-1]) 111 | target = target.reshape(-1, self.downsample, target.shape[-1]) 112 | gene_mmd = self.gene_loss(input, target).nanmean() 113 | 114 | # cell_mmd should only be on the shared genes, and match scale to mse loss 115 | cell_inputs = input[:, :, -self.shared :] 116 | cell_targets = target[:, :, -self.shared :] 117 | 118 | # need to reshape each from (B, self.downsample, F) to (F, self.downsample, B) 119 | cell_inputs = cell_inputs.transpose(2, 0) 120 | cell_targets = cell_targets.transpose(2, 0) 121 | cell_mmd = self.cell_loss(cell_inputs, cell_targets).nanmean() 122 | 123 | final_loss = torch.tensor(0.0).to(cell_mmd.device) 124 | if not gene_mmd.isnan(): 125 | final_loss += gene_mmd 126 | if not cell_mmd.isnan(): 127 | final_loss += cell_mmd 128 | 129 | return final_loss 130 | -------------------------------------------------------------------------------- /MODEL_ACCEPTABLE_USE_POLICY.md: -------------------------------------------------------------------------------- 1 | 2 | # Arc Research Institute State Model Acceptable Use Policy 3 | 4 | **_Last updated June 23, 2025_** 5 | 6 | Arc Research Institute (the “**Institute**,” “**we**” or “**us**”) makes its State Model available and free to use for non-commercial purposes, subject to the terms of the [Arc Research Institute State Model Non-Commercial License](MODEL_LICENSE.md) (the “**License**”) and this Acceptable Use Policy (“**Policy**”). The purpose of this Policy is to ensure the State Model is used safely, ethically, and in accordance with all applicable laws and regulations. Any defined terms used but not defined in this Policy have the meaning given in the License. 7 | 8 | ## Purpose and Permitted Use 9 | 10 | The intent of the State Model is to support and enable research that advances knowledge and serves the public interest. This model is made available for use only by non-commercial entities such as government institutions, non-profit organizations, research institutes, and educational institutions. 11 | 12 | Although the State Model may only be used for non-commercial purposes, you are free to use any outputs you create for any purpose. 13 | 14 | ## Prohibited Uses 15 | 16 | Unless otherwise expressly permitted pursuant to the License, you may not use the State Model, or any outputs or derivatives thereof, for any of the following purposes: 17 | 18 | **1) For-Profit Use:** 19 | - Use by or for any for-profit entity, including but not limited to corporations, partnerships, or other commercial organizations. 20 | - Use for any for-profit purpose, including but not limited to commercial research, sponsored research, product or service development, marketing, advertising, or any activity intended to generate revenue or commercial advantage. 21 | - For clarity, Outputs (as defined in the License) are not subject to this Section 1. 22 | 23 | **2) Unlawful, Harmful, or Unethical Activities:** 24 | - Any activity that violates applicable local, national, or international laws, rules, or regulations. 25 | - Any activity that causes or is intended to cause harm, including but not limited to the generation of offensive, abusive, or unlawful content. 26 | - Any use that infringes upon the rights of others, including invasion of privacy or jeopardizing the safety of others. 27 | 28 | **3) Misinforming or Misleading:** 29 | - Generation, dissemination, or promotion of false, incomplete, or otherwise misleading information. 30 | - Any use intended to deceive, misinform, or otherwise mislead individuals or organizations, including but not limited fabrication of data, fabrication or manipulation of outputs, impersonation, or to otherwise misrepresent facts or scientific findings. 31 | 32 | **4) Security and Privacy Violations:** 33 | - Introducing malware, viruses, or other malicious code. 34 | - Circumventing or attempting to circumvent any security or access controls. 35 | - Collecting, storing, or sharing sensitive, health, or personal data without proper authorization or consent. 36 | 37 | **5) Claiming Endorsement:** 38 | - Making any statements or claims or engaging in any other behavior or communication that indicates or suggests that the Institute endorses your use of the State Model, including any Derivative Works or Outputs (each as defined in the License). 39 | 40 | ## Attribution 41 | 42 | If You redistribute the State Model, or any outputs or derivatives thereof, you must also include a prominent and readable citation to the State Model research paper: Adduri, A. et al. (2025) Predicting cellular responses to perturbation across diverse contexts with State. 43 | 44 | ## Disclaimer 45 | 46 | The State Model is provided AS IS and is intended for informational, theoretical, and research purposes only and is not intended for use in the diagnosis of disease or other conditions, or in the cure, mitigation, treatment, or prevention of disease. It should not be used or substituted for professional, medical advice. 47 | 48 | ## Third Party Integrations 49 | 50 | The Institute is not responsible for the content, security, or privacy practices of any third-party technology, data, materials, or services that you may use in connection with the State Model. Use of such integrations is at your own risk and subject to the terms and policies of the respective third parties. The Institute disclaims any liability for damages or losses resulting from third-party integrations. 51 | 52 | ## Legal and Regulatory Compliance 53 | 54 | The State Model may not be appropriate or available for use in some jurisdictions. Any use of the State Model is at your own risk, and you must comply with applicable laws, rules, and regulations in doing so. This includes, but is not limited to, data protection, privacy, and export control laws. 55 | 56 | ## Enforcement and Disclaimer 57 | 58 | Violation of this Policy may result in suspension or termination of the License and access to the State Model, and may subject you to legal liability. We reserve the right to investigate suspected violations and to cooperate with law enforcement authorities. The Institute disclaims liability for any misuse of or unauthorized training of the State Model. 59 | 60 | ## Contact 61 | 62 | For questions about this Policy or to report suspected violations, please contact the State Model administrator at . 63 | -------------------------------------------------------------------------------- /src/state/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | from hydra import compose, initialize 4 | from omegaconf import DictConfig 5 | 6 | from ._cli import ( 7 | add_arguments_emb, 8 | add_arguments_tx, 9 | run_emb_fit, 10 | run_emb_transform, 11 | run_emb_query, 12 | run_emb_preprocess, 13 | run_emb_eval, 14 | run_tx_infer, 15 | run_tx_predict, 16 | run_tx_preprocess_infer, 17 | run_tx_preprocess_train, 18 | run_tx_train, 19 | ) 20 | 21 | 22 | def get_args() -> tuple[ap.Namespace, list[str]]: 23 | """Parse known args and return remaining args for Hydra overrides""" 24 | parser = ap.ArgumentParser() 25 | subparsers = parser.add_subparsers(required=True, dest="command") 26 | add_arguments_emb(subparsers.add_parser("emb")) 27 | add_arguments_tx(subparsers.add_parser("tx")) 28 | 29 | # Use parse_known_args to get both known args and remaining args 30 | return parser.parse_args() 31 | 32 | 33 | def load_hydra_config(method: str, overrides: list[str] = None) -> DictConfig: 34 | """Load Hydra config with optional overrides""" 35 | if overrides is None: 36 | overrides = [] 37 | 38 | # Initialize Hydra with the path to your configs directory 39 | # Adjust the path based on where this file is relative to configs/ 40 | with initialize(version_base=None, config_path="configs"): 41 | match method: 42 | case "emb": 43 | cfg = compose(config_name="state-defaults", overrides=overrides) 44 | case "tx": 45 | cfg = compose(config_name="config", overrides=overrides) 46 | case _: 47 | raise ValueError(f"Unknown method: {method}") 48 | return cfg 49 | 50 | 51 | def show_hydra_help(method: str): 52 | """Show Hydra configuration help with all parameters""" 53 | from omegaconf import OmegaConf 54 | 55 | # Load the default config to show structure 56 | cfg = load_hydra_config(method) 57 | 58 | print("Hydra Configuration Help") 59 | print("=" * 50) 60 | print(f"Configuration for method: {method}") 61 | print() 62 | print("Full configuration structure:") 63 | print(OmegaConf.to_yaml(cfg)) 64 | print() 65 | print("Usage examples:") 66 | print(" Override single parameter:") 67 | print(" uv run state tx train data.batch_size=64") 68 | print() 69 | print(" Override nested parameter:") 70 | print(" uv run state tx train model.kwargs.hidden_dim=512") 71 | print() 72 | print(" Override multiple parameters:") 73 | print(" uv run state tx train data.batch_size=64 training.lr=0.001") 74 | print() 75 | print(" Change config group:") 76 | print(" uv run state tx train data=custom_data model=custom_model") 77 | print() 78 | print("Available config groups:") 79 | 80 | # Show available config groups 81 | from pathlib import Path 82 | 83 | config_dir = Path(__file__).parent / "configs" 84 | if config_dir.exists(): 85 | for item in config_dir.iterdir(): 86 | if item.is_dir() and not item.name.startswith("."): 87 | configs = [f.stem for f in item.glob("*.yaml")] 88 | if configs: 89 | print(f" {item.name}: {', '.join(configs)}") 90 | 91 | exit(0) 92 | 93 | 94 | def main(): 95 | args = get_args() 96 | 97 | match args.command: 98 | case "emb": 99 | match args.subcommand: 100 | case "fit": 101 | cfg = load_hydra_config("emb", args.hydra_overrides) 102 | run_emb_fit(cfg, args) 103 | case "transform": 104 | run_emb_transform(args) 105 | case "query": 106 | run_emb_query(args) 107 | case "preprocess": 108 | run_emb_preprocess(args) 109 | case "eval": 110 | run_emb_eval(args) 111 | case "tx": 112 | match args.subcommand: 113 | case "train": 114 | if hasattr(args, "help") and args.help: 115 | # Show Hydra configuration help 116 | show_hydra_help("tx") 117 | else: 118 | # Load Hydra config with overrides for sets training 119 | cfg = load_hydra_config("tx", args.hydra_overrides) 120 | run_tx_train(cfg) 121 | case "predict": 122 | # For now, predict uses argparse and not hydra 123 | run_tx_predict(args) 124 | case "infer": 125 | # Run inference using argparse, similar to predict 126 | run_tx_infer(args) 127 | case "preprocess_train": 128 | # Run preprocessing using argparse 129 | run_tx_preprocess_train(args.adata, args.output, args.num_hvgs) 130 | case "preprocess_infer": 131 | # Run inference preprocessing using argparse 132 | run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /src/state/tx/models/decoders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from omegaconf import OmegaConf 6 | 7 | from ...emb.finetune_decoder import Finetune 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class FinetuneVCICountsDecoder(nn.Module): 13 | def __init__( 14 | self, 15 | genes, 16 | # model_loc="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/step=950000.ckpt", 17 | # config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/rda_tabular_counts_2048_new/tahoe_config.yaml", 18 | model_loc="/home/aadduri/vci_pretrain/vci_1.4.2.ckpt", 19 | config="/large_storage/ctc/userspace/aadduri/vci/checkpoint/large_1e-4_rda_tabular_counts_2048/crossds_config.yaml", 20 | read_depth=1200, 21 | latent_dim=1024, # dimension of pretrained vci model 22 | hidden_dims=[512, 512, 512], # hidden dimensions of the decoder 23 | dropout=0.1, 24 | basal_residual=False, 25 | ): 26 | super().__init__() 27 | self.genes = genes 28 | self.model_loc = model_loc 29 | self.config = config 30 | self.finetune = Finetune(OmegaConf.load(self.config)) 31 | self.finetune.load_model(self.model_loc) 32 | self.read_depth = nn.Parameter(torch.tensor(read_depth, dtype=torch.float), requires_grad=False) 33 | self.basal_residual = basal_residual 34 | 35 | # layers = [ 36 | # nn.Linear(latent_dim, hidden_dims[0]), 37 | # ] 38 | 39 | # self.gene_lora = nn.Sequential(*layers) 40 | 41 | self.latent_decoder = nn.Sequential( 42 | nn.Linear(latent_dim, hidden_dims[0]), 43 | nn.LayerNorm(hidden_dims[0]), 44 | nn.GELU(), 45 | nn.Dropout(dropout), 46 | nn.Linear(hidden_dims[0], hidden_dims[1]), 47 | nn.LayerNorm(hidden_dims[1]), 48 | nn.GELU(), 49 | nn.Dropout(dropout), 50 | nn.Linear(hidden_dims[1], len(self.genes)), 51 | nn.ReLU(), 52 | ) 53 | 54 | self.gene_decoder_proj = nn.Sequential( 55 | nn.Linear(len(self.genes), 128), 56 | nn.Linear(128, len(self.genes)), 57 | ) 58 | 59 | self.binary_decoder = self.finetune.model.binary_decoder 60 | for param in self.binary_decoder.parameters(): 61 | param.requires_grad = False 62 | 63 | def gene_dim(self): 64 | return len(self.genes) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | # x is [B, S, latent_dim]. 68 | if len(x.shape) != 3: 69 | x = x.unsqueeze(0) 70 | batch_size, seq_len, latent_dim = x.shape 71 | x = x.view(batch_size * seq_len, latent_dim) 72 | 73 | # Get gene embeddings 74 | gene_embeds = self.finetune.get_gene_embedding(self.genes) 75 | 76 | # Handle RDA task counts 77 | use_rda = getattr(self.finetune.model.cfg.model, "rda", False) 78 | # Define your sub-batch size (tweak this based on your available memory) 79 | sub_batch_size = 16 80 | logprob_chunks = [] # to store outputs of each sub-batch 81 | 82 | for i in range(0, x.shape[0], sub_batch_size): 83 | # Get the sub-batch of latent vectors 84 | x_sub = x[i : i + sub_batch_size] 85 | 86 | # Create task_counts for the sub-batch if needed 87 | if use_rda: 88 | # task_counts_sub = torch.full( 89 | # (x_sub.shape[0],), self.read_depth, device=x.device 90 | # ) 91 | task_counts_sub = torch.ones((x_sub.shape[0],), device=x.device) * self.read_depth 92 | else: 93 | task_counts_sub = None 94 | 95 | # Compute merged embeddings for the sub-batch 96 | merged_embs_sub = self.finetune.model.resize_batch(x_sub, gene_embeds, task_counts_sub) 97 | 98 | # Run the binary decoder on the sub-batch 99 | logprobs_sub = self.binary_decoder(merged_embs_sub) 100 | 101 | # Squeeze the singleton dimension if needed 102 | if logprobs_sub.dim() == 3 and logprobs_sub.size(-1) == 1: 103 | logprobs_sub = logprobs_sub.squeeze(-1) 104 | 105 | # Collect the results 106 | logprob_chunks.append(logprobs_sub) 107 | 108 | # Concatenate the sub-batches back together 109 | logprobs = torch.cat(logprob_chunks, dim=0) 110 | 111 | # Reshape back to [B, S, gene_dim] 112 | decoded_gene = logprobs.view(batch_size, seq_len, len(self.genes)) 113 | decoded_gene = decoded_gene + self.gene_decoder_proj(decoded_gene) 114 | # decoded_gene = torch.nn.functional.relu(decoded_gene) 115 | 116 | # # normalize the sum of decoded_gene to be read depth 117 | # decoded_gene = decoded_gene / decoded_gene.sum(dim=2, keepdim=True) * self.read_depth 118 | 119 | # decoded_gene = self.gene_lora(decoded_gene) 120 | # TODO: fix this to work with basal counts 121 | 122 | # add logic for basal_residual: 123 | decoded_x = self.latent_decoder(x) 124 | decoded_x = decoded_x.view(batch_size, seq_len, len(self.genes)) 125 | 126 | # Pass through the additional decoder layers 127 | return decoded_gene + decoded_x 128 | -------------------------------------------------------------------------------- /src/state/tx/models/embed_sum.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from .base import PerturbationModel 6 | from .utils import build_mlp, get_activation_class 7 | 8 | 9 | class EmbedSumPerturbationModel(PerturbationModel): 10 | """ 11 | Implementation of the EmbedSum model which treats perturbations as learned embeddings 12 | that are added to control cell representations, which are input as gene expression counts 13 | or as embeddings from a foundation model (UCE, scGPT, etc). The outputs are always in 14 | gene expression space. 15 | 16 | This model: 17 | 1. Learns a co-embedding space for perturbations and cell states 18 | 2. Computes perturbation effects in this space 19 | 3. Decoder maps perturbed embeddings to gene expression space 20 | 21 | Args: 22 | input_dim: Dimension of input embeddings (either number of genes or latent dim from obsm key) 23 | hidden_dim: Dimension of hidden layers 24 | output_dim: Number of genes to predict 25 | pert_dim: Dimension of perturbation inputs (usually one-hot size) 26 | decode_intermediate_dim: Optional intermediate dimension for decoder 27 | n_encoder_layers: Number of layers in encoder (default: 2) 28 | n_decoder_layers: Number of layers in encoder (default: 2) 29 | dropout: Dropout rate (default: 0.1) 30 | learning_rate: Learning rate for optimizer (default: 1e-3) 31 | loss_fn: Loss function (default: 'nn.MSELoss()') 32 | """ 33 | 34 | def __init__( 35 | self, 36 | input_dim: int, 37 | hidden_dim: int, 38 | output_dim: int, 39 | pert_dim: int, 40 | output_space: str = "gene", 41 | **kwargs, 42 | ): 43 | # Register with parent constructor 44 | super().__init__( 45 | input_dim=input_dim, 46 | hidden_dim=hidden_dim, 47 | output_dim=output_dim, 48 | pert_dim=pert_dim, 49 | output_space=output_space, 50 | **kwargs, 51 | ) 52 | 53 | # Set class specific parameters before registering with parent constructor 54 | self.n_encoder_layers = kwargs.get("n_encoder_layers", 2) 55 | self.n_decoder_layers = kwargs.get("n_decoder_layers", 2) 56 | self.dropout = kwargs.get("dropout", 0.1) 57 | self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) 58 | self.kwargs = kwargs 59 | 60 | # Build model components 61 | self._build_networks() 62 | 63 | def _build_networks(self): 64 | """ 65 | Build the core components: 66 | 1. Perturbation encoder: maps one-hot to learned embedding 67 | 2. Decoder: maps perturbed embedding to gene space 68 | """ 69 | # Map perturbation to effect in embedding space 70 | self.pert_encoder = build_mlp( 71 | in_dim=self.pert_dim, 72 | out_dim=self.hidden_dim, 73 | hidden_dim=self.hidden_dim, 74 | n_layers=self.n_encoder_layers, 75 | dropout=self.dropout, 76 | activation=self.activation_class, 77 | ) 78 | 79 | # Map the input embedding to the hidden space 80 | self.basal_encoder = build_mlp( 81 | in_dim=self.input_dim, 82 | out_dim=self.hidden_dim, 83 | hidden_dim=self.hidden_dim, 84 | n_layers=self.n_encoder_layers, 85 | dropout=self.dropout, 86 | activation=self.activation_class, 87 | ) 88 | 89 | self.project_out = build_mlp( 90 | in_dim=self.hidden_dim, 91 | out_dim=self.output_dim, 92 | hidden_dim=self.hidden_dim, 93 | n_layers=self.n_decoder_layers, 94 | dropout=self.dropout, 95 | activation=self.activation_class, 96 | ) 97 | 98 | def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: 99 | """Map perturbation to an effect vector in embedding space.""" 100 | return self.pert_encoder(pert) 101 | 102 | def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: 103 | """Expression is already in embedding space, pass through.""" 104 | return self.basal_encoder(expr) 105 | 106 | def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: 107 | """ 108 | Given a perturbation and basal embeddings, compute the perturbed embedding. 109 | """ 110 | # Project perturbation and basal cell state to latent space 111 | perturbation = self.encode_perturbation(pert) 112 | basal_encoded = self.basal_encoder(basal) 113 | 114 | # Add perturbation to basal embedding 115 | perturbed_encoded = basal_encoded + perturbation 116 | return perturbed_encoded 117 | 118 | def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: 119 | """ 120 | Given 121 | 122 | Args: 123 | batch: Dictionary containing: 124 | - pert: Perturbation one-hot 125 | - basal: Control expression embedding 126 | """ 127 | pert = batch["pert_emb"] 128 | basal = batch["ctrl_cell_emb"] 129 | 130 | # compute perturbed cell state to perturbation/cell co-embedding space 131 | perturbed_encoded = self.perturb(pert, basal) 132 | 133 | # Decode to gene space or to input cell embedding space 134 | return self.project_out(perturbed_encoded) 135 | -------------------------------------------------------------------------------- /src/state/tx/models/decoder_only.py: -------------------------------------------------------------------------------- 1 | # File: models/decoder_only.py 2 | 3 | import torch 4 | from geomloss import SamplesLoss 5 | 6 | from .base import PerturbationModel 7 | from .utils import get_activation_class 8 | 9 | 10 | class DecoderOnlyPerturbationModel(PerturbationModel): 11 | """ 12 | DecoderOnlyPerturbationModel learns to map the ground truth latent embedding 13 | (provided in batch["pert_cell_emb"]) to the ground truth HVG space (batch["pert_cell_counts"]). 14 | 15 | Unlike the other perturbation models that compute a control mapping (e.g. via a mapping strategy), 16 | this model simply feeds the latent representation through a decoder network. The loss is computed 17 | between the decoder output and the target HVG expression. 18 | 19 | It keeps the overall architectural style (and uses the SamplesLoss loss function from geomloss) 20 | as in the OldNeuralOT model. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | input_dim: int, 26 | hidden_dim: int, 27 | output_dim: int, 28 | pert_dim: int, 29 | n_decoder_layers: int = 2, 30 | dropout: float = 0.0, 31 | distributional_loss: str = "energy", 32 | output_space: str = "gene", 33 | gene_dim=None, 34 | **kwargs, 35 | ): 36 | super().__init__( 37 | input_dim=input_dim, 38 | hidden_dim=hidden_dim, 39 | gene_dim=gene_dim, 40 | output_dim=output_dim, 41 | pert_dim=pert_dim, 42 | output_space=output_space, 43 | **kwargs, 44 | ) 45 | self.n_decoder_layers = n_decoder_layers 46 | self.dropout = dropout 47 | self.distributional_loss = distributional_loss 48 | self.cell_sentence_len = kwargs["transformer_backbone_kwargs"]["n_positions"] 49 | self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) 50 | self.gene_dim = gene_dim 51 | 52 | # Use the same loss function as OldNeuralOT (e.g. using the MMD loss via geomloss) 53 | self.loss_fn = SamplesLoss(loss=self.distributional_loss) 54 | 55 | def _build_networks(self): 56 | pass 57 | 58 | def forward(self, batch: dict) -> torch.Tensor: 59 | """ 60 | Forward pass: use the ground truth latent embedding (batch["pert_cell_emb"]) as the prediction. 61 | """ 62 | latent = batch["pert_cell_emb"] 63 | return latent 64 | 65 | def training_step(self, batch, batch_idx): 66 | """ 67 | Training step: The decoder output is compared against the target HVG expression. 68 | We assume that when output_space=="gene", the target is in batch["pert_cell_counts"]. 69 | The predictions and targets are reshaped (using a cell sentence length, if provided) 70 | before computing the loss. 71 | """ 72 | pred = self(batch) 73 | # log a zero tensor 74 | self.log("train_loss", 0.0) 75 | 76 | if self.gene_decoder is not None and "pert_cell_counts" in batch: 77 | pert_cell_counts_preds = self.gene_decoder(pred) 78 | pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) 79 | gene_targets = batch["pert_cell_counts"] 80 | gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) 81 | decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() 82 | self.log("decoder_loss", decoder_loss) 83 | else: 84 | self.log("decoder_loss", 0.0) 85 | decoder_loss = None 86 | return decoder_loss 87 | 88 | def validation_step(self, batch, batch_idx): 89 | pred = self(batch) 90 | self.log("val_loss", 0.0) 91 | 92 | return {"loss": None, "predictions": pred} 93 | 94 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): 95 | preds = outputs["predictions"] 96 | 97 | if self.gene_decoder is not None and "pert_cell_counts" in batch: 98 | pert_cell_counts_preds = self.gene_decoder(preds) 99 | gene_targets = batch["pert_cell_counts"] 100 | pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) 101 | gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) 102 | decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() 103 | self.log("decoder_val_loss", decoder_loss) 104 | 105 | def test_step(self, batch, batch_idx): 106 | pred = self(batch) 107 | 108 | if self.gene_decoder is not None and "pert_cell_counts" in batch: 109 | pert_cell_counts_preds = self.gene_decoder(pred) 110 | gene_targets = batch["pert_cell_counts"] 111 | gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) 112 | decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() 113 | self.log("decoder_test_loss", decoder_loss) 114 | return {"loss": None, "predictions": pred} 115 | 116 | def predict_step(self, batch, batch_idx, padded=True, **kwargs): 117 | """ 118 | Typically used for final inference. We'll replicate old logic: 119 | returning 'preds', 'X', 'pert_name', etc. 120 | """ 121 | latent_output = self.forward(batch) # shape [B, ...] 122 | output_dict = { 123 | "preds": latent_output, 124 | "pert_cell_emb": batch.get("pert_cell_emb", None), 125 | "pert_cell_counts": batch.get("pert_cell_counts", None), 126 | "pert_name": batch.get("pert_name", None), 127 | "celltype_name": batch.get("cell_type", None), 128 | "batch": batch.get("batch", None), 129 | "ctrl_cell_emb": batch.get("ctrl_cell_emb", None), 130 | } 131 | 132 | pert_cell_counts_preds = self.gene_decoder(latent_output) 133 | output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds 134 | 135 | return output_dict 136 | -------------------------------------------------------------------------------- /src/state/emb/finetune_decoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch import nn 4 | 5 | from vci.nn.model import StateEmbeddingModel 6 | from vci.train.trainer import get_embeddings 7 | from vci.utils import get_embedding_cfg 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class Finetune: 13 | def __init__(self, cfg, learning_rate=1e-4): 14 | """ 15 | Initialize the Finetune class for fine-tuning the binary decoder of a pre-trained model. 16 | 17 | Parameters: 18 | ----------- 19 | cfg : OmegaConf 20 | Configuration object containing model settings 21 | learning_rate : float 22 | Learning rate for fine-tuning the binary decoder 23 | """ 24 | self.model = None 25 | self.collator = None 26 | self.protein_embeds = None 27 | self._vci_conf = cfg 28 | self.learning_rate = learning_rate 29 | self.cached_gene_embeddings = {} 30 | self.device = None 31 | 32 | def load_model(self, checkpoint): 33 | """ 34 | Load a pre-trained model from a checkpoint and prepare it for fine-tuning. 35 | 36 | Parameters: 37 | ----------- 38 | checkpoint : str 39 | Path to the checkpoint file 40 | """ 41 | if self.model: 42 | raise ValueError("Model already initialized") 43 | 44 | # Import locally to avoid circular imports 45 | 46 | # Load and initialize model for eval 47 | self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) 48 | 49 | # Ensure model uses the provided config, not the stored one 50 | if self._vci_conf is not None: 51 | self.model.update_config(self._vci_conf) 52 | 53 | self.device = self.model.device 54 | 55 | # Load protein embeddings 56 | all_pe = get_embeddings(self._vci_conf) 57 | all_pe.requires_grad = False 58 | self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) 59 | self.model.pe_embedding.to(self.device) 60 | 61 | # Load protein embeddings 62 | self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings) 63 | 64 | # Freeze all parameters 65 | for param in self.model.parameters(): 66 | param.requires_grad = False 67 | 68 | # Enable gradients only for binary decoder 69 | for param in self.model.binary_decoder.parameters(): 70 | param.requires_grad = False 71 | 72 | # Ensure the binary decoder is in training mode so gradients are enabled. 73 | self.model.binary_decoder.eval() 74 | 75 | def get_gene_embedding(self, genes): 76 | """ 77 | Get embeddings for a list of genes, with caching to avoid recomputation. 78 | 79 | Parameters: 80 | ----------- 81 | genes : list 82 | List of gene names/identifiers 83 | 84 | Returns: 85 | -------- 86 | torch.Tensor 87 | Tensor of gene embeddings 88 | """ 89 | # Cache key based on genes tuple 90 | cache_key = tuple(genes) 91 | 92 | # Return cached embeddings if available 93 | if cache_key in self.cached_gene_embeddings: 94 | return self.cached_gene_embeddings[cache_key] 95 | 96 | # Compute gene embeddings 97 | protein_embeds = [self.protein_embeds[x] if x in self.protein_embeds else torch.zeros(5120) for x in genes] 98 | protein_embeds = torch.stack(protein_embeds).to(self.device) 99 | gene_embeds = self.model.gene_embedding_layer(protein_embeds) 100 | 101 | # Cache and return 102 | self.cached_gene_embeddings[cache_key] = gene_embeds 103 | return gene_embeds 104 | 105 | def get_counts(self, cell_embs, genes, read_depth=None, batch_size=32): 106 | """ 107 | Generate predictions with the binary decoder with gradients enabled. 108 | 109 | Parameters: 110 | - cell_embs: A tensor or array of cell embeddings. 111 | - genes: List of gene names. 112 | - read_depth: Optional read depth for RDA normalization. 113 | - batch_size: Batch size for processing. 114 | 115 | Returns: 116 | A single tensor of shape [N, num_genes] where N is the total number of cells. 117 | """ 118 | 119 | # Convert cell_embs to a tensor on the correct device. 120 | cell_embs = torch.tensor(cell_embs, dtype=torch.float, device=self.device) 121 | 122 | # Check if RDA is enabled. 123 | use_rda = getattr(self.model.cfg.model, "rda", False) 124 | if use_rda and read_depth is None: 125 | read_depth = 1000.0 126 | 127 | # Retrieve gene embeddings (cached if available). 128 | gene_embeds = self.get_gene_embedding(genes) 129 | 130 | # List to collect the output predictions for each batch. 131 | output_batches = [] 132 | 133 | # Loop over cell embeddings in batches. 134 | for i in range(0, cell_embs.size(0), batch_size): 135 | # Determine batch indices. 136 | end_idx = min(i + batch_size, cell_embs.size(0)) 137 | cell_embeds_batch = cell_embs[i:end_idx] 138 | 139 | # Set up task counts if using RDA. 140 | if use_rda: 141 | task_counts = torch.full((cell_embeds_batch.shape[0],), read_depth, device=self.device) 142 | else: 143 | task_counts = None 144 | 145 | # Resize the batch using the model's method. 146 | merged_embs = self.model.resize_batch(cell_embeds_batch, gene_embeds, task_counts) 147 | 148 | # Forward pass through the binary decoder. 149 | logprobs_batch = self.model.binary_decoder(merged_embs) 150 | 151 | # If the output has an extra singleton dimension (e.g., [B, gene_dim, 1]), squeeze it. 152 | if logprobs_batch.dim() == 3 and logprobs_batch.size(-1) == 1: 153 | logprobs_batch = logprobs_batch.squeeze(-1) 154 | 155 | output_batches.append(logprobs_batch) 156 | 157 | # Concatenate all batch outputs along the first dimension. 158 | return torch.cat(output_batches, dim=0) 159 | -------------------------------------------------------------------------------- /src/state/emb/vectordb.py: -------------------------------------------------------------------------------- 1 | import lancedb 2 | import numpy as np 3 | import pandas as pd 4 | from typing import Optional, List 5 | 6 | 7 | class StateVectorDB: 8 | """Manages LanceDB operations for State embeddings.""" 9 | 10 | def __init__(self, db_path: str = "./state_embeddings.lancedb"): 11 | """Initialize or connect to a LanceDB database. 12 | 13 | Args: 14 | db_path: Path to the LanceDB database 15 | """ 16 | self.db = lancedb.connect(db_path) 17 | self.table_name = "state_embeddings" 18 | 19 | def create_or_update_table( 20 | self, 21 | embeddings: np.ndarray, 22 | metadata: pd.DataFrame, 23 | embedding_key: str = "X_state", 24 | dataset_name: Optional[str] = None, 25 | batch_size: int = 1000, 26 | ): 27 | """Create or update the embeddings table. 28 | 29 | Args: 30 | embeddings: Cell embeddings array (n_cells x embedding_dim) 31 | metadata: Cell metadata from adata.obs 32 | embedding_key: Name of the embedding (for versioning) 33 | dataset_name: Name of the dataset being processed 34 | batch_size: Batch size for insertion 35 | """ 36 | # Prepare data with metadata 37 | data = [] 38 | for i in range(0, len(embeddings), batch_size): 39 | batch_end = min(i + batch_size, len(embeddings)) 40 | batch_data = [] 41 | 42 | for j in range(i, batch_end): 43 | record = { 44 | "vector": embeddings[j].tolist(), 45 | "cell_id": metadata.index[j], 46 | "embedding_key": embedding_key, 47 | "dataset": dataset_name or "unknown", 48 | **{col: metadata.iloc[j][col] for col in metadata.columns}, 49 | } 50 | batch_data.append(record) 51 | 52 | data.extend(batch_data) 53 | 54 | # Create or append to table 55 | if self.table_name in self.db.table_names(): 56 | table = self.db.open_table(self.table_name) 57 | table.add(data) 58 | else: 59 | self.db.create_table(self.table_name, data=data) 60 | 61 | def search( 62 | self, 63 | query_vector: np.ndarray, 64 | k: int = 10, 65 | filter: str | None = None, 66 | include_distance: bool = True, 67 | columns: List[str] | None = None, 68 | include_vector: bool = False, 69 | ): 70 | """Search for similar embeddings. 71 | 72 | Args: 73 | query_vector: Query embedding vector 74 | k: Number of results to return 75 | filter: Optional filter expression (e.g., 'cell_type == "B cell"') 76 | include_distance: Whether to include distance in results 77 | include_vector: Whether to include the query vector in the results 78 | columns: Specific columns to return (None = all) 79 | Returns: 80 | Search results with metadata 81 | """ 82 | table = self.db.open_table(self.table_name) 83 | 84 | # Build query 85 | query = table.search(query_vector).limit(k) 86 | 87 | if filter: 88 | query = query.where(filter) 89 | 90 | if columns: 91 | query = query.select(columns + ["_distance"] if include_distance else columns) 92 | 93 | results = query.to_pandas() 94 | 95 | # deal with _distance column 96 | if "_distance" in results.columns: 97 | if include_distance: 98 | results = results.rename(columns={"_distance": "query_distance"}) 99 | else: 100 | results = results.drop("_distance", axis=1) 101 | elif include_distance: 102 | results["query_distance"] = 0.0 103 | 104 | # drop vector column if include_vector is False 105 | if not include_vector and "vector" in results.columns: 106 | results = results.drop("vector", axis=1) 107 | 108 | return results 109 | 110 | def batch_search( 111 | self, 112 | query_vectors: np.ndarray, 113 | k: int = 10, 114 | filter: str | None = None, 115 | include_distance: bool = True, 116 | batch_size: int = 100, 117 | show_progress: bool = True, 118 | include_vector: bool = False, 119 | ): 120 | """Batch search for multiple query vectors. 121 | 122 | Args: 123 | query_vectors: Array of query embedding vectors 124 | k: Number of results per query 125 | filter: Optional filter expression 126 | include_distance: Whether to include distances 127 | include_vector: Whether to include the query vector in the results 128 | batch_size: Number of queries to process at once 129 | show_progress: Show progress bar 130 | Returns: 131 | List of DataFrames with search results 132 | """ 133 | from tqdm import tqdm 134 | 135 | results = [] 136 | iterator = range(0, len(query_vectors), batch_size) 137 | 138 | if show_progress: 139 | iterator = tqdm(iterator, desc="Searching") 140 | 141 | for i in iterator: 142 | batch_end = min(i + batch_size, len(query_vectors)) 143 | batch_queries = query_vectors[i:batch_end] 144 | 145 | batch_results = [] 146 | for query_vec in batch_queries: 147 | result = self.search( 148 | query_vector=query_vec, 149 | k=k, 150 | filter=filter, 151 | include_distance=include_distance, 152 | include_vector=include_vector, 153 | ) 154 | batch_results.append(result) 155 | 156 | results.extend(batch_results) 157 | 158 | return results 159 | 160 | def get_table_info(self): 161 | """Get information about the embeddings table.""" 162 | if self.table_name not in self.db.table_names(): 163 | return None 164 | 165 | table = self.db.open_table(self.table_name) 166 | return { 167 | "num_rows": len(table), 168 | "columns": table.schema.names, 169 | "embedding_dim": len(table.to_pandas().iloc[0]["vector"]) if len(table) > 0 else 0, 170 | } 171 | -------------------------------------------------------------------------------- /src/state/_cli/_emb/_transform.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | 3 | 4 | def add_arguments_transform(parser: ap.ArgumentParser): 5 | """Add arguments for state embedding CLI.""" 6 | parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder") 7 | parser.add_argument("--checkpoint", required=False, help="Path to the specific model checkpoint") 8 | parser.add_argument( 9 | "--config", 10 | required=False, 11 | help=( 12 | "Path to config override. If omitted, uses the config embedded in the checkpoint; ignores any config in the model folder." 13 | ), 14 | ) 15 | parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") 16 | parser.add_argument("--output", required=False, help="Path to output embedded anndata file (h5ad)") 17 | parser.add_argument("--embed-key", default="X_state", help="Name of key to store embeddings") 18 | parser.add_argument( 19 | "--protein-embeddings", 20 | required=False, 21 | help=( 22 | "Path to protein embeddings override (.pt). If omitted, the CLI will look for 'protein_embeddings.pt' in --model-folder, " 23 | "then fall back to embeddings packaged in the checkpoint, and finally the path from the config." 24 | ), 25 | ) 26 | parser.add_argument("--lancedb", type=str, help="Path to LanceDB database for vector storage") 27 | parser.add_argument( 28 | "--lancedb-update", action="store_true", help="Update existing entries in LanceDB (default: append)" 29 | ) 30 | parser.add_argument("--lancedb-batch-size", type=int, default=1000, help="Batch size for LanceDB operations") 31 | parser.add_argument( 32 | "--batch-size", 33 | type=int, 34 | default=None, 35 | help=( 36 | "Batch size for embedding forward pass (overrides config). " 37 | "Increase to use more VRAM and speed up embedding." 38 | ), 39 | ) 40 | 41 | 42 | def run_emb_transform(args: ap.ArgumentParser): 43 | """ 44 | Compute embeddings for an input anndata file using a pre-trained VCI model checkpoint. 45 | """ 46 | import glob 47 | import logging 48 | import os 49 | 50 | import torch 51 | from omegaconf import OmegaConf 52 | 53 | logging.basicConfig(level=logging.INFO) 54 | logger = logging.getLogger(__name__) 55 | 56 | from ...emb.inference import Inference 57 | 58 | # check for --output or --lancedb 59 | if not args.output and not args.lancedb: 60 | logger.error("Either --output or --lancedb must be provided") 61 | raise ValueError("Either --output or --lancedb must be provided") 62 | 63 | # look in the model folder with glob for *.ckpt, get the first one, and print it 64 | model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) 65 | if not model_files: 66 | logger.error(f"No model checkpoint found in {args.model_folder}") 67 | raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") 68 | if not args.checkpoint: 69 | args.checkpoint = model_files[-1] 70 | logger.info(f"Using model checkpoint: {args.checkpoint}") 71 | 72 | # Create inference object 73 | logger.info("Creating inference object") 74 | # Resolve protein embeddings in priority order: 75 | # 1) Explicit --protein-embeddings 76 | # 2) Auto-detect 'protein_embeddings.pt' in --model-folder 77 | # 3) Let Inference load from checkpoint/config 78 | protein_embeds = None 79 | if args.protein_embeddings: 80 | logger.info(f"Using protein embeddings override: {args.protein_embeddings}") 81 | protein_embeds = torch.load(args.protein_embeddings, weights_only=False, map_location="cpu") 82 | else: 83 | # Try auto-detect in model folder 84 | try: 85 | exact_path = os.path.join(args.model_folder, "protein_embeddings.pt") 86 | cand_path = None 87 | if os.path.exists(exact_path): 88 | cand_path = exact_path 89 | else: 90 | # Consider other variations like protein_embeddings*.pt 91 | pe_files = sorted(glob.glob(os.path.join(args.model_folder, "protein_embeddings*.pt"))) 92 | if pe_files: 93 | # Prefer the lexicographically last to mimic checkpoint selection behavior 94 | cand_path = pe_files[-1] 95 | if cand_path is not None: 96 | logger.info( 97 | f"Found protein embeddings in model folder: {cand_path}. Using these and overriding config." 98 | ) 99 | protein_embeds = torch.load(cand_path, weights_only=False, map_location="cpu") 100 | except Exception as e: 101 | logger.warning( 102 | f"Failed to load auto-detected protein embeddings: {e}. Will fall back to checkpoint/config." 103 | ) 104 | 105 | # Only use config override if explicitly provided; otherwise use config embedded in the checkpoint 106 | conf = OmegaConf.load(args.config) if args.config else None 107 | inferer = Inference(cfg=conf, protein_embeds=protein_embeds) 108 | 109 | # Load model from checkpoint 110 | logger.info(f"Loading model from checkpoint: {args.checkpoint}") 111 | inferer.load_model(args.checkpoint) 112 | 113 | # Create output directory if it doesn't exist 114 | if args.output: 115 | output_dir = os.path.dirname(args.output) 116 | if output_dir: 117 | os.makedirs(output_dir, exist_ok=True) 118 | logger.info(f"Created output directory: {output_dir}") 119 | 120 | # Generate embeddings 121 | logger.info(f"Computing embeddings for {args.input}") 122 | if args.output: 123 | logger.info(f"Output will be saved to {args.output}") 124 | if args.lancedb: 125 | logger.info(f"Embeddings will be saved to LanceDB at {args.lancedb}") 126 | 127 | inferer.encode_adata( 128 | input_adata_path=args.input, 129 | output_adata_path=args.output, 130 | emb_key=args.embed_key, 131 | batch_size=args.batch_size if getattr(args, "batch_size", None) is not None else None, 132 | lancedb_path=args.lancedb, 133 | update_lancedb=args.lancedb_update, 134 | lancedb_batch_size=args.lancedb_batch_size, 135 | ) 136 | 137 | logger.info("Embedding computation completed successfully!") 138 | -------------------------------------------------------------------------------- /src/state/_cli/_tx/_preprocess_infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse as ap 3 | from typing import Optional 4 | from scipy.sparse import issparse, csc_matrix 5 | 6 | 7 | def add_arguments_preprocess_infer(parser: ap.ArgumentParser): 8 | """Add arguments for the preprocess_infer subcommand.""" 9 | parser.add_argument( 10 | "--adata", 11 | type=str, 12 | required=True, 13 | help="Path to input AnnData file (.h5ad)", 14 | ) 15 | parser.add_argument( 16 | "--output", 17 | type=str, 18 | required=True, 19 | help="Path to output preprocessed AnnData file (.h5ad)", 20 | ) 21 | parser.add_argument( 22 | "--control-condition", 23 | type=str, 24 | required=True, 25 | help="Control condition identifier (e.g., \"[('DMSO_TF', 0.0, 'uM')]\")", 26 | ) 27 | parser.add_argument( 28 | "--pert-col", 29 | type=str, 30 | required=True, 31 | help="Column name containing perturbation information (e.g., 'drugname_drugconc')", 32 | ) 33 | parser.add_argument( 34 | "--seed", 35 | type=int, 36 | default=42, 37 | help="Random seed for reproducibility (default: 42)", 38 | ) 39 | parser.add_argument( 40 | "--embed-key", 41 | type=str, 42 | required=False, 43 | help="obsm key to use/replace instead of X (e.g., 'X_pca')", 44 | ) 45 | 46 | 47 | def _fast_row_reindex_matrix(X, row_indexer): 48 | """ 49 | Return X[row_indexer] efficiently for dense or sparse matrices. 50 | Also convert CSC->CSR once to speed up row gathering on sparse matrices. 51 | """ 52 | if issparse(X): 53 | # Row-wise fancy indexing is faster on CSR than CSC. 54 | if isinstance(X, csc_matrix): 55 | print("Converting X from CSC to CSR for faster row indexing...") 56 | X = X.tocsr(copy=True) 57 | return X[row_indexer] 58 | else: 59 | return X[row_indexer, :] 60 | 61 | 62 | def run_tx_preprocess_infer( 63 | adata_path: str, 64 | output_path: str, 65 | control_condition: str, 66 | pert_col: str, 67 | seed: int = 42, 68 | embed_key: Optional[str] = None, 69 | ): 70 | """ 71 | Preprocess inference data by replacing perturbed cells with control expression. 72 | 73 | This creates a 'control template' where all non-control cells receive expression 74 | sampled (with replacement) from control cells, while keeping original annotations. 75 | """ 76 | import logging 77 | 78 | import anndata as ad 79 | import numpy as np 80 | # tqdm removed from the hot path; the main speed-up is vectorization, not progress bars. 81 | 82 | logger = logging.getLogger(__name__) 83 | 84 | print(f"Loading AnnData from {adata_path}") 85 | adata = ad.read_h5ad(adata_path) 86 | 87 | # Set random seed for reproducibility 88 | rng = np.random.default_rng(seed) 89 | print(f"Set random seed to {seed}") 90 | 91 | # Validate columns/keys upfront 92 | if pert_col not in adata.obs.columns: 93 | raise KeyError(f"Column '{pert_col}' not found in adata.obs") 94 | 95 | if embed_key is not None and embed_key not in adata.obsm: 96 | raise KeyError(f"obsm key '{embed_key}' not found in adata.obsm") 97 | 98 | # Identify control cells 99 | print(f"Identifying control cells with condition: {control_condition!r}") 100 | # Use .values to avoid pandas alignment overhead 101 | col_values = adata.obs[pert_col].values 102 | control_mask = col_values == control_condition 103 | control_indices = np.flatnonzero(control_mask) 104 | 105 | print(f"Found {control_indices.size} control cells out of {adata.n_obs} total cells") 106 | if control_indices.size == 0: 107 | raise ValueError(f"No control cells found with condition '{control_condition}' in column '{pert_col}'") 108 | 109 | # Compute unique perturbations for logging (no heavy loop per perturbation) 110 | if hasattr(adata.obs[pert_col], "cat"): 111 | unique_perturbations = adata.obs[pert_col].cat.categories 112 | else: 113 | unique_perturbations = np.unique(col_values) 114 | 115 | non_control_perturbations = [p for p in unique_perturbations if p != control_condition] 116 | n_non_control_cells = int((~control_mask).sum()) 117 | 118 | print(f"Processing {len(non_control_perturbations)} non-control perturbations") 119 | 120 | # Build a source index for every row: control rows map to themselves, 121 | # non-control rows map to randomly sampled control rows. 122 | source_idx = np.arange(adata.n_obs, dtype=np.int64) 123 | if n_non_control_cells > 0: 124 | sampled_controls = rng.choice(control_indices, size=n_non_control_cells, replace=True) 125 | source_idx[~control_mask] = sampled_controls 126 | 127 | # Create a copy to preserve original object structure/metadata (matches original behavior) 128 | adata_modified = adata.copy() 129 | 130 | # Replace data in a single, vectorized operation 131 | if embed_key is not None: 132 | emb = adata.obsm[embed_key] 133 | # emb is expected to be a dense 2D array-like 134 | adata_modified.obsm[embed_key] = emb[source_idx] 135 | total_replaced_cells = n_non_control_cells 136 | else: 137 | X = adata.X 138 | adata_modified.X = _fast_row_reindex_matrix(X, source_idx) 139 | total_replaced_cells = n_non_control_cells 140 | 141 | print(f"Replacement complete! Replaced expression in {total_replaced_cells} cells") 142 | print(f"Control cells ({control_indices.size}) retain their original expression") 143 | 144 | # Summary log 145 | print("=" * 60) 146 | print("PREPROCESSING SUMMARY:") 147 | print(f" - Input: {adata.n_obs} cells, {adata.n_vars} genes") 148 | print(f" - Control condition: {control_condition!r}") 149 | print(f" - Control cells: {control_indices.size} (unchanged)") 150 | print(f" - Perturbed cells: {total_replaced_cells} (replaced with control expression)") 151 | print(f" - Perturbations processed: {len(non_control_perturbations)}") 152 | if embed_key is not None: 153 | print(f" - Using obsm key: {embed_key}") 154 | else: 155 | print(" - Using expression matrix (X)") 156 | print("") 157 | print("USAGE:") 158 | print(" The output file contains cells with control expression but original") 159 | print(" perturbation annotations. When passed through state_transition inference,") 160 | print(" the model will apply perturbation effects to simulate the original data.") 161 | print(" Compare: state_transition(output) ≈ original_input") 162 | print("=" * 60) 163 | 164 | print(f"Saving preprocessed data to {output_path}") 165 | # Writing can still be I/O-bound; the heavy compute path is now vectorized. 166 | adata_modified.write_h5ad(output_path) 167 | print("Preprocessing complete!") 168 | -------------------------------------------------------------------------------- /src/state/emb/tools/slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import logging 5 | import subprocess 6 | 7 | from pathlib import Path 8 | from omegaconf import OmegaConf 9 | from hydra import compose, initialize 10 | from jinja2 import Template 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format="%(asctime)s %(levelname)s: %(message)s", 17 | datefmt="%Y-%m-%d %H:%M:%S", 18 | ) 19 | 20 | sbatch_script_template = """#!/bin/bash 21 | 22 | #SBATCH --job-name={{ exp_name }} 23 | #SBATCH --nodes={{ num_nodes }} 24 | #SBATCH --gres=gpu:{{ num_gpus_per_node }} 25 | #SBATCH --ntasks-per-node={{ num_gpus_per_node }} 26 | #SBATCH --cpus-per-task=16 27 | #SBATCH --mem=1560G 28 | #SBATCH --time={{ duration }} 29 | #SBATCH --signal=B:SIGINT@300 30 | #SBATCH --output=outputs/{{ exp_name }}/training.log 31 | #SBATCH --open-mode=append 32 | #SBATCH --partition={{ partition }} 33 | {{ sbatch_overrides }} 34 | 35 | unset SLURM_TRES_PER_TASK 36 | 37 | export MASTER_ADDR=$(scontrol show hostname ${SLURM_JOB_NODELIST} | head -n 1) 38 | export MASTER_PORT='12357' 39 | 40 | #export PYTHONFAULTHANDLER=1 41 | #export NCCL_DEBUG=INFO 42 | #export NCCL_DEBUG_SUBSYS=ALL 43 | #export NCCL_VERBOSE_MARK=100 44 | #export TORCH_DISTRIBUTED_DEBUG=DETAIL 45 | #export TORCH_CPP_LOG_LEVEL=INFO 46 | 47 | git log --pretty=format:'%h' -n 1 48 | 49 | srun \\ 50 | python -m vci.train --conf {{ traing_config_file }} 51 | """ 52 | 53 | 54 | def parse_vars(extra_vars): 55 | """ 56 | Parses comma seperated key value pair strings into dict. 57 | """ 58 | vars_list = [] 59 | if extra_vars: 60 | for i in extra_vars: 61 | items = i.split("=") 62 | key = items[0].strip() 63 | if len(items) > 1: 64 | value = "=".join(items[1:]) 65 | vars_list.append((key, value)) 66 | return dict(vars_list) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser(description="Create dataset list CSV file") 71 | parser.add_argument( 72 | "-c", 73 | "--config", 74 | type=str, 75 | help="Training configuration file.", 76 | ) 77 | parser.add_argument( 78 | "-e", 79 | "--exp_name", 80 | type=str, 81 | help="Experiment name. This will be used to name generated artifacts.", 82 | ) 83 | parser.add_argument( 84 | "-n", 85 | "--num_nodes", 86 | type=int, 87 | default=1, 88 | help="Number of nodes to use for this training job.", 89 | ) 90 | parser.add_argument( 91 | "-g", 92 | "--gpus_per_nodes", 93 | type=int, 94 | default=4, 95 | help="Number of GPUs per node", 96 | ) 97 | parser.add_argument( 98 | "-r", 99 | "--reservation", 100 | dest="reservation", 101 | type=str, 102 | default=None, 103 | help="Slurm reservation to use for this job.", 104 | ) 105 | parser.add_argument( 106 | "-p", 107 | "--partition", 108 | dest="partition", 109 | type=str, 110 | default="gpu_batch,gpu_high_mem,gpu_batch_high_mem,vci_gpu_priority,preemptible", 111 | help="Slurm partition to use.", 112 | ) 113 | parser.add_argument( 114 | "--duration", 115 | dest="duration", 116 | type=str, 117 | default="7-00:00:00", 118 | help="SLURM job durarion. Pleae refer Slurm documenation for time format", 119 | ) 120 | parser.add_argument( 121 | "-f", 122 | "--force", 123 | dest="force", 124 | action="store_true", 125 | default=False, 126 | help="Overwrite config and submit the job.", 127 | ) 128 | parser.add_argument( 129 | "-d", 130 | "--dryrun", 131 | dest="dryrun", 132 | action="store_true", 133 | default=False, 134 | help="Only generate slurm sbatch script", 135 | ) 136 | parser.add_argument( 137 | "--set", 138 | metavar="KEY=VALUE", 139 | nargs="+", 140 | default=None, 141 | help="Values to be overriden for the training.Please refer ./conf/defaults.yaml", 142 | ) 143 | 144 | args = parser.parse_args() 145 | 146 | bind_param = { 147 | "exp_name": args.exp_name, 148 | "num_nodes": args.num_nodes, 149 | "num_gpus_per_node": args.gpus_per_nodes, 150 | "duration": args.duration, 151 | "partition": args.partition, 152 | } 153 | 154 | if args.config: 155 | bind_param["traing_config_file"] = args.config 156 | else: 157 | assert args.exp_name, "Experiment name is required when config is not provided." 158 | log.info(f"Creating config for {args.exp_name}...") 159 | trn_conf_dir = Path(f"outputs/{args.exp_name}/conf") 160 | if not args.force: 161 | assert not os.path.exists(trn_conf_dir.parent), f"Conf dir {trn_conf_dir.parent.absolute()} already exists." 162 | 163 | overrides = [ 164 | f"experiment.name={args.exp_name}", 165 | f"experiment.num_nodes={args.num_nodes}", 166 | f"experiment.num_gpus_per_node={args.gpus_per_nodes}", 167 | ] 168 | 169 | if args.set: 170 | log.info(f"Applying overrides: {parse_vars(args.set)}") 171 | for key, value in parse_vars(args.set).items(): 172 | overrides.append(f"{key}={value}") 173 | log.info(f"Applying overrides: {overrides}") 174 | 175 | config_dir = Path(os.path.join(os.path.dirname(__file__), "../..", "conf")) 176 | config_dir = os.path.relpath(config_dir, Path(__file__).parent) 177 | log.info(config_dir) 178 | 179 | with initialize(version_base=None, config_path=config_dir): 180 | cfg = compose( 181 | config_name="defaults.yaml", 182 | overrides=overrides, 183 | ) 184 | cfg = OmegaConf.to_container(cfg, resolve=True) 185 | 186 | os.makedirs(trn_conf_dir, exist_ok=True) 187 | trn_conf_file = Path(f"{trn_conf_dir}/training.yaml") 188 | with open(trn_conf_file, "w") as file: 189 | yaml.dump(cfg, file) 190 | bind_param["traing_config_file"] = trn_conf_file.absolute() 191 | 192 | # SLURM changes 193 | sbatch_overrides = None 194 | if args.reservation: 195 | sbatch_overrides = f"#SBATCH --reservation={args.reservation}\n" 196 | 197 | if sbatch_overrides: 198 | bind_param["sbatch_overrides"] = sbatch_overrides 199 | 200 | template = Template(sbatch_script_template) 201 | rendered_script = template.render(bind_param) 202 | 203 | slurm_script = f"outputs/{args.exp_name}/slurm.sh" 204 | with open(slurm_script, "w") as f: 205 | f.write(rendered_script) 206 | 207 | if not args.dryrun: 208 | subprocess.call(["sbatch", slurm_script]) 209 | -------------------------------------------------------------------------------- /src/state/emb/train/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import lightning as L 4 | 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from datetime import timedelta 8 | 9 | from lightning.pytorch.callbacks import ModelCheckpoint 10 | from lightning.pytorch.loggers import WandbLogger 11 | from lightning.pytorch.strategies import DDPStrategy 12 | 13 | from ..nn.model import StateEmbeddingModel 14 | from ..data import H5adSentenceDataset, VCIDatasetSentenceCollator 15 | from ..train.callbacks import ( 16 | LogLR, 17 | ProfilerCallback, 18 | ResumeCallback, 19 | EMACallback, 20 | PerfProfilerCallback, 21 | CumulativeFLOPSCallback, 22 | ) 23 | from ..utils import get_latest_checkpoint, get_embedding_cfg, get_dataset_cfg 24 | 25 | 26 | def get_embeddings(cfg): 27 | # Load in ESM2 embeddings and special tokens 28 | all_pe = torch.load(get_embedding_cfg(cfg).all_embeddings, weights_only=False) 29 | if isinstance(all_pe, dict): 30 | all_pe = torch.vstack(list(all_pe.values())) 31 | 32 | all_pe = all_pe.cuda() 33 | return all_pe 34 | 35 | 36 | def main(cfg): 37 | print(f"Starting training with Embedding {cfg.embeddings.current} and dataset {cfg.dataset.current}") 38 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 39 | os.environ["NCCL_LAUNCH_TIMEOUT"] = str(cfg.experiment.ddp_timeout) 40 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" 41 | TOTAL_N_CELL = cfg.dataset.num_cells 42 | EPOCH_LENGTH = int(TOTAL_N_CELL // cfg.model.batch_size // 24) 43 | # ? not sure why this needs to be included but seems empirical?? no clue why this is 6 44 | warmup_steps = EPOCH_LENGTH * 6 45 | 46 | train_dataset_sentence_collator = VCIDatasetSentenceCollator(cfg, is_train=True) 47 | # validation should not do augmentations 48 | val_dataset_sentence_collator = VCIDatasetSentenceCollator(cfg, is_train=False) 49 | 50 | generator = torch.Generator() 51 | generator.manual_seed(cfg.dataset.seed) 52 | 53 | if get_dataset_cfg(cfg).ds_type == "h5ad": 54 | DatasetClass = H5adSentenceDataset 55 | else: 56 | raise ValueError(f"Unknown dataset type: {get_dataset_cfg(cfg).ds_type}") 57 | 58 | # Training dataloader 59 | train_dataset = DatasetClass(cfg) 60 | train_dataloader = DataLoader( 61 | train_dataset, 62 | batch_size=cfg.model.batch_size, 63 | shuffle=True, 64 | collate_fn=train_dataset_sentence_collator, 65 | num_workers=cfg.dataset.num_train_workers, 66 | persistent_workers=True, 67 | pin_memory=True, 68 | prefetch_factor=4, 69 | generator=generator, 70 | ) 71 | 72 | val_dataset = DatasetClass(cfg, test=True) 73 | val_dataloader = DataLoader( 74 | val_dataset, 75 | batch_size=cfg.model.batch_size, 76 | shuffle=True, 77 | collate_fn=val_dataset_sentence_collator, 78 | num_workers=cfg.dataset.num_val_workers, 79 | persistent_workers=True, 80 | generator=generator, 81 | ) 82 | 83 | model = StateEmbeddingModel( 84 | token_dim=get_embedding_cfg(cfg).size, 85 | d_model=cfg.model.emsize, 86 | nhead=cfg.model.nhead, 87 | d_hid=cfg.model.d_hid, 88 | nlayers=cfg.model.nlayers, 89 | output_dim=cfg.model.output_dim, 90 | dropout=cfg.model.dropout, 91 | warmup_steps=warmup_steps, 92 | compiled=False, 93 | max_lr=cfg.optimizer.max_lr, 94 | emb_size=get_embedding_cfg(cfg).size, 95 | collater=val_dataset_sentence_collator, 96 | cfg=cfg, 97 | ) 98 | # Ensure model always uses the current config, even after checkpoint loading 99 | model.update_config(cfg) 100 | # Also update datasets and collaters with current config 101 | train_dataset.cfg = cfg 102 | val_dataset.cfg = cfg 103 | train_dataset_sentence_collator.cfg = cfg 104 | val_dataset_sentence_collator.cfg = cfg 105 | model.collater = val_dataset_sentence_collator 106 | model = model.cuda() 107 | all_pe = get_embeddings(cfg) 108 | all_pe.requires_grad = False 109 | model.pe_embedding = nn.Embedding.from_pretrained(all_pe) 110 | 111 | model = model.train() 112 | 113 | run_name, chk = get_latest_checkpoint(cfg) 114 | checkpoint_callback = ModelCheckpoint( 115 | every_n_train_steps=cfg.experiment.checkpoint.every_n_train_steps, 116 | dirpath=os.path.join(cfg.experiment.checkpoint.path, cfg.experiment.name), 117 | filename=f"{run_name}" + "-{epoch}-{step}", 118 | save_last=True, 119 | save_top_k=cfg.experiment.checkpoint.save_top_k, 120 | monitor=cfg.experiment.checkpoint.monitor, 121 | ) 122 | 123 | if cfg.wandb.enable: 124 | try: 125 | import wandb 126 | 127 | exp_logger = WandbLogger(project=cfg.wandb.project, name=cfg.experiment.name) 128 | exp_logger.watch(model, log_freq=1000) 129 | except ImportError: 130 | print("Warning: wandb is not installed. Skipping wandb logging.") 131 | print("To enable wandb logging, install it with: pip install wandb") 132 | exp_logger = None 133 | except Exception as e: 134 | print(f"Warning: Failed to initialize wandb logger: {e}") 135 | print("Continuing without wandb logging.") 136 | exp_logger = None 137 | else: 138 | exp_logger = None 139 | 140 | callbacks = [checkpoint_callback, LogLR(100), ResumeCallback(cfg), PerfProfilerCallback()] 141 | 142 | if getattr(cfg.model, "ema", False): 143 | ema_decay = getattr(cfg.model, "ema_decay", 0.999) 144 | callbacks.append(EMACallback(decay=ema_decay)) 145 | 146 | # Add cumulative FLOPS callback 147 | callbacks.append(CumulativeFLOPSCallback(use_backward=cfg.experiment.cumulative_flops_use_backward)) 148 | 149 | max_steps = -1 150 | if cfg.experiment.profile.enable_profiler: 151 | callbacks.append(ProfilerCallback(cfg=cfg)) 152 | max_steps = cfg.experiment.profile.max_steps 153 | 154 | val_interval = int(cfg.experiment.val_check_interval * cfg.experiment.num_gpus_per_node * cfg.experiment.num_nodes) 155 | trainer = L.Trainer( 156 | max_epochs=cfg.experiment.num_epochs, 157 | max_steps=max_steps, 158 | callbacks=callbacks, 159 | devices=cfg.experiment.num_gpus_per_node, 160 | num_nodes=cfg.experiment.num_nodes, 161 | # Accumulation 162 | gradient_clip_val=cfg.optimizer.max_grad_norm, 163 | accumulate_grad_batches=cfg.optimizer.gradient_accumulation_steps, 164 | precision="bf16-mixed", 165 | strategy=DDPStrategy( 166 | process_group_backend="nccl", 167 | find_unused_parameters=False, 168 | timeout=timedelta(seconds=cfg.experiment.get("ddp_timeout", 3600)), 169 | ), 170 | val_check_interval=val_interval, 171 | # Logging 172 | logger=exp_logger, 173 | fast_dev_run=False, 174 | limit_val_batches=cfg.experiment.limit_val_batches, 175 | ) 176 | 177 | if chk: 178 | print(f"******** Loading chkpoint {run_name} {chk}...") 179 | else: 180 | print(f"******** Initialized fresh {run_name}...") 181 | 182 | trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=chk) 183 | 184 | trainer.save_checkpoint(os.path.join(cfg.experiment.checkpoint.path, f"{run_name}_final.pt")) 185 | -------------------------------------------------------------------------------- /src/state/tx/callbacks/model_flops_utilization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | from lightning import LightningModule, Trainer 7 | from lightning.fabric.utilities.throughput import Throughput, measure_flops 8 | from lightning.pytorch.callbacks import Callback 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | class ModelFLOPSUtilizationCallback(Callback): 15 | """ 16 | PyTorch Lightning callback to measure and log Model FLOPS Utilization (MFU). 17 | 18 | - Measures FLOPs once on the first training batch using `measure_flops`. 19 | - Tracks rolling throughput metrics via `Throughput` with a window equal to 20 | the user input window size. 21 | - Logs MFU to the trainer loggers (e.g., W&B) at the same cadence as other metrics. 22 | 23 | Args: 24 | available_flops: Theoretical peak flops for device in TFLOPS, example: enter 60e12 for 60 TFLOPS. 25 | use_backward: If True, include backward pass FLOPs in the measurement. 26 | logging_interval: The interval at which to log MFU. 27 | cell_set_len: The length of the cell set. 28 | window_size: The size of the rolling window. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | *, 34 | available_flops: Optional[float] = None, 35 | use_backward: bool = False, 36 | logging_interval: int = 50, 37 | cell_set_len: Optional[int] = None, 38 | window_size: int = 20, 39 | ) -> None: 40 | super().__init__() 41 | self.available_flops = available_flops 42 | print(f"ModelFLOPSUtilizationCallback: Using available flops: {self.available_flops}") 43 | self.use_backward = use_backward 44 | print(f"ModelFLOPSUtilizationCallback: Using use_backward: {self.use_backward}") 45 | self.logging_interval = logging_interval 46 | print(f"ModelFLOPSUtilizationCallback: Using logging interval: {self.logging_interval}") 47 | self.cell_set_len = cell_set_len 48 | print(f"ModelFLOPSUtilizationCallback: Using cell set length: {self.cell_set_len}") 49 | 50 | self._throughput: Optional[Throughput] = None 51 | self._window_size: int = window_size 52 | print(f"ModelFLOPSUtilizationCallback: Using window size: {self._window_size}") 53 | self._flops_per_batch: Optional[int] = None 54 | self._measured: bool = False 55 | self._train_start_time: Optional[float] = None 56 | self._cell_sets_len: Optional[int] = None 57 | # Cumulative counters since training start 58 | self._cumulative_time: float = 0.0 59 | self._cumulative_batches: int = 0 60 | self._cumulative_samples: int = 0 61 | 62 | def setup(self, trainer: Trainer, pl_module: Any, stage: str) -> None: 63 | # Initialize throughput tracker 64 | world_size = getattr(trainer, "num_devices") 65 | assert isinstance(world_size, int), f"world_size must be an integer, got {type(world_size)}" 66 | assert world_size > 0, f"world_size must be greater than 0, got {world_size}" 67 | print(f"ModelFLOPSUtilizationCallback: Initializing throughput tracker with world_size: {world_size}") 68 | 69 | self._throughput = Throughput( 70 | available_flops=self.available_flops, 71 | world_size=world_size, 72 | window_size=self._window_size, 73 | ) 74 | # Reset cumulative counters on setup 75 | self._cumulative_time = 0.0 76 | self._cumulative_batches = 0 77 | self._cumulative_samples = 0 78 | 79 | def _infer_batch_size(self, batch: Any) -> int: 80 | """Infer the logical batch size. 81 | 82 | In the cell-load pipeline, the sampler yields flattened batches of size 83 | batch_size * cell_set_len. Divide the leading dimension by cell_set_len to recover the true batch size. 84 | """ 85 | batch_size = batch["pert_cell_emb"].shape[0] 86 | return batch_size // self.cell_set_len 87 | 88 | def _trainstep_forward_backward(self, model: LightningModule, batch: Any) -> torch.Tensor: 89 | """Encapsulate calling StateTransitionPerturbationModel.training_step and backward. 90 | 91 | This intentionally targets StateTransitionPerturbationModel's signature and 92 | performs both forward and backward to capture full FLOPs. 93 | 94 | !!WARNING!! 95 | This has only been tested with StateTransitionPerturbationModel. Behavior with any other model has not been verified. 96 | """ 97 | # Clean gradients before measuring 98 | model.zero_grad(set_to_none=True) 99 | # Call training_step with the expected signature 100 | loss: torch.Tensor = model.training_step(batch, 0, padded=True) # type: ignore 101 | # Backward to include backward-pass FLOPs 102 | if self.use_backward: 103 | loss.backward() 104 | return loss 105 | 106 | def _measure_flops_once(self, trainer: Trainer, pl_module: Any, batch: Any) -> None: 107 | if self._measured: 108 | return 109 | 110 | model = pl_module 111 | 112 | # Measure FLOPs using a single callable that runs training_step and backward 113 | def forward_fn(): 114 | return self._trainstep_forward_backward(model, batch) 115 | 116 | self._flops_per_batch = int(measure_flops(model, forward_fn=forward_fn)) 117 | print(f"ModelFLOPSUtilizationCallback: Measured FLOPs per batch: {self._flops_per_batch}") 118 | pl_module.log("flops_per_batch", self._flops_per_batch, prog_bar=False, on_step=True, on_epoch=False) 119 | 120 | # Clear gradients before real training continues (safety) 121 | model.zero_grad(set_to_none=True) 122 | 123 | # Expose on the module for visibility/debugging 124 | setattr(pl_module, "flops_per_batch", self._flops_per_batch) 125 | self._measured = True 126 | 127 | def on_train_batch_start(self, trainer: Trainer, pl_module: Any, batch: dict, batch_idx: int) -> None: 128 | # Only calculate FLOPs on the first batch of the first epoch 129 | if not self._measured and batch_idx == 0 and trainer.current_epoch == 0: 130 | self._measure_flops_once(trainer, pl_module, batch) 131 | if torch.cuda.is_available(): 132 | torch.cuda.synchronize() 133 | self._train_start_time = time.time() 134 | 135 | def on_train_batch_end(self, trainer: Trainer, pl_module: Any, outputs: Any, batch: dict, batch_idx: int) -> None: 136 | if self._train_start_time is None or self._throughput is None: 137 | return 138 | 139 | samples = self._infer_batch_size(batch) 140 | 141 | # Update cumulative totals since training start 142 | self._cumulative_batches += 1 143 | self._cumulative_samples += samples 144 | 145 | # Log at a cadence controled by the logging_interval 146 | if batch_idx % self.logging_interval == 0 and batch_idx > 0: 147 | # Synchronize CUDA if available to ensure accurate timing 148 | if torch.cuda.is_available(): 149 | torch.cuda.synchronize() 150 | # Cumulative duration since training start 151 | self._cumulative_time = time.time() - self._train_start_time 152 | 153 | if batch_idx == self.logging_interval: 154 | flops = self._flops_per_batch * (self.logging_interval + 1) # type: ignore 155 | else: 156 | flops = self._flops_per_batch * self.logging_interval # type: ignore 157 | 158 | # Update throughput tracker 159 | self._throughput.update( 160 | time=self._cumulative_time, 161 | batches=self._cumulative_batches, 162 | samples=self._cumulative_samples, 163 | flops=flops, # type: ignore 164 | ) 165 | 166 | metrics: Dict[str, float] = self._throughput.compute() 167 | # Prefer global MFU when available, otherwise device MFU 168 | mfu = metrics.get("global/mfu", metrics.get("device/mfu", None)) 169 | if mfu is not None: 170 | mfu = 100 * mfu 171 | pl_module.log("mfu (%)", mfu, prog_bar=True, on_step=True, on_epoch=False) 172 | 173 | # Log cell_sets (cell_sentences) per second 174 | cell_sets_per_sec = metrics.get("global/samples_per_sec", metrics.get("device/samples_per_sec", None)) 175 | if cell_sets_per_sec is not None: 176 | pl_module.log( 177 | "cell_sets_per_sec", 178 | cell_sets_per_sec, 179 | prog_bar=False, 180 | on_step=True, 181 | on_epoch=False, 182 | ) 183 | -------------------------------------------------------------------------------- /tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | import pytest 3 | 4 | from state.tx.callbacks import model_flops_utilization as mfu 5 | from state.tx.callbacks.model_flops_utilization import ModelFLOPSUtilizationCallback 6 | from state.tx.callbacks.cumulative_flops import CumulativeFLOPSCallback 7 | import torch 8 | 9 | 10 | class FakeTrainer: 11 | def __init__(self, num_devices: int = 1, current_epoch: int = 0): 12 | self.num_devices = num_devices 13 | self.current_epoch = current_epoch 14 | 15 | 16 | class FakeModel(torch.nn.Module): 17 | def __init__(self, in_dim: int = 8, out_dim: int = 8) -> None: 18 | super().__init__() 19 | # Keep operations simple and deterministic for FLOPs counting 20 | self.weight = torch.nn.Parameter(torch.ones(in_dim, out_dim)) 21 | self.logged = [] 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | # Single matmul: (1 x in_dim) @ (in_dim x out_dim) -> (1 x out_dim) 25 | return x @ self.weight 26 | 27 | def training_step(self, batch, idx, padded: bool = True) -> torch.Tensor: 28 | # Derive a batch size from the fake batch to influence FLOPs 29 | # Our tests use cell_set_len=5 and fake_batch has shape (batch_size * cell_set_len, ...) 30 | bsz = batch["pert_cell_emb"].shape[0] // 5 31 | x = torch.ones(bsz, self.weight.shape[0], requires_grad=True) 32 | y = self.forward(x) 33 | return y.sum() 34 | 35 | def log(self, name, value, *, prog_bar=False, on_step=False, on_epoch=False, sync_dist=False): 36 | self.logged.append( 37 | { 38 | "name": name, 39 | "value": value, 40 | "prog_bar": prog_bar, 41 | "on_step": on_step, 42 | "on_epoch": on_epoch, 43 | "sync_dist": sync_dist, 44 | } 45 | ) 46 | 47 | 48 | @pytest.fixture 49 | def fake_model(): 50 | # Function-scoped fake model implementing the minimal interface used by the callback 51 | # Use 1x1 matmul so forward FLOPs are exactly 2 (multiply + add) 52 | return FakeModel(in_dim=1, out_dim=1) 53 | 54 | 55 | @pytest.fixture 56 | def trainer(): 57 | return FakeTrainer(num_devices=2, current_epoch=0) 58 | 59 | 60 | class _Arr: 61 | def __init__(self, shape): 62 | self.shape = shape 63 | 64 | 65 | @pytest.fixture 66 | def fake_batch(): 67 | # Create a flattened batch where total rows = batch_size * cell_set_len 68 | # We'll use batch_size=4 and cell_set_len=5 consistently in tests 69 | return {"pert_cell_emb": _Arr((20, 3))} 70 | 71 | 72 | def test_measure_flops_once_only_first_batch_and_epoch(fake_model, fake_batch): 73 | cb = ModelFLOPSUtilizationCallback(cell_set_len=5, use_backward=False, logging_interval=1, window_size=10) 74 | trainer = FakeTrainer(num_devices=1, current_epoch=0) 75 | # Initialize throughput to avoid None checks elsewhere 76 | cb.setup(cast(Any, trainer), fake_model, stage="fit") 77 | 78 | # First batch, first epoch -> should measure exactly once 79 | cb.on_train_batch_start(cast(Any, trainer), fake_model, fake_batch, batch_idx=0) 80 | first_logs = [e for e in fake_model.logged if e["name"] == "flops_per_batch"] 81 | assert cb._measured is True and len(first_logs) == 1 82 | 83 | # Subsequent batch in same epoch -> no re-measure 84 | cb.on_train_batch_start(cast(Any, trainer), fake_model, fake_batch, batch_idx=1) 85 | assert len([e for e in fake_model.logged if e["name"] == "flops_per_batch"]) == 1 86 | 87 | # First batch of a later epoch -> still no re-measure because already measured 88 | trainer.current_epoch = 1 89 | cb.on_train_batch_start(cast(Any, trainer), fake_model, fake_batch, batch_idx=0) 90 | assert len([e for e in fake_model.logged if e["name"] == "flops_per_batch"]) == 1 91 | 92 | 93 | def test_measure_flops_once_counts_forward_and_backward_flops(fake_model, fake_batch): 94 | # Compare forward-only vs forward+backward FLOPs 95 | trainer = FakeTrainer(num_devices=1, current_epoch=0) 96 | 97 | # Forward-only 98 | cb_fwd = ModelFLOPSUtilizationCallback(cell_set_len=5, use_backward=False) 99 | cb_fwd._measured = False 100 | cb_fwd._flops_per_batch = None 101 | cb_fwd._measure_flops_once(cast(Any, trainer), fake_model, fake_batch) 102 | 103 | # Forward + backward 104 | cb_bwd = ModelFLOPSUtilizationCallback(cell_set_len=5, use_backward=True) 105 | cb_bwd._measured = False 106 | cb_bwd._flops_per_batch = None 107 | cb_bwd._measure_flops_once(cast(Any, trainer), fake_model, fake_batch) 108 | 109 | # Expect backward ≈ 2x forward for matmul (dX and dW), so total ≈ forward + 2*forward = 3x forward 110 | assert cb_fwd._flops_per_batch is not None and cb_bwd._flops_per_batch is not None 111 | fwd = cast(int, cb_fwd._flops_per_batch) 112 | bwd = cast(int, cb_bwd._flops_per_batch) 113 | assert bwd == 3 * fwd 114 | # Ensure it was logged on the model 115 | assert any(e["name"] == "flops_per_batch" and e["value"] == cb_bwd._flops_per_batch for e in fake_model.logged) 116 | 117 | 118 | def test_mfu_is_calculated_correctly(fake_model, fake_batch): 119 | # Setup callback with small window for faster MFU computation 120 | cb = ModelFLOPSUtilizationCallback( 121 | cell_set_len=5, 122 | use_backward=False, 123 | logging_interval=5, 124 | available_flops=1000, 125 | window_size=3, 126 | ) 127 | trainer = FakeTrainer(num_devices=1, current_epoch=0) 128 | cb.setup(cast(Any, trainer), fake_model, stage="fit") 129 | 130 | # Set known FLOPs per batch to avoid measurement 131 | cb._measured = True 132 | cb._flops_per_batch = 1000 133 | 134 | # Simulate training with 1 second per batch 135 | start_time = mfu.time.time() 136 | 137 | for batch_idx in range(16): 138 | if batch_idx == 0: 139 | cb.on_train_batch_start(cast(Any, trainer), fake_model, fake_batch, batch_idx=batch_idx) 140 | cb._train_start_time = start_time 141 | else: 142 | cb.on_train_batch_start(cast(Any, trainer), fake_model, fake_batch, batch_idx=batch_idx) 143 | 144 | # Mock time progression 145 | current_time = start_time + (batch_idx + 1) * 1.0 146 | original_time = mfu.time.time 147 | mfu.time.time = lambda: current_time 148 | 149 | try: 150 | cb.on_train_batch_end(cast(Any, trainer), fake_model, outputs=None, batch=fake_batch, batch_idx=batch_idx) 151 | finally: 152 | mfu.time.time = original_time 153 | 154 | # Verify MFU was logged with reasonable values 155 | mfu_logs = [e for e in fake_model.logged if e["name"] == "mfu (%)"] 156 | assert len(mfu_logs) >= 1 157 | 158 | for mfu_log in mfu_logs: 159 | assert 50 <= mfu_log["value"] <= 150 160 | 161 | # Verify samples per second was logged with reasonable values 162 | sps_logs = [e for e in fake_model.logged if e["name"] == "cell_sets_per_sec"] 163 | assert len(sps_logs) >= 1 164 | 165 | for sps_log in sps_logs: 166 | assert 2 <= sps_log["value"] <= 6 167 | 168 | 169 | class TestCumulativeFLOPSCallback: 170 | def test_cumulative_flops_calculation_accuracy(self, fake_model, fake_batch): 171 | """Test that cumulative FLOPs calculation is accurate.""" 172 | cb = CumulativeFLOPSCallback(use_backward=False) 173 | trainer = FakeTrainer(num_devices=1, current_epoch=0) 174 | 175 | # Set known FLOPs per batch to avoid measurement 176 | cb._measured = True 177 | cb._flops_per_batch = 1000 178 | 179 | # Simulate 5 training batches 180 | for batch_idx in range(5): 181 | cb.on_train_batch_end(cast(Any, trainer), fake_model, outputs=None, batch=fake_batch, batch_idx=batch_idx) 182 | 183 | # Check cumulative FLOPs 184 | assert cb._cumulative_flops == 5000 185 | assert cb._batch_count == 5 186 | 187 | def test_cumulative_flops_batch_logging(self, fake_model, fake_batch): 188 | """Test that cumulative FLOPs are logged after every training batch.""" 189 | cb = CumulativeFLOPSCallback(use_backward=False) 190 | trainer = FakeTrainer(num_devices=1, current_epoch=0) 191 | 192 | # Set known FLOPs per batch 193 | cb._measured = True 194 | cb._flops_per_batch = 500 195 | 196 | # Simulate some training batches 197 | for batch_idx in range(3): 198 | cb.on_train_batch_end(cast(Any, trainer), fake_model, outputs=None, batch=fake_batch, batch_idx=batch_idx) 199 | 200 | # Should have cumulative FLOPs and logged after each batch 201 | assert cb._cumulative_flops == 1500 202 | 203 | # Check that cumulative_flops was logged 3 times (once per batch) 204 | cumulative_logs = [log for log in fake_model.logged if log["name"] == "cumulative_flops"] 205 | assert len(cumulative_logs) == 3 206 | assert cumulative_logs[0]["value"] == 500.0 # After batch 0 207 | assert cumulative_logs[1]["value"] == 1000.0 # After batch 1 208 | assert cumulative_logs[2]["value"] == 1500.0 # After batch 2 209 | 210 | # Verify logging parameters 211 | for log in cumulative_logs: 212 | assert log["on_step"] is True 213 | assert log["on_epoch"] is False 214 | -------------------------------------------------------------------------------- /src/state/emb/train/callbacks.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import lightning as L 7 | from lightning.fabric.utilities.throughput import measure_flops 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | class LogLR(L.Callback): 14 | def __init__(self, interval=10): 15 | super().__init__() 16 | self.interval = interval 17 | 18 | def on_train_batch_start( 19 | self, 20 | trainer: L.Trainer, 21 | pl_module: L.LightningModule, 22 | *args, 23 | ) -> None: 24 | if trainer.global_rank == 0: 25 | if trainer.global_step % self.interval == 0 and trainer.logger is not None: 26 | trainer.logger.log_metrics( 27 | {"trainer/learning_rate": pl_module.lr_schedulers().get_last_lr()[0]}, 28 | step=trainer.global_step, 29 | ) 30 | 31 | 32 | class PerfProfilerCallback(L.Callback): 33 | def __init__(self): 34 | super().__init__() 35 | self.batch_start_time = None 36 | self.batch_times = [] 37 | self.iterations_count = 0 38 | self.last_ipm_time = None 39 | self.ipm_history = [] 40 | 41 | def on_train_batch_start(self, trainer: L.Trainer, pl_module, batch, batch_idx): 42 | self.batch_start_time = time.time() 43 | 44 | def on_train_batch_end(self, trainer: L.Trainer, pl_module, outputs, batch, batch_idx): 45 | current_time = time.time() 46 | 47 | # Calculate batch time 48 | if self.batch_start_time: 49 | batch_time = current_time - self.batch_start_time 50 | self.batch_times.append(batch_time) 51 | 52 | # Track iterations per minute 53 | self.iterations_count += 1 54 | if self.last_ipm_time is None: 55 | self.last_ipm_time = current_time 56 | 57 | time_diff = current_time - self.last_ipm_time 58 | if time_diff >= 60: 59 | ipm = (self.iterations_count / time_diff) * 60 60 | self.ipm_history.append(ipm) 61 | trainer.logger.log_metrics({"perf/ipm": ipm}, step=trainer.global_step) 62 | # Reset counters 63 | self.iterations_count = 0 64 | self.last_ipm_time = current_time 65 | 66 | 67 | class ProfilerCallback(L.Callback): 68 | def __init__(self, cfg): 69 | super().__init__() 70 | self.batch_start_time = None 71 | self.batch_times = [] 72 | self.iterations_count = 0 73 | self.last_ipm_time = None 74 | self.ipm_history = [] 75 | self.cfg = cfg 76 | 77 | self.profile_steps = cfg.experiment.profile.profile_steps 78 | 79 | def on_train_batch_start(self, trainer: L.Trainer, pl_module, batch, batch_idx): 80 | self.batch_start_time = time.time() 81 | if batch_idx == self.profile_steps[0]: 82 | logging.info(f"Starting NSys profiling at step {batch_idx}") 83 | torch.cuda.nvtx.range_push("VCIProfiledSection") 84 | 85 | def on_train_batch_end(self, trainer: L.Trainer, pl_module, outputs, batch, batch_idx): 86 | current_time = time.time() 87 | 88 | # Calculate batch time 89 | if self.batch_start_time: 90 | batch_time = current_time - self.batch_start_time 91 | self.batch_times.append(batch_time) 92 | 93 | # Track iterations per minute 94 | self.iterations_count += 1 95 | if self.last_ipm_time is None: 96 | self.last_ipm_time = current_time 97 | 98 | time_diff = current_time - self.last_ipm_time 99 | if time_diff >= 60: 100 | ipm = (self.iterations_count / time_diff) * 60 101 | self.ipm_history.append(ipm) 102 | trainer.logger.log_metrics({"perf/ipm": ipm}, step=trainer.global_step) 103 | # Reset counters 104 | self.iterations_count = 0 105 | self.last_ipm_time = current_time 106 | 107 | if batch_idx == self.profile_steps[1]: 108 | logging.info(f"Stopping NSys profiling at step {batch_idx}") 109 | torch.cuda.nvtx.range_pop() 110 | 111 | 112 | class ResumeCallback(L.Callback): 113 | def __init__(self, cfg): 114 | super().__init__() 115 | self._cfg = cfg 116 | 117 | def on_train_start(self, trainer, pl_module): 118 | if self._cfg.optimizer.get("reset_lr_on_restart", False): 119 | for optimizer in trainer.optimizers: 120 | for param_group in optimizer.param_groups: 121 | original_lr = param_group.get("lr", None) 122 | param_group["lr"] = self._cfg.optimizer.max_lr 123 | logging.info(f"Reset learning rate from {original_lr} to {param_group['lr']}") 124 | 125 | 126 | class EMACallback(L.Callback): 127 | def __init__(self, decay: float = 0.999): 128 | super().__init__() 129 | self.beta = decay 130 | self.velocity = {} 131 | 132 | def on_before_optimizer_step(self, trainer: L.Trainer, pl_module: L.LightningModule, optimizer): 133 | # Check if EMA is enabled via the config flag. 134 | if pl_module.cfg.model.get("ema", False): 135 | with torch.no_grad(): 136 | for param in pl_module.parameters(): 137 | if param.grad is None: 138 | continue 139 | 140 | param_id = id(param) 141 | if param_id not in self.velocity: 142 | self.velocity[param_id] = torch.zeros_like(param.grad) 143 | 144 | self.velocity[param_id] = self.beta * self.velocity[param_id] + (1 - self.beta) * param.grad 145 | param.grad = self.velocity[param_id].clone() 146 | 147 | 148 | class CumulativeFLOPSCallback(L.Callback): 149 | """ 150 | PyTorch Lightning callback to track cumulative FLOPS during SE training. 151 | 152 | - Measures FLOPs once on the first training batch using `measure_flops`. 153 | - Tracks cumulative FLOPs and logs at validation frequency. 154 | - Logs cumulative_flops to trainer loggers (e.g., W&B, CSV) at validation cadence. 155 | 156 | Args: 157 | use_backward: If True, include backward pass FLOPs in the measurement. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | *, 163 | use_backward: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.use_backward = use_backward 167 | 168 | self._flops_per_batch: Optional[int] = None 169 | self._measured: bool = False 170 | self._cumulative_flops: int = 0 171 | self._batch_count: int = 0 172 | 173 | def _trainstep_forward_backward(self, model: L.LightningModule, batch: Any) -> torch.Tensor: 174 | """Encapsulate calling StateEmbeddingModel.training_step and backward. 175 | 176 | This intentionally targets StateEmbeddingModel's signature and 177 | performs both forward and backward to capture full FLOPs. 178 | 179 | !!WARNING!! 180 | This has only been tested with StateEmbeddingModel. Behavior with any other model has not been verified. 181 | """ 182 | model.zero_grad(set_to_none=True) 183 | loss: torch.Tensor = model.training_step(batch, 0) # type: ignore 184 | if self.use_backward: 185 | loss.backward() 186 | return loss 187 | 188 | def _measure_flops_once(self, trainer: L.Trainer, pl_module: Any, batch: Any) -> None: 189 | if self._measured: 190 | return 191 | 192 | model = pl_module 193 | 194 | def forward_fn(): 195 | return self._trainstep_forward_backward(model, batch) 196 | 197 | self._flops_per_batch = int(measure_flops(model, forward_fn=forward_fn)) 198 | logger.info(f"CumulativeFLOPSCallback: Measured FLOPs per batch: {self._flops_per_batch}") 199 | 200 | model.zero_grad(set_to_none=True) 201 | self._measured = True 202 | 203 | def on_train_batch_start(self, trainer: L.Trainer, pl_module: Any, batch: dict, batch_idx: int) -> None: 204 | if not self._measured and batch_idx == 0 and trainer.current_epoch == 0: 205 | self._measure_flops_once(trainer, pl_module, batch) 206 | 207 | def on_train_batch_end(self, trainer: L.Trainer, pl_module: Any, outputs: Any, batch: dict, batch_idx: int) -> None: 208 | if self._flops_per_batch is None: 209 | return 210 | 211 | self._batch_count += 1 212 | self._cumulative_flops += self._flops_per_batch 213 | 214 | # Log cumulative FLOPs after every training batch 215 | pl_module.log( 216 | "cumulative_flops", 217 | float(self._cumulative_flops), 218 | prog_bar=False, 219 | on_step=True, 220 | on_epoch=False, 221 | sync_dist=True, 222 | ) 223 | 224 | def on_validation_start(self, trainer: L.Trainer, pl_module: Any) -> None: 225 | if self._flops_per_batch is None: 226 | return 227 | 228 | # Log cumulative FLOPs at validation frequency for W&B panel alignment 229 | pl_module.log( 230 | "cumulative_flops_val_sync", 231 | float(self._cumulative_flops), 232 | prog_bar=False, 233 | on_step=False, 234 | on_epoch=True, 235 | sync_dist=True, 236 | ) 237 | -------------------------------------------------------------------------------- /src/state/configs/state-defaults.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | name: vci_pretrain_${loss.name}_${model.nhead}_${model.nlayers} 3 | local: local 4 | compiled: false 5 | deaware: false 6 | cumulative_flops_use_backward: true 7 | profile: 8 | enable_profiler: false 9 | profile_steps: 10 | - 10 11 | - 100 12 | max_steps: 110 13 | num_epochs: 16 14 | num_nodes: 1 15 | num_gpus_per_node: 1 16 | port: 12400 17 | val_check_interval: 1000 18 | limit_val_batches: 100 19 | ddp_timeout: 3600 20 | checkpoint: 21 | path: /scratch/ctc/ML/vci/checkpoint/pretrain 22 | save_top_k: 4 23 | monitor: trainer/train_loss 24 | every_n_train_steps: 1000 25 | wandb: 26 | enable: true 27 | project: vci 28 | embeddings: 29 | current: esm2-cellxgene 30 | esm2-cellxgene: 31 | all_embeddings: /large_storage/ctc/ML/data/cell/misc/Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt 32 | ds_emb_mapping: /large_storage/ctc/datasets/vci/training/gene_embidx_mapping.torch 33 | valid_genes_masks: null 34 | size: 5120 35 | num: 19790 36 | esm2-cellxgene-basecamp-tahoe: 37 | all_embeddings: /large_storage/ctc/ML/data/cell/misc/Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt 38 | ds_emb_mapping: /large_storage/ctc/datasets/updated1_gene_embidx_mapping_tahoe_basecamp_cellxgene.torch 39 | valid_genes_masks: /large_storage/ctc/datasets/updated1_valid_gene_index_tahoe_basecamp_cellxgene.torch 40 | size: 5120 41 | num: 19790 42 | esm2-cellxgene-tahoe: 43 | all_embeddings: /large_storage/ctc/ML/data/cell/misc/Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt 44 | ds_emb_mapping: /large_storage/ctc/datasets/updated1_gene_embidx_mapping_tahoe_basecamp_cellxgene.torch 45 | valid_genes_masks: /large_storage/ctc/datasets/updated1_valid_gene_index_tahoe_basecamp_cellxgene.torch 46 | size: 5120 47 | num: 19790 48 | evo2-scbasecamp: 49 | all_embeddings: /large_storage/ctc/projects/vci/scbasecamp/Evo2/all_species_Evo2.torch 50 | ds_emb_mapping: /large_storage/ctc/projects/vci/scbasecamp/Evo2/dataset_emb_idx_Evo2_fixed.torch 51 | valid_genes_masks: /large_storage/ctc/projects/vci/scbasecamp/Evo2/valid_gene_index_Evo2.torch 52 | size: 4096 53 | num: 503178 54 | esm2-scbasecamp: 55 | all_embeddings: /large_storage/ctc/projects/vci/scbasecamp/ESM2/all_species_ESM2.torch 56 | ds_emb_mapping: /large_storage/ctc/projects/vci/scbasecamp/ESM2/dataset_emb_idx_ESM2.torch 57 | valid_genes_masks: /large_storage/ctc/projects/vci/scbasecamp/ESM2/valid_gene_index_ESM2.torch 58 | size: 1280 59 | num: 503178 60 | esm2_3B-scbasecamp: 61 | all_embeddings: /large_storage/ctc/projects/vci/scbasecamp/ESM2_3B/all_species.torch 62 | ds_emb_mapping: /large_storage/ctc/projects/vci/scbasecamp/ESM2_3B/dataset_emb_idx.torch 63 | valid_genes_masks: /large_storage/ctc/projects/vci/scbasecamp/ESM2_3B/valid_gene_index.torch 64 | size: 2560 65 | num: 503178 66 | esm2_3B-scbasecamp_cellxgene: 67 | all_embeddings: /large_storage/ctc/projects/vci/scbasecamp/ESM2_3B/all_species.torch 68 | ds_emb_mapping: /home/alishbaimran/scbasecamp/dataset_emb_idx_ESM2_copy.torch 69 | valid_genes_masks: /home/alishbaimran/scbasecamp/valid_gene_index.torch 70 | size: 2560 71 | num: 503178 72 | llm-cellxgene: 73 | all_embeddings: /home/aadduri/state-diffusion/all_embeddings_llm-cellxgene.pt 74 | ds_emb_mapping: /home/aadduri/state-diffusion/ds_emb_mapping_llm-cellxgene.torch 75 | valid_genes_masks: /home/aadduri/state-diffusion/valid_genes_masks_llm-cellxgene.torch 76 | size: 5120 77 | num: 19790 78 | test-cellxgene: 79 | all_embeddings: /home/aadduri/state-diffusion/all_embeddings_test-cellxgene.pt 80 | ds_emb_mapping: /home/aadduri/state-diffusion/ds_emb_mapping_test-cellxgene.torch 81 | valid_genes_masks: /home/aadduri/state-diffusion/valid_genes_masks_test-cellxgene.torch 82 | size: 5120 83 | num: 19790 84 | all_data: 85 | all_embeddings: /home/aadduri/state-diffusion/all_human_data_filtered/all_embeddings_all_data.pt 86 | ds_emb_mapping: /home/aadduri/state-diffusion/all_human_data_filtered/ds_emb_mapping_all_data.torch 87 | valid_genes_masks: /home/aadduri/state-diffusion/all_human_data_filtered/valid_genes_masks_all_data.torch 88 | size: 5120 89 | num: 19790 90 | SE-167M: 91 | all_embeddings: /home/aadduri/state/SE-167M-Data/all_embeddings_SE-167M.pt 92 | ds_emb_mapping: /home/aadduri/state/SE-167M-Data/ds_emb_mapping_SE-167M.torch 93 | valid_genes_masks: /home/aadduri/state/SE-167M-Data/valid_genes_masks_SE-167M.torch 94 | size: 5120 95 | num: 19790 96 | validations: 97 | diff_exp: 98 | enable: true 99 | eval_interval_multiple: 10 100 | obs_pert_col: gene 101 | obs_filter_label: non-targeting 102 | top_k_rank: 200 103 | method: null 104 | dataset: /large_storage/ctc/datasets/cellxgene/processed/rpe1_top5000_variable.h5ad 105 | dataset_name: rpe1_top5000_variable 106 | perturbation: 107 | enable: true 108 | eval_interval_multiple: 10 109 | pert_col: gene 110 | ctrl_label: non-targeting 111 | dataset: /large_storage/ctc/datasets/cellxgene/processed/rpe1_top5000_variable.h5ad 112 | dataset_name: rpe1_top5000_variable 113 | dataset: 114 | name: vci 115 | seed: 42 116 | num_train_workers: 16 117 | num_val_workers: 4 118 | current: cellxgene 119 | cellxgene: 120 | data_dir: /large_experiments/goodarzilab/mohsen/cellxgene/processed 121 | ds_type: h5ad 122 | filter: false 123 | train: /scratch/ctc/ML/uce/h5ad_train_dataset.csv 124 | val: /scratch/ctc/ML/uce/h5ad_val_dataset.csv 125 | num_datasets: 1139 126 | scbasecamp: 127 | ds_type: h5ad 128 | train: /home/alishbaimran/scbasecamp/scbasecamp_all.csv 129 | val: /home/alishbaimran/scbasecamp/scbasecamp_all.csv 130 | filter: true 131 | filter_by_species: null 132 | scbasecamp-cellxgene: 133 | ds_type: h5ad 134 | train: /home/alishbaimran/scbasecamp/scBasecamp_cellxgene_all.csv 135 | val: /home/alishbaimran/scbasecamp/scBasecamp_cellxgene_all.csv 136 | filter: true 137 | filter_by_species: null 138 | scbasecamp-cellxgene-tahoe-filtered: 139 | ds_type: h5ad 140 | train: /large_storage/ctc/userspace/rohankshah/19kfilt_combined_train.csv 141 | val: /large_storage/ctc/userspace/rohankshah/19kfilt_combined_val.csv 142 | filter: true 143 | filter_by_species: null 144 | num_datasets: 14420 145 | scbasecamp-cellxgene-tahoe: 146 | ds_type: h5ad 147 | train: /large_storage/ctc/datasets/scbasecamp_filtered_tahoe_cellxgene_train.csv 148 | val: /large_storage/ctc/datasets/scbasecamp_filtered_tahoe_cellxgene_val.csv 149 | filter: true 150 | filter_by_species: null 151 | num_datasets: 15700 152 | cellxgene-tahoe: 153 | ds_type: h5ad 154 | train: /large_storage/ctc/datasets/tahoe_cellxgene_train.csv 155 | val: /large_storage/ctc/datasets/tahoe_cellxgene_val.csv 156 | filter: true 157 | filter_by_species: null 158 | num_datasets: 1139 159 | tahoe: 160 | ds_type: h5ad 161 | train: /scratch/ctc/ML/uce/full_train_datasets.csv 162 | val: /scratch/ctc/ML/uce/full_train_datasets.csv 163 | filter: true 164 | valid_genes_masks: null 165 | tahoe-h5ad: 166 | ds_type: h5ad 167 | train: /scratch/ctc/ML/uce/h5ad_train_dataset_tahoe.csv 168 | val: /scratch/ctc/ML/uce/h5ad_val_dataset_tahoe.csv 169 | filter: true 170 | valid_genes_masks: null 171 | pad_length: 2048 172 | pad_token_idx: 0 173 | cls_token_idx: 3 174 | chrom_token_right_idx: 2 175 | P: 512 176 | 'N': 512 177 | S: 512 178 | num_cells: 36238464 179 | overrides: 180 | rpe1_top5000_variable: /large_storage/ctc/datasets/vci/validation/rpe1_top5000_variable.h5ad 181 | llm-cellxgene: 182 | ds_type: h5ad 183 | train: /home/aadduri/state-diffusion/train_llm-cellxgene.csv 184 | val: /home/aadduri/state-diffusion/val_llm-cellxgene.csv 185 | filter: true 186 | num_datasets: 569 187 | test-cellxgene: 188 | ds_type: h5ad 189 | train: /home/aadduri/state-diffusion/train_test-cellxgene.csv 190 | val: /home/aadduri/state-diffusion/val_test-cellxgene.csv 191 | filter: true 192 | num_datasets: 4 193 | all_data: 194 | ds_type: h5ad 195 | train: /home/aadduri/state-diffusion/all_human_data_filtered/train_all_data.csv 196 | val: /home/aadduri/state-diffusion/all_human_data_filtered/val_all_data.csv 197 | filter: true 198 | num_datasets: 31195 199 | SE-167M: 200 | ds_type: h5ad 201 | train: /home/aadduri/state/SE-167M-Data/train_SE-167M.csv 202 | val: /home/aadduri/state/SE-167M-Data/val_SE-167M.csv 203 | filter: true 204 | num_datasets: 14418 205 | tokenizer: 206 | token_dim: 5120 207 | model: 208 | name: vci 209 | batch_size: 128 210 | emsize: 512 211 | d_hid: 1024 212 | nhead: 16 213 | nlayers: 8 214 | dropout: 0.1 215 | output_dim: 512 216 | use_flash_attention: true 217 | rda: true 218 | counts: true 219 | dataset_correction: true 220 | ema: false 221 | ema_decay: 0.999 222 | ema_update_interval: 1000 223 | sample_rda: false 224 | batch_tabular_loss: false 225 | num_downsample: 1 226 | variable_masking: true 227 | task: 228 | mask: 0.2 229 | optimizer: 230 | max_lr: 1.0e-05 231 | weight_decay: 0.01 232 | start: 0.33 233 | end: 1.0 234 | max_grad_norm: 0.8 235 | gradient_accumulation_steps: 8 236 | reset_lr_on_restart: false 237 | zclip: false 238 | loss: 239 | name: tabular 240 | apply_normalization: false 241 | kernel: energy 242 | uniformity: false 243 | -------------------------------------------------------------------------------- /MODEL_LICENSE.md: -------------------------------------------------------------------------------- 1 | # Arc Research Institute State Model Non-Commercial License 2 | 3 | **_Last Updated: June 23, 2025_** 4 | 5 | Arc Research Institute (“**Licensor**”) is releasing certain materials collectively defined below as the “**State Model**”. This license (the “**License**”) is between Licensor and any individual or legal entity exercising permissions granted under this License (“**You**”). 6 | 7 | By downloading, copying, executing, or otherwise exercising any permission granted by this License or using the State Model, You acknowledge that You accept and agree to be bound by the terms and conditions of this License. If use of the State Model is within the scope of Your employment, You hereby represent that you are fully authorized to enter into this Agreement on behalf of any such employer. 8 | 9 | Licensor may make changes to this License. The “_Last Updated_” date above indicates when this License was last changed. The amended License will be effective immediately, and Your continued use of the State Model after such update will confirm Your acceptance of the changes. If You do not agree to the amended License, You must immediately stop using the State Model. 10 | 11 | ## 1. Definitions 12 | 13 | 1.1 “**Commercial Entity**” means any natural person, legal entity, organization, or institution that is for-profit, or otherwise not a Non-Commercial Entity. 14 | 15 | 1.2 “**Derivative Work**” means any modification of, work based upon, or derived from, the State Model or any portion thereof, including adaptations, fine-tunings, adjusted model weights, distilled or pruned versions, additional training, or any other alteration, transformation, or integration, including use as a prior, of the State Model or any part of it. 16 | 17 | 1.3 “**Non-Commercial Entity**” means a (a) government body or (b) a natural person, academic institution, research institute, or other entity registered with relevant tax authorities as a not-for-profit organization that, in each (a) and (b) above, does not own or control or is not owned or controlled by and is not acting with or on behalf of a Commercial Entity. 18 | 19 | 1.4 “**Non-Commercial Purpose**” means a purpose that is not undertaken for direct or indirect monetary compensation, financial consideration, or any other commercial undertaking or advantage. For the purposes of this License, research, development, or educational activities engaged in or sponsored by a Commercial Entity (whether or not performed by a Non-Commercial Entity or otherwise undertaken for a purpose not associated with monetary compensation, financial consideration or other commercial undertaking or advantage) are not Non-Commercial Purposes. 20 | 21 | 1.5 “**Output**” means any results, data, or other material generated by execution of the unmodified State Model or a Derivative Work. 22 | 23 | 1.6 “**State Model**” includes, individually and collectively, the pretrained model (_i.e._, the model code combined with the trained model weights), model weights, documentation, and any other files released by Licensor under this License (provided that source code and object code are licensed under a separate license made available by Licensor). 24 | 25 | ## 2. Grant of Rights 26 | 27 | 2.1 **License** Subject to the terms and conditions of this License, Licensor hereby grants You a personal, worldwide, royalty-free, non-exclusive, irrevocable (except as provided in Section 8), non-sublicensable license: (a) to use, copy, modify, and display the State Model; (b) to create Derivative Works and Outputs; (c) to reproduce, distribute, and make publicly available (i) the State Model and any Derivative Work, each subject to Section 3 (Redistribution Conditions), and (ii) any Output; and (d) under the Licensor’s patent rights covering the State Model, to make, use, import, and otherwise exploit the State Model and any Derivative Works, each (a) through (c)(i) and (d), solely for Non-Commercial Purposes. 28 | 29 | 2.2 **Patent Infringement** If You initiate patent litigation (including a cross-claim or counterclaim in a lawsuit) alleging that the State Model or a Derivative Work infringes a patent, this License to You is terminated as of the date such litigation is filed. 30 | 31 | 2.3 **Reservation of Rights**. All rights other than those expressly granted to You herein are reserved by Licensor. 32 | 33 | ## 3. Redistribution Conditions 34 | 35 | If You redistribute the State Model or any Derivative Work, You must (a) include a complete, unaltered copy of this License; (b) keep intact all original copyright, license, and disclaimer notices; and (c) provide prominent notice stating that the redistributed State Model is governed by this License. If You redistribute the State Model, any Derivative Work or any Output, You must also include a prominent and readable citation to the State Model research paper: Adduri, A. et al. (2025) Predicting cellular responses to perturbation across diverse contexts with State. You may choose to license a Derivative Work or Output under different terms and conditions, provided that the limitations and restrictions herein (including, in the case of a Derivative Work, the Non-Commercial Purpose) remain in substance and effect. 36 | 37 | ## 4. Restrictions 38 | 39 | 4.1 You may not use or distribute the State Model, any or Derivative Work for any purpose other than Non-Commercial Purposes, unless You first obtain a separate, express, written license from Licensor granting such rights. If you wish to inquire about such separate license, you may contact . 40 | 41 | 4.2 You may not remove, obfuscate, or alter any copyright, trademark, citation or other proprietary notice included in the State Model. 42 | 43 | 4.3 You must comply with all applicable laws and regulations, including export-control and sanctions regimes, in connection with Your use of the State Model and any Derivative Work or Output. 44 | 45 | 4.4 If You elect to collect or process personal data, health data or any other regulated data using the State Model or any Derivative Work, You are solely responsible for compliance with all data-protection, privacy, and related laws. 46 | 47 | ## 5. Feedback 48 | 49 | If You provide Licensor with ideas, suggestions, or improvements regarding the State Model (“**Feedback**”), Licensor may incorporate such Feedback without obligation or restriction, and You hereby irrevocably assign to Licensor all right, title, and interest in and to such Feedback and all intellectual property rights therein and thereto. To the extent any rights in and to such Feedback cannot be assigned by You, You hereby grants to Licensor a nonexclusive, worldwide, perpetual, irrevocable, transferable, sublicensable, royalty-free, fully paid-up license to use, practice, and otherwise exercise the Feedback without restriction. 50 | 51 | ## 6. Trademarks 52 | 53 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the State Model and reproducing the content of the license as required herein. 54 | 55 | ## 7. Disclaimer of Warranty; Limitation of Liability 56 | 57 | THE STATE MODEL, ANY DERIVATIVE WORK, AND ANY OUTPUT ARE PROVIDED “AS IS” AND “WITH ALL FAULTS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, OR THAT OPERATION OF THE STATE MODEL WILL BE UNINTERRUPTED OR ERROR-FREE. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, LICENSOR SHALL NOT BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, FOR ANY INDIRECT, SPECIAL, INCIDENTAL, CONSEQUENTIAL, OR EXEMPLARY DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE OR THE USE OF THE STATE MODEL, DERIVATIVE WORKS, OR OUTPUTS, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. LICENSOR’S TOTAL AGGREGATE LIABILITY UNDER THIS LICENSE SHALL NOT EXCEED TWENTY U.S. DOLLARS (USD $20). 58 | 59 | ## 8. Termination 60 | 61 | 8.1 This License and the rights granted herein terminate automatically without notice if You breach Sections 2, 3, or 4. Upon termination, You shall cease all use, reproduction, and distribution of the State Model, any Derivative Work, and any Output, and Sections 2.3, 4.4, 5, 7, 8.2, and 9 shall remain in effect. 62 | 63 | 8.2 Termination shall not affect any third-party’s legitimate possession, under this License, of copies of the State Model, Derivative Works, or Outputs obtained from You prior to termination, provided such third party is itself in full compliance with this License. 64 | 65 | ## 9. Miscellaneous 66 | 67 | 9.1 **Governing Law**. This License shall be governed by and construed in accordance with the laws of the state of California, excluding its or any other jurisdiction’s conflict-of-laws principles. Venue for any disputes related to this License shall be in the federal courts in Santa Clara County, California and You and Licensor hereby agree to the exclusive jurisdiction of such courts. 68 | 69 | 9.2 **Severability**. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable, and the remaining provisions shall remain in full force and effect. 70 | 71 | 9.3 **No Waiver**. Failure by Licensor to enforce any provision of this License shall not constitute a waiver of future enforcement of that or any other provision. 72 | 73 | 9.4 **Entire Agreement**. This License constitutes the entire agreement between the parties with respect to the State Model and supersedes all prior or contemporaneous understandings, whether written or oral, regarding such subject matter. References herein to “including” will be interpreted to mean “including but not limited to” and any examples provided are illustrative and not intended to be the sole examples of a particular concept. 74 | 75 | -------------------------------------------------------------------------------- /src/state/tx/models/scvi/_base_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import Normal 6 | from torch.nn import functional as F 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | """Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py 11 | 12 | Focal Loss, as described in https://arxiv.org/abs/1708.02002. 13 | It is essentially an enhancement to cross entropy loss and is 14 | useful for classification tasks when there is a large class imbalance. 15 | x is expected to contain raw, unnormalized scores for each class. 16 | y is expected to contain class labels. 17 | Shape: 18 | - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. 19 | - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | alpha: Optional[torch.Tensor] = None, 25 | gamma: float = 2.0, 26 | reduction: str = "mean", 27 | ): 28 | """ 29 | Args: 30 | alpha (Tensor, optional): Weights for each class. Defaults to None. 31 | gamma (float, optional): A constant, as described in the paper. 32 | Defaults to 0. 33 | reduction (str, optional): 'mean', 'sum' or 'none'. 34 | Defaults to 'mean'. 35 | """ 36 | if reduction not in ("mean", "sum", "none"): 37 | raise ValueError('Reduction must be one of: "mean", "sum", "none".') 38 | 39 | super().__init__() 40 | self.alpha = alpha 41 | self.gamma = gamma 42 | self.reduction = reduction 43 | 44 | self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none") 45 | 46 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 47 | if len(y_true) == 0: 48 | return torch.tensor(0.0) 49 | 50 | # compute weighted cross entropy term: -alpha * log(pt) 51 | # (alpha is already part of self.nll_loss) 52 | log_p = F.log_softmax(y_pred, dim=-1) 53 | ce = self.nll_loss(log_p, y_true) 54 | 55 | # get true class column from each row 56 | all_rows = torch.arange(len(y_pred)) 57 | log_pt = log_p[all_rows, y_true] 58 | 59 | # compute focal term: (1 - pt)^gamma 60 | pt = log_pt.exp() 61 | focal_term = (1 - pt) ** self.gamma 62 | 63 | # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) 64 | loss = focal_term * ce 65 | 66 | if self.reduction == "mean": 67 | loss = loss.mean() 68 | elif self.reduction == "sum": 69 | loss = loss.sum() 70 | 71 | return loss 72 | 73 | 74 | class MLP(nn.Module): 75 | def __init__( 76 | self, 77 | n_input, 78 | n_output, 79 | n_hidden, 80 | n_layers, 81 | activation_fn: Optional[nn.Module] = nn.ReLU, 82 | use_norm: str = "batch", 83 | dropout_rate: float = 0.3, 84 | drop_norm_last_layer: bool = True, 85 | ): 86 | super().__init__() 87 | if drop_norm_last_layer: 88 | layers = [n_input] + [n_hidden] * n_layers 89 | else: 90 | layers = [n_input] + [n_hidden] * (n_layers - 1) + [n_output] 91 | 92 | network = [] 93 | for n_in, n_out in zip(layers[:-1], layers[1:]): 94 | network.append(nn.Linear(n_in, n_out)) 95 | if use_norm == "batch": 96 | network.append(nn.BatchNorm1d(n_out)) 97 | elif use_norm == "layer": 98 | network.append(nn.LayerNorm(n_out)) 99 | network.append(activation_fn()) 100 | network.append(nn.Dropout(dropout_rate)) 101 | 102 | if drop_norm_last_layer: 103 | network.append(nn.Linear(n_hidden, n_output)) 104 | 105 | self.network = nn.Sequential(*network) 106 | 107 | def forward(self, x): 108 | """ 109 | x: (batch_size, n_input) 110 | """ 111 | return self.network(x) 112 | 113 | 114 | class Classifier(nn.Module): 115 | def __init__( 116 | self, 117 | n_input, 118 | n_labels, 119 | n_hidden, 120 | n_layers, 121 | activation_fn=nn.ReLU, 122 | use_norm: str = "batch", 123 | dropout_rate: float = 0.3, 124 | ): 125 | super().__init__() 126 | self.n_output = n_labels 127 | 128 | self.network = MLP( 129 | n_input=n_input, 130 | n_output=n_labels, 131 | n_layers=n_layers, 132 | n_hidden=n_hidden, 133 | use_norm=use_norm, 134 | dropout_rate=dropout_rate, 135 | activation_fn=activation_fn, 136 | drop_norm_last_layer=True, 137 | ) 138 | 139 | def forward(self, x): 140 | y = self.network(x) 141 | return y 142 | 143 | 144 | class VariationalEncoder(nn.Module): 145 | def __init__( 146 | self, 147 | n_input: int, 148 | n_output: int, 149 | n_layers: int = 1, 150 | n_hidden: int = 128, 151 | dropout_rate: float = 0.1, 152 | use_norm: str = "batch", 153 | var_eps: float = 1e-4, 154 | var_activation=None, 155 | return_dist: bool = False, 156 | **kwargs, 157 | ): 158 | super().__init__() 159 | 160 | self.var_eps = var_eps 161 | self.encoder = MLP( 162 | n_input=n_input, 163 | n_output=n_hidden, 164 | n_layers=n_layers, 165 | n_hidden=n_hidden, 166 | dropout_rate=dropout_rate, 167 | use_norm=use_norm, 168 | drop_norm_last_layer=False, 169 | ) 170 | self.mean_encoder = nn.Linear(n_hidden, n_output) 171 | self.var_encoder = nn.Linear(n_hidden, n_output) 172 | self.return_dist = return_dist 173 | 174 | self.var_activation = torch.exp if var_activation is None else var_activation 175 | 176 | def forward(self, x: torch.Tensor, *cat_list: int): 177 | """ """ 178 | q = self.encoder(x, *cat_list) 179 | 180 | q_m = self.mean_encoder(q) 181 | q_v = self.var_activation(self.var_encoder(q)) + self.var_eps 182 | 183 | dist = Normal(q_m, q_v.sqrt()) 184 | latent = dist.rsample() 185 | 186 | if self.return_dist: 187 | return dist, latent 188 | 189 | return q_m, q_v, latent 190 | 191 | 192 | # Inspired by scvi-tools source code: https://github.com/scverse/scvi-tools/blob/d094c9b3c14e8cb3ac3a309b9cf0160aff237393/scvi/nn/_base_components.py 193 | class CountDecoder(nn.Module): 194 | """Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions. 195 | 196 | Uses a fully-connected neural network of ``n_hidden`` layers. 197 | 198 | Parameters 199 | ---------- 200 | n_input 201 | The dimensionality of the input (latent space) 202 | n_output 203 | The dimensionality of the output (data space) 204 | n_cat_list 205 | A list containing the number of categories 206 | for each category of interest. Each category will be 207 | included using a one-hot encoding 208 | n_layers 209 | The number of fully-connected hidden layers 210 | n_hidden 211 | The number of nodes per hidden layer 212 | dropout_rate 213 | Dropout rate to apply to each of the hidden layers 214 | inject_covariates 215 | Whether to inject covariates in each layer, or just the first (default). 216 | use_batch_norm 217 | Whether to use batch norm in layers 218 | use_layer_norm 219 | Whether to use layer norm in layers 220 | scale_activation 221 | Activation layer to use for px_scale_decoder 222 | """ 223 | 224 | def __init__( 225 | self, 226 | n_input: int, 227 | n_output: int, 228 | n_layers: int = 1, 229 | n_hidden: int = 128, 230 | use_norm: Literal["batch", "layer"] = "batch", 231 | scale_activation: Literal["softmax", "softplus"] = "softmax", 232 | ): 233 | super().__init__() 234 | self.px_decoder = MLP( 235 | n_input=n_input, 236 | n_output=n_hidden, 237 | n_layers=n_layers, 238 | n_hidden=n_hidden, 239 | dropout_rate=0.0, 240 | use_norm=use_norm, 241 | drop_norm_last_layer=False, 242 | ) 243 | 244 | # mean gamma 245 | if scale_activation == "softmax": 246 | px_scale_activation = nn.Softmax(dim=-1) 247 | elif scale_activation == "softplus": 248 | px_scale_activation = nn.Softplus() 249 | self.px_scale_decoder = nn.Sequential( 250 | nn.Linear(n_hidden, n_output), 251 | px_scale_activation, 252 | ) 253 | 254 | # dispersion: here we only deal with gene-cell dispersion case 255 | self.px_r_decoder = nn.Linear(n_hidden, n_output) 256 | 257 | # dropout 258 | self.px_dropout_decoder = nn.Linear(n_hidden, n_output) 259 | 260 | def forward( 261 | self, 262 | dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"], 263 | z: torch.Tensor, 264 | library: torch.Tensor, 265 | ): 266 | """The forward computation for a single sample. 267 | 268 | #. Decodes the data from the latent space using the decoder network 269 | #. Returns parameters for the ZINB distribution of expression 270 | #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` 271 | 272 | Parameters 273 | ---------- 274 | z : 275 | tensor with shape ``(n_input,)`` 276 | library_size 277 | library size 278 | cat_list 279 | list of category membership(s) for this sample 280 | dispersion 281 | One of the following 282 | 283 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 284 | * ``'gene-batch'`` - dispersion can differ between different batches 285 | * ``'gene-label'`` - dispersion can differ between different labels 286 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 287 | 288 | Returns 289 | ------- 290 | 4-tuple of :py:class:`torch.Tensor` 291 | parameters for the ZINB distribution of expression 292 | 293 | """ 294 | # The decoder returns values for the parameters of the ZINB distribution 295 | px = self.px_decoder(z) 296 | px_scale = self.px_scale_decoder(px) 297 | px_dropout = self.px_dropout_decoder(px) 298 | # Clamp to high value: exp(12) ~ 160000 to avoid nans (computational stability) 299 | px_rate = torch.exp(library) * px_scale # torch.clamp( , max=12) 300 | px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None 301 | return px_scale, px_r, px_rate, px_dropout 302 | -------------------------------------------------------------------------------- /src/state/tx/models/old_neural_ot.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from geomloss import SamplesLoss 5 | 6 | from .base import PerturbationModel 7 | from .utils import build_mlp, get_activation_class, get_transformer_backbone 8 | 9 | 10 | class OldNeuralOTPerturbationModel(PerturbationModel): 11 | """ 12 | This model: 13 | 1) Projects basal expression and perturbation encodings into a shared latent space. 14 | 2) Uses an OT-based distributional loss (energy, sinkhorn, etc.) from geomloss. 15 | 3) Enables cells to attend to one another, learning a set-to-set function rather than 16 | a sample-to-sample single-cell map. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | input_dim: int, 22 | hidden_dim: int, 23 | output_dim: int, 24 | pert_dim: int, 25 | predict_residual: bool = True, 26 | distributional_loss: str = "energy", 27 | transformer_backbone_key: str = "GPT2", 28 | transformer_backbone_kwargs: dict = None, 29 | output_space: str = "gene", 30 | gene_dim: Optional[int] = None, 31 | **kwargs, 32 | ): 33 | """ 34 | Args: 35 | input_dim: dimension of the input expression (e.g. number of genes or embedding dimension). 36 | hidden_dim: not necessarily used, but required by PerturbationModel signature. 37 | output_dim: dimension of the output space (genes or latent). 38 | pert_dim: dimension of perturbation embedding. 39 | gpt: e.g. "TranslationTransformerSamplesModel". 40 | model_kwargs: dictionary passed to that model's constructor. 41 | loss: choice of distributional metric ("sinkhorn", "energy", etc.). 42 | **kwargs: anything else to pass up to PerturbationModel or not used. 43 | """ 44 | # Call the parent PerturbationModel constructor 45 | super().__init__( 46 | input_dim=input_dim, 47 | hidden_dim=hidden_dim, 48 | gene_dim=gene_dim, 49 | output_dim=output_dim, 50 | pert_dim=pert_dim, 51 | output_space=output_space, 52 | **kwargs, 53 | ) 54 | 55 | # Save or store relevant hyperparams 56 | self.predict_residual = predict_residual 57 | self.n_encoder_layers = kwargs.get("n_encoder_layers", 2) 58 | self.n_decoder_layers = kwargs.get("n_decoder_layers", 2) 59 | self.activation_class = get_activation_class(kwargs.get("activation", "gelu")) 60 | self.transformer_backbone_key = transformer_backbone_key 61 | self.transformer_backbone_kwargs = transformer_backbone_kwargs 62 | self.distributional_loss = distributional_loss 63 | self.cell_sentence_len = self.transformer_backbone_kwargs["n_positions"] 64 | self.gene_dim = gene_dim 65 | 66 | # Build the distributional loss from geomloss 67 | self.loss_fn = SamplesLoss(loss=self.distributional_loss) 68 | # self.loss_fn = LearnableAlignmentLoss() 69 | 70 | # Build the underlying neural OT network 71 | self._build_networks() 72 | 73 | def _build_networks(self): 74 | """ 75 | Here we instantiate the actual GPT2-based model or any neuralOT translator 76 | via your old get_model(model_key, model_kwargs) approach. 77 | """ 78 | self.pert_encoder = build_mlp( 79 | in_dim=self.pert_dim, 80 | out_dim=self.hidden_dim, 81 | hidden_dim=self.hidden_dim, 82 | n_layers=self.n_encoder_layers, 83 | dropout=self.dropout, 84 | activation=self.activation_class, 85 | ) 86 | 87 | # Map the input embedding to the hidden space 88 | self.basal_encoder = build_mlp( 89 | in_dim=self.input_dim, 90 | out_dim=self.hidden_dim, 91 | hidden_dim=self.hidden_dim, 92 | n_layers=self.n_encoder_layers, 93 | dropout=self.dropout, 94 | activation=self.activation_class, 95 | ) 96 | 97 | self.transformer_backbone, self.transformer_model_dim = get_transformer_backbone( 98 | self.transformer_backbone_key, 99 | self.transformer_backbone_kwargs, 100 | ) 101 | 102 | self.project_out = build_mlp( 103 | in_dim=self.hidden_dim, 104 | out_dim=self.output_dim, 105 | hidden_dim=self.hidden_dim, 106 | n_layers=self.n_decoder_layers, 107 | dropout=self.dropout, 108 | activation=self.activation_class, 109 | ) 110 | 111 | print(self) 112 | 113 | def encode_perturbation(self, pert: torch.Tensor) -> torch.Tensor: 114 | """If needed, define how we embed the raw perturbation input.""" 115 | return self.pert_encoder(pert) 116 | 117 | def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: 118 | """Define how we embed basal state input, if needed.""" 119 | return self.basal_encoder(expr) 120 | 121 | def perturb(self, pert: torch.Tensor, basal: torch.Tensor) -> torch.Tensor: 122 | """ 123 | Return the latent perturbed state given the perturbation and basal state. 124 | """ 125 | pert_embedding = self.encode_perturbation(pert).unsqueeze(1) # shape: [batch_size, 1, hidden_dim] 126 | control_cells = self.encode_basal_expression(basal).unsqueeze(1) # shape: [batch_size, 1, hidden_dim] 127 | cls_input = torch.zeros_like(pert_embedding) # shape: [batch_size, 1, hidden_dim] 128 | seq_input = torch.cat([pert_embedding, control_cells, cls_input], dim=1) # shape: [batch_size, 3, hidden_dim] 129 | 130 | # forward pass + extract CLS last hidden state 131 | prediction = self.transformer_backbone(inputs_embeds=seq_input).last_hidden_state[:, -1] 132 | 133 | # add to basal if predicting residual 134 | if self.predict_residual: 135 | # treat the actual prediction as a residual sum to basal 136 | return prediction + control_cells.squeeze(1) 137 | else: 138 | return prediction 139 | 140 | def forward(self, batch: dict) -> torch.Tensor: 141 | """ 142 | The main forward call. 143 | """ 144 | prediction = self.perturb(batch["pert_emb"], batch["ctrl_cell_emb"]) 145 | output = self.project_out(prediction) 146 | 147 | return output 148 | 149 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: 150 | """Training step logic for both main model and decoder.""" 151 | # Get model predictions (in latent space) 152 | pred = self(batch) 153 | pred = pred.reshape(-1, self.cell_sentence_len, self.output_dim) 154 | # TODO: please improve this, do not assume self.cell_sentence_len for this model 155 | target = batch["pert_cell_emb"] 156 | target = target.reshape(-1, self.cell_sentence_len, self.output_dim) 157 | main_loss = self.loss_fn(pred, target).mean() 158 | self.log("train_loss", main_loss) 159 | 160 | # Process decoder if available 161 | decoder_loss = None 162 | if self.gene_decoder is not None and "pert_cell_counts" in batch: 163 | # Train decoder to map latent predictions to gene space 164 | with torch.no_grad(): 165 | latent_preds = pred.detach() # Detach to prevent gradient flow back to main model 166 | 167 | pert_cell_counts_preds = self.gene_decoder(latent_preds) 168 | gene_targets = batch["pert_cell_counts"] 169 | gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) 170 | decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() 171 | 172 | # Log decoder loss 173 | self.log("decoder_loss", decoder_loss) 174 | 175 | total_loss = main_loss + decoder_loss 176 | else: 177 | total_loss = main_loss 178 | 179 | return total_loss 180 | 181 | def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: 182 | """Validation step logic.""" 183 | pred = self(batch) 184 | pred = pred.reshape(-1, self.cell_sentence_len, self.output_dim) 185 | target = batch["pert_cell_emb"] 186 | target = target.reshape(-1, self.cell_sentence_len, self.output_dim) 187 | loss = self.loss_fn(pred, target).mean() 188 | self.log("val_loss", loss) 189 | 190 | return {"loss": loss, "predictions": pred} 191 | 192 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0) -> None: 193 | """Track decoder performance during validation without training it.""" 194 | if self.gene_decoder is not None and "pert_cell_counts" in batch: 195 | # Get model predictions from validation step 196 | latent_preds = outputs["predictions"] 197 | 198 | # Train decoder to map latent predictions to gene space 199 | pert_cell_counts_preds = self.gene_decoder(latent_preds) # verify this is automatically detached 200 | gene_targets = batch["pert_cell_counts"] 201 | 202 | # Get decoder predictions 203 | pert_cell_counts_preds = pert_cell_counts_preds.reshape(-1, self.cell_sentence_len, self.gene_dim) 204 | gene_targets = gene_targets.reshape(-1, self.cell_sentence_len, self.gene_dim) 205 | decoder_loss = self.loss_fn(pert_cell_counts_preds, gene_targets).mean() 206 | 207 | # Log the validation metric 208 | self.log("decoder_val_loss", decoder_loss) 209 | 210 | def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: 211 | pred = self.forward(batch, padded=False) 212 | target = batch["pert_cell_emb"] 213 | pred = pred.reshape(1, -1, self.output_dim) 214 | target = target.reshape(1, -1, self.output_dim) 215 | loss = self.loss_fn(pred, target).mean() 216 | self.log("test_loss", loss) 217 | pred = pred.reshape(-1, self.output_dim) 218 | target = target.reshape(-1, self.output_dim) 219 | 220 | def predict_step(self, batch, batch_idx, padded=True, **kwargs): 221 | """ 222 | Typically used for final inference. We'll replicate old logic: 223 | returning 'preds', 'X', 'pert_name', etc. 224 | """ 225 | latent_output = self.forward(batch) # shape [B, ...] 226 | output_dict = { 227 | "preds": latent_output, 228 | "pert_cell_emb": batch.get("pert_cell_emb", None), 229 | "pert_cell_counts": batch.get("pert_cell_counts", None), 230 | "pert_name": batch.get("pert_name", None), 231 | "celltype_name": batch.get("cell_type", None), 232 | "batch": batch.get("batch", None), 233 | "ctrl_cell_emb": batch.get("ctrl_cell_emb", None), 234 | "pert_cell_barcode": batch.get("pert_cell_barcode", None), 235 | } 236 | 237 | if self.gene_decoder is not None: 238 | pert_cell_counts_preds = self.gene_decoder(latent_output) 239 | output_dict["pert_cell_counts_preds"] = pert_cell_counts_preds 240 | 241 | return output_dict 242 | --------------------------------------------------------------------------------