├── smt
├── __init__.py
├── db
│ ├── __init__.py
│ ├── tables.py
│ ├── createngramdb.py
│ └── createdb.py
├── decoder
│ ├── __init__.py
│ └── stackdecoder.py
├── phrase
│ ├── __init__.py
│ ├── word_alignment.py
│ └── phrase_extract.py
├── utils
│ ├── __init__.py
│ └── utility.py
├── ibmmodel
│ ├── __init__.py
│ ├── test.txt
│ ├── ibmmodel1.py
│ └── ibmmodel2.py
└── langmodel
│ ├── __init__.py
│ └── ngram.py
├── test
├── __init__.py
├── test_ngram.py
├── test_ibmmodel.py
├── test_phrase.py
└── test_stackdecoder.py
├── .gitignore
├── setup.py
├── README.rst
└── COPYING
/smt/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/db/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/decoder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/phrase/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/ibmmodel/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/langmodel/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/smt/ibmmodel/test.txt:
--------------------------------------------------------------------------------
1 | the house|||das Haus
2 | the book|||das Buch
3 | a book|||ein Buch
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | smt.egg-info/
2 | *.pyc
3 | terminal.py
4 | twitter
5 | jec_basic_sentence
6 | test/:test:
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from setuptools import setup, find_packages
5 |
6 | setup(
7 | name="smt",
8 | version="0.1",
9 | description="Statistical Machine Translation implementation by Python",
10 | author="Noriyuki Abe",
11 | author_email="kenko.py@gmail.com",
12 | url="http://kenkov.jp",
13 | packages=find_packages(),
14 | test_suite="test",
15 | )
16 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | ==============================
2 | IBM Model
3 | ==============================
4 |
5 | IMB models of statistical mathine translation
6 |
7 | Files
8 | =======
9 |
10 | ibmmodel1.py
11 | implements IBM Model1
12 |
13 | ibmmodel2.py
14 | implements IBM Model2
15 |
16 | word_alignment.py
17 | implements symmetrization of word alignments
18 |
19 |
20 | Usege
21 | ======
22 |
23 | See each file and test codes written in test.py
24 |
--------------------------------------------------------------------------------
/smt/langmodel/ngram.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import itertools
6 |
7 |
8 | class NgramException(Exception):
9 | pass
10 |
11 |
12 | def ngram(sentences, n):
13 | s_len = len(sentences)
14 | if s_len < n:
15 | raise NgramException("the sentences length is not enough:\
16 | len(sentences)={} < n={}".format(s_len, n))
17 | xs = itertools.tee(sentences, n)
18 | for i, t in enumerate(xs[1:]):
19 | for _ in xrange(i+1):
20 | next(t)
21 | return itertools.izip(*xs)
22 |
23 |
24 | if __name__ == '__main__':
25 | pass
26 |
--------------------------------------------------------------------------------
/COPYING:
--------------------------------------------------------------------------------
1 | Copyright (c) 2013 Noriyuki ABE
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"),
5 | to deal in the Software without restriction, including without limitation the rights to use,
6 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
7 | and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8 |
9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
12 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
13 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
14 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16 |
--------------------------------------------------------------------------------
/test/test_ngram.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import unittest
6 | from smt.langmodel.ngram import ngram
7 | from smt.langmodel.ngram import NgramException
8 |
9 |
10 | class NgramTest(unittest.TestCase):
11 | def test_ngram_3(self):
12 | sentence = ["I am teacher",
13 | "I am",
14 | "I",
15 | ""]
16 | test_sentences = (["", ""] + item.split() + [""]
17 | for item in sentence)
18 | anss = [[("", "", "I"),
19 | ("", "I", "am"),
20 | ("I", "am", "teacher"),
21 | ("am", "teacher", "")],
22 | [("", "", "I"),
23 | ("", "I", "am"),
24 | ("I", "am", "")],
25 | [("", "", "I"),
26 | ("", "I", "")],
27 | [("", "", "")],
28 | ]
29 |
30 | for sentences, ans in zip(test_sentences, anss):
31 | a = ngram(sentences, 3)
32 | self.assertEqual(list(a), ans)
33 |
34 | def test_ngram_illegal_input(self):
35 | sentences = ["I", "am"]
36 | self.assertRaises(NgramException, ngram, sentences, 3)
37 |
38 |
39 | if __name__ == '__main__':
40 | unittest.main()
41 |
--------------------------------------------------------------------------------
/smt/utils/utility.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 |
6 |
7 | def mkcorpus(sentences):
8 | """
9 | >>> sent_pairs = [("僕 は 男 です", "I am a man"),
10 | ("私 は 女 です", "I am a girl"),
11 | ("私 は 先生 です", "I am a teacher"),
12 | ("彼女 は 先生 です", "She is a teacher"),
13 | ("彼 は 先生 です", "He is a teacher"),
14 | ]
15 | >>> pprint(mkcorpus(sent_pairs))
16 | [(['\xe5\x83\x95',
17 | '\xe3\x81\xaf',
18 | '\xe7\x94\xb7',
19 | '\xe3\x81\xa7\xe3\x81\x99'],
20 | ['I', 'am', 'a', 'man']),
21 | (['\xe7\xa7\x81',
22 | '\xe3\x81\xaf',
23 | '\xe5\xa5\xb3',
24 | '\xe3\x81\xa7\xe3\x81\x99'],
25 | ['I', 'am', 'a', 'girl']),
26 | (['\xe7\xa7\x81',
27 | '\xe3\x81\xaf',
28 | '\xe5\x85\x88\xe7\x94\x9f',
29 | '\xe3\x81\xa7\xe3\x81\x99'],
30 | ['I', 'am', 'a', 'teacher']),
31 | (['\xe5\xbd\xbc\xe5\xa5\xb3',
32 | '\xe3\x81\xaf',
33 | '\xe5\x85\x88\xe7\x94\x9f',
34 | '\xe3\x81\xa7\xe3\x81\x99'],
35 | ['She', 'is', 'a', 'teacher']),
36 | (['\xe5\xbd\xbc',
37 | '\xe3\x81\xaf',
38 | '\xe5\x85\x88\xe7\x94\x9f',
39 | '\xe3\x81\xa7\xe3\x81\x99'],
40 | ['He', 'is', 'a', 'teacher'])]
41 | """
42 | return [(es.split(), fs.split()) for (es, fs) in sentences]
43 |
44 |
45 | def matrix(
46 | m, n, lst,
47 | m_text: list=None,
48 | n_text: list=None):
49 | """
50 | m: row
51 | n: column
52 | lst: items
53 |
54 | >>> print(_matrix(2, 3, [(1, 1), (2, 3)]))
55 | |x| | |
56 | | | |x|
57 | """
58 |
59 | fmt = ""
60 | if n_text:
61 | fmt += " {}\n".format(" ".join(n_text))
62 | for i in range(1, m+1):
63 | if m_text:
64 | fmt += "{:<4.4} ".format(m_text[i-1])
65 | fmt += "|"
66 | for j in range(1, n+1):
67 | if (i, j) in lst:
68 | fmt += "x|"
69 | else:
70 | fmt += " |"
71 | fmt += "\n"
72 | return fmt
73 |
--------------------------------------------------------------------------------
/smt/db/tables.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | # import SQLAlchemy
6 | from sqlalchemy.ext.declarative import declarative_base
7 | from sqlalchemy import Column, TEXT, REAL, INTEGER
8 |
9 |
10 | class Tables(object):
11 |
12 | def get_sentence_table(self, tablename="sentence"):
13 |
14 | class Sentence(declarative_base()):
15 | __tablename__ = tablename
16 | id = Column(INTEGER, primary_key=True)
17 | lang1 = Column(TEXT)
18 | lang2 = Column(TEXT)
19 |
20 | return Sentence
21 |
22 | def get_wordprobability_table(self, tablename):
23 |
24 | class WordProbability(declarative_base()):
25 | __tablename__ = tablename
26 | id = Column(INTEGER, primary_key=True)
27 | transto = Column(TEXT)
28 | transfrom = Column(TEXT)
29 | prob = Column(REAL)
30 |
31 | return WordProbability
32 |
33 | def get_wordalignment_table(self, tablename):
34 |
35 | class WordAlignment(declarative_base()):
36 | __tablename__ = tablename
37 | id = Column(INTEGER, primary_key=True)
38 | from_pos = Column(INTEGER)
39 | to_pos = Column(INTEGER)
40 | to_len = Column(INTEGER)
41 | from_len = Column(INTEGER)
42 | prob = Column(REAL)
43 |
44 | return WordAlignment
45 |
46 | def get_phrase_table(self, tablename="phrase"):
47 |
48 | class Phrase(declarative_base()):
49 | __tablename__ = tablename
50 | id = Column(INTEGER, primary_key=True)
51 | lang1p = Column(TEXT)
52 | lang2p = Column(TEXT)
53 |
54 | return Phrase
55 |
56 | def get_transphraseprob_table(self, tablename="phraseprob"):
57 |
58 | class TransPhraseProb(declarative_base()):
59 | __tablename__ = tablename
60 | id = Column(INTEGER, primary_key=True)
61 | lang1p = Column(TEXT)
62 | lang2p = Column(TEXT)
63 | p1_2 = Column(REAL)
64 | p2_1 = Column(REAL)
65 |
66 | return TransPhraseProb
67 |
68 | def get_trigram_table(self, tablename):
69 |
70 | class Trigram(declarative_base()):
71 | __tablename__ = tablename
72 | id = Column(INTEGER, primary_key=True)
73 | first = Column(TEXT)
74 | second = Column(TEXT)
75 | third = Column(TEXT)
76 | count = Column(INTEGER)
77 |
78 | return Trigram
79 |
80 | def get_trigramprob_table(self, tablename):
81 |
82 | class TrigramProb(declarative_base()):
83 | __tablename__ = tablename
84 | id = Column(INTEGER, primary_key=True)
85 | first = Column(TEXT)
86 | second = Column(TEXT)
87 | third = Column(TEXT)
88 | prob = Column(REAL)
89 |
90 | return TrigramProb
91 |
92 | def get_trigramprobwithoutlast_table(self, tablename):
93 |
94 | class TrigramProbWithoutLast(declarative_base()):
95 | __tablename__ = tablename
96 | id = Column(INTEGER, primary_key=True)
97 | first = Column(TEXT)
98 | second = Column(TEXT)
99 | prob = Column(REAL)
100 |
101 | return TrigramProbWithoutLast
102 |
103 | def get_unigram_table(self, tablename):
104 |
105 | class Unigram(declarative_base()):
106 | __tablename__ = tablename
107 | id = Column(INTEGER, primary_key=True)
108 | first = Column(TEXT)
109 | count = Column(INTEGER)
110 |
111 | return Unigram
112 |
113 | def get_unigramprob_table(self, tablename):
114 |
115 | class UnigramProb(declarative_base()):
116 | __tablename__ = tablename
117 | id = Column(INTEGER, primary_key=True)
118 | first = Column(TEXT)
119 | prob = Column(REAL)
120 |
121 | return UnigramProb
122 |
--------------------------------------------------------------------------------
/test/test_ibmmodel.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import unittest
6 | import collections
7 | #import keitaiso
8 | from smt.ibmmodel.ibmmodel1 import train
9 | from smt.ibmmodel.ibmmodel2 import viterbi_alignment
10 | #import smt.ibmmodel.ibmmodel2 as ibmmodel2
11 | import decimal
12 | from decimal import Decimal as D
13 |
14 | # set deciaml context
15 | decimal.getcontext().prec = 4
16 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP
17 |
18 |
19 | class IBMModel1Test(unittest.TestCase):
20 |
21 | #def _format(self, lst):
22 | # return {(k, float('{:.4f}'.format(v))) for (k, v) in lst}
23 |
24 | def test_train_loop1(self):
25 | sent_pairs = [("the house", "das Haus"),
26 | ("the book", "das Buch"),
27 | ("a book", "ein Buch"),
28 | ]
29 | #t0 = train(sent_pairs, loop_count=0)
30 | t1 = train(sent_pairs, loop_count=1)
31 |
32 | loop1 = [(('house', 'Haus'), D("0.5")),
33 | (('book', 'ein'), D("0.5")),
34 | (('the', 'das'), D("0.5")),
35 | (('the', 'Buch'), D("0.25")),
36 | (('book', 'Buch'), D("0.5")),
37 | (('a', 'ein'), D("0.5")),
38 | (('book', 'das'), D("0.25")),
39 | (('the', 'Haus'), D("0.5")),
40 | (('house', 'das'), D("0.25")),
41 | (('a', 'Buch'), D("0.25"))]
42 | # assertion
43 | # next assertion doesn't make sence because
44 | # initialized by defaultdict
45 | #self.assertEqual(self._format(t0.items()), self._format(loop0))
46 | self.assertEqual(set(t1.items()), set(loop1))
47 |
48 | def test_train_loop2(self):
49 | sent_pairs = [("the house", "das Haus"),
50 | ("the book", "das Buch"),
51 | ("a book", "ein Buch"),
52 | ]
53 | #t0 = train(sent_pairs, loop_count=0)
54 | t2 = train(sent_pairs, loop_count=2)
55 |
56 | loop2 = [(('house', 'Haus'), D("0.5713")),
57 | (('book', 'ein'), D("0.4284")),
58 | (('the', 'das'), D("0.6367")),
59 | (('the', 'Buch'), D("0.1818")),
60 | (('book', 'Buch'), D("0.6367")),
61 | (('a', 'ein'), D("0.5713")),
62 | (('book', 'das'), D("0.1818")),
63 | (('the', 'Haus'), D("0.4284")),
64 | (('house', 'das'), D("0.1818")),
65 | (('a', 'Buch'), D("0.1818"))]
66 | # assertion
67 | # next assertion doesn't make sence because
68 | # initialized by defaultdict
69 | #self.assertEqual(self._format(t0.items()), self._format(loop0))
70 | self.assertEqual(set(t2.items()), set(loop2))
71 |
72 |
73 | class IBMModel2Test(unittest.TestCase):
74 |
75 | def test_viterbi_alignment(self):
76 | x = viterbi_alignment([1, 2, 1],
77 | [2, 3, 2],
78 | collections.defaultdict(int),
79 | collections.defaultdict(int))
80 | # Viterbi_alignment selects the first token
81 | # if t or a doesn't contain the key.
82 | # This means it returns NULL token
83 | # in such a situation.
84 | self.assertEqual(x, {1: 1, 2: 1, 3: 1})
85 |
86 | #def test_zero_division_error(self):
87 | # """
88 | # at the beginning, there was this bug for ZeroDivisionError,
89 | # so this test was created to check that
90 | # """
91 | # sentence = [(u"Xではないかとつくづく疑問に思う",
92 | # u"I often wonder if it might be X."),
93 | # (u"Xがいいなといつも思います",
94 | # u"I always think X would be nice."),
95 | # (u"それがあるようにいつも思います",
96 | # u"It always seems like it is there."),
97 | # ]
98 | # sentences = [(keitaiso.str2wakati(s1), s2) for
99 | # s1, s2 in sentence]
100 |
101 | # self.assertRaises(decimal.DivisionByZero,
102 | # ibmmodel2.train,
103 | # sentences, loop_count=1000)
104 |
105 |
106 | if __name__ == '__main__':
107 | unittest.main()
108 |
--------------------------------------------------------------------------------
/smt/ibmmodel/ibmmodel1.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from operator import itemgetter
5 | import collections
6 | from smt.utils import utility
7 | import decimal
8 | from decimal import Decimal as D
9 |
10 | # set deciaml context
11 | decimal.getcontext().prec = 4
12 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP
13 |
14 |
15 | def _constant_factory(value):
16 | '''define a local function for uniform probability initialization'''
17 | #return itertools.repeat(value).next
18 | return lambda: value
19 |
20 |
21 | def _train(corpus, loop_count=1000):
22 | f_keys = set()
23 | for (es, fs) in corpus:
24 | for f in fs:
25 | f_keys.add(f)
26 | # default value provided as uniform probability)
27 | t = collections.defaultdict(_constant_factory(D(1/len(f_keys))))
28 |
29 | # loop
30 | for i in range(loop_count):
31 | count = collections.defaultdict(D)
32 | total = collections.defaultdict(D)
33 | s_total = collections.defaultdict(D)
34 | for (es, fs) in corpus:
35 | # compute normalization
36 | for e in es:
37 | s_total[e] = D()
38 | for f in fs:
39 | s_total[e] += t[(e, f)]
40 | for e in es:
41 | for f in fs:
42 | count[(e, f)] += t[(e, f)] / s_total[e]
43 | total[f] += t[(e, f)] / s_total[e]
44 | #if e == u"に" and f == u"always":
45 | # print(" BREAK:", i, count[(e, f)])
46 | # estimate probability
47 | for (e, f) in count.keys():
48 | #if count[(e, f)] == 0:
49 | # print(e, f, count[(e, f)])
50 | t[(e, f)] = count[(e, f)] / total[f]
51 |
52 | return t
53 |
54 |
55 | def train(sentences, loop_count=1000):
56 | corpus = utility.mkcorpus(sentences)
57 | return _train(corpus, loop_count)
58 |
59 |
60 | def _pprint(tbl):
61 | for (e, f), v in sorted(tbl.items(), key=itemgetter(1), reverse=True):
62 | print(u"p({e}|{f}) = {v}".format(e=e, f=f, v=v))
63 |
64 |
65 | def test_train_loop1():
66 | sent_pairs = [("the house", "das Haus"),
67 | ("the book", "das Buch"),
68 | ("a book", "ein Buch"),
69 | ]
70 | #t0 = train(sent_pairs, loop_count=0)
71 | t1 = train(sent_pairs, loop_count=1)
72 |
73 | loop1 = [(('house', 'Haus'), D("0.5")),
74 | (('book', 'ein'), D("0.5")),
75 | (('the', 'das'), D("0.5")),
76 | (('the', 'Buch'), D("0.25")),
77 | (('book', 'Buch'), D("0.5")),
78 | (('a', 'ein'), D("0.5")),
79 | (('book', 'das'), D("0.25")),
80 | (('the', 'Haus'), D("0.5")),
81 | (('house', 'das'), D("0.25")),
82 | (('a', 'Buch'), D("0.25"))]
83 | # assertion
84 | # next assertion doesn't make sence because
85 | # initialized by defaultdict
86 | #self.assertEqual(self._format(t0.items()), self._format(loop0))
87 | assert set(t1.items()) == set(loop1)
88 |
89 |
90 | def test_train_loop2():
91 | sent_pairs = [("the house", "das Haus"),
92 | ("the book", "das Buch"),
93 | ("a book", "ein Buch"),
94 | ]
95 | #t0 = train(sent_pairs, loop_count=0)
96 | t2 = train(sent_pairs, loop_count=2)
97 |
98 | loop2 = [(('house', 'Haus'), D("0.5713")),
99 | (('book', 'ein'), D("0.4284")),
100 | (('the', 'das'), D("0.6367")),
101 | (('the', 'Buch'), D("0.1818")),
102 | (('book', 'Buch'), D("0.6367")),
103 | (('a', 'ein'), D("0.5713")),
104 | (('book', 'das'), D("0.1818")),
105 | (('the', 'Haus'), D("0.4284")),
106 | (('house', 'das'), D("0.1818")),
107 | (('a', 'Buch'), D("0.1818"))]
108 | # assertion
109 | # next assertion doesn't make sence because
110 | # initialized by defaultdict
111 | #self.assertEqual(self._format(t0.items()), self._format(loop0))
112 | assert set(t2.items()) == set(loop2)
113 |
114 |
115 | if __name__ == '__main__':
116 | import sys
117 |
118 | fd = open(sys.argv[1]) if len(sys.argv) >= 2 else sys.stdin
119 | sentences = [line.strip().split('|||') for line in fd.readlines()]
120 | t = train(sentences, loop_count=3)
121 | for (e, f), val in t.items():
122 | print("{} {}\t{}".format(e, f, val))
123 |
--------------------------------------------------------------------------------
/smt/phrase/word_alignment.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | from smt.ibmmodel import ibmmodel2
6 | from pprint import pprint
7 |
8 |
9 | def _alignment(elist, flist, e2f, f2e):
10 | '''
11 | elist, flist
12 | wordlist for each language
13 | e2f
14 | translatoin alignment from e to f
15 | alignment is
16 | [(e, f)]
17 | f2e
18 | translatoin alignment from f to e
19 | alignment is
20 | [(e, f)]
21 | return
22 | alignment: {(f, e)}
23 | flist
24 | -----------------
25 | e | |
26 | l | |
27 | i | |
28 | s | |
29 | t | |
30 | -----------------
31 |
32 | '''
33 | neighboring = {(-1, 0), (0, -1), (1, 0), (0, 1),
34 | (-1, -1), (-1, 1), (1, -1), (1, 1)}
35 | e2f = set(e2f)
36 | f2e = set(f2e)
37 | m = len(elist)
38 | n = len(flist)
39 | alignment = e2f.intersection(f2e)
40 | # marge with neighborhood
41 | while True:
42 | set_len = len(alignment)
43 | for e_word in range(1, m+1):
44 | for f_word in range(1, n+1):
45 | if (e_word, f_word) in alignment:
46 | for (e_diff, f_diff) in neighboring:
47 | e_new = e_word + e_diff
48 | f_new = f_word + f_diff
49 | if not alignment:
50 | if (e_new, f_new) in e2f.union(f2e):
51 | alignment.add((e_new, f_new))
52 | else:
53 | if ((e_new not in list(zip(*alignment))[0]
54 | or f_new not in list(zip(*alignment))[1])
55 | and (e_new, f_new) in e2f.union(f2e)):
56 | alignment.add((e_new, f_new))
57 | if set_len == len(alignment):
58 | break
59 | # finalize
60 | for e_word in range(1, m+1):
61 | for f_word in range(1, n+1):
62 | # for alignment = set([])
63 | if not alignment:
64 | if (e_word, f_word) in e2f.union(f2e):
65 | alignment.add((e_word, f_word))
66 | else:
67 | if ((e_word not in list(zip(*alignment))[0]
68 | or f_word not in list(zip(*alignment))[1])
69 | and (e_word, f_word) in e2f.union(f2e)):
70 | alignment.add((e_word, f_word))
71 | return alignment
72 |
73 |
74 | def alignment(es, fs, e2f, f2e):
75 | """
76 | es: English words
77 | fs: Foreign words
78 | f2e: alignment for translation from fs to es
79 | [(e, f)] or {(e, f)}
80 | e2f: alignment for translation from es to fs
81 | [(f, e)] or {(f, e)}
82 | """
83 | _e2f = list(zip(*reversed(list(zip(*e2f)))))
84 | return _alignment(es, fs, _e2f, f2e)
85 |
86 |
87 | def symmetrization(es, fs, corpus):
88 | '''
89 | forpus
90 | for translation from fs to es
91 | return
92 | alignment **from fs to es**
93 | '''
94 | f2e_train = ibmmodel2._train(corpus, loop_count=10)
95 | f2e = ibmmodel2.viterbi_alignment(es, fs, *f2e_train).items()
96 |
97 | e2f_corpus = list(zip(*reversed(list(zip(*corpus)))))
98 | e2f_train = ibmmodel2._train(e2f_corpus, loop_count=10)
99 | e2f = ibmmodel2.viterbi_alignment(fs, es, *e2f_train).items()
100 |
101 | return alignment(es, fs, e2f, f2e)
102 |
103 |
104 | if __name__ == '__main__':
105 | # test for alignment
106 | es = "michael assumes that he will stay in the house".split()
107 | fs = "michael geht davon aus , dass er im haus bleibt".split()
108 | e2f = [(1, 1), (2, 2), (2, 3), (2, 4), (3, 6),
109 | (4, 7), (7, 8), (9, 9), (6, 10)]
110 | f2e = [(1, 1), (2, 2), (3, 6), (4, 7), (7, 8),
111 | (8, 8), (9, 9), (5, 10), (6, 10)]
112 | from smt.utils.utility import matrix
113 | print(matrix(len(es), len(fs), e2f, es, fs))
114 | print(matrix(len(es), len(fs), f2e, es, fs))
115 | ali = _alignment(es, fs, e2f, f2e)
116 | print(matrix(len(es), len(fs), ali, es, fs))
117 |
118 | # test for symmetrization
119 | from smt.utils.utility import mkcorpus
120 | sentenses = [("僕 は 男 です", "I am a man"),
121 | ("私 は 女 です", "I am a girl"),
122 | ("私 は 先生 です", "I am a teacher"),
123 | ("彼女 は 先生 です", "She is a teacher"),
124 | ("彼 は 先生 です", "He is a teacher"),
125 | ]
126 | corpus = mkcorpus(sentenses)
127 | es = "私 は 先生 です".split()
128 | fs = "I am a teacher".split()
129 | syn = symmetrization(es, fs, corpus)
130 | pprint(syn)
131 | print(matrix(len(es), len(fs), syn, es, fs))
132 |
--------------------------------------------------------------------------------
/smt/phrase/phrase_extract.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 |
5 | def phrase_extract(es, fs, alignment):
6 | ext = extract(es, fs, alignment)
7 | ind = {((x, y), (z, w)) for x, y, z, w in ext}
8 | es = tuple(es)
9 | fs = tuple(fs)
10 | return {(es[e_s-1:e_e], fs[f_s-1:f_e])
11 | for (e_s, e_e), (f_s, f_e) in ind}
12 |
13 |
14 | def extract(es, fs, alignment):
15 | """
16 | caution:
17 | alignment starts from 1 - not 0
18 | """
19 | phrases = set()
20 | len_es = len(es)
21 | for e_start in range(1, len_es+1):
22 | for e_end in range(e_start, len_es+1):
23 | # find the minimally matching foreign phrase
24 | f_start, f_end = (len(fs), 0)
25 | for (e, f) in alignment:
26 | if e_start <= e <= e_end:
27 | f_start = min(f, f_start)
28 | f_end = max(f, f_end)
29 | phrases.update(_extract(es, fs, e_start,
30 | e_end, f_start,
31 | f_end, alignment))
32 | return phrases
33 |
34 |
35 | def _extract(es, fs, e_start, e_end, f_start, f_end, alignment):
36 | if f_end == 0:
37 | return {}
38 | for (e, f) in alignment:
39 | if (f_start <= f <= f_end) and (e < e_start or e > e_end):
40 | return {}
41 | ex = set()
42 | f_s = f_start
43 | while True:
44 | f_e = f_end
45 | while True:
46 | ex.add((e_start, e_end, f_s, f_e))
47 | f_e += 1
48 | if f_e in list(zip(*alignment))[1] or f_e > len(fs):
49 | break
50 | f_s -= 1
51 | if f_s in list(zip(*alignment))[1] or f_s < 1:
52 | break
53 | return ex
54 |
55 |
56 | def available_phrases(fs, phrases):
57 | """
58 | return:
59 | set of phrase indexed tuple like
60 | {((1, "I"), (2, "am")),
61 | ((1, "I"),)
62 | ...}
63 | """
64 | available = set()
65 | for i, f in enumerate(fs):
66 | f_rest = ()
67 | for fr in fs[i:]:
68 | f_rest += (fr,)
69 | if f_rest in phrases:
70 | available.add(tuple(enumerate(f_rest, i+1)))
71 | return available
72 |
73 |
74 | def test_phrases():
75 | from smt.utils.utility import mkcorpus
76 | from smt.phrase.word_alignment import symmetrization
77 |
78 | sentenses = [("僕 は 男 です", "I am a man"),
79 | ("私 は 女 です", "I am a girl"),
80 | ("私 は 先生 です", "I am a teacher"),
81 | ("彼女 は 先生 です", "She is a teacher"),
82 | ("彼 は 先生 です", "He is a teacher"),
83 | ]
84 |
85 | corpus = mkcorpus(sentenses)
86 | es, fs = ("私 は 先生 です".split(), "I am a teacher".split())
87 | alignment = symmetrization(es, fs, corpus)
88 | ext = phrase_extract(es, fs, alignment)
89 | ans = ("は 先生 です <-> a teacher",
90 | "先生 <-> teacher"
91 | "私 <-> I am"
92 | "私 は 先生 です <-> I am a teacher")
93 | for e, f in ext:
94 | print("{} {} {}".format(' '.join(e), "<->", ' '.join(f)))
95 |
96 | ## phrases
97 | fs = "I am a teacher".split()
98 | phrases = available_phrases(fs, [fs_ph for (es_ph, fs_ph) in ext])
99 | print(phrases)
100 | ans = {((1, 'I'), (2, 'am')),
101 | ((1, 'I'), (2, 'am'), (3, 'a'), (4, 'teacher')),
102 | ((4, 'teacher'),),
103 | ((3, 'a'), (4, 'teacher'))}
104 |
105 | phrases = available_phrases(fs, [fs_ph for (es_ph, fs_ph) in ext])
106 | assert ans == phrases
107 |
108 |
109 | if __name__ == '__main__':
110 |
111 | # test2
112 | from smt.utils.utility import mkcorpus
113 | from word_alignment import alignment
114 | from smt.ibmmodel import ibmmodel2
115 | import sys
116 |
117 | delimiter = ","
118 | # load file which will be trained
119 | modelfd = open(sys.argv[1])
120 | sentenses = [line.rstrip().split(delimiter) for line
121 | in modelfd.readlines()]
122 | # make corpus
123 | corpus = mkcorpus(sentenses)
124 |
125 | # train model from corpus
126 | f2e_train = ibmmodel2._train(corpus, loop_count=10)
127 | e2f_corpus = list(zip(*reversed(list(zip(*corpus)))))
128 | e2f_train = ibmmodel2._train(e2f_corpus, loop_count=10)
129 |
130 | # phrase extraction
131 | for line in sys.stdin:
132 | _es, _fs = line.rstrip().split(delimiter)
133 | es = _es.split()
134 | fs = _fs.split()
135 |
136 | f2e = ibmmodel2.viterbi_alignment(es, fs, *f2e_train).items()
137 | e2f = ibmmodel2.viterbi_alignment(fs, es, *e2f_train).items()
138 | align = alignment(es, fs, e2f, f2e) # symmetrized alignment
139 |
140 | # output matrix
141 | #from smt.utils.utility import matrix
142 | #print(matrix(len(es), len(fs), align, es, fs))
143 |
144 | ext = phrase_extract(es, fs, align)
145 | for e, f in ext:
146 | print("{}{}{}".format(''.join(e), delimiter, ''.join(f)))
147 |
--------------------------------------------------------------------------------
/smt/ibmmodel/ibmmodel2.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | import collections
5 | from smt.ibmmodel import ibmmodel1
6 | from smt.utils import utility
7 | import decimal
8 | from decimal import Decimal as D
9 |
10 | # set deciaml context
11 | decimal.getcontext().prec = 4
12 | decimal.getcontext().rounding = decimal.ROUND_HALF_UP
13 |
14 |
15 | class _keydefaultdict(collections.defaultdict):
16 | '''define a local function for uniform probability initialization'''
17 | def __missing__(self, key):
18 | if self.default_factory is None:
19 | raise KeyError(key)
20 | else:
21 | ret = self[key] = self.default_factory(key)
22 | return ret
23 |
24 |
25 | def _train(corpus, loop_count=1000):
26 | #print(corpus)
27 | #print(loop_count)
28 | f_keys = set()
29 | for (es, fs) in corpus:
30 | for f in fs:
31 | f_keys.add(f)
32 | # initialize t
33 | t = ibmmodel1._train(corpus, loop_count)
34 | # default value provided as uniform probability)
35 |
36 | def key_fun(key):
37 | ''' default_factory function for keydefaultdict '''
38 | i, j, l_e, l_f = key
39 | return D("1") / D(l_f + 1)
40 | a = _keydefaultdict(key_fun)
41 |
42 | # loop
43 | for _i in range(loop_count):
44 | # variables for estimating t
45 | count = collections.defaultdict(D)
46 | total = collections.defaultdict(D)
47 | # variables for estimating a
48 | count_a = collections.defaultdict(D)
49 | total_a = collections.defaultdict(D)
50 |
51 | s_total = collections.defaultdict(D)
52 | for (es, fs) in corpus:
53 | l_e = len(es)
54 | l_f = len(fs)
55 | # compute normalization
56 | for (j, e) in enumerate(es, 1):
57 | s_total[e] = 0
58 | for (i, f) in enumerate(fs, 1):
59 | s_total[e] += t[(e, f)] * a[(i, j, l_e, l_f)]
60 | # collect counts
61 | for (j, e) in enumerate(es, 1):
62 | for (i, f) in enumerate(fs, 1):
63 | c = t[(e, f)] * a[(i, j, l_e, l_f)] / s_total[e]
64 | count[(e, f)] += c
65 | total[f] += c
66 | count_a[(i, j, l_e, l_f)] += c
67 | total_a[(j, l_e, l_f)] += c
68 |
69 | #for k, v in total.items():
70 | # if v == 0:
71 | # print(k, v)
72 | # estimate probability
73 | for (e, f) in count.keys():
74 | try:
75 | t[(e, f)] = count[(e, f)] / total[f]
76 | except decimal.DivisionByZero:
77 | print(u"e: {e}, f: {f}, count[(e, f)]: {ef}, total[f]: \
78 | {totalf}".format(e=e, f=f, ef=count[(e, f)],
79 | totalf=total[f]))
80 | raise
81 | for (i, j, l_e, l_f) in count_a.keys():
82 | a[(i, j, l_e, l_f)] = count_a[(i, j, l_e, l_f)] / \
83 | total_a[(j, l_e, l_f)]
84 | # output
85 | #for (e, f), val in t.items():
86 | # print("{} {}\t{}".format(e, f, float(val)))
87 | #for (i, j, l_e, l_f), val in a.items():
88 | # print("{} {} {} {}\t{}".format(i, j, l_e, l_f, float(val)))
89 |
90 | return (t, a)
91 |
92 |
93 | def train(sentences, loop_count=1000):
94 | #for i, j in sentences:
95 | # print(i, j)
96 | corpus = utility.mkcorpus(sentences)
97 | return _train(corpus, loop_count)
98 |
99 |
100 | def viterbi_alignment(es, fs, t, a):
101 | '''
102 | return
103 | dictionary
104 | e in es -> f in fs
105 | '''
106 | max_a = collections.defaultdict(float)
107 | l_e = len(es)
108 | l_f = len(fs)
109 | for (j, e) in enumerate(es, 1):
110 | current_max = (0, -1)
111 | for (i, f) in enumerate(fs, 1):
112 | val = t[(e, f)] * a[(i, j, l_e, l_f)]
113 | # select the first one among the maximum candidates
114 | if current_max[1] < val:
115 | current_max = (i, val)
116 | max_a[j] = current_max[0]
117 | return max_a
118 |
119 |
120 | def show_matrix(es, fs, t, a):
121 | '''
122 | print matrix according to viterbi alignment like
123 | fs
124 | -------------
125 | e| |
126 | s| |
127 | | |
128 | -------------
129 | >>> sentences = [("僕 は 男 です", "I am a man"),
130 | ("私 は 女 です", "I am a girl"),
131 | ("私 は 先生 です", "I am a teacher"),
132 | ("彼女 は 先生 です", "She is a teacher"),
133 | ("彼 は 先生 です", "He is a teacher"),
134 | ]
135 | >>> t, a = train(sentences, loop_count=1000)
136 | >>> args = ("私 は 先生 です".split(), "I am a teacher".split(), t, a)
137 | |x| | | |
138 | | | |x| |
139 | | | | |x|
140 | | | |x| |
141 | '''
142 | max_a = viterbi_alignment(es, fs, t, a).items()
143 | m = len(es)
144 | n = len(fs)
145 | return utility.matrix(m, n, max_a, es, fs)
146 |
147 |
148 |
149 | def test_viterbi_alignment():
150 | x = viterbi_alignment([1, 2, 1],
151 | [2, 3, 2],
152 | collections.defaultdict(int),
153 | collections.defaultdict(int))
154 | # Viterbi_alignment selects the first token
155 | # if t or a doesn't contain the key.
156 | # This means it returns NULL token
157 | # in such a situation.
158 | ans = {1: 1, 2: 1, 3: 1}
159 | assert dict(x) == ans
160 |
161 |
162 | if __name__ == '__main__':
163 | import sys
164 |
165 | fd = open(sys.argv[1]) if len(sys.argv) >= 2 else sys.stdin
166 | sentences = [line.strip().split('|||') for line in fd.readlines()]
167 | t, a = train(sentences, loop_count=10)
168 |
169 | es = "私 は 先生 です".split()
170 | fs = "I am a teacher".split()
171 | args = (es, fs, t, a)
172 |
173 | print(show_matrix(*args))
174 |
--------------------------------------------------------------------------------
/test/test_phrase.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | import unittest
5 | from smt.phrase.word_alignment import _alignment
6 | from smt.phrase.word_alignment import symmetrization
7 | from smt.phrase.phrase_extract import extract
8 | from smt.phrase.phrase_extract import phrase_extract
9 | from smt.phrase.phrase_extract import available_phrases
10 | from smt.utils.utility import mkcorpus
11 |
12 |
13 | class WordAlignmentTest(unittest.TestCase):
14 |
15 | def test_alignment(self):
16 | elist = "michael assumes that he will stay in the house".split()
17 | flist = "michael geht davon aus , dass er im haus bleibt".split()
18 | e2f = [(1, 1), (2, 2), (2, 3), (2, 4), (3, 6),
19 | (4, 7), (7, 8), (9, 9), (6, 10)]
20 | f2e = [(1, 1), (2, 2), (3, 6), (4, 7), (7, 8),
21 | (8, 8), (9, 9), (5, 10), (6, 10)]
22 | ans = set([(1, 1),
23 | (2, 2),
24 | (2, 3),
25 | (2, 4),
26 | (3, 6),
27 | (4, 7),
28 | (5, 10),
29 | (6, 10),
30 | (7, 8),
31 | (8, 8),
32 | (9, 9)])
33 | self.assertEqual(_alignment(elist, flist, e2f, f2e), ans)
34 |
35 | def test_symmetrization(self):
36 | sentenses = [("僕 は 男 です", "I am a man"),
37 | ("私 は 女 です", "I am a girl"),
38 | ("私 は 先生 です", "I am a teacher"),
39 | ("彼女 は 先生 です", "She is a teacher"),
40 | ("彼 は 先生 です", "He is a teacher"),
41 | ]
42 | corpus = mkcorpus(sentenses)
43 | es = "私 は 先生 です".split()
44 | fs = "I am a teacher".split()
45 | syn = symmetrization(es, fs, corpus)
46 | ans = set([(1, 1), (1, 2), (2, 3), (3, 4), (4, 3)])
47 | self.assertEqual(syn, ans)
48 |
49 |
50 | class PhraseExtractTest(unittest.TestCase):
51 | def test_extract(self):
52 |
53 | # next alignment matrix is like
54 | #
55 | # | |x|x| | |
56 | # |x| | |x| |
57 | # | | | | |x|
58 | #
59 | es = range(1, 4)
60 | fs = range(1, 6)
61 | alignment = [(2, 1),
62 | (1, 2),
63 | (1, 3),
64 | (2, 4),
65 | (3, 5)]
66 | ans = set([(1, 1, 2, 3), (1, 3, 1, 5), (3, 3, 5, 5), (1, 2, 1, 4)])
67 | self.assertEqual(extract(es, fs, alignment), ans)
68 |
69 | # next alignment matrix is like
70 | #
71 | # |x| | | | | | | | | |
72 | # | |x|x|x| | | | | | |
73 | # | | | | | |x| | | | |
74 | # | | | | | | |x| | | |
75 | # | | | | | | | | | |x|
76 | # | | | | | | | | | |x|
77 | # | | | | | | | |x| | |
78 | # | | | | | | | |x| | |
79 | # | | | | | | | | |x| |
80 | #
81 | es = "michael assumes that he will stay in the house".split()
82 | fs = "michael geht davon aus , dass er im haus bleibt".split()
83 | alignment = set([(1, 1),
84 | (2, 2),
85 | (2, 3),
86 | (2, 4),
87 | (3, 6),
88 | (4, 7),
89 | (5, 10),
90 | (6, 10),
91 | (7, 8),
92 | (8, 8),
93 | (9, 9)])
94 | ans = set([(1, 1, 1, 1),
95 | (1, 2, 1, 4),
96 | (1, 2, 1, 5),
97 | (1, 3, 1, 6),
98 | (1, 4, 1, 7),
99 | (1, 9, 1, 10),
100 | (2, 2, 2, 4),
101 | (2, 2, 2, 5),
102 | (2, 3, 2, 6),
103 | (2, 4, 2, 7),
104 | (2, 9, 2, 10),
105 | (3, 3, 5, 6),
106 | (3, 3, 6, 6),
107 | (3, 4, 5, 7),
108 | (3, 4, 6, 7),
109 | (3, 9, 5, 10),
110 | (3, 9, 6, 10),
111 | (4, 4, 7, 7),
112 | (4, 9, 7, 10),
113 | (5, 6, 10, 10),
114 | (5, 9, 8, 10),
115 | (7, 8, 8, 8),
116 | (7, 9, 8, 9),
117 | (9, 9, 9, 9)])
118 |
119 | self.assertEqual(extract(es, fs, alignment), ans)
120 |
121 | def test_phrase_extract(self):
122 | # next alignment matrix is like
123 | #
124 | # |x| | | | | | | | | |
125 | # | |x|x|x| | | | | | |
126 | # | | | | | |x| | | | |
127 | # | | | | | | |x| | | |
128 | # | | | | | | | | | |x|
129 | # | | | | | | | | | |x|
130 | # | | | | | | | |x| | |
131 | # | | | | | | | |x| | |
132 | # | | | | | | | | |x| |
133 | #
134 | es = "michael assumes that he will stay in the house".split()
135 | fs = "michael geht davon aus , dass er im haus bleibt".split()
136 | alignment = set([(1, 1),
137 | (2, 2),
138 | (2, 3),
139 | (2, 4),
140 | (3, 6),
141 | (4, 7),
142 | (5, 10),
143 | (6, 10),
144 | (7, 8),
145 | (8, 8),
146 | (9, 9)])
147 | ans = set([(('assumes',), ('geht', 'davon', 'aus')),
148 | (('assumes',), ('geht', 'davon', 'aus', ',')),
149 | (('assumes', 'that'),
150 | ('geht', 'davon', 'aus', ',', 'dass')),
151 | (('assumes', 'that', 'he'),
152 | ('geht', 'davon', 'aus', ',', 'dass', 'er')),
153 | (('assumes', 'that', 'he',
154 | 'will', 'stay', 'in', 'the', 'house'),
155 | ('geht', 'davon', 'aus', ',', 'dass',
156 | 'er', 'im', 'haus', 'bleibt')),
157 | (('he',), ('er',)),
158 | (('he', 'will', 'stay', 'in', 'the', 'house'),
159 | ('er', 'im', 'haus', 'bleibt')),
160 | (('house',), ('haus',)),
161 | (('in', 'the'), ('im',)),
162 | (('in', 'the', 'house'), ('im', 'haus')),
163 | (('michael',), ('michael',)),
164 | (('michael', 'assumes'),
165 | ('michael', 'geht', 'davon', 'aus')),
166 | (('michael', 'assumes'),
167 | ('michael', 'geht', 'davon', 'aus', ',')),
168 | (('michael', 'assumes', 'that'),
169 | ('michael', 'geht', 'davon', 'aus', ',', 'dass')),
170 | (('michael', 'assumes', 'that', 'he'),
171 | ('michael', 'geht', 'davon', 'aus', ',', 'dass', 'er')),
172 | (('michael',
173 | 'assumes',
174 | 'that',
175 | 'he',
176 | 'will',
177 | 'stay',
178 | 'in',
179 | 'the',
180 | 'house'),
181 | ('michael',
182 | 'geht',
183 | 'davon',
184 | 'aus',
185 | ',',
186 | 'dass',
187 | 'er',
188 | 'im',
189 | 'haus',
190 | 'bleibt')),
191 | (('that',), (',', 'dass')),
192 | (('that',), ('dass',)),
193 | (('that', 'he'), (',', 'dass', 'er')),
194 | (('that', 'he'), ('dass', 'er')),
195 | (('that', 'he', 'will', 'stay', 'in', 'the', 'house'),
196 | (',', 'dass', 'er', 'im', 'haus', 'bleibt')),
197 | (('that', 'he', 'will', 'stay', 'in', 'the', 'house'),
198 | ('dass', 'er', 'im', 'haus', 'bleibt')),
199 | (('will', 'stay'), ('bleibt',)),
200 | (('will', 'stay', 'in', 'the', 'house'),
201 | ('im', 'haus', 'bleibt'))])
202 | self.assertEqual(phrase_extract(es, fs, alignment), ans)
203 |
204 | # another test
205 | es, fs = ("私 は 先生 です".split(), "I am a teacher".split())
206 | sentenses = [("僕 は 男 です", "I am a man"),
207 | ("私 は 女 です", "I am a girl"),
208 | ("私 は 先生 です", "I am a teacher"),
209 | ("彼女 は 先生 です", "She is a teacher"),
210 | ("彼 は 先生 です", "He is a teacher"),
211 | ]
212 | corpus = mkcorpus(sentenses)
213 | alignment = symmetrization(es, fs, corpus)
214 | ans = set([(('\xe3\x81\xaf',
215 | '\xe5\x85\x88\xe7\x94\x9f',
216 | '\xe3\x81\xa7\xe3\x81\x99'),
217 | ('a', 'teacher')),
218 | (('\xe5\x85\x88\xe7\x94\x9f',), ('teacher',)),
219 | (('\xe7\xa7\x81',), ('I', 'am')),
220 | (('\xe7\xa7\x81',
221 | '\xe3\x81\xaf',
222 | '\xe5\x85\x88\xe7\x94\x9f',
223 | '\xe3\x81\xa7\xe3\x81\x99'),
224 | ('I', 'am', 'a', 'teacher'))])
225 | self.assertEqual(phrase_extract(es, fs, alignment), ans)
226 |
227 | def test_available_phrases(self):
228 | fs = "I am a teacher".split()
229 | phrases = set([("I", "am"),
230 | ("a", "teacher"),
231 | ("teacher",),
232 | ("I", "am", "a", "teacher")])
233 |
234 | ans = set([((4, 'teacher'),),
235 | ((1, 'I'), (2, 'am')),
236 | ((3, 'a'), (4, 'teacher')),
237 | ((1, 'I'), (2, 'am'), (3, 'a'), (4, 'teacher'))])
238 | self.assertEqual(available_phrases(fs, phrases), ans)
239 |
240 | if __name__ == '__main__':
241 | unittest.main()
242 |
--------------------------------------------------------------------------------
/smt/db/createngramdb.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import collections
6 | import sqlite3
7 | # import SQLAlchemy
8 | from sqlalchemy import create_engine
9 | from sqlalchemy.orm import sessionmaker
10 | # smt
11 | from smt.db.tables import Tables
12 | from smt.langmodel.ngram import ngram
13 | import math
14 |
15 |
16 | def _create_ngram_count_db(lang, langmethod=lambda x: x,
17 | n=3, db="sqilte:///:memory:"):
18 | engine = create_engine(db)
19 | # create session
20 | Session = sessionmaker(bind=engine)
21 | session = Session()
22 |
23 | Sentence = Tables().get_sentence_table()
24 | query = session.query(Sentence)
25 |
26 | ngram_dic = collections.defaultdict(float)
27 | for item in query:
28 | if lang == 1:
29 | sentences = langmethod(item.lang1).split()
30 | elif lang == 2:
31 | sentences = langmethod(item.lang2).split()
32 | sentences = ["", ""] + sentences + [""]
33 | ngrams = ngram(sentences, n)
34 | for tpl in ngrams:
35 | ngram_dic[tpl] += 1
36 |
37 | return ngram_dic
38 |
39 |
40 | def create_ngram_count_db(lang, langmethod=lambda x: x,
41 | n=3, db="sqilte:///:memory:"):
42 | engine = create_engine(db)
43 | # create session
44 | Session = sessionmaker(bind=engine)
45 | session = Session()
46 |
47 | # trigram table
48 | tablename = 'lang{}trigram'.format(lang)
49 | Trigram = Tables().get_trigram_table(tablename)
50 | # create table
51 | Trigram.__table__.drop(engine, checkfirst=True)
52 | Trigram.__table__.create(engine)
53 |
54 | ngram_dic = _create_ngram_count_db(lang, langmethod=langmethod, n=n, db=db)
55 |
56 | # insert items
57 | for (first, second, third), count in ngram_dic.items():
58 | print(u"inserting {}, {}, {}".format(first, second, third))
59 | item = Trigram(first=first,
60 | second=second,
61 | third=third,
62 | count=count)
63 | session.add(item)
64 | session.commit()
65 |
66 |
67 | def create_unigram_count_db(lang, langmethod=lambda x: x,
68 | db="sqilte:///:memory:"):
69 | engine = create_engine(db)
70 | # create session
71 | Session = sessionmaker(bind=engine)
72 | session = Session()
73 |
74 | # trigram table
75 | tablename = 'lang{}unigram'.format(lang)
76 | Sentence = Tables().get_sentence_table()
77 | Unigram = Tables().get_unigram_table(tablename)
78 | # create table
79 | Unigram.__table__.drop(engine, checkfirst=True)
80 | Unigram.__table__.create(engine)
81 |
82 | query = session.query(Sentence)
83 | ngram_dic = collections.defaultdict(int)
84 | for item in query:
85 | if lang == 1:
86 | sentences = langmethod(item.lang1).split()
87 | elif lang == 2:
88 | sentences = langmethod(item.lang2).split()
89 | ngrams = ngram(sentences, 1)
90 | for tpl in ngrams:
91 | ngram_dic[tpl] += 1
92 |
93 | # insert items
94 | for (first,), count in ngram_dic.items():
95 | print(u"inserting {}: {}".format(first, count))
96 | item = Unigram(first=first,
97 | count=count)
98 | session.add(item)
99 | session.commit()
100 |
101 |
102 | # create views using SQLite3
103 | def create_ngram_count_without_last_view(lang, db=":memory:"):
104 | # create phrase_count table
105 | fromtablename = "lang{}trigram".format(lang)
106 | table_name = "lang{}trigram_without_last".format(lang)
107 | # create connection
108 | con = sqlite3.connect(db)
109 | cur = con.cursor()
110 | try:
111 | cur.execute("drop view {0}".format(table_name))
112 | except sqlite3.Error:
113 | print("{0} view does not exists.\n\
114 | => creating a new view".format(table_name))
115 | cur.execute("""create view {}
116 | as select first, second, sum(count) as count from
117 | {} group by first, second order by count
118 | desc""".format(table_name, fromtablename))
119 | con.commit()
120 |
121 |
122 | def create_ngram_prob(lang,
123 | db=":memory:"):
124 |
125 | # Create connection in sqlite3 to use view
126 | table_name = "lang{}trigram_without_last".format(lang)
127 | # create connection
128 | con = sqlite3.connect(db)
129 | cur = con.cursor()
130 |
131 | trigram_tablename = 'lang{}trigram'.format(lang)
132 | trigramprob_tablename = 'lang{}trigramprob'.format(lang)
133 | trigramprobwithoutlast_tablename = 'lang{}trigramprob_without_last'\
134 | .format(lang)
135 |
136 | # tables
137 | Trigram = Tables().get_trigram_table(trigram_tablename)
138 | TrigramProb = Tables().get_trigramprob_table(trigramprob_tablename)
139 | TrigramProbWithoutLast = Tables().get_trigramprobwithoutlast_table(
140 | trigramprobwithoutlast_tablename)
141 |
142 | # create connection in SQLAlchemy
143 | sqlalchemydb = "sqlite:///{}".format(db)
144 | engine = create_engine(sqlalchemydb)
145 | # create session
146 | Session = sessionmaker(bind=engine)
147 | session = Session()
148 | # create table
149 | TrigramProb.__table__.drop(engine, checkfirst=True)
150 | TrigramProb.__table__.create(engine)
151 | TrigramProbWithoutLast.__table__.drop(engine, checkfirst=True)
152 | TrigramProbWithoutLast.__table__.create(engine)
153 |
154 | # calculate total number
155 | query = session.query(Trigram)
156 | totalnumber = len(query.all())
157 |
158 | # get trigrams
159 | query = session.query(Trigram)
160 | for item in query:
161 | first, second, third = item.first, item.second, item.third
162 | count = item.count
163 |
164 | cur.execute("""select * from {} where \
165 | first=? and\
166 | second=?""".format(table_name),
167 | (first, second))
168 | one = cur.fetchone()
169 | # if fetch is failed, one is NONE (no exceptions are raised)
170 | if not one:
171 | print("not found correspont first and second")
172 | continue
173 | else:
174 | alpha = 0.00017
175 | c = count
176 | n = one[2]
177 | v = totalnumber
178 | # create logprob
179 | logprob = math.log((c + alpha) / (n + alpha * v))
180 | print(u"{}, {}, {}:\
181 | log({} + {} / {} + {} + {}) = {}".format(first,
182 | second,
183 | third,
184 | c,
185 | alpha,
186 | n,
187 | alpha,
188 | v,
189 | logprob))
190 | trigramprob = TrigramProb(first=first,
191 | second=second,
192 | third=third,
193 | prob=logprob)
194 | session.add(trigramprob)
195 | # for without last
196 | logprobwithoutlast = math.log(alpha / (n + alpha * v))
197 | print(u"{}, {}, {}:\
198 | log({} / {} + {} + {}) = {}".format(first,
199 | second,
200 | third,
201 | alpha,
202 | n,
203 | alpha,
204 | v,
205 | logprobwithoutlast))
206 | probwl = TrigramProbWithoutLast(first=first,
207 | second=second,
208 | prob=logprobwithoutlast)
209 | session.add(probwl)
210 | session.commit()
211 |
212 |
213 | def create_unigram_prob(lang, db=":memory:"):
214 |
215 | unigram_tablename = 'lang{}unigram'.format(lang)
216 | unigramprob_tablename = 'lang{}unigramprob'.format(lang)
217 |
218 | # tables
219 | Unigram = Tables().get_unigram_table(unigram_tablename)
220 | UnigramProb = Tables().get_unigramprob_table(unigramprob_tablename)
221 |
222 | # create engine
223 | sqlalchemydb = "sqlite:///{}".format(db)
224 | engine = create_engine(sqlalchemydb)
225 | # create session
226 | Session = sessionmaker(bind=engine)
227 | session = Session()
228 | # create table
229 | UnigramProb.__table__.drop(engine, checkfirst=True)
230 | UnigramProb.__table__.create(engine)
231 |
232 | # calculate total number
233 | query = session.query(Unigram)
234 | sm = 0
235 | totalnumber = 0
236 | for item in query:
237 | totalnumber += 1
238 | sm += item.count
239 |
240 | # get trigrams
241 | query = session.query(Unigram)
242 | for item in query:
243 | first = item.first
244 | count = item.count
245 |
246 | alpha = 0.00017
247 | c = count
248 | v = totalnumber
249 | # create logprob
250 | logprob = math.log((c + alpha) / (sm + alpha * v))
251 | print(u"{}:\
252 | log({}+{} / {} + {}*{}) = {}".format(first,
253 | c,
254 | alpha,
255 | sm,
256 | alpha,
257 | v,
258 | logprob))
259 | unigramprob = UnigramProb(first=first,
260 | prob=logprob)
261 | session.add(unigramprob)
262 | session.commit()
263 |
264 |
265 | def create_ngram_db(lang, langmethod=lambda x: x,
266 | n=3, db=":memory:"):
267 |
268 | sqlalchemydb = "sqlite:///{}".format(db)
269 | create_ngram_count_db(lang=lang, langmethod=langmethod,
270 | n=n,
271 | db=sqlalchemydb)
272 | create_ngram_count_without_last_view(lang=lang, db=db)
273 | create_ngram_prob(lang=lang, db=db)
274 |
275 | create_unigram_count_db(lang=lang, langmethod=langmethod,
276 | db=sqlalchemydb)
277 | create_unigram_prob(lang=lang, db=db)
278 |
279 |
280 | if __name__ == '__main__':
281 | pass
282 |
--------------------------------------------------------------------------------
/smt/db/createdb.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import collections
6 | import utility
7 | from smt.ibmmodel import ibmmodel2
8 | from smt.phrase import word_alignment
9 | from smt.phrase import phrase_extract
10 | from progressline import ProgressLine
11 | from tables import Tables
12 | # import SQLAlchemy
13 | import sqlalchemy
14 | from sqlalchemy import create_engine
15 | from sqlalchemy.orm import sessionmaker
16 | import sqlite3
17 | import math
18 |
19 |
20 | def create_corpus(db="sqlite:///:memory:",
21 | lang1method=lambda x: x,
22 | lang2method=lambda x: x,
23 | limit=None):
24 | engine = create_engine(db)
25 | # create session
26 | Session = sessionmaker(bind=engine)
27 | session = Session()
28 |
29 | Sentence = Tables().get_sentence_table()
30 |
31 | query = session.query(Sentence)[:limit] if limit \
32 | else session.query(Sentence)
33 |
34 | for item in query:
35 | yield {"lang1": lang1method(item.lang1),
36 | "lang2": lang2method(item.lang2)}
37 |
38 |
39 | def create_train_db(transfrom=2,
40 | transto=1,
41 | lang1method=lambda x: x,
42 | lang2method=lambda x: x,
43 | db="sqlite:///:memory:",
44 | limit=None,
45 | loop_count=1000):
46 | engine = create_engine(db)
47 | # create session
48 | Session = sessionmaker(bind=engine)
49 | session = Session()
50 |
51 | # tablenames
52 | table_prefix = "from{0}to{1}".format(transfrom, transto)
53 | wordprob_tablename = table_prefix + "_" + "wordprob"
54 | wordalign_tablename = table_prefix + "_" + "wordalign"
55 | # tables
56 | WordProbability = Tables().get_wordprobability_table(wordprob_tablename)
57 | WordAlignment = Tables().get_wordalignment_table(wordalign_tablename)
58 | # create table for word probability
59 | WordProbability.__table__.drop(engine, checkfirst=True)
60 | WordProbability.__table__.create(engine)
61 | print("created table: {0}to{1}_wordprob".format(transfrom, transto))
62 |
63 | # create table for alignment probability
64 | WordAlignment.__table__.drop(engine, checkfirst=True)
65 | WordAlignment.__table__.create(engine)
66 | print("created table: {0}to{1}_wordalign".format(transfrom, transto))
67 |
68 | # IBM learning
69 | with ProgressLine(0.12, title='IBM Model learning...'):
70 | # check arguments for carete_corpus
71 | corpus = create_corpus(db=db, limit=limit,
72 | lang1method=lang1method,
73 | lang2method=lang2method)
74 | sentences = [(item["lang{0}".format(transto)],
75 | item["lang{0}".format(transfrom)])
76 | for item in corpus]
77 | t, a = ibmmodel2.train(sentences=sentences,
78 | loop_count=loop_count)
79 | # insert
80 | with ProgressLine(0.12, title='Inserting items into database...'):
81 | for (_to, _from), prob in t.items():
82 | session.add(WordProbability(transto=_to,
83 | transfrom=_from,
84 | prob=float(prob)))
85 | for (from_pos, to_pos, to_len, from_len), prob in a.items():
86 | session.add(WordAlignment(from_pos=from_pos,
87 | to_pos=to_pos,
88 | to_len=to_len,
89 | from_len=from_len,
90 | prob=float(prob)))
91 | session.commit()
92 |
93 |
94 | def db_viterbi_alignment(es, fs,
95 | transfrom=2,
96 | transto=1,
97 | db="sqlite:///:memory:",
98 | init_val=1.0e-10):
99 | """
100 | Calculating viterbi_alignment using specified database.
101 |
102 | Arguments:
103 | trans:
104 | it can take "en2ja" or "ja2en"
105 | """
106 | engine = create_engine(db)
107 | # create session
108 | Session = sessionmaker(bind=engine)
109 | session = Session()
110 |
111 | # tablenames
112 | table_prefix = "from{0}to{1}".format(transfrom, transto)
113 | wordprob_tablename = table_prefix + "_" + "wordprob"
114 | wordalign_tablename = table_prefix + "_" + "wordalign"
115 | # tables
116 | WordProbability = Tables().get_wordprobability_table(wordprob_tablename)
117 | WordAlignment = Tables().get_wordalignment_table(wordalign_tablename)
118 |
119 | def get_wordprob(e, f, init_val=1.0e-10):
120 |
121 | query = session.query(WordProbability).filter_by(transto=e,
122 | transfrom=f)
123 | try:
124 | return query.one().prob
125 | except sqlalchemy.orm.exc.NoResultFound:
126 | return init_val
127 |
128 | def get_wordalign(i, j, l_e, l_f, init_val=1.0e-10):
129 |
130 | query = session.query(WordAlignment).filter_by(from_pos=i,
131 | to_pos=j,
132 | to_len=l_e,
133 | from_len=l_f)
134 | try:
135 | return query.one().prob
136 | except sqlalchemy.orm.exc.NoResultFound:
137 | return init_val
138 |
139 | # algorithm
140 | max_a = collections.defaultdict(float)
141 | l_e = len(es)
142 | l_f = len(fs)
143 | for (j, e) in enumerate(es, 1):
144 | current_max = (0, -1)
145 | for (i, f) in enumerate(fs, 1):
146 | val = get_wordprob(e, f, init_val=init_val) *\
147 | get_wordalign(i, j, l_e, l_f, init_val=init_val)
148 | # select the first one among the maximum candidates
149 | if current_max[1] < val:
150 | current_max = (i, val)
151 | max_a[j] = current_max[0]
152 | return max_a
153 |
154 |
155 | def db_show_matrix(es, fs,
156 | transfrom=2,
157 | transto=1,
158 | db="sqlite:///:memory:",
159 | init_val=0.00001):
160 | '''
161 | print matrix according to viterbi alignment like
162 | fs
163 | -------------
164 | e| |
165 | s| |
166 | | |
167 | -------------
168 | >>> sentences = [("僕 は 男 です", "I am a man"),
169 | ("私 は 女 です", "I am a girl"),
170 | ("私 は 先生 です", "I am a teacher"),
171 | ("彼女 は 先生 です", "She is a teacher"),
172 | ("彼 は 先生 です", "He is a teacher"),
173 | ]
174 | >>> t, a = train(sentences, loop_count=1000)
175 | >>> args = ("私 は 先生 です".split(), "I am a teacher".split(), t, a)
176 | |x| | | |
177 | | | |x| |
178 | | | | |x|
179 | | | |x| |
180 | '''
181 | max_a = db_viterbi_alignment(es, fs,
182 | transfrom=transfrom,
183 | transto=transto,
184 | db=db,
185 | init_val=init_val).items()
186 | m = len(es)
187 | n = len(fs)
188 | return utility.matrix(m, n, max_a)
189 |
190 |
191 | def _db_symmetrization(lang1s, lang2s,
192 | init_val=1.0e-10,
193 | db="sqlite:///:memory:"):
194 | '''
195 | '''
196 | transfrom = 2
197 | transto = 1
198 | trans = db_viterbi_alignment(lang1s, lang2s,
199 | transfrom=transfrom,
200 | transto=transto,
201 | db=db,
202 | init_val=init_val).items()
203 | rev_trans = db_viterbi_alignment(lang2s, lang1s,
204 | transfrom=transto,
205 | transto=transfrom,
206 | db=db,
207 | init_val=init_val).items()
208 | return word_alignment.alignment(lang1s, lang2s, trans, rev_trans)
209 |
210 |
211 | def db_phrase_extract(lang1, lang2,
212 | lang1method=lambda x: x,
213 | lang2method=lambda x: x,
214 | init_val=1.0e-10,
215 | db="sqlite:///:memory:"):
216 | lang1s = lang1method(lang1).split()
217 | lang2s = lang1method(lang2).split()
218 | alignment = _db_symmetrization(lang1s, lang2s,
219 | init_val=init_val,
220 | db=db)
221 | return phrase_extract.phrase_extract(lang1s, lang2s, alignment)
222 |
223 |
224 | def create_phrase_db(limit=None,
225 | lang1method=lambda x: x,
226 | lang2method=lambda x: x,
227 | init_val=1.0e-10,
228 | db="sqlite:///:memory:"):
229 | engine = create_engine(db)
230 | # create session
231 | Session = sessionmaker(bind=engine)
232 | session = Session()
233 | # tables
234 | Sentence = Tables().get_sentence_table()
235 | Phrase = Tables().get_phrase_table()
236 |
237 | # create table for word probability
238 | Phrase.__table__.drop(engine, checkfirst=True)
239 | Phrase.__table__.create(engine)
240 | print("created table: phrase")
241 |
242 | query = session.query(Sentence)[:limit] if limit \
243 | else session.query(Sentence)
244 |
245 | with ProgressLine(0.12, title='extracting phrases...'):
246 | for item in query:
247 | lang1 = item.lang1
248 | lang2 = item.lang2
249 | print(" ", lang1, lang2)
250 | phrases = db_phrase_extract(lang1, lang2,
251 | lang1method=lang1method,
252 | lang2method=lang2method,
253 | init_val=init_val,
254 | db=db)
255 | for lang1ps, lang2ps in phrases:
256 | lang1p = u" ".join(lang1ps)
257 | lang2p = u" ".join(lang2ps)
258 | ph = Phrase(lang1p=lang1p, lang2p=lang2p)
259 | session.add(ph)
260 | session.commit()
261 |
262 |
263 | # create views using SQLite3
264 | def create_phrase_count_view(db=":memory:"):
265 | # create phrase_count table
266 | table_name = "phrasecount"
267 | con = sqlite3.connect(db)
268 | cur = con.cursor()
269 | try:
270 | cur.execute("drop view {0}".format(table_name))
271 | except sqlite3.Error:
272 | print("{0} view does not exists.\n\
273 | => creating a new view".format(table_name))
274 | cur.execute("""create view {0}
275 | as select *, count(*) as count from
276 | phrase group by lang1p, lang2p order by count
277 | desc""".format(table_name))
278 | con.commit()
279 |
280 | # create phrase_count_ja table
281 | table_name_ja = "lang1_phrasecount"
282 | con = sqlite3.connect(db)
283 | cur = con.cursor()
284 | try:
285 | cur.execute("drop view {0}".format(table_name_ja))
286 | except sqlite3.Error:
287 | print("{0} view does not exists.\n\
288 | => creating a new view".format(table_name_ja))
289 | cur.execute("""create view {0}
290 | as select lang1p as langp,
291 | sum(count) as count from phrasecount group by
292 | lang1p order
293 | by count desc""".format(table_name_ja))
294 | con.commit()
295 |
296 | # create phrase_count_en table
297 | table_name_en = "lang2_phrasecount"
298 | con = sqlite3.connect(db)
299 | cur = con.cursor()
300 | try:
301 | cur.execute("drop view {0}".format(table_name_en))
302 | except sqlite3.Error:
303 | print("{0} view does not exists.\n\
304 | => creating a new view".format(table_name_en))
305 | cur.execute("""create view {0}
306 | as select lang2p as langp,
307 | sum(count) as count from phrasecount group by
308 | lang2p order
309 | by count desc""".format(table_name_en))
310 | con.commit()
311 |
312 |
313 | # using sqlite
314 | def create_phrase_prob(db=":memory:"):
315 | """
316 | """
317 | # create phrase_prob table
318 | table_name = "phraseprob"
319 | engine = create_engine("sqlite:///{0}".format(db))
320 | # create session
321 | Session = sessionmaker(bind=engine)
322 | session = Session()
323 | # tables
324 | TransPhraseProb = Tables().get_transphraseprob_table()
325 |
326 | # create table for word probability
327 | TransPhraseProb.__table__.drop(engine, checkfirst=True)
328 | TransPhraseProb.__table__.create(engine)
329 | session.commit()
330 | print("created table: {0}".format(table_name))
331 |
332 | con = sqlite3.connect(db)
333 | cur = con.cursor()
334 | cur_sel = con.cursor()
335 | #cur_rec = con.cursor()
336 | cur.execute("select lang1p, lang2p, count from phrasecount")
337 | with ProgressLine(0.12, title='phrase learning...'):
338 | for lang1p, lang2p, count in cur:
339 | # for p2_1
340 | cur_sel.execute(u"""select count
341 | from lang1_phrasecount where
342 | langp=?""",
343 | (lang1p,))
344 | count2_1 = list(cur_sel)
345 | count2_1 = count2_1[0][0]
346 | p2_1 = count / count2_1
347 | # for p1_2
348 | cur_sel.execute(u"""select count
349 | from lang2_phrasecount where
350 | langp=?""",
351 | (lang2p,))
352 | count1_2 = list(cur_sel)
353 | count1_2 = count1_2[0][0]
354 | p1_2 = count / count1_2
355 | # insert item
356 | transphraseprob = TransPhraseProb(lang1p=lang1p,
357 | lang2p=lang2p,
358 | p1_2=math.log(p1_2),
359 | p2_1=math.log(p2_1))
360 | session.add(transphraseprob)
361 | print(u" added phraseprob: {0} <=> {1} ".format(lang1p, lang2p))
362 | session.commit()
363 |
364 |
365 | def createdb(db=":memory:",
366 | lang1method=lambda x: x,
367 | lang2method=lambda x: x,
368 | init_val=1.0e-10,
369 | limit=None,
370 | loop_count=1000,
371 | ):
372 | alchemydb = "sqlite:///{0}".format(db)
373 | create_train_db(transfrom=2,
374 | transto=1,
375 | lang1method=lang1method,
376 | lang2method=lang2method,
377 | db=alchemydb,
378 | limit=limit,
379 | loop_count=loop_count)
380 | create_train_db(transfrom=1,
381 | transto=2,
382 | lang1method=lang1method,
383 | lang2method=lang2method,
384 | db=alchemydb,
385 | limit=limit,
386 | loop_count=loop_count)
387 | create_phrase_db(limit=limit,
388 | lang1method=lang1method,
389 | lang2method=lang2method,
390 | init_val=init_val,
391 | db=alchemydb)
392 | create_phrase_count_view(db=db)
393 | create_phrase_prob(db=db)
394 |
395 | if __name__ == "__main__":
396 | pass
397 |
--------------------------------------------------------------------------------
/test/test_stackdecoder.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | import unittest
5 | from fractions import Fraction as Frac
6 | from smt.decoder.stackdecoder import _future_cost_estimate
7 | from smt.decoder.stackdecoder import _create_estimate_dict
8 | from smt.decoder.stackdecoder import ArgumentNotSatisfied
9 | from smt.decoder.stackdecoder import future_cost_estimate
10 | from smt.decoder.stackdecoder import TransPhraseProb
11 | from smt.decoder.stackdecoder import Phrase
12 | # sqlalchemy
13 | from sqlalchemy import create_engine
14 | from sqlalchemy.orm import sessionmaker
15 |
16 |
17 | class DBSetup(object):
18 |
19 | def __init__(self, db="sqlite:///:memory:"):
20 | self.db = db
21 | self.tables = [TransPhraseProb, Phrase]
22 |
23 | def __enter__(self):
24 | self.engine = create_engine(self.db)
25 | # create tables
26 | for Table in self.tables:
27 | Table.__table__.create(self.engine)
28 |
29 | # create session
30 | Session = sessionmaker(bind=self.engine)
31 | self.session = Session()
32 |
33 | return self
34 |
35 | def __exit__(self, exc_type, exc_value, traceback):
36 | # drop table
37 | for Table in self.tables:
38 | Table.__table__.drop(self.engine, checkfirst=True)
39 | self.session.close()
40 |
41 |
42 | class FutureCostEstimateTest(unittest.TestCase):
43 |
44 | def setUp(self):
45 | self.maxDiff = None
46 |
47 | def test_future_cost_estimate_2to1(self):
48 | sentences = u"the tourism initiative addresses this\
49 | for the first time".split()
50 | transfrom = 2
51 | transto = 1
52 | init_val = 100.0
53 | db = "sqlite:///test/:test:"
54 |
55 | # data set
56 | dataset = [("1", "the", Frac(-1), 0),
57 | ("1", "the", Frac(-2), 0),
58 | # 2
59 | ("1", "tourism", Frac(-2), 0),
60 | ("1", "tourism", Frac(-3), 0),
61 | # 3
62 | ("1", "initiative", Frac(-15, 10), 0),
63 | ("1", "initiative", Frac(-25, 10), 0),
64 | # 4
65 | ("1", "addresses", Frac(-24, 10), 0),
66 | ("1", "addresses", Frac(-34, 10), 0),
67 | # 5
68 | ("1", "this", Frac(-14, 10), 0),
69 | ("1", "this", Frac(-24, 10), 0),
70 | # 6
71 | ("1", "for", Frac(-1), 0),
72 | ("1", "for", Frac(-2), 0),
73 | # 7
74 | ("1", "the", Frac(-1), 0),
75 | ("1", "the", Frac(-2), 0),
76 | # 8
77 | ("1", "first", Frac(-19, 10), 0),
78 | ("1", "first", Frac(-29, 10), 0),
79 | # 9
80 | ("1", "time", Frac(-16, 10), 0),
81 | ("1", "time", Frac(-26, 10), 0),
82 | # 10
83 | ("1", "initiative addresses", Frac(-4), 0),
84 | ("1", "initiative addresses", Frac(-4), 0),
85 | # 11
86 | ("1", "this for", Frac(-25, 10), 0),
87 | ("1", "this for", Frac(-35, 10), 0),
88 | # 12
89 | ("1", "the first", Frac(-22, 10), 0),
90 | ("1", "the first", Frac(-32, 10), 0),
91 | # 13
92 | ("1", "for the", Frac(-13, 10), 0),
93 | ("1", "for the", Frac(-23, 10), 0),
94 | # 14
95 | ("1", "first time", Frac(-24, 10), 0),
96 | ("1", "first time", Frac(-34, 10), 0),
97 | # 15
98 | ("1", "this for the", Frac(-27, 10), 0),
99 | ("1", "this for the", Frac(-37, 10), 0),
100 | # 16
101 | ("1", "for the first", Frac(-23, 10), 0),
102 | ("1", "for the first", Frac(-33, 10), 0),
103 | # 17
104 | ("1", "the first time", Frac(-23, 10), 0),
105 | ("1", "the first time", Frac(-33, 10), 0),
106 | # 18
107 | ("1", "for the first time", Frac(-23, 10), 0),
108 | ("1", "for the first time", Frac(-33, 10), 0),
109 | ]
110 | val = {(1, 1): -1.0,
111 | (1, 2): -3.0,
112 | (1, 3): -4.5,
113 | (1, 4): -6.9,
114 | (1, 5): -8.3,
115 | (1, 6): -9.3,
116 | (1, 7): -9.6,
117 | (1, 8): -10.6,
118 | (1, 9): -10.6,
119 | (2, 2): -2.0,
120 | (2, 3): -3.5,
121 | (2, 4): -5.9,
122 | (2, 5): -7.3,
123 | (2, 6): -8.3,
124 | (2, 7): -8.6,
125 | (2, 8): -9.6,
126 | (2, 9): -9.6,
127 | (3, 3): -1.5,
128 | (3, 4): -3.9,
129 | (3, 5): -5.3,
130 | (3, 6): -6.3,
131 | (3, 7): -6.6,
132 | (3, 8): -7.6,
133 | (3, 9): -7.6,
134 | (4, 4): -2.4,
135 | (4, 5): -3.8,
136 | (4, 6): -4.8,
137 | (4, 7): -5.1,
138 | (4, 8): -6.1,
139 | (4, 9): -6.1,
140 | (5, 5): -1.4,
141 | (5, 6): -2.4,
142 | (5, 7): -2.7,
143 | (5, 8): -3.6999999999999997,
144 | (5, 9): -3.6999999999999997,
145 | (6, 6): -1.0,
146 | (6, 7): -1.3,
147 | (6, 8): -2.3,
148 | (6, 9): -2.3,
149 | (7, 7): -1.0,
150 | (7, 8): -2.2,
151 | (7, 9): -2.3,
152 | (8, 8): -1.9,
153 | (8, 9): -2.4,
154 | (9, 9): -1.6,
155 | }
156 |
157 | with DBSetup(db) as dbobj:
158 | dbobj.session.add_all(TransPhraseProb(lang1p=item[0],
159 | lang2p=item[1],
160 | p2_1=item[2],
161 | p1_2=item[3])
162 | for item in dataset)
163 | dbobj.session.add_all(Phrase(lang1p=item[0],
164 | lang2p=item[1])
165 | for item in dataset)
166 | dbobj.session.commit()
167 |
168 | ans = future_cost_estimate(sentences,
169 | transfrom=transfrom,
170 | transto=transto,
171 | init_val=init_val,
172 | db=db)
173 | # assert
174 | self.assertEqual(ans, val)
175 |
176 | def test_future_cost_estimate_2to1_argument_not_satisfied(self):
177 | sentences = u"the tourism initiative addresses this\
178 | for the first time".split()
179 | transfrom = 2
180 | transto = 1
181 | init_val = 100.0
182 | db = "sqlite:///test/:test:"
183 |
184 | # data set
185 | dataset = [("1", "the", Frac(-1), 0),
186 | ("1", "the", Frac(-2), 0),
187 | # 2
188 | ("1", "tourism", Frac(-2), 0),
189 | ("1", "tourism", Frac(-3), 0),
190 | # 3
191 | ("1", "initiative", Frac(-15, 10), 0),
192 | ("1", "initiative", Frac(-25, 10), 0),
193 | # 4
194 | ("1", "addresses", Frac(-24, 10), 0),
195 | ("1", "addresses", Frac(-34, 10), 0),
196 | # 5
197 | #("1", "this", Frac(-14, 10), 0),
198 | #("1", "this", Frac(-24, 10), 0),
199 | # 6
200 | ("1", "for", Frac(-1), 0),
201 | ("1", "for", Frac(-2), 0),
202 | # 7
203 | ("1", "the", Frac(-1), 0),
204 | ("1", "the", Frac(-2), 0),
205 | # 8
206 | ("1", "first", Frac(-19, 10), 0),
207 | ("1", "first", Frac(-29, 10), 0),
208 | # 9
209 | ("1", "time", Frac(-16, 10), 0),
210 | ("1", "time", Frac(-26, 10), 0),
211 | # 10
212 | ("1", "initiative addresses", Frac(-4), 0),
213 | ("1", "initiative addresses", Frac(-4), 0),
214 | # 11
215 | ("1", "this for", Frac(-25, 10), 0),
216 | ("1", "this for", Frac(-35, 10), 0),
217 | # 12
218 | ("1", "the first", Frac(-22, 10), 0),
219 | ("1", "the first", Frac(-32, 10), 0),
220 | # 13
221 | ("1", "for the", Frac(-13, 10), 0),
222 | ("1", "for the", Frac(-23, 10), 0),
223 | # 14
224 | ("1", "first time", Frac(-24, 10), 0),
225 | ("1", "first time", Frac(-34, 10), 0),
226 | # 15
227 | ("1", "this for the", Frac(-27, 10), 0),
228 | ("1", "this for the", Frac(-37, 10), 0),
229 | # 16
230 | ("1", "for the first", Frac(-23, 10), 0),
231 | ("1", "for the first", Frac(-33, 10), 0),
232 | # 17
233 | ("1", "the first time", Frac(-23, 10), 0),
234 | ("1", "the first time", Frac(-33, 10), 0),
235 | # 18
236 | ("1", "for the first time", Frac(-23, 10), 0),
237 | ("1", "for the first time", Frac(-33, 10), 0),
238 | ]
239 |
240 | val = {(1, 1): -1.0,
241 | (1, 2): -3.0,
242 | (1, 3): -4.5,
243 | (1, 4): -6.9,
244 | (1, 5): -106.9,
245 | (1, 6): -9.4,
246 | (1, 7): -9.6,
247 | (1, 8): -11.5,
248 | (1, 9): -11.7,
249 | (2, 2): -2.0,
250 | (2, 3): -3.5,
251 | (2, 4): -5.9,
252 | (2, 5): -105.9,
253 | (2, 6): -8.4,
254 | (2, 7): -8.6,
255 | (2, 8): -10.5,
256 | (2, 9): -10.7,
257 | (3, 3): -1.5,
258 | (3, 4): -3.9,
259 | (3, 5): -103.9,
260 | (3, 6): -6.4,
261 | (3, 7): -6.6,
262 | (3, 8): -8.5,
263 | (3, 9): -8.7,
264 | (4, 4): -2.4,
265 | (4, 5): -102.4,
266 | (4, 6): -4.9,
267 | (4, 7): -5.1,
268 | (4, 8): -7.0,
269 | (4, 9): -7.199999999999999,
270 | (5, 5): -100.0,
271 | (5, 6): -2.5,
272 | (5, 7): -2.7,
273 | (5, 8): -4.6,
274 | (5, 9): -4.8,
275 | (6, 6): -1.0,
276 | (6, 7): -1.3,
277 | (6, 8): -2.3,
278 | (6, 9): -2.3,
279 | (7, 7): -1.0,
280 | (7, 8): -2.2,
281 | (7, 9): -2.3,
282 | (8, 8): -1.9,
283 | (8, 9): -2.4,
284 | (9, 9): -1.6,
285 | }
286 |
287 | with DBSetup(db) as dbobj:
288 | dbobj.session.add_all(TransPhraseProb(lang1p=item[0],
289 | lang2p=item[1],
290 | p2_1=item[2],
291 | p1_2=item[3])
292 | for item in dataset)
293 | dbobj.session.add_all(Phrase(lang1p=item[0],
294 | lang2p=item[1])
295 | for item in dataset)
296 | dbobj.session.commit()
297 |
298 | ans = future_cost_estimate(sentences,
299 | transfrom=transfrom,
300 | transto=transto,
301 | init_val=init_val,
302 | db=db)
303 |
304 | # assert
305 | self.assertEqual(ans, val)
306 |
307 | def test_future_cost_estimate_1to2(self):
308 | sentences = u"the tourism initiative addresses this\
309 | for the first time".split()
310 | transfrom = 1
311 | transto = 2
312 | init_val = 100.0
313 | db = "sqlite:///test/:test:"
314 |
315 | # data set
316 | dataset = [("1", "the", Frac(-1), 0),
317 | ("1", "the", Frac(-2), 0),
318 | # 2
319 | ("1", "tourism", Frac(-2), 0),
320 | ("1", "tourism", Frac(-3), 0),
321 | # 3
322 | ("1", "initiative", Frac(-15, 10), 0),
323 | ("1", "initiative", Frac(-25, 10), 0),
324 | # 4
325 | ("1", "addresses", Frac(-24, 10), 0),
326 | ("1", "addresses", Frac(-34, 10), 0),
327 | # 5
328 | ("1", "this", Frac(-14, 10), 0),
329 | ("1", "this", Frac(-24, 10), 0),
330 | # 6
331 | ("1", "for", Frac(-1), 0),
332 | ("1", "for", Frac(-2), 0),
333 | # 7
334 | ("1", "the", Frac(-1), 0),
335 | ("1", "the", Frac(-2), 0),
336 | # 8
337 | ("1", "first", Frac(-19, 10), 0),
338 | ("1", "first", Frac(-29, 10), 0),
339 | # 9
340 | ("1", "time", Frac(-16, 10), 0),
341 | ("1", "time", Frac(-26, 10), 0),
342 | # 10
343 | ("1", "initiative addresses", Frac(-4), 0),
344 | ("1", "initiative addresses", Frac(-4), 0),
345 | # 11
346 | ("1", "this for", Frac(-25, 10), 0),
347 | ("1", "this for", Frac(-35, 10), 0),
348 | # 12
349 | ("1", "the first", Frac(-22, 10), 0),
350 | ("1", "the first", Frac(-32, 10), 0),
351 | # 13
352 | ("1", "for the", Frac(-13, 10), 0),
353 | ("1", "for the", Frac(-23, 10), 0),
354 | # 14
355 | ("1", "first time", Frac(-24, 10), 0),
356 | ("1", "first time", Frac(-34, 10), 0),
357 | # 15
358 | ("1", "this for the", Frac(-27, 10), 0),
359 | ("1", "this for the", Frac(-37, 10), 0),
360 | # 16
361 | ("1", "for the first", Frac(-23, 10), 0),
362 | ("1", "for the first", Frac(-33, 10), 0),
363 | # 17
364 | ("1", "the first time", Frac(-23, 10), 0),
365 | ("1", "the first time", Frac(-33, 10), 0),
366 | # 18
367 | ("1", "for the first time", Frac(-23, 10), 0),
368 | ("1", "for the first time", Frac(-33, 10), 0),
369 | ]
370 |
371 | val = {(1, 1): -1.0,
372 | (1, 2): -3.0,
373 | (1, 3): -4.5,
374 | (1, 4): -6.9,
375 | (1, 5): -8.3,
376 | (1, 6): -9.3,
377 | (1, 7): -9.6,
378 | (1, 8): -10.6,
379 | (1, 9): -10.6,
380 | (2, 2): -2.0,
381 | (2, 3): -3.5,
382 | (2, 4): -5.9,
383 | (2, 5): -7.3,
384 | (2, 6): -8.3,
385 | (2, 7): -8.6,
386 | (2, 8): -9.6,
387 | (2, 9): -9.6,
388 | (3, 3): -1.5,
389 | (3, 4): -3.9,
390 | (3, 5): -5.3,
391 | (3, 6): -6.3,
392 | (3, 7): -6.6,
393 | (3, 8): -7.6,
394 | (3, 9): -7.6,
395 | (4, 4): -2.4,
396 | (4, 5): -3.8,
397 | (4, 6): -4.8,
398 | (4, 7): -5.1,
399 | (4, 8): -6.1,
400 | (4, 9): -6.1,
401 | (5, 5): -1.4,
402 | (5, 6): -2.4,
403 | (5, 7): -2.7,
404 | (5, 8): -3.6999999999999997,
405 | (5, 9): -3.6999999999999997,
406 | (6, 6): -1.0,
407 | (6, 7): -1.3,
408 | (6, 8): -2.3,
409 | (6, 9): -2.3,
410 | (7, 7): -1.0,
411 | (7, 8): -2.2,
412 | (7, 9): -2.3,
413 | (8, 8): -1.9,
414 | (8, 9): -2.4,
415 | (9, 9): -1.6,
416 | }
417 |
418 | with DBSetup(db) as dbobj:
419 | dbobj.session.add_all(TransPhraseProb(lang2p=item[0],
420 | lang1p=item[1],
421 | p1_2=item[2],
422 | p2_1=item[3])
423 | for item in dataset)
424 | dbobj.session.add_all(Phrase(lang1p=item[0],
425 | lang2p=item[1])
426 | for item in dataset)
427 | dbobj.session.commit()
428 |
429 | ans = future_cost_estimate(sentences,
430 | transfrom=transfrom,
431 | transto=transto,
432 | init_val=init_val,
433 | db=db)
434 | # assert
435 | self.assertEqual(ans, val)
436 |
437 | def test__future_cost_estimate(self):
438 | sentences = u"the tourism initiative addresses this\
439 | for the first time".split()
440 | phrase_prob = {(1, 1): Frac(-1),
441 | (2, 2): Frac(-2),
442 | (3, 3): Frac(-15, 10),
443 | (4, 4): Frac(-24, 10),
444 | (5, 5): Frac(-14, 10),
445 | (6, 6): Frac(-1),
446 | (7, 7): Frac(-1),
447 | (8, 8): Frac(-19, 10),
448 | (9, 9): Frac(-16, 10),
449 | (3, 4): Frac(-4),
450 | (5, 6): Frac(-25, 10),
451 | (7, 8): Frac(-22, 10),
452 | (6, 7): Frac(-13, 10),
453 | (8, 9): Frac(-24, 10),
454 | (5, 7): Frac(-27, 10),
455 | (6, 8): Frac(-23, 10),
456 | (7, 9): Frac(-23, 10),
457 | (6, 9): Frac(-23, 10),
458 | }
459 | val = {(1, 1): Frac(-1),
460 | (1, 2): Frac(-3),
461 | (1, 3): Frac(-45, 10),
462 | (1, 4): Frac(-69, 10),
463 | (1, 5): Frac(-83, 10),
464 | (1, 6): Frac(-93, 10),
465 | (1, 7): Frac(-96, 10),
466 | (1, 8): Frac(-106, 10),
467 | (1, 9): Frac(-106, 10),
468 | (2, 2): Frac(-2),
469 | (2, 3): Frac(-35, 10),
470 | (2, 4): Frac(-59, 10),
471 | (2, 5): Frac(-73, 10),
472 | (2, 6): Frac(-83, 10),
473 | (2, 7): Frac(-86, 10),
474 | (2, 8): Frac(-96, 10),
475 | (2, 9): Frac(-96, 10),
476 | (3, 3): Frac(-15, 10),
477 | (3, 4): Frac(-39, 10),
478 | (3, 5): Frac(-53, 10),
479 | (3, 6): Frac(-63, 10),
480 | (3, 7): Frac(-66, 10),
481 | (3, 8): Frac(-76, 10),
482 | (3, 9): Frac(-76, 10),
483 | (4, 4): Frac(-24, 10),
484 | (4, 5): Frac(-38, 10),
485 | (4, 6): Frac(-48, 10),
486 | (4, 7): Frac(-51, 10),
487 | (4, 8): Frac(-61, 10),
488 | (4, 9): Frac(-61, 10),
489 | (5, 5): Frac(-14, 10),
490 | (5, 6): Frac(-24, 10),
491 | (5, 7): Frac(-27, 10),
492 | (5, 8): Frac(-37, 10),
493 | (5, 9): Frac(-37, 10),
494 | (6, 6): Frac(-1),
495 | (6, 7): Frac(-13, 10),
496 | (6, 8): Frac(-23, 10),
497 | (6, 9): Frac(-23, 10),
498 | (7, 7): Frac(-1),
499 | (7, 8): Frac(-22, 10),
500 | (7, 9): Frac(-23, 10),
501 | (8, 8): Frac(-19, 10),
502 | (8, 9): Frac(-24, 10),
503 | (9, 9): Frac(-16, 10)}
504 | ans = _future_cost_estimate(sentences,
505 | phrase_prob)
506 | self.assertEqual(ans, val)
507 |
508 | def test__future_cost_estimate_dict_not_satisfied(self):
509 | sentences = u"the tourism initiative addresses this\
510 | for the first time".split()
511 | phrase_prob = {(1, 1): Frac(-1),
512 | (2, 2): Frac(-2),
513 | # lack one value
514 | #(3, 3): Frac(-15, 10),
515 | (4, 4): Frac(-24, 10),
516 | (5, 5): Frac(-14, 10),
517 | (6, 6): Frac(-1),
518 | (7, 7): Frac(-1),
519 | (8, 8): Frac(-19, 10),
520 | (9, 9): Frac(-16, 10),
521 | (3, 4): Frac(-4),
522 | (5, 6): Frac(-25, 10),
523 | (7, 8): Frac(-22, 10),
524 | (6, 7): Frac(-13, 10),
525 | (8, 9): Frac(-24, 10),
526 | (5, 7): Frac(-27, 10),
527 | (6, 8): Frac(-23, 10),
528 | (7, 9): Frac(-23, 10),
529 | (6, 9): Frac(-23, 10),
530 | }
531 | self.assertRaises(ArgumentNotSatisfied,
532 | _future_cost_estimate,
533 | sentences,
534 | phrase_prob)
535 |
536 | def test_create_estimate_dict(self):
537 | sentences = u"the tourism initiative addresses this\
538 | for the first time".split()
539 | init_val = Frac(-100)
540 | phrase_prob = {(1, 1): Frac(-1),
541 | (2, 2): Frac(-2),
542 | # lack one value
543 | #(3, 3): Frac(-15, 10),
544 | (4, 4): Frac(-24, 10),
545 | (5, 5): Frac(-14, 10),
546 | #(6, 6): Frac(-1),
547 | (7, 7): Frac(-1),
548 | # lack one value
549 | #(8, 8): Frac(-19, 10),
550 | (9, 9): Frac(-16, 10),
551 | (3, 4): Frac(-4),
552 | (5, 6): Frac(-25, 10),
553 | (7, 8): Frac(-22, 10),
554 | (6, 7): Frac(-13, 10),
555 | (8, 9): Frac(-24, 10),
556 | (5, 7): Frac(-27, 10),
557 | (6, 8): Frac(-23, 10),
558 | (7, 9): Frac(-23, 10),
559 | (6, 9): Frac(-23, 10),
560 | }
561 | correct = {(1, 1): Frac(-1),
562 | (2, 2): Frac(-2),
563 | # lack one value
564 | (3, 3): init_val,
565 | (4, 4): Frac(-24, 10),
566 | (5, 5): Frac(-14, 10),
567 | (6, 6): init_val,
568 | (7, 7): Frac(-1),
569 | # lack one value
570 | (8, 8): init_val,
571 | (9, 9): Frac(-16, 10),
572 | (3, 4): Frac(-4),
573 | (5, 6): Frac(-25, 10),
574 | (7, 8): Frac(-22, 10),
575 | (6, 7): Frac(-13, 10),
576 | (8, 9): Frac(-24, 10),
577 | (5, 7): Frac(-27, 10),
578 | (6, 8): Frac(-23, 10),
579 | (7, 9): Frac(-23, 10),
580 | (6, 9): Frac(-23, 10),
581 | }
582 | ans = _create_estimate_dict(sentences,
583 | phrase_prob,
584 | init_val=init_val)
585 | self.assertEqual(ans, correct)
586 |
587 |
588 | if __name__ == '__main__':
589 | unittest.main()
590 |
--------------------------------------------------------------------------------
/smt/decoder/stackdecoder.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # coding:utf-8
3 |
4 | from __future__ import division, print_function
5 | import math
6 | # sqlalchemy
7 | import sqlalchemy
8 | from sqlalchemy.ext.declarative import declarative_base
9 | from sqlalchemy import create_engine
10 | from sqlalchemy import Column, TEXT, REAL, INTEGER
11 | from sqlalchemy.orm import sessionmaker
12 | from smt.db.tables import Tables
13 | #from pprint import pprint
14 |
15 |
16 | # prepare classes for sqlalchemy
17 | class Phrase(declarative_base()):
18 | __tablename__ = "phrase"
19 | id = Column(INTEGER, primary_key=True)
20 | lang1p = Column(TEXT)
21 | lang2p = Column(TEXT)
22 |
23 |
24 | class TransPhraseProb(declarative_base()):
25 | __tablename__ = "phraseprob"
26 | id = Column(INTEGER, primary_key=True)
27 | lang1p = Column(TEXT)
28 | lang2p = Column(TEXT)
29 | p1_2 = Column(REAL)
30 | p2_1 = Column(REAL)
31 |
32 |
33 | def phrase_prob(lang1p, lang2p,
34 | transfrom=2,
35 | transto=1,
36 | db="sqlite:///:memory:",
37 | init_val=1.0e-10):
38 | """
39 | """
40 | engine = create_engine(db)
41 | Session = sessionmaker(bind=engine)
42 | session = Session()
43 | # search
44 | query = session.query(TransPhraseProb).filter_by(lang1p=lang1p,
45 | lang2p=lang2p)
46 | if transfrom == 2 and transto == 1:
47 | try:
48 | # Be Careful! The order of conditional prob is reversed
49 | # as transfrom and transto because of bayes rule
50 | return query.one().p2_1
51 | except sqlalchemy.orm.exc.NoResultFound:
52 | return init_val
53 | elif transfrom == 1 and transto == 2:
54 | try:
55 | return query.one().p1_2
56 | except sqlalchemy.orm.exc.NoResultFound:
57 | return init_val
58 |
59 |
60 | def available_phrases(inputs, transfrom=2, transto=1, db="sqlite:///:memory:"):
61 | """
62 | >>> decode.available_phrases(u"He is a teacher.".split(),
63 | db_name="sqlite:///:db:"))
64 | set([((1, u'He'),),
65 | ((1, u'He'), (2, u'is')),
66 | ((2, u'is'),),
67 | ((2, u'is'), (3, u'a')),
68 | ((3, u'a'),),
69 | ((4, u'teacher.'),)])
70 | """
71 | engine = create_engine(db)
72 | # create session
73 | Session = sessionmaker(bind=engine)
74 | session = Session()
75 | available = set()
76 | for i, f in enumerate(inputs):
77 | f_rest = ()
78 | for fr in inputs[i:]:
79 | f_rest += (fr,)
80 | rest_phrase = u" ".join(f_rest)
81 | if transfrom == 2 and transto == 1:
82 | query = session.query(Phrase).filter_by(lang2p=rest_phrase)
83 | elif transfrom == 1 and transto == 2:
84 | query = session.query(Phrase).filter_by(lang1p=rest_phrase)
85 | lst = list(query)
86 | if lst:
87 | available.add(tuple(enumerate(f_rest, i+1)))
88 | return available
89 |
90 |
91 | class HypothesisBase(object):
92 | def __init__(self,
93 | db,
94 | totalnumber,
95 | sentences,
96 | ngram,
97 | ngram_words,
98 | inputps_with_index,
99 | outputps,
100 | transfrom,
101 | transto,
102 | covered,
103 | remained,
104 | start,
105 | end,
106 | prev_start,
107 | prev_end,
108 | remain_phrases,
109 | prob,
110 | prob_with_cost,
111 | prev_hypo,
112 | cost_dict
113 | ):
114 |
115 | self._db = db
116 | self._totalnumber = totalnumber
117 | self._sentences = sentences
118 | self._ngram = ngram
119 | self._ngram_words = ngram_words
120 | self._inputps_with_index = inputps_with_index
121 | self._outputps = outputps
122 | self._transfrom = transfrom
123 | self._transto = transto
124 | self._covered = covered
125 | self._remained = remained
126 | self._start = start
127 | self._end = end
128 | self._prev_start = prev_start
129 | self._prev_end = prev_end
130 | self._remain_phrases = remain_phrases
131 | self._prob = prob
132 | self._prob_with_cost = prob_with_cost
133 | self._prev_hypo = prev_hypo
134 | self._cost_dict = cost_dict
135 |
136 | self._output_sentences = outputps
137 |
138 | @property
139 | def db(self):
140 | return self._db
141 |
142 | @property
143 | def totalnumber(self):
144 | return self._totalnumber
145 |
146 | @property
147 | def sentences(self):
148 | return self._sentences
149 |
150 | @property
151 | def ngram(self):
152 | return self._ngram
153 |
154 | @property
155 | def ngram_words(self):
156 | return self._ngram_words
157 |
158 | @property
159 | def inputps_with_index(self):
160 | return self._inputps_with_index
161 |
162 | @property
163 | def outputps(self):
164 | return self._outputps
165 |
166 | @property
167 | def transfrom(self):
168 | return self._transfrom
169 |
170 | @property
171 | def transto(self):
172 | return self._transto
173 |
174 | @property
175 | def covered(self):
176 | return self._covered
177 |
178 | @property
179 | def remained(self):
180 | return self._remained
181 |
182 | @property
183 | def start(self):
184 | return self._start
185 |
186 | @property
187 | def end(self):
188 | return self._end
189 |
190 | @property
191 | def prev_start(self):
192 | return self._prev_start
193 |
194 | @property
195 | def prev_end(self):
196 | return self._prev_end
197 |
198 | @property
199 | def remain_phrases(self):
200 | return self._remain_phrases
201 |
202 | @property
203 | def prob(self):
204 | return self._prob
205 |
206 | @property
207 | def prob_with_cost(self):
208 | return self._prob_with_cost
209 |
210 | @property
211 | def prev_hypo(self):
212 | return self._prev_hypo
213 |
214 | @property
215 | def cost_dict(self):
216 | return self._cost_dict
217 |
218 | @property
219 | def output_sentences(self):
220 | return self._output_sentences
221 |
222 | def __unicode__(self):
223 | d = [("db", self._db),
224 | ("sentences", self._sentences),
225 | ("inputps_with_index", self._inputps_with_index),
226 | ("outputps", self._outputps),
227 | ("ngram", self._ngram),
228 | ("ngram_words", self._ngram_words),
229 | ("transfrom", self._transfrom),
230 | ("transto", self._transto),
231 | ("covered", self._covered),
232 | ("remained", self._remained),
233 | ("start", self._start),
234 | ("end", self._end),
235 | ("prev_start", self._prev_start),
236 | ("prev_end", self._prev_end),
237 | ("remain_phrases", self._remain_phrases),
238 | ("prob", self._prob),
239 | ("prob_with_cost", self._prob_with_cost),
240 | #("cost_dict", self._cost_dict),
241 | #("prev_hypo", ""),
242 | ]
243 | return u"Hypothesis Object\n" +\
244 | u"\n".join([u" " + k + u": " +
245 | unicode(v) for (k, v) in d])
246 |
247 | def __str__(self):
248 | return unicode(self).encode('utf-8')
249 |
250 | def __hash__(self):
251 | return hash(unicode(self))
252 |
253 |
254 | class Hypothesis(HypothesisBase):
255 | """
256 | Realize like the following class
257 |
258 | >>> args = {"sentences": sentences,
259 | ... "inputps_with_index": phrase,
260 | ... "outputps": outputps,
261 | ... "covered": hyp0.covered.union(set(phrase)),
262 | ... "remained": hyp0.remained.difference(set(phrase)),
263 | ... "start": phrase[0][0],
264 | ... "end": phrase[-1][0],
265 | ... "prev_start": hyp0.start,
266 | ... "prev_end": hyp0.end,
267 | ... "remain_phrases": remain_phrases(phrase,
268 | ... hyp0.remain_phrases),
269 | ... "prev_hypo": hyp0
270 | ... }
271 |
272 | >>> hyp1 = decode.HypothesisBase(**args)
273 | """
274 |
275 | def __init__(self,
276 | prev_hypo,
277 | inputps_with_index,
278 | outputps,
279 | ):
280 |
281 | start = inputps_with_index[0][0]
282 | end = inputps_with_index[-1][0]
283 | prev_start = prev_hypo.start
284 | prev_end = prev_hypo.end
285 | args = {"db": prev_hypo.db,
286 | "totalnumber": prev_hypo.totalnumber,
287 | "prev_hypo": prev_hypo,
288 | "sentences": prev_hypo.sentences,
289 | "ngram": prev_hypo.ngram,
290 | # set later
291 | "ngram_words": prev_hypo.ngram_words,
292 | "inputps_with_index": inputps_with_index,
293 | "outputps": outputps,
294 | "transfrom": prev_hypo.transfrom,
295 | "transto": prev_hypo.transto,
296 | "covered": prev_hypo.covered.union(set(inputps_with_index)),
297 | "remained": prev_hypo.remained.difference(
298 | set(inputps_with_index)),
299 | "start": start,
300 | "end": end,
301 | "prev_start": prev_start,
302 | "prev_end": prev_end,
303 | "remain_phrases": self._calc_remain_phrases(
304 | inputps_with_index,
305 | prev_hypo.remain_phrases),
306 | "cost_dict": prev_hypo.cost_dict,
307 | # set later
308 | "prob": 0,
309 | "prob_with_cost": 0,
310 | }
311 | HypothesisBase.__init__(self, **args)
312 | # set ngram words
313 | self._ngram_words = self._set_ngram_words()
314 | # set the exact probability
315 | self._prob = self._cal_prob(start - prev_end)
316 | # set the exact probability with cost
317 | self._prob_with_cost = self._cal_prob_with_cost(start - prev_end)
318 | # set the output phrases
319 | self._output_sentences = prev_hypo.output_sentences + outputps
320 |
321 | def _set_ngram_words(self):
322 | lst = self._prev_hypo.ngram_words + list(self._outputps)
323 | o_len = len(self._outputps)
324 | return list(reversed(list(reversed(lst))[:o_len - 1 + self._ngram]))
325 |
326 | def _cal_phrase_prob(self):
327 | inputp = u" ".join(zip(*self._inputps_with_index)[1])
328 | outputp = u" ".join(self._outputps)
329 |
330 | if self._transfrom == 2 and self._transto == 1:
331 | return phrase_prob(lang1p=outputp,
332 | lang2p=inputp,
333 | transfrom=self._transfrom,
334 | transto=self._transto,
335 | db=self._db,
336 | init_val=-100)
337 | elif self._transfrom == 1 and self._transto == 2:
338 | return phrase_prob(lang1p=inputp,
339 | lang2p=outputp,
340 | transfrom=self._transfrom,
341 | transto=self._transto,
342 | db=self._db,
343 | init_val=-100)
344 | else:
345 | raise Exception("specify transfrom and transto")
346 |
347 | def _cal_language_prob(self):
348 | nw = self.ngram_words
349 | triwords = zip(nw, nw[1:], nw[2:])
350 | prob = 0
351 | for first, second, third in triwords:
352 | prob += language_model(first, second, third, self._totalnumber,
353 | transto=self._transto,
354 | db=self._db)
355 | return prob
356 |
357 | def _cal_prob(self, dist):
358 | val = self._prev_hypo.prob +\
359 | self._reordering_model(0.1, dist) +\
360 | self._cal_phrase_prob() +\
361 | self._cal_language_prob()
362 | return val
363 |
364 | def _sub_cal_prob_with_cost(self, s_len, cvd):
365 | insert_flag = False
366 | lst = []
367 | sub_lst = []
368 | for i in range(1, s_len+1):
369 | if i not in cvd:
370 | insert_flag = True
371 | else:
372 | insert_flag = False
373 | if sub_lst:
374 | lst.append(sub_lst)
375 | sub_lst = []
376 | if insert_flag:
377 | sub_lst.append(i)
378 | else:
379 | if sub_lst:
380 | lst.append(sub_lst)
381 | return lst
382 |
383 | def _cal_prob_with_cost(self, dist):
384 | s_len = len(self._sentences)
385 | cvd = set(i for i, val in self._covered)
386 | lst = self._sub_cal_prob_with_cost(s_len, cvd)
387 | prob = self._cal_prob(dist)
388 | prob_with_cost = prob
389 | for item in lst:
390 | start = item[0]
391 | end = item[-1]
392 | cost = self._cost_dict[(start, end)]
393 | prob_with_cost += cost
394 | return prob_with_cost
395 |
396 | def _reordering_model(self, alpha, dist):
397 | return math.log(math.pow(alpha, math.fabs(dist)))
398 |
399 | def _calc_remain_phrases(self, phrase, phrases):
400 | """
401 | >>> res = remain_phrases(((2, u'is'),),
402 | set([((1, u'he'),),
403 | ((2, u'is'),),
404 | ((3, u'a'),),
405 | ((2, u'is'),
406 | (3, u'a')),
407 | ((4, u'teacher'),)]))
408 | set([((1, u'he'),), ((3, u'a'),), ((4, u'teacher'),)])
409 | >>> res = remain_phrases(((2, u'is'), (3, u'a')),
410 | set([((1, u'he'),),
411 | ((2, u'is'),),
412 | ((3, u'a'),),
413 | ((2, u'is'),
414 | (3, u'a')),
415 | ((4, u'teacher'),)]))
416 | set([((1, u'he'),), ((4, u'teacher'),)])
417 | """
418 | s = set()
419 | for ph in phrases:
420 | for p in phrase:
421 | if p in ph:
422 | break
423 | else:
424 | s.add(ph)
425 | return s
426 |
427 |
428 | def create_empty_hypothesis(sentences, cost_dict,
429 | ngram=3, transfrom=2, transto=1,
430 | db="sqlite:///:memory:"):
431 | phrases = available_phrases(sentences,
432 | db=db)
433 | hyp0 = HypothesisBase(sentences=sentences,
434 | db=db,
435 | totalnumber=_get_total_number(transto=transto,
436 | db=db),
437 | inputps_with_index=(),
438 | outputps=[],
439 | ngram=ngram,
440 | ngram_words=["", ""]*ngram,
441 | transfrom=transfrom,
442 | transto=transto,
443 | covered=set(),
444 | start=0,
445 | end=0,
446 | prev_start=0,
447 | prev_end=0,
448 | remained=set(enumerate(sentences, 1)),
449 | remain_phrases=phrases,
450 | prev_hypo=None,
451 | prob=0,
452 | cost_dict=cost_dict,
453 | prob_with_cost=0)
454 | #print(_get_total_number(transto=transto, db=db))
455 | return hyp0
456 |
457 |
458 | class Stack(set):
459 | def __init__(self, size=10,
460 | histogram_pruning=True,
461 | threshold_pruning=False):
462 | set.__init__(self)
463 | self._min_hyp = None
464 | self._max_hyp = None
465 | self._size = size
466 | self._histogram_pruning = histogram_pruning
467 | self._threshold_pruning = threshold_pruning
468 |
469 | def add_hyp(self, hyp):
470 | #prob = hyp.prob
471 | # for the first time
472 | if self == set([]):
473 | self._min_hyp = hyp
474 | self._max_hyp = hyp
475 | else:
476 | raise Exception("Don't use add_hyp for nonempty stack")
477 | #else:
478 | # if self._min_hyp.prob > prob:
479 | # self._min_hyp = hyp
480 | # if self._max_hyp.prob < prob:
481 | # self._max_hyp = hyp
482 | self.add(hyp)
483 |
484 | def _get_min_hyp(self):
485 | # set value which is more than 1
486 | lst = list(self)
487 | mn = lst[0]
488 | for item in self:
489 | if item.prob_with_cost < mn.prob_with_cost:
490 | mn = item
491 | return mn
492 |
493 | def add_with_combine_prune(self, hyp):
494 | prob_with_cost = hyp.prob_with_cost
495 | if self == set([]):
496 | self._min_hyp = hyp
497 | self._max_hyp = hyp
498 | else:
499 | if self._min_hyp.prob_with_cost > prob_with_cost:
500 | self._min_hyp = hyp
501 | if self._max_hyp.prob_with_cost < prob_with_cost:
502 | self._max_hyp = hyp
503 | self.add(hyp)
504 | # combine
505 | for _hyp in self:
506 | if hyp.ngram_words[:-1] == _hyp.ngram_words[:-1] and \
507 | hyp.end == hyp.end:
508 | if hyp.prob_with_cost > _hyp:
509 | self.remove(_hyp)
510 | self.add(hyp)
511 | break
512 | # histogram pruning
513 | if self._histogram_pruning:
514 | if len(self) > self._size:
515 | self.remove(self._min_hyp)
516 | self._min_hyp = self._get_min_hyp()
517 | # threshold pruning
518 | if self._threshold_pruning:
519 | alpha = 1.0e-5
520 | if hyp.prob_with_cost < self._max_hyp + math.log(alpha):
521 | self.remove(hyp)
522 |
523 |
524 | def _get_total_number(transto=1, db="sqlite:///:memory:"):
525 | """
526 | return v
527 | """
528 |
529 | Trigram = Tables().get_trigram_table('lang{}trigram'.format(transto))
530 |
531 | # create connection in SQLAlchemy
532 | engine = create_engine(db)
533 | # create session
534 | Session = sessionmaker(bind=engine)
535 | session = Session()
536 |
537 | # calculate total number
538 | query = session.query(Trigram)
539 |
540 | return len(list(query))
541 |
542 |
543 | def language_model(first, second, third, totalnumber, transto=1,
544 | db="sqlalchemy:///:memory:"):
545 |
546 | class TrigramProb(declarative_base()):
547 | __tablename__ = 'lang{}trigramprob'.format(transto)
548 | id = Column(INTEGER, primary_key=True)
549 | first = Column(TEXT)
550 | second = Column(TEXT)
551 | third = Column(TEXT)
552 | prob = Column(REAL)
553 |
554 | class TrigramProbWithoutLast(declarative_base()):
555 | __tablename__ = 'lang{}trigramprob'.format(transto)
556 | id = Column(INTEGER, primary_key=True)
557 | first = Column(TEXT)
558 | second = Column(TEXT)
559 | prob = Column(REAL)
560 |
561 | # create session
562 | engine = create_engine(db)
563 | Session = sessionmaker(bind=engine)
564 | session = Session()
565 | try:
566 | # next line can raise error if the prob is not found
567 | query = session.query(TrigramProb).filter_by(first=first,
568 | second=second,
569 | third=third)
570 | item = query.one()
571 | return item.prob
572 | except sqlalchemy.orm.exc.NoResultFound:
573 | query = session.query(TrigramProbWithoutLast
574 | ).filter_by(first=first,
575 | second=second)
576 | # I have to modify the database
577 | item = query.first()
578 | if item:
579 | return item.prob
580 | else:
581 | return - math.log(totalnumber)
582 |
583 |
584 | class ArgumentNotSatisfied(Exception):
585 | pass
586 |
587 |
588 | def _future_cost_estimate(sentences,
589 | phrase_prob):
590 | '''
591 | warning:
592 | pass the complete one_word_prob
593 | '''
594 | s_len = len(sentences)
595 | cost = {}
596 |
597 | one_word_prob = {(st, ed): prob for (st, ed), prob in phrase_prob.items()
598 | if st == ed}
599 |
600 | if set(one_word_prob.keys()) != set((x, x) for x in range(1, s_len+1)):
601 | raise ArgumentNotSatisfied("phrase_prob doesn't satisfy the condition")
602 |
603 | # add one word prob
604 | for tpl, prob in one_word_prob.items():
605 | index = tpl[0]
606 | cost[(index, index)] = prob
607 |
608 | for length in range(1, s_len+1):
609 | for start in range(1, s_len-length+1):
610 | end = start + length
611 | try:
612 | cost[(start, end)] = phrase_prob[(start, end)]
613 | except KeyError:
614 | cost[(start, end)] = -float('inf')
615 | for i in range(start, end):
616 | _val = cost[(start, i)] + cost[(i+1, end)]
617 | if _val > cost[(start, end)]:
618 | cost[(start, end)] = _val
619 | return cost
620 |
621 |
622 | def _create_estimate_dict(sentences,
623 | phrase_prob,
624 | init_val=-100):
625 | one_word_prob_dict_nums = set(x for x, y in phrase_prob.keys() if x == y)
626 | comp_dic = {}
627 | # complete the one_word_prob
628 | s_len = len(sentences)
629 | for i in range(1, s_len+1):
630 | if i not in one_word_prob_dict_nums:
631 | comp_dic[(i, i)] = init_val
632 | for key, val in phrase_prob.items():
633 | comp_dic[key] = val
634 | return comp_dic
635 |
636 |
637 | def _get_total_number_for_fce(transto=1, db="sqlite:///:memory:"):
638 | """
639 | return v
640 | """
641 | # create connection in SQLAlchemy
642 | engine = create_engine(db)
643 | # create session
644 | Session = sessionmaker(bind=engine)
645 | session = Session()
646 |
647 | tablename = 'lang{}unigram'.format(transto)
648 | Unigram = Tables().get_unigram_table(tablename)
649 |
650 | # calculate total number
651 | query = session.query(Unigram)
652 | sm = 0
653 | totalnumber = 0
654 | for item in query:
655 | totalnumber += 1
656 | sm += item.count
657 | return {'totalnumber': totalnumber,
658 | 'sm': sm}
659 |
660 |
661 | def _future_cost_langmodel(word,
662 | tn,
663 | transfrom=2,
664 | transto=1,
665 | alpha=0.00017,
666 | db="sqlite:///:memory:"):
667 | tablename = "lang{}unigramprob".format(transto)
668 | # create session
669 | engine = create_engine(db)
670 | Session = sessionmaker(bind=engine)
671 | session = Session()
672 |
673 | UnigramProb = Tables().get_unigramprob_table(tablename)
674 | query = session.query(UnigramProb).filter_by(first=word)
675 | try:
676 | item = query.one()
677 | return item.prob
678 | except sqlalchemy.orm.exc.NoResultFound:
679 | sm = tn['sm']
680 | totalnumber = tn['totalnumber']
681 | return math.log(alpha) - math.log(sm + alpha*totalnumber)
682 |
683 |
684 | def future_cost_estimate(sentences,
685 | transfrom=2,
686 | transto=1,
687 | init_val=-100.0,
688 | db="sqlite:///:memory:"):
689 | # create phrase_prob table
690 | engine = create_engine(db)
691 | # create session
692 | Session = sessionmaker(bind=engine)
693 | session = Session()
694 | phrases = available_phrases(sentences,
695 | db=db)
696 |
697 | tn = _get_total_number_for_fce(transto=transto, db=db)
698 | covered = {}
699 | for phrase in phrases:
700 | phrase_str = u" ".join(zip(*phrase)[1])
701 | if transfrom == 2 and transto == 1:
702 | query = session.query(TransPhraseProb).filter_by(
703 | lang2p=phrase_str).order_by(
704 | sqlalchemy.desc(TransPhraseProb.p2_1))
705 | elif transfrom == 1 and transto == 2:
706 | query = session.query(TransPhraseProb).filter_by(
707 | lang1p=phrase_str).order_by(
708 | sqlalchemy.desc(TransPhraseProb.p1_2))
709 | lst = list(query)
710 | if lst:
711 | # extract the maximum val
712 | val = query.first()
713 | start = zip(*phrase)[0][0]
714 | end = zip(*phrase)[0][-1]
715 | pos = (start, end)
716 | if transfrom == 2 and transto == 1:
717 | fcl = _future_cost_langmodel(word=val.lang1p.split()[0],
718 | tn=tn,
719 | transfrom=transfrom,
720 | transto=transto,
721 | alpha=0.00017,
722 | db=db)
723 | print(val.lang1p.split()[0], fcl)
724 | covered[pos] = val.p2_1 + fcl
725 | if transfrom == 1 and transto == 2:
726 | covered[pos] = val.p1_2
727 | # + language_model()
728 | # estimate future costs
729 | phrase_prob = _create_estimate_dict(sentences, covered)
730 | print(phrase_prob)
731 |
732 | return _future_cost_estimate(sentences,
733 | phrase_prob)
734 |
735 |
736 | def stack_decoder(sentence, transfrom=2, transto=1,
737 | stacksize=10,
738 | searchsize=10,
739 | lang1method=lambda x: x,
740 | lang2method=lambda x: x,
741 | db="sqlite:///:memory:",
742 | verbose=False):
743 | # create phrase_prob table
744 | engine = create_engine(db)
745 | # create session
746 | Session = sessionmaker(bind=engine)
747 | session = Session()
748 |
749 | if transfrom == 2 and transto == 1:
750 | sentences = lang2method(sentence).split()
751 | else:
752 | sentences = lang1method(sentence).split()
753 | # create stacks
754 | len_sentences = len(sentences)
755 | stacks = [Stack(size=stacksize,
756 | histogram_pruning=True,
757 | threshold_pruning=False,
758 | ) for i in range(len_sentences+1)]
759 |
760 | cost_dict = future_cost_estimate(sentences,
761 | transfrom=transfrom,
762 | transto=transto,
763 | db=db)
764 | #create the initial hypothesis
765 | hyp0 = create_empty_hypothesis(sentences=sentences,
766 | cost_dict=cost_dict,
767 | ngram=3,
768 | transfrom=2,
769 | transto=1,
770 | db=db)
771 | stacks[0].add_hyp(hyp0)
772 |
773 | # main loop
774 | for i, stack in enumerate(stacks):
775 | for hyp in stack:
776 | for phrase in hyp.remain_phrases:
777 | phrase_str = u" ".join(zip(*phrase)[1])
778 | if transfrom == 2 and transto == 1:
779 | query = session.query(TransPhraseProb).filter_by(
780 | lang2p=phrase_str).order_by(
781 | sqlalchemy.desc(TransPhraseProb.p2_1))[:searchsize]
782 | elif transfrom == 1 and transto == 2:
783 | query = session.query(TransPhraseProb).filter_by(
784 | lang1p=phrase_str).order_by(
785 | sqlalchemy.desc(TransPhraseProb.p1_2))[:searchsize]
786 | query = list(query)
787 | for item in query:
788 | if transfrom == 2 and transto == 1:
789 | outputp = item.lang1p
790 | elif transfrom == 1 and transto == 2:
791 | outputp = item.lang2p
792 | #print(u"calculating\n {0} = {1}\n in stack {2}".format(
793 | # phrase, outputp, i))
794 | if transfrom == 2 and transto == 1:
795 | outputps = lang1method(outputp).split()
796 | elif transfrom == 1 and transto == 2:
797 | outputps = lang2method(outputp).split()
798 | # place in stack
799 | # and recombine with existing hypothesis if possible
800 | new_hyp = Hypothesis(prev_hypo=hyp,
801 | inputps_with_index=phrase,
802 | outputps=outputps)
803 | if verbose:
804 | print(phrase, u' '.join(outputps))
805 | print("loop: ", i, "len:", len(new_hyp.covered))
806 | stacks[len(new_hyp.covered)].add_with_combine_prune(
807 | new_hyp)
808 | return stacks
809 |
810 |
811 | if __name__ == '__main__':
812 | #import doctest
813 | #doctest.testmod()
814 | pass
815 |
--------------------------------------------------------------------------------