├── pssp-data ├── seqlen_test.png ├── seqlen_train.png ├── amino_acid_test.png ├── amino_acid_train.png ├── secondary_structure_test.png └── secondary_structure_train.png ├── pssp-transformer ├── transformer │ ├── Constants.py │ ├── __init__.py │ ├── Modules.py │ ├── Optim.py │ ├── Layers.py │ ├── SubLayers.py │ ├── Beam.py │ ├── Translator.py │ └── Models.py ├── utils.py ├── translate.py ├── dataset.py ├── download_dataset.py ├── preprocess.py └── main.py ├── pssp-nn ├── model.py ├── make_dataset.py ├── load_dataset.py ├── utils.py └── main.py ├── .gitignore └── README.md /pssp-data/seqlen_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/seqlen_test.png -------------------------------------------------------------------------------- /pssp-data/seqlen_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/seqlen_train.png -------------------------------------------------------------------------------- /pssp-data/amino_acid_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/amino_acid_test.png -------------------------------------------------------------------------------- /pssp-data/amino_acid_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/amino_acid_train.png -------------------------------------------------------------------------------- /pssp-data/secondary_structure_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/secondary_structure_test.png -------------------------------------------------------------------------------- /pssp-data/secondary_structure_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takatex/protein-secondary-structure-prediction/HEAD/pssp-data/secondary_structure_train.png -------------------------------------------------------------------------------- /pssp-transformer/transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = 0 11 | self.init_lr = np.power(d_model, -0.5) 12 | 13 | def step_and_update_lr(self): 14 | "Step with the inner optimizer" 15 | self._update_learning_rate() 16 | self._optimizer.step() 17 | 18 | def zero_grad(self): 19 | "Zero out the gradients by the inner optimizer" 20 | self._optimizer.zero_grad() 21 | 22 | def _get_lr_scale(self): 23 | return np.min([ 24 | np.power(self.n_current_steps, -0.5), 25 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 26 | 27 | def _update_learning_rate(self): 28 | ''' Learning rate scheduling per step ''' 29 | 30 | self.n_current_steps += 1 31 | lr = self.init_lr * self._get_lr_scale() 32 | 33 | for param_group in self._optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | -------------------------------------------------------------------------------- /pssp-nn/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | class Net(nn.Module): 8 | def __init__(self): 9 | super(Net, self).__init__() 10 | 11 | # Conv1d(in_channels, out_channels, kernel_size, stride, padding) 12 | conv_hidden_size = 64 13 | self.conv1 = nn.Sequential( 14 | nn.Conv1d(42, conv_hidden_size, 3, 1, 3 // 2), 15 | nn.ReLU()) 16 | 17 | self.conv2 = nn.Sequential( 18 | nn.Conv1d(42, conv_hidden_size, 7, 1, 7 // 2), 19 | nn.ReLU()) 20 | 21 | self.conv3 = nn.Sequential( 22 | nn.Conv1d(42, conv_hidden_size, 11, 1, 11 // 2), 23 | nn.ReLU()) 24 | 25 | # LSTM(input_size, hidden_size, num_layers, bias, 26 | # batch_first, dropout, bidirectional) 27 | rnn_hidden_size = 256 28 | self.brnn = nn.GRU(conv_hidden_size*3, rnn_hidden_size, 3, True, True, 0.5, True) 29 | 30 | self.fc = nn.Sequential( 31 | nn.Linear(rnn_hidden_size*2+conv_hidden_size*3, 126), 32 | nn.ReLU(), 33 | nn.Linear(126, 8), 34 | nn.ReLU()) 35 | 36 | def forward(self, x): 37 | # obtain multiple local contextual feature map 38 | conv_out = torch.cat([self.conv1(x), self.conv2(x), self.conv3(x)], dim=1) 39 | 40 | # Turn (batch_size x hidden_size x seq_len) 41 | # into (batch_size x seq_len x hidden_size) 42 | conv_out = conv_out.transpose(1, 2) 43 | 44 | # bidirectional rnn 45 | out, _ = self.brnn(conv_out) 46 | 47 | out = torch.cat([conv_out, out], dim=2) 48 | # print(out.sum()) 49 | 50 | # Output shape is (batch_size x seq_len x classnum) 51 | out = self.fc(out) 52 | out = F.softmax(out, dim=2) 53 | return out 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # optianal 2 | .DS_Store 3 | ._.DS_Store 4 | 5 | notebook/ 6 | result/ 7 | *.pkl 8 | *.gz 9 | *.npz 10 | *.txt 11 | *.pt 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # protein-secondary-structure-prediction 2 | 3 | PyTorch implementations of protein secondary structure prediction on CB513. 4 | 5 | I implemented them based on https://github.com/alrojo/CB513 and https://github.com/jadore801120/attention-is-all-you-need-pytorch. 6 | 7 | 8 | # Dataset 9 | I used CB513 dataset of https://github.com/alrojo/CB513. 10 | 11 | |sequence length (train)|sequence length (test)| 12 | |:-:|:-:| 13 | |![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/seqlen_train.png)|![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/seqlen_test.png)| 14 | 15 | |amino acid (train)|amino acid (test)| 16 | |:-:|:-:| 17 | |![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/amino_acid_train.png)|![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/amino_acid_test.png)| 18 | 19 | |secondary structure label(train)|secondary structure label (test)| 20 | |:-:|:-:| 21 | |![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/secondary_structure_train.png)|![](https://github.com/takatex/protein-secondary-structure-prediction/blob/master/pssp-data/secondary_structure_test.png)| 22 | 23 | 24 | # Usage 25 | You can get more infomations by adding `-h` option. 26 | 27 | ## pssp-nn 28 | ``` 29 | python main.py 30 | ``` 31 | 32 | ## pssp-transformer 33 | ``` 34 | python preprocess.py 35 | python main.py 36 | ``` 37 | 38 | 39 | # Acknowledgement 40 | - https://github.com/alrojo/CB513 41 | - https://github.com/jadore801120/attention-is-all-you-need-pytorch 42 | - [Li, Zhen; Yu, Yizhou, Protein Secondary Structure Prediction Using Cascaded Convolutional and Recurrent Neural Networks, 2016.](https://arxiv.org/pdf/1604.07176.pdf) 43 | - [Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention Is All You Need](https://arxiv.org/abs/1706.03762) 44 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /pssp-nn/make_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Original Code : https://github.com/alrojo/CB513/blob/master/data.py 3 | 4 | import os 5 | import numpy as np 6 | import subprocess 7 | from utils import load_gz 8 | 9 | 10 | TRAIN_PATH = '../pssp-data/cullpdb+profile_6133_filtered.npy.gz' 11 | TEST_PATH = '../pssp-data/cb513+profile_split1.npy.gz' 12 | TRAIN_DATASET_PATH = '../pssp-data/train.npz' 13 | TEST_DATASET_PATH = '../pssp-data/test.npz' 14 | TRAIN_URL = "http://www.princeton.edu/~jzthree/datasets/ICML2014/cullpdb+profile_6133_filtered.npy.gz" 15 | TEST_URL = "http://www.princeton.edu/~jzthree/datasets/ICML2014/cb513+profile_split1.npy.gz" 16 | 17 | 18 | def download_dataset(): 19 | print('[Info] Downloading CB513 dataset ...') 20 | if not (os.path.isfile(TRAIN_PATH) and os.path.isfile(TEST_PATH)): 21 | os.makedirs('../pssp-data', exist_ok=True) 22 | os.system(f'wget -O {TRAIN_PATH} {TRAIN_URL}') 23 | os.system(f'wget -O {TEST_PATH} {TEST_URL}') 24 | 25 | 26 | def make_datasets(): 27 | print('[Info] Making datasets ...') 28 | 29 | # train dataset 30 | X_train, y_train, seq_len_train = make_dataset(TRAIN_PATH) 31 | np.savez_compressed(TRAIN_DATASET_PATH, X=X_train, y=y_train, seq_len=seq_len_train) 32 | print(f'[Info] Saved train dataset in {TRAIN_DATASET_PATH}') 33 | 34 | # test dataset 35 | X_test, y_test, seq_len_test = make_dataset(TEST_PATH) 36 | np.savez_compressed(TEST_DATASET_PATH, X=X_test, y=y_test, seq_len=seq_len_test) 37 | print(f'[Info] Saved test dataset in {TEST_DATASET_PATH}') 38 | 39 | 40 | def make_dataset(path): 41 | data = load_gz(path) 42 | data = data.reshape(-1, 700, 57) 43 | 44 | idx = np.append(np.arange(21), np.arange(35, 56)) 45 | X = data[:, :, idx] 46 | X = X.transpose(0, 2, 1) 47 | X = X.astype('float32') 48 | 49 | y = data[:, :, 22:30] 50 | y = np.array([np.dot(yi, np.arange(8)) for yi in y]) 51 | y = y.astype('float32') 52 | 53 | mask = data[:, :, 30] * -1 + 1 54 | seq_len = mask.sum(axis=1) 55 | seq_len = seq_len.astype('float32') 56 | 57 | return X, y, seq_len 58 | -------------------------------------------------------------------------------- /pssp-nn/load_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from make_dataset import download_dataset, make_datasets 6 | 7 | 8 | class MyDataset(Dataset): 9 | 10 | def __init__(self, X, y, seq_len): 11 | self.X = X 12 | self.y = y.astype(int) 13 | self.seq_len = seq_len.astype(int) 14 | 15 | def __len__(self): 16 | return len(self.y) 17 | 18 | def __getitem__(self, idx): 19 | x = self.X[idx] 20 | y = self.y[idx] 21 | seq_len = self.seq_len[idx] 22 | return x, y, seq_len 23 | 24 | 25 | class LoadDataset(object): 26 | 27 | def __init__(self, data_dir, batch_size_train, batch_size_test): 28 | self.data_dir = data_dir 29 | self.train_path = os.path.join(data_dir, 'train.npz') 30 | self.test_path = os.path.join(data_dir, 'test.npz') 31 | self.batch_size_train = batch_size_train 32 | self.batch_size_test = batch_size_test 33 | 34 | def load_dataset(self): 35 | if not(os.path.isfile(self.train_path) and os.path.isfile(self.test_path)): 36 | download_dataset() 37 | make_datasets() 38 | 39 | # train dataset 40 | train_data = np.load(self.train_path) 41 | X_train, y_train, seq_len_train = train_data['X'], train_data['y'], train_data['seq_len'] 42 | 43 | # test dataset 44 | test_data = np.load(self.test_path) 45 | X_test, y_test, seq_len_test = test_data['X'], test_data['y'], test_data['seq_len'] 46 | 47 | return X_train, y_train, seq_len_train, X_test, y_test, seq_len_test 48 | 49 | def __call__(self): 50 | X_train, y_train, seq_len_train, X_test, y_test, seq_len_test = \ 51 | self.load_dataset() 52 | 53 | D_train = MyDataset(X_train, y_train, seq_len_train) 54 | train_loader = DataLoader(D_train, batch_size=self.batch_size_train, shuffle=True) 55 | 56 | D_test = MyDataset(X_test, y_test, seq_len_test) 57 | test_loader = DataLoader(D_test, batch_size=self.batch_size_test, shuffle=False) 58 | 59 | return train_loader, test_loader 60 | -------------------------------------------------------------------------------- /pssp-transformer/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import json 6 | import pickle 7 | import numpy as np 8 | import gzip 9 | import collections 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | def save_text(data, save_path): 16 | with open(save_path, mode='w') as f: 17 | f.write('\n'.join(data)) 18 | 19 | 20 | def save_picke(data, save_path): 21 | with open(save_path, mode="wb") as f: 22 | pickle.dump(data, f) 23 | 24 | 25 | def args2json(data, path, print_args=True): 26 | data = vars(data) 27 | if print_args: 28 | print(f'\n+ ---------------------------') 29 | for k, v in data.items(): 30 | print(f' {k.upper()} : {v}') 31 | print(f'+ ---------------------------\n') 32 | 33 | with open(os.path.join(path, 'args.json'), 'w') as f: 34 | json.dump(data, f) 35 | 36 | 37 | 38 | def amino_count(t): 39 | c = collections.Counter(t) 40 | keys, values = c.keys(), c.values() 41 | return list(keys), list(values) 42 | 43 | 44 | def acid_accuracy(out, target, seq_len): 45 | out = out.cpu().data.numpy() 46 | target = target.cpu().data.numpy() 47 | seq_len = seq_len.cpu().data.numpy() 48 | 49 | out = out.argmax(axis=2) 50 | 51 | count_1 = np.zeros(8) 52 | count_2 = np.zeros(8) 53 | for o, t, l in zip(out, target, seq_len): 54 | o, t = o[:l], t[:l] 55 | 56 | # org 57 | keys, values = amino_count(t) 58 | count_1[keys] += values 59 | 60 | # pred 61 | keys, values = amino_count(t[np.equal(o, t)]) 62 | count_2[keys] += values 63 | 64 | return np.divide(count_2, count_1, out=np.zeros(8), where=count_1!=0) 65 | 66 | 67 | def load_gz(path): # load a .npy.gz file 68 | if path.endswith(".gz"): 69 | f = gzip.open(path, 'rb') 70 | return np.load(f) 71 | else: 72 | return np.load(path) 73 | 74 | def timestamp(): 75 | return time.strftime("%Y%m%d%H%M", time.localtime()) 76 | 77 | def show_progress(e, e_total, train_loss, test_loss, acc): 78 | print(f'[{e:3d}/{e_total:3d}] train_loss:{train_loss:.2f}, '\ 79 | f'test_loss:{test_loss:.2f}, acc:{acc:.3f}') 80 | 81 | 82 | def save_history(history, save_dir): 83 | save_path = os.path.join(save_dir, 'history.npy') 84 | np.save(save_path, history) 85 | 86 | 87 | def save_model(model, save_dir): 88 | save_path = os.path.join(save_dir, 'model.pth') 89 | torch.save(model.state_dict(), save_path) 90 | -------------------------------------------------------------------------------- /pssp-transformer/translate.py: -------------------------------------------------------------------------------- 1 | ''' Translate input text with trained model. ''' 2 | 3 | import torch 4 | import torch.utils.data 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | from dataset import collate_fn, TranslationDataset 9 | from transformer.Translator import Translator 10 | from preprocess import read_instances_from_file, convert_instance_to_idx_seq 11 | 12 | def main(): 13 | '''Main Function''' 14 | 15 | parser = argparse.ArgumentParser(description='translate.py') 16 | 17 | parser.add_argument('-model', required=True, 18 | help='Path to model .pt file') 19 | parser.add_argument('-src', required=True, 20 | help='Source sequence to decode (one line per sequence)') 21 | parser.add_argument('-vocab', required=True, 22 | help='Source sequence to decode (one line per sequence)') 23 | parser.add_argument('-output', default='pred.txt', 24 | help="""Path to output the predictions (each line will 25 | be the decoded sequence""") 26 | parser.add_argument('-beam_size', type=int, default=5, 27 | help='Beam size') 28 | parser.add_argument('-batch_size', type=int, default=30, 29 | help='Batch size') 30 | parser.add_argument('-n_best', type=int, default=1, 31 | help="""If verbose is set, will output the n_best 32 | decoded sentences""") 33 | parser.add_argument('-no_cuda', action='store_true') 34 | 35 | opt = parser.parse_args() 36 | opt.cuda = not opt.no_cuda 37 | 38 | # Prepare DataLoader 39 | preprocess_data = torch.load(opt.vocab) 40 | preprocess_settings = preprocess_data['settings'] 41 | test_src_word_insts = read_instances_from_file( 42 | opt.src, 43 | preprocess_settings.max_word_seq_len, 44 | preprocess_settings.keep_case) 45 | test_src_insts = convert_instance_to_idx_seq( 46 | test_src_word_insts, preprocess_data['dict']['src']) 47 | 48 | test_loader = torch.utils.data.DataLoader( 49 | TranslationDataset( 50 | src_word2idx=preprocess_data['dict']['src'], 51 | tgt_word2idx=preprocess_data['dict']['tgt'], 52 | src_insts=test_src_insts), 53 | num_workers=2, 54 | batch_size=opt.batch_size, 55 | collate_fn=collate_fn) 56 | 57 | translator = Translator(opt) 58 | 59 | with open(opt.output, 'w') as f: 60 | for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): 61 | all_hyp, all_scores = translator.translate_batch(*batch) 62 | for idx_seqs in all_hyp: 63 | for idx_seq in idx_seqs: 64 | pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) 65 | f.write(pred_line + '\n') 66 | print('[Info] Finished.') 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 54 | output, attn = self.attention(q, k, v, mask=mask) 55 | 56 | output = output.view(n_head, sz_b, len_q, d_v) 57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 58 | 59 | output = self.dropout(self.fc(output)) 60 | output = self.layer_norm(output + residual) 61 | 62 | return output, attn 63 | 64 | class PositionwiseFeedForward(nn.Module): 65 | ''' A two-feed-forward-layer module ''' 66 | 67 | def __init__(self, d_in, d_hid, dropout=0.1): 68 | super().__init__() 69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 71 | self.layer_norm = nn.LayerNorm(d_in) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x): 75 | residual = x 76 | output = x.transpose(1, 2) 77 | output = self.w_2(F.relu(self.w_1(output))) 78 | output = output.transpose(1, 2) 79 | output = self.dropout(output) 80 | output = self.layer_norm(output + residual) 81 | return output 82 | -------------------------------------------------------------------------------- /pssp-nn/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import json 6 | import numpy as np 7 | import gzip 8 | import collections 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class CrossEntropy(object): 14 | 15 | def __init__(self): 16 | pass 17 | 18 | def __call__(self, out, target, seq_len): 19 | loss = 0 20 | for o, t, l in zip(out, target, seq_len): 21 | loss += nn.CrossEntropyLoss()(o[:l], t[:l]) 22 | return loss 23 | 24 | 25 | # class LossFunc(object): 26 | # 27 | # def __init__(self): 28 | # self.loss = nn.CrossEntropyLoss() 29 | # 30 | # def __call__(self, out, target, seq_len): 31 | # """ 32 | # out.shape : (batch_size, class_num, seq_len) 33 | # target.shape : (batch_size, seq_len) 34 | # """ 35 | # out = torch.clamp(out, 1e-15, 1 - 1e-15) 36 | # return torch.tensor([self.loss(o[:l], t[:l]) 37 | # for o, t, l in zip(out, target, seq_len)], 38 | # requires_grad=True).sum() 39 | 40 | 41 | def args2json(data, path, print_args=True): 42 | data = vars(data) 43 | if print_args: 44 | print(f'\n+ ---------------------------') 45 | for k, v in data.items(): 46 | print(f' {k.upper()} : {v}') 47 | print(f'+ ---------------------------\n') 48 | 49 | with open(os.path.join(path, 'args.json'), 'w') as f: 50 | json.dump(data, f) 51 | 52 | 53 | def accuracy(out, target, seq_len): 54 | """ 55 | out.shape : (batch_size, seq_len, class_num) 56 | target.shape : (class_num, seq_len) 57 | seq_len.shape : (batch_size) 58 | """ 59 | out = out.cpu().data.numpy() 60 | target = target.cpu().data.numpy() 61 | seq_len = seq_len.cpu().data.numpy() 62 | 63 | out = out.argmax(axis=2) 64 | return np.array([np.equal(o[:l], t[:l]).sum()/l 65 | for o, t, l in zip(out, target, seq_len)]).mean() 66 | 67 | 68 | def amino_count(t): 69 | c = collections.Counter(t) 70 | keys, values = c.keys(), c.values() 71 | return list(keys), list(values) 72 | 73 | 74 | def acid_accuracy(out, target, seq_len): 75 | out = out.cpu().data.numpy() 76 | target = target.cpu().data.numpy() 77 | seq_len = seq_len.cpu().data.numpy() 78 | 79 | out = out.argmax(axis=2) 80 | 81 | count_1 = np.zeros(8) 82 | count_2 = np.zeros(8) 83 | for o, t, l in zip(out, target, seq_len): 84 | o, t = o[:l], t[:l] 85 | 86 | # org 87 | keys, values = amino_count(t) 88 | count_1[keys] += values 89 | 90 | # pred 91 | keys, values = amino_count(t[np.equal(o, t)]) 92 | count_2[keys] += values 93 | 94 | return np.divide(count_2, count_1, out=np.zeros(8), where=count_1!=0) 95 | 96 | 97 | def load_gz(path): # load a .npy.gz file 98 | if path.endswith(".gz"): 99 | f = gzip.open(path, 'rb') 100 | return np.load(f) 101 | else: 102 | return np.load(path) 103 | 104 | 105 | def timestamp(): 106 | return time.strftime("%Y%m%d%H%M", time.localtime()) 107 | 108 | 109 | def show_progress(e, e_total, train_loss, test_loss, acc): 110 | print(f'[{e:3d}/{e_total:3d}] train_loss:{train_loss:.2f}, '\ 111 | f'test_loss:{test_loss:.2f}, acc:{acc:.3f}') 112 | 113 | 114 | def save_history(history, save_dir): 115 | save_path = os.path.join(save_dir, 'history.npy') 116 | np.save(save_path, history) 117 | 118 | 119 | def save_model(model, save_dir): 120 | save_path = os.path.join(save_dir, 'model.pth') 121 | torch.save(model.state_dict(), save_path) 122 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import transformer.Constants as Constants 11 | 12 | class Beam(): 13 | ''' Beam search ''' 14 | 15 | def __init__(self, size, device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | 20 | # The score for each translation on the beam. 21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 22 | self.all_scores = [] 23 | 24 | # The backpointers at each time-step. 25 | self.prev_ks = [] 26 | 27 | # The outputs at each time-step. 28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 29 | self.next_ys[0][0] = Constants.BOS 30 | 31 | def get_current_state(self): 32 | "Get the outputs for the current timestep." 33 | return self.get_tentative_hypothesis() 34 | 35 | def get_current_origin(self): 36 | "Get the backpointers for the current timestep." 37 | return self.prev_ks[-1] 38 | 39 | @property 40 | def done(self): 41 | return self._done 42 | 43 | def advance(self, word_prob): 44 | "Update beam status and check if finished or not." 45 | num_words = word_prob.size(1) 46 | 47 | # Sum the previous scores. 48 | if len(self.prev_ks) > 0: 49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 50 | else: 51 | beam_lk = word_prob[0] 52 | 53 | flat_beam_lk = beam_lk.view(-1) 54 | 55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 57 | 58 | self.all_scores.append(self.scores) 59 | self.scores = best_scores 60 | 61 | # bestScoresId is flattened as a (beam x word) array, 62 | # so we need to calculate which word and beam each score came from 63 | prev_k = best_scores_id / num_words 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(best_scores_id - prev_k * num_words) 66 | 67 | # End condition is when top-of-beam is EOS. 68 | if self.next_ys[-1][0].item() == Constants.EOS: 69 | self._done = True 70 | self.all_scores.append(self.scores) 71 | 72 | return self._done 73 | 74 | def sort_scores(self): 75 | "Sort the scores." 76 | return torch.sort(self.scores, 0, True) 77 | 78 | def get_the_best_score_and_idx(self): 79 | "Get the score of the best in the beam." 80 | scores, ids = self.sort_scores() 81 | return scores[1], ids[1] 82 | 83 | def get_tentative_hypothesis(self): 84 | "Get the decoded sequence for the current timestep." 85 | 86 | if len(self.next_ys) == 1: 87 | dec_seq = self.next_ys[0].unsqueeze(1) 88 | else: 89 | _, keys = self.sort_scores() 90 | hyps = [self.get_hypothesis(k) for k in keys] 91 | hyps = [[Constants.BOS] + h for h in hyps] 92 | dec_seq = torch.LongTensor(hyps) 93 | 94 | return dec_seq 95 | 96 | def get_hypothesis(self, k): 97 | """ Walk back to construct the full hypothesis. """ 98 | hyp = [] 99 | for j in range(len(self.prev_ks) - 1, -1, -1): 100 | hyp.append(self.next_ys[j+1][k]) 101 | k = self.prev_ks[j][k] 102 | 103 | return list(map(lambda x: x.item(), hyp[::-1])) 104 | -------------------------------------------------------------------------------- /pssp-nn/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | sys.path.append(os.pardir) 5 | import numpy as np 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from load_dataset import LoadDataset 11 | from model import Net 12 | from utils import * 13 | 14 | # params 15 | # ---------- 16 | parser = argparse.ArgumentParser(description='Protein Secondary Structure Prediction') 17 | parser.add_argument('-e', '--epochs', type=int, default=1000, 18 | help='The number of epochs to run (default: 1000)') 19 | parser.add_argument('-b', '--batch_size_train', type=int, default=128, 20 | help='input batch size for training (default: 128)') 21 | parser.add_argument('-b_test', '--batch_size_test', type=int, default=1024, 22 | help='input batch size for testing (default: 1024)') 23 | parser.add_argument('--data_dir', type=str, default='../pssp-data', 24 | help='Dataset directory (default: ../pssp-data)') 25 | parser.add_argument('--result_dir', type=str, default='./result', 26 | help='Output directory (default: ./result)') 27 | parser.add_argument('--no_cuda', action='store_true', default=False, 28 | help='disables CUDA training') 29 | parser.add_argument('--seed', type=int, default=1, metavar='S', 30 | help='random seed (default: 1)') 31 | args = parser.parse_args() 32 | 33 | 34 | def train(model, device, train_loader, optimizer, loss_function): 35 | model.train() 36 | train_loss = 0 37 | len_ = len(train_loader) 38 | for batch_idx, (data, target, seq_len) in enumerate(train_loader): 39 | data, target, seq_len = data.to(device), target.to(device), seq_len.to(device) 40 | optimizer.zero_grad() 41 | out = model(data) 42 | loss = loss_function(out, target, seq_len) 43 | loss.backward() 44 | optimizer.step() 45 | train_loss += loss.item() 46 | 47 | train_loss /= len_ 48 | return train_loss 49 | 50 | 51 | def test(model, device, test_loader, loss_function): 52 | model.eval() 53 | test_loss = 0 54 | acc = 0 55 | len_ = len(test_loader) 56 | with torch.no_grad(): 57 | for i, (data, target, seq_len) in enumerate(test_loader): 58 | data, target, seq_len = data.to(device), target.to(device), seq_len.to(device) 59 | out = model(data) 60 | test_loss += loss_function(out, target, seq_len).cpu().data.numpy() 61 | acc += accuracy(out, target, seq_len) 62 | 63 | test_loss /= len_ 64 | acc /= len_ 65 | return test_loss, acc 66 | 67 | 68 | def main(): 69 | use_cuda = not args.no_cuda and torch.cuda.is_available() 70 | torch.manual_seed(args.seed) 71 | device = torch.device("cuda" if use_cuda else "cpu") 72 | 73 | # make directory to save train history and model 74 | os.makedirs(args.result_dir, exist_ok=True) 75 | args2json(args, args.result_dir) 76 | 77 | # laod dataset and set k-fold cross validation 78 | D = LoadDataset(args.data_dir, args.batch_size_train, args.batch_size_test) 79 | train_loader, test_loader = D() 80 | 81 | # model, loss_function, optimizer 82 | model = Net().to(device) 83 | loss_function = CrossEntropy() 84 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01) 85 | 86 | # train and test 87 | history = [] 88 | for e in range(args.epochs): 89 | train_loss = train(model, device, train_loader, optimizer, loss_function) 90 | test_loss, acc = test(model, device, test_loader, loss_function) 91 | history.append([train_loss, test_loss, acc]) 92 | show_progress(e+1, args.epochs, train_loss, test_loss, acc) 93 | 94 | # save train history and model 95 | save_history(history, args.result_dir) 96 | save_model(model, args.result_dir) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /pssp-transformer/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | 5 | from transformer import Constants 6 | 7 | def paired_collate_fn(insts): 8 | src_insts, tgt_insts, sp_insts = list(zip(*insts)) 9 | # print(src_insts) 10 | # print(np.array(sp_insts[0]).shape) 11 | 12 | src_insts = collate_fn_x(src_insts, sp_insts) 13 | # src_insts = collate_fn(src_insts) 14 | tgt_insts = collate_fn(tgt_insts) 15 | # sp_insts = collate_fn_ones(sp_insts) 16 | # print(sp_insts[0].shape) 17 | 18 | return (*src_insts, *tgt_insts) 19 | 20 | 21 | def collate_fn_x(insts, sp_insts): 22 | ''' Pad the instance to the max seq length in batch ''' 23 | 24 | max_len = max(len(inst) for inst in insts) 25 | # print(max_len) 26 | 27 | batch_seq = np.array([ 28 | inst + [Constants.PAD] * (max_len - len(inst)) 29 | for inst in insts]) 30 | 31 | batch_sp = np.array([[ 32 | inst.tolist() + [Constants.PAD] * (max_len - len(inst)) 33 | for inst in sp] 34 | for sp in sp_insts]) 35 | 36 | batch_pos = np.array([ 37 | [pos_i+1 if w_i != Constants.PAD else 0 38 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) 39 | 40 | batch_seq = torch.LongTensor(batch_seq) 41 | batch_sp = torch.FloatTensor(batch_sp) 42 | batch_pos = torch.LongTensor(batch_pos) 43 | 44 | return batch_seq, batch_sp, batch_pos 45 | 46 | 47 | def collate_fn(insts): 48 | ''' Pad the instance to the max seq length in batch ''' 49 | 50 | max_len = max(len(inst) for inst in insts) 51 | # print(max_len) 52 | 53 | batch_seq = np.array([ 54 | inst + [Constants.PAD] * (max_len - len(inst)) 55 | for inst in insts]) 56 | 57 | batch_pos = np.array([ 58 | [pos_i+1 if w_i != Constants.PAD else 0 59 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) 60 | 61 | batch_seq = torch.LongTensor(batch_seq) 62 | batch_pos = torch.LongTensor(batch_pos) 63 | 64 | return batch_seq, batch_pos 65 | 66 | 67 | class TranslationDataset(torch.utils.data.Dataset): 68 | def __init__( 69 | self, src_word2idx, tgt_word2idx, 70 | src_insts=None, tgt_insts=None, sp_insts=None): 71 | 72 | assert src_insts 73 | assert not tgt_insts or (len(src_insts) == len(tgt_insts)) 74 | 75 | src_idx2word = {idx:word for word, idx in src_word2idx.items()} 76 | self._src_word2idx = src_word2idx 77 | self._src_idx2word = src_idx2word 78 | self._src_insts = src_insts 79 | 80 | tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()} 81 | self._tgt_word2idx = tgt_word2idx 82 | self._tgt_idx2word = tgt_idx2word 83 | self._tgt_insts = tgt_insts 84 | 85 | self.sp_insts = sp_insts 86 | 87 | @property 88 | def n_insts(self): 89 | ''' Property for dataset size ''' 90 | return len(self._src_insts) 91 | 92 | @property 93 | def src_vocab_size(self): 94 | ''' Property for vocab size ''' 95 | return len(self._src_word2idx) 96 | 97 | @property 98 | def tgt_vocab_size(self): 99 | ''' Property for vocab size ''' 100 | return len(self._tgt_word2idx) 101 | 102 | @property 103 | def src_word2idx(self): 104 | ''' Property for word dictionary ''' 105 | return self._src_word2idx 106 | 107 | @property 108 | def tgt_word2idx(self): 109 | ''' Property for word dictionary ''' 110 | return self._tgt_word2idx 111 | 112 | @property 113 | def src_idx2word(self): 114 | ''' Property for index dictionary ''' 115 | return self._src_idx2word 116 | 117 | @property 118 | def tgt_idx2word(self): 119 | ''' Property for index dictionary ''' 120 | return self._tgt_idx2word 121 | 122 | def __len__(self): 123 | return self.n_insts 124 | 125 | def __getitem__(self, idx): 126 | if self._tgt_insts: 127 | return self._src_insts[idx], self._tgt_insts[idx], self.sp_insts[idx] 128 | return self._src_insts[idx] 129 | -------------------------------------------------------------------------------- /pssp-transformer/download_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Original Code : https://github.com/alrojo/CB513/blob/master/data.py 3 | 4 | import os 5 | import numpy as np 6 | import subprocess 7 | from utils import load_gz, save_text, save_picke 8 | 9 | 10 | TRAIN_PATH = '../pssp-data/cullpdb+profile_6133_filtered.npy.gz' 11 | TEST_PATH = '../pssp-data/cb513+profile_split1.npy.gz' 12 | 13 | TRAIN_URL = "http://www.princeton.edu/~jzthree/datasets/ICML2014/cullpdb+profile_6133_filtered.npy.gz" 14 | TEST_URL = "http://www.princeton.edu/~jzthree/datasets/ICML2014/cb513+profile_split1.npy.gz" 15 | 16 | AA_PATH = lambda key : f'../pssp-data/aa_{key}.txt' 17 | SP_PATH = lambda key : f'../pssp-data/sp_{key}.pkl' 18 | PSS_PATH = lambda key : f'../pssp-data/pss_{key}.txt' 19 | 20 | 21 | def download_dataset(): 22 | print('[Info] Downloading CB513 dataset ...') 23 | if not (os.path.isfile(TRAIN_PATH) and os.path.isfile(TEST_PATH)): 24 | os.makedirs('../pssp-data', exist_ok=True) 25 | os.system(f'wget -O {TRAIN_PATH} {TRAIN_URL}') 26 | os.system(f'wget -O {TEST_PATH} {TEST_URL}') 27 | 28 | 29 | def make_datasets(): 30 | print('[Info] Making datasets ...') 31 | 32 | # train dataset 33 | X_train, y_train, seq_len_train = make_dataset(TRAIN_PATH) 34 | make_dataset_for_transformer(X_train, y_train, seq_len_train, 'train') 35 | 36 | # test dataset 37 | X_test, y_test, seq_len_test = make_dataset(TEST_PATH) 38 | make_dataset_for_transformer(X_test, y_test, seq_len_test, 'test') 39 | 40 | 41 | def make_dataset(path): 42 | data = load_gz(path) 43 | data = data.reshape(-1, 700, 57) 44 | 45 | idx = np.append(np.arange(21), np.arange(35, 56)) 46 | X = data[:, :, idx] 47 | X = X.transpose(0, 2, 1) 48 | X = X.astype('float32') 49 | 50 | y = data[:, :, 22:30] 51 | y = np.array([np.dot(yi, np.arange(8)) for yi in y]) 52 | y = y.astype('float32') 53 | 54 | mask = data[:, :, 30] * -1 + 1 55 | seq_len = mask.sum(axis=1) 56 | seq_len = seq_len.astype(int) 57 | 58 | return X, y, seq_len 59 | 60 | 61 | def make_dataset_for_transformer(X, y, seq_len, key): 62 | X_amino = X[:, :21, :] 63 | X_profile = X[:, 21:, :] 64 | 65 | amino_acid_array = get_amino_acid_array(X_amino, seq_len) 66 | save_path = AA_PATH(key) 67 | save_text(amino_acid_array, save_path) 68 | print(f'[Info] Saved amino_acid_array for {key} in {save_path}') 69 | 70 | seq_profile = get_seq_profile(X_profile, seq_len) 71 | save_path = SP_PATH(key) 72 | save_picke(seq_profile, save_path) 73 | print(f'[Info] Saved seq_profile for {key} in {save_path}') 74 | 75 | pss_array = get_pss_array(y, seq_len) 76 | save_path = PSS_PATH(key) 77 | save_text(pss_array, save_path) 78 | print(f'[Info] Saved pss_array for {key} in {save_path}') 79 | 80 | 81 | def get_amino_acid_array(X_amino, seq_len): 82 | amino_acid = ['A', 'C', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 83 | 'L', 'N', 'Q', 'P', 'S', 'R', 'T', 'W', 'V', 'Y', 'X'] 84 | amino_acid_array = [] 85 | for X, l in zip(X_amino, seq_len): 86 | acid = {} 87 | for i, aa in enumerate(amino_acid): 88 | keys = np.where(X[i] == 1)[0] 89 | values = [aa] * len(keys) 90 | acid.update(zip(keys, values)) 91 | aa_str = ' '.join([acid[i] for i in range(l)]) 92 | 93 | amino_acid_array.append(aa_str) 94 | return amino_acid_array 95 | 96 | 97 | def get_pss_array(label, seq_len): 98 | pss_icon = ['L', 'B', 'E', 'G', 'I', 'H', 'S', 'T'] 99 | pss_array = [] 100 | for target, l in zip(label, seq_len): 101 | pss = np.array(['Nofill'] * l) 102 | target = target[:l] 103 | for i, p in enumerate(pss_icon): 104 | idx = np.where(target == i)[0] 105 | pss[idx] = p 106 | 107 | pss_str = ' '.join([pss[i] for i in range(l)]) 108 | pss_array.append(pss_str) 109 | 110 | return pss_array 111 | 112 | 113 | def get_seq_profile(X_profile, seq_len): 114 | seq_profile = [] 115 | for sp, l in zip(X_profile, seq_len): 116 | seq_profile.append(sp[:, :l]) 117 | return seq_profile 118 | -------------------------------------------------------------------------------- /pssp-transformer/preprocess.py: -------------------------------------------------------------------------------- 1 | ''' Handling the data io ''' 2 | import os 3 | import argparse 4 | import torch 5 | import pickle 6 | import transformer.Constants as Constants 7 | from download_dataset import download_dataset, make_datasets 8 | 9 | def read_instances_from_file(inst_file, max_sent_len, keep_case): 10 | ''' Convert file into word seq lists and vocab ''' 11 | 12 | word_insts = [] 13 | trimmed_sent_count = 0 14 | with open(inst_file) as f: 15 | for sent in f: 16 | if not keep_case: 17 | sent = sent.lower() 18 | words = sent.split() 19 | if len(words) > max_sent_len: 20 | trimmed_sent_count += 1 21 | word_inst = words[:max_sent_len] 22 | 23 | if word_inst: 24 | word_insts += [[Constants.BOS_WORD] + word_inst + [Constants.EOS_WORD]] 25 | else: 26 | word_insts += [None] 27 | 28 | print('[Info] Get {} instances from {}'.format(len(word_insts), inst_file)) 29 | 30 | if trimmed_sent_count > 0: 31 | print('[Warning] {} instances are trimmed to the max sentence length {}.' 32 | .format(trimmed_sent_count, max_sent_len)) 33 | 34 | return word_insts 35 | 36 | def build_vocab_idx(word_insts, min_word_count): 37 | ''' Trim vocab by number of occurence ''' 38 | 39 | full_vocab = set(w for sent in word_insts for w in sent) 40 | print('[Info] Original Vocabulary size =', len(full_vocab)) 41 | 42 | word2idx = { 43 | Constants.BOS_WORD: Constants.BOS, 44 | Constants.EOS_WORD: Constants.EOS, 45 | Constants.PAD_WORD: Constants.PAD, 46 | Constants.UNK_WORD: Constants.UNK} 47 | word_count = {w: 0 for w in full_vocab} 48 | 49 | for sent in word_insts: 50 | for word in sent: 51 | word_count[word] += 1 52 | 53 | ignored_word_count = 0 54 | for word, count in word_count.items(): 55 | if word not in word2idx: 56 | if count > min_word_count: 57 | word2idx[word] = len(word2idx) 58 | else: 59 | ignored_word_count += 1 60 | 61 | print('[Info] Trimmed vocabulary size = {},'.format(len(word2idx)), 62 | 'each with minimum occurrence = {}'.format(min_word_count)) 63 | print("[Info] Ignored word count = {}".format(ignored_word_count)) 64 | return word2idx 65 | 66 | def convert_instance_to_idx_seq(word_insts, word2idx): 67 | ''' Mapping words to idx sequence. ''' 68 | return [[word2idx.get(w, Constants.UNK) for w in s] for s in word_insts] 69 | 70 | def load_picke_data(path): 71 | with open(path, mode="rb") as f: 72 | data = pickle.load(f) 73 | return data 74 | 75 | 76 | def main(): 77 | ''' Main function ''' 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('-train_src', default='../pssp-data/aa_train.txt') 81 | parser.add_argument('-train_tgt', default='../pssp-data/pss_train.txt') 82 | parser.add_argument('-train_sp', default='../pssp-data/sp_train.pkl') 83 | 84 | parser.add_argument('-valid_src', default='../pssp-data/aa_test.txt') 85 | parser.add_argument('-valid_tgt', default='../pssp-data/pss_test.txt') 86 | parser.add_argument('-valid_sp', default='../pssp-data/sp_test.pkl') 87 | 88 | parser.add_argument('-save_data', default='../pssp-data/dataset.pt') 89 | parser.add_argument('-max_len', '--max_word_seq_len', type=int, default=700) 90 | parser.add_argument('-min_word_count', type=int, default=5) 91 | parser.add_argument('-keep_case', action='store_true') 92 | parser.add_argument('-share_vocab', action='store_true') 93 | parser.add_argument('-vocab', default=None) 94 | 95 | opt = parser.parse_args() 96 | opt.max_token_seq_len = opt.max_word_seq_len + 2 # include the and 97 | 98 | if not os.path.isfile(opt.save_data): 99 | download_dataset() 100 | make_datasets() 101 | 102 | # Training set 103 | train_src_word_insts = read_instances_from_file( 104 | opt.train_src, opt.max_word_seq_len, opt.keep_case) 105 | train_tgt_word_insts = read_instances_from_file( 106 | opt.train_tgt, opt.max_word_seq_len, opt.keep_case) 107 | 108 | if len(train_src_word_insts) != len(train_tgt_word_insts): 109 | print('[Warning] The training instance count is not equal.') 110 | min_inst_count = min(len(train_src_word_insts), len(train_tgt_word_insts)) 111 | train_src_word_insts = train_src_word_insts[:min_inst_count] 112 | train_tgt_word_insts = train_tgt_word_insts[:min_inst_count] 113 | 114 | #- Remove empty instances 115 | train_src_word_insts, train_tgt_word_insts = list(zip(*[ 116 | (s, t) for s, t in zip(train_src_word_insts, train_tgt_word_insts) if s and t])) 117 | 118 | # Validation set 119 | valid_src_word_insts = read_instances_from_file( 120 | opt.valid_src, opt.max_word_seq_len, opt.keep_case) 121 | valid_tgt_word_insts = read_instances_from_file( 122 | opt.valid_tgt, opt.max_word_seq_len, opt.keep_case) 123 | 124 | if len(valid_src_word_insts) != len(valid_tgt_word_insts): 125 | print('[Warning] The validation instance count is not equal.') 126 | min_inst_count = min(len(valid_src_word_insts), len(valid_tgt_word_insts)) 127 | valid_src_word_insts = valid_src_word_insts[:min_inst_count] 128 | valid_tgt_word_insts = valid_tgt_word_insts[:min_inst_count] 129 | 130 | #- Remove empty instances 131 | valid_src_word_insts, valid_tgt_word_insts = list(zip(*[ 132 | (s, t) for s, t in zip(valid_src_word_insts, valid_tgt_word_insts) if s and t])) 133 | 134 | # Build vocabulary 135 | if opt.vocab: 136 | predefined_data = torch.load(opt.vocab) 137 | assert 'dict' in predefined_data 138 | 139 | print('[Info] Pre-defined vocabulary found.') 140 | src_word2idx = predefined_data['dict']['src'] 141 | tgt_word2idx = predefined_data['dict']['tgt'] 142 | else: 143 | if opt.share_vocab: 144 | print('[Info] Build shared vocabulary for source and target.') 145 | word2idx = build_vocab_idx( 146 | train_src_word_insts + train_tgt_word_insts, opt.min_word_count) 147 | src_word2idx = tgt_word2idx = word2idx 148 | else: 149 | print('[Info] Build vocabulary for source.') 150 | src_word2idx = build_vocab_idx(train_src_word_insts, opt.min_word_count) 151 | print('[Info] Build vocabulary for target.') 152 | tgt_word2idx = build_vocab_idx(train_tgt_word_insts, opt.min_word_count) 153 | 154 | # word to index 155 | print('[Info] Convert source word instances into sequences of word index.') 156 | train_src_insts = convert_instance_to_idx_seq(train_src_word_insts, src_word2idx) 157 | valid_src_insts = convert_instance_to_idx_seq(valid_src_word_insts, src_word2idx) 158 | 159 | print('[Info] Convert target word instances into sequences of word index.') 160 | train_tgt_insts = convert_instance_to_idx_seq(train_tgt_word_insts, tgt_word2idx) 161 | valid_tgt_insts = convert_instance_to_idx_seq(valid_tgt_word_insts, tgt_word2idx) 162 | 163 | # read sequences profile 164 | train_seq_profile = load_picke_data(opt.train_sp) 165 | valid_seq_profile = load_picke_data(opt.valid_sp) 166 | 167 | data = { 168 | 'settings': opt, 169 | 'dict': { 170 | 'src': src_word2idx, 171 | 'tgt': tgt_word2idx}, 172 | 'train': { 173 | 'src': train_src_insts, 174 | 'sp' : train_seq_profile, 175 | 'tgt': train_tgt_insts}, 176 | 'valid': { 177 | 'src': valid_src_insts, 178 | 'sp' : valid_seq_profile, 179 | 'tgt': valid_tgt_insts}} 180 | 181 | print('[Info] Dumping the processed data to pickle file', opt.save_data) 182 | torch.save(data, opt.save_data) 183 | print('[Info] Finish.') 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class Translator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 16 | 17 | checkpoint = torch.load(opt.model) 18 | model_opt = checkpoint['settings'] 19 | self.model_opt = model_opt 20 | 21 | model = Transformer( 22 | model_opt.src_vocab_size, 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, 26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight, 27 | d_k=model_opt.d_k, 28 | d_v=model_opt.d_v, 29 | d_model=model_opt.d_model, 30 | d_word_vec=model_opt.d_word_vec, 31 | d_inner=model_opt.d_inner_hid, 32 | n_layers=model_opt.n_layers, 33 | n_head=model_opt.n_head, 34 | dropout=model_opt.dropout) 35 | 36 | model.load_state_dict(checkpoint['model']) 37 | print('[Info] Trained model state loaded.') 38 | 39 | model.word_prob_prj = nn.LogSoftmax(dim=1) 40 | 41 | model = model.to(self.device) 42 | 43 | self.model = model 44 | self.model.eval() 45 | 46 | def translate_batch(self, src_seq, src_pos): 47 | ''' Translation work in one batch ''' 48 | 49 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 50 | ''' Indicate the position of an instance in a tensor. ''' 51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 52 | 53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 54 | ''' Collect tensor parts associated to active instances. ''' 55 | 56 | _, *d_hs = beamed_tensor.size() 57 | n_curr_active_inst = len(curr_active_inst_idx) 58 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 59 | 60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 62 | beamed_tensor = beamed_tensor.view(*new_shape) 63 | 64 | return beamed_tensor 65 | 66 | def collate_active_info( 67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 68 | # Sentences which are still active are collected, 69 | # so the decoder will not run on completed sentences. 70 | n_prev_active_inst = len(inst_idx_to_position_map) 71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 73 | 74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 77 | 78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 79 | 80 | def beam_decode_step( 81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 82 | ''' Decode and update beam status, and then return active beam idx ''' 83 | 84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 88 | return dec_partial_seq 89 | 90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 93 | return dec_partial_pos 94 | 95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 99 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 100 | 101 | return word_prob 102 | 103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 104 | active_inst_idx_list = [] 105 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 107 | if not is_inst_complete: 108 | active_inst_idx_list += [inst_idx] 109 | 110 | return active_inst_idx_list 111 | 112 | n_active_inst = len(inst_idx_to_position_map) 113 | 114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 117 | 118 | # Update the beam with predicted word prob information and collect incomplete instances 119 | active_inst_idx_list = collect_active_inst_idx_list( 120 | inst_dec_beams, word_prob, inst_idx_to_position_map) 121 | 122 | return active_inst_idx_list 123 | 124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 125 | all_hyp, all_scores = [], [] 126 | for inst_idx in range(len(inst_dec_beams)): 127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 128 | all_scores += [scores[:n_best]] 129 | 130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 131 | all_hyp += [hyps] 132 | return all_hyp, all_scores 133 | 134 | with torch.no_grad(): 135 | #-- Encode 136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 137 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 138 | 139 | #-- Repeat data for beam search 140 | n_bm = self.opt.beam_size 141 | n_inst, len_s, d_h = src_enc.size() 142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 144 | 145 | #-- Prepare beams 146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 147 | 148 | #-- Bookkeeping for active or not 149 | active_inst_idx_list = list(range(n_inst)) 150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 151 | 152 | #-- Decode 153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 154 | 155 | active_inst_idx_list = beam_decode_step( 156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 157 | 158 | if not active_inst_idx_list: 159 | break # all instances have finished their path to 160 | 161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 163 | 164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 165 | 166 | return batch_hyp, batch_scores 167 | -------------------------------------------------------------------------------- /pssp-transformer/transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.linear = nn.Linear(21, d_word_vec) 68 | 69 | self.src_word_emb = nn.Embedding( 70 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 71 | 72 | self.position_enc = nn.Embedding.from_pretrained( 73 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 74 | freeze=True) 75 | 76 | self.layer_stack = nn.ModuleList([ 77 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 78 | for _ in range(n_layers)]) 79 | 80 | # def forward(self, src_seq, src_pos, return_attns=False): 81 | def forward(self, src_seq, src_sp, src_pos, return_attns=False): 82 | src_sp = src_sp.transpose(1, 2) 83 | 84 | enc_slf_attn_list = [] 85 | 86 | # -- Prepare masks 87 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 88 | non_pad_mask = get_non_pad_mask(src_seq) 89 | 90 | # -- Forward 91 | # enc_output = torch.cat([self.src_word_emb(src_seq), 92 | # src_sp.transpose(1, 2).float()], dim=2) 93 | # enc_output += self.position_enc(src_pos) 94 | # a = self.src_word_emb(src_seq) 95 | # b = self.position_enc(src_pos) 96 | # c = self.linear(src_sp) 97 | # print(f'a : {a.shape}') 98 | # print(f'b : {b.shape}') 99 | # enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 100 | enc_output = self.src_word_emb(src_seq) + self.linear(src_sp) + self.position_enc(src_pos) 101 | 102 | for enc_layer in self.layer_stack: 103 | enc_output, enc_slf_attn = enc_layer( 104 | enc_output, 105 | non_pad_mask=non_pad_mask, 106 | slf_attn_mask=slf_attn_mask) 107 | if return_attns: 108 | enc_slf_attn_list += [enc_slf_attn] 109 | 110 | if return_attns: 111 | return enc_output, enc_slf_attn_list 112 | return enc_output, 113 | 114 | class Decoder(nn.Module): 115 | ''' A decoder model with self attention mechanism. ''' 116 | 117 | def __init__( 118 | self, 119 | n_tgt_vocab, len_max_seq, d_word_vec, 120 | n_layers, n_head, d_k, d_v, 121 | d_model, d_inner, dropout=0.1): 122 | 123 | super().__init__() 124 | n_position = len_max_seq + 1 125 | 126 | self.tgt_word_emb = nn.Embedding( 127 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 128 | 129 | self.position_enc = nn.Embedding.from_pretrained( 130 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 131 | freeze=True) 132 | 133 | self.layer_stack = nn.ModuleList([ 134 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 135 | for _ in range(n_layers)]) 136 | 137 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 138 | 139 | dec_slf_attn_list, dec_enc_attn_list = [], [] 140 | # print(enc_output.shape) 141 | 142 | # -- Prepare masks 143 | non_pad_mask = get_non_pad_mask(tgt_seq) 144 | 145 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 146 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 147 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 148 | 149 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 150 | 151 | # -- Forward 152 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 153 | # print('---') 154 | # print(enc_output.shape) 155 | # print(dec_output.shape) 156 | 157 | for dec_layer in self.layer_stack: 158 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 159 | dec_output, enc_output, 160 | non_pad_mask=non_pad_mask, 161 | slf_attn_mask=slf_attn_mask, 162 | dec_enc_attn_mask=dec_enc_attn_mask) 163 | 164 | if return_attns: 165 | dec_slf_attn_list += [dec_slf_attn] 166 | dec_enc_attn_list += [dec_enc_attn] 167 | 168 | if return_attns: 169 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 170 | return dec_output, 171 | 172 | class Transformer(nn.Module): 173 | ''' A sequence to sequence model with attention mechanism. ''' 174 | 175 | def __init__( 176 | self, 177 | n_src_vocab, n_tgt_vocab, len_max_seq, 178 | d_word_vec=512, d_model=512, d_inner=2048, 179 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 180 | tgt_emb_prj_weight_sharing=True, 181 | emb_src_tgt_weight_sharing=True): 182 | 183 | super().__init__() 184 | 185 | self.encoder = Encoder( 186 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 187 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 188 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 189 | dropout=dropout) 190 | 191 | self.decoder = Decoder( 192 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 193 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 194 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 195 | dropout=dropout) 196 | 197 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 198 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 199 | 200 | # assert d_model == d_word_vec, \ 201 | # 'To facilitate the residual connections, \ 202 | # the dimensions of all module outputs shall be the same.' 203 | 204 | if tgt_emb_prj_weight_sharing: 205 | # Share the weight matrix between target word embedding & the final logit dense layer 206 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 207 | self.x_logit_scale = (d_model ** -0.5) 208 | else: 209 | self.x_logit_scale = 1. 210 | 211 | if emb_src_tgt_weight_sharing: 212 | # Share the weight matrix between source & target word embeddings 213 | assert n_src_vocab == n_tgt_vocab, \ 214 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 215 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 216 | 217 | # def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 218 | def forward(self, src_seq, src_sp, src_pos, tgt_seq, tgt_pos): 219 | 220 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 221 | 222 | # enc_output, *_ = self.encoder(src_seq, src_pos) 223 | enc_output, *_ = self.encoder(src_seq, src_sp, src_pos) 224 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 225 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 226 | 227 | return seq_logit.view(-1, seq_logit.size(2)) 228 | -------------------------------------------------------------------------------- /pssp-transformer/main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script handling the training process. 3 | ''' 4 | 5 | import os 6 | import argparse 7 | import math 8 | import time 9 | 10 | from tqdm import tqdm 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torch.utils.data 15 | import transformer.Constants as Constants 16 | from dataset import TranslationDataset, paired_collate_fn 17 | from transformer.Models import Transformer 18 | from transformer.Optim import ScheduledOptim 19 | from utils import args2json, save_model, save_history, show_progress 20 | 21 | 22 | def cal_performance(pred, gold, smoothing=False): 23 | ''' Apply label smoothing if needed ''' 24 | 25 | loss = cal_loss(pred, gold, smoothing) 26 | 27 | pred = pred.max(1)[1] 28 | gold = gold.contiguous().view(-1) 29 | non_pad_mask = gold.ne(Constants.PAD) 30 | n_correct = pred.eq(gold) 31 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 32 | 33 | return loss, n_correct 34 | 35 | 36 | def cal_loss(pred, gold, smoothing): 37 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 38 | 39 | gold = gold.contiguous().view(-1) 40 | 41 | if smoothing: 42 | eps = 0.1 43 | n_class = pred.size(1) 44 | 45 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 46 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 47 | log_prb = F.log_softmax(pred, dim=1) 48 | 49 | non_pad_mask = gold.ne(Constants.PAD) 50 | loss = -(one_hot * log_prb).sum(dim=1) 51 | loss = loss.masked_select(non_pad_mask).sum() # average later 52 | else: 53 | # print(gold) 54 | loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum') 55 | 56 | return loss 57 | 58 | 59 | def train_epoch(model, training_data, optimizer, device, smoothing): 60 | ''' Epoch operation in training phase''' 61 | 62 | model.train() 63 | 64 | total_loss = 0 65 | n_word_total = 0 66 | n_word_correct = 0 67 | 68 | for batch in training_data: 69 | 70 | # prepare data 71 | # src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 72 | src_seq, src_sp, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 73 | # print('-----------') 74 | # print(f'src_seq : {src_seq.shape}') 75 | # print(f'src_pos : {src_pos.shape}') 76 | # print(f'tgt_seq : {tgt_seq.shape}') 77 | # print(f'tgt_pos : {tgt_pos.shape}') 78 | gold = tgt_seq[:, 1:] 79 | 80 | # forward 81 | optimizer.zero_grad() 82 | # pred = model(src_seq, src_pos, tgt_seq, tgt_pos) 83 | pred = model(src_seq, src_sp, src_pos, tgt_seq, tgt_pos) 84 | 85 | # backward 86 | loss, n_correct = cal_performance(pred, gold, smoothing=smoothing) 87 | loss.backward() 88 | 89 | # update parameters 90 | optimizer.step_and_update_lr() 91 | 92 | # note keeping 93 | total_loss += loss.item() 94 | 95 | non_pad_mask = gold.ne(Constants.PAD) 96 | n_word = non_pad_mask.sum().item() 97 | n_word_total += n_word 98 | n_word_correct += n_correct 99 | 100 | loss_per_word = total_loss/n_word_total 101 | accuracy = n_word_correct/n_word_total 102 | return loss_per_word, accuracy 103 | 104 | def eval_epoch(model, validation_data, device): 105 | ''' Epoch operation in evaluation phase ''' 106 | 107 | model.eval() 108 | 109 | total_loss = 0 110 | n_word_total = 0 111 | n_word_correct = 0 112 | 113 | with torch.no_grad(): 114 | for batch in validation_data: 115 | 116 | # prepare data 117 | # src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 118 | src_seq, src_sp, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 119 | gold = tgt_seq[:, 1:] 120 | 121 | # forward 122 | # pred = model(src_seq, src_pos, tgt_seq, tgt_pos) 123 | pred = model(src_seq, src_sp, src_pos, tgt_seq, tgt_pos) 124 | loss, n_correct = cal_performance(pred, gold, smoothing=False) 125 | 126 | # note keeping 127 | total_loss += loss.item() 128 | 129 | non_pad_mask = gold.ne(Constants.PAD) 130 | n_word = non_pad_mask.sum().item() 131 | n_word_total += n_word 132 | n_word_correct += n_correct 133 | 134 | loss_per_word = total_loss/n_word_total 135 | accuracy = n_word_correct/n_word_total 136 | return loss_per_word, accuracy 137 | 138 | 139 | 140 | def train(model, training_data, validation_data, optimizer, device, opt): 141 | ''' Start training ''' 142 | 143 | log_train_file = None 144 | log_valid_file = None 145 | 146 | if opt.log: 147 | log_train_file = opt.log + '.train.log' 148 | log_valid_file = opt.log + '.valid.log' 149 | 150 | print('[Info] Training performance will be written to file: {} and {}'.format( 151 | log_train_file, log_valid_file)) 152 | 153 | with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf: 154 | log_tf.write('epoch,loss,ppl,accuracy\n') 155 | log_vf.write('epoch,loss,ppl,accuracy\n') 156 | 157 | history = [] 158 | valid_accus = [] 159 | for e in range(opt.epoch): 160 | 161 | train_loss, train_accu = train_epoch( 162 | model, training_data, optimizer, device, smoothing=opt.label_smoothing) 163 | 164 | valid_loss, valid_accu = eval_epoch(model, validation_data, device) 165 | 166 | history.append([train_loss, valid_loss, valid_accu]) 167 | valid_accus += [valid_accu] 168 | 169 | if valid_accu >= max(valid_accus): 170 | save_model(model, opt.result_dir) 171 | print('[Info] The checkpoint file has been updated.') 172 | 173 | if log_train_file and log_valid_file: 174 | with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf: 175 | log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 176 | epoch=e, loss=train_loss, 177 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu)) 178 | log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 179 | epoch=e, loss=valid_loss, 180 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu)) 181 | 182 | show_progress(e+1, opt.epoch, train_loss, valid_loss, valid_accu) 183 | 184 | save_history(history, opt.result_dir) 185 | 186 | def main(): 187 | ''' Main function ''' 188 | parser = argparse.ArgumentParser() 189 | 190 | parser.add_argument('-data', default='../pssp-data/dataset.pt') 191 | 192 | parser.add_argument('-epoch', type=int, default=10) 193 | parser.add_argument('-batch_size', type=int, default=4) 194 | 195 | #parser.add_argument('-d_word_vec', type=int, default=512) 196 | parser.add_argument('-d_model', type=int, default=256) 197 | parser.add_argument('-d_inner_hid', type=int, default=512) 198 | parser.add_argument('-d_k', type=int, default=64) 199 | parser.add_argument('-d_v', type=int, default=64) 200 | 201 | parser.add_argument('-n_head', type=int, default=8) 202 | parser.add_argument('-n_layers', type=int, default=2) 203 | parser.add_argument('-n_warmup_steps', type=int, default=4000) 204 | 205 | parser.add_argument('-dropout', type=float, default=0.5) 206 | parser.add_argument('-embs_share_weight', action='store_true') 207 | parser.add_argument('-proj_share_weight', action='store_true') 208 | 209 | parser.add_argument('-log', default=None) 210 | parser.add_argument('-result_dir', type=str, default='./result') 211 | # parser.add_argument('-save_model', type=str, default='model') 212 | # parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') 213 | 214 | parser.add_argument('-no_cuda', action='store_true') 215 | parser.add_argument('-label_smoothing', action='store_true') 216 | 217 | opt = parser.parse_args() 218 | opt.cuda = not opt.no_cuda 219 | opt.d_word_vec = opt.d_model 220 | 221 | os.makedirs(opt.result_dir, exist_ok=True) 222 | args2json(opt, opt.result_dir) 223 | 224 | #========= Loading Dataset =========# 225 | data = torch.load(opt.data) 226 | opt.max_token_seq_len = data['settings'].max_token_seq_len 227 | 228 | training_data, validation_data = prepare_dataloaders(data, opt) 229 | 230 | opt.src_vocab_size = training_data.dataset.src_vocab_size 231 | opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size 232 | 233 | #========= Preparing Model =========# 234 | if opt.embs_share_weight: 235 | assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \ 236 | 'The src/tgt word2idx table are different but asked to share word embedding.' 237 | 238 | device = torch.device('cuda' if opt.cuda else 'cpu') 239 | torch.cuda.set_device(1) 240 | transformer = Transformer( 241 | opt.src_vocab_size, 242 | opt.tgt_vocab_size, 243 | opt.max_token_seq_len, 244 | tgt_emb_prj_weight_sharing=opt.proj_share_weight, 245 | emb_src_tgt_weight_sharing=opt.embs_share_weight, 246 | d_k=opt.d_k, 247 | d_v=opt.d_v, 248 | d_model=opt.d_model, 249 | d_word_vec=opt.d_word_vec, 250 | d_inner=opt.d_inner_hid, 251 | n_layers=opt.n_layers, 252 | n_head=opt.n_head, 253 | dropout=opt.dropout).to(device) 254 | 255 | optimizer = ScheduledOptim( 256 | optim.Adam( 257 | filter(lambda x: x.requires_grad, transformer.parameters()), 258 | betas=(0.9, 0.98), eps=1e-09), 259 | opt.d_model, opt.n_warmup_steps) 260 | 261 | train(transformer, training_data, validation_data, optimizer, device ,opt) 262 | 263 | 264 | def prepare_dataloaders(data, opt): 265 | # ========= Preparing DataLoader =========# 266 | train_loader = torch.utils.data.DataLoader( 267 | TranslationDataset( 268 | src_word2idx=data['dict']['src'], 269 | tgt_word2idx=data['dict']['tgt'], 270 | src_insts=data['train']['src'], 271 | tgt_insts=data['train']['tgt'], 272 | sp_insts=data['train']['sp']), 273 | num_workers=2, 274 | batch_size=opt.batch_size, 275 | collate_fn=paired_collate_fn, 276 | shuffle=True) 277 | 278 | valid_loader = torch.utils.data.DataLoader( 279 | TranslationDataset( 280 | src_word2idx=data['dict']['src'], 281 | tgt_word2idx=data['dict']['tgt'], 282 | src_insts=data['valid']['src'], 283 | tgt_insts=data['valid']['tgt'], 284 | sp_insts=data['valid']['sp']), 285 | num_workers=2, 286 | batch_size=opt.batch_size, 287 | collate_fn=paired_collate_fn) 288 | return train_loader, valid_loader 289 | 290 | 291 | if __name__ == '__main__': 292 | main() 293 | --------------------------------------------------------------------------------