├── .gitignore ├── README.mkd ├── bin ├── mc_learn └── mc_talk ├── markovchains ├── __init__.py ├── database.py ├── markovchains.py ├── mysql.py ├── postgresql.py ├── settings.ini.sample ├── test.py └── util.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | syntax: glob 2 | settings.ini 3 | .DS_Store 4 | *.pyc 5 | build/* 6 | dist/* 7 | -------------------------------------------------------------------------------- /README.mkd: -------------------------------------------------------------------------------- 1 | # markovchains 2 | 3 | マルコフ連鎖による発言の自動生成を行うライブラリです 4 | 5 | ## 必要なライブラリ 6 | 7 | - [extractword](http://github.com/yono/py-extractword) 8 | 9 | ### データベースを用いる場合 10 | 11 | - [MySQLdb](http://sourceforge.net/projects/mysql-python/) 12 | - [psycopg2](http://initd.org/psycopg/) 13 | 14 | ## インストール 15 | 16 | % git clone git://github.com/yono/python-markovchains.git 17 | % cd python-markovchains 18 | % sudo python setup.py install 19 | 20 | ### DBを使う場合 21 | 22 | 1. bin/settings.ini.sample を settings.ini に変更 23 | 2. settings.ini にDB作成に使うユーザ名とパスワードを書く 24 | 3. settings.ini を site-packages/markovchains/ にコピー 25 | 26 | ## ライブラリの使い方 27 | 28 | ### 文章を読み込んで発言生成 29 | 30 | #!/usr/bin/env python 31 | # -*- coding:utf-8 -*- 32 | 33 | from markovchains import markovchains 34 | 35 | ## インスタンス生成 36 | m = markovchains.MarkovChains() 37 | 38 | text = u""" 39 | 親譲(おやゆず)りの無鉄砲(むてっぽう)で小供の時から損ばかりしている。 40 | 小学校に居る時分学校の二階から飛び降りて一週間ほど腰(こし)を抜(ぬ)か 41 | した事がある。なぜそんな無闇(むやみ)をしたと聞く人があるかも知れぬ。別 42 | 段深い理由でもない。新築の二階から首を出していたら、同級生の一人が冗談( 43 | じょうだん)に、いくら威張(いば)っても、そこから飛び降りる事は出来まい 44 | 。弱虫やーい。と囃(はや)したからである。小使(こづかい)に負ぶさって帰 45 | って来た時、おやじが大きな眼(め)をして二階ぐらいから飛び降りて腰を抜か 46 | す奴(やつ)があるかと云(い)ったから、この次は抜かさずに飛んで見せます 47 | と答えた。 48 | """ 49 | 50 | ## 文章解析 51 | m.analyze_sentence(text) 52 | 53 | ## 文章生成 54 | print m.make_sentence() 55 | 56 | 57 | ### 文章をデータベースに読み込む 58 | 59 | #!/usr/bin/env python 60 | # -*- coding:utf-8 -*- 61 | 62 | from markovchains import markovchains 63 | 64 | ## インスタンス生成 65 | m = markovchains.MarkovChains() 66 | 67 | text = u""" 68 | 親譲(おやゆず)りの無鉄砲(むてっぽう)で小供の時から損ばかりしている。 69 | 小学校に居る時分学校の二階から飛び降りて一週間ほど腰(こし)を抜(ぬ)か 70 | した事がある。なぜそんな無闇(むやみ)をしたと聞く人があるかも知れぬ。別 71 | 段深い理由でもない。新築の二階から首を出していたら、同級生の一人が冗談( 72 | じょうだん)に、いくら威張(いば)っても、そこから飛び降りる事は出来まい 73 | 。弱虫やーい。と囃(はや)したからである。小使(こづかい)に負ぶさって帰 74 | って来た時、おやじが大きな眼(め)をして二階ぐらいから飛び降りて腰を抜か 75 | す奴(やつ)があるかと云(い)ったから、この次は抜かさずに飛んで見せます 76 | と答えた。 77 | """ 78 | 79 | ## 文章解析 80 | m.analyze_sentence(text) 81 | 82 | ## DB モジュールをロード 83 | m.load_db('mysql', 'markov') 84 | 85 | ## 読み込んだ文章を DB に保存 86 | m.db.register_data() 87 | 88 | ### データベースに読み込んだデータから文章生成 89 | 90 | #!/usr/bin/env python 91 | # -*- coding:utf-8 -*- 92 | 93 | from markovchains import markovchains 94 | 95 | ## インスタンス生成 96 | m = markovchains.MarkovChains() 97 | 98 | ## DB モジュールをロード 99 | m.load_db('mysql', 'markov') 100 | 101 | ## 文章生成 102 | m.db.make_sentence() 103 | 104 | ## コマンドの使い方 105 | 106 | markovchains では、プログラムを書かずともマルコフ連鎖による文章生成を行えるようにライブラリの機能と対応したコマンドを用意しています。 107 | 108 | ### 用意した発言データを元に発言生成 109 | 110 | 発言データが書かれたファイル/ファイルが入ったディレクトリを引数にとって発言を生成します。 111 | 112 | % mc_talk -f FILEorDIRECTORY [-n num] 113 | 114 | この機能を使う場合、データベースの設定は必要ありません。 115 | 116 | options: 117 | -n, --number 単語を組み合わせる数。この数字が小さいと支離滅裂な発言が生成されることが多くなります。逆にこの数字が大きいと日本語らしい文が生成されやすくなりますが、同じような文が何度も生成されやすくなります。 118 | 119 | ### 発言データをデータベースに読み込む 120 | 121 | 発言データをファイル/ファイルが入ったディレクトリからデータベースに読み込みます。読み込んだ発言データは文章生成に使うことができます。 122 | 123 | % mc_learn FILEorDIRECTORY [-u username] [-d dbname] [-b database] [-n num] 124 | 125 | 引数には発言データが書かれたテキストファイルか,テキストファイルが入ってるディレクトリを指定します。 126 | 127 | option: 128 | -u, --user ユーザー 129 | -d, --dbname データベース名 130 | -b, --database RDBMSの種類(mysql or postgresql) 131 | -n, --num: 単語を組み合わせる数 132 | 133 | ### データベースに読み込んだデータを元に発言生成 134 | 135 | mc_learn で読み込んだ発言データを元に文章を生成します。 136 | 137 | % mc_talk [-u user] [-d dbname] [-b database] [-n name] 138 | 139 | option: 140 | -u, --user ユーザー 141 | -d, --dbname データベース名 142 | -b, --database RDBMSの種類(mysql or postgresql) 143 | -n, --num 単語を組み合わせる数 144 | -------------------------------------------------------------------------------- /bin/mc_learn: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | import sys 5 | from optparse import OptionParser, OptionValueError 6 | import markovchains 7 | 8 | if __name__ == '__main__': 9 | usage = "usage: %prog filename/dirname [options]" 10 | parser = OptionParser(usage) 11 | parser.add_option('-u', '--user', action='store', 12 | help=u'発言したユーザ') 13 | parser.add_option('-d', '--dbname', action='store', 14 | help=u'データベース名') 15 | parser.add_option('-n', '--num', action='store', 16 | help=u'N階の数値') 17 | parser.add_option('-b', '--database', action='store', 18 | help=u'RDBMSの種類') 19 | (options, args) = parser.parse_args() 20 | user = '' 21 | dbname = 'markov' 22 | num = 3 23 | database = 'mysql' 24 | if options.user != None: 25 | user = options.user 26 | if options.dbname != None: 27 | dbname = options.dbname 28 | if options.num != None: 29 | num = int(options.num) 30 | if options.database != None: 31 | database = options.database 32 | 33 | fileordir = os.path.join(os.environ['PWD'], sys.argv[1]) 34 | files = [] 35 | if os.path.isdir(fileordir): 36 | files = os.listdir(fileordir) 37 | for i in xrange(len(files)): 38 | files[i] = os.path.join(fileordir, files[i]) 39 | elif os.path.isfile(fileordir): 40 | files.append(fileordir) 41 | else: 42 | quit() 43 | 44 | m = markovchains.MarkovChains(num) 45 | m.load_db(databsae,dbname) 46 | sentences = [] 47 | print "Loading file....." 48 | for file in files: 49 | print file 50 | f = unicode(open(file).read(), 'utf-8', 'ignore').splitlines() 51 | for line in f: 52 | last = m.analyze_sentence(line, user) 53 | print "Registering data....." 54 | m.db.register_data() 55 | -------------------------------------------------------------------------------- /bin/mc_talk: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | from optparse import OptionParser, OptionValueError 5 | import markovchains 6 | 7 | if __name__ == '__main__': 8 | usage = "usage: %prog filename/dirname [options]" 9 | parser = OptionParser(usage) 10 | parser.add_option('-u', '--user', action='store', 11 | help=u'発言したユーザ') 12 | parser.add_option('-d', '--dbname', action='store', 13 | help=u'データベース名') 14 | parser.add_option('-n', '--num', action='store', 15 | help=u'N階の数値') 16 | parser.add_option('-b', '--database', action='store', 17 | help=u'RDBMSの種類') 18 | parser.add_option('-f', '--fileordir', action='store', 19 | help=u'発言の元になる文') 20 | (options, args) = parser.parse_args() 21 | user = '' 22 | dbname = 'markov' 23 | order_num = 3 24 | database = 'mysql' 25 | fileordir = '' 26 | if options.user != None: 27 | user = options.user 28 | if options.dbname != None: 29 | dbname = options.dbname 30 | if options.num != None: 31 | order_num = int(options.num) 32 | if options.database != None: 33 | database = options.database 34 | if options.fileordir != None: 35 | fileordir = options.fileordir 36 | m = markovchains.MarkovChains(order_num) 37 | if fileordir: 38 | files = [] 39 | if os.path.isdir(fileordir): 40 | files = os.listdir(fileordir) 41 | for i in xrange(len(files)): 42 | files[i] = os.path.join(fileordir, files[i]) 43 | elif os.path.isfile(fileordir): 44 | files.append(fileordir) 45 | else: 46 | quit() 47 | for file in files: 48 | f = unicode(open(file).read(), 'utf-8', 'ignore').splitlines() 49 | for line in f: 50 | m.analyze_sentence(line, user) 51 | print m.make_sentence(user=user) 52 | else: 53 | m.load_db(database, dbname) 54 | print m.db.make_sentence(user=user) 55 | -------------------------------------------------------------------------------- /markovchains/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yono/python-markovchains/60cd924f598d38c9162b00bbc448129d7f648da5/markovchains/__init__.py -------------------------------------------------------------------------------- /markovchains/database.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import mysql 4 | import postgresql 5 | 6 | class Database(object): 7 | 8 | @classmethod 9 | def create(cls, db, dbname): 10 | if db == 'mysql': 11 | return mysql.MySQL(dbname) 12 | elif db == 'postgresql': 13 | return postgresql.PostgreSQL(dbname) 14 | 15 | -------------------------------------------------------------------------------- /markovchains/markovchains.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import re 4 | import random 5 | import MySQLdb 6 | 7 | from util import * 8 | from database import Database 9 | 10 | from extractword import Sentence 11 | 12 | 13 | class MarkovChains(object): 14 | 15 | def __init__(self, order_num=3): 16 | 17 | self.num = order_num 18 | 19 | self.chaindic = {} 20 | self.userchaindic = {} 21 | 22 | def load_db(self, database, dbname='markov'): 23 | self.db = Database.create(database, dbname) 24 | self.db.load_db() 25 | 26 | def _get_punctuation(self): 27 | punctuation_words = {u'。': 0, u'.': 0, u'?': 0, u'!': 0, 28 | u'!': 0, u'?': 0, u'w': 0, u'…': 0} 29 | return punctuation_words 30 | 31 | """ 32 | 文章を解析し、連想配列に保存 33 | """ 34 | def analyze_sentence(self, text, user=''): 35 | sentences = self._split_sentences(text) 36 | for sentence in sentences: 37 | words = self._get_words(sentence + u'。') 38 | self._update_newchains_ins(words) 39 | if user: 40 | self._update_newchains_ins(words, user) 41 | 42 | def _split_sentences(self, text): 43 | ps = self._get_punctuation() 44 | ps = re.compile(u'[%s]' % ('|'.join(ps.keys()))) 45 | return ps.split(text) 46 | 47 | def _get_chains(self, words): 48 | chain = [] 49 | chains = [] 50 | for word in words: 51 | if len(chain) == self.num: 52 | values = [x['name'] for x in chain] 53 | chains.append(values) 54 | chain.pop(0) 55 | chain.append(word) 56 | return chains 57 | 58 | def _get_chaindic(self, chains, user=''): 59 | is_start = True 60 | if user: 61 | if user not in self.userchaindic: 62 | self.userchaindic[user] = {} 63 | chaindic = self.userchaindic[user] 64 | else: 65 | chaindic = self.chaindic 66 | 67 | for chain in chains: 68 | prewords = tuple(chain[0:len(chain)-1]) 69 | postword = chain[-1] 70 | if prewords not in chaindic: 71 | chaindic[prewords] = {} 72 | if postword not in chaindic[prewords]: 73 | chaindic[prewords][postword] = Chain(0, 0, is_start) 74 | chaindic[prewords][postword].count += 1 75 | is_start = False 76 | 77 | def _update_newchains_ins(self, words, user=''): 78 | chainlist = self._get_chains(words) 79 | self._get_chaindic(chainlist, user) 80 | 81 | def _get_words(self, text): 82 | sentence = Sentence() 83 | sentence.analysis_text(text) 84 | words = sentence.get_words() 85 | result = [] 86 | first = True 87 | for word in words: 88 | result.append({'name': word, 'isstart': first}) 89 | first = False 90 | return result 91 | 92 | """ 93 | 連想配列を DB に保存 94 | """ 95 | def register_data(self): 96 | self.register_words() 97 | self.register_chains() 98 | self.register_userchains() 99 | 100 | def register_words(self): 101 | existwords = self.db.get_allwords() 102 | 103 | words = {} 104 | for prewords in self.chaindic: 105 | for preword in prewords: 106 | words[preword] = 0 107 | for postword in self.chaindic[prewords]: 108 | words[postword] = 0 109 | for user in self.userchaindic: 110 | for prewords in self.chaindic: 111 | for preword in prewords: 112 | words[preword] = 0 113 | for postword in self.chaindic[prewords]: 114 | words[postword] = 0 115 | 116 | sql = ["('%s')" % (MySQLdb.escape_string(x)) for x in words \ 117 | if x not in existwords] 118 | 119 | if sql: 120 | self.db.insert_words(sql) 121 | 122 | def register_chains(self): 123 | 124 | # 現在持ってる chain を全て持ってくる 125 | exists = self.db.get_allchain(self.num) 126 | words = self.db.get_allwords() 127 | 128 | # 連想配列から同じ形の chain を作成 129 | chains = {} 130 | for prewords in self.chaindic: 131 | for postword in self.chaindic[prewords]: 132 | chain = [] 133 | chain.extend(list(prewords)) 134 | chain.append(postword) 135 | chains[tuple(chain)] = self.chaindic[prewords][postword] 136 | 137 | # ない場合は新たに作成、ある場合は更新 138 | insert_step = 1000 139 | sql = [] 140 | for chain in chains: 141 | if chain in exists: 142 | ids = [words[chain[i]] for i in xrange(len(chain))] 143 | count = chains[chain].count + exists[chain].count 144 | isstart = chains[chain].isstart or exists[chain].isstart 145 | self.db.update_chains(ids, count, isstart) 146 | else: 147 | values = [] 148 | for i in xrange(self.num): 149 | values.append("%d" % (words[chain[i]])) 150 | sql.append("(%s,%s,%d)" % (','.join(values), 151 | str(chains[chain].isstart).upper(), 152 | chains[chain].count)) 153 | if (len(sql) % insert_step) == 0 and len(sql) > 0: 154 | self.db.insert_chains(sql) 155 | sql = [] 156 | 157 | if sql: 158 | self.db.insert_chains(sql) 159 | 160 | def register_userchains(self): 161 | words = self.db.get_allwords() 162 | allchains = self.db.get_allchain(self.num) 163 | 164 | for user in self.userchaindic: 165 | userid = self.db.get_user(user) 166 | exists = self.db.get_userchain(self.num, userid) 167 | chains = {} 168 | for prewords in self.userchaindic[user]: 169 | for postword in self.userchaindic[user][prewords]: 170 | chain = [] 171 | chain.extend(list(prewords)) 172 | chain.append(postword) 173 | chains[tuple(chain)] = self.chaindic[prewords][postword] 174 | 175 | insert_step = 1000 176 | sql = [] 177 | for chain in chains: 178 | try: 179 | node = allchains[chain] 180 | except: 181 | continue 182 | 183 | if chain in exists: 184 | id = exists[chain][1] 185 | count = chains[chain].count + node.count 186 | self.db.update_userchains(count, id) 187 | else: 188 | id = allchains[chain].id 189 | count = chains[chain].count 190 | sql.append("(%d,%d,%d)" % (userid, id, count)) 191 | 192 | if (len(sql) % insert_step) == 0 and len(sql) > 0: 193 | self.db.insert_userchains(sql) 194 | sql = [] 195 | 196 | if sql: 197 | self.db.insert_userchains(sql) 198 | 199 | """ 200 | 文章生成 201 | """ 202 | def make_sentence(self, user=''): 203 | limit = 1 204 | 205 | if user == '' or user not in self.userchaindic: 206 | chaindic = self.chaindic 207 | else: 208 | chaindic = self.userchaindic[user] 209 | 210 | while True: 211 | prewords = random.choice(chaindic.keys()) 212 | postword = random.choice(chaindic[prewords].keys()) 213 | if chaindic[prewords][postword].isstart: 214 | break 215 | 216 | words = [] 217 | words.extend(prewords) 218 | words.append(postword) 219 | 220 | while True: 221 | if postword in self._get_punctuation() and limit < len(words): 222 | return ''.join(words) 223 | next_prewords = list(prewords[1:len(prewords)]) 224 | next_prewords.append(postword) 225 | if tuple(next_prewords) not in chaindic: 226 | return ''.join(words) 227 | 228 | postword = self._select_nextword_from_dic(chaindic, prewords) 229 | 230 | postword = random.choice(chaindic[tuple(next_prewords)].keys()) 231 | prewords = next_prewords 232 | words.append(postword) 233 | 234 | def _select_nextword_from_dic(self, chaindic, _prewords): 235 | sum_count = 0 236 | prewords = tuple(_prewords) 237 | for postword in chaindic[prewords]: 238 | sum_count += chaindic[prewords][postword].count 239 | 240 | postwords = [] 241 | 242 | for postword in chaindic[prewords]: 243 | info = Word(id=1, name=postword, 244 | count=chaindic[prewords][postword].count/float(sum_count)) 245 | postwords.append(info) 246 | 247 | return Util.select_nextword(postwords) 248 | 249 | if __name__ == '__main__': 250 | obj = MarkovChains(order_num=3) 251 | obj.load_db('mysql', dbname='markov3') 252 | print obj.db.make_sentence() 253 | -------------------------------------------------------------------------------- /markovchains/mysql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | from ConfigParser import SafeConfigParser 5 | import copy 6 | import random 7 | 8 | try: 9 | import MySQLdb 10 | except: 11 | pass 12 | 13 | from util import * 14 | 15 | 16 | class MySQL(object): 17 | 18 | def __init__(self, dbname): 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | self.inifile = os.path.join(BASE_DIR, 'settings.ini') 21 | user, password = self._load_ini() 22 | self.con = MySQLdb.connect(user=user, passwd=password, 23 | charset='utf8', use_unicode=True) 24 | self.cur = self.con.cursor() 25 | self.dbname = dbname 26 | self.num = 3 27 | 28 | def __del__(self): 29 | self.cur.close() 30 | self.con.close() 31 | 32 | def load_db(self): 33 | self.cur.execute('show databases') 34 | rows = self.cur.fetchall() 35 | for row in rows: 36 | if row[0] == self.dbname: 37 | self.cur.execute('use %s' % (self.dbname)) 38 | return 39 | self._create_db() 40 | 41 | def _load_ini(self): 42 | parser = SafeConfigParser() 43 | parser.readfp(open(self.inifile)) 44 | user = parser.get('mysql', 'user') 45 | password = parser.get('mysql', 'password') 46 | return (user, password) 47 | 48 | 49 | """ 50 | データベース初期化 & テーブル作成 51 | """ 52 | def _create_db(self): 53 | self.cur.execute('create database %s default character set utf8' %\ 54 | (self.dbname)) 55 | self.cur.execute('use %s' % (self.dbname)) 56 | self._init_tables() 57 | 58 | def _init_tables(self): 59 | self._init_user() 60 | self._init_word() 61 | self._init_chain() 62 | self._init_userchain() 63 | 64 | def _init_user(self): 65 | self.cur.execute(''' 66 | CREATE TABLE user ( 67 | id int(11) NOT NULL auto_increment, 68 | name varchar(100) NOT NULL default '', 69 | PRIMARY KEY (id) 70 | ) DEFAULT CHARACTER SET=utf8 71 | ''') 72 | 73 | def _init_word(self): 74 | self.cur.execute(''' 75 | CREATE TABLE word ( 76 | id int(11) NOT NULL auto_increment, 77 | name varchar(100) NOT NULL default '', 78 | PRIMARY KEY (id) 79 | ) DEFAULT CHARACTER SET=utf8 80 | ''') 81 | 82 | def _init_chain(self): 83 | sql = [] 84 | sql.append(''' 85 | CREATE TABLE chain ( 86 | id int(11) NOT NULL auto_increment, 87 | ''') 88 | 89 | for i in xrange(self.num): 90 | sql.append("word%d_id int(11) NOT NULL default '0'," % i) 91 | 92 | sql.append(''' 93 | isstart BOOL NOT NULL default '0', 94 | count int(11) NOT NULL default '0', 95 | PRIMARY KEY (id), 96 | ''') 97 | 98 | for i in xrange(self.num): 99 | sql.append("FOREIGN KEY (word%d_id) REFERENCES word(id)," % i) 100 | 101 | sql.append("INDEX ") 102 | ids = [] 103 | for i in xrange(self.num - 1): 104 | ids.append("word%d_id" % i) 105 | sql.append("(%s)" % (','.join(ids))) 106 | sql.append(") DEFAULT CHARACTER SET=utf8") 107 | self.cur.execute('\n'.join(sql)) 108 | 109 | def _init_userchain(self): 110 | self.cur.execute(u''' 111 | CREATE TABLE userchain ( 112 | id int(11) NOT NULL auto_increment, 113 | user_id int(11) NOT NULL default '0', 114 | chain_id int(11) NOT NULL default '0', 115 | count int(11) NOT NULL default '0', 116 | PRIMARY KEY (id), 117 | FOREIGN KEY (user_id) REFERENCES user(id), 118 | FOREIGN KEY (chain_id) REFERENCES word(id) 119 | ) DEFAULT CHARACTER SET=utf8; 120 | ''') 121 | 122 | 123 | """ 124 | データ挿入 & 更新 125 | """ 126 | def insert_words(self, sql): 127 | self.cur.execute(u'INSERT INTO word (name) VALUES %s' %\ 128 | (','.join(sql))) 129 | 130 | def insert_chains(self, values): 131 | sql = [] 132 | sql.append("insert into chain(") 133 | sql.append(",".join(["word%d_id" % i for i in xrange(self.num)])) 134 | sql.append(",isstart,count)") 135 | sql.append("values %s" % (','.join(values))) 136 | self.cur.execute('\n'.join(sql)) 137 | 138 | def insert_userchains(self, sql): 139 | self.cur.execute(''' 140 | INSERT INTO userchain(user_id,chain_id,count) VALUES %s 141 | ''' % (','.join(sql))) 142 | 143 | def update_chains(self, ids, count, isstart): 144 | sql = [] 145 | sql.append('UPDATE chain SET count=%d' % (count)) 146 | sql.append('WHERE') 147 | for i in xrange(self.num): 148 | sql.append('word%d_id = %d and ' % (i, ids[i])) 149 | sql.append('isstart = %d' % (isstart)) 150 | self.cur.execute('\n'.join(sql)) 151 | 152 | def update_userchains(self, count, userchainid): 153 | self.cur.execute(''' 154 | UPDATE userchain 155 | SET count=%d 156 | WHERE id = %d 157 | ''' % (count, userchainid)) 158 | 159 | 160 | """ 161 | データ取得 162 | """ 163 | def get_nextwords(self, words, userid, num): 164 | sql = [] 165 | sql.append('select c.word%d_id, w.name, c.count' % (num - 1)) 166 | sql.append('from chain c') 167 | sql.append('inner join word w on c.word%d_id = w.id' % (num - 1)) 168 | sql.append(self._cond_join_userchain(userid)) 169 | sql.append(' where ') 170 | ids = [] 171 | for i in xrange(num - 1): 172 | ids.append(' c.word%d_id = %d' % (i, words[i + 1].id)) 173 | sql.append(' and'.join(ids)) 174 | sql.append(self._cond_userid(userid)) 175 | sql.append(' order by count desc') 176 | 177 | self.cur.execute('\n'.join(sql)) 178 | rows = self.cur.fetchall() 179 | result = [] 180 | for row in rows: 181 | result.append(Word(int(row[0]), row[1], int(row[2]))) 182 | return result 183 | 184 | def get_startword(self, num, userid=-1, word=None): 185 | sql = [] 186 | sql.append('select ') 187 | for i in xrange(num): 188 | sql.append('c.word%d_id, w%d.name,' % (i, i)) 189 | sql.append(' c.count') 190 | sql.append(' from chain c') 191 | for i in xrange(num): 192 | sql.append(' inner join word w%d on c.word%d_id = w%d.id'\ 193 | % (i, i, i)) 194 | sql.append(self._cond_join_userchain(userid)) 195 | sql.append(' where ') 196 | sql.append(' c.isstart = True') 197 | sql.append(self._cond_userid(userid)) 198 | sql.append(self._cond_wordname(word)) 199 | 200 | self.cur.execute('\n'.join(sql)) 201 | 202 | rows = self.cur.fetchall() 203 | row = random.choice(rows) 204 | 205 | result = [] 206 | for i in xrange(0, (num * 2) - 1, 2): 207 | result.append(Word(int(row[i]), row[i + 1], 208 | int(row[(num * 2)]))) 209 | return tuple(result) 210 | 211 | def get_allwords(self): 212 | self.cur.execute('select name,id from word') 213 | rows = self.cur.fetchall() 214 | words = dict(rows) 215 | return words 216 | 217 | def get_allchain(self, num): 218 | sql = [] 219 | sql.append("select") 220 | sql.append(','.join(["w%d.name" % i for i in xrange(num)])) 221 | sql.append(',c.isstart, c.count, c.id') 222 | sql.append('from chain c') 223 | for i in xrange(self.num): 224 | sql.append('inner join word w%d on c.word%d_id = w%d.id' %\ 225 | (i, i, i)) 226 | self.cur.execute('\n'.join(sql)) 227 | rows = self.cur.fetchall() 228 | words = {} 229 | for row in rows: 230 | id = int(row[-1]) 231 | count = int(row[-2]) 232 | isstart = row[-3] 233 | words[tuple(row[0: len(row) - 3])] = Chain(id, count, isstart) 234 | return words 235 | 236 | def get_userchain(self, num, userid): 237 | sql = [] 238 | sql.append("select") 239 | sql.append(','.join(["w%d.name" % i for i in xrange(self.num)])) 240 | sql.append(',uc.user_id, c.count, uc.id') 241 | sql.append('from chain c') 242 | sql.append('inner join userchain uc on uc.chain_id = c.id') 243 | for i in xrange(self.num): 244 | sql.append('inner join word w%d on c.word%d_id = w%d.id' %\ 245 | (i, i, i)) 246 | sql.append("where uc.user_id = %d" % (userid)) 247 | self.cur.execute('\n'.join(sql)) 248 | rows = self.cur.fetchall() 249 | words = {} 250 | for row in rows: 251 | count = int(row[-2]) 252 | id = int(row[-1]) 253 | words[tuple(row[0: len(row) - 3])] = (count, id) 254 | return words 255 | 256 | def get_user(self, user): 257 | self.cur.execute('select id from user where name = "%s"' % (user)) 258 | row = self.cur.fetchone() 259 | if row is None: 260 | self.cur.execute("insert into user (name) values ('%s')" % \ 261 | (user)) 262 | self.cur.execute('select id from user where name = "%s"' % \ 263 | (user)) 264 | row = self.cur.fetchone() 265 | return int(row[0]) 266 | 267 | def get_userid(self, user): 268 | if user: 269 | self.cur.execute("select id from user where name = '%s'" % \ 270 | (user)) 271 | row = self.cur.fetchone() 272 | userid = int(row[0]) 273 | else: 274 | userid = 0 275 | return userid 276 | 277 | 278 | """ 279 | SQL 条件追加 280 | """ 281 | def _cond_join_userchain(self, userid): 282 | if (userid > 0): 283 | return ' inner join userchain uc on uc.chain_id = c.id' 284 | else: 285 | return '' 286 | 287 | def _cond_userid(self, userid): 288 | if (userid > 0): 289 | return ' and uc.user_id = %d' % (userid) 290 | else: 291 | return '' 292 | 293 | def _cond_wordname(self, word): 294 | if word: 295 | return ' and w0.name = "%s"' % (word) 296 | else: 297 | return '' 298 | 299 | def make_sentence(self, user='', word=None): 300 | limit = 1 301 | 302 | userid = self.get_userid(user) 303 | words = self.get_startword(self.num, userid, word) 304 | sentenceid = list(copy.copy(words)) 305 | 306 | count = 0 307 | punctuation_words = {u'。': 0, u'.': 0, u'?': 0, u'!': 0, 308 | u'!': 0, u'?': 0, u'w': 0, u'…': 0} 309 | punctuations = punctuation_words 310 | while True: 311 | end_cond = (count > limit) and (words[-1].name in punctuations) 312 | if end_cond: 313 | break 314 | 315 | nextwords = self.get_nextwords(words, userid, self.num) 316 | if len(nextwords) == 0: 317 | break 318 | 319 | nextword = Util.select_nextword(nextwords) 320 | sentenceid.append(nextword) 321 | tmp = [words[i] for i in xrange(1, self.num)] 322 | tmp.append(nextword) 323 | words = tuple(tmp) 324 | count += 1 325 | 326 | return ''.join([x.name for x in sentenceid]) 327 | 328 | -------------------------------------------------------------------------------- /markovchains/postgresql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | import os 5 | from ConfigParser import SafeConfigParser 6 | import copy 7 | import random 8 | 9 | try: 10 | import psycopg2 11 | except: 12 | pass 13 | 14 | from util import * 15 | 16 | 17 | class PostgreSQL(object): 18 | 19 | def __init__(self, dbname): 20 | if psycopg2: 21 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | self.inifile = os.path.join(BASE_DIR, 'settings.ini') 23 | user, password = self._load_ini() 24 | self.user = user 25 | self.password = password 26 | self.con = psycopg2.connect("user=%s password=%s" % (self.user, 27 | self.password)) 28 | self.con.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) 29 | 30 | self.cur = self.con.cursor() 31 | self.dbname = dbname 32 | self.num = 3 33 | else: 34 | raise BaseException 35 | 36 | def __del__(self): 37 | self.cur.close() 38 | self.con.close() 39 | 40 | def _reconnect_db(self): 41 | self.cur.close() 42 | self.con.close() 43 | self.con = psycopg2.connect("user=%s dbname=%s" % (self.user, self.dbname)) 44 | self.con.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) 45 | self.cur = self.con.cursor() 46 | 47 | def load_db(self): 48 | self.cur.execute('select datname from pg_database') 49 | rows = self.cur.fetchall() 50 | for row in rows: 51 | if row[0] == self.dbname: 52 | self._reconnect_db() 53 | return 54 | self._create_db() 55 | 56 | def _load_ini(self): 57 | parser = SafeConfigParser() 58 | parser.readfp(open(self.inifile)) 59 | user = parser.get('postgresql', 'user') 60 | password = parser.get('postgresql', 'password') 61 | return (user, password) 62 | 63 | 64 | """ 65 | データベース初期化 & テーブル作成 66 | """ 67 | def _create_db(self): 68 | self.cur.execute("create database %s encoding 'utf8'" %(self.dbname)) 69 | self._reconnect_db() 70 | self._init_tables() 71 | 72 | def _init_tables(self): 73 | self._init_user() 74 | self._init_word() 75 | self._init_chain() 76 | self._init_userchain() 77 | 78 | def _init_user(self): 79 | self.cur.execute(""" 80 | CREATE TABLE mc_user ( 81 | id serial PRIMARY KEY, 82 | name varchar NOT NULL 83 | ) 84 | """) 85 | 86 | def _init_word(self): 87 | self.cur.execute(''' 88 | CREATE TABLE word ( 89 | id serial PRIMARY KEY, 90 | name varchar(100) NOT NULL default '' 91 | ) 92 | ''') 93 | 94 | def _init_chain(self): 95 | sql = [] 96 | sql.append(''' 97 | CREATE TABLE chain ( 98 | id serial PRIMARY KEY, 99 | ''') 100 | for i in xrange(self.num): 101 | sql.append("word%d_id int NOT NULL default '0'," % i) 102 | sql.append(''' 103 | isstart bool NOT NULL default false, 104 | count int NOT NULL default '0' 105 | ''') 106 | for i in xrange(self.num): 107 | sql.append(",FOREIGN KEY (word%d_id) REFERENCES word(id)" % i) 108 | sql.append(")") 109 | self.cur.execute('\n'.join(sql)) 110 | 111 | sql = [] 112 | sql.append('CREATE INDEX chain_idx ON chain (') 113 | ids = [] 114 | for i in xrange(self.num): 115 | ids.append('word%d_id' % (i)) 116 | sql.append(','.join(ids)) 117 | sql.append(')') 118 | 119 | def _init_userchain(self): 120 | self.cur.execute(u''' 121 | CREATE TABLE userchain ( 122 | id serial PRIMARY KEY, 123 | user_id int NOT NULL default '0', 124 | chain_id int NOT NULL default '0', 125 | count int NOT NULL default '0', 126 | FOREIGN KEY (user_id) REFERENCES mc_user(id), 127 | FOREIGN KEY (chain_id) REFERENCES word(id) 128 | ) 129 | ''') 130 | 131 | 132 | """ 133 | データ挿入 & 更新 134 | """ 135 | def insert_words(self, sql): 136 | self.cur.execute(u'INSERT INTO word (name) VALUES %s' %\ 137 | (','.join(sql))) 138 | 139 | def insert_chains(self, values): 140 | sql = [] 141 | sql.append("INSERT INTO chain(") 142 | sql.append(",".join(["word%d_id" % i for i in xrange(self.num)])) 143 | sql.append(",isstart,count)") 144 | sql.append("VALUES %s" % (','.join(values))) 145 | self.cur.execute('\n'.join(sql)) 146 | 147 | def insert_userchains(self, sql): 148 | self.cur.execute(''' 149 | INSERT INTO userchain(user_id,chain_id,count) VALUES %s 150 | ''',(','.join(sql))) 151 | 152 | def update_chains(self, ids, count, isstart): 153 | sql = [] 154 | sql.append('UPDATE chain SET count=%d' % (count)) 155 | sql.append('WHERE') 156 | for i in xrange(self.num): 157 | sql.append('word%d_id = %d and ' % (i, ids[i])) 158 | sql.append('isstart = %s' % (str(isstart).upper())) 159 | self.cur.execute('\n'.join(sql)) 160 | 161 | def update_userchains(self, count, userchainid): 162 | self.cur.execute(''' 163 | UPDATE userchain 164 | SET count=%d 165 | WHERE id = %d 166 | ''',(count, userchainid)) 167 | 168 | 169 | """ 170 | データ取得 171 | """ 172 | def get_nextwords(self, words, userid, num): 173 | sql = [] 174 | sql.append('select c.word%d_id, w.name, c.count' % (num - 1)) 175 | sql.append('from chain as c') 176 | sql.append('inner join word as w on c.word%d_id = w.id' % (num - 1)) 177 | sql.append(self._cond_join_userchain(userid)) 178 | sql.append(' where ') 179 | ids = [] 180 | for i in xrange(num - 1): 181 | ids.append(' c.word%d_id = %d' % (i, words[i + 1].id)) 182 | sql.append(' and'.join(ids)) 183 | sql.append(self._cond_userid(userid)) 184 | sql.append(' order by count desc') 185 | 186 | self.cur.execute('\n'.join(sql)) 187 | rows = self.cur.fetchall() 188 | result = [] 189 | for row in rows: 190 | result.append(Word(int(row[0]), row[1].decode(), int(row[2]))) 191 | return result 192 | 193 | def get_startword(self, num, userid=-1, word=None): 194 | sql = [] 195 | sql.append('select ') 196 | for i in xrange(num): 197 | sql.append('c.word%d_id, w%d.name,' % (i, i)) 198 | sql.append(' c.count') 199 | sql.append(' from chain as c') 200 | for i in xrange(num): 201 | sql.append(' inner join word as w%d on c.word%d_id = w%d.id'\ 202 | % (i, i, i)) 203 | sql.append(self._cond_join_userchain(userid)) 204 | sql.append(' where ') 205 | sql.append(' c.isstart = TRUE') 206 | sql.append(self._cond_userid(userid)) 207 | sql.append(self._cond_wordname(word)) 208 | 209 | self.cur.execute('\n'.join(sql)) 210 | 211 | rows = self.cur.fetchall() 212 | row = random.choice(rows) 213 | 214 | result = [] 215 | for i in xrange(0, (num * 2) - 1, 2): 216 | result.append(Word(int(row[i]), row[i + 1].decode(), 217 | int(row[(num * 2)]))) 218 | return tuple(result) 219 | 220 | def get_allwords(self): 221 | self.cur.execute('select name,id from word') 222 | rows = self.cur.fetchall() 223 | words = {} 224 | for row in rows: 225 | words[row[0].decode()] = int(row[1]) 226 | return words 227 | 228 | def get_allchain(self, num): 229 | sql = [] 230 | sql.append("select") 231 | sql.append(','.join(["w%d.name" % i for i in xrange(num)])) 232 | sql.append(',c.isstart, c.count, c.id') 233 | sql.append('from chain as c') 234 | for i in xrange(self.num): 235 | sql.append('inner join word as w%d on c.word%d_id = w%d.id' %\ 236 | (i, i, i)) 237 | self.cur.execute('\n'.join(sql)) 238 | rows = self.cur.fetchall() 239 | words = {} 240 | for row in rows: 241 | id = int(row[-1]) 242 | count = int(row[-2]) 243 | isstart = row[-3] 244 | w1 = row[0].decode() 245 | w2 = row[1].decode() 246 | w3 = row[2].decode() 247 | words[(w1, w2, w3)] = Chain(id, count, isstart) 248 | return words 249 | 250 | def get_userchain(self, num, userid): 251 | sql = [] 252 | sql.append("select") 253 | sql.append(','.join(["w%d.name" % i for i in xrange(self.num)])) 254 | sql.append(',uc.user_id, c.count, uc.id') 255 | sql.append('from chain as c') 256 | sql.append('inner join userchain as uc on uc.chain_id = c.id') 257 | for i in xrange(self.num): 258 | sql.append('inner join word as w%d on c.word%d_id = w%d.id' %\ 259 | (i, i, i)) 260 | sql.append("where uc.user_id = %d" % (userid)) 261 | self.cur.execute('\n'.join(sql)) 262 | rows = self.cur.fetchall() 263 | words = {} 264 | for row in rows: 265 | count = int(row[-2]) 266 | id = int(row[-1]) 267 | words[tuple(row[0: len(row) - 3])] = (count, id) 268 | return words 269 | 270 | def get_user(self, user): 271 | self.cur.execute('select id from mc_user where name = %s', 272 | (user,)) 273 | row = self.cur.fetchone() 274 | if row is None: 275 | self.cur.execute("insert into mc_user (name) values (%s)", 276 | (user,)) 277 | self.cur.execute('select id from mc_user where name = %s', 278 | (user,)) 279 | row = self.cur.fetchone() 280 | return int(row[0]) 281 | 282 | def get_userid(self, user): 283 | if user: 284 | self.cur.execute("select id from mc_user where name = %s", 285 | (user,)) 286 | row = self.cur.fetchone() 287 | userid = int(row[0]) 288 | else: 289 | userid = 0 290 | return userid 291 | 292 | """ 293 | SQL 条件追加 294 | """ 295 | def _cond_join_userchain(self, userid): 296 | if (userid > 0): 297 | return ' inner join userchain uc on uc.chain_id = c.id' 298 | else: 299 | return '' 300 | 301 | def _cond_userid(self, userid): 302 | if (userid > 0): 303 | return ' and uc.user_id = %d' % (userid) 304 | else: 305 | return '' 306 | 307 | def _cond_wordname(self, word): 308 | if word: 309 | return ' and w0.name = "%s"' % (word) 310 | else: 311 | return '' 312 | 313 | def make_sentence(self, user='', word=None): 314 | limit = 1 315 | 316 | userid = self.get_userid(user) 317 | words = self.get_startword(self.num, userid, word) 318 | sentenceid = list(copy.copy(words)) 319 | 320 | count = 0 321 | punctuation_words = {u'。': 0, u'.': 0, u'?': 0, u'!': 0, 322 | u'!': 0, u'?': 0, u'w': 0, u'…': 0} 323 | punctuations = punctuation_words 324 | while True: 325 | end_cond = (count > limit) and (words[-1].name in punctuations) 326 | if end_cond: 327 | break 328 | 329 | nextwords = self.get_nextwords(words, userid, self.num) 330 | if len(nextwords) == 0: 331 | break 332 | 333 | nextword = Util.select_nextword(nextwords) 334 | sentenceid.append(nextword) 335 | tmp = [words[i] for i in xrange(1, self.num)] 336 | tmp.append(nextword) 337 | words = tuple(tmp) 338 | count += 1 339 | 340 | return ''.join([x.name for x in sentenceid]) 341 | 342 | -------------------------------------------------------------------------------- /markovchains/settings.ini.sample: -------------------------------------------------------------------------------- 1 | [mysql] 2 | user = user 3 | password = password 4 | -------------------------------------------------------------------------------- /markovchains/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import unittest 4 | from markovchains import * 5 | 6 | 7 | class TestWord(unittest.TestCase): 8 | 9 | def setUp(self): 10 | pass 11 | 12 | def test_init(self): 13 | id = 1 14 | name = 'name' 15 | count = 1 16 | word = Word(id, name, count) 17 | self.assert_(id, word.id) 18 | self.assert_(name, word.name) 19 | self.assert_(count, word.count) 20 | 21 | 22 | class TestMarkovChains(unittest.TestCase): 23 | 24 | def setUp(self): 25 | self.dbname = '__test_markovchains__' 26 | 27 | def tearDown(self): 28 | pass 29 | 30 | if __name__ == '__main__': 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /markovchains/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import random 4 | 5 | class Util(object): 6 | 7 | @classmethod 8 | def select_nextword(cls, words): 9 | sum_count = sum([x.count for x in words]) 10 | probs = [] 11 | for word in words: 12 | probs.append(word) 13 | probs[-1].count = float(probs[-1].count) / sum_count 14 | probs.sort(lambda x, y: cmp(x.count, y.count), reverse=True) 15 | randnum = random.random() 16 | sum_prob = 0 17 | nextword = '' 18 | for word in probs: 19 | sum_prob += word.count 20 | if randnum < sum_prob: 21 | nextword = word 22 | break 23 | return nextword 24 | 25 | class Word(object): 26 | 27 | def __init__(self, id, name, count): 28 | self.id = id 29 | self.name = name 30 | self.count = count 31 | 32 | 33 | class Chain(object): 34 | 35 | def __init__(self, id, count, isstart): 36 | self.id = id 37 | self.count = count 38 | self.isstart = isstart 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | setup( name = 'markovchains', 3 | version = '0.0.1', 4 | scripts = ['bin/mc_learn','bin/mc_talk'], 5 | packages=['markovchains'], 6 | ) 7 | 8 | --------------------------------------------------------------------------------