├── .gitignore ├── inference ├── cc │ ├── www │ │ ├── __init__.py │ │ ├── log │ │ │ └── __init__.py │ │ ├── handlers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── index.py │ │ ├── static │ │ │ ├── img │ │ │ │ └── README │ │ │ ├── fonts │ │ │ │ ├── glyphicons-halflings-regular.eot │ │ │ │ ├── glyphicons-halflings-regular.ttf │ │ │ │ └── glyphicons-halflings-regular.woff │ │ │ ├── css │ │ │ │ └── dashboard.css │ │ │ └── js │ │ │ │ └── application.js │ │ ├── templates │ │ │ ├── error.html │ │ │ ├── _hcheck.hdn │ │ │ ├── index.html │ │ │ └── layout.html │ │ ├── env.sh │ │ ├── stop.sh │ │ ├── start.sh │ │ └── etagger_dm.py │ ├── src │ │ ├── Config.cc │ │ ├── inference_iris.cc │ │ ├── inference.cc │ │ ├── Vocab.cc │ │ ├── inference_example.cc │ │ ├── TFUtil.cc │ │ ├── Input.cc │ │ └── Etagger.cc │ ├── include │ │ ├── result_obj.h │ │ ├── Config.h │ │ ├── Etagger.h │ │ ├── TFUtil.h │ │ ├── Vocab.h │ │ └── Input.h │ ├── wrapper │ │ ├── Etagger.py │ │ └── inference.py │ └── CMakeLists.txt ├── python │ ├── www │ │ ├── __init__.py │ │ ├── log │ │ │ └── __init__.py │ │ ├── static │ │ │ ├── img │ │ │ │ └── README │ │ │ ├── fonts │ │ │ │ ├── glyphicons-halflings-regular.eot │ │ │ │ ├── glyphicons-halflings-regular.ttf │ │ │ │ └── glyphicons-halflings-regular.woff │ │ │ ├── css │ │ │ │ └── dashboard.css │ │ │ └── js │ │ │ │ └── application.js │ │ ├── handlers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── index.py │ │ ├── templates │ │ │ ├── error.html │ │ │ ├── _hcheck.hdn │ │ │ ├── index.html │ │ │ └── layout.html │ │ ├── env.sh │ │ ├── stop.sh │ │ ├── start.sh │ │ └── etagger_dm.py │ ├── inference_example.py │ ├── inference_iris.py │ ├── inference.py │ └── inference_trt.py ├── train_example.py ├── export.py ├── train_iris.py ├── etc │ └── iris.txt └── freeze.py ├── requirements.txt ├── etc ├── graph-2.png ├── graph-3.png ├── graph-4.png ├── graph-5.png ├── graph-6.png ├── warmup-1.png ├── warmup-2.png ├── webapi-1.png ├── webapi-2.png ├── test_flair.py ├── conv.py ├── inspect.py ├── repair.py ├── test_spacy.py ├── chunk_eval.py └── token_eval.py ├── data └── config.json ├── early_stopping.py ├── test_berttok.py ├── test_bilm.py ├── progbar.py ├── inference.py ├── feed.py ├── embvec.py └── config.py /.gitignore: -------------------------------------------------------------------------------- 1 | embeddings 2 | -------------------------------------------------------------------------------- /inference/cc/www/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/cc/www/log/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/python/www/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/cc/www/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/cc/www/static/img/README: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/python/www/log/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/python/www/static/img/README: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/python/www/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | seqeval 3 | tornado 4 | -------------------------------------------------------------------------------- /etc/graph-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/graph-2.png -------------------------------------------------------------------------------- /etc/graph-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/graph-3.png -------------------------------------------------------------------------------- /etc/graph-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/graph-4.png -------------------------------------------------------------------------------- /etc/graph-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/graph-5.png -------------------------------------------------------------------------------- /etc/graph-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/graph-6.png -------------------------------------------------------------------------------- /etc/warmup-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/warmup-1.png -------------------------------------------------------------------------------- /etc/warmup-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/warmup-2.png -------------------------------------------------------------------------------- /etc/webapi-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/webapi-1.png -------------------------------------------------------------------------------- /etc/webapi-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/etc/webapi-2.png -------------------------------------------------------------------------------- /inference/cc/www/templates/error.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | ERROR 4 | 5 | -------------------------------------------------------------------------------- /inference/python/www/templates/error.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | ERROR 4 | 5 | -------------------------------------------------------------------------------- /inference/cc/www/templates/_hcheck.hdn: -------------------------------------------------------------------------------- 1 | 2 | 3 | HealthCheck OK 4 | 5 | -------------------------------------------------------------------------------- /inference/python/www/templates/_hcheck.hdn: -------------------------------------------------------------------------------- 1 | 2 | 3 | HealthCheck OK 4 | 5 | -------------------------------------------------------------------------------- /inference/cc/www/static/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/cc/www/static/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /inference/cc/www/static/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/cc/www/static/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /inference/cc/www/static/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/cc/www/static/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /inference/python/www/static/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/python/www/static/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /inference/python/www/static/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/python/www/static/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /inference/python/www/static/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsindex/etagger/HEAD/inference/python/www/static/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /inference/cc/src/Config.cc: -------------------------------------------------------------------------------- 1 | #include "Config.h" 2 | 3 | /* 4 | * public methods 5 | */ 6 | 7 | Config::Config() 8 | { 9 | } 10 | 11 | Config::Config(int word_length) 12 | { 13 | this->word_length = word_length; 14 | } 15 | 16 | Config::~Config() 17 | { 18 | } 19 | -------------------------------------------------------------------------------- /inference/cc/include/result_obj.h: -------------------------------------------------------------------------------- 1 | #ifndef RESULT_OBJ 2 | #define RESULT_OBJ 3 | 4 | #define MAX_WORD 64 5 | #define MAX_POS 64 6 | #define MAX_CHK 64 7 | #define MAX_TAG 64 8 | struct result_obj { 9 | char word[MAX_WORD]; 10 | char pos[MAX_POS]; 11 | char chk[MAX_CHK]; 12 | char tag[MAX_TAG]; 13 | char predict[MAX_TAG]; 14 | }; 15 | #endif 16 | -------------------------------------------------------------------------------- /inference/cc/include/Config.h: -------------------------------------------------------------------------------- 1 | #ifndef CONFIG_H 2 | #define CONFIG_H 3 | 4 | class Config { 5 | 6 | public: 7 | Config(); 8 | Config(int word_length); 9 | void SetClassSize(int class_size) { this->class_size = class_size; } 10 | int GetClassSize() { return class_size; } 11 | int GetWordLength() { return word_length; } 12 | ~Config(); 13 | 14 | private: 15 | int class_size; // assigned after loading vocab 16 | int word_length; 17 | }; 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /inference/cc/www/handlers/base.py: -------------------------------------------------------------------------------- 1 | import tornado.web 2 | import logging 3 | 4 | class BaseHandler(tornado.web.RequestHandler): 5 | @property 6 | def log(self): 7 | return self.application.log 8 | @property 9 | def ppid(self): 10 | return self.application.ppid 11 | @property 12 | def Etagger(self): 13 | return self.application.Etagger 14 | @property 15 | def etagger(self): 16 | return self.application.etagger 17 | @property 18 | def nlp(self): 19 | return self.application.nlp 20 | -------------------------------------------------------------------------------- /inference/python/www/handlers/base.py: -------------------------------------------------------------------------------- 1 | import tornado.web 2 | import logging 3 | 4 | class BaseHandler(tornado.web.RequestHandler): 5 | @property 6 | def log(self): 7 | return self.application.log 8 | @property 9 | def ppid(self): 10 | return self.application.ppid 11 | @property 12 | def etagger(self): 13 | return self.application.etagger 14 | @property 15 | def config(self): 16 | return self.application.config 17 | @property 18 | def nlp(self): 19 | return self.application.nlp 20 | -------------------------------------------------------------------------------- /inference/cc/include/Etagger.h: -------------------------------------------------------------------------------- 1 | #ifndef ETAGGER_H 2 | #define ETAGGER_H 3 | 4 | #include "TFUtil.h" 5 | #include "Input.h" 6 | #include "result_obj.h" // for c, python wrapper 7 | 8 | class Etagger { 9 | public: 10 | Etagger(string frozen_graph_fn, string vocab_fn, int word_length, bool lowercase, bool is_memmapped, int num_threads); 11 | int Analyze(vector& bucket); 12 | ~Etagger(); 13 | 14 | private: 15 | TFUtil* util; 16 | tensorflow::Session* sess; 17 | Config* config; 18 | Vocab* vocab; 19 | 20 | }; 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /inference/cc/www/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'layout.html' %} 2 | 3 | {% block pagetitle%} 4 | 7 | {% end %} 8 | 9 | {% block content %} 10 |
11 |
12 | 15 |
16 |
17 | 18 |
19 |
20 | 21 |
22 | 23 |
24 |
25 | 26 |
27 | 28 |
29 |
30 | {%end%} 31 | -------------------------------------------------------------------------------- /inference/python/www/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'layout.html' %} 2 | 3 | {% block pagetitle%} 4 | 7 | {% end %} 8 | 9 | {% block content %} 10 |
11 |
12 | 15 |
16 |
17 | 18 |
19 |
20 | 21 |
22 | 23 |
24 |
25 | 26 |
27 | 28 |
29 |
30 | {%end%} 31 | -------------------------------------------------------------------------------- /inference/python/inference_example.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | val = np.array([[1, 1]], dtype=np.float32) 5 | 6 | with tf.Session() as sess: 7 | 8 | # restore meta graph 9 | loader = tf.train.import_meta_graph('./exported/my_model.meta') 10 | # mapping placeholders and tensors 11 | x = tf.get_default_graph().get_tensor_by_name('input:0') 12 | output = tf.get_default_graph().get_tensor_by_name('output:0') 13 | kernel = tf.get_default_graph().get_tensor_by_name('dense/kernel:0') 14 | bias = tf.get_default_graph().get_tensor_by_name('dense/bias:0') 15 | # restore actual values 16 | loader = loader.restore(sess, './exported/my_model') 17 | 18 | x, output, kernel, bias = sess.run([x, output, kernel, bias], {x: val}) 19 | 20 | print(tf.global_variables()) 21 | print("input ", x) 22 | print("output ", output) 23 | print("dense/kernel:0 ", kernel) 24 | print("dense/bias:0 ", bias) 25 | -------------------------------------------------------------------------------- /data/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "chr_dim": 25, 3 | "pos_dim": 7, 4 | "chk_dim": 10, 5 | "keep_prob": 0.7, 6 | "chr_conv_type": "conv1d", 7 | "filter_sizes": [3], 8 | "num_filters": 53, 9 | "highway_used": false, 10 | "rnn_used": true, 11 | "rnn_num_layers": 2, 12 | "rnn_type": "fused", 13 | "rnn_size": 200, 14 | "tf_used": false, 15 | "tf_num_layers": 4, 16 | "tf_keep_prob": 0.8, 17 | "tf_mh_num_heads": 4, 18 | "tf_mh_num_units": 64, 19 | "tf_mh_keep_prob": 0.8, 20 | "tf_ffn_kernel_size": 3, 21 | "tf_ffn_keep_prob": 0.8, 22 | "qrnn_size": 200, 23 | "qrnn_filter_size": 3, 24 | "qrnn_num_layers": 1, 25 | "starter_learning_rate": 0.001, 26 | "num_warmup_epoch": 0, 27 | "decay_steps": 12000, 28 | "decay_rate": 0.9, 29 | "clip_norm": 10, 30 | "elmo_word_length": 50, 31 | "elmo_keep_prob": 0.7, 32 | "bert_keep_prob": 0.7, 33 | "use_bert_optimization": false, 34 | "starter_learning_rate_for_tf": 0.0003, 35 | "num_warmup_epoch_for_bert": 2 36 | } 37 | -------------------------------------------------------------------------------- /inference/train_example.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | x = tf.placeholder(tf.float32, shape=[1, 2], name='input') 5 | output = tf.identity(tf.layers.dense(x, 1), name='output') 6 | 7 | val = np.array([[1, 1]], dtype=np.float32) 8 | 9 | with tf.Session() as sess: 10 | sess.run(tf.global_variables_initializer()) 11 | 12 | # save graph and weights 13 | saver = tf.train.Saver(tf.global_variables()) 14 | saver.save(sess, './exported/my_model') 15 | 16 | tf.train.write_graph(sess.graph, '.', "./exported/graph.pb", as_text=False) 17 | tf.train.write_graph(sess.graph, '.', "./exported/graph.pb_txt", as_text=True) 18 | 19 | t1 = tf.get_default_graph().get_tensor_by_name('output:0') 20 | t2 = tf.get_default_graph().get_tensor_by_name('dense/kernel:0') 21 | t3 = tf.get_default_graph().get_tensor_by_name('dense/bias:0') 22 | 23 | t1, t2, t3, x = sess.run([t1, t2, t3, x], {x: val}) 24 | 25 | print(tf.global_variables()) 26 | print("input ", x) 27 | print("output ", t1) 28 | print("dense/kernel:0 ", t2) 29 | print("dense/bias:0 ", t3) 30 | -------------------------------------------------------------------------------- /inference/cc/www/static/css/dashboard.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Base structure 3 | */ 4 | 5 | /* Move down content because we have a fixed navbar that is 50px tall */ 6 | body { 7 | padding-top: 50px; 8 | } 9 | 10 | /* 11 | * Global add-ons 12 | */ 13 | 14 | .sub-header { 15 | padding-bottom: 10px; 16 | border-bottom: 1px solid #eee; 17 | } 18 | 19 | /* 20 | * Top navigation 21 | * Hide default border to remove 1px line. 22 | */ 23 | .navbar-fixed-top { 24 | border: 0; 25 | } 26 | 27 | /* 28 | * Main content 29 | */ 30 | 31 | .main { 32 | padding: 20px; 33 | } 34 | 35 | @media (min-width: 768px) { 36 | .main { 37 | padding-right: 40px; 38 | padding-left: 40px; 39 | } 40 | } 41 | 42 | .main .page-header { 43 | margin-top: 0; 44 | } 45 | 46 | /* 47 | * Sidebar 48 | */ 49 | 50 | .nav-sidebar li a { 51 | font-size: 0.7em; 52 | } 53 | 54 | /* 55 | * Placeholder dashboard ideas 56 | */ 57 | 58 | .placeholders { 59 | margin-bottom: 30px; 60 | text-align: center; 61 | } 62 | 63 | .placeholders h4 { 64 | margin-bottom: 0; 65 | } 66 | 67 | .placeholder { 68 | margin-bottom: 20px; 69 | } 70 | 71 | .placeholder img { 72 | display: inline-block; 73 | border-radius: 50%; 74 | } 75 | -------------------------------------------------------------------------------- /inference/cc/include/TFUtil.h: -------------------------------------------------------------------------------- 1 | #ifndef TFUTIL_H 2 | #define TFUTIL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | typedef vector> tensor_dict; 17 | 18 | class TFUtil { 19 | 20 | public: 21 | TFUtil(); 22 | tensorflow::MemmappedEnv* CreateMemmappedEnv(string graph_fn); 23 | tensorflow::Session* CreateSession(tensorflow::MemmappedEnv* memmapped_env, int num_threads); 24 | void DestroySession(tensorflow::Session* sess); 25 | tensorflow::Status LoadFrozenModel(tensorflow::Session *sess, string graph_fn); 26 | tensorflow::Status LoadFrozenMemmappedModel(tensorflow::MemmappedEnv* memmapped_env, tensorflow::Session *sess); 27 | tensorflow::Status LoadModel(tensorflow::Session *sess, string graph_fn, string checkpoint_fn); 28 | ~TFUtil(); 29 | 30 | private: 31 | void load_lstm_lib(); 32 | void load_qrnn_lib(); 33 | }; 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /inference/python/www/static/css/dashboard.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Base structure 3 | */ 4 | 5 | /* Move down content because we have a fixed navbar that is 50px tall */ 6 | body { 7 | padding-top: 50px; 8 | } 9 | 10 | /* 11 | * Global add-ons 12 | */ 13 | 14 | .sub-header { 15 | padding-bottom: 10px; 16 | border-bottom: 1px solid #eee; 17 | } 18 | 19 | /* 20 | * Top navigation 21 | * Hide default border to remove 1px line. 22 | */ 23 | .navbar-fixed-top { 24 | border: 0; 25 | } 26 | 27 | /* 28 | * Main content 29 | */ 30 | 31 | .main { 32 | padding: 20px; 33 | } 34 | 35 | @media (min-width: 768px) { 36 | .main { 37 | padding-right: 40px; 38 | padding-left: 40px; 39 | } 40 | } 41 | 42 | .main .page-header { 43 | margin-top: 0; 44 | } 45 | 46 | /* 47 | * Sidebar 48 | */ 49 | 50 | .nav-sidebar li a { 51 | font-size: 0.7em; 52 | } 53 | 54 | /* 55 | * Placeholder dashboard ideas 56 | */ 57 | 58 | .placeholders { 59 | margin-bottom: 30px; 60 | text-align: center; 61 | } 62 | 63 | .placeholders h4 { 64 | margin-bottom: 0; 65 | } 66 | 67 | .placeholder { 68 | margin-bottom: 20px; 69 | } 70 | 71 | .placeholder img { 72 | display: inline-block; 73 | border-radius: 50%; 74 | } 75 | -------------------------------------------------------------------------------- /etc/test_flair.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | from flair.data import Sentence 6 | from flair.models import SequenceTagger 7 | 8 | def spill_bucket(tagger, bucket): 9 | sentence = [] 10 | for line in bucket: 11 | tokens = line.split() 12 | word = tokens[0] 13 | sentence.append(word) 14 | # make a sentence 15 | sentence = Sentence(' '.join(sentence)) 16 | # run NER over sentence 17 | tagger.predict(sentence) 18 | print(sentence) 19 | print('The following NER tags are found:') 20 | print(sentence.to_tagged_string()) 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | args = parser.parse_args() 26 | 27 | # load the NER tagger 28 | tagger = SequenceTagger.load('ner') 29 | 30 | bucket = [] 31 | while 1: 32 | try: line = sys.stdin.readline() 33 | except KeyboardInterrupt: break 34 | if not line: break 35 | line = line.strip() 36 | if not line and len(bucket) >= 1: 37 | spill_bucket(tagger, bucket) 38 | bucket = [] 39 | if line : bucket.append(line) 40 | 41 | if len(bucket) != 0: 42 | spill_bucket(tagger, bucket) 43 | -------------------------------------------------------------------------------- /etc/conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | class Conv: 6 | def __init__(self): 7 | self.task = 'conv' 8 | 9 | def conv_bucket(self, bucket): 10 | for line in bucket: 11 | line = line.replace('\t', ' ') 12 | tokens = line.split() 13 | size = len(tokens) 14 | assert(size == 5) 15 | morph = tokens[0] 16 | mtag = tokens[1] 17 | etype = tokens[2] 18 | em_etype = tokens[3] 19 | tag = tokens[4] 20 | 21 | etype = etype.replace('_B', '').replace('_I', '') 22 | print(morph, mtag, etype, tag) 23 | print('') 24 | 25 | def conv(self): 26 | bucket = [] 27 | while 1: 28 | try: line = sys.stdin.readline() 29 | except KeyboardInterrupt: break 30 | if not line: break 31 | line = line.strip() 32 | if not line and len(bucket) >= 1: 33 | self.conv_bucket(bucket) 34 | bucket = [] 35 | if line : bucket.append(line) 36 | if len(bucket) != 0: 37 | self.conv_bucket(bucket) 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | 42 | args = parser.parse_args() 43 | 44 | c = Conv() 45 | c.conv() 46 | -------------------------------------------------------------------------------- /inference/cc/include/Vocab.h: -------------------------------------------------------------------------------- 1 | #ifndef VOCAB_H 2 | #define VOCAB_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | 16 | class Vocab { 17 | 18 | public: 19 | Vocab(string vocab_fn, bool lowercase); 20 | int GetTagVocabSize() { return tag_vocab.size(); } 21 | void Split(string s, vector& tokens); 22 | int GetWid(string word); 23 | int GetCid(string ch); 24 | int GetPadCid() { return pad_cid; } 25 | int GetPid(string pos); 26 | int GetKid(string chk); 27 | string GetTag(int tid); 28 | ~Vocab(); 29 | 30 | private: 31 | // same as config.py 32 | bool lowercase; 33 | int pad_wid = 0; 34 | int unk_wid = 1; 35 | int pad_cid = 0; 36 | int unk_cid = 1; 37 | int pad_pid = 0; 38 | int unk_pid = 1; 39 | int pad_kid = 0; 40 | int unk_kid = 1; 41 | int oot_tid = 0; 42 | int xot_tid = 1; 43 | string oot_tag = "O"; 44 | string xot_tag = "X"; 45 | map wrd_vocab; 46 | map chr_vocab; 47 | map pos_vocab; 48 | map chk_vocab; 49 | map tag_vocab; 50 | map itag_vocab; 51 | 52 | bool load_vocab(string vocab_fn); 53 | }; 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /inference/cc/www/static/js/application.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | var submitElem = $('#btnSubmit'); 3 | if (submitElem) { 4 | submitElem.on('click', function() { 5 | $.ajax({ 6 | url: '/etaggertest', 7 | data: $('#docForm').serialize(), 8 | method: 'POST' 9 | }).done(function(data) { 10 | if (data.success) { 11 | $('#info').empty(); 12 | if(data.info) { 13 | $('textarea#info').text(data.info); 14 | } 15 | $('#record_table').empty(); 16 | if(data.record) { 17 | var trHTML = ''; 18 | trHTML += ""; 19 | trHTML += "id"; 20 | trHTML += "word"; 21 | trHTML += "pos"; 22 | trHTML += "chk"; 23 | trHTML += "tag"; 24 | trHTML += "predict"; 25 | trHTML += ""; 26 | $.each(data.record, function (i, entry) { 27 | $.each(entry, function (j, item) { 28 | trHTML += '' + item.id; 29 | trHTML += '' + item.word; 30 | trHTML += '' + item.pos; 31 | trHTML += '' + item.chk; 32 | trHTML += '' + item.tag; 33 | trHTML += '' + item.predict; 34 | trHTML += ''; 35 | }); 36 | trHTML += ''; 37 | }); 38 | $('#record_table').append(trHTML); 39 | } 40 | } 41 | }); 42 | }); 43 | } 44 | }); 45 | -------------------------------------------------------------------------------- /inference/python/www/static/js/application.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | var submitElem = $('#btnSubmit'); 3 | if (submitElem) { 4 | submitElem.on('click', function() { 5 | $.ajax({ 6 | url: '/etaggertest', 7 | data: $('#docForm').serialize(), 8 | method: 'POST' 9 | }).done(function(data) { 10 | if (data.success) { 11 | $('#info').empty(); 12 | if(data.info) { 13 | $('textarea#info').text(data.info); 14 | } 15 | $('#record_table').empty(); 16 | if(data.record) { 17 | var trHTML = ''; 18 | trHTML += ""; 19 | trHTML += "id"; 20 | trHTML += "word"; 21 | trHTML += "pos"; 22 | trHTML += "chk"; 23 | trHTML += "tag"; 24 | trHTML += "predict"; 25 | trHTML += ""; 26 | $.each(data.record, function (i, entry) { 27 | $.each(entry, function (j, item) { 28 | trHTML += '' + item.id; 29 | trHTML += '' + item.word; 30 | trHTML += '' + item.pos; 31 | trHTML += '' + item.chk; 32 | trHTML += '' + item.tag; 33 | trHTML += '' + item.predict; 34 | trHTML += ''; 35 | }); 36 | trHTML += ''; 37 | }); 38 | $('#record_table').append(trHTML); 39 | } 40 | } 41 | }); 42 | }); 43 | } 44 | }); 45 | -------------------------------------------------------------------------------- /etc/inspect.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | class Inspect: 6 | def __init__(self): 7 | self.task = 'inspect' 8 | 9 | def inspect_bucket(self, bucket): 10 | for line in bucket: 11 | line = line.replace('\t', ' ') 12 | tokens = line.split() 13 | size = len(tokens) 14 | morph = tokens[0] 15 | mtag = tokens[1] 16 | etype = tokens[2] 17 | tag = tokens[3] 18 | pred = tokens[4] 19 | comment = 'SUCC' 20 | if tag != pred: comment = 'FAIL' 21 | l = [morph, mtag, etype, tag, pred, comment] 22 | out = '\t'.join(l) 23 | print(out) 24 | print('') 25 | 26 | def inspect(self): 27 | bucket = [] 28 | while 1: 29 | try: line = sys.stdin.readline() 30 | except KeyboardInterrupt: break 31 | if not line: break 32 | line = line.strip() 33 | if not line and len(bucket) >= 1: 34 | self.inspect_bucket(bucket) 35 | bucket = [] 36 | if line : bucket.append(line) 37 | if len(bucket) != 0: 38 | self.inspect_bucket(bucket) 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | 43 | args = parser.parse_args() 44 | 45 | i = Inspect() 46 | i.inspect() 47 | -------------------------------------------------------------------------------- /inference/cc/include/Input.h: -------------------------------------------------------------------------------- 1 | #ifndef INPUT_H 2 | #define INPUT_H 3 | 4 | #include 5 | #include "Config.h" 6 | #include "Vocab.h" 7 | 8 | class Input { 9 | public: 10 | Input(Config* config, Vocab* vocab, vector& bucket); 11 | int GetMaxSentenceLength() { return max_sentence_length; } 12 | tensorflow::Tensor* GetSentenceWordIds() { return sentence_word_ids; } 13 | tensorflow::Tensor* GetSentenceWordChrIds() { return sentence_wordchr_ids; } 14 | tensorflow::Tensor* GetSentencePosIds() { return sentence_pos_ids; } 15 | tensorflow::Tensor* GetSentenceChkIds() { return sentence_chk_ids; } 16 | tensorflow::Tensor* GetSentenceLength() { return sentence_length; } 17 | tensorflow::Tensor* GetIsTrain() { return is_train; } 18 | ~Input(); 19 | 20 | private: 21 | // same as input.py 22 | int max_sentence_length; 23 | tensorflow::Tensor* sentence_word_ids; // (1, max_sentence_length) 24 | tensorflow::Tensor* sentence_wordchr_ids; // (1, max_sentence_length, word_length) 25 | tensorflow::Tensor* sentence_pos_ids; // (1, max_sentence_length) 26 | tensorflow::Tensor* sentence_chk_ids; // (1, max_sentence_length) 27 | tensorflow::Tensor* sentence_length; // scalar tensor 28 | tensorflow::Tensor* is_train; // scalar tensor 29 | 30 | int utf8_len(char chr); 31 | unsigned int* build_coffarr(const char* in, int in_size); 32 | 33 | }; 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /inference/python/inference_iris.py: -------------------------------------------------------------------------------- 1 | #!./bin/env python 2 | 3 | import sys 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | def load_frozen_graph(frozen_graph_filename, prefix='prefix'): 8 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 9 | graph_def = tf.GraphDef() 10 | graph_def.ParseFromString(f.read()) 11 | 12 | with tf.Graph().as_default() as graph: 13 | tf.import_graph_def( 14 | graph_def, 15 | input_map=None, 16 | return_elements=None, 17 | op_dict=None, 18 | producer_op_list=None, 19 | name=prefix, 20 | ) 21 | 22 | return graph 23 | 24 | frozen_graph_filename = './exported/iris_frozen.pb' 25 | graph = load_frozen_graph(frozen_graph_filename, prefix='prefix') 26 | for op in graph.get_operations(): 27 | print(op.name) 28 | 29 | W = graph.get_tensor_by_name('prefix/W:0') 30 | b = graph.get_tensor_by_name('prefix/b:0') 31 | X = graph.get_tensor_by_name('prefix/X:0') 32 | logits = graph.get_tensor_by_name('prefix/logits:0') 33 | 34 | 35 | with tf.Session(graph=graph) as sess: 36 | print(tf.global_variables()) 37 | 38 | p = sess.run(logits, feed_dict={X:[[2,14,33,50]]}) # 1 0 0 -> type 0 39 | print(p, sess.run(tf.argmax(p, 1))) 40 | 41 | p = sess.run(logits, feed_dict={X:[[24,56,31,67]]}) # 0 1 0 -> type 1 42 | print(p, sess.run(tf.argmax(p, 1))) 43 | 44 | p = sess.run(logits, feed_dict={X:[[2,14,33,50], [24,56,31,67]]}) 45 | print(p, sess.run(tf.argmax(p, 1))) 46 | -------------------------------------------------------------------------------- /early_stopping.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | ''' 4 | source from http://forensics.tistory.com/29 5 | ''' 6 | 7 | class EarlyStopping(): 8 | 9 | def __init__(self, patience=0, measure='loss', verbose=0): 10 | """Set early stopping condition 11 | 12 | Args: 13 | patience: how many times to be patient before early stopping. 14 | measure: checking measure, loss | f1 | accuracy. 15 | verbose: if 1, enable verbose mode. 16 | """ 17 | self._step = 0 18 | if measure == 'loss': # loss 19 | self._value = float('inf') 20 | else: # f1, accuracy 21 | self._value = 0 22 | self.patience = patience 23 | self.verbose = verbose 24 | 25 | def reset(self, value): 26 | self._step = 0 27 | self._value = value 28 | 29 | def status(self): 30 | print('Status: step / patience = %d / %d, value = %f\n' % (self._step, self.patience, self._value)) 31 | 32 | def step(self): 33 | return self._step 34 | 35 | def validate(self, value, measure='loss'): 36 | going_worse = False 37 | if measure == 'loss': # loss 38 | if self._value < value: going_worse = True 39 | else: # f1, accuracy 40 | if self._value > value: going_worse = True 41 | if going_worse: 42 | self._step += 1 43 | if self._step > self.patience: 44 | if self.verbose: 45 | print('Training process is stopped early!') 46 | return True 47 | else: 48 | self.reset(value) 49 | return False 50 | 51 | -------------------------------------------------------------------------------- /inference/python/www/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # code from http://stackoverflow.com/a/1116890 5 | function readlink() 6 | { 7 | TARGET_FILE=$2 8 | cd `dirname $TARGET_FILE` 9 | TARGET_FILE=`basename $TARGET_FILE` 10 | 11 | # Iterate down a (possible) chain of symlinks 12 | while [ -L "$TARGET_FILE" ] 13 | do 14 | TARGET_FILE=`readlink $TARGET_FILE` 15 | cd `dirname $TARGET_FILE` 16 | TARGET_FILE=`basename $TARGET_FILE` 17 | done 18 | 19 | # Compute the canonicalized name by finding the physical path 20 | # for the directory we're in and appending the target file. 21 | PHYS_DIR=`pwd -P` 22 | RESULT=$PHYS_DIR/$TARGET_FILE 23 | echo $RESULT 24 | } 25 | export -f readlink 26 | 27 | export LANG=ko_KR.UTF-8 28 | 29 | # directory 30 | ## current dir of this script 31 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 32 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 33 | 34 | # server 35 | daemon_name='etagger_dm.py' 36 | port_devel=8898 37 | port_service=8898 38 | 39 | # command setting 40 | python='env python' 41 | 42 | # setting 43 | #EMB_FILENAME=kor.glove.300k.300d.txt.pkl # kor 44 | EMB_FILENAME=glove.6B.100d.txt.pkl # eng 45 | EMB_CLASS=glove 46 | CONFIG_FILENAME=config.json 47 | WRD_DIM=100 48 | FROZEN_FILENAME=ner_frozen.pb 49 | export CUDA_VISIBLE_DEVICES=0 50 | 51 | # functions 52 | 53 | function make_calmness() 54 | { 55 | exec 3>&2 # save 2 to 3 56 | exec 2> /dev/null 57 | } 58 | 59 | function revert_calmness() 60 | { 61 | exec 2>&3 # restore 2 from previous saved 3(originally 2) 62 | } 63 | 64 | function close_fd() 65 | { 66 | exec 3>&- 67 | } 68 | 69 | function jumpto 70 | { 71 | label=$1 72 | cmd=$(sed -n "/$label:/{:a;n;p;ba};" $0 | grep -v ':$') 73 | eval "$cmd" 74 | exit 75 | } 76 | -------------------------------------------------------------------------------- /inference/cc/www/templates/layout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block title %}ETAGGER{% end %} 5 | 6 | 7 | 8 | 9 | 29 |
30 |
31 | {% block pagetitle%}{% end %} 32 |
33 |
34 |
35 |
36 | {% block content %} 37 | 38 | {% end %} 39 |
40 |
41 | 42 | {% block templates %} 43 | {% end %} 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /inference/python/www/templates/layout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block title %}ETAGGER{% end %} 5 | 6 | 7 | 8 | 9 | 29 |
30 |
31 | {% block pagetitle%}{% end %} 32 |
33 |
34 |
35 |
36 | {% block content %} 37 | 38 | {% end %} 39 |
40 |
41 | 42 | {% block templates %} 43 | {% end %} 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /etc/repair.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | class Repair: 6 | def __init__(self): 7 | self.task = 'repair' 8 | 9 | def repair_bucket(self, bucket): 10 | length = len(bucket) 11 | for i in range(length): 12 | line = bucket[i] 13 | line = line.replace('\t', ' ') 14 | tokens = line.split() 15 | word = tokens[0] 16 | pos = tokens[1] 17 | chunk = tokens[2] 18 | tag = tokens[3] 19 | pred = tokens[4] 20 | # 'X' -> 'O' 21 | if pred == 'X': pred = 'O' 22 | # begining 'I-' -> 'O' 23 | if pred[:2] == 'I-': 24 | if i == 0: 25 | pred = 'O' 26 | else: 27 | p_line = bucket[i-1] 28 | p_line = p_line.replace('\t', ' ') 29 | p_tokens = p_line.split() 30 | p_pred = p_tokens[4] 31 | if p_pred == 'O': 32 | pred = 'O' 33 | l = [word, pos, chunk, tag, pred] 34 | out = ' '.join(l) 35 | print(out) 36 | print('') 37 | 38 | def repair(self): 39 | bucket = [] 40 | while 1: 41 | try: line = sys.stdin.readline() 42 | except KeyboardInterrupt: break 43 | if not line: break 44 | line = line.strip() 45 | if not line and len(bucket) >= 1: 46 | self.repair_bucket(bucket) 47 | bucket = [] 48 | if line : bucket.append(line) 49 | if len(bucket) != 0: 50 | self.repair_bucket(bucket) 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | 55 | args = parser.parse_args() 56 | 57 | r = Repair() 58 | r.repair() 59 | -------------------------------------------------------------------------------- /inference/cc/src/inference_iris.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | typedef std::vector> tensor_dict; 8 | 9 | tensorflow::Status LoadFrozenModel(tensorflow::Session *sess, std::string graph_fn) { 10 | tensorflow::Status status; 11 | 12 | // Read in the protobuf graph we exported 13 | tensorflow::GraphDef graph_def; 14 | status = ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def); 15 | if (status != tensorflow::Status::OK()) return status; 16 | 17 | // Create the graph in the current session 18 | status = sess->Create(graph_def); 19 | if (status != tensorflow::Status::OK()) return status; 20 | 21 | return tensorflow::Status::OK(); 22 | } 23 | 24 | int main(int argc, char const *argv[]) { 25 | 26 | const std::string graph_fn = "./exported/iris_frozen.pb"; 27 | 28 | // Prepare session 29 | tensorflow::Session *sess; 30 | tensorflow::SessionOptions options; 31 | TF_CHECK_OK(tensorflow::NewSession(options, &sess)); 32 | TF_CHECK_OK(LoadFrozenModel(sess, graph_fn)); 33 | 34 | // Prepare inputs 35 | tensorflow::TensorShape data_shape({1, 4}); 36 | tensorflow::Tensor data(tensorflow::DT_FLOAT, data_shape); 37 | auto data_ = data.flat().data(); 38 | data_[0] = 2; 39 | data_[1] = 14; 40 | data_[2] = 33; 41 | data_[3] = 50; 42 | tensor_dict feed_dict = { 43 | {"X", data}, 44 | }; 45 | 46 | std::vector outputs; 47 | TF_CHECK_OK(sess->Run(feed_dict, {"logits"}, 48 | {}, &outputs)); 49 | 50 | std::cout << "input " << data.DebugString() << std::endl; 51 | std::cout << "logits " << outputs[0].DebugString() << std::endl; 52 | 53 | return 0; 54 | } 55 | -------------------------------------------------------------------------------- /inference/cc/www/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # code from http://stackoverflow.com/a/1116890 5 | function readlink() 6 | { 7 | TARGET_FILE=$2 8 | cd `dirname $TARGET_FILE` 9 | TARGET_FILE=`basename $TARGET_FILE` 10 | 11 | # Iterate down a (possible) chain of symlinks 12 | while [ -L "$TARGET_FILE" ] 13 | do 14 | TARGET_FILE=`readlink $TARGET_FILE` 15 | cd `dirname $TARGET_FILE` 16 | TARGET_FILE=`basename $TARGET_FILE` 17 | done 18 | 19 | # Compute the canonicalized name by finding the physical path 20 | # for the directory we're in and appending the target file. 21 | PHYS_DIR=`pwd -P` 22 | RESULT=$PHYS_DIR/$TARGET_FILE 23 | echo $RESULT 24 | } 25 | export -f readlink 26 | 27 | export LC_ALL=ko_KR.UTF-8 28 | export LANG=ko_KR.UTF-8 29 | 30 | # directory 31 | ## current dir of this script 32 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 33 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 34 | PPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../..) 35 | PPPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../../..) 36 | 37 | # server 38 | daemon_name='etagger_dm.py' 39 | port_devel=8898 40 | port_service=8898 41 | 42 | # command setting 43 | python='env python' 44 | 45 | # setting 46 | FROZEN_FILENAME=ner_frozen.pb.memmapped 47 | VOCAB_FILENAME=vocab.txt 48 | SO_FILENAME=libetagger.so 49 | WRD_LEN=15 50 | LOWERCASE=True 51 | IS_MEMMAPPED=True 52 | NUM_THREADS=0 53 | 54 | # functions 55 | 56 | function make_calmness() 57 | { 58 | exec 3>&2 # save 2 to 3 59 | exec 2> /dev/null 60 | } 61 | 62 | function revert_calmness() 63 | { 64 | exec 2>&3 # restore 2 from previous saved 3(originally 2) 65 | } 66 | 67 | function close_fd() 68 | { 69 | exec 3>&- 70 | } 71 | 72 | function jumpto 73 | { 74 | label=$1 75 | cmd=$(sed -n "/$label:/{:a;n;p;ba};" $0 | grep -v ':$') 76 | eval "$cmd" 77 | exit 78 | } 79 | -------------------------------------------------------------------------------- /inference/cc/src/inference.cc: -------------------------------------------------------------------------------- 1 | #include "Etagger.h" 2 | 3 | #include 4 | #include 5 | 6 | int main(int argc, char const *argv[]) 7 | { 8 | if( argc < 3 ) { 9 | cerr << argv[0] << " [is_memmapped(1 | 0:default)]" << endl; 10 | return 1; 11 | } 12 | 13 | const string frozen_graph_fn = argv[1]; 14 | const string vocab_fn = argv[2]; 15 | bool is_memmapped = false; 16 | if( argc == 4 && argv[3][0] == '1' ) is_memmapped = true; 17 | 18 | Etagger etagger = Etagger(frozen_graph_fn, 19 | vocab_fn, 20 | 15, // word_length = 15 21 | true, // lowercase = true 22 | is_memmapped, 23 | 0); // 0(all cores) | n(n cores) 24 | 25 | struct timeval t1,t2,t3,t4; 26 | int num_buckets = 0; 27 | double total_duration_time = 0.0; 28 | gettimeofday(&t1, NULL); 29 | 30 | vector bucket; 31 | for( string line; getline(cin, line); ) { 32 | if( line == "" ) { 33 | gettimeofday(&t3, NULL); 34 | 35 | int ret = etagger.Analyze(bucket); 36 | if( ret < 0 ) continue; 37 | for( int i = 0; i < ret; i++ ) { 38 | cout << bucket[i] << endl; 39 | } 40 | cout << endl; 41 | 42 | num_buckets += 1; 43 | bucket.clear(); 44 | 45 | gettimeofday(&t4, NULL); 46 | double duration_time = ((t4.tv_sec - t3.tv_sec)*1000000 + t4.tv_usec - t3.tv_usec)/(double)1000000; 47 | fprintf(stderr,"elapsed time per sentence = %lf sec\n", duration_time); 48 | if( num_buckets != 1) { // first one may takes longer time, so ignore in computing duration. 49 | total_duration_time += duration_time; 50 | } 51 | } else { 52 | bucket.push_back(line); 53 | } 54 | } 55 | gettimeofday(&t2, NULL); 56 | double duration_time = ((t2.tv_sec - t1.tv_sec)*1000000 + t2.tv_usec - t1.tv_usec)/(double)1000000; 57 | fprintf(stderr,"elapsed time = %lf sec\n", duration_time); 58 | fprintf(stderr,"duration time on average = %lf sec\n", total_duration_time / (num_buckets-1)); 59 | 60 | return 0; 61 | } 62 | -------------------------------------------------------------------------------- /etc/test_spacy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | def get_entity(doc, begin, end): 6 | for ent in doc.ents: 7 | # check included 8 | if ent.start_char <= begin and end <= ent.end_char: 9 | if ent.start_char == begin: return 'B-' + ent.label_ 10 | else: return 'I-' + ent.label_ 11 | return 'O' 12 | 13 | def build_bucket(nlp, line): 14 | bucket = [] 15 | uline = line.decode('utf-8','ignore') # unicode 16 | doc = nlp(uline) 17 | print('postag:') 18 | seq = 0 19 | for token in doc: 20 | begin = token.idx 21 | end = begin + len(token.text) - 1 22 | temp = [] 23 | print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_, 24 | token.shape_, token.is_alpha, token.is_stop, begin, end) 25 | temp.append(token.text) 26 | temp.append(token.tag_) 27 | temp.append('O') # no chunking info 28 | entity = get_entity(doc, begin, end) 29 | temp.append(entity) 30 | utemp = ' '.join(temp) 31 | bucket.append(utemp.encode('utf-8')) 32 | seq += 1 33 | print('') 34 | print('named entity:') 35 | for ent in doc.ents: 36 | print(ent.text, ent.start_char, ent.end_char, ent.label_) 37 | print('') 38 | print('noun chunk:') 39 | for chunk in doc.noun_chunks: 40 | print(chunk.text, chunk.root.text, chunk.root.dep_,chunk.root.head.text) 41 | print('') 42 | return bucket 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | args = parser.parse_args() 47 | 48 | import spacy 49 | nlp = spacy.load('en') 50 | 51 | while 1: 52 | try: line = sys.stdin.readline() 53 | except KeyboardInterrupt: break 54 | if not line: break 55 | line = line.strip() 56 | if not line: continue 57 | # Create bucket 58 | try: bucket = build_bucket(nlp, line) 59 | except Exception as e: 60 | sys.stderr.write(str(e) +'\n') 61 | continue 62 | for i in range(len(bucket)): 63 | out = bucket[i] 64 | sys.stdout.write(out + '\n') 65 | sys.stdout.write('\n') 66 | 67 | -------------------------------------------------------------------------------- /test_berttok.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | from bert import tokenization 6 | 7 | class Tok: 8 | def __init__(self, tokenizer): 9 | self.task = 'tok' 10 | self.tokenizer = tokenizer 11 | self.max_seq_length = 0 12 | 13 | def __proc_bucket(self, bucket): 14 | seq_length = 0 15 | for line in bucket: 16 | tokens = line.split() 17 | word = tokens[0] 18 | pos = tokens[1] 19 | chunk = tokens[2] 20 | tag = tokens[3] 21 | # for '-DOCSTART-' 22 | if word == '-DOCSTART-': 23 | ''' 24 | print(word, pos, chunk, tag) 25 | break 26 | ''' 27 | return None 28 | word_exts = self.tokenizer.tokenize(word) 29 | for m in range(len(word_exts)): 30 | if m == 0: 31 | print(word_exts[m], pos, chunk, tag) 32 | else: 33 | print(word_exts[m], pos, chunk, 'X') 34 | seq_length += 1 35 | print('') 36 | if seq_length > self.max_seq_length: self.max_seq_length = seq_length 37 | return None 38 | 39 | def proc(self): 40 | bucket = [] 41 | while 1: 42 | try: line = sys.stdin.readline() 43 | except KeyboardInterrupt: break 44 | if not line: break 45 | line = line.strip() 46 | if not line and len(bucket) >= 1: 47 | self.__proc_bucket(bucket) 48 | bucket = [] 49 | if line : bucket.append(line) 50 | if len(bucket) != 0: 51 | self.__proc_bucket(bucket) 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('--vocab_file', type=str, help='path to bert vocab file', required=True) 56 | parser.add_argument('--do_lower_case', type=str, help='whether to lower case the input text', required=True) 57 | 58 | args = parser.parse_args() 59 | 60 | do_lower_case = True if args.do_lower_case.lower() == 'true' else False 61 | tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=do_lower_case) 62 | 63 | tok = Tok(tokenizer) 64 | tok.proc() 65 | sys.stderr.write('max_seq_length = %s\n' % (tok.max_seq_length)) 66 | -------------------------------------------------------------------------------- /inference/cc/wrapper/Etagger.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import ctypes as c 5 | 6 | libetagger = None 7 | 8 | # Result class interface to 'struct result_obj'. 9 | # this values should be same as those in 'result_obj.h'. 10 | MAX_WORD = 64 11 | MAX_POS = 64 12 | MAX_CHK = 64 13 | MAX_TAG = 64 14 | class Result( c.Structure ): 15 | _fields_ = [('word', c.c_char * MAX_WORD ), 16 | ('pos', c.c_char * MAX_POS ), 17 | ('chk', c.c_char * MAX_CHK ), 18 | ('tag', c.c_char * MAX_TAG ), 19 | ('predict', c.c_char * MAX_TAG )] 20 | 21 | def initialize(so_path, frozen_graph_fn, vocab_fn, word_length=15, lowercase=True, is_memmapped=False, num_threads=0): 22 | global libetagger 23 | if not libetagger: 24 | libetagger = c.cdll.LoadLibrary(so_path) 25 | c_frozen_graph_fn = c.c_char_p(frozen_graph_fn.encode('utf-8')) # unicode -> utf-8 26 | c_vocab_fn = c.c_char_p(vocab_fn.encode('utf-8')) # unicode -> utf-8 27 | c_word_length = c.c_int(word_length) 28 | c_lowercase = c.c_int(0) 29 | if lowercase == True: c_lowercase = c.c_int(1) 30 | c_is_memmapped = c.c_int(0) 31 | if is_memmapped == True: c_is_memmapped = c.c_int(1) 32 | c_num_threads = c.c_int(num_threads) 33 | etagger = libetagger.initialize(c_frozen_graph_fn, 34 | c_vocab_fn, 35 | c_word_length, 36 | c_lowercase, 37 | c_is_memmapped, 38 | c_num_threads) 39 | return etagger 40 | 41 | def analyze(etagger, bucket): 42 | global libetagger 43 | max_sentence_length = len(bucket) 44 | robj = (Result * max_sentence_length)() 45 | for i in range(max_sentence_length): 46 | tokens = bucket[i].split() 47 | robj[i].word = tokens[0].encode('utf-8') 48 | robj[i].pos = tokens[1].encode('utf-8') 49 | robj[i].chk = tokens[2].encode('utf-8') 50 | robj[i].tag = tokens[3].encode('utf-8') 51 | robj[i].predict = b'O' # initial value 'O'(out of tag) 52 | c_max_sentence_length = c.c_int(max_sentence_length) 53 | ret = libetagger.analyze(etagger, c.byref(robj), c_max_sentence_length) 54 | if ret < 0: return None 55 | out = [] 56 | for r in robj: 57 | out.append([r.word.decode('utf-8'), 58 | r.pos.decode('utf-8'), 59 | r.chk.decode('utf-8'), 60 | r.tag.decode('utf-8'), 61 | r.predict.decode('utf-8')]) 62 | return out 63 | 64 | def finalize(etagger): 65 | global libetagger 66 | libetagger.finalize(etagger) 67 | -------------------------------------------------------------------------------- /inference/export.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import time 4 | import argparse 5 | import tensorflow as tf 6 | # for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369 7 | tf.contrib.rnn 8 | # for QRNN 9 | try: import qrnn 10 | except: sys.stderr.write('import qrnn, failed\n') 11 | 12 | def export(args): 13 | session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 14 | sess = tf.Session(config=session_conf) 15 | with sess.as_default(): 16 | # restore meta graph 17 | meta_file = args.restore + '.meta' 18 | loader = tf.train.import_meta_graph(meta_file, clear_devices=True) 19 | # mapping placeholders and tensors 20 | graph = tf.get_default_graph() 21 | p_is_train = graph.get_tensor_by_name('is_train:0') 22 | p_sentence_length = graph.get_tensor_by_name('sentence_length:0') 23 | p_input_data_pos_ids = graph.get_tensor_by_name('input_data_pos_ids:0') 24 | p_input_data_chk_ids = graph.get_tensor_by_name('input_data_chk_ids:0') 25 | p_input_data_word_ids = graph.get_tensor_by_name('input_data_word_ids:0') 26 | p_input_data_wordchr_ids = graph.get_tensor_by_name('input_data_wordchr_ids:0') 27 | t_logits_indices = graph.get_tensor_by_name('logits_indices:0') 28 | t_sentence_lengths = graph.get_tensor_by_name('sentence_lengths:0') 29 | print('is_train', p_is_train) 30 | print('sentence_length', p_sentence_length) 31 | print('input_data_pos_ids', p_input_data_pos_ids) 32 | print('input_data_chk_ids', p_input_data_chk_ids) 33 | print('input_data_word_ids', p_input_data_word_ids) 34 | print('input_data_wordchr_ids', p_input_data_wordchr_ids) 35 | print('logits_indices', t_logits_indices) 36 | print('sentence_lengths', t_sentence_lengths) 37 | # restore actual values 38 | loader.restore(sess, args.restore) 39 | print(tf.global_variables()) 40 | print(tf.trainable_variables()) 41 | print('model restored') 42 | 43 | # save 44 | saver = tf.train.Saver(tf.global_variables()) 45 | saver.save(sess, args.export) 46 | tf.train.write_graph(sess.graph, args.export_pb, "graph.pb", as_text=False) 47 | tf.train.write_graph(sess.graph, args.export_pb, "graph.pb_txt", as_text=True) 48 | print('model exported') 49 | sess.close() 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--restore', type=str, help='path to saved model(ex, ../checkpoint/ner_model)', required=True) 54 | parser.add_argument('--export', type=str, help='path to exporting model(ex, exported/ner_model)', required=True) 55 | parser.add_argument('--export-pb', type=str, help='path to exporting graph proto(ex, exported)', required=True) 56 | 57 | args = parser.parse_args() 58 | export(args) 59 | -------------------------------------------------------------------------------- /inference/train_iris.py: -------------------------------------------------------------------------------- 1 | #!./bin/env python 2 | 3 | import sys 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | def one_hot(y_data) : 8 | a = np.array(y_data, dtype=int) 9 | b = np.zeros((a.size, a.max()+1)) 10 | b[np.arange(a.size),a] = 1 11 | return b 12 | 13 | def prepare_data(xy_data): 14 | x_data = xy_data[1:] 15 | x_data = np.transpose(x_data) # None x 4 16 | ''' 17 | [ [2 14 33 50], 18 | [24 56 31 67], 19 | [23 51 31 69], 20 | .... ] 21 | ''' 22 | print(x_data) 23 | y_data = xy_data[0] # 1 x None 24 | ''' 25 | [ [0 1 1 0 1 2 .... ] ] 26 | ''' 27 | y_data = one_hot(y_data) # None x 3 28 | ''' 29 | [ [1 0 0], 30 | [0 1 0], 31 | [0 1 0], 32 | ... ] 33 | ''' 34 | print(y_data) 35 | return x_data, y_data 36 | 37 | X = tf.placeholder("float", [None, 4], name='X') 38 | Y = tf.placeholder("float", [None, 3], name='Y') 39 | W = tf.Variable(tf.truncated_normal([4,3], stddev=0.01), name='W') 40 | b = tf.Variable(tf.constant(0.1, shape=[3]), name='b') 41 | logits = tf.nn.softmax(tf.matmul(X, W) + b, name='logits') 42 | 43 | cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(logits), reduction_indices=1)) 44 | learning_rate = tf.Variable(0.001) 45 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 46 | train = optimizer.minimize(cost) 47 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(Y,1)) 48 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 49 | 50 | # training 51 | xy_data = np.loadtxt('./etc/iris.txt', unpack=True, dtype='float32') 52 | with tf.Session() as sess: 53 | sess.run(tf.global_variables_initializer()) 54 | x_data, y_data = prepare_data(xy_data) 55 | for i in range(2000): 56 | if i % 100 == 0 : 57 | print("step : ", i) 58 | print("cost : ", sess.run(cost, feed_dict={X: x_data, Y: y_data})) 59 | print(sess.run(W)) 60 | print("training accuracy :", sess.run(accuracy, feed_dict={X: x_data, Y: y_data})) 61 | sess.run(train, feed_dict={X:x_data, Y:y_data}) 62 | 63 | # save graph and weights 64 | saver = tf.train.Saver(tf.global_variables()) 65 | checkpoint_dir = './exported' 66 | checkpoint_file = 'iris_model' 67 | saver.save(sess, checkpoint_dir + '/' + checkpoint_file) 68 | tf.train.write_graph(sess.graph, '.', "./exported/graph.pb", as_text=False) 69 | tf.train.write_graph(sess.graph, '.', "./exported/graph.pb_txt", as_text=True) 70 | graph = tf.get_default_graph() 71 | for op in graph.get_operations(): 72 | print(op.name) 73 | t1 = graph.get_tensor_by_name('logits:0') 74 | t2 = graph.get_tensor_by_name('W:0') 75 | t3 = graph.get_tensor_by_name('b:0') 76 | 77 | t1, t2, t2, X = sess.run([t1, t2, t3, X], {X: x_data}) 78 | 79 | print(tf.global_variables()) 80 | print('X ', X) 81 | print('logits ', t1) 82 | print('W ', t2) 83 | print('b ', t3) 84 | 85 | 86 | -------------------------------------------------------------------------------- /inference/etc/iris.txt: -------------------------------------------------------------------------------- 1 | #Type PW PL SW SL 2 | 0 2 14 33 50 3 | 1 24 56 31 67 4 | 1 23 51 31 69 5 | 0 2 10 36 46 6 | 1 20 52 30 65 7 | 1 19 51 27 58 8 | 2 13 45 28 57 9 | 2 16 47 33 63 10 | 1 17 45 25 49 11 | 2 14 47 32 70 12 | 0 2 16 31 48 13 | 1 19 50 25 63 14 | 0 1 14 36 49 15 | 0 2 13 32 44 16 | 2 12 40 26 58 17 | 1 18 49 27 63 18 | 2 10 33 23 50 19 | 0 2 16 38 51 20 | 0 2 16 30 50 21 | 1 21 56 28 64 22 | 0 4 19 38 51 23 | 0 2 14 30 49 24 | 2 10 41 27 58 25 | 2 15 45 29 60 26 | 0 2 14 36 50 27 | 1 19 51 27 58 28 | 0 4 15 34 54 29 | 1 18 55 31 64 30 | 2 10 33 24 49 31 | 0 2 14 42 55 32 | 1 15 50 22 60 33 | 2 14 39 27 52 34 | 0 2 14 29 44 35 | 2 12 39 27 58 36 | 1 23 57 32 69 37 | 2 15 42 30 59 38 | 1 20 49 28 56 39 | 1 18 58 25 67 40 | 2 13 44 23 63 41 | 2 15 49 25 63 42 | 2 11 30 25 51 43 | 1 21 54 31 69 44 | 1 25 61 36 72 45 | 2 13 36 29 56 46 | 1 21 55 30 68 47 | 0 1 14 30 48 48 | 0 3 17 38 57 49 | 2 14 44 30 66 50 | 0 4 15 37 51 51 | 2 17 50 30 67 52 | 1 22 56 28 64 53 | 1 15 51 28 63 54 | 2 15 45 22 62 55 | 2 14 46 30 61 56 | 2 11 39 25 56 57 | 1 23 59 32 68 58 | 1 23 54 34 62 59 | 1 25 57 33 67 60 | 0 2 13 35 55 61 | 2 15 45 32 64 62 | 1 18 51 30 59 63 | 1 23 53 32 64 64 | 2 15 45 30 54 65 | 1 21 57 33 67 66 | 0 2 13 30 44 67 | 0 2 16 32 47 68 | 1 18 60 32 72 69 | 1 18 49 30 61 70 | 0 2 12 32 50 71 | 0 1 11 30 43 72 | 2 14 44 31 67 73 | 0 2 14 35 51 74 | 0 4 16 34 50 75 | 2 10 35 26 57 76 | 1 23 61 30 77 77 | 2 13 42 26 57 78 | 0 1 15 41 52 79 | 1 18 48 30 60 80 | 2 13 42 27 56 81 | 0 2 15 31 49 82 | 0 4 17 39 54 83 | 2 16 45 34 60 84 | 2 10 35 20 50 85 | 0 2 13 32 47 86 | 2 13 54 29 62 87 | 0 2 15 34 51 88 | 2 10 50 22 60 89 | 0 1 15 31 49 90 | 0 2 15 37 54 91 | 2 12 47 28 61 92 | 2 13 41 28 57 93 | 0 4 13 39 54 94 | 1 20 51 32 65 95 | 2 15 49 31 69 96 | 2 13 40 25 55 97 | 0 3 13 23 45 98 | 0 3 15 38 51 99 | 2 14 48 28 68 100 | 0 2 15 35 52 101 | 1 25 60 33 63 102 | 2 15 46 28 65 103 | 0 3 14 34 46 104 | 2 18 48 32 59 105 | 2 16 51 27 60 106 | 1 18 55 30 65 107 | 0 5 17 33 51 108 | 1 22 67 38 77 109 | 1 21 66 30 76 110 | 1 13 52 30 67 111 | 2 13 40 28 61 112 | 2 11 38 24 55 113 | 0 2 14 34 52 114 | 1 20 64 38 79 115 | 0 6 16 35 50 116 | 1 20 67 28 77 117 | 2 12 44 26 55 118 | 0 3 14 30 48 119 | 0 2 19 34 48 120 | 1 14 56 26 61 121 | 0 2 12 40 58 122 | 1 18 48 28 62 123 | 2 15 45 30 56 124 | 0 2 14 32 46 125 | 0 4 15 44 57 126 | 1 24 56 34 63 127 | 1 16 58 30 72 128 | 1 21 59 30 71 129 | 1 18 56 29 63 130 | 2 12 42 30 57 131 | 1 23 69 26 77 132 | 2 13 56 29 66 133 | 0 2 15 34 52 134 | 2 10 37 24 55 135 | 0 2 15 31 46 136 | 1 19 61 28 74 137 | 0 3 13 35 50 138 | 1 18 63 29 73 139 | 2 15 47 31 67 140 | 2 13 41 30 56 141 | 2 13 43 29 64 142 | 1 22 58 30 65 143 | 0 3 14 35 51 144 | 2 14 47 29 61 145 | 1 19 53 27 64 146 | 0 2 16 34 48 147 | 1 20 50 25 57 148 | 2 13 40 23 55 149 | 0 2 17 34 54 150 | 1 24 51 28 58 151 | 0 2 15 37 53 152 | -------------------------------------------------------------------------------- /test_bilm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from bilm import Batcher, BidirectionalLanguageModel, weight_layers 4 | 5 | def set_cuda_visible_devices(is_train): 6 | import os 7 | os.environ['CUDA_VISIBLE_DEVICES']='2' 8 | if is_train: 9 | from tensorflow.python.client import device_lib 10 | print(device_lib.list_local_devices()) 11 | return True 12 | 13 | set_cuda_visible_devices(True) 14 | 15 | """ 16 | Load resources 17 | """ 18 | # Location of pretrained LM. Here we use the test fixtures. 19 | datadir = os.path.join('embeddings') 20 | vocab_file = os.path.join(datadir, 'elmo_vocab.txt') 21 | options_file = os.path.join(datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json') 22 | weight_file = os.path.join(datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5') 23 | # Create a Batcher to map text to character ids. 24 | batcher = Batcher(vocab_file, 50) 25 | # Build the biLM graph. 26 | bilm = BidirectionalLanguageModel(options_file, weight_file) 27 | 28 | """ 29 | Build graph 30 | """ 31 | # Input placeholders to the biLM. 32 | question_character_ids = tf.placeholder('int32', shape=(None, None, 50)) # word_length = 50 33 | # Get ops to compute the LM embeddings. 34 | question_embeddings_op = bilm(question_character_ids) 35 | # Get an op to compute ELMo (weighted average of the internal biLM layers) 36 | elmo_question_input = weight_layers('input', question_embeddings_op, l2_coef=0.0) 37 | elmo_question_output = weight_layers('output', question_embeddings_op, l2_coef=0.0) 38 | print(elmo_question_input['weighted_op'].get_shape()) 39 | 40 | """ 41 | Prepare input 42 | """ 43 | tokenized_question = [ 44 | ['What', 'are', 'biLMs', 'useful', 'for', '?'] 45 | ] 46 | # Create batches of data. 47 | question_ids = batcher.batch_sentences(tokenized_question) # (batch_size, sentence_length, word_length) 48 | 49 | # padding 50 | question_ids = question_ids.tolist() 51 | print('length = ', len(question_ids[0])) 52 | print(question_ids) 53 | max_sentence_length = 10 54 | for i in range(max_sentence_length - len(question_ids[0]) + 2): 55 | question_ids[0].append([0]*50) 56 | print('length = ', len(question_ids[0])) 57 | print(question_ids) 58 | 59 | 60 | """ 61 | Compute ELMO embedding 62 | """ 63 | with tf.Session() as sess: 64 | # It is necessary to initialize variables once before running inference. 65 | sess.run(tf.global_variables_initializer()) 66 | 67 | # Compute ELMo representations (here for the input only, for simplicity). 68 | elmo_question_input_ = sess.run([elmo_question_input['weighted_op']], 69 | feed_dict={question_character_ids: question_ids}) # (batch_size, sentence_length, model_dim) 70 | print(elmo_question_input_) 71 | # check padding 72 | for i in range(len(elmo_question_input_[0][0])): 73 | print(i, len(elmo_question_input_[0][0][i]), elmo_question_input_[0][0][i]) 74 | 75 | ##### general usage ##### 76 | """ 77 | 1. we have 'tokenized_question' for real input texts. 78 | 2. get elmo_question_input_ 79 | 3. concat glove embedding + elmo_question_input_ 80 | 3. take contextual encoding(via LSTM, Transformer encoder) 81 | 4. concat contextual encoding + elmo_question_output_ 82 | """ 83 | -------------------------------------------------------------------------------- /inference/cc/www/stop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o nounset 4 | set -o errexit 5 | 6 | # code from http://stackoverflow.com/a/1116890 7 | function readlink() 8 | { 9 | TARGET_FILE=$2 10 | cd `dirname $TARGET_FILE` 11 | TARGET_FILE=`basename $TARGET_FILE` 12 | 13 | # Iterate down a (possible) chain of symlinks 14 | while [ -L "$TARGET_FILE" ] 15 | do 16 | TARGET_FILE=`readlink $TARGET_FILE` 17 | cd `dirname $TARGET_FILE` 18 | TARGET_FILE=`basename $TARGET_FILE` 19 | done 20 | 21 | # Compute the canonicalized name by finding the physical path 22 | # for the directory we're in and appending the target file. 23 | PHYS_DIR=`pwd -P` 24 | RESULT=$PHYS_DIR/$TARGET_FILE 25 | echo $RESULT 26 | } 27 | export -f readlink 28 | 29 | VERBOSE_MODE=0 30 | 31 | function error_handler() 32 | { 33 | local STATUS=${1:-1} 34 | [ ${VERBOSE_MODE} == 0 ] && exit ${STATUS} 35 | echo "Exits abnormally at line "`caller 0` 36 | exit ${STATUS} 37 | } 38 | trap "error_handler" ERR 39 | 40 | PROGNAME=`basename ${BASH_SOURCE}` 41 | 42 | function print_usage_and_exit() 43 | { 44 | set +x 45 | local STATUS=$1 46 | echo "Usage: ${PROGNAME} [-v] [-v] [-h] [--help]" 47 | echo "" 48 | echo " Options -" 49 | echo " -v enables verbose mode 1" 50 | echo " -v -v enables verbose mode 2" 51 | echo " -h, --help shows this help message" 52 | exit ${STATUS:-0} 53 | } 54 | 55 | function debug() 56 | { 57 | if [ "$VERBOSE_MODE" != 0 ]; then 58 | echo $@ 59 | fi 60 | } 61 | 62 | GETOPT=`getopt vh $*` 63 | if [ $? != 0 ] ; then print_usage_and_exit 1; fi 64 | 65 | eval set -- "${GETOPT}" 66 | 67 | while true 68 | do case "$1" in 69 | -v) let VERBOSE_MODE+=1; shift;; 70 | -h|--help) print_usage_and_exit 0;; 71 | --) shift; break;; 72 | *) echo "Internal error!"; exit 1;; 73 | esac 74 | done 75 | 76 | if (( VERBOSE_MODE > 1 )); then 77 | set -x 78 | fi 79 | 80 | 81 | # template area is ended. 82 | # ----------------------------------------------------------------------------- 83 | if [ ${#} != 0 ]; then print_usage_and_exit 1; fi 84 | 85 | # current dir of this script 86 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 87 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 88 | PPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../..) 89 | [[ -f ${CDIR}/env.sh ]] && . ${CDIR}/env.sh || exit 90 | 91 | # ----------------------------------------------------------------------------- 92 | # functions 93 | 94 | 95 | 96 | # end functions 97 | # ----------------------------------------------------------------------------- 98 | 99 | 100 | 101 | # ----------------------------------------------------------------------------- 102 | # main 103 | 104 | make_calmness 105 | child_verbose="" 106 | if (( VERBOSE_MODE > 1 )); then 107 | revert_calmness 108 | child_verbose="-v -v" 109 | fi 110 | 111 | for pid in `pgrep -f ${daemon_name}` 112 | do 113 | sudo kill -9 ${pid} 114 | done 115 | 116 | 117 | close_fd 118 | 119 | # end main 120 | # ----------------------------------------------------------------------------- 121 | -------------------------------------------------------------------------------- /inference/python/www/stop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o nounset 4 | set -o errexit 5 | 6 | # code from http://stackoverflow.com/a/1116890 7 | function readlink() 8 | { 9 | TARGET_FILE=$2 10 | cd `dirname $TARGET_FILE` 11 | TARGET_FILE=`basename $TARGET_FILE` 12 | 13 | # Iterate down a (possible) chain of symlinks 14 | while [ -L "$TARGET_FILE" ] 15 | do 16 | TARGET_FILE=`readlink $TARGET_FILE` 17 | cd `dirname $TARGET_FILE` 18 | TARGET_FILE=`basename $TARGET_FILE` 19 | done 20 | 21 | # Compute the canonicalized name by finding the physical path 22 | # for the directory we're in and appending the target file. 23 | PHYS_DIR=`pwd -P` 24 | RESULT=$PHYS_DIR/$TARGET_FILE 25 | echo $RESULT 26 | } 27 | export -f readlink 28 | 29 | VERBOSE_MODE=0 30 | 31 | function error_handler() 32 | { 33 | local STATUS=${1:-1} 34 | [ ${VERBOSE_MODE} == 0 ] && exit ${STATUS} 35 | echo "Exits abnormally at line "`caller 0` 36 | exit ${STATUS} 37 | } 38 | trap "error_handler" ERR 39 | 40 | PROGNAME=`basename ${BASH_SOURCE}` 41 | 42 | function print_usage_and_exit() 43 | { 44 | set +x 45 | local STATUS=$1 46 | echo "Usage: ${PROGNAME} [-v] [-v] [-h] [--help]" 47 | echo "" 48 | echo " Options -" 49 | echo " -v enables verbose mode 1" 50 | echo " -v -v enables verbose mode 2" 51 | echo " -h, --help shows this help message" 52 | exit ${STATUS:-0} 53 | } 54 | 55 | function debug() 56 | { 57 | if [ "$VERBOSE_MODE" != 0 ]; then 58 | echo $@ 59 | fi 60 | } 61 | 62 | GETOPT=`getopt vh $*` 63 | if [ $? != 0 ] ; then print_usage_and_exit 1; fi 64 | 65 | eval set -- "${GETOPT}" 66 | 67 | while true 68 | do case "$1" in 69 | -v) let VERBOSE_MODE+=1; shift;; 70 | -h|--help) print_usage_and_exit 0;; 71 | --) shift; break;; 72 | *) echo "Internal error!"; exit 1;; 73 | esac 74 | done 75 | 76 | if (( VERBOSE_MODE > 1 )); then 77 | set -x 78 | fi 79 | 80 | 81 | # template area is ended. 82 | # ----------------------------------------------------------------------------- 83 | if [ ${#} != 0 ]; then print_usage_and_exit 1; fi 84 | 85 | # current dir of this script 86 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 87 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 88 | PPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../..) 89 | [[ -f ${CDIR}/env.sh ]] && . ${CDIR}/env.sh || exit 90 | 91 | # ----------------------------------------------------------------------------- 92 | # functions 93 | 94 | 95 | 96 | # end functions 97 | # ----------------------------------------------------------------------------- 98 | 99 | 100 | 101 | # ----------------------------------------------------------------------------- 102 | # main 103 | 104 | make_calmness 105 | child_verbose="" 106 | if (( VERBOSE_MODE > 1 )); then 107 | revert_calmness 108 | child_verbose="-v -v" 109 | fi 110 | 111 | for pid in `pgrep -f ${daemon_name}` 112 | do 113 | sudo kill -9 ${pid} 114 | done 115 | 116 | 117 | close_fd 118 | 119 | # end main 120 | # ----------------------------------------------------------------------------- 121 | -------------------------------------------------------------------------------- /inference/cc/src/Vocab.cc: -------------------------------------------------------------------------------- 1 | #include "Vocab.h" 2 | 3 | /* 4 | * public methods 5 | */ 6 | 7 | Vocab::Vocab(string vocab_fn, bool lowercase=true) 8 | { 9 | bool loaded = load_vocab(vocab_fn); 10 | if( !loaded ) { 11 | throw runtime_error("load_vocab() failed!"); 12 | } 13 | this->lowercase = lowercase; 14 | } 15 | 16 | void Vocab::Split(string s, vector& tokens) 17 | { 18 | istringstream iss(s); 19 | for( string ts; iss >> ts; ) 20 | tokens.push_back(ts); 21 | } 22 | 23 | int Vocab::GetWid(string word) 24 | { 25 | if( this->lowercase ) { 26 | transform(word.begin(), word.end(), word.begin(),::tolower); 27 | } 28 | if( this->wrd_vocab.find(word) != this->wrd_vocab.end() ) { 29 | return this->wrd_vocab[word]; 30 | } 31 | return this->unk_wid; 32 | } 33 | 34 | int Vocab::GetCid(string ch) 35 | { 36 | if( this->chr_vocab.find(ch) != this->chr_vocab.end() ) { 37 | return this->chr_vocab[ch]; 38 | } 39 | return this->unk_cid; 40 | } 41 | 42 | int Vocab::GetPid(string pos) 43 | { 44 | if( this->pos_vocab.find(pos) != this->pos_vocab.end() ) { 45 | return this->pos_vocab[pos]; 46 | } 47 | return this->unk_pid; 48 | } 49 | 50 | int Vocab::GetKid(string chk) 51 | { 52 | if( this->chk_vocab.find(chk) != this->chk_vocab.end() ) { 53 | return this->chk_vocab[chk]; 54 | } 55 | return this->unk_kid; 56 | } 57 | 58 | string Vocab::GetTag(int tid) 59 | { 60 | if( this->itag_vocab.find(tid) != this->itag_vocab.end() ) { 61 | return this->itag_vocab[tid]; 62 | } 63 | return this->oot_tag; 64 | } 65 | 66 | Vocab::~Vocab() 67 | { 68 | } 69 | 70 | /* 71 | * private methods 72 | */ 73 | 74 | bool Vocab::load_vocab(string vocab_fn) 75 | { 76 | fstream fs(vocab_fn, ios_base::in); 77 | if( !fs.is_open() ) { 78 | cerr << "Can't find " << vocab_fn << endl; 79 | return false; 80 | } 81 | string line = ""; 82 | int mode = 0; 83 | string key = ""; 84 | int id = 0; 85 | while( getline(fs, line) ) { 86 | if( line.find("# wrd_vocab") != string::npos ) mode = 1; // wrd_vocab 87 | if( line.find("# chr_vocab") != string::npos ) mode = 2; // chr_vocab 88 | if( line.find("# pos_vocab") != string::npos ) mode = 3; // pos_vocab 89 | if( line.find("# chk_vocab") != string::npos ) mode = 4; // chk_vocab 90 | if( line.find("# tag_vocab") != string::npos ) mode = 5; // tag_vocab 91 | vector tokens; 92 | Split(line, tokens); 93 | if( tokens.size() != 2 ) continue; 94 | key = tokens[0]; 95 | id = atoi(tokens[1].c_str()); 96 | if( mode == 1 ) { 97 | this->wrd_vocab.insert(make_pair(key, id)); 98 | } 99 | if( mode == 2 ) { 100 | this->chr_vocab.insert(make_pair(key, id)); 101 | } 102 | if( mode == 3 ) { 103 | this->pos_vocab.insert(make_pair(key, id)); 104 | } 105 | if( mode == 4 ) { 106 | this->chk_vocab.insert(make_pair(key, id)); 107 | } 108 | if( mode == 5 ) { 109 | this->tag_vocab.insert(make_pair(key, id)); 110 | this->itag_vocab.insert(make_pair(id, key)); 111 | } 112 | } 113 | fs.close(); 114 | 115 | #ifdef DEBUG 116 | for( auto itr = this->wrd_vocab.cbegin(); itr != this->wrd_vocab.cend(); ++itr ) { 117 | cout << itr->first << " " << itr->second << endl; 118 | } 119 | for( auto itr = this->itag_vocab.cbegin(); itr != this->itag_vocab.cend(); ++itr ) { 120 | cout << itr->first << " " << itr->second << endl; 121 | } 122 | #endif 123 | 124 | return true; 125 | } 126 | 127 | -------------------------------------------------------------------------------- /inference/cc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required( VERSION 3.11 ) 2 | 3 | # set version and project 4 | set(etagger_VERSION_MAJOR 1) 5 | set(etagger_VERSION_MINOR 0) 6 | set(etagger_VERSION_PATCH 0) 7 | set(etagger_VERSION ${etagger_VERSION_MAJOR}.${etagger_VERSION_MINOR}.${etagger_VERSION_PATCH}) 8 | if (POLICY CMP0048) 9 | cmake_policy(SET CMP0048 NEW) 10 | endif (POLICY CMP0048) 11 | project(etagger VERSION ${etagger_VERSION}) 12 | 13 | # check dependencies 14 | list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/../cmake/modules) 15 | find_package(TensorFlow 1.11 EXACT REQUIRED) 16 | set(CMAKE_CXX_STANDARD 11) 17 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 18 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=${TensorFlow_ABI}") 19 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=${TensorFlow_ABI}" ) 20 | TensorFlow_REQUIRE_C_LIBRARY() 21 | TensorFlow_REQUIRE_SOURCE() 22 | 23 | # for linking libtensorflow_cc.so, libtensorflow_framework.so in the ${TENSORFLOW_BUILD_DIR} 24 | find_library(TensorFlow_CC_LIBRARY 25 | NAMES libtensorflow_cc.so 26 | PATHS $ENV{TENSORFLOW_BUILD_DIR} 27 | DOC "TensorFlow C library." ) 28 | add_library(TensorFlow_CC_DEP INTERFACE) 29 | TARGET_LINK_LIBRARIES(TensorFlow_CC_DEP INTERFACE -Wl,--allow-multiple-definition -Wl,--whole-archive ${TensorFlow_CC_LIBRARY} -Wl,--no-whole-archive) 30 | 31 | find_library(TensorFlow_FW_LIBRARY 32 | NAMES libtensorflow_framework.so 33 | PATHS $ENV{TENSORFLOW_BUILD_DIR} 34 | DOC "TensorFlow framework library." ) 35 | add_library(TensorFlow_FW_DEP INTERFACE) 36 | TARGET_LINK_LIBRARIES(TensorFlow_FW_DEP INTERFACE -Wl,--allow-multiple-definition -Wl,--whole-archive ${TensorFlow_FW_LIBRARY} -Wl,--no-whole-archive) 37 | 38 | # build libraries 39 | include_directories(include) 40 | set(etagger_src "src/Config.cc" "src/Vocab.cc" "src/Input.cc" "src/TFUtil.cc" "src/Etagger.cc") 41 | add_library(etagger SHARED ${etagger_src}) 42 | target_include_directories(etagger PRIVATE TensorFlow_DEP) 43 | target_link_libraries(etagger PRIVATE TensorFlow_DEP) 44 | target_link_libraries(etagger PRIVATE TensorFlow_FW_DEP) 45 | target_link_libraries(etagger PRIVATE TensorFlow_CC_DEP) 46 | set_target_properties(etagger PROPERTIES VERSION ${etagger_VERSION} 47 | SOVERSION ${etagger_VERSION_MAJOR}) 48 | add_library(etagger_static STATIC ${etagger_src}) 49 | target_include_directories(etagger_static PRIVATE TensorFlow_DEP) 50 | target_link_libraries(etagger_static PRIVATE TensorFlow_DEP) 51 | target_link_libraries(etagger_static PRIVATE TensorFlow_CC_DEP) 52 | target_link_libraries(etagger_static PRIVATE TensorFlow_FW_DEP) 53 | 54 | # build executable binaries 55 | add_executable (inference src/inference.cc) 56 | target_include_directories(inference PRIVATE TensorFlow_DEP) 57 | target_link_libraries(inference PRIVATE TensorFlow_DEP) 58 | target_link_libraries(inference PRIVATE TensorFlow_CC_DEP) 59 | target_link_libraries(inference PRIVATE TensorFlow_FW_DEP) 60 | target_link_libraries(inference PRIVATE etagger) 61 | 62 | 63 | # build misc executable binaries 64 | add_executable (inference_example src/inference_example.cc) 65 | target_include_directories(inference_example PRIVATE TensorFlow_DEP) 66 | target_link_libraries(inference_example PRIVATE TensorFlow_DEP) 67 | target_link_libraries(inference_example PRIVATE TensorFlow_CC_DEP) 68 | target_link_libraries(inference_example PRIVATE TensorFlow_FW_DEP) 69 | 70 | add_executable (inference_iris src/inference_iris.cc) 71 | target_include_directories(inference_iris PRIVATE TensorFlow_DEP) 72 | target_link_libraries(inference_iris PRIVATE TensorFlow_DEP) 73 | target_link_libraries(inference_iris PRIVATE TensorFlow_CC_DEP) 74 | target_link_libraries(inference_iris PRIVATE TensorFlow_FW_DEP) 75 | 76 | -------------------------------------------------------------------------------- /inference/cc/wrapper/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import time 5 | import argparse 6 | 7 | # etagger 8 | import Etagger 9 | 10 | ############################################################################### 11 | # nlp : spacy 12 | import spacy 13 | nlp = spacy.load('en') 14 | 15 | def get_entity(doc, begin, end): 16 | for ent in doc.ents: 17 | # check included 18 | if ent.start_char <= begin and end <= ent.end_char: 19 | if ent.start_char == begin: return 'B-' + ent.label_ 20 | else: return 'I-' + ent.label_ 21 | return 'O' 22 | 23 | def build_bucket(nlp, line): 24 | bucket = [] 25 | doc = nlp(line) 26 | for token in doc: 27 | begin = token.idx 28 | end = begin + len(token.text) - 1 29 | temp = [] 30 | ''' 31 | print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_, 32 | token.shape_, token.is_alpha, token.is_stop, begin, end) 33 | ''' 34 | temp.append(token.text) 35 | temp.append(token.tag_) 36 | temp.append('O') # no chunking info 37 | entity = get_entity(doc, begin, end) 38 | temp.append(entity) # entity by spacy 39 | temp = ' '.join(temp) 40 | bucket.append(temp) 41 | return bucket 42 | ############################################################################### 43 | 44 | def inference(so_path, frozen_graph_fn, vocab_fn, word_length, lowercase=True, is_memmapped=False): 45 | 46 | etagger = Etagger.initialize(so_path, 47 | frozen_graph_fn, 48 | vocab_fn, 49 | word_length=word_length, 50 | lowercase=lowercase, 51 | is_memmapped=is_memmapped, 52 | num_threads=0) 53 | 54 | num_buckets = 0 55 | total_duration_time = 0.0 56 | while 1: 57 | try: line = sys.stdin.readline() 58 | except KeyboardInterrupt: break 59 | if not line: break 60 | line = line.strip() 61 | bucket = build_bucket(nlp, line) 62 | start_time = time.time() 63 | out = Etagger.analyze(etagger, bucket) 64 | if not out: continue 65 | for o in out: 66 | print(' '.join(o)) 67 | print('') 68 | duration_time = time.time() - start_time 69 | out = 'duration_time : ' + str(duration_time) + ' sec' 70 | sys.stderr.write(out + '\n') 71 | num_buckets += 1 72 | if num_buckets != 1: # first one may takes longer time, so ignore in computing duration. 73 | total_duration_time += duration_time 74 | 75 | out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n' 76 | out += 'average processing time / bucket : ' + str(total_duration_time / (num_buckets-1)) + ' sec' 77 | sys.stderr.write(out + '\n') 78 | 79 | Etagger.finalize(etagger) 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--frozen_graph_fn', type=str, help='path to frozen model(ex, ./exported/ner_frozen.pb)', required=True) 84 | parser.add_argument('--vocab_fn', type=str, help='path to vocab(ex, vocab.txt)', required=True) 85 | parser.add_argument('--word_length', type=int, default=15, help='max word length') 86 | parser.add_argument('--is_memmapped', type=str, default='False', help='is memory mapped graph, True | False') 87 | 88 | args = parser.parse_args() 89 | is_memmapped = False 90 | if args.is_memmapped == 'True': is_memmapped = True 91 | 92 | # etagger library path 93 | so_path = os.path.dirname(os.path.abspath(__file__)) + '/../build' + '/' + 'libetagger.so' 94 | 95 | inference(so_path, args.frozen_graph_fn, args.vocab_fn, args.word_length, lowercase=True, is_memmapped=is_memmapped) 96 | -------------------------------------------------------------------------------- /etc/chunk_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | 5 | class ChunkEval: 6 | """Chunk-based evaluation 7 | """ 8 | 9 | def __init__(self): 10 | self.tag_sents = [] 11 | self.pred_sents = [] 12 | 13 | def __eval_bucket(self, bucket): 14 | tag_sent = [] 15 | pred_sent = [] 16 | line_num = 0 17 | for line in bucket: 18 | tokens = line.split() 19 | size = len(tokens) 20 | if line_num == 0 and size == 3: # skip 'USING SKIP CONNECTIONS' 21 | line_num += 1 22 | continue 23 | assert(size == 5) 24 | w = tokens[0] 25 | pos = tokens[1] 26 | chunk = tokens[2] 27 | tag = tokens[3] 28 | pred = tokens[4] 29 | tag_sent.append(tag) 30 | pred_sent.append(pred) 31 | line_num += 1 32 | self.tag_sents.append(tag_sent) 33 | self.pred_sents.append(pred_sent) 34 | 35 | def eval(self): 36 | """Compute micro chunk fscore along with precision, recall given file. 37 | """ 38 | bucket = [] 39 | while 1: 40 | try: line = sys.stdin.readline() 41 | except KeyboardInterrupt: break 42 | if not line: break 43 | line = line.strip() 44 | if not line and len(bucket) >= 1: 45 | self.__eval_bucket(bucket) 46 | bucket = [] 47 | if line : bucket.append(line) 48 | if len(bucket) != 0: 49 | self.__eval_bucket(bucket) 50 | fscore = self.compute_f1(self.pred_sents, self.tag_sents) 51 | print('precision, recall, fscore = ', fscore) 52 | 53 | @staticmethod 54 | def compute_precision(guessed_sentences, correct_sentences): 55 | """Compute micro precision given tag-predictions(guessed sentences) 56 | and tag-corrects(correct sentences). 57 | """ 58 | assert(len(guessed_sentences) == len(correct_sentences)) 59 | correctCount = 0 60 | count = 0 61 | for sentenceIdx in range(len(guessed_sentences)): 62 | guessed = guessed_sentences[sentenceIdx] 63 | correct = correct_sentences[sentenceIdx] 64 | assert(len(guessed) == len(correct)) 65 | idx = 0 66 | while idx < len(guessed): 67 | if guessed[idx][0] == 'B': # A new chunk starts 68 | count += 1 69 | ''' 70 | print('guessed, correct : ', guessed[idx], correct[idx]) 71 | ''' 72 | if guessed[idx] == correct[idx]: 73 | idx += 1 74 | correctlyFound = True 75 | while idx < len(guessed) and guessed[idx][0] == 'I': # Scan until it no longer starts with I. 76 | if guessed[idx] != correct[idx]: 77 | correctlyFound = False 78 | idx += 1 79 | if idx < len(guessed): 80 | if correct[idx][0] == 'I': # The chunk in correct was longer. 81 | correctlyFound = False 82 | if correctlyFound: 83 | correctCount += 1 84 | else: 85 | idx += 1 86 | else: 87 | idx += 1 88 | precision = 0 89 | if count > 0: 90 | precision = float(correctCount) / count 91 | return precision 92 | 93 | @staticmethod 94 | def compute_f1(tag_preds, tag_corrects): 95 | """Compute micro Fscore given tag-predictions and tag-corrects 96 | along with Precision, Recall. 97 | """ 98 | prec = ChunkEval.compute_precision(tag_preds, tag_corrects) 99 | rec = ChunkEval.compute_precision(tag_corrects, tag_preds) 100 | f1 = 0 101 | if (rec+prec) > 0: 102 | f1 = 2.0 * prec * rec / (prec + rec); 103 | return prec, rec, f1 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | 109 | args = parser.parse_args() 110 | 111 | ev = ChunkEval() 112 | ev.eval() 113 | -------------------------------------------------------------------------------- /inference/cc/src/inference_example.cc: -------------------------------------------------------------------------------- 1 | // source is from https://github.com/PatWie/tensorflow-cmake/blob/master/inference/cc/inference_cc.cc 2 | // 2018, Patrick Wieschollek 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | typedef std::vector> tensor_dict; 10 | 11 | /** 12 | * @brief load a previous store model 13 | * @details [long description] 14 | * 15 | * in Python run: 16 | * 17 | * saver = tf.train.Saver(tf.global_variables()) 18 | * saver.save(sess, './exported/my_model') 19 | * tf.train.write_graph(sess.graph, '.', './exported/graph.pb, as_text=False) 20 | * 21 | * this relies on a graph which has an operation called `init` responsible to 22 | * initialize all variables, eg. 23 | * 24 | * sess.run(tf.global_variables_initializer()) # somewhere in the python 25 | * file 26 | * 27 | * @param sess active tensorflow session 28 | * @param graph_fn path to graph file (eg. "./exported/graph.pb") 29 | * @param checkpoint_fn path to checkpoint file (eg. "./exported/my_model", 30 | * optional) 31 | * @return status of reloading 32 | */ 33 | tensorflow::Status LoadModel(tensorflow::Session *sess, std::string graph_fn, 34 | std::string checkpoint_fn = "") { 35 | tensorflow::Status status; 36 | 37 | // Read in the protobuf graph we exported 38 | tensorflow::MetaGraphDef graph_def; 39 | status = ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def); 40 | if (status != tensorflow::Status::OK()) return status; 41 | 42 | // create the graph in the current session 43 | status = sess->Create(graph_def.graph_def()); 44 | if (status != tensorflow::Status::OK()) return status; 45 | 46 | // restore model from checkpoint, iff checkpoint is given 47 | if (checkpoint_fn != "") { 48 | const std::string restore_op_name = graph_def.saver_def().restore_op_name(); 49 | const std::string filename_tensor_name = 50 | graph_def.saver_def().filename_tensor_name(); 51 | 52 | tensorflow::Tensor filename_tensor(tensorflow::DT_STRING, 53 | tensorflow::TensorShape()); 54 | filename_tensor.scalar()() = checkpoint_fn; 55 | 56 | tensor_dict feed_dict = {{filename_tensor_name, filename_tensor}}; 57 | status = sess->Run(feed_dict, {}, {restore_op_name}, nullptr); 58 | if (status != tensorflow::Status::OK()) return status; 59 | } else { 60 | // virtual Status Run(const std::vector >& inputs, 61 | // const std::vector& output_tensor_names, 62 | // const std::vector& target_node_names, 63 | // std::vector* outputs) = 0; 64 | status = sess->Run({}, {}, {"init"}, nullptr); 65 | if (status != tensorflow::Status::OK()) return status; 66 | } 67 | 68 | return tensorflow::Status::OK(); 69 | } 70 | 71 | int main(int argc, char const *argv[]) { 72 | const std::string graph_fn = "./exported/my_model.meta"; 73 | const std::string checkpoint_fn = "./exported/my_model"; 74 | 75 | // prepare session 76 | tensorflow::Session *sess; 77 | tensorflow::SessionOptions options; 78 | TF_CHECK_OK(tensorflow::NewSession(options, &sess)); 79 | TF_CHECK_OK(LoadModel(sess, graph_fn, checkpoint_fn)); 80 | 81 | // prepare inputs 82 | tensorflow::TensorShape data_shape({1, 2}); 83 | tensorflow::Tensor data(tensorflow::DT_FLOAT, data_shape); 84 | 85 | // same as in python file 86 | auto data_ = data.flat().data(); 87 | for (int i = 0; i < 2; ++i) data_[i] = 1; 88 | 89 | tensor_dict feed_dict = { 90 | {"input", data}, 91 | }; 92 | 93 | std::vector outputs; 94 | TF_CHECK_OK(sess->Run(feed_dict, {"output", "dense/kernel:0", "dense/bias:0"}, 95 | {}, &outputs)); 96 | 97 | std::cout << "input " << data.DebugString() << std::endl; 98 | std::cout << "output " << outputs[0].DebugString() << std::endl; 99 | std::cout << "dense/kernel:0 " << outputs[1].DebugString() << std::endl; 100 | std::cout << "dense/bias:0 " << outputs[2].DebugString() << std::endl; 101 | 102 | return 0; 103 | } 104 | -------------------------------------------------------------------------------- /progbar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import logging 4 | import numpy as np 5 | 6 | ''' 7 | source from https://github.com/guillaumegenthial/sequence_tagging/blob/master/model/general_utils.py 8 | ''' 9 | 10 | class Progbar(object): 11 | """Progbar class copied from keras (https://github.com/fchollet/keras/) 12 | 13 | Displays a progress bar. 14 | Small edit : added strict arg to update 15 | # Arguments 16 | target: Total number of steps expected. 17 | interval: Minimum visual progress update interval (in seconds). 18 | """ 19 | 20 | def __init__(self, target, width=30, verbose=1): 21 | self.width = width 22 | self.target = target 23 | self.sum_values = {} 24 | self.unique_values = [] 25 | self.start = time.time() 26 | self.total_width = 0 27 | self.seen_so_far = 0 28 | self.verbose = verbose 29 | 30 | def update(self, current, values=[], exact=[], strict=[]): 31 | """ 32 | Updates the progress bar. 33 | # Arguments 34 | current: Index of current step. 35 | values: List of tuples (name, value_for_last_step). 36 | The progress bar will display averages for these values. 37 | exact: List of tuples (name, value_for_last_step). 38 | The progress bar will display these values directly. 39 | """ 40 | 41 | for k, v in values: 42 | if k not in self.sum_values: 43 | self.sum_values[k] = [v * (current - self.seen_so_far), 44 | current - self.seen_so_far] 45 | self.unique_values.append(k) 46 | else: 47 | self.sum_values[k][0] += v * (current - self.seen_so_far) 48 | self.sum_values[k][1] += (current - self.seen_so_far) 49 | for k, v in exact: 50 | if k not in self.sum_values: 51 | self.unique_values.append(k) 52 | self.sum_values[k] = [v, 1] 53 | 54 | for k, v in strict: 55 | if k not in self.sum_values: 56 | self.unique_values.append(k) 57 | self.sum_values[k] = v 58 | 59 | self.seen_so_far = current 60 | 61 | now = time.time() 62 | if self.verbose == 1: 63 | prev_total_width = self.total_width 64 | sys.stdout.write("\b" * prev_total_width) 65 | sys.stdout.write("\r") 66 | 67 | numdigits = int(np.floor(np.log10(self.target))) + 1 68 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 69 | bar = barstr % (current, self.target) 70 | prog = float(current)/self.target 71 | prog_width = int(self.width*prog) 72 | if prog_width > 0: 73 | bar += ('='*(prog_width-1)) 74 | if current < self.target: 75 | bar += '>' 76 | else: 77 | bar += '=' 78 | bar += ('.'*(self.width-prog_width)) 79 | bar += ']' 80 | sys.stdout.write(bar) 81 | self.total_width = len(bar) 82 | 83 | if current: 84 | time_per_unit = (now - self.start) / current 85 | else: 86 | time_per_unit = 0 87 | eta = time_per_unit*(self.target - current) 88 | info = '' 89 | if current < self.target: 90 | info += ' - ETA: %ds' % eta 91 | else: 92 | info += ' - %ds' % (now - self.start) 93 | for k in self.unique_values: 94 | if type(self.sum_values[k]) is list: 95 | info += ' - %s: %.6f' % (k, 96 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 97 | else: 98 | info += ' - %s: %s' % (k, self.sum_values[k]) 99 | 100 | self.total_width += len(info) 101 | if prev_total_width > self.total_width: 102 | info += ((prev_total_width-self.total_width) * " ") 103 | 104 | sys.stdout.write(info) 105 | sys.stdout.flush() 106 | 107 | if current >= self.target: 108 | sys.stdout.write("\n") 109 | 110 | if self.verbose == 2: 111 | if current >= self.target: 112 | info = '%ds' % (now - self.start) 113 | for k in self.unique_values: 114 | info += ' - %s: %.6f' % (k, 115 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 116 | sys.stdout.write(info + "\n") 117 | 118 | def add(self, n, values=[]): 119 | self.update(self.seen_so_far+n, values) 120 | 121 | 122 | -------------------------------------------------------------------------------- /inference/python/www/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o nounset 4 | set -o errexit 5 | 6 | # code from http://stackoverflow.com/a/1116890 7 | function readlink() 8 | { 9 | TARGET_FILE=$2 10 | cd `dirname $TARGET_FILE` 11 | TARGET_FILE=`basename $TARGET_FILE` 12 | 13 | # Iterate down a (possible) chain of symlinks 14 | while [ -L "$TARGET_FILE" ] 15 | do 16 | TARGET_FILE=`readlink $TARGET_FILE` 17 | cd `dirname $TARGET_FILE` 18 | TARGET_FILE=`basename $TARGET_FILE` 19 | done 20 | 21 | # Compute the canonicalized name by finding the physical path 22 | # for the directory we're in and appending the target file. 23 | PHYS_DIR=`pwd -P` 24 | RESULT=$PHYS_DIR/$TARGET_FILE 25 | echo $RESULT 26 | } 27 | export -f readlink 28 | 29 | VERBOSE_MODE=0 30 | 31 | function error_handler() 32 | { 33 | local STATUS=${1:-1} 34 | [ ${VERBOSE_MODE} == 0 ] && exit ${STATUS} 35 | echo "Exits abnormally at line "`caller 0` 36 | exit ${STATUS} 37 | } 38 | trap "error_handler" ERR 39 | 40 | PROGNAME=`basename ${BASH_SOURCE}` 41 | 42 | function print_usage_and_exit() 43 | { 44 | set +x 45 | local STATUS=$1 46 | echo "Usage: ${PROGNAME} [-v] [-v] [-h] [--help] [mode] [process]" 47 | echo "" 48 | echo " mode 0 : devel, 1 : service" 49 | echo " process 0 : max to #core, [1...n] : number of process" 50 | echo " Options -" 51 | echo " -v enables verbose mode 1" 52 | echo " -v -v enables verbose mode 2" 53 | echo " -h, --help shows this help message" 54 | exit ${STATUS:-0} 55 | } 56 | 57 | function debug() 58 | { 59 | if [ "$VERBOSE_MODE" != 0 ]; then 60 | echo $@ 61 | fi 62 | } 63 | 64 | GETOPT=`getopt vh $*` 65 | if [ $? != 0 ] ; then print_usage_and_exit 1; fi 66 | 67 | eval set -- "${GETOPT}" 68 | 69 | while true 70 | do case "$1" in 71 | -v) let VERBOSE_MODE+=1; shift;; 72 | -h|--help) print_usage_and_exit 0;; 73 | --) shift; break;; 74 | *) echo "Internal error!"; exit 1;; 75 | esac 76 | done 77 | 78 | if (( VERBOSE_MODE > 1 )); then 79 | set -x 80 | fi 81 | 82 | 83 | # template area is ended. 84 | # ----------------------------------------------------------------------------- 85 | if [ ${#} != 2 ]; then print_usage_and_exit 1; fi 86 | 87 | # current dir of this script 88 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 89 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 90 | PPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../..) 91 | PPPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../../..) 92 | [[ -f ${CDIR}/env.sh ]] && . ${CDIR}/env.sh || exit 93 | 94 | # ----------------------------------------------------------------------------- 95 | # functions 96 | 97 | function check_running 98 | { 99 | progname=$1 100 | count_pgrep=`pgrep -f ${progname} | wc -l` 101 | count_pgrep=$(( ${count_pgrep} - 1 )) 102 | if (( count_pgrep > 0 )); then 103 | revert_calmness 104 | echo "count_pgrep = ${count_pgrep}" 105 | echo "${progname} is already running" 106 | exit 0 107 | fi 108 | } 109 | 110 | 111 | # end functions 112 | # ----------------------------------------------------------------------------- 113 | 114 | 115 | 116 | # ----------------------------------------------------------------------------- 117 | # main 118 | 119 | make_calmness 120 | child_verbose="" 121 | if (( VERBOSE_MODE > 1 )); then 122 | revert_calmness 123 | child_verbose="-v -v" 124 | fi 125 | 126 | MODE=$1 127 | PROCESS=$2 128 | 129 | check_running ${daemon_name} 130 | 131 | mkdir -p ${CDIR}/data 132 | mkdir -p ${CDIR}/lib 133 | 134 | function copy_resources { 135 | # data 136 | cp -rf ${PPPDIR}/embeddings/${EMB_FILENAME} ${CDIR}/data 137 | cp -rf ${PPDIR}/exported/${FROZEN_FILENAME} ${CDIR}/data 138 | cp -rf ${PPDIR}/data/${CONFIG_FILENAME} ${CDIR}/data 139 | # lib 140 | cp -rf ${PPPDIR}/embvec.py ${CDIR}/lib 141 | cp -rf ${PPPDIR}/config.py ${CDIR}/lib 142 | cp -rf ${PPPDIR}/input.py ${CDIR}/lib 143 | cp -rf ${PPPDIR}/feed.py ${CDIR}/lib 144 | # for bert 145 | case "${EMB_CLASS}" in 146 | *bert*) 147 | cp -rf ${PPPDIR}/bert ${CDIR}/lib 148 | ;; 149 | esac 150 | } 151 | copy_resources 152 | 153 | EMB_PATH=${CDIR}/data/${EMB_FILENAME} 154 | CONFIG_PATH=${CDIR}/data/${CONFIG_FILENAME} 155 | FROZEN_PATH=${CDIR}/data/${FROZEN_FILENAME} 156 | 157 | cd ${CDIR} 158 | 159 | if (( MODE == 0 )); then 160 | debug=True 161 | port=${port_devel} 162 | else 163 | debug=False 164 | port=${port_service} 165 | fi 166 | 167 | nohup ${python} ${CDIR}/${daemon_name} \ 168 | --debug=${debug} \ 169 | --port=${port} \ 170 | --emb_path=${EMB_PATH} \ 171 | --emb_class=${EMB_CLASS} \ 172 | --config_path=${CONFIG_PATH} \ 173 | --wrd_dim=${WRD_DIM} \ 174 | --frozen_path=${FROZEN_PATH} \ 175 | --log_file_prefix=${CDIR}/log/access.log \ 176 | > /dev/null 2> /dev/null & 177 | 178 | cd ${CDIR} 179 | 180 | close_fd 181 | 182 | # end main 183 | # ----------------------------------------------------------------------------- 184 | -------------------------------------------------------------------------------- /inference/cc/www/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o nounset 4 | set -o errexit 5 | 6 | # code from http://stackoverflow.com/a/1116890 7 | function readlink() 8 | { 9 | TARGET_FILE=$2 10 | cd `dirname $TARGET_FILE` 11 | TARGET_FILE=`basename $TARGET_FILE` 12 | 13 | # Iterate down a (possible) chain of symlinks 14 | while [ -L "$TARGET_FILE" ] 15 | do 16 | TARGET_FILE=`readlink $TARGET_FILE` 17 | cd `dirname $TARGET_FILE` 18 | TARGET_FILE=`basename $TARGET_FILE` 19 | done 20 | 21 | # Compute the canonicalized name by finding the physical path 22 | # for the directory we're in and appending the target file. 23 | PHYS_DIR=`pwd -P` 24 | RESULT=$PHYS_DIR/$TARGET_FILE 25 | echo $RESULT 26 | } 27 | export -f readlink 28 | 29 | VERBOSE_MODE=0 30 | 31 | function error_handler() 32 | { 33 | local STATUS=${1:-1} 34 | [ ${VERBOSE_MODE} == 0 ] && exit ${STATUS} 35 | echo "Exits abnormally at line "`caller 0` 36 | exit ${STATUS} 37 | } 38 | trap "error_handler" ERR 39 | 40 | PROGNAME=`basename ${BASH_SOURCE}` 41 | 42 | function print_usage_and_exit() 43 | { 44 | set +x 45 | local STATUS=$1 46 | echo "Usage: ${PROGNAME} [-v] [-v] [-h] [--help] [mode] [process]" 47 | echo "" 48 | echo " mode 0 : devel, 1 : service" 49 | echo " process 0 : max to #core, [1...n] : number of process" 50 | echo " Options -" 51 | echo " -v enables verbose mode 1" 52 | echo " -v -v enables verbose mode 2" 53 | echo " -h, --help shows this help message" 54 | exit ${STATUS:-0} 55 | } 56 | 57 | function debug() 58 | { 59 | if [ "$VERBOSE_MODE" != 0 ]; then 60 | echo $@ 61 | fi 62 | } 63 | 64 | GETOPT=`getopt vh $*` 65 | if [ $? != 0 ] ; then print_usage_and_exit 1; fi 66 | 67 | eval set -- "${GETOPT}" 68 | 69 | while true 70 | do case "$1" in 71 | -v) let VERBOSE_MODE+=1; shift;; 72 | -h|--help) print_usage_and_exit 0;; 73 | --) shift; break;; 74 | *) echo "Internal error!"; exit 1;; 75 | esac 76 | done 77 | 78 | if (( VERBOSE_MODE > 1 )); then 79 | set -x 80 | fi 81 | 82 | 83 | # template area is ended. 84 | # ----------------------------------------------------------------------------- 85 | if [ ${#} != 2 ]; then print_usage_and_exit 1; fi 86 | 87 | # current dir of this script 88 | CDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))) 89 | PDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/..) 90 | PPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../..) 91 | PPPDIR=$(readlink -f $(dirname $(readlink -f ${BASH_SOURCE[0]}))/../../..) 92 | [[ -f ${CDIR}/env.sh ]] && . ${CDIR}/env.sh || exit 93 | 94 | # ----------------------------------------------------------------------------- 95 | # functions 96 | 97 | function check_running 98 | { 99 | progname=$1 100 | count_pgrep=`pgrep -f ${progname} | wc -l` 101 | count_pgrep=$(( ${count_pgrep} - 1 )) 102 | if (( count_pgrep > 0 )); then 103 | revert_calmness 104 | echo "count_pgrep = ${count_pgrep}" 105 | echo "${progname} is already running" 106 | exit 0 107 | fi 108 | } 109 | 110 | 111 | # end functions 112 | # ----------------------------------------------------------------------------- 113 | 114 | 115 | 116 | # ----------------------------------------------------------------------------- 117 | # main 118 | 119 | make_calmness 120 | child_verbose="" 121 | if (( VERBOSE_MODE > 1 )); then 122 | revert_calmness 123 | child_verbose="-v -v" 124 | fi 125 | 126 | MODE=$1 127 | PROCESS=$2 128 | 129 | check_running ${daemon_name} 130 | 131 | mkdir -p ${CDIR}/data 132 | mkdir -p ${CDIR}/lib 133 | 134 | function copy_resources { 135 | # data 136 | cp -rf ${PPDIR}/exported/${FROZEN_FILENAME} ${CDIR}/data 137 | cp -rf ${PPPDIR}/embeddings/${VOCAB_FILENAME} ${CDIR}/data 138 | # lib 139 | cp -rf ${PPDIR}/cc/build/${SO_FILENAME}* ${CDIR}/lib 140 | } 141 | copy_resources 142 | FROZEN_PATH=${CDIR}/data/${FROZEN_FILENAME} 143 | VOCAB_PATH=${CDIR}/data/${VOCAB_FILENAME} 144 | SO_PATH=${CDIR}/lib/${SO_FILENAME} 145 | 146 | cd ${CDIR} 147 | 148 | if (( MODE == 0 )); then 149 | nohup ${python} ${CDIR}/${daemon_name} \ 150 | --debug=True \ 151 | --port=${port_devel} \ 152 | --so_path=${SO_PATH} \ 153 | --frozen_graph_fn=${FROZEN_PATH} \ 154 | --vocab_fn=${VOCAB_PATH} \ 155 | --word_length=${WRD_LEN} \ 156 | --lowercase=${LOWERCASE} \ 157 | --is_memmapped=${IS_MEMMAPPED} \ 158 | --num_threads=${NUM_THREADS} \ 159 | --log_file_prefix=${CDIR}/log/access.log \ 160 | > /dev/null 2> /dev/null & 161 | else 162 | nohup ${python} ${CDIR}/${daemon_name} \ 163 | --debug=False \ 164 | --port=${port_service} \ 165 | --process=${PROCESS} \ 166 | --so_path=${SO_PATH} \ 167 | --frozen_graph_fn=${FROZEN_PATH} \ 168 | --vocab_fn=${VOCAB_PATH} \ 169 | --word_length=${WRD_LEN} \ 170 | --lowercase=${LOWERCASE} \ 171 | --is_memmapped=${IS_MEMMAPPED} \ 172 | --num_threads=${NUM_THREADS} \ 173 | --log_file_prefix=${CDIR}/log/access.log \ 174 | > /dev/null 2> /dev/null & 175 | fi 176 | cd ${CDIR} 177 | 178 | close_fd 179 | 180 | # end main 181 | # ----------------------------------------------------------------------------- 182 | -------------------------------------------------------------------------------- /inference/cc/src/TFUtil.cc: -------------------------------------------------------------------------------- 1 | #include "TFUtil.h" 2 | 3 | /* 4 | * public methods 5 | */ 6 | 7 | TFUtil::TFUtil() 8 | { 9 | } 10 | 11 | tensorflow::MemmappedEnv* TFUtil::CreateMemmappedEnv(string graph_fn) 12 | { 13 | tensorflow::MemmappedEnv* memmapped_env = new tensorflow::MemmappedEnv(tensorflow::Env::Default()); 14 | TF_CHECK_OK(memmapped_env->InitializeFromFile(graph_fn)); 15 | return memmapped_env; 16 | } 17 | 18 | tensorflow::Session* TFUtil::CreateSession(tensorflow::MemmappedEnv* memmapped_env, int num_threads = 0) 19 | { 20 | tensorflow::Session* sess; 21 | tensorflow::SessionOptions options; 22 | 23 | if( memmapped_env ) { 24 | options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(::tensorflow::OptimizerOptions::L0); 25 | options.env = memmapped_env; 26 | } 27 | 28 | tensorflow::ConfigProto& conf = options.config; 29 | if( num_threads > 0 ) { 30 | conf.set_inter_op_parallelism_threads(num_threads); 31 | conf.set_intra_op_parallelism_threads(num_threads); 32 | } 33 | TF_CHECK_OK(tensorflow::NewSession(options, &sess)); 34 | return sess; 35 | } 36 | 37 | void TFUtil::DestroySession(tensorflow::Session* sess) 38 | { 39 | if( sess ) sess->Close(); 40 | } 41 | 42 | tensorflow::Status TFUtil::LoadFrozenModel(tensorflow::Session* sess, string graph_fn) 43 | { 44 | tensorflow::Status status; 45 | 46 | load_lstm_lib(); 47 | /* load_qrnn_lib(); */ 48 | 49 | // Read in the protobuf graph freezed 50 | tensorflow::GraphDef graph_def; 51 | status = ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def); 52 | if( status != tensorflow::Status::OK() ) return status; 53 | 54 | // Create the graph in the current session 55 | status = sess->Create(graph_def); 56 | if( status != tensorflow::Status::OK() ) return status; 57 | 58 | return tensorflow::Status::OK(); 59 | } 60 | 61 | tensorflow::Status TFUtil::LoadFrozenMemmappedModel(tensorflow::MemmappedEnv* memmapped_env, tensorflow::Session* sess) 62 | { 63 | tensorflow::Status status; 64 | 65 | load_lstm_lib(); 66 | /* load_qrnn_lib(); */ 67 | 68 | // Read the memmory-mapped graph 69 | tensorflow::GraphDef graph_def; 70 | status = ReadBinaryProto(memmapped_env, tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, &graph_def); 71 | 72 | // Create the graph in the current session 73 | status = sess->Create(graph_def); 74 | if( status != tensorflow::Status::OK() ) return status; 75 | 76 | return tensorflow::Status::OK(); 77 | } 78 | 79 | tensorflow::Status TFUtil::LoadModel(tensorflow::Session *sess, 80 | string graph_fn, 81 | string checkpoint_fn = "") 82 | { 83 | 84 | /* 85 | * source is from https://github.com/PatWie/tensorflow-cmake/blob/master/inference/cc/inference_cc.cc 86 | */ 87 | tensorflow::Status status; 88 | 89 | // Read in the protobuf graph we exported 90 | tensorflow::MetaGraphDef graph_def; 91 | status = ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def); 92 | if (status != tensorflow::Status::OK()) return status; 93 | 94 | // create the graph in the current session 95 | status = sess->Create(graph_def.graph_def()); 96 | if (status != tensorflow::Status::OK()) return status; 97 | 98 | // restore model from checkpoint, iff checkpoint is given 99 | if (checkpoint_fn != "") { 100 | const string restore_op_name = graph_def.saver_def().restore_op_name(); 101 | const string filename_tensor_name = 102 | graph_def.saver_def().filename_tensor_name(); 103 | 104 | tensorflow::Tensor filename_tensor(tensorflow::DT_STRING, 105 | tensorflow::TensorShape()); 106 | filename_tensor.scalar()() = checkpoint_fn; 107 | 108 | tensor_dict feed_dict = {{filename_tensor_name, filename_tensor}}; 109 | status = sess->Run(feed_dict, {}, {restore_op_name}, nullptr); 110 | if (status != tensorflow::Status::OK()) return status; 111 | } else { 112 | // virtual Status Run(const vector >& inputs, 113 | // const vector& output_tensor_names, 114 | // const vector& target_node_names, 115 | // vector* outputs) = 0; 116 | status = sess->Run({}, {}, {"init"}, nullptr); 117 | if (status != tensorflow::Status::OK()) return status; 118 | } 119 | 120 | return tensorflow::Status::OK(); 121 | } 122 | 123 | TFUtil::~TFUtil() 124 | { 125 | } 126 | 127 | /* 128 | * private methods 129 | */ 130 | 131 | void TFUtil::load_lstm_lib() 132 | { 133 | /* 134 | * Load _lstm_ops.so library(from LB_LIBRARY_PATH) for LSTMBlockFusedCell() 135 | */ 136 | TF_Status* status = TF_NewStatus(); 137 | TF_LoadLibrary("_lstm_ops.so", status); 138 | if( TF_GetCode(status) != TF_OK ) { 139 | throw runtime_error("fail to load _lstm_ops.so"); 140 | } 141 | TF_DeleteStatus(status); 142 | } 143 | 144 | void TFUtil::load_qrnn_lib() 145 | { 146 | /* 147 | * Load qrnn_lib.cpython-36m-x86_64-linux-gnu.so library(from LB_LIBRARY_PATH) for QRNN 148 | */ 149 | TF_Status* status = TF_NewStatus(); 150 | TF_LoadLibrary("qrnn_lib.cpython-36m-x86_64-linux-gnu.so", status); 151 | if( TF_GetCode(status) != TF_OK ) { 152 | throw runtime_error("fail to load qrnn_lib.cpython-36m-x86_64-linux-gnu.so"); 153 | } 154 | TF_DeleteStatus(status); 155 | } 156 | -------------------------------------------------------------------------------- /inference/freeze.py: -------------------------------------------------------------------------------- 1 | import sys, os, argparse 2 | import tensorflow as tf 3 | # for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369 4 | tf.contrib.rnn 5 | # for QRNN 6 | try: import qrnn 7 | except: sys.stderr.write('import qrnn, failed\n') 8 | 9 | ''' 10 | source is from https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py 11 | ''' 12 | 13 | # The original freeze_graph function 14 | # from tensorflow.python.tools.freeze_graph import freeze_graph 15 | 16 | dir = os.path.dirname(os.path.realpath(__file__)) 17 | 18 | def modify_op(graph_def): 19 | """ 20 | reference : https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091 21 | """ 22 | for node in graph_def.node: 23 | if node.op == 'Assign': 24 | node.op = 'Identity' 25 | if 'use_locking' in node.attr: del node.attr['use_locking'] 26 | if 'validate_shape' in node.attr: del node.attr['validate_shape'] 27 | if len(node.input) == 2: 28 | # input0: ref: Should be from a Variable node. May be uninitialized. 29 | # input1: value: The value to be assigned to the variable. 30 | node.input[0] = node.input[1] 31 | del node.input[1] 32 | return graph_def 33 | 34 | def freeze_graph(model_dir, output_node_names, frozen_model_name, optimize_graph_def=0): 35 | """Extract the sub graph defined by the output nodes and convert 36 | all its variables into constant 37 | Args: 38 | model_dir: the root folder containing the checkpoint state file 39 | output_node_names: a string, containing all the output node's names, 40 | comma separated 41 | frozen_model_name: a string, the name of the frozen model 42 | optimize_graph_def: int, 1 for optimizing graph_def via tensorRT 43 | """ 44 | if not tf.gfile.Exists(model_dir): 45 | raise AssertionError( 46 | "Export directory doesn't exists. Please specify an export " 47 | "directory: %s" % model_dir) 48 | 49 | if not output_node_names: 50 | print("You need to supply the name of a node to --output_node_names.") 51 | return -1 52 | 53 | # We retrieve our checkpoint fullpath 54 | checkpoint = tf.train.get_checkpoint_state(model_dir) 55 | input_checkpoint = checkpoint.model_checkpoint_path 56 | 57 | # We precise the file fullname of our freezed graph 58 | absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) 59 | output_graph_path = absolute_model_dir + "/" + frozen_model_name 60 | 61 | # We clear devices to allow TensorFlow to control on which device it will load operations 62 | clear_devices = True 63 | 64 | # We start a session using a temporary fresh Graph 65 | with tf.Session(graph=tf.Graph()) as sess: 66 | # We import the meta graph in the current default Graph 67 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) 68 | 69 | # We restore the weights 70 | saver.restore(sess, input_checkpoint) 71 | 72 | # We use a built-in TF helper to export variables to constants 73 | output_graph_def = tf.graph_util.convert_variables_to_constants( 74 | sess, # The session is used to retrieve the weights 75 | tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 76 | output_node_names.split(',') # The output node names are used to select the usefull nodes 77 | ) 78 | 79 | # Modify for 'float_ref' 80 | output_graph_def = modify_op(output_graph_def) 81 | 82 | # Optimize graph_def via tensorRT 83 | if optimize_graph_def: 84 | from tensorflow.contrib import tensorrt as trt 85 | # get optimized graph_def 86 | trt_graph_def = trt.create_inference_graph( 87 | input_graph_def=output_graph_def, 88 | outputs=output_node_names.split(','), 89 | max_batch_size=128, 90 | max_workspace_size_bytes=1 << 30, 91 | precision_mode='FP16', # TRT Engine precision "FP32","FP16" or "INT8" 92 | minimum_segment_size=3 # minimum number of nodes in an engine 93 | ) 94 | output_graph_def = trt_graph_def 95 | 96 | # Finally we serialize and dump the output graph to the filesystem 97 | with tf.gfile.GFile(output_graph_path, "wb") as f: 98 | f.write(output_graph_def.SerializeToString()) 99 | print("%d ops in the final graph." % len(output_graph_def.node)) 100 | 101 | return output_graph_def 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--model_dir", type=str, help="Model folder to export", required=True) 106 | parser.add_argument("--frozen_model_name", type=str, help="The name of the frozen model", required=True) 107 | parser.add_argument("--output_node_names", type=str, help="The name of the output nodes, comma separated.", required=True) 108 | parser.add_argument("--optimize_graph_def", type=int, help="1 for optimizing graph_def via tensorRT, default 0", default=0, required=False) 109 | args = parser.parse_args() 110 | 111 | freeze_graph(args.model_dir, args.output_node_names, args.frozen_model_name, args.optimize_graph_def) 112 | 113 | 114 | -------------------------------------------------------------------------------- /etc/token_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | import numpy as np 5 | 6 | class TokenEval: 7 | """Token-based evaluation 8 | """ 9 | 10 | def __init__(self): 11 | self.cls = {} 12 | self.tp = {} 13 | self.fp = {} 14 | self.fn = {} 15 | self.precision = {} 16 | self.recall = {} 17 | self.fscore = {} 18 | 19 | def __eval_bucket(self, bucket): 20 | line_num = 0 21 | for line in bucket: 22 | tokens = line.split() 23 | size = len(tokens) 24 | if line_num == 0 and size == 3: # skip 'USING SKIP CONNECTIONS' 25 | line_num += 1 26 | continue 27 | assert(size == 5) 28 | w = tokens[0] 29 | pos = tokens[1] 30 | chunk = tokens[2] 31 | tag = tokens[3] 32 | pred = tokens[4] 33 | if pred not in self.tp: self.tp[pred] = 0 34 | if tag not in self.tp: self.tp[tag] = 0 35 | if pred not in self.fp: self.fp[pred] = 0 36 | if tag not in self.fp: self.fp[tag] = 0 37 | if pred not in self.fn: self.fn[pred] = 0 38 | if tag not in self.fn: self.fn[tag] = 0 39 | if tag == pred: 40 | self.tp[pred] += 1 41 | else: 42 | self.fp[pred] += 1 43 | self.fn[tag] += 1 44 | self.cls[pred] = None 45 | self.cls[tag] = None 46 | line_num += 1 47 | 48 | def eval(self): 49 | """Compute micro precision, recall, fscore given file. 50 | """ 51 | bucket = [] 52 | while 1: 53 | try: line = sys.stdin.readline() 54 | except KeyboardInterrupt: break 55 | if not line: break 56 | line = line.strip() 57 | if not line and len(bucket) >= 1: 58 | self.__eval_bucket(bucket) 59 | bucket = [] 60 | if line : bucket.append(line) 61 | if len(bucket) != 0: 62 | self.__eval_bucket(bucket) 63 | 64 | # in_class vs out_class 65 | in_class = 'I' 66 | out_classes = ['O', 'X'] 67 | self.tp[in_class] = 0 68 | self.fp[in_class] = 0 69 | self.fn[in_class] = 0 70 | for c, _ in self.cls.items(): 71 | if c not in out_classes: 72 | self.tp[in_class] += self.tp[c] 73 | self.fp[in_class] += self.fp[c] 74 | self.fn[in_class] += self.fn[c] 75 | self.cls[in_class] = None 76 | 77 | print(self.tp) 78 | print(self.fp) 79 | print(self.fn) 80 | 81 | for c, _ in self.cls.items(): 82 | if self.tp[c] + self.fp[c] != 0: 83 | self.precision[c] = self.tp[c]*1.0 / (self.tp[c] + self.fp[c]) 84 | else: 85 | self.precision[c] = 0 86 | if self.tp[c] + self.fn[c] != 0: 87 | self.recall[c] = self.tp[c]*1.0 / (self.tp[c] + self.fn[c]) 88 | else: 89 | self.recall[c] = 0 90 | if self.precision[c] + self.recall[c] != 0: 91 | self.fscore[c] = 2.0*self.precision[c]*self.recall[c] / (self.precision[c] + self.recall[c]) 92 | else: 93 | self.fscore[c] = 0 94 | 95 | print('') 96 | print('precision:') 97 | for c, _ in self.precision.items(): 98 | print(c + ',' + str(self.precision[c])) 99 | print('') 100 | print('recall:') 101 | for c, _ in self.recall.items(): 102 | print(c + ',' + str(self.recall[c])) 103 | print('') 104 | print('fscore:') 105 | for c, _ in self.fscore.items(): 106 | print(c + ',' + str(self.fscore[c])) 107 | print('') 108 | print('total fscore:') 109 | print(self.fscore[in_class]) 110 | 111 | @staticmethod 112 | def compute_f1(class_size, prediction, target, length): 113 | """Compute micro Fscore given prediction and target 114 | along with list of Precision, Recall, Fscore for each class. 115 | """ 116 | tp = np.array([0] * (class_size + 1)) 117 | fp = np.array([0] * (class_size + 1)) 118 | fn = np.array([0] * (class_size + 1)) 119 | for i in range(len(target)): 120 | for j in range(length[i]): 121 | if target[i, j] == prediction[i, j]: 122 | tp[prediction[i, j]] += 1 123 | else: 124 | fp[prediction[i, j]] += 1 125 | fn[target[i, j]] += 1 126 | out_of_classes = [0, 1] # SEE embvec.oot_tid, embvec.xot_tid 127 | for i in range(class_size): 128 | if i not in out_of_classes: 129 | tp[class_size] += tp[i] 130 | fp[class_size] += fp[i] 131 | fn[class_size] += fn[i] 132 | precision = [] 133 | recall = [] 134 | fscore = [] 135 | for i in range(class_size + 1): 136 | if tp[i] + fp[i] == 0: precision.append(0.0) 137 | else: precision.append(tp[i] * 1.0 / (tp[i] + fp[i])) 138 | if tp[i] + fn[i] == 0: recall.append(0.0) 139 | else: recall.append(tp[i] * 1.0 / (tp[i] + fn[i])) 140 | if precision[i] + recall[i] == 0: fscore.append(0.0) 141 | else: fscore.append(2.0 * precision[i] * recall[i] / (precision[i] + recall[i])) 142 | return fscore[class_size], precision, recall, fscore 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | 147 | args = parser.parse_args() 148 | 149 | ev = TokenEval() 150 | ev.eval() 151 | -------------------------------------------------------------------------------- /inference/cc/www/handlers/index.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import logging 5 | import tornado.web 6 | from handlers.base import BaseHandler 7 | import json 8 | import time 9 | 10 | ############################################################################################### 11 | # nlp : spacy 12 | def get_entity(doc, begin, end): 13 | for ent in doc.ents: 14 | # check included 15 | if ent.start_char <= begin and end <= ent.end_char: 16 | if ent.start_char == begin: return 'B-' + ent.label_ 17 | else: return 'I-' + ent.label_ 18 | return 'O' 19 | 20 | def build_bucket(nlp, line): 21 | bucket = [] 22 | doc = nlp(line) 23 | for token in doc: 24 | begin = token.idx 25 | end = begin + len(token.text) - 1 26 | temp = [] 27 | temp.append(token.text) 28 | temp.append(token.tag_) 29 | temp.append('O') # no chunking info 30 | entity = get_entity(doc, begin, end) 31 | temp.append(entity) # entity by spacy 32 | temp = ' '.join(temp) 33 | bucket.append(temp) 34 | return bucket 35 | 36 | def analyze(Etagger, etagger, nlp, query): 37 | """Analyze query by nlp, etagger 38 | """ 39 | bucket = build_bucket(nlp, query) 40 | result = Etagger.analyze(etagger, bucket) 41 | ## build output 42 | out = [] 43 | for i in range(len(result)): 44 | tl = result[i] 45 | entry = {} 46 | entry['id'] = i 47 | entry['word'] = tl[0] 48 | entry['pos'] = tl[1] 49 | entry['chk'] = tl[2] 50 | entry['tag'] = tl[3] 51 | entry['predict'] = tl[4] 52 | out.append(entry) 53 | return out 54 | ############################################################################################### 55 | 56 | class IndexHandler(BaseHandler): 57 | def get(self): 58 | q = self.get_argument('q', '') 59 | self.render('index.html', q=q) 60 | 61 | class HCheckHandler(BaseHandler): 62 | def get(self): 63 | self.set_header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') 64 | templates_dir = 'templates' 65 | hdn_filename = '_hcheck.hdn' 66 | err_filename = 'error.html' 67 | try : fid = open(templates_dir + "/" + hdn_filename, 'r') 68 | except : 69 | self.set_status(404) 70 | self.render(err_filename) 71 | else : 72 | fid.close() 73 | self.render(hdn_filename) 74 | 75 | class EtaggerHandler(BaseHandler): 76 | def get(self) : 77 | start_time = time.time() 78 | 79 | callback = self.get_argument('callback', '') 80 | mode = self.get_argument('mode', 'product') 81 | try : 82 | query = self.get_argument('q', '') 83 | except : 84 | query = "Invalid unicode in q" 85 | 86 | debug = {} 87 | debug['callback'] = callback 88 | debug['mode'] = mode 89 | pid = os.getpid() 90 | debug['pid'] = pid 91 | 92 | rst = {} 93 | rst['msg'] = '' 94 | rst['query'] = query 95 | if mode == 'debug' : rst['debug'] = debug 96 | 97 | Etagger = self.Etagger 98 | etagger = self.etagger[pid] 99 | nlp = self.nlp 100 | try : 101 | out = analyze(Etagger, etagger, nlp, query) 102 | rst['status'] = 200 103 | rst['output'] = out 104 | except : 105 | rst['status'] = 500 106 | rst['output'] = [] 107 | rst['msg'] = 'analyze() fail' 108 | 109 | if mode == 'debug' : 110 | duration_time = time.time() - start_time 111 | debug['exectime'] = duration_time 112 | 113 | try : 114 | ret = json.dumps(rst) 115 | except : 116 | msg = "json.dumps() fail for query %s" % (query) 117 | self.log.debug(msg + "\n") 118 | err = {} 119 | err['status'] = 500 120 | err['msg'] = msg 121 | ret = json.dumps(err) 122 | 123 | if mode == 'debug' : 124 | self.set_header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') 125 | 126 | if callback.strip() : 127 | self.set_header('Content-Type', 'application/javascript; charset=utf-8') 128 | ret = 'if (typeof %s === "function") %s(%s);' % (callback, callback, ret) 129 | else : 130 | self.set_header('Content-Type', 'application/json; charset=utf-8') 131 | 132 | self.write(ret) 133 | self.finish() 134 | 135 | 136 | def post(self): 137 | self.get() 138 | 139 | class EtaggerTestHandler(BaseHandler): 140 | def post(self): 141 | if self.request.body : 142 | try: 143 | json_data = json.loads(self.request.body) 144 | self.request.arguments.update(json_data) 145 | content = '' 146 | if 'content' in json_data : content = json_data['content'] 147 | is_json_request = True 148 | except: 149 | content = self.get_argument('content', "", True) 150 | is_json_request = False 151 | else: 152 | self.write(dict(success=False, info='no request body for post')) 153 | self.finish() 154 | 155 | pid = os.getpid() 156 | Etagger = self.Etagger 157 | etagger = self.etagger[pid] 158 | nlp = self.nlp 159 | 160 | if is_json_request : lines = content 161 | else: lines = content.split('\n') 162 | try: 163 | out_list=[] 164 | for line in lines : 165 | line = line.strip() 166 | if not line : continue 167 | out = analyze(Etagger, etagger, nlp, line) 168 | out_list.append(out) 169 | self.write(dict(success=True, record=out_list, info=None)) 170 | except Exception as e: 171 | msg = str(e) 172 | self.write(dict(success=False, info=msg)) 173 | 174 | self.finish() 175 | -------------------------------------------------------------------------------- /inference/python/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | path = os.path.dirname(os.path.abspath(__file__)) + '/../..' 5 | sys.path.append(path) 6 | import time 7 | import argparse 8 | import tensorflow as tf 9 | import numpy as np 10 | # for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369 11 | tf.contrib.rnn 12 | # for QRNN 13 | try: import qrnn 14 | except: sys.stderr.write('import qrnn, failed\n') 15 | 16 | from embvec import EmbVec 17 | from config import Config 18 | from input import Input 19 | import feed 20 | 21 | def load_frozen_graph(frozen_graph_filename, prefix='prefix'): 22 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 23 | graph_def = tf.GraphDef() 24 | graph_def.ParseFromString(f.read()) 25 | with tf.Graph().as_default() as graph: 26 | tf.import_graph_def( 27 | graph_def, 28 | input_map=None, 29 | return_elements=None, 30 | op_dict=None, 31 | producer_op_list=None, 32 | name=prefix, 33 | ) 34 | return graph 35 | 36 | def inference(config, frozen_pb_path): 37 | """Inference for bucket 38 | """ 39 | 40 | # load graph 41 | graph = load_frozen_graph(frozen_pb_path) 42 | for op in graph.get_operations(): 43 | sys.stderr.write(op.name + '\n') 44 | 45 | # create session with graph 46 | # if graph is optimized by tensorRT, then 47 | # from tensorflow.contrib import tensorrt as trt 48 | # gpu_ops = tf.GPUOptions(per_process_gpu_memory_fraction = 0.50) 49 | gpu_ops = tf.GPUOptions() 50 | ''' 51 | session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_ops) 52 | ''' 53 | session_conf = tf.ConfigProto(allow_soft_placement=True, 54 | log_device_placement=False, 55 | gpu_options=gpu_ops, 56 | inter_op_parallelism_threads=0, 57 | intra_op_parallelism_threads=0) 58 | sess = tf.Session(graph=graph, config=session_conf) 59 | 60 | # mapping output/input tensors for bert 61 | if 'bert' in config.emb_class: 62 | t_bert_embeddings_subgraph = graph.get_tensor_by_name('prefix/bert_embeddings_subgraph:0') 63 | p_bert_embeddings = graph.get_tensor_by_name('prefix/bert_embeddings:0') 64 | # mapping output tensors 65 | t_logits_indices = graph.get_tensor_by_name('prefix/logits_indices:0') 66 | t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0') 67 | 68 | num_buckets = 0 69 | total_duration_time = 0.0 70 | bucket = [] 71 | while 1: 72 | try: line = sys.stdin.readline() 73 | except KeyboardInterrupt: break 74 | if not line: break 75 | line = line.strip() 76 | if not line and len(bucket) >= 1: 77 | start_time = time.time() 78 | inp, feed_dict = feed.build_input_feed_dict_with_graph(graph, config, bucket, Input) 79 | if 'bert' in config.emb_class: 80 | # compute bert embedding at runtime 81 | bert_embeddings = sess.run([t_bert_embeddings_subgraph], feed_dict=feed_dict) 82 | # update feed_dict 83 | feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 84 | logits_indices, sentence_lengths = sess.run([t_logits_indices, t_sentence_lengths], feed_dict=feed_dict) 85 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 86 | for i in range(len(bucket)): 87 | out = bucket[i] + ' ' + tags[i] 88 | sys.stdout.write(out + '\n') 89 | sys.stdout.write('\n') 90 | bucket = [] 91 | duration_time = time.time() - start_time 92 | out = 'duration_time : ' + str(duration_time) + ' sec' 93 | sys.stderr.write(out + '\n') 94 | num_buckets += 1 95 | if num_buckets != 1: # first one may takes longer time, so ignore in computing duration. 96 | total_duration_time += duration_time 97 | if line : bucket.append(line) 98 | if len(bucket) != 0: 99 | start_time = time.time() 100 | inp, feed_dict = feed.build_input_feed_dict_with_graph(graph, config, bucket, Input) 101 | if 'bert' in config.emb_class: 102 | # compute bert embedding at runtime 103 | bert_embeddings = sess.run([t_bert_embeddings_subgraph], feed_dict=feed_dict) 104 | # update feed_dict 105 | feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 106 | logits_indices, sentence_lengths = sess.run([t_logits_indices, t_sentence_lengths], feed_dict=feed_dict) 107 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 108 | for i in range(len(bucket)): 109 | out = bucket[i] + ' ' + tags[i] 110 | sys.stdout.write(out + '\n') 111 | sys.stdout.write('\n') 112 | duration_time = time.time() - start_time 113 | out = 'duration_time : ' + str(duration_time) + ' sec' 114 | tf.logging.info(out) 115 | num_buckets += 1 116 | total_duration_time += duration_time 117 | 118 | out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n' 119 | out += 'average processing time / bucket : ' + str(total_duration_time / (num_buckets-1)) + ' sec' 120 | tf.logging.info(out) 121 | 122 | sess.close() 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--emb_path', type=str, help='path to word embedding vector + vocab(.pkl)', required=True) 127 | parser.add_argument('--config_path', type=str, default='data/config.json', help='path to config.json') 128 | parser.add_argument('--wrd_dim', type=int, help='dimension of word embedding vector', required=True) 129 | parser.add_argument('--word_length', type=int, default=15, help='max word length') 130 | parser.add_argument('--frozen_path', type=str, help='path to frozen model(ex, ./exported/ner_frozen.pb)', required=True) 131 | 132 | args = parser.parse_args() 133 | tf.logging.set_verbosity(tf.logging.INFO) 134 | 135 | args.restore = None 136 | config = Config(args, is_training=False, emb_class='glove', use_crf=True) 137 | inference(config, args.frozen_path) 138 | -------------------------------------------------------------------------------- /inference/cc/src/Input.cc: -------------------------------------------------------------------------------- 1 | #include "Input.h" 2 | #include 3 | #include 4 | 5 | /* 6 | * public methods 7 | */ 8 | 9 | Input::Input(Config* config, Vocab* vocab, vector& bucket) 10 | { 11 | /* 12 | * Args: 13 | * config: configuration info. class_size, word_length, etc. 14 | * vocab: vocab info. word id, pos id, chk id, tag id, etc. 15 | * bucket: list of 'word pos chk tag'. 16 | */ 17 | this->max_sentence_length = bucket.size(); 18 | 19 | // create input tensors 20 | int word_length = config->GetWordLength(); 21 | tensorflow::TensorShape shape1({1, this->max_sentence_length}); 22 | this->sentence_word_ids = new tensorflow::Tensor(tensorflow::DT_INT32, shape1); 23 | tensorflow::TensorShape shape2({1, this->max_sentence_length, word_length}); 24 | this->sentence_wordchr_ids = new tensorflow::Tensor(tensorflow::DT_INT32, shape2); 25 | this->sentence_pos_ids = new tensorflow::Tensor(tensorflow::DT_INT32, shape1); 26 | this->sentence_chk_ids = new tensorflow::Tensor(tensorflow::DT_INT32, shape1); 27 | // additional scalar tensor for sentence_length, is_train 28 | this->sentence_length = new tensorflow::Tensor(tensorflow::DT_INT32, tensorflow::TensorShape()); 29 | this->is_train = new tensorflow::Tensor(tensorflow::DT_BOOL, tensorflow::TensorShape()); 30 | 31 | auto data_word_ids = this->sentence_word_ids->flat().data(); 32 | auto data_wordchr_ids = this->sentence_wordchr_ids->flat().data(); 33 | auto data_pos_ids = this->sentence_pos_ids->flat().data(); 34 | auto data_chk_ids = this->sentence_chk_ids->flat().data(); 35 | auto data_sentence_length = this->sentence_length->flat().data(); 36 | auto data_is_train = this->is_train->flat().data(); 37 | 38 | for( int i = 0; i < max_sentence_length; i++ ) { 39 | string line = bucket[i]; 40 | vector tokens; 41 | vocab->Split(line, tokens); 42 | if( tokens.size() != 4 ) { 43 | throw runtime_error("input tokens must be size 4"); 44 | } 45 | string word = tokens[0]; 46 | string pos = tokens[1]; 47 | string chk = tokens[2]; 48 | string tag = tokens[3]; // correct tag(answer) or dummy 'O' 49 | // build sentence_word_ids 50 | int wid = vocab->GetWid(word); 51 | data_word_ids[i] = wid; 52 | // build sentence_wordchr_ids 53 | int wlen = word.length(); 54 | unsigned int* coffarr = build_coffarr(word.c_str(), wlen); 55 | int cpos_prev = -1; 56 | string ch = string(); 57 | int index = 0; 58 | for( int bpos = 0; bpos < wlen && index < word_length; bpos++ ) { 59 | int cpos = coffarr[bpos]; 60 | if( cpos == cpos_prev ) { 61 | ch = ch + word[bpos]; 62 | } else { 63 | if( !ch.empty() ) { 64 | // 1 character, ex) '가', 'a', '1', '!' 65 | int cid = vocab->GetCid(ch); 66 | data_wordchr_ids[i*word_length + index] = cid; 67 | index += 1; 68 | } 69 | ch.clear(); 70 | ch = word[bpos]; 71 | } 72 | cpos_prev = cpos; 73 | } 74 | if( !ch.empty() ) { 75 | int cid = vocab->GetCid(ch); 76 | data_wordchr_ids[i*word_length + index] = cid; 77 | index += 1; 78 | } 79 | for( int j = 0; j < word_length - index; j++ ) { // padding cid 80 | int pad_cid = vocab->GetPadCid(); 81 | data_wordchr_ids[i*word_length + index + j] = pad_cid; 82 | } 83 | if( coffarr ) free(coffarr); 84 | // build sentence_pos_ids 85 | int pid = vocab->GetPid(pos); 86 | data_pos_ids[i] = pid; 87 | // build sentence_chk_ids 88 | int kid = vocab->GetKid(chk); 89 | data_chk_ids[i] = kid; 90 | } 91 | *data_sentence_length = this->max_sentence_length; 92 | *data_is_train = false; 93 | } 94 | 95 | Input::~Input() 96 | { 97 | if( this->sentence_word_ids ) delete this->sentence_word_ids; 98 | if( this->sentence_wordchr_ids ) delete this->sentence_wordchr_ids; 99 | if( this->sentence_pos_ids ) delete this->sentence_pos_ids; 100 | if( this->sentence_chk_ids ) delete this->sentence_chk_ids; 101 | if( this->sentence_length ) delete this->sentence_length; 102 | if( this->is_train ) delete this->is_train; 103 | } 104 | 105 | /* 106 | * private methods 107 | */ 108 | 109 | int Input::utf8_len(char chr) 110 | { 111 | /* 112 | * get utf8 character length 113 | * 114 | * Args: 115 | * chr: begining byte in utf8 string. 116 | * 117 | * Returns: 118 | * character length for the chr, i.e, range. 119 | */ 120 | if( (chr & 0x80) == 0x00 ) 121 | return 1; 122 | else if( (chr & 0xE0) == 0xC0 ) 123 | return 2; 124 | else if( (chr & 0xF0) == 0xE0 ) 125 | return 3; 126 | else if( (chr & 0xF8) == 0xF0 ) 127 | return 4; 128 | else if( (chr & 0xFC) == 0xF8 ) 129 | return 5; 130 | else if( (chr & 0xFE) == 0xFC ) 131 | return 6; 132 | else if( (chr & 0xFE ) == 0xFE ) 133 | return 1; 134 | return 0; 135 | } 136 | 137 | unsigned int* Input::build_coffarr(const char* in, int in_size) 138 | { 139 | /* 140 | * compute character offset array 141 | * returnd pointer must be released 142 | * ex) utf-8 string : 가나다라abcd가나'\0' 143 | * ----------------------------------------------------------- 144 | * 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 145 | * 0 0 0 1 1 1 2 2 2 3 3 3 4 5 6 7 8 8 8 9 9 9 10 146 | * ----------------------------------------------------------- 147 | * 148 | * usage) 149 | * char* in = "가나다라abcd가나'; 150 | * int in_size = strlen(in); 151 | * unsigned int* coffarr = build_coffarr(in, insize); 152 | * int character_pos = coffarr[byte_pos]; 153 | * if( coffarr ) free(coffarr); 154 | * 155 | * Args: 156 | * in: utf8 string. 157 | * in_size: size of in(byte length). 158 | * 159 | * Returns: 160 | * unsigned int array. this should be freed later. 161 | */ 162 | int i, j; 163 | int index; 164 | int codelen; 165 | const char *s = in; 166 | unsigned int *char_offset_array; 167 | 168 | char_offset_array = (unsigned int*)malloc(sizeof(unsigned int) * (in_size+2)); 169 | if( char_offset_array == NULL ) { 170 | fprintf(stderr, "char_offset_array : malloc fail!"); 171 | return NULL; 172 | } 173 | index=0; 174 | // compute offset for last '\0' 175 | for( i = 0; i < in_size+1; i = i + codelen ) { 176 | codelen = this->utf8_len(s[i]); 177 | if( codelen == 0 ) { 178 | fprintf(stderr, "%s contains invalid utf8 begin code", in); 179 | if( char_offset_array != NULL ) { 180 | free(char_offset_array); 181 | return NULL; 182 | } 183 | } 184 | for( j = 0; j < codelen; j++ ) { 185 | if( codelen == 1 ) 186 | char_offset_array[i] = index; 187 | else { 188 | if( j == 0 ) { 189 | char_offset_array[i] = index; 190 | } else { 191 | if( this->utf8_len(s[i+j]) == 0 ) { // valid inner code 192 | char_offset_array[i+j] = index; 193 | } else { 194 | fprintf(stderr, "%s contains invalid utf8 inner code", in); 195 | free(char_offset_array); 196 | return NULL; 197 | } 198 | } 199 | } 200 | } 201 | index++; 202 | } 203 | return char_offset_array; 204 | } 205 | -------------------------------------------------------------------------------- /inference/python/www/handlers/index.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import logging 5 | import tornado.web 6 | from handlers.base import BaseHandler 7 | import json 8 | import time 9 | 10 | ############################################################################################### 11 | # etagger 12 | path = os.path.dirname(os.path.abspath(__file__)) + '/lib' 13 | sys.path.append(path) 14 | # although `import tensorflow as tf statement is in the `input.py`, 15 | # this statement will not be called by 16 | # `from handlers.index import IndexHandler, HCheckHandler, EtaggerHandler, EtaggerTestHandler`. 17 | from input import Input 18 | import feed 19 | 20 | def get_entity(doc, begin, end): 21 | for ent in doc.ents: 22 | # check included 23 | if ent.start_char <= begin and end <= ent.end_char: 24 | if ent.start_char == begin: return 'B-' + ent.label_ 25 | else: return 'I-' + ent.label_ 26 | return 'O' 27 | 28 | def build_bucket(nlp, line): 29 | bucket = [] 30 | doc = nlp(line) 31 | for token in doc: 32 | begin = token.idx 33 | end = begin + len(token.text) - 1 34 | temp = [] 35 | temp.append(token.text) 36 | temp.append(token.tag_) 37 | temp.append('O') # no chunking info 38 | entity = get_entity(doc, begin, end) 39 | temp.append(entity) # entity by spacy 40 | temp = ' '.join(temp) 41 | bucket.append(temp) 42 | return bucket 43 | 44 | def analyze(graph, sess, query, config, nlp): 45 | """Analyze query by nlp, etagger 46 | """ 47 | bucket = build_bucket(nlp, query) 48 | inp, feed_dict = feed.build_input_feed_dict_with_graph(graph, config, bucket, Input) 49 | ## mapping output/input tensors for bert 50 | if 'bert' in config.emb_class: 51 | t_bert_embeddings_subgraph = graph.get_tensor_by_name('prefix/bert_embeddings_subgraph:0') 52 | p_bert_embeddings = graph.get_tensor_by_name('prefix/bert_embeddings:0') 53 | ## mapping output tensors 54 | t_logits_indices = graph.get_tensor_by_name('prefix/logits_indices:0') 55 | t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0') 56 | ## analyze 57 | if 'bert' in config.emb_class: 58 | # compute bert embedding at runtime 59 | bert_embeddings = sess.run([t_bert_embeddings_subgraph], feed_dict=feed_dict) 60 | # update feed_dict 61 | feed_dict[p_bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 62 | logits_indices, sentence_lengths = sess.run([t_logits_indices, t_sentence_lengths], feed_dict=feed_dict) 63 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 64 | ## build output 65 | out = [] 66 | for i in range(len(bucket)): 67 | tmp = bucket[i] + ' ' + tags[i] 68 | tl = tmp.split() 69 | entry = {} 70 | entry['id'] = i 71 | entry['word'] = tl[0] 72 | entry['pos'] = tl[1] 73 | entry['chk'] = tl[2] 74 | entry['tag'] = tl[3] 75 | entry['predict'] = tl[4] 76 | out.append(entry) 77 | return out 78 | ############################################################################################### 79 | 80 | class IndexHandler(BaseHandler): 81 | def get(self): 82 | q = self.get_argument('q', '') 83 | self.render('index.html', q=q) 84 | 85 | class HCheckHandler(BaseHandler): 86 | def get(self): 87 | self.set_header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') 88 | templates_dir = 'templates' 89 | hdn_filename = '_hcheck.hdn' 90 | err_filename = 'error.html' 91 | try : fid = open(templates_dir + "/" + hdn_filename, 'r') 92 | except : 93 | self.set_status(404) 94 | self.render(err_filename) 95 | else : 96 | fid.close() 97 | self.render(hdn_filename) 98 | 99 | class EtaggerHandler(BaseHandler): 100 | def get(self) : 101 | start_time = time.time() 102 | 103 | callback = self.get_argument('callback', '') 104 | mode = self.get_argument('mode', 'product') 105 | try : 106 | query = self.get_argument('q', '') 107 | except : 108 | query = "Invalid unicode in q" 109 | 110 | debug = {} 111 | debug['callback'] = callback 112 | debug['mode'] = mode 113 | pid = os.getpid() 114 | debug['pid'] = pid 115 | 116 | rst = {} 117 | rst['msg'] = '' 118 | rst['query'] = query 119 | if mode == 'debug' : rst['debug'] = debug 120 | 121 | config = self.config 122 | m = self.etagger[pid] 123 | sess = m['sess'] 124 | graph = m['graph'] 125 | nlp = self.nlp 126 | try : 127 | out = analyze(graph, sess, query, config, nlp) 128 | rst['status'] = 200 129 | rst['output'] = out 130 | except : 131 | rst['status'] = 500 132 | rst['output'] = [] 133 | rst['msg'] = 'analyze() fail' 134 | 135 | if mode == 'debug' : 136 | duration_time = time.time() - start_time 137 | debug['exectime'] = duration_time 138 | 139 | try : 140 | ret = json.dumps(rst) 141 | except : 142 | msg = "json.dumps() fail for query %s" % (query) 143 | self.log.debug(msg + "\n") 144 | err = {} 145 | err['status'] = 500 146 | err['msg'] = msg 147 | ret = json.dumps(err) 148 | 149 | if mode == 'debug' : 150 | self.set_header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') 151 | 152 | if callback.strip() : 153 | self.set_header('Content-Type', 'application/javascript; charset=utf-8') 154 | ret = 'if (typeof %s === "function") %s(%s);' % (callback, callback, ret) 155 | else : 156 | self.set_header('Content-Type', 'application/json; charset=utf-8') 157 | 158 | self.write(ret) 159 | self.finish() 160 | 161 | 162 | def post(self): 163 | self.get() 164 | 165 | class EtaggerTestHandler(BaseHandler): 166 | def post(self): 167 | if self.request.body : 168 | try: 169 | json_data = json.loads(self.request.body) 170 | self.request.arguments.update(json_data) 171 | content = '' 172 | if 'content' in json_data : content = json_data['content'] 173 | is_json_request = True 174 | except: 175 | content = self.get_argument('content', "", True) 176 | is_json_request = False 177 | else: 178 | self.write(dict(success=False, info='no request body for post')) 179 | self.finish() 180 | 181 | pid = os.getpid() 182 | config = self.config 183 | m = self.etagger[pid] 184 | sess = m['sess'] 185 | graph = m['graph'] 186 | nlp = self.nlp 187 | 188 | if is_json_request : lines = content 189 | else: lines = content.split('\n') 190 | try: 191 | out_list=[] 192 | for line in lines : 193 | line = line.strip() 194 | if not line : continue 195 | out = analyze(graph, sess, line, config, nlp) 196 | out_list.append(out) 197 | self.write(dict(success=True, record=out_list, info=None)) 198 | except Exception as e: 199 | msg = str(e) 200 | self.write(dict(success=False, info=msg)) 201 | 202 | self.finish() 203 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import time 4 | import argparse 5 | import tensorflow as tf 6 | import numpy as np 7 | from embvec import EmbVec 8 | from config import Config 9 | from model import Model 10 | from input import Input 11 | import feed 12 | 13 | def inference_bucket(config): 14 | """Inference for bucket. 15 | """ 16 | 17 | # create model and compile 18 | model = Model(config) 19 | model.compile() 20 | sess = model.sess 21 | 22 | # restore model 23 | saver = tf.train.Saver() 24 | saver.restore(sess, config.restore) 25 | sys.stderr.write('model restored' +'\n') 26 | ''' 27 | print(tf.global_variables()) 28 | print(tf.trainable_variables()) 29 | ''' 30 | num_buckets = 0 31 | total_duration_time = 0.0 32 | bucket = [] 33 | while 1: 34 | try: line = sys.stdin.readline() 35 | except KeyboardInterrupt: break 36 | if not line: break 37 | line = line.strip() 38 | if not line and len(bucket) >= 1: 39 | start_time = time.time() 40 | inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input) 41 | if 'bert' in config.emb_class: 42 | # compute bert embedding at runtime 43 | bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) 44 | # update feed_dict 45 | feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 46 | logits_indices, sentence_lengths = sess.run([model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) 47 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 48 | for i in range(len(bucket)): 49 | predict = config.embvec.oot_tag # ex) 'O' 50 | if i < sentence_lengths[0]: predict = tags[i] 51 | out = bucket[i] + ' ' + predict 52 | sys.stdout.write(out + '\n') 53 | sys.stdout.write('\n') 54 | bucket = [] 55 | duration_time = time.time() - start_time 56 | out = 'duration_time : ' + str(duration_time) + ' sec' 57 | tf.logging.info(out) 58 | num_buckets += 1 59 | if num_buckets != 1: # first one may take longer time, so ignore in computing duration. 60 | total_duration_time += duration_time 61 | if line : bucket.append(line) 62 | if len(bucket) != 0: 63 | start_time = time.time() 64 | inp, feed_dict = feed.build_input_feed_dict(model, bucket, Input) 65 | if 'bert' in config.emb_class: 66 | # compute bert embedding at runtime 67 | bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) 68 | # update feed_dict 69 | feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 70 | logits_indices, sentence_lengths = sess.run([model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) 71 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 72 | for i in range(len(bucket)): 73 | predict = config.embvec.oot_tag # ex) 'O' 74 | if i < sentence_lengths[0]: predict = tags[i] 75 | out = bucket[i] + ' ' + predict 76 | sys.stdout.write(out + '\n') 77 | sys.stdout.write('\n') 78 | duration_time = time.time() - start_time 79 | out = 'duration_time : ' + str(duration_time) + ' sec' 80 | tf.logging.info(out) 81 | num_buckets += 1 82 | total_duration_time += duration_time 83 | 84 | out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n' 85 | out += 'average processing time / bucket : ' + str(total_duration_time / (num_buckets-1)) + ' sec' 86 | tf.logging.info(out) 87 | 88 | sess.close() 89 | 90 | def inference_line(config): 91 | """Inference for raw string. 92 | """ 93 | def get_entity(doc, begin, end): 94 | for ent in doc.ents: 95 | # check included 96 | if ent.start_char <= begin and end <= ent.end_char: 97 | if ent.start_char == begin: return 'B-' + ent.label_ 98 | else: return 'I-' + ent.label_ 99 | return 'O' 100 | 101 | def build_bucket(nlp, line): 102 | bucket = [] 103 | doc = nlp(line) 104 | for token in doc: 105 | begin = token.idx 106 | end = begin + len(token.text) - 1 107 | temp = [] 108 | ''' 109 | print(token.i, token.text, token.lemma_, token.pos_, token.tag_, token.dep_, 110 | token.shape_, token.is_alpha, token.is_stop, begin, end) 111 | ''' 112 | temp.append(token.text) 113 | temp.append(token.tag_) 114 | temp.append('O') # no chunking info 115 | entity = get_entity(doc, begin, end) 116 | temp.append(entity) # entity by spacy 117 | temp = ' '.join(temp) 118 | bucket.append(temp) 119 | return bucket 120 | 121 | import spacy 122 | nlp = spacy.load('en') 123 | 124 | # create model and compile 125 | model = Model(config) 126 | model.compile() 127 | sess = model.sess 128 | 129 | # restore model 130 | saver = tf.train.Saver() 131 | saver.restore(sess, config.restore) 132 | tf.logging.info('model restored' +'\n') 133 | 134 | while 1: 135 | try: line = sys.stdin.readline() 136 | except KeyboardInterrupt: break 137 | if not line: break 138 | line = line.strip() 139 | if not line: continue 140 | # create bucket 141 | try: bucket = build_bucket(nlp, line) 142 | except Exception as e: 143 | sys.stderr.write(str(e) +'\n') 144 | continue 145 | inp, feed_dict = feed.build_input_feed_dict(model, bucket) 146 | if 'bert' in config.emb_class: 147 | # compute bert embedding at runtime 148 | bert_embeddings = sess.run([model.bert_embeddings_subgraph], feed_dict=feed_dict) 149 | # update feed_dict 150 | feed_dict[model.bert_embeddings] = feed.align_bert_embeddings(config, bert_embeddings, inp.example['bert_wordidx2tokenidx'], -1) 151 | logits_indices, sentence_lengths = sess.run([model.logits_indices, model.sentence_lengths], feed_dict=feed_dict) 152 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 153 | for i in range(len(bucket)): 154 | out = bucket[i] + ' ' + tags[i] 155 | sys.stdout.write(out + '\n') 156 | sys.stdout.write('\n') 157 | 158 | sess.close() 159 | 160 | if __name__ == '__main__': 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument('--emb_path', type=str, help='path to word embedding vector + vocab(.pkl)', required=True) 163 | parser.add_argument('--config_path', type=str, default='data/config.json', help='path to config.json') 164 | parser.add_argument('--wrd_dim', type=int, help='dimension of word embedding vector', required=True) 165 | parser.add_argument('--word_length', type=int, default=15, help='max word length') 166 | parser.add_argument('--restore', type=str, help='path to saved model(ex, ./checkpoint/ner_model)', required=True) 167 | parser.add_argument('--mode', type=str, default='bulk', help='bulk, bucket, line') 168 | 169 | args = parser.parse_args() 170 | tf.logging.set_verbosity(tf.logging.INFO) 171 | 172 | 173 | config = Config(args, is_training=False, emb_class='glove', use_crf=True) 174 | if args.mode == 'bucket': inference_bucket(config) 175 | if args.mode == 'line': inference_line(config) 176 | -------------------------------------------------------------------------------- /feed.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import argparse 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | def build_feed_dict(model, dataset, max_sentence_length, is_train): 10 | """Build feed_dict for dataset 11 | """ 12 | config = model.config 13 | feed_dict={model.input_data_pos_ids: dataset['pos_ids'], 14 | model.input_data_chk_ids: dataset['chk_ids'], 15 | model.output_data: dataset['tags'], 16 | model.is_train: is_train, 17 | model.sentence_length: max_sentence_length} 18 | feed_dict[model.input_data_word_ids] = dataset['word_ids'] 19 | feed_dict[model.input_data_wordchr_ids] = dataset['wordchr_ids'] 20 | if 'elmo' in config.emb_class: 21 | feed_dict[model.elmo_input_data_wordchr_ids] = dataset['elmo_wordchr_ids'] 22 | if 'bert' in config.emb_class: 23 | feed_dict[model.bert_input_data_token_ids] = dataset['bert_token_ids'] 24 | feed_dict[model.bert_input_data_token_masks] = dataset['bert_token_masks'] 25 | feed_dict[model.bert_input_data_segment_ids] = dataset['bert_segment_ids'] 26 | return feed_dict 27 | 28 | def build_input_feed_dict(model, bucket, Input): 29 | """Build input and feed_dict for bucket(inference only), by default, with model 30 | """ 31 | config = model.config 32 | inp = Input(bucket, config, build_output=False) 33 | feed_dict = {model.input_data_pos_ids: inp.example['pos_ids'], 34 | model.input_data_chk_ids: inp.example['chk_ids'], 35 | model.is_train: False, 36 | model.sentence_length: inp.max_sentence_length} 37 | feed_dict[model.input_data_word_ids] = inp.example['word_ids'] 38 | feed_dict[model.input_data_wordchr_ids] = inp.example['wordchr_ids'] 39 | if 'elmo' in config.emb_class: 40 | feed_dict[model.elmo_input_data_wordchr_ids] = inp.example['elmo_wordchr_ids'] 41 | if 'bert' in config.emb_class: 42 | feed_dict[model.bert_input_data_token_ids] = inp.example['bert_token_ids'] 43 | feed_dict[model.bert_input_data_token_masks] = inp.example['bert_token_masks'] 44 | feed_dict[model.bert_input_data_segment_ids] = inp.example['bert_segment_ids'] 45 | return inp, feed_dict 46 | 47 | def build_input_feed_dict_with_graph(graph, config, bucket, Input): 48 | """Build input and feed_dict for bucket(inference only) with graph 49 | """ 50 | # mapping placeholders 51 | p_is_train = graph.get_tensor_by_name('prefix/is_train:0') 52 | p_sentence_length = graph.get_tensor_by_name('prefix/sentence_length:0') 53 | p_input_data_pos_ids = graph.get_tensor_by_name('prefix/input_data_pos_ids:0') 54 | p_input_data_chk_ids = graph.get_tensor_by_name('prefix/input_data_chk_ids:0') 55 | p_input_data_word_ids = graph.get_tensor_by_name('prefix/input_data_word_ids:0') 56 | p_input_data_wordchr_ids = graph.get_tensor_by_name('prefix/input_data_wordchr_ids:0') 57 | if 'elmo' in config.emb_class: 58 | p_elmo_input_data_wordchr_ids = graph.get_tensor_by_name('prefix/elmo_input_data_wordchr_ids:0') 59 | if 'bert' in config.emb_class: 60 | p_bert_input_data_token_ids = graph.get_tensor_by_name('prefix/bert_input_data_token_ids:0') 61 | p_bert_input_data_token_masks = graph.get_tensor_by_name('prefix/bert_input_data_token_masks:0') 62 | p_bert_input_data_segment_ids = graph.get_tensor_by_name('prefix/bert_input_data_segment_ids:0') 63 | 64 | inp = Input(bucket, config, build_output=False) 65 | feed_dict = {p_input_data_pos_ids: inp.example['pos_ids'], 66 | p_input_data_chk_ids: inp.example['chk_ids'], 67 | p_is_train: False, 68 | p_sentence_length: inp.max_sentence_length} 69 | feed_dict[p_input_data_word_ids] = inp.example['word_ids'] 70 | feed_dict[p_input_data_wordchr_ids] = inp.example['wordchr_ids'] 71 | if 'elmo' in config.emb_class: 72 | feed_dict[p_elmo_input_data_wordchr_ids] = inp.example['elmo_wordchr_ids'] 73 | if 'bert' in config.emb_class: 74 | feed_dict[p_bert_input_data_token_ids] = inp.example['bert_token_ids'] 75 | feed_dict[p_bert_input_data_token_masks] = inp.example['bert_token_masks'] 76 | feed_dict[p_bert_input_data_segment_ids] = inp.example['bert_segment_ids'] 77 | return inp, feed_dict 78 | 79 | def align_bert_embeddings(config, bert_embeddings, bert_wordidx2tokenidx, idx): 80 | """Align bert_embeddings via bert_wordidx2tokenidx 81 | ex) word : 'johanson was a guy to' [0 ~ 4] 82 | token : 'johan ##son was a gu ##y t ##o' [0 ~ 7] 83 | wordidx2tokenidx : [1 3 4 5 7 9 0 0 ...] (bert embedding begins with [CLS] token) 84 | bert embedding : [em('CLS'), em('johan'), em('##son'), em('was'), em('a'), em('gu'), em('##y'), em('t'), em('##o'), 0, ...] 85 | """ 86 | def mean_pooling(ls): 87 | '''Reduce by averaging along with rows. 88 | Args: 89 | ls: list of embedding 90 | code from https://github.com/Adaxry/get_aligned_BERT_emb/blob/master/get_aligned_bert_emb.py#L27 91 | ''' 92 | if len(ls) == 1: 93 | return ls[0] 94 | for item in ls[1:]: 95 | for index, value in enumerate(item): 96 | ls[0][index] += value 97 | return [value / len(ls) for value in ls[0]] 98 | 99 | def mean_pooling_with_cls(ls, cls): 100 | '''Reduce by averaging along with rows. 101 | Args: 102 | ls: list of embedding 103 | cls: '[CLS]' sentence embedding for BERT 104 | ''' 105 | for item in ls: 106 | for index, value in enumerate(item): 107 | cls[index] += value 108 | return [value / (len(ls)+1) for value in cls] 109 | 110 | if idx == 0: 111 | tf.logging.debug('# bert_embeddings') 112 | t = bert_embeddings[0] 113 | tf.logging.debug(' '.join([str(x) for x in np.shape(t)])) 114 | t = bert_embeddings[0][0][1] # first (batch, seq, token) embedding 115 | tf.logging.debug(' '.join([str(x) for x in t])) 116 | 117 | # 4-dim -> 3-dim 118 | bert_embeddings = bert_embeddings[0] 119 | 120 | bert_embeddings_updated = [] 121 | batch_size = len(bert_wordidx2tokenidx) 122 | for i in range(batch_size): # batch 123 | bert_embedding_updated = [] 124 | prev = 1 125 | for j in range(len(bert_wordidx2tokenidx[i])): # seq 126 | cur = bert_wordidx2tokenidx[i][j] 127 | if j == 0: 128 | prev = cur 129 | continue # skip first for '[CLS]' 130 | if cur == 0: break # process before padding area 131 | 132 | # mean prev ~ cur 133 | try: 134 | pooled = mean_pooling(bert_embeddings[i][prev:cur]) 135 | ''' 136 | cls = bert_embeddings[i][0] 137 | pooled = mean_pooling_with_cls(bert_embeddings[i][prev:cur], cls) 138 | ''' 139 | bert_embedding_updated.append(pooled) 140 | except: 141 | tf.logging.debug('[ERROR] ' + 'seq:' + str(i) + '\t' + 'prev:' + str(prev) + '\t' + 'cur:' + str(cur)) 142 | # error padding 143 | padding = [0.0] * config.bert_dim 144 | bert_embedding_updated.append(padding) 145 | 146 | prev = cur 147 | # padding 148 | while len(bert_embedding_updated) < config.bert_max_seq_length: 149 | padding = [0.0] * config.bert_dim 150 | bert_embedding_updated.append(padding) 151 | bert_embeddings_updated.append(bert_embedding_updated) 152 | 153 | if idx == 0: 154 | tf.logging.debug('# bert_embeddings_updated') 155 | t = bert_embeddings_updated[0][0] # first (batch, seq, token) embedding 156 | tf.logging.debug(' '.join([str(x) for x in t])) 157 | tf.logging.debug('# batch size: ' + str(len(bert_embeddings_updated))) 158 | tf.logging.debug('# seq size: ' + str(len(bert_embeddings_updated[0]))) 159 | tf.logging.debug('# emb size: ' + str(len(bert_embeddings_updated[0][0]))) 160 | 161 | return bert_embeddings_updated 162 | 163 | -------------------------------------------------------------------------------- /inference/cc/www/etagger_dm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import logging 5 | from logging.handlers import RotatingFileHandler 6 | import signal 7 | import time 8 | import math 9 | 10 | import tornado.web 11 | import tornado.ioloop 12 | import tornado.autoreload 13 | import tornado.web 14 | import tornado.httpserver 15 | import tornado.process 16 | import tornado.autoreload as autoreload 17 | from tornado.options import define, options 18 | 19 | ############################################################################################### 20 | # etagger 21 | path = os.path.dirname(os.path.abspath(__file__)) + '/../wrapper' 22 | sys.path.append(path) 23 | import Etagger 24 | 25 | # etagger arguments 26 | define('so_path', default='', help='path to libetagger.so.', type=str) 27 | define('frozen_graph_fn', default='', help='path to frozen model(ex, ./exported/ner_frozen.pb).', type=str) 28 | define('vocab_fn', default='', help='path to vocab(ex, vocab.txt).', type=str) 29 | define('word_length', default=15, help='max word length.', type=int) 30 | define('lowercase', default='True', help='True if vocab file was all lowercased, otherwise False.', type=str) 31 | define('is_memmapped', default='False', help='is memory mapped graph, True | False.', type=str) 32 | define('num_threads', default=1, help='number of threads for tensorflow. 0 for all cores, n for n cores.', type=int) 33 | ############################################################################################### 34 | 35 | ############################################################################################### 36 | # nlp : spacy 37 | import spacy 38 | ############################################################################################### 39 | 40 | from handlers.index import IndexHandler, HCheckHandler, EtaggerHandler, EtaggerTestHandler 41 | define('port', default=8897, help='run on the given port.', type=int) 42 | define('debug', default=True, help='run on debug mode.', type=bool) 43 | define('process', default=3, help='number of process for service mode.', type=int) 44 | 45 | 46 | log = logging.getLogger('tornado.application') 47 | 48 | def setupAppLogger(): 49 | fmtStr = '%(asctime)s - %(levelname)s - %(module)s - %(message)s' 50 | formatter = logging.Formatter(fmt=fmtStr) 51 | 52 | cdir = os.path.dirname(os.path.abspath(options.log_file_prefix)) 53 | logfile = cdir + '/' + 'application.log' 54 | 55 | rotatingHandler = RotatingFileHandler(logfile, 'a', options.log_file_max_size, options.log_file_num_backups) 56 | rotatingHandler.setFormatter(formatter) 57 | 58 | if options.logging != 'none': 59 | log.setLevel(getattr(logging, options.logging.upper())) 60 | else: 61 | log.setLevel(logging.ERROR) 62 | 63 | log.propagate = False 64 | log.addHandler(rotatingHandler) 65 | 66 | return log 67 | 68 | class Application(tornado.web.Application): 69 | def __init__(self): 70 | settings = dict( 71 | static_path = os.path.join(os.path.dirname(__file__), 'static'), 72 | template_path = os.path.join(os.path.dirname(__file__), 'templates'), 73 | autoescape = None, 74 | debug = options.debug, 75 | gzip = True 76 | ) 77 | 78 | handlers = [ 79 | (r'/', IndexHandler), 80 | (r'/_hcheck.hdn', HCheckHandler), 81 | (r'/etagger', EtaggerHandler), 82 | (r'/etaggertest', EtaggerTestHandler), 83 | ] 84 | 85 | tornado.web.Application.__init__(self, handlers, **settings) 86 | autoreload.add_reload_hook(self.finalize) 87 | 88 | self.log = setupAppLogger() 89 | ppid = os.getpid() 90 | self.ppid = ppid 91 | self.log.info('initialize parent process[%s] ... done' % (ppid)) 92 | 93 | ############################################################################################### 94 | # save Etagger(python instance) for passing to handlers. 95 | self.Etagger = Etagger 96 | # create nlp(spacy) only once. 97 | self.nlp = spacy.load('en') 98 | self.log.info('initialize spacy on parent process[%s] ... done' % (ppid)) 99 | ############################################################################################### 100 | 101 | log.info('http start...') 102 | 103 | def initialize(self) : 104 | pid = os.getpid() 105 | self.log.info('initialize per child process[%s] ...' % (pid)) 106 | ############################################################################################### 107 | # create etagger instance for each child process. 108 | self.etagger = {} 109 | lowercase = False 110 | if options.lowercase == 'True': lowercase = True 111 | is_memmapped = False 112 | if options.is_memmapped == 'True': is_memmapped = True 113 | etagger = Etagger.initialize(options.so_path, 114 | options.frozen_graph_fn, 115 | options.vocab_fn, 116 | word_length=options.word_length, 117 | lowercase=lowercase, 118 | is_memmapped=is_memmapped, 119 | num_threads=options.num_threads) 120 | 121 | self.etagger[pid] = etagger 122 | ############################################################################################### 123 | self.log.info('initialize per child process[%s] ... done' % (pid)) 124 | 125 | def finalize(self): 126 | # finalize resources 127 | self.log.info('finalize resources...') 128 | ## finalize something.... 129 | for pid, etagger in self.etagger.iteritems() : 130 | Etagger.finalize(etagger) 131 | 132 | log.info('Close logger...') 133 | x = list(log.handlers) 134 | for i in x: 135 | log.removeHandler(i) 136 | i.flush() 137 | i.close() 138 | self.log.info('finalize resources... done') 139 | 140 | def main(): 141 | tornado.options.parse_command_line() 142 | 143 | ''' 144 | # you can prefork tornado before creating application. 145 | # code snippet: 146 | sockets = tornado.netutil.bind_sockets(options.port) 147 | tornado.process.fork_processes(options.process) 148 | application = Application() 149 | httpServer = tornado.httpserver.HTTPServer(application, no_keep_alive=True) 150 | httpServer.add_sockets(sockets) 151 | ''' 152 | 153 | application = Application() 154 | httpServer = tornado.httpserver.HTTPServer(application, no_keep_alive=True) 155 | if options.debug == True : 156 | httpServer.listen(options.port) 157 | application.initialize() 158 | else : 159 | httpServer.bind(options.port) 160 | if options.process == 0 : 161 | httpServer.start(0) # Forks multiple sub-processes, maximum to number of cores 162 | else : 163 | if options.process < 0 : 164 | options.process = 1 165 | httpServer.start(options.process) # Forks multiple sub-processes, given number 166 | pid = os.getpid() 167 | if pid != application.ppid : 168 | application.initialize() 169 | 170 | MAX_WAIT_SECONDS_BEFORE_SHUTDOWN = 3 171 | 172 | def sig_handler(sig, frame): 173 | log.warning('Caught signal: %s', sig) 174 | tornado.ioloop.IOLoop.instance().add_callback(shutdown) 175 | 176 | def shutdown(): 177 | log.info('Stopping http server') 178 | httpServer.stop() 179 | 180 | log.info('Will shutdown in %s seconds ...', MAX_WAIT_SECONDS_BEFORE_SHUTDOWN) 181 | io_loop = tornado.ioloop.IOLoop.instance() 182 | 183 | deadline = time.time() + MAX_WAIT_SECONDS_BEFORE_SHUTDOWN 184 | 185 | def stop_loop(): 186 | now = time.time() 187 | if now < deadline and (io_loop._callbacks or io_loop._timeouts): 188 | io_loop.add_timeout(now + 1, stop_loop) 189 | else: 190 | io_loop.stop() 191 | log.info('Shutdown') 192 | 193 | stop_loop() 194 | 195 | signal.signal(signal.SIGTERM, sig_handler) 196 | signal.signal(signal.SIGINT, sig_handler) 197 | 198 | tornado.ioloop.IOLoop.instance().start() 199 | 200 | log.info('Exit...') 201 | 202 | if __name__ == '__main__': 203 | main() 204 | -------------------------------------------------------------------------------- /inference/python/inference_trt.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | path = os.path.dirname(os.path.abspath(__file__)) + '/../..' 5 | sys.path.append(path) 6 | import time 7 | import argparse 8 | import tensorflow as tf 9 | import numpy as np 10 | # for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369 11 | tf.contrib.rnn 12 | # for QRNN 13 | try: import qrnn 14 | except: sys.stderr.write('import qrnn, failed\n') 15 | # for tensorRT 16 | from tensorflow.contrib import tensorrt as trt 17 | 18 | from embvec import EmbVec 19 | from config import Config 20 | from input import Input 21 | 22 | def load_frozen_graph_def(frozen_graph_filename): 23 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 24 | graph_def = tf.GraphDef() 25 | graph_def.ParseFromString(f.read()) 26 | return graph_def 27 | 28 | def load_graph(graph_def, prefix='prefix'): 29 | with tf.Graph().as_default() as graph: 30 | tf.import_graph_def( 31 | graph_def, 32 | input_map=None, 33 | return_elements=None, 34 | op_dict=None, 35 | producer_op_list=None, 36 | name=prefix, 37 | ) 38 | 39 | return graph 40 | 41 | def build_input_feed_dict(graph, bucket, config): 42 | """Build input and feed_dict for bucket(inference only) 43 | """ 44 | # mapping placeholders 45 | p_is_train = graph.get_tensor_by_name('prefix/is_train:0') 46 | p_sentence_length = graph.get_tensor_by_name('prefix/sentence_length:0') 47 | p_input_data_pos_ids = graph.get_tensor_by_name('prefix/input_data_pos_ids:0') 48 | p_input_data_chk_ids = graph.get_tensor_by_name('prefix/input_data_chk_ids:0') 49 | p_input_data_word_ids = graph.get_tensor_by_name('prefix/input_data_word_ids:0') 50 | p_input_data_wordchr_ids = graph.get_tensor_by_name('prefix/input_data_wordchr_ids:0') 51 | if 'elmo' in config.emb_class: 52 | p_elmo_input_data_wordchr_ids = graph.get_tensor_by_name('prefix/elmo_input_data_wordchr_ids:0') 53 | if 'bert' in config.emb_class: 54 | p_bert_input_data_token_ids = graph.get_tensor_by_name('prefix/bert_input_data_token_ids:0') 55 | p_bert_input_data_token_masks = graph.get_tensor_by_name('prefix/bert_input_data_token_masks:0') 56 | p_bert_input_data_segment_ids = graph.get_tensor_by_name('prefix/bert_input_data_segment_ids:0') 57 | if 'elmo' in config.emb_class: 58 | p_bert_input_data_elmo_indices = graph.get_tensor_by_name('prefix/bert_input_data_elmo_indices:0') 59 | 60 | inp = Input(bucket, config, build_output=False) 61 | feed_dict = {p_input_data_pos_ids: inp.example['pos_ids'], 62 | p_input_data_chk_ids: inp.example['chk_ids'], 63 | p_is_train: False, 64 | p_sentence_length: inp.max_sentence_length} 65 | feed_dict[p_input_data_word_ids] = inp.example['word_ids'] 66 | feed_dict[p_input_data_wordchr_ids] = inp.example['wordchr_ids'] 67 | if 'elmo' in config.emb_class: 68 | feed_dict[p_elmo_input_data_wordchr_ids] = inp.example['elmo_wordchr_ids'] 69 | if 'bert' in config.emb_class: 70 | feed_dict[p_bert_input_data_token_ids] = inp.example['bert_token_ids'] 71 | feed_dict[p_bert_input_data_token_masks] = inp.example['bert_token_masks'] 72 | feed_dict[p_bert_input_data_segment_ids] = inp.example['bert_segment_ids'] 73 | if 'elmo' in config.emb_class: 74 | feed_dict[p_bert_input_data_elmo_indices] = inp.example['bert_elmo_indices'] 75 | return inp, feed_dict 76 | 77 | def inference(config, frozen_pb_path): 78 | """Inference for bucket 79 | """ 80 | 81 | # load graph_def 82 | graph_def = load_frozen_graph_def(frozen_pb_path) 83 | 84 | # get optimized graph_def 85 | trt_graph_def = trt.create_inference_graph( 86 | input_graph_def=graph_def, 87 | outputs=['logits_indices', 'sentence_lengths'], 88 | max_batch_size=128, 89 | max_workspace_size_bytes=1 << 30, 90 | precision_mode='FP16', # TRT Engine precision "FP32","FP16" or "INT8" 91 | minimum_segment_size=3 # minimum number of nodes in an engine 92 | ) 93 | 94 | # reset graph 95 | tf.reset_default_graph() 96 | 97 | # load optimized graph_def to default graph 98 | graph = load_graph(trt_graph_def, prefix='prefix') 99 | for op in graph.get_operations(): 100 | sys.stderr.write(op.name + '\n') 101 | 102 | # create session with optimized graph 103 | gpu_ops = tf.GPUOptions(per_process_gpu_memory_fraction = 0.50) 104 | session_conf = tf.ConfigProto(allow_soft_placement=True, 105 | log_device_placement=False, 106 | gpu_options=gpu_ops, 107 | inter_op_parallelism_threads=0, 108 | intra_op_parallelism_threads=0) 109 | sess = tf.Session(graph=graph, config=session_conf) 110 | 111 | # mapping output tensors 112 | t_logits_indices = graph.get_tensor_by_name('prefix/logits_indices:0') 113 | t_sentence_lengths = graph.get_tensor_by_name('prefix/sentence_lengths:0') 114 | 115 | num_buckets = 0 116 | total_duration_time = 0.0 117 | bucket = [] 118 | while 1: 119 | try: line = sys.stdin.readline() 120 | except KeyboardInterrupt: break 121 | if not line: break 122 | line = line.strip() 123 | if not line and len(bucket) >= 1: 124 | start_time = time.time() 125 | inp, feed_dict = build_input_feed_dict(graph, bucket, config) 126 | logits_indices, sentence_lengths = sess.run([t_logits_indices, t_sentence_lengths], feed_dict=feed_dict) 127 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 128 | for i in range(len(bucket)): 129 | if 'bert' in config.emb_class: 130 | j = inp.example['bert_wordidx2tokenidx'][0][i] 131 | out = bucket[i] + ' ' + tags[j] 132 | else: 133 | out = bucket[i] + ' ' + tags[i] 134 | sys.stdout.write(out + '\n') 135 | sys.stdout.write('\n') 136 | bucket = [] 137 | duration_time = time.time() - start_time 138 | out = 'duration_time : ' + str(duration_time) + ' sec' 139 | tf.logging.info(out) 140 | num_buckets += 1 141 | total_duration_time += duration_time 142 | if line : bucket.append(line) 143 | if len(bucket) != 0: 144 | start_time = time.time() 145 | inp, feed_dict = build_input_feed_dict(graph, bucket, config) 146 | logits_indices, sentence_lengths = sess.run([t_logits_indices, t_sentence_lengths], feed_dict=feed_dict) 147 | tags = config.logit_indices_to_tags(logits_indices[0], sentence_lengths[0]) 148 | for i in range(len(bucket)): 149 | if 'bert' in config.emb_class: 150 | j = inp.example['bert_wordidx2tokenidx'][0][i] 151 | out = bucket[i] + ' ' + tags[j] 152 | else: 153 | out = bucket[i] + ' ' + tags[i] 154 | sys.stdout.write(out + '\n') 155 | sys.stdout.write('\n') 156 | duration_time = time.time() - start_time 157 | out = 'duration_time : ' + str(duration_time) + ' sec' 158 | tf.logging.info(out) 159 | num_buckets += 1 160 | total_duration_time += duration_time 161 | 162 | out = 'total_duration_time : ' + str(total_duration_time) + ' sec' + '\n' 163 | out += 'average processing time / bucket : ' + str(total_duration_time / num_buckets) + ' sec' 164 | tf.logging.info(out) 165 | 166 | sess.close() 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('--emb_path', type=str, help='path to word embedding vector + vocab(.pkl)', required=True) 171 | parser.add_argument('--config_path', type=str, default='data/config.json', help='path to config.json') 172 | parser.add_argument('--wrd_dim', type=int, help='dimension of word embedding vector', required=True) 173 | parser.add_argument('--word_length', type=int, default=15, help='max word length') 174 | parser.add_argument('--frozen_path', type=str, help='path to frozen model(ex, ./exported/ner_frozen.pb)', required=True) 175 | 176 | args = parser.parse_args() 177 | tf.logging.set_verbosity(tf.logging.INFO) 178 | 179 | args.restore = None 180 | config = Config(args, is_training=False, emb_class='glove', use_crf=True) 181 | inference(config, args.frozen_path) 182 | -------------------------------------------------------------------------------- /embvec.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import argparse 4 | import numpy as np 5 | import pickle as pkl 6 | from random import random 7 | 8 | class EmbVec: 9 | 10 | def __init__(self, args): 11 | """Build embedding, vocabularies, other resources 12 | 13 | Args: 14 | args: args from this script(embvec.py). 15 | """ 16 | self.pad = '#PAD#' 17 | self.unk = '#UNK#' 18 | self.lowercase = True 19 | if args.lowercase == 'False': self.lowercase = False 20 | 21 | self.wrd_vocab = {} # word vocab 22 | self.pad_wid = 0 # for padding word embedding 23 | self.unk_wid = 1 # for unknown word 24 | self.wrd_vocab[self.pad] = self.pad_wid 25 | self.wrd_vocab[self.unk] = self.unk_wid 26 | 27 | self.chr_vocab = {} # character vocab 28 | self.pad_cid = 0 # for padding char embedding 29 | self.unk_cid = 1 # for unknown char 30 | self.chr_vocab[self.pad] = self.pad_cid 31 | self.chr_vocab[self.unk] = self.unk_cid 32 | 33 | self.pos_vocab = {} # pos vocab 34 | self.pad_pid = 0 # for padding pos embedding 35 | self.unk_pid = 1 # for unknown pos 36 | self.pos_vocab[self.pad] = self.pad_pid 37 | self.pos_vocab[self.unk] = self.unk_pid 38 | 39 | self.chk_vocab = {} # chunk vocab 40 | self.pad_kid = 0 # for padding chunk embedding 41 | self.unk_kid = 1 # for unknown chunk 42 | self.chk_vocab[self.pad] = self.pad_kid 43 | self.chk_vocab[self.unk] = self.unk_kid 44 | 45 | self.oot_tid = 0 # out of tag id 46 | self.oot_tag = 'O' # out of tag, this is fixed for convenience 47 | self.xot_tid = 1 # 'X' tag id 48 | self.xot_tag = 'X' # 'X' tag, fixed for convenience 49 | self.tag_vocab = {} # tag vocab (tag -> id) 50 | self.itag_vocab = {} # inverse tag vocab (id -> tag) 51 | self.tag_vocab[self.oot_tag] = self.oot_tid 52 | self.itag_vocab[0] = self.oot_tag 53 | self.tag_vocab[self.xot_tag] = self.xot_tid 54 | self.itag_vocab[1] = self.xot_tag 55 | 56 | # elmo 57 | self.elmo_vocab = {} # elmo vocab 58 | self.elmo_vocab_path = args.elmo_vocab_path 59 | self.elmo_options_path = args.elmo_options_path 60 | self.elmo_weight_path = args.elmo_weight_path 61 | 62 | # bert 63 | self.bert_config_path = args.bert_config_path 64 | self.bert_vocab_path = args.bert_vocab_path 65 | self.bert_do_lower_case = False 66 | if args.bert_do_lower_case == 'True': self.bert_do_lower_case = True 67 | self.bert_init_checkpoint = args.bert_init_checkpoint 68 | self.bert_max_seq_length = args.bert_max_seq_length 69 | self.bert_dim = args.bert_dim 70 | 71 | # build character/pos/chunk/tag/elmo vocab. 72 | cid = self.unk_cid + 1 73 | pid = self.unk_pid + 1 74 | kid = self.unk_kid + 1 75 | tid = self.xot_tid + 1 76 | for line in open(args.train_path): 77 | line = line.strip() 78 | if not line: continue 79 | tokens = line.split() 80 | assert(len(tokens) == 4) 81 | word = tokens[0] 82 | pos = tokens[1] 83 | chk = tokens[2] 84 | tag = tokens[3] 85 | # character vocab 86 | for ch in word: 87 | if ch not in self.chr_vocab: 88 | self.chr_vocab[ch] = cid 89 | cid += 1 90 | # elmo vocab(case sensitive) 91 | if word not in self.elmo_vocab: self.elmo_vocab[word] = 1 92 | else: self.elmo_vocab[word] += 1 93 | # pos vocab 94 | if pos not in self.pos_vocab: 95 | self.pos_vocab[pos] = pid 96 | pid += 1 97 | # chunk vocab 98 | if chk not in self.chk_vocab: 99 | self.chk_vocab[chk] = kid 100 | kid += 1 101 | # tag, itag vocab 102 | if tag not in self.tag_vocab: 103 | self.tag_vocab[tag] = tid 104 | self.itag_vocab[tid] = tag 105 | tid += 1 106 | # write elmo vocab. 107 | if self.elmo_vocab_path: 108 | elmo_vocab_fd = open(self.elmo_vocab_path, 'w') 109 | elmo_vocab_fd.write('' + '\n') 110 | elmo_vocab_fd.write('' + '\n') 111 | elmo_vocab_fd.write('' + '\n') 112 | for word, freq in sorted(self.elmo_vocab.items(), key=lambda x: x[1], reverse=True): 113 | elmo_vocab_fd.write(word + '\n') 114 | elmo_vocab_fd.close() 115 | del(self.elmo_vocab) 116 | 117 | # build word embeddings and word vocab. 118 | wrd_vocab_size = 0 119 | for line in open(args.emb_path): wrd_vocab_size += 1 120 | wrd_vocab_size += 2 # for pad, unk 121 | sys.stderr.write('wrd_vocab_size = %s\n' % (wrd_vocab_size)) 122 | self.wrd_dim = args.wrd_dim 123 | self.wrd_embeddings = np.zeros((wrd_vocab_size, self.wrd_dim)) 124 | # 0 id for padding 125 | vector = np.array([float(0) for i in range(self.wrd_dim)]) 126 | self.wrd_embeddings[self.pad_wid] = vector 127 | # 1 wid for unknown 128 | vector = np.array([random() for i in range(self.wrd_dim)]) 129 | self.wrd_embeddings[self.unk_wid] = vector 130 | wid = self.unk_wid + 1 131 | for line in open(args.emb_path): 132 | line = line.strip() 133 | tokens = line.split() 134 | word = tokens[0] 135 | try: vector = np.array([float(val) for val in tokens[1:]]) 136 | except: continue 137 | if len(vector) != self.wrd_dim: continue 138 | if self.lowercase: word = word.lower() 139 | self.wrd_embeddings[wid] = vector 140 | self.wrd_vocab[word] = wid 141 | wid += 1 142 | 143 | def get_wid(self, word): 144 | if self.lowercase: word = word.lower() 145 | if word in self.wrd_vocab: 146 | return self.wrd_vocab[word] 147 | return self.unk_wid 148 | 149 | def get_cid(self, ch): 150 | if ch in self.chr_vocab: 151 | return self.chr_vocab[ch] 152 | return self.unk_cid 153 | 154 | def get_pid(self, pos): 155 | if pos in self.pos_vocab: 156 | return self.pos_vocab[pos] 157 | return self.unk_pid 158 | 159 | def get_kid(self, chk): 160 | if chk in self.chk_vocab: 161 | return self.chk_vocab[chk] 162 | return self.unk_kid 163 | 164 | def get_tid(self, tag): 165 | if tag in self.tag_vocab: 166 | return self.tag_vocab[tag] 167 | return self.oot_tid 168 | 169 | def get_tag(self, tid): 170 | if tid in self.itag_vocab: 171 | return self.itag_vocab[tid] 172 | return self.oot_tag 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--emb_path', type=str, help='path to a file of word embedding vector(.txt)', required=True) 177 | parser.add_argument('--wrd_dim', type=int, help='embedding vector dimension', required=True) 178 | parser.add_argument('--train_path', type=str, help='path to a train-dev file to build vocaburaries', required=True) 179 | parser.add_argument('--lowercase', type=str, help='apply lower case for word embedding', default=True) 180 | parser.add_argument('--elmo_vocab_path', type=str, help='path to elmo vocab file(write)', default='') 181 | parser.add_argument('--elmo_options_path', type=str, help='path to elmo options file', default='') 182 | parser.add_argument('--elmo_weight_path', type=str, help='path to elmo weight file', default='') 183 | parser.add_argument('--bert_config_path', type=str, help='path to bert config file', default='') 184 | parser.add_argument('--bert_vocab_path', type=str, help='path to bert vocab file', default='') 185 | parser.add_argument('--bert_do_lower_case', type=str, help='apply lower case for bert', default=False) 186 | parser.add_argument('--bert_init_checkpoint', type=str, help='path to bert init checkpoint', default='') 187 | parser.add_argument('--bert_max_seq_length', type=int, help='maximum total input sequence length after WordPiece tokenization.', default=180) 188 | parser.add_argument('--bert_dim', type=int, help='bert output dimension size', default=1024) 189 | args = parser.parse_args() 190 | embvec = EmbVec(args) 191 | pkl.dump(embvec, open(args.emb_path + '.pkl', 'wb')) 192 | 193 | # print all vocab for inference by C++. 194 | # 1. wrd_vocab 195 | print('# wrd_vocab', len(embvec.wrd_vocab)) 196 | for word, wid in embvec.wrd_vocab.items(): 197 | print(word, wid) 198 | # 2. chr_vocab 199 | print('# chr_vocab', len(embvec.chr_vocab)) 200 | for ch, cid in embvec.chr_vocab.items(): 201 | print(ch, cid) 202 | # 3. pos_vocab 203 | print('# pos_vocab', len(embvec.pos_vocab)) 204 | for pos, pid in embvec.pos_vocab.items(): 205 | print(pos, pid) 206 | # 4. chk_vocab 207 | print('# chk_vocab', len(embvec.chk_vocab)) 208 | for chk, kid in embvec.chk_vocab.items(): 209 | print(chk, kid) 210 | # 5. tag_vocab 211 | print('# tag_vocab', len(embvec.tag_vocab)) 212 | for tag, tid, in embvec.tag_vocab.items(): 213 | print(tag, tid) 214 | -------------------------------------------------------------------------------- /inference/python/www/etagger_dm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import logging 5 | from logging.handlers import RotatingFileHandler 6 | import signal 7 | import time 8 | import math 9 | 10 | import tornado.web 11 | import tornado.ioloop 12 | import tornado.autoreload 13 | import tornado.web 14 | import tornado.httpserver 15 | import tornado.process 16 | import tornado.autoreload as autoreload 17 | from tornado.options import define, options 18 | 19 | ############################################################################################### 20 | # etagger 21 | ## do not `import tensorflow` before forking processes. 22 | ## see : https://github.com/tensorflow/tensorflow/issues/5448 23 | path = os.path.dirname(os.path.abspath(__file__)) + '/lib' 24 | sys.path.append(path) 25 | from embvec import EmbVec 26 | from config import Config 27 | # etagger arguments 28 | define('emb_path', default='', help='path to word embedding vector + vocab(.pkl)', type=str) 29 | define('emb_class', default='glove', help='class of embedding(glove, elmo, bert, bert+elmo)', type=str) 30 | define('config_path', default='', help='path to config.json', type=str) 31 | define('wrd_dim', default=100, help='dimension of word embedding vector', type=int) 32 | define('word_length', default=15, help='max word length', type=int) 33 | define('frozen_path', default='', help='path to frozen graph', type=str) 34 | define('restore', default='', help='dummy path for config', type=str) 35 | ############################################################################################### 36 | 37 | ############################################################################################### 38 | # nlp : spacy 39 | import spacy 40 | ############################################################################################### 41 | 42 | from handlers.index import IndexHandler, HCheckHandler, EtaggerHandler, EtaggerTestHandler 43 | define('port', default=8897, help='run on the given port', type=int) 44 | define('debug', default=True, help='run on debug mode', type=bool) 45 | define('process', default=3, help='number of process for service mode', type=int) 46 | 47 | 48 | log = logging.getLogger('tornado.application') 49 | 50 | def setupAppLogger(): 51 | fmtStr = '%(asctime)s - %(levelname)s - %(module)s - %(message)s' 52 | formatter = logging.Formatter(fmt=fmtStr) 53 | 54 | cdir = os.path.dirname(os.path.abspath(options.log_file_prefix)) 55 | logfile = cdir + '/' + 'application.log' 56 | 57 | rotatingHandler = RotatingFileHandler(logfile, 'a', options.log_file_max_size, options.log_file_num_backups) 58 | rotatingHandler.setFormatter(formatter) 59 | 60 | if options.logging != 'none': 61 | log.setLevel(getattr(logging, options.logging.upper())) 62 | else: 63 | log.setLevel(logging.ERROR) 64 | 65 | log.propagate = False 66 | log.addHandler(rotatingHandler) 67 | 68 | return log 69 | 70 | class Application(tornado.web.Application): 71 | def __init__(self): 72 | settings = dict( 73 | static_path = os.path.join(os.path.dirname(__file__), 'static'), 74 | template_path = os.path.join(os.path.dirname(__file__), 'templates'), 75 | autoescape = None, 76 | debug = options.debug, 77 | gzip = True 78 | ) 79 | 80 | handlers = [ 81 | (r'/', IndexHandler), 82 | (r'/_hcheck.hdn', HCheckHandler), 83 | (r'/etagger', EtaggerHandler), 84 | (r'/etaggertest', EtaggerTestHandler), 85 | ] 86 | 87 | tornado.web.Application.__init__(self, handlers, **settings) 88 | autoreload.add_reload_hook(self.finalize) 89 | 90 | self.log = setupAppLogger() 91 | ppid = os.getpid() 92 | self.ppid = ppid 93 | self.log.info('initialize parent process[%s] ... done' % (ppid)) 94 | 95 | ############################################################################################### 96 | # create etagger config only once 97 | self.config = Config(options, is_training=False, emb_class=options.emb_class, use_crf=True) 98 | self.log.info('initialize config on parent process[%s] ... done' % (ppid)) 99 | # create nlp(spacy) only once 100 | self.nlp = spacy.load('en') 101 | self.log.info('initialize spacy on parent process[%s] ... done' % (ppid)) 102 | ############################################################################################### 103 | 104 | log.info('http start...') 105 | 106 | def initialize(self) : 107 | ############################################################################################### 108 | # tensorflow should be imported here for child process. 109 | # see : https://github.com/tensorflow/tensorflow/issues/5448 110 | import tensorflow as tf 111 | ## for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369 112 | tf.contrib.rnn 113 | ## for QRNN 114 | try: import qrnn 115 | except: sys.stderr.write('import qrnn, failed\n') 116 | ############################################################################################### 117 | 118 | pid = os.getpid() 119 | self.log.info('initialize per child process[%s] ...' % (pid)) 120 | ############################################################################################### 121 | # loading frozen model for each child process. 122 | self.etagger = {} 123 | graph = self.load_frozen_graph(tf, options.frozen_path) 124 | gpu_ops = tf.GPUOptions() 125 | session_conf = tf.ConfigProto(allow_soft_placement=True, 126 | log_device_placement=False, 127 | gpu_options=gpu_ops, 128 | inter_op_parallelism_threads=0, 129 | intra_op_parallelism_threads=0) 130 | sess = tf.Session(graph=graph, config=session_conf) 131 | m = {} 132 | m['sess'] = sess 133 | m['graph'] = graph 134 | self.etagger[pid] = m 135 | ############################################################################################### 136 | self.log.info('initialize per child process[%s] ... done' % (pid)) 137 | 138 | def load_frozen_graph(self, tf, frozen_graph_filename, prefix='prefix'): 139 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 140 | graph_def = tf.GraphDef() 141 | graph_def.ParseFromString(f.read()) 142 | with tf.Graph().as_default() as graph: 143 | tf.import_graph_def( 144 | graph_def, 145 | input_map=None, 146 | return_elements=None, 147 | op_dict=None, 148 | producer_op_list=None, 149 | name=prefix, 150 | ) 151 | return graph 152 | 153 | def finalize(self): 154 | # finalize resources 155 | self.log.info('finalize resources...') 156 | ## finalize something.... 157 | for pid, m in self.etagger.items() : 158 | sess = m['sess'] 159 | sess.close() 160 | 161 | log.info('Close logger...') 162 | x = list(log.handlers) 163 | for i in x: 164 | log.removeHandler(i) 165 | i.flush() 166 | i.close() 167 | self.log.info('finalize resources... done') 168 | 169 | def main(): 170 | tornado.options.parse_command_line() 171 | 172 | ''' 173 | # you can prefork tornado before creating application. 174 | # code snippet: 175 | sockets = tornado.netutil.bind_sockets(options.port) 176 | tornado.process.fork_processes(options.process) 177 | application = Application() 178 | httpServer = tornado.httpserver.HTTPServer(application, no_keep_alive=True) 179 | httpServer.add_sockets(sockets) 180 | ''' 181 | 182 | application = Application() 183 | httpServer = tornado.httpserver.HTTPServer(application, no_keep_alive=True) 184 | if options.debug == True : 185 | httpServer.listen(options.port) 186 | application.initialize() 187 | else : 188 | httpServer.bind(options.port) 189 | if options.process == 0 : 190 | httpServer.start(0) # Forks multiple sub-processes, maximum to number of cores 191 | else : 192 | if options.process < 0 : 193 | options.process = 1 194 | httpServer.start(options.process) # Forks multiple sub-processes, given number 195 | pid = os.getpid() 196 | if pid != application.ppid : 197 | application.initialize() 198 | 199 | MAX_WAIT_SECONDS_BEFORE_SHUTDOWN = 3 200 | 201 | def sig_handler(sig, frame): 202 | log.warning('Caught signal: %s', sig) 203 | tornado.ioloop.IOLoop.instance().add_callback(shutdown) 204 | 205 | def shutdown(): 206 | log.info('Stopping http server') 207 | httpServer.stop() 208 | 209 | log.info('Will shutdown in %s seconds ...', MAX_WAIT_SECONDS_BEFORE_SHUTDOWN) 210 | io_loop = tornado.ioloop.IOLoop.instance() 211 | 212 | deadline = time.time() + MAX_WAIT_SECONDS_BEFORE_SHUTDOWN 213 | 214 | def stop_loop(): 215 | now = time.time() 216 | if now < deadline and (io_loop._callbacks or io_loop._timeouts): 217 | io_loop.add_timeout(now + 1, stop_loop) 218 | else: 219 | io_loop.stop() 220 | log.info('Shutdown') 221 | 222 | stop_loop() 223 | 224 | signal.signal(signal.SIGTERM, sig_handler) 225 | signal.signal(signal.SIGINT, sig_handler) 226 | 227 | tornado.ioloop.IOLoop.instance().start() 228 | 229 | log.info('Exit...') 230 | 231 | if __name__ == '__main__': 232 | main() 233 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import numpy as np 4 | import pickle as pkl 5 | import json 6 | 7 | class Config: 8 | 9 | def __init__(self, args, is_training=True, emb_class='glove', use_crf=True): 10 | """Set all parameters for model. 11 | 12 | Args: 13 | args: args from train.py, inference,py. 14 | is_training: True for training(from train.py), False for inference(inference.py) 15 | emb_class: class of embedding, glove | elmo | bert | bert+elmo. 16 | use_crf: if True, use crf decoder(bypass). 17 | """ 18 | config = self.__load_config(args) 19 | self.emb_path = args.emb_path 20 | self.embvec = pkl.load(open(self.emb_path, 'rb')) # resources(glove, vocab, path, etc) 21 | self.wrd_dim = args.wrd_dim # dim of word embedding(glove) 22 | self.chr_dim = config['chr_dim'] # dim of character embedding 23 | self.pos_dim = config['pos_dim'] # dim of part of speech embedding 24 | self.chk_dim = config['chk_dim'] # dim of chunk embedding 25 | self.class_size = len(self.embvec.tag_vocab) # number of class(tags) 26 | self.word_length = args.word_length # maximum character size of word for convolution 27 | self.restore = args.restore # checkpoint path if available 28 | self.use_crf = use_crf # use crf decoder or not 29 | self.emb_class = emb_class # class of embedding(glove, elmo, bert, bert+elmo) 30 | 31 | self.keep_prob = config['keep_prob'] # keep probability for dropout 32 | self.chr_conv_type = config['chr_conv_type'] # conv1d | conv2d 33 | self.filter_sizes = config['filter_sizes'] # filter sizes 34 | self.num_filters = config['num_filters'] # number of filters 35 | self.highway_used = config['highway_used'] # use highway network on the concatenated input 36 | self.rnn_used = config['rnn_used'] # use rnn layer or not 37 | self.rnn_num_layers = config['rnn_num_layers'] # number of RNN layers 38 | self.rnn_type = config['rnn_type'] # normal | fused | qrnn 39 | self.rnn_size = config['rnn_size'] # size of RNN hidden unit 40 | self.tf_used = config['tf_used'] # use transformer encoder layer or not 41 | self.tf_num_layers = config['tf_num_layers'] # number of layers for transformer encoder 42 | self.tf_keep_prob = config['tf_keep_prob'] # keep probability for transformer encoder 43 | self.tf_mh_num_heads = config['tf_mh_num_heads'] # number of head for multi head attention 44 | self.tf_mh_num_units = config['tf_mh_num_units'] # Q,K,V dimension for multi head attention 45 | self.tf_mh_keep_prob = config['tf_mh_keep_prob'] # keep probability for multi head attention 46 | self.tf_ffn_kernel_size = config['tf_ffn_kernel_size'] # conv1d kernel size for feed forward net 47 | self.tf_ffn_keep_prob = config['tf_ffn_keep_prob'] # keep probability for feed forward net 48 | 49 | self.starter_learning_rate = config['starter_learning_rate'] # default learning rate 50 | self.num_train_steps = 0 # number of total training steps, assigned by update() 51 | self.num_warmup_epoch = config['num_warmup_epoch'] # number of warmup epoch 52 | self.num_warmup_steps = 0 # number of warmup steps, assigned by update() 53 | self.decay_steps = config['decay_steps'] 54 | self.decay_rate = config['decay_rate'] 55 | self.clip_norm = config['clip_norm'] 56 | if self.tf_used: # modified for transformer 57 | self.starter_learning_rate = config['starter_learning_rate_for_tf'] 58 | if self.rnn_type == 'qrnn': # modified for QRNN 59 | self.qrnn_size = config['qrnn_size'] # size of QRNN hidden units(number of filters) 60 | self.qrnn_filter_size = config['qrnn_filter_size'] # size of filter for QRNN 61 | self.rnn_num_layers = config['qrnn_num_layers'] 62 | 63 | self.is_training = is_training 64 | if self.is_training: 65 | self.epoch = args.epoch 66 | self.batch_size = args.batch_size 67 | self.checkpoint_dir = args.checkpoint_dir 68 | self.summary_dir = args.summary_dir 69 | 70 | '''for CRZ wighout chk 71 | self.chk_dim = 10 72 | self.highway_used = False 73 | ''' 74 | '''for CRZ with chk 75 | self.chk_dim = 64 76 | self.highway_used = True 77 | ''' 78 | 79 | if 'elmo' in self.emb_class: 80 | from bilm import Batcher, BidirectionalLanguageModel 81 | self.word_length = config['elmo_word_length'] # replace to fixed word length for the pre-trained elmo : 'max_characters_per_token' 82 | self.elmo_batcher = Batcher(self.embvec.elmo_vocab_path, self.word_length) # map text to character ids 83 | self.elmo_bilm = BidirectionalLanguageModel(self.embvec.elmo_options_path, self.embvec.elmo_weight_path) # biLM graph 84 | self.elmo_keep_prob = config['elmo_keep_prob'] 85 | '''for KOR 86 | self.rnn_size = 250 87 | ''' 88 | if 'bert' in self.emb_class: 89 | from bert import modeling 90 | from bert import tokenization 91 | self.bert_config = modeling.BertConfig.from_json_file(self.embvec.bert_config_path) 92 | self.bert_tokenizer = tokenization.FullTokenizer( 93 | vocab_file=self.embvec.bert_vocab_path, do_lower_case=self.embvec.bert_do_lower_case) 94 | self.bert_init_checkpoint = self.embvec.bert_init_checkpoint 95 | self.bert_max_seq_length = self.embvec.bert_max_seq_length 96 | self.bert_dim = self.embvec.bert_dim 97 | self.bert_keep_prob = config['bert_keep_prob'] 98 | self.use_bert_optimization = config['use_bert_optimization'] 99 | self.num_warmup_epoch = config['num_warmup_epoch_for_bert'] 100 | '''for KOR, CRZ 101 | self.rnn_size = 256 102 | self.starter_learning_rate = 5e-5 103 | self.num_warmup_epoch = 1 104 | self.decay_steps = 5000 105 | ''' 106 | '''for KOR(CLOVA NER) 107 | self.pos_dim = 100 108 | self.starter_learning_rate = 5e-5 109 | self.num_warmup_epoch = 3 110 | self.decay_rate = 1.0 111 | ''' 112 | 113 | def __load_config(self, args): 114 | """Load config from file. 115 | """ 116 | try: 117 | with open(args.config_path, 'r', encoding='utf-8') as f: 118 | config = json.load(f) 119 | except Exception as e: 120 | config = dict() 121 | return config 122 | 123 | def update(self, data): 124 | """Update num_train_steps, num_warmup_steps after reading training data 125 | 126 | Args: 127 | data: an instance of Input class, training data. 128 | """ 129 | if not self.is_training: return False 130 | self.num_train_steps = int((data.num_examples / self.batch_size) * self.epoch) 131 | self.num_warmup_steps = self.num_warmup_epoch * int(data.num_examples / self.batch_size) 132 | if self.num_warmup_steps == 0: self.num_warmup_steps = 1 # prevent dividing by zero 133 | return True 134 | 135 | # ----------------------------------------------------------------------------- 136 | # utility 137 | # ----------------------------------------------------------------------------- 138 | 139 | def logit_to_tags(self, logit, length): 140 | """Convert logit to tags. 141 | 142 | Args: 143 | logit: [sentence_length, class_size] 144 | length: int 145 | Returns: 146 | tag sequence(size length) 147 | """ 148 | logit = logit[0:length] 149 | # [length] 150 | pred_list = np.argmax(logit, 1).tolist() 151 | tags = [] 152 | for tid in pred_list: 153 | tag = self.embvec.get_tag(tid) 154 | tags.append(tag) 155 | return tags 156 | 157 | def logit_indices_to_tags(self, logit_indices, length): 158 | """Convert logit_indices to tags. 159 | 160 | Args: 161 | logit_indices: [sentence_length] 162 | length: int 163 | Returns: 164 | tag sequence(size length) 165 | """ 166 | pred_list = logit_indices[0:length] 167 | tags = [] 168 | for tid in pred_list: 169 | tag = self.embvec.get_tag(tid) 170 | tags.append(tag) 171 | return tags 172 | 173 | def logits_indices_to_tags_seq(self, logits_indices, lengths): 174 | """Convert logits_indices to sequence of tags. 175 | 176 | Args: 177 | logits_indices: [batch_size, sentence_length] 178 | lengths: [batch_size] 179 | Returns: 180 | sequence of tags 181 | """ 182 | tags_seq = [] 183 | for logit_indices, length in zip(logits_indices, lengths): 184 | tags = self.logit_indices_to_tags(logit_indices, length) 185 | tags_seq.append(tags) 186 | return tags_seq 187 | -------------------------------------------------------------------------------- /inference/cc/src/Etagger.cc: -------------------------------------------------------------------------------- 1 | #include "Etagger.h" 2 | 3 | /* 4 | * public methods 5 | */ 6 | 7 | Etagger::Etagger(string frozen_graph_fn, string vocab_fn, int word_length, bool lowercase, bool is_memmapped, int num_threads) 8 | { 9 | /* 10 | * Args: 11 | * frozen_graph_fn: path to a file of frozen graph. 12 | * vocab_fn: path to a vocab file. 13 | * word_length: max character size of word. ex) 15 14 | * lowercase: true if vocab file was all lowercased, otherwise false. 15 | * is_memmapped: true if frozen graph was memmapped, otherwise false. 16 | * num_threads: number of threads for tensorflow. 0 for all cores, n for n cores. 17 | */ 18 | 19 | this->util = new TFUtil(); 20 | this->sess = NULL; 21 | if( is_memmapped ) { 22 | tensorflow::MemmappedEnv* memmapped_env = this->util->CreateMemmappedEnv(frozen_graph_fn); 23 | this->sess = this->util->CreateSession(memmapped_env, num_threads); 24 | TF_CHECK_OK(this->util->LoadFrozenMemmappedModel(memmapped_env, this->sess)); 25 | } else { 26 | this->sess = this->util->CreateSession(NULL, num_threads); 27 | TF_CHECK_OK(this->util->LoadFrozenModel(this->sess, frozen_graph_fn)); 28 | } 29 | cerr << "Loading graph and creating session ... done" << endl; 30 | 31 | this->config = new Config(word_length); 32 | cerr << "Loading Config ... done" << endl; 33 | cerr << "Loading Vocab From " << vocab_fn; 34 | this->vocab = new Vocab(vocab_fn, lowercase); 35 | cerr << " ... done" << endl; 36 | this->config->SetClassSize(this->vocab->GetTagVocabSize()); 37 | cerr << "Class size: " << this->config->GetClassSize() << endl; 38 | } 39 | 40 | int Etagger::Analyze(vector& bucket) 41 | { 42 | /* 43 | * Args: 44 | * bucket: list of 'word pos chk tag' 45 | * 46 | * Returns: 47 | * number of tokens. 48 | * -1 if failed. 49 | * analyzed results are saved to bucket itself. 50 | * bucket: list of 'word pos chk tag predict' 51 | */ 52 | Input input = Input(this->config, this->vocab, bucket); 53 | int max_sentence_length = input.GetMaxSentenceLength(); 54 | tensorflow::Tensor* sentence_word_ids = input.GetSentenceWordIds(); 55 | tensorflow::Tensor* sentence_wordchr_ids = input.GetSentenceWordChrIds(); 56 | tensorflow::Tensor* sentence_pos_ids = input.GetSentencePosIds(); 57 | tensorflow::Tensor* sentence_chk_ids = input.GetSentenceChkIds(); 58 | tensorflow::Tensor* sentence_length = input.GetSentenceLength(); 59 | tensorflow::Tensor* is_train = input.GetIsTrain(); 60 | #ifdef DEBUG 61 | cout << "[word ids]" << endl; 62 | auto data_word_ids = sentence_word_ids->flat().data(); 63 | for( int i = 0; i < max_sentence_length; i++ ) { 64 | cout << data_word_ids[i] << " "; 65 | } 66 | cout << endl; 67 | cout << "[wordchr ids]" << endl; 68 | auto data_wordchr_ids = sentence_wordchr_ids->flat().data(); 69 | int word_length = this->config->GetWordLength(); 70 | for( int i = 0; i < max_sentence_length; i++ ) { 71 | for( int j = 0; j < word_length; j++ ) { 72 | cout << data_wordchr_ids[i*word_length + j] << " "; 73 | } 74 | cout << endl; 75 | } 76 | cout << "[pos ids]" << endl; 77 | auto data_pos_ids = sentence_pos_ids->flat().data(); 78 | for( int i = 0; i < max_sentence_length; i++ ) { 79 | cout << data_pos_ids[i] << " "; 80 | } 81 | cout << endl; 82 | cout << "[chk ids]" << endl; 83 | auto data_chk_ids = sentence_chk_ids->flat().data(); 84 | for( int i = 0; i < max_sentence_length; i++ ) { 85 | cout << data_chk_ids[i] << " "; 86 | } 87 | cout << endl; 88 | cout << "[sentence length]" << endl; 89 | auto data_sentence_length = sentence_length->flat().data(); 90 | cout << *data_sentence_length << endl; 91 | cout << "[is_train]" << endl; 92 | auto data_is_train = is_train->flat().data(); 93 | cout << *data_is_train << endl; 94 | 95 | cout << endl; 96 | #endif 97 | tensor_dict feed_dict = { 98 | {"input_data_word_ids", *sentence_word_ids}, 99 | {"input_data_wordchr_ids", *sentence_wordchr_ids}, 100 | {"input_data_pos_ids", *sentence_pos_ids}, 101 | {"input_data_chk_ids", *sentence_chk_ids}, 102 | {"sentence_length", *sentence_length}, 103 | {"is_train", *is_train}, 104 | }; 105 | std::vector outputs; 106 | tensorflow::Status run_status = this->sess->Run(feed_dict, {"logits_indices"}, {}, &outputs); 107 | if( !run_status.ok() ) { 108 | cerr << run_status.error_message() << endl; 109 | return -1; 110 | } 111 | /* 112 | cout << "logits_indices " << outputs[0].DebugString() << endl; 113 | */ 114 | int class_size = this->config->GetClassSize(); 115 | tensorflow::TTypes::Flat logits_indices_flat = outputs[0].flat(); 116 | for( int i = 0; i < max_sentence_length; i++ ) { 117 | int max_idx = logits_indices_flat(i); 118 | string tag = this->vocab->GetTag(max_idx); 119 | bucket[i] = bucket[i] + " " + tag; 120 | } 121 | return max_sentence_length; 122 | } 123 | 124 | Etagger::~Etagger() 125 | { 126 | delete this->vocab; 127 | delete this->config; 128 | this->util->DestroySession(this->sess); 129 | delete this->util; 130 | } 131 | 132 | /* 133 | * public methods for C 134 | */ 135 | 136 | extern "C" { 137 | 138 | Etagger* initialize(const char* frozen_graph_fn, 139 | const char* vocab_fn, 140 | int word_length, 141 | int lowercase, 142 | int is_memmapped, 143 | int num_threads) 144 | { 145 | /* 146 | * Args: 147 | * frozen_graph_fn: path to a file of frozen graph. 148 | * vocab_fn: path to a vocab file. 149 | * word_length: max character size of word. ex) 15 150 | * lowercase: 1 if vocab file was all lowercased, otherwise 0. 151 | * is_memmapped: 1 if frozen graph was memmapped, otherwise 0. 152 | * num_threads: number of threads for tensorflow. 0 for all cores, n for n cores. 153 | * 154 | * Python: 155 | * import ctypes as c 156 | * so_path = 'path-to/lib' + '/' + 'libetagger.so' 157 | * libetagger = c.cdll.LoadLibrary(so_path) 158 | * 159 | * frozen_graph_fn = c.c_char_p(b'path-to/ner_frozen.pb') 160 | * vocab_fn = c.c_char_p(b'path-to/vocab.txt') 161 | * word_length = c.c_int(15) 162 | * lowercase = c.c_int(1) 163 | * is_memmapped = c.c_int(1) 164 | * num_threads = c.c_int(0) 165 | * etagger = libetagger.initialize(frozen_graph_fn, vocab_fn, word_length, lowercase, is_memmapped, num_threads) 166 | */ 167 | bool b_lowercase = false; 168 | if( lowercase ) b_lowercase = true; 169 | bool b_is_memmapped = false; 170 | if( is_memmapped ) b_is_memmapped = true; 171 | return new Etagger(frozen_graph_fn, vocab_fn, word_length, b_lowercase, b_is_memmapped, num_threads); 172 | } 173 | 174 | static void split(string s, vector& tokens) 175 | { 176 | istringstream iss(s); 177 | for( string ts; iss >> ts; ) 178 | tokens.push_back(ts); 179 | } 180 | 181 | int analyze(Etagger* etagger, struct result_obj* robj, int max) 182 | { 183 | /* 184 | * Args: 185 | * etagger: an instance of Etagger , i.e, handler. 186 | * robj: list of result_obj. 187 | * max: max size of robj. 188 | * 189 | * Python: 190 | * class Result( c.Structure ): 191 | * _fields_ = [('word', c.c_char * MAX_WORD ), 192 | * ('pos', c.c_char * MAX_POS ), 193 | * ('chk', c.c_char * MAX_CHK ), 194 | * ('tag', c.c_char * MAX_TAG ), 195 | * ('predict', c.c_char * MAX_TAG )] 196 | * 197 | * bucket = build_bucket(nlp, line) 198 | * # ex) bucket 199 | * # word pos chk tag 200 | * # ... 201 | * # jeju NNP O B-GPE 202 | * # island NN O O 203 | * # ... 204 | * max_sentence_length = len(bucket) 205 | * robj = (Result * max_sentence_length)() 206 | * # fill robj from bucket. 207 | * for i in range(max_sentence_length): 208 | * tokens = bucket[i].split() 209 | * robj[i].word = tokens[0].encode('utf-8') 210 | * robj[i].pos = tokens[1].encode('utf-8') 211 | * robj[i].chk = tokens[2].encode('utf-8') 212 | * robj[i].tag = tokens[3].encode('utf-8') 213 | * robj[i].predict = b'O' 214 | * c_max_sentence_length = c.c_int(max_sentence_length) 215 | * ret = libetagger.analyze(etagger, c.byref(robj), c_max_sentence_length) 216 | * out = [] 217 | * for r in robj: 218 | * out.append([r.word.decode('utf-8'), 219 | * r.pos.decode('utf-8'), 220 | * r.chk.decode('utf-8'), 221 | * r.tag.decode('utf-8'), 222 | * r.predict.decode('utf-8')]) 223 | * 224 | * Returns: 225 | * number of tokens. 226 | * -1 if failed. 227 | * analyzed results are saved to robj itself. 228 | */ 229 | vector bucket; 230 | 231 | // build bucket from robj 232 | for( int i = 0; i < max; i++ ) { 233 | string s = string(robj[i].word) + " " + 234 | string(robj[i].pos) + " " + 235 | string(robj[i].chk) + " " + 236 | string(robj[i].tag); 237 | bucket.push_back(s); 238 | } 239 | // bucket: list of 'word pos chk tag' 240 | 241 | int ret = etagger->Analyze(bucket); 242 | if( ret < 0 ) return -1; 243 | // bucket: list of 'word pos chk tag predict' 244 | 245 | // assign predict to robj 246 | for( int i = 0; i < max; i++ ) { 247 | vector tokens; 248 | split(bucket[i], tokens); 249 | string predict = tokens[4]; // last one 250 | strncpy(robj[i].predict, predict.c_str(), MAX_TAG); 251 | } 252 | 253 | return ret; 254 | } 255 | 256 | void finalize(Etagger* etagger) 257 | { 258 | /* 259 | * Args: 260 | * etagger: an instance of Etagger , handler 261 | * Python: 262 | * libetagger.finalize(etagger) 263 | */ 264 | if( etagger ) { 265 | delete etagger; 266 | } 267 | } 268 | 269 | } 270 | --------------------------------------------------------------------------------