├── .github └── ISSUE_TEMPLATE │ ├── 1-bug_report.yaml │ └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── benchmark_common.py ├── benchmark_evaluation.py ├── benchmark_test.py ├── benchmark_train.py ├── config.py ├── data ├── README.md └── dict │ ├── README.md │ ├── dict_label_task1.h5 │ ├── dict_label_task2.h5 │ └── dict_label_task3.h5 ├── env.yaml ├── model └── README.md ├── prepare_production_data.ipynb ├── production.py ├── tasks ├── README.md ├── prepare_task_dataset.ipynb ├── task1.ipynb ├── task2.ipynb └── task3.ipynb ├── temp.ipynb ├── tmp └── readme.md ├── tools ├── Attention.py ├── __init__.py ├── baselines.ipynb ├── embdding_onehot.py ├── embedding_esm.py ├── exact_ec_from_uniprot.py ├── filetool.py ├── funclib.py └── uniprottool.py └── update_production.ipynb /.github/ISSUE_TEMPLATE/1-bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Report Problem / 故障报告 2 | description: Report a problem with ECRecer / 报告使用网站站时出现的问题 3 | labels: ["problem report"] 4 | body: 5 | - type: markdown 6 | id: preamble 7 | attributes: 8 | value: | 9 | 如果在使用ECRecer网站时遇到问题,请填写以下表单。 10 | If you encountered a problem using ECRecer, please fill in the form below. 11 | - type: checkboxes 12 | id: prerequisites 13 | attributes: 14 | label: 先决条件 / Prerequisites 15 | options: 16 | - label: | 17 | 该问题未被其他 [issues](https://github.com/kingstdio/ECRECer/issues) 提出。 18 | No previous issues have reported this problem. 19 | required: true 20 | - type: textarea 21 | id: what_happened 22 | attributes: 23 | label: 发生了什么 / What happened 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: expected_behavior 28 | attributes: 29 | label: 期望的现象 / What you expected 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: how_to_reproduce 34 | attributes: 35 | label: 如何重现此问题 / How to reproduce 36 | validations: 37 | required: true 38 | - type: input 39 | id: os_version 40 | attributes: 41 | label: 操作系统版本 / OS version 42 | validations: 43 | required: true 44 | - type: input 45 | id: browser 46 | attributes: 47 | label: 浏览器或客户端版本 / Browser/client version 48 | validations: 49 | required: true 50 | - type: textarea 51 | id: other_env 52 | attributes: 53 | label: 其他环境 / Other environments 54 | - type: textarea 55 | id: others 56 | attributes: 57 | label: 其他 / Anything else 58 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | #exclude tempfile 133 | tmp/ 134 | temp/ 135 | *.feather 136 | *.fasta 137 | *.dmnd 138 | *.txt 139 | *.feature 140 | *.tsv 141 | *.npy 142 | *.csv 143 | *.h5 144 | 145 | baselines/ 146 | results/ 147 | 148 | ._.DS_Store 149 | .DS_Store 150 | *.gz 151 | 152 | *.model 153 | *.st 154 | 155 | model/single_multi.model 156 | 157 | *.faa 158 | model/multi_many.model 159 | model/multi_many.model 160 | model/single_multi.model 161 | 162 | data/uniprot/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 kingstdio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 11 | 12 | # DMLF: Enzyme Commission Number Predicting and Benchmarking with Multi-agent Dual-core Learning 13 | 14 | This repo contains source codes for a EC prediction tool namely ECRECer, which is an implementation of our paper: 「Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework」 15 | 16 | Detailed information about the framework can be found in our paper 17 | 18 | ```bash 19 | 1. Zhenkun Shi, Qianqian Yuan, Ruoyu Wang, Hoaran Li, Xiaoping Liao*, Hongwu Ma* (2022). ECRECer: Enzyme Commission Number Recommendation and Benchmarking based on Multiagent Dual-core Learning. arXiv preprint arXiv:2202.03632. 20 | 21 | 2. Zhenkun Shi, Rui Deng, Qianqian Yuan, Zhitao Mao, Ruoyu Wang, Haoran Li, Xiaoping Liao*, Hongwu Ma* (2023). Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework. Research. 22 | ``` 23 | 24 | 25 | ## Usage 26 | ### For simply use our tools to predict EC numbers, please visit ECRECEer websiet at https://ecrecer.biodesign.ac.cn 27 | 28 | 29 | ### For users who want to run ECRECer locally, please follow the steps below: 30 | We provide docker image and singularity image for users to run ECRECer locally. 31 | 32 | > Docker image: 33 | 34 | ```bash 35 | # 1. pull ecrecer docker image 36 | docker pull kingstdio/ecrecer 37 | 38 | # 2. run ecrecer docker image 39 | # gpu version: 40 | sudo docker run -it -d --gpus all --name ecrecer -v ~/:/home/ kingstdio/ecrecer #~/ is your fasta file folder 41 | # cpu version: 42 | sudo docker run -it -d --name ecrecer -v ~/:/home/ kingstdio/ecrecer #~/ is your fasta file folder 43 | 44 | # 3. run ECRECer prediction 45 | 46 | sudo docker exec ecrecer python /ecrecer/production.py -i /home/input_fasta_file.fasta -o /home/output_tsv_file.tsv -mode h -topk 10 47 | 48 | #-topk: top k predicted EC numbers 49 | #-mode p: prediction mode, predict EC numbers only 50 | #-mode r: recommendation mode, recommend EC numbers with predicted probabilities, the higher the better 51 | #-mode h: hybird mode, use prediction, recommendation and sequence alignment methods 52 | 53 | ``` 54 | 55 | > Singularity image: 56 | 57 | ```bash 58 | # 1. pull ecrecer singularity image 59 | 60 | # Image ~= 11GB, may take a while to download 61 | wget -c https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/sifimages/ecrecer.sif 62 | 63 | # 2. run ecrecer singularity image 64 | # gpu version: 65 | singularity run --nv ecrecer.sif python /ecrecer/production.py -i input_fasta_file.fasta -o output_tsv_file.tsv -mode h -topk 10 66 | # cpu version: 67 | singularity run ecrecer.sif python /ecrecer/production.py -i input_fasta_file.fasta -o output_tsv_file.tsv -mode h -topk 10 68 | 69 | #-topk: top k predicted EC numbers 70 | #-mode p: prediction mode, predict EC numbers only 71 | #-mode r: recommendation mode, recommend EC numbers with predicted probabilities, the higher the better 72 | #-mode h: hybird mode, use prediction, recommendation and sequence alignment methods 73 | 74 | ``` 75 | 76 | 77 | 78 | 79 | # To re-implement our experiments or offline use, pls use read the details below: 80 | 81 | # Prerequisites 82 | 83 | + Python >= 3.6 84 | + Sklearn 85 | + Xgboost 86 | + conda 87 | + jupyter lab 88 | + ... 89 | 90 | > Create conda env use [env.yaml](./env.yaml) 91 | 92 | ```python 93 | git clone git@github.com:kingstdio/ECRECer.git 94 | conda env create -f env.yaml 95 | ``` 96 | 97 | # Preprocessing 98 | 99 | Download and prepare the data set use the. 100 | 101 | > [prepare_task_dataset.ipynb](./prepare_task_dataset.ipynb) 102 | 103 | Or directly download the preprocessed data from [aws public dataset](https://tibd-public-datasets.s3.amazonaws.com/ecrecer/ecrecer_datasets.zip) and put it in the rootfolder/data/datasets/ 104 | 105 | 118 | 119 | # High throughput benchmarking 120 | 121 | # Train 122 | 123 | ```python 124 | python benchmark_train.py 125 | ``` 126 | 127 | # Test 128 | 129 | ```python 130 | python benchmark_test.py 131 | ``` 132 | 133 | # Evaluation 134 | 135 | ```python 136 | python benchmark_evaluation.py 137 | ``` 138 | 139 | # Production 140 | 141 | ```python 142 | python production.py -i input_fasta_file -o output_tsv_file -mode [p|r] -topk 5 143 | ``` 144 | 145 | # Citations 146 | 147 | If you find these methods valuable for your research, we kindly request that you reference the pertinent paper: 148 | 149 | ```bib 150 | @article{shi2023enzyme, 151 | title={Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework}, 152 | author={Shi, Zhenkun and Deng, Rui and Yuan, Qianqian and Mao, Zhitao and Wang, Ruoyu and Li, Haoran and Liao, Xiaoping and Ma, Hongwu}, 153 | journal={Research}, 154 | year={2023}, 155 | publisher={AAAS} 156 | } 157 | ``` 158 | 159 | ## Stargazers over time 160 | 161 | [![Stargazers over time](https://starchart.cc/kingstdio/ECRECer.svg)](https://github.com/kingstdio/ECRECer/) 162 | -------------------------------------------------------------------------------- /benchmark_common.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os,string, random,joblib,sys 4 | from datetime import datetime 5 | import config as cfg 6 | from sklearn import preprocessing 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | from keras.layers import Input, Dense, GRU, Bidirectional 10 | from tools import Attention 11 | 12 | 13 | # sortmax 结果转 onehot 14 | def props_to_onehot(props): 15 | if isinstance(props, list): 16 | props = np.array(props) 17 | a = np.argmax(props, axis=1) 18 | b = np.zeros((len(a), props.shape[1])) 19 | b[np.arange(len(a)), a] = 1 20 | return b 21 | 22 | def make_onehot_label(label_list, save=False, file_encoder='./encode.h5', type='singel'): 23 | if type=='singel': 24 | encoder = preprocessing.OneHotEncoder(sparse=False) 25 | results = encoder.fit_transform([[item] for item in label_list]) 26 | 27 | elif type=='multi': 28 | encoder = preprocessing.MultiLabelBinarizer() 29 | results=encoder.fit_transform([ item.split(',') for item in label_list]) 30 | else: 31 | print('lable encoding type error, please check') 32 | sys.exit() 33 | 34 | if save ==True: 35 | joblib.dump(encoder, file_encoder) 36 | 37 | return results 38 | 39 | def mgru_attion_model(input_dimensions, gru_h_size=512, dropout=0.2, lossfunction='binary_crossentropy', 40 | evaluation_metrics='accuracy', activation_method = 'sigmoid', output_dimensions=2 41 | ): 42 | inputs = Input(shape=(1, input_dimensions), name="input") 43 | gru = Bidirectional(GRU(gru_h_size, dropout=dropout, return_sequences=True), name="bi-gru")(inputs) 44 | attention = Attention.Attention(32)(gru) 45 | output = Dense(output_dimensions, activation=activation_method, name="dense")(attention) 46 | model = Model(inputs, output) 47 | model.compile(loss=lossfunction, optimizer=Adam(),metrics=[evaluation_metrics]) 48 | 49 | return model 50 | 51 | # region 需要写fasta的dataFame格式 52 | def save_table2fasta(dataset, file_out): 53 | """[summary] 54 | Args: 55 | dataset ([DataFrame]): 需要写fasta的dataFame格式[id, seq] 56 | file_out ([type]): [description] 57 | """ 58 | if ~os.path.exists(file_out): 59 | file = open(file_out, 'w') 60 | for index, row in dataset.iterrows(): 61 | file.write('>{0}\n'.format(row['id'])) 62 | file.write('{0}\n'.format(row['seq'])) 63 | file.close() 64 | print('Write finished') 65 | # endregion 66 | 67 | # region 获取序列比对结果 68 | def getblast(ref_fasta, query_fasta, results_file): 69 | """[获取序列比对结果] 70 | Args: 71 | ref_fasta ([string]]): [训练集fasta文件] 72 | query_fasta ([string]): [测试数据集fasta文件] 73 | results_file ([string]): [要保存的结果文件] 74 | 75 | Returns: 76 | [DataFrame]: [比对结果] 77 | """ 78 | 79 | if os.path.exists(results_file): 80 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD) 81 | return res_data 82 | 83 | 84 | cmd1 = r'diamond makedb --in {0} -d /tmp/train.dmnd'.format(ref_fasta) 85 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q {0} -o {1} -b8 -c1 -k 1'.format(query_fasta, results_file) 86 | cmd3 = r'rm -rf /tmp/*.dmnd' 87 | 88 | print(cmd1) 89 | os.system(cmd1) 90 | print(cmd2) 91 | os.system(cmd2) 92 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD) 93 | os.system(cmd3) 94 | return res_data 95 | # endregion 96 | 97 | #region 创建diamond数据库 98 | def make_diamond_db(dbtable, to_db_file): 99 | """[创建diamond数据库] 100 | 101 | Args: 102 | dbtable ([DataTable]): [[ID,SEQ]DataFame] 103 | to_db_file ([string]): [数据库文件] 104 | """ 105 | save_table2fasta(dataset=dbtable, file_out=cfg.TEMPDIR+'dbtable.fasta') 106 | cmd = r'diamond makedb --in {0} -d {1}'.format(cfg.TEMPDIR+'dbtable.fasta', to_db_file) 107 | os.system(cmd) 108 | cmd = r'rm -rf {0}'.format(cfg.TEMPDIR+'dbtable.fasta') 109 | os.system(cmd) 110 | #endregion 111 | 112 | #region 为blast序列比对添加结果标签 113 | def blast_add_label(blast_df, trainset,): 114 | """[为blast序列比对添加结果标签] 115 | 116 | Args: 117 | blast_df ([DataFrame]): [序列比对结果] 118 | trainset ([DataFrame]): [训练数据] 119 | 120 | Returns: 121 | [Dataframe]: [添加标签后的数据] 122 | """ 123 | res_df = blast_df.merge(trainset, left_on='sseqid', right_on='id', how='left') 124 | res_df = res_df.rename(columns={'id_x': 'id', 125 | 'isemzyme': 'isenzyme_blast', 126 | 'functionCounts': 'functionCounts_blast', 127 | 'ec_number': 'ec_blast', 128 | 'ec_specific_level': 'ec_specific_level_blast' 129 | }) 130 | return res_df.iloc[:,np.r_[0,13:16]] 131 | #endregion 132 | 133 | #region loading embedding features 134 | def load_data_embedding(embedding_type): 135 | """loading embedding features 136 | 137 | Args: 138 | embedding_type (string): one of ['one-hot', 'unirep', 'esm0', 'esm32', 'esm33'] 139 | 140 | Returns: 141 | DataFrame: features 142 | """ 143 | 144 | if embedding_type=='one-hot': #one-hot 145 | feature = pd.read_feather(cfg.FILE_FEATURE_ONEHOT) 146 | 147 | if embedding_type=='unirep': #unirep 148 | feature = pd.read_feather(cfg.FILE_FEATURE_UNIREP) 149 | 150 | if embedding_type=='esm0': #esm0 151 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM0) 152 | 153 | if embedding_type =='esm32': #esm32 154 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM32) 155 | 156 | if embedding_type =='esm33': #esm33 157 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM33) 158 | 159 | return feature 160 | #endregion 161 | 162 | 163 | def get_blast_prediction(reference_db, train_frame, test_frame, results_file, identity_thres=0.2): 164 | 165 | save_table2fasta(dataset=test_frame.iloc[:,np.r_[0,5]], file_out=cfg.TEMPDIR+'test.fasta') 166 | cmd = r'diamond blastp -d {0} -q {1} -o {2} -b8 -c1 -k 1'.format(reference_db, cfg.TEMPDIR+'test.fasta', results_file) 167 | print(cmd) 168 | os.system(cmd) 169 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD) 170 | res_data = res_data[res_data.pident >= identity_thres] # 按显著性阈值过滤 171 | res_df = res_data.merge(train_frame, left_on='sseqid', right_on='id', how='left') 172 | res_df = res_df.rename(columns={'id_x': 'id', 173 | 'isemzyme': 'isemzyme_blast', 174 | 'functionCounts': 'functionCounts_blast', 175 | 'ec_number': 'ec_number_blast', 176 | 'ec_specific_level': 'ec_specific_level_blast' 177 | }) 178 | 179 | return res_df.iloc[:,np.r_[0,2,13:17]] 180 | 181 | #region 读取cdhit聚类结果 182 | def get_cdhit_results(cdhit_clstr_file): 183 | """读取cdhit聚类结果 184 | 185 | Args: 186 | cdhit_clstr_file (string): 聚类结果文件 187 | 188 | Returns: 189 | DataFrame: ['cluster_id','uniprot_id','identity', 'is_representative'] 190 | """ 191 | counter = 0 192 | res = [] 193 | with open(cdhit_clstr_file,'r') as f: 194 | for line in f: 195 | if 'Cluster' in line: 196 | cluster_id = line.replace('>Cluster','').replace('\n', '').strip() 197 | continue 198 | str_uids= line.replace('\n','').split('>')[1].replace('at ','').split('... ') 199 | 200 | if '*' in str_uids[1]: 201 | identity = 1 202 | isrep = True 203 | else: 204 | identity = float(str_uids[1].strip('%')) /100 205 | isrep = False 206 | 207 | res = res +[[cluster_id, str_uids[0], identity, isrep ]] 208 | 209 | resdf = pd.DataFrame(res, columns=['cluster_id','uniprot_id','identity', 'is_representative']) #转换为DataFrame 210 | return resdf 211 | #endregion 212 | 213 | def pycdhit(uniportid_seq_df, identity=0.4, thred_num=4): 214 | """CD-HIT 序列聚类 215 | 216 | Args: 217 | uniportid_seq_df (DataFrame): [uniprot_id, seq] 蛋白DataFrame 218 | identity (float, optional): 聚类阈值. Defaults to 0.4. 219 | thred_num (int, optional): 聚类线程数. Defaults to 4. 220 | 221 | Returns: 222 | 聚类结果 DataFrame: [cluster_id,uniprot_id,identity,is_representative,cluster_size] 223 | """ 224 | if identity>=0.7: 225 | word_size = 5 226 | elif identity>=0.6: 227 | word_size = 4 228 | elif identity >=0.5: 229 | word_size = 3 230 | elif identity >=0.4: 231 | word_size =2 232 | else: 233 | word_size = 5 234 | 235 | # 定义输入输出文件名 236 | 237 | 238 | 239 | time_stamp_str = datetime.now().strftime("%Y-%m-%d_%H_%M_%S_")+''.join(random.sample(string.ascii_letters + string.digits, 16)) 240 | cd_hit_fasta = f'{cfg.TEMPDIR}cdhit_test_{time_stamp_str}.fasta' 241 | cd_hit_results = f'{cfg.TEMPDIR}cdhit_results_{time_stamp_str}' 242 | cd_hit_cluster_res_file =f'{cfg.TEMPDIR}cdhit_results_{time_stamp_str}.clstr' 243 | 244 | # 写聚类fasta文件 245 | save_table2fasta(uniportid_seq_df, cd_hit_fasta) 246 | 247 | # cd-hit聚类 248 | cmd = f'cd-hit -i {cd_hit_fasta} -o {cd_hit_results} -c {identity} -n {word_size} -T {thred_num} -M 0 -g 1 -sc 1 -sf 1 > /dev/null 2>&1' 249 | os.system(cmd) 250 | cdhit_cluster = get_cdhit_results(cdhit_clstr_file=cd_hit_cluster_res_file) 251 | 252 | cluster_size = cdhit_cluster.cluster_id.value_counts() 253 | cluster_size = pd.DataFrame({'cluster_id':cluster_size.index,'cluster_size':cluster_size.values}) 254 | cdhit_cluster = cdhit_cluster.merge(cluster_size, on='cluster_id', how='left') 255 | 256 | cmd = f'rm -f {cd_hit_fasta} {cd_hit_results} {cd_hit_cluster_res_file}' 257 | os.system(cmd) 258 | 259 | return cdhit_cluster 260 | 261 | #region 打印模型的重要指标,排名topN指标 262 | def importance_features_top(model, x_train, topN=10): 263 | """[打印模型的重要指标,排名topN指标] 264 | Args: 265 | model ([type]): [description] 266 | x_train ([type]): [description] 267 | topN (int, optional): [description]. Defaults to 10. 268 | """ 269 | print("打印XGBoost重要指标") 270 | feature_importances_ = model.feature_importances_ 271 | feature_names = x_train.columns 272 | importance_col = pd.DataFrame([*zip(feature_names, feature_importances_)], columns=['features', 'weight']) 273 | importance_col_desc = importance_col.sort_values(by='weight', ascending=False) 274 | print(importance_col_desc.iloc[:topN, :]) 275 | #endregion 276 | 277 | #test 278 | 279 | if __name__ =='__main__': 280 | print('success') -------------------------------------------------------------------------------- /benchmark_evaluation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2023-09-28 12:03:17 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2023-10-07 06:14:15 6 | FilePath: /ECRECer/benchmark_evaluation.py 7 | Description: 8 | 9 | Copyright (c) 2023 by tibd, All Rights Reserved. 10 | ''' 11 | from operator import index 12 | from random import sample 13 | import pandas as pd 14 | import numpy as np 15 | from sklearn import metrics 16 | from pandas._config.config import reset_option 17 | import benchmark_common as bcommon 18 | import config as cfg 19 | 20 | 21 | #region loading results from different methods and make a big DataFrame 22 | def load_res_data(file_slice, file_blast, file_deepec, file_ecpred,file_catfam, file_priam, train, test): 23 | """[加载各种算法的预测结果,并拼合成一个大表格] 24 | 25 | Args: 26 | file_slice ([string]]): [slice 集成预测结果文件] 27 | file_blast ([string]): [blast 序列比对结果文件] 28 | file_deepec ([string]): [deepec 预测结果文件] 29 | file_ecpred ([string]): [ecpred 预测结果文件] 30 | train ([DataFrame]): [训练集] 31 | test ([DataFrame]): [测试集] 32 | 33 | Returns: 34 | [DataFrame]: [拼合完成的大表格] 35 | """ 36 | 37 | 38 | # Slice + groundtruth 39 | ground_truth = test.iloc[:, np.r_[0,1,2,3]] 40 | res_slice = pd.read_csv(file_slice, sep='\t') 41 | 42 | big_res = ground_truth.merge(res_slice, on='id', how='left') 43 | 44 | #Blast 45 | res_blast = pd.read_csv(file_blast, sep='\t', names =cfg.BLAST_TABLE_HEAD) 46 | res_blast = bcommon.blast_add_label(blast_df=res_blast, trainset=train) 47 | big_res = big_res.merge(res_blast, on='id', how='left') 48 | 49 | 50 | # DeepEC 51 | res_deepec = pd.read_csv(file_deepec, sep='\t',names=['id', 'ec_number'], header=0 ) 52 | res_deepec.ec_number=res_deepec.apply(lambda x: x['ec_number'].replace('EC:',''), axis=1) 53 | res_deepec.columns = ['id','ec_deepec'] 54 | res_deepec['isemzyme_deepec']=res_deepec.ec_deepec.apply(lambda x: True if str(x)!='nan' else False) 55 | res_deepec['functionCounts_deepec'] = res_deepec.ec_deepec.apply(lambda x :len(str(x).split(','))) 56 | big_res = big_res.merge(res_deepec, on='id', how='left').drop_duplicates(subset='id') 57 | 58 | # ECpred 59 | res_ecpred = pd.read_csv(file_ecpred, sep='\t', header=0) 60 | res_ecpred['isemzyme_ecpred'] = '' 61 | with pd.option_context('mode.chained_assignment', None): 62 | res_ecpred.isemzyme_ecpred[res_ecpred['EC Number']=='non Enzyme'] = False 63 | res_ecpred.isemzyme_ecpred[res_ecpred['EC Number']!='non Enzyme'] = True 64 | 65 | res_ecpred.columns = ['id','ec_ecpred', 'conf', 'isemzyme_ecpred'] 66 | res_ecpred = res_ecpred.iloc[:,np.r_[0,1,3]] 67 | res_ecpred['functionCounts_ecpred'] = res_ecpred.ec_ecpred.apply(lambda x :len(str(x).split(','))) 68 | big_res = big_res.merge(res_ecpred, on='id', how='left').drop_duplicates(subset='id') 69 | 70 | # CATFAM 71 | res_catfam = pd.read_csv(file_catfam, sep='\t', names=['id', 'ec_catfam']) 72 | res_catfam['isenzyme_catfam']=res_catfam.ec_catfam.apply(lambda x: True if str(x)!='nan' else False) 73 | res_catfam['functionCounts_catfam'] = res_catfam.ec_catfam.apply(lambda x :len(str(x).split(','))) 74 | big_res = big_res.merge(res_catfam, on='id', how='left') 75 | 76 | #PRIAM 77 | res_priam = load_praim_res(resfile=file_priam) 78 | big_res = big_res.merge(res_priam, on='id', how='left') 79 | big_res['isenzyme_priam'] = big_res.ec_priam.apply(lambda x: True if str(x)!='nan' else False) 80 | big_res['functionCounts_priam'] = big_res.ec_priam.apply(lambda x :len(str(x).split(','))) 81 | 82 | big_res=big_res.sort_values(by=['isemzyme', 'id'], ascending=False) 83 | big_res = big_res.drop_duplicates(subset='id') 84 | big_res.reset_index(drop=True, inplace=True) 85 | return big_res 86 | #endregion 87 | 88 | #region 89 | def load_praim_res(resfile): 90 | """[加载PRIAM的预测结果] 91 | Args: 92 | resfile ([string]): [结果文件] 93 | Returns: 94 | [DataFrame]: [结果] 95 | """ 96 | f = open(resfile) 97 | line = f.readline() 98 | counter =0 99 | reslist=[] 100 | lstr ='' 101 | subec=[] 102 | while line: 103 | if '>' in line: 104 | if counter !=0: 105 | reslist +=[[lstr, ', '.join(subec)]] 106 | subec=[] 107 | lstr = line.replace('>', '').replace('\n', '') 108 | elif line.strip()!='': 109 | ecarray = line.split('\t') 110 | subec += [(ecarray[0].replace('#', '').replace('\n', '').replace(' ', '') )] 111 | 112 | line = f.readline() 113 | counter +=1 114 | f.close() 115 | res_priam=pd.DataFrame(reslist, columns=['id', 'ec_priam']) 116 | return res_priam 117 | #endregion 118 | 119 | def integrate_reslults(big_table): 120 | # 拼合多个标签 121 | big_table['ec_islice']=big_table.apply(lambda x : ', '.join(x.iloc[4:(x.pred_functionCounts+4)].values.astype('str')), axis=1) 122 | 123 | #给非酶赋- 124 | with pd.option_context('mode.chained_assignment', None): 125 | big_table.ec_islice[big_table.ec_islice=='nan']='-' 126 | big_table.ec_islice[big_table.ec_islice=='']='-' 127 | big_table['isemzyme_deepec']=big_table.ec_deepec.apply(lambda x : 0 if str(x)=='nan' else 1) 128 | 129 | big_table=big_table.iloc[:, np.r_[0:4,16:31,31, 14:16]] 130 | big_table = big_table.rename(columns={'isemzyme': 'isenzyme_groundtruth', 131 | 'functionCounts': 'functionCounts_groundtruth', 132 | 'ec_number': 'ec_groundtruth', 133 | 'pred_isEnzyme': 'isenzyme_islice', 134 | 'pred_functionCounts': 'functionCounts_islice', 135 | 'isemzyme_blast':'isenzyme_blast' 136 | }) 137 | # 拼合训练测试样本数 138 | samplecounts = pd.read_csv(cfg.DATADIR + 'ecsamplecounts.tsv', sep = '\t') 139 | 140 | big_table = big_table.merge(samplecounts, left_on='ec_groundtruth', right_on='ec_number', how='left') 141 | big_table = big_table.iloc[:, np.r_[0:22,23:25]] 142 | 143 | big_table.to_excel(cfg.FILE_EVL_RESULTS, index=None) 144 | big_table = big_table.drop_duplicates(subset='id') 145 | return big_table 146 | 147 | def caculateMetrix(groundtruth, predict, baselineName, type='binary'): 148 | 149 | if type == 'binary': 150 | acc = metrics.accuracy_score(groundtruth, predict) 151 | precision = metrics.precision_score(groundtruth, predict, zero_division=True ) 152 | recall = metrics.recall_score(groundtruth, predict, zero_division=True) 153 | f1 = metrics.f1_score(groundtruth, predict, zero_division=True) 154 | tn, fp, fn, tp = metrics.confusion_matrix(groundtruth, predict).ravel() 155 | npv = tn/(fn+tn+1.4E-45) 156 | print(baselineName, '\t\t%f' %acc,'\t%f'% precision,'\t\t%f'%npv,'\t%f'% recall,'\t%f'% f1, '\t', 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn) 157 | 158 | if type =='include_unfind': 159 | evadf = pd.DataFrame() 160 | evadf['g'] = groundtruth 161 | evadf['p'] = predict 162 | 163 | evadf_hot = evadf[~evadf.p.isnull()] 164 | evadf_cold = evadf[evadf.p.isnull()] 165 | 166 | tp = len(evadf_hot[(evadf_hot.g.astype('int')==1) & (evadf_hot.p.astype('int')==1)]) 167 | fp = len(evadf_hot[(evadf_hot.g.astype('int')==0) & (evadf_hot.p.astype('int')==1)]) 168 | tn = len(evadf_hot[(evadf_hot.g.astype('int')==0) & (evadf_hot.p.astype('int')==0)]) 169 | fn = len(evadf_hot[(evadf_hot.g.astype('int')==1) & (evadf_hot.p.astype('int')==0)]) 170 | up = len(evadf_cold[evadf_cold.g==1]) 171 | un = len(evadf_cold[evadf_cold.g==0]) 172 | acc = (tp+tn)/(tp+fp+tn+fn+up+un) 173 | precision = tp/(tp+fp) 174 | npv = tn/(tn+fn) 175 | recall = tp/(tp+fn+up) 176 | f1=(2*precision*recall)/(precision+recall) 177 | print( baselineName, 178 | '\t\t%f' %acc, 179 | '\t%f'% precision, 180 | '\t\t%f'%npv, 181 | '\t%f'% recall, 182 | '\t%f'% f1, '\t', 183 | 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn, 'up:',up, 'un:',un) 184 | 185 | if type == 'multi': 186 | acc = metrics.accuracy_score(groundtruth, predict) 187 | precision = metrics.precision_score(groundtruth, predict, average='macro', zero_division=True ) 188 | recall = metrics.recall_score(groundtruth, predict, average='macro', zero_division=True) 189 | f1 = metrics.f1_score(groundtruth, predict, average='macro', zero_division=True) 190 | print('%12s'%baselineName, ' \t\t%f '%acc,'\t%f'% precision, '\t\t%f'% recall,'\t%f'% f1) 191 | 192 | def evalueate_performance(evalutation_table): 193 | print('\n\n1. isEnzyme prediction evalueation metrics') 194 | print('*'*140+'\n') 195 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t confusion Matrix') 196 | ev_isenzyme={ 197 | 'ours':'isenzyme_islice', 198 | 'blast':'isenzyme_blast', 199 | 'ecpred':'isemzyme_ecpred', 200 | 'deepec':'isemzyme_deepec', 201 | 'catfam':'isenzyme_catfam', 202 | 'priam':'isenzyme_priam' 203 | } 204 | 205 | for k,v in ev_isenzyme.items(): 206 | caculateMetrix( groundtruth=evalutation_table.isenzyme_groundtruth.astype('int'), 207 | predict=evalutation_table[v], 208 | baselineName=k, 209 | type='include_unfind') 210 | 211 | 212 | print('\n\n2. EC prediction evalueation metrics') 213 | print('*'*140+'\n') 214 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro') 215 | ev_ec = { 216 | 'ours':'ec_islice', 217 | 'blast':'ec_blast', 218 | 'ecpred':'ec_ecpred', 219 | 'deepec':'ec_deepec', 220 | 'catfam':'ec_catfam', 221 | 'priam':'ec_priam', 222 | } 223 | for k, v in ev_ec.items(): 224 | caculateMetrix( groundtruth=evalutation_table.ec_groundtruth, 225 | predict=evalutation_table[v].fillna('NaN'), 226 | baselineName=k, 227 | type='multi') 228 | 229 | print('\n\n3. Function counts evalueation metrics') 230 | print('*'*140+'\n') 231 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro') 232 | 233 | ev_functionCounts = { 234 | 'ours':'functionCounts_islice', 235 | 'blast':'functionCounts_blast', 236 | 'ecpred':'functionCounts_ecpred', 237 | 'deepec':'functionCounts_deepec' , 238 | 'catfam':'functionCounts_catfam', 239 | 'priam': 'functionCounts_priam' 240 | } 241 | 242 | for k, v in ev_functionCounts.items(): 243 | caculateMetrix( groundtruth=evalutation_table.functionCounts_groundtruth, 244 | predict=evalutation_table[v].fillna('-1').astype('int'), 245 | baselineName=k, 246 | type='multi') 247 | 248 | num_enzyme = len(evalutation_table[evalutation_table.isenzyme_groundtruth]) 249 | num_no_enzyme = len(evalutation_table[~evalutation_table.isenzyme_groundtruth]) 250 | num_multi_function = len(evalutation_table[evalutation_table.functionCounts_groundtruth>1]) 251 | num_single_function = len(evalutation_table[evalutation_table.functionCounts_groundtruth==1]) 252 | 253 | 254 | ev_ec = {'ours':'ec_islice', 'blast':'ec_blast', 'ecpred':'ec_ecpred', 'deepec':'ec_deepec', 'catfam':'ec_catfam'} 255 | 256 | print('\n\n4. EC Prediction Report') 257 | str_out_head = '\t item' 258 | str_out_multi = '多功能酶的accuracy' 259 | str_out_single = '单功能酶的accuracy' 260 | str_out_all ='整体accuracy\t' 261 | for k, v in ev_ec.items(): 262 | str_out_head += ('\t\t\t'+str(k)) 263 | num_multi = len(evalutation_table[(evalutation_table.functionCounts_groundtruth>1) & 264 | (evalutation_table.ec_groundtruth == evalutation_table[v])]) 265 | num_single = len(evalutation_table[(evalutation_table.functionCounts_groundtruth==1) & 266 | (evalutation_table.ec_groundtruth == evalutation_table[v])]) 267 | str_out_multi += ('\t\t' + str(num_multi)+'/'+ str(num_multi_function) +'='+ str(round(num_multi/num_multi_function,4))) 268 | str_out_single += ('\t\t' + str(num_single)+'/'+ str(num_single_function) +'='+ str(round(num_single/num_single_function,4))) 269 | str_out_all += ('\t\t' + str(num_single + num_multi)+'/'+ str(num_enzyme) +'='+ str(round(num_single/num_enzyme,4))) 270 | print(str_out_head ) 271 | print(str_out_multi) 272 | print(str_out_single) 273 | print(str_out_all) 274 | 275 | def filter_newadded_ec(restable): 276 | #filter newly added ECs by your own 277 | ## TODO caculate the newly addded ecs in the testing set and filter them 278 | 279 | return restable 280 | 281 | if __name__ =='__main__': 282 | 283 | # 1. loading data 284 | train = pd.read_feather(cfg.DATADIR+'train.feather').iloc[:,:6] 285 | test = pd.read_feather(cfg.DATADIR+'test.feather').iloc[:,:6] 286 | # test = test[(test.ec_specific_level>=cfg.TRAIN_USE_SPCIFIC_EC_LEVEL) |(~test.isemzyme)] 287 | test.reset_index(drop=True, inplace=True) 288 | # EC-label dict 289 | dict_ec_label = np.load(cfg.DATADIR + 'ec_label_dict.npy', allow_pickle=True).item() 290 | file_blast_res = cfg.RESULTSDIR + r'test_blast_res.tsv' 291 | flat_table = load_res_data( 292 | file_slice=cfg.FILE_INTE_RESULTS, 293 | file_blast=cfg.FILE_BLAST_RESULTS, 294 | file_deepec=cfg.FILE_DEEPEC_RESULTS, 295 | file_ecpred=cfg.FILE_ECPRED_RESULTS, 296 | file_catfam = cfg.FILE_CATFAM_RESULTS, 297 | file_priam = cfg.FILE_PRIAM_RESULTS, 298 | train=train, 299 | test=test 300 | ) 301 | 302 | evalutation_table = integrate_reslults(flat_table) 303 | evalutation_table = filter_newadded_ec(evalutation_table) #filter newly added ECs by your own 304 | 305 | evalueate_performance(evalutation_table) 306 | evalutation_table['sright'] = evalutation_table.apply(lambda x: True if x.ec_groundtruth == x.ec_islice else False, axis=1) 307 | evalutation_table['bright'] = evalutation_table.apply(lambda x: True if x.ec_groundtruth == x.ec_blast else False, axis=1) 308 | evalutation_table.to_excel(cfg.RESULTSDIR+'evaluationFF.xlsx', index=None) 309 | print('\n Evaluation Finished \n\n') -------------------------------------------------------------------------------- /benchmark_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import joblib 4 | import os 5 | import benchmark_common as bcommon 6 | import config as cfg 7 | from Bio import SeqIO 8 | 9 | # region 获取「酶|非酶」预测结果 10 | def get_isEnzymeRes(querydata, model_file): 11 | """[获取「酶|非酶」预测结果] 12 | Args: 13 | querydata ([DataFrame]): [需要预测的数据] 14 | model_file ([string]): [模型文件] 15 | Returns: 16 | [DataFrame]: [预测结果、预测概率] 17 | """ 18 | model = joblib.load(model_file) 19 | predict = model.predict(querydata) 20 | predictprob = model.predict_proba(querydata) 21 | return predict, predictprob[:, 1] 22 | # endregion 23 | 24 | # region 获取「几功能酶」预测结果 25 | def get_howmany_Enzyme(querydata, model_file): 26 | """获取「几功能酶」预测结果 27 | Args: 28 | querydata ([DataFrame]): [需要预测的数据] 29 | model_file ([string]): [模型文件] 30 | Returns: 31 | [DataFrame]: [预测结果、预测概率] 32 | """ 33 | model = joblib.load(model_file) 34 | predict = model.predict(querydata) 35 | predictprob = model.predict_proba(querydata) 36 | return predict+1, predictprob #标签加1,单功能酶标签为0,统一加1 37 | # endregion 38 | 39 | # region 获取slice预测结果 40 | def get_slice_res(slice_query_file, model_path, dict_ec_label,test_set, res_file): 41 | """[获取slice预测结果] 42 | 43 | Args: 44 | slice_query_file ([string]): [需要预测的数据sliceFile] 45 | model_path ([string]): [Slice模型路径] 46 | res_file ([string]]): [预测结果文件] 47 | Returns: 48 | [DataFrame]: [预测结果] 49 | """ 50 | 51 | cmd = '''./slice_predict {0} {1} {2} -o 32 -b 0 -t 32 -q 0'''.format(slice_query_file, model_path, res_file) 52 | print(cmd) 53 | os.system(cmd) 54 | result_slice = pd.read_csv(res_file, header=None, skiprows=1, sep=' ') 55 | 56 | # 5.3 对预测结果排序 57 | slice_pred_rank, slice_pred_prob = sort_results(result_slice) 58 | 59 | # 5.4 将结果翻译成EC号 60 | slice_pred_ec = translate_slice_pred(slice_pred=slice_pred_rank, ec2label_dict = dict_ec_label, test_set=test_set) 61 | 62 | return slice_pred_ec 63 | 64 | # endregion 65 | 66 | #region 将slice的实验结果排序,并按照推荐顺序以两个矩阵的形式返回 67 | def sort_results(result_slice): 68 | """ 69 | 将slice的实验结果排序,并按照推荐顺序以两个矩阵的形式返回 70 | @pred_top:预测结果排序 71 | @pred_pb_top:预测结果评分排序 72 | """ 73 | pred_top = [] 74 | pred_pb_top = [] 75 | aac = [] 76 | for index, row in result_slice.iterrows(): 77 | row_trans = [*row.apply(lambda x: x.split(':')).values] 78 | row_trans = pd.DataFrame(row_trans).sort_values(by=[1], ascending=False) 79 | pred_top += [list(np.array(row_trans[0]).astype('int'))] 80 | pred_pb_top += [list(np.array(row_trans[1]).astype('float'))] 81 | pred_top = pd.DataFrame(pred_top) 82 | pred_pb_top = pd.DataFrame(pred_pb_top) 83 | return pred_top, pred_pb_top 84 | #endregion 85 | 86 | #region 划分测试集XY 87 | def get_test_set(data): 88 | """[划分测试集XY] 89 | 90 | Args: 91 | data ([DataFrame]): [测试数据] 92 | 93 | Returns: 94 | [DataFrame]: [划分好的XY] 95 | """ 96 | testX = data.iloc[:,7:] 97 | testY = data.iloc[:,:6] 98 | return testX, testY 99 | #endregion 100 | 101 | #region 将slice预测的标签转换为EC号 102 | def translate_slice_pred(slice_pred, ec2label_dict, test_set): 103 | """[将slice预测的标签转换为EC号] 104 | 105 | Args: 106 | slice_pred ([DataFrame]): [slice预测后的排序数据] 107 | ec2label_dict ([dict]]): [ec转label的字典] 108 | test_set ([DataFrame]): [测试集用于取ID] 109 | 110 | Returns: 111 | [type]: [description] 112 | """ 113 | label_ec_dict = {value:key for key, value in ec2label_dict.items()} 114 | res_df = pd.DataFrame() 115 | res_df['id'] = test_set.id 116 | colNames = slice_pred.columns.values 117 | for colName in colNames: 118 | res_df['top'+str(colName)] = slice_pred[colName].apply(lambda x: label_ec_dict.get(x)) 119 | return res_df 120 | #endregion 121 | 122 | 123 | 124 | #region 将测试结果进行集成输出 125 | def run_integrage(slice_pred, dict_ec_transfer): 126 | """[将测试结果进行集成输出] 127 | 128 | Args: 129 | slice_pred ([DataFrame]): [slice 预测结果] 130 | dict_ec_transfer ([dict]): [EC转移dict] 131 | 132 | Returns: 133 | [DataFrame]: [集成后的最终结果] 134 | """ 135 | # 取top10,因为最多有10功能酶 136 | slice_pred = slice_pred.iloc[:,np.r_[0:11, 21:28]] 137 | 138 | #酶集成标签 139 | with pd.option_context('mode.chained_assignment', None): 140 | slice_pred['is_enzyme_i'] = slice_pred.apply(lambda x: int(x.isemzyme_blast) if str(x.isemzyme_blast)!='nan' else x.isEnzyme_pred_xg, axis=1) 141 | 142 | # 清空非酶的EC预测标签 143 | for i in range(9): 144 | with pd.option_context('mode.chained_assignment', None): 145 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if x.is_enzyme_i==0 else x['top'+str(i)], axis=1) 146 | slice_pred['top0'] = slice_pred.apply(lambda x: x.ec_number_blast if str(x.ec_number_blast)!='nan' else x.top0, axis=1) 147 | 148 | # 清空有比对结果的预测标签 149 | for i in range(1,10): 150 | with pd.option_context('mode.chained_assignment', None): 151 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if str(x.ec_number_blast)!='nan' else x['top'+str(i)], axis=1) #有比对结果的 152 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if int(x.functionCounts_pred_xg) < int(i+1) else x['top'+str(i)], axis=1) #无比对结果的 153 | with pd.option_context('mode.chained_assignment', None): 154 | slice_pred['top0']=slice_pred['top0'].apply(lambda x: '' if x=='-' else x) 155 | 156 | # 将EC号拆开 157 | for index, row in slice_pred.iterrows(): 158 | ecitems=row['top0'].split(',') 159 | if len(ecitems)>1: 160 | for i in range(len(ecitems)): 161 | slice_pred.loc[index,'top'+str(i)] = ecitems[i].strip() 162 | 163 | slice_pred.reset_index(drop=True, inplace=True) 164 | 165 | # 添加几功能酶预测结果 166 | with pd.option_context('mode.chained_assignment', None): 167 | slice_pred['pred_functionCounts'] = slice_pred.apply(lambda x: int(x['functionCounts_blast']) if str(x['functionCounts_blast'])!='nan' else x.functionCounts_pred_xg ,axis=1) 168 | # 取最终的结果并改名 169 | colnames=[ 'id', 170 | 'pred_ec1', 171 | 'pred_ec2', 172 | 'pred_ec3', 173 | 'pred_ec4', 174 | 'pred_ec5' , 175 | 'pred_ec6', 176 | 'pred_ec7', 177 | 'pred_ec8', 178 | 'pred_ec9', 179 | 'pred_ec10', 180 | 'pred_isEnzyme', 181 | 'pred_functionCounts' 182 | ] 183 | slice_pred=slice_pred.iloc[:, np.r_[0:11, 18,19]] 184 | slice_pred.columns = colnames 185 | 186 | # 计算EC转移情况 187 | for i in range(1,11): 188 | slice_pred['pred_ec'+str(i)] = slice_pred['pred_ec'+str(i)].apply(lambda x: dict_ec_transfer.get(x) if x in dict_ec_transfer.keys() else x) 189 | 190 | # 清空没有EC号预测的酶功能数 191 | with pd.option_context('mode.chained_assignment', None): 192 | slice_pred.pred_functionCounts[slice_pred.pred_ec1.isnull()] = 0 193 | 194 | return slice_pred 195 | #endregion 196 | 197 | if __name__ == '__main__': 198 | 199 | EMBEDDING_METHOD = 'esm32' 200 | TESTSET='test2019' 201 | 202 | # 1. 读入数据 203 | print('step 1: loading data') 204 | train = pd.read_feather(cfg.TRAIN_FEATURE) 205 | test = pd.read_feather(cfg.TEST_FEATURE) 206 | train,test= bcommon.load_data_embedding(train=train, test=test, embedding_type=EMBEDDING_METHOD) 207 | train = train.iloc[:,:7] 208 | 209 | dict_ec_label = np.load(cfg.FILE_EC_LABEL_DICT, allow_pickle=True).item() #EC-标签字典 210 | dict_ec_transfer = np.load(cfg.FILE_TRANSFER_DICT, allow_pickle=True).item() #EC-转移字典 211 | 212 | # 2. 获取序列比对结果 213 | 214 | print('step 2 get blast results') 215 | blast_res = bcommon.get_blast_prediction( reference_db=cfg.FILE_BLAST_TRAIN_DB, 216 | train_frame=train, 217 | test_frame=test.iloc[:,0:7], 218 | results_file=cfg.FILE_BLAST_RESULTS, 219 | identity_thres=cfg.TRAIN_BLAST_IDENTITY_THRES 220 | ) 221 | 222 | # 3.获取酶-非酶预测结果 223 | print('step 3. get isEnzyme results') 224 | testX, testY = get_test_set(data=test) 225 | isEnzyme_pred, isEnzyme_pred_prob = get_isEnzymeRes(querydata=testX, model_file=cfg.ISENZYME_MODEL) 226 | 227 | 228 | # 4. 预测几功能酶预测结果 229 | print('step 4. get howmany functions ') 230 | howmany_Enzyme_pred, howmany_Enzyme_pred_prob = get_howmany_Enzyme(querydata=testX, model_file=cfg.HOWMANY_MODEL) 231 | 232 | # 5.获取Slice预测结果 233 | print('step 5. get EC prediction results') 234 | # 5.1 准备slice所用文件 235 | print('predict finished') -------------------------------------------------------------------------------- /benchmark_train.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | from sklearn.model_selection import train_test_split 5 | import benchmark_common as bcommon 6 | import config as cfg 7 | from keras.callbacks import TensorBoard, ModelCheckpoint 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 9 | 10 | 11 | #region 获取酶训练的数据集 12 | def get_train_X_Y(traindata, feature_bankfile, task=1): 13 | """[获取酶训练的数据集] 14 | 15 | Args: 16 | traindata ([DataFrame]): [description] 17 | 18 | Returns: 19 | [DataFrame]: [trianX, trainY] 20 | """ 21 | traindata = traindata.merge(feature_bankfile, on='id', how='left') 22 | train_X = np.array(traindata.iloc[:,3:]) 23 | if task == 1: 24 | train_Y = bcommon.make_onehot_label(label_list=traindata['isenzyme'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T1,type='singel') 25 | if task == 2: 26 | train_Y = bcommon.make_onehot_label(label_list=traindata['functionCounts'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T2, type='singel') 27 | if task == 3 : 28 | train_Y = bcommon.make_onehot_label(label_list=traindata['ec_number'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T3, type='multi') 29 | return train_X, train_Y 30 | #endregion 31 | 32 | #region 训练是否是酶模型 33 | def train_isenzyme(X,Y, model_file, vali_ratio=0.2, force_model_update=False, epochs=1): 34 | if os.path.exists(model_file) and (force_model_update==False): 35 | return 36 | else: 37 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=vali_ratio, shuffle=False) 38 | x_train = x_train.reshape(x_train.shape[0],1,-1) 39 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1) 40 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task1', histogram_freq=1,write_grads=True) 41 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True') 42 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=512, dropout=0.2, lossfunction='binary_crossentropy', 43 | evaluation_metrics='accuracy', activation_method = 'sigmoid', output_dimensions=Y.shape[1] ) 44 | 45 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=512, epochs= epochs, callbacks=[tbCallBack, checkpoint]) 46 | # 保存 47 | print(f'train_isenzyme finished, best model saved to: {model_file}') 48 | 49 | #endregion 50 | 51 | #region 构建几功能酶模型 52 | def train_howmany_enzyme(X, Y, model_file, force_model_update=False, epochs=1): 53 | if os.path.exists(model_file) and (force_model_update==False): 54 | return 55 | else: 56 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=0.3, shuffle=False) 57 | x_train = x_train.reshape(x_train.shape[0],1,-1) 58 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1) 59 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task3', histogram_freq=1,write_grads=True) 60 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True') 61 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=128, dropout=0.2, lossfunction='categorical_crossentropy', 62 | evaluation_metrics='accuracy', activation_method = 'softmax', output_dimensions=Y.shape[1]) 63 | 64 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=128, epochs= epochs, callbacks=[tbCallBack, checkpoint]) 65 | 66 | print(f'train how many enzyme finished, best model saved to: {model_file}') 67 | #endregion 68 | 69 | def make_ec_label(train_label, test_label, file_save, force_model_update=False): 70 | if os.path.exists(file_save) and (force_model_update==False): 71 | print('ec label dict already exist') 72 | return 73 | ecset = sorted( set(list(train_label) + list(test_label))) 74 | ec_label_dict = {k: v for k, v in zip(ecset, range(len(ecset)))} 75 | np.save(file_save, ec_label_dict) 76 | print('字典保存成功') 77 | return ec_label_dict 78 | 79 | 80 | def train_ec(X, Y, model_file, force_model_update=False, epochs=1): 81 | if os.path.exists(model_file) and (force_model_update==False): 82 | return 83 | else: 84 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=0.3, shuffle=False) 85 | x_train = x_train.reshape(x_train.shape[0],1,-1) 86 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1) 87 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=512, dropout=0.2, lossfunction='categorical_crossentropy', 88 | evaluation_metrics='accuracy', activation_method = 'softmax', output_dimensions=Y.shape[1] ) 89 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task3', histogram_freq=1,write_grads=True) 90 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True') 91 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=3948, epochs= epochs, callbacks=[tbCallBack, checkpoint]) 92 | # 保存 93 | 94 | print(f'train EC model finished, best model saved to: {model_file}') 95 | 96 | 97 | 98 | if __name__ =="__main__": 99 | 100 | EMBEDDING_METHOD = 'esm32' 101 | 102 | # 1. read tranning data 103 | print('step 1 loading task data') 104 | 105 | # please specific the tranning data for each tasks 106 | data_task1_train = pd.read_feather(cfg.FILE_TASK1_TRAIN) 107 | data_task2_train = pd.read_feather(cfg.FILE_TASK2_TRAIN) 108 | data_task3_train = pd.read_feather(cfg.FILE_TASK3_TRAIN) 109 | 110 | # 2. read tranning feature 111 | print(f'step 2: Loading features, embdding method={EMBEDDING_METHOD}') 112 | feature_df = bcommon.load_data_embedding(embedding_type=EMBEDDING_METHOD) 113 | 114 | #3. train task-1 model 115 | print('step 3: train isEnzyme model') 116 | task1_X, task1_Y = get_train_X_Y(traindata=data_task1_train, feature_bankfile=feature_df, task=1) 117 | train_isenzyme(X=task1_X, Y=task1_Y, model_file= cfg.ISENZYME_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400) 118 | 119 | #4. task2 train 120 | print('step 4: train how many enzymes model') 121 | task2_X, task2_Y = get_train_X_Y(traindata=data_task2_train, feature_bankfile=feature_df, task=2) 122 | train_howmany_enzyme(X=task2_X, Y=task2_Y, model_file=cfg.HOWMANY_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400) 123 | 124 | #5. task3 train 125 | print('step 5 train EC model') 126 | task3_X, task3_Y = get_train_X_Y(traindata=data_task3_train, feature_bankfile=feature_df, task=3) 127 | train_ec(X=task3_X, Y=task3_Y, model_file=cfg.EC_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400) 128 | 129 | 130 | print('train finished') 131 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2020-06-05 05:10:25 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2023-08-25 14:32:54 6 | FilePath: /ECRECer/config.py 7 | Description: 8 | 9 | Copyright (c) 2022 by tibd, All Rights Reserved. 10 | ''' 11 | 12 | import os 13 | 14 | 15 | # 1. 定义数据目录 16 | ROOTDIR= f'/hpcfs/fhome/shizhenkun/codebase/ECRECer/' #change to your own absolute data dir 17 | DATADIR = ROOTDIR +'data/' 18 | RESULTSDIR = ROOTDIR +'results/' 19 | MODELDIR = ROOTDIR +'model' 20 | TEMPDIR =ROOTDIR +'tmp/' 21 | DIR_UNIPROT = DATADIR + 'uniprot/' 22 | DIR_DATASETS = DATADIR +'datasets/' 23 | DIR_FEATURES = DATADIR + 'featureBank/' 24 | DIR_DICT = DATADIR +'dict/' 25 | 26 | 27 | #2.URL 28 | URL_SPROT_SNAP201802 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2018_02/knowledgebase/uniprot_sprot-only2018_02.tar.gz' 29 | URL_SPROT_SNAP201902 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_major_releases/release-2019_02/knowledgebase/uniprot_sprot-only2019_02.tar.gz' 30 | URL_SPROT_SNAP202006 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2020_06/knowledgebase/uniprot_sprot-only2020_06.tar.gz' 31 | URL_SPROT_SNAP202102 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_major_releases/release-2021_02/knowledgebase/uniprot_sprot-only2021_02.tar.gz' 32 | URL_SPROT_SNAP202202 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2022_02/knowledgebase/uniprot_sprot-only2022_02.tar.gz' 33 | URL_SPROT_LATEST=f'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz' 34 | 35 | #3.FILES 36 | FILE_SPROT_SNAP201802=DIR_UNIPROT+'uniprot_sprot-only2018_02.tar.gz' 37 | FILE_SPROT_SNAP201902=DIR_UNIPROT+'uniprot_sprot-only2019_02.tar.gz' 38 | FILE_SPROT_SNAP202006=DIR_UNIPROT+'uniprot_sprot-only2020_06.tar.gz' 39 | FILE_SPROT_SNAP202102=DIR_UNIPROT+'uniprot_sprot-only2021_02.tar.gz' 40 | FILE_SPROT_SNAP202202=DIR_UNIPROT+'uniprot_sprot-only2022_02.tar.gz' 41 | 42 | FILE_SPROT_LATEST=DIR_UNIPROT+'uniprot_sprot_leatest.dat.gz' 43 | 44 | FILE_FEATURE_UNIREP = DIR_FEATURES + 'embd_unirep.feather' 45 | FILE_FEATURE_ESM0 = DIR_FEATURES + 'embd_esm0.feather' 46 | FILE_FEATURE_ESM32 = DIR_FEATURES + 'embd_esm32.feather' 47 | FILE_FEATURE_ESM33 = DIR_FEATURES + 'embd_esm33.feather' 48 | FILE_FEATURE_ONEHOT = DIR_FEATURES + 'embd_onehot.feather' 49 | 50 | 51 | FILE_TASK1_TRAIN = DIR_DATASETS + 'task1/train.feather' 52 | FILE_TASK1_TEST_2019 = DIR_DATASETS + 'task1/test_2019.feather' 53 | FILE_TASK1_TEST_2020 = DIR_DATASETS + 'task1/test_2020.feather' 54 | FILE_TASK1_TEST_2021 = DIR_DATASETS + 'task1/test_2021.feather' 55 | FILE_TASK1_TEST_2022 = DIR_DATASETS + 'task1/test_2022.feather' 56 | 57 | FILE_TASK2_TRAIN = DIR_DATASETS + 'task2/train.feather' 58 | FILE_TASK2_TEST_2019 = DIR_DATASETS + 'task2/test_2019.feather' 59 | FILE_TASK2_TEST_2020 = DIR_DATASETS + 'task2/test_2020.feather' 60 | FILE_TASK2_TEST_2021 = DIR_DATASETS + 'task2/test_2021.feather' 61 | FILE_TASK2_TEST_2022 = DIR_DATASETS + 'task2/test_2022.feather' 62 | 63 | 64 | FILE_TASK3_TRAIN = DIR_DATASETS + 'task3/train.feather' 65 | FILE_TASK3_TEST_2019 = DIR_DATASETS + 'task3/test_2019.feather' 66 | FILE_TASK3_TEST_2020 = DIR_DATASETS + 'task3/test_2020.feather' 67 | FILE_TASK3_TEST_2021 = DIR_DATASETS + 'task3/test_2021.feather' 68 | FILE_TASK3_TEST_2022 = DIR_DATASETS + 'task3/test_2022.feather' 69 | 70 | FILE_TASK1_TRAIN_FASTA = DIR_DATASETS +'task1/train.fasta' 71 | FILE_TASK1_TEST_2019_FASTA = DIR_DATASETS +'task1/test_2019.fasta' 72 | FILE_TASK1_TEST_2020_FASTA = DIR_DATASETS +'task1/test_2020.fasta' 73 | FILE_TASK1_TEST_2021_FASTA = DIR_DATASETS +'task1/test_2021.fasta' 74 | FILE_TASK1_TEST_2022_FASTA = DIR_DATASETS +'task1/test_2022.fasta' 75 | 76 | FILE_TASK2_TRAIN_FASTA = DIR_DATASETS +'task2/train.fasta' 77 | FILE_TASK2_TEST_2019_FASTA = DIR_DATASETS +'task2/test_2019.fasta' 78 | FILE_TASK2_TEST_2020_FASTA = DIR_DATASETS +'task2/test_2020.fasta' 79 | FILE_TASK2_TEST_2021_FASTA = DIR_DATASETS +'task2/test_2021.fasta' 80 | FILE_TASK2_TEST_2022_FASTA = DIR_DATASETS +'task2/test_2022.fasta' 81 | 82 | FILE_TASK3_TRAIN_FASTA = DIR_DATASETS +'task3/train.fasta' 83 | FILE_TASK3_TEST_2019_FASTA = DIR_DATASETS +'task3/test_2019.fasta' 84 | FILE_TASK3_TEST_2020_FASTA = DIR_DATASETS +'task3/test_2020.fasta' 85 | FILE_TASK3_TEST_2021_FASTA = DIR_DATASETS +'task3/test_2021.fasta' 86 | FILE_TASK3_TEST_2022_FASTA = DIR_DATASETS +'task3/test_2022.fasta' 87 | 88 | 89 | FILE_CASE_SHEWANELLA_FASTA=DATADIR+'shewanella.faa' 90 | 91 | 92 | TRAIN_FEATURE = DATADIR+'train.feather' 93 | TEST_FEATURE = DATADIR+'test.feather' 94 | TRAIN_FASTA = DATADIR+'train.fasta' 95 | TEST_FASTA = DATADIR+'test.fasta' 96 | 97 | FILE_LATEST_SPROT = DATADIR + 'uniprot_sprot_latest.dat.gz' 98 | FILE_LATEST_TREMBL = DATADIR + 'uniprot_trembl_latest.dat.gz' 99 | 100 | FILE_LATEST_SPROT_FEATHER = DATADIR + 'uniprot/sprot_latest.feather' 101 | FILE_LATEST_TREMBL_FEATHER = DATADIR + 'uniprot/trembl_latest.feather' 102 | 103 | 104 | FILE_EC_LABEL_DICT = DATADIR + 'ec_label_dict.npy' 105 | FILE_BLAST_TRAIN_DB = DATADIR + 'train_blast.dmnd' # blast比对数据库 106 | FILE_BLAST_PRODUCTION_DB = DATADIR + 'uniprot_blast_db/production_blast.dmnd' # 生产环境比对数据库 107 | FILE_BLAST_PRODUCTION_FASTA = DATADIR + 'production_blast.fasta' # 生产环境比对数据库 108 | FILE_TRANSFER_DICT = DATADIR + 'ec_transfer_dict.npy' 109 | 110 | 111 | 112 | ISENZYME_MODEL = MODELDIR+'/isenzyme.h5' 113 | HOWMANY_MODEL = MODELDIR+'/howmany_enzyme.h5' 114 | EC_MODEL = MODELDIR+'/ec.h5' 115 | 116 | 117 | FILE_BLAST_RESULTS = RESULTSDIR + r'test_blast_res.tsv' 118 | FILE_BLAST_ISENAYME_RESULTS = RESULTSDIR +r'isEnzyme_blast_results.tsv' 119 | FILE_BLAST_EC_RESULTS = RESULTSDIR +r'ec_blast_results.tsv' 120 | 121 | 122 | FILE_DEEPEC_RESULTS = RESULTSDIR + r'deepec/DeepEC_Result.txt' 123 | FILE_ECPRED_RESULTS = RESULTSDIR + r'ecpred/ecpred.tsv' 124 | FILE_CATFAM_RESULTS = RESULTSDIR + r'catfam_results.output' 125 | FILE_PRIAM_RESULTS = RESULTSDIR + R'priam/PRIAM_20210819134344/ANNOTATION/sequenceECs.txt' 126 | 127 | FILE_EVL_RESULTS = RESULTSDIR + r'evaluation_table.xlsx' 128 | 129 | UPDATE_MODEL = True #强制模型更新标志 130 | EMBEDDING_METHOD={ 'one-hot':1, 131 | 'unirep':2, 132 | 'esm0':3, 133 | 'esm32':4, 134 | 'esm33':5 135 | } 136 | 137 | # 138 | BLAST_TABLE_HEAD = ['id', 139 | 'sseqid', 140 | 'pident', 141 | 'length', 142 | 'mismatch', 143 | 'gapopen', 144 | 'qstart', 145 | 'qend', 146 | 'sstart', 147 | 'send', 148 | 'evalue', 149 | 'bitscore' 150 | ] 151 | 152 | 153 | DICT_LABEL_T1 = DIR_DICT+'dict_label_task1.h5' 154 | DICT_LABEL_T2 = DIR_DICT+'dict_label_task2.h5' 155 | DICT_LABEL_T3 = DIR_DICT+'dict_label_task3.h5' -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 11 | 12 | # DATA FOLDER 13 | 14 | ## Folder Structure 15 | 16 | rootfolder/data/datasets/ 17 | rootfolder/data/featureBank/ 18 | rootfolder/data/dict/ 19 | 20 | ## Get Data 21 | 22 | Bechmark task data set can be downloaded from AWS S3. 23 | 24 | ### Put in datasets dir 25 | ``` 26 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/ecrecer_datasets.zip 27 | ``` 28 | 29 | 30 | ### Put in featureBank dir 31 | ``` 32 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_unirep.feather 33 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_onehot.feather 34 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm33.feather 35 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm32.feather 36 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm0.feather 37 | ``` 38 | -------------------------------------------------------------------------------- /data/dict/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/dict/dict_label_task1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task1.h5 -------------------------------------------------------------------------------- /data/dict/dict_label_task2.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task2.h5 -------------------------------------------------------------------------------- /data/dict/dict_label_task3.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task3.h5 -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: ECRECer 2 | channels: 3 | - bjrn 4 | - bioconda 5 | - numba 6 | - rapidsai 7 | - nvidia 8 | - conda-forge 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - _py-xgboost-mutex=2.0=cpu_0 14 | - abseil-cpp=20210324.2=h9c3ff4c_0 15 | - anyio=3.3.0=py37h89c1867_0 16 | - argcomplete=1.12.3=pyhd8ed1ab_2 17 | - arrow-cpp=4.0.1=py37haaf8ab1_7_cuda 18 | - arrow-cpp-proc=3.0.0=cuda 19 | - async_generator=1.10=pyhd3eb1b0_0 20 | - attrs=21.2.0=pyhd3eb1b0_0 21 | - babel=2.9.1=pyhd3eb1b0_0 22 | - backcall=0.2.0=pyhd3eb1b0_0 23 | - biopython=1.79=py37h5e8e339_0 24 | - blas=1.0=mkl 25 | - bleach=3.3.0=pyhd3eb1b0_0 26 | - brotlipy=0.7.0=py37h5e8e339_1001 27 | - bzip2=1.0.8=h7b6447c_0 28 | - c-ares=1.17.1=h7f98852_1 29 | - cachetools=4.2.2=pyhd8ed1ab_0 30 | - colorama=0.4.4=pyh9f0ad1d_0 31 | - cryptography=3.4.7=py37h5d9358c_0 32 | - cudatoolkit 33 | - cudf=21.08.02=cuda_11.2_py37_gf6d31fa95d_0 34 | - cupy=9.4.0=py37hf4b2161_0 35 | - cycler=0.10.0=py_2 36 | - daal4py=2021.2.2=py37ha9443f7_0 37 | - dal=2021.2.2=h06a4308_389 38 | - dbus=1.13.18=hb2f20db_0 39 | - debugpy=1.4.1=py37hcd2ae1e_0 40 | - decorator=5.0.9=pyhd3eb1b0_0 41 | - defusedxml=0.7.1=pyhd3eb1b0_0 42 | - dill=0.3.4=pyhd3eb1b0_0 43 | - dlpack=0.5=h9c3ff4c_0 44 | - entrypoints=0.3=pyhd8ed1ab_1003 45 | - et_xmlfile=1.1.0=py37h06a4308_0 46 | - expat=2.4.1=h2531618_2 47 | - fastavro=1.4.4=py37h5e8e339_0 48 | - fastrlock=0.6=py37hcd2ae1e_1 49 | - fontconfig=2.13.1=h6c09931_0 50 | - freetype=2.10.4=h5ab3b9f_0 51 | - fsspec=2021.8.1=pyhd8ed1ab_0 52 | - gawk=5.1.0=h7b6447c_0 53 | - gflags=2.2.2=he1b5a44_1004 54 | - glib=2.68.2=h36276a3_0 55 | - glog=0.5.0=h48cff8f_0 56 | - grpc-cpp=1.39.0=hf1f433d_2 57 | - gst-plugins-base=1.14.0=h8213a91_2 58 | - gstreamer=1.14.0=h28cd5cc_2 59 | - icu=58.2=he6710b0_3 60 | - idna=2.10=pyhd3eb1b0_0 61 | - ipykernel 62 | - ipyparallel=7.0.1=pyhd8ed1ab_0 63 | - ipython=7.27.0=py37h6531663_0 64 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 65 | - jbig=2.1=h7f98852_2003 66 | - jedi=0.18.0=py37h89c1867_2 67 | - joblib=1.0.1=pyhd3eb1b0_0 68 | - jpeg=9d=h36c2ea0_0 69 | - json5=0.9.5=py_0 70 | - jsonschema=3.2.0=py_2 71 | - jupyterlab 72 | - kiwisolver=1.3.1=py37h2527ec5_1 73 | - krb5=1.17.1=h173b8e3_0 74 | - lcms2=2.12=h3be6417_0 75 | - ld_impl_linux-64=2.35.1=h7274673_9 76 | - lz4-c=1.9.3=h2531618_0 77 | - markupsafe=2.0.1=py37h5e8e339_0 78 | - mistune=0.8.4=py37h5e8e339_1004 79 | - mmseqs2=13.45111=h2d02072_0 80 | - mpi=1.0=mpich 81 | - mpich=3.3.2=hc856adb_0 82 | - mypy=0.910=pyhd3eb1b0_0 83 | - mypy_extensions=0.4.3=py37_0 84 | - ncurses=6.2=he6710b0_1 85 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 86 | - nodejs=6.11.2=h3db8ef7_0 87 | - notebook=6.0.3=py37hc8dfbb8_1 88 | - numba=0.54.0=np1.11py3.7hc13618b_g5888895d3_0 89 | - nvtx=0.2.3=py37h5e8e339_0 90 | - olefile=0.46=py_0 91 | - openpyxl=3.0.9=pyhd3eb1b0_0 92 | - openssl=1.1.1l=h7f8727e_0 93 | - orc=1.6.9=h58a87f1_0 94 | - packaging=20.9=pyhd3eb1b0_0 95 | - pandarallel=1.4.8=py37h39e3cac_0 96 | - pandas=1.2.5=py37h219a48f_0 97 | - pandoc=2.12=h06a4308_0 98 | - pandocfilters=1.4.2=py_1 99 | - parquet-cpp=1.5.1=2 100 | - parso=0.8.2=pyhd3eb1b0_0 101 | - pcre=8.44=he6710b0_0 102 | - pexpect=4.8.0=pyhd3eb1b0_3 103 | - pickleshare=0.7.5=pyhd3eb1b0_1003 104 | - pip=21.2.4=pyhd8ed1ab_0 105 | - prometheus_client=0.11.0=pyhd3eb1b0_0 106 | - prompt-toolkit=3.0.17=pyh06a4308_0 107 | - protobuf=3.16.0=py37hcd2ae1e_0 108 | - psutil=5.8.0=py37h5e8e339_1 109 | - psycopg2=2.8.5=py37hb09aad4_1 110 | - ptyprocess=0.7.0=pyhd3eb1b0_2 111 | - py-xgboost=1.3.3=py37h06a4308_0 112 | - pyarrow=4.0.1=py37h3373063_7_cuda 113 | - pycparser=2.20=py_2 114 | - pygments=2.9.0=pyhd3eb1b0_0 115 | - pyopenssl=20.0.1=pyhd3eb1b0_1 116 | - pyparsing=2.4.7=pyhd3eb1b0_0 117 | - pyrsistent=0.17.3=py37h5e8e339_2 118 | - pysocks=1.7.1=py37h89c1867_3 119 | - python=3.7.10=hffdb5ce_100_cpython 120 | - python-dateutil=2.8.1=pyhd3eb1b0_0 121 | - python_abi=3.7=2_cp37m 122 | - pytz=2021.1=pyhd3eb1b0_0 123 | - pyyaml=5.4.1=py37h5e8e339_0 124 | - pyzmq=22.1.0=py37h336d617_0 125 | - qt=5.9.7=h5867ecd_1 126 | - re2=2021.08.01=h9c3ff4c_0 127 | - readline=8.1=h27cfd23_0 128 | - requests=2.25.1=pyhd3eb1b0_0 129 | - scikit-learn=0.24.2=py37ha9443f7_0 130 | - scikit-learn-intelex=2021.2.2=py37h89c1867_1 131 | - scipy=1.6.2=py37had2a1c9_1 132 | - send2trash=1.5.0=pyhd3eb1b0_1 133 | - setuptools=58.0.4=py37h89c1867_0 134 | - snappy=1.1.8=he1b5a44_3 135 | - sniffio=1.2.0=py37h89c1867_1 136 | - spdlog=1.8.5=h4bd325d_0 137 | - sqlalchemy=1.4.0=py37h5e8e339_0 138 | - sqlite=3.35.4=hdfb4753_0 139 | - tbb=2021.2.0=hff7bd54_0 140 | - terminado=0.12.1=py37h89c1867_0 141 | - testpath=0.4.4=pyhd3eb1b0_0 142 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 143 | - tk=8.6.10=hbc83047_0 144 | - toml=0.10.2=pyhd3eb1b0_0 145 | - tornado=6.1=py37h5e8e339_1 146 | - tqdm=4.62.2=pyhd8ed1ab_0 147 | - traitlets=5.0.5=pyhd3eb1b0_0 148 | - urllib3=1.26.4=pyhd3eb1b0_0 149 | - wcwidth=0.2.5=py_0 150 | - webencodings=0.5.1=py_1 151 | - wget=1.20.1=h20c2e04_0 152 | - wheel=0.36.2=pyhd3eb1b0_0 153 | - xgboost=1.3.3=py37h06a4308_0 154 | - xz=5.2.5=h7b6447c_0 155 | - yaml=0.2.5=h516909a_0 156 | - pip: 157 | - absl-py==0.14.1 158 | - alembic==1.7.4 159 | - astunparse==1.6.3 160 | - autopage==0.4.0 161 | - cached-property==1.5.2 162 | - clang==5.0 163 | - cliff==3.9.0 164 | - cmaes==0.8.2 165 | - fair-esm==0.4.0 166 | - gast==0.4.0 167 | - h5py==3.1.0 168 | - keras==2.6.0 169 | - numpy==1.19.5 170 | - pillow==8.3.2 171 | - prettytable==2.2.1 172 | - pyasn1==0.4.8 173 | - pyasn1-modules==0.2.8 174 | - pybind11==2.6.1 175 | - pyperclip==1.8.2 176 | - requests-oauthlib==1.3.0 177 | - rsa==4.7.2 178 | - six==1.15.0 179 | - stevedore==3.4.0 180 | - tensorboard==2.7.0 181 | - tensorboard-data-server==0.6.1 182 | - tensorboard-plugin-wit==1.8.0 183 | - tensorflow==2.6.1 184 | - tensorflow-estimator==2.6.0 185 | - termcolor==1.1.0 186 | - torch==1.9.1+cu111 187 | - typing-extensions==3.7.4.3 188 | - werkzeug==2.0.2 189 | - wrapt==1.12.1 190 | prefix: /hpcfs/fhome/shizhenkun/miniconda3/envs/ECRECer 191 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | 11 | 12 | # Model FOLDER 13 | 14 | ## Folder Structure 15 | 16 | rootfolder/model 17 | 18 | ## Get Data 19 | 20 | Trained models can be downloaded from AWS S3. 21 | 22 | ``` 23 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/ec.h5 24 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/howmany_enzyme.h5 25 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/isenzyme.h5 26 | ``` -------------------------------------------------------------------------------- /production.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import joblib 4 | import os,sys 5 | import benchmark_common as bcommon 6 | import config as cfg 7 | import argparse 8 | import tools.funclib as funclib 9 | from tools.Attention import Attention 10 | from keras.models import load_model 11 | import tools.embedding_esm as esmebd 12 | import time 13 | from pandarallel import pandarallel # import pandaralle 14 | 15 | 16 | #region Integrate output 17 | def integrate_out_put(existing_table, blast_table, dmlf_pred_table, mode='p', topnum=1): 18 | """[Integrate output] 19 | 20 | Args: 21 | existing_table ([DataFrame]): [db search results table] 22 | blast_table ([DataFrame]): [sequence alignment results table] 23 | isEnzyme_pred_table (DataFrame): [isEnzyme prediction results table] 24 | how_many_table ([DataFrame]): [function counts prediction results table] 25 | ec_table ([DataFrame]): [ec prediction table] 26 | 27 | Returns: 28 | [DataFrame]: [final results] 29 | """ 30 | existing_table['res_type'] = 'db_match' 31 | blast_table['res_type']='blast_match' 32 | results_df = ec_table.merge(blast_table, on='id', how='left') 33 | 34 | function_df = how_many_table.copy() 35 | function_df = function_df.merge(isEnzyme_pred_table, on='id', how='left') 36 | function_df = function_df.merge(blast_table[['id', 'ec_number']], on='id', how='left') 37 | function_df['pred_function_counts']=function_df.parallel_apply(lambda x :integrate_enzyme_functioncounts(x.ec_number, x.isEnzyme_pred, x.pred_s, x.pred_m), axis=1) 38 | results_df = results_df.merge(function_df[['id','pred_function_counts']],on='id',how='left') 39 | 40 | results_df.loc[results_df[results_df.res_type.isnull()].index,'res_type']='dmlf_pred' 41 | results_df['pred_ec']=results_df.parallel_apply(lambda x: gather_ec_by_fc(x.iloc[3:23],x.ec_number, x.pred_function_counts), axis=1) 42 | results_df = results_df.iloc[:,np.r_[0,23,1,2,32,27:31]].rename(columns={'seq_x':'seq','seqlength_x':'seqlength'}) 43 | 44 | 45 | if mode=='p': 46 | existing_table['pred_ec']='' 47 | result_set = pd.concat([existing_table, results_df], axis=0) 48 | result_set = result_set.drop_duplicates(subset=['id'], keep='first').sort_values(by='res_type') 49 | result_set['ec_number'] = result_set.apply(lambda x: x.pred_ec if str(x.ec_number)=='nan' else x.ec_number, axis=1) 50 | result_set.reset_index(drop=True, inplace=True) 51 | result_set = result_set.iloc[:,0:9] 52 | 53 | result_set['seqlength'] = result_set.seq.apply(lambda x: len(x)) 54 | result_set['ec_number'] = result_set.ec_number.apply(lambda x: 'Non-Enzyme' if len(x)==1 else x) 55 | result_set = result_set.rename(columns={'ec_number':'ecrecer_pred_ec_number'}) 56 | 57 | result_set = result_set[['id','ecrecer_pred_ec_number','seq','seqlength']] 58 | 59 | if mode =='r': 60 | result_set= results_df.merge(ec_table, on=['id'], how='left') 61 | result_set=result_set.iloc[:,np.r_[0:3,30,5:9, 4,10:30]] 62 | result_set = result_set.rename(columns=dict({'seq_x': 'seq','pred_ec': 'top0','top0_y': 'top1' }, **{'top'+str(i) : 'top'+str(i+1) for i in range(0, 20)})) 63 | # result_set = result_set.iloc[:,0:(8+topnum)] 64 | # result_set.loc[result_set[result_set.id.isin(existing_table.id)].index.values,'res_type']= 'db_match' 65 | 66 | result_set = result_set.iloc[:,np.r_[0, 2:4,8:(8+topnum)]] 67 | 68 | return result_set 69 | 70 | #endregion 71 | 72 | #region Predict Function Counts 73 | def predict_function_counts(test_data): 74 | """[Predict Function Counts] 75 | 76 | Args: 77 | test_data ([DataFrame]): [DF contain protein ID and Seq] 78 | 79 | Returns: 80 | [DataFrame]: [col1:id, col2: single or multi; col3: multi counts] 81 | """ 82 | res=pd.DataFrame() 83 | res['id']=test_data.id 84 | model_s = joblib.load(cfg.MODELDIR+'/single_multi.model') 85 | model_m = joblib.load(cfg.MODELDIR+'/multi_many.model') 86 | pred_s=model_s.predict(np.array(test_data.iloc[:,1:])) 87 | pred_m=model_m.predict(np.array(test_data.iloc[:,1:])) 88 | res['pred_s']=1-pred_s 89 | res['pred_m']=pred_m+2 90 | 91 | return res 92 | #endregion 93 | 94 | #region Integrate function counts by blast, single and multi 95 | def integrate_enzyme_functioncounts(blast, isEnzyme, single, multi): 96 | """[Integrate function counts by blast, single and multi] 97 | 98 | Args: 99 | blast ([type]): [blast results] 100 | s ([type]): [single prediction] 101 | m ([type]): [multi prediction] 102 | 103 | Returns: 104 | [type]: [description] 105 | """ 106 | if str(blast)!='nan': 107 | if str(blast)=='-': 108 | return 0 109 | else: 110 | return len(blast.split(',')) 111 | if isEnzyme == 0: 112 | return 0 113 | if single ==1: 114 | return 1 115 | return multi 116 | #endregion 117 | 118 | #region format finnal ec by function counts 119 | def gather_ec_by_fc(toplist, ec_blast ,counts): 120 | """[format finnal ec by function counts] 121 | 122 | Args: 123 | toplist ([list]): [top 20 predicted EC] 124 | ec_blast ([string]): [blast results] 125 | counts ([int]): [function counts] 126 | 127 | Returns: 128 | [string]: [comma sepreated ec string] 129 | """ 130 | if counts==0: 131 | return '-' 132 | elif str(ec_blast)!='nan': 133 | return str(ec_blast) 134 | else: 135 | return ','.join(toplist[0:counts]) 136 | #endregion 137 | 138 | 139 | 140 | 141 | #region run 142 | def step_by_step_run(input_fasta, output_tsv, mode='p', topnum=1): 143 | """[run] 144 | Args: 145 | input_fasta ([string]): [input fasta file] 146 | output_tsv ([string]): [output tsv file] 147 | """ 148 | start = time.process_time() 149 | if mode =='p': 150 | print('run in annoation mode') 151 | if mode =='r': 152 | print('run in recommendation mode') 153 | if mode =='h': 154 | print('run in hybrid mode') 155 | 156 | # 1. 读入数据 157 | print('step 1: loading data') 158 | input_df = funclib.load_fasta_to_table(input_fasta) # test fasta 159 | latest_sprot = pd.read_feather(cfg.FILE_LATEST_SPROT_FEATHER) #sprot db 160 | 161 | 162 | # 2. 查找数据 163 | print('step 2: find existing data') 164 | find_data = input_df.merge(latest_sprot, on='seq', how='left') 165 | find_data = latest_sprot[latest_sprot.seq.isin(input_df.seq)] 166 | find_data = find_data.drop_duplicates(subset='seq').reset_index(drop=True) 167 | exist_data = find_data.merge(input_df, on='seq', how='left').iloc[:,np.r_[8,0,1:8]].rename(columns={'id_x':'uniprot_id','id_y':'input_id'}).reset_index(drop=True) 168 | noExist_data = input_df[~input_df.seq.isin(find_data.seq)] 169 | 170 | if len(noExist_data) == 0 and mode=='p': 171 | exist_data=exist_data[['input_id','ec_number']].rename(columns={'ec_number':'ec_pred'}) 172 | exist_data.to_csv(output_tsv, sep='\t', index=False) 173 | end = time.process_time() 174 | print('All done running time: %s Seconds'%(end-start)) 175 | return 176 | 177 | 178 | # 3. EMBedding 179 | print('step 3: Embedding') 180 | featurebank_esm32 = pd.read_feather(cfg.FILE_FEATURE_ESM32) 181 | 182 | existing_feature = featurebank_esm32[featurebank_esm32.id.isin(exist_data.uniprot_id)] 183 | existing_feature = exist_data[['input_id','uniprot_id']].merge(existing_feature.rename(columns={'id':'uniprot_id'}), on='uniprot_id', how='left').rename(columns={'input_id':'id'}).iloc[:,np.r_[0,2:existing_feature.shape[1]+1]] 184 | 185 | rep0, rep32, rep33 = esmebd.get_rep_multi_sequence(sequences=noExist_data, model='esm1b_t33_650M_UR50S',seqthres=1022) 186 | 187 | rep32 = pd.concat([existing_feature,rep32],axis=0).reset_index(drop=True) 188 | 189 | print('step 4: run prediction') 190 | 191 | if mode=='p': 192 | 193 | # 5. isEnzyme Prediction 194 | print('step 5: predict isEnzyme') 195 | pred_dmlf = pd.DataFrame(rep32.id.copy()) 196 | model_isEnzyme = load_model(cfg.ISENZYME_MODEL,custom_objects={"Attention": Attention}, compile=False) 197 | predicted = model_isEnzyme.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 198 | encoder_t1=joblib.load(cfg.DICT_LABEL_T1) 199 | 200 | pred_dmlf['dmlf_isEnzyme']=(encoder_t1.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0] 201 | 202 | 203 | # 6. How many Prediction 204 | print('step 6: predict function counts') 205 | model_howmany = load_model(cfg.HOWMANY_MODEL,custom_objects={"Attention": Attention}, compile=False) 206 | predicted = model_howmany.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 207 | encoder_t2=joblib.load(cfg.DICT_LABEL_T2) 208 | pred_dmlf['dmlf_howmany']=(encoder_t2.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0] 209 | 210 | 211 | # 7. EC Prediction 212 | print('step 7: predict EC') 213 | model_ec = load_model(cfg.EC_MODEL,custom_objects={"Attention": Attention}, compile=False) 214 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 215 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3) 216 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))] 217 | 218 | 219 | print('step 8: integrate results') 220 | results = pred_dmlf.merge(exist_data, left_on='id',right_on='input_id', how='left') 221 | results=results.fillna('#') 222 | results['ec_pred'] =results.apply(lambda x : x.ec_number if x.ec_number!='#' else ('-' if x.dmlf_isEnzyme==False else x.dmlf_ec) ,axis=1) 223 | output_df = results[['id', 'ec_pred']].rename(columns={'id':'id_input'}) 224 | 225 | elif mode =='r': 226 | # print('step 4: recommendation') 227 | # label_model_ec = pd.read_feather(f'{cfg.MODELDIR}/task3_labels.feather').label_multi.to_list() 228 | # model_ec = load_model(f'{cfg.MODELDIR}/task3_esm32_2022.h5',custom_objects={"Attention": Attention}, compile=False) 229 | # predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 230 | # output_df=pd.DataFrame() 231 | # output_df['id']=input_df['id'].copy() 232 | # output_df['ec_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((label_model_ec), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values 233 | 234 | print('step 4: predict EC') 235 | pred_dmlf = pd.DataFrame(rep32.id.copy()) 236 | model_ec = load_model(cfg.EC_MODEL, custom_objects={"Attention": Attention}, compile=False) 237 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 238 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3) 239 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))] 240 | pred_dmlf['dmlf_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((encoder_t3.classes_), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values 241 | output_df = pred_dmlf[['id', 'dmlf_recomendations']].rename(columns={'id':'id_input'}) 242 | 243 | elif mode =='h': 244 | print('running in hybird mode') 245 | 246 | # 4. sequence alignment 247 | print('step 4: sequence alignment') 248 | if not os.path.exists(cfg.FILE_BLAST_PRODUCTION_DB): 249 | funclib.table2fasta(latest_sprot, cfg.FILE_BLAST_PRODUCTION_FASTA) 250 | cmd = r'diamond makedb --in {0} -d {1}'.format(cfg.FILE_BLAST_PRODUCTION_FASTA, cfg.FILE_BLAST_PRODUCTION_DB) 251 | os.system(cmd) 252 | blast_res = funclib.getblast_usedb(db=cfg.FILE_BLAST_PRODUCTION_DB, test=input_df) 253 | blast_res =blast_res[['id', 'sseqid']].merge(latest_sprot, left_on='sseqid', right_on='id', how='left').iloc[:,np.r_[0,2,3:10]].rename(columns={'id_x':'input_id','id_y':'uniprot_id'}).reset_index(drop=True) 254 | 255 | # 5. isEnzyme Prediction 256 | print('step 5: predict isEnzyme') 257 | pred_dmlf = pd.DataFrame(rep32.id.copy()) 258 | model_isEnzyme = load_model(cfg.ISENZYME_MODEL,custom_objects={"Attention": Attention}, compile=False) 259 | predicted = model_isEnzyme.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 260 | encoder_t1=joblib.load(cfg.DICT_LABEL_T1) 261 | pred_dmlf['dmlf_isEnzyme']=(encoder_t1.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0] 262 | 263 | 264 | # 6. How many Prediction 265 | print('step 6: predict function counts') 266 | model_howmany = load_model(cfg.HOWMANY_MODEL,custom_objects={"Attention": Attention}, compile=False) 267 | predicted = model_howmany.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 268 | encoder_t2=joblib.load(cfg.DICT_LABEL_T2) 269 | pred_dmlf['dmlf_functions']=(encoder_t2.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0] 270 | 271 | 272 | # 7. EC Prediction 273 | print('step 7: predict EC') 274 | model_ec = load_model(cfg.EC_MODEL,custom_objects={"Attention": Attention}, compile=False) 275 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1)) 276 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3) 277 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))] 278 | pred_dmlf['dmlf_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((encoder_t3.classes_), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values 279 | 280 | pred_dmlf = pred_dmlf.merge(blast_res[['input_id','ec_number']].rename(columns={'ec_number':'blast_ec'}), left_on='id', right_on='input_id', how='left') 281 | # pred_dmlf['dmlf_recomendations']=pred_dmlf.apply(lambda x: x.dmlf_recomendations if x.dmlf_isEnzyme else '-', axis=1 ) 282 | pred_dmlf['dmlf_ec']=pred_dmlf.apply(lambda x: x.dmlf_ec if x.dmlf_isEnzyme else '-', axis=1 ) 283 | pred_dmlf = pred_dmlf.merge(exist_data[['input_id','ec_number']].rename(columns={'ec_number':'db_ec'}), on='input_id', how='left') 284 | pred_dmlf['dmlf_ec']=pred_dmlf.apply(lambda x: x.db_ec if str(x.db_ec)!='nan' else x.dmlf_ec,axis=1) 285 | pred_dmlf['dmlf_isEnzyme']=pred_dmlf.apply(lambda x: True if (str(x.db_ec)!='nan' and x.db_ec!='-') else x.dmlf_isEnzyme,axis=1) 286 | pred_dmlf['dmlf_functions']=pred_dmlf.apply(lambda x: len(x.db_ec.split(',')) if str(x.db_ec)!='nan' else x.dmlf_functions,axis=1) 287 | 288 | 289 | output_df = pred_dmlf[['id', 'dmlf_isEnzyme', 'dmlf_functions', 'dmlf_ec', 'dmlf_recomendations', 'blast_ec' ]].rename(columns={'id':'input_id'}) 290 | 291 | else: 292 | print(f'mode:{mode} not found') 293 | sys.exit() 294 | 295 | 296 | print('step 9: writting results') 297 | 298 | output_df.to_csv(output_tsv, sep='\t', index=False) 299 | 300 | print(output_df) 301 | 302 | end = time.process_time() 303 | print('All done running time: %s Seconds'%(end-start)) 304 | #endregion 305 | 306 | 307 | if __name__ =='__main__': 308 | 309 | parser = argparse.ArgumentParser() 310 | parser.add_argument('-i', help='input file (fasta format)', type=str, default=cfg.DATADIR + 'sample_10.fasta') 311 | parser.add_argument('-o', help='output file (tsv table)', type=str, default=cfg.RESULTSDIR + 'sample_10_2023_07_18.tsv') 312 | parser.add_argument('-mode', help='compute mode. p: prediction, r: recommendation, h:hybrid', type=str, default='r') 313 | parser.add_argument('-topk', help='recommendation records, min=1, max=20', type=int, default='50') 314 | 315 | pandarallel.initialize() #init 316 | args = parser.parse_args() 317 | input_file = args.i 318 | output_file = args.o 319 | compute_mode = args.mode 320 | topk = args.topk 321 | 322 | step_by_step_run( input_fasta=input_file, 323 | output_tsv=output_file, 324 | mode=compute_mode, 325 | topnum=topk 326 | ) 327 | -------------------------------------------------------------------------------- /tasks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking tasks 2 | 3 | ## 1. Is enzyme task 4 | ## 2. How many function task 5 | ## 3. EC number prediction task -------------------------------------------------------------------------------- /tasks/task1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Task1. Enzyme or Non-enzyme Annotation\n", 10 | "\n", 11 | "> author: Shizhenkun \n", 12 | "> email: zhenkun.shi@tib.cas.cn \n", 13 | "> date: 2021-10-20 " 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "## 1. Import packages" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": { 27 | "tags": [] 28 | }, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "INFO: Pandarallel will run on 52 workers.\n", 35 | "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "import time\n", 43 | "import datetime\n", 44 | "import sys\n", 45 | "import os\n", 46 | "from tqdm import tqdm\n", 47 | "from functools import reduce\n", 48 | "import joblib\n", 49 | "\n", 50 | "sys.path.append(\"../tools/\")\n", 51 | "import funclib\n", 52 | "\n", 53 | "sys.path.append(\"../\")\n", 54 | "import benchmark_train as btrain\n", 55 | "import benchmark_test as btest\n", 56 | "import config as cfg\n", 57 | "import benchmark_evaluation as eva\n", 58 | "\n", 59 | "from sklearn.model_selection import train_test_split\n", 60 | "from xgboost import XGBClassifier\n", 61 | "\n", 62 | "from pandarallel import pandarallel # import pandaralle\n", 63 | "pandarallel.initialize() # init\n", 64 | "\n", 65 | "%load_ext autoreload\n", 66 | "%autoreload 2" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## 2. Load data" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 2, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "name": "stdout", 83 | "output_type": "stream", 84 | "text": [ 85 | "train size: 469134\n", 86 | "test size: 7101\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "#read train test data\n", 92 | "train = pd.read_feather(cfg.DATADIR+'task1/train.feather')\n", 93 | "test = pd.read_feather(cfg.DATADIR+'task1/test.feather')\n", 94 | "print('train size: {0}\\ntest size: {1}'.format(len(train), len(test)))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/html": [ 105 | "
\n", 106 | "\n", 119 | "\n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | "
idseqisenzyme
124P00958MSFLISFDKSKKHPAHLQLANNLKIALALEYASKNLKPEVDNDNAA...True
175P00812METGPHYNYYKNRELSIVLAPFSGGQGKLGVEKGPKYMLKHGLQTS...True
248P00959MTQVAKKILVTCALPYANGSIHLGHMLEHIQADVWVRYQRMRGHEV...True
249P00348MAFATRQLVRSLSSSSTAAASAKKILVKHVTVIGGGLMGAGIAQVA...True
250P00469MLEQPYLDLAKKVLDEGHFKPDRTHTGTYSIFGHQMRFDLSKGFPL...True
............
469123Q8I6K2MSNTAVLNDLVALYDRPTEPMFRVKAKKSFKVPKEYVTDRFKNVAV...True
469127O81103MATAPSPTTMGTYSSLISTNSFSTFLPNKSQLSLSGKSKHYVARRS...True
469129Q21221MSSGAPSGSSMSSTPGSPPPRAGGPNSVSFKDLCCLFCCPPFPSSI...True
469130Q6QJ72MSRLLLPKLFSISRTQVPAASLFNNLYRRHKRFVHWTSKMSTDSVR...True
469133D9XDR8MAKMSTTHEEIALAGPDGIPAVDLRDLIDAQLYMPFPFERNPHASE...True
\n", 197 | "

222567 rows × 3 columns

\n", 198 | "
" 199 | ], 200 | "text/plain": [ 201 | " id seq isenzyme\n", 202 | "124 P00958 MSFLISFDKSKKHPAHLQLANNLKIALALEYASKNLKPEVDNDNAA... True\n", 203 | "175 P00812 METGPHYNYYKNRELSIVLAPFSGGQGKLGVEKGPKYMLKHGLQTS... True\n", 204 | "248 P00959 MTQVAKKILVTCALPYANGSIHLGHMLEHIQADVWVRYQRMRGHEV... True\n", 205 | "249 P00348 MAFATRQLVRSLSSSSTAAASAKKILVKHVTVIGGGLMGAGIAQVA... True\n", 206 | "250 P00469 MLEQPYLDLAKKVLDEGHFKPDRTHTGTYSIFGHQMRFDLSKGFPL... True\n", 207 | "... ... ... ...\n", 208 | "469123 Q8I6K2 MSNTAVLNDLVALYDRPTEPMFRVKAKKSFKVPKEYVTDRFKNVAV... True\n", 209 | "469127 O81103 MATAPSPTTMGTYSSLISTNSFSTFLPNKSQLSLSGKSKHYVARRS... True\n", 210 | "469129 Q21221 MSSGAPSGSSMSSTPGSPPPRAGGPNSVSFKDLCCLFCCPPFPSSI... True\n", 211 | "469130 Q6QJ72 MSRLLLPKLFSISRTQVPAASLFNNLYRRHKRFVHWTSKMSTDSVR... True\n", 212 | "469133 D9XDR8 MAKMSTTHEEIALAGPDGIPAVDLRDLIDAQLYMPFPFERNPHASE... True\n", 213 | "\n", 214 | "[222567 rows x 3 columns]" 215 | ] 216 | }, 217 | "execution_count": 5, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "train[train.isenzyme]" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## 3. Sequence aligment" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 3, 236 | "metadata": { 237 | "tags": [] 238 | }, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "Write finished\n", 245 | "Write finished\n", 246 | "diamond makedb --in /tmp/train.fasta -d /tmp/train.dmnd\n" 247 | ] 248 | }, 249 | { 250 | "name": "stderr", 251 | "output_type": "stream", 252 | "text": [ 253 | "diamond v2.0.8.146 (C) Max Planck Society for the Advancement of Science\n", 254 | "Documentation, support and updates available at http://www.diamondsearch.org\n", 255 | "\n", 256 | "#CPU threads: 80\n", 257 | "Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)\n", 258 | "Database input file: /tmp/train.fasta\n", 259 | "Opening the database file... [0.004s]\n", 260 | "Loading sequences... [0.869s]\n", 261 | "Masking sequences... [0.304s]\n", 262 | "Writing sequences... [0.163s]\n", 263 | "Hashing sequences... [0.05s]\n", 264 | "Loading sequences... [0s]\n", 265 | "Writing trailer... [0.003s]\n", 266 | "Closing the input file... [0.001s]\n", 267 | "Closing the database file... [0.038s]\n", 268 | "Database hash = eed65be4bf3bb33f8407f23b2e861bca\n", 269 | "Processed 469134 sequences, 176795800 letters.\n", 270 | "Total time = 1.436s\n" 271 | ] 272 | }, 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "diamond blastp -d /tmp/train.dmnd -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet\n", 278 | " aligment finished \n", 279 | " query samples:7101\n", 280 | " results samples: 5111\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "# blast\n", 286 | "res_data=funclib.getblast(train,test)\n", 287 | "print(' aligment finished \\n query samples:{0}\\n results samples: {1}'.format(len(test), len(res_data)))\n", 288 | "\n", 289 | "res_data = res_data[['id', 'sseqid']].merge(train, left_on='sseqid', right_on='id', how='left')[['id_x', 'isenzyme']]\n", 290 | "res_data =res_data.rename(columns={'id_x':'id','isenzyme':'isenzyme_blast'})\n", 291 | "res_data = test[['id','isenzyme']].merge(res_data, on='id', how='left')" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 4, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t confusion Matrix\n", 304 | "Blast \t\t0.667934 \t0.966532 \t\t0.884197 \t0.795400 \t0.872655 \t tp: 2628 fp: 91 fn: 277 tn: 2115 up: 399 un: 1591\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "print('baslineName', '\\t', 'accuracy','\\t', 'precision(PPV) \\t NPV \\t\\t', 'recall','\\t', 'f1', '\\t\\t', '\\t confusion Matrix')\n", 310 | "eva.caculateMetrix(groundtruth=res_data.isenzyme, predict=res_data.isenzyme_blast, baselineName='Blast', type='include_unfind')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "tags": [] 317 | }, 318 | "source": [ 319 | "## 4. Embedding Comparison\n", 320 | "### 4.1 one-hot + ML" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 4, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "trainset = train.copy()\n", 330 | "testset = test.copy()" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 7, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "MAX_SEQ_LENGTH = 1500 #定义序列最长的长度\n", 340 | "trainset.seq = trainset.seq.map(lambda x : x[0:MAX_SEQ_LENGTH].ljust(MAX_SEQ_LENGTH, 'X'))\n", 341 | "testset.seq = testset.seq.map(lambda x : x[0:MAX_SEQ_LENGTH].ljust(MAX_SEQ_LENGTH, 'X'))\n", 342 | "f_train = funclib.dna_onehot(trainset) #训练集编码\n", 343 | "f_test = funclib.dna_onehot(testset) #测试集编码" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 8, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n", 356 | "knn \t\t0.611745 \t0.666667 \t\t0.595238 \t0.331114 \t0.442467 \t tp: 1094 fp: 547 fn: 2210 tn: 3250\n", 357 | "lr \t\t0.672159 \t0.630412 \t\t0.718666 \t0.713983 \t0.669600 \t tp: 2359 fp: 1383 fn: 945 tn: 2414\n", 358 | "xg \t\t0.723278 \t0.730942 \t\t0.717991 \t0.641344 \t0.683218 \t tp: 2119 fp: 780 fn: 1185 tn: 3017\n", 359 | "dt \t\t0.617096 \t0.599932 \t\t0.629133 \t0.531477 \t0.563633 \t tp: 1756 fp: 1171 fn: 1548 tn: 2626\n", 360 | "rf \t\t0.715111 \t0.691709 \t\t0.735904 \t0.699455 \t0.695561 \t tp: 2311 fp: 1030 fn: 993 tn: 2767\n", 361 | "gbdt \t\t0.689621 \t0.646510 \t\t0.737974 \t0.734564 \t0.687730 \t tp: 2427 fp: 1327 fn: 877 tn: 2470\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "# 计算指标\n", 367 | "X_train = np.array(f_train.iloc[:,2:])\n", 368 | "X_test = np.array(f_test.iloc[:,2:])\n", 369 | "Y_train = np.array(trainset.isenzyme.astype('int'))\n", 370 | "Y_test = np.array(testset.isenzyme.astype('int'))\n", 371 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "### 4.2 Unirep + ML" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 8, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "train_unirep = pd.read_feather(cfg.DATADIR + 'train_unirep.feather')\n", 388 | "test_unirep = pd.read_feather(cfg.DATADIR + 'test_unirep.feather')" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 9, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "train_unirep = trainset.merge(train_unirep, on='id', how='left')\n", 398 | "test_unirep = testset.merge(test_unirep, on='id', how='left')" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 10, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "name": "stdout", 408 | "output_type": "stream", 409 | "text": [ 410 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n", 411 | "knn \t\t0.851852 \t0.854088 \t\t0.850038 \t0.822034 \t0.837754 \t tp: 2716 fp: 464 fn: 588 tn: 3333\n", 412 | "lr \t\t0.809604 \t0.933008 \t\t0.752218 \t0.636501 \t0.756747 \t tp: 2103 fp: 151 fn: 1201 tn: 3646\n", 413 | "xg \t\t0.861991 \t0.885790 \t\t0.844461 \t0.807506 \t0.844839 \t tp: 2668 fp: 344 fn: 636 tn: 3453\n", 414 | "dt \t\t0.769610 \t0.789986 \t\t0.755740 \t0.687651 \t0.735275 \t tp: 2272 fp: 604 fn: 1032 tn: 3193\n", 415 | "rf \t\t0.841994 \t0.909535 \t\t0.801442 \t0.733354 \t0.811997 \t tp: 2423 fp: 241 fn: 881 tn: 3556\n", 416 | "gbdt \t\t0.785101 \t0.894060 \t\t0.734365 \t0.610472 \t0.725540 \t tp: 2017 fp: 239 fn: 1287 tn: 3558\n" 417 | ] 418 | } 419 | ], 420 | "source": [ 421 | "X_train =np.array(train_unirep.iloc[:,3:])\n", 422 | "X_test = np.array(test_unirep.iloc[:,3:])\n", 423 | "\n", 424 | "Y_train = np.array(train_unirep.isenzyme.astype('int')).flatten()\n", 425 | "Y_test = np.array(test_unirep.isenzyme.astype('int')).flatten()\n", 426 | "\n", 427 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "tags": [] 434 | }, 435 | "source": [ 436 | "### 4.3 ESM REP33 + ML" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 9, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "train_esm_33 = pd.read_feather(cfg.DATADIR + 'train_rep33.feather')\n", 446 | "test_esm_33 = pd.read_feather(cfg.DATADIR + 'test_rep33.feather')\n", 447 | "\n", 448 | "train_esm = trainset.merge(train_esm_33, on='id', how='left')\n", 449 | "test_esm = testset.merge(test_esm_33, on='id', how='left')" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 10, 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n", 462 | "knn \t\t0.927300 \t0.935953 \t\t0.920835 \t0.898296 \t0.916738 \t tp: 3215 fp: 220 fn: 364 tn: 4234\n", 463 | "lr \t\t0.908378 \t0.927005 \t\t0.895196 \t0.862252 \t0.893457 \t tp: 3086 fp: 243 fn: 493 tn: 4211\n", 464 | "xg \t\t0.928047 \t0.952913 \t\t0.910593 \t0.882090 \t0.916135 \t tp: 3157 fp: 156 fn: 422 tn: 4298\n", 465 | "dt \t\t0.833811 \t0.848664 \t\t0.823884 \t0.763062 \t0.803590 \t tp: 2731 fp: 487 fn: 848 tn: 3967\n", 466 | "rf \t\t0.916096 \t0.960965 \t\t0.887136 \t0.846046 \t0.899851 \t tp: 3028 fp: 123 fn: 551 tn: 4331\n", 467 | "gbdt \t\t0.865804 \t0.901703 \t\t0.843089 \t0.784297 \t0.838912 \t tp: 2807 fp: 306 fn: 772 tn: 4148\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "X_train = np.array(train_esm.iloc[:,4:])\n", 473 | "X_test = np.array(test_esm.iloc[:,4:])\n", 474 | "\n", 475 | "Y_train = np.array(train_esm.isemzyme.astype('int')).flatten()\n", 476 | "Y_test = np.array(test_esm.isemzyme.astype('int')).flatten()\n", 477 | "\n", 478 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "### 4.4 ESM REP32 + ML" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 10, 491 | "metadata": { 492 | "tags": [] 493 | }, 494 | "outputs": [ 495 | { 496 | "name": "stdout", 497 | "output_type": "stream", 498 | "text": [ 499 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n", 500 | "knn \t\t0.924799 \t0.939125 \t\t0.913352 \t0.896489 \t0.917312 \t tp: 2962 fp: 192 fn: 342 tn: 3605\n", 501 | "lr \t\t0.908604 \t0.927536 \t\t0.893894 \t0.871671 \t0.898736 \t tp: 2880 fp: 225 fn: 424 tn: 3572\n", 502 | "xg \t\t0.921420 \t0.949869 \t\t0.899975 \t0.877421 \t0.912209 \t tp: 2899 fp: 153 fn: 405 tn: 3644\n", 503 | "dt \t\t0.830587 \t0.855981 \t\t0.812530 \t0.764528 \t0.807674 \t tp: 2526 fp: 425 fn: 778 tn: 3372\n", 504 | "rf \t\t0.908604 \t0.961418 \t\t0.872633 \t0.837167 \t0.895001 \t tp: 2766 fp: 111 fn: 538 tn: 3686\n", 505 | "gbdt \t\t0.874947 \t0.905641 \t\t0.852777 \t0.816283 \t0.858644 \t tp: 2697 fp: 281 fn: 607 tn: 3516\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "train_esm_32 = pd.read_feather(cfg.DATADIR + 'train_rep32.feather')\n", 511 | "test_esm_32 = pd.read_feather(cfg.DATADIR + 'test_rep32.feather')\n", 512 | "\n", 513 | "train_esm = trainset.merge(train_esm_32, on='id', how='left')\n", 514 | "test_esm = testset.merge(test_esm_32, on='id', how='left')\n", 515 | "\n", 516 | "X_train = np.array(train_esm.iloc[:,4:])\n", 517 | "X_test = np.array(test_esm.iloc[:,4:])\n", 518 | "\n", 519 | "Y_train = np.array(train_esm.isenzyme.astype('int')).flatten()\n", 520 | "Y_test = np.array(test_esm.isenzyme.astype('int')).flatten()\n", 521 | "\n", 522 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": {}, 528 | "source": [ 529 | "### 4.5 ESM REP0 + ML" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 7, 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "name": "stdout", 539 | "output_type": "stream", 540 | "text": [ 541 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n", 542 | "knn \t\t0.824599 \t0.789179 \t\t0.855641 \t0.827326 \t0.807802 \t tp: 2961 fp: 791 fn: 618 tn: 3663\n", 543 | "lr \t\t0.757251 \t0.721031 \t\t0.787948 \t0.742386 \t0.731553 \t tp: 2657 fp: 1028 fn: 922 tn: 3426\n", 544 | "xg \t\t0.847504 \t0.844555 \t\t0.849686 \t0.806091 \t0.824875 \t tp: 2885 fp: 531 fn: 694 tn: 3923\n", 545 | "dt \t\t0.760612 \t0.739722 \t\t0.776370 \t0.713887 \t0.726575 \t tp: 2555 fp: 899 fn: 1024 tn: 3555\n", 546 | "rf \t\t0.853853 \t0.863623 \t\t0.847017 \t0.797988 \t0.829509 \t tp: 2856 fp: 451 fn: 723 tn: 4003\n", 547 | "gbdt \t\t0.820988 \t0.810020 \t\t0.829258 \t0.781503 \t0.795506 \t tp: 2797 fp: 656 fn: 782 tn: 3798\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "train_esm_0 = pd.read_feather(cfg.DATADIR + 'train_rep0.feather')\n", 553 | "test_esm_0 = pd.read_feather(cfg.DATADIR + 'test_rep0.feather')\n", 554 | "\n", 555 | "train_esm = trainset.merge(train_esm_0, on='id', how='left')\n", 556 | "test_esm = testset.merge(test_esm_0, on='id', how='left')\n", 557 | "\n", 558 | "X_train = np.array(train_esm.iloc[:,4:])\n", 559 | "X_test = np.array(test_esm.iloc[:,4:])\n", 560 | "\n", 561 | "Y_train = np.array(train_esm.isemzyme.astype('int')).flatten()\n", 562 | "Y_test = np.array(test_esm.isemzyme.astype('int')).flatten()\n", 563 | "\n", 564 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)" 565 | ] 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "metadata": {}, 570 | "source": [ 571 | "## 5. Ours" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "# get blast results\n", 581 | "blastres=pd.DataFrame()\n", 582 | "blastres['id']=res.id\n", 583 | "blastres['isemzyme_groundtruth']=res.isemzyme\n", 584 | "blastres['isEmzyme_pred_blast']=res.isEmzyme_pred" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 6, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "#res32\n", 594 | "train_esm_32 = pd.read_feather(cfg.DATADIR + 'train_rep32.feather')\n", 595 | "test_esm_32 = pd.read_feather(cfg.DATADIR + 'test_rep32.feather')\n", 596 | "\n", 597 | "train_esm = trainset.merge(train_esm_32, on='id', how='left')\n", 598 | "test_esm = testset.merge(test_esm_32, on='id', how='left')\n", 599 | "\n", 600 | "X_train = np.array(train_esm.iloc[:,3:])\n", 601 | "X_test = np.array(test_esm.iloc[:,3:])\n", 602 | "\n", 603 | "Y_train = np.array(train_esm.isenzyme.astype('int')).flatten()\n", 604 | "Y_test = np.array(test_esm.isenzyme.astype('int')).flatten()" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": null, 610 | "metadata": {}, 611 | "outputs": [], 612 | "source": [ 613 | "# groundtruth, predict, predictprob = funclib.xgmain(X_train, Y_train, X_test, Y_test, type='binary')\n", 614 | "groundtruth, predict, predictprob = funclib.knnmain(X_train, Y_train, X_test, Y_test, type='binary')\n", 615 | "blastres['isEmzyme_pred_xg'] = predict\n", 616 | "blastres.isEmzyme_pred_xg =blastres.isEmzyme_pred_xg.astype('bool')\n", 617 | "blastres['isEmzyme_pred_slice']=blastres.apply(lambda x: x.isEmzyme_pred_xg if str(x.isEmzyme_pred_blast)=='nan' else x.isEmzyme_pred_blast, axis=1)\n", 618 | "print('baslineName', '\\t', 'accuracy','\\t', 'precision(PPV) \\t NPV \\t\\t', 'recall','\\t', 'f1', '\\t\\t', '\\t confusion Matrix')\n", 619 | "eva.caculateMetrix( groundtruth=blastres.isemzyme_groundtruth, predict=blastres.isEmzyme_pred_slice, baselineName='ours', type='binary')" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 7, 625 | "metadata": {}, 626 | "outputs": [ 627 | { 628 | "data": { 629 | "text/plain": [ 630 | "['/home/shizhenkun/codebase/DMLF/model/isenzyme.model']" 631 | ] 632 | }, 633 | "execution_count": 7, 634 | "metadata": {}, 635 | "output_type": "execute_result" 636 | } 637 | ], 638 | "source": [ 639 | "groundtruth, predict, predictprob, model = funclib.knnmain(X_train, Y_train, X_test, Y_test, type='binary')\n", 640 | "joblib.dump(model, cfg.ISENZYME_MODEL)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 62, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "# 保存文件\n", 650 | "blastres.to_csv(cfg.FILE_SLICE_ISENZYME_RESULTS, sep='\\t', index=None)" 651 | ] 652 | } 653 | ], 654 | "metadata": { 655 | "kernelspec": { 656 | "display_name": "DMLF", 657 | "language": "python", 658 | "name": "python3" 659 | }, 660 | "language_info": { 661 | "codemirror_mode": { 662 | "name": "ipython", 663 | "version": 3 664 | }, 665 | "file_extension": ".py", 666 | "mimetype": "text/x-python", 667 | "name": "python", 668 | "nbconvert_exporter": "python", 669 | "pygments_lexer": "ipython3", 670 | "version": "3.8.15" 671 | }, 672 | "vscode": { 673 | "interpreter": { 674 | "hash": "6b0f740237ba4768c544d9b9677983e49b45ca1230fda464ede0b93eba99c7d2" 675 | } 676 | }, 677 | "widgets": { 678 | "application/vnd.jupyter.widget-state+json": { 679 | "state": {}, 680 | "version_major": 2, 681 | "version_minor": 0 682 | } 683 | } 684 | }, 685 | "nbformat": 4, 686 | "nbformat_minor": 4 687 | } 688 | -------------------------------------------------------------------------------- /tmp/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/Attention.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2022-10-09 06:57:50 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2022-10-09 06:58:50 6 | FilePath: /DMLF/tools/Attention.py 7 | Description: 8 | 9 | Copyright (c) 2022 by tibd, All Rights Reserved. 10 | ''' 11 | 12 | # -*- coding: utf-8 -*- 13 | from keras import backend as K 14 | # from keras.engine.topology import Layer 15 | from keras.layers import Layer, InputSpec 16 | # from visualizer import get_local 17 | 18 | # 利用Keras构造注意力机制层 19 | class Attention(Layer): 20 | def __init__(self, attention_size, **kwargs): 21 | self.attention_size = attention_size 22 | super(Attention, self).__init__(**kwargs) 23 | 24 | def build(self, input_shape): 25 | # W: (EMBED_SIZE, ATTENTION_SIZE) 26 | # b: (ATTENTION_SIZE, 1) 27 | # u: (ATTENTION_SIZE, 1) 28 | self.W = self.add_weight(name="W_{:s}".format(self.name), 29 | shape=(input_shape[-1], self.attention_size), 30 | initializer="glorot_normal", 31 | trainable=True) 32 | self.b = self.add_weight(name="b_{:s}".format(self.name), 33 | shape=(input_shape[1], 1), 34 | initializer="zeros", 35 | trainable=True) 36 | self.u = self.add_weight(name="u_{:s}".format(self.name), 37 | shape=(self.attention_size, 1), 38 | initializer="glorot_normal", 39 | trainable=True) 40 | super(Attention, self).build(input_shape) 41 | 42 | def call(self, x, mask=None): 43 | # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE) 44 | # et: (BATCH_SIZE, MAX_TIMESTEPS, ATTENTION_SIZE) 45 | et = K.tanh(K.dot(x, self.W) + self.b) 46 | # at: (BATCH_SIZE, MAX_TIMESTEPS) 47 | at = K.softmax(K.squeeze(K.dot(et, self.u), axis=-1)) 48 | if mask is not None: 49 | at *= K.cast(mask, K.floatx()) 50 | # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE) 51 | atx = K.expand_dims(at, axis=-1) 52 | ot = atx * x 53 | # output: (BATCH_SIZE, EMBED_SIZE) 54 | output = K.sum(ot, axis=1) 55 | return output 56 | 57 | def compute_mask(self, input, input_mask=None): 58 | return None 59 | 60 | def compute_output_shape(self, input_shape): 61 | return (input_shape[0], input_shape[-1]) 62 | 63 | # 该函数用于保存和加载模型时使用 64 | def get_config(self): 65 | config = {"attention_size": self.attention_size} 66 | base_config = super(Attention, self).get_config() 67 | return dict(list(base_config.items()) + list(config.items())) -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/tools/__init__.py -------------------------------------------------------------------------------- /tools/baselines.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9307705e-50b0-4bae-b67c-ae8d62eb1737", 6 | "metadata": {}, 7 | "source": [ 8 | "# Task1. Baseline Methods\n", 9 | "\n", 10 | "> author: Shizhenkun \n", 11 | "> email: zhenkun.shi@tib.cas.cn \n", 12 | "> date: 2021-10-20 " 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "cb15383d-9838-48f5-9cf9-b369418a7e6b", 18 | "metadata": {}, 19 | "source": [ 20 | "## 1. Import packages" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "71bdaefa-a752-4ee3-851a-aed07d6b3363", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "INFO: Pandarallel will run on 80 workers.\n", 34 | "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n", 35 | "The autoreload extension is already loaded. To reload it, use:\n", 36 | " %reload_ext autoreload\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "import time\n", 44 | "import datetime\n", 45 | "import sys\n", 46 | "import os\n", 47 | "from tqdm import tqdm\n", 48 | "from functools import reduce\n", 49 | "\n", 50 | "sys.path.append(\"../tools/\")\n", 51 | "import funclib\n", 52 | "\n", 53 | "sys.path.append(\"../\")\n", 54 | "import benchmark_train as btrain\n", 55 | "import benchmark_test as btest\n", 56 | "import config as cfg\n", 57 | "import benchmark_evaluation as eva\n", 58 | "\n", 59 | "from sklearn.model_selection import train_test_split\n", 60 | "from xgboost import XGBClassifier\n", 61 | "\n", 62 | "from pandarallel import pandarallel # import pandaralle\n", 63 | "pandarallel.initialize() # init\n", 64 | "\n", 65 | "%load_ext autoreload\n", 66 | "%autoreload 2" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "04b3f7db-f81d-481e-b0c5-278c60fe87f0", 72 | "metadata": {}, 73 | "source": [ 74 | "## 2. Load Data" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "id": "8482e6a8-6523-4ed6-8784-d245538cf0b1", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "test size: 7101\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "test = pd.read_feather(cfg.DATADIR+'task1/test.feather')\n", 93 | "print('test size: {0}'.format(len(test)))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "5625d510-2da8-47f7-b082-7bd39ce5137a", 99 | "metadata": {}, 100 | "source": [ 101 | "## 3.Make fasta" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 9, 107 | "id": "4b090ea7-77b4-40a3-ba7c-78fa1dfb6ec3", 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Write finished\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "funclib.table2fasta(table=test, file_out='../data/test.fasta')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "da7733ac-5080-4879-9d05-56cd657cfbd8", 125 | "metadata": {}, 126 | "source": [ 127 | "## 4. ECPred\n", 128 | "\n", 129 | "Please be patient, this method takes a long time to predict" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 3, 135 | "id": "058a7ff1-45ee-4db8-871d-e25b3d81c6d7", 136 | "metadata": { 137 | "collapsed": true, 138 | "jupyter": { 139 | "outputs_hidden": true 140 | }, 141 | "tags": [] 142 | }, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "Hit:1 http://cn.archive.ubuntu.com/ubuntu focal InRelease\n", 149 | "Hit:2 http://cn.archive.ubuntu.com/ubuntu focal-updates InRelease\n", 150 | "Hit:3 http://cn.archive.ubuntu.com/ubuntu focal-backports InRelease\n", 151 | "Hit:4 http://cn.archive.ubuntu.com/ubuntu focal-security InRelease\n", 152 | "Reading package lists... Done\n", 153 | "Reading package lists... Done\n", 154 | "Building dependency tree \n", 155 | "Reading state information... Done\n", 156 | "default-jre is already the newest version (2:1.11-72).\n", 157 | "0 upgraded, 0 newly installed, 0 to remove and 44 not upgraded.\n", 158 | "Reading package lists... Done\n", 159 | "Building dependency tree \n", 160 | "Reading state information... Done\n", 161 | "default-jdk is already the newest version (2:1.11-72).\n", 162 | "0 upgraded, 0 newly installed, 0 to remove and 44 not upgraded.\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "!sudo apt-get update -y\n", 168 | "!sudo apt-get install default-jre -y\n", 169 | "!sudo apt-get install default-jdk -y" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "6f0e293f-9651-4c7a-95b1-921dda74adda", 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "Main classes of input proteins are being predicted ...\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "!java -jar ../baselines/ECPred/ECPred.jar weighted ../data/test.fasta ../baselines/ECPred/ ../results/ecpred.txt" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "2934039b-ddac-45cd-b361-5ae8d218753a", 193 | "metadata": {}, 194 | "source": [ 195 | "## 5. DeepEC" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "6f9c3112-081b-4b2f-8fcd-95cad5b74d61", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "! conda env create -f ../baselines/deepec/environment.yml\n", 206 | "! conda activate deepec" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "e07e3f89-5292-4e63-ad6b-290a786ee043", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "! python ../baselines/deepec/deepec.py -i ../data/test.fasta -o ../results/deepec/" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "ac05fbbe-ca3b-4b71-87e7-d5f81d8cd675", 222 | "metadata": {}, 223 | "source": [ 224 | "## 6. CatFam" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 1, 230 | "id": "4e11cfc0-8e51-490c-8b6e-538df2c053db", 231 | "metadata": { 232 | "collapsed": true, 233 | "jupyter": { 234 | "outputs_hidden": true 235 | }, 236 | "tags": [] 237 | }, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "Reading package lists... Done\n", 244 | "Building dependency tree \n", 245 | "Reading state information... Done\n", 246 | "The following additional packages will be installed:\n", 247 | " libmbedcrypto3 libmbedtls12 libmbedx509-0 ncbi-blast+ ncbi-data\n", 248 | "The following NEW packages will be installed:\n", 249 | " libmbedcrypto3 libmbedtls12 libmbedx509-0 ncbi-blast+ ncbi-blast+-legacy\n", 250 | " ncbi-data\n", 251 | "0 upgraded, 6 newly installed, 0 to remove and 44 not upgraded.\n", 252 | "Need to get 14.9 MB of archives.\n", 253 | "After this operation, 75.0 MB of additional disk space will be used.\n", 254 | "Get:1 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedcrypto3 amd64 2.16.4-1ubuntu2 [150 kB]\n", 255 | "Get:2 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedx509-0 amd64 2.16.4-1ubuntu2 [42.3 kB]\n", 256 | "Get:3 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedtls12 amd64 2.16.4-1ubuntu2 [71.8 kB]\n", 257 | "Get:4 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-data all 6.1.20170106+dfsg1-8 [3,518 kB]\n", 258 | "Get:5 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-blast+ amd64 2.9.0-2 [11.1 MB]\n", 259 | "Get:6 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-blast+-legacy all 2.9.0-2 [5,064 B]\n", 260 | "Fetched 14.9 MB in 5s (3,160 kB/s) \n", 261 | "Selecting previously unselected package libmbedcrypto3:amd64.\n", 262 | "(Reading database ... 216316 files and directories currently installed.)\n", 263 | "Preparing to unpack .../0-libmbedcrypto3_2.16.4-1ubuntu2_amd64.deb ...\n", 264 | "Unpacking libmbedcrypto3:amd64 (2.16.4-1ubuntu2) ...\n", 265 | "Selecting previously unselected package libmbedx509-0:amd64.\n", 266 | "Preparing to unpack .../1-libmbedx509-0_2.16.4-1ubuntu2_amd64.deb ...\n", 267 | "Unpacking libmbedx509-0:amd64 (2.16.4-1ubuntu2) ...\n", 268 | "Selecting previously unselected package libmbedtls12:amd64.\n", 269 | "Preparing to unpack .../2-libmbedtls12_2.16.4-1ubuntu2_amd64.deb ...\n", 270 | "Unpacking libmbedtls12:amd64 (2.16.4-1ubuntu2) ...\n", 271 | "Selecting previously unselected package ncbi-data.\n", 272 | "Preparing to unpack .../3-ncbi-data_6.1.20170106+dfsg1-8_all.deb ...\n", 273 | "Unpacking ncbi-data (6.1.20170106+dfsg1-8) ...\n", 274 | "Selecting previously unselected package ncbi-blast+.\n", 275 | "Preparing to unpack .../4-ncbi-blast+_2.9.0-2_amd64.deb ...\n", 276 | "Unpacking ncbi-blast+ (2.9.0-2) ...\n", 277 | "Selecting previously unselected package ncbi-blast+-legacy.\n", 278 | "Preparing to unpack .../5-ncbi-blast+-legacy_2.9.0-2_all.deb ...\n", 279 | "Unpacking ncbi-blast+-legacy (2.9.0-2) ...\n", 280 | "Setting up ncbi-data (6.1.20170106+dfsg1-8) ...\n", 281 | "Setting up libmbedcrypto3:amd64 (2.16.4-1ubuntu2) ...\n", 282 | "Setting up libmbedx509-0:amd64 (2.16.4-1ubuntu2) ...\n", 283 | "Setting up libmbedtls12:amd64 (2.16.4-1ubuntu2) ...\n", 284 | "Setting up ncbi-blast+ (2.9.0-2) ...\n", 285 | "Setting up ncbi-blast+-legacy (2.9.0-2) ...\n", 286 | "Processing triggers for libc-bin (2.31-0ubuntu9.2) ...\n", 287 | "Processing triggers for man-db (2.9.1-1) ...\n", 288 | "Processing triggers for hicolor-icon-theme (0.17-2) ...\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "!sudo apt-get install ncbi-blast+-legacy -y" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 21, 299 | "id": "af231781-8cf8-43f4-9ec9-4f1c8d577ff0", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "! ../baselines/catfam/source/catsearch.pl -d ../baselines/catfam/CatFamDB/CatFam_v2.0/CatFam4D99R -i ../data/test.fasta -o ../results/catfam.txt" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "id": "fd0de0c4", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "! java -Xmx50G -jar ../baselines/priam/PRIAM_search.jar -p /home/shizhenkun/codebase/DMLF/baselines/priam/PRIAM_JAN18 -i /home/shizhenkun/codebase/DMLF/data/gu_bang.fasta -o /home/shizhenkun/codebase/DMLF/results/case_gubang/priam/ --bp /home/shizhenkun/downloads/blast-2.2.13/bin --np 78" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "id": "05ec97bb-f0c4-41c6-9158-ed9ed0e95806", 319 | "metadata": {}, 320 | "source": [ 321 | "## 7. PRIAM" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 26, 327 | "id": "26bcceda-472e-474c-a37f-c20d56ded403", 328 | "metadata": { 329 | "collapsed": true, 330 | "jupyter": { 331 | "outputs_hidden": true 332 | }, 333 | "tags": [] 334 | }, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "\n", 341 | "\n", 342 | "PRIAM did not find the profiles library (Normal if it is the first time you use PRIAM with this release).\n", 343 | "So it would now create it. Please do not interupt it during this process else it would results into a corrupted database and any further analysis would be faulty.\n", 344 | "If for any reason you really need to interupt PRIAM before it ends, or if your computer crash during this process, please delete ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY, and all files it contains, to force PRIAM to recreate the library next time it would be launched\n", 345 | "\n", 346 | "Executing makeprofiledb -in /home/shizhenkun/codebase/DMLF/tools/../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/profiles.list -index T -out PROFILE_EZ -title PRIAM_profiles_database \n", 347 | "\n", 348 | "Profiles database sucessfully created.\n", 349 | "If you want to interupt PRIAM, it can now be done with no risk\n", 350 | "\n", 351 | "### PRIAM Profiles search ###\n", 352 | "7101 sequences found in your query file\n", 353 | "Query file splitted into 2 pieces\n", 354 | "Executing 'rpsblast -num_threads 78 -outfmt 5 -evalue 1.0 -query ../results/priam//PRIAM_20211102154953/TMP/DATA/query_part1.fas -db ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/PROFILE_EZ -out /home/shizhenkun/codebase/DMLF/tools/../results/priam/PRIAM_20211102154953/TMP/RAW_RESULTS/query_part1.res -num_alignments 10000 -num_descriptions 10000'.\n", 355 | "This RPS-BLAST job took 4mn 21s \n", 356 | "RPS-Blast jobs now running from 4mn 21s \n", 357 | "Estimated time to complete all 1 remaining RPS-BLAST jobs: 4mn 21s \n", 358 | "Executing 'rpsblast -num_threads 78 -outfmt 5 -evalue 1.0 -query ../results/priam//PRIAM_20211102154953/TMP/DATA/query_part2.fas -db ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/PROFILE_EZ -out /home/shizhenkun/codebase/DMLF/tools/../results/priam/PRIAM_20211102154953/TMP/RAW_RESULTS/query_part2.res -num_alignments 10000 -num_descriptions 10000'.\n", 359 | "This RPS-BLAST job took 1mn 52s \n", 360 | "Time used to search all sequences against profiles of PRIAM: 6mn 13s \n", 361 | "\n", 362 | "### Individual sequence annotation ###\n", 363 | "Exception in thread \"main\" java.lang.reflect.InvocationTargetException\n", 364 | "\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n", 365 | "\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\n", 366 | "\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n", 367 | "\tat java.base/java.lang.reflect.Method.invoke(Method.java:566)\n", 368 | "\tat org.eclipse.jdt.internal.jarinjarloader.JarRsrcLoader.main(JarRsrcLoader.java:58)\n", 369 | "Caused by: java.lang.NoClassDefFoundError: weka/core/Utils\n", 370 | "\tat priam.common.DistributionEstimator.(DistributionEstimator.java:188)\n", 371 | "\tat priam.common.ProfInfos.setPositivesHitsDistrib(ProfInfos.java:229)\n", 372 | "\tat priam.common.AnnotationRulesXMLHandler2.startElement(AnnotationRulesXMLHandler2.java:133)\n", 373 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.startElement(AbstractSAXParser.java:510)\n", 374 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractXMLDocumentParser.emptyElement(AbstractXMLDocumentParser.java:183)\n", 375 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl.scanStartElement(XMLDocumentFragmentScannerImpl.java:1377)\n", 376 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl$FragmentContentDriver.next(XMLDocumentFragmentScannerImpl.java:2710)\n", 377 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentScannerImpl.next(XMLDocumentScannerImpl.java:605)\n", 378 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl.scanDocument(XMLDocumentFragmentScannerImpl.java:534)\n", 379 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:888)\n", 380 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:824)\n", 381 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XMLParser.parse(XMLParser.java:141)\n", 382 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.parse(AbstractSAXParser.java:1216)\n", 383 | "\tat java.xml/com.sun.org.apache.xerces.internal.jaxp.SAXParserImpl$JAXPSAXParser.parse(SAXParserImpl.java:635)\n", 384 | "\tat java.xml/com.sun.org.apache.xerces.internal.jaxp.SAXParserImpl.parse(SAXParserImpl.java:324)\n", 385 | "\tat java.xml/javax.xml.parsers.SAXParser.parse(SAXParser.java:330)\n", 386 | "\tat priam.modules.Annotate.readAnnotationRulesXML(Annotate.java:65)\n", 387 | "\tat priam.modules.Annotate.AnnotateSequences(Annotate.java:110)\n", 388 | "\tat priam.search.Main.main_standard(Main.java:51)\n", 389 | "\tat priam.search.Main.main(Main.java:16)\n", 390 | "\t... 5 more\n", 391 | "Caused by: java.lang.ClassNotFoundException: weka.core.Utils\n", 392 | "\tat java.base/java.net.URLClassLoader.findClass(URLClassLoader.java:471)\n", 393 | "\tat java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:589)\n", 394 | "\tat java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:522)\n", 395 | "\t... 25 more\n" 396 | ] 397 | } 398 | ], 399 | "source": [ 400 | "! java -Xmx50G -jar ../baselines/priam/PRIAM_search.jar -p /home/shizhenkun/codebase/DMLF/baselines/priam/PRIAM_JAN18 -i /home/shizhenkun/codebase/DMLF/data/test.fasta -o /home/shizhenkun/codebase/DMLF/results/priam/ --bp /home/shizhenkun/downloads/blast-2.2.13/bin --np 78" 401 | ] 402 | } 403 | ], 404 | "metadata": { 405 | "kernelspec": { 406 | "display_name": "py38", 407 | "language": "python", 408 | "name": "py38" 409 | }, 410 | "language_info": { 411 | "codemirror_mode": { 412 | "name": "ipython", 413 | "version": 3 414 | }, 415 | "file_extension": ".py", 416 | "mimetype": "text/x-python", 417 | "name": "python", 418 | "nbconvert_exporter": "python", 419 | "pygments_lexer": "ipython3", 420 | "version": "3.7.10 | packaged by conda-forge | (default, Feb 19 2021, 16:07:37) \n[GCC 9.3.0]" 421 | }, 422 | "vscode": { 423 | "interpreter": { 424 | "hash": "5f0598356972e9a098a4a3756f1f561f75f6bcc730be3bc51870d213558f68c7" 425 | } 426 | }, 427 | "widgets": { 428 | "application/vnd.jupyter.widget-state+json": { 429 | "state": {}, 430 | "version_major": 2, 431 | "version_minor": 0 432 | } 433 | } 434 | }, 435 | "nbformat": 4, 436 | "nbformat_minor": 5 437 | } 438 | -------------------------------------------------------------------------------- /tools/embdding_onehot.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2022-10-08 06:15:36 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2022-10-08 08:00:58 6 | FilePath: /DMLF/tools/embdding_onehot.py 7 | Description: 8 | 9 | Copyright (c) 2022 by tibd, All Rights Reserved. 10 | ''' 11 | 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import os,sys,itertools 16 | from datetime import datetime 17 | from tqdm import tqdm 18 | from pandarallel import pandarallel # 导入pandaralle 19 | sys.path.append(os.path.dirname(os.path.realpath('__file__'))) 20 | sys.path.append('../') 21 | import config as cfg 22 | 23 | # amino acid 编码字典 24 | prot_dict = dict( 25 | A=0b000000000000000000000001, R=0b000000000000000000000010, N=0b000000000000000000000100, 26 | D=0b000000000000000000001000, C=0b000000000000000000010000, E=0b000000000000000000100000, 27 | Q=0b000000000000000001000000, G=0b000000000000000010000000, H=0b000000000000000100000000, 28 | O=0b000000000000001000000000, I=0b000000000000010000000000, L=0b000000000000100000000000, 29 | K=0b000000000001000000000000, M=0b000000000010000000000001, F=0b000000000100000000000000, 30 | P=0b000000001000000000000000, U=0b000000010000000000000000, S=0b000000100000000000000000, 31 | T=0b000001000000000000000000, W=0b000010000000000000000000, Y=0b000100000000000000000000, 32 | V=0b001000000000000000000000, B=0b010000000000000000000000, Z=0b100000000000000000000000, 33 | X=0b000000000000000000000000 34 | ) 35 | 36 | #region using one-hot to encode a amino acid sequence 37 | def protein_sequence_one_hot(protein_seq, padding=False, padding_window=1500): 38 | """ using one-hot to encode a amino acid sequence 39 | 40 | Args: 41 | protein_seq (string): aminio acid sequence 42 | padding (bool, optional): padding to a fixed length. Defaults to False. 43 | padding_window (int, optional): padded sequence size. Defaults to 1500. 44 | 45 | Returns: 46 | _type_: _description_ 47 | """ 48 | res = [prot_dict.get(item) for item in protein_seq] 49 | if padding==True: 50 | if len(protein_seq)>=padding_window: 51 | res= res[:padding_window] 52 | else: 53 | res=np.pad(res, (0,(padding_window-len(protein_seq))), 'median') 54 | return list(res) 55 | #endregion 56 | 57 | 58 | #region encode amino acid sequences dataframe 59 | def get_onehot(sequences, padding=True, padding_window=1500): 60 | """encode amino acid sequences dataframe 61 | 62 | Args: 63 | sequences (DataFrame): sequences dataframe cols name must contain ['id, 'seq'] 64 | padding (bool, optional): if padding to a fixed length. Defaults to True. 65 | padding_window (int, optional): fixed padding size. Defaults to 1500. 66 | 67 | Returns: 68 | DataFrame: one-hot represented sequences DataFrame 69 | """ 70 | sequences = sequences[['id','seq']].copy() 71 | sequences['onehot']=sequences.parallel_apply(lambda x: protein_sequence_one_hot(protein_seq=x.seq, padding=padding, padding_window=padding_window), axis=1) 72 | one_hot_rep = pd.DataFrame(np.array(list(itertools.chain(*sequences.onehot.values))).reshape(-1,(padding_window)), columns=['f'+str(i) for i in range (1,(padding_window+1))]) 73 | one_hot_rep.insert(0,'id',value=sequences.id.values) 74 | 75 | return one_hot_rep 76 | #endregion 77 | 78 | if __name__ =='__main__': 79 | seqs = pd.DataFrame([['seq1', 'MTTSVIVAGARTPIGKLMGSLKDFSASELGAIAIKGALEKANVPAS'], 80 | ['seq2', 'MAERAPRGEVAVMVAVQSALVDRPGMLATARGLSHFGEHCIGWLIL'] 81 | ], columns=['id', 'seq']) 82 | pandarallel.initialize() 83 | res = get_onehot(sequences=seqs, padding=True, padding_window=50) 84 | print(res) -------------------------------------------------------------------------------- /tools/embedding_esm.py: -------------------------------------------------------------------------------- 1 | from esm import model 2 | import torch 3 | import esm 4 | import re 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | import sys 9 | sys.path.insert(0, "../../") 10 | import config as cfg 11 | 12 | 13 | # region 将字符串拆分成固定长度 14 | def cut_text(text,lenth): 15 | """[将字符串拆分成固定长度] 16 | 17 | Args: 18 | text ([string]): [input string] 19 | lenth ([int]): [sub_sequence length] 20 | 21 | Returns: 22 | [string list]: [string results list] 23 | """ 24 | textArr = re.findall('.{'+str(lenth)+'}', text) 25 | textArr.append(text[(len(textArr)*lenth):]) 26 | return textArr 27 | #endregion 28 | 29 | #region 对单个序列进行embedding 30 | def get_rep_single_seq(seqid, sequence, model,batch_converter, seqthres=1022): 31 | """[对单个序列进行embedding] 32 | 33 | Args: 34 | seqid ([string]): [sequence name]] 35 | sequence ([sting]): [sequence] 36 | model ([model]): [ embedding model]] 37 | batch_converter ([object]): [description] 38 | seqthres (int, optional): [max sequence length]. Defaults to 1022. 39 | 40 | Returns: 41 | [type]: [description] 42 | """ 43 | 44 | if len(sequence) < seqthres: 45 | data =[(seqid, sequence)] 46 | else: 47 | seqArray = cut_text(sequence, seqthres) 48 | data=[] 49 | for item in seqArray: 50 | data.append((seqid, item)) 51 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 52 | 53 | if torch.cuda.is_available(): 54 | batch_tokens = batch_tokens.to(device="cuda", non_blocking=True) 55 | 56 | REP_LAYERS = [0, 32, 33] 57 | MINI_SIZE = len(batch_labels) 58 | 59 | with torch.no_grad(): 60 | results = model(batch_tokens, repr_layers=REP_LAYERS, return_contacts=False) 61 | 62 | 63 | representations = {layer: t.to(device="cpu") for layer, t in results["representations"].items()} 64 | result ={} 65 | result["label"] = batch_labels[0] 66 | 67 | for i in range(MINI_SIZE): 68 | if i ==0: 69 | result["mean_representations"] = {layer: t[i, 1 : len(batch_strs[0]) + 1].mean(0).clone() for layer, t in representations.items()} 70 | else: 71 | for index, layer in enumerate(REP_LAYERS): 72 | result["mean_representations"][layer] += {layer: t[i, 1 : len(batch_strs[0]) + 1].mean(0).clone() for layer, t in representations.items()}[layer] 73 | 74 | for index, layer in enumerate(REP_LAYERS): 75 | result["mean_representations"][layer] = result["mean_representations"][layer] /MINI_SIZE 76 | 77 | return result 78 | #endregion 79 | 80 | #region 对多个序列进行embedding 81 | def get_rep_multi_sequence(sequences, model='esm_msa1b_t12_100M_UR50S', repr_layers=[0, 32, 33], seqthres=1022): 82 | """[对多个序列进行embedding] 83 | Args: 84 | sequences ([DataFrame]): [ sequence info]] 85 | seqthres (int, optional): [description]. Defaults to 1022. 86 | 87 | Returns: 88 | [DataFrame]: [final_rep0, final_rep32, final_rep33] 89 | """ 90 | final_label_list = [] 91 | final_rep0 =[] 92 | final_rep32 =[] 93 | final_rep33 =[] 94 | 95 | model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() 96 | batch_converter = alphabet.get_batch_converter() 97 | if torch.cuda.is_available(): 98 | model = model.cuda() 99 | print("Transferred model to GPU") 100 | 101 | for i in tqdm(range(len(sequences))): 102 | apd = get_rep_single_seq( 103 | seqid = sequences.iloc[i].id, 104 | sequence=sequences.iloc[i].seq, 105 | model=model, 106 | batch_converter=batch_converter, 107 | seqthres=seqthres) 108 | 109 | final_label_list.append(np.array(apd['label'])) 110 | final_rep0.append(np.array(apd['mean_representations'][0])) 111 | final_rep32.append(np.array(apd['mean_representations'][32])) 112 | final_rep33.append(np.array(apd['mean_representations'][33])) 113 | 114 | final_rep0 = pd.DataFrame(final_rep0) 115 | final_rep32 = pd.DataFrame(final_rep32) 116 | final_rep33 = pd.DataFrame(final_rep33) 117 | final_rep0.insert(loc=0, column='id', value=np.array(final_label_list).flatten()) 118 | final_rep32.insert(loc=0, column='id', value=np.array(final_label_list).flatten()) 119 | final_rep33.insert(loc=0, column='id', value=np.array(final_label_list).flatten()) 120 | 121 | col_name = ['id']+ ['f'+str(i) for i in range (1,final_rep0.shape[1])] 122 | final_rep0.columns = col_name 123 | final_rep32.columns = col_name 124 | final_rep33.columns = col_name 125 | 126 | return final_rep0, final_rep32, final_rep33 127 | #endregion 128 | 129 | if __name__ =='__main__': 130 | SEQTHRES = 1022 131 | RUNMODEL = { 'ESM-1b' :'esm1b_t33_650M_UR50S', 132 | 'ESM-MSA-1b' :'esm_msa1b_t12_100M_UR50S' 133 | } 134 | train = pd.read_feather(cfg.DATADIR+'train.feather').iloc[:,:6] 135 | test = pd.read_feather(cfg.DATADIR+'test.feather').iloc[:,:6] 136 | rep0, rep32, rep33 = get_rep_multi_sequence(sequences=train, model=RUNMODEL.get('ESM-MSA-1b'),seqthres=SEQTHRES) 137 | 138 | rep0.to_feather(cfg.DATADIR + 'train_rep0.feather') 139 | rep32.to_feather(cfg.DATADIR + 'train_rep32.feather') 140 | rep33.to_feather(cfg.DATADIR + 'train_rep33.feather') 141 | 142 | print('Embedding Success!') -------------------------------------------------------------------------------- /tools/exact_ec_from_uniprot.py: -------------------------------------------------------------------------------- 1 | from Bio import SeqIO 2 | import gzip 3 | import re 4 | from tqdm import tqdm 5 | import time 6 | import sys,os 7 | sys.path.append(os.getcwd()) 8 | import config as cfg 9 | import pandas as pd 10 | 11 | #region 从gizp读取数据 12 | def read_file_from_gzip(file_in_path, file_out_path, extract_type, save_file_type='tsv'): 13 | """从原始Zip file 中解析数据 14 | 15 | Args: 16 | file_in_path ([string]): [输入文件路径] 17 | file_out_path ([string]): [输出文件路径]] 18 | extract_type ([string]): [抽取数据的类型:with_ec, without_ec, full]] 19 | """ 20 | if save_file_type == 'feather': 21 | outpath = file_out_path 22 | file_out_path = cfg.TEMPDIR +'temprecords.tsv' 23 | 24 | table_head = [ 'id', 25 | 'name', 26 | 'isenzyme', 27 | 'isMultiFunctional', 28 | 'functionCounts', 29 | 'ec_number', 30 | 'ec_specific_level', 31 | 'date_integraged', 32 | 'date_sequence_update', 33 | 'date_annotation_update', 34 | 'seq', 35 | 'seqlength' 36 | ] 37 | counter = 0 38 | saver = 0 39 | file_write_obj = open(file_out_path,'w') 40 | with gzip.open(file_in_path, "rt") as handle: 41 | file_write_obj.writelines('\t'.join(table_head)) 42 | file_write_obj.writelines('\n') 43 | for record in tqdm( SeqIO.parse(handle, 'swiss'), position=1, leave=True): 44 | res = process_record(record, extract_type= extract_type) 45 | counter+=1 46 | if counter %10==0: 47 | file_write_obj.flush() 48 | # if saver%100000==0: 49 | # print(saver) 50 | if len(res) >0: 51 | saver +=1 52 | file_write_obj.writelines('\t'.join(map(str,res))) 53 | file_write_obj.writelines('\n') 54 | else: 55 | continue 56 | file_write_obj.close() 57 | 58 | if save_file_type == 'feather': 59 | indata = pd.read_csv(cfg.TEMPDIR +'temprecords.tsv', sep='\t') 60 | indata.to_feather(outpath) 61 | 62 | 63 | #endregion 64 | 65 | #region 提取单条含有EC号的数据 66 | def process_record(record, extract_type='with_ec'): 67 | """ 68 | 提取单条含有EC号的数据 69 | Args: 70 | record ([type]): uniprot 中的记录节点 71 | extract_type (string, optional): 提取的类型,可选值:with_ec, without_ec, full。默认为有EC号(with_ec). 72 | Returns: 73 | [type]: [description] 74 | """ 75 | 76 | 77 | description = record.description 78 | isEnzyme = 'EC=' in description #有EC号的被认为是酶,否则认为是酶 79 | isMultiFunctional = False 80 | functionCounts = 0 81 | ec_specific_level =0 82 | 83 | 84 | if isEnzyme: 85 | ec = str(re.findall(r"EC=[0-9,.\-;]*",description) 86 | ).replace('EC=','').replace('\'','').replace(']','').replace('[','').replace(';','').strip() 87 | 88 | #统计酶的功能数 89 | isMultiFunctional = ',' in ec 90 | functionCounts = ec.count(',') + 1 91 | 92 | 93 | # - 单功能酶 94 | if not isMultiFunctional: 95 | levelCount = ec.count('-') 96 | ec_specific_level = 4-levelCount 97 | 98 | else: # -多功能酶 99 | ecarray = ec.split(',') 100 | for subec in ecarray: 101 | current_ec_level = 4- subec.count('-') 102 | if ec_specific_level < current_ec_level: 103 | ec_specific_level = current_ec_level 104 | else: 105 | ec = '-' 106 | 107 | id = record.id.strip() 108 | name = record.name.strip() 109 | seq = record.seq.strip() 110 | date_integrated = record.annotations.get('date').strip() 111 | date_sequence_update = record.annotations.get('date_last_sequence_update').strip() 112 | date_annotation_update = record.annotations.get('date_last_annotation_update').strip() 113 | seqlength = len(seq) 114 | res = [id, name, isEnzyme, isMultiFunctional, functionCounts, ec,ec_specific_level, date_integrated, date_sequence_update, date_annotation_update, seq, seqlength] 115 | 116 | if extract_type == 'full': 117 | return res 118 | 119 | if extract_type == 'with_ec': 120 | if isEnzyme: 121 | return res 122 | else: 123 | return [] 124 | 125 | if extract_type == 'without_ec': 126 | if isEnzyme: 127 | return [] 128 | else: 129 | return res 130 | #endregion 131 | 132 | def run_exact_task(infile, outfile): 133 | start = time.process_time() 134 | extract_type ='full' 135 | read_file_from_gzip(file_in_path=infile, file_out_path=outfile, extract_type=extract_type) 136 | end = time.process_time() 137 | print('finished use time %6.3f s' % (end - start)) 138 | 139 | if __name__ =="__main__": 140 | start = time.process_time() 141 | in_filepath_sprot = cfg.FILE_LATEST_SPROT 142 | out_filepath_sprot = cfg.FILE_LATEST_SPROT_FEATHER 143 | 144 | in_filepath_trembl = cfg.FILE_LATEST_TREMBL 145 | out_filepath_trembl = cfg.FILE_LATEST_TREMBL_FEATHER 146 | 147 | extract_type ='full' 148 | 149 | 150 | # read_file_from_gzip(file_in_path=in_filepath_sprot, file_out_path=out_filepath_sprot, extract_type=extract_type, save_file_type='feather') 151 | 152 | read_file_from_gzip(file_in_path=in_filepath_trembl, file_out_path=out_filepath_trembl, extract_type=extract_type, save_file_type='feather') 153 | end = time.process_time() 154 | print('finished use time %6.3f s' % (end - start)) 155 | 156 | -------------------------------------------------------------------------------- /tools/filetool.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2022-10-05 05:24:14 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2022-10-05 05:25:15 6 | FilePath: /DMLF/tools/filetool.py 7 | Description: 8 | 9 | Copyright (c) 2022 by tibd, All Rights Reserved. 10 | ''' 11 | 12 | from retrying import retry 13 | from matplotlib.pyplot import axis 14 | import urllib3 15 | import pandas as pd 16 | import os 17 | import sys 18 | from shutil import copyfile 19 | from sys import exit 20 | import zipfile 21 | 22 | 23 | #region 下载文件 24 | @retry(stop_max_attempt_number=10, wait_random_min=10, wait_random_max=20) 25 | def download(download_url, save_file): 26 | """[从网址下载文件] 27 | Args: 28 | download_url ([Stirng]): [下载路径] 29 | save_file ([String]): [保存文件路径] 30 | """ 31 | http = urllib3.PoolManager() 32 | response = http.request('GET', download_url) 33 | with open(save_file, 'wb') as f: 34 | f.write(response.data) 35 | response.release_conn() 36 | #endregion 37 | 38 | 39 | def wget(download_url, save_file, verbos=False): 40 | process = os.popen('which wget') # return file 41 | output = process.read() 42 | if output =='': 43 | print('wget not installed') 44 | else: 45 | if verbos: 46 | cmd = 'wget ' + download_url + ' -O ' + save_file 47 | else: 48 | cmd = 'wget -q ' + download_url + ' -O ' + save_file 49 | print (cmd) 50 | process = os.popen(cmd) 51 | output = process.read() 52 | process.close() 53 | 54 | def convert_DF_dateTime(inputdf): 55 | """[Covert unisprot csv records datatime] 56 | 57 | Args: 58 | inputdf ([DataFrame]): [input dataFrame] 59 | 60 | Returns: 61 | [DataFrame]: [converted DataFrame] 62 | """ 63 | inputdf.date_integraged = pd.to_datetime(inputdf['date_integraged']) 64 | inputdf.date_sequence_update = pd.to_datetime(inputdf['date_sequence_update']) 65 | inputdf.date_annotation_update = pd.to_datetime(inputdf['date_annotation_update']) 66 | inputdf = inputdf.sort_values(by='date_integraged', ascending=True) 67 | inputdf.reset_index(drop=True, inplace=True) 68 | return inputdf 69 | 70 | def get_file_names_in_dir(dataroot, filetype='all'): 71 | """返回某个文件夹下的文件列表,可以指定文件类型 72 | 73 | Args: 74 | dataroot (string): 文件夹目录 75 | filetype (str, optional): 文件类型. Defaults to ''. 76 | 77 | Returns: 78 | DataFrame: columns=['filename','filetype','filename_no_suffix'] 79 | """ 80 | exist_file_df = pd.DataFrame(os.listdir(dataroot), columns=['filename']) 81 | 82 | if len(exist_file_df)!=0: 83 | exist_file_df['filetype'] = exist_file_df.filename.apply(lambda x: x.split('.')[-1]) 84 | exist_file_df['filename_no_suffix'] = exist_file_df.apply(lambda x: x['filename'].replace(('.'+str(x['filetype']).strip()), ''), axis=1) 85 | if filetype !='all': 86 | exist_file_df = exist_file_df[exist_file_df.filetype==filetype] 87 | return exist_file_df 88 | else: 89 | return pd.DataFrame(columns=['filename','filetype','filename_no_suffix']) 90 | 91 | 92 | def copy(source, target): 93 | """ 拷贝文件 94 | 95 | Args: 96 | source (string): source 97 | target (string): target 98 | """ 99 | try: 100 | copyfile(src=source, dst=target) 101 | except IOError as e: 102 | print("Unable to copy file. %s" % e) 103 | exit(1) 104 | except: 105 | print("Unexpected error:", sys.exc_info()) 106 | exit(1) 107 | 108 | def delete(filepath): 109 | """删除文件 110 | 111 | Args: 112 | filepath (string): 文件全路径 113 | """ 114 | try: 115 | os.remove(filepath) 116 | except IOError as e: 117 | print("Unable to delete file. %s" % e) 118 | exit(1) 119 | except: 120 | print("Unexpected error:", sys.exc_info()) 121 | exit(1) 122 | 123 | #region unzip file 124 | def unzipfile(filename, target_dir): 125 | """uzip file 126 | 127 | Args: 128 | zipfile (string): zip file full path 129 | target_dir (string): target dir 130 | """ 131 | with zipfile.ZipFile(filename,"r") as zip_ref: 132 | zip_ref.extractall(target_dir) 133 | #endregion 134 | 135 | 136 | def isfileExists(filepath): 137 | return os.path.exists(filepath) 138 | -------------------------------------------------------------------------------- /tools/funclib.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from sklearn.model_selection import train_test_split 3 | from sklearn import metrics 4 | from sklearn import linear_model 5 | from sklearn.svm import SVC 6 | from sklearn import tree 7 | from tkinter import _flatten 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.ensemble import GradientBoostingClassifier 10 | from sklearn.neighbors import KNeighborsClassifier 11 | from xgboost import XGBClassifier 12 | from Bio import SeqIO 13 | 14 | import pandas as pd 15 | import numpy as np 16 | import os 17 | 18 | 19 | table_head = [ 'id', 20 | 'name', 21 | 'isemzyme', 22 | 'isMultiFunctional', 23 | 'functionCounts', 24 | 'ec_number', 25 | 'ec_specific_level', 26 | 'date_integraged', 27 | 'date_sequence_update', 28 | 'date_annotation_update', 29 | 'seq', 30 | 'seqlength' 31 | ] 32 | 33 | def table2fasta(table, file_out): 34 | file = open(file_out, 'w') 35 | for index, row in table.iterrows(): 36 | file.write('>{0}\n'.format(row['id'])) 37 | file.write('{0}\n'.format(row['seq'])) 38 | file.close() 39 | print('Write finished') 40 | 41 | #氨基酸字典(X未知项用于数据对齐) 42 | prot_dict = dict( 43 | A=1, R=2, N=3, D=4, C=5, E=6, Q=7, G=8, H=9, O=10, I=11, L=12, 44 | K=13, M=14, F=15, P=16, U=17, S=18, T=19, W=20, Y=21, V=22, B=23, Z=24, X=0 45 | ) 46 | 47 | # one-hot 编码 48 | def dna_onehot_outdateed(Xdna): 49 | listtmp = list() 50 | for index, row in Xdna.iterrows(): 51 | row = [prot_dict[x] if x in prot_dict else x for x in row['seq']] 52 | listtmp.append(row) 53 | return pd.DataFrame(listtmp) 54 | 55 | # one-hot 编码 56 | def dna_onehot(Xdna): 57 | listtmp = [] 58 | listtmp =Xdna.seq.parallel_apply(lambda x: np.array([prot_dict.get(item) for item in x])) 59 | listtmp = pd.DataFrame(np.stack(listtmp)) 60 | listtmp = pd.concat( [Xdna.iloc[:,0:2], listtmp], axis=1) 61 | return listtmp 62 | 63 | 64 | 65 | 66 | def lrmain(X_train_std, Y_train, X_test_std, Y_test, type='binary'): 67 | logreg = linear_model.LogisticRegression( 68 | solver = 'saga', 69 | multi_class='auto', 70 | verbose=False, 71 | n_jobs=-1, 72 | max_iter=10000 73 | ) 74 | # sc = StandardScaler() 75 | # X_train_std = sc.fit_transform(X_train_std) 76 | logreg.fit(X_train_std, Y_train) 77 | predict = logreg.predict(X_test_std) 78 | lrpredpro = logreg.predict_proba(X_test_std) 79 | groundtruth = Y_test 80 | return groundtruth, predict, lrpredpro, logreg 81 | 82 | def knnmain(X_train_std, Y_train, X_test_std, Y_test, type='binary'): 83 | knn=KNeighborsClassifier(n_neighbors=5, n_jobs=16) 84 | knn.fit(X_train_std, Y_train) 85 | predict = knn.predict(X_test_std) 86 | lrpredpro = knn.predict_proba(X_test_std) 87 | groundtruth = Y_test 88 | return groundtruth, predict, lrpredpro, knn 89 | 90 | def svmmain(X_train_std, Y_train, X_test_std, Y_test): 91 | svcmodel = SVC(probability=True, kernel='rbf', tol=0.001) 92 | svcmodel.fit(X_train_std, Y_train.ravel(), sample_weight=None) 93 | predict = svcmodel.predict(X_test_std) 94 | predictprob =svcmodel.predict_proba(X_test_std) 95 | groundtruth = Y_test 96 | return groundtruth, predict, predictprob,svcmodel 97 | 98 | def xgmain(X_train_std, Y_train, X_test_std, Y_test, type='binary', vali=True): 99 | 100 | x_train, x_vali, y_train, y_vali = train_test_split(X_train_std, Y_train, test_size=0.2, random_state=1) 101 | eval_set = [(x_train, y_train), (x_vali, y_vali)] 102 | 103 | if type=='binary': 104 | # model = XGBClassifier( 105 | # objective='binary:logistic', 106 | # random_state=15, 107 | # use_label_encoder=False, 108 | # n_jobs=-1, 109 | # eval_metric='mlogloss', 110 | # min_child_weight=15, 111 | # max_depth=15, 112 | # n_estimators=300, 113 | # tree_method = 'gpu_hist', 114 | # learning_rate = 0.01 115 | # ) 116 | model = XGBClassifier( 117 | objective='binary:logistic', 118 | random_state=13, 119 | use_label_encoder=False, 120 | n_jobs=-2, 121 | eval_metric='auc', 122 | max_depth=6, 123 | n_estimators= 300, 124 | tree_method = 'gpu_hist', 125 | learning_rate = 0.01, 126 | gpu_id=0 127 | ) 128 | model.fit(x_train, y_train, eval_set=eval_set, verbose=False) 129 | # model.fit(t1_x_train, t1_y_train, eval_set=t1_eval_set, verbose=100) 130 | if type=='multi': 131 | model = XGBClassifier( 132 | min_child_weight=6, 133 | max_depth=6, 134 | objective='multi:softmax', 135 | num_class=len(set(Y_train)), 136 | use_label_encoder=False, 137 | n_estimators=120 138 | ) 139 | if vali: 140 | model.fit(x_train, y_train, eval_metric="mlogloss", eval_set=eval_set, verbose=False) 141 | else: 142 | model.fit(X_train_std, Y_train, eval_metric="mlogloss", eval_set=None, verbose=False) 143 | 144 | predict = model.predict(X_test_std) 145 | predictprob = model.predict_proba(X_test_std) 146 | groundtruth = Y_test 147 | return groundtruth, predict, predictprob, model 148 | 149 | def dtmain(X_train_std, Y_train, X_test_std, Y_test): 150 | model = tree.DecisionTreeClassifier() 151 | model.fit(X_train_std, Y_train.ravel()) 152 | predict = model.predict(X_test_std) 153 | predictprob = model.predict_proba(X_test_std) 154 | groundtruth = Y_test 155 | return groundtruth, predict, predictprob,model 156 | 157 | def rfmain(X_train_std, Y_train, X_test_std, Y_test): 158 | model = RandomForestClassifier(oob_score=True, random_state=10, n_jobs=-2) 159 | model.fit(X_train_std, Y_train.ravel()) 160 | predict = model.predict(X_test_std) 161 | predictprob = model.predict_proba(X_test_std) 162 | groundtruth = Y_test 163 | return groundtruth, predict, predictprob,model 164 | 165 | def gbdtmain(X_train_std, Y_train, X_test_std, Y_test): 166 | model = GradientBoostingClassifier(random_state=10) 167 | model.fit(X_train_std, Y_train.ravel()) 168 | predict = model.predict(X_test_std) 169 | predictprob = model.predict_proba(X_test_std) 170 | groundtruth = Y_test 171 | return groundtruth, predict, predictprob, model 172 | 173 | 174 | def caculateMetrix(groundtruth, predict, baselineName, type='binary'): 175 | acc = metrics.accuracy_score(groundtruth, predict) 176 | if type == 'binary': 177 | precision = metrics.precision_score(groundtruth, predict, zero_division=True ) 178 | recall = metrics.recall_score(groundtruth, predict, zero_division=True) 179 | f1 = metrics.f1_score(groundtruth, predict, zero_division=True) 180 | tn, fp, fn, tp = metrics.confusion_matrix(groundtruth, predict).ravel() 181 | npv = tn/(fn+tn+1.4E-45) 182 | print(baselineName, '\t\t%f' %acc,'\t%f'% precision,'\t\t%f'%npv,'\t%f'% recall,'\t%f'% f1, '\t', 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn) 183 | 184 | if type == 'multi': 185 | precision = metrics.precision_score(groundtruth, predict, average='macro', zero_division=True ) 186 | recall = metrics.recall_score(groundtruth, predict, average='macro', zero_division=True) 187 | f1 = metrics.f1_score(groundtruth, predict, average='macro', zero_division=True) 188 | print('%12s'%baselineName, ' \t\t%f '%acc,'\t%f'% precision, '\t\t%f'% recall,'\t%f'% f1) 189 | 190 | 191 | def evaluate(baslineName, X_train_std, Y_train, X_test_std, Y_test, type='binary'): 192 | 193 | if baslineName == 'lr': 194 | groundtruth, predict, predictprob, model = lrmain (X_train_std, Y_train, X_test_std, Y_test, type=type) 195 | elif baslineName == 'svm': 196 | groundtruth, predict, predictprob, model = svmmain(X_train_std, Y_train, X_test_std, Y_test) 197 | elif baslineName =='xg': 198 | groundtruth, predict, predictprob, model = xgmain(X_train_std, Y_train, X_test_std, Y_test, type=type) 199 | elif baslineName =='dt': 200 | groundtruth, predict, predictprob, model = dtmain(X_train_std, Y_train, X_test_std, Y_test) 201 | elif baslineName =='rf': 202 | groundtruth, predict, predictprob, model = rfmain(X_train_std, Y_train, X_test_std, Y_test) 203 | elif baslineName =='gbdt': 204 | groundtruth, predict, predictprob, model = gbdtmain(X_train_std, Y_train, X_test_std, Y_test) 205 | elif baslineName =='knn': 206 | groundtruth, predict, predictprob, model = knnmain(X_train_std, Y_train, X_test_std, Y_test, type=type) 207 | else: 208 | print('Baseline Name Errror') 209 | 210 | caculateMetrix(groundtruth=groundtruth,predict=predict,baselineName=baslineName, type=type) 211 | 212 | def evaluate_2(baslineName, X_train_std, Y_train, X_test_std, Y_test, type='binary'): 213 | 214 | if baslineName == 'lr': 215 | groundtruth, predict, predictprob, model = lrmain (X_train_std, Y_train, X_test_std, Y_test, type=type) 216 | elif baslineName == 'svm': 217 | groundtruth, predict, predictprob, model = svmmain(X_train_std, Y_train, X_test_std, Y_test) 218 | elif baslineName =='xg': 219 | groundtruth, predict, predictprob, model = xgmain(X_train_std, Y_train, X_test_std, Y_test, type=type, vali=False) 220 | elif baslineName =='dt': 221 | groundtruth, predict, predictprob, model = dtmain(X_train_std, Y_train, X_test_std, Y_test) 222 | elif baslineName =='rf': 223 | groundtruth, predict, predictprob, model = rfmain(X_train_std, Y_train, X_test_std, Y_test) 224 | elif baslineName =='gbdt': 225 | groundtruth, predict, predictprob, model = gbdtmain(X_train_std, Y_train, X_test_std, Y_test) 226 | elif baslineName =='knn': 227 | groundtruth, predict, predictprob, model = knnmain(X_train_std, Y_train, X_test_std, Y_test, type=type) 228 | else: 229 | print('Baseline Name Errror') 230 | 231 | caculateMetrix(groundtruth=groundtruth,predict=predict,baselineName=baslineName, type=type) 232 | 233 | def run_baseline(X_train, Y_train, X_test, Y_test, type='binary'): 234 | methods=['knn','lr', 'xg', 'dt', 'rf', 'gbdt'] 235 | if type == 'binary': 236 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t\t confusion Matrix') 237 | if type =='multi': 238 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro') 239 | for method in methods: 240 | evaluate(method, X_train, Y_train, X_test, Y_test, type=type) 241 | 242 | def run_baseline_2(X_train, Y_train, X_test, Y_test, type='binary'): 243 | methods=['knn', 'xg', 'dt', 'rf', 'gbdt'] 244 | if type == 'binary': 245 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t\t confusion Matrix') 246 | if type =='multi': 247 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro') 248 | for method in methods: 249 | evaluate_2(method, X_train, Y_train, X_test, Y_test, type=type) 250 | 251 | 252 | def static_interval(data, span): 253 | """[summary] 254 | 255 | Args: 256 | data ([dataframe]): [需要统计的数据] 257 | span ([int]): [统计的间隔] 258 | 259 | Returns: 260 | [type]: [完成的统计列表]] 261 | """ 262 | res = [] 263 | count = 0 264 | for i in range(int(len(data)/span) + 1): 265 | lable = str(i*span) + '-' + str((i+1)* span -1 ) 266 | num = data[(data.length>=(i*span)) & (data.length<(i+1)*span)]['count'].sum() 267 | res += [[lable, num]] 268 | return res 269 | 270 | 271 | 272 | def getblast(train, test): 273 | 274 | table2fasta(train, '/tmp/train.fasta') 275 | table2fasta(test, '/tmp/test.fasta') 276 | 277 | cmd1 = r'diamond makedb --in /tmp/train.fasta -d /tmp/train.dmnd --quiet' 278 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet' 279 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv' 280 | print(cmd1) 281 | os.system(cmd1) 282 | print(cmd2) 283 | os.system(cmd2) 284 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore']) 285 | os.system(cmd3) 286 | return res_data 287 | 288 | def getblast_usedb(db, test): 289 | table2fasta(test, '/tmp/test.fasta') 290 | cmd2 = r'diamond blastp -d {0} -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet'.format(db) 291 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv' 292 | 293 | print(cmd2) 294 | os.system(cmd2) 295 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore']) 296 | os.system(cmd3) 297 | return res_data 298 | 299 | def getblast_fasta(trainfasta, testfasta): 300 | 301 | cmd1 = r'diamond makedb --in {0} -d /tmp/train.dmnd --quiet'.format(trainfasta) 302 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q {0} -o /tmp/test_fasta_results.tsv -b8 -c1 -k 1 --quiet'.format(testfasta) 303 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv' 304 | # print(cmd1) 305 | os.system(cmd1) 306 | # print(cmd2) 307 | os.system(cmd2) 308 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore']) 309 | os.system(cmd3) 310 | return res_data 311 | 312 | 313 | #region 统计EC数 314 | def stiatistic_ec_num(eclist): 315 | """统计EC数量 316 | 317 | Args: 318 | eclist (list): 可包含多功能酶的EC列表,用;分割 319 | 320 | Returns: 321 | int: 列表中包含的独立EC数量 322 | """ 323 | eclist = list(eclist.flatten()) #展成1维 324 | eclist = _flatten([item.split(';') for item in eclist]) #分割多功能酶 325 | eclist = [item.strip() for item in eclist] # 去空格 326 | num_ecs = len(set(eclist)) 327 | return num_ecs 328 | #endregion 329 | 330 | def caculateMetrix_1(baselineName,tp, fp, tn,fn): 331 | sampleNum = tp+fp+tn+fn 332 | accuracy = (tp+tn)/sampleNum 333 | precision = tp/(tp+fp) 334 | npv = tn/(tn+fn) 335 | recall = tp/(tp+fn) 336 | f1 = 2 * (precision * recall) / (precision + recall) 337 | print('{0} \t {1:.6f} \t{2:.6f} \t\t {3:.6f} \t{4:.6f}\t {5:.6f}\t\t \t \t \t tp:{6} fp:{7} fn:{8} tn:{9}'.format(baselineName,accuracy, precision, npv, recall, f1, tp,fp,fn,tn)) 338 | 339 | 340 | 341 | def get_integrated_results(res_data, train, test, baslineName): 342 | # 给比对结果添加标签 343 | isEmzyme_dict = {v: k for k,v in zip(train.isemzyme, train.id )} 344 | res_data['diamoion_pred'] = res_data['sseqid'].apply(lambda x: isEmzyme_dict.get(x)) 345 | 346 | blast_res = pd.DataFrame 347 | blast_res = res_data[['id','pident','bitscore', 'diamoion_pred']] 348 | 349 | X_train = train.iloc[:,12:] 350 | X_test = test.iloc[:,12:] 351 | Y_train = train.iloc[:,2].astype('int') 352 | Y_test = test.iloc[:,2].astype('int') 353 | X_train = np.array(X_train) 354 | X_test = np.array(X_test) 355 | Y_train = np.array(Y_train).flatten() 356 | Y_test = np.array(Y_test).flatten() 357 | 358 | if baslineName == 'lr': 359 | groundtruth, predict, predictprob = lrmain (X_train, Y_train, X_test, Y_test) 360 | elif baslineName == 'svm': 361 | groundtruth, predict, predictprob = svmmain(X_train, Y_train, X_test, Y_test) 362 | elif baslineName =='xg': 363 | groundtruth, predict, predictprob = xgmain(X_train, Y_train, X_test, Y_test) 364 | elif baslineName =='dt': 365 | groundtruth, predict, predictprob = dtmain(X_train, Y_train, X_test, Y_test) 366 | elif baslineName =='rf': 367 | groundtruth, predict, predictprob = rfmain(X_train, Y_train, X_test, Y_test) 368 | elif baslineName =='gbdt': 369 | groundtruth, predict, predictprob = gbdtmain(X_train, Y_train, X_test, Y_test) 370 | else: 371 | print('Baseline Name Errror') 372 | 373 | test_res = pd.DataFrame() 374 | test_res[['id', 'name','isemzyme','ec_number']] = test[['id','name','isemzyme','ec_number']] 375 | test_res.reset_index(drop=True, inplace=True) 376 | 377 | #拼合比对结果到测试集 378 | test_merge_res = pd.merge(test_res, blast_res, on='id', how='left') 379 | test_merge_res['xg_pred'] = predict 380 | test_merge_res['xg_pred_prob'] = predictprob 381 | test_merge_res['groundtruth'] = groundtruth 382 | 383 | test_merge_res['final_pred'] = '' 384 | for index, row in test_merge_res.iterrows(): 385 | if (row.diamoion_pred == True) | (row.diamoion_pred == False): 386 | with pd.option_context('mode.chained_assignment', None): 387 | test_merge_res['final_pred'][index] = row.diamoion_pred 388 | else: 389 | with pd.option_context('mode.chained_assignment', None): 390 | test_merge_res['final_pred'][index] = row.xg_pred 391 | 392 | # 计算指标 393 | tp = len(test_merge_res[test_merge_res.groundtruth & test_merge_res.final_pred]) 394 | fp = len(test_merge_res[(test_merge_res.groundtruth ==False) & (test_merge_res.final_pred)]) 395 | tn = len(test_merge_res[(test_merge_res.groundtruth ==False) & (test_merge_res.final_pred ==False)]) 396 | fn = len(test_merge_res[(test_merge_res.groundtruth ) & (test_merge_res.final_pred == False)]) 397 | caculateMetrix_1(baslineName,tp, fp, tn,fn) 398 | 399 | 400 | def run_integrated(res_data, train, test): 401 | methods=['lr','xg', 'dt', 'rf', 'gbdt'] 402 | print('baslineName', '\t\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', 'auroc','\t\t', 'auprc', '\t\t confusion Matrix') 403 | for method in methods: 404 | get_integrated_results(res_data, train, test, method) 405 | 406 | 407 | 408 | #region 将多功能的EC编号展开,返回唯一的EC编号列表 409 | def get_distinct_ec(ecnumbers): 410 | """ 411 | 将多功能的EC编号展开,返回唯一的EC编号列表 412 | Args: 413 | ecnumbers: EC_number 列 414 | 415 | Returns: 排序好的唯一EC列表 416 | 417 | """ 418 | result_list=[] 419 | for item in ecnumbers: 420 | ecarray = item.split(',') 421 | for subitem in ecarray: 422 | result_list+=[subitem.strip()] 423 | return sorted(list(set(result_list))) 424 | #endregion 425 | 426 | 427 | 428 | #region 将多功能酶拆解为多个单功能酶 429 | def split_ecdf_to_single_lines(full_table): 430 | """ 431 | 将多功能酶拆解为多个单功能酶 432 | Args: 433 | full_table: 包含EC号的完整列表 434 | 并 1.去除酶号前后空格 435 | 并 2. 将酶号拓展为4位的标准格式 436 | Returns: 展开后的EC列表,每个EC号一行 437 | """ 438 | listres=full_table.parallel_apply(lambda x: split_ecdf_to_single_lines_pr_record(x) , axis=1) 439 | temp_li = [] 440 | for res in tqdm(listres): 441 | for j in res: 442 | temp_li = temp_li + [j] 443 | resDf = pd.DataFrame(temp_li,columns=full_table.columns.values) 444 | 445 | return resDf 446 | #endregion 447 | 448 | 449 | def split_ecdf_to_single_lines_pr_record(row): 450 | resDf = [] 451 | 452 | if row.ec_number.strip()=='-': #若是非酶直接返回 453 | row.ec_number='-' 454 | row.ec_number = row.ec_number.strip() 455 | resDf = row.values 456 | return [[row.id, row.seq, row.ec_number]] 457 | else: 458 | ecs = row.ec_number.split(',') #拆解多功能酶 459 | if len(ecs) ==1: # 单功能酶直接返回 460 | return [[row.id, row.seq, row.ec_number]] 461 | for ec in ecs: 462 | ec = ec.strip() 463 | ecarray=ec.split('.') #拆解每一位 464 | if ecarray[3] == '': #若是最后一位是空,补足_ 465 | ec=ec+'-' 466 | row.ec_number = ec.strip() 467 | resDf = resDf + [[row.id, row.seq, ec]] 468 | return resDf 469 | 470 | 471 | 472 | def load_fasta_to_table(file): 473 | """[Load fasta file to DataFrame] 474 | 475 | Args: 476 | file ([string]): [fasta file location] 477 | 478 | Returns: 479 | [DataFrame]: [loaded fasta in DF format] 480 | """ 481 | if os.path.exists(file) == False: 482 | print('file not found:{0}'.format(file)) 483 | return nullcontext 484 | 485 | input_data=[] 486 | for record in SeqIO.parse(file, format='fasta'): 487 | input_data=input_data +[[record.id, str(record.seq)]] 488 | input_df = pd.DataFrame(input_data, columns=['id','seq']) 489 | return input_df 490 | 491 | def load_deepec_resluts(filepath): 492 | """load deepec predicted resluts 493 | 494 | Args: 495 | filepath (string): deepec predicted file 496 | 497 | Returns: 498 | DataFrame: columns=['id', 'ec_deepec'] 499 | """ 500 | res_deepec = pd.read_csv(f'{filepath}', sep='\t',names=['id', 'ec_number'], header=0 ) 501 | res_deepec.ec_number=res_deepec.apply(lambda x: x['ec_number'].replace('EC:',''), axis=1) 502 | res_deepec.columns = ['id','ec_deepec'] 503 | res = [] 504 | for index, group in res_deepec.groupby('id'): 505 | if len(group)==1: 506 | res = res + [[group.id.values[0], group.ec_deepec.values[0]]] 507 | else: 508 | ecs_str = ','.join(group.ec_deepec.values) 509 | res = res +[[group.id.values[0],ecs_str]] 510 | res_deepec = pd.DataFrame(res, columns=['id', 'ec_deepec']) 511 | return res_deepec 512 | 513 | 514 | #region 515 | def load_praim_res(resfile): 516 | """[加载PRIAM的预测结果] 517 | Args: 518 | resfile ([string]): [结果文件] 519 | Returns: 520 | [DataFrame]: [结果] 521 | """ 522 | f = open(resfile) 523 | line = f.readline() 524 | counter =0 525 | reslist=[] 526 | lstr ='' 527 | subec=[] 528 | while line: 529 | if '>' in line: 530 | if counter !=0: 531 | reslist +=[[lstr, ', '.join(subec)]] 532 | subec=[] 533 | lstr = line.replace('>', '').replace('\n', '') 534 | elif line.strip()!='': 535 | ecarray = line.split('\t') 536 | subec += [(ecarray[0].replace('#', '').replace('\n', '').replace(' ', '') )] 537 | 538 | line = f.readline() 539 | counter +=1 540 | f.close() 541 | res_priam=pd.DataFrame(reslist, columns=['id', 'ec_priam']) 542 | return res_priam 543 | #endregion 544 | 545 | def load_catfam_res(resfile): 546 | res_catfam = pd.read_csv(resfile, sep='\t', names=['id', 'ec_catfam']) 547 | return res_catfam 548 | 549 | 550 | def load_ecpred_res(resfile): 551 | res_ecpred = pd.read_csv(f'{resfile}', sep='\t', header=0) 552 | res_ecpred = res_ecpred.rename(columns={'Protein ID':'id','EC Number':'ec_ecpred','Confidence Score(max 1.0)':'pident_ecpred'}) 553 | res_ecpred['ec_ecpred']= res_ecpred.ec_ecpred.apply(lambda x : '-' if x=='non Enzyme' else x) 554 | return res_ecpred -------------------------------------------------------------------------------- /tools/uniprottool.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Zhenkun Shi 3 | Date: 2023-07-17 00:30:14 4 | LastEditors: Zhenkun Shi 5 | LastEditTime: 2023-07-17 00:49:11 6 | FilePath: /ECRECer/tools/uniprottool.py 7 | Description: 8 | 9 | Copyright (c) 2023 by tibd, All Rights Reserved. 10 | ''' 11 | 12 | 13 | from Bio import SeqIO 14 | import gzip 15 | import re 16 | from tqdm import tqdm 17 | import time 18 | import sys,os 19 | sys.path.insert(0, os.path.dirname(os.path.realpath('__file__'))) 20 | sys.path.insert(1,'../') 21 | import config as cfg 22 | import pandas as pd 23 | import requests 24 | from requests.adapters import HTTPAdapter, Retry 25 | 26 | #region 从gizp读取数据 27 | def read_file_from_gzip(file_in_path, file_out_path, extract_type, save_file_type='tsv'): 28 | """从原始Zip file 中解析数据 29 | 30 | Args: 31 | file_in_path ([string]): [输入文件路径] 32 | file_out_path ([string]): [输出文件路径]] 33 | extract_type ([string]): [抽取数据的类型:with_ec, without_ec, full]] 34 | """ 35 | if save_file_type == 'feather': 36 | outpath = file_out_path 37 | file_out_path = cfg.TEMP_DIR +'temprecords.tsv' 38 | 39 | table_head = [ 'id', 40 | 'name', 41 | 'isenzyme', 42 | 'isMultiFunctional', 43 | 'functionCounts', 44 | 'ec_number', 45 | 'ec_specific_level', 46 | 'date_integraged', 47 | 'date_sequence_update', 48 | 'date_annotation_update', 49 | 'seq', 50 | 'seqlength' 51 | ] 52 | counter = 0 53 | saver = 0 54 | file_write_obj = open(file_out_path,'w') 55 | with gzip.open(file_in_path, "rt") as handle: 56 | file_write_obj.writelines('\t'.join(table_head)) 57 | file_write_obj.writelines('\n') 58 | for record in tqdm( SeqIO.parse(handle, 'swiss'), position=1, leave=True): 59 | res = process_record(record, extract_type= extract_type) 60 | counter+=1 61 | if counter %10==0: 62 | file_write_obj.flush() 63 | # if saver%100000==0: 64 | # print(saver) 65 | if len(res) >0: 66 | saver +=1 67 | file_write_obj.writelines('\t'.join(map(str,res))) 68 | file_write_obj.writelines('\n') 69 | else: 70 | continue 71 | file_write_obj.close() 72 | 73 | if save_file_type == 'feather': 74 | indata = pd.read_csv(cfg.TEMP_DIR +'temprecords.tsv', sep='\t') 75 | indata.to_feather(outpath) 76 | 77 | 78 | #endregion 79 | 80 | #region 提取单条含有EC号的数据 81 | def process_record(record, extract_type='with_ec'): 82 | """ 83 | 提取单条含有EC号的数据 84 | Args: 85 | record ([type]): uniprot 中的记录节点 86 | extract_type (string, optional): 提取的类型,可选值:with_ec, without_ec, full。默认为有EC号(with_ec). 87 | Returns: 88 | [type]: [description] 89 | """ 90 | 91 | description = record.description 92 | isEnzyme = 'EC=' in description #有EC号的被认为是酶,否则认为是酶 93 | isMultiFunctional = False 94 | functionCounts = 0 95 | ec_specific_level =0 96 | 97 | 98 | if isEnzyme: 99 | ec = str(re.findall(r"EC=[0-9,.\-;]*",description) 100 | ).replace('EC=','').replace('\'','').replace(']','').replace('[','').replace(';','').strip() 101 | 102 | #统计酶的功能数 103 | isMultiFunctional = ',' in ec 104 | functionCounts = ec.count(',') + 1 105 | 106 | 107 | # - 单功能酶 108 | if not isMultiFunctional: 109 | levelCount = ec.count('-') 110 | ec_specific_level = 4-levelCount 111 | 112 | else: # -多功能酶 113 | ecarray = ec.split(',') 114 | for subec in ecarray: 115 | current_ec_level = 4- subec.count('-') 116 | if ec_specific_level < current_ec_level: 117 | ec_specific_level = current_ec_level 118 | else: 119 | ec = '-' 120 | 121 | id = record.id.strip() 122 | name = record.name.strip() 123 | seq = record.seq.strip() 124 | date_integrated = record.annotations.get('date').strip() 125 | date_sequence_update = record.annotations.get('date_last_sequence_update').strip() 126 | date_annotation_update = record.annotations.get('date_last_annotation_update').strip() 127 | seqlength = len(seq) 128 | res = [id, name, isEnzyme, isMultiFunctional, functionCounts, ec,ec_specific_level, date_integrated, date_sequence_update, date_annotation_update, seq, seqlength] 129 | 130 | if extract_type == 'full': 131 | return res 132 | 133 | if extract_type == 'with_ec': 134 | if isEnzyme: 135 | return res 136 | else: 137 | return [] 138 | 139 | if extract_type == 'without_ec': 140 | if isEnzyme: 141 | return [] 142 | else: 143 | return res 144 | #endregion 145 | 146 | def run_exact_task(infile, outfile): 147 | start = time.process_time() 148 | extract_type ='full' 149 | read_file_from_gzip(file_in_path=infile, file_out_path=outfile, extract_type=extract_type) 150 | end = time.process_time() 151 | print('finished use time %6.3f s' % (end - start)) 152 | 153 | def get_next_link(headers): 154 | re_next_link = re.compile(r'<(.+)>; rel="next"') 155 | if "Link" in headers: 156 | match = re_next_link.match(headers["Link"]) 157 | if match: 158 | return match.group(1) 159 | 160 | def get_batch(batch_url, session): 161 | while batch_url: 162 | response = session.get(batch_url) 163 | response.raise_for_status() 164 | total = response.headers["x-total-results"] 165 | yield response, total 166 | batch_url = get_next_link(response.headers) 167 | 168 | def get_batch_data_from_uniprot_rest_api(url): 169 | # url = 'https://rest.uniprot.org/uniprotkb/search?fields=accession%2Ccc_interaction&format=tsv&query=Insulin%20AND%20%28reviewed%3Atrue%29&size=500' 170 | res = [] 171 | session = requests.Session() 172 | retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504]) 173 | session.mount("https://", HTTPAdapter(max_retries=retries)) 174 | 175 | for batch, total in tqdm(get_batch(url, session)): 176 | batch_records = batch.text.splitlines()[1:] 177 | res = res + batch_records 178 | 179 | res = [item.split('\t') for item in res] 180 | return res 181 | 182 | 183 | if __name__ =="__main__": 184 | print('success') 185 | # start = time.process_time() 186 | # in_filepath_sprot = cfg.FILE_LATEST_SPROT 187 | # out_filepath_sprot = cfg.FILE_LATEST_SPROT_FEATHER 188 | 189 | # in_filepath_trembl = cfg.FILE_LATEST_TREMBL 190 | # out_filepath_trembl = cfg.FILE_LATEST_TREMBL_FEATHER 191 | 192 | # extract_type ='full' 193 | 194 | 195 | # # read_file_from_gzip(file_in_path=in_filepath_sprot, file_out_path=out_filepath_sprot, extract_type=extract_type, save_file_type='feather') 196 | 197 | # read_file_from_gzip(file_in_path=in_filepath_trembl, file_out_path=out_filepath_trembl, extract_type=extract_type, save_file_type='feather') 198 | # end = time.process_time() 199 | # print('finished use time %6.3f s' % (end - start)) 200 | 201 | -------------------------------------------------------------------------------- /update_production.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5c2801b5-1a62-4abb-9030-0a6ef709cf24", 6 | "metadata": {}, 7 | "source": [ 8 | "# Update Production Data and Model\n", 9 | "\n", 10 | "> author: Shizhenkun \n", 11 | "> email: zhenkun.shi@tib.cas.cn \n", 12 | "> date: 2021-12-24 \n", 13 | "\n", 14 | "This file contains update codes for the production server. The update should be scheduled every eight weeks." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "48326994-9d5b-4905-96ed-36fa3ed72dd9", 20 | "metadata": {}, 21 | "source": [ 22 | "## 1. Import packages" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "96a76da5-08be-41be-b05a-cd68b6d6a789", 29 | "metadata": { 30 | "tags": [] 31 | }, 32 | "outputs": [ 33 | { 34 | "ename": "ModuleNotFoundError", 35 | "evalue": "No module named 'joblib'", 36 | "output_type": "error", 37 | "traceback": [ 38 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 39 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 40 | "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcfg\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m reduce\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjoblib\u001b[39;00m\n\u001b[1;32m 9\u001b[0m sys\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./tools/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mfunclib\u001b[39;00m\n", 41 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'joblib'" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "import numpy as np\n", 47 | "import pandas as pd\n", 48 | "import sys\n", 49 | "\n", 50 | "import config as cfg\n", 51 | "from functools import reduce\n", 52 | "import joblib\n", 53 | "\n", 54 | "sys.path.append(\"./tools/\")\n", 55 | "import funclib\n", 56 | "import exact_ec_from_uniprot as exactec\n", 57 | "import minitools as mtool\n", 58 | "import benchmark_common as bcommon\n", 59 | "import embedding_esm as esmebd\n", 60 | "\n", 61 | "from pandarallel import pandarallel \n", 62 | "pandarallel.initialize() \n", 63 | "import benchmark_train as btrain\n", 64 | "\n", 65 | "%load_ext autoreload\n", 66 | "%autoreload 2" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "b4fe1bf3-da2e-4796-abad-0029f6e81319", 72 | "metadata": {}, 73 | "source": [ 74 | "## 2. Define Functions" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "id": "4ffaeb29-3068-4230-a5a0-13d85cca0f01", 81 | "metadata": { 82 | "tags": [] 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# install axel for download dataset\n", 87 | "def install_axel():\n", 88 | " isExists = !which axel\n", 89 | " if 'axel' in str(isExists[0]):\n", 90 | " return True\n", 91 | " else:\n", 92 | " !sudo apt install axel -y\n", 93 | "\n", 94 | "# add missing '-' for ec number\n", 95 | "def refill_ec(ec): \n", 96 | " if ec == '-':\n", 97 | " return ec\n", 98 | " levelArray = ec.split('.')\n", 99 | " if levelArray[3]=='':\n", 100 | " levelArray[3] ='-'\n", 101 | " ec = '.'.join(levelArray)\n", 102 | " return ec\n", 103 | "\n", 104 | "def specific_ecs(ecstr):\n", 105 | " if '-' not in ecstr or len(ecstr)<4:\n", 106 | " return ecstr\n", 107 | " ecs = ecstr.split(',')\n", 108 | " if len(ecs)==1:\n", 109 | " return ecstr\n", 110 | " \n", 111 | " reslist=[]\n", 112 | " \n", 113 | " for ec in ecs:\n", 114 | " recs = ecs.copy()\n", 115 | " recs.remove(ec)\n", 116 | " ecarray = np.array([x.split('.') for x in recs])\n", 117 | " \n", 118 | " if '-' not in ec:\n", 119 | " reslist +=[ec]\n", 120 | " continue\n", 121 | " linearray= ec.split('.')\n", 122 | " if linearray[1] == '-':\n", 123 | " #l1 in l1s and l2 not empty\n", 124 | " if (linearray[0] in ecarray[:,0]) and (len(set(ecarray[:,0]) - set({'-'}))>0):\n", 125 | " continue\n", 126 | " if linearray[2] == '-':\n", 127 | " # l1, l2 in l1s l2s, l3 not empty\n", 128 | " if (linearray[0] in ecarray[:,0]) and (linearray[1] in ecarray[:,1]) and (len(set(ecarray[:,2]) - set({'-'}))>0):\n", 129 | " continue\n", 130 | " if linearray[3] == '-':\n", 131 | " # l1, l2, l3 in l1s l2s l3s, l4 not empty\n", 132 | " if (linearray[0] in ecarray[:,0]) and (linearray[1] in ecarray[:,1]) and (linearray[2] in ecarray[:,2]) and (len(set(ecarray[:,3]) - set({'-'}))>0):\n", 133 | " continue\n", 134 | " \n", 135 | " reslist +=[ec]\n", 136 | " return ','.join(reslist)\n", 137 | "\n", 138 | "#format ec\n", 139 | "def format_ec(ecstr):\n", 140 | " ecArray= ecstr.split(',')\n", 141 | " ecArray=[x.strip() for x in ecArray] #strip blank\n", 142 | " ecArray=[refill_ec(x) for x in ecArray] #format ec to full\n", 143 | " ecArray = list(set(ecArray)) # remove duplicates\n", 144 | " \n", 145 | " return ','.join(ecArray)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "b5e4912f-59f0-4555-9172-edcb548056e1", 151 | "metadata": {}, 152 | "source": [ 153 | "## 3. Download latest data from unisprot" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 8, 159 | "id": "a53aba9e-9232-44a2-a554-8f670e536fb0", 160 | "metadata": { 161 | "tags": [] 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "# download location ./tmp\n", 166 | "\n", 167 | "# ! mv $cfg.DATADIR'uniprot_sprot_latest.dat.gz' $cfg.TEMPDIR$currenttime'_uniprot_sprot_latest.dat.gz'\n", 168 | "# install_axel()\n", 169 | "! axel -n 10 https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz -o ./data/uniprot_sprot_latest.dat.gz -q -c" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "id": "5d7df6ac-7247-4045-bca2-58277bb61d24", 175 | "metadata": {}, 176 | "source": [ 177 | "## 4. Preprocessing" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 9, 183 | "id": "afee4738-dc8e-4637-8d7f-b288035ab3ff", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stderr", 188 | "output_type": "stream", 189 | "text": [ 190 | "566996it [04:36, 2050.78it/s]\n" 191 | ] 192 | }, 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "finished use time 275.733 s\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "exactec.run_exact_task(infile=cfg.DATADIR+'uniprot_sprot_latest.dat.gz', outfile=cfg.DATADIR+'sprot_latest.tsv')\n", 203 | "\n", 204 | "#加载数据并转换时间格式\n", 205 | "sprot_latest = pd.read_csv(cfg.DATADIR+'sprot_latest.tsv', sep='\\t',header=0) #读入文件\n", 206 | "sprot_latest = mtool.convert_DF_dateTime(inputdf = sprot_latest)\n", 207 | "\n", 208 | "sprot_latest.drop_duplicates(subset=['seq'], keep='first', inplace=True)\n", 209 | "sprot_latest.reset_index(drop=True, inplace=True)\n", 210 | "\n", 211 | "#sprot_latest format EC\n", 212 | "sprot_latest['ec_number'] = sprot_latest.ec_number.parallel_apply(lambda x: format_ec(x))\n", 213 | "sprot_latest['ec_number'] = sprot_latest.ec_number.parallel_apply(lambda x: specific_ecs(x))\n", 214 | "sprot_latest['functionCounts'] = sprot_latest.ec_number.parallel_apply(lambda x: 0 if x=='-' else len(x.split(',')))\n", 215 | "\n", 216 | "# Trim Strging\n", 217 | "with pd.option_context('mode.chained_assignment', None):\n", 218 | " sprot_latest.ec_number = sprot_latest.ec_number.parallel_apply(lambda x : str(x).strip()) #ec trim\n", 219 | " sprot_latest.seq = sprot_latest.seq.parallel_apply(lambda x : str(x).strip()) #seq trim\n", 220 | "\n", 221 | "sprot_latest.to_feather(cfg.DATADIR + 'latest_sprot.feather')\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "4ef3c2fa-5187-404d-9df3-bbc657287910", 227 | "metadata": {}, 228 | "source": [ 229 | "## 5. Caculation Features" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 13, 235 | "id": "d4d7fc8a-65ad-4248-a2d2-649e0b813745", 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "train size: 478954\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "train= pd.read_feather(cfg.DATADIR + 'latest_sprot.feather')\n", 248 | "print('train size: {0}'.format(len(train)))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "id": "e5fe8cb0-9e00-4f0c-9045-bef4d106204e", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep0.feather': No such file or directory\n", 262 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep32.feather': No such file or directory\n", 263 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep33.feather': No such file or directory\n", 264 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_unirep.feather': No such file or directory\n", 265 | "Transferred model to GPU\n" 266 | ] 267 | }, 268 | { 269 | "name": "stderr", 270 | "output_type": "stream", 271 | "text": [ 272 | " 10%|█████████ | 47813/478954 [1:00:08<12:13:55, 9.79it/s]" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "! mv $cfg.DATADIR'sprot_latest_rep0.feather' $cfg.DATADIR'featureBank/sprot_latest_rep0.feather'\n", 278 | "! mv $cfg.DATADIR'sprot_latest_rep32.feather' $cfg.DATADIR'featureBank/sprot_latest_rep32.feather'\n", 279 | "! mv $cfg.DATADIR'sprot_latest_rep33.feather' $cfg.DATADIR'featureBank/sprot_latest_rep33.feather'\n", 280 | "! mv $cfg.DATADIR'sprot_latest_unirep.feather' $cfg.DATADIR'featureBank/sprot_latest_unirep.feather'\n", 281 | "\n", 282 | "# !pip install fair-esm\n", 283 | "tr_rep0, tr_rep32, tr_rep33 = esmebd.get_rep_multi_sequence(sequences=train, model='esm1b_t33_650M_UR50S',seqthres=1022)\n", 284 | "tr_rep0.to_feather(cfg.DATADIR + 'sprot_latest_rep0.feather')\n", 285 | "tr_rep32.to_feather(cfg.DATADIR + 'sprot_latest_rep32.feather')\n", 286 | "tr_rep33.to_feather(cfg.DATADIR + 'sprot_latest_rep33.feather')\n", 287 | "\n" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 16, 293 | "id": "1ac61a46-8250-4852-8dcd-96453ba52e2d", 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/html": [ 299 | "
\n", 300 | "\n", 313 | "\n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | "
idf1f2f3f4f5f6f7f8f9...f1271f1272f1273f1274f1275f1276f1277f1278f1279f1280
0P609950.1709890.0256570.0257440.020174-0.076020-0.1910660.0851430.504738-0.036427...0.210663-0.1027720.087211-0.0487000.1936670.1696940.027922-0.0605040.285435-0.069696
1P02396-0.1111110.0029450.0180400.028491-0.006897-0.017328-0.033452-0.107921-0.010232...0.037690-0.116853-0.0234830.0642090.0174970.0969470.0047600.0005240.0560060.069281
2P02362-0.188286-0.0070290.0066700.029186-0.0027060.0453820.0457750.001362-0.016976...0.034140-0.077455-0.0415540.071670-0.0016900.0971980.0198490.001559-0.0087010.083810
3P02565-0.026709-0.0005620.0137120.003235-0.0388730.023068-0.035896-0.070689-0.016175...0.015740-0.104805-0.0047120.0366990.0342010.067125-0.0004670.003438-0.0483340.014916
4P02827-0.1172810.0041210.0212030.024243-0.0455090.132612-0.0074180.035875-0.022678...-0.003237-0.081184-0.0218810.0246420.0382220.0419060.0099690.000780-0.1094300.042850
..................................................................
478949C0P9J6-0.0924070.0103080.0228300.019094-0.0569880.064407-0.0320280.017303-0.020415...-0.028978-0.085610-0.0159380.0488610.0186880.0827180.016897-0.003126-0.0933330.065130
478950C0PTH8-0.1309440.0062370.0215030.022638-0.0443760.1284830.0233100.020793-0.019937...-0.005362-0.079125-0.0172780.0197710.0329460.0544440.021701-0.001576-0.1170680.037637
478951D5KXD2-0.0984200.0018060.0236560.018481-0.0555790.1172490.0300400.014601-0.021143...0.004568-0.075909-0.0117510.0363920.0196320.0739730.020997-0.001993-0.1012030.055563
478952A0A2Z4HPZ0-0.1237960.0022070.0216130.022929-0.0553020.1127820.0043560.053342-0.014881...0.000617-0.074901-0.0152750.0333630.0188510.0584290.019783-0.003662-0.1266440.052289
478953W7N2P0-0.2070700.0007970.0168630.0267360.013239-0.0619950.0608440.069271-0.026021...0.012974-0.110873-0.0287240.0593150.0347170.0524760.0321700.0077320.0171070.083767
\n", 607 | "

478954 rows × 1281 columns

\n", 608 | "
" 609 | ], 610 | "text/plain": [ 611 | " id f1 f2 f3 f4 f5 \\\n", 612 | "0 P60995 0.170989 0.025657 0.025744 0.020174 -0.076020 \n", 613 | "1 P02396 -0.111111 0.002945 0.018040 0.028491 -0.006897 \n", 614 | "2 P02362 -0.188286 -0.007029 0.006670 0.029186 -0.002706 \n", 615 | "3 P02565 -0.026709 -0.000562 0.013712 0.003235 -0.038873 \n", 616 | "4 P02827 -0.117281 0.004121 0.021203 0.024243 -0.045509 \n", 617 | "... ... ... ... ... ... ... \n", 618 | "478949 C0P9J6 -0.092407 0.010308 0.022830 0.019094 -0.056988 \n", 619 | "478950 C0PTH8 -0.130944 0.006237 0.021503 0.022638 -0.044376 \n", 620 | "478951 D5KXD2 -0.098420 0.001806 0.023656 0.018481 -0.055579 \n", 621 | "478952 A0A2Z4HPZ0 -0.123796 0.002207 0.021613 0.022929 -0.055302 \n", 622 | "478953 W7N2P0 -0.207070 0.000797 0.016863 0.026736 0.013239 \n", 623 | "\n", 624 | " f6 f7 f8 f9 ... f1271 f1272 \\\n", 625 | "0 -0.191066 0.085143 0.504738 -0.036427 ... 0.210663 -0.102772 \n", 626 | "1 -0.017328 -0.033452 -0.107921 -0.010232 ... 0.037690 -0.116853 \n", 627 | "2 0.045382 0.045775 0.001362 -0.016976 ... 0.034140 -0.077455 \n", 628 | "3 0.023068 -0.035896 -0.070689 -0.016175 ... 0.015740 -0.104805 \n", 629 | "4 0.132612 -0.007418 0.035875 -0.022678 ... -0.003237 -0.081184 \n", 630 | "... ... ... ... ... ... ... ... \n", 631 | "478949 0.064407 -0.032028 0.017303 -0.020415 ... -0.028978 -0.085610 \n", 632 | "478950 0.128483 0.023310 0.020793 -0.019937 ... -0.005362 -0.079125 \n", 633 | "478951 0.117249 0.030040 0.014601 -0.021143 ... 0.004568 -0.075909 \n", 634 | "478952 0.112782 0.004356 0.053342 -0.014881 ... 0.000617 -0.074901 \n", 635 | "478953 -0.061995 0.060844 0.069271 -0.026021 ... 0.012974 -0.110873 \n", 636 | "\n", 637 | " f1273 f1274 f1275 f1276 f1277 f1278 f1279 \\\n", 638 | "0 0.087211 -0.048700 0.193667 0.169694 0.027922 -0.060504 0.285435 \n", 639 | "1 -0.023483 0.064209 0.017497 0.096947 0.004760 0.000524 0.056006 \n", 640 | "2 -0.041554 0.071670 -0.001690 0.097198 0.019849 0.001559 -0.008701 \n", 641 | "3 -0.004712 0.036699 0.034201 0.067125 -0.000467 0.003438 -0.048334 \n", 642 | "4 -0.021881 0.024642 0.038222 0.041906 0.009969 0.000780 -0.109430 \n", 643 | "... ... ... ... ... ... ... ... \n", 644 | "478949 -0.015938 0.048861 0.018688 0.082718 0.016897 -0.003126 -0.093333 \n", 645 | "478950 -0.017278 0.019771 0.032946 0.054444 0.021701 -0.001576 -0.117068 \n", 646 | "478951 -0.011751 0.036392 0.019632 0.073973 0.020997 -0.001993 -0.101203 \n", 647 | "478952 -0.015275 0.033363 0.018851 0.058429 0.019783 -0.003662 -0.126644 \n", 648 | "478953 -0.028724 0.059315 0.034717 0.052476 0.032170 0.007732 0.017107 \n", 649 | "\n", 650 | " f1280 \n", 651 | "0 -0.069696 \n", 652 | "1 0.069281 \n", 653 | "2 0.083810 \n", 654 | "3 0.014916 \n", 655 | "4 0.042850 \n", 656 | "... ... \n", 657 | "478949 0.065130 \n", 658 | "478950 0.037637 \n", 659 | "478951 0.055563 \n", 660 | "478952 0.052289 \n", 661 | "478953 0.083767 \n", 662 | "\n", 663 | "[478954 rows x 1281 columns]" 664 | ] 665 | }, 666 | "execution_count": 16, 667 | "metadata": {}, 668 | "output_type": "execute_result" 669 | } 670 | ], 671 | "source": [ 672 | "tr_rep0" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": 18, 678 | "id": "c9b2137c-ae11-42f7-a8da-fa699cca7520", 679 | "metadata": {}, 680 | "outputs": [], 681 | "source": [ 682 | "train_esm_latest = pd.read_feather(cfg.DATADIR + 'sprot_latest_rep32.feather')\n", 683 | "train_esm_latest = train.merge(train_esm_latest, on='id', how='left')" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "id": "5da8140c-e00f-413b-b7c8-94607a44ea59", 689 | "metadata": {}, 690 | "source": [ 691 | "## 6. Split X Y" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": 19, 697 | "id": "a5ee3dbe-519c-471e-aa60-feecf84db6d0", 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "name": "stderr", 702 | "output_type": "stream", 703 | "text": [ 704 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 231834/231834 [05:02<00:00, 766.36it/s]\n" 705 | ] 706 | }, 707 | { 708 | "name": "stdout", 709 | "output_type": "stream", 710 | "text": [ 711 | "loading ec to label dict\n" 712 | ] 713 | } 714 | ], 715 | "source": [ 716 | "# task 1\n", 717 | "X_train_task1 = np.array(train_esm_latest.iloc[:,12:])\n", 718 | "Y_train_task1 = np.array(train_esm_latest.isenzyme.astype('int')).flatten()\n", 719 | "train_enzyme = train_esm_latest[train_esm_latest.isenzyme].reset_index(drop=True)\n", 720 | "\n", 721 | "# task 2\n", 722 | "X_train_task2_s = np.array(train_enzyme.iloc[:,12:])\n", 723 | "Y_train_task2_s = train_enzyme.functionCounts.apply(lambda x : 0 if x==1 else 1).astype('int').values\n", 724 | "\n", 725 | "train_task2M=train_enzyme[train_enzyme.functionCounts>=2].reset_index(drop=True)\n", 726 | "X_train_task2_m = np.array(train_task2M.iloc[:,12:])\n", 727 | "Y_train_task2_m = np.array(train_task2M.functionCounts.astype('int')-2).flatten()\n", 728 | "\n", 729 | "#task 3\n", 730 | "train_set_task3= funclib.split_ecdf_to_single_lines(train_enzyme.iloc[:,np.r_[0,10,5]])\n", 731 | "train_set_task3=train_set_task3.merge(train_esm_latest.iloc[:,np.r_[0,12:1292]], on='id', how='left')\n", 732 | "\n", 733 | "#4. Loading EC Numbers\n", 734 | "print('loading ec to label dict')\n", 735 | "dict_ec_label = btrain.make_ec_label(train_label=train_set_task3['ec_number'], test_label=train_set_task3['ec_number'], file_save= cfg.FILE_EC_LABEL_DICT, force_model_update=cfg.UPDATE_MODEL)\n", 736 | "\n", 737 | "train_set_task3['ec_label']=train_set_task3.ec_number.parallel_apply(lambda x: dict_ec_label.get(x)) \n", 738 | "X_train_task3 = np.array(train_set_task3.iloc[:,3:])\n", 739 | "Y_train_task3 = np.array(train_set_task3.ec_label.astype('int')).flatten()" 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "id": "2f5e89d3-b5f5-4897-bb23-afaa0db3d388", 745 | "metadata": {}, 746 | "source": [ 747 | "## 7. Train Model" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": null, 753 | "id": "df91a62a-0eee-44b1-9b0a-65b5e60d8bc0", 754 | "metadata": {}, 755 | "outputs": [], 756 | "source": [] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "id": "a52c6fe4-94de-4b69-be93-427535de880e", 762 | "metadata": {}, 763 | "outputs": [], 764 | "source": [] 765 | } 766 | ], 767 | "metadata": { 768 | "kernelspec": { 769 | "display_name": "ECRECer", 770 | "language": "python", 771 | "name": "ecrecer" 772 | }, 773 | "language_info": { 774 | "codemirror_mode": { 775 | "name": "ipython", 776 | "version": 3 777 | }, 778 | "file_extension": ".py", 779 | "mimetype": "text/x-python", 780 | "name": "python", 781 | "nbconvert_exporter": "python", 782 | "pygments_lexer": "ipython3", 783 | "version": "3.7.10" 784 | }, 785 | "widgets": { 786 | "application/vnd.jupyter.widget-state+json": { 787 | "state": {}, 788 | "version_major": 2, 789 | "version_minor": 0 790 | } 791 | } 792 | }, 793 | "nbformat": 4, 794 | "nbformat_minor": 5 795 | } 796 | --------------------------------------------------------------------------------