├── .gitignore ├── README.md ├── makefile ├── run ├── semi-supervised │ ├── train.py │ └── viterbi.py └── unsupervised │ ├── print_parameters.py │ ├── train.py │ └── viterbi.py ├── src ├── npylm │ ├── common.h │ ├── ctype.cpp │ ├── ctype.h │ ├── hash.cpp │ ├── hash.h │ ├── hashmap │ │ ├── flat_hashmap.h │ │ └── hashmap.h │ ├── lattice.cpp │ ├── lattice.h │ ├── lm │ │ ├── hpylm.cpp │ │ ├── hpylm.h │ │ ├── model.h │ │ ├── node.h │ │ ├── vpylm.cpp │ │ └── vpylm.h │ ├── npylm.cpp │ ├── npylm.h │ ├── sampler.cpp │ ├── sampler.h │ ├── sentence.cpp │ ├── sentence.h │ ├── wordtype.cpp │ └── wordtype.h ├── python.cpp └── python │ ├── corpus.cpp │ ├── corpus.h │ ├── dataset.cpp │ ├── dataset.h │ ├── dictionary.cpp │ ├── dictionary.h │ ├── model.cpp │ ├── model.h │ ├── trainer.cpp │ └── trainer.h └── test ├── generate_test_sequence.py ├── module_tests ├── hash.cpp ├── lattice.cpp ├── npylm.cpp ├── sentence.cpp ├── vpylm.cpp └── wordtype.cpp └── running_tests ├── save.cpp └── train.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all 2 | * 3 | 4 | # allow any 5 | !*/ 6 | !*.py 7 | !*.cpp 8 | !makefile 9 | !*.h 10 | !*.md 11 | !.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nested Pitman-Yor Language Model (NPYLM) 2 | 3 | [ベイズ階層言語モデルによる教師なし形態素解析](http://chasen.org/~daiti-m/paper/nl190segment.pdf)のC++実装です。 4 | 5 | 単語n-gramモデルは3-gramで固定です。2-gramは非対応です。 6 | 7 | 現在も開発途中です。 8 | 9 | 実装について 10 | - [ベイズ階層言語モデルによる教師なし形態素解析](http://musyoku.github.io/2016/12/14/%E3%83%99%E3%82%A4%E3%82%BA%E9%9A%8E%E5%B1%A4%E8%A8%80%E8%AA%9E%E3%83%A2%E3%83%87%E3%83%AB%E3%81%AB%E3%82%88%E3%82%8B%E6%95%99%E5%B8%AB%E3%81%AA%E3%81%97%E5%BD%A2%E6%85%8B%E7%B4%A0%E8%A7%A3%E6%9E%90/) 11 | - [Forward filtering-Backward samplingによる単語分割でアンダーフローを防ぐ](http://musyoku.github.io/2017/04/15/forward-filtering-backward-sampling%E3%81%A7%E3%82%A2%E3%83%B3%E3%83%80%E3%83%BC%E3%83%95%E3%83%AD%E3%83%BC%E3%82%92%E9%98%B2%E3%81%90/) 12 | 13 | 文字列の単語ID化をハッシュで実装しているため、学習結果を違うコンピュータで用いると正しく分割が行えない可能性があります。 14 | 15 | ## 更新履歴 16 | 17 | ### 2017/11/21 18 | - 前向き確率の計算をlogsumexpからスケーリングに変更 19 | 20 | ## 動作環境 21 | 22 | - Boost 23 | - C++14 24 | - Python 3 25 | 26 | ## 準備 27 | 28 | ### macOS 29 | 30 | macOSの場合、PythonとBoostはともにbrewでインストールする必要があります。 31 | 32 | #### Python 3のインストール 33 | 34 | ``` 35 | brew install python3 36 | ``` 37 | 38 | `PYTHONPATH`を変更する必要があるかもしれません。 39 | 40 | #### Boostのインストール 41 | 42 | ``` 43 | brew install boost-python --with-python3 44 | ``` 45 | 46 | ### Ubuntu 47 | 48 | #### Boostのインストール 49 | 50 | ``` 51 | ./bootstrap.sh --with-python=python3 --with-python-version=3.5 52 | ./b2 python=3.5 -d2 -j4 --prefix YOUR_BOOST_DIR install 53 | ``` 54 | 55 | Pythonのバージョンを自身のものと置き換えてください。 56 | 57 | ### ビルド 58 | 59 | 以下のコマンドで`npylm.so`が生成され、Pythonから利用できるようになります。 60 | 61 | ``` 62 | make install 63 | ``` 64 | 65 | `makefile`内のBoostのパスを自身の環境に合わせて書き換えてください。 66 | 67 | Ubuntuでエラーが出る場合は代わりに以下を実行します。 68 | 69 | ``` 70 | make install_ubuntu 71 | ``` 72 | 73 | ### MeCabのインストール 74 | 75 | 半教師あり学習をする場合は必要です。 76 | 77 | ``` 78 | pip install mecab-python3 79 | ``` 80 | 81 | ## 学習(教師なし) 82 | 83 | `run/unsupervised`にコード例があります。 84 | 85 | ### 実行例 86 | 87 | ``` 88 | python3 train.py -split 0.9 -l 8 -file YOUR_TEXTFILE 89 | ``` 90 | 91 | ### オプション 92 | 93 | - -file 94 | - 学習に使うテキストファイル 95 | - -dir 96 | - 学習に使うテキストファイル群が入っているディレクトリ 97 | - 複数ファイルを用いる場合はこちらを指定 98 | - -split 99 | - 読み込んだ行のうち何割を学習に用いるか 100 | - 0から1の実数を指定 101 | - 1を指定すると全データを用いてモデルを学習する 102 | - -l 103 | - 可能な単語の最大長 104 | - 日本語なら8〜16、英語なら16〜20程度を指定 105 | - 文の長さをN、単語の最大長をLとすると、NPYLMの計算量はO(NL^3)になる 106 | 107 | なお学習の再開はできません。 108 | 109 | ## 学習(半教師あり) 110 | 111 | `run/semi-supervised`にコード例があります。 112 | 113 | ### 実行例 114 | 115 | ``` 116 | python3 train.py -train-split 0.9 -ssl-split 0.01 -l 8 -file YOUR_TEXTFILE 117 | ``` 118 | 119 | ### オプション 120 | 121 | - -file 122 | - 学習に使うテキストファイル 123 | - -dir 124 | - 学習に使うテキストファイル群が入っているディレクトリ 125 | - 複数ファイルを用いる場合はこちらを指定 126 | - -train-split 127 | - 読み込んだ行のうち何割を学習に用いるか 128 | - 0から1の実数を指定 129 | - 1を指定すると全データを用いてモデルを学習する 130 | - -ssl-split 131 | - 学習データのうち何割を教師データに用いるか 132 | - 0から1の実数を指定 133 | - -l 134 | - 可能な単語の最大長 135 | - 日本語なら8〜16、英語なら16〜20程度を指定 136 | - 文の長さをN、単語の最大長をLとすると、NPYLMの計算量はO(NL^3)になる 137 | 138 | なお学習の再開はできません。 139 | 140 | ## 単語分割 141 | 142 | 分割結果をファイルに保存します。 143 | 144 | ``` 145 | python viterbi.py -file YOUR_TEXTFILE 146 | ``` 147 | 148 | ### オプション 149 | 150 | - -file 151 | - 分割するテキストファイル 152 | - -dir 153 | - 分割するテキストファイルが入っているディレクトリ 154 | - 複数ファイルをまとめて分割する場合はこれを指定 155 | - ファイルごとに個別の出力ファイルが作成されます 156 | - -out 157 | - 出力フォルダ 158 | 159 | ## 注意事項 160 | 161 | 研究以外の用途には使用できません。 162 | 163 | https://twitter.com/daiti_m/status/851810748263157760 164 | 165 | 実装に誤りが含まれる可能性があります。 166 | 167 | 質問等、何かありましたらissueにてお知らせください。 168 | 169 | ## 展望 170 | 171 | 現在、[条件付確率場とベイズ階層言語モデルの統合による半教師あり形態素解析](http://chasen.org/~daiti-m/paper/nlp2011semiseg.pdf)と[半教師あり形態素解析NPYCRFの修正](http://www.anlp.jp/proceedings/annual_meeting/2016/pdf_dir/D6-3.pdf)を実装しています。 172 | 173 | ## 実行結果 174 | 175 | #### けものフレンズ実況スレ 176 | 177 | アンダーフローが起きていないかを確認するためにわざと1行あたりの文字数を多くして学習させています。 178 | 179 | 15万行あったデータから改行を削除し4,770行に圧縮しました。 180 | 181 | > / アニメ / に / セルリアン / 設定 / を / 出 / す / 必要 / は / まったく / なかった / と思う / um / nk / は / 原作 / と / 漫画 / で / 終わり方 / 変え / て / た / なあ / 原作 / 有り / だ / と / 色々 / 考え / ちゃ / う / けど / オリジナル / だ / と / 余計な / 事 / 考え / なくていい / ね / 犬 / のフレンズ / だから / さ / 、 / ガキが余計な / レスつけてくんな / って / 言って / る / だろ / ジャパリパーク / 滅亡 / まで / あと / 1 / 4 / 日 / 間 / 。 / まだ / 作品 / は / 完結 / して / ない / し / この / あと / 話 / に / どう / 関わって / くる / か / わから / ない / と思う / んだが / 最終話 / まで / い / って / / そういう / こと / だった / のか / ー / ! / / って / な / った / 後で / 見返すと / また / 唸 / ら / され / そう / な / 場面 / が / 多い / ん / だよ / ね / お別れ / エンド / 想像 / する / だけで / 震え / て / くる / けど / ハッピー / エンド / なら / 受け入れ / る / しかない / よ / ね / この / スレ / は / 子供 / が / 多くて / 最終回の / 感動 / も / 萎える / だ / ろ / う / 百合 / アニメ / として / 楽しんで / る / 奴 / もいる / ん / だね / ぇ / あー / 、 / それ / は / 誤 / り / だ / 。 / 低 / レアリティ / フレンズ / にスポットライト / が / 当て / られ / て / る / こと / もある / 。 / 毒 / 抜いた / っていう / のは / 伏線 / 回収 / は / だいたい / や / る / けど / も / 鬱 / な / 感じ / に / は / して / ない / ていう / 意味 / だ / と / 解釈 / し / た / けど / 【 / � / 】 / 『けものフレンズ』 / 第 / 10話「ろっじ」 / より / 、 / 先行 / 場面 / カット / が / 到着 / ! / 宿泊 / した / ロッジ / に / 幽霊 / ……? / 【 / � / 】 / [無断転載禁止]©2ch. / net [ / 8 / 9 / 1 / 1 / 9 / 1 / 9 / 2 / 3 / ] / 実際のところ / けもの / フレンズ / は / 百合 / なのか / ? / 例えば / どこが / 百合 / っぽい / ? / いや / 、 / ある / だろ / 。 / セルリアン / が / いる / 事 / に / よって / 物語 / に / 緊張感 / が / 生まれ / る / 。 / 伏線 / が / 結構 / 大きい / 物 / な気がする / んだ / けど / これ / あと / 2話 / で / 終わ / る / のか / なぁ / ? / もしかして / 「 / ジャパリ / パーク / の / 外に / 人間 / を / 探しに / 行く / よ / ! / カバンちゃん / たち / の / 冒険は / これから / だ!」 / エンド / じゃある / まい / な / それ / でも / あれ / は / 許容 / し / 難 / かった / とおもう / ぞ / そもそも / 利 / 潤 / 第一 / でない / けもの / プロジェクト / に / 売上 / で / 優劣 / 語 / る / の / は / ナンセンス / や / ぞ / の / タイトル / も / いい / な / カバンちゃん / 「 / さーばる / 島の外に / 出 / ちゃ / う / と / 記憶 / も / なくな / る / ん / だ / よ / ? / 」 / さーばる / 「 / うん / 、 / それ / でも / いい / よ / 。 / 」 / カバンちゃん / 「 / う / うん / 、 / ダメ / 。 / ボク / が / のこ / る / よ / 。 / 」 / さーばる / 「 / え? / ・・・ / 」 / カバンちゃん / 「 / さーばる / は / 大切な / 僕 / のフレンズ / だから / 」 / わざわざ / ng / 宣言 / とか / キッズ / の / 代表 / みたいな / も / ん / だろ / 出来 / の / 良い / おさ / らい / サイト / が / ある / よ / 俺 / も / そこ / で / 予習 / し / た / タイム / アタック / 式 / の / クエスト / で / 低 / レア / 構成 / ボーナス / 入るから / 人権 / ない / ってレベル / じゃ / ない / ぞ / 182 | 183 | > / 予約 / でき / ない / のに / いつまで / 2巻 / は / ランキング / 1位 / なん / だ / よ / けもの / フレンズ / 、 / 縮めて / フレンズ / 。 / ジャパリパーク / の / 不思議 / な / 不思議な / 生き物 / 。 / 空 / に / 山 / に / 海 / に / フレンズ / は / いた / る / ところ / で / そ / の / 姿 / を / 見 / る / こと / が / 出来 / る / 。この / 少女 / 、 / ヒト / と / 呼ばれ / る / かばん。 / 相棒 / の / サーバル / と / 共に / バトル / & / ゲット / 。 / フレンズの / 数だけ / の / 出会い / が / あり / フレンズの / 数だけ / の / 別れ / がある / ( / 石塚運昇) / ニコ動 / 調子 / 悪 / そう / なんだ / けど / 上映会 / だ / いじょうぶ / か / ね / コスプレ / はよ / もう / 何度も / 書かれてる / だろうけど / 、 / 2話 / ed / で / 入 / って / きた / 身としては / ed / やっぱ / 良い / 。 / 歌 / も / 歌詞 / も / 合って / る / し / 、 / 普通の / アニメ / っぽく / ない / 廃墟 / に / した / の / も / 全部 / 良い / 。 / 情報 / が / 氾濫 / し / すぎ / て / 何 / が / ネタバレ / か / さっぱり / 分からん / けど / 、 / 1つ / ぐらい / は / 本物 / が / まじ / っ / て / そう / だな / 。 / ま、 / 来週 / を / 楽しみ / に / して / る / よ / 。 / アライさん / の / 「 / 困難は / 群れで分け合え / 」 / って / 台詞 / もしかして / アプリ版 / で / 出て / きた / こと / あった / りする / ? / それ / なら / 記憶 / の / 引 / 継ぎ / 説 / は / ほぼ / 間違え / なさ / そう / だ / けど / 神 / 展開 / で / ワロタ / これ / は / 地上波 / 流 / せ / ません / ね / ぇ / … / まあ、 / 数 / 打ちゃ当たる / 11話の / 展開 / は / 予想 / されて / なかった / 気配 / だ / が / 汗 / まあ / ニコ動 / ランキング / に / あ / っ / た / アプリ版 / ストーリー / 見 / た / 時 / 点 / で / すでに / 覚悟はできてる / 一 / 人 / でも / いいから / 出 / して / ほしかった / … / マジで / サーバル / / プレーリードッグ / / ヘラジカ / 殷周 / 伝説 / みたい / な / 肉 / マン / を / 想像 / し / ちゃ / っ / た / じゃぱりまん / の / 中身 / で / 肉 / が / ある / なら / だ / が / 無い / から / 安心 / 184 | 185 | > / 知識 / や / 性格 / まで / クローン / は / 無理 / だ / と思う / nhkで / アニメ / 放送 / から / の / 紅白 / 出場 / だと / 嫌な予感がする / 藤子不二 / 雄 / や / 大友 / 克洋や / 鳥山明 / は / エール / を / 送 / られ / た / が、 / 自分 / が / 理解 / でき / ない / もの / が / 流行 / っ / て / る / と / 怒 / った / らしい / な / 巨人の星 / とか / ス / ポ / 根 / も / のは / 、 / どうして / こんな / もの / が / 人気 / なん / だ / って / アシスタント / と / 担当 / に / 怒鳴り / 散らし / た / そう / な / 日本語 / で / しか / 伝わらない / 表現 / ある / もん / ね / ひらがな / カタカナ / 漢字 / でも / ニュアンス / 使い / 分 / け / られ / る / し / また / 同時 / に / 英語 / いい / なあ / と思う / 所 / もある / 親友 / ( / ?) / なのに / そ / の / 欠片 / も / み / られ / ない / 猫 / の / 話 / は / やめ / なさい / … / なん / か / 、 / マズルの / ところ / が / カイ / ゼル / 髭 / みたい / で / 、 / かわいい / っていうより / カッコイイ / とおもう / ん / だが / / 「 / ( / いや / 俺ら / に / 言われ / て / も / ) / 」 / って / 困惑 / する / 様子 / が / w / 藤子不二 / 雄 / は / 貶 / そう / と / 思って / た / ら / 全力で / 尊敬 / されて / 持ち / 上げ / て / くる / ので / 面倒見 / ざるを得な / かった / とか / いう / ホント / か / ウソ / か / 分からない / 逸話 / すこ / 世界 / 最高峰 / の / 日本 / アニメ / を / 字幕 / 無し / で / 観 / られ / る / 幸せ / たぶん / そうな / ん / だろう / とは思う / が / おいしい / とこ / だけ / 取ら / れ / て / 何も / やってない / 扱い / で / うん / ざ / り / して / そう / 結局 / 自分 / の / 好 / み / って事 / か / 本人 / に / しか / わから / ない / 先 / 駆 / 者 / として / の / 強烈な / 自負 / と / 、 / 追い / 抜か / れ / る / 恐怖 / かわ / あった / のだろう / と愚考 / した / 。 / スポーツ / 漫画や / 劇 / 画 / が / 流行 / っ / た / 時 / ノイローゼ / になり / かけ / た / と / 聞く / が / 、 / 90年代 / 以降 / の / トーン / バリバリ / アニメ / 絵柄 / や / 萌え / 文化 / とか / 見 / た / ら / どう / なってしまう / んだろう / あと / サーバルちゃん / かわいい / コクボス / 「 / ジカ / ン / ダヨ / 」 / 礼儀 / は / 守 / る / 人 / だろう / … / 内心 / 穏やか / ではない / が / ニコ動 / の再生数 / なんて / まったく / 当て / に / ならん / ぞ / ニコ生 / の / 来場者 / 数 / と / 有料 / 動画 / の / 再生数 / は / 当て / に / して / いい / けど / 6話 / の / へいげん / の / とき / / ヘラジカ / さん / たち / に / 「 / サーバルキャット / のサーバル / だよ / 」って / 自己紹介 / して / た / のが / なんか / 不思議 / だ / った / な / 「 / 省略 / する / 」って / 文化 / ある / ん / だ / / みたい / な / 一話 / の再生数 / で / 分かる / のは / 洗脳 / 度 / だけ / だ / な / すぐ / 解 / け / る / 洗脳 / かも知れん / が / それ / 見たい / わ / めちゃくちゃ / 好循環 / じゃない / か / … / イッカク / クジラ / の / イッカク / だ。 / って / 名乗 / って / る / キャラ / もいた / し / 。 / いや~ / メンゴメンゴ / / あ / 、 / ボス / さん / きゅー / w / アプリ版 / の / オオ / アルマジロ / の / 声で / 絡んで / ほしかった / ( / cv / 相 / 沢 / 舞 / ) / 名前 / の / 概念 / 例えば / ヘラジカ / は / 種族 / として / 言って / る / のか / 名前 / と / して / 言って / る / のか / かばんちゃん / と / 名 / 付け / て / る / から / 概念 / として / は / 在 / る / ん / だろう / けど / 果たして / クローン / は / 同じ / ポテンシャル / を / 引き / 出 / せ / る / だろう / か / 。 / 藤子f / という / か / ドラえもん / に / は / 完全 / 敗北 / を / 認め / て / た / らしい / から / ね / ドラえもん / を / 超え / る / キャラ / は / 作れ / ない / って / どうぶつ / スクープ / と / ダーウィン / と / wbc / どっち / を優先 / すべき / か / 186 | 187 | > / 捨て / た / り / しない / ん / じゃないかな / ? / ま、 / ちょっと / (ry / ロッ / ソファンタズマ / は / 先代 / サーバル / の / 技 / って / 言いたい / のは / わかる / けど / 。 / かばんちゃん / 視点 / から / 見ると / op / の / 映像 / の意味 / になる / と / いう / ダブルミーニング / なの / かも知れない / これ / は / なかなか / 面白い / 基本的に / ゲスト / フレンズって / カップル / にな / る / けど / この / 二人 / は / その後 / どうなった / ん / だろう / ね / 杏子 / も / 野中 / か / 優しく / て / 元気な / 女の子 / の / 唐突な / 涙 / は / めちゃくちゃ / くる / もの / がある / な / ppp / が / 歌 / って / る / こと / から / 考えると / あり / 得 / る / ね / 、 / 宣伝 / 曲 / そして / まんま / と / ようこそ / され / ました / なんだか / コケ / そう / な / 気がして / ならない / 12話 / だから / 話 / を / 膨らま / せ / て / 伏線 / を / 回収 / して / って / の / が / でき / た / けど / だら / っと / 続け / て / いく / と / すると / ・・・ / たまに / 見 / る / 豆腐 / や / 冷奴 / は / 何 / を / 示唆 / して / る / ん / だ / ? / 姿 / が / フレンズ / だと / 判別 / でき / ない / ん / じゃないか / とりあえず / 初版 / は / 無理 / でも / 重版分 / 買おう / 古典sf / 的に / タイ / ムスリップ / して / アプリ版 / プレイヤー / として / ミライさん / の / 運命 / を / 変え / に / 行く / とか / は / あり / そう / 188 | 189 | > / すごーい / ぞ / おー / ! / が / 正解 / で / 世界遺 / 産! / が / 空耳 / おいおい / エヴァ / か / ? / 聖域 / への / 立ち / 入り / 認可 / 、 / 正体 / 不明 / な / 新規 / フレンズ / の / 評価 / 、 / 困った / とき / の / 相談 / 役 / 色々 / やって / る / な / 輪廻転生 / して / 二人 / は / 一緒 / も / 人間 / は / 滅んで / て / かばんちゃん / は / サーバルちゃん / とずっと一緒 / も / ぶっちゃけ / 百合厨の願望 / だし / たつきが / そんな安直な / 設定 / に / せ / ん / でしょ / ジャパリパーク / って / 時間 / の / 進 / み / が / サンドスター / その他 / の / 影響 / で / 物凄く / 早 / く / なって / る / から / 人工 / 物 / が / 朽ち / て / た / り / する / ん / か / な / 聞こえ / ない / と / 言えば / ツチノコ / の / 「 / ピット器官 / ! / 」 / って / 言 / っ / た / 後 / に / モ / ニャモニャ / … / って / なんか / 言って / る / けど / なんて / 言って / る / のか / 未だに / 分からない / … / メ / イ / ズ / ランナー / 的な / 人類滅亡 / とか / ちょっと / 似てる / し / 13話 / ネタバレ / 中尉 / か / つて / の / ロッジ / で / は / フレンズと / 一 / 夜 / を / 共に / する / こと / が / でき / る / 人工 / 物 / は / ヒト / が / 使わ / なくな / る / と / たった / 数年 / で / 朽ち / る / よ / 林 / 業 / 管理 / も / やって / る / から / な / toki / o / のフレンズ / と / 言われ / て / も / 不思議 / はない / 図書館 / に / 入 / り / 浸 / って / ゴロゴロ / 絵本 / 読んで / る / フレンズ / い / ない / かな / ピット器官 / ! / ・・・ / ・・・ / だとかでぇ、 / 俺には / 赤外線が見えるからな / ! / ( / ・ / ∀・) / 目 / が / か / ゆい / … / ジャパリパーク / に / も / 花粉症 / は / ある / のか / な / サーバル / が / なぜ / ハシビロコウ / だけ / ハシビロ / ちゃん / 呼び / なのか / が / 最大の / 謎 / 12話 / は / op / は / 一番 / 最後 / カバンちゃん / と / サーバルちゃんが / 仲良く / ブランコ / に / 揺ら / れ / る / 絵 / が / 入 / る / よ / けもフレ / を / 細かく / 見 / て / ない / 間違って / セリフ / 覚え / て / る / アプリ / 時代 / の / 知識 / を / 間違って / 捉え / て / る / って / いう / ので / 、 / 悶々と / 考察 / してる / 人 / の / 多 / い / こと / 多い / こと / … / … / 190 | 191 | #### 2ch 192 | 193 | 2chから集めた884,158行の書き込みで学習を行いました。 194 | 195 | 196 | > / どことなく / 硬貨 / の / 気配 / が / ある / な / 197 | 198 | > / 展開 / の / ヒント / と / 世界観 / を / 膨らま / せ / る / ギミック / と / 物語 / の / 伏線 / を / 本当に / 勘違い / して / る / 人 / は / い / ない / でしょ / 199 | 200 | > / トラブっても / その / 前の / 状態 / に / 簡単 / に / 戻 / れ / る / 201 | 202 | > / レシート / を / わ / た / さ / ない / 会社 / は / 100% / 脱税 / している / 203 | 204 | > / すっきり / した / 。 / 実装 / 当時 / は / 2度と / やりたく / 無い / と思った / けど / 205 | 206 | > / 未だ / 趣味 / な / 個人 / 用途 / で / win / 10 / に / 頑なに / 乗り換え / ない / ヤツ / なんて / 新しい / もん / に / 適応 / でき / ない / 老 / 化 / 始ま / っ / ちゃ / って / る / お / 人 / か / 207 | 208 | > / 実家の / 猫 / がよくやる / けど / あんまり / 懐 / かれ / て / る / 気がしない / 209 | 210 | > / ラデ / の / ラインナップ / は / こう / いう / 噂 / のようだ。 / 211 | 212 | > / ダメウォ / なんて / 殆ど / で / ねー / じゃねーか / ど / アホ / 213 | 214 | > / 新 / retina / 、 / 旧 / retina / が / 併売 / され / て / る / 中 / で / 比較 / やら / 機種 / 選び / ごと / に / 別 / スレ / 面倒 / だ / もん / 215 | 216 | > / イオク / 出 / る / だけで / 不快 / 217 | 218 | > / あの / まま / やってりゃ / ジュリア / の / 撃墜 / も / 時間の問題 / だ / っ / た / し / 219 | 220 | > / も / し / 踊ら / され / て / た / ら / 面白 / さ / を / 感じ / られ / る / はず / だ / 221 | 222 | > / 二連 / スレ建て / オ / ッ / ツ / オ / ッ / ツ / 223 | 224 | > / の / ガチャ限 / 定運極化特別ルール / って / 何 / ですか / ? / 225 | 226 | > / 特に / その / 辺 / フォロー / ない / まま / あの / 状況 / で / と / どめ / 刺 / し / 損 / ね / ました / で / 最後 / まで / いく / の / は / な / ・・・ / 227 | 228 | > / こうなると / 意外 / に / ツチノコ / が / ハードル / 低 / そう / 229 | 230 | > / 強制 / アップデート / のたびに / 自分 / の / 使い方 / にあわせた / 細かい / 設定 / を / 勝手に / 戻 / す / だけ / で / なく / 231 | 232 | > / マジか了 / 解した / 233 | 234 | > / 今度 / は / mac / 使い / () / に / 乗り換え / た / ん / だろう / が / ・・・ / 哀れ / よ / のぅ / 235 | / 今 / 後 / も / ノエル / たくさん / 配 / って / くれ / る / なら / 問題ない / けど / 236 | 237 | > / マルチ / 魔窟 / 初めて / やった / けど / フレンド / が / いい人 / で / 上手く / 出来 / た / わ / 238 | 239 | > / 咲 / くん / も / 女 / の / 子 / 声優 / が / よかった / 240 | 241 | > / 確かに / 少し / づつ / エンジン / かか / っ / て / き / た / 感じ / が / する / な / 242 | 243 | > / くっそ / 、 / まず / ラファエル / が / 出 / ねえ / 244 | 245 | > / 第六 / 世代 / cpu / で / 組 / もう / と思って / る / けど / win10 / 買 / う / の / は / は / 待った / 方が / いい / のか / な / これ / … / (´・ω・`) / 246 | 247 | > / 移動 / 先 / で / ある程度 / なんで / も / で / き / る / mbp / は / 本当に / いい / 製品 / だ / と思います / 248 | 249 | > / いや / 俺 / は / そこ / が / 好き / 250 | 251 | > / と言えば / ギャラホルン / 崩壊 / しそう / 252 | 253 | > / オオクニ欲 / しかった / -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | CC = g++ 2 | BOOST = /usr/local/Cellar/boost/1.65.0 3 | INCLUDE = `python3-config --includes` -std=c++14 -I$(BOOST)/include 4 | LDFLAGS = `python3-config --ldflags` -lboost_serialization -lboost_python3 -L$(BOOST)/lib 5 | SOFLAGS = -shared -fPIC -march=native 6 | TESTFLAGS = -O0 -g -Wall 7 | SOURCES = src/python/*.cpp src/npylm/*.cpp src/npylm/lm/*.cpp 8 | 9 | install: ## npylm.soを生成 10 | $(CC) $(INCLUDE) $(SOFLAGS) src/python.cpp $(SOURCES) $(LDFLAGS) -o run/npylm.so -O3 11 | cp run/npylm.so run/semi-supervised/npylm.so 12 | cp run/npylm.so run/unsupervised/npylm.so 13 | rm -rf run/npylm.so 14 | 15 | install_ubuntu: ## npylm.soを生成 16 | $(CC) -Wl,--no-as-needed -Wno-deprecated $(INCLUDE) $(SOFLAGS) src/python.cpp $(SOURCES) $(LDFLAGS) -o run/npylm.so -O3 17 | cp run/npylm.so run/semi-supervised/npylm.so 18 | cp run/npylm.so run/unsupervised/npylm.so 19 | rm -rf run/npylm.so 20 | 21 | check_includes: ## Python.hの場所を確認 22 | python3-config --includes 23 | 24 | check_ldflags: ## libpython3の場所を確認 25 | python3-config --ldflags 26 | 27 | module_tests: ## 各モジュールのテスト. 28 | $(CC) test/module_tests/wordtype.cpp $(SOURCES) -o test/module_tests/wordtype $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 29 | ./test/module_tests/wordtype 30 | $(CC) test/module_tests/npylm.cpp $(SOURCES) -o test/module_tests/npylm $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 31 | ./test/module_tests/npylm 32 | $(CC) test/module_tests/vpylm.cpp $(SOURCES) -o test/module_tests/vpylm $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 33 | ./test/module_tests/vpylm 34 | $(CC) test/module_tests/sentence.cpp $(SOURCES) -o test/module_tests/sentence $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 35 | ./test/module_tests/sentence 36 | $(CC) test/module_tests/hash.cpp $(SOURCES) -o test/module_tests/hash $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 37 | ./test/module_tests/hash 38 | $(CC) test/module_tests/lattice.cpp $(SOURCES) -o test/module_tests/lattice $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 39 | ./test/module_tests/lattice 40 | 41 | running_tests: ## 運用テスト 42 | $(CC) test/running_tests/train.cpp $(SOURCES) -o test/running_tests/train $(INCLUDE) $(LDFLAGS) -O0 -g 43 | $(CC) test/running_tests/save.cpp $(SOURCES) -o test/running_tests/save $(INCLUDE) $(LDFLAGS) $(TESTFLAGS) 44 | 45 | .PHONY: help 46 | help: 47 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 48 | .DEFAULT_GOAL := help -------------------------------------------------------------------------------- /run/semi-supervised/train.py: -------------------------------------------------------------------------------- 1 | import argparse, sys, os, time, codecs, random 2 | import MeCab 3 | import npylm 4 | 5 | 6 | class stdout: 7 | BOLD = "\033[1m" 8 | END = "\033[0m" 9 | CLEAR = "\033[2K" 10 | 11 | 12 | def printb(string): 13 | print(stdout.BOLD + string + stdout.END) 14 | 15 | 16 | def printr(string): 17 | sys.stdout.write("\r" + stdout.CLEAR) 18 | sys.stdout.write(string) 19 | sys.stdout.flush() 20 | 21 | 22 | def build_corpus(filepath, directory, semisupervised_split_ratio): 23 | assert filepath is not None or directory is not None 24 | corpus = npylm.corpus() 25 | sentence_list = [] 26 | 27 | if filepath is not None: 28 | with codecs.open(filepath, "r", "utf-8") as f: 29 | for sentence_str in f: 30 | sentence_str = sentence_str.strip() 31 | sentence_list.append(sentence_str) 32 | 33 | if directory is not None: 34 | for filename in os.listdir(directory): 35 | with codecs.open(os.path.join(directory, filename), "r", 36 | "utf-8") as f: 37 | for sentence_str in f: 38 | sentence_str = sentence_str.strip() 39 | sentence_list.append(sentence_str) 40 | 41 | random.shuffle(sentence_list) 42 | 43 | semisupervised_split = int(len(sentence_list) * semisupervised_split_ratio) 44 | sentence_list_l = sentence_list[:semisupervised_split] 45 | sentence_list_u = sentence_list[semisupervised_split:] 46 | 47 | tagger = MeCab.Tagger() 48 | tagger.parse("") 49 | for sentence_str in sentence_list_l: 50 | m = tagger.parseToNode(sentence_str) # 形態素解析 51 | words = [] 52 | while m: 53 | word = m.surface 54 | if len(word) > 0: 55 | words.append(word) 56 | m = m.next 57 | if len(words) > 0: 58 | corpus.add_true_segmentation(words) 59 | 60 | for sentence_str in sentence_list_u: 61 | corpus.add_sentence(sentence_str) 62 | 63 | return corpus 64 | 65 | 66 | def main(): 67 | parser = argparse.ArgumentParser() 68 | # 以下のどちらかを必ず指定 69 | parser.add_argument( 70 | "--train-filename", 71 | "-file", 72 | type=str, 73 | default=None, 74 | help="訓練用のテキストファイルのパス") 75 | parser.add_argument( 76 | "--train-directory", 77 | "-dir", 78 | type=str, 79 | default=None, 80 | help="訓練用のテキストファイルが入っているディレクトリ") 81 | 82 | parser.add_argument("--seed", type=int, default=1) 83 | parser.add_argument( 84 | "--epochs", "-e", type=int, default=100000, help="総epoch") 85 | parser.add_argument( 86 | "--working-directory", 87 | "-cwd", 88 | type=str, 89 | default="out", 90 | help="ワーキングディレクトリ") 91 | parser.add_argument( 92 | "--train-split", 93 | "-train-split", 94 | type=float, 95 | default=0.9, 96 | help="テキストデータの何割を訓練データにするか") 97 | parser.add_argument( 98 | "--semisupervised-split", 99 | "-ssl-split", 100 | type=float, 101 | default=0.1, 102 | help="テキストデータの何割を教師データにするか") 103 | 104 | parser.add_argument("--lambda-a", "-lam-a", type=float, default=4) 105 | parser.add_argument("--lambda-b", "-lam-b", type=float, default=1) 106 | parser.add_argument( 107 | "--vpylm-beta-stop", "-beta-stop", type=float, default=4) 108 | parser.add_argument( 109 | "--vpylm-beta-pass", "-beta-pass", type=float, default=1) 110 | parser.add_argument( 111 | "--max-word-length", "-l", type=int, default=16, help="可能な単語の最大長.") 112 | args = parser.parse_args() 113 | 114 | assert args.working_directory is not None 115 | try: 116 | os.mkdir(args.working_directory) 117 | except: 118 | pass 119 | 120 | # 訓練データを追加 121 | corpus = build_corpus(args.train_filename, args.train_directory, 122 | args.semisupervised_split) 123 | dataset = npylm.dataset(corpus, args.train_split, args.seed) 124 | 125 | print("#train", dataset.get_num_sentences_train()) 126 | print("#train (supervised)", dataset.get_num_sentences_supervised()) 127 | print("#dev", dataset.get_num_sentences_dev()) 128 | 129 | # 単語辞書を保存 130 | dictionary = dataset.get_dict() 131 | dictionary.save(os.path.join(args.working_directory, "npylm.dict")) 132 | 133 | # モデル 134 | model = npylm.model(dataset, args.max_word_length) # 可能な単語の最大長を指定 135 | 136 | # ハイパーパラメータの設定 137 | model.set_initial_lambda_a(args.lambda_a) 138 | model.set_initial_lambda_b(args.lambda_b) 139 | model.set_vpylm_beta_stop(args.vpylm_beta_stop) 140 | model.set_vpylm_beta_pass(args.vpylm_beta_pass) 141 | 142 | # 学習の準備 143 | trainer = npylm.trainer(dataset, model) 144 | 145 | # 文字列の単語IDが衝突しているかどうかをチェック 146 | # 時間の無駄なので一度したらしなくてよい 147 | # メモリを大量に消費します 148 | if True: 149 | print("ハッシュの衝突を確認中 ...") 150 | num_checked_words = dataset.detect_hash_collision(args.max_word_length) 151 | print("衝突はありません (総単語数 {})".format(num_checked_words)) 152 | 153 | # 学習ループ 154 | for epoch in range(1, args.epochs + 1): 155 | start = time.time() 156 | trainer.gibbs() # ギブスサンプリング 157 | trainer.sample_hpylm_vpylm_hyperparameters( 158 | ) # HPYLMとVPYLMのハイパーパラメータの更新 159 | trainer.sample_lambda() # λの更新 160 | 161 | # p(k|VPYLM)の推定は数イテレーション後にやるほうが精度が良い 162 | if epoch > 3: 163 | trainer.update_p_k_given_vpylm() 164 | 165 | model.save(os.path.join(args.working_directory, "npylm.model")) 166 | 167 | # ログ 168 | elapsed_time = time.time() - start 169 | printr("Iteration {} / {} - {:.3f} sec".format(epoch, args.epochs, 170 | elapsed_time)) 171 | if epoch % 10 == 0: 172 | printr("") 173 | trainer.print_segmentation_train(10) 174 | print("ppl_dev: {}".format(trainer.compute_perplexity_dev())) 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /run/semi-supervised/viterbi.py: -------------------------------------------------------------------------------- 1 | import argparse, os, codecs, sys 2 | import npylm 3 | 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser() 7 | # 以下のどちらかを必ず指定 8 | parser.add_argument( 9 | "--input-filename", 10 | "-file", 11 | type=str, 12 | default=None, 13 | help="訓練用のテキストファイルのパス") 14 | parser.add_argument( 15 | "--input-directory", 16 | "-dir", 17 | type=str, 18 | default=None, 19 | help="訓練用のテキストファイルが入っているディレクトリ") 20 | 21 | parser.add_argument( 22 | "--working-directory", 23 | "-cwd", 24 | type=str, 25 | default="out", 26 | help="ワーキングディレクトリ") 27 | parser.add_argument( 28 | "--output-directory", "-out", type=str, default="out", help="分割結果の出力先") 29 | args = parser.parse_args() 30 | 31 | try: 32 | os.mkdir(args.output_dir) 33 | except: 34 | pass 35 | 36 | model = npylm.model(os.path.join(args.working_directory, "npylm.model")) 37 | 38 | if args.input_filename is not None: 39 | segmentation_list = [] 40 | with codecs.open(args.input_filename, "r", "utf-8") as f: 41 | for sentence_str in f: 42 | sentence_str = sentence_str.strip() 43 | segmentation = model.parse(sentence_str) 44 | if len(segmentation) > 0: 45 | segmentation_list.append(segmentation) 46 | 47 | filename = args.input_filename.split("/")[-1] 48 | with codecs.open( 49 | os.path.join(args.output_directory, filename), "w", 50 | "utf-8") as f: 51 | for segmentation in segmentation_list: 52 | f.write(" ".join(segmentation)) 53 | f.write("\n") 54 | 55 | if args.input_directory is not None: 56 | for filename in os.listdir(args.input_directory): 57 | print("processing {} ...".format(filename)) 58 | segmentation_list = [] 59 | with codecs.open( 60 | os.path.join(args.input_directory, filename), "r", 61 | "utf-8") as f: 62 | for sentence_str in f: 63 | sentence_str = sentence_str.strip() 64 | segmentation = model.parse(sentence_str) 65 | if len(segmentation) > 0: 66 | segmentation_list.append(segmentation) 67 | 68 | with codecs.open( 69 | os.path.join(args.output_directory, filename), "w", 70 | "utf-8") as f: 71 | for segmentation in segmentation_list: 72 | f.write(" ".join(segmentation)) 73 | f.write("\n") 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /run/unsupervised/print_parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | import os 4 | import sys 5 | 6 | import npylm 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--working-directory", 13 | "-cwd", 14 | type=str, 15 | default="out", 16 | help="ワーキングディレクトリ") 17 | args = parser.parse_args() 18 | 19 | model = npylm.model(os.path.join(args.working_directory, "npylm.model")) 20 | lambda_list = model.get_lambda() 21 | word_types = [ 22 | "アルファベット", "数字", "記号", "ひらがな", "カタカナ", "漢字", "漢字+ひらがな", "漢字+カタカナ", 23 | "その他" 24 | ] 25 | for wtype, lam in zip(word_types, lambda_list): 26 | print(wtype, lam) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /run/unsupervised/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | import os 4 | import sys 5 | import time 6 | 7 | import npylm 8 | 9 | 10 | class stdout: 11 | BOLD = "\033[1m" 12 | END = "\033[0m" 13 | CLEAR = "\033[2K" 14 | 15 | 16 | def printb(string): 17 | print(stdout.BOLD + string + stdout.END) 18 | 19 | 20 | def printr(string): 21 | sys.stdout.write("\r" + stdout.CLEAR) 22 | sys.stdout.write(string) 23 | sys.stdout.flush() 24 | 25 | 26 | def build_corpus(filepath, directory): 27 | assert filepath is not None or directory is not None 28 | corpus = npylm.corpus() 29 | 30 | if filepath is not None: 31 | with codecs.open(filepath, "r", "utf-8") as f: 32 | for sentence_str in f: 33 | sentence_str = sentence_str.strip() 34 | corpus.add_sentence(sentence_str) 35 | 36 | if directory is not None: 37 | for filename in os.listdir(directory): 38 | with codecs.open(os.path.join(directory, filename), "r", 39 | "utf-8") as f: 40 | for sentence_str in f: 41 | sentence_str = sentence_str.strip() 42 | corpus.add_sentence(sentence_str) 43 | 44 | return corpus 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser() 49 | # 以下のどちらかを必ず指定 50 | parser.add_argument( 51 | "--train-filename", 52 | "-file", 53 | type=str, 54 | default=None, 55 | help="訓練用のテキストファイルのパス") 56 | parser.add_argument( 57 | "--train-directory", 58 | "-dir", 59 | type=str, 60 | default=None, 61 | help="訓練用のテキストファイルが入っているディレクトリ") 62 | 63 | parser.add_argument("--seed", type=int, default=1) 64 | parser.add_argument( 65 | "--epochs", "-e", type=int, default=100000, help="総epoch") 66 | parser.add_argument( 67 | "--working-directory", 68 | "-cwd", 69 | type=str, 70 | default="out", 71 | help="ワーキングディレクトリ") 72 | parser.add_argument( 73 | "--train-split", 74 | "-split", 75 | type=float, 76 | default=0.9, 77 | help="テキストデータの何割を訓練データにするか") 78 | 79 | parser.add_argument("--lambda-a", "-lam-a", type=float, default=4) 80 | parser.add_argument("--lambda-b", "-lam-b", type=float, default=1) 81 | parser.add_argument( 82 | "--vpylm-beta-stop", "-beta-stop", type=float, default=4) 83 | parser.add_argument( 84 | "--vpylm-beta-pass", "-beta-pass", type=float, default=1) 85 | parser.add_argument( 86 | "--max-word-length", "-l", type=int, default=16, help="可能な単語の最大長.") 87 | args = parser.parse_args() 88 | 89 | assert args.working_directory is not None 90 | try: 91 | os.mkdir(args.working_directory) 92 | except: 93 | pass 94 | 95 | # 訓練データを追加 96 | corpus = build_corpus(args.train_filename, args.train_directory) 97 | dataset = npylm.dataset(corpus, args.train_split, args.seed) 98 | 99 | print("#train", dataset.get_num_sentences_train()) 100 | print("#dev", dataset.get_num_sentences_dev()) 101 | 102 | # 単語辞書を保存 103 | dictionary = dataset.get_dict() 104 | dictionary.save(os.path.join(args.working_directory, "npylm.dict")) 105 | 106 | # モデル 107 | model = npylm.model(dataset, args.max_word_length) # 可能な単語の最大長を指定 108 | 109 | # ハイパーパラメータの設定 110 | model.set_initial_lambda_a(args.lambda_a) 111 | model.set_initial_lambda_b(args.lambda_b) 112 | model.set_vpylm_beta_stop(args.vpylm_beta_stop) 113 | model.set_vpylm_beta_pass(args.vpylm_beta_pass) 114 | 115 | # 学習の準備 116 | trainer = npylm.trainer(dataset, model) 117 | 118 | # 文字列の単語IDが衝突しているかどうかをチェック 119 | # 時間の無駄なので一度したらしなくてよい 120 | # メモリを大量に消費します 121 | if True: 122 | print("ハッシュの衝突を確認中 ...") 123 | num_checked_words = dataset.detect_hash_collision(args.max_word_length) 124 | print("衝突はありません (総単語数 {})".format(num_checked_words)) 125 | 126 | # 学習ループ 127 | for epoch in range(1, args.epochs + 1): 128 | start = time.time() 129 | trainer.gibbs() # ギブスサンプリング 130 | trainer.sample_hpylm_vpylm_hyperparameters( 131 | ) # HPYLMとVPYLMのハイパーパラメータの更新 132 | trainer.sample_lambda() # λの更新 133 | 134 | # p(k|VPYLM)の推定は数イテレーション後にやるほうが精度が良い 135 | if epoch > 3: 136 | trainer.update_p_k_given_vpylm() 137 | 138 | model.save(os.path.join(args.working_directory, "npylm.model")) 139 | 140 | # ログ 141 | elapsed_time = time.time() - start 142 | printr("Iteration {} / {} - {:.3f} sec".format(epoch, args.epochs, 143 | elapsed_time)) 144 | if epoch % 10 == 0: 145 | printr("") 146 | trainer.print_segmentation_train(10) 147 | print("ppl_dev: {}".format(trainer.compute_perplexity_dev())) 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /run/unsupervised/viterbi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import codecs 3 | import os 4 | import sys 5 | 6 | import npylm 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | # 以下のどちらかを必ず指定 12 | parser.add_argument( 13 | "--input-filename", 14 | "-file", 15 | type=str, 16 | default=None, 17 | help="訓練用のテキストファイルのパス") 18 | parser.add_argument( 19 | "--input-directory", 20 | "-dir", 21 | type=str, 22 | default=None, 23 | help="訓練用のテキストファイルが入っているディレクトリ") 24 | 25 | parser.add_argument( 26 | "--working-directory", 27 | "-cwd", 28 | type=str, 29 | default="out", 30 | help="ワーキングディレクトリ") 31 | parser.add_argument( 32 | "--output-directory", "-out", type=str, default="out", help="分割結果の出力先") 33 | args = parser.parse_args() 34 | 35 | try: 36 | os.mkdir(args.output_dir) 37 | except: 38 | pass 39 | 40 | model = npylm.model(os.path.join(args.working_directory, "npylm.model")) 41 | 42 | if args.input_filename is not None: 43 | segmentation_list = [] 44 | with codecs.open(args.input_filename, "r", "utf-8") as f: 45 | for sentence_str in f: 46 | sentence_str = sentence_str.strip() 47 | segmentation = model.parse(sentence_str) 48 | if len(segmentation) > 0: 49 | segmentation_list.append(segmentation) 50 | 51 | filename = args.input_filename.split("/")[-1] 52 | with codecs.open( 53 | os.path.join(args.output_directory, filename), "w", 54 | "utf-8") as f: 55 | for segmentation in segmentation_list: 56 | f.write(" ".join(segmentation)) 57 | f.write("\n") 58 | 59 | if args.input_directory is not None: 60 | for filename in os.listdir(args.input_directory): 61 | print("processing {} ...".format(filename)) 62 | segmentation_list = [] 63 | with codecs.open( 64 | os.path.join(args.input_directory, filename), "r", 65 | "utf-8") as f: 66 | for sentence_str in f: 67 | sentence_str = sentence_str.strip() 68 | segmentation = model.parse(sentence_str) 69 | if len(segmentation) > 0: 70 | segmentation_list.append(segmentation) 71 | 72 | with codecs.open( 73 | os.path.join(args.output_directory, filename), "w", 74 | "utf-8") as f: 75 | for segmentation in segmentation_list: 76 | f.write(" ".join(segmentation)) 77 | f.write("\n") 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /src/npylm/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "hashmap/hashmap.h" 4 | #include "hashmap/flat_hashmap.h" 5 | 6 | #ifdef __NO_INLINE__ 7 | #define __DEBUG__ 1 8 | #endif 9 | 10 | template 11 | using hashmap = ska::flat_hash_map>; // これが最速 12 | // using hashmap = ska::flat_hash_map; 13 | // using hashmap = emilib::HashMap; 14 | // using hashmap = std::unordered_map; 15 | 16 | using id = size_t; 17 | 18 | #define HPYLM_INITIAL_D 0.5 19 | #define HPYLM_INITIAL_THETA 2.0 20 | #define HPYLM_BETA_A 1.0 21 | #define HPYLM_BETA_B 1.0 22 | #define HPYLM_GAMMA_ALPHA 1.0 23 | #define HPYLM_GAMMA_BETA 1.0 24 | 25 | #define VPYLM_BETA_STOP 4 26 | #define VPYLM_BETA_PASS 1 27 | #define VPYLM_EPS 1e-12 28 | 29 | #define ID_BOS 0 30 | #define ID_BOW 0 31 | #define ID_EOS 1 32 | #define ID_EOW 2 -------------------------------------------------------------------------------- /src/npylm/ctype.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CTYPE_UNKNOWN 0 5 | #define CTYPE_BASIC_LATIN 1 6 | #define CTYPE_LATIN_1_SUPPLEMENT 2 7 | #define CTYPE_LATIN_EXTENDED_A 3 8 | #define CTYPE_LATIN_EXTENDED_B 4 9 | #define CTYPE_IPA_EXTENSIONS 5 10 | #define CTYPE_SPACING_MODIFIER_LETTERS 6 11 | #define CTYPE_COMBINING_DIACRITICAL_MARKS 7 12 | #define CTYPE_GREEK_AND_COPTIC 8 13 | #define CTYPE_CYRILLIC 9 14 | #define CTYPE_CYRILLIC_SUPPLEMENT 10 15 | #define CTYPE_ARMENIAN 11 16 | #define CTYPE_HEBREW 12 17 | #define CTYPE_ARABIC 13 18 | #define CTYPE_SYRIAC 14 19 | #define CTYPE_ARABIC_SUPPLEMENT 15 20 | #define CTYPE_THAANA 16 21 | #define CTYPE_NKO 17 22 | #define CTYPE_SAMARITAN 18 23 | #define CTYPE_MANDAIC 19 24 | #define CTYPE_SYRIAC_SUPPLEMENT 20 25 | #define CTYPE_ARABIC_EXTENDED_A 21 26 | #define CTYPE_DEVANAGARI 22 27 | #define CTYPE_BENGALI 23 28 | #define CTYPE_GURMUKHI 24 29 | #define CTYPE_GUJARATI 25 30 | #define CTYPE_ORIYA 26 31 | #define CTYPE_TAMIL 27 32 | #define CTYPE_TELUGU 28 33 | #define CTYPE_KANNADA 29 34 | #define CTYPE_MALAYALAM 30 35 | #define CTYPE_SINHALA 31 36 | #define CTYPE_THAI 32 37 | #define CTYPE_LAO 33 38 | #define CTYPE_TIBETAN 34 39 | #define CTYPE_MYANMAR 35 40 | #define CTYPE_GEORGIAN 36 41 | #define CTYPE_HANGUL_JAMO 37 42 | #define CTYPE_ETHIOPIC 38 43 | #define CTYPE_ETHIOPIC_SUPPLEMENT 39 44 | #define CTYPE_CHEROKEE 40 45 | #define CTYPE_UNIFIED_CANADIAN_ABORIGINAL_SYLLABICS 41 46 | #define CTYPE_OGHAM 42 47 | #define CTYPE_RUNIC 43 48 | #define CTYPE_TAGALOG 44 49 | #define CTYPE_HANUNOO 45 50 | #define CTYPE_BUHID 46 51 | #define CTYPE_TAGBANWA 47 52 | #define CTYPE_KHMER 48 53 | #define CTYPE_MONGOLIAN 49 54 | #define CTYPE_UNIFIED_CANADIAN_ABORIGINAL_SYLLABICS_EXTENDED 50 55 | #define CTYPE_LIMBU 51 56 | #define CTYPE_TAI_LE 52 57 | #define CTYPE_NEW_TAI_LUE 53 58 | #define CTYPE_KHMER_SYMBOLS 54 59 | #define CTYPE_BUGINESE 55 60 | #define CTYPE_TAI_THAM 56 61 | #define CTYPE_COMBINING_DIACRITICAL_MARKS_EXTENDED 57 62 | #define CTYPE_BALINESE 58 63 | #define CTYPE_SUNDANESE 59 64 | #define CTYPE_BATAK 60 65 | #define CTYPE_LEPCHA 61 66 | #define CTYPE_OL_CHIKI 62 67 | #define CTYPE_CYRILLIC_EXTENDED_C 63 68 | #define CTYPE_SUNDANESE_SUPPLEMENT 64 69 | #define CTYPE_VEDIC_EXTENSIONS 65 70 | #define CTYPE_PHONETIC_EXTENSIONS 66 71 | #define CTYPE_PHONETIC_EXTENSIONS_SUPPLEMENT 67 72 | #define CTYPE_COMBINING_DIACRITICAL_MARKS_SUPPLEMENT 68 73 | #define CTYPE_LATIN_EXTENDED_ADDITIONAL 69 74 | #define CTYPE_GREEK_EXTENDED 70 75 | #define CTYPE_GENERAL_PUNCTUATION 71 76 | #define CTYPE_SUPERSCRIPTS_AND_SUBSCRIPTS 72 77 | #define CTYPE_CURRENCY_SYMBOLS 73 78 | #define CTYPE_COMBINING_DIACRITICAL_MARKS_FOR_SYMBOLS 74 79 | #define CTYPE_LETTERLIKE_SYMBOLS 75 80 | #define CTYPE_NUMBER_FORMS 76 81 | #define CTYPE_ARROWS 77 82 | #define CTYPE_MATHEMATICAL_OPERATORS 78 83 | #define CTYPE_MISCELLANEOUS_TECHNICAL 79 84 | #define CTYPE_CONTROL_PICTURES 80 85 | #define CTYPE_OPTICAL_CHARACTER_RECOGNITION 81 86 | #define CTYPE_ENCLOSED_ALPHANUMERICS 82 87 | #define CTYPE_BOX_DRAWING 83 88 | #define CTYPE_BLOCK_ELEMENTS 84 89 | #define CTYPE_GEOMETRIC_SHAPES 85 90 | #define CTYPE_MISCELLANEOUS_SYMBOLS 86 91 | #define CTYPE_DINGBATS 87 92 | #define CTYPE_MISCELLANEOUS_MATHEMATICAL_SYMBOLS_A 88 93 | #define CTYPE_SUPPLEMENTAL_ARROWS_A 89 94 | #define CTYPE_BRAILLE_PATTERNS 90 95 | #define CTYPE_SUPPLEMENTAL_ARROWS_B 91 96 | #define CTYPE_MISCELLANEOUS_MATHEMATICAL_SYMBOLS_B 92 97 | #define CTYPE_SUPPLEMENTAL_MATHEMATICAL_OPERATORS 93 98 | #define CTYPE_MISCELLANEOUS_SYMBOLS_AND_ARROWS 94 99 | #define CTYPE_GLAGOLITIC 95 100 | #define CTYPE_LATIN_EXTENDED_C 96 101 | #define CTYPE_COPTIC 97 102 | #define CTYPE_GEORGIAN_SUPPLEMENT 98 103 | #define CTYPE_TIFINAGH 99 104 | #define CTYPE_ETHIOPIC_EXTENDED 100 105 | #define CTYPE_CYRILLIC_EXTENDED_A 101 106 | #define CTYPE_SUPPLEMENTAL_PUNCTUATION 102 107 | #define CTYPE_CJK_RADICALS_SUPPLEMENT 103 108 | #define CTYPE_KANGXI_RADICALS 104 109 | #define CTYPE_IDEOGRAPHIC_DESCRIPTION_CHARACTERS 105 110 | #define CTYPE_CJK_SYMBOLS_AND_PUNCTUATION 106 111 | #define CTYPE_HIRAGANA 107 112 | #define CTYPE_KATAKANA 108 113 | #define CTYPE_BOPOMOFO 109 114 | #define CTYPE_HANGUL_COMPATIBILITY_JAMO 110 115 | #define CTYPE_KANBUN 111 116 | #define CTYPE_BOPOMOFO_EXTENDED 112 117 | #define CTYPE_CJK_STROKES 113 118 | #define CTYPE_KATAKANA_PHONETIC_EXTENSIONS 114 119 | #define CTYPE_ENCLOSED_CJK_LETTERS_AND_MONTHS 115 120 | #define CTYPE_CJK_COMPATIBILITY 116 121 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_A 117 122 | #define CTYPE_YIJING_HEXAGRAM_SYMBOLS 118 123 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS 119 124 | #define CTYPE_YI_SYLLABLES 120 125 | #define CTYPE_YI_RADICALS 121 126 | #define CTYPE_LISU 122 127 | #define CTYPE_VAI 123 128 | #define CTYPE_CYRILLIC_EXTENDED_B 124 129 | #define CTYPE_BAMUM 125 130 | #define CTYPE_MODIFIER_TONE_LETTERS 126 131 | #define CTYPE_LATIN_EXTENDED_D 127 132 | #define CTYPE_SYLOTI_NAGRI 128 133 | #define CTYPE_COMMON_INDIC_NUMBER_FORMS 129 134 | #define CTYPE_PHAGS_PA 130 135 | #define CTYPE_SAURASHTRA 131 136 | #define CTYPE_DEVANAGARI_EXTENDED 132 137 | #define CTYPE_KAYAH_LI 133 138 | #define CTYPE_REJANG 134 139 | #define CTYPE_HANGUL_JAMO_EXTENDED_A 135 140 | #define CTYPE_JAVANESE 136 141 | #define CTYPE_MYANMAR_EXTENDED_B 137 142 | #define CTYPE_CHAM 138 143 | #define CTYPE_MYANMAR_EXTENDED_A 139 144 | #define CTYPE_TAI_VIET 140 145 | #define CTYPE_MEETEI_MAYEK_EXTENSIONS 141 146 | #define CTYPE_ETHIOPIC_EXTENDED_A 142 147 | #define CTYPE_LATIN_EXTENDED_E 143 148 | #define CTYPE_CHEROKEE_SUPPLEMENT 144 149 | #define CTYPE_MEETEI_MAYEK 145 150 | #define CTYPE_HANGUL_SYLLABLES 146 151 | #define CTYPE_HANGUL_JAMO_EXTENDED_B 147 152 | #define CTYPE_HIGH_SURROGATES 148 153 | #define CTYPE_HIGH_PRIVATE_USE_SURROGATES 149 154 | #define CTYPE_LOW_SURROGATES 150 155 | #define CTYPE_PRIVATE_USE_AREA 151 156 | #define CTYPE_CJK_COMPATIBILITY_IDEOGRAPHS 152 157 | #define CTYPE_ALPHABETIC_PRESENTATION_FORMS 153 158 | #define CTYPE_ARABIC_PRESENTATION_FORMS_A 154 159 | #define CTYPE_VARIATION_SELECTORS 155 160 | #define CTYPE_VERTICAL_FORMS 156 161 | #define CTYPE_COMBINING_HALF_MARKS 157 162 | #define CTYPE_CJK_COMPATIBILITY_FORMS 158 163 | #define CTYPE_SMALL_FORM_VARIANTS 159 164 | #define CTYPE_ARABIC_PRESENTATION_FORMS_B 160 165 | #define CTYPE_HALFWIDTH_AND_FULLWIDTH_FORMS 161 166 | #define CTYPE_SPECIALS 162 167 | #define CTYPE_LINEAR_B_SYLLABARY 163 168 | #define CTYPE_LINEAR_B_IDEOGRAMS 164 169 | #define CTYPE_AEGEAN_NUMBERS 165 170 | #define CTYPE_ANCIENT_GREEK_NUMBERS 166 171 | #define CTYPE_ANCIENT_SYMBOLS 167 172 | #define CTYPE_PHAISTOS_DISC 168 173 | #define CTYPE_LYCIAN 169 174 | #define CTYPE_CARIAN 170 175 | #define CTYPE_COPTIC_EPACT_NUMBERS 171 176 | #define CTYPE_OLD_ITALIC 172 177 | #define CTYPE_GOTHIC 173 178 | #define CTYPE_OLD_PERMIC 174 179 | #define CTYPE_UGARITIC 175 180 | #define CTYPE_OLD_PERSIAN 176 181 | #define CTYPE_DESERET 177 182 | #define CTYPE_SHAVIAN 178 183 | #define CTYPE_OSMANYA 179 184 | #define CTYPE_OSAGE 180 185 | #define CTYPE_ELBASAN 181 186 | #define CTYPE_CAUCASIAN_ALBANIAN 182 187 | #define CTYPE_LINEAR_A 183 188 | #define CTYPE_CYPRIOT_SYLLABARY 184 189 | #define CTYPE_IMPERIAL_ARAMAIC 185 190 | #define CTYPE_PALMYRENE 186 191 | #define CTYPE_NABATAEAN 187 192 | #define CTYPE_HATRAN 188 193 | #define CTYPE_PHOENICIAN 189 194 | #define CTYPE_LYDIAN 190 195 | #define CTYPE_MEROITIC_HIEROGLYPHS 191 196 | #define CTYPE_MEROITIC_CURSIVE 192 197 | #define CTYPE_KHAROSHTHI 193 198 | #define CTYPE_OLD_SOUTH_ARABIAN 194 199 | #define CTYPE_OLD_NORTH_ARABIAN 195 200 | #define CTYPE_MANICHAEAN 196 201 | #define CTYPE_AVESTAN 197 202 | #define CTYPE_INSCRIPTIONAL_PARTHIAN 198 203 | #define CTYPE_INSCRIPTIONAL_PAHLAVI 199 204 | #define CTYPE_PSALTER_PAHLAVI 200 205 | #define CTYPE_OLD_TURKIC 201 206 | #define CTYPE_OLD_HUNGARIAN 202 207 | #define CTYPE_RUMI_NUMERAL_SYMBOLS 203 208 | #define CTYPE_BRAHMI 204 209 | #define CTYPE_KAITHI 205 210 | #define CTYPE_SORA_SOMPENG 206 211 | #define CTYPE_CHAKMA 207 212 | #define CTYPE_MAHAJANI 208 213 | #define CTYPE_SHARADA 209 214 | #define CTYPE_SINHALA_ARCHAIC_NUMBERS 210 215 | #define CTYPE_KHOJKI 211 216 | #define CTYPE_MULTANI 212 217 | #define CTYPE_KHUDAWADI 213 218 | #define CTYPE_GRANTHA 214 219 | #define CTYPE_NEWA 215 220 | #define CTYPE_TIRHUTA 216 221 | #define CTYPE_SIDDHAM 217 222 | #define CTYPE_MODI 218 223 | #define CTYPE_MONGOLIAN_SUPPLEMENT 219 224 | #define CTYPE_TAKRI 220 225 | #define CTYPE_AHOM 221 226 | #define CTYPE_WARANG_CITI 222 227 | #define CTYPE_ZANABAZAR_SQUARE 223 228 | #define CTYPE_SOYOMBO 224 229 | #define CTYPE_PAU_CIN_HAU 225 230 | #define CTYPE_BHAIKSUKI 226 231 | #define CTYPE_MARCHEN 227 232 | #define CTYPE_MASARAM_GONDI 228 233 | #define CTYPE_CUNEIFORM 229 234 | #define CTYPE_CUNEIFORM_NUMBERS_AND_PUNCTUATION 230 235 | #define CTYPE_EARLY_DYNASTIC_CUNEIFORM 231 236 | #define CTYPE_EGYPTIAN_HIEROGLYPHS 232 237 | #define CTYPE_ANATOLIAN_HIEROGLYPHS 233 238 | #define CTYPE_BAMUM_SUPPLEMENT 234 239 | #define CTYPE_MRO 235 240 | #define CTYPE_BASSA_VAH 236 241 | #define CTYPE_PAHAWH_HMONG 237 242 | #define CTYPE_MIAO 238 243 | #define CTYPE_IDEOGRAPHIC_SYMBOLS_AND_PUNCTUATION 239 244 | #define CTYPE_TANGUT 240 245 | #define CTYPE_TANGUT_COMPONENTS 241 246 | #define CTYPE_KANA_SUPPLEMENT 242 247 | #define CTYPE_KANA_EXTENDED_A 243 248 | #define CTYPE_NUSHU 244 249 | #define CTYPE_DUPLOYAN 245 250 | #define CTYPE_SHORTHAND_FORMAT_CONTROLS 246 251 | #define CTYPE_BYZANTINE_MUSICAL_SYMBOLS 247 252 | #define CTYPE_MUSICAL_SYMBOLS 248 253 | #define CTYPE_ANCIENT_GREEK_MUSICAL_NOTATION 249 254 | #define CTYPE_TAI_XUAN_JING_SYMBOLS 250 255 | #define CTYPE_COUNTING_ROD_NUMERALS 251 256 | #define CTYPE_MATHEMATICAL_ALPHANUMERIC_SYMBOLS 252 257 | #define CTYPE_SUTTON_SIGNWRITING 253 258 | #define CTYPE_GLAGOLITIC_SUPPLEMENT 254 259 | #define CTYPE_MENDE_KIKAKUI 255 260 | #define CTYPE_ADLAM 256 261 | #define CTYPE_ARABIC_MATHEMATICAL_ALPHABETIC_SYMBOLS 257 262 | #define CTYPE_MAHJONG_TILES 258 263 | #define CTYPE_DOMINO_TILES 259 264 | #define CTYPE_PLAYING_CARDS 260 265 | #define CTYPE_ENCLOSED_ALPHANUMERIC_SUPPLEMENT 261 266 | #define CTYPE_ENCLOSED_IDEOGRAPHIC_SUPPLEMENT 262 267 | #define CTYPE_MISCELLANEOUS_SYMBOLS_AND_PICTOGRAPHS 263 268 | #define CTYPE_EMOTICONS 264 269 | #define CTYPE_ORNAMENTAL_DINGBATS 265 270 | #define CTYPE_TRANSPORT_AND_MAP_SYMBOLS 266 271 | #define CTYPE_ALCHEMICAL_SYMBOLS 267 272 | #define CTYPE_GEOMETRIC_SHAPES_EXTENDED 268 273 | #define CTYPE_SUPPLEMENTAL_ARROWS_C 269 274 | #define CTYPE_SUPPLEMENTAL_SYMBOLS_AND_PICTOGRAPHS 270 275 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_B 271 276 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_C 272 277 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_D 273 278 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_E 274 279 | #define CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_F 275 280 | #define CTYPE_CJK_COMPATIBILITY_IDEOGRAPHS_SUPPLEMENT 276 281 | #define CTYPE_TAGS 277 282 | #define CTYPE_VARIATION_SELECTORS_SUPPLEMENT 278 283 | #define CTYPE_SUPPLEMENTARY_PRIVATE_USE_AREA_A 279 284 | #define CTYPE_SUPPLEMENTARY_PRIVATE_USE_AREA_B 280 285 | 286 | namespace npylm { 287 | namespace ctype { 288 | unsigned int get_type(wchar_t c); 289 | std::string get_name(unsigned int type); 290 | } 291 | } -------------------------------------------------------------------------------- /src/npylm/hash.cpp: -------------------------------------------------------------------------------- 1 | #include "hash.h" 2 | 3 | namespace npylm{ 4 | size_t load_bytes(const char* p, int n){ 5 | size_t result = 0; 6 | --n; 7 | do{ 8 | result = (result << 8) + static_cast(p[n]); 9 | }while(--n >= 0); 10 | return result; 11 | } 12 | size_t shift_mix(size_t v){ 13 | return v ^ (v >> 47); 14 | } 15 | size_t unaligned_load(const char* p){ 16 | size_t result; 17 | __builtin_memcpy(&result, p, sizeof(result)); 18 | return result; 19 | } 20 | #if __SIZEOF_SIZE_T__ == 4 21 | size_t hash_bytes(const void* ptr, size_t len, size_t seed){ 22 | size_t seed = static_cast(0xc70f6907UL); 23 | size_t hash = seed ^ len; 24 | const char* buf = static_cast(ptr); 25 | while(len >= 4){ 26 | size_t k = unaligned_load(buf); 27 | k *= m; 28 | k ^= k >> 24; 29 | k *= m; 30 | hash *= m; 31 | hash ^= k; 32 | buf += 4; 33 | len -= 4; 34 | } 35 | switch(len){ 36 | case 3: 37 | hash ^= static_cast(buf[2]) << 16; 38 | case 2: 39 | hash ^= static_cast(buf[1]) << 8; 40 | case 1: 41 | hash ^= static_cast(buf[0]); 42 | hash *= m; 43 | }; 44 | hash ^= hash >> 13; 45 | hash *= m; 46 | hash ^= hash >> 15; 47 | return hash; 48 | } 49 | #elif __SIZEOF_SIZE_T__ == 8 50 | size_t hash_bytes(const void* ptr, size_t len){ 51 | size_t seed = static_cast(0xc70f6907UL); 52 | static const size_t mul = (((size_t) 0xc6a4a793UL) << 32UL) + (size_t) 0x5bd1e995UL; 53 | const char* const buf = static_cast(ptr); 54 | const int len_aligned = len & ~0x7; 55 | const char* const end = buf + len_aligned; 56 | size_t hash = seed ^ (len * mul); 57 | for (const char* p = buf; p != end; p += 8){ 58 | const size_t data = shift_mix(unaligned_load(p) * mul) * mul; 59 | hash ^= data; 60 | hash *= mul; 61 | } 62 | if ((len & 0x7) != 0){ 63 | const size_t data = load_bytes(end, len & 0x7); 64 | hash ^= data; 65 | hash *= mul; 66 | } 67 | hash = shift_mix(hash) * mul; 68 | hash = shift_mix(hash); 69 | return hash; 70 | } 71 | #endif 72 | size_t hash_wstring(const std::wstring &str){ 73 | return hash_bytes(str.data(), str.size() * sizeof(wchar_t)); 74 | } 75 | size_t hash_substring_ptr(wchar_t const* ptr, int start, int end){ 76 | return hash_bytes(ptr + start, (end - start + 1) * sizeof(wchar_t)); 77 | } 78 | size_t hash_substring(const std::wstring &str, int start, int end){ 79 | wchar_t const* ptr = str.data(); 80 | return hash_substring_ptr(ptr, start, end); 81 | } 82 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/hash.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace npylm { 5 | size_t load_bytes(const char* p, int n); 6 | size_t shift_mix(size_t v); 7 | size_t unaligned_load(const char* p); 8 | #if __SIZEOF_SIZE_T__ == 4 9 | size_t hash_bytes(const void* ptr, size_t len, size_t seed); 10 | #elif __SIZEOF_SIZE_T__ == 8 11 | size_t hash_bytes(const void* ptr, size_t len); 12 | #endif 13 | size_t hash_wstring(const std::wstring &str); 14 | size_t hash_substring_ptr(wchar_t const* ptr, int start, int end); // endを含む 15 | size_t hash_substring(const std::wstring &str, int start, int end); // endを含む 16 | } -------------------------------------------------------------------------------- /src/npylm/hashmap/hashmap.h: -------------------------------------------------------------------------------- 1 | // By Emil Ernerfeldt 2014-2016 2 | // LICENSE: 3 | // This software is dual-licensed to the public domain and under the following 4 | // license: you are granted a perpetual, irrevocable license to copy, modify, 5 | // publish, and distribute this file as you see fit. 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include "../common.h" 17 | 18 | namespace emilib { 19 | 20 | enum class State : uint8_t 21 | { 22 | INACTIVE, // Never been touched 23 | ACTIVE, // Is inside a search-chain, but is empty 24 | FILLED // Is set with key/value 25 | }; 26 | 27 | /// like std::equal_to but no need to #include 28 | template 29 | struct HashMapEqualTo 30 | { 31 | constexpr bool operator()(const T &lhs, const T &rhs) const 32 | { 33 | return lhs == rhs; 34 | } 35 | }; 36 | 37 | /// A cache-friendly hash table with open addressing, linear probing and power-of-two capacity 38 | template , typename CompT = HashMapEqualTo> 39 | class HashMap 40 | { 41 | private: 42 | using MyType = HashMap; 43 | using PairT = std::pair; 44 | public: 45 | using size_type = size_t; 46 | using value_type = PairT; 47 | using reference = PairT&; 48 | using const_reference = const PairT&; 49 | 50 | class iterator 51 | { 52 | public: 53 | using iterator_category = std::forward_iterator_tag; 54 | using difference_type = size_t; 55 | using distance_type = size_t; 56 | using value_type = std::pair; 57 | using pointer = value_type*; 58 | using reference = value_type&; 59 | 60 | iterator() { } 61 | 62 | iterator(MyType* hash_map, size_t bucket) : _map(hash_map), _bucket(bucket) 63 | { 64 | } 65 | 66 | iterator& operator++() 67 | { 68 | this->goto_next_element(); 69 | return *this; 70 | } 71 | 72 | iterator operator++(int) 73 | { 74 | size_t old_index = _bucket; 75 | this->goto_next_element(); 76 | return iterator(_map, old_index); 77 | } 78 | 79 | reference operator*() const 80 | { 81 | return _map->_pairs[_bucket]; 82 | } 83 | 84 | pointer operator->() const 85 | { 86 | return _map->_pairs + _bucket; 87 | } 88 | 89 | bool operator==(const iterator& rhs) 90 | { 91 | assert(_map == rhs._map); 92 | return this->_bucket == rhs._bucket; 93 | } 94 | 95 | bool operator!=(const iterator& rhs) 96 | { 97 | assert(_map == rhs._map); 98 | return this->_bucket != rhs._bucket; 99 | } 100 | 101 | private: 102 | void goto_next_element() 103 | { 104 | assert(_bucket < _map->_num_buckets); 105 | do { 106 | _bucket++; 107 | } while (_bucket < _map->_num_buckets && _map->_states[_bucket] != State::FILLED); 108 | } 109 | 110 | //private: 111 | // friend class MyType; 112 | public: 113 | MyType* _map; 114 | size_t _bucket; 115 | }; 116 | 117 | class const_iterator 118 | { 119 | public: 120 | using iterator_category = std::forward_iterator_tag; 121 | using difference_type = size_t; 122 | using distance_type = size_t; 123 | using value_type = const std::pair; 124 | using pointer = value_type*; 125 | using reference = value_type&; 126 | 127 | const_iterator() { } 128 | 129 | const_iterator(iterator proto) : _map(proto._map), _bucket(proto._bucket) 130 | { 131 | } 132 | 133 | const_iterator(const MyType* hash_map, size_t bucket) : _map(hash_map), _bucket(bucket) 134 | { 135 | } 136 | 137 | const_iterator& operator++() 138 | { 139 | this->goto_next_element(); 140 | return *this; 141 | } 142 | 143 | const_iterator operator++(int) 144 | { 145 | size_t old_index = _bucket; 146 | this->goto_next_element(); 147 | return const_iterator(_map, old_index); 148 | } 149 | 150 | reference operator*() const 151 | { 152 | return _map->_pairs[_bucket]; 153 | } 154 | 155 | pointer operator->() const 156 | { 157 | return _map->_pairs + _bucket; 158 | } 159 | 160 | bool operator==(const const_iterator& rhs) 161 | { 162 | assert(_map == rhs._map); 163 | return this->_bucket == rhs._bucket; 164 | } 165 | 166 | bool operator!=(const const_iterator& rhs) 167 | { 168 | assert(_map == rhs._map); 169 | return this->_bucket != rhs._bucket; 170 | } 171 | 172 | private: 173 | void goto_next_element() 174 | { 175 | assert(_bucket < _map->_num_buckets); 176 | do { 177 | _bucket++; 178 | } while (_bucket < _map->_num_buckets && _map->_states[_bucket] != State::FILLED); 179 | } 180 | 181 | //private: 182 | // friend class MyType; 183 | public: 184 | const MyType* _map; 185 | size_t _bucket; 186 | }; 187 | 188 | // ------------------------------------------------------------------------ 189 | 190 | HashMap() = default; 191 | 192 | HashMap(const HashMap& other) 193 | { 194 | reserve(other.size()); 195 | // insert(cbegin(other), cend(other)); 196 | insert(other.begin(), other.end()); 197 | } 198 | 199 | HashMap(HashMap&& other) 200 | { 201 | *this = std::move(other); 202 | } 203 | 204 | HashMap& operator=(const HashMap& other) 205 | { 206 | clear(); 207 | reserve(other.size()); 208 | // insert(cbegin(other), cend(other)); 209 | insert(other.begin(), other.end()); 210 | return *this; 211 | } 212 | 213 | void operator=(HashMap&& other) 214 | { 215 | this->swap(other); 216 | } 217 | 218 | ~HashMap() 219 | { 220 | for (size_t bucket=0; bucket<_num_buckets; ++bucket) { 221 | if (_states[bucket] == State::FILLED) { 222 | _pairs[bucket].~PairT(); 223 | } 224 | } 225 | free(_states); 226 | free(_pairs); 227 | } 228 | 229 | void swap(HashMap& other) 230 | { 231 | std::swap(_hasher, other._hasher); 232 | std::swap(_comp, other._comp); 233 | std::swap(_states, other._states); 234 | std::swap(_pairs, other._pairs); 235 | std::swap(_num_buckets, other._num_buckets); 236 | std::swap(_num_filled, other._num_filled); 237 | std::swap(_max_probe_length, other._max_probe_length); 238 | std::swap(_mask, other._mask); 239 | } 240 | 241 | // ------------------------------------------------------------- 242 | 243 | iterator begin() 244 | { 245 | size_t bucket = 0; 246 | while (bucket<_num_buckets && _states[bucket] != State::FILLED) { 247 | ++bucket; 248 | } 249 | return iterator(this, bucket); 250 | } 251 | 252 | const_iterator begin() const 253 | { 254 | size_t bucket = 0; 255 | while (bucket<_num_buckets && _states[bucket] != State::FILLED) { 256 | ++bucket; 257 | } 258 | return const_iterator(this, bucket); 259 | } 260 | 261 | iterator end() 262 | { return iterator(this, _num_buckets); } 263 | 264 | const_iterator end() const 265 | { return const_iterator(this, _num_buckets); } 266 | 267 | size_t size() const 268 | { 269 | return _num_filled; 270 | } 271 | 272 | bool empty() const 273 | { 274 | return _num_filled==0; 275 | } 276 | 277 | // ------------------------------------------------------------ 278 | 279 | iterator find(const KeyT& key) 280 | { 281 | auto bucket = this->find_filled_bucket(key); 282 | if (bucket == (size_t)-1) { 283 | return this->end(); 284 | } 285 | return iterator(this, bucket); 286 | } 287 | 288 | const_iterator find(const KeyT& key) const 289 | { 290 | auto bucket = this->find_filled_bucket(key); 291 | if (bucket == (size_t)-1) 292 | { 293 | return this->end(); 294 | } 295 | return const_iterator(this, bucket); 296 | } 297 | 298 | bool contains(const KeyT& k) const 299 | { 300 | return find_filled_bucket(k) != (size_t)-1; 301 | } 302 | 303 | size_t count(const KeyT& k) const 304 | { 305 | return find_filled_bucket(k) != (size_t)-1 ? 1 : 0; 306 | } 307 | 308 | /// Returns the matching ValueT or nullptr if k isn't found. 309 | ValueT* try_get(const KeyT& k) 310 | { 311 | auto bucket = find_filled_bucket(k); 312 | if (bucket != (size_t)-1) { 313 | return &_pairs[bucket].second; 314 | } else { 315 | return nullptr; 316 | } 317 | } 318 | 319 | /// Const version of the above 320 | const ValueT* try_get(const KeyT& k) const 321 | { 322 | auto bucket = find_filled_bucket(k); 323 | if (bucket != (size_t)-1) { 324 | return &_pairs[bucket].second; 325 | } else { 326 | return nullptr; 327 | } 328 | } 329 | 330 | /// Convenience function. 331 | const ValueT get_or_return_default(const KeyT& k) const 332 | { 333 | const ValueT* ret = try_get(k); 334 | if (ret) { 335 | return *ret; 336 | } else { 337 | return ValueT(); 338 | } 339 | } 340 | 341 | // ----------------------------------------------------- 342 | 343 | /// Returns a pair consisting of an iterator to the inserted element 344 | /// (or to the element that prevented the insertion) 345 | /// and a bool denoting whether the insertion took place. 346 | std::pair insert(const KeyT& key, const ValueT& value) 347 | { 348 | check_expand_need(); 349 | 350 | auto bucket = find_or_allocate(key); 351 | 352 | if (_states[bucket] == State::FILLED) { 353 | return { iterator(this, bucket), false }; 354 | } else { 355 | _states[bucket] = State::FILLED; 356 | new(_pairs + bucket) PairT(key, value); 357 | _num_filled++; 358 | return { iterator(this, bucket), true }; 359 | } 360 | } 361 | 362 | std::pair insert(const std::pair& p) 363 | { 364 | return insert(p.first, p.second); 365 | } 366 | 367 | void insert(const_iterator begin, const_iterator end) 368 | { 369 | for (; begin != end; ++begin) { 370 | insert(begin->first, begin->second); 371 | } 372 | } 373 | 374 | /// Same as above, but contains(key) MUST be false 375 | void insert_unique(KeyT&& key, ValueT&& value) 376 | { 377 | assert(!contains(key)); 378 | check_expand_need(); 379 | auto bucket = find_empty_bucket(key); 380 | _states[bucket] = State::FILLED; 381 | new(_pairs + bucket) PairT(std::move(key), std::move(value)); 382 | _num_filled++; 383 | } 384 | 385 | void insert_unique(std::pair&& p) 386 | { 387 | insert_unique(std::move(p.first), std::move(p.second)); 388 | } 389 | 390 | /// Return the old value or ValueT() if it didn't exist. 391 | ValueT set_get(const KeyT& key, const ValueT& new_value) 392 | { 393 | check_expand_need(); 394 | 395 | auto bucket = find_or_allocate(key); 396 | 397 | // Check if inserting a new value rather than overwriting an old entry 398 | if (_states[bucket] == State::FILLED) { 399 | ValueT old_value = _pairs[bucket].second; 400 | _pairs[bucket] = new_value.second; 401 | return old_value; 402 | } else { 403 | _states[bucket] = State::FILLED; 404 | new(_pairs + bucket) PairT(key, new_value); 405 | _num_filled++; 406 | return ValueT(); 407 | } 408 | } 409 | 410 | /// Like std::map::operator[]. 411 | ValueT& operator[](const KeyT& key) 412 | { 413 | check_expand_need(); 414 | 415 | auto bucket = find_or_allocate(key); 416 | 417 | /* Check if inserting a new value rather than overwriting an old entry */ 418 | if (_states[bucket] != State::FILLED) { 419 | _states[bucket] = State::FILLED; 420 | new(_pairs + bucket) PairT(key, ValueT()); 421 | _num_filled++; 422 | } 423 | 424 | return _pairs[bucket].second; 425 | } 426 | 427 | // ------------------------------------------------------- 428 | 429 | /// Erase an element from the hash table. 430 | /// return false if element was not found 431 | bool erase(const KeyT& key) 432 | { 433 | auto bucket = find_filled_bucket(key); 434 | if (bucket != (size_t)-1) { 435 | _states[bucket] = State::ACTIVE; 436 | _pairs[bucket].~PairT(); 437 | _num_filled -= 1; 438 | return true; 439 | } else { 440 | return false; 441 | } 442 | } 443 | 444 | /// Erase an element using an iterator. 445 | /// Returns an iterator to the next element (or end()). 446 | iterator erase(iterator it) 447 | { 448 | assert(it._map == this); 449 | assert(it._bucket < _num_buckets); 450 | _states[it._bucket] = State::ACTIVE; 451 | _pairs[it._bucket].~PairT(); 452 | _num_filled -= 1; 453 | return ++it; 454 | } 455 | 456 | /// Remove all elements, keeping full capacity. 457 | void clear() 458 | { 459 | for (size_t bucket=0; bucket<_num_buckets; ++bucket) { 460 | if (_states[bucket] == State::FILLED) { 461 | _states[bucket] = State::INACTIVE; 462 | _pairs[bucket].~PairT(); 463 | } 464 | } 465 | _num_filled = 0; 466 | _max_probe_length = -1; 467 | } 468 | 469 | /// Make room for this many elements 470 | void reserve(size_t num_elems) 471 | { 472 | size_t required_buckets = num_elems + num_elems/2 + 1; 473 | if (required_buckets <= _num_buckets) { 474 | return; 475 | } 476 | size_t num_buckets = 4; 477 | while (num_buckets < required_buckets) { num_buckets *= 2; } 478 | 479 | auto new_states = (State*)malloc(num_buckets * sizeof(State)); 480 | auto new_pairs = (PairT*)malloc(num_buckets * sizeof(PairT)); 481 | 482 | if (!new_states || !new_pairs) { 483 | free(new_states); 484 | free(new_pairs); 485 | throw std::bad_alloc(); 486 | } 487 | 488 | //auto old_num_filled = _num_filled; 489 | auto old_num_buckets = _num_buckets; 490 | auto old_states = _states; 491 | auto old_pairs = _pairs; 492 | 493 | _num_filled = 0; 494 | _num_buckets = num_buckets; 495 | _mask = _num_buckets - 1; 496 | _states = new_states; 497 | _pairs = new_pairs; 498 | 499 | std::fill_n(_states, num_buckets, State::INACTIVE); 500 | 501 | _max_probe_length = -1; 502 | 503 | for (size_t src_bucket=0; src_bucket _max_probe_length) { 599 | _max_probe_length = offset; 600 | } 601 | return bucket; 602 | } 603 | } 604 | } 605 | 606 | public: 607 | HashT _hasher; 608 | CompT _comp; 609 | State* _states = nullptr; 610 | PairT* _pairs = nullptr; 611 | size_t _num_buckets = 0; 612 | size_t _num_filled = 0; 613 | int _max_probe_length = -1; // Our longest bucket-brigade is this long. ONLY when we have zero elements is this ever negative (-1). 614 | size_t _mask = 0; // _num_buckets minus one 615 | 616 | friend class boost::serialization::access; 617 | template 618 | void serialize(Archive &archive, unsigned int version) 619 | { 620 | boost::serialization::split_free(archive, *this, version); 621 | } 622 | }; 623 | } // namespace emilib 624 | 625 | namespace boost { namespace serialization { 626 | template, typename CompT = emilib::HashMapEqualTo> 627 | void save(Archive &archive, const emilib::HashMap &hmap, unsigned int version) { 628 | archive & hmap.size(); 629 | for(auto itr = hmap.begin();itr != hmap.end();itr++){ 630 | archive & itr->first; 631 | archive & itr->second; 632 | } 633 | } 634 | template, typename CompT = emilib::HashMapEqualTo> 635 | void load(Archive &archive, emilib::HashMap &hmap, unsigned int version) { 636 | size_t map_size = 0; 637 | archive & map_size; 638 | hmap.clear(); 639 | for(int i = 0;i < map_size;i++){ 640 | KeyT key; 641 | ValueT value; 642 | archive & key; 643 | archive & value; 644 | hmap[key] = value; 645 | } 646 | } 647 | }} // namespace boost::serialization 648 | -------------------------------------------------------------------------------- /src/npylm/lattice.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "common.h" 4 | #include "npylm.h" 5 | 6 | namespace npylm { 7 | namespace lattice { 8 | void _init_alpha(double*** &alpha, int size, int max_word_length); 9 | void _delete_alpha(double*** &alpha, int size, int max_word_length); 10 | } 11 | class Lattice { 12 | private: 13 | void _allocate_capacity(int max_word_length, int max_sentence_length); 14 | void _delete_capacity(); 15 | public: 16 | NPYLM* _npylm; 17 | id* _word_ids; 18 | id** _substring_word_id_cache; 19 | double*** _alpha; // 前向き確率 20 | double**** _pw_h; // キャッシュ 21 | double* _log_z; // 正規化定数 22 | double* _scaling; // スケーリング係数 23 | double* _backward_sampling_table; 24 | int*** _viterbi_backward; 25 | int _max_word_length; 26 | int _max_sentence_length; 27 | Lattice(NPYLM* npylm, int max_word_length, int max_sentence_length); 28 | ~Lattice(); 29 | void reserve(int max_word_length, int max_sentence_length); 30 | id get_substring_word_id_at_t_k(Sentence* sentence, int t, int k); 31 | void blocked_gibbs(Sentence* sentence, std::vector &segments, bool use_scaling = true); 32 | void viterbi_argmax_alpha_t_k_j(Sentence* sentence, int t, int k, int j); 33 | void viterbi_forward(Sentence* sentence); 34 | void viterbi_argmax_backward_k_and_j_to_eos(Sentence* sentence, int t, int next_word_length, int &argmax_k, int &argmax_j); 35 | void viterbi_backward(Sentence* sentence, std::vector &segments); 36 | void viterbi_decode(Sentence* sentence, std::vector &segments); 37 | double compute_log_forward_probability(Sentence* sentence, bool use_scaling); 38 | void _enumerate_forward_variables(Sentence* sentence, double*** alpha, double* scaling, bool use_scaling); 39 | void _sum_alpha_t_k_j(Sentence* sentence, double*** alpha, double**** pw_h_t_k_j_i, int t, int k, int j, double prod_scaling); 40 | void _forward_filtering(Sentence* sentence, double*** alpha, double* scaling, double**** pw_h_t_k_j_i, bool use_scaling = true); 41 | void _backward_sampling(Sentence* sentence, std::vector &segments, double*** alpha, double**** pw_h_t_k_j_i); 42 | void _sample_backward_k_and_j(Sentence* sentence, double*** alpha, double**** pw_h_t_k_j_i, int t, int next_word_length, int &sampled_k, int &sampled_j); 43 | }; 44 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/lm/hpylm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "../sampler.h" 7 | #include "hpylm.h" 8 | 9 | namespace npylm { 10 | namespace lm { 11 | HPYLM::HPYLM(int ngram){ 12 | // 深さは0から始まることに注意 13 | // 2-gramなら最大深さは1. root(0) -> 2-gram(1) 14 | // 3-gramなら最大深さは2. root(0) -> 2-gram(1) -> 3-gram(2) 15 | _depth = ngram - 1; 16 | 17 | _root = new Node(0); 18 | _root->_depth = 0; // ルートは深さ0 19 | 20 | for(int n = 0;n < ngram;n++){ 21 | _d_m.push_back(HPYLM_INITIAL_D); 22 | _theta_m.push_back(HPYLM_INITIAL_THETA); 23 | _a_m.push_back(HPYLM_BETA_A); 24 | _b_m.push_back(HPYLM_BETA_B); 25 | _alpha_m.push_back(HPYLM_GAMMA_ALPHA); 26 | _beta_m.push_back(HPYLM_GAMMA_BETA); 27 | } 28 | } 29 | HPYLM::~HPYLM(){ 30 | _delete_node(_root); 31 | } 32 | template 33 | void HPYLM::serialize(Archive& archive, unsigned int version) 34 | { 35 | archive & _root; 36 | archive & _depth; 37 | archive & _g0; 38 | archive & _d_m; 39 | archive & _theta_m; 40 | archive & _a_m; 41 | archive & _b_m; 42 | archive & _alpha_m; 43 | archive & _beta_m; 44 | } 45 | template void HPYLM::serialize(boost::archive::binary_iarchive &ar, unsigned int version); 46 | template void HPYLM::serialize(boost::archive::binary_oarchive &ar, unsigned int version); 47 | } 48 | } -------------------------------------------------------------------------------- /src/npylm/lm/hpylm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "../common.h" 5 | #include "model.h" 6 | #include "node.h" 7 | 8 | namespace npylm { 9 | namespace lm { 10 | class HPYLM: public Model { 11 | private: 12 | friend class boost::serialization::access; 13 | template 14 | void serialize(Archive& archive, unsigned int version); 15 | public: 16 | HPYLM(int ngram = 2); 17 | ~HPYLM(); 18 | }; 19 | } 20 | } -------------------------------------------------------------------------------- /src/npylm/lm/model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "../common.h" 9 | #include "../sampler.h" 10 | #include "node.h" 11 | 12 | namespace npylm { 13 | namespace lm { 14 | template 15 | class Model { 16 | public: 17 | Node* _root; // 文脈木のルートノード 18 | int _depth; // 最大の深さ. HPYLMは固定 19 | double _g0; // ゼログラム確率 20 | // 深さmのノードに関するパラメータ 21 | std::vector _d_m; // Pitman-Yor過程のディスカウント係数 22 | std::vector _theta_m; // Pitman-Yor過程の集中度 23 | // "A Bayesian Interpretation of Interpolated Kneser-Ney" Appendix C参照 24 | // http://www.gatsby.ucl.ac.uk/~ywteh/research/compling/hpylm.pdf 25 | std::vector _a_m; // ベータ分布のパラメータ dの推定用 26 | std::vector _b_m; // ベータ分布のパラメータ dの推定用 27 | std::vector _alpha_m; // ガンマ分布のパラメータ θの推定用 28 | std::vector _beta_m; // ガンマ分布のパラメータ θの推定用 29 | void _delete_node(Node* node){ 30 | for(auto &elem: node->_children){ 31 | Node* child = elem.second; 32 | _delete_node(child); 33 | } 34 | delete node; 35 | } 36 | int get_num_nodes(){ 37 | return _root->get_num_nodes() + 1; 38 | } 39 | int get_num_customers(){ 40 | return _root->get_num_customers(); 41 | } 42 | int get_num_tables(){ 43 | return _root->get_num_tables(); 44 | } 45 | int get_sum_stop_counts(){ 46 | return _root->sum_stop_counts(); 47 | } 48 | int get_sum_pass_counts(){ 49 | return _root->sum_pass_counts(); 50 | } 51 | void set_g0(double g0){ 52 | _g0 = g0; 53 | } 54 | void init_hyperparameters_at_depth_if_needed(int depth){ 55 | if(depth >= _d_m.size()){ 56 | while(_d_m.size() <= depth){ 57 | _d_m.push_back(HPYLM_INITIAL_D); 58 | } 59 | } 60 | if(depth >= _theta_m.size()){ 61 | while(_theta_m.size() <= depth){ 62 | _theta_m.push_back(HPYLM_INITIAL_THETA); 63 | } 64 | } 65 | if(depth >= _a_m.size()){ 66 | while(_a_m.size() <= depth){ 67 | _a_m.push_back(HPYLM_BETA_A); 68 | } 69 | } 70 | if(depth >= _b_m.size()){ 71 | while(_b_m.size() <= depth){ 72 | _b_m.push_back(HPYLM_BETA_B); 73 | } 74 | } 75 | if(depth >= _alpha_m.size()){ 76 | while(_alpha_m.size() <= depth){ 77 | _alpha_m.push_back(HPYLM_GAMMA_ALPHA); 78 | } 79 | } 80 | if(depth >= _beta_m.size()){ 81 | while(_beta_m.size() <= depth){ 82 | _beta_m.push_back(HPYLM_GAMMA_BETA); 83 | } 84 | } 85 | } 86 | // "A Bayesian Interpretation of Interpolated Kneser-Ney" Appendix C参照 87 | // http://www.gatsby.ucl.ac.uk/~ywteh/research/compling/hpylm.pdf 88 | void sum_auxiliary_variables_recursively(Node* node, std::vector &sum_log_x_u_m, std::vector &sum_y_ui_m, std::vector &sum_1_y_ui_m, std::vector &sum_1_z_uwkj_m, int &bottom){ 89 | for(auto elem: node->_children){ 90 | Node* child = elem.second; 91 | int depth = child->_depth; 92 | 93 | if(depth > bottom){ 94 | bottom = depth; 95 | } 96 | init_hyperparameters_at_depth_if_needed(depth); 97 | 98 | double d = _d_m[depth]; 99 | double theta = _theta_m[depth]; 100 | sum_log_x_u_m[depth] += child->auxiliary_log_x_u(theta); // log(x_u) 101 | sum_y_ui_m[depth] += child->auxiliary_y_ui(d, theta); // y_ui 102 | sum_1_y_ui_m[depth] += child->auxiliary_1_y_ui(d, theta); // 1 - y_ui 103 | sum_1_z_uwkj_m[depth] += child->auxiliary_1_z_uwkj(d); // 1 - z_uwkj 104 | 105 | sum_auxiliary_variables_recursively(child, sum_log_x_u_m, sum_y_ui_m, sum_1_y_ui_m, sum_1_z_uwkj_m, bottom); 106 | } 107 | } 108 | // dとθの推定 109 | void sample_hyperparams(){ 110 | int max_depth = _d_m.size() - 1; 111 | 112 | // 親ノードの深さが0であることに注意 113 | std::vector sum_log_x_u_m(max_depth + 1, 0.0); 114 | std::vector sum_y_ui_m(max_depth + 1, 0.0); 115 | std::vector sum_1_y_ui_m(max_depth + 1, 0.0); 116 | std::vector sum_1_z_uwkj_m(max_depth + 1, 0.0); 117 | 118 | // _root 119 | sum_log_x_u_m[0] = _root->auxiliary_log_x_u(_theta_m[0]); // log(x_u) 120 | sum_y_ui_m[0] = _root->auxiliary_y_ui(_d_m[0], _theta_m[0]); // y_ui 121 | sum_1_y_ui_m[0] = _root->auxiliary_1_y_ui(_d_m[0], _theta_m[0]); // 1 - y_ui 122 | sum_1_z_uwkj_m[0] = _root->auxiliary_1_z_uwkj(_d_m[0]); // 1 - z_uwkj 123 | 124 | // それ以外 125 | _depth = 0; 126 | // __depthは以下を実行すると更新される 127 | // HPYLMでは無意味だがVPYLMで最大深さを求める時に使う 128 | sum_auxiliary_variables_recursively(_root, sum_log_x_u_m, sum_y_ui_m, sum_1_y_ui_m, sum_1_z_uwkj_m, _depth); 129 | init_hyperparameters_at_depth_if_needed(_depth); 130 | 131 | for(int u = 0;u <= _depth;u++){ 132 | _d_m[u] = sampler::beta(_a_m[u] + sum_1_y_ui_m[u], _b_m[u] + sum_1_z_uwkj_m[u]); 133 | _theta_m[u] = sampler::gamma(_alpha_m[u] + sum_y_ui_m[u], _beta_m[u] - sum_log_x_u_m[u]); 134 | } 135 | // 不要な深さのハイパーパラメータを削除 136 | int num_remove = _d_m.size() - _depth - 1; 137 | for(int n = 0;n < num_remove;n++){ 138 | _d_m.pop_back(); 139 | _theta_m.pop_back(); 140 | _a_m.pop_back(); 141 | _b_m.pop_back(); 142 | _alpha_m.pop_back(); 143 | _beta_m.pop_back(); 144 | } 145 | } 146 | }; 147 | template class Model; 148 | template class Model; 149 | } // namespace lm 150 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/lm/node.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "../common.h" 15 | #include "../sampler.h" 16 | 17 | namespace npylm { 18 | namespace lm { 19 | template 20 | class Node { 21 | public: 22 | hashmap _children; // 子の文脈木 23 | hashmap> _arrangement; // 客の配置 std::vectorのk番目の要素がテーブルkの客数を表す 24 | Node* _parent; // 親ノード 25 | int _num_tables; // 総テーブル数 26 | int _num_customers; // 客の総数 27 | int _stop_count; // 停止回数. VPYLM用 28 | int _pass_count; // 通過回数. VPYLM用 29 | int _depth; // ノードの深さ. rootが0であることに注意 30 | T _token_id; // このノードに割り当てられた単語ID(または文字ID) 31 | Node(){} 32 | Node(T token_id){ 33 | _num_tables = 0; 34 | _num_customers = 0; 35 | _stop_count = 0; 36 | _pass_count = 0; 37 | _token_id = token_id; 38 | _parent = NULL; 39 | } 40 | bool parent_exists(){ 41 | return !(_parent == NULL); 42 | } 43 | bool child_exists(T token_id){ 44 | return !(_children.find(token_id) == _children.end()); 45 | } 46 | bool need_to_remove_from_parent(){ 47 | if(_parent == NULL){ 48 | return false; 49 | } 50 | if(_children.size() == 0 && _arrangement.size() == 0){ 51 | return true; 52 | } 53 | return false; 54 | } 55 | int get_num_tables_serving_word(T token_id){ 56 | if(_arrangement.find(token_id) == _arrangement.end()){ 57 | return 0; 58 | } 59 | return _arrangement[token_id].size(); 60 | } 61 | int get_num_customers_eating_word(T token_id){ 62 | if(_arrangement.find(token_id) == _arrangement.end()){ 63 | return 0; 64 | } 65 | std::vector &num_customers_at_table = _arrangement[token_id]; 66 | int sum = 0; 67 | for(int i = 0;i < num_customers_at_table.size();i++){ 68 | sum += num_customers_at_table[i]; 69 | } 70 | return sum; 71 | } 72 | Node* find_child_node(T token_id, bool generate_if_not_exist = false){ 73 | auto itr = _children.find(token_id); 74 | if (itr != _children.end()) { 75 | return itr->second; 76 | } 77 | if(generate_if_not_exist == false){ 78 | return NULL; 79 | } 80 | Node* child = new Node(token_id); 81 | child->_parent = this; 82 | child->_depth = _depth + 1; 83 | _children[token_id] = child; 84 | return child; 85 | } 86 | // 客をテーブルに追加 87 | bool add_customer_to_table(T token_id, int table_k, double g0, std::vector &d_m, std::vector &theta_m, int &added_to_table_k_of_root){ 88 | auto itr = _arrangement.find(token_id); 89 | if(itr == _arrangement.end()){ 90 | return add_customer_to_new_table(token_id, g0, d_m, theta_m, added_to_table_k_of_root); 91 | } 92 | std::vector &num_customers_at_table = itr->second; 93 | assert(table_k < num_customers_at_table.size()); 94 | num_customers_at_table[table_k]++; 95 | _num_customers++; 96 | return true; 97 | } 98 | bool add_customer_to_table(T token_id, int table_k, double* parent_pw_at_depth, std::vector &d_m, std::vector &theta_m, int &added_to_table_k_of_root){ 99 | auto itr = _arrangement.find(token_id); 100 | if(itr == _arrangement.end()){ 101 | return add_customer_to_new_table(token_id, parent_pw_at_depth, d_m, theta_m, added_to_table_k_of_root); 102 | } 103 | std::vector &num_customers_at_table = itr->second; 104 | assert(table_k < num_customers_at_table.size()); 105 | num_customers_at_table[table_k]++; 106 | _num_customers++; 107 | return true; 108 | } 109 | bool add_customer_to_new_table(T token_id, double g0, std::vector &d_m, std::vector &theta_m, int &added_to_table_k_of_root){ 110 | _add_customer_to_new_table(token_id); 111 | if(_parent != NULL){ 112 | bool success = _parent->add_customer(token_id, g0, d_m, theta_m, false, added_to_table_k_of_root); 113 | assert(success == true); 114 | } 115 | return true; 116 | } 117 | bool add_customer_to_new_table(T token_id, double* parent_pw_at_depth, std::vector &d_m, std::vector &theta_m, int &added_to_table_k_of_root){ 118 | _add_customer_to_new_table(token_id); 119 | if(_parent != NULL){ 120 | bool success = _parent->add_customer(token_id, parent_pw_at_depth, d_m, theta_m, false, added_to_table_k_of_root); 121 | assert(success == true); 122 | } 123 | return true; 124 | } 125 | void _add_customer_to_new_table(T token_id){ 126 | auto itr = _arrangement.find(token_id); 127 | if(itr == _arrangement.end()){ 128 | std::vector tables = {1}; 129 | _arrangement[token_id] = tables; 130 | }else{ 131 | std::vector &num_customers_at_table = itr->second; 132 | num_customers_at_table.push_back(1); 133 | } 134 | _num_tables++; 135 | _num_customers++; 136 | } 137 | bool remove_customer_from_table(T token_id, int table_k, int &removed_from_table_k_of_root){ 138 | auto itr = _arrangement.find(token_id); 139 | assert(itr != _arrangement.end()); 140 | std::vector &num_customers_at_table = itr->second; 141 | assert(table_k < num_customers_at_table.size()); 142 | num_customers_at_table[table_k]--; 143 | _num_customers--; 144 | assert(num_customers_at_table[table_k] >= 0); 145 | if(num_customers_at_table[table_k] == 0){ 146 | if(_parent != NULL){ 147 | bool success = _parent->remove_customer(token_id, false, removed_from_table_k_of_root); 148 | assert(success == true); 149 | } 150 | num_customers_at_table.erase(num_customers_at_table.begin() + table_k); 151 | _num_tables--; 152 | if(num_customers_at_table.size() == 0){ 153 | _arrangement.erase(token_id); 154 | } 155 | } 156 | return true; 157 | } 158 | // NPYLMではルートノードのどのテーブルに代理客が追加されたかを知る必要がある 159 | // 再帰計算があるので遅い 160 | bool add_customer(T token_id, double g0, std::vector &d_m, std::vector &theta_m, bool update_beta_count, int &added_to_table_k_of_root){ 161 | init_hyperparameters_at_depth_if_needed(_depth, d_m, theta_m); 162 | double d_u = d_m[_depth]; 163 | double theta_u = theta_m[_depth]; 164 | double parent_pw = g0; 165 | if(_parent){ 166 | parent_pw = _parent->compute_p_w(token_id, g0, d_m, theta_m); 167 | } 168 | auto itr = _arrangement.find(token_id); 169 | if(itr == _arrangement.end()){ 170 | add_customer_to_new_table(token_id, g0, d_m, theta_m, added_to_table_k_of_root); 171 | if(update_beta_count == true){ 172 | increment_stop_count(); 173 | } 174 | if(_depth == 0){ // if root node 175 | added_to_table_k_of_root = 0; 176 | } 177 | return true; 178 | } 179 | 180 | std::vector &num_customers_at_table = itr->second; 181 | double sum = 0; 182 | for(int k = 0;k < num_customers_at_table.size();k++){ 183 | sum += std::max(0.0, num_customers_at_table[k] - d_u); 184 | } 185 | double t_u = _num_tables; 186 | sum += (theta_u + d_u * t_u) * parent_pw; 187 | 188 | double normalizer = 1.0 / sum; 189 | double bernoulli = sampler::uniform(0, 1); 190 | double stack = 0; 191 | for(int k = 0;k < num_customers_at_table.size();k++){ 192 | stack += std::max(0.0, num_customers_at_table[k] - d_u) * normalizer; 193 | if(bernoulli <= stack){ 194 | add_customer_to_table(token_id, k, g0, d_m, theta_m, added_to_table_k_of_root); 195 | if(update_beta_count){ 196 | increment_stop_count(); 197 | } 198 | if(_depth == 0){ 199 | added_to_table_k_of_root = k; 200 | } 201 | return true; 202 | } 203 | } 204 | add_customer_to_new_table(token_id, g0, d_m, theta_m, added_to_table_k_of_root); 205 | if(update_beta_count){ 206 | increment_stop_count(); 207 | } 208 | if(_depth == 0){ 209 | added_to_table_k_of_root = num_customers_at_table.size() - 1; 210 | } 211 | return true; 212 | } 213 | // NPYLMではルートノードのどのテーブルに代理客が追加されたかを知る必要がある 214 | // 再帰計算を防ぐため親の確率のキャッシュを使う 215 | bool add_customer(T token_id, double* parent_pw_at_depth, std::vector &d_m, std::vector &theta_m, bool update_beta_count, int &added_to_table_k_of_root){ 216 | init_hyperparameters_at_depth_if_needed(_depth, d_m, theta_m); 217 | double d_u = d_m[_depth]; 218 | double theta_u = theta_m[_depth]; 219 | double parent_pw = parent_pw_at_depth[_depth]; 220 | auto itr = _arrangement.find(token_id); 221 | if(itr == _arrangement.end()){ 222 | add_customer_to_new_table(token_id, parent_pw_at_depth, d_m, theta_m, added_to_table_k_of_root); 223 | if(update_beta_count == true){ 224 | increment_stop_count(); 225 | } 226 | if(_depth == 0){ // ルートノードの場合 227 | added_to_table_k_of_root = 0; 228 | } 229 | return true; 230 | } 231 | 232 | std::vector &num_customers_at_table = itr->second; 233 | double sum = 0; 234 | for(int k = 0;k < num_customers_at_table.size();k++){ 235 | // 分母は定数なので無視 236 | sum += std::max(0.0, num_customers_at_table[k] - d_u); 237 | } 238 | double t_u = _num_tables; 239 | sum += (theta_u + d_u * t_u) * parent_pw; 240 | 241 | double normalizer = 1.0 / sum; 242 | double bernoulli = sampler::uniform(0, 1); 243 | double stack = 0; 244 | // 既存のテーブルのどこかに追加 245 | for(int k = 0;k < num_customers_at_table.size();k++){ 246 | stack += std::max(0.0, num_customers_at_table[k] - d_u) * normalizer; 247 | if(bernoulli <= stack){ 248 | add_customer_to_table(token_id, k, parent_pw_at_depth, d_m, theta_m, added_to_table_k_of_root); 249 | if(update_beta_count){ 250 | increment_stop_count(); 251 | } 252 | if(_depth == 0){ // ルートノードの場合 253 | added_to_table_k_of_root = k; 254 | } 255 | return true; 256 | } 257 | } 258 | // 新しいテーブルに追加 259 | add_customer_to_new_table(token_id, parent_pw_at_depth, d_m, theta_m, added_to_table_k_of_root); 260 | if(update_beta_count){ 261 | increment_stop_count(); 262 | } 263 | if(_depth == 0){ 264 | added_to_table_k_of_root = num_customers_at_table.size() - 1; 265 | } 266 | return true; 267 | } 268 | // NPYLMではルートノードのどのテーブルから代理客が削除されたかを知る必要がある 269 | bool remove_customer(T token_id, bool update_beta_count, int &removed_from_table_k_of_root){ 270 | auto itr = _arrangement.find(token_id); 271 | assert(itr != _arrangement.end()); 272 | std::vector &num_customers_at_table = itr->second; 273 | double sum = std::accumulate(num_customers_at_table.begin(), num_customers_at_table.end(), 0); 274 | double normalizer = 1.0 / sum; 275 | double bernoulli = sampler::uniform(0, 1); 276 | double stack = 0; 277 | for(int k = 0;k < num_customers_at_table.size();k++){ 278 | stack += num_customers_at_table[k] * normalizer; 279 | if(bernoulli <= stack){ 280 | remove_customer_from_table(token_id, k, removed_from_table_k_of_root); 281 | if(update_beta_count == true){ 282 | decrement_stop_count(); 283 | } 284 | if(_depth == 0){ 285 | removed_from_table_k_of_root = k; 286 | } 287 | return true; 288 | } 289 | } 290 | remove_customer_from_table(token_id, num_customers_at_table.size() - 1, removed_from_table_k_of_root); 291 | if(update_beta_count == true){ 292 | decrement_stop_count(); 293 | } 294 | if(_depth == 0){ 295 | removed_from_table_k_of_root = num_customers_at_table.size() - 1; 296 | } 297 | return true; 298 | } 299 | // 再帰計算が含まれるので遅い 300 | double compute_p_w(T token_id, double g0, std::vector &d_m, std::vector &theta_m){ 301 | init_hyperparameters_at_depth_if_needed(_depth, d_m, theta_m); 302 | double d_u = d_m[_depth]; 303 | double theta_u = theta_m[_depth]; 304 | double t_u = _num_tables; 305 | double c_u = _num_customers; 306 | auto itr = _arrangement.find(token_id); 307 | if(itr == _arrangement.end()){ 308 | double coeff = (theta_u + d_u * t_u) / (theta_u + c_u); 309 | if(_parent != NULL){ 310 | return _parent->compute_p_w(token_id, g0, d_m, theta_m) * coeff; 311 | } 312 | return g0 * coeff; 313 | } 314 | double parent_pw = g0; 315 | if(_parent != NULL){ 316 | parent_pw = _parent->compute_p_w(token_id, g0, d_m, theta_m); 317 | } 318 | std::vector &num_customers_at_table = itr->second; 319 | double c_uw = std::accumulate(num_customers_at_table.begin(), num_customers_at_table.end(), 0); 320 | double t_uw = num_customers_at_table.size(); 321 | double first_term = std::max(0.0, c_uw - d_u * t_uw) / (theta_u + c_u); 322 | double second_coeff = (theta_u + d_u * t_u) / (theta_u + c_u); 323 | return first_term + second_coeff * parent_pw; 324 | } 325 | // 再帰計算を防ぐ 326 | double compute_p_w_with_parent_p_w(T token_id, double parent_pw, std::vector &d_m, std::vector &theta_m){ 327 | init_hyperparameters_at_depth_if_needed(_depth, d_m, theta_m); 328 | double d_u = d_m[_depth]; 329 | double theta_u = theta_m[_depth]; 330 | double t_u = _num_tables; 331 | double c_u = _num_customers; 332 | auto itr = _arrangement.find(token_id); 333 | if(itr == _arrangement.end()){ 334 | double coeff = (theta_u + d_u * t_u) / (theta_u + c_u); 335 | return parent_pw * coeff; 336 | } 337 | std::vector &num_customers_at_table = itr->second; 338 | double c_uw = std::accumulate(num_customers_at_table.begin(), num_customers_at_table.end(), 0); 339 | double t_uw = num_customers_at_table.size(); 340 | double first_term = std::max(0.0, c_uw - d_u * t_uw) / (theta_u + c_u); 341 | double second_coeff = (theta_u + d_u * t_u) / (theta_u + c_u); 342 | return first_term + second_coeff * parent_pw; 343 | } 344 | // VPYLM 345 | double stop_probability(double beta_stop, double beta_pass, bool recursive = true){ 346 | double p = (_stop_count + beta_stop) / (_stop_count + _pass_count + beta_stop + beta_pass); 347 | if(recursive == false){ 348 | return p; 349 | } 350 | if(_parent != NULL){ 351 | p *= _parent->pass_probability(beta_stop, beta_pass); 352 | } 353 | return p; 354 | } 355 | // VPYLM 356 | double pass_probability(double beta_stop, double beta_pass, bool recursive = true){ 357 | double p = (_pass_count + beta_pass) / (_stop_count + _pass_count + beta_stop + beta_pass); 358 | if(recursive == false){ 359 | return p; 360 | } 361 | if(_parent != NULL){ 362 | p *= _parent->pass_probability(beta_stop, beta_pass); 363 | } 364 | return p; 365 | } 366 | // VPYLM 367 | void increment_stop_count(){ 368 | _stop_count++; 369 | if(_parent != NULL){ 370 | _parent->increment_pass_count(); 371 | } 372 | } 373 | // VPYLM 374 | void decrement_stop_count(){ 375 | _stop_count--; 376 | assert(_stop_count >= 0); 377 | if(_parent != NULL){ 378 | _parent->decrement_pass_count(); 379 | } 380 | } 381 | // VPYLM 382 | void increment_pass_count(){ 383 | _pass_count++; 384 | if(_parent != NULL){ 385 | _parent->increment_pass_count(); 386 | } 387 | } 388 | // VPYLM 389 | void decrement_pass_count(){ 390 | _pass_count--; 391 | assert(_pass_count >= 0); 392 | if(_parent != NULL){ 393 | _parent->decrement_pass_count(); 394 | } 395 | } 396 | bool remove_from_parent(){ 397 | if(_parent == NULL){ 398 | return false; 399 | } 400 | _parent->delete_child_node(_token_id); 401 | return true; 402 | } 403 | void delete_child_node(T token_id){ 404 | Node* child = find_child_node(token_id); 405 | if(child){ 406 | _children.erase(token_id); 407 | delete child; 408 | } 409 | if(_children.size() == 0 && _arrangement.size() == 0){ 410 | remove_from_parent(); 411 | } 412 | } 413 | int get_max_depth(int base){ 414 | int max_depth = base; 415 | for(auto &elem: _children){ 416 | int depth = elem.second->get_max_depth(base + 1); 417 | if(depth > max_depth){ 418 | max_depth = depth; 419 | } 420 | } 421 | return max_depth; 422 | } 423 | int get_num_nodes(){ 424 | int num = _children.size(); 425 | for(auto &elem: _children){ 426 | num += elem.second->get_num_nodes(); 427 | } 428 | return num; 429 | } 430 | int get_num_tables(){ 431 | int num = 0; 432 | for(auto &elem: _arrangement){ 433 | num += elem.second.size(); 434 | } 435 | assert(num == _num_tables); 436 | for(auto &elem: _children){ 437 | num += elem.second->get_num_tables(); 438 | } 439 | return num; 440 | } 441 | int get_num_customers(){ 442 | int num = 0; 443 | for(auto &elem: _arrangement){ 444 | num += std::accumulate(elem.second.begin(), elem.second.end(), 0); 445 | } 446 | assert(num == _num_customers); 447 | for(auto &elem: _children){ 448 | num += elem.second->get_num_customers(); 449 | } 450 | return num; 451 | } 452 | int sum_pass_counts(){ 453 | int sum = _pass_count; 454 | for(auto &elem: _children){ 455 | sum += elem.second->sum_pass_counts(); 456 | } 457 | return sum; 458 | } 459 | int sum_stop_counts(){ 460 | int sum = _stop_count; 461 | for(auto &elem: _children){ 462 | sum += elem.second->sum_stop_counts(); 463 | } 464 | return sum; 465 | } 466 | void enumerate_nodes_at_depth(int depth, std::vector &nodes){ 467 | if(_depth == depth){ 468 | nodes.push_back(this); 469 | } 470 | for(auto &elem: _children){ 471 | elem.second->enumerate_nodes_at_depth(depth, nodes); 472 | } 473 | } 474 | // dとθの推定用 475 | // "A Bayesian Interpretation of Interpolated Kneser-Ney" Appendix C参照 476 | // http://www.gatsby.ucl.ac.uk/~ywteh/research/compling/hpylm.pdf 477 | double auxiliary_log_x_u(double theta_u){ 478 | if(_num_customers >= 2){ 479 | double x_u = sampler::beta(theta_u + 1, _num_customers - 1); 480 | return log(x_u + 1e-8); 481 | } 482 | return 0; 483 | } 484 | double auxiliary_y_ui(double d_u, double theta_u){ 485 | if(_num_tables >= 2){ 486 | double sum_y_ui = 0; 487 | for(int i = 1;i <= _num_tables - 1;i++){ 488 | double denominator = theta_u + d_u * i; 489 | assert(denominator > 0); 490 | sum_y_ui += sampler::bernoulli(theta_u / denominator);; 491 | } 492 | return sum_y_ui; 493 | } 494 | return 0; 495 | } 496 | double auxiliary_1_y_ui(double d_u, double theta_u){ 497 | if(_num_tables >= 2){ 498 | double sum_1_y_ui = 0; 499 | for(int i = 1;i <= _num_tables - 1;i++){ 500 | double denominator = theta_u + d_u * i; 501 | assert(denominator > 0); 502 | sum_1_y_ui += 1.0 - sampler::bernoulli(theta_u / denominator); 503 | } 504 | return sum_1_y_ui; 505 | } 506 | return 0; 507 | } 508 | double auxiliary_1_z_uwkj(double d_u){ 509 | double sum_z_uwkj = 0; 510 | // c_u.. 511 | for(auto elem: _arrangement){ 512 | // c_uw. 513 | std::vector &num_customers_at_table = elem.second; 514 | for(int k = 0;k < num_customers_at_table.size();k++){ 515 | // c_uwk 516 | int c_uwk = num_customers_at_table[k]; 517 | if(c_uwk >= 2){ 518 | for(int j = 1;j <= c_uwk - 1;j++){ 519 | assert(j - d_u > 0); 520 | sum_z_uwkj += 1 - sampler::bernoulli((j - 1) / (j - d_u)); 521 | } 522 | } 523 | } 524 | } 525 | return sum_z_uwkj; 526 | } 527 | void init_hyperparameters_at_depth_if_needed(int depth, std::vector &d_m, std::vector &theta_m){ 528 | if(depth >= d_m.size()){ 529 | while(d_m.size() <= depth){ 530 | d_m.push_back(HPYLM_INITIAL_D); 531 | } 532 | while(theta_m.size() <= depth){ 533 | theta_m.push_back(HPYLM_INITIAL_THETA); 534 | } 535 | } 536 | } 537 | template 538 | void serialize(Archive& archive, unsigned int version) 539 | { 540 | archive & _children; 541 | archive & _arrangement; 542 | archive & _num_tables; 543 | archive & _num_customers; 544 | archive & _parent; 545 | archive & _stop_count; 546 | archive & _pass_count; 547 | archive & _token_id; 548 | archive & _depth; 549 | } 550 | }; 551 | template class Node; 552 | template class Node; 553 | } // namespace lm 554 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/lm/vpylm.cpp: -------------------------------------------------------------------------------- 1 | #include "vpylm.h" 2 | #include "../sampler.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace npylm { 12 | namespace lm { 13 | VPYLM::VPYLM(double g0, int max_possible_depth, double beta_stop, double beta_pass) 14 | { 15 | assert(g0 > 0); 16 | _root = new Node(0); 17 | _root->_depth = 0; // ルートは深さ0 18 | // http://www.ism.ac.jp/~daichi/paper/ipsj07vpylm.pdfによると初期値は(4, 1) 19 | // しかしVPYLMは初期値にあまり依存しないらしい 20 | _beta_stop = beta_stop; 21 | _beta_pass = beta_pass; 22 | _depth = 0; 23 | _g0 = g0; 24 | _max_depth = max_possible_depth; // 訓練データ中の最大長の文の文字数が可能な最大深さになる 25 | _parent_pw_cache = new double[max_possible_depth + 1]; 26 | _sampling_table = new double[max_possible_depth + 1]; 27 | _path_nodes = new Node*[max_possible_depth + 1]; 28 | } 29 | VPYLM::~VPYLM() 30 | { 31 | _delete_node(_root); 32 | delete[] _sampling_table; 33 | delete[] _parent_pw_cache; 34 | delete[] _path_nodes; 35 | } 36 | bool VPYLM::add_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t) 37 | { 38 | assert(_parent_pw_cache != NULL); 39 | assert(0 <= depth_t && depth_t <= t); 40 | Node* node = find_node_by_tracing_back_context(character_ids, t, depth_t, _parent_pw_cache); 41 | assert(node != NULL); 42 | if (depth_t > 0) { // ルートノードは特殊なので無視 43 | assert(node->_token_id == character_ids[t - depth_t]); 44 | } 45 | assert(node->_depth == depth_t); 46 | wchar_t token_t = character_ids[t]; 47 | int tabke_k; 48 | return node->add_customer(token_t, _parent_pw_cache, _d_m, _theta_m, true, tabke_k); 49 | } 50 | // parent_pw_cacheがすでにセットされていてpath_nodesを更新する 51 | // NPYLMから呼ぶ用 52 | bool VPYLM::add_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t, double* parent_pw_cache, Node** path_nodes) 53 | { 54 | assert(path_nodes != NULL); 55 | assert(0 <= depth_t && depth_t <= t); 56 | Node* node = find_node_by_tracing_back_context(character_ids, t, depth_t, path_nodes); 57 | assert(node != NULL); 58 | if (depth_t > 0) { // ルートノードは特殊なので無視 59 | // if(node->_token_id != character_ids[t - depth_t]){ 60 | // for(int i = 0;i <= t;i++){ 61 | // std::wcout << character_ids[i]; 62 | // } 63 | // std::wcout << std::endl; 64 | // } 65 | assert(node->_token_id == character_ids[t - depth_t]); 66 | } 67 | assert(node->_depth == depth_t); 68 | wchar_t token_t = character_ids[t]; 69 | int tabke_k; 70 | return node->add_customer(token_t, parent_pw_cache, _d_m, _theta_m, true, tabke_k); 71 | } 72 | bool VPYLM::remove_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t) 73 | { 74 | assert(0 <= depth_t && depth_t <= t); 75 | Node* node = find_node_by_tracing_back_context(character_ids, t, depth_t, false, false); 76 | assert(node != NULL); 77 | if (depth_t > 0) { 78 | assert(node->_token_id == character_ids[t - depth_t]); 79 | } 80 | assert(node->_depth == depth_t); 81 | wchar_t token_t = character_ids[t]; 82 | int table_k; 83 | node->remove_customer(token_t, true, table_k); 84 | // 客が一人もいなくなったらノードを削除する 85 | if (node->need_to_remove_from_parent()) { 86 | node->remove_from_parent(); 87 | } 88 | return true; 89 | } 90 | // 文字列の位置tからorderだけ遡る 91 | // character_ids: [2, 4, 7, 1, 9, 10] 92 | // t: 3 ^ ^ 93 | // depth_t: 2 |<- <-| 94 | Node* VPYLM::find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, bool generate_node_if_needed, bool return_middle_node) 95 | { 96 | if (t - depth_t < 0) { 97 | return NULL; 98 | } 99 | Node* node = _root; 100 | for (int depth = 1; depth <= depth_t; depth++) { 101 | wchar_t context_token_id = character_ids[t - depth]; 102 | Node* child = node->find_child_node(context_token_id, generate_node_if_needed); 103 | if (child == NULL) { 104 | if (return_middle_node) { 105 | return node; 106 | } 107 | return NULL; 108 | } 109 | node = child; 110 | } 111 | assert(node->_depth == depth_t); 112 | if (depth_t > 0) { 113 | assert(node->_token_id == character_ids[t - depth_t]); 114 | } 115 | return node; 116 | } 117 | // add_customer用 118 | // 辿りながら確率をキャッシュ 119 | Node* VPYLM::find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, double* parent_pw_cache) 120 | { 121 | assert(parent_pw_cache != NULL); 122 | if (t - depth_t < 0) { 123 | return NULL; 124 | } 125 | wchar_t token_t = character_ids[t]; 126 | Node* node = _root; 127 | double parent_pw = _g0; 128 | parent_pw_cache[0] = _g0; 129 | for (int depth = 1; depth <= depth_t; depth++) { 130 | wchar_t context_token_id = character_ids[t - depth]; 131 | // 事前に確率を計算 132 | double pw = node->compute_p_w_with_parent_p_w(token_t, parent_pw, _d_m, _theta_m); 133 | assert(pw > 0); 134 | parent_pw_cache[depth] = pw; 135 | Node* child = node->find_child_node(context_token_id, true); 136 | assert(child != NULL); 137 | parent_pw = pw; 138 | node = child; 139 | } 140 | assert(node->_depth == depth_t); 141 | if (depth_t > 0) { 142 | assert(node->_token_id == character_ids[t - depth_t]); 143 | } 144 | return node; 145 | } 146 | // すでに辿ったノードのキャッシュを使いながら辿る 147 | Node* VPYLM::find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, Node** path_nodes_cache) 148 | { 149 | assert(path_nodes_cache != NULL); 150 | if (t - depth_t < 0) { 151 | return NULL; 152 | } 153 | Node* node = _root; 154 | int depth = 0; 155 | for (; depth < depth_t; depth++) { 156 | if (path_nodes_cache[depth + 1] != NULL) { 157 | node = path_nodes_cache[depth + 1]; 158 | assert(node->_depth == depth + 1); 159 | } else { 160 | wchar_t context_token_id = character_ids[t - depth - 1]; 161 | Node* child = node->find_child_node(context_token_id, true); 162 | assert(child != NULL); 163 | node = child; 164 | } 165 | } 166 | assert(node != NULL); 167 | if (depth_t > 0) { 168 | assert(node->_token_id == character_ids[t - depth_t]); 169 | } 170 | return node; 171 | } 172 | double VPYLM::compute_p_w(wchar_t const* character_ids, int character_ids_length) 173 | { 174 | return exp(compute_log_p_w(character_ids, character_ids_length)); 175 | } 176 | double VPYLM::compute_log_p_w(wchar_t const* character_ids, int character_ids_length) 177 | { 178 | wchar_t token_t = character_ids[0]; 179 | double log_pw = 0; 180 | if (token_t != ID_BOW) { 181 | log_pw = log(_root->compute_p_w(token_t, _g0, _d_m, _theta_m)); 182 | } 183 | for (int t = 1; t < character_ids_length; t++) { 184 | log_pw += log(compute_p_w_given_h(character_ids, 0, t - 1)); 185 | } 186 | return log_pw; 187 | } 188 | // 文字列のcontext_substr_startからcontext_substr_endまでの部分文字列を文脈として、context_substr_end+1の文字が生成される確率 189 | double VPYLM::compute_p_w_given_h(wchar_t const* character_ids, int context_substr_start, int context_substr_end) 190 | { 191 | assert(context_substr_start >= 0); 192 | assert(context_substr_end >= context_substr_start); 193 | wchar_t target_id = character_ids[context_substr_end + 1]; 194 | return compute_p_w_given_h(target_id, character_ids, context_substr_start, context_substr_end); 195 | } 196 | // 単語のサンプリングなどで任意のtarget_idの確率を計算することがあるため一般化 197 | // 文字列のcontext_substr_startからcontext_substr_endまでの部分文字列を文脈として、target_idが生成される確率 198 | double VPYLM::compute_p_w_given_h(wchar_t target_id, wchar_t const* character_ids, int context_substr_start, int context_substr_end) 199 | { 200 | assert(context_substr_start >= 0); 201 | assert(context_substr_end >= context_substr_start); 202 | Node* node = _root; 203 | assert(node != NULL); 204 | double parent_pass_probability = 1; 205 | double p = 0; 206 | double parent_pw = _g0; 207 | double eps = VPYLM_EPS; // 停止確率がこの値を下回れば打ち切り 208 | double p_stop = 1; 209 | int depth = 0; 210 | 211 | // 無限の深さまで考える 212 | // 実際のコンテキスト長を超えて確率を計算することもある 213 | while (p_stop > eps) { 214 | // ノードがない場合親の確率とベータ事前分布から計算 215 | if (node == NULL) { 216 | p_stop = (_beta_stop) / (_beta_pass + _beta_stop) * parent_pass_probability; 217 | p += parent_pw * p_stop; 218 | parent_pass_probability *= (_beta_pass) / (_beta_pass + _beta_stop); 219 | } else { 220 | assert(context_substr_end - depth >= 0); 221 | assert(node->_depth == depth); 222 | double pw = node->compute_p_w_with_parent_p_w(target_id, parent_pw, _d_m, _theta_m); 223 | p_stop = node->stop_probability(_beta_stop, _beta_pass, false) * parent_pass_probability; 224 | p += pw * p_stop; 225 | parent_pass_probability *= node->pass_probability(_beta_stop, _beta_pass, false); 226 | parent_pw = pw; 227 | if (context_substr_end - depth <= context_substr_start) { 228 | node = NULL; 229 | } else { 230 | wchar_t context_token_id = character_ids[context_substr_end - depth]; 231 | Node* child = node->find_child_node(context_token_id); 232 | node = child; 233 | if (depth > 0 && node) { 234 | assert(node->_token_id == context_token_id); 235 | } 236 | } 237 | } 238 | depth++; 239 | } 240 | assert(p > 0); 241 | return p; 242 | } 243 | // 辿ったノードとそれぞれのノードからの出力確率をキャッシュしながらオーダーをサンプリング 244 | int VPYLM::sample_depth_at_time_t(wchar_t const* character_ids, int t, double* parent_pw_cache, Node** path_nodes) 245 | { 246 | assert(path_nodes != NULL); 247 | assert(parent_pw_cache != NULL); 248 | if (t == 0) { 249 | return 0; 250 | } 251 | // VPYLMは本来無限の深さを考えるが、計算量的な問題から以下の値を下回れば打ち切り 252 | double eps = VPYLM_EPS; 253 | 254 | wchar_t token_t = character_ids[t]; 255 | double sum = 0; 256 | double parent_pw = _g0; 257 | double parent_pass_probability = 1; 258 | parent_pw_cache[0] = _g0; 259 | int sampling_table_size = 0; 260 | Node* node = _root; 261 | for (int n = 0; n <= t; n++) { 262 | if (node) { 263 | assert(n == node->_depth); 264 | double pw = node->compute_p_w_with_parent_p_w(token_t, parent_pw, _d_m, _theta_m); 265 | double p_stop = node->stop_probability(_beta_stop, _beta_pass, false); 266 | double p = pw * p_stop * parent_pass_probability; 267 | parent_pw = pw; 268 | parent_pw_cache[n + 1] = pw; 269 | _sampling_table[n] = p; 270 | path_nodes[n] = node; 271 | sampling_table_size += 1; 272 | parent_pass_probability *= node->pass_probability(_beta_stop, _beta_pass, false); 273 | sum += p; 274 | if (p_stop < eps) { 275 | break; 276 | } 277 | if (n < t) { 278 | wchar_t context_token_id = character_ids[t - n - 1]; 279 | node = node->find_child_node(context_token_id); 280 | } 281 | } else { 282 | double p_stop = (_beta_stop) / (_beta_pass + _beta_stop) * parent_pass_probability; 283 | double p = parent_pw * p_stop; // ノードがない場合親の確率をそのまま使う 284 | parent_pw_cache[n + 1] = parent_pw; 285 | _sampling_table[n] = p; 286 | path_nodes[n] = NULL; 287 | sampling_table_size += 1; 288 | sum += p; 289 | parent_pass_probability *= (_beta_pass) / (_beta_pass + _beta_stop); 290 | if (p_stop < eps) { 291 | break; 292 | } 293 | } 294 | } 295 | assert(sampling_table_size <= t + 1); 296 | double normalizer = 1.0 / sum; 297 | double bernoulli = sampler::uniform(0, 1); 298 | double stack = 0; 299 | for (int n = 0; n < sampling_table_size; n++) { 300 | stack += _sampling_table[n] * normalizer; 301 | if (bernoulli < stack) { 302 | return n; 303 | } 304 | } 305 | return sampling_table_size - 1; 306 | } 307 | template 308 | void VPYLM::serialize(Archive& archive, unsigned int version) 309 | { 310 | boost::serialization::split_member(archive, *this, version); 311 | } 312 | template void VPYLM::serialize(boost::archive::binary_iarchive& ar, unsigned int version); 313 | template void VPYLM::serialize(boost::archive::binary_oarchive& ar, unsigned int version); 314 | void VPYLM::save(boost::archive::binary_oarchive& archive, unsigned int version) const 315 | { 316 | archive& _root; 317 | archive& _depth; 318 | archive& _max_depth; 319 | archive& _beta_stop; 320 | archive& _beta_pass; 321 | archive& _g0; 322 | archive& _d_m; 323 | archive& _theta_m; 324 | archive& _a_m; 325 | archive& _b_m; 326 | archive& _alpha_m; 327 | archive& _beta_m; 328 | } 329 | void VPYLM::load(boost::archive::binary_iarchive& archive, unsigned int version) 330 | { 331 | archive& _root; 332 | archive& _depth; 333 | archive& _max_depth; 334 | archive& _beta_stop; 335 | archive& _beta_pass; 336 | archive& _g0; 337 | archive& _d_m; 338 | archive& _theta_m; 339 | archive& _a_m; 340 | archive& _b_m; 341 | archive& _alpha_m; 342 | archive& _beta_m; 343 | _parent_pw_cache = new double[_max_depth + 1]; 344 | _sampling_table = new double[_max_depth + 1]; 345 | _path_nodes = new Node*[_max_depth + 1]; 346 | } 347 | }; 348 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/lm/vpylm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "../sentence.h" 6 | #include "../common.h" 7 | #include "model.h" 8 | #include "node.h" 9 | 10 | namespace npylm { 11 | namespace lm { 12 | class VPYLM: public Model { 13 | private: 14 | friend class boost::serialization::access; 15 | template 16 | void serialize(Archive& archive, unsigned int version); 17 | void save(boost::archive::binary_oarchive &archive, unsigned int version) const; 18 | void load(boost::archive::binary_iarchive &archive, unsigned int version); 19 | public: 20 | double _beta_stop; // 停止確率q_iのベータ分布の初期パラメータ 21 | double _beta_pass; // 停止確率q_iのベータ分布の初期パラメータ 22 | int _max_depth; 23 | // 計算高速化用 24 | double* _sampling_table; 25 | double* _parent_pw_cache; 26 | Node** _path_nodes; 27 | VPYLM(){} 28 | VPYLM(double g0, int max_possible_depth, double beta_stop, double beta_pass); 29 | ~VPYLM(); 30 | bool add_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t); 31 | bool add_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t, double* parent_pw_cache, Node** path_nodes); 32 | bool remove_customer_at_time_t(wchar_t const* character_ids, int t, int depth_t); 33 | Node* find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, bool generate_node_if_needed = false, bool return_middle_node = false); 34 | Node* find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, double* parent_pw_cache); 35 | Node* find_node_by_tracing_back_context(wchar_t const* character_ids, int t, int depth_t, Node** path_nodes_cache); 36 | double compute_p_w(wchar_t const* character_ids, int character_ids_length); 37 | double compute_log_p_w(wchar_t const* character_ids, int character_ids_length); 38 | double compute_p_w_given_h(wchar_t const* character_ids, int context_substr_start, int context_substr_end); 39 | double compute_p_w_given_h(wchar_t target_id, wchar_t const* character_ids, int context_substr_start, int context_substr_end); 40 | int sample_depth_at_time_t(wchar_t const* character_ids, int t, double* parent_pw_cache, Node** path_nodes); 41 | }; 42 | } 43 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/npylm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "npylm.h" 7 | #include "hash.h" 8 | #include "sampler.h" 9 | #include "wordtype.h" 10 | 11 | namespace npylm { 12 | using namespace lm; 13 | // character_idsのsubstr_char_t_startからsubstr_char_t_endまでの文字列をで挟んでwrapped_character_idsの先頭に格納 14 | void wrap_bow_eow(wchar_t const* characters, int substr_char_t_start, int substr_char_t_end, wchar_t* token_ids){ 15 | token_ids[0] = ID_BOW; 16 | int i = 0; 17 | for(;i < substr_char_t_end - substr_char_t_start + 1;i++){ 18 | token_ids[i + 1] = characters[i + substr_char_t_start]; 19 | } 20 | token_ids[i + 1] = ID_EOW; 21 | } 22 | double factorial(double n) { 23 | if (n == 0){ 24 | return 1; 25 | } 26 | return n * factorial(n - 1); 27 | } 28 | // lambda_a, lambda_bは単語長のポアソン分布のハイパーパラメータ 29 | // 異なる文字種ごとに違うλを使うが、学習時に個別に推定するため事前分布は共通化する 30 | NPYLM::NPYLM(int max_word_length, int max_sentence_length, double g0, double initial_lambda_a, double initial_lambda_b, double vpylm_beta_stop, double vpylm_beta_pass){ 31 | _hpylm = new HPYLM(3); // 3-gram以外を指定すると動かないので注意 32 | _vpylm = new VPYLM(g0, max_sentence_length, vpylm_beta_stop, vpylm_beta_pass); 33 | _lambda_for_type = new double[WORDTYPE_NUM_TYPES + 1]; // 文字種ごとの単語長のポアソン分布のハイパーパラメータ 34 | _hpylm_parent_pw_cache = new double[3]; // 3-gram 35 | set_lambda_prior(initial_lambda_a, initial_lambda_b); 36 | 37 | _max_sentence_length = max_sentence_length; 38 | _max_word_length = max_word_length; 39 | _characters = new wchar_t[max_sentence_length + 2]; // を含める 40 | _pk_vpylm = new double[max_word_length + 2]; // kが1スタート、かつk > max_word_length用の領域も必要なので+2 41 | for(int k = 1;k < max_word_length + 2;k++){ 42 | _pk_vpylm[k] = 1.0 / (max_word_length + 2); 43 | } 44 | 45 | #ifdef __DEBUG__ 46 | std::cout << "warning: debug mode enabled!" << std::endl; 47 | #endif 48 | } 49 | NPYLM::~NPYLM(){ 50 | delete _hpylm; 51 | delete _vpylm; 52 | delete[] _hpylm_parent_pw_cache; 53 | delete[] _lambda_for_type; 54 | delete[] _characters; 55 | delete[] _pk_vpylm; 56 | } 57 | void NPYLM::reserve(int max_sentence_length){ 58 | if(max_sentence_length <= _max_sentence_length){ 59 | return; 60 | } 61 | _delete_capacity(); 62 | _allocate_capacity(max_sentence_length); 63 | _max_sentence_length = max_sentence_length; 64 | } 65 | void NPYLM::_allocate_capacity(int max_sentence_length){ 66 | _max_sentence_length = max_sentence_length; 67 | _characters = new wchar_t[max_sentence_length + 2]; 68 | } 69 | void NPYLM::_delete_capacity(){ 70 | delete[] _characters; 71 | } 72 | void NPYLM::set_vpylm_g0(double g0){ 73 | _vpylm->set_g0(g0); 74 | } 75 | void NPYLM::set_lambda_prior(double a, double b){ 76 | _lambda_a = a; 77 | _lambda_b = b; 78 | sample_lambda_with_initial_params(); 79 | } 80 | void NPYLM::sample_lambda_with_initial_params(){ 81 | for(int type = 1;type <= WORDTYPE_NUM_TYPES;type++){ 82 | _lambda_for_type[type] = sampler::gamma(_lambda_a, _lambda_b); 83 | } 84 | } 85 | bool NPYLM::add_customer_at_time_t(Sentence* sentence, int t){ 86 | assert(_characters != NULL); 87 | assert(t >= 2); 88 | id token_t = sentence->get_word_id_at(t); 89 | Node* node = find_node_by_tracing_back_context_from_time_t(sentence, t, _hpylm_parent_pw_cache, true, false); 90 | assert(node != NULL); 91 | int num_tables_before = _hpylm->_root->_num_tables; 92 | int added_table_k = -1; 93 | int substr_char_t_start = sentence->_start[t]; 94 | int substr_char_t_end = sentence->_start[t] + sentence->_segments[t] - 1; 95 | node->add_customer(token_t, _hpylm_parent_pw_cache, _hpylm->_d_m, _hpylm->_theta_m, true, added_table_k); 96 | int num_tables_after = _hpylm->_root->_num_tables; 97 | // 単語unigramノードでテーブル数が増えた場合VPYLMに追加 98 | if(num_tables_before < num_tables_after){ 99 | _g0_cache.clear(); 100 | if(token_t == ID_EOS){ 101 | _vpylm->_root->add_customer(token_t, _vpylm->_g0, _vpylm->_d_m, _vpylm->_theta_m, true, added_table_k); 102 | return true; 103 | } 104 | assert(added_table_k != -1); 105 | std::vector> &depths = _prev_depth_at_table_of_token[token_t]; 106 | assert(depths.size() <= added_table_k); // 存在してはいけない 107 | std::vector prev_depths; 108 | vpylm_add_customers(sentence->_characters, substr_char_t_start, substr_char_t_end, _characters, prev_depths); 109 | assert(prev_depths.size() == substr_char_t_end - substr_char_t_start + 3); 110 | depths.push_back(prev_depths); 111 | } 112 | return true; 113 | } 114 | void NPYLM::vpylm_add_customers(wchar_t const* characters, int substr_char_t_start, int substr_char_t_end, wchar_t* token_ids, std::vector &prev_depths){ 115 | assert(prev_depths.size() == 0); 116 | assert(substr_char_t_end >= substr_char_t_start); 117 | // 先頭にをつける 118 | assert(substr_char_t_end < _max_sentence_length); 119 | wrap_bow_eow(characters, substr_char_t_start, substr_char_t_end, token_ids); 120 | int token_ids_length = substr_char_t_end - substr_char_t_start + 3; // を考慮 121 | // 客を追加 122 | for(int char_t = 0;char_t < token_ids_length;char_t++){ 123 | int depth_t = _vpylm->sample_depth_at_time_t(token_ids, char_t, _vpylm->_parent_pw_cache, _vpylm->_path_nodes); 124 | _vpylm->add_customer_at_time_t(token_ids, char_t, depth_t, _vpylm->_parent_pw_cache, _vpylm->_path_nodes); // キャッシュを使って追加 125 | prev_depths.push_back(depth_t); 126 | } 127 | } 128 | bool NPYLM::remove_customer_at_time_t(Sentence* sentence, int t){ 129 | assert(_characters != NULL); 130 | assert(t >= 2); 131 | id token_t = sentence->get_word_id_at(t); 132 | Node* node = find_node_by_tracing_back_context_from_time_t(sentence->_word_ids, sentence->get_num_segments(), t, false, false); 133 | assert(node != NULL); 134 | int num_tables_before = _hpylm->_root->_num_tables; 135 | int removed_from_table_k = -1; 136 | int substr_char_t_start = sentence->_start[t]; 137 | int substr_char_t_end = sentence->_start[t] + sentence->_segments[t] - 1; 138 | node->remove_customer(token_t, true, removed_from_table_k); 139 | 140 | // 単語unigramノードでテーブル数が増えた場合VPYLMから削除 141 | int num_tables_after = _hpylm->_root->_num_tables; 142 | if(num_tables_before > num_tables_after){ 143 | _g0_cache.clear(); 144 | if(token_t == ID_EOS){ 145 | // は文字列に分解できないので常にVPYLMのルートノードに追加されている 146 | _vpylm->_root->remove_customer(token_t, true, removed_from_table_k); 147 | return true; 148 | } 149 | assert(removed_from_table_k != -1); 150 | auto itr = _prev_depth_at_table_of_token.find(token_t); 151 | assert(itr != _prev_depth_at_table_of_token.end()); 152 | std::vector> &depths = itr->second; 153 | assert(removed_from_table_k < depths.size()); 154 | // 客を除外 155 | std::vector &prev_depths = depths[removed_from_table_k]; 156 | assert(prev_depths.size() > 0); 157 | vpylm_remove_customers(sentence->_characters, substr_char_t_start, substr_char_t_end, _characters, prev_depths); 158 | // シフト 159 | depths.erase(depths.begin() + removed_from_table_k); 160 | } 161 | if(node->need_to_remove_from_parent()){ 162 | node->remove_from_parent(); 163 | } 164 | return true; 165 | } 166 | void NPYLM::vpylm_remove_customers(wchar_t const* characters, int substr_char_t_start, int substr_char_t_end, wchar_t* token_ids, std::vector &prev_depths){ 167 | assert(prev_depths.size() > 0); 168 | assert(substr_char_t_end >= substr_char_t_start); 169 | // 先頭にをつける 170 | assert(substr_char_t_end < _max_sentence_length); 171 | wrap_bow_eow(characters, substr_char_t_start, substr_char_t_end, token_ids); 172 | int token_ids_length = substr_char_t_end - substr_char_t_start + 3; // を考慮 173 | // 客を除外 174 | assert(prev_depths.size() == token_ids_length); 175 | auto prev_depth_t = prev_depths.begin(); 176 | for(int char_t = 0;char_t < token_ids_length;char_t++){ 177 | _vpylm->remove_customer_at_time_t(token_ids, char_t, *prev_depth_t); 178 | prev_depth_t++; 179 | } 180 | } 181 | Node* NPYLM::find_node_by_tracing_back_context_from_time_t(id const* word_ids, int word_ids_length, int word_t_index, bool generate_node_if_needed, bool return_middle_node){ 182 | assert(word_t_index >= 2); 183 | assert(word_t_index < word_ids_length); 184 | Node* node = _hpylm->_root; 185 | for(int depth = 1;depth <= 2;depth++){ 186 | id context_id = ID_BOS; 187 | if(word_t_index - depth >= 0){ 188 | context_id = word_ids[word_t_index - depth]; 189 | } 190 | Node* child = node->find_child_node(context_id, generate_node_if_needed); 191 | if(child == NULL){ 192 | if(return_middle_node){ 193 | return node; 194 | } 195 | return NULL; 196 | } 197 | node = child; 198 | } 199 | assert(node->_depth == 2); 200 | return node; 201 | } 202 | // add_customer用 203 | Node* NPYLM::find_node_by_tracing_back_context_from_time_t(Sentence* sentence, int word_t_index, double* parent_pw_cache, int generate_node_if_needed, bool return_middle_node){ 204 | assert(word_t_index >= 2); 205 | assert(word_t_index < sentence->get_num_segments()); 206 | assert(sentence->_segments[word_t_index] > 0); 207 | int substr_char_t_start = sentence->_start[word_t_index]; 208 | int substr_char_t_end = sentence->_start[word_t_index] + sentence->_segments[word_t_index] - 1; 209 | return find_node_by_tracing_back_context_from_time_t( 210 | sentence->_characters, sentence->size(), 211 | sentence->_word_ids, sentence->get_num_segments(), 212 | word_t_index, substr_char_t_start, substr_char_t_end, 213 | parent_pw_cache, generate_node_if_needed, return_middle_node); 214 | } 215 | // 効率のためノードを探しながら確率も計算する 216 | Node* NPYLM::find_node_by_tracing_back_context_from_time_t( 217 | wchar_t const* characters, int character_ids_length, 218 | id const* word_ids, int word_ids_length, 219 | int word_t_index, int substr_char_t_start, int substr_char_t_end, 220 | double* parent_pw_cache, bool generate_node_if_needed, bool return_middle_node){ 221 | assert(word_t_index >= 2); 222 | assert(word_t_index < word_ids_length); 223 | assert(substr_char_t_start >= 0); 224 | assert(substr_char_t_end >= substr_char_t_start); 225 | Node* node = _hpylm->_root; 226 | id word_t_id = word_ids[word_t_index]; 227 | double parent_pw = compute_g0_substring_at_time_t(characters, character_ids_length, substr_char_t_start, substr_char_t_end, word_t_id); 228 | parent_pw_cache[0] = parent_pw; 229 | for(int depth = 1;depth <= 2;depth++){ 230 | id context_id = ID_BOS; 231 | if(word_t_index - depth >= 0){ 232 | context_id = word_ids[word_t_index - depth]; 233 | } 234 | // 事前に確率を計算 235 | double pw = node->compute_p_w_with_parent_p_w(word_t_id, parent_pw, _hpylm->_d_m, _hpylm->_theta_m); 236 | parent_pw_cache[depth] = pw; 237 | Node* child = node->find_child_node(context_id, generate_node_if_needed); 238 | if(child == NULL && return_middle_node == true){ 239 | return node; 240 | } 241 | assert(child != NULL); 242 | parent_pw = pw; 243 | node = child; 244 | } 245 | assert(node->_depth == 2); 246 | return node; 247 | } 248 | // word_idは既知なので再計算を防ぐ 249 | double NPYLM::compute_g0_substring_at_time_t(wchar_t const* characters, int character_ids_length, int substr_char_t_start, int substr_char_t_end, id word_t_id){ 250 | assert(_characters != NULL); 251 | if(word_t_id == ID_EOS){ 252 | return _vpylm->_g0; 253 | } 254 | 255 | #ifdef __DEBUG__ 256 | id a = hash_substring_ptr(characters, substr_char_t_start, substr_char_t_end); 257 | assert(a == word_t_id); 258 | #endif 259 | 260 | assert(substr_char_t_end < _max_sentence_length); 261 | assert(substr_char_t_start >= 0); 262 | assert(substr_char_t_end >= substr_char_t_start); 263 | int word_length = substr_char_t_end - substr_char_t_start + 1; 264 | // if(word_length > _max_word_length){ 265 | // return 0; 266 | // } 267 | auto itr = _g0_cache.find(word_t_id); 268 | if(itr == _g0_cache.end()){ 269 | // 先頭にをつける 270 | wchar_t* token_ids = _characters; 271 | wrap_bow_eow(characters, substr_char_t_start, substr_char_t_end, token_ids); 272 | int token_ids_length = substr_char_t_end - substr_char_t_start + 3; 273 | // g0を計算 274 | double pw = _vpylm->compute_p_w(token_ids, token_ids_length); 275 | 276 | // 学習の最初のイテレーションでは文が丸ごと1単語になるので補正する意味はない 277 | if(word_length > _max_word_length){ 278 | _g0_cache[word_t_id] = pw; 279 | return pw; 280 | } 281 | 282 | double p_k_given_vpylm = compute_p_k_given_vpylm(word_length); 283 | int type = wordtype::detect_word_type_substr(characters, substr_char_t_start, substr_char_t_end); 284 | assert(type <= WORDTYPE_NUM_TYPES); 285 | assert(type > 0); 286 | double lambda = _lambda_for_type[type]; 287 | double poisson = compute_poisson_k_lambda(word_length, lambda); 288 | assert(poisson > 0); 289 | double g0 = pw * poisson / p_k_given_vpylm; 290 | 291 | // ごく稀にポアソン補正で1を超えることがある 292 | if((0 < g0 && g0 < 1) == false){ 293 | for(int u = substr_char_t_start;u <= substr_char_t_end;u++){ 294 | std::wcout << characters[u]; 295 | } 296 | std::wcout << std::endl; 297 | std::cout << pw << std::endl; 298 | std::cout << poisson << std::endl; 299 | std::cout << p_k_given_vpylm << std::endl; 300 | std::cout << g0 << std::endl; 301 | std::cout << word_length << std::endl; 302 | } 303 | assert(0 < g0 && g0 < 1); 304 | _g0_cache[word_t_id] = g0; 305 | return g0; 306 | } 307 | return itr->second; 308 | } 309 | double NPYLM::compute_poisson_k_lambda(unsigned int k, double lambda){ 310 | return pow(lambda, k) * exp(-lambda) / factorial(k); 311 | } 312 | double NPYLM::compute_p_k_given_vpylm(int k){ 313 | assert(k > 0); 314 | if(k > _max_word_length){ 315 | return 0; 316 | } 317 | return _pk_vpylm[k]; 318 | } 319 | void NPYLM::sample_hpylm_vpylm_hyperparameters(){ 320 | _hpylm->sample_hyperparams(); 321 | _vpylm->sample_hyperparams(); 322 | } 323 | double NPYLM::compute_log_p_w(Sentence* sentence){ 324 | double pw = 0; 325 | for(int t = 2;t < sentence->get_num_segments();t++){ 326 | pw += log(compute_p_w_given_h(sentence, t)); 327 | } 328 | return pw; 329 | } 330 | double NPYLM::compute_p_w(Sentence* sentence){ 331 | double pw = 1; 332 | for(int t = 2;t < sentence->get_num_segments();t++){ 333 | pw *= compute_p_w_given_h(sentence, t); 334 | } 335 | return pw; 336 | } 337 | double NPYLM::compute_p_w_given_h(Sentence* sentence, int word_t_index){ 338 | assert(word_t_index >= 2); 339 | assert(word_t_index < sentence->get_num_segments()); 340 | assert(sentence->_segments[word_t_index] > 0); 341 | int substr_char_t_start = sentence->_start[word_t_index]; 342 | int substr_char_t_end = sentence->_start[word_t_index] + sentence->_segments[word_t_index] - 1; 343 | return compute_p_w_given_h(sentence->_characters, sentence->size(), sentence->_word_ids, sentence->get_num_segments(), word_t_index, substr_char_t_start, substr_char_t_end); 344 | } 345 | double NPYLM::compute_p_w_given_h( 346 | wchar_t const* characters, int character_ids_length, 347 | id const* word_ids, int word_ids_length, 348 | int word_t_index, int substr_char_t_start, int substr_char_t_end){ 349 | assert(word_t_index < word_ids_length); 350 | assert(substr_char_t_start >= 0); 351 | id word_id = word_ids[word_t_index]; 352 | 353 | if(word_id != ID_EOS){ 354 | assert(substr_char_t_end < character_ids_length); 355 | #ifdef __DEBUG__ 356 | id a = hash_substring_ptr(characters, substr_char_t_start, substr_char_t_end); 357 | assert(a == word_id); 358 | #endif 359 | } 360 | // ノードを探しながら_hpylm_parent_pw_cacheをセット 361 | Node* node = find_node_by_tracing_back_context_from_time_t(characters, character_ids_length, word_ids, word_ids_length, word_t_index, substr_char_t_start, substr_char_t_end, _hpylm_parent_pw_cache, false, true); 362 | assert(node != NULL); 363 | double parent_pw = _hpylm_parent_pw_cache[node->_depth]; 364 | // 効率のため親の確率のキャッシュから計算 365 | return node->compute_p_w_with_parent_p_w(word_id, parent_pw, _hpylm->_d_m, _hpylm->_theta_m); 366 | } 367 | template 368 | void NPYLM::serialize(Archive &archive, unsigned int version) 369 | { 370 | boost::serialization::split_member(archive, *this, version); 371 | } 372 | template void NPYLM::serialize(boost::archive::binary_iarchive &ar, unsigned int version); 373 | template void NPYLM::serialize(boost::archive::binary_oarchive &ar, unsigned int version); 374 | void NPYLM::save(boost::archive::binary_oarchive &archive, unsigned int version) const { 375 | archive & _hpylm; 376 | archive & _vpylm; 377 | archive & _max_word_length; 378 | archive & _max_sentence_length; 379 | archive & _lambda_a; 380 | archive & _lambda_b; 381 | 382 | for(int type = 1;type <= WORDTYPE_NUM_TYPES;type++){ 383 | archive & _lambda_for_type[type]; 384 | } 385 | for(int k = 0;k <= _max_word_length + 1;k++){ 386 | archive & _pk_vpylm[k]; 387 | } 388 | } 389 | void NPYLM::load(boost::archive::binary_iarchive &archive, unsigned int version) { 390 | archive & _hpylm; 391 | archive & _vpylm; 392 | archive & _max_word_length; 393 | archive & _max_sentence_length; 394 | archive & _lambda_a; 395 | archive & _lambda_b; 396 | 397 | _pk_vpylm = new double[_max_word_length + 2]; 398 | _lambda_for_type = new double[WORDTYPE_NUM_TYPES + 1]; 399 | 400 | _hpylm_parent_pw_cache = new double[3]; 401 | _characters = new wchar_t[_max_sentence_length + 2]; 402 | 403 | for(int type = 1;type <= WORDTYPE_NUM_TYPES;type++){ 404 | archive & _lambda_for_type[type]; 405 | } 406 | for(int k = 0;k <= _max_word_length + 1;k++){ 407 | archive & _pk_vpylm[k]; 408 | } 409 | } 410 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/npylm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "common.h" 4 | #include "lm/node.h" 5 | #include "lm/vpylm.h" 6 | #include "lm/hpylm.h" 7 | 8 | namespace npylm { 9 | // character_idsのsubstr_char_t_startからsubstr_char_t_endまでの文字列をで挟んでwrapped_character_idsの先頭に格納 10 | void wrap_bow_eow(wchar_t const* character_ids, int substr_char_t_start, int substr_char_t_end, wchar_t* wrapped_character_ids); 11 | double factorial(double n); 12 | class NPYLM { 13 | private: 14 | friend class boost::serialization::access; 15 | template 16 | void serialize(Archive &archive, unsigned int version); 17 | void save(boost::archive::binary_oarchive &archive, unsigned int version) const; 18 | void load(boost::archive::binary_iarchive &archive, unsigned int version); 19 | void _allocate_capacity(int max_sentence_length); 20 | void _delete_capacity(); 21 | public: 22 | lm::HPYLM* _hpylm; // 単語n-gram 23 | lm::VPYLM* _vpylm; // 文字n-gram 24 | 25 | // 単語unigramノードで新たなテーブルが作られた時はVPYLMからその単語が生成されたと判断し、単語の文字列をVPYLMに追加する 26 | // その時各文字がVPYLMのどの深さに追加されたかを保存する 27 | // 単語unigramノードのテーブルごと、単語IDごとに保存する必要がある 28 | hashmap>> _prev_depth_at_table_of_token; 29 | 30 | hashmap _g0_cache; 31 | hashmap _vpylm_g0_cache; 32 | double* _lambda_for_type; 33 | double* _pk_vpylm; // 文字n-gramから長さkの単語が生成される確率 34 | int _max_word_length; 35 | int _max_sentence_length; 36 | double _lambda_a; 37 | double _lambda_b; 38 | // 計算高速化用 39 | double* _hpylm_parent_pw_cache; 40 | wchar_t* _characters; 41 | NPYLM(){} 42 | NPYLM(int max_word_length, 43 | int max_sentence_length, 44 | double g0, 45 | double initial_lambda_a, 46 | double initial_lambda_b, 47 | double vpylm_beta_stop, 48 | double vpylm_beta_pass); 49 | ~NPYLM(); 50 | void reserve(int max_sentence_length); 51 | void set_vpylm_g0(double g0); 52 | void set_lambda_prior(double a, double b); 53 | void sample_lambda_with_initial_params(); 54 | bool add_customer_at_time_t(Sentence* sentence, int t); 55 | void vpylm_add_customers(wchar_t const* character_ids, int substr_char_t_start, int substr_char_t_end, wchar_t* wrapped_character_ids, std::vector &prev_depths); 56 | bool remove_customer_at_time_t(Sentence* sentence, int t); 57 | void vpylm_remove_customers(wchar_t const* character_ids, int substr_char_t_start, int substr_char_t_end, wchar_t* wrapped_character_ids, std::vector &prev_depths); 58 | lm::Node* find_node_by_tracing_back_context_from_time_t(id const* word_ids, int word_ids_length, int word_t_index, bool generate_node_if_needed, bool return_middle_node); 59 | lm::Node* find_node_by_tracing_back_context_from_time_t(Sentence* sentence, int word_t_index, double* parent_pw_cache, int generate_node_if_needed, bool return_middle_node); 60 | lm::Node* find_node_by_tracing_back_context_from_time_t( 61 | wchar_t const* character_ids, int character_ids_length, 62 | id const* word_ids, int word_ids_length, 63 | int word_t_index, int substr_char_t_start, int substr_char_t_end, 64 | double* parent_pw_cache, bool generate_node_if_needed, bool return_middle_node); 65 | // word_idは既知なので再計算を防ぐ 66 | double compute_g0_substring_at_time_t(wchar_t const* character_ids, int character_ids_length, int substr_char_t_start, int substr_char_t_end, id word_t_id); 67 | double compute_poisson_k_lambda(unsigned int k, double lambda); 68 | double compute_p_k_given_vpylm(int k); 69 | void sample_hpylm_vpylm_hyperparameters(); 70 | double compute_log_p_w(Sentence* sentence); 71 | double compute_p_w(Sentence* sentence); 72 | double compute_p_w_given_h(Sentence* sentence, int word_t_index); 73 | double compute_p_w_given_h( 74 | wchar_t const* character_ids, int character_ids_length, 75 | id const* word_ids, int word_ids_length, 76 | int word_t_index, int substr_char_t_start, int substr_char_t_end); 77 | }; 78 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/sampler.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "sampler.h" 3 | 4 | namespace npylm { 5 | namespace sampler{ 6 | int seed = std::chrono::system_clock::now().time_since_epoch().count(); 7 | // int seed = 1; 8 | std::mt19937 mt(seed); 9 | void set_seed(int seed){ 10 | mt = std::mt19937(seed); 11 | } 12 | double gamma(double a, double b){ 13 | std::gamma_distribution distribution(a, 1.0 / b); 14 | return distribution(mt); 15 | } 16 | double beta(double a, double b){ 17 | double ga = gamma(a, 1.0); 18 | double gb = gamma(b, 1.0); 19 | return ga / (ga + gb); 20 | } 21 | double bernoulli(double p){ 22 | std::uniform_real_distribution rand(0, 1); 23 | double r = rand(mt); 24 | if(r > p){ 25 | return 0; 26 | } 27 | return 1; 28 | } 29 | double uniform(double min, double max){ 30 | std::uniform_real_distribution rand(min, max); 31 | return rand(mt); 32 | } 33 | double uniform_int(int min, int max){ 34 | std::uniform_int_distribution<> rand(min, max); 35 | return rand(mt); 36 | } 37 | double normal(double mean, double stddev){ 38 | std::normal_distribution rand(mean, stddev); 39 | return rand(mt); 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /src/npylm/sampler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace npylm { 5 | namespace sampler { 6 | extern std::mt19937 mt; 7 | double gamma(double a, double b); 8 | double beta(double a, double b); 9 | double bernoulli(double p); 10 | double uniform(double min, double max); 11 | double uniform_int(int min, int max); 12 | double normal(double mean, double stddev); 13 | void set_seed(int seed); 14 | } 15 | } -------------------------------------------------------------------------------- /src/npylm/sentence.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "hash.h" 3 | #include "sentence.h" 4 | 5 | // は長さが0文字であることに注意 6 | 7 | namespace npylm { 8 | Sentence::Sentence(std::wstring sentence){ 9 | _sentence_str = sentence; 10 | _characters = _sentence_str.data(); 11 | _word_ids = new id[size() + 3]; 12 | _segments = new int[size() + 3]; 13 | _start = new int[size() + 3]; 14 | for(int i = 0;i < size() + 3;i++){ 15 | _word_ids[i] = 0; 16 | _segments[i] = 0; 17 | } 18 | _word_ids[0] = ID_BOS; 19 | _word_ids[1] = ID_BOS; 20 | _word_ids[2] = get_substr_word_id(0, size() - 1); 21 | _word_ids[3] = ID_EOS; 22 | _segments[0] = 1; 23 | _segments[1] = 1; 24 | _segments[2] = _sentence_str.size(); 25 | _segments[3] = 1; 26 | _start[0] = 0; 27 | _start[1] = 0; 28 | _start[2] = 0; 29 | _start[3] = _sentence_str.size(); 30 | _num_segments = 4; 31 | _supervised = false; 32 | } 33 | Sentence::Sentence(std::wstring sentence, bool supervised): Sentence(sentence){ 34 | _supervised = supervised; 35 | } 36 | Sentence::~Sentence(){ 37 | delete[] _segments; 38 | delete[] _start; 39 | delete[] _word_ids; 40 | } 41 | Sentence* Sentence::copy(){ 42 | Sentence* sentence = new Sentence(_sentence_str); 43 | return sentence; 44 | } 45 | bool Sentence::is_supervised(){ 46 | return _supervised; 47 | } 48 | int Sentence::size(){ 49 | return _sentence_str.size(); 50 | } 51 | int Sentence::get_num_segments(){ 52 | return _num_segments; 53 | } 54 | int Sentence::get_num_segments_without_special_tokens(){ 55 | return _num_segments - 3; 56 | } 57 | int Sentence::get_word_length_at(int t){ 58 | assert(t < _num_segments); 59 | return _segments[t]; 60 | } 61 | id Sentence::get_word_id_at(int t){ 62 | assert(t < _num_segments); 63 | return _word_ids[t]; 64 | } 65 | id Sentence::get_substr_word_id(int start_index, int end_index){ 66 | return hash_substring_ptr(_characters, start_index, end_index); 67 | } 68 | std::wstring Sentence::get_substr_word_str(int start_index, int end_index){ 69 | std::wstring str(_sentence_str.begin() + start_index, _sentence_str.begin() + end_index + 1); 70 | return str; 71 | } 72 | // を考慮 73 | std::wstring Sentence::get_word_str_at(int t){ 74 | assert(t < _num_segments); 75 | if(t < 2){ 76 | return L""; 77 | } 78 | assert(t < _num_segments - 1); 79 | std::wstring str(_sentence_str.begin() + _start[t], _sentence_str.begin() + _start[t] + _segments[t]); 80 | return str; 81 | } 82 | void Sentence::dump_characters(){ 83 | for(int i = 0;i < size();i++){ 84 | std::cout << _characters[i] << ","; 85 | } 86 | std::cout << std::endl; 87 | } 88 | void Sentence::dump_words(){ 89 | std::wcout << L" / "; 90 | for(int i = 2;i < _num_segments - 1;i++){ 91 | for(int j = 0;j < _segments[i];j++){ 92 | std::wcout << _characters[j + _start[i]]; 93 | } 94 | std::wcout << L" / "; 95 | } 96 | std::wcout << std::endl; 97 | } 98 | // num_segmentsにはの数は含めない 99 | void Sentence::split(int* segments_without_special_tokens, int num_segments_without_special_tokens){ 100 | int start = 0; 101 | int n = 0; 102 | int sum = 0; 103 | for(;n < num_segments_without_special_tokens;n++){ 104 | if(segments_without_special_tokens[n] == 0){ 105 | assert(n > 0); 106 | break; 107 | } 108 | sum += segments_without_special_tokens[n]; 109 | _segments[n + 2] = segments_without_special_tokens[n]; 110 | _word_ids[n + 2] = get_substr_word_id(start, start + segments_without_special_tokens[n] - 1); 111 | _start[n + 2] = start; 112 | start += segments_without_special_tokens[n]; 113 | } 114 | assert(sum == _sentence_str.size()); 115 | _segments[n + 2] = 1; 116 | _word_ids[n + 2] = ID_EOS; 117 | _start[n + 2] = _start[n + 1]; 118 | n++; 119 | for(;n < _sentence_str.size();n++){ 120 | _segments[n + 2] = 0; 121 | _start[n + 2] = 0; 122 | } 123 | _num_segments = num_segments_without_special_tokens + 3; 124 | } 125 | void Sentence::split(std::vector &segments_without_special_tokens){ 126 | int num_segments_without_special_tokens = segments_without_special_tokens.size(); 127 | int start = 0; 128 | int n = 0; 129 | int sum = 0; 130 | for(;n < num_segments_without_special_tokens;n++){ 131 | assert(segments_without_special_tokens[n] > 0); 132 | sum += segments_without_special_tokens[n]; 133 | _segments[n + 2] = segments_without_special_tokens[n]; 134 | _word_ids[n + 2] = get_substr_word_id(start, start + segments_without_special_tokens[n] - 1); 135 | _start[n + 2] = start; 136 | start += segments_without_special_tokens[n]; 137 | } 138 | assert(sum == _sentence_str.size()); 139 | _segments[n + 2] = 1; 140 | _word_ids[n + 2] = ID_EOS; 141 | _start[n + 2] = _start[n + 1]; 142 | n++; 143 | for(;n < _sentence_str.size();n++){ 144 | _segments[n + 2] = 0; 145 | _start[n + 2] = 0; 146 | } 147 | _num_segments = num_segments_without_special_tokens + 3; 148 | } 149 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/sentence.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "common.h" 5 | 6 | // は長さが0文字であることに注意 7 | 8 | namespace npylm { 9 | class Sentence { 10 | public: 11 | int _num_segments; // 2つと1つを含める 12 | int* _segments; // 各単語の長さが入る. 2つが先頭に来る 13 | int* _start; // 2つが先頭に来る 14 | bool _supervised; // 教師データかどうか 15 | wchar_t const* _characters; // _sentence_strの各文字 16 | id* _word_ids; // 2つと1つを含める 17 | std::wstring _sentence_str; // 生の文データ 18 | Sentence(std::wstring sentence); 19 | Sentence(std::wstring sentence, bool supervised); 20 | ~Sentence(); 21 | Sentence* copy(); 22 | int size(); 23 | bool is_supervised(); 24 | int get_num_segments(); 25 | int get_num_segments_without_special_tokens(); 26 | int get_word_length_at(int t); 27 | id get_word_id_at(int t); 28 | id get_substr_word_id(int start_index, int end_index); // end_indexを含む 29 | std::wstring get_substr_word_str(int start_index, int end_index); // endを含む 30 | std::wstring get_word_str_at(int t); // t=0,1の時はが返る 31 | void dump_characters(); 32 | void dump_words(); 33 | // num_segmentsにはの数は含めない 34 | void split(int* segments_without_special_tokens, int num_segments_without_special_tokens); 35 | void split(std::vector &segments_without_special_tokens); 36 | }; 37 | } // namespace npylm -------------------------------------------------------------------------------- /src/npylm/wordtype.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ctype.h" 3 | #include "wordtype.h" 4 | 5 | namespace npylm { 6 | namespace wordtype { 7 | bool is_dash(wchar_t character){ 8 | if(character == 0x30FC){ 9 | return true; 10 | } 11 | return false; 12 | } 13 | bool is_hiragana(wchar_t character){ 14 | int type = ctype::get_type(character); 15 | if(type == CTYPE_HIRAGANA){ 16 | return true; 17 | } 18 | return is_dash(character); // 長音はひらがなとカタカナ両方で使われる 19 | } 20 | bool is_katakana(wchar_t character){ 21 | int type = ctype::get_type(character); 22 | if(type == CTYPE_KATAKANA){ 23 | return true; 24 | } 25 | if(type == CTYPE_KATAKANA_PHONETIC_EXTENSIONS){ 26 | return true; 27 | } 28 | return is_dash(character); // 長音はひらがなとカタカナ両方で使われる 29 | } 30 | bool is_kanji(wchar_t character){ 31 | int type = ctype::get_type(character); 32 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS){ 33 | return true; 34 | } 35 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_A){ 36 | return true; 37 | } 38 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_B){ 39 | return true; 40 | } 41 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_C){ 42 | return true; 43 | } 44 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_D){ 45 | return true; 46 | } 47 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_E){ 48 | return true; 49 | } 50 | if(type == CTYPE_CJK_UNIFIED_IDEOGRAPHS_EXTENSION_F){ 51 | return true; 52 | } 53 | if(type == CTYPE_CJK_RADICALS_SUPPLEMENT){ 54 | return true; 55 | } 56 | return false; 57 | 58 | } 59 | bool is_number(wchar_t character){ 60 | int type = ctype::get_type(character); 61 | if(type == CTYPE_BASIC_LATIN){ 62 | if(0x30 <= character && character <= 0x39){ 63 | return true; 64 | } 65 | return false; 66 | } 67 | if(type == CTYPE_NUMBER_FORMS){ 68 | return true; 69 | } 70 | if(type == CTYPE_COMMON_INDIC_NUMBER_FORMS){ 71 | return true; 72 | } 73 | if(type == CTYPE_AEGEAN_NUMBERS){ 74 | return true; 75 | } 76 | if(type == CTYPE_ANCIENT_GREEK_NUMBERS){ 77 | return true; 78 | } 79 | if(type == CTYPE_COPTIC_EPACT_NUMBERS){ 80 | return true; 81 | } 82 | if(type == CTYPE_SINHALA_ARCHAIC_NUMBERS){ 83 | return true; 84 | } 85 | if(type == CTYPE_CUNEIFORM_NUMBERS_AND_PUNCTUATION){ 86 | return true; 87 | } 88 | return false; 89 | } 90 | bool is_alphabet(wchar_t character){ 91 | if(0x41 <= character && character <= 0x5a){ 92 | return true; 93 | } 94 | if(0x61 <= character && character <= 0x7a){ 95 | return true; 96 | } 97 | return false; 98 | } 99 | bool is_symbol(wchar_t character){ 100 | if(is_alphabet(character)){ 101 | return false; 102 | } 103 | if(is_number(character)){ 104 | return false; 105 | } 106 | if(is_kanji(character)){ 107 | return false; 108 | } 109 | if(is_hiragana(character)){ 110 | return false; 111 | } 112 | return true; 113 | } 114 | int detect_word_type(std::wstring &word){ 115 | return detect_word_type_substr(word.data(), 0, word.size() - 1); 116 | } 117 | // 文字列の指定範囲の単語種判定 118 | int detect_word_type_substr(wchar_t const* characters, int substr_start, int substr_end){ 119 | assert(substr_end >= substr_start); 120 | int num_alphabet = 0; 121 | int num_number = 0; 122 | int num_symbol = 0; 123 | int num_hiragana = 0; 124 | int num_katakana = 0; 125 | int num_kanji = 0; 126 | int num_dash = 0; 127 | int size = substr_end - substr_start + 1; 128 | for(int i = substr_start;i <= substr_end;i++){ 129 | const wchar_t target = characters[i]; 130 | if(is_alphabet(target)){ 131 | num_alphabet += 1; 132 | continue; 133 | } 134 | if(is_number(target)){ 135 | num_number += 1; 136 | continue; 137 | } 138 | if(is_dash(target)){ 139 | num_dash += 1; 140 | continue; 141 | } 142 | if(is_hiragana(target)){ 143 | num_hiragana += 1; 144 | continue; 145 | } 146 | if(is_katakana(target)){ 147 | num_katakana += 1; 148 | continue; 149 | } 150 | if(is_kanji(target)){ 151 | num_kanji += 1; 152 | continue; 153 | } 154 | num_symbol += 1; 155 | } 156 | if(num_alphabet == size){ 157 | return WORDTYPE_ALPHABET; 158 | } 159 | if(num_number == size){ 160 | return WORDTYPE_NUMBER; 161 | } 162 | if(num_hiragana + num_dash == size){ 163 | return WORDTYPE_HIRAGANA; 164 | } 165 | if(num_katakana + num_dash == size){ 166 | return WORDTYPE_KATAKANA; 167 | } 168 | if(num_kanji == size){ 169 | return WORDTYPE_KANJI; 170 | } 171 | if(num_symbol == size){ 172 | return WORDTYPE_SYMBOL; 173 | } 174 | if(num_kanji > 0){ 175 | if(num_hiragana + num_kanji == size){ 176 | return WORDTYPE_KANJI_HIRAGANA; 177 | } 178 | if(num_hiragana > 0){ 179 | if(num_hiragana + num_kanji + num_dash == size){ 180 | return WORDTYPE_KANJI_HIRAGANA; 181 | } 182 | } 183 | if(num_katakana + num_kanji == size){ 184 | return WORDTYPE_KANJI_KATAKANA; 185 | } 186 | if(num_katakana){ 187 | if(num_katakana + num_kanji + num_dash == size){ 188 | return WORDTYPE_KANJI_KATAKANA; 189 | } 190 | } 191 | } 192 | return WORDTYPE_OTHER; 193 | } 194 | } 195 | } -------------------------------------------------------------------------------- /src/npylm/wordtype.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define WORDTYPE_NUM_TYPES 9 6 | 7 | #define WORDTYPE_ALPHABET 1 8 | #define WORDTYPE_NUMBER 2 9 | #define WORDTYPE_SYMBOL 3 10 | #define WORDTYPE_HIRAGANA 4 11 | #define WORDTYPE_KATAKANA 5 12 | #define WORDTYPE_KANJI 6 13 | #define WORDTYPE_KANJI_HIRAGANA 7 14 | #define WORDTYPE_KANJI_KATAKANA 8 15 | #define WORDTYPE_OTHER 9 16 | 17 | namespace npylm { 18 | namespace wordtype { 19 | bool is_dash(wchar_t character); 20 | bool is_hiragana(wchar_t character); 21 | bool is_katakana(wchar_t character); 22 | bool is_kanji(wchar_t character); 23 | bool is_number(wchar_t character); 24 | bool is_alphabet(wchar_t character); 25 | bool is_symbol(wchar_t character); 26 | int detect_word_type(std::wstring &word); 27 | int detect_word_type_substr(wchar_t const* characters, int substr_start, int substr_end); 28 | } 29 | } -------------------------------------------------------------------------------- /src/python.cpp: -------------------------------------------------------------------------------- 1 | #include "python/dataset.h" 2 | #include "python/dictionary.h" 3 | #include "python/model.h" 4 | #include "python/trainer.h" 5 | 6 | using namespace npylm; 7 | using boost::python::arg; 8 | 9 | BOOST_PYTHON_MODULE(npylm) 10 | { 11 | boost::python::class_("dictionary") 12 | .def("save", &Dictionary::save) 13 | .def("load", &Dictionary::load); 14 | 15 | boost::python::class_("corpus") 16 | .def("add_textfile", &Corpus::add_textfile) 17 | .def("add_true_segmentation", &Corpus::python_add_true_segmentation) 18 | .def("add_sentence", &Corpus::add_sentence); 19 | 20 | boost::python::class_("dataset", boost::python::init()) 21 | .def("get_max_sentence_length", &Dataset::get_max_sentence_length) 22 | .def("detect_hash_collision", &Dataset::detect_hash_collision) 23 | .def("get_num_sentences_train", &Dataset::get_num_sentences_train) 24 | .def("get_num_sentences_dev", &Dataset::get_num_sentences_dev) 25 | .def("get_num_sentences_supervised", &Dataset::get_num_sentences_supervised) 26 | .def("get_dict", &Dataset::get_dict_obj, boost::python::return_internal_reference<>()); 27 | 28 | boost::python::class_("trainer", boost::python::init((arg("dataset"), arg("model"), arg("always_accept_new_segmentation") = true))) 29 | .def("print_segmentation_train", &Trainer::print_segmentation_train) 30 | .def("print_segmentation_dev", &Trainer::print_segmentation_dev) 31 | .def("sample_hpylm_vpylm_hyperparameters", &Trainer::sample_hpylm_vpylm_hyperparameters) 32 | .def("sample_lambda", &Trainer::sample_lambda) 33 | .def("update_p_k_given_vpylm", &Trainer::update_p_k_given_vpylm) 34 | .def("compute_perplexity_train", &Trainer::compute_perplexity_train) 35 | .def("compute_perplexity_dev", &Trainer::compute_perplexity_dev) 36 | .def("gibbs", &Trainer::gibbs); 37 | 38 | boost::python::class_("model", boost::python::init()) 39 | .def(boost::python::init()) 40 | .def("set_initial_lambda_a", &Model::set_initial_lambda_a) 41 | .def("set_initial_lambda_b", &Model::set_initial_lambda_b) 42 | .def("set_vpylm_beta_stop", &Model::set_vpylm_beta_stop) 43 | .def("set_vpylm_beta_pass", &Model::set_vpylm_beta_pass) 44 | .def("get_lambda", &Model::python_get_lambda) 45 | .def("parse", &Model::python_parse) 46 | .def("save", &Model::save) 47 | .def("load", &Model::load); 48 | } -------------------------------------------------------------------------------- /src/python/corpus.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "corpus.h" 5 | 6 | namespace npylm { 7 | void Corpus::add_textfile(std::string filename){ 8 | std::wifstream ifs(filename.c_str()); 9 | std::wstring sentence_str; 10 | assert(ifs.fail() == false); 11 | while (getline(ifs, sentence_str)){ 12 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 13 | return; 14 | } 15 | if(sentence_str.empty()){ 16 | continue; 17 | } 18 | add_sentence(sentence_str); 19 | } 20 | } 21 | void Corpus::add_sentence(std::wstring sentence_str){ 22 | _sentence_str_list.push_back(sentence_str); 23 | } 24 | int Corpus::get_num_sentences(){ 25 | return _sentence_str_list.size(); 26 | } 27 | int Corpus::get_num_true_segmentations(){ 28 | return _word_sequence_list.size(); 29 | } 30 | void Corpus::_before_add_true_segmentation(boost::python::list &py_word_str_list, std::vector &word_str_vec){ 31 | int num_words = boost::python::len(py_word_str_list); 32 | for(int i = 0;i < num_words;i++){ 33 | std::wstring word = boost::python::extract(py_word_str_list[i]); 34 | word_str_vec.push_back(word); 35 | } 36 | } 37 | void Corpus::python_add_true_segmentation(boost::python::list py_word_str_list){ 38 | std::vector word_str_vec; 39 | _before_add_true_segmentation(py_word_str_list, word_str_vec); 40 | assert(word_str_vec.size() > 0); 41 | add_true_segmentation(word_str_vec); 42 | } 43 | void Corpus::add_true_segmentation(std::vector &word_str_vec){ 44 | assert(word_str_vec.size() > 1); 45 | _word_sequence_list.push_back(word_str_vec); 46 | } 47 | } -------------------------------------------------------------------------------- /src/python/corpus.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace npylm { 6 | class Corpus{ 7 | private: 8 | void _before_add_true_segmentation(boost::python::list &py_word_str_list, std::vector &word_str_vec); 9 | public: 10 | std::vector _sentence_str_list; 11 | std::vector> _word_sequence_list; 12 | Corpus(){} 13 | void add_textfile(std::string filename); 14 | void add_sentence(std::wstring sentence_str); 15 | void add_true_segmentation(std::vector &word_str_vec); // 正解の分割を追加する 16 | void python_add_true_segmentation(boost::python::list py_word_str_list); 17 | int get_num_sentences(); 18 | int get_num_true_segmentations(); 19 | }; 20 | } -------------------------------------------------------------------------------- /src/python/dataset.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "dataset.h" 5 | #include "../npylm/sampler.h" 6 | #include "../npylm/hash.h" 7 | 8 | namespace npylm { 9 | Dataset::Dataset(Corpus* corpus, double train_split, int seed){ 10 | _dict = new Dictionary(); 11 | _corpus = corpus; 12 | _max_sentence_length = 0; 13 | _avg_sentence_length = 0; 14 | int sum_sentence_length = 0; 15 | std::vector rand_indices; 16 | for(int i = 0;i < corpus->get_num_sentences();i++){ 17 | rand_indices.push_back(i); 18 | } 19 | // まず教師なし学習用のデータをtrain/devに振り分ける 20 | sampler::set_seed(seed); 21 | shuffle(rand_indices.begin(), rand_indices.end(), sampler::mt); // データをシャッフル 22 | train_split = std::min(1.0, std::max(0.0, train_split)); 23 | int num_train_data = corpus->get_num_sentences() * train_split; 24 | for(int i = 0;i < rand_indices.size();i++){ 25 | std::wstring &sentence_str = corpus->_sentence_str_list[rand_indices[i]]; 26 | if(i < num_train_data){ 27 | _add_words_to_dataset(sentence_str, _sentence_sequences_train); 28 | }else{ 29 | _add_words_to_dataset(sentence_str, _sentence_sequences_dev); 30 | } 31 | // 統計 32 | if(_max_sentence_length == 0 || sentence_str.size() > _max_sentence_length){ 33 | _max_sentence_length = sentence_str.size(); 34 | } 35 | sum_sentence_length += sentence_str.size(); 36 | } 37 | // 教師分割データがあればすべてtrainに追加 38 | _num_supervised_data = corpus->get_num_true_segmentations(); 39 | for(int i = 0;i < corpus->get_num_true_segmentations();i++){ 40 | // 分割から元の文を復元 41 | std::vector &words = corpus->_word_sequence_list[i]; 42 | std::vector segmentation; 43 | std::wstring sentence_str; 44 | for(auto word_str: words){ 45 | sentence_str += word_str; 46 | segmentation.push_back(word_str.size()); 47 | } 48 | // 構成文字を辞書に追加 49 | for(wchar_t character: sentence_str){ 50 | _dict->add_character(character); 51 | } 52 | // データセットに追加 53 | Sentence* sentence = new Sentence(sentence_str, true); 54 | sentence->split(segmentation); // 分割 55 | _sentence_sequences_train.push_back(sentence); 56 | // 統計 57 | if(_max_sentence_length == 0 || sentence_str.size() > _max_sentence_length){ 58 | _max_sentence_length = sentence_str.size(); 59 | } 60 | sum_sentence_length += sentence_str.size(); 61 | } 62 | _avg_sentence_length = sum_sentence_length / (double)corpus->get_num_sentences(); 63 | } 64 | Dataset::~Dataset(){ 65 | for(int n = 0;n < _sentence_sequences_train.size();n++){ 66 | Sentence* sentence = _sentence_sequences_train[n]; 67 | delete sentence; 68 | } 69 | for(int n = 0;n < _sentence_sequences_dev.size();n++){ 70 | Sentence* sentence = _sentence_sequences_dev[n]; 71 | delete sentence; 72 | } 73 | delete _dict; 74 | } 75 | int Dataset::get_num_sentences_train(){ 76 | return _sentence_sequences_train.size(); 77 | } 78 | int Dataset::get_num_sentences_dev(){ 79 | return _sentence_sequences_dev.size(); 80 | } 81 | int Dataset::get_num_sentences_supervised(){ 82 | return _num_supervised_data; 83 | } 84 | void Dataset::_add_words_to_dataset(std::wstring &sentence_str, std::vector &dataset){ 85 | assert(sentence_str.size() > 0); 86 | for(wchar_t character: sentence_str){ 87 | _dict->add_character(character); 88 | } 89 | Sentence* sentence = new Sentence(sentence_str); 90 | dataset.push_back(sentence); 91 | } 92 | int Dataset::get_max_sentence_length(){ 93 | return _max_sentence_length; 94 | } 95 | int Dataset::get_average_sentence_length(){ 96 | return _avg_sentence_length; 97 | } 98 | int Dataset::detect_hash_collision(int max_word_length){ 99 | int step = 0; 100 | std::unordered_map pool; 101 | for(Sentence* sentence: _sentence_sequences_train){ 102 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 103 | return 0; 104 | } 105 | _detect_collision_of_sentence(sentence, pool, max_word_length); 106 | step++; 107 | } 108 | for(Sentence* sentence: _sentence_sequences_dev){ 109 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 110 | return 0; 111 | } 112 | _detect_collision_of_sentence(sentence, pool, max_word_length); 113 | step++; 114 | } 115 | return pool.size(); 116 | } 117 | void Dataset::_detect_collision_of_sentence(Sentence* sentence, std::unordered_map &pool, int max_word_length){ 118 | for(int t = 1;t <= sentence->size();t++){ 119 | for(int k = 1;k <= std::min(t, max_word_length);k++){ 120 | id word_id = sentence->get_substr_word_id(t - k, t - 1); 121 | std::wstring word = sentence->get_substr_word_str(t - k, t - 1); 122 | assert(word_id == hash_wstring(word)); 123 | auto itr = pool.find(word_id); 124 | if(itr == pool.end()){ 125 | pool[word_id] = word; 126 | }else{ 127 | assert(itr->second == word); 128 | } 129 | } 130 | } 131 | } 132 | Dictionary &Dataset::get_dict_obj(){ 133 | return *_dict; 134 | } 135 | } -------------------------------------------------------------------------------- /src/python/dataset.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "../npylm/common.h" 6 | #include "../npylm/sentence.h" 7 | #include "corpus.h" 8 | #include "dictionary.h" 9 | 10 | namespace npylm { 11 | class Dataset{ 12 | private: 13 | Corpus* _corpus; 14 | void _add_words_to_dataset(std::wstring &sentence_str, std::vector &dataset); 15 | void _detect_collision_of_sentence(Sentence* sentence, std::unordered_map &pool, int max_word_length); 16 | public: 17 | int _max_sentence_length; 18 | int _avg_sentence_length; 19 | int _num_supervised_data; 20 | Dictionary* _dict; 21 | std::vector _sentence_sequences_train; 22 | std::vector _sentence_sequences_dev; 23 | Dataset(Corpus* corpus, double train_split, int seed); 24 | ~Dataset(); 25 | int get_num_sentences_train(); 26 | int get_num_sentences_supervised(); 27 | int get_num_sentences_dev(); 28 | int get_max_sentence_length(); 29 | int get_average_sentence_length(); 30 | int detect_hash_collision(int max_word_length); 31 | Dictionary &get_dict_obj(); 32 | }; 33 | } -------------------------------------------------------------------------------- /src/python/dictionary.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "dictionary.h" 8 | 9 | namespace npylm { 10 | void Dictionary::add_character(wchar_t character){ 11 | _all_characters.insert(character); 12 | } 13 | int Dictionary::get_num_characters(){ 14 | return _all_characters.size(); 15 | } 16 | bool Dictionary::load(std::string filename){ 17 | std::string dictionary_filename = filename; 18 | std::ifstream ifs(dictionary_filename); 19 | if(ifs.good()){ 20 | boost::archive::binary_iarchive iarchive(ifs); 21 | iarchive >> _all_characters; 22 | ifs.close(); 23 | return true; 24 | } 25 | ifs.close(); 26 | return false; 27 | } 28 | bool Dictionary::save(std::string filename){ 29 | std::ofstream ofs(filename); 30 | boost::archive::binary_oarchive oarchive(ofs); 31 | oarchive << _all_characters; 32 | ofs.close(); 33 | return true; 34 | } 35 | } -------------------------------------------------------------------------------- /src/python/dictionary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace npylm { 5 | class Dictionary{ 6 | public: 7 | std::unordered_set _all_characters; // すべての文字 8 | Dictionary(){} 9 | void add_character(wchar_t character); 10 | int get_num_characters(); 11 | bool load(std::string filename); 12 | bool save(std::string filename); 13 | }; 14 | } -------------------------------------------------------------------------------- /src/python/model.cpp: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | #include "../npylm/common.h" 3 | #include "../npylm/wordtype.h" 4 | #include 5 | 6 | namespace npylm { 7 | Model::Model(Dataset* dataset, int max_word_length) 8 | { 9 | _set_locale(); 10 | int max_sentence_length = dataset->get_max_sentence_length(); 11 | double vpylm_g0 = 1.0 / (double)dataset->_dict->get_num_characters(); 12 | _npylm = new NPYLM(max_word_length, max_sentence_length, vpylm_g0, 4, 1, VPYLM_BETA_STOP, VPYLM_BETA_PASS); 13 | _lattice = new Lattice(_npylm, max_word_length, dataset->get_max_sentence_length()); 14 | } 15 | Model::Model(Dataset* dataset, 16 | int max_word_length, // 可能な単語長の最大値. 英語16, 日本語8程度 17 | double initial_lambda_a, // 単語長のポアソン分布のλの事前分布のハイパーパラメータ 18 | double initial_lambda_b, // 単語長のポアソン分布のλの事前分布のハイパーパラメータ 19 | double vpylm_beta_stop, // VPYLMのハイパーパラメータ 20 | double vpylm_beta_pass) 21 | { // VPYLMのハイパーパラメータ 22 | _set_locale(); 23 | int max_sentence_length = dataset->get_max_sentence_length(); 24 | double vpylm_g0 = 1.0 / (double)dataset->_dict->get_num_characters(); 25 | _npylm = new NPYLM(max_word_length, max_sentence_length, vpylm_g0, initial_lambda_a, initial_lambda_b, vpylm_beta_stop, vpylm_beta_pass); 26 | _lattice = new Lattice(_npylm, max_word_length, dataset->get_max_sentence_length()); 27 | } 28 | Model::Model(std::string filename) 29 | { 30 | _set_locale(); 31 | _npylm = new NPYLM(); 32 | if (load(filename) == false) { 33 | std::cout << filename << " not found." << std::endl; 34 | exit(0); 35 | } 36 | _lattice = new Lattice(_npylm, _npylm->_max_word_length, _npylm->_max_sentence_length); 37 | } 38 | Model::~Model() 39 | { 40 | delete _npylm; 41 | } 42 | // 日本語周り 43 | void Model::_set_locale() 44 | { 45 | setlocale(LC_CTYPE, "ja_JP.UTF-8"); 46 | std::ios_base::sync_with_stdio(false); 47 | std::locale default_loc("ja_JP.UTF-8"); 48 | std::locale::global(default_loc); 49 | std::locale ctype_default(std::locale::classic(), default_loc, std::locale::ctype); //※ 50 | std::wcout.imbue(ctype_default); 51 | std::wcin.imbue(ctype_default); 52 | } 53 | int Model::get_max_word_length() 54 | { 55 | return _npylm->_max_word_length; 56 | } 57 | void Model::set_initial_lambda_a(double lambda) 58 | { 59 | _npylm->_lambda_a = lambda; 60 | _npylm->sample_lambda_with_initial_params(); 61 | } 62 | void Model::set_initial_lambda_b(double lambda) 63 | { 64 | _npylm->_lambda_b = lambda; 65 | _npylm->sample_lambda_with_initial_params(); 66 | } 67 | void Model::set_vpylm_beta_stop(double stop) 68 | { 69 | _npylm->_vpylm->_beta_stop = stop; 70 | } 71 | void Model::set_vpylm_beta_pass(double pass) 72 | { 73 | _npylm->_vpylm->_beta_pass = pass; 74 | } 75 | bool Model::load(std::string filename) 76 | { 77 | bool success = false; 78 | std::ifstream ifs(filename); 79 | if (ifs.good()) { 80 | boost::archive::binary_iarchive iarchive(ifs); 81 | iarchive >> *_npylm; 82 | success = true; 83 | } 84 | ifs.close(); 85 | return success; 86 | } 87 | bool Model::save(std::string filename) 88 | { 89 | bool success = false; 90 | std::ofstream ofs(filename); 91 | if (ofs.good()) { 92 | boost::archive::binary_oarchive oarchive(ofs); 93 | oarchive << *_npylm; 94 | success = true; 95 | } 96 | ofs.close(); 97 | return success; 98 | } 99 | void Model::parse(std::wstring sentence_str, std::vector& words) 100 | { 101 | // 領域の再確保 102 | _lattice->reserve(_npylm->_max_word_length, sentence_str.size()); 103 | _npylm->reserve(sentence_str.size()); 104 | words.clear(); 105 | std::vector segments; // 分割の一時保存用 106 | Sentence* sentence = new Sentence(sentence_str); 107 | _lattice->viterbi_decode(sentence, segments); 108 | sentence->split(segments); 109 | for (int n = 0; n < sentence->get_num_segments_without_special_tokens(); n++) { 110 | std::wstring word = sentence->get_word_str_at(n + 2); 111 | words.push_back(word); 112 | } 113 | delete sentence; 114 | } 115 | boost::python::list Model::python_parse(std::wstring sentence_str) 116 | { 117 | // 領域の再確保 118 | _lattice->reserve(_npylm->_max_word_length, sentence_str.size()); 119 | _npylm->reserve(sentence_str.size()); 120 | std::vector segments; // 分割の一時保存用 121 | Sentence* sentence = new Sentence(sentence_str); 122 | _lattice->viterbi_decode(sentence, segments); 123 | sentence->split(segments); 124 | boost::python::list words; 125 | for (int n = 0; n < sentence->get_num_segments_without_special_tokens(); n++) { 126 | std::wstring word = sentence->get_word_str_at(n + 2); 127 | words.append(word); 128 | } 129 | delete sentence; 130 | return words; 131 | } 132 | // use_scaling=trueならアンダーフローを防ぐ 133 | double Model::compute_log_forward_probability(std::wstring sentence_str, bool use_scaling) 134 | { 135 | // キャッシュの再確保 136 | _lattice->reserve(_npylm->_max_word_length, sentence_str.size()); 137 | _npylm->reserve(sentence_str.size()); 138 | Sentence* sentence = new Sentence(sentence_str); 139 | double log_px = _lattice->compute_log_forward_probability(sentence, use_scaling); 140 | delete sentence; 141 | return log_px; 142 | } 143 | 144 | boost::python::list Model::python_get_lambda() 145 | { 146 | boost::python::list ret; 147 | for (int type = 1; type <= WORDTYPE_NUM_TYPES; type++) { 148 | ret.append(_npylm->_lambda_for_type[type]); 149 | } 150 | return ret; 151 | } 152 | } -------------------------------------------------------------------------------- /src/python/model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "../npylm/npylm.h" 4 | #include "../npylm/lattice.h" 5 | #include "dataset.h" 6 | #include "dictionary.h" 7 | 8 | namespace npylm { 9 | class Model{ 10 | private: 11 | void _set_locale(); 12 | public: 13 | NPYLM* _npylm; 14 | Lattice* _lattice; // forward filtering-backward sampling 15 | Model(Dataset* dataset, int max_word_length); 16 | Model(Dataset* dataset, 17 | int max_word_length, 18 | double initial_lambda_a, 19 | double initial_lambda_b, 20 | double vpylm_beta_stop, 21 | double vpylm_beta_pass); 22 | Model(std::string filename); 23 | ~Model(); 24 | int get_max_word_length(); 25 | void set_initial_lambda_a(double lambda); 26 | void set_initial_lambda_b(double lambda); 27 | void set_vpylm_beta_stop(double stop); 28 | void set_vpylm_beta_pass(double pass); 29 | double compute_log_forward_probability(std::wstring sentence_str, bool use_scaling = true); 30 | bool load(std::string filename); 31 | bool save(std::string filename); 32 | void parse(std::wstring sentence_str, std::vector &words); 33 | boost::python::list python_parse(std::wstring sentence_str); 34 | boost::python::list python_get_lambda(); 35 | }; 36 | } -------------------------------------------------------------------------------- /src/python/trainer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "../npylm/sampler.h" 5 | #include "../npylm/wordtype.h" 6 | #include "trainer.h" 7 | 8 | namespace npylm { 9 | Trainer::Trainer(Dataset* dataset, Model* model, bool always_accept_new_segmentation){ 10 | _dataset = dataset; 11 | _model = model; 12 | _dict = dataset->_dict; 13 | _vpylm_sampling_probability_table = new double[_dict->get_num_characters() + 1]; // を含む 14 | _vpylm_sampling_id_table = new wchar_t[_dict->get_num_characters() + 1]; // を含む 15 | _added_to_npylm_train = new bool[dataset->_sentence_sequences_train.size()]; 16 | for(int data_index = 0;data_index < dataset->_sentence_sequences_train.size();data_index++){ 17 | _rand_indices_train.push_back(data_index); 18 | _added_to_npylm_train[data_index] = false; 19 | } 20 | for(int data_index = 0;data_index < dataset->_sentence_sequences_dev.size();data_index++){ 21 | _rand_indices_dev.push_back(data_index); 22 | } 23 | _always_accept_new_segmentation = always_accept_new_segmentation; 24 | _num_segmentation_rejection = 0; 25 | _num_segmentation_acceptance = 0; 26 | } 27 | 28 | // HPYLM,VPYLMのdとthetaをサンプリング 29 | void Trainer::sample_hpylm_vpylm_hyperparameters(){ 30 | _model->_npylm->sample_hpylm_vpylm_hyperparameters(); 31 | } 32 | // 文字種ごとにλのサンプリング 33 | void Trainer::sample_lambda(){ 34 | std::vector a_for_type(WORDTYPE_NUM_TYPES + 1, 0.0); 35 | std::vector b_for_type(WORDTYPE_NUM_TYPES + 1, 0.0); 36 | std::unordered_set words; 37 | NPYLM* npylm = _model->_npylm; 38 | for(int type = 1;type <= WORDTYPE_NUM_TYPES;type++){ 39 | a_for_type[type] = npylm->_lambda_a; 40 | b_for_type[type] = npylm->_lambda_b; 41 | } 42 | for(auto sentence: _dataset->_sentence_sequences_train){ 43 | // は除外 44 | for(int t = 2;t < sentence->get_num_segments() - 1;t++){ 45 | std::wstring word = sentence->get_word_str_at(t); 46 | id word_id = sentence->get_word_id_at(t); 47 | int word_length = sentence->get_word_length_at(t); 48 | if(word_length > npylm->_max_word_length){ 49 | continue; 50 | } 51 | if(words.find(word_id) == words.end()){ 52 | std::vector &tables = npylm->_hpylm->_root->_arrangement[word_id]; 53 | int t_w = tables.size(); 54 | int type = wordtype::detect_word_type(word); 55 | a_for_type[type] += t_w * word_length; 56 | b_for_type[type] += t_w; 57 | words.insert(word_id); 58 | } 59 | } 60 | } 61 | for(int type = 1;type <= WORDTYPE_NUM_TYPES;type++){ 62 | double lambda = sampler::gamma(a_for_type[type], b_for_type[type]); 63 | npylm->_lambda_for_type[type] = lambda; 64 | } 65 | } 66 | // VPYLMに文脈を渡し次の文字を生成 67 | wchar_t Trainer::sample_word_from_vpylm_given_context(wchar_t* context_ids, int context_length, int sample_t, bool skip_eow){ 68 | double sum_probs = 0; 69 | lm::VPYLM* vpylm = _model->_npylm->_vpylm; 70 | int table_index = 0; 71 | auto all_characters = _dict->_all_characters; 72 | int num_characters = _dict->get_num_characters(); 73 | for(wchar_t character_id: all_characters){ 74 | assert(table_index < num_characters); 75 | double pw = vpylm->compute_p_w_given_h(character_id, context_ids, 0, context_length - 1); 76 | sum_probs += pw; 77 | _vpylm_sampling_probability_table[table_index] = pw; 78 | _vpylm_sampling_id_table[table_index] = character_id; 79 | table_index++; 80 | } 81 | if(skip_eow == false){ 82 | assert(table_index < num_characters + 1); 83 | double pw = vpylm->compute_p_w_given_h(ID_EOW, context_ids, 0, context_length - 1); 84 | sum_probs += pw; 85 | _vpylm_sampling_probability_table[table_index] = pw; 86 | _vpylm_sampling_id_table[table_index] = ID_EOW; 87 | } 88 | 89 | double normalizer = 1.0 / sum_probs; 90 | double r = sampler::uniform(0, 1); 91 | double stack = 0; 92 | for(int i = 0;i <= table_index;i++){ 93 | stack += _vpylm_sampling_probability_table[i] * normalizer; 94 | if(r <= stack){ 95 | return _vpylm_sampling_id_table[i]; 96 | } 97 | } 98 | return _vpylm_sampling_id_table[table_index]; 99 | } 100 | // VPYLMから長さkの単語が出現する確率をキャッシュする 101 | void Trainer::update_p_k_given_vpylm(){ 102 | int num_samples = 20000; 103 | int early_stopping_threshold = 10; 104 | int max_word_length = _model->get_max_word_length() + 1; 105 | double* pk_vpylm = _model->_npylm->_pk_vpylm; 106 | int* num_words_of_k = new int[max_word_length]; 107 | for(int i = 0;i <= max_word_length;i++){ 108 | pk_vpylm[i] = 0; 109 | num_words_of_k[i] = 0; 110 | } 111 | wchar_t* wrapped_character_ids = new wchar_t[max_word_length + 2]; 112 | double sum_words = 0; 113 | for(int m = 1;m <= num_samples;m++){ 114 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 115 | return; 116 | } 117 | // wcout << "m = " << m << endl; 118 | wrapped_character_ids[0] = ID_BOW; 119 | int k = 0; 120 | for(int j = 0;j < max_word_length;j++){ 121 | bool skip_eow = (j == 0) ? true : false; 122 | wchar_t token_char = sample_word_from_vpylm_given_context(wrapped_character_ids, j + 1, j + 1, skip_eow); 123 | wrapped_character_ids[j + 1] = token_char; 124 | if(token_char == ID_EOW){ 125 | break; 126 | } 127 | k++; 128 | } 129 | sum_words += 1; 130 | if(k == 0){ // 131 | continue; 132 | } 133 | assert(k <= max_word_length); 134 | num_words_of_k[k] += 1; 135 | 136 | // すべてのkが生成されていたら早期終了 137 | if(m % 100 == 0){ 138 | bool stop = true; 139 | for(int k = 1;k <= max_word_length;k++){ 140 | if(num_words_of_k[k] < early_stopping_threshold){ 141 | stop = false; 142 | break; 143 | } 144 | } 145 | if(stop){ 146 | break; 147 | } 148 | } 149 | } 150 | for(int k = 1;k <= max_word_length;k++){ 151 | pk_vpylm[k] = (num_words_of_k[k] + 1) / (sum_words + max_word_length); // ラプラススムージングを入れておく 152 | assert(pk_vpylm[k] > 0); 153 | } 154 | delete[] num_words_of_k; 155 | delete[] wrapped_character_ids; 156 | } 157 | // 単語分割のギブスサンプリング 158 | void Trainer::gibbs(){ 159 | int num_sentences = _dataset->_sentence_sequences_train.size(); 160 | assert(num_sentences > 0); 161 | int max_sentence_length = _dataset->get_max_sentence_length(); 162 | std::vector segments; // 分割の一時保存用 163 | shuffle(_rand_indices_train.begin(), _rand_indices_train.end(), sampler::mt); // データをシャッフル 164 | int* old_segments = new int[max_sentence_length + 3]; 165 | int num_old_segments; 166 | // モデルパラメータを更新 167 | for(int step = 1;step <= num_sentences;step++){ 168 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 169 | return; 170 | } 171 | // 訓練データを一つ取り出す 172 | int data_index = _rand_indices_train[step - 1]; 173 | assert(data_index < _dataset->_sentence_sequences_train.size()); 174 | Sentence* sentence = _dataset->_sentence_sequences_train[data_index]; 175 | 176 | // 教師あり 177 | if(sentence->is_supervised()){ 178 | // モデルに追加されているかチェック 179 | if(_added_to_npylm_train[data_index] == true){ 180 | // 古い分割をモデルから削除 181 | for(int t = 2;t < sentence->get_num_segments();t++){ 182 | _model->_npylm->remove_customer_at_time_t(sentence, t); 183 | } 184 | } 185 | // 同じ分割結果を再度モデルに追加 186 | // ???「同じ分割を追加するなら最初から削除しなければ良いのでは?」 187 | // 追加と削除を繰り返すことでHPYLMとVPYLMのパラメータ(客の配置)がギブスサンプリングされるので必要 188 | for(int t = 2;t < sentence->get_num_segments();t++){ 189 | _model->_npylm->add_customer_at_time_t(sentence, t); 190 | } 191 | _added_to_npylm_train[data_index] = true; 192 | continue; 193 | } 194 | // 教師なし 195 | // モデルに追加されているかチェック 196 | if(_added_to_npylm_train[data_index] == true){ 197 | double old_log_ps, new_log_ps; 198 | // 古い分割をモデルから削除 199 | for(int t = 2;t < sentence->get_num_segments();t++){ 200 | _model->_npylm->remove_customer_at_time_t(sentence, t); 201 | } 202 | // 新しい分割の棄却判定をするかどうか 203 | if(_always_accept_new_segmentation == false){ 204 | // 古い分割を一時保存 205 | // は無視 206 | for(int i = 0;i < sentence->get_num_segments_without_special_tokens();i++){ 207 | old_segments[i] = sentence->_segments[i + 2]; // は2つ 208 | } 209 | num_old_segments = sentence->get_num_segments_without_special_tokens(); 210 | // 古い分割での文の確率を計算 211 | old_log_ps = _model->_npylm->compute_log_p_w(sentence); 212 | } 213 | 214 | #ifdef __DEBUG__ 215 | // 正規化しない場合の結果と比較するためシードを合わせる 216 | int seed = (unsigned int)time(NULL); 217 | sampler::mt.seed(seed); 218 | #endif 219 | 220 | // 新しい分割を取得 221 | _model->_lattice->blocked_gibbs(sentence, segments, true); 222 | sentence->split(segments); 223 | 224 | #ifdef __DEBUG__ 225 | // 正規化しない場合の結果と比較 226 | std::vector a = segments; 227 | sampler::mt.seed(seed); 228 | _model->_lattice->blocked_gibbs(sentence, segments, false); 229 | std::vector b = segments; 230 | assert(a.size() == b.size()); 231 | for(int i = 0;i < a.size();i++){ 232 | assert(a[i] == b[i]); 233 | } 234 | #endif 235 | 236 | // 以前の分割結果と現在の分割結果の確率を求める 237 | // 本来は分割を一定数サンプリングして平均をとるべき 238 | if(_always_accept_new_segmentation == false){ 239 | new_log_ps = _model->_npylm->compute_log_p_w(sentence); 240 | // 新しい分割の方が確率が低い場合、比率のベルヌーイ試行でどちらを採用するか決める. 241 | double bernoulli = std::min(1.0, exp(new_log_ps - old_log_ps)); 242 | double r = sampler::uniform(0, 1); 243 | if(bernoulli < r){ 244 | // 新しい分割を捨てて古いものに差し替える 245 | sentence->split(old_segments, num_old_segments); 246 | _num_segmentation_rejection++; 247 | }else{ 248 | _num_segmentation_acceptance++; 249 | } 250 | } 251 | } 252 | // 新しい分割結果をモデルに追加 253 | for(int t = 2;t < sentence->get_num_segments();t++){ 254 | _model->_npylm->add_customer_at_time_t(sentence, t); 255 | } 256 | _added_to_npylm_train[data_index] = true; 257 | } 258 | // 客数チェック 259 | assert(_model->_npylm->_hpylm->_root->_num_tables <= _model->_npylm->_vpylm->get_num_customers()); 260 | delete[] old_segments; 261 | } 262 | double Trainer::compute_perplexity_train(){ 263 | return _compute_perplexity(_dataset->_sentence_sequences_train); 264 | } 265 | double Trainer::compute_perplexity_dev(){ 266 | return _compute_perplexity(_dataset->_sentence_sequences_dev); 267 | } 268 | // ビタビアルゴリズムによる最尤分割のパープレキシティ 269 | double Trainer::_compute_perplexity(std::vector &dataset){ 270 | if(dataset.size() == 0){ 271 | return 0; 272 | } 273 | double ppl = 0; 274 | int num_sentences = dataset.size(); 275 | std::vector segments; // 分割の一時保存用 276 | for(int data_index = 0;data_index < num_sentences;data_index++){ 277 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 278 | return 0; 279 | } 280 | Sentence* sentence = dataset[data_index]->copy(); // 干渉を防ぐためコピー 281 | _model->_lattice->viterbi_decode(sentence, segments); 282 | sentence->split(segments); 283 | ppl += _model->_npylm->compute_log_p_w(sentence) / ((double)sentence->get_num_segments() - 2); 284 | delete sentence; 285 | } 286 | ppl = exp(-ppl / num_sentences); 287 | return ppl; 288 | } 289 | double Trainer::compute_log_likelihood_train(){ 290 | return _compute_log_likelihood(_dataset->_sentence_sequences_train); 291 | } 292 | double Trainer::compute_log_likelihood_dev(){ 293 | return _compute_log_likelihood(_dataset->_sentence_sequences_dev); 294 | } 295 | double Trainer::_compute_log_likelihood(std::vector &dataset){ 296 | if(dataset.size() == 0){ 297 | return 0; 298 | } 299 | double sum_log_likelihood = 0; 300 | int num_sentences = dataset.size(); 301 | for(int data_index = 0;data_index < num_sentences;data_index++){ 302 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 303 | return 0; 304 | } 305 | Sentence* sentence = dataset[data_index]; 306 | double log_px = _model->_lattice->compute_log_forward_probability(sentence, true); 307 | #ifdef __DEBUG__ 308 | double _log_px = _model->_lattice->compute_log_forward_probability(sentence, false); 309 | assert(abs(log_px - _log_px) < 1e-8); 310 | #endif 311 | sum_log_likelihood += log_px; 312 | } 313 | return sum_log_likelihood; 314 | } 315 | // デバッグ用 316 | void Trainer::remove_all_data(){ 317 | int max_sentence_length = _dataset->get_max_sentence_length(); 318 | wchar_t* wrapped_character_ids = new wchar_t[max_sentence_length + 2]; // を追加 319 | for(int data_index = 0;data_index < _dataset->_sentence_sequences_train.size();data_index++){ 320 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 321 | return; 322 | } 323 | Sentence* sentence = _dataset->_sentence_sequences_train[data_index]; 324 | // 古い分割をモデルから削除 325 | if(_added_to_npylm_train[data_index] == true){ 326 | for(int t = 2;t < sentence->get_num_segments();t++){ 327 | _model->_npylm->remove_customer_at_time_t(sentence, t); 328 | } 329 | } 330 | } 331 | delete[] wrapped_character_ids; 332 | } 333 | void Trainer::print_segmentation_train(int num_to_print){ 334 | _print_segmentation(num_to_print, _dataset->_sentence_sequences_train, _rand_indices_train); 335 | } 336 | void Trainer::print_segmentation_dev(int num_to_print){ 337 | shuffle(_rand_indices_dev.begin(), _rand_indices_dev.end(), sampler::mt); 338 | _print_segmentation(num_to_print, _dataset->_sentence_sequences_dev, _rand_indices_dev); 339 | } 340 | void Trainer::_print_segmentation(int num_to_print, std::vector &dataset, std::vector &rand_indices){ 341 | num_to_print = std::min((int)dataset.size(), num_to_print); 342 | std::vector segments; // 分割の一時保存用 343 | for(int n = 0;n < num_to_print;n++){ 344 | if (PyErr_CheckSignals() != 0) { // ctrl+cが押されたかチェック 345 | return; 346 | } 347 | int data_index = rand_indices[n]; 348 | Sentence* sentence = dataset[data_index]->copy(); 349 | _model->_lattice->viterbi_decode(sentence, segments); 350 | sentence->split(segments); 351 | sentence->dump_words(); 352 | delete sentence; 353 | } 354 | } 355 | } -------------------------------------------------------------------------------- /src/python/trainer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "dataset.h" 6 | #include "model.h" 7 | #include "dictionary.h" 8 | 9 | namespace npylm { 10 | class Trainer{ 11 | private: 12 | std::vector _rand_indices_train; 13 | std::vector _rand_indices_dev; 14 | Dataset* _dataset; 15 | Dictionary* _dict; 16 | Model* _model; 17 | double* _vpylm_sampling_probability_table; 18 | wchar_t* _vpylm_sampling_id_table; 19 | bool _always_accept_new_segmentation; 20 | bool* _added_to_npylm_train; 21 | int _num_segmentation_rejection; 22 | int _num_segmentation_acceptance; 23 | void _print_segmentation(int num_to_print, std::vector &dataset, std::vector &rand_indices); 24 | double _compute_perplexity(std::vector &dataset); 25 | double _compute_log_likelihood(std::vector &dataset); 26 | public: 27 | Trainer(Dataset* dataset, Model* model, bool always_accept_new_segmentation); 28 | void remove_all_data(); 29 | void gibbs(); 30 | void sample_hpylm_vpylm_hyperparameters(); 31 | void sample_lambda(); 32 | wchar_t sample_word_from_vpylm_given_context(wchar_t* context_ids, int context_length, int sample_t, bool skip_eow = false); 33 | void update_p_k_given_vpylm(); 34 | double compute_perplexity_train(); 35 | double compute_perplexity_dev(); 36 | double compute_log_likelihood_train(); 37 | double compute_log_likelihood_dev(); 38 | void print_segmentation_train(int num_to_print); 39 | void print_segmentation_dev(int num_to_print); 40 | }; 41 | } -------------------------------------------------------------------------------- /test/generate_test_sequence.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse, codecs 3 | import numpy as np 4 | 5 | def main(args): 6 | words = [char + char + char for char in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"]] 7 | prob = np.full((len(words), len(words)), 1.0 / len(words), dtype=np.float32) 8 | with codecs.open("../dataset/test.txt", "w", "utf-8") as f: 9 | for n in range(args.num_seq): 10 | sequence = "" 11 | word_index = 0 12 | for l in range(args.seq_length): 13 | word_index = int(np.argwhere(np.random.multinomial(1, prob[word_index]) == 1)) 14 | sequence += str(words[word_index]) 15 | print(sequence) 16 | f.write(sequence + "\n") 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("-l", "--seq-length", type=int, default=20, help="1つの文の長さ.") 21 | parser.add_argument("-n", "--num-seq", type=int, default=20, help="生成する文の個数.") 22 | main(parser.parse_args()) -------------------------------------------------------------------------------- /test/module_tests/hash.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/npylm/common.h" 8 | #include "../../src/npylm/hash.h" 9 | using std::cout; 10 | using std::flush; 11 | using std::endl; 12 | 13 | void test_hash_substring(){ 14 | std::wstring sentence_str = L"本論文 では, 教師 データ や 辞書 を 必要 とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する. 観測された文字列を, 文字 nグラム-単語 nグラムをノンパラメトリックベイズ法の枠組で統合した確率モデルからの出力とみなし, MCMC 法と動的計画法を用いて, 繰り返し隠れた「単語」を推定する. 提案法は, あらゆる言語の生文字列から直接, 全く知識なしに Kneser-Ney と同等に高精度にスムージングされ, 未知語のない nグラム言語モデルを構築する方法とみなすこともできる.話し言葉や古文を含む日本語, および中国語単語分割の標準的なデータセットでの実験により, 提案法の有効性および効率性を確認した."; 15 | for(int t = 0;t < sentence_str.size();t++){ 16 | for(int k = 0;k < std::min((size_t)t, sentence_str.size());k++){ 17 | size_t hash = npylm::hash_substring(sentence_str, t - k, t); 18 | std::wstring substr(sentence_str.begin() + t - k, sentence_str.begin() + t + 1); 19 | size_t _hash = npylm::hash_wstring(substr); 20 | assert(hash == _hash); 21 | } 22 | } 23 | } 24 | 25 | int main(){ 26 | test_hash_substring(); 27 | cout << "OK" << endl; 28 | return 0; 29 | } -------------------------------------------------------------------------------- /test/module_tests/lattice.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "../../src/npylm/sampler.h" 4 | #include "../../src/python/model.h" 5 | #include "../../src/python/dataset.h" 6 | #include "../../src/python/dictionary.h" 7 | #include "../../src/python/trainer.h" 8 | using namespace npylm; 9 | using std::wcout; 10 | using std::cout; 11 | using std::flush; 12 | using std::endl; 13 | 14 | double compute_forward_probability(Lattice* lattice, Sentence* sentence, bool normalize){ 15 | assert(sentence->size() <= lattice->_max_sentence_length); 16 | int size = sentence->size() + 1; 17 | lattice->_alpha[0][0][0] = 1; 18 | lattice->_log_z[0] = 0; 19 | for(int i = 0;i < size;i++){ 20 | for(int j = 0;j < lattice->_max_word_length + 1;j++){ 21 | lattice->_substring_word_id_cache[i][j] = 0; 22 | } 23 | } 24 | for(int t = 0;t < size;t++){ 25 | lattice->_log_z[t] = 0; 26 | for(int k = 0;k < lattice->_max_word_length + 1;k++){ 27 | for(int j = 0;j < lattice->_max_word_length + 1;j++){ 28 | lattice->_alpha[t][k][j] = -1; 29 | } 30 | } 31 | } 32 | lattice->forward_filtering(sentence, normalize); 33 | double sum_probability = 0; 34 | int t = sentence->size(); 35 | for(int k = 1;k <= std::min(t, lattice->_max_word_length);k++){ 36 | for(int j = 1;j <= std::min(t - k, lattice->_max_word_length);j++){ 37 | if(normalize){ 38 | sum_probability += lattice->_alpha[t][k][j] * exp(lattice->_log_z[t]); 39 | }else{ 40 | sum_probability += lattice->_alpha[t][k][j]; 41 | } 42 | } 43 | } 44 | return sum_probability; 45 | } 46 | void test_compute_forward_probability(){ 47 | std::string filename = "../../dataset/test.txt"; 48 | Corpus* corpus = new Corpus(); 49 | corpus->add_textfile(filename); 50 | int seed = 0; 51 | Dataset* dataset = new Dataset(corpus, 1, seed); 52 | int max_word_length = 8; 53 | Model* model = new Model(dataset, max_word_length); 54 | Trainer* trainer = new Trainer(dataset, model, false); 55 | Lattice* lattice = model->_lattice; 56 | 57 | for(int epoch = 0;epoch < 20;epoch++){ 58 | trainer->gibbs(); 59 | for(Sentence* sentence: dataset->_sentence_sequences_train){ 60 | double prob_n = compute_forward_probability(lattice, sentence, true); 61 | double prob_u = compute_forward_probability(lattice, sentence, false); 62 | assert(std::abs(prob_n - prob_u) < 1e-16); 63 | } 64 | } 65 | } 66 | 67 | int main(int argc, char *argv[]){ 68 | test_compute_forward_probability(); 69 | cout << "OK" << endl; 70 | } -------------------------------------------------------------------------------- /test/module_tests/npylm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/npylm/common.h" 8 | #include "../../src/npylm/lm/vpylm.h" 9 | #include "../../src/npylm/npylm.h" 10 | #include "../../src/npylm/sentence.h" 11 | #include "../../src/npylm/hash.h" 12 | using namespace npylm; 13 | using namespace npylm::lm; 14 | using std::cout; 15 | using std::flush; 16 | using std::endl; 17 | 18 | double compute_p_w_given_h(NPYLM* npylm, Sentence* sentence, int word_t_index){ 19 | int substr_char_t_start = sentence->_start[word_t_index]; 20 | int substr_char_t_end = sentence->_start[word_t_index] + sentence->_segments[word_t_index] - 1; 21 | int word_ids_length = sentence->get_num_segments(); 22 | int character_ids_length = sentence->size(); 23 | assert(word_t_index < word_ids_length); 24 | assert(substr_char_t_start >= 0); 25 | id word_id = sentence->_word_ids[word_t_index]; 26 | if(word_t_index == word_ids_length - 1){ 27 | assert(word_id == ID_EOS); 28 | }else{ 29 | if(word_id != ID_EOS){ 30 | assert(substr_char_t_end < character_ids_length); 31 | #ifdef __DEBUG__ 32 | id a = hash_substring_ptr(sentence->_characters, substr_char_t_start, substr_char_t_end); 33 | assert(a == word_id); 34 | #endif 35 | } 36 | } 37 | Node* node = npylm->find_node_by_tracing_back_context_from_time_t(sentence->_word_ids, word_ids_length, word_t_index, false, true); 38 | assert(node != NULL); 39 | double g0 = npylm->compute_g0_substring_at_time_t(sentence->_characters, sentence->size(), substr_char_t_start, substr_char_t_end, word_id); 40 | return node->compute_p_w(word_id, g0, npylm->_hpylm->_d_m, npylm->_hpylm->_theta_m); 41 | } 42 | 43 | void add_costmers(NPYLM* npylm, Sentence* sentence, int* segments_without_special_tokens){ 44 | for(int repeat = 0;repeat < 5;repeat++){ 45 | for(int i = 1;i <= sentence->_sentence_str.size() - 3;i++){ 46 | for(int m = 1;m <= sentence->_sentence_str.size() - i - 2;m++){ 47 | for(int k = 1;k <= sentence->_sentence_str.size() - i - m - 1;k++){ 48 | int n = sentence->_sentence_str.size() - i - m - k; 49 | assert(i + m + k + n == sentence->_sentence_str.size()); 50 | segments_without_special_tokens[0] = i; 51 | segments_without_special_tokens[1] = m; 52 | segments_without_special_tokens[2] = k; 53 | segments_without_special_tokens[3] = n; 54 | sentence->split(segments_without_special_tokens, 4); 55 | for(int t = 2;t < sentence->get_num_segments();t++){ 56 | npylm->add_customer_at_time_t(sentence, t); 57 | } 58 | } 59 | } 60 | } 61 | } 62 | } 63 | 64 | void remove_costmers(NPYLM* npylm, Sentence* sentence, int* segments_without_special_tokens){ 65 | for(int repeat = 0;repeat < 5;repeat++){ 66 | for(int i = 1;i <= sentence->_sentence_str.size() - 3;i++){ 67 | for(int m = 1;m <= sentence->_sentence_str.size() - i - 2;m++){ 68 | for(int k = 1;k <= sentence->_sentence_str.size() - i - m - 1;k++){ 69 | int n = sentence->_sentence_str.size() - i - m - k; 70 | assert(i + m + k + n == sentence->_sentence_str.size()); 71 | segments_without_special_tokens[0] = i; 72 | segments_without_special_tokens[1] = m; 73 | segments_without_special_tokens[2] = k; 74 | segments_without_special_tokens[3] = n; 75 | sentence->split(segments_without_special_tokens, 4); 76 | for(int t = 2;t < sentence->get_num_segments();t++){ 77 | npylm->remove_customer_at_time_t(sentence, t); 78 | } 79 | } 80 | } 81 | } 82 | } 83 | } 84 | 85 | void test_vpylm_add_customers(){ 86 | VPYLM* vpylm = new VPYLM(0.001, 1000, 4, 1); 87 | NPYLM* npylm = new NPYLM(20, 10000, 0.001, 4, 1, 4, 1); 88 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 89 | Sentence* sentence = new Sentence(sentence_str); 90 | wchar_t* token_ids = new wchar_t[sentence_str.size() + 2]; 91 | 92 | double* parent_pw_cache = new double[sentence->size() + 2]; 93 | Node** path_nodes_cache = new Node*[sentence->size() + 2]; 94 | sampler::mt.seed(0); 95 | for(int n = 0;n < 100;n++){ 96 | for(int t = 0;t < sentence->size();t++){ 97 | wrap_bow_eow(sentence->_characters, 0, t, token_ids); 98 | for(int char_t = 0;char_t <= t + 2;char_t++){ 99 | int depth = vpylm->sample_depth_at_time_t(token_ids, char_t, parent_pw_cache, path_nodes_cache); 100 | vpylm->add_customer_at_time_t(token_ids, char_t, depth); 101 | } 102 | } 103 | } 104 | 105 | sampler::mt.seed(0); 106 | for(int n = 0;n < 100;n++){ 107 | for(int t = 0;t < sentence->size();t++){ 108 | std::vector prev_depths; 109 | npylm->vpylm_add_customers(sentence->_characters, 0, t, token_ids, prev_depths); 110 | } 111 | } 112 | assert(vpylm->get_num_nodes() == npylm->_vpylm->get_num_nodes()); 113 | assert(vpylm->get_num_customers() == npylm->_vpylm->get_num_customers()); 114 | assert(vpylm->get_num_tables() == npylm->_vpylm->get_num_tables()); 115 | assert(vpylm->get_sum_stop_counts() == npylm->_vpylm->get_sum_stop_counts()); 116 | assert(vpylm->get_sum_pass_counts() == npylm->_vpylm->get_sum_pass_counts()); 117 | 118 | delete[] parent_pw_cache; 119 | delete[] path_nodes_cache; 120 | delete[] token_ids; 121 | delete sentence; 122 | delete vpylm; 123 | delete npylm; 124 | } 125 | 126 | void test_vpylm_remove_customers(){ 127 | NPYLM* npylm = new NPYLM(20, 10000, 0.001, 4, 1, 4, 1); 128 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 129 | Sentence* sentence = new Sentence(sentence_str); 130 | wchar_t* token_ids = new wchar_t[sentence_str.size() + 2]; 131 | 132 | std::vector> prev_depths_list; 133 | for(int n = 0;n < 100;n++){ 134 | for(int t = 0;t < sentence->size();t++){ 135 | std::vector prev_depths; 136 | npylm->vpylm_add_customers(sentence->_characters, 0, t, token_ids, prev_depths); 137 | prev_depths_list.push_back(prev_depths); 138 | } 139 | } 140 | 141 | auto itr = prev_depths_list.begin(); 142 | for(int n = 0;n < 100;n++){ 143 | for(int t = 0;t < sentence->size();t++){ 144 | std::vector &prev_depths = *itr; 145 | npylm->vpylm_remove_customers(sentence->_characters, 0, t, token_ids, prev_depths); 146 | itr++; 147 | } 148 | } 149 | 150 | assert(npylm->_vpylm->get_num_customers() == 0); 151 | assert(npylm->_vpylm->get_num_tables() == 0); 152 | assert(npylm->_vpylm->get_sum_stop_counts() == 0); 153 | assert(npylm->_vpylm->get_sum_pass_counts() == 0); 154 | 155 | delete[] token_ids; 156 | delete sentence; 157 | delete npylm; 158 | } 159 | 160 | void test_remove_customer_at_time_t(){ 161 | NPYLM* npylm = new NPYLM(20, 10000, 0.001, 4, 1, 4, 1); 162 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 163 | Sentence* sentence = new Sentence(sentence_str); 164 | int* segments_without_special_tokens = new int[4]; 165 | add_costmers(npylm, sentence, segments_without_special_tokens); 166 | remove_costmers(npylm, sentence, segments_without_special_tokens); 167 | 168 | assert(npylm->_vpylm->get_num_customers() == 0); 169 | assert(npylm->_vpylm->get_num_tables() == 0); 170 | assert(npylm->_vpylm->get_sum_stop_counts() == 0); 171 | assert(npylm->_vpylm->get_sum_pass_counts() == 0); 172 | 173 | assert(npylm->_hpylm->get_num_customers() == 0); 174 | assert(npylm->_hpylm->get_num_tables() == 0); 175 | assert(npylm->_hpylm->get_sum_stop_counts() == 0); 176 | assert(npylm->_hpylm->get_sum_pass_counts() == 0); 177 | 178 | delete[] segments_without_special_tokens; 179 | delete npylm; 180 | } 181 | 182 | void test_find_node_by_tracing_back_context_from_time_t(){ 183 | NPYLM* npylm = new NPYLM(20, 10000, 0.001, 4, 1, 4, 1); 184 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 185 | Sentence* sentence = new Sentence(sentence_str); 186 | int* segments_without_special_tokens = new int[4]; 187 | add_costmers(npylm, sentence, segments_without_special_tokens); 188 | double* parent_pw_cache = new double[3]; 189 | 190 | for(int t = 2;t < sentence->get_num_segments();t++){ 191 | Node* node = npylm->find_node_by_tracing_back_context_from_time_t(sentence, t, parent_pw_cache, false, false); 192 | Node* parent = node->_parent; 193 | id word_t_id = sentence->get_word_id_at(t); 194 | int substr_char_t_start = sentence->_start[t]; 195 | int substr_char_t_end = sentence->_start[t] + sentence->_segments[t] - 1; 196 | double g0 = npylm->compute_g0_substring_at_time_t(sentence->_characters, sentence->size(), substr_char_t_start, substr_char_t_end, word_t_id); 197 | double pw = parent->compute_p_w(word_t_id, g0, npylm->_hpylm->_d_m, npylm->_hpylm->_theta_m); 198 | assert(pw == parent_pw_cache[2]); 199 | } 200 | 201 | delete[] parent_pw_cache; 202 | delete npylm; 203 | } 204 | 205 | void test_compute_p_w_given_h(){ 206 | NPYLM* npylm = new NPYLM(20, 10000, 0.001, 4, 1, 4, 1); 207 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 208 | Sentence* sentence = new Sentence(sentence_str); 209 | int* segments_without_special_tokens = new int[4]; 210 | add_costmers(npylm, sentence, segments_without_special_tokens); 211 | double* parent_pw_cache = new double[3]; 212 | 213 | for(int t = 2;t < sentence->get_num_segments();t++){ 214 | double p1 = compute_p_w_given_h(npylm, sentence, t); 215 | double p2 = npylm->compute_p_w_given_h(sentence, t); 216 | assert(p1 == p2); 217 | } 218 | 219 | delete[] parent_pw_cache; 220 | delete npylm; 221 | } 222 | 223 | int main(){ 224 | setlocale(LC_CTYPE, "ja_JP.UTF-8"); 225 | std::ios_base::sync_with_stdio(false); 226 | std::locale default_loc("ja_JP.UTF-8"); 227 | std::locale::global(default_loc); 228 | std::locale ctype_default(std::locale::classic(), default_loc, std::locale::ctype); //※ 229 | std::wcout.imbue(ctype_default); 230 | std::wcin.imbue(ctype_default); 231 | 232 | test_vpylm_add_customers(); 233 | cout << "OK" << endl; 234 | test_vpylm_remove_customers(); 235 | cout << "OK" << endl; 236 | test_remove_customer_at_time_t(); 237 | cout << "OK" << endl; 238 | test_find_node_by_tracing_back_context_from_time_t(); 239 | cout << "OK" << endl; 240 | test_compute_p_w_given_h(); 241 | cout << "OK" << endl; 242 | return 0; 243 | } -------------------------------------------------------------------------------- /test/module_tests/sentence.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "../../src/npylm/common.h" 9 | #include "../../src/npylm/hash.h" 10 | #include "../../src/npylm/sentence.h" 11 | using namespace npylm; 12 | using std::cout; 13 | using std::flush; 14 | using std::endl; 15 | 16 | void test_get_substr_word_id(){ 17 | std::wstring sentence_str = L"本論文 では, 教師 データ や 辞書 を 必要 とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する. 観測された文字列を, 文字 nグラム-単語 nグラムをノンパラメトリックベイズ法の枠組で統合した確率モデルからの出力とみなし, MCMC 法と動的計画法を用いて, 繰り返し隠れた「単語」を推定する. 提案法は, あらゆる言語の生文字列から直接, 全く知識なしに Kneser-Ney と同等に高精度にスムージングされ, 未知語のない nグラム言語モデルを構築する方法とみなすこともできる.話し言葉や古文を含む日本語, および中国語単語分割の標準的なデータセットでの実験により, 提案法の有効性および効率性を確認した."; 18 | Sentence* sentence = new Sentence(sentence_str); 19 | for(int t = 0;t < sentence_str.size();t++){ 20 | for(int k = 0;k < std::min((size_t)t, sentence_str.size());k++){ 21 | size_t hash = hash_substring(sentence_str, t - k, t); 22 | std::wstring substr(sentence_str.begin() + t - k, sentence_str.begin() + t + 1); 23 | size_t _hash = hash_wstring(substr); 24 | size_t __hash = sentence->get_substr_word_id(t - k, t); 25 | assert(hash == _hash && _hash == __hash); 26 | } 27 | } 28 | delete sentence; 29 | } 30 | 31 | void test_get_substr_word_str(){ 32 | std::wstring sentence_str = L"本論文 では, 教師 データ や 辞書 を 必要 とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する. 観測された文字列を, 文字 nグラム-単語 nグラムをノンパラメトリックベイズ法の枠組で統合した確率モデルからの出力とみなし, MCMC 法と動的計画法を用いて, 繰り返し隠れた「単語」を推定する. 提案法は, あらゆる言語の生文字列から直接, 全く知識なしに Kneser-Ney と同等に高精度にスムージングされ, 未知語のない nグラム言語モデルを構築する方法とみなすこともできる.話し言葉や古文を含む日本語, および中国語単語分割の標準的なデータセットでの実験により, 提案法の有効性および効率性を確認した."; 33 | Sentence* sentence = new Sentence(sentence_str); 34 | for(int t = 0;t < sentence_str.size();t++){ 35 | for(int k = 0;k < std::min((size_t)t, sentence_str.size());k++){ 36 | std::wstring substr(sentence_str.begin() + t - k, sentence_str.begin() + t + 1); 37 | std::wstring _substr = sentence->get_substr_word_str(t - k, t); 38 | assert(substr.compare(_substr) == 0); 39 | } 40 | } 41 | delete sentence; 42 | } 43 | 44 | void test_split_by_array(){ 45 | std::wstring sentence_str = L"提案法は, あらゆる言語の生文字列から直接, 全く知識なしに Kneser-Ney と同等に高精度にスムージングされ, 未知語のない nグラム言語モデルを構築する方法とみなすこともできる."; 46 | Sentence* sentence = new Sentence(sentence_str); 47 | int* segments_without_special_tokens = new int[4]; 48 | for(int i = 1;i <= sentence_str.size() - 3;i++){ 49 | for(int m = 1;m <= sentence_str.size() - i - 2;m++){ 50 | for(int k = 1;k <= sentence_str.size() - i - m - 1;k++){ 51 | int n = sentence_str.size() - i - m - k; 52 | assert(i + m + k + n == sentence_str.size()); 53 | 54 | std::wstring first(sentence_str.begin(), sentence_str.begin() + i); 55 | std::wstring second(sentence_str.begin() + i, sentence_str.begin() + i + m); 56 | std::wstring third(sentence_str.begin() + i + m, sentence_str.begin() + i + m + k); 57 | std::wstring forth(sentence_str.begin() + i + m + k, sentence_str.begin() + sentence_str.size()); 58 | // std::wcout << first << " / " << second << " / " << third << " / " << forth<< endl; 59 | segments_without_special_tokens[0] = i; 60 | segments_without_special_tokens[1] = m; 61 | segments_without_special_tokens[2] = k; 62 | segments_without_special_tokens[3] = n; 63 | sentence->split(segments_without_special_tokens, 4); 64 | 65 | std::wstring _first = sentence->get_word_str_at(2); 66 | std::wstring _second = sentence->get_word_str_at(3); 67 | std::wstring _third = sentence->get_word_str_at(4); 68 | std::wstring _forth = sentence->get_word_str_at(5); 69 | assert(first.compare(_first) == 0); 70 | assert(second.compare(_second) == 0); 71 | assert(third.compare(_third) == 0); 72 | assert(forth.compare(_forth) == 0); 73 | 74 | size_t hash_first = hash_substring(sentence_str, 0, i - 1); 75 | size_t hash_second = hash_substring(sentence_str, i, i + m - 1); 76 | size_t hash_third = hash_substring(sentence_str, i + m, i + m + k - 1); 77 | size_t hash_forth = hash_substring(sentence_str, i + m + k, i + m + k + n - 1); 78 | assert(hash_first == sentence->get_word_id_at(2)); 79 | assert(hash_second == sentence->get_word_id_at(3)); 80 | assert(hash_third == sentence->get_word_id_at(4)); 81 | assert(hash_forth == sentence->get_word_id_at(5)); 82 | } 83 | } 84 | } 85 | delete[] segments_without_special_tokens; 86 | delete sentence; 87 | } 88 | 89 | void test_split_by_vector(){ 90 | std::wstring sentence_str = L"提案法は, あらゆる言語の生文字列から直接, 全く知識なしに Kneser-Ney と同等に高精度にスムージングされ, 未知語のない nグラム言語モデルを構築する方法とみなすこともできる."; 91 | Sentence* sentence = new Sentence(sentence_str); 92 | std::vector segments_without_special_tokens{0, 0, 0, 0}; 93 | for(int i = 1;i <= sentence_str.size() - 3;i++){ 94 | for(int m = 1;m <= sentence_str.size() - i - 2;m++){ 95 | for(int k = 1;k <= sentence_str.size() - i - m - 1;k++){ 96 | int n = sentence_str.size() - i - m - k; 97 | assert(i + m + k + n == sentence_str.size()); 98 | 99 | std::wstring first(sentence_str.begin(), sentence_str.begin() + i); 100 | std::wstring second(sentence_str.begin() + i, sentence_str.begin() + i + m); 101 | std::wstring third(sentence_str.begin() + i + m, sentence_str.begin() + i + m + k); 102 | std::wstring forth(sentence_str.begin() + i + m + k, sentence_str.begin() + sentence_str.size()); 103 | // std::wcout << first << " / " << second << " / " << third << " / " << forth<< endl; 104 | segments_without_special_tokens[0] = i; 105 | segments_without_special_tokens[1] = m; 106 | segments_without_special_tokens[2] = k; 107 | segments_without_special_tokens[3] = n; 108 | sentence->split(segments_without_special_tokens); 109 | 110 | std::wstring _first = sentence->get_word_str_at(2); 111 | std::wstring _second = sentence->get_word_str_at(3); 112 | std::wstring _third = sentence->get_word_str_at(4); 113 | std::wstring _forth = sentence->get_word_str_at(5); 114 | assert(first.compare(_first) == 0); 115 | assert(second.compare(_second) == 0); 116 | assert(third.compare(_third) == 0); 117 | assert(forth.compare(_forth) == 0); 118 | 119 | size_t hash_first = hash_substring(sentence_str, 0, i - 1); 120 | size_t hash_second = hash_substring(sentence_str, i, i + m - 1); 121 | size_t hash_third = hash_substring(sentence_str, i + m, i + m + k - 1); 122 | size_t hash_forth = hash_substring(sentence_str, i + m + k, i + m + k + n - 1); 123 | assert(hash_first == sentence->get_word_id_at(2)); 124 | assert(hash_second == sentence->get_word_id_at(3)); 125 | assert(hash_third == sentence->get_word_id_at(4)); 126 | assert(hash_forth == sentence->get_word_id_at(5)); 127 | } 128 | } 129 | } 130 | delete sentence; 131 | } 132 | 133 | int main(){ 134 | setlocale(LC_CTYPE, "ja_JP.UTF-8"); 135 | std::ios_base::sync_with_stdio(false); 136 | std::locale default_loc("ja_JP.UTF-8"); 137 | std::locale::global(default_loc); 138 | std::locale ctype_default(std::locale::classic(), default_loc, std::locale::ctype); //※ 139 | std::wcout.imbue(ctype_default); 140 | std::wcin.imbue(ctype_default); 141 | 142 | test_get_substr_word_id(); 143 | cout << "OK" << endl; 144 | test_get_substr_word_str(); 145 | cout << "OK" << endl; 146 | test_split_by_array(); 147 | cout << "OK" << endl; 148 | test_split_by_vector(); 149 | cout << "OK" << endl; 150 | return 0; 151 | } -------------------------------------------------------------------------------- /test/module_tests/vpylm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/npylm/common.h" 8 | #include "../../src/npylm/lm/vpylm.h" 9 | #include "../../src/npylm/npylm.h" 10 | #include "../../src/npylm/sentence.h" 11 | using namespace npylm; 12 | using namespace npylm::lm; 13 | using std::cout; 14 | using std::flush; 15 | using std::endl; 16 | 17 | int sample_depth_at_time_t(VPYLM* vpylm, wchar_t const* token_ids, int t){ 18 | if(t == 0){ 19 | return 0; 20 | } 21 | wchar_t token_t = token_ids[t]; 22 | 23 | // この値を下回れば打ち切り 24 | double eps = VPYLM_EPS; 25 | 26 | double sum = 0; 27 | double p_pass = 0; 28 | double parent_pw = vpylm->_g0; 29 | int sampling_table_size = 0; 30 | Node* node = vpylm->_root; 31 | for(int n = 0;n <= t;n++){ 32 | if(node){ 33 | double pw = node->compute_p_w(token_t, vpylm->_g0, vpylm->_d_m, vpylm->_theta_m); 34 | double p_stop = node->stop_probability(vpylm->_beta_stop, vpylm->_beta_pass); 35 | p_pass = node->pass_probability(vpylm->_beta_stop, vpylm->_beta_pass); 36 | double p = pw * p_stop; 37 | parent_pw = pw; 38 | vpylm->_sampling_table[n] = p; 39 | sampling_table_size += 1; 40 | sum += p; 41 | if(p_stop < eps){ 42 | break; 43 | } 44 | if(n < t){ 45 | wchar_t context_token_id = token_ids[t - n - 1]; 46 | node = node->find_child_node(context_token_id); 47 | } 48 | }else{ 49 | double p_stop = p_pass * vpylm->_beta_stop / (vpylm->_beta_stop + vpylm->_beta_pass); 50 | double p = parent_pw * p_stop; 51 | // probs.push_back(p); 52 | vpylm->_sampling_table[n] = p; 53 | sampling_table_size += 1; 54 | sum += p; 55 | p_pass *= vpylm->_beta_pass / (vpylm->_beta_stop + vpylm->_beta_pass); 56 | if(p_stop < eps){ 57 | break; 58 | } 59 | } 60 | } 61 | // assert(sampling_table_size == t + 1); 62 | double normalizer = 1.0 / sum; 63 | double bernoulli = sampler::uniform(0, 1); 64 | double stack = 0; 65 | for(int n = 0;n < sampling_table_size;n++){ 66 | stack += vpylm->_sampling_table[n] * normalizer; 67 | if(bernoulli < stack){ 68 | return n; 69 | } 70 | } 71 | return vpylm->_sampling_table[sampling_table_size - 1]; 72 | } 73 | 74 | bool add_customer_at_time_t(VPYLM* vpylm, wchar_t const* token_ids, int t, int depth_t){ 75 | assert(0 <= depth_t && depth_t <= t); 76 | Node* node = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t, true, false); 77 | assert(node != NULL); 78 | assert(node->_depth == depth_t); 79 | if(depth_t > 0){ // ルートノードは特殊なので無視 80 | assert(node->_token_id == token_ids[t - depth_t]); 81 | } 82 | id token_t = token_ids[t]; 83 | int tabke_k; 84 | return node->add_customer(token_t, vpylm->_g0, vpylm->_d_m, vpylm->_theta_m, true, tabke_k); 85 | } 86 | 87 | double compute_p_w_given_h(VPYLM* vpylm, wchar_t const* token_ids, int context_start, int context_end){ 88 | Node* node = vpylm->_root; 89 | wchar_t target_id = token_ids[context_end + 1]; 90 | assert(node != NULL); 91 | double parent_pass_probability = 1; 92 | double p = 0; 93 | double eps = VPYLM_EPS; // 停止確率がこの値を下回れば打ち切り 94 | double parent_pw = vpylm->_g0; 95 | double p_stop = 1; 96 | int depth = 0; 97 | while(p_stop > eps){ 98 | // ノードがない場合親の確率とベータ事前分布から計算 99 | if(node == NULL){ 100 | p_stop = (vpylm->_beta_stop) / (vpylm->_beta_pass + vpylm->_beta_stop) * parent_pass_probability; 101 | p += parent_pw * p_stop; 102 | parent_pass_probability *= (vpylm->_beta_pass) / (vpylm->_beta_pass + vpylm->_beta_stop); 103 | }else{ 104 | assert(context_end - depth + 1 >= 0); 105 | assert(node->_depth == depth); 106 | wchar_t context_token_id = token_ids[context_end - depth]; 107 | double pw = node->compute_p_w(target_id, vpylm->_g0, vpylm->_d_m, vpylm->_theta_m); 108 | p_stop = node->stop_probability(vpylm->_beta_stop, vpylm->_beta_pass); 109 | p += pw * p_stop; 110 | Node* child = node->find_child_node(context_token_id); 111 | parent_pass_probability = node->pass_probability(vpylm->_beta_stop, vpylm->_beta_pass); 112 | parent_pw = pw; 113 | node = child; 114 | if(depth > 0 && node){ 115 | assert(node->_token_id == context_token_id); 116 | } 117 | } 118 | depth++; 119 | } 120 | assert(p > 0); 121 | return p; 122 | } 123 | void test_compute_p_w_given_h(){ 124 | VPYLM* vpylm = new VPYLM(0.001, 1000, 4, 1); 125 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 126 | Sentence* sentence = new Sentence(sentence_str); 127 | wchar_t* wrapped_character_ids = new wchar_t[sentence_str.size() + 2]; 128 | wrap_bow_eow(sentence->_characters, 0, sentence->size() - 1, wrapped_character_ids); 129 | for(int t = 0;t < sentence->size();t++){ 130 | for(int depth_t = 0;depth_t <= t;depth_t++){ 131 | vpylm->add_customer_at_time_t(wrapped_character_ids, t, depth_t); 132 | } 133 | } 134 | for(int end = 0;end < sentence->size() - 1;end++){ 135 | for(int start = 0;start < end;start++){ 136 | double a = vpylm->compute_p_w_given_h(wrapped_character_ids, start, end); 137 | double b = compute_p_w_given_h(vpylm, wrapped_character_ids, start, end); 138 | assert(std::abs(a - b) < 1e-16); 139 | } 140 | } 141 | delete sentence; 142 | delete vpylm; 143 | delete[] wrapped_character_ids; 144 | } 145 | 146 | void test_find_node_by_tracing_back_context(){ 147 | VPYLM* vpylm = new VPYLM(0.001, 1000, 4, 1); 148 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 149 | Sentence* sentence = new Sentence(sentence_str); 150 | wchar_t* token_ids = new wchar_t[sentence->size() + 2]; 151 | wrap_bow_eow(sentence->_characters, 0, sentence->size() - 1, token_ids); 152 | for(int t = 0;t < sentence->size() + 2;t++){ 153 | for(int depth_t = 0;depth_t <= t;depth_t++){ 154 | vpylm->add_customer_at_time_t(token_ids, t, depth_t); 155 | } 156 | } 157 | 158 | double* parent_pw_cache = new double[sentence->size() + 2]; 159 | for(int t = 1;t < sentence->size() + 2;t++){ 160 | for(int depth_t = 1;depth_t <= t;depth_t++){ 161 | Node* node0 = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t - 1); 162 | assert(node0 != NULL); 163 | if(depth_t > 1){ 164 | assert(node0->_token_id == token_ids[t - depth_t + 1]); 165 | } 166 | Node* node1 = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t, parent_pw_cache); 167 | assert(node1->_token_id == token_ids[t - depth_t]); 168 | 169 | double p = node0->compute_p_w(token_ids[t], vpylm->_g0, vpylm->_d_m, vpylm->_theta_m); 170 | assert(parent_pw_cache[depth_t] == p); 171 | } 172 | } 173 | 174 | Node** path_nodes_cache = new Node*[sentence->size() + 2]; 175 | for(int t = 1;t < sentence->size() + 2;t++){ 176 | vpylm->sample_depth_at_time_t(token_ids, t, parent_pw_cache, path_nodes_cache); 177 | 178 | for(int depth_t = 1;depth_t <= t;depth_t++){ 179 | Node* node0 = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t - 1); 180 | assert(node0 != NULL); 181 | if(depth_t > 1){ 182 | assert(node0->_token_id == token_ids[t - depth_t + 1]); 183 | } 184 | double p = node0->compute_p_w(token_ids[t], vpylm->_g0, vpylm->_d_m, vpylm->_theta_m); 185 | assert(parent_pw_cache[depth_t] == p); 186 | assert(node0 == path_nodes_cache[depth_t - 1]); 187 | } 188 | 189 | for(int depth_t = 1;depth_t <= t;depth_t++){ 190 | path_nodes_cache[t - depth_t] = NULL; 191 | Node* node0 = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t - 1); 192 | Node* node1 = vpylm->find_node_by_tracing_back_context(token_ids, t, depth_t - 1, path_nodes_cache); 193 | assert(node0 != NULL); 194 | assert(node1 != NULL); 195 | assert(node0 == node1); 196 | } 197 | } 198 | 199 | delete sentence; 200 | delete vpylm; 201 | delete[] token_ids; 202 | delete[] path_nodes_cache; 203 | delete[] parent_pw_cache; 204 | } 205 | 206 | void test_add_customer(){ 207 | sampler::mt.seed(0); 208 | VPYLM* vpylm1 = new VPYLM(0.001, 1000, 4, 1); 209 | VPYLM* vpylm2 = new VPYLM(0.001, 1000, 4, 1); 210 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 211 | Sentence* sentence = new Sentence(sentence_str); 212 | wchar_t* token_ids = new wchar_t[sentence->size() + 2]; 213 | wrap_bow_eow(sentence->_characters, 0, sentence->size() - 1, token_ids); 214 | for(int t = 0;t < sentence->size();t++){ 215 | for(int depth_t = 0;depth_t <= t;depth_t++){ 216 | vpylm1->add_customer_at_time_t(token_ids, t, depth_t); 217 | } 218 | } 219 | sampler::mt.seed(0); 220 | for(int t = 0;t < sentence->size();t++){ 221 | for(int depth_t = 0;depth_t <= t;depth_t++){ 222 | add_customer_at_time_t(vpylm2, token_ids, t, depth_t); 223 | } 224 | } 225 | assert(vpylm1->get_num_nodes() == vpylm2->get_num_nodes()); 226 | assert(vpylm1->get_num_customers() == vpylm2->get_num_customers()); 227 | assert(vpylm1->get_num_tables() == vpylm2->get_num_tables()); 228 | assert(vpylm1->get_sum_stop_counts() == vpylm2->get_sum_stop_counts()); 229 | assert(vpylm1->get_sum_pass_counts() == vpylm2->get_sum_pass_counts()); 230 | for(int end = 1;end < sentence->size() + 2;end++){ 231 | for(int start = 0;start < end;start++){ 232 | double a = vpylm1->compute_p_w(token_ids, end - start + 1); 233 | double b = vpylm2->compute_p_w(token_ids, end - start + 1); 234 | assert(a == b); 235 | } 236 | } 237 | delete sentence; 238 | delete vpylm1; 239 | delete vpylm2; 240 | delete[] token_ids; 241 | } 242 | 243 | void test_remove_customer(){ 244 | VPYLM* vpylm = new VPYLM(0.001, 1000, 4, 1); 245 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 246 | Sentence* sentence = new Sentence(sentence_str); 247 | wchar_t* token_ids = new wchar_t[sentence->size() + 2]; 248 | wrap_bow_eow(sentence->_characters, 0, sentence->size() - 1, token_ids); 249 | for(int n = 0;n < 100;n++){ 250 | for(int t = 0;t < sentence->size();t++){ 251 | for(int depth_t = 0;depth_t <= t;depth_t++){ 252 | vpylm->add_customer_at_time_t(token_ids, t, depth_t); 253 | } 254 | } 255 | } 256 | for(int n = 0;n < 100;n++){ 257 | for(int t = 0;t < sentence->size();t++){ 258 | for(int depth_t = 0;depth_t <= t;depth_t++){ 259 | vpylm->remove_customer_at_time_t(token_ids, t, depth_t); 260 | } 261 | } 262 | } 263 | 264 | assert(vpylm->get_num_customers() == 0); 265 | assert(vpylm->get_num_tables() == 0); 266 | assert(vpylm->get_sum_stop_counts() == 0); 267 | assert(vpylm->get_sum_pass_counts() == 0); 268 | 269 | delete sentence; 270 | delete vpylm; 271 | delete[] token_ids; 272 | } 273 | 274 | void test_sample_depth_at_timestep(){ 275 | sampler::mt.seed(0); 276 | VPYLM* vpylm = new VPYLM(0.001, 1000, 4, 1); 277 | std::wstring sentence_str = L"本論文では, 教師データや辞書を必要とせず, あらゆる言語に適用できる教師なし形態素解析器および言語モデルを提案する."; 278 | Sentence* sentence = new Sentence(sentence_str); 279 | wchar_t* token_ids = new wchar_t[sentence->size() + 2]; 280 | wrap_bow_eow(sentence->_characters, 0, sentence->size() - 1, token_ids); 281 | for(int t = 0;t < sentence->size();t++){ 282 | for(int depth_t = 0;depth_t <= t;depth_t++){ 283 | vpylm->add_customer_at_time_t(token_ids, t, depth_t); 284 | } 285 | } 286 | for(int t = 0;t < sentence->size();t++){ 287 | for(int seed = 0;seed < 256;seed++){ 288 | sampler::mt.seed(seed); 289 | int a = vpylm->sample_depth_at_time_t(token_ids, t, vpylm->_parent_pw_cache, vpylm->_path_nodes); 290 | sampler::mt.seed(seed); 291 | int b = sample_depth_at_time_t(vpylm, token_ids, t); 292 | assert(a == b); 293 | } 294 | } 295 | delete sentence; 296 | delete vpylm; 297 | delete[] token_ids; 298 | } 299 | 300 | int main(){ 301 | setlocale(LC_CTYPE, "ja_JP.UTF-8"); 302 | std::ios_base::sync_with_stdio(false); 303 | std::locale default_loc("ja_JP.UTF-8"); 304 | std::locale::global(default_loc); 305 | std::locale ctype_default(std::locale::classic(), default_loc, std::locale::ctype); //※ 306 | std::wcout.imbue(ctype_default); 307 | std::wcin.imbue(ctype_default); 308 | 309 | test_compute_p_w_given_h(); 310 | cout << "OK" << endl; 311 | test_find_node_by_tracing_back_context(); 312 | cout << "OK" << endl; 313 | test_add_customer(); 314 | cout << "OK" << endl; 315 | test_remove_customer(); 316 | cout << "OK" << endl; 317 | test_sample_depth_at_timestep(); 318 | cout << "OK" << endl; 319 | return 0; 320 | } -------------------------------------------------------------------------------- /test/module_tests/wordtype.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/npylm/ctype.h" 8 | #include "../../src/npylm/wordtype.h" 9 | using std::cout; 10 | using std::flush; 11 | using std::endl; 12 | 13 | void test_wordtype(){ 14 | std::wstring sentence_str = L"本論文では,100教師データや!?dictionaryを必要とせず,"; 15 | int type; 16 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 0); 17 | assert(type == WORDTYPE_KANJI); 18 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 1); 19 | assert(type == WORDTYPE_KANJI); 20 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 2); 21 | assert(type == WORDTYPE_KANJI); 22 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 3, 3); 23 | assert(type == WORDTYPE_HIRAGANA); 24 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 3, 4); 25 | assert(type == WORDTYPE_HIRAGANA); 26 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 3); 27 | assert(type == WORDTYPE_KANJI_HIRAGANA); 28 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 4); 29 | assert(type == WORDTYPE_KANJI_HIRAGANA); 30 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 5); 31 | assert(type == WORDTYPE_OTHER); 32 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 6); 33 | assert(type == WORDTYPE_OTHER); 34 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 0, 7); 35 | assert(type == WORDTYPE_OTHER); 36 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 6, 6); 37 | assert(type == WORDTYPE_NUMBER); 38 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 6, 7); 39 | assert(type == WORDTYPE_NUMBER); 40 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 6, 8); 41 | assert(type == WORDTYPE_NUMBER); 42 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 6, 9); 43 | assert(type == WORDTYPE_OTHER); 44 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 9, 11); 45 | assert(type == WORDTYPE_KANJI_KATAKANA); 46 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 9, 12); 47 | assert(type == WORDTYPE_KANJI_KATAKANA); 48 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 11, 11); 49 | assert(type == WORDTYPE_KATAKANA); 50 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 11, 12); 51 | assert(type == WORDTYPE_KATAKANA); 52 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 13, 14); 53 | assert(type == WORDTYPE_OTHER); 54 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 17, 18); 55 | assert(type == WORDTYPE_ALPHABET); 56 | type = npylm::wordtype::detect_word_type_substr(sentence_str.data(), 26, 27); 57 | assert(type == WORDTYPE_OTHER); 58 | } 59 | 60 | int main(){ 61 | test_wordtype(); 62 | cout << "OK" << endl; 63 | return 0; 64 | } -------------------------------------------------------------------------------- /test/running_tests/save.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "../../src/npylm/sampler.h" 4 | #include "../../src/python/model.h" 5 | #include "../../src/python/dataset.h" 6 | #include "../../src/python/dictionary.h" 7 | #include "../../src/python/trainer.h" 8 | using namespace npylm; 9 | using std::cout; 10 | using std::flush; 11 | using std::endl; 12 | 13 | template 14 | void compare_node(lm::Node* a, lm::Node* b){ 15 | assert(a->_num_tables == b->_num_tables); 16 | assert(a->_num_customers == b->_num_customers); 17 | assert(a->_stop_count == b->_stop_count); 18 | assert(a->_pass_count == b->_pass_count); 19 | assert(a->_depth == b->_depth); 20 | assert(a->_token_id == b->_token_id); 21 | assert(a->_arrangement.size() == b->_arrangement.size()); 22 | for(auto elem: a->_arrangement){ 23 | T key = elem.first; 24 | std::vector &table_a = elem.second; 25 | std::vector &table_b = b->_arrangement[key]; 26 | assert(table_a.size() == table_b.size()); 27 | } 28 | for(auto elem: a->_children){ 29 | T key = elem.first; 30 | lm::Node* children_a = elem.second; 31 | lm::Node* children_b = b->_children[key]; 32 | compare_node(children_a, children_b); 33 | } 34 | } 35 | 36 | void compare_npylm(NPYLM* a, NPYLM* b){ 37 | assert(a != NULL); 38 | assert(b != NULL); 39 | compare_node(a->_hpylm->_root, b->_hpylm->_root); 40 | compare_node(a->_vpylm->_root, b->_vpylm->_root); 41 | } 42 | 43 | int main(int argc, char *argv[]){ 44 | std::string filename = "../../dataset/test.txt"; 45 | Corpus* corpus = new Corpus(); 46 | corpus->add_textfile(filename); 47 | int seed = 0; 48 | Dataset* dataset = new Dataset(corpus, 1, seed); 49 | int max_word_length = 8; 50 | Model* model = new Model(dataset, max_word_length); 51 | Dictionary* dictionary = dataset->_dict; 52 | dictionary->save("npylm.dict"); 53 | Trainer* trainer = new Trainer(dataset, model, false); 54 | 55 | for(int epoch = 0;epoch < 1000;epoch++){ 56 | cout << "\r" << epoch << flush; 57 | trainer->gibbs(); 58 | trainer->sample_hpylm_vpylm_hyperparameters(); 59 | trainer->sample_lambda(); 60 | model->save("npylm.model"); 61 | Model* _model = new Model("npylm.model"); 62 | compare_npylm(model->_npylm, _model->_npylm); 63 | delete _model; 64 | } 65 | } -------------------------------------------------------------------------------- /test/running_tests/train.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "../../src/npylm/sampler.h" 4 | #include "../../src/python/model.h" 5 | #include "../../src/python/dataset.h" 6 | #include "../../src/python/dictionary.h" 7 | #include "../../src/python/trainer.h" 8 | using namespace npylm; 9 | using std::cout; 10 | using std::flush; 11 | using std::endl; 12 | 13 | void run_training_loop(){ 14 | std::string filename = "../../dataset/test.txt"; 15 | Corpus* corpus = new Corpus(); 16 | corpus->add_textfile(filename); 17 | int seed = 0; 18 | Dataset* dataset = new Dataset(corpus, 0.95, seed); 19 | int max_word_length = 12; 20 | Model* model = new Model(dataset, max_word_length); 21 | Dictionary* dictionary = dataset->_dict; 22 | dictionary->save("npylm.dict"); 23 | Trainer* trainer = new Trainer(dataset, model, true); 24 | 25 | for(int epoch = 1;epoch <= 200;epoch++){ 26 | auto start_time = std::chrono::system_clock::now(); 27 | trainer->gibbs(); 28 | auto diff = std::chrono::system_clock::now() - start_time; 29 | cout << (std::chrono::duration_cast(diff).count() / 1000.0) << endl; 30 | trainer->sample_hpylm_vpylm_hyperparameters(); 31 | trainer->sample_lambda(); 32 | if(epoch > 3){ 33 | trainer->update_p_k_given_vpylm(); 34 | } 35 | if(epoch % 10 == 0){ 36 | trainer->print_segmentation_train(10); 37 | cout << "ppl: " << trainer->compute_perplexity_train() << endl; 38 | trainer->print_segmentation_dev(10); 39 | cout << "ppl: " << trainer->compute_perplexity_dev() << endl; 40 | cout << "log_likelihood: " << trainer->compute_log_likelihood_train() << endl; 41 | cout << "log_likelihood: " << trainer->compute_log_likelihood_dev() << endl; 42 | } 43 | } 44 | } 45 | 46 | int main(int argc, char *argv[]){ 47 | for(int i = 0;i < 10;i++){ 48 | run_training_loop(); 49 | } 50 | } --------------------------------------------------------------------------------