├── .gitignore
├── GenKI
├── __init__.py
├── dataLoader.py
├── model.py
├── pcNet.py
├── preprocesing.py
├── train.py
├── tune.py
├── utils.py
└── version.py
├── LICENSE
├── MANIFEST.in
├── README.md
├── data
└── README.md
├── environment.yml
├── logo.jpg
├── notebook
├── Compatible_matlab.ipynb
├── Example.ipynb
└── SERGIO.py
├── requirements.txt
├── run_eva.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # custom
132 | notebook/*
133 | .Rproj.user
134 | BACKUP
135 | data*
136 | log_dir
137 | !notebook/*.py
138 | !notebook/Example.ipynb
139 | !notebook/Compatible_matlab.ipynb
140 |
--------------------------------------------------------------------------------
/GenKI/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/GenKI/dataLoader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import anndata
4 | import os
5 | import torch
6 | from torch_geometric.data import Data
7 | import matplotlib.pyplot as plt
8 | # from tqdm import tqdm
9 | from .pcNet import make_pcNet
10 | from .preprocesing import check_adata
11 |
12 |
13 | class scBase():
14 | def __init__(self,
15 | adata: anndata.AnnData,
16 | target_gene: list[str],
17 | target_cell: str = None,
18 | obs_label: str = "ident",
19 | GRN_file_dir: str = "GRNs",
20 | rebuild_GRN: bool = False,
21 | pcNet_name: str = "pcNet",
22 | verbose: bool = False,
23 | **kwargs):
24 |
25 | check_adata(adata)
26 | self._gene_names = list(adata.var_names)
27 | if isinstance(target_gene[0], int): # list[int]
28 | target_gene = adata.var_names[target_gene].tolist()
29 | if all([g not in self._gene_names for g in target_gene]):
30 | raise IndexError("The input target gene should be in the gene list of adata")
31 | else:
32 | self._target_gene = target_gene
33 | # self._idx_KO = self._gene_names.index(target_gene)
34 | if target_cell is None:
35 | self._counts = adata.X # use all cells, standardized counts
36 | if verbose:
37 | print(f"use all the cells ({self._counts.shape[0]}) in adata")
38 | elif not (obs_label in adata.obs.keys()):
39 | raise IndexError("require a valid cell label in adata.obs")
40 | else:
41 | self._counts = adata[adata.obs[obs_label] == target_cell, :].X
42 | self._counts = scipy.sparse.lil_matrix(self._counts) # sparse
43 | pcNet_path = os.path.join(GRN_file_dir, f"{pcNet_name}.npz")
44 | if rebuild_GRN:
45 | if verbose:
46 | print("build GRN")
47 | self._net = make_pcNet(adata.layers["norm"], nComp = 5, as_sparse = True, timeit = verbose, **kwargs)
48 | os.makedirs(GRN_file_dir, exist_ok = True) # create dir
49 | scipy.sparse.save_npz(pcNet_path, self._net) # save GRN
50 | # scipy.sparse.lil_matrix(pcNet_np) # np to sparse
51 | if verbose:
52 | print(f"GRN has been built and saved in \"{pcNet_path}\"")
53 | else:
54 | try:
55 | if verbose:
56 | print(f"loading GRN from \"{pcNet_path}\"")
57 | self._net = scipy.sparse.load_npz(pcNet_path)
58 | except ImportError:
59 | print("require npz file name")
60 | if verbose:
61 | print("init completed\n")
62 |
63 | @property
64 | def counts(self):
65 | return self._counts
66 |
67 | @property
68 | def net(self):
69 | return self._net
70 |
71 | @property
72 | def target_gene(self):
73 | return self._target_gene
74 |
75 |
76 | def __len__(self):
77 | return len(self._counts)
78 |
79 |
80 | def __call__(self, gene_name: list[str]):
81 | return [self._gene_names.index(g) for g in gene_name] # return gene index
82 |
83 |
84 | def __repr__(self):
85 | info = f"counts: {self._counts.shape}\n"\
86 | f"net: {self._net.shape}\n"\
87 | f"target_gene: {self._target_gene}"
88 | return info
89 |
90 |
91 | class DataLoader(scBase):
92 | def __init__(self,
93 | adata: anndata.AnnData,
94 | target_gene: list[str],
95 | target_cell: str = None,
96 | obs_label: str = "ident",
97 | GRN_file_dir: str = "GRNs",
98 | rebuild_GRN: bool = False,
99 | pcNet_name: str = "pcNet",
100 | cutoff: int = 85,
101 | verbose: bool = False,
102 | **kwargs):
103 | super().__init__(adata,
104 | target_gene,
105 | target_cell,
106 | obs_label,
107 | GRN_file_dir,
108 | rebuild_GRN,
109 | pcNet_name,
110 | verbose,
111 | **kwargs)
112 | self.verbose = verbose
113 | self.cutoff = cutoff
114 | self.edge_index = torch.tensor(self._build_edges(), dtype = torch.long) # dense np to tensor
115 |
116 |
117 | def _build_edges(self, net = None):
118 | '''
119 | net: array-like, GRN built from data.
120 | '''
121 | if net is None:
122 | net = self._net
123 | grn_to_use = abs(net.toarray()) if scipy.sparse.issparse(net) else abs(net)
124 | grn_to_use[grn_to_use < np.percentile(grn_to_use, self.cutoff)] = 0
125 | edge_index_np = np.asarray(np.where(grn_to_use > 0), dtype = int)
126 | return edge_index_np # dense np
127 |
128 | def load_data(self):
129 | return self._data_init()
130 |
131 | def load_kodata(self):
132 | return self._KO_data_init()
133 |
134 |
135 | def _data_init(self):
136 | counts = self._counts.toarray() if scipy.sparse.issparse(self._counts) else self._counts
137 | x = torch.tensor(counts.T, dtype = torch.float) # define x
138 | return Data(x = x, edge_index = self.edge_index, y = self._gene_names)
139 |
140 |
141 | def _KO_data_init(self):
142 | # KO edges
143 | mask = ~(torch.isin(self.edge_index[0], torch.tensor(self(self._target_gene))) + torch.isin(self.edge_index[1], torch.tensor(self(self._target_gene))))
144 | edge_index_KO = self.edge_index[:, mask] # torch.long
145 |
146 | # KO counts
147 | counts_KO = self._counts.copy()
148 | counts_KO[:, self(self._target_gene)] = 0
149 | counts_KO = counts_KO.toarray() if scipy.sparse.issparse(counts_KO) else counts_KO
150 | x_KO = torch.tensor(counts_KO.T, dtype = torch.float) # define counts (KO)
151 | # if self.verbose:
152 | # print(f"set expression of {self._target_gene} to zeros and remove edges")
153 | return Data(x = x_KO, edge_index = edge_index_KO, y = self._gene_names)
154 |
155 |
156 | def _gen_ZINB(self,
157 | p_BIN = 0.95,
158 | n_NB = 10,
159 | p_NB = 0.5,
160 | noise = True,
161 | normalize = True,
162 | show = False,
163 | decays: list[float] = [1.]):
164 | '''
165 | p_BIN: parameter of binomial distribution.
166 | n_NB, p_NB: parameters of negative binomial distribution.
167 | noise: add noise to ZINB samples.
168 | normalize: LogNorm samples, scale_factor = 10e4.
169 | show: show histograms of simulated samples and WT expression.
170 | decay: scales on n_NB simulating target gene expression.
171 | '''
172 | np.random.seed(11)
173 | if decays[0] != 1.:
174 | raise ValueError("first element in decays should always be 1")
175 | elif len(decays) != len(self._target_gene):
176 | raise ValueError("length of decays should match target gene")
177 | else:
178 | s_BIN = np.random.binomial(1, p_BIN, self._counts.shape[0]) # sample binomial 0 and 1
179 | s = np.array([s_BIN * np.random.negative_binomial(decay*n_NB, p_NB, self._counts.shape[0]) for decay in decays]).T # ZINB
180 | if noise:
181 | s[s > 0] = s[s > 0]+ np.random.choice([-1, 0, 1], size = len(s[s > 0]), p = [0.15, 0.7, 0.15]) # add int noise
182 | if normalize:
183 | scale_factor = 10e4
184 | s = np.log1p(s * scale_factor / s.sum()) # LogNorm
185 | if show:
186 | fig, ax = plt.subplots(ncols = 2, figsize = (12, 5))
187 | counts = self._counts.toarray() if scipy.sparse.issparse(self._counts) else self._counts
188 | for i, (values, title) in enumerate(zip([counts[:, self(self._target_gene)], s], ['WT', 'Simulated'])):
189 | ax[i].hist(values, bins = counts.shape[0]//5, label = self._target_gene)
190 | ax[i].set_title(title)
191 | ax[i].legend()
192 | plt.show()
193 | if self.verbose:
194 | print(f"sample gene patterns ({self._counts.shape[0]}) from NB{n_NB, p_NB} with P(zero) = {round(1-p_BIN, 2)}")
195 | return s
196 |
197 |
198 | def scale_rows_and_cols(self, scale): # scale target gene edges in pcNet as a whole
199 | mask = np.ones(len(self._gene_names), dtype = bool)
200 | mask[self(self._target_gene)] = 0
201 | mask = np.invert(mask[:, None] @ mask[None, :])
202 | self._net[mask] *= scale
203 |
204 |
205 | def OE_data_init(self,
206 | weight_scale: list[float] = [10.,],
207 | gene_patterns: list[float] = None,
208 | **kwargs):
209 | # OE edges
210 | weight_scale = np.asarray(weight_scale)
211 | net_OE = self._net.toarray().copy()
212 | net_OE[:, self(self._target_gene)] *= weight_scale[None, :]
213 | net_OE[self(self._target_gene), :] *= weight_scale[:, None] # note cross points get multiplied twice
214 |
215 | edge_index_OE = self._build_edges(net_OE)
216 | edge_index_OE = torch.tensor(edge_index_OE, dtype = torch.long)
217 |
218 | # OE counts
219 | counts_OE = self._counts.copy()
220 | orig_counts = counts_OE[:, self(self._target_gene)]
221 | if gene_patterns is not None:
222 | counts_OE[:, self(self._target_gene)] = gene_patterns
223 | else:
224 | decays = [w/weight_scale[0] for w in weight_scale] # n_NB = coeff * n_NB(10), coeff <= 1
225 | counts_OE[:, self(self._target_gene)] = self._gen_ZINB(n_NB = weight_scale[0], decays = decays, **kwargs)
226 | counts_OE = counts_OE.toarray() if scipy.sparse.issparse(counts_OE) else counts_OE
227 | x_OE = torch.tensor(counts_OE.T, dtype = torch.float)
228 | if self.verbose:
229 | print(f"replace expression of {self._target_gene} to simulated expressions and edges by scale {weight_scale}")
230 | return Data(x = x_OE, edge_index = edge_index_OE, y = self._gene_names)
231 |
232 |
233 | # def run_sys_KO(self, model, genelist):
234 | # '''
235 | # model: a trained VGAE model.
236 | # genelist: array-like, gene list to be systematic KO
237 | # '''
238 | # self.verbose = False
239 | # g_orig = self._target_gene
240 | # data = self.data_init()
241 | # z_m0, z_S0 = get_latent_vars(data, model)
242 | # sys_res = []
243 | # from tqdm import tqdm
244 | # for g in tqdm(genelist, desc = "systematic KO...", total = len(genelist)):
245 | # if g not in self._gene_names:
246 | # raise IndexError(f'"{g}" is not in the object')
247 | # else:
248 | # self._target_gene = g # reset KO gene
249 | # data_v = self.KO_data_init()
250 | # z_mv, z_Sv = get_latent_vars(data_v, model)
251 | # dis = get_distance(z_mv, z_Sv, z_m0, z_S0, by = "KL")
252 | # sys_res.append(dis)
253 | # self._target_gene = g_orig
254 | # # print(self._target_gene)
255 | # self.verbose = True
256 | # return np.array(sys_res)
257 |
--------------------------------------------------------------------------------
/GenKI/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import negative_sampling
3 |
4 |
5 | EPS = 1e-15
6 | MAX_LOGSTD = 10
7 |
8 |
9 | def reset(value):
10 | if hasattr(value, 'reset_parameters'):
11 | value.reset_parameters()
12 | else:
13 | for child in value.children() if hasattr(value, 'children') else []:
14 | reset(child)
15 |
16 |
17 | class InnerProductDecoder(torch.nn.Module):
18 | """The inner product decoder."""
19 | def forward(self, z, edge_index, sigmoid=True):
20 | value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
21 | return torch.sigmoid(value) if sigmoid else value
22 |
23 | def forward_all(self, z, sigmoid=True):
24 | adj = torch.matmul(z, z.t())
25 | return torch.sigmoid(adj) if sigmoid else adj
26 |
27 |
28 | class GAE(torch.nn.Module):
29 | """The Graph Auto-Encoder model."""
30 | def __init__(self, encoder, decoder=None):
31 | super().__init__()
32 | self.encoder = encoder
33 | self.decoder = InnerProductDecoder() if decoder is None else decoder
34 | GAE.reset_parameters(self)
35 |
36 | def reset_parameters(self):
37 | reset(self.encoder)
38 | reset(self.decoder)
39 |
40 | def encode(self, *args, **kwargs):
41 | return self.encoder(*args, **kwargs)
42 |
43 | def decode(self, *args, **kwargs):
44 | return self.decoder(*args, **kwargs)
45 |
46 | def recon_loss(self, z, pos_edge_index, neg_edge_index=None):
47 | """Given latent embeddings z, computes the binary cross
48 | entropy loss for positive edges pos_edge_index and negative
49 | sampled edges.
50 |
51 | Args:
52 | z (Tensor): The latent embeddings.
53 | pos_edge_index (LongTensor): The positive edges to train against.
54 | neg_edge_index (LongTensor, optional): The negative edges to train
55 | against. If not given, uses negative sampling to calculate
56 | negative edges. (default: :obj:`None`)
57 | """
58 |
59 | pos_loss = -torch.log(
60 | self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()
61 |
62 | if neg_edge_index is None:
63 | neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
64 | neg_loss = -torch.log(1 -
65 | self.decoder(z, neg_edge_index, sigmoid=True) +
66 | EPS).mean()
67 |
68 | return pos_loss + neg_loss
69 |
70 | def test(self, z, pos_edge_index, neg_edge_index):
71 | """Given latent embeddings z, positive edges
72 | pos_edge_index and negative edges neg_edge_index
73 | computes metrics.
74 |
75 | Args:
76 | z (Tensor): The latent embeddings.
77 | pos_edge_index (LongTensor): The positive edges to evaluate
78 | against.
79 | neg_edge_index (LongTensor): The negative edges to evaluate
80 | against.
81 | """
82 | from sklearn.metrics import average_precision_score, roc_auc_score #f1_score, confusion_matrix
83 |
84 | pos_y = z.new_ones(pos_edge_index.size(1))
85 | neg_y = z.new_zeros(neg_edge_index.size(1))
86 | y = torch.cat([pos_y, neg_y], dim=0)
87 |
88 | pos_pred = self.decoder(z, pos_edge_index, sigmoid=True)
89 | neg_pred = self.decoder(z, neg_edge_index, sigmoid=True)
90 | pred = torch.cat([pos_pred, neg_pred], dim=0)
91 |
92 | y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
93 | # pred[pred < 0.5] = 0
94 | # pred[pred >= 0.5] = 1
95 | # tn, fp, fn, tp = confusion_matrix(y, pred, labels=[0, 1]).ravel()
96 |
97 | return roc_auc_score(y, pred), average_precision_score(y, pred) # f1_score(y, pred), [tn, fp, fn, tp]
98 |
99 |
100 | class VGAE(GAE):
101 | """The Variational Graph Auto-Encoder model.
102 |
103 | Args:
104 | encoder (Module): The encoder module to compute :math:`\mu` and
105 | :math:`\log\sigma^2`.
106 | decoder (Module, optional): The decoder module. If set to :obj:`None`,
107 | will default to the
108 | :class:`torch_geometric.nn.models.InnerProductDecoder`.
109 | (default: :obj:`None`)
110 | """
111 | def __init__(self, encoder, decoder=None):
112 | super().__init__(encoder, decoder)
113 |
114 | def reparametrize(self, mu, logstd):
115 | if self.training:
116 | return mu + torch.randn_like(logstd) * torch.exp(logstd)
117 | else:
118 | return mu
119 |
120 | def encode(self, *args, **kwargs):
121 | self.__mu__, self.__logstd__ = self.encoder(*args, **kwargs)
122 | self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD)
123 | z = self.reparametrize(self.__mu__, self.__logstd__)
124 | return z
125 |
126 | def kl_loss(self, mu=None, logstd=None):
127 | """Computes the KL loss, either for the passed arguments :obj:`mu`
128 | and :obj:`logstd`, or based on latent variables from last encoding.
129 |
130 | Args:
131 | mu (Tensor, optional): The latent space for :math:`\mu`. If set to
132 | :obj:`None`, uses the last computation of :math:`mu`.
133 | (default: :obj:`None`)
134 | logstd (Tensor, optional): The latent space for
135 | :math:`\log\sigma`. If set to :obj:`None`, uses the last
136 | computation of :math:`\log\sigma^2`.(default: :obj:`None`)
137 | """
138 | mu = self.__mu__ if mu is None else mu
139 | logstd = self.__logstd__ if logstd is None else logstd.clamp(
140 | max=MAX_LOGSTD)
141 | return -0.5 * torch.mean(
142 | torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))
143 |
--------------------------------------------------------------------------------
/GenKI/pcNet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import svd
3 | from scipy import sparse
4 | import os
5 | import time
6 |
7 | try:
8 | import ray
9 | except ImportError:
10 | class dummy_ray():
11 | def __init__(self):
12 | self.remote = lambda x:0
13 | print("ray not initialized")
14 | def is_initialized(self):
15 | return False
16 | def init(self, num_cpus):
17 | self.num_cpu = num_cpus
18 | def put(self, x):
19 | self.x = x
20 | ray = dummy_ray()
21 |
22 |
23 | def pcCoefficients(X, K, nComp):
24 | y = X[:, K]
25 | Xi = np.delete(X, K, 1)
26 | U, s, VT = svd(Xi, full_matrices=False)
27 | #print ('U:', U.shape, 's:', s.shape, 'VT:', VT.shape)
28 | V = VT[:nComp, :].T
29 | #print('V:', V.shape)
30 | score = Xi@V
31 | t = np.sqrt(np.sum(score**2, axis=0))
32 | score_lsq = ((score.T / (t**2)[:, None])).T
33 | beta = np.sum(y[:, None]*score_lsq, axis=0)
34 | beta = V@beta
35 | return list(beta)
36 |
37 |
38 | def pcNet(X, # X: cell * gene
39 | nComp: int = 3,
40 | scale: bool = True,
41 | symmetric: bool = True,
42 | q: float = 0., # q: 0-100
43 | as_sparse: bool = True,
44 | random_state: int = 0):
45 | X = X.toarray() if sparse.issparse(X) else X
46 | if nComp < 2 or nComp >= X.shape[1]:
47 | raise ValueError("nComp should be greater or equal than 2 and lower than the total number of genes")
48 | else:
49 | np.random.seed(random_state)
50 | n = X.shape[1] # genes
51 | B = np.array([pcCoefficients(X, k, nComp) for k in range(n)])
52 | A = np.ones((n, n), dtype=float)
53 | np.fill_diagonal(A, 0)
54 | for i in range(n):
55 | A[i, A[i, :]==1] = B[i, :]
56 | if scale:
57 | absA = abs(A)
58 | A = A / np.max(absA)
59 | if q > 0:
60 | A[absA < np.percentile(absA, q)] = 0
61 | if symmetric: # place in the end
62 | A = (A + A.T)/2
63 | #diag(A) <- 0
64 | if as_sparse:
65 | A = sparse.csc_matrix(A)
66 | return A
67 |
68 |
69 | @ray.remote
70 | def pc_net_parallel(X, # X: cell * gene
71 | nComp: int = 3,
72 | scale: bool = True,
73 | symmetric: bool = True,
74 | q: float = 0.,
75 | as_sparse: bool = True,
76 | random_state: int = 0):
77 | return pcNet(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q,
78 | as_sparse = as_sparse, random_state = random_state)
79 |
80 |
81 | def pc_net_single(X, # X: cell * gene
82 | nComp: int = 3,
83 | scale: bool = True,
84 | symmetric: bool = True,
85 | q: float = 0.,
86 | as_sparse: bool = True,
87 | random_state: int = 0):
88 | return pcNet(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q,
89 | as_sparse = as_sparse, random_state = random_state)
90 |
91 |
92 | def make_pcNet(X,
93 | nComp: int = 3,
94 | scale: bool = True,
95 | symmetric: bool = True,
96 | q: float = 0.,
97 | as_sparse: bool = True,
98 | random_state: int = 0,
99 | n_cpus: int = 1, # -1: use all CPUs
100 | timeit: bool = True):
101 | start_time = time.time()
102 | if n_cpus != 1:
103 | if ray.is_initialized():
104 | ray.shutdown()
105 | if n_cpus == -1:
106 | n_cpus = os.cpu_count()
107 | ray.init(num_cpus = n_cpus)
108 | print(f"ray init, using {n_cpus} CPUs")
109 |
110 | X_ray = ray.put(X) # put X to distributed object store and return object ref (ID)
111 | # print(X_ray)
112 | net = pc_net_parallel.remote(X_ray, nComp = nComp, scale = scale, symmetric = symmetric, q = q,
113 | as_sparse = as_sparse, random_state = random_state)
114 | net = ray.get(net)
115 | else:
116 | net = pc_net_single(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q,
117 | as_sparse = as_sparse, random_state = random_state)
118 | if ray.is_initialized():
119 | ray.shutdown()
120 | if timeit:
121 | duration = time.time() - start_time
122 | print("execution time of making pcNet: {:.2f} s".format(duration))
123 | return net
124 |
125 |
126 | def main():
127 | counts = np.random.randint(0, 10, (5, 100))
128 | net = make_pcNet(counts, as_sparse = True, timeit = True, n_cpus = 1)
129 | print(f"input counts shape: {counts.shape},\nmake pcNet completed, shape: {net.shape}")
130 | return 100
131 |
132 |
133 | if __name__ == "__main__":
134 | import sys
135 | sys.exit(main())
136 |
--------------------------------------------------------------------------------
/GenKI/preprocesing.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from anndata import (
3 | AnnData,
4 | read_h5ad,
5 | read_csv,
6 | read_excel,
7 | read_hdf,
8 | read_loom,
9 | read_mtx,
10 | read_text,
11 | )
12 | import scanpy as sc
13 | import pandas as pd
14 | from typing import Union
15 | from scipy import sparse
16 | import os
17 | import pickle
18 | from torch_geometric.data import Data
19 | from torch_geometric.transforms import RandomLinkSplit
20 | from torch_geometric import seed_everything
21 | sc.settings.verbosity = 0
22 |
23 | CURRENT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) # GenKI
24 | # https://scanpy.readthedocs.io/en/stable/api.html#reading
25 | file_attrs = ["h5ad", "csv", "xlsx", "h5", "loom", "mtx", "txt", "mat"]
26 | read_fun_keys = ["h5ad", "csv", "excel", "hdf", "loom", "mtx", "text"] # fun from scanpy read
27 |
28 |
29 | def _read_counts(counts_path: str,
30 | transpose: bool = False,
31 | **kwargs):
32 |
33 | """Read counts file to build an AnnData.
34 | Args:
35 | counts_path (str): Path to counts file.
36 | transpose (bool, optional): Whether transpose the counts.
37 | Returns:
38 | AnnData
39 | """
40 |
41 | file_attr = counts_path.split(".")[-1]
42 | if Path(counts_path).is_file() and file_attr in file_attrs:
43 | print(f"load counts from {counts_path}")
44 | if file_attr == "mat":
45 | import numpy as np
46 | import h5py
47 | f = h5py.File(counts_path,'r')
48 | # print(f.keys())
49 | counts = np.array(f.get(list(f.keys())[0]), dtype="float64")
50 | if transpose:
51 | counts = counts.T
52 | adata = sc.AnnData(counts)
53 | else:
54 | read_fun_key = read_fun_keys[file_attrs.index(file_attr)]
55 | read_fun = getattr(sc, f"read_{read_fun_key}") # define sc.read_{} function
56 | if transpose:
57 | adata = read_fun(counts_path, **kwargs).transpose() # transpose counts file
58 | else:
59 | adata = read_fun(counts_path, **kwargs)
60 | else:
61 | raise ValueError("incorrect file path given to counts")
62 | return adata
63 |
64 |
65 | def build_adata(
66 | adata: Union[str, AnnData],
67 | meta_gene_path: Union[None, str] = None,
68 | meta_cell_path: Union[None, str] = None,
69 | meta_cell_cols: Union[None, str] = None,
70 | sep = "\t",
71 | header = None,
72 | log_normalize: bool = False,
73 | scale_data: bool = True,
74 | as_sparse: bool = True,
75 | uppercase: bool = True,
76 | **kwargs,
77 | ):
78 |
79 | """Load counts, metadata of genes and cells to build an AnnData input for scTenifoldXct.
80 | Args:
81 | adata (str): Path to counts or adata.
82 | meta_gene (Union[None, str]): Path to metadata of variables.
83 | meta_cell (Union[None, str]): Path to metadata of Observations
84 | sep (str, optional): The delimiter for metadata. Defaults to "\t".
85 | log_normalize (bool, optional): Whether log-normalize the counts. Defaults to False.
86 | scale_data (bool, optional): Whether standardize the counts. Defaults to False.
87 | as_sparse (bool, optional): Whether make the counts sparse. Defaults to True.
88 | uppercase (bool, optional): Whether convert gene names to uppercase. Defaults to True.
89 | **kwargs: key words in _read_counts
90 | Returns:
91 | AnnData
92 | """
93 | if not isinstance(adata, AnnData):
94 | adata = _read_counts(adata, **kwargs)
95 |
96 | if meta_gene_path is not None and Path(meta_gene_path).is_file():
97 | try:
98 | # print("add metadata for genes")
99 | df_gene = pd.read_csv(meta_gene_path, header=header, sep=sep)
100 | df_gene.index = df_gene.index.astype("str")
101 | adata.var_names = df_gene[0]
102 | except Exception:
103 | raise ValueError("incorrect file path given to meta_gene")
104 | if meta_cell_path is not None and Path(meta_cell_path).is_file():
105 | try:
106 | # print("add metadata for cells")
107 | df_cell = pd.read_csv(meta_cell_path, header=header, sep=sep)
108 | df_cell.index = df_cell.index.astype("str")
109 | adata.obs = df_cell
110 | if meta_cell_cols is not None:
111 | adata.obs.columns = meta_cell_cols
112 | except Exception:
113 | raise ValueError("incorrect file path given to meta_cell")
114 |
115 | if log_normalize:
116 | # print("normalize counts")
117 | adata.layers["raw"] = adata.X
118 | sc.pp.normalize_total(adata, target_sum=1e4)
119 | sc.pp.log1p(adata)
120 |
121 |
122 | if scale_data:
123 | adata.layers["norm"] = adata.X
124 | # print("standardize counts")
125 | from sklearn import preprocessing
126 | counts = adata.X.toarray() if sparse.issparse(adata.X) else adata.X
127 | scaler = preprocessing.StandardScaler().fit(counts)
128 | adata.X = scaler.transform(counts)
129 |
130 | if as_sparse:
131 | # print("make counts sparse")
132 | adata.X = (
133 | sparse.csr_matrix(adata.X) if not sparse.issparse(adata.X) else adata.X
134 | )
135 |
136 | if uppercase:
137 | adata.var_names = adata.var_names.str.upper() # all species use upper case genes
138 |
139 | return adata
140 |
141 |
142 | def check_adata(adata):
143 | if "norm" not in adata.layers:
144 | raise ValueError("normalized counts should be saved at adata.layers[\"norm\"]")
145 | if not (adata.X.toarray() < 0).any():
146 | raise ValueError("adata.X should be standardized counts")
147 |
148 |
149 | def save_gdata(data, dir: str = "data", name: str = "data"):
150 | path = os.path.join(CURRENT_DIR, dir)
151 | os.makedirs(path, exist_ok = True)
152 | print(f"save {name} to dir {path}")
153 | with open(os.path.join(path, f'{name}.p'), 'wb') as f:
154 | pickle.dump(data.to_dict(), f, protocol=pickle.HIGHEST_PROTOCOL)
155 |
156 |
157 | def load_gdata(dir: str = "data", name: str = "data"):
158 | with open(os.path.join(dir, f"{name}.p"), 'rb') as f:
159 | data = pickle.load(f)
160 | data = Data.from_dict(data)
161 | print(f"load {name} from dir {dir}")
162 | return data
163 |
164 |
165 | def split_data(dir: str = "data", data = None, load: bool = False, save: bool = False):
166 | path = os.path.join(CURRENT_DIR, dir)
167 | if load:
168 | train_data, val_data, test_data = load_gdata(path, "train_data"), load_gdata(path, "val_data"), load_gdata(path, "test_data")
169 | return train_data, val_data, test_data
170 | elif data is None:
171 | data = load_gdata(path, "data")
172 | seed_everything(42)
173 | transform = RandomLinkSplit(is_undirected = True,
174 | split_labels = True,
175 | num_val = 0.05,
176 | num_test = 0.2,
177 | )
178 | train_data, val_data, test_data = transform(data)
179 | # print("split data into train/valid/test")
180 | if save:
181 | save_gdata(train_data, dir, "train_data")
182 | save_gdata(val_data, dir, "val_data")
183 | save_gdata(test_data, dir, "test_data")
184 | del train_data.pos_edge_label, train_data.neg_edge_label
185 | return train_data, val_data, test_data
186 |
187 |
188 | if __name__ == "__main__":
189 | from os import path
190 |
191 | data_dir = path.join(
192 | path.dirname(path.dirname(path.abspath(__file__))), "filtered_gene_bc_matrices"
193 | )
194 | counts_path = str(path.join(data_dir, "matrix.mtx"))
195 | gene_path = str(path.join(data_dir, "genes.tsv"))
196 | cell_path = str(path.join(data_dir, "barcodes.tsv"))
197 | adata = build_adata(counts_path, gene_path, cell_path)
198 | print(adata)
199 |
--------------------------------------------------------------------------------
/GenKI/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.data import Data
4 | from torch_geometric.nn import GCNConv
5 | from .model import VGAE
6 | import torch.utils.tensorboard as tb
7 | import os
8 | import matplotlib.pyplot as plt
9 | from tqdm import tqdm
10 |
11 | from .utils import get_distance
12 | from .preprocesing import split_data
13 |
14 |
15 | class VariationalGCNEncoder(torch.nn.Module): # encoder
16 | def __init__(self, in_channels, out_channels, hidden = 2):
17 | super(VariationalGCNEncoder, self).__init__()
18 | self.conv1 = GCNConv(in_channels, hidden * out_channels)
19 | self.conv_mu = GCNConv(hidden * out_channels, out_channels)
20 | self.conv_logstd = GCNConv(hidden * out_channels, out_channels)
21 |
22 | def forward(self, x, edge_index):
23 | x = self.conv1(x, edge_index).relu()
24 | return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
25 |
26 |
27 | class VGAE_trainer():
28 | def __init__(self,
29 | data,
30 | out_channels: int = 2,
31 | epochs: int = 100,
32 | lr: float = 7e-4,
33 | weight_decay = 9e-4,
34 | beta: str = 1e-4,
35 | log_dir: str = None,
36 | verbose: bool = True,
37 | seed: int = None,
38 | **kwargs):
39 | self.num_features = data.num_features
40 | self.out_channels = out_channels
41 | self.data = data
42 | self.epochs = epochs
43 | self.lr = lr
44 | self.verbose = verbose
45 | self.logging = True if log_dir is not None else False
46 | if self.logging:
47 | self.train_logger = tb.SummaryWriter(os.path.join(log_dir, "train"))
48 | self.test_logger = tb.SummaryWriter(os.path.join(log_dir, "test"))
49 | if beta is not None:
50 | self.beta = beta
51 | else:
52 | self.beta = (1 / self.train_data.num_nodes)
53 | self.weight_decay = weight_decay
54 | self.seed = seed
55 |
56 |
57 | def __repr__(self) -> str:
58 | return f"Hyperparameters\n"\
59 | f"epochs: {self.epochs}, lr: {self.lr}, beta: {self.beta:.4f}\n"
60 |
61 |
62 | # split data
63 | def _transform_data(self,
64 | x_noise: float = None,
65 | x_dropout: float = None,
66 | edge_noise: float = None,
67 | **kwargs
68 | ):
69 | """
70 | Args:
71 | x_noise: Standard deviation of white noise added to the training data. Defaults to None.
72 | edge_noise: Remove or add edges to the training data.
73 | """
74 | # from copy import deepcopy
75 | # data_ = deepcopy(self.data)
76 | if x_dropout is not None:
77 | mask = torch.FloatTensor(self.data.x.shape).uniform_() > x_dropout # % zeros
78 | self.data.x = self.data.x * mask
79 | print(f"force zeros to data x, dropout: {x_dropout}")
80 |
81 | self.train_data, self.val_data, self.test_data = split_data(data = self.data, **kwargs) # fixed split
82 | if x_noise is not None: # white noise on X
83 | gamma = x_noise * torch.randn(self.train_data.x.shape)
84 | self.train_data.x = 2**gamma * self.train_data.x
85 | print(f"add white noise to training data x, level: {x_noise} SD")
86 |
87 | # if x_dropout is not None:
88 | # mask = torch.FloatTensor(self.train_data.x.shape).uniform_() > x_dropout # % zeros
89 | # self.train_data.x = self.train_data.x * mask
90 | # print(f"force zeros to training data x, dropout: {x_dropout}")
91 |
92 | if edge_noise is not None:
93 | n_pos_edge = self.train_data.pos_edge_label_index.shape[1]
94 | n = int(abs(edge_noise) * n_pos_edge)
95 | print("Before:", self.train_data)
96 | print("\n")
97 | if edge_noise > 0: # fold edges
98 | from torch_geometric.utils import negative_sampling
99 | fake_pos_edge = negative_sampling(self.data.edge_index, # then fake edges impossible appeared in test set: for data leakage
100 | num_neg_samples= n,
101 | num_nodes = len(self.train_data.x))
102 | new_pos_edge = torch.unique(torch.cat((self.train_data.pos_edge_label_index, fake_pos_edge), 1), dim = 1)
103 | new_edge = torch.cat((new_pos_edge, new_pos_edge[torch.LongTensor([1, 0])]), 1) # swap
104 | print(f"add noise to training data edge, level: {edge_noise}: add {n} edges")
105 | else: # ratio edges
106 | if n > n_pos_edge:
107 | raise ValueError("cannot retain edges more than the total")
108 | weights = torch.tensor([1/n_pos_edge] * n_pos_edge, dtype = torch.float) # uniform weights
109 | index = weights.multinomial(num_samples = n, replacement = False)
110 | new_pos_edge = self.train_data.pos_edge_label_index[:, index]
111 | new_edge = torch.cat((new_pos_edge, new_pos_edge[torch.LongTensor([1, 0])]), 1)
112 | # new_neg_edge = self.train_data.neg_edge_label_index[:, index]
113 | # self.train_data.neg_edge_label_index = new_neg_edge
114 | print(f"retain a portion of training data edge, level: {abs(edge_noise)}: retain {n} edges")
115 | self.train_data.edge_index, self.train_data.pos_edge_label_index = new_edge, new_pos_edge
116 | print("After:", self.train_data)
117 |
118 |
119 | # refer to https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py
120 | def train(self, **kwargs):
121 | global_step = 0
122 | if self.seed is not None:
123 | torch.manual_seed(self.seed) # 8096
124 | self._transform_data(**kwargs) # get split data w/o noise
125 | self.model = VGAE(VariationalGCNEncoder(self.num_features, self.out_channels))
126 |
127 | # to cuda
128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129 | self.model = self.model.to(device)
130 | self.train_data, self.val_data, self.test_data = self.train_data.to(device), self.val_data.to(device), self.test_data.to(device)
131 |
132 | optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, weight_decay = self.weight_decay)
133 | # optimizer = torch.optim.SGD(self.model.parameters(), lr = self.lr, momentum = 0.9, weight_decay = 5e-4)
134 |
135 | for epoch in range(self.epochs):
136 | self.model.train()
137 | optimizer.zero_grad()
138 | z = self.model.encode(self.train_data.x, self.train_data.edge_index)
139 | recon_loss = self.model.recon_loss(z, self.train_data.pos_edge_label_index, self.train_data.neg_edge_label_index)
140 | kl_loss = self.beta * self.model.kl_loss() # beta-VAE
141 | loss = recon_loss + kl_loss
142 | if self.logging:
143 | self.train_logger.add_scalar("loss", loss.item(), global_step)
144 | self.train_logger.add_scalar("recon_loss", recon_loss.item(), global_step)
145 | self.train_logger.add_scalar("kl_loss", kl_loss.item(), global_step)
146 | loss.backward()
147 | optimizer.step()
148 | global_step += 1
149 |
150 | with torch.no_grad():
151 | self.model.eval()
152 | z = self.model.encode(self.test_data.x, self.test_data.edge_index)
153 | auc, ap = self.model.test(z, self.test_data.pos_edge_label_index, self.test_data.neg_edge_label_index)
154 | if self.logging:
155 | self.test_logger.add_scalar("AUROC", auc, global_step)
156 | self.test_logger.add_scalar("AP", ap, global_step)
157 | if self.verbose:
158 | print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, AUROC: {auc:.4f}, AP: {ap:.4f}")
159 | self.final_metrics = global_step, loss.item(), auc, ap
160 |
161 |
162 | def save_model(self, name: str):
163 | """
164 | name: str, name of .th file that will be saved in model.
165 | """
166 | if isinstance(self.model, VGAE):
167 | os.makedirs("model", exist_ok=True)
168 | path = os.path.join("model", f"{name}.th")
169 | print(f"save model parameters to {path}")
170 | return torch.save(self.model.state_dict(), f"model/{name}.th")
171 | else:
172 | raise ValueError(f"model type {type(self.model)} not supported")
173 |
174 |
175 | def load_model(self, name: str):
176 | """
177 | name: str, name of .th file saved in model.
178 | """
179 | r = VGAE(VariationalGCNEncoder(self.num_features, self.out_channels))
180 | path = os.path.join("model", f"{name}.th")
181 | print(f"load model parameters from {path}")
182 | r.load_state_dict(
183 | torch.load(os.path.join("model", f"{name}.th"), map_location=torch.device("cpu"))
184 | )
185 | self.model = r
186 |
187 |
188 | # after training
189 | def get_latent_vars(self, data, plot_latent_mu = False):
190 | """
191 | data: torch_geometric.data.data.Data.
192 | plot_latent_z: bool, whether to plot nodes with random sampled latent features.
193 | """
194 | self.model.eval()
195 | _ = self.model.encode(data.x, data.edge_index)
196 | z_m = self.model.__mu__.detach().numpy()
197 | z_S = (self.model.__logstd__.exp() ** 2).detach().numpy() # variance
198 | if plot_latent_mu:
199 | # z_np = z.detach().numpy()
200 | fig, ax = plt.subplots(figsize=(6, 6), dpi=80)
201 | if z_m.shape[1] == 2:
202 | ax.scatter(z_m[:, 0], z_m[:, 1], s=4)
203 | elif z_m.shape[1] == 1:
204 | ax.hist(z_m, bins=60)
205 | elif z_m.shape[1] == 3:
206 | from mpl_toolkits.mplot3d import Axes3D
207 | ax = Axes3D(fig)
208 | ax.scatter(z_m[:, 0], z_m[:, 1], z_m[:, 2], s=4)
209 | plt.show()
210 | if z_m.shape[1] == 1:
211 | z_m = z_m.flatten()
212 | z_S = z_S.flatten()
213 | return z_m, z_S
214 |
215 |
216 | def pmt(self, data_v, n = 100, by = "KL"):
217 | """
218 | data_v: torch_geometric.data.data.Data, virtual data.
219 | n: int, # of permutation.
220 | by: str, method for distance calculation of distributions.
221 | """
222 | # permutate cells order and compute KL div
223 | n_cell = self.data.x.shape[1]
224 | dis_p = []
225 | np.random.seed(0)
226 | for _ in tqdm(range(n), desc="Permutating", total = n):
227 | # p: pmt, v: virtual, m: mean, S: sigma
228 | idx_pmt = np.random.choice(
229 | np.arange(n_cell), size=n_cell
230 | ) # bootstrap cell labels
231 | data_WT_p = Data(
232 | # x=torch.tensor(data.x[:, idx_pmt], dtype=torch.float),
233 | x=self.data.x[:, idx_pmt].clone().detach().requires_grad_(False),
234 | edge_index=self.data.edge_index,
235 | )
236 | data_KO_p = Data(
237 | # x=torch.tensor(data_v.x[:, idx_pmt], dtype=torch.float),
238 | x=data_v.x[:, idx_pmt].clone().detach().requires_grad_(False),
239 | edge_index=data_v.edge_index,
240 | ) # construct virtual data (KO) based on pmt WT
241 | z_mp, z_Sp = self.get_latent_vars(data_WT_p)
242 | z_mvp, z_Svp = self.get_latent_vars(data_KO_p)
243 | if by == "KL":
244 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp)) # order KL
245 | # if by == 'reverse_KL':
246 | # dis_p.append(get_distance(z_mp, z_Sp, z_mvp, z_Svp)) # reverse order KL
247 | if by == "t":
248 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp, by="t"))
249 | if by == "EMD":
250 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp, by="EMD"))
251 | return np.array(dis_p)
252 |
253 |
254 | def eva(args):
255 | from .preprocesing import load_gdata
256 | CURRENT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
257 | load_path = os.path.join(CURRENT_DIR, args.dir)
258 | data = load_gdata(load_path, "data")
259 |
260 | sensei = VGAE_trainer(data,
261 | epochs = args.epochs,
262 | lr = args.lr,
263 | log_dir = args.logdir,
264 | beta = args.beta,
265 | verbose = args.verbose,
266 | seed = args.seed,
267 | )
268 | sensei.train(edge_noise = args.e_noise,
269 | x_noise = args.x_noise,
270 | x_dropout = args.x_dropout,
271 | dir = args.dir, load = False, # load split data
272 | )
273 | epoch, loss, auc, ap = sensei.final_metrics
274 |
275 | save_path = os.path.join(CURRENT_DIR, "train_log")
276 | os.makedirs(save_path, exist_ok = True)
277 | try:
278 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'r')
279 | except IOError:
280 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'w')
281 | f.writelines("Epoch,Loss,AUROC,AP\n")
282 | finally:
283 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'a')
284 | f.writelines(f"{epoch:03d}, {loss:.4f}, {auc:.4f}, {ap:.4f}\n")
285 | f.close()
286 | if args.do_test:
287 | data_ko = load_gdata(load_path, "data_ko")
288 | print("continue")
289 | from .utils import get_generank, get_r2_score
290 | z_mu, z_std = sensei.get_latent_vars(data)
291 | z_mu_KO, z_std_KO = sensei.get_latent_vars(data_ko)
292 | dis = get_distance(z_mu_KO, z_std_KO, z_mu, z_std, by = 'KL')
293 | null = sensei.pmt(data_ko, n = 100, by = 'KL')
294 | res = get_generank(data, dis, null, save_significant_as = args.generank_out)
295 | geneset = list(res.index)
296 | r2, r2_adj = get_r2_score(data, geneset[1:], geneset[0])
297 | try:
298 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'r')
299 | except IOError:
300 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'w')
301 | f.writelines("R2,R2_adj\n")
302 | finally:
303 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'a')
304 | f.writelines(f"{r2:.4f}, {r2_adj:.4f}\n")
305 | f.close()
306 |
307 |
308 | if __name__ == "__main__":
309 | import argparse
310 | parser = argparse.ArgumentParser()
311 | parser.add_argument('--ddir', type = str, default = "data")
312 | parser.add_argument('--epochs', type = int, default = 100)
313 | parser.add_argument('--lr', type = float, default = 7e-4)
314 | parser.add_argument('--beta', type = float, default = 1e-4)
315 | parser.add_argument('--seed', type = int, default = None)
316 | parser.add_argument('--logdir', type = str, default = None)
317 | parser.add_argument('-E', '--e_noise', type = float, default = None)
318 | parser.add_argument('-X', '--x_noise', type = float, default = None)
319 | parser.add_argument('-XO', '--x_dropout', type = float, default = None)
320 | parser.add_argument('-v', '--verbose', action = "store_true")
321 | parser.add_argument('--train_out', type = str, default = "train_log")
322 |
323 | parser.add_argument('--do_test', action = "store_true")
324 | parser.add_argument('--generank_out', type = str, default = "gene_list")
325 | parser.add_argument('--r2_out', type = str, default = "r2_score")
326 | args = parser.parse_args()
327 |
328 | eva(args)
329 | # python -m GenKI.train --dir data --logdir log_dir/run0 --seed 8096 -E -0.1 -v
330 | # python -m GenKI.train --dir data_covid --logdir log_dir/sigma0_covid --lr 5e-3 --seed 8096 -v
331 | # python -m GenKI.train --dir data --train_out train_log --do_test --generank_out genelist --r2_out r2_score -v
332 |
333 |
334 |
--------------------------------------------------------------------------------
/GenKI/tune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .model import VGAE
3 | from .train import VariationalGCNEncoder
4 | import numpy as np
5 | from ray import tune
6 | from ray.tune.schedulers import ASHAScheduler
7 | import os
8 | from .preprocesing import split_data
9 |
10 |
11 | hyperparams = {
12 | "lr": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(1, 4))),
13 | "beta": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(1, 5))),
14 | "seed": tune.randint(0, 10000),
15 | "weight_decay": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(3, 7)))
16 | }
17 |
18 |
19 | def train(config, checkpoint_dir = None):
20 | train_data, val_data, test_data = split_data(dir = "data", load = True)
21 | torch.manual_seed(config["seed"])
22 | model = VGAE(VariationalGCNEncoder(train_data.num_features, 2))
23 |
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | model = model.to(device)
26 | train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device)
27 | optimizer = torch.optim.Adam(model.parameters(), lr = config["lr"], weight_decay = config["weight_decay"])
28 | # optimizer = torch.optim.SGD(model.parameters(), lr = config["lr"], momentum = 0.9, weight_decay = config["weight_decay"])
29 |
30 | # when restore a checkpoint
31 | if checkpoint_dir:
32 | checkpoint = os.path.join(checkpoint_dir, "checkpoint")
33 | model_state, optimizer_state = torch.load(checkpoint)
34 | model.load_state_dict(model_state)
35 | optimizer.load_state_dict(optimizer_state)
36 |
37 | for epoch in range(1000): # search < max_num_epochs
38 | model.train()
39 | optimizer.zero_grad()
40 | z = model.encode(train_data.x, train_data.edge_index)
41 | recon_loss = model.recon_loss(z, train_data.pos_edge_label_index)
42 | kl_loss = config["beta"] * model.kl_loss() # beta-VAE
43 | loss = recon_loss + kl_loss
44 | loss.backward()
45 | optimizer.step()
46 |
47 | # valid set
48 | with torch.no_grad():
49 | model.eval()
50 | z = model.encode(val_data.x, val_data.edge_index)
51 | auc, ap = model.test(z, val_data.pos_edge_label_index, val_data.neg_edge_label_index)
52 | # print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, AUROC: {auc:.4f}, AP: {ap:.4f}')
53 |
54 | # save a checkpoint
55 | with tune.checkpoint_dir(step = epoch) as checkpoint_dir:
56 | path = os.path.join(checkpoint_dir, "checkpoint")
57 | torch.save(
58 | (model.state_dict(), optimizer.state_dict()), path)
59 | # record metrics from valid set
60 | tune.report(Loss = loss.item(), AUROC = auc, AP = ap, F_ = auc*ap) # call: shown as training_iteration
61 | print("Finished Training")
62 |
63 |
64 | def test_best_model(best_trial):
65 | train_data, val_data, test_data = split_data(dir = "data", load = True)
66 | torch.manual_seed(best_trial.config["seed"])
67 | best_trained_model = VGAE(VariationalGCNEncoder(train_data.num_features, 2))
68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69 | best_trained_model = best_trained_model.to(device)
70 | train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device)
71 |
72 | checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
73 | model_state, optimizer_state = torch.load(checkpoint_path)
74 | best_trained_model.load_state_dict(model_state)
75 |
76 | # test set
77 | with torch.no_grad():
78 | best_trained_model.eval()
79 | z = best_trained_model.encode(test_data.x, test_data.edge_index)
80 | auc, ap = best_trained_model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)
81 | print(f"Best trial test AUROC: {auc:.4f}")
82 | print(f"Best trial test AP: {ap:.4f}")
83 |
84 |
85 | def main(num_samples,
86 | max_num_epochs = 100,
87 | gpus_per_trial = 0,
88 | save_as: str = "tune_result"):
89 | scheduler = ASHAScheduler(
90 | max_t = max_num_epochs, # max iteration
91 | grace_period = 10, # stop at least after this iteration
92 | reduction_factor = 2)
93 | result = tune.run(train,
94 | config = hyperparams,
95 | num_samples = num_samples, # trials: sets sampled from grid of hyperparams
96 | name = "experiment", # saved folder name
97 | metric = "F_",
98 | mode = "max",
99 | resources_per_trial={"cpu": 4, "gpu": gpus_per_trial},
100 | scheduler = scheduler, # prune bad runs
101 | # stop = {'training_iteration':100}, # when tune.report was called
102 | )
103 | best_trial = result.get_best_trial("F_", "max", "last")
104 | print(f"Best trial config: {best_trial}")
105 | print("Best trial valid AUROC: {}".format(best_trial.last_result["AUROC"]))
106 | print("Best trial valid AP: {}".format(best_trial.last_result["AP"]))
107 | test_best_model(best_trial) # only should run once
108 | save_path = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), f"{save_as}.csv")
109 | result.dataframe().to_csv(save_path)
110 |
111 |
112 | if __name__ == "__main__":
113 | # seed for Ray Tune's random search algorithm
114 | np.random.seed(42)
115 | main(num_samples = 50)
116 | # report: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html#tune-autofilled-metrics
117 |
118 |
--------------------------------------------------------------------------------
/GenKI/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import scipy
5 | import matplotlib.pyplot as plt
6 | from sklearn.manifold import TSNE
7 | from sklearn.preprocessing import StandardScaler
8 | from scipy.cluster.vq import kmeans2
9 | import math
10 |
11 |
12 | def boxcox_norm(x):
13 | """
14 | x: 1-D array-like, require positive values.
15 | Box-cox transform and standarize x
16 | """
17 | xt, _ = scipy.stats.boxcox(x)
18 | return StandardScaler().fit_transform(xt[:, None]).flatten()
19 | # (xt - xt.mean())/np.sqrt(xt.var()) # z-score for 1-D
20 |
21 |
22 | def _t_stat(m0, S0, m1, S1):
23 | return (m0 - m1) / math.sqrt(
24 | S0**2 + S1**2
25 | )
26 |
27 |
28 | def _kl_1d(m0, S0, m1, S1):
29 | """
30 | KL divergence between two gaussian distributions.
31 | https://stats.stackexchange.com/questions/234757/how-to-use-kullback-leibler-divergence-if-mean-and-standard-deviation-of-of-two
32 | """
33 | return 0.5 * math.log(S1 / S0) + (S0**2 + (m0 - m1) ** 2) / (2 * S1**2) - 1 / 2
34 |
35 |
36 | def _kl_mvn(m0, S0, m1, S1):
37 | """
38 | KL divergence between two multivariate gaussian distributions.
39 | """
40 | # store inv diag covariance of S1 and diff between means
41 | N = m0.shape[0]
42 | iS1 = np.linalg.pinv(S1) # pseudo-inverse
43 | diff = m1 - m0
44 | tr_term = np.trace(iS1 @ S0)
45 | det_term = np.log(
46 | np.linalg.det(S1) / np.linalg.det(S0)
47 | ) # np.sum(np.log(S1)) - np.sum(np.log(S0))
48 | quad_term = diff.T @ iS1 @ diff # np.sum( (diff*diff) * iS1, axis=1)
49 | return 0.5 * (tr_term + det_term + quad_term - N)
50 |
51 |
52 | def _wasserstein_dist2(m0, S0, m1, S1):
53 | """
54 | Earth Mover Distance (wasserstein distance) between two multivariate gaussian distributions.
55 | """
56 | return np.square(np.linalg.norm(m0 - m1)) + np.trace(
57 | S0 + S1 - 2 * np.sqrt(np.sqrt(S0) @ S1 @ np.sqrt(S0))
58 | )
59 |
60 |
61 | def get_distance(z_m0, z_S0, z_m1, z_S1, by = "KL"):
62 | """
63 | z_m0, z_S0: latent means and sigma from WT data.
64 | z_mv, z_Sv: latent means and sigma from virtual data.
65 | """
66 | dis = []
67 | try:
68 | out_channels = z_m0.shape[1] # out_channels >= 2
69 | for m0, S0, m1, S1 in zip(z_m0, z_S0, z_m1, z_S1):
70 | temp_S0 = np.zeros((out_channels, out_channels), float)
71 | np.fill_diagonal(temp_S0, S0)
72 | temp_S1 = np.zeros((out_channels, out_channels), float)
73 | np.fill_diagonal(temp_S1, S1)
74 | if by == "KL":
75 | dis.append(_kl_mvn(m0, temp_S0, m1, temp_S1))
76 | # dis.append(
77 | # (
78 | # _kl_mvn(m0, temp_S0, m1, temp_S1)
79 | # + _kl_mvn(m1, temp_S1, m0, temp_S0)
80 | # )
81 | # / 2
82 | # )
83 | if by == "EMD":
84 | dis.append(_wasserstein_dist2(m0, temp_S0, m1, temp_S1))
85 | except IndexError: # when out_channels == 1
86 | for m0, S0, m1, S1 in zip(z_m0, z_S0, z_m1, z_S1):
87 | if by == "KL":
88 | dis.append((_kl_1d(m0, S0, m1, S1) + _kl_1d(m1, S1, m0, S0)) / 2)
89 | if by == "t":
90 | dis.append(_t_stat(m0, S0, m1, S1))
91 | return np.array(dis)
92 |
93 |
94 | def get_generank(
95 | data,
96 | distance,
97 | null = None,
98 | rank: bool = True,
99 | reverse: bool = False,
100 | bagging: float = 0.05,
101 | cutoff: float = 0.95,
102 | save_significant_as: str = None,
103 | ):
104 | """
105 | data: torch_geometric.data.data.Data.
106 | distance: array-like, output of get_distance.
107 | null: array-like, output of pmt.
108 | bagging: threshold for bagging top names at each permutation.
109 | cutoff: threshold for frequency of bagging after all permutations.
110 | save_significant_as: .txt file name of significant genes that will be saved in result for enrichment test.
111 | """
112 | if null is not None:
113 | if reverse:
114 | idx = np.argsort(-null, axis=1)
115 | else:
116 | idx = np.argsort(null, axis=1)
117 | thre = int(len(data.y) * (1 - bagging)) # 95%
118 | idx = idx[:, thre:].flatten() # bagging index of top ranked genes
119 |
120 | y = np.bincount(idx)
121 | ii = np.nonzero(y)[0]
122 | freq = np.vstack((ii, y[ii])).T # [gene_index, #hit]
123 | df_KL = pd.DataFrame(
124 | data=distance[freq[:, 0]][:, None],
125 | index=np.array(data.y)[freq[:, 0]],
126 | columns=["dis"],
127 | )
128 | df_KL[["index", "hit"]] = freq.astype(int)
129 | hit = int(null.shape[0] * cutoff)
130 | df_KL = df_KL[(df_KL.hit >= hit) & (df_KL.dis != 0)]
131 | if rank:
132 | df_KL.sort_values(by=["dis", "hit"], ascending=reverse, inplace=True)
133 | if save_significant_as is not None:
134 | output = list(df_KL.index)
135 | os.makedirs("result", exist_ok=True)
136 | np.savetxt(
137 | os.path.join("result", f"{save_significant_as}.txt"), output, fmt="%s", delimiter=","
138 | )
139 | print(f"save {len(output)} genes to \"./result/{save_significant_as}.txt\"")
140 | else:
141 | df_KL = pd.DataFrame(
142 | data=distance, index=np.array(data.y), columns=["dis"]
143 | ) # filter pseudo values
144 | if rank:
145 | df_KL.sort_values(by=["dis"], ascending=reverse, inplace=True)
146 | if rank:
147 | df_KL["rank"] = np.arange(len(df_KL)) + 1
148 | return df_KL
149 |
150 |
151 | def get_generank_gsea(data,
152 | distance,
153 | reverse: bool = False,
154 | save_as: str = None):
155 | """
156 | data: torch_geometric.data.data.Data,
157 | distance: array-like, output of get_distance.
158 | save_as: str, .txt file name that will be saved in result for GSEA.
159 | """
160 | df_gsea = get_generank(data, distance, reverse=reverse)
161 | df_gsea = df_gsea[df_gsea["dis"] > 0] # remove pinverse
162 |
163 | # Box-cox
164 | df_gsea["dis_norm"] = boxcox_norm(df_gsea["dis"]) # z-score
165 |
166 | df_gsea.sort_values(by="dis_norm", inplace=True, ascending=reverse)
167 | output_gsea = np.stack(
168 | (df_gsea.index, df_gsea["dis_norm"])
169 | ).T # GSEA: [gene_name, value]
170 |
171 | if save_as is not None:
172 | os.makedirs("result", exist_ok=True)
173 | np.savetxt(
174 | os.path.join("result", "GSEA_{save_as}.txt"), output_gsea, fmt="%s", delimiter="\t"
175 | )
176 | print(f"save ranked genes to \"./result/GSEA_{save_as}.txt\"")
177 | return df_gsea
178 |
179 |
180 | def get_r2_score(data, geneset: list, target_gene: str):
181 | import statsmodels.api as sm
182 | # np.random.seed(4)
183 |
184 | X = data.x.numpy().T # standardized counts
185 | x_genes_idx = [data.y.index(g) for g in geneset]
186 | y_genes_idx = data.y.index(target_gene)
187 | x_genes = X[:, x_genes_idx].copy()#.toarray()
188 | y_genes = X[:, y_genes_idx].copy()#.toarray()
189 | # print(x_genes.shape, y_genes.shape)
190 |
191 | result = sm.OLS(y_genes, x_genes).fit() # y, X
192 | return result.rsquared, result.rsquared_adj
193 | # for _ in range(30):
194 | # random_idx = np.random.choice(X.shape[1], x_genes.shape[1], replace = False)
195 | # random_genes = X[:, random_idx].copy()#.toarray()
196 | # result = sm.OLS(y_genes, random_genes).fit()
197 | # r2.append([result.rsquared, result.rsquared_adj])
198 |
199 |
200 | def get_sys_KO_cluster(
201 | obj,
202 | sys_res: np.ndarray,
203 | perplexity=25,
204 | n_cluster=50,
205 | save_as=None,
206 | show_TSNE=True,
207 | verbose=False,
208 | ):
209 | """
210 | obj: a sc object.
211 | sys_res: np.ndarray, output of run_sys_KO.
212 | perplexity: int, hyperparameter of TSNE.
213 | n_cluster: int, hyperparameter of k-means.
214 | show_TSNE: bool, whether to show the TSNE plot after clustering.
215 | """
216 | np.random.seed(100)
217 | scaled_sys_res = StandardScaler().fit_transform(
218 | sys_res
219 | ) # standarize features (cols)
220 | X_embedded = TSNE(
221 | n_components=2,
222 | learning_rate="auto",
223 | perplexity=perplexity,
224 | metric="euclidean",
225 | init="pca",
226 | ).fit_transform(scaled_sys_res)
227 | _, label = kmeans2(X_embedded, n_cluster, minit="points") # centroid
228 | cluster_idx = np.where(label == label[obj(obj._target_gene)])
229 | cluster_gene_names = np.array(obj._gene_names)[cluster_idx]
230 | if verbose:
231 | print(
232 | f"TSNE perplexity: {perplexity}, # Clustering: {n_cluster}\nThe cluster containing {obj._target_gene} has {len(cluster_gene_names)} genes"
233 | )
234 | if save_as is not None:
235 | os.makedirs("result", exist_ok=True)
236 | np.savetxt(
237 | os.path.join("result", f"sys_KO_{save_as}.txt"),
238 | cluster_gene_names,
239 | fmt="%s",
240 | delimiter="\t",
241 | )
242 | print(f"save ranked genes to \"./result/sys_KO_{save_as}.txt\"")
243 | if show_TSNE:
244 | fig, ax = plt.subplots(figsize=(8, 8), dpi=80)
245 | colors = [
246 | "red" if i == label[obj(obj._target_gene)] else "black" for i in label
247 | ]
248 | ax.set_title(f"TSNE plot of systematic KO {len(sys_res)} genes")
249 | ax.scatter(X_embedded[:, 0], X_embedded[:, 1], s=3, c=colors, alpha=0.5)
250 | ax.axis("tight")
251 | ax.annotate(
252 | obj._target_gene,
253 | xy=(X_embedded[obj(obj._target_gene)]),
254 | xycoords="data",
255 | xytext=(X_embedded[obj(obj._target_gene)] + 15),
256 | textcoords="offset points",
257 | arrowprops=dict(arrowstyle="->", connectionstyle="arc3"),
258 | )
259 | plt.show()
260 | return cluster_gene_names
261 |
262 |
263 | if __name__ == "__main__":
264 | pm = np.array([0, 0])
265 | pv = np.array([0, 0])
266 | qm = np.array([[1, 0], [0, 1]])
267 | qv = np.array([[1, 0], [0, 1]])
268 | assert _kl_mvn(pm, qm, pv, qv) == 0
269 |
--------------------------------------------------------------------------------
/GenKI/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1.dev"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Yongjian Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjgeno/GenKI/6d69789b89859eda75ac75dfb1cc00ef190ada41/MANIFEST.in
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GenKI (Gene Knock-out Inference)
2 | A VGAE (Variational Graph Auto-Encoder) based model to learn perturbation using scRNA-seq data.
3 | New! Data has been added.
4 | [Paper](https://doi.org/10.1093/nar/gkad450)
5 |
6 |
7 |
8 |
9 |
10 |
11 | ### Install dependencies
12 | Fist install dependencies of GenKI with `conda`:
13 | ```shell
14 | conda env create -f environment.yml
15 | conda activate ogenki
16 | ```
17 | Install `pytorch-geometric` following the document:
18 | https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html
19 |
20 |
21 |
22 | ### Install GenKI with `pip`:
23 | ```shell
24 | pip install git+https://github.com/yjgeno/GenKI.git
25 | ```
26 | or install it manually from source:
27 | ```shell
28 | git clone https://github.com/yjgeno/GenKI.git
29 | cd GenKI
30 | pip install .
31 | ```
32 |
33 |
34 | #### Tutorial
35 | Virtual KO experiment:
https://github.com/yjgeno/GenKI/blob/master/notebook/Example.ipynb
36 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | * The microglial (Trem2-KO) dataset is now available at [GoogleDrive](https://drive.google.com/file/d/1tG9bUGCsWqhg0hJ94lDLtLl8WLl0hDks/view?usp=sharing).
2 | * Other datasets may be provided upon request.
3 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 |
4 | dependencies:
5 | - python~=3.9.6
6 | - pip>=22.0.4
7 | # - pycairo>=1.21.0
8 | - anndata>=0.7.4
9 | - numpy>=1.17.0
10 | - matplotlib-base>=3.1.2
11 | - pandas>=0.21
12 | - scipy>=1.4
13 | - seaborn
14 | - h5py>=3
15 | - tqdm
16 | # - statsmodels~=0.13.1
17 | - patsy
18 | - networkx>=2.3
19 | - natsort
20 | - joblib
21 | - numba>=0.41.0
22 | - umap-learn>=0.3.10
23 | - packaging
24 | - setuptools-scm
25 | - black>=20.8b1
26 | - docutils
27 | - sphinx<4.2,>=4.1
28 | - sphinx_rtd_theme>=0.3.1
29 | # - python-igraph~=0.9.8
30 | - leidenalg
31 | - louvain!=0.6.2,>=0.6
32 | # - scikit-misc~=0.1.4
33 | - pytest>=4.4
34 | - pytest-nunit
35 | - dask-core!=2.17.0
36 | - fsspec
37 | - zappy
38 | - zarr
39 | - profimp
40 | - flit-core
41 | - ipywidgets
42 |
43 | - pip:
44 | - flit
45 | - session-info
46 | - scanpydoc<0.7.6,>=0.7.4
47 | - torch==1.11.0
48 | - tensorboard
49 | # - torch-geometric
50 | - scikit-learn
51 | - scikit-misc>=0.1.3
52 | - scanpy>=1.7.2
53 | - ray
54 | - notebook
55 | name: ogenki
56 |
57 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html
--------------------------------------------------------------------------------
/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjgeno/GenKI/6d69789b89859eda75ac75dfb1cc00ef190ada41/logo.jpg
--------------------------------------------------------------------------------
/notebook/Compatible_matlab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "6c209667",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import numpy as np\n",
12 | "import pandas as pd\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import scanpy as sc\n",
15 | "\n",
16 | "sc.settings.verbosity = 0"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "id": "ffae2228",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import GenKI as gk\n",
27 | "from GenKI.preprocesing import build_adata\n",
28 | "from GenKI.dataLoader import DataLoader\n",
29 | "from GenKI.train import VGAE_trainer\n",
30 | "from GenKI import utils\n",
31 | "\n",
32 | "%load_ext autoreload\n",
33 | "%autoreload 2"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "id": "d32e28a9",
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "load counts from data\\filtered_gene_bc_matrices\\X.txt\n",
47 | "AnnData object with n_obs × n_vars = 100 × 300\n",
48 | " obs: 'cell_type'\n",
49 | " uns: 'log1p'\n",
50 | " layers: 'raw', 'norm'\n"
51 | ]
52 | }
53 | ],
54 | "source": [
55 | "# example\n",
56 | "\n",
57 | "data_dir = os.path.join(\"data\", \"filtered_gene_bc_matrices\")\n",
58 | "counts_path = str(os.path.join(data_dir, \"X.txt\"))\n",
59 | "gene_path = str(os.path.join(data_dir, \"g.txt\"))\n",
60 | "cell_path = str(os.path.join(data_dir, \"c.txt\"))\n",
61 | "\n",
62 | "adata = build_adata(counts_path, \n",
63 | " gene_path, \n",
64 | " cell_path, \n",
65 | " meta_cell_cols=[\"cell_type\"], # colname of cell type\n",
66 | " delimiter=',', # X.txt\n",
67 | " transpose=True, # X.txt\n",
68 | " log_normalize=True,\n",
69 | " scale_data=True,\n",
70 | " )\n",
71 | "\n",
72 | "adata = adata[:100, :300].copy()\n",
73 | "print(adata)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 4,
79 | "id": "071c88c5",
80 | "metadata": {},
81 | "outputs": [
82 | {
83 | "data": {
84 | "text/plain": [
85 | "'LAMB3'"
86 | ]
87 | },
88 | "execution_count": 4,
89 | "metadata": {},
90 | "output_type": "execute_result"
91 | }
92 | ],
93 | "source": [
94 | "# gene to ko\n",
95 | "adata.var_names[66]"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 5,
101 | "id": "1e37d540",
102 | "metadata": {
103 | "scrolled": true
104 | },
105 | "outputs": [
106 | {
107 | "name": "stdout",
108 | "output_type": "stream",
109 | "text": [
110 | "use all the cells (100) in adata\n",
111 | "build GRN\n",
112 | "ray init, using 8 CPUs\n",
113 | "execution time of making pcNet: 6.34 s\n",
114 | "GRN has been built and saved in \"GRNs\\pcNet_example.npz\"\n",
115 | "init completed\n",
116 | "\n"
117 | ]
118 | }
119 | ],
120 | "source": [
121 | "# load data\n",
122 | "\n",
123 | "data_wrapper = DataLoader(\n",
124 | " adata, # adata object\n",
125 | " target_gene = [66], # KO gene name/index\n",
126 | " target_cell = None, # obsname for cell type, if none use all\n",
127 | " obs_label = \"cell_type\", # colname for genes\n",
128 | " GRN_file_dir = \"GRNs\", # folder name for GRNs\n",
129 | " rebuild_GRN = True, # whether build GRN by pcNet\n",
130 | " pcNet_name = \"pcNet_example\", # GRN file name\n",
131 | " verbose = True, # whether verbose\n",
132 | " n_cpus = 8, # multiprocessing\n",
133 | " )\n",
134 | "\n",
135 | "data_wt = data_wrapper.load_data()\n",
136 | "data_ko = data_wrapper.load_kodata()"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 6,
142 | "id": "7862c768",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "# init trainer\n",
147 | "\n",
148 | "hyperparams = {\"epochs\": 100, \n",
149 | " \"lr\": 7e-4, \n",
150 | " \"beta\": 1e-4, \n",
151 | " \"seed\": 8096}\n",
152 | "log_dir=None \n",
153 | "\n",
154 | "sensei = VGAE_trainer(data_wt, \n",
155 | " epochs=hyperparams[\"epochs\"], \n",
156 | " lr=hyperparams[\"lr\"], \n",
157 | " log_dir=log_dir, \n",
158 | " beta=hyperparams[\"beta\"],\n",
159 | " seed=hyperparams[\"seed\"],\n",
160 | " verbose=True,\n",
161 | " )"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 7,
167 | "id": "31bfb6a0",
168 | "metadata": {
169 | "scrolled": false
170 | },
171 | "outputs": [
172 | {
173 | "name": "stdout",
174 | "output_type": "stream",
175 | "text": [
176 | "Epoch: 000, Loss: 1.5318, AUROC: 0.8618, AP: 0.7514\n",
177 | "Epoch: 001, Loss: 1.4803, AUROC: 0.8643, AP: 0.7555\n",
178 | "Epoch: 002, Loss: 1.5022, AUROC: 0.8668, AP: 0.7596\n",
179 | "Epoch: 003, Loss: 1.4556, AUROC: 0.8692, AP: 0.7634\n",
180 | "Epoch: 004, Loss: 1.4764, AUROC: 0.8715, AP: 0.7671\n",
181 | "Epoch: 005, Loss: 1.4815, AUROC: 0.8739, AP: 0.7708\n",
182 | "Epoch: 006, Loss: 1.4358, AUROC: 0.8760, AP: 0.7744\n",
183 | "Epoch: 007, Loss: 1.4775, AUROC: 0.8783, AP: 0.7781\n",
184 | "Epoch: 008, Loss: 1.4364, AUROC: 0.8804, AP: 0.7812\n",
185 | "Epoch: 009, Loss: 1.4689, AUROC: 0.8823, AP: 0.7839\n",
186 | "Epoch: 010, Loss: 1.4504, AUROC: 0.8842, AP: 0.7863\n",
187 | "Epoch: 011, Loss: 1.4228, AUROC: 0.8859, AP: 0.7886\n",
188 | "Epoch: 012, Loss: 1.4194, AUROC: 0.8875, AP: 0.7907\n",
189 | "Epoch: 013, Loss: 1.3903, AUROC: 0.8890, AP: 0.7928\n",
190 | "Epoch: 014, Loss: 1.3900, AUROC: 0.8904, AP: 0.7948\n",
191 | "Epoch: 015, Loss: 1.3866, AUROC: 0.8916, AP: 0.7969\n",
192 | "Epoch: 016, Loss: 1.3610, AUROC: 0.8928, AP: 0.7991\n",
193 | "Epoch: 017, Loss: 1.3729, AUROC: 0.8940, AP: 0.8016\n",
194 | "Epoch: 018, Loss: 1.3424, AUROC: 0.8953, AP: 0.8042\n",
195 | "Epoch: 019, Loss: 1.3410, AUROC: 0.8965, AP: 0.8069\n",
196 | "Epoch: 020, Loss: 1.3542, AUROC: 0.8977, AP: 0.8100\n",
197 | "Epoch: 021, Loss: 1.3635, AUROC: 0.8989, AP: 0.8131\n",
198 | "Epoch: 022, Loss: 1.3064, AUROC: 0.9001, AP: 0.8165\n",
199 | "Epoch: 023, Loss: 1.3163, AUROC: 0.9013, AP: 0.8197\n",
200 | "Epoch: 024, Loss: 1.3282, AUROC: 0.9026, AP: 0.8227\n",
201 | "Epoch: 025, Loss: 1.3091, AUROC: 0.9038, AP: 0.8250\n",
202 | "Epoch: 026, Loss: 1.2974, AUROC: 0.9050, AP: 0.8271\n",
203 | "Epoch: 027, Loss: 1.2778, AUROC: 0.9061, AP: 0.8291\n",
204 | "Epoch: 028, Loss: 1.2874, AUROC: 0.9070, AP: 0.8308\n",
205 | "Epoch: 029, Loss: 1.2723, AUROC: 0.9080, AP: 0.8323\n",
206 | "Epoch: 030, Loss: 1.2780, AUROC: 0.9087, AP: 0.8336\n",
207 | "Epoch: 031, Loss: 1.2760, AUROC: 0.9095, AP: 0.8348\n",
208 | "Epoch: 032, Loss: 1.2589, AUROC: 0.9102, AP: 0.8359\n",
209 | "Epoch: 033, Loss: 1.2620, AUROC: 0.9109, AP: 0.8369\n",
210 | "Epoch: 034, Loss: 1.2600, AUROC: 0.9114, AP: 0.8377\n",
211 | "Epoch: 035, Loss: 1.2685, AUROC: 0.9119, AP: 0.8384\n",
212 | "Epoch: 036, Loss: 1.2250, AUROC: 0.9124, AP: 0.8392\n",
213 | "Epoch: 037, Loss: 1.2224, AUROC: 0.9128, AP: 0.8399\n",
214 | "Epoch: 038, Loss: 1.2461, AUROC: 0.9132, AP: 0.8406\n",
215 | "Epoch: 039, Loss: 1.2168, AUROC: 0.9136, AP: 0.8413\n",
216 | "Epoch: 040, Loss: 1.2123, AUROC: 0.9140, AP: 0.8419\n",
217 | "Epoch: 041, Loss: 1.2369, AUROC: 0.9143, AP: 0.8425\n",
218 | "Epoch: 042, Loss: 1.2207, AUROC: 0.9146, AP: 0.8431\n",
219 | "Epoch: 043, Loss: 1.2010, AUROC: 0.9148, AP: 0.8436\n",
220 | "Epoch: 044, Loss: 1.2111, AUROC: 0.9151, AP: 0.8443\n",
221 | "Epoch: 045, Loss: 1.2058, AUROC: 0.9153, AP: 0.8451\n",
222 | "Epoch: 046, Loss: 1.1937, AUROC: 0.9155, AP: 0.8459\n",
223 | "Epoch: 047, Loss: 1.1681, AUROC: 0.9158, AP: 0.8469\n",
224 | "Epoch: 048, Loss: 1.1873, AUROC: 0.9160, AP: 0.8477\n",
225 | "Epoch: 049, Loss: 1.1639, AUROC: 0.9162, AP: 0.8484\n",
226 | "Epoch: 050, Loss: 1.1858, AUROC: 0.9164, AP: 0.8490\n",
227 | "Epoch: 051, Loss: 1.1696, AUROC: 0.9166, AP: 0.8496\n",
228 | "Epoch: 052, Loss: 1.1834, AUROC: 0.9168, AP: 0.8502\n",
229 | "Epoch: 053, Loss: 1.2075, AUROC: 0.9170, AP: 0.8507\n",
230 | "Epoch: 054, Loss: 1.1903, AUROC: 0.9171, AP: 0.8512\n",
231 | "Epoch: 055, Loss: 1.1360, AUROC: 0.9173, AP: 0.8516\n",
232 | "Epoch: 056, Loss: 1.1702, AUROC: 0.9174, AP: 0.8521\n",
233 | "Epoch: 057, Loss: 1.1352, AUROC: 0.9175, AP: 0.8525\n",
234 | "Epoch: 058, Loss: 1.1411, AUROC: 0.9176, AP: 0.8528\n",
235 | "Epoch: 059, Loss: 1.1275, AUROC: 0.9178, AP: 0.8533\n",
236 | "Epoch: 060, Loss: 1.1570, AUROC: 0.9179, AP: 0.8536\n",
237 | "Epoch: 061, Loss: 1.1273, AUROC: 0.9179, AP: 0.8539\n",
238 | "Epoch: 062, Loss: 1.1276, AUROC: 0.9179, AP: 0.8542\n",
239 | "Epoch: 063, Loss: 1.1243, AUROC: 0.9180, AP: 0.8546\n",
240 | "Epoch: 064, Loss: 1.0985, AUROC: 0.9181, AP: 0.8549\n",
241 | "Epoch: 065, Loss: 1.1350, AUROC: 0.9182, AP: 0.8551\n",
242 | "Epoch: 066, Loss: 1.1116, AUROC: 0.9182, AP: 0.8554\n",
243 | "Epoch: 067, Loss: 1.1489, AUROC: 0.9182, AP: 0.8555\n",
244 | "Epoch: 068, Loss: 1.1193, AUROC: 0.9182, AP: 0.8557\n",
245 | "Epoch: 069, Loss: 1.1360, AUROC: 0.9182, AP: 0.8558\n",
246 | "Epoch: 070, Loss: 1.1065, AUROC: 0.9182, AP: 0.8560\n",
247 | "Epoch: 071, Loss: 1.0939, AUROC: 0.9183, AP: 0.8562\n",
248 | "Epoch: 072, Loss: 1.1089, AUROC: 0.9183, AP: 0.8562\n",
249 | "Epoch: 073, Loss: 1.1342, AUROC: 0.9182, AP: 0.8562\n",
250 | "Epoch: 074, Loss: 1.0828, AUROC: 0.9182, AP: 0.8562\n",
251 | "Epoch: 075, Loss: 1.1173, AUROC: 0.9182, AP: 0.8563\n",
252 | "Epoch: 076, Loss: 1.1005, AUROC: 0.9182, AP: 0.8565\n",
253 | "Epoch: 077, Loss: 1.0860, AUROC: 0.9181, AP: 0.8566\n",
254 | "Epoch: 078, Loss: 1.0871, AUROC: 0.9181, AP: 0.8567\n",
255 | "Epoch: 079, Loss: 1.1317, AUROC: 0.9181, AP: 0.8569\n",
256 | "Epoch: 080, Loss: 1.0836, AUROC: 0.9181, AP: 0.8571\n",
257 | "Epoch: 081, Loss: 1.0848, AUROC: 0.9181, AP: 0.8573\n",
258 | "Epoch: 082, Loss: 1.0880, AUROC: 0.9182, AP: 0.8577\n",
259 | "Epoch: 083, Loss: 1.1009, AUROC: 0.9182, AP: 0.8579\n",
260 | "Epoch: 084, Loss: 1.0999, AUROC: 0.9182, AP: 0.8582\n",
261 | "Epoch: 085, Loss: 1.0923, AUROC: 0.9183, AP: 0.8586\n",
262 | "Epoch: 086, Loss: 1.1181, AUROC: 0.9183, AP: 0.8589\n",
263 | "Epoch: 087, Loss: 1.0989, AUROC: 0.9183, AP: 0.8592\n",
264 | "Epoch: 088, Loss: 1.0841, AUROC: 0.9183, AP: 0.8595\n",
265 | "Epoch: 089, Loss: 1.1156, AUROC: 0.9183, AP: 0.8597\n",
266 | "Epoch: 090, Loss: 1.0894, AUROC: 0.9183, AP: 0.8600\n",
267 | "Epoch: 091, Loss: 1.0946, AUROC: 0.9183, AP: 0.8603\n",
268 | "Epoch: 092, Loss: 1.0588, AUROC: 0.9183, AP: 0.8608\n",
269 | "Epoch: 093, Loss: 1.1279, AUROC: 0.9183, AP: 0.8611\n",
270 | "Epoch: 094, Loss: 1.0511, AUROC: 0.9183, AP: 0.8615\n",
271 | "Epoch: 095, Loss: 1.0543, AUROC: 0.9183, AP: 0.8620\n",
272 | "Epoch: 096, Loss: 1.0692, AUROC: 0.9183, AP: 0.8625\n",
273 | "Epoch: 097, Loss: 1.0644, AUROC: 0.9184, AP: 0.8630\n",
274 | "Epoch: 098, Loss: 1.0727, AUROC: 0.9185, AP: 0.8635\n",
275 | "Epoch: 099, Loss: 1.0578, AUROC: 0.9186, AP: 0.8640\n"
276 | ]
277 | }
278 | ],
279 | "source": [
280 | "# %%timeit\n",
281 | "sensei.train()"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "id": "e16a285e",
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "# sensei.save_model('model_example')"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 8,
297 | "id": "e055e8f3",
298 | "metadata": {
299 | "scrolled": true
300 | },
301 | "outputs": [],
302 | "source": [
303 | "# get distance between wt and ko\n",
304 | "\n",
305 | "z_mu_wt, z_std_wt = sensei.get_latent_vars(data_wt)\n",
306 | "z_mu_ko, z_std_ko = sensei.get_latent_vars(data_ko)\n",
307 | "dis = gk.utils.get_distance(z_mu_ko, z_std_ko, z_mu_wt, z_std_wt, by=\"KL\")"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 9,
313 | "id": "78e852b7",
314 | "metadata": {},
315 | "outputs": [
316 | {
317 | "data": {
318 | "text/html": [
319 | "\n",
320 | "\n",
333 | "
\n",
334 | " \n",
335 | " \n",
336 | " | \n",
337 | " dis | \n",
338 | " rank | \n",
339 | "
\n",
340 | " \n",
341 | " \n",
342 | " \n",
343 | " LAMB3 | \n",
344 | " 52.747102 | \n",
345 | " 1 | \n",
346 | "
\n",
347 | " \n",
348 | " COL3A1 | \n",
349 | " 0.138186 | \n",
350 | " 2 | \n",
351 | "
\n",
352 | " \n",
353 | " S100A9 | \n",
354 | " 0.085078 | \n",
355 | " 3 | \n",
356 | "
\n",
357 | " \n",
358 | " MDK | \n",
359 | " 0.084988 | \n",
360 | " 4 | \n",
361 | "
\n",
362 | " \n",
363 | " MFAP2 | \n",
364 | " 0.079943 | \n",
365 | " 5 | \n",
366 | "
\n",
367 | " \n",
368 | "
\n",
369 | "
"
370 | ],
371 | "text/plain": [
372 | " dis rank\n",
373 | "LAMB3 52.747102 1\n",
374 | "COL3A1 0.138186 2\n",
375 | "S100A9 0.085078 3\n",
376 | "MDK 0.084988 4\n",
377 | "MFAP2 0.079943 5"
378 | ]
379 | },
380 | "execution_count": 9,
381 | "metadata": {},
382 | "output_type": "execute_result"
383 | }
384 | ],
385 | "source": [
386 | "# raw ranked gene list\n",
387 | "\n",
388 | "res_raw = utils.get_generank(data_wt, dis, rank=True)\n",
389 | "res_raw.head()"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 10,
395 | "id": "9ddf3b86",
396 | "metadata": {},
397 | "outputs": [
398 | {
399 | "name": "stderr",
400 | "output_type": "stream",
401 | "text": [
402 | "Permutating: 100%|██████████| 100/100 [00:02<00:00, 33.71it/s]\n"
403 | ]
404 | },
405 | {
406 | "data": {
407 | "text/html": [
408 | "\n",
409 | "\n",
422 | "
\n",
423 | " \n",
424 | " \n",
425 | " | \n",
426 | " dis | \n",
427 | " index | \n",
428 | " hit | \n",
429 | " rank | \n",
430 | "
\n",
431 | " \n",
432 | " \n",
433 | " \n",
434 | " LAMB3 | \n",
435 | " 52.747102 | \n",
436 | " 66 | \n",
437 | " 100 | \n",
438 | " 1 | \n",
439 | "
\n",
440 | " \n",
441 | " COL3A1 | \n",
442 | " 0.138186 | \n",
443 | " 9 | \n",
444 | " 100 | \n",
445 | " 2 | \n",
446 | "
\n",
447 | " \n",
448 | " S100A9 | \n",
449 | " 0.085078 | \n",
450 | " 194 | \n",
451 | " 100 | \n",
452 | " 3 | \n",
453 | "
\n",
454 | " \n",
455 | " TIMP1 | \n",
456 | " 0.033647 | \n",
457 | " 145 | \n",
458 | " 100 | \n",
459 | " 4 | \n",
460 | "
\n",
461 | " \n",
462 | "
\n",
463 | "
"
464 | ],
465 | "text/plain": [
466 | " dis index hit rank\n",
467 | "LAMB3 52.747102 66 100 1\n",
468 | "COL3A1 0.138186 9 100 2\n",
469 | "S100A9 0.085078 194 100 3\n",
470 | "TIMP1 0.033647 145 100 4"
471 | ]
472 | },
473 | "execution_count": 10,
474 | "metadata": {},
475 | "output_type": "execute_result"
476 | }
477 | ],
478 | "source": [
479 | "# if permutation test\n",
480 | "\n",
481 | "null = sensei.pmt(data_ko, n=100, by=\"KL\")\n",
482 | "res = utils.get_generank(data_wt, dis, null,)\n",
483 | "# save_significant_as = 'gene_list_example')\n",
484 | "res"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": null,
490 | "id": "fc571eb2",
491 | "metadata": {},
492 | "outputs": [],
493 | "source": []
494 | }
495 | ],
496 | "metadata": {
497 | "kernelspec": {
498 | "display_name": "Python 3 (ipykernel)",
499 | "language": "python",
500 | "name": "python3"
501 | },
502 | "language_info": {
503 | "codemirror_mode": {
504 | "name": "ipython",
505 | "version": 3
506 | },
507 | "file_extension": ".py",
508 | "mimetype": "text/x-python",
509 | "name": "python",
510 | "nbconvert_exporter": "python",
511 | "pygments_lexer": "ipython3",
512 | "version": "3.9.13"
513 | }
514 | },
515 | "nbformat": 4,
516 | "nbformat_minor": 5
517 | }
518 |
--------------------------------------------------------------------------------
/notebook/Example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "6323b4c4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import numpy as np\n",
12 | "import pandas as pd\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import scanpy as sc\n",
15 | "\n",
16 | "sc.settings.verbosity = 0"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "id": "ffae2228",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import GenKI as gk\n",
27 | "from GenKI.preprocesing import build_adata\n",
28 | "from GenKI.dataLoader import DataLoader\n",
29 | "from GenKI.train import VGAE_trainer\n",
30 | "from GenKI import utils\n",
31 | "\n",
32 | "%load_ext autoreload\n",
33 | "%autoreload 2"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "id": "2b15fb16",
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "load counts from data/microglial_seurat_WT.h5ad\n"
47 | ]
48 | },
49 | {
50 | "data": {
51 | "text/plain": [
52 | "AnnData object with n_obs × n_vars = 100 × 300\n",
53 | " obs: 'sce_source', 'treatment', 'trem2_genotype', 'snn_cluster', 'nCount_RNA', 'nFeature_RNA'\n",
54 | " var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'\n",
55 | " layers: 'norm'"
56 | ]
57 | },
58 | "execution_count": 3,
59 | "metadata": {},
60 | "output_type": "execute_result"
61 | }
62 | ],
63 | "source": [
64 | "# subset data as an example\n",
65 | "\n",
66 | "adata = build_adata(\"data/microglial_seurat_WT.h5ad\")\n",
67 | "adata = adata[:100, :300].copy()\n",
68 | "adata"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 4,
74 | "id": "1e37d540",
75 | "metadata": {
76 | "scrolled": true
77 | },
78 | "outputs": [
79 | {
80 | "name": "stdout",
81 | "output_type": "stream",
82 | "text": [
83 | "use all the cells (100) in adata\n",
84 | "build GRN\n",
85 | "ray init, using 8 CPUs\n",
86 | "execution time of making pcNet: 6.26 s\n",
87 | "GRN has been built and saved in \"GRNs\\pcNet_example.npz\"\n",
88 | "init completed\n",
89 | "\n"
90 | ]
91 | }
92 | ],
93 | "source": [
94 | "# load data\n",
95 | "\n",
96 | "data_wrapper = DataLoader(\n",
97 | " adata, # adata object\n",
98 | " target_gene = [\"TUBG1\"], # KO gene name\n",
99 | " target_cell = None, # obsname for cell type, if none use all\n",
100 | " obs_label = \"ident\", # colname for genes\n",
101 | " GRN_file_dir = \"GRNs\", # folder name for GRNs\n",
102 | " rebuild_GRN = True, # whether build GRN by pcNet\n",
103 | " pcNet_name = \"pcNet_example\", # GRN file name\n",
104 | " verbose = True, # whether verbose\n",
105 | " n_cpus = 8, # multiprocessing\n",
106 | " )\n",
107 | "\n",
108 | "data_wt = data_wrapper.load_data()\n",
109 | "data_ko = data_wrapper.load_kodata()"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 5,
115 | "id": "7862c768",
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "# init trainer\n",
120 | "\n",
121 | "hyperparams = {\"epochs\": 100, \n",
122 | " \"lr\": 7e-4, \n",
123 | " \"beta\": 1e-4, \n",
124 | " \"seed\": 8096}\n",
125 | "log_dir=None \n",
126 | "\n",
127 | "sensei = VGAE_trainer(data_wt, \n",
128 | " epochs=hyperparams[\"epochs\"], \n",
129 | " lr=hyperparams[\"lr\"], \n",
130 | " log_dir=log_dir, \n",
131 | " beta=hyperparams[\"beta\"],\n",
132 | " seed=hyperparams[\"seed\"],\n",
133 | " verbose=False,\n",
134 | " )"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 6,
140 | "id": "31bfb6a0",
141 | "metadata": {
142 | "scrolled": true
143 | },
144 | "outputs": [],
145 | "source": [
146 | "# %%timeit\n",
147 | "sensei.train()"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 7,
153 | "id": "e16a285e",
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "# sensei.save_model('model_example')"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 8,
163 | "id": "e055e8f3",
164 | "metadata": {
165 | "scrolled": true
166 | },
167 | "outputs": [
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "(300,)\n"
173 | ]
174 | }
175 | ],
176 | "source": [
177 | "# get distance between wt and ko\n",
178 | "\n",
179 | "z_mu_wt, z_std_wt = sensei.get_latent_vars(data_wt)\n",
180 | "z_mu_ko, z_std_ko = sensei.get_latent_vars(data_ko)\n",
181 | "dis = gk.utils.get_distance(z_mu_ko, z_std_ko, z_mu_wt, z_std_wt, by=\"KL\")\n",
182 | "print(dis.shape)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 9,
188 | "id": "78e852b7",
189 | "metadata": {},
190 | "outputs": [
191 | {
192 | "data": {
193 | "text/html": [
194 | "\n",
195 | "\n",
208 | "
\n",
209 | " \n",
210 | " \n",
211 | " | \n",
212 | " dis | \n",
213 | " rank | \n",
214 | "
\n",
215 | " \n",
216 | " \n",
217 | " \n",
218 | " TUBG1 | \n",
219 | " 2.565490 | \n",
220 | " 1 | \n",
221 | "
\n",
222 | " \n",
223 | " TYROBP | \n",
224 | " 0.001753 | \n",
225 | " 2 | \n",
226 | "
\n",
227 | " \n",
228 | " LST1 | \n",
229 | " 0.001628 | \n",
230 | " 3 | \n",
231 | "
\n",
232 | " \n",
233 | " CYBA | \n",
234 | " 0.001517 | \n",
235 | " 4 | \n",
236 | "
\n",
237 | " \n",
238 | " LAT2 | \n",
239 | " 0.001438 | \n",
240 | " 5 | \n",
241 | "
\n",
242 | " \n",
243 | "
\n",
244 | "
"
245 | ],
246 | "text/plain": [
247 | " dis rank\n",
248 | "TUBG1 2.565490 1\n",
249 | "TYROBP 0.001753 2\n",
250 | "LST1 0.001628 3\n",
251 | "CYBA 0.001517 4\n",
252 | "LAT2 0.001438 5"
253 | ]
254 | },
255 | "execution_count": 9,
256 | "metadata": {},
257 | "output_type": "execute_result"
258 | }
259 | ],
260 | "source": [
261 | "# raw ranked gene list\n",
262 | "\n",
263 | "res_raw = utils.get_generank(data_wt, dis, rank=True)\n",
264 | "res_raw.head()"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 10,
270 | "id": "9ddf3b86",
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "name": "stderr",
275 | "output_type": "stream",
276 | "text": [
277 | "Permutating: 100%|██████████| 100/100 [00:03<00:00, 33.02it/s]\n"
278 | ]
279 | },
280 | {
281 | "data": {
282 | "text/html": [
283 | "\n",
284 | "\n",
297 | "
\n",
298 | " \n",
299 | " \n",
300 | " | \n",
301 | " dis | \n",
302 | " index | \n",
303 | " hit | \n",
304 | " rank | \n",
305 | "
\n",
306 | " \n",
307 | " \n",
308 | " \n",
309 | " TUBG1 | \n",
310 | " 2.56549 | \n",
311 | " 163 | \n",
312 | " 100 | \n",
313 | " 1 | \n",
314 | "
\n",
315 | " \n",
316 | "
\n",
317 | "
"
318 | ],
319 | "text/plain": [
320 | " dis index hit rank\n",
321 | "TUBG1 2.56549 163 100 1"
322 | ]
323 | },
324 | "execution_count": 10,
325 | "metadata": {},
326 | "output_type": "execute_result"
327 | }
328 | ],
329 | "source": [
330 | "# if permutation test\n",
331 | "\n",
332 | "null = sensei.pmt(data_ko, n=100, by=\"KL\")\n",
333 | "res = utils.get_generank(data_wt, dis, null,)\n",
334 | "# save_significant_as = 'gene_list_example')\n",
335 | "res"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "6a349662",
342 | "metadata": {},
343 | "outputs": [],
344 | "source": []
345 | }
346 | ],
347 | "metadata": {
348 | "kernelspec": {
349 | "display_name": "Python 3 (ipykernel)",
350 | "language": "python",
351 | "name": "python3"
352 | },
353 | "language_info": {
354 | "codemirror_mode": {
355 | "name": "ipython",
356 | "version": 3
357 | },
358 | "file_extension": ".py",
359 | "mimetype": "text/x-python",
360 | "name": "python",
361 | "nbconvert_exporter": "python",
362 | "pygments_lexer": "ipython3",
363 | "version": "3.9.13"
364 | }
365 | },
366 | "nbformat": 4,
367 | "nbformat_minor": 5
368 | }
369 |
--------------------------------------------------------------------------------
/notebook/SERGIO.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import scanpy as sc
4 | # import matplotlib.pyplot as plt
5 | # import seaborn as sns
6 | # from scipy import stats
7 | # from PIL import Image
8 | # import io
9 | import os
10 | from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, average_precision_score
11 | from sklearn.covariance import GraphicalLasso, GraphicalLassoCV
12 | # from time import time
13 |
14 | # import matplotlib as mpl
15 | # mpl.rcParams['interactive'] == False
16 | # mpl.rcParams['lines.linewidth'] = 2
17 | # mpl.rcParams['lines.linestyle'] = '--'
18 | # mpl.rcParams['axes.titlesize'] = 24
19 | # mpl.rcParams['axes.labelsize'] = 16
20 | # mpl.rcParams['lines.markersize'] = 2
21 | # mpl.rcParams['xtick.labelsize'] = 16
22 | # mpl.rcParams['ytick.labelsize'] = 16
23 | # mpl.rcParams['figure.dpi'] = 80
24 | # mpl.rcParams['legend.fontsize'] = 12
25 | # markers = [".",",","o","v","^","<",">"]
26 | # linestyles = ['-', '--', '-.', ':', 'dashed', 'dashdot', 'dotted']
27 |
28 | import GenKI as gk
29 | from GenKI.preprocesing import build_adata
30 | from GenKI.train import VGAE_trainer
31 | from GenKI import utils
32 | from scTenifold import scTenifoldKnk
33 |
34 | N = 10
35 |
36 | def main(args):
37 | exp = pd.read_csv(f"SERGIO/De-noised_{args.data}/simulated_noNoise_{args.data_id}.csv", index_col=0)
38 | gt_grn_df = pd.read_csv(f"SERGIO/De-noised_{args.data}/gt_GRN.csv", header=None)
39 | inds = list(zip(gt_grn_df[0], gt_grn_df[1]))
40 | gt_grn = np.zeros((len(exp), len(exp))).astype(int)
41 | for ind in inds:
42 | gt_grn[ind] = 1
43 |
44 | genes = np.arange(len(exp)).astype(str)
45 | genes = np.char.add(["g"]*len(genes), genes)
46 | idx = gt_grn_df[0].value_counts().index.to_numpy()
47 | KO_genes = genes[idx] # gene list to be KO
48 | print(KO_genes[:N])
49 |
50 | ada = sc.AnnData(exp.values.T)
51 | ada.var_names = genes
52 |
53 | corr_mat = np.corrcoef(ada.X.T) # corr
54 | cov = GraphicalLassoCV(alphas=4).fit(ada.X) # GGM
55 | ggm = GraphicalLasso(alpha=cov.alpha_,
56 | max_iter=100,
57 | assume_centered=False)
58 | res = ggm.fit(ada.X)
59 | pre_ = np.around(res.precision_, decimals=6)
60 |
61 | # GenKI
62 | ada_WT = build_adata(ada, scale_data = True, uppercase=False)
63 |
64 | for target_gene in KO_genes[:N]:
65 | data_wrapper = gk.DataLoader(ada_WT,
66 | target_gene = [target_gene],
67 | target_cell = None,
68 | # obs_label = None,
69 | GRN_file_dir = 'GRNs',
70 | rebuild_GRN = False,
71 | pcNet_name = f'pcNet_{args.data}_{args.data_id}_man', # build network
72 | cutoff = 85,
73 | verbose = False,
74 | )
75 | ko_idx = data_wrapper([target_gene])[0]
76 | labels = np.zeros(len(exp)).astype(int)
77 | labels_inds = gt_grn_df.loc[gt_grn_df[0]==ko_idx, 1].to_numpy().astype(int)
78 | labels[labels_inds] = 1
79 | data = data_wrapper.load_data()
80 | data_KO = data_wrapper.load_kodata()
81 |
82 | hyperparams = {"epochs": 75,
83 | "lr": 7e-4,
84 | "beta": 1e-4,
85 | "seed": 8096}
86 | log_dir=None
87 |
88 | sensei = VGAE_trainer(data,
89 | epochs=hyperparams["epochs"],
90 | lr=hyperparams["lr"],
91 | log_dir=log_dir,
92 | beta=hyperparams["beta"],
93 | seed=hyperparams["seed"],
94 | )
95 | sensei.train()
96 | z_mu, z_std = sensei.get_latent_vars(data)
97 | z_mu_KO, z_std_KO = sensei.get_latent_vars(data_KO)
98 | dis = gk.utils.get_distance(z_mu_KO, z_std_KO, z_mu, z_std, by = 'KL')
99 | res = utils.get_generank(data, dis, rank=False)
100 | scores = res["dis"].to_numpy()
101 | fpr, tpr, thres = roc_curve(labels, scores)
102 | roc_auc = roc_auc_score(labels, scores)
103 | print("AUROC:", roc_auc)
104 | precision, recall, _ = precision_recall_curve(labels, scores)
105 | ap = average_precision_score(labels, scores)
106 | print("AP:", ap)
107 |
108 | # junk classifier
109 | scores_junk = np.random.uniform(0,1,len(exp))
110 | fpr, tpr, thres = roc_curve(labels, scores_junk)
111 | roc_auc_junk = roc_auc_score(labels, scores_junk)
112 | print("AUROC_baseline:", roc_auc_junk)
113 | precision, recall, _ = precision_recall_curve(labels, scores_junk)
114 | ap_junk = average_precision_score(labels, scores_junk)
115 | print("AP_baseline:", ap_junk)
116 |
117 | # corr
118 | score_corr = corr_mat[ko_idx]
119 | fpr, tpr, thres = roc_curve(labels, score_corr)
120 | roc_auc_corr = roc_auc_score(labels, score_corr)
121 | print("AUROC_corr:", roc_auc_corr)
122 | precision, recall, _ = precision_recall_curve(labels, score_corr)
123 | ap_corr = average_precision_score(labels, score_corr)
124 | print("AP_corr:", ap_corr)
125 |
126 | # GGM
127 | scores_ggm = pre_[ko_idx]
128 | fpr, tpr, thres = roc_curve(labels, scores_ggm)
129 | roc_auc_ggm = roc_auc_score(labels, scores_ggm)
130 | print("AUROC_ggm:", roc_auc_ggm)
131 | precision, recall, _ = precision_recall_curve(labels, scores_ggm)
132 | ap_ggm = average_precision_score(labels, scores_ggm)
133 | print("AP_ggm:", ap_ggm)
134 |
135 | # Knk
136 | exp.index = genes
137 | # sct = scTenifoldKnk(data=exp,
138 | # ko_method="default",
139 | # ko_genes=[target_gene], # the gene you wants to knock out
140 | # qc_kws={"min_lib_size": 1, "min_percent": 0.001},
141 | # )
142 | # result = sct.build()
143 |
144 | knk = scTenifoldKnk(data=exp,
145 | qc_kws={"min_lib_size": 1, "min_percent": 0.001},
146 | )
147 | knk.run_step("qc")
148 | knk.run_step("nc", n_cpus=1)
149 | knk.run_step("td")
150 | knk.run_step("ko", ko_genes=[target_gene], ko_method="default")
151 | knk.run_step("ma")
152 | knk.run_step("dr", sorted_by="adjusted p-value")
153 | result = knk.d_regulation
154 |
155 | knk_score = dict(zip(result["Gene"], result["FC"]))
156 | res["temp"] = res.index
157 | res["Knk"] = res["temp"].map(knk_score)
158 | del res["temp"]
159 | scores_knk = res["Knk"].to_numpy()
160 | fpr_knk, tpr_knk, thres_knk = roc_curve(labels, scores_knk)
161 | roc_auc_knk = roc_auc_score(labels, scores_knk)
162 | print(roc_auc_knk)
163 | precision_knk, recall_knk, _ = precision_recall_curve(labels, scores_knk)
164 | ap_knk = average_precision_score(labels, scores_knk)
165 | print(ap_knk)
166 |
167 | try:
168 | f = open(os.path.join("result", f'{args.out}.txt'), 'r')
169 | except IOError:
170 | f = open(os.path.join("result", f'{args.out}.txt'), 'w')
171 | f.writelines("file,file_id,KO_gene,ROC,ROC_Knk,ROC_junk,ROC_corr,ROC_ggm,"\
172 | "AP,AP_Knk,AP_junk,AP_corr,AP_ggm\n")
173 | finally:
174 | f = open(os.path.join("result", f'{args.out}.txt'), 'a')
175 | f.writelines(f"{args.data},{args.data_id},{target_gene},"\
176 | f"{roc_auc:.4f},{roc_auc_knk:.4f},{roc_auc_junk:.4f},{roc_auc_corr:.4f},{roc_auc_ggm:.4f},"\
177 | f"{ap:.4f},{ap_knk:.4f},{ap_junk:.4f},{ap_corr:.4f},{ap_ggm:.4f}"\
178 | "\n")
179 | f.close()
180 |
181 |
182 | if __name__ == '__main__':
183 | import argparse
184 | parser = argparse.ArgumentParser()
185 | parser.add_argument('--data', type = str, default = '100G_9T_300cPerT_4_DS1')
186 | parser.add_argument('--data_id', type = int, default = 0)
187 | parser.add_argument('-O', '--out', default = 'SERGIO_trials')
188 |
189 | args = parser.parse_args()
190 | main(args)
191 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | anndata==0.8.0,
2 | matplotlib~=3.5.1
3 | numpy>=1.21.6
4 | pandas~=1.4.2
5 | ray>=1.11.0
6 | scanpy==1.9.1
7 | scipy~=1.8.0
8 | statsmodels~=0.13.2
9 | scikit_learn>=1.0.2
10 | torch==1.11.0
11 | tqdm~=4.64.0
12 |
--------------------------------------------------------------------------------
/run_eva.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # trap break INT
4 | # for noise in 1 2 3 4 5
5 | # do
6 | # for run in `seq 30 $max`
7 | # do
8 | # echo "noise: $noise, run: $run"
9 | # python -m GenKI.train --dir data --train_out train_X_$noise -X $noise
10 | # echo -e "completed\n"
11 | # done
12 | # done
13 |
14 |
15 | # for noise in 1 2 3 4 5
16 | # do
17 | # for run in `seq 30 $max`
18 | # do
19 | # echo "noise: $noise, run: $run"
20 | # python -m GenKI.train --dir data --train_out train_E_$noise -E $noise
21 | # echo -e "completed\n"
22 | # done
23 | # done
24 | # trap - INT
25 |
26 |
27 | for noise in 0.1 0.3 0.5 0.7 0.9
28 | do
29 | for run in `seq 10 $max`
30 | do
31 | echo "dropout: $noise, run: $run"
32 | python -m GenKI.train --dir data --train_out train_XO_$noise -XO $noise
33 | echo -e "completed\n"
34 | done
35 | done
36 | trap - INT
37 |
38 |
39 | # trap break INT
40 | # for run in `seq 30 $max`
41 | # do
42 | # echo "run: $run"
43 | # python -m GenKI.train --train_out train_log --do_test --generank_out genelist --r2_out r2_score
44 | # echo -e "completed\n"
45 | # done
46 | # trap - INT
47 |
48 |
49 | # trap break INT
50 | # for cutoff in 30 55 70 95
51 | # do
52 | # trap break INT
53 | # for run in `seq 30 $max`
54 | # do
55 | # echo "run: $run"
56 | # python -m GenKI.train --train_out train_log --dir data_$cutoff
57 | # echo -e "completed\n"
58 | # done
59 | # done
60 | # trap - INT
61 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from setuptools import setup, find_packages
3 |
4 | HERE = pathlib.Path(__file__).parent
5 | README = (HERE / "README.md").read_text()
6 | DESCRIPTION = "GenKI"
7 | PACKAGES = find_packages(exclude=("tests*",))
8 | exec(open('GenKI/version.py').read())
9 |
10 | INSTALL_REQUIRES = [
11 | "anndata==0.8.0",
12 | "matplotlib~=3.5.1",
13 | "numpy>=1.21.6",
14 | "pandas~=1.4.2",
15 | "ray>=1.11.0",
16 | "scanpy==1.9.1",
17 | "scipy~=1.8.0",
18 | "statsmodels~=0.13.2",
19 | "scikit_learn>=1.0.2",
20 | # "torch==1.11.0",
21 | "tqdm~=4.64.0",
22 | ]
23 |
24 | setup(
25 | name="GenKI",
26 | version=__version__,
27 | description=DESCRIPTION,
28 | long_description=README,
29 | long_description_content_type="text/markdown",
30 | url="https://github.com/yjgeno/GenKI",
31 | author="Yongjian Yang, TAMU",
32 | author_email="yjyang027@tamu.edu",
33 | license="MIT",
34 | keywords=[
35 | "neural network",
36 | "graph neural network",
37 | "variational graph neural network",
38 | "computational-biology",
39 | "single-cell",
40 | "gene knock-out",
41 | "gene regulatroy network",
42 | ],
43 | classifiers=[
44 | "License :: OSI Approved :: MIT License",
45 | "Intended Audience :: Science/Research",
46 | "Topic :: Scientific/Engineering :: Bio-Informatics",
47 | "Programming Language :: Python :: 3",
48 | ],
49 | packages=PACKAGES,
50 | include_package_data=True, # MANIFEST
51 | install_requires=INSTALL_REQUIRES,
52 | )
--------------------------------------------------------------------------------