├── LSTM_peptides ├── README.md ├── requirements.txt ├── lstm_pretrain.sh ├── lstm_pretrain_uniprot.sh ├── sample.py ├── LSTM_peptides.py ├── LSTM_peptides_fine.py └── LSTM_peptides_sample.py ├── Transformer_AA ├── README.md ├── trans_pretrain_2.sh ├── trans_pretrain_1.sh ├── trans_pretrain_all.sh ├── trans_pretrain.sh ├── load_data.py ├── predict.py ├── train_eval.py ├── finetuning.py └── model.py ├── README.md └── LICENSE /LSTM_peptides/README.md: -------------------------------------------------------------------------------- 1 | LSTM_peptides 2 | -------------------------------------------------------------------------------- /Transformer_AA/README.md: -------------------------------------------------------------------------------- 1 | Transformer_AA 2 | -------------------------------------------------------------------------------- /LSTM_peptides/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | scipy==1.5.0 3 | numpy==1.19.5 4 | scikit-learn==0.23.1 5 | tensorflow==2.5.0 6 | progressbar2==3.53.1 7 | modlamp>=4.2.3 8 | -------------------------------------------------------------------------------- /Transformer_AA/trans_pretrain_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | folder="./pretrain_data/pdb_train_transformer" 5 | 6 | echo "run TRANS_peptides 2..." 7 | 8 | softfiles=$(ls $folder) 9 | 10 | model_file='final_finetuning_model.pt' 11 | 12 | for sfile in ${softfiles} 13 | do 14 | echo "process:" $folder/${sfile} 15 | 16 | python finetuning.py 2 $model_file ${sfile} 17 | 18 | echo "finish:" $folder/${sfile} 19 | 20 | done -------------------------------------------------------------------------------- /Transformer_AA/trans_pretrain_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder="./train_data/uniprot_train_transformer" 4 | 5 | echo "run TRANS_peptides 1..." 6 | 7 | softfiles=$(ls $folder) 8 | 9 | #model_file='model.pt' 10 | 11 | model_file='final_finetuning_model.pt' 12 | 13 | for sfile in ${softfiles} 14 | do 15 | echo "process:" $folder/${sfile} 16 | 17 | python finetuning.py 1 $model_file ${sfile} 18 | 19 | echo "finish:" $folder/${sfile} 20 | 21 | 22 | 23 | 24 | done 25 | 26 | -------------------------------------------------------------------------------- /LSTM_peptides/lstm_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #folder="/home/shun/LSTM_peptides/pdb_train_lstm" 4 | folder="/root/autodl-tmp/LSTM_peptides/pdb_train_lstm" 5 | 6 | echo "run LSTM_peptides fine_tune..." 7 | 8 | softfiles=$(ls $folder) 9 | 10 | last_file="pdb" 11 | 12 | for sfile in ${softfiles} 13 | do 14 | echo "process:" $folder/${sfile} 15 | 16 | 17 | python LSTM_peptides.py --name ${sfile%.*} --dataset $folder/${sfile} --modfile ./$last_file/checkpoint/model_epoch_22.hdf5 --epochs 30 18 | 19 | echo "finish:" $folder/${sfile} 20 | 21 | last_file=${sfile%.*} 22 | 23 | 24 | done 25 | -------------------------------------------------------------------------------- /LSTM_peptides/lstm_pretrain_uniprot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #folder="/home/shun/LSTM_peptides/pdb_train_lstm" 4 | folder="/root/autodl-tmp/LSTM_peptides/uniprot_train_lstm" 5 | 6 | echo "run LSTM_peptides fine_tune.2.." 7 | 8 | softfiles=$(ls $folder) 9 | 10 | last_file="18_pdb_train_lstm" 11 | 12 | for sfile in ${softfiles} 13 | do 14 | echo "process:" $folder/${sfile} 15 | 16 | 17 | python LSTM_peptides.py --name ${sfile%.*} --dataset $folder/${sfile} --modfile ./$last_file/checkpoint/model_epoch_29.hdf5 --epochs 30 18 | 19 | echo "finish:" $folder/${sfile} 20 | 21 | last_file=${sfile%.*} 22 | 23 | 24 | done 25 | -------------------------------------------------------------------------------- /LSTM_peptides/sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Script to sample with different temperatures from a trained model. 5 | Need to turn off 'ask' when initializing model 6 | """ 7 | 8 | import os 9 | 10 | sample = 36209 11 | #temps = [0.5, 1., 1.5] 12 | temps = [1.25] 13 | 14 | name = '9_uniprot_train_lstm' 15 | sepoch = 29 16 | maxlen = 500 17 | 18 | pwd = '/root/autodl-tmp/LSTM_peptides/' 19 | modfile = pwd + name + '/checkpoint/model_epoch_%i.hdf5' % sepoch 20 | 21 | for t in temps: 22 | print("\nSampling %i sequences at %.1f temperature..." % (sample, t)) 23 | cmd = "python %sLSTM_peptides_sample.py --train False --modfile %s " \ 24 | "--temp %.1f --sample %i --maxlen %i --name mjs" % (pwd, modfile, t, sample, maxlen) 25 | os.system(cmd) 26 | -------------------------------------------------------------------------------- /Transformer_AA/trans_pretrain_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder="./train_data/uniprot_train_transformer" 4 | 5 | echo "run TRANS_peptides 1..." 6 | 7 | softfiles=$(ls $folder) 8 | 9 | model_file='model.pt' 10 | 11 | for sfile in ${softfiles} 12 | do 13 | echo "process:" $folder/${sfile} 14 | 15 | python finetuning.py 1 $model_file ${sfile} 16 | 17 | echo "finish:" $folder/${sfile} 18 | 19 | model_file='final_finetuning_model.pt' 20 | 21 | 22 | done 23 | 24 | folder="./pretrain_data/pdb_train_transformer" 25 | 26 | echo "run TRANS_peptides 2..." 27 | 28 | softfiles=$(ls $folder) 29 | 30 | model_file='final_finetuning_model.pt' 31 | 32 | for sfile in ${softfiles} 33 | do 34 | echo "process:" $folder/${sfile} 35 | 36 | python finetuning.py 2 $model_file ${sfile} 37 | 38 | echo "finish:" $folder/${sfile} 39 | 40 | done -------------------------------------------------------------------------------- /Transformer_AA/trans_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder="./train_data/uniprot_train_transformer" 4 | 5 | echo "run TRANS_peptides 1..." 6 | 7 | softfiles=$(ls $folder) 8 | 9 | model_file='model.pt' 10 | 11 | for sfile in ${softfiles} 12 | do 13 | echo "process:" $folder/${sfile} 14 | 15 | python finetuning.py 1 $model_file ${sfile} 16 | 17 | echo "finish:" $folder/${sfile} 18 | 19 | model_file='final_finetuning_model.pt' 20 | 21 | 22 | done 23 | 24 | folder="./pretrain_data/pdb_train_transformer" 25 | 26 | echo "run TRANS_peptides 2..." 27 | 28 | softfiles=$(ls $folder) 29 | 30 | model_file='final_finetuning_model.pt' 31 | 32 | for sfile in ${softfiles} 33 | do 34 | echo "process:" $folder/${sfile} 35 | 36 | python finetuning.py 2 $model_file ${sfile} 37 | 38 | echo "finish:" $folder/${sfile} 39 | 40 | done -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMPTrans-lstm 2 | Application of deep generative model discovers novel and diverse functional peptides against microbial resistance 3 | 4 | 5 | ## Requirements 6 | ```python 7 | matplotlib==3.2.2 8 | scipy==1.5.0 9 | numpy==1.19.5 10 | scikit-learn==0.23.1 11 | tensorflow==2.5.0 12 | progressbar2==3.53.1 13 | modlamp>=4.2.3 14 | 15 | ``` 16 | 17 | ## LSTM_peptides 18 | 19 | 20 | 21 | ## Transformer_AA 22 | ```python 23 | train: 24 | python train_eval.py 25 | 26 | finetune: 27 | python finetuning.py 28 | ``` 29 | ## Cite: 30 | 31 | Mao, Jiashun, Shenghui Guan, Yongqing Chen, Amir Zeb, Qingxiang Sun, Ranlan Lu, Jie Dong, Jianmin Wang, and Dongsheng Cao. "Application of a deep generative model produces novel and diverse functional peptides against microbial resistance." Computational and Structural Biotechnology Journal (2022). 32 | https://doi.org/10.1016/j.csbj.2022.12.029 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jianmin Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Transformer_AA/load_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torchtext.legacy import data 4 | import re 5 | #device = "cuda" if torch.cuda.is_available() else 'cpu' 6 | device ='cpu' 7 | 8 | def tokenizer(text): 9 | token = [tok for tok in list(text)] 10 | return token 11 | 12 | def smiles_atom_tokenizer(smi): 13 | pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 14 | regex = re.compile(pattern) 15 | tokens = [token for token in regex.findall(smi)] 16 | return tokens 17 | 18 | def data_gen(train_path='./train_data',train_data='uniprot_train_data.csv',valid_data='pdb_train_data.csv'): 19 | TEXT = data.Field(tokenize=tokenizer, 20 | init_token = '', 21 | eos_token = '', 22 | lower = False, #True 23 | batch_first = True) 24 | 25 | train, val = data.TabularDataset.splits( 26 | path=train_path, 27 | train=train_data, 28 | validation=valid_data, 29 | format='csv', 30 | skip_header=True, 31 | fields=[('trg', TEXT), ('src', TEXT)]) 32 | 33 | TEXT.build_vocab(train, min_freq=2) 34 | id2vocab = TEXT.vocab.itos 35 | vocab2id = TEXT.vocab.stoi 36 | PAD_IDX = vocab2id[TEXT.pad_token] 37 | UNK_IDX = vocab2id[TEXT.unk_token] 38 | SOS_IDX = vocab2id[TEXT.init_token] 39 | EOS_IDX = vocab2id[TEXT.eos_token] 40 | 41 | #train_iter 自动shuffle, val_iter 按照sort_key排序,传入Decoder或者Encoder的sequence的长度不能超过模型中 position embedding 的 "vocabulary" size 42 | train_iter, val_iter = data.BucketIterator.splits( 43 | (train, val), 44 | batch_sizes=(8, 8), 45 | sort_key=lambda x: len(x.src), 46 | device=device) 47 | return train_iter, val_iter, id2vocab, PAD_IDX,TEXT,vocab2id,UNK_IDX 48 | 49 | ################################################################################################################### 50 | def pre_data_gen(TEXT=None,train_path ='./train_data/uniprot_train_transformer' ,train_data='1_uniprot_train_transformer.csv',valid_data='56_uniprot_train_transformer.csv'): 51 | pre_train, pre_val = data.TabularDataset.splits( 52 | path=train_path, 53 | train=train_data, 54 | validation=valid_data, 55 | format='csv', 56 | skip_header=True, 57 | fields=[('trg', TEXT), ('src', TEXT)]) 58 | 59 | pre_train_iter, pre_val_iter = data.BucketIterator.splits( 60 | (pre_train, pre_val), 61 | batch_sizes=(8, 8), 62 | sort_key=lambda x: len(x.src), 63 | device=device) 64 | return pre_train_iter, pre_val_iter -------------------------------------------------------------------------------- /Transformer_AA/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from load_data import data_gen,pre_data_gen 4 | pre_train_iter, pre_val_iter, id2vocab, PAD_IDX,TEXT,vocab2id,UNK_IDX = data_gen() 5 | import pandas as pd 6 | from model import Encoder, Decoder, Transformer 7 | import re 8 | device = "cuda" if torch.cuda.is_available() else 'cpu' 9 | 10 | #device ='cpu' 11 | 12 | INPUT_DIM = len(id2vocab) 13 | OUTPUT_DIM = len(id2vocab) 14 | HID_DIM = 256 15 | ENC_LAYERS = 4 16 | DEC_LAYERS = 4 17 | ENC_HEADS = 2 18 | DEC_HEADS = 2 19 | ENC_PF_DIM = 512 20 | DEC_PF_DIM = 512 21 | ENC_DROPOUT = 0.1 22 | DEC_DROPOUT = 0.1 23 | N_EPOCHS = 100 24 | CLIP = 1 25 | max_length = 2000 26 | 27 | enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device,max_length) 28 | dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device,max_length) 29 | model = Transformer(enc, dec, PAD_IDX, device).to(device) 30 | model.load_state_dict(torch.load('final_finetuning_model.pt')) 31 | model.eval() 32 | 33 | #embeding_data = "./pretrain_data/AMP_standardseq_trans.csv" 34 | embeding_data = "./final.txt" 35 | embeding_data = pd.read_csv(embeding_data, sep=',',header=0) 36 | #print(embeding_data) 37 | 38 | def smiles_atom_tokenizer(smi): 39 | pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 40 | regex = re.compile(pattern) 41 | tokens = [token for token in regex.findall(smi)] 42 | return tokens 43 | 44 | for i in embeding_data["src"].values: 45 | tokens = smiles_atom_tokenizer(i) 46 | #tokens = [tok.lower() for tok in list(i)] 47 | 48 | tokens = [TEXT.init_token] + tokens + [TEXT.eos_token] 49 | 50 | src_indexes = [vocab2id.get(token, UNK_IDX) for token in tokens] 51 | src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) 52 | src_mask = model.make_src_mask(src_tensor) 53 | 54 | with torch.no_grad(): 55 | enc_src = model.encoder(src_tensor, src_mask) 56 | 57 | #print(i,enc_src) 58 | 59 | trg_indexes = [vocab2id[TEXT.init_token]] 60 | 61 | for i in range(1000): 62 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 63 | trg_mask = model.make_trg_mask(trg_tensor) 64 | with torch.no_grad(): 65 | output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) 66 | 67 | pred_token = output.argmax(2)[:,-1].item() 68 | trg_indexes.append(pred_token) 69 | 70 | if pred_token == vocab2id[TEXT.eos_token]: 71 | trg_indexes = trg_indexes[:-1] 72 | break 73 | 74 | trg_tokens = [id2vocab[i] for i in trg_indexes] 75 | 76 | print("".join(trg_tokens[1:])) 77 | 78 | 79 | ''' 80 | sent = '中新网9月19日电据英国媒体报道,当地时间19日,苏格兰公投结果出炉,55%选民投下反对票,对独立说“不”。在结果公布前,英国广播公司(BBC)预测,苏格兰选民以55%对45%投票反对独立。' 81 | tokens = [tok for tok in jieba.cut(sent)] 82 | tokens = [TEXT.init_token] + tokens + [TEXT.eos_token] 83 | 84 | src_indexes = [vocab2id.get(token, UNK_IDX) for token in tokens] 85 | src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) 86 | src_mask = model.make_src_mask(src_tensor) 87 | 88 | with torch.no_grad(): 89 | enc_src = model.encoder(src_tensor, src_mask) 90 | 91 | trg_indexes = [vocab2id[TEXT.init_token]] 92 | 93 | for i in range(50): 94 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 95 | trg_mask = model.make_trg_mask(trg_tensor) 96 | with torch.no_grad(): 97 | output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) 98 | 99 | pred_token = output.argmax(2)[:,-1].item() 100 | trg_indexes.append(pred_token) 101 | 102 | if pred_token == vocab2id[TEXT.eos_token]: 103 | break 104 | 105 | trg_tokens = [id2vocab[i] for i in trg_indexes] 106 | 107 | print(trg_tokens[1:]) 108 | ''' -------------------------------------------------------------------------------- /Transformer_AA/train_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from model import Encoder, Decoder, Transformer 10 | plt.switch_backend('agg') 11 | device = "cuda" if torch.cuda.is_available() else 'cpu' 12 | #device ='cpu' 13 | print("device:",device) 14 | 15 | from load_data import data_gen,pre_data_gen 16 | train_iter, val_iter, id2vocab, PAD_IDX,TEXT,vocab2id,UNK_IDX = data_gen() 17 | train_iter, val_iter = pre_data_gen(TEXT=TEXT,train_path ='./train_data', 18 | train_data='0_uniprot_train_transformer.csv',valid_data='56_uniprot_train_transformer.csv') 19 | 20 | INPUT_DIM = len(id2vocab) 21 | OUTPUT_DIM = len(id2vocab) 22 | HID_DIM = 256 23 | ENC_LAYERS = 4 24 | DEC_LAYERS = 4 25 | ENC_HEADS = 2 26 | DEC_HEADS = 2 27 | ENC_PF_DIM = 512 28 | DEC_PF_DIM = 512 29 | ENC_DROPOUT = 0.1 30 | DEC_DROPOUT = 0.1 31 | N_EPOCHS = 100 32 | CLIP = 1 33 | max_length = 2000 34 | 35 | save_pt_freq = 10 36 | 37 | enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device,max_length) 38 | dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device,max_length) 39 | 40 | model = Transformer(enc, dec, PAD_IDX, device).to(device) 41 | 42 | def initialize_weights(m): 43 | if hasattr(m, 'weight') and m.weight.dim() > 1: 44 | nn.init.xavier_uniform_(m.weight.data) 45 | model.apply(initialize_weights) 46 | 47 | optimizer = optim.Adam(model.parameters(), lr=5e-5) 48 | #we ignore the loss whenever the target token is a padding token. 49 | criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX) 50 | 51 | 52 | 53 | loss_vals = [] 54 | loss_vals_eval = [] 55 | for epoch in range(N_EPOCHS): 56 | model.train() 57 | epoch_loss= [] 58 | pbar = tqdm(train_iter) 59 | pbar.set_description("[Train Epoch {}]".format(epoch)) 60 | for batch in pbar: 61 | trg, src = batch.trg.to(device), batch.src.to(device) 62 | model.zero_grad() 63 | output, _ = model(src, trg[:,:-1]) 64 | #trg = [batch size, trg len] 65 | #output = [batch size, trg len-1, output dim] 66 | output_dim = output.shape[-1] 67 | output = output.contiguous().view(-1, output_dim) 68 | trg = trg[:,1:].contiguous().view(-1) 69 | #trg = [(trg len - 1) * batch size] 70 | #output = [(trg len - 1) * batch size, output dim] 71 | loss = criterion(output, trg) 72 | loss.backward() 73 | torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) 74 | epoch_loss.append(loss.item()) 75 | optimizer.step() 76 | pbar.set_postfix(loss=loss.item()) 77 | loss_vals.append(np.mean(epoch_loss)) 78 | 79 | model.eval() 80 | epoch_loss_eval= [] 81 | pbar = tqdm(val_iter) 82 | pbar.set_description("[Eval Epoch {}]".format(epoch)) 83 | for batch in pbar: 84 | trg, src = batch.trg.to(device), batch.src.to(device) 85 | model.zero_grad() 86 | output, _ = model(src, trg[:,:-1]) 87 | #trg = [batch size, trg len] 88 | #output = [batch size, trg len-1, output dim] 89 | output_dim = output.shape[-1] 90 | output = output.contiguous().view(-1, output_dim) 91 | trg = trg[:,1:].contiguous().view(-1) 92 | #trg = [(trg len - 1) * batch size] 93 | #output = [(trg len - 1) * batch size, output dim] 94 | loss = criterion(output, trg) 95 | epoch_loss_eval.append(loss.item()) 96 | pbar.set_postfix(loss=loss.item()) 97 | loss_vals_eval.append(np.mean(epoch_loss_eval)) 98 | 99 | if (epoch+1)%save_pt_freq ==0: 100 | torch.save(model.state_dict(), str(epoch+1)+'_model.pt') 101 | print("save model:",str(epoch+1)+'_model.pt') 102 | 103 | torch.save(model.state_dict(), 'model.pt') 104 | print(loss_vals,loss_vals_eval) 105 | 106 | l1, = plt.plot(np.linspace(1, N_EPOCHS, N_EPOCHS).astype(int), loss_vals) 107 | l2, = plt.plot(np.linspace(1, N_EPOCHS, N_EPOCHS).astype(int), loss_vals_eval) 108 | plt.legend(handles=[l1,l2],labels=['Train loss','Eval loss'],loc='best') 109 | filename = "trans.jpg" 110 | plt.savefig(filename) -------------------------------------------------------------------------------- /Transformer_AA/finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from tqdm import tqdm 8 | from model import Encoder, Decoder, Transformer 9 | import sys 10 | device = "cuda" if torch.cuda.is_available() else 'cpu' 11 | #device ='cpu' 12 | 13 | print("device:",device) 14 | 15 | v2 = int(sys.argv[1]) 16 | model_file = sys.argv[2] 17 | train_data = sys.argv[3] 18 | 19 | from load_data import data_gen,pre_data_gen 20 | pre_train_iter, pre_val_iter, id2vocab, PAD_IDX,TEXT,vocab2id,UNK_IDX = data_gen() 21 | 22 | if v2==1: 23 | #'1_uniprot_train_transformer.csv' 24 | pre_train_iter, pre_val_iter = pre_data_gen(TEXT=TEXT,train_path ='./train_data/uniprot_train_transformer' , 25 | train_data=train_data,valid_data='56_uniprot_train_transformer.csv') 26 | else: 27 | #'0_pdb_train_transformer.csv' 28 | pre_train_iter, pre_val_iter = pre_data_gen(TEXT=TEXT,train_path ='./pretrain_data/pdb_train_transformer' , 29 | train_data=train_data,valid_data='18_pdb_train_transformer.csv') 30 | 31 | INPUT_DIM = len(id2vocab) 32 | OUTPUT_DIM = len(id2vocab) 33 | HID_DIM = 256 34 | ENC_LAYERS = 4 35 | DEC_LAYERS = 4 36 | ENC_HEADS = 2 37 | DEC_HEADS = 2 38 | ENC_PF_DIM = 512 39 | DEC_PF_DIM = 512 40 | ENC_DROPOUT = 0.1 41 | DEC_DROPOUT = 0.1 42 | N_EPOCHS = 100 43 | CLIP = 1 44 | max_length = 2000 45 | 46 | print(INPUT_DIM) 47 | 48 | enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device,max_length) 49 | dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device,max_length) 50 | model = Transformer(enc, dec, PAD_IDX, device).to(device) 51 | model.load_state_dict(torch.load(model_file)) 52 | 53 | save_pt_freq = 10 54 | 55 | optimizer = optim.Adam(model.parameters(), lr=5e-5) 56 | #we ignore the loss whenever the target token is a padding token. 57 | criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX) 58 | 59 | loss_vals = [] 60 | loss_vals_eval = [] 61 | for epoch in range(N_EPOCHS): 62 | model.train() 63 | epoch_loss= [] 64 | pbar = tqdm(pre_train_iter) 65 | pbar.set_description("[finetuning_Train Epoch {}]".format(epoch)) 66 | for batch in pbar: 67 | trg, src = batch.trg.to(device), batch.src.to(device) 68 | model.zero_grad() 69 | output, _ = model(src, trg[:,:-1]) 70 | #trg = [batch size, trg len] 71 | #output = [batch size, trg len-1, output dim] 72 | output_dim = output.shape[-1] 73 | output = output.contiguous().view(-1, output_dim) 74 | trg = trg[:,1:].contiguous().view(-1) 75 | #trg = [(trg len - 1) * batch size] 76 | #output = [(trg len - 1) * batch size, output dim] 77 | loss = criterion(output, trg) 78 | loss.backward() 79 | torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP) 80 | epoch_loss.append(loss.item()) 81 | optimizer.step() 82 | pbar.set_postfix(loss=loss.item()) 83 | loss_vals.append(np.mean(epoch_loss)) 84 | 85 | model.eval() 86 | epoch_loss_eval= [] 87 | pbar = tqdm(pre_val_iter) 88 | pbar.set_description("[finetuning_Eval Epoch {}]".format(epoch)) 89 | for batch in pbar: 90 | trg, src = batch.trg.to(device), batch.src.to(device) 91 | model.zero_grad() 92 | output, _ = model(src, trg[:,:-1]) 93 | #trg = [batch size, trg len] 94 | #output = [batch size, trg len-1, output dim] 95 | output_dim = output.shape[-1] 96 | output = output.contiguous().view(-1, output_dim) 97 | trg = trg[:,1:].contiguous().view(-1) 98 | #trg = [(trg len - 1) * batch size] 99 | #output = [(trg len - 1) * batch size, output dim] 100 | loss = criterion(output, trg) 101 | epoch_loss_eval.append(loss.item()) 102 | pbar.set_postfix(loss=loss.item()) 103 | loss_vals_eval.append(np.mean(epoch_loss_eval)) 104 | 105 | if (epoch+1)%save_pt_freq ==0: 106 | torch.save(model.state_dict(), str(epoch+1)+'_finetuning_model.pt') 107 | print("save model:",str(epoch+1)+'_finetuning_model.pt') 108 | 109 | torch.save(model.state_dict(), 'final_finetuning_model.pt') 110 | 111 | print(loss_vals,loss_vals_eval) -------------------------------------------------------------------------------- /Transformer_AA/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | import torch 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, 7 | input_dim, 8 | hid_dim, 9 | n_layers, 10 | n_heads, 11 | pf_dim, 12 | dropout, 13 | device, 14 | max_length = 500): 15 | super().__init__() 16 | self.device = device 17 | self.tok_embedding = nn.Embedding(input_dim, hid_dim) 18 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 19 | self.layers = nn.ModuleList([EncoderLayer(hid_dim, 20 | n_heads, 21 | pf_dim, 22 | dropout, 23 | device) 24 | for _ in range(n_layers)]) 25 | self.dropout = nn.Dropout(dropout) 26 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 27 | 28 | def forward(self, src, src_mask): 29 | #src = [batch size, src len] 30 | #src_mask = [batch size, 1, 1, src len] 31 | batch_size = src.shape[0] 32 | src_len = src.shape[1] 33 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 34 | #pos = [batch size, src len] 35 | src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) 36 | #src = [batch size, src len, hid dim] 37 | for layer in self.layers: 38 | src = layer(src, src_mask) 39 | #src = [batch size, src len, hid dim] 40 | return src 41 | 42 | class EncoderLayer(nn.Module): 43 | def __init__(self, 44 | hid_dim, 45 | n_heads, 46 | pf_dim, 47 | dropout, 48 | device): 49 | super().__init__() 50 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 51 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 52 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 53 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 54 | pf_dim, 55 | dropout) 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | def forward(self, src, src_mask): 59 | #src = [batch size, src len, hid dim] 60 | #src_mask = [batch size, 1, 1, src len] 61 | #self attention 62 | _src, _ = self.self_attention(src, src, src, src_mask) 63 | #dropout, residual connection and layer norm 64 | src = self.self_attn_layer_norm(src + self.dropout(_src)) 65 | #src = [batch size, src len, hid dim] 66 | #positionwise feedforward 67 | _src = self.positionwise_feedforward(src) 68 | #dropout, residual and layer norm 69 | src = self.ff_layer_norm(src + self.dropout(_src)) 70 | #src = [batch size, src len, hid dim] 71 | return src 72 | 73 | class MultiHeadAttentionLayer(nn.Module): 74 | def __init__(self, hid_dim, n_heads, dropout, device): 75 | super().__init__() 76 | assert hid_dim % n_heads == 0 77 | self.hid_dim = hid_dim 78 | self.n_heads = n_heads 79 | self.head_dim = hid_dim // n_heads 80 | self.fc_q = nn.Linear(hid_dim, hid_dim) 81 | self.fc_k = nn.Linear(hid_dim, hid_dim) 82 | self.fc_v = nn.Linear(hid_dim, hid_dim) 83 | self.fc_o = nn.Linear(hid_dim, hid_dim) 84 | self.dropout = nn.Dropout(dropout) 85 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 86 | 87 | def forward(self, query, key, value, mask = None): 88 | batch_size = query.shape[0] 89 | #query = [batch size, query len, hid dim] 90 | #key = [batch size, key len, hid dim] 91 | #value = [batch size, value len, hid dim] 92 | Q = self.fc_q(query) 93 | K = self.fc_k(key) 94 | V = self.fc_v(value) 95 | #Q = [batch size, query len, hid dim] 96 | #K = [batch size, key len, hid dim] 97 | #V = [batch size, value len, hid dim] 98 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 99 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 100 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 101 | #Q = [batch size, n heads, query len, head dim] 102 | #K = [batch size, n heads, key len, head dim] 103 | #V = [batch size, n heads, value len, head dim] 104 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 105 | #energy = [batch size, n heads, query len, key len] 106 | if mask is not None: 107 | energy = energy.masked_fill(mask == 0, -1e10) 108 | attention = torch.softmax(energy, dim = -1) 109 | #attention = [batch size, n heads, query len, key len] 110 | x = torch.matmul(self.dropout(attention), V) 111 | #x = [batch size, n heads, query len, head dim] 112 | x = x.permute(0, 2, 1, 3).contiguous() 113 | #x = [batch size, query len, n heads, head dim] 114 | x = x.view(batch_size, -1, self.hid_dim) 115 | #x = [batch size, query len, hid dim] 116 | x = self.fc_o(x) 117 | #x = [batch size, query len, hid dim] 118 | return x, attention 119 | 120 | class PositionwiseFeedforwardLayer(nn.Module): 121 | def __init__(self, hid_dim, pf_dim, dropout): 122 | super().__init__() 123 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 124 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 125 | self.dropout = nn.Dropout(dropout) 126 | 127 | def forward(self, x): 128 | #x = [batch size, seq len, hid dim] 129 | x = self.dropout(torch.relu(self.fc_1(x))) 130 | #x = [batch size, seq len, pf dim] 131 | x = self.fc_2(x) 132 | #x = [batch size, seq len, hid dim] 133 | return x 134 | 135 | class Decoder(nn.Module): 136 | def __init__(self, 137 | output_dim, 138 | hid_dim, 139 | n_layers, 140 | n_heads, 141 | pf_dim, 142 | dropout, 143 | device, 144 | max_length = 500): 145 | super().__init__() 146 | self.device = device 147 | self.tok_embedding = nn.Embedding(output_dim, hid_dim) 148 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 149 | self.layers = nn.ModuleList([DecoderLayer(hid_dim, 150 | n_heads, 151 | pf_dim, 152 | dropout, 153 | device) 154 | for _ in range(n_layers)]) 155 | self.fc_out = nn.Linear(hid_dim, output_dim) 156 | self.dropout = nn.Dropout(dropout) 157 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 158 | 159 | def forward(self, trg, enc_src, trg_mask, src_mask): 160 | #trg = [batch size, trg len] 161 | #enc_src = [batch size, src len, hid dim] 162 | #trg_mask = [batch size, 1, trg len, trg len] 163 | #src_mask = [batch size, 1, 1, src len] 164 | batch_size = trg.shape[0] 165 | trg_len = trg.shape[1] 166 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 167 | #pos = [batch size, trg len] 168 | trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) 169 | #trg = [batch size, trg len, hid dim] 170 | for layer in self.layers: 171 | trg, attention = layer(trg, enc_src, trg_mask, src_mask) 172 | #trg = [batch size, trg len, hid dim] 173 | #attention = [batch size, n heads, trg len, src len] 174 | output = self.fc_out(trg) 175 | #output = [batch size, trg len, output dim] 176 | return output, attention 177 | 178 | class DecoderLayer(nn.Module): 179 | def __init__(self, 180 | hid_dim, 181 | n_heads, 182 | pf_dim, 183 | dropout, 184 | device): 185 | super().__init__() 186 | 187 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 188 | self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) 189 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 190 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 191 | self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 192 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 193 | pf_dim, 194 | dropout) 195 | self.dropout = nn.Dropout(dropout) 196 | 197 | def forward(self, trg, enc_src, trg_mask, src_mask): 198 | #trg = [batch size, trg len, hid dim] 199 | #enc_src = [batch size, src len, hid dim] 200 | #trg_mask = [batch size, 1, trg len, trg len] 201 | #src_mask = [batch size, 1, 1, src len] 202 | #self attention 203 | _trg, _ = self.self_attention(trg, trg, trg, trg_mask) 204 | #dropout, residual connection and layer norm 205 | trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) 206 | #trg = [batch size, trg len, hid dim] 207 | #encoder attention 208 | _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) 209 | #dropout, residual connection and layer norm 210 | trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) 211 | #trg = [batch size, trg len, hid dim] 212 | #positionwise feedforward 213 | _trg = self.positionwise_feedforward(trg) 214 | #dropout, residual and layer norm 215 | trg = self.ff_layer_norm(trg + self.dropout(_trg)) 216 | #trg = [batch size, trg len, hid dim] 217 | #attention = [batch size, n heads, trg len, src len] 218 | return trg, attention 219 | 220 | class Transformer(nn.Module): 221 | def __init__(self, 222 | encoder, 223 | decoder, 224 | pad_idx, 225 | device): 226 | super().__init__() 227 | self.encoder = encoder 228 | self.decoder = decoder 229 | self.pad_idx = pad_idx 230 | self.device = device 231 | 232 | def make_src_mask(self, src): 233 | #src = [batch size, src len] 234 | src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2) 235 | #src_mask = [batch size, 1, 1, src len] 236 | return src_mask 237 | 238 | def make_trg_mask(self, trg): 239 | #trg = [batch size, trg len] 240 | trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(2) 241 | #trg_pad_mask = [batch size, 1, 1, trg len] 242 | trg_len = trg.shape[1] 243 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool() 244 | #trg_sub_mask = [trg len, trg len] 245 | trg_mask = trg_pad_mask & trg_sub_mask 246 | #trg_mask = [batch size, 1, trg len, trg len] 247 | return trg_mask 248 | 249 | def forward(self, src, trg): 250 | #src = [batch size, src len] 251 | #trg = [batch size, trg len] 252 | src_mask = self.make_src_mask(src) 253 | trg_mask = self.make_trg_mask(trg) 254 | #src_mask = [batch size, 1, 1, src len] 255 | #trg_mask = [batch size, 1, trg len, trg len] 256 | enc_src = self.encoder(src, src_mask) 257 | #enc_src = [batch size, src len, hid dim] 258 | output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) 259 | #output = [batch size, trg len, output dim] 260 | #attention = [batch size, n heads, trg len, src len] 261 | return output, attention 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | -------------------------------------------------------------------------------- /LSTM_peptides/LSTM_peptides.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | ..author:: Alex Müller, ETH Zürich, Switzerland. 5 | ..date:: September 2017 6 | 7 | Code for training a LSTM model on peptide sequences followed by sampling novel sequences through the model. 8 | Check the readme for possible flags to use with this script. 9 | """ 10 | import json 11 | import os 12 | import pickle 13 | import random 14 | import argparse 15 | 16 | 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow.keras.callbacks import ModelCheckpoint 23 | from tensorflow.keras.initializers import RandomNormal 24 | from tensorflow.keras.layers import Dense, LSTM, GRU 25 | from tensorflow.keras.models import Sequential, load_model 26 | from tensorflow.keras.optimizers import Adam 27 | from tensorflow.keras.regularizers import l2 28 | from modlamp.analysis import GlobalAnalysis 29 | from modlamp.core import count_aas 30 | from modlamp.descriptors import PeptideDescriptor, GlobalDescriptor 31 | from modlamp.sequences import Random, Helices 32 | from progressbar import ProgressBar 33 | from scipy.spatial import distance 34 | from sklearn.model_selection import KFold 35 | from sklearn.preprocessing import StandardScaler 36 | 37 | plt.switch_backend('agg') 38 | flags = argparse.ArgumentParser() 39 | flags.add_argument("-d", "--dataset", default="training_sequences_noC.csv", help="dataset file (expecting csv)", type=str) 40 | flags.add_argument("-n", "--name", default="test", help="run name for log and checkpoint files", type=str) 41 | flags.add_argument("-b", "--batch_size", default=128, help="batch size", type=int) 42 | flags.add_argument("-e", "--epochs", default=50, help="epochs to train", type=int) 43 | flags.add_argument("-l", "--layers", default=2, help="number of layers in the network", type=int) 44 | flags.add_argument("-x", "--neurons", default=256, help="number of units per layer", type=int) 45 | flags.add_argument("-c", "--cell", default="LSTM", help="type of neuron to use, available: LSTM, GRU", type=str) 46 | flags.add_argument("-o", "--dropout", default=0.1, help="dropout to use in every layer; layer 1 gets 1*dropout, layer 2 2*dropout etc.", type=float) 47 | flags.add_argument("-t", "--train", default=False, help="whether the network should be trained or just sampled from", type=bool) 48 | flags.add_argument("-v", "--valsplit", default=0.2, help="fraction of the data to use for validation", type=float) 49 | flags.add_argument("-s", "--sample", default=100, help="number of sequences to sample training", type=int) 50 | flags.add_argument("-p", "--temp", default=1.25, help="temperature used for sampling", type=float) 51 | flags.add_argument("-m", "--maxlen", default=0, help="maximum sequence length allowed when sampling new sequences", type=int) 52 | flags.add_argument("-a", "--startchar", default="j", help="starting character to begin sampling. Default='j' for 'begin'", type=str) 53 | flags.add_argument("-r", "--lr", default=0.01, help="learning rate to be used with the Adam optimizer", type=float) 54 | flags.add_argument("--l2", default=None, help="l2 regularization rate. If None, no l2 regularization is used", type=float) 55 | flags.add_argument("--modfile", default=None, help="filename of the pretrained model to used for sampling if train=False", type=str) 56 | flags.add_argument("--finetune", default=True, help="if True, a pretrained model provided in modfile is finetuned on the dataset", type=bool) 57 | flags.add_argument("--cv", default=None, help="number of folds to use for cross-validation; if None, no CV is performed", type=int) 58 | flags.add_argument("--window", default=0, help="window size used to process sequences. If 0, all sequences are padded to the longest sequence length in the dataset", type=int) 59 | flags.add_argument("--step", default=1, help="step size to move window or prediction target", type=int) 60 | flags.add_argument("--target", default="all", help="whether to learn all proceeding characters or just the last `one` in sequence", type=str) 61 | flags.add_argument("--padlen", default=0, help="number of spaces to use for padding sequences (if window not 0); if 0, sequences are padded to the length of the longest sequence in the dataset", type=int) 62 | flags.add_argument("--refs", default=True, help="whether reference sequence sets should be generated for the analysis", type=bool) 63 | args = flags.parse_args() 64 | 65 | 66 | 67 | def _save_flags(filename): 68 | """ Function to save used arguments to log-file 69 | 70 | :return: saved file 71 | """ 72 | with open(filename, 'w') as f: 73 | f.write("Used flags:\n-----------\n") 74 | json.dump(args.__dict__, f, indent=2) 75 | 76 | 77 | def _onehotencode(s, vocab=None): 78 | """ Function to one-hot encode a sring. 79 | 80 | :param s: {str} String to encode in one-hot fashion 81 | :param vocab: vocabulary to use fore encoding, if None, default AAs are used 82 | :return: one-hot encoded string as a np.array 83 | : j init char " " pading char 84 | """ 85 | if not vocab: 86 | vocab = ['j','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', ' '] 87 | 88 | # generate translation dictionary for one-hot encoding 89 | to_one_hot = dict() 90 | for i, a in enumerate(vocab): 91 | v = np.zeros(len(vocab)) 92 | v[i] = 1 93 | to_one_hot[a] = v 94 | 95 | result = [] 96 | for l in s: 97 | result.append(to_one_hot[l]) 98 | result = np.array(result) 99 | return np.reshape(result, (1, result.shape[0], result.shape[1])), to_one_hot, vocab 100 | 101 | 102 | def _onehotdecode(matrix, vocab=None, filename=None): 103 | """ Decode a given one-hot represented matrix back into sequences 104 | 105 | :param matrix: matrix containing sequence patterns that are one-hot encoded 106 | :param vocab: vocabulary, if None, standard AAs are used 107 | :param filename: filename for saving sequences, if ``None``, sequences are returned in a list 108 | :return: list of decoded sequences in the range lenmin-lenmax, if ``filename``, they are saved to a file 109 | """ 110 | if not vocab: 111 | _, _, vocab = _onehotencode('A') 112 | if len(matrix.shape) == 2: # if a matrix containing only one string is supplied 113 | result = [] 114 | for i in range(matrix.shape[0]): 115 | for j in range(matrix.shape[1]): 116 | aa = np.where(matrix[i, j] == 1.)[0][0] 117 | result.append(vocab[aa]) 118 | seq = ''.join(result) 119 | if filename: 120 | with open(filename, 'wb') as f: 121 | f.write(seq) 122 | else: 123 | return seq 124 | 125 | elif len(matrix.shape) == 3: # if a matrix containing several strings is supplied 126 | result = [] 127 | for n in range(matrix.shape[0]): 128 | oneresult = [] 129 | for i in range(matrix.shape[1]): 130 | for j in range(matrix.shape[2]): 131 | aa = np.where(matrix[n, i, j] == 1.)[0][0] 132 | oneresult.append(vocab[aa]) 133 | seq = ''.join(oneresult) 134 | result.append(seq) 135 | if filename: 136 | with open(filename, 'wb') as f: 137 | for s in result: 138 | f.write(s + '\n') 139 | else: 140 | return result 141 | 142 | 143 | def _sample_with_temp(preds, temp=1.0): 144 | """ Helper function to sample one letter from a probability array given a temperature. 145 | 146 | :param preds: {np.array} predictions returned by the network 147 | :param temp: {float} temperature value to sample at. 148 | """ 149 | streched = np.log(preds) / temp 150 | stretched_probs = np.exp(streched) / np.sum(np.exp(streched)) 151 | return np.random.choice(len(streched), p=stretched_probs) 152 | 153 | 154 | def load_model_instance(filename): 155 | """ Load a whole Model class instance from a given epoch file 156 | 157 | :param filename: epoch file, e.g. model_epoch_5.hdf5 158 | :return: model instance with trained weights 159 | """ 160 | modfile = os.path.dirname(filename) + '/model.p' 161 | mod = pickle.load(open(modfile, 'rb')) 162 | hdf5_file = ''.join(modfile.split('.')[:-1]) + '.hdf5' 163 | mod.model = load_model(hdf5_file) 164 | return mod 165 | 166 | 167 | def save_model_instance(mod): 168 | """ Save a whole Model instance and the corresponding model with weights to two files (model.p and model.hdf5) 169 | 170 | :param mod: model instance 171 | :return: saved model files in the checkpoint dir 172 | """ 173 | tmp = mod.model 174 | tmp.save(mod.checkpointdir + 'model.hdf5') 175 | mod.model = None 176 | pickle.dump(mod, open(mod.checkpointdir + 'model.p', 'wb')) 177 | mod.model = tmp 178 | 179 | 180 | class SequenceHandler(object): 181 | """ Class for handling peptide sequences, e.g. loading, one-hot encoding or decoding and saving """ 182 | 183 | def __init__(self, window=0, step=2, refs=True): 184 | """ 185 | :param window: {str} window used for chopping up sequences. If 0: False 186 | :param step: {int} size of the steps to move the window forward 187 | :param refs {bool} whether to generate reference sequence sets for analysis 188 | """ 189 | self.sequences = None 190 | self.generated = None 191 | self.ran = None 192 | self.hel = None 193 | self.X = list() 194 | self.y = list() 195 | self.window = window 196 | self.step = step 197 | self.refs = refs 198 | # generate translation dictionary for one-hot encoding 199 | _, self.to_one_hot, self.vocab = _onehotencode('A') 200 | 201 | def load_sequences(self, filename): 202 | """ Method to load peptide sequences from a csv file 203 | 204 | :param filename: {str} filename of the sequence file to be read (``csv``, one sequence per line) 205 | :return: sequences in self.sequences 206 | """ 207 | with open(filename) as f: 208 | self.sequences = [s.strip() for s in f] 209 | self.sequences = random.sample(self.sequences, len(self.sequences)) # shuffle sequences randomly 210 | 211 | def pad_sequences(self, pad_char=' ', padlen=0): 212 | """ Pad all sequences to the longest length (default, padlen=0) or a given length 213 | 214 | :param pad_char: {str} Character to pad sequences with 215 | :param padlen: {int} Custom length of padding to add to all sequences to (optional), default: 0. If 216 | 0, sequences are padded to the length of the longest sequence in the training set. If a window is used and the 217 | padded sequence is shorter than the window size, it is padded to fit the window. 218 | """ 219 | if padlen: 220 | padded_seqs = [] 221 | for seq in self.sequences: 222 | if len(seq) < self.window: 223 | padded_seq = seq + pad_char * (self.step + self.window - len(seq)) 224 | else: 225 | padded_seq = seq + pad_char * padlen 226 | padded_seqs.append(padded_seq) 227 | else: 228 | length = max([len(seq) for seq in self.sequences]) 229 | padded_seqs = [] 230 | for seq in self.sequences: 231 | padded_seq = 'j' + seq + pad_char * (length - len(seq)) 232 | padded_seqs.append(padded_seq) 233 | 234 | if pad_char not in self.vocab: 235 | self.vocab += [pad_char] 236 | 237 | self.sequences = padded_seqs # overwrite sequences with padded sequences 238 | 239 | def one_hot_encode(self, target='all'): 240 | """ Chop up loaded sequences into patterns of length ``window`` by moving by stepsize ``step`` and translate 241 | them with a one-hot vector encoding 242 | 243 | :param target: {str} whether all proceeding AA should be learned or just the last one in sequence (`all`, `one`) 244 | :return: one-hot encoded sequence patterns in self.X and corresponding target amino acids in self.y 245 | """ 246 | if self.window == 0: 247 | for s in self.sequences: 248 | self.X.append([self.to_one_hot[char] for char in s[:-self.step]]) 249 | if target == 'all': 250 | self.y.append([self.to_one_hot[char] for char in s[self.step:]]) 251 | elif target == 'one': 252 | self.y.append(s[-self.step:]) 253 | 254 | self.X = np.reshape(self.X, (len(self.X), len(self.sequences[0]) - self.step, len(self.vocab))) 255 | self.y = np.reshape(self.y, (len(self.y), len(self.sequences[0]) - self.step, len(self.vocab))) 256 | 257 | else: 258 | for s in self.sequences: 259 | for i in range(0, len(s) - self.window, self.step): 260 | self.X.append([self.to_one_hot[char] for char in s[i: i + self.window]]) 261 | if target == 'all': 262 | self.y.append([self.to_one_hot[char] for char in s[i + 1: i + self.window + 1]]) 263 | elif target == 'one': 264 | self.y.append(s[-self.step:]) 265 | 266 | self.X = np.reshape(self.X, (len(self.X), self.window, len(self.vocab))) 267 | self.y = np.reshape(self.y, (len(self.y), self.window, len(self.vocab))) 268 | 269 | print("\nData shape:\nX: " + str(self.X.shape) + "\ny: " + str(self.y.shape)) 270 | 271 | def analyze_training(self): 272 | """ Method to analyze the distribution of the training data 273 | 274 | :return: prints out information about the length distribution of the sequences in ``self.sequences`` 275 | """ 276 | d = GlobalDescriptor(self.sequences) 277 | d.length() 278 | print("\nLENGTH DISTRIBUTION OF TRAINING DATA:\n") 279 | print("Number of sequences: \t%i" % len(self.sequences)) 280 | print("Mean sequence length: \t%.1f ± %.1f" % (np.mean(d.descriptor), np.std(d.descriptor))) 281 | print("Median sequence length: \t%i" % np.median(d.descriptor)) 282 | print("Minimal sequence length:\t%i" % np.min(d.descriptor)) 283 | print("Maximal sequence length:\t%i" % np.max(d.descriptor)) 284 | 285 | def analyze_generated(self, num, fname='analysis.txt', plot=False,min_length_seq=3): 286 | """ Method to analyze the generated sequences located in `self.generated`. 287 | 288 | :param num: {int} wanted number of sequences to sample 289 | :param fname: {str} filename to save analysis info to 290 | :param plot: {bool} whether to plot an overview of descriptors 291 | :return: file with analysis info (distances) 292 | """ 293 | with open(fname, 'w') as f: 294 | print("Analyzing...") 295 | f.write("ANALYSIS OF SAMPLED SEQUENCES\n==============================\n\n") 296 | f.write("Nr. of duplicates in generated sequences: %i\n" % (len(self.generated) - len(set(self.generated)))) 297 | count = len(set(self.generated) & set(self.sequences)) # get shared entries in both lists 298 | f.write("%.1f percent of generated sequences are present in the training data.\n" % 299 | ((count / len(self.generated)) * 100)) 300 | d = GlobalDescriptor(self.generated) 301 | len1 = len(d.sequences) 302 | d.filter_aa('j') 303 | len2 = len(d.sequences) 304 | d.length() 305 | f.write("\n\nLENGTH DISTRIBUTION OF GENERATED DATA:\n\n") 306 | f.write("Number of sequences too short:\t%i\n" % (num - len1)) 307 | f.write("Number of invalid (with j):\t%i\n" % (len1 - len2)) 308 | f.write("Number of valid unique seqs:\t%i\n" % len2) 309 | f.write("Mean sequence length: \t\t%.1f ± %.1f\n" % (np.mean(d.descriptor), np.std(d.descriptor))) 310 | f.write("Median sequence length: \t\t%i\n" % np.median(d.descriptor)) 311 | f.write("Minimal sequence length: \t\t%i\n" % np.min(d.descriptor)) 312 | f.write("Maximal sequence length: \t\t%i\n" % np.max(d.descriptor)) 313 | 314 | 315 | self.sequences = [s[1:].rstrip() for s in self.sequences] 316 | 317 | self.analyze_training() 318 | 319 | d.sequences = [s for s in d.sequences] 320 | 321 | descriptor = 'pepcats' 322 | #seq_desc = PeptideDescriptor([s[1:].rstrip() for s in self.sequences], descriptor) 323 | seq_desc = PeptideDescriptor(self.sequences, descriptor) 324 | seq_desc.calculate_autocorr(min_length_seq) 325 | 326 | gen_desc = PeptideDescriptor(d.sequences, descriptor) 327 | gen_desc.calculate_autocorr(min_length_seq) 328 | 329 | # random comparison set 330 | self.ran = Random(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) # generate rand seqs 331 | probas = count_aas(''.join(seq_desc.sequences)).values() # get the aa distribution of training seqs 332 | self.ran.generate_sequences(proba=probas) 333 | ran_desc = PeptideDescriptor(self.ran.sequences, descriptor) 334 | ran_desc.calculate_autocorr(min_length_seq) 335 | 336 | # amphipathic helices comparison set 337 | self.hel = Helices(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) 338 | self.hel.generate_sequences() 339 | hel_desc = PeptideDescriptor(self.hel.sequences, descriptor) 340 | hel_desc.calculate_autocorr(min_length_seq) 341 | 342 | # distance calculation 343 | f.write("\n\nDISTANCE CALCULATION IN '%s' DESCRIPTOR SPACE\n\n" % descriptor.upper()) 344 | desc_dist = distance.cdist(gen_desc.descriptor, seq_desc.descriptor, metric='euclidean') 345 | f.write("Average euclidean distance of sampled to training data:\t%.3f +/- %.3f\n" % 346 | (np.mean(desc_dist), np.std(desc_dist))) 347 | ran_dist = distance.cdist(ran_desc.descriptor, seq_desc.descriptor, metric='euclidean') 348 | f.write("Average euclidean distance if randomly sampled seqs:\t%.3f +/- %.3f\n" % 349 | (np.mean(ran_dist), np.std(ran_dist))) 350 | hel_dist = distance.cdist(hel_desc.descriptor, seq_desc.descriptor, metric='euclidean') 351 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.3f +/- %.3f\n" % 352 | (np.mean(hel_dist), np.std(hel_dist))) 353 | 354 | # more simple descriptors 355 | g_seq = GlobalDescriptor(seq_desc.sequences) 356 | g_gen = GlobalDescriptor(gen_desc.sequences) 357 | g_ran = GlobalDescriptor(ran_desc.sequences) 358 | g_hel = GlobalDescriptor(hel_desc.sequences) 359 | g_seq.calculate_all() 360 | g_gen.calculate_all() 361 | g_ran.calculate_all() 362 | g_hel.calculate_all() 363 | sclr = StandardScaler() 364 | sclr.fit(g_seq.descriptor) 365 | f.write("\n\nDISTANCE CALCULATION FOR SCALED GLOBAL DESCRIPTORS\n\n") 366 | desc_dist = distance.cdist(sclr.transform(g_gen.descriptor), sclr.transform(g_seq.descriptor), 367 | metric='euclidean') 368 | f.write("Average euclidean distance of sampled to training data:\t%.2f +/- %.2f\n" % 369 | (np.mean(desc_dist), np.std(desc_dist))) 370 | ran_dist = distance.cdist(sclr.transform(g_ran.descriptor), sclr.transform(g_seq.descriptor), 371 | metric='euclidean') 372 | f.write("Average euclidean distance if randomly sampled seqs:\t%.2f +/- %.2f\n" % 373 | (np.mean(ran_dist), np.std(ran_dist))) 374 | hel_dist = distance.cdist(sclr.transform(g_hel.descriptor), sclr.transform(g_seq.descriptor), 375 | metric='euclidean') 376 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.2f +/- %.2f\n" % 377 | (np.mean(hel_dist), np.std(hel_dist))) 378 | 379 | # hydrophobic moments 380 | uh_seq = PeptideDescriptor(seq_desc.sequences, 'eisenberg') 381 | uh_seq.calculate_moment() 382 | uh_gen = PeptideDescriptor(gen_desc.sequences, 'eisenberg') 383 | uh_gen.calculate_moment() 384 | uh_ran = PeptideDescriptor(ran_desc.sequences, 'eisenberg') 385 | uh_ran.calculate_moment() 386 | uh_hel = PeptideDescriptor(hel_desc.sequences, 'eisenberg') 387 | uh_hel.calculate_moment() 388 | f.write("\n\nHYDROPHOBIC MOMENTS\n\n") 389 | f.write("Hydrophobic moment of training seqs:\t%.3f +/- %.3f\n" % 390 | (np.mean(uh_seq.descriptor), np.std(uh_seq.descriptor))) 391 | f.write("Hydrophobic moment of sampled seqs:\t\t%.3f +/- %.3f\n" % 392 | (np.mean(uh_gen.descriptor), np.std(uh_gen.descriptor))) 393 | f.write("Hydrophobic moment of random seqs:\t\t%.3f +/- %.3f\n" % 394 | (np.mean(uh_ran.descriptor), np.std(uh_ran.descriptor))) 395 | f.write("Hydrophobic moment of amphipathic seqs:\t%.3f +/- %.3f\n" % 396 | (np.mean(uh_hel.descriptor), np.std(uh_hel.descriptor))) 397 | 398 | if plot: 399 | if self.refs: 400 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences, uh_hel.sequences, uh_ran.sequences], 401 | ['training', 'sampled', 'hel', 'ran']) 402 | else: 403 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences], ['training', 'sampled']) 404 | a.plot_summary(filename=fname[:-4] + '.png') 405 | 406 | def save_generated(self, logdir, filename): 407 | """ Save all sequences in `self.generated` to file 408 | 409 | :param logdir: {str} current log directory (used for comparison sequences) 410 | :param filename: {str} filename to save the sequences to 411 | :return: saved file 412 | """ 413 | with open(filename, 'w') as f: 414 | for s in self.generated: 415 | f.write(s + '\n') 416 | 417 | self.ran.save_fasta(logdir + '/random_sequences.fasta') 418 | self.hel.save_fasta(logdir + '/helical_sequences.fasta') 419 | 420 | 421 | class Model(object): 422 | """ 423 | Class containing the LSTM model to learn sequential data 424 | """ 425 | 426 | def __init__(self, n_vocab, outshape, session_name, cell="LSTM", n_units=256, batch=64, layers=2, lr=0.001, 427 | dropoutfract=0.1, loss='categorical_crossentropy', l2_reg=None, ask=True, seed=42): 428 | """ Initialize the model 429 | 430 | :param n_vocab: {int} length of vocabulary 431 | :param outshape: {int} output dimensionality of the model 432 | :param session_name: {str} custom name for the current session. Will create directory with this name to save 433 | results / logs to. 434 | :param n_units: {int} number of LSTM units per layer 435 | :param batch: {int} batch size 436 | :param layers: {int} number of layers in the network 437 | :param loss: {str} applied loss function, choose from available keras loss functions 438 | :param lr: {float} learning rate to use with Adam optimizer 439 | :param dropoutfract: {float} fraction of dropout to add to each layer. Layer1 gets 1 * value, Layer2 2 * 440 | value and so on. 441 | :param l2_reg: {float} l2 regularization for kernel 442 | :param seed {int} random seed used to initialize weights 443 | """ 444 | random.seed(seed) 445 | self.seed = seed 446 | self.dropout = dropoutfract 447 | self.inshape = (None, n_vocab) 448 | self.outshape = outshape 449 | self.neurons = n_units 450 | self.layers = layers 451 | self.losses = list() 452 | self.val_losses = list() 453 | self.batchsize = batch 454 | self.lr = lr 455 | self.cv_loss = None 456 | self.cv_loss_std = None 457 | self.cv_val_loss = None 458 | self.cv_val_loss_std = None 459 | self.model = None 460 | self.cell = cell 461 | self.losstype = loss 462 | self.session_name = session_name 463 | self.logdir = './' + session_name 464 | self.l2 = l2_reg 465 | if ask and os.path.exists(self.logdir): 466 | decision = input('\nSession folder already exists!\n' 467 | 'Do you want to overwrite the previous session? [y/n] ') 468 | if decision in ['n', 'no', 'N', 'NO', 'No']: 469 | self.logdir = './' + input('Enter new session name: ') 470 | os.makedirs(self.logdir) 471 | self.checkpointdir = self.logdir + '/checkpoint/' 472 | if not os.path.exists(self.checkpointdir): 473 | os.makedirs(self.checkpointdir) 474 | _, _, self.vocab = _onehotencode('A') 475 | 476 | self.initialize_model(seed=self.seed) 477 | 478 | def initialize_model(self, seed=42): 479 | """ Method to initialize the model with all parameters saved in the attributes. This method is used during 480 | initialization of the class, as well as in cross-validation to reinitialize a fresh model for every fold. 481 | 482 | :param seed: {int} random seed to use for weight initialization 483 | 484 | :return: initialized model in ``self.model`` 485 | """ 486 | self.losses = list() 487 | self.val_losses = list() 488 | self.cv_loss = None 489 | self.cv_loss_std = None 490 | self.cv_val_loss = None 491 | self.cv_val_loss_std = None 492 | self.model = None 493 | weight_init = RandomNormal(mean=0.0, stddev=0.05, seed=seed) # weights randomly between -0.05 and 0.05 494 | optimizer = Adam(lr=self.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) #learning_rate 495 | 496 | if self.l2: 497 | l2reg = l2(self.l2) 498 | else: 499 | l2reg = None 500 | 501 | self.model = Sequential() 502 | for l in range(self.layers): 503 | if self.cell == "GRU": 504 | self.model.add(GRU(units=self.neurons, 505 | name='GRU%i' % (l + 1), 506 | input_shape=self.inshape, 507 | return_sequences=True, 508 | kernel_initializer=weight_init, 509 | kernel_regularizer=l2reg, 510 | dropout=self.dropout * (l + 1))) 511 | else: 512 | self.model.add(LSTM(units=self.neurons, 513 | name='LSTM%i' % (l + 1), 514 | input_shape=self.inshape, 515 | return_sequences=True, 516 | kernel_initializer=weight_init, 517 | kernel_regularizer=l2reg, 518 | dropout=self.dropout * (l + 1), 519 | recurrent_dropout=self.dropout * (l + 1))) 520 | self.model.add(Dense(self.outshape, 521 | name='Dense', 522 | activation='softmax', 523 | kernel_regularizer=self.l2, 524 | kernel_initializer=weight_init)) 525 | self.model.compile(loss=self.losstype, optimizer=optimizer) 526 | with open(self.checkpointdir + "model.json", 'w') as f: 527 | json.dump(self.model.to_json(), f) 528 | self.model.summary() 529 | 530 | def finetuneinit(self, session_name): 531 | """ Method to generate a new directory for finetuning a pre-existing model on a new dataset with a new name 532 | 533 | :param session_name: {str} new session name for finetuning 534 | :return: generates all necessary session folders 535 | """ 536 | self.session_name = session_name 537 | self.logdir = './' + session_name 538 | if os.path.exists(self.logdir): 539 | decision = input('\nSession folder already exists!\n' 540 | 'Do you want to overwrite the previous session? [y/n] ') 541 | if decision in ['n', 'no', 'N', 'NO', 'No']: 542 | self.logdir = './' + input('Enter new session name: ') 543 | os.makedirs(self.logdir) 544 | self.checkpointdir = self.logdir + '/checkpoint/' 545 | if not os.path.exists(self.checkpointdir): 546 | os.makedirs(self.checkpointdir) 547 | 548 | def train(self, x, y, epochs=100, valsplit=0.2, sample=100): 549 | """ Train the model on given training data. 550 | 551 | :param x: {array} training data 552 | :param y: {array} targets for training data in X 553 | :param epochs: {int} number of epochs to train 554 | :param valsplit: {float} fraction of data that should be used as validation data during training 555 | :param sample: {int} number of sequences to sample after every training epoch 556 | :return: trained model and measured losses in self.model, self.losses and self.val_losses 557 | """ 558 | writer = tf.summary.create_file_writer('./logs/' + self.session_name) 559 | with writer.as_default(): 560 | for e in range(epochs): 561 | print("Epoch %i" % e) 562 | checkpoints = [ModelCheckpoint(filepath=self.checkpointdir + 'model_epoch_%i.hdf5' % e, verbose=0)] 563 | train_history = self.model.fit(x, y, epochs=1, batch_size=self.batchsize, validation_split=valsplit, 564 | shuffle=False, callbacks=checkpoints) 565 | tf.summary.scalar('loss', train_history.history['loss'][-1], step=e) 566 | self.losses.append(train_history.history['loss']) 567 | if valsplit > 0.: 568 | self.val_losses.append(train_history.history['val_loss']) 569 | tf.summary.scalar('val_loss', train_history.history['val_loss'][-1], step=e) 570 | if sample: 571 | for s in self.sample(sample): # sample sequences after every training epoch 572 | print(s) 573 | writer.close() 574 | 575 | def cross_val(self, x, y, epochs=100, cv=5, plot=True): 576 | """ Method to perform cross-validation with the model given data X, y 577 | 578 | :param x: {array} training data 579 | :param y: {array} targets for training data in X 580 | :param epochs: {int} number of epochs to train 581 | :param cv: {int} fold 582 | :param plot: {bool} whether the losses should be plotted and saved to the session folder 583 | :return: 584 | """ 585 | self.losses = list() # clean losses if already present 586 | self.val_losses = list() 587 | kf = KFold(n_splits=cv) 588 | cntr = 0 589 | for train, test in kf.split(x): 590 | print("\nFold %i" % (cntr + 1)) 591 | self.initialize_model(seed=cntr) # reinitialize every fold, otherwise it will "remember" previous data 592 | train_history = self.model.fit(x[train], y[train], epochs=epochs, batch_size=self.batchsize, 593 | validation_data=(x[test], y[test])) 594 | self.losses.append(train_history.history['loss']) 595 | self.val_losses.append(train_history.history['val_loss']) 596 | cntr += 1 597 | self.cv_loss = np.mean(self.losses, axis=0) 598 | self.cv_loss_std = np.std(self.losses, axis=0) 599 | self.cv_val_loss = np.mean(self.val_losses, axis=0) 600 | self.cv_val_loss_std = np.std(self.val_losses, axis=0) 601 | if plot: 602 | self.plot_losses(cv=True) 603 | 604 | # get best epoch with corresponding val_loss 605 | minloss = np.min(self.cv_val_loss) 606 | e = np.where(minloss == self.cv_val_loss)[0][0] 607 | print("\n%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 608 | with open(self.logdir + '/' + self.session_name + '_best_epoch.txt', 'w') as f: 609 | f.write("%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 610 | 611 | def plot_losses(self, show=False, cv=False): 612 | """Plot the losses obtained in training. 613 | 614 | :param show: {bool} Whether the plot should be shown or saved. If ``False``, the plot is saved to the 615 | session folder. 616 | :param cv: {bool} Whether the losses from cross-validation should be plotted. The standard deviation will be 617 | depicted as filled areas around the mean curve. 618 | :return: plot (saved) or shown interactive 619 | """ 620 | fig, ax = plt.subplots() 621 | ax.set_title('LSTM Categorical Crossentropy Loss Plot', fontweight='bold', fontsize=16) 622 | if cv: 623 | filename = self.logdir + '/' + self.session_name + '_cv_loss_plot.pdf' 624 | x = range(1, len(self.cv_loss) + 1) 625 | ax.plot(x, self.cv_loss, '-', color='#FE4365', label='Training') 626 | ax.plot(x, self.cv_val_loss, '-', color='k', label='Validation') 627 | ax.fill_between(x, self.cv_loss + self.cv_loss_std, self.cv_loss - self.cv_loss_std, 628 | facecolors='#FE4365', alpha=0.5) 629 | ax.fill_between(x, self.cv_val_loss + self.cv_val_loss_std, self.cv_val_loss - self.cv_val_loss_std, 630 | facecolors='k', alpha=0.5) 631 | ax.set_xlim([0.5, len(self.cv_loss) + 0.5]) 632 | minloss = np.min(self.cv_val_loss) 633 | plt.text(x=0.5, y=0.5, s='best epoch: ' + str(np.where(minloss == self.cv_val_loss)[0][0]) + ', val_loss: ' 634 | + str(minloss.round(4)), transform=ax.transAxes) 635 | else: 636 | filename = self.logdir + '/' + self.session_name + '_loss_plot.pdf' 637 | x = range(1, len(self.losses) + 1) 638 | ax.plot(x, self.losses, '-', color='#FE4365', label='Training') 639 | if self.val_losses: 640 | ax.plot(x, self.val_losses, '-', color='k', label='Validation') 641 | ax.set_xlim([0.5, len(self.losses) + 0.5]) 642 | ax.set_ylabel('Loss', fontweight='bold', fontsize=14) 643 | ax.set_xlabel('Epoch', fontweight='bold', fontsize=14) 644 | ax.spines['right'].set_visible(False) 645 | ax.spines['top'].set_visible(False) 646 | ax.xaxis.set_ticks_position('bottom') 647 | ax.yaxis.set_ticks_position('left') 648 | plt.legend(loc='best') 649 | if show: 650 | plt.show() 651 | else: 652 | plt.savefig(filename) 653 | 654 | def sample(self, num=100, minlen=7, maxlen=50, start=None, temp=2.5, show=False): 655 | """Invoke generation of sequence patterns through sampling from the trained model. 656 | 657 | :param num: {int} number of sequences to sample 658 | :param minlen {int} minimal allowed sequence length 659 | :param maxlen: {int} maximal length of each pattern generated, if 0, a random length is chosen between 7 and 50 660 | :param start: {str} start AA to be used for sampling. If ``None``, a random AA is chosen 661 | :param temp: {float} temperature value to sample at. 662 | :param show: {bool} whether the sampled sequences should be printed out 663 | :return: {array} matrix of patterns of shape (num, seqlen, inputshape[0]) 664 | """ 665 | print("\nSampling...\n") 666 | sampled = [] 667 | lcntr = 0 668 | pbar = ProgressBar() 669 | for rs in pbar(range(num)): 670 | random.seed(rs) 671 | if not maxlen: # if the length should be randomly sampled 672 | longest = np.random.randint(7, 50) 673 | else: 674 | longest = maxlen 675 | 676 | if start: 677 | start_aa = start 678 | else: # generate random starting letter 679 | start_aa = 'j' 680 | sequence = start_aa # start with starting letter 681 | 682 | while sequence[-1] != ' ' and len(sequence) <= longest: # sample until padding or maxlen is reached 683 | x, _, _ = _onehotencode(sequence) 684 | preds = self.model.predict(x)[0][-1] 685 | next_aa = _sample_with_temp(preds, temp=temp) 686 | sequence += self.vocab[next_aa] 687 | 688 | if start_aa == 'j': 689 | sequence = sequence[1:].rstrip() 690 | else: # keep starting AA if chosen for sampling 691 | sequence = sequence.rstrip() 692 | 693 | if len(sequence) < minlen: # don't take sequences shorter than the minimal length 694 | lcntr += 1 695 | continue 696 | 697 | sampled.append(sequence) 698 | if show: 699 | print(sequence) 700 | 701 | print("\t%i sequences were shorter than %i" % (lcntr, minlen)) 702 | return sampled 703 | 704 | def load_model(self, filename): 705 | """Method to load a trained model from a hdf5 file 706 | 707 | :return: model loaded from file in ``self.model`` 708 | """ 709 | self.model.load_weights(filename) 710 | 711 | # def get_num_params(self): 712 | # """Method to get the amount of trainable parameters in the model. 713 | # """ 714 | # trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 715 | # non_trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.non_trainable_variables()]) 716 | # print('\nMODEL PARAMETERS') 717 | # print('Total parameters: %i' % (trainable + non_trainable)) 718 | # print('Trainable parameters: %i' % trainable) 719 | # print('Non-trainable parameters: %i' % non_trainable) 720 | 721 | 722 | def main(infile, sessname, neurons=64, layers=2, epochs=100, batchsize=128, window=0, step=1, target='all', 723 | valsplit=0.2, sample=100, aa='j', temperature=2.5, cell="LSTM", dropout=0.1, train=False, learningrate=0.01, 724 | modfile=None, samplelength=36, pad=0, l2_rate=None, cv=None, finetune=True, references=True): 725 | # loading sequence data, analyze, pad and encode it 726 | data = SequenceHandler(window=window, step=step, refs=references) 727 | print("Loading sequences...") 728 | data.load_sequences(infile) 729 | data.analyze_training() 730 | 731 | # pad sequences 732 | print("\nPadding sequences...") 733 | data.pad_sequences(padlen=pad) 734 | 735 | # one-hot encode padded sequences 736 | print("One-hot encoding sequences...") 737 | data.one_hot_encode(target=target) 738 | 739 | if train: 740 | # building the LSTM model 741 | print("\nBuilding model...") 742 | model = Model(n_vocab=len(data.vocab), outshape=len(data.vocab), session_name=sessname, n_units=neurons, 743 | batch=batchsize, layers=layers, cell=cell, loss='categorical_crossentropy', lr=learningrate, 744 | dropoutfract=dropout, l2_reg=l2_rate, ask=True, seed=42) 745 | print("Model built!") 746 | 747 | if cv: 748 | print("\nPERFORMING %i-FOLD CROSS-VALIDATION...\n" % cv) 749 | model.cross_val(data.X, data.y, epochs=epochs, cv=cv) 750 | model.initialize_model(seed=42) 751 | model.train(data.X, data.y, epochs=epochs, valsplit=0.0, sample=0) 752 | model.plot_losses() 753 | else: 754 | # training model on data 755 | print("\nTRAINING MODEL FOR %i EPOCHS...\n" % epochs) 756 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 757 | model.plot_losses() # plot loss 758 | 759 | save_model_instance(model) 760 | 761 | elif finetune: 762 | print("\nUSING PRETRAINED MODEL FOR FINETUNING... (%s)\n" % modfile) 763 | print("Loading model...") 764 | model = load_model_instance(modfile) 765 | model.load_model(modfile) 766 | model.finetuneinit(sessname) # generate new session folders for finetuning run 767 | print("Finetuning model...") 768 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 769 | model.plot_losses() # plot loss 770 | save_model_instance(model) 771 | else: 772 | print("\nUSING PRETRAINED MODEL... (%s)\n" % modfile) 773 | model = load_model_instance(modfile) 774 | model.load_model(modfile) 775 | 776 | print(model.model.summary()) # print number of parameters in the model 777 | 778 | # generating new data through sampling 779 | print("\nSAMPLING %i SEQUENCES...\n" % sample) 780 | data.generated = model.sample(sample, start=aa, maxlen=samplelength, show=False, temp=temperature) 781 | data.analyze_generated(sample, fname=model.logdir + '/analysis_temp' + str(temperature) + '.txt', plot=True) 782 | data.save_generated(model.logdir, model.logdir + '/sampled_sequences_temp' + str(temperature) + '.csv') 783 | 784 | 785 | if __name__ == "__main__": 786 | # run main code 787 | main(infile=args.dataset, sessname=args.name, batchsize=args.batch_size, epochs=args.epochs, 788 | layers=args.layers, valsplit=args.valsplit, neurons=args.neurons, cell=args.cell, sample=args.sample, 789 | temperature=args.temp, dropout=args.dropout, train=True, modfile=args.modfile, 790 | learningrate=args.lr, cv=args.cv, samplelength=args.maxlen, window=args.window, 791 | step=args.step, aa=args.startchar, l2_rate=args.l2, target=args.target, pad=args.padlen, 792 | finetune=False, references=args.refs) 793 | 794 | # save used flags to log file 795 | _save_flags("./" + args.name + "/flags.txt") 796 | -------------------------------------------------------------------------------- /LSTM_peptides/LSTM_peptides_fine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | ..author:: Alex Müller, ETH Zürich, Switzerland. 5 | ..date:: September 2017 6 | 7 | Code for training a LSTM model on peptide sequences followed by sampling novel sequences through the model. 8 | Check the readme for possible flags to use with this script. 9 | """ 10 | import json 11 | import os 12 | import pickle 13 | import random 14 | import argparse 15 | 16 | 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow.keras.callbacks import ModelCheckpoint 23 | from tensorflow.keras.initializers import RandomNormal 24 | from tensorflow.keras.layers import Dense, LSTM, GRU 25 | from tensorflow.keras.models import Sequential, load_model 26 | from tensorflow.keras.optimizers import Adam 27 | from tensorflow.keras.regularizers import l2 28 | from modlamp.analysis import GlobalAnalysis 29 | from modlamp.core import count_aas 30 | from modlamp.descriptors import PeptideDescriptor, GlobalDescriptor 31 | from modlamp.sequences import Random, Helices 32 | from progressbar import ProgressBar 33 | from scipy.spatial import distance 34 | from sklearn.model_selection import KFold 35 | from sklearn.preprocessing import StandardScaler 36 | 37 | plt.switch_backend('agg') 38 | flags = argparse.ArgumentParser() 39 | flags.add_argument("-d", "--dataset", default="training_sequences_noC.csv", help="dataset file (expecting csv)", type=str) 40 | flags.add_argument("-n", "--name", default="test", help="run name for log and checkpoint files", type=str) 41 | flags.add_argument("-b", "--batch_size", default=128, help="batch size", type=int) 42 | flags.add_argument("-e", "--epochs", default=50, help="epochs to train", type=int) 43 | flags.add_argument("-l", "--layers", default=2, help="number of layers in the network", type=int) 44 | flags.add_argument("-x", "--neurons", default=256, help="number of units per layer", type=int) 45 | flags.add_argument("-c", "--cell", default="LSTM", help="type of neuron to use, available: LSTM, GRU", type=str) 46 | flags.add_argument("-o", "--dropout", default=0.1, help="dropout to use in every layer; layer 1 gets 1*dropout, layer 2 2*dropout etc.", type=float) 47 | flags.add_argument("-t", "--train", default=False, help="whether the network should be trained or just sampled from", type=bool) 48 | flags.add_argument("-v", "--valsplit", default=0.2, help="fraction of the data to use for validation", type=float) 49 | flags.add_argument("-s", "--sample", default=100, help="number of sequences to sample training", type=int) 50 | flags.add_argument("-p", "--temp", default=1.25, help="temperature used for sampling", type=float) 51 | flags.add_argument("-m", "--maxlen", default=0, help="maximum sequence length allowed when sampling new sequences", type=int) 52 | flags.add_argument("-a", "--startchar", default="j", help="starting character to begin sampling. Default='j' for 'begin'", type=str) 53 | flags.add_argument("-r", "--lr", default=0.01, help="learning rate to be used with the Adam optimizer", type=float) 54 | flags.add_argument("--l2", default=None, help="l2 regularization rate. If None, no l2 regularization is used", type=float) 55 | flags.add_argument("--modfile", default=None, help="filename of the pretrained model to used for sampling if train=False", type=str) 56 | flags.add_argument("--finetune", default=True, help="if True, a pretrained model provided in modfile is finetuned on the dataset", type=bool) 57 | flags.add_argument("--cv", default=None, help="number of folds to use for cross-validation; if None, no CV is performed", type=int) 58 | flags.add_argument("--window", default=0, help="window size used to process sequences. If 0, all sequences are padded to the longest sequence length in the dataset", type=int) 59 | flags.add_argument("--step", default=1, help="step size to move window or prediction target", type=int) 60 | flags.add_argument("--target", default="all", help="whether to learn all proceeding characters or just the last `one` in sequence", type=str) 61 | flags.add_argument("--padlen", default=0, help="number of spaces to use for padding sequences (if window not 0); if 0, sequences are padded to the length of the longest sequence in the dataset", type=int) 62 | flags.add_argument("--refs", default=True, help="whether reference sequence sets should be generated for the analysis", type=bool) 63 | args = flags.parse_args() 64 | 65 | 66 | 67 | def _save_flags(filename): 68 | """ Function to save used arguments to log-file 69 | 70 | :return: saved file 71 | """ 72 | with open(filename, 'w') as f: 73 | f.write("Used flags:\n-----------\n") 74 | json.dump(args.__dict__, f, indent=2) 75 | 76 | 77 | def _onehotencode(s, vocab=None): 78 | """ Function to one-hot encode a sring. 79 | 80 | :param s: {str} String to encode in one-hot fashion 81 | :param vocab: vocabulary to use fore encoding, if None, default AAs are used 82 | :return: one-hot encoded string as a np.array 83 | : j init char " " pading char 84 | """ 85 | if not vocab: 86 | vocab = ['j','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', ' '] 87 | 88 | # generate translation dictionary for one-hot encoding 89 | to_one_hot = dict() 90 | for i, a in enumerate(vocab): 91 | v = np.zeros(len(vocab)) 92 | v[i] = 1 93 | to_one_hot[a] = v 94 | 95 | result = [] 96 | for l in s: 97 | result.append(to_one_hot[l]) 98 | result = np.array(result) 99 | return np.reshape(result, (1, result.shape[0], result.shape[1])), to_one_hot, vocab 100 | 101 | 102 | def _onehotdecode(matrix, vocab=None, filename=None): 103 | """ Decode a given one-hot represented matrix back into sequences 104 | 105 | :param matrix: matrix containing sequence patterns that are one-hot encoded 106 | :param vocab: vocabulary, if None, standard AAs are used 107 | :param filename: filename for saving sequences, if ``None``, sequences are returned in a list 108 | :return: list of decoded sequences in the range lenmin-lenmax, if ``filename``, they are saved to a file 109 | """ 110 | if not vocab: 111 | _, _, vocab = _onehotencode('A') 112 | if len(matrix.shape) == 2: # if a matrix containing only one string is supplied 113 | result = [] 114 | for i in range(matrix.shape[0]): 115 | for j in range(matrix.shape[1]): 116 | aa = np.where(matrix[i, j] == 1.)[0][0] 117 | result.append(vocab[aa]) 118 | seq = ''.join(result) 119 | if filename: 120 | with open(filename, 'wb') as f: 121 | f.write(seq) 122 | else: 123 | return seq 124 | 125 | elif len(matrix.shape) == 3: # if a matrix containing several strings is supplied 126 | result = [] 127 | for n in range(matrix.shape[0]): 128 | oneresult = [] 129 | for i in range(matrix.shape[1]): 130 | for j in range(matrix.shape[2]): 131 | aa = np.where(matrix[n, i, j] == 1.)[0][0] 132 | oneresult.append(vocab[aa]) 133 | seq = ''.join(oneresult) 134 | result.append(seq) 135 | if filename: 136 | with open(filename, 'wb') as f: 137 | for s in result: 138 | f.write(s + '\n') 139 | else: 140 | return result 141 | 142 | 143 | def _sample_with_temp(preds, temp=1.0): 144 | """ Helper function to sample one letter from a probability array given a temperature. 145 | 146 | :param preds: {np.array} predictions returned by the network 147 | :param temp: {float} temperature value to sample at. 148 | """ 149 | streched = np.log(preds) / temp 150 | stretched_probs = np.exp(streched) / np.sum(np.exp(streched)) 151 | return np.random.choice(len(streched), p=stretched_probs) 152 | 153 | 154 | def load_model_instance(filename): 155 | """ Load a whole Model class instance from a given epoch file 156 | 157 | :param filename: epoch file, e.g. model_epoch_5.hdf5 158 | :return: model instance with trained weights 159 | """ 160 | modfile = os.path.dirname(filename) + '/model.p' 161 | mod = pickle.load(open(modfile, 'rb')) 162 | hdf5_file = ''.join(modfile.split('.')[:-1]) + '.hdf5' 163 | mod.model = load_model(hdf5_file) 164 | return mod 165 | 166 | 167 | def save_model_instance(mod): 168 | """ Save a whole Model instance and the corresponding model with weights to two files (model.p and model.hdf5) 169 | 170 | :param mod: model instance 171 | :return: saved model files in the checkpoint dir 172 | """ 173 | tmp = mod.model 174 | tmp.save(mod.checkpointdir + 'model.hdf5') 175 | mod.model = None 176 | pickle.dump(mod, open(mod.checkpointdir + 'model.p', 'wb')) 177 | mod.model = tmp 178 | 179 | 180 | class SequenceHandler(object): 181 | """ Class for handling peptide sequences, e.g. loading, one-hot encoding or decoding and saving """ 182 | 183 | def __init__(self, window=0, step=2, refs=True): 184 | """ 185 | :param window: {str} window used for chopping up sequences. If 0: False 186 | :param step: {int} size of the steps to move the window forward 187 | :param refs {bool} whether to generate reference sequence sets for analysis 188 | """ 189 | self.sequences = None 190 | self.generated = None 191 | self.ran = None 192 | self.hel = None 193 | self.X = list() 194 | self.y = list() 195 | self.window = window 196 | self.step = step 197 | self.refs = refs 198 | # generate translation dictionary for one-hot encoding 199 | _, self.to_one_hot, self.vocab = _onehotencode('A') 200 | 201 | def load_sequences(self, filename): 202 | """ Method to load peptide sequences from a csv file 203 | 204 | :param filename: {str} filename of the sequence file to be read (``csv``, one sequence per line) 205 | :return: sequences in self.sequences 206 | """ 207 | with open(filename) as f: 208 | self.sequences = [s.strip() for s in f] 209 | self.sequences = random.sample(self.sequences, len(self.sequences)) # shuffle sequences randomly 210 | 211 | def pad_sequences(self, pad_char=' ', padlen=0): 212 | """ Pad all sequences to the longest length (default, padlen=0) or a given length 213 | 214 | :param pad_char: {str} Character to pad sequences with 215 | :param padlen: {int} Custom length of padding to add to all sequences to (optional), default: 0. If 216 | 0, sequences are padded to the length of the longest sequence in the training set. If a window is used and the 217 | padded sequence is shorter than the window size, it is padded to fit the window. 218 | """ 219 | if padlen: 220 | padded_seqs = [] 221 | for seq in self.sequences: 222 | if len(seq) < self.window: 223 | padded_seq = seq + pad_char * (self.step + self.window - len(seq)) 224 | else: 225 | padded_seq = seq + pad_char * padlen 226 | padded_seqs.append(padded_seq) 227 | else: 228 | length = max([len(seq) for seq in self.sequences]) 229 | padded_seqs = [] 230 | for seq in self.sequences: 231 | padded_seq = 'j' + seq + pad_char * (length - len(seq)) 232 | padded_seqs.append(padded_seq) 233 | 234 | if pad_char not in self.vocab: 235 | self.vocab += [pad_char] 236 | 237 | self.sequences = padded_seqs # overwrite sequences with padded sequences 238 | 239 | def one_hot_encode(self, target='all'): 240 | """ Chop up loaded sequences into patterns of length ``window`` by moving by stepsize ``step`` and translate 241 | them with a one-hot vector encoding 242 | 243 | :param target: {str} whether all proceeding AA should be learned or just the last one in sequence (`all`, `one`) 244 | :return: one-hot encoded sequence patterns in self.X and corresponding target amino acids in self.y 245 | """ 246 | if self.window == 0: 247 | for s in self.sequences: 248 | self.X.append([self.to_one_hot[char] for char in s[:-self.step]]) 249 | if target == 'all': 250 | self.y.append([self.to_one_hot[char] for char in s[self.step:]]) 251 | elif target == 'one': 252 | self.y.append(s[-self.step:]) 253 | 254 | self.X = np.reshape(self.X, (len(self.X), len(self.sequences[0]) - self.step, len(self.vocab))) 255 | self.y = np.reshape(self.y, (len(self.y), len(self.sequences[0]) - self.step, len(self.vocab))) 256 | 257 | else: 258 | for s in self.sequences: 259 | for i in range(0, len(s) - self.window, self.step): 260 | self.X.append([self.to_one_hot[char] for char in s[i: i + self.window]]) 261 | if target == 'all': 262 | self.y.append([self.to_one_hot[char] for char in s[i + 1: i + self.window + 1]]) 263 | elif target == 'one': 264 | self.y.append(s[-self.step:]) 265 | 266 | self.X = np.reshape(self.X, (len(self.X), self.window, len(self.vocab))) 267 | self.y = np.reshape(self.y, (len(self.y), self.window, len(self.vocab))) 268 | 269 | print("\nData shape:\nX: " + str(self.X.shape) + "\ny: " + str(self.y.shape)) 270 | 271 | def analyze_training(self): 272 | """ Method to analyze the distribution of the training data 273 | 274 | :return: prints out information about the length distribution of the sequences in ``self.sequences`` 275 | """ 276 | d = GlobalDescriptor(self.sequences) 277 | d.length() 278 | print("\nLENGTH DISTRIBUTION OF TRAINING DATA:\n") 279 | print("Number of sequences: \t%i" % len(self.sequences)) 280 | print("Mean sequence length: \t%.1f ± %.1f" % (np.mean(d.descriptor), np.std(d.descriptor))) 281 | print("Median sequence length: \t%i" % np.median(d.descriptor)) 282 | print("Minimal sequence length:\t%i" % np.min(d.descriptor)) 283 | print("Maximal sequence length:\t%i" % np.max(d.descriptor)) 284 | 285 | def analyze_generated(self, num, fname='analysis.txt', plot=False,min_length_seq=3): 286 | """ Method to analyze the generated sequences located in `self.generated`. 287 | 288 | :param num: {int} wanted number of sequences to sample 289 | :param fname: {str} filename to save analysis info to 290 | :param plot: {bool} whether to plot an overview of descriptors 291 | :return: file with analysis info (distances) 292 | """ 293 | with open(fname, 'w') as f: 294 | print("Analyzing...") 295 | f.write("ANALYSIS OF SAMPLED SEQUENCES\n==============================\n\n") 296 | f.write("Nr. of duplicates in generated sequences: %i\n" % (len(self.generated) - len(set(self.generated)))) 297 | count = len(set(self.generated) & set(self.sequences)) # get shared entries in both lists 298 | f.write("%.1f percent of generated sequences are present in the training data.\n" % 299 | ((count / len(self.generated)) * 100)) 300 | d = GlobalDescriptor(self.generated) 301 | len1 = len(d.sequences) 302 | d.filter_aa('j') 303 | len2 = len(d.sequences) 304 | d.length() 305 | f.write("\n\nLENGTH DISTRIBUTION OF GENERATED DATA:\n\n") 306 | f.write("Number of sequences too short:\t%i\n" % (num - len1)) 307 | f.write("Number of invalid (with j):\t%i\n" % (len1 - len2)) 308 | f.write("Number of valid unique seqs:\t%i\n" % len2) 309 | f.write("Mean sequence length: \t\t%.1f ± %.1f\n" % (np.mean(d.descriptor), np.std(d.descriptor))) 310 | f.write("Median sequence length: \t\t%i\n" % np.median(d.descriptor)) 311 | f.write("Minimal sequence length: \t\t%i\n" % np.min(d.descriptor)) 312 | f.write("Maximal sequence length: \t\t%i\n" % np.max(d.descriptor)) 313 | 314 | 315 | self.sequences = [s[1:].rstrip() for s in self.sequences] 316 | 317 | self.analyze_training() 318 | 319 | d.sequences = [s for s in d.sequences] 320 | 321 | descriptor = 'pepcats' 322 | #seq_desc = PeptideDescriptor([s[1:].rstrip() for s in self.sequences], descriptor) 323 | seq_desc = PeptideDescriptor(self.sequences, descriptor) 324 | seq_desc.calculate_autocorr(min_length_seq) 325 | 326 | gen_desc = PeptideDescriptor(d.sequences, descriptor) 327 | gen_desc.calculate_autocorr(min_length_seq) 328 | 329 | # random comparison set 330 | self.ran = Random(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) # generate rand seqs 331 | probas = count_aas(''.join(seq_desc.sequences)).values() # get the aa distribution of training seqs 332 | self.ran.generate_sequences(proba=probas) 333 | ran_desc = PeptideDescriptor(self.ran.sequences, descriptor) 334 | ran_desc.calculate_autocorr(min_length_seq) 335 | 336 | # amphipathic helices comparison set 337 | self.hel = Helices(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) 338 | self.hel.generate_sequences() 339 | hel_desc = PeptideDescriptor(self.hel.sequences, descriptor) 340 | hel_desc.calculate_autocorr(min_length_seq) 341 | 342 | # distance calculation 343 | f.write("\n\nDISTANCE CALCULATION IN '%s' DESCRIPTOR SPACE\n\n" % descriptor.upper()) 344 | desc_dist = distance.cdist(gen_desc.descriptor, seq_desc.descriptor, metric='euclidean') 345 | f.write("Average euclidean distance of sampled to training data:\t%.3f +/- %.3f\n" % 346 | (np.mean(desc_dist), np.std(desc_dist))) 347 | ran_dist = distance.cdist(ran_desc.descriptor, seq_desc.descriptor, metric='euclidean') 348 | f.write("Average euclidean distance if randomly sampled seqs:\t%.3f +/- %.3f\n" % 349 | (np.mean(ran_dist), np.std(ran_dist))) 350 | hel_dist = distance.cdist(hel_desc.descriptor, seq_desc.descriptor, metric='euclidean') 351 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.3f +/- %.3f\n" % 352 | (np.mean(hel_dist), np.std(hel_dist))) 353 | 354 | # more simple descriptors 355 | g_seq = GlobalDescriptor(seq_desc.sequences) 356 | g_gen = GlobalDescriptor(gen_desc.sequences) 357 | g_ran = GlobalDescriptor(ran_desc.sequences) 358 | g_hel = GlobalDescriptor(hel_desc.sequences) 359 | g_seq.calculate_all() 360 | g_gen.calculate_all() 361 | g_ran.calculate_all() 362 | g_hel.calculate_all() 363 | sclr = StandardScaler() 364 | sclr.fit(g_seq.descriptor) 365 | f.write("\n\nDISTANCE CALCULATION FOR SCALED GLOBAL DESCRIPTORS\n\n") 366 | desc_dist = distance.cdist(sclr.transform(g_gen.descriptor), sclr.transform(g_seq.descriptor), 367 | metric='euclidean') 368 | f.write("Average euclidean distance of sampled to training data:\t%.2f +/- %.2f\n" % 369 | (np.mean(desc_dist), np.std(desc_dist))) 370 | ran_dist = distance.cdist(sclr.transform(g_ran.descriptor), sclr.transform(g_seq.descriptor), 371 | metric='euclidean') 372 | f.write("Average euclidean distance if randomly sampled seqs:\t%.2f +/- %.2f\n" % 373 | (np.mean(ran_dist), np.std(ran_dist))) 374 | hel_dist = distance.cdist(sclr.transform(g_hel.descriptor), sclr.transform(g_seq.descriptor), 375 | metric='euclidean') 376 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.2f +/- %.2f\n" % 377 | (np.mean(hel_dist), np.std(hel_dist))) 378 | 379 | # hydrophobic moments 380 | uh_seq = PeptideDescriptor(seq_desc.sequences, 'eisenberg') 381 | uh_seq.calculate_moment() 382 | uh_gen = PeptideDescriptor(gen_desc.sequences, 'eisenberg') 383 | uh_gen.calculate_moment() 384 | uh_ran = PeptideDescriptor(ran_desc.sequences, 'eisenberg') 385 | uh_ran.calculate_moment() 386 | uh_hel = PeptideDescriptor(hel_desc.sequences, 'eisenberg') 387 | uh_hel.calculate_moment() 388 | f.write("\n\nHYDROPHOBIC MOMENTS\n\n") 389 | f.write("Hydrophobic moment of training seqs:\t%.3f +/- %.3f\n" % 390 | (np.mean(uh_seq.descriptor), np.std(uh_seq.descriptor))) 391 | f.write("Hydrophobic moment of sampled seqs:\t\t%.3f +/- %.3f\n" % 392 | (np.mean(uh_gen.descriptor), np.std(uh_gen.descriptor))) 393 | f.write("Hydrophobic moment of random seqs:\t\t%.3f +/- %.3f\n" % 394 | (np.mean(uh_ran.descriptor), np.std(uh_ran.descriptor))) 395 | f.write("Hydrophobic moment of amphipathic seqs:\t%.3f +/- %.3f\n" % 396 | (np.mean(uh_hel.descriptor), np.std(uh_hel.descriptor))) 397 | 398 | if plot: 399 | if self.refs: 400 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences, uh_hel.sequences, uh_ran.sequences], 401 | ['training', 'sampled', 'hel', 'ran']) 402 | else: 403 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences], ['training', 'sampled']) 404 | a.plot_summary(filename=fname[:-4] + '.png') 405 | 406 | def save_generated(self, logdir, filename): 407 | """ Save all sequences in `self.generated` to file 408 | 409 | :param logdir: {str} current log directory (used for comparison sequences) 410 | :param filename: {str} filename to save the sequences to 411 | :return: saved file 412 | """ 413 | with open(filename, 'w') as f: 414 | for s in self.generated: 415 | f.write(s + '\n') 416 | 417 | self.ran.save_fasta(logdir + '/random_sequences.fasta') 418 | self.hel.save_fasta(logdir + '/helical_sequences.fasta') 419 | 420 | 421 | class Model(object): 422 | """ 423 | Class containing the LSTM model to learn sequential data 424 | """ 425 | 426 | def __init__(self, n_vocab, outshape, session_name, cell="LSTM", n_units=256, batch=64, layers=2, lr=0.001, 427 | dropoutfract=0.1, loss='categorical_crossentropy', l2_reg=None, ask=True, seed=42): 428 | """ Initialize the model 429 | 430 | :param n_vocab: {int} length of vocabulary 431 | :param outshape: {int} output dimensionality of the model 432 | :param session_name: {str} custom name for the current session. Will create directory with this name to save 433 | results / logs to. 434 | :param n_units: {int} number of LSTM units per layer 435 | :param batch: {int} batch size 436 | :param layers: {int} number of layers in the network 437 | :param loss: {str} applied loss function, choose from available keras loss functions 438 | :param lr: {float} learning rate to use with Adam optimizer 439 | :param dropoutfract: {float} fraction of dropout to add to each layer. Layer1 gets 1 * value, Layer2 2 * 440 | value and so on. 441 | :param l2_reg: {float} l2 regularization for kernel 442 | :param seed {int} random seed used to initialize weights 443 | """ 444 | random.seed(seed) 445 | self.seed = seed 446 | self.dropout = dropoutfract 447 | self.inshape = (None, n_vocab) 448 | self.outshape = outshape 449 | self.neurons = n_units 450 | self.layers = layers 451 | self.losses = list() 452 | self.val_losses = list() 453 | self.batchsize = batch 454 | self.lr = lr 455 | self.cv_loss = None 456 | self.cv_loss_std = None 457 | self.cv_val_loss = None 458 | self.cv_val_loss_std = None 459 | self.model = None 460 | self.cell = cell 461 | self.losstype = loss 462 | self.session_name = session_name 463 | self.logdir = './' + session_name 464 | self.l2 = l2_reg 465 | if ask and os.path.exists(self.logdir): 466 | decision = input('\nSession folder already exists!\n' 467 | 'Do you want to overwrite the previous session? [y/n] ') 468 | if decision in ['n', 'no', 'N', 'NO', 'No']: 469 | self.logdir = './' + input('Enter new session name: ') 470 | os.makedirs(self.logdir) 471 | self.checkpointdir = self.logdir + '/checkpoint/' 472 | if not os.path.exists(self.checkpointdir): 473 | os.makedirs(self.checkpointdir) 474 | _, _, self.vocab = _onehotencode('A') 475 | 476 | self.initialize_model(seed=self.seed) 477 | 478 | def initialize_model(self, seed=42): 479 | """ Method to initialize the model with all parameters saved in the attributes. This method is used during 480 | initialization of the class, as well as in cross-validation to reinitialize a fresh model for every fold. 481 | 482 | :param seed: {int} random seed to use for weight initialization 483 | 484 | :return: initialized model in ``self.model`` 485 | """ 486 | self.losses = list() 487 | self.val_losses = list() 488 | self.cv_loss = None 489 | self.cv_loss_std = None 490 | self.cv_val_loss = None 491 | self.cv_val_loss_std = None 492 | self.model = None 493 | weight_init = RandomNormal(mean=0.0, stddev=0.05, seed=seed) # weights randomly between -0.05 and 0.05 494 | optimizer = Adam(lr=self.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) #learning_rate 495 | 496 | if self.l2: 497 | l2reg = l2(self.l2) 498 | else: 499 | l2reg = None 500 | 501 | self.model = Sequential() 502 | for l in range(self.layers): 503 | if self.cell == "GRU": 504 | self.model.add(GRU(units=self.neurons, 505 | name='GRU%i' % (l + 1), 506 | input_shape=self.inshape, 507 | return_sequences=True, 508 | kernel_initializer=weight_init, 509 | kernel_regularizer=l2reg, 510 | dropout=self.dropout * (l + 1))) 511 | else: 512 | self.model.add(LSTM(units=self.neurons, 513 | name='LSTM%i' % (l + 1), 514 | input_shape=self.inshape, 515 | return_sequences=True, 516 | kernel_initializer=weight_init, 517 | kernel_regularizer=l2reg, 518 | dropout=self.dropout * (l + 1), 519 | recurrent_dropout=self.dropout * (l + 1))) 520 | self.model.add(Dense(self.outshape, 521 | name='Dense', 522 | activation='softmax', 523 | kernel_regularizer=self.l2, 524 | kernel_initializer=weight_init)) 525 | self.model.compile(loss=self.losstype, optimizer=optimizer) 526 | with open(self.checkpointdir + "model.json", 'w') as f: 527 | json.dump(self.model.to_json(), f) 528 | self.model.summary() 529 | 530 | def finetuneinit(self, session_name): 531 | """ Method to generate a new directory for finetuning a pre-existing model on a new dataset with a new name 532 | 533 | :param session_name: {str} new session name for finetuning 534 | :return: generates all necessary session folders 535 | """ 536 | self.session_name = session_name 537 | self.logdir = './' + session_name 538 | if os.path.exists(self.logdir): 539 | decision = input('\nSession folder already exists!\n' 540 | 'Do you want to overwrite the previous session? [y/n] ') 541 | if decision in ['n', 'no', 'N', 'NO', 'No']: 542 | self.logdir = './' + input('Enter new session name: ') 543 | os.makedirs(self.logdir) 544 | self.checkpointdir = self.logdir + '/checkpoint/' 545 | if not os.path.exists(self.checkpointdir): 546 | os.makedirs(self.checkpointdir) 547 | 548 | def train(self, x, y, epochs=100, valsplit=0.2, sample=100): 549 | """ Train the model on given training data. 550 | 551 | :param x: {array} training data 552 | :param y: {array} targets for training data in X 553 | :param epochs: {int} number of epochs to train 554 | :param valsplit: {float} fraction of data that should be used as validation data during training 555 | :param sample: {int} number of sequences to sample after every training epoch 556 | :return: trained model and measured losses in self.model, self.losses and self.val_losses 557 | """ 558 | writer = tf.summary.create_file_writer('./logs/' + self.session_name) 559 | with writer.as_default(): 560 | for e in range(epochs): 561 | print("Epoch %i" % e) 562 | checkpoints = [ModelCheckpoint(filepath=self.checkpointdir + 'model_epoch_%i.hdf5' % e, verbose=0)] 563 | train_history = self.model.fit(x, y, epochs=1, batch_size=self.batchsize, validation_split=valsplit, 564 | shuffle=False, callbacks=checkpoints) 565 | tf.summary.scalar('loss', train_history.history['loss'][-1], step=e) 566 | self.losses.append(train_history.history['loss']) 567 | if valsplit > 0.: 568 | self.val_losses.append(train_history.history['val_loss']) 569 | tf.summary.scalar('val_loss', train_history.history['val_loss'][-1], step=e) 570 | if sample: 571 | for s in self.sample(sample): # sample sequences after every training epoch 572 | print(s) 573 | writer.close() 574 | 575 | def cross_val(self, x, y, epochs=100, cv=5, plot=True): 576 | """ Method to perform cross-validation with the model given data X, y 577 | 578 | :param x: {array} training data 579 | :param y: {array} targets for training data in X 580 | :param epochs: {int} number of epochs to train 581 | :param cv: {int} fold 582 | :param plot: {bool} whether the losses should be plotted and saved to the session folder 583 | :return: 584 | """ 585 | self.losses = list() # clean losses if already present 586 | self.val_losses = list() 587 | kf = KFold(n_splits=cv) 588 | cntr = 0 589 | for train, test in kf.split(x): 590 | print("\nFold %i" % (cntr + 1)) 591 | self.initialize_model(seed=cntr) # reinitialize every fold, otherwise it will "remember" previous data 592 | train_history = self.model.fit(x[train], y[train], epochs=epochs, batch_size=self.batchsize, 593 | validation_data=(x[test], y[test])) 594 | self.losses.append(train_history.history['loss']) 595 | self.val_losses.append(train_history.history['val_loss']) 596 | cntr += 1 597 | self.cv_loss = np.mean(self.losses, axis=0) 598 | self.cv_loss_std = np.std(self.losses, axis=0) 599 | self.cv_val_loss = np.mean(self.val_losses, axis=0) 600 | self.cv_val_loss_std = np.std(self.val_losses, axis=0) 601 | if plot: 602 | self.plot_losses(cv=True) 603 | 604 | # get best epoch with corresponding val_loss 605 | minloss = np.min(self.cv_val_loss) 606 | e = np.where(minloss == self.cv_val_loss)[0][0] 607 | print("\n%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 608 | with open(self.logdir + '/' + self.session_name + '_best_epoch.txt', 'w') as f: 609 | f.write("%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 610 | 611 | def plot_losses(self, show=False, cv=False): 612 | """Plot the losses obtained in training. 613 | 614 | :param show: {bool} Whether the plot should be shown or saved. If ``False``, the plot is saved to the 615 | session folder. 616 | :param cv: {bool} Whether the losses from cross-validation should be plotted. The standard deviation will be 617 | depicted as filled areas around the mean curve. 618 | :return: plot (saved) or shown interactive 619 | """ 620 | fig, ax = plt.subplots() 621 | ax.set_title('LSTM Categorical Crossentropy Loss Plot', fontweight='bold', fontsize=16) 622 | if cv: 623 | filename = self.logdir + '/' + self.session_name + '_cv_loss_plot.pdf' 624 | x = range(1, len(self.cv_loss) + 1) 625 | ax.plot(x, self.cv_loss, '-', color='#FE4365', label='Training') 626 | ax.plot(x, self.cv_val_loss, '-', color='k', label='Validation') 627 | ax.fill_between(x, self.cv_loss + self.cv_loss_std, self.cv_loss - self.cv_loss_std, 628 | facecolors='#FE4365', alpha=0.5) 629 | ax.fill_between(x, self.cv_val_loss + self.cv_val_loss_std, self.cv_val_loss - self.cv_val_loss_std, 630 | facecolors='k', alpha=0.5) 631 | ax.set_xlim([0.5, len(self.cv_loss) + 0.5]) 632 | minloss = np.min(self.cv_val_loss) 633 | plt.text(x=0.5, y=0.5, s='best epoch: ' + str(np.where(minloss == self.cv_val_loss)[0][0]) + ', val_loss: ' 634 | + str(minloss.round(4)), transform=ax.transAxes) 635 | else: 636 | filename = self.logdir + '/' + self.session_name + '_loss_plot.pdf' 637 | x = range(1, len(self.losses) + 1) 638 | ax.plot(x, self.losses, '-', color='#FE4365', label='Training') 639 | if self.val_losses: 640 | ax.plot(x, self.val_losses, '-', color='k', label='Validation') 641 | ax.set_xlim([0.5, len(self.losses) + 0.5]) 642 | ax.set_ylabel('Loss', fontweight='bold', fontsize=14) 643 | ax.set_xlabel('Epoch', fontweight='bold', fontsize=14) 644 | ax.spines['right'].set_visible(False) 645 | ax.spines['top'].set_visible(False) 646 | ax.xaxis.set_ticks_position('bottom') 647 | ax.yaxis.set_ticks_position('left') 648 | plt.legend(loc='best') 649 | if show: 650 | plt.show() 651 | else: 652 | plt.savefig(filename) 653 | 654 | def sample(self, num=100, minlen=7, maxlen=50, start=None, temp=2.5, show=False): 655 | """Invoke generation of sequence patterns through sampling from the trained model. 656 | 657 | :param num: {int} number of sequences to sample 658 | :param minlen {int} minimal allowed sequence length 659 | :param maxlen: {int} maximal length of each pattern generated, if 0, a random length is chosen between 7 and 50 660 | :param start: {str} start AA to be used for sampling. If ``None``, a random AA is chosen 661 | :param temp: {float} temperature value to sample at. 662 | :param show: {bool} whether the sampled sequences should be printed out 663 | :return: {array} matrix of patterns of shape (num, seqlen, inputshape[0]) 664 | """ 665 | print("\nSampling...\n") 666 | sampled = [] 667 | lcntr = 0 668 | pbar = ProgressBar() 669 | for rs in pbar(range(num)): 670 | random.seed(rs) 671 | if not maxlen: # if the length should be randomly sampled 672 | longest = np.random.randint(7, 50) 673 | else: 674 | longest = maxlen 675 | 676 | if start: 677 | start_aa = start 678 | else: # generate random starting letter 679 | start_aa = 'j' 680 | sequence = start_aa # start with starting letter 681 | 682 | while sequence[-1] != ' ' and len(sequence) <= longest: # sample until padding or maxlen is reached 683 | x, _, _ = _onehotencode(sequence) 684 | preds = self.model.predict(x)[0][-1] 685 | next_aa = _sample_with_temp(preds, temp=temp) 686 | sequence += self.vocab[next_aa] 687 | 688 | if start_aa == 'j': 689 | sequence = sequence[1:].rstrip() 690 | else: # keep starting AA if chosen for sampling 691 | sequence = sequence.rstrip() 692 | 693 | if len(sequence) < minlen: # don't take sequences shorter than the minimal length 694 | lcntr += 1 695 | continue 696 | 697 | sampled.append(sequence) 698 | if show: 699 | print(sequence) 700 | 701 | print("\t%i sequences were shorter than %i" % (lcntr, minlen)) 702 | return sampled 703 | 704 | def load_model(self, filename): 705 | """Method to load a trained model from a hdf5 file 706 | 707 | :return: model loaded from file in ``self.model`` 708 | """ 709 | self.model.load_weights(filename) 710 | 711 | # def get_num_params(self): 712 | # """Method to get the amount of trainable parameters in the model. 713 | # """ 714 | # trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 715 | # non_trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.non_trainable_variables()]) 716 | # print('\nMODEL PARAMETERS') 717 | # print('Total parameters: %i' % (trainable + non_trainable)) 718 | # print('Trainable parameters: %i' % trainable) 719 | # print('Non-trainable parameters: %i' % non_trainable) 720 | 721 | 722 | def main(infile, sessname, neurons=64, layers=2, epochs=100, batchsize=128, window=0, step=1, target='all', 723 | valsplit=0.2, sample=100, aa='j', temperature=2.5, cell="LSTM", dropout=0.1, train=False, learningrate=0.01, 724 | modfile=None, samplelength=36, pad=0, l2_rate=None, cv=None, finetune=True, references=True): 725 | # loading sequence data, analyze, pad and encode it 726 | data = SequenceHandler(window=window, step=step, refs=references) 727 | print("Loading sequences...") 728 | data.load_sequences(infile) 729 | data.analyze_training() 730 | 731 | # pad sequences 732 | print("\nPadding sequences...") 733 | data.pad_sequences(padlen=pad) 734 | 735 | # one-hot encode padded sequences 736 | print("One-hot encoding sequences...") 737 | data.one_hot_encode(target=target) 738 | 739 | if train: 740 | # building the LSTM model 741 | print("\nBuilding model...") 742 | model = Model(n_vocab=len(data.vocab), outshape=len(data.vocab), session_name=sessname, n_units=neurons, 743 | batch=batchsize, layers=layers, cell=cell, loss='categorical_crossentropy', lr=learningrate, 744 | dropoutfract=dropout, l2_reg=l2_rate, ask=True, seed=42) 745 | print("Model built!") 746 | 747 | if cv: 748 | print("\nPERFORMING %i-FOLD CROSS-VALIDATION...\n" % cv) 749 | model.cross_val(data.X, data.y, epochs=epochs, cv=cv) 750 | model.initialize_model(seed=42) 751 | model.train(data.X, data.y, epochs=epochs, valsplit=0.0, sample=0) 752 | model.plot_losses() 753 | else: 754 | # training model on data 755 | print("\nTRAINING MODEL FOR %i EPOCHS...\n" % epochs) 756 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 757 | model.plot_losses() # plot loss 758 | 759 | save_model_instance(model) 760 | 761 | elif finetune: 762 | print("\nUSING PRETRAINED MODEL FOR FINETUNING... (%s)\n" % modfile) 763 | print("Loading model...") 764 | model = load_model_instance(modfile) 765 | model.load_model(modfile) 766 | model.finetuneinit(sessname) # generate new session folders for finetuning run 767 | print("Finetuning model...") 768 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 769 | model.plot_losses() # plot loss 770 | save_model_instance(model) 771 | else: 772 | print("\nUSING PRETRAINED MODEL... (%s)\n" % modfile) 773 | model = load_model_instance(modfile) 774 | model.load_model(modfile) 775 | 776 | print(model.model.summary()) # print number of parameters in the model 777 | 778 | # generating new data through sampling 779 | print("\nSAMPLING %i SEQUENCES...\n" % sample) 780 | data.generated = model.sample(sample, start=aa, maxlen=samplelength, show=False, temp=temperature) 781 | data.analyze_generated(sample, fname=model.logdir + '/analysis_temp' + str(temperature) + '.txt', plot=True) 782 | data.save_generated(model.logdir, model.logdir + '/sampled_sequences_temp' + str(temperature) + '.csv') 783 | 784 | 785 | if __name__ == "__main__": 786 | # run main code 787 | main(infile=args.dataset, sessname=args.name, batchsize=args.batch_size, epochs=args.epochs, 788 | layers=args.layers, valsplit=args.valsplit, neurons=args.neurons, cell=args.cell, sample=args.sample, 789 | temperature=args.temp, dropout=args.dropout, train=False, modfile=args.modfile, 790 | learningrate=args.lr, cv=args.cv, samplelength=args.maxlen, window=args.window, 791 | step=args.step, aa=args.startchar, l2_rate=args.l2, target=args.target, pad=args.padlen, 792 | finetune=True, references=args.refs) 793 | 794 | # save used flags to log file 795 | _save_flags("./" + args.name + "/flags.txt") 796 | -------------------------------------------------------------------------------- /LSTM_peptides/LSTM_peptides_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | ..author:: Alex Müller, ETH Zürich, Switzerland. 5 | ..date:: September 2017 6 | 7 | Code for training a LSTM model on peptide sequences followed by sampling novel sequences through the model. 8 | Check the readme for possible flags to use with this script. 9 | """ 10 | import json 11 | import os 12 | import pickle 13 | import random 14 | import argparse 15 | 16 | 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from tensorflow.keras.callbacks import ModelCheckpoint 23 | from tensorflow.keras.initializers import RandomNormal 24 | from tensorflow.keras.layers import Dense, LSTM, GRU 25 | from tensorflow.keras.models import Sequential, load_model 26 | from tensorflow.keras.optimizers import Adam 27 | from tensorflow.keras.regularizers import l2 28 | from modlamp.analysis import GlobalAnalysis 29 | from modlamp.core import count_aas 30 | from modlamp.descriptors import PeptideDescriptor, GlobalDescriptor 31 | from modlamp.sequences import Random, Helices 32 | from progressbar import ProgressBar 33 | from scipy.spatial import distance 34 | from sklearn.model_selection import KFold 35 | from sklearn.preprocessing import StandardScaler 36 | 37 | plt.switch_backend('agg') 38 | flags = argparse.ArgumentParser() 39 | flags.add_argument("-d", "--dataset", default="training_sequences_noC.csv", help="dataset file (expecting csv)", type=str) 40 | flags.add_argument("-n", "--name", default="test", help="run name for log and checkpoint files", type=str) 41 | flags.add_argument("-b", "--batch_size", default=128, help="batch size", type=int) 42 | flags.add_argument("-e", "--epochs", default=50, help="epochs to train", type=int) 43 | flags.add_argument("-l", "--layers", default=2, help="number of layers in the network", type=int) 44 | flags.add_argument("-x", "--neurons", default=256, help="number of units per layer", type=int) 45 | flags.add_argument("-c", "--cell", default="LSTM", help="type of neuron to use, available: LSTM, GRU", type=str) 46 | flags.add_argument("-o", "--dropout", default=0.1, help="dropout to use in every layer; layer 1 gets 1*dropout, layer 2 2*dropout etc.", type=float) 47 | flags.add_argument("-t", "--train", default=False, help="whether the network should be trained or just sampled from", type=bool) 48 | flags.add_argument("-v", "--valsplit", default=0.2, help="fraction of the data to use for validation", type=float) 49 | flags.add_argument("-s", "--sample", default=100, help="number of sequences to sample training", type=int) 50 | flags.add_argument("-p", "--temp", default=1.25, help="temperature used for sampling", type=float) 51 | flags.add_argument("-m", "--maxlen", default=0, help="maximum sequence length allowed when sampling new sequences", type=int) 52 | flags.add_argument("-a", "--startchar", default="j", help="starting character to begin sampling. Default='j' for 'begin'", type=str) 53 | flags.add_argument("-r", "--lr", default=0.01, help="learning rate to be used with the Adam optimizer", type=float) 54 | flags.add_argument("--l2", default=None, help="l2 regularization rate. If None, no l2 regularization is used", type=float) 55 | flags.add_argument("--modfile", default=None, help="filename of the pretrained model to used for sampling if train=False", type=str) 56 | flags.add_argument("--finetune", default=True, help="if True, a pretrained model provided in modfile is finetuned on the dataset", type=bool) 57 | flags.add_argument("--cv", default=None, help="number of folds to use for cross-validation; if None, no CV is performed", type=int) 58 | flags.add_argument("--window", default=0, help="window size used to process sequences. If 0, all sequences are padded to the longest sequence length in the dataset", type=int) 59 | flags.add_argument("--step", default=1, help="step size to move window or prediction target", type=int) 60 | flags.add_argument("--target", default="all", help="whether to learn all proceeding characters or just the last `one` in sequence", type=str) 61 | flags.add_argument("--padlen", default=0, help="number of spaces to use for padding sequences (if window not 0); if 0, sequences are padded to the length of the longest sequence in the dataset", type=int) 62 | flags.add_argument("--refs", default=True, help="whether reference sequence sets should be generated for the analysis", type=bool) 63 | args = flags.parse_args() 64 | 65 | 66 | 67 | def _save_flags(filename): 68 | """ Function to save used arguments to log-file 69 | 70 | :return: saved file 71 | """ 72 | with open(filename, 'w') as f: 73 | f.write("Used flags:\n-----------\n") 74 | json.dump(args.__dict__, f, indent=2) 75 | 76 | 77 | def _onehotencode(s, vocab=None): 78 | """ Function to one-hot encode a sring. 79 | 80 | :param s: {str} String to encode in one-hot fashion 81 | :param vocab: vocabulary to use fore encoding, if None, default AAs are used 82 | :return: one-hot encoded string as a np.array 83 | : j init char " " pading char 84 | """ 85 | if not vocab: 86 | vocab = ['j','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', ' '] 87 | 88 | # generate translation dictionary for one-hot encoding 89 | to_one_hot = dict() 90 | for i, a in enumerate(vocab): 91 | v = np.zeros(len(vocab)) 92 | v[i] = 1 93 | to_one_hot[a] = v 94 | 95 | result = [] 96 | for l in s: 97 | result.append(to_one_hot[l]) 98 | result = np.array(result) 99 | return np.reshape(result, (1, result.shape[0], result.shape[1])), to_one_hot, vocab 100 | 101 | 102 | def _onehotdecode(matrix, vocab=None, filename=None): 103 | """ Decode a given one-hot represented matrix back into sequences 104 | 105 | :param matrix: matrix containing sequence patterns that are one-hot encoded 106 | :param vocab: vocabulary, if None, standard AAs are used 107 | :param filename: filename for saving sequences, if ``None``, sequences are returned in a list 108 | :return: list of decoded sequences in the range lenmin-lenmax, if ``filename``, they are saved to a file 109 | """ 110 | if not vocab: 111 | _, _, vocab = _onehotencode('A') 112 | if len(matrix.shape) == 2: # if a matrix containing only one string is supplied 113 | result = [] 114 | for i in range(matrix.shape[0]): 115 | for j in range(matrix.shape[1]): 116 | aa = np.where(matrix[i, j] == 1.)[0][0] 117 | result.append(vocab[aa]) 118 | seq = ''.join(result) 119 | if filename: 120 | with open(filename, 'wb') as f: 121 | f.write(seq) 122 | else: 123 | return seq 124 | 125 | elif len(matrix.shape) == 3: # if a matrix containing several strings is supplied 126 | result = [] 127 | for n in range(matrix.shape[0]): 128 | oneresult = [] 129 | for i in range(matrix.shape[1]): 130 | for j in range(matrix.shape[2]): 131 | aa = np.where(matrix[n, i, j] == 1.)[0][0] 132 | oneresult.append(vocab[aa]) 133 | seq = ''.join(oneresult) 134 | result.append(seq) 135 | if filename: 136 | with open(filename, 'wb') as f: 137 | for s in result: 138 | f.write(s + '\n') 139 | else: 140 | return result 141 | 142 | 143 | def _sample_with_temp(preds, temp=1.0): 144 | """ Helper function to sample one letter from a probability array given a temperature. 145 | 146 | :param preds: {np.array} predictions returned by the network 147 | :param temp: {float} temperature value to sample at. 148 | """ 149 | streched = np.log(preds) / temp 150 | stretched_probs = np.exp(streched) / np.sum(np.exp(streched)) 151 | return np.random.choice(len(streched), p=stretched_probs) 152 | 153 | 154 | def load_model_instance(filename): 155 | """ Load a whole Model class instance from a given epoch file 156 | 157 | :param filename: epoch file, e.g. model_epoch_5.hdf5 158 | :return: model instance with trained weights 159 | """ 160 | modfile = os.path.dirname(filename) + '/model.p' 161 | mod = pickle.load(open(modfile, 'rb')) 162 | hdf5_file = ''.join(modfile.split('.')[:-1]) + '.hdf5' 163 | mod.model = load_model(hdf5_file) 164 | return mod 165 | 166 | 167 | def save_model_instance(mod): 168 | """ Save a whole Model instance and the corresponding model with weights to two files (model.p and model.hdf5) 169 | 170 | :param mod: model instance 171 | :return: saved model files in the checkpoint dir 172 | """ 173 | tmp = mod.model 174 | tmp.save(mod.checkpointdir + 'model.hdf5') 175 | mod.model = None 176 | pickle.dump(mod, open(mod.checkpointdir + 'model.p', 'wb')) 177 | mod.model = tmp 178 | 179 | 180 | class SequenceHandler(object): 181 | """ Class for handling peptide sequences, e.g. loading, one-hot encoding or decoding and saving """ 182 | 183 | def __init__(self, window=0, step=2, refs=True): 184 | """ 185 | :param window: {str} window used for chopping up sequences. If 0: False 186 | :param step: {int} size of the steps to move the window forward 187 | :param refs {bool} whether to generate reference sequence sets for analysis 188 | """ 189 | self.sequences = None 190 | self.generated = None 191 | self.ran = None 192 | self.hel = None 193 | self.X = list() 194 | self.y = list() 195 | self.window = window 196 | self.step = step 197 | self.refs = refs 198 | # generate translation dictionary for one-hot encoding 199 | _, self.to_one_hot, self.vocab = _onehotencode('A') 200 | 201 | def load_sequences(self, filename): 202 | """ Method to load peptide sequences from a csv file 203 | 204 | :param filename: {str} filename of the sequence file to be read (``csv``, one sequence per line) 205 | :return: sequences in self.sequences 206 | """ 207 | with open(filename) as f: 208 | self.sequences = [s.strip() for s in f] 209 | self.sequences = random.sample(self.sequences, len(self.sequences)) # shuffle sequences randomly 210 | 211 | def pad_sequences(self, pad_char=' ', padlen=0): 212 | """ Pad all sequences to the longest length (default, padlen=0) or a given length 213 | 214 | :param pad_char: {str} Character to pad sequences with 215 | :param padlen: {int} Custom length of padding to add to all sequences to (optional), default: 0. If 216 | 0, sequences are padded to the length of the longest sequence in the training set. If a window is used and the 217 | padded sequence is shorter than the window size, it is padded to fit the window. 218 | """ 219 | if padlen: 220 | padded_seqs = [] 221 | for seq in self.sequences: 222 | if len(seq) < self.window: 223 | padded_seq = seq + pad_char * (self.step + self.window - len(seq)) 224 | else: 225 | padded_seq = seq + pad_char * padlen 226 | padded_seqs.append(padded_seq) 227 | else: 228 | length = max([len(seq) for seq in self.sequences]) 229 | padded_seqs = [] 230 | for seq in self.sequences: 231 | padded_seq = 'j' + seq + pad_char * (length - len(seq)) 232 | padded_seqs.append(padded_seq) 233 | 234 | if pad_char not in self.vocab: 235 | self.vocab += [pad_char] 236 | 237 | self.sequences = padded_seqs # overwrite sequences with padded sequences 238 | 239 | def one_hot_encode(self, target='all'): 240 | """ Chop up loaded sequences into patterns of length ``window`` by moving by stepsize ``step`` and translate 241 | them with a one-hot vector encoding 242 | 243 | :param target: {str} whether all proceeding AA should be learned or just the last one in sequence (`all`, `one`) 244 | :return: one-hot encoded sequence patterns in self.X and corresponding target amino acids in self.y 245 | """ 246 | if self.window == 0: 247 | for s in self.sequences: 248 | self.X.append([self.to_one_hot[char] for char in s[:-self.step]]) 249 | if target == 'all': 250 | self.y.append([self.to_one_hot[char] for char in s[self.step:]]) 251 | elif target == 'one': 252 | self.y.append(s[-self.step:]) 253 | 254 | self.X = np.reshape(self.X, (len(self.X), len(self.sequences[0]) - self.step, len(self.vocab))) 255 | self.y = np.reshape(self.y, (len(self.y), len(self.sequences[0]) - self.step, len(self.vocab))) 256 | 257 | else: 258 | for s in self.sequences: 259 | for i in range(0, len(s) - self.window, self.step): 260 | self.X.append([self.to_one_hot[char] for char in s[i: i + self.window]]) 261 | if target == 'all': 262 | self.y.append([self.to_one_hot[char] for char in s[i + 1: i + self.window + 1]]) 263 | elif target == 'one': 264 | self.y.append(s[-self.step:]) 265 | 266 | self.X = np.reshape(self.X, (len(self.X), self.window, len(self.vocab))) 267 | self.y = np.reshape(self.y, (len(self.y), self.window, len(self.vocab))) 268 | 269 | print("\nData shape:\nX: " + str(self.X.shape) + "\ny: " + str(self.y.shape)) 270 | 271 | def analyze_training(self): 272 | """ Method to analyze the distribution of the training data 273 | 274 | :return: prints out information about the length distribution of the sequences in ``self.sequences`` 275 | """ 276 | d = GlobalDescriptor(self.sequences) 277 | d.length() 278 | print("\nLENGTH DISTRIBUTION OF TRAINING DATA:\n") 279 | print("Number of sequences: \t%i" % len(self.sequences)) 280 | print("Mean sequence length: \t%.1f ± %.1f" % (np.mean(d.descriptor), np.std(d.descriptor))) 281 | print("Median sequence length: \t%i" % np.median(d.descriptor)) 282 | print("Minimal sequence length:\t%i" % np.min(d.descriptor)) 283 | print("Maximal sequence length:\t%i" % np.max(d.descriptor)) 284 | 285 | def analyze_generated(self, num, fname='analysis.txt', plot=False,min_length_seq=3): 286 | """ Method to analyze the generated sequences located in `self.generated`. 287 | 288 | :param num: {int} wanted number of sequences to sample 289 | :param fname: {str} filename to save analysis info to 290 | :param plot: {bool} whether to plot an overview of descriptors 291 | :return: file with analysis info (distances) 292 | """ 293 | with open(fname, 'w') as f: 294 | print("Analyzing...") 295 | f.write("ANALYSIS OF SAMPLED SEQUENCES\n==============================\n\n") 296 | f.write("Nr. of duplicates in generated sequences: %i\n" % (len(self.generated) - len(set(self.generated)))) 297 | count = len(set(self.generated) & set(self.sequences)) # get shared entries in both lists 298 | f.write("%.1f percent of generated sequences are present in the training data.\n" % 299 | ((count / len(self.generated)) * 100)) 300 | d = GlobalDescriptor(self.generated) 301 | len1 = len(d.sequences) 302 | d.filter_aa('j') 303 | len2 = len(d.sequences) 304 | d.length() 305 | f.write("\n\nLENGTH DISTRIBUTION OF GENERATED DATA:\n\n") 306 | f.write("Number of sequences too short:\t%i\n" % (num - len1)) 307 | f.write("Number of invalid (with j):\t%i\n" % (len1 - len2)) 308 | f.write("Number of valid unique seqs:\t%i\n" % len2) 309 | f.write("Mean sequence length: \t\t%.1f ± %.1f\n" % (np.mean(d.descriptor), np.std(d.descriptor))) 310 | f.write("Median sequence length: \t\t%i\n" % np.median(d.descriptor)) 311 | f.write("Minimal sequence length: \t\t%i\n" % np.min(d.descriptor)) 312 | f.write("Maximal sequence length: \t\t%i\n" % np.max(d.descriptor)) 313 | 314 | 315 | self.sequences = [s[1:].rstrip() for s in self.sequences] 316 | 317 | self.analyze_training() 318 | 319 | d.sequences = [s for s in d.sequences] 320 | 321 | descriptor = 'pepcats' 322 | #seq_desc = PeptideDescriptor([s[1:].rstrip() for s in self.sequences], descriptor) 323 | seq_desc = PeptideDescriptor(self.sequences, descriptor) 324 | seq_desc.calculate_autocorr(min_length_seq) 325 | 326 | gen_desc = PeptideDescriptor(d.sequences, descriptor) 327 | gen_desc.calculate_autocorr(min_length_seq) 328 | 329 | # random comparison set 330 | self.ran = Random(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) # generate rand seqs 331 | probas = count_aas(''.join(seq_desc.sequences)).values() # get the aa distribution of training seqs 332 | self.ran.generate_sequences(proba=probas) 333 | ran_desc = PeptideDescriptor(self.ran.sequences, descriptor) 334 | ran_desc.calculate_autocorr(min_length_seq) 335 | 336 | # amphipathic helices comparison set 337 | self.hel = Helices(len(self.generated), np.min(d.descriptor), np.max(d.descriptor)) 338 | self.hel.generate_sequences() 339 | hel_desc = PeptideDescriptor(self.hel.sequences, descriptor) 340 | hel_desc.calculate_autocorr(min_length_seq) 341 | 342 | # distance calculation 343 | f.write("\n\nDISTANCE CALCULATION IN '%s' DESCRIPTOR SPACE\n\n" % descriptor.upper()) 344 | desc_dist = distance.cdist(gen_desc.descriptor, seq_desc.descriptor, metric='euclidean') 345 | f.write("Average euclidean distance of sampled to training data:\t%.3f +/- %.3f\n" % 346 | (np.mean(desc_dist), np.std(desc_dist))) 347 | ran_dist = distance.cdist(ran_desc.descriptor, seq_desc.descriptor, metric='euclidean') 348 | f.write("Average euclidean distance if randomly sampled seqs:\t%.3f +/- %.3f\n" % 349 | (np.mean(ran_dist), np.std(ran_dist))) 350 | hel_dist = distance.cdist(hel_desc.descriptor, seq_desc.descriptor, metric='euclidean') 351 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.3f +/- %.3f\n" % 352 | (np.mean(hel_dist), np.std(hel_dist))) 353 | 354 | # more simple descriptors 355 | g_seq = GlobalDescriptor(seq_desc.sequences) 356 | g_gen = GlobalDescriptor(gen_desc.sequences) 357 | g_ran = GlobalDescriptor(ran_desc.sequences) 358 | g_hel = GlobalDescriptor(hel_desc.sequences) 359 | g_seq.calculate_all() 360 | g_gen.calculate_all() 361 | g_ran.calculate_all() 362 | g_hel.calculate_all() 363 | sclr = StandardScaler() 364 | sclr.fit(g_seq.descriptor) 365 | f.write("\n\nDISTANCE CALCULATION FOR SCALED GLOBAL DESCRIPTORS\n\n") 366 | desc_dist = distance.cdist(sclr.transform(g_gen.descriptor), sclr.transform(g_seq.descriptor), 367 | metric='euclidean') 368 | f.write("Average euclidean distance of sampled to training data:\t%.2f +/- %.2f\n" % 369 | (np.mean(desc_dist), np.std(desc_dist))) 370 | ran_dist = distance.cdist(sclr.transform(g_ran.descriptor), sclr.transform(g_seq.descriptor), 371 | metric='euclidean') 372 | f.write("Average euclidean distance if randomly sampled seqs:\t%.2f +/- %.2f\n" % 373 | (np.mean(ran_dist), np.std(ran_dist))) 374 | hel_dist = distance.cdist(sclr.transform(g_hel.descriptor), sclr.transform(g_seq.descriptor), 375 | metric='euclidean') 376 | f.write("Average euclidean distance if amphipathic helical seqs:\t%.2f +/- %.2f\n" % 377 | (np.mean(hel_dist), np.std(hel_dist))) 378 | 379 | # hydrophobic moments 380 | uh_seq = PeptideDescriptor(seq_desc.sequences, 'eisenberg') 381 | uh_seq.calculate_moment() 382 | uh_gen = PeptideDescriptor(gen_desc.sequences, 'eisenberg') 383 | uh_gen.calculate_moment() 384 | uh_ran = PeptideDescriptor(ran_desc.sequences, 'eisenberg') 385 | uh_ran.calculate_moment() 386 | uh_hel = PeptideDescriptor(hel_desc.sequences, 'eisenberg') 387 | uh_hel.calculate_moment() 388 | f.write("\n\nHYDROPHOBIC MOMENTS\n\n") 389 | f.write("Hydrophobic moment of training seqs:\t%.3f +/- %.3f\n" % 390 | (np.mean(uh_seq.descriptor), np.std(uh_seq.descriptor))) 391 | f.write("Hydrophobic moment of sampled seqs:\t\t%.3f +/- %.3f\n" % 392 | (np.mean(uh_gen.descriptor), np.std(uh_gen.descriptor))) 393 | f.write("Hydrophobic moment of random seqs:\t\t%.3f +/- %.3f\n" % 394 | (np.mean(uh_ran.descriptor), np.std(uh_ran.descriptor))) 395 | f.write("Hydrophobic moment of amphipathic seqs:\t%.3f +/- %.3f\n" % 396 | (np.mean(uh_hel.descriptor), np.std(uh_hel.descriptor))) 397 | 398 | if plot: 399 | if self.refs: 400 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences, uh_hel.sequences, uh_ran.sequences], 401 | ['training', 'sampled', 'hel', 'ran']) 402 | else: 403 | a = GlobalAnalysis([uh_seq.sequences, uh_gen.sequences], ['training', 'sampled']) 404 | a.plot_summary(filename=fname[:-4] + '.png') 405 | 406 | def save_generated(self, logdir, filename): 407 | """ Save all sequences in `self.generated` to file 408 | 409 | :param logdir: {str} current log directory (used for comparison sequences) 410 | :param filename: {str} filename to save the sequences to 411 | :return: saved file 412 | """ 413 | with open(filename, 'w') as f: 414 | for s in self.generated: 415 | f.write(s + '\n') 416 | 417 | self.ran.save_fasta(logdir + '/random_sequences.fasta') 418 | self.hel.save_fasta(logdir + '/helical_sequences.fasta') 419 | 420 | 421 | class Model(object): 422 | """ 423 | Class containing the LSTM model to learn sequential data 424 | """ 425 | 426 | def __init__(self, n_vocab, outshape, session_name, cell="LSTM", n_units=256, batch=64, layers=2, lr=0.001, 427 | dropoutfract=0.1, loss='categorical_crossentropy', l2_reg=None, ask=True, seed=42): 428 | """ Initialize the model 429 | 430 | :param n_vocab: {int} length of vocabulary 431 | :param outshape: {int} output dimensionality of the model 432 | :param session_name: {str} custom name for the current session. Will create directory with this name to save 433 | results / logs to. 434 | :param n_units: {int} number of LSTM units per layer 435 | :param batch: {int} batch size 436 | :param layers: {int} number of layers in the network 437 | :param loss: {str} applied loss function, choose from available keras loss functions 438 | :param lr: {float} learning rate to use with Adam optimizer 439 | :param dropoutfract: {float} fraction of dropout to add to each layer. Layer1 gets 1 * value, Layer2 2 * 440 | value and so on. 441 | :param l2_reg: {float} l2 regularization for kernel 442 | :param seed {int} random seed used to initialize weights 443 | """ 444 | random.seed(seed) 445 | self.seed = seed 446 | self.dropout = dropoutfract 447 | self.inshape = (None, n_vocab) 448 | self.outshape = outshape 449 | self.neurons = n_units 450 | self.layers = layers 451 | self.losses = list() 452 | self.val_losses = list() 453 | self.batchsize = batch 454 | self.lr = lr 455 | self.cv_loss = None 456 | self.cv_loss_std = None 457 | self.cv_val_loss = None 458 | self.cv_val_loss_std = None 459 | self.model = None 460 | self.cell = cell 461 | self.losstype = loss 462 | self.session_name = session_name 463 | self.logdir = './' + session_name 464 | self.l2 = l2_reg 465 | if ask and os.path.exists(self.logdir): 466 | decision = input('\nSession folder already exists!\n' 467 | 'Do you want to overwrite the previous session? [y/n] ') 468 | if decision in ['n', 'no', 'N', 'NO', 'No']: 469 | self.logdir = './' + input('Enter new session name: ') 470 | os.makedirs(self.logdir) 471 | self.checkpointdir = self.logdir + '/checkpoint/' 472 | if not os.path.exists(self.checkpointdir): 473 | os.makedirs(self.checkpointdir) 474 | _, _, self.vocab = _onehotencode('A') 475 | 476 | self.initialize_model(seed=self.seed) 477 | 478 | def initialize_model(self, seed=42): 479 | """ Method to initialize the model with all parameters saved in the attributes. This method is used during 480 | initialization of the class, as well as in cross-validation to reinitialize a fresh model for every fold. 481 | 482 | :param seed: {int} random seed to use for weight initialization 483 | 484 | :return: initialized model in ``self.model`` 485 | """ 486 | self.losses = list() 487 | self.val_losses = list() 488 | self.cv_loss = None 489 | self.cv_loss_std = None 490 | self.cv_val_loss = None 491 | self.cv_val_loss_std = None 492 | self.model = None 493 | weight_init = RandomNormal(mean=0.0, stddev=0.05, seed=seed) # weights randomly between -0.05 and 0.05 494 | optimizer = Adam(lr=self.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) #learning_rate 495 | 496 | if self.l2: 497 | l2reg = l2(self.l2) 498 | else: 499 | l2reg = None 500 | 501 | self.model = Sequential() 502 | for l in range(self.layers): 503 | if self.cell == "GRU": 504 | self.model.add(GRU(units=self.neurons, 505 | name='GRU%i' % (l + 1), 506 | input_shape=self.inshape, 507 | return_sequences=True, 508 | kernel_initializer=weight_init, 509 | kernel_regularizer=l2reg, 510 | dropout=self.dropout * (l + 1))) 511 | else: 512 | self.model.add(LSTM(units=self.neurons, 513 | name='LSTM%i' % (l + 1), 514 | input_shape=self.inshape, 515 | return_sequences=True, 516 | kernel_initializer=weight_init, 517 | kernel_regularizer=l2reg, 518 | dropout=self.dropout * (l + 1), 519 | recurrent_dropout=self.dropout * (l + 1))) 520 | self.model.add(Dense(self.outshape, 521 | name='Dense', 522 | activation='softmax', 523 | kernel_regularizer=self.l2, 524 | kernel_initializer=weight_init)) 525 | self.model.compile(loss=self.losstype, optimizer=optimizer) 526 | with open(self.checkpointdir + "model.json", 'w') as f: 527 | json.dump(self.model.to_json(), f) 528 | self.model.summary() 529 | 530 | def finetuneinit(self, session_name): 531 | """ Method to generate a new directory for finetuning a pre-existing model on a new dataset with a new name 532 | 533 | :param session_name: {str} new session name for finetuning 534 | :return: generates all necessary session folders 535 | """ 536 | self.session_name = session_name 537 | self.logdir = './' + session_name 538 | if os.path.exists(self.logdir): 539 | decision = input('\nSession folder already exists!\n' 540 | 'Do you want to overwrite the previous session? [y/n] ') 541 | if decision in ['n', 'no', 'N', 'NO', 'No']: 542 | self.logdir = './' + input('Enter new session name: ') 543 | os.makedirs(self.logdir) 544 | self.checkpointdir = self.logdir + '/checkpoint/' 545 | if not os.path.exists(self.checkpointdir): 546 | os.makedirs(self.checkpointdir) 547 | 548 | def train(self, x, y, epochs=100, valsplit=0.2, sample=100): 549 | """ Train the model on given training data. 550 | 551 | :param x: {array} training data 552 | :param y: {array} targets for training data in X 553 | :param epochs: {int} number of epochs to train 554 | :param valsplit: {float} fraction of data that should be used as validation data during training 555 | :param sample: {int} number of sequences to sample after every training epoch 556 | :return: trained model and measured losses in self.model, self.losses and self.val_losses 557 | """ 558 | writer = tf.summary.create_file_writer('./logs/' + self.session_name) 559 | with writer.as_default(): 560 | for e in range(epochs): 561 | print("Epoch %i" % e) 562 | checkpoints = [ModelCheckpoint(filepath=self.checkpointdir + 'model_epoch_%i.hdf5' % e, verbose=0)] 563 | train_history = self.model.fit(x, y, epochs=1, batch_size=self.batchsize, validation_split=valsplit, 564 | shuffle=False, callbacks=checkpoints) 565 | tf.summary.scalar('loss', train_history.history['loss'][-1], step=e) 566 | self.losses.append(train_history.history['loss']) 567 | if valsplit > 0.: 568 | self.val_losses.append(train_history.history['val_loss']) 569 | tf.summary.scalar('val_loss', train_history.history['val_loss'][-1], step=e) 570 | if sample: 571 | for s in self.sample(sample): # sample sequences after every training epoch 572 | print(s) 573 | writer.close() 574 | 575 | def cross_val(self, x, y, epochs=100, cv=5, plot=True): 576 | """ Method to perform cross-validation with the model given data X, y 577 | 578 | :param x: {array} training data 579 | :param y: {array} targets for training data in X 580 | :param epochs: {int} number of epochs to train 581 | :param cv: {int} fold 582 | :param plot: {bool} whether the losses should be plotted and saved to the session folder 583 | :return: 584 | """ 585 | self.losses = list() # clean losses if already present 586 | self.val_losses = list() 587 | kf = KFold(n_splits=cv) 588 | cntr = 0 589 | for train, test in kf.split(x): 590 | print("\nFold %i" % (cntr + 1)) 591 | self.initialize_model(seed=cntr) # reinitialize every fold, otherwise it will "remember" previous data 592 | train_history = self.model.fit(x[train], y[train], epochs=epochs, batch_size=self.batchsize, 593 | validation_data=(x[test], y[test])) 594 | self.losses.append(train_history.history['loss']) 595 | self.val_losses.append(train_history.history['val_loss']) 596 | cntr += 1 597 | self.cv_loss = np.mean(self.losses, axis=0) 598 | self.cv_loss_std = np.std(self.losses, axis=0) 599 | self.cv_val_loss = np.mean(self.val_losses, axis=0) 600 | self.cv_val_loss_std = np.std(self.val_losses, axis=0) 601 | if plot: 602 | self.plot_losses(cv=True) 603 | 604 | # get best epoch with corresponding val_loss 605 | minloss = np.min(self.cv_val_loss) 606 | e = np.where(minloss == self.cv_val_loss)[0][0] 607 | print("\n%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 608 | with open(self.logdir + '/' + self.session_name + '_best_epoch.txt', 'w') as f: 609 | f.write("%i-fold cross-validation result:\n\nBest epoch:\t%i\nVal_loss:\t%.4f" % (cv, e, minloss)) 610 | 611 | def plot_losses(self, show=False, cv=False): 612 | """Plot the losses obtained in training. 613 | 614 | :param show: {bool} Whether the plot should be shown or saved. If ``False``, the plot is saved to the 615 | session folder. 616 | :param cv: {bool} Whether the losses from cross-validation should be plotted. The standard deviation will be 617 | depicted as filled areas around the mean curve. 618 | :return: plot (saved) or shown interactive 619 | """ 620 | fig, ax = plt.subplots() 621 | ax.set_title('LSTM Categorical Crossentropy Loss Plot', fontweight='bold', fontsize=16) 622 | if cv: 623 | filename = self.logdir + '/' + self.session_name + '_cv_loss_plot.pdf' 624 | x = range(1, len(self.cv_loss) + 1) 625 | ax.plot(x, self.cv_loss, '-', color='#FE4365', label='Training') 626 | ax.plot(x, self.cv_val_loss, '-', color='k', label='Validation') 627 | ax.fill_between(x, self.cv_loss + self.cv_loss_std, self.cv_loss - self.cv_loss_std, 628 | facecolors='#FE4365', alpha=0.5) 629 | ax.fill_between(x, self.cv_val_loss + self.cv_val_loss_std, self.cv_val_loss - self.cv_val_loss_std, 630 | facecolors='k', alpha=0.5) 631 | ax.set_xlim([0.5, len(self.cv_loss) + 0.5]) 632 | minloss = np.min(self.cv_val_loss) 633 | plt.text(x=0.5, y=0.5, s='best epoch: ' + str(np.where(minloss == self.cv_val_loss)[0][0]) + ', val_loss: ' 634 | + str(minloss.round(4)), transform=ax.transAxes) 635 | else: 636 | filename = self.logdir + '/' + self.session_name + '_loss_plot.pdf' 637 | x = range(1, len(self.losses) + 1) 638 | ax.plot(x, self.losses, '-', color='#FE4365', label='Training') 639 | if self.val_losses: 640 | ax.plot(x, self.val_losses, '-', color='k', label='Validation') 641 | ax.set_xlim([0.5, len(self.losses) + 0.5]) 642 | ax.set_ylabel('Loss', fontweight='bold', fontsize=14) 643 | ax.set_xlabel('Epoch', fontweight='bold', fontsize=14) 644 | ax.spines['right'].set_visible(False) 645 | ax.spines['top'].set_visible(False) 646 | ax.xaxis.set_ticks_position('bottom') 647 | ax.yaxis.set_ticks_position('left') 648 | plt.legend(loc='best') 649 | if show: 650 | plt.show() 651 | else: 652 | plt.savefig(filename) 653 | 654 | def sample(self, num=100, minlen=7, maxlen=50, start=None, temp=2.5, show=True): 655 | """Invoke generation of sequence patterns through sampling from the trained model. 656 | 657 | :param num: {int} number of sequences to sample 658 | :param minlen {int} minimal allowed sequence length 659 | :param maxlen: {int} maximal length of each pattern generated, if 0, a random length is chosen between 7 and 50 660 | :param start: {str} start AA to be used for sampling. If ``None``, a random AA is chosen 661 | :param temp: {float} temperature value to sample at. 662 | :param show: {bool} whether the sampled sequences should be printed out 663 | :return: {array} matrix of patterns of shape (num, seqlen, inputshape[0]) 664 | """ 665 | print("\nSampling...\n") 666 | sampled = [] 667 | lcntr = 0 668 | pbar = ProgressBar() 669 | for rs in pbar(range(num)): 670 | random.seed(rs) 671 | if not maxlen: # if the length should be randomly sampled 672 | longest = np.random.randint(7, 50) 673 | else: 674 | longest = maxlen 675 | 676 | if start: 677 | start_aa = start 678 | else: # generate random starting letter 679 | start_aa = 'j' 680 | sequence = start_aa # start with starting letter 681 | 682 | while sequence[-1] != ' ' and len(sequence) <= longest: # sample until padding or maxlen is reached 683 | x, _, _ = _onehotencode(sequence) 684 | preds = self.model.predict(x)[0][-1] 685 | next_aa = _sample_with_temp(preds, temp=temp) 686 | sequence += self.vocab[next_aa] 687 | 688 | if start_aa == 'j': 689 | sequence = sequence[1:].rstrip() 690 | else: # keep starting AA if chosen for sampling 691 | sequence = sequence.rstrip() 692 | 693 | #print(sequence,len(sequence)) 694 | 695 | if len(sequence) < minlen: # don't take sequences shorter than the minimal length 696 | lcntr += 1 697 | continue 698 | 699 | sampled.append(sequence) 700 | if show: 701 | print(sequence) 702 | 703 | print("\t%i sequences were shorter than %i" % (lcntr, minlen)) 704 | return sampled 705 | 706 | def load_model(self, filename): 707 | """Method to load a trained model from a hdf5 file 708 | 709 | :return: model loaded from file in ``self.model`` 710 | """ 711 | self.model.load_weights(filename) 712 | 713 | # def get_num_params(self): 714 | # """Method to get the amount of trainable parameters in the model. 715 | # """ 716 | # trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 717 | # non_trainable = np.sum([np.prod(v.get_shape().as_list()) for v in tf.non_trainable_variables()]) 718 | # print('\nMODEL PARAMETERS') 719 | # print('Total parameters: %i' % (trainable + non_trainable)) 720 | # print('Trainable parameters: %i' % trainable) 721 | # print('Non-trainable parameters: %i' % non_trainable) 722 | 723 | 724 | def main(infile, sessname, neurons=64, layers=2, epochs=100, batchsize=128, window=0, step=1, target='all', 725 | valsplit=0.2, sample=100, aa='j', temperature=2.5, cell="LSTM", dropout=0.1, train=False, learningrate=0.01, 726 | modfile=None, samplelength=36, pad=0, l2_rate=None, cv=None, finetune=True, references=True): 727 | # loading sequence data, analyze, pad and encode it 728 | data = SequenceHandler(window=window, step=step, refs=references) 729 | print("Loading sequences...") 730 | data.load_sequences(infile) 731 | data.analyze_training() 732 | 733 | # pad sequences 734 | print("\nPadding sequences...") 735 | data.pad_sequences(padlen=pad) 736 | 737 | # one-hot encode padded sequences 738 | print("One-hot encoding sequences...") 739 | data.one_hot_encode(target=target) 740 | 741 | if train: 742 | # building the LSTM model 743 | print("\nBuilding model...") 744 | model = Model(n_vocab=len(data.vocab), outshape=len(data.vocab), session_name=sessname, n_units=neurons, 745 | batch=batchsize, layers=layers, cell=cell, loss='categorical_crossentropy', lr=learningrate, 746 | dropoutfract=dropout, l2_reg=l2_rate, ask=True, seed=42) 747 | print("Model built!") 748 | 749 | if cv: 750 | print("\nPERFORMING %i-FOLD CROSS-VALIDATION...\n" % cv) 751 | model.cross_val(data.X, data.y, epochs=epochs, cv=cv) 752 | model.initialize_model(seed=42) 753 | model.train(data.X, data.y, epochs=epochs, valsplit=0.0, sample=0) 754 | model.plot_losses() 755 | else: 756 | # training model on data 757 | print("\nTRAINING MODEL FOR %i EPOCHS...\n" % epochs) 758 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 759 | model.plot_losses() # plot loss 760 | 761 | save_model_instance(model) 762 | 763 | elif finetune: 764 | print("\nUSING PRETRAINED MODEL FOR FINETUNING... (%s)\n" % modfile) 765 | print("Loading model...") 766 | model = load_model_instance(modfile) 767 | model.load_model(modfile) 768 | model.finetuneinit(sessname) # generate new session folders for finetuning run 769 | print("Finetuning model...") 770 | model.train(data.X, data.y, epochs=epochs, valsplit=valsplit, sample=0) 771 | model.plot_losses() # plot loss 772 | save_model_instance(model) 773 | else: 774 | print("\nUSING PRETRAINED MODEL... (%s)\n" % modfile) 775 | model = load_model_instance(modfile) 776 | model.load_model(modfile) 777 | 778 | print(model.model.summary()) # print number of parameters in the model 779 | 780 | # generating new data through sampling 781 | print("\nSAMPLING %i SEQUENCES...\n" % sample) 782 | data.generated = model.sample(sample, start=aa, maxlen=samplelength, show=True, temp=temperature) 783 | data.analyze_generated(sample, fname=model.logdir + '/analysis_temp' + str(temperature) + '.txt', plot=True) 784 | data.save_generated(model.logdir, model.logdir + '/sampled_sequences_temp' + str(temperature) + '.csv') 785 | 786 | 787 | if __name__ == "__main__": 788 | # run main code 789 | main(infile=args.dataset, sessname=args.name, batchsize=args.batch_size, epochs=args.epochs, 790 | layers=args.layers, valsplit=args.valsplit, neurons=args.neurons, cell=args.cell, sample=args.sample, 791 | temperature=args.temp, dropout=args.dropout, train=False, modfile=args.modfile, 792 | learningrate=args.lr, cv=args.cv, samplelength=args.maxlen, window=args.window, 793 | step=args.step, aa=args.startchar, l2_rate=args.l2, target=args.target, pad=args.padlen, 794 | finetune=False, references=args.refs) 795 | 796 | # save used flags to log file 797 | _save_flags("./" + args.name + "/flags.txt") 798 | --------------------------------------------------------------------------------