├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── comparison.py │ ├── multitask.py │ ├── embedding.py │ └── sequence.py ├── parse_utils.py ├── pfam.py ├── scop.py ├── metrics.pyx ├── pdb.py ├── fasta.py ├── alphabets.py ├── utils.py ├── transmembrane.py └── alignment.pyx ├── setup.py ├── .gitignore ├── README.md ├── embed_sequences.py ├── eval_contact_scop.py ├── eval_similarity.py ├── train_lm_pfam.py ├── eval_transmembrane.py ├── eval_secstr.py ├── train_similarity.py ├── eval_contact_casp12.py ├── LICENSE └── train_similarity_and_contact.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy as np 4 | 5 | setup( 6 | ext_modules = cythonize(['src/metrics.pyx', 'src/alignment.pyx']), 7 | include_dirs=[np.get_include()] 8 | ) 9 | -------------------------------------------------------------------------------- /src/parse_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | def parse_3line(f): 4 | names = [] 5 | xs = [] 6 | ys = [] 7 | for line in f: 8 | if line.startswith(b'>'): 9 | name = line[1:] 10 | # get the sequence 11 | x = f.readline().strip() 12 | # get the transmembrane annotations 13 | y = f.readline().strip() 14 | names.append(name) 15 | xs.append(x) 16 | ys.append(y) 17 | return names, xs, ys 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/pfam.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | 5 | def parse_seed(f): 6 | alignments = [] 7 | a = [] 8 | for line in f: 9 | if line.startswith(b'#'): 10 | continue 11 | if line.startswith(b'//'): 12 | alignments.append(a) 13 | a = [] 14 | else: 15 | _,s = line.split() 16 | a.append(s) 17 | if len(a) > 0: 18 | alignments.append(a) 19 | 20 | return alignments 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/scop.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | import src.fasta as fasta 5 | 6 | class NullEncoder: 7 | def encode(self, x): 8 | return x 9 | 10 | def parse_astral_name(name, encode_struct=True): 11 | tokens = name.split() 12 | name = tokens[0] 13 | 14 | # encode structure levels as integer by right-padding with zero byte 15 | # to 4 bytes 16 | if encode_struct: 17 | struct = b'' 18 | for s in tokens[1].split(b'.'): 19 | n = len(s) 20 | s = s + b'\x00'*(4-n) 21 | struct += s 22 | struct = np.frombuffer(struct, dtype=np.int32) 23 | else: 24 | struct = np.array(tokens[1].split(b'.')) 25 | 26 | return name, struct 27 | 28 | def parse_astral(f, encoder=NullEncoder(), encode_struct=True): 29 | names = [] 30 | structs = [] 31 | sequences = [] 32 | for name,sequence in fasta.parse_stream(f): 33 | x = encoder.encode(sequence.upper()) 34 | name, struct = parse_astral_name(name, encode_struct=encode_struct) 35 | names.append(name) 36 | structs.append(struct) 37 | sequences.append(x) 38 | structs = np.stack(structs, 0) 39 | return names, structs, sequences 40 | -------------------------------------------------------------------------------- /src/metrics.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | cimport cython 4 | cimport numpy as np 5 | import numpy as np 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | @cython.cdivision(True) 10 | def average_precision(np.ndarray[float] target, np.ndarray[float] pred, N=None): 11 | cdef float n 12 | if N is None: 13 | n = target.sum() 14 | else: 15 | n = N 16 | 17 | ## copy the target and prediction into matrix 18 | cdef np.ndarray[float, ndim=2] matrix = np.zeros((target.shape[0],2), dtype=np.float32) 19 | matrix[:,0] = -pred # negate the prediction to sort in descending order 20 | matrix[:,1] = target 21 | matrix.view('f4,f4').sort(order='f0', axis=0) # sort the rows 22 | #print(matrix[:10]) 23 | 24 | cdef float auprc, count, pr, relk, delta 25 | auprc = count = pr = relk = 0 26 | cdef int i = 0 27 | 28 | for i in range(matrix.shape[0]): 29 | count += 1 30 | relk += matrix[i,1] # target 31 | delta = matrix[i,1] - pr 32 | pr += delta/count 33 | if i >= matrix.shape[0] - 1 or matrix[i,0] != matrix[i+1,0]: 34 | auprc += pr*relk 35 | relk = 0 36 | auprc /= n 37 | 38 | return auprc 39 | 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't commit the data or pretrained models 2 | data/ 3 | pretrained_models/ 4 | pretrained_models.tar.gz 5 | 6 | # VIM temp files 7 | *.swp 8 | 9 | # cython compiled .c files 10 | src/*.c 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *,cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # IPython Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | -------------------------------------------------------------------------------- /src/pdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | def parse_secstr_stream(f, comment=b'#'): 4 | name = None 5 | flag = -1 6 | protein = [] 7 | secstr = [] 8 | for line in f: 9 | if line.startswith(comment): 10 | continue 11 | # strip newline 12 | line = line.rstrip(b'\r\n') 13 | if line.startswith(b'>'): 14 | if name is not None and flag==1: 15 | yield name, b''.join(protein), b''.join(secstr) 16 | elif flag == 0: 17 | assert line[1:].startswith(name) 18 | 19 | # each protein has an amino acid sequence 20 | # and secstr sequence associated with it 21 | 22 | name = line[1:] 23 | tokens = name.split(b':') 24 | name = b':'.join(tokens[:-1]) 25 | flag = tokens[-1] 26 | 27 | if flag == b'sequence': 28 | flag = 0 29 | protein = [] 30 | secstr = [] 31 | elif flag == b'secstr': 32 | flag = 1 33 | else: 34 | raise Exception("Unrecognized flag: " + flag.decode()) 35 | 36 | elif flag==0: 37 | protein.append(line) 38 | elif flag==1: 39 | secstr.append(line) 40 | else: 41 | raise Exception("Flag not set properly") 42 | 43 | if name is not None: 44 | yield name, b''.join(protein), b''.join(secstr) 45 | 46 | def parse_secstr(f, comment=b'#'): 47 | 48 | names = [] 49 | proteins = [] 50 | secstrs = [] 51 | for name,protein,secstr in parse_secstr_stream(f, comment=comment): 52 | names.append(name) 53 | proteins.append(protein) 54 | secstrs.append(secstr) 55 | return names, proteins, secstrs 56 | 57 | 58 | -------------------------------------------------------------------------------- /src/fasta.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | def parse_stream(f, comment=b'#'): 4 | name = None 5 | sequence = [] 6 | for line in f: 7 | if line.startswith(comment): 8 | continue 9 | line = line.strip() 10 | if line.startswith(b'>'): 11 | if name is not None: 12 | yield name, b''.join(sequence) 13 | name = line[1:] 14 | sequence = [] 15 | else: 16 | sequence.append(line.upper()) 17 | if name is not None: 18 | yield name, b''.join(sequence) 19 | 20 | def parse(f, comment=b'#'): 21 | names = [] 22 | sequences = [] 23 | name = None 24 | sequence = [] 25 | for line in f: 26 | if line.startswith(comment): 27 | continue 28 | line = line.strip() 29 | if line.startswith(b'>'): 30 | if name is not None: 31 | names.append(name) 32 | sequences.append(b''.join(sequence)) 33 | name = line[1:] 34 | sequence = [] 35 | else: 36 | sequence.append(line.upper()) 37 | if name is not None: 38 | names.append(name) 39 | sequences.append(b''.join(sequence)) 40 | 41 | return names, sequences 42 | 43 | #def parse(f, comment='#'): 44 | # names = [] 45 | # sequences = [] 46 | # name = None 47 | # sequence = [] 48 | # for line in f: 49 | # if line.startswith(comment): 50 | # continue 51 | # line = line.strip() 52 | # if line.startswith('>'): 53 | # if name is not None: 54 | # names.append(name) 55 | # sequences.append(''.join(sequence)) 56 | # name = line[1:] 57 | # sequence = [] 58 | # else: 59 | # sequence.append(line) 60 | # if name is not None: 61 | # names.append(name) 62 | # sequences.append(''.join(sequence)) 63 | # 64 | # return names, sequences 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/alphabets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | 5 | class Alphabet: 6 | def __init__(self, chars, encoding=None, mask=False, missing=255): 7 | self.chars = np.frombuffer(chars, dtype=np.uint8) 8 | self.encoding = np.zeros(256, dtype=np.uint8) + missing 9 | if encoding is None: 10 | self.encoding[self.chars] = np.arange(len(self.chars)) 11 | self.size = len(self.chars) 12 | else: 13 | self.encoding[self.chars] = encoding 14 | self.size = encoding.max() + 1 15 | self.mask = mask 16 | if mask: 17 | self.size -= 1 18 | 19 | def __len__(self): 20 | return self.size 21 | 22 | def __getitem__(self, i): 23 | return chr(self.chars[i]) 24 | 25 | def encode(self, x): 26 | """ encode a byte string into alphabet indices """ 27 | x = np.frombuffer(x, dtype=np.uint8) 28 | return self.encoding[x] 29 | 30 | def decode(self, x): 31 | """ decode index array, x, to byte string of this alphabet """ 32 | string = self.chars[x] 33 | return string.tobytes() 34 | 35 | def unpack(self, h, k): 36 | """ unpack integer h into array of this alphabet with length k """ 37 | n = self.size 38 | kmer = np.zeros(k, dtype=np.uint8) 39 | for i in reversed(range(k)): 40 | c = h % n 41 | kmer[i] = c 42 | h = h // n 43 | return kmer 44 | 45 | def get_kmer(self, h, k): 46 | """ retrieve byte string of length k decoded from integer h """ 47 | kmer = self.unpack(h, k) 48 | return self.decode(kmer) 49 | 50 | DNA = Alphabet(b'ACGT') 51 | 52 | class Uniprot21(Alphabet): 53 | def __init__(self, mask=False): 54 | chars = alphabet = b'ARNDCQEGHILKMFPSTWYVXOUBZ' 55 | encoding = np.arange(len(chars)) 56 | encoding[21:] = [11,4,20,20] # encode 'OUBZ' as synonyms 57 | super(Uniprot21, self).__init__(chars, encoding=encoding, mask=mask, missing=20) 58 | 59 | class SDM12(Alphabet): 60 | """ 61 | A D KER N TSQ YF LIVM C W H G P 62 | 63 | See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2732308/#B33 64 | "Reduced amino acid alphabets exhibit an improved sensitivity and selectivity in fold assignment" 65 | Peterson et al. 2009. Bioinformatics. 66 | """ 67 | def __init__(self, mask=False): 68 | chars = alphabet = b'ADKNTYLCWHGPXERSQFIVMOUBZ' 69 | groups = [b'A',b'D',b'KERO',b'N',b'TSQ',b'YF',b'LIVM',b'CU',b'W',b'H',b'G',b'P',b'XBZ'] 70 | groups = {c:i for i in range(len(groups)) for c in groups[i]} 71 | encoding = np.array([groups[c] for c in chars]) 72 | super(SDM12, self).__init__(chars, encoding=encoding, mask=mask) 73 | 74 | SecStr8 = Alphabet(b'HBEGITS ') 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning protein sequence embeddings using information from structure 2 | 3 | New and improved embedding models combining sequence and structure training are now available at https://github.com/tbepler/prose! 4 | 5 |
6 |
7 | 8 | This repository contains the source code and links to the data and pretrained embedding models accompanying the ICLR 2019 paper: [Learning protein sequence embeddings using information from structure](https://openreview.net/pdf?id=SygLehCqtm) 9 | 10 | ``` 11 | @inproceedings{ 12 | bepler2018learning, 13 | title={Learning protein sequence embeddings using information from structure}, 14 | author={Tristan Bepler and Bonnie Berger}, 15 | booktitle={International Conference on Learning Representations}, 16 | year={2019}, 17 | } 18 | ``` 19 | 20 | ## Setup and dependencies 21 | 22 | Dependencies: 23 | - python 3 24 | - pytorch >= 0.4 25 | - numpy 26 | - scipy 27 | - pandas 28 | - sklearn 29 | - cython 30 | - h5py (for embedding script) 31 | 32 | Run setup.py to compile the cython files: 33 | 34 | ``` 35 | python setup.py build_ext --inplace 36 | ``` 37 | 38 | ## Data sets 39 | 40 | The data sets with train/dev/test splits are provided as .tar.gz files from the links below. 41 | 42 | - [SCOPe data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/scope.tar.gz) 43 | - [Pfam data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/pfam.tar.gz) 44 | - [Protein secondary structure data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/secstr.tar.gz) 45 | - [Transmembrane data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/transmembrane.tar.gz) 46 | - [CASP12 contact map data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/casp12.tar.gz) 47 | 48 | The training and evaluation scripts assume that these data sets have been extracted into a directory called 'data'. 49 | 50 | ## Pretrained models 51 | 52 | Our trained versions of the structure-based embedding models and the bidirectional language model can be downloaded [here](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/pretrained_models.tar.gz). 53 | 54 | ## Author 55 | 56 | Tristan Bepler (tbepler@mit.edu) 57 | 58 | ## Cite 59 | 60 | Please cite the above paper if you use this code or pretrained models in your work. 61 | 62 | ## License 63 | 64 | The source code and trained models are provided free for non-commercial use under the terms of the CC BY-NC 4.0 license. See [LICENSE](LICENSE) file and/or https://creativecommons.org/licenses/by-nc/4.0/legalcode for more information. 65 | 66 | 67 | ## Contact 68 | 69 | If you have any questions, comments, or would like to report a bug, please file a Github issue or contact me at tbepler@mit.edu. 70 | 71 | -------------------------------------------------------------------------------- /src/models/comparison.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class L1(nn.Module): 9 | def forward(self, x, y): 10 | return -torch.sum(torch.abs(x.unsqueeze(1)-y), -1) 11 | 12 | 13 | class L2(nn.Module): 14 | def forward(self, x, y): 15 | return -torch.sum((x.unsqueeze(1)-y)**2, -1) 16 | 17 | 18 | class DotProduct(nn.Module): 19 | def forward(self, x, y): 20 | return torch.mm(x, y.t()) 21 | 22 | 23 | def pad_gap_scores(s, gap): 24 | col = gap.expand(s.size(0), 1) 25 | s = torch.cat([s, col], 1) 26 | row = gap.expand(1, s.size(1)) 27 | s = torch.cat([s, row], 0) 28 | return s 29 | 30 | 31 | class OrdinalRegression(nn.Module): 32 | def __init__(self, embedding, n_classes, compare=L1() 33 | , align_method='ssa', beta_init=10 34 | , allow_insertions=False, gap_init=-10 35 | ): 36 | super(OrdinalRegression, self).__init__() 37 | 38 | self.embedding = embedding 39 | self.n_out = n_classes 40 | 41 | self.compare = compare 42 | self.align_method = align_method 43 | self.allow_insertions = allow_insertions 44 | self.gap = nn.Parameter(torch.FloatTensor([gap_init])) 45 | 46 | self.theta = nn.Parameter(torch.ones(1,n_classes-1)) 47 | self.beta = nn.Parameter(torch.zeros(n_classes-1)+beta_init) 48 | self.clip() 49 | 50 | def forward(self, x): 51 | return self.embedding(x) 52 | 53 | def clip(self): 54 | # clip the weights of ordinal regression to be non-negative 55 | self.theta.data.clamp_(min=0) 56 | 57 | def score(self, z_x, z_y): 58 | 59 | if self.align_method == 'ssa': 60 | s = self.compare(z_x, z_y) 61 | if self.allow_insertions: 62 | s = pad_gap_scores(s, self.gap) 63 | 64 | a = F.softmax(s, 1) 65 | b = F.softmax(s, 0) 66 | 67 | if self.allow_insertions: 68 | index = s.size(0)-1 69 | index = s.data.new(1).long().fill_(index) 70 | a = a.index_fill(0, index, 0) 71 | 72 | index = s.size(1)-1 73 | index = s.data.new(1).long().fill_(index) 74 | b = b.index_fill(1, index, 0) 75 | 76 | a = a + b - a*b 77 | c = torch.sum(a*s)/torch.sum(a) 78 | 79 | elif self.align_method == 'ua': 80 | s = self.compare(z_x, z_y) 81 | c = torch.mean(s) 82 | 83 | elif self.align_method == 'me': 84 | z_x = z_x.mean(0) 85 | z_y = z_y.mean(0) 86 | c = self.compare(z_x.unsqueeze(0), z_y.unsqueeze(0)).squeeze(0) 87 | 88 | else: 89 | raise Exception('Unknown alignment method: ' + self.align_method) 90 | 91 | logits = c*self.theta + self.beta 92 | return logits.view(-1) 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /src/models/multitask.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .comparison import L1, pad_gap_scores 8 | 9 | 10 | class SCOPCM(nn.Module): 11 | def __init__(self, embedding, similarity_kwargs={}, 12 | cmap_kwargs={}): 13 | super(SCOPCM, self).__init__() 14 | 15 | self.embedding = embedding 16 | embed_dim = embedding.nout 17 | 18 | self.scop_predict = OrdinalRegression(5, **similarity_kwargs) 19 | self.cmap_predict = ConvContactMap(embed_dim, **cmap_kwargs) 20 | 21 | def clip(self): 22 | self.scop_predict.clip() 23 | self.cmap_predict.clip() 24 | 25 | def forward(self, x): 26 | return self.embedding(x) 27 | 28 | def score(self, z_x, z_y): 29 | return self.scop_predict(z_x, z_y) 30 | 31 | def predict(self, z): 32 | return self.cmap_predict(z) 33 | 34 | 35 | class ConvContactMap(nn.Module): 36 | def __init__(self, embed_dim, hidden_dim=50, width=7, act=nn.ReLU()): 37 | super(ConvContactMap, self).__init__() 38 | self.hidden = nn.Conv2d(2*embed_dim, hidden_dim, 1) 39 | self.act = act 40 | self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width//2) 41 | self.clip() 42 | 43 | def clip(self): 44 | # force the conv layer to be transpose invariant 45 | w = self.conv.weight 46 | self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3)) 47 | 48 | def forward(self, z): 49 | return self.predict(z) 50 | 51 | def predict(self, z): 52 | # z is (b,L,d) 53 | z = z.transpose(1, 2) # (b,d,L) 54 | z_dif = torch.abs(z.unsqueeze(2) - z.unsqueeze(3)) 55 | z_mul = z.unsqueeze(2)*z.unsqueeze(3) 56 | z = torch.cat([z_dif, z_mul], 1) 57 | # (b,2d,L,L) 58 | h = self.act(self.hidden(z)) 59 | logits = self.conv(h).squeeze(1) 60 | return logits 61 | 62 | 63 | class OrdinalRegression(nn.Module): 64 | def __init__(self, n_classes, compare=L1() 65 | , align_method='ssa', beta_init=10 66 | , allow_insertions=False, gap_init=-10 67 | ): 68 | super(OrdinalRegression, self).__init__() 69 | 70 | self.n_out = n_classes 71 | 72 | self.compare = compare 73 | self.align_method = align_method 74 | self.allow_insertions = allow_insertions 75 | self.gap = nn.Parameter(torch.FloatTensor([gap_init])) 76 | 77 | self.theta = nn.Parameter(torch.ones(1,n_classes-1)) 78 | self.beta = nn.Parameter(torch.zeros(n_classes-1)+beta_init) 79 | self.clip() 80 | 81 | def clip(self): 82 | # clip the weights of ordinal regression to be non-negative 83 | self.theta.data.clamp_(min=0) 84 | 85 | def forward(self, z_x, z_y): 86 | 87 | if self.align_method == 'ssa': 88 | s = self.compare(z_x, z_y) 89 | if self.allow_insertions: 90 | s = pad_gap_scores(s, self.gap) 91 | 92 | a = F.softmax(s, 1) 93 | b = F.softmax(s, 0) 94 | 95 | if self.allow_insertions: 96 | index = s.size(0)-1 97 | index = s.data.new(1).long().fill_(index) 98 | a = a.index_fill(0, index, 0) 99 | 100 | index = s.size(1)-1 101 | index = s.data.new(1).long().fill_(index) 102 | b = b.index_fill(1, index, 0) 103 | 104 | a = a + b - a*b 105 | c = torch.sum(a*s)/torch.sum(a) 106 | 107 | elif self.align_method == 'ua': 108 | s = self.compare(z_x, z_y) 109 | c = torch.mean(s) 110 | 111 | elif self.align_method == 'me': 112 | z_x = z_x.mean(0) 113 | z_y = z_y.mean(0) 114 | c = self.compare(z_x.unsqueeze(0), z_y.unsqueeze(0)).squeeze(0) 115 | 116 | else: 117 | raise Exception('Unknown alignment method: ' + self.align_method) 118 | 119 | logits = c*self.theta + self.beta 120 | return logits.view(-1) 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /src/models/embedding.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import PackedSequence 7 | 8 | 9 | class LMEmbed(nn.Module): 10 | def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU() 11 | , sparse=False): 12 | super(LMEmbed, self).__init__() 13 | 14 | if padding_idx == -1: 15 | padding_idx = nin-1 16 | 17 | self.lm = lm 18 | self.embed = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse) 19 | self.proj = nn.Linear(lm.hidden_size(), nout) 20 | self.transform = transform 21 | self.nout = nout 22 | 23 | def forward(self, x): 24 | packed = type(x) is PackedSequence 25 | h_lm = self.lm.encode(x) 26 | 27 | # embed and unpack if packed 28 | if packed: 29 | h = self.embed(x.data) 30 | h_lm = h_lm.data 31 | else: 32 | h = self.embed(x) 33 | 34 | # project 35 | h_lm = self.proj(h_lm) 36 | h = self.transform(h + h_lm) 37 | 38 | # repack if needed 39 | if packed: 40 | h = PackedSequence(h, x.batch_sizes) 41 | 42 | return h 43 | 44 | 45 | class Linear(nn.Module): 46 | def __init__(self, nin, nhidden, nout, padding_idx=-1, 47 | sparse=False, lm=None): 48 | super(Linear, self).__init__() 49 | 50 | if padding_idx == -1: 51 | padding_idx = nin-1 52 | 53 | if lm is not None: 54 | self.embed = LMEmbed(nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse) 55 | self.proj = nn.Linear(self.embed.nout, nout) 56 | self.lm = True 57 | else: 58 | self.proj = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse) 59 | self.lm = False 60 | 61 | self.nout = nout 62 | 63 | 64 | def forward(self, x): 65 | 66 | if self.lm: 67 | h = self.embed(x) 68 | if type(h) is PackedSequence: 69 | h = h.data 70 | z = self.proj(h) 71 | z = PackedSequence(z, x.batch_sizes) 72 | else: 73 | h = h.view(-1, h.size(2)) 74 | z = self.proj(h) 75 | z = z.view(x.size(0), x.size(1), -1) 76 | else: 77 | if type(x) is PackedSequence: 78 | z = self.embed(x.data) 79 | z = PackedSequence(z, x.batch_sizes) 80 | else: 81 | z = self.embed(x) 82 | 83 | return z 84 | 85 | 86 | class StackedRNN(nn.Module): 87 | def __init__(self, nin, nembed, nunits, nout, nlayers=2, padding_idx=-1, dropout=0, 88 | rnn_type='lstm', sparse=False, lm=None): 89 | super(StackedRNN, self).__init__() 90 | 91 | if padding_idx == -1: 92 | padding_idx = nin-1 93 | 94 | if lm is not None: 95 | self.embed = LMEmbed(nin, nembed, lm, padding_idx=padding_idx, sparse=sparse) 96 | nembed = self.embed.nout 97 | self.lm = True 98 | else: 99 | self.embed = nn.Embedding(nin, nembed, padding_idx=padding_idx, sparse=sparse) 100 | self.lm = False 101 | 102 | if rnn_type == 'lstm': 103 | RNN = nn.LSTM 104 | elif rnn_type == 'gru': 105 | RNN = nn.GRU 106 | 107 | self.dropout = nn.Dropout(p=dropout) 108 | if nlayers == 1: 109 | dropout = 0 110 | 111 | self.rnn = RNN(nembed, nunits, nlayers, batch_first=True 112 | , bidirectional=True, dropout=dropout) 113 | self.proj = nn.Linear(2*nunits, nout) 114 | self.nout = nout 115 | 116 | 117 | 118 | def forward(self, x): 119 | 120 | if self.lm: 121 | h = self.embed(x) 122 | else: 123 | if type(x) is PackedSequence: 124 | h = self.embed(x.data) 125 | h = PackedSequence(h, x.batch_sizes) 126 | else: 127 | h = self.embed(x) 128 | 129 | h,_ = self.rnn(h) 130 | 131 | if type(h) is PackedSequence: 132 | h = h.data 133 | h = self.dropout(h) 134 | z = self.proj(h) 135 | z = PackedSequence(z, x.batch_sizes) 136 | else: 137 | h = h.view(-1, h.size(2)) 138 | h = self.dropout(h) 139 | z = self.proj(h) 140 | z = z.view(x.size(0), x.size(1), -1) 141 | 142 | return z 143 | 144 | 145 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data 7 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 8 | 9 | def pack_sequences(X, order=None): 10 | 11 | #X = [x.squeeze(0) for x in X] 12 | 13 | n = len(X) 14 | lengths = np.array([len(x) for x in X]) 15 | if order is None: 16 | order = np.argsort(lengths)[::-1] 17 | m = max(len(x) for x in X) 18 | 19 | X_block = X[0].new(n,m).zero_() 20 | 21 | for i in range(n): 22 | j = order[i] 23 | x = X[j] 24 | X_block[i,:len(x)] = x 25 | 26 | #X_block = torch.from_numpy(X_block) 27 | 28 | lengths = lengths[order] 29 | X = pack_padded_sequence(X_block, lengths, batch_first=True) 30 | 31 | return X, order 32 | 33 | 34 | def unpack_sequences(X, order): 35 | X,lengths = pad_packed_sequence(X, batch_first=True) 36 | X_block = [None]*len(order) 37 | for i in range(len(order)): 38 | j = order[i] 39 | X_block[j] = X[i,:lengths[i]] 40 | return X_block 41 | 42 | 43 | def collate_lists(args): 44 | x = [a[0] for a in args] 45 | y = [a[1] for a in args] 46 | return x, y 47 | 48 | 49 | class ContactMapDataset(torch.utils.data.Dataset): 50 | def __init__(self, X, Y, augment=None, fragment=False, mi=64, ma=500): 51 | self.X = X 52 | self.Y = Y 53 | self.augment = augment 54 | self.fragment = fragment 55 | self.mi = mi 56 | self.ma = ma 57 | """ 58 | if fragment: # multiply sequence occurence by expected number of fragments 59 | lengths = np.array([len(x) for x in X]) 60 | mi = np.clip(lengths, None, mi) 61 | ma = np.clip(lengths, None, ma) 62 | weights = 2*lengths/(ma + mi) 63 | mul = np.ceil(weights).astype(int) 64 | X_ = [] 65 | Y_ = [] 66 | for i,n in enumerate(mul): 67 | X_ += [X[i]]*n 68 | Y_ += [Y[i]]*n 69 | self.X = X_ 70 | self.Y = Y_ 71 | """ 72 | 73 | def __len__(self): 74 | return len(self.X) 75 | 76 | def __getitem__(self, i): 77 | x = self.X[i] 78 | y = self.Y[i] 79 | if self.fragment and len(x) > self.mi: 80 | mi = self.mi 81 | ma = min(self.ma, len(x)) 82 | l = np.random.randint(mi, ma+1) 83 | i = np.random.randint(len(x)-l+1) 84 | xl = x[i:i+l] 85 | yl = y[i:i+l,i:i+l] 86 | # make sure there are unmasked observations 87 | while torch.sum(yl >= 0) == 0: 88 | l = np.random.randint(mi, ma+1) 89 | i = np.random.randint(len(x)-l+1) 90 | xl = x[i:i+l] 91 | yl = y[i:i+l,i:i+l] 92 | y = yl.contiguous() 93 | x = xl 94 | if self.augment is not None: 95 | x = self.augment(x) 96 | return x, y 97 | 98 | 99 | class AllPairsDataset(torch.utils.data.Dataset): 100 | def __init__(self, X, Y, augment=None): 101 | self.X = X 102 | self.Y = Y 103 | self.augment = augment 104 | 105 | def __len__(self): 106 | return len(self.X)**2 107 | 108 | def __getitem__(self, k): 109 | n = len(self.X) 110 | i = k//n 111 | j = k%n 112 | 113 | x0 = self.X[i] 114 | x1 = self.X[j] 115 | if self.augment is not None: 116 | x0 = self.augment(x0) 117 | x1 = self.augment(x1) 118 | 119 | y = self.Y[i,j] 120 | #y = torch.cumprod((self.Y[i] == self.Y[j]).long(), 0).sum() 121 | 122 | return x0, x1, y 123 | 124 | 125 | class PairedDataset(torch.utils.data.Dataset): 126 | def __init__(self, X0, X1, Y): 127 | self.X0 = X0 128 | self.X1 = X1 129 | self.Y = Y 130 | 131 | def __len__(self): 132 | return len(self.X0) 133 | 134 | def __getitem__(self, i): 135 | return self.X0[i], self.X1[i], self.Y[i] 136 | 137 | 138 | def collate_paired_sequences(args): 139 | x0 = [a[0] for a in args] 140 | x1 = [a[1] for a in args] 141 | y = [a[2] for a in args] 142 | return x0, x1, torch.stack(y, 0) 143 | 144 | 145 | class MultinomialResample: 146 | def __init__(self, trans, p): 147 | self.p = (1-p)*torch.eye(trans.size(0)).to(trans.device) + p*trans 148 | 149 | def __call__(self, x): 150 | #print(x.size(), x.dtype) 151 | p = self.p[x] # get distribution for each x 152 | return torch.multinomial(p, 1).view(-1) # sample from distribution 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /src/transmembrane.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def encode_labels(s): 11 | y = np.zeros(len(s), dtype=int) 12 | for i in range(len(s)): 13 | if s[i:i+1] == b'I': 14 | y[i] = 0 15 | elif s[i:i+1] == b'O': 16 | y[i] = 1 17 | elif s[i:i+1] == b'M': 18 | y[i] = 2 19 | elif s[i:i+1] == b'S': 20 | y[i] = 3 21 | else: 22 | raise Exception('Unrecognized annotation: ' + s[i:i+1].decode('utf-8')) 23 | return y 24 | 25 | 26 | def transmembrane_regions(y): 27 | regions = [] 28 | start = -1 29 | for i in range(len(y)): 30 | if y[i] == 2 and start < 0: 31 | start = i 32 | elif y[i] != 2 and start > 0: 33 | regions.append((start,i)) 34 | start = -1 35 | if start > 0: 36 | regions.append((start, len(y))) 37 | return regions 38 | 39 | 40 | def is_prediction_correct(y_hat, y): 41 | ## prediction is correct if it has the same number of transmembrane regions 42 | ## and those overlap real transmembrane regions by at least 5 bases 43 | ## and it starts with a signaling peptide when y does 44 | pred_regions = transmembrane_regions(y_hat) 45 | target_regions = transmembrane_regions(y) 46 | if len(pred_regions) != len(target_regions): 47 | return 0 48 | 49 | for p,t in zip(pred_regions, target_regions): 50 | if p[1] <= t[0]: 51 | return 0 52 | if t[1] <= p[0]: 53 | return 0 54 | s = max(p[0], t[0]) 55 | e = min(p[1], t[1]) 56 | overlap = e - s 57 | if overlap < 5: 58 | return 0 59 | 60 | # finally, check signal peptide 61 | if y[0] == 3 and y_hat[0] != 3: 62 | return 0 63 | 64 | return 1 65 | 66 | 67 | ## TOPCONS uses a very specific state architecture for HMM 68 | ## we can adopt this to describe the transmembrane grammar 69 | class Grammar: 70 | def __init__(self, n_helix=21, signal_helix=True): 71 | ## describe the transmembrane states 72 | n_states = 3 + 2*n_helix 73 | 74 | start = np.zeros(n_states) 75 | start[0] = 1.0 # inner 76 | start[1] = 1.0 # outer 77 | start[2] = 1.0 # signal peptide 78 | 79 | end = np.zeros(n_states) 80 | end[0] = 1.0 # from inner 81 | end[1] = 1.0 # from outer 82 | 83 | trans = np.zeros((n_states, n_states)) 84 | trans[0,0] = 1.0 # inner -> inner 85 | trans[0,3] = 1.0 # inner -> helix (i->o) 86 | trans[1,1] = 1.0 # outer -> outer 87 | trans[1,3+n_helix] = 1.0 # outer -> helix (o->i) 88 | 89 | trans[2,0] = 1.0 # signal -> inner 90 | trans[2,1] = 1.0 # signal -> outer 91 | 92 | for i in range(3,2+n_helix): # i->o helices 93 | trans[i,i+1] = 1.0 94 | trans[2+n_helix,1] = 1.0 # helix (i->o) -> outer 95 | 96 | for i in range(3+n_helix,2+2*n_helix): # o->i helices 97 | trans[i,i+1] = 1.0 98 | trans[2+2*n_helix,0] = 1.0 # helix (o->i) -> inner 99 | 100 | emit = np.zeros((n_states, 4)) 101 | emit[0,0] = 1.0 # inner 102 | emit[0,1] = 1.0 103 | emit[1,0] = 1.0 # outer 104 | emit[1,1] = 1.0 105 | emit[2,3] = 1.0 # signal peptide 106 | #if signal_helix: 107 | # emit[2,2] = 1.0 108 | for i in range(3,3+2*n_helix): # helices 109 | emit[i,2] = 1.0 110 | 111 | mapping = np.zeros(n_states, dtype=int) 112 | mapping[0] = 0 113 | mapping[1] = 1 114 | mapping[2] = 3 115 | mapping[3:3+2*n_helix] = 2 116 | 117 | self.start = np.log(start) - np.log(start.sum()) 118 | self.end = np.log(end) - np.log(end.sum()) 119 | self.trans = np.log(trans) - np.log(trans.sum(1, keepdims=True)) 120 | self.emit = emit 121 | self.mapping = mapping 122 | 123 | def decode(self, logp): 124 | p = np.exp(logp) 125 | z = np.log(np.dot(p, self.emit.T)) 126 | 127 | tb = np.zeros(z.shape, dtype=np.int8) - 1 128 | p0 = z[0] + self.start 129 | for i in range(z.shape[0] - 1): 130 | trans = p0[:,np.newaxis] + self.trans + z[i+1] # 131 | tb[i+1] = np.argmax(trans, 0) 132 | p0 = np.max(trans, 0) 133 | # transition to end 134 | p0 = p0 + self.end 135 | state = np.argmax(p0) 136 | score = np.max(p0) 137 | # traceback most likely sequence of states 138 | y = np.zeros(z.shape[0], dtype=int) 139 | j = state 140 | y[-1] = j 141 | for i in range(z.shape[0]-1, 0, -1): 142 | j = tb[i,j] 143 | y[i-1] = j 144 | 145 | # map the states 146 | y = self.mapping[y] 147 | 148 | return y, score 149 | 150 | def predict_viterbi(self, xs, model, use_cuda=False): 151 | y_hats = [] 152 | with torch.no_grad(): 153 | for x in xs: 154 | if use_cuda: 155 | x = x.cuda() 156 | log_p_hat = F.log_softmax(model(x), 1).cpu().numpy() 157 | y_hat,_ = self.decode(log_p_hat) 158 | y_hats.append(y_hat) 159 | return y_hats 160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /src/models/sequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 7 | 8 | class BiLM(nn.Module): 9 | def __init__(self, nin, nout, embedding_dim, hidden_dim, num_layers 10 | , tied=True, mask_idx=None, dropout=0): 11 | super(BiLM, self).__init__() 12 | 13 | if mask_idx is None: 14 | mask_idx = nin-1 15 | self.mask_idx = mask_idx 16 | self.embed = nn.Embedding(nin, embedding_dim, padding_idx=mask_idx) 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | self.tied = tied 20 | if tied: 21 | layers = [] 22 | nin = embedding_dim 23 | for _ in range(num_layers): 24 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True)) 25 | nin = hidden_dim 26 | self.rnn = nn.ModuleList(layers) 27 | else: 28 | layers = [] 29 | nin = embedding_dim 30 | for _ in range(num_layers): 31 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True)) 32 | nin = hidden_dim 33 | self.lrnn = nn.ModuleList(layers) 34 | 35 | layers = [] 36 | nin = embedding_dim 37 | for _ in range(num_layers): 38 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True)) 39 | nin = hidden_dim 40 | self.rrnn = nn.ModuleList(layers) 41 | 42 | self.linear = nn.Linear(hidden_dim, nout) 43 | 44 | def hidden_size(self): 45 | h = 0 46 | if self.tied: 47 | for layer in self.rnn: 48 | h += 2*layer.hidden_size 49 | else: 50 | for layer in self.lrnn: 51 | h += layer.hidden_size 52 | for layer in self.rrnn: 53 | h += layer.hidden_size 54 | return h 55 | 56 | 57 | def transform(self, z, last_only=False): 58 | # sequences are flanked by the start/stop token as: 59 | # [stop, x, stop] 60 | 61 | idx = [i for i in range(z.size(1)-1, -1, -1)] 62 | idx = torch.LongTensor(idx).to(z.device) 63 | z_rvs = z.index_select(1, idx) 64 | 65 | z = z[:,:-1] 66 | z_rvs = z_rvs[:,:-1] 67 | idx = [i for i in range(z_rvs.size(1)-1, -1, -1)] 68 | idx = torch.LongTensor(idx).to(z_rvs.device) 69 | 70 | if last_only: 71 | if self.tied: 72 | h_fwd = z 73 | h_rvs = z_rvs 74 | for rnn in self.rnn: 75 | h_fwd,_ = rnn(h_fwd) 76 | h_fwd = self.dropout(h_fwd) 77 | h_rvs,_ = rnn(h_rvs) 78 | h_rvs = self.dropout(h_rvs) 79 | else: 80 | h_fwd = z 81 | h_rvs = z_rvs 82 | for lrnn,rrnn in zip(self.lrnn, self.rrnn): 83 | h_fwd,_ = lrnn(h_fwd) 84 | h_fwd = self.dropout(h_fwd) 85 | h_rvs,_ = rrnn(h_rvs) 86 | h_rvs = self.dropout(h_rvs) 87 | hidden = (h_fwd, h_rvs.index_select(1, idx)) 88 | else: 89 | hidden = [] 90 | if self.tied: 91 | h_fwd = z 92 | h_rvs = z_rvs 93 | for rnn in self.rnn: 94 | h_fwd,_ = rnn(h_fwd) 95 | h_fwd = self.dropout(h_fwd) 96 | h_rvs,_ = rnn(h_rvs) 97 | h_rvs = self.dropout(h_rvs) 98 | hidden.append((h_fwd, h_rvs.index_select(1, idx))) 99 | else: 100 | h_fwd = z 101 | h_rvs = z_rvs 102 | for lrnn,rrnn in zip(self.lrnn, self.rrnn): 103 | h_fwd,_ = lrnn(h_fwd) 104 | h_fwd = self.dropout(h_fwd) 105 | h_rvs,_ = rrnn(h_rvs) 106 | h_rvs = self.dropout(h_rvs) 107 | hidden.append((h_fwd, h_rvs.index_select(1, dx))) 108 | 109 | return hidden 110 | 111 | 112 | def encode(self, x): 113 | packed = type(x) is PackedSequence 114 | if packed: 115 | # pad with the start/stop token 116 | x,batch_sizes = pad_packed_sequence(x, batch_first=True, padding_value=self.mask_idx-1) 117 | x = x + 1 118 | ## append start/stop tokens to x 119 | x_ = x.data.new(x.size(0), x.size(1)+2).zero_() 120 | x_[:,1:-1] = x 121 | x = x_ 122 | 123 | 124 | # sequences x are flanked by the start/stop token as: 125 | # [stop, x, stop] 126 | 127 | z = self.embed(x) 128 | hidden = self.transform(z) 129 | 130 | concat = [] 131 | for h_fwd,h_rvs in hidden: 132 | h_fwd = h_fwd[:,:-1] 133 | h_rvs = h_rvs[:,1:] 134 | concat.append(h_fwd) 135 | concat.append(h_rvs) 136 | 137 | h = torch.cat(concat, 2) 138 | if packed: 139 | h = pack_padded_sequence(h, batch_sizes, batch_first=True) 140 | 141 | return h 142 | 143 | 144 | def forward(self, x): 145 | packed = type(x) is PackedSequence 146 | if packed: 147 | # pad with the start/stop token 148 | x,batch_sizes = pad_packed_sequence(x, batch_first=True, padding_value=self.mask_idx) 149 | 150 | # sequences x are flanked by the start/stop token as: 151 | # [stop, x, stop] 152 | 153 | z = self.embed(x) 154 | 155 | h_fwd,h_rvs = self.transform(z, last_only=True) 156 | 157 | b = z.size(0) 158 | n = z.size(1) - 1 159 | 160 | h_fwd = h_fwd.contiguous() 161 | h_flat = h_fwd.view(b*n, h_fwd.size(2)) 162 | logp_fwd = self.linear(h_flat) 163 | logp_fwd = logp_fwd.view(b, n, -1) 164 | 165 | zero = h_fwd.data.new(b,1,logp_fwd.size(2)).zero_() 166 | logp_fwd = torch.cat([zero, logp_fwd], 1) 167 | 168 | h_rvs = h_rvs.contiguous() 169 | logp_rvs = self.linear(h_rvs.view(-1, h_rvs.size(2))).view(b, n, -1) 170 | logp_rvs = torch.cat([logp_rvs, zero], 1) 171 | 172 | logp = F.log_softmax(logp_fwd + logp_rvs, dim=2) 173 | if packed: 174 | logp = pack_padded_sequence(logp, batch_sizes, batch_first=True) 175 | 176 | return logp 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /embed_sequences.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import sys 4 | import numpy as np 5 | import h5py 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from src.alphabets import Uniprot21 12 | import src.fasta as fasta 13 | import src.models.sequence 14 | 15 | 16 | def unstack_lstm(lstm): 17 | device = next(iter(lstm.parameters())).device 18 | 19 | in_size = lstm.input_size 20 | hidden_dim = lstm.hidden_size 21 | layers = [] 22 | for i in range(lstm.num_layers): 23 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True) 24 | layer.to(device) 25 | 26 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l'] 27 | for attr in attributes: 28 | dest = attr + '0' 29 | src = attr + str(i) 30 | getattr(layer, dest).data[:] = getattr(lstm, src) 31 | #setattr(layer, dest, getattr(lstm, src)) 32 | 33 | dest = attr + '0_reverse' 34 | src = attr + str(i) + '_reverse' 35 | getattr(layer, dest).data[:] = getattr(lstm, src) 36 | #setattr(layer, dest, getattr(lstm, src)) 37 | layer.flatten_parameters() 38 | layers.append(layer) 39 | in_size = 2*hidden_dim 40 | return layers 41 | 42 | def embed_stack(x, lm_embed, lstm_stack, proj, include_lm=True, final_only=False): 43 | zs = [] 44 | 45 | x_onehot = x.new(x.size(0),x.size(1), 21).float().zero_() 46 | x_onehot.scatter_(2,x.unsqueeze(2),1) 47 | zs.append(x_onehot) 48 | 49 | h = lm_embed(x) 50 | if include_lm and not final_only: 51 | zs.append(h) 52 | 53 | if lstm_stack is not None: 54 | for lstm in lstm_stack: 55 | h,_ = lstm(h) 56 | if not final_only: 57 | zs.append(h) 58 | h = proj(h.squeeze(0)).unsqueeze(0) 59 | zs.append(h) 60 | 61 | z = torch.cat(zs, 2) 62 | return z 63 | 64 | 65 | def embed_sequence(x, lm_embed, lstm_stack, proj, include_lm=True, final_only=False 66 | , pool='none', use_cuda=False): 67 | 68 | if len(x) == 0: 69 | return None 70 | 71 | alphabet = Uniprot21() 72 | x = x.upper() 73 | # convert to alphabet index 74 | x = alphabet.encode(x) 75 | x = torch.from_numpy(x) 76 | if use_cuda: 77 | x = x.cuda() 78 | 79 | # embed the sequence 80 | with torch.no_grad(): 81 | x = x.long().unsqueeze(0) 82 | z = embed_stack(x, lm_embed, lstm_stack, proj 83 | , include_lm=include_lm, final_only=final_only) 84 | # pool if needed 85 | z = z.squeeze(0) 86 | if pool == 'sum': 87 | z = z.sum(0) 88 | elif pool == 'max': 89 | z,_ = z.max(0) 90 | elif pool == 'avg': 91 | z = z.mean(0) 92 | z = z.cpu().numpy() 93 | 94 | return z 95 | 96 | 97 | def load_model(path, use_cuda=False): 98 | encoder = torch.load(path) 99 | encoder.eval() 100 | 101 | if use_cuda: 102 | encoder.cuda() 103 | 104 | if type(encoder) is src.models.sequence.BiLM: 105 | # model is only the LM 106 | return encoder.encode, None, None 107 | 108 | encoder = encoder.embedding 109 | 110 | lm_embed = encoder.embed 111 | lstm_stack = unstack_lstm(encoder.rnn) 112 | proj = encoder.proj 113 | 114 | return lm_embed, lstm_stack, proj 115 | 116 | 117 | def main(): 118 | import argparse 119 | parser = argparse.ArgumentParser('Script for embedding fasta format sequences using a saved embedding model. Saves embeddings as HDF5 file.') 120 | 121 | parser.add_argument('path', help='sequences to embed in fasta format') 122 | parser.add_argument('-m', '--model', help='path to saved embedding model') 123 | parser.add_argument('-o', '--output', help='path to HDF5 output file') 124 | parser.add_argument('--lm-only', action='store_true', help='only return the language model hidden layers') 125 | parser.add_argument('--no-lm', action='store_true', help='do not include LM hidden layers in embedding. by default, all hidden layers of all layers are concatenated and returned by this script.') 126 | parser.add_argument('--proj-only', action='store_true', help='only return the final structure-learned embedding') 127 | parser.add_argument('--pool', choices=['none', 'sum', 'max', 'avg'], default='none', help='apply some pooling operation over each sequence (default: none)') 128 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 129 | 130 | args = parser.parse_args() 131 | 132 | path = args.path 133 | 134 | # set the device 135 | d = args.device 136 | use_cuda = (d != -1) and torch.cuda.is_available() 137 | if d >= 0: 138 | torch.cuda.set_device(d) 139 | 140 | # load the model 141 | lm_embed, lstm_stack, proj = load_model(args.model, use_cuda=use_cuda) 142 | 143 | # parse the sequences and embed them 144 | # write them to hdf5 file 145 | print('# writing:', args.output, file=sys.stderr) 146 | h5 = h5py.File(args.output, 'w') 147 | 148 | lm_only = args.lm_only 149 | if lm_only: 150 | lstm_stack = None 151 | proj = None 152 | 153 | no_lm = args.no_lm 154 | include_lm = not no_lm 155 | final_only = args.proj_only 156 | 157 | pool = args.pool 158 | print('# embedding with lm_only={}, no_lm={}, proj_only={}'.format(lm_only, no_lm, final_only), file=sys.stderr) 159 | print('# pooling:', pool, file=sys.stderr) 160 | 161 | count = 0 162 | with open(path, 'rb') as f: 163 | for name,sequence in fasta.parse_stream(f): 164 | # use sequence name as HDF key 165 | pid = name.decode('utf-8') 166 | if len(sequence) == 0: 167 | print('# WARNING: sequence', pid, 'has length=0. Skipping.', file=sys.stderr) 168 | continue 169 | # only do pids we haven't done already... 170 | if pid not in h5: 171 | z = embed_sequence(sequence, lm_embed, lstm_stack, proj 172 | , include_lm=include_lm, final_only=final_only 173 | , pool=pool, use_cuda=use_cuda) 174 | # write as hdf5 dataset 175 | h5.create_dataset(pid, data=z, compression='lzf') 176 | count += 1 177 | print('# {} sequences processed...'.format(count), file=sys.stderr, end='\r') 178 | print(' '*80, file=sys.stderr, end='\r') 179 | print('# Done!', file=sys.stderr) 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /eval_contact_scop.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | import sys 5 | import os 6 | import glob 7 | from PIL import Image 8 | 9 | import torch 10 | from torch.nn.utils.rnn import PackedSequence 11 | import torch.utils.data 12 | 13 | from src.alphabets import Uniprot21 14 | import src.fasta as fasta 15 | from src.utils import pack_sequences, unpack_sequences 16 | from src.utils import ContactMapDataset, collate_lists 17 | from src.metrics import average_precision 18 | 19 | 20 | def load_data(seq_path, struct_path, alphabet, baselines=False): 21 | 22 | pdb_index = {} 23 | for path in struct_path: 24 | pid = os.path.basename(path)[:7] 25 | pdb_index[pid] = path 26 | 27 | with open(seq_path, 'rb') as f: 28 | names, sequences = fasta.parse(f) 29 | names = [name.split()[0].decode('utf-8') for name in names] 30 | sequences = [alphabet.encode(s.upper()) for s in sequences] 31 | 32 | x = [torch.from_numpy(x).long() for x in sequences] 33 | 34 | names_ = [] 35 | x_ = [] 36 | y = [] 37 | for xi,name in zip(x,names): 38 | pid = name 39 | if pid not in pdb_index: 40 | pid = 'd' + pid[1:] 41 | path = pdb_index[pid] 42 | 43 | im = np.array(Image.open(path), copy=False) 44 | contacts = np.zeros(im.shape, dtype=np.float32) 45 | contacts[im == 1] = -1 46 | contacts[im == 255] = 1 47 | 48 | # mask the matrix below the diagonal 49 | mask = np.tril_indices(contacts.shape[0], k=1) 50 | contacts[mask] = -1 51 | 52 | names_.append(name) 53 | x_.append(xi) 54 | y.append(torch.from_numpy(contacts)) 55 | 56 | return x_, y, names_ 57 | 58 | 59 | def predict_minibatch(model, x, use_cuda): 60 | b = len(x) 61 | x,order = pack_sequences(x) 62 | x = PackedSequence(x.data, x.batch_sizes) 63 | z = model(x) # embed the sequences 64 | z = unpack_sequences(z, order) 65 | 66 | logits = [] 67 | for i in range(b): 68 | zi = z[i] 69 | lp = model.predict(zi.unsqueeze(0)).view(zi.size(0), zi.size(0)) 70 | logits.append(lp) 71 | 72 | return logits 73 | 74 | 75 | def calc_metrics(logits, y): 76 | y_hat = (logits > 0).astype(np.float32) 77 | TP = (y_hat*y).sum() 78 | precision = 1.0 79 | if y_hat.sum() > 0: 80 | precision = TP/y_hat.sum() 81 | recall = TP/y.sum() 82 | F1 = 0 83 | if precision + recall > 0: 84 | F1 = 2*precision*recall/(precision + recall) 85 | AUPR = average_precision(y, logits) 86 | return precision, recall, F1, AUPR 87 | 88 | 89 | def main(): 90 | import argparse 91 | parser = argparse.ArgumentParser('Script for evaluating contact map models.') 92 | parser.add_argument('model', help='path to saved model') 93 | parser.add_argument('--dataset', default='2.06 test', help='which dataset (default: 2.06 test)') 94 | parser.add_argument('--batch-size', default=10, type=int, help='number of sequences to process in each batch (default: 10)') 95 | parser.add_argument('-o', '--output', help='output file path (default: stdout)') 96 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 97 | args = parser.parse_args() 98 | 99 | # load the data 100 | if args.dataset == '2.06 test': 101 | fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa' 102 | contact_paths = glob.glob('data/SCOPe/pdbstyle-2.06/*/*.png') 103 | elif args.dataset == '2.07 test' or args.dataset == '2.07 new test': 104 | fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.fa' 105 | contact_paths = glob.glob('data/SCOPe/pdbstyle-2.07/*/*.png') 106 | else: 107 | raise Exception('Bad dataset argument ' + args.dataset) 108 | 109 | alphabet = Uniprot21() 110 | x,y,names = load_data(fasta_path, contact_paths, alphabet) 111 | 112 | ## set the device 113 | d = args.device 114 | use_cuda = (d != -1) and torch.cuda.is_available() 115 | if d >= 0: 116 | torch.cuda.set_device(d) 117 | 118 | if use_cuda: 119 | x = [x_.cuda() for x_ in x] 120 | y = [y_.cuda() for y_ in y] 121 | 122 | model = torch.load(args.model) 123 | model.eval() 124 | if use_cuda: 125 | model.cuda() 126 | 127 | # predict contact maps 128 | batch_size = args.batch_size 129 | dataset = ContactMapDataset(x, y) 130 | iterator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_lists) 131 | logits = [] 132 | with torch.no_grad(): 133 | for xmb,ymb in iterator: 134 | lmb = predict_minibatch(model, xmb, use_cuda) 135 | logits += lmb 136 | 137 | # calculate performance metrics 138 | lengths = np.array([len(x_) for x_ in x]) 139 | logits = [logit.cpu().numpy() for logit in logits] 140 | y = [y_.cpu().numpy() for y_ in y] 141 | 142 | output = args.output 143 | if output is None: 144 | output = sys.stdout 145 | else: 146 | output = open(output, 'w') 147 | line = '\t'.join(['Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5']) 148 | print(line, file=output) 149 | output.flush() 150 | 151 | # for all contacts 152 | y_flat = [] 153 | logits_flat = [] 154 | for i in range(len(y)): 155 | yi = y[i] 156 | mask = (yi < 0) 157 | y_flat.append(yi[~mask]) 158 | logits_flat.append(logits[i][~mask]) 159 | 160 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts 161 | precision = np.zeros(len(x)) 162 | recall = np.zeros(len(x)) 163 | F1 = np.zeros(len(x)) 164 | AUPR = np.zeros(len(x)) 165 | prL = np.zeros(len(x)) 166 | prL2 = np.zeros(len(x)) 167 | prL5 = np.zeros(len(x)) 168 | for i in range(len(x)): 169 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 170 | precision[i] = pr 171 | recall[i] = re 172 | F1[i] = f1 173 | AUPR[i] = aupr 174 | 175 | order = np.argsort(logits_flat[i])[::-1] 176 | n = lengths[i] 177 | topL = order[:n] 178 | prL[i] = y_flat[i][topL].mean() 179 | topL2 = order[:n//2] 180 | prL2[i] = y_flat[i][topL2].mean() 181 | topL5 = order[:n//5] 182 | prL5[i] = y_flat[i][topL5].mean() 183 | 184 | template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 185 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean()) 186 | print(line, file=output) 187 | output.flush() 188 | 189 | # for Medium/Long range contacts 190 | y_flat = [] 191 | logits_flat = [] 192 | for i in range(len(y)): 193 | yi = y[i] 194 | mask = (yi < 0) 195 | 196 | medlong = np.tril_indices(len(yi), k=11) 197 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8) 198 | medlong_mask[medlong] = 1 199 | mask = mask | (medlong_mask == 1) 200 | 201 | y_flat.append(yi[~mask]) 202 | logits_flat.append(logits[i][~mask]) 203 | 204 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts 205 | precision = np.zeros(len(x)) 206 | recall = np.zeros(len(x)) 207 | F1 = np.zeros(len(x)) 208 | AUPR = np.zeros(len(x)) 209 | prL = np.zeros(len(x)) 210 | prL2 = np.zeros(len(x)) 211 | prL5 = np.zeros(len(x)) 212 | for i in range(len(x)): 213 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 214 | precision[i] = pr 215 | recall[i] = re 216 | F1[i] = f1 217 | AUPR[i] = aupr 218 | 219 | order = np.argsort(logits_flat[i])[::-1] 220 | n = lengths[i] 221 | topL = order[:n] 222 | prL[i] = y_flat[i][topL].mean() 223 | topL2 = order[:n//2] 224 | prL2[i] = y_flat[i][topL2].mean() 225 | topL5 = order[:n//5] 226 | prL5[i] = y_flat[i][topL5].mean() 227 | 228 | template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 229 | line = template.format(np.nanmean(precision), np.nanmean(recall), np.nanmean(F1), np.nanmean(AUPR), np.nanmean(prL) 230 | , np.nanmean(prL2), np.nanmean(prL5)) 231 | print(line, file=output) 232 | output.flush() 233 | 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /eval_similarity.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | 12 | from scipy.stats import pearsonr,spearmanr 13 | 14 | from src.utils import pack_sequences, unpack_sequences 15 | from src.alphabets import Uniprot21 16 | from src.alignment import nw_score 17 | from src.metrics import average_precision 18 | 19 | 20 | def encode_sequence(x, alphabet): 21 | # convert to bytes and uppercase 22 | x = x.encode('utf-8').upper() 23 | # convert to alphabet index 24 | x = alphabet.encode(x) 25 | return x 26 | 27 | 28 | def load_pairs(path, alphabet): 29 | table = pd.read_csv(path, sep='\t') 30 | 31 | x0 = [encode_sequence(x, alphabet) for x in table['sequence_A']] 32 | x1 = [encode_sequence(x, alphabet) for x in table['sequence_B']] 33 | y = table['similarity'].values 34 | 35 | return x0, x1, y 36 | 37 | 38 | class NWAlign: 39 | def __init__(self, alphabet): 40 | from Bio.SubsMat import MatrixInfo as matlist 41 | L = len(alphabet) 42 | subst = np.zeros((L,L), dtype=np.int32) 43 | for i in range(L): 44 | for j in range(i,L): 45 | a = alphabet[i] 46 | b = alphabet[j] 47 | subst[i,j] = subst[j,i] = matlist.blosum62[(b,a)] 48 | self.subst = subst 49 | self.gap = -11 50 | self.extend = -1 51 | 52 | def __call__(self, x, y): 53 | b = len(x) 54 | scores = np.zeros(b) 55 | for i in range(b): 56 | scores[i] = nw_score(x[i], y[i], self.subst, self.gap, self.extend) 57 | return scores 58 | 59 | 60 | class TorchModel: 61 | def __init__(self, model, use_cuda, mode='ssa'): 62 | self.model = model 63 | self.use_cuda = use_cuda 64 | self.mode = mode 65 | 66 | def __call__(self, x, y): 67 | n = len(x) 68 | c = [torch.from_numpy(x_).long() for x_ in x] + [torch.from_numpy(y_).long() for y_ in y] 69 | 70 | c,order = pack_sequences(c) 71 | if self.use_cuda: 72 | c = c.cuda() 73 | 74 | with torch.no_grad(): 75 | z = self.model(c) # embed the sequences 76 | z = unpack_sequences(z, order) 77 | 78 | scores = np.zeros(n) 79 | if self.mode == 'align': 80 | for i in range(n): 81 | z_x = z[i] 82 | z_y = z[i+n] 83 | 84 | logits = self.model.score(z_x, z_y) 85 | p = torch.sigmoid(logits).cpu() 86 | p_ge = torch.ones(p.size(0)+1) 87 | p_ge[1:] = p 88 | p_lt = torch.ones(p.size(0)+1) 89 | p_lt[:-1] = 1 - p 90 | p = p_ge*p_lt 91 | p = p/p.sum() # make sure p is normalized 92 | levels = torch.arange(5).float() 93 | scores[i] = torch.sum(p*levels).item() 94 | 95 | elif self.mode == 'coarse': 96 | z_x = z[:n] 97 | z_y = z[n:] 98 | z_x = torch.stack([z.mean(0) for z in z_x], 0) 99 | z_y = torch.stack([z.mean(0) for z in z_y], 0) 100 | scores[:] = -torch.sum(torch.abs(z_x - z_y), 1).cpu().numpy() 101 | 102 | return scores 103 | 104 | def find_best_threshold(x, y, tr0=-np.inf): 105 | order = np.argsort(x) 106 | 107 | tp = np.zeros(len(x)+1) 108 | tp[0] = y.sum() 109 | tn = np.zeros(len(x)+1) 110 | tn[0] = 0 111 | 112 | 113 | for i in range(len(x)): 114 | j = order[i] 115 | tp[i+1] = tp[i] - y[j] 116 | tn[i+1] = tn[i] + 1 - y[j] 117 | 118 | acc = (tp + tn)/len(y) 119 | i = np.argmax(acc) - 1 120 | 121 | tr = x[order[i]] 122 | if i < 0: 123 | tr = tr0 124 | 125 | return tr 126 | 127 | 128 | def find_best_thresholds(x, y): 129 | thresholds = np.zeros(5) 130 | thresholds[0] = -np.inf 131 | for i in range(4): 132 | mask = (x > thresholds[i]) 133 | xi = x[mask] 134 | labels = (y[mask] > i) 135 | tr = find_best_threshold(xi, labels, tr0=thresholds[i]) 136 | thresholds[i+1] = tr 137 | return thresholds 138 | 139 | 140 | def calculate_metrics(scores, y, thresholds): 141 | ## calculate accuracy, r, rho 142 | pred_level = np.digitize(scores, thresholds[1:], right=True) 143 | accuracy = np.mean(pred_level == y) 144 | r,_ = pearsonr(scores, y) 145 | rho,_ = spearmanr(scores, y) 146 | ## calculate average-precision score for each structural level 147 | aupr = np.zeros(4, dtype=np.float32) 148 | for i in range(4): 149 | target = (y > i).astype(np.float32) 150 | aupr[i] = average_precision(target, scores.astype(np.float32)) 151 | return accuracy, r, rho, aupr 152 | 153 | 154 | def score_pairs(model, x0, x1, batch_size=100): 155 | scores = [] 156 | for i in range(0, len(x0), batch_size): 157 | x0_mb = x0[i:i+batch_size] 158 | x1_mb = x1[i:i+batch_size] 159 | scores.append(model(x0_mb, x1_mb)) 160 | scores = np.concatenate(scores, 0) 161 | return scores 162 | 163 | 164 | def main(): 165 | import argparse 166 | parser = argparse.ArgumentParser('Script for evaluating similarity model on SCOP test set.') 167 | 168 | parser.add_argument('model', help='path to saved model file or "nw-align" for Needleman-Wunsch alignment score baseline') 169 | 170 | parser.add_argument('--dev', action='store_true', help='use train/dev split') 171 | 172 | parser.add_argument('--batch-size', default=64, type=int, help='number of sequence pairs to process in each batch (default: 64)') 173 | 174 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 175 | 176 | parser.add_argument('--coarse', action='store_true', help='use coarse comparison rather than full SSA') 177 | 178 | args = parser.parse_args() 179 | 180 | scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.txt' 181 | 182 | eval_paths = [ ( '2.06-test' 183 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt') 184 | , ( '2.07-new' 185 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.allpairs.txt') 186 | ] 187 | if args.dev: 188 | scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.sampledpairs.txt' 189 | 190 | eval_paths = [ ( '2.06-dev' 191 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt') 192 | ] 193 | 194 | 195 | ## load the data 196 | alphabet = Uniprot21() 197 | x0_train, x1_train, y_train = load_pairs(scop_train_path, alphabet) 198 | 199 | ## load the model 200 | if args.model == 'nw-align': 201 | model = NWAlign(alphabet) 202 | elif args.model in ['hhalign', 'phmmer', 'TMalign']: 203 | model = args.model 204 | else: 205 | model = torch.load(args.model) 206 | model.eval() 207 | 208 | ## set the device 209 | d = args.device 210 | use_cuda = (d != -1) and torch.cuda.is_available() 211 | if d >= 0: 212 | torch.cuda.set_device(d) 213 | 214 | if use_cuda: 215 | model.cuda() 216 | 217 | mode = 'align' 218 | if args.coarse: 219 | mode = 'coarse' 220 | model = TorchModel(model, use_cuda, mode=mode) 221 | 222 | batch_size = args.batch_size 223 | 224 | ## for calculating the classification accuracy, first find the best partitions using the training set 225 | if type(model) is str: 226 | path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.' \ 227 | + model + '.npy' 228 | scores = np.load(path) 229 | scores = scores.mean(1) 230 | else: 231 | scores = score_pairs(model, x0_train, x1_train, batch_size) 232 | thresholds = find_best_thresholds(scores, y_train) 233 | 234 | print('Dataset\tAccuracy\tPearson\'s r\tSpearman\'s rho\tClass\tFold\tSuperfamily\tFamily') 235 | 236 | accuracy, r, rho, aupr = calculate_metrics(scores, y_train, thresholds) 237 | 238 | template = '{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 239 | 240 | line = '2.06-train\t' + template.format(accuracy, r, rho, aupr[0], aupr[1], aupr[2], aupr[3]) 241 | #line = '\t'.join(['2.06-train', str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])]) 242 | print(line) 243 | 244 | for dset,path in eval_paths: 245 | x0_test, x1_test, y_test = load_pairs(path, alphabet) 246 | if type(model) is str: 247 | path = os.path.splitext(path)[0] 248 | path = path + '.' + model + '.npy' 249 | scores = np.load(path) 250 | scores = scores.mean(1) 251 | else: 252 | scores = score_pairs(model, x0_test, x1_test, batch_size) 253 | accuracy, r, rho, aupr = calculate_metrics(scores, y_test, thresholds) 254 | 255 | line = dset + '\t' + template.format(accuracy, r, rho, aupr[0], aupr[1], aupr[2], aupr[3]) 256 | #line = '\t'.join([dset, str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])]) 257 | print(line) 258 | 259 | 260 | if __name__ == '__main__': 261 | main() 262 | 263 | -------------------------------------------------------------------------------- /train_lm_pfam.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import sys 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch.nn.utils.rnn import pack_padded_sequence 13 | 14 | import src.fasta as fasta 15 | from src.alphabets import Uniprot21 16 | import src.models.sequence 17 | 18 | parser = argparse.ArgumentParser('Train sequence model') 19 | 20 | parser.add_argument('-b', '--minibatch-size', type=int, default=32, help='minibatch size (default: 32)') 21 | parser.add_argument('-n', '--num-epochs', type=int, default=10, help='number of epochs (default: 10)') 22 | 23 | parser.add_argument('--hidden-dim', type=int, default=512, help='hidden dimension of RNN (default: 512)') 24 | parser.add_argument('--num-layers', type=int, default=2, help='number of RNN layers (default: 2)') 25 | parser.add_argument('--dropout', type=float, default=0, help='dropout (default: 0)') 26 | 27 | parser.add_argument('--untied', action='store_true', help='use biRNN with untied weights') 28 | 29 | parser.add_argument('--l2', type=float, default=0, help='l2 regularizer (default: 0)') 30 | 31 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 0.001)') 32 | parser.add_argument('--clip', type=float, default=1, help='gradient clipping max norm (default: 1)') 33 | 34 | parser.add_argument('-d', '--device', type=int, default=-2, help='device to use, -1: cpu, 0+: gpu (default: gpu if available, else cpu)') 35 | 36 | parser.add_argument('-o', '--output', help='where to write training curve (default: stdout)') 37 | parser.add_argument('--save-prefix', help='path prefix for saving models (default: no saving)') 38 | 39 | 40 | pfam_train = 'data/pfam/Pfam-A.train.fasta' 41 | pfam_test = 'data/pfam/Pfam-A.test.fasta' 42 | 43 | 44 | def preprocess_sequence(s, alphabet): 45 | x = alphabet.encode(s) 46 | # pad with start/stop token 47 | z = np.zeros(len(x)+2, dtype=x.dtype) 48 | z[1:-1] = x + 1 49 | return z 50 | 51 | def load_pfam(path, alph): 52 | # load path sequences and families 53 | with open(path, 'rb') as f: 54 | group = [] 55 | sequences = [] 56 | for name,sequence in fasta.parse_stream(f): 57 | x = preprocess_sequence(sequence.upper(), alph) 58 | sequences.append(x) 59 | family = name.split(b';')[-2] 60 | group.append(family) 61 | group = np.array(group) 62 | sequences = np.array(sequences) 63 | return group, sequences 64 | 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | 69 | alph = Uniprot21() 70 | ntokens = len(alph) 71 | 72 | ## load the training sequences 73 | train_group, X_train = load_pfam(pfam_train, alph) 74 | print('# loaded', len(X_train), 'sequences from', pfam_train, file=sys.stderr) 75 | 76 | ## load the testing sequences 77 | test_group, X_test = load_pfam(pfam_test, alph) 78 | print('# loaded', len(X_test), 'sequences from', pfam_test, file=sys.stderr) 79 | 80 | ## initialize the model 81 | nin = ntokens + 1 82 | nout = ntokens 83 | embedding_dim = 21 84 | hidden_dim = args.hidden_dim 85 | num_layers = args.num_layers 86 | mask_idx = ntokens 87 | dropout = args.dropout 88 | 89 | tied = not args.untied 90 | 91 | model = src.models.sequence.BiLM(nin, nout, embedding_dim, hidden_dim, num_layers 92 | , mask_idx=mask_idx, dropout=dropout, tied=tied) 93 | print('# initialized model', file=sys.stderr) 94 | 95 | device = args.device 96 | use_cuda = torch.cuda.is_available() and (device == -2 or device >= 0) 97 | if device >= 0: 98 | torch.cuda.set_device(device) 99 | if use_cuda: 100 | model = model.cuda() 101 | 102 | ## form the data iterators and optimizer 103 | lr = args.lr 104 | l2 = args.l2 105 | solver = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2) 106 | 107 | def collate(xs): 108 | B = len(xs) 109 | N = max(len(x) for x in xs) 110 | lengths = np.array([len(x) for x in xs], dtype=int) 111 | 112 | order = np.argsort(lengths)[::-1] 113 | lengths = lengths[order] 114 | 115 | X = torch.LongTensor(B, N).zero_() + mask_idx 116 | for i in range(B): 117 | x = xs[order[i]] 118 | n = len(x) 119 | X[i,:n] = torch.from_numpy(x) 120 | return X, lengths 121 | 122 | mb = args.minibatch_size 123 | 124 | train_iterator = torch.utils.data.DataLoader(X_train, batch_size=mb, shuffle=True 125 | , collate_fn=collate) 126 | test_iterator = torch.utils.data.DataLoader(X_test, batch_size=mb 127 | , collate_fn=collate) 128 | 129 | ## fit the model! 130 | 131 | print('# training model', file=sys.stderr) 132 | 133 | output = sys.stdout 134 | if args.output is not None: 135 | output = open(args.output, 'w') 136 | 137 | num_epochs = args.num_epochs 138 | clip = args.clip 139 | 140 | save_prefix = args.save_prefix 141 | digits = int(np.floor(np.log10(num_epochs))) + 1 142 | 143 | print('epoch\tsplit\tlog_p\tperplexity\taccuracy', file=output) 144 | output.flush() 145 | 146 | for epoch in range(num_epochs): 147 | # train epoch 148 | model.train() 149 | it = 0 150 | n = 0 151 | accuracy = 0 152 | loss_accum = 0 153 | for X,lengths in train_iterator: 154 | if use_cuda: 155 | X = X.cuda() 156 | X = Variable(X) 157 | logp = model(X) 158 | 159 | mask = (X != mask_idx) 160 | 161 | index = X*mask.long() 162 | loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2) 163 | loss = torch.mean(loss.masked_select(mask)) 164 | 165 | loss.backward() 166 | 167 | # clip the gradient 168 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 169 | 170 | solver.step() 171 | solver.zero_grad() 172 | 173 | _,y_hat = torch.max(logp, 2) 174 | correct = torch.sum((y_hat == X).masked_select(mask)) 175 | #correct = torch.sum((y_hat == X)[mask.nonzero()].float()) 176 | 177 | b = mask.long().sum().item() 178 | n += b 179 | delta = b*(loss.item() - loss_accum) 180 | loss_accum += delta/n 181 | delta = correct.item() - b*accuracy 182 | accuracy += delta/n 183 | 184 | b = X.size(0) 185 | it += b 186 | if (it - b)//100 < it//100: 187 | print('# [{}/{}] training {:.1%} loss={:.5f}, acc={:.5f}'.format(epoch+1 188 | , num_epochs 189 | , it/len(X_train) 190 | , loss_accum 191 | , accuracy 192 | ) 193 | , end='\r', file=sys.stderr) 194 | print(' '*80, end='\r', file=sys.stderr) 195 | 196 | perplex = np.exp(loss_accum) 197 | string = str(epoch+1).zfill(digits) + '\t' + 'train' + '\t' + str(loss_accum) \ 198 | + '\t' + str(perplex) + '\t' + str(accuracy) 199 | print(string, file=output) 200 | output.flush() 201 | 202 | # test epoch 203 | model.eval() 204 | it = 0 205 | n = 0 206 | accuracy = 0 207 | loss_accum = 0 208 | with torch.no_grad(): 209 | for X,lengths in test_iterator: 210 | if use_cuda: 211 | X = X.cuda() 212 | X = Variable(X) 213 | logp = model(X) 214 | 215 | mask = (X != mask_idx) 216 | 217 | index = X*mask.long() 218 | loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2) 219 | loss = torch.mean(loss.masked_select(mask)) 220 | 221 | _,y_hat = torch.max(logp, 2) 222 | correct = torch.sum((y_hat == X).masked_select(mask)) 223 | 224 | b = mask.long().sum().item() 225 | n += b 226 | delta = b*(loss.item() - loss_accum) 227 | loss_accum += delta/n 228 | delta = correct.item() - b*accuracy 229 | accuracy += delta/n 230 | 231 | b = X.size(0) 232 | it += b 233 | if (it - b)//100 < it//100: 234 | print('# [{}/{}] test {:.1%} loss={:.5f}, acc={:.5f}'.format(epoch+1 235 | , num_epochs 236 | , it/len(X_test) 237 | , loss_accum 238 | , accuracy 239 | ) 240 | , end='\r', file=sys.stderr) 241 | print(' '*80, end='\r', file=sys.stderr) 242 | 243 | perplex = np.exp(loss_accum) 244 | string = str(epoch+1).zfill(digits) + '\t' + 'test' + '\t' + str(loss_accum) \ 245 | + '\t' + str(perplex) + '\t' + str(accuracy) 246 | print(string, file=output) 247 | output.flush() 248 | 249 | ## save the model 250 | if save_prefix is not None: 251 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav' 252 | model = model.cpu() 253 | torch.save(model, save_path) 254 | if use_cuda: 255 | model = model.cuda() 256 | 257 | 258 | 259 | if __name__ == '__main__': 260 | main() 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | -------------------------------------------------------------------------------- /eval_transmembrane.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.utils.rnn import PackedSequence 10 | import torch.utils.data 11 | 12 | from src.alphabets import Uniprot21 13 | from src.parse_utils import parse_3line 14 | import src.transmembrane as tm 15 | 16 | def load_3line(path, alphabet): 17 | with open(path, 'rb') as f: 18 | names, x, y = parse_3line(f) 19 | x = [alphabet.encode(x) for x in x] 20 | y = [tm.encode_labels(y) for y in y] 21 | return x, y 22 | 23 | def load_data(): 24 | alphabet = Uniprot21() 25 | 26 | path = 'data/transmembrane/TOPCONS2_datasets/TM.3line' 27 | x_tm, y_tm = load_3line(path, alphabet) 28 | 29 | path = 'data/transmembrane/TOPCONS2_datasets/SP+TM.3line' 30 | x_tm_sp, y_tm_sp = load_3line(path, alphabet) 31 | 32 | path = 'data/transmembrane/TOPCONS2_datasets/Globular.3line' 33 | x_glob, y_glob = load_3line(path, alphabet) 34 | 35 | path = 'data/transmembrane/TOPCONS2_datasets/Globular+SP.3line' 36 | x_glob_sp, y_glob_sp = load_3line(path, alphabet) 37 | 38 | datasets = {'TM': (x_tm, y_tm), 'SP+TM': (x_tm_sp, y_tm_sp), 39 | 'Globular': (x_glob, y_glob), 'Globular+SP': (x_glob_sp, y_glob_sp)} 40 | 41 | return datasets 42 | 43 | def split_dataset(xs, ys, random=np.random, k=5): 44 | x_splits = [[] for _ in range(k)] 45 | y_splits = [[] for _ in range(k)] 46 | order = random.permutation(len(xs)) 47 | for i in range(len(order)): 48 | j = order[i] 49 | x_s = x_splits[i%k] 50 | y_s = y_splits[i%k] 51 | x_s.append(xs[j]) 52 | y_s.append(ys[j]) 53 | return x_splits, y_splits 54 | 55 | def unstack_lstm(lstm): 56 | in_size = lstm.input_size 57 | hidden_dim = lstm.hidden_size 58 | layers = [] 59 | for i in range(lstm.num_layers): 60 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True) 61 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l'] 62 | for attr in attributes: 63 | dest = attr + '0' 64 | src = attr + str(i) 65 | getattr(layer, dest).data[:] = getattr(lstm, src) 66 | #setattr(layer, dest, getattr(lstm, src)) 67 | 68 | dest = attr + '0_reverse' 69 | src = attr + str(i) + '_reverse' 70 | getattr(layer, dest).data[:] = getattr(lstm, src) 71 | #setattr(layer, dest, getattr(lstm, src)) 72 | layers.append(layer) 73 | in_size = 2*hidden_dim 74 | return layers 75 | 76 | def featurize(x, lm_embed, lstm_stack, proj, include_lm=True, lm_only=False): 77 | zs = [] 78 | 79 | x_onehot = x.new(x.size(0),x.size(1), 21).float().zero_() 80 | x_onehot.scatter_(2,x.unsqueeze(2),1) 81 | zs.append(x_onehot) 82 | 83 | h = lm_embed(x) 84 | if include_lm: 85 | zs.append(h) 86 | if not lm_only: 87 | for lstm in lstm_stack: 88 | h,_ = lstm(h) 89 | zs.append(h) 90 | h = proj(h.squeeze(0)).unsqueeze(0) 91 | zs.append(h) 92 | z = torch.cat(zs, 2) 93 | return z 94 | 95 | def featurize_dict(datasets, lm_embed, lstm_stack, proj, use_cuda=False, include_lm=True, lm_only=False): 96 | z = {} 97 | for k,v in datasets.items(): 98 | x_k = v[0] 99 | z[k] = [] 100 | with torch.no_grad(): 101 | for x in x_k: 102 | x = torch.from_numpy(x).long().unsqueeze(0) 103 | if use_cuda: 104 | x = x.cuda() 105 | z_x = featurize(x, lm_embed, lstm_stack, proj, include_lm=include_lm, lm_only=lm_only) 106 | z_x = z_x.squeeze(0).cpu() 107 | z[k].append(z_x) 108 | return z 109 | 110 | def featurize_one_hot_dict(datasets, n): 111 | z = {} 112 | for k,v in datasets.items(): 113 | x_k = v[0] 114 | z[k] = [] 115 | with torch.no_grad(): 116 | for x in x_k: 117 | x = torch.from_numpy(x).long() 118 | one_hot = torch.FloatTensor(x.size(0), n).to(x.device) 119 | one_hot.zero_() 120 | one_hot.scatter_(1, x.unsqueeze(1), 1) 121 | z[k].append(one_hot) 122 | return z 123 | 124 | def make_train_test(splits, j, k): 125 | x_train = [] 126 | y_train = [] 127 | for v in splits.values(): 128 | for i in range(k): 129 | if i != j: 130 | x_train += v[0][i] 131 | y_train += v[1][i] 132 | 133 | x_test = {k:v[0][j] for k,v in splits.items()} 134 | y_test = {k:v[1][j] for k,v in splits.items()} 135 | 136 | return x_train, y_train, x_test, y_test 137 | 138 | class ListDataset: 139 | def __init__(self, x, y): 140 | self.x = x 141 | self.y = y 142 | 143 | def __len__(self): 144 | return len(self.x) 145 | 146 | def __getitem__(self, i): 147 | return self.x[i], self.y[i] 148 | 149 | class LSTM(nn.Module): 150 | def __init__(self, n_in, n_hidden, n_out): 151 | super(LSTM, self).__init__() 152 | self.rnn = nn.LSTM(n_in, n_hidden, bidirectional=True, batch_first=True) 153 | self.linear = nn.Linear(2*n_hidden, n_out) 154 | 155 | def forward(self, x): 156 | if type(x) is not PackedSequence: 157 | ndim = len(x.size()) 158 | if ndim == 2: 159 | x = x.unsqueeze(0) 160 | h,_ = self.rnn(x) 161 | if type(h) is PackedSequence: 162 | z = self.linear(h.data) 163 | return PackedSequence(z, h.batch_sizes) 164 | else: 165 | z = self.linear(h.view(h.size(0)*h.size(1), -1)) 166 | z = z.view(h.size(0), h.size(1), -1) 167 | if ndim == 2: 168 | z = z.squeeze(0) 169 | return z 170 | 171 | def train(x_train, y_train, num_epochs=10, hidden_dim=100, use_cuda=False): 172 | 173 | d = x_train[0].size(1) 174 | 175 | model = LSTM(d, hidden_dim, 4) 176 | if use_cuda: 177 | model.cuda() 178 | 179 | batch_size = 1 180 | dataset = ListDataset(x_train, y_train) 181 | iterator = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size) 182 | 183 | optim = torch.optim.Adam(model.parameters(), lr=3e-4) 184 | 185 | for epoch in range(num_epochs): 186 | for x,y in iterator: 187 | if use_cuda: 188 | x = x.cuda() 189 | y = y.cuda() 190 | log_p = model(x).squeeze(0) 191 | x = x.squeeze(0) 192 | y = y.squeeze(0) 193 | loss = F.cross_entropy(log_p, y) 194 | loss.backward() 195 | optim.step() 196 | optim.zero_grad() 197 | 198 | return model 199 | 200 | def evaluate_model(model, grammar, z_test, y_test, use_cuda=False): 201 | results = {} 202 | for key in z_test: 203 | y_hats = grammar.predict_viterbi(z_test[key], model, use_cuda) 204 | correct = np.zeros(len(y_hats)) 205 | for i,(pred,target) in enumerate(zip(y_hats, y_test[key])): 206 | correct[i] = tm.is_prediction_correct(pred, target) 207 | results[key] = correct.mean() 208 | overall = sum(results.values())/len(results) 209 | return overall, results 210 | 211 | def evaluate_split(splits, j, k, num_epochs=10, hidden_dim=100, use_cuda=False, grammar=tm.Grammar()): 212 | x_train, y_train, x_test, y_test = make_train_test(splits, j, k) 213 | model = train(x_train, y_train, num_epochs=num_epochs, hidden_dim=hidden_dim, use_cuda=use_cuda) 214 | model.eval() 215 | overall,results = evaluate_model(model, grammar, x_test, y_test, use_cuda=use_cuda) 216 | return overall, results 217 | 218 | 219 | def main(): 220 | import argparse 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument('model', help='path to saved embedding model') 223 | parser.add_argument('--hidden-dim', type=int, default=150, help='dimension of LSTM (default: 150)') 224 | parser.add_argument('--num-epochs', type=int, default=10, help='number of training epochs (default: 10)') 225 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 226 | args = parser.parse_args() 227 | 228 | datasets = load_data() 229 | num_epochs = args.num_epochs 230 | hidden_dim = args.hidden_dim 231 | 232 | d = args.device 233 | use_cuda = (d != -1) and torch.cuda.is_available() 234 | if d >= 0: 235 | torch.cuda.set_device(d) 236 | 237 | 238 | ## load the embedding model 239 | if args.model == '1-hot': 240 | print('# featurizing data', file=sys.stderr) 241 | z = featurize_one_hot_dict(datasets, 21) 242 | datasets = {k: (z[k],v[1]) for k,v in datasets.items()} 243 | 244 | else: 245 | encoder = torch.load(args.model) 246 | encoder.eval() 247 | encoder = encoder.embedding 248 | 249 | lm_embed = encoder.embed 250 | lstm_stack = unstack_lstm(encoder.rnn) 251 | proj = encoder.proj 252 | 253 | if use_cuda: 254 | lm_embed.cuda() 255 | for lstm in lstm_stack: 256 | lstm.cuda() 257 | proj.cuda() 258 | 259 | ## featurize the sequences 260 | print('# featurizing data', file=sys.stderr) 261 | z = featurize_dict(datasets, lm_embed, lstm_stack, proj, use_cuda=use_cuda) 262 | 263 | del lm_embed 264 | del lstm_stack 265 | del proj 266 | del encoder 267 | 268 | datasets = {k: (z[k],v[1]) for k,v in datasets.items()} 269 | 270 | ## split into folds 271 | random = np.random.RandomState(10) 272 | K = 10 273 | datasets_split = {k: split_dataset(v[0], v[1], random=random, k=K) for k,v in datasets.items()} 274 | 275 | ## train/test on each fold 276 | print('# training and evaluating with', K, 'folds', file=sys.stderr) 277 | print('# using', hidden_dim, 'LSTM units', file=sys.stderr) 278 | tags = ['TM', 'SP+TM', 'Globular', 'Globular+SP'] 279 | print('\t'.join(['Fold'] + tags + ['Overall'])) 280 | split_results = {} 281 | split_overall = [] 282 | for i in range(K): 283 | overall, results = evaluate_split(datasets_split, i, K, 284 | num_epochs=num_epochs, hidden_dim=hidden_dim, 285 | use_cuda=use_cuda) 286 | for key in tags: 287 | this = split_results.get(key, []) 288 | this.append(results[key]) 289 | split_results[key] = this 290 | split_overall.append(overall) 291 | cols = [str(i)] + ['{:.5f}'.format(results[key]) for key in tags] + ['{:.5f}'.format(overall)] 292 | line = '\t'.join(cols) 293 | print(line) 294 | 295 | results = {key:np.mean(values) for key,values in split_results.items()} 296 | overall = np.mean(split_overall) 297 | 298 | cols = ['All'] + ['{:.5f}'.format(results[key]) for key in tags] + ['{:.5f}'.format(overall)] 299 | line = '\t'.join(cols) 300 | print(line) 301 | 302 | 303 | if __name__ == '__main__': 304 | main() 305 | 306 | 307 | 308 | -------------------------------------------------------------------------------- /eval_secstr.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.utils.rnn import PackedSequence 9 | import torch.utils.data 10 | 11 | from src.alphabets import Uniprot21, SecStr8 12 | from src.utils import pack_sequences, unpack_sequences 13 | import src.pdb as pdb 14 | 15 | 16 | secstr_train_path = 'data/secstr/ss_cullpdb_pc40_res3.0_R1.0_d180412_filtered.train.fa' 17 | secstr_test_path = 'data/secstr/ss_cullpdb_pc40_res3.0_R1.0_d180412_filtered.test.fa' 18 | 19 | 20 | def encode_sequence(x, alphabet): 21 | # convert to bytes and uppercase 22 | x = x.encode('utf-8').upper() 23 | # convert to alphabet index 24 | x = alphabet.encode(x) 25 | return x 26 | 27 | 28 | def load_secstr(path, alphabet, secstr): 29 | with open(path, 'rb') as f: 30 | names,aa_seqs,ss_seqs = pdb.parse_secstr(f) 31 | aa_seqs = [alphabet.encode(x.upper()) for x in aa_seqs] 32 | ss_seqs = [secstr.encode(x.upper()) for x in ss_seqs] 33 | return names,aa_seqs,ss_seqs 34 | 35 | 36 | def unstack_lstm(lstm): 37 | in_size = lstm.input_size 38 | hidden_dim = lstm.hidden_size 39 | layers = [] 40 | for i in range(lstm.num_layers): 41 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True) 42 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l'] 43 | for attr in attributes: 44 | dest = attr + '0' 45 | src = attr + str(i) 46 | getattr(layer, dest).data[:] = getattr(lstm, src) 47 | dest = attr + '0_reverse' 48 | src = attr + str(i) + '_reverse' 49 | getattr(layer, dest).data[:] = getattr(lstm, src) 50 | layers.append(layer) 51 | in_size = 2*hidden_dim 52 | return layers 53 | 54 | def featurize(x, lm_embed, lstm_stack, proj, include_lm=True, lm_only=False): 55 | zs = [] 56 | 57 | packed = type(x) is PackedSequence 58 | if packed: 59 | batch_sizes = x.batch_sizes 60 | x = x.data 61 | 62 | x_onehot = x.new(x.size(0), 21).float().zero_() 63 | x_onehot.scatter_(1,x.unsqueeze(1),1) 64 | 65 | if packed: 66 | x_onehot = PackedSequence(x_onehot, batch_sizes) 67 | x = PackedSequence(x, batch_sizes) 68 | 69 | zs.append(x_onehot) 70 | 71 | h = lm_embed(x) 72 | if include_lm: 73 | zs.append(h) 74 | if not lm_only: 75 | for lstm in lstm_stack: 76 | h,_ = lstm(h) 77 | zs.append(h) 78 | if packed: 79 | h = h.data 80 | h = proj(h) 81 | if packed: 82 | h = PackedSequence(h, batch_sizes) 83 | zs.append(h) 84 | if packed: 85 | zs = [z.data for z in zs] 86 | z = torch.cat(zs, 1) 87 | if packed: 88 | z = PackedSequence(z, batch_sizes) 89 | return z 90 | 91 | class TorchModel: 92 | def __init__(self, model, use_cuda, full_features=False): 93 | self.model = model 94 | self.use_cuda = use_cuda 95 | self.full_features = full_features 96 | if full_features: 97 | self.lm_embed = model.embedding.embed 98 | self.lstm_stack = unstack_lstm(model.embedding.rnn) 99 | self.proj = model.embedding.proj 100 | if use_cuda: 101 | self.lm_embed.cuda() 102 | for lstm in self.lstm_stack: 103 | lstm.cuda() 104 | self.proj.cuda() 105 | 106 | 107 | def __call__(self, x): 108 | c = [torch.from_numpy(x_).long() for x_ in x] 109 | 110 | c,order = pack_sequences(c) 111 | if self.use_cuda: 112 | c = c.cuda() 113 | 114 | if self.full_features: 115 | z = featurize(c, self.lm_embed, self.lstm_stack, self.proj) 116 | else: 117 | z = self.model(c) # embed the sequences 118 | z = unpack_sequences(z, order) 119 | 120 | return z 121 | 122 | def kmer_features(xs, n, k): 123 | if k == 1: 124 | return xs, n 125 | pad = np.array([n]*(k//2)) 126 | f = (n+1)**np.arange(k) 127 | kmers = [] 128 | for x in xs: 129 | x = np.concatenate([pad, x, pad], axis=0) 130 | z = np.convolve(x, f, mode='valid') 131 | kmers.append(z) 132 | return kmers, (n+1)**k 133 | 134 | 135 | class Shuffle: 136 | def __init__(self, x, y, minibatch_size): 137 | self.x = x 138 | self.y = y 139 | self.minibatch_size = minibatch_size 140 | 141 | def __iter__(self): 142 | n = len(self.x) 143 | order = np.random.permutation(n) 144 | order = torch.from_numpy(order).long().to(self.x.device) 145 | x = self.x[order] 146 | y = self.y[order] 147 | b = self.minibatch_size 148 | for i in range(0, n, b): 149 | yield x[i:i+b], y[i:i+b] 150 | 151 | 152 | def fit_kmer_potentials(x, y, n, m): 153 | _,counts = np.unique(y, return_counts=True) 154 | weights = torch.zeros(n, m) 155 | weights += torch.from_numpy(counts/counts.sum()).float() 156 | for i in range(len(x)): 157 | weights[x[i],y[i]] += 1 158 | 159 | model = nn.Embedding(n, m, sparse=True) 160 | model.weight.data[:] = torch.log(weights) - torch.log(weights.sum(1, keepdim=True)) 161 | 162 | return model 163 | 164 | 165 | def fit_nn_potentials(model, x, y, lr=0.001, num_epochs=10, minibatch_size=256 166 | , use_cuda=False): 167 | solver = torch.optim.Adam(model.parameters(), lr=lr) 168 | 169 | iterator = Shuffle(x, y, minibatch_size) 170 | 171 | model.train() 172 | for epoch in range(num_epochs): 173 | n = 0 174 | loss_accum = 0 175 | acc = 0 176 | for x,y in iterator: 177 | if use_cuda: 178 | x = x.cuda() 179 | y = y.cuda() 180 | potentials = model(x).view(x.size(0), -1) 181 | loss = F.cross_entropy(potentials, y) 182 | 183 | loss.backward() 184 | solver.step() 185 | solver.zero_grad() 186 | 187 | _,y_hat = potentials.max(1) 188 | correct = torch.sum((y_hat == y).float()) 189 | 190 | b = x.size(0) 191 | n += b 192 | delta = b*(loss.item() - loss_accum) 193 | loss_accum += delta/n 194 | delta = correct.item() - b*acc 195 | acc += delta/n 196 | 197 | print('train', epoch+1, loss_accum, np.exp(loss_accum), acc) 198 | 199 | 200 | def main(): 201 | import argparse 202 | parser = argparse.ArgumentParser('Script for evaluating similarity model on SCOP test set.') 203 | 204 | parser.add_argument('features', help='path to saved embedding model file or "1-", "3-", or "5-mer" for k-mer features') 205 | 206 | parser.add_argument('--num-epochs', type=int, default=10, help='number of epochs to train for (default: 10)') 207 | parser.add_argument('--all-hidden', action='store_true', help='use all hidden layers as features') 208 | 209 | parser.add_argument('-v', '--print-examples', default=0, type=int, help='number of examples to print (default: 0)') 210 | 211 | parser.add_argument('-o', '--output', help='output file path (default: stdout)') 212 | parser.add_argument('--save-prefix', help='path prefix for saving models') 213 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 214 | 215 | args = parser.parse_args() 216 | num_epochs = args.num_epochs 217 | 218 | ## load the data 219 | alphabet = Uniprot21() 220 | secstr = SecStr8 221 | 222 | names_train, x_train, y_train = load_secstr(secstr_train_path, alphabet, secstr) 223 | names_test, x_test, y_test = load_secstr(secstr_test_path, alphabet, secstr) 224 | 225 | sequences_test = [''.join(alphabet[c] for c in x_test[i]) for i in range(len(x_test))] 226 | 227 | y_train = np.concatenate(y_train, 0) 228 | 229 | ## set the device 230 | d = args.device 231 | use_cuda = (d != -1) and torch.cuda.is_available() 232 | if d >= 0: 233 | torch.cuda.set_device(d) 234 | 235 | 236 | if args.features == '1-mer': 237 | n = len(alphabet) 238 | x_test = [x.astype(int) for x in x_test] 239 | elif args.features == '3-mer': 240 | x_train,n = kmer_features(x_train, len(alphabet), 3) 241 | x_test,_ = kmer_features(x_test, len(alphabet), 3) 242 | elif args.features == '5-mer': 243 | x_train,n = kmer_features(x_train, len(alphabet), 5) 244 | x_test,_ = kmer_features(x_test, len(alphabet), 5) 245 | else: 246 | features = torch.load(args.features) 247 | features.eval() 248 | 249 | if use_cuda: 250 | features.cuda() 251 | 252 | features = TorchModel(features, use_cuda, full_features=args.all_hidden) 253 | batch_size = 32 # batch size for featurizing sequences 254 | 255 | with torch.no_grad(): 256 | z_train = [] 257 | for i in range(0,len(x_train),batch_size): 258 | for z in features(x_train[i:i+batch_size]): 259 | z_train.append(z.cpu().numpy()) 260 | x_train = z_train 261 | 262 | z_test = [] 263 | for i in range(0,len(x_test),batch_size): 264 | for z in features(x_test[i:i+batch_size]): 265 | z_test.append(z.cpu().numpy()) 266 | x_test = z_test 267 | 268 | n = x_train[0].shape[1] 269 | del features 270 | del z_train 271 | del z_test 272 | 273 | print('split', 'epoch', 'loss', 'perplexity', 'accuracy') 274 | 275 | if args.features.endswith('-mer'): 276 | x_train = np.concatenate(x_train, 0) 277 | model = fit_kmer_potentials(x_train, y_train, n, len(secstr)) 278 | else: 279 | x_train = torch.cat([torch.from_numpy(x) for x in x_train], 0) 280 | if use_cuda and not args.all_hidden: 281 | x_train = x_train.cuda() 282 | 283 | num_hidden = 1024 284 | model = nn.Sequential( nn.Linear(n, num_hidden) 285 | , nn.ReLU() 286 | , nn.Linear(num_hidden, num_hidden) 287 | , nn.ReLU() 288 | , nn.Linear(num_hidden, len(secstr)) 289 | ) 290 | 291 | y_train = torch.from_numpy(y_train).long() 292 | if use_cuda: 293 | y_train = y_train.cuda() 294 | model.cuda() 295 | 296 | fit_nn_potentials(model, x_train, y_train, num_epochs=num_epochs, use_cuda=use_cuda) 297 | 298 | if use_cuda: 299 | model.cuda() 300 | model.eval() 301 | 302 | num_examples = args.print_examples 303 | if num_examples > 0: 304 | names_examples = names_test[:num_examples] 305 | x_examples = x_test[:num_examples] 306 | y_examples = y_test[:num_examples] 307 | 308 | A = np.zeros((8,3), dtype=np.float32) 309 | I = np.zeros(8, dtype=int) 310 | # helix 311 | A[0,0] = 1.0 312 | A[3,0] = 1.0 313 | A[4,0] = 1.0 314 | I[0] = 0 315 | I[3] = 0 316 | I[4] = 0 317 | # sheet 318 | A[1,1] = 1.0 319 | A[2,1] = 1.0 320 | I[1] = 1 321 | I[2] = 1 322 | # coil 323 | A[5,2] = 1.0 324 | A[6,2] = 1.0 325 | A[7,2] = 1.0 326 | I[5] = 2 327 | I[6] = 2 328 | I[7] = 2 329 | 330 | A = torch.from_numpy(A) 331 | I = torch.from_numpy(I) 332 | if use_cuda: 333 | A = A.cuda() 334 | I = I.cuda() 335 | 336 | n = 0 337 | acc_8 = 0 338 | acc_3 = 0 339 | loss_8 = 0 340 | loss_3 = 0 341 | 342 | x_test = torch.cat([torch.from_numpy(x) for x in x_test], 0) 343 | y_test = torch.cat([torch.from_numpy(y).long() for y in y_test], 0) 344 | 345 | if use_cuda and not args.all_hidden: 346 | x_test = x_test.cuda() 347 | y_test = y_test.cuda() 348 | 349 | mb = 256 350 | with torch.no_grad(): 351 | for i in range(0, len(x_test), mb): 352 | x = x_test[i:i+mb] 353 | y = y_test[i:i+mb] 354 | 355 | if use_cuda: 356 | x = x.cuda() 357 | y = y.cuda() 358 | 359 | potentials = model(x).view(x.size(0), -1) 360 | 361 | ## 8-class SS 362 | l = F.cross_entropy(potentials, y).item() 363 | _,y_hat = potentials.max(1) 364 | correct = torch.sum((y == y_hat).float()).item() 365 | 366 | n += x.size(0) 367 | delta = x.size(0)*(l - loss_8) 368 | loss_8 += delta/n 369 | delta = correct - x.size(0)*acc_8 370 | acc_8 += delta/n 371 | 372 | ## 3-class SS 373 | y = I[y] 374 | p = F.softmax(potentials, 1) 375 | p = torch.mm(p, A) # ss3 probabilities 376 | log_p = torch.log(p) 377 | l = F.nll_loss(log_p, y).item() 378 | _,y_hat = log_p.max(1) 379 | correct = torch.sum((y == y_hat).float()).item() 380 | 381 | delta = x.size(0)*(l - loss_3) 382 | loss_3 += delta/n 383 | delta = correct - x.size(0)*acc_3 384 | acc_3 += delta/n 385 | 386 | print('-', '-', '8-class', '-', '3-class', '-') 387 | print('split', 'perplexity', 'accuracy', 'perplexity', 'accuracy') 388 | print('test', np.exp(loss_8), acc_8, np.exp(loss_3), acc_3) 389 | 390 | 391 | if num_examples > 0: 392 | for i in range(num_examples): 393 | name = names_examples[i].decode('utf-8') 394 | x = x_examples[i] 395 | y = y_examples[i] 396 | 397 | seq = sequences_test[i] 398 | 399 | print('>' + name + ' sequence') 400 | print(seq) 401 | print('') 402 | 403 | ss = ''.join(secstr[c] for c in y) 404 | ss = ss.replace(' ', 'C') 405 | print('>' + name + ' secstr') 406 | print(ss) 407 | print('') 408 | 409 | x = torch.from_numpy(x) 410 | if use_cuda: 411 | x = x.cuda() 412 | potentials = model(x) 413 | _,y_hat = torch.max(potentials, 1) 414 | y_hat = y_hat.cpu().numpy() 415 | 416 | ss_hat = ''.join(secstr[c] for c in y_hat) 417 | ss_hat = ss_hat.replace(' ', 'C') 418 | print('>' + name + ' predicted') 419 | print(ss_hat) 420 | print('') 421 | 422 | 423 | 424 | if __name__ == '__main__': 425 | main() 426 | -------------------------------------------------------------------------------- /train_similarity.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import sys 6 | 7 | from scipy.stats import pearsonr, spearmanr 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 14 | import torch.utils.data 15 | 16 | from src.alphabets import Uniprot21 17 | import src.scop as scop 18 | from src.utils import pack_sequences, unpack_sequences 19 | from src.utils import PairedDataset, AllPairsDataset, collate_paired_sequences 20 | from src.utils import MultinomialResample 21 | import src.models.embedding 22 | import src.models.comparison 23 | 24 | 25 | def main(): 26 | import argparse 27 | parser = argparse.ArgumentParser('Script for training embedding model on SCOP.') 28 | 29 | parser.add_argument('--dev', action='store_true', help='use train/dev split') 30 | 31 | parser.add_argument('-m', '--model', choices=['ssa', 'ua', 'me'], default='ssa', help='alignment scoring method for comparing sequences in embedding space [ssa: soft symmetric alignment, ua: uniform alignment, me: mean embedding] (default: ssa)') 32 | parser.add_argument('--allow-insert', action='store_true', help='model insertions (default: false)') 33 | 34 | parser.add_argument('--norm', choices=['l1', 'l2'], default='l1', help='comparison norm (default: l1)') 35 | 36 | parser.add_argument('--rnn-type', choices=['lstm', 'gru'], default='lstm', help='type of RNN block to use (default: lstm)') 37 | parser.add_argument('--embedding-dim', type=int, default=100, help='embedding dimension (default: 100)') 38 | parser.add_argument('--input-dim', type=int, default=512, help='dimension of input to RNN (default: 512)') 39 | parser.add_argument('--rnn-dim', type=int, default=512, help='hidden units of RNNs (default: 512)') 40 | parser.add_argument('--num-layers', type=int, default=3, help='number of RNN layers (default: 3)') 41 | parser.add_argument('--dropout', type=float, default=0, help='dropout probability (default: 0)') 42 | 43 | parser.add_argument('--epoch-size', type=int, default=100000, help='number of examples per epoch (default: 100,000)') 44 | parser.add_argument('--epoch-scale', type=int, default=5, help='scaling on epoch size (default: 5)') 45 | parser.add_argument('--num-epochs', type=int, default=100, help='number of epochs (default: 100)') 46 | 47 | parser.add_argument('--batch-size', type=int, default=64, help='minibatch size (default: 64)') 48 | 49 | parser.add_argument('--weight-decay', type=float, default=0, help='L2 regularization (default: 0)') 50 | parser.add_argument('--lr', type=float, default=0.001) 51 | 52 | parser.add_argument('--tau', type=float, default=0.5, help='sampling proportion exponent (default: 0.5)') 53 | parser.add_argument('--augment', type=float, default=0, help='probability of resampling amino acid for data augmentation (default: 0)') 54 | parser.add_argument('--lm', help='pretrained LM to use as initial embedding') 55 | 56 | parser.add_argument('-o', '--output', help='output file path (default: stdout)') 57 | parser.add_argument('--save-prefix', help='path prefix for saving models') 58 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 59 | 60 | args = parser.parse_args() 61 | 62 | 63 | prefix = args.output 64 | 65 | 66 | ## set the device 67 | d = args.device 68 | use_cuda = (d != -1) and torch.cuda.is_available() 69 | if d >= 0: 70 | torch.cuda.set_device(d) 71 | 72 | ## make the datasets 73 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.fa' 74 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt' 75 | if args.dev: 76 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa' 77 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt' 78 | 79 | alphabet = Uniprot21() 80 | 81 | print('# loading training sequences:', astral_train_path, file=sys.stderr) 82 | with open(astral_train_path, 'rb') as f: 83 | names_train, structs_train, sequences_train = scop.parse_astral(f, encoder=alphabet) 84 | x_train = [torch.from_numpy(x).long() for x in sequences_train] 85 | if use_cuda: 86 | x_train = [x.cuda() for x in x_train] 87 | y_train = torch.from_numpy(structs_train) 88 | 89 | print('# loaded', len(x_train), 'training sequences', file=sys.stderr) 90 | 91 | 92 | print('# loading test sequence pairs:', astral_testpairs_path, file=sys.stderr) 93 | test_pairs_table = pd.read_csv(astral_testpairs_path, sep='\t') 94 | x0_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_A']] 95 | x0_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x0_test] 96 | x1_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_B']] 97 | x1_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x1_test] 98 | if use_cuda: 99 | x0_test = [x.cuda() for x in x0_test] 100 | x1_test = [x.cuda() for x in x1_test] 101 | y_test = test_pairs_table['similarity'].values 102 | y_test = torch.from_numpy(y_test).long() 103 | 104 | dataset_test = PairedDataset(x0_test, x1_test, y_test) 105 | print('# loaded', len(x0_test), 'test pairs', file=sys.stderr) 106 | 107 | ## make the dataset iterators 108 | scale = args.epoch_scale 109 | 110 | epoch_size = args.epoch_size 111 | batch_size = args.batch_size 112 | 113 | # precompute the similarity pairs 114 | y_train_levels = torch.cumprod((y_train.unsqueeze(1) == y_train.unsqueeze(0)).long(), 2) 115 | 116 | # data augmentation by resampling amino acids 117 | augment = None 118 | p = 0 119 | if args.augment > 0: 120 | p = args.augment 121 | trans = torch.ones(len(alphabet),len(alphabet)) 122 | trans = trans/trans.sum(1, keepdim=True) 123 | if use_cuda: 124 | trans = trans.cuda() 125 | augment = MultinomialResample(trans, p) 126 | print('# resampling amino acids with p:', p, file=sys.stderr) 127 | dataset_train = AllPairsDataset(x_train, y_train_levels, augment=augment) 128 | 129 | similarity = y_train_levels.numpy().sum(2) 130 | levels,counts = np.unique(similarity, return_counts=True) 131 | order = np.argsort(levels) 132 | levels = levels[order] 133 | counts = counts[order] 134 | 135 | print('#', levels, file=sys.stderr) 136 | print('#', counts/np.sum(counts), file=sys.stderr) 137 | 138 | weight = counts**0.5 139 | print('#', weight/np.sum(weight), file=sys.stderr) 140 | 141 | weight = counts**0.33 142 | print('#', weight/np.sum(weight), file=sys.stderr) 143 | 144 | weight = counts**0.25 145 | print('#', weight/np.sum(weight), file=sys.stderr) 146 | 147 | tau = args.tau 148 | print('# using tau:', tau, file=sys.stderr) 149 | print('#', counts**tau/np.sum(counts**tau), file=sys.stderr) 150 | weights = counts**tau/counts 151 | weights = weights[similarity].ravel() 152 | #weights = np.ones(len(dataset_train)) 153 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, epoch_size) 154 | 155 | # two training dataset iterators for sampling pairs of sequences for training 156 | train_iterator = torch.utils.data.DataLoader(dataset_train 157 | , batch_size=batch_size 158 | , sampler=sampler 159 | , collate_fn=collate_paired_sequences 160 | ) 161 | test_iterator = torch.utils.data.DataLoader(dataset_test 162 | , batch_size=batch_size 163 | , collate_fn=collate_paired_sequences 164 | ) 165 | 166 | 167 | ## initialize the model 168 | rnn_type = args.rnn_type 169 | rnn_dim = args.rnn_dim 170 | num_layers = args.num_layers 171 | 172 | embedding_size = args.embedding_dim 173 | input_dim = args.input_dim 174 | 175 | dropout = args.dropout 176 | 177 | allow_insert = args.allow_insert 178 | 179 | print('# initializing model with:', file=sys.stderr) 180 | print('# embedding_size:', embedding_size, file=sys.stderr) 181 | print('# input_dim:', input_dim, file=sys.stderr) 182 | print('# rnn_dim:', rnn_dim, file=sys.stderr) 183 | print('# num_layers:', num_layers, file=sys.stderr) 184 | print('# dropout:', dropout, file=sys.stderr) 185 | print('# allow_insert:', allow_insert, file=sys.stderr) 186 | 187 | compare_type = args.model 188 | print('# comparison method:', compare_type, file=sys.stderr) 189 | 190 | lm = None 191 | if args.lm is not None: 192 | lm = torch.load(args.lm) 193 | lm.eval() 194 | ## do not update the LM parameters 195 | for param in lm.parameters(): 196 | param.requires_grad = False 197 | print('# using LM:', args.lm, file=sys.stderr) 198 | 199 | if num_layers > 0: 200 | embedding = src.models.embedding.StackedRNN(len(alphabet), input_dim, rnn_dim, embedding_size 201 | , nlayers=num_layers, dropout=dropout, lm=lm) 202 | else: 203 | embedding = src.models.embedding.Linear(len(alphabet), input_dim, embedding_size, lm=lm) 204 | 205 | if args.norm == 'l1': 206 | norm = src.models.comparison.L1() 207 | print('# norm: l1', file=sys.stderr) 208 | elif args.norm == 'l2': 209 | norm = src.models.comparison.L2() 210 | print('# norm: l2', file=sys.stderr) 211 | model = src.models.comparison.OrdinalRegression(embedding, 5, align_method=compare_type 212 | , compare=norm, allow_insertions=allow_insert 213 | ) 214 | 215 | if use_cuda: 216 | model.cuda() 217 | 218 | ## setup training parameters and optimizer 219 | num_epochs = args.num_epochs 220 | 221 | weight_decay = args.weight_decay 222 | lr = args.lr 223 | 224 | print('# training with Adam: lr={}, weight_decay={}'.format(lr, weight_decay), file=sys.stderr) 225 | params = [p for p in model.parameters() if p.requires_grad] 226 | optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 227 | 228 | ## train the model 229 | print('# training model', file=sys.stderr) 230 | 231 | save_prefix = args.save_prefix 232 | output = args.output 233 | if output is None: 234 | output = sys.stdout 235 | else: 236 | output = open(output, 'w') 237 | digits = int(np.floor(np.log10(num_epochs))) + 1 238 | line = '\t'.join(['epoch', 'split', 'loss', 'mse', 'accuracy', 'r', 'rho' ]) 239 | print(line, file=output) 240 | 241 | 242 | for epoch in range(num_epochs): 243 | # train epoch 244 | model.train() 245 | it = 0 246 | n = 0 247 | loss_estimate = 0 248 | mse_estimate = 0 249 | acc_estimate = 0 250 | 251 | for x0,x1,y in train_iterator: # zip(train_iterator_0, train_iterator_1): 252 | 253 | if use_cuda: 254 | y = y.cuda() 255 | y = Variable(y) 256 | 257 | b = len(x0) 258 | x = x0 + x1 259 | 260 | x,order = pack_sequences(x) 261 | x = PackedSequence(Variable(x.data), x.batch_sizes) 262 | z = model(x) # embed the sequences 263 | z = unpack_sequences(z, order) 264 | 265 | z0 = z[:b] 266 | z1 = z[b:] 267 | 268 | logits = [] 269 | for i in range(b): 270 | z_a = z0[i] 271 | z_b = z1[i] 272 | logits.append(model.score(z_a, z_b)) 273 | logits = torch.stack(logits, 0) 274 | 275 | loss = F.binary_cross_entropy_with_logits(logits, y.float()) 276 | loss.backward() 277 | 278 | optim.step() 279 | optim.zero_grad() 280 | model.clip() # projected gradient for bounding ordinal regressionn parameters 281 | 282 | p = F.sigmoid(logits) 283 | ones = p.new(b,1).zero_() + 1 284 | p_ge = torch.cat([ones, p], 1) 285 | p_lt = torch.cat([1-p, ones], 1) 286 | p = p_ge*p_lt 287 | p = p/p.sum(1,keepdim=True) # make sure p is normalized 288 | 289 | _,y_hard = torch.max(p, 1) 290 | levels = torch.arange(5).to(p.device) 291 | y_hat = torch.sum(p*levels, 1) 292 | y = torch.sum(y.data, 1) 293 | 294 | loss = F.cross_entropy(p, y) # calculate cross entropy loss from p vector 295 | 296 | correct = torch.sum((y == y_hard).float()) 297 | mse = torch.sum((y.float() - y_hat)**2) 298 | 299 | n += b 300 | delta = b*(loss.item() - loss_estimate) 301 | loss_estimate += delta/n 302 | delta = correct.item() - b*acc_estimate 303 | acc_estimate += delta/n 304 | delta = mse.item() - b*mse_estimate 305 | mse_estimate += delta/n 306 | 307 | 308 | if (n - b)//100 < n//100: 309 | print('# [{}/{}] training {:.1%} loss={:.5f}, mse={:.5f}, acc={:.5f}'.format(epoch+1 310 | , num_epochs 311 | , n/epoch_size 312 | , loss_estimate 313 | , mse_estimate 314 | , acc_estimate 315 | ) 316 | , end='\r', file=sys.stderr) 317 | print(' '*80, end='\r', file=sys.stderr) 318 | line = '\t'.join([str(epoch+1).zfill(digits), 'train', str(loss_estimate) 319 | , str(mse_estimate), str(acc_estimate), '-', '-']) 320 | print(line, file=output) 321 | output.flush() 322 | 323 | # eval and save model 324 | model.eval() 325 | 326 | y = [] 327 | logits = [] 328 | with torch.no_grad(): 329 | for x0,x1,y_mb in test_iterator: 330 | 331 | if use_cuda: 332 | y_mb = y_mb.cuda() 333 | y.append(y_mb.long()) 334 | 335 | b = len(x0) 336 | x = x0 + x1 337 | 338 | x,order = pack_sequences(x) 339 | x = PackedSequence(Variable(x.data), x.batch_sizes) 340 | z = model(x) # embed the sequences 341 | z = unpack_sequences(z, order) 342 | 343 | z0 = z[:b] 344 | z1 = z[b:] 345 | 346 | for i in range(b): 347 | z_a = z0[i] 348 | z_b = z1[i] 349 | logits.append(model.score(z_a, z_b)) 350 | 351 | y = torch.cat(y, 0) 352 | logits = torch.stack(logits, 0) 353 | 354 | p = F.sigmoid(logits).data 355 | ones = p.new(p.size(0),1).zero_() + 1 356 | p_ge = torch.cat([ones, p], 1) 357 | p_lt = torch.cat([1-p, ones], 1) 358 | p = p_ge*p_lt 359 | p = p/p.sum(1,keepdim=True) # make sure p is normalized 360 | 361 | loss = F.cross_entropy(p, y).item() 362 | 363 | _,y_hard = torch.max(p, 1) 364 | levels = torch.arange(5).to(p.device) 365 | y_hat = torch.sum(p*levels, 1) 366 | 367 | accuracy = torch.mean((y == y_hard).float()).item() 368 | mse = torch.mean((y.float() - y_hat)**2).item() 369 | 370 | y = y.cpu().numpy() 371 | y_hat = y_hat.cpu().numpy() 372 | 373 | r,_ = pearsonr(y_hat, y) 374 | rho,_ = spearmanr(y_hat, y) 375 | 376 | line = '\t'.join([str(epoch+1).zfill(digits), 'test', str(loss), str(mse) 377 | , str(accuracy), str(r), str(rho)]) 378 | print(line, file=output) 379 | output.flush() 380 | 381 | 382 | # save the model 383 | if save_prefix is not None: 384 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav' 385 | model.cpu() 386 | torch.save(model, save_path) 387 | if use_cuda: 388 | model.cuda() 389 | 390 | 391 | 392 | 393 | if __name__ == '__main__': 394 | main() 395 | 396 | 397 | 398 | 399 | 400 | -------------------------------------------------------------------------------- /eval_contact_casp12.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | import sys 5 | import os 6 | import glob 7 | from PIL import Image 8 | 9 | import torch 10 | from torch.nn.utils.rnn import PackedSequence 11 | import torch.utils.data 12 | 13 | from src.alphabets import Uniprot21 14 | import src.fasta as fasta 15 | from src.utils import pack_sequences, unpack_sequences 16 | from src.utils import ContactMapDataset, collate_lists 17 | from src.metrics import average_precision 18 | 19 | 20 | def load_data(seq_path, struct_path, alphabet, baselines=False): 21 | 22 | pdb_index = {} 23 | for path in struct_path: 24 | pid = os.path.basename(path).split('.')[0] 25 | pdb_index[pid] = path 26 | 27 | with open(seq_path, 'rb') as f: 28 | names, sequences = fasta.parse(f) 29 | names = [name.split()[0].decode('utf-8') for name in names] 30 | sequences = [alphabet.encode(s.upper()) for s in sequences] 31 | 32 | x = [torch.from_numpy(x).long() for x in sequences] 33 | 34 | if baselines: 35 | preds = load_baselines() 36 | preds_domains = {} 37 | 38 | names_ = [] 39 | x_ = [] 40 | y = [] 41 | missing = 0 42 | for xi,name in zip(x,names): 43 | pid = name 44 | if pid not in pdb_index: 45 | print('MISSING:', pid, 'not in structures. Skipping.', file=sys.stderr) 46 | missing += 1 47 | continue 48 | path = pdb_index[pid] 49 | 50 | im = np.array(Image.open(path), copy=False) 51 | contacts = np.zeros(im.shape, dtype=np.float32) 52 | contacts[im == 1] = -1 53 | contacts[im == 255] = 1 54 | 55 | # trim to the domain 56 | start = 0 57 | while contacts[start,start] < 0: 58 | start += 1 59 | end = contacts.shape[0] 60 | while contacts[end-1,end-1] < 0: 61 | end -= 1 62 | xi = xi[start:end] 63 | contacts = contacts[start:end,start:end] 64 | 65 | if baselines: 66 | tag = name.split('-')[0] 67 | for key,value in preds.items(): 68 | if tag not in value: 69 | print(key, 'missing protein', tag, file=sys.stderr) 70 | logits = np.zeros(contacts.shape, dtype=np.float32) - np.inf 71 | else: 72 | logits = value[tag] 73 | logits = logits[start:end,start:end] 74 | domains = preds_domains.get(key, []) 75 | domains.append(logits) 76 | preds_domains[key] = domains 77 | 78 | # mask the matrix below the diagonal 79 | mask = np.tril_indices(contacts.shape[0], k=1) 80 | contacts[mask] = -1 81 | 82 | names_.append(name) 83 | x_.append(xi) 84 | y.append(torch.from_numpy(contacts)) 85 | 86 | print('Missing', missing, 'structures from', len(sequences), 'total.', file=sys.stderr) 87 | print('Reporting on', len(x_), 'structures.', file=sys.stderr) 88 | 89 | if baselines: 90 | return x_, y, names_, preds_domains 91 | 92 | return x_, y, names_ 93 | 94 | baselines = {'157': 'GREMLIN (Baker)', 95 | '079': 'iFold_1', 96 | '219': 'Deepfold-Contact', 97 | '013': 'MetaPSICOV', 98 | '451': 'RaptorX-Contact', 99 | } 100 | 101 | def load_baselines(): 102 | all_preds = {} 103 | for key,value in baselines.items(): 104 | path = 'data/casp12/predictions/*/*'+key+'_1' 105 | paths = glob.glob(path) 106 | preds = {} 107 | for path in paths: 108 | name = os.path.basename(path).split('RR')[0] 109 | 110 | with open(path, 'r') as f: 111 | for line in f: 112 | if line.startswith('MODEL'): 113 | break 114 | n = 0 115 | logits = None 116 | for line in f: 117 | if line.startswith('END'): 118 | break 119 | if line[0] not in '0123456789': 120 | n += len(line.strip()) 121 | else: 122 | if logits is None: 123 | logits = np.zeros((n,n), dtype=np.float32) - np.inf 124 | i,j,_,_,p = line.strip().split() 125 | i = int(i) - 1 126 | j = int(j) - 1 127 | p = float(p) 128 | logits[i,j] = logits[j,i] = np.log(p) - np.log(1-p) 129 | preds[name] = logits 130 | all_preds[value] = preds 131 | return all_preds 132 | 133 | def predict_minibatch(model, x, use_cuda): 134 | b = len(x) 135 | x,order = pack_sequences(x) 136 | x = PackedSequence(x.data, x.batch_sizes) 137 | z = model(x) # embed the sequences 138 | z = unpack_sequences(z, order) 139 | 140 | logits = [] 141 | for i in range(b): 142 | zi = z[i] 143 | lp = model.predict(zi.unsqueeze(0)).view(zi.size(0), zi.size(0)) 144 | logits.append(lp) 145 | 146 | return logits 147 | 148 | 149 | def calc_metrics(logits, y): 150 | y_hat = (logits > 0).astype(np.float32) 151 | TP = (y_hat*y).sum() 152 | precision = 1.0 153 | if y_hat.sum() > 0: 154 | precision = TP/y_hat.sum() 155 | recall = TP/y.sum() 156 | F1 = 0 157 | if precision + recall > 0: 158 | F1 = 2*precision*recall/(precision + recall) 159 | AUPR = average_precision(y, logits) 160 | return precision, recall, F1, AUPR 161 | 162 | 163 | def calc_baselines(baselines, y, lengths, names, output=sys.stdout, individual=False): 164 | 165 | line = '\t'.join(['Distance', 'Method', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5']) 166 | if individual: 167 | line = '\t'.join(['Distance', 'Method', 'Protein', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5']) 168 | print(line, file=output) 169 | output.flush() 170 | 171 | y = [y_.cpu().numpy() for y_ in y] 172 | 173 | # calculate performance metrics 174 | for key,logits in baselines.items(): 175 | 176 | # for all contacts 177 | y_flat = [] 178 | logits_flat = [] 179 | for i in range(len(y)): 180 | yi = y[i] 181 | mask = (yi < 0) 182 | y_flat.append(yi[~mask]) 183 | logits_flat.append(logits[i][~mask]) 184 | 185 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts 186 | precision = np.zeros(len(y)) 187 | recall = np.zeros(len(y)) 188 | F1 = np.zeros(len(y)) 189 | AUPR = np.zeros(len(y)) 190 | prL = np.zeros(len(y)) 191 | prL2 = np.zeros(len(y)) 192 | prL5 = np.zeros(len(y)) 193 | for i in range(len(y)): 194 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 195 | precision[i] = pr 196 | recall[i] = re 197 | F1[i] = f1 198 | AUPR[i] = aupr 199 | 200 | order = np.argsort(logits_flat[i])[::-1] 201 | n = lengths[i] 202 | topL = order[:n] 203 | prL[i] = y_flat[i][topL].mean() 204 | topL2 = order[:n//2] 205 | prL2[i] = y_flat[i][topL2].mean() 206 | topL5 = order[:n//5] 207 | prL5[i] = y_flat[i][topL5].mean() 208 | 209 | if individual: 210 | template = 'All\t{}\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 211 | for i in range(len(y)): 212 | name = names[i] 213 | line = template.format(key, name, precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i]) 214 | print(line, file=output) 215 | else: 216 | template = 'All\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 217 | line = template.format(key, precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean()) 218 | print(line, file=output) 219 | output.flush() 220 | 221 | # for Medium/Long range contacts 222 | y_flat = [] 223 | logits_flat = [] 224 | for i in range(len(y)): 225 | yi = y[i] 226 | mask = (yi < 0) 227 | 228 | medlong = np.tril_indices(len(yi), k=11) 229 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8) 230 | medlong_mask[medlong] = 1 231 | mask = mask | (medlong_mask == 1) 232 | 233 | y_flat.append(yi[~mask]) 234 | logits_flat.append(logits[i][~mask]) 235 | 236 | # calculate precision, recall, F1, and area under the precision recall curve for med/long range contacts 237 | precision = np.zeros(len(y)) 238 | recall = np.zeros(len(y)) 239 | F1 = np.zeros(len(y)) 240 | AUPR = np.zeros(len(y)) 241 | prL = np.zeros(len(y)) 242 | prL2 = np.zeros(len(y)) 243 | prL5 = np.zeros(len(y)) 244 | for i in range(len(y)): 245 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 246 | precision[i] = pr 247 | recall[i] = re 248 | F1[i] = f1 249 | AUPR[i] = aupr 250 | 251 | order = np.argsort(logits_flat[i])[::-1] 252 | n = lengths[i] 253 | topL = order[:n] 254 | prL[i] = y_flat[i][topL].mean() 255 | topL2 = order[:n//2] 256 | prL2[i] = y_flat[i][topL2].mean() 257 | topL5 = order[:n//5] 258 | prL5[i] = y_flat[i][topL5].mean() 259 | 260 | if individual: 261 | template = 'Medium/Long\t{}\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 262 | for i in range(len(y)): 263 | name = names[i] 264 | line = template.format(key, name, precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i]) 265 | print(line, file=output) 266 | else: 267 | template = 'Medium/Long\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 268 | line = template.format(key, precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean()) 269 | print(line, file=output) 270 | output.flush() 271 | 272 | 273 | def main(): 274 | import argparse 275 | parser = argparse.ArgumentParser('Script for evaluating contact map models.') 276 | parser.add_argument('model', help='path to saved model') 277 | parser.add_argument('--batch-size', default=10, type=int, help='number of sequences to process in each batch (default: 10)') 278 | parser.add_argument('-o', '--output', help='output file path (default: stdout)') 279 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 280 | parser.add_argument('--individual', action='store_true') 281 | args = parser.parse_args() 282 | 283 | # load the data 284 | fasta_path = 'data/casp12/casp12.fm-domains.seq.fa' 285 | contact_paths = glob.glob('data/casp12/domains_T0/*.png') 286 | 287 | alphabet = Uniprot21() 288 | baselines = None 289 | if args.model == 'baselines': 290 | x,y,names,baselines = load_data(fasta_path, contact_paths, alphabet, baselines=True) 291 | else: 292 | x,y,names = load_data(fasta_path, contact_paths, alphabet) 293 | 294 | if baselines is not None: 295 | output = args.output 296 | if output is None: 297 | output = sys.stdout 298 | else: 299 | output = open(output, 'w') 300 | 301 | lengths = np.array([len(x_) for x_ in x]) 302 | calc_baselines(baselines, y, lengths, names, output=output, individual=args.individual) 303 | 304 | sys.exit(0) 305 | 306 | ## set the device 307 | d = args.device 308 | use_cuda = (d != -1) and torch.cuda.is_available() 309 | if d >= 0: 310 | torch.cuda.set_device(d) 311 | 312 | if use_cuda: 313 | x = [x_.cuda() for x_ in x] 314 | y = [y_.cuda() for y_ in y] 315 | 316 | model = torch.load(args.model) 317 | model.eval() 318 | if use_cuda: 319 | model.cuda() 320 | 321 | # predict contact maps 322 | batch_size = args.batch_size 323 | dataset = ContactMapDataset(x, y) 324 | iterator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_lists) 325 | logits = [] 326 | with torch.no_grad(): 327 | for xmb,ymb in iterator: 328 | lmb = predict_minibatch(model, xmb, use_cuda) 329 | logits += lmb 330 | 331 | # calculate performance metrics 332 | lengths = np.array([len(x_) for x_ in x]) 333 | logits = [logit.cpu().numpy() for logit in logits] 334 | y = [y_.cpu().numpy() for y_ in y] 335 | 336 | output = args.output 337 | if output is None: 338 | output = sys.stdout 339 | else: 340 | output = open(output, 'w') 341 | if args.individual: 342 | line = '\t'.join(['Distance', 'Protein', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5']) 343 | else: 344 | line = '\t'.join(['Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5']) 345 | print(line, file=output) 346 | output.flush() 347 | 348 | # for all contacts 349 | y_flat = [] 350 | logits_flat = [] 351 | for i in range(len(y)): 352 | yi = y[i] 353 | mask = (yi < 0) 354 | y_flat.append(yi[~mask]) 355 | logits_flat.append(logits[i][~mask]) 356 | 357 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts 358 | precision = np.zeros(len(x)) 359 | recall = np.zeros(len(x)) 360 | F1 = np.zeros(len(x)) 361 | AUPR = np.zeros(len(x)) 362 | prL = np.zeros(len(x)) 363 | prL2 = np.zeros(len(x)) 364 | prL5 = np.zeros(len(x)) 365 | for i in range(len(x)): 366 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 367 | precision[i] = pr 368 | recall[i] = re 369 | F1[i] = f1 370 | AUPR[i] = aupr 371 | 372 | order = np.argsort(logits_flat[i])[::-1] 373 | n = lengths[i] 374 | topL = order[:n] 375 | prL[i] = y_flat[i][topL].mean() 376 | topL2 = order[:n//2] 377 | prL2[i] = y_flat[i][topL2].mean() 378 | topL5 = order[:n//5] 379 | prL5[i] = y_flat[i][topL5].mean() 380 | 381 | if args.individual: 382 | template = 'All\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 383 | for i in range(len(x)): 384 | name = names[i] 385 | line = template.format(name,precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i]) 386 | print(line, file=output) 387 | else: 388 | template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 389 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean()) 390 | print(line, file=output) 391 | output.flush() 392 | 393 | # for Medium/Long range contacts 394 | y_flat = [] 395 | logits_flat = [] 396 | for i in range(len(y)): 397 | yi = y[i] 398 | mask = (yi < 0) 399 | 400 | medlong = np.tril_indices(len(yi), k=11) 401 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8) 402 | medlong_mask[medlong] = 1 403 | mask = mask | (medlong_mask == 1) 404 | 405 | y_flat.append(yi[~mask]) 406 | logits_flat.append(logits[i][~mask]) 407 | 408 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts 409 | precision = np.zeros(len(x)) 410 | recall = np.zeros(len(x)) 411 | F1 = np.zeros(len(x)) 412 | AUPR = np.zeros(len(x)) 413 | prL = np.zeros(len(x)) 414 | prL2 = np.zeros(len(x)) 415 | prL5 = np.zeros(len(x)) 416 | for i in range(len(x)): 417 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i]) 418 | precision[i] = pr 419 | recall[i] = re 420 | F1[i] = f1 421 | AUPR[i] = aupr 422 | 423 | order = np.argsort(logits_flat[i])[::-1] 424 | n = lengths[i] 425 | topL = order[:n] 426 | prL[i] = y_flat[i][topL].mean() 427 | topL2 = order[:n//2] 428 | prL2[i] = y_flat[i][topL2].mean() 429 | topL5 = order[:n//5] 430 | prL5[i] = y_flat[i][topL5].mean() 431 | 432 | if args.individual: 433 | template = 'Medium/Long\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 434 | for i in range(len(x)): 435 | name = names[i] 436 | line = template.format(name,precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i]) 437 | print(line, file=output) 438 | else: 439 | template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}' 440 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean()) 441 | print(line, file=output) 442 | output.flush() 443 | 444 | 445 | 446 | if __name__ == '__main__': 447 | main() 448 | 449 | 450 | 451 | 452 | 453 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /src/alignment.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | cimport numpy as np 4 | import numpy as np 5 | cimport cython 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | def half_global_alignment(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S 10 | , np.int32_t gap, np.int32_t extend): 11 | 12 | """ Matches all of x to a substring of y""" 13 | 14 | cdef int n,m 15 | n = len(x) 16 | m = len(y) 17 | 18 | cdef np.float32_t MI = -np.inf 19 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32) 20 | 21 | A[0,1:,0] = MI 22 | A[1:,0,0] = MI 23 | 24 | A[0,1:,1] = MI # gap in x 25 | A[1:,0,1] = gap + np.arange(n)*extend 26 | 27 | # starting from gap in y costs 0 28 | A[1:,0,2] = MI 29 | 30 | # initialize the traceback matrix 31 | cdef np.ndarray[np.int8_t, ndim=3] tb = np.zeros((n+1,m+1,3), dtype=np.int8) - 1 32 | 33 | cdef np.float32_t s 34 | cdef int i,j 35 | cdef np.int8_t k 36 | 37 | for i in range(n): 38 | for j in range(m): 39 | # match i,j 40 | k = 0 41 | s = A[i,j,0] 42 | if A[i,j,1] > s: 43 | k = 1 44 | s = A[i,j,1] 45 | if A[i,j,2] > s: 46 | k = 2 47 | s = A[i,j,2] 48 | A[i+1,j+1,0] = s + S[x[i],y[j]] 49 | tb[i+1,j+1,0] = k 50 | # insert in x 51 | k = 0 52 | s = A[i,j+1,0] + gap 53 | if A[i,j+1,1] + extend > s: 54 | k = 1 55 | s = A[i,j+1,1] + extend 56 | if A[i,j+1,2] + gap > s: 57 | k = 2 58 | s = A[i,j+1,2] + gap 59 | A[i+1,j+1,1] = s 60 | tb[i+1,j+1,1] = k 61 | # insert in y 62 | k = 0 63 | s = A[i+1,j,0] + gap 64 | if A[i+1,j,1] + gap > s: 65 | k = 1 66 | s = A[i+1,j,1] + gap 67 | if A[i+1,j,2] + extend > s: 68 | k = 2 69 | s = A[i+1,j,2] + extend 70 | A[i+1,j+1,2] = s 71 | tb[i+1,j+1,2] = k 72 | 73 | 74 | # find the end of the best alignment 75 | cdef int j_max, k_max 76 | cdef np.float32_t s_max 77 | 78 | j_max = 0 79 | k_max = 0 80 | s_max = A[n,j_max,k_max] 81 | for j in range(m+1): 82 | for k in range(3): 83 | if A[n,j,k] > s_max: 84 | s_max = A[n,j,k] 85 | j_max = j 86 | k_max = k 87 | 88 | # backtrack the alignment 89 | cdef int k_next 90 | i = n 91 | j = j_max 92 | k = k_max 93 | while tb[i,j,k] >= 0: 94 | k_next = tb[i,j,k] 95 | if k == 0: 96 | i = i - 1 97 | j = j - 1 98 | elif k == 1: 99 | i = i - 1 100 | elif k == 2: 101 | j = j - 1 102 | k = k_next 103 | 104 | return s_max, j, j_max 105 | 106 | cdef int argmax(np.float32_t a, np.float32_t b, np.float32_t c): 107 | if b > a: 108 | if c > b: 109 | return 2 110 | return 1 111 | if c > a: 112 | return 2 113 | return 0 114 | 115 | @cython.boundscheck(False) 116 | @cython.wraparound(False) 117 | def nw_score_extra(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S 118 | , np.int32_t gap, np.int32_t extend): 119 | 120 | cdef int n,m 121 | n = len(x) 122 | m = len(y) 123 | 124 | cdef np.float32_t MI = -np.inf 125 | 126 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32) 127 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32) 128 | cdef np.ndarray[np.float32_t, ndim=2] A_temp 129 | 130 | A[1:,0] = MI # match scores 131 | A[1:,1] = MI # gap in x 132 | A[1:,2] = gap + np.arange(m)*extend # gap in y 133 | 134 | # also calculate the length of the alignment 135 | cdef np.ndarray[np.int32_t, ndim=2] L_prev = np.zeros((m+1,3), dtype=np.int32) 136 | cdef np.ndarray[np.int32_t, ndim=2] L = np.zeros((m+1,3), dtype=np.int32) 137 | cdef np.ndarray[np.int32_t, ndim=2] L_temp 138 | 139 | # and the number of exact matches within the alignment 140 | cdef np.ndarray[np.int32_t, ndim=2] M_prev = np.zeros((m+1,3), dtype=np.int32) 141 | cdef np.ndarray[np.int32_t, ndim=2] M = np.zeros((m+1,3), dtype=np.int32) 142 | cdef np.ndarray[np.int32_t, ndim=2] M_temp 143 | 144 | 145 | cdef np.float32_t s 146 | cdef int i,j,k 147 | 148 | for i in range(n): 149 | # swap A and A_prev 150 | A_temp = A_prev 151 | A_prev = A 152 | A = A_temp 153 | 154 | L_temp = L_prev 155 | L_prev = L 156 | L = L_temp 157 | 158 | M_temp = M_prev 159 | M_prev = M 160 | M = M_temp 161 | 162 | # init A[0] 163 | A[0,0] = MI 164 | A[0,1] = gap + i*extend 165 | A[0,2] = MI 166 | for j in range(m): 167 | # match i,j 168 | s = S[x[i],y[j]] 169 | k = argmax(A_prev[j,0], A_prev[j,1], A_prev[j,2]) 170 | A[j+1,0] = s + A_prev[j,k] 171 | L[j+1,0] = L_prev[j,k] + 1 172 | if x[i] == y[j]: 173 | M[j+1,0] = M_prev[j,k] + 1 174 | else: 175 | M[j+1,0] = M_prev[j,k] 176 | 177 | # insert in x 178 | k = argmax(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap) 179 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap) 180 | L[j+1,1] = L_prev[j+1,k] + 1 181 | M[j+1,1] = M_prev[j+1,k] 182 | 183 | # insert in y 184 | k = argmax(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend) 185 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend) 186 | L[j+1,2] = L[j,k] + 1 187 | M[j+1,2] = M[j,k] 188 | 189 | k = argmax(A[m,0], A[m,1], A[m,2]) 190 | 191 | return A[m,k], M[m,k], L[m,k] 192 | 193 | 194 | @cython.boundscheck(False) 195 | @cython.wraparound(False) 196 | def sw_score_subst_no_affine(np.ndarray[np.float32_t,ndim=2] subst 197 | , np.float32_t gap): 198 | 199 | cdef int n,m 200 | n = subst.shape[0] 201 | m = subst.shape[1] 202 | 203 | cdef np.float32_t MI = -np.inf 204 | 205 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.float32) 206 | 207 | cdef int i,j 208 | for i in range(n): 209 | for j in range(m): 210 | A[i+1,j+1] = max(subst[i,j] + A[i,j], gap+A[i+1,j], gap+A[i,j+1], 0) 211 | 212 | return np.max(A) 213 | 214 | @cython.boundscheck(False) 215 | @cython.wraparound(False) 216 | def sw_score_subst(np.ndarray[np.float32_t,ndim=2] subst 217 | , np.float32_t gap, np.float32_t extend): 218 | 219 | cdef int n,m 220 | n = subst.shape[0] 221 | m = subst.shape[1] 222 | 223 | cdef np.float32_t MI = -np.inf 224 | 225 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32) 226 | 227 | A[0,1:,0] = MI # match scores 228 | A[1:,0,0] = MI 229 | 230 | A[0,1:,1] = MI # gap in x 231 | A[1:,0,2] = MI # gap in y 232 | 233 | cdef np.float32_t s 234 | cdef int i,j 235 | 236 | for i in range(n): 237 | for j in range(m): 238 | # match i,j 239 | s = subst[i,j] 240 | A[i+1,j+1,0] = s + max(A[i,j,0], A[i,j,1], A[i,j,2], 0) 241 | # insert in x 242 | A[i+1,j+1,1] = max(A[i,j+1,0]+gap, A[i,j+1,1]+extend, A[i,j+1,2]+gap, 0) 243 | # insert in y 244 | A[i+1,j+1,2] = max(A[i+1,j,0]+gap, A[i+1,j,1]+gap, A[i+1,j,2]+extend, 0) 245 | 246 | s = np.max(A) 247 | #s = max(A[m,0], A[m,1], A[m,2]) 248 | 249 | return s 250 | 251 | 252 | @cython.boundscheck(False) 253 | @cython.wraparound(False) 254 | def sw_score(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S 255 | , np.int32_t gap, np.int32_t extend): 256 | 257 | cdef int n,m 258 | n = len(x) 259 | m = len(y) 260 | 261 | cdef np.float32_t MI = -np.inf 262 | 263 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32) 264 | 265 | A[0,1:,0] = MI # match scores 266 | A[1:,0,0] = MI 267 | 268 | A[0,1:,1] = MI # gap in x 269 | A[1:,0,2] = MI # gap in y 270 | 271 | #A[1:,2] = gap + np.arange(m)*extend # gap in y 272 | 273 | cdef np.float32_t s 274 | cdef int i,j 275 | 276 | for i in range(n): 277 | for j in range(m): 278 | # match i,j 279 | s = S[x[i],y[j]] 280 | A[i+1,j+1,0] = s + max(A[i,j,0], A[i,j,1], A[i,j,2], 0) 281 | # insert in x 282 | A[i+1,j+1,1] = max(A[i,j+1,0]+gap, A[i,j+1,1]+extend, A[i,j+1,2]+gap, 0) 283 | # insert in y 284 | A[i+1,j+1,2] = max(A[i+1,j,0]+gap, A[i+1,j,1]+gap, A[i+1,j,2]+extend, 0) 285 | 286 | s = np.max(A) 287 | #s = max(A[m,0], A[m,1], A[m,2]) 288 | 289 | return s 290 | 291 | 292 | @cython.boundscheck(False) 293 | @cython.wraparound(False) 294 | def nw_score(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S 295 | , np.int32_t gap, np.int32_t extend): 296 | 297 | cdef int n,m 298 | n = len(x) 299 | m = len(y) 300 | 301 | cdef np.float32_t MI = -np.inf 302 | 303 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32) 304 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32) 305 | cdef np.ndarray[np.float32_t, ndim=2] A_temp 306 | 307 | A[1:,0] = MI # match scores 308 | A[1:,1] = MI # gap in x 309 | A[1:,2] = gap + np.arange(m)*extend # gap in y 310 | 311 | cdef np.float32_t s 312 | cdef int i,j 313 | 314 | for i in range(n): 315 | # swap A and A_prev 316 | A_temp = A_prev 317 | A_prev = A 318 | A = A_temp 319 | # init A[0] 320 | A[0,0] = MI 321 | A[0,1] = gap + i*extend 322 | A[0,2] = MI 323 | for j in range(m): 324 | # match i,j 325 | s = S[x[i],y[j]] 326 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2]) 327 | # insert in x 328 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap) 329 | #A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend) 330 | # insert in y 331 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend) 332 | #A[j+1,2] = max(A[j,0]+gap, A[j,2]+extend) 333 | 334 | s = max(A[m,0], A[m,1], A[m,2]) 335 | 336 | return s 337 | 338 | @cython.boundscheck(False) 339 | @cython.wraparound(False) 340 | def nw_score_subst(np.ndarray[np.float32_t, ndim=2] S 341 | , np.int32_t gap, np.int32_t extend): 342 | 343 | cdef int n,m 344 | n = S.shape[0] 345 | m = S.shape[1] 346 | 347 | cdef np.float32_t MI = -np.inf 348 | 349 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32) 350 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32) 351 | cdef np.ndarray[np.float32_t, ndim=2] A_temp 352 | 353 | A[1:,0] = MI # match scores 354 | A[1:,1] = MI # gap in x 355 | A[1:,2] = gap + np.arange(m)*extend # gap in y 356 | 357 | cdef np.float32_t s 358 | cdef int i,j 359 | 360 | for i in range(n): 361 | # swap A and A_prev 362 | A_temp = A_prev 363 | A_prev = A 364 | A = A_temp 365 | # init A[0] 366 | A[0,0] = MI 367 | A[0,1] = gap + i*extend 368 | A[0,2] = MI 369 | for j in range(m): 370 | # match i,j 371 | s = S[i,j] 372 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2]) 373 | # insert in x 374 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap) 375 | #A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend) 376 | # insert in y 377 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend) 378 | #A[j+1,2] = max(A[j,0]+gap, A[j,2]+extend) 379 | 380 | s = max(A[m,0], A[m,1], A[m,2]) 381 | 382 | return s 383 | 384 | @cython.boundscheck(False) 385 | @cython.wraparound(False) 386 | def nw_align_subst(np.ndarray[np.float32_t, ndim=2] S 387 | , np.int32_t gap, np.int32_t extend): 388 | 389 | cdef int n,m 390 | n = S.shape[0] 391 | m = S.shape[1] 392 | 393 | cdef np.float32_t MI = -np.inf 394 | 395 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32) 396 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32) 397 | cdef np.ndarray[np.float32_t, ndim=2] A_temp 398 | 399 | A[1:,0] = MI # match scores 400 | A[1:,1] = MI # gap in x 401 | A[1:,2] = gap + np.arange(m)*extend # gap in y 402 | 403 | cdef np.ndarray[np.int8_t, ndim=3] tb = np.zeros((n+1,m+1,3), dtype=np.int8) 404 | tb[:] = -1 405 | tb[1:,:,1] = 1 406 | tb[:,1:,2] = 2 407 | 408 | cdef np.float32_t s 409 | cdef int i,j,k 410 | 411 | for i in range(n): 412 | # swap A and A_prev 413 | A_temp = A_prev 414 | A_prev = A 415 | A = A_temp 416 | # init A[0] 417 | A[0,0] = MI 418 | A[0,1] = gap + i*extend 419 | A[0,2] = MI 420 | for j in range(m): 421 | # match i,j 422 | s = S[i,j] 423 | k = 0 424 | if A_prev[j,k] < A_prev[j,1]: 425 | k = 1 426 | if A_prev[j,k] < A_prev[j,2]: 427 | k = 2 428 | A[j+1,0] = s + A_prev[j,k] 429 | tb[i+1,j+1,0] = k 430 | 431 | # insert in x 432 | s = A_prev[j+1,0] + gap 433 | k = 0 434 | if s < A_prev[j+1,1] + extend: 435 | s = A_prev[j+1,1] + extend 436 | k = 1 437 | if s < A_prev[j+1,2] + gap: 438 | s = A_prev[j+1,2] + gap 439 | k = 2 440 | A[j+1,1] = s 441 | tb[i+1,j+1,1] = k 442 | 443 | # insert in y 444 | s = A[j,0] + gap 445 | k = 0 446 | if s < A[j,1] + gap: 447 | s = A[j,1] + gap 448 | k = 1 449 | if s < A[j,2] + extend: 450 | s = A[j,2] + extend 451 | k = 2 452 | A[j+1,2] = s 453 | tb[i+1,j+1,2] = k 454 | 455 | k = 0 456 | s = A[m,k] 457 | if s < A[m,1]: 458 | s = A[m,1] 459 | k = 1 460 | if s < A[m,2]: 461 | s = A[m,2] 462 | k = 2 463 | 464 | ## traceback 465 | align = [] 466 | i = n 467 | j = m 468 | while k >= 0: 469 | k_ = tb[i,j,k] 470 | if k == 0: 471 | align.append((i-1,j-1)) 472 | i -= 1 473 | j -= 1 474 | elif k == 1: 475 | i -= 1 476 | elif k == 2: 477 | j -= 1 478 | k = k_ 479 | align = np.array(align[::-1], dtype=int) 480 | 481 | return s, align 482 | 483 | @cython.boundscheck(False) 484 | @cython.wraparound(False) 485 | def nw_affine_subst_score(np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap 486 | , np.float32_t extend): 487 | 488 | cdef int n,m 489 | n = S.shape[0] 490 | m = S.shape[1] 491 | 492 | cdef np.float32_t MI = -np.inf 493 | 494 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32) 495 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32) 496 | cdef np.ndarray[np.float32_t, ndim=2] A_temp 497 | 498 | A[1:,0] = MI # match scores 499 | A[1:,1] = MI # gap in x 500 | A[1:,2] = gap + np.arange(m)*extend # gap in y 501 | 502 | cdef np.float32_t s 503 | cdef int i,j 504 | 505 | for i in range(n): 506 | # swap A and A_prev 507 | A_temp = A_prev 508 | A_prev = A 509 | A = A_temp 510 | # init A[0] 511 | A[0,0] = MI 512 | A[0,1] = gap + i*extend 513 | A[0,2] = MI 514 | for j in range(m): 515 | # match i,j 516 | s = S[i,j] 517 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2]) 518 | # insert in x 519 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap) 520 | # insert in y 521 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend) 522 | 523 | s = max(A[m,0], A[m,1], A[m,2]) 524 | 525 | return s 526 | 527 | @cython.boundscheck(False) 528 | @cython.wraparound(False) 529 | def nw_subst_score(np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap): #, np.float32_t extend): 530 | 531 | cdef int n,m 532 | n = S.shape[0] 533 | m = S.shape[1] 534 | 535 | cdef np.float32_t MI = -np.inf 536 | 537 | cdef np.ndarray[np.float32_t] A_prev = np.zeros(m+1, dtype=np.float32) 538 | cdef np.ndarray[np.float32_t] A = np.zeros(m+1, dtype=np.float32) 539 | cdef np.ndarray[np.float32_t] A_temp 540 | 541 | A[1:] = gap*np.arange(m) 542 | 543 | cdef np.float32_t s 544 | cdef int i,j 545 | 546 | for i in range(n): 547 | # swap A and A_prev 548 | A_temp = A_prev 549 | A_prev = A 550 | A = A_temp 551 | # init A[0] 552 | A[0] = gap*(i+1) 553 | for j in range(m): 554 | # match i,j 555 | s = S[i,j] 556 | A[j+1] = max(A_prev[j]+s, A_prev[j+1]+gap, A[j]+gap) 557 | 558 | s = A[m] 559 | 560 | return s 561 | 562 | @cython.boundscheck(False) 563 | @cython.wraparound(False) 564 | def edit_distance(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y): 565 | 566 | cdef int n,m 567 | n = len(x) 568 | m = len(y) 569 | 570 | cdef np.ndarray[np.int32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.int32) 571 | A[0,1:] = np.arange(m) + 1 572 | A[1:,0] = np.arange(n) + 1 573 | 574 | cdef int i,j 575 | cdef np.int32_t s 576 | 577 | for i in range(n): 578 | for j in range(m): 579 | s = int(x[i] != y[j]) 580 | A[i+1,j+1] = min(A[i,j] + s, A[i+1,j] + 1, A[i,j+1] + 1) 581 | 582 | return A[n,m] 583 | 584 | @cython.boundscheck(False) 585 | @cython.wraparound(False) 586 | def dtw_subst_score(np.ndarray[np.float32_t, ndim=2] S): 587 | 588 | cdef int n,m 589 | n = S.shape[0] 590 | m = S.shape[1] 591 | 592 | cdef np.float32_t MI = -np.inf 593 | 594 | cdef np.ndarray[np.float32_t] A_prev = np.zeros(m+1, dtype=np.float32) 595 | cdef np.ndarray[np.float32_t] A = np.zeros(m+1, dtype=np.float32) 596 | cdef np.ndarray[np.float32_t] A_temp 597 | 598 | cdef np.float32_t s 599 | cdef int i,j 600 | 601 | for i in range(n): 602 | # swap A and A_prev 603 | A_temp = A_prev 604 | A_prev = A 605 | A = A_temp 606 | for j in range(m): 607 | # match i,j 608 | s = S[i,j] 609 | A[j+1] = s + max(A_prev[j], A_prev[j+1], A[j]) 610 | 611 | s = A[m] 612 | 613 | return s 614 | 615 | @cython.boundscheck(False) 616 | @cython.wraparound(False) 617 | def nw_align( np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap): 618 | 619 | cdef int n,m 620 | n = S.shape[0] 621 | m = S.shape[1] 622 | 623 | cdef int i,j 624 | cdef np.float32_t score 625 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.float32) 626 | A[:,0] = gap*np.arange(n+1) 627 | A[0,:] = gap*np.arange(m+1) 628 | A[1:,1:] = -np.inf 629 | 630 | cdef np.ndarray[np.int8_t, ndim=2] paths = np.zeros((n+1,m+1), dtype=np.int8) 631 | paths[:] = -1 632 | paths[0,1:] = 1 633 | paths[1:,0] = 2 634 | 635 | i = j = 0 636 | score = -np.inf 637 | 638 | for i in range(n): 639 | for j in range(m): 640 | s = A[i,j] + S[i,j] 641 | if s > A[i+1,j+1]: 642 | A[i+1,j+1] = s 643 | paths[i+1,j+1] = 0 644 | 645 | gi = A[i+1,j] + gap 646 | if gi > A[i+1,j+1]: 647 | A[i+1,j+1] = gi 648 | paths[i+1,j+1] = 1 649 | 650 | gj = A[i,j+1] + gap 651 | if gj > A[i+1,j+1]: 652 | A[i+1,j+1] = gj 653 | paths[i+1,j+1] = 2 654 | 655 | ## traceback 656 | tb = [] 657 | i = n 658 | j = m 659 | while paths[i,j] >= 0: 660 | if paths[i,j] == 0: 661 | tb.append((i-1,j-1)) 662 | i -= 1 663 | j -= 1 664 | elif paths[i,j] == 1: 665 | j -= 1 666 | elif paths[i,j] == 2: 667 | i -= 1 668 | tb = np.array(tb[::-1], dtype=int) 669 | 670 | 671 | return A[n,m], tb 672 | 673 | 674 | 675 | 676 | -------------------------------------------------------------------------------- /train_similarity_and_contact.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import sys 6 | import os 7 | import glob 8 | from PIL import Image 9 | from scipy.stats import pearsonr, spearmanr 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 16 | import torch.utils.data 17 | 18 | from src.alphabets import Uniprot21 19 | import src.scop as scop 20 | from src.utils import pack_sequences, unpack_sequences 21 | from src.utils import ContactMapDataset, collate_lists 22 | from src.utils import PairedDataset, AllPairsDataset, collate_paired_sequences 23 | from src.utils import MultinomialResample 24 | import src.models.embedding 25 | import src.models.multitask 26 | from src.metrics import average_precision 27 | 28 | cmap_paths = glob.glob('data/SCOPe/pdbstyle-2.06/*/*.png') 29 | cmap_dict = {os.path.basename(path)[:7] : path for path in cmap_paths} 30 | 31 | 32 | def load_data(path, alphabet): 33 | with open(path, 'rb') as f: 34 | names, structs, sequences = scop.parse_astral(f, encoder=alphabet) 35 | x = [torch.from_numpy(x).long() for x in sequences] 36 | s = torch.from_numpy(structs) 37 | c = [] 38 | for name in names: 39 | name = name.decode('utf-8') 40 | if name not in cmap_dict: 41 | name = 'd' + name[1:] 42 | path = cmap_dict[name] 43 | im = np.array(Image.open(path), copy=False) 44 | contacts = np.zeros(im.shape, dtype=np.float32) 45 | contacts[im == 1] = -1 46 | contacts[im == 255] = 1 47 | # mask the matrix below the diagonal 48 | mask = np.tril_indices(contacts.shape[0], k=-1) 49 | contacts[mask] = -1 50 | c.append(torch.from_numpy(contacts)) 51 | return x, s, c 52 | 53 | 54 | def load_scop_testpairs(astral_testpairs_path, alphabet): 55 | print('# loading test sequence pairs:', astral_testpairs_path, file=sys.stderr) 56 | test_pairs_table = pd.read_csv(astral_testpairs_path, sep='\t') 57 | x0_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_A']] 58 | x0_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x0_test] 59 | x1_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_B']] 60 | x1_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x1_test] 61 | y_test = test_pairs_table['similarity'].values 62 | y_test = torch.from_numpy(y_test).long() 63 | 64 | return x0_test, x1_test, y_test 65 | 66 | 67 | def similarity_grad(model, x0, x1, y, use_cuda, weight=0.5): 68 | if use_cuda: 69 | y = y.cuda() 70 | y = Variable(y) 71 | 72 | b = len(x0) 73 | x = x0 + x1 74 | 75 | x,order = pack_sequences(x) 76 | x = PackedSequence(Variable(x.data), x.batch_sizes) 77 | z = model(x) # embed the sequences 78 | z = unpack_sequences(z, order) 79 | 80 | z0 = z[:b] 81 | z1 = z[b:] 82 | 83 | logits = [] 84 | for i in range(b): 85 | z_a = z0[i] 86 | z_b = z1[i] 87 | logits.append(model.score(z_a, z_b)) 88 | logits = torch.stack(logits, 0) 89 | 90 | loss = F.binary_cross_entropy_with_logits(logits, y.float()) 91 | 92 | # backprop weighted loss 93 | w_loss = loss*weight 94 | w_loss.backward() 95 | 96 | # calculate minibatch performance metrics 97 | with torch.no_grad(): 98 | p = torch.sigmoid(logits) 99 | ones = p.new(b,1).zero_() + 1 100 | p_ge = torch.cat([ones, p], 1) 101 | p_lt = torch.cat([1-p, ones], 1) 102 | p = p_ge*p_lt 103 | p = p/p.sum(1,keepdim=True) # make sure p is normalized 104 | 105 | _,y_hard = torch.max(p, 1) 106 | levels = torch.arange(5).to(p.device) 107 | y_hat = torch.sum(p*levels, 1) 108 | y = torch.sum(y.data, 1) 109 | 110 | loss = F.cross_entropy(p, y).item() # calculate cross entropy loss from p vector 111 | 112 | correct = torch.sum((y == y_hard).float()).item() 113 | mse = torch.mean((y.float() - y_hat)**2).item() 114 | 115 | return loss, correct, mse, b 116 | 117 | 118 | def contacts_grad(model, x, y, use_cuda, weight=0.5): 119 | b = len(x) 120 | x,order = pack_sequences(x) 121 | x = PackedSequence(Variable(x.data), x.batch_sizes) 122 | z = model(x) # embed the sequences 123 | z = unpack_sequences(z, order) 124 | 125 | logits = [] 126 | for i in range(b): 127 | zi = z[i] 128 | lp = model.predict(zi.unsqueeze(0)).view(-1) 129 | logits.append(lp) 130 | logits = torch.cat(logits, 0) 131 | 132 | y = torch.cat([yi.view(-1) for yi in y]) 133 | if use_cuda: 134 | y = y.cuda() 135 | mask = (y < 0) 136 | 137 | logits = logits[~mask] 138 | y = Variable(y[~mask]) 139 | b = y.size(0) 140 | 141 | loss = F.binary_cross_entropy_with_logits(logits, y) 142 | 143 | # backprop weighted loss 144 | w_loss = loss*weight 145 | w_loss.backward() 146 | 147 | # calculate the recall and precision 148 | with torch.no_grad(): 149 | p_hat = torch.sigmoid(logits) 150 | tp = torch.sum(p_hat*y).item() 151 | gp = y.sum().item() 152 | pp = p_hat.sum().item() 153 | 154 | return loss.item(), tp, gp, pp, b 155 | 156 | 157 | def predict_contacts(model, x, y, use_cuda): 158 | b = len(x) 159 | x,order = pack_sequences(x) 160 | x = PackedSequence(Variable(x.data), x.batch_sizes) 161 | z = model(x) # embed the sequences 162 | z = unpack_sequences(z, order) 163 | 164 | logits = [] 165 | y_list = [] 166 | for i in range(b): 167 | zi = z[i] 168 | lp = model.predict(zi.unsqueeze(0)).view(-1) 169 | 170 | yi = y[i].view(-1) 171 | if use_cuda: 172 | yi = yi.cuda() 173 | mask = (yi < 0) 174 | 175 | lp = lp[~mask] 176 | yi = yi[~mask] 177 | 178 | logits.append(lp) 179 | y_list.append(yi) 180 | 181 | return logits, y_list 182 | 183 | 184 | def eval_contacts(model, test_iterator, use_cuda): 185 | logits = [] 186 | y = [] 187 | 188 | for x,y_mb in test_iterator: 189 | logits_this, y_this = predict_contacts(model, x, y_mb, use_cuda) 190 | logits += logits_this 191 | y += y_this 192 | 193 | y = torch.cat(y, 0) 194 | logits = torch.cat(logits, 0) 195 | 196 | loss = F.binary_cross_entropy_with_logits(logits, y).item() 197 | 198 | p_hat = torch.sigmoid(logits) 199 | tp = torch.sum(y*p_hat).item() 200 | pr = tp/torch.sum(p_hat).item() 201 | re = tp/torch.sum(y).item() 202 | f1 = 2*pr*re/(pr + re) 203 | 204 | y = y.cpu().numpy() 205 | logits = logits.data.cpu().numpy() 206 | 207 | aupr = average_precision(y, logits) 208 | 209 | return loss, pr, re, f1, aupr 210 | 211 | def eval_similarity(model, test_iterator, use_cuda): 212 | y = [] 213 | logits = [] 214 | for x0,x1,y_mb in test_iterator: 215 | 216 | if use_cuda: 217 | y_mb = y_mb.cuda() 218 | y.append(y_mb.long()) 219 | 220 | b = len(x0) 221 | x = x0 + x1 222 | 223 | x,order = pack_sequences(x) 224 | x = PackedSequence(Variable(x.data), x.batch_sizes) 225 | z = model(x) # embed the sequences 226 | z = unpack_sequences(z, order) 227 | 228 | z0 = z[:b] 229 | z1 = z[b:] 230 | 231 | for i in range(b): 232 | z_a = z0[i] 233 | z_b = z1[i] 234 | logits.append(model.score(z_a, z_b)) 235 | 236 | y = torch.cat(y, 0) 237 | logits = torch.stack(logits, 0) 238 | 239 | p = torch.sigmoid(logits).data 240 | ones = p.new(p.size(0),1).zero_() + 1 241 | p_ge = torch.cat([ones, p], 1) 242 | p_lt = torch.cat([1-p, ones], 1) 243 | p = p_ge*p_lt 244 | p = p/p.sum(1,keepdim=True) # make sure p is normalized 245 | 246 | loss = F.cross_entropy(p, y).item() 247 | 248 | _,y_hard = torch.max(p, 1) 249 | levels = torch.arange(5).to(p.device) 250 | y_hat = torch.sum(p*levels, 1) 251 | 252 | accuracy = torch.mean((y == y_hard).float()).item() 253 | mse = torch.mean((y.float() - y_hat)**2).item() 254 | 255 | y = y.cpu().numpy() 256 | y_hat = y_hat.cpu().numpy() 257 | 258 | r,_ = pearsonr(y_hat, y) 259 | rho,_ = spearmanr(y_hat, y) 260 | 261 | return loss, accuracy, mse, r, rho 262 | 263 | 264 | def main(): 265 | import argparse 266 | parser = argparse.ArgumentParser('Script for training contact prediction model') 267 | 268 | parser.add_argument('--dev', action='store_true', help='use train/dev split') 269 | 270 | parser.add_argument('--rnn-type', choices=['lstm', 'gru'], default='lstm', help='type of RNN block to use (default: lstm)') 271 | parser.add_argument('--embedding-dim', type=int, default=100, help='embedding dimension (default: 40)') 272 | parser.add_argument('--input-dim', type=int, default=512, help='dimension of input to RNN (default: 512)') 273 | parser.add_argument('--rnn-dim', type=int, default=512, help='hidden units of RNNs (default: 128)') 274 | parser.add_argument('--num-layers', type=int, default=3, help='number of RNN layers (default: 3)') 275 | parser.add_argument('--dropout', type=float, default=0, help='dropout probability (default: 0)') 276 | 277 | 278 | parser.add_argument('--hidden-dim', type=int, default=50, help='number of hidden units for comparison layer in contact predictionn (default: 50)') 279 | parser.add_argument('--width', type=int, default=7, help='width of convolutional filter for contact prediction (default: 7)') 280 | 281 | 282 | parser.add_argument('--epoch-size', type=int, default=100000, help='number of examples per epoch (default: 100,000)') 283 | parser.add_argument('--epoch-scale', type=int, default=5, help='report heldout performance every this many epochs (default: 5)') 284 | parser.add_argument('--num-epochs', type=int, default=100, help='number of epochs (default: 100)') 285 | 286 | 287 | parser.add_argument('--similarity-batch-size', type=int, default=64, help='minibatch size for similarity prediction loss in pairs (default: 64)') 288 | parser.add_argument('--contact-batch-size', type=int, default=10, help='minibatch size for contact predictionn loss (default: 10)') 289 | 290 | 291 | parser.add_argument('--weight-decay', type=float, default=0, help='L2 regularization (default: 0)') 292 | parser.add_argument('--lr', type=float, default=0.001) 293 | parser.add_argument('--lambda', dest='lambda_', type=float, default=0.5, help='weight on the similarity objective, contact map objective weight is one minus this (default: 0.5)') 294 | 295 | parser.add_argument('--tau', type=float, default=0.5, help='sampling proportion exponent (default: 0.5)') 296 | parser.add_argument('--augment', type=float, default=0, help='probability of resampling amino acid for data augmentation (default: 0)') 297 | 298 | parser.add_argument('--lm', help='pretrained LM to use as initial embedding') 299 | 300 | parser.add_argument('-o', '--output', help='output file path (default: stdout)') 301 | parser.add_argument('--save-prefix', help='path prefix for saving models') 302 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use') 303 | 304 | args = parser.parse_args() 305 | 306 | 307 | prefix = args.output 308 | 309 | 310 | ## set the device 311 | d = args.device 312 | use_cuda = (d != -1) and torch.cuda.is_available() 313 | if d >= 0: 314 | torch.cuda.set_device(d) 315 | 316 | ## make the datasets 317 | alphabet = Uniprot21() 318 | 319 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.fa' 320 | astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa' 321 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt' 322 | if args.dev: 323 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa' 324 | astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.fa' 325 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt' 326 | 327 | 328 | print('# loading training sequences:', astral_train_path, file=sys.stderr) 329 | x_train, structs_train, contacts_train = load_data(astral_train_path, alphabet) 330 | if use_cuda: 331 | x_train = [x.cuda() for x in x_train] 332 | #contacts_train = [c.cuda() for c in contacts_train] 333 | print('# loaded', len(x_train), 'training sequences', file=sys.stderr) 334 | 335 | print('# loading test sequences:', astral_test_path, file=sys.stderr) 336 | x_test, _, contacts_test = load_data(astral_test_path, alphabet) 337 | if use_cuda: 338 | x_test = [x.cuda() for x in x_test] 339 | #contacts_test = [c.cuda() for c in contacts_test] 340 | print('# loaded', len(x_test), 'contact map test sequences', file=sys.stderr) 341 | 342 | x0_test, x1_test, y_scop_test = load_scop_testpairs(astral_testpairs_path, alphabet) 343 | if use_cuda: 344 | x0_test = [x.cuda() for x in x0_test] 345 | x1_test = [x.cuda() for x in x1_test] 346 | print('# loaded', len(x0_test), 'scop test pairs', file=sys.stderr) 347 | 348 | ## make the dataset iterators 349 | 350 | # data augmentation by resampling amino acids 351 | augment = None 352 | p = 0 353 | if args.augment > 0: 354 | p = args.augment 355 | trans = torch.ones(len(alphabet),len(alphabet)) 356 | trans = trans/trans.sum(1, keepdim=True) 357 | if use_cuda: 358 | trans = trans.cuda() 359 | augment = MultinomialResample(trans, p) 360 | print('# resampling amino acids with p:', p, file=sys.stderr) 361 | 362 | # SCOP structural similarity datasets 363 | scop_levels = torch.cumprod((structs_train.unsqueeze(1) == structs_train.unsqueeze(0)).long(), 2) 364 | scop_train = AllPairsDataset(x_train, scop_levels, augment=augment) 365 | scop_test = PairedDataset(x0_test, x1_test, y_scop_test) 366 | 367 | # contact map datasets 368 | cmap_train = ContactMapDataset(x_train, contacts_train, augment=augment) 369 | cmap_test = ContactMapDataset(x_test, contacts_test) 370 | 371 | # iterators for contacts data 372 | batch_size = args.contact_batch_size 373 | cmap_train_iterator = torch.utils.data.DataLoader(cmap_train 374 | , batch_size=batch_size 375 | , shuffle=True 376 | , collate_fn=collate_lists 377 | ) 378 | cmap_test_iterator = torch.utils.data.DataLoader(cmap_test 379 | , batch_size=batch_size 380 | , collate_fn=collate_lists 381 | ) 382 | 383 | # make the SCOP training iterator have same number of minibatches 384 | num_steps = len(cmap_train_iterator) 385 | batch_size = args.similarity_batch_size 386 | epoch_size = num_steps*batch_size 387 | 388 | similarity = scop_levels.numpy().sum(2) 389 | levels,counts = np.unique(similarity, return_counts=True) 390 | order = np.argsort(levels) 391 | levels = levels[order] 392 | counts = counts[order] 393 | 394 | tau = args.tau 395 | print('# using tau:', tau, file=sys.stderr) 396 | print('#', counts**tau/np.sum(counts**tau), file=sys.stderr) 397 | weights = counts**tau/counts 398 | weights = weights[similarity].ravel() 399 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, epoch_size) 400 | N = epoch_size 401 | 402 | # iterators for similarity data 403 | scop_train_iterator = torch.utils.data.DataLoader(scop_train 404 | , batch_size=batch_size 405 | , sampler=sampler 406 | , collate_fn=collate_paired_sequences 407 | ) 408 | scop_test_iterator = torch.utils.data.DataLoader(scop_test 409 | , batch_size=batch_size 410 | , collate_fn=collate_paired_sequences 411 | ) 412 | 413 | report_steps = args.epoch_scale 414 | 415 | 416 | ## initialize the model 417 | rnn_type = args.rnn_type 418 | rnn_dim = args.rnn_dim 419 | num_layers = args.num_layers 420 | 421 | embedding_size = args.embedding_dim 422 | input_dim = args.input_dim 423 | dropout = args.dropout 424 | 425 | print('# initializing embedding model with:', file=sys.stderr) 426 | print('# embedding_size:', embedding_size, file=sys.stderr) 427 | print('# input_dim:', input_dim, file=sys.stderr) 428 | print('# rnn_dim:', rnn_dim, file=sys.stderr) 429 | print('# num_layers:', num_layers, file=sys.stderr) 430 | print('# dropout:', dropout, file=sys.stderr) 431 | 432 | lm = None 433 | if args.lm is not None: 434 | print('# using pretrained LM:', args.lm, file=sys.stderr) 435 | lm = torch.load(args.lm) 436 | lm.eval() 437 | ## do not update the LM parameters 438 | for param in lm.parameters(): 439 | param.requires_grad = False 440 | 441 | embedding = src.models.embedding.StackedRNN(len(alphabet), input_dim, rnn_dim 442 | , embedding_size, nlayers=num_layers 443 | , dropout=dropout, lm=lm) 444 | 445 | # similarity prediction parameters 446 | similarity_kwargs = {} 447 | 448 | # contact map prediction parameters 449 | hidden_dim = args.hidden_dim 450 | width = args.width 451 | cmap_kwargs = {'hidden_dim': hidden_dim, 'width': width} 452 | 453 | model = src.models.multitask.SCOPCM(embedding, similarity_kwargs=similarity_kwargs, 454 | cmap_kwargs=cmap_kwargs) 455 | if use_cuda: 456 | model.cuda() 457 | 458 | ## setup training parameters and optimizer 459 | num_epochs = args.num_epochs 460 | 461 | weight_decay = args.weight_decay 462 | lr = args.lr 463 | 464 | print('# training with Adam: lr={}, weight_decay={}'.format(lr, weight_decay), file=sys.stderr) 465 | params = [p for p in model.parameters() if p.requires_grad] 466 | optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 467 | 468 | scop_weight = args.lambda_ 469 | cmap_weight = 1 - scop_weight 470 | 471 | print('# weighting tasks with SIMILARITY: {:.3f}, CONTACTS: {:.3f}'.format(scop_weight, cmap_weight), file=sys.stderr) 472 | 473 | ## train the model 474 | print('# training model', file=sys.stderr) 475 | 476 | save_prefix = args.save_prefix 477 | output = args.output 478 | if output is None: 479 | output = sys.stdout 480 | else: 481 | output = open(output, 'w') 482 | digits = int(np.floor(np.log10(num_epochs))) + 1 483 | tokens = ['sim_loss', 'sim_mse', 'sim_acc', 'sim_r', 'sim_rho' 484 | ,'cmap_loss', 'cmap_pr', 'cmap_re', 'cmap_f1', 'cmap_aupr'] 485 | line = '\t'.join(['epoch', 'split'] + tokens) 486 | print(line, file=output) 487 | 488 | prog_template = '# [{}/{}] training {:.1%} sim_loss={:.5f}, sim_acc={:.5f}, cmap_loss={:.5f}, cmap_f1={:.5f}' 489 | 490 | for epoch in range(num_epochs): 491 | # train epoch 492 | model.train() 493 | 494 | scop_n = 0 495 | scop_loss_accum = 0 496 | scop_mse_accum = 0 497 | scop_acc_accum = 0 498 | 499 | cmap_n = 0 500 | cmap_loss_accum = 0 501 | cmap_pp = 0 502 | cmap_pr_accum = 0 503 | cmap_gp = 0 504 | cmap_re_accum = 0 505 | 506 | for (cmap_x, cmap_y), (scop_x0, scop_x1, scop_y) in zip(cmap_train_iterator, scop_train_iterator): 507 | 508 | # calculate gradients and metrics for similarity part 509 | loss, correct, mse, b = similarity_grad(model, scop_x0, scop_x1, scop_y, use_cuda, weight=scop_weight) 510 | 511 | scop_n += b 512 | delta = b*(loss - scop_loss_accum) 513 | scop_loss_accum += delta/scop_n 514 | delta = correct - b*scop_acc_accum 515 | scop_acc_accum += delta/scop_n 516 | delta = b*(mse - scop_mse_accum) 517 | scop_mse_accum += delta/scop_n 518 | 519 | report = ((scop_n - b)//100 < scop_n//100) 520 | 521 | # calculate the contact map prediction gradients and metrics 522 | loss, tp, gp_, pp_, b = contacts_grad(model, cmap_x, cmap_y, use_cuda, weight=cmap_weight) 523 | 524 | cmap_gp += gp_ 525 | delta = tp - gp_*cmap_re_accum 526 | cmap_re_accum += delta/cmap_gp 527 | 528 | cmap_pp += pp_ 529 | delta = tp - pp_*cmap_pr_accum 530 | cmap_pr_accum += delta/cmap_pp 531 | 532 | cmap_n += b 533 | delta = b*(loss - cmap_loss_accum) 534 | cmap_loss_accum += delta/cmap_n 535 | 536 | ## update the parameters 537 | optim.step() 538 | optim.zero_grad() 539 | model.clip() 540 | 541 | if report: 542 | f1 = 2*cmap_pr_accum*cmap_re_accum/(cmap_pr_accum + cmap_re_accum) 543 | line = prog_template.format(epoch+1, num_epochs, scop_n/N, scop_loss_accum 544 | , scop_acc_accum, cmap_loss_accum, f1) 545 | print(line, end='\r', file=sys.stderr) 546 | print(' '*80, end='\r', file=sys.stderr) 547 | f1 = 2*cmap_pr_accum*cmap_re_accum/(cmap_pr_accum + cmap_re_accum) 548 | tokens = [ scop_loss_accum, scop_mse_accum, scop_acc_accum, '-', '-' 549 | , cmap_loss_accum, cmap_pr_accum, cmap_re_accum, f1, '-'] 550 | tokens = [x if type(x) is str else '{:.5f}'.format(x) for x in tokens] 551 | 552 | line = '\t'.join([str(epoch+1).zfill(digits), 'train'] + tokens) 553 | print(line, file=output) 554 | output.flush() 555 | 556 | # eval and save model 557 | if (epoch+1) % report_steps == 0: 558 | model.eval() 559 | with torch.no_grad(): 560 | scop_loss, scop_acc, scop_mse, scop_r, scop_rho = \ 561 | eval_similarity(model, scop_test_iterator, use_cuda) 562 | cmap_loss, cmap_pr, cmap_re, cmap_f1, cmap_aupr = \ 563 | eval_contacts(model, cmap_test_iterator, use_cuda) 564 | 565 | tokens = [ scop_loss, scop_mse, scop_acc, scop_r, scop_rho 566 | , cmap_loss, cmap_pr, cmap_re, cmap_f1, cmap_aupr] 567 | tokens = ['{:.5f}'.format(x) for x in tokens] 568 | 569 | line = '\t'.join([str(epoch+1).zfill(digits), 'test'] + tokens) 570 | print(line, file=output) 571 | output.flush() 572 | 573 | # save the model 574 | if save_prefix is not None: 575 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav' 576 | model.cpu() 577 | torch.save(model, save_path) 578 | if use_cuda: 579 | model.cuda() 580 | 581 | 582 | 583 | 584 | if __name__ == '__main__': 585 | main() 586 | 587 | 588 | 589 | 590 | 591 | --------------------------------------------------------------------------------