├── .DS_Store ├── .gitignore ├── Code ├── Modules.py ├── config.JSON ├── denoise_contact.py ├── generate_kmers.py ├── main.py ├── plot_embedding.py ├── predict_multiway.py ├── process.py └── utils.py ├── History_version ├── .DS_Store ├── Code │ ├── Modules.py │ ├── analysis_SPRITE.py │ ├── hg19.chrom.sizes.txt │ ├── hg38.chrom.sizes.txt │ ├── main_SPRITE.py │ ├── main_drop.py │ ├── process_SPRITE.py │ ├── random_walk.py │ ├── random_walk_hyper.py │ ├── torchsummary.py │ ├── utils.py │ └── word2vec_ops.so ├── Readme.md └── data │ ├── .DS_Store │ ├── SPRITE │ ├── .DS_Store │ ├── bin2node.npy │ ├── node2bin.npy │ ├── node2chrom.npy │ └── tuples │ │ ├── .DS_Store │ │ ├── occ_3_8.zip │ │ └── occ_above_8.zip │ └── drop │ ├── .DS_Store │ ├── coor2id.npy │ ├── id2coor.npy │ ├── test_data.npz │ └── train_data.npz ├── LICENSE └── Readme.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | .DS_Store 4 | .DS_Store 5 | *.json 6 | -------------------------------------------------------------------------------- /Code/Modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils import * 4 | import scipy 5 | 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | device_ids = [0, 1] 9 | activation = torch.tanh 10 | 11 | 12 | def get_non_pad_mask(seq): 13 | assert seq.dim() == 2 14 | return seq.ne(0).type(torch.float).unsqueeze(-1) 15 | 16 | 17 | def get_attn_key_pad_mask(seq_k, seq_q): 18 | ''' For masking out the padding part of key sequence. ''' 19 | 20 | # Expand to fit the shape of key query attention matrix. 21 | len_q = seq_q.size(1) 22 | padding_mask = seq_k.eq(0) 23 | padding_mask = padding_mask.unsqueeze( 24 | 1).expand(-1, len_q, -1) # b x lq x lk 25 | 26 | return padding_mask 27 | 28 | 29 | class Wrap_Embedding(torch.nn.Embedding): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | 33 | def forward(self, *input): 34 | return super().forward(*input), torch.Tensor([0]).to(device) 35 | 36 | 37 | # Used only for really big adjacency matrix 38 | class SparseEmbedding(nn.Module): 39 | def __init__(self, embedding_weight, sparse=False): 40 | super().__init__() 41 | print(embedding_weight.shape) 42 | self.sparse = sparse 43 | if self.sparse: 44 | self.embedding = embedding_weight 45 | else: 46 | try: 47 | try: 48 | self.embedding = torch.from_numpy( 49 | np.asarray(embedding_weight.todense())).to(device) 50 | except BaseException: 51 | self.embedding = torch.from_numpy( 52 | np.asarray(embedding_weight)).to(device) 53 | except Exception as e: 54 | print("Sparse Embedding Error", e) 55 | self.sparse = True 56 | self.embedding = embedding_weight 57 | 58 | def forward(self, x): 59 | 60 | if self.sparse: 61 | x = x.cpu().numpy() 62 | x = x.reshape((-1)) 63 | temp = np.asarray((self.embedding[x, :]).todense()) 64 | 65 | return torch.from_numpy(temp).to(device) 66 | else: 67 | return self.embedding[x, :] 68 | 69 | 70 | class TiedAutoEncoder(nn.Module): 71 | def __init__(self, shape_list, use_bias=True): 72 | super().__init__() 73 | self.weight_list = [] 74 | self.bias_list = [] 75 | self.use_bias = use_bias 76 | self.recon_bias_list = [] 77 | for i in range(len(shape_list) - 1): 78 | self.weight_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i + 1], shape_list[i]).to(device))) 79 | self.bias_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i + 1]).to(device))) 80 | self.recon_bias_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i]).to(device))) 81 | self.recon_bias_list = self.recon_bias_list[::-1] 82 | 83 | for i, w in enumerate(self.weight_list): 84 | self.register_parameter('tied weight_%d' % i, w) 85 | self.register_parameter('tied bias1', self.bias_list[i]) 86 | self.register_parameter('tied bias2', self.recon_bias_list[i]) 87 | 88 | self.reset_parameters() 89 | 90 | def reset_parameters(self): 91 | for i, w in enumerate(self.weight_list): 92 | torch.nn.init.kaiming_uniform_(self.weight_list[i], a=math.sqrt(5)) 93 | 94 | for i, b in enumerate(self.bias_list): 95 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_list[i]) 96 | bound = 1 / math.sqrt(fan_in) 97 | torch.nn.init.uniform_(self.bias_list[i], -bound, bound) 98 | temp_weight_list = self.weight_list[::-1] 99 | for i, b in enumerate(self.recon_bias_list): 100 | fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(temp_weight_list[i]) 101 | bound = 1 / math.sqrt(fan_out) 102 | torch.nn.init.uniform_(self.recon_bias_list[i], -bound, bound) 103 | 104 | def forward(self, input): 105 | # return input, input 106 | encoded_feats = input 107 | for i in range(len(self.weight_list)): 108 | if self.use_bias: 109 | encoded_feats = F.linear(encoded_feats, self.weight_list[i], self.bias_list[i]) 110 | else: 111 | encoded_feats = F.linear(encoded_feats, self.weight_list[i]) 112 | if i < len(self.weight_list) - 1: 113 | encoded_feats = activation(encoded_feats) 114 | 115 | reverse_weight_list = self.weight_list[::-1] 116 | reconstructed_output = encoded_feats 117 | for i in range(len(self.recon_bias_list)): 118 | reconstructed_output = F.linear(reconstructed_output, reverse_weight_list[i].t(), self.recon_bias_list[i]) 119 | if i < len(self.recon_bias_list) - 1: 120 | reconstructed_output = activation(reconstructed_output) 121 | 122 | return encoded_feats, reconstructed_output 123 | 124 | 125 | class MultipleEmbedding(nn.Module): 126 | def __init__( 127 | self, 128 | embedding_weights, 129 | dim, 130 | sparse=True, 131 | num_list=None, 132 | chrom_range=None, 133 | inter_initial=None): 134 | super().__init__() 135 | print(dim) 136 | self.chrom_range = chrom_range 137 | print("chrom_range", chrom_range) 138 | self.num_list = torch.tensor([0] + list(num_list)).to(device) 139 | print(self.num_list) 140 | self.dim = dim 141 | 142 | self.embeddings = [] 143 | for i, w in enumerate(embedding_weights): 144 | self.embeddings.append(SparseEmbedding(w, sparse)) 145 | 146 | if inter_initial is not None: 147 | for i in trange(len(inter_initial)): 148 | temp = inter_initial[i, :] 149 | inter_initial[i, temp > 0] = scipy.stats.mstats.zscore(temp[temp > 0]).astype('float32') 150 | 151 | # inter_initial[inter_initial > 0] = scipy.stats.mstats.zscore(inter_initial[inter_initial > 0], axis=1).astype('float32') 152 | inter_initial[np.isnan(inter_initial)] = 0.0 153 | 154 | self.inter_initial = SparseEmbedding(inter_initial, sparse) 155 | else: 156 | self.inter_initial = SparseEmbedding(w, sparse) 157 | 158 | test = torch.zeros(1, device=device).long() 159 | self.input_size = [] 160 | for w in self.embeddings: 161 | self.input_size.append(w(test).shape[-1]) 162 | 163 | self.wstack = [TiedAutoEncoder([self.input_size[i], self.dim, self.dim], use_bias=False).to(device) for i, w in 164 | enumerate(self.embeddings)] 165 | self.next_w = FeedForward([self.dim, self.dim]).to(device) 166 | self.recon = [FeedForward([self.dim, v[1] - v[0]]).to(device) for i, v in enumerate(self.chrom_range)] 167 | 168 | for i, w in enumerate(self.wstack): 169 | self.add_module("Embedding_Linear%d" % (i), w) 170 | # self.add_module("Embedding_Linear", self.next_w) 171 | self.add_module("Embedding_recon%d" % (i), self.recon[i]) 172 | 173 | 174 | self.dropout = nn.Dropout(0.2) 175 | 176 | def forward(self, x): 177 | 178 | final = torch.zeros((len(x), self.dim)).to(device) 179 | recon_loss = torch.Tensor([0.0]).to(device) 180 | for i in range(len(self.num_list) - 1): 181 | select = (x >= (self.num_list[i] + 1)) & (x < (self.num_list[i + 1] + 1)) 182 | if torch.sum(select) == 0: 183 | continue 184 | adj = self.embeddings[i](x[select] - self.num_list[i] - 1) 185 | # output = adj 186 | output = self.dropout(adj) 187 | output, recon = self.wstack[i](output) 188 | final[select] = output 189 | 190 | final = final 191 | 192 | random_chrom = np.random.choice(np.arange(len(self.chrom_range)), 1)[0] 193 | # Get the bins in the other chromosome, and it cannot be 0 (because 0 is padding) 194 | other_chrom = ((x < self.num_list[random_chrom] + 1) | (x >= self.num_list[random_chrom + 1] + 1)) & (x != 0) 195 | if torch.sum(other_chrom) != 0: 196 | target = self.inter_initial(x[other_chrom] - 1) 197 | target = target[:, self.num_list[random_chrom]:self.num_list[random_chrom + 1]] 198 | recon = self.recon[random_chrom](activation(final[other_chrom])) 199 | recon_loss += (target - recon).pow(2).mean(dim=-1).mean() * 100 200 | 201 | return final, recon_loss 202 | 203 | 204 | class Classifier(nn.Module): 205 | def __init__( 206 | self, 207 | n_head, 208 | d_model, 209 | d_k, 210 | d_v, 211 | node_embedding, 212 | diag_mask, 213 | bottle_neck, 214 | attribute_dict=None, 215 | **args): 216 | super().__init__() 217 | 218 | self.pff_classifier = PositionwiseFeedForward([d_model, 1], reshape=True, use_bias=True) 219 | 220 | self.node_embedding = node_embedding 221 | self.encode1 = EncoderLayer( 222 | n_head, 223 | d_model, 224 | d_k, 225 | d_v, 226 | dropout_mul=0.3, 227 | dropout_pff=0.4, 228 | diag_mask=diag_mask, 229 | bottle_neck=bottle_neck) 230 | self.encode2 = EncoderLayer( 231 | n_head, 232 | d_model, 233 | d_k, 234 | d_v, 235 | dropout_mul=0.3, 236 | dropout_pff=0.4, 237 | diag_mask=diag_mask, 238 | bottle_neck=bottle_neck) 239 | self.diag_mask_flag = diag_mask 240 | self.layer_norm1 = nn.LayerNorm(d_model) 241 | self.layer_norm2 = nn.LayerNorm(d_model) 242 | self.next_w = FeedForward([bottle_neck, bottle_neck]).to(device) 243 | if attribute_dict is not None: 244 | self.attribute_dict = torch.from_numpy(attribute_dict).to(device) 245 | self.attribute_dict_embedding = nn.Embedding(len(self.attribute_dict),1 ,padding_idx=0) 246 | self.attribute_dict_embedding.weight = nn.Parameter(self.attribute_dict) 247 | self.attribute_dict_embedding.weight.requires_grad = False 248 | self.attribute_nn = nn.Linear(self.attribute_dict.shape[-1], bottle_neck) 249 | self.attribute_dict = self.attribute_dict_embedding 250 | 251 | 252 | def get_node_embeddings(self, x, return_recon=False): 253 | # shape of x: (b, tuple) 254 | sz_b, len_seq = x.shape 255 | x, recon_loss = self.node_embedding(x.view(-1)) 256 | if return_recon: 257 | return x.view(sz_b, len_seq, -1), recon_loss 258 | else: 259 | return x.view(sz_b, len_seq, -1) 260 | 261 | def get_embedding(self, x, slf_attn_mask, non_pad_mask, return_recon=False): 262 | sz_b, len_seq = x.shape 263 | attribute_embed = self.attribute_dict(x.view(-1)) 264 | attribute_embed = self.attribute_nn(attribute_embed).view(sz_b, len_seq, -1) 265 | if return_recon: 266 | x, recon_loss = self.get_node_embeddings(x, return_recon) 267 | else: 268 | x = self.get_node_embeddings(x, return_recon) 269 | x += attribute_embed 270 | x = activation(self.next_w(x)) 271 | dynamic, static, attn = self.encode1(x, x, slf_attn_mask, non_pad_mask) 272 | # dynamic, static1, attn = self.encode2(dynamic, static,slf_attn_mask, non_pad_mask) 273 | if return_recon: 274 | return dynamic, static, attn, recon_loss 275 | else: 276 | return dynamic, static, attn 277 | 278 | def forward(self, x, mask=None, get_outlier=None, return_recon=False): 279 | x = x.long() 280 | 281 | slf_attn_mask = get_attn_key_pad_mask(seq_k=x, seq_q=x) 282 | non_pad_mask = get_non_pad_mask(x) 283 | 284 | # output, recon_loss = self.get_node_embeddings(x,return_recon=True) 285 | # output = output.view(len(output),1,-1) 286 | if return_recon: 287 | dynamic, static, attn, recon_loss = self.get_embedding(x, slf_attn_mask, non_pad_mask, return_recon) 288 | else: 289 | dynamic, static, attn = self.get_embedding(x, slf_attn_mask, non_pad_mask, return_recon) 290 | dynamic = self.layer_norm1(dynamic) 291 | static = self.layer_norm2(static) 292 | sz_b, len_seq, dim = dynamic.shape 293 | 294 | if self.diag_mask_flag: 295 | output = (dynamic - static) ** 2 296 | # output = dynamic * static 297 | else: 298 | output = dynamic 299 | output = self.pff_classifier(output) 300 | 301 | mode = 'sum' 302 | 303 | if mode == 'min': 304 | output, _ = torch.max( 305 | (1 - output) * non_pad_mask, dim=-2, keepdim=False) 306 | output = 1 - output 307 | 308 | elif mode == 'sum': 309 | output = torch.sum(output * non_pad_mask, dim=-2, keepdim=False) 310 | mask_sum = torch.sum(non_pad_mask, dim=-2, keepdim=False) + 1e-15 311 | output /= mask_sum 312 | elif mode == 'first': 313 | output = output[:, 0, :] 314 | # output = F.softplus(output) 315 | if return_recon: 316 | return output, recon_loss 317 | else: 318 | return output 319 | 320 | 321 | # A custom position-wise MLP. 322 | # dims is a list, it would create multiple layer with tanh between them 323 | # If dropout, it would add the dropout at the end. Before residual and 324 | # layer-norm 325 | 326 | 327 | class PositionwiseFeedForward(nn.Module): 328 | def __init__( 329 | self, 330 | dims, 331 | dropout=None, 332 | reshape=False, 333 | use_bias=True, 334 | residual=False, 335 | layer_norm=False): 336 | super(PositionwiseFeedForward, self).__init__() 337 | self.w_stack = [] 338 | self.dims = dims 339 | for i in range(len(dims) - 1): 340 | self.w_stack.append(nn.Conv1d(dims[i], dims[i + 1], 1, bias=use_bias)) 341 | self.add_module("PWF_Conv%d" % (i), self.w_stack[-1]) 342 | self.reshape = reshape 343 | self.layer_norm = nn.LayerNorm(dims[-1]) 344 | 345 | if dropout is not None: 346 | self.dropout = nn.Dropout(dropout) 347 | else: 348 | self.dropout = None 349 | 350 | self.residual = residual 351 | self.layer_norm_flag = layer_norm 352 | 353 | def forward(self, x): 354 | output = x.transpose(1, 2) 355 | 356 | for i in range(len(self.w_stack) - 1): 357 | output = self.w_stack[i](output) 358 | output = activation(output) 359 | if self.dropout is not None: 360 | output = self.dropout(output) 361 | 362 | output = self.w_stack[-1](output) 363 | output = output.transpose(1, 2) 364 | 365 | if self.reshape: 366 | output = output.view(output.shape[0], -1, 1) 367 | 368 | if self.dims[0] == self.dims[-1]: 369 | # residual 370 | if self.residual: 371 | output += x 372 | 373 | if self.layer_norm_flag: 374 | output = self.layer_norm(output) 375 | 376 | return output 377 | 378 | 379 | # A custom position wise MLP. 380 | # dims is a list, it would create multiple layer with torch.tanh between them 381 | # We don't do residual and layer-norm, because this is only used as the 382 | # final classifier 383 | 384 | 385 | class FeedForward(nn.Module): 386 | ''' A two-feed-forward-layer module ''' 387 | 388 | def __init__(self, dims, dropout=None, reshape=False, use_bias=True): 389 | super(FeedForward, self).__init__() 390 | self.w_stack = [] 391 | for i in range(len(dims) - 1): 392 | self.w_stack.append(nn.Linear(dims[i], dims[i + 1], use_bias)) 393 | self.add_module("FF_Linear%d" % (i), self.w_stack[-1]) 394 | 395 | if dropout is not None: 396 | self.dropout = nn.Dropout(dropout) 397 | else: 398 | self.dropout = None 399 | 400 | self.reshape = reshape 401 | 402 | def forward(self, x): 403 | output = x 404 | for i in range(len(self.w_stack) - 1): 405 | output = self.w_stack[i](output) 406 | output = activation(output) 407 | if self.dropout is not None: 408 | output = self.dropout(output) 409 | output = self.w_stack[-1](output) 410 | 411 | if self.reshape: 412 | output = output.view(output.shape[0], -1, 1) 413 | 414 | return output 415 | 416 | 417 | class ScaledDotProductAttention(nn.Module): 418 | ''' Scaled Dot-Product Attention ''' 419 | 420 | def __init__(self, temperature): 421 | super().__init__() 422 | self.temperature = temperature 423 | 424 | def masked_softmax(self, vector: torch.Tensor, 425 | mask: torch.Tensor, 426 | dim: int = -1, 427 | memory_efficient: bool = False, 428 | mask_fill_value: float = -1e32) -> torch.Tensor: 429 | 430 | if mask is None: 431 | result = torch.nn.functional.softmax(vector, dim=dim) 432 | else: 433 | mask = mask.float() 434 | while mask.dim() < vector.dim(): 435 | mask = mask.unsqueeze(1) 436 | if not memory_efficient: 437 | # To limit numerical errors from large vector elements outside 438 | # the mask, we zero these out. 439 | result = torch.nn.functional.softmax(vector * mask, dim=dim) 440 | result = result * mask 441 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) 442 | else: 443 | masked_vector = vector.masked_fill( 444 | (1 - mask).bool(), mask_fill_value) 445 | result = torch.nn.functional.softmax(masked_vector, dim=dim) 446 | return result 447 | 448 | def forward(self, q, k, v, diag_mask, mask=None): 449 | attn = torch.bmm(q, k.transpose(1, 2)) 450 | attn = attn / self.temperature 451 | 452 | if mask is not None: 453 | attn = attn.masked_fill(mask, -float('inf')) 454 | 455 | attn = self.masked_softmax( 456 | attn, diag_mask, dim=-1, memory_efficient=True) 457 | 458 | output = torch.bmm(attn, v) 459 | 460 | return output, attn 461 | 462 | 463 | class MultiHeadAttention(nn.Module): 464 | ''' Multi-Head Attention module ''' 465 | 466 | def __init__( 467 | self, 468 | n_head, 469 | d_model, 470 | d_k, 471 | d_v, 472 | dropout, 473 | diag_mask, 474 | input_dim): 475 | super().__init__() 476 | 477 | self.n_head = n_head 478 | self.d_k = d_k 479 | self.d_v = d_v 480 | 481 | self.w_qs = nn.Linear(input_dim, n_head * d_k, bias=False) 482 | self.w_ks = nn.Linear(input_dim, n_head * d_k, bias=False) 483 | self.w_vs = nn.Linear(input_dim, n_head * d_v, bias=False) 484 | 485 | nn.init.normal_(self.w_qs.weight, mean=0, 486 | std=np.sqrt(2.0 / (d_model + d_k))) 487 | nn.init.normal_(self.w_ks.weight, mean=0, 488 | std=np.sqrt(2.0 / (d_model + d_k))) 489 | nn.init.normal_(self.w_vs.weight, mean=0, 490 | std=np.sqrt(2.0 / (d_model + d_v))) 491 | 492 | self.attention = ScaledDotProductAttention( 493 | temperature=np.power(d_k, 0.5)) 494 | 495 | self.fc1 = nn.Linear(n_head * d_v, d_model) 496 | self.fc2 = nn.Linear(n_head * d_v, d_model) 497 | 498 | self.layer_norm1 = nn.LayerNorm(input_dim) 499 | self.layer_norm2 = nn.LayerNorm(input_dim) 500 | self.layer_norm3 = nn.LayerNorm(input_dim) 501 | 502 | if dropout is not None: 503 | self.dropout = nn.Dropout(dropout) 504 | else: 505 | self.dropout = dropout 506 | 507 | self.diag_mask_flag = diag_mask 508 | self.diag_mask = None 509 | 510 | def pass_(self, inputs): 511 | return inputs 512 | 513 | def forward(self, q, k, v, diag_mask, mask=None): 514 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 515 | 516 | residual_dynamic = q 517 | residual_static = v 518 | 519 | q = self.layer_norm1(q) 520 | k = self.layer_norm2(k) 521 | v = self.layer_norm3(v) 522 | 523 | sz_b, len_q, _ = q.shape 524 | sz_b, len_k, _ = k.shape 525 | sz_b, len_v, _ = v.shape 526 | 527 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 528 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 529 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 530 | 531 | q = q.permute(2, 0, 1, 3).contiguous( 532 | ).view(-1, len_q, d_k) # (n*b) x lq x dk 533 | k = k.permute(2, 0, 1, 3).contiguous( 534 | ).view(-1, len_k, d_k) # (n*b) x lk x dk 535 | v = v.permute(2, 0, 1, 3).contiguous( 536 | ).view(-1, len_v, d_v) # (n*b) x lv x dv 537 | 538 | n = sz_b * n_head 539 | 540 | if self.diag_mask is not None: 541 | if (len(self.diag_mask) <= n) or ( 542 | self.diag_mask.shape[1] != len_v): 543 | self.diag_mask = torch.ones((len_v, len_v), device=device) 544 | if self.diag_mask_flag: 545 | self.diag_mask -= torch.eye(len_v, len_v, device=device) 546 | self.diag_mask = self.diag_mask.repeat(n, 1, 1) 547 | diag_mask = self.diag_mask 548 | else: 549 | diag_mask = self.diag_mask[:n] 550 | 551 | else: 552 | self.diag_mask = (torch.ones((len_v, len_v), device=device)) 553 | if self.diag_mask_flag: 554 | self.diag_mask -= torch.eye(len_v, len_v, device=device) 555 | self.diag_mask = self.diag_mask.repeat(n, 1, 1) 556 | diag_mask = self.diag_mask 557 | 558 | if mask is not None: 559 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 560 | 561 | dynamic, attn = self.attention(q, k, v, diag_mask, mask=mask) 562 | 563 | dynamic = dynamic.view(n_head, sz_b, len_q, d_v) 564 | dynamic = dynamic.permute( 565 | 1, 2, 0, 3).contiguous().view( 566 | sz_b, len_q, -1) # b x lq x (n*dv) 567 | static = v.view(n_head, sz_b, len_q, d_v) 568 | static = static.permute( 569 | 1, 2, 0, 3).contiguous().view( 570 | sz_b, len_q, -1) # b x lq x (n*dv) 571 | 572 | dynamic = self.dropout(self.fc1(dynamic)) if self.dropout is not None else self.fc1(dynamic) 573 | static = self.dropout(self.fc2(static)) if self.dropout is not None else self.fc2(static) 574 | 575 | return dynamic, static, attn 576 | 577 | 578 | class EncoderLayer(nn.Module): 579 | '''A self-attention layer + 2 layered pff''' 580 | 581 | def __init__( 582 | self, 583 | n_head, 584 | d_model, 585 | d_k, 586 | d_v, 587 | dropout_mul, 588 | dropout_pff, 589 | diag_mask, 590 | bottle_neck): 591 | super().__init__() 592 | self.n_head = n_head 593 | self.d_k = d_k 594 | self.d_v = d_v 595 | 596 | self.mul_head_attn = MultiHeadAttention( 597 | n_head, 598 | d_model, 599 | d_k, 600 | d_v, 601 | dropout=dropout_mul, 602 | diag_mask=diag_mask, 603 | input_dim=bottle_neck) 604 | self.pff_n1 = PositionwiseFeedForward( 605 | [d_model, d_model, d_model], dropout=dropout_pff, residual=True, layer_norm=True) 606 | self.pff_n2 = PositionwiseFeedForward( 607 | [bottle_neck, d_model, d_model], dropout=dropout_pff, residual=False, layer_norm=True) 608 | 609 | # self.dropout = nn.Dropout(0.2) 610 | 611 | def forward(self, dynamic, static, slf_attn_mask, non_pad_mask): 612 | dynamic, static1, attn = self.mul_head_attn( 613 | dynamic, dynamic, static, slf_attn_mask) 614 | dynamic = self.pff_n1(dynamic * non_pad_mask) * non_pad_mask 615 | static1 = self.pff_n2(static * non_pad_mask) * non_pad_mask 616 | 617 | return dynamic, static, attn 618 | 619 | 620 | class DataGenerator(): 621 | def __init__(self, edges, edge_weight, batch_size, num_batch_per_iter, min_size = 2, max_size = 2, flag=False): 622 | self.edges = [[] for i in range(max_size + 1)] 623 | self.edge_weight = [[] for i in range(max_size + 1)] 624 | 625 | for e, ew in zip(edges, edge_weight): 626 | self.edges[len(e)].append(e) 627 | self.edge_weight[len(e)].append(ew) 628 | 629 | self.batch_size = batch_size 630 | self.num_batch_per_iter = num_batch_per_iter 631 | self.min_size = min_size 632 | self.max_size = max_size 633 | self.flag = flag 634 | for i in range(min_size, max_size + 1): 635 | self.edges[i] = np.array(self.edges[i]) 636 | self.edge_weight[i] = np.array(self.edge_weight[i]) 637 | 638 | while len(self.edges[i]) <= self.num_batch_per_iter * self.batch_size: 639 | self.edges[i] = np.concatenate([self.edges[i], self.edges[i]]) 640 | self.edge_weight[i] = np.concatenate([self.edge_weight[i], self.edge_weight[i]]) 641 | self.shuffle(i) 642 | self.pointer = np.zeros((len(self.edges)),dtype='int') 643 | 644 | 645 | 646 | def shuffle(self, i): 647 | if self.flag: 648 | print("reach end, shuffling") 649 | index = np.random.permutation(len(self.edges[i])) 650 | self.edges[i] = (self.edges[i])[index] 651 | self.edge_weight[i] = (self.edge_weight[i])[index] 652 | 653 | def next_iter(self): 654 | 655 | return_edges = [] 656 | return_edge_weight = [] 657 | 658 | for i in range(self.min_size, self.max_size + 1): 659 | self.pointer[i] += self.num_batch_per_iter * self.batch_size 660 | 661 | if self.pointer[i] <= len(self.edges[i]): 662 | index = range(self.pointer[i] - self.num_batch_per_iter * self.batch_size, min(self.pointer[i], len(self.edges[i]))) 663 | edges = (self.edges[i])[index] 664 | edge_weight = (self.edge_weight[i])[index] 665 | return_edges+= list(edges) 666 | return_edge_weight += list(edge_weight) 667 | else: 668 | index = range(self.pointer[i] - self.num_batch_per_iter * self.batch_size, 669 | min(self.pointer[i], len(self.edges[i]))) 670 | edges = (self.edges[i])[index] 671 | edge_weight = (self.edge_weight[i])[index] 672 | 673 | self.shuffle(i) 674 | left = self.num_batch_per_iter * self.batch_size - len(index) 675 | self.pointer[i] = 0 676 | self.pointer[i] += left 677 | index = range(0, self.pointer[i]) 678 | edges, edge_weight = np.concatenate([edges, self.edges[i][index]]), np.concatenate([edge_weight, self.edge_weight[i][index]]) 679 | return_edges += list(edges) 680 | return_edge_weight += list(edge_weight) 681 | return np.asarray(return_edges), np.asarray(return_edge_weight) 682 | 683 | def balance_num(self, edges): 684 | cell = edges[:, 0] 685 | final = [] 686 | choice, counts_ = np.unique(cell, return_counts=True) 687 | # num = int(np.mean(counts_)) 688 | num = 50 689 | for c in tqdm(choice): 690 | final.append(np.random.choice(np.where(cell == c)[0], num, replace=True)) 691 | final = np.concatenate(final, axis=-1) 692 | return final 693 | -------------------------------------------------------------------------------- /Code/config.JSON: -------------------------------------------------------------------------------- 1 | { 2 | "cluster_path": "abc.cluster", 3 | "mcool_path": "abc.mcool", 4 | "resolution": 1000000, 5 | "chrom_list": ["chr1","chr2","chr3","chr4","chr5", 6 | "chr6","chr7","chr8","chr9","chr10", 7 | "chr11","chr12","chr13","chr14","chr15", 8 | "chr16","chr17","chr18","chr19","chr20", 9 | "chr21","chr22","chrX"], 10 | "chrom_size": "../hg38.chrom.sizes.txt", 11 | "temp_dir": "../Temp", 12 | "max_cluster_size": 25, 13 | "min_distance": 0, 14 | "k-mer_size": [2, 3, 4, 5], 15 | "min_freq_cutoff": 2, 16 | "quantile_cutoff_for_positive": 0.6, 17 | "quantile_cutoff_for_unlabel": 0.4, 18 | "embed_dim": 64 19 | } -------------------------------------------------------------------------------- /Code/denoise_contact.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | mpl.use("Agg") 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | from utils import * 7 | from torch.nn.utils.rnn import pad_sequence 8 | from sklearn.preprocessing import QuantileTransformer 9 | import os 10 | import torch.nn.functional as F 11 | from utils import * 12 | 13 | def get_free_gpu(): 14 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > ./tmp') 15 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 16 | if len(memory_available) > 0: 17 | id = int(np.argmax(memory_available)) 18 | print("setting to gpu:%d" % id) 19 | torch.cuda.set_device(id) 20 | return "cuda:%d" % id 21 | else: 22 | return 23 | 24 | if torch.cuda.is_available(): 25 | current_device = get_free_gpu() 26 | else: 27 | current_device = 'cpu' 28 | 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | def proba2matrix(sample, weight=None, proba=None, intra=True): 32 | 33 | sample_left = sample 34 | weight_left = weight 35 | if intra: 36 | sample_left -= np.min(sample_left) 37 | size = int(np.max(sample_left) + 1 ) 38 | m = np.zeros((size, size), dtype='float32') 39 | if weight is not None: 40 | for i in range(sample_left.shape[-1] - 1): 41 | for j in range( i +1, sample_left.shape[-1]): 42 | m[sample_left[: ,i], sample_left[: ,j]] += np.maximum(proba * weight_left, proba) 43 | 44 | else: 45 | for i in range(sample_left.shape[-1] - 1): 46 | for j in range(i + 1, sample_left.shape[-1]): 47 | m[sample_left[:, i], sample_left[:, j]] += proba 48 | 49 | m = m + m.T 50 | else: 51 | size1 = int(np.max(sample_left[:, 0]) - np.min(sample_left[:, 0]) + 1) 52 | size2 = int(np.max(sample_left[:, 1]) - np.min(sample_left[:, 1]) + 1) 53 | m = np.zeros((size1, size2), dtype='float32') 54 | if weight is not None: 55 | m[sample_left[:, 0] - np.min(sample_left[:, 0]), sample_left[:, 1]- np.min(sample_left[:, 1])] += np.maximum(proba * weight_left, proba) 56 | 57 | else: 58 | m[sample_left[:, 0] - np.min(sample_left[:, 0]), sample_left[:, 1] - np.min(sample_left[:, 1])] += proba 59 | 60 | 61 | return m 62 | 63 | 64 | 65 | 66 | 67 | def generate_pair_wise(chrom_id): 68 | samples = [] 69 | for i in range(chrom_range[chrom_id ,0] ,chrom_range[chrom_id ,1]): 70 | for j in range( i +min_dis, chrom_range[chrom_id ,1]): 71 | samples.append([i ,j]) 72 | 73 | samples = np.array(samples) 74 | return samples 75 | 76 | def predict(model, input): 77 | model.eval() 78 | output = [] 79 | new_batch_size = int(1e4) 80 | with torch.no_grad(): 81 | for j in trange(math.ceil(len(input) / new_batch_size)): 82 | x = input[j * new_batch_size:min((j + 1) * new_batch_size, len(input))] 83 | x = np2tensor_hyper(x, dtype=torch.long) 84 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 85 | output.append(model(x).detach().cpu().numpy()) 86 | output = np.concatenate(output, axis=0) 87 | torch.cuda.empty_cache() 88 | return output 89 | 90 | config = get_config() 91 | min_dis = config['min_distance'] 92 | temp_dir = config['temp_dir'] 93 | res = config['resolution'] 94 | vmin = -0.0 95 | vmax = 1.0 96 | 97 | 98 | chrom_range = np.load(os.path.join(temp_dir,"chrom_range.npy")) 99 | classifier_model = torch.load(os.path.join(temp_dir, "model2load"), map_location=current_device) 100 | 101 | print ("device", classifier_model.layer_norm1.weight.device) 102 | device_info = classifier_model.layer_norm1.weight.device 103 | device_info = str(device_info).split(":")[-1] 104 | torch.cuda.set_device(int(device_info)) 105 | transformer = QuantileTransformer(n_quantiles=1000, output_distribution='uniform') 106 | task_mode = 'class' 107 | 108 | origin = np.load(os.path.join(temp_dir, "intra_adj.npy")).astype('float32') 109 | # origin = np.load(os.path.join(temp_dir, "edge_list_adj.npy")).astype('float32') 110 | 111 | import h5py 112 | 113 | ## generating mcool file 114 | f = h5py.File("../denoised.mcool", "w") 115 | grp = f.create_group("resolutions") 116 | grp = grp.create_group("%d" % res) 117 | cooler_bin = grp.create_group("bins") 118 | node2bin = np.load(os.path.join(temp_dir,"node2bin.npy"), allow_pickle=True).item() 119 | 120 | chrom_list = [] 121 | chrom_start = [] 122 | chrom_end = [] 123 | chrom_name = config['chrom_list'] 124 | 125 | for i in range(1, int(np.max(list(node2bin.keys())) + 1)): 126 | bin_ = node2bin[i] 127 | chrom, start = bin_.split(":") 128 | chrom_list.append(chrom_name.index(chrom)) 129 | chrom_start.append(int(start)) 130 | chrom_end.append(int(start) + res) 131 | 132 | print (np.array(chrom_list), np.array(chrom_start)) 133 | cooler_bin.create_dataset("chrom", data = chrom_list) 134 | cooler_bin.create_dataset("start", data = chrom_start) 135 | cooler_bin.create_dataset("end", data = chrom_end) 136 | 137 | cooler_chrom = grp.create_group("chroms") 138 | cooler_chrom.create_dataset("name", data = [l.encode('utf8') for l in chrom_name], dtype = h5py.special_dtype(vlen=str)) 139 | 140 | cooler_pixels = grp.create_group("pixels") 141 | bin_id1 = [] 142 | bin_id2 = [] 143 | balanced = [] 144 | 145 | 146 | 147 | for i in range(len(chrom_name)): 148 | pair_wise = generate_pair_wise(i) 149 | print (pair_wise) 150 | bin_id1.append(np.copy(pair_wise[:, 0]) - 1) 151 | bin_id2.append(np.copy(pair_wise[:, 1]) - 1) 152 | # print (pair_wise.shape, pair_wise) 153 | proba = predict(classifier_model, pair_wise).reshape((-1)) 154 | if task_mode == 'class': 155 | proba = torch.sigmoid(torch.from_numpy(proba)).numpy() 156 | else: 157 | proba = F.softplus(torch.from_numpy(proba)).numpy() 158 | # print ( np.sum(proba >= 0.5) ,proba.shape) 159 | 160 | pair_wise_weight = np.array([origin[e[0 ] -1 ,e[1 ] -1] for e in tqdm(pair_wise)]) 161 | 162 | my_proba = proba2matrix(pair_wise, None, proba) 163 | coverage1 = np.sqrt(np.mean(my_proba, axis=-1, keepdims=True)) 164 | coverage2 = np.sqrt(np.mean(my_proba, axis=0, keepdims=True)) 165 | my_proba = my_proba / (coverage1 + 1e-15) 166 | my_proba = my_proba / (coverage2 + 1e-15) 167 | 168 | origin_part = proba2matrix(pair_wise, None, pair_wise_weight) 169 | gap1 = np.sum(origin_part, axis=-1) == 0 170 | gap2 = np.sum(origin_part, axis=0) == 0 171 | coverage1 = np.sqrt(np.mean(origin_part, axis=-1, keepdims=True)) 172 | coverage2 = np.sqrt(np.mean(origin_part, axis=0, keepdims=True)) 173 | origin_part = origin_part / (coverage1+ 1e-15) 174 | origin_part = origin_part / (coverage2 + 1e-15) 175 | 176 | 177 | my = my_proba * origin_part 178 | my = np.maximum(my_proba * origin_part, my_proba) 179 | coverage1 = np.sqrt(np.mean(my, axis=-1, keepdims=True)) 180 | coverage2 = np.sqrt(np.mean(my, axis=0, keepdims=True)) 181 | my = my / (coverage1 + 1e-15) 182 | my = my / (coverage2 + 1e-15) 183 | 184 | 185 | my[gap1 ,:] = 0.0 186 | my[:, gap2] = 0.0 187 | my_proba[gap1 ,:] = 0.0 188 | my_proba[:, gap2] = 0.0 189 | 190 | my = transformer.fit_transform(my.reshape((-1, 1))).reshape((len(my), -1)) 191 | origin_part = transformer.fit_transform(origin_part.reshape((-1, 1))).reshape((len(origin_part), -1)) 192 | my_proba = transformer.fit_transform(my_proba.reshape((-1, 1))).reshape((len(my), -1)) 193 | 194 | fig = plt.figure(figsize=(5, 5)) 195 | plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0) 196 | mask =None 197 | # print ("matrix", matrix, np.min(matrix), np.max(matrix)) 198 | ax = sns.heatmap(my, cmap="Reds", square=True, mask=mask ,cbar=False, vmin=vmin, vmax=vmax) 199 | ax.get_xaxis().set_visible(False) 200 | ax.get_yaxis().set_visible(False) 201 | plt.savefig("../%s_denoise.png" %chrom_name[i], dpi=300) 202 | plt.close(fig) 203 | 204 | 205 | min_ = np.min(pair_wise) 206 | # print (pair_wise[:, 0] - min_, pair_wise[:, 1] - min_, my.shape) 207 | value = my[pair_wise[:, 0] - min_, pair_wise[:, 1] - min_] 208 | balanced.append(value) 209 | 210 | # fig = plt.figure(figsize=(5, 5)) 211 | # plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0) 212 | # mask =None 213 | # # print ("matrix", matrix, np.min(matrix), np.max(matrix)) 214 | # ax = sns.heatmap(my_proba, cmap="Reds", square=True, mask=mask ,cbar=False, vmin=vmin, vmax=vmax) 215 | # ax.get_xaxis().set_visible(False) 216 | # ax.get_yaxis().set_visible(False) 217 | # plt.savefig("../chr%d_denoise_proba.png" %(i+1), dpi=300) 218 | # plt.close(fig) 219 | 220 | fig = plt.figure(figsize=(5, 5)) 221 | plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0) 222 | mask =None 223 | # print ("matrix", matrix, np.min(matrix), np.max(matrix)) 224 | ax = sns.heatmap(origin_part, cmap="Reds", square=True, mask=mask ,cbar=False, vmin=vmin, vmax=vmax) 225 | ax.get_xaxis().set_visible(False) 226 | ax.get_yaxis().set_visible(False) 227 | plt.savefig("../%s_origin.png" %chrom_name[i], dpi=300) 228 | plt.close(fig) 229 | 230 | 231 | bin_id1 = np.concatenate(bin_id1, axis=0) 232 | bin_id2 = np.concatenate(bin_id2, axis=0) 233 | balanced = np.concatenate(balanced, axis=0) 234 | cooler_pixels.create_dataset("bin1_id", data=bin_id1) 235 | cooler_pixels.create_dataset("bin2_id", data=bin_id2) 236 | cooler_pixels.create_dataset("balanced", data=balanced) 237 | -------------------------------------------------------------------------------- /Code/generate_kmers.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import time 3 | import os 4 | from collections import Counter 5 | import multiprocessing 6 | from utils import * 7 | 8 | def build_dict(size, i_list): 9 | list1 = [] 10 | list1_freq = [] 11 | 12 | for i in tqdm(i_list): 13 | hash_counter = Counter() 14 | for j,index in enumerate(node2usefulindex[i]): 15 | datum = new_data[index] 16 | n = len(datum) 17 | # weight = weights[n] 18 | combs = combinations(datum[datum > i + min_dis], size - 1) 19 | keys = np.array(list(combs)) 20 | 21 | if len(keys) <= 0: 22 | continue 23 | 24 | if size > 2: 25 | dis_list = np.zeros((len(keys), size - 2)) 26 | 27 | for j in range(size - 2): 28 | dis_list[:, j] = keys[:, j + 1] - keys[:, j] 29 | 30 | dis_list_min = np.min(dis_list, axis=-1) 31 | 32 | length = len(keys) 33 | keys = keys[dis_list_min > min_dis] 34 | 35 | for comb in keys: 36 | comb = tuple(comb) 37 | hash_counter[comb] += 1 38 | # hash_dict[comb] += weight 39 | # print (hash_counter) 40 | 41 | new_hash = {el:hash_counter[el] for el in hash_counter if hash_counter[el] >= min_freq_cutoff} 42 | hash_dict = new_hash 43 | 44 | keys = list(hash_dict.keys()) 45 | freq = np.array([hash_dict[k] for k in keys]) 46 | keys = np.array(keys) 47 | if len(keys) > 0: 48 | keys = np.concatenate([np.ones((len(keys), 1), dtype='int') * i, keys], axis=-1) 49 | 50 | temp = keys 51 | temp_freq = freq 52 | if len(temp) == 0: 53 | temp = np.array([]) 54 | temp_freq = np.array([]) 55 | if len(temp) > 0: 56 | list1.append(temp) 57 | list1_freq.append(temp_freq) 58 | 59 | 60 | 61 | if len(list1) > 0: 62 | list1 = np.concatenate(list1, axis=0) 63 | list1_freq = np.concatenate(list1_freq, axis=0) 64 | else: 65 | list1 = np.array([]) 66 | list1_freq = np.array([]) 67 | 68 | del new_hash, hash_counter 69 | return i_list, list1, list1_freq 70 | 71 | 72 | 73 | config = get_config() 74 | max_size = config['max_cluster_size'] 75 | k_list = config['k-mer_size'] 76 | temp_dir = config['temp_dir'] 77 | min_dis = config['min_distance'] 78 | min_freq_cutoff = config['min_freq_cutoff'] 79 | 80 | chrom_range = np.load(os.path.join(temp_dir, "chrom_range.npy")) 81 | node_num = np.max(chrom_range) + 1 82 | print(chrom_range) 83 | data = np.load(os.path.join(temp_dir, "edge_list.npy"), allow_pickle=True) 84 | MAX_WORKER = multiprocessing.cpu_count() 85 | 86 | for k in k_list: 87 | size = k 88 | 89 | new_data = [] 90 | for datum in tqdm(data): 91 | if (len(datum) >= size) & (len(datum) <= max_size): 92 | new_data.append(np.array(datum)) 93 | new_data = np.array(new_data) 94 | node2usefulindex = [[] for i in range(node_num)] 95 | for i, datum in enumerate(tqdm(new_data)): 96 | for n in datum: 97 | node2usefulindex[n].append(i) 98 | node2usefulindex = np.array(node2usefulindex) 99 | 100 | process_list = [] 101 | list1 = [] 102 | list1_freq = [] 103 | pool = ProcessPoolExecutor(max_workers=MAX_WORKER) 104 | 105 | node_list = np.arange(node_num).astype('int') 106 | batch_size = 50 107 | job_iter = np.array_split(node_list, int(len(node_list) / batch_size)) 108 | job_iter = iter(job_iter) 109 | jobs_left = len(node_list) 110 | 111 | while jobs_left > 0: 112 | for i in job_iter: 113 | process_list.append(pool.submit(build_dict, size, i)) 114 | time.sleep(0.2) 115 | if len(process_list) > 1.3*MAX_WORKER: 116 | break 117 | 118 | start = time.time() 119 | for p in as_completed(process_list): 120 | a = p.result() 121 | i_list, temp, temp_freq = a 122 | jobs_left -= len(i_list) 123 | if len(temp) > 0: 124 | list1.append(temp) 125 | list1_freq.append(temp_freq) 126 | 127 | process_list.remove(p) 128 | del p 129 | if time.time() - start >= 10: 130 | break 131 | print ("jobs_left", jobs_left) 132 | pool.shutdown(wait=True) 133 | 134 | 135 | if len(list1) > 0: 136 | list1 = np.concatenate(list1,axis = 0) 137 | list1_freq = np.concatenate(list1_freq, axis=0) 138 | print() 139 | print (list1.shape) 140 | np.save(os.path.join(temp_dir,"all_%d_counter.npy" % size) ,list1) 141 | np.save(os.path.join(temp_dir,"all_%d_freq_counter.npy" % size) , list1_freq) 142 | print ("Quick summarize") 143 | print ("total data", len(list1_freq)) 144 | for c in [2,3,4,5,6,7,8]: 145 | print (">= %d" %c, np.sum(list1_freq >= c)) -------------------------------------------------------------------------------- /Code/main.py: -------------------------------------------------------------------------------- 1 | from pybloom_live import BloomFilter 2 | import multiprocessing 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | import time 6 | from tqdm import tqdm, trange 7 | import argparse 8 | import warnings 9 | import random 10 | from Modules import * 11 | from utils import * 12 | 13 | 14 | import datetime 15 | 16 | 17 | cpu_num = multiprocessing.cpu_count() 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = False 22 | 23 | warnings.filterwarnings("ignore") 24 | def get_free_gpu(): 25 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > ./tmp') 26 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 27 | if len(memory_available) > 0: 28 | id = int(np.argmax(memory_available)) 29 | print("setting to gpu:%d" % id) 30 | torch.cuda.set_device(id) 31 | else: 32 | return 33 | 34 | if torch.cuda.is_available(): 35 | get_free_gpu() 36 | 37 | def forward_op_batch( 38 | model, 39 | loss_func, 40 | batch_data, 41 | batch_weight, 42 | y=""): 43 | x = batch_data 44 | w = batch_weight 45 | 46 | # When label is not generated, prepare the data 47 | if len(y) == 0: 48 | x, y, w, s = generate_negative(x, "train_dict", w, neg_num=neg_num) 49 | x, y, w, s = sync_shuffle([x, y, w, s]) 50 | else: 51 | s = torch.ones((len(y), 1)) 52 | 53 | # forward 54 | pred, recon_loss = model(x, return_recon=True) 55 | 56 | loss = loss_func(pred, y, weight=w) 57 | 58 | return F.sigmoid(pred), y, loss, recon_loss, w, s 59 | 60 | def forward_op_batch_regress( 61 | model, 62 | loss_func, 63 | batch_data, 64 | batch_weight, 65 | y=""): 66 | x = batch_data 67 | w = batch_weight 68 | 69 | # When label is not generated, prepare the data 70 | if len(y) == 0: 71 | x, y, w, s = generate_negative(x, "train_dict", w,neg_num=1) 72 | x, y, w, s = sync_shuffle([x, y, w, s]) 73 | else: 74 | s = torch.ones((len(y), 1)) 75 | 76 | if len(x) % 2 == 1: 77 | batch_length = (len(x) - 1) 78 | x = x[:batch_length] 79 | w = w[:batch_length] 80 | y = y[:batch_length] 81 | s = s[:batch_length] 82 | batch_length = int(batch_length / 2) 83 | else: 84 | batch_length = int(len(x) / 2) 85 | 86 | # forward 87 | pred, recon_loss = model(x, return_recon=True) 88 | pred = F.softplus(pred) 89 | 90 | loss = F.mse_loss(pred, y) 91 | 92 | pred = pred.view(batch_length,2) 93 | y = y.view(batch_length,2) 94 | w = w.view(batch_length,2) 95 | s = s.view(batch_length,2) 96 | l_back = torch.argmin(y,dim=-1,keepdim=False) 97 | l = l_back.clone() 98 | 99 | l[l == 0] = -1 100 | mask = y[:,0]!=y[:,1] 101 | l = l[mask].float() 102 | l_back = l_back[mask].float() 103 | pred = pred[mask] 104 | w = w[mask,0] 105 | s = s[mask,0] 106 | # print ("l,pred",l, pred) 107 | # , weight=s.float().view(-1, 1).to(device) 108 | # loss = loss_func(pred[:,0], pred[:,1], l, margin=0.1) 109 | 110 | y = l_back 111 | pred = pred[:,0] - pred[:,1] 112 | # pred = F.sigmoid(pred) 113 | # loss = loss_func(pred, y) 114 | # print (y) 115 | return F.sigmoid(pred), y, loss, recon_loss, w, s 116 | 117 | 118 | 119 | def train_epoch( 120 | model, 121 | loss_func, 122 | training_data, 123 | optimizer, 124 | batch_size): 125 | # Epoch operation in training phase 126 | # print (len(train_dict[min_size]), train_dict[min_size].capacity, len(test_dict[min_size])) 127 | edges, edge_weight = training_data 128 | y = torch.tensor([]) 129 | # y = training_y 130 | # Permutate all the data 131 | if len(y) > 0: 132 | print("existing y") 133 | edges, edge_weight, y = sync_shuffle([edges, edge_weight, y]) 134 | else: 135 | edges, edge_weight = sync_shuffle([edges, edge_weight]) 136 | 137 | model.train() 138 | 139 | if task_mode == 'class': 140 | forward_func = forward_op_batch 141 | elif task_mode == 'regress': 142 | forward_func = forward_op_batch_regress 143 | 144 | bce_total_loss = 0 145 | recon_total_loss = 0 146 | acc_list, y_list, pred_list, weight_list, size_list = [], [], [], [], [] 147 | 148 | batch_num = int(math.floor(len(edges) / batch_size)) 149 | bar = trange( 150 | batch_num, 151 | mininterval=0.1, 152 | desc=' - (Training) ', 153 | leave=False, 154 | ) 155 | for i in bar: 156 | batch_edge = edges[i * batch_size:(i + 1) * batch_size] 157 | batch_edge_weight = edge_weight[i * batch_size:(i + 1) * batch_size] 158 | batch_y = "" 159 | if len(y) > 0: 160 | batch_y = y[i * batch_size:(i + 1) * batch_size] 161 | if len(batch_y) == 0: 162 | continue 163 | 164 | pred, batch_y, loss_bce, loss_recon, batch_w, batch_s = forward_func( 165 | model, loss_func, batch_edge, batch_edge_weight, y=batch_y) 166 | loss = loss_bce * alpha + loss_recon * beta 167 | # loss = loss_bce + loss_recon 168 | 169 | # acc_list.append(accuracy(pred, batch_y)) 170 | y_list.append(batch_y) 171 | pred_list.append(pred) 172 | weight_list.append(batch_w) 173 | size_list.append(batch_s) 174 | 175 | for opt in optimizer: 176 | opt.zero_grad() 177 | 178 | # backward 179 | loss.backward() 180 | 181 | # update parameters 182 | for opt in optimizer: 183 | opt.step() 184 | 185 | bar.set_description(" - (Training) BCE: %.4f recon: %.4f" % 186 | (bce_total_loss / (i + 1), recon_total_loss / (i + 1))) 187 | bce_total_loss += loss_bce.item() 188 | recon_total_loss += loss_recon.item() 189 | y = torch.cat(y_list) 190 | pred = torch.cat(pred_list) 191 | size_list = torch.cat(size_list) 192 | weight_list = torch.cat(weight_list) 193 | 194 | auc1, auc2 = roc_auc_cuda(y, pred, size_list, max_size) 195 | acc = accuracy(pred, y, size_list, max_size) 196 | 197 | return bce_total_loss / batch_num, recon_total_loss / batch_num, acc, auc1, auc2 198 | 199 | 200 | def eval_epoch(model, loss_func, validation_data, batch_size): 201 | ''' Epoch operation in evaluation phase ''' 202 | bce_total_loss = 0 203 | recon_total_loss = 0 204 | 205 | model.eval() 206 | 207 | if task_mode == 'class': 208 | forward_func = forward_op_batch 209 | elif task_mode == 'regress': 210 | forward_func = forward_op_batch_regress 211 | 212 | with torch.no_grad(): 213 | validation_data, validation_weight = validation_data 214 | y = "" 215 | 216 | validation_data, validation_weight = sync_shuffle( 217 | [validation_data, validation_weight], 10000) 218 | 219 | pred, label, size_list, weight_list = [], [], [], [] 220 | 221 | for i in tqdm(range(int(math.floor(len(validation_data) / batch_size))), 222 | mininterval=0.1, desc=' - (Validation) ', leave=False): 223 | # prepare data 224 | batch_edge = validation_data[i * batch_size:(i + 1) * batch_size] 225 | batch_edge_weight = validation_weight[i * batch_size:(i + 1) * batch_size] 226 | 227 | # if len(y) == 0: 228 | # batch_x, batch_y, batch_w, batch_s = generate_negative( 229 | # batch_x, "test_dict", weight=batch_w, neg_num=neg_num) 230 | # else: 231 | # batch_y = y[i * batch_size:(i + 1) * batch_size] 232 | # 233 | # batch_x, batch_y, batch_w, batch_s = sync_shuffle( 234 | # [batch_x, batch_y, batch_w, batch_s]) 235 | # pred_batch, recon_loss = model(batch_x, return_recon=True) 236 | # loss = loss_func(pred_batch, batch_y) 237 | pred_batch, batch_y, loss, recon_loss, batch_w, batch_s = forward_func( 238 | model, loss_func, batch_edge, batch_edge_weight) 239 | 240 | size_list.append(batch_s) 241 | pred.append(pred_batch) 242 | label.append(batch_y) 243 | weight_list.append(batch_edge_weight) 244 | 245 | 246 | recon_total_loss += recon_loss.item() 247 | bce_total_loss += loss.item() 248 | 249 | pred = torch.cat(pred, dim=0) 250 | label = torch.cat(label, dim=0) 251 | size_list = torch.cat(size_list, dim=0) 252 | # weight_list = torch.cat(weight_list, dim=0) 253 | 254 | acc = accuracy(pred, label, size_list, max_size) 255 | auc1, auc2 = roc_auc_cuda(label, pred, size_list, max_size) 256 | 257 | return bce_total_loss / (i + 1), recon_total_loss / \ 258 | (i + 1), acc, auc1, auc2 259 | 260 | 261 | def train(model, 262 | loss, 263 | training_data, 264 | validation_data, 265 | optimizer, 266 | epochs, 267 | batch_size): 268 | valid_accus = [0] 269 | edges, edge_weight = training_data 270 | training_data_generator = DataGenerator( 271 | edges, edge_weight, int(batch_size),1000, min_size=min_size, max_size=max_size) 272 | start = time.time() 273 | for epoch_i in range(epochs): 274 | 275 | save_embeddings(model, True) 276 | print('[ Epoch', epoch_i, 'of', epochs, ']') 277 | 278 | start = time.time() 279 | edges_part, edge_weight_part = training_data_generator.next_iter() 280 | training_data_new = edges_part, edge_weight_part 281 | 282 | bce_loss, recon_loss, train_accu, auc1, auc2 = train_epoch(model, loss, training_data_new, optimizer, 283 | batch_size) 284 | 285 | print( 286 | ' - (Training) bce: {bce_loss: 7.4f},' 287 | 'recon: {recon_loss: 7.4f}' 288 | ' acc: {accu}, auc: {auc1}, aupr: {auc2}, ' 289 | 'elapse: {elapse:3.3f} s'.format( 290 | bce_loss=bce_loss, 291 | recon_loss=recon_loss, 292 | accu=train_accu, 293 | auc1=auc1, 294 | auc2=auc2, 295 | elapse=( 296 | time.time() - start))) 297 | 298 | start = time.time() 299 | valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, 300 | batch_size) 301 | print( 302 | ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 303 | ' acc: {accu},' 304 | ' auc: {auc1}, aupr: {auc2},' 305 | 'elapse: {elapse:3.3f} s'.format( 306 | bce_loss=valid_bce_loss, 307 | recon_loss=recon_loss, 308 | accu=valid_accu, 309 | auc1=valid_auc1, 310 | auc2=valid_auc2, 311 | elapse=( 312 | time.time() - start))) 313 | valid_aupr_final = float(valid_auc2.split(" ")[-2]) 314 | valid_accus += [valid_aupr_final] 315 | 316 | checkpoint = { 317 | 'model_link': model.state_dict(), 318 | 'epoch': epoch_i} 319 | 320 | if valid_aupr_final >= max(valid_accus): 321 | torch.save(checkpoint, os.path.join(temp_dir, model_name)) 322 | torch.save(model, os.path.join(temp_dir, "model2load")) 323 | 324 | torch.cuda.empty_cache() 325 | 326 | checkpoint = torch.load(os.path.join(temp_dir, model_name)) 327 | model.load_state_dict(checkpoint['model_link']) 328 | 329 | valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, 330 | batch_size) 331 | print( 332 | ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 333 | ' acc: {accu},' 334 | ' auc: {auc1}, aupr: {auc2},' 335 | 'elapse: {elapse:3.3f} s'.format( 336 | bce_loss=valid_bce_loss, 337 | recon_loss=recon_loss, 338 | accu=valid_accu, 339 | auc1=valid_auc1, 340 | auc2=valid_auc2, 341 | elapse=( 342 | time.time() - start))) 343 | 344 | 345 | def neighbor_check(temp, dict): 346 | return tuple(temp) in dict 347 | # flag = False 348 | # for i in range(len(temp)): 349 | # for j in [-1, 0, 1]: 350 | # a = np.copy(temp) 351 | # a[i] += j 352 | # a.sort() 353 | # if tuple(a) in dict: 354 | # flag = True 355 | # break 356 | # if flag: 357 | # break 358 | # return flag 359 | 360 | 361 | def generate_negative(x, dict1, weight="", neg_num=1): 362 | if len(weight) == 0: 363 | weight = torch.ones(len(x), dtype=torch.float) 364 | if dict1 == 'train_dict': 365 | dict1 = train_dict 366 | elif dict1 == 'test_dict': 367 | dict1 = test_dict 368 | 369 | change_num_list = [[] for i in range(max_size + 1)] 370 | for s in range(min_size, max_size + 1): 371 | change_num = np.random.binomial(s, 0.5, int(len(x) * (math.ceil(neg_num) * 2))) 372 | change_num = change_num[change_num != 0] 373 | 374 | change_num_list[s] = list(change_num) 375 | 376 | neg_list = [] 377 | new_x = [] 378 | new_index = [] 379 | neg_weight = [] 380 | size_list = [] 381 | size_neg_list = [] 382 | 383 | for j, sample in enumerate(x): 384 | for i in range(int(math.ceil(neg_num))): 385 | 386 | decompose_sample = np.copy(sample) 387 | list1 = change_num_list[decompose_sample.shape[-1]] 388 | change_num = list1.pop() 389 | changes = np.random.choice(np.arange(decompose_sample.shape[-1]), change_num, replace=False) 390 | temp = np.copy(decompose_sample) 391 | trial = 0 392 | while neighbor_check(temp, dict1[(len(temp))]): 393 | temp = np.copy(decompose_sample) 394 | # trial += 1 395 | # if trial >= 10000: 396 | # temp = "" 397 | # break 398 | 399 | for change in changes: 400 | if temp[change] not in node2chrom: 401 | print(temp, decompose_sample) 402 | chrom = node2chrom[temp[change]] 403 | start, end = chrom_range[chrom] 404 | 405 | temp[change] = int( 406 | math.floor( 407 | (end - start) * random.random())) + start 408 | 409 | 410 | temp = list(set(temp)) 411 | 412 | if len(temp) < len(decompose_sample): 413 | temp = np.copy(decompose_sample) 414 | continue 415 | 416 | temp.sort() 417 | dis_list = [] 418 | for k in range(len(temp) - 1): 419 | dis_list.append(temp[k + 1] - temp[k]) 420 | if np.min(dis_list) <= min_dis: 421 | temp = np.copy(decompose_sample) 422 | 423 | if i == 0: 424 | size_list.append(len(decompose_sample)) 425 | if len(temp) > 0: 426 | neg_list.append(temp) 427 | size_neg_list.append(len(temp)) 428 | neg_weight.append(weight[j]) 429 | 430 | pos_weight = weight 431 | pos_weight = torch.tensor(pos_weight).to(device) 432 | size_list = torch.tensor(size_list + size_neg_list) 433 | pos_part = np2tensor_hyper(list(x), dtype=torch.long) 434 | neg = np2tensor_hyper(neg_list, dtype=torch.long) 435 | if type(pos_part) == list: 436 | pos_part = pad_sequence(pos_part, batch_first=True, padding_value=0) 437 | neg = pad_sequence(neg, batch_first=True, padding_value=0) 438 | 439 | if len(neg) == 0: 440 | neg = torch.zeros((1, pos_part.shape[-1]),dtype=torch.long, device=device) 441 | pos_part = pos_part.to(device) 442 | neg = neg.to(device) 443 | if task_mode == 'class': 444 | y = torch.cat([torch.ones((len(pos_part), 1), device=device), 445 | torch.zeros((len(neg), 1), device=device)], dim=0) 446 | w = torch.cat([torch.ones((len(pos_part), 1), device=device) * pos_weight.view(-1, 1), 447 | torch.ones((len(neg), 1), device=device)]) 448 | x = torch.cat([pos_part, neg]) 449 | elif task_mode == 'regress': 450 | w = torch.cat([torch.ones((len(pos_part), 1), device=device), 451 | torch.ones((len(neg), 1), device=device)], dim=0) 452 | y = torch.cat([torch.ones((len(pos_part), 1), device=device) * pos_weight.view(-1, 1), 453 | torch.zeros((len(neg), 1), device=device)]) 454 | x = torch.cat([pos_part, neg]) 455 | else: 456 | print ("Wrong task mode") 457 | raise EOFError 458 | 459 | return x, y, w, size_list 460 | 461 | 462 | def save_embeddings(model, origin=False): 463 | model.eval() 464 | with torch.no_grad(): 465 | ids = np.arange(num_list[-1]) + 1 466 | ids = torch.Tensor(ids).long().to(device).view(-1, 1) 467 | embeddings = [] 468 | for j in range(math.ceil(len(ids) / batch_size)): 469 | x = ids[j * batch_size:min((j + 1) * batch_size, len(ids))] 470 | if origin: 471 | embed = model.get_node_embeddings(x) 472 | embed = embed.detach().cpu().numpy() 473 | embeddings.append(embed) 474 | 475 | embeddings = np.concatenate(embeddings, axis=0)[:, 0, :] 476 | np.save("../embeddings.npy" , embeddings) 477 | 478 | torch.cuda.empty_cache() 479 | return embeddings 480 | 481 | 482 | def predict(model, input): 483 | model.eval() 484 | output = [] 485 | new_batch_size = int(1e5) 486 | with torch.no_grad(): 487 | for j in trange(math.ceil(len(input) / new_batch_size)): 488 | x = input[j * new_batch_size:min((j + 1) * new_batch_size, len(input))] 489 | x = np2tensor_hyper(x, dtype=torch.long) 490 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 491 | output.append(model(x).detach().cpu().numpy()) 492 | output = np.concatenate(output, axis=0) 493 | torch.cuda.empty_cache() 494 | return output 495 | 496 | 497 | def get_attributes(): 498 | attribute_all = [] 499 | for i in range(len(num)): 500 | chrom = np.zeros((num[i], len(chrom_list))) 501 | chrom[:, i] = 1 502 | coor = np.arange(num[i]).reshape((-1, 1)).astype('float32') 503 | coor /= num[0] 504 | attribute = np.concatenate([chrom, coor], axis=-1) 505 | attribute_all.append(attribute) 506 | 507 | attribute_all = np.concatenate(attribute_all, axis=0) 508 | attribute_dict = np.concatenate([np.zeros((1, attribute_all.shape[-1])), attribute_all], axis=0).astype( 509 | 'float32') 510 | 511 | print("attribute_dict", attribute_dict.shape) 512 | return attribute_dict 513 | 514 | 515 | 516 | config = get_config() 517 | bottle_neck = config['embed_dim'] 518 | size_list = config['k-mer_size'] 519 | min_size, max_size = int(np.min(size_list)), int(np.max(size_list)) 520 | temp_dir = config['temp_dir'] 521 | 522 | quantile_cutoff_for_positive = config['quantile_cutoff_for_positive'] 523 | quantile_cutoff_for_unlabel = config['quantile_cutoff_for_unlabel'] 524 | 525 | min_dis = config['min_distance'] 526 | chrom_list = config['chrom_list'] 527 | neg_num = 3 528 | batch_size = 96 529 | loss = F.binary_cross_entropy_with_logits 530 | model_name = 'model.chkpt' 531 | current_time = datetime.datetime.now() 532 | task_mode = 'class' 533 | 534 | 535 | 536 | neighbor_mask = [] 537 | 538 | chrom_range = np.load(os.path.join(temp_dir,"chrom_range.npy")) 539 | node2chrom = np.load(os.path.join(temp_dir,"node2chrom.npy"), allow_pickle=True).item() 540 | num = [] 541 | for v in chrom_range: 542 | num.append(v[1] - v[0]) 543 | 544 | num_list = np.cumsum(num) 545 | zero_num_list = np.array([0] + list(num_list)) 546 | print("Node type num", num) 547 | 548 | data_list = [] 549 | weight_list = [] 550 | from sklearn.preprocessing import QuantileTransformer 551 | for size in size_list: 552 | data = np.load(os.path.join(temp_dir,"all_%d_counter.npy" % size)).astype('int') 553 | weight = np.load(os.path.join(temp_dir,"all_%d_freq_counter.npy" % size)).astype('float32') 554 | print("before filter", "size", size, "length", len(data)) 555 | weight = QuantileTransformer(n_quantiles=1000, output_distribution='uniform').fit_transform(weight.reshape((-1,1))).reshape((-1)) 556 | mask = weight > quantile_cutoff_for_positive 557 | # mask = weight >= cutoff 558 | data = data[mask] 559 | weight = weight[mask] 560 | print("after filter", "size", size, "length", len(data)) 561 | for datum in data: 562 | data_list.append(datum) 563 | weight_list.append(weight) 564 | 565 | data = np.array(data_list) 566 | weight = np.concatenate(weight_list,axis = 0) 567 | 568 | 569 | embeddings_initial = [] 570 | inter_initial = np.load(os.path.join(temp_dir, "inter_adj.npy")).astype('float32') 571 | adj = np.load(os.path.join(temp_dir, "intra_adj.npy")).astype('float32') 572 | for v in chrom_range: 573 | temp = adj[v[0] - 1:v[1] - 1, v[0] - 1:v[1] - 1] 574 | temp = np.corrcoef(temp).astype('float32') 575 | temp[np.isnan(temp)] = 0.0 576 | print (temp.shape) 577 | embeddings_initial.append(temp) 578 | 579 | 580 | attribute_dict = get_attributes() 581 | 582 | num = torch.as_tensor(num) 583 | num_list = torch.as_tensor(num_list) 584 | print(num, num_list) 585 | 586 | compress = True 587 | # Note that, no matter how many node types are here, make sure the 588 | # hyperedge (N1,N2,N3,...) has id, N1 < N2 < N3... 589 | train_dict = test_dict = [set() for i in range(max_size+1)] 590 | 591 | index = np.arange(len(data)) 592 | 593 | print ("weight",weight) 594 | weight /= np.mean(weight) 595 | weight *= neg_num 596 | print (weight) 597 | 598 | np.random.shuffle(index) 599 | split = int(0.8 * len(index)) 600 | train_data = data[index[:split]] 601 | test_data = data[index[split:]] 602 | train_weight = weight[index[:split]] 603 | test_weight = weight[index[split:]] 604 | 605 | print("train data amount", len(train_data)) 606 | 607 | print("dict_size", len(train_dict[-1]), len(test_dict[-1])) 608 | 609 | node_embedding = MultipleEmbedding( 610 | embeddings_initial, 611 | bottle_neck, 612 | False, 613 | num_list, chrom_range, inter_initial).to(device) 614 | 615 | classifier_model = Classifier( 616 | n_head=8, 617 | d_model=bottle_neck, 618 | d_k=bottle_neck, 619 | d_v=bottle_neck, 620 | node_embedding=node_embedding, 621 | diag_mask=True, 622 | bottle_neck=bottle_neck, 623 | attribute_dict=attribute_dict).to(device) 624 | 625 | save_embeddings(classifier_model, True) 626 | 627 | 628 | params_list = list(classifier_model.parameters()) 629 | 630 | optimizer = torch.optim.AdamW(params_list, lr=1e-3, amsgrad=False) 631 | 632 | model_parameters = filter(lambda p: p.requires_grad, params_list) 633 | params = sum([np.prod(p.size()) for p in model_parameters]) 634 | print("params to be trained", params) 635 | 636 | 637 | alpha = 0.0 638 | beta = 1.0 639 | train(classifier_model, 640 | loss=loss, 641 | training_data=(train_data, train_weight), 642 | validation_data=(test_data, test_weight), 643 | optimizer=[optimizer], epochs=3, batch_size=batch_size) 644 | 645 | 646 | data_list = [] 647 | weight_list = [] 648 | from sklearn.preprocessing import QuantileTransformer 649 | for size in size_list: 650 | data = np.load(os.path.join(temp_dir,"all_%d_counter.npy" % size)).astype('int') 651 | weight = np.load(os.path.join(temp_dir,"all_%d_freq_counter.npy" % size)).astype('float32') 652 | print("before filter", "size", size, "length", len(data)) 653 | weight = QuantileTransformer(n_quantiles=1000, output_distribution='uniform').fit_transform(weight.reshape((-1,1))).reshape((-1)) 654 | mask = weight > quantile_cutoff_for_unlabel 655 | data = data[mask] 656 | weight = weight[mask] 657 | print("after filter", "size", size, "length", len(data)) 658 | for datum in data: 659 | data_list.append(datum) 660 | weight_list.append(weight) 661 | 662 | dict_data = np.array(data_list) 663 | 664 | test_dict = build_hash(dict_data, compress=compress, max_size=max_size, 665 | min_size=min_size) 666 | 667 | train_dict = test_dict 668 | 669 | print ("Finish building Dict") 670 | 671 | optimizer = torch.optim.AdamW(params_list, lr=1e-3, amsgrad=False) 672 | alpha = 1.0 673 | beta = 0.001 674 | 675 | train(classifier_model, 676 | loss=loss, 677 | training_data=(train_data, train_weight), 678 | validation_data=(test_data, test_weight), 679 | optimizer=[optimizer], epochs=30, batch_size=batch_size) 680 | 681 | checkpoint = torch.load(os.path.join(temp_dir, model_name)) 682 | classifier_model.load_state_dict(checkpoint['model_link']) 683 | 684 | save_embeddings(classifier_model, True) 685 | torch.save(classifier_model, os.path.join(temp_dir, "model2load")) 686 | 687 | -------------------------------------------------------------------------------- /Code/plot_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | from sklearn.decomposition import PCA 6 | 7 | 8 | vec = np.load("../embeddings.npy") 9 | label = np.load("../subcompartment_label_hg38_1Mb.npy") 10 | vec = vec[label != -1] 11 | label = label[label != -1] 12 | vec = PCA(n_components=2).fit_transform(vec) 13 | 14 | 15 | 16 | label = np.array(["State"+str(label[i]) for i in range(len(label))]) 17 | g = sns.scatterplot(x = vec[:,0],y = vec[:,1],hue = label,alpha = 1.0,linewidth=0, s = 30, ) 18 | plt.savefig("../scatter.png") -------------------------------------------------------------------------------- /Code/predict_multiway.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import * 3 | 4 | import numpy as np 5 | from utils import * 6 | from torch.nn.utils.rnn import pad_sequence 7 | from sklearn.preprocessing import QuantileTransformer 8 | import os 9 | import sys 10 | import torch.nn.functional as F 11 | import random 12 | import argparse 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="predict multi-way interactions") 17 | parser.add_argument("-i", "--file", type=str) 18 | parser.add_argument("-o", "--output", type=str, default="./output.txt") 19 | 20 | return parser.parse_args() 21 | 22 | 23 | def parse_file(filepath): 24 | file1 = open(filepath, "r") 25 | 26 | bin2node = np.load(os.path.join(temp_dir, "bin2node.npy"), allow_pickle=True).item() 27 | 28 | line = file1.readline() 29 | count = 0 30 | final = [] 31 | 32 | while line: 33 | info_list = line.strip().split("\t") 34 | temp = [] 35 | 36 | for info in info_list: 37 | try: 38 | chrom, bin_ = info.split(":") 39 | except: 40 | print(info) 41 | raise EOFError 42 | if chrom not in chrom_list: 43 | continue 44 | bin_ = int(math.floor(int(bin_) / res)) * res 45 | bin_ = "%s:%d" % (chrom, bin_) 46 | node = bin2node[bin_] 47 | temp.append(node) 48 | temp = list(set(temp)) 49 | 50 | temp.sort() 51 | count += 1 52 | if count % 100 == 0: 53 | print("%d\r" % count, end="") 54 | sys.stdout.flush() 55 | if len(temp) > 1: 56 | final.append(temp) 57 | 58 | line = file1.readline() 59 | 60 | return final 61 | 62 | def get_free_gpu(): 63 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > ./tmp') 64 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 65 | if len(memory_available) > 0: 66 | id = int(np.argmax(memory_available)) 67 | print("setting to gpu:%d" % id) 68 | torch.cuda.set_device(id) 69 | return "cuda:%d" % id 70 | else: 71 | return 72 | 73 | 74 | def predict(model, input): 75 | model.eval() 76 | output = [] 77 | new_batch_size = int(1e4) 78 | with torch.no_grad(): 79 | for j in trange(math.ceil(len(input) / new_batch_size)): 80 | x = input[j * new_batch_size:min((j + 1) * new_batch_size, len(input))] 81 | x = np2tensor_hyper(x, dtype=torch.long) 82 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 83 | output.append(model(x).detach().cpu().numpy()) 84 | torch.cuda.empty_cache() 85 | output = np.concatenate(output, axis=0) 86 | 87 | return output 88 | 89 | 90 | if torch.cuda.is_available(): 91 | current_device = get_free_gpu() 92 | else: 93 | current_device = 'cpu' 94 | 95 | 96 | config = get_config() 97 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 98 | temp_dir = config['temp_dir'] 99 | res = config['resolution'] 100 | chrom_list = config['chrom_list'] 101 | 102 | 103 | args = parse_args() 104 | 105 | if type(args.file) != str: 106 | print ("invalid filepath") 107 | raise EOFError 108 | else: 109 | samples = np.array(parse_file(args.file)) 110 | 111 | classifier_model = torch.load(os.path.join(temp_dir, "model2load"), map_location=current_device) 112 | proba = predict(classifier_model, samples) 113 | proba = torch.sigmoid(torch.from_numpy(proba)).numpy() 114 | np.savetxt(args.output, proba) -------------------------------------------------------------------------------- /Code/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import math 4 | from tqdm import tqdm, trange 5 | import sys, os 6 | from utils import * 7 | import h5py 8 | 9 | 10 | def build_node_dict(): 11 | tab = pd.read_table(chrom_size, header=None, sep="\t") 12 | tab.columns = ['chr', 'size'] 13 | print(tab) 14 | 15 | bin2node = {} 16 | node2bin = {} 17 | node2chrom = {} 18 | chrom_range = [] 19 | count = 1 20 | 21 | for j, chrom in enumerate(chrom_list): 22 | size = np.max(tab['size'][tab['chr'] == chrom]) 23 | max_bin_chrom = math.ceil(size / res) 24 | 25 | temp = [count] 26 | for i in range(max_bin_chrom + 1): 27 | bin_ = "%s:%d" % (chrom, i * res) 28 | bin2node[bin_] = count 29 | node2bin[count] = bin_ 30 | node2chrom[count] = j 31 | count += 1 32 | temp.append(count) 33 | chrom_range.append(temp) 34 | print(chrom_range) 35 | 36 | np.save(os.path.join(temp_dir, "chrom_range.npy"), chrom_range) 37 | np.save(os.path.join(temp_dir, "bin2node.npy"), bin2node) 38 | np.save(os.path.join(temp_dir, "node2chrom.npy"), node2chrom) 39 | np.save(os.path.join(temp_dir,"node2bin.npy"), node2bin) 40 | 41 | 42 | def parse_file(): 43 | file1 = open(cluster_path, "r") 44 | 45 | bin2node = np.load(os.path.join(temp_dir, "bin2node.npy"), allow_pickle=True).item() 46 | 47 | line = file1.readline() 48 | count = 0 49 | final = [] 50 | 51 | while line: 52 | info_list = line.strip().split("\t")[1:] 53 | temp = [] 54 | if (len(info_list) < 2) or (len(info_list) > max_cluster_size * 50): 55 | line = file1.readline() 56 | continue 57 | 58 | for info in info_list: 59 | try: 60 | chrom, bin_ = info.split(":") 61 | except: 62 | print (info) 63 | raise EOFError 64 | if chrom not in chrom_list: 65 | continue 66 | bin_ = int(math.floor(int(bin_) / res)) * res 67 | bin_ = "%s:%d" % (chrom, bin_) 68 | node = bin2node[bin_] 69 | temp.append(node) 70 | temp = list(set(temp)) 71 | 72 | 73 | if len(temp) > max_cluster_size: 74 | line = file1.readline() 75 | continue 76 | 77 | temp.sort() 78 | count += 1 79 | if count % 100 == 0: 80 | print("%d\r" % count, end="") 81 | sys.stdout.flush() 82 | if len(temp) > 1: 83 | final.append(temp) 84 | 85 | line = file1.readline() 86 | 87 | np.save(os.path.join(temp_dir, "edge_list.npy"), final) 88 | 89 | 90 | def edgelist2adj(): 91 | edge_list = np.load(os.path.join(temp_dir, "edge_list.npy"), allow_pickle=True) 92 | chrom_range = np.load(os.path.join(temp_dir, "chrom_range.npy"), allow_pickle=True) 93 | 94 | node_num = int(np.max(chrom_range)) 95 | print(node_num) 96 | 97 | adj = np.zeros((node_num - 1, node_num - 1)) 98 | 99 | for e in tqdm(edge_list): 100 | for i in e: 101 | for j in e: 102 | if i!=j: 103 | adj[i-1,j-1] += 1 104 | print (adj) 105 | np.save(os.path.join(temp_dir, "edge_list_adj.npy"), adj) 106 | 107 | def parse_cool_contact(): 108 | f = h5py.File(mcool_path, "r") 109 | f = f['resolutions'] 110 | f = f[str(res)] 111 | 112 | bin2node = np.load(os.path.join(temp_dir, "bin2node.npy"), allow_pickle=True).item() 113 | node2chrom = np.load(os.path.join(temp_dir, "node2chrom.npy"), allow_pickle=True).item() 114 | 115 | cool_bin_info_chrom = np.array(f['bins']['chrom']) 116 | cool_bin_info_start = np.array(f['bins']['start']) 117 | chrom_name = np.array(f['chroms']['name']).astype('str') 118 | print ("chrom_name", chrom_name) 119 | 120 | cool_index2node = {} 121 | print ("Building dict to map cool bin to MATCHA node id") 122 | for i in trange(len(cool_bin_info_chrom)): 123 | chrom = cool_bin_info_chrom[i] 124 | start = cool_bin_info_start[i] 125 | chrom = chrom_name[chrom] 126 | 127 | if chrom not in chrom_list: 128 | print (chrom) 129 | continue 130 | bin = "%s:%d" %(chrom, start) 131 | 132 | node = bin2node[bin] 133 | cool_index2node[i] = node 134 | 135 | chrom_range = np.load(os.path.join(temp_dir, "chrom_range.npy")) 136 | 137 | node_num = int(np.max(chrom_range)) 138 | print(node_num) 139 | 140 | intra_adj = np.zeros((node_num - 1, node_num - 1)) 141 | inter_adj = np.zeros((node_num - 1, node_num - 1)) 142 | 143 | 144 | cool_index_bin1 = np.array(f['pixels']['bin1_id']) 145 | cool_index_bin2 = np.array(f['pixels']['bin2_id']) 146 | if 'balanced' in f['pixels'].keys(): 147 | cool_count = np.array(f['pixels']['balanced']) 148 | else: 149 | cool_count = np.array(f['pixels']['count']) 150 | 151 | print ("Building adjacency matrxi from mcool file") 152 | for i in trange(len(cool_index_bin1)): 153 | index1 = cool_index_bin1[i] 154 | index2 = cool_index_bin2[i] 155 | if (not index1 in cool_index2node) or (not index2 in cool_index2node): 156 | continue 157 | # minus 1 because, our node id starts at 1 158 | node1 = cool_index2node[index1] - 1 159 | node2 = cool_index2node[index2] - 1 160 | count = float(cool_count[i]) 161 | 162 | if not np.isnan(count): 163 | chrom1 = node2chrom[node1 + 1] 164 | chrom2 = node2chrom[node2 + 1] 165 | 166 | 167 | if chrom1 == chrom2: 168 | intra_adj[node1, node2] += count 169 | intra_adj[node2, node1] += count 170 | else: 171 | inter_adj[node1, node2] += count 172 | inter_adj[node2, node1] += count 173 | 174 | print(intra_adj, inter_adj) 175 | np.save(os.path.join(temp_dir, "intra_adj.npy"), intra_adj) 176 | np.save(os.path.join(temp_dir, "inter_adj.npy"), inter_adj) 177 | 178 | def build_subcompartment_label(): 179 | tab = pd.read_table("../gm12878_subcompartments_hg38.bed", sep="\t", header=None) 180 | tab = tab[tab.columns[:4]] 181 | bin2node = np.load(os.path.join(temp_dir, "bin2node.npy"), allow_pickle=True).item() 182 | tab.columns = ['chrom', 'start', 'end', 'label'] 183 | print(tab) 184 | chrom_range = np.load(os.path.join(temp_dir, "chrom_range.npy"), allow_pickle=True) 185 | 186 | node_num = int(np.max(chrom_range)) 187 | print(node_num) 188 | 189 | state_dict = {'A1': 0, 'A2': 1, 'B1': 2, 'B2': 3, 'B3': 4} 190 | 191 | label_list = np.ones((node_num, 10)) * -1 192 | for i in range(len(tab)): 193 | chrom = tab['chrom'][i] 194 | start = tab['start'][i] 195 | end = tab['end'][i] 196 | label = tab['label'][i] 197 | if label in state_dict: 198 | label = state_dict[label] 199 | else: 200 | label = -1 201 | 202 | start = int(math.floor(start / 100000)) 203 | end = int(math.floor(end / 100000)) 204 | 205 | for j in range(start, end + 1): 206 | larger_bin = int(math.floor(j / 10)) 207 | coord = "%s:%d" % (chrom, larger_bin * 1000000) 208 | if coord in bin2node: 209 | coord = bin2node[coord] 210 | label_list[coord, j % 10] = label 211 | 212 | print(label_list, np.min(label_list), np.max(label_list)) 213 | 214 | final = [] 215 | 216 | for vec in tqdm(label_list): 217 | unique, count = np.unique(vec, return_counts=True) 218 | if np.max(count) >= 6: 219 | pick = unique[np.argmax(count)] 220 | final.append(pick) 221 | else: 222 | final.append(-1) 223 | final = np.array(final) 224 | print(final, np.sum(final != -1)) 225 | final = final[1:] 226 | np.save("../subcompartment_label_hg38_1Mb.npy", final) 227 | 228 | 229 | config = get_config() 230 | res = config['resolution'] 231 | chrom_list = config['chrom_list'] 232 | temp_dir = config['temp_dir'] 233 | cluster_path = config['cluster_path'] 234 | mcool_path = config['mcool_path'] 235 | chrom_size = config['chrom_size'] 236 | max_cluster_size = config['max_cluster_size'] 237 | if not os.path.exists(temp_dir): 238 | os.mkdir(temp_dir) 239 | build_node_dict() 240 | parse_file() 241 | # edgelist2adj() 242 | parse_cool_contact() 243 | 244 | # build_subcompartment_label() -------------------------------------------------------------------------------- /Code/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | from sklearn.metrics import average_precision_score 6 | from sklearn.metrics import roc_auc_score 7 | from concurrent.futures import as_completed, ProcessPoolExecutor 8 | from copy import copy, deepcopy 9 | from pybloom_live import BloomFilter 10 | import math 11 | from tqdm import tqdm, trange 12 | import os 13 | 14 | def add_padding_idx(vec): 15 | if len(vec.shape) == 1: 16 | return np.asarray([np.sort(np.asarray(v) + 1).astype('int') 17 | for v in tqdm(vec)]) 18 | else: 19 | vec = np.asarray(vec) + 1 20 | vec = np.sort(vec, axis=-1) 21 | return vec.astype('int') 22 | 23 | 24 | def np2tensor_hyper(vec, dtype): 25 | vec = np.asarray(vec) 26 | if len(vec.shape) == 1: 27 | return [torch.as_tensor(v, dtype=dtype) for v in vec] 28 | else: 29 | return torch.as_tensor(vec, dtype=dtype) 30 | 31 | 32 | def roc_auc_cuda(y_true, y_pred, size_list, max_size): 33 | roc_str, aupr_str = "", "" 34 | try: 35 | y_t = (y_true > 0.5).float().cpu().detach().numpy().reshape((-1, 1)) 36 | y_p = y_pred.cpu().detach().numpy().reshape((-1, 1)) 37 | roc, aupr = roc_auc_score( 38 | y_t, y_p), average_precision_score( 39 | y_t, y_p) 40 | roc_str += "%s %.3f " % ('all', roc) 41 | aupr_str += "%s %.3f " % ('all', aupr) 42 | 43 | for s in np.unique(size_list): 44 | y_t = (y_true[size_list == s] > 0.5).float().cpu().detach().numpy().reshape((-1, 1)) 45 | y_p = y_pred[size_list == s].cpu().detach().numpy().reshape((-1, 1)) 46 | roc, aupr = roc_auc_score( 47 | y_t, y_p), average_precision_score( 48 | y_t, y_p) 49 | roc_str += "%s %.3f " % (str(s), roc) 50 | aupr_str += "%s %.3f " % (str(s), aupr) 51 | 52 | return roc_str[:-1], aupr_str[:-1] 53 | except BaseException: 54 | return 0.0, 0.0 55 | 56 | 57 | def accuracy(output, target, size_list=None, max_size=None): 58 | acc_str = "" 59 | if size_list is not None: 60 | for s in np.unique(size_list): 61 | pred = output[size_list == s] >= 0.5 62 | truth = target[size_list == s] >= 0.5 63 | acc = torch.sum(pred.eq(truth)) 64 | acc = float(acc) * 1.0 / (truth.shape[0] * 1.0) 65 | acc_str += "%s %.3f " % (str(s), acc) 66 | else: 67 | pred = output >= 0.5 68 | truth = target >= 0.5 69 | acc = torch.sum(pred.eq(truth)) 70 | acc = float(acc) * 1.0 / (truth.shape[0] * 1.0) 71 | acc_str += "%.3f " % (acc) 72 | return acc_str 73 | 74 | 75 | def build_hash(data, compress, min_size, max_size, capacity=None): 76 | if capacity is None: 77 | capacity = len(data) * 5 78 | capacity = int(math.ceil(capacity)) + 1000 79 | print("total_capacity", capacity) 80 | dict_list = [] 81 | for i in range(max_size + 1): 82 | if i < min_size: 83 | dict_list.append(BloomFilter(10, 1e-3)) 84 | else: 85 | dict_list.append(BloomFilter(capacity, 1e-3)) 86 | 87 | print(len(dict_list)) 88 | for datum in tqdm(data): 89 | dict_list[len(datum)].add(tuple(datum)) 90 | 91 | print(len(dict_list[min_size]) / dict_list[min_size].capacity) 92 | 93 | print(len(dict_list[-1])) 94 | length_list = [len(dict_list[i]) for i in range(len(dict_list))] 95 | print(length_list) 96 | # np.save("../data/SPRITE/length.npy", length_list) 97 | return dict_list 98 | 99 | 100 | def parallel_build_hash(data, func, initial=None, compress=False, min_size=-1, max_size=-1): 101 | import multiprocessing 102 | cpu_num = multiprocessing.cpu_count() 103 | np.random.shuffle(data) 104 | data = np.array_split(data, min(cpu_num * 1, 32)) 105 | length = len(data) 106 | dict1 = deepcopy(initial) 107 | pool = ProcessPoolExecutor(max_workers=cpu_num) 108 | process_list = [] 109 | 110 | if func == 'build_hash': 111 | func = build_hash 112 | 113 | for datum in data: 114 | process_list.append(pool.submit(func, datum, compress, min_size, max_size, length)) 115 | 116 | for p in as_completed(process_list): 117 | a = p.result() 118 | if dict1 is None: 119 | dict1 = a 120 | elif compress: 121 | for i, d in enumerate(dict1): 122 | dict1[i] = d.union(a[i]) 123 | else: 124 | for i, d in enumerate(dict1): 125 | dict1[i] = d.update(a[i]) 126 | del a 127 | pool.shutdown(wait=True) 128 | 129 | # if args.data in ['schic','ramani']: 130 | # print (num[0]) 131 | # new_list_of_set = [set() for i in range(int(num[0]+1))] 132 | # for s in dict1: 133 | # try: 134 | # new_list_of_set[s[0]].add(s) 135 | # except: 136 | # print (s) 137 | # raise EOFError 138 | # dict1 = new_list_of_set 139 | return dict1 140 | 141 | 142 | def sync_shuffle(sample_list, max_num=-1): 143 | index = torch.randperm(len(sample_list[0])) 144 | if max_num > 0: 145 | index = index[:max_num] 146 | new_list = [] 147 | for s in sample_list: 148 | new_list.append(s[index]) 149 | return new_list 150 | 151 | 152 | def pass_(x): 153 | return x 154 | 155 | 156 | 157 | def get_config(): 158 | c = open("./config.JSON","r") 159 | return json.load(c) -------------------------------------------------------------------------------- /History_version/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/.DS_Store -------------------------------------------------------------------------------- /History_version/Code/Modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm, trange 6 | import copy 7 | import math 8 | from torch.autograd import Function 9 | from utils import * 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | device_ids = [0, 1] 13 | activation = F.tanh 14 | 15 | def gelu_accurate(x): 16 | if not hasattr(gelu_accurate, "_a"): 17 | gelu_accurate._a = math.sqrt(2 / math.pi) 18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 19 | 20 | 21 | def gelu(x: torch.Tensor) -> torch.Tensor: 22 | if hasattr(torch.nn.functional, 'gelu'): 23 | return torch.nn.functional.gelu(x.float()).type_as(x) 24 | else: 25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 26 | 27 | 28 | def get_non_pad_mask(seq): 29 | assert seq.dim() == 2 30 | return seq.ne(0).type(torch.float).unsqueeze(-1) 31 | 32 | def get_attn_key_pad_mask(seq_k, seq_q): 33 | ''' For masking out the padding part of key sequence. ''' 34 | 35 | # Expand to fit the shape of key query attention matrix. 36 | len_q = seq_q.size(1) 37 | padding_mask = seq_k.eq(0) 38 | padding_mask = padding_mask.unsqueeze( 39 | 1).expand(-1, len_q, -1) # b x lq x lk 40 | 41 | return padding_mask 42 | 43 | class Wrap_Embedding(torch.nn.Embedding): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | def forward(self, *input): 48 | return super().forward(*input), torch.Tensor([0]).to(device) 49 | 50 | # Used only for really big adjacency matrix 51 | class SparseEmbedding(nn.Module): 52 | def __init__(self, embedding_weight, sparse=False): 53 | super().__init__() 54 | print(embedding_weight.shape) 55 | self.sparse = sparse 56 | if self.sparse: 57 | self.embedding = embedding_weight 58 | else: 59 | try: 60 | try: 61 | self.embedding = torch.from_numpy( 62 | np.asarray(embedding_weight.todense())).to(device) 63 | except BaseException: 64 | self.embedding = torch.from_numpy( 65 | np.asarray(embedding_weight)).to(device) 66 | except Exception as e: 67 | print("Sparse Embedding Error",e) 68 | self.sparse = True 69 | self.embedding = embedding_weight 70 | 71 | def forward(self, x): 72 | 73 | if self.sparse: 74 | x = x.cpu().numpy() 75 | x = x.reshape((-1)) 76 | temp = np.asarray((self.embedding[x, :]).todense()) 77 | 78 | return torch.from_numpy(temp).to(device) 79 | else: 80 | return self.embedding[x, :] 81 | 82 | class TiedAutoEncoder(nn.Module): 83 | def __init__(self, shape_list,use_bias = True): 84 | super().__init__() 85 | self.weight_list = [] 86 | self.bias_list = [] 87 | self.use_bias = use_bias 88 | self.recon_bias_list = [] 89 | for i in range(len(shape_list) - 1): 90 | self.weight_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i+1],shape_list[i]).to(device))) 91 | self.bias_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i+1]).to(device))) 92 | self.recon_bias_list.append(nn.parameter.Parameter(torch.Tensor(shape_list[i]).to(device))) 93 | self.recon_bias_list = self.recon_bias_list[::-1] 94 | 95 | for i,w in enumerate(self.weight_list): 96 | self.register_parameter('tied weight_%d' % i,w) 97 | self.register_parameter('tied bias1', self.bias_list[i]) 98 | self.register_parameter('tied bias2', self.recon_bias_list[i]) 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | for i,w in enumerate(self.weight_list): 104 | torch.nn.init.kaiming_uniform_(self.weight_list[i], a=math.sqrt(5)) 105 | 106 | for i, b in enumerate(self.bias_list): 107 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_list[i]) 108 | bound = 1 / math.sqrt(fan_in) 109 | torch.nn.init.uniform_(self.bias_list[i], -bound, bound) 110 | temp_weight_list = self.weight_list[::-1] 111 | for i, b in enumerate(self.recon_bias_list): 112 | fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(temp_weight_list[i]) 113 | bound = 1 / math.sqrt(fan_out) 114 | torch.nn.init.uniform_(self.recon_bias_list[i], -bound, bound) 115 | 116 | def forward(self, input): 117 | # return input, input 118 | encoded_feats = input 119 | for i in range(len(self.weight_list)): 120 | if self.use_bias: 121 | encoded_feats = F.linear(encoded_feats, self.weight_list[i], self.bias_list[i]) 122 | else: 123 | encoded_feats = F.linear(encoded_feats, self.weight_list[i]) 124 | if i < len(self.weight_list) - 1: 125 | encoded_feats = activation(encoded_feats) 126 | 127 | reverse_weight_list = self.weight_list[::-1] 128 | reconstructed_output = encoded_feats 129 | for i in range(len(self.recon_bias_list)): 130 | reconstructed_output = F.linear(reconstructed_output, reverse_weight_list[i].t(), self.recon_bias_list[i]) 131 | if i < len(self.recon_bias_list) - 1: 132 | reconstructed_output = activation(reconstructed_output) 133 | 134 | 135 | return encoded_feats, reconstructed_output 136 | 137 | class MultipleEmbedding(nn.Module): 138 | def __init__( 139 | self, 140 | embedding_weights, 141 | dim, 142 | sparse=True, 143 | num_list=None, 144 | chrom_range = None, 145 | inter_initial = None): 146 | super().__init__() 147 | print(dim) 148 | self.chrom_range = chrom_range 149 | print (chrom_range) 150 | self.num_list = torch.tensor([0] + list(num_list)).to(device) 151 | print(self.num_list) 152 | self.dim = dim 153 | 154 | self.embeddings = [] 155 | for i, w in enumerate(embedding_weights): 156 | self.embeddings.append(SparseEmbedding(w, sparse)) 157 | 158 | import scipy 159 | for i in trange(len(inter_initial)): 160 | temp = inter_initial[i, :] 161 | inter_initial[i, temp > 0] = scipy.stats.mstats.zscore(temp[temp > 0]).astype('float32') 162 | 163 | # inter_initial[inter_initial > 0] = scipy.stats.mstats.zscore(inter_initial[inter_initial > 0], axis=1).astype('float32') 164 | inter_initial[np.isnan(inter_initial)] = 0.0 165 | 166 | self.inter_initial = SparseEmbedding(inter_initial, sparse) 167 | 168 | 169 | self.label_info = torch.Tensor(np.load("../data/SPRITE/subcompartment_label_hg38_1Mb.npy", allow_pickle=True)).long().to(device) 170 | self.label_info = torch.Tensor( 171 | np.load("../data/SPRITE/compartment_hg38.npy", allow_pickle=True)).long().to(device) 172 | self.label_info = torch.cat([torch.zeros((1)).long().to(device), self.label_info],dim = 0) 173 | test = torch.zeros(1, device=device).long() 174 | self.input_size = [] 175 | for w in self.embeddings: 176 | self.input_size.append(w(test).shape[-1]) 177 | 178 | self.wstack = [TiedAutoEncoder([self.input_size[i],self.dim],use_bias=False).to(device) for i,w in enumerate(self.embeddings)] 179 | self.next_w = FeedForward([self.dim, self.dim]).to(device) 180 | self.recon = [FeedForward([self.dim, v[1] - v[0]]).to(device) for i,v in enumerate(self.chrom_range)] 181 | self.classifier = nn.Linear(self.dim, 2).to(device) 182 | # self.wstack = [nn.Linear(self.input_size[i],self.dim).to(device) for i,w in enumerate(self.embeddings)] 183 | self.norm_stack =[nn.BatchNorm1d(self.dim, affine=False).to(device) for w in self.embeddings] 184 | self.norm = nn.LayerNorm(self.dim).to(device) 185 | # self.norm = nn.BatchNorm1d(self.dim).to(device) 186 | self.domain_classifier = FeedForward([self.dim, 22]) 187 | self.add_module("Embedding_norm", self.norm) 188 | for i, w in enumerate(self.wstack): 189 | self.add_module("Embedding_Linear%d" % (i), w) 190 | self.add_module("Embedding_Linear", self.next_w) 191 | self.add_module("Embedding_recon%d" % (i), self.recon[i]) 192 | self.add_module("Embedding_norm%d" % (i), self.norm_stack[i]) 193 | self.add_module("domain_classifier", self.classifier) 194 | 195 | self.dropout = nn.Dropout(0.2) 196 | 197 | def forward(self, x): 198 | 199 | final = torch.zeros((len(x), self.dim)).to(device) 200 | recon_loss = torch.Tensor([0.0]).to(device) 201 | for i in range(len(self.num_list) - 1): 202 | select = (x >= (self.num_list[i] + 1)) & (x < (self.num_list[i + 1] + 1)) 203 | if torch.sum(select) == 0: 204 | continue 205 | adj = self.embeddings[i](x[select] - self.num_list[i] - 1) 206 | output = adj 207 | output = self.dropout(adj) 208 | output, recon = self.wstack[i](output) 209 | output = self.norm_stack[i](output) 210 | # try: 211 | # output = self.norm(output) 212 | # except: 213 | # print (output) 214 | # output = F.tanh(output) 215 | final[select] = output 216 | # recon_loss += sparse_autoencoder_error(recon, adj) 217 | # recon_loss += F.mse_loss(recon, adj) 218 | 219 | final = self.next_w(activation(final)) 220 | # final = self.norm(final) 221 | # pred = self.classifier(self.dropout(final)) 222 | # y = self.label_info[x] 223 | # recon_loss += F.cross_entropy(pred[y!= -1],y[y!= -1] ) 224 | 225 | random_chrom = np.random.choice(np.arange(len(self.chrom_range)),1)[0] 226 | other_chrom = (x < self.num_list[random_chrom] + 1) | (x >= self.num_list[random_chrom + 1] + 1) 227 | target = self.inter_initial(x[other_chrom] - 1) 228 | target = target[:,self.num_list[random_chrom]:self.num_list[random_chrom + 1]] 229 | recon = self.recon[random_chrom](final[other_chrom]) 230 | recon_loss += (target - recon).pow(2).mean(dim = -1).mean() * 100 231 | 232 | # domains = np.random.choice(np.arange(len(self.chrom_range)),2,replace=False) 233 | # source, target = domains 234 | # batch_size = len(x) 235 | # source_x = torch.arange(self.chrom_range[source][0], self.chrom_range[source][1]).long().to(device) 236 | # target_x = torch.arange(self.chrom_range[target][0], self.chrom_range[target][1]).long().to(device) 237 | # index_s = torch.tensor(np.random.choice(np.arange(len(source_x)),batch_size,replace=True)).long().to(device) 238 | # index_t = torch.tensor(np.random.choice(np.arange(len(target_x)),batch_size,replace=True)).long().to(device) 239 | # source_f,_ = self.wstack[source](self.embeddings[source](source_x[index_s]- self.num_list[source])) 240 | # target_f,_ =self.wstack[target](self.embeddings[target](target_x[index_t] - self.num_list[target])) 241 | # source_f = self.next_w(source_f) 242 | # target_f = self.next_w(target_f) 243 | # recon_loss += coral(source_f, target_f) 244 | 245 | return final, recon_loss 246 | 247 | class Classifier(nn.Module): 248 | def __init__( 249 | self, 250 | n_head, 251 | d_model, 252 | d_k, 253 | d_v, 254 | node_embedding, 255 | diag_mask, 256 | bottle_neck, 257 | attribute_dict=None, 258 | **args): 259 | super().__init__() 260 | 261 | self.pff_classifier = PositionwiseFeedForward([d_model, 1], reshape=True, use_bias=True) 262 | 263 | self.node_embedding = node_embedding 264 | self.encode1 = EncoderLayer( 265 | n_head, 266 | d_model, 267 | d_k, 268 | d_v, 269 | dropout_mul=0.3, 270 | dropout_pff=0.4, 271 | diag_mask=diag_mask, 272 | bottle_neck=bottle_neck) 273 | self.encode2 = EncoderLayer(n_head, d_model, d_k, d_v, dropout_mul=0.0, dropout_pff=0.0, diag_mask = diag_mask, bottle_neck=bottle_neck) 274 | self.diag_mask_flag = diag_mask 275 | self.layer_norm1 = nn.LayerNorm(d_model) 276 | self.layer_norm2 = nn.LayerNorm(d_model) 277 | 278 | def get_node_embeddings(self, x,return_recon = False): 279 | # shape of x: (b, tuple) 280 | sz_b, len_seq = x.shape 281 | x, recon_loss = self.node_embedding(x.view(-1)) 282 | if return_recon: 283 | return x.view(sz_b, len_seq, -1), recon_loss 284 | else: 285 | return x.view(sz_b, len_seq, -1) 286 | 287 | def get_embedding(self, x, slf_attn_mask, non_pad_mask,return_recon = False): 288 | if return_recon: 289 | x, recon_loss = self.get_node_embeddings(x,return_recon) 290 | else: 291 | x = self.get_node_embeddings(x, return_recon) 292 | dynamic, static, attn = self.encode1(x, x, slf_attn_mask, non_pad_mask) 293 | # dynamic, static1, attn = self.encode2(dynamic, static,slf_attn_mask, non_pad_mask) 294 | if return_recon: 295 | return dynamic, static, attn, recon_loss 296 | else: 297 | return dynamic, static, attn 298 | 299 | def get_embedding_static(self, x): 300 | if len(x.shape) == 1: 301 | x = x.view(-1, 1) 302 | flag = True 303 | else: 304 | flag = False 305 | slf_attn_mask = get_attn_key_pad_mask(seq_k=x, seq_q=x) 306 | non_pad_mask = get_non_pad_mask(x) 307 | x = self.get_node_embeddings(x) 308 | dynamic, static, attn = self.encode1(x, x, slf_attn_mask, non_pad_mask) 309 | # dynamic, static, attn = self.encode2(dynamic, static,slf_attn_mask, non_pad_mask) 310 | if flag: 311 | return static[:, 0, :] 312 | return static 313 | 314 | def forward(self, x, mask=None, get_outlier=None, return_recon = False): 315 | x = x.long() 316 | 317 | 318 | slf_attn_mask = get_attn_key_pad_mask(seq_k=x, seq_q=x) 319 | non_pad_mask = get_non_pad_mask(x) 320 | 321 | # output, recon_loss = self.get_node_embeddings(x,return_recon=True) 322 | # output = output.view(len(output),1,-1) 323 | if return_recon: 324 | dynamic, static, attn, recon_loss = self.get_embedding(x, slf_attn_mask, non_pad_mask,return_recon) 325 | else: 326 | dynamic, static, attn = self.get_embedding(x, slf_attn_mask, non_pad_mask, return_recon) 327 | dynamic = self.layer_norm1(dynamic) 328 | static = self.layer_norm2(static) 329 | sz_b, len_seq, dim = dynamic.shape 330 | 331 | if self.diag_mask_flag == 'True': 332 | output = (dynamic - static) ** 2 333 | # output = dynamic * static 334 | else: 335 | output = dynamic 336 | output = self.pff_classifier(output) 337 | output = F.sigmoid(output) 338 | 339 | mode = 'sum' 340 | 341 | if mode == 'min': 342 | output, _ = torch.max( 343 | (1 - output) * non_pad_mask, dim=-2, keepdim=False) 344 | output = 1 - output 345 | 346 | elif mode == 'sum': 347 | output = torch.sum(output * non_pad_mask, dim=-2, keepdim=False) 348 | mask_sum = torch.sum(non_pad_mask, dim=-2, keepdim=False) + 1e-15 349 | output /= mask_sum 350 | elif mode == 'first': 351 | output = output[:, 0, :] 352 | 353 | if return_recon: 354 | return output, recon_loss 355 | else: 356 | return output 357 | 358 | # A custom position-wise MLP. 359 | # dims is a list, it would create multiple layer with tanh between them 360 | # If dropout, it would add the dropout at the end. Before residual and 361 | # layer-norm 362 | 363 | 364 | class PositionwiseFeedForward(nn.Module): 365 | def __init__( 366 | self, 367 | dims, 368 | dropout=None, 369 | reshape=False, 370 | use_bias=True, 371 | residual=False, 372 | layer_norm=False): 373 | super(PositionwiseFeedForward, self).__init__() 374 | self.w_stack = [] 375 | self.dims = dims 376 | for i in range(len(dims) - 1): 377 | self.w_stack.append(nn.Conv1d(dims[i], dims[i + 1], 1, use_bias)) 378 | self.add_module("PWF_Conv%d" % (i), self.w_stack[-1]) 379 | self.reshape = reshape 380 | self.layer_norm = nn.LayerNorm(dims[-1]) 381 | 382 | if dropout is not None: 383 | self.dropout = nn.Dropout(dropout) 384 | else: 385 | self.dropout = None 386 | 387 | self.residual = residual 388 | self.layer_norm_flag = layer_norm 389 | 390 | def forward(self, x): 391 | output = x.transpose(1, 2) 392 | 393 | 394 | for i in range(len(self.w_stack) - 1): 395 | output = self.w_stack[i](output) 396 | output = activation(output) 397 | if self.dropout is not None: 398 | output = self.dropout(output) 399 | 400 | output = self.w_stack[-1](output) 401 | output = output.transpose(1, 2) 402 | 403 | if self.reshape: 404 | output = output.view(output.shape[0], -1, 1) 405 | 406 | if self.dims[0] == self.dims[-1]: 407 | # residual 408 | if self.residual: 409 | output += x 410 | 411 | if self.layer_norm_flag: 412 | output = self.layer_norm(output) 413 | 414 | return output 415 | 416 | 417 | # A custom position wise MLP. 418 | # dims is a list, it would create multiple layer with torch.tanh between them 419 | # We don't do residual and layer-norm, because this is only used as the 420 | # final classifier 421 | 422 | 423 | class FeedForward(nn.Module): 424 | ''' A two-feed-forward-layer module ''' 425 | 426 | def __init__(self, dims, dropout=None, reshape=False, use_bias=True): 427 | super(FeedForward, self).__init__() 428 | self.w_stack = [] 429 | for i in range(len(dims) - 1): 430 | self.w_stack.append(nn.Linear(dims[i], dims[i + 1], use_bias)) 431 | self.add_module("FF_Linear%d" % (i), self.w_stack[-1]) 432 | 433 | if dropout is not None: 434 | self.dropout = nn.Dropout(dropout) 435 | else: 436 | self.dropout = None 437 | 438 | self.reshape = reshape 439 | 440 | def forward(self, x): 441 | output = x 442 | for i in range(len(self.w_stack) - 1): 443 | output = self.w_stack[i](output) 444 | output = activation(output) 445 | if self.dropout is not None: 446 | output = self.dropout(output) 447 | output = self.w_stack[-1](output) 448 | 449 | if self.reshape: 450 | output = output.view(output.shape[0], -1, 1) 451 | 452 | return output 453 | 454 | 455 | class ScaledDotProductAttention(nn.Module): 456 | ''' Scaled Dot-Product Attention ''' 457 | 458 | def __init__(self, temperature): 459 | super().__init__() 460 | self.temperature = temperature 461 | 462 | def masked_softmax(self, vector: torch.Tensor, 463 | mask: torch.Tensor, 464 | dim: int = -1, 465 | memory_efficient: bool = False, 466 | mask_fill_value: float = -1e32) -> torch.Tensor: 467 | 468 | if mask is None: 469 | result = torch.nn.functional.softmax(vector, dim=dim) 470 | else: 471 | mask = mask.float() 472 | while mask.dim() < vector.dim(): 473 | mask = mask.unsqueeze(1) 474 | if not memory_efficient: 475 | # To limit numerical errors from large vector elements outside 476 | # the mask, we zero these out. 477 | result = torch.nn.functional.softmax(vector * mask, dim=dim) 478 | result = result * mask 479 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) 480 | else: 481 | masked_vector = vector.masked_fill( 482 | (1 - mask).byte(), mask_fill_value) 483 | result = torch.nn.functional.softmax(masked_vector, dim=dim) 484 | return result 485 | 486 | def forward(self, q, k, v, diag_mask, mask=None): 487 | attn = torch.bmm(q, k.transpose(1, 2)) 488 | attn = attn / self.temperature 489 | 490 | if mask is not None: 491 | attn = attn.masked_fill(mask, -float('inf')) 492 | 493 | attn = self.masked_softmax( 494 | attn, diag_mask, dim=-1, memory_efficient=True) 495 | 496 | 497 | output = torch.bmm(attn, v) 498 | 499 | return output, attn 500 | 501 | 502 | class MultiHeadAttention(nn.Module): 503 | ''' Multi-Head Attention module ''' 504 | 505 | def __init__( 506 | self, 507 | n_head, 508 | d_model, 509 | d_k, 510 | d_v, 511 | dropout, 512 | diag_mask, 513 | input_dim): 514 | super().__init__() 515 | 516 | self.n_head = n_head 517 | self.d_k = d_k 518 | self.d_v = d_v 519 | 520 | self.w_qs = nn.Linear(input_dim, n_head * d_k, bias=False) 521 | self.w_ks = nn.Linear(input_dim, n_head * d_k, bias=False) 522 | self.w_vs = nn.Linear(input_dim, n_head * d_v, bias=False) 523 | 524 | nn.init.normal_(self.w_qs.weight, mean=0, 525 | std=np.sqrt(2.0 / (d_model + d_k))) 526 | nn.init.normal_(self.w_ks.weight, mean=0, 527 | std=np.sqrt(2.0 / (d_model + d_k))) 528 | nn.init.normal_(self.w_vs.weight, mean=0, 529 | std=np.sqrt(2.0 / (d_model + d_v))) 530 | 531 | self.attention = ScaledDotProductAttention( 532 | temperature=np.power(d_k, 0.5)) 533 | 534 | self.fc1 = nn.Linear(n_head * d_v, d_model) 535 | self.fc2 = nn.Linear(n_head * d_v, d_model) 536 | 537 | self.layer_norm1 = nn.LayerNorm(input_dim) 538 | self.layer_norm2 = nn.LayerNorm(input_dim) 539 | self.layer_norm3 = nn.LayerNorm(input_dim) 540 | 541 | if dropout is not None: 542 | self.dropout = nn.Dropout(dropout) 543 | else: 544 | self.dropout = dropout 545 | 546 | self.diag_mask_flag = diag_mask 547 | self.diag_mask = None 548 | 549 | def pass_(self, inputs): 550 | return inputs 551 | 552 | def forward(self, q, k, v, diag_mask, mask=None): 553 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 554 | 555 | residual_dynamic = q 556 | residual_static = v 557 | 558 | q = self.layer_norm1(q) 559 | k = self.layer_norm2(k) 560 | v = self.layer_norm3(v) 561 | 562 | sz_b, len_q, _ = q.shape 563 | sz_b, len_k, _ = k.shape 564 | sz_b, len_v, _ = v.shape 565 | 566 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 567 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 568 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 569 | 570 | q = q.permute(2, 0, 1, 3).contiguous( 571 | ).view(-1, len_q, d_k) # (n*b) x lq x dk 572 | k = k.permute(2, 0, 1, 3).contiguous( 573 | ).view(-1, len_k, d_k) # (n*b) x lk x dk 574 | v = v.permute(2, 0, 1, 3).contiguous( 575 | ).view(-1, len_v, d_v) # (n*b) x lv x dv 576 | 577 | n = sz_b * n_head 578 | 579 | if self.diag_mask is not None: 580 | if (len(self.diag_mask) <= n) or ( 581 | self.diag_mask.shape[1] != len_v): 582 | self.diag_mask = torch.ones((len_v, len_v), device=device) 583 | if self.diag_mask_flag == 'True': 584 | self.diag_mask -= torch.eye(len_v, len_v, device=device) 585 | self.diag_mask = self.diag_mask.repeat(n, 1, 1) 586 | diag_mask = self.diag_mask 587 | else: 588 | diag_mask = self.diag_mask[:n] 589 | 590 | else: 591 | self.diag_mask = (torch.ones((len_v, len_v), device=device)) 592 | if self.diag_mask_flag == 'True': 593 | self.diag_mask -= torch.eye(len_v, len_v, device=device) 594 | self.diag_mask = self.diag_mask.repeat(n, 1, 1) 595 | diag_mask = self.diag_mask 596 | 597 | if mask is not None: 598 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 599 | 600 | dynamic, attn = self.attention(q, k, v, diag_mask, mask=mask) 601 | 602 | dynamic = dynamic.view(n_head, sz_b, len_q, d_v) 603 | dynamic = dynamic.permute( 604 | 1, 2, 0, 3).contiguous().view( 605 | sz_b, len_q, -1) # b x lq x (n*dv) 606 | static = v.view(n_head, sz_b, len_q, d_v) 607 | static = static.permute( 608 | 1, 2, 0, 3).contiguous().view( 609 | sz_b, len_q, -1) # b x lq x (n*dv) 610 | 611 | dynamic = self.dropout(self.fc1(dynamic)) if self.dropout is not None else self.fc1(dynamic) 612 | static = self.dropout(self.fc2(static)) if self.dropout is not None else self.fc2(static) 613 | 614 | 615 | return dynamic, static, attn 616 | 617 | 618 | class EncoderLayer(nn.Module): 619 | '''A self-attention layer + 2 layered pff''' 620 | 621 | def __init__( 622 | self, 623 | n_head, 624 | d_model, 625 | d_k, 626 | d_v, 627 | dropout_mul, 628 | dropout_pff, 629 | diag_mask, 630 | bottle_neck): 631 | super().__init__() 632 | self.n_head = n_head 633 | self.d_k = d_k 634 | self.d_v = d_v 635 | 636 | self.mul_head_attn = MultiHeadAttention( 637 | n_head, 638 | d_model, 639 | d_k, 640 | d_v, 641 | dropout=dropout_mul, 642 | diag_mask=diag_mask, 643 | input_dim=bottle_neck) 644 | self.pff_n1 = PositionwiseFeedForward( 645 | [d_model, d_model, d_model], dropout=dropout_pff, residual=True, layer_norm=True) 646 | self.pff_n2 = PositionwiseFeedForward( 647 | [bottle_neck, d_model, d_model], dropout=dropout_pff, residual=False, layer_norm=True) 648 | 649 | # self.dropout = nn.Dropout(0.2) 650 | 651 | def forward(self, dynamic, static, slf_attn_mask, non_pad_mask): 652 | dynamic, static1, attn = self.mul_head_attn( 653 | dynamic, dynamic, static, slf_attn_mask) 654 | dynamic = self.pff_n1(dynamic * non_pad_mask) * non_pad_mask 655 | static1 = self.pff_n2(static * non_pad_mask) * non_pad_mask 656 | 657 | return dynamic, static1, attn 658 | 659 | 660 | class DataGenerator(): 661 | def __init__(self, edges, edge_weight, batch_size, num_batch_per_iter, flag=False): 662 | self.edges = edges 663 | self.edge_weight = edge_weight 664 | self.batch_size = batch_size 665 | self.num_batch_per_iter = num_batch_per_iter 666 | self.pointer = 0 667 | self.flag = flag 668 | self.shuffle() 669 | 670 | def shuffle(self): 671 | if self.flag: 672 | print("reach end, shuffling") 673 | index = np.random.permutation(len(self.edges)) 674 | self.edges = self.edges[index] 675 | self.edge_weight = self.edge_weight[index] 676 | 677 | def next_iter(self): 678 | # if self.flag: 679 | # index = self.balance_num(self.edges) 680 | # edges = self.edges[index] 681 | # edge_weight = self.edge_weight[index] 682 | # return edges, edge_weight 683 | 684 | self.pointer += self.num_batch_per_iter * self.batch_size 685 | 686 | if self.pointer <= len(self.edges): 687 | index = range(self.pointer - self.num_batch_per_iter * self.batch_size, min(self.pointer, len(self.edges))) 688 | edges = self.edges[index] 689 | edge_weight = self.edge_weight[index] 690 | return edges, edge_weight 691 | else: 692 | # print(self.pointer, len(self.edges)) 693 | index = range(self.pointer - self.num_batch_per_iter * self.batch_size, min(self.pointer, len(self.edges))) 694 | edges = self.edges[index] 695 | edge_weight = self.edge_weight[index] 696 | 697 | self.shuffle() 698 | left = self.num_batch_per_iter * self.batch_size - len(index) 699 | self.pointer = 0 700 | self.pointer += left 701 | index = range(0, self.pointer) 702 | 703 | return np.concatenate([edges, self.edges[index]]), np.concatenate([edge_weight, self.edge_weight[index]]) 704 | 705 | def balance_num(self, edges): 706 | cell = edges[:, 0] 707 | final = [] 708 | choice, counts_ = np.unique(cell, return_counts=True) 709 | # num = int(np.mean(counts_)) 710 | num = 50 711 | for c in tqdm(choice): 712 | final.append(np.random.choice(np.where(cell == c)[0], num, replace=True)) 713 | final = np.concatenate(final, axis=-1) 714 | return final -------------------------------------------------------------------------------- /History_version/Code/analysis_SPRITE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib as mpl 4 | mpl.use("Agg") 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | from tqdm import tqdm, trange 8 | from itertools import combinations 9 | import time 10 | import os 11 | import argparse 12 | 13 | def parse_args(): 14 | # Parses the node2vec arguments. 15 | parser = argparse.ArgumentParser(description="Hyper-SAGNN") 16 | 17 | parser.add_argument('-n','--num', type=int, default=0) 18 | parser.add_argument('-s', '--size', type=int, default=3) 19 | args = parser.parse_args() 20 | return args 21 | 22 | args = parse_args() 23 | size = args.size 24 | thresh_list = [[2,3],[3,5],[5,8],[8,12]] 25 | # thresh_list = [[0,2]] 26 | def build_dict(i): 27 | hash_dict = {} 28 | for j,datum in enumerate(tqdm(data)): 29 | if i not in data_set_list[j]: 30 | continue 31 | combs = combinations(datum[datum > i + 5], size - 1) 32 | for comb in combs: 33 | new_comb = (i,) + comb 34 | if new_comb in hash_dict: 35 | hash_dict[new_comb] += 1 36 | else: 37 | hash_dict[new_comb] = 1 38 | new_hash = hash_dict 39 | 40 | np.save("./dict_%dnode/hash_tuples_%d.npy" % (size,i), new_hash) 41 | del hash_dict,new_hash 42 | return 0 43 | 44 | 45 | chrom_range = np.load("../data/SPRITE/chrom_range.npy") 46 | 47 | 48 | # data = np.load("../data/SPRITE/edge_list.npy", allow_pickle=True) 49 | # new_data = [] 50 | # for datum in data: 51 | # if (len(datum) >= size) & (len(datum) < 25): 52 | # new_data.append(np.array(datum)) 53 | # data = np.array(new_data) 54 | # 55 | # np.save("./shrink_SPRITE.npy",data) 56 | data= np.load("./shrink_SPRITE.npy",allow_pickle=True) 57 | data_set_list = [set(datum) for datum in tqdm(data)] 58 | from concurrent.futures import ProcessPoolExecutor, as_completed 59 | 60 | 61 | process_list = [] 62 | MAX_WORKER = 10 63 | pool = ProcessPoolExecutor(max_workers=MAX_WORKER) 64 | 65 | node_list = np.arange(2746) 66 | job_iter = iter(node_list) 67 | jobs_left = len(node_list) 68 | while jobs_left: 69 | for i in job_iter: 70 | process_list.append(pool.submit(build_dict,i)) 71 | if len(process_list) > MAX_WORKER * 1.3: 72 | break 73 | time.sleep(1) 74 | start = time.time() 75 | for p in as_completed(process_list): 76 | a = p.result() 77 | process_list.remove(p) 78 | del p 79 | jobs_left -= 1 80 | print (jobs_left) 81 | 82 | if time.time() - start > 5: 83 | break 84 | 85 | pool.shutdown(wait=True) 86 | 87 | 88 | for thres in thresh_list: 89 | if not os.path.exists("./%d_freq/%d/" % (size,thres[0])): 90 | os.mkdir("./%d_freq/%d/" % (size,thres[0])) 91 | if not os.path.exists("./%d_freq/upper/" % size): 92 | os.mkdir ("./%d_freq/upper/" % size) 93 | def dict2freq(i): 94 | print ("start %d" % i) 95 | hash_dict = np.load("./dict_%dnode/hash_tuples_%d.npy" % (size,i), allow_pickle=True).item() 96 | keys = np.array(list(hash_dict.keys())) 97 | if len(keys) > 0: 98 | dis_list = np.zeros((len(keys),size-1)) 99 | 100 | 101 | for j in range(size - 1): 102 | dis_list[:,j] = keys[:,j+1] - keys[:,j] 103 | 104 | dis_list = np.min(dis_list,axis = -1) 105 | old_length = len(keys) 106 | keys = keys[dis_list > 5] 107 | print (old_length,len(keys)) 108 | freq = np.array([hash_dict[tuple(k)] for k in tqdm(keys)]) 109 | 110 | for thres in thresh_list: 111 | temp = keys[(freq >= thres[0]) & (freq < thres[1])] 112 | np.save("./%d_freq/%d/%d.npy" % (size,thres[0], i), temp) 113 | 114 | temp = keys[freq >= thres[1]] 115 | np.save("./%d_freq/upper/%d.npy" % (size,i), temp) 116 | del hash_dict, temp 117 | 118 | 119 | 120 | from concurrent.futures import ProcessPoolExecutor, as_completed 121 | 122 | 123 | process_list = [] 124 | MAX_WORKER = 100 125 | pool = ProcessPoolExecutor(max_workers=MAX_WORKER) 126 | node_list = np.arange(2746) 127 | job_iter = iter(node_list) 128 | jobs_left = len(node_list) 129 | while jobs_left: 130 | for i in job_iter: 131 | process_list.append(pool.submit(dict2freq,i)) 132 | if len(process_list) > MAX_WORKER * 1.3: 133 | break 134 | time.sleep(1) 135 | start = time.time() 136 | for p in as_completed(process_list): 137 | a = p.result() 138 | process_list.remove(p) 139 | del p 140 | jobs_left -= 1 141 | time.sleep(1) 142 | 143 | if time.time() - start > 10: 144 | break 145 | 146 | pool.shutdown(wait=True) 147 | 148 | 149 | 150 | for thres in thresh_list: 151 | list1 = [] 152 | for i in trange(2746): 153 | temp = np.load("./%d_freq/%d/%d.npy" % (size,thres[0], i)) 154 | if len(temp) > 0: 155 | list1.append(temp) 156 | list1 = np.concatenate(list1,axis = 0) 157 | print (list1.shape) 158 | np.save("./%d_%d_%d.npy" %(thres[0], thres[1],size),list1) 159 | list1 = [] 160 | for i in trange(2746): 161 | if i == 0: 162 | continue 163 | temp = np.load("./%d_freq/upper/%d.npy" % (size,i)) 164 | if len(temp) > 0: 165 | list1.append(temp) 166 | list1 = np.concatenate(list1,axis = 0) 167 | print (list1.shape) 168 | np.save("./upper_%d.npy" % size,list1) -------------------------------------------------------------------------------- /History_version/Code/hg19.chrom.sizes.txt: -------------------------------------------------------------------------------- 1 | chr1 249250621 2 | chr2 243199373 3 | chr3 198022430 4 | chr4 191154276 5 | chr5 180915260 6 | chr6 171115067 7 | chr7 159138663 8 | chr8 146364022 9 | chr9 141213431 10 | chr10 135534747 11 | chr11 135006516 12 | chr12 133851895 13 | chr13 115169878 14 | chr14 107349540 15 | chr15 102531392 16 | chr16 90354753 17 | chr17 81195210 18 | chr18 78077248 19 | chr19 59128983 20 | chr20 63025520 21 | chr21 48129895 22 | chr22 51304566 -------------------------------------------------------------------------------- /History_version/Code/hg38.chrom.sizes.txt: -------------------------------------------------------------------------------- 1 | chr1 248956422 2 | chr2 242193529 3 | chr3 198295559 4 | chr4 190214555 5 | chr5 181538259 6 | chr6 170805979 7 | chr7 159345973 8 | chr8 145138636 9 | chr9 138394717 10 | chr10 133797422 11 | chr11 135086622 12 | chr12 133275309 13 | chr13 114364328 14 | chr14 107043718 15 | chr15 101991189 16 | chr16 90338345 17 | chr17 83257441 18 | chr18 80373285 19 | chr19 58617616 20 | chr20 64444167 21 | chr21 46709983 22 | chr22 50818468 23 | -------------------------------------------------------------------------------- /History_version/Code/main_SPRITE.py: -------------------------------------------------------------------------------- 1 | from pybloomfilter import BloomFilter 2 | import multiprocessing 3 | from torch.nn.utils.rnn import pad_sequence 4 | from torchsummary import summary 5 | from gensim.models import Word2Vec 6 | 7 | import time 8 | import argparse 9 | import warnings 10 | import random 11 | from random_walk import random_walk 12 | from random_walk_hyper import random_walk_hyper 13 | from Modules import * 14 | from utils import * 15 | 16 | import matplotlib as mpl 17 | import seaborn as sns 18 | import matplotlib.pyplot as plt 19 | mpl.use("Agg") 20 | 21 | cpu_num = multiprocessing.cpu_count() 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.deterministic = False 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | 30 | def parse_args(): 31 | # Parses the node2vec arguments. 32 | parser = argparse.ArgumentParser(description="Hyper-SAGNN") 33 | 34 | parser.add_argument('--data', type=str, default='SPRITE') 35 | parser.add_argument('--TRY', action='store_true') 36 | parser.add_argument('--FILTER', action='store_true') 37 | parser.add_argument('--grid', type=str, default='') 38 | parser.add_argument('--remark', type=str, default='') 39 | 40 | parser.add_argument('--random-walk', action='store_true') 41 | 42 | parser.add_argument('--dimensions', type=int, default=64, 43 | help='Number of dimensions. Default is 64.') 44 | 45 | parser.add_argument('-l', '--walk-length', type=int, default=80, 46 | help='Length of walk per source. Default is 40.') 47 | 48 | parser.add_argument('-r', '--num-walks', type=int, default=40, 49 | help='Number of walks per source. Default is 10.') 50 | 51 | parser.add_argument('-k', '--window-size', type=int, default=10, 52 | help='Context size for optimization. Default is 10.') 53 | 54 | parser.add_argument('--p', type=float, default=2, 55 | help='Return hyperparameter. Default is 2.') 56 | 57 | parser.add_argument('--q', type=float, default=0.25, 58 | help='Inout hyperparameter. Default is 0.25.') 59 | 60 | parser.add_argument( 61 | '-a', 62 | '--alpha', 63 | type=float, 64 | default=0.0, 65 | help='The weight of random walk -skip-gram loss. Default is ') 66 | parser.add_argument( 67 | '--rw', 68 | type=float, 69 | default=0.01, 70 | help='The weight of reconstruction of adjacency matrix loss. Default is ') 71 | parser.add_argument('-w', '--walk', type=str, default='hyper', 72 | help='The walk type, empty stands for normal rw') 73 | parser.add_argument('-d', '--diag', type=str, default='True', 74 | help='Use the diag mask or not') 75 | parser.add_argument( 76 | '-f', 77 | '--feature', 78 | type=str, 79 | default='adj', 80 | help='Features used in the first step') 81 | 82 | args = parser.parse_args() 83 | 84 | if not args.random_walk: 85 | args.model_name = 'model_no_randomwalk' 86 | args.epoch = 25 87 | else: 88 | args.model_name = 'model_{}_'.format(args.data) 89 | args.epoch = 25 90 | if args.TRY: 91 | args.model_name = 'try' + args.model_name 92 | if not args.random_walk: 93 | args.epoch = 5 94 | else: 95 | args.epoch = 1 96 | # args.epoch = 1 97 | args.model_name += args.remark 98 | print (args.model_name) 99 | 100 | args.save_path = os.path.join( 101 | '../checkpoints/', args.data, args.model_name) 102 | if not os.path.exists(args.save_path): 103 | os.makedirs(args.save_path) 104 | return args 105 | 106 | 107 | def train_batch_hyperedge( 108 | model, 109 | loss_func, 110 | batch_data, 111 | batch_weight, 112 | y=""): 113 | 114 | x = batch_data 115 | w = batch_weight 116 | 117 | # When label is not generated, prepare the data 118 | if len(y) == 0: 119 | x, y, w, s = generate_negative(x, "train_dict", w) 120 | x, y, w, s = sync_shuffle([x, y, w, s]) 121 | else: 122 | s = torch.ones((len(y), 1)) 123 | 124 | # forward 125 | pred, recon_loss = model(x, return_recon=True) 126 | # , weight=s.float().view(-1, 1).to(device) 127 | loss = loss_func(pred, y) 128 | return pred, y, loss, recon_loss, w, s 129 | 130 | 131 | def train_epoch( 132 | model, 133 | loss_func, 134 | training_data, 135 | optimizer, 136 | batch_size): 137 | # Epoch operation in training phase 138 | # print (len(train_dict[min_size]), train_dict[min_size].capacity, len(test_dict[min_size])) 139 | edges, edge_weight = training_data 140 | y = torch.tensor([]) 141 | # y = training_y 142 | # Permutate all the data 143 | if len(y) > 0: 144 | print ("existing y") 145 | edges, edge_weight, y = sync_shuffle([edges, edge_weight, y]) 146 | else: 147 | edges, edge_weight = sync_shuffle([edges, edge_weight]) 148 | 149 | model.train() 150 | 151 | bce_total_loss = 0 152 | recon_total_loss = 0 153 | acc_list, y_list, pred_list, weight_list, size_list = [], [], [], [], [] 154 | 155 | batch_num = int(math.floor(len(edges) / batch_size)) 156 | bar = trange( 157 | batch_num, 158 | mininterval=0.1, 159 | desc=' - (Training) ', 160 | leave=False, 161 | ) 162 | for i in bar: 163 | batch_edge = edges[i * batch_size:(i + 1) * batch_size] 164 | batch_edge_weight = edge_weight[i * batch_size:(i + 1) * batch_size] 165 | batch_y = "" 166 | if len(y) > 0: 167 | batch_y = y[i * batch_size:(i + 1) * batch_size] 168 | if len(batch_y) == 0: 169 | continue 170 | 171 | pred, batch_y, loss_bce, loss_recon, batch_w, batch_s = train_batch_hyperedge( 172 | model, loss_func, batch_edge, batch_edge_weight, y=batch_y) 173 | loss = loss_bce * alpha + loss_recon * beta 174 | # loss = loss_bce + loss_recon 175 | 176 | # acc_list.append(accuracy(pred, batch_y)) 177 | y_list.append(batch_y) 178 | pred_list.append(pred) 179 | weight_list.append(batch_w) 180 | size_list.append(batch_s) 181 | 182 | for opt in optimizer: 183 | opt.zero_grad() 184 | 185 | # backward 186 | loss.backward() 187 | 188 | # update parameters 189 | for opt in optimizer: 190 | opt.step() 191 | 192 | bar.set_description(" - (Training) BCE: %.4f recon: %.4f" % 193 | (bce_total_loss / (i + 1), recon_total_loss / (i + 1))) 194 | bce_total_loss += loss_bce.item() 195 | recon_total_loss += loss_recon.item() 196 | y = torch.cat(y_list) 197 | pred = torch.cat(pred_list) 198 | size_list = torch.cat(size_list) 199 | weight_list = torch.cat(weight_list) 200 | auc1_1, auc2_1 = roc_auc_cuda(y, pred, weight_list, max_size) 201 | acc_1 = accuracy(pred, y, weight_list, max_size) 202 | 203 | auc1, auc2 = roc_auc_cuda(y, pred, size_list, max_size) 204 | acc = accuracy(pred, y, size_list, max_size) 205 | 206 | return bce_total_loss / batch_num, recon_total_loss / batch_num, acc_1+acc, auc1_1 + auc1, auc2_1+auc2 207 | 208 | 209 | def eval_epoch(model, loss_func, validation_data, batch_size): 210 | ''' Epoch operation in evaluation phase ''' 211 | bce_total_loss = 0 212 | recon_total_loss = 0 213 | 214 | model.eval() 215 | with torch.no_grad(): 216 | validation_data, validation_weight = validation_data 217 | y = "" 218 | 219 | validation_data, validation_weight = sync_shuffle( 220 | [validation_data, validation_weight],10000) 221 | 222 | pred, label, size_list, weight_list = [], [], [], [] 223 | 224 | for i in tqdm(range(int(math.floor(len(validation_data) / batch_size))), 225 | mininterval=0.1, desc=' - (Validation) ', leave=False): 226 | # prepare data 227 | batch_x = validation_data[i * batch_size:(i + 1) * batch_size] 228 | batch_w = validation_weight[i * batch_size:(i + 1) * batch_size] 229 | 230 | if len(y) == 0: 231 | batch_x, batch_y, batch_w, batch_s = generate_negative( 232 | batch_x, "test_dict", weight=batch_w) 233 | else: 234 | batch_y = y[i * batch_size:(i + 1) * batch_size] 235 | 236 | batch_x, batch_y, batch_w, batch_s = sync_shuffle( 237 | [batch_x, batch_y, batch_w, batch_s]) 238 | pred_batch, recon_loss = model(batch_x, return_recon=True) 239 | size_list.append(batch_s) 240 | pred.append(pred_batch) 241 | label.append(batch_y) 242 | weight_list.append(batch_w) 243 | # weight=batch_s.float().view(-1, 1).to(device) 244 | loss = loss_func(pred_batch, batch_y) 245 | recon_total_loss += recon_loss.item() 246 | bce_total_loss += loss.item() 247 | 248 | pred = torch.cat(pred, dim=0) 249 | label = torch.cat(label, dim=0) 250 | size_list = torch.cat(size_list, dim=0) 251 | weight_list = torch.cat(weight_list, dim = 0) 252 | acc_1 = accuracy(pred, label, weight_list, max_size) 253 | auc1_1, auc2_1 = roc_auc_cuda(label, pred, weight_list, max_size) 254 | 255 | acc = accuracy(pred, label, size_list, max_size) 256 | auc1, auc2 = roc_auc_cuda(label, pred, size_list, max_size) 257 | 258 | return bce_total_loss / (i + 1), recon_total_loss / \ 259 | (i + 1), acc_1 + acc, auc1_1 + auc1, auc2_1 + auc2 260 | 261 | 262 | def train(model, 263 | loss, 264 | training_data, 265 | validation_data, 266 | optimizer, 267 | epochs, 268 | batch_size): 269 | valid_accus = [0] 270 | # outlier_data = generate_outlier() 271 | edges, edge_weight = training_data 272 | training_data_new = training_data 273 | training_data_generator = DataGenerator( 274 | edges, edge_weight, int(batch_size), 300, True) 275 | 276 | for epoch_i in range(epochs): 277 | 278 | save_embeddings(model, True) 279 | print ('[ Epoch', epoch_i, 'of', epochs, ']') 280 | 281 | start = time.time() 282 | edges_part, edge_weight_part = training_data_generator.next_iter() 283 | training_data_new = edges_part, edge_weight_part 284 | 285 | bce_loss, recon_loss, train_accu, auc1, auc2 = train_epoch(model, loss, training_data_new, optimizer, batch_size) 286 | 287 | print ( 288 | ' - (Training) bce: {bce_loss: 7.4f},' 289 | 'recon: {recon_loss: 7.4f}' 290 | ' acc: {accu}, auc: {auc1}, aupr: {auc2}, ' 291 | 'elapse: {elapse:3.3f} s'.format( 292 | bce_loss=bce_loss, 293 | recon_loss=recon_loss, 294 | accu=train_accu, 295 | auc1=auc1, 296 | auc2=auc2, 297 | elapse=( 298 | time.time() - start))) 299 | 300 | start = time.time() 301 | valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, batch_size) 302 | print ( 303 | ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 304 | ' acc: {accu},' 305 | ' auc: {auc1}, aupr: {auc2},' 306 | 'elapse: {elapse:3.3f} s'.format( 307 | bce_loss=valid_bce_loss, 308 | recon_loss=recon_loss, 309 | accu=valid_accu, 310 | auc1=valid_auc1, 311 | auc2=valid_auc2, 312 | elapse=( 313 | time.time() - start))) 314 | valid_aupr_final = float(valid_auc2.split(" ")[-2]) 315 | valid_accus += [valid_aupr_final] 316 | 317 | checkpoint = { 318 | 'model_link': model.state_dict(), 319 | 'epoch': epoch_i} 320 | 321 | model_name = 'model.chkpt' 322 | 323 | if valid_aupr_final >= max(valid_accus): 324 | torch.save(checkpoint, os.path.join(args.save_path, model_name)) 325 | 326 | torch.cuda.empty_cache() 327 | 328 | checkpoint = torch.load(os.path.join(args.save_path, model_name)) 329 | model.load_state_dict(checkpoint['model_link']) 330 | 331 | valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, 332 | batch_size) 333 | print( 334 | ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 335 | ' acc: {accu},' 336 | ' auc: {auc1}, aupr: {auc2},' 337 | 'elapse: {elapse:3.3f} s'.format( 338 | bce_loss=valid_bce_loss, 339 | recon_loss=recon_loss, 340 | accu=valid_accu, 341 | auc1=valid_auc1, 342 | auc2=valid_auc2, 343 | elapse=( 344 | time.time() - start))) 345 | 346 | 347 | 348 | def neighbor_check(temp, dict): 349 | return tuple(temp) in dict 350 | flag = False 351 | for i in range(len(temp)): 352 | for j in [-1, 0, 1]: 353 | a = np.copy(temp) 354 | a[i] += j 355 | a.sort() 356 | if tuple(a) in dict: 357 | flag = True 358 | break 359 | if flag: 360 | break 361 | return flag 362 | 363 | 364 | def generate_negative(x, dict1, weight=""): 365 | if len(weight) == 0: 366 | weight = torch.ones(len(x), dtype=torch.float) 367 | mode = "" 368 | if dict1 == 'train_dict': 369 | dict1 = train_dict 370 | mode = "train" 371 | elif dict1 == 'test_dict': 372 | dict1 = test_dict 373 | mode = "test" 374 | 375 | change_num_list = [[] for i in range(max_size + 1)] 376 | for s in range(min_size, max_size + 1): 377 | change_num = np.random.binomial(s, 0.5, len(x) * (neg_num * 2)) 378 | change_num = change_num[change_num != 0] 379 | 380 | change_num_list[s] = list(change_num) 381 | 382 | neg_list = [] 383 | new_x = [] 384 | new_index = [] 385 | neg_weight = [] 386 | max_id = int(num[-1]) 387 | size_list = [] 388 | size_neg_list = [] 389 | 390 | 391 | for j, sample in enumerate(x): 392 | for i in range(neg_num): 393 | # generate decomposed sample 394 | # if len(sample) > min_size: 395 | # decompose_sample = np.copy(sample) 396 | # decompose_size = int( 397 | # min(max_size - min_size + 1, len(sample) - min_size + 1) * random.random()) + min_size 398 | # if decompose_size == len(sample): 399 | # decompose_sample = np.copy(sample) 400 | # else: 401 | # decompose_sample = np.copy(sample) 402 | # np.random.shuffle(decompose_sample) 403 | # decompose_sample = decompose_sample[:decompose_size] 404 | # decompose_sample.sort() 405 | # 406 | # if tuple(decompose_sample) not in dict1[len(decompose_sample)]: 407 | # dict1[len(decompose_sample)].add(tuple(decompose_sample)) 408 | # if mode == 'train': 409 | # test_dict[len(decompose_sample)].add(tuple(decompose_sample)) 410 | # 411 | # else: 412 | # decompose_sample = np.copy(sample) 413 | # 414 | # if tuple(decompose_sample) not in dict1[len(decompose_sample)]: 415 | # dict1[len(decompose_sample)].add(tuple(decompose_sample)) 416 | # if mode == 'train': 417 | # test_dict[len(decompose_sample)].add(tuple(decompose_sample)) 418 | 419 | decompose_sample = np.copy(sample) 420 | list1 = change_num_list[decompose_sample.shape[-1]] 421 | change_num = list1.pop() 422 | changes = np.random.choice(np.arange(decompose_sample.shape[-1]),change_num,replace=False) 423 | simple_or_hard = np.random.rand() 424 | temp = np.copy(decompose_sample) 425 | trial = 0 426 | flag = False 427 | while neighbor_check(temp, dict1[(len(temp))]): 428 | temp = np.copy(decompose_sample) 429 | trial += 1 430 | if trial >= 1000: 431 | temp = "" 432 | break 433 | 434 | for change in changes: 435 | if temp[change] not in node2chrom: 436 | print (temp, decompose_sample) 437 | chrom = node2chrom[temp[change]] 438 | start, end = chrom_range[chrom] 439 | 440 | # Only change one node 441 | if simple_or_hard <= pair_ratio: 442 | # temp[change] = np.random.randint(int(start), int(end), 1) 443 | temp[change] = int( 444 | math.floor( 445 | (end - start) * random.random())) + start 446 | else: 447 | # Only one node type 448 | temp = np.random.randint( 449 | 1, max_id, decompose_sample.shape[-1]) 450 | 451 | temp = list(set(temp)) 452 | 453 | if len(temp) < len(decompose_sample): 454 | temp = np.copy(decompose_sample) 455 | 456 | temp.sort() 457 | dis_list = [] 458 | for k in range(len(temp) - 1): 459 | dis_list.append(temp[k + 1] - temp[k]) 460 | if np.min(dis_list) <= 5: 461 | temp = np.copy(decompose_sample) 462 | 463 | if len(temp) > 0: 464 | if i == 0: 465 | new_x.append(decompose_sample) 466 | new_index.append(j) 467 | size_list.append(len(decompose_sample)) 468 | 469 | neg_list.append(temp) 470 | size_neg_list.append(len(temp)) 471 | neg_weight.append(weight[j]) 472 | 473 | 474 | new_weight = weight[np.array(new_index)] 475 | new_weight = torch.tensor(new_weight)#.to(device) 476 | neg_weight = torch.tensor(neg_weight) 477 | size_list = torch.Tensor(np.concatenate( 478 | [np.array(size_list), np.array(size_neg_list)], axis=0)) 479 | x = np2tensor_hyper(new_x, dtype=torch.long) 480 | neg = np2tensor_hyper(neg_list, dtype=torch.long) 481 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 482 | neg = pad_sequence(neg, batch_first=True, padding_value=0).to(device) 483 | 484 | a = torch.cat([x, neg]) 485 | 486 | return a,\ 487 | torch.cat([torch.ones((len(x), 1), device=device), (torch.zeros((len(neg), 1), device=device))]),\ 488 | torch.cat([new_weight, neg_weight], dim=0),\ 489 | size_list 490 | 491 | 492 | def save_embeddings(model, origin=False): 493 | model.eval() 494 | with torch.no_grad(): 495 | ids = np.arange(num_list[-1]) + 1 496 | ids = torch.Tensor(ids).long().to(device).view(-1, 1) 497 | embeddings = [] 498 | for j in range(math.ceil(len(ids) / batch_size)): 499 | x = ids[j * batch_size:min((j + 1) * batch_size, len(ids))] 500 | if origin: 501 | embed = model.get_node_embeddings(x) 502 | else: 503 | embed = model.get_embedding_static(x) 504 | embed = embed.detach().cpu().numpy() 505 | embeddings.append(embed) 506 | 507 | embeddings = np.concatenate(embeddings, axis=0)[:, 0, :] 508 | 509 | np.save("../mymodel_%d.npy" % (0), embeddings) 510 | 511 | if origin: 512 | old_static = np.load("../mymodel_%d_origin.npy" % (0)) 513 | try: 514 | update_rate = np.sum((old_static - embeddings) ** 2, axis=-1) / np.sum(old_static ** 2, axis=-1) 515 | print("update_rate: %f\t%f" % (np.min(update_rate), np.max(update_rate))) 516 | except: 517 | pass 518 | np.save("../mymodel_%d_origin.npy" % (0), embeddings) 519 | 520 | torch.cuda.empty_cache() 521 | return embeddings 522 | 523 | def oe(matrix): 524 | for i in range(len(matrix)): 525 | if i == 0: 526 | continue 527 | x = [] 528 | y = [] 529 | for j in range(len(matrix) - i): 530 | x.append(j) 531 | y.append(j + i) 532 | x = np.array(x) 533 | y = np.array(y) 534 | # print (x,y) 535 | 536 | matrix[x , y] /= (np.mean(matrix[x,y])+1e-15) 537 | matrix[y, x] /= (np.mean(matrix[y,x])+1e-15) 538 | matrix = np.log(1+matrix) 539 | return matrix 540 | 541 | def predict(model, input): 542 | model.eval() 543 | output = [] 544 | with torch.no_grad(): 545 | for j in trange(math.ceil(len(input) / batch_size)): 546 | x = input[j * batch_size:min((j + 1) * batch_size, len(input))] 547 | x = np2tensor_hyper(x, dtype=torch.long) 548 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 549 | output.append(model(x).detach().cpu().numpy()) 550 | output = np.concatenate(output, axis=0) 551 | torch.cuda.empty_cache() 552 | return output 553 | 554 | 555 | args = parse_args() 556 | neg_num = 3 557 | batch_size = 96 558 | neg_num_w2v = 5 559 | bottle_neck = args.dimensions 560 | pair_ratio = 1.0 561 | dynamic_dict = False 562 | max_size = 3 563 | min_size = 3 564 | loss = F.binary_cross_entropy 565 | 566 | neighbor_mask = [] 567 | 568 | 569 | chrom_range = np.load("../data/SPRITE/chrom_range.npy") 570 | node2chrom = np.load("../data/SPRITE/node2chrom.npy", allow_pickle=True).item() 571 | num = [] 572 | for v in chrom_range: 573 | num.append(v[1] - v[0]) 574 | 575 | num_list = np.cumsum(num) 576 | zero_num_list = np.array([0] + list(num_list)) 577 | print ("Node type num", num) 578 | 579 | data_list = [] 580 | for size in range(min_size, max_size + 1): 581 | for thresh in [[3,5],[5,8],[8,12],"upper"]: 582 | if type(thresh) == list: 583 | name = "%d_%d" % (thresh[0],thresh[1]) 584 | else: 585 | name = thresh 586 | 587 | data = np.load("../data/SPRITE/tuples/%s_filter_%d.npy" %(name,size)).astype('int') 588 | for datum in data: 589 | data_list.append(datum) 590 | 591 | data = np.array(data_list) 592 | print (len(data)) 593 | 594 | attribute_dict = None 595 | 596 | if args.feature == 'adj': 597 | embeddings_initial = [] 598 | inter_initial = np.load("../data/SPRITE/inter_adj_SPRITE.npy").astype('float32') 599 | 600 | adj = np.load("../data/SPRITE/intra_adj_SPRITE.npy").astype('float32') 601 | # adj = np.load("../data/SPRITE/adj.npy").astype('float32') 602 | for v in chrom_range: 603 | # adj_list = [] 604 | # for i in range(5): 605 | # adj = np.load("../data/SPRITE/adj_%d.npy" % i).astype('float32') 606 | # adj = adj[v[0] - 1 :v[1] - 1, v[0] - 1:v[1] - 1] 607 | # # adj = np.log(1+adj) 608 | # # # adj = oe(adj) 609 | # adj = np.corrcoef(adj).astype('float32') 610 | # adj_list.append(adj) 611 | # temp = np.concatenate(adj_list, axis=-1) 612 | 613 | temp = adj[v[0] - 1:v[1] - 1,v[0] - 1:v[1] - 1] 614 | # temp = np.log(1 + temp) 615 | temp = np.corrcoef(temp).astype('float32') 616 | temp[np.isnan(temp)] = 0.0 617 | # temp = oe(temp) 618 | # temp = np.concatenate( 619 | # [adj[v[0]:v[1], 0: v[0]], adj[v[0]:v[1], v[1]:]], axis=-1) 620 | # temp /= (np.sum(temp, axis=1, keepdims=True) + 1e-10) * 100 621 | # temp /= (np.max(temp, axis=0, keepdims=True) + 1e-10) 622 | embeddings_initial.append(temp) 623 | # embeddings_initial = [adj] 624 | 625 | print(chrom_range) 626 | # print (train_weight) 627 | # print (train_weight, np.min(train_weight), np.max(train_weight)) 628 | # train_weight_mean = np.mean(train_weight) 629 | # train_weight = train_weight / train_weight_mean * neg_num 630 | # test_weight = test_weight / train_weight_mean * neg_num 631 | 632 | num = torch.as_tensor(num) 633 | num_list = torch.as_tensor(num_list) 634 | print (num, num_list) 635 | print ("walk type", args.walk) 636 | 637 | if args.feature == 'walk': 638 | node_list = np.arange(num_list[-1]).astype('int') 639 | if args.walk == 'hyper': 640 | walk_path = random_walk_hyper(args, node_list, data) 641 | else: 642 | walk_path = random_walk(args, num, data) 643 | del node_list 644 | 645 | compress = True 646 | # Note that, no matter how many node types are here, make sure the 647 | # hyperedge (N1,N2,N3,...) has id, N1 < N2 < N3... 648 | if not dynamic_dict: 649 | test_dict = build_hash(data, compress=compress, max_size=max_size, 650 | min_size=min_size, fname="test") 651 | train_dict = test_dict 652 | # train_dict = build_hash(train_data, compress = compress, max_size=max_size, min_size = min_size, fname="test") 653 | else: 654 | train_dict = [BloomFilter(1e8, 1e-3) for i in range(max_size + 1)] 655 | test_dict = [BloomFilter(1e8, 1e-3) for i in range(max_size + 1)] 656 | 657 | data = [] 658 | intra_inter = [] 659 | for size in range(min_size, max_size + 1): 660 | temp_list = [] 661 | for thresh in [[3, 5], [5, 8], [8, 12], "upper"]: 662 | if type(thresh) == list: 663 | if size == 3 and thresh[0] <= 5: 664 | continue 665 | if size == 4 and thresh[0] <= 3: 666 | continue 667 | 668 | if type(thresh) == list: 669 | name = "%d_%d" % (thresh[0], thresh[1]) 670 | else: 671 | name = thresh 672 | 673 | temp = np.load("../data/SPRITE/tuples/%s_filter_%d.npy" % (name, size)).astype('int') 674 | temp_list.append(temp) 675 | intra_inter.append(np.load("../data/SPRITE/tuples/%s_%d_intra_inter.npy" % (name, size))) 676 | temp_list = np.concatenate(temp_list,axis = 0) 677 | for datum in temp_list: 678 | data.append(datum) 679 | 680 | 681 | data = np.array(data) 682 | intra_inter = np.concatenate(intra_inter,axis = 0) 683 | 684 | index = np.arange(len(data)) 685 | weight = np.ones((len(data)), dtype='float32') 686 | weight = intra_inter 687 | np.random.shuffle(index) 688 | split = int(0.5 * len(index)) 689 | train_data = data[index[:split]] 690 | test_data = data[index[split:]] 691 | train_weight = weight[index[:split]] 692 | test_weight = weight[index[split:]] 693 | 694 | 695 | del data 696 | 697 | print ("train data amount", len(train_data)) 698 | 699 | print ("dict_size", len(train_dict[-1]), len(test_dict[-1])) 700 | 701 | if args.feature == 'walk': 702 | # Note that for this part, the word2vec still takes sentences with 703 | # words starts at "0" 704 | if not args.TRY and os.path.exists( 705 | "../%s_wv_%d_%s.npy" % 706 | (args.data, args.dimensions, args.walk)): 707 | A = np.load( 708 | "../%s_wv_%d_%s.npy" % 709 | (args.data, 710 | args.dimensions, 711 | args.walk), 712 | allow_pickle=True) 713 | else: 714 | print ("start loading") 715 | walks = np.loadtxt(walk_path, delimiter=" ").astype('int') 716 | start = time.time() 717 | split_num = 20 718 | pool = ProcessPoolExecutor(max_workers=split_num) 719 | process_list = [] 720 | walks = np.array_split(walks, split_num) 721 | 722 | result = [] 723 | print ("Start turning path to strs") 724 | for walk in walks: 725 | process_list.append(pool.submit(walkpath2str, walk)) 726 | 727 | for p in as_completed(process_list): 728 | result += p.result() 729 | 730 | pool.shutdown(wait=True) 731 | 732 | walks = result 733 | print ( 734 | "Finishing Loading and processing %.2f s" % 735 | (time.time() - start)) 736 | print ("Start Word2vec") 737 | import multiprocessing 738 | 739 | print ("num cpu cores", multiprocessing.cpu_count()) 740 | w2v = Word2Vec( 741 | walks, 742 | size=args.dimensions, 743 | window=args.window_size, 744 | min_count=0, 745 | sg=1, 746 | iter=1, 747 | workers=multiprocessing.cpu_count()) 748 | wv = w2v.wv 749 | A = [wv[str(i)] for i in range(num_list[-1])] 750 | np.save("../%s_wv_%d_%s.npy" % 751 | (args.data, args.dimensions, args.walk), A) 752 | 753 | from sklearn.preprocessing import StandardScaler 754 | 755 | A = StandardScaler().fit_transform(A) 756 | 757 | A = np.concatenate( 758 | (np.zeros((1, A.shape[-1]), dtype='float32'), A), axis=0) 759 | A = A.astype('float32') 760 | A = torch.tensor(A).to(device) 761 | print (A.shape) 762 | 763 | node_embedding = Wrap_Embedding(int( 764 | num_list[-1] + 1), args.dimensions, scale_grad_by_freq=False, padding_idx=0, sparse=False) 765 | node_embedding.weight = nn.Parameter(A) 766 | 767 | elif args.feature == 'adj': 768 | flag = False 769 | # node_embedding = MultipleEmbedding_back( 770 | # embeddings_initial, 771 | # bottle_neck, 772 | # flag, 773 | # num_list).to(device) 774 | 775 | node_embedding = MultipleEmbedding( 776 | embeddings_initial, 777 | bottle_neck, 778 | flag, 779 | num_list, chrom_range, inter_initial).to(device) 780 | # node_embedding = Wrap_Embedding(int( 781 | # num_list[-1] + 1), args.dimensions, scale_grad_by_freq=False, padding_idx=0, sparse=False) 782 | 783 | classifier_model = Classifier( 784 | n_head=8, 785 | d_model=args.dimensions, 786 | d_k=16, 787 | d_v=16, 788 | node_embedding=node_embedding, 789 | diag_mask=args.diag, 790 | bottle_neck=bottle_neck).to(device) 791 | 792 | save_embeddings(classifier_model, True) 793 | 794 | 795 | 796 | summary(classifier_model, (3,)) 797 | 798 | params_list = list(classifier_model.parameters()) 799 | 800 | if args.feature == 'adj': 801 | # optimizer = torch.optim.RMSprop(params_list, lr=1e-3) 802 | optimizer = torch.optim.AdamW(params_list, lr=1e-3, amsgrad=False) 803 | else: 804 | optimizer = torch.optim.RMSprop(params_list, lr=1e-3) 805 | 806 | model_parameters = filter(lambda p: p.requires_grad, params_list) 807 | params = sum([np.prod(p.size()) for p in model_parameters]) 808 | print ("params to be trained", params) 809 | # 810 | alpha = 0.0 811 | beta = 1.0 812 | train(classifier_model, 813 | loss=loss, 814 | training_data=(train_data, train_weight), 815 | validation_data=(test_data, test_weight), 816 | optimizer=[optimizer], epochs=5, batch_size=batch_size) 817 | 818 | alpha = 1.0 819 | beta = 1.0 820 | train(classifier_model, 821 | loss=loss, 822 | training_data=(train_data, train_weight), 823 | validation_data=(test_data, test_weight), 824 | optimizer=[optimizer], epochs=30, batch_size=batch_size) 825 | 826 | model_name = 'model.chkpt' 827 | checkpoint = torch.load(os.path.join(args.save_path, model_name)) 828 | classifier_model.load_state_dict(checkpoint['model_link']) 829 | 830 | save_embeddings(classifier_model,True) -------------------------------------------------------------------------------- /History_version/Code/main_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torchsummary import summary 6 | from gensim.models import Word2Vec 7 | import random 8 | from scipy.sparse import csr_matrix 9 | from scipy.sparse import vstack as s_vstack 10 | import numpy as np 11 | import os 12 | import time 13 | import datetime 14 | import math 15 | import argparse 16 | import warnings 17 | from concurrent.futures import as_completed, ProcessPoolExecutor 18 | from tqdm import tqdm, trange 19 | import umap 20 | from random_walk import random_walk 21 | from random_walk_hyper import random_walk_hyper 22 | from Modules import * 23 | from utils import * 24 | 25 | from sklearn.manifold import t_sne, TSNE, MDS, Isomap 26 | from sklearn.decomposition import PCA 27 | from sklearn.cluster import KMeans, AgglomerativeClustering 28 | from sklearn.metrics.cluster import adjusted_rand_score as ARI 29 | from torch.utils.tensorboard import SummaryWriter 30 | import matplotlib as mpl 31 | 32 | mpl.use("Agg") 33 | import matplotlib.pyplot as plt 34 | import seaborn as sns 35 | import multiprocessing 36 | from pybloomfilter import BloomFilter 37 | 38 | cpu_num = multiprocessing.cpu_count() 39 | 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | torch.backends.cudnn.benchmark = True 42 | torch.backends.cudnn.deterministic = True 43 | device_ids = [0, 1] 44 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 45 | 46 | warnings.filterwarnings("ignore") 47 | 48 | 49 | def parse_args(): 50 | # Parses the node2vec arguments. 51 | parser = argparse.ArgumentParser(description="Run node2vec.") 52 | 53 | parser.add_argument('--data', type=str, default='drop') 54 | parser.add_argument('--TRY', action='store_true') 55 | parser.add_argument('--FILTER', action='store_true') 56 | parser.add_argument('--grid', type=str, default='') 57 | parser.add_argument('--remark', type=str, default='') 58 | 59 | parser.add_argument('--random-walk', action='store_true') 60 | 61 | parser.add_argument('--dimensions', type=int, default=64, 62 | help='Number of dimensions. Default is 64.') 63 | 64 | parser.add_argument('-l', '--walk-length', type=int, default=80, 65 | help='Length of walk per source. Default is 80.') 66 | 67 | parser.add_argument('-r', '--num-walks', type=int, default=40, 68 | help='Number of walks per source. Default is 10.') 69 | 70 | parser.add_argument('-k', '--window-size', type=int, default=40, 71 | help='Context size for optimization. Default is 10.') 72 | 73 | parser.add_argument('-i', '--iter', default=1, type=int, 74 | help='Number of epochs in SGD') 75 | 76 | parser.add_argument('--workers', type=int, default=8, 77 | help='Number of parallel workers. Default is 8.') 78 | 79 | parser.add_argument('--p', type=float, default=2, 80 | help='Return hyperparameter. Default is 1.') 81 | 82 | parser.add_argument('--q', type=float, default=0.25, 83 | help='Inout hyperparameter. Default is 1.') 84 | 85 | parser.add_argument('-a', '--alpha', type=float, default=1.0, 86 | help='The weight of node2vec loss. Default is ') 87 | parser.add_argument('-w', '--walk', type=str, default='hyper', 88 | help='The walk type, empty stands for normal rw') 89 | parser.add_argument('-d', '--diag', type=str, default='True', 90 | help='Use the diagz mask or not') 91 | parser.add_argument( 92 | '-f', 93 | '--feature', 94 | type=str, 95 | default='adj', 96 | help='Features used in the first step') 97 | 98 | args = parser.parse_args() 99 | 100 | if not args.random_walk: 101 | args.model_name = 'model_no_randomwalk' 102 | args.epoch = 25 103 | else: 104 | args.model_name = 'model_{}_'.format(args.data) 105 | args.epoch = 25 106 | if args.TRY: 107 | args.model_name = 'try' + args.model_name 108 | if not args.random_walk: 109 | args.epoch = 5 110 | else: 111 | args.epoch = 1 112 | # args.epoch = 1 113 | args.model_name += args.remark 114 | print(args.model_name) 115 | 116 | args.save_path = os.path.join( 117 | '../checkpoints/', args.data, args.model_name) 118 | if not os.path.exists(args.save_path): 119 | os.makedirs(args.save_path) 120 | return args 121 | 122 | 123 | def train_batch_hyperedge( 124 | model, 125 | loss_func, 126 | batch_data, 127 | batch_weight, 128 | y=""): 129 | x = batch_data 130 | w = batch_weight 131 | 132 | # When label is not generated, prepare the data 133 | if len(y) == 0: 134 | x, y, w, s = generate_negative(x, "train_dict", w) 135 | x, y, w, s = sync_shuffle([x, y, w, s]) 136 | else: 137 | s = torch.ones((len(y), 1)) 138 | 139 | # forward 140 | pred, recon_loss = model(x, return_recon=True) 141 | # , weight=s.float().view(-1, 1).to(device) 142 | loss = loss_func(pred, y) 143 | return pred, y, loss, recon_loss, w, s 144 | 145 | 146 | def train_epoch( 147 | model, 148 | loss_func, 149 | training_data, 150 | optimizer, 151 | batch_size): 152 | # Epoch operation in training phase 153 | # print (len(train_dict[min_size]), train_dict[min_size].capacity, len(test_dict[min_size])) 154 | edges, edge_weight = training_data 155 | y = torch.tensor([]) 156 | # y = training_y 157 | # Permutate all the data 158 | if len(y) > 0: 159 | print("existing y") 160 | edges, edge_weight, y = sync_shuffle([edges, edge_weight, y]) 161 | else: 162 | edges, edge_weight = sync_shuffle([edges, edge_weight]) 163 | 164 | model.train() 165 | 166 | bce_total_loss = 0 167 | recon_total_loss = 0 168 | acc_list, y_list, pred_list, weight_list, size_list = [], [], [], [], [] 169 | 170 | batch_num = int(math.floor(len(edges) / batch_size)) 171 | bar = trange( 172 | batch_num, 173 | mininterval=0.1, 174 | desc=' - (Training) ', 175 | leave=False, 176 | ) 177 | for i in bar: 178 | batch_edge = edges[i * batch_size:(i + 1) * batch_size] 179 | batch_edge_weight = edge_weight[i * batch_size:(i + 1) * batch_size] 180 | batch_y = "" 181 | if len(y) > 0: 182 | batch_y = y[i * batch_size:(i + 1) * batch_size] 183 | if len(batch_y) == 0: 184 | continue 185 | 186 | pred, batch_y, loss_bce, loss_recon, batch_w, batch_s = train_batch_hyperedge( 187 | model, loss_func, batch_edge, batch_edge_weight, y=batch_y) 188 | loss = loss_bce + loss_recon 189 | # loss = loss_bce + loss_recon 190 | 191 | # acc_list.append(accuracy(pred, batch_y)) 192 | y_list.append(batch_y) 193 | pred_list.append(pred) 194 | weight_list.append(batch_w) 195 | size_list.append(batch_s) 196 | 197 | for opt in optimizer: 198 | opt.zero_grad() 199 | 200 | # backward 201 | loss.backward() 202 | 203 | # update parameters 204 | for opt in optimizer: 205 | opt.step() 206 | 207 | bar.set_description(" - (Training) BCE: %.4f recon: %.4f" % 208 | (bce_total_loss / (i + 1), recon_total_loss / (i + 1))) 209 | bce_total_loss += loss_bce.item() 210 | recon_total_loss += loss_recon.item() 211 | y = torch.cat(y_list) 212 | pred = torch.cat(pred_list) 213 | size_list = torch.cat(size_list) 214 | weight_list = torch.cat(weight_list) 215 | 216 | auc1, auc2 = roc_auc_cuda(y, pred, size_list, max_size) 217 | acc = accuracy(pred, y, size_list, max_size) 218 | 219 | return bce_total_loss / batch_num, recon_total_loss / batch_num, acc, auc1, auc2 220 | 221 | 222 | def eval_epoch(model, loss_func, validation_data, batch_size,final=False): 223 | ''' Epoch operation in evaluation phase ''' 224 | bce_total_loss = 0 225 | recon_total_loss = 0 226 | 227 | model.eval() 228 | with torch.no_grad(): 229 | validation_data, validation_weight = validation_data 230 | y = "" 231 | if final: 232 | validation_data, validation_weight = sync_shuffle( 233 | [validation_data, validation_weight], -1) 234 | else: 235 | validation_data, validation_weight = sync_shuffle( 236 | [validation_data, validation_weight], 50000) 237 | 238 | pred, label, size_list, weight_list = [], [], [], [] 239 | 240 | for i in tqdm(range(int(math.floor(len(validation_data) / batch_size))), 241 | mininterval=0.1, desc=' - (Validation) ', leave=False): 242 | # prepare data 243 | batch_x = validation_data[i * batch_size:(i + 1) * batch_size] 244 | batch_w = validation_weight[i * batch_size:(i + 1) * batch_size] 245 | 246 | if len(y) == 0: 247 | batch_x, batch_y, batch_w, batch_s = generate_negative( 248 | batch_x, "test_dict", weight=batch_w) 249 | else: 250 | batch_y = y[i * batch_size:(i + 1) * batch_size] 251 | 252 | batch_x, batch_y, batch_w, batch_s = sync_shuffle( 253 | [batch_x, batch_y, batch_w, batch_s]) 254 | pred_batch, recon_loss = model(batch_x, return_recon=True) 255 | size_list.append(batch_s) 256 | pred.append(pred_batch) 257 | label.append(batch_y) 258 | weight_list.append(batch_w) 259 | # weight=batch_s.float().view(-1, 1).to(device) 260 | loss = loss_func(pred_batch, batch_y) 261 | recon_total_loss += recon_loss.item() 262 | bce_total_loss += loss.item() 263 | 264 | pred = torch.cat(pred, dim=0) 265 | label = torch.cat(label, dim=0) 266 | size_list = torch.cat(size_list, dim=0) 267 | weight_list = torch.cat(weight_list, dim=0) 268 | 269 | acc = accuracy(pred, label, size_list, max_size) 270 | auc1, auc2 = roc_auc_cuda(label, pred, size_list, max_size) 271 | 272 | return bce_total_loss / (i + 1), recon_total_loss / \ 273 | (i + 1), acc, auc1, auc2 274 | def train(model, 275 | loss, 276 | training_data, 277 | validation_data, 278 | optimizer, 279 | epochs, 280 | batch_size): 281 | valid_accus = [0] 282 | # outlier_data = generate_outlier() 283 | edges, edge_weight = training_data 284 | training_data_new = training_data 285 | training_data_generator = DataGenerator( 286 | edges, edge_weight, int(batch_size), 300, True) 287 | 288 | for epoch_i in range(epochs): 289 | 290 | save_embeddings(model, True) 291 | print ('[ Epoch', epoch_i, 'of', epochs, ']') 292 | 293 | start = time.time() 294 | edges_part, edge_weight_part = training_data_generator.next_iter() 295 | training_data_new = edges_part, edge_weight_part 296 | 297 | bce_loss, recon_loss, train_accu, auc1, auc2 = train_epoch(model, loss, training_data_new, optimizer, batch_size) 298 | 299 | print ( 300 | ' - (Training) bce: {bce_loss: 7.4f},' 301 | 'recon: {recon_loss: 7.4f}' 302 | ' acc: {accu}, auc: {auc1}, aupr: {auc2}, ' 303 | 'elapse: {elapse:3.3f} s'.format( 304 | bce_loss=bce_loss, 305 | recon_loss=recon_loss, 306 | accu=train_accu, 307 | auc1=auc1, 308 | auc2=auc2, 309 | elapse=( 310 | time.time() - start))) 311 | 312 | start = time.time() 313 | valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, batch_size) 314 | print ( 315 | ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 316 | ' acc: {accu},' 317 | ' auc: {auc1}, aupr: {auc2},' 318 | 'elapse: {elapse:3.3f} s'.format( 319 | bce_loss=valid_bce_loss, 320 | recon_loss=recon_loss, 321 | accu=valid_accu, 322 | auc1=valid_auc1, 323 | auc2=valid_auc2, 324 | elapse=( 325 | time.time() - start))) 326 | 327 | valid_aupr_final = float(valid_auc2.split(" ")[-1]) 328 | valid_accus += [valid_aupr_final] 329 | 330 | checkpoint = { 331 | 'model_link': model.state_dict(), 332 | 'epoch': epoch_i} 333 | 334 | model_name = 'model.chkpt' 335 | 336 | if valid_aupr_final >= max(valid_accus): 337 | torch.save(checkpoint, os.path.join(args.save_path, model_name)) 338 | 339 | torch.cuda.empty_cache() 340 | 341 | # checkpoint = torch.load(os.path.join(args.save_path, model_name)) 342 | # model.load_state_dict(checkpoint['model_link']) 343 | # 344 | # 345 | # valid_bce_loss, recon_loss, valid_accu, valid_auc1, valid_auc2 = eval_epoch(model, loss, validation_data, 346 | # batch_size,final=True) 347 | # print( 348 | # ' - (Validation-hyper) bce: {bce_loss: 7.4f}, recon: {recon_loss: 7.4f},' 349 | # ' acc: {accu},' 350 | # ' auc: {auc1}, aupr: {auc2},' 351 | # 'elapse: {elapse:3.3f} s'.format( 352 | # bce_loss=valid_bce_loss, 353 | # recon_loss=recon_loss, 354 | # accu=valid_accu, 355 | # auc1=valid_auc1, 356 | # auc2=valid_auc2, 357 | # elapse=( 358 | # time.time() - start))) 359 | 360 | def neighbor_check(temp, dict): 361 | flag = False 362 | # return tuple(temp) in dict 363 | for i in range(len(temp)): 364 | for j in [-1, 0, 1]: 365 | a = np.copy(temp) 366 | a[i] += j 367 | a.sort() 368 | if tuple(a) in dict: 369 | flag = True 370 | break 371 | if flag: 372 | break 373 | return flag 374 | 375 | 376 | def generate_negative(x, dict1, weight=""): 377 | if len(weight) == 0: 378 | weight = torch.ones(len(x), dtype=torch.float) 379 | mode = "" 380 | if dict1 == 'train_dict': 381 | dict1 = train_dict 382 | mode = "train" 383 | elif dict1 == 'test_dict': 384 | dict1 = test_dict 385 | mode = "test" 386 | 387 | neg_list = [] 388 | new_x = [] 389 | new_index = [] 390 | neg_weight = [] 391 | max_id = int(num[-1]) 392 | size_list = [] 393 | size_neg_list = [] 394 | for j, sample in enumerate(x): 395 | for i in range(neg_num): 396 | # generate decomposed sample 397 | # if len(sample) > min_size: 398 | # decompose_sample = np.copy(sample) 399 | # decompose_size = int( 400 | # min(max_size - min_size + 1, len(sample) - min_size + 1) * random.random()) + min_size 401 | # if decompose_size == len(sample): 402 | # decompose_sample = np.copy(sample) 403 | # else: 404 | # decompose_sample = np.copy(sample) 405 | # np.random.shuffle(decompose_sample) 406 | # decompose_sample = decompose_sample[:decompose_size] 407 | # decompose_sample.sort() 408 | # 409 | # if tuple(decompose_sample) not in dict1[len(decompose_sample)]: 410 | # dict1[len(decompose_sample)].add(tuple(decompose_sample)) 411 | # if mode == 'train': 412 | # test_dict[len(decompose_sample)].add(tuple(decompose_sample)) 413 | # 414 | # else: 415 | # decompose_sample = np.copy(sample) 416 | # 417 | # if tuple(decompose_sample) not in dict1[len(decompose_sample)]: 418 | # dict1[len(decompose_sample)].add(tuple(decompose_sample)) 419 | # if mode == 'train': 420 | # test_dict[len(decompose_sample)].add(tuple(decompose_sample)) 421 | decompose_sample = np.copy(sample) 422 | change_num = np.random.binomial(decompose_sample.shape[-1], 0.3, 1) 423 | # change_num = 1 424 | while change_num == 0: 425 | change_num = np.random.binomial(decompose_sample.shape[-1], 0.3, 1) 426 | changes = np.random.choice(np.arange(decompose_sample.shape[-1]), change_num, replace=False) 427 | simple_or_hard = np.random.rand() 428 | temp = np.copy(decompose_sample) 429 | trial = 0 430 | flag = False 431 | while neighbor_check(temp, dict1[(len(temp))]): 432 | temp = np.copy(decompose_sample) 433 | trial += 1 434 | if trial >= 10000: 435 | temp = "" 436 | break 437 | 438 | for change in changes: 439 | # Only change one node 440 | if simple_or_hard <= pair_ratio: 441 | if isinstance(start_end_dict, dict): 442 | start, end = start_end_dict[int(temp[change])] 443 | temp[change] = np.random.randint( 444 | int(start), int(end), 1) + 1 445 | else: 446 | print ("error") 447 | 448 | temp = list(set(temp)) 449 | 450 | if len(temp) < len(decompose_sample): 451 | temp = np.copy(decompose_sample) 452 | continue 453 | 454 | temp.sort() 455 | 456 | if len(temp) > 0: 457 | if i == 0: 458 | new_x.append(decompose_sample) 459 | new_index.append(j) 460 | size_list.append(len(decompose_sample)) 461 | 462 | neg_list.append(temp) 463 | size_neg_list.append(len(temp)) 464 | neg_weight.append(weight[j]) 465 | 466 | new_weight = weight[np.array(new_index)] 467 | new_weight = torch.tensor(new_weight) # .to(device) 468 | neg_weight = torch.tensor(neg_weight) 469 | size_list = torch.Tensor(np.concatenate( 470 | [np.array(size_list), np.array(size_neg_list)], axis=0)) 471 | x = np2tensor_hyper(new_x, dtype=torch.long) 472 | neg = np2tensor_hyper(neg_list, dtype=torch.long) 473 | if type(x) == list: 474 | a = x + neg 475 | else: 476 | a = torch.cat([x,neg],dim = 0) 477 | a = pad_sequence(a, batch_first=True, padding_value=0).to(device) 478 | 479 | return a, \ 480 | torch.cat([torch.ones((len(x), 1), device=device), (torch.zeros((len(neg), 1), device=device))]), \ 481 | torch.cat([new_weight, neg_weight], dim=0), \ 482 | size_list 483 | 484 | 485 | def predict(model, input): 486 | model.eval() 487 | output = [] 488 | with torch.no_grad(): 489 | for j in trange(math.ceil(len(input) / batch_size)): 490 | x = input[j * batch_size:min((j + 1) * batch_size, len(input))] 491 | x = np2tensor_hyper(x, dtype=torch.long) 492 | x = pad_sequence(x, batch_first=True, padding_value=0).to(device) 493 | output.append(model(x).detach().cpu().numpy()) 494 | output = np.concatenate(output, axis=0) 495 | torch.cuda.empty_cache() 496 | return output 497 | 498 | 499 | def save_embeddings(model, origin=False): 500 | model.eval() 501 | with torch.no_grad(): 502 | ids = np.arange(num_list[-1]) + 1 503 | ids = torch.Tensor(ids).long().to(device).view(-1, 1) 504 | embeddings = [] 505 | for j in range(math.ceil(len(ids) / batch_size)): 506 | x = ids[j * batch_size:min((j + 1) * batch_size, len(ids))] 507 | if origin: 508 | embed = model.get_node_embeddings(x) 509 | else: 510 | embed = model.get_embedding_static(x) 511 | embed = embed.detach().cpu().numpy() 512 | embeddings.append(embed) 513 | 514 | embeddings = np.concatenate(embeddings, axis=0)[:, 0, :] 515 | for i in range(len(num_list)): 516 | start = 0 if i == 0 else num_list[i - 1] 517 | static = embeddings[int(start):int(num_list[i])] 518 | np.save("../mymodel_%d.npy" % (i), static) 519 | 520 | if origin: 521 | if i == 0: 522 | old_static = np.load("../mymodel_%d_origin.npy" % (i)) 523 | try: 524 | update_rate = np.sum((old_static - static) ** 2, axis=-1) / np.sum(old_static ** 2, axis=-1) 525 | print("update_rate: %f\t%f" % (np.min(update_rate), np.max(update_rate))) 526 | except: 527 | pass 528 | np.save("../mymodel_%d_origin.npy" % (i), static) 529 | 530 | torch.cuda.empty_cache() 531 | return embeddings 532 | 533 | 534 | # New different from 535 | 536 | 537 | 538 | def generate_embeddings(edge, nums_type, H=None, weight=1): 539 | if len(num) == 1: 540 | return [get_adjacency(edge, weight, True)] 541 | 542 | 543 | def get_adjacency(data, weight, norm=True): 544 | A = np.zeros((num_list[-1], num_list[-1])) 545 | 546 | for index, datum in enumerate(tqdm(data)): 547 | for i in range(datum.shape[-1]): 548 | for j in range(datum.shape[-1]): 549 | if i != j: 550 | A[datum[i], datum[j]] += weight[index] 551 | 552 | if norm: 553 | temp = np.concatenate((np.zeros((1), dtype='int'), num), axis=0) 554 | temp = np.cumsum(temp) 555 | 556 | for i in range(len(temp) - 1): 557 | A[temp[i]:temp[i + 1], 558 | :] /= (np.max(A[temp[i]:temp[i + 1], 559 | :], 560 | axis=0, 561 | keepdims=True) + 1e-10) 562 | 563 | return csr_matrix(A).astype('float32') 564 | 565 | 566 | args = parse_args() 567 | neg_num = 5 568 | batch_size = 96 569 | neg_num_w2v = 5 570 | bottle_neck = args.dimensions 571 | pair_ratio = 1.0 572 | dynamic_dict = False 573 | max_size = 5 574 | min_size = 3 575 | train_type = 'hyper' 576 | loss = F.binary_cross_entropy 577 | 578 | 579 | train_zip = np.load("../data/%s/train_data.npz" % (args.data), allow_pickle=True) 580 | test_zip = np.load("../data/%s/test_data.npz" % (args.data), allow_pickle=True) 581 | train_data, test_data = train_zip['train_data'], test_zip['test_data'] 582 | 583 | try: 584 | train_weight, test_weight = train_zip["train_weight"].astype('float32'), test_zip["test_weight"].astype('float32') 585 | except BaseException: 586 | print("no specific train weight") 587 | test_weight = np.ones(len(test_data), dtype='float32') 588 | train_weight = np.ones(len(train_data), dtype='float32') * neg_num 589 | 590 | num = train_zip['nums_type'] 591 | num_list = np.cumsum(num) 592 | print("Node type num", num) 593 | 594 | try: 595 | start_end_dict = train_zip["start_end_dict"].item() 596 | start_end_dict = dict(start_end_dict) 597 | new_dict = {} 598 | for k in tqdm(start_end_dict): 599 | v1, v2 = start_end_dict[k] 600 | new_dict[k + 1] = int(v1), int(v2) 601 | 602 | start_end_dict = new_dict 603 | except Exception as e: 604 | print("no specific start_end_dict", e) 605 | start_end_dict = "nothing" 606 | 607 | try: 608 | attribute_dict = train_zip["attribute_dict"].reshape((-1, 1)) 609 | 610 | except Exception as e: 611 | print("no specific attribute_dict", e) 612 | attribute_dict = None 613 | 614 | print("before weight filter", train_data.shape, test_data.shape) 615 | 616 | # l = np.load("../label.npy") 617 | # for i in np.unique(train_data[:,0]): 618 | # weight_sum = np.sum(train_weight[train_data[:,0] == i]) 619 | # print (i, l[i], weight_sum) 620 | 621 | if args.feature == 'adj': 622 | temp_filter_num = 0 if args.data in ['drop', 'drop_only', 'drop_non'] else 0 623 | embeddings_initial = generate_embeddings(train_data[(train_weight >= temp_filter_num)], num, H=None, 624 | weight=train_weight[(train_weight >= temp_filter_num)]) 625 | 626 | if attribute_dict is not None: 627 | attribute_dict = np.concatenate([attribute_dict % 1e7, np.floor(attribute_dict / 1e7)]) 628 | print(attribute_dict) 629 | attribute_dict /= np.max(attribute_dict) 630 | attribute_dict = np.concatenate([np.zeros((num[0] + 1, 1)), attribute_dict], axis=0).astype('float32') 631 | print(attribute_dict, attribute_dict.shape) 632 | 633 | num = torch.as_tensor(num) 634 | num_list = torch.as_tensor(num_list) 635 | 636 | print("walk type", args.walk) 637 | 638 | filter_num = 0 if args.data in ['drop', 'drop_only', 'drop_non'] else 0 639 | 640 | train_mask = (train_weight >= filter_num) 641 | train_data = train_data[train_mask] 642 | train_weight = train_weight[train_mask] 643 | test_mask = (test_weight >= filter_num) 644 | test_data = test_data[test_mask] 645 | test_weight = test_weight[test_mask] 646 | 647 | dict_data = np.concatenate([train_data, test_data]) 648 | dict_weight = np.concatenate([train_weight, test_weight]) 649 | 650 | 651 | # At this stage, the index still starts from zero 652 | 653 | node_list = np.arange(num_list[-1]).astype('int') 654 | if args.walk == 'hyper': 655 | walk_path = random_walk_hyper(args, node_list, train_data) 656 | else: 657 | walk_path = random_walk(args, num, train_data) 658 | del node_list 659 | 660 | # Add 1 for the padding index 661 | print("adding pad idx") 662 | dict_data = add_padding_idx(dict_data) 663 | train_data = add_padding_idx(train_data) 664 | test_data = add_padding_idx(test_data) 665 | 666 | # Note that, no matter how many node types are here, make sure the 667 | # hyperedge (N1,N2,N3,...) has id, N1 < N2 < N3... 668 | 669 | compress = True 670 | # Note that, no matter how many node types are here, make sure the 671 | # hyperedge (N1,N2,N3,...) has id, N1 < N2 < N3... 672 | if not dynamic_dict: 673 | test_dict = build_hash(dict_data, compress=compress, max_size=max_size, 674 | min_size=min_size, fname="test") 675 | train_dict = test_dict 676 | # train_dict = build_hash(train_data, compress = compress, max_size=max_size, min_size = min_size, fname="test") 677 | else: 678 | train_dict = [BloomFilter(1e8, 1e-3) for i in range(max_size + 1)] 679 | test_dict = [BloomFilter(1e8, 1e-3) for i in range(max_size + 1)] 680 | print("dict_size", len(train_dict), len(test_dict)) 681 | 682 | 683 | 684 | print("after weight filter", train_data.shape, test_data.shape, dict_data.shape) 685 | print(train_weight, np.min(train_weight), np.max(train_weight)) 686 | train_weight_mean = np.mean(train_weight) 687 | train_weight = train_weight / train_weight_mean * neg_num 688 | test_weight = test_weight / train_weight_mean * neg_num 689 | dict_weight = dict_weight / train_weight_mean * neg_num 690 | print("train data amount", len(train_data)) 691 | 692 | 693 | if args.feature == 'walk': 694 | # Note that for this part, the word2vec still takes sentences with 695 | # words starts at "0" 696 | if not args.TRY and os.path.exists( 697 | "../%s_wv_%d_%s.npy" % 698 | (args.data, args.dimensions, args.walk)): 699 | A = np.load( 700 | "../%s_wv_%d_%s.npy" % 701 | (args.data, 702 | args.dimensions, 703 | args.walk), 704 | allow_pickle=True) 705 | else: 706 | print ("start loading") 707 | walks = np.loadtxt(walk_path, delimiter=" ").astype('int') 708 | start = time.time() 709 | split_num = 20 710 | pool = ProcessPoolExecutor(max_workers=split_num) 711 | process_list = [] 712 | walks = np.array_split(walks, split_num) 713 | 714 | result = [] 715 | print ("Start turning path to strs") 716 | for walk in walks: 717 | process_list.append(pool.submit(walkpath2str, walk)) 718 | 719 | for p in as_completed(process_list): 720 | result += p.result() 721 | 722 | pool.shutdown(wait=True) 723 | 724 | walks = result 725 | print ( 726 | "Finishing Loading and processing %.2f s" % 727 | (time.time() - start)) 728 | print ("Start Word2vec") 729 | import multiprocessing 730 | 731 | print ("num cpu cores", multiprocessing.cpu_count()) 732 | w2v = Word2Vec( 733 | walks, 734 | size=args.dimensions, 735 | window=args.window_size, 736 | min_count=0, 737 | sg=1, 738 | iter=1, 739 | workers=multiprocessing.cpu_count()) 740 | wv = w2v.wv 741 | A = [wv[str(i)] for i in range(num_list[-1])] 742 | np.save("../%s_wv_%d_%s.npy" % 743 | (args.data, args.dimensions, args.walk), A) 744 | 745 | from sklearn.preprocessing import StandardScaler 746 | 747 | A = StandardScaler().fit_transform(A) 748 | 749 | A = np.concatenate( 750 | (np.zeros((1, A.shape[-1]), dtype='float32'), A), axis=0) 751 | A = A.astype('float32') 752 | A = torch.tensor(A).to(device) 753 | print (A.shape) 754 | 755 | node_embedding = Wrap_Embedding(int( 756 | num_list[-1] + 1), args.dimensions, scale_grad_by_freq=False, padding_idx=0, sparse=False) 757 | node_embedding.weight = nn.Parameter(A) 758 | 759 | elif args.feature == 'adj': 760 | flag = False 761 | node_embedding = MultipleEmbedding( 762 | embeddings_initial, 763 | bottle_neck, 764 | flag, 765 | num_list).to(device) 766 | 767 | classifier_model = Classifier( 768 | n_head=8, 769 | d_model=args.dimensions, 770 | d_k=16, 771 | d_v=16, 772 | node_embedding=node_embedding, 773 | diag_mask=args.diag, 774 | bottle_neck=bottle_neck).to(device) 775 | 776 | save_embeddings(classifier_model, True) 777 | 778 | 779 | 780 | summary(classifier_model, (3,)) 781 | 782 | params_list = list(classifier_model.parameters()) 783 | 784 | if args.feature == 'adj': 785 | # optimizer = torch.optim.RMSprop(params_list, lr=1e-3) 786 | optimizer = torch.optim.AdamW(params_list, lr=1e-3, amsgrad=False) 787 | else: 788 | optimizer = torch.optim.RMSprop(params_list, lr=1e-3) 789 | 790 | model_parameters = filter(lambda p: p.requires_grad, params_list) 791 | params = sum([np.prod(p.size()) for p in model_parameters]) 792 | print ("params to be trained", params) 793 | 794 | 795 | train(classifier_model, 796 | loss=loss, 797 | training_data=(train_data, train_weight), 798 | validation_data=(test_data, test_weight), 799 | optimizer=[optimizer], epochs=30, batch_size=batch_size) 800 | model_name = 'model.chkpt' 801 | checkpoint = torch.load(os.path.join(args.save_path, model_name)) 802 | classifier_model.load_state_dict(checkpoint['model_link']) -------------------------------------------------------------------------------- /History_version/Code/process_SPRITE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import math 4 | from tqdm import tqdm, trange 5 | res = 1000000 6 | chrom_list = ["chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8", "chr9", \ 7 | "chr10", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", \ 8 | "chr18", "chr19", "chr20","chr21","chr22"] 9 | def filter(): 10 | file1 = open("../../SPRITE/4DNFIBEVVTN5.clusters","r") 11 | result = open("../../SPRITE/filtered.txt","w") 12 | 13 | 14 | line = file1.readline() 15 | 16 | while line: 17 | info = line.strip().split("\t") 18 | if len(info) <= 2: 19 | line = file1.readline() 20 | continue 21 | else: 22 | result.write("\t".join(info[1:])+"\n") 23 | 24 | line = file1.readline() 25 | 26 | result.close() 27 | 28 | 29 | def build_node_dict(): 30 | tab = pd.read_table("../../SPRITE/hg38.chrom.sizes.txt",header = None,sep = "\t") 31 | tab.columns = ['chr','size'] 32 | print (tab) 33 | 34 | bin2node = {} 35 | node2bin = {} 36 | node2chrom = {} 37 | chrom_range = [] 38 | count = 1 39 | 40 | for j, chrom in enumerate(chrom_list): 41 | size = np.max(tab['size'][tab['chr'] == chrom]) 42 | max_bin_chrom = math.ceil(size / res) 43 | 44 | temp = [count] 45 | for i in range(max_bin_chrom + 1): 46 | bin_ = "%s:%d" %(chrom, i * res) 47 | bin2node[bin_] = count 48 | node2bin[count] = bin_ 49 | node2chrom[count] = j 50 | count += 1 51 | temp.append(count) 52 | chrom_range.append(temp) 53 | print (chrom_range) 54 | np.save("../data/SPRITE/chrom_range.npy",chrom_range) 55 | np.save("../data/SPRITE/bin2node.npy", bin2node) 56 | np.save("../data/SPRITE/node2chrom.npy",node2chrom) 57 | np.save("../data/SPRITE/node2bin.npy",node2bin) 58 | 59 | def parse_file(): 60 | result = open("../../SPRITE/filtered.txt", "r") 61 | line = result.readline() 62 | bin2node = np.load("../data/SPRITE/bin2node.npy",allow_pickle=True).item() 63 | node2bin = np.load("../data/SPRITE/node2bin.npy",allow_pickle=True).item() 64 | node2chrom = np.load("../data/SPRITE/node2chrom.npy",allow_pickle=True).item() 65 | final = [] 66 | while line: 67 | info_list = line.strip().split("\t") 68 | temp = [] 69 | if len(info_list) > 1000: 70 | line = result.readline() 71 | continue 72 | for info in info_list: 73 | chrom, bin_ = info.split(":") 74 | if chrom not in chrom_list: 75 | continue 76 | bin_ = int(math.floor(int(bin_) / res)) * res 77 | bin_ = "%s:%d" %(chrom,bin_) 78 | node = bin2node[bin_] 79 | temp.append(node) 80 | temp = list(set(temp)) 81 | temp.sort() 82 | 83 | if len(temp) > 1: 84 | final.append(temp) 85 | 86 | 87 | line = result.readline() 88 | 89 | np.save("../data/SPRITE/edge_list.npy", final) 90 | 91 | final = np.load("../data/SPRITE/edge_list.npy", allow_pickle=True) 92 | chrom_range = np.load("../data/SPRITE/chrom_range.npy", allow_pickle=True) 93 | node_freq = np.zeros((np.max(chrom_range))) 94 | for e in tqdm(final): 95 | if len(e) > 25: 96 | continue 97 | 98 | for n in e: 99 | node_freq[n] += 1 100 | print (node_freq) 101 | 102 | drop_list = np.where(node_freq <= 50)[0] 103 | print (drop_list, len(drop_list)) 104 | 105 | node2newnode = {} 106 | dropnode2newnode = {} 107 | newnode2chrom = {} 108 | 109 | count = 1 110 | for n in range(np.max(chrom_range)): 111 | if n == 0: 112 | continue 113 | elif n in drop_list: 114 | dropnode2newnode[n] = count 115 | else: 116 | node2newnode[n] = count 117 | count += 1 118 | dropnode2newnode[n+1] = count 119 | print ("remap") 120 | 121 | new_node2bin = {} 122 | new_bin2node = {} 123 | 124 | for node in node2bin: 125 | if node in node2newnode: 126 | new_node2bin[node2newnode[node]] = node2bin[node] 127 | new_bin2node[node2bin[node]] = node2newnode[node] 128 | newnode2chrom[node2newnode[node]] = node2chrom[node] 129 | 130 | np.save("../data/SPRITE/bin2node.npy", new_bin2node) 131 | np.save("../data/SPRITE/node2bin.npy", new_node2bin) 132 | np.save("../data/SPRITE/node2chrom.npy", newnode2chrom) 133 | 134 | new_final = [] 135 | for e in tqdm(final): 136 | temp = [] 137 | for n in e: 138 | if n in node2newnode: 139 | temp.append(node2newnode[n]) 140 | if len(temp) >= 2: 141 | new_final.append(temp) 142 | final = new_final 143 | new_chrom_range = [] 144 | for v in chrom_range: 145 | temp = [] 146 | if v[0] in node2newnode: 147 | temp.append(node2newnode[v[0]]) 148 | else: 149 | temp.append(dropnode2newnode[v[0]]) 150 | 151 | if v[1] in node2newnode: 152 | temp.append(node2newnode[v[1]]) 153 | else: 154 | temp.append(dropnode2newnode[v[1]]) 155 | 156 | new_chrom_range.append(temp) 157 | print (chrom_range,new_chrom_range) 158 | 159 | # print (final) 160 | np.save("../data/SPRITE/edge_list.npy",final) 161 | np.save("../data/SPRITE/chrom_range.npy", new_chrom_range) 162 | 163 | 164 | def parse_cool_contact(): 165 | file = pd.read_table("../data/SPRITE/SPRITE_contact.txt",sep = "\t") 166 | bin2node = np.load("../data/SPRITE/bin2node.npy", allow_pickle=True).item() 167 | chrom_range = np.load("../data/SPRITE/chrom_range.npy") 168 | 169 | node_num = int(np.max(chrom_range)) 170 | print (node_num) 171 | 172 | intra_adj = np.zeros((node_num - 1,node_num - 1)) 173 | inter_adj = np.zeros((node_num - 1,node_num - 1)) 174 | for i in trange(len(file)): 175 | chrom1 = file['chrom1'][i] 176 | start1 = file['start1'][i] 177 | chrom2 = file['chrom2'][i] 178 | start2 = file['start2'][i] 179 | 180 | if chrom1 not in chrom_list or chrom2 not in chrom_list: 181 | continue 182 | 183 | w = file['balanced'][i] 184 | 185 | if not np.isnan(w): 186 | bin1 = "%s:%d" %(chrom1,start1) 187 | bin2 = "%s:%d" %(chrom2,start2) 188 | if bin1 in bin2node and bin2 in bin2node: 189 | node1 = bin2node[bin1] - 1 190 | node2 = bin2node[bin2] - 1 191 | if chrom1 == chrom2: 192 | intra_adj[node1, node2] += w 193 | intra_adj[node2,node1] += w 194 | else: 195 | inter_adj[node1, node2] += w 196 | inter_adj[node2, node1] += w 197 | else: 198 | print (bin1,bin2) 199 | 200 | print(intra_adj, inter_adj) 201 | np.save("../data/SPRITE/intra_adj_SPRITE.npy", intra_adj) 202 | np.save("../data/SPRITE/inter_adj_SPRITE.npy", inter_adj) 203 | 204 | 205 | build_node_dict() 206 | parse_file() 207 | parse_cool_contact() -------------------------------------------------------------------------------- /History_version/Code/random_walk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import networkx as nx 5 | import random 6 | from tqdm import tqdm 7 | import torch 8 | from concurrent.futures import as_completed, ProcessPoolExecutor 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | device_ids = [0, 1] 12 | 13 | 14 | class Graph(): 15 | def __init__(self, nx_G, p, q, is_directed=False): 16 | self.G = nx_G 17 | self.is_directed = is_directed 18 | self.p = p 19 | self.q = q 20 | self.neighbors = [] 21 | print("initialization") 22 | for i in range(len(nx_G.nodes()) 23 | ): # actualy nx_G.nodes() is already increasing order 24 | self.neighbors.append(sorted(nx_G.neighbors(i))) 25 | self.degree = np.zeros((len(nx_G.nodes()))) 26 | for i in range(len(nx_G.nodes())): 27 | self.degree[i] = np.sum([nx_G[i][nbr]['weight'] 28 | for nbr in self.neighbors[i]]) 29 | print(self.degree) 30 | 31 | 32 | def get_alias_edge(src, dst): 33 | ''' 34 | Get the alias edge setup lists for a given edge. 35 | ''' 36 | global sG 37 | G = sG.G 38 | p = sG.p 39 | q = sG.q 40 | 41 | unnormalized_probs = [] 42 | for dst_nbr in sG.neighbors[dst]: 43 | if dst_nbr == src: 44 | unnormalized_probs.append( 45 | (G[dst][dst_nbr]['weight'] / p) / np.sqrt(sG.degree[dst_nbr])) 46 | # unnormalized_probs.append((G[dst][dst_nbr]['weight'] / p)) 47 | elif G.has_edge(dst_nbr, src): 48 | unnormalized_probs.append( 49 | (G[dst][dst_nbr]['weight']) / 50 | np.sqrt( 51 | sG.degree[dst_nbr])) 52 | # unnormalized_probs.append((G[dst][dst_nbr]['weight'])) 53 | else: 54 | unnormalized_probs.append( 55 | (G[dst][dst_nbr]['weight'] / q) / np.sqrt(sG.degree[dst_nbr])) 56 | # unnormalized_probs.append((G[dst][dst_nbr]['weight'] / q)) 57 | norm_const = sum(unnormalized_probs) 58 | normalized_probs = [ 59 | float(u_prob) / 60 | norm_const for u_prob in unnormalized_probs] 61 | 62 | return alias_setup(normalized_probs) 63 | 64 | 65 | def alias_some_edges(edges): 66 | alias_edges = {} 67 | for edge in tqdm(edges): 68 | alias_edges[(edge[0], edge[1])] = get_alias_edge(edge[0], edge[1]) 69 | alias_edges[(edge[1], edge[0])] = get_alias_edge(edge[1], edge[0]) 70 | return alias_edges 71 | 72 | 73 | def preprocess_transition_probs(sg): 74 | ''' 75 | Preprocessing of transition probabilities for guiding the random walks. 76 | ''' 77 | global sG 78 | sG = sg 79 | G = sG.G 80 | is_directed = sG.is_directed 81 | 82 | print("transition probs: ") 83 | alias_nodes = {} 84 | for node in tqdm(G.nodes()): 85 | unnormalized_probs = [ 86 | G[node][nbr]['weight'] / 87 | np.sqrt( 88 | sG.degree[nbr]) for nbr in sG.neighbors[node]] 89 | # unnormalized_probs = [G[node][nbr]['weight'] for nbr in sG.neighbors[node]] 90 | norm_const = sum(unnormalized_probs) 91 | normalized_probs = [float(u_prob) / 92 | norm_const for u_prob in unnormalized_probs] 93 | alias_nodes[node] = alias_setup(normalized_probs) 94 | 95 | triads = {} 96 | 97 | # Parallel alias edges 98 | print("alias edges: ") 99 | edges = G.edges() 100 | 101 | threads_num = 100 102 | pool = ProcessPoolExecutor(max_workers=threads_num) 103 | process_list = [] 104 | 105 | edges = np.array_split(edges, threads_num * 2) 106 | for e in edges: 107 | process_list.append(pool.submit(alias_some_edges, e)) 108 | 109 | alias_edges = {} 110 | for p in as_completed(process_list): 111 | alias_t = p.result() 112 | alias_edges.update(alias_t) 113 | pool.shutdown(wait=True) 114 | 115 | sG.alias_nodes = alias_nodes 116 | sG.alias_edges = alias_edges 117 | 118 | 119 | def alias_setup(probs): 120 | ''' 121 | Compute utility lists for non-uniform sampling from discrete distributions. 122 | Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 123 | for details 124 | ''' 125 | K = len(probs) 126 | q = np.zeros(K) 127 | J = np.zeros(K, dtype=np.int) 128 | 129 | smaller = [] 130 | larger = [] 131 | for kk, prob in enumerate(probs): 132 | q[kk] = K * prob 133 | if q[kk] < 1.0: 134 | smaller.append(kk) 135 | else: 136 | larger.append(kk) 137 | 138 | while len(smaller) > 0 and len(larger) > 0: 139 | small = smaller.pop() 140 | large = larger.pop() 141 | 142 | J[small] = large 143 | q[large] = q[large] + q[small] - 1.0 144 | if q[large] < 1.0: 145 | smaller.append(large) 146 | else: 147 | larger.append(large) 148 | 149 | return J, q 150 | 151 | 152 | def alias_draw(J, q): 153 | ''' 154 | Draw sample from a non-uniform discrete distribution using alias sampling. 155 | ''' 156 | K = len(J) 157 | 158 | kk = int(np.floor(np.random.rand() * K)) 159 | if np.random.rand() < q[kk]: 160 | return kk 161 | else: 162 | return J[kk] 163 | 164 | 165 | def add_weight(G, u, v): 166 | if 'weight' not in G[u][v]: 167 | G[u][v]['weight'] = 1 168 | else: 169 | G[u][v]['weight'] += 1 170 | 171 | 172 | def node2vec_walk(sG, walk_length, start_node): 173 | ''' 174 | Simulate a random walk starting from start node. 175 | ''' 176 | alias_nodes = sG.alias_nodes 177 | alias_edges = sG.alias_edges 178 | 179 | walk = [start_node] 180 | 181 | while len(walk) < walk_length: 182 | cur = walk[-1] 183 | cur_nbrs = sG.neighbors[cur] 184 | if len(cur_nbrs) > 0: 185 | if len(walk) == 1: 186 | walk.append(cur_nbrs[alias_draw( 187 | alias_nodes[cur][0], alias_nodes[cur][1])]) 188 | else: 189 | prev = walk[-2] 190 | next_n = cur_nbrs[alias_draw(alias_edges[(prev, cur)][0], 191 | alias_edges[(prev, cur)][1])] 192 | walk.append(next_n) 193 | else: 194 | walk.append(cur) 195 | continue 196 | 197 | return walk 198 | 199 | 200 | def simulate_walks(sG, num_walks, walk_length): 201 | ''' 202 | Repeatedly simulate random walks from each node. 203 | ''' 204 | print("sample walks:") 205 | walks = [] 206 | nodes = sG.G.nodes() 207 | for node in tqdm(nodes): 208 | for walk_iter in range(num_walks): 209 | temp = node2vec_walk(sG, walk_length, node) 210 | if len(temp) == walk_length: 211 | walks.append(temp) 212 | 213 | random.shuffle(walks) 214 | return walks 215 | 216 | 217 | def read_graph(num, hyperedge_list): 218 | ''' 219 | Transfer the hyperedge to pairwise edge & Reads the input network in networkx. 220 | ''' 221 | G = nx.Graph() 222 | tot = sum(num) 223 | G.add_nodes_from(range(tot)) 224 | for ee in tqdm(hyperedge_list): 225 | e = ee 226 | edges_to_add = [] 227 | for i in range(len(e)): 228 | for j in range(i + 1, len(e)): 229 | edges_to_add.append((e[i], e[j])) 230 | G.add_edges_from(edges_to_add) 231 | for i in range(len(e)): 232 | for j in range(i + 1, len(e)): 233 | add_weight(G, e[i], e[j]) 234 | 235 | G = G.to_undirected() 236 | 237 | return G 238 | 239 | 240 | def toint(hyperedge_list): 241 | return np.array([h.astype('int') for h in hyperedge_list]) 242 | 243 | 244 | def random_walk(args, num, hyperedge_list): 245 | ''' 246 | Learn embeddings by optimizing the Skipgram objective using SGD. 247 | ''' 248 | # p, q = 1, 1 249 | # num_walks, walk_length, window_size = 10, 80, 10 250 | hyperedge_list = toint(hyperedge_list) 251 | p, q = args.p, args.q 252 | num_walks, walk_length, window_size = args.num_walks, args.walk_length, args.window_size 253 | # emb_save_path = '../embs/{}/p{}_q{}_r{}_l{}_k{}_i{}.embs'.format(args.data, p, q, num_walks, walk_length, window_size, iteration) 254 | if not os.path.exists("../walks/{}/".format(args.data)): 255 | os.mkdir("../walks/{}/".format(args.data)) 256 | walks_save_path = '../walks/{}/p{}_q{}_r{}_l{}_walks.txt'.format( 257 | args.data, p, q, num_walks, walk_length) 258 | start = time.time() 259 | 260 | if not args.TRY and os.path.exists(walks_save_path): 261 | return walks_save_path 262 | else: 263 | nx_G = read_graph(num.numpy(), hyperedge_list) 264 | G = Graph(nx_G, p, q) 265 | preprocess_transition_probs(G) 266 | walks = simulate_walks(G, num_walks, walk_length) 267 | walks = np.array(walks) 268 | 269 | print(walks.shape) 270 | np.savetxt(walks_save_path, walks, fmt="%d", delimiter=" ") 271 | #np.save(walks_save_path, walks) 272 | 273 | print("RandomWalk running time: %.2lf" % (time.time() - start)) 274 | 275 | return walks_save_path 276 | -------------------------------------------------------------------------------- /History_version/Code/random_walk_hyper.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import as_completed, ProcessPoolExecutor 2 | from scipy.sparse import csr_matrix, lil_matrix, csc_matrix 3 | from tqdm import tqdm, trange 4 | import time 5 | import numpy as np 6 | import os 7 | 8 | # os.environ["OMP_DISPLAY_ENV"] = "FALSE" 9 | # os.environ["OMP_NUM_THREADS"] = "20" 10 | os.environ["KMP_AFFINITY"] = 'none' 11 | # os.environ["KMP_AFFINITY"]="scatter" 12 | 13 | 14 | # FIXME: may be there is more efficient method 15 | 16 | weight_1st = 1.0 17 | weight_degree = -0.5 18 | 19 | print(weight_1st, weight_degree) 20 | 21 | 22 | def make_sparse_matrix(raw_data, m, n): 23 | indptr = [len(row) for row in raw_data] 24 | indptr = np.cumsum([0] + indptr) 25 | indices = [i for row in raw_data for i in row] 26 | data = [1] * len(indices) 27 | return csr_matrix((data, indices, indptr), shape=(m, n), dtype='float32') 28 | 29 | 30 | def alias_setup(probs): 31 | ''' 32 | Compute utility lists for non-uniform sampling from discrete distributions. 33 | Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 34 | for details 35 | ''' 36 | K = len(probs) 37 | q = np.zeros(K) 38 | J = np.zeros(K, dtype=np.int) 39 | 40 | smaller = [] 41 | larger = [] 42 | for kk, prob in enumerate(probs): 43 | q[kk] = K * prob 44 | if q[kk] < 1.0: 45 | smaller.append(kk) 46 | else: 47 | larger.append(kk) 48 | 49 | while len(smaller) > 0 and len(larger) > 0: 50 | small = smaller.pop() 51 | large = larger.pop() 52 | 53 | J[small] = large 54 | q[large] = q[large] + q[small] - 1.0 55 | if q[large] < 1.0: 56 | smaller.append(large) 57 | else: 58 | larger.append(large) 59 | 60 | return (J, q) 61 | 62 | 63 | def alias_draw(P): 64 | ''' 65 | Draw sample from a non-uniform discrete distribution using alias sampling. 66 | ''' 67 | J, q = P 68 | K = len(J) 69 | 70 | kk = int(np.floor(np.random.rand() * K)) 71 | if np.random.rand() < q[kk]: 72 | return kk 73 | else: 74 | return J[kk] 75 | 76 | 77 | class HyperGraphRandomWalk(): 78 | def __init__(self, p, q, is_weighted=False): 79 | self.p = p 80 | self.q = q 81 | # FIXME: current version is only for unweighted graph 82 | self.is_weighted = is_weighted 83 | 84 | def build_graph(self, node_list, edge_list): 85 | # is considered to be range(num_node) FIXME: maybe a dict for nodes 86 | # will be better 87 | self.nodes = node_list 88 | self.edges = edge_list # the neighbors of hyperedges (without weight) 89 | 90 | # the neighbors of nodes (with weight) 91 | n_edge = [[] for _ in range(int(np.max(node_list) + 1))] 92 | 93 | self.node_degree = np.zeros((int(np.max(node_list) + 1))) 94 | self.edge_degree = np.array([len(e) for e in self.edges]) 95 | for i, e in enumerate(tqdm(edge_list)): 96 | if isinstance(e, tuple): 97 | e = list(e) 98 | e.sort() 99 | ww = 1 # FIXME: unweighted case 100 | for v in e: 101 | n_edge[v].append((i, ww)) 102 | 103 | self.node_degree[v] += 1 104 | 105 | for v in tqdm(node_list): 106 | n_edge_i = sorted(n_edge[v]) 107 | n_edge[v] = np.array(n_edge_i) 108 | 109 | self.n_edge = n_edge 110 | # adjacent matrices of V x E, E x V, E x E 111 | print('adj matrix:') 112 | self.EV = make_sparse_matrix( 113 | self.edges, len( 114 | self.edges), int( 115 | np.max(node_list) + 1)) 116 | self.delta = lil_matrix((self.EV.shape[0], self.EV.shape[0])) 117 | size = np.array([1 / np.sqrt(len(e)) for e in self.edges]) 118 | self.delta.setdiag(size) 119 | 120 | self.EV_over_delta = self.delta * self.EV 121 | 122 | self.VE = self.EV.T 123 | self.VE_over_delta = self.EV_over_delta.T 124 | 125 | print("EV size", self.EV.shape) 126 | 127 | 128 | def get_first_order_part(nodes): 129 | alias_n2n_1st = {} 130 | node2ff_1st = {} 131 | 132 | for src in tqdm(nodes): 133 | dsts = node_nbr[src] 134 | ff_1st = np.array( 135 | (VE_over_delta[src, :] * EV_over_delta[:, dsts]).todense()).reshape((-1)) 136 | node2ff_1st[src] = ff_1st 137 | unnormalized_probs = ff_1st / np.sqrt(node_degree[dsts]) 138 | normalized_probs = unnormalized_probs / np.sum(unnormalized_probs) 139 | alias_n2n_1st[src] = alias_setup(normalized_probs) 140 | 141 | return alias_n2n_1st, node2ff_1st 142 | 143 | 144 | def get_first_order(G): 145 | print("1st order: ") 146 | global EV, VE, EV_over_delta, VE_over_delta, node_nbr, node_degree 147 | 148 | EV = G.EV 149 | VE = G.VE 150 | EV_over_delta = G.EV_over_delta 151 | VE_over_delta = G.VE_over_delta 152 | node_nbr = G.node_nbr 153 | node_degree = G.node_degree 154 | 155 | processes_num = 80 156 | pool = ProcessPoolExecutor(max_workers=processes_num) 157 | process_list = [] 158 | 159 | nodes = np.copy(G.nodes) 160 | 161 | split_num = min(processes_num, int(len(nodes) / 100)) + 1 162 | print("split_num", split_num) 163 | np.random.shuffle(nodes) 164 | nodes = np.array_split(nodes, split_num) 165 | 166 | print("Start get first order") 167 | for node in nodes: 168 | process_list.append(pool.submit(get_first_order_part, node)) 169 | 170 | alias_n2n_1st = {} 171 | node2ff_1st = {} 172 | for p in as_completed(process_list): 173 | alias_t1, alias_t2 = p.result() 174 | alias_n2n_1st.update(alias_t1) 175 | node2ff_1st.update(alias_t2) 176 | 177 | pool.shutdown(wait=True) 178 | 179 | print("start turn dict to list") 180 | 181 | nodes = np.copy(G.nodes) 182 | 183 | alias_n2n_1st_list = [[] for n in nodes] 184 | node2ff_1st_list = [[] for n in nodes] 185 | 186 | for n in nodes: 187 | alias_n2n_1st_list[n] = alias_n2n_1st[n] 188 | node2ff_1st_list[n] = node2ff_1st[n] 189 | 190 | return alias_n2n_1st_list, node2ff_1st_list 191 | 192 | 193 | def get_src_dst2e(G, edges): 194 | src_dst_2e = {} 195 | node_nbr = [[] for n in range(int(np.max(G.nodes)) + 1)] 196 | 197 | for e1 in tqdm(edges): 198 | for src in G.edges[e1]: 199 | for dst in G.edges[e1]: 200 | if src != dst: 201 | if (src, dst) in src_dst_2e: 202 | src_dst_2e[(src, dst)].append(e1) 203 | else: 204 | src_dst_2e[(src, dst)] = [e1] 205 | 206 | node_nbr[src].append(dst) 207 | node_nbr[dst].append(src) 208 | 209 | print("get node nbr") 210 | 211 | for k in trange(len(node_nbr)): 212 | list1 = node_nbr[k] 213 | list1 = sorted(set(list1)) 214 | node_nbr[k] = list1 215 | for k in tqdm(src_dst_2e.keys()): 216 | list1 = sorted(src_dst_2e[k]) 217 | src_dst_2e[k] = list1 218 | G.src_dst_2e = src_dst_2e 219 | G.node_nbr = np.array(node_nbr) 220 | 221 | 222 | def get_alias_n2n_2nd(src, dst): 223 | dst_nbr = node_nbr[dst] 224 | 225 | pp = np.ones(len(dst_nbr)) 226 | pp /= q 227 | 228 | e1_all = src_dst_2e[(src, dst)] 229 | # ff_all_1 = EV[e1_all, :dst] * VE[:dst] 230 | # ff_all_2 = EV[e1_all, dst+1:] * VE[dst+1:] 231 | condition = np.array(VE[dst_nbr, :][:, e1_all].sum(axis=-1)).reshape((-1)) 232 | pp[condition > 0] /= p 233 | 234 | for i, nb in enumerate(dst_nbr): 235 | if nb == src: 236 | pp[i] *= q 237 | elif (src, nb) in src_dst_2e: 238 | pp[i] *= q 239 | # e2_all = src_dst_2e[(dst, nb)] 240 | # ff_all_1 = EV[e1_all, :dst] * VE[:dst, e2_all] 241 | # ff_all_2 = EV[e1_all, dst+1:] * VE[dst+1:, e2_all] 242 | # 243 | # 244 | # pp[i] *= ((ff_all_1.sum() + ff_all_2.sum()) ** 0.5) 245 | 246 | ff_1st = node2ff_1st[dst] 247 | #pp += np.random.randn(pp.shape[0]) * 0.5 248 | pp *= (ff_1st ** weight_1st) 249 | pp *= (node_degree[dst_nbr] ** weight_degree) 250 | 251 | unnormalized_probs = pp 252 | normalized_probs = unnormalized_probs / np.sum(unnormalized_probs) 253 | normalized_probs = normalized_probs / np.sum(normalized_probs) 254 | return alias_setup(normalized_probs) 255 | 256 | 257 | def get_alias_n2n_2nd_dropped(src, dst): 258 | dst_nbr = node_nbr[dst] 259 | 260 | pp = np.zeros(len(dst_nbr)) 261 | 262 | e1_all = src_dst_2e[(src, dst)] 263 | # ff_all_1 = EV[e1_all, :dst] * VE[:dst] 264 | # ff_all_2 = EV[e1_all, dst+1:] * VE[dst+1:] 265 | condition = np.array(VE[dst_nbr, :][:, e1_all].sum(axis=-1)).reshape((-1)) 266 | pp[condition > 0] += p * condition[condition > 0] 267 | 268 | for i, nb in enumerate(dst_nbr): 269 | if nb == src: 270 | pp[i] += node_degree[src] 271 | elif (src, nb) in src_dst_2e: 272 | pp[i] += len(src_dst_2e[(src, nb)]) 273 | else: 274 | pp[i] += 1 / q 275 | # e2_all = src_dst_2e[(dst, nb)] 276 | # ff_all_1 = EV[e1_all, :dst] * VE[:dst, e2_all] 277 | # ff_all_2 = EV[e1_all, dst+1:] * VE[dst+1:, e2_all] 278 | # 279 | # 280 | # pp[i] *= ((ff_all_1.sum() + ff_all_2.sum()) ** 0.5) 281 | 282 | ff_1st = node2ff_1st[dst] 283 | # pp += np.random.randn(pp.shape[0]) * 0.5 284 | pp *= (ff_1st ** weight_1st) 285 | pp *= (node_degree[dst_nbr] ** weight_degree) 286 | 287 | unnormalized_probs = pp 288 | normalized_probs = unnormalized_probs / np.sum(unnormalized_probs) 289 | normalized_probs = normalized_probs / np.sum(normalized_probs) 290 | return alias_setup(normalized_probs) 291 | 292 | 293 | def get_second_order(nodes): 294 | alias_n2n_2nd = {} 295 | for i in trange(len(nodes)): 296 | src = nodes[i] 297 | dsts = node_nbr[src] 298 | 299 | for dst_index, dst in enumerate(dsts): 300 | alias_n2n_2nd[(src, dst)] = get_alias_n2n_2nd(src, dst) 301 | return alias_n2n_2nd 302 | # for multi-processing 303 | 304 | 305 | def parallel_get_second_order(G): 306 | print("2nd order: ") 307 | global p, q, node_nbr, VE, EV, src_dst_2e, node2ff_1st, node_degree, node_nbr 308 | p, q = G.p, G.q 309 | node_nbr = G.node_nbr 310 | VE = G.VE 311 | EV = G.EV 312 | src_dst_2e = G.src_dst_2e 313 | node2ff_1st = G.node2ff_1st 314 | node_degree = G.node_degree 315 | node_nbr = G.node_nbr 316 | 317 | # f is a csr-matrix 318 | # O(\sum_v (\sum_e\in nbr(v) |e|)^2) 319 | 320 | processes_num = 80 321 | pool = ProcessPoolExecutor(max_workers=processes_num) 322 | process_list = [] 323 | 324 | second_start = time.time() 325 | 326 | nodes = np.copy(G.nodes) 327 | 328 | split_num = min(processes_num, int(len(nodes) / 100)) * 2 + 1 329 | print("split_num", split_num) 330 | np.random.shuffle(nodes) 331 | nodes = np.array_split(nodes, split_num) 332 | 333 | print("Start get second order alias") 334 | for node in nodes: 335 | process_list.append(pool.submit(get_second_order, node)) 336 | 337 | alias_n2n_2nd = {} 338 | for p in as_completed(process_list): 339 | alias_t1 = p.result() 340 | alias_n2n_2nd.update(alias_t1) 341 | 342 | print("get-second-order-term running time: " + 343 | str(time.time() - second_start)) 344 | 345 | print("Start to turn the dict into list") 346 | alias_n2n_2nd_list = [] 347 | alias_n2n_toid = {} 348 | for i, k in enumerate(tqdm(alias_n2n_2nd.keys())): 349 | alias_n2n_toid[k] = i 350 | alias_n2n_2nd_list.append(alias_n2n_2nd[k]) 351 | 352 | G.alias_n2n_toid = alias_n2n_toid 353 | G.alias_n2n_2nd_list = alias_n2n_2nd_list 354 | 355 | pool.shutdown(wait=True) 356 | return alias_n2n_2nd 357 | 358 | 359 | def random_walk_list(walk_length, start): 360 | walk = [start] 361 | while len(walk) < (walk_length): 362 | cur = walk[-1] 363 | cur_ns = node_nbr[cur] 364 | if len(cur_ns) < 1: 365 | walk.append(cur) 366 | continue 367 | 368 | try: 369 | if len(walk) == 1: 370 | next_n = cur_ns[alias_draw(alias_n2n_1st[cur])] 371 | else: 372 | prev_n = walk[-2] 373 | next_n = cur_ns[alias_draw( 374 | alias_n2n_2nd_list[alias_n2n_toid[(prev_n, cur)]])] 375 | 376 | except Exception as e: 377 | print("error", e) 378 | break 379 | walk.append(next_n) 380 | 381 | return walk 382 | 383 | 384 | def simulate_walks_part(num_walks, walk_length, nodes): 385 | walks = [] 386 | for node in tqdm(nodes): 387 | for walk_iter in range(num_walks): 388 | walk = random_walk_list(walk_length, node) 389 | walks.append(walk) 390 | return walks 391 | 392 | 393 | def simulate_walks_para(G, num_walks, walk_length): 394 | ''' 395 | Repeatedly simulate random walks from each node. 396 | ''' 397 | global alias_n2n_1st, alias_n2n_2nd_list, alias_n2n_toid 398 | alias_n2n_1st = G.alias_n2n_1st 399 | alias_n2n_2nd_list = G.alias_n2n_2nd_list 400 | alias_n2n_toid = G.alias_n2n_toid 401 | 402 | processes_num = 30 403 | pool = ProcessPoolExecutor(max_workers=processes_num) 404 | process_list = [] 405 | 406 | print("sample walks:") 407 | walks = [] 408 | 409 | nodes = np.copy(G.nodes) 410 | 411 | split_num = processes_num 412 | print("split_num", split_num) 413 | np.random.shuffle(nodes) 414 | nodes = np.array_split(nodes, split_num) 415 | 416 | for node in nodes: 417 | process_list.append( 418 | pool.submit( 419 | simulate_walks_part, 420 | num_walks, 421 | walk_length, 422 | node)) 423 | 424 | for p in as_completed(process_list): 425 | alias_t1 = p.result() 426 | walks += alias_t1 427 | 428 | pool.shutdown(wait=True) 429 | 430 | print("start permutation") 431 | idx = np.random.permutation(len(walks)) 432 | walks = np.array(walks, dtype='int') 433 | return walks[idx] 434 | 435 | 436 | def toint(hyperedge_list): 437 | return np.array([np.array(h).astype('int') - 1 for h in hyperedge_list]) 438 | 439 | 440 | def random_walk_hyper(args, node_list, hyperedge_list): 441 | p, q = args.p, args.q 442 | 443 | num_walks, walk_length, window_size = args.num_walks, args.walk_length, args.window_size 444 | walks_save_path = '../walks/{}/p{}_q{}_r{}_l{}_hyper_walks.txt'.format( 445 | args.data, p, q, num_walks, walk_length) 446 | if not os.path.exists("../walks/{}/".format(args.data)): 447 | os.mkdir("../walks/{}/".format(args.data)) 448 | start = time.time() 449 | 450 | if not args.TRY and os.path.exists(walks_save_path): 451 | return walks_save_path 452 | else: 453 | G = HyperGraphRandomWalk(p, q) 454 | G.data = args.data 455 | # FIXME: take care when the input are tensors, but I think other 456 | # dataset they will not be 457 | print('build') 458 | hyperedge_list = toint(hyperedge_list) 459 | G.build_graph(node_list, hyperedge_list) 460 | edges = np.array(range(len(G.edges))) 461 | print("Building pairwise to hyper dict") 462 | get_src_dst2e(G, edges) 463 | G.alias_n2n_1st, G.node2ff_1st = get_first_order(G) 464 | parallel_get_second_order(G) 465 | print("RandomWalk getting edges time: %.2lf" % (time.time() - start)) 466 | print(G.__dict__.keys()) 467 | 468 | name = [ 469 | 'data', 470 | 'edges', 471 | 'node_degree', 472 | 'edge_degree', 473 | 'n_edge', 474 | 'EV', 475 | 'delta', 476 | 'EV_over_delta', 477 | 'VE', 478 | 'VE_over_delta', 479 | 'src_dst_2e', 480 | 'node_nbr', 481 | 'node2ff_1st'] 482 | 483 | for n in name: 484 | delattr(G, n) 485 | 486 | walks = simulate_walks_para(G, num_walks, walk_length) 487 | print("RandomWalk running time: %.2lf" % (time.time() - start)) 488 | np.savetxt(walks_save_path, walks, fmt="%d", delimiter=" ") 489 | # np.save(walks_save_path,walks) 490 | del G 491 | del walks 492 | print("RandomWalk running time: %.2lf" % (time.time() - start)) 493 | 494 | return walks_save_path 495 | -------------------------------------------------------------------------------- /History_version/Code/torchsummary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | 9 | def summary(model, input_size, batch_size=-1, device="cuda"): 10 | 11 | def register_hook(module): 12 | 13 | def hook(module, input, output): 14 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 15 | module_idx = len(summary) 16 | 17 | m_key = "%s-%i" % (class_name, module_idx + 1) 18 | summary[m_key] = OrderedDict() 19 | summary[m_key]["input_shape"] = list(input[0].size()) 20 | summary[m_key]["input_shape"][0] = batch_size 21 | if isinstance(output, (list, tuple)): 22 | summary[m_key]["output_shape"] = [ 23 | [-1] + list(o.size())[1:] for o in output 24 | ] 25 | else: 26 | summary[m_key]["output_shape"] = list(output.size()) 27 | summary[m_key]["output_shape"][0] = batch_size 28 | 29 | params = 0 30 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 31 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 32 | summary[m_key]["trainable"] = module.weight.requires_grad 33 | 34 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 35 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 36 | summary[m_key]["nb_params"] = params 37 | 38 | if ( 39 | not isinstance(module, nn.Sequential) 40 | and not isinstance(module, nn.ModuleList) 41 | and not (module == model) 42 | ): 43 | hooks.append(module.register_forward_hook(hook)) 44 | 45 | device = device.lower() 46 | assert device in [ 47 | "cuda", 48 | "cpu", 49 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 50 | 51 | if device == "cuda" and torch.cuda.is_available(): 52 | dtype = torch.cuda.FloatTensor 53 | else: 54 | dtype = torch.FloatTensor 55 | 56 | # multiple inputs to the network 57 | if isinstance(input_size, tuple): 58 | input_size = [input_size] 59 | 60 | # batch_size of 2 for batchnorm 61 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 62 | # print(type(x[0])) 63 | 64 | # create properties 65 | summary = OrderedDict() 66 | hooks = [] 67 | 68 | # register hook 69 | model.apply(register_hook) 70 | 71 | # make a forward pass 72 | # print(x.shape) 73 | model(*x) 74 | 75 | # remove these hooks 76 | for h in hooks: 77 | h.remove() 78 | 79 | print("----------------------------------------------------------------") 80 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 81 | print(line_new) 82 | print("================================================================") 83 | total_params = 0 84 | total_output = 0 85 | trainable_params = 0 86 | for layer in summary: 87 | # input_shape, output_shape, trainable, nb_params 88 | line_new = "{:>20} {:>25} {:>15}".format( 89 | layer, 90 | str(summary[layer]["output_shape"]), 91 | "{0:,}".format(summary[layer]["nb_params"]), 92 | ) 93 | total_params += summary[layer]["nb_params"] 94 | try: 95 | total_output += np.prod(summary[layer]["output_shape"]) 96 | except: 97 | print("error", layer) 98 | if "trainable" in summary[layer]: 99 | if summary[layer]["trainable"] == True: 100 | trainable_params += summary[layer]["nb_params"] 101 | print(line_new) 102 | 103 | # assume 4 bytes/number (float on cuda). 104 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 105 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 106 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 107 | total_size = total_params_size + total_output_size + total_input_size 108 | 109 | print("================================================================") 110 | print("Total params: {0:,}".format(total_params)) 111 | print("Trainable params: {0:,}".format(trainable_params)) 112 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 113 | print("----------------------------------------------------------------") 114 | print("Input size (MB): %0.2f" % total_input_size) 115 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 116 | print("Params size (MB): %0.2f" % total_params_size) 117 | print("Estimated Total Size (MB): %0.2f" % total_size) 118 | print("----------------------------------------------------------------") 119 | # return summary -------------------------------------------------------------------------------- /History_version/Code/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm, trange 4 | from sklearn.metrics import average_precision_score, precision_score, recall_score, f1_score 5 | from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef 6 | from concurrent.futures import as_completed, ProcessPoolExecutor 7 | from pybloom_live import ScalableBloomFilter 8 | from copy import copy,deepcopy 9 | from itertools import combinations 10 | from pybloomfilter import BloomFilter 11 | import os 12 | import time 13 | def add_padding_idx(vec): 14 | if len(vec.shape) == 1: 15 | return np.asarray([np.sort(np.asarray(v) + 1).astype('int') 16 | for v in tqdm(vec)]) 17 | else: 18 | vec = np.asarray(vec) + 1 19 | vec = np.sort(vec, axis=-1) 20 | return vec.astype('int') 21 | 22 | 23 | def np2tensor_hyper(vec, dtype): 24 | vec = np.asarray(vec) 25 | if len(vec.shape) == 1: 26 | return [torch.as_tensor(v, dtype=dtype) for v in vec] 27 | else: 28 | return torch.as_tensor(vec, dtype = dtype) 29 | 30 | 31 | def walkpath2str(walk): 32 | return [list(map(str, w)) for w in tqdm(walk)] 33 | 34 | 35 | def roc_auc_cuda(y_true, y_pred, size_list, max_size): 36 | roc_str, aupr_str = "", "" 37 | try: 38 | for s in np.unique(size_list): 39 | y_t = (y_true[size_list == s] > 0.5).float().cpu().detach().numpy().reshape((-1, 1)) 40 | y_p = y_pred[size_list == s].cpu().detach().numpy().reshape((-1, 1)) 41 | roc, aupr = roc_auc_score( 42 | y_t, y_p), average_precision_score( 43 | y_t, y_p) 44 | roc_str += "%d %.3f " % (s, roc) 45 | aupr_str += "%d %.3f " % (s, aupr) 46 | return roc_str, aupr_str 47 | except BaseException: 48 | return 0.0, 0.0 49 | 50 | 51 | def accuracy(output, target, size_list = None, max_size = None): 52 | acc_str = "" 53 | if size_list is not None: 54 | for s in np.unique(size_list): 55 | pred = output[size_list == s] >= 0.5 56 | truth = target[size_list == s] >= 0.5 57 | acc = torch.sum(pred.eq(truth)) 58 | acc = float(acc) * 1.0 / (truth.shape[0] * 1.0) 59 | acc_str += "%d %.3f " % (s, acc) 60 | else: 61 | pred = output >= 0.5 62 | truth = target >= 0.5 63 | acc = torch.sum(pred.eq(truth)) 64 | acc = float(acc) * 1.0 / (truth.shape[0] * 1.0) 65 | acc_str += "%.3f " % (acc) 66 | return acc_str 67 | 68 | 69 | def build_hash(data,compress,max_size,min_size,fname): 70 | # if os.path.isfile("../data/SPRITE/%s_dict_%d" %(fname, 0)): 71 | # print ("existing dict") 72 | # dict_list = [BloomFilter.open("../data/SPRITE/%s_dict_%d" %(fname, i)) for i in range(max_size + 1)] 73 | # else: 74 | # dict_list = [] 75 | # for i in range(max_size + 1): 76 | # if i <= 2: 77 | # dict_list.append(BloomFilter(10, 1e-3, "../data/SPRITE/%s_dict_%d" %(fname, i))) 78 | # else: 79 | # dict_list.append(BloomFilter(3e9, 1e-3, "../data/SPRITE/%s_dict_%d" % (fname, i))) 80 | 81 | dict_list = [] 82 | for i in range(max_size + 1): 83 | if i <= 2: 84 | dict_list.append(BloomFilter(10, 1e-3)) 85 | # dict_list.append(set()) 86 | else: 87 | dict_list.append(BloomFilter(5e8, 1e-3)) 88 | # dict_list.append(set()) 89 | print (len(dict_list)) 90 | data_list = [[] for i in range(max_size + 1)] 91 | for datum in tqdm(data): 92 | # datum = np.array(datum).astype('int') 93 | if (min_size < 0) or (len(datum) >= min_size): 94 | # if len(datum) >= 100: 95 | # continue 96 | # datum.sort() 97 | # for j in range(min_size, min(len(datum), max_size) + 1): 98 | # combs = combinations(datum, j) 99 | # data_list[j].append(combs) 100 | data_list[len(datum)].append(tuple(datum)) 101 | 102 | if len(data_list[min_size]) > 1e7: 103 | start = time.time() 104 | for i in range(max_size + 1): 105 | dict_list[i].update(data_list[i]) 106 | # for combs in data_list[i]: 107 | # dict_list[i].update(combs) 108 | # for d in combs: 109 | # if d in dict_list[i]: 110 | # # final_data.append(d) 111 | # continue 112 | # else: 113 | # # final_data.append(d) 114 | # dict_list[i].add(d) 115 | print(len(dict_list[min_size]) / dict_list[min_size].capacity, "%.2f s" %(time.time() - start)) 116 | data_list = [[] for i in range(max_size + 1)] 117 | 118 | 119 | for i in range(max_size + 1): 120 | for combs in data_list[i]: 121 | dict_list[i].add(combs) 122 | 123 | 124 | print (len(dict_list[-1])) 125 | length_list = [len(dict_list[i]) for i in range(len(dict_list))] 126 | print (length_list) 127 | np.save("../data/SPRITE/length.npy", length_list) 128 | return dict_list 129 | 130 | 131 | def parallel_build_hash(data, func, args, num, initial=None, compress=False, max_size=-1): 132 | import multiprocessing 133 | cpu_num = multiprocessing.cpu_count() 134 | data = np.array_split(data, cpu_num * 1) 135 | dict1 = deepcopy(initial) 136 | pool = ProcessPoolExecutor(max_workers=cpu_num) 137 | process_list = [] 138 | 139 | if func == 'build_hash': 140 | func = build_hash 141 | if func == 'build_hash2': 142 | func = build_hash2 143 | if func == 'build_hash3': 144 | func = build_hash3 145 | 146 | for datum in data: 147 | process_list.append(pool.submit(func, datum, compress, max_size)) 148 | 149 | for p in as_completed(process_list): 150 | a = p.result() 151 | if compress: 152 | dict1 = dict1.union(a) 153 | else: 154 | dict1.update(a) 155 | del a 156 | pool.shutdown(wait=True) 157 | 158 | # if args.data in ['schic','ramani']: 159 | # print (num[0]) 160 | # new_list_of_set = [set() for i in range(int(num[0]+1))] 161 | # for s in dict1: 162 | # try: 163 | # new_list_of_set[s[0]].add(s) 164 | # except: 165 | # print (s) 166 | # raise EOFError 167 | # dict1 = new_list_of_set 168 | return dict1 169 | 170 | def sync_shuffle(sample_list, max_num = -1): 171 | index = torch.randperm(len(sample_list[0])) 172 | if max_num > 0: 173 | index = index[:max_num] 174 | new_list = [] 175 | for s in sample_list: 176 | new_list.append(s[index]) 177 | return new_list 178 | 179 | 180 | def pass_(x): 181 | return x 182 | 183 | 184 | def generate_outlier_part(data, dict_pair, k=20): 185 | inputs = [] 186 | negs = [] 187 | 188 | for e in tqdm(data): 189 | point = int(np.where(e == 0)[0]) 190 | start = 0 if point == 0 else int(num_list[point - 1]) 191 | end = int(num_list[point]) 192 | 193 | count = 0 194 | trial = 0 195 | while count < k: 196 | trial += 1 197 | if trial >= 100: 198 | break 199 | j = np.random.randint(start, end) + 1 200 | condition = [(j, n) in dict_pair for n in e] 201 | if np.sum(condition) > 0: 202 | continue 203 | else: 204 | temp = np.copy(e) 205 | temp[point] = j 206 | inputs.append(temp) 207 | negs.append(point) 208 | count += 1 209 | inputs, index = np.unique(inputs, axis=0, return_index=True) 210 | negs = np.array(negs)[index] 211 | return np.array(inputs), np.array(negs) 212 | 213 | 214 | def check_outlier(model, data_): 215 | data, negs = data_ 216 | bs = 1024 217 | num_of_batches = int(np.floor(data.shape[0] / bs)) + 1 218 | k = 3 219 | outlier_prec = torch.zeros(k).to(device) 220 | 221 | model.eval() 222 | with torch.no_grad(): 223 | for i in tqdm(range(num_of_batches)): 224 | inputs = data[i * bs:(i + 1) * bs] 225 | neg = negs[i * bs:(i + 1) * bs] 226 | outlier = model(inputs, get_outlier=k) 227 | outlier_prec += (outlier.transpose(1, 0) == neg).sum(dim=1).float() 228 | # for kk in range(k): 229 | # outlier_prec[kk] += (outlier[:,kk].view(-1)==neg).sum().float() 230 | outlier_prec = outlier_prec.cumsum(dim=0) 231 | outlier_prec /= data.shape[0] 232 | for kk in range(k): 233 | print("outlier top %d hitting: %.5f" % (kk + 1, outlier_prec[kk])) 234 | 235 | 236 | class Word2Vec_Skipgram_Data_Empty(object): 237 | """Word2Vec model (Skipgram).""" 238 | 239 | def __init__(self): 240 | return 241 | 242 | def next_batch(self): 243 | """Train the model.""" 244 | 245 | return 0, 0, 0, 0, 0 246 | -------------------------------------------------------------------------------- /History_version/Code/word2vec_ops.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/Code/word2vec_ops.so -------------------------------------------------------------------------------- /History_version/Readme.md: -------------------------------------------------------------------------------- 1 | This is the old version of the MATCHA for replicating the results in the manuscript. It is no longer maintained 2 | 3 | ## Required Package 4 | 5 | Python (>= 3.6.8) 6 | 7 | pytorch(tested on 1.2.0) 8 | 9 | Numpy(tested on 1.16.3) 10 | 11 | tqdm 12 | 13 | Umap-learn(tested on 0.3.10) 14 | 15 | pybloomfiltermmap3 16 | 17 | 18 | 19 | ## Running command 20 | 21 | unzip the data file under SPRITE folder before proceeding, the hyperedges with occurrence frequency 2 are not included due to the file size. 22 | 23 | But they can be generated by the script (preprocess_SPRITE.py and analysis_SPRITE.py) 24 | 25 | The required files for these two scripts are: 26 | 27 | 1. SPRITE cluster file from 4DN data portal. 28 | 29 | 2. The .mcool file from 4DN data portal. (Dumped into pair-wise contact named by SPRITE_contact.txt) 30 | 31 | 32 | 33 | python main_drop.py -f adj -w hyper 34 | 35 | for running on ChIA-Drop data 36 | 37 | python main_SPRITE.py -f adj -w hyper 38 | 39 | for running on SPRITE data -------------------------------------------------------------------------------- /History_version/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/.DS_Store -------------------------------------------------------------------------------- /History_version/data/SPRITE/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/.DS_Store -------------------------------------------------------------------------------- /History_version/data/SPRITE/bin2node.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/bin2node.npy -------------------------------------------------------------------------------- /History_version/data/SPRITE/node2bin.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/node2bin.npy -------------------------------------------------------------------------------- /History_version/data/SPRITE/node2chrom.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/node2chrom.npy -------------------------------------------------------------------------------- /History_version/data/SPRITE/tuples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/tuples/.DS_Store -------------------------------------------------------------------------------- /History_version/data/SPRITE/tuples/occ_3_8.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/tuples/occ_3_8.zip -------------------------------------------------------------------------------- /History_version/data/SPRITE/tuples/occ_above_8.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/SPRITE/tuples/occ_above_8.zip -------------------------------------------------------------------------------- /History_version/data/drop/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/drop/.DS_Store -------------------------------------------------------------------------------- /History_version/data/drop/coor2id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/drop/coor2id.npy -------------------------------------------------------------------------------- /History_version/data/drop/id2coor.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/drop/id2coor.npy -------------------------------------------------------------------------------- /History_version/data/drop/test_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/drop/test_data.npz -------------------------------------------------------------------------------- /History_version/data/drop/train_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ma-compbio/MATCHA/ff18cd9db3e20e527882163fb0d263a27f6646ef/History_version/data/drop/train_data.npz -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ma Lab at CMU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # MATCHA: Probing multi-way chromatin interaction with hypergraph representation learning 2 | 3 | This is the implementation of the algorithm MATCHA for analyzing multi-way chromatin interaction data via hypergraph representation learning. 4 | 5 | ## Requirements 6 | The main part of the alogrithm (`process.py, generate_kmers.py, main.py`) requires 7 | 8 | 9 | - h5py 10 | - numpy 11 | - pytorch 12 | - pybloom_live (https://github.com/joseph-fox/python-bloomfilter) 13 | - scikit-learn 14 | - tqdm 15 | 16 | 17 | The visualization part of the algorihtm (`denoise_contact.py`) requires 18 | 19 | - seaborn 20 | - matplotlib 21 | 22 | ## Configure the parameters 23 | 24 | All the input parameters are stored in the config.JSON file. 25 | Please fill in this file before running the program. 26 | Note that, some scripts only use part of these parameters, so these parameters can be filled in before running those specific script. 27 | 28 | | params | description | example | used in | 29 | |--------------|------------------------------|---------------------------|------------| 30 | | cluster_path | the path of the cluster file | "./4DNFIBEVVTN5.clusters" | process.py | 31 | |mcool_path | the path of the mcool file | "./4DNFIUOOYQC3.mcool" | process.py| 32 | |resolution | the resolution to consider (bin size) | 1000000 | process.py| 33 | |chrom_list | list of the chromosomes to consider | ["chr1", "chr2"] | process.py, main.py| 34 | |chrom_size | the path of the chromatin size file | "./hg38.chrom.sizes.txt" | process.py| 35 | |temp_dir | the directory of the temp files to store | "../Temp" | all| 36 | |max_cluster_size| the maximum cluster size to consider | 25| process.py, generate_kmers.py | 37 | |min_distance | minimum pairwise genomic distance constraint for multi-way interactions (in unit of the number of bins) |0| generate_kmers.py, main.py, denoise_contact.py| 38 | |k-mer_size| list of the size of the k-mers to considier | [2,3,4,5] | generate_kmers.py, main.py, 39 | |min_freq_cutoff | only consider k-mers with occurrence frequency >= | 2 | generate_kmers.py| 40 | |quantile_cutoff_for_positive | the quantile cutoff of hyperedges to be considered as positive samples. For instance, 0.6 represents the hyperedges with occurrence frequency in the top 40% (>= 0.6) would be used as positive samples. The cut-off is applied to different sized hyperedges separately| 0.6 | main.py | 41 | |quantile_cutoff_for_unlabel | the quantile cutoff of hyperedges to be considered as non-negative samples (positive + samples that cannot be confidently classified as either positive or negative samples) | 0.4 | main.py | 42 | |embed_dim | embedding dimensions for the bins | 64| main.py| 43 | 44 | 45 | ## Usage 46 | 1. `cd Code` 47 | 2. Run `python process.py`, which will parse the input cluster file, mcool file and the chromosome size files. There will be 3 key output files: 48 | 1. `bin2node.npy, node2bin.npy` within the `temp_dir` above. As the name indicates, it's a dictionary that maps the genomic bin to the node id and vice verse. The genomic bin has the format of `chr1:2000000` 49 | 2. `node2chrom.npy`. It maps the node id to the chromosome. 50 | 3. All these dictionaries can be loaded through `np.load(FILEPATH, allow_picke=True).item()` 51 | 3. Run `python generate_kmers.py`, which will further transfer the parsed cluster file into a list of k-mers (hyperedges) with the corresponding occurrence frequencies. The output files are 52 | 1. `all__counter.npy`: the generated k-mers 53 | 2. `all__freq_counter.npy`: the occurrence frequency corresponds to the generated k-mers 54 | 4. Run `python main.py`, which will train the model based on the generated dataset. The output includes: 55 | 1. `model2load` within the `temp_dir` above. The model can be loaded by `model = torch.load(FILEPATH)`. The model can return predictions through `model(x)`. Note that the `x` should be a pytorch tensor of dtype `torch.long` 56 | 2. `embeddings.npy` lies in the root dir. It's the embedding vectors for the genomic bins. The shape of the vectors are `(num of genomic bins, embed_dim chosen above)`. The mapping relationship between the genomic bin and its index in this vector can be retrived in the dictionary `node2bin.npy, bin2node.npy` mentioned above. 57 | 5. To generate the denoised contact matrix, run `python denoise_contact.py` There will be output figures named as `chr1_origin.png` and `chr1_denoise.png`, etc... produced in the root dir. There will also be an mcool file named as `denoised.mcool` in the root dir, which contains the denoised intra-chromosomal contact matrix at the given resolution. 58 | 59 | 6. To predict the probabilities of forming multi-way chromatin interactions for a custom list of genome coordinate, run `python predict_multiway.py -i INPUT_FILE -o OUTPUT_FILE`. The `INPUT_FILE` should be a text file where each line is a tab separated list of genome coordinates. For example: 60 | ```text 61 | chr1:1000000chr2:20000chr3:40000 62 | chr1:1000000chr2:20000chr3:40000chr1:12345 63 | ``` 64 | The output file will be a list of the probability scores stored in the `OUTPUT_FILE` 65 | 66 | ## Cite 67 | 68 | If you want to cite our paper 69 | 70 | ``` 71 | @article{zhang2020matcha, 72 | title={MATCHA: Probing Multi-way Chromatin Interaction with Hypergraph Representation Learning}, 73 | author={Zhang, Ruochi and Ma, Jian}, 74 | journal={Cell Systems}, 75 | volume={10}, 76 | number={5}, 77 | pages={397--407}, 78 | year={2020}, 79 | publisher={Elsevier} 80 | } 81 | ``` 82 | --------------------------------------------------------------------------------