├── .gitignore ├── GenKI ├── __init__.py ├── dataLoader.py ├── model.py ├── pcNet.py ├── preprocesing.py ├── train.py ├── tune.py ├── utils.py └── version.py ├── LICENSE ├── MANIFEST.in ├── README.md ├── data └── README.md ├── environment.yml ├── logo.jpg ├── notebook ├── Compatible_matlab.ipynb ├── Example.ipynb └── SERGIO.py ├── requirements.txt ├── run_eva.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # custom 132 | notebook/* 133 | .Rproj.user 134 | BACKUP 135 | data* 136 | log_dir 137 | !notebook/*.py 138 | !notebook/Example.ipynb 139 | !notebook/Compatible_matlab.ipynb 140 | -------------------------------------------------------------------------------- /GenKI/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /GenKI/dataLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import anndata 4 | import os 5 | import torch 6 | from torch_geometric.data import Data 7 | import matplotlib.pyplot as plt 8 | # from tqdm import tqdm 9 | from .pcNet import make_pcNet 10 | from .preprocesing import check_adata 11 | 12 | 13 | class scBase(): 14 | def __init__(self, 15 | adata: anndata.AnnData, 16 | target_gene: list[str], 17 | target_cell: str = None, 18 | obs_label: str = "ident", 19 | GRN_file_dir: str = "GRNs", 20 | rebuild_GRN: bool = False, 21 | pcNet_name: str = "pcNet", 22 | verbose: bool = False, 23 | **kwargs): 24 | 25 | check_adata(adata) 26 | self._gene_names = list(adata.var_names) 27 | if isinstance(target_gene[0], int): # list[int] 28 | target_gene = adata.var_names[target_gene].tolist() 29 | if all([g not in self._gene_names for g in target_gene]): 30 | raise IndexError("The input target gene should be in the gene list of adata") 31 | else: 32 | self._target_gene = target_gene 33 | # self._idx_KO = self._gene_names.index(target_gene) 34 | if target_cell is None: 35 | self._counts = adata.X # use all cells, standardized counts 36 | if verbose: 37 | print(f"use all the cells ({self._counts.shape[0]}) in adata") 38 | elif not (obs_label in adata.obs.keys()): 39 | raise IndexError("require a valid cell label in adata.obs") 40 | else: 41 | self._counts = adata[adata.obs[obs_label] == target_cell, :].X 42 | self._counts = scipy.sparse.lil_matrix(self._counts) # sparse 43 | pcNet_path = os.path.join(GRN_file_dir, f"{pcNet_name}.npz") 44 | if rebuild_GRN: 45 | if verbose: 46 | print("build GRN") 47 | self._net = make_pcNet(adata.layers["norm"], nComp = 5, as_sparse = True, timeit = verbose, **kwargs) 48 | os.makedirs(GRN_file_dir, exist_ok = True) # create dir 49 | scipy.sparse.save_npz(pcNet_path, self._net) # save GRN 50 | # scipy.sparse.lil_matrix(pcNet_np) # np to sparse 51 | if verbose: 52 | print(f"GRN has been built and saved in \"{pcNet_path}\"") 53 | else: 54 | try: 55 | if verbose: 56 | print(f"loading GRN from \"{pcNet_path}\"") 57 | self._net = scipy.sparse.load_npz(pcNet_path) 58 | except ImportError: 59 | print("require npz file name") 60 | if verbose: 61 | print("init completed\n") 62 | 63 | @property 64 | def counts(self): 65 | return self._counts 66 | 67 | @property 68 | def net(self): 69 | return self._net 70 | 71 | @property 72 | def target_gene(self): 73 | return self._target_gene 74 | 75 | 76 | def __len__(self): 77 | return len(self._counts) 78 | 79 | 80 | def __call__(self, gene_name: list[str]): 81 | return [self._gene_names.index(g) for g in gene_name] # return gene index 82 | 83 | 84 | def __repr__(self): 85 | info = f"counts: {self._counts.shape}\n"\ 86 | f"net: {self._net.shape}\n"\ 87 | f"target_gene: {self._target_gene}" 88 | return info 89 | 90 | 91 | class DataLoader(scBase): 92 | def __init__(self, 93 | adata: anndata.AnnData, 94 | target_gene: list[str], 95 | target_cell: str = None, 96 | obs_label: str = "ident", 97 | GRN_file_dir: str = "GRNs", 98 | rebuild_GRN: bool = False, 99 | pcNet_name: str = "pcNet", 100 | cutoff: int = 85, 101 | verbose: bool = False, 102 | **kwargs): 103 | super().__init__(adata, 104 | target_gene, 105 | target_cell, 106 | obs_label, 107 | GRN_file_dir, 108 | rebuild_GRN, 109 | pcNet_name, 110 | verbose, 111 | **kwargs) 112 | self.verbose = verbose 113 | self.cutoff = cutoff 114 | self.edge_index = torch.tensor(self._build_edges(), dtype = torch.long) # dense np to tensor 115 | 116 | 117 | def _build_edges(self, net = None): 118 | ''' 119 | net: array-like, GRN built from data. 120 | ''' 121 | if net is None: 122 | net = self._net 123 | grn_to_use = abs(net.toarray()) if scipy.sparse.issparse(net) else abs(net) 124 | grn_to_use[grn_to_use < np.percentile(grn_to_use, self.cutoff)] = 0 125 | edge_index_np = np.asarray(np.where(grn_to_use > 0), dtype = int) 126 | return edge_index_np # dense np 127 | 128 | def load_data(self): 129 | return self._data_init() 130 | 131 | def load_kodata(self): 132 | return self._KO_data_init() 133 | 134 | 135 | def _data_init(self): 136 | counts = self._counts.toarray() if scipy.sparse.issparse(self._counts) else self._counts 137 | x = torch.tensor(counts.T, dtype = torch.float) # define x 138 | return Data(x = x, edge_index = self.edge_index, y = self._gene_names) 139 | 140 | 141 | def _KO_data_init(self): 142 | # KO edges 143 | mask = ~(torch.isin(self.edge_index[0], torch.tensor(self(self._target_gene))) + torch.isin(self.edge_index[1], torch.tensor(self(self._target_gene)))) 144 | edge_index_KO = self.edge_index[:, mask] # torch.long 145 | 146 | # KO counts 147 | counts_KO = self._counts.copy() 148 | counts_KO[:, self(self._target_gene)] = 0 149 | counts_KO = counts_KO.toarray() if scipy.sparse.issparse(counts_KO) else counts_KO 150 | x_KO = torch.tensor(counts_KO.T, dtype = torch.float) # define counts (KO) 151 | # if self.verbose: 152 | # print(f"set expression of {self._target_gene} to zeros and remove edges") 153 | return Data(x = x_KO, edge_index = edge_index_KO, y = self._gene_names) 154 | 155 | 156 | def _gen_ZINB(self, 157 | p_BIN = 0.95, 158 | n_NB = 10, 159 | p_NB = 0.5, 160 | noise = True, 161 | normalize = True, 162 | show = False, 163 | decays: list[float] = [1.]): 164 | ''' 165 | p_BIN: parameter of binomial distribution. 166 | n_NB, p_NB: parameters of negative binomial distribution. 167 | noise: add noise to ZINB samples. 168 | normalize: LogNorm samples, scale_factor = 10e4. 169 | show: show histograms of simulated samples and WT expression. 170 | decay: scales on n_NB simulating target gene expression. 171 | ''' 172 | np.random.seed(11) 173 | if decays[0] != 1.: 174 | raise ValueError("first element in decays should always be 1") 175 | elif len(decays) != len(self._target_gene): 176 | raise ValueError("length of decays should match target gene") 177 | else: 178 | s_BIN = np.random.binomial(1, p_BIN, self._counts.shape[0]) # sample binomial 0 and 1 179 | s = np.array([s_BIN * np.random.negative_binomial(decay*n_NB, p_NB, self._counts.shape[0]) for decay in decays]).T # ZINB 180 | if noise: 181 | s[s > 0] = s[s > 0]+ np.random.choice([-1, 0, 1], size = len(s[s > 0]), p = [0.15, 0.7, 0.15]) # add int noise 182 | if normalize: 183 | scale_factor = 10e4 184 | s = np.log1p(s * scale_factor / s.sum()) # LogNorm 185 | if show: 186 | fig, ax = plt.subplots(ncols = 2, figsize = (12, 5)) 187 | counts = self._counts.toarray() if scipy.sparse.issparse(self._counts) else self._counts 188 | for i, (values, title) in enumerate(zip([counts[:, self(self._target_gene)], s], ['WT', 'Simulated'])): 189 | ax[i].hist(values, bins = counts.shape[0]//5, label = self._target_gene) 190 | ax[i].set_title(title) 191 | ax[i].legend() 192 | plt.show() 193 | if self.verbose: 194 | print(f"sample gene patterns ({self._counts.shape[0]}) from NB{n_NB, p_NB} with P(zero) = {round(1-p_BIN, 2)}") 195 | return s 196 | 197 | 198 | def scale_rows_and_cols(self, scale): # scale target gene edges in pcNet as a whole 199 | mask = np.ones(len(self._gene_names), dtype = bool) 200 | mask[self(self._target_gene)] = 0 201 | mask = np.invert(mask[:, None] @ mask[None, :]) 202 | self._net[mask] *= scale 203 | 204 | 205 | def OE_data_init(self, 206 | weight_scale: list[float] = [10.,], 207 | gene_patterns: list[float] = None, 208 | **kwargs): 209 | # OE edges 210 | weight_scale = np.asarray(weight_scale) 211 | net_OE = self._net.toarray().copy() 212 | net_OE[:, self(self._target_gene)] *= weight_scale[None, :] 213 | net_OE[self(self._target_gene), :] *= weight_scale[:, None] # note cross points get multiplied twice 214 | 215 | edge_index_OE = self._build_edges(net_OE) 216 | edge_index_OE = torch.tensor(edge_index_OE, dtype = torch.long) 217 | 218 | # OE counts 219 | counts_OE = self._counts.copy() 220 | orig_counts = counts_OE[:, self(self._target_gene)] 221 | if gene_patterns is not None: 222 | counts_OE[:, self(self._target_gene)] = gene_patterns 223 | else: 224 | decays = [w/weight_scale[0] for w in weight_scale] # n_NB = coeff * n_NB(10), coeff <= 1 225 | counts_OE[:, self(self._target_gene)] = self._gen_ZINB(n_NB = weight_scale[0], decays = decays, **kwargs) 226 | counts_OE = counts_OE.toarray() if scipy.sparse.issparse(counts_OE) else counts_OE 227 | x_OE = torch.tensor(counts_OE.T, dtype = torch.float) 228 | if self.verbose: 229 | print(f"replace expression of {self._target_gene} to simulated expressions and edges by scale {weight_scale}") 230 | return Data(x = x_OE, edge_index = edge_index_OE, y = self._gene_names) 231 | 232 | 233 | # def run_sys_KO(self, model, genelist): 234 | # ''' 235 | # model: a trained VGAE model. 236 | # genelist: array-like, gene list to be systematic KO 237 | # ''' 238 | # self.verbose = False 239 | # g_orig = self._target_gene 240 | # data = self.data_init() 241 | # z_m0, z_S0 = get_latent_vars(data, model) 242 | # sys_res = [] 243 | # from tqdm import tqdm 244 | # for g in tqdm(genelist, desc = "systematic KO...", total = len(genelist)): 245 | # if g not in self._gene_names: 246 | # raise IndexError(f'"{g}" is not in the object') 247 | # else: 248 | # self._target_gene = g # reset KO gene 249 | # data_v = self.KO_data_init() 250 | # z_mv, z_Sv = get_latent_vars(data_v, model) 251 | # dis = get_distance(z_mv, z_Sv, z_m0, z_S0, by = "KL") 252 | # sys_res.append(dis) 253 | # self._target_gene = g_orig 254 | # # print(self._target_gene) 255 | # self.verbose = True 256 | # return np.array(sys_res) 257 | -------------------------------------------------------------------------------- /GenKI/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import negative_sampling 3 | 4 | 5 | EPS = 1e-15 6 | MAX_LOGSTD = 10 7 | 8 | 9 | def reset(value): 10 | if hasattr(value, 'reset_parameters'): 11 | value.reset_parameters() 12 | else: 13 | for child in value.children() if hasattr(value, 'children') else []: 14 | reset(child) 15 | 16 | 17 | class InnerProductDecoder(torch.nn.Module): 18 | """The inner product decoder.""" 19 | def forward(self, z, edge_index, sigmoid=True): 20 | value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1) 21 | return torch.sigmoid(value) if sigmoid else value 22 | 23 | def forward_all(self, z, sigmoid=True): 24 | adj = torch.matmul(z, z.t()) 25 | return torch.sigmoid(adj) if sigmoid else adj 26 | 27 | 28 | class GAE(torch.nn.Module): 29 | """The Graph Auto-Encoder model.""" 30 | def __init__(self, encoder, decoder=None): 31 | super().__init__() 32 | self.encoder = encoder 33 | self.decoder = InnerProductDecoder() if decoder is None else decoder 34 | GAE.reset_parameters(self) 35 | 36 | def reset_parameters(self): 37 | reset(self.encoder) 38 | reset(self.decoder) 39 | 40 | def encode(self, *args, **kwargs): 41 | return self.encoder(*args, **kwargs) 42 | 43 | def decode(self, *args, **kwargs): 44 | return self.decoder(*args, **kwargs) 45 | 46 | def recon_loss(self, z, pos_edge_index, neg_edge_index=None): 47 | """Given latent embeddings z, computes the binary cross 48 | entropy loss for positive edges pos_edge_index and negative 49 | sampled edges. 50 | 51 | Args: 52 | z (Tensor): The latent embeddings. 53 | pos_edge_index (LongTensor): The positive edges to train against. 54 | neg_edge_index (LongTensor, optional): The negative edges to train 55 | against. If not given, uses negative sampling to calculate 56 | negative edges. (default: :obj:`None`) 57 | """ 58 | 59 | pos_loss = -torch.log( 60 | self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean() 61 | 62 | if neg_edge_index is None: 63 | neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) 64 | neg_loss = -torch.log(1 - 65 | self.decoder(z, neg_edge_index, sigmoid=True) + 66 | EPS).mean() 67 | 68 | return pos_loss + neg_loss 69 | 70 | def test(self, z, pos_edge_index, neg_edge_index): 71 | """Given latent embeddings z, positive edges 72 | pos_edge_index and negative edges neg_edge_index 73 | computes metrics. 74 | 75 | Args: 76 | z (Tensor): The latent embeddings. 77 | pos_edge_index (LongTensor): The positive edges to evaluate 78 | against. 79 | neg_edge_index (LongTensor): The negative edges to evaluate 80 | against. 81 | """ 82 | from sklearn.metrics import average_precision_score, roc_auc_score #f1_score, confusion_matrix 83 | 84 | pos_y = z.new_ones(pos_edge_index.size(1)) 85 | neg_y = z.new_zeros(neg_edge_index.size(1)) 86 | y = torch.cat([pos_y, neg_y], dim=0) 87 | 88 | pos_pred = self.decoder(z, pos_edge_index, sigmoid=True) 89 | neg_pred = self.decoder(z, neg_edge_index, sigmoid=True) 90 | pred = torch.cat([pos_pred, neg_pred], dim=0) 91 | 92 | y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy() 93 | # pred[pred < 0.5] = 0 94 | # pred[pred >= 0.5] = 1 95 | # tn, fp, fn, tp = confusion_matrix(y, pred, labels=[0, 1]).ravel() 96 | 97 | return roc_auc_score(y, pred), average_precision_score(y, pred) # f1_score(y, pred), [tn, fp, fn, tp] 98 | 99 | 100 | class VGAE(GAE): 101 | """The Variational Graph Auto-Encoder model. 102 | 103 | Args: 104 | encoder (Module): The encoder module to compute :math:`\mu` and 105 | :math:`\log\sigma^2`. 106 | decoder (Module, optional): The decoder module. If set to :obj:`None`, 107 | will default to the 108 | :class:`torch_geometric.nn.models.InnerProductDecoder`. 109 | (default: :obj:`None`) 110 | """ 111 | def __init__(self, encoder, decoder=None): 112 | super().__init__(encoder, decoder) 113 | 114 | def reparametrize(self, mu, logstd): 115 | if self.training: 116 | return mu + torch.randn_like(logstd) * torch.exp(logstd) 117 | else: 118 | return mu 119 | 120 | def encode(self, *args, **kwargs): 121 | self.__mu__, self.__logstd__ = self.encoder(*args, **kwargs) 122 | self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD) 123 | z = self.reparametrize(self.__mu__, self.__logstd__) 124 | return z 125 | 126 | def kl_loss(self, mu=None, logstd=None): 127 | """Computes the KL loss, either for the passed arguments :obj:`mu` 128 | and :obj:`logstd`, or based on latent variables from last encoding. 129 | 130 | Args: 131 | mu (Tensor, optional): The latent space for :math:`\mu`. If set to 132 | :obj:`None`, uses the last computation of :math:`mu`. 133 | (default: :obj:`None`) 134 | logstd (Tensor, optional): The latent space for 135 | :math:`\log\sigma`. If set to :obj:`None`, uses the last 136 | computation of :math:`\log\sigma^2`.(default: :obj:`None`) 137 | """ 138 | mu = self.__mu__ if mu is None else mu 139 | logstd = self.__logstd__ if logstd is None else logstd.clamp( 140 | max=MAX_LOGSTD) 141 | return -0.5 * torch.mean( 142 | torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) 143 | -------------------------------------------------------------------------------- /GenKI/pcNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import svd 3 | from scipy import sparse 4 | import os 5 | import time 6 | 7 | try: 8 | import ray 9 | except ImportError: 10 | class dummy_ray(): 11 | def __init__(self): 12 | self.remote = lambda x:0 13 | print("ray not initialized") 14 | def is_initialized(self): 15 | return False 16 | def init(self, num_cpus): 17 | self.num_cpu = num_cpus 18 | def put(self, x): 19 | self.x = x 20 | ray = dummy_ray() 21 | 22 | 23 | def pcCoefficients(X, K, nComp): 24 | y = X[:, K] 25 | Xi = np.delete(X, K, 1) 26 | U, s, VT = svd(Xi, full_matrices=False) 27 | #print ('U:', U.shape, 's:', s.shape, 'VT:', VT.shape) 28 | V = VT[:nComp, :].T 29 | #print('V:', V.shape) 30 | score = Xi@V 31 | t = np.sqrt(np.sum(score**2, axis=0)) 32 | score_lsq = ((score.T / (t**2)[:, None])).T 33 | beta = np.sum(y[:, None]*score_lsq, axis=0) 34 | beta = V@beta 35 | return list(beta) 36 | 37 | 38 | def pcNet(X, # X: cell * gene 39 | nComp: int = 3, 40 | scale: bool = True, 41 | symmetric: bool = True, 42 | q: float = 0., # q: 0-100 43 | as_sparse: bool = True, 44 | random_state: int = 0): 45 | X = X.toarray() if sparse.issparse(X) else X 46 | if nComp < 2 or nComp >= X.shape[1]: 47 | raise ValueError("nComp should be greater or equal than 2 and lower than the total number of genes") 48 | else: 49 | np.random.seed(random_state) 50 | n = X.shape[1] # genes 51 | B = np.array([pcCoefficients(X, k, nComp) for k in range(n)]) 52 | A = np.ones((n, n), dtype=float) 53 | np.fill_diagonal(A, 0) 54 | for i in range(n): 55 | A[i, A[i, :]==1] = B[i, :] 56 | if scale: 57 | absA = abs(A) 58 | A = A / np.max(absA) 59 | if q > 0: 60 | A[absA < np.percentile(absA, q)] = 0 61 | if symmetric: # place in the end 62 | A = (A + A.T)/2 63 | #diag(A) <- 0 64 | if as_sparse: 65 | A = sparse.csc_matrix(A) 66 | return A 67 | 68 | 69 | @ray.remote 70 | def pc_net_parallel(X, # X: cell * gene 71 | nComp: int = 3, 72 | scale: bool = True, 73 | symmetric: bool = True, 74 | q: float = 0., 75 | as_sparse: bool = True, 76 | random_state: int = 0): 77 | return pcNet(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q, 78 | as_sparse = as_sparse, random_state = random_state) 79 | 80 | 81 | def pc_net_single(X, # X: cell * gene 82 | nComp: int = 3, 83 | scale: bool = True, 84 | symmetric: bool = True, 85 | q: float = 0., 86 | as_sparse: bool = True, 87 | random_state: int = 0): 88 | return pcNet(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q, 89 | as_sparse = as_sparse, random_state = random_state) 90 | 91 | 92 | def make_pcNet(X, 93 | nComp: int = 3, 94 | scale: bool = True, 95 | symmetric: bool = True, 96 | q: float = 0., 97 | as_sparse: bool = True, 98 | random_state: int = 0, 99 | n_cpus: int = 1, # -1: use all CPUs 100 | timeit: bool = True): 101 | start_time = time.time() 102 | if n_cpus != 1: 103 | if ray.is_initialized(): 104 | ray.shutdown() 105 | if n_cpus == -1: 106 | n_cpus = os.cpu_count() 107 | ray.init(num_cpus = n_cpus) 108 | print(f"ray init, using {n_cpus} CPUs") 109 | 110 | X_ray = ray.put(X) # put X to distributed object store and return object ref (ID) 111 | # print(X_ray) 112 | net = pc_net_parallel.remote(X_ray, nComp = nComp, scale = scale, symmetric = symmetric, q = q, 113 | as_sparse = as_sparse, random_state = random_state) 114 | net = ray.get(net) 115 | else: 116 | net = pc_net_single(X, nComp = nComp, scale = scale, symmetric = symmetric, q = q, 117 | as_sparse = as_sparse, random_state = random_state) 118 | if ray.is_initialized(): 119 | ray.shutdown() 120 | if timeit: 121 | duration = time.time() - start_time 122 | print("execution time of making pcNet: {:.2f} s".format(duration)) 123 | return net 124 | 125 | 126 | def main(): 127 | counts = np.random.randint(0, 10, (5, 100)) 128 | net = make_pcNet(counts, as_sparse = True, timeit = True, n_cpus = 1) 129 | print(f"input counts shape: {counts.shape},\nmake pcNet completed, shape: {net.shape}") 130 | return 100 131 | 132 | 133 | if __name__ == "__main__": 134 | import sys 135 | sys.exit(main()) 136 | -------------------------------------------------------------------------------- /GenKI/preprocesing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from anndata import ( 3 | AnnData, 4 | read_h5ad, 5 | read_csv, 6 | read_excel, 7 | read_hdf, 8 | read_loom, 9 | read_mtx, 10 | read_text, 11 | ) 12 | import scanpy as sc 13 | import pandas as pd 14 | from typing import Union 15 | from scipy import sparse 16 | import os 17 | import pickle 18 | from torch_geometric.data import Data 19 | from torch_geometric.transforms import RandomLinkSplit 20 | from torch_geometric import seed_everything 21 | sc.settings.verbosity = 0 22 | 23 | CURRENT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) # GenKI 24 | # https://scanpy.readthedocs.io/en/stable/api.html#reading 25 | file_attrs = ["h5ad", "csv", "xlsx", "h5", "loom", "mtx", "txt", "mat"] 26 | read_fun_keys = ["h5ad", "csv", "excel", "hdf", "loom", "mtx", "text"] # fun from scanpy read 27 | 28 | 29 | def _read_counts(counts_path: str, 30 | transpose: bool = False, 31 | **kwargs): 32 | 33 | """Read counts file to build an AnnData. 34 | Args: 35 | counts_path (str): Path to counts file. 36 | transpose (bool, optional): Whether transpose the counts. 37 | Returns: 38 | AnnData 39 | """ 40 | 41 | file_attr = counts_path.split(".")[-1] 42 | if Path(counts_path).is_file() and file_attr in file_attrs: 43 | print(f"load counts from {counts_path}") 44 | if file_attr == "mat": 45 | import numpy as np 46 | import h5py 47 | f = h5py.File(counts_path,'r') 48 | # print(f.keys()) 49 | counts = np.array(f.get(list(f.keys())[0]), dtype="float64") 50 | if transpose: 51 | counts = counts.T 52 | adata = sc.AnnData(counts) 53 | else: 54 | read_fun_key = read_fun_keys[file_attrs.index(file_attr)] 55 | read_fun = getattr(sc, f"read_{read_fun_key}") # define sc.read_{} function 56 | if transpose: 57 | adata = read_fun(counts_path, **kwargs).transpose() # transpose counts file 58 | else: 59 | adata = read_fun(counts_path, **kwargs) 60 | else: 61 | raise ValueError("incorrect file path given to counts") 62 | return adata 63 | 64 | 65 | def build_adata( 66 | adata: Union[str, AnnData], 67 | meta_gene_path: Union[None, str] = None, 68 | meta_cell_path: Union[None, str] = None, 69 | meta_cell_cols: Union[None, str] = None, 70 | sep = "\t", 71 | header = None, 72 | log_normalize: bool = False, 73 | scale_data: bool = True, 74 | as_sparse: bool = True, 75 | uppercase: bool = True, 76 | **kwargs, 77 | ): 78 | 79 | """Load counts, metadata of genes and cells to build an AnnData input for scTenifoldXct. 80 | Args: 81 | adata (str): Path to counts or adata. 82 | meta_gene (Union[None, str]): Path to metadata of variables. 83 | meta_cell (Union[None, str]): Path to metadata of Observations 84 | sep (str, optional): The delimiter for metadata. Defaults to "\t". 85 | log_normalize (bool, optional): Whether log-normalize the counts. Defaults to False. 86 | scale_data (bool, optional): Whether standardize the counts. Defaults to False. 87 | as_sparse (bool, optional): Whether make the counts sparse. Defaults to True. 88 | uppercase (bool, optional): Whether convert gene names to uppercase. Defaults to True. 89 | **kwargs: key words in _read_counts 90 | Returns: 91 | AnnData 92 | """ 93 | if not isinstance(adata, AnnData): 94 | adata = _read_counts(adata, **kwargs) 95 | 96 | if meta_gene_path is not None and Path(meta_gene_path).is_file(): 97 | try: 98 | # print("add metadata for genes") 99 | df_gene = pd.read_csv(meta_gene_path, header=header, sep=sep) 100 | df_gene.index = df_gene.index.astype("str") 101 | adata.var_names = df_gene[0] 102 | except Exception: 103 | raise ValueError("incorrect file path given to meta_gene") 104 | if meta_cell_path is not None and Path(meta_cell_path).is_file(): 105 | try: 106 | # print("add metadata for cells") 107 | df_cell = pd.read_csv(meta_cell_path, header=header, sep=sep) 108 | df_cell.index = df_cell.index.astype("str") 109 | adata.obs = df_cell 110 | if meta_cell_cols is not None: 111 | adata.obs.columns = meta_cell_cols 112 | except Exception: 113 | raise ValueError("incorrect file path given to meta_cell") 114 | 115 | if log_normalize: 116 | # print("normalize counts") 117 | adata.layers["raw"] = adata.X 118 | sc.pp.normalize_total(adata, target_sum=1e4) 119 | sc.pp.log1p(adata) 120 | 121 | 122 | if scale_data: 123 | adata.layers["norm"] = adata.X 124 | # print("standardize counts") 125 | from sklearn import preprocessing 126 | counts = adata.X.toarray() if sparse.issparse(adata.X) else adata.X 127 | scaler = preprocessing.StandardScaler().fit(counts) 128 | adata.X = scaler.transform(counts) 129 | 130 | if as_sparse: 131 | # print("make counts sparse") 132 | adata.X = ( 133 | sparse.csr_matrix(adata.X) if not sparse.issparse(adata.X) else adata.X 134 | ) 135 | 136 | if uppercase: 137 | adata.var_names = adata.var_names.str.upper() # all species use upper case genes 138 | 139 | return adata 140 | 141 | 142 | def check_adata(adata): 143 | if "norm" not in adata.layers: 144 | raise ValueError("normalized counts should be saved at adata.layers[\"norm\"]") 145 | if not (adata.X.toarray() < 0).any(): 146 | raise ValueError("adata.X should be standardized counts") 147 | 148 | 149 | def save_gdata(data, dir: str = "data", name: str = "data"): 150 | path = os.path.join(CURRENT_DIR, dir) 151 | os.makedirs(path, exist_ok = True) 152 | print(f"save {name} to dir {path}") 153 | with open(os.path.join(path, f'{name}.p'), 'wb') as f: 154 | pickle.dump(data.to_dict(), f, protocol=pickle.HIGHEST_PROTOCOL) 155 | 156 | 157 | def load_gdata(dir: str = "data", name: str = "data"): 158 | with open(os.path.join(dir, f"{name}.p"), 'rb') as f: 159 | data = pickle.load(f) 160 | data = Data.from_dict(data) 161 | print(f"load {name} from dir {dir}") 162 | return data 163 | 164 | 165 | def split_data(dir: str = "data", data = None, load: bool = False, save: bool = False): 166 | path = os.path.join(CURRENT_DIR, dir) 167 | if load: 168 | train_data, val_data, test_data = load_gdata(path, "train_data"), load_gdata(path, "val_data"), load_gdata(path, "test_data") 169 | return train_data, val_data, test_data 170 | elif data is None: 171 | data = load_gdata(path, "data") 172 | seed_everything(42) 173 | transform = RandomLinkSplit(is_undirected = True, 174 | split_labels = True, 175 | num_val = 0.05, 176 | num_test = 0.2, 177 | ) 178 | train_data, val_data, test_data = transform(data) 179 | # print("split data into train/valid/test") 180 | if save: 181 | save_gdata(train_data, dir, "train_data") 182 | save_gdata(val_data, dir, "val_data") 183 | save_gdata(test_data, dir, "test_data") 184 | del train_data.pos_edge_label, train_data.neg_edge_label 185 | return train_data, val_data, test_data 186 | 187 | 188 | if __name__ == "__main__": 189 | from os import path 190 | 191 | data_dir = path.join( 192 | path.dirname(path.dirname(path.abspath(__file__))), "filtered_gene_bc_matrices" 193 | ) 194 | counts_path = str(path.join(data_dir, "matrix.mtx")) 195 | gene_path = str(path.join(data_dir, "genes.tsv")) 196 | cell_path = str(path.join(data_dir, "barcodes.tsv")) 197 | adata = build_adata(counts_path, gene_path, cell_path) 198 | print(adata) 199 | -------------------------------------------------------------------------------- /GenKI/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.data import Data 4 | from torch_geometric.nn import GCNConv 5 | from .model import VGAE 6 | import torch.utils.tensorboard as tb 7 | import os 8 | import matplotlib.pyplot as plt 9 | from tqdm import tqdm 10 | 11 | from .utils import get_distance 12 | from .preprocesing import split_data 13 | 14 | 15 | class VariationalGCNEncoder(torch.nn.Module): # encoder 16 | def __init__(self, in_channels, out_channels, hidden = 2): 17 | super(VariationalGCNEncoder, self).__init__() 18 | self.conv1 = GCNConv(in_channels, hidden * out_channels) 19 | self.conv_mu = GCNConv(hidden * out_channels, out_channels) 20 | self.conv_logstd = GCNConv(hidden * out_channels, out_channels) 21 | 22 | def forward(self, x, edge_index): 23 | x = self.conv1(x, edge_index).relu() 24 | return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) 25 | 26 | 27 | class VGAE_trainer(): 28 | def __init__(self, 29 | data, 30 | out_channels: int = 2, 31 | epochs: int = 100, 32 | lr: float = 7e-4, 33 | weight_decay = 9e-4, 34 | beta: str = 1e-4, 35 | log_dir: str = None, 36 | verbose: bool = True, 37 | seed: int = None, 38 | **kwargs): 39 | self.num_features = data.num_features 40 | self.out_channels = out_channels 41 | self.data = data 42 | self.epochs = epochs 43 | self.lr = lr 44 | self.verbose = verbose 45 | self.logging = True if log_dir is not None else False 46 | if self.logging: 47 | self.train_logger = tb.SummaryWriter(os.path.join(log_dir, "train")) 48 | self.test_logger = tb.SummaryWriter(os.path.join(log_dir, "test")) 49 | if beta is not None: 50 | self.beta = beta 51 | else: 52 | self.beta = (1 / self.train_data.num_nodes) 53 | self.weight_decay = weight_decay 54 | self.seed = seed 55 | 56 | 57 | def __repr__(self) -> str: 58 | return f"Hyperparameters\n"\ 59 | f"epochs: {self.epochs}, lr: {self.lr}, beta: {self.beta:.4f}\n" 60 | 61 | 62 | # split data 63 | def _transform_data(self, 64 | x_noise: float = None, 65 | x_dropout: float = None, 66 | edge_noise: float = None, 67 | **kwargs 68 | ): 69 | """ 70 | Args: 71 | x_noise: Standard deviation of white noise added to the training data. Defaults to None. 72 | edge_noise: Remove or add edges to the training data. 73 | """ 74 | # from copy import deepcopy 75 | # data_ = deepcopy(self.data) 76 | if x_dropout is not None: 77 | mask = torch.FloatTensor(self.data.x.shape).uniform_() > x_dropout # % zeros 78 | self.data.x = self.data.x * mask 79 | print(f"force zeros to data x, dropout: {x_dropout}") 80 | 81 | self.train_data, self.val_data, self.test_data = split_data(data = self.data, **kwargs) # fixed split 82 | if x_noise is not None: # white noise on X 83 | gamma = x_noise * torch.randn(self.train_data.x.shape) 84 | self.train_data.x = 2**gamma * self.train_data.x 85 | print(f"add white noise to training data x, level: {x_noise} SD") 86 | 87 | # if x_dropout is not None: 88 | # mask = torch.FloatTensor(self.train_data.x.shape).uniform_() > x_dropout # % zeros 89 | # self.train_data.x = self.train_data.x * mask 90 | # print(f"force zeros to training data x, dropout: {x_dropout}") 91 | 92 | if edge_noise is not None: 93 | n_pos_edge = self.train_data.pos_edge_label_index.shape[1] 94 | n = int(abs(edge_noise) * n_pos_edge) 95 | print("Before:", self.train_data) 96 | print("\n") 97 | if edge_noise > 0: # fold edges 98 | from torch_geometric.utils import negative_sampling 99 | fake_pos_edge = negative_sampling(self.data.edge_index, # then fake edges impossible appeared in test set: for data leakage 100 | num_neg_samples= n, 101 | num_nodes = len(self.train_data.x)) 102 | new_pos_edge = torch.unique(torch.cat((self.train_data.pos_edge_label_index, fake_pos_edge), 1), dim = 1) 103 | new_edge = torch.cat((new_pos_edge, new_pos_edge[torch.LongTensor([1, 0])]), 1) # swap 104 | print(f"add noise to training data edge, level: {edge_noise}: add {n} edges") 105 | else: # ratio edges 106 | if n > n_pos_edge: 107 | raise ValueError("cannot retain edges more than the total") 108 | weights = torch.tensor([1/n_pos_edge] * n_pos_edge, dtype = torch.float) # uniform weights 109 | index = weights.multinomial(num_samples = n, replacement = False) 110 | new_pos_edge = self.train_data.pos_edge_label_index[:, index] 111 | new_edge = torch.cat((new_pos_edge, new_pos_edge[torch.LongTensor([1, 0])]), 1) 112 | # new_neg_edge = self.train_data.neg_edge_label_index[:, index] 113 | # self.train_data.neg_edge_label_index = new_neg_edge 114 | print(f"retain a portion of training data edge, level: {abs(edge_noise)}: retain {n} edges") 115 | self.train_data.edge_index, self.train_data.pos_edge_label_index = new_edge, new_pos_edge 116 | print("After:", self.train_data) 117 | 118 | 119 | # refer to https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py 120 | def train(self, **kwargs): 121 | global_step = 0 122 | if self.seed is not None: 123 | torch.manual_seed(self.seed) # 8096 124 | self._transform_data(**kwargs) # get split data w/o noise 125 | self.model = VGAE(VariationalGCNEncoder(self.num_features, self.out_channels)) 126 | 127 | # to cuda 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | self.model = self.model.to(device) 130 | self.train_data, self.val_data, self.test_data = self.train_data.to(device), self.val_data.to(device), self.test_data.to(device) 131 | 132 | optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, weight_decay = self.weight_decay) 133 | # optimizer = torch.optim.SGD(self.model.parameters(), lr = self.lr, momentum = 0.9, weight_decay = 5e-4) 134 | 135 | for epoch in range(self.epochs): 136 | self.model.train() 137 | optimizer.zero_grad() 138 | z = self.model.encode(self.train_data.x, self.train_data.edge_index) 139 | recon_loss = self.model.recon_loss(z, self.train_data.pos_edge_label_index, self.train_data.neg_edge_label_index) 140 | kl_loss = self.beta * self.model.kl_loss() # beta-VAE 141 | loss = recon_loss + kl_loss 142 | if self.logging: 143 | self.train_logger.add_scalar("loss", loss.item(), global_step) 144 | self.train_logger.add_scalar("recon_loss", recon_loss.item(), global_step) 145 | self.train_logger.add_scalar("kl_loss", kl_loss.item(), global_step) 146 | loss.backward() 147 | optimizer.step() 148 | global_step += 1 149 | 150 | with torch.no_grad(): 151 | self.model.eval() 152 | z = self.model.encode(self.test_data.x, self.test_data.edge_index) 153 | auc, ap = self.model.test(z, self.test_data.pos_edge_label_index, self.test_data.neg_edge_label_index) 154 | if self.logging: 155 | self.test_logger.add_scalar("AUROC", auc, global_step) 156 | self.test_logger.add_scalar("AP", ap, global_step) 157 | if self.verbose: 158 | print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, AUROC: {auc:.4f}, AP: {ap:.4f}") 159 | self.final_metrics = global_step, loss.item(), auc, ap 160 | 161 | 162 | def save_model(self, name: str): 163 | """ 164 | name: str, name of .th file that will be saved in model. 165 | """ 166 | if isinstance(self.model, VGAE): 167 | os.makedirs("model", exist_ok=True) 168 | path = os.path.join("model", f"{name}.th") 169 | print(f"save model parameters to {path}") 170 | return torch.save(self.model.state_dict(), f"model/{name}.th") 171 | else: 172 | raise ValueError(f"model type {type(self.model)} not supported") 173 | 174 | 175 | def load_model(self, name: str): 176 | """ 177 | name: str, name of .th file saved in model. 178 | """ 179 | r = VGAE(VariationalGCNEncoder(self.num_features, self.out_channels)) 180 | path = os.path.join("model", f"{name}.th") 181 | print(f"load model parameters from {path}") 182 | r.load_state_dict( 183 | torch.load(os.path.join("model", f"{name}.th"), map_location=torch.device("cpu")) 184 | ) 185 | self.model = r 186 | 187 | 188 | # after training 189 | def get_latent_vars(self, data, plot_latent_mu = False): 190 | """ 191 | data: torch_geometric.data.data.Data. 192 | plot_latent_z: bool, whether to plot nodes with random sampled latent features. 193 | """ 194 | self.model.eval() 195 | _ = self.model.encode(data.x, data.edge_index) 196 | z_m = self.model.__mu__.detach().numpy() 197 | z_S = (self.model.__logstd__.exp() ** 2).detach().numpy() # variance 198 | if plot_latent_mu: 199 | # z_np = z.detach().numpy() 200 | fig, ax = plt.subplots(figsize=(6, 6), dpi=80) 201 | if z_m.shape[1] == 2: 202 | ax.scatter(z_m[:, 0], z_m[:, 1], s=4) 203 | elif z_m.shape[1] == 1: 204 | ax.hist(z_m, bins=60) 205 | elif z_m.shape[1] == 3: 206 | from mpl_toolkits.mplot3d import Axes3D 207 | ax = Axes3D(fig) 208 | ax.scatter(z_m[:, 0], z_m[:, 1], z_m[:, 2], s=4) 209 | plt.show() 210 | if z_m.shape[1] == 1: 211 | z_m = z_m.flatten() 212 | z_S = z_S.flatten() 213 | return z_m, z_S 214 | 215 | 216 | def pmt(self, data_v, n = 100, by = "KL"): 217 | """ 218 | data_v: torch_geometric.data.data.Data, virtual data. 219 | n: int, # of permutation. 220 | by: str, method for distance calculation of distributions. 221 | """ 222 | # permutate cells order and compute KL div 223 | n_cell = self.data.x.shape[1] 224 | dis_p = [] 225 | np.random.seed(0) 226 | for _ in tqdm(range(n), desc="Permutating", total = n): 227 | # p: pmt, v: virtual, m: mean, S: sigma 228 | idx_pmt = np.random.choice( 229 | np.arange(n_cell), size=n_cell 230 | ) # bootstrap cell labels 231 | data_WT_p = Data( 232 | # x=torch.tensor(data.x[:, idx_pmt], dtype=torch.float), 233 | x=self.data.x[:, idx_pmt].clone().detach().requires_grad_(False), 234 | edge_index=self.data.edge_index, 235 | ) 236 | data_KO_p = Data( 237 | # x=torch.tensor(data_v.x[:, idx_pmt], dtype=torch.float), 238 | x=data_v.x[:, idx_pmt].clone().detach().requires_grad_(False), 239 | edge_index=data_v.edge_index, 240 | ) # construct virtual data (KO) based on pmt WT 241 | z_mp, z_Sp = self.get_latent_vars(data_WT_p) 242 | z_mvp, z_Svp = self.get_latent_vars(data_KO_p) 243 | if by == "KL": 244 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp)) # order KL 245 | # if by == 'reverse_KL': 246 | # dis_p.append(get_distance(z_mp, z_Sp, z_mvp, z_Svp)) # reverse order KL 247 | if by == "t": 248 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp, by="t")) 249 | if by == "EMD": 250 | dis_p.append(get_distance(z_mvp, z_Svp, z_mp, z_Sp, by="EMD")) 251 | return np.array(dis_p) 252 | 253 | 254 | def eva(args): 255 | from .preprocesing import load_gdata 256 | CURRENT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) 257 | load_path = os.path.join(CURRENT_DIR, args.dir) 258 | data = load_gdata(load_path, "data") 259 | 260 | sensei = VGAE_trainer(data, 261 | epochs = args.epochs, 262 | lr = args.lr, 263 | log_dir = args.logdir, 264 | beta = args.beta, 265 | verbose = args.verbose, 266 | seed = args.seed, 267 | ) 268 | sensei.train(edge_noise = args.e_noise, 269 | x_noise = args.x_noise, 270 | x_dropout = args.x_dropout, 271 | dir = args.dir, load = False, # load split data 272 | ) 273 | epoch, loss, auc, ap = sensei.final_metrics 274 | 275 | save_path = os.path.join(CURRENT_DIR, "train_log") 276 | os.makedirs(save_path, exist_ok = True) 277 | try: 278 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'r') 279 | except IOError: 280 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'w') 281 | f.writelines("Epoch,Loss,AUROC,AP\n") 282 | finally: 283 | f = open(os.path.join(save_path, f'{args.train_out}.txt'), 'a') 284 | f.writelines(f"{epoch:03d}, {loss:.4f}, {auc:.4f}, {ap:.4f}\n") 285 | f.close() 286 | if args.do_test: 287 | data_ko = load_gdata(load_path, "data_ko") 288 | print("continue") 289 | from .utils import get_generank, get_r2_score 290 | z_mu, z_std = sensei.get_latent_vars(data) 291 | z_mu_KO, z_std_KO = sensei.get_latent_vars(data_ko) 292 | dis = get_distance(z_mu_KO, z_std_KO, z_mu, z_std, by = 'KL') 293 | null = sensei.pmt(data_ko, n = 100, by = 'KL') 294 | res = get_generank(data, dis, null, save_significant_as = args.generank_out) 295 | geneset = list(res.index) 296 | r2, r2_adj = get_r2_score(data, geneset[1:], geneset[0]) 297 | try: 298 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'r') 299 | except IOError: 300 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'w') 301 | f.writelines("R2,R2_adj\n") 302 | finally: 303 | f = open(os.path.join(save_path, f'{args.r2_out}.txt'), 'a') 304 | f.writelines(f"{r2:.4f}, {r2_adj:.4f}\n") 305 | f.close() 306 | 307 | 308 | if __name__ == "__main__": 309 | import argparse 310 | parser = argparse.ArgumentParser() 311 | parser.add_argument('--ddir', type = str, default = "data") 312 | parser.add_argument('--epochs', type = int, default = 100) 313 | parser.add_argument('--lr', type = float, default = 7e-4) 314 | parser.add_argument('--beta', type = float, default = 1e-4) 315 | parser.add_argument('--seed', type = int, default = None) 316 | parser.add_argument('--logdir', type = str, default = None) 317 | parser.add_argument('-E', '--e_noise', type = float, default = None) 318 | parser.add_argument('-X', '--x_noise', type = float, default = None) 319 | parser.add_argument('-XO', '--x_dropout', type = float, default = None) 320 | parser.add_argument('-v', '--verbose', action = "store_true") 321 | parser.add_argument('--train_out', type = str, default = "train_log") 322 | 323 | parser.add_argument('--do_test', action = "store_true") 324 | parser.add_argument('--generank_out', type = str, default = "gene_list") 325 | parser.add_argument('--r2_out', type = str, default = "r2_score") 326 | args = parser.parse_args() 327 | 328 | eva(args) 329 | # python -m GenKI.train --dir data --logdir log_dir/run0 --seed 8096 -E -0.1 -v 330 | # python -m GenKI.train --dir data_covid --logdir log_dir/sigma0_covid --lr 5e-3 --seed 8096 -v 331 | # python -m GenKI.train --dir data --train_out train_log --do_test --generank_out genelist --r2_out r2_score -v 332 | 333 | 334 | -------------------------------------------------------------------------------- /GenKI/tune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model import VGAE 3 | from .train import VariationalGCNEncoder 4 | import numpy as np 5 | from ray import tune 6 | from ray.tune.schedulers import ASHAScheduler 7 | import os 8 | from .preprocesing import split_data 9 | 10 | 11 | hyperparams = { 12 | "lr": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(1, 4))), 13 | "beta": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(1, 5))), 14 | "seed": tune.randint(0, 10000), 15 | "weight_decay": tune.sample_from(lambda _: np.random.randint(1, 10)*(0.1**np.random.randint(3, 7))) 16 | } 17 | 18 | 19 | def train(config, checkpoint_dir = None): 20 | train_data, val_data, test_data = split_data(dir = "data", load = True) 21 | torch.manual_seed(config["seed"]) 22 | model = VGAE(VariationalGCNEncoder(train_data.num_features, 2)) 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | model = model.to(device) 26 | train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device) 27 | optimizer = torch.optim.Adam(model.parameters(), lr = config["lr"], weight_decay = config["weight_decay"]) 28 | # optimizer = torch.optim.SGD(model.parameters(), lr = config["lr"], momentum = 0.9, weight_decay = config["weight_decay"]) 29 | 30 | # when restore a checkpoint 31 | if checkpoint_dir: 32 | checkpoint = os.path.join(checkpoint_dir, "checkpoint") 33 | model_state, optimizer_state = torch.load(checkpoint) 34 | model.load_state_dict(model_state) 35 | optimizer.load_state_dict(optimizer_state) 36 | 37 | for epoch in range(1000): # search < max_num_epochs 38 | model.train() 39 | optimizer.zero_grad() 40 | z = model.encode(train_data.x, train_data.edge_index) 41 | recon_loss = model.recon_loss(z, train_data.pos_edge_label_index) 42 | kl_loss = config["beta"] * model.kl_loss() # beta-VAE 43 | loss = recon_loss + kl_loss 44 | loss.backward() 45 | optimizer.step() 46 | 47 | # valid set 48 | with torch.no_grad(): 49 | model.eval() 50 | z = model.encode(val_data.x, val_data.edge_index) 51 | auc, ap = model.test(z, val_data.pos_edge_label_index, val_data.neg_edge_label_index) 52 | # print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, AUROC: {auc:.4f}, AP: {ap:.4f}') 53 | 54 | # save a checkpoint 55 | with tune.checkpoint_dir(step = epoch) as checkpoint_dir: 56 | path = os.path.join(checkpoint_dir, "checkpoint") 57 | torch.save( 58 | (model.state_dict(), optimizer.state_dict()), path) 59 | # record metrics from valid set 60 | tune.report(Loss = loss.item(), AUROC = auc, AP = ap, F_ = auc*ap) # call: shown as training_iteration 61 | print("Finished Training") 62 | 63 | 64 | def test_best_model(best_trial): 65 | train_data, val_data, test_data = split_data(dir = "data", load = True) 66 | torch.manual_seed(best_trial.config["seed"]) 67 | best_trained_model = VGAE(VariationalGCNEncoder(train_data.num_features, 2)) 68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | best_trained_model = best_trained_model.to(device) 70 | train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device) 71 | 72 | checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") 73 | model_state, optimizer_state = torch.load(checkpoint_path) 74 | best_trained_model.load_state_dict(model_state) 75 | 76 | # test set 77 | with torch.no_grad(): 78 | best_trained_model.eval() 79 | z = best_trained_model.encode(test_data.x, test_data.edge_index) 80 | auc, ap = best_trained_model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index) 81 | print(f"Best trial test AUROC: {auc:.4f}") 82 | print(f"Best trial test AP: {ap:.4f}") 83 | 84 | 85 | def main(num_samples, 86 | max_num_epochs = 100, 87 | gpus_per_trial = 0, 88 | save_as: str = "tune_result"): 89 | scheduler = ASHAScheduler( 90 | max_t = max_num_epochs, # max iteration 91 | grace_period = 10, # stop at least after this iteration 92 | reduction_factor = 2) 93 | result = tune.run(train, 94 | config = hyperparams, 95 | num_samples = num_samples, # trials: sets sampled from grid of hyperparams 96 | name = "experiment", # saved folder name 97 | metric = "F_", 98 | mode = "max", 99 | resources_per_trial={"cpu": 4, "gpu": gpus_per_trial}, 100 | scheduler = scheduler, # prune bad runs 101 | # stop = {'training_iteration':100}, # when tune.report was called 102 | ) 103 | best_trial = result.get_best_trial("F_", "max", "last") 104 | print(f"Best trial config: {best_trial}") 105 | print("Best trial valid AUROC: {}".format(best_trial.last_result["AUROC"])) 106 | print("Best trial valid AP: {}".format(best_trial.last_result["AP"])) 107 | test_best_model(best_trial) # only should run once 108 | save_path = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), f"{save_as}.csv") 109 | result.dataframe().to_csv(save_path) 110 | 111 | 112 | if __name__ == "__main__": 113 | # seed for Ray Tune's random search algorithm 114 | np.random.seed(42) 115 | main(num_samples = 50) 116 | # report: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html#tune-autofilled-metrics 117 | 118 | -------------------------------------------------------------------------------- /GenKI/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import scipy 5 | import matplotlib.pyplot as plt 6 | from sklearn.manifold import TSNE 7 | from sklearn.preprocessing import StandardScaler 8 | from scipy.cluster.vq import kmeans2 9 | import math 10 | 11 | 12 | def boxcox_norm(x): 13 | """ 14 | x: 1-D array-like, require positive values. 15 | Box-cox transform and standarize x 16 | """ 17 | xt, _ = scipy.stats.boxcox(x) 18 | return StandardScaler().fit_transform(xt[:, None]).flatten() 19 | # (xt - xt.mean())/np.sqrt(xt.var()) # z-score for 1-D 20 | 21 | 22 | def _t_stat(m0, S0, m1, S1): 23 | return (m0 - m1) / math.sqrt( 24 | S0**2 + S1**2 25 | ) 26 | 27 | 28 | def _kl_1d(m0, S0, m1, S1): 29 | """ 30 | KL divergence between two gaussian distributions. 31 | https://stats.stackexchange.com/questions/234757/how-to-use-kullback-leibler-divergence-if-mean-and-standard-deviation-of-of-two 32 | """ 33 | return 0.5 * math.log(S1 / S0) + (S0**2 + (m0 - m1) ** 2) / (2 * S1**2) - 1 / 2 34 | 35 | 36 | def _kl_mvn(m0, S0, m1, S1): 37 | """ 38 | KL divergence between two multivariate gaussian distributions. 39 | """ 40 | # store inv diag covariance of S1 and diff between means 41 | N = m0.shape[0] 42 | iS1 = np.linalg.pinv(S1) # pseudo-inverse 43 | diff = m1 - m0 44 | tr_term = np.trace(iS1 @ S0) 45 | det_term = np.log( 46 | np.linalg.det(S1) / np.linalg.det(S0) 47 | ) # np.sum(np.log(S1)) - np.sum(np.log(S0)) 48 | quad_term = diff.T @ iS1 @ diff # np.sum( (diff*diff) * iS1, axis=1) 49 | return 0.5 * (tr_term + det_term + quad_term - N) 50 | 51 | 52 | def _wasserstein_dist2(m0, S0, m1, S1): 53 | """ 54 | Earth Mover Distance (wasserstein distance) between two multivariate gaussian distributions. 55 | """ 56 | return np.square(np.linalg.norm(m0 - m1)) + np.trace( 57 | S0 + S1 - 2 * np.sqrt(np.sqrt(S0) @ S1 @ np.sqrt(S0)) 58 | ) 59 | 60 | 61 | def get_distance(z_m0, z_S0, z_m1, z_S1, by = "KL"): 62 | """ 63 | z_m0, z_S0: latent means and sigma from WT data. 64 | z_mv, z_Sv: latent means and sigma from virtual data. 65 | """ 66 | dis = [] 67 | try: 68 | out_channels = z_m0.shape[1] # out_channels >= 2 69 | for m0, S0, m1, S1 in zip(z_m0, z_S0, z_m1, z_S1): 70 | temp_S0 = np.zeros((out_channels, out_channels), float) 71 | np.fill_diagonal(temp_S0, S0) 72 | temp_S1 = np.zeros((out_channels, out_channels), float) 73 | np.fill_diagonal(temp_S1, S1) 74 | if by == "KL": 75 | dis.append(_kl_mvn(m0, temp_S0, m1, temp_S1)) 76 | # dis.append( 77 | # ( 78 | # _kl_mvn(m0, temp_S0, m1, temp_S1) 79 | # + _kl_mvn(m1, temp_S1, m0, temp_S0) 80 | # ) 81 | # / 2 82 | # ) 83 | if by == "EMD": 84 | dis.append(_wasserstein_dist2(m0, temp_S0, m1, temp_S1)) 85 | except IndexError: # when out_channels == 1 86 | for m0, S0, m1, S1 in zip(z_m0, z_S0, z_m1, z_S1): 87 | if by == "KL": 88 | dis.append((_kl_1d(m0, S0, m1, S1) + _kl_1d(m1, S1, m0, S0)) / 2) 89 | if by == "t": 90 | dis.append(_t_stat(m0, S0, m1, S1)) 91 | return np.array(dis) 92 | 93 | 94 | def get_generank( 95 | data, 96 | distance, 97 | null = None, 98 | rank: bool = True, 99 | reverse: bool = False, 100 | bagging: float = 0.05, 101 | cutoff: float = 0.95, 102 | save_significant_as: str = None, 103 | ): 104 | """ 105 | data: torch_geometric.data.data.Data. 106 | distance: array-like, output of get_distance. 107 | null: array-like, output of pmt. 108 | bagging: threshold for bagging top names at each permutation. 109 | cutoff: threshold for frequency of bagging after all permutations. 110 | save_significant_as: .txt file name of significant genes that will be saved in result for enrichment test. 111 | """ 112 | if null is not None: 113 | if reverse: 114 | idx = np.argsort(-null, axis=1) 115 | else: 116 | idx = np.argsort(null, axis=1) 117 | thre = int(len(data.y) * (1 - bagging)) # 95% 118 | idx = idx[:, thre:].flatten() # bagging index of top ranked genes 119 | 120 | y = np.bincount(idx) 121 | ii = np.nonzero(y)[0] 122 | freq = np.vstack((ii, y[ii])).T # [gene_index, #hit] 123 | df_KL = pd.DataFrame( 124 | data=distance[freq[:, 0]][:, None], 125 | index=np.array(data.y)[freq[:, 0]], 126 | columns=["dis"], 127 | ) 128 | df_KL[["index", "hit"]] = freq.astype(int) 129 | hit = int(null.shape[0] * cutoff) 130 | df_KL = df_KL[(df_KL.hit >= hit) & (df_KL.dis != 0)] 131 | if rank: 132 | df_KL.sort_values(by=["dis", "hit"], ascending=reverse, inplace=True) 133 | if save_significant_as is not None: 134 | output = list(df_KL.index) 135 | os.makedirs("result", exist_ok=True) 136 | np.savetxt( 137 | os.path.join("result", f"{save_significant_as}.txt"), output, fmt="%s", delimiter="," 138 | ) 139 | print(f"save {len(output)} genes to \"./result/{save_significant_as}.txt\"") 140 | else: 141 | df_KL = pd.DataFrame( 142 | data=distance, index=np.array(data.y), columns=["dis"] 143 | ) # filter pseudo values 144 | if rank: 145 | df_KL.sort_values(by=["dis"], ascending=reverse, inplace=True) 146 | if rank: 147 | df_KL["rank"] = np.arange(len(df_KL)) + 1 148 | return df_KL 149 | 150 | 151 | def get_generank_gsea(data, 152 | distance, 153 | reverse: bool = False, 154 | save_as: str = None): 155 | """ 156 | data: torch_geometric.data.data.Data, 157 | distance: array-like, output of get_distance. 158 | save_as: str, .txt file name that will be saved in result for GSEA. 159 | """ 160 | df_gsea = get_generank(data, distance, reverse=reverse) 161 | df_gsea = df_gsea[df_gsea["dis"] > 0] # remove pinverse 162 | 163 | # Box-cox 164 | df_gsea["dis_norm"] = boxcox_norm(df_gsea["dis"]) # z-score 165 | 166 | df_gsea.sort_values(by="dis_norm", inplace=True, ascending=reverse) 167 | output_gsea = np.stack( 168 | (df_gsea.index, df_gsea["dis_norm"]) 169 | ).T # GSEA: [gene_name, value] 170 | 171 | if save_as is not None: 172 | os.makedirs("result", exist_ok=True) 173 | np.savetxt( 174 | os.path.join("result", "GSEA_{save_as}.txt"), output_gsea, fmt="%s", delimiter="\t" 175 | ) 176 | print(f"save ranked genes to \"./result/GSEA_{save_as}.txt\"") 177 | return df_gsea 178 | 179 | 180 | def get_r2_score(data, geneset: list, target_gene: str): 181 | import statsmodels.api as sm 182 | # np.random.seed(4) 183 | 184 | X = data.x.numpy().T # standardized counts 185 | x_genes_idx = [data.y.index(g) for g in geneset] 186 | y_genes_idx = data.y.index(target_gene) 187 | x_genes = X[:, x_genes_idx].copy()#.toarray() 188 | y_genes = X[:, y_genes_idx].copy()#.toarray() 189 | # print(x_genes.shape, y_genes.shape) 190 | 191 | result = sm.OLS(y_genes, x_genes).fit() # y, X 192 | return result.rsquared, result.rsquared_adj 193 | # for _ in range(30): 194 | # random_idx = np.random.choice(X.shape[1], x_genes.shape[1], replace = False) 195 | # random_genes = X[:, random_idx].copy()#.toarray() 196 | # result = sm.OLS(y_genes, random_genes).fit() 197 | # r2.append([result.rsquared, result.rsquared_adj]) 198 | 199 | 200 | def get_sys_KO_cluster( 201 | obj, 202 | sys_res: np.ndarray, 203 | perplexity=25, 204 | n_cluster=50, 205 | save_as=None, 206 | show_TSNE=True, 207 | verbose=False, 208 | ): 209 | """ 210 | obj: a sc object. 211 | sys_res: np.ndarray, output of run_sys_KO. 212 | perplexity: int, hyperparameter of TSNE. 213 | n_cluster: int, hyperparameter of k-means. 214 | show_TSNE: bool, whether to show the TSNE plot after clustering. 215 | """ 216 | np.random.seed(100) 217 | scaled_sys_res = StandardScaler().fit_transform( 218 | sys_res 219 | ) # standarize features (cols) 220 | X_embedded = TSNE( 221 | n_components=2, 222 | learning_rate="auto", 223 | perplexity=perplexity, 224 | metric="euclidean", 225 | init="pca", 226 | ).fit_transform(scaled_sys_res) 227 | _, label = kmeans2(X_embedded, n_cluster, minit="points") # centroid 228 | cluster_idx = np.where(label == label[obj(obj._target_gene)]) 229 | cluster_gene_names = np.array(obj._gene_names)[cluster_idx] 230 | if verbose: 231 | print( 232 | f"TSNE perplexity: {perplexity}, # Clustering: {n_cluster}\nThe cluster containing {obj._target_gene} has {len(cluster_gene_names)} genes" 233 | ) 234 | if save_as is not None: 235 | os.makedirs("result", exist_ok=True) 236 | np.savetxt( 237 | os.path.join("result", f"sys_KO_{save_as}.txt"), 238 | cluster_gene_names, 239 | fmt="%s", 240 | delimiter="\t", 241 | ) 242 | print(f"save ranked genes to \"./result/sys_KO_{save_as}.txt\"") 243 | if show_TSNE: 244 | fig, ax = plt.subplots(figsize=(8, 8), dpi=80) 245 | colors = [ 246 | "red" if i == label[obj(obj._target_gene)] else "black" for i in label 247 | ] 248 | ax.set_title(f"TSNE plot of systematic KO {len(sys_res)} genes") 249 | ax.scatter(X_embedded[:, 0], X_embedded[:, 1], s=3, c=colors, alpha=0.5) 250 | ax.axis("tight") 251 | ax.annotate( 252 | obj._target_gene, 253 | xy=(X_embedded[obj(obj._target_gene)]), 254 | xycoords="data", 255 | xytext=(X_embedded[obj(obj._target_gene)] + 15), 256 | textcoords="offset points", 257 | arrowprops=dict(arrowstyle="->", connectionstyle="arc3"), 258 | ) 259 | plt.show() 260 | return cluster_gene_names 261 | 262 | 263 | if __name__ == "__main__": 264 | pm = np.array([0, 0]) 265 | pv = np.array([0, 0]) 266 | qm = np.array([[1, 0], [0, 1]]) 267 | qv = np.array([[1, 0], [0, 1]]) 268 | assert _kl_mvn(pm, qm, pv, qv) == 0 269 | -------------------------------------------------------------------------------- /GenKI/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1.dev" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yongjian Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjgeno/GenKI/6d69789b89859eda75ac75dfb1cc00ef190ada41/MANIFEST.in -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenKI (Gene Knock-out Inference) 2 | A VGAE (Variational Graph Auto-Encoder) based model to learn perturbation using scRNA-seq data.
3 | New! Data has been added.
4 | [Paper](https://doi.org/10.1093/nar/gkad450) 5 |
6 |

7 | drawing 8 |

9 |
10 | 11 | ### Install dependencies 12 | Fist install dependencies of GenKI with `conda`: 13 | ```shell 14 | conda env create -f environment.yml 15 | conda activate ogenki 16 | ``` 17 | Install `pytorch-geometric` following the document:
18 | https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 19 |
20 |
21 | 22 | ### Install GenKI with `pip`: 23 | ```shell 24 | pip install git+https://github.com/yjgeno/GenKI.git 25 | ``` 26 | or install it manually from source: 27 | ```shell 28 | git clone https://github.com/yjgeno/GenKI.git 29 | cd GenKI 30 | pip install . 31 | ``` 32 |
33 | 34 | #### Tutorial 35 | Virtual KO experiment:
https://github.com/yjgeno/GenKI/blob/master/notebook/Example.ipynb
36 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | * The microglial (Trem2-KO) dataset is now available at [GoogleDrive](https://drive.google.com/file/d/1tG9bUGCsWqhg0hJ94lDLtLl8WLl0hDks/view?usp=sharing). 2 | * Other datasets may be provided upon request. 3 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | 4 | dependencies: 5 | - python~=3.9.6 6 | - pip>=22.0.4 7 | # - pycairo>=1.21.0 8 | - anndata>=0.7.4 9 | - numpy>=1.17.0 10 | - matplotlib-base>=3.1.2 11 | - pandas>=0.21 12 | - scipy>=1.4 13 | - seaborn 14 | - h5py>=3 15 | - tqdm 16 | # - statsmodels~=0.13.1 17 | - patsy 18 | - networkx>=2.3 19 | - natsort 20 | - joblib 21 | - numba>=0.41.0 22 | - umap-learn>=0.3.10 23 | - packaging 24 | - setuptools-scm 25 | - black>=20.8b1 26 | - docutils 27 | - sphinx<4.2,>=4.1 28 | - sphinx_rtd_theme>=0.3.1 29 | # - python-igraph~=0.9.8 30 | - leidenalg 31 | - louvain!=0.6.2,>=0.6 32 | # - scikit-misc~=0.1.4 33 | - pytest>=4.4 34 | - pytest-nunit 35 | - dask-core!=2.17.0 36 | - fsspec 37 | - zappy 38 | - zarr 39 | - profimp 40 | - flit-core 41 | - ipywidgets 42 | 43 | - pip: 44 | - flit 45 | - session-info 46 | - scanpydoc<0.7.6,>=0.7.4 47 | - torch==1.11.0 48 | - tensorboard 49 | # - torch-geometric 50 | - scikit-learn 51 | - scikit-misc>=0.1.3 52 | - scanpy>=1.7.2 53 | - ray 54 | - notebook 55 | name: ogenki 56 | 57 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjgeno/GenKI/6d69789b89859eda75ac75dfb1cc00ef190ada41/logo.jpg -------------------------------------------------------------------------------- /notebook/Compatible_matlab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6c209667", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import scanpy as sc\n", 15 | "\n", 16 | "sc.settings.verbosity = 0" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "ffae2228", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import GenKI as gk\n", 27 | "from GenKI.preprocesing import build_adata\n", 28 | "from GenKI.dataLoader import DataLoader\n", 29 | "from GenKI.train import VGAE_trainer\n", 30 | "from GenKI import utils\n", 31 | "\n", 32 | "%load_ext autoreload\n", 33 | "%autoreload 2" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "id": "d32e28a9", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "load counts from data\\filtered_gene_bc_matrices\\X.txt\n", 47 | "AnnData object with n_obs × n_vars = 100 × 300\n", 48 | " obs: 'cell_type'\n", 49 | " uns: 'log1p'\n", 50 | " layers: 'raw', 'norm'\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "# example\n", 56 | "\n", 57 | "data_dir = os.path.join(\"data\", \"filtered_gene_bc_matrices\")\n", 58 | "counts_path = str(os.path.join(data_dir, \"X.txt\"))\n", 59 | "gene_path = str(os.path.join(data_dir, \"g.txt\"))\n", 60 | "cell_path = str(os.path.join(data_dir, \"c.txt\"))\n", 61 | "\n", 62 | "adata = build_adata(counts_path, \n", 63 | " gene_path, \n", 64 | " cell_path, \n", 65 | " meta_cell_cols=[\"cell_type\"], # colname of cell type\n", 66 | " delimiter=',', # X.txt\n", 67 | " transpose=True, # X.txt\n", 68 | " log_normalize=True,\n", 69 | " scale_data=True,\n", 70 | " )\n", 71 | "\n", 72 | "adata = adata[:100, :300].copy()\n", 73 | "print(adata)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "id": "071c88c5", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "'LAMB3'" 86 | ] 87 | }, 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "# gene to ko\n", 95 | "adata.var_names[66]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "id": "1e37d540", 102 | "metadata": { 103 | "scrolled": true 104 | }, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "use all the cells (100) in adata\n", 111 | "build GRN\n", 112 | "ray init, using 8 CPUs\n", 113 | "execution time of making pcNet: 6.34 s\n", 114 | "GRN has been built and saved in \"GRNs\\pcNet_example.npz\"\n", 115 | "init completed\n", 116 | "\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "# load data\n", 122 | "\n", 123 | "data_wrapper = DataLoader(\n", 124 | " adata, # adata object\n", 125 | " target_gene = [66], # KO gene name/index\n", 126 | " target_cell = None, # obsname for cell type, if none use all\n", 127 | " obs_label = \"cell_type\", # colname for genes\n", 128 | " GRN_file_dir = \"GRNs\", # folder name for GRNs\n", 129 | " rebuild_GRN = True, # whether build GRN by pcNet\n", 130 | " pcNet_name = \"pcNet_example\", # GRN file name\n", 131 | " verbose = True, # whether verbose\n", 132 | " n_cpus = 8, # multiprocessing\n", 133 | " )\n", 134 | "\n", 135 | "data_wt = data_wrapper.load_data()\n", 136 | "data_ko = data_wrapper.load_kodata()" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "id": "7862c768", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# init trainer\n", 147 | "\n", 148 | "hyperparams = {\"epochs\": 100, \n", 149 | " \"lr\": 7e-4, \n", 150 | " \"beta\": 1e-4, \n", 151 | " \"seed\": 8096}\n", 152 | "log_dir=None \n", 153 | "\n", 154 | "sensei = VGAE_trainer(data_wt, \n", 155 | " epochs=hyperparams[\"epochs\"], \n", 156 | " lr=hyperparams[\"lr\"], \n", 157 | " log_dir=log_dir, \n", 158 | " beta=hyperparams[\"beta\"],\n", 159 | " seed=hyperparams[\"seed\"],\n", 160 | " verbose=True,\n", 161 | " )" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "id": "31bfb6a0", 168 | "metadata": { 169 | "scrolled": false 170 | }, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "Epoch: 000, Loss: 1.5318, AUROC: 0.8618, AP: 0.7514\n", 177 | "Epoch: 001, Loss: 1.4803, AUROC: 0.8643, AP: 0.7555\n", 178 | "Epoch: 002, Loss: 1.5022, AUROC: 0.8668, AP: 0.7596\n", 179 | "Epoch: 003, Loss: 1.4556, AUROC: 0.8692, AP: 0.7634\n", 180 | "Epoch: 004, Loss: 1.4764, AUROC: 0.8715, AP: 0.7671\n", 181 | "Epoch: 005, Loss: 1.4815, AUROC: 0.8739, AP: 0.7708\n", 182 | "Epoch: 006, Loss: 1.4358, AUROC: 0.8760, AP: 0.7744\n", 183 | "Epoch: 007, Loss: 1.4775, AUROC: 0.8783, AP: 0.7781\n", 184 | "Epoch: 008, Loss: 1.4364, AUROC: 0.8804, AP: 0.7812\n", 185 | "Epoch: 009, Loss: 1.4689, AUROC: 0.8823, AP: 0.7839\n", 186 | "Epoch: 010, Loss: 1.4504, AUROC: 0.8842, AP: 0.7863\n", 187 | "Epoch: 011, Loss: 1.4228, AUROC: 0.8859, AP: 0.7886\n", 188 | "Epoch: 012, Loss: 1.4194, AUROC: 0.8875, AP: 0.7907\n", 189 | "Epoch: 013, Loss: 1.3903, AUROC: 0.8890, AP: 0.7928\n", 190 | "Epoch: 014, Loss: 1.3900, AUROC: 0.8904, AP: 0.7948\n", 191 | "Epoch: 015, Loss: 1.3866, AUROC: 0.8916, AP: 0.7969\n", 192 | "Epoch: 016, Loss: 1.3610, AUROC: 0.8928, AP: 0.7991\n", 193 | "Epoch: 017, Loss: 1.3729, AUROC: 0.8940, AP: 0.8016\n", 194 | "Epoch: 018, Loss: 1.3424, AUROC: 0.8953, AP: 0.8042\n", 195 | "Epoch: 019, Loss: 1.3410, AUROC: 0.8965, AP: 0.8069\n", 196 | "Epoch: 020, Loss: 1.3542, AUROC: 0.8977, AP: 0.8100\n", 197 | "Epoch: 021, Loss: 1.3635, AUROC: 0.8989, AP: 0.8131\n", 198 | "Epoch: 022, Loss: 1.3064, AUROC: 0.9001, AP: 0.8165\n", 199 | "Epoch: 023, Loss: 1.3163, AUROC: 0.9013, AP: 0.8197\n", 200 | "Epoch: 024, Loss: 1.3282, AUROC: 0.9026, AP: 0.8227\n", 201 | "Epoch: 025, Loss: 1.3091, AUROC: 0.9038, AP: 0.8250\n", 202 | "Epoch: 026, Loss: 1.2974, AUROC: 0.9050, AP: 0.8271\n", 203 | "Epoch: 027, Loss: 1.2778, AUROC: 0.9061, AP: 0.8291\n", 204 | "Epoch: 028, Loss: 1.2874, AUROC: 0.9070, AP: 0.8308\n", 205 | "Epoch: 029, Loss: 1.2723, AUROC: 0.9080, AP: 0.8323\n", 206 | "Epoch: 030, Loss: 1.2780, AUROC: 0.9087, AP: 0.8336\n", 207 | "Epoch: 031, Loss: 1.2760, AUROC: 0.9095, AP: 0.8348\n", 208 | "Epoch: 032, Loss: 1.2589, AUROC: 0.9102, AP: 0.8359\n", 209 | "Epoch: 033, Loss: 1.2620, AUROC: 0.9109, AP: 0.8369\n", 210 | "Epoch: 034, Loss: 1.2600, AUROC: 0.9114, AP: 0.8377\n", 211 | "Epoch: 035, Loss: 1.2685, AUROC: 0.9119, AP: 0.8384\n", 212 | "Epoch: 036, Loss: 1.2250, AUROC: 0.9124, AP: 0.8392\n", 213 | "Epoch: 037, Loss: 1.2224, AUROC: 0.9128, AP: 0.8399\n", 214 | "Epoch: 038, Loss: 1.2461, AUROC: 0.9132, AP: 0.8406\n", 215 | "Epoch: 039, Loss: 1.2168, AUROC: 0.9136, AP: 0.8413\n", 216 | "Epoch: 040, Loss: 1.2123, AUROC: 0.9140, AP: 0.8419\n", 217 | "Epoch: 041, Loss: 1.2369, AUROC: 0.9143, AP: 0.8425\n", 218 | "Epoch: 042, Loss: 1.2207, AUROC: 0.9146, AP: 0.8431\n", 219 | "Epoch: 043, Loss: 1.2010, AUROC: 0.9148, AP: 0.8436\n", 220 | "Epoch: 044, Loss: 1.2111, AUROC: 0.9151, AP: 0.8443\n", 221 | "Epoch: 045, Loss: 1.2058, AUROC: 0.9153, AP: 0.8451\n", 222 | "Epoch: 046, Loss: 1.1937, AUROC: 0.9155, AP: 0.8459\n", 223 | "Epoch: 047, Loss: 1.1681, AUROC: 0.9158, AP: 0.8469\n", 224 | "Epoch: 048, Loss: 1.1873, AUROC: 0.9160, AP: 0.8477\n", 225 | "Epoch: 049, Loss: 1.1639, AUROC: 0.9162, AP: 0.8484\n", 226 | "Epoch: 050, Loss: 1.1858, AUROC: 0.9164, AP: 0.8490\n", 227 | "Epoch: 051, Loss: 1.1696, AUROC: 0.9166, AP: 0.8496\n", 228 | "Epoch: 052, Loss: 1.1834, AUROC: 0.9168, AP: 0.8502\n", 229 | "Epoch: 053, Loss: 1.2075, AUROC: 0.9170, AP: 0.8507\n", 230 | "Epoch: 054, Loss: 1.1903, AUROC: 0.9171, AP: 0.8512\n", 231 | "Epoch: 055, Loss: 1.1360, AUROC: 0.9173, AP: 0.8516\n", 232 | "Epoch: 056, Loss: 1.1702, AUROC: 0.9174, AP: 0.8521\n", 233 | "Epoch: 057, Loss: 1.1352, AUROC: 0.9175, AP: 0.8525\n", 234 | "Epoch: 058, Loss: 1.1411, AUROC: 0.9176, AP: 0.8528\n", 235 | "Epoch: 059, Loss: 1.1275, AUROC: 0.9178, AP: 0.8533\n", 236 | "Epoch: 060, Loss: 1.1570, AUROC: 0.9179, AP: 0.8536\n", 237 | "Epoch: 061, Loss: 1.1273, AUROC: 0.9179, AP: 0.8539\n", 238 | "Epoch: 062, Loss: 1.1276, AUROC: 0.9179, AP: 0.8542\n", 239 | "Epoch: 063, Loss: 1.1243, AUROC: 0.9180, AP: 0.8546\n", 240 | "Epoch: 064, Loss: 1.0985, AUROC: 0.9181, AP: 0.8549\n", 241 | "Epoch: 065, Loss: 1.1350, AUROC: 0.9182, AP: 0.8551\n", 242 | "Epoch: 066, Loss: 1.1116, AUROC: 0.9182, AP: 0.8554\n", 243 | "Epoch: 067, Loss: 1.1489, AUROC: 0.9182, AP: 0.8555\n", 244 | "Epoch: 068, Loss: 1.1193, AUROC: 0.9182, AP: 0.8557\n", 245 | "Epoch: 069, Loss: 1.1360, AUROC: 0.9182, AP: 0.8558\n", 246 | "Epoch: 070, Loss: 1.1065, AUROC: 0.9182, AP: 0.8560\n", 247 | "Epoch: 071, Loss: 1.0939, AUROC: 0.9183, AP: 0.8562\n", 248 | "Epoch: 072, Loss: 1.1089, AUROC: 0.9183, AP: 0.8562\n", 249 | "Epoch: 073, Loss: 1.1342, AUROC: 0.9182, AP: 0.8562\n", 250 | "Epoch: 074, Loss: 1.0828, AUROC: 0.9182, AP: 0.8562\n", 251 | "Epoch: 075, Loss: 1.1173, AUROC: 0.9182, AP: 0.8563\n", 252 | "Epoch: 076, Loss: 1.1005, AUROC: 0.9182, AP: 0.8565\n", 253 | "Epoch: 077, Loss: 1.0860, AUROC: 0.9181, AP: 0.8566\n", 254 | "Epoch: 078, Loss: 1.0871, AUROC: 0.9181, AP: 0.8567\n", 255 | "Epoch: 079, Loss: 1.1317, AUROC: 0.9181, AP: 0.8569\n", 256 | "Epoch: 080, Loss: 1.0836, AUROC: 0.9181, AP: 0.8571\n", 257 | "Epoch: 081, Loss: 1.0848, AUROC: 0.9181, AP: 0.8573\n", 258 | "Epoch: 082, Loss: 1.0880, AUROC: 0.9182, AP: 0.8577\n", 259 | "Epoch: 083, Loss: 1.1009, AUROC: 0.9182, AP: 0.8579\n", 260 | "Epoch: 084, Loss: 1.0999, AUROC: 0.9182, AP: 0.8582\n", 261 | "Epoch: 085, Loss: 1.0923, AUROC: 0.9183, AP: 0.8586\n", 262 | "Epoch: 086, Loss: 1.1181, AUROC: 0.9183, AP: 0.8589\n", 263 | "Epoch: 087, Loss: 1.0989, AUROC: 0.9183, AP: 0.8592\n", 264 | "Epoch: 088, Loss: 1.0841, AUROC: 0.9183, AP: 0.8595\n", 265 | "Epoch: 089, Loss: 1.1156, AUROC: 0.9183, AP: 0.8597\n", 266 | "Epoch: 090, Loss: 1.0894, AUROC: 0.9183, AP: 0.8600\n", 267 | "Epoch: 091, Loss: 1.0946, AUROC: 0.9183, AP: 0.8603\n", 268 | "Epoch: 092, Loss: 1.0588, AUROC: 0.9183, AP: 0.8608\n", 269 | "Epoch: 093, Loss: 1.1279, AUROC: 0.9183, AP: 0.8611\n", 270 | "Epoch: 094, Loss: 1.0511, AUROC: 0.9183, AP: 0.8615\n", 271 | "Epoch: 095, Loss: 1.0543, AUROC: 0.9183, AP: 0.8620\n", 272 | "Epoch: 096, Loss: 1.0692, AUROC: 0.9183, AP: 0.8625\n", 273 | "Epoch: 097, Loss: 1.0644, AUROC: 0.9184, AP: 0.8630\n", 274 | "Epoch: 098, Loss: 1.0727, AUROC: 0.9185, AP: 0.8635\n", 275 | "Epoch: 099, Loss: 1.0578, AUROC: 0.9186, AP: 0.8640\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "# %%timeit\n", 281 | "sensei.train()" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "e16a285e", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "# sensei.save_model('model_example')" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 8, 297 | "id": "e055e8f3", 298 | "metadata": { 299 | "scrolled": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "# get distance between wt and ko\n", 304 | "\n", 305 | "z_mu_wt, z_std_wt = sensei.get_latent_vars(data_wt)\n", 306 | "z_mu_ko, z_std_ko = sensei.get_latent_vars(data_ko)\n", 307 | "dis = gk.utils.get_distance(z_mu_ko, z_std_ko, z_mu_wt, z_std_wt, by=\"KL\")" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 9, 313 | "id": "78e852b7", 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/html": [ 319 | "
\n", 320 | "\n", 333 | "\n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | "
disrank
LAMB352.7471021
COL3A10.1381862
S100A90.0850783
MDK0.0849884
MFAP20.0799435
\n", 369 | "
" 370 | ], 371 | "text/plain": [ 372 | " dis rank\n", 373 | "LAMB3 52.747102 1\n", 374 | "COL3A1 0.138186 2\n", 375 | "S100A9 0.085078 3\n", 376 | "MDK 0.084988 4\n", 377 | "MFAP2 0.079943 5" 378 | ] 379 | }, 380 | "execution_count": 9, 381 | "metadata": {}, 382 | "output_type": "execute_result" 383 | } 384 | ], 385 | "source": [ 386 | "# raw ranked gene list\n", 387 | "\n", 388 | "res_raw = utils.get_generank(data_wt, dis, rank=True)\n", 389 | "res_raw.head()" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 10, 395 | "id": "9ddf3b86", 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "name": "stderr", 400 | "output_type": "stream", 401 | "text": [ 402 | "Permutating: 100%|██████████| 100/100 [00:02<00:00, 33.71it/s]\n" 403 | ] 404 | }, 405 | { 406 | "data": { 407 | "text/html": [ 408 | "
\n", 409 | "\n", 422 | "\n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | "
disindexhitrank
LAMB352.747102661001
COL3A10.13818691002
S100A90.0850781941003
TIMP10.0336471451004
\n", 463 | "
" 464 | ], 465 | "text/plain": [ 466 | " dis index hit rank\n", 467 | "LAMB3 52.747102 66 100 1\n", 468 | "COL3A1 0.138186 9 100 2\n", 469 | "S100A9 0.085078 194 100 3\n", 470 | "TIMP1 0.033647 145 100 4" 471 | ] 472 | }, 473 | "execution_count": 10, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "# if permutation test\n", 480 | "\n", 481 | "null = sensei.pmt(data_ko, n=100, by=\"KL\")\n", 482 | "res = utils.get_generank(data_wt, dis, null,)\n", 483 | "# save_significant_as = 'gene_list_example')\n", 484 | "res" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "id": "fc571eb2", 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [] 494 | } 495 | ], 496 | "metadata": { 497 | "kernelspec": { 498 | "display_name": "Python 3 (ipykernel)", 499 | "language": "python", 500 | "name": "python3" 501 | }, 502 | "language_info": { 503 | "codemirror_mode": { 504 | "name": "ipython", 505 | "version": 3 506 | }, 507 | "file_extension": ".py", 508 | "mimetype": "text/x-python", 509 | "name": "python", 510 | "nbconvert_exporter": "python", 511 | "pygments_lexer": "ipython3", 512 | "version": "3.9.13" 513 | } 514 | }, 515 | "nbformat": 4, 516 | "nbformat_minor": 5 517 | } 518 | -------------------------------------------------------------------------------- /notebook/Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6323b4c4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import scanpy as sc\n", 15 | "\n", 16 | "sc.settings.verbosity = 0" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "ffae2228", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import GenKI as gk\n", 27 | "from GenKI.preprocesing import build_adata\n", 28 | "from GenKI.dataLoader import DataLoader\n", 29 | "from GenKI.train import VGAE_trainer\n", 30 | "from GenKI import utils\n", 31 | "\n", 32 | "%load_ext autoreload\n", 33 | "%autoreload 2" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "id": "2b15fb16", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "load counts from data/microglial_seurat_WT.h5ad\n" 47 | ] 48 | }, 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "AnnData object with n_obs × n_vars = 100 × 300\n", 53 | " obs: 'sce_source', 'treatment', 'trem2_genotype', 'snn_cluster', 'nCount_RNA', 'nFeature_RNA'\n", 54 | " var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'\n", 55 | " layers: 'norm'" 56 | ] 57 | }, 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "# subset data as an example\n", 65 | "\n", 66 | "adata = build_adata(\"data/microglial_seurat_WT.h5ad\")\n", 67 | "adata = adata[:100, :300].copy()\n", 68 | "adata" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "id": "1e37d540", 75 | "metadata": { 76 | "scrolled": true 77 | }, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "use all the cells (100) in adata\n", 84 | "build GRN\n", 85 | "ray init, using 8 CPUs\n", 86 | "execution time of making pcNet: 6.26 s\n", 87 | "GRN has been built and saved in \"GRNs\\pcNet_example.npz\"\n", 88 | "init completed\n", 89 | "\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "# load data\n", 95 | "\n", 96 | "data_wrapper = DataLoader(\n", 97 | " adata, # adata object\n", 98 | " target_gene = [\"TUBG1\"], # KO gene name\n", 99 | " target_cell = None, # obsname for cell type, if none use all\n", 100 | " obs_label = \"ident\", # colname for genes\n", 101 | " GRN_file_dir = \"GRNs\", # folder name for GRNs\n", 102 | " rebuild_GRN = True, # whether build GRN by pcNet\n", 103 | " pcNet_name = \"pcNet_example\", # GRN file name\n", 104 | " verbose = True, # whether verbose\n", 105 | " n_cpus = 8, # multiprocessing\n", 106 | " )\n", 107 | "\n", 108 | "data_wt = data_wrapper.load_data()\n", 109 | "data_ko = data_wrapper.load_kodata()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "7862c768", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# init trainer\n", 120 | "\n", 121 | "hyperparams = {\"epochs\": 100, \n", 122 | " \"lr\": 7e-4, \n", 123 | " \"beta\": 1e-4, \n", 124 | " \"seed\": 8096}\n", 125 | "log_dir=None \n", 126 | "\n", 127 | "sensei = VGAE_trainer(data_wt, \n", 128 | " epochs=hyperparams[\"epochs\"], \n", 129 | " lr=hyperparams[\"lr\"], \n", 130 | " log_dir=log_dir, \n", 131 | " beta=hyperparams[\"beta\"],\n", 132 | " seed=hyperparams[\"seed\"],\n", 133 | " verbose=False,\n", 134 | " )" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 6, 140 | "id": "31bfb6a0", 141 | "metadata": { 142 | "scrolled": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# %%timeit\n", 147 | "sensei.train()" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "id": "e16a285e", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# sensei.save_model('model_example')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "id": "e055e8f3", 164 | "metadata": { 165 | "scrolled": true 166 | }, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "(300,)\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "# get distance between wt and ko\n", 178 | "\n", 179 | "z_mu_wt, z_std_wt = sensei.get_latent_vars(data_wt)\n", 180 | "z_mu_ko, z_std_ko = sensei.get_latent_vars(data_ko)\n", 181 | "dis = gk.utils.get_distance(z_mu_ko, z_std_ko, z_mu_wt, z_std_wt, by=\"KL\")\n", 182 | "print(dis.shape)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 9, 188 | "id": "78e852b7", 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/html": [ 194 | "
\n", 195 | "\n", 208 | "\n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | "
disrank
TUBG12.5654901
TYROBP0.0017532
LST10.0016283
CYBA0.0015174
LAT20.0014385
\n", 244 | "
" 245 | ], 246 | "text/plain": [ 247 | " dis rank\n", 248 | "TUBG1 2.565490 1\n", 249 | "TYROBP 0.001753 2\n", 250 | "LST1 0.001628 3\n", 251 | "CYBA 0.001517 4\n", 252 | "LAT2 0.001438 5" 253 | ] 254 | }, 255 | "execution_count": 9, 256 | "metadata": {}, 257 | "output_type": "execute_result" 258 | } 259 | ], 260 | "source": [ 261 | "# raw ranked gene list\n", 262 | "\n", 263 | "res_raw = utils.get_generank(data_wt, dis, rank=True)\n", 264 | "res_raw.head()" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 10, 270 | "id": "9ddf3b86", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stderr", 275 | "output_type": "stream", 276 | "text": [ 277 | "Permutating: 100%|██████████| 100/100 [00:03<00:00, 33.02it/s]\n" 278 | ] 279 | }, 280 | { 281 | "data": { 282 | "text/html": [ 283 | "
\n", 284 | "\n", 297 | "\n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | "
disindexhitrank
TUBG12.565491631001
\n", 317 | "
" 318 | ], 319 | "text/plain": [ 320 | " dis index hit rank\n", 321 | "TUBG1 2.56549 163 100 1" 322 | ] 323 | }, 324 | "execution_count": 10, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "# if permutation test\n", 331 | "\n", 332 | "null = sensei.pmt(data_ko, n=100, by=\"KL\")\n", 333 | "res = utils.get_generank(data_wt, dis, null,)\n", 334 | "# save_significant_as = 'gene_list_example')\n", 335 | "res" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "id": "6a349662", 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [] 345 | } 346 | ], 347 | "metadata": { 348 | "kernelspec": { 349 | "display_name": "Python 3 (ipykernel)", 350 | "language": "python", 351 | "name": "python3" 352 | }, 353 | "language_info": { 354 | "codemirror_mode": { 355 | "name": "ipython", 356 | "version": 3 357 | }, 358 | "file_extension": ".py", 359 | "mimetype": "text/x-python", 360 | "name": "python", 361 | "nbconvert_exporter": "python", 362 | "pygments_lexer": "ipython3", 363 | "version": "3.9.13" 364 | } 365 | }, 366 | "nbformat": 4, 367 | "nbformat_minor": 5 368 | } 369 | -------------------------------------------------------------------------------- /notebook/SERGIO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | # import matplotlib.pyplot as plt 5 | # import seaborn as sns 6 | # from scipy import stats 7 | # from PIL import Image 8 | # import io 9 | import os 10 | from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, average_precision_score 11 | from sklearn.covariance import GraphicalLasso, GraphicalLassoCV 12 | # from time import time 13 | 14 | # import matplotlib as mpl 15 | # mpl.rcParams['interactive'] == False 16 | # mpl.rcParams['lines.linewidth'] = 2 17 | # mpl.rcParams['lines.linestyle'] = '--' 18 | # mpl.rcParams['axes.titlesize'] = 24 19 | # mpl.rcParams['axes.labelsize'] = 16 20 | # mpl.rcParams['lines.markersize'] = 2 21 | # mpl.rcParams['xtick.labelsize'] = 16 22 | # mpl.rcParams['ytick.labelsize'] = 16 23 | # mpl.rcParams['figure.dpi'] = 80 24 | # mpl.rcParams['legend.fontsize'] = 12 25 | # markers = [".",",","o","v","^","<",">"] 26 | # linestyles = ['-', '--', '-.', ':', 'dashed', 'dashdot', 'dotted'] 27 | 28 | import GenKI as gk 29 | from GenKI.preprocesing import build_adata 30 | from GenKI.train import VGAE_trainer 31 | from GenKI import utils 32 | from scTenifold import scTenifoldKnk 33 | 34 | N = 10 35 | 36 | def main(args): 37 | exp = pd.read_csv(f"SERGIO/De-noised_{args.data}/simulated_noNoise_{args.data_id}.csv", index_col=0) 38 | gt_grn_df = pd.read_csv(f"SERGIO/De-noised_{args.data}/gt_GRN.csv", header=None) 39 | inds = list(zip(gt_grn_df[0], gt_grn_df[1])) 40 | gt_grn = np.zeros((len(exp), len(exp))).astype(int) 41 | for ind in inds: 42 | gt_grn[ind] = 1 43 | 44 | genes = np.arange(len(exp)).astype(str) 45 | genes = np.char.add(["g"]*len(genes), genes) 46 | idx = gt_grn_df[0].value_counts().index.to_numpy() 47 | KO_genes = genes[idx] # gene list to be KO 48 | print(KO_genes[:N]) 49 | 50 | ada = sc.AnnData(exp.values.T) 51 | ada.var_names = genes 52 | 53 | corr_mat = np.corrcoef(ada.X.T) # corr 54 | cov = GraphicalLassoCV(alphas=4).fit(ada.X) # GGM 55 | ggm = GraphicalLasso(alpha=cov.alpha_, 56 | max_iter=100, 57 | assume_centered=False) 58 | res = ggm.fit(ada.X) 59 | pre_ = np.around(res.precision_, decimals=6) 60 | 61 | # GenKI 62 | ada_WT = build_adata(ada, scale_data = True, uppercase=False) 63 | 64 | for target_gene in KO_genes[:N]: 65 | data_wrapper = gk.DataLoader(ada_WT, 66 | target_gene = [target_gene], 67 | target_cell = None, 68 | # obs_label = None, 69 | GRN_file_dir = 'GRNs', 70 | rebuild_GRN = False, 71 | pcNet_name = f'pcNet_{args.data}_{args.data_id}_man', # build network 72 | cutoff = 85, 73 | verbose = False, 74 | ) 75 | ko_idx = data_wrapper([target_gene])[0] 76 | labels = np.zeros(len(exp)).astype(int) 77 | labels_inds = gt_grn_df.loc[gt_grn_df[0]==ko_idx, 1].to_numpy().astype(int) 78 | labels[labels_inds] = 1 79 | data = data_wrapper.load_data() 80 | data_KO = data_wrapper.load_kodata() 81 | 82 | hyperparams = {"epochs": 75, 83 | "lr": 7e-4, 84 | "beta": 1e-4, 85 | "seed": 8096} 86 | log_dir=None 87 | 88 | sensei = VGAE_trainer(data, 89 | epochs=hyperparams["epochs"], 90 | lr=hyperparams["lr"], 91 | log_dir=log_dir, 92 | beta=hyperparams["beta"], 93 | seed=hyperparams["seed"], 94 | ) 95 | sensei.train() 96 | z_mu, z_std = sensei.get_latent_vars(data) 97 | z_mu_KO, z_std_KO = sensei.get_latent_vars(data_KO) 98 | dis = gk.utils.get_distance(z_mu_KO, z_std_KO, z_mu, z_std, by = 'KL') 99 | res = utils.get_generank(data, dis, rank=False) 100 | scores = res["dis"].to_numpy() 101 | fpr, tpr, thres = roc_curve(labels, scores) 102 | roc_auc = roc_auc_score(labels, scores) 103 | print("AUROC:", roc_auc) 104 | precision, recall, _ = precision_recall_curve(labels, scores) 105 | ap = average_precision_score(labels, scores) 106 | print("AP:", ap) 107 | 108 | # junk classifier 109 | scores_junk = np.random.uniform(0,1,len(exp)) 110 | fpr, tpr, thres = roc_curve(labels, scores_junk) 111 | roc_auc_junk = roc_auc_score(labels, scores_junk) 112 | print("AUROC_baseline:", roc_auc_junk) 113 | precision, recall, _ = precision_recall_curve(labels, scores_junk) 114 | ap_junk = average_precision_score(labels, scores_junk) 115 | print("AP_baseline:", ap_junk) 116 | 117 | # corr 118 | score_corr = corr_mat[ko_idx] 119 | fpr, tpr, thres = roc_curve(labels, score_corr) 120 | roc_auc_corr = roc_auc_score(labels, score_corr) 121 | print("AUROC_corr:", roc_auc_corr) 122 | precision, recall, _ = precision_recall_curve(labels, score_corr) 123 | ap_corr = average_precision_score(labels, score_corr) 124 | print("AP_corr:", ap_corr) 125 | 126 | # GGM 127 | scores_ggm = pre_[ko_idx] 128 | fpr, tpr, thres = roc_curve(labels, scores_ggm) 129 | roc_auc_ggm = roc_auc_score(labels, scores_ggm) 130 | print("AUROC_ggm:", roc_auc_ggm) 131 | precision, recall, _ = precision_recall_curve(labels, scores_ggm) 132 | ap_ggm = average_precision_score(labels, scores_ggm) 133 | print("AP_ggm:", ap_ggm) 134 | 135 | # Knk 136 | exp.index = genes 137 | # sct = scTenifoldKnk(data=exp, 138 | # ko_method="default", 139 | # ko_genes=[target_gene], # the gene you wants to knock out 140 | # qc_kws={"min_lib_size": 1, "min_percent": 0.001}, 141 | # ) 142 | # result = sct.build() 143 | 144 | knk = scTenifoldKnk(data=exp, 145 | qc_kws={"min_lib_size": 1, "min_percent": 0.001}, 146 | ) 147 | knk.run_step("qc") 148 | knk.run_step("nc", n_cpus=1) 149 | knk.run_step("td") 150 | knk.run_step("ko", ko_genes=[target_gene], ko_method="default") 151 | knk.run_step("ma") 152 | knk.run_step("dr", sorted_by="adjusted p-value") 153 | result = knk.d_regulation 154 | 155 | knk_score = dict(zip(result["Gene"], result["FC"])) 156 | res["temp"] = res.index 157 | res["Knk"] = res["temp"].map(knk_score) 158 | del res["temp"] 159 | scores_knk = res["Knk"].to_numpy() 160 | fpr_knk, tpr_knk, thres_knk = roc_curve(labels, scores_knk) 161 | roc_auc_knk = roc_auc_score(labels, scores_knk) 162 | print(roc_auc_knk) 163 | precision_knk, recall_knk, _ = precision_recall_curve(labels, scores_knk) 164 | ap_knk = average_precision_score(labels, scores_knk) 165 | print(ap_knk) 166 | 167 | try: 168 | f = open(os.path.join("result", f'{args.out}.txt'), 'r') 169 | except IOError: 170 | f = open(os.path.join("result", f'{args.out}.txt'), 'w') 171 | f.writelines("file,file_id,KO_gene,ROC,ROC_Knk,ROC_junk,ROC_corr,ROC_ggm,"\ 172 | "AP,AP_Knk,AP_junk,AP_corr,AP_ggm\n") 173 | finally: 174 | f = open(os.path.join("result", f'{args.out}.txt'), 'a') 175 | f.writelines(f"{args.data},{args.data_id},{target_gene},"\ 176 | f"{roc_auc:.4f},{roc_auc_knk:.4f},{roc_auc_junk:.4f},{roc_auc_corr:.4f},{roc_auc_ggm:.4f},"\ 177 | f"{ap:.4f},{ap_knk:.4f},{ap_junk:.4f},{ap_corr:.4f},{ap_ggm:.4f}"\ 178 | "\n") 179 | f.close() 180 | 181 | 182 | if __name__ == '__main__': 183 | import argparse 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('--data', type = str, default = '100G_9T_300cPerT_4_DS1') 186 | parser.add_argument('--data_id', type = int, default = 0) 187 | parser.add_argument('-O', '--out', default = 'SERGIO_trials') 188 | 189 | args = parser.parse_args() 190 | main(args) 191 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anndata==0.8.0, 2 | matplotlib~=3.5.1 3 | numpy>=1.21.6 4 | pandas~=1.4.2 5 | ray>=1.11.0 6 | scanpy==1.9.1 7 | scipy~=1.8.0 8 | statsmodels~=0.13.2 9 | scikit_learn>=1.0.2 10 | torch==1.11.0 11 | tqdm~=4.64.0 12 | -------------------------------------------------------------------------------- /run_eva.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # trap break INT 4 | # for noise in 1 2 3 4 5 5 | # do 6 | # for run in `seq 30 $max` 7 | # do 8 | # echo "noise: $noise, run: $run" 9 | # python -m GenKI.train --dir data --train_out train_X_$noise -X $noise 10 | # echo -e "completed\n" 11 | # done 12 | # done 13 | 14 | 15 | # for noise in 1 2 3 4 5 16 | # do 17 | # for run in `seq 30 $max` 18 | # do 19 | # echo "noise: $noise, run: $run" 20 | # python -m GenKI.train --dir data --train_out train_E_$noise -E $noise 21 | # echo -e "completed\n" 22 | # done 23 | # done 24 | # trap - INT 25 | 26 | 27 | for noise in 0.1 0.3 0.5 0.7 0.9 28 | do 29 | for run in `seq 10 $max` 30 | do 31 | echo "dropout: $noise, run: $run" 32 | python -m GenKI.train --dir data --train_out train_XO_$noise -XO $noise 33 | echo -e "completed\n" 34 | done 35 | done 36 | trap - INT 37 | 38 | 39 | # trap break INT 40 | # for run in `seq 30 $max` 41 | # do 42 | # echo "run: $run" 43 | # python -m GenKI.train --train_out train_log --do_test --generank_out genelist --r2_out r2_score 44 | # echo -e "completed\n" 45 | # done 46 | # trap - INT 47 | 48 | 49 | # trap break INT 50 | # for cutoff in 30 55 70 95 51 | # do 52 | # trap break INT 53 | # for run in `seq 30 $max` 54 | # do 55 | # echo "run: $run" 56 | # python -m GenKI.train --train_out train_log --dir data_$cutoff 57 | # echo -e "completed\n" 58 | # done 59 | # done 60 | # trap - INT 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup, find_packages 3 | 4 | HERE = pathlib.Path(__file__).parent 5 | README = (HERE / "README.md").read_text() 6 | DESCRIPTION = "GenKI" 7 | PACKAGES = find_packages(exclude=("tests*",)) 8 | exec(open('GenKI/version.py').read()) 9 | 10 | INSTALL_REQUIRES = [ 11 | "anndata==0.8.0", 12 | "matplotlib~=3.5.1", 13 | "numpy>=1.21.6", 14 | "pandas~=1.4.2", 15 | "ray>=1.11.0", 16 | "scanpy==1.9.1", 17 | "scipy~=1.8.0", 18 | "statsmodels~=0.13.2", 19 | "scikit_learn>=1.0.2", 20 | # "torch==1.11.0", 21 | "tqdm~=4.64.0", 22 | ] 23 | 24 | setup( 25 | name="GenKI", 26 | version=__version__, 27 | description=DESCRIPTION, 28 | long_description=README, 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/yjgeno/GenKI", 31 | author="Yongjian Yang, TAMU", 32 | author_email="yjyang027@tamu.edu", 33 | license="MIT", 34 | keywords=[ 35 | "neural network", 36 | "graph neural network", 37 | "variational graph neural network", 38 | "computational-biology", 39 | "single-cell", 40 | "gene knock-out", 41 | "gene regulatroy network", 42 | ], 43 | classifiers=[ 44 | "License :: OSI Approved :: MIT License", 45 | "Intended Audience :: Science/Research", 46 | "Topic :: Scientific/Engineering :: Bio-Informatics", 47 | "Programming Language :: Python :: 3", 48 | ], 49 | packages=PACKAGES, 50 | include_package_data=True, # MANIFEST 51 | install_requires=INSTALL_REQUIRES, 52 | ) --------------------------------------------------------------------------------