├── src
├── __init__.py
├── models
│ ├── __init__.py
│ ├── comparison.py
│ ├── multitask.py
│ ├── embedding.py
│ └── sequence.py
├── parse_utils.py
├── pfam.py
├── scop.py
├── metrics.pyx
├── pdb.py
├── fasta.py
├── alphabets.py
├── utils.py
├── transmembrane.py
└── alignment.pyx
├── setup.py
├── .gitignore
├── README.md
├── embed_sequences.py
├── eval_contact_scop.py
├── eval_similarity.py
├── train_lm_pfam.py
├── eval_transmembrane.py
├── eval_secstr.py
├── train_similarity.py
├── eval_contact_casp12.py
├── LICENSE
└── train_similarity_and_contact.py
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy as np
4 |
5 | setup(
6 | ext_modules = cythonize(['src/metrics.pyx', 'src/alignment.pyx']),
7 | include_dirs=[np.get_include()]
8 | )
9 |
--------------------------------------------------------------------------------
/src/parse_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | def parse_3line(f):
4 | names = []
5 | xs = []
6 | ys = []
7 | for line in f:
8 | if line.startswith(b'>'):
9 | name = line[1:]
10 | # get the sequence
11 | x = f.readline().strip()
12 | # get the transmembrane annotations
13 | y = f.readline().strip()
14 | names.append(name)
15 | xs.append(x)
16 | ys.append(y)
17 | return names, xs, ys
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/src/pfam.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 |
5 | def parse_seed(f):
6 | alignments = []
7 | a = []
8 | for line in f:
9 | if line.startswith(b'#'):
10 | continue
11 | if line.startswith(b'//'):
12 | alignments.append(a)
13 | a = []
14 | else:
15 | _,s = line.split()
16 | a.append(s)
17 | if len(a) > 0:
18 | alignments.append(a)
19 |
20 | return alignments
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/src/scop.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 | import src.fasta as fasta
5 |
6 | class NullEncoder:
7 | def encode(self, x):
8 | return x
9 |
10 | def parse_astral_name(name, encode_struct=True):
11 | tokens = name.split()
12 | name = tokens[0]
13 |
14 | # encode structure levels as integer by right-padding with zero byte
15 | # to 4 bytes
16 | if encode_struct:
17 | struct = b''
18 | for s in tokens[1].split(b'.'):
19 | n = len(s)
20 | s = s + b'\x00'*(4-n)
21 | struct += s
22 | struct = np.frombuffer(struct, dtype=np.int32)
23 | else:
24 | struct = np.array(tokens[1].split(b'.'))
25 |
26 | return name, struct
27 |
28 | def parse_astral(f, encoder=NullEncoder(), encode_struct=True):
29 | names = []
30 | structs = []
31 | sequences = []
32 | for name,sequence in fasta.parse_stream(f):
33 | x = encoder.encode(sequence.upper())
34 | name, struct = parse_astral_name(name, encode_struct=encode_struct)
35 | names.append(name)
36 | structs.append(struct)
37 | sequences.append(x)
38 | structs = np.stack(structs, 0)
39 | return names, structs, sequences
40 |
--------------------------------------------------------------------------------
/src/metrics.pyx:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 |
3 | cimport cython
4 | cimport numpy as np
5 | import numpy as np
6 |
7 | @cython.boundscheck(False)
8 | @cython.wraparound(False)
9 | @cython.cdivision(True)
10 | def average_precision(np.ndarray[float] target, np.ndarray[float] pred, N=None):
11 | cdef float n
12 | if N is None:
13 | n = target.sum()
14 | else:
15 | n = N
16 |
17 | ## copy the target and prediction into matrix
18 | cdef np.ndarray[float, ndim=2] matrix = np.zeros((target.shape[0],2), dtype=np.float32)
19 | matrix[:,0] = -pred # negate the prediction to sort in descending order
20 | matrix[:,1] = target
21 | matrix.view('f4,f4').sort(order='f0', axis=0) # sort the rows
22 | #print(matrix[:10])
23 |
24 | cdef float auprc, count, pr, relk, delta
25 | auprc = count = pr = relk = 0
26 | cdef int i = 0
27 |
28 | for i in range(matrix.shape[0]):
29 | count += 1
30 | relk += matrix[i,1] # target
31 | delta = matrix[i,1] - pr
32 | pr += delta/count
33 | if i >= matrix.shape[0] - 1 or matrix[i,0] != matrix[i+1,0]:
34 | auprc += pr*relk
35 | relk = 0
36 | auprc /= n
37 |
38 | return auprc
39 |
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Don't commit the data or pretrained models
2 | data/
3 | pretrained_models/
4 | pretrained_models.tar.gz
5 |
6 | # VIM temp files
7 | *.swp
8 |
9 | # cython compiled .c files
10 | src/*.c
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | env/
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *,cover
57 | .hypothesis/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # IPython Notebook
81 | .ipynb_checkpoints
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # dotenv
90 | .env
91 |
92 | # virtualenv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
--------------------------------------------------------------------------------
/src/pdb.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | def parse_secstr_stream(f, comment=b'#'):
4 | name = None
5 | flag = -1
6 | protein = []
7 | secstr = []
8 | for line in f:
9 | if line.startswith(comment):
10 | continue
11 | # strip newline
12 | line = line.rstrip(b'\r\n')
13 | if line.startswith(b'>'):
14 | if name is not None and flag==1:
15 | yield name, b''.join(protein), b''.join(secstr)
16 | elif flag == 0:
17 | assert line[1:].startswith(name)
18 |
19 | # each protein has an amino acid sequence
20 | # and secstr sequence associated with it
21 |
22 | name = line[1:]
23 | tokens = name.split(b':')
24 | name = b':'.join(tokens[:-1])
25 | flag = tokens[-1]
26 |
27 | if flag == b'sequence':
28 | flag = 0
29 | protein = []
30 | secstr = []
31 | elif flag == b'secstr':
32 | flag = 1
33 | else:
34 | raise Exception("Unrecognized flag: " + flag.decode())
35 |
36 | elif flag==0:
37 | protein.append(line)
38 | elif flag==1:
39 | secstr.append(line)
40 | else:
41 | raise Exception("Flag not set properly")
42 |
43 | if name is not None:
44 | yield name, b''.join(protein), b''.join(secstr)
45 |
46 | def parse_secstr(f, comment=b'#'):
47 |
48 | names = []
49 | proteins = []
50 | secstrs = []
51 | for name,protein,secstr in parse_secstr_stream(f, comment=comment):
52 | names.append(name)
53 | proteins.append(protein)
54 | secstrs.append(secstr)
55 | return names, proteins, secstrs
56 |
57 |
58 |
--------------------------------------------------------------------------------
/src/fasta.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | def parse_stream(f, comment=b'#'):
4 | name = None
5 | sequence = []
6 | for line in f:
7 | if line.startswith(comment):
8 | continue
9 | line = line.strip()
10 | if line.startswith(b'>'):
11 | if name is not None:
12 | yield name, b''.join(sequence)
13 | name = line[1:]
14 | sequence = []
15 | else:
16 | sequence.append(line.upper())
17 | if name is not None:
18 | yield name, b''.join(sequence)
19 |
20 | def parse(f, comment=b'#'):
21 | names = []
22 | sequences = []
23 | name = None
24 | sequence = []
25 | for line in f:
26 | if line.startswith(comment):
27 | continue
28 | line = line.strip()
29 | if line.startswith(b'>'):
30 | if name is not None:
31 | names.append(name)
32 | sequences.append(b''.join(sequence))
33 | name = line[1:]
34 | sequence = []
35 | else:
36 | sequence.append(line.upper())
37 | if name is not None:
38 | names.append(name)
39 | sequences.append(b''.join(sequence))
40 |
41 | return names, sequences
42 |
43 | #def parse(f, comment='#'):
44 | # names = []
45 | # sequences = []
46 | # name = None
47 | # sequence = []
48 | # for line in f:
49 | # if line.startswith(comment):
50 | # continue
51 | # line = line.strip()
52 | # if line.startswith('>'):
53 | # if name is not None:
54 | # names.append(name)
55 | # sequences.append(''.join(sequence))
56 | # name = line[1:]
57 | # sequence = []
58 | # else:
59 | # sequence.append(line)
60 | # if name is not None:
61 | # names.append(name)
62 | # sequences.append(''.join(sequence))
63 | #
64 | # return names, sequences
65 |
66 |
67 |
--------------------------------------------------------------------------------
/src/alphabets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 |
5 | class Alphabet:
6 | def __init__(self, chars, encoding=None, mask=False, missing=255):
7 | self.chars = np.frombuffer(chars, dtype=np.uint8)
8 | self.encoding = np.zeros(256, dtype=np.uint8) + missing
9 | if encoding is None:
10 | self.encoding[self.chars] = np.arange(len(self.chars))
11 | self.size = len(self.chars)
12 | else:
13 | self.encoding[self.chars] = encoding
14 | self.size = encoding.max() + 1
15 | self.mask = mask
16 | if mask:
17 | self.size -= 1
18 |
19 | def __len__(self):
20 | return self.size
21 |
22 | def __getitem__(self, i):
23 | return chr(self.chars[i])
24 |
25 | def encode(self, x):
26 | """ encode a byte string into alphabet indices """
27 | x = np.frombuffer(x, dtype=np.uint8)
28 | return self.encoding[x]
29 |
30 | def decode(self, x):
31 | """ decode index array, x, to byte string of this alphabet """
32 | string = self.chars[x]
33 | return string.tobytes()
34 |
35 | def unpack(self, h, k):
36 | """ unpack integer h into array of this alphabet with length k """
37 | n = self.size
38 | kmer = np.zeros(k, dtype=np.uint8)
39 | for i in reversed(range(k)):
40 | c = h % n
41 | kmer[i] = c
42 | h = h // n
43 | return kmer
44 |
45 | def get_kmer(self, h, k):
46 | """ retrieve byte string of length k decoded from integer h """
47 | kmer = self.unpack(h, k)
48 | return self.decode(kmer)
49 |
50 | DNA = Alphabet(b'ACGT')
51 |
52 | class Uniprot21(Alphabet):
53 | def __init__(self, mask=False):
54 | chars = alphabet = b'ARNDCQEGHILKMFPSTWYVXOUBZ'
55 | encoding = np.arange(len(chars))
56 | encoding[21:] = [11,4,20,20] # encode 'OUBZ' as synonyms
57 | super(Uniprot21, self).__init__(chars, encoding=encoding, mask=mask, missing=20)
58 |
59 | class SDM12(Alphabet):
60 | """
61 | A D KER N TSQ YF LIVM C W H G P
62 |
63 | See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2732308/#B33
64 | "Reduced amino acid alphabets exhibit an improved sensitivity and selectivity in fold assignment"
65 | Peterson et al. 2009. Bioinformatics.
66 | """
67 | def __init__(self, mask=False):
68 | chars = alphabet = b'ADKNTYLCWHGPXERSQFIVMOUBZ'
69 | groups = [b'A',b'D',b'KERO',b'N',b'TSQ',b'YF',b'LIVM',b'CU',b'W',b'H',b'G',b'P',b'XBZ']
70 | groups = {c:i for i in range(len(groups)) for c in groups[i]}
71 | encoding = np.array([groups[c] for c in chars])
72 | super(SDM12, self).__init__(chars, encoding=encoding, mask=mask)
73 |
74 | SecStr8 = Alphabet(b'HBEGITS ')
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning protein sequence embeddings using information from structure
2 |
3 | New and improved embedding models combining sequence and structure training are now available at https://github.com/tbepler/prose!
4 |
5 |
6 |
7 |
8 | This repository contains the source code and links to the data and pretrained embedding models accompanying the ICLR 2019 paper: [Learning protein sequence embeddings using information from structure](https://openreview.net/pdf?id=SygLehCqtm)
9 |
10 | ```
11 | @inproceedings{
12 | bepler2018learning,
13 | title={Learning protein sequence embeddings using information from structure},
14 | author={Tristan Bepler and Bonnie Berger},
15 | booktitle={International Conference on Learning Representations},
16 | year={2019},
17 | }
18 | ```
19 |
20 | ## Setup and dependencies
21 |
22 | Dependencies:
23 | - python 3
24 | - pytorch >= 0.4
25 | - numpy
26 | - scipy
27 | - pandas
28 | - sklearn
29 | - cython
30 | - h5py (for embedding script)
31 |
32 | Run setup.py to compile the cython files:
33 |
34 | ```
35 | python setup.py build_ext --inplace
36 | ```
37 |
38 | ## Data sets
39 |
40 | The data sets with train/dev/test splits are provided as .tar.gz files from the links below.
41 |
42 | - [SCOPe data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/scope.tar.gz)
43 | - [Pfam data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/pfam.tar.gz)
44 | - [Protein secondary structure data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/secstr.tar.gz)
45 | - [Transmembrane data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/transmembrane.tar.gz)
46 | - [CASP12 contact map data](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/casp12.tar.gz)
47 |
48 | The training and evaluation scripts assume that these data sets have been extracted into a directory called 'data'.
49 |
50 | ## Pretrained models
51 |
52 | Our trained versions of the structure-based embedding models and the bidirectional language model can be downloaded [here](http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/pretrained_models.tar.gz).
53 |
54 | ## Author
55 |
56 | Tristan Bepler (tbepler@mit.edu)
57 |
58 | ## Cite
59 |
60 | Please cite the above paper if you use this code or pretrained models in your work.
61 |
62 | ## License
63 |
64 | The source code and trained models are provided free for non-commercial use under the terms of the CC BY-NC 4.0 license. See [LICENSE](LICENSE) file and/or https://creativecommons.org/licenses/by-nc/4.0/legalcode for more information.
65 |
66 |
67 | ## Contact
68 |
69 | If you have any questions, comments, or would like to report a bug, please file a Github issue or contact me at tbepler@mit.edu.
70 |
71 |
--------------------------------------------------------------------------------
/src/models/comparison.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class L1(nn.Module):
9 | def forward(self, x, y):
10 | return -torch.sum(torch.abs(x.unsqueeze(1)-y), -1)
11 |
12 |
13 | class L2(nn.Module):
14 | def forward(self, x, y):
15 | return -torch.sum((x.unsqueeze(1)-y)**2, -1)
16 |
17 |
18 | class DotProduct(nn.Module):
19 | def forward(self, x, y):
20 | return torch.mm(x, y.t())
21 |
22 |
23 | def pad_gap_scores(s, gap):
24 | col = gap.expand(s.size(0), 1)
25 | s = torch.cat([s, col], 1)
26 | row = gap.expand(1, s.size(1))
27 | s = torch.cat([s, row], 0)
28 | return s
29 |
30 |
31 | class OrdinalRegression(nn.Module):
32 | def __init__(self, embedding, n_classes, compare=L1()
33 | , align_method='ssa', beta_init=10
34 | , allow_insertions=False, gap_init=-10
35 | ):
36 | super(OrdinalRegression, self).__init__()
37 |
38 | self.embedding = embedding
39 | self.n_out = n_classes
40 |
41 | self.compare = compare
42 | self.align_method = align_method
43 | self.allow_insertions = allow_insertions
44 | self.gap = nn.Parameter(torch.FloatTensor([gap_init]))
45 |
46 | self.theta = nn.Parameter(torch.ones(1,n_classes-1))
47 | self.beta = nn.Parameter(torch.zeros(n_classes-1)+beta_init)
48 | self.clip()
49 |
50 | def forward(self, x):
51 | return self.embedding(x)
52 |
53 | def clip(self):
54 | # clip the weights of ordinal regression to be non-negative
55 | self.theta.data.clamp_(min=0)
56 |
57 | def score(self, z_x, z_y):
58 |
59 | if self.align_method == 'ssa':
60 | s = self.compare(z_x, z_y)
61 | if self.allow_insertions:
62 | s = pad_gap_scores(s, self.gap)
63 |
64 | a = F.softmax(s, 1)
65 | b = F.softmax(s, 0)
66 |
67 | if self.allow_insertions:
68 | index = s.size(0)-1
69 | index = s.data.new(1).long().fill_(index)
70 | a = a.index_fill(0, index, 0)
71 |
72 | index = s.size(1)-1
73 | index = s.data.new(1).long().fill_(index)
74 | b = b.index_fill(1, index, 0)
75 |
76 | a = a + b - a*b
77 | c = torch.sum(a*s)/torch.sum(a)
78 |
79 | elif self.align_method == 'ua':
80 | s = self.compare(z_x, z_y)
81 | c = torch.mean(s)
82 |
83 | elif self.align_method == 'me':
84 | z_x = z_x.mean(0)
85 | z_y = z_y.mean(0)
86 | c = self.compare(z_x.unsqueeze(0), z_y.unsqueeze(0)).squeeze(0)
87 |
88 | else:
89 | raise Exception('Unknown alignment method: ' + self.align_method)
90 |
91 | logits = c*self.theta + self.beta
92 | return logits.view(-1)
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/src/models/multitask.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .comparison import L1, pad_gap_scores
8 |
9 |
10 | class SCOPCM(nn.Module):
11 | def __init__(self, embedding, similarity_kwargs={},
12 | cmap_kwargs={}):
13 | super(SCOPCM, self).__init__()
14 |
15 | self.embedding = embedding
16 | embed_dim = embedding.nout
17 |
18 | self.scop_predict = OrdinalRegression(5, **similarity_kwargs)
19 | self.cmap_predict = ConvContactMap(embed_dim, **cmap_kwargs)
20 |
21 | def clip(self):
22 | self.scop_predict.clip()
23 | self.cmap_predict.clip()
24 |
25 | def forward(self, x):
26 | return self.embedding(x)
27 |
28 | def score(self, z_x, z_y):
29 | return self.scop_predict(z_x, z_y)
30 |
31 | def predict(self, z):
32 | return self.cmap_predict(z)
33 |
34 |
35 | class ConvContactMap(nn.Module):
36 | def __init__(self, embed_dim, hidden_dim=50, width=7, act=nn.ReLU()):
37 | super(ConvContactMap, self).__init__()
38 | self.hidden = nn.Conv2d(2*embed_dim, hidden_dim, 1)
39 | self.act = act
40 | self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width//2)
41 | self.clip()
42 |
43 | def clip(self):
44 | # force the conv layer to be transpose invariant
45 | w = self.conv.weight
46 | self.conv.weight.data[:] = 0.5*(w + w.transpose(2,3))
47 |
48 | def forward(self, z):
49 | return self.predict(z)
50 |
51 | def predict(self, z):
52 | # z is (b,L,d)
53 | z = z.transpose(1, 2) # (b,d,L)
54 | z_dif = torch.abs(z.unsqueeze(2) - z.unsqueeze(3))
55 | z_mul = z.unsqueeze(2)*z.unsqueeze(3)
56 | z = torch.cat([z_dif, z_mul], 1)
57 | # (b,2d,L,L)
58 | h = self.act(self.hidden(z))
59 | logits = self.conv(h).squeeze(1)
60 | return logits
61 |
62 |
63 | class OrdinalRegression(nn.Module):
64 | def __init__(self, n_classes, compare=L1()
65 | , align_method='ssa', beta_init=10
66 | , allow_insertions=False, gap_init=-10
67 | ):
68 | super(OrdinalRegression, self).__init__()
69 |
70 | self.n_out = n_classes
71 |
72 | self.compare = compare
73 | self.align_method = align_method
74 | self.allow_insertions = allow_insertions
75 | self.gap = nn.Parameter(torch.FloatTensor([gap_init]))
76 |
77 | self.theta = nn.Parameter(torch.ones(1,n_classes-1))
78 | self.beta = nn.Parameter(torch.zeros(n_classes-1)+beta_init)
79 | self.clip()
80 |
81 | def clip(self):
82 | # clip the weights of ordinal regression to be non-negative
83 | self.theta.data.clamp_(min=0)
84 |
85 | def forward(self, z_x, z_y):
86 |
87 | if self.align_method == 'ssa':
88 | s = self.compare(z_x, z_y)
89 | if self.allow_insertions:
90 | s = pad_gap_scores(s, self.gap)
91 |
92 | a = F.softmax(s, 1)
93 | b = F.softmax(s, 0)
94 |
95 | if self.allow_insertions:
96 | index = s.size(0)-1
97 | index = s.data.new(1).long().fill_(index)
98 | a = a.index_fill(0, index, 0)
99 |
100 | index = s.size(1)-1
101 | index = s.data.new(1).long().fill_(index)
102 | b = b.index_fill(1, index, 0)
103 |
104 | a = a + b - a*b
105 | c = torch.sum(a*s)/torch.sum(a)
106 |
107 | elif self.align_method == 'ua':
108 | s = self.compare(z_x, z_y)
109 | c = torch.mean(s)
110 |
111 | elif self.align_method == 'me':
112 | z_x = z_x.mean(0)
113 | z_y = z_y.mean(0)
114 | c = self.compare(z_x.unsqueeze(0), z_y.unsqueeze(0)).squeeze(0)
115 |
116 | else:
117 | raise Exception('Unknown alignment method: ' + self.align_method)
118 |
119 | logits = c*self.theta + self.beta
120 | return logits.view(-1)
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/src/models/embedding.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import PackedSequence
7 |
8 |
9 | class LMEmbed(nn.Module):
10 | def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU()
11 | , sparse=False):
12 | super(LMEmbed, self).__init__()
13 |
14 | if padding_idx == -1:
15 | padding_idx = nin-1
16 |
17 | self.lm = lm
18 | self.embed = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse)
19 | self.proj = nn.Linear(lm.hidden_size(), nout)
20 | self.transform = transform
21 | self.nout = nout
22 |
23 | def forward(self, x):
24 | packed = type(x) is PackedSequence
25 | h_lm = self.lm.encode(x)
26 |
27 | # embed and unpack if packed
28 | if packed:
29 | h = self.embed(x.data)
30 | h_lm = h_lm.data
31 | else:
32 | h = self.embed(x)
33 |
34 | # project
35 | h_lm = self.proj(h_lm)
36 | h = self.transform(h + h_lm)
37 |
38 | # repack if needed
39 | if packed:
40 | h = PackedSequence(h, x.batch_sizes)
41 |
42 | return h
43 |
44 |
45 | class Linear(nn.Module):
46 | def __init__(self, nin, nhidden, nout, padding_idx=-1,
47 | sparse=False, lm=None):
48 | super(Linear, self).__init__()
49 |
50 | if padding_idx == -1:
51 | padding_idx = nin-1
52 |
53 | if lm is not None:
54 | self.embed = LMEmbed(nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse)
55 | self.proj = nn.Linear(self.embed.nout, nout)
56 | self.lm = True
57 | else:
58 | self.proj = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse)
59 | self.lm = False
60 |
61 | self.nout = nout
62 |
63 |
64 | def forward(self, x):
65 |
66 | if self.lm:
67 | h = self.embed(x)
68 | if type(h) is PackedSequence:
69 | h = h.data
70 | z = self.proj(h)
71 | z = PackedSequence(z, x.batch_sizes)
72 | else:
73 | h = h.view(-1, h.size(2))
74 | z = self.proj(h)
75 | z = z.view(x.size(0), x.size(1), -1)
76 | else:
77 | if type(x) is PackedSequence:
78 | z = self.embed(x.data)
79 | z = PackedSequence(z, x.batch_sizes)
80 | else:
81 | z = self.embed(x)
82 |
83 | return z
84 |
85 |
86 | class StackedRNN(nn.Module):
87 | def __init__(self, nin, nembed, nunits, nout, nlayers=2, padding_idx=-1, dropout=0,
88 | rnn_type='lstm', sparse=False, lm=None):
89 | super(StackedRNN, self).__init__()
90 |
91 | if padding_idx == -1:
92 | padding_idx = nin-1
93 |
94 | if lm is not None:
95 | self.embed = LMEmbed(nin, nembed, lm, padding_idx=padding_idx, sparse=sparse)
96 | nembed = self.embed.nout
97 | self.lm = True
98 | else:
99 | self.embed = nn.Embedding(nin, nembed, padding_idx=padding_idx, sparse=sparse)
100 | self.lm = False
101 |
102 | if rnn_type == 'lstm':
103 | RNN = nn.LSTM
104 | elif rnn_type == 'gru':
105 | RNN = nn.GRU
106 |
107 | self.dropout = nn.Dropout(p=dropout)
108 | if nlayers == 1:
109 | dropout = 0
110 |
111 | self.rnn = RNN(nembed, nunits, nlayers, batch_first=True
112 | , bidirectional=True, dropout=dropout)
113 | self.proj = nn.Linear(2*nunits, nout)
114 | self.nout = nout
115 |
116 |
117 |
118 | def forward(self, x):
119 |
120 | if self.lm:
121 | h = self.embed(x)
122 | else:
123 | if type(x) is PackedSequence:
124 | h = self.embed(x.data)
125 | h = PackedSequence(h, x.batch_sizes)
126 | else:
127 | h = self.embed(x)
128 |
129 | h,_ = self.rnn(h)
130 |
131 | if type(h) is PackedSequence:
132 | h = h.data
133 | h = self.dropout(h)
134 | z = self.proj(h)
135 | z = PackedSequence(z, x.batch_sizes)
136 | else:
137 | h = h.view(-1, h.size(2))
138 | h = self.dropout(h)
139 | z = self.proj(h)
140 | z = z.view(x.size(0), x.size(1), -1)
141 |
142 | return z
143 |
144 |
145 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.utils.data
7 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
8 |
9 | def pack_sequences(X, order=None):
10 |
11 | #X = [x.squeeze(0) for x in X]
12 |
13 | n = len(X)
14 | lengths = np.array([len(x) for x in X])
15 | if order is None:
16 | order = np.argsort(lengths)[::-1]
17 | m = max(len(x) for x in X)
18 |
19 | X_block = X[0].new(n,m).zero_()
20 |
21 | for i in range(n):
22 | j = order[i]
23 | x = X[j]
24 | X_block[i,:len(x)] = x
25 |
26 | #X_block = torch.from_numpy(X_block)
27 |
28 | lengths = lengths[order]
29 | X = pack_padded_sequence(X_block, lengths, batch_first=True)
30 |
31 | return X, order
32 |
33 |
34 | def unpack_sequences(X, order):
35 | X,lengths = pad_packed_sequence(X, batch_first=True)
36 | X_block = [None]*len(order)
37 | for i in range(len(order)):
38 | j = order[i]
39 | X_block[j] = X[i,:lengths[i]]
40 | return X_block
41 |
42 |
43 | def collate_lists(args):
44 | x = [a[0] for a in args]
45 | y = [a[1] for a in args]
46 | return x, y
47 |
48 |
49 | class ContactMapDataset(torch.utils.data.Dataset):
50 | def __init__(self, X, Y, augment=None, fragment=False, mi=64, ma=500):
51 | self.X = X
52 | self.Y = Y
53 | self.augment = augment
54 | self.fragment = fragment
55 | self.mi = mi
56 | self.ma = ma
57 | """
58 | if fragment: # multiply sequence occurence by expected number of fragments
59 | lengths = np.array([len(x) for x in X])
60 | mi = np.clip(lengths, None, mi)
61 | ma = np.clip(lengths, None, ma)
62 | weights = 2*lengths/(ma + mi)
63 | mul = np.ceil(weights).astype(int)
64 | X_ = []
65 | Y_ = []
66 | for i,n in enumerate(mul):
67 | X_ += [X[i]]*n
68 | Y_ += [Y[i]]*n
69 | self.X = X_
70 | self.Y = Y_
71 | """
72 |
73 | def __len__(self):
74 | return len(self.X)
75 |
76 | def __getitem__(self, i):
77 | x = self.X[i]
78 | y = self.Y[i]
79 | if self.fragment and len(x) > self.mi:
80 | mi = self.mi
81 | ma = min(self.ma, len(x))
82 | l = np.random.randint(mi, ma+1)
83 | i = np.random.randint(len(x)-l+1)
84 | xl = x[i:i+l]
85 | yl = y[i:i+l,i:i+l]
86 | # make sure there are unmasked observations
87 | while torch.sum(yl >= 0) == 0:
88 | l = np.random.randint(mi, ma+1)
89 | i = np.random.randint(len(x)-l+1)
90 | xl = x[i:i+l]
91 | yl = y[i:i+l,i:i+l]
92 | y = yl.contiguous()
93 | x = xl
94 | if self.augment is not None:
95 | x = self.augment(x)
96 | return x, y
97 |
98 |
99 | class AllPairsDataset(torch.utils.data.Dataset):
100 | def __init__(self, X, Y, augment=None):
101 | self.X = X
102 | self.Y = Y
103 | self.augment = augment
104 |
105 | def __len__(self):
106 | return len(self.X)**2
107 |
108 | def __getitem__(self, k):
109 | n = len(self.X)
110 | i = k//n
111 | j = k%n
112 |
113 | x0 = self.X[i]
114 | x1 = self.X[j]
115 | if self.augment is not None:
116 | x0 = self.augment(x0)
117 | x1 = self.augment(x1)
118 |
119 | y = self.Y[i,j]
120 | #y = torch.cumprod((self.Y[i] == self.Y[j]).long(), 0).sum()
121 |
122 | return x0, x1, y
123 |
124 |
125 | class PairedDataset(torch.utils.data.Dataset):
126 | def __init__(self, X0, X1, Y):
127 | self.X0 = X0
128 | self.X1 = X1
129 | self.Y = Y
130 |
131 | def __len__(self):
132 | return len(self.X0)
133 |
134 | def __getitem__(self, i):
135 | return self.X0[i], self.X1[i], self.Y[i]
136 |
137 |
138 | def collate_paired_sequences(args):
139 | x0 = [a[0] for a in args]
140 | x1 = [a[1] for a in args]
141 | y = [a[2] for a in args]
142 | return x0, x1, torch.stack(y, 0)
143 |
144 |
145 | class MultinomialResample:
146 | def __init__(self, trans, p):
147 | self.p = (1-p)*torch.eye(trans.size(0)).to(trans.device) + p*trans
148 |
149 | def __call__(self, x):
150 | #print(x.size(), x.dtype)
151 | p = self.p[x] # get distribution for each x
152 | return torch.multinomial(p, 1).view(-1) # sample from distribution
153 |
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/src/transmembrane.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def encode_labels(s):
11 | y = np.zeros(len(s), dtype=int)
12 | for i in range(len(s)):
13 | if s[i:i+1] == b'I':
14 | y[i] = 0
15 | elif s[i:i+1] == b'O':
16 | y[i] = 1
17 | elif s[i:i+1] == b'M':
18 | y[i] = 2
19 | elif s[i:i+1] == b'S':
20 | y[i] = 3
21 | else:
22 | raise Exception('Unrecognized annotation: ' + s[i:i+1].decode('utf-8'))
23 | return y
24 |
25 |
26 | def transmembrane_regions(y):
27 | regions = []
28 | start = -1
29 | for i in range(len(y)):
30 | if y[i] == 2 and start < 0:
31 | start = i
32 | elif y[i] != 2 and start > 0:
33 | regions.append((start,i))
34 | start = -1
35 | if start > 0:
36 | regions.append((start, len(y)))
37 | return regions
38 |
39 |
40 | def is_prediction_correct(y_hat, y):
41 | ## prediction is correct if it has the same number of transmembrane regions
42 | ## and those overlap real transmembrane regions by at least 5 bases
43 | ## and it starts with a signaling peptide when y does
44 | pred_regions = transmembrane_regions(y_hat)
45 | target_regions = transmembrane_regions(y)
46 | if len(pred_regions) != len(target_regions):
47 | return 0
48 |
49 | for p,t in zip(pred_regions, target_regions):
50 | if p[1] <= t[0]:
51 | return 0
52 | if t[1] <= p[0]:
53 | return 0
54 | s = max(p[0], t[0])
55 | e = min(p[1], t[1])
56 | overlap = e - s
57 | if overlap < 5:
58 | return 0
59 |
60 | # finally, check signal peptide
61 | if y[0] == 3 and y_hat[0] != 3:
62 | return 0
63 |
64 | return 1
65 |
66 |
67 | ## TOPCONS uses a very specific state architecture for HMM
68 | ## we can adopt this to describe the transmembrane grammar
69 | class Grammar:
70 | def __init__(self, n_helix=21, signal_helix=True):
71 | ## describe the transmembrane states
72 | n_states = 3 + 2*n_helix
73 |
74 | start = np.zeros(n_states)
75 | start[0] = 1.0 # inner
76 | start[1] = 1.0 # outer
77 | start[2] = 1.0 # signal peptide
78 |
79 | end = np.zeros(n_states)
80 | end[0] = 1.0 # from inner
81 | end[1] = 1.0 # from outer
82 |
83 | trans = np.zeros((n_states, n_states))
84 | trans[0,0] = 1.0 # inner -> inner
85 | trans[0,3] = 1.0 # inner -> helix (i->o)
86 | trans[1,1] = 1.0 # outer -> outer
87 | trans[1,3+n_helix] = 1.0 # outer -> helix (o->i)
88 |
89 | trans[2,0] = 1.0 # signal -> inner
90 | trans[2,1] = 1.0 # signal -> outer
91 |
92 | for i in range(3,2+n_helix): # i->o helices
93 | trans[i,i+1] = 1.0
94 | trans[2+n_helix,1] = 1.0 # helix (i->o) -> outer
95 |
96 | for i in range(3+n_helix,2+2*n_helix): # o->i helices
97 | trans[i,i+1] = 1.0
98 | trans[2+2*n_helix,0] = 1.0 # helix (o->i) -> inner
99 |
100 | emit = np.zeros((n_states, 4))
101 | emit[0,0] = 1.0 # inner
102 | emit[0,1] = 1.0
103 | emit[1,0] = 1.0 # outer
104 | emit[1,1] = 1.0
105 | emit[2,3] = 1.0 # signal peptide
106 | #if signal_helix:
107 | # emit[2,2] = 1.0
108 | for i in range(3,3+2*n_helix): # helices
109 | emit[i,2] = 1.0
110 |
111 | mapping = np.zeros(n_states, dtype=int)
112 | mapping[0] = 0
113 | mapping[1] = 1
114 | mapping[2] = 3
115 | mapping[3:3+2*n_helix] = 2
116 |
117 | self.start = np.log(start) - np.log(start.sum())
118 | self.end = np.log(end) - np.log(end.sum())
119 | self.trans = np.log(trans) - np.log(trans.sum(1, keepdims=True))
120 | self.emit = emit
121 | self.mapping = mapping
122 |
123 | def decode(self, logp):
124 | p = np.exp(logp)
125 | z = np.log(np.dot(p, self.emit.T))
126 |
127 | tb = np.zeros(z.shape, dtype=np.int8) - 1
128 | p0 = z[0] + self.start
129 | for i in range(z.shape[0] - 1):
130 | trans = p0[:,np.newaxis] + self.trans + z[i+1] #
131 | tb[i+1] = np.argmax(trans, 0)
132 | p0 = np.max(trans, 0)
133 | # transition to end
134 | p0 = p0 + self.end
135 | state = np.argmax(p0)
136 | score = np.max(p0)
137 | # traceback most likely sequence of states
138 | y = np.zeros(z.shape[0], dtype=int)
139 | j = state
140 | y[-1] = j
141 | for i in range(z.shape[0]-1, 0, -1):
142 | j = tb[i,j]
143 | y[i-1] = j
144 |
145 | # map the states
146 | y = self.mapping[y]
147 |
148 | return y, score
149 |
150 | def predict_viterbi(self, xs, model, use_cuda=False):
151 | y_hats = []
152 | with torch.no_grad():
153 | for x in xs:
154 | if use_cuda:
155 | x = x.cuda()
156 | log_p_hat = F.log_softmax(model(x), 1).cpu().numpy()
157 | y_hat,_ = self.decode(log_p_hat)
158 | y_hats.append(y_hat)
159 | return y_hats
160 |
161 |
162 |
163 |
164 |
165 |
--------------------------------------------------------------------------------
/src/models/sequence.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
7 |
8 | class BiLM(nn.Module):
9 | def __init__(self, nin, nout, embedding_dim, hidden_dim, num_layers
10 | , tied=True, mask_idx=None, dropout=0):
11 | super(BiLM, self).__init__()
12 |
13 | if mask_idx is None:
14 | mask_idx = nin-1
15 | self.mask_idx = mask_idx
16 | self.embed = nn.Embedding(nin, embedding_dim, padding_idx=mask_idx)
17 | self.dropout = nn.Dropout(p=dropout)
18 |
19 | self.tied = tied
20 | if tied:
21 | layers = []
22 | nin = embedding_dim
23 | for _ in range(num_layers):
24 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True))
25 | nin = hidden_dim
26 | self.rnn = nn.ModuleList(layers)
27 | else:
28 | layers = []
29 | nin = embedding_dim
30 | for _ in range(num_layers):
31 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True))
32 | nin = hidden_dim
33 | self.lrnn = nn.ModuleList(layers)
34 |
35 | layers = []
36 | nin = embedding_dim
37 | for _ in range(num_layers):
38 | layers.append(nn.LSTM(nin, hidden_dim, 1, batch_first=True))
39 | nin = hidden_dim
40 | self.rrnn = nn.ModuleList(layers)
41 |
42 | self.linear = nn.Linear(hidden_dim, nout)
43 |
44 | def hidden_size(self):
45 | h = 0
46 | if self.tied:
47 | for layer in self.rnn:
48 | h += 2*layer.hidden_size
49 | else:
50 | for layer in self.lrnn:
51 | h += layer.hidden_size
52 | for layer in self.rrnn:
53 | h += layer.hidden_size
54 | return h
55 |
56 |
57 | def transform(self, z, last_only=False):
58 | # sequences are flanked by the start/stop token as:
59 | # [stop, x, stop]
60 |
61 | idx = [i for i in range(z.size(1)-1, -1, -1)]
62 | idx = torch.LongTensor(idx).to(z.device)
63 | z_rvs = z.index_select(1, idx)
64 |
65 | z = z[:,:-1]
66 | z_rvs = z_rvs[:,:-1]
67 | idx = [i for i in range(z_rvs.size(1)-1, -1, -1)]
68 | idx = torch.LongTensor(idx).to(z_rvs.device)
69 |
70 | if last_only:
71 | if self.tied:
72 | h_fwd = z
73 | h_rvs = z_rvs
74 | for rnn in self.rnn:
75 | h_fwd,_ = rnn(h_fwd)
76 | h_fwd = self.dropout(h_fwd)
77 | h_rvs,_ = rnn(h_rvs)
78 | h_rvs = self.dropout(h_rvs)
79 | else:
80 | h_fwd = z
81 | h_rvs = z_rvs
82 | for lrnn,rrnn in zip(self.lrnn, self.rrnn):
83 | h_fwd,_ = lrnn(h_fwd)
84 | h_fwd = self.dropout(h_fwd)
85 | h_rvs,_ = rrnn(h_rvs)
86 | h_rvs = self.dropout(h_rvs)
87 | hidden = (h_fwd, h_rvs.index_select(1, idx))
88 | else:
89 | hidden = []
90 | if self.tied:
91 | h_fwd = z
92 | h_rvs = z_rvs
93 | for rnn in self.rnn:
94 | h_fwd,_ = rnn(h_fwd)
95 | h_fwd = self.dropout(h_fwd)
96 | h_rvs,_ = rnn(h_rvs)
97 | h_rvs = self.dropout(h_rvs)
98 | hidden.append((h_fwd, h_rvs.index_select(1, idx)))
99 | else:
100 | h_fwd = z
101 | h_rvs = z_rvs
102 | for lrnn,rrnn in zip(self.lrnn, self.rrnn):
103 | h_fwd,_ = lrnn(h_fwd)
104 | h_fwd = self.dropout(h_fwd)
105 | h_rvs,_ = rrnn(h_rvs)
106 | h_rvs = self.dropout(h_rvs)
107 | hidden.append((h_fwd, h_rvs.index_select(1, dx)))
108 |
109 | return hidden
110 |
111 |
112 | def encode(self, x):
113 | packed = type(x) is PackedSequence
114 | if packed:
115 | # pad with the start/stop token
116 | x,batch_sizes = pad_packed_sequence(x, batch_first=True, padding_value=self.mask_idx-1)
117 | x = x + 1
118 | ## append start/stop tokens to x
119 | x_ = x.data.new(x.size(0), x.size(1)+2).zero_()
120 | x_[:,1:-1] = x
121 | x = x_
122 |
123 |
124 | # sequences x are flanked by the start/stop token as:
125 | # [stop, x, stop]
126 |
127 | z = self.embed(x)
128 | hidden = self.transform(z)
129 |
130 | concat = []
131 | for h_fwd,h_rvs in hidden:
132 | h_fwd = h_fwd[:,:-1]
133 | h_rvs = h_rvs[:,1:]
134 | concat.append(h_fwd)
135 | concat.append(h_rvs)
136 |
137 | h = torch.cat(concat, 2)
138 | if packed:
139 | h = pack_padded_sequence(h, batch_sizes, batch_first=True)
140 |
141 | return h
142 |
143 |
144 | def forward(self, x):
145 | packed = type(x) is PackedSequence
146 | if packed:
147 | # pad with the start/stop token
148 | x,batch_sizes = pad_packed_sequence(x, batch_first=True, padding_value=self.mask_idx)
149 |
150 | # sequences x are flanked by the start/stop token as:
151 | # [stop, x, stop]
152 |
153 | z = self.embed(x)
154 |
155 | h_fwd,h_rvs = self.transform(z, last_only=True)
156 |
157 | b = z.size(0)
158 | n = z.size(1) - 1
159 |
160 | h_fwd = h_fwd.contiguous()
161 | h_flat = h_fwd.view(b*n, h_fwd.size(2))
162 | logp_fwd = self.linear(h_flat)
163 | logp_fwd = logp_fwd.view(b, n, -1)
164 |
165 | zero = h_fwd.data.new(b,1,logp_fwd.size(2)).zero_()
166 | logp_fwd = torch.cat([zero, logp_fwd], 1)
167 |
168 | h_rvs = h_rvs.contiguous()
169 | logp_rvs = self.linear(h_rvs.view(-1, h_rvs.size(2))).view(b, n, -1)
170 | logp_rvs = torch.cat([logp_rvs, zero], 1)
171 |
172 | logp = F.log_softmax(logp_fwd + logp_rvs, dim=2)
173 | if packed:
174 | logp = pack_padded_sequence(logp, batch_sizes, batch_first=True)
175 |
176 | return logp
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/embed_sequences.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import sys
4 | import numpy as np
5 | import h5py
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from src.alphabets import Uniprot21
12 | import src.fasta as fasta
13 | import src.models.sequence
14 |
15 |
16 | def unstack_lstm(lstm):
17 | device = next(iter(lstm.parameters())).device
18 |
19 | in_size = lstm.input_size
20 | hidden_dim = lstm.hidden_size
21 | layers = []
22 | for i in range(lstm.num_layers):
23 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True)
24 | layer.to(device)
25 |
26 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l']
27 | for attr in attributes:
28 | dest = attr + '0'
29 | src = attr + str(i)
30 | getattr(layer, dest).data[:] = getattr(lstm, src)
31 | #setattr(layer, dest, getattr(lstm, src))
32 |
33 | dest = attr + '0_reverse'
34 | src = attr + str(i) + '_reverse'
35 | getattr(layer, dest).data[:] = getattr(lstm, src)
36 | #setattr(layer, dest, getattr(lstm, src))
37 | layer.flatten_parameters()
38 | layers.append(layer)
39 | in_size = 2*hidden_dim
40 | return layers
41 |
42 | def embed_stack(x, lm_embed, lstm_stack, proj, include_lm=True, final_only=False):
43 | zs = []
44 |
45 | x_onehot = x.new(x.size(0),x.size(1), 21).float().zero_()
46 | x_onehot.scatter_(2,x.unsqueeze(2),1)
47 | zs.append(x_onehot)
48 |
49 | h = lm_embed(x)
50 | if include_lm and not final_only:
51 | zs.append(h)
52 |
53 | if lstm_stack is not None:
54 | for lstm in lstm_stack:
55 | h,_ = lstm(h)
56 | if not final_only:
57 | zs.append(h)
58 | h = proj(h.squeeze(0)).unsqueeze(0)
59 | zs.append(h)
60 |
61 | z = torch.cat(zs, 2)
62 | return z
63 |
64 |
65 | def embed_sequence(x, lm_embed, lstm_stack, proj, include_lm=True, final_only=False
66 | , pool='none', use_cuda=False):
67 |
68 | if len(x) == 0:
69 | return None
70 |
71 | alphabet = Uniprot21()
72 | x = x.upper()
73 | # convert to alphabet index
74 | x = alphabet.encode(x)
75 | x = torch.from_numpy(x)
76 | if use_cuda:
77 | x = x.cuda()
78 |
79 | # embed the sequence
80 | with torch.no_grad():
81 | x = x.long().unsqueeze(0)
82 | z = embed_stack(x, lm_embed, lstm_stack, proj
83 | , include_lm=include_lm, final_only=final_only)
84 | # pool if needed
85 | z = z.squeeze(0)
86 | if pool == 'sum':
87 | z = z.sum(0)
88 | elif pool == 'max':
89 | z,_ = z.max(0)
90 | elif pool == 'avg':
91 | z = z.mean(0)
92 | z = z.cpu().numpy()
93 |
94 | return z
95 |
96 |
97 | def load_model(path, use_cuda=False):
98 | encoder = torch.load(path)
99 | encoder.eval()
100 |
101 | if use_cuda:
102 | encoder.cuda()
103 |
104 | if type(encoder) is src.models.sequence.BiLM:
105 | # model is only the LM
106 | return encoder.encode, None, None
107 |
108 | encoder = encoder.embedding
109 |
110 | lm_embed = encoder.embed
111 | lstm_stack = unstack_lstm(encoder.rnn)
112 | proj = encoder.proj
113 |
114 | return lm_embed, lstm_stack, proj
115 |
116 |
117 | def main():
118 | import argparse
119 | parser = argparse.ArgumentParser('Script for embedding fasta format sequences using a saved embedding model. Saves embeddings as HDF5 file.')
120 |
121 | parser.add_argument('path', help='sequences to embed in fasta format')
122 | parser.add_argument('-m', '--model', help='path to saved embedding model')
123 | parser.add_argument('-o', '--output', help='path to HDF5 output file')
124 | parser.add_argument('--lm-only', action='store_true', help='only return the language model hidden layers')
125 | parser.add_argument('--no-lm', action='store_true', help='do not include LM hidden layers in embedding. by default, all hidden layers of all layers are concatenated and returned by this script.')
126 | parser.add_argument('--proj-only', action='store_true', help='only return the final structure-learned embedding')
127 | parser.add_argument('--pool', choices=['none', 'sum', 'max', 'avg'], default='none', help='apply some pooling operation over each sequence (default: none)')
128 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
129 |
130 | args = parser.parse_args()
131 |
132 | path = args.path
133 |
134 | # set the device
135 | d = args.device
136 | use_cuda = (d != -1) and torch.cuda.is_available()
137 | if d >= 0:
138 | torch.cuda.set_device(d)
139 |
140 | # load the model
141 | lm_embed, lstm_stack, proj = load_model(args.model, use_cuda=use_cuda)
142 |
143 | # parse the sequences and embed them
144 | # write them to hdf5 file
145 | print('# writing:', args.output, file=sys.stderr)
146 | h5 = h5py.File(args.output, 'w')
147 |
148 | lm_only = args.lm_only
149 | if lm_only:
150 | lstm_stack = None
151 | proj = None
152 |
153 | no_lm = args.no_lm
154 | include_lm = not no_lm
155 | final_only = args.proj_only
156 |
157 | pool = args.pool
158 | print('# embedding with lm_only={}, no_lm={}, proj_only={}'.format(lm_only, no_lm, final_only), file=sys.stderr)
159 | print('# pooling:', pool, file=sys.stderr)
160 |
161 | count = 0
162 | with open(path, 'rb') as f:
163 | for name,sequence in fasta.parse_stream(f):
164 | # use sequence name as HDF key
165 | pid = name.decode('utf-8')
166 | if len(sequence) == 0:
167 | print('# WARNING: sequence', pid, 'has length=0. Skipping.', file=sys.stderr)
168 | continue
169 | # only do pids we haven't done already...
170 | if pid not in h5:
171 | z = embed_sequence(sequence, lm_embed, lstm_stack, proj
172 | , include_lm=include_lm, final_only=final_only
173 | , pool=pool, use_cuda=use_cuda)
174 | # write as hdf5 dataset
175 | h5.create_dataset(pid, data=z, compression='lzf')
176 | count += 1
177 | print('# {} sequences processed...'.format(count), file=sys.stderr, end='\r')
178 | print(' '*80, file=sys.stderr, end='\r')
179 | print('# Done!', file=sys.stderr)
180 |
181 |
182 |
183 | if __name__ == '__main__':
184 | main()
185 |
186 |
187 |
188 |
189 |
--------------------------------------------------------------------------------
/eval_contact_scop.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 | import sys
5 | import os
6 | import glob
7 | from PIL import Image
8 |
9 | import torch
10 | from torch.nn.utils.rnn import PackedSequence
11 | import torch.utils.data
12 |
13 | from src.alphabets import Uniprot21
14 | import src.fasta as fasta
15 | from src.utils import pack_sequences, unpack_sequences
16 | from src.utils import ContactMapDataset, collate_lists
17 | from src.metrics import average_precision
18 |
19 |
20 | def load_data(seq_path, struct_path, alphabet, baselines=False):
21 |
22 | pdb_index = {}
23 | for path in struct_path:
24 | pid = os.path.basename(path)[:7]
25 | pdb_index[pid] = path
26 |
27 | with open(seq_path, 'rb') as f:
28 | names, sequences = fasta.parse(f)
29 | names = [name.split()[0].decode('utf-8') for name in names]
30 | sequences = [alphabet.encode(s.upper()) for s in sequences]
31 |
32 | x = [torch.from_numpy(x).long() for x in sequences]
33 |
34 | names_ = []
35 | x_ = []
36 | y = []
37 | for xi,name in zip(x,names):
38 | pid = name
39 | if pid not in pdb_index:
40 | pid = 'd' + pid[1:]
41 | path = pdb_index[pid]
42 |
43 | im = np.array(Image.open(path), copy=False)
44 | contacts = np.zeros(im.shape, dtype=np.float32)
45 | contacts[im == 1] = -1
46 | contacts[im == 255] = 1
47 |
48 | # mask the matrix below the diagonal
49 | mask = np.tril_indices(contacts.shape[0], k=1)
50 | contacts[mask] = -1
51 |
52 | names_.append(name)
53 | x_.append(xi)
54 | y.append(torch.from_numpy(contacts))
55 |
56 | return x_, y, names_
57 |
58 |
59 | def predict_minibatch(model, x, use_cuda):
60 | b = len(x)
61 | x,order = pack_sequences(x)
62 | x = PackedSequence(x.data, x.batch_sizes)
63 | z = model(x) # embed the sequences
64 | z = unpack_sequences(z, order)
65 |
66 | logits = []
67 | for i in range(b):
68 | zi = z[i]
69 | lp = model.predict(zi.unsqueeze(0)).view(zi.size(0), zi.size(0))
70 | logits.append(lp)
71 |
72 | return logits
73 |
74 |
75 | def calc_metrics(logits, y):
76 | y_hat = (logits > 0).astype(np.float32)
77 | TP = (y_hat*y).sum()
78 | precision = 1.0
79 | if y_hat.sum() > 0:
80 | precision = TP/y_hat.sum()
81 | recall = TP/y.sum()
82 | F1 = 0
83 | if precision + recall > 0:
84 | F1 = 2*precision*recall/(precision + recall)
85 | AUPR = average_precision(y, logits)
86 | return precision, recall, F1, AUPR
87 |
88 |
89 | def main():
90 | import argparse
91 | parser = argparse.ArgumentParser('Script for evaluating contact map models.')
92 | parser.add_argument('model', help='path to saved model')
93 | parser.add_argument('--dataset', default='2.06 test', help='which dataset (default: 2.06 test)')
94 | parser.add_argument('--batch-size', default=10, type=int, help='number of sequences to process in each batch (default: 10)')
95 | parser.add_argument('-o', '--output', help='output file path (default: stdout)')
96 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
97 | args = parser.parse_args()
98 |
99 | # load the data
100 | if args.dataset == '2.06 test':
101 | fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa'
102 | contact_paths = glob.glob('data/SCOPe/pdbstyle-2.06/*/*.png')
103 | elif args.dataset == '2.07 test' or args.dataset == '2.07 new test':
104 | fasta_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.fa'
105 | contact_paths = glob.glob('data/SCOPe/pdbstyle-2.07/*/*.png')
106 | else:
107 | raise Exception('Bad dataset argument ' + args.dataset)
108 |
109 | alphabet = Uniprot21()
110 | x,y,names = load_data(fasta_path, contact_paths, alphabet)
111 |
112 | ## set the device
113 | d = args.device
114 | use_cuda = (d != -1) and torch.cuda.is_available()
115 | if d >= 0:
116 | torch.cuda.set_device(d)
117 |
118 | if use_cuda:
119 | x = [x_.cuda() for x_ in x]
120 | y = [y_.cuda() for y_ in y]
121 |
122 | model = torch.load(args.model)
123 | model.eval()
124 | if use_cuda:
125 | model.cuda()
126 |
127 | # predict contact maps
128 | batch_size = args.batch_size
129 | dataset = ContactMapDataset(x, y)
130 | iterator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_lists)
131 | logits = []
132 | with torch.no_grad():
133 | for xmb,ymb in iterator:
134 | lmb = predict_minibatch(model, xmb, use_cuda)
135 | logits += lmb
136 |
137 | # calculate performance metrics
138 | lengths = np.array([len(x_) for x_ in x])
139 | logits = [logit.cpu().numpy() for logit in logits]
140 | y = [y_.cpu().numpy() for y_ in y]
141 |
142 | output = args.output
143 | if output is None:
144 | output = sys.stdout
145 | else:
146 | output = open(output, 'w')
147 | line = '\t'.join(['Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
148 | print(line, file=output)
149 | output.flush()
150 |
151 | # for all contacts
152 | y_flat = []
153 | logits_flat = []
154 | for i in range(len(y)):
155 | yi = y[i]
156 | mask = (yi < 0)
157 | y_flat.append(yi[~mask])
158 | logits_flat.append(logits[i][~mask])
159 |
160 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts
161 | precision = np.zeros(len(x))
162 | recall = np.zeros(len(x))
163 | F1 = np.zeros(len(x))
164 | AUPR = np.zeros(len(x))
165 | prL = np.zeros(len(x))
166 | prL2 = np.zeros(len(x))
167 | prL5 = np.zeros(len(x))
168 | for i in range(len(x)):
169 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
170 | precision[i] = pr
171 | recall[i] = re
172 | F1[i] = f1
173 | AUPR[i] = aupr
174 |
175 | order = np.argsort(logits_flat[i])[::-1]
176 | n = lengths[i]
177 | topL = order[:n]
178 | prL[i] = y_flat[i][topL].mean()
179 | topL2 = order[:n//2]
180 | prL2[i] = y_flat[i][topL2].mean()
181 | topL5 = order[:n//5]
182 | prL5[i] = y_flat[i][topL5].mean()
183 |
184 | template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
185 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
186 | print(line, file=output)
187 | output.flush()
188 |
189 | # for Medium/Long range contacts
190 | y_flat = []
191 | logits_flat = []
192 | for i in range(len(y)):
193 | yi = y[i]
194 | mask = (yi < 0)
195 |
196 | medlong = np.tril_indices(len(yi), k=11)
197 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8)
198 | medlong_mask[medlong] = 1
199 | mask = mask | (medlong_mask == 1)
200 |
201 | y_flat.append(yi[~mask])
202 | logits_flat.append(logits[i][~mask])
203 |
204 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts
205 | precision = np.zeros(len(x))
206 | recall = np.zeros(len(x))
207 | F1 = np.zeros(len(x))
208 | AUPR = np.zeros(len(x))
209 | prL = np.zeros(len(x))
210 | prL2 = np.zeros(len(x))
211 | prL5 = np.zeros(len(x))
212 | for i in range(len(x)):
213 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
214 | precision[i] = pr
215 | recall[i] = re
216 | F1[i] = f1
217 | AUPR[i] = aupr
218 |
219 | order = np.argsort(logits_flat[i])[::-1]
220 | n = lengths[i]
221 | topL = order[:n]
222 | prL[i] = y_flat[i][topL].mean()
223 | topL2 = order[:n//2]
224 | prL2[i] = y_flat[i][topL2].mean()
225 | topL5 = order[:n//5]
226 | prL5[i] = y_flat[i][topL5].mean()
227 |
228 | template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
229 | line = template.format(np.nanmean(precision), np.nanmean(recall), np.nanmean(F1), np.nanmean(AUPR), np.nanmean(prL)
230 | , np.nanmean(prL2), np.nanmean(prL5))
231 | print(line, file=output)
232 | output.flush()
233 |
234 |
235 |
236 | if __name__ == '__main__':
237 | main()
238 |
239 |
240 |
241 |
242 |
243 |
--------------------------------------------------------------------------------
/eval_similarity.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import os
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 |
12 | from scipy.stats import pearsonr,spearmanr
13 |
14 | from src.utils import pack_sequences, unpack_sequences
15 | from src.alphabets import Uniprot21
16 | from src.alignment import nw_score
17 | from src.metrics import average_precision
18 |
19 |
20 | def encode_sequence(x, alphabet):
21 | # convert to bytes and uppercase
22 | x = x.encode('utf-8').upper()
23 | # convert to alphabet index
24 | x = alphabet.encode(x)
25 | return x
26 |
27 |
28 | def load_pairs(path, alphabet):
29 | table = pd.read_csv(path, sep='\t')
30 |
31 | x0 = [encode_sequence(x, alphabet) for x in table['sequence_A']]
32 | x1 = [encode_sequence(x, alphabet) for x in table['sequence_B']]
33 | y = table['similarity'].values
34 |
35 | return x0, x1, y
36 |
37 |
38 | class NWAlign:
39 | def __init__(self, alphabet):
40 | from Bio.SubsMat import MatrixInfo as matlist
41 | L = len(alphabet)
42 | subst = np.zeros((L,L), dtype=np.int32)
43 | for i in range(L):
44 | for j in range(i,L):
45 | a = alphabet[i]
46 | b = alphabet[j]
47 | subst[i,j] = subst[j,i] = matlist.blosum62[(b,a)]
48 | self.subst = subst
49 | self.gap = -11
50 | self.extend = -1
51 |
52 | def __call__(self, x, y):
53 | b = len(x)
54 | scores = np.zeros(b)
55 | for i in range(b):
56 | scores[i] = nw_score(x[i], y[i], self.subst, self.gap, self.extend)
57 | return scores
58 |
59 |
60 | class TorchModel:
61 | def __init__(self, model, use_cuda, mode='ssa'):
62 | self.model = model
63 | self.use_cuda = use_cuda
64 | self.mode = mode
65 |
66 | def __call__(self, x, y):
67 | n = len(x)
68 | c = [torch.from_numpy(x_).long() for x_ in x] + [torch.from_numpy(y_).long() for y_ in y]
69 |
70 | c,order = pack_sequences(c)
71 | if self.use_cuda:
72 | c = c.cuda()
73 |
74 | with torch.no_grad():
75 | z = self.model(c) # embed the sequences
76 | z = unpack_sequences(z, order)
77 |
78 | scores = np.zeros(n)
79 | if self.mode == 'align':
80 | for i in range(n):
81 | z_x = z[i]
82 | z_y = z[i+n]
83 |
84 | logits = self.model.score(z_x, z_y)
85 | p = torch.sigmoid(logits).cpu()
86 | p_ge = torch.ones(p.size(0)+1)
87 | p_ge[1:] = p
88 | p_lt = torch.ones(p.size(0)+1)
89 | p_lt[:-1] = 1 - p
90 | p = p_ge*p_lt
91 | p = p/p.sum() # make sure p is normalized
92 | levels = torch.arange(5).float()
93 | scores[i] = torch.sum(p*levels).item()
94 |
95 | elif self.mode == 'coarse':
96 | z_x = z[:n]
97 | z_y = z[n:]
98 | z_x = torch.stack([z.mean(0) for z in z_x], 0)
99 | z_y = torch.stack([z.mean(0) for z in z_y], 0)
100 | scores[:] = -torch.sum(torch.abs(z_x - z_y), 1).cpu().numpy()
101 |
102 | return scores
103 |
104 | def find_best_threshold(x, y, tr0=-np.inf):
105 | order = np.argsort(x)
106 |
107 | tp = np.zeros(len(x)+1)
108 | tp[0] = y.sum()
109 | tn = np.zeros(len(x)+1)
110 | tn[0] = 0
111 |
112 |
113 | for i in range(len(x)):
114 | j = order[i]
115 | tp[i+1] = tp[i] - y[j]
116 | tn[i+1] = tn[i] + 1 - y[j]
117 |
118 | acc = (tp + tn)/len(y)
119 | i = np.argmax(acc) - 1
120 |
121 | tr = x[order[i]]
122 | if i < 0:
123 | tr = tr0
124 |
125 | return tr
126 |
127 |
128 | def find_best_thresholds(x, y):
129 | thresholds = np.zeros(5)
130 | thresholds[0] = -np.inf
131 | for i in range(4):
132 | mask = (x > thresholds[i])
133 | xi = x[mask]
134 | labels = (y[mask] > i)
135 | tr = find_best_threshold(xi, labels, tr0=thresholds[i])
136 | thresholds[i+1] = tr
137 | return thresholds
138 |
139 |
140 | def calculate_metrics(scores, y, thresholds):
141 | ## calculate accuracy, r, rho
142 | pred_level = np.digitize(scores, thresholds[1:], right=True)
143 | accuracy = np.mean(pred_level == y)
144 | r,_ = pearsonr(scores, y)
145 | rho,_ = spearmanr(scores, y)
146 | ## calculate average-precision score for each structural level
147 | aupr = np.zeros(4, dtype=np.float32)
148 | for i in range(4):
149 | target = (y > i).astype(np.float32)
150 | aupr[i] = average_precision(target, scores.astype(np.float32))
151 | return accuracy, r, rho, aupr
152 |
153 |
154 | def score_pairs(model, x0, x1, batch_size=100):
155 | scores = []
156 | for i in range(0, len(x0), batch_size):
157 | x0_mb = x0[i:i+batch_size]
158 | x1_mb = x1[i:i+batch_size]
159 | scores.append(model(x0_mb, x1_mb))
160 | scores = np.concatenate(scores, 0)
161 | return scores
162 |
163 |
164 | def main():
165 | import argparse
166 | parser = argparse.ArgumentParser('Script for evaluating similarity model on SCOP test set.')
167 |
168 | parser.add_argument('model', help='path to saved model file or "nw-align" for Needleman-Wunsch alignment score baseline')
169 |
170 | parser.add_argument('--dev', action='store_true', help='use train/dev split')
171 |
172 | parser.add_argument('--batch-size', default=64, type=int, help='number of sequence pairs to process in each batch (default: 64)')
173 |
174 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
175 |
176 | parser.add_argument('--coarse', action='store_true', help='use coarse comparison rather than full SSA')
177 |
178 | args = parser.parse_args()
179 |
180 | scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.txt'
181 |
182 | eval_paths = [ ( '2.06-test'
183 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt')
184 | , ( '2.07-new'
185 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.allpairs.txt')
186 | ]
187 | if args.dev:
188 | scop_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.sampledpairs.txt'
189 |
190 | eval_paths = [ ( '2.06-dev'
191 | , 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt')
192 | ]
193 |
194 |
195 | ## load the data
196 | alphabet = Uniprot21()
197 | x0_train, x1_train, y_train = load_pairs(scop_train_path, alphabet)
198 |
199 | ## load the model
200 | if args.model == 'nw-align':
201 | model = NWAlign(alphabet)
202 | elif args.model in ['hhalign', 'phmmer', 'TMalign']:
203 | model = args.model
204 | else:
205 | model = torch.load(args.model)
206 | model.eval()
207 |
208 | ## set the device
209 | d = args.device
210 | use_cuda = (d != -1) and torch.cuda.is_available()
211 | if d >= 0:
212 | torch.cuda.set_device(d)
213 |
214 | if use_cuda:
215 | model.cuda()
216 |
217 | mode = 'align'
218 | if args.coarse:
219 | mode = 'coarse'
220 | model = TorchModel(model, use_cuda, mode=mode)
221 |
222 | batch_size = args.batch_size
223 |
224 | ## for calculating the classification accuracy, first find the best partitions using the training set
225 | if type(model) is str:
226 | path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.sampledpairs.' \
227 | + model + '.npy'
228 | scores = np.load(path)
229 | scores = scores.mean(1)
230 | else:
231 | scores = score_pairs(model, x0_train, x1_train, batch_size)
232 | thresholds = find_best_thresholds(scores, y_train)
233 |
234 | print('Dataset\tAccuracy\tPearson\'s r\tSpearman\'s rho\tClass\tFold\tSuperfamily\tFamily')
235 |
236 | accuracy, r, rho, aupr = calculate_metrics(scores, y_train, thresholds)
237 |
238 | template = '{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
239 |
240 | line = '2.06-train\t' + template.format(accuracy, r, rho, aupr[0], aupr[1], aupr[2], aupr[3])
241 | #line = '\t'.join(['2.06-train', str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])])
242 | print(line)
243 |
244 | for dset,path in eval_paths:
245 | x0_test, x1_test, y_test = load_pairs(path, alphabet)
246 | if type(model) is str:
247 | path = os.path.splitext(path)[0]
248 | path = path + '.' + model + '.npy'
249 | scores = np.load(path)
250 | scores = scores.mean(1)
251 | else:
252 | scores = score_pairs(model, x0_test, x1_test, batch_size)
253 | accuracy, r, rho, aupr = calculate_metrics(scores, y_test, thresholds)
254 |
255 | line = dset + '\t' + template.format(accuracy, r, rho, aupr[0], aupr[1], aupr[2], aupr[3])
256 | #line = '\t'.join([dset, str(accuracy), str(r), str(rho), str(aupr[0]), str(aupr[1]), str(aupr[2]), str(aupr[3])])
257 | print(line)
258 |
259 |
260 | if __name__ == '__main__':
261 | main()
262 |
263 |
--------------------------------------------------------------------------------
/train_lm_pfam.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import sys
4 | import argparse
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | import torch.utils.data
12 | from torch.nn.utils.rnn import pack_padded_sequence
13 |
14 | import src.fasta as fasta
15 | from src.alphabets import Uniprot21
16 | import src.models.sequence
17 |
18 | parser = argparse.ArgumentParser('Train sequence model')
19 |
20 | parser.add_argument('-b', '--minibatch-size', type=int, default=32, help='minibatch size (default: 32)')
21 | parser.add_argument('-n', '--num-epochs', type=int, default=10, help='number of epochs (default: 10)')
22 |
23 | parser.add_argument('--hidden-dim', type=int, default=512, help='hidden dimension of RNN (default: 512)')
24 | parser.add_argument('--num-layers', type=int, default=2, help='number of RNN layers (default: 2)')
25 | parser.add_argument('--dropout', type=float, default=0, help='dropout (default: 0)')
26 |
27 | parser.add_argument('--untied', action='store_true', help='use biRNN with untied weights')
28 |
29 | parser.add_argument('--l2', type=float, default=0, help='l2 regularizer (default: 0)')
30 |
31 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 0.001)')
32 | parser.add_argument('--clip', type=float, default=1, help='gradient clipping max norm (default: 1)')
33 |
34 | parser.add_argument('-d', '--device', type=int, default=-2, help='device to use, -1: cpu, 0+: gpu (default: gpu if available, else cpu)')
35 |
36 | parser.add_argument('-o', '--output', help='where to write training curve (default: stdout)')
37 | parser.add_argument('--save-prefix', help='path prefix for saving models (default: no saving)')
38 |
39 |
40 | pfam_train = 'data/pfam/Pfam-A.train.fasta'
41 | pfam_test = 'data/pfam/Pfam-A.test.fasta'
42 |
43 |
44 | def preprocess_sequence(s, alphabet):
45 | x = alphabet.encode(s)
46 | # pad with start/stop token
47 | z = np.zeros(len(x)+2, dtype=x.dtype)
48 | z[1:-1] = x + 1
49 | return z
50 |
51 | def load_pfam(path, alph):
52 | # load path sequences and families
53 | with open(path, 'rb') as f:
54 | group = []
55 | sequences = []
56 | for name,sequence in fasta.parse_stream(f):
57 | x = preprocess_sequence(sequence.upper(), alph)
58 | sequences.append(x)
59 | family = name.split(b';')[-2]
60 | group.append(family)
61 | group = np.array(group)
62 | sequences = np.array(sequences)
63 | return group, sequences
64 |
65 |
66 | def main():
67 | args = parser.parse_args()
68 |
69 | alph = Uniprot21()
70 | ntokens = len(alph)
71 |
72 | ## load the training sequences
73 | train_group, X_train = load_pfam(pfam_train, alph)
74 | print('# loaded', len(X_train), 'sequences from', pfam_train, file=sys.stderr)
75 |
76 | ## load the testing sequences
77 | test_group, X_test = load_pfam(pfam_test, alph)
78 | print('# loaded', len(X_test), 'sequences from', pfam_test, file=sys.stderr)
79 |
80 | ## initialize the model
81 | nin = ntokens + 1
82 | nout = ntokens
83 | embedding_dim = 21
84 | hidden_dim = args.hidden_dim
85 | num_layers = args.num_layers
86 | mask_idx = ntokens
87 | dropout = args.dropout
88 |
89 | tied = not args.untied
90 |
91 | model = src.models.sequence.BiLM(nin, nout, embedding_dim, hidden_dim, num_layers
92 | , mask_idx=mask_idx, dropout=dropout, tied=tied)
93 | print('# initialized model', file=sys.stderr)
94 |
95 | device = args.device
96 | use_cuda = torch.cuda.is_available() and (device == -2 or device >= 0)
97 | if device >= 0:
98 | torch.cuda.set_device(device)
99 | if use_cuda:
100 | model = model.cuda()
101 |
102 | ## form the data iterators and optimizer
103 | lr = args.lr
104 | l2 = args.l2
105 | solver = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
106 |
107 | def collate(xs):
108 | B = len(xs)
109 | N = max(len(x) for x in xs)
110 | lengths = np.array([len(x) for x in xs], dtype=int)
111 |
112 | order = np.argsort(lengths)[::-1]
113 | lengths = lengths[order]
114 |
115 | X = torch.LongTensor(B, N).zero_() + mask_idx
116 | for i in range(B):
117 | x = xs[order[i]]
118 | n = len(x)
119 | X[i,:n] = torch.from_numpy(x)
120 | return X, lengths
121 |
122 | mb = args.minibatch_size
123 |
124 | train_iterator = torch.utils.data.DataLoader(X_train, batch_size=mb, shuffle=True
125 | , collate_fn=collate)
126 | test_iterator = torch.utils.data.DataLoader(X_test, batch_size=mb
127 | , collate_fn=collate)
128 |
129 | ## fit the model!
130 |
131 | print('# training model', file=sys.stderr)
132 |
133 | output = sys.stdout
134 | if args.output is not None:
135 | output = open(args.output, 'w')
136 |
137 | num_epochs = args.num_epochs
138 | clip = args.clip
139 |
140 | save_prefix = args.save_prefix
141 | digits = int(np.floor(np.log10(num_epochs))) + 1
142 |
143 | print('epoch\tsplit\tlog_p\tperplexity\taccuracy', file=output)
144 | output.flush()
145 |
146 | for epoch in range(num_epochs):
147 | # train epoch
148 | model.train()
149 | it = 0
150 | n = 0
151 | accuracy = 0
152 | loss_accum = 0
153 | for X,lengths in train_iterator:
154 | if use_cuda:
155 | X = X.cuda()
156 | X = Variable(X)
157 | logp = model(X)
158 |
159 | mask = (X != mask_idx)
160 |
161 | index = X*mask.long()
162 | loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2)
163 | loss = torch.mean(loss.masked_select(mask))
164 |
165 | loss.backward()
166 |
167 | # clip the gradient
168 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
169 |
170 | solver.step()
171 | solver.zero_grad()
172 |
173 | _,y_hat = torch.max(logp, 2)
174 | correct = torch.sum((y_hat == X).masked_select(mask))
175 | #correct = torch.sum((y_hat == X)[mask.nonzero()].float())
176 |
177 | b = mask.long().sum().item()
178 | n += b
179 | delta = b*(loss.item() - loss_accum)
180 | loss_accum += delta/n
181 | delta = correct.item() - b*accuracy
182 | accuracy += delta/n
183 |
184 | b = X.size(0)
185 | it += b
186 | if (it - b)//100 < it//100:
187 | print('# [{}/{}] training {:.1%} loss={:.5f}, acc={:.5f}'.format(epoch+1
188 | , num_epochs
189 | , it/len(X_train)
190 | , loss_accum
191 | , accuracy
192 | )
193 | , end='\r', file=sys.stderr)
194 | print(' '*80, end='\r', file=sys.stderr)
195 |
196 | perplex = np.exp(loss_accum)
197 | string = str(epoch+1).zfill(digits) + '\t' + 'train' + '\t' + str(loss_accum) \
198 | + '\t' + str(perplex) + '\t' + str(accuracy)
199 | print(string, file=output)
200 | output.flush()
201 |
202 | # test epoch
203 | model.eval()
204 | it = 0
205 | n = 0
206 | accuracy = 0
207 | loss_accum = 0
208 | with torch.no_grad():
209 | for X,lengths in test_iterator:
210 | if use_cuda:
211 | X = X.cuda()
212 | X = Variable(X)
213 | logp = model(X)
214 |
215 | mask = (X != mask_idx)
216 |
217 | index = X*mask.long()
218 | loss = -logp.gather(2, index.unsqueeze(2)).squeeze(2)
219 | loss = torch.mean(loss.masked_select(mask))
220 |
221 | _,y_hat = torch.max(logp, 2)
222 | correct = torch.sum((y_hat == X).masked_select(mask))
223 |
224 | b = mask.long().sum().item()
225 | n += b
226 | delta = b*(loss.item() - loss_accum)
227 | loss_accum += delta/n
228 | delta = correct.item() - b*accuracy
229 | accuracy += delta/n
230 |
231 | b = X.size(0)
232 | it += b
233 | if (it - b)//100 < it//100:
234 | print('# [{}/{}] test {:.1%} loss={:.5f}, acc={:.5f}'.format(epoch+1
235 | , num_epochs
236 | , it/len(X_test)
237 | , loss_accum
238 | , accuracy
239 | )
240 | , end='\r', file=sys.stderr)
241 | print(' '*80, end='\r', file=sys.stderr)
242 |
243 | perplex = np.exp(loss_accum)
244 | string = str(epoch+1).zfill(digits) + '\t' + 'test' + '\t' + str(loss_accum) \
245 | + '\t' + str(perplex) + '\t' + str(accuracy)
246 | print(string, file=output)
247 | output.flush()
248 |
249 | ## save the model
250 | if save_prefix is not None:
251 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav'
252 | model = model.cpu()
253 | torch.save(model, save_path)
254 | if use_cuda:
255 | model = model.cuda()
256 |
257 |
258 |
259 | if __name__ == '__main__':
260 | main()
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
--------------------------------------------------------------------------------
/eval_transmembrane.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 | import sys
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn.utils.rnn import PackedSequence
10 | import torch.utils.data
11 |
12 | from src.alphabets import Uniprot21
13 | from src.parse_utils import parse_3line
14 | import src.transmembrane as tm
15 |
16 | def load_3line(path, alphabet):
17 | with open(path, 'rb') as f:
18 | names, x, y = parse_3line(f)
19 | x = [alphabet.encode(x) for x in x]
20 | y = [tm.encode_labels(y) for y in y]
21 | return x, y
22 |
23 | def load_data():
24 | alphabet = Uniprot21()
25 |
26 | path = 'data/transmembrane/TOPCONS2_datasets/TM.3line'
27 | x_tm, y_tm = load_3line(path, alphabet)
28 |
29 | path = 'data/transmembrane/TOPCONS2_datasets/SP+TM.3line'
30 | x_tm_sp, y_tm_sp = load_3line(path, alphabet)
31 |
32 | path = 'data/transmembrane/TOPCONS2_datasets/Globular.3line'
33 | x_glob, y_glob = load_3line(path, alphabet)
34 |
35 | path = 'data/transmembrane/TOPCONS2_datasets/Globular+SP.3line'
36 | x_glob_sp, y_glob_sp = load_3line(path, alphabet)
37 |
38 | datasets = {'TM': (x_tm, y_tm), 'SP+TM': (x_tm_sp, y_tm_sp),
39 | 'Globular': (x_glob, y_glob), 'Globular+SP': (x_glob_sp, y_glob_sp)}
40 |
41 | return datasets
42 |
43 | def split_dataset(xs, ys, random=np.random, k=5):
44 | x_splits = [[] for _ in range(k)]
45 | y_splits = [[] for _ in range(k)]
46 | order = random.permutation(len(xs))
47 | for i in range(len(order)):
48 | j = order[i]
49 | x_s = x_splits[i%k]
50 | y_s = y_splits[i%k]
51 | x_s.append(xs[j])
52 | y_s.append(ys[j])
53 | return x_splits, y_splits
54 |
55 | def unstack_lstm(lstm):
56 | in_size = lstm.input_size
57 | hidden_dim = lstm.hidden_size
58 | layers = []
59 | for i in range(lstm.num_layers):
60 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True)
61 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l']
62 | for attr in attributes:
63 | dest = attr + '0'
64 | src = attr + str(i)
65 | getattr(layer, dest).data[:] = getattr(lstm, src)
66 | #setattr(layer, dest, getattr(lstm, src))
67 |
68 | dest = attr + '0_reverse'
69 | src = attr + str(i) + '_reverse'
70 | getattr(layer, dest).data[:] = getattr(lstm, src)
71 | #setattr(layer, dest, getattr(lstm, src))
72 | layers.append(layer)
73 | in_size = 2*hidden_dim
74 | return layers
75 |
76 | def featurize(x, lm_embed, lstm_stack, proj, include_lm=True, lm_only=False):
77 | zs = []
78 |
79 | x_onehot = x.new(x.size(0),x.size(1), 21).float().zero_()
80 | x_onehot.scatter_(2,x.unsqueeze(2),1)
81 | zs.append(x_onehot)
82 |
83 | h = lm_embed(x)
84 | if include_lm:
85 | zs.append(h)
86 | if not lm_only:
87 | for lstm in lstm_stack:
88 | h,_ = lstm(h)
89 | zs.append(h)
90 | h = proj(h.squeeze(0)).unsqueeze(0)
91 | zs.append(h)
92 | z = torch.cat(zs, 2)
93 | return z
94 |
95 | def featurize_dict(datasets, lm_embed, lstm_stack, proj, use_cuda=False, include_lm=True, lm_only=False):
96 | z = {}
97 | for k,v in datasets.items():
98 | x_k = v[0]
99 | z[k] = []
100 | with torch.no_grad():
101 | for x in x_k:
102 | x = torch.from_numpy(x).long().unsqueeze(0)
103 | if use_cuda:
104 | x = x.cuda()
105 | z_x = featurize(x, lm_embed, lstm_stack, proj, include_lm=include_lm, lm_only=lm_only)
106 | z_x = z_x.squeeze(0).cpu()
107 | z[k].append(z_x)
108 | return z
109 |
110 | def featurize_one_hot_dict(datasets, n):
111 | z = {}
112 | for k,v in datasets.items():
113 | x_k = v[0]
114 | z[k] = []
115 | with torch.no_grad():
116 | for x in x_k:
117 | x = torch.from_numpy(x).long()
118 | one_hot = torch.FloatTensor(x.size(0), n).to(x.device)
119 | one_hot.zero_()
120 | one_hot.scatter_(1, x.unsqueeze(1), 1)
121 | z[k].append(one_hot)
122 | return z
123 |
124 | def make_train_test(splits, j, k):
125 | x_train = []
126 | y_train = []
127 | for v in splits.values():
128 | for i in range(k):
129 | if i != j:
130 | x_train += v[0][i]
131 | y_train += v[1][i]
132 |
133 | x_test = {k:v[0][j] for k,v in splits.items()}
134 | y_test = {k:v[1][j] for k,v in splits.items()}
135 |
136 | return x_train, y_train, x_test, y_test
137 |
138 | class ListDataset:
139 | def __init__(self, x, y):
140 | self.x = x
141 | self.y = y
142 |
143 | def __len__(self):
144 | return len(self.x)
145 |
146 | def __getitem__(self, i):
147 | return self.x[i], self.y[i]
148 |
149 | class LSTM(nn.Module):
150 | def __init__(self, n_in, n_hidden, n_out):
151 | super(LSTM, self).__init__()
152 | self.rnn = nn.LSTM(n_in, n_hidden, bidirectional=True, batch_first=True)
153 | self.linear = nn.Linear(2*n_hidden, n_out)
154 |
155 | def forward(self, x):
156 | if type(x) is not PackedSequence:
157 | ndim = len(x.size())
158 | if ndim == 2:
159 | x = x.unsqueeze(0)
160 | h,_ = self.rnn(x)
161 | if type(h) is PackedSequence:
162 | z = self.linear(h.data)
163 | return PackedSequence(z, h.batch_sizes)
164 | else:
165 | z = self.linear(h.view(h.size(0)*h.size(1), -1))
166 | z = z.view(h.size(0), h.size(1), -1)
167 | if ndim == 2:
168 | z = z.squeeze(0)
169 | return z
170 |
171 | def train(x_train, y_train, num_epochs=10, hidden_dim=100, use_cuda=False):
172 |
173 | d = x_train[0].size(1)
174 |
175 | model = LSTM(d, hidden_dim, 4)
176 | if use_cuda:
177 | model.cuda()
178 |
179 | batch_size = 1
180 | dataset = ListDataset(x_train, y_train)
181 | iterator = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)
182 |
183 | optim = torch.optim.Adam(model.parameters(), lr=3e-4)
184 |
185 | for epoch in range(num_epochs):
186 | for x,y in iterator:
187 | if use_cuda:
188 | x = x.cuda()
189 | y = y.cuda()
190 | log_p = model(x).squeeze(0)
191 | x = x.squeeze(0)
192 | y = y.squeeze(0)
193 | loss = F.cross_entropy(log_p, y)
194 | loss.backward()
195 | optim.step()
196 | optim.zero_grad()
197 |
198 | return model
199 |
200 | def evaluate_model(model, grammar, z_test, y_test, use_cuda=False):
201 | results = {}
202 | for key in z_test:
203 | y_hats = grammar.predict_viterbi(z_test[key], model, use_cuda)
204 | correct = np.zeros(len(y_hats))
205 | for i,(pred,target) in enumerate(zip(y_hats, y_test[key])):
206 | correct[i] = tm.is_prediction_correct(pred, target)
207 | results[key] = correct.mean()
208 | overall = sum(results.values())/len(results)
209 | return overall, results
210 |
211 | def evaluate_split(splits, j, k, num_epochs=10, hidden_dim=100, use_cuda=False, grammar=tm.Grammar()):
212 | x_train, y_train, x_test, y_test = make_train_test(splits, j, k)
213 | model = train(x_train, y_train, num_epochs=num_epochs, hidden_dim=hidden_dim, use_cuda=use_cuda)
214 | model.eval()
215 | overall,results = evaluate_model(model, grammar, x_test, y_test, use_cuda=use_cuda)
216 | return overall, results
217 |
218 |
219 | def main():
220 | import argparse
221 | parser = argparse.ArgumentParser()
222 | parser.add_argument('model', help='path to saved embedding model')
223 | parser.add_argument('--hidden-dim', type=int, default=150, help='dimension of LSTM (default: 150)')
224 | parser.add_argument('--num-epochs', type=int, default=10, help='number of training epochs (default: 10)')
225 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
226 | args = parser.parse_args()
227 |
228 | datasets = load_data()
229 | num_epochs = args.num_epochs
230 | hidden_dim = args.hidden_dim
231 |
232 | d = args.device
233 | use_cuda = (d != -1) and torch.cuda.is_available()
234 | if d >= 0:
235 | torch.cuda.set_device(d)
236 |
237 |
238 | ## load the embedding model
239 | if args.model == '1-hot':
240 | print('# featurizing data', file=sys.stderr)
241 | z = featurize_one_hot_dict(datasets, 21)
242 | datasets = {k: (z[k],v[1]) for k,v in datasets.items()}
243 |
244 | else:
245 | encoder = torch.load(args.model)
246 | encoder.eval()
247 | encoder = encoder.embedding
248 |
249 | lm_embed = encoder.embed
250 | lstm_stack = unstack_lstm(encoder.rnn)
251 | proj = encoder.proj
252 |
253 | if use_cuda:
254 | lm_embed.cuda()
255 | for lstm in lstm_stack:
256 | lstm.cuda()
257 | proj.cuda()
258 |
259 | ## featurize the sequences
260 | print('# featurizing data', file=sys.stderr)
261 | z = featurize_dict(datasets, lm_embed, lstm_stack, proj, use_cuda=use_cuda)
262 |
263 | del lm_embed
264 | del lstm_stack
265 | del proj
266 | del encoder
267 |
268 | datasets = {k: (z[k],v[1]) for k,v in datasets.items()}
269 |
270 | ## split into folds
271 | random = np.random.RandomState(10)
272 | K = 10
273 | datasets_split = {k: split_dataset(v[0], v[1], random=random, k=K) for k,v in datasets.items()}
274 |
275 | ## train/test on each fold
276 | print('# training and evaluating with', K, 'folds', file=sys.stderr)
277 | print('# using', hidden_dim, 'LSTM units', file=sys.stderr)
278 | tags = ['TM', 'SP+TM', 'Globular', 'Globular+SP']
279 | print('\t'.join(['Fold'] + tags + ['Overall']))
280 | split_results = {}
281 | split_overall = []
282 | for i in range(K):
283 | overall, results = evaluate_split(datasets_split, i, K,
284 | num_epochs=num_epochs, hidden_dim=hidden_dim,
285 | use_cuda=use_cuda)
286 | for key in tags:
287 | this = split_results.get(key, [])
288 | this.append(results[key])
289 | split_results[key] = this
290 | split_overall.append(overall)
291 | cols = [str(i)] + ['{:.5f}'.format(results[key]) for key in tags] + ['{:.5f}'.format(overall)]
292 | line = '\t'.join(cols)
293 | print(line)
294 |
295 | results = {key:np.mean(values) for key,values in split_results.items()}
296 | overall = np.mean(split_overall)
297 |
298 | cols = ['All'] + ['{:.5f}'.format(results[key]) for key in tags] + ['{:.5f}'.format(overall)]
299 | line = '\t'.join(cols)
300 | print(line)
301 |
302 |
303 | if __name__ == '__main__':
304 | main()
305 |
306 |
307 |
308 |
--------------------------------------------------------------------------------
/eval_secstr.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.utils.rnn import PackedSequence
9 | import torch.utils.data
10 |
11 | from src.alphabets import Uniprot21, SecStr8
12 | from src.utils import pack_sequences, unpack_sequences
13 | import src.pdb as pdb
14 |
15 |
16 | secstr_train_path = 'data/secstr/ss_cullpdb_pc40_res3.0_R1.0_d180412_filtered.train.fa'
17 | secstr_test_path = 'data/secstr/ss_cullpdb_pc40_res3.0_R1.0_d180412_filtered.test.fa'
18 |
19 |
20 | def encode_sequence(x, alphabet):
21 | # convert to bytes and uppercase
22 | x = x.encode('utf-8').upper()
23 | # convert to alphabet index
24 | x = alphabet.encode(x)
25 | return x
26 |
27 |
28 | def load_secstr(path, alphabet, secstr):
29 | with open(path, 'rb') as f:
30 | names,aa_seqs,ss_seqs = pdb.parse_secstr(f)
31 | aa_seqs = [alphabet.encode(x.upper()) for x in aa_seqs]
32 | ss_seqs = [secstr.encode(x.upper()) for x in ss_seqs]
33 | return names,aa_seqs,ss_seqs
34 |
35 |
36 | def unstack_lstm(lstm):
37 | in_size = lstm.input_size
38 | hidden_dim = lstm.hidden_size
39 | layers = []
40 | for i in range(lstm.num_layers):
41 | layer = nn.LSTM(in_size, hidden_dim, batch_first=True, bidirectional=True)
42 | attributes = ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l']
43 | for attr in attributes:
44 | dest = attr + '0'
45 | src = attr + str(i)
46 | getattr(layer, dest).data[:] = getattr(lstm, src)
47 | dest = attr + '0_reverse'
48 | src = attr + str(i) + '_reverse'
49 | getattr(layer, dest).data[:] = getattr(lstm, src)
50 | layers.append(layer)
51 | in_size = 2*hidden_dim
52 | return layers
53 |
54 | def featurize(x, lm_embed, lstm_stack, proj, include_lm=True, lm_only=False):
55 | zs = []
56 |
57 | packed = type(x) is PackedSequence
58 | if packed:
59 | batch_sizes = x.batch_sizes
60 | x = x.data
61 |
62 | x_onehot = x.new(x.size(0), 21).float().zero_()
63 | x_onehot.scatter_(1,x.unsqueeze(1),1)
64 |
65 | if packed:
66 | x_onehot = PackedSequence(x_onehot, batch_sizes)
67 | x = PackedSequence(x, batch_sizes)
68 |
69 | zs.append(x_onehot)
70 |
71 | h = lm_embed(x)
72 | if include_lm:
73 | zs.append(h)
74 | if not lm_only:
75 | for lstm in lstm_stack:
76 | h,_ = lstm(h)
77 | zs.append(h)
78 | if packed:
79 | h = h.data
80 | h = proj(h)
81 | if packed:
82 | h = PackedSequence(h, batch_sizes)
83 | zs.append(h)
84 | if packed:
85 | zs = [z.data for z in zs]
86 | z = torch.cat(zs, 1)
87 | if packed:
88 | z = PackedSequence(z, batch_sizes)
89 | return z
90 |
91 | class TorchModel:
92 | def __init__(self, model, use_cuda, full_features=False):
93 | self.model = model
94 | self.use_cuda = use_cuda
95 | self.full_features = full_features
96 | if full_features:
97 | self.lm_embed = model.embedding.embed
98 | self.lstm_stack = unstack_lstm(model.embedding.rnn)
99 | self.proj = model.embedding.proj
100 | if use_cuda:
101 | self.lm_embed.cuda()
102 | for lstm in self.lstm_stack:
103 | lstm.cuda()
104 | self.proj.cuda()
105 |
106 |
107 | def __call__(self, x):
108 | c = [torch.from_numpy(x_).long() for x_ in x]
109 |
110 | c,order = pack_sequences(c)
111 | if self.use_cuda:
112 | c = c.cuda()
113 |
114 | if self.full_features:
115 | z = featurize(c, self.lm_embed, self.lstm_stack, self.proj)
116 | else:
117 | z = self.model(c) # embed the sequences
118 | z = unpack_sequences(z, order)
119 |
120 | return z
121 |
122 | def kmer_features(xs, n, k):
123 | if k == 1:
124 | return xs, n
125 | pad = np.array([n]*(k//2))
126 | f = (n+1)**np.arange(k)
127 | kmers = []
128 | for x in xs:
129 | x = np.concatenate([pad, x, pad], axis=0)
130 | z = np.convolve(x, f, mode='valid')
131 | kmers.append(z)
132 | return kmers, (n+1)**k
133 |
134 |
135 | class Shuffle:
136 | def __init__(self, x, y, minibatch_size):
137 | self.x = x
138 | self.y = y
139 | self.minibatch_size = minibatch_size
140 |
141 | def __iter__(self):
142 | n = len(self.x)
143 | order = np.random.permutation(n)
144 | order = torch.from_numpy(order).long().to(self.x.device)
145 | x = self.x[order]
146 | y = self.y[order]
147 | b = self.minibatch_size
148 | for i in range(0, n, b):
149 | yield x[i:i+b], y[i:i+b]
150 |
151 |
152 | def fit_kmer_potentials(x, y, n, m):
153 | _,counts = np.unique(y, return_counts=True)
154 | weights = torch.zeros(n, m)
155 | weights += torch.from_numpy(counts/counts.sum()).float()
156 | for i in range(len(x)):
157 | weights[x[i],y[i]] += 1
158 |
159 | model = nn.Embedding(n, m, sparse=True)
160 | model.weight.data[:] = torch.log(weights) - torch.log(weights.sum(1, keepdim=True))
161 |
162 | return model
163 |
164 |
165 | def fit_nn_potentials(model, x, y, lr=0.001, num_epochs=10, minibatch_size=256
166 | , use_cuda=False):
167 | solver = torch.optim.Adam(model.parameters(), lr=lr)
168 |
169 | iterator = Shuffle(x, y, minibatch_size)
170 |
171 | model.train()
172 | for epoch in range(num_epochs):
173 | n = 0
174 | loss_accum = 0
175 | acc = 0
176 | for x,y in iterator:
177 | if use_cuda:
178 | x = x.cuda()
179 | y = y.cuda()
180 | potentials = model(x).view(x.size(0), -1)
181 | loss = F.cross_entropy(potentials, y)
182 |
183 | loss.backward()
184 | solver.step()
185 | solver.zero_grad()
186 |
187 | _,y_hat = potentials.max(1)
188 | correct = torch.sum((y_hat == y).float())
189 |
190 | b = x.size(0)
191 | n += b
192 | delta = b*(loss.item() - loss_accum)
193 | loss_accum += delta/n
194 | delta = correct.item() - b*acc
195 | acc += delta/n
196 |
197 | print('train', epoch+1, loss_accum, np.exp(loss_accum), acc)
198 |
199 |
200 | def main():
201 | import argparse
202 | parser = argparse.ArgumentParser('Script for evaluating similarity model on SCOP test set.')
203 |
204 | parser.add_argument('features', help='path to saved embedding model file or "1-", "3-", or "5-mer" for k-mer features')
205 |
206 | parser.add_argument('--num-epochs', type=int, default=10, help='number of epochs to train for (default: 10)')
207 | parser.add_argument('--all-hidden', action='store_true', help='use all hidden layers as features')
208 |
209 | parser.add_argument('-v', '--print-examples', default=0, type=int, help='number of examples to print (default: 0)')
210 |
211 | parser.add_argument('-o', '--output', help='output file path (default: stdout)')
212 | parser.add_argument('--save-prefix', help='path prefix for saving models')
213 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
214 |
215 | args = parser.parse_args()
216 | num_epochs = args.num_epochs
217 |
218 | ## load the data
219 | alphabet = Uniprot21()
220 | secstr = SecStr8
221 |
222 | names_train, x_train, y_train = load_secstr(secstr_train_path, alphabet, secstr)
223 | names_test, x_test, y_test = load_secstr(secstr_test_path, alphabet, secstr)
224 |
225 | sequences_test = [''.join(alphabet[c] for c in x_test[i]) for i in range(len(x_test))]
226 |
227 | y_train = np.concatenate(y_train, 0)
228 |
229 | ## set the device
230 | d = args.device
231 | use_cuda = (d != -1) and torch.cuda.is_available()
232 | if d >= 0:
233 | torch.cuda.set_device(d)
234 |
235 |
236 | if args.features == '1-mer':
237 | n = len(alphabet)
238 | x_test = [x.astype(int) for x in x_test]
239 | elif args.features == '3-mer':
240 | x_train,n = kmer_features(x_train, len(alphabet), 3)
241 | x_test,_ = kmer_features(x_test, len(alphabet), 3)
242 | elif args.features == '5-mer':
243 | x_train,n = kmer_features(x_train, len(alphabet), 5)
244 | x_test,_ = kmer_features(x_test, len(alphabet), 5)
245 | else:
246 | features = torch.load(args.features)
247 | features.eval()
248 |
249 | if use_cuda:
250 | features.cuda()
251 |
252 | features = TorchModel(features, use_cuda, full_features=args.all_hidden)
253 | batch_size = 32 # batch size for featurizing sequences
254 |
255 | with torch.no_grad():
256 | z_train = []
257 | for i in range(0,len(x_train),batch_size):
258 | for z in features(x_train[i:i+batch_size]):
259 | z_train.append(z.cpu().numpy())
260 | x_train = z_train
261 |
262 | z_test = []
263 | for i in range(0,len(x_test),batch_size):
264 | for z in features(x_test[i:i+batch_size]):
265 | z_test.append(z.cpu().numpy())
266 | x_test = z_test
267 |
268 | n = x_train[0].shape[1]
269 | del features
270 | del z_train
271 | del z_test
272 |
273 | print('split', 'epoch', 'loss', 'perplexity', 'accuracy')
274 |
275 | if args.features.endswith('-mer'):
276 | x_train = np.concatenate(x_train, 0)
277 | model = fit_kmer_potentials(x_train, y_train, n, len(secstr))
278 | else:
279 | x_train = torch.cat([torch.from_numpy(x) for x in x_train], 0)
280 | if use_cuda and not args.all_hidden:
281 | x_train = x_train.cuda()
282 |
283 | num_hidden = 1024
284 | model = nn.Sequential( nn.Linear(n, num_hidden)
285 | , nn.ReLU()
286 | , nn.Linear(num_hidden, num_hidden)
287 | , nn.ReLU()
288 | , nn.Linear(num_hidden, len(secstr))
289 | )
290 |
291 | y_train = torch.from_numpy(y_train).long()
292 | if use_cuda:
293 | y_train = y_train.cuda()
294 | model.cuda()
295 |
296 | fit_nn_potentials(model, x_train, y_train, num_epochs=num_epochs, use_cuda=use_cuda)
297 |
298 | if use_cuda:
299 | model.cuda()
300 | model.eval()
301 |
302 | num_examples = args.print_examples
303 | if num_examples > 0:
304 | names_examples = names_test[:num_examples]
305 | x_examples = x_test[:num_examples]
306 | y_examples = y_test[:num_examples]
307 |
308 | A = np.zeros((8,3), dtype=np.float32)
309 | I = np.zeros(8, dtype=int)
310 | # helix
311 | A[0,0] = 1.0
312 | A[3,0] = 1.0
313 | A[4,0] = 1.0
314 | I[0] = 0
315 | I[3] = 0
316 | I[4] = 0
317 | # sheet
318 | A[1,1] = 1.0
319 | A[2,1] = 1.0
320 | I[1] = 1
321 | I[2] = 1
322 | # coil
323 | A[5,2] = 1.0
324 | A[6,2] = 1.0
325 | A[7,2] = 1.0
326 | I[5] = 2
327 | I[6] = 2
328 | I[7] = 2
329 |
330 | A = torch.from_numpy(A)
331 | I = torch.from_numpy(I)
332 | if use_cuda:
333 | A = A.cuda()
334 | I = I.cuda()
335 |
336 | n = 0
337 | acc_8 = 0
338 | acc_3 = 0
339 | loss_8 = 0
340 | loss_3 = 0
341 |
342 | x_test = torch.cat([torch.from_numpy(x) for x in x_test], 0)
343 | y_test = torch.cat([torch.from_numpy(y).long() for y in y_test], 0)
344 |
345 | if use_cuda and not args.all_hidden:
346 | x_test = x_test.cuda()
347 | y_test = y_test.cuda()
348 |
349 | mb = 256
350 | with torch.no_grad():
351 | for i in range(0, len(x_test), mb):
352 | x = x_test[i:i+mb]
353 | y = y_test[i:i+mb]
354 |
355 | if use_cuda:
356 | x = x.cuda()
357 | y = y.cuda()
358 |
359 | potentials = model(x).view(x.size(0), -1)
360 |
361 | ## 8-class SS
362 | l = F.cross_entropy(potentials, y).item()
363 | _,y_hat = potentials.max(1)
364 | correct = torch.sum((y == y_hat).float()).item()
365 |
366 | n += x.size(0)
367 | delta = x.size(0)*(l - loss_8)
368 | loss_8 += delta/n
369 | delta = correct - x.size(0)*acc_8
370 | acc_8 += delta/n
371 |
372 | ## 3-class SS
373 | y = I[y]
374 | p = F.softmax(potentials, 1)
375 | p = torch.mm(p, A) # ss3 probabilities
376 | log_p = torch.log(p)
377 | l = F.nll_loss(log_p, y).item()
378 | _,y_hat = log_p.max(1)
379 | correct = torch.sum((y == y_hat).float()).item()
380 |
381 | delta = x.size(0)*(l - loss_3)
382 | loss_3 += delta/n
383 | delta = correct - x.size(0)*acc_3
384 | acc_3 += delta/n
385 |
386 | print('-', '-', '8-class', '-', '3-class', '-')
387 | print('split', 'perplexity', 'accuracy', 'perplexity', 'accuracy')
388 | print('test', np.exp(loss_8), acc_8, np.exp(loss_3), acc_3)
389 |
390 |
391 | if num_examples > 0:
392 | for i in range(num_examples):
393 | name = names_examples[i].decode('utf-8')
394 | x = x_examples[i]
395 | y = y_examples[i]
396 |
397 | seq = sequences_test[i]
398 |
399 | print('>' + name + ' sequence')
400 | print(seq)
401 | print('')
402 |
403 | ss = ''.join(secstr[c] for c in y)
404 | ss = ss.replace(' ', 'C')
405 | print('>' + name + ' secstr')
406 | print(ss)
407 | print('')
408 |
409 | x = torch.from_numpy(x)
410 | if use_cuda:
411 | x = x.cuda()
412 | potentials = model(x)
413 | _,y_hat = torch.max(potentials, 1)
414 | y_hat = y_hat.cpu().numpy()
415 |
416 | ss_hat = ''.join(secstr[c] for c in y_hat)
417 | ss_hat = ss_hat.replace(' ', 'C')
418 | print('>' + name + ' predicted')
419 | print(ss_hat)
420 | print('')
421 |
422 |
423 |
424 | if __name__ == '__main__':
425 | main()
426 |
--------------------------------------------------------------------------------
/train_similarity.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import sys
6 |
7 | from scipy.stats import pearsonr, spearmanr
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
14 | import torch.utils.data
15 |
16 | from src.alphabets import Uniprot21
17 | import src.scop as scop
18 | from src.utils import pack_sequences, unpack_sequences
19 | from src.utils import PairedDataset, AllPairsDataset, collate_paired_sequences
20 | from src.utils import MultinomialResample
21 | import src.models.embedding
22 | import src.models.comparison
23 |
24 |
25 | def main():
26 | import argparse
27 | parser = argparse.ArgumentParser('Script for training embedding model on SCOP.')
28 |
29 | parser.add_argument('--dev', action='store_true', help='use train/dev split')
30 |
31 | parser.add_argument('-m', '--model', choices=['ssa', 'ua', 'me'], default='ssa', help='alignment scoring method for comparing sequences in embedding space [ssa: soft symmetric alignment, ua: uniform alignment, me: mean embedding] (default: ssa)')
32 | parser.add_argument('--allow-insert', action='store_true', help='model insertions (default: false)')
33 |
34 | parser.add_argument('--norm', choices=['l1', 'l2'], default='l1', help='comparison norm (default: l1)')
35 |
36 | parser.add_argument('--rnn-type', choices=['lstm', 'gru'], default='lstm', help='type of RNN block to use (default: lstm)')
37 | parser.add_argument('--embedding-dim', type=int, default=100, help='embedding dimension (default: 100)')
38 | parser.add_argument('--input-dim', type=int, default=512, help='dimension of input to RNN (default: 512)')
39 | parser.add_argument('--rnn-dim', type=int, default=512, help='hidden units of RNNs (default: 512)')
40 | parser.add_argument('--num-layers', type=int, default=3, help='number of RNN layers (default: 3)')
41 | parser.add_argument('--dropout', type=float, default=0, help='dropout probability (default: 0)')
42 |
43 | parser.add_argument('--epoch-size', type=int, default=100000, help='number of examples per epoch (default: 100,000)')
44 | parser.add_argument('--epoch-scale', type=int, default=5, help='scaling on epoch size (default: 5)')
45 | parser.add_argument('--num-epochs', type=int, default=100, help='number of epochs (default: 100)')
46 |
47 | parser.add_argument('--batch-size', type=int, default=64, help='minibatch size (default: 64)')
48 |
49 | parser.add_argument('--weight-decay', type=float, default=0, help='L2 regularization (default: 0)')
50 | parser.add_argument('--lr', type=float, default=0.001)
51 |
52 | parser.add_argument('--tau', type=float, default=0.5, help='sampling proportion exponent (default: 0.5)')
53 | parser.add_argument('--augment', type=float, default=0, help='probability of resampling amino acid for data augmentation (default: 0)')
54 | parser.add_argument('--lm', help='pretrained LM to use as initial embedding')
55 |
56 | parser.add_argument('-o', '--output', help='output file path (default: stdout)')
57 | parser.add_argument('--save-prefix', help='path prefix for saving models')
58 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
59 |
60 | args = parser.parse_args()
61 |
62 |
63 | prefix = args.output
64 |
65 |
66 | ## set the device
67 | d = args.device
68 | use_cuda = (d != -1) and torch.cuda.is_available()
69 | if d >= 0:
70 | torch.cuda.set_device(d)
71 |
72 | ## make the datasets
73 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.fa'
74 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt'
75 | if args.dev:
76 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa'
77 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt'
78 |
79 | alphabet = Uniprot21()
80 |
81 | print('# loading training sequences:', astral_train_path, file=sys.stderr)
82 | with open(astral_train_path, 'rb') as f:
83 | names_train, structs_train, sequences_train = scop.parse_astral(f, encoder=alphabet)
84 | x_train = [torch.from_numpy(x).long() for x in sequences_train]
85 | if use_cuda:
86 | x_train = [x.cuda() for x in x_train]
87 | y_train = torch.from_numpy(structs_train)
88 |
89 | print('# loaded', len(x_train), 'training sequences', file=sys.stderr)
90 |
91 |
92 | print('# loading test sequence pairs:', astral_testpairs_path, file=sys.stderr)
93 | test_pairs_table = pd.read_csv(astral_testpairs_path, sep='\t')
94 | x0_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_A']]
95 | x0_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x0_test]
96 | x1_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_B']]
97 | x1_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x1_test]
98 | if use_cuda:
99 | x0_test = [x.cuda() for x in x0_test]
100 | x1_test = [x.cuda() for x in x1_test]
101 | y_test = test_pairs_table['similarity'].values
102 | y_test = torch.from_numpy(y_test).long()
103 |
104 | dataset_test = PairedDataset(x0_test, x1_test, y_test)
105 | print('# loaded', len(x0_test), 'test pairs', file=sys.stderr)
106 |
107 | ## make the dataset iterators
108 | scale = args.epoch_scale
109 |
110 | epoch_size = args.epoch_size
111 | batch_size = args.batch_size
112 |
113 | # precompute the similarity pairs
114 | y_train_levels = torch.cumprod((y_train.unsqueeze(1) == y_train.unsqueeze(0)).long(), 2)
115 |
116 | # data augmentation by resampling amino acids
117 | augment = None
118 | p = 0
119 | if args.augment > 0:
120 | p = args.augment
121 | trans = torch.ones(len(alphabet),len(alphabet))
122 | trans = trans/trans.sum(1, keepdim=True)
123 | if use_cuda:
124 | trans = trans.cuda()
125 | augment = MultinomialResample(trans, p)
126 | print('# resampling amino acids with p:', p, file=sys.stderr)
127 | dataset_train = AllPairsDataset(x_train, y_train_levels, augment=augment)
128 |
129 | similarity = y_train_levels.numpy().sum(2)
130 | levels,counts = np.unique(similarity, return_counts=True)
131 | order = np.argsort(levels)
132 | levels = levels[order]
133 | counts = counts[order]
134 |
135 | print('#', levels, file=sys.stderr)
136 | print('#', counts/np.sum(counts), file=sys.stderr)
137 |
138 | weight = counts**0.5
139 | print('#', weight/np.sum(weight), file=sys.stderr)
140 |
141 | weight = counts**0.33
142 | print('#', weight/np.sum(weight), file=sys.stderr)
143 |
144 | weight = counts**0.25
145 | print('#', weight/np.sum(weight), file=sys.stderr)
146 |
147 | tau = args.tau
148 | print('# using tau:', tau, file=sys.stderr)
149 | print('#', counts**tau/np.sum(counts**tau), file=sys.stderr)
150 | weights = counts**tau/counts
151 | weights = weights[similarity].ravel()
152 | #weights = np.ones(len(dataset_train))
153 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, epoch_size)
154 |
155 | # two training dataset iterators for sampling pairs of sequences for training
156 | train_iterator = torch.utils.data.DataLoader(dataset_train
157 | , batch_size=batch_size
158 | , sampler=sampler
159 | , collate_fn=collate_paired_sequences
160 | )
161 | test_iterator = torch.utils.data.DataLoader(dataset_test
162 | , batch_size=batch_size
163 | , collate_fn=collate_paired_sequences
164 | )
165 |
166 |
167 | ## initialize the model
168 | rnn_type = args.rnn_type
169 | rnn_dim = args.rnn_dim
170 | num_layers = args.num_layers
171 |
172 | embedding_size = args.embedding_dim
173 | input_dim = args.input_dim
174 |
175 | dropout = args.dropout
176 |
177 | allow_insert = args.allow_insert
178 |
179 | print('# initializing model with:', file=sys.stderr)
180 | print('# embedding_size:', embedding_size, file=sys.stderr)
181 | print('# input_dim:', input_dim, file=sys.stderr)
182 | print('# rnn_dim:', rnn_dim, file=sys.stderr)
183 | print('# num_layers:', num_layers, file=sys.stderr)
184 | print('# dropout:', dropout, file=sys.stderr)
185 | print('# allow_insert:', allow_insert, file=sys.stderr)
186 |
187 | compare_type = args.model
188 | print('# comparison method:', compare_type, file=sys.stderr)
189 |
190 | lm = None
191 | if args.lm is not None:
192 | lm = torch.load(args.lm)
193 | lm.eval()
194 | ## do not update the LM parameters
195 | for param in lm.parameters():
196 | param.requires_grad = False
197 | print('# using LM:', args.lm, file=sys.stderr)
198 |
199 | if num_layers > 0:
200 | embedding = src.models.embedding.StackedRNN(len(alphabet), input_dim, rnn_dim, embedding_size
201 | , nlayers=num_layers, dropout=dropout, lm=lm)
202 | else:
203 | embedding = src.models.embedding.Linear(len(alphabet), input_dim, embedding_size, lm=lm)
204 |
205 | if args.norm == 'l1':
206 | norm = src.models.comparison.L1()
207 | print('# norm: l1', file=sys.stderr)
208 | elif args.norm == 'l2':
209 | norm = src.models.comparison.L2()
210 | print('# norm: l2', file=sys.stderr)
211 | model = src.models.comparison.OrdinalRegression(embedding, 5, align_method=compare_type
212 | , compare=norm, allow_insertions=allow_insert
213 | )
214 |
215 | if use_cuda:
216 | model.cuda()
217 |
218 | ## setup training parameters and optimizer
219 | num_epochs = args.num_epochs
220 |
221 | weight_decay = args.weight_decay
222 | lr = args.lr
223 |
224 | print('# training with Adam: lr={}, weight_decay={}'.format(lr, weight_decay), file=sys.stderr)
225 | params = [p for p in model.parameters() if p.requires_grad]
226 | optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
227 |
228 | ## train the model
229 | print('# training model', file=sys.stderr)
230 |
231 | save_prefix = args.save_prefix
232 | output = args.output
233 | if output is None:
234 | output = sys.stdout
235 | else:
236 | output = open(output, 'w')
237 | digits = int(np.floor(np.log10(num_epochs))) + 1
238 | line = '\t'.join(['epoch', 'split', 'loss', 'mse', 'accuracy', 'r', 'rho' ])
239 | print(line, file=output)
240 |
241 |
242 | for epoch in range(num_epochs):
243 | # train epoch
244 | model.train()
245 | it = 0
246 | n = 0
247 | loss_estimate = 0
248 | mse_estimate = 0
249 | acc_estimate = 0
250 |
251 | for x0,x1,y in train_iterator: # zip(train_iterator_0, train_iterator_1):
252 |
253 | if use_cuda:
254 | y = y.cuda()
255 | y = Variable(y)
256 |
257 | b = len(x0)
258 | x = x0 + x1
259 |
260 | x,order = pack_sequences(x)
261 | x = PackedSequence(Variable(x.data), x.batch_sizes)
262 | z = model(x) # embed the sequences
263 | z = unpack_sequences(z, order)
264 |
265 | z0 = z[:b]
266 | z1 = z[b:]
267 |
268 | logits = []
269 | for i in range(b):
270 | z_a = z0[i]
271 | z_b = z1[i]
272 | logits.append(model.score(z_a, z_b))
273 | logits = torch.stack(logits, 0)
274 |
275 | loss = F.binary_cross_entropy_with_logits(logits, y.float())
276 | loss.backward()
277 |
278 | optim.step()
279 | optim.zero_grad()
280 | model.clip() # projected gradient for bounding ordinal regressionn parameters
281 |
282 | p = F.sigmoid(logits)
283 | ones = p.new(b,1).zero_() + 1
284 | p_ge = torch.cat([ones, p], 1)
285 | p_lt = torch.cat([1-p, ones], 1)
286 | p = p_ge*p_lt
287 | p = p/p.sum(1,keepdim=True) # make sure p is normalized
288 |
289 | _,y_hard = torch.max(p, 1)
290 | levels = torch.arange(5).to(p.device)
291 | y_hat = torch.sum(p*levels, 1)
292 | y = torch.sum(y.data, 1)
293 |
294 | loss = F.cross_entropy(p, y) # calculate cross entropy loss from p vector
295 |
296 | correct = torch.sum((y == y_hard).float())
297 | mse = torch.sum((y.float() - y_hat)**2)
298 |
299 | n += b
300 | delta = b*(loss.item() - loss_estimate)
301 | loss_estimate += delta/n
302 | delta = correct.item() - b*acc_estimate
303 | acc_estimate += delta/n
304 | delta = mse.item() - b*mse_estimate
305 | mse_estimate += delta/n
306 |
307 |
308 | if (n - b)//100 < n//100:
309 | print('# [{}/{}] training {:.1%} loss={:.5f}, mse={:.5f}, acc={:.5f}'.format(epoch+1
310 | , num_epochs
311 | , n/epoch_size
312 | , loss_estimate
313 | , mse_estimate
314 | , acc_estimate
315 | )
316 | , end='\r', file=sys.stderr)
317 | print(' '*80, end='\r', file=sys.stderr)
318 | line = '\t'.join([str(epoch+1).zfill(digits), 'train', str(loss_estimate)
319 | , str(mse_estimate), str(acc_estimate), '-', '-'])
320 | print(line, file=output)
321 | output.flush()
322 |
323 | # eval and save model
324 | model.eval()
325 |
326 | y = []
327 | logits = []
328 | with torch.no_grad():
329 | for x0,x1,y_mb in test_iterator:
330 |
331 | if use_cuda:
332 | y_mb = y_mb.cuda()
333 | y.append(y_mb.long())
334 |
335 | b = len(x0)
336 | x = x0 + x1
337 |
338 | x,order = pack_sequences(x)
339 | x = PackedSequence(Variable(x.data), x.batch_sizes)
340 | z = model(x) # embed the sequences
341 | z = unpack_sequences(z, order)
342 |
343 | z0 = z[:b]
344 | z1 = z[b:]
345 |
346 | for i in range(b):
347 | z_a = z0[i]
348 | z_b = z1[i]
349 | logits.append(model.score(z_a, z_b))
350 |
351 | y = torch.cat(y, 0)
352 | logits = torch.stack(logits, 0)
353 |
354 | p = F.sigmoid(logits).data
355 | ones = p.new(p.size(0),1).zero_() + 1
356 | p_ge = torch.cat([ones, p], 1)
357 | p_lt = torch.cat([1-p, ones], 1)
358 | p = p_ge*p_lt
359 | p = p/p.sum(1,keepdim=True) # make sure p is normalized
360 |
361 | loss = F.cross_entropy(p, y).item()
362 |
363 | _,y_hard = torch.max(p, 1)
364 | levels = torch.arange(5).to(p.device)
365 | y_hat = torch.sum(p*levels, 1)
366 |
367 | accuracy = torch.mean((y == y_hard).float()).item()
368 | mse = torch.mean((y.float() - y_hat)**2).item()
369 |
370 | y = y.cpu().numpy()
371 | y_hat = y_hat.cpu().numpy()
372 |
373 | r,_ = pearsonr(y_hat, y)
374 | rho,_ = spearmanr(y_hat, y)
375 |
376 | line = '\t'.join([str(epoch+1).zfill(digits), 'test', str(loss), str(mse)
377 | , str(accuracy), str(r), str(rho)])
378 | print(line, file=output)
379 | output.flush()
380 |
381 |
382 | # save the model
383 | if save_prefix is not None:
384 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav'
385 | model.cpu()
386 | torch.save(model, save_path)
387 | if use_cuda:
388 | model.cuda()
389 |
390 |
391 |
392 |
393 | if __name__ == '__main__':
394 | main()
395 |
396 |
397 |
398 |
399 |
400 |
--------------------------------------------------------------------------------
/eval_contact_casp12.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import numpy as np
4 | import sys
5 | import os
6 | import glob
7 | from PIL import Image
8 |
9 | import torch
10 | from torch.nn.utils.rnn import PackedSequence
11 | import torch.utils.data
12 |
13 | from src.alphabets import Uniprot21
14 | import src.fasta as fasta
15 | from src.utils import pack_sequences, unpack_sequences
16 | from src.utils import ContactMapDataset, collate_lists
17 | from src.metrics import average_precision
18 |
19 |
20 | def load_data(seq_path, struct_path, alphabet, baselines=False):
21 |
22 | pdb_index = {}
23 | for path in struct_path:
24 | pid = os.path.basename(path).split('.')[0]
25 | pdb_index[pid] = path
26 |
27 | with open(seq_path, 'rb') as f:
28 | names, sequences = fasta.parse(f)
29 | names = [name.split()[0].decode('utf-8') for name in names]
30 | sequences = [alphabet.encode(s.upper()) for s in sequences]
31 |
32 | x = [torch.from_numpy(x).long() for x in sequences]
33 |
34 | if baselines:
35 | preds = load_baselines()
36 | preds_domains = {}
37 |
38 | names_ = []
39 | x_ = []
40 | y = []
41 | missing = 0
42 | for xi,name in zip(x,names):
43 | pid = name
44 | if pid not in pdb_index:
45 | print('MISSING:', pid, 'not in structures. Skipping.', file=sys.stderr)
46 | missing += 1
47 | continue
48 | path = pdb_index[pid]
49 |
50 | im = np.array(Image.open(path), copy=False)
51 | contacts = np.zeros(im.shape, dtype=np.float32)
52 | contacts[im == 1] = -1
53 | contacts[im == 255] = 1
54 |
55 | # trim to the domain
56 | start = 0
57 | while contacts[start,start] < 0:
58 | start += 1
59 | end = contacts.shape[0]
60 | while contacts[end-1,end-1] < 0:
61 | end -= 1
62 | xi = xi[start:end]
63 | contacts = contacts[start:end,start:end]
64 |
65 | if baselines:
66 | tag = name.split('-')[0]
67 | for key,value in preds.items():
68 | if tag not in value:
69 | print(key, 'missing protein', tag, file=sys.stderr)
70 | logits = np.zeros(contacts.shape, dtype=np.float32) - np.inf
71 | else:
72 | logits = value[tag]
73 | logits = logits[start:end,start:end]
74 | domains = preds_domains.get(key, [])
75 | domains.append(logits)
76 | preds_domains[key] = domains
77 |
78 | # mask the matrix below the diagonal
79 | mask = np.tril_indices(contacts.shape[0], k=1)
80 | contacts[mask] = -1
81 |
82 | names_.append(name)
83 | x_.append(xi)
84 | y.append(torch.from_numpy(contacts))
85 |
86 | print('Missing', missing, 'structures from', len(sequences), 'total.', file=sys.stderr)
87 | print('Reporting on', len(x_), 'structures.', file=sys.stderr)
88 |
89 | if baselines:
90 | return x_, y, names_, preds_domains
91 |
92 | return x_, y, names_
93 |
94 | baselines = {'157': 'GREMLIN (Baker)',
95 | '079': 'iFold_1',
96 | '219': 'Deepfold-Contact',
97 | '013': 'MetaPSICOV',
98 | '451': 'RaptorX-Contact',
99 | }
100 |
101 | def load_baselines():
102 | all_preds = {}
103 | for key,value in baselines.items():
104 | path = 'data/casp12/predictions/*/*'+key+'_1'
105 | paths = glob.glob(path)
106 | preds = {}
107 | for path in paths:
108 | name = os.path.basename(path).split('RR')[0]
109 |
110 | with open(path, 'r') as f:
111 | for line in f:
112 | if line.startswith('MODEL'):
113 | break
114 | n = 0
115 | logits = None
116 | for line in f:
117 | if line.startswith('END'):
118 | break
119 | if line[0] not in '0123456789':
120 | n += len(line.strip())
121 | else:
122 | if logits is None:
123 | logits = np.zeros((n,n), dtype=np.float32) - np.inf
124 | i,j,_,_,p = line.strip().split()
125 | i = int(i) - 1
126 | j = int(j) - 1
127 | p = float(p)
128 | logits[i,j] = logits[j,i] = np.log(p) - np.log(1-p)
129 | preds[name] = logits
130 | all_preds[value] = preds
131 | return all_preds
132 |
133 | def predict_minibatch(model, x, use_cuda):
134 | b = len(x)
135 | x,order = pack_sequences(x)
136 | x = PackedSequence(x.data, x.batch_sizes)
137 | z = model(x) # embed the sequences
138 | z = unpack_sequences(z, order)
139 |
140 | logits = []
141 | for i in range(b):
142 | zi = z[i]
143 | lp = model.predict(zi.unsqueeze(0)).view(zi.size(0), zi.size(0))
144 | logits.append(lp)
145 |
146 | return logits
147 |
148 |
149 | def calc_metrics(logits, y):
150 | y_hat = (logits > 0).astype(np.float32)
151 | TP = (y_hat*y).sum()
152 | precision = 1.0
153 | if y_hat.sum() > 0:
154 | precision = TP/y_hat.sum()
155 | recall = TP/y.sum()
156 | F1 = 0
157 | if precision + recall > 0:
158 | F1 = 2*precision*recall/(precision + recall)
159 | AUPR = average_precision(y, logits)
160 | return precision, recall, F1, AUPR
161 |
162 |
163 | def calc_baselines(baselines, y, lengths, names, output=sys.stdout, individual=False):
164 |
165 | line = '\t'.join(['Distance', 'Method', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
166 | if individual:
167 | line = '\t'.join(['Distance', 'Method', 'Protein', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
168 | print(line, file=output)
169 | output.flush()
170 |
171 | y = [y_.cpu().numpy() for y_ in y]
172 |
173 | # calculate performance metrics
174 | for key,logits in baselines.items():
175 |
176 | # for all contacts
177 | y_flat = []
178 | logits_flat = []
179 | for i in range(len(y)):
180 | yi = y[i]
181 | mask = (yi < 0)
182 | y_flat.append(yi[~mask])
183 | logits_flat.append(logits[i][~mask])
184 |
185 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts
186 | precision = np.zeros(len(y))
187 | recall = np.zeros(len(y))
188 | F1 = np.zeros(len(y))
189 | AUPR = np.zeros(len(y))
190 | prL = np.zeros(len(y))
191 | prL2 = np.zeros(len(y))
192 | prL5 = np.zeros(len(y))
193 | for i in range(len(y)):
194 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
195 | precision[i] = pr
196 | recall[i] = re
197 | F1[i] = f1
198 | AUPR[i] = aupr
199 |
200 | order = np.argsort(logits_flat[i])[::-1]
201 | n = lengths[i]
202 | topL = order[:n]
203 | prL[i] = y_flat[i][topL].mean()
204 | topL2 = order[:n//2]
205 | prL2[i] = y_flat[i][topL2].mean()
206 | topL5 = order[:n//5]
207 | prL5[i] = y_flat[i][topL5].mean()
208 |
209 | if individual:
210 | template = 'All\t{}\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
211 | for i in range(len(y)):
212 | name = names[i]
213 | line = template.format(key, name, precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i])
214 | print(line, file=output)
215 | else:
216 | template = 'All\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
217 | line = template.format(key, precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
218 | print(line, file=output)
219 | output.flush()
220 |
221 | # for Medium/Long range contacts
222 | y_flat = []
223 | logits_flat = []
224 | for i in range(len(y)):
225 | yi = y[i]
226 | mask = (yi < 0)
227 |
228 | medlong = np.tril_indices(len(yi), k=11)
229 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8)
230 | medlong_mask[medlong] = 1
231 | mask = mask | (medlong_mask == 1)
232 |
233 | y_flat.append(yi[~mask])
234 | logits_flat.append(logits[i][~mask])
235 |
236 | # calculate precision, recall, F1, and area under the precision recall curve for med/long range contacts
237 | precision = np.zeros(len(y))
238 | recall = np.zeros(len(y))
239 | F1 = np.zeros(len(y))
240 | AUPR = np.zeros(len(y))
241 | prL = np.zeros(len(y))
242 | prL2 = np.zeros(len(y))
243 | prL5 = np.zeros(len(y))
244 | for i in range(len(y)):
245 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
246 | precision[i] = pr
247 | recall[i] = re
248 | F1[i] = f1
249 | AUPR[i] = aupr
250 |
251 | order = np.argsort(logits_flat[i])[::-1]
252 | n = lengths[i]
253 | topL = order[:n]
254 | prL[i] = y_flat[i][topL].mean()
255 | topL2 = order[:n//2]
256 | prL2[i] = y_flat[i][topL2].mean()
257 | topL5 = order[:n//5]
258 | prL5[i] = y_flat[i][topL5].mean()
259 |
260 | if individual:
261 | template = 'Medium/Long\t{}\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
262 | for i in range(len(y)):
263 | name = names[i]
264 | line = template.format(key, name, precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i])
265 | print(line, file=output)
266 | else:
267 | template = 'Medium/Long\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
268 | line = template.format(key, precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
269 | print(line, file=output)
270 | output.flush()
271 |
272 |
273 | def main():
274 | import argparse
275 | parser = argparse.ArgumentParser('Script for evaluating contact map models.')
276 | parser.add_argument('model', help='path to saved model')
277 | parser.add_argument('--batch-size', default=10, type=int, help='number of sequences to process in each batch (default: 10)')
278 | parser.add_argument('-o', '--output', help='output file path (default: stdout)')
279 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
280 | parser.add_argument('--individual', action='store_true')
281 | args = parser.parse_args()
282 |
283 | # load the data
284 | fasta_path = 'data/casp12/casp12.fm-domains.seq.fa'
285 | contact_paths = glob.glob('data/casp12/domains_T0/*.png')
286 |
287 | alphabet = Uniprot21()
288 | baselines = None
289 | if args.model == 'baselines':
290 | x,y,names,baselines = load_data(fasta_path, contact_paths, alphabet, baselines=True)
291 | else:
292 | x,y,names = load_data(fasta_path, contact_paths, alphabet)
293 |
294 | if baselines is not None:
295 | output = args.output
296 | if output is None:
297 | output = sys.stdout
298 | else:
299 | output = open(output, 'w')
300 |
301 | lengths = np.array([len(x_) for x_ in x])
302 | calc_baselines(baselines, y, lengths, names, output=output, individual=args.individual)
303 |
304 | sys.exit(0)
305 |
306 | ## set the device
307 | d = args.device
308 | use_cuda = (d != -1) and torch.cuda.is_available()
309 | if d >= 0:
310 | torch.cuda.set_device(d)
311 |
312 | if use_cuda:
313 | x = [x_.cuda() for x_ in x]
314 | y = [y_.cuda() for y_ in y]
315 |
316 | model = torch.load(args.model)
317 | model.eval()
318 | if use_cuda:
319 | model.cuda()
320 |
321 | # predict contact maps
322 | batch_size = args.batch_size
323 | dataset = ContactMapDataset(x, y)
324 | iterator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_lists)
325 | logits = []
326 | with torch.no_grad():
327 | for xmb,ymb in iterator:
328 | lmb = predict_minibatch(model, xmb, use_cuda)
329 | logits += lmb
330 |
331 | # calculate performance metrics
332 | lengths = np.array([len(x_) for x_ in x])
333 | logits = [logit.cpu().numpy() for logit in logits]
334 | y = [y_.cpu().numpy() for y_ in y]
335 |
336 | output = args.output
337 | if output is None:
338 | output = sys.stdout
339 | else:
340 | output = open(output, 'w')
341 | if args.individual:
342 | line = '\t'.join(['Distance', 'Protein', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
343 | else:
344 | line = '\t'.join(['Distance', 'Precision', 'Recall', 'F1', 'AUPR', 'Precision@L', 'Precision@L/2', 'Precision@L/5'])
345 | print(line, file=output)
346 | output.flush()
347 |
348 | # for all contacts
349 | y_flat = []
350 | logits_flat = []
351 | for i in range(len(y)):
352 | yi = y[i]
353 | mask = (yi < 0)
354 | y_flat.append(yi[~mask])
355 | logits_flat.append(logits[i][~mask])
356 |
357 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts
358 | precision = np.zeros(len(x))
359 | recall = np.zeros(len(x))
360 | F1 = np.zeros(len(x))
361 | AUPR = np.zeros(len(x))
362 | prL = np.zeros(len(x))
363 | prL2 = np.zeros(len(x))
364 | prL5 = np.zeros(len(x))
365 | for i in range(len(x)):
366 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
367 | precision[i] = pr
368 | recall[i] = re
369 | F1[i] = f1
370 | AUPR[i] = aupr
371 |
372 | order = np.argsort(logits_flat[i])[::-1]
373 | n = lengths[i]
374 | topL = order[:n]
375 | prL[i] = y_flat[i][topL].mean()
376 | topL2 = order[:n//2]
377 | prL2[i] = y_flat[i][topL2].mean()
378 | topL5 = order[:n//5]
379 | prL5[i] = y_flat[i][topL5].mean()
380 |
381 | if args.individual:
382 | template = 'All\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
383 | for i in range(len(x)):
384 | name = names[i]
385 | line = template.format(name,precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i])
386 | print(line, file=output)
387 | else:
388 | template = 'All\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
389 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
390 | print(line, file=output)
391 | output.flush()
392 |
393 | # for Medium/Long range contacts
394 | y_flat = []
395 | logits_flat = []
396 | for i in range(len(y)):
397 | yi = y[i]
398 | mask = (yi < 0)
399 |
400 | medlong = np.tril_indices(len(yi), k=11)
401 | medlong_mask = np.zeros((len(yi),len(yi)), dtype=np.uint8)
402 | medlong_mask[medlong] = 1
403 | mask = mask | (medlong_mask == 1)
404 |
405 | y_flat.append(yi[~mask])
406 | logits_flat.append(logits[i][~mask])
407 |
408 | # calculate precision, recall, F1, and area under the precision recall curve for all contacts
409 | precision = np.zeros(len(x))
410 | recall = np.zeros(len(x))
411 | F1 = np.zeros(len(x))
412 | AUPR = np.zeros(len(x))
413 | prL = np.zeros(len(x))
414 | prL2 = np.zeros(len(x))
415 | prL5 = np.zeros(len(x))
416 | for i in range(len(x)):
417 | pr,re,f1,aupr = calc_metrics(logits_flat[i], y_flat[i])
418 | precision[i] = pr
419 | recall[i] = re
420 | F1[i] = f1
421 | AUPR[i] = aupr
422 |
423 | order = np.argsort(logits_flat[i])[::-1]
424 | n = lengths[i]
425 | topL = order[:n]
426 | prL[i] = y_flat[i][topL].mean()
427 | topL2 = order[:n//2]
428 | prL2[i] = y_flat[i][topL2].mean()
429 | topL5 = order[:n//5]
430 | prL5[i] = y_flat[i][topL5].mean()
431 |
432 | if args.individual:
433 | template = 'Medium/Long\t{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
434 | for i in range(len(x)):
435 | name = names[i]
436 | line = template.format(name,precision[i], recall[i], F1[i], AUPR[i], prL[i], prL2[i], prL5[i])
437 | print(line, file=output)
438 | else:
439 | template = 'Medium/Long\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
440 | line = template.format(precision.mean(), recall.mean(), F1.mean(), AUPR.mean(), prL.mean(), prL2.mean(), prL5.mean())
441 | print(line, file=output)
442 | output.flush()
443 |
444 |
445 |
446 | if __name__ == '__main__':
447 | main()
448 |
449 |
450 |
451 |
452 |
453 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
409 |
--------------------------------------------------------------------------------
/src/alignment.pyx:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | cimport numpy as np
4 | import numpy as np
5 | cimport cython
6 |
7 | @cython.boundscheck(False)
8 | @cython.wraparound(False)
9 | def half_global_alignment(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S
10 | , np.int32_t gap, np.int32_t extend):
11 |
12 | """ Matches all of x to a substring of y"""
13 |
14 | cdef int n,m
15 | n = len(x)
16 | m = len(y)
17 |
18 | cdef np.float32_t MI = -np.inf
19 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32)
20 |
21 | A[0,1:,0] = MI
22 | A[1:,0,0] = MI
23 |
24 | A[0,1:,1] = MI # gap in x
25 | A[1:,0,1] = gap + np.arange(n)*extend
26 |
27 | # starting from gap in y costs 0
28 | A[1:,0,2] = MI
29 |
30 | # initialize the traceback matrix
31 | cdef np.ndarray[np.int8_t, ndim=3] tb = np.zeros((n+1,m+1,3), dtype=np.int8) - 1
32 |
33 | cdef np.float32_t s
34 | cdef int i,j
35 | cdef np.int8_t k
36 |
37 | for i in range(n):
38 | for j in range(m):
39 | # match i,j
40 | k = 0
41 | s = A[i,j,0]
42 | if A[i,j,1] > s:
43 | k = 1
44 | s = A[i,j,1]
45 | if A[i,j,2] > s:
46 | k = 2
47 | s = A[i,j,2]
48 | A[i+1,j+1,0] = s + S[x[i],y[j]]
49 | tb[i+1,j+1,0] = k
50 | # insert in x
51 | k = 0
52 | s = A[i,j+1,0] + gap
53 | if A[i,j+1,1] + extend > s:
54 | k = 1
55 | s = A[i,j+1,1] + extend
56 | if A[i,j+1,2] + gap > s:
57 | k = 2
58 | s = A[i,j+1,2] + gap
59 | A[i+1,j+1,1] = s
60 | tb[i+1,j+1,1] = k
61 | # insert in y
62 | k = 0
63 | s = A[i+1,j,0] + gap
64 | if A[i+1,j,1] + gap > s:
65 | k = 1
66 | s = A[i+1,j,1] + gap
67 | if A[i+1,j,2] + extend > s:
68 | k = 2
69 | s = A[i+1,j,2] + extend
70 | A[i+1,j+1,2] = s
71 | tb[i+1,j+1,2] = k
72 |
73 |
74 | # find the end of the best alignment
75 | cdef int j_max, k_max
76 | cdef np.float32_t s_max
77 |
78 | j_max = 0
79 | k_max = 0
80 | s_max = A[n,j_max,k_max]
81 | for j in range(m+1):
82 | for k in range(3):
83 | if A[n,j,k] > s_max:
84 | s_max = A[n,j,k]
85 | j_max = j
86 | k_max = k
87 |
88 | # backtrack the alignment
89 | cdef int k_next
90 | i = n
91 | j = j_max
92 | k = k_max
93 | while tb[i,j,k] >= 0:
94 | k_next = tb[i,j,k]
95 | if k == 0:
96 | i = i - 1
97 | j = j - 1
98 | elif k == 1:
99 | i = i - 1
100 | elif k == 2:
101 | j = j - 1
102 | k = k_next
103 |
104 | return s_max, j, j_max
105 |
106 | cdef int argmax(np.float32_t a, np.float32_t b, np.float32_t c):
107 | if b > a:
108 | if c > b:
109 | return 2
110 | return 1
111 | if c > a:
112 | return 2
113 | return 0
114 |
115 | @cython.boundscheck(False)
116 | @cython.wraparound(False)
117 | def nw_score_extra(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S
118 | , np.int32_t gap, np.int32_t extend):
119 |
120 | cdef int n,m
121 | n = len(x)
122 | m = len(y)
123 |
124 | cdef np.float32_t MI = -np.inf
125 |
126 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32)
127 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32)
128 | cdef np.ndarray[np.float32_t, ndim=2] A_temp
129 |
130 | A[1:,0] = MI # match scores
131 | A[1:,1] = MI # gap in x
132 | A[1:,2] = gap + np.arange(m)*extend # gap in y
133 |
134 | # also calculate the length of the alignment
135 | cdef np.ndarray[np.int32_t, ndim=2] L_prev = np.zeros((m+1,3), dtype=np.int32)
136 | cdef np.ndarray[np.int32_t, ndim=2] L = np.zeros((m+1,3), dtype=np.int32)
137 | cdef np.ndarray[np.int32_t, ndim=2] L_temp
138 |
139 | # and the number of exact matches within the alignment
140 | cdef np.ndarray[np.int32_t, ndim=2] M_prev = np.zeros((m+1,3), dtype=np.int32)
141 | cdef np.ndarray[np.int32_t, ndim=2] M = np.zeros((m+1,3), dtype=np.int32)
142 | cdef np.ndarray[np.int32_t, ndim=2] M_temp
143 |
144 |
145 | cdef np.float32_t s
146 | cdef int i,j,k
147 |
148 | for i in range(n):
149 | # swap A and A_prev
150 | A_temp = A_prev
151 | A_prev = A
152 | A = A_temp
153 |
154 | L_temp = L_prev
155 | L_prev = L
156 | L = L_temp
157 |
158 | M_temp = M_prev
159 | M_prev = M
160 | M = M_temp
161 |
162 | # init A[0]
163 | A[0,0] = MI
164 | A[0,1] = gap + i*extend
165 | A[0,2] = MI
166 | for j in range(m):
167 | # match i,j
168 | s = S[x[i],y[j]]
169 | k = argmax(A_prev[j,0], A_prev[j,1], A_prev[j,2])
170 | A[j+1,0] = s + A_prev[j,k]
171 | L[j+1,0] = L_prev[j,k] + 1
172 | if x[i] == y[j]:
173 | M[j+1,0] = M_prev[j,k] + 1
174 | else:
175 | M[j+1,0] = M_prev[j,k]
176 |
177 | # insert in x
178 | k = argmax(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap)
179 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap)
180 | L[j+1,1] = L_prev[j+1,k] + 1
181 | M[j+1,1] = M_prev[j+1,k]
182 |
183 | # insert in y
184 | k = argmax(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend)
185 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend)
186 | L[j+1,2] = L[j,k] + 1
187 | M[j+1,2] = M[j,k]
188 |
189 | k = argmax(A[m,0], A[m,1], A[m,2])
190 |
191 | return A[m,k], M[m,k], L[m,k]
192 |
193 |
194 | @cython.boundscheck(False)
195 | @cython.wraparound(False)
196 | def sw_score_subst_no_affine(np.ndarray[np.float32_t,ndim=2] subst
197 | , np.float32_t gap):
198 |
199 | cdef int n,m
200 | n = subst.shape[0]
201 | m = subst.shape[1]
202 |
203 | cdef np.float32_t MI = -np.inf
204 |
205 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.float32)
206 |
207 | cdef int i,j
208 | for i in range(n):
209 | for j in range(m):
210 | A[i+1,j+1] = max(subst[i,j] + A[i,j], gap+A[i+1,j], gap+A[i,j+1], 0)
211 |
212 | return np.max(A)
213 |
214 | @cython.boundscheck(False)
215 | @cython.wraparound(False)
216 | def sw_score_subst(np.ndarray[np.float32_t,ndim=2] subst
217 | , np.float32_t gap, np.float32_t extend):
218 |
219 | cdef int n,m
220 | n = subst.shape[0]
221 | m = subst.shape[1]
222 |
223 | cdef np.float32_t MI = -np.inf
224 |
225 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32)
226 |
227 | A[0,1:,0] = MI # match scores
228 | A[1:,0,0] = MI
229 |
230 | A[0,1:,1] = MI # gap in x
231 | A[1:,0,2] = MI # gap in y
232 |
233 | cdef np.float32_t s
234 | cdef int i,j
235 |
236 | for i in range(n):
237 | for j in range(m):
238 | # match i,j
239 | s = subst[i,j]
240 | A[i+1,j+1,0] = s + max(A[i,j,0], A[i,j,1], A[i,j,2], 0)
241 | # insert in x
242 | A[i+1,j+1,1] = max(A[i,j+1,0]+gap, A[i,j+1,1]+extend, A[i,j+1,2]+gap, 0)
243 | # insert in y
244 | A[i+1,j+1,2] = max(A[i+1,j,0]+gap, A[i+1,j,1]+gap, A[i+1,j,2]+extend, 0)
245 |
246 | s = np.max(A)
247 | #s = max(A[m,0], A[m,1], A[m,2])
248 |
249 | return s
250 |
251 |
252 | @cython.boundscheck(False)
253 | @cython.wraparound(False)
254 | def sw_score(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S
255 | , np.int32_t gap, np.int32_t extend):
256 |
257 | cdef int n,m
258 | n = len(x)
259 | m = len(y)
260 |
261 | cdef np.float32_t MI = -np.inf
262 |
263 | cdef np.ndarray[np.float32_t, ndim=3] A = np.zeros((n+1,m+1,3), dtype=np.float32)
264 |
265 | A[0,1:,0] = MI # match scores
266 | A[1:,0,0] = MI
267 |
268 | A[0,1:,1] = MI # gap in x
269 | A[1:,0,2] = MI # gap in y
270 |
271 | #A[1:,2] = gap + np.arange(m)*extend # gap in y
272 |
273 | cdef np.float32_t s
274 | cdef int i,j
275 |
276 | for i in range(n):
277 | for j in range(m):
278 | # match i,j
279 | s = S[x[i],y[j]]
280 | A[i+1,j+1,0] = s + max(A[i,j,0], A[i,j,1], A[i,j,2], 0)
281 | # insert in x
282 | A[i+1,j+1,1] = max(A[i,j+1,0]+gap, A[i,j+1,1]+extend, A[i,j+1,2]+gap, 0)
283 | # insert in y
284 | A[i+1,j+1,2] = max(A[i+1,j,0]+gap, A[i+1,j,1]+gap, A[i+1,j,2]+extend, 0)
285 |
286 | s = np.max(A)
287 | #s = max(A[m,0], A[m,1], A[m,2])
288 |
289 | return s
290 |
291 |
292 | @cython.boundscheck(False)
293 | @cython.wraparound(False)
294 | def nw_score(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y, np.ndarray[np.int32_t, ndim=2] S
295 | , np.int32_t gap, np.int32_t extend):
296 |
297 | cdef int n,m
298 | n = len(x)
299 | m = len(y)
300 |
301 | cdef np.float32_t MI = -np.inf
302 |
303 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32)
304 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32)
305 | cdef np.ndarray[np.float32_t, ndim=2] A_temp
306 |
307 | A[1:,0] = MI # match scores
308 | A[1:,1] = MI # gap in x
309 | A[1:,2] = gap + np.arange(m)*extend # gap in y
310 |
311 | cdef np.float32_t s
312 | cdef int i,j
313 |
314 | for i in range(n):
315 | # swap A and A_prev
316 | A_temp = A_prev
317 | A_prev = A
318 | A = A_temp
319 | # init A[0]
320 | A[0,0] = MI
321 | A[0,1] = gap + i*extend
322 | A[0,2] = MI
323 | for j in range(m):
324 | # match i,j
325 | s = S[x[i],y[j]]
326 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2])
327 | # insert in x
328 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap)
329 | #A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend)
330 | # insert in y
331 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend)
332 | #A[j+1,2] = max(A[j,0]+gap, A[j,2]+extend)
333 |
334 | s = max(A[m,0], A[m,1], A[m,2])
335 |
336 | return s
337 |
338 | @cython.boundscheck(False)
339 | @cython.wraparound(False)
340 | def nw_score_subst(np.ndarray[np.float32_t, ndim=2] S
341 | , np.int32_t gap, np.int32_t extend):
342 |
343 | cdef int n,m
344 | n = S.shape[0]
345 | m = S.shape[1]
346 |
347 | cdef np.float32_t MI = -np.inf
348 |
349 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32)
350 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32)
351 | cdef np.ndarray[np.float32_t, ndim=2] A_temp
352 |
353 | A[1:,0] = MI # match scores
354 | A[1:,1] = MI # gap in x
355 | A[1:,2] = gap + np.arange(m)*extend # gap in y
356 |
357 | cdef np.float32_t s
358 | cdef int i,j
359 |
360 | for i in range(n):
361 | # swap A and A_prev
362 | A_temp = A_prev
363 | A_prev = A
364 | A = A_temp
365 | # init A[0]
366 | A[0,0] = MI
367 | A[0,1] = gap + i*extend
368 | A[0,2] = MI
369 | for j in range(m):
370 | # match i,j
371 | s = S[i,j]
372 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2])
373 | # insert in x
374 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap)
375 | #A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend)
376 | # insert in y
377 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend)
378 | #A[j+1,2] = max(A[j,0]+gap, A[j,2]+extend)
379 |
380 | s = max(A[m,0], A[m,1], A[m,2])
381 |
382 | return s
383 |
384 | @cython.boundscheck(False)
385 | @cython.wraparound(False)
386 | def nw_align_subst(np.ndarray[np.float32_t, ndim=2] S
387 | , np.int32_t gap, np.int32_t extend):
388 |
389 | cdef int n,m
390 | n = S.shape[0]
391 | m = S.shape[1]
392 |
393 | cdef np.float32_t MI = -np.inf
394 |
395 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32)
396 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32)
397 | cdef np.ndarray[np.float32_t, ndim=2] A_temp
398 |
399 | A[1:,0] = MI # match scores
400 | A[1:,1] = MI # gap in x
401 | A[1:,2] = gap + np.arange(m)*extend # gap in y
402 |
403 | cdef np.ndarray[np.int8_t, ndim=3] tb = np.zeros((n+1,m+1,3), dtype=np.int8)
404 | tb[:] = -1
405 | tb[1:,:,1] = 1
406 | tb[:,1:,2] = 2
407 |
408 | cdef np.float32_t s
409 | cdef int i,j,k
410 |
411 | for i in range(n):
412 | # swap A and A_prev
413 | A_temp = A_prev
414 | A_prev = A
415 | A = A_temp
416 | # init A[0]
417 | A[0,0] = MI
418 | A[0,1] = gap + i*extend
419 | A[0,2] = MI
420 | for j in range(m):
421 | # match i,j
422 | s = S[i,j]
423 | k = 0
424 | if A_prev[j,k] < A_prev[j,1]:
425 | k = 1
426 | if A_prev[j,k] < A_prev[j,2]:
427 | k = 2
428 | A[j+1,0] = s + A_prev[j,k]
429 | tb[i+1,j+1,0] = k
430 |
431 | # insert in x
432 | s = A_prev[j+1,0] + gap
433 | k = 0
434 | if s < A_prev[j+1,1] + extend:
435 | s = A_prev[j+1,1] + extend
436 | k = 1
437 | if s < A_prev[j+1,2] + gap:
438 | s = A_prev[j+1,2] + gap
439 | k = 2
440 | A[j+1,1] = s
441 | tb[i+1,j+1,1] = k
442 |
443 | # insert in y
444 | s = A[j,0] + gap
445 | k = 0
446 | if s < A[j,1] + gap:
447 | s = A[j,1] + gap
448 | k = 1
449 | if s < A[j,2] + extend:
450 | s = A[j,2] + extend
451 | k = 2
452 | A[j+1,2] = s
453 | tb[i+1,j+1,2] = k
454 |
455 | k = 0
456 | s = A[m,k]
457 | if s < A[m,1]:
458 | s = A[m,1]
459 | k = 1
460 | if s < A[m,2]:
461 | s = A[m,2]
462 | k = 2
463 |
464 | ## traceback
465 | align = []
466 | i = n
467 | j = m
468 | while k >= 0:
469 | k_ = tb[i,j,k]
470 | if k == 0:
471 | align.append((i-1,j-1))
472 | i -= 1
473 | j -= 1
474 | elif k == 1:
475 | i -= 1
476 | elif k == 2:
477 | j -= 1
478 | k = k_
479 | align = np.array(align[::-1], dtype=int)
480 |
481 | return s, align
482 |
483 | @cython.boundscheck(False)
484 | @cython.wraparound(False)
485 | def nw_affine_subst_score(np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap
486 | , np.float32_t extend):
487 |
488 | cdef int n,m
489 | n = S.shape[0]
490 | m = S.shape[1]
491 |
492 | cdef np.float32_t MI = -np.inf
493 |
494 | cdef np.ndarray[np.float32_t, ndim=2] A_prev = np.zeros((m+1,3), dtype=np.float32)
495 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((m+1,3), dtype=np.float32)
496 | cdef np.ndarray[np.float32_t, ndim=2] A_temp
497 |
498 | A[1:,0] = MI # match scores
499 | A[1:,1] = MI # gap in x
500 | A[1:,2] = gap + np.arange(m)*extend # gap in y
501 |
502 | cdef np.float32_t s
503 | cdef int i,j
504 |
505 | for i in range(n):
506 | # swap A and A_prev
507 | A_temp = A_prev
508 | A_prev = A
509 | A = A_temp
510 | # init A[0]
511 | A[0,0] = MI
512 | A[0,1] = gap + i*extend
513 | A[0,2] = MI
514 | for j in range(m):
515 | # match i,j
516 | s = S[i,j]
517 | A[j+1,0] = s + max(A_prev[j,0], A_prev[j,1], A_prev[j,2])
518 | # insert in x
519 | A[j+1,1] = max(A_prev[j+1,0]+gap, A_prev[j+1,1]+extend, A_prev[j+1,2]+gap)
520 | # insert in y
521 | A[j+1,2] = max(A[j,0]+gap, A[j,1]+gap, A[j,2]+extend)
522 |
523 | s = max(A[m,0], A[m,1], A[m,2])
524 |
525 | return s
526 |
527 | @cython.boundscheck(False)
528 | @cython.wraparound(False)
529 | def nw_subst_score(np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap): #, np.float32_t extend):
530 |
531 | cdef int n,m
532 | n = S.shape[0]
533 | m = S.shape[1]
534 |
535 | cdef np.float32_t MI = -np.inf
536 |
537 | cdef np.ndarray[np.float32_t] A_prev = np.zeros(m+1, dtype=np.float32)
538 | cdef np.ndarray[np.float32_t] A = np.zeros(m+1, dtype=np.float32)
539 | cdef np.ndarray[np.float32_t] A_temp
540 |
541 | A[1:] = gap*np.arange(m)
542 |
543 | cdef np.float32_t s
544 | cdef int i,j
545 |
546 | for i in range(n):
547 | # swap A and A_prev
548 | A_temp = A_prev
549 | A_prev = A
550 | A = A_temp
551 | # init A[0]
552 | A[0] = gap*(i+1)
553 | for j in range(m):
554 | # match i,j
555 | s = S[i,j]
556 | A[j+1] = max(A_prev[j]+s, A_prev[j+1]+gap, A[j]+gap)
557 |
558 | s = A[m]
559 |
560 | return s
561 |
562 | @cython.boundscheck(False)
563 | @cython.wraparound(False)
564 | def edit_distance(np.ndarray[np.uint8_t] x, np.ndarray[np.uint8_t] y):
565 |
566 | cdef int n,m
567 | n = len(x)
568 | m = len(y)
569 |
570 | cdef np.ndarray[np.int32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.int32)
571 | A[0,1:] = np.arange(m) + 1
572 | A[1:,0] = np.arange(n) + 1
573 |
574 | cdef int i,j
575 | cdef np.int32_t s
576 |
577 | for i in range(n):
578 | for j in range(m):
579 | s = int(x[i] != y[j])
580 | A[i+1,j+1] = min(A[i,j] + s, A[i+1,j] + 1, A[i,j+1] + 1)
581 |
582 | return A[n,m]
583 |
584 | @cython.boundscheck(False)
585 | @cython.wraparound(False)
586 | def dtw_subst_score(np.ndarray[np.float32_t, ndim=2] S):
587 |
588 | cdef int n,m
589 | n = S.shape[0]
590 | m = S.shape[1]
591 |
592 | cdef np.float32_t MI = -np.inf
593 |
594 | cdef np.ndarray[np.float32_t] A_prev = np.zeros(m+1, dtype=np.float32)
595 | cdef np.ndarray[np.float32_t] A = np.zeros(m+1, dtype=np.float32)
596 | cdef np.ndarray[np.float32_t] A_temp
597 |
598 | cdef np.float32_t s
599 | cdef int i,j
600 |
601 | for i in range(n):
602 | # swap A and A_prev
603 | A_temp = A_prev
604 | A_prev = A
605 | A = A_temp
606 | for j in range(m):
607 | # match i,j
608 | s = S[i,j]
609 | A[j+1] = s + max(A_prev[j], A_prev[j+1], A[j])
610 |
611 | s = A[m]
612 |
613 | return s
614 |
615 | @cython.boundscheck(False)
616 | @cython.wraparound(False)
617 | def nw_align( np.ndarray[np.float32_t, ndim=2] S, np.float32_t gap):
618 |
619 | cdef int n,m
620 | n = S.shape[0]
621 | m = S.shape[1]
622 |
623 | cdef int i,j
624 | cdef np.float32_t score
625 | cdef np.ndarray[np.float32_t, ndim=2] A = np.zeros((n+1,m+1), dtype=np.float32)
626 | A[:,0] = gap*np.arange(n+1)
627 | A[0,:] = gap*np.arange(m+1)
628 | A[1:,1:] = -np.inf
629 |
630 | cdef np.ndarray[np.int8_t, ndim=2] paths = np.zeros((n+1,m+1), dtype=np.int8)
631 | paths[:] = -1
632 | paths[0,1:] = 1
633 | paths[1:,0] = 2
634 |
635 | i = j = 0
636 | score = -np.inf
637 |
638 | for i in range(n):
639 | for j in range(m):
640 | s = A[i,j] + S[i,j]
641 | if s > A[i+1,j+1]:
642 | A[i+1,j+1] = s
643 | paths[i+1,j+1] = 0
644 |
645 | gi = A[i+1,j] + gap
646 | if gi > A[i+1,j+1]:
647 | A[i+1,j+1] = gi
648 | paths[i+1,j+1] = 1
649 |
650 | gj = A[i,j+1] + gap
651 | if gj > A[i+1,j+1]:
652 | A[i+1,j+1] = gj
653 | paths[i+1,j+1] = 2
654 |
655 | ## traceback
656 | tb = []
657 | i = n
658 | j = m
659 | while paths[i,j] >= 0:
660 | if paths[i,j] == 0:
661 | tb.append((i-1,j-1))
662 | i -= 1
663 | j -= 1
664 | elif paths[i,j] == 1:
665 | j -= 1
666 | elif paths[i,j] == 2:
667 | i -= 1
668 | tb = np.array(tb[::-1], dtype=int)
669 |
670 |
671 | return A[n,m], tb
672 |
673 |
674 |
675 |
676 |
--------------------------------------------------------------------------------
/train_similarity_and_contact.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import sys
6 | import os
7 | import glob
8 | from PIL import Image
9 | from scipy.stats import pearsonr, spearmanr
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
16 | import torch.utils.data
17 |
18 | from src.alphabets import Uniprot21
19 | import src.scop as scop
20 | from src.utils import pack_sequences, unpack_sequences
21 | from src.utils import ContactMapDataset, collate_lists
22 | from src.utils import PairedDataset, AllPairsDataset, collate_paired_sequences
23 | from src.utils import MultinomialResample
24 | import src.models.embedding
25 | import src.models.multitask
26 | from src.metrics import average_precision
27 |
28 | cmap_paths = glob.glob('data/SCOPe/pdbstyle-2.06/*/*.png')
29 | cmap_dict = {os.path.basename(path)[:7] : path for path in cmap_paths}
30 |
31 |
32 | def load_data(path, alphabet):
33 | with open(path, 'rb') as f:
34 | names, structs, sequences = scop.parse_astral(f, encoder=alphabet)
35 | x = [torch.from_numpy(x).long() for x in sequences]
36 | s = torch.from_numpy(structs)
37 | c = []
38 | for name in names:
39 | name = name.decode('utf-8')
40 | if name not in cmap_dict:
41 | name = 'd' + name[1:]
42 | path = cmap_dict[name]
43 | im = np.array(Image.open(path), copy=False)
44 | contacts = np.zeros(im.shape, dtype=np.float32)
45 | contacts[im == 1] = -1
46 | contacts[im == 255] = 1
47 | # mask the matrix below the diagonal
48 | mask = np.tril_indices(contacts.shape[0], k=-1)
49 | contacts[mask] = -1
50 | c.append(torch.from_numpy(contacts))
51 | return x, s, c
52 |
53 |
54 | def load_scop_testpairs(astral_testpairs_path, alphabet):
55 | print('# loading test sequence pairs:', astral_testpairs_path, file=sys.stderr)
56 | test_pairs_table = pd.read_csv(astral_testpairs_path, sep='\t')
57 | x0_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_A']]
58 | x0_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x0_test]
59 | x1_test = [x.encode('utf-8').upper() for x in test_pairs_table['sequence_B']]
60 | x1_test = [torch.from_numpy(alphabet.encode(x)).long() for x in x1_test]
61 | y_test = test_pairs_table['similarity'].values
62 | y_test = torch.from_numpy(y_test).long()
63 |
64 | return x0_test, x1_test, y_test
65 |
66 |
67 | def similarity_grad(model, x0, x1, y, use_cuda, weight=0.5):
68 | if use_cuda:
69 | y = y.cuda()
70 | y = Variable(y)
71 |
72 | b = len(x0)
73 | x = x0 + x1
74 |
75 | x,order = pack_sequences(x)
76 | x = PackedSequence(Variable(x.data), x.batch_sizes)
77 | z = model(x) # embed the sequences
78 | z = unpack_sequences(z, order)
79 |
80 | z0 = z[:b]
81 | z1 = z[b:]
82 |
83 | logits = []
84 | for i in range(b):
85 | z_a = z0[i]
86 | z_b = z1[i]
87 | logits.append(model.score(z_a, z_b))
88 | logits = torch.stack(logits, 0)
89 |
90 | loss = F.binary_cross_entropy_with_logits(logits, y.float())
91 |
92 | # backprop weighted loss
93 | w_loss = loss*weight
94 | w_loss.backward()
95 |
96 | # calculate minibatch performance metrics
97 | with torch.no_grad():
98 | p = torch.sigmoid(logits)
99 | ones = p.new(b,1).zero_() + 1
100 | p_ge = torch.cat([ones, p], 1)
101 | p_lt = torch.cat([1-p, ones], 1)
102 | p = p_ge*p_lt
103 | p = p/p.sum(1,keepdim=True) # make sure p is normalized
104 |
105 | _,y_hard = torch.max(p, 1)
106 | levels = torch.arange(5).to(p.device)
107 | y_hat = torch.sum(p*levels, 1)
108 | y = torch.sum(y.data, 1)
109 |
110 | loss = F.cross_entropy(p, y).item() # calculate cross entropy loss from p vector
111 |
112 | correct = torch.sum((y == y_hard).float()).item()
113 | mse = torch.mean((y.float() - y_hat)**2).item()
114 |
115 | return loss, correct, mse, b
116 |
117 |
118 | def contacts_grad(model, x, y, use_cuda, weight=0.5):
119 | b = len(x)
120 | x,order = pack_sequences(x)
121 | x = PackedSequence(Variable(x.data), x.batch_sizes)
122 | z = model(x) # embed the sequences
123 | z = unpack_sequences(z, order)
124 |
125 | logits = []
126 | for i in range(b):
127 | zi = z[i]
128 | lp = model.predict(zi.unsqueeze(0)).view(-1)
129 | logits.append(lp)
130 | logits = torch.cat(logits, 0)
131 |
132 | y = torch.cat([yi.view(-1) for yi in y])
133 | if use_cuda:
134 | y = y.cuda()
135 | mask = (y < 0)
136 |
137 | logits = logits[~mask]
138 | y = Variable(y[~mask])
139 | b = y.size(0)
140 |
141 | loss = F.binary_cross_entropy_with_logits(logits, y)
142 |
143 | # backprop weighted loss
144 | w_loss = loss*weight
145 | w_loss.backward()
146 |
147 | # calculate the recall and precision
148 | with torch.no_grad():
149 | p_hat = torch.sigmoid(logits)
150 | tp = torch.sum(p_hat*y).item()
151 | gp = y.sum().item()
152 | pp = p_hat.sum().item()
153 |
154 | return loss.item(), tp, gp, pp, b
155 |
156 |
157 | def predict_contacts(model, x, y, use_cuda):
158 | b = len(x)
159 | x,order = pack_sequences(x)
160 | x = PackedSequence(Variable(x.data), x.batch_sizes)
161 | z = model(x) # embed the sequences
162 | z = unpack_sequences(z, order)
163 |
164 | logits = []
165 | y_list = []
166 | for i in range(b):
167 | zi = z[i]
168 | lp = model.predict(zi.unsqueeze(0)).view(-1)
169 |
170 | yi = y[i].view(-1)
171 | if use_cuda:
172 | yi = yi.cuda()
173 | mask = (yi < 0)
174 |
175 | lp = lp[~mask]
176 | yi = yi[~mask]
177 |
178 | logits.append(lp)
179 | y_list.append(yi)
180 |
181 | return logits, y_list
182 |
183 |
184 | def eval_contacts(model, test_iterator, use_cuda):
185 | logits = []
186 | y = []
187 |
188 | for x,y_mb in test_iterator:
189 | logits_this, y_this = predict_contacts(model, x, y_mb, use_cuda)
190 | logits += logits_this
191 | y += y_this
192 |
193 | y = torch.cat(y, 0)
194 | logits = torch.cat(logits, 0)
195 |
196 | loss = F.binary_cross_entropy_with_logits(logits, y).item()
197 |
198 | p_hat = torch.sigmoid(logits)
199 | tp = torch.sum(y*p_hat).item()
200 | pr = tp/torch.sum(p_hat).item()
201 | re = tp/torch.sum(y).item()
202 | f1 = 2*pr*re/(pr + re)
203 |
204 | y = y.cpu().numpy()
205 | logits = logits.data.cpu().numpy()
206 |
207 | aupr = average_precision(y, logits)
208 |
209 | return loss, pr, re, f1, aupr
210 |
211 | def eval_similarity(model, test_iterator, use_cuda):
212 | y = []
213 | logits = []
214 | for x0,x1,y_mb in test_iterator:
215 |
216 | if use_cuda:
217 | y_mb = y_mb.cuda()
218 | y.append(y_mb.long())
219 |
220 | b = len(x0)
221 | x = x0 + x1
222 |
223 | x,order = pack_sequences(x)
224 | x = PackedSequence(Variable(x.data), x.batch_sizes)
225 | z = model(x) # embed the sequences
226 | z = unpack_sequences(z, order)
227 |
228 | z0 = z[:b]
229 | z1 = z[b:]
230 |
231 | for i in range(b):
232 | z_a = z0[i]
233 | z_b = z1[i]
234 | logits.append(model.score(z_a, z_b))
235 |
236 | y = torch.cat(y, 0)
237 | logits = torch.stack(logits, 0)
238 |
239 | p = torch.sigmoid(logits).data
240 | ones = p.new(p.size(0),1).zero_() + 1
241 | p_ge = torch.cat([ones, p], 1)
242 | p_lt = torch.cat([1-p, ones], 1)
243 | p = p_ge*p_lt
244 | p = p/p.sum(1,keepdim=True) # make sure p is normalized
245 |
246 | loss = F.cross_entropy(p, y).item()
247 |
248 | _,y_hard = torch.max(p, 1)
249 | levels = torch.arange(5).to(p.device)
250 | y_hat = torch.sum(p*levels, 1)
251 |
252 | accuracy = torch.mean((y == y_hard).float()).item()
253 | mse = torch.mean((y.float() - y_hat)**2).item()
254 |
255 | y = y.cpu().numpy()
256 | y_hat = y_hat.cpu().numpy()
257 |
258 | r,_ = pearsonr(y_hat, y)
259 | rho,_ = spearmanr(y_hat, y)
260 |
261 | return loss, accuracy, mse, r, rho
262 |
263 |
264 | def main():
265 | import argparse
266 | parser = argparse.ArgumentParser('Script for training contact prediction model')
267 |
268 | parser.add_argument('--dev', action='store_true', help='use train/dev split')
269 |
270 | parser.add_argument('--rnn-type', choices=['lstm', 'gru'], default='lstm', help='type of RNN block to use (default: lstm)')
271 | parser.add_argument('--embedding-dim', type=int, default=100, help='embedding dimension (default: 40)')
272 | parser.add_argument('--input-dim', type=int, default=512, help='dimension of input to RNN (default: 512)')
273 | parser.add_argument('--rnn-dim', type=int, default=512, help='hidden units of RNNs (default: 128)')
274 | parser.add_argument('--num-layers', type=int, default=3, help='number of RNN layers (default: 3)')
275 | parser.add_argument('--dropout', type=float, default=0, help='dropout probability (default: 0)')
276 |
277 |
278 | parser.add_argument('--hidden-dim', type=int, default=50, help='number of hidden units for comparison layer in contact predictionn (default: 50)')
279 | parser.add_argument('--width', type=int, default=7, help='width of convolutional filter for contact prediction (default: 7)')
280 |
281 |
282 | parser.add_argument('--epoch-size', type=int, default=100000, help='number of examples per epoch (default: 100,000)')
283 | parser.add_argument('--epoch-scale', type=int, default=5, help='report heldout performance every this many epochs (default: 5)')
284 | parser.add_argument('--num-epochs', type=int, default=100, help='number of epochs (default: 100)')
285 |
286 |
287 | parser.add_argument('--similarity-batch-size', type=int, default=64, help='minibatch size for similarity prediction loss in pairs (default: 64)')
288 | parser.add_argument('--contact-batch-size', type=int, default=10, help='minibatch size for contact predictionn loss (default: 10)')
289 |
290 |
291 | parser.add_argument('--weight-decay', type=float, default=0, help='L2 regularization (default: 0)')
292 | parser.add_argument('--lr', type=float, default=0.001)
293 | parser.add_argument('--lambda', dest='lambda_', type=float, default=0.5, help='weight on the similarity objective, contact map objective weight is one minus this (default: 0.5)')
294 |
295 | parser.add_argument('--tau', type=float, default=0.5, help='sampling proportion exponent (default: 0.5)')
296 | parser.add_argument('--augment', type=float, default=0, help='probability of resampling amino acid for data augmentation (default: 0)')
297 |
298 | parser.add_argument('--lm', help='pretrained LM to use as initial embedding')
299 |
300 | parser.add_argument('-o', '--output', help='output file path (default: stdout)')
301 | parser.add_argument('--save-prefix', help='path prefix for saving models')
302 | parser.add_argument('-d', '--device', type=int, default=-2, help='compute device to use')
303 |
304 | args = parser.parse_args()
305 |
306 |
307 | prefix = args.output
308 |
309 |
310 | ## set the device
311 | d = args.device
312 | use_cuda = (d != -1) and torch.cuda.is_available()
313 | if d >= 0:
314 | torch.cuda.set_device(d)
315 |
316 | ## make the datasets
317 | alphabet = Uniprot21()
318 |
319 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.fa'
320 | astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa'
321 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.sampledpairs.txt'
322 | if args.dev:
323 | astral_train_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa'
324 | astral_test_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.fa'
325 | astral_testpairs_path = 'data/SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.sampledpairs.txt'
326 |
327 |
328 | print('# loading training sequences:', astral_train_path, file=sys.stderr)
329 | x_train, structs_train, contacts_train = load_data(astral_train_path, alphabet)
330 | if use_cuda:
331 | x_train = [x.cuda() for x in x_train]
332 | #contacts_train = [c.cuda() for c in contacts_train]
333 | print('# loaded', len(x_train), 'training sequences', file=sys.stderr)
334 |
335 | print('# loading test sequences:', astral_test_path, file=sys.stderr)
336 | x_test, _, contacts_test = load_data(astral_test_path, alphabet)
337 | if use_cuda:
338 | x_test = [x.cuda() for x in x_test]
339 | #contacts_test = [c.cuda() for c in contacts_test]
340 | print('# loaded', len(x_test), 'contact map test sequences', file=sys.stderr)
341 |
342 | x0_test, x1_test, y_scop_test = load_scop_testpairs(astral_testpairs_path, alphabet)
343 | if use_cuda:
344 | x0_test = [x.cuda() for x in x0_test]
345 | x1_test = [x.cuda() for x in x1_test]
346 | print('# loaded', len(x0_test), 'scop test pairs', file=sys.stderr)
347 |
348 | ## make the dataset iterators
349 |
350 | # data augmentation by resampling amino acids
351 | augment = None
352 | p = 0
353 | if args.augment > 0:
354 | p = args.augment
355 | trans = torch.ones(len(alphabet),len(alphabet))
356 | trans = trans/trans.sum(1, keepdim=True)
357 | if use_cuda:
358 | trans = trans.cuda()
359 | augment = MultinomialResample(trans, p)
360 | print('# resampling amino acids with p:', p, file=sys.stderr)
361 |
362 | # SCOP structural similarity datasets
363 | scop_levels = torch.cumprod((structs_train.unsqueeze(1) == structs_train.unsqueeze(0)).long(), 2)
364 | scop_train = AllPairsDataset(x_train, scop_levels, augment=augment)
365 | scop_test = PairedDataset(x0_test, x1_test, y_scop_test)
366 |
367 | # contact map datasets
368 | cmap_train = ContactMapDataset(x_train, contacts_train, augment=augment)
369 | cmap_test = ContactMapDataset(x_test, contacts_test)
370 |
371 | # iterators for contacts data
372 | batch_size = args.contact_batch_size
373 | cmap_train_iterator = torch.utils.data.DataLoader(cmap_train
374 | , batch_size=batch_size
375 | , shuffle=True
376 | , collate_fn=collate_lists
377 | )
378 | cmap_test_iterator = torch.utils.data.DataLoader(cmap_test
379 | , batch_size=batch_size
380 | , collate_fn=collate_lists
381 | )
382 |
383 | # make the SCOP training iterator have same number of minibatches
384 | num_steps = len(cmap_train_iterator)
385 | batch_size = args.similarity_batch_size
386 | epoch_size = num_steps*batch_size
387 |
388 | similarity = scop_levels.numpy().sum(2)
389 | levels,counts = np.unique(similarity, return_counts=True)
390 | order = np.argsort(levels)
391 | levels = levels[order]
392 | counts = counts[order]
393 |
394 | tau = args.tau
395 | print('# using tau:', tau, file=sys.stderr)
396 | print('#', counts**tau/np.sum(counts**tau), file=sys.stderr)
397 | weights = counts**tau/counts
398 | weights = weights[similarity].ravel()
399 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, epoch_size)
400 | N = epoch_size
401 |
402 | # iterators for similarity data
403 | scop_train_iterator = torch.utils.data.DataLoader(scop_train
404 | , batch_size=batch_size
405 | , sampler=sampler
406 | , collate_fn=collate_paired_sequences
407 | )
408 | scop_test_iterator = torch.utils.data.DataLoader(scop_test
409 | , batch_size=batch_size
410 | , collate_fn=collate_paired_sequences
411 | )
412 |
413 | report_steps = args.epoch_scale
414 |
415 |
416 | ## initialize the model
417 | rnn_type = args.rnn_type
418 | rnn_dim = args.rnn_dim
419 | num_layers = args.num_layers
420 |
421 | embedding_size = args.embedding_dim
422 | input_dim = args.input_dim
423 | dropout = args.dropout
424 |
425 | print('# initializing embedding model with:', file=sys.stderr)
426 | print('# embedding_size:', embedding_size, file=sys.stderr)
427 | print('# input_dim:', input_dim, file=sys.stderr)
428 | print('# rnn_dim:', rnn_dim, file=sys.stderr)
429 | print('# num_layers:', num_layers, file=sys.stderr)
430 | print('# dropout:', dropout, file=sys.stderr)
431 |
432 | lm = None
433 | if args.lm is not None:
434 | print('# using pretrained LM:', args.lm, file=sys.stderr)
435 | lm = torch.load(args.lm)
436 | lm.eval()
437 | ## do not update the LM parameters
438 | for param in lm.parameters():
439 | param.requires_grad = False
440 |
441 | embedding = src.models.embedding.StackedRNN(len(alphabet), input_dim, rnn_dim
442 | , embedding_size, nlayers=num_layers
443 | , dropout=dropout, lm=lm)
444 |
445 | # similarity prediction parameters
446 | similarity_kwargs = {}
447 |
448 | # contact map prediction parameters
449 | hidden_dim = args.hidden_dim
450 | width = args.width
451 | cmap_kwargs = {'hidden_dim': hidden_dim, 'width': width}
452 |
453 | model = src.models.multitask.SCOPCM(embedding, similarity_kwargs=similarity_kwargs,
454 | cmap_kwargs=cmap_kwargs)
455 | if use_cuda:
456 | model.cuda()
457 |
458 | ## setup training parameters and optimizer
459 | num_epochs = args.num_epochs
460 |
461 | weight_decay = args.weight_decay
462 | lr = args.lr
463 |
464 | print('# training with Adam: lr={}, weight_decay={}'.format(lr, weight_decay), file=sys.stderr)
465 | params = [p for p in model.parameters() if p.requires_grad]
466 | optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
467 |
468 | scop_weight = args.lambda_
469 | cmap_weight = 1 - scop_weight
470 |
471 | print('# weighting tasks with SIMILARITY: {:.3f}, CONTACTS: {:.3f}'.format(scop_weight, cmap_weight), file=sys.stderr)
472 |
473 | ## train the model
474 | print('# training model', file=sys.stderr)
475 |
476 | save_prefix = args.save_prefix
477 | output = args.output
478 | if output is None:
479 | output = sys.stdout
480 | else:
481 | output = open(output, 'w')
482 | digits = int(np.floor(np.log10(num_epochs))) + 1
483 | tokens = ['sim_loss', 'sim_mse', 'sim_acc', 'sim_r', 'sim_rho'
484 | ,'cmap_loss', 'cmap_pr', 'cmap_re', 'cmap_f1', 'cmap_aupr']
485 | line = '\t'.join(['epoch', 'split'] + tokens)
486 | print(line, file=output)
487 |
488 | prog_template = '# [{}/{}] training {:.1%} sim_loss={:.5f}, sim_acc={:.5f}, cmap_loss={:.5f}, cmap_f1={:.5f}'
489 |
490 | for epoch in range(num_epochs):
491 | # train epoch
492 | model.train()
493 |
494 | scop_n = 0
495 | scop_loss_accum = 0
496 | scop_mse_accum = 0
497 | scop_acc_accum = 0
498 |
499 | cmap_n = 0
500 | cmap_loss_accum = 0
501 | cmap_pp = 0
502 | cmap_pr_accum = 0
503 | cmap_gp = 0
504 | cmap_re_accum = 0
505 |
506 | for (cmap_x, cmap_y), (scop_x0, scop_x1, scop_y) in zip(cmap_train_iterator, scop_train_iterator):
507 |
508 | # calculate gradients and metrics for similarity part
509 | loss, correct, mse, b = similarity_grad(model, scop_x0, scop_x1, scop_y, use_cuda, weight=scop_weight)
510 |
511 | scop_n += b
512 | delta = b*(loss - scop_loss_accum)
513 | scop_loss_accum += delta/scop_n
514 | delta = correct - b*scop_acc_accum
515 | scop_acc_accum += delta/scop_n
516 | delta = b*(mse - scop_mse_accum)
517 | scop_mse_accum += delta/scop_n
518 |
519 | report = ((scop_n - b)//100 < scop_n//100)
520 |
521 | # calculate the contact map prediction gradients and metrics
522 | loss, tp, gp_, pp_, b = contacts_grad(model, cmap_x, cmap_y, use_cuda, weight=cmap_weight)
523 |
524 | cmap_gp += gp_
525 | delta = tp - gp_*cmap_re_accum
526 | cmap_re_accum += delta/cmap_gp
527 |
528 | cmap_pp += pp_
529 | delta = tp - pp_*cmap_pr_accum
530 | cmap_pr_accum += delta/cmap_pp
531 |
532 | cmap_n += b
533 | delta = b*(loss - cmap_loss_accum)
534 | cmap_loss_accum += delta/cmap_n
535 |
536 | ## update the parameters
537 | optim.step()
538 | optim.zero_grad()
539 | model.clip()
540 |
541 | if report:
542 | f1 = 2*cmap_pr_accum*cmap_re_accum/(cmap_pr_accum + cmap_re_accum)
543 | line = prog_template.format(epoch+1, num_epochs, scop_n/N, scop_loss_accum
544 | , scop_acc_accum, cmap_loss_accum, f1)
545 | print(line, end='\r', file=sys.stderr)
546 | print(' '*80, end='\r', file=sys.stderr)
547 | f1 = 2*cmap_pr_accum*cmap_re_accum/(cmap_pr_accum + cmap_re_accum)
548 | tokens = [ scop_loss_accum, scop_mse_accum, scop_acc_accum, '-', '-'
549 | , cmap_loss_accum, cmap_pr_accum, cmap_re_accum, f1, '-']
550 | tokens = [x if type(x) is str else '{:.5f}'.format(x) for x in tokens]
551 |
552 | line = '\t'.join([str(epoch+1).zfill(digits), 'train'] + tokens)
553 | print(line, file=output)
554 | output.flush()
555 |
556 | # eval and save model
557 | if (epoch+1) % report_steps == 0:
558 | model.eval()
559 | with torch.no_grad():
560 | scop_loss, scop_acc, scop_mse, scop_r, scop_rho = \
561 | eval_similarity(model, scop_test_iterator, use_cuda)
562 | cmap_loss, cmap_pr, cmap_re, cmap_f1, cmap_aupr = \
563 | eval_contacts(model, cmap_test_iterator, use_cuda)
564 |
565 | tokens = [ scop_loss, scop_mse, scop_acc, scop_r, scop_rho
566 | , cmap_loss, cmap_pr, cmap_re, cmap_f1, cmap_aupr]
567 | tokens = ['{:.5f}'.format(x) for x in tokens]
568 |
569 | line = '\t'.join([str(epoch+1).zfill(digits), 'test'] + tokens)
570 | print(line, file=output)
571 | output.flush()
572 |
573 | # save the model
574 | if save_prefix is not None:
575 | save_path = save_prefix + '_epoch' + str(epoch+1).zfill(digits) + '.sav'
576 | model.cpu()
577 | torch.save(model, save_path)
578 | if use_cuda:
579 | model.cuda()
580 |
581 |
582 |
583 |
584 | if __name__ == '__main__':
585 | main()
586 |
587 |
588 |
589 |
590 |
591 |
--------------------------------------------------------------------------------