├── LICENSE ├── README.md ├── data ├── afdb │ └── readme.md ├── embedding │ └── readme.md ├── new_mol_smi │ └── readme.md ├── new_seq_smi │ └── readme.md ├── new_time │ └── readme.md └── readme.md ├── data_utils.py ├── get_afdb.py ├── logger └── readme.md ├── mat.py ├── model └── readme.md ├── prepare_graphs.py ├── prepare_negative.py ├── pretrained ├── SaProt_650M_PDB │ └── readme.md └── readme.md ├── process_esm.py ├── process_mat.py ├── process_saprot.py ├── retrieval.py ├── retrieval_rnn.py ├── retrieval_tfmr.py ├── train.py ├── train_contra.py ├── train_rnn.py ├── train_tfmr.py └── unimol.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReactZyme: A Benchmark for Enzyme-Reaction Prediction [[paper](https://www.arxiv.org/abs/2408.13659)] 2 | ### Official Github repository of ReactZyme ([arxiv-link](https://www.arxiv.org/abs/2408.13659)). 3 | 4 | ### Check out for our newest [EnzymeFlow](https://github.com/WillHua127/EnzymeFlow) !!! 5 | 6 | # Data preparation 7 | 8 | Rawdata can be downloaded from [zendo-reactzyme](https://zenodo.org/records/13635807). Once downloaded, put rawdata into 'data' folder. 9 | 10 | (1) Rawdata 11 | 12 | There should be 4 rawdata files including: (1) cleaned_uniprot_rhea.tsv; (2) uniprot_molecules.tsv; (3) uniprot_rhea.tsv; (4) rhea_molecules.tsv. 13 | Additionally, there is a saprot_seq.pt for structure-aware protein sequences for SaProt after running FoldSeek. 14 | Put these files under the 'data' folder. 15 | 16 | (2) Processed data 17 | 18 | And there should be 3 splits: (1) time; (2) seq-smi based; (3) mol-smi based. Put time/seq-smi/mol-smi under new_time/new_seq_smi/new_mol_smi folders, respectively. Notice that we only provide positive enzyme-reaction pairs, the design of negative samples remains an open question. Nevertheless, we provide example of negative samples generation in [prepare_negative.py](https://github.com/WillHua127/ReactZyme/blob/main/prepare_negative.py). 19 | 20 | # Python file - utils 21 | 22 | SaProt tips: If you want to use SaProt, you have to use FoldSeek to get structure-aware sequence representations. This can be annoying. So we provide [processed structure-aware sequences](https://zenodo.org/records/13635807) for our dataset (the 'saprot_seq.pt' file from zendo). Or if you'd like to do it on your own, you can use the function [get_struc_seq](https://github.com/WillHua127/ReactZyme/blob/main/process_saprot.py) from process_saprot.py. 23 | 24 | 25 | (1) Processing Sequences 26 | 27 | >[get_afdb.py](https://github.com/WillHua127/ReactZyme/blob/main/get_afdb.py): code example of fetching afdb structures for time-based split. 28 | 29 | >[process_saprot.py](https://github.com/WillHua127/ReactZyme/blob/main/process_saprot.py): code example of processing saprot features for afbd structures. 30 | 31 | >[process_esm.py](https://github.com/WillHua127/ReactZyme/blob/main/process_esm.py): code example of processing ESM features for sequences. 32 | 33 | 34 | 35 | (2) Processing Reactions 36 | 37 | >[mat.py](https://github.com/WillHua127/ReactZyme/blob/main/mat.py): code for MAT for loading model purposes. 38 | 39 | >[process_mat.py](https://github.com/WillHua127/ReactZyme/blob/main/process_mat.py): code example of processing MAT features for reactions. 40 | 41 | >[prepare_graphs.py](https://github.com/WillHua127/ReactZyme/blob/main/prepare_graphs.py): code for process molecular graphs. 42 | 43 | 44 | 45 | (3) General dataloading 46 | 47 | >[data_utils.py](https://github.com/WillHua127/ReactZyme/blob/main/data_utils.py): dataloader etc. 48 | 49 | 50 | 51 | 52 | (4) Negative samples 53 | >[prepare_negative.py](https://github.com/WillHua127/ReactZyme/blob/main/prepare_negative.py): code example of preparing negative samples based on reaction SMILES. Once you have the dictionary of negative pairs 'data/negative_mol_dict.pt', you can prepare negative samples for training. 54 | 55 | 56 | (5) Unimol features 57 | 58 | > [unimol.ipynb](https://github.com/WillHua127/ReactZyme/blob/main/unimol.ipynb): code example of generating unimol features for reactions. 59 | 60 | 61 | # Python file - train and evaluation 62 | 63 | 64 | (1) Train MLP 65 | 66 | 67 | >[train.py](https://github.com/WillHua127/ReactZyme/blob/main/train.py): code for MLP training. 68 | 69 | You can do time-based esm-unimol training like: CUDA_VISIBLE_DEVICES=0 python train.py --split_type time --mol_embedding_type unimol --pro_embedding_type esm --batch_size 1000 70 | 71 | >[retrieval.py](https://github.com/WillHua127/ReactZyme/blob/main/retrieval.py): code for MLP evaluation. 72 | 73 | 74 | 75 | (2) Train Contrastive 76 | 77 | 78 | >[train_contra.py](https://github.com/WillHua127/ReactZyme/blob/main/train_contra.py): code for MLP-contrastive training. 79 | 80 | 81 | >[retrieval.py](https://github.com/WillHua127/ReactZyme/blob/main/retrieval.py): code for MLP-contrastive evaluation. 82 | 83 | 84 | 85 | (3) Train Transformer 86 | 87 | 88 | >[train_tfmr.py](https://github.com/WillHua127/ReactZyme/blob/main/train_tfmr.py): code for Transformer training. 89 | 90 | 91 | >[retrieval_tfmr.py](https://github.com/WillHua127/ReactZyme/blob/main/retrieval_tfmr.py): code for Transformer evaluation. 92 | 93 | 94 | 95 | 96 | (4) Train Bi-RNN 97 | 98 | 99 | >[train_rnn.py](https://github.com/WillHua127/ReactZyme/blob/main/train_rnn.py): code for Bi-RNN training. 100 | 101 | 102 | >[retrieval_rnn.py](https://github.com/WillHua127/ReactZyme/blob/main/retrieval_rnn.py): code for Bi-RNN evaluation. 103 | 104 | 105 | 106 | ## Citation 107 | ``` 108 | @article{hua2024reactzyme, 109 | title={Reactzyme: A Benchmark for Enzyme-Reaction Prediction}, 110 | author={Hua, Chenqing and Zhong, Bozitao and Luan, Sitao and Hong, Liang and Wolf, Guy and Precup, Doina and Zheng, Shuangjia}, 111 | journal={arXiv preprint arXiv:2408.13659}, 112 | year={2024} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /data/afdb/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save alphafoldDB structrues for SaProt 2 | -------------------------------------------------------------------------------- /data/embedding/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save pretrained embeddings for ESM, unimol, etc. For example, esm_seq_embedding.pt and unimol_mol_embedding.pt 2 | -------------------------------------------------------------------------------- /data/new_mol_smi/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save reaction-smilarity based splits via https://zenodo.org/records/11494913. 2 | 3 | Put (1) positive_train_val_mol_smi.pt; (2) positive_test_mol_smi.pt; (3) negative_train_val_mol_smi.pt; (4) negative_test_mol_smi.pt 4 | -------------------------------------------------------------------------------- /data/new_seq_smi/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save sequence-smilarity based splits via https://zenodo.org/records/11494913. 2 | 3 | Put (1) positive_train_val_seq_smi.pt; (2) positive_test_seq_smi.pt; (3) negative_train_val_seq_smi.pt; (4) negative_test_seq_smi.pt 4 | -------------------------------------------------------------------------------- /data/new_time/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save time based splits via https://zenodo.org/records/11494913. 2 | 3 | Put (1) positive_train_val_time_smi.pt; (2) positive_test_time_smi.pt; (3) negative_train_val_time_smi.pt; (4) negative_test_time_smi.pt 4 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save raw data, can download from https://zenodo.org/records/11494913 2 | 3 | Put (1) cleaned_uniprot_rhea.tsv; (2) uniprot_molecules.tsv; (3) uniprot_rhea.tsv; (4) rhea_molecules.tsv; (5) saprot_seq.pt (additional if you want to use SaProt) 4 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | from mat import make_model 6 | 7 | def get_samples(pos_path, neg_path): 8 | pos_samples = torch.load(pos_path) 9 | neg_samples = torch.load(neg_path) 10 | 11 | pos_mols = [] 12 | pos_seqs = [] 13 | for unis, values in pos_samples.items(): 14 | mol = values[0] 15 | seq = values[1] 16 | pos_mols.append(mol.replace('*', 'C')) 17 | pos_seqs.append(seq) 18 | 19 | 20 | neg_mols = [] 21 | neg_seqs = [] 22 | for unis, values in neg_samples.items(): 23 | mol = values[0] 24 | seq = values[1] 25 | neg_mols.append(mol.replace('*', 'C')) 26 | neg_seqs.append(seq) 27 | 28 | assert len(pos_mols) == len(pos_seqs) 29 | assert len(neg_mols) == len(neg_seqs) 30 | 31 | return pos_mols, pos_seqs, neg_mols, neg_seqs 32 | 33 | 34 | def collate_fn(batch): 35 | mols, seqs, labels = zip(*batch) 36 | batch_mols = pad_sequence(mols, batch_first=True, padding_value=0) 37 | batch_seqs = pad_sequence(seqs, batch_first=True, padding_value=1) 38 | batch_labels = torch.stack(labels) 39 | return batch_mols, batch_seqs, batch_labels 40 | 41 | 42 | class EnzymeDataset(Dataset): 43 | def __init__(self, molecules, sequences, mol_tokenizer, seq_tokenizer, positive_sample=True, max_len=7000): 44 | assert len(molecules) == len(sequences) 45 | self.len = len(sequences) 46 | self.mols = molecules 47 | self.seqs = sequences 48 | self.mol_tokenizer = mol_tokenizer 49 | self.seq_tokenizer = seq_tokenizer 50 | self.max_len = max_len 51 | self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len) 52 | 53 | def __len__(self): 54 | return self.len 55 | 56 | def __getitem__(self, item): 57 | mols = self.mols[item] 58 | seqs = self.seqs[item] 59 | labels = self.labels[item] 60 | 61 | mol_tok = self.mol_tokenizer(mols, padding=True, truncation=True, max_length=self.max_len)['input_ids'] 62 | seq_tok = self.seq_tokenizer(seqs, padding=True, truncation=True, max_length=self.max_len)['input_ids'] 63 | 64 | return torch.tensor(mol_tok), torch.tensor(seq_tok), labels 65 | 66 | def collate_fn_pretrained(batch): 67 | mols, seqs, labels = zip(*batch) 68 | batch_mols = torch.stack(mols) 69 | batch_seqs = torch.stack(seqs) 70 | batch_labels = torch.stack(labels) 71 | return batch_mols, batch_seqs, batch_labels 72 | 73 | class EnzymeDatasetPretrained(Dataset): 74 | def __init__(self, molecules, sequences, mol_embedding, seq_embedding, positive_sample=True, max_len=5000): 75 | assert len(molecules) == len(sequences) 76 | self.len = len(sequences) 77 | self.mols = molecules 78 | self.seqs = sequences 79 | self.mol_embedding = mol_embedding 80 | self.seq_embedding = seq_embedding 81 | self.max_len = max_len 82 | self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len) 83 | 84 | def __len__(self): 85 | return self.len 86 | 87 | def __getitem__(self, item): 88 | mols = self.mols[item] 89 | seqs = self.seqs[item] 90 | labels = self.labels[item] 91 | seqs = seqs[:self.max_len] if len(seqs) > self.max_len else seqs 92 | 93 | mol_tok = self.mol_embedding[mols] 94 | seq_tok = self.seq_embedding[seqs] 95 | 96 | if mol_tok.dim() == 2: mol_tok = mol_tok.sum(0) 97 | 98 | return mol_tok, seq_tok, labels 99 | 100 | 101 | def collate_fn_pretrained_single(batch): 102 | data, labels = zip(*batch) 103 | batch_data = torch.stack(data) 104 | batch_labels = torch.stack(labels) 105 | return batch_data, batch_labels 106 | 107 | class EnzymeDatasetPretrainedSingle(Dataset): 108 | def __init__(self, data, embedding, positive_sample=True, max_len=5000): 109 | self.len = len(data) 110 | self.data = data 111 | self.embedding = embedding 112 | self.max_len = max_len 113 | self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len) 114 | 115 | def __len__(self): 116 | return self.len 117 | 118 | def __getitem__(self, item): 119 | data = self.data[item] 120 | labels = self.labels[item] 121 | data = data[:self.max_len] if len(data) > self.max_len else data 122 | 123 | emb = self.embedding[data] 124 | 125 | if emb.dim() == 2: emb = emb.sum(0) 126 | 127 | return emb, labels 128 | 129 | 130 | def graph_collate_fn(batch): 131 | mols, seqs, labels = zip(*batch) 132 | return mols, seqs, labels 133 | 134 | 135 | class GraphEnzymeDataset(Dataset): 136 | def __init__(self, molecules, sequences, alphabet, mol_graphs_dict, positive_sample=True, max_len=7000): 137 | assert len(molecules) == len(sequences) 138 | self.len = len(sequences) 139 | self.mols = molecules 140 | self.seqs = sequences 141 | self.alphabet = alphabet 142 | self.mol_graphs_dict = mol_graphs_dict 143 | self.max_len = max_len 144 | self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len) 145 | 146 | def __len__(self): 147 | return self.len 148 | 149 | def __getitem__(self, item): 150 | mols = self.mols[item] 151 | seqs = self.seqs[item] 152 | labels = self.labels[item] 153 | 154 | split_mols = mols.replace('*', 'C').split('.') 155 | 156 | seq_tok = self.alphabet.encode(seqs[:self.max_len]) 157 | graph_mols = [(self.mol_graphs_dict[mol]) for mol in split_mols] 158 | 159 | return graph_mols, torch.tensor(seq_tok).view(1, -1), labels 160 | 161 | -------------------------------------------------------------------------------- /get_afdb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | from typing import List, Union 5 | import requests 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from tqdm import tqdm 8 | import os 9 | from datetime import datetime 10 | import time 11 | 12 | import logging 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | 16 | 17 | class AlphaFetcher: 18 | """ 19 | A class to fetch and download protein metadata and files from the AlphaFold Protein Structure Database using 20 | Uniprot access codes. 21 | 22 | Attributes: 23 | uniprot_access_list (List[str]): A list storing the Uniprot access codes to be fetched. 24 | failed_ids (List[str]): A list storing any Uniprot access codes that failed to be fetched. 25 | metadata_dict (dict): A dictionary storing fetched metadata against each Uniprot access code. 26 | base_savedir (str): The base directory where fetched files will be saved. 27 | """ 28 | 29 | def __init__(self, base_savedir=os.path.join(os.getcwd(), f'alphafetcher_results_' 30 | f'{datetime.now().strftime("%Y%m%d_%H%M%S")}')): 31 | """ 32 | Initializes the AlphaFetcher class with default values. 33 | """ 34 | self.uniprot_access_list = [] 35 | self.failed_ids = [] 36 | self.metadata_dict = {} 37 | self.base_savedir = base_savedir 38 | 39 | def add_proteins(self, proteins: Union[str, List[str]]) -> None: 40 | """ 41 | Adds the provided Uniprot access codes to the list for fetching. 42 | 43 | Args: 44 | proteins (Union[str, List[str]]): A single Uniprot access code or a list of codes. 45 | 46 | Raises: 47 | ValueError: If the provided proteins parameter is neither a string nor a list of strings. 48 | """ 49 | if isinstance(proteins, str): 50 | self.uniprot_access_list.append(proteins) 51 | elif isinstance(proteins, list): 52 | self.uniprot_access_list.extend(proteins) # Using extend() method to add multiple items from a list. 53 | else: 54 | raise ValueError("Expected a string or a list of strings, but got {}".format(type(proteins))) 55 | 56 | def _fetch_single_metadata(self, uniprot_access: str, alphafold_database_base: str, pbar=None): 57 | """ 58 | Fetches the metadata for a single Uniprot access code. 59 | 60 | Args: 61 | uniprot_access (str): The Uniprot access code to fetch. 62 | alphafold_database_base (str): The base URL for the Alphafold API. 63 | pbar (tqdm, optional): A tqdm progress bar. Defaults to None. 64 | """ 65 | response = requests.get(f"{alphafold_database_base}{uniprot_access}") 66 | 67 | if response.status_code == 200: 68 | alphafold_data = response.json()[0] 69 | self.metadata_dict[uniprot_access] = alphafold_data 70 | 71 | else: 72 | self.failed_ids.append(uniprot_access) 73 | 74 | if pbar: 75 | pbar.update(1) 76 | 77 | def fetch_metadata(self, multithread: bool = False, workers: int = 10): 78 | """ 79 | Fetches metadata for all the Uniprot access codes added to the class. 80 | 81 | Args: 82 | multithread (bool, optional): If true, uses multithreading for faster fetching. Defaults to False. 83 | workers (int, optional): Number of threads to use if multithreading. If -1, uses all available CPUs. 84 | Defaults to 10. 85 | """ 86 | alphafold_api_base = "https://alphafold.ebi.ac.uk/api/prediction/" 87 | 88 | # Use all available CPUs if workers is set to -1 89 | if workers == -1: 90 | workers = os.cpu_count() or 1 # Default to 1 if os.cpu_count() returns None 91 | 92 | if len(self.uniprot_access_list) == 0: 93 | print('Please a list of Uniprot access codes with the method add_proteins()') 94 | return 95 | 96 | with tqdm(total=len(self.uniprot_access_list), desc="Fetching Metadata") as pbar: 97 | if multithread: 98 | with ThreadPoolExecutor(max_workers=workers) as executor: 99 | 100 | futures = [executor.submit(self._fetch_single_metadata, uniprot_access, alphafold_api_base, 101 | pbar) for uniprot_access in self.uniprot_access_list] 102 | 103 | # Ensure all futures have completed 104 | for _ in as_completed(futures): 105 | pass 106 | 107 | else: 108 | for uniprot_access in self.uniprot_access_list: 109 | self._fetch_single_metadata(uniprot_access, alphafold_api_base, pbar) 110 | 111 | if len(self.failed_ids) > 0: 112 | print(f'Uniprot accessions not found in database: {", ".join(self.failed_ids)}') 113 | 114 | def _download_single_protein(self, uniprot_access: str, pdb: bool = False, cif: bool = False, bcif: bool = False, 115 | pae_image: bool = False, pae_data: bool = False, pbar=None): 116 | """ 117 | Downloads files for a single Uniprot access code. 118 | 119 | Args: 120 | uniprot_access (str): The Uniprot access code to fetch. 121 | pdb (bool, optional): If true, downloads the pdb file. Defaults to False. 122 | cif (bool, optional): If true, downloads the cif file. Defaults to False. 123 | bcif (bool, optional): If true, downloads the bcif file. Defaults to False. 124 | pae_image (bool, optional): If true, downloads the PAE image file. Defaults to False. 125 | pae_data (bool, optional): If true, downloads the PAE data file. Defaults to False. 126 | pbar (tqdm, optional): A tqdm progress bar. Defaults to None. 127 | """ 128 | 129 | links_to_download = [] 130 | metadata_dict = self.metadata_dict[uniprot_access] 131 | 132 | if pdb: 133 | pdb_savedir = os.path.join(self.base_savedir, 'pdb_files') 134 | extension = 'pdb' 135 | links_to_download.append([metadata_dict['pdbUrl'], pdb_savedir, extension]) 136 | if cif: 137 | cif_savedir = os.path.join(self.base_savedir, 'cif_files') 138 | extension = 'cif' 139 | links_to_download.append([metadata_dict['cifUrl'], cif_savedir, extension]) 140 | if bcif: 141 | bcif_savedir = os.path.join(self.base_savedir, 'bcif_files') 142 | extension = 'bcif' 143 | links_to_download.append([metadata_dict['bcifUrl'], bcif_savedir, extension]) 144 | if pae_image: 145 | pae_image_savedir = os.path.join(self.base_savedir, 'pae_image_files') 146 | extension = 'png' 147 | links_to_download.append([metadata_dict['paeImageUrl'], pae_image_savedir, extension]) 148 | if pae_data: 149 | pae_data_savedir = os.path.join(self.base_savedir, 'pae_data_files') 150 | extension = 'json' 151 | links_to_download.append([metadata_dict['paeDocUrl'], pae_data_savedir, extension]) 152 | 153 | if len(links_to_download) == 0: 154 | print('Please select a type of data to download') 155 | return 156 | 157 | for data_type in links_to_download: 158 | data_type_url = data_type[0] 159 | data_type_savedir = data_type[1] 160 | file_extension = data_type[2] 161 | if not os.path.isdir(data_type_savedir): 162 | os.makedirs(data_type_savedir, exist_ok=True) 163 | 164 | response = requests.get(data_type_url) 165 | 166 | if response.status_code == 200: 167 | save_path = os.path.join(data_type_savedir, f"{uniprot_access}.{file_extension}") 168 | 169 | with open(save_path, 'wb') as f: 170 | f.write(response.content) 171 | 172 | else: 173 | print(f"Error with protein {uniprot_access}") 174 | return 175 | 176 | if pbar: 177 | pbar.update(1) 178 | 179 | def download_all_files(self, multithread: bool = False, workers: int = 10, pdb: bool = False, cif: bool = False, 180 | bcif: bool = False, pae_image: bool = False, pae_data: bool = False): 181 | """ 182 | Downloads files for all the Uniprot access codes added to the class. 183 | 184 | Args: 185 | multithread (bool, optional): If true, uses multithreading for faster downloading. Defaults to False. 186 | workers (int, optional): Number of threads to use if multithreading. If -1, uses all available CPUs. 187 | Defaults to 10. 188 | pdb (bool, optional): If true, downloads the pdb file. Defaults to False. 189 | cif (bool, optional): If true, downloads the cif file. Defaults to False. 190 | bcif (bool, optional): If true, downloads the bcif file. Defaults to False. 191 | pae_image (bool, optional): If true, downloads the PAE image file. Defaults to False. 192 | pae_data (bool, optional): If true, downloads the PAE data file. Defaults to False. 193 | """ 194 | 195 | # Use all available CPUs if workers is set to -1 196 | if workers == -1: 197 | workers = os.cpu_count() or 1 # Default to 1 if os.cpu_count() returns None 198 | 199 | if len(self.uniprot_access_list) == 0: 200 | print('Please a list of Uniprot access codes with the method add_proteins()') 201 | return 202 | 203 | # This means that fetch_metadata has not been called. If it was called but had invalid codes, self.failed_ids 204 | # would not be empty 205 | if len(self.metadata_dict) == 0 and len(self.failed_ids) == 0: 206 | self.fetch_metadata(multithread=multithread, workers=workers) 207 | 208 | # This means that after fetching the metadata, there were no valid uniprot access codes 209 | if len(self.metadata_dict) == 0 and len(self.failed_ids) > 0: 210 | print('No valid Uniprot access codes provided') 211 | return 212 | 213 | valid_uniprots = self.metadata_dict.keys() 214 | with tqdm(total=len(valid_uniprots), desc="Fetching files") as pbar: 215 | if multithread: 216 | with ThreadPoolExecutor(max_workers=workers) as executor: 217 | futures = {executor.submit(self._download_single_protein, uniprot_access, pdb, cif, bcif, pae_image, 218 | pae_data, pbar): uniprot_access for uniprot_access in valid_uniprots} 219 | 220 | # Ensure all futures have completed and handle exceptions 221 | for future in as_completed(futures): 222 | uniprot_access = futures.get(future) 223 | try: 224 | future.result() 225 | except Exception as e: 226 | logging.error(f"Error in thread for {uniprot_access}: {e}") 227 | 228 | else: 229 | for uniprot_access in valid_uniprots: 230 | self._download_single_protein(uniprot_access, pdb, cif, bcif, pae_image, pae_data, pbar) 231 | 232 | 233 | 234 | def get_alphafold_download_link(uniprot_id): 235 | link_pattern = 'https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v2.pdb' 236 | return link_pattern.format(uniprot_id) 237 | 238 | def download_alphafold_prediction(uniprot_id, path): 239 | url = get_alphafold_download_link(uniprot_id) 240 | result = subprocess.run(['wget', url, '-O', f'/{path}/{uniprot_id}.pdb']) 241 | return result # Result will be 0 if operation was successful 242 | 243 | 244 | if __name__ == "__main__": 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument('--save_dir', type=str, default='./data/afdb') 247 | parser.add_argument('--data_dir', type=str, default='./data') 248 | args = parser.parse_args() 249 | 250 | os.makedirs(args.save_dir, exist_ok=True) 251 | 252 | print('loading...') 253 | trn_data = torch.load(os.path.join(args.data_dir, 'positive_train_val_time.pt')) 254 | tst_data = torch.load(os.path.join(args.data_dir, 'positive_test_time.pt')) 255 | 256 | uniprots = list(trn_data.keys()) + list(tst_data.keys()) 257 | 258 | print(f'fetching {len(uniprots)} pdbs from alphafold database...') 259 | 260 | fetcher = AlphaFetcher(base_savedir=args.save_dir) 261 | # Add desired Uniprot access codes 262 | fetcher.add_proteins(uniprots) 263 | print(f'fetching {len(uniprots)} pdbs from alphafold database...') 264 | 265 | # Retrieve metadata 266 | fetcher.fetch_metadata(multithread=True, workers=-1) 267 | # Metadata available at fetcher.metadata_dict 268 | 269 | # Commence download of specified files 270 | fetcher.download_all_files(pdb=True, cif=False, multithread=True, workers=-1) 271 | 272 | 273 | -------------------------------------------------------------------------------- /logger/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save model logging. 2 | -------------------------------------------------------------------------------- /mat.py: -------------------------------------------------------------------------------- 1 | import math, copy 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.nn.init import _calculate_fan_in_and_fan_out, _no_grad_normal_, _no_grad_uniform_ 10 | 11 | 12 | def xavier_normal_small_init_(tensor, gain=1.): 13 | # type: (Tensor, float) -> Tensor 14 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 15 | std = gain * math.sqrt(2.0 / float(fan_in + 4*fan_out)) 16 | 17 | return _no_grad_normal_(tensor, 0., std) 18 | 19 | 20 | def xavier_uniform_small_init_(tensor, gain=1.): 21 | # type: (Tensor, float) -> Tensor 22 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 23 | std = gain * math.sqrt(2.0 / float(fan_in + 4*fan_out)) 24 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 25 | 26 | return _no_grad_uniform_(tensor, -a, a) 27 | 28 | 29 | def make_model(d_atom, N=2, d_model=128, h=8, dropout=0.1, 30 | lambda_attention=0.3, lambda_distance=0.3, trainable_lambda=False, 31 | N_dense=2, leaky_relu_slope=0.0, aggregation_type='mean', 32 | dense_output_nonlinearity='relu', distance_matrix_kernel='softmax', 33 | use_edge_features=False, n_output=1, 34 | control_edges=False, integrated_distances=False, 35 | scale_norm=False, init_type='uniform', use_adapter=False, n_generator_layers=1): 36 | "Helper: Construct a model from hyperparameters." 37 | c = copy.deepcopy 38 | attn = MultiHeadedAttention(h, d_model, dropout, lambda_attention, lambda_distance, trainable_lambda, distance_matrix_kernel, use_edge_features, control_edges, integrated_distances) 39 | ff = PositionwiseFeedForward(d_model, N_dense, dropout, leaky_relu_slope, dense_output_nonlinearity) 40 | model = GraphTransformer( 41 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout, scale_norm, use_adapter), N, scale_norm), 42 | Embeddings(d_model, d_atom, dropout), 43 | Generator(d_model, aggregation_type, n_output, n_generator_layers, leaky_relu_slope, dropout, scale_norm)) 44 | 45 | # This was important from their code. 46 | # Initialize parameters with Glorot / fan_avg. 47 | for p in model.parameters(): 48 | if p.dim() > 1: 49 | if init_type == 'uniform': 50 | nn.init.xavier_uniform_(p) 51 | elif init_type == 'normal': 52 | nn.init.xavier_normal_(p) 53 | elif init_type == 'small_normal_init': 54 | xavier_normal_small_init_(p) 55 | elif init_type == 'small_uniform_init': 56 | xavier_uniform_small_init_(p) 57 | return model 58 | 59 | 60 | class GraphTransformer(nn.Module): 61 | def __init__(self, encoder, src_embed, generator): 62 | super(GraphTransformer, self).__init__() 63 | self.encoder = encoder 64 | self.src_embed = src_embed 65 | self.generator = generator 66 | 67 | def forward(self, src, src_mask, adj_matrix, distances_matrix, edges_att): 68 | "Take in and process masked src and target sequences." 69 | return self.predict(self.encode(src, src_mask, adj_matrix, distances_matrix, edges_att), src_mask) 70 | 71 | def encode(self, src, src_mask, adj_matrix, distances_matrix, edges_att): 72 | return self.encoder(self.src_embed(src), src_mask, adj_matrix, distances_matrix, edges_att) 73 | 74 | def predict(self, out, out_mask): 75 | return self.generator(out, out_mask) 76 | 77 | 78 | class Generator(nn.Module): 79 | "Define standard linear + softmax generation step." 80 | def __init__(self, d_model, aggregation_type='mean', n_output=1, n_layers=1, 81 | leaky_relu_slope=0.01, dropout=0.0, scale_norm=False): 82 | super(Generator, self).__init__() 83 | if n_layers == 1: 84 | self.proj = nn.Linear(d_model, n_output) 85 | else: 86 | self.proj = [] 87 | for i in range(n_layers-1): 88 | self.proj.append(nn.Linear(d_model, d_model)) 89 | self.proj.append(nn.LeakyReLU(leaky_relu_slope)) 90 | self.proj.append(ScaleNorm(d_model) if scale_norm else LayerNorm(d_model)) 91 | self.proj.append(nn.Dropout(dropout)) 92 | self.proj.append(nn.Linear(d_model, n_output)) 93 | self.proj = torch.nn.Sequential(*self.proj) 94 | self.aggregation_type = aggregation_type 95 | 96 | def forward(self, x, mask): 97 | mask = mask.unsqueeze(-1).float() 98 | out_masked = x * mask 99 | if self.aggregation_type == 'mean': 100 | out_sum = out_masked.sum(dim=1) 101 | mask_sum = mask.sum(dim=(1)) 102 | out_avg_pooling = out_sum / mask_sum 103 | elif self.aggregation_type == 'sum': 104 | out_sum = out_masked.sum(dim=1) 105 | out_avg_pooling = out_sum 106 | elif self.aggregation_type == 'dummy_node': 107 | out_avg_pooling = out_masked[:,0] 108 | projected = self.proj(out_avg_pooling) 109 | return projected 110 | 111 | 112 | class PositionGenerator(nn.Module): 113 | "Define standard linear + softmax generation step." 114 | def __init__(self, d_model): 115 | super(PositionGenerator, self).__init__() 116 | self.norm = LayerNorm(d_model) 117 | self.proj = nn.Linear(d_model, 3) 118 | 119 | def forward(self, x, mask): 120 | mask = mask.unsqueeze(-1).float() 121 | out_masked = self.norm(x) * mask 122 | projected = self.proj(out_masked) 123 | return projected 124 | 125 | 126 | ### Encoder 127 | 128 | def clones(module, N): 129 | "Produce N identical layers." 130 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 131 | 132 | 133 | class Encoder(nn.Module): 134 | "Core encoder is a stack of N layers" 135 | def __init__(self, layer, N, scale_norm): 136 | super(Encoder, self).__init__() 137 | self.layers = clones(layer, N) 138 | self.norm = ScaleNorm(layer.size) if scale_norm else LayerNorm(layer.size) 139 | 140 | def forward(self, x, mask, adj_matrix, distances_matrix, edges_att): 141 | "Pass the input (and mask) through each layer in turn." 142 | for layer in self.layers: 143 | x = layer(x, mask, adj_matrix, distances_matrix, edges_att) 144 | return self.norm(x) 145 | 146 | 147 | class LayerNorm(nn.Module): 148 | "Construct a layernorm module (See citation for details)." 149 | def __init__(self, features, eps=1e-6): 150 | super(LayerNorm, self).__init__() 151 | self.a_2 = nn.Parameter(torch.ones(features)) 152 | self.b_2 = nn.Parameter(torch.zeros(features)) 153 | self.eps = eps 154 | 155 | def forward(self, x): 156 | mean = x.mean(-1, keepdim=True) 157 | std = x.std(-1, keepdim=True) 158 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 159 | 160 | 161 | class ScaleNorm(nn.Module): 162 | """ScaleNorm""" 163 | "All g’s in SCALE NORM are initialized to sqrt(d)" 164 | def __init__(self, scale, eps=1e-5): 165 | super(ScaleNorm, self).__init__() 166 | self.scale = nn.Parameter(torch.tensor(math.sqrt(scale))) 167 | self.eps = eps 168 | 169 | def forward(self, x): 170 | norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 171 | return x * norm 172 | 173 | 174 | class SublayerConnection(nn.Module): 175 | """ 176 | A residual connection followed by a layer norm. 177 | Note for code simplicity the norm is first as opposed to last. 178 | """ 179 | def __init__(self, size, dropout, scale_norm, use_adapter): 180 | super(SublayerConnection, self).__init__() 181 | self.norm = ScaleNorm(size) if scale_norm else LayerNorm(size) 182 | self.dropout = nn.Dropout(dropout) 183 | self.use_adapter = use_adapter 184 | self.adapter = Adapter(size, 8) if use_adapter else None 185 | 186 | def forward(self, x, sublayer): 187 | "Apply residual connection to any sublayer with the same size." 188 | if self.use_adapter: 189 | return x + self.dropout(self.adapter(sublayer(self.norm(x)))) 190 | return x + self.dropout(sublayer(self.norm(x))) 191 | 192 | 193 | class EncoderLayer(nn.Module): 194 | "Encoder is made up of self-attn and feed forward (defined below)" 195 | def __init__(self, size, self_attn, feed_forward, dropout, scale_norm, use_adapter): 196 | super(EncoderLayer, self).__init__() 197 | self.self_attn = self_attn 198 | self.feed_forward = feed_forward 199 | self.sublayer = clones(SublayerConnection(size, dropout, scale_norm, use_adapter), 2) 200 | self.size = size 201 | 202 | def forward(self, x, mask, adj_matrix, distances_matrix, edges_att): 203 | "Follow Figure 1 (left) for connections." 204 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, adj_matrix, distances_matrix, edges_att, mask)) 205 | return self.sublayer[1](x, self.feed_forward) 206 | 207 | 208 | ### Attention 209 | 210 | class EdgeFeaturesLayer(nn.Module): 211 | def __init__(self, d_model, d_edge, h, dropout): 212 | super(EdgeFeaturesLayer, self).__init__() 213 | assert d_model % h == 0 214 | d_k = d_model // h 215 | self.linear = nn.Linear(d_edge, 1, bias=False) 216 | with torch.no_grad(): 217 | self.linear.weight.fill_(0.25) 218 | 219 | def forward(self, x): 220 | p_edge = x.permute(0, 2, 3, 1) 221 | p_edge = self.linear(p_edge).permute(0, 3, 1, 2) 222 | return torch.relu(p_edge) 223 | 224 | 225 | def attention(query, key, value, adj_matrix, distances_matrix, edges_att, 226 | mask=None, dropout=None, 227 | lambdas=(0.3, 0.3, 0.4), trainable_lambda=False, 228 | distance_matrix_kernel=None, use_edge_features=False, control_edges=False, 229 | eps=1e-6, inf=1e12): 230 | "Compute 'Scaled Dot Product Attention'" 231 | d_k = query.size(-1) 232 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 233 | / math.sqrt(d_k) 234 | if mask is not None: 235 | scores = scores.masked_fill(mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2], 1) == 0, -inf) 236 | p_attn = F.softmax(scores, dim = -1) 237 | 238 | if use_edge_features: 239 | adj_matrix = edges_att.view(adj_matrix.shape) 240 | 241 | # Prepare adjacency matrix 242 | adj_matrix = adj_matrix / (adj_matrix.sum(dim=-1).unsqueeze(2) + eps) 243 | adj_matrix = adj_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1) 244 | p_adj = adj_matrix 245 | 246 | p_dist = distances_matrix 247 | 248 | if trainable_lambda: 249 | softmax_attention, softmax_distance, softmax_adjacency = lambdas.cuda() 250 | p_weighted = softmax_attention * p_attn + softmax_distance * p_dist + softmax_adjacency * p_adj 251 | else: 252 | lambda_attention, lambda_distance, lambda_adjacency = lambdas 253 | p_weighted = lambda_attention * p_attn + lambda_distance * p_dist + lambda_adjacency * p_adj 254 | 255 | if dropout is not None: 256 | p_weighted = dropout(p_weighted) 257 | 258 | atoms_featrues = torch.matmul(p_weighted, value) 259 | return atoms_featrues, p_weighted, p_attn 260 | 261 | 262 | class MultiHeadedAttention(nn.Module): 263 | def __init__(self, h, d_model, dropout=0.1, lambda_attention=0.3, lambda_distance=0.3, trainable_lambda=False, 264 | distance_matrix_kernel='softmax', use_edge_features=False, control_edges=False, integrated_distances=False): 265 | "Take in model size and number of heads." 266 | super(MultiHeadedAttention, self).__init__() 267 | assert d_model % h == 0 268 | # We assume d_v always equals d_k 269 | self.d_k = d_model // h 270 | self.h = h 271 | self.trainable_lambda = trainable_lambda 272 | if trainable_lambda: 273 | lambda_adjacency = 1. - lambda_attention - lambda_distance 274 | lambdas_tensor = torch.tensor([lambda_attention, lambda_distance, lambda_adjacency], requires_grad=True) 275 | self.lambdas = torch.nn.Parameter(lambdas_tensor) 276 | else: 277 | lambda_adjacency = 1. - lambda_attention - lambda_distance 278 | self.lambdas = (lambda_attention, lambda_distance, lambda_adjacency) 279 | 280 | self.linears = clones(nn.Linear(d_model, d_model), 4) 281 | self.attn = None 282 | self.dropout = nn.Dropout(p=dropout) 283 | if distance_matrix_kernel == 'softmax': 284 | self.distance_matrix_kernel = lambda x: F.softmax(-x, dim = -1) 285 | elif distance_matrix_kernel == 'exp': 286 | self.distance_matrix_kernel = lambda x: torch.exp(-x) 287 | self.integrated_distances = integrated_distances 288 | self.use_edge_features = use_edge_features 289 | self.control_edges = control_edges 290 | if use_edge_features: 291 | d_edge = 11 if not integrated_distances else 12 292 | self.edges_feature_layer = EdgeFeaturesLayer(d_model, d_edge, h, dropout) 293 | 294 | def forward(self, query, key, value, adj_matrix, distances_matrix, edges_att, mask=None): 295 | "Implements Figure 2" 296 | if mask is not None: 297 | # Same mask applied to all h heads. 298 | mask = mask.unsqueeze(1) 299 | nbatches = query.size(0) 300 | 301 | # 1) Do all the linear projections in batch from d_model => h x d_k 302 | query, key, value = \ 303 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 304 | for l, x in zip(self.linears, (query, key, value))] 305 | 306 | # Prepare distances matrix 307 | distances_matrix = distances_matrix.masked_fill(mask.repeat(1, mask.shape[-1], 1) == 0, np.inf) 308 | distances_matrix = self.distance_matrix_kernel(distances_matrix) 309 | p_dist = distances_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1) 310 | 311 | if self.use_edge_features: 312 | if self.integrated_distances: 313 | edges_att = torch.cat((edges_att, distances_matrix.unsqueeze(1)), dim=1) 314 | edges_att = self.edges_feature_layer(edges_att) 315 | 316 | # 2) Apply attention on all the projected vectors in batch. 317 | x, self.attn, self.self_attn = attention(query, key, value, adj_matrix, 318 | p_dist, edges_att, 319 | mask=mask, dropout=self.dropout, 320 | lambdas=self.lambdas, 321 | trainable_lambda=self.trainable_lambda, 322 | distance_matrix_kernel=self.distance_matrix_kernel, 323 | use_edge_features=self.use_edge_features, 324 | control_edges=self.control_edges) 325 | 326 | # 3) "Concat" using a view and apply a final linear. 327 | x = x.transpose(1, 2).contiguous() \ 328 | .view(nbatches, -1, self.h * self.d_k) 329 | return self.linears[-1](x) 330 | 331 | 332 | ### Conv 1x1 aka Positionwise feed forward 333 | 334 | class PositionwiseFeedForward(nn.Module): 335 | "Implements FFN equation." 336 | def __init__(self, d_model, N_dense, dropout=0.1, leaky_relu_slope=0.0, dense_output_nonlinearity='relu'): 337 | super(PositionwiseFeedForward, self).__init__() 338 | self.N_dense = N_dense 339 | self.linears = clones(nn.Linear(d_model, d_model), N_dense) 340 | self.dropout = clones(nn.Dropout(dropout), N_dense) 341 | self.leaky_relu_slope = leaky_relu_slope 342 | if dense_output_nonlinearity == 'relu': 343 | self.dense_output_nonlinearity = lambda x: F.leaky_relu(x, negative_slope=self.leaky_relu_slope) 344 | elif dense_output_nonlinearity == 'tanh': 345 | self.tanh = torch.nn.Tanh() 346 | self.dense_output_nonlinearity = lambda x: self.tanh(x) 347 | elif dense_output_nonlinearity == 'none': 348 | self.dense_output_nonlinearity = lambda x: x 349 | 350 | 351 | def forward(self, x): 352 | if self.N_dense == 0: 353 | return x 354 | 355 | for i in range(len(self.linears)-1): 356 | x = self.dropout[i](F.leaky_relu(self.linears[i](x), negative_slope=self.leaky_relu_slope)) 357 | 358 | return self.dropout[-1](self.dense_output_nonlinearity(self.linears[-1](x))) 359 | 360 | 361 | ## Embeddings 362 | 363 | class Embeddings(nn.Module): 364 | def __init__(self, d_model, d_atom, dropout): 365 | super(Embeddings, self).__init__() 366 | self.lut = nn.Linear(d_atom, d_model) 367 | self.dropout = nn.Dropout(dropout) 368 | 369 | def forward(self, x): 370 | return self.dropout(self.lut(x)) 371 | -------------------------------------------------------------------------------- /model/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save models. 2 | -------------------------------------------------------------------------------- /prepare_graphs.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from tqdm import tqdm 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | from sklearn.metrics import pairwise_distances 11 | 12 | def smiles_to_mol(smiles): 13 | try: 14 | mol = Chem.MolFromSmiles(smiles) 15 | 16 | if mol is None: 17 | mol = Chem.MolFromSmiles(smiles, sanitize=False) 18 | 19 | # AllChem.Compute2DCoords(mol) 20 | try: 21 | mol = Chem.AddHs(mol) 22 | AllChem.EmbedMolecule(mol, maxAttempts=1000) 23 | AllChem.UFFOptimizeMolecule(mol) 24 | mol = Chem.RemoveHs(mol) 25 | except: 26 | AllChem.Compute2DCoords(mol) 27 | except ValueError as e: 28 | logging.warning('the SMILES ({}) can not be converted to a graph.\nREASON: {}'.format(smiles, e)) 29 | 30 | afm, adj, dist = featurize_mol(mol, add_dummy_node=True, one_hot_formal_charge=True) 31 | return afm, adj, dist 32 | 33 | 34 | def featurize_mol(mol, add_dummy_node, one_hot_formal_charge): 35 | node_features = np.array([get_atom_features(atom, one_hot_formal_charge) 36 | for atom in mol.GetAtoms()]) 37 | 38 | adj_matrix = np.eye(mol.GetNumAtoms()) 39 | for bond in mol.GetBonds(): 40 | begin_atom = bond.GetBeginAtom().GetIdx() 41 | end_atom = bond.GetEndAtom().GetIdx() 42 | adj_matrix[begin_atom, end_atom] = adj_matrix[end_atom, begin_atom] = 1 43 | 44 | conf = mol.GetConformer() 45 | pos_matrix = np.array([[conf.GetAtomPosition(k).x, conf.GetAtomPosition(k).y, conf.GetAtomPosition(k).z] 46 | for k in range(mol.GetNumAtoms())]) 47 | dist_matrix = pairwise_distances(pos_matrix) 48 | 49 | if add_dummy_node: 50 | m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1)) 51 | m[1:, 1:] = node_features 52 | m[0, 0] = 1. 53 | node_features = m 54 | 55 | m = np.zeros((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1)) 56 | m[1:, 1:] = adj_matrix 57 | adj_matrix = m 58 | 59 | m = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6) 60 | m[1:, 1:] = dist_matrix 61 | dist_matrix = m 62 | 63 | return node_features, adj_matrix, dist_matrix 64 | 65 | 66 | def get_atom_features(atom, one_hot_formal_charge=True): 67 | attributes = [] 68 | 69 | attributes += one_hot_vector( 70 | atom.GetAtomicNum(), 71 | [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999] 72 | ) 73 | 74 | attributes += one_hot_vector( 75 | len(atom.GetNeighbors()), 76 | [0, 1, 2, 3, 4, 5] 77 | ) 78 | 79 | attributes += one_hot_vector( 80 | atom.GetTotalNumHs(), 81 | [0, 1, 2, 3, 4] 82 | ) 83 | 84 | if one_hot_formal_charge: 85 | attributes += one_hot_vector( 86 | atom.GetFormalCharge(), 87 | [-1, 0, 1] 88 | ) 89 | else: 90 | attributes.append(atom.GetFormalCharge()) 91 | 92 | attributes.append(atom.IsInRing()) 93 | attributes.append(atom.GetIsAromatic()) 94 | 95 | return np.array(attributes, dtype=np.float32) 96 | 97 | 98 | def one_hot_vector(val, lst): 99 | if val not in lst: 100 | val = lst[-1] 101 | return map(lambda x: x == val, lst) 102 | 103 | 104 | 105 | if __name__ == '__main__': 106 | uni2mol = pd.read_csv('data/uniprot_molecules.tsv', sep='\t', header=0) 107 | comprehend = pd.read_csv('data/cleaned_uniprot_rhea.tsv', sep='\t', header=0) 108 | 109 | uni_seq_dict = {comprehend['Entry'][i]: comprehend['Sequence'][i] for i in range(len(comprehend['Entry']))} 110 | uni_mol_dict = {uni2mol['uniprot_id'][i]: uni2mol['molecules'][i] for i in range(len(uni2mol['uniprot_id'])) if uni2mol['uniprot_id'][i] in uni_seq_dict} 111 | 112 | smiles = [mol.replace('*', 'C').split('.') for uni, mol in uni_mol_dict.items()] 113 | unique_smis = list(set(itertools.chain(*smiles))) 114 | 115 | uni_smi_dict = {} 116 | for smi in tqdm(unique_smis): 117 | rep = smiles_to_mol(smi) 118 | uni_smi_dict[smi] = {'node':torch.FloatTensor(rep[0]), 'adj':torch.FloatTensor(rep[1]) ,'dist':torch.FloatTensor(rep[2])} 119 | 120 | torch.save(uni_smi_dict, 'data/mol_graphs.pt') 121 | 122 | -------------------------------------------------------------------------------- /prepare_negative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from Levenshtein import distance 4 | 5 | import random 6 | from tqdm import tqdm 7 | 8 | def get_samples(pos_path): 9 | pos_samples = torch.load(pos_path) 10 | 11 | pos_mols = [] 12 | pos_seqs = [] 13 | for unis, values in pos_samples.items(): 14 | mol = values[0] 15 | seq = values[1] 16 | pos_mols.append(mol.replace('*', 'C')) 17 | pos_seqs.append(seq) 18 | 19 | assert len(pos_mols) == len(pos_seqs) 20 | 21 | return pos_mols, pos_seqs 22 | 23 | def find_similar_sequences(target_sequence, sequence_list, n_sample=5): 24 | # Calculate the distance between the target sequence and each sequence in the list 25 | distances = [(seq, distance(target_sequence, seq)) for seq in sequence_list] 26 | 27 | # Sort the list of sequences by distance 28 | sorted_sequences = sorted(distances, key=lambda x: x[1]) 29 | return [seq for seq, dist in sorted_sequences[1 :n_sample+1]] 30 | 31 | 32 | if __name__ == "__main__": 33 | pos_trn_mols, pos_trn_seqs = get_samples('data/new_time/positive_train_val_time.pt') 34 | pos_tst_mols, pos_tst_seqs = get_samples('data/new_time/positive_test_time.pt') 35 | unique_seqs = list(set(pos_trn_seqs + pos_tst_seqs)) 36 | unique_mols = list(set(pos_trn_mols + pos_tst_mols)) 37 | 38 | negative_mol_dict = {} 39 | for mol in tqdm(unique_mols): 40 | negative_mol = find_similar_sequences(mol, unique_mols, n_sample=2000) 41 | negative_mol_dict[mol] = negative_mol 42 | 43 | torch.save(negative_mol_dict, 'data/negative_mol_dict.pt') 44 | 45 | 46 | 47 | # negative_seq_dict = {} 48 | # for seq in tqdm(unique_seqs): 49 | # negative_seq = find_similar_sequences(seq, unique_seqs, n_sample=100) 50 | # negative_seq_dict[seq] = negative_seq 51 | 52 | # torch.save(negative_seq_dict, 'data/negative_seq_dict.pt') 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /pretrained/SaProt_650M_PDB/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save SaProt pretrained weights from https://huggingface.co/westlake-repl/SaProt_650M_PDB. 2 | 3 | Put (1) SaProt_650M_PDB.pt; (2) config.json; (3) tokenizer_config.json; (4) special_tokens_map.json; (5) pytorch_model.bin; (6) vocab.txt; (7) gitattributes. 4 | -------------------------------------------------------------------------------- /pretrained/readme.md: -------------------------------------------------------------------------------- 1 | Folder to save pretrained weights for ESM, SaProt, MAT. Uni-Mol is processed differently following https://github.com/deepmodeling/Uni-Mol. 2 | 3 | ESM pretrained weights: https://github.com/facebookresearch/esm/tree/main 4 | 5 | SaProt pretrained weights: https://github.com/westlake-repl/SaProt 6 | 7 | MAT pretrained weights: https://github.com/ardigen/MAT 8 | 9 | Put (1) esm2_t30_150M_UR50D-contact-regression.pt; (2) esm2_t30_150M_UR50D.pt; (3)SaProt_650M_PDB (a folder); (4) mat.pt 10 | -------------------------------------------------------------------------------- /process_esm.py: -------------------------------------------------------------------------------- 1 | import esm 2 | import torch 3 | 4 | from tqdm import tqdm 5 | import pandas as pd 6 | from data_utils import * 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | if __name__ == "__main__": 11 | home_dict = './pretrained/' 12 | model_name = 'esm2_t33_650M_UR50D.pt' 13 | model, alphabet = esm.pretrained.load_model_and_alphabet(home_dict + model_name) 14 | #model, alphabet = esm.pretrained.esm2_t30_150M_UR50D() 15 | model = model.to(device) 16 | model.eval() 17 | 18 | if model_name in {'esm2_t33_650M_UR50D.pt'}: 19 | layer = 33 20 | 21 | elif model_name in {'esm2_t30_150M_UR50D.pt'}: 22 | layer = 30 23 | 24 | 25 | uni_seq_embedding_dict = {} 26 | pos_trn_mols, pos_trn_seqs, _, _ = get_samples('data/new_time/positive_train_val_time.pt', 'data/new_time/negative_train_val_time.pt') 27 | pos_tst_mols, pos_tst_seqs, _, _ = get_samples('data/new_time/positive_test_time.pt', 'data/new_time/negative_test_time.pt') 28 | unique_seqs = list(set(pos_trn_seqs + pos_tst_seqs)) 29 | 30 | with torch.no_grad(): 31 | for seq in tqdm(unique_seqs): 32 | if len(seq) > 5000: 33 | seq = seq[:5000] 34 | 35 | toks = torch.tensor(alphabet.encode(seq)).view(1, -1).to(device) 36 | out = model(toks, repr_layers=[33], return_contacts=False) 37 | uni_seq_embedding_dict[seq] = out['representations'][33].squeeze().mean(0).detach() 38 | 39 | torch.cuda.empty_cache() 40 | 41 | torch.save(uni_seq_embedding_dict, 'data/seq_embedding.pt') 42 | -------------------------------------------------------------------------------- /process_mat.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | from rdkit import Chem 4 | from rdkit.Chem import AllChem 5 | from sklearn.metrics import pairwise_distances 6 | 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | from mat import make_model 11 | from data_utils import * 12 | 13 | model_params = { 14 | 'd_atom': 28, 15 | 'd_model': 1024, 16 | 'N': 8, 17 | 'h': 16, 18 | 'N_dense': 1, 19 | 'lambda_attention': 0.33, 20 | 'lambda_distance': 0.33, 21 | 'leaky_relu_slope': 0.1, 22 | 'dense_output_nonlinearity': 'relu', 23 | 'distance_matrix_kernel': 'exp', 24 | 'dropout': 0.0, 25 | 'aggregation_type': 'mean' 26 | } 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | use_cuda = torch.cuda.is_available() 30 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 31 | 32 | def get_pretrained_mat(model_path): 33 | model = make_model(**model_params) 34 | pretrained_state_dict = torch.load(model_path) 35 | model_state_dict = model.state_dict() 36 | 37 | for name, param in pretrained_state_dict.items(): 38 | if 'generator' in name: 39 | continue 40 | if isinstance(param, torch.nn.Parameter): 41 | param = param.data 42 | model_state_dict[name].copy_(param) 43 | 44 | return model 45 | 46 | def mol_embedder(mat, node_feats, adjacency, distance): 47 | batch_mask = torch.sum(torch.abs(node_feats), dim=-1) != 0 48 | embedding = mat.encode(node_feats, batch_mask, adjacency, distance, None).squeeze() 49 | return embedding 50 | 51 | def smiles_to_mol(smiles): 52 | try: 53 | mol = Chem.MolFromSmiles(smiles) 54 | AllChem.Compute2DCoords(mol) 55 | # try: 56 | # mol = Chem.AddHs(mol) 57 | # AllChem.EmbedMolecule(mol, maxAttempts=500) 58 | # AllChem.UFFOptimizeMolecule(mol) 59 | # mol = Chem.RemoveHs(mol) 60 | # except: 61 | # AllChem.Compute2DCoords(mol) 62 | except ValueError as e: 63 | logging.warning('the SMILES ({}) can not be converted to a graph.\nREASON: {}'.format(smiles, e)) 64 | 65 | afm, adj, dist = featurize_mol(mol, add_dummy_node=True, one_hot_formal_charge=True) 66 | return afm, adj, dist 67 | 68 | 69 | def featurize_mol(mol, add_dummy_node, one_hot_formal_charge): 70 | node_features = np.array([get_atom_features(atom, one_hot_formal_charge) 71 | for atom in mol.GetAtoms()]) 72 | 73 | adj_matrix = np.eye(mol.GetNumAtoms()) 74 | for bond in mol.GetBonds(): 75 | begin_atom = bond.GetBeginAtom().GetIdx() 76 | end_atom = bond.GetEndAtom().GetIdx() 77 | adj_matrix[begin_atom, end_atom] = adj_matrix[end_atom, begin_atom] = 1 78 | 79 | conf = mol.GetConformer() 80 | pos_matrix = np.array([[conf.GetAtomPosition(k).x, conf.GetAtomPosition(k).y, conf.GetAtomPosition(k).z] 81 | for k in range(mol.GetNumAtoms())]) 82 | dist_matrix = pairwise_distances(pos_matrix) 83 | 84 | if add_dummy_node: 85 | m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1)) 86 | m[1:, 1:] = node_features 87 | m[0, 0] = 1. 88 | node_features = m 89 | 90 | m = np.zeros((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1)) 91 | m[1:, 1:] = adj_matrix 92 | adj_matrix = m 93 | 94 | m = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6) 95 | m[1:, 1:] = dist_matrix 96 | dist_matrix = m 97 | 98 | return node_features, adj_matrix, dist_matrix 99 | 100 | 101 | def get_atom_features(atom, one_hot_formal_charge=True): 102 | attributes = [] 103 | 104 | attributes += one_hot_vector( 105 | atom.GetAtomicNum(), 106 | [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999] 107 | ) 108 | 109 | attributes += one_hot_vector( 110 | len(atom.GetNeighbors()), 111 | [0, 1, 2, 3, 4, 5] 112 | ) 113 | 114 | attributes += one_hot_vector( 115 | atom.GetTotalNumHs(), 116 | [0, 1, 2, 3, 4] 117 | ) 118 | 119 | if one_hot_formal_charge: 120 | attributes += one_hot_vector( 121 | atom.GetFormalCharge(), 122 | [-1, 0, 1] 123 | ) 124 | else: 125 | attributes.append(atom.GetFormalCharge()) 126 | 127 | attributes.append(atom.IsInRing()) 128 | attributes.append(atom.GetIsAromatic()) 129 | 130 | return np.array(attributes, dtype=np.float32) 131 | 132 | 133 | def one_hot_vector(val, lst): 134 | if val not in lst: 135 | val = lst[-1] 136 | return map(lambda x: x == val, lst) 137 | 138 | 139 | 140 | if __name__ == "__main__": 141 | home_dict = './pretrained/' 142 | model_name = 'mat.pt' 143 | mat = get_pretrained_mat(home_dict + model_name).to(device) 144 | mat.eval() 145 | 146 | uni_mol_embedding_dict = {} 147 | pos_trn_mols, pos_trn_seqs, _, _ = get_samples('data/new_time/positive_train_val_time.pt', 'data/new_time/negative_train_val_time.pt') 148 | pos_tst_mols, pos_tst_seqs, _, _ = get_samples('data/new_time/positive_test_time.pt', 'data/new_time/negative_test_time.pt') 149 | unique_mols = list(set(pos_trn_mols + pos_tst_mols)) 150 | 151 | with torch.no_grad(): 152 | for smi in tqdm(unique_mols): 153 | smiles = smi.replace('*', 'C').split('.') 154 | molecules = [smiles_to_mol(i) for i in smiles] 155 | embeddings = [mol_embedder(mat, FloatTensor(feat).unsqueeze(0), FloatTensor(adj).unsqueeze(0), FloatTensor(dist).unsqueeze(0)) for feat, adj, dist in molecules] 156 | embeddings = torch.stack([i.squeeze().mean(0) for i in embeddings], dim=0).mean(0).detach() 157 | uni_mol_embedding_dict[smi] = embeddings 158 | 159 | torch.cuda.empty_cache() 160 | 161 | 162 | torch.save(uni_mol_embedding_dict, 'data/mol_embedding.pt') 163 | -------------------------------------------------------------------------------- /process_saprot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import os 7 | import time 8 | import json 9 | import numpy as np 10 | import sys 11 | import esm 12 | 13 | from transformers import EsmTokenizer, EsmForMaskedLM 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | def get_struc_seq(foldseek, 18 | path, 19 | chains: list = None, 20 | process_id: int = 0, 21 | plddt_path: str = None, 22 | plddt_threshold: float = 70.) -> dict: 23 | """ 24 | 25 | Args: 26 | foldseek: Binary executable file of foldseek 27 | path: Path to pdb file 28 | chains: Chains to be extracted from pdb file. If None, all chains will be extracted. 29 | process_id: Process ID for temporary files. This is used for parallel processing. 30 | plddt_path: Path to plddt file. If None, plddt will not be used. 31 | plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked. 32 | 33 | Returns: 34 | seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of 35 | (seq, struc_seq, combined_seq). 36 | """ 37 | assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}" 38 | assert os.path.exists(path), f"Pdb file not found: {path}" 39 | assert plddt_path is None or os.path.exists(plddt_path), f"Plddt file not found: {plddt_path}" 40 | 41 | tmp_save_path = f"get_struc_seq_{process_id}.tsv" 42 | 43 | cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}" 44 | os.system(cmd) 45 | 46 | seq_dict = {} 47 | name = os.path.basename(path) 48 | with open(tmp_save_path, "r") as r: 49 | for i, line in enumerate(r): 50 | desc, seq, struc_seq = line.split("\t")[:3] 51 | 52 | # Mask low plddt 53 | if plddt_path is not None: 54 | with open(plddt_path, "r") as r: 55 | plddts = np.array(json.load(r)["confidenceScore"]) 56 | 57 | # Mask regions with plddt < threshold 58 | indices = np.where(plddts < plddt_threshold)[0] 59 | np_seq = np.array(list(struc_seq)) 60 | np_seq[indices] = "#" 61 | struc_seq = "".join(np_seq) 62 | 63 | name_chain = desc.split(" ")[0] 64 | chain = name_chain.replace(name, "").split("_")[-1] 65 | 66 | if chains is None or chain in chains: 67 | if chain not in seq_dict: 68 | combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)]) 69 | seq_dict[chain] = (seq, struc_seq, combined_seq) 70 | 71 | os.remove(tmp_save_path) 72 | os.remove(tmp_save_path + ".dbtype") 73 | return seq_dict 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--seq_path', type=str, default='./data/saprot_seq.pt') 79 | parser.add_argument('--model_path', type=str, default='./weights/SaProt_650M_PDB') 80 | args = parser.parse_args() 81 | 82 | 83 | print('loading data...') 84 | trn_data = torch.load('data/positive_train_val_time.pt') 85 | tst_data = torch.load('data/positive_test_time.pt') 86 | trn_data.update(tst_data) 87 | 88 | saprot_seq = torch.load(args.seq_path) 89 | esm_embeddings = torch.load('data/embedding/esm_seq_embedding.pt') 90 | 91 | 92 | print('loading model...') 93 | tokenizer = EsmTokenizer.from_pretrained(args.model_path) 94 | model = EsmForMaskedLM.from_pretrained(args.model_path) 95 | model = model.to(device) 96 | model.eval() 97 | 98 | 99 | print('processing feature...') 100 | uni_seq_embedding_dict ={} 101 | items = 0 102 | with torch.no_grad(): 103 | for uni, comp in tqdm(trn_data.items()): 104 | seq = comp[1] 105 | if len(seq) > 5000: 106 | seq = seq[:5000] 107 | 108 | try: 109 | sa_seq = saprot_seq[uni] 110 | inputs = tokenizer(sa_seq, return_tensors="pt") 111 | inputs = {k: v.to(device) for k, v in inputs.items()} 112 | inputs["output_hidden_states"] = True 113 | outputs = model(**inputs) 114 | hidden = outputs['hidden_states'][-1][:, 1:-1, :] 115 | hidden = hidden.squeeze().mean(0).detach() 116 | uni_seq_embedding_dict[seq] = hidden 117 | 118 | torch.cuda.empty_cache() 119 | 120 | 121 | except: 122 | items += 1 123 | print(f'{uni} is invalid, loading from esm, {items} items loaded from esm...') 124 | uni_seq_embedding_dict[seq] = esm_embeddings[seq] 125 | 126 | 127 | torch.save(uni_seq_embedding_dict, 'data/saprot_seq_embedding.pt') 128 | -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | 10 | from data_utils import * 11 | from itertools import islice 12 | 13 | torch.manual_seed(42) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed(42) 16 | torch.cuda.manual_seed_all(42) 17 | 18 | 19 | class CrossAttention(nn.Module): 20 | def __init__(self, query_input_dim, key_input_dim, output_dim): 21 | super(CrossAttention, self).__init__() 22 | 23 | self.out_dim = output_dim 24 | self.W_Q = nn.Linear(query_input_dim, output_dim) 25 | self.W_K = nn.Linear(key_input_dim, output_dim) 26 | self.W_V = nn.Linear(key_input_dim, output_dim) 27 | self.scale_val = self.out_dim ** 0.5 28 | self.softmax = nn.Softmax(dim=-1) 29 | 30 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 31 | query = self.W_Q(query_input) 32 | key = self.W_K(key_input) 33 | value = self.W_V(value_input) 34 | 35 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 36 | attn_weights = self.softmax(attn_weights) 37 | output = torch.matmul(attn_weights, value) 38 | 39 | return output 40 | 41 | class PretrainedNetwork(nn.Module): 42 | def __init__(self, mol_input_dim=1024, seq_input_dim=1280, hidden_dim=128, output_dim=64, dropout=0.0): 43 | super(PretrainedNetwork, self).__init__() 44 | self.hidden_dim = hidden_dim 45 | 46 | self.lin_mol_embed = nn.Sequential( 47 | nn.Linear(mol_input_dim, 256, bias=False), 48 | nn.Dropout(dropout), 49 | nn.BatchNorm1d(256), 50 | nn.SiLU(), 51 | nn.Linear(256, 256, bias=False), 52 | nn.Dropout(dropout), 53 | nn.BatchNorm1d(256), 54 | nn.SiLU(), 55 | nn.Linear(256, 256, bias=False), 56 | nn.Dropout(dropout), 57 | nn.BatchNorm1d(256), 58 | nn.SiLU(), 59 | nn.Linear(256, hidden_dim, bias=False), 60 | ) 61 | 62 | self.lin_seq_embed = nn.Sequential( 63 | nn.Linear(seq_input_dim, 512, bias=False), 64 | nn.Dropout(dropout), 65 | nn.BatchNorm1d(512), 66 | nn.SiLU(), 67 | nn.Linear(512, 256, bias=False), 68 | nn.Dropout(dropout), 69 | nn.BatchNorm1d(256), 70 | nn.SiLU(), 71 | nn.Linear(256, 256, bias=False), 72 | nn.Dropout(dropout), 73 | nn.BatchNorm1d(256), 74 | nn.SiLU(), 75 | nn.Linear(256, hidden_dim, bias=False), 76 | ) 77 | 78 | 79 | self.lin_out = nn.Sequential( 80 | nn.Linear(2*hidden_dim, hidden_dim, bias=False), 81 | nn.Dropout(dropout), 82 | nn.SiLU(), 83 | nn.Linear(hidden_dim, hidden_dim, bias=False), 84 | nn.Dropout(dropout), 85 | nn.SiLU(), 86 | nn.Linear(hidden_dim, output_dim, bias=False), 87 | nn.Dropout(dropout), 88 | nn.SiLU(), 89 | nn.Linear(output_dim, 16, bias=False), 90 | nn.Dropout(dropout), 91 | nn.Linear(16, 1, bias=False), 92 | ) 93 | 94 | self.cross_attn_seq = CrossAttention( 95 | query_input_dim=hidden_dim, 96 | key_input_dim=hidden_dim, 97 | output_dim=hidden_dim, 98 | ) 99 | 100 | self.cross_attn_mol = CrossAttention( 101 | query_input_dim=hidden_dim, 102 | key_input_dim=hidden_dim, 103 | output_dim=hidden_dim, 104 | ) 105 | 106 | def forward(self, mol_src, seq_src): 107 | # src:(B,H) 108 | b_size = mol_src.size(0) 109 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 110 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 111 | 112 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 113 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 114 | 115 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 116 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 117 | 118 | outputs = self.lin_out(torch.cat([_mol_embedded, _seq_embedded], dim=-1)) 119 | 120 | return outputs 121 | 122 | 123 | def parse_arguments(): 124 | parser = argparse.ArgumentParser(description='Hyperparams') 125 | parser.add_argument('--model_path', type=str, default='model/time/esm_mat_epoch18', help='checkpoint') 126 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 127 | parser.add_argument('--topk', type=int, default=1, help='topk') 128 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 129 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 130 | parser.add_argument('--split_type', type=str, default='time') 131 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 132 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 133 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 134 | parser.add_argument('--mol_embedding_type', type=str, default='mat') 135 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 136 | return parser.parse_args() 137 | 138 | args = parse_arguments() 139 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 140 | 141 | model = PretrainedNetwork( 142 | mol_input_dim=args.mol_input_dim,#1024, 143 | seq_input_dim=1280, 144 | hidden_dim=args.hidden, 145 | output_dim=64, 146 | dropout=args.dropout, 147 | ).to(args.device) 148 | 149 | checkpoint = torch.load(args.model_path, map_location=args.device) 150 | model.load_state_dict(checkpoint['model_state_dict']) 151 | 152 | 153 | def topk_accuracy(logits, labels, k=1): 154 | asrt = torch.argsort(logits, dim=1, descending=True, stable=True) 155 | if (logits == 0).all(dim=-1).sum(): 156 | rand_perm = torch.stack([torch.randperm(logits.size(1)) for _ in range(logits.size(0))]) 157 | indices = torch.where((logits == 0).all(dim=-1) == 1)[0] 158 | asrt[indices] = rand_perm[indices] 159 | 160 | ranking = torch.empty(logits.shape[0], logits.shape[1], dtype = torch.long).scatter_ (1, asrt, torch.arange(logits.shape[1]).repeat(logits.shape[0], 1)) 161 | ranking = (ranking + 1).to(labels.device) 162 | mean_rank = (ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 163 | #mean_rank = (ranking * labels).sum(-1) / (((labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1)) 164 | mean_rank = mean_rank.mean(dim=0) 165 | #mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / ((1.0 / (labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1) + 1e-9) 166 | mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 167 | mrr = mrr.mean(dim=0) 168 | 169 | top_accs = [] 170 | top_accs2 = [] 171 | for k in [1, 2, 3, 4, 5, 10, 20, 50]: 172 | top_acc = ((ranking <= k) * labels.float()).sum(dim=-1) / k 173 | top_acc = top_acc.mean(dim=0) 174 | top_accs.append(top_acc) 175 | 176 | top_acc2 = (((ranking <= k) * labels.float()).sum(dim=-1) > 0).float() 177 | top_acc2 = top_acc2.mean(dim=0) 178 | top_accs2.append(top_acc2) 179 | 180 | return top_accs[0], top_accs[1], top_accs[2], top_accs[3], top_accs[4], top_accs[5], top_accs[6], top_accs[7], top_accs2[0], top_accs2[1], top_accs2[2], top_accs2[3], top_accs2[4], top_accs2[5], top_accs2[6], top_accs2[7], mean_rank, mrr 181 | 182 | 183 | 184 | 185 | @torch.no_grad() 186 | def test(pos_pair_loader, mol_loader, labels, k=1): 187 | model.eval() 188 | torch.set_grad_enabled(False) 189 | 190 | preds = [] 191 | for (seqs, _) in tqdm(tst_pos_pair_loader): 192 | logits = [] 193 | seqs = seqs.repeat(args.batch_size, 1).to(seqs.device) 194 | for (mols, _) in tst_molecules_loader: 195 | b = mols.size(0) 196 | mols = mols.to(args.device) 197 | out = model(mols, seqs[:b, :]) 198 | logits.append(out) 199 | 200 | logits = torch.concat(logits, dim=0).view(1, -1) 201 | preds.append(logits) 202 | preds = torch.cat(preds, dim=0).to(args.device) 203 | 204 | return preds, labels 205 | 206 | 207 | if __name__ == '__main__': 208 | # pos_trn_mols, pos_trn_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 209 | print('loading data...') 210 | pos_tst_mols, pos_tst_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 211 | 212 | 213 | unique_mols = list(set(pos_tst_mols)) 214 | unique_seqs = list(set(pos_tst_seqs)) 215 | 216 | labels = torch.zeros(len(unique_seqs), len(unique_mols)) 217 | indices = [(unique_seqs.index(seq), unique_mols.index(mol)) for seq, mol in zip(pos_tst_seqs, pos_tst_mols)] 218 | 219 | for idx in indices: 220 | labels[idx[0]][idx[1]] = 1 221 | labels = labels.to(args.device) 222 | 223 | 224 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 225 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 226 | 227 | print('loading data...') 228 | pos_seqs = EnzymeDatasetPretrainedSingle(unique_seqs, seq_embedding, positive_sample=True, max_len=args.seq_len) 229 | pos_mols = EnzymeDatasetPretrainedSingle(unique_mols, mol_embedding, positive_sample=True, max_len=100000) 230 | tst_pos_pair_loader = DataLoader(pos_seqs, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 231 | tst_molecules_loader = DataLoader(pos_mols, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 232 | 233 | 234 | # unique_mols = list(set(list(pos_trn_mols + pos_tst_mols))) 235 | # labels = torch.tensor([unique_mols.index(mol) for mol in pos_tst_mols]).to(args.device) 236 | 237 | # pos_tst = EnzymeDataset(pos_tst_mols, pos_tst_seqs, mol_tokenizer, seq_tokenizer, positive_sample=True, max_len=args.seq_len) 238 | # all_molecules = EnzymeDataset(unique_mols, unique_mols, mol_tokenizer, mol_tokenizer, positive_sample=True, max_len=args.seq_len) 239 | 240 | 241 | # tst_pos_pair_loader = DataLoader(pos_tst, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 242 | # tst_molecules_loader = DataLoader(all_molecules, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 243 | 244 | preds, labels = test(tst_pos_pair_loader, tst_molecules_loader, labels, k=args.topk) 245 | 246 | 247 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.detach().cpu(), labels.detach().cpu()) 248 | print(f'Pred Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 249 | 250 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.detach().cpu(), labels.detach().cpu()) 251 | print(f'Data Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 252 | 253 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 254 | print(f'Pred Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 255 | 256 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 257 | print(f'Data Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /retrieval_rnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | 10 | from data_utils import * 11 | from itertools import islice 12 | 13 | torch.manual_seed(42) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed(42) 16 | torch.cuda.manual_seed_all(42) 17 | 18 | 19 | class CrossAttention(nn.Module): 20 | def __init__(self, query_input_dim, key_input_dim, output_dim): 21 | super(CrossAttention, self).__init__() 22 | 23 | self.out_dim = output_dim 24 | self.W_Q = nn.Linear(query_input_dim, output_dim) 25 | self.W_K = nn.Linear(key_input_dim, output_dim) 26 | self.W_V = nn.Linear(key_input_dim, output_dim) 27 | self.scale_val = self.out_dim ** 0.5 28 | self.softmax = nn.Softmax(dim=-1) 29 | 30 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 31 | query = self.W_Q(query_input) 32 | key = self.W_K(key_input) 33 | value = self.W_V(value_input) 34 | 35 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 36 | attn_weights = self.softmax(attn_weights) 37 | output = torch.matmul(attn_weights, value) 38 | 39 | return output 40 | 41 | 42 | class PretrainedNetwork(nn.Module): 43 | def __init__(self, mol_input_dim, seq_input_dim, hidden_dim, output_dim, n_layers=1, dropout=0.0): 44 | super(PretrainedNetwork, self).__init__() 45 | self.hidden_dim = hidden_dim 46 | 47 | self.lin_mol_embed = nn.Sequential( 48 | nn.Linear(mol_input_dim, 256, bias=False), 49 | nn.Dropout(dropout), 50 | nn.BatchNorm1d(256), 51 | nn.SiLU(), 52 | nn.Linear(256, 256, bias=False), 53 | nn.Dropout(dropout), 54 | nn.BatchNorm1d(256), 55 | nn.SiLU(), 56 | nn.Linear(256, hidden_dim, bias=False), 57 | ) 58 | 59 | self.lin_seq_embed = nn.Sequential( 60 | nn.Linear(seq_input_dim, 512, bias=False), 61 | nn.Dropout(dropout), 62 | nn.BatchNorm1d(512), 63 | nn.SiLU(), 64 | nn.Linear(512, 256, bias=False), 65 | nn.Dropout(dropout), 66 | nn.BatchNorm1d(256), 67 | nn.SiLU(), 68 | nn.Linear(256, hidden_dim, bias=False), 69 | ) 70 | 71 | self.gru = nn.GRU(hidden_dim, hidden_dim, n_layers, dropout=dropout, bidirectional=True, batch_first=True) 72 | self.lin_out = nn.Sequential( 73 | nn.Linear(hidden_dim, hidden_dim, bias=False), 74 | nn.Dropout(dropout), 75 | nn.LayerNorm(hidden_dim), 76 | nn.ReLU(), 77 | nn.Linear(hidden_dim, output_dim, bias=False), 78 | nn.Dropout(dropout), 79 | nn.LayerNorm(output_dim), 80 | nn.ReLU(), 81 | nn.Linear(output_dim, 16, bias=False), 82 | nn.Linear(16, 1, bias=False), 83 | ) 84 | 85 | self.cross_attn_seq = CrossAttention( 86 | query_input_dim=hidden_dim, 87 | key_input_dim=hidden_dim, 88 | output_dim=hidden_dim, 89 | ) 90 | 91 | self.cross_attn_mol = CrossAttention( 92 | query_input_dim=hidden_dim, 93 | key_input_dim=hidden_dim, 94 | output_dim=hidden_dim, 95 | ) 96 | 97 | def forward(self, mol_src, seq_src, hidden=None): 98 | # src:(B,T) 99 | b_size = mol_src.size(0) 100 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 101 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 102 | 103 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 104 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 105 | 106 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 107 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 108 | 109 | embedded = torch.cat([_mol_embedded, _seq_embedded], dim=1) #(B,2T,H) 110 | outputs, _ = self.gru(embedded, hidden) #(B,2T,2H) 111 | 112 | # sum bidirectional outputs 113 | outputs = (outputs[:, :, :self.hidden_dim] + 114 | outputs[:, :, self.hidden_dim:]) #(B,2T,H) 115 | 116 | outputs = self.lin_out(outputs.sum(1)) #(B,T,O) 117 | 118 | return outputs 119 | 120 | 121 | def parse_arguments(): 122 | parser = argparse.ArgumentParser(description='Hyperparams') 123 | parser.add_argument('--model_path', type=str, default='model/time/esm_unimol_epoch33_rnn', help='checkpoint') 124 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 125 | parser.add_argument('--topk', type=int, default=1, help='topk') 126 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 127 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 128 | parser.add_argument('--split_type', type=str, default='time') 129 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 130 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 131 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 132 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 133 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 134 | return parser.parse_args() 135 | 136 | args = parse_arguments() 137 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 138 | 139 | model = PretrainedNetwork( 140 | mol_input_dim=args.mol_input_dim,#1024, 141 | seq_input_dim=1280, 142 | hidden_dim=args.hidden, 143 | output_dim=64, 144 | dropout=args.dropout, 145 | ).to(args.device) 146 | 147 | checkpoint = torch.load(args.model_path, map_location=args.device) 148 | model.load_state_dict(checkpoint['model_state_dict']) 149 | 150 | 151 | def topk_accuracy(logits, labels, k=1): 152 | asrt = torch.argsort(logits, dim=1, descending=True, stable=True) 153 | if (logits == 0).all(dim=-1).sum(): 154 | rand_perm = torch.stack([torch.randperm(logits.size(1)) for _ in range(logits.size(0))]) 155 | indices = torch.where((logits == 0).all(dim=-1) == 1)[0] 156 | asrt[indices] = rand_perm[indices] 157 | 158 | ranking = torch.empty(logits.shape[0], logits.shape[1], dtype = torch.long).scatter_ (1, asrt, torch.arange(logits.shape[1]).repeat(logits.shape[0], 1)) 159 | ranking = (ranking + 1).to(labels.device) 160 | mean_rank = (ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 161 | #mean_rank = (ranking * labels).sum(-1) / (((labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1)) 162 | mean_rank = mean_rank.mean(dim=0) 163 | #mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / ((1.0 / (labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1) + 1e-9) 164 | mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 165 | mrr = mrr.mean(dim=0) 166 | 167 | top_accs = [] 168 | top_accs2 = [] 169 | for k in [1, 2, 3, 4, 5, 10, 20, 50]: 170 | top_acc = ((ranking <= k) * labels.float()).sum(dim=-1) / k 171 | top_acc = top_acc.mean(dim=0) 172 | top_accs.append(top_acc) 173 | 174 | top_acc2 = (((ranking <= k) * labels.float()).sum(dim=-1) > 0).float() 175 | top_acc2 = top_acc2.mean(dim=0) 176 | top_accs2.append(top_acc2) 177 | 178 | return top_accs[0], top_accs[1], top_accs[2], top_accs[3], top_accs[4], top_accs[5], top_accs[6], top_accs[7], top_accs2[0], top_accs2[1], top_accs2[2], top_accs2[3], top_accs2[4], top_accs2[5], top_accs2[6], top_accs2[7], mean_rank, mrr 179 | 180 | 181 | 182 | @torch.no_grad() 183 | def test(pos_pair_loader, mol_loader, labels, k=1): 184 | model.eval() 185 | torch.set_grad_enabled(False) 186 | 187 | preds = [] 188 | for (seqs, _) in tqdm(tst_pos_pair_loader): 189 | logits = [] 190 | seqs = seqs.repeat(args.batch_size, 1).to(seqs.device) 191 | for (mols, _) in tst_molecules_loader: 192 | b = mols.size(0) 193 | mols = mols.to(args.device) 194 | out = model(mols, seqs[:b, :]) 195 | logits.append(out) 196 | 197 | logits = torch.concat(logits, dim=0).view(1, -1) 198 | preds.append(logits) 199 | preds = torch.cat(preds, dim=0).to(args.device) 200 | 201 | return preds, labels 202 | 203 | 204 | if __name__ == '__main__': 205 | # pos_trn_mols, pos_trn_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 206 | print('loading data...') 207 | pos_tst_mols, pos_tst_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 208 | 209 | 210 | unique_mols = list(set(pos_tst_mols)) 211 | unique_seqs = list(set(pos_tst_seqs)) 212 | 213 | labels = torch.zeros(len(unique_seqs), len(unique_mols)) 214 | indices = [(unique_seqs.index(seq), unique_mols.index(mol)) for seq, mol in zip(pos_tst_seqs, pos_tst_mols)] 215 | 216 | for idx in indices: 217 | labels[idx[0]][idx[1]] = 1 218 | labels = labels.to(args.device) 219 | 220 | 221 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 222 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 223 | 224 | print('loading data...') 225 | pos_seqs = EnzymeDatasetPretrainedSingle(unique_seqs, seq_embedding, positive_sample=True, max_len=args.seq_len) 226 | pos_mols = EnzymeDatasetPretrainedSingle(unique_mols, mol_embedding, positive_sample=True, max_len=100000) 227 | tst_pos_pair_loader = DataLoader(pos_seqs, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 228 | tst_molecules_loader = DataLoader(pos_mols, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 229 | 230 | 231 | # unique_mols = list(set(list(pos_trn_mols + pos_tst_mols))) 232 | # labels = torch.tensor([unique_mols.index(mol) for mol in pos_tst_mols]).to(args.device) 233 | 234 | # pos_tst = EnzymeDataset(pos_tst_mols, pos_tst_seqs, mol_tokenizer, seq_tokenizer, positive_sample=True, max_len=args.seq_len) 235 | # all_molecules = EnzymeDataset(unique_mols, unique_mols, mol_tokenizer, mol_tokenizer, positive_sample=True, max_len=args.seq_len) 236 | 237 | 238 | # tst_pos_pair_loader = DataLoader(pos_tst, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 239 | # tst_molecules_loader = DataLoader(all_molecules, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 240 | 241 | preds, labels = test(tst_pos_pair_loader, tst_molecules_loader, labels, k=args.topk) 242 | 243 | 244 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.detach().cpu(), labels.detach().cpu()) 245 | print(f'Pred Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 246 | 247 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.detach().cpu(), labels.detach().cpu()) 248 | print(f'Data Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 249 | 250 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 251 | print(f'Pred Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 252 | 253 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 254 | print(f'Data Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /retrieval_tfmr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | 10 | from data_utils import * 11 | from itertools import islice 12 | 13 | torch.manual_seed(42) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed(42) 16 | torch.cuda.manual_seed_all(42) 17 | 18 | 19 | class CrossAttention(nn.Module): 20 | def __init__(self, query_input_dim, key_input_dim, output_dim): 21 | super(CrossAttention, self).__init__() 22 | 23 | self.out_dim = output_dim 24 | self.W_Q = nn.Linear(query_input_dim, output_dim) 25 | self.W_K = nn.Linear(key_input_dim, output_dim) 26 | self.W_V = nn.Linear(key_input_dim, output_dim) 27 | self.scale_val = self.out_dim ** 0.5 28 | self.softmax = nn.Softmax(dim=-1) 29 | 30 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 31 | query = self.W_Q(query_input) 32 | key = self.W_K(key_input) 33 | value = self.W_V(value_input) 34 | 35 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 36 | attn_weights = self.softmax(attn_weights) 37 | output = torch.matmul(attn_weights, value) 38 | 39 | return output 40 | 41 | class PretrainedNetwork(nn.Module): 42 | def __init__(self, mol_input_dim=1024, seq_input_dim=1280, hidden_dim=128, output_dim=64, n_layers=4, dropout=0.0, max_len=2048): 43 | super(PretrainedNetwork, self).__init__() 44 | self.hidden_dim = hidden_dim 45 | 46 | self.lin_mol_embed = nn.Sequential( 47 | nn.Linear(mol_input_dim, 256, bias=False), 48 | nn.Dropout(dropout), 49 | nn.BatchNorm1d(256), 50 | nn.SiLU(), 51 | nn.Linear(256, 256, bias=False), 52 | nn.Dropout(dropout), 53 | nn.BatchNorm1d(256), 54 | nn.SiLU(), 55 | nn.Linear(256, hidden_dim, bias=False), 56 | ) 57 | 58 | self.lin_seq_embed = nn.Sequential( 59 | nn.Linear(seq_input_dim, 512, bias=False), 60 | nn.Dropout(dropout), 61 | nn.BatchNorm1d(512), 62 | nn.SiLU(), 63 | nn.Linear(512, 256, bias=False), 64 | nn.Dropout(dropout), 65 | nn.BatchNorm1d(256), 66 | nn.SiLU(), 67 | nn.Linear(256, hidden_dim, bias=False), 68 | ) 69 | 70 | self.transformer = nn.Transformer(d_model=hidden_dim, nhead=8, num_encoder_layers=n_layers, num_decoder_layers=n_layers, dim_feedforward=hidden_dim, batch_first=True) 71 | self.lin_out = nn.Sequential( 72 | nn.Linear(hidden_dim, hidden_dim, bias=False), 73 | nn.Dropout(dropout), 74 | nn.LayerNorm(hidden_dim), 75 | nn.ReLU(), 76 | nn.Linear(hidden_dim, output_dim, bias=False), 77 | nn.Dropout(dropout), 78 | nn.LayerNorm(output_dim), 79 | nn.ReLU(), 80 | nn.Linear(output_dim, 16, bias=False), 81 | nn.Linear(16, 1, bias=False), 82 | ) 83 | 84 | self.cross_attn_seq = CrossAttention( 85 | query_input_dim=hidden_dim, 86 | key_input_dim=hidden_dim, 87 | output_dim=hidden_dim, 88 | ) 89 | 90 | self.cross_attn_mol = CrossAttention( 91 | query_input_dim=hidden_dim, 92 | key_input_dim=hidden_dim, 93 | output_dim=hidden_dim, 94 | ) 95 | 96 | def forward(self, mol_src, seq_src): 97 | # src:(B,T) 98 | b_size = mol_src.size(0) 99 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 100 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 101 | 102 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 103 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 104 | 105 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 106 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 107 | 108 | embedded = torch.cat([_mol_embedded, _seq_embedded], dim=1) #(B,2T,H) 109 | outputs = self.transformer(embedded, embedded) # (B,2T,H) 110 | 111 | outputs = self.lin_out(outputs.sum(1)) #(B,T,O) 112 | 113 | return outputs 114 | 115 | 116 | def parse_arguments(): 117 | parser = argparse.ArgumentParser(description='Hyperparams') 118 | parser.add_argument('--model_path', type=str, default='model/time/esm_unimol_epoch49_tfmr', help='checkpoint') 119 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 120 | parser.add_argument('--topk', type=int, default=1, help='topk') 121 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 122 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 123 | parser.add_argument('--split_type', type=str, default='time') 124 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 125 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 126 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 127 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 128 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 129 | return parser.parse_args() 130 | 131 | args = parse_arguments() 132 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 133 | 134 | model = PretrainedNetwork( 135 | mol_input_dim=args.mol_input_dim,#1024, 136 | seq_input_dim=1280, 137 | hidden_dim=args.hidden, 138 | output_dim=64, 139 | dropout=args.dropout, 140 | ).to(args.device) 141 | 142 | checkpoint = torch.load(args.model_path, map_location=args.device) 143 | model.load_state_dict(checkpoint['model_state_dict']) 144 | 145 | 146 | def topk_accuracy(logits, labels, k=1): 147 | asrt = torch.argsort(logits, dim=1, descending=True, stable=True) 148 | if (logits == 0).all(dim=-1).sum(): 149 | rand_perm = torch.stack([torch.randperm(logits.size(1)) for _ in range(logits.size(0))]) 150 | indices = torch.where((logits == 0).all(dim=-1) == 1)[0] 151 | asrt[indices] = rand_perm[indices] 152 | 153 | ranking = torch.empty(logits.shape[0], logits.shape[1], dtype = torch.long).scatter_ (1, asrt, torch.arange(logits.shape[1]).repeat(logits.shape[0], 1)) 154 | ranking = (ranking + 1).to(labels.device) 155 | mean_rank = (ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 156 | #mean_rank = (ranking * labels).sum(-1) / (((labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1)) 157 | mean_rank = mean_rank.mean(dim=0) 158 | #mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / ((1.0 / (labels.argsort(dim=-1, descending=True) + 1) * labels).sum(-1) + 1e-9) 159 | mrr = (1.0 / ranking * labels.float()).sum(dim=-1) / (labels.sum(dim=-1)) # (num_seq) 160 | mrr = mrr.mean(dim=0) 161 | 162 | top_accs = [] 163 | top_accs2 = [] 164 | for k in [1, 2, 3, 4, 5, 10, 20, 50]: 165 | top_acc = ((ranking <= k) * labels.float()).sum(dim=-1) / k 166 | top_acc = top_acc.mean(dim=0) 167 | top_accs.append(top_acc) 168 | 169 | top_acc2 = (((ranking <= k) * labels.float()).sum(dim=-1) > 0).float() 170 | top_acc2 = top_acc2.mean(dim=0) 171 | top_accs2.append(top_acc2) 172 | 173 | return top_accs[0], top_accs[1], top_accs[2], top_accs[3], top_accs[4], top_accs[5], top_accs[6], top_accs[7], top_accs2[0], top_accs2[1], top_accs2[2], top_accs2[3], top_accs2[4], top_accs2[5], top_accs2[6], top_accs2[7], mean_rank, mrr 174 | 175 | 176 | 177 | 178 | @torch.no_grad() 179 | def test(pos_pair_loader, mol_loader, labels, k=1): 180 | model.eval() 181 | torch.set_grad_enabled(False) 182 | 183 | preds = [] 184 | for (seqs, _) in tqdm(tst_pos_pair_loader): 185 | logits = [] 186 | seqs = seqs.repeat(args.batch_size, 1).to(seqs.device) 187 | for (mols, _) in tst_molecules_loader: 188 | b = mols.size(0) 189 | mols = mols.to(args.device) 190 | out = model(mols, seqs[:b, :]) 191 | logits.append(out) 192 | 193 | logits = torch.concat(logits, dim=0).view(1, -1) 194 | preds.append(logits) 195 | preds = torch.cat(preds, dim=0).to(args.device) 196 | 197 | return preds, labels 198 | 199 | 200 | if __name__ == '__main__': 201 | # pos_trn_mols, pos_trn_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 202 | print('loading data...') 203 | pos_tst_mols, pos_tst_seqs, _, _ = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 204 | 205 | 206 | unique_mols = list(set(pos_tst_mols)) 207 | unique_seqs = list(set(pos_tst_seqs)) 208 | 209 | labels = torch.zeros(len(unique_seqs), len(unique_mols)) 210 | indices = [(unique_seqs.index(seq), unique_mols.index(mol)) for seq, mol in zip(pos_tst_seqs, pos_tst_mols)] 211 | 212 | for idx in indices: 213 | labels[idx[0]][idx[1]] = 1 214 | labels = labels.to(args.device) 215 | 216 | 217 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 218 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 219 | 220 | print('loading data...') 221 | pos_seqs = EnzymeDatasetPretrainedSingle(unique_seqs, seq_embedding, positive_sample=True, max_len=args.seq_len) 222 | pos_mols = EnzymeDatasetPretrainedSingle(unique_mols, mol_embedding, positive_sample=True, max_len=100000) 223 | tst_pos_pair_loader = DataLoader(pos_seqs, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 224 | tst_molecules_loader = DataLoader(pos_mols, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained_single) 225 | 226 | 227 | # unique_mols = list(set(list(pos_trn_mols + pos_tst_mols))) 228 | # labels = torch.tensor([unique_mols.index(mol) for mol in pos_tst_mols]).to(args.device) 229 | 230 | # pos_tst = EnzymeDataset(pos_tst_mols, pos_tst_seqs, mol_tokenizer, seq_tokenizer, positive_sample=True, max_len=args.seq_len) 231 | # all_molecules = EnzymeDataset(unique_mols, unique_mols, mol_tokenizer, mol_tokenizer, positive_sample=True, max_len=args.seq_len) 232 | 233 | 234 | # tst_pos_pair_loader = DataLoader(pos_tst, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 235 | # tst_molecules_loader = DataLoader(all_molecules, batch_size=1, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn) 236 | 237 | preds, labels = test(tst_pos_pair_loader, tst_molecules_loader, labels, k=args.topk) 238 | 239 | 240 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.detach().cpu(), labels.detach().cpu()) 241 | print(f'Pred Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 242 | 243 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.detach().cpu(), labels.detach().cpu()) 244 | print(f'Data Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 245 | 246 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(preds.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 247 | print(f'Pred Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 248 | 249 | top1_acc, top2_acc, top3_acc, top4_acc, top5_acc, top10_acc, top20_acc, top50_acc, top1_acc2, top2_acc2, top3_acc2, top4_acc2, top5_acc2, top10_acc2, top20_acc2, top50_acc2, mean_rank, mrr = topk_accuracy(labels.transpose(0,1).detach().cpu(), labels.transpose(0,1).detach().cpu()) 250 | print(f'Data Transpose Top1 Acc-N: {top1_acc:.4f}, Top2 Acc-N: {top2_acc:.4f}, Top3 Acc-N: {top3_acc:.4f}, Top4 Acc-N: {top4_acc:.4f}, Top5 Acc-N: {top5_acc:.4f}, Top10 Acc-N: {top10_acc:.4f}, Top20 Acc-N: {top20_acc:.4f}, Top50 Acc-N: {top50_acc:.4f}, Top1 Acc: {top1_acc2:.4f}, Top2 Acc: {top2_acc2:.4f}, Top3 Acc: {top3_acc2:.4f}, Top4 Acc: {top4_acc2:.4f}, Top5 Acc: {top5_acc2:.4f}, Top10 Acc: {top10_acc2:.4f}, Top20 Acc: {top20_acc2:.4f}, Top50 Acc: {top50_acc2:.4f}, Mean Rank: {mean_rank:.4f}, MRR: {mrr:.4f}') 251 | 252 | 253 | 254 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchmetrics.classification import BinaryAccuracy, AUROC 12 | 13 | from data_utils import * 14 | from itertools import islice 15 | 16 | class CrossAttention(nn.Module): 17 | def __init__(self, query_input_dim, key_input_dim, output_dim): 18 | super(CrossAttention, self).__init__() 19 | 20 | self.out_dim = output_dim 21 | self.W_Q = nn.Linear(query_input_dim, output_dim) 22 | self.W_K = nn.Linear(key_input_dim, output_dim) 23 | self.W_V = nn.Linear(key_input_dim, output_dim) 24 | self.scale_val = self.out_dim ** 0.5 25 | self.softmax = nn.Softmax(dim=-1) 26 | 27 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 28 | query = self.W_Q(query_input) 29 | key = self.W_K(key_input) 30 | value = self.W_V(value_input) 31 | 32 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 33 | attn_weights = self.softmax(attn_weights) 34 | output = torch.matmul(attn_weights, value) 35 | 36 | return output 37 | 38 | class PretrainedNetwork(nn.Module): 39 | def __init__(self, mol_input_dim=1024, seq_input_dim=1280, hidden_dim=128, output_dim=64, dropout=0.0): 40 | super(PretrainedNetwork, self).__init__() 41 | self.hidden_dim = hidden_dim 42 | 43 | self.lin_mol_embed = nn.Sequential( 44 | nn.Linear(mol_input_dim, 256, bias=False), 45 | nn.Dropout(dropout), 46 | nn.BatchNorm1d(256), 47 | nn.SiLU(), 48 | nn.Linear(256, 256, bias=False), 49 | nn.Dropout(dropout), 50 | nn.BatchNorm1d(256), 51 | nn.SiLU(), 52 | nn.Linear(256, 256, bias=False), 53 | nn.Dropout(dropout), 54 | nn.BatchNorm1d(256), 55 | nn.SiLU(), 56 | nn.Linear(256, hidden_dim, bias=False), 57 | ) 58 | 59 | self.lin_seq_embed = nn.Sequential( 60 | nn.Linear(seq_input_dim, 512, bias=False), 61 | nn.Dropout(dropout), 62 | nn.BatchNorm1d(512), 63 | nn.SiLU(), 64 | nn.Linear(512, 256, bias=False), 65 | nn.Dropout(dropout), 66 | nn.BatchNorm1d(256), 67 | nn.SiLU(), 68 | nn.Linear(256, 256, bias=False), 69 | nn.Dropout(dropout), 70 | nn.BatchNorm1d(256), 71 | nn.SiLU(), 72 | nn.Linear(256, hidden_dim, bias=False), 73 | ) 74 | 75 | 76 | self.lin_out = nn.Sequential( 77 | nn.Linear(2*hidden_dim, hidden_dim, bias=False), 78 | nn.Dropout(dropout), 79 | nn.SiLU(), 80 | nn.Linear(hidden_dim, hidden_dim, bias=False), 81 | nn.Dropout(dropout), 82 | nn.SiLU(), 83 | nn.Linear(hidden_dim, output_dim, bias=False), 84 | nn.Dropout(dropout), 85 | nn.SiLU(), 86 | nn.Linear(output_dim, 16, bias=False), 87 | nn.Dropout(dropout), 88 | nn.Linear(16, 1, bias=False), 89 | ) 90 | 91 | self.cross_attn_seq = CrossAttention( 92 | query_input_dim=hidden_dim, 93 | key_input_dim=hidden_dim, 94 | output_dim=hidden_dim, 95 | ) 96 | 97 | self.cross_attn_mol = CrossAttention( 98 | query_input_dim=hidden_dim, 99 | key_input_dim=hidden_dim, 100 | output_dim=hidden_dim, 101 | ) 102 | 103 | def forward(self, mol_src, seq_src): 104 | # src:(B,H) 105 | b_size = mol_src.size(0) 106 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 107 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 108 | 109 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 110 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 111 | 112 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 113 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 114 | 115 | outputs = self.lin_out(torch.cat([_mol_embedded, _seq_embedded], dim=-1)) 116 | 117 | return outputs 118 | 119 | 120 | def parse_arguments(): 121 | parser = argparse.ArgumentParser(description='Hyperparams') 122 | parser.add_argument('--epochs', type=int, default=10000, help='number of epochs') 123 | parser.add_argument('--early_stopping', type=int, default=300, help='early stopping') 124 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 125 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 126 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 127 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 128 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 129 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 130 | parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate') 131 | parser.add_argument('--weight_decay', type=float, default=5e-10, help='Adam weight decay') 132 | parser.add_argument('--split_type', type=str, default='mol_smi') 133 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 134 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 135 | parser.add_argument('--checkpoint', type=str, default=None) 136 | return parser.parse_args() 137 | 138 | 139 | PAD_MOL = 0 140 | PAD_SEQ = 1 141 | args = parse_arguments() 142 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 143 | 144 | model = PretrainedNetwork( 145 | mol_input_dim=args.mol_input_dim, #1024, 146 | seq_input_dim=1280, 147 | hidden_dim=args.hidden, 148 | output_dim=64, 149 | dropout=args.dropout, 150 | ).to(args.device) 151 | 152 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 153 | 154 | best_val_loss = float('inf') 155 | best_tst_acc = 0 156 | best_tst_roc = 0 157 | if args.checkpoint is not None: 158 | print('loading model') 159 | checkpoint = torch.load(args.checkpoint, map_location=args.device) 160 | model.load_state_dict(checkpoint['model_state_dict']) 161 | best_val_loss = checkpoint["best_loss"] 162 | 163 | criterion = nn.BCEWithLogitsLoss(reduction='none') 164 | 165 | accuracy = BinaryAccuracy().to('cpu') 166 | auroc = AUROC(task="binary").to('cpu') 167 | 168 | 169 | def train(loader, neg_weight=1, threshold=0.5): 170 | model.train() 171 | torch.set_grad_enabled(True) 172 | 173 | total_loss = 0 174 | pred_labels = [] 175 | true_labels = [] 176 | for (mols, seqs, labels) in tqdm(loader): 177 | optimizer.zero_grad() 178 | 179 | mols = mols.to(args.device) 180 | seqs = seqs.to(args.device) 181 | labels = labels.to(args.device) 182 | 183 | out = model(mols, seqs) 184 | out = out.view(-1) 185 | #loss = criterion(out, labels) 186 | 187 | weights = torch.ones_like(labels).to(args.device) 188 | weights[labels==0] = neg_weight 189 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 190 | #loss = (loss * weights).mean() 191 | 192 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 193 | true_labels.append(labels) 194 | 195 | total_loss += loss.item() * args.batch_size 196 | loss.backward() 197 | optimizer.step() 198 | 199 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 200 | 201 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 202 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 203 | 204 | acc = accuracy(pred_labels, true_labels) 205 | roc = auroc(pred_labels, true_labels) 206 | 207 | return total_loss / len(loader.dataset), acc.item(), roc.item() 208 | 209 | 210 | @torch.no_grad() 211 | def test(loader, neg_weight=1, threshold=0.5): 212 | model.eval() 213 | torch.set_grad_enabled(False) 214 | 215 | total_loss = 0 216 | pred_labels = [] 217 | true_labels = [] 218 | 219 | with torch.no_grad(): 220 | for (mols, seqs, labels) in tqdm(loader): 221 | mols = mols.to(args.device) 222 | seqs = seqs.to(args.device) 223 | labels = labels.to(args.device) 224 | 225 | out = model(mols, seqs) 226 | out = out.view(-1) 227 | #loss = criterion(out, labels) 228 | 229 | weights = torch.ones_like(labels).to(args.device) 230 | weights[labels==0] = neg_weight 231 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 232 | #loss = (loss * weights).mean() 233 | 234 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 235 | true_labels.append(labels) 236 | 237 | total_loss += loss.item() * args.batch_size 238 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 239 | 240 | 241 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 242 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 243 | 244 | acc = accuracy(pred_labels, true_labels) 245 | roc = auroc(pred_labels, true_labels) 246 | 247 | return total_loss / len(loader.dataset), acc.item(), roc.item() 248 | 249 | 250 | if __name__ == '__main__': 251 | 252 | date = datetime.today().strftime('%Y_%m_%d_%H_%M_%S') 253 | with open(f'logger/{date}.txt', 'a') as logger: 254 | logger.write(f'{args}\n') 255 | logger.close() 256 | 257 | print('loading data...') 258 | pos_trn_mols, pos_trn_seqs, neg_trn_mols, neg_trn_seqs = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 259 | pos_tst_mols, pos_tst_seqs, neg_tst_mols, neg_tst_seqs = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 260 | 261 | trn_weight = len(pos_trn_mols) / len(neg_trn_mols) 262 | tst_weight = len(pos_tst_mols) / len(neg_tst_mols) 263 | 264 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 265 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 266 | 267 | 268 | print('loading data...') 269 | pos_trn_val = EnzymeDatasetPretrained(pos_trn_mols, pos_trn_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 270 | neg_trn_val = EnzymeDatasetPretrained(neg_trn_mols, neg_trn_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 271 | trn_val_dataset = pos_trn_val + neg_trn_val 272 | 273 | pos_tst = EnzymeDatasetPretrained(pos_tst_mols, pos_tst_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 274 | neg_tst = EnzymeDatasetPretrained(neg_tst_mols, neg_tst_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 275 | tst_dataset = pos_tst + neg_tst 276 | 277 | trn_size = int(0.9 * len(trn_val_dataset)) 278 | val_size = len(trn_val_dataset) - trn_size 279 | trn_dataset, val_dataset = torch.utils.data.random_split(trn_val_dataset, [trn_size, val_size]) 280 | 281 | 282 | trn_loader = DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 283 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 284 | tst_loader = DataLoader(tst_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 285 | 286 | 287 | current_pointer = 0 288 | 289 | 290 | for epoch in range(args.epochs): 291 | trn_loss, trn_acc, trn_roc = train(trn_loader, neg_weight=trn_weight) 292 | val_loss, val_acc, val_roc = test(val_loader, neg_weight=trn_weight) 293 | tst_loss, tst_acc, tst_roc = test(tst_loader, neg_weight=tst_weight) 294 | 295 | current_pointer += 1 296 | if trn_loss < best_val_loss: 297 | best_val_loss = trn_loss 298 | best_tst_acc = tst_acc 299 | best_tst_roc = tst_roc 300 | current_pointer = 0 301 | 302 | torch.save( 303 | { 304 | "epoch": epoch, 305 | "model_state_dict": model.state_dict(), 306 | "optimizer_state_dict": optimizer.state_dict(), 307 | "best_loss": best_val_loss, 308 | "best_acc": best_tst_acc, 309 | "best_roc": best_tst_roc, 310 | }, 311 | f'model/{args.split_type}/{args.pro_embedding_type}_{args.mol_embedding_type}_epoch{epoch}', 312 | ) 313 | 314 | 315 | print(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}') 316 | 317 | with open(f'logger/{date}.txt', 'a') as logger: 318 | logger.write(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}\n') 319 | logger.close() 320 | 321 | #scheduler.step() 322 | 323 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 324 | if current_pointer == args.early_stopping: 325 | break 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | -------------------------------------------------------------------------------- /train_contra.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchmetrics.classification import BinaryAccuracy, AUROC 12 | 13 | from data_utils import * 14 | from itertools import islice 15 | 16 | 17 | # Contrastive Loss Function 18 | class ContrastiveLoss(torch.nn.Module): 19 | def __init__(self, margin=1.0): 20 | super(ContrastiveLoss, self).__init__() 21 | self.margin = margin 22 | 23 | def forward(self, output1, output2, label, weights=None): 24 | euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2) 25 | 26 | if weights is None: 27 | loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + 28 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 29 | 30 | else: 31 | loss = (1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2) 32 | loss = weights * loss 33 | loss = torch.mean(loss) 34 | return loss 35 | 36 | 37 | class CrossAttention(nn.Module): 38 | def __init__(self, query_input_dim, key_input_dim, output_dim): 39 | super(CrossAttention, self).__init__() 40 | 41 | self.out_dim = output_dim 42 | self.W_Q = nn.Linear(query_input_dim, output_dim) 43 | self.W_K = nn.Linear(key_input_dim, output_dim) 44 | self.W_V = nn.Linear(key_input_dim, output_dim) 45 | self.scale_val = self.out_dim ** 0.5 46 | self.softmax = nn.Softmax(dim=-1) 47 | 48 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 49 | query = self.W_Q(query_input) 50 | key = self.W_K(key_input) 51 | value = self.W_V(value_input) 52 | 53 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 54 | attn_weights = self.softmax(attn_weights) 55 | output = torch.matmul(attn_weights, value) 56 | 57 | return output 58 | 59 | 60 | class PretrainedNetwork(nn.Module): 61 | def __init__(self, mol_input_dim=1024, seq_input_dim=1280, hidden_dim=128, output_dim=64, dropout=0.0): 62 | super(PretrainedNetwork, self).__init__() 63 | self.hidden_dim = hidden_dim 64 | 65 | self.lin_mol_embed = nn.Sequential( 66 | nn.Linear(mol_input_dim, 256, bias=False), 67 | nn.Dropout(dropout), 68 | nn.BatchNorm1d(256), 69 | nn.SiLU(), 70 | nn.Linear(256, 256, bias=False), 71 | nn.Dropout(dropout), 72 | nn.BatchNorm1d(256), 73 | nn.SiLU(), 74 | nn.Linear(256, 256, bias=False), 75 | nn.Dropout(dropout), 76 | nn.BatchNorm1d(256), 77 | nn.SiLU(), 78 | nn.Linear(256, hidden_dim, bias=False), 79 | ) 80 | 81 | self.lin_seq_embed = nn.Sequential( 82 | nn.Linear(seq_input_dim, 512, bias=False), 83 | nn.Dropout(dropout), 84 | nn.BatchNorm1d(512), 85 | nn.SiLU(), 86 | nn.Linear(512, 256, bias=False), 87 | nn.Dropout(dropout), 88 | nn.BatchNorm1d(256), 89 | nn.SiLU(), 90 | nn.Linear(256, 256, bias=False), 91 | nn.Dropout(dropout), 92 | nn.BatchNorm1d(256), 93 | nn.SiLU(), 94 | nn.Linear(256, hidden_dim, bias=False), 95 | ) 96 | 97 | 98 | self.lin_out = nn.Sequential( 99 | nn.Linear(2*hidden_dim, hidden_dim, bias=False), 100 | nn.Dropout(dropout), 101 | nn.SiLU(), 102 | nn.Linear(hidden_dim, hidden_dim, bias=False), 103 | nn.Dropout(dropout), 104 | nn.SiLU(), 105 | nn.Linear(hidden_dim, output_dim, bias=False), 106 | nn.Dropout(dropout), 107 | nn.SiLU(), 108 | nn.Linear(output_dim, 16, bias=False), 109 | nn.Dropout(dropout), 110 | nn.Linear(16, 1, bias=False), 111 | ) 112 | 113 | self.cross_attn_seq = CrossAttention( 114 | query_input_dim=hidden_dim, 115 | key_input_dim=hidden_dim, 116 | output_dim=hidden_dim, 117 | ) 118 | 119 | self.cross_attn_mol = CrossAttention( 120 | query_input_dim=hidden_dim, 121 | key_input_dim=hidden_dim, 122 | output_dim=hidden_dim, 123 | ) 124 | 125 | def forward(self, mol_src, seq_src): 126 | # src:(B,H) 127 | b_size = mol_src.size(0) 128 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 129 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 130 | 131 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 132 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 133 | 134 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded).reshape(b_size, -1) #(B,H) 135 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded).reshape(b_size, -1) #(B,H) 136 | 137 | outputs = self.lin_out(torch.cat([_mol_embedded, _seq_embedded], dim=-1)) 138 | 139 | return outputs, _mol_embedded, _seq_embedded 140 | 141 | 142 | def parse_arguments(): 143 | parser = argparse.ArgumentParser(description='Hyperparams') 144 | parser.add_argument('--epochs', type=int, default=10000, help='number of epochs') 145 | parser.add_argument('--early_stopping', type=int, default=300, help='early stopping') 146 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 147 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 148 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 149 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 150 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 151 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 152 | parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate') 153 | parser.add_argument('--weight_decay', type=float, default=5e-10, help='Adam weight decay') 154 | parser.add_argument('--split_type', type=str, default='mol_smi') 155 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 156 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 157 | parser.add_argument('--checkpoint', type=str, default=None) 158 | return parser.parse_args() 159 | 160 | 161 | PAD_MOL = 0 162 | PAD_SEQ = 1 163 | args = parse_arguments() 164 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 165 | 166 | model = PretrainedNetwork( 167 | mol_input_dim=args.mol_input_dim, #1024, 168 | seq_input_dim=1280, 169 | hidden_dim=args.hidden, 170 | output_dim=64, 171 | dropout=args.dropout, 172 | ).to(args.device) 173 | 174 | Contra_Loss = ContrastiveLoss() 175 | 176 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 177 | 178 | best_val_loss = float('inf') 179 | best_tst_acc = 0 180 | best_tst_roc = 0 181 | if args.checkpoint is not None: 182 | print('loading model') 183 | checkpoint = torch.load(args.checkpoint, map_location=args.device) 184 | model.load_state_dict(checkpoint['model_state_dict']) 185 | best_val_loss = checkpoint["best_loss"] 186 | 187 | criterion = nn.BCEWithLogitsLoss(reduction='none') 188 | 189 | accuracy = BinaryAccuracy().to('cpu') 190 | auroc = AUROC(task="binary").to('cpu') 191 | 192 | 193 | def train(loader, neg_weight=1, threshold=0.5): 194 | model.train() 195 | torch.set_grad_enabled(True) 196 | 197 | total_loss = 0 198 | pred_labels = [] 199 | true_labels = [] 200 | for (mols, seqs, labels) in tqdm(loader): 201 | optimizer.zero_grad() 202 | 203 | mols = mols.to(args.device) 204 | seqs = seqs.to(args.device) 205 | labels = labels.to(args.device) 206 | 207 | out, mol_rep, seq_rep = model(mols, seqs) 208 | out = out.view(-1) 209 | #loss = criterion(out, labels) 210 | 211 | weights = torch.ones_like(labels).to(args.device) 212 | weights[labels==0] = neg_weight 213 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 214 | contra_loss = Contra_Loss(mol_rep, seq_rep, labels, weights) 215 | 216 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 217 | true_labels.append(labels) 218 | 219 | total_loss += loss.item() * args.batch_size 220 | (loss+contra_loss).backward() 221 | optimizer.step() 222 | 223 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 224 | 225 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 226 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 227 | 228 | acc = accuracy(pred_labels, true_labels) 229 | roc = auroc(pred_labels, true_labels) 230 | 231 | return total_loss / len(loader.dataset), acc.item(), roc.item() 232 | 233 | 234 | @torch.no_grad() 235 | def test(loader, neg_weight=1, threshold=0.5): 236 | model.eval() 237 | torch.set_grad_enabled(False) 238 | 239 | total_loss = 0 240 | pred_labels = [] 241 | true_labels = [] 242 | 243 | with torch.no_grad(): 244 | for (mols, seqs, labels) in tqdm(loader): 245 | mols = mols.to(args.device) 246 | seqs = seqs.to(args.device) 247 | labels = labels.to(args.device) 248 | 249 | out, _, _ = model(mols, seqs) 250 | out = out.view(-1) 251 | #loss = criterion(out, labels) 252 | 253 | weights = torch.ones_like(labels).to(args.device) 254 | weights[labels==0] = neg_weight 255 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 256 | #loss = (loss * weights).mean() 257 | 258 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 259 | true_labels.append(labels) 260 | 261 | total_loss += loss.item() * args.batch_size 262 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 263 | 264 | 265 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 266 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 267 | 268 | acc = accuracy(pred_labels, true_labels) 269 | roc = auroc(pred_labels, true_labels) 270 | 271 | return total_loss / len(loader.dataset), acc.item(), roc.item() 272 | 273 | 274 | if __name__ == '__main__': 275 | 276 | date = datetime.today().strftime('%Y_%m_%d_%H_%M_%S') 277 | with open(f'logger/{date}.txt', 'a') as logger: 278 | logger.write(f'{args}\n') 279 | logger.close() 280 | 281 | print('loading data...') 282 | pos_trn_mols, pos_trn_seqs, neg_trn_mols, neg_trn_seqs = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 283 | pos_tst_mols, pos_tst_seqs, neg_tst_mols, neg_tst_seqs = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 284 | 285 | trn_weight = len(pos_trn_mols) / len(neg_trn_mols) 286 | tst_weight = len(pos_tst_mols) / len(neg_tst_mols) 287 | 288 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 289 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 290 | 291 | 292 | print('loading data...') 293 | pos_trn_val = EnzymeDatasetPretrained(pos_trn_mols, pos_trn_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 294 | neg_trn_val = EnzymeDatasetPretrained(neg_trn_mols, neg_trn_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 295 | trn_val_dataset = pos_trn_val + neg_trn_val 296 | 297 | pos_tst = EnzymeDatasetPretrained(pos_tst_mols, pos_tst_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 298 | neg_tst = EnzymeDatasetPretrained(neg_tst_mols, neg_tst_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 299 | tst_dataset = pos_tst + neg_tst 300 | 301 | trn_size = int(0.9 * len(trn_val_dataset)) 302 | val_size = len(trn_val_dataset) - trn_size 303 | trn_dataset, val_dataset = torch.utils.data.random_split(trn_val_dataset, [trn_size, val_size]) 304 | 305 | 306 | trn_loader = DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 307 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 308 | tst_loader = DataLoader(tst_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 309 | 310 | 311 | current_pointer = 0 312 | 313 | 314 | for epoch in range(args.epochs): 315 | trn_loss, trn_acc, trn_roc = train(trn_loader, neg_weight=trn_weight) 316 | val_loss, val_acc, val_roc = test(val_loader, neg_weight=trn_weight) 317 | tst_loss, tst_acc, tst_roc = test(tst_loader, neg_weight=tst_weight) 318 | 319 | current_pointer += 1 320 | if trn_loss < best_val_loss: 321 | best_val_loss = trn_loss 322 | best_tst_acc = tst_acc 323 | best_tst_roc = tst_roc 324 | current_pointer = 0 325 | 326 | torch.save( 327 | { 328 | "epoch": epoch, 329 | "model_state_dict": model.state_dict(), 330 | "optimizer_state_dict": optimizer.state_dict(), 331 | "best_loss": best_val_loss, 332 | "best_acc": best_tst_acc, 333 | "best_roc": best_tst_roc, 334 | }, 335 | f'model/{args.split_type}/{args.pro_embedding_type}_{args.mol_embedding_type}_epoch{epoch}_contrastive', 336 | ) 337 | 338 | 339 | print(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}') 340 | 341 | with open(f'logger/{date}.txt', 'a') as logger: 342 | logger.write(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}\n') 343 | logger.close() 344 | 345 | #scheduler.step() 346 | 347 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 348 | if current_pointer == args.early_stopping: 349 | break 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | -------------------------------------------------------------------------------- /train_rnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchmetrics.classification import BinaryAccuracy, AUROC 12 | 13 | from data_utils import * 14 | from itertools import islice 15 | 16 | class CrossAttention(nn.Module): 17 | def __init__(self, query_input_dim, key_input_dim, output_dim): 18 | super(CrossAttention, self).__init__() 19 | 20 | self.out_dim = output_dim 21 | self.W_Q = nn.Linear(query_input_dim, output_dim) 22 | self.W_K = nn.Linear(key_input_dim, output_dim) 23 | self.W_V = nn.Linear(key_input_dim, output_dim) 24 | self.scale_val = self.out_dim ** 0.5 25 | self.softmax = nn.Softmax(dim=-1) 26 | 27 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 28 | query = self.W_Q(query_input) 29 | key = self.W_K(key_input) 30 | value = self.W_V(value_input) 31 | 32 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 33 | attn_weights = self.softmax(attn_weights) 34 | output = torch.matmul(attn_weights, value) 35 | 36 | return output 37 | 38 | 39 | class PretrainedNetwork(nn.Module): 40 | def __init__(self, mol_input_dim, seq_input_dim, hidden_dim, output_dim, n_layers=1, dropout=0.0): 41 | super(PretrainedNetwork, self).__init__() 42 | self.hidden_dim = hidden_dim 43 | 44 | self.lin_mol_embed = nn.Sequential( 45 | nn.Linear(mol_input_dim, 256, bias=False), 46 | nn.Dropout(dropout), 47 | nn.BatchNorm1d(256), 48 | nn.SiLU(), 49 | nn.Linear(256, 256, bias=False), 50 | nn.Dropout(dropout), 51 | nn.BatchNorm1d(256), 52 | nn.SiLU(), 53 | nn.Linear(256, hidden_dim, bias=False), 54 | ) 55 | 56 | self.lin_seq_embed = nn.Sequential( 57 | nn.Linear(seq_input_dim, 512, bias=False), 58 | nn.Dropout(dropout), 59 | nn.BatchNorm1d(512), 60 | nn.SiLU(), 61 | nn.Linear(512, 256, bias=False), 62 | nn.Dropout(dropout), 63 | nn.BatchNorm1d(256), 64 | nn.SiLU(), 65 | nn.Linear(256, hidden_dim, bias=False), 66 | ) 67 | 68 | self.gru = nn.GRU(hidden_dim, hidden_dim, n_layers, dropout=dropout, bidirectional=True, batch_first=True) 69 | self.lin_out = nn.Sequential( 70 | nn.Linear(hidden_dim, hidden_dim, bias=False), 71 | nn.Dropout(dropout), 72 | nn.LayerNorm(hidden_dim), 73 | nn.ReLU(), 74 | nn.Linear(hidden_dim, output_dim, bias=False), 75 | nn.Dropout(dropout), 76 | nn.LayerNorm(output_dim), 77 | nn.ReLU(), 78 | nn.Linear(output_dim, 16, bias=False), 79 | nn.Linear(16, 1, bias=False), 80 | ) 81 | 82 | self.cross_attn_seq = CrossAttention( 83 | query_input_dim=hidden_dim, 84 | key_input_dim=hidden_dim, 85 | output_dim=hidden_dim, 86 | ) 87 | 88 | self.cross_attn_mol = CrossAttention( 89 | query_input_dim=hidden_dim, 90 | key_input_dim=hidden_dim, 91 | output_dim=hidden_dim, 92 | ) 93 | 94 | def forward(self, mol_src, seq_src, hidden=None): 95 | # src:(B,T) 96 | b_size = mol_src.size(0) 97 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 98 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 99 | 100 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 101 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 102 | 103 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 104 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 105 | 106 | embedded = torch.cat([_mol_embedded, _seq_embedded], dim=1) #(B,2T,H) 107 | outputs, _ = self.gru(embedded, hidden) #(B,2T,2H) 108 | 109 | # sum bidirectional outputs 110 | outputs = (outputs[:, :, :self.hidden_dim] + 111 | outputs[:, :, self.hidden_dim:]) #(B,2T,H) 112 | 113 | outputs = self.lin_out(outputs.sum(1)) #(B,T,O) 114 | 115 | return outputs 116 | 117 | 118 | 119 | def parse_arguments(): 120 | parser = argparse.ArgumentParser(description='Hyperparams') 121 | parser.add_argument('--epochs', type=int, default=10000, help='number of epochs') 122 | parser.add_argument('--early_stopping', type=int, default=300, help='early stopping') 123 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 124 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 125 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 126 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 127 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 128 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 129 | parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate') 130 | parser.add_argument('--weight_decay', type=float, default=5e-10, help='Adam weight decay') 131 | parser.add_argument('--split_type', type=str, default='mol_smi') 132 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 133 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 134 | parser.add_argument('--checkpoint', type=str, default=None) 135 | return parser.parse_args() 136 | 137 | 138 | PAD_MOL = 0 139 | PAD_SEQ = 1 140 | args = parse_arguments() 141 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 142 | 143 | model = PretrainedNetwork( 144 | mol_input_dim=args.mol_input_dim, #1024, 145 | seq_input_dim=1280, 146 | hidden_dim=args.hidden, 147 | output_dim=64, 148 | dropout=args.dropout, 149 | ).to(args.device) 150 | 151 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 152 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 153 | #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 154 | 155 | best_val_loss = float('inf') 156 | best_tst_acc = 0 157 | best_tst_roc = 0 158 | if args.checkpoint is not None: 159 | print('loading model') 160 | checkpoint = torch.load(args.checkpoint, map_location=args.device) 161 | model.load_state_dict(checkpoint['model_state_dict']) 162 | best_val_loss = checkpoint["best_loss"] 163 | 164 | criterion = nn.BCEWithLogitsLoss(reduction='none') 165 | 166 | accuracy = BinaryAccuracy().to('cpu') 167 | auroc = AUROC(task="binary").to('cpu') 168 | 169 | 170 | def train(loader, neg_weight=1, threshold=0.5): 171 | model.train() 172 | torch.set_grad_enabled(True) 173 | 174 | total_loss = 0 175 | pred_labels = [] 176 | true_labels = [] 177 | for (mols, seqs, labels) in tqdm(loader): 178 | optimizer.zero_grad() 179 | 180 | mols = mols.to(args.device) 181 | seqs = seqs.to(args.device) 182 | labels = labels.to(args.device) 183 | 184 | out = model(mols, seqs) 185 | out = out.view(-1) 186 | #loss = criterion(out, labels) 187 | 188 | weights = torch.ones_like(labels).to(args.device) 189 | weights[labels==0] = neg_weight 190 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 191 | #loss = (loss * weights).mean() 192 | 193 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 194 | true_labels.append(labels) 195 | 196 | total_loss += loss.item() * args.batch_size 197 | loss.backward() 198 | optimizer.step() 199 | 200 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 201 | 202 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 203 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 204 | 205 | acc = accuracy(pred_labels, true_labels) 206 | roc = auroc(pred_labels, true_labels) 207 | 208 | return total_loss / len(loader.dataset), acc.item(), roc.item() 209 | 210 | 211 | @torch.no_grad() 212 | def test(loader, neg_weight=1, threshold=0.5): 213 | model.eval() 214 | torch.set_grad_enabled(False) 215 | 216 | total_loss = 0 217 | pred_labels = [] 218 | true_labels = [] 219 | 220 | with torch.no_grad(): 221 | for (mols, seqs, labels) in tqdm(loader): 222 | mols = mols.to(args.device) 223 | seqs = seqs.to(args.device) 224 | labels = labels.to(args.device) 225 | 226 | out = model(mols, seqs) 227 | out = out.view(-1) 228 | #loss = criterion(out, labels) 229 | 230 | weights = torch.ones_like(labels).to(args.device) 231 | weights[labels==0] = neg_weight 232 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 233 | #loss = (loss * weights).mean() 234 | 235 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 236 | true_labels.append(labels) 237 | 238 | total_loss += loss.item() * args.batch_size 239 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 240 | 241 | 242 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 243 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 244 | 245 | acc = accuracy(pred_labels, true_labels) 246 | roc = auroc(pred_labels, true_labels) 247 | 248 | return total_loss / len(loader.dataset), acc.item(), roc.item() 249 | 250 | 251 | if __name__ == '__main__': 252 | 253 | date = datetime.today().strftime('%Y_%m_%d_%H_%M_%S') 254 | with open(f'logger/{date}.txt', 'a') as logger: 255 | logger.write(f'{args}\n') 256 | logger.close() 257 | 258 | print('loading data...') 259 | pos_trn_mols, pos_trn_seqs, neg_trn_mols, neg_trn_seqs = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 260 | pos_tst_mols, pos_tst_seqs, neg_tst_mols, neg_tst_seqs = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 261 | 262 | trn_weight = len(pos_trn_mols) / len(neg_trn_mols) 263 | tst_weight = len(pos_tst_mols) / len(neg_tst_mols) 264 | 265 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 266 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 267 | 268 | 269 | print('loading data...') 270 | pos_trn_val = EnzymeDatasetPretrained(pos_trn_mols, pos_trn_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 271 | neg_trn_val = EnzymeDatasetPretrained(neg_trn_mols, neg_trn_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 272 | trn_val_dataset = pos_trn_val + neg_trn_val 273 | 274 | pos_tst = EnzymeDatasetPretrained(pos_tst_mols, pos_tst_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 275 | neg_tst = EnzymeDatasetPretrained(neg_tst_mols, neg_tst_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 276 | tst_dataset = pos_tst + neg_tst 277 | 278 | trn_size = int(0.9 * len(trn_val_dataset)) 279 | val_size = len(trn_val_dataset) - trn_size 280 | trn_dataset, val_dataset = torch.utils.data.random_split(trn_val_dataset, [trn_size, val_size]) 281 | 282 | 283 | trn_loader = DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 284 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 285 | tst_loader = DataLoader(tst_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 286 | 287 | 288 | current_pointer = 0 289 | 290 | 291 | for epoch in range(args.epochs): 292 | trn_loss, trn_acc, trn_roc = train(trn_loader, neg_weight=trn_weight) 293 | val_loss, val_acc, val_roc = test(val_loader, neg_weight=trn_weight) 294 | tst_loss, tst_acc, tst_roc = test(tst_loader, neg_weight=tst_weight) 295 | 296 | current_pointer += 1 297 | if trn_loss < best_val_loss: 298 | best_val_loss = trn_loss 299 | best_tst_acc = tst_acc 300 | best_tst_roc = tst_roc 301 | current_pointer = 0 302 | 303 | torch.save( 304 | { 305 | "epoch": epoch, 306 | "model_state_dict": model.state_dict(), 307 | "optimizer_state_dict": optimizer.state_dict(), 308 | "best_loss": best_val_loss, 309 | "best_acc": best_tst_acc, 310 | "best_roc": best_tst_roc, 311 | }, 312 | f'model/{args.split_type}/{args.pro_embedding_type}_{args.mol_embedding_type}_epoch{epoch}_rnn', 313 | ) 314 | 315 | 316 | print(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}') 317 | 318 | with open(f'logger/{date}.txt', 'a') as logger: 319 | logger.write(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}\n') 320 | logger.close() 321 | 322 | #scheduler.step() 323 | 324 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 325 | if current_pointer == args.early_stopping: 326 | break 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /train_tfmr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchmetrics.classification import BinaryAccuracy, AUROC 12 | 13 | from data_utils import * 14 | from itertools import islice 15 | 16 | class CrossAttention(nn.Module): 17 | def __init__(self, query_input_dim, key_input_dim, output_dim): 18 | super(CrossAttention, self).__init__() 19 | 20 | self.out_dim = output_dim 21 | self.W_Q = nn.Linear(query_input_dim, output_dim) 22 | self.W_K = nn.Linear(key_input_dim, output_dim) 23 | self.W_V = nn.Linear(key_input_dim, output_dim) 24 | self.scale_val = self.out_dim ** 0.5 25 | self.softmax = nn.Softmax(dim=-1) 26 | 27 | def forward(self, query_input, key_input, value_input, query_input_mask=None, key_input_mask=None): 28 | query = self.W_Q(query_input) 29 | key = self.W_K(key_input) 30 | value = self.W_V(value_input) 31 | 32 | attn_weights = torch.matmul(query, key.transpose(1, 2)) / self.scale_val 33 | attn_weights = self.softmax(attn_weights) 34 | output = torch.matmul(attn_weights, value) 35 | 36 | return output 37 | 38 | class PretrainedNetwork(nn.Module): 39 | def __init__(self, mol_input_dim=1024, seq_input_dim=1280, hidden_dim=128, output_dim=64, n_layers=4, dropout=0.0, max_len=2048): 40 | super(PretrainedNetwork, self).__init__() 41 | self.hidden_dim = hidden_dim 42 | 43 | self.lin_mol_embed = nn.Sequential( 44 | nn.Linear(mol_input_dim, 256, bias=False), 45 | nn.Dropout(dropout), 46 | nn.BatchNorm1d(256), 47 | nn.SiLU(), 48 | nn.Linear(256, 256, bias=False), 49 | nn.Dropout(dropout), 50 | nn.BatchNorm1d(256), 51 | nn.SiLU(), 52 | nn.Linear(256, hidden_dim, bias=False), 53 | ) 54 | 55 | self.lin_seq_embed = nn.Sequential( 56 | nn.Linear(seq_input_dim, 512, bias=False), 57 | nn.Dropout(dropout), 58 | nn.BatchNorm1d(512), 59 | nn.SiLU(), 60 | nn.Linear(512, 256, bias=False), 61 | nn.Dropout(dropout), 62 | nn.BatchNorm1d(256), 63 | nn.SiLU(), 64 | nn.Linear(256, hidden_dim, bias=False), 65 | ) 66 | 67 | self.transformer = nn.Transformer(d_model=hidden_dim, nhead=8, num_encoder_layers=n_layers, num_decoder_layers=n_layers, dim_feedforward=hidden_dim, batch_first=True) 68 | self.lin_out = nn.Sequential( 69 | nn.Linear(hidden_dim, hidden_dim, bias=False), 70 | nn.Dropout(dropout), 71 | nn.LayerNorm(hidden_dim), 72 | nn.ReLU(), 73 | nn.Linear(hidden_dim, output_dim, bias=False), 74 | nn.Dropout(dropout), 75 | nn.LayerNorm(output_dim), 76 | nn.ReLU(), 77 | nn.Linear(output_dim, 16, bias=False), 78 | nn.Linear(16, 1, bias=False), 79 | ) 80 | 81 | self.cross_attn_seq = CrossAttention( 82 | query_input_dim=hidden_dim, 83 | key_input_dim=hidden_dim, 84 | output_dim=hidden_dim, 85 | ) 86 | 87 | self.cross_attn_mol = CrossAttention( 88 | query_input_dim=hidden_dim, 89 | key_input_dim=hidden_dim, 90 | output_dim=hidden_dim, 91 | ) 92 | 93 | def forward(self, mol_src, seq_src): 94 | # src:(B,T) 95 | b_size = mol_src.size(0) 96 | mol_embedded = self.lin_mol_embed(mol_src) #(B,H) 97 | seq_embedded = self.lin_seq_embed(seq_src) #(B,H) 98 | 99 | mol_embedded = mol_embedded.reshape(b_size, 1, -1) 100 | seq_embedded = seq_embedded.reshape(b_size, 1, -1) 101 | 102 | _mol_embedded = self.cross_attn_mol(mol_embedded, seq_embedded, seq_embedded) #(B,H) 103 | _seq_embedded = self.cross_attn_seq(seq_embedded, mol_embedded, mol_embedded) #(B,H) 104 | 105 | embedded = torch.cat([_mol_embedded, _seq_embedded], dim=1) #(B,2T,H) 106 | outputs = self.transformer(embedded, embedded) # (B,2T,H) 107 | 108 | outputs = self.lin_out(outputs.sum(1)) #(B,T,O) 109 | 110 | return outputs 111 | 112 | 113 | 114 | def parse_arguments(): 115 | parser = argparse.ArgumentParser(description='Hyperparams') 116 | parser.add_argument('--epochs', type=int, default=10000, help='number of epochs') 117 | parser.add_argument('--early_stopping', type=int, default=300, help='early stopping') 118 | parser.add_argument('--seq_len', type=int, default=5000, help='maximum length') 119 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 120 | parser.add_argument('--n_worker', type=int, default=0, help='number of workers') 121 | parser.add_argument('--hidden', type=int, default=128, help='length of hidden vector') 122 | parser.add_argument('--mol_input_dim', type=int, default=512, help='length of hidden vector') 123 | parser.add_argument('--dropout', type=float, default=0., help='Adam learning rate') 124 | parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate') 125 | parser.add_argument('--weight_decay', type=float, default=5e-10, help='Adam weight decay') 126 | parser.add_argument('--split_type', type=str, default='mol_smi') 127 | parser.add_argument('--mol_embedding_type', type=str, default='unimol') 128 | parser.add_argument('--pro_embedding_type', type=str, default='esm') 129 | parser.add_argument('--checkpoint', type=str, default=None) 130 | return parser.parse_args() 131 | 132 | 133 | PAD_MOL = 0 134 | PAD_SEQ = 1 135 | args = parse_arguments() 136 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 137 | 138 | model = PretrainedNetwork( 139 | mol_input_dim=args.mol_input_dim, #1024, 140 | seq_input_dim=1280, 141 | hidden_dim=args.hidden, 142 | output_dim=64, 143 | dropout=args.dropout, 144 | ).to(args.device) 145 | 146 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 147 | 148 | best_val_loss = float('inf') 149 | best_tst_acc = 0 150 | best_tst_roc = 0 151 | if args.checkpoint is not None: 152 | print('loading model') 153 | checkpoint = torch.load(args.checkpoint, map_location=args.device) 154 | model.load_state_dict(checkpoint['model_state_dict']) 155 | best_val_loss = checkpoint["best_loss"] 156 | 157 | criterion = nn.BCEWithLogitsLoss(reduction='none') 158 | 159 | accuracy = BinaryAccuracy().to('cpu') 160 | auroc = AUROC(task="binary").to('cpu') 161 | 162 | 163 | def train(loader, neg_weight=1, threshold=0.5): 164 | model.train() 165 | torch.set_grad_enabled(True) 166 | 167 | total_loss = 0 168 | pred_labels = [] 169 | true_labels = [] 170 | for (mols, seqs, labels) in tqdm(loader): 171 | optimizer.zero_grad() 172 | 173 | mols = mols.to(args.device) 174 | seqs = seqs.to(args.device) 175 | labels = labels.to(args.device) 176 | 177 | out = model(mols, seqs) 178 | out = out.view(-1) 179 | #loss = criterion(out, labels) 180 | 181 | weights = torch.ones_like(labels).to(args.device) 182 | weights[labels==0] = neg_weight 183 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 184 | #loss = (loss * weights).mean() 185 | 186 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 187 | true_labels.append(labels) 188 | 189 | total_loss += loss.item() * args.batch_size 190 | loss.backward() 191 | optimizer.step() 192 | 193 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 194 | 195 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 196 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 197 | 198 | acc = accuracy(pred_labels, true_labels) 199 | roc = auroc(pred_labels, true_labels) 200 | 201 | return total_loss / len(loader.dataset), acc.item(), roc.item() 202 | 203 | 204 | @torch.no_grad() 205 | def test(loader, neg_weight=1, threshold=0.5): 206 | model.eval() 207 | torch.set_grad_enabled(False) 208 | 209 | total_loss = 0 210 | pred_labels = [] 211 | true_labels = [] 212 | 213 | with torch.no_grad(): 214 | for (mols, seqs, labels) in tqdm(loader): 215 | mols = mols.to(args.device) 216 | seqs = seqs.to(args.device) 217 | labels = labels.to(args.device) 218 | 219 | out = model(mols, seqs) 220 | out = out.view(-1) 221 | #loss = criterion(out, labels) 222 | 223 | weights = torch.ones_like(labels).to(args.device) 224 | weights[labels==0] = neg_weight 225 | loss = F.binary_cross_entropy_with_logits(out, labels, weight=weights) 226 | #loss = (loss * weights).mean() 227 | 228 | pred_labels.append((torch.sigmoid(out) > threshold).long()) 229 | true_labels.append(labels) 230 | 231 | total_loss += loss.item() * args.batch_size 232 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 233 | 234 | 235 | pred_labels = torch.cat(pred_labels, dim=-1).detach().cpu() 236 | true_labels = torch.cat(true_labels, dim=-1).detach().cpu() 237 | 238 | acc = accuracy(pred_labels, true_labels) 239 | roc = auroc(pred_labels, true_labels) 240 | 241 | return total_loss / len(loader.dataset), acc.item(), roc.item() 242 | 243 | 244 | if __name__ == '__main__': 245 | 246 | date = datetime.today().strftime('%Y_%m_%d_%H_%M_%S') 247 | with open(f'logger/{date}.txt', 'a') as logger: 248 | logger.write(f'{args}\n') 249 | logger.close() 250 | 251 | print('loading data...') 252 | pos_trn_mols, pos_trn_seqs, neg_trn_mols, neg_trn_seqs = get_samples(f'data/new_{args.split_type}/positive_train_val_{args.split_type}.pt', f'data/new_{args.split_type}/negative_train_val_{args.split_type}.pt') 253 | pos_tst_mols, pos_tst_seqs, neg_tst_mols, neg_tst_seqs = get_samples(f'data/new_{args.split_type}/positive_test_{args.split_type}.pt', f'data/new_{args.split_type}/negative_test_{args.split_type}.pt') 254 | 255 | trn_weight = len(pos_trn_mols) / len(neg_trn_mols) 256 | tst_weight = len(pos_tst_mols) / len(neg_tst_mols) 257 | 258 | mol_embedding = torch.load(f'data/embedding/{args.mol_embedding_type}_mol_embedding.pt') 259 | seq_embedding = torch.load(f'data/embedding/{args.pro_embedding_type}_seq_embedding.pt') 260 | 261 | 262 | print('loading data...') 263 | pos_trn_val = EnzymeDatasetPretrained(pos_trn_mols, pos_trn_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 264 | neg_trn_val = EnzymeDatasetPretrained(neg_trn_mols, neg_trn_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 265 | trn_val_dataset = pos_trn_val + neg_trn_val 266 | 267 | pos_tst = EnzymeDatasetPretrained(pos_tst_mols, pos_tst_seqs, mol_embedding, seq_embedding, positive_sample=True, max_len=args.seq_len) 268 | neg_tst = EnzymeDatasetPretrained(neg_tst_mols, neg_tst_seqs, mol_embedding, seq_embedding, positive_sample=False, max_len=args.seq_len) 269 | tst_dataset = pos_tst + neg_tst 270 | 271 | trn_size = int(0.9 * len(trn_val_dataset)) 272 | val_size = len(trn_val_dataset) - trn_size 273 | trn_dataset, val_dataset = torch.utils.data.random_split(trn_val_dataset, [trn_size, val_size]) 274 | 275 | 276 | trn_loader = DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 277 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 278 | tst_loader = DataLoader(tst_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker, collate_fn=collate_fn_pretrained) 279 | 280 | 281 | current_pointer = 0 282 | 283 | 284 | for epoch in range(args.epochs): 285 | trn_loss, trn_acc, trn_roc = train(trn_loader, neg_weight=trn_weight) 286 | val_loss, val_acc, val_roc = test(val_loader, neg_weight=trn_weight) 287 | tst_loss, tst_acc, tst_roc = test(tst_loader, neg_weight=tst_weight) 288 | 289 | current_pointer += 1 290 | if trn_loss < best_val_loss: 291 | best_val_loss = trn_loss 292 | best_tst_acc = tst_acc 293 | best_tst_roc = tst_roc 294 | current_pointer = 0 295 | 296 | torch.save( 297 | { 298 | "epoch": epoch, 299 | "model_state_dict": model.state_dict(), 300 | "optimizer_state_dict": optimizer.state_dict(), 301 | "best_loss": best_val_loss, 302 | "best_acc": best_tst_acc, 303 | "best_roc": best_tst_roc, 304 | }, 305 | f'model/{args.split_type}/{args.pro_embedding_type}_{args.mol_embedding_type}_epoch{epoch}_tfmr', 306 | ) 307 | 308 | 309 | print(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}') 310 | 311 | with open(f'logger/{date}.txt', 'a') as logger: 312 | logger.write(f'Epoch: {epoch:04d}, Trn Loss: {trn_loss:.4f}, Trn Acc: {trn_acc:.4f}, Trn ROC: {trn_roc:.4f}, Val Loss: {val_loss:.4f}, Tst Loss: {tst_loss:.4f}, Tst Acc: {tst_acc:.4f}, Tst ROC: {tst_roc:.4f}, Best Val Loss: {best_val_loss:.4f}, Best Tst Acc: {best_tst_acc:.4f}, Best Tst ROC: {best_tst_roc:.4f}\n') 313 | logger.close() 314 | 315 | #scheduler.step() 316 | 317 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 318 | if current_pointer == args.early_stopping: 319 | break 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | --------------------------------------------------------------------------------