├── LICENSE ├── README.md ├── autoenc_mnist.json ├── cnn_mnist.json ├── gan_mnist.json ├── index.html ├── layers.md ├── nnboard_digest.gif └── server.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 AL, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nnboard 2 | a GUI editor for neural network (especially for chainer) 3 | 4 | chainer 向けの GUI エディタです。 5 | 6 | ![](/nnboard_digest.gif) 7 | 8 | ## 動作環境 9 | * python 3.5.2 以降 (おそらく 3.5 以降なら動きます) 10 | * chainer 2.0.2 以降 (おそらく 1.19 以降なら動きます) 11 | 12 | ## 使い方 13 | ### ダウンロード 14 | ``` 15 | git clone https://github.com/al4tech/nnboard.git 16 | cd nnboard 17 | ``` 18 | 19 | ### 実行 20 | ``` 21 | python server.py 22 | ``` 23 | `server.py` を実行すると、自動的に `index.html` が開きます。このページ上で、ニューラルネットワークをさくさく設計することができます。 24 | 25 | 初めての方は、まず一番上の `Load Canvas` ボタンから、サンプルファイル(このディレクトリにある .json ファイルは全てサンプルファイルです)を適当にひとつ読み込んでみて、 青字の `Start Learning` ボタンを押してみてください。学習が始まります。別のサンプルを試すときは、先に青字の `Quit Learning` ボタンを押して、学習を停止してから、 `Load Canvas` ボタンで別のサンプルファイルを読み込み、それから `Start Learning` ボタンを押してください。 26 | 27 | ### 終了 28 | 29 | `server.py` を Ctrl-C で終了してください。 30 | 31 | なお、 `index.html` の最下部にある `Shutdown Server` ボタンを押しても、 `server.py` が終了します。 32 | 33 | ### 作業状態の保存と読み込み 34 | 35 | * `index.html` の一番上にある `Download Canvas` リンクを押すと、json ファイルがローカルに保存されます。この中に、構築したネットワークの情報が入っています。 36 | 37 | * `Load Canvas` ボタンを押して json ファイルを選択すると、構築したネットワークが復元されます。 38 | 39 | * 構築したネットワークとともに、optimizerの設定なども保存されます。 40 | 41 | * 学習結果は保存されません・・・ 42 | 43 | 44 | 46 | 47 | 48 | ## 詳しい使い方 49 | ### ネットワーク設計 50 | 51 | * 層や結合を編集するときは、canvas にフォーカスが当たった状態にします(canvas内をどこかクリックすれば良いです)。 52 | 53 | * `a` キーを押すと、層が作られます(`add`)。作った層は、クリックで選択でき、ドラッグで移動できます。 54 | 55 | * ある層(a)を選択中に、`Shift` キーを押しながら別の層(b)をクリックすると、(a)から(b)に結合が生じます。結合もクリックで選択できます。 56 | 57 | * `Del` キーを押すと、選択中の層や結合を削除できます。層を削除すると、層にくっついている結合も一緒に削除されます。 58 | 59 | * ある単一の層を選択中に、様々な英字キーを押すことで、層のタイプを変更することができます。([層タイプ一覧](/layers.md)) 60 | 61 | * 対応キー:`b`(batch Normalization),`c`(convolution),`C`(Concat),`e`(experience replay),`f`(full connected),`i`(input),`m`(mean_squared_loss),`o`(other;任意の関数),`p`(pooling),`r`(random),`R`(Reshape),`s`(softmax_cross_entropy),`T`(Transpose),`+`(足し算),`*`(掛け算),`-`(マイナス) 62 | 63 | * `Options`から、オプション引数を設定できます。`o`の場合は、任意の関数を設定できます。(lambda式も指定可です。例えば `"func":"lambda x,y:F.softmax_cross_entropy(x,y)"`と書けば、タイプ`s`の層と実質的に同じになります。) 64 | 65 | * `Options` は json の書式で書く必要があります。 None は null で指定します。タプルを指定したいときは、(jsの)Array として書きます。例: `"shape":[-1,1,28,28]` 66 | 67 | * ある単一の層を選択中に、`Ctrl` キーを押しながら様々な英字キーを押すことで、層の活性化関数を変更することができます。 68 | 69 | * 対応キー:`e`(elu),`i`(id),`l`(leaky_relu),`r`(relu),`s`(sigmoid),`t`(tanh) 70 | 71 | * `Options`から、(これらに限らない)任意の活性化関数に変更できます(lambda式も指定可:例えば `"act":"lambda x:F.relu(x)"` など)。 72 | 73 | 74 | ### 学習 75 | 76 | * loss と optimizer を設定する必要があります。 77 | 78 | * 単一の層を選択中に数字キー(`0`-`9`)を押すと、層に「タグ」をつけることができます(層の中に `#0` などと表示されます)。 79 | 80 | * 一行目には `optimizee: #0, loss:#4` と書いてあります。これに従って、最適化したい loss の層に `#4` タグを指定します。その loss から計算される勾配に従って最適化したい重みをもつ層に `#0` タグを指定します。 81 | 82 | * 同時に4つまで複数のoptimizerを併用できます。 83 | 84 | * 複数のoptimizerを交互に動かしたい場合などには、`condition` の指定を行ってください。ここでは「非負整数 x を受け取り、(xイテレーション目にこのoptimizerを動かすか)を返す関数」を指定してください。 85 | 86 | * 例:optimizer 0 の condtition が `lambda x: x%6` で、optimizer 1 の condition が `lambda x: not(x%6)` の時、「0を5回」→「1を1回」→「0を5回」→・・・という動かし方になります。 87 | * `Start Learning` ボタンを押すと、学習が始まります。 88 | 89 | * 正常に学習が始まると、ボタンの表示が `Quit Learning` に変化します。`Quit Learning` ボタンを押すと、学習が終了し、ボタンの表示が `Start Learning` に戻ります。何らかのエラーが生じて学習が死んだ場合も、ボタンの表示が `Start Learning` に戻ります。 90 | 91 | * 学習中の表示の見方 92 | 93 | * 各層の右下に表示されているのは shape です。右上は、現在の層の値のプレビューです。 94 | 95 | * 各層をダブルクリックすると、現在の層の値をいつでも可視化することができます 96 | 97 | * いい感じに可視化されない場合は、任意の層を使って、いい感じに shape を整形すると良いです。 98 | 99 | * 学習中の loss の変化が折れ線グラフで表示されます(google chart API を利用しているため、インターネット接続時のみの機能です)。 100 | 101 | * このグラフは15秒ごとに自動更新されます。手動で更新したい場合は `Update graph manually` ボタンを押してください。 102 | 103 | * softmax_cross_entropy の層で loss を集計している場合に限り、 accuracy の変化も折れ線グラフで表示されます。 104 | 105 | * 学習中に任意コードの実行ができます。学習係数を途中で変えたりできます。 `Execute Code` にコードを入力し、`Execute` ボタンを押してください。 106 | 107 | * エラーが出た場合はダイアログで表示されます。 108 | 109 | * 学習中にハイパーパラメータをスライダーで調節できます。 `Tuning Slider` の欄に調節したいハイパーパラメータの変数を入力し、 `GetValue` ボタンを押すと、スライダーに現在の値がセットされ、スライダーが操作可能となります。この状態でスライダーを操作すると、動的にハイパーパラメータの値を変更できます。 110 |     111 | ### 構成 112 | 113 | * `index.html`:編集画面(GUI) 114 | 115 | * `server.py`:chainerでニューラルネットの計算を行うサーバー 116 | 117 | ### FAQ 118 | 119 | * `server.py` 起動時に `Address already in use` などと表示されて起動できない。 120 | 121 | * 以下を確かめてみてください: 122 | 123 | * `server.py` がバックグラウンドで起動したままになっている。 → `index.html` を(手動で)開いて、 `Shutdown Server` ボタンを押せば終了できます。 124 | 125 | * それ以外の何らかのプログラムが `localhost:8000` を使用している。 → 通信に使用するポート番号を変更しましょう。例えば、ポート 12345 番を使いたい場合は、サーバを `server.py -p 12345` で起動し、 `index.html` の最下部にある `Address of Server` を `http://localhost:12345` と変更してください。 126 | 127 | * テストエラー見たい 128 | 129 |    * 現状、そのような機能はありませんが、見られるようにしたいと思っています。 130 | 131 | * 学習結果を保存したい 132 | 133 | * 現状、そのような機能はありません 134 | 135 | * GPU で動かしたい 136 | 137 | * 現状、未対応です 138 | 139 | * 学習中にネットワークいじったらどうなるの 140 | 141 | * 現状、どうにもなりません。 ← 部分的に、ネットワークを学習中に動的にいじれるようにしました。ネットワークを変更してから、 `SendNetworkInfoToServer` ボタンを押すと、変更したことをサーバーの計算に反映できます。(link 層はいじれません) 142 | 143 | * `Start Learning` ボタンを押しても `Quit Learning` に変化しない 144 | 145 | * おそらく一瞬で学習が落ちてます。 146 | 147 | 148 | ## License 149 | * MIT 150 | * see LICENSE 151 | 152 | -------------------------------------------------------------------------------- /autoenc_mnist.json: -------------------------------------------------------------------------------- 1 | {"node":{"0":{"x":319,"y":54,"w":180,"h":40,"text":"input <-- mnist_train_x (id)\n","opt":{"act":"id","source":"mnist_train_x"},"color":"#d0d0d0","ltype":"i","code":"","to":[2,7],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"1":{"x":152,"y":181,"w":180,"h":40,"text":"fc (100ch) (relu) #0\n","opt":{"act":"relu","out_channel":100},"color":"#3399ff","ltype":"f","code":"","to":[4],"from":[2],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"3":{"x":150,"y":287,"w":180,"h":40,"text":"fc (784ch) (sigm) #0\n","opt":{"act":"sigm","out_channel":784},"color":"#3399ff","ltype":"f","code":"","to":[6],"from":[4],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"5":{"x":335,"y":417,"w":180,"h":40,"text":"meanSquaredError (id) #4\n","opt":{"act":"id"},"color":"#33ff55","ltype":"m","code":"","to":[],"from":[6,7],"optiflg":[false,false,false,false,true,false,false,false,false,false]}},"edge":{"2":{"pre":"0","post":1},"4":{"pre":"1","post":3},"6":{"pre":"3","post":5},"7":{"pre":"0","post":"5"}},"idCnt":12,"opti":["optimizers.Adam()","optimizers.SGD(lr=0.01)","optimizers.MomentumSGD(lr=0.01)","optimizers.AdaDelta()"],"settings":{"bs":100,"aveintvl":100},"optiee":["0","1","2","3"],"loss":["4","5","6","7"],"cond":["lambda x:1","lambda x:1","lambda x:1","lambda x:1"]} -------------------------------------------------------------------------------- /cnn_mnist.json: -------------------------------------------------------------------------------- 1 | {"node":{"0":{"x":272,"y":29,"w":180,"h":40,"text":"input <-- mnist_train_x (id)\n","opt":{"act":"id","source":"mnist_train_x"},"color":"#d0d0d0","ltype":"i","code":"","to":[17],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"1":{"x":255,"y":208,"w":180,"h":40,"text":"conv (4ch)[3]{2}|1| (relu) #0\n","opt":{"act":"relu","out_channel":4,"ksize":3,"stride":2,"pad":1},"color":"#88bbff","ltype":"c","code":"","to":[4,23],"from":[15],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"3":{"x":254,"y":292,"w":180,"h":40,"text":"batchNorm (4ch) (id) #0\n","opt":{"act":"id","size":4},"color":"#33ff55","ltype":"b","code":"","to":[6],"from":[4],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"5":{"x":253,"y":390,"w":180,"h":40,"text":"conv (8ch)[3]{2}|1| (relu) #0\n","opt":{"act":"relu","out_channel":8,"ksize":3,"stride":2,"pad":1},"color":"#88bbff","ltype":"c","code":"","to":[8,25],"from":[6],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"7":{"x":250,"y":480,"w":180,"h":40,"text":"fc (10ch) (id) #0\n","opt":{"act":"id","out_channel":10},"color":"#3399ff","ltype":"f","code":"","to":[10],"from":[8],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"9":{"x":359,"y":562,"w":180,"h":40,"text":"softmaxCrossEnt (id) #4\n","opt":{"act":"id"},"color":"#33ff55","ltype":"s","code":"","to":[],"from":[10,12],"optiflg":[false,false,false,false,true,false,false,false,false,false]},"11":{"x":541,"y":479,"w":180,"h":40,"text":"input <-- mnist_train_t (id)\n","opt":{"act":"id","source":"mnist_train_t"},"color":"#d0d0d0","ltype":"i","code":"","to":[12],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"13":{"x":262,"y":145,"w":180,"h":40,"text":"batchNorm (1ch) (id) #0\n","opt":{"act":"id","size":1},"color":"#33ff55","ltype":"b","code":"","to":[15],"from":[18],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"16":{"x":271,"y":82,"w":180,"h":40,"text":"F.reshape (id)\nshape=-1,1,28,28 ","opt":{"act":"id","func":"F.reshape","shape":[-1,1,28,28]},"color":"#f0f0f0","ltype":"o","code":"","to":[18],"from":[17],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"22":{"x":540,"y":213,"w":180,"h":40,"text":"lambda x:x[:,:3,:,:] (id)\n","opt":{"act":"id","func":"lambda x:x[:,:3,:,:]"},"color":"#f0f0f0","ltype":"o","code":"","to":[],"from":[23],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"24":{"x":525,"y":380,"w":180,"h":40,"text":"lambda x: x[:,:3,:,:] (id)\n","opt":{"act":"id","func":"lambda x: x[:,:3,:,:]"},"color":"#f0f0f0","ltype":"o","code":"","to":[],"from":[25],"optiflg":[false,false,false,false,false,false,false,false,false,false]}},"edge":{"4":{"pre":"1","post":3},"6":{"pre":"3","post":5},"8":{"pre":"5","post":7},"10":{"pre":"7","post":9},"12":{"pre":"11","post":"9"},"15":{"pre":"13","post":"1"},"17":{"pre":"0","post":16},"18":{"pre":"16","post":"13"},"23":{"pre":"1","post":22},"25":{"pre":"5","post":24}},"idCnt":26,"opti":["optimizers.Adam()","optimizers.SGD(lr=0.01)","optimizers.MomentumSGD(lr=0.01)","optimizers.AdaDelta()"],"settings":{"bs":100,"aveintvl":100},"optiee":["0","1","2","3"],"loss":["4","5","6","7"],"cond":["lambda x:1","lambda x:1","lambda x:1","lambda x:1"]} -------------------------------------------------------------------------------- /gan_mnist.json: -------------------------------------------------------------------------------- 1 | {"node":{"8":{"x":637,"y":58,"w":180,"h":40,"text":"random (50) (id)\nmu=0 sigma=1 ","opt":{"act":"id","type":"normal","mu":0,"sigma":1,"sample_shape":[50]},"color":"#d0d0ff","ltype":"r","code":"","to":[21],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"9":{"x":531,"y":139,"w":180,"h":40,"text":"fc (100ch) (l_relu) #0\n","opt":{"act":"l_relu","out_channel":100},"color":"#3399ff","ltype":"f","code":"","to":[22],"from":[21],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"10":{"x":445,"y":210,"w":180,"h":40,"text":"fc (784ch) (sigm) #0\n","opt":{"act":"sigm","out_channel":784},"color":"#3399ff","ltype":"f","code":"","to":[47],"from":[22],"optiflg":[true,false,false,false,false,false,false,false,false,false]},"11":{"x":140,"y":317,"w":180,"h":40,"text":"Concat (id)\n","opt":{"act":"id","type":"batch_dim"},"color":"#a0a0a0","ltype":"C","code":"","to":[24],"from":[48,49],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"12":{"x":119,"y":91,"w":180,"h":40,"text":"input <-- mnist_train_x (id)\n","opt":{"act":"id","source":"mnist_train_x"},"color":"#d0d0d0","ltype":"i","code":"","to":[49],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"15":{"x":128,"y":402,"w":180,"h":40,"text":"fc (100ch) (l_relu) #1\n","opt":{"act":"l_relu","out_channel":100},"color":"#3399ff","ltype":"f","code":"","to":[25],"from":[24],"optiflg":[false,true,false,false,false,false,false,false,false,false]},"16":{"x":125,"y":489,"w":180,"h":40,"text":"fc (2ch) (id) #1\n","opt":{"act":"id","out_channel":2},"color":"#3399ff","ltype":"f","code":"","to":[26,55],"from":[25],"optiflg":[false,true,false,false,false,false,false,false,false,false]},"17":{"x":119,"y":566,"w":180,"h":40,"text":"softmaxCrossEnt (id) #4 #5\n","opt":{"act":"id"},"color":"#33ff55","ltype":"s","code":"","to":[],"from":[26,28],"optiflg":[false,false,false,false,true,true,false,false,false,false]},"18":{"x":331,"y":569,"w":180,"h":40,"text":"Concat (id)\n","opt":{"act":"id","type":"batch_dim"},"color":"#a0a0a0","ltype":"C","code":"","to":[28],"from":[27,29],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"19":{"x":372,"y":507,"w":180,"h":40,"text":"value 0 (type:int32) (id)\n","opt":{"act":"id","value":0,"type":"int32"},"color":"#f0f0f0","ltype":"v","code":"","to":[27],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"20":{"x":565,"y":570,"w":180,"h":40,"text":"value 1 (type:int32) (id)\n","opt":{"act":"id","value":1,"type":"int32"},"color":"#f0f0f0","ltype":"v","code":"","to":[29],"from":[],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"46":{"x":318,"y":263,"w":180,"h":40,"text":"exp.replay (size=10) (id)\n","opt":{"act":"id","size":10},"color":"#33aa55","ltype":"e","code":"","to":[48],"from":[47],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"54":{"x":401,"y":432,"w":180,"h":40,"text":"F.softmax (id)\n","opt":{"act":"id","func":"F.softmax"},"color":"#f0f0f0","ltype":"o","code":"","to":[59],"from":[55],"optiflg":[false,false,false,false,false,false,false,false,false,false]},"58":{"x":635,"y":434,"w":180,"h":40,"text":"lambda x:x[:,1] (id)\n","opt":{"act":"id","func":"lambda x:x[:,1]"},"color":"#f0f0f0","ltype":"o","code":"","to":[],"from":[59],"optiflg":[false,false,false,false,false,false,false,false,false,false]}},"edge":{"21":{"pre":"8","post":"9"},"22":{"pre":"9","post":"10"},"24":{"pre":"11","post":"15"},"25":{"pre":"15","post":"16"},"26":{"pre":"16","post":"17"},"27":{"pre":"19","post":"18"},"28":{"pre":"18","post":"17"},"29":{"pre":"20","post":"18"},"47":{"pre":"10","post":"46"},"48":{"pre":"46","post":"11"},"49":{"pre":"12","post":"11"},"55":{"pre":"16","post":54},"59":{"pre":"54","post":58}},"idCnt":60,"opti":["optimizers.Adam(alpha=-0.003)","optimizers.SGD(lr=0.001)","optimizers.MomentumSGD(lr=0.01)","optimizers.AdaDelta()"],"settings":{"bs":100,"aveintvl":100},"optiee":["0","1","2","3"],"loss":["4","5","6","7"],"cond":["lambda x:not(x%6)","lambda x:x%6","lambda x:1","lambda x:1"]} -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Board 5 | 6 | 7 | 845 | 857 | 858 | 859 | 860 | Load Canvas: 861 | Download Canvas
862 |
863 | 864 |
865 | Show: 866 | 867 | 868 | 869 | 870 | 871 |
872 | Options: 873 |
874 | training settings:
875 |
877 | (Experimental)
879 | 884 |
885 | optimizer 0: 886 | optimizee: # 887 | loss: # 888 | condition:
889 | optimizer 1: 890 | optimizee: # 891 | loss: # 892 | condition:
893 | optimizer 2: 894 | optimizee: # 895 | loss: # 896 | condition:
897 | optimizer 3: 898 | optimizee: # 899 | loss: # 900 | condition:
901 |
902 | Tuning Slider: 903 |   0.00000  
904 |
905 | |
906 |
907 |
908 |
909 | Execute Code: 910 |
911 |
912 | 913 | 914 |
915 |
916 |
917 |
918 | Address of Server:
919 |
920 | 921 |
922 |
923 |
924 |
925 | 973 | 974 | -------------------------------------------------------------------------------- /layers.md: -------------------------------------------------------------------------------- 1 | # 層の一覧 2 | 3 | * `b` : Batch Normalization 4 | 5 | * バッチ正規化層です。`size` に、出力チャネル数を指定してください。内部では `L.BatchNormalization(size, **kwargs)` がインスタンス化されます。 6 | 7 | * `c`: Convolution 2D 8 | 9 | * 畳み込み層です。`out_channel` に、出力チャネル数を指定してください。内部では `L.Convolution2D(None, out_channel, **kwargs)` がインスタンス化されます。 10 | 11 | * `C`: Concat 12 | 13 | * Variableを結合します。`type` には `"batch_dim"` (バッチ次元(chainerの場合0次元目))か `"channel_dim"`(チャネル次元(chainerの場合1次元目))を指定してください。内部では `F.concat(list_of_args, axis=<0 or 1>` が呼ばれます。引数の順番は、引数を生み出す層のx座標の昇順となります。 14 | 15 | * `d`: Dropout 16 | 17 | * ドロップアウト層です。内部では `F.dropout(arg, **kwargs)` が呼ばれます。 18 | 19 | * `e`: experience replay 20 | 21 | * 体験再生 (experience replay) を行う層です。過去に入力されたバッチをいくつか貯めておいて、その中からランダムに1つを選んで出力します。貯めておくバッチの個数を `size` で指定してください。 22 | 23 | * `f`: full connected 24 | 25 | * 全結合層です。`out_channel` に、出力チャネル数を指定してください。内部では `L.Linear(None, out_channel, **kwargs)` がインスタンス化されます。 26 | 27 | * `i`: input 28 | 29 | * 入力層です。`source` に、データのソースを指定してください(現時点では、`mnist_train_x` `mnist_train_t` `mnist_test_x` `mnist_test_t` のみが選択可能です)。 30 | 31 | * `m`: mean squared error 32 | 33 | * 平均2乗和誤差を求める層です。内部では `F.mean_squared_error(arg0, arg1)` が呼ばれます。 34 | 35 | * `o`: others 36 | 37 | * 任意の層を設計できます。`func` に指定したオブジェクトが、層が生成されたタイミングでインスタンス化され、学習中、そのインスタンスが引数付きで呼び出され続けます。引数の与えられる順番は、引数を生み出す層のx座標の昇順となります。(ここに記載している層の多くは、この `o` でも実装可能です。) 38 | 39 | * 例1:"func":"L.Linear(None, 32)"` と指定すれば、`f` 層で `"out_channel":32` と指定するのと同じになります。 40 | 41 | * 例2:"func":"F.relu"` と指定すれば、ReLU活性化のみを行う層となります。 42 | 43 | * 例3:`"func":"lambda x, y:x+2*y"` と指定すれば、x と y を受け取って x + 2*y を出力する層となります。 44 | 45 | * `p`: pooling 46 | 47 | * プーリング層です。`type` には、`max` か `average` を指定してください。内部では `F.max_pooling_2d(arg, **kwargs)` または `F.average_pooling_2d(arg, **kwargs)` が呼ばれます。 48 | 49 | * `r`: random 50 | 51 | * 乱数を生成する層です。現状、`type` 指定に関わらず、各要素が独立な正規乱数によって生成されます。`sample_shape` に、生成される Variable のサンプル部分の shape(つまり、0次元目のバッチサイズを除いたshape)を指定してください(例:`"shape":[3, 32, 32]`)。`mu` に正規分布の平均、`sigma` に正規分布の標準偏差を指定してください。 52 | 53 | * `R`: Reshape 54 | 55 | * shape の変更を行う層です。(現状、`o` 層に `"func":"F.reshape"` と指定したものが出てきます。)`shape` に、変更後の shape を指定してください(例:`"shape":[-1, 1, 28, 28]`)。 56 | 57 | * `s`: softmax cross entropy 58 | 59 | * ソフトマックス交差エントロピー誤差を求める層です。内部では `F.softmax_cross_entropy(arg0, arg1)` が呼ばれます。(arg0 と arg1 の順番は、 60 | ノードの x 座標の小さな方が arg0 となります。)なお、この層は正答率集計の対象となります(裏で `F.accuracy(arg0, arg1)` を呼んでいる)。 61 | 62 | * `T`: Transpose 63 | 64 | * Variable を転置するための層です。最後の2つの軸が入れ替わります。(これ以外の軸の転置を行いたい場合は、`o` 層で `"func":"F.transpose"` としてください) 65 | 66 | * `v`: value 67 | 68 | * 固定値を出力するための層です。`value` に値を指定してください。`type` には、"float32" "int32" など、型を指定してください(`np.dtype` の引数に渡せる文字列を指定してください)。 69 | 70 | * `+`: 和 71 | 72 | * 引数を全て足し合わせる層です。 73 | 74 | * `*`: 積 75 | 76 | * 引数を全て掛け合わせる層です。 77 | 78 | * `-`: 負号 79 | 80 | * 引数にマイナスをつけます。(引数は 1 つしか受け取れません。引き算記号ではありません。) 81 | 82 | 83 | # 活性化関数について 84 | 85 | * 全ての層には活性化関数がついています。上記の計算を行った後に、活性化関数が適用され、その層の出力となります。 86 | 87 | * 活性化関数は、`Ctrl`+{`e`,`i`,`l`,`r`,`s`,`t`} で切り替え可能です。また、Options の `act` を編集すれば、任意のものに変更可能です。 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /nnboard_digest.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/al4tech/nnboard/89c3a513d0f32b2350380fe169888135b7efe722/nnboard_digest.gif -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | # 参考: 2 | # http://blog.sarabande.jp/post/81479479934 3 | # http://coreblog.org/ats/stuff/minpy_web/13/03.html 4 | # http://d.hatena.ne.jp/matasaburou/touch/20151003/1443882557 5 | import numpy as np 6 | import argparse 7 | import urllib 8 | import time 9 | from http.server import HTTPServer, SimpleHTTPRequestHandler 10 | import cgi 11 | 12 | import json 13 | from http.server import BaseHTTPRequestHandler 14 | 15 | import threading, sys 16 | 17 | from chainer import Chain, ChainList, cuda, gradient_check, Function, Link, optimizers, serializers, utils, Variable, datasets 18 | from chainer import functions as F 19 | from chainer import links as L 20 | 21 | 22 | def get_error_message(sys_exc_info=None): 23 | ex, ms, tb = sys.exc_info() if sys_exc_info is None else sys_exc_info 24 | return '[Error]\n' + str(ex) + '\n' + str(ms) 25 | 26 | def float_if_float(x): 27 | # np.float系列がJSON serializableでないのを回避するため 28 | if isinstance(x, list): 29 | return [float_if_float(i) for i in x] 30 | return float(x) if isinstance(x, np.float) or isinstance(x, np.float32) or isinstance(x, np.float64) else x 31 | # もっと綺麗に書けるはず 32 | 33 | def filter_dic(dic, filt=None, omit=None): 34 | ret = {} 35 | for k,v in dic.items(): 36 | if (filt is None or (k in filt)) and (omit is None or (not k in omit)): ret[k] = v 37 | return ret 38 | 39 | def cov2ovl(cov): 40 | # covmatをovlmatに 41 | norm = np.sqrt(np.diag(cov)).reshape(-1, 1) 42 | return np.maximum(np.minimum(1., cov / np.dot(norm, norm.T)), -1.) 43 | 44 | def cov2deg(cov): 45 | # covmatをdegmatに 46 | return np.arccos(cov2ovl(cov)) * 180 / np.pi 47 | 48 | def cov2maxovl(cov): 49 | return np.max(np.abs(cov2ovl(cov) * (np.ones(cov.shape) - np.eye(cov.shape[0])))) 50 | 51 | def cov2minnorm(cov): 52 | return np.sqrt(np.min(np.diag(cov))) 53 | 54 | class linkset(ChainList): 55 | def __init__(self, links): 56 | super(linkset, self).__init__(*links) 57 | # class nnmdl(Chain): # 一つのlinkを複数のchainに登録できないようなので、全体を管理するchainはchainでなくした(chainにする必要もないので) 58 | class nnmdl: 59 | def __init__(self, net_info): 60 | self.initialize(net_info) 61 | def initialize(self, net_info, net_info_old=None): 62 | self.err_info = {} # 学習中に生じたエラーはここに格納しておく。get_info(typ=='err')で取得されたらクリアされる。 63 | self.net_info = net_info 64 | if net_info_old is None: # ガチ初期化の場合 65 | self.l = {} # Linkを格納する。'e'のみ記憶の溜め込み場のリストとして使う 66 | self.h = {} # Variableを格納する。 67 | else: # 動的更新の場合 68 | # 削除されたノードの情報を捨て去る 69 | for k,v in net_info_old['node'].items(): 70 | if not (k in net_info['node']): # ノード k は削除された。 71 | print('info: node',k,'is deleted. Deleting related info.') 72 | if k in self.l:del self.l[k] 73 | if k in self.h: del self.h[k] 74 | self.bs = self.net_info['settings']['bs'] 75 | self.aveintvl = self.net_info['settings']['aveintvl'] 76 | # トポロジカルソートする。結果は self.torder リストに格納される。 77 | self.torder, self.visited_for_tsort = [], set() 78 | for k in self.net_info['node'].keys(): 79 | self.recur_for_tsort(k) 80 | self.torder.reverse() 81 | # 閉路検出して警告を出す 82 | idx_to_torder = {self.torder[i]: i for i in range(len(self.torder))} 83 | for k,v in net_info['edge'].items(): 84 | if idx_to_torder[str(v['pre'])] > idx_to_torder[str(v['post'])]: 85 | print('WARNING!!!!: computation graph contains some cycles. (will be dealt as some acyclic graph.)') 86 | break 87 | # Linkを用意する 88 | random_seed = np.random.randint(1000000007) 89 | for k,v in net_info['node'].items(): 90 | existed, updated = False, False 91 | if net_info_old is not None and k in net_info_old['node']: 92 | existed = True # ノード k は以前から存在した。 93 | # if v != net_info_old['node'][k]: 94 | v_old = net_info_old['node'][k] 95 | if v['ltype'] != v_old['ltype'] or v['opt'] != v_old['opt']: 96 | updated = True # ノード k は今回アップデートされた。 97 | print('WARNING: node',k,'is updated!') 98 | if existed: continue # 本当は and not updated としたいが、とりあえず今は updated を無視することにする。TODO 99 | 100 | if v['ltype']=='f': 101 | self.l[k] = L.Linear(None, int(v['opt']['out_channel']), **filter_dic(v['opt'], omit=['act', 'out_channel'])) 102 | # self.add_link(k, self.l[k]) 103 | elif v['ltype']=='c': 104 | self.l[k] = L.Convolution2D(None, int(v['opt']['out_channel']), **filter_dic(v['opt'], omit=['act', 'out_channel'])) 105 | elif v['ltype']=='r': 106 | self.l[k] = Sampler(source='random_normal', bs=self.bs, opt=v['opt'], sample_shape=v['opt']['sample_shape'], random_seed=random_seed) 107 | elif v['ltype']=='i': 108 | self.l[k] = Sampler(source=v['opt']['source'], bs=self.bs, opt=v['opt'], sample_shape=None, random_seed=random_seed) 109 | elif v['ltype']=='o': 110 | if v['opt']['func'][:2] == 'L.': 111 | self.l[k] = eval(v['opt']['func']) # optで指定できるのはインスタンス生成時引数でなく__call__時引数とする(インスタンス生成時引数はfuncの中に直接書いて。) 112 | else: 113 | self.l[k] = eval(v['opt']['func']) # Fも同じ扱い 114 | elif v['ltype']=='b': 115 | self.l[k] = L.BatchNormalization(v['opt']['size'], **filter_dic(v['opt'], omit=['act', 'size'])) 116 | elif v['ltype']=='e': 117 | self.l[k] = [] # 溜め込み場としてのリスト 118 | elif v['ltype'] in ['i','C','s','v','m','p','+','-','*','T']: 119 | pass 120 | else: 121 | raise NotImplementedError('Unknown layer type: ' + v['ltype']) 122 | 123 | 124 | if net_info_old is not None: return 125 | # 現状、optimizerの更新は動的にできないようにしとく。(多分linkset重複定義などでエラー出る) TODO 126 | 127 | # optimizer準備 128 | self.num_of_opt = 4 129 | self.optchain = [None] * self.num_of_opt 130 | self.opt = [None] * self.num_of_opt 131 | self.lossidx = [[] for i in range(self.num_of_opt)] 132 | linkflg = False 133 | for i in range(self.num_of_opt): # optimizer番号のループ 134 | optiee_tag = int(self.net_info['optiee'][i]) # タグ番号 135 | links = [] 136 | for j in self.net_info['node'].keys(): 137 | n = self.net_info['node'][j] 138 | if n['optiflg'][optiee_tag]: 139 | if j in self.l and isinstance(self.l[j], Link): 140 | links.append(self.l[j]) 141 | linkflg = True 142 | if len(links) > 0: # optimizeeがない場合はoptimizerはインスタンス化されないようにしてる。(self.opt[i]はNoneのままとなる) 143 | self.optchain[i] = linkset(links) # optimizeしたいlinksを束ねたchainを作る (opt.setupが引数にlinkしか受け付けないため...) 144 | self.opt[i] = eval(self.net_info['opti'][i]) # optimizerのインスタンス取得 145 | self.opt[i].setup(self.optchain[i]) 146 | # loss調べとく 147 | loss_tag = int(self.net_info['loss'][i]) # タグ番号 148 | for j in self.net_info['node'].keys(): 149 | n = self.net_info['node'][j] 150 | if n['optiflg'][loss_tag]: 151 | self.lossidx[i].append(j) 152 | if len(self.lossidx[i])==0: print('WARNING: please specify at least one node as the loss. (optimizer '+str(i)+')') 153 | if not linkflg: print('WARNING: please specify at least one node as the optimizee.') 154 | 155 | self.lossfrac = [np.zeros(2) for i in range(self.num_of_opt)] 156 | self.accfrac = [np.zeros(2) for i in range(self.num_of_opt)] 157 | self.aveloss = None 158 | self.aveacc = None 159 | self.update_cnt = 0 160 | def update_net(self, net_info_new): 161 | ''' 162 | ネットワークの動的更新 163 | 基本的には __init__() と同じなんだけど重みやoptimizerの状態を引き継ぐ必要がある 164 | ''' 165 | self.initialize(net_info_new, self.net_info) 166 | def recur_for_tsort(self, node_id): 167 | # self.net_infoを元にtsortするときの再帰用 self.visited_for_tsort を更新しつつ self.torder に結果入れてく 168 | if not(isinstance(node_id, str)): node_id = str(node_id) 169 | if node_id in self.visited_for_tsort: return 170 | self.visited_for_tsort.add(node_id) 171 | n = self.net_info['node'][node_id] 172 | for i in n['to']: 173 | to_node_id = str(self.net_info['edge'][str(i)]['post']) 174 | if not (to_node_id in self.visited_for_tsort): 175 | self.recur_for_tsort(to_node_id) 176 | self.torder.append(node_id) 177 | def __call__(self, mode='train'): # 順計算 178 | self.acc = {} 179 | for nid in self.torder: 180 | try: # あるノードでエラーが起きても残りのノードは計算して欲しいので、forの内側にtry書く必要ある。 181 | n = self.net_info['node'][str(nid)] 182 | n_ltype = n['ltype'] 183 | if n_ltype in ['r', 'i']: # 0 変数 184 | vs = self.prevars(n, sort=True) 185 | assert len(vs)==0, '0 arg is expected, but '+str(len(vs))+' are given.' 186 | self.h[nid] = self.l[nid]() # サンプル取得 187 | elif n_ltype in ['f', 'c', 'b']: # 1 変数 link 188 | vs = self.prevars(n, sort=True) 189 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 190 | self.h[nid] = self.l[nid](vs[0]) 191 | elif n_ltype == 'C': 192 | vs = self.prevars(n, sort=True) 193 | typ = n['opt']['type'] 194 | if typ == 'batch_dim': 195 | axis = 0 196 | elif typ == 'channel_dim': 197 | axis = 1 198 | else: 199 | raise ValueError("Unknown Concat type: "+typ + " (expected: 'batch_dim' or 'channel_dim')") 200 | self.h[nid] = F.concat(vs, axis=axis) 201 | elif n_ltype in ['+', '*']: 202 | vs = self.prevars(n) 203 | self.h[nid] = 0 204 | for v in vs: 205 | if n_ltype == '+': 206 | self.h[nid] += v 207 | else: 208 | self.h[nid] *= v 209 | elif n_ltype == '-': 210 | vs = self.prevars(n) 211 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 212 | self.h[nid] = -vs[0] 213 | elif n_ltype == 'm': 214 | vs = self.prevars(n) 215 | assert len(vs)==2, '2 args are expected, but '+str(len(vs))+' are given.' 216 | self.h[nid] = F.mean_squared_error(vs[0], vs[1]) 217 | elif n_ltype == 's': 218 | vs = self.prevars(n, sort=True) 219 | assert len(vs)==2, '2 args are expected, but '+str(len(vs))+' are given.' 220 | self.h[nid] = F.softmax_cross_entropy(vs[0], vs[1]) 221 | self.acc[nid] = F.accuracy(vs[0], vs[1]) 222 | elif n_ltype == 'd': # TODO: 1変数 function もまとめ時か 223 | vs = self.prevars(n) 224 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 225 | self.h[nid] = F.dropout(vs[0], **filter_dic(n['opt'], omit=['act'])) 226 | elif n_ltype == 'p': 227 | vs = self.prevars(n) 228 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 229 | typ = n['opt']['type'] 230 | if typ == 'max': 231 | self.h[nid] = F.max_pooling_2d(vs[0], **filter_dic(n['opt'], omit=['act', 'type'])) 232 | elif typ == 'average': 233 | self.h[nid] = F.average_pooling_2d(vs[0], **filter_dic(n['opt'], omit=['act', 'type'])) 234 | else: 235 | raise ValueError("Unknown pooling type: " + typ + " (expected: 'max' or 'average')") 236 | elif n_ltype == 'e': 237 | # experience replay やってきたサンプルを (リストself.l[nid]に) 溜め込む。そして、同じ数だけランダムに吐き出す。 238 | # Variableのitem assignmentが無いようなので、以下の「サンプル」は「ミニバッチ」に読み替えて実装することとする。 239 | vs = self.prevars(n) 240 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 241 | mem = self.l[nid] 242 | insiz = 1 # vs[0].data.shape[0] # 入ってきたサンプル数 243 | maxsiz = n['opt']['size'] # 溜め込む最大サンプル数 244 | nowsiz = len(mem) # 現在溜め込んでるサンプル数 245 | rest = maxsiz - nowsiz 246 | if rest > 0: mem += [vs[0]] 247 | if rest < insiz: 248 | # memに溜め込まれているサンプルをランダムに上書きする 249 | mem[np.random.randint(maxsiz)] = vs[0] 250 | self.h[nid] = mem[np.random.randint(len(mem))] # memからランダムにinsiz個のサンプルを吐き出す 251 | elif n_ltype == 'v': # value 252 | typ = n['opt']['type'] if 'type' in n['opt'] else 'np.float32' 253 | self.h[nid] = Variable(np.array([n['opt']['value']]*self.bs, dtype=np.dtype(typ))) 254 | elif n_ltype == 'T': # transpose (last 2 dim) それ以外やりたいなら F.transpose でやって 255 | vs = self.prevars(n) 256 | assert len(vs)==1, '1 arg is expected, but '+str(len(vs))+' are given.' 257 | nd = vs[0].data.ndim 258 | self.h[nid] = F.transpose(vs[0], axes=list(np.arange(nd-2))+[nd-1, nd-2]) 259 | elif n_ltype == 'o': # 任意の層 260 | vs = self.prevars(n, sort=True) 261 | self.h[nid] = eval(n['opt']['func'])(*vs, **filter_dic(n['opt'], omit=['act','func'])) 262 | # print(self.update_cnt,'calculated! shape=',self.h[nid].data.shape) 263 | 264 | # あとは活性化 265 | a = n['opt']['act'] 266 | if a == 'relu': 267 | self.h[nid] = F.relu(self.h[nid]) 268 | elif a in ['sigm', 'sigmoid']: 269 | self.h[nid] = F.sigmoid(self.h[nid]) 270 | elif a == 'elu': 271 | self.h[nid] = F.elu(self.h[nid]) 272 | elif a in ['l_relu', 'leaky_relu']: 273 | self.h[nid] = F.leaky_relu(self.h[nid]) 274 | elif a == 'tanh': 275 | self.h[nid] = F.tanh(self.h[nid]) 276 | elif a in ['id', 'identity']: 277 | pass 278 | else: 279 | self.h[nid] = eval(a)(self.h[nid]) # 任意の活性化関数を使える 280 | except: 281 | self.err_info[nid] = get_error_message() 282 | if nid in self.h: del self.h[nid] # 前回の情報が残りっぱなしになるのを防ぐ 283 | 284 | def prevars(self, n, sort=False): 285 | vs = [] 286 | for i in n['from']: 287 | pre_nid = self.net_info['edge'][str(i)]['pre'] 288 | if sort: 289 | vs.append((self.net_info['node'][str(pre_nid)]['x'], self.h[pre_nid])) 290 | else: 291 | vs.append(self.h[pre_nid]) 292 | if sort: 293 | vs.sort() 294 | return [v[1] for v in vs] 295 | return vs 296 | def update(self): 297 | try: 298 | self() # とりあえず順計算 299 | # ロスを計算 300 | self.loss = [0] * self.num_of_opt 301 | for i in range(self.num_of_opt): 302 | if not eval(self.net_info['cond'][i])(self.update_cnt): continue # conditionが合致してるかチェック 303 | for j in self.lossidx[i]: # まずはlossを計算(足し合わせ)accfracも計算 304 | self.loss[i] += self.h[j] 305 | if j in self.acc: self.accfrac[i] += np.array([self.acc[j].data, 1]) 306 | if isinstance(self.loss[i], Variable) and self.opt[i] is not None: 307 | # あとは逆伝播 (TODO: loss指定が前と同じならわざわざ逆伝播し直す必要はないはず。) 308 | for j in range(self.num_of_opt): 309 | if self.opt[j] is not None: self.optchain[j].cleargrads() # TODO: 一体どの範囲をcleargradsすれば十分なのかはよくわかっていない。毎回全部やっとくか、て感じになってる 310 | self.loss[i].grad = np.ones(self.loss[i].shape, dtype=np.float32) 311 | self.loss[i].backward() 312 | self.opt[i].update() 313 | self.lossfrac[i] += np.array([np.sum(self.loss[i].data), 1]) 314 | self.update_cnt += 1 315 | if self.update_cnt % self.aveintvl == 0: 316 | self.aveloss = self.get_aveloss(clear=True) 317 | self.aveacc = self.get_aveacc(clear=True) 318 | except: 319 | self.err_info['general'] = get_error_message() 320 | def get_aveloss(self, clear=False): 321 | ret = [None for i in range(self.num_of_opt)] 322 | for i in range(self.num_of_opt): 323 | if self.lossfrac[i][1] == 0: continue 324 | ret[i] = self.lossfrac[i][0]/self.lossfrac[i][1] 325 | if clear: self.lossfrac = [np.zeros(2) for i in range(self.num_of_opt)] 326 | return ret 327 | def get_aveacc(self, clear=False): 328 | ret = [None for i in range(self.num_of_opt)] 329 | for i in range(self.num_of_opt): 330 | if self.accfrac[i][1] == 0: continue 331 | ret[i] = self.accfrac[i][0]/self.accfrac[i][1] 332 | if clear: self.accfrac = [np.zeros(2) for i in range(self.num_of_opt)] 333 | return ret 334 | def W(self, idx): 335 | return self.l[idx].W.data 336 | 337 | 338 | class Sampler: 339 | def __init__(self, source, bs, opt, sample_shape, random_seed=None): 340 | _ = np.random.get_state() # 保存 341 | if random_seed is not None: np.random.seed(random_seed) 342 | self.random_state = np.random.get_state() 343 | np.random.set_state(_) # 復元 344 | self.source = source 345 | self.bs = bs 346 | self.opt = opt 347 | self.sample_shape = list(sample_shape) if isinstance(sample_shape, tuple) else sample_shape 348 | if self.source=='random_normal': 349 | self.sample_num = self.bs 350 | elif self.source in ['mnist_train_x', 'mnist_train_t', 'mnist_test_x', 'mnist_test_t']: 351 | # self.dataをロードする 352 | mnist_train, mnist_test = datasets.get_mnist() 353 | if self.source == 'mnist_train_x': 354 | self.data = np.array([d[0] for d in mnist_train], dtype=np.float32) 355 | if self.source == 'mnist_train_t': 356 | self.data = np.array([d[1] for d in mnist_train], dtype=np.int32) 357 | if self.source == 'mnist_test_x': 358 | self.data = np.array([d[0] for d in mnist_test], dtype=np.float32) 359 | if self.source == 'mnist_test_t': 360 | self.data = np.array([d[1] for d in mnist_test], dtype=np.int32) 361 | self.sample_num = len(self.data) 362 | 363 | elif self.source in ['fashion_mnist_train_x', 'fashion_mnist_train_t', 'fashion_mnist_test_x', 'fashion_mnist_test_t']: 364 | # self.dataをロードする 365 | fashion_mnist_train, fashion_mnist_test = datasets.get_fashion_mnist() 366 | if self.source == 'fashion_mnist_train_x': 367 | self.data = np.array([d[0] for d in fashion_mnist_train], dtype=np.float32) 368 | if self.source == 'fashion_mnist_train_t': 369 | self.data = np.array([d[1] for d in fashion_mnist_train], dtype=np.int32) 370 | if self.source == 'fashion_mnist_test_x': 371 | self.data = np.array([d[0] for d in fashion_mnist_test], dtype=np.float32) 372 | if self.source == 'fashion_mnist_test_t': 373 | self.data = np.array([d[1] for d in fashion_mnist_test], dtype=np.int32) 374 | self.sample_num = len(self.data) 375 | else: 376 | raise NotImplementedError 377 | self.epoch = 0 378 | self.sample_cnt = 0 379 | def __call__(self): 380 | if self.sample_cnt == 0: 381 | self.epoch += 1 382 | if self.source == 'random_normal': # 新しいサンプルを作ろう 383 | self.data = (np.random.randn(*([self.bs] + self.sample_shape)) * self.opt['sigma'] + self.opt['mu']).astype(np.float32) 384 | else: 385 | self.data = self.shuffled(self.data) # 自身の random_state の下でシャッフルするだけ 386 | 387 | ret = Variable(self.data[self.sample_cnt:self.sample_cnt+self.bs]) 388 | self.sample_cnt += self.bs 389 | if self.sample_cnt >= self.sample_num: self.sample_cnt = 0 390 | return ret 391 | def shuffled(self, ary): 392 | # こいつが持つrandom stateの下でaryをシャッフルしたものを返す 393 | # ary は np.array とする 394 | _ = np.random.get_state() # 保存 395 | np.random.set_state(self.random_state) 396 | # np.random.shuffle(ary) # TODO: shuffleが同じseedで必ず同じ順列に並べ替える(データ長以外のデータの性質に依存せず)ことを仮定しているので注意!! 397 | # データ非依存性がやや不安なので、indexをshuffleすることにする。 398 | perm = np.arange(len(ary)) 399 | np.random.shuffle(perm) 400 | ret = ary[perm] 401 | self.random_state = np.random.get_state() 402 | np.random.set_state(_) # 復元 403 | return ret 404 | 405 | class ComputationThreadManager(): # これは一度しかインスタンス化されない。 406 | def __init__(self): 407 | self.start_event = threading.Event() # 計算を開始させるかのフラグ 408 | self.stop_event = threading.Event() # 計算を停止させるかのフラグ 409 | self.exit_event = threading.Event() # スレッドを終了させるイベント 410 | self.update_net_event = threading.Event() # ネットワークを動的に更新するためのフラグ 411 | self.mdl = None 412 | self.thread = None 413 | def target(self): 414 | """別スレッド""" 415 | self.computing = False 416 | if self.exit_event.is_set(): self.exit_event.clear() 417 | while not self.exit_event.is_set(): # 別スレッドのメインループ 418 | if self.start_event.is_set(): # 学習開始時の処理 419 | self.start_event.clear() 420 | self.computing = True 421 | # 計算のための補助情報を self.net_info から計算する 422 | # トポロジカルソート順、サンプラーのロード、LinkやChain作成など 423 | self.mdl = nnmdl(self.net_info) 424 | if self.update_net_event.is_set(): # ネットワークの動的更新 425 | self.update_net_event.clear() 426 | self.mdl.update_net(self.net_info_new) 427 | if self.stop_event.is_set(): # 学習終了時の処理 428 | self.stop_event.clear() 429 | self.computing = False 430 | if self.computing: # 学習中の処理 431 | self.mdl.update() 432 | else: # 学習してない時の処理 433 | time.sleep(0.1) 434 | self.exit_event.clear() 435 | # self.mdl = None 436 | print('[end of thread]') 437 | """別スレッド終了""" 438 | def start_computing(self, net_info): 439 | try: 440 | self.net_info = net_info 441 | self.thread = threading.Thread(target = self.target) # スレッド作成(thread can only be started onceなので毎回インスタンス化が必要) 442 | self.thread.start() # スレッド開始! 443 | self.start_event.set() 444 | return '[computation thread started.]' 445 | except: 446 | print('[main thread] error') 447 | return get_error_message() # 学習用スレッドで生じたエラーはここには届かない!!!▲▲▲ 448 | def stop_computing(self): 449 | try: 450 | self.stop_event.set() 451 | self.exit_event.set() 452 | self.thread.join() #スレッドが停止するのを待つ 453 | return '[computation thread stopped.]' 454 | except: 455 | print('[main thread] error') 456 | return get_error_message() 457 | def update_net(self, net_info): # 学習中の、ネットワークの動的な更新! 458 | try: 459 | self.net_info_new = net_info 460 | self.update_net_event.set() 461 | return '[update_net_event set.]' 462 | except: 463 | print('[main thread] error') 464 | return get_error_message() 465 | def get_info(self, params): # これ、別スレッド内で情報収集させた方が、バッチ同期とれて良いのでは? TODO 466 | typ = params['type'] 467 | dic = {} 468 | if self.mdl is not None: 469 | if typ=='shape': 470 | for k,v in self.mdl.h.items(): 471 | dic[k] = v.data.shape if isinstance(v, Variable) and v.data is not None else () 472 | elif typ=='weight_summary': 473 | for k,v in self.mdl.l.items(): 474 | rec = {} 475 | if hasattr(v, 'W') and isinstance(v.W, Variable) and v.W.data is not None: 476 | rec['W_shape'] = v.W.data.shape 477 | rec['W_norm'] = float_if_float(np.linalg.norm(v.W.data)) 478 | if v.W.data.ndim==2: 479 | W_pre = np.dot(v.W.data.T, v.W.data) 480 | rec['W_pre_maxovl'] = float_if_float(cov2maxovl(W_pre)) 481 | rec['W_pre_minnorm'] = float_if_float(cov2minnorm(W_pre)) 482 | W_post = np.dot(v.W.data, v.W.data.T) 483 | rec['W_post_maxovl'] = float_if_float(cov2maxovl(W_post)) 484 | rec['W_post_minnorm'] = float_if_float(cov2minnorm(W_post)) 485 | if hasattr(v, 'b') and isinstance(v.b, Variable): 486 | rec['b_shape'] = v.b.data.shape 487 | rec['b_norm'] = float_if_float(np.linalg.norm(v.b.data)) 488 | dic[k] = rec 489 | elif typ=='learning_status': 490 | dic = {'aveloss': float_if_float(self.mdl.aveloss), 'aveacc': float_if_float(self.mdl.aveacc),'update_cnt': self.mdl.update_cnt, 'thread_alive': self.thread.is_alive()} 491 | elif typ=='image_sample': 492 | # image型なノードすべてから現在のイメージを送り返す ←image型に限らずにした! 493 | for k,v in self.mdl.h.items(): 494 | if isinstance(v, Variable) and v.data is not None: 495 | if v.data.ndim >= 1: # もともと ==4 にしてた 496 | dic[k] = [v.data[0].tolist()] 497 | if v.data.shape[0]>1: # もう一個送ってやる 498 | dic[k].append(v.data[1].tolist()) 499 | else: 500 | dic[k] = [v.data.tolist()] 501 | elif typ=='activation_detail': 502 | # nid番のノードの現在のactivationを丸ごと送り返す 503 | v = self.mdl.h[params['id']] 504 | dic = v.data.tolist() if isinstance(v, Variable) and v.data is not None else [] 505 | elif typ=='err': # 学習中に生じたエラーを取得 506 | dic = self.mdl.err_info 507 | self.mdl.err_info = {} # 取得されたので、エラーをクリアする 508 | return json.dumps(dic) 509 | def exec(self, com): 510 | exec(com) # エラー処理はこの外側で行います 511 | def eval(self, com): 512 | return eval(com) # エラー処理はこの外側で行います 513 | 514 | 515 | 516 | class MyHandler(BaseHTTPRequestHandler): 517 | def __init__(self, *initargs): 518 | # リクエスト来るたびにここから毎回実行されることに注意!!! 519 | super(BaseHTTPRequestHandler, self).__init__(*initargs) 520 | def do_GET(self): 521 | i=self.path.rfind('?') 522 | if i>=0: 523 | path, query=self.path[:i], self.path[i+1:] 524 | else: 525 | path=self.path 526 | query='' 527 | unquoted_query = urllib.parse.unquote(query) 528 | 529 | body = (str(np.random.randint(1000000007))+' Do nothing.
').encode('utf-8') 530 | self.send_response(200) 531 | self.send_header('Content-type', 'text/html; charset=utf-8') 532 | self.send_header('Content-length', len(body)) 533 | self.send_header('Access-Control-Allow-Origin', '*') # CORS解決 534 | self.end_headers() 535 | self.wfile.write(body) 536 | def do_POST(self): 537 | print("[POST] self.path =", self.path) 538 | content_len = int(self.headers.get('content-length')) 539 | requestBody = self.rfile.read(content_len).decode('UTF-8') 540 | # print('requestBody=' + requestBody) 541 | jsonData = json.loads(requestBody) 542 | # print('**JSON**') 543 | # print(json.dumps(jsonData, sort_keys=False, indent=4, separators={',', ':'})) 544 | body = '' 545 | 546 | com = jsonData['command'] 547 | if com == 'set': 548 | if ctm.thread is None or (not ctm.thread.is_alive()): # 学習用スレッドが生きてない場合は・・・ 549 | body = ctm.start_computing(net_info = jsonData['data']) # 新規の学習開始 550 | else: # 学習用スレッドが生きている場合は、学習中のネットワークの動的な更新! 551 | body = ctm.update_net(net_info = jsonData['data']) 552 | elif com == 'getinfo': 553 | body = ctm.get_info(jsonData['params']) # 'data'でもいいんでは 中身 typeだけだし 554 | elif com == 'exec': # 文字通りexecする 555 | try: 556 | ctm.exec(jsonData['data']) 557 | body = 'executed successfully' 558 | except: 559 | body = get_error_message() 560 | elif com == 'eval': # 文字通りevalする 561 | try: 562 | body = json.dumps({'error':0, 'ret':str(ctm.eval(jsonData['data']))}) 563 | except: 564 | body = json.dumps({'error':1, 'ret':get_error_message()}) 565 | elif com == 'stop': 566 | body = ctm.stop_computing() 567 | elif com == 'shutdown': 568 | if ctm.thread is not None and ctm.thread.is_alive(): ctm.stop_computing() # 学習用スレッド生きてるなら殺す 569 | global httpd 570 | threading.Thread(target=httpd.shutdown).start() # サーバーを落とす。他のスレッドからじゃないと落とせないようだ。 571 | else: 572 | body = 'unknown command. computation thread is '+('alive' if ctm.thread.is_alive() else 'dead') 573 | body = body.encode('utf-8') 574 | self.send_response(200) 575 | self.send_header('Content-type', 'text/plain') 576 | self.send_header('Access-Control-Allow-Origin', '*') # CORS解決 577 | self.end_headers() 578 | self.wfile.write(body) 579 | 580 | 581 | # index.htmlを開く 582 | import webbrowser, os 583 | webbrowser.open('file:///' + os.path.abspath(".") + '/index.html') 584 | 585 | parser = argparse.ArgumentParser(description='server') 586 | parser.add_argument('-p', '--port', 587 | default=8000, help='port number') 588 | commandargs = parser.parse_args() 589 | 590 | 591 | ctm = ComputationThreadManager() # 計算用スレッドの開始 592 | host = 'localhost' 593 | port = int(commandargs.port) 594 | httpd = HTTPServer((host, port), MyHandler) 595 | print('serving at port', port) 596 | httpd.serve_forever() --------------------------------------------------------------------------------