├── smt ├── __init__.py ├── db │ ├── __init__.py │ ├── tables.py │ ├── createngramdb.py │ └── createdb.py ├── decoder │ ├── __init__.py │ └── stackdecoder.py ├── phrase │ ├── __init__.py │ ├── word_alignment.py │ └── phrase_extract.py ├── utils │ ├── __init__.py │ └── utility.py ├── ibmmodel │ ├── __init__.py │ ├── test.txt │ ├── ibmmodel1.py │ └── ibmmodel2.py └── langmodel │ ├── __init__.py │ └── ngram.py ├── test ├── __init__.py ├── test_ngram.py ├── test_ibmmodel.py ├── test_phrase.py └── test_stackdecoder.py ├── .gitignore ├── setup.py ├── README.rst └── COPYING /smt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/phrase/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/ibmmodel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/langmodel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smt/ibmmodel/test.txt: -------------------------------------------------------------------------------- 1 | the house|||das Haus 2 | the book|||das Buch 3 | a book|||ein Buch 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | smt.egg-info/ 2 | *.pyc 3 | terminal.py 4 | twitter 5 | jec_basic_sentence 6 | test/:test: 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="smt", 8 | version="0.1", 9 | description="Statistical Machine Translation implementation by Python", 10 | author="Noriyuki Abe", 11 | author_email="kenko.py@gmail.com", 12 | url="http://kenkov.jp", 13 | packages=find_packages(), 14 | test_suite="test", 15 | ) 16 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ============================== 2 | IBM Model 3 | ============================== 4 | 5 | IMB models of statistical mathine translation 6 | 7 | Files 8 | ======= 9 | 10 | ibmmodel1.py 11 | implements IBM Model1 12 | 13 | ibmmodel2.py 14 | implements IBM Model2 15 | 16 | word_alignment.py 17 | implements symmetrization of word alignments 18 | 19 | 20 | Usege 21 | ====== 22 | 23 | See each file and test codes written in test.py 24 | -------------------------------------------------------------------------------- /smt/langmodel/ngram.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import itertools 6 | 7 | 8 | class NgramException(Exception): 9 | pass 10 | 11 | 12 | def ngram(sentences, n): 13 | s_len = len(sentences) 14 | if s_len < n: 15 | raise NgramException("the sentences length is not enough:\ 16 | len(sentences)={} < n={}".format(s_len, n)) 17 | xs = itertools.tee(sentences, n) 18 | for i, t in enumerate(xs[1:]): 19 | for _ in xrange(i+1): 20 | next(t) 21 | return itertools.izip(*xs) 22 | 23 | 24 | if __name__ == '__main__': 25 | pass 26 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 Noriyuki ABE 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), 5 | to deal in the Software without restriction, including without limitation the rights to use, 6 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 12 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 13 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 14 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | -------------------------------------------------------------------------------- /test/test_ngram.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import unittest 6 | from smt.langmodel.ngram import ngram 7 | from smt.langmodel.ngram import NgramException 8 | 9 | 10 | class NgramTest(unittest.TestCase): 11 | def test_ngram_3(self): 12 | sentence = ["I am teacher", 13 | "I am", 14 | "I", 15 | ""] 16 | test_sentences = (["", ""] + item.split() + [""] 17 | for item in sentence) 18 | anss = [[("", "", "I"), 19 | ("", "I", "am"), 20 | ("I", "am", "teacher"), 21 | ("am", "teacher", "")], 22 | [("", "", "I"), 23 | ("", "I", "am"), 24 | ("I", "am", "")], 25 | [("", "", "I"), 26 | ("", "I", "")], 27 | [("", "", "")], 28 | ] 29 | 30 | for sentences, ans in zip(test_sentences, anss): 31 | a = ngram(sentences, 3) 32 | self.assertEqual(list(a), ans) 33 | 34 | def test_ngram_illegal_input(self): 35 | sentences = ["I", "am"] 36 | self.assertRaises(NgramException, ngram, sentences, 3) 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /smt/utils/utility.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | 6 | 7 | def mkcorpus(sentences): 8 | """ 9 | >>> sent_pairs = [("僕 は 男 です", "I am a man"), 10 | ("私 は 女 です", "I am a girl"), 11 | ("私 は 先生 です", "I am a teacher"), 12 | ("彼女 は 先生 です", "She is a teacher"), 13 | ("彼 は 先生 です", "He is a teacher"), 14 | ] 15 | >>> pprint(mkcorpus(sent_pairs)) 16 | [(['\xe5\x83\x95', 17 | '\xe3\x81\xaf', 18 | '\xe7\x94\xb7', 19 | '\xe3\x81\xa7\xe3\x81\x99'], 20 | ['I', 'am', 'a', 'man']), 21 | (['\xe7\xa7\x81', 22 | '\xe3\x81\xaf', 23 | '\xe5\xa5\xb3', 24 | '\xe3\x81\xa7\xe3\x81\x99'], 25 | ['I', 'am', 'a', 'girl']), 26 | (['\xe7\xa7\x81', 27 | '\xe3\x81\xaf', 28 | '\xe5\x85\x88\xe7\x94\x9f', 29 | '\xe3\x81\xa7\xe3\x81\x99'], 30 | ['I', 'am', 'a', 'teacher']), 31 | (['\xe5\xbd\xbc\xe5\xa5\xb3', 32 | '\xe3\x81\xaf', 33 | '\xe5\x85\x88\xe7\x94\x9f', 34 | '\xe3\x81\xa7\xe3\x81\x99'], 35 | ['She', 'is', 'a', 'teacher']), 36 | (['\xe5\xbd\xbc', 37 | '\xe3\x81\xaf', 38 | '\xe5\x85\x88\xe7\x94\x9f', 39 | '\xe3\x81\xa7\xe3\x81\x99'], 40 | ['He', 'is', 'a', 'teacher'])] 41 | """ 42 | return [(es.split(), fs.split()) for (es, fs) in sentences] 43 | 44 | 45 | def matrix( 46 | m, n, lst, 47 | m_text: list=None, 48 | n_text: list=None): 49 | """ 50 | m: row 51 | n: column 52 | lst: items 53 | 54 | >>> print(_matrix(2, 3, [(1, 1), (2, 3)])) 55 | |x| | | 56 | | | |x| 57 | """ 58 | 59 | fmt = "" 60 | if n_text: 61 | fmt += " {}\n".format(" ".join(n_text)) 62 | for i in range(1, m+1): 63 | if m_text: 64 | fmt += "{:<4.4} ".format(m_text[i-1]) 65 | fmt += "|" 66 | for j in range(1, n+1): 67 | if (i, j) in lst: 68 | fmt += "x|" 69 | else: 70 | fmt += " |" 71 | fmt += "\n" 72 | return fmt 73 | -------------------------------------------------------------------------------- /smt/db/tables.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | # import SQLAlchemy 6 | from sqlalchemy.ext.declarative import declarative_base 7 | from sqlalchemy import Column, TEXT, REAL, INTEGER 8 | 9 | 10 | class Tables(object): 11 | 12 | def get_sentence_table(self, tablename="sentence"): 13 | 14 | class Sentence(declarative_base()): 15 | __tablename__ = tablename 16 | id = Column(INTEGER, primary_key=True) 17 | lang1 = Column(TEXT) 18 | lang2 = Column(TEXT) 19 | 20 | return Sentence 21 | 22 | def get_wordprobability_table(self, tablename): 23 | 24 | class WordProbability(declarative_base()): 25 | __tablename__ = tablename 26 | id = Column(INTEGER, primary_key=True) 27 | transto = Column(TEXT) 28 | transfrom = Column(TEXT) 29 | prob = Column(REAL) 30 | 31 | return WordProbability 32 | 33 | def get_wordalignment_table(self, tablename): 34 | 35 | class WordAlignment(declarative_base()): 36 | __tablename__ = tablename 37 | id = Column(INTEGER, primary_key=True) 38 | from_pos = Column(INTEGER) 39 | to_pos = Column(INTEGER) 40 | to_len = Column(INTEGER) 41 | from_len = Column(INTEGER) 42 | prob = Column(REAL) 43 | 44 | return WordAlignment 45 | 46 | def get_phrase_table(self, tablename="phrase"): 47 | 48 | class Phrase(declarative_base()): 49 | __tablename__ = tablename 50 | id = Column(INTEGER, primary_key=True) 51 | lang1p = Column(TEXT) 52 | lang2p = Column(TEXT) 53 | 54 | return Phrase 55 | 56 | def get_transphraseprob_table(self, tablename="phraseprob"): 57 | 58 | class TransPhraseProb(declarative_base()): 59 | __tablename__ = tablename 60 | id = Column(INTEGER, primary_key=True) 61 | lang1p = Column(TEXT) 62 | lang2p = Column(TEXT) 63 | p1_2 = Column(REAL) 64 | p2_1 = Column(REAL) 65 | 66 | return TransPhraseProb 67 | 68 | def get_trigram_table(self, tablename): 69 | 70 | class Trigram(declarative_base()): 71 | __tablename__ = tablename 72 | id = Column(INTEGER, primary_key=True) 73 | first = Column(TEXT) 74 | second = Column(TEXT) 75 | third = Column(TEXT) 76 | count = Column(INTEGER) 77 | 78 | return Trigram 79 | 80 | def get_trigramprob_table(self, tablename): 81 | 82 | class TrigramProb(declarative_base()): 83 | __tablename__ = tablename 84 | id = Column(INTEGER, primary_key=True) 85 | first = Column(TEXT) 86 | second = Column(TEXT) 87 | third = Column(TEXT) 88 | prob = Column(REAL) 89 | 90 | return TrigramProb 91 | 92 | def get_trigramprobwithoutlast_table(self, tablename): 93 | 94 | class TrigramProbWithoutLast(declarative_base()): 95 | __tablename__ = tablename 96 | id = Column(INTEGER, primary_key=True) 97 | first = Column(TEXT) 98 | second = Column(TEXT) 99 | prob = Column(REAL) 100 | 101 | return TrigramProbWithoutLast 102 | 103 | def get_unigram_table(self, tablename): 104 | 105 | class Unigram(declarative_base()): 106 | __tablename__ = tablename 107 | id = Column(INTEGER, primary_key=True) 108 | first = Column(TEXT) 109 | count = Column(INTEGER) 110 | 111 | return Unigram 112 | 113 | def get_unigramprob_table(self, tablename): 114 | 115 | class UnigramProb(declarative_base()): 116 | __tablename__ = tablename 117 | id = Column(INTEGER, primary_key=True) 118 | first = Column(TEXT) 119 | prob = Column(REAL) 120 | 121 | return UnigramProb 122 | -------------------------------------------------------------------------------- /test/test_ibmmodel.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import unittest 6 | import collections 7 | #import keitaiso 8 | from smt.ibmmodel.ibmmodel1 import train 9 | from smt.ibmmodel.ibmmodel2 import viterbi_alignment 10 | #import smt.ibmmodel.ibmmodel2 as ibmmodel2 11 | import decimal 12 | from decimal import Decimal as D 13 | 14 | # set deciaml context 15 | decimal.getcontext().prec = 4 16 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP 17 | 18 | 19 | class IBMModel1Test(unittest.TestCase): 20 | 21 | #def _format(self, lst): 22 | # return {(k, float('{:.4f}'.format(v))) for (k, v) in lst} 23 | 24 | def test_train_loop1(self): 25 | sent_pairs = [("the house", "das Haus"), 26 | ("the book", "das Buch"), 27 | ("a book", "ein Buch"), 28 | ] 29 | #t0 = train(sent_pairs, loop_count=0) 30 | t1 = train(sent_pairs, loop_count=1) 31 | 32 | loop1 = [(('house', 'Haus'), D("0.5")), 33 | (('book', 'ein'), D("0.5")), 34 | (('the', 'das'), D("0.5")), 35 | (('the', 'Buch'), D("0.25")), 36 | (('book', 'Buch'), D("0.5")), 37 | (('a', 'ein'), D("0.5")), 38 | (('book', 'das'), D("0.25")), 39 | (('the', 'Haus'), D("0.5")), 40 | (('house', 'das'), D("0.25")), 41 | (('a', 'Buch'), D("0.25"))] 42 | # assertion 43 | # next assertion doesn't make sence because 44 | # initialized by defaultdict 45 | #self.assertEqual(self._format(t0.items()), self._format(loop0)) 46 | self.assertEqual(set(t1.items()), set(loop1)) 47 | 48 | def test_train_loop2(self): 49 | sent_pairs = [("the house", "das Haus"), 50 | ("the book", "das Buch"), 51 | ("a book", "ein Buch"), 52 | ] 53 | #t0 = train(sent_pairs, loop_count=0) 54 | t2 = train(sent_pairs, loop_count=2) 55 | 56 | loop2 = [(('house', 'Haus'), D("0.5713")), 57 | (('book', 'ein'), D("0.4284")), 58 | (('the', 'das'), D("0.6367")), 59 | (('the', 'Buch'), D("0.1818")), 60 | (('book', 'Buch'), D("0.6367")), 61 | (('a', 'ein'), D("0.5713")), 62 | (('book', 'das'), D("0.1818")), 63 | (('the', 'Haus'), D("0.4284")), 64 | (('house', 'das'), D("0.1818")), 65 | (('a', 'Buch'), D("0.1818"))] 66 | # assertion 67 | # next assertion doesn't make sence because 68 | # initialized by defaultdict 69 | #self.assertEqual(self._format(t0.items()), self._format(loop0)) 70 | self.assertEqual(set(t2.items()), set(loop2)) 71 | 72 | 73 | class IBMModel2Test(unittest.TestCase): 74 | 75 | def test_viterbi_alignment(self): 76 | x = viterbi_alignment([1, 2, 1], 77 | [2, 3, 2], 78 | collections.defaultdict(int), 79 | collections.defaultdict(int)) 80 | # Viterbi_alignment selects the first token 81 | # if t or a doesn't contain the key. 82 | # This means it returns NULL token 83 | # in such a situation. 84 | self.assertEqual(x, {1: 1, 2: 1, 3: 1}) 85 | 86 | #def test_zero_division_error(self): 87 | # """ 88 | # at the beginning, there was this bug for ZeroDivisionError, 89 | # so this test was created to check that 90 | # """ 91 | # sentence = [(u"Xではないかとつくづく疑問に思う", 92 | # u"I often wonder if it might be X."), 93 | # (u"Xがいいなといつも思います", 94 | # u"I always think X would be nice."), 95 | # (u"それがあるようにいつも思います", 96 | # u"It always seems like it is there."), 97 | # ] 98 | # sentences = [(keitaiso.str2wakati(s1), s2) for 99 | # s1, s2 in sentence] 100 | 101 | # self.assertRaises(decimal.DivisionByZero, 102 | # ibmmodel2.train, 103 | # sentences, loop_count=1000) 104 | 105 | 106 | if __name__ == '__main__': 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /smt/ibmmodel/ibmmodel1.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from operator import itemgetter 5 | import collections 6 | from smt.utils import utility 7 | import decimal 8 | from decimal import Decimal as D 9 | 10 | # set deciaml context 11 | decimal.getcontext().prec = 4 12 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP 13 | 14 | 15 | def _constant_factory(value): 16 | '''define a local function for uniform probability initialization''' 17 | #return itertools.repeat(value).next 18 | return lambda: value 19 | 20 | 21 | def _train(corpus, loop_count=1000): 22 | f_keys = set() 23 | for (es, fs) in corpus: 24 | for f in fs: 25 | f_keys.add(f) 26 | # default value provided as uniform probability) 27 | t = collections.defaultdict(_constant_factory(D(1/len(f_keys)))) 28 | 29 | # loop 30 | for i in range(loop_count): 31 | count = collections.defaultdict(D) 32 | total = collections.defaultdict(D) 33 | s_total = collections.defaultdict(D) 34 | for (es, fs) in corpus: 35 | # compute normalization 36 | for e in es: 37 | s_total[e] = D() 38 | for f in fs: 39 | s_total[e] += t[(e, f)] 40 | for e in es: 41 | for f in fs: 42 | count[(e, f)] += t[(e, f)] / s_total[e] 43 | total[f] += t[(e, f)] / s_total[e] 44 | #if e == u"に" and f == u"always": 45 | # print(" BREAK:", i, count[(e, f)]) 46 | # estimate probability 47 | for (e, f) in count.keys(): 48 | #if count[(e, f)] == 0: 49 | # print(e, f, count[(e, f)]) 50 | t[(e, f)] = count[(e, f)] / total[f] 51 | 52 | return t 53 | 54 | 55 | def train(sentences, loop_count=1000): 56 | corpus = utility.mkcorpus(sentences) 57 | return _train(corpus, loop_count) 58 | 59 | 60 | def _pprint(tbl): 61 | for (e, f), v in sorted(tbl.items(), key=itemgetter(1), reverse=True): 62 | print(u"p({e}|{f}) = {v}".format(e=e, f=f, v=v)) 63 | 64 | 65 | def test_train_loop1(): 66 | sent_pairs = [("the house", "das Haus"), 67 | ("the book", "das Buch"), 68 | ("a book", "ein Buch"), 69 | ] 70 | #t0 = train(sent_pairs, loop_count=0) 71 | t1 = train(sent_pairs, loop_count=1) 72 | 73 | loop1 = [(('house', 'Haus'), D("0.5")), 74 | (('book', 'ein'), D("0.5")), 75 | (('the', 'das'), D("0.5")), 76 | (('the', 'Buch'), D("0.25")), 77 | (('book', 'Buch'), D("0.5")), 78 | (('a', 'ein'), D("0.5")), 79 | (('book', 'das'), D("0.25")), 80 | (('the', 'Haus'), D("0.5")), 81 | (('house', 'das'), D("0.25")), 82 | (('a', 'Buch'), D("0.25"))] 83 | # assertion 84 | # next assertion doesn't make sence because 85 | # initialized by defaultdict 86 | #self.assertEqual(self._format(t0.items()), self._format(loop0)) 87 | assert set(t1.items()) == set(loop1) 88 | 89 | 90 | def test_train_loop2(): 91 | sent_pairs = [("the house", "das Haus"), 92 | ("the book", "das Buch"), 93 | ("a book", "ein Buch"), 94 | ] 95 | #t0 = train(sent_pairs, loop_count=0) 96 | t2 = train(sent_pairs, loop_count=2) 97 | 98 | loop2 = [(('house', 'Haus'), D("0.5713")), 99 | (('book', 'ein'), D("0.4284")), 100 | (('the', 'das'), D("0.6367")), 101 | (('the', 'Buch'), D("0.1818")), 102 | (('book', 'Buch'), D("0.6367")), 103 | (('a', 'ein'), D("0.5713")), 104 | (('book', 'das'), D("0.1818")), 105 | (('the', 'Haus'), D("0.4284")), 106 | (('house', 'das'), D("0.1818")), 107 | (('a', 'Buch'), D("0.1818"))] 108 | # assertion 109 | # next assertion doesn't make sence because 110 | # initialized by defaultdict 111 | #self.assertEqual(self._format(t0.items()), self._format(loop0)) 112 | assert set(t2.items()) == set(loop2) 113 | 114 | 115 | if __name__ == '__main__': 116 | import sys 117 | 118 | fd = open(sys.argv[1]) if len(sys.argv) >= 2 else sys.stdin 119 | sentences = [line.strip().split('|||') for line in fd.readlines()] 120 | t = train(sentences, loop_count=3) 121 | for (e, f), val in t.items(): 122 | print("{} {}\t{}".format(e, f, val)) 123 | -------------------------------------------------------------------------------- /smt/phrase/word_alignment.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | from smt.ibmmodel import ibmmodel2 6 | from pprint import pprint 7 | 8 | 9 | def _alignment(elist, flist, e2f, f2e): 10 | ''' 11 | elist, flist 12 | wordlist for each language 13 | e2f 14 | translatoin alignment from e to f 15 | alignment is 16 | [(e, f)] 17 | f2e 18 | translatoin alignment from f to e 19 | alignment is 20 | [(e, f)] 21 | return 22 | alignment: {(f, e)} 23 | flist 24 | ----------------- 25 | e | | 26 | l | | 27 | i | | 28 | s | | 29 | t | | 30 | ----------------- 31 | 32 | ''' 33 | neighboring = {(-1, 0), (0, -1), (1, 0), (0, 1), 34 | (-1, -1), (-1, 1), (1, -1), (1, 1)} 35 | e2f = set(e2f) 36 | f2e = set(f2e) 37 | m = len(elist) 38 | n = len(flist) 39 | alignment = e2f.intersection(f2e) 40 | # marge with neighborhood 41 | while True: 42 | set_len = len(alignment) 43 | for e_word in range(1, m+1): 44 | for f_word in range(1, n+1): 45 | if (e_word, f_word) in alignment: 46 | for (e_diff, f_diff) in neighboring: 47 | e_new = e_word + e_diff 48 | f_new = f_word + f_diff 49 | if not alignment: 50 | if (e_new, f_new) in e2f.union(f2e): 51 | alignment.add((e_new, f_new)) 52 | else: 53 | if ((e_new not in list(zip(*alignment))[0] 54 | or f_new not in list(zip(*alignment))[1]) 55 | and (e_new, f_new) in e2f.union(f2e)): 56 | alignment.add((e_new, f_new)) 57 | if set_len == len(alignment): 58 | break 59 | # finalize 60 | for e_word in range(1, m+1): 61 | for f_word in range(1, n+1): 62 | # for alignment = set([]) 63 | if not alignment: 64 | if (e_word, f_word) in e2f.union(f2e): 65 | alignment.add((e_word, f_word)) 66 | else: 67 | if ((e_word not in list(zip(*alignment))[0] 68 | or f_word not in list(zip(*alignment))[1]) 69 | and (e_word, f_word) in e2f.union(f2e)): 70 | alignment.add((e_word, f_word)) 71 | return alignment 72 | 73 | 74 | def alignment(es, fs, e2f, f2e): 75 | """ 76 | es: English words 77 | fs: Foreign words 78 | f2e: alignment for translation from fs to es 79 | [(e, f)] or {(e, f)} 80 | e2f: alignment for translation from es to fs 81 | [(f, e)] or {(f, e)} 82 | """ 83 | _e2f = list(zip(*reversed(list(zip(*e2f))))) 84 | return _alignment(es, fs, _e2f, f2e) 85 | 86 | 87 | def symmetrization(es, fs, corpus): 88 | ''' 89 | forpus 90 | for translation from fs to es 91 | return 92 | alignment **from fs to es** 93 | ''' 94 | f2e_train = ibmmodel2._train(corpus, loop_count=10) 95 | f2e = ibmmodel2.viterbi_alignment(es, fs, *f2e_train).items() 96 | 97 | e2f_corpus = list(zip(*reversed(list(zip(*corpus))))) 98 | e2f_train = ibmmodel2._train(e2f_corpus, loop_count=10) 99 | e2f = ibmmodel2.viterbi_alignment(fs, es, *e2f_train).items() 100 | 101 | return alignment(es, fs, e2f, f2e) 102 | 103 | 104 | if __name__ == '__main__': 105 | # test for alignment 106 | es = "michael assumes that he will stay in the house".split() 107 | fs = "michael geht davon aus , dass er im haus bleibt".split() 108 | e2f = [(1, 1), (2, 2), (2, 3), (2, 4), (3, 6), 109 | (4, 7), (7, 8), (9, 9), (6, 10)] 110 | f2e = [(1, 1), (2, 2), (3, 6), (4, 7), (7, 8), 111 | (8, 8), (9, 9), (5, 10), (6, 10)] 112 | from smt.utils.utility import matrix 113 | print(matrix(len(es), len(fs), e2f, es, fs)) 114 | print(matrix(len(es), len(fs), f2e, es, fs)) 115 | ali = _alignment(es, fs, e2f, f2e) 116 | print(matrix(len(es), len(fs), ali, es, fs)) 117 | 118 | # test for symmetrization 119 | from smt.utils.utility import mkcorpus 120 | sentenses = [("僕 は 男 です", "I am a man"), 121 | ("私 は 女 です", "I am a girl"), 122 | ("私 は 先生 です", "I am a teacher"), 123 | ("彼女 は 先生 です", "She is a teacher"), 124 | ("彼 は 先生 です", "He is a teacher"), 125 | ] 126 | corpus = mkcorpus(sentenses) 127 | es = "私 は 先生 です".split() 128 | fs = "I am a teacher".split() 129 | syn = symmetrization(es, fs, corpus) 130 | pprint(syn) 131 | print(matrix(len(es), len(fs), syn, es, fs)) 132 | -------------------------------------------------------------------------------- /smt/phrase/phrase_extract.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | 5 | def phrase_extract(es, fs, alignment): 6 | ext = extract(es, fs, alignment) 7 | ind = {((x, y), (z, w)) for x, y, z, w in ext} 8 | es = tuple(es) 9 | fs = tuple(fs) 10 | return {(es[e_s-1:e_e], fs[f_s-1:f_e]) 11 | for (e_s, e_e), (f_s, f_e) in ind} 12 | 13 | 14 | def extract(es, fs, alignment): 15 | """ 16 | caution: 17 | alignment starts from 1 - not 0 18 | """ 19 | phrases = set() 20 | len_es = len(es) 21 | for e_start in range(1, len_es+1): 22 | for e_end in range(e_start, len_es+1): 23 | # find the minimally matching foreign phrase 24 | f_start, f_end = (len(fs), 0) 25 | for (e, f) in alignment: 26 | if e_start <= e <= e_end: 27 | f_start = min(f, f_start) 28 | f_end = max(f, f_end) 29 | phrases.update(_extract(es, fs, e_start, 30 | e_end, f_start, 31 | f_end, alignment)) 32 | return phrases 33 | 34 | 35 | def _extract(es, fs, e_start, e_end, f_start, f_end, alignment): 36 | if f_end == 0: 37 | return {} 38 | for (e, f) in alignment: 39 | if (f_start <= f <= f_end) and (e < e_start or e > e_end): 40 | return {} 41 | ex = set() 42 | f_s = f_start 43 | while True: 44 | f_e = f_end 45 | while True: 46 | ex.add((e_start, e_end, f_s, f_e)) 47 | f_e += 1 48 | if f_e in list(zip(*alignment))[1] or f_e > len(fs): 49 | break 50 | f_s -= 1 51 | if f_s in list(zip(*alignment))[1] or f_s < 1: 52 | break 53 | return ex 54 | 55 | 56 | def available_phrases(fs, phrases): 57 | """ 58 | return: 59 | set of phrase indexed tuple like 60 | {((1, "I"), (2, "am")), 61 | ((1, "I"),) 62 | ...} 63 | """ 64 | available = set() 65 | for i, f in enumerate(fs): 66 | f_rest = () 67 | for fr in fs[i:]: 68 | f_rest += (fr,) 69 | if f_rest in phrases: 70 | available.add(tuple(enumerate(f_rest, i+1))) 71 | return available 72 | 73 | 74 | def test_phrases(): 75 | from smt.utils.utility import mkcorpus 76 | from smt.phrase.word_alignment import symmetrization 77 | 78 | sentenses = [("僕 は 男 です", "I am a man"), 79 | ("私 は 女 です", "I am a girl"), 80 | ("私 は 先生 です", "I am a teacher"), 81 | ("彼女 は 先生 です", "She is a teacher"), 82 | ("彼 は 先生 です", "He is a teacher"), 83 | ] 84 | 85 | corpus = mkcorpus(sentenses) 86 | es, fs = ("私 は 先生 です".split(), "I am a teacher".split()) 87 | alignment = symmetrization(es, fs, corpus) 88 | ext = phrase_extract(es, fs, alignment) 89 | ans = ("は 先生 です <-> a teacher", 90 | "先生 <-> teacher" 91 | "私 <-> I am" 92 | "私 は 先生 です <-> I am a teacher") 93 | for e, f in ext: 94 | print("{} {} {}".format(' '.join(e), "<->", ' '.join(f))) 95 | 96 | ## phrases 97 | fs = "I am a teacher".split() 98 | phrases = available_phrases(fs, [fs_ph for (es_ph, fs_ph) in ext]) 99 | print(phrases) 100 | ans = {((1, 'I'), (2, 'am')), 101 | ((1, 'I'), (2, 'am'), (3, 'a'), (4, 'teacher')), 102 | ((4, 'teacher'),), 103 | ((3, 'a'), (4, 'teacher'))} 104 | 105 | phrases = available_phrases(fs, [fs_ph for (es_ph, fs_ph) in ext]) 106 | assert ans == phrases 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | # test2 112 | from smt.utils.utility import mkcorpus 113 | from word_alignment import alignment 114 | from smt.ibmmodel import ibmmodel2 115 | import sys 116 | 117 | delimiter = "," 118 | # load file which will be trained 119 | modelfd = open(sys.argv[1]) 120 | sentenses = [line.rstrip().split(delimiter) for line 121 | in modelfd.readlines()] 122 | # make corpus 123 | corpus = mkcorpus(sentenses) 124 | 125 | # train model from corpus 126 | f2e_train = ibmmodel2._train(corpus, loop_count=10) 127 | e2f_corpus = list(zip(*reversed(list(zip(*corpus))))) 128 | e2f_train = ibmmodel2._train(e2f_corpus, loop_count=10) 129 | 130 | # phrase extraction 131 | for line in sys.stdin: 132 | _es, _fs = line.rstrip().split(delimiter) 133 | es = _es.split() 134 | fs = _fs.split() 135 | 136 | f2e = ibmmodel2.viterbi_alignment(es, fs, *f2e_train).items() 137 | e2f = ibmmodel2.viterbi_alignment(fs, es, *e2f_train).items() 138 | align = alignment(es, fs, e2f, f2e) # symmetrized alignment 139 | 140 | # output matrix 141 | #from smt.utils.utility import matrix 142 | #print(matrix(len(es), len(fs), align, es, fs)) 143 | 144 | ext = phrase_extract(es, fs, align) 145 | for e, f in ext: 146 | print("{}{}{}".format(''.join(e), delimiter, ''.join(f))) 147 | -------------------------------------------------------------------------------- /smt/ibmmodel/ibmmodel2.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | import collections 5 | from smt.ibmmodel import ibmmodel1 6 | from smt.utils import utility 7 | import decimal 8 | from decimal import Decimal as D 9 | 10 | # set deciaml context 11 | decimal.getcontext().prec = 4 12 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP 13 | 14 | 15 | class _keydefaultdict(collections.defaultdict): 16 | '''define a local function for uniform probability initialization''' 17 | def __missing__(self, key): 18 | if self.default_factory is None: 19 | raise KeyError(key) 20 | else: 21 | ret = self[key] = self.default_factory(key) 22 | return ret 23 | 24 | 25 | def _train(corpus, loop_count=1000): 26 | #print(corpus) 27 | #print(loop_count) 28 | f_keys = set() 29 | for (es, fs) in corpus: 30 | for f in fs: 31 | f_keys.add(f) 32 | # initialize t 33 | t = ibmmodel1._train(corpus, loop_count) 34 | # default value provided as uniform probability) 35 | 36 | def key_fun(key): 37 | ''' default_factory function for keydefaultdict ''' 38 | i, j, l_e, l_f = key 39 | return D("1") / D(l_f + 1) 40 | a = _keydefaultdict(key_fun) 41 | 42 | # loop 43 | for _i in range(loop_count): 44 | # variables for estimating t 45 | count = collections.defaultdict(D) 46 | total = collections.defaultdict(D) 47 | # variables for estimating a 48 | count_a = collections.defaultdict(D) 49 | total_a = collections.defaultdict(D) 50 | 51 | s_total = collections.defaultdict(D) 52 | for (es, fs) in corpus: 53 | l_e = len(es) 54 | l_f = len(fs) 55 | # compute normalization 56 | for (j, e) in enumerate(es, 1): 57 | s_total[e] = 0 58 | for (i, f) in enumerate(fs, 1): 59 | s_total[e] += t[(e, f)] * a[(i, j, l_e, l_f)] 60 | # collect counts 61 | for (j, e) in enumerate(es, 1): 62 | for (i, f) in enumerate(fs, 1): 63 | c = t[(e, f)] * a[(i, j, l_e, l_f)] / s_total[e] 64 | count[(e, f)] += c 65 | total[f] += c 66 | count_a[(i, j, l_e, l_f)] += c 67 | total_a[(j, l_e, l_f)] += c 68 | 69 | #for k, v in total.items(): 70 | # if v == 0: 71 | # print(k, v) 72 | # estimate probability 73 | for (e, f) in count.keys(): 74 | try: 75 | t[(e, f)] = count[(e, f)] / total[f] 76 | except decimal.DivisionByZero: 77 | print(u"e: {e}, f: {f}, count[(e, f)]: {ef}, total[f]: \ 78 | {totalf}".format(e=e, f=f, ef=count[(e, f)], 79 | totalf=total[f])) 80 | raise 81 | for (i, j, l_e, l_f) in count_a.keys(): 82 | a[(i, j, l_e, l_f)] = count_a[(i, j, l_e, l_f)] / \ 83 | total_a[(j, l_e, l_f)] 84 | # output 85 | #for (e, f), val in t.items(): 86 | # print("{} {}\t{}".format(e, f, float(val))) 87 | #for (i, j, l_e, l_f), val in a.items(): 88 | # print("{} {} {} {}\t{}".format(i, j, l_e, l_f, float(val))) 89 | 90 | return (t, a) 91 | 92 | 93 | def train(sentences, loop_count=1000): 94 | #for i, j in sentences: 95 | # print(i, j) 96 | corpus = utility.mkcorpus(sentences) 97 | return _train(corpus, loop_count) 98 | 99 | 100 | def viterbi_alignment(es, fs, t, a): 101 | ''' 102 | return 103 | dictionary 104 | e in es -> f in fs 105 | ''' 106 | max_a = collections.defaultdict(float) 107 | l_e = len(es) 108 | l_f = len(fs) 109 | for (j, e) in enumerate(es, 1): 110 | current_max = (0, -1) 111 | for (i, f) in enumerate(fs, 1): 112 | val = t[(e, f)] * a[(i, j, l_e, l_f)] 113 | # select the first one among the maximum candidates 114 | if current_max[1] < val: 115 | current_max = (i, val) 116 | max_a[j] = current_max[0] 117 | return max_a 118 | 119 | 120 | def show_matrix(es, fs, t, a): 121 | ''' 122 | print matrix according to viterbi alignment like 123 | fs 124 | ------------- 125 | e| | 126 | s| | 127 | | | 128 | ------------- 129 | >>> sentences = [("僕 は 男 です", "I am a man"), 130 | ("私 は 女 です", "I am a girl"), 131 | ("私 は 先生 です", "I am a teacher"), 132 | ("彼女 は 先生 です", "She is a teacher"), 133 | ("彼 は 先生 です", "He is a teacher"), 134 | ] 135 | >>> t, a = train(sentences, loop_count=1000) 136 | >>> args = ("私 は 先生 です".split(), "I am a teacher".split(), t, a) 137 | |x| | | | 138 | | | |x| | 139 | | | | |x| 140 | | | |x| | 141 | ''' 142 | max_a = viterbi_alignment(es, fs, t, a).items() 143 | m = len(es) 144 | n = len(fs) 145 | return utility.matrix(m, n, max_a, es, fs) 146 | 147 | 148 | 149 | def test_viterbi_alignment(): 150 | x = viterbi_alignment([1, 2, 1], 151 | [2, 3, 2], 152 | collections.defaultdict(int), 153 | collections.defaultdict(int)) 154 | # Viterbi_alignment selects the first token 155 | # if t or a doesn't contain the key. 156 | # This means it returns NULL token 157 | # in such a situation. 158 | ans = {1: 1, 2: 1, 3: 1} 159 | assert dict(x) == ans 160 | 161 | 162 | if __name__ == '__main__': 163 | import sys 164 | 165 | fd = open(sys.argv[1]) if len(sys.argv) >= 2 else sys.stdin 166 | sentences = [line.strip().split('|||') for line in fd.readlines()] 167 | t, a = train(sentences, loop_count=10) 168 | 169 | es = "私 は 先生 です".split() 170 | fs = "I am a teacher".split() 171 | args = (es, fs, t, a) 172 | 173 | print(show_matrix(*args)) 174 | -------------------------------------------------------------------------------- /test/test_phrase.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | import unittest 5 | from smt.phrase.word_alignment import _alignment 6 | from smt.phrase.word_alignment import symmetrization 7 | from smt.phrase.phrase_extract import extract 8 | from smt.phrase.phrase_extract import phrase_extract 9 | from smt.phrase.phrase_extract import available_phrases 10 | from smt.utils.utility import mkcorpus 11 | 12 | 13 | class WordAlignmentTest(unittest.TestCase): 14 | 15 | def test_alignment(self): 16 | elist = "michael assumes that he will stay in the house".split() 17 | flist = "michael geht davon aus , dass er im haus bleibt".split() 18 | e2f = [(1, 1), (2, 2), (2, 3), (2, 4), (3, 6), 19 | (4, 7), (7, 8), (9, 9), (6, 10)] 20 | f2e = [(1, 1), (2, 2), (3, 6), (4, 7), (7, 8), 21 | (8, 8), (9, 9), (5, 10), (6, 10)] 22 | ans = set([(1, 1), 23 | (2, 2), 24 | (2, 3), 25 | (2, 4), 26 | (3, 6), 27 | (4, 7), 28 | (5, 10), 29 | (6, 10), 30 | (7, 8), 31 | (8, 8), 32 | (9, 9)]) 33 | self.assertEqual(_alignment(elist, flist, e2f, f2e), ans) 34 | 35 | def test_symmetrization(self): 36 | sentenses = [("僕 は 男 です", "I am a man"), 37 | ("私 は 女 です", "I am a girl"), 38 | ("私 は 先生 です", "I am a teacher"), 39 | ("彼女 は 先生 です", "She is a teacher"), 40 | ("彼 は 先生 です", "He is a teacher"), 41 | ] 42 | corpus = mkcorpus(sentenses) 43 | es = "私 は 先生 です".split() 44 | fs = "I am a teacher".split() 45 | syn = symmetrization(es, fs, corpus) 46 | ans = set([(1, 1), (1, 2), (2, 3), (3, 4), (4, 3)]) 47 | self.assertEqual(syn, ans) 48 | 49 | 50 | class PhraseExtractTest(unittest.TestCase): 51 | def test_extract(self): 52 | 53 | # next alignment matrix is like 54 | # 55 | # | |x|x| | | 56 | # |x| | |x| | 57 | # | | | | |x| 58 | # 59 | es = range(1, 4) 60 | fs = range(1, 6) 61 | alignment = [(2, 1), 62 | (1, 2), 63 | (1, 3), 64 | (2, 4), 65 | (3, 5)] 66 | ans = set([(1, 1, 2, 3), (1, 3, 1, 5), (3, 3, 5, 5), (1, 2, 1, 4)]) 67 | self.assertEqual(extract(es, fs, alignment), ans) 68 | 69 | # next alignment matrix is like 70 | # 71 | # |x| | | | | | | | | | 72 | # | |x|x|x| | | | | | | 73 | # | | | | | |x| | | | | 74 | # | | | | | | |x| | | | 75 | # | | | | | | | | | |x| 76 | # | | | | | | | | | |x| 77 | # | | | | | | | |x| | | 78 | # | | | | | | | |x| | | 79 | # | | | | | | | | |x| | 80 | # 81 | es = "michael assumes that he will stay in the house".split() 82 | fs = "michael geht davon aus , dass er im haus bleibt".split() 83 | alignment = set([(1, 1), 84 | (2, 2), 85 | (2, 3), 86 | (2, 4), 87 | (3, 6), 88 | (4, 7), 89 | (5, 10), 90 | (6, 10), 91 | (7, 8), 92 | (8, 8), 93 | (9, 9)]) 94 | ans = set([(1, 1, 1, 1), 95 | (1, 2, 1, 4), 96 | (1, 2, 1, 5), 97 | (1, 3, 1, 6), 98 | (1, 4, 1, 7), 99 | (1, 9, 1, 10), 100 | (2, 2, 2, 4), 101 | (2, 2, 2, 5), 102 | (2, 3, 2, 6), 103 | (2, 4, 2, 7), 104 | (2, 9, 2, 10), 105 | (3, 3, 5, 6), 106 | (3, 3, 6, 6), 107 | (3, 4, 5, 7), 108 | (3, 4, 6, 7), 109 | (3, 9, 5, 10), 110 | (3, 9, 6, 10), 111 | (4, 4, 7, 7), 112 | (4, 9, 7, 10), 113 | (5, 6, 10, 10), 114 | (5, 9, 8, 10), 115 | (7, 8, 8, 8), 116 | (7, 9, 8, 9), 117 | (9, 9, 9, 9)]) 118 | 119 | self.assertEqual(extract(es, fs, alignment), ans) 120 | 121 | def test_phrase_extract(self): 122 | # next alignment matrix is like 123 | # 124 | # |x| | | | | | | | | | 125 | # | |x|x|x| | | | | | | 126 | # | | | | | |x| | | | | 127 | # | | | | | | |x| | | | 128 | # | | | | | | | | | |x| 129 | # | | | | | | | | | |x| 130 | # | | | | | | | |x| | | 131 | # | | | | | | | |x| | | 132 | # | | | | | | | | |x| | 133 | # 134 | es = "michael assumes that he will stay in the house".split() 135 | fs = "michael geht davon aus , dass er im haus bleibt".split() 136 | alignment = set([(1, 1), 137 | (2, 2), 138 | (2, 3), 139 | (2, 4), 140 | (3, 6), 141 | (4, 7), 142 | (5, 10), 143 | (6, 10), 144 | (7, 8), 145 | (8, 8), 146 | (9, 9)]) 147 | ans = set([(('assumes',), ('geht', 'davon', 'aus')), 148 | (('assumes',), ('geht', 'davon', 'aus', ',')), 149 | (('assumes', 'that'), 150 | ('geht', 'davon', 'aus', ',', 'dass')), 151 | (('assumes', 'that', 'he'), 152 | ('geht', 'davon', 'aus', ',', 'dass', 'er')), 153 | (('assumes', 'that', 'he', 154 | 'will', 'stay', 'in', 'the', 'house'), 155 | ('geht', 'davon', 'aus', ',', 'dass', 156 | 'er', 'im', 'haus', 'bleibt')), 157 | (('he',), ('er',)), 158 | (('he', 'will', 'stay', 'in', 'the', 'house'), 159 | ('er', 'im', 'haus', 'bleibt')), 160 | (('house',), ('haus',)), 161 | (('in', 'the'), ('im',)), 162 | (('in', 'the', 'house'), ('im', 'haus')), 163 | (('michael',), ('michael',)), 164 | (('michael', 'assumes'), 165 | ('michael', 'geht', 'davon', 'aus')), 166 | (('michael', 'assumes'), 167 | ('michael', 'geht', 'davon', 'aus', ',')), 168 | (('michael', 'assumes', 'that'), 169 | ('michael', 'geht', 'davon', 'aus', ',', 'dass')), 170 | (('michael', 'assumes', 'that', 'he'), 171 | ('michael', 'geht', 'davon', 'aus', ',', 'dass', 'er')), 172 | (('michael', 173 | 'assumes', 174 | 'that', 175 | 'he', 176 | 'will', 177 | 'stay', 178 | 'in', 179 | 'the', 180 | 'house'), 181 | ('michael', 182 | 'geht', 183 | 'davon', 184 | 'aus', 185 | ',', 186 | 'dass', 187 | 'er', 188 | 'im', 189 | 'haus', 190 | 'bleibt')), 191 | (('that',), (',', 'dass')), 192 | (('that',), ('dass',)), 193 | (('that', 'he'), (',', 'dass', 'er')), 194 | (('that', 'he'), ('dass', 'er')), 195 | (('that', 'he', 'will', 'stay', 'in', 'the', 'house'), 196 | (',', 'dass', 'er', 'im', 'haus', 'bleibt')), 197 | (('that', 'he', 'will', 'stay', 'in', 'the', 'house'), 198 | ('dass', 'er', 'im', 'haus', 'bleibt')), 199 | (('will', 'stay'), ('bleibt',)), 200 | (('will', 'stay', 'in', 'the', 'house'), 201 | ('im', 'haus', 'bleibt'))]) 202 | self.assertEqual(phrase_extract(es, fs, alignment), ans) 203 | 204 | # another test 205 | es, fs = ("私 は 先生 です".split(), "I am a teacher".split()) 206 | sentenses = [("僕 は 男 です", "I am a man"), 207 | ("私 は 女 です", "I am a girl"), 208 | ("私 は 先生 です", "I am a teacher"), 209 | ("彼女 は 先生 です", "She is a teacher"), 210 | ("彼 は 先生 です", "He is a teacher"), 211 | ] 212 | corpus = mkcorpus(sentenses) 213 | alignment = symmetrization(es, fs, corpus) 214 | ans = set([(('\xe3\x81\xaf', 215 | '\xe5\x85\x88\xe7\x94\x9f', 216 | '\xe3\x81\xa7\xe3\x81\x99'), 217 | ('a', 'teacher')), 218 | (('\xe5\x85\x88\xe7\x94\x9f',), ('teacher',)), 219 | (('\xe7\xa7\x81',), ('I', 'am')), 220 | (('\xe7\xa7\x81', 221 | '\xe3\x81\xaf', 222 | '\xe5\x85\x88\xe7\x94\x9f', 223 | '\xe3\x81\xa7\xe3\x81\x99'), 224 | ('I', 'am', 'a', 'teacher'))]) 225 | self.assertEqual(phrase_extract(es, fs, alignment), ans) 226 | 227 | def test_available_phrases(self): 228 | fs = "I am a teacher".split() 229 | phrases = set([("I", "am"), 230 | ("a", "teacher"), 231 | ("teacher",), 232 | ("I", "am", "a", "teacher")]) 233 | 234 | ans = set([((4, 'teacher'),), 235 | ((1, 'I'), (2, 'am')), 236 | ((3, 'a'), (4, 'teacher')), 237 | ((1, 'I'), (2, 'am'), (3, 'a'), (4, 'teacher'))]) 238 | self.assertEqual(available_phrases(fs, phrases), ans) 239 | 240 | if __name__ == '__main__': 241 | unittest.main() 242 | -------------------------------------------------------------------------------- /smt/db/createngramdb.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import collections 6 | import sqlite3 7 | # import SQLAlchemy 8 | from sqlalchemy import create_engine 9 | from sqlalchemy.orm import sessionmaker 10 | # smt 11 | from smt.db.tables import Tables 12 | from smt.langmodel.ngram import ngram 13 | import math 14 | 15 | 16 | def _create_ngram_count_db(lang, langmethod=lambda x: x, 17 | n=3, db="sqilte:///:memory:"): 18 | engine = create_engine(db) 19 | # create session 20 | Session = sessionmaker(bind=engine) 21 | session = Session() 22 | 23 | Sentence = Tables().get_sentence_table() 24 | query = session.query(Sentence) 25 | 26 | ngram_dic = collections.defaultdict(float) 27 | for item in query: 28 | if lang == 1: 29 | sentences = langmethod(item.lang1).split() 30 | elif lang == 2: 31 | sentences = langmethod(item.lang2).split() 32 | sentences = ["", ""] + sentences + [""] 33 | ngrams = ngram(sentences, n) 34 | for tpl in ngrams: 35 | ngram_dic[tpl] += 1 36 | 37 | return ngram_dic 38 | 39 | 40 | def create_ngram_count_db(lang, langmethod=lambda x: x, 41 | n=3, db="sqilte:///:memory:"): 42 | engine = create_engine(db) 43 | # create session 44 | Session = sessionmaker(bind=engine) 45 | session = Session() 46 | 47 | # trigram table 48 | tablename = 'lang{}trigram'.format(lang) 49 | Trigram = Tables().get_trigram_table(tablename) 50 | # create table 51 | Trigram.__table__.drop(engine, checkfirst=True) 52 | Trigram.__table__.create(engine) 53 | 54 | ngram_dic = _create_ngram_count_db(lang, langmethod=langmethod, n=n, db=db) 55 | 56 | # insert items 57 | for (first, second, third), count in ngram_dic.items(): 58 | print(u"inserting {}, {}, {}".format(first, second, third)) 59 | item = Trigram(first=first, 60 | second=second, 61 | third=third, 62 | count=count) 63 | session.add(item) 64 | session.commit() 65 | 66 | 67 | def create_unigram_count_db(lang, langmethod=lambda x: x, 68 | db="sqilte:///:memory:"): 69 | engine = create_engine(db) 70 | # create session 71 | Session = sessionmaker(bind=engine) 72 | session = Session() 73 | 74 | # trigram table 75 | tablename = 'lang{}unigram'.format(lang) 76 | Sentence = Tables().get_sentence_table() 77 | Unigram = Tables().get_unigram_table(tablename) 78 | # create table 79 | Unigram.__table__.drop(engine, checkfirst=True) 80 | Unigram.__table__.create(engine) 81 | 82 | query = session.query(Sentence) 83 | ngram_dic = collections.defaultdict(int) 84 | for item in query: 85 | if lang == 1: 86 | sentences = langmethod(item.lang1).split() 87 | elif lang == 2: 88 | sentences = langmethod(item.lang2).split() 89 | ngrams = ngram(sentences, 1) 90 | for tpl in ngrams: 91 | ngram_dic[tpl] += 1 92 | 93 | # insert items 94 | for (first,), count in ngram_dic.items(): 95 | print(u"inserting {}: {}".format(first, count)) 96 | item = Unigram(first=first, 97 | count=count) 98 | session.add(item) 99 | session.commit() 100 | 101 | 102 | # create views using SQLite3 103 | def create_ngram_count_without_last_view(lang, db=":memory:"): 104 | # create phrase_count table 105 | fromtablename = "lang{}trigram".format(lang) 106 | table_name = "lang{}trigram_without_last".format(lang) 107 | # create connection 108 | con = sqlite3.connect(db) 109 | cur = con.cursor() 110 | try: 111 | cur.execute("drop view {0}".format(table_name)) 112 | except sqlite3.Error: 113 | print("{0} view does not exists.\n\ 114 | => creating a new view".format(table_name)) 115 | cur.execute("""create view {} 116 | as select first, second, sum(count) as count from 117 | {} group by first, second order by count 118 | desc""".format(table_name, fromtablename)) 119 | con.commit() 120 | 121 | 122 | def create_ngram_prob(lang, 123 | db=":memory:"): 124 | 125 | # Create connection in sqlite3 to use view 126 | table_name = "lang{}trigram_without_last".format(lang) 127 | # create connection 128 | con = sqlite3.connect(db) 129 | cur = con.cursor() 130 | 131 | trigram_tablename = 'lang{}trigram'.format(lang) 132 | trigramprob_tablename = 'lang{}trigramprob'.format(lang) 133 | trigramprobwithoutlast_tablename = 'lang{}trigramprob_without_last'\ 134 | .format(lang) 135 | 136 | # tables 137 | Trigram = Tables().get_trigram_table(trigram_tablename) 138 | TrigramProb = Tables().get_trigramprob_table(trigramprob_tablename) 139 | TrigramProbWithoutLast = Tables().get_trigramprobwithoutlast_table( 140 | trigramprobwithoutlast_tablename) 141 | 142 | # create connection in SQLAlchemy 143 | sqlalchemydb = "sqlite:///{}".format(db) 144 | engine = create_engine(sqlalchemydb) 145 | # create session 146 | Session = sessionmaker(bind=engine) 147 | session = Session() 148 | # create table 149 | TrigramProb.__table__.drop(engine, checkfirst=True) 150 | TrigramProb.__table__.create(engine) 151 | TrigramProbWithoutLast.__table__.drop(engine, checkfirst=True) 152 | TrigramProbWithoutLast.__table__.create(engine) 153 | 154 | # calculate total number 155 | query = session.query(Trigram) 156 | totalnumber = len(query.all()) 157 | 158 | # get trigrams 159 | query = session.query(Trigram) 160 | for item in query: 161 | first, second, third = item.first, item.second, item.third 162 | count = item.count 163 | 164 | cur.execute("""select * from {} where \ 165 | first=? and\ 166 | second=?""".format(table_name), 167 | (first, second)) 168 | one = cur.fetchone() 169 | # if fetch is failed, one is NONE (no exceptions are raised) 170 | if not one: 171 | print("not found correspont first and second") 172 | continue 173 | else: 174 | alpha = 0.00017 175 | c = count 176 | n = one[2] 177 | v = totalnumber 178 | # create logprob 179 | logprob = math.log((c + alpha) / (n + alpha * v)) 180 | print(u"{}, {}, {}:\ 181 | log({} + {} / {} + {} + {}) = {}".format(first, 182 | second, 183 | third, 184 | c, 185 | alpha, 186 | n, 187 | alpha, 188 | v, 189 | logprob)) 190 | trigramprob = TrigramProb(first=first, 191 | second=second, 192 | third=third, 193 | prob=logprob) 194 | session.add(trigramprob) 195 | # for without last 196 | logprobwithoutlast = math.log(alpha / (n + alpha * v)) 197 | print(u"{}, {}, {}:\ 198 | log({} / {} + {} + {}) = {}".format(first, 199 | second, 200 | third, 201 | alpha, 202 | n, 203 | alpha, 204 | v, 205 | logprobwithoutlast)) 206 | probwl = TrigramProbWithoutLast(first=first, 207 | second=second, 208 | prob=logprobwithoutlast) 209 | session.add(probwl) 210 | session.commit() 211 | 212 | 213 | def create_unigram_prob(lang, db=":memory:"): 214 | 215 | unigram_tablename = 'lang{}unigram'.format(lang) 216 | unigramprob_tablename = 'lang{}unigramprob'.format(lang) 217 | 218 | # tables 219 | Unigram = Tables().get_unigram_table(unigram_tablename) 220 | UnigramProb = Tables().get_unigramprob_table(unigramprob_tablename) 221 | 222 | # create engine 223 | sqlalchemydb = "sqlite:///{}".format(db) 224 | engine = create_engine(sqlalchemydb) 225 | # create session 226 | Session = sessionmaker(bind=engine) 227 | session = Session() 228 | # create table 229 | UnigramProb.__table__.drop(engine, checkfirst=True) 230 | UnigramProb.__table__.create(engine) 231 | 232 | # calculate total number 233 | query = session.query(Unigram) 234 | sm = 0 235 | totalnumber = 0 236 | for item in query: 237 | totalnumber += 1 238 | sm += item.count 239 | 240 | # get trigrams 241 | query = session.query(Unigram) 242 | for item in query: 243 | first = item.first 244 | count = item.count 245 | 246 | alpha = 0.00017 247 | c = count 248 | v = totalnumber 249 | # create logprob 250 | logprob = math.log((c + alpha) / (sm + alpha * v)) 251 | print(u"{}:\ 252 | log({}+{} / {} + {}*{}) = {}".format(first, 253 | c, 254 | alpha, 255 | sm, 256 | alpha, 257 | v, 258 | logprob)) 259 | unigramprob = UnigramProb(first=first, 260 | prob=logprob) 261 | session.add(unigramprob) 262 | session.commit() 263 | 264 | 265 | def create_ngram_db(lang, langmethod=lambda x: x, 266 | n=3, db=":memory:"): 267 | 268 | sqlalchemydb = "sqlite:///{}".format(db) 269 | create_ngram_count_db(lang=lang, langmethod=langmethod, 270 | n=n, 271 | db=sqlalchemydb) 272 | create_ngram_count_without_last_view(lang=lang, db=db) 273 | create_ngram_prob(lang=lang, db=db) 274 | 275 | create_unigram_count_db(lang=lang, langmethod=langmethod, 276 | db=sqlalchemydb) 277 | create_unigram_prob(lang=lang, db=db) 278 | 279 | 280 | if __name__ == '__main__': 281 | pass 282 | -------------------------------------------------------------------------------- /smt/db/createdb.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import collections 6 | import utility 7 | from smt.ibmmodel import ibmmodel2 8 | from smt.phrase import word_alignment 9 | from smt.phrase import phrase_extract 10 | from progressline import ProgressLine 11 | from tables import Tables 12 | # import SQLAlchemy 13 | import sqlalchemy 14 | from sqlalchemy import create_engine 15 | from sqlalchemy.orm import sessionmaker 16 | import sqlite3 17 | import math 18 | 19 | 20 | def create_corpus(db="sqlite:///:memory:", 21 | lang1method=lambda x: x, 22 | lang2method=lambda x: x, 23 | limit=None): 24 | engine = create_engine(db) 25 | # create session 26 | Session = sessionmaker(bind=engine) 27 | session = Session() 28 | 29 | Sentence = Tables().get_sentence_table() 30 | 31 | query = session.query(Sentence)[:limit] if limit \ 32 | else session.query(Sentence) 33 | 34 | for item in query: 35 | yield {"lang1": lang1method(item.lang1), 36 | "lang2": lang2method(item.lang2)} 37 | 38 | 39 | def create_train_db(transfrom=2, 40 | transto=1, 41 | lang1method=lambda x: x, 42 | lang2method=lambda x: x, 43 | db="sqlite:///:memory:", 44 | limit=None, 45 | loop_count=1000): 46 | engine = create_engine(db) 47 | # create session 48 | Session = sessionmaker(bind=engine) 49 | session = Session() 50 | 51 | # tablenames 52 | table_prefix = "from{0}to{1}".format(transfrom, transto) 53 | wordprob_tablename = table_prefix + "_" + "wordprob" 54 | wordalign_tablename = table_prefix + "_" + "wordalign" 55 | # tables 56 | WordProbability = Tables().get_wordprobability_table(wordprob_tablename) 57 | WordAlignment = Tables().get_wordalignment_table(wordalign_tablename) 58 | # create table for word probability 59 | WordProbability.__table__.drop(engine, checkfirst=True) 60 | WordProbability.__table__.create(engine) 61 | print("created table: {0}to{1}_wordprob".format(transfrom, transto)) 62 | 63 | # create table for alignment probability 64 | WordAlignment.__table__.drop(engine, checkfirst=True) 65 | WordAlignment.__table__.create(engine) 66 | print("created table: {0}to{1}_wordalign".format(transfrom, transto)) 67 | 68 | # IBM learning 69 | with ProgressLine(0.12, title='IBM Model learning...'): 70 | # check arguments for carete_corpus 71 | corpus = create_corpus(db=db, limit=limit, 72 | lang1method=lang1method, 73 | lang2method=lang2method) 74 | sentences = [(item["lang{0}".format(transto)], 75 | item["lang{0}".format(transfrom)]) 76 | for item in corpus] 77 | t, a = ibmmodel2.train(sentences=sentences, 78 | loop_count=loop_count) 79 | # insert 80 | with ProgressLine(0.12, title='Inserting items into database...'): 81 | for (_to, _from), prob in t.items(): 82 | session.add(WordProbability(transto=_to, 83 | transfrom=_from, 84 | prob=float(prob))) 85 | for (from_pos, to_pos, to_len, from_len), prob in a.items(): 86 | session.add(WordAlignment(from_pos=from_pos, 87 | to_pos=to_pos, 88 | to_len=to_len, 89 | from_len=from_len, 90 | prob=float(prob))) 91 | session.commit() 92 | 93 | 94 | def db_viterbi_alignment(es, fs, 95 | transfrom=2, 96 | transto=1, 97 | db="sqlite:///:memory:", 98 | init_val=1.0e-10): 99 | """ 100 | Calculating viterbi_alignment using specified database. 101 | 102 | Arguments: 103 | trans: 104 | it can take "en2ja" or "ja2en" 105 | """ 106 | engine = create_engine(db) 107 | # create session 108 | Session = sessionmaker(bind=engine) 109 | session = Session() 110 | 111 | # tablenames 112 | table_prefix = "from{0}to{1}".format(transfrom, transto) 113 | wordprob_tablename = table_prefix + "_" + "wordprob" 114 | wordalign_tablename = table_prefix + "_" + "wordalign" 115 | # tables 116 | WordProbability = Tables().get_wordprobability_table(wordprob_tablename) 117 | WordAlignment = Tables().get_wordalignment_table(wordalign_tablename) 118 | 119 | def get_wordprob(e, f, init_val=1.0e-10): 120 | 121 | query = session.query(WordProbability).filter_by(transto=e, 122 | transfrom=f) 123 | try: 124 | return query.one().prob 125 | except sqlalchemy.orm.exc.NoResultFound: 126 | return init_val 127 | 128 | def get_wordalign(i, j, l_e, l_f, init_val=1.0e-10): 129 | 130 | query = session.query(WordAlignment).filter_by(from_pos=i, 131 | to_pos=j, 132 | to_len=l_e, 133 | from_len=l_f) 134 | try: 135 | return query.one().prob 136 | except sqlalchemy.orm.exc.NoResultFound: 137 | return init_val 138 | 139 | # algorithm 140 | max_a = collections.defaultdict(float) 141 | l_e = len(es) 142 | l_f = len(fs) 143 | for (j, e) in enumerate(es, 1): 144 | current_max = (0, -1) 145 | for (i, f) in enumerate(fs, 1): 146 | val = get_wordprob(e, f, init_val=init_val) *\ 147 | get_wordalign(i, j, l_e, l_f, init_val=init_val) 148 | # select the first one among the maximum candidates 149 | if current_max[1] < val: 150 | current_max = (i, val) 151 | max_a[j] = current_max[0] 152 | return max_a 153 | 154 | 155 | def db_show_matrix(es, fs, 156 | transfrom=2, 157 | transto=1, 158 | db="sqlite:///:memory:", 159 | init_val=0.00001): 160 | ''' 161 | print matrix according to viterbi alignment like 162 | fs 163 | ------------- 164 | e| | 165 | s| | 166 | | | 167 | ------------- 168 | >>> sentences = [("僕 は 男 です", "I am a man"), 169 | ("私 は 女 です", "I am a girl"), 170 | ("私 は 先生 です", "I am a teacher"), 171 | ("彼女 は 先生 です", "She is a teacher"), 172 | ("彼 は 先生 です", "He is a teacher"), 173 | ] 174 | >>> t, a = train(sentences, loop_count=1000) 175 | >>> args = ("私 は 先生 です".split(), "I am a teacher".split(), t, a) 176 | |x| | | | 177 | | | |x| | 178 | | | | |x| 179 | | | |x| | 180 | ''' 181 | max_a = db_viterbi_alignment(es, fs, 182 | transfrom=transfrom, 183 | transto=transto, 184 | db=db, 185 | init_val=init_val).items() 186 | m = len(es) 187 | n = len(fs) 188 | return utility.matrix(m, n, max_a) 189 | 190 | 191 | def _db_symmetrization(lang1s, lang2s, 192 | init_val=1.0e-10, 193 | db="sqlite:///:memory:"): 194 | ''' 195 | ''' 196 | transfrom = 2 197 | transto = 1 198 | trans = db_viterbi_alignment(lang1s, lang2s, 199 | transfrom=transfrom, 200 | transto=transto, 201 | db=db, 202 | init_val=init_val).items() 203 | rev_trans = db_viterbi_alignment(lang2s, lang1s, 204 | transfrom=transto, 205 | transto=transfrom, 206 | db=db, 207 | init_val=init_val).items() 208 | return word_alignment.alignment(lang1s, lang2s, trans, rev_trans) 209 | 210 | 211 | def db_phrase_extract(lang1, lang2, 212 | lang1method=lambda x: x, 213 | lang2method=lambda x: x, 214 | init_val=1.0e-10, 215 | db="sqlite:///:memory:"): 216 | lang1s = lang1method(lang1).split() 217 | lang2s = lang1method(lang2).split() 218 | alignment = _db_symmetrization(lang1s, lang2s, 219 | init_val=init_val, 220 | db=db) 221 | return phrase_extract.phrase_extract(lang1s, lang2s, alignment) 222 | 223 | 224 | def create_phrase_db(limit=None, 225 | lang1method=lambda x: x, 226 | lang2method=lambda x: x, 227 | init_val=1.0e-10, 228 | db="sqlite:///:memory:"): 229 | engine = create_engine(db) 230 | # create session 231 | Session = sessionmaker(bind=engine) 232 | session = Session() 233 | # tables 234 | Sentence = Tables().get_sentence_table() 235 | Phrase = Tables().get_phrase_table() 236 | 237 | # create table for word probability 238 | Phrase.__table__.drop(engine, checkfirst=True) 239 | Phrase.__table__.create(engine) 240 | print("created table: phrase") 241 | 242 | query = session.query(Sentence)[:limit] if limit \ 243 | else session.query(Sentence) 244 | 245 | with ProgressLine(0.12, title='extracting phrases...'): 246 | for item in query: 247 | lang1 = item.lang1 248 | lang2 = item.lang2 249 | print(" ", lang1, lang2) 250 | phrases = db_phrase_extract(lang1, lang2, 251 | lang1method=lang1method, 252 | lang2method=lang2method, 253 | init_val=init_val, 254 | db=db) 255 | for lang1ps, lang2ps in phrases: 256 | lang1p = u" ".join(lang1ps) 257 | lang2p = u" ".join(lang2ps) 258 | ph = Phrase(lang1p=lang1p, lang2p=lang2p) 259 | session.add(ph) 260 | session.commit() 261 | 262 | 263 | # create views using SQLite3 264 | def create_phrase_count_view(db=":memory:"): 265 | # create phrase_count table 266 | table_name = "phrasecount" 267 | con = sqlite3.connect(db) 268 | cur = con.cursor() 269 | try: 270 | cur.execute("drop view {0}".format(table_name)) 271 | except sqlite3.Error: 272 | print("{0} view does not exists.\n\ 273 | => creating a new view".format(table_name)) 274 | cur.execute("""create view {0} 275 | as select *, count(*) as count from 276 | phrase group by lang1p, lang2p order by count 277 | desc""".format(table_name)) 278 | con.commit() 279 | 280 | # create phrase_count_ja table 281 | table_name_ja = "lang1_phrasecount" 282 | con = sqlite3.connect(db) 283 | cur = con.cursor() 284 | try: 285 | cur.execute("drop view {0}".format(table_name_ja)) 286 | except sqlite3.Error: 287 | print("{0} view does not exists.\n\ 288 | => creating a new view".format(table_name_ja)) 289 | cur.execute("""create view {0} 290 | as select lang1p as langp, 291 | sum(count) as count from phrasecount group by 292 | lang1p order 293 | by count desc""".format(table_name_ja)) 294 | con.commit() 295 | 296 | # create phrase_count_en table 297 | table_name_en = "lang2_phrasecount" 298 | con = sqlite3.connect(db) 299 | cur = con.cursor() 300 | try: 301 | cur.execute("drop view {0}".format(table_name_en)) 302 | except sqlite3.Error: 303 | print("{0} view does not exists.\n\ 304 | => creating a new view".format(table_name_en)) 305 | cur.execute("""create view {0} 306 | as select lang2p as langp, 307 | sum(count) as count from phrasecount group by 308 | lang2p order 309 | by count desc""".format(table_name_en)) 310 | con.commit() 311 | 312 | 313 | # using sqlite 314 | def create_phrase_prob(db=":memory:"): 315 | """ 316 | """ 317 | # create phrase_prob table 318 | table_name = "phraseprob" 319 | engine = create_engine("sqlite:///{0}".format(db)) 320 | # create session 321 | Session = sessionmaker(bind=engine) 322 | session = Session() 323 | # tables 324 | TransPhraseProb = Tables().get_transphraseprob_table() 325 | 326 | # create table for word probability 327 | TransPhraseProb.__table__.drop(engine, checkfirst=True) 328 | TransPhraseProb.__table__.create(engine) 329 | session.commit() 330 | print("created table: {0}".format(table_name)) 331 | 332 | con = sqlite3.connect(db) 333 | cur = con.cursor() 334 | cur_sel = con.cursor() 335 | #cur_rec = con.cursor() 336 | cur.execute("select lang1p, lang2p, count from phrasecount") 337 | with ProgressLine(0.12, title='phrase learning...'): 338 | for lang1p, lang2p, count in cur: 339 | # for p2_1 340 | cur_sel.execute(u"""select count 341 | from lang1_phrasecount where 342 | langp=?""", 343 | (lang1p,)) 344 | count2_1 = list(cur_sel) 345 | count2_1 = count2_1[0][0] 346 | p2_1 = count / count2_1 347 | # for p1_2 348 | cur_sel.execute(u"""select count 349 | from lang2_phrasecount where 350 | langp=?""", 351 | (lang2p,)) 352 | count1_2 = list(cur_sel) 353 | count1_2 = count1_2[0][0] 354 | p1_2 = count / count1_2 355 | # insert item 356 | transphraseprob = TransPhraseProb(lang1p=lang1p, 357 | lang2p=lang2p, 358 | p1_2=math.log(p1_2), 359 | p2_1=math.log(p2_1)) 360 | session.add(transphraseprob) 361 | print(u" added phraseprob: {0} <=> {1} ".format(lang1p, lang2p)) 362 | session.commit() 363 | 364 | 365 | def createdb(db=":memory:", 366 | lang1method=lambda x: x, 367 | lang2method=lambda x: x, 368 | init_val=1.0e-10, 369 | limit=None, 370 | loop_count=1000, 371 | ): 372 | alchemydb = "sqlite:///{0}".format(db) 373 | create_train_db(transfrom=2, 374 | transto=1, 375 | lang1method=lang1method, 376 | lang2method=lang2method, 377 | db=alchemydb, 378 | limit=limit, 379 | loop_count=loop_count) 380 | create_train_db(transfrom=1, 381 | transto=2, 382 | lang1method=lang1method, 383 | lang2method=lang2method, 384 | db=alchemydb, 385 | limit=limit, 386 | loop_count=loop_count) 387 | create_phrase_db(limit=limit, 388 | lang1method=lang1method, 389 | lang2method=lang2method, 390 | init_val=init_val, 391 | db=alchemydb) 392 | create_phrase_count_view(db=db) 393 | create_phrase_prob(db=db) 394 | 395 | if __name__ == "__main__": 396 | pass 397 | -------------------------------------------------------------------------------- /test/test_stackdecoder.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | import unittest 5 | from fractions import Fraction as Frac 6 | from smt.decoder.stackdecoder import _future_cost_estimate 7 | from smt.decoder.stackdecoder import _create_estimate_dict 8 | from smt.decoder.stackdecoder import ArgumentNotSatisfied 9 | from smt.decoder.stackdecoder import future_cost_estimate 10 | from smt.decoder.stackdecoder import TransPhraseProb 11 | from smt.decoder.stackdecoder import Phrase 12 | # sqlalchemy 13 | from sqlalchemy import create_engine 14 | from sqlalchemy.orm import sessionmaker 15 | 16 | 17 | class DBSetup(object): 18 | 19 | def __init__(self, db="sqlite:///:memory:"): 20 | self.db = db 21 | self.tables = [TransPhraseProb, Phrase] 22 | 23 | def __enter__(self): 24 | self.engine = create_engine(self.db) 25 | # create tables 26 | for Table in self.tables: 27 | Table.__table__.create(self.engine) 28 | 29 | # create session 30 | Session = sessionmaker(bind=self.engine) 31 | self.session = Session() 32 | 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_value, traceback): 36 | # drop table 37 | for Table in self.tables: 38 | Table.__table__.drop(self.engine, checkfirst=True) 39 | self.session.close() 40 | 41 | 42 | class FutureCostEstimateTest(unittest.TestCase): 43 | 44 | def setUp(self): 45 | self.maxDiff = None 46 | 47 | def test_future_cost_estimate_2to1(self): 48 | sentences = u"the tourism initiative addresses this\ 49 | for the first time".split() 50 | transfrom = 2 51 | transto = 1 52 | init_val = 100.0 53 | db = "sqlite:///test/:test:" 54 | 55 | # data set 56 | dataset = [("1", "the", Frac(-1), 0), 57 | ("1", "the", Frac(-2), 0), 58 | # 2 59 | ("1", "tourism", Frac(-2), 0), 60 | ("1", "tourism", Frac(-3), 0), 61 | # 3 62 | ("1", "initiative", Frac(-15, 10), 0), 63 | ("1", "initiative", Frac(-25, 10), 0), 64 | # 4 65 | ("1", "addresses", Frac(-24, 10), 0), 66 | ("1", "addresses", Frac(-34, 10), 0), 67 | # 5 68 | ("1", "this", Frac(-14, 10), 0), 69 | ("1", "this", Frac(-24, 10), 0), 70 | # 6 71 | ("1", "for", Frac(-1), 0), 72 | ("1", "for", Frac(-2), 0), 73 | # 7 74 | ("1", "the", Frac(-1), 0), 75 | ("1", "the", Frac(-2), 0), 76 | # 8 77 | ("1", "first", Frac(-19, 10), 0), 78 | ("1", "first", Frac(-29, 10), 0), 79 | # 9 80 | ("1", "time", Frac(-16, 10), 0), 81 | ("1", "time", Frac(-26, 10), 0), 82 | # 10 83 | ("1", "initiative addresses", Frac(-4), 0), 84 | ("1", "initiative addresses", Frac(-4), 0), 85 | # 11 86 | ("1", "this for", Frac(-25, 10), 0), 87 | ("1", "this for", Frac(-35, 10), 0), 88 | # 12 89 | ("1", "the first", Frac(-22, 10), 0), 90 | ("1", "the first", Frac(-32, 10), 0), 91 | # 13 92 | ("1", "for the", Frac(-13, 10), 0), 93 | ("1", "for the", Frac(-23, 10), 0), 94 | # 14 95 | ("1", "first time", Frac(-24, 10), 0), 96 | ("1", "first time", Frac(-34, 10), 0), 97 | # 15 98 | ("1", "this for the", Frac(-27, 10), 0), 99 | ("1", "this for the", Frac(-37, 10), 0), 100 | # 16 101 | ("1", "for the first", Frac(-23, 10), 0), 102 | ("1", "for the first", Frac(-33, 10), 0), 103 | # 17 104 | ("1", "the first time", Frac(-23, 10), 0), 105 | ("1", "the first time", Frac(-33, 10), 0), 106 | # 18 107 | ("1", "for the first time", Frac(-23, 10), 0), 108 | ("1", "for the first time", Frac(-33, 10), 0), 109 | ] 110 | val = {(1, 1): -1.0, 111 | (1, 2): -3.0, 112 | (1, 3): -4.5, 113 | (1, 4): -6.9, 114 | (1, 5): -8.3, 115 | (1, 6): -9.3, 116 | (1, 7): -9.6, 117 | (1, 8): -10.6, 118 | (1, 9): -10.6, 119 | (2, 2): -2.0, 120 | (2, 3): -3.5, 121 | (2, 4): -5.9, 122 | (2, 5): -7.3, 123 | (2, 6): -8.3, 124 | (2, 7): -8.6, 125 | (2, 8): -9.6, 126 | (2, 9): -9.6, 127 | (3, 3): -1.5, 128 | (3, 4): -3.9, 129 | (3, 5): -5.3, 130 | (3, 6): -6.3, 131 | (3, 7): -6.6, 132 | (3, 8): -7.6, 133 | (3, 9): -7.6, 134 | (4, 4): -2.4, 135 | (4, 5): -3.8, 136 | (4, 6): -4.8, 137 | (4, 7): -5.1, 138 | (4, 8): -6.1, 139 | (4, 9): -6.1, 140 | (5, 5): -1.4, 141 | (5, 6): -2.4, 142 | (5, 7): -2.7, 143 | (5, 8): -3.6999999999999997, 144 | (5, 9): -3.6999999999999997, 145 | (6, 6): -1.0, 146 | (6, 7): -1.3, 147 | (6, 8): -2.3, 148 | (6, 9): -2.3, 149 | (7, 7): -1.0, 150 | (7, 8): -2.2, 151 | (7, 9): -2.3, 152 | (8, 8): -1.9, 153 | (8, 9): -2.4, 154 | (9, 9): -1.6, 155 | } 156 | 157 | with DBSetup(db) as dbobj: 158 | dbobj.session.add_all(TransPhraseProb(lang1p=item[0], 159 | lang2p=item[1], 160 | p2_1=item[2], 161 | p1_2=item[3]) 162 | for item in dataset) 163 | dbobj.session.add_all(Phrase(lang1p=item[0], 164 | lang2p=item[1]) 165 | for item in dataset) 166 | dbobj.session.commit() 167 | 168 | ans = future_cost_estimate(sentences, 169 | transfrom=transfrom, 170 | transto=transto, 171 | init_val=init_val, 172 | db=db) 173 | # assert 174 | self.assertEqual(ans, val) 175 | 176 | def test_future_cost_estimate_2to1_argument_not_satisfied(self): 177 | sentences = u"the tourism initiative addresses this\ 178 | for the first time".split() 179 | transfrom = 2 180 | transto = 1 181 | init_val = 100.0 182 | db = "sqlite:///test/:test:" 183 | 184 | # data set 185 | dataset = [("1", "the", Frac(-1), 0), 186 | ("1", "the", Frac(-2), 0), 187 | # 2 188 | ("1", "tourism", Frac(-2), 0), 189 | ("1", "tourism", Frac(-3), 0), 190 | # 3 191 | ("1", "initiative", Frac(-15, 10), 0), 192 | ("1", "initiative", Frac(-25, 10), 0), 193 | # 4 194 | ("1", "addresses", Frac(-24, 10), 0), 195 | ("1", "addresses", Frac(-34, 10), 0), 196 | # 5 197 | #("1", "this", Frac(-14, 10), 0), 198 | #("1", "this", Frac(-24, 10), 0), 199 | # 6 200 | ("1", "for", Frac(-1), 0), 201 | ("1", "for", Frac(-2), 0), 202 | # 7 203 | ("1", "the", Frac(-1), 0), 204 | ("1", "the", Frac(-2), 0), 205 | # 8 206 | ("1", "first", Frac(-19, 10), 0), 207 | ("1", "first", Frac(-29, 10), 0), 208 | # 9 209 | ("1", "time", Frac(-16, 10), 0), 210 | ("1", "time", Frac(-26, 10), 0), 211 | # 10 212 | ("1", "initiative addresses", Frac(-4), 0), 213 | ("1", "initiative addresses", Frac(-4), 0), 214 | # 11 215 | ("1", "this for", Frac(-25, 10), 0), 216 | ("1", "this for", Frac(-35, 10), 0), 217 | # 12 218 | ("1", "the first", Frac(-22, 10), 0), 219 | ("1", "the first", Frac(-32, 10), 0), 220 | # 13 221 | ("1", "for the", Frac(-13, 10), 0), 222 | ("1", "for the", Frac(-23, 10), 0), 223 | # 14 224 | ("1", "first time", Frac(-24, 10), 0), 225 | ("1", "first time", Frac(-34, 10), 0), 226 | # 15 227 | ("1", "this for the", Frac(-27, 10), 0), 228 | ("1", "this for the", Frac(-37, 10), 0), 229 | # 16 230 | ("1", "for the first", Frac(-23, 10), 0), 231 | ("1", "for the first", Frac(-33, 10), 0), 232 | # 17 233 | ("1", "the first time", Frac(-23, 10), 0), 234 | ("1", "the first time", Frac(-33, 10), 0), 235 | # 18 236 | ("1", "for the first time", Frac(-23, 10), 0), 237 | ("1", "for the first time", Frac(-33, 10), 0), 238 | ] 239 | 240 | val = {(1, 1): -1.0, 241 | (1, 2): -3.0, 242 | (1, 3): -4.5, 243 | (1, 4): -6.9, 244 | (1, 5): -106.9, 245 | (1, 6): -9.4, 246 | (1, 7): -9.6, 247 | (1, 8): -11.5, 248 | (1, 9): -11.7, 249 | (2, 2): -2.0, 250 | (2, 3): -3.5, 251 | (2, 4): -5.9, 252 | (2, 5): -105.9, 253 | (2, 6): -8.4, 254 | (2, 7): -8.6, 255 | (2, 8): -10.5, 256 | (2, 9): -10.7, 257 | (3, 3): -1.5, 258 | (3, 4): -3.9, 259 | (3, 5): -103.9, 260 | (3, 6): -6.4, 261 | (3, 7): -6.6, 262 | (3, 8): -8.5, 263 | (3, 9): -8.7, 264 | (4, 4): -2.4, 265 | (4, 5): -102.4, 266 | (4, 6): -4.9, 267 | (4, 7): -5.1, 268 | (4, 8): -7.0, 269 | (4, 9): -7.199999999999999, 270 | (5, 5): -100.0, 271 | (5, 6): -2.5, 272 | (5, 7): -2.7, 273 | (5, 8): -4.6, 274 | (5, 9): -4.8, 275 | (6, 6): -1.0, 276 | (6, 7): -1.3, 277 | (6, 8): -2.3, 278 | (6, 9): -2.3, 279 | (7, 7): -1.0, 280 | (7, 8): -2.2, 281 | (7, 9): -2.3, 282 | (8, 8): -1.9, 283 | (8, 9): -2.4, 284 | (9, 9): -1.6, 285 | } 286 | 287 | with DBSetup(db) as dbobj: 288 | dbobj.session.add_all(TransPhraseProb(lang1p=item[0], 289 | lang2p=item[1], 290 | p2_1=item[2], 291 | p1_2=item[3]) 292 | for item in dataset) 293 | dbobj.session.add_all(Phrase(lang1p=item[0], 294 | lang2p=item[1]) 295 | for item in dataset) 296 | dbobj.session.commit() 297 | 298 | ans = future_cost_estimate(sentences, 299 | transfrom=transfrom, 300 | transto=transto, 301 | init_val=init_val, 302 | db=db) 303 | 304 | # assert 305 | self.assertEqual(ans, val) 306 | 307 | def test_future_cost_estimate_1to2(self): 308 | sentences = u"the tourism initiative addresses this\ 309 | for the first time".split() 310 | transfrom = 1 311 | transto = 2 312 | init_val = 100.0 313 | db = "sqlite:///test/:test:" 314 | 315 | # data set 316 | dataset = [("1", "the", Frac(-1), 0), 317 | ("1", "the", Frac(-2), 0), 318 | # 2 319 | ("1", "tourism", Frac(-2), 0), 320 | ("1", "tourism", Frac(-3), 0), 321 | # 3 322 | ("1", "initiative", Frac(-15, 10), 0), 323 | ("1", "initiative", Frac(-25, 10), 0), 324 | # 4 325 | ("1", "addresses", Frac(-24, 10), 0), 326 | ("1", "addresses", Frac(-34, 10), 0), 327 | # 5 328 | ("1", "this", Frac(-14, 10), 0), 329 | ("1", "this", Frac(-24, 10), 0), 330 | # 6 331 | ("1", "for", Frac(-1), 0), 332 | ("1", "for", Frac(-2), 0), 333 | # 7 334 | ("1", "the", Frac(-1), 0), 335 | ("1", "the", Frac(-2), 0), 336 | # 8 337 | ("1", "first", Frac(-19, 10), 0), 338 | ("1", "first", Frac(-29, 10), 0), 339 | # 9 340 | ("1", "time", Frac(-16, 10), 0), 341 | ("1", "time", Frac(-26, 10), 0), 342 | # 10 343 | ("1", "initiative addresses", Frac(-4), 0), 344 | ("1", "initiative addresses", Frac(-4), 0), 345 | # 11 346 | ("1", "this for", Frac(-25, 10), 0), 347 | ("1", "this for", Frac(-35, 10), 0), 348 | # 12 349 | ("1", "the first", Frac(-22, 10), 0), 350 | ("1", "the first", Frac(-32, 10), 0), 351 | # 13 352 | ("1", "for the", Frac(-13, 10), 0), 353 | ("1", "for the", Frac(-23, 10), 0), 354 | # 14 355 | ("1", "first time", Frac(-24, 10), 0), 356 | ("1", "first time", Frac(-34, 10), 0), 357 | # 15 358 | ("1", "this for the", Frac(-27, 10), 0), 359 | ("1", "this for the", Frac(-37, 10), 0), 360 | # 16 361 | ("1", "for the first", Frac(-23, 10), 0), 362 | ("1", "for the first", Frac(-33, 10), 0), 363 | # 17 364 | ("1", "the first time", Frac(-23, 10), 0), 365 | ("1", "the first time", Frac(-33, 10), 0), 366 | # 18 367 | ("1", "for the first time", Frac(-23, 10), 0), 368 | ("1", "for the first time", Frac(-33, 10), 0), 369 | ] 370 | 371 | val = {(1, 1): -1.0, 372 | (1, 2): -3.0, 373 | (1, 3): -4.5, 374 | (1, 4): -6.9, 375 | (1, 5): -8.3, 376 | (1, 6): -9.3, 377 | (1, 7): -9.6, 378 | (1, 8): -10.6, 379 | (1, 9): -10.6, 380 | (2, 2): -2.0, 381 | (2, 3): -3.5, 382 | (2, 4): -5.9, 383 | (2, 5): -7.3, 384 | (2, 6): -8.3, 385 | (2, 7): -8.6, 386 | (2, 8): -9.6, 387 | (2, 9): -9.6, 388 | (3, 3): -1.5, 389 | (3, 4): -3.9, 390 | (3, 5): -5.3, 391 | (3, 6): -6.3, 392 | (3, 7): -6.6, 393 | (3, 8): -7.6, 394 | (3, 9): -7.6, 395 | (4, 4): -2.4, 396 | (4, 5): -3.8, 397 | (4, 6): -4.8, 398 | (4, 7): -5.1, 399 | (4, 8): -6.1, 400 | (4, 9): -6.1, 401 | (5, 5): -1.4, 402 | (5, 6): -2.4, 403 | (5, 7): -2.7, 404 | (5, 8): -3.6999999999999997, 405 | (5, 9): -3.6999999999999997, 406 | (6, 6): -1.0, 407 | (6, 7): -1.3, 408 | (6, 8): -2.3, 409 | (6, 9): -2.3, 410 | (7, 7): -1.0, 411 | (7, 8): -2.2, 412 | (7, 9): -2.3, 413 | (8, 8): -1.9, 414 | (8, 9): -2.4, 415 | (9, 9): -1.6, 416 | } 417 | 418 | with DBSetup(db) as dbobj: 419 | dbobj.session.add_all(TransPhraseProb(lang2p=item[0], 420 | lang1p=item[1], 421 | p1_2=item[2], 422 | p2_1=item[3]) 423 | for item in dataset) 424 | dbobj.session.add_all(Phrase(lang1p=item[0], 425 | lang2p=item[1]) 426 | for item in dataset) 427 | dbobj.session.commit() 428 | 429 | ans = future_cost_estimate(sentences, 430 | transfrom=transfrom, 431 | transto=transto, 432 | init_val=init_val, 433 | db=db) 434 | # assert 435 | self.assertEqual(ans, val) 436 | 437 | def test__future_cost_estimate(self): 438 | sentences = u"the tourism initiative addresses this\ 439 | for the first time".split() 440 | phrase_prob = {(1, 1): Frac(-1), 441 | (2, 2): Frac(-2), 442 | (3, 3): Frac(-15, 10), 443 | (4, 4): Frac(-24, 10), 444 | (5, 5): Frac(-14, 10), 445 | (6, 6): Frac(-1), 446 | (7, 7): Frac(-1), 447 | (8, 8): Frac(-19, 10), 448 | (9, 9): Frac(-16, 10), 449 | (3, 4): Frac(-4), 450 | (5, 6): Frac(-25, 10), 451 | (7, 8): Frac(-22, 10), 452 | (6, 7): Frac(-13, 10), 453 | (8, 9): Frac(-24, 10), 454 | (5, 7): Frac(-27, 10), 455 | (6, 8): Frac(-23, 10), 456 | (7, 9): Frac(-23, 10), 457 | (6, 9): Frac(-23, 10), 458 | } 459 | val = {(1, 1): Frac(-1), 460 | (1, 2): Frac(-3), 461 | (1, 3): Frac(-45, 10), 462 | (1, 4): Frac(-69, 10), 463 | (1, 5): Frac(-83, 10), 464 | (1, 6): Frac(-93, 10), 465 | (1, 7): Frac(-96, 10), 466 | (1, 8): Frac(-106, 10), 467 | (1, 9): Frac(-106, 10), 468 | (2, 2): Frac(-2), 469 | (2, 3): Frac(-35, 10), 470 | (2, 4): Frac(-59, 10), 471 | (2, 5): Frac(-73, 10), 472 | (2, 6): Frac(-83, 10), 473 | (2, 7): Frac(-86, 10), 474 | (2, 8): Frac(-96, 10), 475 | (2, 9): Frac(-96, 10), 476 | (3, 3): Frac(-15, 10), 477 | (3, 4): Frac(-39, 10), 478 | (3, 5): Frac(-53, 10), 479 | (3, 6): Frac(-63, 10), 480 | (3, 7): Frac(-66, 10), 481 | (3, 8): Frac(-76, 10), 482 | (3, 9): Frac(-76, 10), 483 | (4, 4): Frac(-24, 10), 484 | (4, 5): Frac(-38, 10), 485 | (4, 6): Frac(-48, 10), 486 | (4, 7): Frac(-51, 10), 487 | (4, 8): Frac(-61, 10), 488 | (4, 9): Frac(-61, 10), 489 | (5, 5): Frac(-14, 10), 490 | (5, 6): Frac(-24, 10), 491 | (5, 7): Frac(-27, 10), 492 | (5, 8): Frac(-37, 10), 493 | (5, 9): Frac(-37, 10), 494 | (6, 6): Frac(-1), 495 | (6, 7): Frac(-13, 10), 496 | (6, 8): Frac(-23, 10), 497 | (6, 9): Frac(-23, 10), 498 | (7, 7): Frac(-1), 499 | (7, 8): Frac(-22, 10), 500 | (7, 9): Frac(-23, 10), 501 | (8, 8): Frac(-19, 10), 502 | (8, 9): Frac(-24, 10), 503 | (9, 9): Frac(-16, 10)} 504 | ans = _future_cost_estimate(sentences, 505 | phrase_prob) 506 | self.assertEqual(ans, val) 507 | 508 | def test__future_cost_estimate_dict_not_satisfied(self): 509 | sentences = u"the tourism initiative addresses this\ 510 | for the first time".split() 511 | phrase_prob = {(1, 1): Frac(-1), 512 | (2, 2): Frac(-2), 513 | # lack one value 514 | #(3, 3): Frac(-15, 10), 515 | (4, 4): Frac(-24, 10), 516 | (5, 5): Frac(-14, 10), 517 | (6, 6): Frac(-1), 518 | (7, 7): Frac(-1), 519 | (8, 8): Frac(-19, 10), 520 | (9, 9): Frac(-16, 10), 521 | (3, 4): Frac(-4), 522 | (5, 6): Frac(-25, 10), 523 | (7, 8): Frac(-22, 10), 524 | (6, 7): Frac(-13, 10), 525 | (8, 9): Frac(-24, 10), 526 | (5, 7): Frac(-27, 10), 527 | (6, 8): Frac(-23, 10), 528 | (7, 9): Frac(-23, 10), 529 | (6, 9): Frac(-23, 10), 530 | } 531 | self.assertRaises(ArgumentNotSatisfied, 532 | _future_cost_estimate, 533 | sentences, 534 | phrase_prob) 535 | 536 | def test_create_estimate_dict(self): 537 | sentences = u"the tourism initiative addresses this\ 538 | for the first time".split() 539 | init_val = Frac(-100) 540 | phrase_prob = {(1, 1): Frac(-1), 541 | (2, 2): Frac(-2), 542 | # lack one value 543 | #(3, 3): Frac(-15, 10), 544 | (4, 4): Frac(-24, 10), 545 | (5, 5): Frac(-14, 10), 546 | #(6, 6): Frac(-1), 547 | (7, 7): Frac(-1), 548 | # lack one value 549 | #(8, 8): Frac(-19, 10), 550 | (9, 9): Frac(-16, 10), 551 | (3, 4): Frac(-4), 552 | (5, 6): Frac(-25, 10), 553 | (7, 8): Frac(-22, 10), 554 | (6, 7): Frac(-13, 10), 555 | (8, 9): Frac(-24, 10), 556 | (5, 7): Frac(-27, 10), 557 | (6, 8): Frac(-23, 10), 558 | (7, 9): Frac(-23, 10), 559 | (6, 9): Frac(-23, 10), 560 | } 561 | correct = {(1, 1): Frac(-1), 562 | (2, 2): Frac(-2), 563 | # lack one value 564 | (3, 3): init_val, 565 | (4, 4): Frac(-24, 10), 566 | (5, 5): Frac(-14, 10), 567 | (6, 6): init_val, 568 | (7, 7): Frac(-1), 569 | # lack one value 570 | (8, 8): init_val, 571 | (9, 9): Frac(-16, 10), 572 | (3, 4): Frac(-4), 573 | (5, 6): Frac(-25, 10), 574 | (7, 8): Frac(-22, 10), 575 | (6, 7): Frac(-13, 10), 576 | (8, 9): Frac(-24, 10), 577 | (5, 7): Frac(-27, 10), 578 | (6, 8): Frac(-23, 10), 579 | (7, 9): Frac(-23, 10), 580 | (6, 9): Frac(-23, 10), 581 | } 582 | ans = _create_estimate_dict(sentences, 583 | phrase_prob, 584 | init_val=init_val) 585 | self.assertEqual(ans, correct) 586 | 587 | 588 | if __name__ == '__main__': 589 | unittest.main() 590 | -------------------------------------------------------------------------------- /smt/decoder/stackdecoder.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # coding:utf-8 3 | 4 | from __future__ import division, print_function 5 | import math 6 | # sqlalchemy 7 | import sqlalchemy 8 | from sqlalchemy.ext.declarative import declarative_base 9 | from sqlalchemy import create_engine 10 | from sqlalchemy import Column, TEXT, REAL, INTEGER 11 | from sqlalchemy.orm import sessionmaker 12 | from smt.db.tables import Tables 13 | #from pprint import pprint 14 | 15 | 16 | # prepare classes for sqlalchemy 17 | class Phrase(declarative_base()): 18 | __tablename__ = "phrase" 19 | id = Column(INTEGER, primary_key=True) 20 | lang1p = Column(TEXT) 21 | lang2p = Column(TEXT) 22 | 23 | 24 | class TransPhraseProb(declarative_base()): 25 | __tablename__ = "phraseprob" 26 | id = Column(INTEGER, primary_key=True) 27 | lang1p = Column(TEXT) 28 | lang2p = Column(TEXT) 29 | p1_2 = Column(REAL) 30 | p2_1 = Column(REAL) 31 | 32 | 33 | def phrase_prob(lang1p, lang2p, 34 | transfrom=2, 35 | transto=1, 36 | db="sqlite:///:memory:", 37 | init_val=1.0e-10): 38 | """ 39 | """ 40 | engine = create_engine(db) 41 | Session = sessionmaker(bind=engine) 42 | session = Session() 43 | # search 44 | query = session.query(TransPhraseProb).filter_by(lang1p=lang1p, 45 | lang2p=lang2p) 46 | if transfrom == 2 and transto == 1: 47 | try: 48 | # Be Careful! The order of conditional prob is reversed 49 | # as transfrom and transto because of bayes rule 50 | return query.one().p2_1 51 | except sqlalchemy.orm.exc.NoResultFound: 52 | return init_val 53 | elif transfrom == 1 and transto == 2: 54 | try: 55 | return query.one().p1_2 56 | except sqlalchemy.orm.exc.NoResultFound: 57 | return init_val 58 | 59 | 60 | def available_phrases(inputs, transfrom=2, transto=1, db="sqlite:///:memory:"): 61 | """ 62 | >>> decode.available_phrases(u"He is a teacher.".split(), 63 | db_name="sqlite:///:db:")) 64 | set([((1, u'He'),), 65 | ((1, u'He'), (2, u'is')), 66 | ((2, u'is'),), 67 | ((2, u'is'), (3, u'a')), 68 | ((3, u'a'),), 69 | ((4, u'teacher.'),)]) 70 | """ 71 | engine = create_engine(db) 72 | # create session 73 | Session = sessionmaker(bind=engine) 74 | session = Session() 75 | available = set() 76 | for i, f in enumerate(inputs): 77 | f_rest = () 78 | for fr in inputs[i:]: 79 | f_rest += (fr,) 80 | rest_phrase = u" ".join(f_rest) 81 | if transfrom == 2 and transto == 1: 82 | query = session.query(Phrase).filter_by(lang2p=rest_phrase) 83 | elif transfrom == 1 and transto == 2: 84 | query = session.query(Phrase).filter_by(lang1p=rest_phrase) 85 | lst = list(query) 86 | if lst: 87 | available.add(tuple(enumerate(f_rest, i+1))) 88 | return available 89 | 90 | 91 | class HypothesisBase(object): 92 | def __init__(self, 93 | db, 94 | totalnumber, 95 | sentences, 96 | ngram, 97 | ngram_words, 98 | inputps_with_index, 99 | outputps, 100 | transfrom, 101 | transto, 102 | covered, 103 | remained, 104 | start, 105 | end, 106 | prev_start, 107 | prev_end, 108 | remain_phrases, 109 | prob, 110 | prob_with_cost, 111 | prev_hypo, 112 | cost_dict 113 | ): 114 | 115 | self._db = db 116 | self._totalnumber = totalnumber 117 | self._sentences = sentences 118 | self._ngram = ngram 119 | self._ngram_words = ngram_words 120 | self._inputps_with_index = inputps_with_index 121 | self._outputps = outputps 122 | self._transfrom = transfrom 123 | self._transto = transto 124 | self._covered = covered 125 | self._remained = remained 126 | self._start = start 127 | self._end = end 128 | self._prev_start = prev_start 129 | self._prev_end = prev_end 130 | self._remain_phrases = remain_phrases 131 | self._prob = prob 132 | self._prob_with_cost = prob_with_cost 133 | self._prev_hypo = prev_hypo 134 | self._cost_dict = cost_dict 135 | 136 | self._output_sentences = outputps 137 | 138 | @property 139 | def db(self): 140 | return self._db 141 | 142 | @property 143 | def totalnumber(self): 144 | return self._totalnumber 145 | 146 | @property 147 | def sentences(self): 148 | return self._sentences 149 | 150 | @property 151 | def ngram(self): 152 | return self._ngram 153 | 154 | @property 155 | def ngram_words(self): 156 | return self._ngram_words 157 | 158 | @property 159 | def inputps_with_index(self): 160 | return self._inputps_with_index 161 | 162 | @property 163 | def outputps(self): 164 | return self._outputps 165 | 166 | @property 167 | def transfrom(self): 168 | return self._transfrom 169 | 170 | @property 171 | def transto(self): 172 | return self._transto 173 | 174 | @property 175 | def covered(self): 176 | return self._covered 177 | 178 | @property 179 | def remained(self): 180 | return self._remained 181 | 182 | @property 183 | def start(self): 184 | return self._start 185 | 186 | @property 187 | def end(self): 188 | return self._end 189 | 190 | @property 191 | def prev_start(self): 192 | return self._prev_start 193 | 194 | @property 195 | def prev_end(self): 196 | return self._prev_end 197 | 198 | @property 199 | def remain_phrases(self): 200 | return self._remain_phrases 201 | 202 | @property 203 | def prob(self): 204 | return self._prob 205 | 206 | @property 207 | def prob_with_cost(self): 208 | return self._prob_with_cost 209 | 210 | @property 211 | def prev_hypo(self): 212 | return self._prev_hypo 213 | 214 | @property 215 | def cost_dict(self): 216 | return self._cost_dict 217 | 218 | @property 219 | def output_sentences(self): 220 | return self._output_sentences 221 | 222 | def __unicode__(self): 223 | d = [("db", self._db), 224 | ("sentences", self._sentences), 225 | ("inputps_with_index", self._inputps_with_index), 226 | ("outputps", self._outputps), 227 | ("ngram", self._ngram), 228 | ("ngram_words", self._ngram_words), 229 | ("transfrom", self._transfrom), 230 | ("transto", self._transto), 231 | ("covered", self._covered), 232 | ("remained", self._remained), 233 | ("start", self._start), 234 | ("end", self._end), 235 | ("prev_start", self._prev_start), 236 | ("prev_end", self._prev_end), 237 | ("remain_phrases", self._remain_phrases), 238 | ("prob", self._prob), 239 | ("prob_with_cost", self._prob_with_cost), 240 | #("cost_dict", self._cost_dict), 241 | #("prev_hypo", ""), 242 | ] 243 | return u"Hypothesis Object\n" +\ 244 | u"\n".join([u" " + k + u": " + 245 | unicode(v) for (k, v) in d]) 246 | 247 | def __str__(self): 248 | return unicode(self).encode('utf-8') 249 | 250 | def __hash__(self): 251 | return hash(unicode(self)) 252 | 253 | 254 | class Hypothesis(HypothesisBase): 255 | """ 256 | Realize like the following class 257 | 258 | >>> args = {"sentences": sentences, 259 | ... "inputps_with_index": phrase, 260 | ... "outputps": outputps, 261 | ... "covered": hyp0.covered.union(set(phrase)), 262 | ... "remained": hyp0.remained.difference(set(phrase)), 263 | ... "start": phrase[0][0], 264 | ... "end": phrase[-1][0], 265 | ... "prev_start": hyp0.start, 266 | ... "prev_end": hyp0.end, 267 | ... "remain_phrases": remain_phrases(phrase, 268 | ... hyp0.remain_phrases), 269 | ... "prev_hypo": hyp0 270 | ... } 271 | 272 | >>> hyp1 = decode.HypothesisBase(**args) 273 | """ 274 | 275 | def __init__(self, 276 | prev_hypo, 277 | inputps_with_index, 278 | outputps, 279 | ): 280 | 281 | start = inputps_with_index[0][0] 282 | end = inputps_with_index[-1][0] 283 | prev_start = prev_hypo.start 284 | prev_end = prev_hypo.end 285 | args = {"db": prev_hypo.db, 286 | "totalnumber": prev_hypo.totalnumber, 287 | "prev_hypo": prev_hypo, 288 | "sentences": prev_hypo.sentences, 289 | "ngram": prev_hypo.ngram, 290 | # set later 291 | "ngram_words": prev_hypo.ngram_words, 292 | "inputps_with_index": inputps_with_index, 293 | "outputps": outputps, 294 | "transfrom": prev_hypo.transfrom, 295 | "transto": prev_hypo.transto, 296 | "covered": prev_hypo.covered.union(set(inputps_with_index)), 297 | "remained": prev_hypo.remained.difference( 298 | set(inputps_with_index)), 299 | "start": start, 300 | "end": end, 301 | "prev_start": prev_start, 302 | "prev_end": prev_end, 303 | "remain_phrases": self._calc_remain_phrases( 304 | inputps_with_index, 305 | prev_hypo.remain_phrases), 306 | "cost_dict": prev_hypo.cost_dict, 307 | # set later 308 | "prob": 0, 309 | "prob_with_cost": 0, 310 | } 311 | HypothesisBase.__init__(self, **args) 312 | # set ngram words 313 | self._ngram_words = self._set_ngram_words() 314 | # set the exact probability 315 | self._prob = self._cal_prob(start - prev_end) 316 | # set the exact probability with cost 317 | self._prob_with_cost = self._cal_prob_with_cost(start - prev_end) 318 | # set the output phrases 319 | self._output_sentences = prev_hypo.output_sentences + outputps 320 | 321 | def _set_ngram_words(self): 322 | lst = self._prev_hypo.ngram_words + list(self._outputps) 323 | o_len = len(self._outputps) 324 | return list(reversed(list(reversed(lst))[:o_len - 1 + self._ngram])) 325 | 326 | def _cal_phrase_prob(self): 327 | inputp = u" ".join(zip(*self._inputps_with_index)[1]) 328 | outputp = u" ".join(self._outputps) 329 | 330 | if self._transfrom == 2 and self._transto == 1: 331 | return phrase_prob(lang1p=outputp, 332 | lang2p=inputp, 333 | transfrom=self._transfrom, 334 | transto=self._transto, 335 | db=self._db, 336 | init_val=-100) 337 | elif self._transfrom == 1 and self._transto == 2: 338 | return phrase_prob(lang1p=inputp, 339 | lang2p=outputp, 340 | transfrom=self._transfrom, 341 | transto=self._transto, 342 | db=self._db, 343 | init_val=-100) 344 | else: 345 | raise Exception("specify transfrom and transto") 346 | 347 | def _cal_language_prob(self): 348 | nw = self.ngram_words 349 | triwords = zip(nw, nw[1:], nw[2:]) 350 | prob = 0 351 | for first, second, third in triwords: 352 | prob += language_model(first, second, third, self._totalnumber, 353 | transto=self._transto, 354 | db=self._db) 355 | return prob 356 | 357 | def _cal_prob(self, dist): 358 | val = self._prev_hypo.prob +\ 359 | self._reordering_model(0.1, dist) +\ 360 | self._cal_phrase_prob() +\ 361 | self._cal_language_prob() 362 | return val 363 | 364 | def _sub_cal_prob_with_cost(self, s_len, cvd): 365 | insert_flag = False 366 | lst = [] 367 | sub_lst = [] 368 | for i in range(1, s_len+1): 369 | if i not in cvd: 370 | insert_flag = True 371 | else: 372 | insert_flag = False 373 | if sub_lst: 374 | lst.append(sub_lst) 375 | sub_lst = [] 376 | if insert_flag: 377 | sub_lst.append(i) 378 | else: 379 | if sub_lst: 380 | lst.append(sub_lst) 381 | return lst 382 | 383 | def _cal_prob_with_cost(self, dist): 384 | s_len = len(self._sentences) 385 | cvd = set(i for i, val in self._covered) 386 | lst = self._sub_cal_prob_with_cost(s_len, cvd) 387 | prob = self._cal_prob(dist) 388 | prob_with_cost = prob 389 | for item in lst: 390 | start = item[0] 391 | end = item[-1] 392 | cost = self._cost_dict[(start, end)] 393 | prob_with_cost += cost 394 | return prob_with_cost 395 | 396 | def _reordering_model(self, alpha, dist): 397 | return math.log(math.pow(alpha, math.fabs(dist))) 398 | 399 | def _calc_remain_phrases(self, phrase, phrases): 400 | """ 401 | >>> res = remain_phrases(((2, u'is'),), 402 | set([((1, u'he'),), 403 | ((2, u'is'),), 404 | ((3, u'a'),), 405 | ((2, u'is'), 406 | (3, u'a')), 407 | ((4, u'teacher'),)])) 408 | set([((1, u'he'),), ((3, u'a'),), ((4, u'teacher'),)]) 409 | >>> res = remain_phrases(((2, u'is'), (3, u'a')), 410 | set([((1, u'he'),), 411 | ((2, u'is'),), 412 | ((3, u'a'),), 413 | ((2, u'is'), 414 | (3, u'a')), 415 | ((4, u'teacher'),)])) 416 | set([((1, u'he'),), ((4, u'teacher'),)]) 417 | """ 418 | s = set() 419 | for ph in phrases: 420 | for p in phrase: 421 | if p in ph: 422 | break 423 | else: 424 | s.add(ph) 425 | return s 426 | 427 | 428 | def create_empty_hypothesis(sentences, cost_dict, 429 | ngram=3, transfrom=2, transto=1, 430 | db="sqlite:///:memory:"): 431 | phrases = available_phrases(sentences, 432 | db=db) 433 | hyp0 = HypothesisBase(sentences=sentences, 434 | db=db, 435 | totalnumber=_get_total_number(transto=transto, 436 | db=db), 437 | inputps_with_index=(), 438 | outputps=[], 439 | ngram=ngram, 440 | ngram_words=["", ""]*ngram, 441 | transfrom=transfrom, 442 | transto=transto, 443 | covered=set(), 444 | start=0, 445 | end=0, 446 | prev_start=0, 447 | prev_end=0, 448 | remained=set(enumerate(sentences, 1)), 449 | remain_phrases=phrases, 450 | prev_hypo=None, 451 | prob=0, 452 | cost_dict=cost_dict, 453 | prob_with_cost=0) 454 | #print(_get_total_number(transto=transto, db=db)) 455 | return hyp0 456 | 457 | 458 | class Stack(set): 459 | def __init__(self, size=10, 460 | histogram_pruning=True, 461 | threshold_pruning=False): 462 | set.__init__(self) 463 | self._min_hyp = None 464 | self._max_hyp = None 465 | self._size = size 466 | self._histogram_pruning = histogram_pruning 467 | self._threshold_pruning = threshold_pruning 468 | 469 | def add_hyp(self, hyp): 470 | #prob = hyp.prob 471 | # for the first time 472 | if self == set([]): 473 | self._min_hyp = hyp 474 | self._max_hyp = hyp 475 | else: 476 | raise Exception("Don't use add_hyp for nonempty stack") 477 | #else: 478 | # if self._min_hyp.prob > prob: 479 | # self._min_hyp = hyp 480 | # if self._max_hyp.prob < prob: 481 | # self._max_hyp = hyp 482 | self.add(hyp) 483 | 484 | def _get_min_hyp(self): 485 | # set value which is more than 1 486 | lst = list(self) 487 | mn = lst[0] 488 | for item in self: 489 | if item.prob_with_cost < mn.prob_with_cost: 490 | mn = item 491 | return mn 492 | 493 | def add_with_combine_prune(self, hyp): 494 | prob_with_cost = hyp.prob_with_cost 495 | if self == set([]): 496 | self._min_hyp = hyp 497 | self._max_hyp = hyp 498 | else: 499 | if self._min_hyp.prob_with_cost > prob_with_cost: 500 | self._min_hyp = hyp 501 | if self._max_hyp.prob_with_cost < prob_with_cost: 502 | self._max_hyp = hyp 503 | self.add(hyp) 504 | # combine 505 | for _hyp in self: 506 | if hyp.ngram_words[:-1] == _hyp.ngram_words[:-1] and \ 507 | hyp.end == hyp.end: 508 | if hyp.prob_with_cost > _hyp: 509 | self.remove(_hyp) 510 | self.add(hyp) 511 | break 512 | # histogram pruning 513 | if self._histogram_pruning: 514 | if len(self) > self._size: 515 | self.remove(self._min_hyp) 516 | self._min_hyp = self._get_min_hyp() 517 | # threshold pruning 518 | if self._threshold_pruning: 519 | alpha = 1.0e-5 520 | if hyp.prob_with_cost < self._max_hyp + math.log(alpha): 521 | self.remove(hyp) 522 | 523 | 524 | def _get_total_number(transto=1, db="sqlite:///:memory:"): 525 | """ 526 | return v 527 | """ 528 | 529 | Trigram = Tables().get_trigram_table('lang{}trigram'.format(transto)) 530 | 531 | # create connection in SQLAlchemy 532 | engine = create_engine(db) 533 | # create session 534 | Session = sessionmaker(bind=engine) 535 | session = Session() 536 | 537 | # calculate total number 538 | query = session.query(Trigram) 539 | 540 | return len(list(query)) 541 | 542 | 543 | def language_model(first, second, third, totalnumber, transto=1, 544 | db="sqlalchemy:///:memory:"): 545 | 546 | class TrigramProb(declarative_base()): 547 | __tablename__ = 'lang{}trigramprob'.format(transto) 548 | id = Column(INTEGER, primary_key=True) 549 | first = Column(TEXT) 550 | second = Column(TEXT) 551 | third = Column(TEXT) 552 | prob = Column(REAL) 553 | 554 | class TrigramProbWithoutLast(declarative_base()): 555 | __tablename__ = 'lang{}trigramprob'.format(transto) 556 | id = Column(INTEGER, primary_key=True) 557 | first = Column(TEXT) 558 | second = Column(TEXT) 559 | prob = Column(REAL) 560 | 561 | # create session 562 | engine = create_engine(db) 563 | Session = sessionmaker(bind=engine) 564 | session = Session() 565 | try: 566 | # next line can raise error if the prob is not found 567 | query = session.query(TrigramProb).filter_by(first=first, 568 | second=second, 569 | third=third) 570 | item = query.one() 571 | return item.prob 572 | except sqlalchemy.orm.exc.NoResultFound: 573 | query = session.query(TrigramProbWithoutLast 574 | ).filter_by(first=first, 575 | second=second) 576 | # I have to modify the database 577 | item = query.first() 578 | if item: 579 | return item.prob 580 | else: 581 | return - math.log(totalnumber) 582 | 583 | 584 | class ArgumentNotSatisfied(Exception): 585 | pass 586 | 587 | 588 | def _future_cost_estimate(sentences, 589 | phrase_prob): 590 | ''' 591 | warning: 592 | pass the complete one_word_prob 593 | ''' 594 | s_len = len(sentences) 595 | cost = {} 596 | 597 | one_word_prob = {(st, ed): prob for (st, ed), prob in phrase_prob.items() 598 | if st == ed} 599 | 600 | if set(one_word_prob.keys()) != set((x, x) for x in range(1, s_len+1)): 601 | raise ArgumentNotSatisfied("phrase_prob doesn't satisfy the condition") 602 | 603 | # add one word prob 604 | for tpl, prob in one_word_prob.items(): 605 | index = tpl[0] 606 | cost[(index, index)] = prob 607 | 608 | for length in range(1, s_len+1): 609 | for start in range(1, s_len-length+1): 610 | end = start + length 611 | try: 612 | cost[(start, end)] = phrase_prob[(start, end)] 613 | except KeyError: 614 | cost[(start, end)] = -float('inf') 615 | for i in range(start, end): 616 | _val = cost[(start, i)] + cost[(i+1, end)] 617 | if _val > cost[(start, end)]: 618 | cost[(start, end)] = _val 619 | return cost 620 | 621 | 622 | def _create_estimate_dict(sentences, 623 | phrase_prob, 624 | init_val=-100): 625 | one_word_prob_dict_nums = set(x for x, y in phrase_prob.keys() if x == y) 626 | comp_dic = {} 627 | # complete the one_word_prob 628 | s_len = len(sentences) 629 | for i in range(1, s_len+1): 630 | if i not in one_word_prob_dict_nums: 631 | comp_dic[(i, i)] = init_val 632 | for key, val in phrase_prob.items(): 633 | comp_dic[key] = val 634 | return comp_dic 635 | 636 | 637 | def _get_total_number_for_fce(transto=1, db="sqlite:///:memory:"): 638 | """ 639 | return v 640 | """ 641 | # create connection in SQLAlchemy 642 | engine = create_engine(db) 643 | # create session 644 | Session = sessionmaker(bind=engine) 645 | session = Session() 646 | 647 | tablename = 'lang{}unigram'.format(transto) 648 | Unigram = Tables().get_unigram_table(tablename) 649 | 650 | # calculate total number 651 | query = session.query(Unigram) 652 | sm = 0 653 | totalnumber = 0 654 | for item in query: 655 | totalnumber += 1 656 | sm += item.count 657 | return {'totalnumber': totalnumber, 658 | 'sm': sm} 659 | 660 | 661 | def _future_cost_langmodel(word, 662 | tn, 663 | transfrom=2, 664 | transto=1, 665 | alpha=0.00017, 666 | db="sqlite:///:memory:"): 667 | tablename = "lang{}unigramprob".format(transto) 668 | # create session 669 | engine = create_engine(db) 670 | Session = sessionmaker(bind=engine) 671 | session = Session() 672 | 673 | UnigramProb = Tables().get_unigramprob_table(tablename) 674 | query = session.query(UnigramProb).filter_by(first=word) 675 | try: 676 | item = query.one() 677 | return item.prob 678 | except sqlalchemy.orm.exc.NoResultFound: 679 | sm = tn['sm'] 680 | totalnumber = tn['totalnumber'] 681 | return math.log(alpha) - math.log(sm + alpha*totalnumber) 682 | 683 | 684 | def future_cost_estimate(sentences, 685 | transfrom=2, 686 | transto=1, 687 | init_val=-100.0, 688 | db="sqlite:///:memory:"): 689 | # create phrase_prob table 690 | engine = create_engine(db) 691 | # create session 692 | Session = sessionmaker(bind=engine) 693 | session = Session() 694 | phrases = available_phrases(sentences, 695 | db=db) 696 | 697 | tn = _get_total_number_for_fce(transto=transto, db=db) 698 | covered = {} 699 | for phrase in phrases: 700 | phrase_str = u" ".join(zip(*phrase)[1]) 701 | if transfrom == 2 and transto == 1: 702 | query = session.query(TransPhraseProb).filter_by( 703 | lang2p=phrase_str).order_by( 704 | sqlalchemy.desc(TransPhraseProb.p2_1)) 705 | elif transfrom == 1 and transto == 2: 706 | query = session.query(TransPhraseProb).filter_by( 707 | lang1p=phrase_str).order_by( 708 | sqlalchemy.desc(TransPhraseProb.p1_2)) 709 | lst = list(query) 710 | if lst: 711 | # extract the maximum val 712 | val = query.first() 713 | start = zip(*phrase)[0][0] 714 | end = zip(*phrase)[0][-1] 715 | pos = (start, end) 716 | if transfrom == 2 and transto == 1: 717 | fcl = _future_cost_langmodel(word=val.lang1p.split()[0], 718 | tn=tn, 719 | transfrom=transfrom, 720 | transto=transto, 721 | alpha=0.00017, 722 | db=db) 723 | print(val.lang1p.split()[0], fcl) 724 | covered[pos] = val.p2_1 + fcl 725 | if transfrom == 1 and transto == 2: 726 | covered[pos] = val.p1_2 727 | # + language_model() 728 | # estimate future costs 729 | phrase_prob = _create_estimate_dict(sentences, covered) 730 | print(phrase_prob) 731 | 732 | return _future_cost_estimate(sentences, 733 | phrase_prob) 734 | 735 | 736 | def stack_decoder(sentence, transfrom=2, transto=1, 737 | stacksize=10, 738 | searchsize=10, 739 | lang1method=lambda x: x, 740 | lang2method=lambda x: x, 741 | db="sqlite:///:memory:", 742 | verbose=False): 743 | # create phrase_prob table 744 | engine = create_engine(db) 745 | # create session 746 | Session = sessionmaker(bind=engine) 747 | session = Session() 748 | 749 | if transfrom == 2 and transto == 1: 750 | sentences = lang2method(sentence).split() 751 | else: 752 | sentences = lang1method(sentence).split() 753 | # create stacks 754 | len_sentences = len(sentences) 755 | stacks = [Stack(size=stacksize, 756 | histogram_pruning=True, 757 | threshold_pruning=False, 758 | ) for i in range(len_sentences+1)] 759 | 760 | cost_dict = future_cost_estimate(sentences, 761 | transfrom=transfrom, 762 | transto=transto, 763 | db=db) 764 | #create the initial hypothesis 765 | hyp0 = create_empty_hypothesis(sentences=sentences, 766 | cost_dict=cost_dict, 767 | ngram=3, 768 | transfrom=2, 769 | transto=1, 770 | db=db) 771 | stacks[0].add_hyp(hyp0) 772 | 773 | # main loop 774 | for i, stack in enumerate(stacks): 775 | for hyp in stack: 776 | for phrase in hyp.remain_phrases: 777 | phrase_str = u" ".join(zip(*phrase)[1]) 778 | if transfrom == 2 and transto == 1: 779 | query = session.query(TransPhraseProb).filter_by( 780 | lang2p=phrase_str).order_by( 781 | sqlalchemy.desc(TransPhraseProb.p2_1))[:searchsize] 782 | elif transfrom == 1 and transto == 2: 783 | query = session.query(TransPhraseProb).filter_by( 784 | lang1p=phrase_str).order_by( 785 | sqlalchemy.desc(TransPhraseProb.p1_2))[:searchsize] 786 | query = list(query) 787 | for item in query: 788 | if transfrom == 2 and transto == 1: 789 | outputp = item.lang1p 790 | elif transfrom == 1 and transto == 2: 791 | outputp = item.lang2p 792 | #print(u"calculating\n {0} = {1}\n in stack {2}".format( 793 | # phrase, outputp, i)) 794 | if transfrom == 2 and transto == 1: 795 | outputps = lang1method(outputp).split() 796 | elif transfrom == 1 and transto == 2: 797 | outputps = lang2method(outputp).split() 798 | # place in stack 799 | # and recombine with existing hypothesis if possible 800 | new_hyp = Hypothesis(prev_hypo=hyp, 801 | inputps_with_index=phrase, 802 | outputps=outputps) 803 | if verbose: 804 | print(phrase, u' '.join(outputps)) 805 | print("loop: ", i, "len:", len(new_hyp.covered)) 806 | stacks[len(new_hyp.covered)].add_with_combine_prune( 807 | new_hyp) 808 | return stacks 809 | 810 | 811 | if __name__ == '__main__': 812 | #import doctest 813 | #doctest.testmod() 814 | pass 815 | --------------------------------------------------------------------------------