├── LICENSE ├── PFC tutorial.ipynb ├── README.md ├── fasthigashi ├── FastHigashi_Wrapper.py ├── Fast_process.py ├── __init__.py ├── parafac2_intergrative.py ├── parafac_integrative.py ├── partial_rwr.py ├── preprocessing.py ├── project2orthogonal.py ├── sparse_for_schic.py └── util.py ├── figs ├── fig1.png ├── higashi_cellsystems.png └── higashi_title.png └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ma Lab at CMU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-Higashi: Ultrafast and interpretable single-cell 3D genome analysis 2 | [https://www.cell.com/cell-systems/fulltext/S2405-4712(22)00395-7](https://www.cell.com/cell-systems/fulltext/S2405-4712(22)00395-7) 3 | 4 | Fast-Higashi is an interpretable model that takes single-cell Hi-C (scHi-C) contact maps as input and jointly infers cell embeddings as well as meta-interactions. 5 | ![figs/fig1.png](https://github.com/ma-compbio/Fast-Higashi/blob/main/figs/fig1.png) 6 | # Installation 7 | 8 | We now have Fast-Higashi on conda! 9 | 10 | Do 11 | `conda install -c ruochiz fasthigashi` 12 | or 13 | `mamba install -c ruochiz fasthigashi` 14 | 15 | After that install the latest pytorch with corresponding CUDA support. Check https://pytorch.org for details. Note that fasthigashi won't check if you have pytorch installed. So, the user would have to install the correct pytorch version individually. 16 | 17 | ```{bash} 18 | git clone https://github.com/ma-compbio/Fast-Higashi/ 19 | cd Fast-Higashi 20 | python setup.py install 21 | ``` 22 | 23 | It is recommended to have pytorch installed (with CUDA support when applicable) before installing higashi. 24 | 25 | # Documentation 26 | The input format would be exactly the same as the Higashi software. 27 | Detailed documentation will be updated here at the [Higashi wiki](https://github.com/ma-compbio/Higashi/wiki/Fast-Higashi-Usage) 28 | 29 | # Tutorial 30 | - [Lee et al. (sn-m3c-seq on PFC)](https://github.com/ma-compbio/Fast-Higashi/blob/main/PFC%20tutorial.ipynb) 31 | 32 | # Cite 33 | 34 | Cite our paper by 35 | 36 | ``` 37 | @article {Zhang2022fast, 38 | author = {Zhang, Ruochi and Zhou, Tianming and Ma, Jian}, 39 | title = {Ultrafast and interpretable single-cell 3D genome analysis with Fast-Higashi}, 40 | year = {2022}, 41 | doi = {10.1016/j.cels.2022.09.004}, 42 | journal={Cell systems}, 43 | volume={13}, 44 | number={10}, 45 | pages={798--807}, 46 | year={2022}, 47 | publisher={Elsevier} 48 | } 49 | ``` 50 | 51 | ![figs/Overview.png](https://github.com/ma-compbio/Fast-Higashi/blob/main/figs/higashi_cellsystems.png) 52 | 53 | 54 | 55 | # Contact 56 | 57 | Please contact zhangruo@broadinstitute.org or raise an issue in the github repo with any questions about installation or usage. 58 | -------------------------------------------------------------------------------- /fasthigashi/FastHigashi_Wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse, os, gc, pickle, sys 3 | from pathlib import Path 4 | from tqdm.auto import tqdm, trange 5 | import math, time, h5py 6 | import numpy as np 7 | import pandas as pd 8 | from scipy.sparse import coo_matrix, csr_matrix 9 | from sklearn.preprocessing import normalize 10 | import multiprocessing as mpl 11 | 12 | try: 13 | from .parafac2_intergrative import Fast_Higashi_core 14 | from .preprocessing import calc_bulk, filter_bin, normalize_per_cell, normalize_by_coverage, Clip, normalize_per_batch 15 | from .sparse_for_schic import Sparse, Chrom_Dataset 16 | except: 17 | try: 18 | from parafac2_intergrative import Fast_Higashi_core 19 | from preprocessing import calc_bulk, filter_bin, normalize_per_cell, normalize_by_coverage, Clip, normalize_per_batch 20 | from sparse_for_schic import Sparse, Chrom_Dataset 21 | except: 22 | raise EOFError 23 | 24 | CPU_per_GPU = 4 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="Higashi main program") 27 | parser.add_argument('-c', '--config', type=Path, default=Path("../config_dir/config_ramani.JSON")) 28 | parser.add_argument('--path2input_cache', type=Path, default=None) 29 | parser.add_argument('--path2result_dir', type=Path, default=None) 30 | parser.add_argument('--rank', type=int, default=256) 31 | parser.add_argument('--size', type=int, default=15) 32 | parser.add_argument('--size_func', type=str, default='scale') 33 | parser.add_argument('--off_diag', type=int, default=100) 34 | parser.add_argument('--fac_size', type=eval, default=1) 35 | parser.add_argument('--share_factors', type=eval, default=['shared', 'shared', 'shared']) 36 | parser.add_argument('--l2reg', type=float, default=10) 37 | parser.add_argument('--do_conv', action='store_true', default=False) 38 | parser.add_argument('--do_rwr', action='store_true', default=False) 39 | parser.add_argument('--do_col', action='store_true', default=False) 40 | parser.add_argument('--no_col', action='store_true', default=False) 41 | parser.add_argument('--extra', type=str, default="") 42 | parser.add_argument('--cache_extra', type=str, default="") 43 | parser.add_argument('--filter', action='store_true', default=False) 44 | parser.add_argument('--batch_norm', action='store_true', default=False) 45 | parser.add_argument('--tol', type=float, default=2e-5) 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def get_config(config_path = "./config.jSON"): 51 | import json 52 | c = open(config_path,"r") 53 | return json.load(c) 54 | 55 | 56 | def get_free_gpu(num=1): 57 | # os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > ./tmp') 58 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Total > ./tmp1') 59 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Used > ./tmp2') 60 | memory_all = [int(x.split()[2]) for x in open('tmp1', 'r').readlines()] 61 | memory_used = [int(x.split()[2]) for x in open('tmp2', 'r').readlines()] 62 | memory_available = [m1-m2 for m1,m2 in zip(memory_all, memory_used)] 63 | if len(memory_available) > 0: 64 | max_mem = np.max(memory_available) 65 | ids = np.where(memory_available >= max_mem-1000)[0] 66 | chosen_id = int(np.random.choice(ids, 1)[0]) 67 | print("setting to gpu:%d" % chosen_id, "available memory =", max_mem, "MB") 68 | sys.stdout.flush() 69 | torch.cuda.set_device(chosen_id) 70 | return torch.device("cuda:%d" % chosen_id), chosen_id, max_mem * 1000000 71 | else: 72 | print("running on cpu device then") 73 | import psutil 74 | mem = psutil.virtual_memory().available 75 | return None, None, mem 76 | 77 | # def get_memory_free(gpu_index): 78 | # try: 79 | # from py3nvml import py3nvml 80 | # print ("gpu mem") 81 | # py3nvml.nvmlInit() 82 | # handle = py3nvml.nvmlDeviceGetHandleByIndex(int(gpu_index)) 83 | # mem_info = py3nvml.nvmlDeviceGetMemoryInfo(handle) 84 | # print (mem_info) 85 | # return mem_info.free 86 | # except: 87 | # print ("Are you running on CPU devices? If not, check if you have py3nvml installed") 88 | # print ("Otherwise Fast-Higashi would have an incorrect estimation of the gpu memory") 89 | # import psutil 90 | # return psutil.virtual_memory().available 91 | 92 | def parse_embedding(project_list, fac, dim=None): 93 | if dim is None: 94 | dim = fac.shape[1] 95 | else: 96 | dim = min(dim, fac.shape[1]) 97 | embedding_list = [] 98 | for p in project_list: 99 | if type(p).__name__ == 'Tensor': p = p.detach().cpu().numpy() 100 | p = p / np.linalg.norm(p, axis=0, keepdims=True) 101 | embed = fac @ p 102 | embedding_list.append(embed) 103 | 104 | embedding = np.concatenate(embedding_list, axis=1) 105 | from sklearn.decomposition import TruncatedSVD 106 | model = TruncatedSVD(n_components=dim) 107 | embedding = model.fit_transform(embedding) 108 | return embedding 109 | 110 | 111 | 112 | 113 | 114 | 115 | class FastHigashi(): 116 | def __init__(self, config_path, 117 | path2input_cache, 118 | path2result_dir, 119 | off_diag, 120 | filter, 121 | do_conv, 122 | do_rwr, 123 | do_col, 124 | no_col): 125 | super().__init__() 126 | self.off_diag = off_diag 127 | self.filter = filter 128 | self.do_conv = do_conv 129 | self.do_rwr = do_rwr 130 | self.do_col = do_col 131 | self.no_col = no_col 132 | 133 | self.config_path = config_path 134 | self.config = get_config(config_path) 135 | 136 | self.chrom_list = self.config['chrom_list'] 137 | self.temp_dir = self.config['temp_dir'] 138 | self.data_dir = self.config['data_dir'] 139 | self.fh_resolutions = self.config['resolution_fh'] 140 | self.embedding_storage = None 141 | self.model = None 142 | 143 | if path2input_cache is None: 144 | path2input_cache = self.temp_dir 145 | 146 | if path2result_dir is None: 147 | path2result_dir = self.temp_dir 148 | 149 | self.path2input_cache = path2input_cache 150 | if not os.path.exists(path2input_cache): 151 | os.mkdir(path2input_cache) 152 | 153 | self.path2result_dir = path2result_dir 154 | if not os.path.exists(path2result_dir): 155 | os.mkdir(path2result_dir) 156 | 157 | _, self.gpu_id, self.avail_mem = get_free_gpu() 158 | if torch.cuda.is_available(): 159 | self.device = 'cuda' 160 | torch.set_num_threads(CPU_per_GPU) 161 | else: 162 | self.device = 'cpu' 163 | cpu_count = mpl.cpu_count() 164 | torch.set_num_threads(max(cpu_count - 2, 1)) 165 | def fast_process_data(self): 166 | try: 167 | from .Fast_process import create_dir, generate_chrom_start_end, extract_table 168 | except: 169 | from Fast_process import create_dir, generate_chrom_start_end, extract_table 170 | 171 | create_dir(self.config) 172 | generate_chrom_start_end(self.config) 173 | extract_table(self.config) 174 | print ("fast process finishes") 175 | 176 | def preprocess_meta(self): 177 | data_dir = self.config['data_dir'] 178 | if os.path.isfile(os.path.join(self.path2input_cache, "qc.npy")): 179 | qc = np.load(os.path.join(self.path2input_cache, "qc.npy")) 180 | readcount = np.load(os.path.join(self.path2input_cache, "read_count_all.npy")) 181 | else: 182 | qc, readcount = self.get_qc() 183 | good_qc = np.where(qc > 0)[0] 184 | bad_qc = np.where(qc <= 0)[0] 185 | # Always put good quality cells before bad quality cells 186 | reorder = np.concatenate([np.sort(good_qc), 187 | np.sort(bad_qc)], axis=0) 188 | 189 | with open(os.path.join(data_dir, "label_info.pickle"), "rb") as f: 190 | label_info = pd.DataFrame(pickle.load(f)) 191 | if len(label_info) == 0: 192 | label_info = pd.DataFrame(np.ones(len(readcount)), columns=['placeholder']) 193 | # print (label_info) 194 | try: 195 | plot_label = self.config['plot_label'] 196 | except: 197 | plot_label = [] 198 | 199 | np.save(os.path.join(self.path2input_cache, "reorder.npy"), reorder) 200 | label_info = label_info.iloc[reorder].reset_index() 201 | sig_list = label_info[plot_label] 202 | 203 | gc.collect() 204 | 205 | if "batch_id" in self.config: 206 | self.batch_id = np.asarray(label_info[self.config['batch_id']]) 207 | 208 | np.save(os.path.join(self.path2input_cache, "qc.npy"), qc) 209 | np.save(os.path.join(self.path2input_cache, "read_count_all.npy"), readcount) 210 | 211 | return label_info, reorder, sig_list, readcount, qc 212 | 213 | def sum_sparse_by_batch_id(self, m_list): 214 | batch = self.batch_id 215 | avail_batch = np.unique(batch) 216 | return_dict = {b: np.zeros(m_list[0].shape) for b in avail_batch} 217 | for a, b in zip(m_list, batch): 218 | return_dict[b][a.row, a.col] += a.data 219 | return return_dict 220 | 221 | def pack_training_data_one_process(self, 222 | raw_dir, chrom, reorder, 223 | off_diag=None, 224 | fac_size=None, 225 | merge_fac_row=1, merge_fac_col=1, 226 | is_sym=True, 227 | filename_pattern='%s_sparse_adj.npy', 228 | force_shift=None, 229 | batch_norm=True, 230 | bar=None): 231 | 232 | filename = filename_pattern % chrom 233 | a = np.load(os.path.join(raw_dir, filename), allow_pickle=True)[reorder] 234 | # For data with blacklist, block those regions 235 | try: 236 | blacklist = np.load(os.path.join(self.temp_dir, "raw", "blacklist.npy"), allow_pickle=True).item() 237 | except: 238 | # print("no black list") 239 | blacklist = None 240 | 241 | if blacklist is not None: 242 | bl = blacklist[chrom] 243 | bl = bl[bl < a[0].shape[0]] 244 | print("num of bl", bl.shape) 245 | new_sparse_list = [] 246 | for m in a: 247 | m = m.astype('float32').toarray() 248 | m[bl, :] = 0.0 249 | m[:, bl] = 0.0 250 | new_sparse_list.append(csr_matrix(m)) 251 | a = new_sparse_list 252 | 253 | if fac_size is None: fac_size = 1 254 | matrix_list = [m.tocoo() for m in a] 255 | 256 | if merge_fac_row > 1 or merge_fac_col > 1: 257 | for m in matrix_list: 258 | m.col //= merge_fac_col 259 | m.row //= merge_fac_row 260 | 261 | m.sum_duplicates() 262 | m.resize(list(np.ceil(np.array(m.shape) / np.array([merge_fac_col, merge_fac_row])).astype('int'))) 263 | if bar is not None: 264 | bar.set_description("sparse mtx into tensors %s - filter off-diag" % chrom, refresh=True) 265 | for m in matrix_list: 266 | mask = np.abs(m.col - m.row) <= off_diag 267 | m.row = m.row[mask] 268 | m.col = m.col[mask] 269 | m.data = m.data[mask] 270 | 271 | if bar is not None: 272 | bar.set_description("sparse mtx into tensors %s - get bulk" % chrom, refresh=True) 273 | if "batch_id" in self.config: 274 | batch_bulk = self.sum_sparse_by_batch_id(matrix_list) 275 | bulk = 0 276 | for b in batch_bulk: bulk += batch_bulk[b] 277 | else: 278 | bulk = self.sum_sparse_coo(matrix_list) 279 | 280 | bin_id_mapping_row, num_bins_row, bin_id_mapping_col, num_bins_col, v_row, v_col = filter_bin(bulk=bulk / len(matrix_list), 281 | is_sym=is_sym) 282 | if "batch_id" in self.config: 283 | if batch_norm: 284 | matrix_list = normalize_per_batch( 285 | bulk=bulk, 286 | batch_bulk=batch_bulk, 287 | matrix_list=matrix_list, 288 | batch_id=self.batch_id, 289 | off_diag=off_diag+1 290 | ) 291 | if bar is not None: 292 | bar.set_description("sparse mtx into tensors %s - read count normalize" % chrom, refresh=True) 293 | matrix_list = normalize_per_cell( 294 | matrix_list, matrix_list_intra=matrix_list, bulk=None, 295 | per_cell_normalize_func=[ 296 | normalize_by_coverage, 297 | ], 298 | ) 299 | assert not any(np.isnan(m.data).any() for m in matrix_list) 300 | 301 | nnz = sum(m.nnz for m in matrix_list) 302 | indices = np.empty([3, nnz], dtype=np.int32) 303 | values = np.empty([nnz], dtype=np.float32) 304 | del nnz 305 | 306 | shape = (num_bins_row, num_bins_col) 307 | col_offset = None 308 | do_shift = False 309 | if bar is not None: 310 | bar.set_description("sparse mtx into tensors %s - mtx indices 2 tensor indices" % chrom, refresh=True) 311 | idx_nnz = 0 312 | for i, m in enumerate(matrix_list): 313 | if m.nnz == 0: continue 314 | row, col, data = bin_id_mapping_row[m.row], bin_id_mapping_col[m.col], m.data 315 | col_new = col - row 316 | idx = (row != -1) & (col != -1) #& (col_new >= -off_diag) & (col_new <= off_diag) 317 | # assert idx.sum() > 0 318 | row, col, data = row[idx], col_new[idx], data[idx] 319 | if isinstance(fac_size, int) and fac_size == 1: 320 | pass 321 | elif isinstance(fac_size, int) and fac_size > 1: 322 | col += fac_size // 2 323 | if fac_size % 2 == 0: col[col <= 0] -= 1 324 | col //= fac_size 325 | offset = col.min() 326 | col -= offset 327 | tmp = coo_matrix((data, (row, col)), shape=(row.max() + 1, col.max() + 1)) 328 | tmp.sum_duplicates() 329 | row, col, data = tmp.row, tmp.col + offset, tmp.data 330 | del tmp, offset 331 | else: 332 | col_sgn = np.sign(col) 333 | col = col_sgn * fac_size[np.abs(col)] 334 | 335 | if do_shift: 336 | col += col_offset 337 | else: 338 | col += row 339 | 340 | nnz = len(data) 341 | ii = slice(idx_nnz, idx_nnz + nnz) 342 | indices[0, ii] = row 343 | indices[1, ii] = col 344 | indices[2, ii] = i 345 | values[ii] = data 346 | idx_nnz += nnz 347 | del nnz, ii 348 | 349 | indices = np.ascontiguousarray(indices[:, :idx_nnz]) 350 | values = np.ascontiguousarray(values[:idx_nnz]) 351 | shape = shape + (len(matrix_list),) 352 | # print(shape, do_shift) 353 | assert indices.min() >= 0 354 | assert (indices.max(1) < shape).all() 355 | gc.collect() 356 | if bar is not None: 357 | bar.set_description("sparse mtx into tensors %s - final touches" % chrom, refresh=True) 358 | values = np.log1p(values) 359 | if is_sym: 360 | s = 15 361 | else: 362 | s = 15 363 | mean_, std_ = np.mean(values), np.std(values) 364 | values = np.clip(values, a_min=None, a_max=mean_ + s * std_) 365 | 366 | return indices, values, shape 367 | 368 | def preprocess_contact_map(self, config, reorder, path2input_cache, batch_norm, key_fn=lambda c: c, **kwargs): 369 | print(f'cache file = {path2input_cache}') 370 | do_cache = path2input_cache is not None 371 | 372 | if do_cache and os.path.exists(path2input_cache): 373 | print(f'loading cached input from {path2input_cache}') 374 | all_matrix = [] 375 | with open(path2input_cache, 'rb') as f: 376 | for chrom in self.chrom_list: 377 | all_matrix.append(pickle.load(f)) 378 | sys.stdout.flush() 379 | return all_matrix 380 | 381 | chrom_list = config['chrom_list'] 382 | 383 | if "batch_id" in self.config: 384 | if batch_norm: 385 | print ("will do per batch normalization") 386 | size_list = [] 387 | print(f'saving cached input to {path2input_cache}') 388 | bar = trange(len(chrom_list), desc='sparse mtx into tensors') 389 | with open(path2input_cache, 'wb') as f: 390 | for chrom in chrom_list: 391 | indices, values, shape = self.pack_training_data_one_process( 392 | raw_dir=Path(config['temp_dir']) / 'raw', chrom=chrom, reorder=reorder, 393 | batch_norm=batch_norm, bar=bar, 394 | **kwargs,) 395 | bar.set_description("sparse mtx into tensors %s - construct sparse" % chrom, refresh=True) 396 | obj = Sparse(indices, values, shape, copy=False) 397 | bar.set_description("sparse mtx into tensors %s - sort sparse indices" % chrom, refresh=True) 398 | obj.sort_indices() 399 | 400 | bar.update(1) 401 | bar.refresh() 402 | size_list.append(obj.shape[0]) 403 | # all_matrix.append(obj) 404 | pickle.dump(obj, f, protocol=4) 405 | sys.stdout.flush() 406 | all_matrix = [] 407 | with open(path2input_cache, 'rb') as f: 408 | for i in range(len(size_list)): 409 | all_matrix.append(pickle.load(f)) 410 | sys.stdout.flush() 411 | return all_matrix 412 | 413 | @staticmethod 414 | def sum_sparse(m): 415 | x = np.zeros(m[0].shape) 416 | for a in m: 417 | ri = np.repeat(np.arange(a.shape[0]), np.diff(a.indptr)) 418 | x[ri, a.indices] += a.data 419 | return x 420 | 421 | @staticmethod 422 | def sum_sparse_coo(m): 423 | x = np.zeros(m[0].shape) 424 | for a in m: 425 | x[a.row, a.col] += a.data 426 | return x 427 | 428 | def get_qc(self): 429 | temp_dir = self.config['temp_dir'] 430 | raw_dir = os.path.join(temp_dir, "raw") 431 | chrom_list = self.config['chrom_list'] 432 | mask = [] 433 | scale = int(1000000 / self.config['resolution']) 434 | read_count_all = 0 435 | for chrom in chrom_list: 436 | read_count = [] 437 | a = np.load(os.path.join(raw_dir, "%s_sparse_adj.npy" % chrom), allow_pickle=True) 438 | bulk = self.sum_sparse(a) 439 | cov = np.sum(bulk > 0, axis=-1) 440 | n_bin = np.sum(cov > 0.1 * cov.shape[0] * scale) 441 | 442 | mask_chrom = [] 443 | for m in a: 444 | mask_chrom.append((len(m.data) + np.sum(m.diagonal() > 0)) / 2) 445 | read_count.append(np.sum(m)) 446 | read_count = np.asarray(read_count) 447 | mask_chrom = np.array(mask_chrom).astype('float') 448 | if np.sum(mask_chrom > n_bin) > 0.5 * len(mask_chrom): 449 | mask_chrom = (mask_chrom > (n_bin)) #& (mask_chrom < np.quantile(mask_chrom, 0.85)) 450 | else: 451 | mask_chrom = (mask_chrom > np.quantile(mask_chrom, 0.5)) #& (mask_chrom < np.quantile(mask_chrom, 0.85)) 452 | 453 | mask.append(mask_chrom) 454 | read_count_all += np.asarray(read_count) 455 | kept = (np.sum(np.array(mask).astype('float'), axis=0) >= (len(chrom_list))).astype('float32') 456 | 457 | read_count_all = np.log1p(read_count_all) 458 | return kept, read_count_all 459 | 460 | def prep_dataset(self, meta_only=False, batch_norm=True): 461 | if self.device != 'cpu': 462 | torch.set_num_threads(max(mpl.cpu_count() - 2, 1)) 463 | 464 | self.label_info, reorder, self.sig_list, readcount, qc = self.preprocess_meta() 465 | self.reorder = reorder 466 | self.coverage_feats = readcount[reorder].reshape((-1, 1)) 467 | if meta_only: 468 | return 469 | good_qc_num = np.sum(qc > 0) 470 | print("total number of cells that pass qc check", good_qc_num, "bad", len(qc) - good_qc_num, "total:", len(qc)) 471 | tensor_list = [None] * len(self.fh_resolutions) * len(self.chrom_list) 472 | recommend_bs_cell = [] 473 | ct = 0 474 | 475 | 476 | for res in self.fh_resolutions: 477 | all_matrix = [] 478 | try: 479 | cache_extra = args.cache_extra 480 | except: 481 | cache_extra = "" 482 | path2input_cache_intra = os.path.join(self.path2input_cache, 'cache_intra_%d_offdiag_%d_%s.pkl' % ( 483 | res, self.off_diag, cache_extra)) 484 | all_matrix += self.preprocess_contact_map( 485 | self.config, reorder=reorder, path2input_cache=path2input_cache_intra, 486 | batch_norm=batch_norm, 487 | is_sym=True, 488 | off_diag=self.off_diag, 489 | fac_size=1, 490 | merge_fac_row=int(res / self.config['resolution']), merge_fac_col=int(res / self.config['resolution']), 491 | filename_pattern='%s_sparse_adj.npy', 492 | force_shift=False, 493 | ) 494 | 495 | size_list = [m.shape[0] for m in all_matrix] 496 | num_cell = all_matrix[-1].shape[-1] 497 | avail_mem = self.avail_mem 498 | # 4 because of float32 -> bytes, 499 | # 10 because of overhead & cache 500 | max_tensor_size = avail_mem / (4 * 12) 501 | recommend_bs_bin = min(max(int(15000000 / res), 128), 256) 502 | total_cell_num = all_matrix[-1].shape[-1] 503 | 504 | total_reads, total_possible = 0, 0 505 | bar = trange(len(self.chrom_list), desc='breaking into batches') 506 | for i, size in enumerate(size_list): 507 | n_batch = max(math.ceil(size / recommend_bs_bin), 1) 508 | if self.device == 'cpu': 509 | bs_bin_local = size 510 | bs_cell = num_cell 511 | else: 512 | bs_bin_local = math.ceil(size / n_batch) 513 | bs_cell = int(max_tensor_size / (bs_bin_local * (bs_bin_local + 2 * self.off_diag))) 514 | # print ("bs_cell", bs_cell) 515 | n_batch = int(math.ceil((good_qc_num if self.filter else num_cell) / bs_cell)) 516 | bs_cell = int(math.ceil((good_qc_num if self.filter else num_cell) / n_batch)) 517 | bs_cell = min(bs_cell, good_qc_num if self.filter else num_cell) 518 | # print ("bs_bin_local", bs_bin_local, size) 519 | recommend_bs_cell.append(bs_cell) 520 | try: 521 | total_reads += len(all_matrix[i].values) 522 | except: 523 | total_reads += torch.sum(all_matrix[i] > 0) 524 | total_possible += np.prod(all_matrix[i].shape) 525 | tensor_list[ct] = Chrom_Dataset( 526 | tensor=all_matrix[i], 527 | bs_bin=bs_bin_local, 528 | bs_cell=bs_cell, 529 | good_qc_num=good_qc_num if self.filter else -1, 530 | kind='hic', 531 | upper_sim=False, 532 | compact=True, 533 | flank=self.off_diag, 534 | chrom=self.chrom_list[i], 535 | resolution=res) 536 | ct += 1 537 | bar.update(1) 538 | 539 | bar.close() 540 | sparsity = total_reads / total_possible 541 | print("sparsity", sparsity) 542 | del all_matrix 543 | gc.collect() 544 | bar.close() 545 | if sparsity * (500000 / res) ** 2 <= 0.03: 546 | print("sparsity below threshold, automatically col_normalize") 547 | do_col = sparsity * (500000 / res) ** 2 <= 0.03 or self.do_col 548 | if self.no_col: 549 | do_col = False 550 | print("do_conv", self.do_conv, "do_rwr", self.do_rwr, "do_col", do_col) 551 | self.final_do_col = do_col 552 | if self.no_col and self.do_col: 553 | print("choose one between do col or no col!") 554 | raise EOFError 555 | 556 | 557 | print("recommend_bs_cell", recommend_bs_cell, "pinning memory") 558 | all_matrix = tensor_list 559 | for i in range(len(all_matrix)): 560 | all_matrix[i].pin_memory() 561 | gc.collect() 562 | shape_list = np.stack([mtx.shape[:-1] for mtx in all_matrix]) 563 | # print(shape_list, np.sum(shape_list[:, 0])) 564 | self.good_qc_num = good_qc_num 565 | self.all_matrix = all_matrix 566 | if self.device != 'cpu': 567 | torch.set_num_threads(CPU_per_GPU) 568 | 569 | def only_partial_rwr(self): 570 | try: 571 | from .partial_rwr import partial_rwr 572 | except: 573 | from partial_rwr import partial_rwr 574 | chrom_count = 0 575 | impute_result = h5py.File(os.path.join(self.path2result_dir, "impute_prwr.hdf5"), "w") 576 | for chrom_data in tqdm(self.all_matrix, desc="imputing"): 577 | group = impute_result.create_group(self.chrom_list[chrom_count]) 578 | group.create_dataset("shape", data=np.asarray([chrom_data.num_bin, chrom_data.num_bin])) 579 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 580 | slice_cell = chrom_data.cell_slice_list[cell_batch_id] 581 | 582 | imputed_map = None 583 | 584 | 585 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 586 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 587 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 588 | slice_row = chrom_data.bin_slice_list[bin_batch_id] 589 | 590 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 591 | save_context=dict(device=self.device), 592 | transpose=True, 593 | do_conv=False) 594 | 595 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 596 | if imputed_map is None: 597 | imputed_map = np.zeros((int(chrom_batch_cell_batch.shape[0]), chrom_data.num_bin, chrom_data.num_bin)) 598 | if kind == 'hic': 599 | chrom_batch_cell_batch, n_i = partial_rwr(chrom_batch_cell_batch, 600 | slice_start=slice_local.start, 601 | slice_end=slice_local.stop, 602 | do_conv=self.do_conv, 603 | do_rwr=self.do_rwr, 604 | do_col=False, 605 | bin_cov=torch.ones(1), 606 | return_rwr_iter=True, 607 | force_rwr_epochs=-1, 608 | final_transpose=False) 609 | 610 | imputed_map[:, slice_row, slice_col] = chrom_batch_cell_batch.detach().cpu().numpy() 611 | imputed_map = imputed_map + imputed_map.transpose(0, 2, 1) 612 | for i in range(len(imputed_map)): 613 | m = imputed_map[i] 614 | m = m - np.diag(np.diag(m) / 2) 615 | group.create_dataset(str(self.reorder[slice_cell][i]), data=m.astype('float32')) 616 | i 617 | for cell_batch_id in range(0, chrom_data.num_cell_batch_bad): 618 | slice_cell = chrom_data.cell_slice_list[cell_batch_id + chrom_data.num_cell_batch] 619 | imputed_map = None 620 | 621 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 622 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 623 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 624 | slice_row = chrom_data.bin_slice_list[bin_batch_id] 625 | 626 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 627 | save_context=dict(device=self.device), 628 | transpose=True, 629 | do_conv=False, 630 | good_qc=False) 631 | 632 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 633 | if imputed_map is None: 634 | imputed_map = np.zeros((int(chrom_batch_cell_batch.shape[0]), chrom_data.num_bin, chrom_data.num_bin)) 635 | if kind == 'hic': 636 | chrom_batch_cell_batch, _ = partial_rwr(chrom_batch_cell_batch, 637 | slice_start=slice_local.start, 638 | slice_end=slice_local.stop, 639 | do_conv=self.do_conv, 640 | do_rwr=self.do_rwr, 641 | do_col=False, 642 | bin_cov=torch.ones(1), 643 | return_rwr_iter=True, 644 | force_rwr_epochs=-1, 645 | final_transpose=False) 646 | imputed_map[:, slice_row, slice_col] = chrom_batch_cell_batch.detach().cpu().numpy() 647 | imputed_map = imputed_map + imputed_map.transpose(0, 2, 1) 648 | for i in range(len(imputed_map)): 649 | m = imputed_map[i] 650 | m = m - np.diag(np.diag(m) / 2) 651 | group.create_dataset(str(self.reorder[slice_cell][i]), data=m.astype('float32')) 652 | 653 | 654 | chrom_count += 1 655 | impute_result.close() 656 | 657 | def run_model(self, dim1=.6, 658 | rank=256, 659 | n_iter_parafac=1, 660 | n_iter_max=None, 661 | tol=2e-5, 662 | extra="", 663 | run_init=True): 664 | self.rank = rank 665 | save_str = "dim1_%.1f_rank_%d_niterp_%d_%s" % (dim1, rank, n_iter_parafac, extra) 666 | self.save_str = save_str 667 | print(save_str) 668 | start = time.time() 669 | if self.model is None: 670 | self.model = Fast_Higashi_core(rank=rank, off_diag=self.off_diag, res_list=self.fh_resolutions).to(self.device) 671 | if n_iter_max is None: 672 | n_iter_max = int(self.good_qc_num / 15) 673 | 674 | result = self.model.fit_transform( 675 | self.all_matrix, 676 | size_ratio=dim1, 677 | n_iter_max=n_iter_max, 678 | n_iter_parafac=n_iter_parafac, 679 | do_conv=self.do_conv, 680 | do_rwr=self.do_rwr, 681 | do_col=self.final_do_col, 682 | tol=tol, 683 | gpu_id=self.gpu_id, 684 | run_init=run_init 685 | ) 686 | print("takes: %.2f s" % (time.time() - start)) 687 | weights_all, factors_all, p_list = result 688 | 689 | A_list, B_list, D_list, meta_embedding = factors_all 690 | 691 | self.meta_embedding = meta_embedding.detach().cpu().numpy() 692 | 693 | self.A_list = [A.detach().cpu().numpy() for A in A_list] 694 | self.B_list = [B.detach().cpu().numpy() for B in B_list] 695 | self.D_list = [D.detach().cpu().numpy() for D in D_list] 696 | self.p_list = [[p.detach().cpu().numpy() for p in temp] for temp in p_list] 697 | del factors_all, p_list, weights_all 698 | pickle.dump([self.A_list, self.B_list, self.D_list, self.meta_embedding, self.p_list], 699 | open(os.path.join(self.path2result_dir, "results_all%s.pkl" % save_str), "wb"), protocol=4) 700 | 701 | pickle.dump([self.meta_embedding, self.D_list], open(os.path.join(self.path2result_dir, "results%s.pkl" % save_str), "wb"), protocol=4) 702 | 703 | 704 | def load_model(self, dim1=.6, 705 | rank=256, 706 | n_iter_parafac=1, 707 | extra=""): 708 | save_str = "dim1_%.1f_rank_%d_niterp_%d_%s" % (dim1, rank, n_iter_parafac, extra) 709 | data = pickle.load(open(os.path.join(self.path2result_dir, "results_all%s.pkl" % save_str), "rb")) 710 | print ("model loaded") 711 | self.A_list, self.B_list, self.D_list, self.meta_embedding, self.p_list = data 712 | 713 | def check_same_score(self, knn_graph, batch_id): 714 | gather = np.zeros(len(batch_id)) 715 | same = (batch_id[knn_graph.row] == batch_id[knn_graph.col]).astype('float') 716 | uniq_dis, uniq_dis_index = np.unique(knn_graph.row, return_index=True) 717 | # skip first because it's always nothing 718 | v = np.split(same, uniq_dis_index)[1:] 719 | for i in range(len(v)): 720 | gather[uniq_dis[i]] = v[i].mean() 721 | return gather 722 | 723 | def eval_batch_mix(self, embed): 724 | from pynndescent import PyNNDescentTransformer 725 | import warnings 726 | from sklearn.decomposition import TruncatedSVD 727 | with warnings.catch_warnings(): 728 | warnings.simplefilter("ignore") 729 | if "batch_id" in self.config: 730 | batch_id = self.batch_id 731 | knn_graph = PyNNDescentTransformer(n_neighbors=50).fit_transform(embed) 732 | knn_graph.data = np.ones_like(knn_graph.data) 733 | # The number of 734 | knn_graph2nd = knn_graph @ knn_graph.T 735 | knn_graph2nd_embed = TruncatedSVD(n_components=embed.shape[-1]).fit_transform(knn_graph2nd) 736 | knn_graph2nd = PyNNDescentTransformer(n_neighbors=25).fit_transform(knn_graph2nd_embed) 737 | knn_graph = PyNNDescentTransformer(n_neighbors=25).fit_transform(embed) 738 | knn_graph = knn_graph.tocoo() 739 | knn_graph2nd = knn_graph2nd.tocoo() 740 | mix_score = self.check_same_score(knn_graph, batch_id) 741 | mix_score2 = self.check_same_score(knn_graph2nd, batch_id) 742 | 743 | return np.array([mix_score, mix_score2]) 744 | 745 | def restore_order_fun(self, x): 746 | new_x = np.zeros_like(x) 747 | new_x[self.reorder] = x 748 | return new_x 749 | 750 | def fetch_cell_embedding(self, final_dim=None, restore_order=False): 751 | print ("fetching embedding") 752 | 753 | final_dim = self.rank if final_dim is None else final_dim 754 | embedding_list = [] 755 | for p in self.D_list: 756 | if type(p).__name__ == 'Tensor': p = p.detach().cpu().numpy() 757 | p = p / np.linalg.norm(p, axis=0, keepdims=True) 758 | embed = self.meta_embedding @ p 759 | embedding_list.append(embed) 760 | 761 | embedding = np.concatenate(embedding_list, axis=1) 762 | from sklearn.preprocessing import quantile_transform 763 | c_v = quantile_transform(self.coverage_feats, n_quantiles=100) 764 | self.label_info['coverage_fh'] = c_v 765 | 766 | from sklearn.decomposition import TruncatedSVD 767 | model = TruncatedSVD(n_components=final_dim) 768 | embed = model.fit_transform(embedding) 769 | 770 | 771 | if restore_order: 772 | embedding = self.restore_order_fun(embedding) 773 | embed = self.restore_order_fun(embed) 774 | 775 | 776 | 777 | embed_l2 = normalize(embed) 778 | 779 | 780 | # embed_l2 = normalize(embed, axis=1) 781 | # embed_correct_l2 = normalize(embed_correct) 782 | # embed_correct2_l2 = normalize(embed_correct2) 783 | store = {'embed_all': embedding, 'embed_raw': embed, 'embed_l2_norm': embed_l2, 'restore_order':restore_order} 784 | self.embedding_storage = store 785 | self.correct_batch_linear('coverage_fh', False) 786 | # self.correct_batch_linear('coverage2', add_intercept_back=True) 787 | 788 | 789 | return store 790 | 791 | def calc_modularity(self, A, label, resolution=1, normalize=True): 792 | num_nodes = A.shape[0] 793 | import itertools 794 | label_a2i = dict(zip(set(label), itertools.count())) 795 | num_labels = len(label_a2i) 796 | if num_labels == 1: return 0. 797 | label = np.fromiter(map(label_a2i.get, label), dtype=int) 798 | A = A.tocoo() 799 | assert (A.col != A.row).all() # Multiplying diagonal values by 2 might works 800 | Asum = A.data.sum() 801 | assert Asum > 0 802 | score = A.data[label[A.row] == label[A.col]].sum() / Asum 803 | 804 | k = np.bincount(label[A.row], weights=A.data, minlength=num_labels) / Asum 805 | score -= k @ k * resolution 806 | 807 | if normalize: 808 | max_score = k @ (1 - k*resolution) 809 | score /= max_score 810 | 811 | return score 812 | 813 | 814 | 815 | def correct_batch_linear(self, var_to_regress_name, add_intercept_back=False): 816 | if self.embedding_storage is None: 817 | print ("Run fetch_cell_embedding() first!") 818 | return None 819 | if type(var_to_regress_name) is str: 820 | try: 821 | var_to_regress = np.array(self.label_info[var_to_regress_name]) 822 | if self.embedding_storage['restore_order']: 823 | var_to_regress = self.restore_order_fun(var_to_regress) 824 | # print (var_to_regress, var_to_regress.dtype,) 825 | if var_to_regress.dtype not in [np.dtype('float32'), np.dtype('float16'), np.dtype('float64')]: 826 | print ("not float var, one hot encoding") 827 | uniq = np.unique(var_to_regress) 828 | var_to_regress_new = np.zeros((len(var_to_regress), len(uniq))) 829 | for i, u_ in enumerate(uniq): 830 | var_to_regress_new[var_to_regress == u_, i] = 1 831 | var_to_regress = var_to_regress_new 832 | 833 | except: 834 | print ("var_to_regress not in label_info.pickle!") 835 | return None 836 | elif type(var_to_regress_name) is list: 837 | var_to_regress_all = [] 838 | for name in var_to_regress_name: 839 | try: 840 | var_to_regress = np.array(self.label_info[name]) 841 | if self.embedding_storage['restore_order']: 842 | var_to_regress = self.restore_order_fun(var_to_regress) 843 | print(var_to_regress, var_to_regress.dtype, ) 844 | if var_to_regress.dtype not in [np.dtype('float32'), np.dtype('float16'), np.dtype('float64')]: 845 | print("not float var, one hot encoding") 846 | uniq = np.unique(var_to_regress) 847 | var_to_regress_new = np.zeros((len(var_to_regress), len(uniq))) 848 | for i, u_ in enumerate(uniq): 849 | var_to_regress_new[var_to_regress == u_, i] = 1 850 | var_to_regress = var_to_regress_new 851 | if len(var_to_regress.shape) == 1: 852 | var_to_regress = var_to_regress.reshape((-1, 1)) 853 | var_to_regress_all.append(var_to_regress) 854 | except: 855 | print("var_to_regress %s not in label_info.pickle!" % name) 856 | return None 857 | var_to_regress = np.concatenate(var_to_regress_all, axis=-1) 858 | var_to_regress_name = "_".join(var_to_regress_name) 859 | else: 860 | print("var_to_regress must be a str or list of strs") 861 | return None 862 | 863 | if len(var_to_regress.shape) == 1: 864 | var_to_regress = var_to_regress.reshape((-1, 1)) 865 | # print (var_to_regress.shape) 866 | from sklearn.linear_model import LinearRegression 867 | model = LinearRegression() 868 | embedding = self.embedding_storage['embed_all'] 869 | embedding = embedding - model.fit(var_to_regress, embedding).predict(var_to_regress) 870 | if add_intercept_back: 871 | embedding = embedding + model.intercept_[None] 872 | from sklearn.decomposition import TruncatedSVD 873 | model = TruncatedSVD(n_components=self.embedding_storage['embed_raw'].shape[-1]) 874 | reduce = model.fit_transform(embedding) 875 | reduce_l2 = normalize(reduce) 876 | self.embedding_storage['embed_correct_%s' % var_to_regress_name] = reduce 877 | self.embedding_storage['embed_l2_norm_correct_%s' % var_to_regress_name] = reduce_l2 878 | return self.embedding_storage 879 | 880 | if __name__ == '__main__': 881 | print(time.ctime()) 882 | # parse all arguments 883 | args = parse_args() 884 | 885 | # initialize the model 886 | wrapper = FastHigashi(config_path=args.config, 887 | path2input_cache=args.path2input_cache, 888 | path2result_dir=args.path2result_dir, 889 | off_diag=args.off_diag, 890 | filter=args.filter, 891 | do_conv=args.do_conv, 892 | do_rwr=args.do_rwr, 893 | do_col=args.do_col, 894 | no_col=args.no_col) 895 | if not os.path.exists(os.path.join(wrapper.temp_dir, "raw", "%s_sparse_adj.npy" % wrapper.chrom_list[0])): 896 | start = time.time() 897 | wrapper.fast_process_data() 898 | print("contact pairs to sparse mtx takes: %.2f s" % (time.time() - start)) 899 | 900 | # packing data from sparse matrices to 901 | start = time.time() 902 | wrapper.prep_dataset(batch_norm=args.batch_norm) 903 | print("packing sparse mtx takes: %.2f s" % (time.time() - start)) 904 | 905 | wrapper.run_model(extra=args.extra, 906 | rank=args.rank, 907 | n_iter_parafac=1, 908 | tol=args.tol) 909 | 910 | # loading existing trained models 911 | wrapper.load_model(extra=args.extra, 912 | rank=args.rank, 913 | n_iter_parafac=1 914 | ) 915 | 916 | # only do partial_rwr for analysis purpose 917 | # wrapper.only_partial_rwr() 918 | 919 | # getting embedding 920 | embed = wrapper.fetch_cell_embedding(final_dim=args.rank, 921 | restore_order=False) 922 | 923 | 924 | 925 | 926 | # prefer stands for the embeddings that the algorithm think might perform the best 927 | print (embed.keys()) 928 | 929 | 930 | 931 | # ## internal uses... code not uploaded 932 | # wrapper.correct_batch_linear("Donor") 933 | # wrapper.correct_batch_linear(["Donor", "Region"]) 934 | # try: 935 | # from .evaluation import evaluate_combine 936 | # except: 937 | # from evaluation import evaluate_combine 938 | # # evaluate_combine(wrapper.config, wrapper.sig_list, [slice(None)], embed['prefer'], project=None, extra=args.cache_extra+"_"+args.extra+"prefer", save_dir=wrapper.path2result_dir, with_CCA=False, label_info=wrapper.label_info, 939 | # # coverage_feats=None, log=None, number_only=False, save_fmt='png', linear_corr=False) 940 | # evaluate_combine(wrapper.config, wrapper.sig_list, [slice(None)], embed['embed_l2_norm'], project=None, 941 | # extra=args.cache_extra + "_" + args.extra + "l2_norm_raw", save_dir=wrapper.path2result_dir, 942 | # with_CCA=False, label_info=wrapper.label_info, 943 | # coverage_feats=None, log=None, number_only=False, save_fmt='png', linear_corr=False) 944 | # evaluate_combine(wrapper.config, wrapper.sig_list, [slice(None)], embed['embed_l2_norm_correct_coverage_fh'], project=None, 945 | # extra=args.cache_extra + "_" + args.extra + "l2_norm_linear", save_dir=wrapper.path2result_dir, 946 | # with_CCA=False, label_info=wrapper.label_info, 947 | # coverage_feats=None, log=None, number_only=False, save_fmt='png', linear_corr=False) 948 | # 949 | # evaluate_combine(wrapper.config, wrapper.sig_list, [slice(None)], embed['embed_l2_norm_correct_Donor'], 950 | # project=None, 951 | # extra=args.cache_extra + "_" + args.extra + "l2_norm_donor", save_dir=wrapper.path2result_dir, 952 | # with_CCA=False, label_info=wrapper.label_info, 953 | # coverage_feats=None, log=None, number_only=False, save_fmt='png', linear_corr=False) 954 | # 955 | # evaluate_combine(wrapper.config, wrapper.sig_list, [slice(None)], embed['embed_l2_norm_correct_Donor_Region'], 956 | # project=None, 957 | # extra=args.cache_extra + "_" + args.extra + "l2_norm_donor_region", save_dir=wrapper.path2result_dir, 958 | # with_CCA=False, label_info=wrapper.label_info, 959 | # coverage_feats=None, log=None, number_only=False, save_fmt='png', linear_corr=False) 960 | -------------------------------------------------------------------------------- /fasthigashi/Fast_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm.auto import tqdm, trange 8 | from scipy.sparse import csr_matrix, vstack, SparseEfficiencyWarning, diags, \ 9 | hstack 10 | from concurrent.futures import ProcessPoolExecutor, as_completed 11 | import h5py, math 12 | import pandas as pd 13 | 14 | # try: 15 | # get_ipython() 16 | # print ("jupyter notebook mode") 17 | # from tqdm.notebook import tqdm, trange 18 | # except Exception as e: 19 | # print (e) 20 | # print ("terminal mode") 21 | # pass 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | device_ids = [0, 1] 25 | 26 | def get_config(config_path = "./config.jSON"): 27 | import json 28 | c = open(config_path,"r") 29 | return json.load(c) 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description="Higashi Processing") 34 | parser.add_argument('-c', '--config', type=str, default="./config.JSON") 35 | return parser.parse_args() 36 | 37 | 38 | def get_free_gpu(): 39 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > ./tmp') 40 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 41 | if len(memory_available) > 0: 42 | max_mem = np.max(memory_available) 43 | ids = np.where(memory_available == max_mem)[0] 44 | chosen_id = int(np.random.choice(ids, 1)[0]) 45 | print("setting to gpu:%d" % chosen_id) 46 | torch.cuda.set_device(chosen_id) 47 | else: 48 | return 49 | 50 | 51 | def create_dir(config): 52 | temp_dir = config['temp_dir'] 53 | if not os.path.exists(temp_dir): 54 | os.mkdir(temp_dir) 55 | 56 | raw_dir = os.path.join(temp_dir, "raw") 57 | if not os.path.exists(raw_dir): 58 | os.mkdir(raw_dir) 59 | 60 | rw_dir = os.path.join(temp_dir, "rw") 61 | if not os.path.exists(rw_dir): 62 | os.mkdir(rw_dir) 63 | 64 | embed_dir = os.path.join(temp_dir, "embed") 65 | if not os.path.exists(embed_dir): 66 | os.mkdir(embed_dir) 67 | 68 | 69 | # Generate a indexing table of start and end id of each chromosome 70 | def generate_chrom_start_end(config): 71 | # fetch info from config 72 | genome_reference_path = config['genome_reference_path'] 73 | chrom_list = config['chrom_list'] 74 | res = config['resolution'] 75 | temp_dir = config['temp_dir'] 76 | 77 | print("generating start/end dict for chromosome") 78 | chrom_size = pd.read_table(genome_reference_path, sep="\t", header=None) 79 | chrom_size.columns = ['chrom', 'size'] 80 | # build a list that stores the start and end of each chromosome (unit of the number of bins) 81 | chrom_start_end = np.zeros((len(chrom_list), 2), dtype='int') 82 | for i, chrom in enumerate(chrom_list): 83 | size = chrom_size[chrom_size['chrom'] == chrom] 84 | size = size['size'][size.index[0]] 85 | n_bin = int(math.ceil(size / res)) 86 | chrom_start_end[i, 1] = chrom_start_end[i, 0] + n_bin 87 | if i + 1 < len(chrom_list): 88 | chrom_start_end[i + 1, 0] = chrom_start_end[i, 1] 89 | 90 | # print("chrom_start_end", chrom_start_end) 91 | np.save(os.path.join(temp_dir, "chrom_start_end.npy"), chrom_start_end) 92 | 93 | 94 | def data2mtx(config, file, chrom_start_end, verbose, cell_id, blacklist=""): 95 | if type(file) is str: 96 | if "header_included" in config: 97 | if config['header_included']: 98 | tab = pd.read_table(file, sep="\t") 99 | else: 100 | tab = pd.read_table(file, sep="\t", header=None) 101 | tab.columns = config['contact_header'][:len(tab.columns)] 102 | else: 103 | tab = pd.read_table(file, sep="\t", header=None) 104 | tab.columns = config['contact_header'] 105 | if 'count' not in tab.columns: 106 | tab['count'] = 1 107 | else: 108 | tab = file 109 | 110 | data = tab 111 | # fetch info from config 112 | res = config['resolution'] 113 | chrom_list = config['chrom_list'] 114 | 115 | data = data[(data['chrom1'] == data['chrom2']) & ((np.abs(data['pos2'] - data['pos1']) >= 2500) | (np.abs(data['pos2'] - data['pos1']) == 0))] 116 | 117 | if blacklist != "" and len(data) > 0: 118 | data = remove_blacklist(blacklist, data) 119 | 120 | pos1 = np.array(data['pos1']) 121 | pos2 = np.array(data['pos2']) 122 | bin1 = np.floor(pos1 / res).astype('int') 123 | bin2 = np.floor(pos2 / res).astype('int') 124 | 125 | chrom1, chrom2 = np.array(data['chrom1'].values), np.array(data['chrom2'].values) 126 | count = np.array(data['count'].values) 127 | 128 | del data 129 | 130 | m1_list = [] 131 | for i, chrom in enumerate(chrom_list): 132 | mask = (chrom1 == chrom) # & (bin1 != bin2) 133 | size = chrom_start_end[i, 1] - chrom_start_end[i, 0] 134 | temp_weight2 = count[mask] 135 | m1 = csr_matrix((temp_weight2, (bin1[mask], bin2[mask])), shape=(size, size), dtype='float32') 136 | m1 = m1 + m1.T 137 | m1_list.append(m1) 138 | count = count[~mask] 139 | bin1 = bin1[~mask] 140 | bin2 = bin2[~mask] 141 | chrom1 = chrom1[~mask] 142 | 143 | return m1_list, cell_id 144 | 145 | 146 | # Extra the data.txt table 147 | # Memory consumption re-optimize 148 | def extract_table(config): 149 | # fetch info from config 150 | data_dir = config['data_dir'] 151 | temp_dir = config['temp_dir'] 152 | chrom_list = config['chrom_list'] 153 | if "blacklist" in config: 154 | blacklist = config["blacklist"] 155 | else: 156 | blacklist = "" 157 | if 'input_format' in config: 158 | input_format = config['input_format'] 159 | else: 160 | input_format = 'higashi_v1' 161 | 162 | chrom_start_end = np.load(os.path.join(temp_dir, "chrom_start_end.npy")) 163 | import multiprocessing 164 | cpu_num = multiprocessing.cpu_count() 165 | if input_format == 'higashi_v1': 166 | print("extracting from data.txt") 167 | if "structured" in config: 168 | if config["structured"]: 169 | chunksize = int(5e6) 170 | cell_tab = [] 171 | 172 | p_list = [] 173 | pool = ProcessPoolExecutor(max_workers=cpu_num) 174 | print("First calculating how many lines are there") 175 | line_count = sum(1 for i in open(os.path.join(data_dir, "data.txt"), 'rb')) 176 | print("There are %d lines" % line_count) 177 | bar = trange(line_count, desc=' - Processing ', leave=False, ) 178 | cell_num = 0 179 | with open(os.path.join(data_dir, "data.txt"), 'r') as csv_file: 180 | chunk_count = 0 181 | reader = pd.read_csv(csv_file, chunksize=chunksize, sep="\t") 182 | for chunk in reader: 183 | if len(chunk['cell_id'].unique()) == 1: 184 | # Only one cell, keep appending 185 | cell_tab.append(chunk) 186 | else: 187 | # More than one cell, append all but the last part 188 | last_cell = np.array(chunk.tail(1)['cell_id'])[0] 189 | tails = chunk.iloc[np.array(chunk['cell_id']) != last_cell, :] 190 | head = chunk.iloc[np.array(chunk['cell_id']) == last_cell, :] 191 | cell_tab.append(tails) 192 | cell_tab = pd.concat(cell_tab, axis=0).reset_index() 193 | for cell_id in np.unique(cell_tab['cell_id']): 194 | p_list.append( 195 | pool.submit(data2mtx, config, cell_tab[cell_tab['cell_id'] == cell_id].reset_index(), 196 | chrom_start_end, False, cell_id, blacklist)) 197 | cell_num = max(cell_num, cell_id + 1) 198 | 199 | cell_tab = [head] 200 | bar.update(n=chunksize) 201 | bar.refresh() 202 | 203 | 204 | if len(cell_tab) != 0: 205 | cell_tab = pd.concat(cell_tab, axis=0).reset_index() 206 | for cell_id in np.unique(cell_tab['cell_id']): 207 | p_list.append( 208 | pool.submit(data2mtx, config, cell_tab[cell_tab['cell_id'] == cell_id].reset_index(), 209 | chrom_start_end, False, cell_id, blacklist)) 210 | cell_num = max(cell_num, cell_id + 1) 211 | cell_num = int(cell_num) 212 | mtx_all_list = [[0] * cell_num for i in range(len(chrom_list))] 213 | 214 | 215 | for p in as_completed(p_list): 216 | mtx_list, cell_id = p.result() 217 | for i in range(len(chrom_list)): 218 | mtx_all_list[i][cell_id] = mtx_list[i] 219 | 220 | else: 221 | data = pd.read_table(os.path.join(data_dir, "data.txt"), sep="\t") 222 | # ['cell_name','cell_id', 'chrom1', 'pos1', 'chrom2', 'pos2', 'count'] 223 | cell_id_all = np.unique(data['cell_id']) 224 | cell_num = int(np.max(cell_id_all) + 1) 225 | bar = trange(cell_num) 226 | mtx_all_list = [[0] * cell_num for i in range(len(chrom_list))] 227 | p_list = [] 228 | pool = ProcessPoolExecutor(max_workers=cpu_num) 229 | for cell_id in range(cell_num): 230 | p_list.append(pool.submit(data2mtx, config, data[data['cell_id'] == cell_id].reset_index(), 231 | chrom_start_end, False, cell_id, blacklist)) 232 | 233 | for p in as_completed(p_list): 234 | mtx_list, cell_id = p.result() 235 | for i in range(len(chrom_list)): 236 | mtx_all_list[i][cell_id] = mtx_list[i] 237 | bar.update(1) 238 | bar.close() 239 | pool.shutdown(wait=True) 240 | 241 | else: 242 | data = pd.read_table(os.path.join(data_dir, "data.txt"), sep="\t") 243 | cell_id_all = np.unique(data['cell_id']) 244 | cell_num = int(np.max(cell_id_all) + 1) 245 | bar = trange(cell_num) 246 | mtx_all_list = [[0] * cell_num for i in range(len(chrom_list))] 247 | p_list = [] 248 | pool = ProcessPoolExecutor(max_workers=cpu_num) 249 | for cell_id in range(cell_num): 250 | p_list.append( 251 | pool.submit(data2mtx, config, data[data['cell_id'] == cell_id].reset_index(), chrom_start_end, 252 | False, cell_id, blacklist)) 253 | 254 | for p in as_completed(p_list): 255 | mtx_list, cell_id = p.result() 256 | for i in range(len(chrom_list)): 257 | mtx_all_list[i][cell_id] = mtx_list[i] 258 | bar.update(1) 259 | bar.close() 260 | pool.shutdown(wait=True) 261 | for i in range(len(chrom_list)): 262 | np.save(os.path.join(temp_dir, "raw", "%s_sparse_adj.npy" % chrom_list[i]), mtx_all_list[i], 263 | allow_pickle=True) 264 | elif input_format == 'higashi_v2': 265 | print("extracting from filelist.txt") 266 | with open(os.path.join(data_dir, "filelist.txt"), "r") as f: 267 | lines = f.readlines() 268 | filelist = [line.strip() for line in lines] 269 | bar = trange(len(filelist)) 270 | mtx_all_list = [[0] * len(filelist) for i in range(len(chrom_list))] 271 | p_list = [] 272 | pool = ProcessPoolExecutor(max_workers=cpu_num) 273 | for cell_id, file in enumerate(filelist): 274 | p_list.append(pool.submit(data2mtx, config, file, chrom_start_end, False, cell_id, blacklist)) 275 | 276 | for p in as_completed(p_list): 277 | mtx_list, cell_id = p.result() 278 | for i in range(len(chrom_list)): 279 | mtx_all_list[i][cell_id] = mtx_list[i] 280 | bar.update(1) 281 | bar.close() 282 | pool.shutdown(wait=True) 283 | 284 | for i in range(len(chrom_list)): 285 | np.save(os.path.join(temp_dir, "raw", "%s_sparse_adj.npy" % chrom_list[i]), mtx_all_list[i], 286 | allow_pickle=True) 287 | else: 288 | print("invalid input format") 289 | raise EOFError 290 | 291 | 292 | def remove_blacklist(blacklistbed, chromdf): 293 | import pybedtools 294 | blacklist = pybedtools.BedTool(blacklistbed) 295 | left = chromdf[['chrom1', 'pos1', 'pos1']].copy() 296 | left.loc[:, 'temp_indexname'] = np.arange(len(left)) 297 | 298 | right = chromdf[['chrom2', 'pos2', 'pos2']].copy() 299 | right.loc[:, 'temp_indexname'] = np.arange(len(right)) 300 | 301 | import tempfile 302 | f1 = tempfile.NamedTemporaryFile() 303 | left.to_csv(f1, sep="\t", header=False, index=False) 304 | bed_anchor = pybedtools.BedTool(f1.name) 305 | good_anchor = bed_anchor.subtract(blacklist) 306 | good_anchor_left = good_anchor.to_dataframe() 307 | 308 | f2 = tempfile.NamedTemporaryFile() 309 | right.to_csv(f2, sep="\t", header=False, index=False) 310 | bed_anchor = pybedtools.BedTool(f2.name) 311 | good_anchor = bed_anchor.subtract(blacklist) 312 | good_anchor_right = good_anchor.to_dataframe() 313 | 314 | good_index = np.intersect1d(good_anchor_left['name'], good_anchor_right['name']) 315 | ori_len = len(chromdf) 316 | # str1 = "length from %d " % len(chromdf) 317 | chromdf = chromdf.iloc[good_index, :] 318 | # str1 += "to %d" % (len(chromdf)) 319 | # if len(chromdf) < ori_len: 320 | # print (str1) 321 | pybedtools.helpers.cleanup() 322 | return chromdf 323 | 324 | 325 | if __name__ == '__main__': 326 | args = parse_args() 327 | config = get_config(args.config) 328 | 329 | create_dir(config) 330 | generate_chrom_start_end(config) 331 | extract_table(config) -------------------------------------------------------------------------------- /fasthigashi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/Fast-Higashi/71182a9bdee2b96cd1676e1448285c5705a354e7/fasthigashi/__init__.py -------------------------------------------------------------------------------- /fasthigashi/parafac2_intergrative.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import time 3 | import torch.cuda 4 | import torch.jit as jit 5 | from opt_einsum import contract 6 | from sklearn.decomposition import TruncatedSVD 7 | try: 8 | from .parafac_integrative import parafac 9 | from .project2orthogonal import * 10 | from .partial_rwr import partial_rwr, tilte2rec 11 | from .util import * 12 | except: 13 | from parafac_integrative import parafac 14 | from project2orthogonal import * 15 | from partial_rwr import partial_rwr, tilte2rec 16 | from util import * 17 | 18 | import torch.nn.functional as F 19 | import multiprocessing as mpl 20 | cpu_count = mpl.cpu_count() 21 | ## factors are named as : U (dim1, dim2, r), A(dim1, r), B(r, r), D(R, r), meta_embedding(dim3, R) 22 | ## matmul(meta_embedding, D) are also referred to as C 23 | ## size_list: list of r 24 | ## rank: R 25 | def pass_(x, **kwargs): 26 | return x 27 | 28 | gpu_flag = torch.cuda.is_available() 29 | progressbar = pass_ 30 | def save_to_device(model_param, temp_var): 31 | if model_param is None or model_param is temp_var: return 32 | model_param[:] = temp_var.to(model_param.device) 33 | 34 | def summarize_a_tensor(ts, name): 35 | a = torch.sum(torch.isnan(ts)) 36 | if a > 0: 37 | print("name:", name) 38 | print ("shape:", ts.shape) 39 | print ("min/max:", torch.min(ts), torch.max(ts)) 40 | print ("nan num:", torch.sum(torch.isnan(ts))) 41 | raise EOFError 42 | else: 43 | return 44 | 45 | class Fast_Higashi_core(): 46 | 47 | def __init__(self, rank, off_diag, res_list): 48 | self.rank = rank 49 | 50 | self.sparse_conv = False 51 | self.off_diag = off_diag 52 | self.res_list = res_list 53 | # print ("sparse conv", self.sparse_conv) 54 | self.device = torch.device("cpu") 55 | 56 | def to(self, device): 57 | self.device = device 58 | return self 59 | 60 | @torch.no_grad() 61 | def init_params(self, schic, do_conv, do_rwr, do_col): 62 | rank = self.rank 63 | uniq_size_list = list(self.chrom2size.values()) * len(self.res_list) 64 | _t = time.perf_counter() 65 | context = dict(device=self.device, dtype=torch.float32) 66 | context_cpu = dict(device='cpu', dtype=torch.float32) 67 | cum_size_list = np.concatenate([[0], np.cumsum(uniq_size_list)]) 68 | # A_list, one for each matrix 69 | # A of size (#bin, r) 70 | # A could be large in cpu 71 | A_list = [ 72 | torch.randn([chrom_data.shape[0], self.chrom2size[chrom_data.chrom]], **context_cpu) * 1e-2 + 1 73 | for chrom_data in schic] 74 | if gpu_flag: A_list = [a.pin_memory() for a in A_list] 75 | 76 | # B one for each chromosome (multi-resolution) 77 | # B is definitely small in gpu 78 | # B of size (size) 79 | B_dict = {chrom: torch.eye(size, **context).add_(torch.randn(size, **context), alpha=1e-2) 80 | for chrom, size in self.chrom2size.items()} 81 | 82 | C = None 83 | 84 | print(f'time elapsed: {time.perf_counter() - _t:.2f}') 85 | sys.stdout.flush() 86 | bin_cov_list = [] 87 | bad_bin_cov_list = [] 88 | n_i_all = [] 89 | feats_for_SVD = {chrom: [] for chrom in self.chrom2size} 90 | bar_ = tqdm(schic, desc="initializing params") 91 | for chrom_data in bar_: 92 | if C is None: 93 | C = np.empty((chrom_data.num_cell, cum_size_list[-1])) 94 | C_start = 0 95 | n_i_list = [] 96 | bin_cov = torch.ones([chrom_data.num_cell, chrom_data.num_bin], 97 | dtype=torch.float32, device='cpu') * 1e-4 98 | bad_bin_cov = torch.ones([chrom_data.total_cell_num - chrom_data.num_cell, chrom_data.num_bin], 99 | dtype=torch.float32, device='cpu') * 1e-4 100 | # for bin_index in range(0, chrom_data.shape[0], chrom_data.bs_bin): 101 | 102 | num_bin_1m = int(math.ceil(chrom_data.num_bin * chrom_data.resolution / 1000000)) 103 | size1 = min(int(math.ceil(chrom_data.num_bin / chrom_data.num_bin_batch * 104 | chrom_data.resolution / 1000000)) + 2 * self.off_diag + 1, 105 | num_bin_1m) 106 | feats_dim = int(math.ceil(num_bin_1m * size1)) 107 | 108 | feats = np.empty((chrom_data.num_cell, feats_dim)) 109 | 110 | ll = int(math.ceil(1000000 / chrom_data.resolution)) 111 | if ll > 1: 112 | conv_filter = torch.ones(1, 1, ll, ll).to(self.device) / (ll * ll) 113 | conv_flag = True 114 | else: 115 | conv_flag = False 116 | feats_start = 0 117 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 118 | slice_ = chrom_data.bin_slice_list[bin_batch_id] 119 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 120 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 121 | # feat = [] 122 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 123 | slice_cell = chrom_data.cell_slice_list[cell_batch_id] 124 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 125 | save_context=dict(device=self.device), 126 | transpose=True, 127 | do_conv=do_conv if self.sparse_conv else False) 128 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 129 | # here it's always, cell, row, col 130 | if kind == 'hic': 131 | chrom_batch_cell_batch, n_i = partial_rwr(chrom_batch_cell_batch, 132 | slice_start=slice_local.start, 133 | slice_end=slice_local.stop, 134 | do_conv=False if self.sparse_conv else do_conv, 135 | do_rwr=do_rwr, 136 | do_col=False, 137 | bin_cov=torch.ones(1), 138 | return_rwr_iter=True, 139 | force_rwr_epochs=-1, 140 | final_transpose=False) 141 | b_c = chrom_batch_cell_batch.sum(1) 142 | bin_cov[slice_cell, slice_col] += b_c.detach().cpu() 143 | n_i_list.append(n_i) 144 | 145 | if not do_col: 146 | if conv_flag: 147 | B = F.avg_pool2d(chrom_batch_cell_batch[:,None], 148 | ll, ll, padding=0, ceil_mode=False)[:, 0, :, :] 149 | # print (B.shape, chrom_batch_cell_batch.shape) 150 | else: 151 | B = chrom_batch_cell_batch 152 | B = B.cpu().numpy().reshape((len(B), -1)) 153 | 154 | # if chrom_data.num_cell > 20000: 155 | # # large number of cells, sparsify 156 | # B[B <= 2e-8] = 0.0 157 | try: 158 | feats[slice_cell, feats_start:feats_start+B.shape[-1]] = B 159 | except Exception as e: 160 | print (B.shape, feats_start, feats_dim, num_bin_1m, chrom_data.resolution, chrom_data.num_bin, chrom_data.chrom) 161 | print (e) 162 | raise e 163 | # else: 164 | # feat.append(chrom_batch_cell_batch.permute(2, 0, 1).cpu().numpy().reshape(chrom_batch_cell_batch.shape[2], -1)) 165 | del chrom_batch_cell_batch 166 | if not do_col: 167 | feats_start += B.shape[-1] 168 | # gc.collect() 169 | try: 170 | torch.cuda.empty_cache() 171 | except: 172 | pass 173 | 174 | # if len(feat) > 0: 175 | # feat = np.concatenate(feat, axis=0) 176 | # feat = feat.reshape((len(feat), -1)) 177 | # feats_for_SVD[chrom_data.chrom].append(feat) 178 | 179 | for cell_batch_id in range(0, chrom_data.num_cell_batch_bad): 180 | slice_cell = chrom_data.cell_slice_list[cell_batch_id + chrom_data.num_cell_batch] 181 | slice_cell = slice(slice_cell.start-chrom_data.num_cell, slice_cell.stop-chrom_data.num_cell) 182 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 183 | save_context=dict(device=self.device), 184 | transpose=True, 185 | do_conv=do_conv if self.sparse_conv else False, 186 | good_qc=False) 187 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 188 | if kind == 'hic': 189 | chrom_batch_cell_batch, _ = partial_rwr(chrom_batch_cell_batch, 190 | slice_start=slice_local.start, 191 | slice_end=slice_local.stop, 192 | do_conv=False if self.sparse_conv else do_conv, 193 | do_rwr=do_rwr, 194 | do_col=False, 195 | bin_cov=torch.ones(1), 196 | return_rwr_iter=True, 197 | force_rwr_epochs=-1, 198 | final_transpose=False) 199 | b_c = chrom_batch_cell_batch.sum(1) 200 | bad_bin_cov[slice_cell, slice_col] += b_c.detach().cpu() 201 | if do_col: 202 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 203 | slice_ = chrom_data.bin_slice_list[bin_batch_id] 204 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 205 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 206 | # feat = [] 207 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 208 | slice_cell = chrom_data.cell_slice_list[cell_batch_id] 209 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 210 | save_context=dict(device=self.device), 211 | transpose=True, 212 | do_conv=do_conv if self.sparse_conv else False) 213 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 214 | # here it's always, cell, row, col 215 | if kind == 'hic': 216 | chrom_batch_cell_batch, n_i = partial_rwr(chrom_batch_cell_batch, 217 | slice_start=slice_local.start, 218 | slice_end=slice_local.stop, 219 | do_conv=False if self.sparse_conv else do_conv, 220 | do_rwr=do_rwr, 221 | do_col=do_col, 222 | bin_cov=bin_cov[slice_cell, slice_col], 223 | return_rwr_iter=True, 224 | force_rwr_epochs=-1, 225 | final_transpose=False) 226 | 227 | 228 | if conv_flag: 229 | B = F.avg_pool2d(chrom_batch_cell_batch[:,None], 230 | ll, ll, padding=0, ceil_mode=False)[:, 0, :, :] 231 | else: 232 | B = chrom_batch_cell_batch 233 | 234 | # feat.append( 235 | # B.cpu().numpy()) 236 | B = B.cpu().numpy().reshape((len(B), -1)) 237 | feats[slice_cell, feats_start:feats_start + B.shape[-1]] = B 238 | 239 | # else: 240 | # feat.append(chrom_batch_cell_batch.permute(2, 0, 1).cpu().numpy().reshape(chrom_batch_cell_batch.shape[2], -1)) 241 | del chrom_batch_cell_batch 242 | feats_start += B.shape[-1] 243 | gc.collect() 244 | try: 245 | torch.cuda.empty_cache() 246 | except: 247 | pass 248 | 249 | 250 | # feats_for_SVD[chrom_data.chrom] = np.concatenate(feats_for_SVD[chrom_data.chrom], axis=1) 251 | size = self.chrom2size[chrom_data.chrom] 252 | if self.device != 'cpu': 253 | torch.set_num_threads(max(cpu_count - 2, 1)) 254 | # print ("turning into sparse") 255 | # feats = csr_matrix(feats) 256 | bar_.set_description("initializing params - SVD ing %s: " %str(feats[:,:feats_start].shape), refresh=True) 257 | svd = TruncatedSVD(n_components=size, n_iter=2) 258 | temp = svd.fit_transform(feats[:, :feats_start]) 259 | if self.device != 'cpu': 260 | torch.set_num_threads(4) 261 | # del feats 262 | C[:, C_start:C_start+temp.shape[-1]] = temp 263 | C_start += temp.shape[-1] 264 | bar_.set_description("initializing params - finished SVD ", refresh=True) 265 | n_i_all.append(np.max(n_i_list) if len(n_i_list) > 0 else 0) 266 | if type(bin_cov) is not float: 267 | bin_cov[bin_cov <= 1e-4] = float('inf') 268 | if chrom_data.num_cell_batch_bad > 0: 269 | if type(bin_cov) is not float: 270 | bad_bin_cov[bad_bin_cov <= 1e-4] = float('inf') 271 | bad_bin_cov_list.append(bad_bin_cov) 272 | else: 273 | bad_bin_cov_list.append(0) 274 | else: 275 | bad_bin_cov_list.append(0) 276 | bin_cov_list.append(bin_cov) 277 | 278 | n_i_all = np.array(n_i_all) 279 | self.n_i = np.array(n_i_all) 280 | 281 | 282 | print("rwr iters:", self.n_i) 283 | C = torch.from_numpy(C).float() 284 | U, S, Vh = torch.linalg.svd(C.to(self.device), full_matrices=False) 285 | meta_embedding = U[:, :rank] 286 | SVh = Vh[:rank].mul_(S[:rank, None]) 287 | D_dict = { 288 | chrom: SVh[:, start: stop].clone() 289 | for chrom, start, stop in zip(list(self.chrom2size.keys()), cum_size_list[:-1], cum_size_list[1:]) 290 | } 291 | del C 292 | 293 | print(f'time elapsed: {time.perf_counter() - _t:.2f}') 294 | sys.stdout.flush() 295 | self.A_list = A_list 296 | self.B_dict = B_dict 297 | self.meta_embedding = meta_embedding 298 | self.D_dict = D_dict 299 | self.bin_cov_list = bin_cov_list 300 | self.bad_bin_cov_list = bad_bin_cov_list 301 | print ("finish init") 302 | 303 | @torch.no_grad() 304 | def update_meta_embedding_interactions(self, 305 | schic, projection_list, 306 | projected_tensor_list=None, 307 | do_conv=True, do_rwr=True, do_col=False, 308 | first_iter=False 309 | ): 310 | if first_iter: 311 | rec_error_tensor_norm = np.zeros([len(schic), 1]) 312 | device = self.device 313 | bin_cov_list = self.bin_cov_list 314 | rank = self.rank 315 | A_list, B_dict, D_dict, meta_embedding = self.A_list, self.B_dict, self.D_dict, self.meta_embedding 316 | 317 | SVD_term = torch.zeros(meta_embedding.shape[::-1], dtype=torch.float32, device=device) 318 | rec_error_x_U = np.zeros([len(schic), 1]) 319 | rec_error_x_V = 0 320 | 321 | densify_time = 0 322 | partial_rwr_time = 0 323 | svd_time = 0 324 | contract_time = 0 325 | 326 | 327 | for chrom_index, (chrom_data, A, projection, bin_cov) in enumerate(zip( 328 | progressbar(schic), A_list, projection_list, 329 | bin_cov_list 330 | )): 331 | B = B_dict[chrom_data.chrom] 332 | D = D_dict[chrom_data.chrom] 333 | size = self.chrom2size[chrom_data.chrom] 334 | 335 | # chromosome specific embedding 336 | C = torch.matmul(meta_embedding, D) 337 | 338 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 339 | slice_ = chrom_data.bin_slice_list[bin_batch_id] 340 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 341 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 342 | # Fetch and densify the X 343 | temp = None 344 | 345 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 346 | slice_cell = slice(cell_batch_id * chrom_data.bs_cell, 347 | min((cell_batch_id + 1) * chrom_data.bs_cell, chrom_data.num_cell)) 348 | _t1 = time.perf_counter() 349 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 350 | save_context=dict(device=device), 351 | transpose=(do_conv and not self.sparse_conv) or do_rwr, 352 | do_conv=do_conv if self.sparse_conv else False) 353 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 354 | densify_time += np.array(list(t) + [time.perf_counter() - _t1]) 355 | _t = time.perf_counter() 356 | if kind == 'hic': 357 | chrom_batch_cell_batch, t1 = partial_rwr(chrom_batch_cell_batch, 358 | slice_start=slice_local.start, 359 | slice_end=slice_local.stop, 360 | do_conv=False if self.sparse_conv else do_conv, 361 | do_rwr=do_rwr, 362 | do_col=do_col, 363 | bin_cov=bin_cov[slice_cell, slice_col], 364 | bin_cov_row=bin_cov[slice_cell, slice_], 365 | force_rwr_epochs=self.n_i[chrom_index]) 366 | 367 | 368 | if first_iter: 369 | rec_error_tensor_norm[chrom_index] += torch.linalg.norm(chrom_batch_cell_batch).square_().item() 370 | 371 | partial_rwr_time += time.perf_counter() - _t 372 | _t = time.perf_counter() 373 | # lhs: bs_bin, cell, size 374 | lhs = contract('ir,jr,kr->ikj', A[slice_].to(device), B, C[slice_cell].to(device)) 375 | contract_time += time.perf_counter() - _t 376 | _t = time.perf_counter() 377 | # rhs: bs_bin, # bin2, cell 378 | rhs = chrom_batch_cell_batch.to(device) 379 | 380 | if temp is None: 381 | temp = torch.bmm(rhs, lhs) 382 | else: 383 | temp.baddbmm_(rhs, lhs) 384 | if cell_batch_id != chrom_data.num_cell_batch-1: 385 | del chrom_batch_cell_batch 386 | 387 | 388 | _t = time.perf_counter() 389 | # For smaller batch size or small matrix dimension, cpu is much faster 390 | # GPU has advantages when dealing with large batch size or super large matrix 391 | # Here, temp is shape of (bs_bin, total_bin, r) 392 | # bs_bin < 200, r ~ 100, total_bin can goes up to 2280, so... 393 | # svd_device = 'cpu' if temp.shape[1] <= 700 else device 394 | svd_device = device 395 | # try: 396 | U, S = project2orthogonal(temp.to(svd_device), temp.shape[-1], compute_device=device) 397 | # except: 398 | # U, S = project2orthogonal_ill(temp.to(svd_device), temp.shape[-1], compute_device=device) 399 | 400 | svd_time += time.perf_counter() - _t 401 | _t = time.perf_counter() 402 | 403 | # store projections 404 | projection[bin_batch_id] = U.to(projection[bin_batch_id].device) 405 | U = U.to(device) 406 | 407 | 408 | # calc error 409 | if S is None: 410 | rec_error_x_U[chrom_index] += temp.view(-1).inner(U.view(-1)).item() 411 | else: 412 | # assert (S.sum().item() - temp.view(-1).inner(U.view(-1)).item()) / S.sum().item() < 1e-5, ( 413 | # S.sum().item(), temp.view(-1).inner(U.view(-1)).item(), 414 | # S.sum().item() - temp.view(-1).inner(U.view(-1)).item(), 415 | # (S.sum().item() - temp.view(-1).inner(U.view(-1)).item()) / S.sum().item() 416 | # ) 417 | rec_error_x_U[chrom_index] += S.sum().item() 418 | 419 | 420 | 421 | # lhs: rank, # bin1, size 422 | lhs = contract('ir,jr,kr->kij', A[slice_].to(device), B, D) 423 | lhs = lhs.reshape(lhs.shape[0], -1) 424 | _t = time.perf_counter() 425 | 426 | # First use the last densified one 427 | # bin1, size, bin2 * bin1, bin2, bs_cell -> bin1, size, bs_cell 428 | _t = time.perf_counter() 429 | projected = torch.bmm(U.transpose(-1, -2), chrom_batch_cell_batch) 430 | SVD_term[:, slice_cell] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 431 | contract_time += time.perf_counter() - _t 432 | _t = time.perf_counter() 433 | 434 | # All but the last one (which has been reused ) 435 | for cell_batch_id in range(0, chrom_data.num_cell_batch - 1): 436 | slice_cell = slice(cell_batch_id * chrom_data.bs_cell, 437 | min((cell_batch_id + 1) * chrom_data.bs_cell, chrom_data.num_cell)) 438 | 439 | if chrom_data.bs_cell < chrom_data.num_cell: 440 | _t = time.perf_counter() 441 | # torch.cuda.synchronize() 442 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 443 | save_context=dict(device=device), 444 | transpose=(do_conv and not self.sparse_conv) or do_rwr, 445 | do_conv=do_conv if self.sparse_conv else False) 446 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 447 | # torch.cuda.synchronize() 448 | densify_time += np.array(list(t)+[time.perf_counter() - _t]) 449 | _t = time.perf_counter() 450 | if kind == 'hic': 451 | chrom_batch_cell_batch, t1 = partial_rwr(chrom_batch_cell_batch.clamp_(1e-8), 452 | slice_start=slice_local.start, 453 | slice_end=slice_local.stop, 454 | do_conv=False if self.sparse_conv else do_conv, 455 | do_rwr=do_rwr, 456 | do_col=do_col, 457 | bin_cov=bin_cov[slice_cell, slice_col], 458 | bin_cov_row=bin_cov[slice_cell, slice_], 459 | force_rwr_epochs=self.n_i[chrom_index] 460 | ) 461 | else: 462 | t1 = 0 463 | 464 | partial_rwr_time += t1 465 | partial_rwr_time += time.perf_counter() - _t 466 | _t = time.perf_counter() 467 | 468 | # bin1, size, bin2 * bin1, bin2, bs_cell -> bin1, size, bs_cell 469 | _t = time.perf_counter() 470 | projected = torch.bmm(U.transpose(-1, -2), chrom_batch_cell_batch) 471 | SVD_term[:, slice_cell] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 472 | # if kind == 'hic': 473 | # SVD_term_1[:, slice_cell] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 474 | # else: 475 | # SVD_term_2[:, slice_cell] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 476 | contract_time += time.perf_counter() - _t 477 | del chrom_batch_cell_batch 478 | _t = time.perf_counter() 479 | 480 | del C 481 | # SVD_term: dim3 * R 482 | _t = time.perf_counter() 483 | meta_embedding, S = project2orthogonal(SVD_term.T, rank=rank, compute_device=device) 484 | svd_time += time.perf_counter() - _t 485 | _t = time.perf_counter() 486 | rec_error_x_V += meta_embedding.mul(SVD_term.T).sum().item() 487 | 488 | for chrom_index, (chrom_data, projection, bin_cov) in enumerate( 489 | zip(progressbar(schic), projection_list, bin_cov_list)): 490 | 491 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 492 | gather_project = 0 493 | slice_ = chrom_data.bin_slice_list[bin_batch_id] 494 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 495 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 496 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 497 | slice_cell = slice(cell_batch_id * chrom_data.bs_cell, 498 | min((cell_batch_id + 1) * chrom_data.bs_cell, chrom_data.num_cell)) 499 | _t = time.perf_counter() 500 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 501 | save_context=dict(device=device), 502 | transpose=(do_conv and not self.sparse_conv) or do_rwr, 503 | do_conv=do_conv if self.sparse_conv else False) 504 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 505 | densify_time += np.array(list(t) + [time.perf_counter() - _t]) 506 | _t = time.perf_counter() 507 | if kind == 'hic': 508 | chrom_batch_cell_batch, t1 = partial_rwr(chrom_batch_cell_batch.clamp_(1e-8), 509 | slice_start=slice_local.start, 510 | slice_end=slice_local.stop, 511 | do_conv=False if self.sparse_conv else do_conv, 512 | do_rwr=do_rwr, 513 | do_col=do_col, 514 | bin_cov=bin_cov[slice_cell, slice_col], 515 | bin_cov_row=bin_cov[slice_cell, slice_], 516 | force_rwr_epochs=self.n_i[chrom_index]) 517 | partial_rwr_time += t1 518 | 519 | 520 | _t = time.perf_counter() 521 | 522 | projected = contract( 523 | "ijk,km, ijl -> ilm", chrom_batch_cell_batch, meta_embedding[slice_cell], 524 | projection[bin_batch_id].to(meta_embedding.device)) 525 | del chrom_batch_cell_batch 526 | gather_project += projected 527 | contract_time += time.perf_counter() - _t 528 | _t = time.perf_counter() 529 | projected_tensor_list[chrom_data.chrom][chrom_data.global_slice_bin][slice_] = gather_project.to(projected_tensor_list[chrom_data.chrom].device) 530 | 531 | self.meta_embedding = meta_embedding.to(self.meta_embedding.device) 532 | # print (densify_time, partial_rwr_time, svd_time, contract_time) 533 | gc.collect() 534 | try: 535 | torch.cuda.empty_cache() 536 | except: 537 | pass 538 | if first_iter: 539 | return projection_list, projected_tensor_list, rec_error_x_U, rec_error_x_V, rec_error_tensor_norm 540 | return projection_list, projected_tensor_list, rec_error_x_U, rec_error_x_V 541 | 542 | 543 | 544 | def fit(self, schic, size_ratio=0.3, 545 | n_iter_max=2000, n_iter_parafac=5, 546 | do_conv=True, do_rwr=False, do_col=False, tol=1e-8, 547 | size_list = None, gpu_id=None, 548 | verbose=True, 549 | run_init=True): 550 | 551 | self.gpu_id = gpu_id 552 | self.all_in_gpu = False 553 | self.benchmark_speed = False 554 | rank = self.rank 555 | device = self.device 556 | # Calculating sizes, the size would be forced to be the same for matrix from the same chromosomes 557 | 558 | if size_list is None: 559 | size_list = [min(int(chrom_data.shape[0] * size_ratio * chrom_data.resolution / 1000000), rank) 560 | for chrom_data in schic] 561 | chrom2size = {} 562 | for chrom_data, size in zip(schic, size_list): 563 | chrom = chrom_data.chrom 564 | if chrom in chrom2size: 565 | chrom2size[chrom] = min(chrom2size[chrom], size) 566 | else: 567 | chrom2size[chrom] = size 568 | else: 569 | chrom2size = {} 570 | for chrom_data, size in zip(schic, size_list): 571 | chrom = chrom_data.chrom 572 | if chrom in chrom2size: 573 | if size != chrom2size[chrom]: 574 | print ("size of the same chromosome must be same!", size, chrom2size[chrom], chrom) 575 | raise EOFError 576 | else: 577 | chrom2size[chrom] = size 578 | 579 | self.chrom2size = chrom2size 580 | self.chrom2num_bin = {} 581 | self.chrom2id = {chrom:[] for chrom in self.chrom2size} 582 | for chrom_index, chrom_data in enumerate(schic): 583 | self.chrom2id[chrom_data.chrom].append(chrom_index) 584 | if chrom_data.chrom not in self.chrom2num_bin: 585 | self.chrom2num_bin[chrom_data.chrom] = chrom_data.num_bin 586 | chrom_data.global_slice_bin = slice(0, chrom_data.num_bin) 587 | else: 588 | chrom_data.global_slice_bin = slice(self.chrom2num_bin[chrom_data.chrom], 589 | self.chrom2num_bin[chrom_data.chrom]+chrom_data.num_bin) 590 | self.chrom2num_bin[chrom_data.chrom] += chrom_data.num_bin 591 | 592 | print ("empty params initialized") 593 | del size_list 594 | if run_init: 595 | self.init_params(schic, do_conv, do_rwr, do_col) 596 | rec_errors = [] 597 | rec_errors_total = [] 598 | 599 | 600 | # create_projection_list: 601 | projection_list = [] 602 | for chrom_data in schic: 603 | temp1 = [] 604 | for bin_batch in chrom_data.tensor_list: 605 | for cell_batch_bin_batch in bin_batch: 606 | a = torch.empty((cell_batch_bin_batch.shape[0] - 2, 607 | cell_batch_bin_batch.shape[1] - 2, 608 | self.chrom2size[chrom_data.chrom] 609 | ), dtype=torch.float32) 610 | if gpu_flag: a = a.pin_memory() 611 | temp1.append(a) 612 | break 613 | projection_list.append(temp1) 614 | # Note the batch_id dim is at dim 1 615 | projected_tensor_list = {chrom: 616 | torch.empty([self.chrom2num_bin[chrom], self.chrom2size[chrom], rank], dtype=torch.float32) 617 | for chrom in self.chrom2size 618 | } 619 | if gpu_flag: 620 | for a in projected_tensor_list: projected_tensor_list[a] = projected_tensor_list[a].pin_memory() 621 | rec_error_core_norm = np.zeros([len(schic), 1]) 622 | 623 | for chrom_index, (chrom_data, A, re_c, bin_cov) in enumerate(zip( 624 | progressbar(schic), self.A_list, rec_error_core_norm, self.bin_cov_list)): 625 | B = self.B_dict[chrom_data.chrom] 626 | D = self.D_dict[chrom_data.chrom] 627 | for i in range(0, chrom_data.num_bin, chrom_data.bs_bin): 628 | slice_ = slice(i, i + chrom_data.bs_bin) 629 | c = contract('ir,jr,kr->kij', A[slice_].to(device), B, D) 630 | 631 | re_c[:] += torch.linalg.norm(c).square_().item() 632 | del c 633 | 634 | rec_error_tensor_norm = None 635 | for iteration in range(n_iter_max): 636 | if (iteration % 10) == 0 and iteration > 0 and n_iter_parafac < 10: 637 | n_iter_parafac += 1 638 | # print ("n_iter_para", n_iter_parafac) 639 | print("Starting iteration", iteration) 640 | sys.stdout.flush() 641 | 642 | start_time = time.time() 643 | 644 | if rec_error_tensor_norm is None: 645 | projection_list, projected_tensor_list, rec_error_x_U, rec_error_x_V, rec_error_tensor_norm = \ 646 | self.update_meta_embedding_interactions( 647 | schic, projection_list, 648 | projected_tensor_list, 649 | do_conv=do_conv, 650 | do_rwr=do_rwr, 651 | do_col=do_col, 652 | first_iter=True 653 | ) 654 | norm_tensor = np.sqrt(rec_error_tensor_norm).reshape((-1)) 655 | norm_tensor_all = float(np.linalg.norm(norm_tensor)) 656 | else: 657 | projection_list, projected_tensor_list, rec_error_x_U, rec_error_x_V = \ 658 | self.update_meta_embedding_interactions( 659 | schic, projection_list, 660 | projected_tensor_list, 661 | do_conv=do_conv, 662 | do_rwr=do_rwr, 663 | do_col=do_col) 664 | 665 | 666 | 667 | rec_error_by_block_U = rec_error_tensor_norm + rec_error_core_norm - 2 * rec_error_x_U 668 | rec_error_V = rec_error_tensor_norm.sum() + rec_error_core_norm.sum() - 2 * rec_error_x_V 669 | del rec_error_x_U, rec_error_x_V 670 | 671 | rec_error_x_core = np.zeros([len(schic), 1]) 672 | 673 | # Run parafac on projected tensors (size of (dim1, size, rank)) 674 | for chrom in self.chrom2id: 675 | ids = self.chrom2id[chrom] 676 | temp_A = torch.cat([self.A_list[i] for i in ids], dim=0) 677 | temp_B = self.B_dict[chrom] 678 | temp_D = self.D_dict[chrom] 679 | temp_factors = [temp_A, temp_B, temp_D] 680 | factors, core_norm, loss_x = parafac( 681 | projected_tensor_list[chrom], 682 | rank=self.chrom2size[chrom], 683 | init=temp_factors, 684 | n_iter_max=n_iter_parafac, 685 | verbose=False, 686 | ) 687 | 688 | # rec_error_core_norm[chrom_index] = core_norm 689 | # rec_error_x_core[chrom_index] = loss_x 690 | 691 | for i in ids: 692 | self.A_list[i][:] = factors[0][schic[i].global_slice_bin].to(self.A_list[i].device) 693 | 694 | self.B_dict[chrom][:] = factors[1].to(self.B_dict[chrom].device) 695 | self.D_dict[chrom][:] = factors[2].to(self.D_dict[chrom].device) 696 | 697 | rec_error_core_norm = np.zeros([len(schic), 1]) 698 | for chrom_index, (chrom_data, A, re_c, bin_cov) in enumerate(zip( 699 | progressbar(schic), self.A_list, rec_error_core_norm, self.bin_cov_list)): 700 | B = self.B_dict[chrom_data.chrom] 701 | D = self.D_dict[chrom_data.chrom] 702 | for i in range(0, chrom_data.num_bin, chrom_data.bs_bin): 703 | slice_ = slice(i, i + chrom_data.bs_bin) 704 | c = contract('ir,jr,kr->kij', A[slice_].to(device), B, D) 705 | re_c[:] += torch.linalg.norm(c).square_().item() 706 | del c 707 | 708 | print() 709 | 710 | rec_error = np.sqrt(rec_error_V.sum()) / norm_tensor_all 711 | rec_errors_total.append(rec_error) 712 | rec_error_by_block = np.sqrt(rec_error_by_block_U.ravel()) / norm_tensor 713 | rec_errors.append(rec_error_by_block) 714 | 715 | if iteration >= 1: 716 | differences = (rec_errors[-2] ** 2 - rec_errors[-1] ** 2) / (rec_errors[-2] ** 2) 717 | total_differences = ( 718 | (rec_errors_total[-2] ** 2 - rec_errors_total[-1] ** 2) / rec_errors_total[-2] ** 2) 719 | 720 | print( 721 | f"PARAFAC2 re={rec_error:.3f} " 722 | f"{total_differences:.2e} " 723 | f"variation min{differences.min().item():.1e} at chrom {differences.argmin().item():d}, " 724 | f"max{differences.max().item():.1e} at chrom {differences.argmax().item():d}", 725 | f"takes {time.time() - start_time:.1f}s" 726 | ) 727 | # if iteration >= 3 and tol > 0 and (total_differences < tol or differences.max() < tol * 2): 728 | if iteration >= 3 and tol > 0 and (total_differences < tol or differences.max() < tol * 2): 729 | print('converged in {} iterations.'.format(iteration)) 730 | break 731 | else: 732 | 733 | print( 734 | f"PARAFAC2 re={rec_error:.3f} " 735 | f"takes {time.time() - start_time:.1f}s" 736 | ) 737 | sys.stdout.flush() 738 | self.projection_list = projection_list 739 | self.projected_tensor_list = projected_tensor_list 740 | return self 741 | 742 | def transform(self, schic, do_conv, do_rwr, do_col): 743 | # if self.device == 'cpu': 744 | # do_conv = False 745 | # do_rwr = False 746 | # do_col = False 747 | # final update of meta-embeddings: 748 | device = self.device 749 | projection_list = self.projection_list 750 | projected_tensor_list = self.projected_tensor_list 751 | bin_cov_list = self.bin_cov_list 752 | bad_bin_cov_list = self.bad_bin_cov_list 753 | rank = self.rank 754 | print ("start transform") 755 | SVD_term = torch.zeros([self.meta_embedding.shape[-1], schic[0].total_cell_num], dtype=torch.float32, device=device) 756 | 757 | lhs_all = 0 758 | 759 | 760 | for chrom_index, (chrom_data, A, projection, bin_cov, bad_bin_cov) in enumerate(zip( 761 | progressbar(schic), self.A_list, projection_list, 762 | bin_cov_list, bad_bin_cov_list 763 | )): 764 | B = self.B_dict[chrom_data.chrom] 765 | D = self.D_dict[chrom_data.chrom] 766 | for bin_batch_id in range(0, chrom_data.num_bin_batch): 767 | # slice_ = slice(bin_batch_id * chrom_data.bs_bin, bin_batch_id * chrom_data.bs_bin + chrom_data.bs_bin) 768 | slice_ = chrom_data.bin_slice_list[bin_batch_id] 769 | slice_local = chrom_data.local_bin_slice_list[bin_batch_id] 770 | slice_col = chrom_data.col_bin_slice_list[bin_batch_id] 771 | # lhs: rank, # bin1, size 772 | lhs = contract('ir,jr,kr->kij', A[slice_].to(device), B, D) 773 | lhs = lhs.reshape(lhs.shape[0], -1) 774 | lhs_all += lhs @ lhs.T 775 | U = projection[bin_batch_id].to(device) 776 | 777 | # Fetch and densify the X 778 | 779 | for cell_batch_id in range(0, chrom_data.num_cell_batch): 780 | slice_cell = slice(cell_batch_id * chrom_data.bs_cell, 781 | min((cell_batch_id + 1) * chrom_data.bs_cell, chrom_data.num_cell)) 782 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 783 | save_context=dict(device=device), 784 | transpose=(do_conv and not self.sparse_conv) or do_rwr, 785 | do_conv=do_conv if self.sparse_conv else False) 786 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 787 | 788 | if kind == 'hic': 789 | chrom_batch_cell_batch, t1 = partial_rwr(chrom_batch_cell_batch, 790 | slice_start=slice_local.start, 791 | slice_end=slice_local.stop, 792 | do_conv=False if self.sparse_conv else do_conv, 793 | do_rwr=do_rwr, 794 | do_col=do_col, 795 | bin_cov=bin_cov[slice_cell, slice_col], 796 | bin_cov_row=bin_cov[slice_cell, slice_], 797 | force_rwr_epochs=self.n_i[chrom_index]) 798 | 799 | projected = torch.bmm(U.transpose(-1, -2), chrom_batch_cell_batch) 800 | 801 | SVD_term[:, slice_cell] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 802 | 803 | for cell_batch_id in range(0, chrom_data.num_cell_batch_bad): 804 | slice_cell = slice(cell_batch_id * chrom_data.bs_cell, 805 | (cell_batch_id + 1) * chrom_data.bs_cell) 806 | chrom_batch_cell_batch, kind = chrom_data.fetch(bin_batch_id, cell_batch_id, 807 | save_context=dict(device=device), 808 | transpose=(do_conv and not self.sparse_conv) or do_rwr, 809 | do_conv=do_conv if self.sparse_conv else False, 810 | good_qc=False) 811 | chrom_batch_cell_batch, t = chrom_batch_cell_batch 812 | if kind == 'hic': 813 | chrom_batch_cell_batch, t1 = partial_rwr(chrom_batch_cell_batch, 814 | slice_start=slice_local.start, 815 | slice_end=slice_local.stop, 816 | do_conv=False if self.sparse_conv else do_conv, 817 | do_rwr=do_rwr, 818 | do_col=do_col, 819 | bin_cov=bad_bin_cov[slice_cell, slice_col], 820 | bin_cov_row=bad_bin_cov[slice_cell, slice_], 821 | force_rwr_epochs=self.n_i[chrom_index]) 822 | 823 | projected = torch.bmm(U.transpose(-1, -2), chrom_batch_cell_batch) 824 | slice_cell2 = slice(chrom_data.num_cell + cell_batch_id * chrom_data.bs_cell, 825 | chrom_data.num_cell + (cell_batch_id + 1) * chrom_data.bs_cell) 826 | SVD_term[:, slice_cell2] += torch.matmul(lhs, projected.reshape(-1, projected.shape[-1])) 827 | 828 | 829 | 830 | # SVD_term: dim3 * R 831 | meta_embedding, S = project2orthogonal(SVD_term.T, rank=rank, compute_device=device) 832 | parafac2_tensor = (None, (self.A_list, self.B_dict.values(), self.D_dict.values(), meta_embedding), self.projection_list) 833 | 834 | return parafac2_tensor 835 | 836 | def fit_transform(self, schic, size_ratio=0.3, 837 | n_iter_max=2000, n_iter_parafac=5, 838 | do_conv=True, do_rwr=False, do_col=False, tol=1e-8, 839 | size_list = None, gpu_id=None, 840 | verbose=True, run_init=True): 841 | print ("n_iter_parafac", n_iter_parafac) 842 | self.fit(schic, size_ratio, 843 | n_iter_max, n_iter_parafac, 844 | do_conv, do_rwr, do_col, tol, 845 | size_list, gpu_id, 846 | verbose, run_init) 847 | 848 | 849 | 850 | return self.transform(schic, do_conv, do_rwr, do_col) -------------------------------------------------------------------------------- /fasthigashi/parafac_integrative.py: -------------------------------------------------------------------------------- 1 | import math, itertools, sys 2 | from tqdm.auto import tqdm, trange 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import torch 8 | from opt_einsum import contract 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | def update_factor(f, A, B): 13 | t = A.diagonal(dim1=-1, dim2=-2) 14 | t += 1e-10 15 | f[:] = torch.linalg.solve(A, B).T 16 | 17 | 18 | 19 | def balance_norm(factors): 20 | norm = 1 21 | for f in factors: 22 | norm = norm * torch.norm(f, dim=0) 23 | even_split = (norm + 1e-15)#.pow(1.0 / len(factors)) 24 | for i in range(len(factors)): 25 | factors[i] = factors[i] / (torch.norm(factors[i], dim=0) + 1e-15) #* even_split 26 | factors[-1] = factors[-1] * even_split 27 | 28 | def parafac( 29 | X, rank, n_iter_max=100, init=None, verbose=False, 30 | common_factor=None, 31 | ): 32 | context = dict(device=device, dtype=torch.float32) 33 | X = X.to(**context) 34 | ndim = len(X.shape) 35 | 36 | factors = init 37 | 38 | factors = [ 39 | factor.to(**context) for factor in factors 40 | ] 41 | 42 | 43 | def calc_A(factor, i): 44 | # A = torch.full([rank] * 2, 1. if share_factors[i] == 'shared' else (1.+l2_reg), **context) 45 | A = torch.full([rank] * 2, 1., **context) 46 | for j, f in enumerate(factor): 47 | if j != i: A.mul_(f.T @ f) 48 | return A 49 | 50 | formula_B = [ 51 | ','.join(['ijk'] + ['ijk'[j] + 'r' for j in range(ndim) if j != i]) + 52 | '->' + (('ijk'[i] + 'r')) 53 | for i in range(ndim) 54 | ] 55 | def calc_B(factor, i, out=None): 56 | B = torch.zeros(factor[i].shape, **context) 57 | B += contract( 58 | formula_B[i], 59 | X, 60 | *[factor[_] for _ in range(len(factor)) if _ != i], 61 | ) 62 | return B 63 | 64 | 65 | balance_norm(factors) 66 | 67 | history = [] 68 | if verbose: pbar = trange(n_iter_max) 69 | else: pbar = range(n_iter_max) 70 | 71 | iiter = None 72 | for iiter in pbar: 73 | iiter += 1 74 | for i in range(len(factors)): 75 | A = calc_A(factors, i) 76 | B = calc_B(factors, i) 77 | update_factor(factors[i], A, B.T) 78 | 79 | 80 | loss_recon, loss_reg, loss_x, loss_norm = 0, 0, 0, 0 81 | 82 | Xhat = contract('ir,jr,kr->ijk', *factors) 83 | loss_recon += torch.linalg.norm(X - Xhat).square_().item() 84 | loss_norm += torch.linalg.norm(Xhat).square_().item() 85 | 86 | loss_x += Xhat.mul(X).sum().item() 87 | loss = loss_recon 88 | 89 | history.append({ 90 | 'loss recon': loss_recon, 91 | 'loss': loss, 92 | 'loss x': loss_x, 93 | 'loss norm': loss_norm, 94 | }) 95 | 96 | if len(history) < 2: wdiff = np.nan 97 | else: wdiff = (history[-2]['loss'] - history[-1]['loss']) / history[-2]['loss'] 98 | 99 | if isinstance(pbar, tqdm): 100 | pbar.set_description( 101 | f"loss:{loss:.2e} = {loss_recon:.2e} + {loss_reg:.2e} " 102 | # f"norm:{norm_target.min().item():.2e} {norm_target.max().item():.2e} " 103 | f"%diff:{wdiff:.2e}" 104 | ) 105 | pbar.update(1) 106 | 107 | if wdiff < 1e-5: break 108 | balance_norm(factors) 109 | 110 | del X 111 | # print(f"{len(history)}, {wdiff:.1e};", end=' ') 112 | return factors, history[-1]['loss norm'], history[-1]['loss x'] 113 | 114 | -------------------------------------------------------------------------------- /fasthigashi/partial_rwr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | import numpy as np 5 | torch.backends.cudnn.benchmark = True 6 | import torch.jit as jit 7 | 8 | def slice_arrange_func(x, slice_,strips=30): 9 | num_bin = int(x.shape[1]) 10 | index_dim0 = torch.tile(torch.arange(num_bin), (strips,)) 11 | index_dim1 = (torch.arange(strips).reshape((-1, 1)) + 12 | torch.arange(slice_.start, num_bin + slice_.start).reshape((1, -1))).reshape((-1)) 13 | index_dim1_ = torch.repeat_interleave(torch.arange(strips), num_bin) 14 | filter = index_dim1 < x.shape[2] 15 | 16 | a = torch.zeros(x.shape[0], x.shape[1], strips).float().to(x.device) 17 | a[:, index_dim0[filter], index_dim1_[filter]] = x[:, index_dim0[filter], index_dim1[filter]] 18 | return a 19 | 20 | 21 | def rec2tilte(x, strips): 22 | num_bin = int(x.shape[1]) 23 | index_dim0 = torch.tile(torch.arange(num_bin), (strips,)) 24 | index_dim1 = (torch.arange(strips).reshape((-1, 1)) + 25 | torch.arange(num_bin).reshape((1, -1))).reshape((-1)) 26 | index_dim1_ = torch.repeat_interleave(torch.arange(strips), num_bin) 27 | filter = index_dim1 < x.shape[2] 28 | 29 | a = torch.zeros(x.shape[0], x.shape[1], x.shape[1]+strips).float().to(x.device) 30 | a[:, index_dim0[filter], index_dim1[filter]] = x[:, index_dim0[filter], index_dim1_[filter]] 31 | return a 32 | 33 | def tilte2rec(x, strips): 34 | num_bin = int(x.shape[1]) 35 | index_dim0 = torch.tile(torch.arange(num_bin), (strips,)) 36 | index_dim1 = (torch.arange(strips).reshape((-1, 1)) + 37 | torch.arange(num_bin).reshape((1, -1))).reshape((-1)) 38 | index_dim1_ = torch.repeat_interleave(torch.arange(strips), num_bin) 39 | filter = index_dim1 < x.shape[2] 40 | 41 | a = torch.zeros(x.shape[0], x.shape[1], strips).float().to(x.device) 42 | a[:, index_dim0[filter], index_dim1_[filter]] = x[:, index_dim0[filter], index_dim1[filter]] 43 | return a 44 | 45 | @torch.no_grad() 46 | # @jit.script 47 | def partial_rwr(x: torch.Tensor, 48 | slice_start: int, 49 | slice_end: int, 50 | do_conv:bool, 51 | do_rwr:bool, 52 | do_col:bool, 53 | bin_cov:torch.Tensor=torch.ones(1), 54 | bin_cov_row:torch.Tensor=torch.ones(1), 55 | return_rwr_iter:bool=False, 56 | force_rwr_epochs:int=-1, 57 | final_transpose:bool=True, 58 | slice_arrange:bool=False, 59 | slice_arrange_size:int=100, 60 | **kw 61 | # compact=False, 62 | # flank:int=10, 63 | # max_dis:int=10 64 | ): 65 | 66 | # The slice_start / end has to be from the local_bin_slice_list not the global one. 67 | # slice_arrange: when true, after rwr, only store the elements that are within max_dis 68 | # The returned element would be of size: (bs_cell, bs_bin, 2*max_dis + 1) 69 | # compact: input tensor is in compact form or dense form. 70 | 71 | # if max_dis > flank and compact: 72 | # print ("max_dis has to be smaller than flank for compact matrix.") 73 | # raise EOFError 74 | 75 | n_iter = 0 76 | if do_conv or do_rwr: 77 | if do_conv and x.shape[1] > 1: 78 | pad = 1 79 | ll = pad * 2 + 1 80 | x = F.avg_pool2d(x[:, None], ll, 1, padding=pad, ceil_mode=True).clamp_(min=1e-8) 81 | x = x[:, 0, :, :] 82 | 83 | if do_rwr: 84 | A = x 85 | local_sim_2nd = (torch.bmm(A, A.permute(0, 2, 1))) # size of (bs, #bin1, #bin1) 86 | t = torch.diagonal(local_sim_2nd, dim1=-2, dim2=-1) 87 | t.zero_() 88 | local_sim_1st = A[:, :, slice_start:slice_end].clone() 89 | 90 | local_sim_2nd = (local_sim_2nd.div_(local_sim_2nd.sum(1, keepdim=True).add_(1e-15))) * 0.25 91 | local_sim_1st = (local_sim_1st.div_(local_sim_1st.sum(1, keepdim=True).add_(1e-15))) * 0.75 92 | 93 | local_sim = local_sim_1st.add_(local_sim_2nd) 94 | 95 | # # fill entries with zero coverage with 1 96 | t = torch.diagonal(local_sim, dim1=-2, dim2=-1) 97 | t[:] += local_sim.sum(1) == 0 98 | 99 | if force_rwr_epochs < 0: 100 | auto_stop = True 101 | rwr_epochs = 60 102 | else: 103 | rwr_epochs = force_rwr_epochs 104 | auto_stop = False 105 | 106 | # Copy paste code here for speed optimization 107 | # For input A of size x 108 | # P, Q, Q_new: 3x 109 | rp = 0.5 110 | ngene = local_sim.shape[1] 111 | P = local_sim.div_(local_sim.sum(1, keepdim=True).add_(1e-15)) 112 | Q = torch.eye(ngene, device=P.device)[None] 113 | Q = Q.repeat(local_sim.shape[0], 1, 1) 114 | epoch_count = 0 115 | for i in range(rwr_epochs): 116 | Q_new = rp * torch.bmm(Q, P) 117 | t = torch.diagonal(Q_new, dim1=-2, dim2=-1) 118 | t.add_(1 - rp) 119 | if auto_stop: 120 | delta = ((Q - Q_new).square_()).sum(dim=1).sum(dim=1).sqrt_() 121 | Q = Q_new 122 | if torch.max(delta) < 0.01: 123 | break 124 | else: 125 | Q = Q_new 126 | epoch_count += 1 127 | 128 | local_sim = Q 129 | n_iter = epoch_count 130 | 131 | if do_col: 132 | local_sim = (local_sim + local_sim.permute(0, 2, 1)) * 0.5 133 | local_sim = local_sim.clamp_(min=0.0) 134 | local_sim = local_sim.div_(local_sim.sum(2, keepdim=True).add_(1e-15)) 135 | A = A.div_(bin_cov[:, None, :].to(A.device)) 136 | 137 | 138 | x = torch.bmm(local_sim, A) 139 | 140 | # if slice_arrange: 141 | # # currect x: (bs_cell, bs_bin, all_bin or flank...) 142 | # num_bin = int(x.shape[1]) 143 | # 144 | # 145 | # 146 | # if compact: 147 | # # When compact, it's the easier, because the main diag is always flank~flank+bs_bin, flank~flank+bs_bin 148 | # # plus we require that max_dis < flank, we don't need to worry negative index or outof boundary index 149 | # row_index = torch.repeat_interleave(torch.arange(num_bin, device=x.device), 2*max_dis+1) 150 | # # It'll be [0,0,0,...0,1,1,...,1,2,2,...,2...] 151 | # 152 | # # for col, slices or (2*max_dis+1), needs to add flank such that starts at the true main diag, minus max_dis, such center is main_diag 153 | # # Then for each row, increase index by one... 154 | # col_index = (torch.arange(flank-max_dis, flank-max_dis+2*max_dis+1, device=x.device)).reshape((1, -1)) + \ 155 | # torch.arange(num_bin, device=x.device).reshape((-1, 1)) 156 | # col_index = col_index.reshape((-1)) 157 | # # Should be [0,1,2,3..., 1,2,3..,2,3,...] correspond to strips of first, second... 158 | # 159 | # else: 160 | # # When not compact things is a little more complicated. We will have cases where the value is out of index(too left or too right) 161 | # row_index = torch.repeat_interleave(torch.arange(num_bin, device=x.device), 2 * max_dis + 1) 162 | # # It'll be [0,0,0,...0,1,1,...,1,2,2,...,2...] 163 | # # For col, it should start at slice_start - max_dis ends at slice_start+max_dis+1 164 | # col_index = (torch.arange(slice_start - max_dis, slice_start+max_dis+1).reshape((1, -1)) + 165 | # torch.arange(num_bin, device=x.device).reshape((-1, 1))) 166 | # col_index = col_index.reshape((-1)).clamp_(min=0, max=x.shape[2]-1) 167 | # # Main different is that, we add clamp to address out of index problem. 168 | # # One Caveat, it will now copy out of index values instead of putting it as 0 169 | # x = x[:, row_index, col_index] 170 | 171 | if final_transpose: 172 | x = x.permute(1, 2, 0) 173 | if return_rwr_iter: 174 | return x, n_iter 175 | return x, 0 176 | 177 | 178 | -------------------------------------------------------------------------------- /fasthigashi/preprocessing.py: -------------------------------------------------------------------------------- 1 | import time, sys, itertools 2 | from tqdm.auto import tqdm, trange 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from scipy.sparse import coo_matrix 7 | 8 | from scipy.ndimage import gaussian_filter 9 | from scipy.sparse import csr_matrix, save_npz, diags, eye, vstack 10 | from sklearn.linear_model import LinearRegression 11 | 12 | # Include VC / VC_SQRT norm 13 | class NormalizerBin: 14 | def __init__(self, method='SQVC'): 15 | self.method = method 16 | self.vec = None 17 | 18 | def fit(self, bulk, eps=1e-10): 19 | row_sum = np.array(bulk.sum(0), copy=False).ravel() + eps 20 | if self.method == 'SQVC': 21 | self.vec = row_sum ** (-.5) 22 | elif self.method == 'VC': 23 | self.vec = row_sum ** -1 24 | else: 25 | raise NotImplementedError 26 | return self 27 | 28 | def transform(self, m, inplace=True): 29 | assert inplace 30 | if type(m).__name__ == 'coo_matrix': 31 | total = m.data.sum() 32 | m.data *= self.vec[m.col] * self.vec[m.row] 33 | m.data *= total / m.data.sum() 34 | elif type(m).__name__ == 'ndarray': 35 | total = m.sum() 36 | m *= self.vec[None, :] * self.vec[:, None] 37 | m *= total / m.sum() 38 | return m 39 | 40 | 41 | class NormalizerOE: 42 | def __init__(self): 43 | self.vec = None 44 | 45 | def fit(self, bulk): 46 | L = len(bulk) 47 | raise NotImplementedError 48 | vec = bulk.ravel()[:-1].reshape(L-1, L+1)[:, :-1].sum(0) 49 | vec[0] += bulk[-1, -1] 50 | self.vec = vec ** -1 51 | return self 52 | 53 | def transform(self, m, inplace=True): 54 | assert type(m).__name__ == 'coo_matrix' 55 | assert inplace 56 | total = m.data.sum() 57 | m.data *= self.vec[np.abs(m.col - m.row)] 58 | m.data *= total / m.data.sum() 59 | return m 60 | 61 | 62 | class Clip: 63 | def __init__(self, axis='entry', s=10.): 64 | self.thr = None 65 | self.axis = axis 66 | self.s = s 67 | 68 | def fit(self, matrix_list, bulk): 69 | bulk_0 = calc_bulk([coo_matrix((np.ones_like(m.data), (m.row, m.col)), shape=m.shape) for m in matrix_list]) 70 | bulk_0 += 1e-10 71 | bulk_2 = calc_bulk([coo_matrix((m.data**2, (m.row, m.col)), shape=m.shape) for m in matrix_list]) 72 | if self.axis == 'entry': 73 | bulk_1 = bulk / bulk_0 74 | bulk_2 /= bulk_0 75 | elif self.axis == 'row': 76 | bulk_1 = bulk.sum(1, keepdims=True) / bulk_0.sum(1, keepdims=True) 77 | bulk_2 = bulk_2.sum(1, keepdims=True) / bulk_0.sum(1, keepdims=True) 78 | else: raise NotImplementedError 79 | mean = bulk_1 80 | std = bulk_2 - bulk_1**2 81 | assert (std >= -1e-5).all(), std.min() 82 | std = np.maximum(0, std) ** .5 83 | self.thr = np.broadcast_to(mean + std * self.s, shape=bulk.shape) 84 | 85 | def transform(self, m, inplace=True): 86 | assert inplace 87 | if isinstance(m, np.ndarray): 88 | m[:] = np.clip(m.data, a_min=None, a_max=self.thr) 89 | else: 90 | thr = self.thr[m.row, m.col] 91 | cnt = m.data > thr 92 | m.data = np.clip(m.data, a_min=None, a_max=thr) 93 | assert (m.data > 0).all() 94 | return m 95 | 96 | 97 | def regress_out(x, y): 98 | # y = np.unique(y, return_inverse=True)[1] 99 | mean = pd.DataFrame(x).groupby(y).mean() 100 | x = x - mean.loc[y].values 101 | return x 102 | 103 | 104 | def quantile_normalization(x, y): 105 | y = np.unique(y, return_inverse=True)[1] 106 | z = np.empty_like(x) 107 | for c, df in pd.DataFrame(x).groupby(y): 108 | rank_mean = df.stack().groupby(df.rank(method='first').stack().astype(int)).mean() 109 | df = df.rank(method='min').stack().astype(int).map(rank_mean).unstack() 110 | z[df.index] = df.values 111 | return z 112 | 113 | 114 | def calc_bulk(matrix_list): 115 | # shape = matrix_list[0].shape 116 | # nnz = sum(m.nnz for m in matrix_list) 117 | # indices = np.empty([3, nnz], dtype=np.int16) 118 | # values = np.empty([nnz], dtype=np.float32) 119 | # del nnz 120 | # idx_nnz = 0 121 | # for i, m in enumerate(tqdm(matrix_list)): 122 | # idx = slice(idx_nnz, idx_nnz + m.nnz) 123 | # indices[0, idx] = m.row 124 | # indices[1, idx] = m.col 125 | # values[idx] = m.data 126 | # idx_nnz += m.nnz 127 | # del idx, m 128 | # bulk = coo_matrix((values[:idx_nnz], tuple(indices[:2, :idx_nnz])), shape) 129 | # del idx_nnz 130 | # bulk = np.array(bulk.todense()) 131 | # bulk /= len(matrix_list) 132 | # return bulk 133 | bulk = sum_sparse(matrix_list) 134 | return bulk / len(matrix_list) 135 | 136 | 137 | def normalize_by_coverage(m, mi=None, scale=None): 138 | scale = m.shape[0] if scale is None else scale 139 | if m is mi or mi is None: n = m.sum() 140 | else: n = m.sum() + mi.sum() 141 | m.data *= scale / (n + 1e-15) 142 | return m 143 | 144 | 145 | def normalize_by_coverage_clip(m, mi=None, scale=None, bulk=None): 146 | scale = m.shape[0] if scale is None else scale 147 | if m is mi or mi is None: n = m.sum() 148 | else: n = m.sum() + mi.sum() 149 | off_diag = (np.sum(m > 0) - np.sum(m.diagonal() > 0)) / 2 150 | if off_diag > m.shape[0]: 151 | m.data *= scale / (n + 1e-15) 152 | else: 153 | # print ("clip low cov_data") 154 | m.data *= scale / (n + 1e-15) 155 | return m 156 | 157 | def conv(m, *args, **kwargs): 158 | A = gaussian_filter((m).astype(np.float32).toarray(), 1, order=0, mode='mirror', truncate=1) 159 | return A 160 | 161 | def log1p_matrix(m): 162 | m.data = np.log1p(m.data) 163 | return m 164 | 165 | 166 | def half_main_diag(m, *args, **kwargs): 167 | m.data[m.col == m.row] /= 2 168 | return m 169 | def zero_main_diag(m, *args, **kwargs): 170 | m.data[m.col == m.row] = 0.0 171 | return m 172 | 173 | def normalize_per_cell( 174 | matrix_list, matrix_list_intra, bulk=None, per_cell_normalize_func=(), 175 | normalizers=(), 176 | ): 177 | if bulk is None: bulk = calc_bulk(matrix_list) 178 | # normalizers = [ 179 | # NormalizerBin(method='SQVC'), 180 | # NormalizerOE(), 181 | # ] 182 | for normalizer in normalizers: 183 | normalizer.fit(matrix_list=matrix_list, bulk=bulk) 184 | bulk = normalizer.transform(bulk) 185 | for i, (m, mi) in enumerate(zip(matrix_list, matrix_list_intra)): 186 | for normalizer in normalizers: 187 | normalizer.transform(m) 188 | for func in per_cell_normalize_func: 189 | m = func(m, mi) 190 | matrix_list[i] = m 191 | 192 | return matrix_list 193 | 194 | 195 | def norm2(mtx_list, info, info2, bk_cov): 196 | mtx_list, batch = mtx_list 197 | for i, m in enumerate(mtx_list): 198 | row, col, data = m.row, m.col, m.data 199 | distance = np.abs(row - col).astype('int') 200 | # if multihic: 201 | # distance = np.ceil((np.sqrt(8 * distance + 1) - 1) / 2).astype('int') 202 | ratio = info[batch[i]] 203 | cov = info2[batch[i]] 204 | data = data / (np.sqrt(cov[row]) * np.sqrt(cov[col])) * (np.sqrt(bk_cov[row]) * np.sqrt(bk_cov[col])) 205 | 206 | d = distance 207 | # divided by batch ratio 208 | new_data = data / (ratio[d] + 1e-15) 209 | 210 | m.data = new_data 211 | mtx_list[i] = m 212 | 213 | 214 | 215 | return mtx_list 216 | 217 | 218 | def sum_sparse(m): 219 | x = np.zeros(m[0].shape) 220 | for a in m: 221 | x[a.row, a.col] += a.data 222 | return x 223 | 224 | 225 | def sum_sparse_by_batch_id(m_list, batch): 226 | avail_batch = np.unique(batch) 227 | return_dict = {b:np.zeros(m_list[0].shape) for b in avail_batch} 228 | for a,b in zip(m_list, batch): 229 | return_dict[b][a.row, a.col] += a.data 230 | return return_dict 231 | 232 | def normalize_per_batch(bulk, batch_bulk, matrix_list, batch_id, off_diag): 233 | info = {} 234 | info2 = {} 235 | matrix_list = np.array(matrix_list) 236 | bk_cov = bulk.sum(axis=-1) 237 | bulk /= (np.sqrt(bk_cov[None]) + 1e-15) 238 | bulk /= (np.sqrt(bk_cov[:, None]) + 1e-15) 239 | bk_sum = bulk.sum() 240 | 241 | import math 242 | # max_size = int(math.ceil((math.sqrt(8*(bulk.shape[0]-1) + 1) - 1) / 2)) + 1 243 | # if multihic: 244 | # bulk_ratio = np.zeros((max_size)) 245 | # else: 246 | bulk_ratio = np.zeros((off_diag)) 247 | for k in range(off_diag): 248 | a = np.diagonal(bulk,k).sum() if k==0 else np.diagonal(bulk,k).sum() * 2 249 | # id_ = int(math.ceil((math.sqrt(8*k + 1) - 1) / 2)) if multihic else k 250 | id_ = k 251 | bulk_ratio[id_] += a 252 | 253 | bulk_ratio = np.array(bulk_ratio) / bk_sum 254 | for b in batch_bulk.keys(): 255 | m = batch_bulk[b] 256 | m_cov = m.sum(axis=-1) 257 | # 258 | m = m / (np.sqrt(m_cov[None]) + 1e-15) 259 | m = m / (np.sqrt(m_cov[:, None]) + 1e-15) 260 | # 261 | m_sum = m.sum() 262 | 263 | # if multihic: 264 | # ratio = np.zeros((max_size)) 265 | # else: 266 | ratio = np.zeros((off_diag)) 267 | for k in range(off_diag): 268 | a = np.diagonal(m, k).sum() if k == 0 else np.diagonal(m, k).sum() * 2 269 | # id_ = int(math.ceil((math.sqrt(8 * k + 1) - 1) / 2)) if multihic else k 270 | id_ = k 271 | ratio[id_] += a 272 | 273 | ratio = np.array(ratio) / (m_sum + 1e-15) 274 | 275 | info[b] = ratio / (bulk_ratio + 1e-15) 276 | info2[b] = np.array(m_cov) 277 | 278 | from tqdm.contrib.concurrent import process_map 279 | # from functools import partial 280 | from multiprocessing import Pool 281 | 282 | # func = partial(norm2, info=info, info2=info2, bk_cov=bk_cov) 283 | # batch_num = int(len(matrix_list) / 250) 284 | matrix_list = norm2((matrix_list, batch_id), info, info2, bk_cov) 285 | # matrix_list = np.array_split(matrix_list, batch_num) 286 | # batch_id = np.array_split(batch_id, batch_num) 287 | # # p = Pool(batch_num) 288 | # matrix_list = process_map(func, zip(matrix_list,batch_id), max_workers=batch_num, total=len(matrix_list)) 289 | # # matrix_list = p.map(func, zip(matrix_list,batch_id)) 290 | # matrix_list = np.concatenate(matrix_list, axis=0) 291 | 292 | return list(matrix_list) 293 | 294 | 295 | def reformat_input(matrix_list, config, valid_bin=None, off_diag=None, fac_size=None, loss_distribution='Gaussian', sparse=False): 296 | m = matrix_list[0] 297 | if off_diag is None: 298 | off_diag = int(50000000 / config['resolution']) 299 | 300 | if fac_size is None: 301 | fac_size = int(300000 / config['resolution']) 302 | if fac_size <= 2: 303 | fac_size = 1 304 | 305 | if valid_bin is None: valid_bin = np.ones(m.shape[0], dtype=bool) 306 | 307 | nnz = sum(m.nnz for m in matrix_list) 308 | indices = np.empty([3, nnz], dtype=np.int16) 309 | values = np.empty([nnz], dtype=np.float32) 310 | del nnz 311 | 312 | patch_size = min(2 * off_diag + 1, m.shape[1]) 313 | shape = (m.shape[0], patch_size) 314 | size_l = patch_size // 2 315 | size_r = (patch_size + 1) // 2 316 | idxs = np.mgrid[[slice(0, s) for s in [m.shape[0], patch_size]]] 317 | mask = (idxs[0] + idxs[1] >= size_l) & (idxs[0] + idxs[1] < sum(shape)-size_r) 318 | if sparse: 319 | new_a = None 320 | else: 321 | indices = None 322 | values = None 323 | new_a = np.empty(shape + (len(matrix_list),), dtype=np.float32) 324 | 325 | idx_nnz = 0 326 | for i, m in enumerate(tqdm(matrix_list)): 327 | # if loss_distribution in ['Gaussian', 'ZIG']: 328 | # pass 329 | # elif loss_distribution in ['NB']: 330 | # pass 331 | # else: 332 | # raise ValueError 333 | 334 | row, col, data = m.row, m.col, m.data 335 | col_new = col - row 336 | idx = valid_bin[row] & valid_bin[col] & (col_new >= -size_l) & (col_new < size_r) 337 | row, col, data = row[idx], col_new[idx], data[idx] 338 | col += size_l 339 | del col_new 340 | 341 | if sparse: 342 | nnz = len(data) 343 | ii = slice(idx_nnz, idx_nnz + nnz) 344 | indices[0, ii] = row 345 | indices[1, ii] = col 346 | indices[2, ii] = i 347 | values[ii] = data 348 | idx_nnz += nnz 349 | del nnz, ii 350 | else: 351 | new_a[..., i] = coo_matrix((data.astype(np.float32), (row, col)), shape=shape).todense() 352 | 353 | if sparse: 354 | return ( 355 | np.ascontiguousarray(indices[:, :idx_nnz]), 356 | np.ascontiguousarray(values[:idx_nnz]), 357 | mask.shape + (len(matrix_list),) 358 | ), mask 359 | else: 360 | return new_a, mask 361 | 362 | 363 | def correct_batch_effect_pre(matrix_list, data_list): 364 | indices, values, shape = matrix_list 365 | if np.all(np.diff(indices[2]) >= 0): 366 | data_list = [0] + list(np.searchsorted(indices[2], [dl_slice.stop for dl_slice in data_list], side='right')) 367 | data_list = [slice(*_) for _ in zip(data_list[:-1], data_list[1:])] 368 | else: raise NotImplementedError 369 | bulks = [ 370 | np.array(coo_matrix((values[slice_], tuple(indices[:2, slice_])), tuple(shape[:2])).todense()) 371 | / (slice_.stop - slice_.start) 372 | for slice_ in data_list 373 | ] 374 | cols_avg = [bulk.mean(0) for bulk in bulks] 375 | rows_avg = [bulk.mean(1) for bulk in bulks] 376 | avg_func = lambda x: np.mean(x, 0) 377 | # def avg_func(x): 378 | # x = np.stack(x) 379 | # idx = x > 0 380 | # x[~idx] = 1 381 | # y = np.exp(np.log(x).sum(0) / idx.sum(0)) 382 | # return y 383 | # col_avg = cols_avg[0] 384 | # row_avg = rows_avg[0] 385 | # col_avg = np.mean(cols_avg, axis=0) 386 | # row_avg = np.mean(rows_avg, axis=0) 387 | col_avg = avg_func(cols_avg) 388 | row_avg = avg_func(rows_avg) 389 | factors_col = [col_avg / avg for avg in cols_avg] 390 | factors_row = [row_avg / avg for avg in rows_avg] 391 | for slice_, factor_col, factor_row in zip(data_list, factors_col, factors_row): 392 | values[slice_] *= factor_col[indices[1, slice_]] 393 | # values[slice_] *= factor_row[indices[0, slice_]] 394 | assert not np.isnan(values).any() 395 | return matrix_list 396 | 397 | 398 | def downsample(matrix_list, data_list, bulk_list=None, mode='stratum', rate_mode='minimum'): 399 | def slicing(a, i, tolist=True): 400 | if isinstance(i, slice): return a[i] 401 | ret = itertools.compress(a, i) 402 | if tolist: return list(ret) 403 | else: return ret 404 | if bulk_list is None: bulk_list = [calc_bulk(slicing(matrix_list, slc)) for slc in data_list] 405 | L = len(bulk_list[0]) 406 | obs_list = [] 407 | if mode == 'global': 408 | obs_list = [[bulk.sum()] for bulk in bulk_list] 409 | elif mode == 'stratum': 410 | for bulk in bulk_list: 411 | tmp = bulk.copy().ravel()[:-1].reshape(L-1, L+1) 412 | np.cumsum(tmp, axis=0, out=tmp) 413 | obs = np.empty(L) 414 | obs[1:] = tmp.ravel()[L-1::L][::-1] 415 | obs[0] = tmp[-1, 0] + bulk[-1, -1] 416 | assert np.isclose(obs[0], np.diag(bulk).sum(), atol=1e-2, rtol=1e-5) 417 | assert np.isclose(obs[1], np.diag(bulk, 1).sum(), atol=1e-2, rtol=1e-5) 418 | obs_list.append(obs) 419 | del tmp 420 | else: raise NotImplementedError 421 | obs_list = np.array(obs_list) 422 | if rate_mode == 'minimum': 423 | target = obs_list.copy() 424 | print(f'# of empty entries = {(target == 0).any(0).sum()}') 425 | sys.stdout.flush() 426 | target[target == 0] = np.nan 427 | target = np.nanmin(target, 0) 428 | assert not (target <= 0).any() 429 | else: raise NotImplementedError 430 | for slc, obs in zip(data_list, obs_list): 431 | rate = target / obs # It's ok to have nan, because these entries won't be used 432 | if (np.nan_to_num(rate, 1.) == 1).all(): continue 433 | assert (rate[obs > 0] > 0).all() 434 | assert (rate[obs > 0] <= 1).all() 435 | for matrix in tqdm(slicing(matrix_list, slc, tolist=False)): 436 | mask_u = matrix.row > matrix.col 437 | mask_l = matrix.row < matrix.col 438 | assert mask_u.sum() == mask_l.sum() 439 | mask = ~mask_l 440 | if mode == 'global': r = rate 441 | elif mode == 'stratum': r = rate[matrix.row[mask] - matrix.col[mask]] 442 | else: raise NotImplementedError 443 | matrix.data[mask] = np.random.binomial(matrix.data[mask].astype(int), r).astype(np.float32) 444 | matrix.col[mask_l] = matrix.row[mask_u] 445 | matrix.row[mask_l] = matrix.col[mask_u] 446 | matrix.data[mask_l] = matrix.data[mask_u] 447 | matrix.eliminate_zeros() 448 | assert not np.isnan(matrix.data).any() 449 | # t = matrix.todense() 450 | # assert (t == t.T).all() 451 | bulk_list = [calc_bulk(slicing(matrix_list, slc)) for slc in data_list] 452 | library_size_list = [bulk.sum() for bulk in bulk_list] 453 | print(f'library sizes =', ' '.join(map('{:.2e}'.format, library_size_list))) 454 | 455 | 456 | def downsample_clip(matrix_list, count, mode='global'): 457 | assert mode == 'global' 458 | for m in matrix_list: 459 | c = m.data.sum() 460 | if c <= count: continue 461 | r = count / c 462 | mask_u = m.row > m.col 463 | mask_l = m.row < m.col 464 | assert mask_u.sum() == mask_l.sum() 465 | mask = ~mask_l 466 | m.data[mask] = np.random.binomial(m.data[mask].astype(int), r).astype(np.float32) 467 | m.col[mask_l] = m.row[mask_u] 468 | m.row[mask_l] = m.col[mask_u] 469 | m.data[mask_l] = m.data[mask_u] 470 | m.eliminate_zeros() 471 | assert not np.isnan(m.data).any() 472 | 473 | 474 | def filter_bin(matrix_list=None, bulk=None, is_sym=True): 475 | if bulk is None: bulk = calc_bulk(matrix_list) 476 | 477 | def get_mapping(c, l): 478 | v = c > min(0., 0.01 * l) 479 | m = np.cumsum(v) - 1 480 | m[~v] = -1 481 | n = v.sum() 482 | # print(f'{n} out of {len(c)} bins are valid') 483 | return m, n, v 484 | bin_id_mapping_row, num_bins_row, v_row = get_mapping(bulk.sum(1), bulk.shape[1]) 485 | if is_sym: 486 | bin_id_mapping_col, num_bins_col, v_col = bin_id_mapping_row, num_bins_row, v_row 487 | else: 488 | bin_id_mapping_col, num_bins_col, v_col = get_mapping(bulk.sum(0), bulk.shape[0]) 489 | return bin_id_mapping_row, num_bins_row, bin_id_mapping_col, num_bins_col, v_row, v_col 490 | 491 | 492 | def slice_rearrange(matrix, size, fac_size): 493 | new_m = [] 494 | patch_size = min(2 * size + 1, matrix.shape[1]) 495 | if matrix.shape[-1] <= patch_size: 496 | return matrix 497 | for i in range(matrix.shape[0]): 498 | temp = matrix[i, max(int(i / fac_size) - size, 0):min(int(i / fac_size) + size + 1, matrix.shape[0])] 499 | if len(temp) == 0: 500 | print(i - size, i + size + 1, matrix.shape) 501 | raise EOFError 502 | if len(temp) < patch_size: 503 | temp = np.concatenate([temp, np.zeros(patch_size - len(temp))]) 504 | new_m.append(temp) 505 | matrix = np.stack(new_m) 506 | return matrix 507 | 508 | 509 | def kth_diag_indices(a, k): 510 | rows, cols = np.diag_indices_from(a) 511 | if k < 0: 512 | return rows[-k:], cols[:k] 513 | elif k > 0: 514 | return rows[:-k], cols[k:] 515 | else: 516 | return rows, cols 517 | 518 | 519 | def get_expected(matrix): 520 | expected = [] 521 | for k in range(len(matrix)): 522 | diag = np.diag(matrix, k) 523 | expected.append(np.mean(diag)) 524 | return np.array(expected) 525 | 526 | 527 | def oe(matrix, expected=None): 528 | new_matrix = np.zeros_like(matrix) 529 | for k in range(len(matrix)): 530 | rows, cols = kth_diag_indices(matrix, k) 531 | diag = np.diag(matrix, k) 532 | if expected is not None: 533 | expect = expected[k] 534 | else: 535 | expect = np.mean(diag) 536 | if expect == 0: 537 | new_matrix[rows, cols] = 0.0 538 | else: 539 | new_matrix[rows, cols] = diag / (expect) 540 | new_matrix = new_matrix + new_matrix.T - np.diag(np.diagonal(new_matrix)) 541 | return new_matrix 542 | -------------------------------------------------------------------------------- /fasthigashi/project2orthogonal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit as jit 3 | import numpy as np 4 | from functools import partial 5 | 6 | def project2orthogonal(matrix: torch.Tensor, rank:int, compute_device:torch.device): 7 | dim_1, dim_2 = matrix.shape[-2], matrix.shape[-1] 8 | if rank is None: rank = min(matrix.shape[-2:]) 9 | try: 10 | if matrix.shape[-2] / matrix.shape[-1] >= 0.5: 11 | try: 12 | U, S, Vh = torch.linalg.svd(matrix, full_matrices=False, driver='gesvda') 13 | final = U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device) 14 | if torch.any(torch.isnan(final)) or torch.any(torch.isinf(final)): 15 | U, S, Vh = torch.linalg.svd(matrix, full_matrices=False) 16 | final = U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device) 17 | except Exception as e: 18 | U, S, Vh = torch.linalg.svd(matrix, full_matrices=False) 19 | final = U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device) 20 | else: 21 | U, S, Vh = torch.linalg.svd(matrix, full_matrices=False) 22 | final = U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device) 23 | # a = U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device) 24 | # if torch.sum(torch.isnan(a)) > 0: 25 | # print("a", a, torch.sum(torch.isnan(a)), a.shape) 26 | if torch.any(torch.isnan(final)): 27 | print("gesvd & default failed") 28 | raise BaseException 29 | return final, S[..., :rank] 30 | except Exception as e: 31 | print(f'error {e}. using eigh, shape = {matrix.shape}') 32 | 33 | kk = 1e-2 34 | U, S, V, Vh = None, None, None, None 35 | mode = dim_2 > dim_1 36 | X = matrix @ matrix.transpose(-1, -2) if mode else matrix.transpose(-1, -2) @ matrix 37 | t = X.diagonal(dim1=-1, dim2=-2) 38 | t += kk 39 | del t 40 | eigvals, eigvecs = torch.linalg.eigh(X) 41 | eigvecs = eigvecs[..., -rank:].flip(-1) 42 | if mode: 43 | U = eigvecs 44 | Vh = U.transpose(-1, -2) @ matrix 45 | V, R = torch.linalg.qr(Vh.transpose(-1, -2)) 46 | V = V.mul_(R.diagonal(dim1=-1, dim2=-2).sign()[..., None, :]) 47 | else: 48 | V = eigvecs 49 | U = matrix @ V 50 | U, R = torch.linalg.qr(U) 51 | U = U.mul_(R.diagonal(dim1=-1, dim2=-2).sign()[..., None, :]) 52 | UVh = U @ V.transpose(-1, -2) 53 | assert (U.transpose(-1, -2) @ U - torch.eye(U.shape[-1], device=U.device)).max().item() < 1e-4 54 | assert (V.transpose(-1, -2) @ V - torch.eye(V.shape[-1], device=V.device)).max().item() < 1e-4 55 | return UVh, S 56 | 57 | 58 | # @jit.script 59 | # def project2orthogonal(matrix: torch.Tensor, rank:int, compute_device:torch.device): 60 | # dim_1, dim_2 = matrix.shape[-2], matrix.shape[-1] 61 | # if rank is None: rank = min(matrix.shape[-2:]) 62 | # # try: 63 | # U, S, Vh = torch.linalg.svd(matrix, full_matrices=False) 64 | # return U[..., :rank].to(compute_device) @ Vh[..., :rank, :].to(compute_device), S[..., :rank] 65 | # 66 | # @jit.script 67 | # def project2orthogonal_ill(matrix: torch.Tensor, rank:int, compute_device:torch.device): 68 | # dim_1, dim_2 = matrix.shape[-2], matrix.shape[-1] 69 | # if rank is None: rank = min(matrix.shape[-2:]) 70 | # # except Exception as e: 71 | # print('ill conditioned matrix. using eigh, shape = {matrix.shape}') 72 | # 73 | # kk = 1e-2 74 | # # U, S, V, Vh = None, None, None, None 75 | # mode = dim_2 > dim_1 76 | # X = matrix @ matrix.transpose(-1, -2) if mode else matrix.transpose(-1, -2) @ matrix 77 | # t = X.diagonal(dim1=-1, dim2=-2) 78 | # t += kk 79 | # del t 80 | # eigvals, eigvecs = torch.linalg.eigh(X) 81 | # eigvecs = eigvecs[..., -rank:].flip(-1) 82 | # if mode: 83 | # U = eigvecs 84 | # Vh = U.transpose(-1, -2) @ matrix 85 | # V, R = torch.linalg.qr(Vh.transpose(-1, -2)) 86 | # V = V.mul_(R.diagonal(dim1=-1, dim2=-2).sign()[..., None, :]) 87 | # else: 88 | # V = eigvecs 89 | # U = matrix @ V 90 | # U, R = torch.linalg.qr(U) 91 | # U = U.mul_(R.diagonal(dim1=-1, dim2=-2).sign()[..., None, :]) 92 | # UVh = U @ V.transpose(-1, -2) 93 | # assert (U.transpose(-1, -2) @ U - torch.eye(U.shape[-1], device=U.device)).max().item() < 1e-4 94 | # assert (V.transpose(-1, -2) @ V - torch.eye(V.shape[-1], device=V.device)).max().item() < 1e-4 95 | # return UVh.to(compute_device), None 96 | # 97 | 98 | def torch_svd_eigh(matrix, rank=None): 99 | if rank is None: rank = min(matrix.shape[-2:]) 100 | dim_1, dim_2 = matrix.shape[-2], matrix.shape[-1] 101 | kk = 1 102 | kkk = 1e-10 103 | if dim_2 > dim_1: 104 | X = matrix @ matrix.transpose(-1, -2) 105 | t = X.diagonal(dim1=-1, dim2=-2) 106 | t += kk 107 | del t 108 | S, U = torch.linalg.eigh(X) 109 | S = S[..., -rank:].flip(-1).sub_(kk).clip_(min=kkk).sqrt_() 110 | U = U[..., -rank:].flip(-1) 111 | Vh = (U.transpose(-1, -2) @ matrix).div_(S[..., None]) 112 | else: 113 | X = matrix.transpose(-1, -2) @ matrix 114 | t = X.diagonal(dim1=-1, dim2=-2) 115 | t += kk 116 | del t 117 | S, V = torch.linalg.eigh(X) 118 | S = S[..., -rank:].flip(-1).sub_(kk).clip_(min=kkk).sqrt_() 119 | V = V[..., -rank:].flip(-1) 120 | U = (matrix @ V).div_(S[..., None, :]) 121 | Vh = V.transpose(-1, -2) 122 | # assert torch.linalg.norm(U * S[..., None, :] @ Vh - matrix) / torch.linalg.norm(matrix) < 1e-6 123 | assert torch.linalg.norm(U.transpose(-1, -2) @ U - torch.eye(U.shape[-1], device=U.device)) / len(matrix) < 1e-6 124 | assert torch.linalg.norm(Vh @ Vh.transpose(-1, -2) - torch.eye(Vh.shape[-2], device=Vh.device)) / len(matrix) < 1e-6 125 | return U, S, Vh -------------------------------------------------------------------------------- /fasthigashi/sparse_for_schic.py: -------------------------------------------------------------------------------- 1 | import copy, time 2 | import gc 3 | import math 4 | import numpy as np 5 | from scipy.sparse._sparsetools import coo_tocsr 6 | from scipy.sparse import coo_matrix, csr_matrix 7 | import torch 8 | from typing import Dict, List, Tuple 9 | from tqdm import tqdm, trange 10 | #TODO: use numpy.unravel_index 11 | 12 | gpu_flag = torch.cuda.is_available() 13 | # gpu_flag = False 14 | 15 | def calc_offset(dims): 16 | offset = np.ones_like(dims) 17 | offset[:-1] = np.cumprod(dims[:0:-1])[::-1] 18 | return offset 19 | 20 | 21 | # May consider negative values in the future. 22 | def get_minimum_dtype(max_value): 23 | for dtype in [np.int8, np.int16, np.int32, np.int64]: 24 | if max_value <= np.iinfo(dtype).max: return dtype 25 | raise ValueError 26 | 27 | 28 | def unfold(indices, shape): 29 | offset = calc_offset(shape) 30 | indices = indices.astype(shape.dtype, copy=False) 31 | indices = offset @ indices 32 | # indices = offset[:-1] @ indices[:-1] + indices[-1] 33 | # indices = (indices * offset[:, None]).sum(0) 34 | assert (indices >= 0).all() 35 | assert (indices < np.prod(shape)).all() 36 | return indices 37 | 38 | 39 | def fold(indices, shape, out=None): 40 | offset = calc_offset(shape) 41 | # indices = indices[None] // offset[:, None] 42 | # indices[1:] %= offset[:-1, None] 43 | if out is None: indices_new = np.empty([len(shape), len(indices)], dtype=indices.dtype) 44 | else: indices_new = out 45 | del out 46 | # out[0] = indices 47 | # indices = out 48 | # del out 49 | # for i in range(len(shape)-1): 50 | # indices[i], indices[i+1] = divmod(indices[i], offset[i]) 51 | for i in range(len(shape)-1): 52 | indices_new[i], indices = divmod(indices, offset[i]) 53 | indices_new[-1] = indices 54 | return indices_new 55 | 56 | 57 | # Some memory are shared. Need caution. 58 | class Sparse: 59 | def __init__( 60 | self, indices, values, shape, indptr=None, copy=True, verbose=False, 61 | ): 62 | self.ndim = len(shape) 63 | self.indices = np.array(indices, copy=copy) 64 | self.values = np.array(values, copy=copy) 65 | self.shape = np.array(shape, copy=copy) 66 | self.indptr = None if indptr is None else np.array(indptr, copy=copy) 67 | self.verbose = verbose 68 | assert self.indices.shape == (self.ndim, len(self.values)) 69 | assert (self.indices >= 0).all() 70 | assert (self.indices < self.shape[:, None]).all(), (self.indices.max(1), self.shape) 71 | 72 | def scale(self, factor=1): 73 | self.shape = [int(math.ceil(self.shape[0] / factor)), 74 | int(math.ceil(self.shape[1] / factor)), 75 | self.shape[2] 76 | ] 77 | self.sort_indices() 78 | self.indices[0] = np.floor(self.indices[0] / factor) 79 | self.indices[1] = np.floor(self.indices[1] / factor) 80 | self.indptr = None 81 | 82 | #sum_duplicates 83 | unique, inv, unique_counts = np.unique(self.indices.T, axis=0, return_inverse=True, return_counts=True) 84 | new_count = np.zeros_like(unique_counts, dtype='float32') 85 | for i, iv in enumerate(inv): 86 | new_count[iv] += self.values[i] 87 | self.indices = unique.T 88 | self.values = new_count 89 | return 90 | 91 | def filter_max_distance(self, max_distance=100): 92 | distance = np.abs(self.indices[1] - self.indices[0]) 93 | mask = distance <= max_distance 94 | self.indices = np.asarray([a[mask] for a in self.indices]) 95 | self.values = self.values[mask] 96 | self.indptr = None 97 | 98 | def permute(self, *dims, inplace=False): 99 | assert tuple(sorted(dims)) == tuple(range(self.ndim)) 100 | dims = np.array(dims) 101 | indices = self.indices[dims] 102 | shape = self.shape[dims] 103 | indptr = self.indptr if dims[0] == 0 else None 104 | if inplace: 105 | self.indices = indices 106 | self.shape = shape 107 | self.indptr = indptr 108 | return self 109 | else: 110 | return Sparse(indices, self.values, shape, indptr, verbose=self.verbose) 111 | 112 | def reshape(self, *dims, inplace=False): 113 | dims = np.array(dims) 114 | if (dims == -1).any(): 115 | assert (dims == -1).sum() <= 1, dims 116 | assert self.shape.prod() % dims.prod() == 0 117 | dims[dims == -1] = self.shape.prod(dtype=self.shape.dtype) // -dims.prod(dtype=self.shape.dtype) 118 | assert self.shape.prod() == dims.prod() 119 | indices = self.indices 120 | indices = unfold(indices, self.shape) 121 | indices = fold(indices, dims) 122 | if inplace: 123 | self.indices = indices 124 | self.shape = dims 125 | self.ndim = len(dims) 126 | return self 127 | else: 128 | return Sparse( 129 | indices, self.values, dims, verbose=self.verbose, 130 | indptr=self.indptr if dims[0] == 0 else None, 131 | ) 132 | 133 | def sort_indices(self, dim=0, force=False): 134 | assert 0 <= dim < self.ndim 135 | assert dim == 0 136 | if not force and self.indptr is not None: return 137 | d = np.diff(self.indices[dim]) 138 | if (d >= 0).all(): 139 | # print('indices are sorted') 140 | indptr = np.full(self.shape[dim]+1, len(self.values)+1, dtype=int) 141 | # indptr[0] = 0 142 | idx = np.concatenate([[0], np.nonzero(d != 0)[0]+1]) 143 | indptr[self.indices[dim, idx]] = idx 144 | indptr[-1] = len(self.values) 145 | np.minimum.accumulate(indptr[::-1], out=indptr[::-1]) 146 | # for i, (l, r) in enumerate(zip(indptr[:-1], indptr[1:])): 147 | # assert (self.indices[dim, l:r] == i).all() 148 | # assert (np.diff(indptr) >= 0).all() 149 | self.indptr = indptr 150 | 151 | return 152 | # print('sorting indices') 153 | _t = time.perf_counter() 154 | dim_other = np.arange(self.ndim) != dim 155 | row = self.indices[dim] 156 | col = self.indices[dim_other] 157 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 158 | col = unfold(col, self.shape[dim_other]) 159 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 160 | M = self.shape[dim] 161 | N = col.max()+1 162 | indptr = np.empty(M + 1, dtype=col.dtype) 163 | indices = np.empty_like(col) 164 | values = np.empty_like(self.values) 165 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 166 | coo_tocsr( 167 | M, N, len(self.values), row, col, self.values, 168 | indptr, indices, values, 169 | ) 170 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 171 | self.indices[dim] = np.repeat(np.arange(M), np.diff(indptr)) 172 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 173 | self.indices[dim_other] = fold(indices, self.shape[dim_other]) 174 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 175 | self.values[:] = values 176 | # self.values = values 177 | self.indptr = indptr 178 | # print(f'time elapsed = {time.perf_counter() - _t:.2e}') 179 | # for i, (l, r) in enumerate(zip(indptr[:-1], indptr[1:])): 180 | # assert (self.indices[dim, l:r] == i).all() 181 | if self.verbose: 182 | print(f'time used in sorting indices = {time.perf_counter() - _t:.2e}') 183 | 184 | def slicing(self, idx, dim=0): 185 | assert dim == 0 186 | assert isinstance(idx, slice) 187 | assert idx.step is None or idx.step == 1 188 | if self.indptr is None: self.sort_indices(dim) 189 | start = idx.start if idx.start is not None else 0 190 | stop = idx.stop if idx.stop is not None else self.shape[dim] 191 | stop = min(stop, self.shape[dim]) 192 | indptr = self.indptr[start: stop+1] 193 | idx = slice(indptr[0], indptr[-1]) 194 | indices = self.indices[:, idx].copy() 195 | # assert (indices[dim] >= start).all() 196 | # assert (indices[dim] < stop).all() 197 | indices[dim] -= start 198 | return Sparse(indices, self.values[idx], (stop-start,) + tuple(self.shape[1:])) 199 | 200 | @torch.no_grad() 201 | def get_slice_idx_value(self, idx, dim=0, device='cpu'): 202 | assert dim == 0 203 | assert isinstance(idx, slice) 204 | assert idx.step is None or idx.step == 1 205 | if self.indptr is None: self.sort_indices(dim) 206 | start = idx.start if idx.start is not None else 0 207 | stop = idx.stop if idx.stop is not None else self.shape[dim] 208 | stop = min(stop, self.shape[dim]) 209 | indptr = self.indptr[start: stop + 1] 210 | idx = slice(indptr[0], indptr[-1]) 211 | indices = self.indices[:, idx] 212 | v = self.values[idx] 213 | return indices, v, (stop-start,) + tuple(self.shape[1:]), start 214 | 215 | def indexing(self, idx, dim=0): 216 | assert dim == 0 217 | assert isinstance(idx, int) 218 | assert 0 <= idx < self.shape[dim], (self.shape, dim, idx) 219 | if self.indptr is None: self.sort_indices(dim) 220 | idx = slice(self.indptr[idx], self.indptr[idx+1]) 221 | return Sparse(self.indices[1:, idx], self.values[idx], tuple(self.shape[1:])) 222 | 223 | def __getitem__(self, item): 224 | if isinstance(item, int): return self.indexing(item) 225 | elif isinstance(item, slice): return self.slicing(item) 226 | else: raise NotImplementedError 227 | 228 | def to_dense(self): 229 | v = np.zeros(tuple(self.shape)) 230 | indices = tuple(self.indices) 231 | values = self.values 232 | v[indices] = values 233 | return v 234 | 235 | def to_scipy(self): 236 | return coo_matrix((self.values, tuple(self.indices)), tuple(self.shape)) 237 | 238 | def to_csr(self): 239 | return csr_matrix((self.values, tuple(self.indices)), tuple(self.shape)) 240 | 241 | def to_pytorch(self, **context): 242 | return torch.sparse_coo_tensor( 243 | # np.ascontiguousarray(self.indices), 244 | # np.ascontiguousarray(self.values), 245 | self.indices, 246 | self.values, 247 | self.shape.tolist(), 248 | **context, 249 | ) 250 | 251 | def __len__(self): 252 | return self.shape[0] 253 | 254 | def split(self, bins, dim): 255 | bins = list(bins) 256 | chunk = np.digitize(self.indices[dim], bins) 257 | order = np.argsort(chunk, kind='stable') 258 | chunk = chunk[order] 259 | boundaries = [0] + (np.nonzero(chunk[:-1] != chunk[1:])[0]+1).tolist() + [len(order)] 260 | slices = [slice(*_) for _ in zip(boundaries[:-1], boundaries[1:])] 261 | # for i, slc in enumerate(slices): 262 | # assert i == 0 or (bins[i-1] <= self.indices[dim, order[slc]]).all() 263 | # assert i == len(slices)-1 or (self.indices[dim, order[slc]] < bins[i]).all() 264 | def f(indices, offset, dim): 265 | indices = indices.copy() 266 | indices[dim] -= offset 267 | return indices 268 | return (Sparse( 269 | f(self.indices[:, order[slc]], start, dim), 270 | self.values[order[slc]], 271 | tuple(self.shape[:dim]) + (stop-start,) + tuple(self.shape[dim+1:]), 272 | ) for slc, start, stop in zip(slices, [0] + bins, bins + [self.shape[dim]])) 273 | 274 | def numel(self): 275 | return int(np.prod(self.shape)) 276 | 277 | import torch.jit as jit 278 | # @jit.script 279 | def densify_jit(shape:torch.Tensor, 280 | indices:List[torch.Tensor], 281 | values:torch.Tensor, 282 | device:torch.device, 283 | transpose:bool=False, 284 | do_conv:bool=False): 285 | 286 | if transpose: 287 | shape_local = [int(shape[2]), int(shape[0]), int(shape[1])] 288 | else: 289 | shape_local = [int(shape[0]), int(shape[1]), int(shape[2])] 290 | 291 | dense_tensor = torch.zeros(shape_local, device=device, dtype=values.dtype) 292 | values = values.to(device, non_blocking=True) 293 | 294 | cell_indices = indices[2].to(device, non_blocking=True).long() 295 | indices_0 = indices[0].to(device, non_blocking=True).long() 296 | indices_1 = indices[1].to(device, non_blocking=True).long() 297 | 298 | 299 | if do_conv: 300 | count = 0 301 | for id_0 in [-1, 0, 1]: 302 | for id_1 in [-1, 0, 1]: 303 | if transpose: 304 | dense_tensor[cell_indices, indices_0 + id_0, indices_1 + id_1] += values 305 | else: 306 | dense_tensor[indices_0 + id_0, indices_1 + id_1, cell_indices] += values 307 | count += 1 308 | dense_tensor /= count 309 | 310 | else: 311 | if transpose: 312 | dense_tensor[cell_indices, indices_0, indices_1] = values 313 | else: 314 | dense_tensor[indices_0, indices_1, cell_indices] = values 315 | 316 | if transpose: 317 | dense_tensor = dense_tensor[:, 1:-1, 1:-1] 318 | else: 319 | dense_tensor = dense_tensor[1:-1, 1:-1, :] 320 | return dense_tensor.clamp_(min=1e-8) 321 | 322 | class Fake_Sparse: 323 | def __init__(self, slice_, indices, values, shape): 324 | self.slice_ = slice_ 325 | 326 | 327 | self.shape = [shape[0]+2, shape[1]+2, shape[2]] 328 | 329 | if gpu_flag: 330 | self.indices = [torch.tensor(indices[0]+1, dtype=torch.short).pin_memory(), 331 | torch.tensor(indices[1]+1, dtype=torch.short).pin_memory(), 332 | torch.tensor(indices[2], dtype=torch.int).pin_memory()] 333 | self.values = torch.tensor(values, dtype=torch.float32).pin_memory()#[mask] 334 | 335 | else: 336 | self.indices = [torch.tensor(indices[0]+1, dtype=torch.int).contiguous(), 337 | torch.tensor(indices[1]+1, dtype=torch.int).contiguous(), 338 | torch.tensor(indices[2], dtype=torch.int).contiguous()] 339 | self.values = torch.tensor(values, dtype=torch.float32).contiguous() # [mask] 340 | 341 | 342 | def compress(self, flank): 343 | print ("compressing") 344 | self.shape = [self.shape[0], self.shape[0] + 2 * flank, self.shape[2]] 345 | self.indices[1] = self.indices[1] - self.slice_.start + flank 346 | 347 | def pin_memory(self): 348 | self.values = self.values.pin_memory() 349 | self.indices = [_.pin_memory() for _ in self.indices] 350 | return self 351 | 352 | def densify(self, save_context, transpose=False, do_conv=False, out=None): 353 | return densify_jit(torch.as_tensor(self.shape), self.indices, self.values, save_context['device'], transpose, do_conv) 354 | 355 | 356 | class Chrom_Dataset: 357 | def __init__(self, tensor, bs_bin, bs_cell, good_qc_num=-1, kind='hic', 358 | upper_sim=False, compact=False, flank=0, chrom='chr1', resolution=10000): 359 | # tensor: big sparse or dense tensor 360 | # bs_bin: batch_size for bin 361 | # bs_cell: batch_size for cell 362 | # good_qc_num: the first good_qc_num cells are good cells 363 | # kind: is it hic or 1d signals 364 | # upper_sim: stores only the upper triangle signals or 365 | # compact: if compact if True: return matrix is size of (bs_bin, bs_bin + 2*flank, bs_cell) otherwise (bs_bin, all bins, bs_cell) 366 | # flank: flanking region size, also equivalent to the max distance. 367 | # chrom: the chromosome of this dataset (which can associate it with other dataset of same chrom but different resolution) 368 | # resolution: the resolution of this dataset 369 | 370 | self.resolution = resolution 371 | self.chrom = chrom 372 | self.length = tensor.shape[0] 373 | if good_qc_num == -1: 374 | good_qc_num = tensor.shape[-1] 375 | self.num_cell = good_qc_num 376 | self.total_cell_num = tensor.shape[-1] 377 | self.num_bin = tensor.shape[0] 378 | self.shape = [self.num_bin, tensor.shape[1], self.num_cell] 379 | self.bs_cell = bs_cell 380 | self.bs_bin = bs_bin 381 | 382 | self.tensor_list = [] 383 | self.bad_tensor_list = [] 384 | self.bin_slice_list = [] # this is the global one, indicating for a contact map of (n_bin, n_bin) where this small tensor correspond to 385 | self.local_bin_slice_list = [] # this is the local one, for compact map, it indicates which col slice in the small compact map correspond to 386 | self.cell_slice_list = [] 387 | self.bad_cell_slice_list = [] 388 | self.col_bin_slice_list = [] 389 | 390 | self.kind_list = [] 391 | self.bad_kind_list = [] 392 | self.upper_sim = upper_sim 393 | self.compact = compact 394 | self.flank = flank 395 | 396 | # the tensor list is ordered by: 397 | # - n_batch_bin 398 | # - - n_batch_cell 399 | 400 | count = 0 401 | for i in range(0, self.num_bin, bs_bin): 402 | # Fetch and densify the X 403 | slice_ = slice(i, i + bs_bin) 404 | self.tensor_list.append([]) 405 | self.bad_tensor_list.append([]) 406 | self.kind_list.append([]) 407 | self.bad_kind_list.append([]) 408 | if not (type(tensor) is torch.Tensor): 409 | indices, values, shape, start = tensor.get_slice_idx_value(slice_, device='cpu') 410 | if self.upper_sim: 411 | # For a rectangular data slice: [x1:x2, :], the duplicated part is the lower triangular of [x1:x2, x1:x2] 412 | # Thus store upper triangle of [x1:x2, :] or , the left part of [x1:x2, :x1] 413 | mask = (indices[1, :] < start) | (indices[1, :] >= indices[0, :]) 414 | indices = indices[:, mask] 415 | values = values[mask] 416 | # Because we'll do a = a + a.T, diag needs to be divided by 2 417 | mask2 = (indices[1, :] == indices[0, :]) 418 | values[mask2] *= 0.5 419 | indices[0, :] -= start 420 | 421 | 422 | cell_start_point = list(np.arange(0, self.num_cell, self.bs_cell)) + \ 423 | list(np.arange(self.num_cell, tensor.shape[-1], self.bs_cell)) 424 | 425 | for cell_index in cell_start_point: 426 | if cell_index < self.num_cell: 427 | rhs = self.num_cell 428 | storage = self.tensor_list 429 | storage_kind = self.kind_list 430 | else: 431 | rhs = tensor.shape[-1] 432 | storage = self.bad_tensor_list 433 | storage_kind = self.bad_kind_list 434 | 435 | cell_start = cell_index 436 | cell_end = min(cell_index + self.bs_cell, rhs) 437 | 438 | if i == 0: 439 | self.cell_slice_list.append(slice(cell_start, cell_end)) 440 | 441 | if (type(tensor) is torch.Tensor): 442 | t = tensor[slice_, :, slice(cell_index, cell_end)] 443 | storage[count].append(t) 444 | storage_kind[count].append(kind) 445 | if cell_index == 0: 446 | self.bin_slice_list.append(slice(i, i + t.shape[0])) 447 | self.local_bin_slice_list.append(slice(i, i + t.shape[0])) 448 | self.col_bin_slice_list.append(slice(None)) 449 | continue 450 | 451 | 452 | mask = (indices[2] >= cell_index) & (indices[2] < cell_end) 453 | local_shape = (shape[0], shape[1], min(cell_end, shape[-1]) - cell_start) 454 | local_indices = [indices[0][mask], indices[1][mask], indices[2][mask] - cell_start] 455 | local_values = values[mask] 456 | 457 | if cell_index == 0: 458 | self.bin_slice_list.append(slice(i, i+shape[0])) 459 | if self.compact: 460 | if i > self.flank: 461 | self.local_bin_slice_list.append(slice(self.flank, self.flank + shape[0])) 462 | extend_shape_right = 0 463 | if self.num_bin - i - shape[0] - flank > 0: 464 | extend_shape_right += flank 465 | else: 466 | extend_shape_right += self.num_bin - i - shape[0] 467 | 468 | self.col_bin_slice_list.append(slice(i-self.flank, i+shape[0]+extend_shape_right)) 469 | else: 470 | self.local_bin_slice_list.append(slice(i, i + shape[0])) 471 | extend_shape_right = 0 472 | if self.num_bin - i - shape[0] - flank > 0: 473 | extend_shape_right += flank 474 | else: 475 | extend_shape_right += self.num_bin - i - shape[0] 476 | self.col_bin_slice_list.append(slice(0, i + shape[0] + extend_shape_right)) 477 | else: 478 | self.local_bin_slice_list.append(slice(i, i + shape[0])) 479 | self.col_bin_slice_list.append(slice(None)) 480 | 481 | if compact: 482 | extend_shape = 0 483 | if i > self.flank: 484 | local_indices[1] = local_indices[1] - i + flank 485 | extend_shape += self.flank 486 | else: 487 | extend_shape += i 488 | if self.num_bin - i - shape[0] - flank > 0: 489 | extend_shape += flank 490 | else: 491 | extend_shape += self.num_bin - i - shape[0] 492 | local_shape = [local_shape[0], shape[0] + extend_shape, local_shape[2]] 493 | # print("esr", extend_shape_right, self.num_bin, i, shape[0], local_shape) 494 | # if np.sum(np.max(local_indices, axis=1) > np.asarray(local_shape)) > 0: 495 | # print (shape, i, extend_shape, self.num_bin) 496 | # print (chrom, local_shape, np.min(local_indices, axis=-1), np.max(local_indices, axis=1), 497 | # np.min(local_values), np.max(local_values)) 498 | # raise EOFError 499 | storage[count].append(Fake_Sparse(slice(i, i+shape[0]), local_indices, local_values, local_shape)) 500 | storage_kind[count].append(kind) 501 | 502 | 503 | count += 1 504 | 505 | self.num_bin_batch = len(self.tensor_list) 506 | self.num_cell_batch = len(self.tensor_list[0]) 507 | self.num_cell_batch_bad = len(self.bad_tensor_list[0]) 508 | self.uniq_kind = np.unique(self.kind_list) 509 | if compact: 510 | self.shape = [self.num_bin, bs_bin + 2 * flank, self.num_cell] 511 | 512 | def __len__(self): 513 | return self.length 514 | 515 | # hasn't adapted to multires 516 | def append_dim0(self, tensor, good_qc_num=-1, kind='hic'): 517 | if good_qc_num == -1: 518 | good_qc_num = tensor.shape[-1] 519 | bs_bin = self.bs_bin 520 | count = self.num_bin_batch 521 | 522 | for i in range(0, tensor.shape[0], bs_bin): 523 | # Fetch and densify the X 524 | slice_ = slice(i, i + bs_bin) 525 | self.tensor_list.append([]) 526 | self.bad_tensor_list.append([]) 527 | self.kind_list.append([]) 528 | self.bad_kind_list.append([]) 529 | if not (type(tensor) is torch.Tensor): 530 | indices, values, shape, start = tensor.get_slice_idx_value(slice_, device='cpu') 531 | indices[0, :] -= start 532 | for cell_index in range(0, self.num_cell, self.bs_cell): 533 | cell_start = cell_index 534 | cell_end = min(cell_index + self.bs_cell, self.num_cell) 535 | if (type(tensor) is torch.Tensor): 536 | t = tensor[slice_, :, slice(cell_index, cell_end)] 537 | self.tensor_list[count].append(t) 538 | self.kind_list[count].append(kind) 539 | if cell_index == 0: 540 | self.bin_slice_list.append(slice(i+self.num_bin, i+self.num_bin+t.shape[0])) 541 | continue 542 | mask = (indices[2] >= cell_index) & (indices[2] < cell_end) 543 | local_shape = (shape[0], shape[1], min(cell_end, shape[-1]) - cell_start) 544 | local_indices = [indices[0][mask], indices[1][mask], indices[2][mask] - cell_start] 545 | local_values = values[mask] 546 | if cell_index == 0: 547 | self.bin_slice_list.append(slice(i + self.num_bin, i + self.num_bin + shape[0])) 548 | self.tensor_list[count].append(Fake_Sparse(local_indices, local_values, local_shape)) 549 | self.kind_list[count].append(kind) 550 | for cell_index in range(self.num_cell, tensor.shape[-1], self.bs_cell): 551 | cell_start = cell_index 552 | cell_end = min(cell_index + self.bs_cell, tensor.shape[-1]) 553 | if (type(tensor) is torch.Tensor): 554 | self.bad_tensor_list[count].append(tensor[slice_, :, slice(cell_index, cell_end)]) 555 | self.bad_kind_list[count].append(kind) 556 | continue 557 | mask = (indices[2] >= cell_index) & (indices[2] < cell_end) 558 | local_shape = (shape[0], shape[1], min(cell_end, shape[-1]) - cell_start) 559 | local_indices = [indices[0][mask], indices[1][mask], indices[2][mask] - cell_start] 560 | local_values = values[mask] 561 | self.bad_tensor_list[count].append(Fake_Sparse(local_indices, local_values, local_shape)) 562 | self.bad_kind_list[count].append(kind) 563 | 564 | count += 1 565 | 566 | self.num_bin += tensor.shape[0] 567 | self.shape[0] += tensor.shape[0] 568 | 569 | 570 | 571 | self.num_bin_batch = len(self.tensor_list) 572 | self.num_cell_batch = len(self.tensor_list[0]) 573 | self.num_cell_batch_bad = len(self.bad_tensor_list[0]) 574 | self.uniq_kind = np.unique(self.kind_list) 575 | 576 | def pin_memory(self): 577 | for i in range(len(self.tensor_list)): 578 | for j in range(len(self.tensor_list[i])): 579 | self.tensor_list[i][j] = self.tensor_list[i][j].pin_memory() 580 | 581 | for i in range(len(self.bad_tensor_list)): 582 | for j in range(len(self.bad_tensor_list[i])): 583 | self.bad_tensor_list[i][j] = self.bad_tensor_list[i][j].pin_memory() 584 | 585 | def fetch_bad(self, bin_id, cell_id, **kwargs): 586 | return self.fetch(bin_id, cell_id, good_qc=False, **kwargs) 587 | 588 | def fetch(self, bin_id, cell_id, save_context, transpose=False, good_qc=True, **kwargs): 589 | if good_qc: 590 | temp = self.tensor_list[bin_id][cell_id] 591 | kind = self.kind_list[bin_id][cell_id] 592 | else: 593 | temp = self.bad_tensor_list[bin_id][cell_id] 594 | kind = self.bad_kind_list[bin_id][cell_id] 595 | 596 | transpose = False if kind != 'hic' else transpose 597 | if type(temp) is torch.Tensor: 598 | if transpose: 599 | temp = temp.permute(2, 0, 1).to(save_context['device']) 600 | return (temp.to(save_context['device']), [0]), kind 601 | else: 602 | _t = time.perf_counter() 603 | a = temp.densify(save_context, transpose, **kwargs) 604 | if self.upper_sim and kind == 'hic': 605 | # Why local? because local indicates the "diag" in the local compact map 606 | slice_ = self.local_bin_slice_list[bin_id] 607 | 608 | if transpose: 609 | a[:, :, slice_] = a[:, :, slice_] + a[:, :, slice_].permute(0, 2, 1) 610 | else: 611 | a[:, slice_, :] = a[:, slice_, :] + a[:, slice_, :].permute(1, 0, 2) 612 | b = [time.perf_counter() - _t] 613 | return (a, b), kind 614 | 615 | def norm(self): 616 | total_norm = 0 617 | for i in range(len(self.tensor_list)): 618 | for j in range(len(self.tensor_list[i])): 619 | total_norm += self.tensor_list[i][j].values.square().sum() 620 | return torch.sqrt(total_norm).item() 621 | 622 | def replace(self, bin_id, cell_id, dense_tensor, sparse_ratio): 623 | # self.tensor_list[bin_id][cell_id] = dense_tensor 624 | print (self.tensor_list[bin_id][cell_id].indices[0].shape) 625 | cutoff = torch.quantile(dense_tensor.permute(2, 0, 1).reshape(dense_tensor.shape[2], -1), 1 - sparse_ratio, dim=1) 626 | local_indices = (dense_tensor > cutoff[None, None, :]).nonzero().T 627 | local_values = dense_tensor[local_indices[0], local_indices[1], local_indices[2]] 628 | print (local_indices.shape, local_values.shape, torch.prod(torch.tensor(dense_tensor.shape))) 629 | # del self.tensor_list[bin_id][cell_id] 630 | self.tensor_list[bin_id][cell_id] = Fake_Sparse(local_indices, 631 | local_values, dense_tensor.shape) 632 | gc.collect() 633 | 634 | def test(): 635 | def new_obj(): 636 | return Sparse([[0, 2, 1, 0], [0, 3, 2, 1]], [1, 2, 3, 4.], (3, 4)) 637 | base = new_obj().to_scipy().todense() 638 | print(base) 639 | o = new_obj() 640 | o.sort_indices() 641 | assert (o.to_scipy().todense() == base).all() 642 | o = new_obj() 643 | o = o.permute(0, 1) 644 | assert (o.to_scipy().todense() == base).all() 645 | o = new_obj() 646 | o = o.permute(1, 0) 647 | assert (o.to_scipy().todense() == base.T).all() 648 | o = new_obj() 649 | o = o.reshape(1, 12) 650 | assert (o.to_scipy().todense() == base.reshape(1, 12)).all() 651 | o = new_obj() 652 | o = o.reshape(4, 3) 653 | assert (o.to_scipy().todense() == base.reshape(4, 3)).all() 654 | for s in [slice(2), slice(10), slice(0, None), slice(2, None), slice(1, 3)]: 655 | o = new_obj() 656 | o = o.slicing(s) 657 | assert (o.to_scipy().todense() == base[s]).all() 658 | 659 | o = new_obj() 660 | o = o.permute(1, 0).reshape(6, 2).slicing(slice(1, 3)).permute(1, 0) 661 | assert (o.to_scipy().todense() == base.T.reshape(6, 2)[1: 3].T).all() 662 | 663 | 664 | if __name__ == '__main__': 665 | test() 666 | -------------------------------------------------------------------------------- /fasthigashi/util.py: -------------------------------------------------------------------------------- 1 | import math, time, itertools, gc, os, pickle, sys, copy 2 | import numpy as np, pandas as pd 3 | 4 | import torch 5 | from tqdm.auto import tqdm, trange 6 | 7 | from sklearn.neighbors import NearestNeighbors 8 | import scipy.sparse 9 | from scipy.sparse import coo_matrix, csr_matrix 10 | 11 | 12 | def shift_csr(a, u, v, m, n): 13 | a = a.tocsr() 14 | indptr = np.full(m + 1, a.indptr[-1]) 15 | indptr[u] = a.indptr[:-1] 16 | np.minimum.accumulate(indptr[::-1], out=indptr[::-1]) 17 | indices = v[a.indices] 18 | return csr_matrix((a.data, indices, indptr), shape=(m, n)) 19 | 20 | 21 | def shift_coo(a, u, v, m, n): 22 | a = a.tocoo() 23 | return coo_matrix((a.data, (u[a.row], v[a.col])), shape=(m, n)) 24 | 25 | 26 | def trim_sparse(a, lb=-np.inf, ub=np.inf): 27 | a.data[(a.data < lb) | (a.data > ub)] = 0. 28 | a.sum_duplicates() 29 | a.prune() 30 | return a 31 | 32 | 33 | def load_data_frame(path2file, open_fn=open, delimiter=',', dtype=float): 34 | decode = lambda s: s if isinstance(s, str) else s.decode() 35 | with open_fn(path2file) as f: 36 | cols = decode(f.readline()).strip().split(delimiter) 37 | rows = [] 38 | values = [] 39 | for line in tqdm(f): 40 | line = decode(line).strip().split(delimiter) 41 | rows.append(line[0]) 42 | values.append(list(map(dtype, line[1:]))) 43 | return pd.DataFrame( 44 | data=np.array(values), 45 | index=pd.Series(data=rows, name=cols[0]), 46 | columns=cols[1:] 47 | ) 48 | 49 | 50 | def is_oom_error(exception: BaseException) -> bool: 51 | print (isinstance(exception, RuntimeError), len(exception.args), "CUDA" in exception.args[0], "out of memory" in exception.args[0]) 52 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 53 | 54 | 55 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 56 | def is_cuda_out_of_memory(exception: BaseException) -> bool: 57 | return ( 58 | isinstance(exception, RuntimeError) 59 | and len(exception.args) == 1 60 | and "CUDA" in exception.args[0] 61 | and "out of memory" in exception.args[0] 62 | ) 63 | 64 | 65 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 66 | def is_cudnn_snafu(exception: BaseException) -> bool: 67 | # For/because of https://github.com/pytorch/pytorch/issues/4107 68 | return ( 69 | isinstance(exception, RuntimeError) 70 | and len(exception.args) == 1 71 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 72 | ) 73 | 74 | 75 | # based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py 76 | def is_out_of_cpu_memory(exception: BaseException) -> bool: 77 | return ( 78 | isinstance(exception, RuntimeError) 79 | and len(exception.args) == 1 80 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 81 | ) 82 | 83 | 84 | def garbage_collection_cuda() -> None: 85 | """Garbage collection Torch (CUDA) memory.""" 86 | gc.collect() 87 | try: 88 | # This is the last thing that should cause an OOM error, but seemingly it can. 89 | torch.cuda.empty_cache() 90 | except RuntimeError as exception: 91 | if not is_oom_error(exception): 92 | # Only handle OOM errors 93 | raise -------------------------------------------------------------------------------- /figs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/Fast-Higashi/71182a9bdee2b96cd1676e1448285c5705a354e7/figs/fig1.png -------------------------------------------------------------------------------- /figs/higashi_cellsystems.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/Fast-Higashi/71182a9bdee2b96cd1676e1448285c5705a354e7/figs/higashi_cellsystems.png -------------------------------------------------------------------------------- /figs/higashi_title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/Fast-Higashi/71182a9bdee2b96cd1676e1448285c5705a354e7/figs/higashi_title.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | print (find_packages()) 3 | setup( 4 | name='fast-higashi', 5 | version='0.1.1a0', 6 | description='Fast-Higashi: Ultrafast and interpretable single-cell 3D genome analysis', 7 | url='https://github.com/ma-compbio/Fast-Higashi', 8 | include_package_data=True, 9 | python_requires='>=3.9', 10 | packages=find_packages(), 11 | install_requires=[], 12 | extras_require={}, 13 | author='Ruochi Zhang', 14 | author_email='ruochiz@andrew.cmu.edu', 15 | license='MIT' 16 | ) 17 | --------------------------------------------------------------------------------