├── README.md ├── analysis.spatial_KO.xenium_human_glioblastoma_gpu.ipynb ├── celcomen ├── .ipynb_checkpoints │ └── __init__-checkpoint.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-310.pyc ├── datareaders │ ├── __init__.py │ └── datareader.py ├── models │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── celcomen-checkpoint.py │ │ └── simcomen-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── celcomen.cpython-310.pyc │ │ ├── celcomen.cpython-39.pyc │ │ ├── simcomen.cpython-310.pyc │ │ └── simcomen.cpython-39.pyc │ ├── celcomen.py │ └── simcomen.py ├── training_plan │ ├── .ipynb_checkpoints │ │ └── train-checkpoint.py │ └── train.py └── utils │ ├── .ipynb_checkpoints │ └── helpers-checkpoint.py │ ├── __init__.py │ └── helpers.py ├── docs ├── .ipynb_checkpoints │ ├── Makefile-checkpoint │ └── make-checkpoint.bat ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── requirements.txt └── source │ └── .ipynb_checkpoints │ ├── conf-checkpoint.py │ └── index-checkpoint.rst ├── images └── disentangling graphs and gene colocalization-2.png ├── pyproject.toml └── readthedocs.yaml /README.md: -------------------------------------------------------------------------------- 1 | # Cell Communication Energy (celcomen) 2 | Causal generative model designed to disentangle intercellular and intracellular gene regulation with theoretical identifiability guarantees. Celcomen can generate counterfactual spatial transcriptomic samples by simulating the effect of local perturbations, such as gene activations/inhibitations or cell insertions/deletions. 3 | 4 | You can find out more by reading our [manuscript](https://arxiv.org/abs/2409.05804). 5 | 6 |

7 | 8 |

9 | 10 | Installation 11 | ============ 12 | Conda Environment 13 | -- 14 | We recommend using [Anaconda](https://www.anaconda.com/)/[Miniconda](https://docs.conda.io/projects/miniconda/en/latest/) to create a conda environment for using celcomen. You can create a python environment using the following command: 15 | 16 | conda create -n celcomen_env python=3.9 17 | 18 | Then, you can activate the environment using: 19 | 20 | conda activate celcomen_env 21 | 22 | Install celcomen 23 | -- 24 | Then install 25 | ``` 26 | pip install git+https://github.com/stathismegas/celcomen 27 | ``` 28 | 29 | Causal Disentanglement and spatial Counterfactuals 30 | ============ 31 | To learn intracellular and extra-cellular gene regulation and then use it to simulate inflammation conuterfactuals in specific locaitons of the tissue, follow the tutorial `analysis.perturbation_newest_celcomen.ipynb`. 32 | 33 | As explained in the tutorial, the adata object should have count data, without any prior normalization or log-transformation. 34 | 35 | To speed up the training process on a GPU refer to the tutorial `train_using_dataloaders_gpu.ipynb`. 36 | 37 | -------------------------------------------------------------------------------- /celcomen/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | """Method for causal disentanglement of inter-cellular and intra-cellular gene regulation in spatial transcriptomics""" 2 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /celcomen/__init__.py: -------------------------------------------------------------------------------- 1 | """Method for causal disentanglement of inter-cellular and intra-cellular gene regulation in spatial transcriptomics""" 2 | __version__ = "0.0.1" 3 | 4 | from . import datareaders 5 | from . import models 6 | from . import utils 7 | -------------------------------------------------------------------------------- /celcomen/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /celcomen/datareaders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datareader 2 | -------------------------------------------------------------------------------- /celcomen/datareaders/datareader.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.loader import DataLoader 2 | import scanpy as sc 3 | import torch 4 | import torch_geometric 5 | #from torch_cluster.knn import knn_graph 6 | from sklearn.neighbors import kneighbors_graph 7 | import numpy as np 8 | 9 | from scipy.spatial.distance import pdist, squareform 10 | 11 | def get_dataset_loaders(h5ad_path: str, sample_id_name: str, n_neighbors: int, distance: float, device: str, verbose: bool): 12 | """ 13 | Prepares and returns PyTorch Geometric DataLoader from a single-cell spatial transcriptomics dataset. 14 | 15 | The function reads a single-cell AnnData object from an H5AD file, normalises the data, and generates 16 | graph data where nodes correspond to cells, and edges are created based on spatial proximity using 17 | a k-nearest neighbours graph. The data is then loaded into a PyTorch Geometric `DataLoader`. 18 | 19 | Parameters 20 | ---------- 21 | h5ad_path : str 22 | Path to the H5AD file containing the raw counts of the single-cell spatial transcriptomics data. 23 | sample_id_name : str 24 | Name of the sample ID column in `adata.obs` to separate the dataset into different samples. 25 | n_neighbors : int 26 | Number of neighbours to use for constructing the k-nearest neighbours graph for spatial information. 27 | distance : float 28 | Distance of neighbours to use for constructing the k-nearest neighbours graph for spatial information. 29 | verbose : bool 30 | If True, prints detailed information about the DataLoader during the loading process. 31 | 32 | Returns 33 | ------- 34 | DataLoader 35 | A PyTorch Geometric DataLoader containing the processed graph data, with each graph representing 36 | a sample of cells in the dataset. 37 | 38 | Notes 39 | ----- 40 | - The spatial positions of the cells are used to create a k-nearest neighbours graph, with edges 41 | connecting cells that are spatially close to each other. 42 | - The input features for the graph (`x`) are normalised before constructing the graph. 43 | - `adata.obsm["spatial"]` is used to extract the spatial coordinates of the cells. 44 | - The graph data is validated using PyTorch Geometric's built-in validation method. 45 | 46 | Examples 47 | -------- 48 | >>> loader = get_dataset_loaders('data.h5ad', 'sample_id', n_neighbors=6, verbose=True) 49 | Step 1 50 | ===== 51 | Number of graphs in the current batch: 1 52 | Data(x=[100, 33500], edge_index=[2, 500], pos=[100, 2], y=[1]) 53 | 54 | Raises 55 | ------ 56 | ValueError 57 | If there are issues with the input data during validation, e.g., if the graph is not well-formed. 58 | """ 59 | 60 | adata = sc.read_h5ad(h5ad_path) 61 | 62 | # sc.pp.normalize_total(adata, target_sum=1e6) 63 | # sc.pp.log1p(adata) 64 | 65 | adata_list = [ adata[adata.obs[sample_id_name]==i] for i in set(adata.obs[sample_id_name]) ] 66 | 67 | data_list = [] 68 | n_neighbors = 6 69 | 70 | for adata in adata_list: 71 | pos = torch.from_numpy(adata.obsm["spatial"]) 72 | x = torch.from_numpy(adata.X.todense()) 73 | # normalize x 74 | norm_factor = torch.pow(x,2).sum(1).reshape(-1,1) 75 | x = torch.div(x, norm_factor) 76 | y = torch.Tensor([0]) 77 | #edge_index = knn_graph(pos, k=n_neighbors) 78 | #distances = squareform(pdist(df.loc[mask, ['x_centroid','y_centroid']])) 79 | distances = squareform(pdist( adata.obsm["spatial"] ) ) 80 | if distance!=None: 81 | # compute the edges as two cell widths apart so 30µm 82 | edge_index = torch.from_numpy(np.array(np.where((distances < distance)&(distances != 0)))).to(device) 83 | if n_neighbors!=None: 84 | edge_index = torch.from_numpy(np.array(np.where(kneighbors_graph(adata.obsm["spatial"], n_neighbors=n_neighbors, include_self=False).toarray() == 1))).to(device) 85 | data = torch_geometric.data.Data(x=x, pos=pos, y=y, edge_index=edge_index) 86 | data.validate(raise_on_error=True) # performs basic checks on the graph 87 | data_list.append(data) 88 | 89 | loader = DataLoader( data_list, batch_size=1, shuffle=True) 90 | 91 | if verbose: 92 | for step, data in enumerate(loader): 93 | print(f'Step {step+1}') 94 | print("=====") 95 | print(f'Number of graphs in the current batch: {data.num_graphs}') 96 | print(data) 97 | print() 98 | 99 | return loader 100 | -------------------------------------------------------------------------------- /celcomen/models/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/.ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /celcomen/models/.ipynb_checkpoints/celcomen-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # define the celcomen class 4 | class celcomen(torch.nn.Module): 5 | # define initialization function 6 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0): 7 | super(celcomen, self).__init__() 8 | # define the seed 9 | torch.manual_seed(seed) 10 | # set up the graph convolution 11 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False) 12 | # set up the linear layer for intracellular gene regulation 13 | self.lin = torch.nn.Linear(input_dim, output_dim) 14 | # define the neighbors 15 | self.n_neighbors = n_neighbors 16 | # define a tracking variable for the gene expression x matrix 17 | self.gex = None 18 | 19 | # define a function to artificially set the g2g matrix 20 | def set_g2g(self, g2g): 21 | """ 22 | Artifically sets the core g2g matrix to be a specified interaction matrix 23 | """ 24 | # set the weight as the input 25 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=True) 26 | # and then set the bias as all zeros 27 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g)).astype('float32')), requires_grad=False) 28 | 29 | # define a function to artificially set the g2g matrix 30 | def set_g2g_intra(self, g2g_intra): 31 | """ 32 | Artifically sets the core g2g intracellular matrix to be a specified matrix 33 | """ 34 | # set the weight as the input 35 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=True) 36 | # and then set the bias as all zeros 37 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False) 38 | 39 | # define a function to artificially set the sphex matrix 40 | def set_gex(self, gex): 41 | """ 42 | Artifically sets the current sphex matrix 43 | """ 44 | self.gex = torch.nn.Parameter(gex, requires_grad=False) 45 | 46 | # define the forward pass 47 | def forward(self, edge_index, batch): 48 | """ 49 | Forward pass for prediction or training, 50 | convolutes the input by the expected interactions and returns log(Z_mft) 51 | """ 52 | # compute the message 53 | msg = self.conv1(self.gex, edge_index) 54 | # compute intracellular message 55 | msg_intra = self.lin(self.gex) 56 | # compute the log z mft 57 | log_z_mft = self.log_Z_mft(edge_index, batch) 58 | return msg, msg_intra, log_z_mft 59 | 60 | # define approximation function 61 | def log_Z_mft(self, edge_index, batch): 62 | """ 63 | Mean Field Theory approximation to the partition function. Assumptions used are: 64 | - expression of values of genes are close to their mean values over the visium slide 65 | - \sum_b g_{a,b} m^b >0 \forall a, where m is the mean gene expression and g is the gene-gene 66 | interaction matrix. 67 | """ 68 | # retrieve number of spots 69 | num_spots = self.gex.shape[0] 70 | # calculate mean gene expression 71 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph 72 | # calculate the norm of the sum of mean genes 73 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) # maybe needs to change to g = torch.norm(torch.mm(mean_genes, self.conv1.lin.weight)) 74 | # calculate the contribution for mean values 75 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes) 76 | # calculate the contribution gene interactions 77 | z_interaction = self.z_interaction(num_spots=num_spots, g=g) 78 | # add the two contributions 79 | log_z_mft = z_mean + z_interaction 80 | return log_z_mft 81 | 82 | def z_interaction(self, num_spots, g): 83 | """ 84 | Avoid exploding exponentials by returning an approximate interaction term for the partition function. 85 | """ 86 | if g>20: 87 | z_interaction = num_spots * ( g - torch.log( g) ) 88 | else: 89 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g)) 90 | return z_interaction -------------------------------------------------------------------------------- /celcomen/models/.ipynb_checkpoints/simcomen-checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv 2 | from sklearn.neighbors import kneighbors_graph 3 | # define the number of neighbors (six for visium) 4 | n_neighbors = 6 5 | # define the simcomen class 6 | class simcomen(torch.nn.Module): 7 | # define initialization function 8 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0): 9 | super(simcomen, self).__init__() 10 | # define the seed 11 | torch.manual_seed(seed) 12 | # set up the graph convolution 13 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False) 14 | # set up the linear layer for intracellular gene regulation 15 | self.lin = torch.nn.Linear(input_dim, output_dim) 16 | # define the neighbors 17 | self.n_neighbors = n_neighbors 18 | # define a tracking variable for the gene expression x matrix 19 | self.sphex = None 20 | self.gex = None 21 | 22 | # define a function to artificially set the g2g matrix 23 | def set_g2g(self, g2g): 24 | """ 25 | Artifically sets the core g2g matrix to be a specified interaction matrix 26 | """ 27 | # set the weight as the input 28 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=False) 29 | # and then set the bias as all zeros 30 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(n_genes).astype('float32')), requires_grad=False) 31 | 32 | # define a function to artificially set the g2g matrix 33 | def set_g2g_intra(self, g2g_intra): 34 | """ 35 | Artifically sets the core g2g intracellular matrix to be a specified matrix 36 | """ 37 | # set the weight as the input 38 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=False) 39 | # and then set the bias as all zeros 40 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False) 41 | 42 | # define a function to artificially set the sphex matrix 43 | def set_sphex(self, sphex): 44 | """ 45 | Artifically sets the current sphex matrix 46 | """ 47 | self.sphex = torch.nn.Parameter(sphex, requires_grad=True) 48 | 49 | # define the forward pass 50 | def forward(self, edge_index, batch): 51 | """ 52 | Forward pass for prediction or training, 53 | convolutes the input by the expected interactions and returns log(Z_mft) 54 | """ 55 | # compute the gex 56 | self.gex = calc_gex(self.sphex) 57 | # compute the message 58 | msg = self.conv1(self.gex, edge_index) 59 | # compute intracellular message 60 | msg_intra = self.lin(self.gex) 61 | # compute the log z mft 62 | log_z_mft = self.log_Z_mft(edge_index, batch) 63 | return msg, msg_intra, log_z_mft 64 | 65 | # define approximation function 66 | def log_Z_mft(self, edge_index, batch): 67 | """ 68 | Mean Field Theory approximation to the partition function. Assumptions used are: 69 | - expression of values of genes are close to their mean values over the visium slide 70 | - \sum_b g_{a,b} m^b >0 \forall a, where m is the mean gene expression and g is the gene-gene 71 | interaction matrix. 72 | """ 73 | # retrieve number of spots 74 | num_spots = self.gex.shape[0] 75 | # calculate mean gene expression 76 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph 77 | # calculate the norm of the sum of mean genes 78 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) # maybe needs to change to g = torch.norm(torch.mm(mean_genes, self.conv1.lin.weight)) 79 | # calculate the contribution for mean values 80 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes) 81 | # calculate the contribution gene interactions 82 | z_interaction = self.z_interaction(num_spots=num_spots, g=g) 83 | # add the two contributions 84 | log_z_mft = z_mean + z_interaction 85 | return log_z_mft 86 | 87 | def z_interaction(self, num_spots, g): 88 | """ 89 | Avoid exploding exponentials by returning an approximate interaction term for the partition function. 90 | """ 91 | if g>20: 92 | z_interaction = num_spots * ( g - torch.log( g) ) 93 | else: 94 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g)) 95 | return z_interaction -------------------------------------------------------------------------------- /celcomen/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import celcomen 2 | from . import simcomen 3 | -------------------------------------------------------------------------------- /celcomen/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /celcomen/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /celcomen/models/__pycache__/celcomen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/celcomen.cpython-310.pyc -------------------------------------------------------------------------------- /celcomen/models/__pycache__/celcomen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/celcomen.cpython-39.pyc -------------------------------------------------------------------------------- /celcomen/models/__pycache__/simcomen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/simcomen.cpython-310.pyc -------------------------------------------------------------------------------- /celcomen/models/__pycache__/simcomen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/simcomen.cpython-39.pyc -------------------------------------------------------------------------------- /celcomen/models/celcomen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GCNConv 3 | from sklearn.neighbors import kneighbors_graph 4 | import numpy as np 5 | 6 | # define the celcomen class 7 | class celcomen(torch.nn.Module): 8 | """ 9 | A k-hop Graph Neural Network model for disentangling inter- and intra-cellular gene regulation, and then leveraging to predict spatial counterfactuals. 10 | 11 | Parameters 12 | ---------- 13 | input_dim : int 14 | Dimensionality of the input features (gene expression data). 15 | output_dim : int 16 | Dimensionality of the output features. 17 | n_neighbors : int 18 | The number of neighbours used in the spatial graph to model cell-cell interactions. 19 | seed : int, optional 20 | Seed for random number generation to ensure reproducibility. Default is 0. 21 | 22 | Attributes 23 | ---------- 24 | conv1 : GCNConv 25 | A graph convolutional layer that models gene-to-gene interactions (G2G). 26 | lin : torch.nn.Linear 27 | A linear layer that models intracellular gene regulation. 28 | n_neighbors : int 29 | The number of neighbours for spatial graph construction. 30 | gex : torch.nn.Parameter or None 31 | Stores the gene expression matrix used for the forward pass. Set to None initially. 32 | 33 | Methods 34 | ------- 35 | set_g2g(g2g) 36 | Sets the gene-to-gene (G2G) interaction matrix artificially. 37 | set_g2g_intra(g2g_intra) 38 | Sets the intracellular regulation matrix artificially. 39 | set_gex(gex) 40 | Sets the gene expression matrix artificially. 41 | forward(edge_index, batch) 42 | Forward pass to compute the gene-to-gene and intracellular messages, 43 | and the log partition function estimate. 44 | log_Z_mft(edge_index, batch) 45 | Computes the Mean Field Theory (MFT) approximation to the partition function. 46 | z_interaction(num_spots, g) 47 | Provides an approximation for the interaction term in the partition function 48 | to prevent numerical instability due to exploding exponentials. 49 | 50 | Examples 51 | -------- 52 | >>> model = celcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42) 53 | >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) 54 | >>> batch = torch.tensor([0, 1], dtype=torch.long) 55 | >>> model.set_gex(torch.randn(100, 1000)) 56 | >>> msg, msg_intra, log_z_mft = model(edge_index, batch) 57 | >>> print(log_z_mft) 58 | """ 59 | 60 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0): 61 | """ 62 | Initializes the celcomen model with a graph convolution layer and a linear 63 | layer for gene-to-gene and intracellular regulation, respectively. 64 | 65 | Parameters 66 | ---------- 67 | input_dim : int 68 | Dimensionality of the input features. 69 | output_dim : int 70 | Dimensionality of the output features. 71 | n_neighbors : int 72 | Number of neighbours for constructing the spatial graph. 73 | seed : int, optional 74 | Random seed for reproducibility (default is 0). 75 | """ 76 | super(celcomen, self).__init__() 77 | # define the seed 78 | torch.manual_seed(seed) 79 | # set up the graph convolution 80 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False) 81 | # set up the linear layer for intracellular gene regulation 82 | self.lin = torch.nn.Linear(input_dim, output_dim) 83 | # define the neighbors 84 | self.n_neighbors = n_neighbors 85 | # define a tracking variable for the gene expression x matrix 86 | self.gex = None 87 | 88 | # define a function to artificially set the g2g matrix 89 | def set_g2g(self, g2g): 90 | """ 91 | Artificially sets the gene-to-gene (G2G) interaction matrix. 92 | 93 | Parameters 94 | ---------- 95 | g2g : torch.Tensor 96 | A matrix representing gene-to-gene interactions to be used for graph convolution. 97 | """ 98 | # set the weight as the input 99 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=True) 100 | # and then set the bias as all zeros 101 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g)).astype('float32')), requires_grad=False) 102 | 103 | # define a function to artificially set the g2g matrix 104 | def set_g2g_intra(self, g2g_intra): 105 | """ 106 | Artificially sets the intracellular gene regulation matrix. 107 | 108 | Parameters 109 | ---------- 110 | g2g_intra : torch.Tensor 111 | A matrix representing intracellular gene regulation interactions. 112 | """ 113 | # set the weight as the input 114 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=True) 115 | # and then set the bias as all zeros 116 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False) 117 | 118 | # define a function to artificially set the sphex matrix 119 | def set_gex(self, gex): 120 | """ 121 | Sets the gene expression matrix to be used during the forward pass. 122 | 123 | Parameters 124 | ---------- 125 | gex : torch.Tensor 126 | A matrix representing the gene expression of the cells. 127 | """ 128 | self.gex = torch.nn.Parameter(gex, requires_grad=False) 129 | 130 | # define the forward pass 131 | def forward(self, edge_index, batch): 132 | """ 133 | Forward pass for the model, computing gene-to-gene and intracellular messages, 134 | and estimating the log partition function using Mean Field Theory (MFT). 135 | 136 | Parameters 137 | ---------- 138 | edge_index : torch.Tensor 139 | Tensor representing the graph edges (connectivity between nodes/cells). 140 | batch : torch.Tensor 141 | Tensor representing the batch of data. 142 | 143 | Returns 144 | ------- 145 | msg : torch.Tensor 146 | The message propagated between cells based on gene-to-gene interactions. 147 | msg_intra : torch.Tensor 148 | The message based on intracellular gene regulation. 149 | log_z_mft : torch.Tensor 150 | The Mean Field Theory approximation to the log partition function. 151 | """ 152 | # compute the message 153 | msg = self.conv1(self.gex, edge_index) 154 | # compute intracellular message 155 | msg_intra = self.lin(self.gex) 156 | # compute the log z mft 157 | log_z_mft = self.log_Z_mft(edge_index, batch) 158 | return msg, msg_intra, log_z_mft 159 | 160 | # define approximation function 161 | def log_Z_mft(self, edge_index, batch): 162 | """ 163 | Computes the Mean Field Theory (MFT) approximation to the partition function, 164 | which estimates the likelihood of gene expression states in the dataset. 165 | 166 | Parameters 167 | ---------- 168 | edge_index : torch.Tensor 169 | Tensor representing the graph edges (connectivity between nodes/cells). 170 | batch : torch.Tensor 171 | Tensor representing the batch of data. 172 | 173 | Returns 174 | ------- 175 | log_z_mft : torch.Tensor 176 | The log partition function estimated using Mean Field Theory (MFT). 177 | """ 178 | # retrieve number of spots 179 | num_spots = self.gex.shape[0] 180 | # calculate mean gene expression 181 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph 182 | # calculate the norm of the sum of mean genes 183 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) 184 | # calculate the contribution for mean values 185 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes) 186 | # calculate the contribution gene interactions 187 | z_interaction = self.z_interaction(num_spots=num_spots, g=g) 188 | # add the two contributions 189 | log_z_mft = z_mean + z_interaction 190 | return log_z_mft 191 | 192 | def z_interaction(self, num_spots, g): 193 | """ 194 | Avoids exploding exponentials in the partition function approximation by 195 | returning an approximate interaction term. 196 | 197 | Parameters 198 | ---------- 199 | num_spots : int 200 | Number of spots (cells) in the dataset. 201 | g : torch.Tensor 202 | Norm of the sum of mean gene expressions weighted by gene-to-gene interactions. 203 | 204 | Returns 205 | ------- 206 | z_interaction : torch.Tensor 207 | The approximated interaction term for the partition function. 208 | """ 209 | if g>20: 210 | z_interaction = num_spots * ( g - torch.log( g) ) 211 | else: 212 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g)) 213 | return z_interaction -------------------------------------------------------------------------------- /celcomen/models/simcomen.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv 2 | from sklearn.neighbors import kneighbors_graph 3 | import torch 4 | import numpy as np 5 | #from ..utils.helpers import calc_gex 6 | 7 | # define the number of neighbors (six for visium) 8 | n_neighbors = 6 9 | # define the simcomen class 10 | class simcomen(torch.nn.Module): 11 | """ 12 | A k-hop Graph Neural Network model for predicting spatial counterfactuals, such as localised perturbations. 13 | 14 | Parameters 15 | ---------- 16 | input_dim : int 17 | The dimensionality of the input gene expression features. 18 | output_dim : int 19 | The dimensionality of the output features after processing through graph convolution and linear layers. 20 | n_neighbors : int 21 | The number of neighbors to use in constructing the k-nearest neighbor graph. 22 | seed : int, optional 23 | Random seed for reproducibility, default is 0. 24 | 25 | Attributes 26 | ---------- 27 | conv1 : GCNConv 28 | Graph convolutional layer for gene-to-gene (G2G) interactions. 29 | lin : torch.nn.Linear 30 | Linear layer for intracellular gene regulation. 31 | n_neighbors : int 32 | Number of spatial neighbors used for constructing the graph. 33 | sphex : torch.nn.Parameter or None 34 | Spherical gene expression matrix, set via `set_sphex`. 35 | gex : torch.nn.Parameter or None 36 | Gene expression matrix, calculated from the spherical expression matrix. 37 | output_dim : int 38 | Output dimensionality of the model. 39 | 40 | Methods 41 | ------- 42 | set_g2g(g2g) 43 | Sets the gene-to-gene (G2G) interaction matrix artificially. 44 | set_g2g_intra(g2g_intra) 45 | Sets the intracellular gene regulation matrix artificially. 46 | set_sphex(sphex) 47 | Sets the spherical gene expression matrix artificially. 48 | forward(edge_index, batch) 49 | Forward pass of the model, calculating messages from gene-to-gene interactions, 50 | intracellular interactions, and the log partition function (log(Z_mft)). 51 | log_Z_mft(edge_index, batch) 52 | Computes the Mean Field Theory (MFT) approximation to the partition function for the current gene expressions. 53 | z_interaction(num_spots, g) 54 | Calculates the interaction term for the partition function while avoiding numerical instability. 55 | calc_gex(sphex) 56 | Converts the spherical gene expression matrix into a regular gene expression matrix. 57 | calc_sphex(gex) 58 | Converts the regular gene expression matrix into a spherical gene expression matrix. 59 | get_pos(n_x, n_y) 60 | Generates a 2D hexagonal grid of positions for spatial modeling. 61 | normalize_g2g(g2g) 62 | Symmetrizes and normalizes the gene-to-gene interaction matrix. 63 | 64 | Examples 65 | -------- 66 | >>> model = simcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42) 67 | >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) 68 | >>> batch = torch.tensor([0, 1], dtype=torch.long) 69 | >>> model.set_sphex(torch.randn(100, 1000)) 70 | >>> msg, msg_intra, log_z_mft = model(edge_index, batch) 71 | >>> print(log_z_mft) 72 | """ 73 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0): 74 | """ 75 | Initializes the `simcomen` model with a graph convolution layer and a linear layer 76 | for gene-to-gene interactions and intracellular regulation, respectively. 77 | 78 | Parameters 79 | ---------- 80 | input_dim : int 81 | Dimensionality of the input features. 82 | output_dim : int 83 | Dimensionality of the output features. 84 | n_neighbors : int 85 | Number of neighbors to use for constructing the spatial graph. 86 | seed : int, optional 87 | Random seed for reproducibility (default is 0). 88 | """ 89 | super(simcomen, self).__init__() 90 | # define the seed 91 | torch.manual_seed(seed) 92 | # set up the graph convolution 93 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False) 94 | # set up the linear layer for intracellular gene regulation 95 | self.lin = torch.nn.Linear(input_dim, output_dim) 96 | # define the neighbors 97 | self.n_neighbors = n_neighbors 98 | # define a tracking variable for the gene expression x matrix 99 | self.sphex = None 100 | self.gex = None 101 | self.output_dim = output_dim 102 | 103 | # define a function to artificially set the g2g matrix 104 | def set_g2g(self, g2g): 105 | """ 106 | Artificially sets the gene-to-gene (G2G) interaction matrix. 107 | 108 | Parameters 109 | ---------- 110 | g2g : torch.Tensor 111 | A tensor representing gene-to-gene interactions to be used in the graph convolution. 112 | """ 113 | # set the weight as the input 114 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=False) 115 | # and then set the bias as all zeros 116 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(self.output_dim).astype('float32')), requires_grad=False) 117 | 118 | # define a function to artificially set the g2g matrix 119 | def set_g2g_intra(self, g2g_intra): 120 | """ 121 | Artificially sets the intracellular regulation matrix. 122 | 123 | Parameters 124 | ---------- 125 | g2g_intra : torch.Tensor 126 | A tensor representing intracellular gene regulation interactions. 127 | """ 128 | # set the weight as the input 129 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=False) 130 | # and then set the bias as all zeros 131 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False) 132 | 133 | # define a function to artificially set the sphex matrix 134 | def set_sphex(self, sphex): 135 | """ 136 | Sets the spherical gene expression matrix for the forward pass. 137 | 138 | Parameters 139 | ---------- 140 | sphex : torch.Tensor 141 | A tensor representing the spherical expression matrix. 142 | """ 143 | self.sphex = torch.nn.Parameter(sphex, requires_grad=True) 144 | 145 | # define the forward pass 146 | def forward(self, edge_index, batch): 147 | """ 148 | Forward pass of the model, calculates the messages between nodes using gene-to-gene interactions 149 | and intracellular gene regulation. Also calculates the log partition function using Mean Field Theory. 150 | 151 | Parameters 152 | ---------- 153 | edge_index : torch.Tensor 154 | Tensor representing the graph edges (connections between nodes/cells). 155 | batch : torch.Tensor 156 | Tensor representing the batch of data. 157 | 158 | Returns 159 | ------- 160 | msg : torch.Tensor 161 | Message passed between nodes based on gene-to-gene interactions. 162 | msg_intra : torch.Tensor 163 | Message passed within nodes based on intracellular gene regulation. 164 | log_z_mft : torch.Tensor 165 | Mean Field Theory approximation of the log partition function. 166 | """ 167 | # compute the gex 168 | self.gex = self.calc_gex(self.sphex) 169 | #print( f"self.gex device is {self.gex.device}") 170 | #print( f"edge_index device is {edge_index.device}") 171 | # compute the message 172 | msg = self.conv1(self.gex, edge_index) 173 | # compute intracellular message 174 | msg_intra = self.lin(self.gex) 175 | # compute the log z mft 176 | log_z_mft = self.log_Z_mft(edge_index, batch) 177 | return msg, msg_intra, log_z_mft 178 | 179 | # define approximation function 180 | def log_Z_mft(self, edge_index, batch): 181 | """ 182 | Computes the Mean Field Theory (MFT) approximation of the partition function. 183 | This function assumes that gene expression values are close to their mean across the spatial slide. 184 | 185 | Parameters 186 | ---------- 187 | edge_index : torch.Tensor 188 | Tensor representing the graph edges (connections between nodes/cells). 189 | batch : torch.Tensor 190 | Tensor representing the batch of data. 191 | 192 | Returns 193 | ------- 194 | log_z_mft : torch.Tensor 195 | Mean Field Theory approximation of the log partition function. 196 | """ 197 | # retrieve number of spots 198 | num_spots = self.gex.shape[0] 199 | # calculate mean gene expression 200 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph 201 | # calculate the norm of the sum of mean genes 202 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) 203 | # calculate the contribution for mean values 204 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes) 205 | # calculate the contribution gene interactions 206 | z_interaction = self.z_interaction(num_spots=num_spots, g=g) 207 | # add the two contributions 208 | log_z_mft = z_mean + z_interaction 209 | return log_z_mft 210 | 211 | def z_interaction(self, num_spots, g): 212 | """ 213 | Calculates the interaction term for the partition function approximation, avoiding exploding exponentials. 214 | 215 | Parameters 216 | ---------- 217 | num_spots : int 218 | Number of spots (cells) in the dataset. 219 | g : torch.Tensor 220 | Norm of the sum of mean gene expressions weighted by gene-to-gene interactions. 221 | 222 | Returns 223 | ------- 224 | z_interaction : torch.Tensor 225 | Approximated interaction term for the partition function. 226 | """ 227 | if g>20: 228 | z_interaction = num_spots * ( g - torch.log( g) ) 229 | else: 230 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g)) 231 | return z_interaction 232 | 233 | # define a function to derive the gex from the sphex 234 | def calc_gex(self, sphex): 235 | """ 236 | Converts the spherical expression matrix into a regular gene expression matrix. 237 | 238 | Parameters 239 | ---------- 240 | sphex : torch.Tensor 241 | The spherical gene expression matrix. 242 | 243 | Returns 244 | ------- 245 | gex : torch.Tensor 246 | The converted regular gene expression matrix. 247 | """ 248 | # setup the gex 249 | n_genes = sphex.shape[1]+1 250 | #gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32'), device=next(self.parameters()).device) 251 | gex = torch.zeros((sphex.shape[0], n_genes), dtype=torch.float32, device=next(self.parameters()).device) 252 | # compute the gex 253 | for idx in range(n_genes): 254 | if idx == n_genes-1: 255 | gex[:,idx] = torch.sin(sphex[:,idx-1]) 256 | else: 257 | gex[:,idx] = torch.cos(sphex[:,idx]) 258 | for idx_ in range(idx): 259 | gex[:,idx] *= torch.sin(sphex[:,idx_]) 260 | return torch.nan_to_num(gex) 261 | 262 | # define a function to gather positions 263 | def get_pos(n_x, n_y): 264 | """ 265 | Generates a 2D hexagonal grid of positions for spatial modelling. 266 | 267 | Parameters 268 | ---------- 269 | n_x : int 270 | Number of positions along the x-axis. 271 | n_y : int 272 | Number of positions along the y-axis. 273 | 274 | Returns 275 | ------- 276 | pos : numpy.ndarray 277 | Array of 2D positions for the grid. 278 | """ 279 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)]) 280 | # derive the y-step given a distance of one 281 | y_step = np.sqrt(1**2+0.5**2) 282 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)]) 283 | # define the positions 284 | pos = np.vstack([xs.flatten(), ys.flatten()]).T 285 | return pos 286 | 287 | 288 | # define a function to normalize the g2g 289 | def normalize_g2g(g2g): 290 | """ 291 | Symmetrizes and normalizes the gene-to-gene interaction matrix. 292 | 293 | Parameters 294 | ---------- 295 | g2g : numpy.ndarray 296 | The gene-to-gene interaction matrix. 297 | 298 | Returns 299 | ------- 300 | g2g : numpy.ndarray 301 | The normalized and symmetrized gene-to-gene interaction matrix. 302 | """ 303 | # symmetrize the values 304 | g2g = (g2g + g2g.T) / 2 305 | # force them to be between 0-1 306 | g2g[g2g < 0] = 0 307 | g2g[g2g > 1] = 1 308 | # force the central line to be 1 309 | for idx in range(len(g2g)): 310 | g2g[idx, idx] = 1 311 | return g2g 312 | 313 | # define a function to derive the gex from the sphex 314 | def calc_sphex(self, gex): 315 | """ 316 | Converts the regular gene expression matrix into a spherical gene expression matrix. 317 | 318 | Parameters 319 | ---------- 320 | gex : torch.Tensor 321 | The regular gene expression matrix. 322 | 323 | Returns 324 | ------- 325 | sphex : torch.Tensor 326 | The converted spherical gene expression matrix. 327 | """ 328 | # setup the gex 329 | n_sgenes = gex.shape[1]-1 330 | #sphex = torch.from_numpy(np.zeros((gex.shape[0], n_sgenes)).astype('float32'), device=next(self.parameters()).device) 331 | sphex = torch.zeros((gex.shape[0], n_sgenes), dtype=torch.float32, device=next(self.parameters()).device) 332 | # compute the gex 333 | for idx in range(n_sgenes): 334 | sphex[:,idx] = gex[:,idx] 335 | for idx_ in range(idx): 336 | sphex[:,idx] /= torch.sin(sphex[:,idx_]) 337 | sphex[:,idx] = torch.arccos(sphex[:,idx]) 338 | return sphex 339 | 340 | -------------------------------------------------------------------------------- /celcomen/training_plan/.ipynb_checkpoints/train-checkpoint.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | 5 | # def train(model, optimizer, train_loader, num_epochs, seed=1, device="cpu"): 6 | 7 | # if device is None: 8 | # device = torch.device(device) 9 | 10 | # torch.manual_seed(seed) 11 | 12 | # for epoch in range(num_epochs): 13 | 14 | # model = model.train() 15 | # for batch_idx, (features, labels) in enumerate(train_loader): 16 | 17 | # features, labels = features.to(device), labels.to(device) 18 | 19 | # logits = model(features) 20 | 21 | # loss = F.cross_entropy(logits, labels) 22 | # optimizer.zero_grad() 23 | # loss.backward() 24 | # optimizer.step() 25 | 26 | 27 | def train(num_epochs, learning_rate, model, loader, seed=1, device="cpu"): 28 | 29 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0) 30 | losses = [] 31 | model.train() 32 | torch.manual_seed(seed) 33 | 34 | for epoch in range(num_epochs): 35 | losses_= [] 36 | 37 | for data in loader: 38 | # move data to device 39 | data = data.to(device) 40 | # train loader # Iterate in batches over the training dataset. 41 | # set the appropriate gex 42 | model.set_gex(data.x) 43 | # derive the message as well as the mean field approximation 44 | msg, msg_intra, log_z_mft = model(data.edge_index, 1) 45 | # compute the loss and track it 46 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) ) 47 | losses_.append(loss.detach().numpy()[0][0]) 48 | # derive the gradients, update, and clear 49 | loss.backward() 50 | optimizer.step() 51 | optimizer.zero_grad() 52 | # repeatedly force a normalization 53 | model.conv1.lin.weight = torch.nn.Parameter(normalize_g2g(model.conv1.lin.weight), requires_grad=True) 54 | model.lin.weight = torch.nn.Parameter(normalize_g2g(model.lin.weight), requires_grad=True) 55 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0) 56 | 57 | print(f"Loss={np.mean(losses_)}") 58 | losses.append(np.mean(losses_)) 59 | 60 | return losses 61 | 62 | 63 | -------------------------------------------------------------------------------- /celcomen/training_plan/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | from ..utils.helpers import normalize_g2g, calc_sphex, calc_gex 5 | 6 | def train(epochs, learning_rate, model, loader, zmft_scalar=1e-1, seed=1, device="cpu", verbose=False): 7 | """ 8 | Trains the model using a stochastic gradient descent (SGD) optimizer over the specified number of epochs. 9 | 10 | During training, the model calculates messages between genes based on gene-to-gene (G2G) interactions 11 | and applies a Mean Field Theory (MFT) approximation for gene expression interactions. 12 | The G2G and intracellular regulation matrices are normalized after each step. 13 | 14 | Parameters 15 | ---------- 16 | epochs : int 17 | Number of training epochs to run. 18 | learning_rate : float 19 | Learning rate for the SGD optimizer. 20 | model : torch.nn.Module 21 | The model to be trained, which includes a graph convolutional layer and a linear layer. 22 | loader : torch_geometric.loader.DataLoader 23 | DataLoader that provides the data for each batch during training. 24 | zmft_scalar : float, optional 25 | Scalar to weight the Mean Field Theory term in the loss function. Default is 1e-1. 26 | seed : int, optional 27 | Seed for random number generation to ensure reproducibility. Default is 1. 28 | device : str, optional 29 | Device to use for training, e.g., "cpu" or "cuda". Default is "cpu". 30 | verbose : bool, optional 31 | If True, prints the loss at each epoch. Default is False. 32 | 33 | Returns 34 | ------- 35 | losses : list of float 36 | List of losses recorded at each epoch. 37 | 38 | Examples 39 | -------- 40 | >>> losses = train(epochs=100, learning_rate=0.01, model=my_model, loader=my_loader, zmft_scalar=1e-1, seed=42, device="cuda") 41 | >>> print(losses[-1]) # Final loss 42 | """ 43 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0) 44 | losses = [] 45 | model.train() 46 | torch.manual_seed(seed) 47 | 48 | for epoch in tqdm(range(epochs), total=epochs): 49 | losses_= [] 50 | 51 | for data in loader: 52 | # move data to device 53 | data = data.to(device) 54 | # train loader # Iterate in batches over the training dataset. 55 | # set the appropriate gex 56 | model.set_gex(data.x) 57 | # derive the message as well as the mean field approximation 58 | msg, msg_intra, log_z_mft = model(data.edge_index, 1) 59 | # compute the loss and track it 60 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) ) 61 | if device=="cpu": 62 | losses_.append(loss.detach().numpy()[0][0]) 63 | else: 64 | losses_.append(loss.detach().cpu().numpy()[0][0]) 65 | # derive the gradients, update, and clear 66 | loss.backward() 67 | optimizer.step() 68 | optimizer.zero_grad() 69 | # repeatedly force a normalization 70 | model.conv1.lin.weight = torch.nn.Parameter(normalize_g2g(model.conv1.lin.weight), requires_grad=True) 71 | model.lin.weight = torch.nn.Parameter(normalize_g2g(model.lin.weight), requires_grad=True) 72 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0) 73 | 74 | if verbose: print(f"Epoch={epoch} | Loss={np.mean(losses_)}") 75 | losses.append(np.mean(losses_)) 76 | 77 | return losses 78 | 79 | 80 | def train_simcomen(epochs, learning_rate, model, edge_index, zmft_scalar=1e-1, seed=1, device="cpu", verbose=False): 81 | """ 82 | Trains the `simcomen` model using stochastic gradient descent (SGD) over the specified number of epochs. 83 | 84 | The training process computes the Mean Field Theory (MFT) approximation of the partition function 85 | and normalizes gene expression data using a spherical representation. The loss is tracked and updated 86 | based on both gene-to-gene (G2G) and intracellular interactions. 87 | 88 | Parameters 89 | ---------- 90 | epochs : int 91 | Number of training epochs to run. 92 | learning_rate : float 93 | Learning rate for the SGD optimizer. 94 | model : torch.nn.Module 95 | The `simcomen` model to be trained, which includes graph convolution and linear layers. 96 | edge_index : torch.Tensor 97 | Tensor representing the edges in the graph, i.e., the connections between nodes (cells). 98 | zmft_scalar : float, optional 99 | Scalar to weight the Mean Field Theory term in the loss function. Default is 1e-1. 100 | seed : int, optional 101 | Seed for random number generation to ensure reproducibility. Default is 1. 102 | device : str, optional 103 | Device to use for training, e.g., "cpu" or "cuda". Default is "cpu". 104 | verbose : bool, optional 105 | If True, prints the loss at each epoch. Default is False. 106 | 107 | Returns 108 | ------- 109 | losses : list of float 110 | List of losses recorded at each epoch. 111 | 112 | Examples 113 | -------- 114 | >>> losses = train_simcomen(epochs=100, learning_rate=0.01, model=my_model, edge_index=my_edge_index, zmft_scalar=1e-1, seed=42, device="cuda") 115 | >>> print(losses[-1]) # Final loss 116 | """ 117 | # set up the optimizer 118 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0) 119 | # keep track of the losses per data object 120 | losses = [] 121 | model.train() 122 | torch.manual_seed(seed) 123 | 124 | tmp_gexs = [] 125 | # work through epochs 126 | for epoch in tqdm(range(epochs), total=epochs): 127 | # derive the message as well as the mean field approximation 128 | msg, msg_intra, log_z_mft = model(edge_index, 1) 129 | if (epoch % 5) == 0: 130 | if device=="cpu": 131 | tmp_gex = model.gex.clone().detach().numpy() 132 | else: 133 | tmp_gex = model.gex.clone().detach().cpu().numpy() 134 | tmp_gexs.append(tmp_gex) 135 | # compute the loss and track it 136 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) ) 137 | if device=="cpu": 138 | losses.append(loss.detach().numpy()[0][0]) 139 | else: 140 | losses.append(loss.detach().cpu().numpy()[0][0]) 141 | # derive the gradients, update, and clear 142 | if verbose: print(f"Epoch={epoch} | Loss={np.mean(losses[-1])}") 143 | 144 | loss.backward() 145 | optimizer.step() 146 | optimizer.zero_grad() 147 | 148 | return losses 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /celcomen/utils/.ipynb_checkpoints/helpers-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # define a function to derive the gex from the sphex 5 | def calc_gex(sphex): 6 | """ 7 | Calculates the gene expression matrix from the spherical 8 | """ 9 | # setup the gex 10 | n_genes = sphex.shape[1]+1 11 | gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32')) 12 | # compute the gex 13 | for idx in range(n_genes): 14 | if idx == n_genes-1: 15 | gex[:,idx] = torch.sin(sphex[:,idx-1]) 16 | else: 17 | gex[:,idx] = torch.cos(sphex[:,idx]) 18 | for idx_ in range(idx): 19 | gex[:,idx] *= torch.sin(sphex[:,idx_]) 20 | return gex 21 | 22 | # define a function to gather positions 23 | def get_pos(n_x, n_y): 24 | # create the hex lattice 25 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)]) 26 | # derive the y-step given a distance of one 27 | y_step = np.sqrt(1**2+0.5**2) 28 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)]) 29 | # define the positions 30 | pos = np.vstack([xs.flatten(), ys.flatten()]).T 31 | return pos 32 | 33 | 34 | # define a function to normalize the g2g 35 | def normalize_g2g(g2g): 36 | """ 37 | Addresses any small fluctuations in symmetrical weights 38 | """ 39 | # symmetrize the values 40 | g2g = (g2g + g2g.T) / 2 41 | # force them to be between 0-1 42 | g2g[g2g < 0] = 0 43 | g2g[g2g > 1] = 1 44 | # force the central line to be 1 45 | for idx in range(len(g2g)): 46 | g2g[idx, idx] = 1 47 | return g2g 48 | -------------------------------------------------------------------------------- /celcomen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import helpers 2 | -------------------------------------------------------------------------------- /celcomen/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # define a function to derive the gex from the sphex 5 | def calc_gex(sphex): 6 | """ 7 | Converts a spherical gene expression matrix into a standard gene expression matrix. 8 | 9 | The function takes a spherical expression matrix (`sphex`) and calculates the corresponding 10 | gene expression matrix (`gex`) using trigonometric functions such as sine and cosine. 11 | The last gene is computed using sine, and the others are computed using cosine, followed by 12 | multiplying by the sine of all preceding dimensions. 13 | 14 | Parameters 15 | ---------- 16 | sphex : torch.Tensor 17 | The spherical gene expression matrix. Expected shape is (n_samples, n_features - 1). 18 | 19 | Returns 20 | ------- 21 | gex : torch.Tensor 22 | The computed gene expression matrix of shape (n_samples, n_features), where `n_features` is one more than in `sphex`. 23 | All NaN values are replaced with 0. 24 | 25 | Examples 26 | -------- 27 | >>> sphex = torch.tensor([[0.5, 0.6], [0.3, 0.4]]) 28 | >>> gex = calc_gex(sphex) 29 | >>> print(gex) 30 | tensor([[0.8776, 0.5646, 0.4794], 31 | [0.9553, 0.5646, 0.2955]]) 32 | """ 33 | # setup the gex 34 | n_genes = sphex.shape[1]+1 35 | gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32')) 36 | # compute the gex 37 | for idx in range(n_genes): 38 | if idx == n_genes-1: 39 | gex[:,idx] = torch.sin(sphex[:,idx-1]) 40 | else: 41 | gex[:,idx] = torch.cos(sphex[:,idx]) 42 | for idx_ in range(idx): 43 | gex[:,idx] *= torch.sin(sphex[:,idx_]) 44 | return torch.nan_to_num(gex) 45 | 46 | # define a function to gather positions 47 | def get_pos(n_x, n_y): 48 | """ 49 | Generates a 2D hexagonal grid of positions. 50 | 51 | This function creates a hexagonal lattice for a 2D grid, where the x-coordinates are adjusted 52 | for alternating rows. The y-coordinates are spaced based on a predefined step size derived 53 | from the geometry of a hexagonal grid. 54 | 55 | Parameters 56 | ---------- 57 | n_x : int 58 | Number of positions along the x-axis. 59 | n_y : int 60 | Number of positions along the y-axis. 61 | 62 | Returns 63 | ------- 64 | pos : numpy.ndarray 65 | A 2D array of shape (n_x * n_y, 2) representing the coordinates of the positions on the hexagonal grid. 66 | 67 | Examples 68 | -------- 69 | >>> pos = get_pos(3, 3) 70 | >>> print(pos) 71 | array([[0.5, 0. ], 72 | [1.5, 0. ], 73 | [2.5, 0. ], 74 | [0. , 1.11803399], 75 | [1. , 1.11803399], 76 | [2. , 1.11803399], 77 | [0.5, 2.23606798], 78 | [1.5, 2.23606798], 79 | [2.5, 2.23606798]]) 80 | """ 81 | # create the hex lattice 82 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)]) 83 | # derive the y-step given a distance of one 84 | y_step = np.sqrt(1**2+0.5**2) 85 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)]) 86 | # define the positions 87 | pos = np.vstack([xs.flatten(), ys.flatten()]).T 88 | return pos 89 | 90 | 91 | # define a function to normalize the g2g 92 | def normalize_g2g(g2g): 93 | """ 94 | Symmetrizes and normalizes a gene-to-gene (G2G) interaction matrix. 95 | 96 | This function ensures that the matrix is symmetrical, normalizes values to be between 0 and 1, 97 | and forces the diagonal to be 1 (representing self-interactions). 98 | 99 | Parameters 100 | ---------- 101 | g2g : numpy.ndarray 102 | The gene-to-gene interaction matrix, typically of shape (n_genes, n_genes). 103 | 104 | Returns 105 | ------- 106 | g2g : numpy.ndarray 107 | The normalized and symmetrized gene-to-gene interaction matrix. 108 | 109 | Examples 110 | -------- 111 | >>> g2g = np.array([[0.8, 0.2], [0.1, 0.7]]) 112 | >>> normalized_g2g = normalize_g2g(g2g) 113 | >>> print(normalized_g2g) 114 | array([[1. , 0.15], 115 | [0.15, 1. ]]) 116 | """ 117 | # symmetrize the values 118 | g2g = (g2g + g2g.T) / 2 119 | # force them to be between 0-1 120 | g2g[g2g < 0] = 0 121 | g2g[g2g > 1] = 1 122 | # force the central line to be 1 123 | for idx in range(len(g2g)): 124 | g2g[idx, idx] = 1 125 | return g2g 126 | 127 | # define a function to derive the gex from the sphex 128 | def calc_sphex(gex): 129 | """ 130 | Converts a standard gene expression matrix into a spherical expression matrix. 131 | 132 | This function calculates the spherical representation of a gene expression matrix (`gex`), 133 | where the new features are derived using trigonometric functions such as arcsin and arccos. 134 | 135 | Parameters 136 | ---------- 137 | gex : torch.Tensor 138 | The standard gene expression matrix, of shape (n_samples, n_genes). 139 | 140 | Returns 141 | ------- 142 | sphex : torch.Tensor 143 | The spherical gene expression matrix, of shape (n_samples, n_genes - 1). 144 | 145 | Examples 146 | -------- 147 | >>> gex = torch.tensor([[0.8, 0.6, 0.4], [0.7, 0.5, 0.3]]) 148 | >>> sphex = calc_sphex(gex) 149 | >>> print(sphex) 150 | tensor([[1.5708, 0.5404], 151 | [1.4706, 0.5236]]) 152 | """ 153 | # setup the gex 154 | n_sgenes = gex.shape[1]-1 155 | sphex = torch.from_numpy(np.zeros((gex.shape[0], n_sgenes)).astype('float32')) 156 | # compute the gex 157 | for idx in range(n_sgenes): 158 | sphex[:,idx] = gex[:,idx] 159 | for idx_ in range(idx): 160 | sphex[:,idx] /= torch.sin(sphex[:,idx_]) 161 | sphex[:,idx] = torch.arccos(sphex[:,idx]) 162 | return sphex 163 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/Makefile-checkpoint: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/make-checkpoint.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Celcomen' 10 | copyright = '2025, Stathis Megas' 11 | author = 'Stathis Megas' 12 | import celcomen as cce 13 | release = cce.__version__ 14 | 15 | # -- General configuration --------------------------------------------------- 16 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 17 | 18 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary'] 19 | templates_path = ['_templates'] 20 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 21 | 22 | # -- Options for HTML output ------------------------------------------------- 23 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 24 | 25 | html_theme = 'sphinx_rtd_theme' 26 | html_static_path = ['_static'] 27 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Celcomen documentation master file, created by 2 | sphinx-quickstart on Thu Jan 9 18:35:31 2025. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Celcomen 7 | ======== 8 | Project home page `here `_. 9 | 10 | Main workflow classes 11 | ----------------------- 12 | 13 | The ``celcomen`` class 14 | ---------------------- 15 | 16 | .. autoclass:: celcomen.models.celcomen.celcomen 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | The ``simcomen`` class 22 | ---------------------- 23 | 24 | .. autoclass:: celcomen.models.simcomen.simcomen 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | Utility functions 30 | ----------------- 31 | .. autosummary:: 32 | :toctree: 33 | 34 | celcomen.utils.helpers.calc_gex 35 | celcomen.datareaders.datareader.get_dataset_loaders 36 | 37 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme 2 | -------------------------------------------------------------------------------- /docs/source/.ipynb_checkpoints/conf-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Celcomen' 10 | copyright = '2025, Stathis Megas' 11 | author = 'Stathis Megas' 12 | release = '0.0.1' 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [] 18 | 19 | templates_path = ['_templates'] 20 | exclude_patterns = [] 21 | 22 | 23 | 24 | # -- Options for HTML output ------------------------------------------------- 25 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 26 | 27 | html_theme = 'sphinx_rtd_theme' 28 | html_static_path = ['_static'] 29 | -------------------------------------------------------------------------------- /docs/source/.ipynb_checkpoints/index-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. Celcomen documentation master file, created by 2 | sphinx-quickstart on Thu Jan 9 18:35:31 2025. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Celcomen documentation 7 | ====================== 8 | 9 | Add your content using ``reStructuredText`` syntax. See the 10 | `reStructuredText `_ 11 | documentation for details. 12 | 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Contents: 17 | 18 | -------------------------------------------------------------------------------- /images/disentangling graphs and gene colocalization-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/images/disentangling graphs and gene colocalization-2.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "celcomen" 7 | authors = [{name = "Stathis Megas", email = "stathismegas@gmail.com"}] 8 | version = "0.0.2" 9 | description = "Celcomen is a first step towards models of Virtual Tissue, which generalize models of Virtual Cells to the tissue context, and can be leveraged to predict tissue counterfactuals such as spatially localised gene knockouts" 10 | dependencies = [ 11 | "numpy", 12 | "torch", 13 | "torch_geometric", 14 | "scikit-learn", 15 | "scanpy", 16 | "tqdm" 17 | ] 18 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.10" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Install our python package before building the docs 19 | python: 20 | install: 21 | - requirements: docs/requirements.txt 22 | - method: pip 23 | path: . 24 | --------------------------------------------------------------------------------