├── .github
└── ISSUE_TEMPLATE
│ ├── 1-bug_report.yaml
│ └── config.yml
├── .gitignore
├── LICENSE
├── README.md
├── benchmark_common.py
├── benchmark_evaluation.py
├── benchmark_test.py
├── benchmark_train.py
├── config.py
├── data
├── README.md
└── dict
│ ├── README.md
│ ├── dict_label_task1.h5
│ ├── dict_label_task2.h5
│ └── dict_label_task3.h5
├── env.yaml
├── model
└── README.md
├── prepare_production_data.ipynb
├── production.py
├── tasks
├── README.md
├── prepare_task_dataset.ipynb
├── task1.ipynb
├── task2.ipynb
└── task3.ipynb
├── temp.ipynb
├── tmp
└── readme.md
├── tools
├── Attention.py
├── __init__.py
├── baselines.ipynb
├── embdding_onehot.py
├── embedding_esm.py
├── exact_ec_from_uniprot.py
├── filetool.py
├── funclib.py
└── uniprottool.py
└── update_production.ipynb
/.github/ISSUE_TEMPLATE/1-bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Report Problem / 故障报告
2 | description: Report a problem with ECRecer / 报告使用网站站时出现的问题
3 | labels: ["problem report"]
4 | body:
5 | - type: markdown
6 | id: preamble
7 | attributes:
8 | value: |
9 | 如果在使用ECRecer网站时遇到问题,请填写以下表单。
10 | If you encountered a problem using ECRecer, please fill in the form below.
11 | - type: checkboxes
12 | id: prerequisites
13 | attributes:
14 | label: 先决条件 / Prerequisites
15 | options:
16 | - label: |
17 | 该问题未被其他 [issues](https://github.com/kingstdio/ECRECer/issues) 提出。
18 | No previous issues have reported this problem.
19 | required: true
20 | - type: textarea
21 | id: what_happened
22 | attributes:
23 | label: 发生了什么 / What happened
24 | validations:
25 | required: true
26 | - type: textarea
27 | id: expected_behavior
28 | attributes:
29 | label: 期望的现象 / What you expected
30 | validations:
31 | required: true
32 | - type: textarea
33 | id: how_to_reproduce
34 | attributes:
35 | label: 如何重现此问题 / How to reproduce
36 | validations:
37 | required: true
38 | - type: input
39 | id: os_version
40 | attributes:
41 | label: 操作系统版本 / OS version
42 | validations:
43 | required: true
44 | - type: input
45 | id: browser
46 | attributes:
47 | label: 浏览器或客户端版本 / Browser/client version
48 | validations:
49 | required: true
50 | - type: textarea
51 | id: other_env
52 | attributes:
53 | label: 其他环境 / Other environments
54 | - type: textarea
55 | id: others
56 | attributes:
57 | label: 其他 / Anything else
58 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | #exclude tempfile
133 | tmp/
134 | temp/
135 | *.feather
136 | *.fasta
137 | *.dmnd
138 | *.txt
139 | *.feature
140 | *.tsv
141 | *.npy
142 | *.csv
143 | *.h5
144 |
145 | baselines/
146 | results/
147 |
148 | ._.DS_Store
149 | .DS_Store
150 | *.gz
151 |
152 | *.model
153 | *.st
154 |
155 | model/single_multi.model
156 |
157 | *.faa
158 | model/multi_many.model
159 | model/multi_many.model
160 | model/single_multi.model
161 |
162 | data/uniprot/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 kingstdio
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
11 |
12 | # DMLF: Enzyme Commission Number Predicting and Benchmarking with Multi-agent Dual-core Learning
13 |
14 | This repo contains source codes for a EC prediction tool namely ECRECer, which is an implementation of our paper: 「Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework」
15 |
16 | Detailed information about the framework can be found in our paper
17 |
18 | ```bash
19 | 1. Zhenkun Shi, Qianqian Yuan, Ruoyu Wang, Hoaran Li, Xiaoping Liao*, Hongwu Ma* (2022). ECRECer: Enzyme Commission Number Recommendation and Benchmarking based on Multiagent Dual-core Learning. arXiv preprint arXiv:2202.03632.
20 |
21 | 2. Zhenkun Shi, Rui Deng, Qianqian Yuan, Zhitao Mao, Ruoyu Wang, Haoran Li, Xiaoping Liao*, Hongwu Ma* (2023). Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework. Research.
22 | ```
23 |
24 |
25 | ## Usage
26 | ### For simply use our tools to predict EC numbers, please visit ECRECEer websiet at https://ecrecer.biodesign.ac.cn
27 |
28 |
29 | ### For users who want to run ECRECer locally, please follow the steps below:
30 | We provide docker image and singularity image for users to run ECRECer locally.
31 |
32 | > Docker image:
33 |
34 | ```bash
35 | # 1. pull ecrecer docker image
36 | docker pull kingstdio/ecrecer
37 |
38 | # 2. run ecrecer docker image
39 | # gpu version:
40 | sudo docker run -it -d --gpus all --name ecrecer -v ~/:/home/ kingstdio/ecrecer #~/ is your fasta file folder
41 | # cpu version:
42 | sudo docker run -it -d --name ecrecer -v ~/:/home/ kingstdio/ecrecer #~/ is your fasta file folder
43 |
44 | # 3. run ECRECer prediction
45 |
46 | sudo docker exec ecrecer python /ecrecer/production.py -i /home/input_fasta_file.fasta -o /home/output_tsv_file.tsv -mode h -topk 10
47 |
48 | #-topk: top k predicted EC numbers
49 | #-mode p: prediction mode, predict EC numbers only
50 | #-mode r: recommendation mode, recommend EC numbers with predicted probabilities, the higher the better
51 | #-mode h: hybird mode, use prediction, recommendation and sequence alignment methods
52 |
53 | ```
54 |
55 | > Singularity image:
56 |
57 | ```bash
58 | # 1. pull ecrecer singularity image
59 |
60 | # Image ~= 11GB, may take a while to download
61 | wget -c https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/sifimages/ecrecer.sif
62 |
63 | # 2. run ecrecer singularity image
64 | # gpu version:
65 | singularity run --nv ecrecer.sif python /ecrecer/production.py -i input_fasta_file.fasta -o output_tsv_file.tsv -mode h -topk 10
66 | # cpu version:
67 | singularity run ecrecer.sif python /ecrecer/production.py -i input_fasta_file.fasta -o output_tsv_file.tsv -mode h -topk 10
68 |
69 | #-topk: top k predicted EC numbers
70 | #-mode p: prediction mode, predict EC numbers only
71 | #-mode r: recommendation mode, recommend EC numbers with predicted probabilities, the higher the better
72 | #-mode h: hybird mode, use prediction, recommendation and sequence alignment methods
73 |
74 | ```
75 |
76 |
77 |
78 |
79 | # To re-implement our experiments or offline use, pls use read the details below:
80 |
81 | # Prerequisites
82 |
83 | + Python >= 3.6
84 | + Sklearn
85 | + Xgboost
86 | + conda
87 | + jupyter lab
88 | + ...
89 |
90 | > Create conda env use [env.yaml](./env.yaml)
91 |
92 | ```python
93 | git clone git@github.com:kingstdio/ECRECer.git
94 | conda env create -f env.yaml
95 | ```
96 |
97 | # Preprocessing
98 |
99 | Download and prepare the data set use the.
100 |
101 | > [prepare_task_dataset.ipynb](./prepare_task_dataset.ipynb)
102 |
103 | Or directly download the preprocessed data from [aws public dataset](https://tibd-public-datasets.s3.amazonaws.com/ecrecer/ecrecer_datasets.zip) and put it in the rootfolder/data/datasets/
104 |
105 |
118 |
119 | # High throughput benchmarking
120 |
121 | # Train
122 |
123 | ```python
124 | python benchmark_train.py
125 | ```
126 |
127 | # Test
128 |
129 | ```python
130 | python benchmark_test.py
131 | ```
132 |
133 | # Evaluation
134 |
135 | ```python
136 | python benchmark_evaluation.py
137 | ```
138 |
139 | # Production
140 |
141 | ```python
142 | python production.py -i input_fasta_file -o output_tsv_file -mode [p|r] -topk 5
143 | ```
144 |
145 | # Citations
146 |
147 | If you find these methods valuable for your research, we kindly request that you reference the pertinent paper:
148 |
149 | ```bib
150 | @article{shi2023enzyme,
151 | title={Enzyme Commission Number Prediction and Benchmarking with Hierarchical Dual-core Multitask Learning Framework},
152 | author={Shi, Zhenkun and Deng, Rui and Yuan, Qianqian and Mao, Zhitao and Wang, Ruoyu and Li, Haoran and Liao, Xiaoping and Ma, Hongwu},
153 | journal={Research},
154 | year={2023},
155 | publisher={AAAS}
156 | }
157 | ```
158 |
159 | ## Stargazers over time
160 |
161 | [](https://github.com/kingstdio/ECRECer/)
162 |
--------------------------------------------------------------------------------
/benchmark_common.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os,string, random,joblib,sys
4 | from datetime import datetime
5 | import config as cfg
6 | from sklearn import preprocessing
7 | from keras.models import Model
8 | from keras.optimizers import Adam
9 | from keras.layers import Input, Dense, GRU, Bidirectional
10 | from tools import Attention
11 |
12 |
13 | # sortmax 结果转 onehot
14 | def props_to_onehot(props):
15 | if isinstance(props, list):
16 | props = np.array(props)
17 | a = np.argmax(props, axis=1)
18 | b = np.zeros((len(a), props.shape[1]))
19 | b[np.arange(len(a)), a] = 1
20 | return b
21 |
22 | def make_onehot_label(label_list, save=False, file_encoder='./encode.h5', type='singel'):
23 | if type=='singel':
24 | encoder = preprocessing.OneHotEncoder(sparse=False)
25 | results = encoder.fit_transform([[item] for item in label_list])
26 |
27 | elif type=='multi':
28 | encoder = preprocessing.MultiLabelBinarizer()
29 | results=encoder.fit_transform([ item.split(',') for item in label_list])
30 | else:
31 | print('lable encoding type error, please check')
32 | sys.exit()
33 |
34 | if save ==True:
35 | joblib.dump(encoder, file_encoder)
36 |
37 | return results
38 |
39 | def mgru_attion_model(input_dimensions, gru_h_size=512, dropout=0.2, lossfunction='binary_crossentropy',
40 | evaluation_metrics='accuracy', activation_method = 'sigmoid', output_dimensions=2
41 | ):
42 | inputs = Input(shape=(1, input_dimensions), name="input")
43 | gru = Bidirectional(GRU(gru_h_size, dropout=dropout, return_sequences=True), name="bi-gru")(inputs)
44 | attention = Attention.Attention(32)(gru)
45 | output = Dense(output_dimensions, activation=activation_method, name="dense")(attention)
46 | model = Model(inputs, output)
47 | model.compile(loss=lossfunction, optimizer=Adam(),metrics=[evaluation_metrics])
48 |
49 | return model
50 |
51 | # region 需要写fasta的dataFame格式
52 | def save_table2fasta(dataset, file_out):
53 | """[summary]
54 | Args:
55 | dataset ([DataFrame]): 需要写fasta的dataFame格式[id, seq]
56 | file_out ([type]): [description]
57 | """
58 | if ~os.path.exists(file_out):
59 | file = open(file_out, 'w')
60 | for index, row in dataset.iterrows():
61 | file.write('>{0}\n'.format(row['id']))
62 | file.write('{0}\n'.format(row['seq']))
63 | file.close()
64 | print('Write finished')
65 | # endregion
66 |
67 | # region 获取序列比对结果
68 | def getblast(ref_fasta, query_fasta, results_file):
69 | """[获取序列比对结果]
70 | Args:
71 | ref_fasta ([string]]): [训练集fasta文件]
72 | query_fasta ([string]): [测试数据集fasta文件]
73 | results_file ([string]): [要保存的结果文件]
74 |
75 | Returns:
76 | [DataFrame]: [比对结果]
77 | """
78 |
79 | if os.path.exists(results_file):
80 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD)
81 | return res_data
82 |
83 |
84 | cmd1 = r'diamond makedb --in {0} -d /tmp/train.dmnd'.format(ref_fasta)
85 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q {0} -o {1} -b8 -c1 -k 1'.format(query_fasta, results_file)
86 | cmd3 = r'rm -rf /tmp/*.dmnd'
87 |
88 | print(cmd1)
89 | os.system(cmd1)
90 | print(cmd2)
91 | os.system(cmd2)
92 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD)
93 | os.system(cmd3)
94 | return res_data
95 | # endregion
96 |
97 | #region 创建diamond数据库
98 | def make_diamond_db(dbtable, to_db_file):
99 | """[创建diamond数据库]
100 |
101 | Args:
102 | dbtable ([DataTable]): [[ID,SEQ]DataFame]
103 | to_db_file ([string]): [数据库文件]
104 | """
105 | save_table2fasta(dataset=dbtable, file_out=cfg.TEMPDIR+'dbtable.fasta')
106 | cmd = r'diamond makedb --in {0} -d {1}'.format(cfg.TEMPDIR+'dbtable.fasta', to_db_file)
107 | os.system(cmd)
108 | cmd = r'rm -rf {0}'.format(cfg.TEMPDIR+'dbtable.fasta')
109 | os.system(cmd)
110 | #endregion
111 |
112 | #region 为blast序列比对添加结果标签
113 | def blast_add_label(blast_df, trainset,):
114 | """[为blast序列比对添加结果标签]
115 |
116 | Args:
117 | blast_df ([DataFrame]): [序列比对结果]
118 | trainset ([DataFrame]): [训练数据]
119 |
120 | Returns:
121 | [Dataframe]: [添加标签后的数据]
122 | """
123 | res_df = blast_df.merge(trainset, left_on='sseqid', right_on='id', how='left')
124 | res_df = res_df.rename(columns={'id_x': 'id',
125 | 'isemzyme': 'isenzyme_blast',
126 | 'functionCounts': 'functionCounts_blast',
127 | 'ec_number': 'ec_blast',
128 | 'ec_specific_level': 'ec_specific_level_blast'
129 | })
130 | return res_df.iloc[:,np.r_[0,13:16]]
131 | #endregion
132 |
133 | #region loading embedding features
134 | def load_data_embedding(embedding_type):
135 | """loading embedding features
136 |
137 | Args:
138 | embedding_type (string): one of ['one-hot', 'unirep', 'esm0', 'esm32', 'esm33']
139 |
140 | Returns:
141 | DataFrame: features
142 | """
143 |
144 | if embedding_type=='one-hot': #one-hot
145 | feature = pd.read_feather(cfg.FILE_FEATURE_ONEHOT)
146 |
147 | if embedding_type=='unirep': #unirep
148 | feature = pd.read_feather(cfg.FILE_FEATURE_UNIREP)
149 |
150 | if embedding_type=='esm0': #esm0
151 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM0)
152 |
153 | if embedding_type =='esm32': #esm32
154 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM32)
155 |
156 | if embedding_type =='esm33': #esm33
157 | feature = pd.read_feather(cfg.FILE_FEATURE_ESM33)
158 |
159 | return feature
160 | #endregion
161 |
162 |
163 | def get_blast_prediction(reference_db, train_frame, test_frame, results_file, identity_thres=0.2):
164 |
165 | save_table2fasta(dataset=test_frame.iloc[:,np.r_[0,5]], file_out=cfg.TEMPDIR+'test.fasta')
166 | cmd = r'diamond blastp -d {0} -q {1} -o {2} -b8 -c1 -k 1'.format(reference_db, cfg.TEMPDIR+'test.fasta', results_file)
167 | print(cmd)
168 | os.system(cmd)
169 | res_data = pd.read_csv(results_file, sep='\t', names=cfg.BLAST_TABLE_HEAD)
170 | res_data = res_data[res_data.pident >= identity_thres] # 按显著性阈值过滤
171 | res_df = res_data.merge(train_frame, left_on='sseqid', right_on='id', how='left')
172 | res_df = res_df.rename(columns={'id_x': 'id',
173 | 'isemzyme': 'isemzyme_blast',
174 | 'functionCounts': 'functionCounts_blast',
175 | 'ec_number': 'ec_number_blast',
176 | 'ec_specific_level': 'ec_specific_level_blast'
177 | })
178 |
179 | return res_df.iloc[:,np.r_[0,2,13:17]]
180 |
181 | #region 读取cdhit聚类结果
182 | def get_cdhit_results(cdhit_clstr_file):
183 | """读取cdhit聚类结果
184 |
185 | Args:
186 | cdhit_clstr_file (string): 聚类结果文件
187 |
188 | Returns:
189 | DataFrame: ['cluster_id','uniprot_id','identity', 'is_representative']
190 | """
191 | counter = 0
192 | res = []
193 | with open(cdhit_clstr_file,'r') as f:
194 | for line in f:
195 | if 'Cluster' in line:
196 | cluster_id = line.replace('>Cluster','').replace('\n', '').strip()
197 | continue
198 | str_uids= line.replace('\n','').split('>')[1].replace('at ','').split('... ')
199 |
200 | if '*' in str_uids[1]:
201 | identity = 1
202 | isrep = True
203 | else:
204 | identity = float(str_uids[1].strip('%')) /100
205 | isrep = False
206 |
207 | res = res +[[cluster_id, str_uids[0], identity, isrep ]]
208 |
209 | resdf = pd.DataFrame(res, columns=['cluster_id','uniprot_id','identity', 'is_representative']) #转换为DataFrame
210 | return resdf
211 | #endregion
212 |
213 | def pycdhit(uniportid_seq_df, identity=0.4, thred_num=4):
214 | """CD-HIT 序列聚类
215 |
216 | Args:
217 | uniportid_seq_df (DataFrame): [uniprot_id, seq] 蛋白DataFrame
218 | identity (float, optional): 聚类阈值. Defaults to 0.4.
219 | thred_num (int, optional): 聚类线程数. Defaults to 4.
220 |
221 | Returns:
222 | 聚类结果 DataFrame: [cluster_id,uniprot_id,identity,is_representative,cluster_size]
223 | """
224 | if identity>=0.7:
225 | word_size = 5
226 | elif identity>=0.6:
227 | word_size = 4
228 | elif identity >=0.5:
229 | word_size = 3
230 | elif identity >=0.4:
231 | word_size =2
232 | else:
233 | word_size = 5
234 |
235 | # 定义输入输出文件名
236 |
237 |
238 |
239 | time_stamp_str = datetime.now().strftime("%Y-%m-%d_%H_%M_%S_")+''.join(random.sample(string.ascii_letters + string.digits, 16))
240 | cd_hit_fasta = f'{cfg.TEMPDIR}cdhit_test_{time_stamp_str}.fasta'
241 | cd_hit_results = f'{cfg.TEMPDIR}cdhit_results_{time_stamp_str}'
242 | cd_hit_cluster_res_file =f'{cfg.TEMPDIR}cdhit_results_{time_stamp_str}.clstr'
243 |
244 | # 写聚类fasta文件
245 | save_table2fasta(uniportid_seq_df, cd_hit_fasta)
246 |
247 | # cd-hit聚类
248 | cmd = f'cd-hit -i {cd_hit_fasta} -o {cd_hit_results} -c {identity} -n {word_size} -T {thred_num} -M 0 -g 1 -sc 1 -sf 1 > /dev/null 2>&1'
249 | os.system(cmd)
250 | cdhit_cluster = get_cdhit_results(cdhit_clstr_file=cd_hit_cluster_res_file)
251 |
252 | cluster_size = cdhit_cluster.cluster_id.value_counts()
253 | cluster_size = pd.DataFrame({'cluster_id':cluster_size.index,'cluster_size':cluster_size.values})
254 | cdhit_cluster = cdhit_cluster.merge(cluster_size, on='cluster_id', how='left')
255 |
256 | cmd = f'rm -f {cd_hit_fasta} {cd_hit_results} {cd_hit_cluster_res_file}'
257 | os.system(cmd)
258 |
259 | return cdhit_cluster
260 |
261 | #region 打印模型的重要指标,排名topN指标
262 | def importance_features_top(model, x_train, topN=10):
263 | """[打印模型的重要指标,排名topN指标]
264 | Args:
265 | model ([type]): [description]
266 | x_train ([type]): [description]
267 | topN (int, optional): [description]. Defaults to 10.
268 | """
269 | print("打印XGBoost重要指标")
270 | feature_importances_ = model.feature_importances_
271 | feature_names = x_train.columns
272 | importance_col = pd.DataFrame([*zip(feature_names, feature_importances_)], columns=['features', 'weight'])
273 | importance_col_desc = importance_col.sort_values(by='weight', ascending=False)
274 | print(importance_col_desc.iloc[:topN, :])
275 | #endregion
276 |
277 | #test
278 |
279 | if __name__ =='__main__':
280 | print('success')
--------------------------------------------------------------------------------
/benchmark_evaluation.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2023-09-28 12:03:17
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2023-10-07 06:14:15
6 | FilePath: /ECRECer/benchmark_evaluation.py
7 | Description:
8 |
9 | Copyright (c) 2023 by tibd, All Rights Reserved.
10 | '''
11 | from operator import index
12 | from random import sample
13 | import pandas as pd
14 | import numpy as np
15 | from sklearn import metrics
16 | from pandas._config.config import reset_option
17 | import benchmark_common as bcommon
18 | import config as cfg
19 |
20 |
21 | #region loading results from different methods and make a big DataFrame
22 | def load_res_data(file_slice, file_blast, file_deepec, file_ecpred,file_catfam, file_priam, train, test):
23 | """[加载各种算法的预测结果,并拼合成一个大表格]
24 |
25 | Args:
26 | file_slice ([string]]): [slice 集成预测结果文件]
27 | file_blast ([string]): [blast 序列比对结果文件]
28 | file_deepec ([string]): [deepec 预测结果文件]
29 | file_ecpred ([string]): [ecpred 预测结果文件]
30 | train ([DataFrame]): [训练集]
31 | test ([DataFrame]): [测试集]
32 |
33 | Returns:
34 | [DataFrame]: [拼合完成的大表格]
35 | """
36 |
37 |
38 | # Slice + groundtruth
39 | ground_truth = test.iloc[:, np.r_[0,1,2,3]]
40 | res_slice = pd.read_csv(file_slice, sep='\t')
41 |
42 | big_res = ground_truth.merge(res_slice, on='id', how='left')
43 |
44 | #Blast
45 | res_blast = pd.read_csv(file_blast, sep='\t', names =cfg.BLAST_TABLE_HEAD)
46 | res_blast = bcommon.blast_add_label(blast_df=res_blast, trainset=train)
47 | big_res = big_res.merge(res_blast, on='id', how='left')
48 |
49 |
50 | # DeepEC
51 | res_deepec = pd.read_csv(file_deepec, sep='\t',names=['id', 'ec_number'], header=0 )
52 | res_deepec.ec_number=res_deepec.apply(lambda x: x['ec_number'].replace('EC:',''), axis=1)
53 | res_deepec.columns = ['id','ec_deepec']
54 | res_deepec['isemzyme_deepec']=res_deepec.ec_deepec.apply(lambda x: True if str(x)!='nan' else False)
55 | res_deepec['functionCounts_deepec'] = res_deepec.ec_deepec.apply(lambda x :len(str(x).split(',')))
56 | big_res = big_res.merge(res_deepec, on='id', how='left').drop_duplicates(subset='id')
57 |
58 | # ECpred
59 | res_ecpred = pd.read_csv(file_ecpred, sep='\t', header=0)
60 | res_ecpred['isemzyme_ecpred'] = ''
61 | with pd.option_context('mode.chained_assignment', None):
62 | res_ecpred.isemzyme_ecpred[res_ecpred['EC Number']=='non Enzyme'] = False
63 | res_ecpred.isemzyme_ecpred[res_ecpred['EC Number']!='non Enzyme'] = True
64 |
65 | res_ecpred.columns = ['id','ec_ecpred', 'conf', 'isemzyme_ecpred']
66 | res_ecpred = res_ecpred.iloc[:,np.r_[0,1,3]]
67 | res_ecpred['functionCounts_ecpred'] = res_ecpred.ec_ecpred.apply(lambda x :len(str(x).split(',')))
68 | big_res = big_res.merge(res_ecpred, on='id', how='left').drop_duplicates(subset='id')
69 |
70 | # CATFAM
71 | res_catfam = pd.read_csv(file_catfam, sep='\t', names=['id', 'ec_catfam'])
72 | res_catfam['isenzyme_catfam']=res_catfam.ec_catfam.apply(lambda x: True if str(x)!='nan' else False)
73 | res_catfam['functionCounts_catfam'] = res_catfam.ec_catfam.apply(lambda x :len(str(x).split(',')))
74 | big_res = big_res.merge(res_catfam, on='id', how='left')
75 |
76 | #PRIAM
77 | res_priam = load_praim_res(resfile=file_priam)
78 | big_res = big_res.merge(res_priam, on='id', how='left')
79 | big_res['isenzyme_priam'] = big_res.ec_priam.apply(lambda x: True if str(x)!='nan' else False)
80 | big_res['functionCounts_priam'] = big_res.ec_priam.apply(lambda x :len(str(x).split(',')))
81 |
82 | big_res=big_res.sort_values(by=['isemzyme', 'id'], ascending=False)
83 | big_res = big_res.drop_duplicates(subset='id')
84 | big_res.reset_index(drop=True, inplace=True)
85 | return big_res
86 | #endregion
87 |
88 | #region
89 | def load_praim_res(resfile):
90 | """[加载PRIAM的预测结果]
91 | Args:
92 | resfile ([string]): [结果文件]
93 | Returns:
94 | [DataFrame]: [结果]
95 | """
96 | f = open(resfile)
97 | line = f.readline()
98 | counter =0
99 | reslist=[]
100 | lstr =''
101 | subec=[]
102 | while line:
103 | if '>' in line:
104 | if counter !=0:
105 | reslist +=[[lstr, ', '.join(subec)]]
106 | subec=[]
107 | lstr = line.replace('>', '').replace('\n', '')
108 | elif line.strip()!='':
109 | ecarray = line.split('\t')
110 | subec += [(ecarray[0].replace('#', '').replace('\n', '').replace(' ', '') )]
111 |
112 | line = f.readline()
113 | counter +=1
114 | f.close()
115 | res_priam=pd.DataFrame(reslist, columns=['id', 'ec_priam'])
116 | return res_priam
117 | #endregion
118 |
119 | def integrate_reslults(big_table):
120 | # 拼合多个标签
121 | big_table['ec_islice']=big_table.apply(lambda x : ', '.join(x.iloc[4:(x.pred_functionCounts+4)].values.astype('str')), axis=1)
122 |
123 | #给非酶赋-
124 | with pd.option_context('mode.chained_assignment', None):
125 | big_table.ec_islice[big_table.ec_islice=='nan']='-'
126 | big_table.ec_islice[big_table.ec_islice=='']='-'
127 | big_table['isemzyme_deepec']=big_table.ec_deepec.apply(lambda x : 0 if str(x)=='nan' else 1)
128 |
129 | big_table=big_table.iloc[:, np.r_[0:4,16:31,31, 14:16]]
130 | big_table = big_table.rename(columns={'isemzyme': 'isenzyme_groundtruth',
131 | 'functionCounts': 'functionCounts_groundtruth',
132 | 'ec_number': 'ec_groundtruth',
133 | 'pred_isEnzyme': 'isenzyme_islice',
134 | 'pred_functionCounts': 'functionCounts_islice',
135 | 'isemzyme_blast':'isenzyme_blast'
136 | })
137 | # 拼合训练测试样本数
138 | samplecounts = pd.read_csv(cfg.DATADIR + 'ecsamplecounts.tsv', sep = '\t')
139 |
140 | big_table = big_table.merge(samplecounts, left_on='ec_groundtruth', right_on='ec_number', how='left')
141 | big_table = big_table.iloc[:, np.r_[0:22,23:25]]
142 |
143 | big_table.to_excel(cfg.FILE_EVL_RESULTS, index=None)
144 | big_table = big_table.drop_duplicates(subset='id')
145 | return big_table
146 |
147 | def caculateMetrix(groundtruth, predict, baselineName, type='binary'):
148 |
149 | if type == 'binary':
150 | acc = metrics.accuracy_score(groundtruth, predict)
151 | precision = metrics.precision_score(groundtruth, predict, zero_division=True )
152 | recall = metrics.recall_score(groundtruth, predict, zero_division=True)
153 | f1 = metrics.f1_score(groundtruth, predict, zero_division=True)
154 | tn, fp, fn, tp = metrics.confusion_matrix(groundtruth, predict).ravel()
155 | npv = tn/(fn+tn+1.4E-45)
156 | print(baselineName, '\t\t%f' %acc,'\t%f'% precision,'\t\t%f'%npv,'\t%f'% recall,'\t%f'% f1, '\t', 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn)
157 |
158 | if type =='include_unfind':
159 | evadf = pd.DataFrame()
160 | evadf['g'] = groundtruth
161 | evadf['p'] = predict
162 |
163 | evadf_hot = evadf[~evadf.p.isnull()]
164 | evadf_cold = evadf[evadf.p.isnull()]
165 |
166 | tp = len(evadf_hot[(evadf_hot.g.astype('int')==1) & (evadf_hot.p.astype('int')==1)])
167 | fp = len(evadf_hot[(evadf_hot.g.astype('int')==0) & (evadf_hot.p.astype('int')==1)])
168 | tn = len(evadf_hot[(evadf_hot.g.astype('int')==0) & (evadf_hot.p.astype('int')==0)])
169 | fn = len(evadf_hot[(evadf_hot.g.astype('int')==1) & (evadf_hot.p.astype('int')==0)])
170 | up = len(evadf_cold[evadf_cold.g==1])
171 | un = len(evadf_cold[evadf_cold.g==0])
172 | acc = (tp+tn)/(tp+fp+tn+fn+up+un)
173 | precision = tp/(tp+fp)
174 | npv = tn/(tn+fn)
175 | recall = tp/(tp+fn+up)
176 | f1=(2*precision*recall)/(precision+recall)
177 | print( baselineName,
178 | '\t\t%f' %acc,
179 | '\t%f'% precision,
180 | '\t\t%f'%npv,
181 | '\t%f'% recall,
182 | '\t%f'% f1, '\t',
183 | 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn, 'up:',up, 'un:',un)
184 |
185 | if type == 'multi':
186 | acc = metrics.accuracy_score(groundtruth, predict)
187 | precision = metrics.precision_score(groundtruth, predict, average='macro', zero_division=True )
188 | recall = metrics.recall_score(groundtruth, predict, average='macro', zero_division=True)
189 | f1 = metrics.f1_score(groundtruth, predict, average='macro', zero_division=True)
190 | print('%12s'%baselineName, ' \t\t%f '%acc,'\t%f'% precision, '\t\t%f'% recall,'\t%f'% f1)
191 |
192 | def evalueate_performance(evalutation_table):
193 | print('\n\n1. isEnzyme prediction evalueation metrics')
194 | print('*'*140+'\n')
195 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t confusion Matrix')
196 | ev_isenzyme={
197 | 'ours':'isenzyme_islice',
198 | 'blast':'isenzyme_blast',
199 | 'ecpred':'isemzyme_ecpred',
200 | 'deepec':'isemzyme_deepec',
201 | 'catfam':'isenzyme_catfam',
202 | 'priam':'isenzyme_priam'
203 | }
204 |
205 | for k,v in ev_isenzyme.items():
206 | caculateMetrix( groundtruth=evalutation_table.isenzyme_groundtruth.astype('int'),
207 | predict=evalutation_table[v],
208 | baselineName=k,
209 | type='include_unfind')
210 |
211 |
212 | print('\n\n2. EC prediction evalueation metrics')
213 | print('*'*140+'\n')
214 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro')
215 | ev_ec = {
216 | 'ours':'ec_islice',
217 | 'blast':'ec_blast',
218 | 'ecpred':'ec_ecpred',
219 | 'deepec':'ec_deepec',
220 | 'catfam':'ec_catfam',
221 | 'priam':'ec_priam',
222 | }
223 | for k, v in ev_ec.items():
224 | caculateMetrix( groundtruth=evalutation_table.ec_groundtruth,
225 | predict=evalutation_table[v].fillna('NaN'),
226 | baselineName=k,
227 | type='multi')
228 |
229 | print('\n\n3. Function counts evalueation metrics')
230 | print('*'*140+'\n')
231 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro')
232 |
233 | ev_functionCounts = {
234 | 'ours':'functionCounts_islice',
235 | 'blast':'functionCounts_blast',
236 | 'ecpred':'functionCounts_ecpred',
237 | 'deepec':'functionCounts_deepec' ,
238 | 'catfam':'functionCounts_catfam',
239 | 'priam': 'functionCounts_priam'
240 | }
241 |
242 | for k, v in ev_functionCounts.items():
243 | caculateMetrix( groundtruth=evalutation_table.functionCounts_groundtruth,
244 | predict=evalutation_table[v].fillna('-1').astype('int'),
245 | baselineName=k,
246 | type='multi')
247 |
248 | num_enzyme = len(evalutation_table[evalutation_table.isenzyme_groundtruth])
249 | num_no_enzyme = len(evalutation_table[~evalutation_table.isenzyme_groundtruth])
250 | num_multi_function = len(evalutation_table[evalutation_table.functionCounts_groundtruth>1])
251 | num_single_function = len(evalutation_table[evalutation_table.functionCounts_groundtruth==1])
252 |
253 |
254 | ev_ec = {'ours':'ec_islice', 'blast':'ec_blast', 'ecpred':'ec_ecpred', 'deepec':'ec_deepec', 'catfam':'ec_catfam'}
255 |
256 | print('\n\n4. EC Prediction Report')
257 | str_out_head = '\t item'
258 | str_out_multi = '多功能酶的accuracy'
259 | str_out_single = '单功能酶的accuracy'
260 | str_out_all ='整体accuracy\t'
261 | for k, v in ev_ec.items():
262 | str_out_head += ('\t\t\t'+str(k))
263 | num_multi = len(evalutation_table[(evalutation_table.functionCounts_groundtruth>1) &
264 | (evalutation_table.ec_groundtruth == evalutation_table[v])])
265 | num_single = len(evalutation_table[(evalutation_table.functionCounts_groundtruth==1) &
266 | (evalutation_table.ec_groundtruth == evalutation_table[v])])
267 | str_out_multi += ('\t\t' + str(num_multi)+'/'+ str(num_multi_function) +'='+ str(round(num_multi/num_multi_function,4)))
268 | str_out_single += ('\t\t' + str(num_single)+'/'+ str(num_single_function) +'='+ str(round(num_single/num_single_function,4)))
269 | str_out_all += ('\t\t' + str(num_single + num_multi)+'/'+ str(num_enzyme) +'='+ str(round(num_single/num_enzyme,4)))
270 | print(str_out_head )
271 | print(str_out_multi)
272 | print(str_out_single)
273 | print(str_out_all)
274 |
275 | def filter_newadded_ec(restable):
276 | #filter newly added ECs by your own
277 | ## TODO caculate the newly addded ecs in the testing set and filter them
278 |
279 | return restable
280 |
281 | if __name__ =='__main__':
282 |
283 | # 1. loading data
284 | train = pd.read_feather(cfg.DATADIR+'train.feather').iloc[:,:6]
285 | test = pd.read_feather(cfg.DATADIR+'test.feather').iloc[:,:6]
286 | # test = test[(test.ec_specific_level>=cfg.TRAIN_USE_SPCIFIC_EC_LEVEL) |(~test.isemzyme)]
287 | test.reset_index(drop=True, inplace=True)
288 | # EC-label dict
289 | dict_ec_label = np.load(cfg.DATADIR + 'ec_label_dict.npy', allow_pickle=True).item()
290 | file_blast_res = cfg.RESULTSDIR + r'test_blast_res.tsv'
291 | flat_table = load_res_data(
292 | file_slice=cfg.FILE_INTE_RESULTS,
293 | file_blast=cfg.FILE_BLAST_RESULTS,
294 | file_deepec=cfg.FILE_DEEPEC_RESULTS,
295 | file_ecpred=cfg.FILE_ECPRED_RESULTS,
296 | file_catfam = cfg.FILE_CATFAM_RESULTS,
297 | file_priam = cfg.FILE_PRIAM_RESULTS,
298 | train=train,
299 | test=test
300 | )
301 |
302 | evalutation_table = integrate_reslults(flat_table)
303 | evalutation_table = filter_newadded_ec(evalutation_table) #filter newly added ECs by your own
304 |
305 | evalueate_performance(evalutation_table)
306 | evalutation_table['sright'] = evalutation_table.apply(lambda x: True if x.ec_groundtruth == x.ec_islice else False, axis=1)
307 | evalutation_table['bright'] = evalutation_table.apply(lambda x: True if x.ec_groundtruth == x.ec_blast else False, axis=1)
308 | evalutation_table.to_excel(cfg.RESULTSDIR+'evaluationFF.xlsx', index=None)
309 | print('\n Evaluation Finished \n\n')
--------------------------------------------------------------------------------
/benchmark_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import joblib
4 | import os
5 | import benchmark_common as bcommon
6 | import config as cfg
7 | from Bio import SeqIO
8 |
9 | # region 获取「酶|非酶」预测结果
10 | def get_isEnzymeRes(querydata, model_file):
11 | """[获取「酶|非酶」预测结果]
12 | Args:
13 | querydata ([DataFrame]): [需要预测的数据]
14 | model_file ([string]): [模型文件]
15 | Returns:
16 | [DataFrame]: [预测结果、预测概率]
17 | """
18 | model = joblib.load(model_file)
19 | predict = model.predict(querydata)
20 | predictprob = model.predict_proba(querydata)
21 | return predict, predictprob[:, 1]
22 | # endregion
23 |
24 | # region 获取「几功能酶」预测结果
25 | def get_howmany_Enzyme(querydata, model_file):
26 | """获取「几功能酶」预测结果
27 | Args:
28 | querydata ([DataFrame]): [需要预测的数据]
29 | model_file ([string]): [模型文件]
30 | Returns:
31 | [DataFrame]: [预测结果、预测概率]
32 | """
33 | model = joblib.load(model_file)
34 | predict = model.predict(querydata)
35 | predictprob = model.predict_proba(querydata)
36 | return predict+1, predictprob #标签加1,单功能酶标签为0,统一加1
37 | # endregion
38 |
39 | # region 获取slice预测结果
40 | def get_slice_res(slice_query_file, model_path, dict_ec_label,test_set, res_file):
41 | """[获取slice预测结果]
42 |
43 | Args:
44 | slice_query_file ([string]): [需要预测的数据sliceFile]
45 | model_path ([string]): [Slice模型路径]
46 | res_file ([string]]): [预测结果文件]
47 | Returns:
48 | [DataFrame]: [预测结果]
49 | """
50 |
51 | cmd = '''./slice_predict {0} {1} {2} -o 32 -b 0 -t 32 -q 0'''.format(slice_query_file, model_path, res_file)
52 | print(cmd)
53 | os.system(cmd)
54 | result_slice = pd.read_csv(res_file, header=None, skiprows=1, sep=' ')
55 |
56 | # 5.3 对预测结果排序
57 | slice_pred_rank, slice_pred_prob = sort_results(result_slice)
58 |
59 | # 5.4 将结果翻译成EC号
60 | slice_pred_ec = translate_slice_pred(slice_pred=slice_pred_rank, ec2label_dict = dict_ec_label, test_set=test_set)
61 |
62 | return slice_pred_ec
63 |
64 | # endregion
65 |
66 | #region 将slice的实验结果排序,并按照推荐顺序以两个矩阵的形式返回
67 | def sort_results(result_slice):
68 | """
69 | 将slice的实验结果排序,并按照推荐顺序以两个矩阵的形式返回
70 | @pred_top:预测结果排序
71 | @pred_pb_top:预测结果评分排序
72 | """
73 | pred_top = []
74 | pred_pb_top = []
75 | aac = []
76 | for index, row in result_slice.iterrows():
77 | row_trans = [*row.apply(lambda x: x.split(':')).values]
78 | row_trans = pd.DataFrame(row_trans).sort_values(by=[1], ascending=False)
79 | pred_top += [list(np.array(row_trans[0]).astype('int'))]
80 | pred_pb_top += [list(np.array(row_trans[1]).astype('float'))]
81 | pred_top = pd.DataFrame(pred_top)
82 | pred_pb_top = pd.DataFrame(pred_pb_top)
83 | return pred_top, pred_pb_top
84 | #endregion
85 |
86 | #region 划分测试集XY
87 | def get_test_set(data):
88 | """[划分测试集XY]
89 |
90 | Args:
91 | data ([DataFrame]): [测试数据]
92 |
93 | Returns:
94 | [DataFrame]: [划分好的XY]
95 | """
96 | testX = data.iloc[:,7:]
97 | testY = data.iloc[:,:6]
98 | return testX, testY
99 | #endregion
100 |
101 | #region 将slice预测的标签转换为EC号
102 | def translate_slice_pred(slice_pred, ec2label_dict, test_set):
103 | """[将slice预测的标签转换为EC号]
104 |
105 | Args:
106 | slice_pred ([DataFrame]): [slice预测后的排序数据]
107 | ec2label_dict ([dict]]): [ec转label的字典]
108 | test_set ([DataFrame]): [测试集用于取ID]
109 |
110 | Returns:
111 | [type]: [description]
112 | """
113 | label_ec_dict = {value:key for key, value in ec2label_dict.items()}
114 | res_df = pd.DataFrame()
115 | res_df['id'] = test_set.id
116 | colNames = slice_pred.columns.values
117 | for colName in colNames:
118 | res_df['top'+str(colName)] = slice_pred[colName].apply(lambda x: label_ec_dict.get(x))
119 | return res_df
120 | #endregion
121 |
122 |
123 |
124 | #region 将测试结果进行集成输出
125 | def run_integrage(slice_pred, dict_ec_transfer):
126 | """[将测试结果进行集成输出]
127 |
128 | Args:
129 | slice_pred ([DataFrame]): [slice 预测结果]
130 | dict_ec_transfer ([dict]): [EC转移dict]
131 |
132 | Returns:
133 | [DataFrame]: [集成后的最终结果]
134 | """
135 | # 取top10,因为最多有10功能酶
136 | slice_pred = slice_pred.iloc[:,np.r_[0:11, 21:28]]
137 |
138 | #酶集成标签
139 | with pd.option_context('mode.chained_assignment', None):
140 | slice_pred['is_enzyme_i'] = slice_pred.apply(lambda x: int(x.isemzyme_blast) if str(x.isemzyme_blast)!='nan' else x.isEnzyme_pred_xg, axis=1)
141 |
142 | # 清空非酶的EC预测标签
143 | for i in range(9):
144 | with pd.option_context('mode.chained_assignment', None):
145 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if x.is_enzyme_i==0 else x['top'+str(i)], axis=1)
146 | slice_pred['top0'] = slice_pred.apply(lambda x: x.ec_number_blast if str(x.ec_number_blast)!='nan' else x.top0, axis=1)
147 |
148 | # 清空有比对结果的预测标签
149 | for i in range(1,10):
150 | with pd.option_context('mode.chained_assignment', None):
151 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if str(x.ec_number_blast)!='nan' else x['top'+str(i)], axis=1) #有比对结果的
152 | slice_pred['top'+str(i)] = slice_pred.apply(lambda x: '' if int(x.functionCounts_pred_xg) < int(i+1) else x['top'+str(i)], axis=1) #无比对结果的
153 | with pd.option_context('mode.chained_assignment', None):
154 | slice_pred['top0']=slice_pred['top0'].apply(lambda x: '' if x=='-' else x)
155 |
156 | # 将EC号拆开
157 | for index, row in slice_pred.iterrows():
158 | ecitems=row['top0'].split(',')
159 | if len(ecitems)>1:
160 | for i in range(len(ecitems)):
161 | slice_pred.loc[index,'top'+str(i)] = ecitems[i].strip()
162 |
163 | slice_pred.reset_index(drop=True, inplace=True)
164 |
165 | # 添加几功能酶预测结果
166 | with pd.option_context('mode.chained_assignment', None):
167 | slice_pred['pred_functionCounts'] = slice_pred.apply(lambda x: int(x['functionCounts_blast']) if str(x['functionCounts_blast'])!='nan' else x.functionCounts_pred_xg ,axis=1)
168 | # 取最终的结果并改名
169 | colnames=[ 'id',
170 | 'pred_ec1',
171 | 'pred_ec2',
172 | 'pred_ec3',
173 | 'pred_ec4',
174 | 'pred_ec5' ,
175 | 'pred_ec6',
176 | 'pred_ec7',
177 | 'pred_ec8',
178 | 'pred_ec9',
179 | 'pred_ec10',
180 | 'pred_isEnzyme',
181 | 'pred_functionCounts'
182 | ]
183 | slice_pred=slice_pred.iloc[:, np.r_[0:11, 18,19]]
184 | slice_pred.columns = colnames
185 |
186 | # 计算EC转移情况
187 | for i in range(1,11):
188 | slice_pred['pred_ec'+str(i)] = slice_pred['pred_ec'+str(i)].apply(lambda x: dict_ec_transfer.get(x) if x in dict_ec_transfer.keys() else x)
189 |
190 | # 清空没有EC号预测的酶功能数
191 | with pd.option_context('mode.chained_assignment', None):
192 | slice_pred.pred_functionCounts[slice_pred.pred_ec1.isnull()] = 0
193 |
194 | return slice_pred
195 | #endregion
196 |
197 | if __name__ == '__main__':
198 |
199 | EMBEDDING_METHOD = 'esm32'
200 | TESTSET='test2019'
201 |
202 | # 1. 读入数据
203 | print('step 1: loading data')
204 | train = pd.read_feather(cfg.TRAIN_FEATURE)
205 | test = pd.read_feather(cfg.TEST_FEATURE)
206 | train,test= bcommon.load_data_embedding(train=train, test=test, embedding_type=EMBEDDING_METHOD)
207 | train = train.iloc[:,:7]
208 |
209 | dict_ec_label = np.load(cfg.FILE_EC_LABEL_DICT, allow_pickle=True).item() #EC-标签字典
210 | dict_ec_transfer = np.load(cfg.FILE_TRANSFER_DICT, allow_pickle=True).item() #EC-转移字典
211 |
212 | # 2. 获取序列比对结果
213 |
214 | print('step 2 get blast results')
215 | blast_res = bcommon.get_blast_prediction( reference_db=cfg.FILE_BLAST_TRAIN_DB,
216 | train_frame=train,
217 | test_frame=test.iloc[:,0:7],
218 | results_file=cfg.FILE_BLAST_RESULTS,
219 | identity_thres=cfg.TRAIN_BLAST_IDENTITY_THRES
220 | )
221 |
222 | # 3.获取酶-非酶预测结果
223 | print('step 3. get isEnzyme results')
224 | testX, testY = get_test_set(data=test)
225 | isEnzyme_pred, isEnzyme_pred_prob = get_isEnzymeRes(querydata=testX, model_file=cfg.ISENZYME_MODEL)
226 |
227 |
228 | # 4. 预测几功能酶预测结果
229 | print('step 4. get howmany functions ')
230 | howmany_Enzyme_pred, howmany_Enzyme_pred_prob = get_howmany_Enzyme(querydata=testX, model_file=cfg.HOWMANY_MODEL)
231 |
232 | # 5.获取Slice预测结果
233 | print('step 5. get EC prediction results')
234 | # 5.1 准备slice所用文件
235 | print('predict finished')
--------------------------------------------------------------------------------
/benchmark_train.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | from sklearn.model_selection import train_test_split
5 | import benchmark_common as bcommon
6 | import config as cfg
7 | from keras.callbacks import TensorBoard, ModelCheckpoint
8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9 |
10 |
11 | #region 获取酶训练的数据集
12 | def get_train_X_Y(traindata, feature_bankfile, task=1):
13 | """[获取酶训练的数据集]
14 |
15 | Args:
16 | traindata ([DataFrame]): [description]
17 |
18 | Returns:
19 | [DataFrame]: [trianX, trainY]
20 | """
21 | traindata = traindata.merge(feature_bankfile, on='id', how='left')
22 | train_X = np.array(traindata.iloc[:,3:])
23 | if task == 1:
24 | train_Y = bcommon.make_onehot_label(label_list=traindata['isenzyme'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T1,type='singel')
25 | if task == 2:
26 | train_Y = bcommon.make_onehot_label(label_list=traindata['functionCounts'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T2, type='singel')
27 | if task == 3 :
28 | train_Y = bcommon.make_onehot_label(label_list=traindata['ec_number'].to_list(), save=True, file_encoder=cfg.DICT_LABEL_T3, type='multi')
29 | return train_X, train_Y
30 | #endregion
31 |
32 | #region 训练是否是酶模型
33 | def train_isenzyme(X,Y, model_file, vali_ratio=0.2, force_model_update=False, epochs=1):
34 | if os.path.exists(model_file) and (force_model_update==False):
35 | return
36 | else:
37 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=vali_ratio, shuffle=False)
38 | x_train = x_train.reshape(x_train.shape[0],1,-1)
39 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1)
40 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task1', histogram_freq=1,write_grads=True)
41 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True')
42 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=512, dropout=0.2, lossfunction='binary_crossentropy',
43 | evaluation_metrics='accuracy', activation_method = 'sigmoid', output_dimensions=Y.shape[1] )
44 |
45 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=512, epochs= epochs, callbacks=[tbCallBack, checkpoint])
46 | # 保存
47 | print(f'train_isenzyme finished, best model saved to: {model_file}')
48 |
49 | #endregion
50 |
51 | #region 构建几功能酶模型
52 | def train_howmany_enzyme(X, Y, model_file, force_model_update=False, epochs=1):
53 | if os.path.exists(model_file) and (force_model_update==False):
54 | return
55 | else:
56 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=0.3, shuffle=False)
57 | x_train = x_train.reshape(x_train.shape[0],1,-1)
58 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1)
59 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task3', histogram_freq=1,write_grads=True)
60 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True')
61 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=128, dropout=0.2, lossfunction='categorical_crossentropy',
62 | evaluation_metrics='accuracy', activation_method = 'softmax', output_dimensions=Y.shape[1])
63 |
64 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=128, epochs= epochs, callbacks=[tbCallBack, checkpoint])
65 |
66 | print(f'train how many enzyme finished, best model saved to: {model_file}')
67 | #endregion
68 |
69 | def make_ec_label(train_label, test_label, file_save, force_model_update=False):
70 | if os.path.exists(file_save) and (force_model_update==False):
71 | print('ec label dict already exist')
72 | return
73 | ecset = sorted( set(list(train_label) + list(test_label)))
74 | ec_label_dict = {k: v for k, v in zip(ecset, range(len(ecset)))}
75 | np.save(file_save, ec_label_dict)
76 | print('字典保存成功')
77 | return ec_label_dict
78 |
79 |
80 | def train_ec(X, Y, model_file, force_model_update=False, epochs=1):
81 | if os.path.exists(model_file) and (force_model_update==False):
82 | return
83 | else:
84 | x_train, x_vali, y_train, y_vali = train_test_split(np.array(X), np.array(Y), test_size=0.3, shuffle=False)
85 | x_train = x_train.reshape(x_train.shape[0],1,-1)
86 | x_vali = x_vali.reshape(x_vali.shape[0],1,-1)
87 | instant_model = bcommon.mgru_attion_model(input_dimensions=X.shape[1], gru_h_size=512, dropout=0.2, lossfunction='categorical_crossentropy',
88 | evaluation_metrics='accuracy', activation_method = 'softmax', output_dimensions=Y.shape[1] )
89 | tbCallBack = TensorBoard(log_dir=f'{cfg.TEMPDIR}model/task3', histogram_freq=1,write_grads=True)
90 | checkpoint = ModelCheckpoint(filepath=model_file, monitor='val_accuracy', mode='auto', save_best_only='True')
91 | instant_model.fit(x_train, y_train, validation_data=(x_vali, y_vali), batch_size=3948, epochs= epochs, callbacks=[tbCallBack, checkpoint])
92 | # 保存
93 |
94 | print(f'train EC model finished, best model saved to: {model_file}')
95 |
96 |
97 |
98 | if __name__ =="__main__":
99 |
100 | EMBEDDING_METHOD = 'esm32'
101 |
102 | # 1. read tranning data
103 | print('step 1 loading task data')
104 |
105 | # please specific the tranning data for each tasks
106 | data_task1_train = pd.read_feather(cfg.FILE_TASK1_TRAIN)
107 | data_task2_train = pd.read_feather(cfg.FILE_TASK2_TRAIN)
108 | data_task3_train = pd.read_feather(cfg.FILE_TASK3_TRAIN)
109 |
110 | # 2. read tranning feature
111 | print(f'step 2: Loading features, embdding method={EMBEDDING_METHOD}')
112 | feature_df = bcommon.load_data_embedding(embedding_type=EMBEDDING_METHOD)
113 |
114 | #3. train task-1 model
115 | print('step 3: train isEnzyme model')
116 | task1_X, task1_Y = get_train_X_Y(traindata=data_task1_train, feature_bankfile=feature_df, task=1)
117 | train_isenzyme(X=task1_X, Y=task1_Y, model_file= cfg.ISENZYME_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400)
118 |
119 | #4. task2 train
120 | print('step 4: train how many enzymes model')
121 | task2_X, task2_Y = get_train_X_Y(traindata=data_task2_train, feature_bankfile=feature_df, task=2)
122 | train_howmany_enzyme(X=task2_X, Y=task2_Y, model_file=cfg.HOWMANY_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400)
123 |
124 | #5. task3 train
125 | print('step 5 train EC model')
126 | task3_X, task3_Y = get_train_X_Y(traindata=data_task3_train, feature_bankfile=feature_df, task=3)
127 | train_ec(X=task3_X, Y=task3_Y, model_file=cfg.EC_MODEL, force_model_update=cfg.UPDATE_MODEL, epochs=400)
128 |
129 |
130 | print('train finished')
131 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2020-06-05 05:10:25
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2023-08-25 14:32:54
6 | FilePath: /ECRECer/config.py
7 | Description:
8 |
9 | Copyright (c) 2022 by tibd, All Rights Reserved.
10 | '''
11 |
12 | import os
13 |
14 |
15 | # 1. 定义数据目录
16 | ROOTDIR= f'/hpcfs/fhome/shizhenkun/codebase/ECRECer/' #change to your own absolute data dir
17 | DATADIR = ROOTDIR +'data/'
18 | RESULTSDIR = ROOTDIR +'results/'
19 | MODELDIR = ROOTDIR +'model'
20 | TEMPDIR =ROOTDIR +'tmp/'
21 | DIR_UNIPROT = DATADIR + 'uniprot/'
22 | DIR_DATASETS = DATADIR +'datasets/'
23 | DIR_FEATURES = DATADIR + 'featureBank/'
24 | DIR_DICT = DATADIR +'dict/'
25 |
26 |
27 | #2.URL
28 | URL_SPROT_SNAP201802 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2018_02/knowledgebase/uniprot_sprot-only2018_02.tar.gz'
29 | URL_SPROT_SNAP201902 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_major_releases/release-2019_02/knowledgebase/uniprot_sprot-only2019_02.tar.gz'
30 | URL_SPROT_SNAP202006 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2020_06/knowledgebase/uniprot_sprot-only2020_06.tar.gz'
31 | URL_SPROT_SNAP202102 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_major_releases/release-2021_02/knowledgebase/uniprot_sprot-only2021_02.tar.gz'
32 | URL_SPROT_SNAP202202 = f'https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2022_02/knowledgebase/uniprot_sprot-only2022_02.tar.gz'
33 | URL_SPROT_LATEST=f'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz'
34 |
35 | #3.FILES
36 | FILE_SPROT_SNAP201802=DIR_UNIPROT+'uniprot_sprot-only2018_02.tar.gz'
37 | FILE_SPROT_SNAP201902=DIR_UNIPROT+'uniprot_sprot-only2019_02.tar.gz'
38 | FILE_SPROT_SNAP202006=DIR_UNIPROT+'uniprot_sprot-only2020_06.tar.gz'
39 | FILE_SPROT_SNAP202102=DIR_UNIPROT+'uniprot_sprot-only2021_02.tar.gz'
40 | FILE_SPROT_SNAP202202=DIR_UNIPROT+'uniprot_sprot-only2022_02.tar.gz'
41 |
42 | FILE_SPROT_LATEST=DIR_UNIPROT+'uniprot_sprot_leatest.dat.gz'
43 |
44 | FILE_FEATURE_UNIREP = DIR_FEATURES + 'embd_unirep.feather'
45 | FILE_FEATURE_ESM0 = DIR_FEATURES + 'embd_esm0.feather'
46 | FILE_FEATURE_ESM32 = DIR_FEATURES + 'embd_esm32.feather'
47 | FILE_FEATURE_ESM33 = DIR_FEATURES + 'embd_esm33.feather'
48 | FILE_FEATURE_ONEHOT = DIR_FEATURES + 'embd_onehot.feather'
49 |
50 |
51 | FILE_TASK1_TRAIN = DIR_DATASETS + 'task1/train.feather'
52 | FILE_TASK1_TEST_2019 = DIR_DATASETS + 'task1/test_2019.feather'
53 | FILE_TASK1_TEST_2020 = DIR_DATASETS + 'task1/test_2020.feather'
54 | FILE_TASK1_TEST_2021 = DIR_DATASETS + 'task1/test_2021.feather'
55 | FILE_TASK1_TEST_2022 = DIR_DATASETS + 'task1/test_2022.feather'
56 |
57 | FILE_TASK2_TRAIN = DIR_DATASETS + 'task2/train.feather'
58 | FILE_TASK2_TEST_2019 = DIR_DATASETS + 'task2/test_2019.feather'
59 | FILE_TASK2_TEST_2020 = DIR_DATASETS + 'task2/test_2020.feather'
60 | FILE_TASK2_TEST_2021 = DIR_DATASETS + 'task2/test_2021.feather'
61 | FILE_TASK2_TEST_2022 = DIR_DATASETS + 'task2/test_2022.feather'
62 |
63 |
64 | FILE_TASK3_TRAIN = DIR_DATASETS + 'task3/train.feather'
65 | FILE_TASK3_TEST_2019 = DIR_DATASETS + 'task3/test_2019.feather'
66 | FILE_TASK3_TEST_2020 = DIR_DATASETS + 'task3/test_2020.feather'
67 | FILE_TASK3_TEST_2021 = DIR_DATASETS + 'task3/test_2021.feather'
68 | FILE_TASK3_TEST_2022 = DIR_DATASETS + 'task3/test_2022.feather'
69 |
70 | FILE_TASK1_TRAIN_FASTA = DIR_DATASETS +'task1/train.fasta'
71 | FILE_TASK1_TEST_2019_FASTA = DIR_DATASETS +'task1/test_2019.fasta'
72 | FILE_TASK1_TEST_2020_FASTA = DIR_DATASETS +'task1/test_2020.fasta'
73 | FILE_TASK1_TEST_2021_FASTA = DIR_DATASETS +'task1/test_2021.fasta'
74 | FILE_TASK1_TEST_2022_FASTA = DIR_DATASETS +'task1/test_2022.fasta'
75 |
76 | FILE_TASK2_TRAIN_FASTA = DIR_DATASETS +'task2/train.fasta'
77 | FILE_TASK2_TEST_2019_FASTA = DIR_DATASETS +'task2/test_2019.fasta'
78 | FILE_TASK2_TEST_2020_FASTA = DIR_DATASETS +'task2/test_2020.fasta'
79 | FILE_TASK2_TEST_2021_FASTA = DIR_DATASETS +'task2/test_2021.fasta'
80 | FILE_TASK2_TEST_2022_FASTA = DIR_DATASETS +'task2/test_2022.fasta'
81 |
82 | FILE_TASK3_TRAIN_FASTA = DIR_DATASETS +'task3/train.fasta'
83 | FILE_TASK3_TEST_2019_FASTA = DIR_DATASETS +'task3/test_2019.fasta'
84 | FILE_TASK3_TEST_2020_FASTA = DIR_DATASETS +'task3/test_2020.fasta'
85 | FILE_TASK3_TEST_2021_FASTA = DIR_DATASETS +'task3/test_2021.fasta'
86 | FILE_TASK3_TEST_2022_FASTA = DIR_DATASETS +'task3/test_2022.fasta'
87 |
88 |
89 | FILE_CASE_SHEWANELLA_FASTA=DATADIR+'shewanella.faa'
90 |
91 |
92 | TRAIN_FEATURE = DATADIR+'train.feather'
93 | TEST_FEATURE = DATADIR+'test.feather'
94 | TRAIN_FASTA = DATADIR+'train.fasta'
95 | TEST_FASTA = DATADIR+'test.fasta'
96 |
97 | FILE_LATEST_SPROT = DATADIR + 'uniprot_sprot_latest.dat.gz'
98 | FILE_LATEST_TREMBL = DATADIR + 'uniprot_trembl_latest.dat.gz'
99 |
100 | FILE_LATEST_SPROT_FEATHER = DATADIR + 'uniprot/sprot_latest.feather'
101 | FILE_LATEST_TREMBL_FEATHER = DATADIR + 'uniprot/trembl_latest.feather'
102 |
103 |
104 | FILE_EC_LABEL_DICT = DATADIR + 'ec_label_dict.npy'
105 | FILE_BLAST_TRAIN_DB = DATADIR + 'train_blast.dmnd' # blast比对数据库
106 | FILE_BLAST_PRODUCTION_DB = DATADIR + 'uniprot_blast_db/production_blast.dmnd' # 生产环境比对数据库
107 | FILE_BLAST_PRODUCTION_FASTA = DATADIR + 'production_blast.fasta' # 生产环境比对数据库
108 | FILE_TRANSFER_DICT = DATADIR + 'ec_transfer_dict.npy'
109 |
110 |
111 |
112 | ISENZYME_MODEL = MODELDIR+'/isenzyme.h5'
113 | HOWMANY_MODEL = MODELDIR+'/howmany_enzyme.h5'
114 | EC_MODEL = MODELDIR+'/ec.h5'
115 |
116 |
117 | FILE_BLAST_RESULTS = RESULTSDIR + r'test_blast_res.tsv'
118 | FILE_BLAST_ISENAYME_RESULTS = RESULTSDIR +r'isEnzyme_blast_results.tsv'
119 | FILE_BLAST_EC_RESULTS = RESULTSDIR +r'ec_blast_results.tsv'
120 |
121 |
122 | FILE_DEEPEC_RESULTS = RESULTSDIR + r'deepec/DeepEC_Result.txt'
123 | FILE_ECPRED_RESULTS = RESULTSDIR + r'ecpred/ecpred.tsv'
124 | FILE_CATFAM_RESULTS = RESULTSDIR + r'catfam_results.output'
125 | FILE_PRIAM_RESULTS = RESULTSDIR + R'priam/PRIAM_20210819134344/ANNOTATION/sequenceECs.txt'
126 |
127 | FILE_EVL_RESULTS = RESULTSDIR + r'evaluation_table.xlsx'
128 |
129 | UPDATE_MODEL = True #强制模型更新标志
130 | EMBEDDING_METHOD={ 'one-hot':1,
131 | 'unirep':2,
132 | 'esm0':3,
133 | 'esm32':4,
134 | 'esm33':5
135 | }
136 |
137 | #
138 | BLAST_TABLE_HEAD = ['id',
139 | 'sseqid',
140 | 'pident',
141 | 'length',
142 | 'mismatch',
143 | 'gapopen',
144 | 'qstart',
145 | 'qend',
146 | 'sstart',
147 | 'send',
148 | 'evalue',
149 | 'bitscore'
150 | ]
151 |
152 |
153 | DICT_LABEL_T1 = DIR_DICT+'dict_label_task1.h5'
154 | DICT_LABEL_T2 = DIR_DICT+'dict_label_task2.h5'
155 | DICT_LABEL_T3 = DIR_DICT+'dict_label_task3.h5'
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 |
11 |
12 | # DATA FOLDER
13 |
14 | ## Folder Structure
15 |
16 | rootfolder/data/datasets/
17 | rootfolder/data/featureBank/
18 | rootfolder/data/dict/
19 |
20 | ## Get Data
21 |
22 | Bechmark task data set can be downloaded from AWS S3.
23 |
24 | ### Put in datasets dir
25 | ```
26 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/ecrecer_datasets.zip
27 | ```
28 |
29 |
30 | ### Put in featureBank dir
31 | ```
32 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_unirep.feather
33 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_onehot.feather
34 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm33.feather
35 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm32.feather
36 | wget https://tibd-public-datasets.s3.us-east-1.amazonaws.com/ecrecer/data/featureBank/embd_esm0.feather
37 | ```
38 |
--------------------------------------------------------------------------------
/data/dict/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/dict/dict_label_task1.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task1.h5
--------------------------------------------------------------------------------
/data/dict/dict_label_task2.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task2.h5
--------------------------------------------------------------------------------
/data/dict/dict_label_task3.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/data/dict/dict_label_task3.h5
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: ECRECer
2 | channels:
3 | - bjrn
4 | - bioconda
5 | - numba
6 | - rapidsai
7 | - nvidia
8 | - conda-forge
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=4.5=1_gnu
13 | - _py-xgboost-mutex=2.0=cpu_0
14 | - abseil-cpp=20210324.2=h9c3ff4c_0
15 | - anyio=3.3.0=py37h89c1867_0
16 | - argcomplete=1.12.3=pyhd8ed1ab_2
17 | - arrow-cpp=4.0.1=py37haaf8ab1_7_cuda
18 | - arrow-cpp-proc=3.0.0=cuda
19 | - async_generator=1.10=pyhd3eb1b0_0
20 | - attrs=21.2.0=pyhd3eb1b0_0
21 | - babel=2.9.1=pyhd3eb1b0_0
22 | - backcall=0.2.0=pyhd3eb1b0_0
23 | - biopython=1.79=py37h5e8e339_0
24 | - blas=1.0=mkl
25 | - bleach=3.3.0=pyhd3eb1b0_0
26 | - brotlipy=0.7.0=py37h5e8e339_1001
27 | - bzip2=1.0.8=h7b6447c_0
28 | - c-ares=1.17.1=h7f98852_1
29 | - cachetools=4.2.2=pyhd8ed1ab_0
30 | - colorama=0.4.4=pyh9f0ad1d_0
31 | - cryptography=3.4.7=py37h5d9358c_0
32 | - cudatoolkit
33 | - cudf=21.08.02=cuda_11.2_py37_gf6d31fa95d_0
34 | - cupy=9.4.0=py37hf4b2161_0
35 | - cycler=0.10.0=py_2
36 | - daal4py=2021.2.2=py37ha9443f7_0
37 | - dal=2021.2.2=h06a4308_389
38 | - dbus=1.13.18=hb2f20db_0
39 | - debugpy=1.4.1=py37hcd2ae1e_0
40 | - decorator=5.0.9=pyhd3eb1b0_0
41 | - defusedxml=0.7.1=pyhd3eb1b0_0
42 | - dill=0.3.4=pyhd3eb1b0_0
43 | - dlpack=0.5=h9c3ff4c_0
44 | - entrypoints=0.3=pyhd8ed1ab_1003
45 | - et_xmlfile=1.1.0=py37h06a4308_0
46 | - expat=2.4.1=h2531618_2
47 | - fastavro=1.4.4=py37h5e8e339_0
48 | - fastrlock=0.6=py37hcd2ae1e_1
49 | - fontconfig=2.13.1=h6c09931_0
50 | - freetype=2.10.4=h5ab3b9f_0
51 | - fsspec=2021.8.1=pyhd8ed1ab_0
52 | - gawk=5.1.0=h7b6447c_0
53 | - gflags=2.2.2=he1b5a44_1004
54 | - glib=2.68.2=h36276a3_0
55 | - glog=0.5.0=h48cff8f_0
56 | - grpc-cpp=1.39.0=hf1f433d_2
57 | - gst-plugins-base=1.14.0=h8213a91_2
58 | - gstreamer=1.14.0=h28cd5cc_2
59 | - icu=58.2=he6710b0_3
60 | - idna=2.10=pyhd3eb1b0_0
61 | - ipykernel
62 | - ipyparallel=7.0.1=pyhd8ed1ab_0
63 | - ipython=7.27.0=py37h6531663_0
64 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
65 | - jbig=2.1=h7f98852_2003
66 | - jedi=0.18.0=py37h89c1867_2
67 | - joblib=1.0.1=pyhd3eb1b0_0
68 | - jpeg=9d=h36c2ea0_0
69 | - json5=0.9.5=py_0
70 | - jsonschema=3.2.0=py_2
71 | - jupyterlab
72 | - kiwisolver=1.3.1=py37h2527ec5_1
73 | - krb5=1.17.1=h173b8e3_0
74 | - lcms2=2.12=h3be6417_0
75 | - ld_impl_linux-64=2.35.1=h7274673_9
76 | - lz4-c=1.9.3=h2531618_0
77 | - markupsafe=2.0.1=py37h5e8e339_0
78 | - mistune=0.8.4=py37h5e8e339_1004
79 | - mmseqs2=13.45111=h2d02072_0
80 | - mpi=1.0=mpich
81 | - mpich=3.3.2=hc856adb_0
82 | - mypy=0.910=pyhd3eb1b0_0
83 | - mypy_extensions=0.4.3=py37_0
84 | - ncurses=6.2=he6710b0_1
85 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
86 | - nodejs=6.11.2=h3db8ef7_0
87 | - notebook=6.0.3=py37hc8dfbb8_1
88 | - numba=0.54.0=np1.11py3.7hc13618b_g5888895d3_0
89 | - nvtx=0.2.3=py37h5e8e339_0
90 | - olefile=0.46=py_0
91 | - openpyxl=3.0.9=pyhd3eb1b0_0
92 | - openssl=1.1.1l=h7f8727e_0
93 | - orc=1.6.9=h58a87f1_0
94 | - packaging=20.9=pyhd3eb1b0_0
95 | - pandarallel=1.4.8=py37h39e3cac_0
96 | - pandas=1.2.5=py37h219a48f_0
97 | - pandoc=2.12=h06a4308_0
98 | - pandocfilters=1.4.2=py_1
99 | - parquet-cpp=1.5.1=2
100 | - parso=0.8.2=pyhd3eb1b0_0
101 | - pcre=8.44=he6710b0_0
102 | - pexpect=4.8.0=pyhd3eb1b0_3
103 | - pickleshare=0.7.5=pyhd3eb1b0_1003
104 | - pip=21.2.4=pyhd8ed1ab_0
105 | - prometheus_client=0.11.0=pyhd3eb1b0_0
106 | - prompt-toolkit=3.0.17=pyh06a4308_0
107 | - protobuf=3.16.0=py37hcd2ae1e_0
108 | - psutil=5.8.0=py37h5e8e339_1
109 | - psycopg2=2.8.5=py37hb09aad4_1
110 | - ptyprocess=0.7.0=pyhd3eb1b0_2
111 | - py-xgboost=1.3.3=py37h06a4308_0
112 | - pyarrow=4.0.1=py37h3373063_7_cuda
113 | - pycparser=2.20=py_2
114 | - pygments=2.9.0=pyhd3eb1b0_0
115 | - pyopenssl=20.0.1=pyhd3eb1b0_1
116 | - pyparsing=2.4.7=pyhd3eb1b0_0
117 | - pyrsistent=0.17.3=py37h5e8e339_2
118 | - pysocks=1.7.1=py37h89c1867_3
119 | - python=3.7.10=hffdb5ce_100_cpython
120 | - python-dateutil=2.8.1=pyhd3eb1b0_0
121 | - python_abi=3.7=2_cp37m
122 | - pytz=2021.1=pyhd3eb1b0_0
123 | - pyyaml=5.4.1=py37h5e8e339_0
124 | - pyzmq=22.1.0=py37h336d617_0
125 | - qt=5.9.7=h5867ecd_1
126 | - re2=2021.08.01=h9c3ff4c_0
127 | - readline=8.1=h27cfd23_0
128 | - requests=2.25.1=pyhd3eb1b0_0
129 | - scikit-learn=0.24.2=py37ha9443f7_0
130 | - scikit-learn-intelex=2021.2.2=py37h89c1867_1
131 | - scipy=1.6.2=py37had2a1c9_1
132 | - send2trash=1.5.0=pyhd3eb1b0_1
133 | - setuptools=58.0.4=py37h89c1867_0
134 | - snappy=1.1.8=he1b5a44_3
135 | - sniffio=1.2.0=py37h89c1867_1
136 | - spdlog=1.8.5=h4bd325d_0
137 | - sqlalchemy=1.4.0=py37h5e8e339_0
138 | - sqlite=3.35.4=hdfb4753_0
139 | - tbb=2021.2.0=hff7bd54_0
140 | - terminado=0.12.1=py37h89c1867_0
141 | - testpath=0.4.4=pyhd3eb1b0_0
142 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
143 | - tk=8.6.10=hbc83047_0
144 | - toml=0.10.2=pyhd3eb1b0_0
145 | - tornado=6.1=py37h5e8e339_1
146 | - tqdm=4.62.2=pyhd8ed1ab_0
147 | - traitlets=5.0.5=pyhd3eb1b0_0
148 | - urllib3=1.26.4=pyhd3eb1b0_0
149 | - wcwidth=0.2.5=py_0
150 | - webencodings=0.5.1=py_1
151 | - wget=1.20.1=h20c2e04_0
152 | - wheel=0.36.2=pyhd3eb1b0_0
153 | - xgboost=1.3.3=py37h06a4308_0
154 | - xz=5.2.5=h7b6447c_0
155 | - yaml=0.2.5=h516909a_0
156 | - pip:
157 | - absl-py==0.14.1
158 | - alembic==1.7.4
159 | - astunparse==1.6.3
160 | - autopage==0.4.0
161 | - cached-property==1.5.2
162 | - clang==5.0
163 | - cliff==3.9.0
164 | - cmaes==0.8.2
165 | - fair-esm==0.4.0
166 | - gast==0.4.0
167 | - h5py==3.1.0
168 | - keras==2.6.0
169 | - numpy==1.19.5
170 | - pillow==8.3.2
171 | - prettytable==2.2.1
172 | - pyasn1==0.4.8
173 | - pyasn1-modules==0.2.8
174 | - pybind11==2.6.1
175 | - pyperclip==1.8.2
176 | - requests-oauthlib==1.3.0
177 | - rsa==4.7.2
178 | - six==1.15.0
179 | - stevedore==3.4.0
180 | - tensorboard==2.7.0
181 | - tensorboard-data-server==0.6.1
182 | - tensorboard-plugin-wit==1.8.0
183 | - tensorflow==2.6.1
184 | - tensorflow-estimator==2.6.0
185 | - termcolor==1.1.0
186 | - torch==1.9.1+cu111
187 | - typing-extensions==3.7.4.3
188 | - werkzeug==2.0.2
189 | - wrapt==1.12.1
190 | prefix: /hpcfs/fhome/shizhenkun/miniconda3/envs/ECRECer
191 |
--------------------------------------------------------------------------------
/model/README.md:
--------------------------------------------------------------------------------
1 |
11 |
12 | # Model FOLDER
13 |
14 | ## Folder Structure
15 |
16 | rootfolder/model
17 |
18 | ## Get Data
19 |
20 | Trained models can be downloaded from AWS S3.
21 |
22 | ```
23 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/ec.h5
24 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/howmany_enzyme.h5
25 | wget https://tibd-public-datasets.s3.amazonaws.com/ecrecer/model/isenzyme.h5
26 | ```
--------------------------------------------------------------------------------
/production.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import joblib
4 | import os,sys
5 | import benchmark_common as bcommon
6 | import config as cfg
7 | import argparse
8 | import tools.funclib as funclib
9 | from tools.Attention import Attention
10 | from keras.models import load_model
11 | import tools.embedding_esm as esmebd
12 | import time
13 | from pandarallel import pandarallel # import pandaralle
14 |
15 |
16 | #region Integrate output
17 | def integrate_out_put(existing_table, blast_table, dmlf_pred_table, mode='p', topnum=1):
18 | """[Integrate output]
19 |
20 | Args:
21 | existing_table ([DataFrame]): [db search results table]
22 | blast_table ([DataFrame]): [sequence alignment results table]
23 | isEnzyme_pred_table (DataFrame): [isEnzyme prediction results table]
24 | how_many_table ([DataFrame]): [function counts prediction results table]
25 | ec_table ([DataFrame]): [ec prediction table]
26 |
27 | Returns:
28 | [DataFrame]: [final results]
29 | """
30 | existing_table['res_type'] = 'db_match'
31 | blast_table['res_type']='blast_match'
32 | results_df = ec_table.merge(blast_table, on='id', how='left')
33 |
34 | function_df = how_many_table.copy()
35 | function_df = function_df.merge(isEnzyme_pred_table, on='id', how='left')
36 | function_df = function_df.merge(blast_table[['id', 'ec_number']], on='id', how='left')
37 | function_df['pred_function_counts']=function_df.parallel_apply(lambda x :integrate_enzyme_functioncounts(x.ec_number, x.isEnzyme_pred, x.pred_s, x.pred_m), axis=1)
38 | results_df = results_df.merge(function_df[['id','pred_function_counts']],on='id',how='left')
39 |
40 | results_df.loc[results_df[results_df.res_type.isnull()].index,'res_type']='dmlf_pred'
41 | results_df['pred_ec']=results_df.parallel_apply(lambda x: gather_ec_by_fc(x.iloc[3:23],x.ec_number, x.pred_function_counts), axis=1)
42 | results_df = results_df.iloc[:,np.r_[0,23,1,2,32,27:31]].rename(columns={'seq_x':'seq','seqlength_x':'seqlength'})
43 |
44 |
45 | if mode=='p':
46 | existing_table['pred_ec']=''
47 | result_set = pd.concat([existing_table, results_df], axis=0)
48 | result_set = result_set.drop_duplicates(subset=['id'], keep='first').sort_values(by='res_type')
49 | result_set['ec_number'] = result_set.apply(lambda x: x.pred_ec if str(x.ec_number)=='nan' else x.ec_number, axis=1)
50 | result_set.reset_index(drop=True, inplace=True)
51 | result_set = result_set.iloc[:,0:9]
52 |
53 | result_set['seqlength'] = result_set.seq.apply(lambda x: len(x))
54 | result_set['ec_number'] = result_set.ec_number.apply(lambda x: 'Non-Enzyme' if len(x)==1 else x)
55 | result_set = result_set.rename(columns={'ec_number':'ecrecer_pred_ec_number'})
56 |
57 | result_set = result_set[['id','ecrecer_pred_ec_number','seq','seqlength']]
58 |
59 | if mode =='r':
60 | result_set= results_df.merge(ec_table, on=['id'], how='left')
61 | result_set=result_set.iloc[:,np.r_[0:3,30,5:9, 4,10:30]]
62 | result_set = result_set.rename(columns=dict({'seq_x': 'seq','pred_ec': 'top0','top0_y': 'top1' }, **{'top'+str(i) : 'top'+str(i+1) for i in range(0, 20)}))
63 | # result_set = result_set.iloc[:,0:(8+topnum)]
64 | # result_set.loc[result_set[result_set.id.isin(existing_table.id)].index.values,'res_type']= 'db_match'
65 |
66 | result_set = result_set.iloc[:,np.r_[0, 2:4,8:(8+topnum)]]
67 |
68 | return result_set
69 |
70 | #endregion
71 |
72 | #region Predict Function Counts
73 | def predict_function_counts(test_data):
74 | """[Predict Function Counts]
75 |
76 | Args:
77 | test_data ([DataFrame]): [DF contain protein ID and Seq]
78 |
79 | Returns:
80 | [DataFrame]: [col1:id, col2: single or multi; col3: multi counts]
81 | """
82 | res=pd.DataFrame()
83 | res['id']=test_data.id
84 | model_s = joblib.load(cfg.MODELDIR+'/single_multi.model')
85 | model_m = joblib.load(cfg.MODELDIR+'/multi_many.model')
86 | pred_s=model_s.predict(np.array(test_data.iloc[:,1:]))
87 | pred_m=model_m.predict(np.array(test_data.iloc[:,1:]))
88 | res['pred_s']=1-pred_s
89 | res['pred_m']=pred_m+2
90 |
91 | return res
92 | #endregion
93 |
94 | #region Integrate function counts by blast, single and multi
95 | def integrate_enzyme_functioncounts(blast, isEnzyme, single, multi):
96 | """[Integrate function counts by blast, single and multi]
97 |
98 | Args:
99 | blast ([type]): [blast results]
100 | s ([type]): [single prediction]
101 | m ([type]): [multi prediction]
102 |
103 | Returns:
104 | [type]: [description]
105 | """
106 | if str(blast)!='nan':
107 | if str(blast)=='-':
108 | return 0
109 | else:
110 | return len(blast.split(','))
111 | if isEnzyme == 0:
112 | return 0
113 | if single ==1:
114 | return 1
115 | return multi
116 | #endregion
117 |
118 | #region format finnal ec by function counts
119 | def gather_ec_by_fc(toplist, ec_blast ,counts):
120 | """[format finnal ec by function counts]
121 |
122 | Args:
123 | toplist ([list]): [top 20 predicted EC]
124 | ec_blast ([string]): [blast results]
125 | counts ([int]): [function counts]
126 |
127 | Returns:
128 | [string]: [comma sepreated ec string]
129 | """
130 | if counts==0:
131 | return '-'
132 | elif str(ec_blast)!='nan':
133 | return str(ec_blast)
134 | else:
135 | return ','.join(toplist[0:counts])
136 | #endregion
137 |
138 |
139 |
140 |
141 | #region run
142 | def step_by_step_run(input_fasta, output_tsv, mode='p', topnum=1):
143 | """[run]
144 | Args:
145 | input_fasta ([string]): [input fasta file]
146 | output_tsv ([string]): [output tsv file]
147 | """
148 | start = time.process_time()
149 | if mode =='p':
150 | print('run in annoation mode')
151 | if mode =='r':
152 | print('run in recommendation mode')
153 | if mode =='h':
154 | print('run in hybrid mode')
155 |
156 | # 1. 读入数据
157 | print('step 1: loading data')
158 | input_df = funclib.load_fasta_to_table(input_fasta) # test fasta
159 | latest_sprot = pd.read_feather(cfg.FILE_LATEST_SPROT_FEATHER) #sprot db
160 |
161 |
162 | # 2. 查找数据
163 | print('step 2: find existing data')
164 | find_data = input_df.merge(latest_sprot, on='seq', how='left')
165 | find_data = latest_sprot[latest_sprot.seq.isin(input_df.seq)]
166 | find_data = find_data.drop_duplicates(subset='seq').reset_index(drop=True)
167 | exist_data = find_data.merge(input_df, on='seq', how='left').iloc[:,np.r_[8,0,1:8]].rename(columns={'id_x':'uniprot_id','id_y':'input_id'}).reset_index(drop=True)
168 | noExist_data = input_df[~input_df.seq.isin(find_data.seq)]
169 |
170 | if len(noExist_data) == 0 and mode=='p':
171 | exist_data=exist_data[['input_id','ec_number']].rename(columns={'ec_number':'ec_pred'})
172 | exist_data.to_csv(output_tsv, sep='\t', index=False)
173 | end = time.process_time()
174 | print('All done running time: %s Seconds'%(end-start))
175 | return
176 |
177 |
178 | # 3. EMBedding
179 | print('step 3: Embedding')
180 | featurebank_esm32 = pd.read_feather(cfg.FILE_FEATURE_ESM32)
181 |
182 | existing_feature = featurebank_esm32[featurebank_esm32.id.isin(exist_data.uniprot_id)]
183 | existing_feature = exist_data[['input_id','uniprot_id']].merge(existing_feature.rename(columns={'id':'uniprot_id'}), on='uniprot_id', how='left').rename(columns={'input_id':'id'}).iloc[:,np.r_[0,2:existing_feature.shape[1]+1]]
184 |
185 | rep0, rep32, rep33 = esmebd.get_rep_multi_sequence(sequences=noExist_data, model='esm1b_t33_650M_UR50S',seqthres=1022)
186 |
187 | rep32 = pd.concat([existing_feature,rep32],axis=0).reset_index(drop=True)
188 |
189 | print('step 4: run prediction')
190 |
191 | if mode=='p':
192 |
193 | # 5. isEnzyme Prediction
194 | print('step 5: predict isEnzyme')
195 | pred_dmlf = pd.DataFrame(rep32.id.copy())
196 | model_isEnzyme = load_model(cfg.ISENZYME_MODEL,custom_objects={"Attention": Attention}, compile=False)
197 | predicted = model_isEnzyme.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
198 | encoder_t1=joblib.load(cfg.DICT_LABEL_T1)
199 |
200 | pred_dmlf['dmlf_isEnzyme']=(encoder_t1.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0]
201 |
202 |
203 | # 6. How many Prediction
204 | print('step 6: predict function counts')
205 | model_howmany = load_model(cfg.HOWMANY_MODEL,custom_objects={"Attention": Attention}, compile=False)
206 | predicted = model_howmany.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
207 | encoder_t2=joblib.load(cfg.DICT_LABEL_T2)
208 | pred_dmlf['dmlf_howmany']=(encoder_t2.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0]
209 |
210 |
211 | # 7. EC Prediction
212 | print('step 7: predict EC')
213 | model_ec = load_model(cfg.EC_MODEL,custom_objects={"Attention": Attention}, compile=False)
214 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
215 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3)
216 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))]
217 |
218 |
219 | print('step 8: integrate results')
220 | results = pred_dmlf.merge(exist_data, left_on='id',right_on='input_id', how='left')
221 | results=results.fillna('#')
222 | results['ec_pred'] =results.apply(lambda x : x.ec_number if x.ec_number!='#' else ('-' if x.dmlf_isEnzyme==False else x.dmlf_ec) ,axis=1)
223 | output_df = results[['id', 'ec_pred']].rename(columns={'id':'id_input'})
224 |
225 | elif mode =='r':
226 | # print('step 4: recommendation')
227 | # label_model_ec = pd.read_feather(f'{cfg.MODELDIR}/task3_labels.feather').label_multi.to_list()
228 | # model_ec = load_model(f'{cfg.MODELDIR}/task3_esm32_2022.h5',custom_objects={"Attention": Attention}, compile=False)
229 | # predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
230 | # output_df=pd.DataFrame()
231 | # output_df['id']=input_df['id'].copy()
232 | # output_df['ec_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((label_model_ec), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values
233 |
234 | print('step 4: predict EC')
235 | pred_dmlf = pd.DataFrame(rep32.id.copy())
236 | model_ec = load_model(cfg.EC_MODEL, custom_objects={"Attention": Attention}, compile=False)
237 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
238 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3)
239 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))]
240 | pred_dmlf['dmlf_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((encoder_t3.classes_), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values
241 | output_df = pred_dmlf[['id', 'dmlf_recomendations']].rename(columns={'id':'id_input'})
242 |
243 | elif mode =='h':
244 | print('running in hybird mode')
245 |
246 | # 4. sequence alignment
247 | print('step 4: sequence alignment')
248 | if not os.path.exists(cfg.FILE_BLAST_PRODUCTION_DB):
249 | funclib.table2fasta(latest_sprot, cfg.FILE_BLAST_PRODUCTION_FASTA)
250 | cmd = r'diamond makedb --in {0} -d {1}'.format(cfg.FILE_BLAST_PRODUCTION_FASTA, cfg.FILE_BLAST_PRODUCTION_DB)
251 | os.system(cmd)
252 | blast_res = funclib.getblast_usedb(db=cfg.FILE_BLAST_PRODUCTION_DB, test=input_df)
253 | blast_res =blast_res[['id', 'sseqid']].merge(latest_sprot, left_on='sseqid', right_on='id', how='left').iloc[:,np.r_[0,2,3:10]].rename(columns={'id_x':'input_id','id_y':'uniprot_id'}).reset_index(drop=True)
254 |
255 | # 5. isEnzyme Prediction
256 | print('step 5: predict isEnzyme')
257 | pred_dmlf = pd.DataFrame(rep32.id.copy())
258 | model_isEnzyme = load_model(cfg.ISENZYME_MODEL,custom_objects={"Attention": Attention}, compile=False)
259 | predicted = model_isEnzyme.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
260 | encoder_t1=joblib.load(cfg.DICT_LABEL_T1)
261 | pred_dmlf['dmlf_isEnzyme']=(encoder_t1.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0]
262 |
263 |
264 | # 6. How many Prediction
265 | print('step 6: predict function counts')
266 | model_howmany = load_model(cfg.HOWMANY_MODEL,custom_objects={"Attention": Attention}, compile=False)
267 | predicted = model_howmany.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
268 | encoder_t2=joblib.load(cfg.DICT_LABEL_T2)
269 | pred_dmlf['dmlf_functions']=(encoder_t2.inverse_transform(bcommon.props_to_onehot(predicted))).reshape(1,-1)[0]
270 |
271 |
272 | # 7. EC Prediction
273 | print('step 7: predict EC')
274 | model_ec = load_model(cfg.EC_MODEL,custom_objects={"Attention": Attention}, compile=False)
275 | predicted = model_ec.predict(np.array(rep32.iloc[:,1:]).reshape(rep32.shape[0],1,-1))
276 | encoder_t3=joblib.load(cfg.DICT_LABEL_T3)
277 | pred_dmlf['dmlf_ec']=[','.join(item) for item in (encoder_t3.inverse_transform(bcommon.props_to_onehot(predicted)))]
278 | pred_dmlf['dmlf_recomendations']=pd.DataFrame(predicted).apply(lambda x :sorted(dict(zip((encoder_t3.classes_), x)).items(),key = lambda x:x[1], reverse = True)[0:topnum], axis=1 ).values
279 |
280 | pred_dmlf = pred_dmlf.merge(blast_res[['input_id','ec_number']].rename(columns={'ec_number':'blast_ec'}), left_on='id', right_on='input_id', how='left')
281 | # pred_dmlf['dmlf_recomendations']=pred_dmlf.apply(lambda x: x.dmlf_recomendations if x.dmlf_isEnzyme else '-', axis=1 )
282 | pred_dmlf['dmlf_ec']=pred_dmlf.apply(lambda x: x.dmlf_ec if x.dmlf_isEnzyme else '-', axis=1 )
283 | pred_dmlf = pred_dmlf.merge(exist_data[['input_id','ec_number']].rename(columns={'ec_number':'db_ec'}), on='input_id', how='left')
284 | pred_dmlf['dmlf_ec']=pred_dmlf.apply(lambda x: x.db_ec if str(x.db_ec)!='nan' else x.dmlf_ec,axis=1)
285 | pred_dmlf['dmlf_isEnzyme']=pred_dmlf.apply(lambda x: True if (str(x.db_ec)!='nan' and x.db_ec!='-') else x.dmlf_isEnzyme,axis=1)
286 | pred_dmlf['dmlf_functions']=pred_dmlf.apply(lambda x: len(x.db_ec.split(',')) if str(x.db_ec)!='nan' else x.dmlf_functions,axis=1)
287 |
288 |
289 | output_df = pred_dmlf[['id', 'dmlf_isEnzyme', 'dmlf_functions', 'dmlf_ec', 'dmlf_recomendations', 'blast_ec' ]].rename(columns={'id':'input_id'})
290 |
291 | else:
292 | print(f'mode:{mode} not found')
293 | sys.exit()
294 |
295 |
296 | print('step 9: writting results')
297 |
298 | output_df.to_csv(output_tsv, sep='\t', index=False)
299 |
300 | print(output_df)
301 |
302 | end = time.process_time()
303 | print('All done running time: %s Seconds'%(end-start))
304 | #endregion
305 |
306 |
307 | if __name__ =='__main__':
308 |
309 | parser = argparse.ArgumentParser()
310 | parser.add_argument('-i', help='input file (fasta format)', type=str, default=cfg.DATADIR + 'sample_10.fasta')
311 | parser.add_argument('-o', help='output file (tsv table)', type=str, default=cfg.RESULTSDIR + 'sample_10_2023_07_18.tsv')
312 | parser.add_argument('-mode', help='compute mode. p: prediction, r: recommendation, h:hybrid', type=str, default='r')
313 | parser.add_argument('-topk', help='recommendation records, min=1, max=20', type=int, default='50')
314 |
315 | pandarallel.initialize() #init
316 | args = parser.parse_args()
317 | input_file = args.i
318 | output_file = args.o
319 | compute_mode = args.mode
320 | topk = args.topk
321 |
322 | step_by_step_run( input_fasta=input_file,
323 | output_tsv=output_file,
324 | mode=compute_mode,
325 | topnum=topk
326 | )
327 |
--------------------------------------------------------------------------------
/tasks/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarking tasks
2 |
3 | ## 1. Is enzyme task
4 | ## 2. How many function task
5 | ## 3. EC number prediction task
--------------------------------------------------------------------------------
/tasks/task1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "tags": []
7 | },
8 | "source": [
9 | "# Task1. Enzyme or Non-enzyme Annotation\n",
10 | "\n",
11 | "> author: Shizhenkun \n",
12 | "> email: zhenkun.shi@tib.cas.cn \n",
13 | "> date: 2021-10-20 "
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "## 1. Import packages"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "metadata": {
27 | "tags": []
28 | },
29 | "outputs": [
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "INFO: Pandarallel will run on 52 workers.\n",
35 | "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n"
36 | ]
37 | }
38 | ],
39 | "source": [
40 | "import numpy as np\n",
41 | "import pandas as pd\n",
42 | "import time\n",
43 | "import datetime\n",
44 | "import sys\n",
45 | "import os\n",
46 | "from tqdm import tqdm\n",
47 | "from functools import reduce\n",
48 | "import joblib\n",
49 | "\n",
50 | "sys.path.append(\"../tools/\")\n",
51 | "import funclib\n",
52 | "\n",
53 | "sys.path.append(\"../\")\n",
54 | "import benchmark_train as btrain\n",
55 | "import benchmark_test as btest\n",
56 | "import config as cfg\n",
57 | "import benchmark_evaluation as eva\n",
58 | "\n",
59 | "from sklearn.model_selection import train_test_split\n",
60 | "from xgboost import XGBClassifier\n",
61 | "\n",
62 | "from pandarallel import pandarallel # import pandaralle\n",
63 | "pandarallel.initialize() # init\n",
64 | "\n",
65 | "%load_ext autoreload\n",
66 | "%autoreload 2"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "## 2. Load data"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 2,
79 | "metadata": {},
80 | "outputs": [
81 | {
82 | "name": "stdout",
83 | "output_type": "stream",
84 | "text": [
85 | "train size: 469134\n",
86 | "test size: 7101\n"
87 | ]
88 | }
89 | ],
90 | "source": [
91 | "#read train test data\n",
92 | "train = pd.read_feather(cfg.DATADIR+'task1/train.feather')\n",
93 | "test = pd.read_feather(cfg.DATADIR+'task1/test.feather')\n",
94 | "print('train size: {0}\\ntest size: {1}'.format(len(train), len(test)))"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "data": {
104 | "text/html": [
105 | "
\n",
106 | "\n",
119 | "
\n",
120 | " \n",
121 | " \n",
122 | " | \n",
123 | " id | \n",
124 | " seq | \n",
125 | " isenzyme | \n",
126 | "
\n",
127 | " \n",
128 | " \n",
129 | " \n",
130 | " 124 | \n",
131 | " P00958 | \n",
132 | " MSFLISFDKSKKHPAHLQLANNLKIALALEYASKNLKPEVDNDNAA... | \n",
133 | " True | \n",
134 | "
\n",
135 | " \n",
136 | " 175 | \n",
137 | " P00812 | \n",
138 | " METGPHYNYYKNRELSIVLAPFSGGQGKLGVEKGPKYMLKHGLQTS... | \n",
139 | " True | \n",
140 | "
\n",
141 | " \n",
142 | " 248 | \n",
143 | " P00959 | \n",
144 | " MTQVAKKILVTCALPYANGSIHLGHMLEHIQADVWVRYQRMRGHEV... | \n",
145 | " True | \n",
146 | "
\n",
147 | " \n",
148 | " 249 | \n",
149 | " P00348 | \n",
150 | " MAFATRQLVRSLSSSSTAAASAKKILVKHVTVIGGGLMGAGIAQVA... | \n",
151 | " True | \n",
152 | "
\n",
153 | " \n",
154 | " 250 | \n",
155 | " P00469 | \n",
156 | " MLEQPYLDLAKKVLDEGHFKPDRTHTGTYSIFGHQMRFDLSKGFPL... | \n",
157 | " True | \n",
158 | "
\n",
159 | " \n",
160 | " ... | \n",
161 | " ... | \n",
162 | " ... | \n",
163 | " ... | \n",
164 | "
\n",
165 | " \n",
166 | " 469123 | \n",
167 | " Q8I6K2 | \n",
168 | " MSNTAVLNDLVALYDRPTEPMFRVKAKKSFKVPKEYVTDRFKNVAV... | \n",
169 | " True | \n",
170 | "
\n",
171 | " \n",
172 | " 469127 | \n",
173 | " O81103 | \n",
174 | " MATAPSPTTMGTYSSLISTNSFSTFLPNKSQLSLSGKSKHYVARRS... | \n",
175 | " True | \n",
176 | "
\n",
177 | " \n",
178 | " 469129 | \n",
179 | " Q21221 | \n",
180 | " MSSGAPSGSSMSSTPGSPPPRAGGPNSVSFKDLCCLFCCPPFPSSI... | \n",
181 | " True | \n",
182 | "
\n",
183 | " \n",
184 | " 469130 | \n",
185 | " Q6QJ72 | \n",
186 | " MSRLLLPKLFSISRTQVPAASLFNNLYRRHKRFVHWTSKMSTDSVR... | \n",
187 | " True | \n",
188 | "
\n",
189 | " \n",
190 | " 469133 | \n",
191 | " D9XDR8 | \n",
192 | " MAKMSTTHEEIALAGPDGIPAVDLRDLIDAQLYMPFPFERNPHASE... | \n",
193 | " True | \n",
194 | "
\n",
195 | " \n",
196 | "
\n",
197 | "
222567 rows × 3 columns
\n",
198 | "
"
199 | ],
200 | "text/plain": [
201 | " id seq isenzyme\n",
202 | "124 P00958 MSFLISFDKSKKHPAHLQLANNLKIALALEYASKNLKPEVDNDNAA... True\n",
203 | "175 P00812 METGPHYNYYKNRELSIVLAPFSGGQGKLGVEKGPKYMLKHGLQTS... True\n",
204 | "248 P00959 MTQVAKKILVTCALPYANGSIHLGHMLEHIQADVWVRYQRMRGHEV... True\n",
205 | "249 P00348 MAFATRQLVRSLSSSSTAAASAKKILVKHVTVIGGGLMGAGIAQVA... True\n",
206 | "250 P00469 MLEQPYLDLAKKVLDEGHFKPDRTHTGTYSIFGHQMRFDLSKGFPL... True\n",
207 | "... ... ... ...\n",
208 | "469123 Q8I6K2 MSNTAVLNDLVALYDRPTEPMFRVKAKKSFKVPKEYVTDRFKNVAV... True\n",
209 | "469127 O81103 MATAPSPTTMGTYSSLISTNSFSTFLPNKSQLSLSGKSKHYVARRS... True\n",
210 | "469129 Q21221 MSSGAPSGSSMSSTPGSPPPRAGGPNSVSFKDLCCLFCCPPFPSSI... True\n",
211 | "469130 Q6QJ72 MSRLLLPKLFSISRTQVPAASLFNNLYRRHKRFVHWTSKMSTDSVR... True\n",
212 | "469133 D9XDR8 MAKMSTTHEEIALAGPDGIPAVDLRDLIDAQLYMPFPFERNPHASE... True\n",
213 | "\n",
214 | "[222567 rows x 3 columns]"
215 | ]
216 | },
217 | "execution_count": 5,
218 | "metadata": {},
219 | "output_type": "execute_result"
220 | }
221 | ],
222 | "source": [
223 | "train[train.isenzyme]"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {},
229 | "source": [
230 | "## 3. Sequence aligment"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 3,
236 | "metadata": {
237 | "tags": []
238 | },
239 | "outputs": [
240 | {
241 | "name": "stdout",
242 | "output_type": "stream",
243 | "text": [
244 | "Write finished\n",
245 | "Write finished\n",
246 | "diamond makedb --in /tmp/train.fasta -d /tmp/train.dmnd\n"
247 | ]
248 | },
249 | {
250 | "name": "stderr",
251 | "output_type": "stream",
252 | "text": [
253 | "diamond v2.0.8.146 (C) Max Planck Society for the Advancement of Science\n",
254 | "Documentation, support and updates available at http://www.diamondsearch.org\n",
255 | "\n",
256 | "#CPU threads: 80\n",
257 | "Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)\n",
258 | "Database input file: /tmp/train.fasta\n",
259 | "Opening the database file... [0.004s]\n",
260 | "Loading sequences... [0.869s]\n",
261 | "Masking sequences... [0.304s]\n",
262 | "Writing sequences... [0.163s]\n",
263 | "Hashing sequences... [0.05s]\n",
264 | "Loading sequences... [0s]\n",
265 | "Writing trailer... [0.003s]\n",
266 | "Closing the input file... [0.001s]\n",
267 | "Closing the database file... [0.038s]\n",
268 | "Database hash = eed65be4bf3bb33f8407f23b2e861bca\n",
269 | "Processed 469134 sequences, 176795800 letters.\n",
270 | "Total time = 1.436s\n"
271 | ]
272 | },
273 | {
274 | "name": "stdout",
275 | "output_type": "stream",
276 | "text": [
277 | "diamond blastp -d /tmp/train.dmnd -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet\n",
278 | " aligment finished \n",
279 | " query samples:7101\n",
280 | " results samples: 5111\n"
281 | ]
282 | }
283 | ],
284 | "source": [
285 | "# blast\n",
286 | "res_data=funclib.getblast(train,test)\n",
287 | "print(' aligment finished \\n query samples:{0}\\n results samples: {1}'.format(len(test), len(res_data)))\n",
288 | "\n",
289 | "res_data = res_data[['id', 'sseqid']].merge(train, left_on='sseqid', right_on='id', how='left')[['id_x', 'isenzyme']]\n",
290 | "res_data =res_data.rename(columns={'id_x':'id','isenzyme':'isenzyme_blast'})\n",
291 | "res_data = test[['id','isenzyme']].merge(res_data, on='id', how='left')"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 4,
297 | "metadata": {},
298 | "outputs": [
299 | {
300 | "name": "stdout",
301 | "output_type": "stream",
302 | "text": [
303 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t confusion Matrix\n",
304 | "Blast \t\t0.667934 \t0.966532 \t\t0.884197 \t0.795400 \t0.872655 \t tp: 2628 fp: 91 fn: 277 tn: 2115 up: 399 un: 1591\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "print('baslineName', '\\t', 'accuracy','\\t', 'precision(PPV) \\t NPV \\t\\t', 'recall','\\t', 'f1', '\\t\\t', '\\t confusion Matrix')\n",
310 | "eva.caculateMetrix(groundtruth=res_data.isenzyme, predict=res_data.isenzyme_blast, baselineName='Blast', type='include_unfind')"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {
316 | "tags": []
317 | },
318 | "source": [
319 | "## 4. Embedding Comparison\n",
320 | "### 4.1 one-hot + ML"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": 4,
326 | "metadata": {},
327 | "outputs": [],
328 | "source": [
329 | "trainset = train.copy()\n",
330 | "testset = test.copy()"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 7,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "MAX_SEQ_LENGTH = 1500 #定义序列最长的长度\n",
340 | "trainset.seq = trainset.seq.map(lambda x : x[0:MAX_SEQ_LENGTH].ljust(MAX_SEQ_LENGTH, 'X'))\n",
341 | "testset.seq = testset.seq.map(lambda x : x[0:MAX_SEQ_LENGTH].ljust(MAX_SEQ_LENGTH, 'X'))\n",
342 | "f_train = funclib.dna_onehot(trainset) #训练集编码\n",
343 | "f_test = funclib.dna_onehot(testset) #测试集编码"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 8,
349 | "metadata": {},
350 | "outputs": [
351 | {
352 | "name": "stdout",
353 | "output_type": "stream",
354 | "text": [
355 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n",
356 | "knn \t\t0.611745 \t0.666667 \t\t0.595238 \t0.331114 \t0.442467 \t tp: 1094 fp: 547 fn: 2210 tn: 3250\n",
357 | "lr \t\t0.672159 \t0.630412 \t\t0.718666 \t0.713983 \t0.669600 \t tp: 2359 fp: 1383 fn: 945 tn: 2414\n",
358 | "xg \t\t0.723278 \t0.730942 \t\t0.717991 \t0.641344 \t0.683218 \t tp: 2119 fp: 780 fn: 1185 tn: 3017\n",
359 | "dt \t\t0.617096 \t0.599932 \t\t0.629133 \t0.531477 \t0.563633 \t tp: 1756 fp: 1171 fn: 1548 tn: 2626\n",
360 | "rf \t\t0.715111 \t0.691709 \t\t0.735904 \t0.699455 \t0.695561 \t tp: 2311 fp: 1030 fn: 993 tn: 2767\n",
361 | "gbdt \t\t0.689621 \t0.646510 \t\t0.737974 \t0.734564 \t0.687730 \t tp: 2427 fp: 1327 fn: 877 tn: 2470\n"
362 | ]
363 | }
364 | ],
365 | "source": [
366 | "# 计算指标\n",
367 | "X_train = np.array(f_train.iloc[:,2:])\n",
368 | "X_test = np.array(f_test.iloc[:,2:])\n",
369 | "Y_train = np.array(trainset.isenzyme.astype('int'))\n",
370 | "Y_test = np.array(testset.isenzyme.astype('int'))\n",
371 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)"
372 | ]
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {},
377 | "source": [
378 | "### 4.2 Unirep + ML"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 8,
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "train_unirep = pd.read_feather(cfg.DATADIR + 'train_unirep.feather')\n",
388 | "test_unirep = pd.read_feather(cfg.DATADIR + 'test_unirep.feather')"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 9,
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "train_unirep = trainset.merge(train_unirep, on='id', how='left')\n",
398 | "test_unirep = testset.merge(test_unirep, on='id', how='left')"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": 10,
404 | "metadata": {},
405 | "outputs": [
406 | {
407 | "name": "stdout",
408 | "output_type": "stream",
409 | "text": [
410 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n",
411 | "knn \t\t0.851852 \t0.854088 \t\t0.850038 \t0.822034 \t0.837754 \t tp: 2716 fp: 464 fn: 588 tn: 3333\n",
412 | "lr \t\t0.809604 \t0.933008 \t\t0.752218 \t0.636501 \t0.756747 \t tp: 2103 fp: 151 fn: 1201 tn: 3646\n",
413 | "xg \t\t0.861991 \t0.885790 \t\t0.844461 \t0.807506 \t0.844839 \t tp: 2668 fp: 344 fn: 636 tn: 3453\n",
414 | "dt \t\t0.769610 \t0.789986 \t\t0.755740 \t0.687651 \t0.735275 \t tp: 2272 fp: 604 fn: 1032 tn: 3193\n",
415 | "rf \t\t0.841994 \t0.909535 \t\t0.801442 \t0.733354 \t0.811997 \t tp: 2423 fp: 241 fn: 881 tn: 3556\n",
416 | "gbdt \t\t0.785101 \t0.894060 \t\t0.734365 \t0.610472 \t0.725540 \t tp: 2017 fp: 239 fn: 1287 tn: 3558\n"
417 | ]
418 | }
419 | ],
420 | "source": [
421 | "X_train =np.array(train_unirep.iloc[:,3:])\n",
422 | "X_test = np.array(test_unirep.iloc[:,3:])\n",
423 | "\n",
424 | "Y_train = np.array(train_unirep.isenzyme.astype('int')).flatten()\n",
425 | "Y_test = np.array(test_unirep.isenzyme.astype('int')).flatten()\n",
426 | "\n",
427 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "metadata": {
433 | "tags": []
434 | },
435 | "source": [
436 | "### 4.3 ESM REP33 + ML"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 9,
442 | "metadata": {},
443 | "outputs": [],
444 | "source": [
445 | "train_esm_33 = pd.read_feather(cfg.DATADIR + 'train_rep33.feather')\n",
446 | "test_esm_33 = pd.read_feather(cfg.DATADIR + 'test_rep33.feather')\n",
447 | "\n",
448 | "train_esm = trainset.merge(train_esm_33, on='id', how='left')\n",
449 | "test_esm = testset.merge(test_esm_33, on='id', how='left')"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 10,
455 | "metadata": {},
456 | "outputs": [
457 | {
458 | "name": "stdout",
459 | "output_type": "stream",
460 | "text": [
461 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n",
462 | "knn \t\t0.927300 \t0.935953 \t\t0.920835 \t0.898296 \t0.916738 \t tp: 3215 fp: 220 fn: 364 tn: 4234\n",
463 | "lr \t\t0.908378 \t0.927005 \t\t0.895196 \t0.862252 \t0.893457 \t tp: 3086 fp: 243 fn: 493 tn: 4211\n",
464 | "xg \t\t0.928047 \t0.952913 \t\t0.910593 \t0.882090 \t0.916135 \t tp: 3157 fp: 156 fn: 422 tn: 4298\n",
465 | "dt \t\t0.833811 \t0.848664 \t\t0.823884 \t0.763062 \t0.803590 \t tp: 2731 fp: 487 fn: 848 tn: 3967\n",
466 | "rf \t\t0.916096 \t0.960965 \t\t0.887136 \t0.846046 \t0.899851 \t tp: 3028 fp: 123 fn: 551 tn: 4331\n",
467 | "gbdt \t\t0.865804 \t0.901703 \t\t0.843089 \t0.784297 \t0.838912 \t tp: 2807 fp: 306 fn: 772 tn: 4148\n"
468 | ]
469 | }
470 | ],
471 | "source": [
472 | "X_train = np.array(train_esm.iloc[:,4:])\n",
473 | "X_test = np.array(test_esm.iloc[:,4:])\n",
474 | "\n",
475 | "Y_train = np.array(train_esm.isemzyme.astype('int')).flatten()\n",
476 | "Y_test = np.array(test_esm.isemzyme.astype('int')).flatten()\n",
477 | "\n",
478 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "metadata": {},
484 | "source": [
485 | "### 4.4 ESM REP32 + ML"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 10,
491 | "metadata": {
492 | "tags": []
493 | },
494 | "outputs": [
495 | {
496 | "name": "stdout",
497 | "output_type": "stream",
498 | "text": [
499 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n",
500 | "knn \t\t0.924799 \t0.939125 \t\t0.913352 \t0.896489 \t0.917312 \t tp: 2962 fp: 192 fn: 342 tn: 3605\n",
501 | "lr \t\t0.908604 \t0.927536 \t\t0.893894 \t0.871671 \t0.898736 \t tp: 2880 fp: 225 fn: 424 tn: 3572\n",
502 | "xg \t\t0.921420 \t0.949869 \t\t0.899975 \t0.877421 \t0.912209 \t tp: 2899 fp: 153 fn: 405 tn: 3644\n",
503 | "dt \t\t0.830587 \t0.855981 \t\t0.812530 \t0.764528 \t0.807674 \t tp: 2526 fp: 425 fn: 778 tn: 3372\n",
504 | "rf \t\t0.908604 \t0.961418 \t\t0.872633 \t0.837167 \t0.895001 \t tp: 2766 fp: 111 fn: 538 tn: 3686\n",
505 | "gbdt \t\t0.874947 \t0.905641 \t\t0.852777 \t0.816283 \t0.858644 \t tp: 2697 fp: 281 fn: 607 tn: 3516\n"
506 | ]
507 | }
508 | ],
509 | "source": [
510 | "train_esm_32 = pd.read_feather(cfg.DATADIR + 'train_rep32.feather')\n",
511 | "test_esm_32 = pd.read_feather(cfg.DATADIR + 'test_rep32.feather')\n",
512 | "\n",
513 | "train_esm = trainset.merge(train_esm_32, on='id', how='left')\n",
514 | "test_esm = testset.merge(test_esm_32, on='id', how='left')\n",
515 | "\n",
516 | "X_train = np.array(train_esm.iloc[:,4:])\n",
517 | "X_test = np.array(test_esm.iloc[:,4:])\n",
518 | "\n",
519 | "Y_train = np.array(train_esm.isenzyme.astype('int')).flatten()\n",
520 | "Y_test = np.array(test_esm.isenzyme.astype('int')).flatten()\n",
521 | "\n",
522 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)"
523 | ]
524 | },
525 | {
526 | "cell_type": "markdown",
527 | "metadata": {},
528 | "source": [
529 | "### 4.5 ESM REP0 + ML"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 7,
535 | "metadata": {},
536 | "outputs": [
537 | {
538 | "name": "stdout",
539 | "output_type": "stream",
540 | "text": [
541 | "baslineName \t accuracy \t precision(PPV) \t NPV \t\t recall \t f1 \t\t \t\t confusion Matrix\n",
542 | "knn \t\t0.824599 \t0.789179 \t\t0.855641 \t0.827326 \t0.807802 \t tp: 2961 fp: 791 fn: 618 tn: 3663\n",
543 | "lr \t\t0.757251 \t0.721031 \t\t0.787948 \t0.742386 \t0.731553 \t tp: 2657 fp: 1028 fn: 922 tn: 3426\n",
544 | "xg \t\t0.847504 \t0.844555 \t\t0.849686 \t0.806091 \t0.824875 \t tp: 2885 fp: 531 fn: 694 tn: 3923\n",
545 | "dt \t\t0.760612 \t0.739722 \t\t0.776370 \t0.713887 \t0.726575 \t tp: 2555 fp: 899 fn: 1024 tn: 3555\n",
546 | "rf \t\t0.853853 \t0.863623 \t\t0.847017 \t0.797988 \t0.829509 \t tp: 2856 fp: 451 fn: 723 tn: 4003\n",
547 | "gbdt \t\t0.820988 \t0.810020 \t\t0.829258 \t0.781503 \t0.795506 \t tp: 2797 fp: 656 fn: 782 tn: 3798\n"
548 | ]
549 | }
550 | ],
551 | "source": [
552 | "train_esm_0 = pd.read_feather(cfg.DATADIR + 'train_rep0.feather')\n",
553 | "test_esm_0 = pd.read_feather(cfg.DATADIR + 'test_rep0.feather')\n",
554 | "\n",
555 | "train_esm = trainset.merge(train_esm_0, on='id', how='left')\n",
556 | "test_esm = testset.merge(test_esm_0, on='id', how='left')\n",
557 | "\n",
558 | "X_train = np.array(train_esm.iloc[:,4:])\n",
559 | "X_test = np.array(test_esm.iloc[:,4:])\n",
560 | "\n",
561 | "Y_train = np.array(train_esm.isemzyme.astype('int')).flatten()\n",
562 | "Y_test = np.array(test_esm.isemzyme.astype('int')).flatten()\n",
563 | "\n",
564 | "funclib.run_baseline(X_train, Y_train, X_test, Y_test)"
565 | ]
566 | },
567 | {
568 | "cell_type": "markdown",
569 | "metadata": {},
570 | "source": [
571 | "## 5. Ours"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": null,
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "# get blast results\n",
581 | "blastres=pd.DataFrame()\n",
582 | "blastres['id']=res.id\n",
583 | "blastres['isemzyme_groundtruth']=res.isemzyme\n",
584 | "blastres['isEmzyme_pred_blast']=res.isEmzyme_pred"
585 | ]
586 | },
587 | {
588 | "cell_type": "code",
589 | "execution_count": 6,
590 | "metadata": {},
591 | "outputs": [],
592 | "source": [
593 | "#res32\n",
594 | "train_esm_32 = pd.read_feather(cfg.DATADIR + 'train_rep32.feather')\n",
595 | "test_esm_32 = pd.read_feather(cfg.DATADIR + 'test_rep32.feather')\n",
596 | "\n",
597 | "train_esm = trainset.merge(train_esm_32, on='id', how='left')\n",
598 | "test_esm = testset.merge(test_esm_32, on='id', how='left')\n",
599 | "\n",
600 | "X_train = np.array(train_esm.iloc[:,3:])\n",
601 | "X_test = np.array(test_esm.iloc[:,3:])\n",
602 | "\n",
603 | "Y_train = np.array(train_esm.isenzyme.astype('int')).flatten()\n",
604 | "Y_test = np.array(test_esm.isenzyme.astype('int')).flatten()"
605 | ]
606 | },
607 | {
608 | "cell_type": "code",
609 | "execution_count": null,
610 | "metadata": {},
611 | "outputs": [],
612 | "source": [
613 | "# groundtruth, predict, predictprob = funclib.xgmain(X_train, Y_train, X_test, Y_test, type='binary')\n",
614 | "groundtruth, predict, predictprob = funclib.knnmain(X_train, Y_train, X_test, Y_test, type='binary')\n",
615 | "blastres['isEmzyme_pred_xg'] = predict\n",
616 | "blastres.isEmzyme_pred_xg =blastres.isEmzyme_pred_xg.astype('bool')\n",
617 | "blastres['isEmzyme_pred_slice']=blastres.apply(lambda x: x.isEmzyme_pred_xg if str(x.isEmzyme_pred_blast)=='nan' else x.isEmzyme_pred_blast, axis=1)\n",
618 | "print('baslineName', '\\t', 'accuracy','\\t', 'precision(PPV) \\t NPV \\t\\t', 'recall','\\t', 'f1', '\\t\\t', '\\t confusion Matrix')\n",
619 | "eva.caculateMetrix( groundtruth=blastres.isemzyme_groundtruth, predict=blastres.isEmzyme_pred_slice, baselineName='ours', type='binary')"
620 | ]
621 | },
622 | {
623 | "cell_type": "code",
624 | "execution_count": 7,
625 | "metadata": {},
626 | "outputs": [
627 | {
628 | "data": {
629 | "text/plain": [
630 | "['/home/shizhenkun/codebase/DMLF/model/isenzyme.model']"
631 | ]
632 | },
633 | "execution_count": 7,
634 | "metadata": {},
635 | "output_type": "execute_result"
636 | }
637 | ],
638 | "source": [
639 | "groundtruth, predict, predictprob, model = funclib.knnmain(X_train, Y_train, X_test, Y_test, type='binary')\n",
640 | "joblib.dump(model, cfg.ISENZYME_MODEL)"
641 | ]
642 | },
643 | {
644 | "cell_type": "code",
645 | "execution_count": 62,
646 | "metadata": {},
647 | "outputs": [],
648 | "source": [
649 | "# 保存文件\n",
650 | "blastres.to_csv(cfg.FILE_SLICE_ISENZYME_RESULTS, sep='\\t', index=None)"
651 | ]
652 | }
653 | ],
654 | "metadata": {
655 | "kernelspec": {
656 | "display_name": "DMLF",
657 | "language": "python",
658 | "name": "python3"
659 | },
660 | "language_info": {
661 | "codemirror_mode": {
662 | "name": "ipython",
663 | "version": 3
664 | },
665 | "file_extension": ".py",
666 | "mimetype": "text/x-python",
667 | "name": "python",
668 | "nbconvert_exporter": "python",
669 | "pygments_lexer": "ipython3",
670 | "version": "3.8.15"
671 | },
672 | "vscode": {
673 | "interpreter": {
674 | "hash": "6b0f740237ba4768c544d9b9677983e49b45ca1230fda464ede0b93eba99c7d2"
675 | }
676 | },
677 | "widgets": {
678 | "application/vnd.jupyter.widget-state+json": {
679 | "state": {},
680 | "version_major": 2,
681 | "version_minor": 0
682 | }
683 | }
684 | },
685 | "nbformat": 4,
686 | "nbformat_minor": 4
687 | }
688 |
--------------------------------------------------------------------------------
/tmp/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tools/Attention.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2022-10-09 06:57:50
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2022-10-09 06:58:50
6 | FilePath: /DMLF/tools/Attention.py
7 | Description:
8 |
9 | Copyright (c) 2022 by tibd, All Rights Reserved.
10 | '''
11 |
12 | # -*- coding: utf-8 -*-
13 | from keras import backend as K
14 | # from keras.engine.topology import Layer
15 | from keras.layers import Layer, InputSpec
16 | # from visualizer import get_local
17 |
18 | # 利用Keras构造注意力机制层
19 | class Attention(Layer):
20 | def __init__(self, attention_size, **kwargs):
21 | self.attention_size = attention_size
22 | super(Attention, self).__init__(**kwargs)
23 |
24 | def build(self, input_shape):
25 | # W: (EMBED_SIZE, ATTENTION_SIZE)
26 | # b: (ATTENTION_SIZE, 1)
27 | # u: (ATTENTION_SIZE, 1)
28 | self.W = self.add_weight(name="W_{:s}".format(self.name),
29 | shape=(input_shape[-1], self.attention_size),
30 | initializer="glorot_normal",
31 | trainable=True)
32 | self.b = self.add_weight(name="b_{:s}".format(self.name),
33 | shape=(input_shape[1], 1),
34 | initializer="zeros",
35 | trainable=True)
36 | self.u = self.add_weight(name="u_{:s}".format(self.name),
37 | shape=(self.attention_size, 1),
38 | initializer="glorot_normal",
39 | trainable=True)
40 | super(Attention, self).build(input_shape)
41 |
42 | def call(self, x, mask=None):
43 | # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
44 | # et: (BATCH_SIZE, MAX_TIMESTEPS, ATTENTION_SIZE)
45 | et = K.tanh(K.dot(x, self.W) + self.b)
46 | # at: (BATCH_SIZE, MAX_TIMESTEPS)
47 | at = K.softmax(K.squeeze(K.dot(et, self.u), axis=-1))
48 | if mask is not None:
49 | at *= K.cast(mask, K.floatx())
50 | # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
51 | atx = K.expand_dims(at, axis=-1)
52 | ot = atx * x
53 | # output: (BATCH_SIZE, EMBED_SIZE)
54 | output = K.sum(ot, axis=1)
55 | return output
56 |
57 | def compute_mask(self, input, input_mask=None):
58 | return None
59 |
60 | def compute_output_shape(self, input_shape):
61 | return (input_shape[0], input_shape[-1])
62 |
63 | # 该函数用于保存和加载模型时使用
64 | def get_config(self):
65 | config = {"attention_size": self.attention_size}
66 | base_config = super(Attention, self).get_config()
67 | return dict(list(base_config.items()) + list(config.items()))
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kingstdio/ECRECer/5ccd2c9d163e20faf813d829a868dfa8a0fedc93/tools/__init__.py
--------------------------------------------------------------------------------
/tools/baselines.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "9307705e-50b0-4bae-b67c-ae8d62eb1737",
6 | "metadata": {},
7 | "source": [
8 | "# Task1. Baseline Methods\n",
9 | "\n",
10 | "> author: Shizhenkun \n",
11 | "> email: zhenkun.shi@tib.cas.cn \n",
12 | "> date: 2021-10-20 "
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "cb15383d-9838-48f5-9cf9-b369418a7e6b",
18 | "metadata": {},
19 | "source": [
20 | "## 1. Import packages"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 3,
26 | "id": "71bdaefa-a752-4ee3-851a-aed07d6b3363",
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "INFO: Pandarallel will run on 80 workers.\n",
34 | "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n",
35 | "The autoreload extension is already loaded. To reload it, use:\n",
36 | " %reload_ext autoreload\n"
37 | ]
38 | }
39 | ],
40 | "source": [
41 | "import numpy as np\n",
42 | "import pandas as pd\n",
43 | "import time\n",
44 | "import datetime\n",
45 | "import sys\n",
46 | "import os\n",
47 | "from tqdm import tqdm\n",
48 | "from functools import reduce\n",
49 | "\n",
50 | "sys.path.append(\"../tools/\")\n",
51 | "import funclib\n",
52 | "\n",
53 | "sys.path.append(\"../\")\n",
54 | "import benchmark_train as btrain\n",
55 | "import benchmark_test as btest\n",
56 | "import config as cfg\n",
57 | "import benchmark_evaluation as eva\n",
58 | "\n",
59 | "from sklearn.model_selection import train_test_split\n",
60 | "from xgboost import XGBClassifier\n",
61 | "\n",
62 | "from pandarallel import pandarallel # import pandaralle\n",
63 | "pandarallel.initialize() # init\n",
64 | "\n",
65 | "%load_ext autoreload\n",
66 | "%autoreload 2"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "id": "04b3f7db-f81d-481e-b0c5-278c60fe87f0",
72 | "metadata": {},
73 | "source": [
74 | "## 2. Load Data"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 2,
80 | "id": "8482e6a8-6523-4ed6-8784-d245538cf0b1",
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "name": "stdout",
85 | "output_type": "stream",
86 | "text": [
87 | "test size: 7101\n"
88 | ]
89 | }
90 | ],
91 | "source": [
92 | "test = pd.read_feather(cfg.DATADIR+'task1/test.feather')\n",
93 | "print('test size: {0}'.format(len(test)))"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "id": "5625d510-2da8-47f7-b082-7bd39ce5137a",
99 | "metadata": {},
100 | "source": [
101 | "## 3.Make fasta"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 9,
107 | "id": "4b090ea7-77b4-40a3-ba7c-78fa1dfb6ec3",
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "Write finished\n"
115 | ]
116 | }
117 | ],
118 | "source": [
119 | "funclib.table2fasta(table=test, file_out='../data/test.fasta')"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "id": "da7733ac-5080-4879-9d05-56cd657cfbd8",
125 | "metadata": {},
126 | "source": [
127 | "## 4. ECPred\n",
128 | "\n",
129 | "Please be patient, this method takes a long time to predict"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 3,
135 | "id": "058a7ff1-45ee-4db8-871d-e25b3d81c6d7",
136 | "metadata": {
137 | "collapsed": true,
138 | "jupyter": {
139 | "outputs_hidden": true
140 | },
141 | "tags": []
142 | },
143 | "outputs": [
144 | {
145 | "name": "stdout",
146 | "output_type": "stream",
147 | "text": [
148 | "Hit:1 http://cn.archive.ubuntu.com/ubuntu focal InRelease\n",
149 | "Hit:2 http://cn.archive.ubuntu.com/ubuntu focal-updates InRelease\n",
150 | "Hit:3 http://cn.archive.ubuntu.com/ubuntu focal-backports InRelease\n",
151 | "Hit:4 http://cn.archive.ubuntu.com/ubuntu focal-security InRelease\n",
152 | "Reading package lists... Done\n",
153 | "Reading package lists... Done\n",
154 | "Building dependency tree \n",
155 | "Reading state information... Done\n",
156 | "default-jre is already the newest version (2:1.11-72).\n",
157 | "0 upgraded, 0 newly installed, 0 to remove and 44 not upgraded.\n",
158 | "Reading package lists... Done\n",
159 | "Building dependency tree \n",
160 | "Reading state information... Done\n",
161 | "default-jdk is already the newest version (2:1.11-72).\n",
162 | "0 upgraded, 0 newly installed, 0 to remove and 44 not upgraded.\n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "!sudo apt-get update -y\n",
168 | "!sudo apt-get install default-jre -y\n",
169 | "!sudo apt-get install default-jdk -y"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "id": "6f0e293f-9651-4c7a-95b1-921dda74adda",
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "Main classes of input proteins are being predicted ...\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "!java -jar ../baselines/ECPred/ECPred.jar weighted ../data/test.fasta ../baselines/ECPred/ ../results/ecpred.txt"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "2934039b-ddac-45cd-b361-5ae8d218753a",
193 | "metadata": {},
194 | "source": [
195 | "## 5. DeepEC"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "6f9c3112-081b-4b2f-8fcd-95cad5b74d61",
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "! conda env create -f ../baselines/deepec/environment.yml\n",
206 | "! conda activate deepec"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "id": "e07e3f89-5292-4e63-ad6b-290a786ee043",
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "! python ../baselines/deepec/deepec.py -i ../data/test.fasta -o ../results/deepec/"
217 | ]
218 | },
219 | {
220 | "cell_type": "markdown",
221 | "id": "ac05fbbe-ca3b-4b71-87e7-d5f81d8cd675",
222 | "metadata": {},
223 | "source": [
224 | "## 6. CatFam"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 1,
230 | "id": "4e11cfc0-8e51-490c-8b6e-538df2c053db",
231 | "metadata": {
232 | "collapsed": true,
233 | "jupyter": {
234 | "outputs_hidden": true
235 | },
236 | "tags": []
237 | },
238 | "outputs": [
239 | {
240 | "name": "stdout",
241 | "output_type": "stream",
242 | "text": [
243 | "Reading package lists... Done\n",
244 | "Building dependency tree \n",
245 | "Reading state information... Done\n",
246 | "The following additional packages will be installed:\n",
247 | " libmbedcrypto3 libmbedtls12 libmbedx509-0 ncbi-blast+ ncbi-data\n",
248 | "The following NEW packages will be installed:\n",
249 | " libmbedcrypto3 libmbedtls12 libmbedx509-0 ncbi-blast+ ncbi-blast+-legacy\n",
250 | " ncbi-data\n",
251 | "0 upgraded, 6 newly installed, 0 to remove and 44 not upgraded.\n",
252 | "Need to get 14.9 MB of archives.\n",
253 | "After this operation, 75.0 MB of additional disk space will be used.\n",
254 | "Get:1 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedcrypto3 amd64 2.16.4-1ubuntu2 [150 kB]\n",
255 | "Get:2 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedx509-0 amd64 2.16.4-1ubuntu2 [42.3 kB]\n",
256 | "Get:3 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 libmbedtls12 amd64 2.16.4-1ubuntu2 [71.8 kB]\n",
257 | "Get:4 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-data all 6.1.20170106+dfsg1-8 [3,518 kB]\n",
258 | "Get:5 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-blast+ amd64 2.9.0-2 [11.1 MB]\n",
259 | "Get:6 http://cn.archive.ubuntu.com/ubuntu focal/universe amd64 ncbi-blast+-legacy all 2.9.0-2 [5,064 B]\n",
260 | "Fetched 14.9 MB in 5s (3,160 kB/s) \n",
261 | "Selecting previously unselected package libmbedcrypto3:amd64.\n",
262 | "(Reading database ... 216316 files and directories currently installed.)\n",
263 | "Preparing to unpack .../0-libmbedcrypto3_2.16.4-1ubuntu2_amd64.deb ...\n",
264 | "Unpacking libmbedcrypto3:amd64 (2.16.4-1ubuntu2) ...\n",
265 | "Selecting previously unselected package libmbedx509-0:amd64.\n",
266 | "Preparing to unpack .../1-libmbedx509-0_2.16.4-1ubuntu2_amd64.deb ...\n",
267 | "Unpacking libmbedx509-0:amd64 (2.16.4-1ubuntu2) ...\n",
268 | "Selecting previously unselected package libmbedtls12:amd64.\n",
269 | "Preparing to unpack .../2-libmbedtls12_2.16.4-1ubuntu2_amd64.deb ...\n",
270 | "Unpacking libmbedtls12:amd64 (2.16.4-1ubuntu2) ...\n",
271 | "Selecting previously unselected package ncbi-data.\n",
272 | "Preparing to unpack .../3-ncbi-data_6.1.20170106+dfsg1-8_all.deb ...\n",
273 | "Unpacking ncbi-data (6.1.20170106+dfsg1-8) ...\n",
274 | "Selecting previously unselected package ncbi-blast+.\n",
275 | "Preparing to unpack .../4-ncbi-blast+_2.9.0-2_amd64.deb ...\n",
276 | "Unpacking ncbi-blast+ (2.9.0-2) ...\n",
277 | "Selecting previously unselected package ncbi-blast+-legacy.\n",
278 | "Preparing to unpack .../5-ncbi-blast+-legacy_2.9.0-2_all.deb ...\n",
279 | "Unpacking ncbi-blast+-legacy (2.9.0-2) ...\n",
280 | "Setting up ncbi-data (6.1.20170106+dfsg1-8) ...\n",
281 | "Setting up libmbedcrypto3:amd64 (2.16.4-1ubuntu2) ...\n",
282 | "Setting up libmbedx509-0:amd64 (2.16.4-1ubuntu2) ...\n",
283 | "Setting up libmbedtls12:amd64 (2.16.4-1ubuntu2) ...\n",
284 | "Setting up ncbi-blast+ (2.9.0-2) ...\n",
285 | "Setting up ncbi-blast+-legacy (2.9.0-2) ...\n",
286 | "Processing triggers for libc-bin (2.31-0ubuntu9.2) ...\n",
287 | "Processing triggers for man-db (2.9.1-1) ...\n",
288 | "Processing triggers for hicolor-icon-theme (0.17-2) ...\n"
289 | ]
290 | }
291 | ],
292 | "source": [
293 | "!sudo apt-get install ncbi-blast+-legacy -y"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 21,
299 | "id": "af231781-8cf8-43f4-9ec9-4f1c8d577ff0",
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "! ../baselines/catfam/source/catsearch.pl -d ../baselines/catfam/CatFamDB/CatFam_v2.0/CatFam4D99R -i ../data/test.fasta -o ../results/catfam.txt"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "id": "fd0de0c4",
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "! java -Xmx50G -jar ../baselines/priam/PRIAM_search.jar -p /home/shizhenkun/codebase/DMLF/baselines/priam/PRIAM_JAN18 -i /home/shizhenkun/codebase/DMLF/data/gu_bang.fasta -o /home/shizhenkun/codebase/DMLF/results/case_gubang/priam/ --bp /home/shizhenkun/downloads/blast-2.2.13/bin --np 78"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "id": "05ec97bb-f0c4-41c6-9158-ed9ed0e95806",
319 | "metadata": {},
320 | "source": [
321 | "## 7. PRIAM"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 26,
327 | "id": "26bcceda-472e-474c-a37f-c20d56ded403",
328 | "metadata": {
329 | "collapsed": true,
330 | "jupyter": {
331 | "outputs_hidden": true
332 | },
333 | "tags": []
334 | },
335 | "outputs": [
336 | {
337 | "name": "stdout",
338 | "output_type": "stream",
339 | "text": [
340 | "\n",
341 | "\n",
342 | "PRIAM did not find the profiles library (Normal if it is the first time you use PRIAM with this release).\n",
343 | "So it would now create it. Please do not interupt it during this process else it would results into a corrupted database and any further analysis would be faulty.\n",
344 | "If for any reason you really need to interupt PRIAM before it ends, or if your computer crash during this process, please delete ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY, and all files it contains, to force PRIAM to recreate the library next time it would be launched\n",
345 | "\n",
346 | "Executing makeprofiledb -in /home/shizhenkun/codebase/DMLF/tools/../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/profiles.list -index T -out PROFILE_EZ -title PRIAM_profiles_database \n",
347 | "\n",
348 | "Profiles database sucessfully created.\n",
349 | "If you want to interupt PRIAM, it can now be done with no risk\n",
350 | "\n",
351 | "### PRIAM Profiles search ###\n",
352 | "7101 sequences found in your query file\n",
353 | "Query file splitted into 2 pieces\n",
354 | "Executing 'rpsblast -num_threads 78 -outfmt 5 -evalue 1.0 -query ../results/priam//PRIAM_20211102154953/TMP/DATA/query_part1.fas -db ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/PROFILE_EZ -out /home/shizhenkun/codebase/DMLF/tools/../results/priam/PRIAM_20211102154953/TMP/RAW_RESULTS/query_part1.res -num_alignments 10000 -num_descriptions 10000'.\n",
355 | "This RPS-BLAST job took 4mn 21s \n",
356 | "RPS-Blast jobs now running from 4mn 21s \n",
357 | "Estimated time to complete all 1 remaining RPS-BLAST jobs: 4mn 21s \n",
358 | "Executing 'rpsblast -num_threads 78 -outfmt 5 -evalue 1.0 -query ../results/priam//PRIAM_20211102154953/TMP/DATA/query_part2.fas -db ../baselines/priam/PRIAM_JAN18/PROFILES/LIBRARY/PROFILE_EZ -out /home/shizhenkun/codebase/DMLF/tools/../results/priam/PRIAM_20211102154953/TMP/RAW_RESULTS/query_part2.res -num_alignments 10000 -num_descriptions 10000'.\n",
359 | "This RPS-BLAST job took 1mn 52s \n",
360 | "Time used to search all sequences against profiles of PRIAM: 6mn 13s \n",
361 | "\n",
362 | "### Individual sequence annotation ###\n",
363 | "Exception in thread \"main\" java.lang.reflect.InvocationTargetException\n",
364 | "\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n",
365 | "\tat java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\n",
366 | "\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n",
367 | "\tat java.base/java.lang.reflect.Method.invoke(Method.java:566)\n",
368 | "\tat org.eclipse.jdt.internal.jarinjarloader.JarRsrcLoader.main(JarRsrcLoader.java:58)\n",
369 | "Caused by: java.lang.NoClassDefFoundError: weka/core/Utils\n",
370 | "\tat priam.common.DistributionEstimator.(DistributionEstimator.java:188)\n",
371 | "\tat priam.common.ProfInfos.setPositivesHitsDistrib(ProfInfos.java:229)\n",
372 | "\tat priam.common.AnnotationRulesXMLHandler2.startElement(AnnotationRulesXMLHandler2.java:133)\n",
373 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.startElement(AbstractSAXParser.java:510)\n",
374 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractXMLDocumentParser.emptyElement(AbstractXMLDocumentParser.java:183)\n",
375 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl.scanStartElement(XMLDocumentFragmentScannerImpl.java:1377)\n",
376 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl$FragmentContentDriver.next(XMLDocumentFragmentScannerImpl.java:2710)\n",
377 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentScannerImpl.next(XMLDocumentScannerImpl.java:605)\n",
378 | "\tat java.xml/com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl.scanDocument(XMLDocumentFragmentScannerImpl.java:534)\n",
379 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:888)\n",
380 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:824)\n",
381 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.XMLParser.parse(XMLParser.java:141)\n",
382 | "\tat java.xml/com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.parse(AbstractSAXParser.java:1216)\n",
383 | "\tat java.xml/com.sun.org.apache.xerces.internal.jaxp.SAXParserImpl$JAXPSAXParser.parse(SAXParserImpl.java:635)\n",
384 | "\tat java.xml/com.sun.org.apache.xerces.internal.jaxp.SAXParserImpl.parse(SAXParserImpl.java:324)\n",
385 | "\tat java.xml/javax.xml.parsers.SAXParser.parse(SAXParser.java:330)\n",
386 | "\tat priam.modules.Annotate.readAnnotationRulesXML(Annotate.java:65)\n",
387 | "\tat priam.modules.Annotate.AnnotateSequences(Annotate.java:110)\n",
388 | "\tat priam.search.Main.main_standard(Main.java:51)\n",
389 | "\tat priam.search.Main.main(Main.java:16)\n",
390 | "\t... 5 more\n",
391 | "Caused by: java.lang.ClassNotFoundException: weka.core.Utils\n",
392 | "\tat java.base/java.net.URLClassLoader.findClass(URLClassLoader.java:471)\n",
393 | "\tat java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:589)\n",
394 | "\tat java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:522)\n",
395 | "\t... 25 more\n"
396 | ]
397 | }
398 | ],
399 | "source": [
400 | "! java -Xmx50G -jar ../baselines/priam/PRIAM_search.jar -p /home/shizhenkun/codebase/DMLF/baselines/priam/PRIAM_JAN18 -i /home/shizhenkun/codebase/DMLF/data/test.fasta -o /home/shizhenkun/codebase/DMLF/results/priam/ --bp /home/shizhenkun/downloads/blast-2.2.13/bin --np 78"
401 | ]
402 | }
403 | ],
404 | "metadata": {
405 | "kernelspec": {
406 | "display_name": "py38",
407 | "language": "python",
408 | "name": "py38"
409 | },
410 | "language_info": {
411 | "codemirror_mode": {
412 | "name": "ipython",
413 | "version": 3
414 | },
415 | "file_extension": ".py",
416 | "mimetype": "text/x-python",
417 | "name": "python",
418 | "nbconvert_exporter": "python",
419 | "pygments_lexer": "ipython3",
420 | "version": "3.7.10 | packaged by conda-forge | (default, Feb 19 2021, 16:07:37) \n[GCC 9.3.0]"
421 | },
422 | "vscode": {
423 | "interpreter": {
424 | "hash": "5f0598356972e9a098a4a3756f1f561f75f6bcc730be3bc51870d213558f68c7"
425 | }
426 | },
427 | "widgets": {
428 | "application/vnd.jupyter.widget-state+json": {
429 | "state": {},
430 | "version_major": 2,
431 | "version_minor": 0
432 | }
433 | }
434 | },
435 | "nbformat": 4,
436 | "nbformat_minor": 5
437 | }
438 |
--------------------------------------------------------------------------------
/tools/embdding_onehot.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2022-10-08 06:15:36
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2022-10-08 08:00:58
6 | FilePath: /DMLF/tools/embdding_onehot.py
7 | Description:
8 |
9 | Copyright (c) 2022 by tibd, All Rights Reserved.
10 | '''
11 |
12 |
13 | import numpy as np
14 | import pandas as pd
15 | import os,sys,itertools
16 | from datetime import datetime
17 | from tqdm import tqdm
18 | from pandarallel import pandarallel # 导入pandaralle
19 | sys.path.append(os.path.dirname(os.path.realpath('__file__')))
20 | sys.path.append('../')
21 | import config as cfg
22 |
23 | # amino acid 编码字典
24 | prot_dict = dict(
25 | A=0b000000000000000000000001, R=0b000000000000000000000010, N=0b000000000000000000000100,
26 | D=0b000000000000000000001000, C=0b000000000000000000010000, E=0b000000000000000000100000,
27 | Q=0b000000000000000001000000, G=0b000000000000000010000000, H=0b000000000000000100000000,
28 | O=0b000000000000001000000000, I=0b000000000000010000000000, L=0b000000000000100000000000,
29 | K=0b000000000001000000000000, M=0b000000000010000000000001, F=0b000000000100000000000000,
30 | P=0b000000001000000000000000, U=0b000000010000000000000000, S=0b000000100000000000000000,
31 | T=0b000001000000000000000000, W=0b000010000000000000000000, Y=0b000100000000000000000000,
32 | V=0b001000000000000000000000, B=0b010000000000000000000000, Z=0b100000000000000000000000,
33 | X=0b000000000000000000000000
34 | )
35 |
36 | #region using one-hot to encode a amino acid sequence
37 | def protein_sequence_one_hot(protein_seq, padding=False, padding_window=1500):
38 | """ using one-hot to encode a amino acid sequence
39 |
40 | Args:
41 | protein_seq (string): aminio acid sequence
42 | padding (bool, optional): padding to a fixed length. Defaults to False.
43 | padding_window (int, optional): padded sequence size. Defaults to 1500.
44 |
45 | Returns:
46 | _type_: _description_
47 | """
48 | res = [prot_dict.get(item) for item in protein_seq]
49 | if padding==True:
50 | if len(protein_seq)>=padding_window:
51 | res= res[:padding_window]
52 | else:
53 | res=np.pad(res, (0,(padding_window-len(protein_seq))), 'median')
54 | return list(res)
55 | #endregion
56 |
57 |
58 | #region encode amino acid sequences dataframe
59 | def get_onehot(sequences, padding=True, padding_window=1500):
60 | """encode amino acid sequences dataframe
61 |
62 | Args:
63 | sequences (DataFrame): sequences dataframe cols name must contain ['id, 'seq']
64 | padding (bool, optional): if padding to a fixed length. Defaults to True.
65 | padding_window (int, optional): fixed padding size. Defaults to 1500.
66 |
67 | Returns:
68 | DataFrame: one-hot represented sequences DataFrame
69 | """
70 | sequences = sequences[['id','seq']].copy()
71 | sequences['onehot']=sequences.parallel_apply(lambda x: protein_sequence_one_hot(protein_seq=x.seq, padding=padding, padding_window=padding_window), axis=1)
72 | one_hot_rep = pd.DataFrame(np.array(list(itertools.chain(*sequences.onehot.values))).reshape(-1,(padding_window)), columns=['f'+str(i) for i in range (1,(padding_window+1))])
73 | one_hot_rep.insert(0,'id',value=sequences.id.values)
74 |
75 | return one_hot_rep
76 | #endregion
77 |
78 | if __name__ =='__main__':
79 | seqs = pd.DataFrame([['seq1', 'MTTSVIVAGARTPIGKLMGSLKDFSASELGAIAIKGALEKANVPAS'],
80 | ['seq2', 'MAERAPRGEVAVMVAVQSALVDRPGMLATARGLSHFGEHCIGWLIL']
81 | ], columns=['id', 'seq'])
82 | pandarallel.initialize()
83 | res = get_onehot(sequences=seqs, padding=True, padding_window=50)
84 | print(res)
--------------------------------------------------------------------------------
/tools/embedding_esm.py:
--------------------------------------------------------------------------------
1 | from esm import model
2 | import torch
3 | import esm
4 | import re
5 | from tqdm import tqdm
6 | import numpy as np
7 | import pandas as pd
8 | import sys
9 | sys.path.insert(0, "../../")
10 | import config as cfg
11 |
12 |
13 | # region 将字符串拆分成固定长度
14 | def cut_text(text,lenth):
15 | """[将字符串拆分成固定长度]
16 |
17 | Args:
18 | text ([string]): [input string]
19 | lenth ([int]): [sub_sequence length]
20 |
21 | Returns:
22 | [string list]: [string results list]
23 | """
24 | textArr = re.findall('.{'+str(lenth)+'}', text)
25 | textArr.append(text[(len(textArr)*lenth):])
26 | return textArr
27 | #endregion
28 |
29 | #region 对单个序列进行embedding
30 | def get_rep_single_seq(seqid, sequence, model,batch_converter, seqthres=1022):
31 | """[对单个序列进行embedding]
32 |
33 | Args:
34 | seqid ([string]): [sequence name]]
35 | sequence ([sting]): [sequence]
36 | model ([model]): [ embedding model]]
37 | batch_converter ([object]): [description]
38 | seqthres (int, optional): [max sequence length]. Defaults to 1022.
39 |
40 | Returns:
41 | [type]: [description]
42 | """
43 |
44 | if len(sequence) < seqthres:
45 | data =[(seqid, sequence)]
46 | else:
47 | seqArray = cut_text(sequence, seqthres)
48 | data=[]
49 | for item in seqArray:
50 | data.append((seqid, item))
51 | batch_labels, batch_strs, batch_tokens = batch_converter(data)
52 |
53 | if torch.cuda.is_available():
54 | batch_tokens = batch_tokens.to(device="cuda", non_blocking=True)
55 |
56 | REP_LAYERS = [0, 32, 33]
57 | MINI_SIZE = len(batch_labels)
58 |
59 | with torch.no_grad():
60 | results = model(batch_tokens, repr_layers=REP_LAYERS, return_contacts=False)
61 |
62 |
63 | representations = {layer: t.to(device="cpu") for layer, t in results["representations"].items()}
64 | result ={}
65 | result["label"] = batch_labels[0]
66 |
67 | for i in range(MINI_SIZE):
68 | if i ==0:
69 | result["mean_representations"] = {layer: t[i, 1 : len(batch_strs[0]) + 1].mean(0).clone() for layer, t in representations.items()}
70 | else:
71 | for index, layer in enumerate(REP_LAYERS):
72 | result["mean_representations"][layer] += {layer: t[i, 1 : len(batch_strs[0]) + 1].mean(0).clone() for layer, t in representations.items()}[layer]
73 |
74 | for index, layer in enumerate(REP_LAYERS):
75 | result["mean_representations"][layer] = result["mean_representations"][layer] /MINI_SIZE
76 |
77 | return result
78 | #endregion
79 |
80 | #region 对多个序列进行embedding
81 | def get_rep_multi_sequence(sequences, model='esm_msa1b_t12_100M_UR50S', repr_layers=[0, 32, 33], seqthres=1022):
82 | """[对多个序列进行embedding]
83 | Args:
84 | sequences ([DataFrame]): [ sequence info]]
85 | seqthres (int, optional): [description]. Defaults to 1022.
86 |
87 | Returns:
88 | [DataFrame]: [final_rep0, final_rep32, final_rep33]
89 | """
90 | final_label_list = []
91 | final_rep0 =[]
92 | final_rep32 =[]
93 | final_rep33 =[]
94 |
95 | model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
96 | batch_converter = alphabet.get_batch_converter()
97 | if torch.cuda.is_available():
98 | model = model.cuda()
99 | print("Transferred model to GPU")
100 |
101 | for i in tqdm(range(len(sequences))):
102 | apd = get_rep_single_seq(
103 | seqid = sequences.iloc[i].id,
104 | sequence=sequences.iloc[i].seq,
105 | model=model,
106 | batch_converter=batch_converter,
107 | seqthres=seqthres)
108 |
109 | final_label_list.append(np.array(apd['label']))
110 | final_rep0.append(np.array(apd['mean_representations'][0]))
111 | final_rep32.append(np.array(apd['mean_representations'][32]))
112 | final_rep33.append(np.array(apd['mean_representations'][33]))
113 |
114 | final_rep0 = pd.DataFrame(final_rep0)
115 | final_rep32 = pd.DataFrame(final_rep32)
116 | final_rep33 = pd.DataFrame(final_rep33)
117 | final_rep0.insert(loc=0, column='id', value=np.array(final_label_list).flatten())
118 | final_rep32.insert(loc=0, column='id', value=np.array(final_label_list).flatten())
119 | final_rep33.insert(loc=0, column='id', value=np.array(final_label_list).flatten())
120 |
121 | col_name = ['id']+ ['f'+str(i) for i in range (1,final_rep0.shape[1])]
122 | final_rep0.columns = col_name
123 | final_rep32.columns = col_name
124 | final_rep33.columns = col_name
125 |
126 | return final_rep0, final_rep32, final_rep33
127 | #endregion
128 |
129 | if __name__ =='__main__':
130 | SEQTHRES = 1022
131 | RUNMODEL = { 'ESM-1b' :'esm1b_t33_650M_UR50S',
132 | 'ESM-MSA-1b' :'esm_msa1b_t12_100M_UR50S'
133 | }
134 | train = pd.read_feather(cfg.DATADIR+'train.feather').iloc[:,:6]
135 | test = pd.read_feather(cfg.DATADIR+'test.feather').iloc[:,:6]
136 | rep0, rep32, rep33 = get_rep_multi_sequence(sequences=train, model=RUNMODEL.get('ESM-MSA-1b'),seqthres=SEQTHRES)
137 |
138 | rep0.to_feather(cfg.DATADIR + 'train_rep0.feather')
139 | rep32.to_feather(cfg.DATADIR + 'train_rep32.feather')
140 | rep33.to_feather(cfg.DATADIR + 'train_rep33.feather')
141 |
142 | print('Embedding Success!')
--------------------------------------------------------------------------------
/tools/exact_ec_from_uniprot.py:
--------------------------------------------------------------------------------
1 | from Bio import SeqIO
2 | import gzip
3 | import re
4 | from tqdm import tqdm
5 | import time
6 | import sys,os
7 | sys.path.append(os.getcwd())
8 | import config as cfg
9 | import pandas as pd
10 |
11 | #region 从gizp读取数据
12 | def read_file_from_gzip(file_in_path, file_out_path, extract_type, save_file_type='tsv'):
13 | """从原始Zip file 中解析数据
14 |
15 | Args:
16 | file_in_path ([string]): [输入文件路径]
17 | file_out_path ([string]): [输出文件路径]]
18 | extract_type ([string]): [抽取数据的类型:with_ec, without_ec, full]]
19 | """
20 | if save_file_type == 'feather':
21 | outpath = file_out_path
22 | file_out_path = cfg.TEMPDIR +'temprecords.tsv'
23 |
24 | table_head = [ 'id',
25 | 'name',
26 | 'isenzyme',
27 | 'isMultiFunctional',
28 | 'functionCounts',
29 | 'ec_number',
30 | 'ec_specific_level',
31 | 'date_integraged',
32 | 'date_sequence_update',
33 | 'date_annotation_update',
34 | 'seq',
35 | 'seqlength'
36 | ]
37 | counter = 0
38 | saver = 0
39 | file_write_obj = open(file_out_path,'w')
40 | with gzip.open(file_in_path, "rt") as handle:
41 | file_write_obj.writelines('\t'.join(table_head))
42 | file_write_obj.writelines('\n')
43 | for record in tqdm( SeqIO.parse(handle, 'swiss'), position=1, leave=True):
44 | res = process_record(record, extract_type= extract_type)
45 | counter+=1
46 | if counter %10==0:
47 | file_write_obj.flush()
48 | # if saver%100000==0:
49 | # print(saver)
50 | if len(res) >0:
51 | saver +=1
52 | file_write_obj.writelines('\t'.join(map(str,res)))
53 | file_write_obj.writelines('\n')
54 | else:
55 | continue
56 | file_write_obj.close()
57 |
58 | if save_file_type == 'feather':
59 | indata = pd.read_csv(cfg.TEMPDIR +'temprecords.tsv', sep='\t')
60 | indata.to_feather(outpath)
61 |
62 |
63 | #endregion
64 |
65 | #region 提取单条含有EC号的数据
66 | def process_record(record, extract_type='with_ec'):
67 | """
68 | 提取单条含有EC号的数据
69 | Args:
70 | record ([type]): uniprot 中的记录节点
71 | extract_type (string, optional): 提取的类型,可选值:with_ec, without_ec, full。默认为有EC号(with_ec).
72 | Returns:
73 | [type]: [description]
74 | """
75 |
76 |
77 | description = record.description
78 | isEnzyme = 'EC=' in description #有EC号的被认为是酶,否则认为是酶
79 | isMultiFunctional = False
80 | functionCounts = 0
81 | ec_specific_level =0
82 |
83 |
84 | if isEnzyme:
85 | ec = str(re.findall(r"EC=[0-9,.\-;]*",description)
86 | ).replace('EC=','').replace('\'','').replace(']','').replace('[','').replace(';','').strip()
87 |
88 | #统计酶的功能数
89 | isMultiFunctional = ',' in ec
90 | functionCounts = ec.count(',') + 1
91 |
92 |
93 | # - 单功能酶
94 | if not isMultiFunctional:
95 | levelCount = ec.count('-')
96 | ec_specific_level = 4-levelCount
97 |
98 | else: # -多功能酶
99 | ecarray = ec.split(',')
100 | for subec in ecarray:
101 | current_ec_level = 4- subec.count('-')
102 | if ec_specific_level < current_ec_level:
103 | ec_specific_level = current_ec_level
104 | else:
105 | ec = '-'
106 |
107 | id = record.id.strip()
108 | name = record.name.strip()
109 | seq = record.seq.strip()
110 | date_integrated = record.annotations.get('date').strip()
111 | date_sequence_update = record.annotations.get('date_last_sequence_update').strip()
112 | date_annotation_update = record.annotations.get('date_last_annotation_update').strip()
113 | seqlength = len(seq)
114 | res = [id, name, isEnzyme, isMultiFunctional, functionCounts, ec,ec_specific_level, date_integrated, date_sequence_update, date_annotation_update, seq, seqlength]
115 |
116 | if extract_type == 'full':
117 | return res
118 |
119 | if extract_type == 'with_ec':
120 | if isEnzyme:
121 | return res
122 | else:
123 | return []
124 |
125 | if extract_type == 'without_ec':
126 | if isEnzyme:
127 | return []
128 | else:
129 | return res
130 | #endregion
131 |
132 | def run_exact_task(infile, outfile):
133 | start = time.process_time()
134 | extract_type ='full'
135 | read_file_from_gzip(file_in_path=infile, file_out_path=outfile, extract_type=extract_type)
136 | end = time.process_time()
137 | print('finished use time %6.3f s' % (end - start))
138 |
139 | if __name__ =="__main__":
140 | start = time.process_time()
141 | in_filepath_sprot = cfg.FILE_LATEST_SPROT
142 | out_filepath_sprot = cfg.FILE_LATEST_SPROT_FEATHER
143 |
144 | in_filepath_trembl = cfg.FILE_LATEST_TREMBL
145 | out_filepath_trembl = cfg.FILE_LATEST_TREMBL_FEATHER
146 |
147 | extract_type ='full'
148 |
149 |
150 | # read_file_from_gzip(file_in_path=in_filepath_sprot, file_out_path=out_filepath_sprot, extract_type=extract_type, save_file_type='feather')
151 |
152 | read_file_from_gzip(file_in_path=in_filepath_trembl, file_out_path=out_filepath_trembl, extract_type=extract_type, save_file_type='feather')
153 | end = time.process_time()
154 | print('finished use time %6.3f s' % (end - start))
155 |
156 |
--------------------------------------------------------------------------------
/tools/filetool.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2022-10-05 05:24:14
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2022-10-05 05:25:15
6 | FilePath: /DMLF/tools/filetool.py
7 | Description:
8 |
9 | Copyright (c) 2022 by tibd, All Rights Reserved.
10 | '''
11 |
12 | from retrying import retry
13 | from matplotlib.pyplot import axis
14 | import urllib3
15 | import pandas as pd
16 | import os
17 | import sys
18 | from shutil import copyfile
19 | from sys import exit
20 | import zipfile
21 |
22 |
23 | #region 下载文件
24 | @retry(stop_max_attempt_number=10, wait_random_min=10, wait_random_max=20)
25 | def download(download_url, save_file):
26 | """[从网址下载文件]
27 | Args:
28 | download_url ([Stirng]): [下载路径]
29 | save_file ([String]): [保存文件路径]
30 | """
31 | http = urllib3.PoolManager()
32 | response = http.request('GET', download_url)
33 | with open(save_file, 'wb') as f:
34 | f.write(response.data)
35 | response.release_conn()
36 | #endregion
37 |
38 |
39 | def wget(download_url, save_file, verbos=False):
40 | process = os.popen('which wget') # return file
41 | output = process.read()
42 | if output =='':
43 | print('wget not installed')
44 | else:
45 | if verbos:
46 | cmd = 'wget ' + download_url + ' -O ' + save_file
47 | else:
48 | cmd = 'wget -q ' + download_url + ' -O ' + save_file
49 | print (cmd)
50 | process = os.popen(cmd)
51 | output = process.read()
52 | process.close()
53 |
54 | def convert_DF_dateTime(inputdf):
55 | """[Covert unisprot csv records datatime]
56 |
57 | Args:
58 | inputdf ([DataFrame]): [input dataFrame]
59 |
60 | Returns:
61 | [DataFrame]: [converted DataFrame]
62 | """
63 | inputdf.date_integraged = pd.to_datetime(inputdf['date_integraged'])
64 | inputdf.date_sequence_update = pd.to_datetime(inputdf['date_sequence_update'])
65 | inputdf.date_annotation_update = pd.to_datetime(inputdf['date_annotation_update'])
66 | inputdf = inputdf.sort_values(by='date_integraged', ascending=True)
67 | inputdf.reset_index(drop=True, inplace=True)
68 | return inputdf
69 |
70 | def get_file_names_in_dir(dataroot, filetype='all'):
71 | """返回某个文件夹下的文件列表,可以指定文件类型
72 |
73 | Args:
74 | dataroot (string): 文件夹目录
75 | filetype (str, optional): 文件类型. Defaults to ''.
76 |
77 | Returns:
78 | DataFrame: columns=['filename','filetype','filename_no_suffix']
79 | """
80 | exist_file_df = pd.DataFrame(os.listdir(dataroot), columns=['filename'])
81 |
82 | if len(exist_file_df)!=0:
83 | exist_file_df['filetype'] = exist_file_df.filename.apply(lambda x: x.split('.')[-1])
84 | exist_file_df['filename_no_suffix'] = exist_file_df.apply(lambda x: x['filename'].replace(('.'+str(x['filetype']).strip()), ''), axis=1)
85 | if filetype !='all':
86 | exist_file_df = exist_file_df[exist_file_df.filetype==filetype]
87 | return exist_file_df
88 | else:
89 | return pd.DataFrame(columns=['filename','filetype','filename_no_suffix'])
90 |
91 |
92 | def copy(source, target):
93 | """ 拷贝文件
94 |
95 | Args:
96 | source (string): source
97 | target (string): target
98 | """
99 | try:
100 | copyfile(src=source, dst=target)
101 | except IOError as e:
102 | print("Unable to copy file. %s" % e)
103 | exit(1)
104 | except:
105 | print("Unexpected error:", sys.exc_info())
106 | exit(1)
107 |
108 | def delete(filepath):
109 | """删除文件
110 |
111 | Args:
112 | filepath (string): 文件全路径
113 | """
114 | try:
115 | os.remove(filepath)
116 | except IOError as e:
117 | print("Unable to delete file. %s" % e)
118 | exit(1)
119 | except:
120 | print("Unexpected error:", sys.exc_info())
121 | exit(1)
122 |
123 | #region unzip file
124 | def unzipfile(filename, target_dir):
125 | """uzip file
126 |
127 | Args:
128 | zipfile (string): zip file full path
129 | target_dir (string): target dir
130 | """
131 | with zipfile.ZipFile(filename,"r") as zip_ref:
132 | zip_ref.extractall(target_dir)
133 | #endregion
134 |
135 |
136 | def isfileExists(filepath):
137 | return os.path.exists(filepath)
138 |
--------------------------------------------------------------------------------
/tools/funclib.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from sklearn.model_selection import train_test_split
3 | from sklearn import metrics
4 | from sklearn import linear_model
5 | from sklearn.svm import SVC
6 | from sklearn import tree
7 | from tkinter import _flatten
8 | from sklearn.ensemble import RandomForestClassifier
9 | from sklearn.ensemble import GradientBoostingClassifier
10 | from sklearn.neighbors import KNeighborsClassifier
11 | from xgboost import XGBClassifier
12 | from Bio import SeqIO
13 |
14 | import pandas as pd
15 | import numpy as np
16 | import os
17 |
18 |
19 | table_head = [ 'id',
20 | 'name',
21 | 'isemzyme',
22 | 'isMultiFunctional',
23 | 'functionCounts',
24 | 'ec_number',
25 | 'ec_specific_level',
26 | 'date_integraged',
27 | 'date_sequence_update',
28 | 'date_annotation_update',
29 | 'seq',
30 | 'seqlength'
31 | ]
32 |
33 | def table2fasta(table, file_out):
34 | file = open(file_out, 'w')
35 | for index, row in table.iterrows():
36 | file.write('>{0}\n'.format(row['id']))
37 | file.write('{0}\n'.format(row['seq']))
38 | file.close()
39 | print('Write finished')
40 |
41 | #氨基酸字典(X未知项用于数据对齐)
42 | prot_dict = dict(
43 | A=1, R=2, N=3, D=4, C=5, E=6, Q=7, G=8, H=9, O=10, I=11, L=12,
44 | K=13, M=14, F=15, P=16, U=17, S=18, T=19, W=20, Y=21, V=22, B=23, Z=24, X=0
45 | )
46 |
47 | # one-hot 编码
48 | def dna_onehot_outdateed(Xdna):
49 | listtmp = list()
50 | for index, row in Xdna.iterrows():
51 | row = [prot_dict[x] if x in prot_dict else x for x in row['seq']]
52 | listtmp.append(row)
53 | return pd.DataFrame(listtmp)
54 |
55 | # one-hot 编码
56 | def dna_onehot(Xdna):
57 | listtmp = []
58 | listtmp =Xdna.seq.parallel_apply(lambda x: np.array([prot_dict.get(item) for item in x]))
59 | listtmp = pd.DataFrame(np.stack(listtmp))
60 | listtmp = pd.concat( [Xdna.iloc[:,0:2], listtmp], axis=1)
61 | return listtmp
62 |
63 |
64 |
65 |
66 | def lrmain(X_train_std, Y_train, X_test_std, Y_test, type='binary'):
67 | logreg = linear_model.LogisticRegression(
68 | solver = 'saga',
69 | multi_class='auto',
70 | verbose=False,
71 | n_jobs=-1,
72 | max_iter=10000
73 | )
74 | # sc = StandardScaler()
75 | # X_train_std = sc.fit_transform(X_train_std)
76 | logreg.fit(X_train_std, Y_train)
77 | predict = logreg.predict(X_test_std)
78 | lrpredpro = logreg.predict_proba(X_test_std)
79 | groundtruth = Y_test
80 | return groundtruth, predict, lrpredpro, logreg
81 |
82 | def knnmain(X_train_std, Y_train, X_test_std, Y_test, type='binary'):
83 | knn=KNeighborsClassifier(n_neighbors=5, n_jobs=16)
84 | knn.fit(X_train_std, Y_train)
85 | predict = knn.predict(X_test_std)
86 | lrpredpro = knn.predict_proba(X_test_std)
87 | groundtruth = Y_test
88 | return groundtruth, predict, lrpredpro, knn
89 |
90 | def svmmain(X_train_std, Y_train, X_test_std, Y_test):
91 | svcmodel = SVC(probability=True, kernel='rbf', tol=0.001)
92 | svcmodel.fit(X_train_std, Y_train.ravel(), sample_weight=None)
93 | predict = svcmodel.predict(X_test_std)
94 | predictprob =svcmodel.predict_proba(X_test_std)
95 | groundtruth = Y_test
96 | return groundtruth, predict, predictprob,svcmodel
97 |
98 | def xgmain(X_train_std, Y_train, X_test_std, Y_test, type='binary', vali=True):
99 |
100 | x_train, x_vali, y_train, y_vali = train_test_split(X_train_std, Y_train, test_size=0.2, random_state=1)
101 | eval_set = [(x_train, y_train), (x_vali, y_vali)]
102 |
103 | if type=='binary':
104 | # model = XGBClassifier(
105 | # objective='binary:logistic',
106 | # random_state=15,
107 | # use_label_encoder=False,
108 | # n_jobs=-1,
109 | # eval_metric='mlogloss',
110 | # min_child_weight=15,
111 | # max_depth=15,
112 | # n_estimators=300,
113 | # tree_method = 'gpu_hist',
114 | # learning_rate = 0.01
115 | # )
116 | model = XGBClassifier(
117 | objective='binary:logistic',
118 | random_state=13,
119 | use_label_encoder=False,
120 | n_jobs=-2,
121 | eval_metric='auc',
122 | max_depth=6,
123 | n_estimators= 300,
124 | tree_method = 'gpu_hist',
125 | learning_rate = 0.01,
126 | gpu_id=0
127 | )
128 | model.fit(x_train, y_train, eval_set=eval_set, verbose=False)
129 | # model.fit(t1_x_train, t1_y_train, eval_set=t1_eval_set, verbose=100)
130 | if type=='multi':
131 | model = XGBClassifier(
132 | min_child_weight=6,
133 | max_depth=6,
134 | objective='multi:softmax',
135 | num_class=len(set(Y_train)),
136 | use_label_encoder=False,
137 | n_estimators=120
138 | )
139 | if vali:
140 | model.fit(x_train, y_train, eval_metric="mlogloss", eval_set=eval_set, verbose=False)
141 | else:
142 | model.fit(X_train_std, Y_train, eval_metric="mlogloss", eval_set=None, verbose=False)
143 |
144 | predict = model.predict(X_test_std)
145 | predictprob = model.predict_proba(X_test_std)
146 | groundtruth = Y_test
147 | return groundtruth, predict, predictprob, model
148 |
149 | def dtmain(X_train_std, Y_train, X_test_std, Y_test):
150 | model = tree.DecisionTreeClassifier()
151 | model.fit(X_train_std, Y_train.ravel())
152 | predict = model.predict(X_test_std)
153 | predictprob = model.predict_proba(X_test_std)
154 | groundtruth = Y_test
155 | return groundtruth, predict, predictprob,model
156 |
157 | def rfmain(X_train_std, Y_train, X_test_std, Y_test):
158 | model = RandomForestClassifier(oob_score=True, random_state=10, n_jobs=-2)
159 | model.fit(X_train_std, Y_train.ravel())
160 | predict = model.predict(X_test_std)
161 | predictprob = model.predict_proba(X_test_std)
162 | groundtruth = Y_test
163 | return groundtruth, predict, predictprob,model
164 |
165 | def gbdtmain(X_train_std, Y_train, X_test_std, Y_test):
166 | model = GradientBoostingClassifier(random_state=10)
167 | model.fit(X_train_std, Y_train.ravel())
168 | predict = model.predict(X_test_std)
169 | predictprob = model.predict_proba(X_test_std)
170 | groundtruth = Y_test
171 | return groundtruth, predict, predictprob, model
172 |
173 |
174 | def caculateMetrix(groundtruth, predict, baselineName, type='binary'):
175 | acc = metrics.accuracy_score(groundtruth, predict)
176 | if type == 'binary':
177 | precision = metrics.precision_score(groundtruth, predict, zero_division=True )
178 | recall = metrics.recall_score(groundtruth, predict, zero_division=True)
179 | f1 = metrics.f1_score(groundtruth, predict, zero_division=True)
180 | tn, fp, fn, tp = metrics.confusion_matrix(groundtruth, predict).ravel()
181 | npv = tn/(fn+tn+1.4E-45)
182 | print(baselineName, '\t\t%f' %acc,'\t%f'% precision,'\t\t%f'%npv,'\t%f'% recall,'\t%f'% f1, '\t', 'tp:',tp,'fp:',fp,'fn:',fn,'tn:',tn)
183 |
184 | if type == 'multi':
185 | precision = metrics.precision_score(groundtruth, predict, average='macro', zero_division=True )
186 | recall = metrics.recall_score(groundtruth, predict, average='macro', zero_division=True)
187 | f1 = metrics.f1_score(groundtruth, predict, average='macro', zero_division=True)
188 | print('%12s'%baselineName, ' \t\t%f '%acc,'\t%f'% precision, '\t\t%f'% recall,'\t%f'% f1)
189 |
190 |
191 | def evaluate(baslineName, X_train_std, Y_train, X_test_std, Y_test, type='binary'):
192 |
193 | if baslineName == 'lr':
194 | groundtruth, predict, predictprob, model = lrmain (X_train_std, Y_train, X_test_std, Y_test, type=type)
195 | elif baslineName == 'svm':
196 | groundtruth, predict, predictprob, model = svmmain(X_train_std, Y_train, X_test_std, Y_test)
197 | elif baslineName =='xg':
198 | groundtruth, predict, predictprob, model = xgmain(X_train_std, Y_train, X_test_std, Y_test, type=type)
199 | elif baslineName =='dt':
200 | groundtruth, predict, predictprob, model = dtmain(X_train_std, Y_train, X_test_std, Y_test)
201 | elif baslineName =='rf':
202 | groundtruth, predict, predictprob, model = rfmain(X_train_std, Y_train, X_test_std, Y_test)
203 | elif baslineName =='gbdt':
204 | groundtruth, predict, predictprob, model = gbdtmain(X_train_std, Y_train, X_test_std, Y_test)
205 | elif baslineName =='knn':
206 | groundtruth, predict, predictprob, model = knnmain(X_train_std, Y_train, X_test_std, Y_test, type=type)
207 | else:
208 | print('Baseline Name Errror')
209 |
210 | caculateMetrix(groundtruth=groundtruth,predict=predict,baselineName=baslineName, type=type)
211 |
212 | def evaluate_2(baslineName, X_train_std, Y_train, X_test_std, Y_test, type='binary'):
213 |
214 | if baslineName == 'lr':
215 | groundtruth, predict, predictprob, model = lrmain (X_train_std, Y_train, X_test_std, Y_test, type=type)
216 | elif baslineName == 'svm':
217 | groundtruth, predict, predictprob, model = svmmain(X_train_std, Y_train, X_test_std, Y_test)
218 | elif baslineName =='xg':
219 | groundtruth, predict, predictprob, model = xgmain(X_train_std, Y_train, X_test_std, Y_test, type=type, vali=False)
220 | elif baslineName =='dt':
221 | groundtruth, predict, predictprob, model = dtmain(X_train_std, Y_train, X_test_std, Y_test)
222 | elif baslineName =='rf':
223 | groundtruth, predict, predictprob, model = rfmain(X_train_std, Y_train, X_test_std, Y_test)
224 | elif baslineName =='gbdt':
225 | groundtruth, predict, predictprob, model = gbdtmain(X_train_std, Y_train, X_test_std, Y_test)
226 | elif baslineName =='knn':
227 | groundtruth, predict, predictprob, model = knnmain(X_train_std, Y_train, X_test_std, Y_test, type=type)
228 | else:
229 | print('Baseline Name Errror')
230 |
231 | caculateMetrix(groundtruth=groundtruth,predict=predict,baselineName=baslineName, type=type)
232 |
233 | def run_baseline(X_train, Y_train, X_test, Y_test, type='binary'):
234 | methods=['knn','lr', 'xg', 'dt', 'rf', 'gbdt']
235 | if type == 'binary':
236 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t\t confusion Matrix')
237 | if type =='multi':
238 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro')
239 | for method in methods:
240 | evaluate(method, X_train, Y_train, X_test, Y_test, type=type)
241 |
242 | def run_baseline_2(X_train, Y_train, X_test, Y_test, type='binary'):
243 | methods=['knn', 'xg', 'dt', 'rf', 'gbdt']
244 | if type == 'binary':
245 | print('baslineName', '\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', '\t\t confusion Matrix')
246 | if type =='multi':
247 | print('%12s'%'baslineName', '\t\t', 'accuracy','\t', 'precision-macro \t', 'recall-macro','\t', 'f1-macro')
248 | for method in methods:
249 | evaluate_2(method, X_train, Y_train, X_test, Y_test, type=type)
250 |
251 |
252 | def static_interval(data, span):
253 | """[summary]
254 |
255 | Args:
256 | data ([dataframe]): [需要统计的数据]
257 | span ([int]): [统计的间隔]
258 |
259 | Returns:
260 | [type]: [完成的统计列表]]
261 | """
262 | res = []
263 | count = 0
264 | for i in range(int(len(data)/span) + 1):
265 | lable = str(i*span) + '-' + str((i+1)* span -1 )
266 | num = data[(data.length>=(i*span)) & (data.length<(i+1)*span)]['count'].sum()
267 | res += [[lable, num]]
268 | return res
269 |
270 |
271 |
272 | def getblast(train, test):
273 |
274 | table2fasta(train, '/tmp/train.fasta')
275 | table2fasta(test, '/tmp/test.fasta')
276 |
277 | cmd1 = r'diamond makedb --in /tmp/train.fasta -d /tmp/train.dmnd --quiet'
278 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet'
279 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv'
280 | print(cmd1)
281 | os.system(cmd1)
282 | print(cmd2)
283 | os.system(cmd2)
284 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore'])
285 | os.system(cmd3)
286 | return res_data
287 |
288 | def getblast_usedb(db, test):
289 | table2fasta(test, '/tmp/test.fasta')
290 | cmd2 = r'diamond blastp -d {0} -q /tmp/test.fasta -o /tmp/test_fasta_results.tsv -b5 -c1 -k 1 --quiet'.format(db)
291 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv'
292 |
293 | print(cmd2)
294 | os.system(cmd2)
295 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore'])
296 | os.system(cmd3)
297 | return res_data
298 |
299 | def getblast_fasta(trainfasta, testfasta):
300 |
301 | cmd1 = r'diamond makedb --in {0} -d /tmp/train.dmnd --quiet'.format(trainfasta)
302 | cmd2 = r'diamond blastp -d /tmp/train.dmnd -q {0} -o /tmp/test_fasta_results.tsv -b8 -c1 -k 1 --quiet'.format(testfasta)
303 | cmd3 = r'rm -rf /tmp/*.fasta /tmp/*.dmnd /tmp/*.tsv'
304 | # print(cmd1)
305 | os.system(cmd1)
306 | # print(cmd2)
307 | os.system(cmd2)
308 | res_data = pd.read_csv('/tmp/test_fasta_results.tsv', sep='\t', names=['id', 'sseqid', 'pident', 'length','mismatch','gapopen','qstart','qend','sstart','send','evalue','bitscore'])
309 | os.system(cmd3)
310 | return res_data
311 |
312 |
313 | #region 统计EC数
314 | def stiatistic_ec_num(eclist):
315 | """统计EC数量
316 |
317 | Args:
318 | eclist (list): 可包含多功能酶的EC列表,用;分割
319 |
320 | Returns:
321 | int: 列表中包含的独立EC数量
322 | """
323 | eclist = list(eclist.flatten()) #展成1维
324 | eclist = _flatten([item.split(';') for item in eclist]) #分割多功能酶
325 | eclist = [item.strip() for item in eclist] # 去空格
326 | num_ecs = len(set(eclist))
327 | return num_ecs
328 | #endregion
329 |
330 | def caculateMetrix_1(baselineName,tp, fp, tn,fn):
331 | sampleNum = tp+fp+tn+fn
332 | accuracy = (tp+tn)/sampleNum
333 | precision = tp/(tp+fp)
334 | npv = tn/(tn+fn)
335 | recall = tp/(tp+fn)
336 | f1 = 2 * (precision * recall) / (precision + recall)
337 | print('{0} \t {1:.6f} \t{2:.6f} \t\t {3:.6f} \t{4:.6f}\t {5:.6f}\t\t \t \t \t tp:{6} fp:{7} fn:{8} tn:{9}'.format(baselineName,accuracy, precision, npv, recall, f1, tp,fp,fn,tn))
338 |
339 |
340 |
341 | def get_integrated_results(res_data, train, test, baslineName):
342 | # 给比对结果添加标签
343 | isEmzyme_dict = {v: k for k,v in zip(train.isemzyme, train.id )}
344 | res_data['diamoion_pred'] = res_data['sseqid'].apply(lambda x: isEmzyme_dict.get(x))
345 |
346 | blast_res = pd.DataFrame
347 | blast_res = res_data[['id','pident','bitscore', 'diamoion_pred']]
348 |
349 | X_train = train.iloc[:,12:]
350 | X_test = test.iloc[:,12:]
351 | Y_train = train.iloc[:,2].astype('int')
352 | Y_test = test.iloc[:,2].astype('int')
353 | X_train = np.array(X_train)
354 | X_test = np.array(X_test)
355 | Y_train = np.array(Y_train).flatten()
356 | Y_test = np.array(Y_test).flatten()
357 |
358 | if baslineName == 'lr':
359 | groundtruth, predict, predictprob = lrmain (X_train, Y_train, X_test, Y_test)
360 | elif baslineName == 'svm':
361 | groundtruth, predict, predictprob = svmmain(X_train, Y_train, X_test, Y_test)
362 | elif baslineName =='xg':
363 | groundtruth, predict, predictprob = xgmain(X_train, Y_train, X_test, Y_test)
364 | elif baslineName =='dt':
365 | groundtruth, predict, predictprob = dtmain(X_train, Y_train, X_test, Y_test)
366 | elif baslineName =='rf':
367 | groundtruth, predict, predictprob = rfmain(X_train, Y_train, X_test, Y_test)
368 | elif baslineName =='gbdt':
369 | groundtruth, predict, predictprob = gbdtmain(X_train, Y_train, X_test, Y_test)
370 | else:
371 | print('Baseline Name Errror')
372 |
373 | test_res = pd.DataFrame()
374 | test_res[['id', 'name','isemzyme','ec_number']] = test[['id','name','isemzyme','ec_number']]
375 | test_res.reset_index(drop=True, inplace=True)
376 |
377 | #拼合比对结果到测试集
378 | test_merge_res = pd.merge(test_res, blast_res, on='id', how='left')
379 | test_merge_res['xg_pred'] = predict
380 | test_merge_res['xg_pred_prob'] = predictprob
381 | test_merge_res['groundtruth'] = groundtruth
382 |
383 | test_merge_res['final_pred'] = ''
384 | for index, row in test_merge_res.iterrows():
385 | if (row.diamoion_pred == True) | (row.diamoion_pred == False):
386 | with pd.option_context('mode.chained_assignment', None):
387 | test_merge_res['final_pred'][index] = row.diamoion_pred
388 | else:
389 | with pd.option_context('mode.chained_assignment', None):
390 | test_merge_res['final_pred'][index] = row.xg_pred
391 |
392 | # 计算指标
393 | tp = len(test_merge_res[test_merge_res.groundtruth & test_merge_res.final_pred])
394 | fp = len(test_merge_res[(test_merge_res.groundtruth ==False) & (test_merge_res.final_pred)])
395 | tn = len(test_merge_res[(test_merge_res.groundtruth ==False) & (test_merge_res.final_pred ==False)])
396 | fn = len(test_merge_res[(test_merge_res.groundtruth ) & (test_merge_res.final_pred == False)])
397 | caculateMetrix_1(baslineName,tp, fp, tn,fn)
398 |
399 |
400 | def run_integrated(res_data, train, test):
401 | methods=['lr','xg', 'dt', 'rf', 'gbdt']
402 | print('baslineName', '\t\t', 'accuracy','\t', 'precision(PPV) \t NPV \t\t', 'recall','\t', 'f1', '\t\t', 'auroc','\t\t', 'auprc', '\t\t confusion Matrix')
403 | for method in methods:
404 | get_integrated_results(res_data, train, test, method)
405 |
406 |
407 |
408 | #region 将多功能的EC编号展开,返回唯一的EC编号列表
409 | def get_distinct_ec(ecnumbers):
410 | """
411 | 将多功能的EC编号展开,返回唯一的EC编号列表
412 | Args:
413 | ecnumbers: EC_number 列
414 |
415 | Returns: 排序好的唯一EC列表
416 |
417 | """
418 | result_list=[]
419 | for item in ecnumbers:
420 | ecarray = item.split(',')
421 | for subitem in ecarray:
422 | result_list+=[subitem.strip()]
423 | return sorted(list(set(result_list)))
424 | #endregion
425 |
426 |
427 |
428 | #region 将多功能酶拆解为多个单功能酶
429 | def split_ecdf_to_single_lines(full_table):
430 | """
431 | 将多功能酶拆解为多个单功能酶
432 | Args:
433 | full_table: 包含EC号的完整列表
434 | 并 1.去除酶号前后空格
435 | 并 2. 将酶号拓展为4位的标准格式
436 | Returns: 展开后的EC列表,每个EC号一行
437 | """
438 | listres=full_table.parallel_apply(lambda x: split_ecdf_to_single_lines_pr_record(x) , axis=1)
439 | temp_li = []
440 | for res in tqdm(listres):
441 | for j in res:
442 | temp_li = temp_li + [j]
443 | resDf = pd.DataFrame(temp_li,columns=full_table.columns.values)
444 |
445 | return resDf
446 | #endregion
447 |
448 |
449 | def split_ecdf_to_single_lines_pr_record(row):
450 | resDf = []
451 |
452 | if row.ec_number.strip()=='-': #若是非酶直接返回
453 | row.ec_number='-'
454 | row.ec_number = row.ec_number.strip()
455 | resDf = row.values
456 | return [[row.id, row.seq, row.ec_number]]
457 | else:
458 | ecs = row.ec_number.split(',') #拆解多功能酶
459 | if len(ecs) ==1: # 单功能酶直接返回
460 | return [[row.id, row.seq, row.ec_number]]
461 | for ec in ecs:
462 | ec = ec.strip()
463 | ecarray=ec.split('.') #拆解每一位
464 | if ecarray[3] == '': #若是最后一位是空,补足_
465 | ec=ec+'-'
466 | row.ec_number = ec.strip()
467 | resDf = resDf + [[row.id, row.seq, ec]]
468 | return resDf
469 |
470 |
471 |
472 | def load_fasta_to_table(file):
473 | """[Load fasta file to DataFrame]
474 |
475 | Args:
476 | file ([string]): [fasta file location]
477 |
478 | Returns:
479 | [DataFrame]: [loaded fasta in DF format]
480 | """
481 | if os.path.exists(file) == False:
482 | print('file not found:{0}'.format(file))
483 | return nullcontext
484 |
485 | input_data=[]
486 | for record in SeqIO.parse(file, format='fasta'):
487 | input_data=input_data +[[record.id, str(record.seq)]]
488 | input_df = pd.DataFrame(input_data, columns=['id','seq'])
489 | return input_df
490 |
491 | def load_deepec_resluts(filepath):
492 | """load deepec predicted resluts
493 |
494 | Args:
495 | filepath (string): deepec predicted file
496 |
497 | Returns:
498 | DataFrame: columns=['id', 'ec_deepec']
499 | """
500 | res_deepec = pd.read_csv(f'{filepath}', sep='\t',names=['id', 'ec_number'], header=0 )
501 | res_deepec.ec_number=res_deepec.apply(lambda x: x['ec_number'].replace('EC:',''), axis=1)
502 | res_deepec.columns = ['id','ec_deepec']
503 | res = []
504 | for index, group in res_deepec.groupby('id'):
505 | if len(group)==1:
506 | res = res + [[group.id.values[0], group.ec_deepec.values[0]]]
507 | else:
508 | ecs_str = ','.join(group.ec_deepec.values)
509 | res = res +[[group.id.values[0],ecs_str]]
510 | res_deepec = pd.DataFrame(res, columns=['id', 'ec_deepec'])
511 | return res_deepec
512 |
513 |
514 | #region
515 | def load_praim_res(resfile):
516 | """[加载PRIAM的预测结果]
517 | Args:
518 | resfile ([string]): [结果文件]
519 | Returns:
520 | [DataFrame]: [结果]
521 | """
522 | f = open(resfile)
523 | line = f.readline()
524 | counter =0
525 | reslist=[]
526 | lstr =''
527 | subec=[]
528 | while line:
529 | if '>' in line:
530 | if counter !=0:
531 | reslist +=[[lstr, ', '.join(subec)]]
532 | subec=[]
533 | lstr = line.replace('>', '').replace('\n', '')
534 | elif line.strip()!='':
535 | ecarray = line.split('\t')
536 | subec += [(ecarray[0].replace('#', '').replace('\n', '').replace(' ', '') )]
537 |
538 | line = f.readline()
539 | counter +=1
540 | f.close()
541 | res_priam=pd.DataFrame(reslist, columns=['id', 'ec_priam'])
542 | return res_priam
543 | #endregion
544 |
545 | def load_catfam_res(resfile):
546 | res_catfam = pd.read_csv(resfile, sep='\t', names=['id', 'ec_catfam'])
547 | return res_catfam
548 |
549 |
550 | def load_ecpred_res(resfile):
551 | res_ecpred = pd.read_csv(f'{resfile}', sep='\t', header=0)
552 | res_ecpred = res_ecpred.rename(columns={'Protein ID':'id','EC Number':'ec_ecpred','Confidence Score(max 1.0)':'pident_ecpred'})
553 | res_ecpred['ec_ecpred']= res_ecpred.ec_ecpred.apply(lambda x : '-' if x=='non Enzyme' else x)
554 | return res_ecpred
--------------------------------------------------------------------------------
/tools/uniprottool.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Zhenkun Shi
3 | Date: 2023-07-17 00:30:14
4 | LastEditors: Zhenkun Shi
5 | LastEditTime: 2023-07-17 00:49:11
6 | FilePath: /ECRECer/tools/uniprottool.py
7 | Description:
8 |
9 | Copyright (c) 2023 by tibd, All Rights Reserved.
10 | '''
11 |
12 |
13 | from Bio import SeqIO
14 | import gzip
15 | import re
16 | from tqdm import tqdm
17 | import time
18 | import sys,os
19 | sys.path.insert(0, os.path.dirname(os.path.realpath('__file__')))
20 | sys.path.insert(1,'../')
21 | import config as cfg
22 | import pandas as pd
23 | import requests
24 | from requests.adapters import HTTPAdapter, Retry
25 |
26 | #region 从gizp读取数据
27 | def read_file_from_gzip(file_in_path, file_out_path, extract_type, save_file_type='tsv'):
28 | """从原始Zip file 中解析数据
29 |
30 | Args:
31 | file_in_path ([string]): [输入文件路径]
32 | file_out_path ([string]): [输出文件路径]]
33 | extract_type ([string]): [抽取数据的类型:with_ec, without_ec, full]]
34 | """
35 | if save_file_type == 'feather':
36 | outpath = file_out_path
37 | file_out_path = cfg.TEMP_DIR +'temprecords.tsv'
38 |
39 | table_head = [ 'id',
40 | 'name',
41 | 'isenzyme',
42 | 'isMultiFunctional',
43 | 'functionCounts',
44 | 'ec_number',
45 | 'ec_specific_level',
46 | 'date_integraged',
47 | 'date_sequence_update',
48 | 'date_annotation_update',
49 | 'seq',
50 | 'seqlength'
51 | ]
52 | counter = 0
53 | saver = 0
54 | file_write_obj = open(file_out_path,'w')
55 | with gzip.open(file_in_path, "rt") as handle:
56 | file_write_obj.writelines('\t'.join(table_head))
57 | file_write_obj.writelines('\n')
58 | for record in tqdm( SeqIO.parse(handle, 'swiss'), position=1, leave=True):
59 | res = process_record(record, extract_type= extract_type)
60 | counter+=1
61 | if counter %10==0:
62 | file_write_obj.flush()
63 | # if saver%100000==0:
64 | # print(saver)
65 | if len(res) >0:
66 | saver +=1
67 | file_write_obj.writelines('\t'.join(map(str,res)))
68 | file_write_obj.writelines('\n')
69 | else:
70 | continue
71 | file_write_obj.close()
72 |
73 | if save_file_type == 'feather':
74 | indata = pd.read_csv(cfg.TEMP_DIR +'temprecords.tsv', sep='\t')
75 | indata.to_feather(outpath)
76 |
77 |
78 | #endregion
79 |
80 | #region 提取单条含有EC号的数据
81 | def process_record(record, extract_type='with_ec'):
82 | """
83 | 提取单条含有EC号的数据
84 | Args:
85 | record ([type]): uniprot 中的记录节点
86 | extract_type (string, optional): 提取的类型,可选值:with_ec, without_ec, full。默认为有EC号(with_ec).
87 | Returns:
88 | [type]: [description]
89 | """
90 |
91 | description = record.description
92 | isEnzyme = 'EC=' in description #有EC号的被认为是酶,否则认为是酶
93 | isMultiFunctional = False
94 | functionCounts = 0
95 | ec_specific_level =0
96 |
97 |
98 | if isEnzyme:
99 | ec = str(re.findall(r"EC=[0-9,.\-;]*",description)
100 | ).replace('EC=','').replace('\'','').replace(']','').replace('[','').replace(';','').strip()
101 |
102 | #统计酶的功能数
103 | isMultiFunctional = ',' in ec
104 | functionCounts = ec.count(',') + 1
105 |
106 |
107 | # - 单功能酶
108 | if not isMultiFunctional:
109 | levelCount = ec.count('-')
110 | ec_specific_level = 4-levelCount
111 |
112 | else: # -多功能酶
113 | ecarray = ec.split(',')
114 | for subec in ecarray:
115 | current_ec_level = 4- subec.count('-')
116 | if ec_specific_level < current_ec_level:
117 | ec_specific_level = current_ec_level
118 | else:
119 | ec = '-'
120 |
121 | id = record.id.strip()
122 | name = record.name.strip()
123 | seq = record.seq.strip()
124 | date_integrated = record.annotations.get('date').strip()
125 | date_sequence_update = record.annotations.get('date_last_sequence_update').strip()
126 | date_annotation_update = record.annotations.get('date_last_annotation_update').strip()
127 | seqlength = len(seq)
128 | res = [id, name, isEnzyme, isMultiFunctional, functionCounts, ec,ec_specific_level, date_integrated, date_sequence_update, date_annotation_update, seq, seqlength]
129 |
130 | if extract_type == 'full':
131 | return res
132 |
133 | if extract_type == 'with_ec':
134 | if isEnzyme:
135 | return res
136 | else:
137 | return []
138 |
139 | if extract_type == 'without_ec':
140 | if isEnzyme:
141 | return []
142 | else:
143 | return res
144 | #endregion
145 |
146 | def run_exact_task(infile, outfile):
147 | start = time.process_time()
148 | extract_type ='full'
149 | read_file_from_gzip(file_in_path=infile, file_out_path=outfile, extract_type=extract_type)
150 | end = time.process_time()
151 | print('finished use time %6.3f s' % (end - start))
152 |
153 | def get_next_link(headers):
154 | re_next_link = re.compile(r'<(.+)>; rel="next"')
155 | if "Link" in headers:
156 | match = re_next_link.match(headers["Link"])
157 | if match:
158 | return match.group(1)
159 |
160 | def get_batch(batch_url, session):
161 | while batch_url:
162 | response = session.get(batch_url)
163 | response.raise_for_status()
164 | total = response.headers["x-total-results"]
165 | yield response, total
166 | batch_url = get_next_link(response.headers)
167 |
168 | def get_batch_data_from_uniprot_rest_api(url):
169 | # url = 'https://rest.uniprot.org/uniprotkb/search?fields=accession%2Ccc_interaction&format=tsv&query=Insulin%20AND%20%28reviewed%3Atrue%29&size=500'
170 | res = []
171 | session = requests.Session()
172 | retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
173 | session.mount("https://", HTTPAdapter(max_retries=retries))
174 |
175 | for batch, total in tqdm(get_batch(url, session)):
176 | batch_records = batch.text.splitlines()[1:]
177 | res = res + batch_records
178 |
179 | res = [item.split('\t') for item in res]
180 | return res
181 |
182 |
183 | if __name__ =="__main__":
184 | print('success')
185 | # start = time.process_time()
186 | # in_filepath_sprot = cfg.FILE_LATEST_SPROT
187 | # out_filepath_sprot = cfg.FILE_LATEST_SPROT_FEATHER
188 |
189 | # in_filepath_trembl = cfg.FILE_LATEST_TREMBL
190 | # out_filepath_trembl = cfg.FILE_LATEST_TREMBL_FEATHER
191 |
192 | # extract_type ='full'
193 |
194 |
195 | # # read_file_from_gzip(file_in_path=in_filepath_sprot, file_out_path=out_filepath_sprot, extract_type=extract_type, save_file_type='feather')
196 |
197 | # read_file_from_gzip(file_in_path=in_filepath_trembl, file_out_path=out_filepath_trembl, extract_type=extract_type, save_file_type='feather')
198 | # end = time.process_time()
199 | # print('finished use time %6.3f s' % (end - start))
200 |
201 |
--------------------------------------------------------------------------------
/update_production.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5c2801b5-1a62-4abb-9030-0a6ef709cf24",
6 | "metadata": {},
7 | "source": [
8 | "# Update Production Data and Model\n",
9 | "\n",
10 | "> author: Shizhenkun \n",
11 | "> email: zhenkun.shi@tib.cas.cn \n",
12 | "> date: 2021-12-24 \n",
13 | "\n",
14 | "This file contains update codes for the production server. The update should be scheduled every eight weeks."
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "id": "48326994-9d5b-4905-96ed-36fa3ed72dd9",
20 | "metadata": {},
21 | "source": [
22 | "## 1. Import packages"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "id": "96a76da5-08be-41be-b05a-cd68b6d6a789",
29 | "metadata": {
30 | "tags": []
31 | },
32 | "outputs": [
33 | {
34 | "ename": "ModuleNotFoundError",
35 | "evalue": "No module named 'joblib'",
36 | "output_type": "error",
37 | "traceback": [
38 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
39 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
40 | "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mconfig\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mcfg\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m reduce\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjoblib\u001b[39;00m\n\u001b[1;32m 9\u001b[0m sys\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./tools/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mfunclib\u001b[39;00m\n",
41 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'joblib'"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "import numpy as np\n",
47 | "import pandas as pd\n",
48 | "import sys\n",
49 | "\n",
50 | "import config as cfg\n",
51 | "from functools import reduce\n",
52 | "import joblib\n",
53 | "\n",
54 | "sys.path.append(\"./tools/\")\n",
55 | "import funclib\n",
56 | "import exact_ec_from_uniprot as exactec\n",
57 | "import minitools as mtool\n",
58 | "import benchmark_common as bcommon\n",
59 | "import embedding_esm as esmebd\n",
60 | "\n",
61 | "from pandarallel import pandarallel \n",
62 | "pandarallel.initialize() \n",
63 | "import benchmark_train as btrain\n",
64 | "\n",
65 | "%load_ext autoreload\n",
66 | "%autoreload 2"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "id": "b4fe1bf3-da2e-4796-abad-0029f6e81319",
72 | "metadata": {},
73 | "source": [
74 | "## 2. Define Functions"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 2,
80 | "id": "4ffaeb29-3068-4230-a5a0-13d85cca0f01",
81 | "metadata": {
82 | "tags": []
83 | },
84 | "outputs": [],
85 | "source": [
86 | "# install axel for download dataset\n",
87 | "def install_axel():\n",
88 | " isExists = !which axel\n",
89 | " if 'axel' in str(isExists[0]):\n",
90 | " return True\n",
91 | " else:\n",
92 | " !sudo apt install axel -y\n",
93 | "\n",
94 | "# add missing '-' for ec number\n",
95 | "def refill_ec(ec): \n",
96 | " if ec == '-':\n",
97 | " return ec\n",
98 | " levelArray = ec.split('.')\n",
99 | " if levelArray[3]=='':\n",
100 | " levelArray[3] ='-'\n",
101 | " ec = '.'.join(levelArray)\n",
102 | " return ec\n",
103 | "\n",
104 | "def specific_ecs(ecstr):\n",
105 | " if '-' not in ecstr or len(ecstr)<4:\n",
106 | " return ecstr\n",
107 | " ecs = ecstr.split(',')\n",
108 | " if len(ecs)==1:\n",
109 | " return ecstr\n",
110 | " \n",
111 | " reslist=[]\n",
112 | " \n",
113 | " for ec in ecs:\n",
114 | " recs = ecs.copy()\n",
115 | " recs.remove(ec)\n",
116 | " ecarray = np.array([x.split('.') for x in recs])\n",
117 | " \n",
118 | " if '-' not in ec:\n",
119 | " reslist +=[ec]\n",
120 | " continue\n",
121 | " linearray= ec.split('.')\n",
122 | " if linearray[1] == '-':\n",
123 | " #l1 in l1s and l2 not empty\n",
124 | " if (linearray[0] in ecarray[:,0]) and (len(set(ecarray[:,0]) - set({'-'}))>0):\n",
125 | " continue\n",
126 | " if linearray[2] == '-':\n",
127 | " # l1, l2 in l1s l2s, l3 not empty\n",
128 | " if (linearray[0] in ecarray[:,0]) and (linearray[1] in ecarray[:,1]) and (len(set(ecarray[:,2]) - set({'-'}))>0):\n",
129 | " continue\n",
130 | " if linearray[3] == '-':\n",
131 | " # l1, l2, l3 in l1s l2s l3s, l4 not empty\n",
132 | " if (linearray[0] in ecarray[:,0]) and (linearray[1] in ecarray[:,1]) and (linearray[2] in ecarray[:,2]) and (len(set(ecarray[:,3]) - set({'-'}))>0):\n",
133 | " continue\n",
134 | " \n",
135 | " reslist +=[ec]\n",
136 | " return ','.join(reslist)\n",
137 | "\n",
138 | "#format ec\n",
139 | "def format_ec(ecstr):\n",
140 | " ecArray= ecstr.split(',')\n",
141 | " ecArray=[x.strip() for x in ecArray] #strip blank\n",
142 | " ecArray=[refill_ec(x) for x in ecArray] #format ec to full\n",
143 | " ecArray = list(set(ecArray)) # remove duplicates\n",
144 | " \n",
145 | " return ','.join(ecArray)"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "id": "b5e4912f-59f0-4555-9172-edcb548056e1",
151 | "metadata": {},
152 | "source": [
153 | "## 3. Download latest data from unisprot"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 8,
159 | "id": "a53aba9e-9232-44a2-a554-8f670e536fb0",
160 | "metadata": {
161 | "tags": []
162 | },
163 | "outputs": [],
164 | "source": [
165 | "# download location ./tmp\n",
166 | "\n",
167 | "# ! mv $cfg.DATADIR'uniprot_sprot_latest.dat.gz' $cfg.TEMPDIR$currenttime'_uniprot_sprot_latest.dat.gz'\n",
168 | "# install_axel()\n",
169 | "! axel -n 10 https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz -o ./data/uniprot_sprot_latest.dat.gz -q -c"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "id": "5d7df6ac-7247-4045-bca2-58277bb61d24",
175 | "metadata": {},
176 | "source": [
177 | "## 4. Preprocessing"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 9,
183 | "id": "afee4738-dc8e-4637-8d7f-b288035ab3ff",
184 | "metadata": {},
185 | "outputs": [
186 | {
187 | "name": "stderr",
188 | "output_type": "stream",
189 | "text": [
190 | "566996it [04:36, 2050.78it/s]\n"
191 | ]
192 | },
193 | {
194 | "name": "stdout",
195 | "output_type": "stream",
196 | "text": [
197 | "finished use time 275.733 s\n"
198 | ]
199 | }
200 | ],
201 | "source": [
202 | "exactec.run_exact_task(infile=cfg.DATADIR+'uniprot_sprot_latest.dat.gz', outfile=cfg.DATADIR+'sprot_latest.tsv')\n",
203 | "\n",
204 | "#加载数据并转换时间格式\n",
205 | "sprot_latest = pd.read_csv(cfg.DATADIR+'sprot_latest.tsv', sep='\\t',header=0) #读入文件\n",
206 | "sprot_latest = mtool.convert_DF_dateTime(inputdf = sprot_latest)\n",
207 | "\n",
208 | "sprot_latest.drop_duplicates(subset=['seq'], keep='first', inplace=True)\n",
209 | "sprot_latest.reset_index(drop=True, inplace=True)\n",
210 | "\n",
211 | "#sprot_latest format EC\n",
212 | "sprot_latest['ec_number'] = sprot_latest.ec_number.parallel_apply(lambda x: format_ec(x))\n",
213 | "sprot_latest['ec_number'] = sprot_latest.ec_number.parallel_apply(lambda x: specific_ecs(x))\n",
214 | "sprot_latest['functionCounts'] = sprot_latest.ec_number.parallel_apply(lambda x: 0 if x=='-' else len(x.split(',')))\n",
215 | "\n",
216 | "# Trim Strging\n",
217 | "with pd.option_context('mode.chained_assignment', None):\n",
218 | " sprot_latest.ec_number = sprot_latest.ec_number.parallel_apply(lambda x : str(x).strip()) #ec trim\n",
219 | " sprot_latest.seq = sprot_latest.seq.parallel_apply(lambda x : str(x).strip()) #seq trim\n",
220 | "\n",
221 | "sprot_latest.to_feather(cfg.DATADIR + 'latest_sprot.feather')\n"
222 | ]
223 | },
224 | {
225 | "cell_type": "markdown",
226 | "id": "4ef3c2fa-5187-404d-9df3-bbc657287910",
227 | "metadata": {},
228 | "source": [
229 | "## 5. Caculation Features"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 13,
235 | "id": "d4d7fc8a-65ad-4248-a2d2-649e0b813745",
236 | "metadata": {},
237 | "outputs": [
238 | {
239 | "name": "stdout",
240 | "output_type": "stream",
241 | "text": [
242 | "train size: 478954\n"
243 | ]
244 | }
245 | ],
246 | "source": [
247 | "train= pd.read_feather(cfg.DATADIR + 'latest_sprot.feather')\n",
248 | "print('train size: {0}'.format(len(train)))"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "id": "e5fe8cb0-9e00-4f0c-9045-bef4d106204e",
255 | "metadata": {},
256 | "outputs": [
257 | {
258 | "name": "stdout",
259 | "output_type": "stream",
260 | "text": [
261 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep0.feather': No such file or directory\n",
262 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep32.feather': No such file or directory\n",
263 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_rep33.feather': No such file or directory\n",
264 | "mv: cannot stat '/home/shizhenkun/codebase/DMLF/data/sprot_latest_unirep.feather': No such file or directory\n",
265 | "Transferred model to GPU\n"
266 | ]
267 | },
268 | {
269 | "name": "stderr",
270 | "output_type": "stream",
271 | "text": [
272 | " 10%|█████████ | 47813/478954 [1:00:08<12:13:55, 9.79it/s]"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "! mv $cfg.DATADIR'sprot_latest_rep0.feather' $cfg.DATADIR'featureBank/sprot_latest_rep0.feather'\n",
278 | "! mv $cfg.DATADIR'sprot_latest_rep32.feather' $cfg.DATADIR'featureBank/sprot_latest_rep32.feather'\n",
279 | "! mv $cfg.DATADIR'sprot_latest_rep33.feather' $cfg.DATADIR'featureBank/sprot_latest_rep33.feather'\n",
280 | "! mv $cfg.DATADIR'sprot_latest_unirep.feather' $cfg.DATADIR'featureBank/sprot_latest_unirep.feather'\n",
281 | "\n",
282 | "# !pip install fair-esm\n",
283 | "tr_rep0, tr_rep32, tr_rep33 = esmebd.get_rep_multi_sequence(sequences=train, model='esm1b_t33_650M_UR50S',seqthres=1022)\n",
284 | "tr_rep0.to_feather(cfg.DATADIR + 'sprot_latest_rep0.feather')\n",
285 | "tr_rep32.to_feather(cfg.DATADIR + 'sprot_latest_rep32.feather')\n",
286 | "tr_rep33.to_feather(cfg.DATADIR + 'sprot_latest_rep33.feather')\n",
287 | "\n"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 16,
293 | "id": "1ac61a46-8250-4852-8dcd-96453ba52e2d",
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/html": [
299 | "\n",
300 | "\n",
313 | "
\n",
314 | " \n",
315 | " \n",
316 | " | \n",
317 | " id | \n",
318 | " f1 | \n",
319 | " f2 | \n",
320 | " f3 | \n",
321 | " f4 | \n",
322 | " f5 | \n",
323 | " f6 | \n",
324 | " f7 | \n",
325 | " f8 | \n",
326 | " f9 | \n",
327 | " ... | \n",
328 | " f1271 | \n",
329 | " f1272 | \n",
330 | " f1273 | \n",
331 | " f1274 | \n",
332 | " f1275 | \n",
333 | " f1276 | \n",
334 | " f1277 | \n",
335 | " f1278 | \n",
336 | " f1279 | \n",
337 | " f1280 | \n",
338 | "
\n",
339 | " \n",
340 | " \n",
341 | " \n",
342 | " 0 | \n",
343 | " P60995 | \n",
344 | " 0.170989 | \n",
345 | " 0.025657 | \n",
346 | " 0.025744 | \n",
347 | " 0.020174 | \n",
348 | " -0.076020 | \n",
349 | " -0.191066 | \n",
350 | " 0.085143 | \n",
351 | " 0.504738 | \n",
352 | " -0.036427 | \n",
353 | " ... | \n",
354 | " 0.210663 | \n",
355 | " -0.102772 | \n",
356 | " 0.087211 | \n",
357 | " -0.048700 | \n",
358 | " 0.193667 | \n",
359 | " 0.169694 | \n",
360 | " 0.027922 | \n",
361 | " -0.060504 | \n",
362 | " 0.285435 | \n",
363 | " -0.069696 | \n",
364 | "
\n",
365 | " \n",
366 | " 1 | \n",
367 | " P02396 | \n",
368 | " -0.111111 | \n",
369 | " 0.002945 | \n",
370 | " 0.018040 | \n",
371 | " 0.028491 | \n",
372 | " -0.006897 | \n",
373 | " -0.017328 | \n",
374 | " -0.033452 | \n",
375 | " -0.107921 | \n",
376 | " -0.010232 | \n",
377 | " ... | \n",
378 | " 0.037690 | \n",
379 | " -0.116853 | \n",
380 | " -0.023483 | \n",
381 | " 0.064209 | \n",
382 | " 0.017497 | \n",
383 | " 0.096947 | \n",
384 | " 0.004760 | \n",
385 | " 0.000524 | \n",
386 | " 0.056006 | \n",
387 | " 0.069281 | \n",
388 | "
\n",
389 | " \n",
390 | " 2 | \n",
391 | " P02362 | \n",
392 | " -0.188286 | \n",
393 | " -0.007029 | \n",
394 | " 0.006670 | \n",
395 | " 0.029186 | \n",
396 | " -0.002706 | \n",
397 | " 0.045382 | \n",
398 | " 0.045775 | \n",
399 | " 0.001362 | \n",
400 | " -0.016976 | \n",
401 | " ... | \n",
402 | " 0.034140 | \n",
403 | " -0.077455 | \n",
404 | " -0.041554 | \n",
405 | " 0.071670 | \n",
406 | " -0.001690 | \n",
407 | " 0.097198 | \n",
408 | " 0.019849 | \n",
409 | " 0.001559 | \n",
410 | " -0.008701 | \n",
411 | " 0.083810 | \n",
412 | "
\n",
413 | " \n",
414 | " 3 | \n",
415 | " P02565 | \n",
416 | " -0.026709 | \n",
417 | " -0.000562 | \n",
418 | " 0.013712 | \n",
419 | " 0.003235 | \n",
420 | " -0.038873 | \n",
421 | " 0.023068 | \n",
422 | " -0.035896 | \n",
423 | " -0.070689 | \n",
424 | " -0.016175 | \n",
425 | " ... | \n",
426 | " 0.015740 | \n",
427 | " -0.104805 | \n",
428 | " -0.004712 | \n",
429 | " 0.036699 | \n",
430 | " 0.034201 | \n",
431 | " 0.067125 | \n",
432 | " -0.000467 | \n",
433 | " 0.003438 | \n",
434 | " -0.048334 | \n",
435 | " 0.014916 | \n",
436 | "
\n",
437 | " \n",
438 | " 4 | \n",
439 | " P02827 | \n",
440 | " -0.117281 | \n",
441 | " 0.004121 | \n",
442 | " 0.021203 | \n",
443 | " 0.024243 | \n",
444 | " -0.045509 | \n",
445 | " 0.132612 | \n",
446 | " -0.007418 | \n",
447 | " 0.035875 | \n",
448 | " -0.022678 | \n",
449 | " ... | \n",
450 | " -0.003237 | \n",
451 | " -0.081184 | \n",
452 | " -0.021881 | \n",
453 | " 0.024642 | \n",
454 | " 0.038222 | \n",
455 | " 0.041906 | \n",
456 | " 0.009969 | \n",
457 | " 0.000780 | \n",
458 | " -0.109430 | \n",
459 | " 0.042850 | \n",
460 | "
\n",
461 | " \n",
462 | " ... | \n",
463 | " ... | \n",
464 | " ... | \n",
465 | " ... | \n",
466 | " ... | \n",
467 | " ... | \n",
468 | " ... | \n",
469 | " ... | \n",
470 | " ... | \n",
471 | " ... | \n",
472 | " ... | \n",
473 | " ... | \n",
474 | " ... | \n",
475 | " ... | \n",
476 | " ... | \n",
477 | " ... | \n",
478 | " ... | \n",
479 | " ... | \n",
480 | " ... | \n",
481 | " ... | \n",
482 | " ... | \n",
483 | " ... | \n",
484 | "
\n",
485 | " \n",
486 | " 478949 | \n",
487 | " C0P9J6 | \n",
488 | " -0.092407 | \n",
489 | " 0.010308 | \n",
490 | " 0.022830 | \n",
491 | " 0.019094 | \n",
492 | " -0.056988 | \n",
493 | " 0.064407 | \n",
494 | " -0.032028 | \n",
495 | " 0.017303 | \n",
496 | " -0.020415 | \n",
497 | " ... | \n",
498 | " -0.028978 | \n",
499 | " -0.085610 | \n",
500 | " -0.015938 | \n",
501 | " 0.048861 | \n",
502 | " 0.018688 | \n",
503 | " 0.082718 | \n",
504 | " 0.016897 | \n",
505 | " -0.003126 | \n",
506 | " -0.093333 | \n",
507 | " 0.065130 | \n",
508 | "
\n",
509 | " \n",
510 | " 478950 | \n",
511 | " C0PTH8 | \n",
512 | " -0.130944 | \n",
513 | " 0.006237 | \n",
514 | " 0.021503 | \n",
515 | " 0.022638 | \n",
516 | " -0.044376 | \n",
517 | " 0.128483 | \n",
518 | " 0.023310 | \n",
519 | " 0.020793 | \n",
520 | " -0.019937 | \n",
521 | " ... | \n",
522 | " -0.005362 | \n",
523 | " -0.079125 | \n",
524 | " -0.017278 | \n",
525 | " 0.019771 | \n",
526 | " 0.032946 | \n",
527 | " 0.054444 | \n",
528 | " 0.021701 | \n",
529 | " -0.001576 | \n",
530 | " -0.117068 | \n",
531 | " 0.037637 | \n",
532 | "
\n",
533 | " \n",
534 | " 478951 | \n",
535 | " D5KXD2 | \n",
536 | " -0.098420 | \n",
537 | " 0.001806 | \n",
538 | " 0.023656 | \n",
539 | " 0.018481 | \n",
540 | " -0.055579 | \n",
541 | " 0.117249 | \n",
542 | " 0.030040 | \n",
543 | " 0.014601 | \n",
544 | " -0.021143 | \n",
545 | " ... | \n",
546 | " 0.004568 | \n",
547 | " -0.075909 | \n",
548 | " -0.011751 | \n",
549 | " 0.036392 | \n",
550 | " 0.019632 | \n",
551 | " 0.073973 | \n",
552 | " 0.020997 | \n",
553 | " -0.001993 | \n",
554 | " -0.101203 | \n",
555 | " 0.055563 | \n",
556 | "
\n",
557 | " \n",
558 | " 478952 | \n",
559 | " A0A2Z4HPZ0 | \n",
560 | " -0.123796 | \n",
561 | " 0.002207 | \n",
562 | " 0.021613 | \n",
563 | " 0.022929 | \n",
564 | " -0.055302 | \n",
565 | " 0.112782 | \n",
566 | " 0.004356 | \n",
567 | " 0.053342 | \n",
568 | " -0.014881 | \n",
569 | " ... | \n",
570 | " 0.000617 | \n",
571 | " -0.074901 | \n",
572 | " -0.015275 | \n",
573 | " 0.033363 | \n",
574 | " 0.018851 | \n",
575 | " 0.058429 | \n",
576 | " 0.019783 | \n",
577 | " -0.003662 | \n",
578 | " -0.126644 | \n",
579 | " 0.052289 | \n",
580 | "
\n",
581 | " \n",
582 | " 478953 | \n",
583 | " W7N2P0 | \n",
584 | " -0.207070 | \n",
585 | " 0.000797 | \n",
586 | " 0.016863 | \n",
587 | " 0.026736 | \n",
588 | " 0.013239 | \n",
589 | " -0.061995 | \n",
590 | " 0.060844 | \n",
591 | " 0.069271 | \n",
592 | " -0.026021 | \n",
593 | " ... | \n",
594 | " 0.012974 | \n",
595 | " -0.110873 | \n",
596 | " -0.028724 | \n",
597 | " 0.059315 | \n",
598 | " 0.034717 | \n",
599 | " 0.052476 | \n",
600 | " 0.032170 | \n",
601 | " 0.007732 | \n",
602 | " 0.017107 | \n",
603 | " 0.083767 | \n",
604 | "
\n",
605 | " \n",
606 | "
\n",
607 | "
478954 rows × 1281 columns
\n",
608 | "
"
609 | ],
610 | "text/plain": [
611 | " id f1 f2 f3 f4 f5 \\\n",
612 | "0 P60995 0.170989 0.025657 0.025744 0.020174 -0.076020 \n",
613 | "1 P02396 -0.111111 0.002945 0.018040 0.028491 -0.006897 \n",
614 | "2 P02362 -0.188286 -0.007029 0.006670 0.029186 -0.002706 \n",
615 | "3 P02565 -0.026709 -0.000562 0.013712 0.003235 -0.038873 \n",
616 | "4 P02827 -0.117281 0.004121 0.021203 0.024243 -0.045509 \n",
617 | "... ... ... ... ... ... ... \n",
618 | "478949 C0P9J6 -0.092407 0.010308 0.022830 0.019094 -0.056988 \n",
619 | "478950 C0PTH8 -0.130944 0.006237 0.021503 0.022638 -0.044376 \n",
620 | "478951 D5KXD2 -0.098420 0.001806 0.023656 0.018481 -0.055579 \n",
621 | "478952 A0A2Z4HPZ0 -0.123796 0.002207 0.021613 0.022929 -0.055302 \n",
622 | "478953 W7N2P0 -0.207070 0.000797 0.016863 0.026736 0.013239 \n",
623 | "\n",
624 | " f6 f7 f8 f9 ... f1271 f1272 \\\n",
625 | "0 -0.191066 0.085143 0.504738 -0.036427 ... 0.210663 -0.102772 \n",
626 | "1 -0.017328 -0.033452 -0.107921 -0.010232 ... 0.037690 -0.116853 \n",
627 | "2 0.045382 0.045775 0.001362 -0.016976 ... 0.034140 -0.077455 \n",
628 | "3 0.023068 -0.035896 -0.070689 -0.016175 ... 0.015740 -0.104805 \n",
629 | "4 0.132612 -0.007418 0.035875 -0.022678 ... -0.003237 -0.081184 \n",
630 | "... ... ... ... ... ... ... ... \n",
631 | "478949 0.064407 -0.032028 0.017303 -0.020415 ... -0.028978 -0.085610 \n",
632 | "478950 0.128483 0.023310 0.020793 -0.019937 ... -0.005362 -0.079125 \n",
633 | "478951 0.117249 0.030040 0.014601 -0.021143 ... 0.004568 -0.075909 \n",
634 | "478952 0.112782 0.004356 0.053342 -0.014881 ... 0.000617 -0.074901 \n",
635 | "478953 -0.061995 0.060844 0.069271 -0.026021 ... 0.012974 -0.110873 \n",
636 | "\n",
637 | " f1273 f1274 f1275 f1276 f1277 f1278 f1279 \\\n",
638 | "0 0.087211 -0.048700 0.193667 0.169694 0.027922 -0.060504 0.285435 \n",
639 | "1 -0.023483 0.064209 0.017497 0.096947 0.004760 0.000524 0.056006 \n",
640 | "2 -0.041554 0.071670 -0.001690 0.097198 0.019849 0.001559 -0.008701 \n",
641 | "3 -0.004712 0.036699 0.034201 0.067125 -0.000467 0.003438 -0.048334 \n",
642 | "4 -0.021881 0.024642 0.038222 0.041906 0.009969 0.000780 -0.109430 \n",
643 | "... ... ... ... ... ... ... ... \n",
644 | "478949 -0.015938 0.048861 0.018688 0.082718 0.016897 -0.003126 -0.093333 \n",
645 | "478950 -0.017278 0.019771 0.032946 0.054444 0.021701 -0.001576 -0.117068 \n",
646 | "478951 -0.011751 0.036392 0.019632 0.073973 0.020997 -0.001993 -0.101203 \n",
647 | "478952 -0.015275 0.033363 0.018851 0.058429 0.019783 -0.003662 -0.126644 \n",
648 | "478953 -0.028724 0.059315 0.034717 0.052476 0.032170 0.007732 0.017107 \n",
649 | "\n",
650 | " f1280 \n",
651 | "0 -0.069696 \n",
652 | "1 0.069281 \n",
653 | "2 0.083810 \n",
654 | "3 0.014916 \n",
655 | "4 0.042850 \n",
656 | "... ... \n",
657 | "478949 0.065130 \n",
658 | "478950 0.037637 \n",
659 | "478951 0.055563 \n",
660 | "478952 0.052289 \n",
661 | "478953 0.083767 \n",
662 | "\n",
663 | "[478954 rows x 1281 columns]"
664 | ]
665 | },
666 | "execution_count": 16,
667 | "metadata": {},
668 | "output_type": "execute_result"
669 | }
670 | ],
671 | "source": [
672 | "tr_rep0"
673 | ]
674 | },
675 | {
676 | "cell_type": "code",
677 | "execution_count": 18,
678 | "id": "c9b2137c-ae11-42f7-a8da-fa699cca7520",
679 | "metadata": {},
680 | "outputs": [],
681 | "source": [
682 | "train_esm_latest = pd.read_feather(cfg.DATADIR + 'sprot_latest_rep32.feather')\n",
683 | "train_esm_latest = train.merge(train_esm_latest, on='id', how='left')"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "id": "5da8140c-e00f-413b-b7c8-94607a44ea59",
689 | "metadata": {},
690 | "source": [
691 | "## 6. Split X Y"
692 | ]
693 | },
694 | {
695 | "cell_type": "code",
696 | "execution_count": 19,
697 | "id": "a5ee3dbe-519c-471e-aa60-feecf84db6d0",
698 | "metadata": {},
699 | "outputs": [
700 | {
701 | "name": "stderr",
702 | "output_type": "stream",
703 | "text": [
704 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 231834/231834 [05:02<00:00, 766.36it/s]\n"
705 | ]
706 | },
707 | {
708 | "name": "stdout",
709 | "output_type": "stream",
710 | "text": [
711 | "loading ec to label dict\n"
712 | ]
713 | }
714 | ],
715 | "source": [
716 | "# task 1\n",
717 | "X_train_task1 = np.array(train_esm_latest.iloc[:,12:])\n",
718 | "Y_train_task1 = np.array(train_esm_latest.isenzyme.astype('int')).flatten()\n",
719 | "train_enzyme = train_esm_latest[train_esm_latest.isenzyme].reset_index(drop=True)\n",
720 | "\n",
721 | "# task 2\n",
722 | "X_train_task2_s = np.array(train_enzyme.iloc[:,12:])\n",
723 | "Y_train_task2_s = train_enzyme.functionCounts.apply(lambda x : 0 if x==1 else 1).astype('int').values\n",
724 | "\n",
725 | "train_task2M=train_enzyme[train_enzyme.functionCounts>=2].reset_index(drop=True)\n",
726 | "X_train_task2_m = np.array(train_task2M.iloc[:,12:])\n",
727 | "Y_train_task2_m = np.array(train_task2M.functionCounts.astype('int')-2).flatten()\n",
728 | "\n",
729 | "#task 3\n",
730 | "train_set_task3= funclib.split_ecdf_to_single_lines(train_enzyme.iloc[:,np.r_[0,10,5]])\n",
731 | "train_set_task3=train_set_task3.merge(train_esm_latest.iloc[:,np.r_[0,12:1292]], on='id', how='left')\n",
732 | "\n",
733 | "#4. Loading EC Numbers\n",
734 | "print('loading ec to label dict')\n",
735 | "dict_ec_label = btrain.make_ec_label(train_label=train_set_task3['ec_number'], test_label=train_set_task3['ec_number'], file_save= cfg.FILE_EC_LABEL_DICT, force_model_update=cfg.UPDATE_MODEL)\n",
736 | "\n",
737 | "train_set_task3['ec_label']=train_set_task3.ec_number.parallel_apply(lambda x: dict_ec_label.get(x)) \n",
738 | "X_train_task3 = np.array(train_set_task3.iloc[:,3:])\n",
739 | "Y_train_task3 = np.array(train_set_task3.ec_label.astype('int')).flatten()"
740 | ]
741 | },
742 | {
743 | "cell_type": "markdown",
744 | "id": "2f5e89d3-b5f5-4897-bb23-afaa0db3d388",
745 | "metadata": {},
746 | "source": [
747 | "## 7. Train Model"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "execution_count": null,
753 | "id": "df91a62a-0eee-44b1-9b0a-65b5e60d8bc0",
754 | "metadata": {},
755 | "outputs": [],
756 | "source": []
757 | },
758 | {
759 | "cell_type": "code",
760 | "execution_count": null,
761 | "id": "a52c6fe4-94de-4b69-be93-427535de880e",
762 | "metadata": {},
763 | "outputs": [],
764 | "source": []
765 | }
766 | ],
767 | "metadata": {
768 | "kernelspec": {
769 | "display_name": "ECRECer",
770 | "language": "python",
771 | "name": "ecrecer"
772 | },
773 | "language_info": {
774 | "codemirror_mode": {
775 | "name": "ipython",
776 | "version": 3
777 | },
778 | "file_extension": ".py",
779 | "mimetype": "text/x-python",
780 | "name": "python",
781 | "nbconvert_exporter": "python",
782 | "pygments_lexer": "ipython3",
783 | "version": "3.7.10"
784 | },
785 | "widgets": {
786 | "application/vnd.jupyter.widget-state+json": {
787 | "state": {},
788 | "version_major": 2,
789 | "version_minor": 0
790 | }
791 | }
792 | },
793 | "nbformat": 4,
794 | "nbformat_minor": 5
795 | }
796 |
--------------------------------------------------------------------------------