├── README.md
├── analysis.spatial_KO.xenium_human_glioblastoma_gpu.ipynb
├── celcomen
├── .ipynb_checkpoints
│ └── __init__-checkpoint.py
├── __init__.py
├── __pycache__
│ └── __init__.cpython-310.pyc
├── datareaders
│ ├── __init__.py
│ └── datareader.py
├── models
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ ├── celcomen-checkpoint.py
│ │ └── simcomen-checkpoint.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── celcomen.cpython-310.pyc
│ │ ├── celcomen.cpython-39.pyc
│ │ ├── simcomen.cpython-310.pyc
│ │ └── simcomen.cpython-39.pyc
│ ├── celcomen.py
│ └── simcomen.py
├── training_plan
│ ├── .ipynb_checkpoints
│ │ └── train-checkpoint.py
│ └── train.py
└── utils
│ ├── .ipynb_checkpoints
│ └── helpers-checkpoint.py
│ ├── __init__.py
│ └── helpers.py
├── docs
├── .ipynb_checkpoints
│ ├── Makefile-checkpoint
│ └── make-checkpoint.bat
├── Makefile
├── conf.py
├── index.rst
├── make.bat
├── requirements.txt
└── source
│ └── .ipynb_checkpoints
│ ├── conf-checkpoint.py
│ └── index-checkpoint.rst
├── images
└── disentangling graphs and gene colocalization-2.png
├── pyproject.toml
└── readthedocs.yaml
/README.md:
--------------------------------------------------------------------------------
1 | # Cell Communication Energy (celcomen)
2 | Causal generative model designed to disentangle intercellular and intracellular gene regulation with theoretical identifiability guarantees. Celcomen can generate counterfactual spatial transcriptomic samples by simulating the effect of local perturbations, such as gene activations/inhibitations or cell insertions/deletions.
3 |
4 | You can find out more by reading our [manuscript](https://arxiv.org/abs/2409.05804).
5 |
6 |
7 |
8 |
9 |
10 | Installation
11 | ============
12 | Conda Environment
13 | --
14 | We recommend using [Anaconda](https://www.anaconda.com/)/[Miniconda](https://docs.conda.io/projects/miniconda/en/latest/) to create a conda environment for using celcomen. You can create a python environment using the following command:
15 |
16 | conda create -n celcomen_env python=3.9
17 |
18 | Then, you can activate the environment using:
19 |
20 | conda activate celcomen_env
21 |
22 | Install celcomen
23 | --
24 | Then install
25 | ```
26 | pip install git+https://github.com/stathismegas/celcomen
27 | ```
28 |
29 | Causal Disentanglement and spatial Counterfactuals
30 | ============
31 | To learn intracellular and extra-cellular gene regulation and then use it to simulate inflammation conuterfactuals in specific locaitons of the tissue, follow the tutorial `analysis.perturbation_newest_celcomen.ipynb`.
32 |
33 | As explained in the tutorial, the adata object should have count data, without any prior normalization or log-transformation.
34 |
35 | To speed up the training process on a GPU refer to the tutorial `train_using_dataloaders_gpu.ipynb`.
36 |
37 |
--------------------------------------------------------------------------------
/celcomen/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | """Method for causal disentanglement of inter-cellular and intra-cellular gene regulation in spatial transcriptomics"""
2 | __version__ = "0.0.1"
--------------------------------------------------------------------------------
/celcomen/__init__.py:
--------------------------------------------------------------------------------
1 | """Method for causal disentanglement of inter-cellular and intra-cellular gene regulation in spatial transcriptomics"""
2 | __version__ = "0.0.1"
3 |
4 | from . import datareaders
5 | from . import models
6 | from . import utils
7 |
--------------------------------------------------------------------------------
/celcomen/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/celcomen/datareaders/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datareader
2 |
--------------------------------------------------------------------------------
/celcomen/datareaders/datareader.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.loader import DataLoader
2 | import scanpy as sc
3 | import torch
4 | import torch_geometric
5 | #from torch_cluster.knn import knn_graph
6 | from sklearn.neighbors import kneighbors_graph
7 | import numpy as np
8 |
9 | from scipy.spatial.distance import pdist, squareform
10 |
11 | def get_dataset_loaders(h5ad_path: str, sample_id_name: str, n_neighbors: int, distance: float, device: str, verbose: bool):
12 | """
13 | Prepares and returns PyTorch Geometric DataLoader from a single-cell spatial transcriptomics dataset.
14 |
15 | The function reads a single-cell AnnData object from an H5AD file, normalises the data, and generates
16 | graph data where nodes correspond to cells, and edges are created based on spatial proximity using
17 | a k-nearest neighbours graph. The data is then loaded into a PyTorch Geometric `DataLoader`.
18 |
19 | Parameters
20 | ----------
21 | h5ad_path : str
22 | Path to the H5AD file containing the raw counts of the single-cell spatial transcriptomics data.
23 | sample_id_name : str
24 | Name of the sample ID column in `adata.obs` to separate the dataset into different samples.
25 | n_neighbors : int
26 | Number of neighbours to use for constructing the k-nearest neighbours graph for spatial information.
27 | distance : float
28 | Distance of neighbours to use for constructing the k-nearest neighbours graph for spatial information.
29 | verbose : bool
30 | If True, prints detailed information about the DataLoader during the loading process.
31 |
32 | Returns
33 | -------
34 | DataLoader
35 | A PyTorch Geometric DataLoader containing the processed graph data, with each graph representing
36 | a sample of cells in the dataset.
37 |
38 | Notes
39 | -----
40 | - The spatial positions of the cells are used to create a k-nearest neighbours graph, with edges
41 | connecting cells that are spatially close to each other.
42 | - The input features for the graph (`x`) are normalised before constructing the graph.
43 | - `adata.obsm["spatial"]` is used to extract the spatial coordinates of the cells.
44 | - The graph data is validated using PyTorch Geometric's built-in validation method.
45 |
46 | Examples
47 | --------
48 | >>> loader = get_dataset_loaders('data.h5ad', 'sample_id', n_neighbors=6, verbose=True)
49 | Step 1
50 | =====
51 | Number of graphs in the current batch: 1
52 | Data(x=[100, 33500], edge_index=[2, 500], pos=[100, 2], y=[1])
53 |
54 | Raises
55 | ------
56 | ValueError
57 | If there are issues with the input data during validation, e.g., if the graph is not well-formed.
58 | """
59 |
60 | adata = sc.read_h5ad(h5ad_path)
61 |
62 | # sc.pp.normalize_total(adata, target_sum=1e6)
63 | # sc.pp.log1p(adata)
64 |
65 | adata_list = [ adata[adata.obs[sample_id_name]==i] for i in set(adata.obs[sample_id_name]) ]
66 |
67 | data_list = []
68 | n_neighbors = 6
69 |
70 | for adata in adata_list:
71 | pos = torch.from_numpy(adata.obsm["spatial"])
72 | x = torch.from_numpy(adata.X.todense())
73 | # normalize x
74 | norm_factor = torch.pow(x,2).sum(1).reshape(-1,1)
75 | x = torch.div(x, norm_factor)
76 | y = torch.Tensor([0])
77 | #edge_index = knn_graph(pos, k=n_neighbors)
78 | #distances = squareform(pdist(df.loc[mask, ['x_centroid','y_centroid']]))
79 | distances = squareform(pdist( adata.obsm["spatial"] ) )
80 | if distance!=None:
81 | # compute the edges as two cell widths apart so 30µm
82 | edge_index = torch.from_numpy(np.array(np.where((distances < distance)&(distances != 0)))).to(device)
83 | if n_neighbors!=None:
84 | edge_index = torch.from_numpy(np.array(np.where(kneighbors_graph(adata.obsm["spatial"], n_neighbors=n_neighbors, include_self=False).toarray() == 1))).to(device)
85 | data = torch_geometric.data.Data(x=x, pos=pos, y=y, edge_index=edge_index)
86 | data.validate(raise_on_error=True) # performs basic checks on the graph
87 | data_list.append(data)
88 |
89 | loader = DataLoader( data_list, batch_size=1, shuffle=True)
90 |
91 | if verbose:
92 | for step, data in enumerate(loader):
93 | print(f'Step {step+1}')
94 | print("=====")
95 | print(f'Number of graphs in the current batch: {data.num_graphs}')
96 | print(data)
97 | print()
98 |
99 | return loader
100 |
--------------------------------------------------------------------------------
/celcomen/models/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/.ipynb_checkpoints/__init__-checkpoint.py
--------------------------------------------------------------------------------
/celcomen/models/.ipynb_checkpoints/celcomen-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # define the celcomen class
4 | class celcomen(torch.nn.Module):
5 | # define initialization function
6 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0):
7 | super(celcomen, self).__init__()
8 | # define the seed
9 | torch.manual_seed(seed)
10 | # set up the graph convolution
11 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False)
12 | # set up the linear layer for intracellular gene regulation
13 | self.lin = torch.nn.Linear(input_dim, output_dim)
14 | # define the neighbors
15 | self.n_neighbors = n_neighbors
16 | # define a tracking variable for the gene expression x matrix
17 | self.gex = None
18 |
19 | # define a function to artificially set the g2g matrix
20 | def set_g2g(self, g2g):
21 | """
22 | Artifically sets the core g2g matrix to be a specified interaction matrix
23 | """
24 | # set the weight as the input
25 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=True)
26 | # and then set the bias as all zeros
27 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g)).astype('float32')), requires_grad=False)
28 |
29 | # define a function to artificially set the g2g matrix
30 | def set_g2g_intra(self, g2g_intra):
31 | """
32 | Artifically sets the core g2g intracellular matrix to be a specified matrix
33 | """
34 | # set the weight as the input
35 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=True)
36 | # and then set the bias as all zeros
37 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False)
38 |
39 | # define a function to artificially set the sphex matrix
40 | def set_gex(self, gex):
41 | """
42 | Artifically sets the current sphex matrix
43 | """
44 | self.gex = torch.nn.Parameter(gex, requires_grad=False)
45 |
46 | # define the forward pass
47 | def forward(self, edge_index, batch):
48 | """
49 | Forward pass for prediction or training,
50 | convolutes the input by the expected interactions and returns log(Z_mft)
51 | """
52 | # compute the message
53 | msg = self.conv1(self.gex, edge_index)
54 | # compute intracellular message
55 | msg_intra = self.lin(self.gex)
56 | # compute the log z mft
57 | log_z_mft = self.log_Z_mft(edge_index, batch)
58 | return msg, msg_intra, log_z_mft
59 |
60 | # define approximation function
61 | def log_Z_mft(self, edge_index, batch):
62 | """
63 | Mean Field Theory approximation to the partition function. Assumptions used are:
64 | - expression of values of genes are close to their mean values over the visium slide
65 | - \sum_b g_{a,b} m^b >0 \forall a, where m is the mean gene expression and g is the gene-gene
66 | interaction matrix.
67 | """
68 | # retrieve number of spots
69 | num_spots = self.gex.shape[0]
70 | # calculate mean gene expression
71 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph
72 | # calculate the norm of the sum of mean genes
73 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) # maybe needs to change to g = torch.norm(torch.mm(mean_genes, self.conv1.lin.weight))
74 | # calculate the contribution for mean values
75 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes)
76 | # calculate the contribution gene interactions
77 | z_interaction = self.z_interaction(num_spots=num_spots, g=g)
78 | # add the two contributions
79 | log_z_mft = z_mean + z_interaction
80 | return log_z_mft
81 |
82 | def z_interaction(self, num_spots, g):
83 | """
84 | Avoid exploding exponentials by returning an approximate interaction term for the partition function.
85 | """
86 | if g>20:
87 | z_interaction = num_spots * ( g - torch.log( g) )
88 | else:
89 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g))
90 | return z_interaction
--------------------------------------------------------------------------------
/celcomen/models/.ipynb_checkpoints/simcomen-checkpoint.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.nn import GCNConv
2 | from sklearn.neighbors import kneighbors_graph
3 | # define the number of neighbors (six for visium)
4 | n_neighbors = 6
5 | # define the simcomen class
6 | class simcomen(torch.nn.Module):
7 | # define initialization function
8 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0):
9 | super(simcomen, self).__init__()
10 | # define the seed
11 | torch.manual_seed(seed)
12 | # set up the graph convolution
13 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False)
14 | # set up the linear layer for intracellular gene regulation
15 | self.lin = torch.nn.Linear(input_dim, output_dim)
16 | # define the neighbors
17 | self.n_neighbors = n_neighbors
18 | # define a tracking variable for the gene expression x matrix
19 | self.sphex = None
20 | self.gex = None
21 |
22 | # define a function to artificially set the g2g matrix
23 | def set_g2g(self, g2g):
24 | """
25 | Artifically sets the core g2g matrix to be a specified interaction matrix
26 | """
27 | # set the weight as the input
28 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=False)
29 | # and then set the bias as all zeros
30 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(n_genes).astype('float32')), requires_grad=False)
31 |
32 | # define a function to artificially set the g2g matrix
33 | def set_g2g_intra(self, g2g_intra):
34 | """
35 | Artifically sets the core g2g intracellular matrix to be a specified matrix
36 | """
37 | # set the weight as the input
38 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=False)
39 | # and then set the bias as all zeros
40 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False)
41 |
42 | # define a function to artificially set the sphex matrix
43 | def set_sphex(self, sphex):
44 | """
45 | Artifically sets the current sphex matrix
46 | """
47 | self.sphex = torch.nn.Parameter(sphex, requires_grad=True)
48 |
49 | # define the forward pass
50 | def forward(self, edge_index, batch):
51 | """
52 | Forward pass for prediction or training,
53 | convolutes the input by the expected interactions and returns log(Z_mft)
54 | """
55 | # compute the gex
56 | self.gex = calc_gex(self.sphex)
57 | # compute the message
58 | msg = self.conv1(self.gex, edge_index)
59 | # compute intracellular message
60 | msg_intra = self.lin(self.gex)
61 | # compute the log z mft
62 | log_z_mft = self.log_Z_mft(edge_index, batch)
63 | return msg, msg_intra, log_z_mft
64 |
65 | # define approximation function
66 | def log_Z_mft(self, edge_index, batch):
67 | """
68 | Mean Field Theory approximation to the partition function. Assumptions used are:
69 | - expression of values of genes are close to their mean values over the visium slide
70 | - \sum_b g_{a,b} m^b >0 \forall a, where m is the mean gene expression and g is the gene-gene
71 | interaction matrix.
72 | """
73 | # retrieve number of spots
74 | num_spots = self.gex.shape[0]
75 | # calculate mean gene expression
76 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph
77 | # calculate the norm of the sum of mean genes
78 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes)) # maybe needs to change to g = torch.norm(torch.mm(mean_genes, self.conv1.lin.weight))
79 | # calculate the contribution for mean values
80 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes)
81 | # calculate the contribution gene interactions
82 | z_interaction = self.z_interaction(num_spots=num_spots, g=g)
83 | # add the two contributions
84 | log_z_mft = z_mean + z_interaction
85 | return log_z_mft
86 |
87 | def z_interaction(self, num_spots, g):
88 | """
89 | Avoid exploding exponentials by returning an approximate interaction term for the partition function.
90 | """
91 | if g>20:
92 | z_interaction = num_spots * ( g - torch.log( g) )
93 | else:
94 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g))
95 | return z_interaction
--------------------------------------------------------------------------------
/celcomen/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import celcomen
2 | from . import simcomen
3 |
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/celcomen.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/celcomen.cpython-310.pyc
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/celcomen.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/celcomen.cpython-39.pyc
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/simcomen.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/simcomen.cpython-310.pyc
--------------------------------------------------------------------------------
/celcomen/models/__pycache__/simcomen.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/celcomen/models/__pycache__/simcomen.cpython-39.pyc
--------------------------------------------------------------------------------
/celcomen/models/celcomen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn import GCNConv
3 | from sklearn.neighbors import kneighbors_graph
4 | import numpy as np
5 |
6 | # define the celcomen class
7 | class celcomen(torch.nn.Module):
8 | """
9 | A k-hop Graph Neural Network model for disentangling inter- and intra-cellular gene regulation, and then leveraging to predict spatial counterfactuals.
10 |
11 | Parameters
12 | ----------
13 | input_dim : int
14 | Dimensionality of the input features (gene expression data).
15 | output_dim : int
16 | Dimensionality of the output features.
17 | n_neighbors : int
18 | The number of neighbours used in the spatial graph to model cell-cell interactions.
19 | seed : int, optional
20 | Seed for random number generation to ensure reproducibility. Default is 0.
21 |
22 | Attributes
23 | ----------
24 | conv1 : GCNConv
25 | A graph convolutional layer that models gene-to-gene interactions (G2G).
26 | lin : torch.nn.Linear
27 | A linear layer that models intracellular gene regulation.
28 | n_neighbors : int
29 | The number of neighbours for spatial graph construction.
30 | gex : torch.nn.Parameter or None
31 | Stores the gene expression matrix used for the forward pass. Set to None initially.
32 |
33 | Methods
34 | -------
35 | set_g2g(g2g)
36 | Sets the gene-to-gene (G2G) interaction matrix artificially.
37 | set_g2g_intra(g2g_intra)
38 | Sets the intracellular regulation matrix artificially.
39 | set_gex(gex)
40 | Sets the gene expression matrix artificially.
41 | forward(edge_index, batch)
42 | Forward pass to compute the gene-to-gene and intracellular messages,
43 | and the log partition function estimate.
44 | log_Z_mft(edge_index, batch)
45 | Computes the Mean Field Theory (MFT) approximation to the partition function.
46 | z_interaction(num_spots, g)
47 | Provides an approximation for the interaction term in the partition function
48 | to prevent numerical instability due to exploding exponentials.
49 |
50 | Examples
51 | --------
52 | >>> model = celcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42)
53 | >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
54 | >>> batch = torch.tensor([0, 1], dtype=torch.long)
55 | >>> model.set_gex(torch.randn(100, 1000))
56 | >>> msg, msg_intra, log_z_mft = model(edge_index, batch)
57 | >>> print(log_z_mft)
58 | """
59 |
60 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0):
61 | """
62 | Initializes the celcomen model with a graph convolution layer and a linear
63 | layer for gene-to-gene and intracellular regulation, respectively.
64 |
65 | Parameters
66 | ----------
67 | input_dim : int
68 | Dimensionality of the input features.
69 | output_dim : int
70 | Dimensionality of the output features.
71 | n_neighbors : int
72 | Number of neighbours for constructing the spatial graph.
73 | seed : int, optional
74 | Random seed for reproducibility (default is 0).
75 | """
76 | super(celcomen, self).__init__()
77 | # define the seed
78 | torch.manual_seed(seed)
79 | # set up the graph convolution
80 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False)
81 | # set up the linear layer for intracellular gene regulation
82 | self.lin = torch.nn.Linear(input_dim, output_dim)
83 | # define the neighbors
84 | self.n_neighbors = n_neighbors
85 | # define a tracking variable for the gene expression x matrix
86 | self.gex = None
87 |
88 | # define a function to artificially set the g2g matrix
89 | def set_g2g(self, g2g):
90 | """
91 | Artificially sets the gene-to-gene (G2G) interaction matrix.
92 |
93 | Parameters
94 | ----------
95 | g2g : torch.Tensor
96 | A matrix representing gene-to-gene interactions to be used for graph convolution.
97 | """
98 | # set the weight as the input
99 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=True)
100 | # and then set the bias as all zeros
101 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g)).astype('float32')), requires_grad=False)
102 |
103 | # define a function to artificially set the g2g matrix
104 | def set_g2g_intra(self, g2g_intra):
105 | """
106 | Artificially sets the intracellular gene regulation matrix.
107 |
108 | Parameters
109 | ----------
110 | g2g_intra : torch.Tensor
111 | A matrix representing intracellular gene regulation interactions.
112 | """
113 | # set the weight as the input
114 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=True)
115 | # and then set the bias as all zeros
116 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False)
117 |
118 | # define a function to artificially set the sphex matrix
119 | def set_gex(self, gex):
120 | """
121 | Sets the gene expression matrix to be used during the forward pass.
122 |
123 | Parameters
124 | ----------
125 | gex : torch.Tensor
126 | A matrix representing the gene expression of the cells.
127 | """
128 | self.gex = torch.nn.Parameter(gex, requires_grad=False)
129 |
130 | # define the forward pass
131 | def forward(self, edge_index, batch):
132 | """
133 | Forward pass for the model, computing gene-to-gene and intracellular messages,
134 | and estimating the log partition function using Mean Field Theory (MFT).
135 |
136 | Parameters
137 | ----------
138 | edge_index : torch.Tensor
139 | Tensor representing the graph edges (connectivity between nodes/cells).
140 | batch : torch.Tensor
141 | Tensor representing the batch of data.
142 |
143 | Returns
144 | -------
145 | msg : torch.Tensor
146 | The message propagated between cells based on gene-to-gene interactions.
147 | msg_intra : torch.Tensor
148 | The message based on intracellular gene regulation.
149 | log_z_mft : torch.Tensor
150 | The Mean Field Theory approximation to the log partition function.
151 | """
152 | # compute the message
153 | msg = self.conv1(self.gex, edge_index)
154 | # compute intracellular message
155 | msg_intra = self.lin(self.gex)
156 | # compute the log z mft
157 | log_z_mft = self.log_Z_mft(edge_index, batch)
158 | return msg, msg_intra, log_z_mft
159 |
160 | # define approximation function
161 | def log_Z_mft(self, edge_index, batch):
162 | """
163 | Computes the Mean Field Theory (MFT) approximation to the partition function,
164 | which estimates the likelihood of gene expression states in the dataset.
165 |
166 | Parameters
167 | ----------
168 | edge_index : torch.Tensor
169 | Tensor representing the graph edges (connectivity between nodes/cells).
170 | batch : torch.Tensor
171 | Tensor representing the batch of data.
172 |
173 | Returns
174 | -------
175 | log_z_mft : torch.Tensor
176 | The log partition function estimated using Mean Field Theory (MFT).
177 | """
178 | # retrieve number of spots
179 | num_spots = self.gex.shape[0]
180 | # calculate mean gene expression
181 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph
182 | # calculate the norm of the sum of mean genes
183 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes))
184 | # calculate the contribution for mean values
185 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes)
186 | # calculate the contribution gene interactions
187 | z_interaction = self.z_interaction(num_spots=num_spots, g=g)
188 | # add the two contributions
189 | log_z_mft = z_mean + z_interaction
190 | return log_z_mft
191 |
192 | def z_interaction(self, num_spots, g):
193 | """
194 | Avoids exploding exponentials in the partition function approximation by
195 | returning an approximate interaction term.
196 |
197 | Parameters
198 | ----------
199 | num_spots : int
200 | Number of spots (cells) in the dataset.
201 | g : torch.Tensor
202 | Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.
203 |
204 | Returns
205 | -------
206 | z_interaction : torch.Tensor
207 | The approximated interaction term for the partition function.
208 | """
209 | if g>20:
210 | z_interaction = num_spots * ( g - torch.log( g) )
211 | else:
212 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g))
213 | return z_interaction
--------------------------------------------------------------------------------
/celcomen/models/simcomen.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.nn import GCNConv
2 | from sklearn.neighbors import kneighbors_graph
3 | import torch
4 | import numpy as np
5 | #from ..utils.helpers import calc_gex
6 |
7 | # define the number of neighbors (six for visium)
8 | n_neighbors = 6
9 | # define the simcomen class
10 | class simcomen(torch.nn.Module):
11 | """
12 | A k-hop Graph Neural Network model for predicting spatial counterfactuals, such as localised perturbations.
13 |
14 | Parameters
15 | ----------
16 | input_dim : int
17 | The dimensionality of the input gene expression features.
18 | output_dim : int
19 | The dimensionality of the output features after processing through graph convolution and linear layers.
20 | n_neighbors : int
21 | The number of neighbors to use in constructing the k-nearest neighbor graph.
22 | seed : int, optional
23 | Random seed for reproducibility, default is 0.
24 |
25 | Attributes
26 | ----------
27 | conv1 : GCNConv
28 | Graph convolutional layer for gene-to-gene (G2G) interactions.
29 | lin : torch.nn.Linear
30 | Linear layer for intracellular gene regulation.
31 | n_neighbors : int
32 | Number of spatial neighbors used for constructing the graph.
33 | sphex : torch.nn.Parameter or None
34 | Spherical gene expression matrix, set via `set_sphex`.
35 | gex : torch.nn.Parameter or None
36 | Gene expression matrix, calculated from the spherical expression matrix.
37 | output_dim : int
38 | Output dimensionality of the model.
39 |
40 | Methods
41 | -------
42 | set_g2g(g2g)
43 | Sets the gene-to-gene (G2G) interaction matrix artificially.
44 | set_g2g_intra(g2g_intra)
45 | Sets the intracellular gene regulation matrix artificially.
46 | set_sphex(sphex)
47 | Sets the spherical gene expression matrix artificially.
48 | forward(edge_index, batch)
49 | Forward pass of the model, calculating messages from gene-to-gene interactions,
50 | intracellular interactions, and the log partition function (log(Z_mft)).
51 | log_Z_mft(edge_index, batch)
52 | Computes the Mean Field Theory (MFT) approximation to the partition function for the current gene expressions.
53 | z_interaction(num_spots, g)
54 | Calculates the interaction term for the partition function while avoiding numerical instability.
55 | calc_gex(sphex)
56 | Converts the spherical gene expression matrix into a regular gene expression matrix.
57 | calc_sphex(gex)
58 | Converts the regular gene expression matrix into a spherical gene expression matrix.
59 | get_pos(n_x, n_y)
60 | Generates a 2D hexagonal grid of positions for spatial modeling.
61 | normalize_g2g(g2g)
62 | Symmetrizes and normalizes the gene-to-gene interaction matrix.
63 |
64 | Examples
65 | --------
66 | >>> model = simcomen(input_dim=1000, output_dim=100, n_neighbors=6, seed=42)
67 | >>> edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
68 | >>> batch = torch.tensor([0, 1], dtype=torch.long)
69 | >>> model.set_sphex(torch.randn(100, 1000))
70 | >>> msg, msg_intra, log_z_mft = model(edge_index, batch)
71 | >>> print(log_z_mft)
72 | """
73 | def __init__(self, input_dim, output_dim, n_neighbors, seed=0):
74 | """
75 | Initializes the `simcomen` model with a graph convolution layer and a linear layer
76 | for gene-to-gene interactions and intracellular regulation, respectively.
77 |
78 | Parameters
79 | ----------
80 | input_dim : int
81 | Dimensionality of the input features.
82 | output_dim : int
83 | Dimensionality of the output features.
84 | n_neighbors : int
85 | Number of neighbors to use for constructing the spatial graph.
86 | seed : int, optional
87 | Random seed for reproducibility (default is 0).
88 | """
89 | super(simcomen, self).__init__()
90 | # define the seed
91 | torch.manual_seed(seed)
92 | # set up the graph convolution
93 | self.conv1 = GCNConv(input_dim, output_dim, add_self_loops=False)
94 | # set up the linear layer for intracellular gene regulation
95 | self.lin = torch.nn.Linear(input_dim, output_dim)
96 | # define the neighbors
97 | self.n_neighbors = n_neighbors
98 | # define a tracking variable for the gene expression x matrix
99 | self.sphex = None
100 | self.gex = None
101 | self.output_dim = output_dim
102 |
103 | # define a function to artificially set the g2g matrix
104 | def set_g2g(self, g2g):
105 | """
106 | Artificially sets the gene-to-gene (G2G) interaction matrix.
107 |
108 | Parameters
109 | ----------
110 | g2g : torch.Tensor
111 | A tensor representing gene-to-gene interactions to be used in the graph convolution.
112 | """
113 | # set the weight as the input
114 | self.conv1.lin.weight = torch.nn.Parameter(g2g, requires_grad=False)
115 | # and then set the bias as all zeros
116 | self.conv1.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(self.output_dim).astype('float32')), requires_grad=False)
117 |
118 | # define a function to artificially set the g2g matrix
119 | def set_g2g_intra(self, g2g_intra):
120 | """
121 | Artificially sets the intracellular regulation matrix.
122 |
123 | Parameters
124 | ----------
125 | g2g_intra : torch.Tensor
126 | A tensor representing intracellular gene regulation interactions.
127 | """
128 | # set the weight as the input
129 | self.lin.weight = torch.nn.Parameter(g2g_intra, requires_grad=False)
130 | # and then set the bias as all zeros
131 | self.lin.bias = torch.nn.Parameter(torch.from_numpy(np.zeros(len(g2g_intra)).astype('float32')), requires_grad=False)
132 |
133 | # define a function to artificially set the sphex matrix
134 | def set_sphex(self, sphex):
135 | """
136 | Sets the spherical gene expression matrix for the forward pass.
137 |
138 | Parameters
139 | ----------
140 | sphex : torch.Tensor
141 | A tensor representing the spherical expression matrix.
142 | """
143 | self.sphex = torch.nn.Parameter(sphex, requires_grad=True)
144 |
145 | # define the forward pass
146 | def forward(self, edge_index, batch):
147 | """
148 | Forward pass of the model, calculates the messages between nodes using gene-to-gene interactions
149 | and intracellular gene regulation. Also calculates the log partition function using Mean Field Theory.
150 |
151 | Parameters
152 | ----------
153 | edge_index : torch.Tensor
154 | Tensor representing the graph edges (connections between nodes/cells).
155 | batch : torch.Tensor
156 | Tensor representing the batch of data.
157 |
158 | Returns
159 | -------
160 | msg : torch.Tensor
161 | Message passed between nodes based on gene-to-gene interactions.
162 | msg_intra : torch.Tensor
163 | Message passed within nodes based on intracellular gene regulation.
164 | log_z_mft : torch.Tensor
165 | Mean Field Theory approximation of the log partition function.
166 | """
167 | # compute the gex
168 | self.gex = self.calc_gex(self.sphex)
169 | #print( f"self.gex device is {self.gex.device}")
170 | #print( f"edge_index device is {edge_index.device}")
171 | # compute the message
172 | msg = self.conv1(self.gex, edge_index)
173 | # compute intracellular message
174 | msg_intra = self.lin(self.gex)
175 | # compute the log z mft
176 | log_z_mft = self.log_Z_mft(edge_index, batch)
177 | return msg, msg_intra, log_z_mft
178 |
179 | # define approximation function
180 | def log_Z_mft(self, edge_index, batch):
181 | """
182 | Computes the Mean Field Theory (MFT) approximation of the partition function.
183 | This function assumes that gene expression values are close to their mean across the spatial slide.
184 |
185 | Parameters
186 | ----------
187 | edge_index : torch.Tensor
188 | Tensor representing the graph edges (connections between nodes/cells).
189 | batch : torch.Tensor
190 | Tensor representing the batch of data.
191 |
192 | Returns
193 | -------
194 | log_z_mft : torch.Tensor
195 | Mean Field Theory approximation of the log partition function.
196 | """
197 | # retrieve number of spots
198 | num_spots = self.gex.shape[0]
199 | # calculate mean gene expression
200 | mean_genes = torch.mean(self.gex, axis=0).reshape(-1,1) # the mean should be per connected graph
201 | # calculate the norm of the sum of mean genes
202 | g = torch.norm(torch.mm( self.n_neighbors*self.conv1.lin.weight + 2*self.lin.weight, mean_genes))
203 | # calculate the contribution for mean values
204 | z_mean = - num_spots * torch.mm(torch.mm(torch.t(mean_genes), self.lin.weight + 0.5 * self.n_neighbors * self.conv1.lin.weight), mean_genes)
205 | # calculate the contribution gene interactions
206 | z_interaction = self.z_interaction(num_spots=num_spots, g=g)
207 | # add the two contributions
208 | log_z_mft = z_mean + z_interaction
209 | return log_z_mft
210 |
211 | def z_interaction(self, num_spots, g):
212 | """
213 | Calculates the interaction term for the partition function approximation, avoiding exploding exponentials.
214 |
215 | Parameters
216 | ----------
217 | num_spots : int
218 | Number of spots (cells) in the dataset.
219 | g : torch.Tensor
220 | Norm of the sum of mean gene expressions weighted by gene-to-gene interactions.
221 |
222 | Returns
223 | -------
224 | z_interaction : torch.Tensor
225 | Approximated interaction term for the partition function.
226 | """
227 | if g>20:
228 | z_interaction = num_spots * ( g - torch.log( g) )
229 | else:
230 | z_interaction = num_spots * torch.log((torch.exp( g) - torch.exp(- g))/( g))
231 | return z_interaction
232 |
233 | # define a function to derive the gex from the sphex
234 | def calc_gex(self, sphex):
235 | """
236 | Converts the spherical expression matrix into a regular gene expression matrix.
237 |
238 | Parameters
239 | ----------
240 | sphex : torch.Tensor
241 | The spherical gene expression matrix.
242 |
243 | Returns
244 | -------
245 | gex : torch.Tensor
246 | The converted regular gene expression matrix.
247 | """
248 | # setup the gex
249 | n_genes = sphex.shape[1]+1
250 | #gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32'), device=next(self.parameters()).device)
251 | gex = torch.zeros((sphex.shape[0], n_genes), dtype=torch.float32, device=next(self.parameters()).device)
252 | # compute the gex
253 | for idx in range(n_genes):
254 | if idx == n_genes-1:
255 | gex[:,idx] = torch.sin(sphex[:,idx-1])
256 | else:
257 | gex[:,idx] = torch.cos(sphex[:,idx])
258 | for idx_ in range(idx):
259 | gex[:,idx] *= torch.sin(sphex[:,idx_])
260 | return torch.nan_to_num(gex)
261 |
262 | # define a function to gather positions
263 | def get_pos(n_x, n_y):
264 | """
265 | Generates a 2D hexagonal grid of positions for spatial modelling.
266 |
267 | Parameters
268 | ----------
269 | n_x : int
270 | Number of positions along the x-axis.
271 | n_y : int
272 | Number of positions along the y-axis.
273 |
274 | Returns
275 | -------
276 | pos : numpy.ndarray
277 | Array of 2D positions for the grid.
278 | """
279 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)])
280 | # derive the y-step given a distance of one
281 | y_step = np.sqrt(1**2+0.5**2)
282 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)])
283 | # define the positions
284 | pos = np.vstack([xs.flatten(), ys.flatten()]).T
285 | return pos
286 |
287 |
288 | # define a function to normalize the g2g
289 | def normalize_g2g(g2g):
290 | """
291 | Symmetrizes and normalizes the gene-to-gene interaction matrix.
292 |
293 | Parameters
294 | ----------
295 | g2g : numpy.ndarray
296 | The gene-to-gene interaction matrix.
297 |
298 | Returns
299 | -------
300 | g2g : numpy.ndarray
301 | The normalized and symmetrized gene-to-gene interaction matrix.
302 | """
303 | # symmetrize the values
304 | g2g = (g2g + g2g.T) / 2
305 | # force them to be between 0-1
306 | g2g[g2g < 0] = 0
307 | g2g[g2g > 1] = 1
308 | # force the central line to be 1
309 | for idx in range(len(g2g)):
310 | g2g[idx, idx] = 1
311 | return g2g
312 |
313 | # define a function to derive the gex from the sphex
314 | def calc_sphex(self, gex):
315 | """
316 | Converts the regular gene expression matrix into a spherical gene expression matrix.
317 |
318 | Parameters
319 | ----------
320 | gex : torch.Tensor
321 | The regular gene expression matrix.
322 |
323 | Returns
324 | -------
325 | sphex : torch.Tensor
326 | The converted spherical gene expression matrix.
327 | """
328 | # setup the gex
329 | n_sgenes = gex.shape[1]-1
330 | #sphex = torch.from_numpy(np.zeros((gex.shape[0], n_sgenes)).astype('float32'), device=next(self.parameters()).device)
331 | sphex = torch.zeros((gex.shape[0], n_sgenes), dtype=torch.float32, device=next(self.parameters()).device)
332 | # compute the gex
333 | for idx in range(n_sgenes):
334 | sphex[:,idx] = gex[:,idx]
335 | for idx_ in range(idx):
336 | sphex[:,idx] /= torch.sin(sphex[:,idx_])
337 | sphex[:,idx] = torch.arccos(sphex[:,idx])
338 | return sphex
339 |
340 |
--------------------------------------------------------------------------------
/celcomen/training_plan/.ipynb_checkpoints/train-checkpoint.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import torch
4 |
5 | # def train(model, optimizer, train_loader, num_epochs, seed=1, device="cpu"):
6 |
7 | # if device is None:
8 | # device = torch.device(device)
9 |
10 | # torch.manual_seed(seed)
11 |
12 | # for epoch in range(num_epochs):
13 |
14 | # model = model.train()
15 | # for batch_idx, (features, labels) in enumerate(train_loader):
16 |
17 | # features, labels = features.to(device), labels.to(device)
18 |
19 | # logits = model(features)
20 |
21 | # loss = F.cross_entropy(logits, labels)
22 | # optimizer.zero_grad()
23 | # loss.backward()
24 | # optimizer.step()
25 |
26 |
27 | def train(num_epochs, learning_rate, model, loader, seed=1, device="cpu"):
28 |
29 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
30 | losses = []
31 | model.train()
32 | torch.manual_seed(seed)
33 |
34 | for epoch in range(num_epochs):
35 | losses_= []
36 |
37 | for data in loader:
38 | # move data to device
39 | data = data.to(device)
40 | # train loader # Iterate in batches over the training dataset.
41 | # set the appropriate gex
42 | model.set_gex(data.x)
43 | # derive the message as well as the mean field approximation
44 | msg, msg_intra, log_z_mft = model(data.edge_index, 1)
45 | # compute the loss and track it
46 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) )
47 | losses_.append(loss.detach().numpy()[0][0])
48 | # derive the gradients, update, and clear
49 | loss.backward()
50 | optimizer.step()
51 | optimizer.zero_grad()
52 | # repeatedly force a normalization
53 | model.conv1.lin.weight = torch.nn.Parameter(normalize_g2g(model.conv1.lin.weight), requires_grad=True)
54 | model.lin.weight = torch.nn.Parameter(normalize_g2g(model.lin.weight), requires_grad=True)
55 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
56 |
57 | print(f"Loss={np.mean(losses_)}")
58 | losses.append(np.mean(losses_))
59 |
60 | return losses
61 |
62 |
63 |
--------------------------------------------------------------------------------
/celcomen/training_plan/train.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import torch
4 | from ..utils.helpers import normalize_g2g, calc_sphex, calc_gex
5 |
6 | def train(epochs, learning_rate, model, loader, zmft_scalar=1e-1, seed=1, device="cpu", verbose=False):
7 | """
8 | Trains the model using a stochastic gradient descent (SGD) optimizer over the specified number of epochs.
9 |
10 | During training, the model calculates messages between genes based on gene-to-gene (G2G) interactions
11 | and applies a Mean Field Theory (MFT) approximation for gene expression interactions.
12 | The G2G and intracellular regulation matrices are normalized after each step.
13 |
14 | Parameters
15 | ----------
16 | epochs : int
17 | Number of training epochs to run.
18 | learning_rate : float
19 | Learning rate for the SGD optimizer.
20 | model : torch.nn.Module
21 | The model to be trained, which includes a graph convolutional layer and a linear layer.
22 | loader : torch_geometric.loader.DataLoader
23 | DataLoader that provides the data for each batch during training.
24 | zmft_scalar : float, optional
25 | Scalar to weight the Mean Field Theory term in the loss function. Default is 1e-1.
26 | seed : int, optional
27 | Seed for random number generation to ensure reproducibility. Default is 1.
28 | device : str, optional
29 | Device to use for training, e.g., "cpu" or "cuda". Default is "cpu".
30 | verbose : bool, optional
31 | If True, prints the loss at each epoch. Default is False.
32 |
33 | Returns
34 | -------
35 | losses : list of float
36 | List of losses recorded at each epoch.
37 |
38 | Examples
39 | --------
40 | >>> losses = train(epochs=100, learning_rate=0.01, model=my_model, loader=my_loader, zmft_scalar=1e-1, seed=42, device="cuda")
41 | >>> print(losses[-1]) # Final loss
42 | """
43 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
44 | losses = []
45 | model.train()
46 | torch.manual_seed(seed)
47 |
48 | for epoch in tqdm(range(epochs), total=epochs):
49 | losses_= []
50 |
51 | for data in loader:
52 | # move data to device
53 | data = data.to(device)
54 | # train loader # Iterate in batches over the training dataset.
55 | # set the appropriate gex
56 | model.set_gex(data.x)
57 | # derive the message as well as the mean field approximation
58 | msg, msg_intra, log_z_mft = model(data.edge_index, 1)
59 | # compute the loss and track it
60 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) )
61 | if device=="cpu":
62 | losses_.append(loss.detach().numpy()[0][0])
63 | else:
64 | losses_.append(loss.detach().cpu().numpy()[0][0])
65 | # derive the gradients, update, and clear
66 | loss.backward()
67 | optimizer.step()
68 | optimizer.zero_grad()
69 | # repeatedly force a normalization
70 | model.conv1.lin.weight = torch.nn.Parameter(normalize_g2g(model.conv1.lin.weight), requires_grad=True)
71 | model.lin.weight = torch.nn.Parameter(normalize_g2g(model.lin.weight), requires_grad=True)
72 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
73 |
74 | if verbose: print(f"Epoch={epoch} | Loss={np.mean(losses_)}")
75 | losses.append(np.mean(losses_))
76 |
77 | return losses
78 |
79 |
80 | def train_simcomen(epochs, learning_rate, model, edge_index, zmft_scalar=1e-1, seed=1, device="cpu", verbose=False):
81 | """
82 | Trains the `simcomen` model using stochastic gradient descent (SGD) over the specified number of epochs.
83 |
84 | The training process computes the Mean Field Theory (MFT) approximation of the partition function
85 | and normalizes gene expression data using a spherical representation. The loss is tracked and updated
86 | based on both gene-to-gene (G2G) and intracellular interactions.
87 |
88 | Parameters
89 | ----------
90 | epochs : int
91 | Number of training epochs to run.
92 | learning_rate : float
93 | Learning rate for the SGD optimizer.
94 | model : torch.nn.Module
95 | The `simcomen` model to be trained, which includes graph convolution and linear layers.
96 | edge_index : torch.Tensor
97 | Tensor representing the edges in the graph, i.e., the connections between nodes (cells).
98 | zmft_scalar : float, optional
99 | Scalar to weight the Mean Field Theory term in the loss function. Default is 1e-1.
100 | seed : int, optional
101 | Seed for random number generation to ensure reproducibility. Default is 1.
102 | device : str, optional
103 | Device to use for training, e.g., "cpu" or "cuda". Default is "cpu".
104 | verbose : bool, optional
105 | If True, prints the loss at each epoch. Default is False.
106 |
107 | Returns
108 | -------
109 | losses : list of float
110 | List of losses recorded at each epoch.
111 |
112 | Examples
113 | --------
114 | >>> losses = train_simcomen(epochs=100, learning_rate=0.01, model=my_model, edge_index=my_edge_index, zmft_scalar=1e-1, seed=42, device="cuda")
115 | >>> print(losses[-1]) # Final loss
116 | """
117 | # set up the optimizer
118 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
119 | # keep track of the losses per data object
120 | losses = []
121 | model.train()
122 | torch.manual_seed(seed)
123 |
124 | tmp_gexs = []
125 | # work through epochs
126 | for epoch in tqdm(range(epochs), total=epochs):
127 | # derive the message as well as the mean field approximation
128 | msg, msg_intra, log_z_mft = model(edge_index, 1)
129 | if (epoch % 5) == 0:
130 | if device=="cpu":
131 | tmp_gex = model.gex.clone().detach().numpy()
132 | else:
133 | tmp_gex = model.gex.clone().detach().cpu().numpy()
134 | tmp_gexs.append(tmp_gex)
135 | # compute the loss and track it
136 | loss = -(-log_z_mft + zmft_scalar * torch.trace(torch.mm(msg, torch.t(model.gex))) + zmft_scalar * torch.trace(torch.mm(msg_intra, torch.t(model.gex))) )
137 | if device=="cpu":
138 | losses.append(loss.detach().numpy()[0][0])
139 | else:
140 | losses.append(loss.detach().cpu().numpy()[0][0])
141 | # derive the gradients, update, and clear
142 | if verbose: print(f"Epoch={epoch} | Loss={np.mean(losses[-1])}")
143 |
144 | loss.backward()
145 | optimizer.step()
146 | optimizer.zero_grad()
147 |
148 | return losses
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/celcomen/utils/.ipynb_checkpoints/helpers-checkpoint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | # define a function to derive the gex from the sphex
5 | def calc_gex(sphex):
6 | """
7 | Calculates the gene expression matrix from the spherical
8 | """
9 | # setup the gex
10 | n_genes = sphex.shape[1]+1
11 | gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32'))
12 | # compute the gex
13 | for idx in range(n_genes):
14 | if idx == n_genes-1:
15 | gex[:,idx] = torch.sin(sphex[:,idx-1])
16 | else:
17 | gex[:,idx] = torch.cos(sphex[:,idx])
18 | for idx_ in range(idx):
19 | gex[:,idx] *= torch.sin(sphex[:,idx_])
20 | return gex
21 |
22 | # define a function to gather positions
23 | def get_pos(n_x, n_y):
24 | # create the hex lattice
25 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)])
26 | # derive the y-step given a distance of one
27 | y_step = np.sqrt(1**2+0.5**2)
28 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)])
29 | # define the positions
30 | pos = np.vstack([xs.flatten(), ys.flatten()]).T
31 | return pos
32 |
33 |
34 | # define a function to normalize the g2g
35 | def normalize_g2g(g2g):
36 | """
37 | Addresses any small fluctuations in symmetrical weights
38 | """
39 | # symmetrize the values
40 | g2g = (g2g + g2g.T) / 2
41 | # force them to be between 0-1
42 | g2g[g2g < 0] = 0
43 | g2g[g2g > 1] = 1
44 | # force the central line to be 1
45 | for idx in range(len(g2g)):
46 | g2g[idx, idx] = 1
47 | return g2g
48 |
--------------------------------------------------------------------------------
/celcomen/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import helpers
2 |
--------------------------------------------------------------------------------
/celcomen/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | # define a function to derive the gex from the sphex
5 | def calc_gex(sphex):
6 | """
7 | Converts a spherical gene expression matrix into a standard gene expression matrix.
8 |
9 | The function takes a spherical expression matrix (`sphex`) and calculates the corresponding
10 | gene expression matrix (`gex`) using trigonometric functions such as sine and cosine.
11 | The last gene is computed using sine, and the others are computed using cosine, followed by
12 | multiplying by the sine of all preceding dimensions.
13 |
14 | Parameters
15 | ----------
16 | sphex : torch.Tensor
17 | The spherical gene expression matrix. Expected shape is (n_samples, n_features - 1).
18 |
19 | Returns
20 | -------
21 | gex : torch.Tensor
22 | The computed gene expression matrix of shape (n_samples, n_features), where `n_features` is one more than in `sphex`.
23 | All NaN values are replaced with 0.
24 |
25 | Examples
26 | --------
27 | >>> sphex = torch.tensor([[0.5, 0.6], [0.3, 0.4]])
28 | >>> gex = calc_gex(sphex)
29 | >>> print(gex)
30 | tensor([[0.8776, 0.5646, 0.4794],
31 | [0.9553, 0.5646, 0.2955]])
32 | """
33 | # setup the gex
34 | n_genes = sphex.shape[1]+1
35 | gex = torch.from_numpy(np.zeros((sphex.shape[0], n_genes)).astype('float32'))
36 | # compute the gex
37 | for idx in range(n_genes):
38 | if idx == n_genes-1:
39 | gex[:,idx] = torch.sin(sphex[:,idx-1])
40 | else:
41 | gex[:,idx] = torch.cos(sphex[:,idx])
42 | for idx_ in range(idx):
43 | gex[:,idx] *= torch.sin(sphex[:,idx_])
44 | return torch.nan_to_num(gex)
45 |
46 | # define a function to gather positions
47 | def get_pos(n_x, n_y):
48 | """
49 | Generates a 2D hexagonal grid of positions.
50 |
51 | This function creates a hexagonal lattice for a 2D grid, where the x-coordinates are adjusted
52 | for alternating rows. The y-coordinates are spaced based on a predefined step size derived
53 | from the geometry of a hexagonal grid.
54 |
55 | Parameters
56 | ----------
57 | n_x : int
58 | Number of positions along the x-axis.
59 | n_y : int
60 | Number of positions along the y-axis.
61 |
62 | Returns
63 | -------
64 | pos : numpy.ndarray
65 | A 2D array of shape (n_x * n_y, 2) representing the coordinates of the positions on the hexagonal grid.
66 |
67 | Examples
68 | --------
69 | >>> pos = get_pos(3, 3)
70 | >>> print(pos)
71 | array([[0.5, 0. ],
72 | [1.5, 0. ],
73 | [2.5, 0. ],
74 | [0. , 1.11803399],
75 | [1. , 1.11803399],
76 | [2. , 1.11803399],
77 | [0.5, 2.23606798],
78 | [1.5, 2.23606798],
79 | [2.5, 2.23606798]])
80 | """
81 | # create the hex lattice
82 | xs = np.array([np.arange(0, n_x) + 0.5 if idx % 2 == 0 else np.arange(0, n_x) for idx in range(n_y)])
83 | # derive the y-step given a distance of one
84 | y_step = np.sqrt(1**2+0.5**2)
85 | ys = np.array([[y_step * idy] * n_x for idy in range(n_y)])
86 | # define the positions
87 | pos = np.vstack([xs.flatten(), ys.flatten()]).T
88 | return pos
89 |
90 |
91 | # define a function to normalize the g2g
92 | def normalize_g2g(g2g):
93 | """
94 | Symmetrizes and normalizes a gene-to-gene (G2G) interaction matrix.
95 |
96 | This function ensures that the matrix is symmetrical, normalizes values to be between 0 and 1,
97 | and forces the diagonal to be 1 (representing self-interactions).
98 |
99 | Parameters
100 | ----------
101 | g2g : numpy.ndarray
102 | The gene-to-gene interaction matrix, typically of shape (n_genes, n_genes).
103 |
104 | Returns
105 | -------
106 | g2g : numpy.ndarray
107 | The normalized and symmetrized gene-to-gene interaction matrix.
108 |
109 | Examples
110 | --------
111 | >>> g2g = np.array([[0.8, 0.2], [0.1, 0.7]])
112 | >>> normalized_g2g = normalize_g2g(g2g)
113 | >>> print(normalized_g2g)
114 | array([[1. , 0.15],
115 | [0.15, 1. ]])
116 | """
117 | # symmetrize the values
118 | g2g = (g2g + g2g.T) / 2
119 | # force them to be between 0-1
120 | g2g[g2g < 0] = 0
121 | g2g[g2g > 1] = 1
122 | # force the central line to be 1
123 | for idx in range(len(g2g)):
124 | g2g[idx, idx] = 1
125 | return g2g
126 |
127 | # define a function to derive the gex from the sphex
128 | def calc_sphex(gex):
129 | """
130 | Converts a standard gene expression matrix into a spherical expression matrix.
131 |
132 | This function calculates the spherical representation of a gene expression matrix (`gex`),
133 | where the new features are derived using trigonometric functions such as arcsin and arccos.
134 |
135 | Parameters
136 | ----------
137 | gex : torch.Tensor
138 | The standard gene expression matrix, of shape (n_samples, n_genes).
139 |
140 | Returns
141 | -------
142 | sphex : torch.Tensor
143 | The spherical gene expression matrix, of shape (n_samples, n_genes - 1).
144 |
145 | Examples
146 | --------
147 | >>> gex = torch.tensor([[0.8, 0.6, 0.4], [0.7, 0.5, 0.3]])
148 | >>> sphex = calc_sphex(gex)
149 | >>> print(sphex)
150 | tensor([[1.5708, 0.5404],
151 | [1.4706, 0.5236]])
152 | """
153 | # setup the gex
154 | n_sgenes = gex.shape[1]-1
155 | sphex = torch.from_numpy(np.zeros((gex.shape[0], n_sgenes)).astype('float32'))
156 | # compute the gex
157 | for idx in range(n_sgenes):
158 | sphex[:,idx] = gex[:,idx]
159 | for idx_ in range(idx):
160 | sphex[:,idx] /= torch.sin(sphex[:,idx_])
161 | sphex[:,idx] = torch.arccos(sphex[:,idx])
162 | return sphex
163 |
--------------------------------------------------------------------------------
/docs/.ipynb_checkpoints/Makefile-checkpoint:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/.ipynb_checkpoints/make-checkpoint.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'Celcomen'
10 | copyright = '2025, Stathis Megas'
11 | author = 'Stathis Megas'
12 | import celcomen as cce
13 | release = cce.__version__
14 |
15 | # -- General configuration ---------------------------------------------------
16 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
17 |
18 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary']
19 | templates_path = ['_templates']
20 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
21 |
22 | # -- Options for HTML output -------------------------------------------------
23 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
24 |
25 | html_theme = 'sphinx_rtd_theme'
26 | html_static_path = ['_static']
27 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Celcomen documentation master file, created by
2 | sphinx-quickstart on Thu Jan 9 18:35:31 2025.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Celcomen
7 | ========
8 | Project home page `here `_.
9 |
10 | Main workflow classes
11 | -----------------------
12 |
13 | The ``celcomen`` class
14 | ----------------------
15 |
16 | .. autoclass:: celcomen.models.celcomen.celcomen
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | The ``simcomen`` class
22 | ----------------------
23 |
24 | .. autoclass:: celcomen.models.simcomen.simcomen
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 | Utility functions
30 | -----------------
31 | .. autosummary::
32 | :toctree:
33 |
34 | celcomen.utils.helpers.calc_gex
35 | celcomen.datareaders.datareader.get_dataset_loaders
36 |
37 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-rtd-theme
2 |
--------------------------------------------------------------------------------
/docs/source/.ipynb_checkpoints/conf-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'Celcomen'
10 | copyright = '2025, Stathis Megas'
11 | author = 'Stathis Megas'
12 | release = '0.0.1'
13 |
14 | # -- General configuration ---------------------------------------------------
15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
16 |
17 | extensions = []
18 |
19 | templates_path = ['_templates']
20 | exclude_patterns = []
21 |
22 |
23 |
24 | # -- Options for HTML output -------------------------------------------------
25 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
26 |
27 | html_theme = 'sphinx_rtd_theme'
28 | html_static_path = ['_static']
29 |
--------------------------------------------------------------------------------
/docs/source/.ipynb_checkpoints/index-checkpoint.rst:
--------------------------------------------------------------------------------
1 | .. Celcomen documentation master file, created by
2 | sphinx-quickstart on Thu Jan 9 18:35:31 2025.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Celcomen documentation
7 | ======================
8 |
9 | Add your content using ``reStructuredText`` syntax. See the
10 | `reStructuredText `_
11 | documentation for details.
12 |
13 |
14 | .. toctree::
15 | :maxdepth: 2
16 | :caption: Contents:
17 |
18 |
--------------------------------------------------------------------------------
/images/disentangling graphs and gene colocalization-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teichlab/celcomen/f5d19bd24e968a45b057b1ea3aef5c6ce5974303/images/disentangling graphs and gene colocalization-2.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.2,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "celcomen"
7 | authors = [{name = "Stathis Megas", email = "stathismegas@gmail.com"}]
8 | version = "0.0.2"
9 | description = "Celcomen is a first step towards models of Virtual Tissue, which generalize models of Virtual Cells to the tissue context, and can be leveraged to predict tissue counterfactuals such as spatially localised gene knockouts"
10 | dependencies = [
11 | "numpy",
12 | "torch",
13 | "torch_geometric",
14 | "scikit-learn",
15 | "scanpy",
16 | "tqdm"
17 | ]
18 |
--------------------------------------------------------------------------------
/readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-20.04
11 | tools:
12 | python: "3.10"
13 |
14 | # Build documentation in the docs/ directory with Sphinx
15 | sphinx:
16 | configuration: docs/conf.py
17 |
18 | # Install our python package before building the docs
19 | python:
20 | install:
21 | - requirements: docs/requirements.txt
22 | - method: pip
23 | path: .
24 |
--------------------------------------------------------------------------------