├── UNF ├── __init__.py ├── models │ ├── __init__.py │ ├── model_util.py │ ├── model.py │ ├── lstm_crf_predictor.py │ ├── model_loader.py │ ├── model_trace.py │ ├── fasttext.py │ ├── textcnn.py │ ├── self_attention.py │ ├── predictor.py │ ├── lstm_crf.py │ ├── leam.py │ └── dpcnn.py ├── modules │ ├── __init__.py │ ├── embedding │ │ ├── __init__.py │ │ └── embedding.py │ ├── encoder │ │ ├── __init__.py │ │ ├── full_connect.py │ │ ├── lstm_encoder.py │ │ ├── self_attention_encoder.py │ │ └── cnn_maxpool.py │ ├── base_type.py │ ├── module_util.py │ └── decoder │ │ └── crf.py ├── training │ ├── __init__.py │ ├── optimizer.py │ ├── loss.py │ ├── learner_loader.py │ └── metric.py ├── web_server │ ├── static │ │ ├── css │ │ │ ├── style.css │ │ │ ├── jquery.fileupload-ui-noscript.css │ │ │ ├── demo-ie8.css │ │ │ ├── jquery.fileupload-noscript.css │ │ │ ├── jquery.fileupload.css │ │ │ ├── jquery.fileupload-ui.css │ │ │ └── demo.css │ │ ├── img │ │ │ ├── 123.png │ │ │ ├── loading.gif │ │ │ └── progressbar.gif │ │ ├── fonts │ │ │ ├── glyphicons-halflings-regular.eot │ │ │ ├── glyphicons-halflings-regular.ttf │ │ │ ├── glyphicons-halflings-regular.woff │ │ │ └── glyphicons-halflings-regular.woff2 │ │ └── js │ │ │ ├── main.js │ │ │ ├── cors │ │ │ ├── jquery.xdr-transport.js │ │ │ └── jquery.postmessage-transport.js │ │ │ ├── jquery.fileupload-audio.js │ │ │ ├── jquery.fileupload-video.js │ │ │ ├── app.js │ │ │ ├── jquery.fileupload-validate.js │ │ │ ├── jquery.fileupload-jquery-ui.js │ │ │ ├── jquery.fileupload-process.js │ │ │ ├── jquery.iframe-transport.js │ │ │ ├── jquery.fileupload-image.js │ │ │ └── vendor │ │ │ └── jquery.ui.widget.js │ ├── run.py │ ├── server.py │ └── templates │ │ └── web.html ├── trace │ ├── parse_vocab.py │ ├── CMakeLists.txt │ ├── lib │ │ ├── tokenizer.h │ │ ├── tokenizer.cc │ │ ├── string_util.h │ │ └── string_util.cc │ └── predict.cc ├── test │ └── data │ │ ├── test_tokenizer.py │ │ └── test_data_loader.py ├── data │ ├── __init__.py │ ├── field.py │ ├── tokenizer.py │ └── data_loader.py ├── train_flow.py ├── conf │ ├── fasttext_conf.py │ ├── leam_conf.py │ ├── dpcnn_conf.py │ ├── textcnn_conf.py │ ├── selfattention_conf.py │ └── lstm_crf_conf.py ├── trace.py ├── score_flow.py └── common_util │ └── ner_p_r_f_cal.py ├── pic ├── make.png ├── cmake.png ├── module.png ├── framework.png ├── web_demo.png ├── tensorboard1.png └── tensorboard2.png ├── requirements.txt ├── .gitignore ├── README.md └── LICENSE /UNF/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UNF/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UNF/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UNF/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UNF/modules/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UNF/modules/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pic/make.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/make.png -------------------------------------------------------------------------------- /pic/cmake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/cmake.png -------------------------------------------------------------------------------- /pic/module.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/module.png -------------------------------------------------------------------------------- /pic/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/framework.png -------------------------------------------------------------------------------- /pic/web_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/web_demo.png -------------------------------------------------------------------------------- /pic/tensorboard1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/tensorboard1.png -------------------------------------------------------------------------------- /pic/tensorboard2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/pic/tensorboard2.png -------------------------------------------------------------------------------- /UNF/training/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam, SGD, SparseAdam, Adagrad, Adadelta, RMSprop -------------------------------------------------------------------------------- /UNF/web_server/static/css/style.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | body { 4 | padding-top: 60px; 5 | } 6 | */ -------------------------------------------------------------------------------- /UNF/web_server/static/img/123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/img/123.png -------------------------------------------------------------------------------- /UNF/training/loss.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | from torch.nn import CrossEntropyLoss 4 | from torch.nn import BCEWithLogitsLoss -------------------------------------------------------------------------------- /UNF/web_server/static/img/loading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/img/loading.gif -------------------------------------------------------------------------------- /UNF/web_server/static/img/progressbar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/img/progressbar.gif -------------------------------------------------------------------------------- /UNF/web_server/static/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /UNF/web_server/static/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apex.egg==info 2 | torchtext==0.5.0 3 | torch==1.0.0 4 | apex==0.9.10dev 5 | Flask==1.1.1 6 | spacy==2.2.3 7 | tensorboardX==2.0 8 | tornado==6.0.3 9 | -------------------------------------------------------------------------------- /UNF/web_server/static/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /UNF/web_server/static/fonts/glyphicons-halflings-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waterzxj/UNF/HEAD/UNF/web_server/static/fonts/glyphicons-halflings-regular.woff2 -------------------------------------------------------------------------------- /UNF/trace/parse_vocab.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | import json 5 | 6 | vocab = json.load(open("vocab.txt")) 7 | for item in vocab: 8 | print(item) 9 | 10 | -------------------------------------------------------------------------------- /UNF/test/data/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from UNF.data.tokenizer import WhitespaceTokenizer, SpacyTokenizer 4 | 5 | class TestToken(TestCase): 6 | 7 | def test_whitespacetokenizer(self): 8 | self.assertEqual(WhitespaceTokenizer()("a b c"), ["a", "b", "c"]) 9 | -------------------------------------------------------------------------------- /UNF/trace/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(custom_ops) 3 | 4 | find_package(Torch REQUIRED) 5 | 6 | add_executable(predict predict.cc lib/tokenizer.cc lib/string_util.cc) 7 | target_include_directories(predict PUBLIC "lib") 8 | target_link_libraries(predict "${TORCH_LIBRARIES}") 9 | set_property(TARGET predict PROPERTY CXX_STANDARD 11) 10 | -------------------------------------------------------------------------------- /UNF/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import DataLoader 2 | from .field import WordField, CharField, SiteField 3 | from .tokenizer import BaseTokenizer, WhitespaceTokenizer, SpacyTokenizer 4 | 5 | 6 | __version__ = "0.1.0" 7 | 8 | __all__ = ["DataLoader", 9 | "WordField", "CharField", 10 | "SiteField", "BaseTokenizer", 11 | "WhitespaceTokenizer", "SpacyTokenizer"] 12 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/jquery.fileupload-ui-noscript.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload UI Plugin NoScript CSS 8.8.5 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2012, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | .fileinput-button i, 14 | .fileupload-buttonbar .delete, 15 | .fileupload-buttonbar .toggle { 16 | display: none; 17 | } 18 | -------------------------------------------------------------------------------- /UNF/trace/lib/tokenizer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | class Tokenizer { 9 | public: 10 | Tokenizer(const std::string &vocab_path); 11 | void token2id(const std::vector &segment, std::vector &t2id); 12 | void tokenize(const std::string &text, std::vector& segment); 13 | uint32_t get_pad_index(); 14 | 15 | private: 16 | std::unordered_map vocab_; 17 | }; 18 | -------------------------------------------------------------------------------- /UNF/models/model_util.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import json 3 | 4 | class Config(object): 5 | @classmethod 6 | def from_dict(cls, json_object): 7 | config = cls() 8 | for key, value in json_object.items(): 9 | config.__dict__[key] = value 10 | return config 11 | 12 | @classmethod 13 | def from_json_file(cls, json_file): 14 | with open(json_file, "r", encoding='utf-8') as reader: 15 | text = reader.read() 16 | return cls.from_dict(json.loads(text)) 17 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/demo-ie8.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload Demo CSS Fixes for IE<9 1.0.0 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2013, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | .navigation { 14 | list-style: none; 15 | padding: 0; 16 | margin: 1em 0; 17 | } 18 | .navigation li { 19 | display: inline; 20 | margin-right: 10px; 21 | } 22 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/jquery.fileupload-noscript.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload Plugin NoScript CSS 1.2.0 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2013, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | .fileinput-button input { 14 | position: static; 15 | opacity: 1; 16 | filter: none; 17 | font-size: inherit; 18 | direction: inherit; 19 | } 20 | .fileinput-button span { 21 | display: none; 22 | } 23 | -------------------------------------------------------------------------------- /UNF/data/field.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 对处理数据域的抽象 4 | """ 5 | from torchtext.data.field import RawField, Field, LabelField 6 | 7 | 8 | class WordField(Field): 9 | """ 10 | 数据词域的抽象 11 | """ 12 | def __init__(self, **kwarg): 13 | print(kwarg) 14 | super(WordField, self).__init__(**kwarg) 15 | 16 | class CharField(Field): 17 | """ 18 | 数据字符域的抽象 19 | """ 20 | def __init__(self, **kwarg): 21 | super(CharField, self).__init__(**kwarg) 22 | 23 | class SiteField(Field): 24 | """ 25 | 站点域的抽象 26 | """ 27 | def __init__(self, **kwarg): 28 | super(SiteField, self).__init__(**kwarg) 29 | -------------------------------------------------------------------------------- /UNF/train_flow.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 训练一个文本分类模型的工作流 4 | 加载数据 -> 加载模型 -> 训练保存 5 | 6 | 以下提供了各个分类模型的conf文件 7 | """ 8 | import sys 9 | from conf.textcnn_conf import data_loader_conf,model_conf,learner_conf 10 | from data.data_loader import DataLoader 11 | from models.model_loader import ModelLoader 12 | from training.learner_loader import LearnerLoader 13 | 14 | 15 | data_loader = DataLoader(data_loader_conf) 16 | train_iter, dev_iter, test_iter = data_loader.generate_dataset() 17 | 18 | model, model_conf = ModelLoader.from_params(model_conf, data_loader.fields) 19 | learner = LearnerLoader.from_params(model, train_iter, dev_iter, learner_conf, test_iter=test_iter, fields=data_loader.fields, model_conf=model_conf) 20 | 21 | learner.learn() 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /UNF/web_server/run.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from flask import Flask 4 | from server import debug 5 | from tornado.wsgi import WSGIContainer 6 | from tornado.httpserver import HTTPServer 7 | from tornado.ioloop import IOLoop 8 | 9 | #app = Flask(__name__, static_url_path='/sim/static') 10 | app = Flask(__name__) 11 | 12 | # Register blueprint 13 | app.register_blueprint(debug, url_prefix='/debug') 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | print('start run_app...') 20 | # asynchronous load data 21 | # load_data() 22 | http_server = HTTPServer(WSGIContainer(app)) 23 | http_server.listen(4444) 24 | IOLoop.instance().start() 25 | # port = 4444 26 | # app.run(host='0.0.0.0', port=port, debug=True) 27 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/jquery.fileupload.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload Plugin CSS 1.3.0 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2013, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | .fileinput-button { 14 | position: relative; 15 | overflow: hidden; 16 | } 17 | .fileinput-button input { 18 | position: absolute; 19 | top: 0; 20 | right: 0; 21 | margin: 0; 22 | opacity: 0; 23 | -ms-filter: 'alpha(opacity=0)'; 24 | font-size: 200px; 25 | direction: ltr; 26 | cursor: pointer; 27 | } 28 | 29 | /* Fixes for IE < 8 */ 30 | @media screen\9 { 31 | .fileinput-button input { 32 | filter: alpha(opacity=0); 33 | font-size: 100%; 34 | height: 100%; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /UNF/models/model.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from modules.embedding.embedding import TokenEmbedding 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self): 11 | super(Model, self).__init__() 12 | 13 | def forward(self, *arg, **kwarg): 14 | """ 15 | 模型前向过程 16 | """ 17 | raise Exception("Not implemented!!") 18 | 19 | def predict(self): 20 | """ 21 | 模型预测过程 22 | """ 23 | raise Exception("Not implemented!!") 24 | 25 | def get_parameter_names(self): 26 | return [name for name, _ in self.named_parameters()] 27 | 28 | def load_state_dict(self, state_dict, strict=True): 29 | true_state_dict = {} 30 | for k,v in state_dict.items(): 31 | if k.startswith("model."): 32 | k = k.split(".", 1)[1] #去掉名字中的第一个model. 33 | true_state_dict[k] = v 34 | 35 | self.model.load_state_dict(true_state_dict, strict) 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /UNF/modules/base_type.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | 5 | class InitType(): 6 | """ 7 | 各种矩阵初始化的方法 8 | """ 9 | UNIFORM = 'uniform' 10 | NORMAL = "normal" 11 | XAVIER_UNIFORM = 'xavier_uniform' 12 | XAVIER_NORMAL = 'xavier_normal' 13 | KAIMING_UNIFORM = 'kaiming_uniform' 14 | KAIMING_NORMAL = 'kaiming_normal' 15 | ORTHOGONAL = 'orthogonal' 16 | 17 | def __str__(self): 18 | return ",".join( 19 | [self.UNIFORM, self.NORMAL, self.XAVIER_UNIFORM, self.XAVIER_NORMAL, 20 | self.KAIMING_UNIFORM, self.KAIMING_NORMAL, self.ORTHOGONAL]) 21 | 22 | 23 | class FAN_MODE(): 24 | FAN_IN = 'FAN_IN' 25 | FAN_OUT = "FAN_OUT" 26 | 27 | def __str__(self): 28 | return ",".join([self.FAN_IN, self.FAN_OUT]) 29 | 30 | 31 | class ActivationType(): 32 | SIGMOID = 'sigmoid' 33 | TANH = "tanh" 34 | RELU = 'relu' 35 | LEAKY_RELU = 'leaky_relu' 36 | NONE = 'linear' 37 | 38 | def __str__(self): 39 | return ",".join( 40 | [self.SIGMOID, self.TANH, self.RELU, self.LEAKY_RELU, self.NONE]) 41 | -------------------------------------------------------------------------------- /UNF/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 对分词的抽象 4 | """ 5 | 6 | class BaseTokenizer(object): 7 | """ 8 | Tokenize的基类 9 | """ 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, x): 14 | return self.tokenize(x) 15 | 16 | def tokenize(self, x): 17 | raise Exception("Not Implemented!") 18 | 19 | 20 | class WhitespaceTokenizer(BaseTokenizer): 21 | """ 22 | 空格切分 23 | """ 24 | def __init__(self): 25 | super(WhitespaceTokenizer, self).__init__() 26 | 27 | def tokenize(self, x): 28 | return x.split() 29 | 30 | 31 | class SpacyTokenizer(BaseTokenizer): 32 | """ 33 | spacy切分 34 | """ 35 | def __init__(self, language): 36 | super(SpacyTokenizer, self).__init__() 37 | self.language = language 38 | self.init() 39 | 40 | def init(self): 41 | import spacy 42 | self.spacy = spacy.load(self.language) 43 | 44 | def tokenize(self, x): 45 | return [tok.text for tok in self.spacy.tokenizer(x)] 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | print(WhitespaceTokenizer("a b c")) 51 | -------------------------------------------------------------------------------- /UNF/modules/encoder/full_connect.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | class FullConnectLayer(nn.Module): 9 | def __init__(self, in_features, out_features, 10 | dropout=0.0, act="relu"): 11 | """ 12 | 封装torch.nn.Linear(),加入droput和activate 13 | """ 14 | super(FullConnectLayer, self).__init__() 15 | self.fc = nn.Linear(in_features, out_features) 16 | self.dropout = nn.Dropout(p=dropout) 17 | self.act = act 18 | 19 | if self.act != None: 20 | if self.act == "relu": 21 | self.act_func = F.relu 22 | elif self.act == "sigmoid": 23 | self.act_func = F.sigmoid 24 | elif self.act == "tanh": 25 | self.act_func = F.tanh 26 | else: 27 | raise Exception("%s activation not support" % act) 28 | 29 | def forward(self, input): 30 | tmp = self.dropout(self.fc(input)) 31 | if self.act: 32 | tmp = self.act_func(tmp) 33 | 34 | return tmp 35 | 36 | 37 | -------------------------------------------------------------------------------- /UNF/trace/lib/tokenizer.cc: -------------------------------------------------------------------------------- 1 | #include "tokenizer.h" 2 | #include "string_util.h" 3 | 4 | #include 5 | 6 | Tokenizer::Tokenizer(const std::string &vocab_path) { 7 | std::ifstream ifs(vocab_path); 8 | if (!ifs) { 9 | std::cerr << "Load vocab fail!!" << std::endl; 10 | return; 11 | } 12 | 13 | std::string line; 14 | uint32_t index = 0; 15 | while(getline(ifs, line)) { 16 | std::string tmp = string_util::trim(line); 17 | vocab_[tmp] = index; 18 | index += 1; 19 | } 20 | } 21 | 22 | void Tokenizer::tokenize(const std::string &text, std::vector& segment) { 23 | string_util::split(text, " ", segment); 24 | return; 25 | } 26 | 27 | void Tokenizer::token2id(const std::vector &segment, std::vector &t2id) { 28 | for (uint32_t index=0; index < segment.size(); index++) { 29 | if (vocab_.find(segment[index]) != vocab_.end()) { 30 | t2id.push_back(vocab_[segment[index]]); 31 | } 32 | else { 33 | //hard code WARNING 34 | t2id.push_back(vocab_[""]); 35 | } 36 | } 37 | return; 38 | } 39 | 40 | uint32_t Tokenizer::get_pad_index() { 41 | return vocab_[""]; 42 | } 43 | -------------------------------------------------------------------------------- /UNF/models/lstm_crf_predictor.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.predictor import Predictor 9 | from models.lstm_crf import LstmCrfTagger 10 | 11 | class LstmCrfPredictor(Predictor): 12 | def __init__(self, model_save_path, device=None): 13 | super(LstmCrfPredictor, self).__init__(model_save_path, device) 14 | 15 | def model_loader(self, conf): 16 | model = LstmCrfTagger(**conf.__dict__) 17 | return model 18 | 19 | def predict(self, input, **kwargs): 20 | input = input.split() 21 | input_ids = [self.vocab.get(item, 0) for item in input] 22 | input_ids = torch.LongTensor(input_ids) 23 | mask = torch.ones(1, input_ids.size(0)) 24 | input_seq_length = torch.tensor([input_ids.size(0)]).long() 25 | if self.device is not None: 26 | input_ids = input_ids.to(self.device) 27 | mask = mask.to(self.device) 28 | 29 | res = self.model.predict(input_ids, input_seq_length, 30 | mask) 31 | 32 | t_res = [] 33 | for item in res.detach().cpu().tolist()[0]: 34 | t_res.append(self.target[item]) 35 | 36 | return t_res 37 | 38 | 39 | -------------------------------------------------------------------------------- /UNF/web_server/server.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | sys.path.append("..") 6 | from flask import render_template, redirect, url_for, request, Blueprint 7 | from collections import defaultdict 8 | import random 9 | import json 10 | 11 | from models.predictor import Predictor 12 | 13 | model_path = "../sex_textcnn3" 14 | model_type = "TEXTCNN" 15 | 16 | predictor = Predictor(model_path, model_type=model_type) 17 | 18 | model_info = json.load(open("%s/conf.json" % model_path)) 19 | model_info["model_name"] = model_type 20 | #just for beautiful 21 | tmp = "" 22 | for k,v in model_info.items(): 23 | tmp += k 24 | tmp += ":" 25 | tmp += str(v) 26 | tmp += " " 27 | 28 | tmp = tmp.rstrip() 29 | 30 | 31 | 32 | static_location = 'static' 33 | debug = Blueprint('debug', __name__, static_folder=static_location) 34 | app = debug 35 | 36 | 37 | @app.route('/', methods=['GET', 'POST']) 38 | @app.route('/index', methods=['GET', 'POST']) 39 | def similar(): 40 | if request.method == 'GET': 41 | return render_template('web.html') 42 | else: 43 | title = request.form.get('title') 44 | score = predictor.predict(title) 45 | return render_template('web.html', score=score, input_value=title, model_info=tmp) 46 | 47 | -------------------------------------------------------------------------------- /UNF/test/data/test_data_loader.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from UNF.data.data_loader import DataLoader 4 | 5 | class TestDataLoader(TestCase): 6 | 7 | def test_data_loader(self): 8 | config = { 9 | "dataset": { 10 | "fields": [{ 11 | "name": "", 12 | "cls": "WordField" 13 | "attrs": { 14 | "tokenize": "WhitespaceTokenizer", 15 | } 16 | }, 17 | { 18 | "name": "", 19 | "cls": "LabelField", 20 | }], 21 | "dataset": { 22 | "path": "", 23 | "train": "", 24 | "validation": "", 25 | "test": "", 26 | "format":"" 27 | }, 28 | "iteration": { 29 | "batch_size": 64, 30 | "device": "cpu" 31 | } 32 | } 33 | } 34 | data_loader = DataLoader(config) 35 | print(data_loader) 36 | 37 | 38 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/jquery.fileupload-ui.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload UI Plugin CSS 9.0.0 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2010, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | .fileupload-buttonbar .btn, 14 | .fileupload-buttonbar .toggle { 15 | margin-bottom: 5px; 16 | } 17 | .progress-animated .progress-bar, 18 | .progress-animated .bar { 19 | background: url("../img/progressbar.gif") !important; 20 | filter: none; 21 | } 22 | .fileupload-process { 23 | float: right; 24 | display: none; 25 | } 26 | .fileupload-processing .fileupload-process, 27 | .files .processing .preview { 28 | display: block; 29 | width: 32px; 30 | height: 32px; 31 | background: url("../img/loading.gif") center no-repeat; 32 | background-size: contain; 33 | } 34 | .files audio, 35 | .files video { 36 | max-width: 300px; 37 | } 38 | 39 | @media (max-width: 767px) { 40 | .fileupload-buttonbar .toggle, 41 | .files .toggle, 42 | .files .btn span { 43 | display: none; 44 | } 45 | .files .name { 46 | width: 80px; 47 | word-wrap: break-word; 48 | } 49 | .files audio, 50 | .files video { 51 | max-width: 80px; 52 | } 53 | .files img, 54 | .files canvas { 55 | max-width: 100%; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /UNF/trace/lib/string_util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | namespace string_util { 13 | 14 | const static std::unordered_set CN_PUNCS = {",", "。", "...", 15 | ";", "?", "|", "!", "_", ":", "“", "”", "《", "》", 16 | "】", "【", "(", ")", "^"}; 17 | /* 18 | * @desc: 把vector的内容转换成string 19 | */ 20 | void vector2str(const std::vector& vec, std::string& desc, const std::string& seq); 21 | 22 | /* 23 | * @desc: 把text按照seq分隔符切分成vec,存进desc 24 | */ 25 | void split(const std::string& text, const std::string& seq, std::vector& desc); 26 | 27 | /* 28 | * @desc: string trim todo:空间复杂度o(n)太高 29 | */ 30 | std::string trim(const std::string& src, const char seq=' '); 31 | 32 | /* 33 | * @desc: 返回char的长度,eg digit->1; chines->3 34 | */ 35 | int parse_char(const char *str); 36 | 37 | /* 38 | * @desc: 处理一个字符串中的标点符号 39 | * 40 | * param[in]: 输入字符串 41 | * param[in|put]: 处理后的字符串 42 | * param[in]: 把标点符号替换成指定的字符,默认为空 43 | * 44 | * return -1:failed, 0:success 45 | */ 46 | int punct_process(const std::string &raw_str, std::string &norm_str, 47 | const std::string &replacer=""); 48 | 49 | /* 50 | * @desc: 判断一个字符是不是中文标点 51 | */ 52 | bool is_cn_punct(const std::string &word); 53 | 54 | } //end namesapce string_util 55 | -------------------------------------------------------------------------------- /UNF/web_server/static/css/demo.css: -------------------------------------------------------------------------------- 1 | @charset "UTF-8"; 2 | /* 3 | * jQuery File Upload Demo CSS 1.1.0 4 | * https://github.com/blueimp/jQuery-File-Upload 5 | * 6 | * Copyright 2013, Sebastian Tschan 7 | * https://blueimp.net 8 | * 9 | * Licensed under the MIT license: 10 | * http://www.opensource.org/licenses/MIT 11 | */ 12 | 13 | body { 14 | max-width: 750px; 15 | margin: 0 auto; 16 | padding: 1em; 17 | font-family: "Lucida Grande", "Lucida Sans Unicode", Arial, sans-serif; 18 | font-size: 1em; 19 | line-height: 1.4em; 20 | background: #222; 21 | color: #fff; 22 | -webkit-text-size-adjust: 100%; 23 | -ms-text-size-adjust: 100%; 24 | } 25 | a { 26 | color: orange; 27 | text-decoration: none; 28 | } 29 | img { 30 | border: 0; 31 | vertical-align: middle; 32 | } 33 | h1 { 34 | line-height: 1em; 35 | } 36 | blockquote { 37 | padding: 0 0 0 15px; 38 | margin: 0 0 20px; 39 | border-left: 5px solid #eee; 40 | } 41 | table { 42 | width: 100%; 43 | margin: 10px 0; 44 | } 45 | 46 | .fileupload-progress { 47 | margin: 10px 0; 48 | } 49 | .fileupload-progress .progress-extended { 50 | margin-top: 5px; 51 | } 52 | .error { 53 | color: red; 54 | } 55 | 56 | @media (min-width: 481px) { 57 | .navigation { 58 | list-style: none; 59 | padding: 0; 60 | } 61 | .navigation li { 62 | display: inline-block; 63 | } 64 | .navigation li:not(:first-child):before { 65 | content: "| "; 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /UNF/conf/fasttext_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path": "test/test_data/data", 12 | "train": "train_sample", 13 | "validation": "val_sample", 14 | "test": "test", 15 | "format": "json" 16 | }, 17 | "fields":[{ 18 | "name":"TEXT", 19 | "name_cls":"WordField", 20 | "attrs":{ 21 | "tokenize":"WhitespaceTokenizer", 22 | } 23 | }, 24 | { 25 | "name":"LABEL", 26 | "name_cls":"LabelField", 27 | }], 28 | "iterator":{ 29 | "batch_size":512, 30 | "shuffle": True, 31 | } 32 | } 33 | 34 | #模型相关 35 | model_conf = [ 36 | { 37 | "name": "TEXT", 38 | "encoder_cls": "FastText", 39 | "encoder_params": { 40 | "input_dim": 100, 41 | "hidden_dim": 200, 42 | } 43 | } 44 | ] 45 | 46 | #learner相关的 47 | learner_conf = { 48 | "num_epochs": 6, 49 | "optimizer": "Adam", 50 | "optimizer_parmas": { 51 | "lr": 1e-4 52 | }, 53 | "device": "cuda:0", 54 | "loss": "CrossEntropyLoss", 55 | "serialization_dir": "sex_fasttext", 56 | "label_tag": "1", 57 | "use_fp16": True, 58 | "multi_gpu": True 59 | } 60 | -------------------------------------------------------------------------------- /UNF/conf/leam_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path": "test/test_data/tiktok_music", 12 | "train": "music_train", 13 | "validation": "music_valid", 14 | "test": "music_test", 15 | "format": "json" 16 | }, 17 | "fields":[{ 18 | "name":"TEXT", 19 | "name_cls":"WordField", 20 | "attrs":{ 21 | "tokenize":"WhitespaceTokenizer", 22 | } 23 | }, 24 | { 25 | "name":"LABEL", 26 | "name_cls":"LabelField", 27 | }], 28 | "iterator":{ 29 | "batch_size":64, 30 | "shuffle": True, 31 | } 32 | } 33 | 34 | #模型相关 35 | model_conf = [ 36 | { 37 | "name": "TEXT", 38 | "encoder_cls": "LEAM", 39 | "encoder_params": { 40 | "input_dim": 100, 41 | "ngram": 6, 42 | "dropout": 0.1, 43 | "pretrained": False, 44 | } 45 | } 46 | ] 47 | 48 | #learner相关的 49 | learner_conf = { 50 | "num_epochs": 10, 51 | "optimizer": "Adam", 52 | "optimizer_parmas": { 53 | "lr": 1e-4 54 | }, 55 | "device": "cuda:2", 56 | "loss": "CrossEntropyLoss", 57 | "serialization_dir": "tiktok_music_leam", 58 | "label_tag": "__label__1" 59 | } 60 | 61 | -------------------------------------------------------------------------------- /UNF/models/model_loader.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 从配置文件反射到对应的模型 4 | """ 5 | import sys 6 | 7 | from models.fasttext import FastText 8 | from models.textcnn import TextCnn 9 | from models.lstm_crf import LstmCrfTagger 10 | from models.dpcnn import DpCnn 11 | from models.self_attention import SelfAttention 12 | from models.leam import LEAM 13 | 14 | 15 | class ModelLoader(object): 16 | 17 | @classmethod 18 | def from_params(cls, model_conf, fields): 19 | if len(model_conf) == 1: 20 | model_conf = model_conf[0] 21 | extra = {} 22 | name = model_conf["name"] 23 | #hardcode label的field_name 24 | if "label_num" not in model_conf: 25 | label_num = len(fields["LABEL"][1].vocab.stoi) 26 | else: 27 | label_num = model_conf["label_num"] 28 | 29 | model_conf["encoder_params"]["label_nums"] = label_num 30 | 31 | vocab_size = len(fields[name][1].vocab.stoi) 32 | model_conf["encoder_params"]["vocab_size"] = vocab_size 33 | 34 | encoder_params = model_conf["encoder_params"] 35 | if "pretrained" in encoder_params and encoder_params["pretrained"]: 36 | extra["vectors"] = fields[name].vocab.vectors 37 | 38 | 39 | return globals()[model_conf["encoder_cls"]](**model_conf["encoder_params"], **extra), \ 40 | model_conf["encoder_params"] 41 | else: 42 | #多域模型 43 | pass 44 | -------------------------------------------------------------------------------- /UNF/trace.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | import argparse 5 | import json 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from models.textcnn import TextCnnTrace 11 | from models.fasttext import FastTextTrace 12 | from models.dpcnn import DpCnnTrace 13 | from models.model_util import Config 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model_path", 18 | type=str, 19 | required=True, 20 | ) 21 | parser.add_argument("--model_cls", 22 | type=str, 23 | default="TextCnnTrace", 24 | ) 25 | parser.add_argument("--save_path", 26 | type=str, 27 | default="trace.pt", 28 | ) 29 | args = parser.parse_args() 30 | model_path = args.model_path 31 | model_cls = args.model_cls 32 | save_path = args.save_path 33 | 34 | config = Config.from_json_file("%s/conf.json" % (model_path)) 35 | net = globals()[model_cls](**config.__dict__) 36 | net.load_state_dict_trace(torch.load("%s/best.th" % model_path)) 37 | net.eval() 38 | 39 | mock_input = net.mock_input_data() 40 | tr = torch.jit.trace(net, mock_input) 41 | print(tr.code) 42 | 43 | #move vocab 44 | os.system("mv %s/vocab.txt trace/" % model_path) 45 | 46 | #save trace model 47 | tr.save("trace/%s" % save_path) 48 | -------------------------------------------------------------------------------- /UNF/models/model_trace.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from modules.embedding.embedding import TokenEmbedding 7 | 8 | 9 | class ModelTrace(nn.Module): 10 | def __init__(self, input_dim=None, vocab_size=None, **kwargs): 11 | """ 12 | trace类的抽象类 13 | """ 14 | super(ModelTrace, self).__init__() 15 | if input_dim is not None and vocab_size is not None: 16 | self.embedding = TokenEmbedding(input_dim, vocab_size) 17 | #加载预训练的词向量 18 | if "pretrain" in kwargs: 19 | if kwargs["pretrain"]: 20 | self.embedding.from_pretrained(kwargs['vectors']) 21 | 22 | def forward(self, *arg, **kwarg): 23 | """ 24 | 模型前向过程 25 | """ 26 | raise Exception("Not implemented!!") 27 | 28 | def mock_input_data(self): 29 | """ 30 | mock trace输入 31 | """ 32 | return torch.ones((1, 128), dtype=torch.long), torch.ones((1, 128), dtype=torch.long) 33 | 34 | def load_state_dict_trace(self, state_dict, strict=True): 35 | true_state_dict = {} 36 | for k,v in state_dict.items(): 37 | if k.startswith("model."): 38 | k = k.split(".", 1)[1] #去掉名字中的第一个model. 39 | true_state_dict[k] = v 40 | 41 | self.load_state_dict(true_state_dict, strict) 42 | 43 | def get_parameter_names(self): 44 | return [name for name, _ in self.named_parameters()] 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /UNF/conf/dpcnn_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path": "test/test_data/aclImdb", 12 | "train": "train", 13 | "test": "test", 14 | "format": "json" 15 | }, 16 | "fields":[{ 17 | "name":"TEXT", 18 | "name_cls":"WordField", 19 | "attrs":{ 20 | "tokenize":"WhitespaceTokenizer", 21 | } 22 | }, 23 | { 24 | "name":"LABEL", 25 | "name_cls":"LabelField", 26 | }], 27 | "iterator":{ 28 | "batch_size":64, 29 | "shuffle": True, 30 | } 31 | } 32 | 33 | #模型相关 34 | model_conf = [ 35 | { 36 | "name": "TEXT", 37 | "encoder_cls": "DpCnn", 38 | "encoder_params": { 39 | "input_dim": 100, 40 | "filter_num": 100, 41 | "filter_size": 3, 42 | "stride": 2, 43 | "block_size": 3, 44 | "dropout": 0.1, 45 | "pretrained": False, 46 | } 47 | } 48 | ] 49 | 50 | #learner相关的 51 | learner_conf = { 52 | "num_epochs": 20, 53 | "optimizer": "Adam", 54 | "optimizer_parmas": { 55 | "lr": 1e-4 56 | }, 57 | "device": "cuda:1", 58 | "loss": "CrossEntropyLoss", 59 | "serialization_dir": "imdb_dpcnn", 60 | "label_tag": "pos" 61 | } 62 | 63 | -------------------------------------------------------------------------------- /UNF/conf/textcnn_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path": "test/test_data/data", 12 | "train": "train", 13 | "validation": "valid", 14 | "test": "test", 15 | "format": "json" 16 | }, 17 | "fields":[{ 18 | "name":"TEXT", 19 | "name_cls":"WordField", 20 | "attrs":{ 21 | "tokenize":"WhitespaceTokenizer", 22 | "min_count": 3 23 | } 24 | }, 25 | { 26 | "name":"LABEL", 27 | "name_cls":"LabelField", 28 | }], 29 | "iterator":{ 30 | "batch_size":512, 31 | "shuffle": True, 32 | } 33 | } 34 | 35 | #模型相关 36 | model_conf = [ 37 | { 38 | "name": "TEXT", 39 | "encoder_cls": "TextCnn", 40 | "encoder_params": { 41 | "input_dim": 100, 42 | "filter_num": 100, 43 | "filter_size": [1,2,3,4], 44 | "pretrained": False, 45 | } 46 | } 47 | ] 48 | 49 | 50 | #learner相关的 51 | learner_conf = { 52 | "num_epochs": 6, 53 | "optimizer": "Adam", 54 | "optimizer_parmas": { 55 | "lr": 1e-4 56 | }, 57 | "device": "cuda:0", 58 | "loss": "CrossEntropyLoss", 59 | "serialization_dir": "sample_dir", 60 | "label_tag": "1", 61 | "use_fp16": False, 62 | "multi_gpu": False 63 | } 64 | 65 | -------------------------------------------------------------------------------- /UNF/training/learner_loader.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 从配置文件反射拿到learner 4 | """ 5 | import os 6 | import sys 7 | import logging 8 | sys.path.append("training") 9 | from learner import Trainer 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class LearnerLoader(object): 14 | 15 | @classmethod 16 | def from_params(cls, model, train_iter, dev_iter, learner_conf, test_iter=None, fields=None, 17 | model_conf=None): 18 | if fields is not None: 19 | if "label_tag" in learner_conf: 20 | label_index = fields["LABEL"][1].vocab.stoi[learner_conf["label_tag"]] 21 | logger.info("Load label index") 22 | logger.info("Label index: %s" % label_index) 23 | logger.info("Label vocab: %s" % fields["LABEL"][1].vocab.stoi) 24 | return Trainer(model, train_iter, dev_iter, 25 | **learner_conf, test_iter=test_iter, label_index=label_index, model_conf=model_conf, fields=fields) 26 | if "sequence_model" in learner_conf and learner_conf["sequence_model"]: 27 | assert fields is not None, "sequence model need target vocab" 28 | vocab = fields["LABEL"][1].vocab.itos 29 | 30 | return Trainer(model, train_iter, dev_iter, 31 | **learner_conf, test_iter=test_iter, fields=fields, model_conf=model_conf, label_vocab=vocab) 32 | 33 | 34 | return Trainer(model, train_iter, dev_iter, 35 | **learner_conf, test_iter=test_iter, fields=fields, model_conf=model_conf) 36 | 37 | 38 | -------------------------------------------------------------------------------- /UNF/conf/selfattention_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path": "test/test_data/data", 12 | "train": "train_sample", 13 | "validation": "val_sample", 14 | "test": "test", 15 | "format": "json" 16 | }, 17 | "fields":[{ 18 | "name":"TEXT", 19 | "name_cls":"WordField", 20 | "attrs":{ 21 | "tokenize":"WhitespaceTokenizer", 22 | "include_lengths": True 23 | } 24 | }, 25 | { 26 | "name":"LABEL", 27 | "name_cls":"LabelField", 28 | }], 29 | "iterator":{ 30 | "batch_size":512, 31 | "shuffle": True, 32 | } 33 | } 34 | 35 | #模型相关 36 | model_conf = [ 37 | { 38 | "name": "TEXT", 39 | "encoder_cls": "SelfAttention", 40 | "encoder_params": { 41 | "input_dim": 100, 42 | "hidden_size": 100, 43 | "layer_num": 2, 44 | "attention_num": 1, 45 | } 46 | } 47 | ] 48 | 49 | 50 | #learner相关的 51 | learner_conf = { 52 | "num_epochs": 6, 53 | "optimizer": "Adam", 54 | "optimizer_parmas": { 55 | "lr": 1e-4 56 | }, 57 | "device": "cuda:0", 58 | "loss": "CrossEntropyLoss", 59 | "serialization_dir": "sex_selfattention4", 60 | "label_tag": "1", 61 | "use_fp16":False, 62 | "multi_gpu": True 63 | } 64 | -------------------------------------------------------------------------------- /UNF/models/fasttext.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from models.model import Model 8 | from models.model_trace import ModelTrace 9 | 10 | class FastTextTrace(ModelTrace): 11 | """ 12 | implementation: Bag of Tricks for Efficient Text Classification 13 | 14 | """ 15 | def __init__(self, input_dim, vocab_size, 16 | hidden_dim, label_nums, **kwargs): 17 | super(FastTextTrace, self).__init__(input_dim, vocab_size, **kwargs) 18 | self.hidden_fc = nn.Linear(input_dim, hidden_dim) 19 | self.final_fc = nn.Linear(hidden_dim, label_nums) 20 | 21 | def forward(self, input, mask=None, label=None): 22 | embedding_o = self.embedding(input) #b * s * dim 23 | tmp = F.avg_pool2d(embedding_o, (embedding_o.size(1), 1)).squeeze(1)#b * dim 24 | 25 | hidden_fc_o = self.hidden_fc(tmp) 26 | final_fc_o = self.final_fc(hidden_fc_o) 27 | 28 | return final_fc_o 29 | 30 | 31 | 32 | class FastText(Model): 33 | def __init__(self, input_dim, vocab_size, 34 | hidden_dim, label_nums, **kwargs): 35 | super(FastText, self).__init__() 36 | self.model = FastTextTrace(input_dim, vocab_size, hidden_dim, 37 | label_nums, **kwargs) 38 | 39 | def forward(self, input, mask=None, label=None): 40 | logits = self.model(input, mask, label) 41 | return {"logits": logits} 42 | 43 | def predict(self, input, mask=None, label=None): 44 | return self.forward(input, mask, label)["logits"] 45 | -------------------------------------------------------------------------------- /UNF/conf/lstm_crf_conf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | data_loader conf解释: 4 | dataset: 配置训练数据和训练数据的格式,提供自动加载的功能 5 | field: 针对每一个域提供一个相应的配置,每一个域的tokenzie,最大长度,是否返回padding长度(LSTM用)等 6 | iterator: 提供迭代的配置,包括每个batch大小,device是cpu还是gpu 7 | """ 8 | #data_loader相关 9 | data_loader_conf = { 10 | "dataset":{ 11 | "path":"test/test_data/ner", 12 | "train":"train_ner_sample", 13 | "test":"test_ner_sample", 14 | "format":"json" 15 | }, 16 | "fields":[{ 17 | "name":"TEXT", 18 | "name_cls":"WordField", 19 | "attrs":{ 20 | "tokenize":"WhitespaceTokenizer", 21 | "include_lengths": True 22 | } 23 | }, 24 | { 25 | "name":"LABEL", 26 | "name_cls":"Field", 27 | "attrs":{ 28 | "tokenize":"WhitespaceTokenizer", 29 | "sequential": True, 30 | "unk_token": None 31 | } 32 | }], 33 | "iterator":{ 34 | "batch_size":64, 35 | "shuffle": True, 36 | } 37 | } 38 | 39 | #模型相关 40 | model_conf = [ 41 | { 42 | "name": "TEXT", 43 | "encoder_cls": "LstmCrfTagger", 44 | "encoder_params": { 45 | "input_dim": 100, 46 | "hidden_size": 200, 47 | "num_layers": 3, 48 | "device": "cuda:0" 49 | } 50 | } 51 | ] 52 | 53 | #learner相关的 54 | learner_conf = { 55 | "num_epochs": 10, 56 | "optimizer": "Adam", 57 | "optimizer_parmas": { 58 | "lr": 1e-5 59 | }, 60 | "device": "cuda:0", 61 | "serialization_dir": "model_lstm_example1", 62 | "sequence_model": True, 63 | "metric": "NerF1Measure" 64 | } 65 | 66 | -------------------------------------------------------------------------------- /UNF/score_flow.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | import argparse 5 | import json 6 | 7 | from models.lstm_crf_predictor import LstmCrfPredictor 8 | from models.predictor import Predictor 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--model_path", 14 | type=str, 15 | required=True, 16 | ) 17 | parser.add_argument("--device", 18 | type=str, 19 | default=None, 20 | ) 21 | parser.add_argument("--test_path", 22 | type=str, 23 | required=True, 24 | ) 25 | parser.add_argument("--model_type", 26 | type=str, 27 | default="textcnn", 28 | ) 29 | parser.add_argument("--save_path", 30 | type=str, 31 | default="test_save.dat", 32 | required=True, 33 | ) 34 | 35 | args = parser.parse_args() 36 | model_path = args.model_path 37 | device = args.device 38 | model_type = args.model_type 39 | #step1 初始化predictor 40 | if model_type == "lstm-crf": 41 | predictor = LstmCrfPredictor(model_path, device) 42 | else: 43 | predictor = Predictor(model_path, device, model_type) 44 | 45 | #step2 开始预测 46 | save_path = open(os.path.join(model_path, args.save_path), "w") 47 | for line in open(args.test_path): 48 | line = json.loads(line.rstrip()) 49 | pred = predictor.predict(line["TEXT"]) 50 | save_path.write("%s\t%s\t%s\n" % (line["TEXT"], line["LABEL"], " ".join(map(str, pred)))) 51 | 52 | save_path.close() 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /UNF/trace/predict.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "tokenizer.h" 8 | 9 | 10 | int max_seq_length = 128; //hard code, comfortable with the tracing process 11 | 12 | int main(int argc, const char* argv[]) 13 | { 14 | if (argc != 3) { 15 | std::cerr << "usage predict "; 16 | return -1; 17 | } 18 | //step1 load model 19 | torch::jit::script::Module module; 20 | try { 21 | module = torch::jit::load(argv[1]); 22 | } 23 | catch (const c10::Error& e) { 24 | std::cerr << "error loading the model\n"; 25 | return -1; 26 | } 27 | 28 | //step2 init tokenizer 29 | Tokenizer tokenizer(argv[2]); 30 | uint32_t pad_index = tokenizer.get_pad_index(); 31 | 32 | //step3 loop predict 33 | std::string text; 34 | while (true) 35 | { 36 | std::cout << "\n" << "Input ->"; 37 | getline(std::cin, text); 38 | if (text == "break") break; 39 | //segment 40 | std::vector segment; 41 | tokenizer.tokenize(text, segment); 42 | 43 | //to id 44 | std::vector t2id; 45 | tokenizer.token2id(segment, t2id); 46 | 47 | //padding 48 | int seg_len = segment.size(); 49 | torch::Tensor mask = torch::ones({1, max_seq_length}); 50 | 51 | while(seg_len < max_seq_length) { 52 | t2id.push_back(pad_index); 53 | mask[0][seg_len] = 0; 54 | seg_len ++; 55 | } 56 | 57 | std::vector inputs; 58 | inputs.push_back(torch::from_blob(t2id.data(), {1, max_seq_length}).to(torch::kLong)); 59 | inputs.push_back(mask.to(torch::kLong)); 60 | torch::Tensor logits = module.forward(inputs).toTensor(); 61 | //torch::Tensor logits = module.forward({input}).toTensor(); 62 | std::cout << logits << std::endl; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /UNF/models/textcnn.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from __future__ import absolute_import 3 | import os 4 | import sys 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from models.model import Model 10 | from models.model_trace import ModelTrace 11 | from modules.embedding.embedding import TokenEmbedding 12 | from modules.encoder.cnn_maxpool import CnnMaxpoolLayer 13 | 14 | class TextCnnTrace(ModelTrace): 15 | """ 16 | implemantation:Convolutional Neural Networks for Sentence Classification 17 | """ 18 | def __init__(self, 19 | input_dim, 20 | vocab_size, 21 | filter_size, 22 | filter_num, 23 | label_nums, 24 | dropout=0.0, **kwargs): 25 | super(TextCnnTrace, self).__init__(input_dim, vocab_size, **kwargs) 26 | if not isinstance(filter_size, (tuple, list)): 27 | filter_size = [filter_size] 28 | 29 | if not isinstance(filter_num, (tuple, list)): 30 | filter_num = len(filter_size) * [filter_num] 31 | 32 | self.encoder = CnnMaxpoolLayer(input_dim, 33 | filter_num, filter_size, **kwargs) 34 | 35 | self.dropout = nn.Dropout(p=dropout) 36 | self.fc = nn.Linear(sum(filter_num), label_nums) 37 | 38 | def forward(self, input, mask=None, label=None): 39 | if len(input.size()) == 1: 40 | input = input.unsqueeze(0) 41 | 42 | x = self.embedding(input) 43 | output = self.encoder(x, mask) #[b * l] 44 | output = self.dropout(output) 45 | logits = self.fc(output) #[b, label_num] 46 | return logits 47 | 48 | 49 | class TextCnn(Model): 50 | def __init__(self, 51 | input_dim, 52 | vocab_size, 53 | filter_size, 54 | filter_num, 55 | label_nums, 56 | dropout=0.0, **kwargs): 57 | super(TextCnn, self).__init__() 58 | self.model = TextCnnTrace(input_dim, vocab_size, filter_size, filter_num, 59 | label_nums, dropout, **kwargs) 60 | 61 | def forward(self, input, mask=None, label=None): 62 | logits = self.model(input, mask, label) 63 | return {"logits": logits} 64 | 65 | def predict(self, input, mask=None, label=None): 66 | return self.forward(input, mask, label)["logits"] 67 | -------------------------------------------------------------------------------- /UNF/modules/encoder/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | 9 | class LstmEncoderLayer(nn.Module): 10 | def __init__(self, input_size, hidden_size, num_layers, label_nums=None, 11 | batch_first=True, bidirectional=False, dropout=0.0): 12 | """" 13 | Lstm编码器的封装 14 | 15 | :params input_size 输入维度 16 | :params hidden_size hidden_state的维度 17 | :params num_layers lstm的层数 18 | :params batch_first 输入的第一个维度是否表示batch大小 19 | :params bidirectional 是否使用双向LSTM 20 | """ 21 | super(LstmEncoderLayer, self).__init__() 22 | if bidirectional: 23 | lstm_size = hidden_size // 2 24 | else: 25 | lstm_size = hidden_size 26 | 27 | self.lstm = nn.LSTM(input_size=input_size, hidden_size=lstm_size, 28 | num_layers=num_layers, batch_first=batch_first, 29 | bidirectional=bidirectional) 30 | 31 | self.lstm_dropout = nn.Dropout(dropout) 32 | self.label_nums = label_nums 33 | 34 | if self.label_nums: 35 | self.hidden2tags = nn.Linear(hidden_size, label_nums) 36 | 37 | def forward(self, input, input_seq_lengths, batch_first=True, is_sort=False): 38 | """ 39 | :params input 输入矩阵,按序列长度降序排列 40 | :params input_seq_length 输入序列的长度 41 | :params batch_first 输入矩阵的维度第一维是否是batch大小 42 | """ 43 | if not is_sort: 44 | #对输入tensor按长度序排列 45 | word_seq_lengths, word_perm_idx = input_seq_lengths.sort(0, descending=True) 46 | input = input[word_perm_idx] 47 | 48 | packed_words = pack_padded_sequence(input, word_seq_lengths.cpu().numpy(), True) 49 | hidden = None 50 | lstm_out, hidden = self.lstm(packed_words, hidden) 51 | lstm_out, _ = pad_packed_sequence(lstm_out) 52 | lstm_out = lstm_out.transpose(1, 0) 53 | if not is_sort: 54 | _, word_seq_recover = word_perm_idx.sort(0, descending=False) 55 | lstm_out = lstm_out[word_seq_recover] 56 | 57 | outputs = self.lstm_dropout(lstm_out) #batch * seq_len * (hidden_dim*directions) 58 | if self.label_nums: 59 | outputs = self.hidden2tags(outputs) 60 | 61 | return outputs 62 | -------------------------------------------------------------------------------- /UNF/common_util/ner_p_r_f_cal.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | 5 | sys.path.append("../") 6 | 7 | from training.learner_util import get_ner_BIO 8 | 9 | 10 | gold_loc_tp = 0.0 11 | gold_loc_fp = 0.0 12 | gold_loc_tn = 0.0 13 | gold_loc_fn = 0.0 14 | 15 | gold_per_tp = 0.0 16 | gold_per_fp = 0.0 17 | gold_per_tn = 0.0 18 | gold_per_fn = 0.0 19 | 20 | gold_org_tp = 0.0 21 | gold_org_fp = 0.0 22 | gold_org_tn = 0.0 23 | gold_org_fn = 0.0 24 | 25 | 26 | for line in open("../model_lstm/test_ner"): 27 | line = line.rstrip() 28 | parts = line.split("\t") 29 | query = parts[0] 30 | labels = get_ner_BIO(parts[1].split()) 31 | preds = get_ner_BIO(parts[2].split()) 32 | if labels or preds: 33 | for item in labels: 34 | if item in preds: 35 | if "LOC" in item: 36 | gold_loc_tp += 1 37 | elif "ORG" in item: 38 | gold_org_tp += 1 39 | elif "PER" in item: 40 | gold_per_tp += 1 41 | 42 | else: 43 | if "LOC" in item: 44 | gold_loc_fn += 1 45 | elif "ORG" in item: 46 | gold_org_fn += 1 47 | elif "PER" in item: 48 | gold_per_fn += 1 49 | 50 | for item in preds: 51 | if item not in labels: 52 | if "LOC" in item: 53 | gold_loc_fp += 1 54 | elif "ORG" in item: 55 | gold_org_fp += 1 56 | elif "PER" in item: 57 | gold_per_fp += 1 58 | 59 | 60 | loc_pre = gold_loc_tp * 1.0 / (gold_loc_tp + gold_loc_fp) 61 | loc_rec = gold_loc_tp * 1.0 / (gold_loc_tp + gold_loc_fn) 62 | loc_f = 2 * loc_pre * loc_rec / (loc_pre + loc_rec) 63 | print("Location precision:%s recall:%s f:%s" % (loc_pre, loc_rec, loc_f)) 64 | 65 | 66 | per_pre = gold_per_tp * 1.0 / (gold_per_tp + gold_per_fp) 67 | per_rec = gold_per_tp * 1.0 / (gold_per_tp + gold_per_fn) 68 | per_f = 2 * per_pre * per_rec / (per_pre + per_rec) 69 | print("per precision:%s recall:%s f:%s" % (per_pre, per_rec, per_f)) 70 | 71 | org_pre = gold_org_tp * 1.0 / (gold_org_tp + gold_org_fp) 72 | org_rec = gold_org_tp * 1.0 / (gold_org_tp + gold_org_fn) 73 | org_f = 2 * org_pre * org_rec / (org_pre + org_rec) 74 | print("org precision:%s recall:%s f:%s" % (org_pre, org_rec, org_f)) 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /UNF/modules/encoder/self_attention_encoder.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from modules.module_util import mask_softmax 10 | 11 | 12 | class SelfAttentionEncoder(nn.Module): 13 | def __init__(self, head_num, input_dim, 14 | attention_dim, coefficient=None): 15 | """ 16 | imple: A STRUCTURED SELF-ATTENTIVE SENTENCE EMBEDDING 17 | 18 | :params head_num int attention的个数 19 | :params input_dim int 输入数据的维数 20 | :params attention_dim attention的维度大小,超参,attention的一个中间变量 21 | """ 22 | super(SelfAttentionEncoder, self).__init__() 23 | self.att_s1 = nn.Linear(input_dim, attention_dim, bias=False) 24 | self.att_s2 = nn.Linear(attention_dim, head_num, bias=False) 25 | self.coefficient = coefficient 26 | 27 | def forward(self, input, mask=None): 28 | inner_out = self.att_s1(input) #batch_size * seq_len * attention_dim 29 | final_out = self.att_s2(inner_out) #batch_size * seq_len * head_num 30 | 31 | #if mask is not None: 32 | # mask = mask.unsqueeze(2) #batch_size * seq_len * 1 33 | 34 | att_weight = mask_softmax(final_out, 1, mask) #batch_size * seq_len * head_num 35 | 36 | output = {} 37 | output["attention"] = att_weight 38 | 39 | #H = A*M 40 | H = att_weight.transpose(1,2)@input #batch_size * atten_num * input_dim 41 | batch_size = input.size(0) 42 | output["encoder"] = H.view(batch_size, -1) 43 | 44 | if self.coefficient: 45 | output["regulariration_loss"] = self.frobenius_regularization_penalty(att_weight) / batch_size 46 | 47 | return output 48 | 49 | def frobenius_regularization_penalty(self, attention): 50 | """ 51 | 实现论文中PENALIZATION TERM,||AAT − I|| 52 | """ 53 | num_timesteps = attention.size(1) 54 | batch_size = attention.size(0) 55 | #(batch_size, num_attention_heads, timesteps) 56 | attention_transpose = attention.transpose(1, 2) 57 | 58 | identity = torch.eye(num_timesteps, device=attention.device) 59 | 60 | #(batch_size, timesteps, timesteps) 61 | identity = identity.unsqueeze(0).expand(batch_size, num_timesteps, num_timesteps) 62 | 63 | #(batch_size, timesteps, timesteps) 64 | delta = attention @ attention_transpose - identity 65 | 66 | return torch.sum(torch.sum(torch.sum(delta ** 2, 1), 1) ** 0.5) 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /UNF/models/self_attention.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from models.model_trace import ModelTrace 10 | from modules.embedding.embedding import TokenEmbedding 11 | from modules.encoder.lstm_encoder import LstmEncoderLayer 12 | from modules.encoder.self_attention_encoder import SelfAttentionEncoder 13 | 14 | from training.learner_util import generate_mask 15 | 16 | class SelfAttention(ModelTrace): 17 | 18 | def __init__(self, label_nums, vocab_size, input_dim, 19 | hidden_size, layer_num, attention_num, coefficient=0.0, bidirection=True, 20 | batch_first=True, device=None, 21 | dropout=0.0, averge_batch_loss=True, **kwargs): 22 | """ 23 | Implemention: A STRUCTURED SELF-ATTENTIVE SENTENCE EMBEDDING 24 | 25 | :params label_nums int 输出label的数量 26 | :params vocab_size int 词表大小 27 | :params input_dim int 输入词表维度 28 | :params hidden_size int 隐藏层的维度 29 | :params layer_num int 隐藏层的层数 30 | :params attention_num int attention的个数 31 | :params coefficient float 正则化系数 32 | """ 33 | super(SelfAttention, self).__init__(input_dim, vocab_size, **kwargs) 34 | self.coefficient = coefficient 35 | 36 | self.encoder = LstmEncoderLayer(input_dim, hidden_size, layer_num, 37 | bidirectional=bidirection, batch_first=batch_first, dropout=dropout) 38 | 39 | self.averge_batch_loss = averge_batch_loss 40 | 41 | self.att_encoder = SelfAttentionEncoder(attention_num, hidden_size, int(hidden_size/4), 42 | coefficient ) 43 | self.fc = nn.Linear(hidden_size * attention_num, label_nums) 44 | 45 | def forward(self, input, input_seq_length, mask=None, label=None): 46 | embedding = self.embedding(input) #batch_size * seq_len * input_dim 47 | encoder_res = self.encoder(embedding, input_seq_length) #batch * seq_len * (hidden_size) 48 | 49 | #attention的实现 50 | att_res = self.att_encoder(encoder_res, mask) 51 | att_encoder = att_res["encoder"] 52 | logits = self.fc(att_encoder) 53 | output = {} 54 | output["logits"] = logits 55 | 56 | if self.coefficient != 0: 57 | output["regulariration_loss"] = att_res["regulariration_loss"] 58 | output["coefficient"] = self.coefficient 59 | 60 | return output 61 | 62 | def predict(self, input, input_seq_length, mask=None, label=None): 63 | return self.forward(input, input_seq_length, label, mask)["logits"] 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /UNF/models/predictor.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import json 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from models.model_util import Config 10 | from models.dpcnn import DpCnn 11 | from models.fasttext import FastText 12 | from models.leam import LEAM 13 | from models.self_attention import SelfAttention 14 | from models.textcnn import TextCnn 15 | 16 | 17 | class Predictor(nn.Module): 18 | def __init__(self, model_save_path, device=None, model_type=None): 19 | super(Predictor, self).__init__() 20 | model_conf = os.path.join(model_save_path, "conf.json") 21 | vocab_path = os.path.join(model_save_path, "vocab.txt") 22 | target_path = os.path.join(model_save_path, "target.txt") 23 | 24 | self.model_type = model_type 25 | self.model = self.model_loader(Config.from_json_file(model_conf)) 26 | self.model.load_state_dict(torch.load(os.path.join(model_save_path, "best.th"))) 27 | self.model.eval() 28 | 29 | self.device = device 30 | if self.device is not None: 31 | self.model.to(device) 32 | 33 | self.vocab = self.load_vocab(vocab_path) 34 | self.target = self.load_vocab(target_path, reverse=True) 35 | 36 | def model_loader(self, conf): 37 | name = self.model_type.lower() 38 | if name == "textcnn": 39 | model = TextCnn(**conf.__dict__) 40 | elif name == "fastext": 41 | model = FastText(**conf.__dict__) 42 | elif name == "dpcnn": 43 | model = DpCnn(**conf.__dict__) 44 | elif name == "leam": 45 | model = LEAM(**conf.__dict__) 46 | elif name == "self-attention": 47 | model = SelfAttention(**conf.__dict__) 48 | else: 49 | raise Exception("name:%s model not implemented!" % (name)) 50 | 51 | return model 52 | 53 | def predict(self, input, **kwargs): 54 | input = input.split() 55 | input_ids = [self.vocab.get(item, 0) for item in input] 56 | 57 | input_ids = torch.LongTensor(input_ids) 58 | if self.device is not None: 59 | input_ids = input_ids.to(self.device) 60 | 61 | mask = (input_ids != 1).long() 62 | 63 | res = self.model.predict(input_ids, mask) 64 | res = res.detach().cpu().tolist()[0] 65 | return res 66 | 67 | def load_vocab(self, path, reverse=False): 68 | res = {} 69 | tmp = json.load(open(path)) 70 | for index, word in enumerate(tmp): 71 | if reverse: 72 | res[index] = word 73 | else: 74 | res[word] = index 75 | return res 76 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/main.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Plugin JS Example 8.9.1 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2010, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* global $, window */ 13 | 14 | $(function () { 15 | 'use strict'; 16 | 17 | // Initialize the jQuery File Upload widget: 18 | $('#fileupload').fileupload({ 19 | // Uncomment the following to send cross-domain cookies: 20 | //xhrFields: {withCredentials: true}, 21 | url: 'upload' 22 | }); 23 | 24 | // Enable iframe cross-domain access via redirect option: 25 | $('#fileupload').fileupload( 26 | 'option', 27 | 'redirect', 28 | window.location.href.replace( 29 | /\/[^\/]*$/, 30 | '/cors/result.html?%s' 31 | ) 32 | ); 33 | 34 | if (window.location.hostname === 'blueimp.github.io') { 35 | // Demo settings: 36 | $('#fileupload').fileupload('option', { 37 | url: '//jquery-file-upload.appspot.com/', 38 | // Enable image resizing, except for Android and Opera, 39 | // which actually support image resizing, but fail to 40 | // send Blob objects via XHR requests: 41 | disableImageResize: /Android(?!.*Chrome)|Opera/ 42 | .test(window.navigator.userAgent), 43 | maxFileSize: 5000000, 44 | acceptFileTypes: /(\.|\/)(gif|jpe?g|png)$/i 45 | }); 46 | // Upload server status check for browsers with CORS support: 47 | if ($.support.cors) { 48 | $.ajax({ 49 | url: '//jquery-file-upload.appspot.com/', 50 | type: 'HEAD' 51 | }).fail(function () { 52 | $('
') 53 | .text('Upload server currently unavailable - ' + 54 | new Date()) 55 | .appendTo('#fileupload'); 56 | }); 57 | } 58 | } else { 59 | // Load existing files: 60 | $('#fileupload').addClass('fileupload-processing'); 61 | $.ajax({ 62 | // Uncomment the following to send cross-domain cookies: 63 | //xhrFields: {withCredentials: true}, 64 | url: $('#fileupload').fileupload('option', 'url'), 65 | dataType: 'json', 66 | context: $('#fileupload')[0] 67 | }).always(function () { 68 | $(this).removeClass('fileupload-processing'); 69 | }).done(function (result) { 70 | $(this).fileupload('option', 'done') 71 | .call(this, $.Event('done'), {result: result}); 72 | }); 73 | } 74 | 75 | }); 76 | -------------------------------------------------------------------------------- /UNF/modules/encoder/cnn_maxpool.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os 3 | import sys 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from modules.module_util import initial_parameter 9 | 10 | 11 | class CnnMaxpoolLayer(nn.Module): 12 | def __init__(self, input_num, output_num, 13 | filter_size, stride=1, padding=0, 14 | activation='relu', initial_method=None, **kwargs): 15 | """ 16 | cnn+maxpooling的结构做encoder 17 | :params input_num int 输入的维度 18 | :params output_num int|list 每个卷积核输出的维度 19 | :params filter_size list 卷积核的大小 20 | :params init_method str 网络参数初始化的方法,默认为xavier_uniform 21 | """ 22 | super(CnnMaxpoolLayer, self).__init__() 23 | if not isinstance(filter_size, (tuple, list)): 24 | filter_size = [filter_size] 25 | 26 | if not isinstance(output_num, (tuple, list)): 27 | output_num = [output_num] * len(filter_size) 28 | 29 | assert len(filter_size) == len(output_num), \ 30 | "Filter size len is not equal output_num len" 31 | 32 | print("stride", stride) 33 | self.convs = nn.ModuleList( 34 | [nn.Conv1d(in_channels=input_num, out_channels=on, 35 | kernel_size=ks, stride=stride, padding=padding, bias=False) 36 | for on, ks in zip(output_num, filter_size)] 37 | ) 38 | 39 | if activation == "relu": 40 | self.activation = F.relu 41 | elif activation == "sigmoid": 42 | self.activation = F.sigmoid 43 | elif activation == "tanh": 44 | self.activation = F.tanh 45 | else: 46 | raise Exception("%s activation not support" % activation) 47 | 48 | #使用默认初始化 49 | #initial_parameter(self, initial_method) 50 | 51 | def forward(self, input, mask=None): 52 | """ 53 | :params: input torch.Tensor [batch_size, length, dim] 54 | :params: mask torch.Tensor [batch_size, length] 55 | """ 56 | if mask is not None: 57 | input = input * mask.unsqueeze(-1).float() 58 | 59 | #[b, l, d] -> [b, d, l] 60 | input = torch.transpose(input, 1, 2) 61 | conv_res = [self.activation(conv(input)) for conv in self.convs] #[b, o, lout] 62 | 63 | #import pdb;pdb.set_trace() 64 | tmp = [] 65 | for i in range(len(conv_res)): 66 | dim = conv_res[i].size(2) 67 | if isinstance(dim, torch.Tensor): 68 | #trace 无法识别tuple的操作,会转成tensor 69 | dim = dim.tolist() 70 | max_out = F.max_pool1d(conv_res[i], kernel_size=dim) 71 | tmp.append(max_out.squeeze(2)) 72 | 73 | return torch.cat(tmp, dim=-1) 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /UNF/models/lstm_crf.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from models.model import Model 10 | from modules.embedding.embedding import TokenEmbedding 11 | from modules.encoder.lstm_encoder import LstmEncoderLayer 12 | from modules.decoder.crf import CRF 13 | 14 | class LstmCrfTagger(Model): 15 | def __init__(self, label_nums, vocab_size, input_dim, 16 | hidden_size, num_layers, use_crf=True, 17 | bidirection=True, batch_first=True, device=None, 18 | dropout=0.0, averge_batch_loss=True, **kwargs): 19 | """ 20 | ref: Neural Architectures for Named Entity Recognition 21 | 模型结构是word_embedding + bilstm + crf 22 | 23 | :params 24 | """ 25 | super(LstmCrfTagger, self).__init__(input_dim, vocab_size, **kwargs) 26 | 27 | self.encoder = LstmEncoderLayer(input_dim, hidden_size, num_layers, label_nums=label_nums+2, 28 | bidirectional=bidirection, batch_first=batch_first, dropout=dropout) 29 | 30 | self.averge_batch_loss = averge_batch_loss 31 | self.use_crf = use_crf 32 | if self.use_crf: 33 | self.decoder = CRF(label_nums, device) 34 | 35 | def forward(self, input, input_seq_length, 36 | mask=None, batch_label=None): 37 | 38 | embedding = self.embedding(input) #batch_size * seq_len * input_dim 39 | encoder_res = self.encoder(embedding, input_seq_length) #batch * seq_len * (hidden_dim*directions) 40 | batch_size = encoder_res.size(0) 41 | seq_len = encoder_res.size(1) 42 | 43 | if self.use_crf: 44 | _, tag_seq = self.decoder._viterbi_decode(encoder_res, mask) 45 | if batch_label is not None: 46 | total_loss = self.decoder.neg_log_likelihood_loss(encoder_res, mask, batch_label) 47 | 48 | else: 49 | outs = encoder_res.view(batch_size * seq_len, -1) 50 | _, tag_seq = torch.max(outs, 1) 51 | tag_seq = tag_seq.view(batch_size, seq_len) 52 | tag_seq = mask.long() * tag_seq 53 | if batch_label is not None: 54 | loss_function = nn.NLLLoss(ignore_index=0, size_average=False)#mask的token不对最后的loss产生影响, 固定mask的label id为0 55 | score = F.log_softmax(outs, 1) 56 | total_loss = loss_function(score, batch_label.view(batch_size * seq_len)) 57 | 58 | if batch_label is not None: 59 | if self.averge_batch_loss: 60 | total_loss = total_loss / batch_size 61 | return {"loss":total_loss, "logits": tag_seq} 62 | else: 63 | return {"logits": tag_seq} 64 | 65 | def predict(self, input, input_seq_length, mask=None): 66 | input = input.unsqueeze(0) 67 | res = self.forward(input, input_seq_length, mask) 68 | return res["logits"] 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /UNF/modules/embedding/embedding.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | Embedding类的抽象 4 | """ 5 | import os 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from modules.module_util import init_tensor 13 | from modules.base_type import InitType, FAN_MODE, ActivationType 14 | 15 | 16 | class BaseEmbedding(nn.Module): 17 | """ 18 | Emebdding类的基类 19 | :params dim int类型,embedding的维度大小 20 | :params vocab_size int类型 21 | :params device string or [string1, string2],计算的后端,默认是cpu 22 | :params init_type string, 初始化的计算方式 ,默认采用uniform初始化 23 | :params dropout float 24 | """ 25 | 26 | def __init__(self, dim, vocab_size, 27 | device=None, dropout=0.0): 28 | 29 | super(BaseEmbedding, self).__init__() 30 | self.dim = dim 31 | self.vocab_size = vocab_size 32 | self.device = device 33 | self.dropout = nn.Dropout(p=dropout) 34 | 35 | 36 | 37 | @classmethod 38 | def from_dict(cls, params): 39 | return cls(**params) 40 | 41 | def forward(self, input): 42 | raise Exception("BaseEmbedding forward method not implemented!") 43 | 44 | 45 | class TokenEmbedding(BaseEmbedding): 46 | def __init__(self, dim, vocab_size, device=None, 47 | dropout=0.0, 48 | init_type=InitType.XAVIER_NORMAL, 49 | low=0, high=1, mean=0, std=1, 50 | activation_type=ActivationType.NONE, 51 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0 52 | ): 53 | """ 54 | Embedding类的基础类 55 | 56 | :params dim int类型,embedding的维度大小 57 | :params vocab_size int类型 58 | :params device string or [string1, string2],计算的后端,默认是cpu 59 | :params init_type string, 初始化的计算方式 ,默认采用uniform初始化 60 | :params dropout float 61 | """ 62 | super(TokenEmbedding, self).__init__(dim, vocab_size, device, 63 | dropout) 64 | 65 | self.embeddings = nn.Embedding(vocab_size, dim) 66 | embedding_lookup_table = init_tensor(tensor=torch.empty(vocab_size, dim), 67 | init_type=init_type, low=low, high=high, mean=mean, std=std, 68 | activation_type=activation_type, fan_mode=fan_mode, 69 | negative_slope=negative_slope) 70 | 71 | self.embeddings.weight.data.copy_(embedding_lookup_table) 72 | 73 | def forward(self, input): 74 | embedding = self.embeddings(input) 75 | return self.dropout(embedding) 76 | 77 | @classmethod 78 | def from_pretrained(cls, vectors, vocab_map=None): 79 | """ 80 | copy从dataloader每个域加载好的预训练的词向量 81 | 82 | :params vectors Vector类型 83 | """ 84 | if isinstance(path, (str)): 85 | raise Exception("Load embedding from path not implemented!") 86 | 87 | self.embeddings.weight.data.copy_(vectors) 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /UNF/data/data_loader.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | """ 3 | 对数据加载的抽象,从磁盘数据load成模型可训练的数据格式 4 | """ 5 | import random 6 | 7 | from torchtext.data import Field, Dataset, LabelField, TabularDataset 8 | from torchtext.data import Iterator, BucketIterator 9 | 10 | from data.field import WordField, CharField, SiteField 11 | from data.tokenizer import BaseTokenizer, WhitespaceTokenizer, SpacyTokenizer 12 | 13 | 14 | class DataLoader(object): 15 | """ 16 | 流程步骤: 17 | 1)创建Field,每个类型的字段生成一个Field;生成Field的主要参数包括 tokenize等 18 | 2)根据上一步创建的Field和数据路径,创建Dataset对象,Dataset就是Example对象的集合 19 | 3)根据上一步创建的Dataset对象,创建Iterator对象,Iterator就是提供一个迭代方法,懒加载的返回 20 | 一个Batch对象,生成Batch对象的时候会调用Field对象的process方法完成string2id的映射转换和pad过程; 21 | Batch对象主要是包含一个batch_size的Example对象 22 | 23 | """ 24 | 25 | def __init__(self, config): 26 | self.config = config 27 | self.SEED = 1441 #magic data 28 | self.fields = None 29 | 30 | def generate_dataset(self): 31 | fields = self.config["fields"] 32 | #step1: Field生成 33 | inner_fields = {} 34 | inner_info = {} 35 | 36 | for item in fields: 37 | f_name = item["name"] 38 | f_cls = item["name_cls"] 39 | if "attrs" in item: 40 | if "tokenize" in item["attrs"]: 41 | #初始化field的tokenizer对象 42 | if "language" in item["attrs"]: 43 | item["attrs"]["tokenize"] = globals()[item["attrs"]["tokenize"]](item["attrs"]["language"]) 44 | else: 45 | item["attrs"]["tokenize"] = globals()[item["attrs"]["tokenize"]]() 46 | 47 | #初始化field建词表的信息 48 | if "min_count" in item["attrs"]: 49 | inner_info[f_name] = item["attrs"]["min_count"] 50 | del item["attrs"]["min_count"] 51 | 52 | 53 | inner_fields[f_name] = (f_name, globals()[f_cls](**item["attrs"])) 54 | else: 55 | inner_fields[f_name] = (f_name, globals()[f_cls]()) 56 | 57 | self.fields = inner_fields 58 | 59 | #step2: Dataset生成 60 | datasets = TabularDataset.splits(**self.config["dataset"] , fields=inner_fields) 61 | if len(datasets) == 2: 62 | #训练集、验证集的划分 63 | train_datasets, test_datasets = datasets 64 | train_datasets, valid_datasets = train_datasets.split(random_state=random.seed(self.SEED)) 65 | elif len(datasets) == 3: 66 | train_datasets, valid_datasets, test_datasets = datasets 67 | 68 | #step3: 根据Field生成对应的词表 69 | for item in inner_fields.values(): 70 | name, obj = item 71 | if obj.use_vocab: 72 | if name in inner_info: 73 | obj.build_vocab(datasets[0], min_freq=inner_info[name]) 74 | else: 75 | obj.build_vocab(datasets[0]) 76 | 77 | #step4: Iterator对象生成 78 | data_iterator = BucketIterator.splits((train_datasets, valid_datasets, test_datasets), sort=False, **self.config["iterator"]) 79 | 80 | return data_iterator 81 | -------------------------------------------------------------------------------- /UNF/web_server/templates/web.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | UNF web demo 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |


15 |
16 |

UNF web demo

17 |
18 |
19 |
20 |

Text

21 |
22 |
23 |
24 |
25 | 26 |
27 | {% if input_value %} 28 | 29 | {%else%} 30 | 31 | 32 | {%endif%} 33 | 34 |
35 |
36 |
37 |
38 | 42 |
43 |
44 |
45 |
46 |
47 | 48 |
49 |
50 |

Result

51 |
52 |
53 | 54 | {% if model_info %} 55 | 56 | 57 | 58 | 59 | 60 | 61 | {% endif %} 62 |
Model info
{{ model_info }}
63 | 64 | 65 | {% if score %} 66 | 67 | 68 | 69 | 70 | 71 | 72 | {% endif %} 73 |
Score
{{ score }}
74 |
75 |
76 |
77 | 78 | -------------------------------------------------------------------------------- /UNF/models/leam.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import sys 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn.functional import normalize 8 | from torch.nn import CrossEntropyLoss 9 | 10 | from models.model import Model 11 | from modules.embedding.embedding import TokenEmbedding 12 | from modules.module_util import mask_softmax 13 | from modules.encoder.full_connect import FullConnectLayer 14 | 15 | 16 | class LEAM(Model): 17 | def __init__(self, input_dim, vocab_size, label_nums, 18 | hidden_dim=256, ngrams=3, active=True, 19 | norm=True, coefficient=1, dropout=0.0, **kwargs): 20 | """ 21 | implementation: Joint Embedding of Words and Labels for Text Classification 22 | 23 | params: 24 | """ 25 | super(LEAM, self).__init__(input_dim, vocab_size, **kwargs) 26 | self.label_embedding = TokenEmbedding(input_dim, label_nums) # e * c 27 | self.label_nums = label_nums 28 | padding = int(ngrams/2) 29 | self.active = active 30 | self.norm = norm 31 | self.coefficient = coefficient 32 | 33 | assert ngrams % 2==1, "ngram should be the odd" 34 | 35 | if self.active: 36 | self.conv = torch.nn.Sequential( 37 | torch.nn.Conv1d( 38 | self.label_nums, self.label_nums, 39 | ngrams, padding=padding), 40 | torch.nn.ReLU() 41 | ) 42 | else: 43 | self.conv = torch.nn.Conv1d( 44 | self.label_nums, self.label_nums, 45 | ngrams, padding=padding) 46 | 47 | 48 | self.fc1 = FullConnectLayer(input_dim, hidden_dim, dropout, "relu") 49 | self.fc2 = nn.Linear(hidden_dim, label_nums) 50 | 51 | def forward(self, input, mask=None, label=None): 52 | #import pdb;pdb.set_trace() 53 | word_embedding = self.embedding(input) #b * s * dim 54 | embedding_label = self.label_embedding(label) 55 | 56 | label_embedding = self.label_embedding.embeddings.weight 57 | 58 | if mask is not None: 59 | word_embedding = word_embedding * mask.unsqueeze(-1).float() 60 | 61 | if self.norm: 62 | word_embedding = normalize(word_embedding, dim=2) 63 | label_embedding = normalize(label_embedding, dim=1) # l * dim 64 | 65 | #Attention操作 66 | G = word_embedding @ label_embedding.transpose(0, 1) # b * s * l 67 | att_v = self.conv(G.transpose(1,2)) #b * l * s 68 | att_v = att_v.transpose(1,2) #b * s * l 69 | att_v = F.max_pool1d(att_v, att_v.size(2)) #b * s * 1 70 | 71 | att_v = mask_softmax(att_v, 1, mask) #b *s * 1 72 | H_enc = att_v * word_embedding #b * s * dim 73 | att_out = torch.sum(H_enc, 1) # b * dim 74 | 75 | #全连接操作 76 | tmp = self.fc1(att_out) #b * s * hidden_dim 77 | logits = self.fc2(tmp) #b * s * label_nums 78 | 79 | output = {} 80 | output["logits"] = logits 81 | 82 | if self.coefficient != 0: 83 | output["coefficient"] = self.coefficient 84 | tmp = self.fc1(embedding_label) # b * l * hidden_dim 85 | logits = self.fc2(tmp) 86 | reg_loss = F.cross_entropy(logits, label) 87 | output["regulariration_loss"] = reg_loss 88 | 89 | return output 90 | 91 | def predict(self, input, mask=None, label=None): 92 | return self.forward(input, label, mask)["logits"] 93 | -------------------------------------------------------------------------------- /UNF/models/dpcnn.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from __future__ import absolute_import 3 | import os 4 | import sys 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from models.model import Model 10 | from models.model_trace import ModelTrace 11 | 12 | 13 | class DpCnnTrace(ModelTrace): 14 | def __init__(self, input_dim, vocab_size, label_nums, 15 | block_size=16, filter_size=3, filter_num=250, 16 | stride=2, dropout=0.0, **kwargs): 17 | """ 18 | implementation: Deep Pyramid Convolutional Neural Networks for Text Categorization 19 | 20 | :params block_size dpcnn里block的数量,每个block包括两个卷积层和一个max_pooling层 21 | """ 22 | super(DpCnnTrace, self).__init__(input_dim, vocab_size, **kwargs) 23 | self.filter_num = filter_num 24 | self.filter_size = filter_size 25 | self.stride = stride 26 | self.padding = int(self.filter_size/2) 27 | assert self.filter_size % 2 == 1, "filter size should be odd" 28 | self.block_size = block_size 29 | 30 | self.region_embedding = torch.nn.Sequential( 31 | torch.nn.Conv1d( 32 | input_dim, self.filter_num, 33 | self.filter_size, padding=self.padding) 34 | ) 35 | 36 | self.blocks = torch.nn.ModuleList([torch.nn.Sequential( 37 | torch.nn.ReLU(), 38 | torch.nn.Conv1d( 39 | self.filter_num, self.filter_num, 40 | self.filter_size, padding=self.padding), 41 | torch.nn.ReLU(), 42 | torch.nn.Conv1d( 43 | self.filter_num, self.filter_num, 44 | self.filter_size, padding=self.padding) 45 | ) for _ in range(self.block_size + 1)]) 46 | 47 | self.linear = nn.Linear(self.filter_num, label_nums) 48 | self.dropout = nn.Dropout(p=dropout) 49 | 50 | def forward(self, input, mask=None, label=None): 51 | input = self.embedding(input) 52 | 53 | if mask is not None: 54 | input = input * mask.unsqueeze(-1).float() 55 | 56 | #region embedding 57 | input = input.transpose(1, 2) 58 | region_out = self.region_embedding(input) 59 | block_out = self.blocks[0](region_out) 60 | #short cut 61 | block_out = block_out + region_out 62 | for index in range(1, self.block_size+1): 63 | block_features = F.max_pool1d( 64 | block_out, self.filter_size, self.stride) 65 | block_out = self.blocks[index](block_features) 66 | block_features = block_features + block_out 67 | doc_embedding = F.max_pool1d( 68 | block_features, block_features.size(2)).squeeze() 69 | 70 | logits = self.dropout(self.linear(doc_embedding)) 71 | return logits 72 | 73 | 74 | class DpCnn(Model): 75 | def __init__(self, input_dim, vocab_size, label_nums, 76 | block_size=16, filter_size=3, filter_num=250, 77 | stride=2, dropout=0.0, **kwargs): 78 | super(DpCnn, self).__init__() 79 | self.model = DpCnnTrace(input_dim, vocab_size, label_nums, block_size, 80 | filter_size, filter_num, stride, dropout, **kwargs) 81 | 82 | def forward(self, input, mask=None, label=None): 83 | logits = self.model(input, mask, label) 84 | return {"logits": logits} 85 | 86 | def predict(self, input, mask=None, label=None): 87 | return self.forward(input, label, mask)["logits"] 88 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/cors/jquery.xdr-transport.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery XDomainRequest Transport Plugin 1.1.3 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2011, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | * 11 | * Based on Julian Aubourg's ajaxHooks xdr.js: 12 | * https://github.com/jaubourg/ajaxHooks/ 13 | */ 14 | 15 | /* global define, window, XDomainRequest */ 16 | 17 | (function (factory) { 18 | 'use strict'; 19 | if (typeof define === 'function' && define.amd) { 20 | // Register as an anonymous AMD module: 21 | define(['jquery'], factory); 22 | } else { 23 | // Browser globals: 24 | factory(window.jQuery); 25 | } 26 | }(function ($) { 27 | 'use strict'; 28 | if (window.XDomainRequest && !$.support.cors) { 29 | $.ajaxTransport(function (s) { 30 | if (s.crossDomain && s.async) { 31 | if (s.timeout) { 32 | s.xdrTimeout = s.timeout; 33 | delete s.timeout; 34 | } 35 | var xdr; 36 | return { 37 | send: function (headers, completeCallback) { 38 | var addParamChar = /\?/.test(s.url) ? '&' : '?'; 39 | function callback(status, statusText, responses, responseHeaders) { 40 | xdr.onload = xdr.onerror = xdr.ontimeout = $.noop; 41 | xdr = null; 42 | completeCallback(status, statusText, responses, responseHeaders); 43 | } 44 | xdr = new XDomainRequest(); 45 | // XDomainRequest only supports GET and POST: 46 | if (s.type === 'DELETE') { 47 | s.url = s.url + addParamChar + '_method=DELETE'; 48 | s.type = 'POST'; 49 | } else if (s.type === 'PUT') { 50 | s.url = s.url + addParamChar + '_method=PUT'; 51 | s.type = 'POST'; 52 | } else if (s.type === 'PATCH') { 53 | s.url = s.url + addParamChar + '_method=PATCH'; 54 | s.type = 'POST'; 55 | } 56 | xdr.open(s.type, s.url); 57 | xdr.onload = function () { 58 | callback( 59 | 200, 60 | 'OK', 61 | {text: xdr.responseText}, 62 | 'Content-Type: ' + xdr.contentType 63 | ); 64 | }; 65 | xdr.onerror = function () { 66 | callback(404, 'Not Found'); 67 | }; 68 | if (s.xdrTimeout) { 69 | xdr.ontimeout = function () { 70 | callback(0, 'timeout'); 71 | }; 72 | xdr.timeout = s.xdrTimeout; 73 | } 74 | xdr.send((s.hasContent && s.data) || null); 75 | }, 76 | abort: function () { 77 | if (xdr) { 78 | xdr.onerror = $.noop(); 79 | xdr.abort(); 80 | } 81 | } 82 | }; 83 | } 84 | }); 85 | } 86 | })); 87 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-audio.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Audio Preview Plugin 1.0.3 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global define, window, document */ 14 | 15 | (function (factory) { 16 | 'use strict'; 17 | if (typeof define === 'function' && define.amd) { 18 | // Register as an anonymous AMD module: 19 | define([ 20 | 'jquery', 21 | 'load-image', 22 | './jquery.fileupload-process' 23 | ], factory); 24 | } else { 25 | // Browser globals: 26 | factory( 27 | window.jQuery, 28 | window.loadImage 29 | ); 30 | } 31 | }(function ($, loadImage) { 32 | 'use strict'; 33 | 34 | // Prepend to the default processQueue: 35 | $.blueimp.fileupload.prototype.options.processQueue.unshift( 36 | { 37 | action: 'loadAudio', 38 | // Use the action as prefix for the "@" options: 39 | prefix: true, 40 | fileTypes: '@', 41 | maxFileSize: '@', 42 | disabled: '@disableAudioPreview' 43 | }, 44 | { 45 | action: 'setAudio', 46 | name: '@audioPreviewName', 47 | disabled: '@disableAudioPreview' 48 | } 49 | ); 50 | 51 | // The File Upload Audio Preview plugin extends the fileupload widget 52 | // with audio preview functionality: 53 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 54 | 55 | options: { 56 | // The regular expression for the types of audio files to load, 57 | // matched against the file type: 58 | loadAudioFileTypes: /^audio\/.*$/ 59 | }, 60 | 61 | _audioElement: document.createElement('audio'), 62 | 63 | processActions: { 64 | 65 | // Loads the audio file given via data.files and data.index 66 | // as audio element if the browser supports playing it. 67 | // Accepts the options fileTypes (regular expression) 68 | // and maxFileSize (integer) to limit the files to load: 69 | loadAudio: function (data, options) { 70 | if (options.disabled) { 71 | return data; 72 | } 73 | var file = data.files[data.index], 74 | url, 75 | audio; 76 | if (this._audioElement.canPlayType && 77 | this._audioElement.canPlayType(file.type) && 78 | ($.type(options.maxFileSize) !== 'number' || 79 | file.size <= options.maxFileSize) && 80 | (!options.fileTypes || 81 | options.fileTypes.test(file.type))) { 82 | url = loadImage.createObjectURL(file); 83 | if (url) { 84 | audio = this._audioElement.cloneNode(false); 85 | audio.src = url; 86 | audio.controls = true; 87 | data.audio = audio; 88 | return data; 89 | } 90 | } 91 | return data; 92 | }, 93 | 94 | // Sets the audio element as a property of the file object: 95 | setAudio: function (data, options) { 96 | if (data.audio && !options.disabled) { 97 | data.files[data.index][options.name || 'preview'] = data.audio; 98 | } 99 | return data; 100 | } 101 | 102 | } 103 | 104 | }); 105 | 106 | })); 107 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-video.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Video Preview Plugin 1.0.3 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global define, window, document */ 14 | 15 | (function (factory) { 16 | 'use strict'; 17 | if (typeof define === 'function' && define.amd) { 18 | // Register as an anonymous AMD module: 19 | define([ 20 | 'jquery', 21 | 'load-image', 22 | './jquery.fileupload-process' 23 | ], factory); 24 | } else { 25 | // Browser globals: 26 | factory( 27 | window.jQuery, 28 | window.loadImage 29 | ); 30 | } 31 | }(function ($, loadImage) { 32 | 'use strict'; 33 | 34 | // Prepend to the default processQueue: 35 | $.blueimp.fileupload.prototype.options.processQueue.unshift( 36 | { 37 | action: 'loadVideo', 38 | // Use the action as prefix for the "@" options: 39 | prefix: true, 40 | fileTypes: '@', 41 | maxFileSize: '@', 42 | disabled: '@disableVideoPreview' 43 | }, 44 | { 45 | action: 'setVideo', 46 | name: '@videoPreviewName', 47 | disabled: '@disableVideoPreview' 48 | } 49 | ); 50 | 51 | // The File Upload Video Preview plugin extends the fileupload widget 52 | // with video preview functionality: 53 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 54 | 55 | options: { 56 | // The regular expression for the types of video files to load, 57 | // matched against the file type: 58 | loadVideoFileTypes: /^video\/.*$/ 59 | }, 60 | 61 | _videoElement: document.createElement('video'), 62 | 63 | processActions: { 64 | 65 | // Loads the video file given via data.files and data.index 66 | // as video element if the browser supports playing it. 67 | // Accepts the options fileTypes (regular expression) 68 | // and maxFileSize (integer) to limit the files to load: 69 | loadVideo: function (data, options) { 70 | if (options.disabled) { 71 | return data; 72 | } 73 | var file = data.files[data.index], 74 | url, 75 | video; 76 | if (this._videoElement.canPlayType && 77 | this._videoElement.canPlayType(file.type) && 78 | ($.type(options.maxFileSize) !== 'number' || 79 | file.size <= options.maxFileSize) && 80 | (!options.fileTypes || 81 | options.fileTypes.test(file.type))) { 82 | url = loadImage.createObjectURL(file); 83 | if (url) { 84 | video = this._videoElement.cloneNode(false); 85 | video.src = url; 86 | video.controls = true; 87 | data.video = video; 88 | return data; 89 | } 90 | } 91 | return data; 92 | }, 93 | 94 | // Sets the video element as a property of the file object: 95 | setVideo: function (data, options) { 96 | if (data.video && !options.disabled) { 97 | data.files[data.index][options.name || 'preview'] = data.video; 98 | } 99 | return data; 100 | } 101 | 102 | } 103 | 104 | }); 105 | 106 | })); 107 | -------------------------------------------------------------------------------- /UNF/trace/lib/string_util.cc: -------------------------------------------------------------------------------- 1 | #include "string_util.h" 2 | #include 3 | 4 | void string_util::vector2str(const std::vector& vec, std::string& desc, const std::string& seq) { 5 | if (vec.size() <= 0) return; 6 | 7 | auto iter = vec.begin(); 8 | if (iter != vec.end()) { 9 | desc.append(*iter); 10 | ++iter; 11 | } 12 | 13 | while (iter != vec.end()) { 14 | desc.append(seq); 15 | desc.append(*iter); 16 | ++iter; 17 | } 18 | } 19 | 20 | void string_util::split(const std::string& text, const std::string& seq, std::vector& desc) { 21 | if (text.size() == 0) return; 22 | 23 | size_t seq_s = seq.size(); 24 | size_t text_s = text.size(); 25 | 26 | size_t start = 0; 27 | 28 | std::string tmp; 29 | while (start < text_s) { 30 | size_t pos = text.find_first_of(seq, start); 31 | 32 | if (pos != std::string::npos) { 33 | tmp = text.substr(start, pos - start); 34 | if (!tmp.empty()) { 35 | desc.push_back(tmp); 36 | } 37 | start = pos + seq_s; 38 | } else { 39 | break; 40 | } 41 | } 42 | 43 | if (start < text_s) { 44 | tmp = text.substr(start); 45 | if (!tmp.empty()) { 46 | desc.push_back(tmp); 47 | } 48 | } 49 | return; 50 | } 51 | 52 | std::string string_util::trim(const std::string& src, const char seq) { 53 | if (src.empty()) return src; 54 | std::string tmp = src; 55 | 56 | size_t len = tmp.size(), pos = 0; 57 | 58 | for (size_t i = 0; i < len; i++) { 59 | if (seq == src[i]) pos += 1; 60 | else { 61 | break; 62 | } 63 | } 64 | 65 | if (pos > 0) { 66 | tmp.erase(0, pos); 67 | } 68 | 69 | pos = 0; 70 | len = tmp.size(); 71 | int i = int(len - 1); 72 | for (; i >= 0; i--) { 73 | if (seq == src[i]) { 74 | pos += 1; 75 | } else { 76 | break; 77 | } 78 | } 79 | 80 | if (pos > 0) { 81 | tmp.erase(len - pos); 82 | } 83 | 84 | return tmp; 85 | } 86 | 87 | 88 | int string_util::parse_char(const char *str) { 89 | if (str == NULL) { 90 | return 0; 91 | } 92 | 93 | unsigned char p = (unsigned char)(*str); 94 | int n = 0; 95 | while (p & 0x80) { 96 | ++n; 97 | p = p << 1; 98 | } 99 | 100 | if (n == 0) { 101 | ++n; 102 | } else if (n > 4) { 103 | n = 1; 104 | } 105 | 106 | return n; 107 | } 108 | 109 | int string_util::punct_process(const std::string &raw_str, std::string &norm_str, const std::string &replacer) { 110 | if (raw_str.empty()) { 111 | return -1; 112 | } 113 | 114 | norm_str.clear(); 115 | char *p = (char *)raw_str.c_str(); 116 | size_t offset = 0; 117 | int len = 0; 118 | while (*p) { 119 | len = parse_char(p); 120 | if (len == 1) { 121 | if (ispunct(*p)) { 122 | norm_str += replacer; 123 | } else { 124 | norm_str.push_back(*p); 125 | } 126 | } else if (len > 1) { 127 | std::string c = raw_str.substr(offset, len); 128 | if (is_cn_punct(c)) { 129 | norm_str += replacer; 130 | } else { 131 | norm_str += c; 132 | } 133 | } 134 | p += len; 135 | offset += len; 136 | } 137 | 138 | norm_str = trim(norm_str); 139 | return 0; 140 | } 141 | 142 | bool string_util::is_cn_punct(const std::string &word) { 143 | if (word.empty()) { 144 | return false; 145 | } 146 | 147 | if (CN_PUNCS.find(word) != CN_PUNCS.end()) { 148 | return true; 149 | } else { 150 | return false; 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/app.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Plugin Angular JS Example 1.2.1 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global window, angular */ 14 | 15 | (function () { 16 | 'use strict'; 17 | 18 | var isOnGitHub = window.location.hostname === 'blueimp.github.io', 19 | url = isOnGitHub ? '//jquery-file-upload.appspot.com/' : 'server/php/'; 20 | 21 | angular.module('demo', [ 22 | 'blueimp.fileupload' 23 | ]) 24 | .config([ 25 | '$httpProvider', 'fileUploadProvider', 26 | function ($httpProvider, fileUploadProvider) { 27 | delete $httpProvider.defaults.headers.common['X-Requested-With']; 28 | fileUploadProvider.defaults.redirect = window.location.href.replace( 29 | /\/[^\/]*$/, 30 | '/cors/result.html?%s' 31 | ); 32 | if (isOnGitHub) { 33 | // Demo settings: 34 | angular.extend(fileUploadProvider.defaults, { 35 | // Enable image resizing, except for Android and Opera, 36 | // which actually support image resizing, but fail to 37 | // send Blob objects via XHR requests: 38 | disableImageResize: /Android(?!.*Chrome)|Opera/ 39 | .test(window.navigator.userAgent), 40 | maxFileSize: 5000000, 41 | acceptFileTypes: /(\.|\/)(gif|jpe?g|png)$/i 42 | }); 43 | } 44 | } 45 | ]) 46 | 47 | .controller('DemoFileUploadController', [ 48 | '$scope', '$http', '$filter', '$window', 49 | function ($scope, $http) { 50 | $scope.options = { 51 | url: url 52 | }; 53 | if (!isOnGitHub) { 54 | $scope.loadingFiles = true; 55 | $http.get(url) 56 | .then( 57 | function (response) { 58 | $scope.loadingFiles = false; 59 | $scope.queue = response.data.files || []; 60 | }, 61 | function () { 62 | $scope.loadingFiles = false; 63 | } 64 | ); 65 | } 66 | } 67 | ]) 68 | 69 | .controller('FileDestroyController', [ 70 | '$scope', '$http', 71 | function ($scope, $http) { 72 | var file = $scope.file, 73 | state; 74 | if (file.url) { 75 | file.$state = function () { 76 | return state; 77 | }; 78 | file.$destroy = function () { 79 | state = 'pending'; 80 | return $http({ 81 | url: file.deleteUrl, 82 | method: file.deleteType 83 | }).then( 84 | function () { 85 | state = 'resolved'; 86 | $scope.clear(file); 87 | }, 88 | function () { 89 | state = 'rejected'; 90 | } 91 | ); 92 | }; 93 | } else if (!file.$cancel && !file._index) { 94 | file.$cancel = function () { 95 | $scope.clear(file); 96 | }; 97 | } 98 | } 99 | ]); 100 | 101 | }()); 102 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/cors/jquery.postmessage-transport.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery postMessage Transport Plugin 1.1.1 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2011, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* global define, window, document */ 13 | 14 | (function (factory) { 15 | 'use strict'; 16 | if (typeof define === 'function' && define.amd) { 17 | // Register as an anonymous AMD module: 18 | define(['jquery'], factory); 19 | } else { 20 | // Browser globals: 21 | factory(window.jQuery); 22 | } 23 | }(function ($) { 24 | 'use strict'; 25 | 26 | var counter = 0, 27 | names = [ 28 | 'accepts', 29 | 'cache', 30 | 'contents', 31 | 'contentType', 32 | 'crossDomain', 33 | 'data', 34 | 'dataType', 35 | 'headers', 36 | 'ifModified', 37 | 'mimeType', 38 | 'password', 39 | 'processData', 40 | 'timeout', 41 | 'traditional', 42 | 'type', 43 | 'url', 44 | 'username' 45 | ], 46 | convert = function (p) { 47 | return p; 48 | }; 49 | 50 | $.ajaxSetup({ 51 | converters: { 52 | 'postmessage text': convert, 53 | 'postmessage json': convert, 54 | 'postmessage html': convert 55 | } 56 | }); 57 | 58 | $.ajaxTransport('postmessage', function (options) { 59 | if (options.postMessage && window.postMessage) { 60 | var iframe, 61 | loc = $('').prop('href', options.postMessage)[0], 62 | target = loc.protocol + '//' + loc.host, 63 | xhrUpload = options.xhr().upload; 64 | return { 65 | send: function (_, completeCallback) { 66 | counter += 1; 67 | var message = { 68 | id: 'postmessage-transport-' + counter 69 | }, 70 | eventName = 'message.' + message.id; 71 | iframe = $( 72 | '' 75 | ).bind('load', function () { 76 | $.each(names, function (i, name) { 77 | message[name] = options[name]; 78 | }); 79 | message.dataType = message.dataType.replace('postmessage ', ''); 80 | $(window).bind(eventName, function (e) { 81 | e = e.originalEvent; 82 | var data = e.data, 83 | ev; 84 | if (e.origin === target && data.id === message.id) { 85 | if (data.type === 'progress') { 86 | ev = document.createEvent('Event'); 87 | ev.initEvent(data.type, false, true); 88 | $.extend(ev, data); 89 | xhrUpload.dispatchEvent(ev); 90 | } else { 91 | completeCallback( 92 | data.status, 93 | data.statusText, 94 | {postmessage: data.result}, 95 | data.headers 96 | ); 97 | iframe.remove(); 98 | $(window).unbind(eventName); 99 | } 100 | } 101 | }); 102 | iframe[0].contentWindow.postMessage( 103 | message, 104 | target 105 | ); 106 | }).appendTo(document.body); 107 | }, 108 | abort: function () { 109 | if (iframe) { 110 | iframe.remove(); 111 | } 112 | } 113 | }; 114 | } 115 | }); 116 | 117 | })); 118 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-validate.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Validation Plugin 1.1.2 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* global define, window */ 13 | 14 | (function (factory) { 15 | 'use strict'; 16 | if (typeof define === 'function' && define.amd) { 17 | // Register as an anonymous AMD module: 18 | define([ 19 | 'jquery', 20 | './jquery.fileupload-process' 21 | ], factory); 22 | } else { 23 | // Browser globals: 24 | factory( 25 | window.jQuery 26 | ); 27 | } 28 | }(function ($) { 29 | 'use strict'; 30 | 31 | // Append to the default processQueue: 32 | $.blueimp.fileupload.prototype.options.processQueue.push( 33 | { 34 | action: 'validate', 35 | // Always trigger this action, 36 | // even if the previous action was rejected: 37 | always: true, 38 | // Options taken from the global options map: 39 | acceptFileTypes: '@', 40 | maxFileSize: '@', 41 | minFileSize: '@', 42 | maxNumberOfFiles: '@', 43 | disabled: '@disableValidation' 44 | } 45 | ); 46 | 47 | // The File Upload Validation plugin extends the fileupload widget 48 | // with file validation functionality: 49 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 50 | 51 | options: { 52 | /* 53 | // The regular expression for allowed file types, matches 54 | // against either file type or file name: 55 | acceptFileTypes: /(\.|\/)(gif|jpe?g|png)$/i, 56 | // The maximum allowed file size in bytes: 57 | maxFileSize: 10000000, // 10 MB 58 | // The minimum allowed file size in bytes: 59 | minFileSize: undefined, // No minimal file size 60 | // The limit of files to be uploaded: 61 | maxNumberOfFiles: 10, 62 | */ 63 | 64 | // Function returning the current number of files, 65 | // has to be overriden for maxNumberOfFiles validation: 66 | getNumberOfFiles: $.noop, 67 | 68 | // Error and info messages: 69 | messages: { 70 | maxNumberOfFiles: 'Maximum number of files exceeded', 71 | acceptFileTypes: 'File type not allowed', 72 | maxFileSize: 'File is too large', 73 | minFileSize: 'File is too small' 74 | } 75 | }, 76 | 77 | processActions: { 78 | 79 | validate: function (data, options) { 80 | if (options.disabled) { 81 | return data; 82 | } 83 | var dfd = $.Deferred(), 84 | settings = this.options, 85 | file = data.files[data.index], 86 | fileSize; 87 | if (options.minFileSize || options.maxFileSize) { 88 | fileSize = file.size; 89 | } 90 | if ($.type(options.maxNumberOfFiles) === 'number' && 91 | (settings.getNumberOfFiles() || 0) + data.files.length > 92 | options.maxNumberOfFiles) { 93 | file.error = settings.i18n('maxNumberOfFiles'); 94 | } else if (options.acceptFileTypes && 95 | !(options.acceptFileTypes.test(file.type) || 96 | options.acceptFileTypes.test(file.name))) { 97 | file.error = settings.i18n('acceptFileTypes'); 98 | } else if (fileSize > options.maxFileSize) { 99 | file.error = settings.i18n('maxFileSize'); 100 | } else if ($.type(fileSize) === 'number' && 101 | fileSize < options.minFileSize) { 102 | file.error = settings.i18n('minFileSize'); 103 | } else { 104 | delete file.error; 105 | } 106 | if (file.error || data.files.error) { 107 | data.files.error = true; 108 | dfd.rejectWith(this, [data]); 109 | } else { 110 | dfd.resolveWith(this, [data]); 111 | } 112 | return dfd.promise(); 113 | } 114 | 115 | } 116 | 117 | }); 118 | 119 | })); 120 | -------------------------------------------------------------------------------- /UNF/modules/module_util.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch.nn.init as init 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from modules.base_type import InitType, FAN_MODE, ActivationType 8 | 9 | def init_tensor(tensor, init_type=InitType.XAVIER_UNIFORM, low=0, high=1, 10 | mean=0, std=1, activation_type=ActivationType.NONE, 11 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0): 12 | """ 13 | 各种标准的tensor参数初始化方法 14 | """ 15 | if init_type == InitType.UNIFORM: 16 | return torch.nn.init.uniform_(tensor, a=low, b=high) 17 | elif init_type == InitType.NORMAL: 18 | return torch.nn.init.normal_(tensor, mean=mean, std=std) 19 | elif init_type == InitType.XAVIER_UNIFORM: 20 | return torch.nn.init.xavier_uniform_( 21 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 22 | elif init_type == InitType.XAVIER_NORMAL: 23 | return torch.nn.init.xavier_normal_( 24 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 25 | elif init_type == InitType.KAIMING_UNIFORM: 26 | return torch.nn.init.kaiming_uniform_( 27 | tensor, a=negative_slope, mode=fan_mode, 28 | nonlinearity=activation_type) 29 | elif init_type == InitType.KAIMING_NORMAL: 30 | return torch.nn.init.kaiming_normal_( 31 | tensor, a=negative_slope, mode=fan_mode, 32 | nonlinearity=activation_type) 33 | elif init_type == InitType.ORTHOGONAL: 34 | return torch.nn.init.orthogonal_( 35 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 36 | else: 37 | raise TypeError( 38 | "Unsupported tensor init type: %s. Supported init type is: %s" % ( 39 | init_type, InitType.str())) 40 | 41 | def initial_parameter(net, initial_method=None): 42 | """A method used to initialize the weights of PyTorch models. 43 | :param net: a PyTorch model 44 | :param str initial_method: one of the following initializations. 45 | - xavier_uniform 46 | - xavier_normal (default) 47 | - kaiming_normal, or msra 48 | - kaiming_uniform 49 | - orthogonal 50 | - sparse 51 | - normal 52 | - uniform 53 | """ 54 | if initial_method == 'xavier_uniform': 55 | init_method = init.xavier_uniform_ 56 | elif initial_method == 'xavier_normal': 57 | init_method = init.xavier_normal_ 58 | elif initial_method == 'kaiming_normal' or initial_method == 'msra': 59 | init_method = init.kaiming_normal_ 60 | elif initial_method == 'kaiming_uniform': 61 | init_method = init.kaiming_uniform_ 62 | elif initial_method == 'orthogonal': 63 | init_method = init.orthogonal_ 64 | elif initial_method == 'sparse': 65 | init_method = init.sparse_ 66 | elif initial_method == 'normal': 67 | init_method = init.normal_ 68 | elif initial_method == 'uniform': 69 | init_method = init.uniform_ 70 | else: 71 | init_method = init.xavier_normal_ 72 | 73 | def weights_init(m): 74 | # classname = m.__class__.__name__ 75 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn 76 | if initial_method is not None: 77 | init_method(m.weight.data) 78 | else: 79 | init.xavier_normal_(m.weight.data) 80 | init.normal_(m.bias.data) 81 | elif isinstance(m, nn.LSTM): 82 | for w in m.parameters(): 83 | if len(w.data.size()) > 1: 84 | init_method(w.data) # weight 85 | else: 86 | init.normal_(w.data) # bias 87 | elif m is not None and hasattr(m, 'weight') and \ 88 | hasattr(m.weight, "requires_grad"): 89 | init_method(m.weight.data) 90 | else: 91 | for w in m.parameters(): 92 | if w.requires_grad: 93 | if len(w.data.size()) > 1: 94 | init_method(w.data) # weight 95 | else: 96 | init.normal_(w.data) # bias 97 | # print("init else") 98 | 99 | net.apply(weights_init) 100 | 101 | 102 | def mask_softmax(input, dim, mask=None): 103 | """ 104 | 根据dim和mask,对input做softmax的操作 105 | """ 106 | if mask is None: 107 | return F.softmax(input, dim=dim) 108 | 109 | else: 110 | masked_input = input.masked_fill((1 - mask.unsqueeze(-1)).byte(), -1e32) 111 | return F.softmax(masked_input, dim=dim) 112 | 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | UNF(Universal NLP Framework) is built on pytorch and torchtext. Its design philosophy is: 4 | - ***modularity***: specifically, on the one hand, it is convenient to quickly run some nlp-related tasks; on the other hand, it is convenient for secondary development and research to implement some new models or technologies. 5 | - ***efficiency***: supports **distributed training** and **half-precision** training, which is convenient for quickly training the model, although the current support is relatively crude 6 | - ***comprehensive***: support pytorch **trace** into static graph, support **c ++** server, provide web-server for **debugging tools** 7 | 8 | # Support Tasks 9 | Now, support ***text classification*** and ***sequence labeling*** related tasks. 10 | 11 | # Related Papers 12 | - Convolutional Neural Networks for Sentence Classification [2014](https://arxiv.org/abs/1408.5882) 13 | - Bag of Tricks for Efficient Text Classification [2016](https://arxiv.org/pdf/1607.01759.pdf) 14 | - Deep Pyramid Convolutional Neural Networks for Text Categorization [2017, ACL](https://www.aclweb.org/anthology/P17-1052) 15 | - Hierarchical Attention Networks for Document Classification [2017, ACL](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf) 16 | - A STRUCTURED SELF-ATTENTIVE SENTENCE EMBEDDING [2017,ICLR](https://arxiv.org/abs/1703.03130) 17 | - Joint Embedding of Words and Labels for Text Classification[2018,ACL](https://www.aclweb.org/anthology/P18-1216/) 18 | - Neural Architectures for Named Entity [2016,ACL](https://www.aclweb.org/anthology/N16-1030/) 19 | - Semi-supervised Multitask Learning for Sequence Labeling [2017, ACL](https://arxiv.org/abs/1704.07156) 20 | 21 | 22 | # Framwork 23 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/framework.png) 24 | 25 | 26 | # Module relation 27 | 28 | Module name | Module function 29 | ---|--- 30 | UNF.data | Load data from disk to RAM, include batch, padding,numerical 31 | UNF.module | Neural network layer, include encoder, decoder, embedding, provided for use by the model 32 | UNF.model | Neural network model structure, include DpCnn, SelAttention,Lstm-crf..and python predictor for those models 33 | UNF.training | Model training, include early stopping, model save and reload, visualize metrics throuth Tensorboard 34 | UNF.tracing | Trace pytorch dynamic graph to static graph, and provide c++ serving 35 | UNF.web_server | Web server tool related 36 | 37 | 38 | # Requirement 39 | python3 40 | 41 | pip3 install -r requirement.txt 42 | 43 | # Training 44 | 45 | ``` 46 | #quick start 47 | python3 train_flow.py 48 | ``` 49 | ***Only* 5 line code need** 50 | ``` 51 | #data loader 52 | data_loader = DataLoader(data_loader_conf) 53 | train_iter, dev_iter, test_iter = data_loader.generate_dataset() 54 | 55 | #model loader 56 | model, model_conf = ModelLoader.from_params(model_conf, data_loader.fields) 57 | 58 | #learner loader 59 | learner = LearnerLoader.from_params(model, train_iter, dev_iter, learner_conf, test_iter=test_iter, fields=data_loader.fields, model_conf=model_conf) 60 | 61 | #learning 62 | learner.learn() 63 | ``` 64 | ### train_flow.py is the demo code for training,Run directly!! 65 | 66 | ### conf for multu-gpu and mixed precision 67 | ``` 68 | "use_fp16": False, 69 | "multi_gpu": False 70 | ``` 71 | 72 | ### tensorboard demo 73 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/tensorboard1.png) 74 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/tensorboard2.png) 75 | 76 | # Python inference 77 | 78 | ``` 79 | #quick start 80 | python3 score_flow.py 81 | ``` 82 | 83 | ``` 84 | #core code 85 | from models.predictor import Predictor 86 | 87 | predictor = Predictor(model_path, device, model_type) 88 | logits = predictor.predict(input) 89 | 90 | (0.18, -0.67) 91 | ``` 92 | 93 | # C++ inference 94 | 95 | ### step1: Trace dynamic graph to static graph 96 | 97 | 98 | ``` 99 | #quick start 100 | python3 trace.py 101 | ``` 102 | 103 | ``` 104 | #core code 105 | net = globals()[model_cls](**config.__dict__) 106 | net.load_state_dict_trace(torch.load("%s/best.th" % model_path)) 107 | net.eval() 108 | 109 | mock_input = net.mock_input_data() 110 | tr = torch.jit.trace(net, mock_input) 111 | tr.save("trace/%s" % save_path) 112 | ``` 113 | 114 | ### step2: c++ serving 115 | - install cmake 116 | - download [libtorch](https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.2.0.zip) and unzip to trace folder 117 | 118 | ``` 119 | cd trace 120 | cmake -DCMAKE_PREFIX_PATH=libtorch . 121 | ``` 122 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/cmake.png) 123 | 124 | ``` 125 | make 126 | ``` 127 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/make.png) 128 | 129 | ``` 130 | ./predict trace.pt predict_vocab.txt 131 | output: 2.2128 -2.3287 132 | ``` 133 | 134 | # RESTFUL-API web demo 135 | 136 | ``` 137 | cd web_server 138 | python run.py 139 | ``` 140 | 141 | ![image](https://github.com/waterzxj/UNF/blob/master/pic/web_demo.png) 142 | 143 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-jquery-ui.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload jQuery UI Plugin 8.7.1 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global define, window */ 14 | 15 | (function (factory) { 16 | 'use strict'; 17 | if (typeof define === 'function' && define.amd) { 18 | // Register as an anonymous AMD module: 19 | define(['jquery', './jquery.fileupload-ui'], factory); 20 | } else { 21 | // Browser globals: 22 | factory(window.jQuery); 23 | } 24 | }(function ($) { 25 | 'use strict'; 26 | 27 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 28 | 29 | options: { 30 | processdone: function (e, data) { 31 | data.context.find('.start').button('enable'); 32 | }, 33 | progress: function (e, data) { 34 | if (data.context) { 35 | data.context.find('.progress').progressbar( 36 | 'option', 37 | 'value', 38 | parseInt(data.loaded / data.total * 100, 10) 39 | ); 40 | } 41 | }, 42 | progressall: function (e, data) { 43 | var $this = $(this); 44 | $this.find('.fileupload-progress') 45 | .find('.progress').progressbar( 46 | 'option', 47 | 'value', 48 | parseInt(data.loaded / data.total * 100, 10) 49 | ).end() 50 | .find('.progress-extended').each(function () { 51 | $(this).html( 52 | ($this.data('blueimp-fileupload') || 53 | $this.data('fileupload')) 54 | ._renderExtendedProgress(data) 55 | ); 56 | }); 57 | } 58 | }, 59 | 60 | _renderUpload: function (func, files) { 61 | var node = this._super(func, files), 62 | showIconText = $(window).width() > 480; 63 | node.find('.progress').empty().progressbar(); 64 | node.find('.start').button({ 65 | icons: {primary: 'ui-icon-circle-arrow-e'}, 66 | text: showIconText 67 | }); 68 | node.find('.cancel').button({ 69 | icons: {primary: 'ui-icon-cancel'}, 70 | text: showIconText 71 | }); 72 | if (node.hasClass('fade')) { 73 | node.hide(); 74 | } 75 | return node; 76 | }, 77 | 78 | _renderDownload: function (func, files) { 79 | var node = this._super(func, files), 80 | showIconText = $(window).width() > 480; 81 | node.find('.delete').button({ 82 | icons: {primary: 'ui-icon-trash'}, 83 | text: showIconText 84 | }); 85 | if (node.hasClass('fade')) { 86 | node.hide(); 87 | } 88 | return node; 89 | }, 90 | 91 | _startHandler: function (e) { 92 | $(e.currentTarget).button('disable'); 93 | this._super(e); 94 | }, 95 | 96 | _transition: function (node) { 97 | var deferred = $.Deferred(); 98 | if (node.hasClass('fade')) { 99 | node.fadeToggle( 100 | this.options.transitionDuration, 101 | this.options.transitionEasing, 102 | function () { 103 | deferred.resolveWith(node); 104 | } 105 | ); 106 | } else { 107 | deferred.resolveWith(node); 108 | } 109 | return deferred; 110 | }, 111 | 112 | _create: function () { 113 | this._super(); 114 | this.element 115 | .find('.fileupload-buttonbar') 116 | .find('.fileinput-button').each(function () { 117 | var input = $(this).find('input:file').detach(); 118 | $(this) 119 | .button({icons: {primary: 'ui-icon-plusthick'}}) 120 | .append(input); 121 | }) 122 | .end().find('.start') 123 | .button({icons: {primary: 'ui-icon-circle-arrow-e'}}) 124 | .end().find('.cancel') 125 | .button({icons: {primary: 'ui-icon-cancel'}}) 126 | .end().find('.delete') 127 | .button({icons: {primary: 'ui-icon-trash'}}) 128 | .end().find('.progress').progressbar(); 129 | }, 130 | 131 | _destroy: function () { 132 | this.element 133 | .find('.fileupload-buttonbar') 134 | .find('.fileinput-button').each(function () { 135 | var input = $(this).find('input:file').detach(); 136 | $(this) 137 | .button('destroy') 138 | .append(input); 139 | }) 140 | .end().find('.start') 141 | .button('destroy') 142 | .end().find('.cancel') 143 | .button('destroy') 144 | .end().find('.delete') 145 | .button('destroy') 146 | .end().find('.progress').progressbar('destroy'); 147 | this._super(); 148 | } 149 | 150 | }); 151 | 152 | })); 153 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-process.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Processing Plugin 1.3.0 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2012, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global define, window */ 14 | 15 | (function (factory) { 16 | 'use strict'; 17 | if (typeof define === 'function' && define.amd) { 18 | // Register as an anonymous AMD module: 19 | define([ 20 | 'jquery', 21 | './jquery.fileupload' 22 | ], factory); 23 | } else { 24 | // Browser globals: 25 | factory( 26 | window.jQuery 27 | ); 28 | } 29 | }(function ($) { 30 | 'use strict'; 31 | 32 | var originalAdd = $.blueimp.fileupload.prototype.options.add; 33 | 34 | // The File Upload Processing plugin extends the fileupload widget 35 | // with file processing functionality: 36 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 37 | 38 | options: { 39 | // The list of processing actions: 40 | processQueue: [ 41 | /* 42 | { 43 | action: 'log', 44 | type: 'debug' 45 | } 46 | */ 47 | ], 48 | add: function (e, data) { 49 | var $this = $(this); 50 | data.process(function () { 51 | return $this.fileupload('process', data); 52 | }); 53 | originalAdd.call(this, e, data); 54 | } 55 | }, 56 | 57 | processActions: { 58 | /* 59 | log: function (data, options) { 60 | console[options.type]( 61 | 'Processing "' + data.files[data.index].name + '"' 62 | ); 63 | } 64 | */ 65 | }, 66 | 67 | _processFile: function (data, originalData) { 68 | var that = this, 69 | dfd = $.Deferred().resolveWith(that, [data]), 70 | chain = dfd.promise(); 71 | this._trigger('process', null, data); 72 | $.each(data.processQueue, function (i, settings) { 73 | var func = function (data) { 74 | if (originalData.errorThrown) { 75 | return $.Deferred() 76 | .rejectWith(that, [originalData]).promise(); 77 | } 78 | return that.processActions[settings.action].call( 79 | that, 80 | data, 81 | settings 82 | ); 83 | }; 84 | chain = chain.pipe(func, settings.always && func); 85 | }); 86 | chain 87 | .done(function () { 88 | that._trigger('processdone', null, data); 89 | that._trigger('processalways', null, data); 90 | }) 91 | .fail(function () { 92 | that._trigger('processfail', null, data); 93 | that._trigger('processalways', null, data); 94 | }); 95 | return chain; 96 | }, 97 | 98 | // Replaces the settings of each processQueue item that 99 | // are strings starting with an "@", using the remaining 100 | // substring as key for the option map, 101 | // e.g. "@autoUpload" is replaced with options.autoUpload: 102 | _transformProcessQueue: function (options) { 103 | var processQueue = []; 104 | $.each(options.processQueue, function () { 105 | var settings = {}, 106 | action = this.action, 107 | prefix = this.prefix === true ? action : this.prefix; 108 | $.each(this, function (key, value) { 109 | if ($.type(value) === 'string' && 110 | value.charAt(0) === '@') { 111 | settings[key] = options[ 112 | value.slice(1) || (prefix ? prefix + 113 | key.charAt(0).toUpperCase() + key.slice(1) : key) 114 | ]; 115 | } else { 116 | settings[key] = value; 117 | } 118 | 119 | }); 120 | processQueue.push(settings); 121 | }); 122 | options.processQueue = processQueue; 123 | }, 124 | 125 | // Returns the number of files currently in the processsing queue: 126 | processing: function () { 127 | return this._processing; 128 | }, 129 | 130 | // Processes the files given as files property of the data parameter, 131 | // returns a Promise object that allows to bind callbacks: 132 | process: function (data) { 133 | var that = this, 134 | options = $.extend({}, this.options, data); 135 | if (options.processQueue && options.processQueue.length) { 136 | this._transformProcessQueue(options); 137 | if (this._processing === 0) { 138 | this._trigger('processstart'); 139 | } 140 | $.each(data.files, function (index) { 141 | var opts = index ? $.extend({}, options) : options, 142 | func = function () { 143 | if (data.errorThrown) { 144 | return $.Deferred() 145 | .rejectWith(that, [data]).promise(); 146 | } 147 | return that._processFile(opts, data); 148 | }; 149 | opts.index = index; 150 | that._processing += 1; 151 | that._processingQueue = that._processingQueue.pipe(func, func) 152 | .always(function () { 153 | that._processing -= 1; 154 | if (that._processing === 0) { 155 | that._trigger('processstop'); 156 | } 157 | }); 158 | }); 159 | } 160 | return this._processingQueue; 161 | }, 162 | 163 | _create: function () { 164 | this._super(); 165 | this._processing = 0; 166 | this._processingQueue = $.Deferred().resolveWith(this) 167 | .promise(); 168 | } 169 | 170 | }); 171 | 172 | })); 173 | -------------------------------------------------------------------------------- /UNF/training/metric.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import torch 4 | 5 | from learner_util import get_ner_BIO 6 | 7 | 8 | class Metric(object): 9 | def __call__(self, 10 | predictions, 11 | gold_labels, 12 | mask=None): 13 | """ 14 | metric的抽象类 15 | 16 | :params predictions 预测结果的tensor 17 | :params gold_labels 实际结果的tensor 18 | :mask mask 19 | """ 20 | raise NotImplementedError 21 | 22 | def get_metric(self, reset=False): 23 | """ 24 | 返回metric的指标 25 | """ 26 | raise NotImplementedError 27 | 28 | def reset(self): 29 | """ 30 | 重置内部状态 31 | """ 32 | raise NotImplementedError 33 | 34 | @staticmethod 35 | def unwrap_to_tensors(*tensors): 36 | """ 37 | 把tensor安全的copy到cpu进行操作,避免gpu的oom 38 | """ 39 | return (x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in tensors) 40 | 41 | @classmethod 42 | def from_option(cls, conf): 43 | return cls(**conf) 44 | 45 | 46 | class F1Measure(Metric): 47 | def __init__(self, positive_label): 48 | """ 49 | 准确率、召回率、F值的评价指标 50 | """ 51 | super(F1Measure, self).__init__() 52 | self._positive_label = positive_label 53 | self._true_positives = 0.0 54 | self._true_negatives = 0.0 55 | self._false_positives = 0.0 56 | self._false_negatives = 0.0 57 | 58 | def __call__(self, 59 | predictions, 60 | gold_labels, 61 | mask=None): 62 | predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) 63 | num_classes = predictions.size(-1) 64 | if (gold_labels >= num_classes).any(): 65 | raise Exception("A gold label passed to F1Measure contains an id >= {}, " 66 | "the number of classes.".format(num_classes)) 67 | if mask is None: 68 | mask = torch.ones_like(gold_labels) 69 | mask = mask.float() 70 | gold_labels = gold_labels.float() 71 | 72 | self.update(predictions, gold_labels, mask) 73 | 74 | def update(self, predictions, gold_labels, mask): 75 | positive_label_mask = gold_labels.eq(self._positive_label).float() 76 | negative_label_mask = 1.0 - positive_label_mask 77 | 78 | argmax_predictions = predictions.max(-1)[1].float().squeeze(-1) 79 | 80 | # True Negatives: correct non-positive predictions. 81 | correct_null_predictions = (argmax_predictions != 82 | self._positive_label).float() * negative_label_mask 83 | self._true_negatives += (correct_null_predictions.float() * mask).sum() 84 | 85 | # True Positives: correct positively labeled predictions. 86 | correct_non_null_predictions = (argmax_predictions == 87 | self._positive_label).float() * positive_label_mask 88 | self._true_positives += (correct_non_null_predictions * mask).sum() 89 | 90 | # False Negatives: incorrect negatively labeled predictions. 91 | incorrect_null_predictions = (argmax_predictions != 92 | self._positive_label).float() * positive_label_mask 93 | self._false_negatives += (incorrect_null_predictions * mask).sum() 94 | 95 | # False Positives: incorrect positively labeled predictions 96 | incorrect_non_null_predictions = (argmax_predictions == 97 | self._positive_label).float() * negative_label_mask 98 | self._false_positives += (incorrect_non_null_predictions * mask).sum() 99 | 100 | def get_metric(self, reset=False): 101 | """ 102 | 返回准确率、召回率、F值评价指标 103 | """ 104 | # print('TP',self._true_positives,'TN',self._true_negatives,'FP',self._false_positives,'FN',self._false_negatives) 105 | 106 | precision = float(self._true_positives) / float(self._true_positives + self._false_positives + 1e-13) 107 | recall = float(self._true_positives) / float(self._true_positives + self._false_negatives + 1e-13) 108 | f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13)) 109 | if reset: 110 | self.reset() 111 | return {"precision":precision, "recall": recall, "f1_measure":f1_measure} 112 | 113 | def reset(self): 114 | self._true_positives = 0.0 115 | self._true_negatives = 0.0 116 | self._false_positives = 0.0 117 | self._false_negatives = 0.0 118 | 119 | 120 | class NerF1Measure(Metric): 121 | def __init__(self, label_vocab): 122 | self.golden_num = 0.0 123 | self.predict_num = 0.0 124 | self.right_num = 0.0 125 | self.label_vocab = label_vocab 126 | 127 | def reset(self): 128 | """ 129 | 重置内部状态 130 | """ 131 | self.golden_num = 0.0 132 | self.predict_num = 0.0 133 | self.right_num = 0.0 134 | 135 | def get_metric(self, reset=False): 136 | """ 137 | 返回metric的指标 138 | """ 139 | if self.predict_num == 0.0: 140 | precision = -1 141 | else: 142 | precision = (self.right_num+0.0)/self.predict_num 143 | 144 | if self.golden_num == 0.0: 145 | recall = -1 146 | else: 147 | recall = (self.right_num+0.0)/self.golden_num 148 | 149 | if (precision == -1) or (recall == -1) or (precision+recall) <= 0.: 150 | f_measure = -1 151 | else: 152 | f_measure = 2*precision*recall/(precision+recall) 153 | 154 | if reset: 155 | self.reset() 156 | 157 | return {"precision":precision, "recall": recall, "f1_measure":f_measure} 158 | 159 | def update(self, gold_matrix, pred_matrix): 160 | right_ner = list(set(gold_matrix).intersection(set(pred_matrix))) 161 | self.golden_num += len(gold_matrix) 162 | self.predict_num += len(pred_matrix) 163 | self.right_num += len(right_ner) 164 | 165 | def __call__(self, 166 | predictions, 167 | gold_labels, 168 | mask=None): 169 | """ 170 | metric的抽象类 171 | 172 | :params predictions 预测结果的tensor 173 | :params gold_labels 实际结果的tensor 174 | :mask mask 175 | """ 176 | batch_size = gold_labels.size(0) 177 | seq_len = gold_labels.size(1) 178 | predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, 179 | mask) 180 | 181 | predictions = predictions.tolist() 182 | gold_labels = gold_labels.tolist() 183 | mask = mask.tolist() 184 | 185 | for idx in range(batch_size): 186 | pred = [self.label_vocab[predictions[idx][idy]] for idy in range(seq_len) if mask[idx][idy] != 0] 187 | gold = [self.label_vocab[gold_labels[idx][idy]] for idy in range(seq_len) if mask[idx][idy] != 0] 188 | 189 | 190 | gold_matrix = get_ner_BIO(gold) 191 | pred_matrix = get_ner_BIO(pred) 192 | self.update(gold_matrix, pred_matrix) 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.iframe-transport.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery Iframe Transport Plugin 1.8.2 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2011, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* global define, window, document */ 13 | 14 | (function (factory) { 15 | 'use strict'; 16 | if (typeof define === 'function' && define.amd) { 17 | // Register as an anonymous AMD module: 18 | define(['jquery'], factory); 19 | } else { 20 | // Browser globals: 21 | factory(window.jQuery); 22 | } 23 | }(function ($) { 24 | 'use strict'; 25 | 26 | // Helper variable to create unique names for the transport iframes: 27 | var counter = 0; 28 | 29 | // The iframe transport accepts four additional options: 30 | // options.fileInput: a jQuery collection of file input fields 31 | // options.paramName: the parameter name for the file form data, 32 | // overrides the name property of the file input field(s), 33 | // can be a string or an array of strings. 34 | // options.formData: an array of objects with name and value properties, 35 | // equivalent to the return data of .serializeArray(), e.g.: 36 | // [{name: 'a', value: 1}, {name: 'b', value: 2}] 37 | // options.initialIframeSrc: the URL of the initial iframe src, 38 | // by default set to "javascript:false;" 39 | $.ajaxTransport('iframe', function (options) { 40 | if (options.async) { 41 | // javascript:false as initial iframe src 42 | // prevents warning popups on HTTPS in IE6: 43 | /*jshint scripturl: true */ 44 | var initialIframeSrc = options.initialIframeSrc || 'javascript:false;', 45 | /*jshint scripturl: false */ 46 | form, 47 | iframe, 48 | addParamChar; 49 | return { 50 | send: function (_, completeCallback) { 51 | form = $('
'); 52 | form.attr('accept-charset', options.formAcceptCharset); 53 | addParamChar = /\?/.test(options.url) ? '&' : '?'; 54 | // XDomainRequest only supports GET and POST: 55 | if (options.type === 'DELETE') { 56 | options.url = options.url + addParamChar + '_method=DELETE'; 57 | options.type = 'POST'; 58 | } else if (options.type === 'PUT') { 59 | options.url = options.url + addParamChar + '_method=PUT'; 60 | options.type = 'POST'; 61 | } else if (options.type === 'PATCH') { 62 | options.url = options.url + addParamChar + '_method=PATCH'; 63 | options.type = 'POST'; 64 | } 65 | // IE versions below IE8 cannot set the name property of 66 | // elements that have already been added to the DOM, 67 | // so we set the name along with the iframe HTML markup: 68 | counter += 1; 69 | iframe = $( 70 | '' 72 | ).bind('load', function () { 73 | var fileInputClones, 74 | paramNames = $.isArray(options.paramName) ? 75 | options.paramName : [options.paramName]; 76 | iframe 77 | .unbind('load') 78 | .bind('load', function () { 79 | var response; 80 | // Wrap in a try/catch block to catch exceptions thrown 81 | // when trying to access cross-domain iframe contents: 82 | try { 83 | response = iframe.contents(); 84 | // Google Chrome and Firefox do not throw an 85 | // exception when calling iframe.contents() on 86 | // cross-domain requests, so we unify the response: 87 | if (!response.length || !response[0].firstChild) { 88 | throw new Error(); 89 | } 90 | } catch (e) { 91 | response = undefined; 92 | } 93 | // The complete callback returns the 94 | // iframe content document as response object: 95 | completeCallback( 96 | 200, 97 | 'success', 98 | {'iframe': response} 99 | ); 100 | // Fix for IE endless progress bar activity bug 101 | // (happens on form submits to iframe targets): 102 | $('') 103 | .appendTo(form); 104 | window.setTimeout(function () { 105 | // Removing the form in a setTimeout call 106 | // allows Chrome's developer tools to display 107 | // the response result 108 | form.remove(); 109 | }, 0); 110 | }); 111 | form 112 | .prop('target', iframe.prop('name')) 113 | .prop('action', options.url) 114 | .prop('method', options.type); 115 | if (options.formData) { 116 | $.each(options.formData, function (index, field) { 117 | $('') 118 | .prop('name', field.name) 119 | .val(field.value) 120 | .appendTo(form); 121 | }); 122 | } 123 | if (options.fileInput && options.fileInput.length && 124 | options.type === 'POST') { 125 | fileInputClones = options.fileInput.clone(); 126 | // Insert a clone for each file input field: 127 | options.fileInput.after(function (index) { 128 | return fileInputClones[index]; 129 | }); 130 | if (options.paramName) { 131 | options.fileInput.each(function (index) { 132 | $(this).prop( 133 | 'name', 134 | paramNames[index] || options.paramName 135 | ); 136 | }); 137 | } 138 | // Appending the file input fields to the hidden form 139 | // removes them from their original location: 140 | form 141 | .append(options.fileInput) 142 | .prop('enctype', 'multipart/form-data') 143 | // enctype must be set as encoding for IE: 144 | .prop('encoding', 'multipart/form-data'); 145 | // Remove the HTML5 form attribute from the input(s): 146 | options.fileInput.removeAttr('form'); 147 | } 148 | form.submit(); 149 | // Insert the file input fields at their original location 150 | // by replacing the clones with the originals: 151 | if (fileInputClones && fileInputClones.length) { 152 | options.fileInput.each(function (index, input) { 153 | var clone = $(fileInputClones[index]); 154 | // Restore the original name and form properties: 155 | $(input) 156 | .prop('name', clone.prop('name')) 157 | .attr('form', clone.attr('form')); 158 | clone.replaceWith(input); 159 | }); 160 | } 161 | }); 162 | form.append(iframe).appendTo(document.body); 163 | }, 164 | abort: function () { 165 | if (iframe) { 166 | // javascript:false as iframe src aborts the request 167 | // and prevents warning popups on HTTPS in IE6. 168 | // concat is used to avoid the "Script URL" JSLint error: 169 | iframe 170 | .unbind('load') 171 | .prop('src', initialIframeSrc); 172 | } 173 | if (form) { 174 | form.remove(); 175 | } 176 | } 177 | }; 178 | } 179 | }); 180 | 181 | // The iframe transport returns the iframe content document as response. 182 | // The following adds converters from iframe to text, json, html, xml 183 | // and script. 184 | // Please note that the Content-Type for JSON responses has to be text/plain 185 | // or text/html, if the browser doesn't include application/json in the 186 | // Accept header, else IE will show a download dialog. 187 | // The Content-Type for XML responses on the other hand has to be always 188 | // application/xml or text/xml, so IE properly parses the XML response. 189 | // See also 190 | // https://github.com/blueimp/jQuery-File-Upload/wiki/Setup#content-type-negotiation 191 | $.ajaxSetup({ 192 | converters: { 193 | 'iframe text': function (iframe) { 194 | return iframe && $(iframe[0].body).text(); 195 | }, 196 | 'iframe json': function (iframe) { 197 | return iframe && $.parseJSON($(iframe[0].body).text()); 198 | }, 199 | 'iframe html': function (iframe) { 200 | return iframe && $(iframe[0].body).html(); 201 | }, 202 | 'iframe xml': function (iframe) { 203 | var xmlDoc = iframe && iframe[0]; 204 | return xmlDoc && $.isXMLDoc(xmlDoc) ? xmlDoc : 205 | $.parseXML((xmlDoc.XMLDocument && xmlDoc.XMLDocument.xml) || 206 | $(xmlDoc.body).html()); 207 | }, 208 | 'iframe script': function (iframe) { 209 | return iframe && $.globalEval($(iframe[0].body).text()); 210 | } 211 | } 212 | }); 213 | 214 | })); 215 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/jquery.fileupload-image.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery File Upload Image Preview & Resize Plugin 1.7.2 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2013, Sebastian Tschan 6 | * https://blueimp.net 7 | * 8 | * Licensed under the MIT license: 9 | * http://www.opensource.org/licenses/MIT 10 | */ 11 | 12 | /* jshint nomen:false */ 13 | /* global define, window, Blob */ 14 | 15 | (function (factory) { 16 | 'use strict'; 17 | if (typeof define === 'function' && define.amd) { 18 | // Register as an anonymous AMD module: 19 | define([ 20 | 'jquery', 21 | 'load-image', 22 | 'load-image-meta', 23 | 'load-image-exif', 24 | 'load-image-ios', 25 | 'canvas-to-blob', 26 | './jquery.fileupload-process' 27 | ], factory); 28 | } else { 29 | // Browser globals: 30 | factory( 31 | window.jQuery, 32 | window.loadImage 33 | ); 34 | } 35 | }(function ($, loadImage) { 36 | 'use strict'; 37 | 38 | // Prepend to the default processQueue: 39 | $.blueimp.fileupload.prototype.options.processQueue.unshift( 40 | { 41 | action: 'loadImageMetaData', 42 | disableImageHead: '@', 43 | disableExif: '@', 44 | disableExifThumbnail: '@', 45 | disableExifSub: '@', 46 | disableExifGps: '@', 47 | disabled: '@disableImageMetaDataLoad' 48 | }, 49 | { 50 | action: 'loadImage', 51 | // Use the action as prefix for the "@" options: 52 | prefix: true, 53 | fileTypes: '@', 54 | maxFileSize: '@', 55 | noRevoke: '@', 56 | disabled: '@disableImageLoad' 57 | }, 58 | { 59 | action: 'resizeImage', 60 | // Use "image" as prefix for the "@" options: 61 | prefix: 'image', 62 | maxWidth: '@', 63 | maxHeight: '@', 64 | minWidth: '@', 65 | minHeight: '@', 66 | crop: '@', 67 | orientation: '@', 68 | forceResize: '@', 69 | disabled: '@disableImageResize' 70 | }, 71 | { 72 | action: 'saveImage', 73 | quality: '@imageQuality', 74 | type: '@imageType', 75 | disabled: '@disableImageResize' 76 | }, 77 | { 78 | action: 'saveImageMetaData', 79 | disabled: '@disableImageMetaDataSave' 80 | }, 81 | { 82 | action: 'resizeImage', 83 | // Use "preview" as prefix for the "@" options: 84 | prefix: 'preview', 85 | maxWidth: '@', 86 | maxHeight: '@', 87 | minWidth: '@', 88 | minHeight: '@', 89 | crop: '@', 90 | orientation: '@', 91 | thumbnail: '@', 92 | canvas: '@', 93 | disabled: '@disableImagePreview' 94 | }, 95 | { 96 | action: 'setImage', 97 | name: '@imagePreviewName', 98 | disabled: '@disableImagePreview' 99 | }, 100 | { 101 | action: 'deleteImageReferences', 102 | disabled: '@disableImageReferencesDeletion' 103 | } 104 | ); 105 | 106 | // The File Upload Resize plugin extends the fileupload widget 107 | // with image resize functionality: 108 | $.widget('blueimp.fileupload', $.blueimp.fileupload, { 109 | 110 | options: { 111 | // The regular expression for the types of images to load: 112 | // matched against the file type: 113 | loadImageFileTypes: /^image\/(gif|jpeg|png|svg\+xml)$/, 114 | // The maximum file size of images to load: 115 | loadImageMaxFileSize: 10000000, // 10MB 116 | // The maximum width of resized images: 117 | imageMaxWidth: 1920, 118 | // The maximum height of resized images: 119 | imageMaxHeight: 1080, 120 | // Defines the image orientation (1-8) or takes the orientation 121 | // value from Exif data if set to true: 122 | imageOrientation: false, 123 | // Define if resized images should be cropped or only scaled: 124 | imageCrop: false, 125 | // Disable the resize image functionality by default: 126 | disableImageResize: true, 127 | // The maximum width of the preview images: 128 | previewMaxWidth: 80, 129 | // The maximum height of the preview images: 130 | previewMaxHeight: 80, 131 | // Defines the preview orientation (1-8) or takes the orientation 132 | // value from Exif data if set to true: 133 | previewOrientation: true, 134 | // Create the preview using the Exif data thumbnail: 135 | previewThumbnail: true, 136 | // Define if preview images should be cropped or only scaled: 137 | previewCrop: false, 138 | // Define if preview images should be resized as canvas elements: 139 | previewCanvas: true 140 | }, 141 | 142 | processActions: { 143 | 144 | // Loads the image given via data.files and data.index 145 | // as img element, if the browser supports the File API. 146 | // Accepts the options fileTypes (regular expression) 147 | // and maxFileSize (integer) to limit the files to load: 148 | loadImage: function (data, options) { 149 | if (options.disabled) { 150 | return data; 151 | } 152 | var that = this, 153 | file = data.files[data.index], 154 | dfd = $.Deferred(); 155 | if (($.type(options.maxFileSize) === 'number' && 156 | file.size > options.maxFileSize) || 157 | (options.fileTypes && 158 | !options.fileTypes.test(file.type)) || 159 | !loadImage( 160 | file, 161 | function (img) { 162 | if (img.src) { 163 | data.img = img; 164 | } 165 | dfd.resolveWith(that, [data]); 166 | }, 167 | options 168 | )) { 169 | return data; 170 | } 171 | return dfd.promise(); 172 | }, 173 | 174 | // Resizes the image given as data.canvas or data.img 175 | // and updates data.canvas or data.img with the resized image. 176 | // Also stores the resized image as preview property. 177 | // Accepts the options maxWidth, maxHeight, minWidth, 178 | // minHeight, canvas and crop: 179 | resizeImage: function (data, options) { 180 | if (options.disabled || !(data.canvas || data.img)) { 181 | return data; 182 | } 183 | options = $.extend({canvas: true}, options); 184 | var that = this, 185 | dfd = $.Deferred(), 186 | img = (options.canvas && data.canvas) || data.img, 187 | resolve = function (newImg) { 188 | if (newImg && (newImg.width !== img.width || 189 | newImg.height !== img.height || 190 | options.forceResize)) { 191 | data[newImg.getContext ? 'canvas' : 'img'] = newImg; 192 | } 193 | data.preview = newImg; 194 | dfd.resolveWith(that, [data]); 195 | }, 196 | thumbnail; 197 | if (data.exif) { 198 | if (options.orientation === true) { 199 | options.orientation = data.exif.get('Orientation'); 200 | } 201 | if (options.thumbnail) { 202 | thumbnail = data.exif.get('Thumbnail'); 203 | if (thumbnail) { 204 | loadImage(thumbnail, resolve, options); 205 | return dfd.promise(); 206 | } 207 | } 208 | // Prevent orienting the same image twice: 209 | if (data.orientation) { 210 | delete options.orientation; 211 | } else { 212 | data.orientation = options.orientation; 213 | } 214 | } 215 | if (img) { 216 | resolve(loadImage.scale(img, options)); 217 | return dfd.promise(); 218 | } 219 | return data; 220 | }, 221 | 222 | // Saves the processed image given as data.canvas 223 | // inplace at data.index of data.files: 224 | saveImage: function (data, options) { 225 | if (!data.canvas || options.disabled) { 226 | return data; 227 | } 228 | var that = this, 229 | file = data.files[data.index], 230 | dfd = $.Deferred(); 231 | if (data.canvas.toBlob) { 232 | data.canvas.toBlob( 233 | function (blob) { 234 | if (!blob.name) { 235 | if (file.type === blob.type) { 236 | blob.name = file.name; 237 | } else if (file.name) { 238 | blob.name = file.name.replace( 239 | /\..+$/, 240 | '.' + blob.type.substr(6) 241 | ); 242 | } 243 | } 244 | // Don't restore invalid meta data: 245 | if (file.type !== blob.type) { 246 | delete data.imageHead; 247 | } 248 | // Store the created blob at the position 249 | // of the original file in the files list: 250 | data.files[data.index] = blob; 251 | dfd.resolveWith(that, [data]); 252 | }, 253 | options.type || file.type, 254 | options.quality 255 | ); 256 | } else { 257 | return data; 258 | } 259 | return dfd.promise(); 260 | }, 261 | 262 | loadImageMetaData: function (data, options) { 263 | if (options.disabled) { 264 | return data; 265 | } 266 | var that = this, 267 | dfd = $.Deferred(); 268 | loadImage.parseMetaData(data.files[data.index], function (result) { 269 | $.extend(data, result); 270 | dfd.resolveWith(that, [data]); 271 | }, options); 272 | return dfd.promise(); 273 | }, 274 | 275 | saveImageMetaData: function (data, options) { 276 | if (!(data.imageHead && data.canvas && 277 | data.canvas.toBlob && !options.disabled)) { 278 | return data; 279 | } 280 | var file = data.files[data.index], 281 | blob = new Blob([ 282 | data.imageHead, 283 | // Resized images always have a head size of 20 bytes, 284 | // including the JPEG marker and a minimal JFIF header: 285 | this._blobSlice.call(file, 20) 286 | ], {type: file.type}); 287 | blob.name = file.name; 288 | data.files[data.index] = blob; 289 | return data; 290 | }, 291 | 292 | // Sets the resized version of the image as a property of the 293 | // file object, must be called after "saveImage": 294 | setImage: function (data, options) { 295 | if (data.preview && !options.disabled) { 296 | data.files[data.index][options.name || 'preview'] = data.preview; 297 | } 298 | return data; 299 | }, 300 | 301 | deleteImageReferences: function (data, options) { 302 | if (!options.disabled) { 303 | delete data.img; 304 | delete data.canvas; 305 | delete data.preview; 306 | delete data.imageHead; 307 | } 308 | return data; 309 | } 310 | 311 | } 312 | 313 | }); 314 | 315 | })); 316 | -------------------------------------------------------------------------------- /UNF/modules/decoder/crf.py: -------------------------------------------------------------------------------- 1 | #codig:utf-8 2 | from __future__ import print_function 3 | import torch 4 | import torch.autograd as autograd 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | START_TAG = -2 8 | STOP_TAG = -1 9 | 10 | #ref from https://github.com/jiesutd/NCRFpp 11 | 12 | # Compute log sum exp in a numerically stable way for the forward algorithm 13 | def log_sum_exp(vec, m_size): 14 | """ 15 | calculate log of exp sum 16 | args: 17 | vec (batch_size, vanishing_dim, hidden_dim) : input tensor 18 | m_size : hidden_dim 19 | return: 20 | batch_size, hidden_dim 21 | """ 22 | _, idx = torch.max(vec, 1) # B * 1 * M 23 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 24 | return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M 25 | 26 | class CRF(nn.Module): 27 | 28 | def __init__(self, tagset_size, gpu=None): 29 | super(CRF, self).__init__() 30 | print("build CRF...") 31 | self.gpu = gpu 32 | # Matrix of transition parameters. Entry i,j is the score of transitioning from i to j. 33 | self.tagset_size = tagset_size 34 | # # We add 2 here, because of START_TAG and STOP_TAG 35 | # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag 36 | init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) 37 | init_transitions[:,START_TAG] = -10000.0 38 | init_transitions[STOP_TAG,:] = -10000.0 39 | init_transitions[:,0] = -10000.0 40 | init_transitions[0,:] = -10000.0 41 | if self.gpu: 42 | init_transitions = init_transitions.to(self.gpu) 43 | self.transitions = nn.Parameter(init_transitions) 44 | 45 | # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) 46 | # self.transitions.data.zero_() 47 | 48 | def _calculate_PZ(self, feats, mask): 49 | """ 50 | input: 51 | feats: (batch, seq_len, self.tag_size+2) 52 | masks: (batch, seq_len) 53 | """ 54 | batch_size = feats.size(0) 55 | seq_len = feats.size(1) 56 | tag_size = feats.size(2) 57 | # print feats.view(seq_len, tag_size) 58 | assert(tag_size == self.tagset_size+2) 59 | mask = mask.transpose(1,0).contiguous() 60 | ins_num = seq_len * batch_size 61 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 62 | feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) 63 | ## need to consider start 64 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) 65 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 66 | # build iter 67 | seq_iter = enumerate(scores) 68 | _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size 69 | # only need start from start_tag 70 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size 71 | 72 | ## add start score (from start to all tag, duplicate to batch_size) 73 | # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) 74 | # iter over last scores 75 | for idx, cur_values in seq_iter: 76 | # previous to_target is current from_target 77 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 78 | # cur_values: bat_size * from_target * to_target 79 | 80 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 81 | cur_partition = log_sum_exp(cur_values, tag_size) 82 | # print cur_partition.data 83 | 84 | # (bat_size * from_target * to_target) -> (bat_size * to_target) 85 | # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) 86 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 87 | 88 | ## effective updated partition part, only keep the partition value of mask value = 1 89 | masked_cur_partition = cur_partition.masked_select(mask_idx) 90 | ## let mask_idx broadcastable, to disable warning 91 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 92 | 93 | ## replace the partition where the maskvalue=1, other partition value keeps the same 94 | partition.masked_scatter_(mask_idx, masked_cur_partition) 95 | # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG 96 | cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 97 | cur_partition = log_sum_exp(cur_values, tag_size) 98 | final_partition = cur_partition[:, STOP_TAG] 99 | return final_partition.sum(), scores 100 | 101 | 102 | def _viterbi_decode(self, feats, mask): 103 | """ 104 | input: 105 | feats: (batch, seq_len, self.tag_size+2) 106 | mask: (batch, seq_len) 107 | output: 108 | decode_idx: (batch, seq_len) decoded sequence 109 | path_score: (batch, 1) corresponding score for each sequence (to be implementated) 110 | """ 111 | batch_size = feats.size(0) 112 | seq_len = feats.size(1) 113 | tag_size = feats.size(2) 114 | assert(tag_size == self.tagset_size+2) 115 | ## calculate sentence length for each sentence 116 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() 117 | ## mask to (seq_len, batch_size) 118 | mask = mask.transpose(1,0).contiguous() 119 | ins_num = seq_len * batch_size 120 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 121 | feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 122 | ## need to consider start 123 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) 124 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 125 | 126 | # build iter 127 | seq_iter = enumerate(scores) 128 | ## record the position of best score 129 | back_points = list() 130 | partition_history = list() 131 | ## reverse mask (bug for mask = 1- mask, use this as alternative choice) 132 | # mask = 1 + (-1)*mask 133 | mask = (1 - mask.long()).byte() 134 | _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size 135 | # only need start from start_tag 136 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size 137 | # print "init part:",partition.size() 138 | partition_history.append(partition) 139 | # iter over last scores 140 | for idx, cur_values in seq_iter: 141 | # previous to_target is current from_target 142 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 143 | # cur_values: batch_size * from_target * to_target 144 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 145 | ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG 146 | # print "cur value:", cur_values.size() 147 | partition, cur_bp = torch.max(cur_values, 1) 148 | # print "partsize:",partition.size() 149 | # exit(0) 150 | # print partition 151 | # print cur_bp 152 | # print "one best, ",idx 153 | partition_history.append(partition) 154 | ## cur_bp: (batch_size, tag_size) max source score position in current tag 155 | ## set padded label as 0, which will be filtered in post processing 156 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 157 | back_points.append(cur_bp) 158 | # exit(0) 159 | ### add score to final STOP_TAG 160 | partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1,0).contiguous() ## (batch_size, seq_len. tag_size) 161 | ### get the last position for each setences, and select the last partitions using gather() 162 | last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 163 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) 164 | ### calculate the score from last partition to end state (and then select the STOP_TAG from it) 165 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) 166 | _, last_bp = torch.max(last_values, 1) 167 | pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() 168 | if self.gpu: 169 | pad_zero = pad_zero.to(self.gpu) 170 | back_points.append(pad_zero) 171 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 172 | 173 | ## select end ids in STOP_TAG 174 | pointer = last_bp[:, STOP_TAG] 175 | insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) 176 | back_points = back_points.transpose(1,0).contiguous() 177 | ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values 178 | # print "lp:",last_position 179 | # print "il:",insert_last 180 | back_points.scatter_(1, last_position, insert_last) 181 | # print "bp:",back_points 182 | # exit(0) 183 | back_points = back_points.transpose(1,0).contiguous() 184 | ## decode from the end, padded position ids are 0, which will be filtered if following evaluation 185 | decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) 186 | if self.gpu: 187 | decode_idx = decode_idx.to(self.gpu) 188 | decode_idx[-1] = pointer.detach() 189 | for idx in range(len(back_points)-2, -1, -1): 190 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 191 | decode_idx[idx] = pointer.detach().view(batch_size) 192 | path_score = None 193 | decode_idx = decode_idx.transpose(1,0) 194 | return path_score, decode_idx 195 | 196 | 197 | def forward(self, feats): 198 | path_score, best_path = self._viterbi_decode(feats) 199 | return path_score, best_path 200 | 201 | 202 | def _score_sentence(self, scores, mask, tags): 203 | """ 204 | input: 205 | scores: variable (seq_len, batch, tag_size, tag_size) 206 | mask: (batch, seq_len) 207 | tags: tensor (batch, seq_len) 208 | output: 209 | score: sum of score for gold sequences within whole batch 210 | """ 211 | # Gives the score of a provided tag sequence 212 | batch_size = scores.size(1) 213 | seq_len = scores.size(0) 214 | tag_size = scores.size(2) 215 | ## convert tag value into a new format, recorded label bigram information to index 216 | new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) 217 | if self.gpu: 218 | new_tags = new_tags.to(self.gpu) 219 | for idx in range(seq_len): 220 | if idx == 0: 221 | ## start -> first score 222 | new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] 223 | 224 | else: 225 | new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] 226 | 227 | ## transition for label to STOP_TAG 228 | end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) 229 | ## length for batch, last word position = length - 1 230 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() 231 | ## index the label id of last word 232 | end_ids = torch.gather(tags, 1, length_mask - 1) 233 | 234 | ## index the transition score for end_id to STOP_TAG 235 | end_energy = torch.gather(end_transition, 1, end_ids) 236 | 237 | ## convert tag as (seq_len, batch_size, 1) 238 | new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) 239 | ### need convert tags id to search from 400 positions of scores 240 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size 241 | ## mask transpose to (seq_len, batch_size) 242 | tg_energy = tg_energy.masked_select(mask.transpose(1,0)) 243 | 244 | # ## calculate the score from START_TAG to first label 245 | # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) 246 | # start_energy = torch.gather(start_transition, 1, tags[0,:]) 247 | 248 | ## add all score together 249 | # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() 250 | gold_score = tg_energy.sum() + end_energy.sum() 251 | return gold_score 252 | 253 | def neg_log_likelihood_loss(self, feats, mask, tags): 254 | # nonegative log likelihood 255 | batch_size = feats.size(0) 256 | forward_score, scores = self._calculate_PZ(feats, mask) 257 | gold_score = self._score_sentence(scores, mask, tags) 258 | # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0] 259 | # exit(0) 260 | return forward_score - gold_score 261 | 262 | 263 | 264 | -------------------------------------------------------------------------------- /UNF/web_server/static/js/vendor/jquery.ui.widget.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * jQuery UI Widget 1.10.4+amd 3 | * https://github.com/blueimp/jQuery-File-Upload 4 | * 5 | * Copyright 2014 jQuery Foundation and other contributors 6 | * Released under the MIT license. 7 | * http://jquery.org/license 8 | * 9 | * http://api.jqueryui.com/jQuery.widget/ 10 | */ 11 | 12 | (function (factory) { 13 | if (typeof define === "function" && define.amd) { 14 | // Register as an anonymous AMD module: 15 | define(["jquery"], factory); 16 | } else { 17 | // Browser globals: 18 | factory(jQuery); 19 | } 20 | }(function( $, undefined ) { 21 | 22 | var uuid = 0, 23 | slice = Array.prototype.slice, 24 | _cleanData = $.cleanData; 25 | $.cleanData = function( elems ) { 26 | for ( var i = 0, elem; (elem = elems[i]) != null; i++ ) { 27 | try { 28 | $( elem ).triggerHandler( "remove" ); 29 | // http://bugs.jquery.com/ticket/8235 30 | } catch( e ) {} 31 | } 32 | _cleanData( elems ); 33 | }; 34 | 35 | $.widget = function( name, base, prototype ) { 36 | var fullName, existingConstructor, constructor, basePrototype, 37 | // proxiedPrototype allows the provided prototype to remain unmodified 38 | // so that it can be used as a mixin for multiple widgets (#8876) 39 | proxiedPrototype = {}, 40 | namespace = name.split( "." )[ 0 ]; 41 | 42 | name = name.split( "." )[ 1 ]; 43 | fullName = namespace + "-" + name; 44 | 45 | if ( !prototype ) { 46 | prototype = base; 47 | base = $.Widget; 48 | } 49 | 50 | // create selector for plugin 51 | $.expr[ ":" ][ fullName.toLowerCase() ] = function( elem ) { 52 | return !!$.data( elem, fullName ); 53 | }; 54 | 55 | $[ namespace ] = $[ namespace ] || {}; 56 | existingConstructor = $[ namespace ][ name ]; 57 | constructor = $[ namespace ][ name ] = function( options, element ) { 58 | // allow instantiation without "new" keyword 59 | if ( !this._createWidget ) { 60 | return new constructor( options, element ); 61 | } 62 | 63 | // allow instantiation without initializing for simple inheritance 64 | // must use "new" keyword (the code above always passes args) 65 | if ( arguments.length ) { 66 | this._createWidget( options, element ); 67 | } 68 | }; 69 | // extend with the existing constructor to carry over any static properties 70 | $.extend( constructor, existingConstructor, { 71 | version: prototype.version, 72 | // copy the object used to create the prototype in case we need to 73 | // redefine the widget later 74 | _proto: $.extend( {}, prototype ), 75 | // track widgets that inherit from this widget in case this widget is 76 | // redefined after a widget inherits from it 77 | _childConstructors: [] 78 | }); 79 | 80 | basePrototype = new base(); 81 | // we need to make the options hash a property directly on the new instance 82 | // otherwise we'll modify the options hash on the prototype that we're 83 | // inheriting from 84 | basePrototype.options = $.widget.extend( {}, basePrototype.options ); 85 | $.each( prototype, function( prop, value ) { 86 | if ( !$.isFunction( value ) ) { 87 | proxiedPrototype[ prop ] = value; 88 | return; 89 | } 90 | proxiedPrototype[ prop ] = (function() { 91 | var _super = function() { 92 | return base.prototype[ prop ].apply( this, arguments ); 93 | }, 94 | _superApply = function( args ) { 95 | return base.prototype[ prop ].apply( this, args ); 96 | }; 97 | return function() { 98 | var __super = this._super, 99 | __superApply = this._superApply, 100 | returnValue; 101 | 102 | this._super = _super; 103 | this._superApply = _superApply; 104 | 105 | returnValue = value.apply( this, arguments ); 106 | 107 | this._super = __super; 108 | this._superApply = __superApply; 109 | 110 | return returnValue; 111 | }; 112 | })(); 113 | }); 114 | constructor.prototype = $.widget.extend( basePrototype, { 115 | // TODO: remove support for widgetEventPrefix 116 | // always use the name + a colon as the prefix, e.g., draggable:start 117 | // don't prefix for widgets that aren't DOM-based 118 | widgetEventPrefix: existingConstructor ? (basePrototype.widgetEventPrefix || name) : name 119 | }, proxiedPrototype, { 120 | constructor: constructor, 121 | namespace: namespace, 122 | widgetName: name, 123 | widgetFullName: fullName 124 | }); 125 | 126 | // If this widget is being redefined then we need to find all widgets that 127 | // are inheriting from it and redefine all of them so that they inherit from 128 | // the new version of this widget. We're essentially trying to replace one 129 | // level in the prototype chain. 130 | if ( existingConstructor ) { 131 | $.each( existingConstructor._childConstructors, function( i, child ) { 132 | var childPrototype = child.prototype; 133 | 134 | // redefine the child widget using the same prototype that was 135 | // originally used, but inherit from the new version of the base 136 | $.widget( childPrototype.namespace + "." + childPrototype.widgetName, constructor, child._proto ); 137 | }); 138 | // remove the list of existing child constructors from the old constructor 139 | // so the old child constructors can be garbage collected 140 | delete existingConstructor._childConstructors; 141 | } else { 142 | base._childConstructors.push( constructor ); 143 | } 144 | 145 | $.widget.bridge( name, constructor ); 146 | }; 147 | 148 | $.widget.extend = function( target ) { 149 | var input = slice.call( arguments, 1 ), 150 | inputIndex = 0, 151 | inputLength = input.length, 152 | key, 153 | value; 154 | for ( ; inputIndex < inputLength; inputIndex++ ) { 155 | for ( key in input[ inputIndex ] ) { 156 | value = input[ inputIndex ][ key ]; 157 | if ( input[ inputIndex ].hasOwnProperty( key ) && value !== undefined ) { 158 | // Clone objects 159 | if ( $.isPlainObject( value ) ) { 160 | target[ key ] = $.isPlainObject( target[ key ] ) ? 161 | $.widget.extend( {}, target[ key ], value ) : 162 | // Don't extend strings, arrays, etc. with objects 163 | $.widget.extend( {}, value ); 164 | // Copy everything else by reference 165 | } else { 166 | target[ key ] = value; 167 | } 168 | } 169 | } 170 | } 171 | return target; 172 | }; 173 | 174 | $.widget.bridge = function( name, object ) { 175 | var fullName = object.prototype.widgetFullName || name; 176 | $.fn[ name ] = function( options ) { 177 | var isMethodCall = typeof options === "string", 178 | args = slice.call( arguments, 1 ), 179 | returnValue = this; 180 | 181 | // allow multiple hashes to be passed on init 182 | options = !isMethodCall && args.length ? 183 | $.widget.extend.apply( null, [ options ].concat(args) ) : 184 | options; 185 | 186 | if ( isMethodCall ) { 187 | this.each(function() { 188 | var methodValue, 189 | instance = $.data( this, fullName ); 190 | if ( !instance ) { 191 | return $.error( "cannot call methods on " + name + " prior to initialization; " + 192 | "attempted to call method '" + options + "'" ); 193 | } 194 | if ( !$.isFunction( instance[options] ) || options.charAt( 0 ) === "_" ) { 195 | return $.error( "no such method '" + options + "' for " + name + " widget instance" ); 196 | } 197 | methodValue = instance[ options ].apply( instance, args ); 198 | if ( methodValue !== instance && methodValue !== undefined ) { 199 | returnValue = methodValue && methodValue.jquery ? 200 | returnValue.pushStack( methodValue.get() ) : 201 | methodValue; 202 | return false; 203 | } 204 | }); 205 | } else { 206 | this.each(function() { 207 | var instance = $.data( this, fullName ); 208 | if ( instance ) { 209 | instance.option( options || {} )._init(); 210 | } else { 211 | $.data( this, fullName, new object( options, this ) ); 212 | } 213 | }); 214 | } 215 | 216 | return returnValue; 217 | }; 218 | }; 219 | 220 | $.Widget = function( /* options, element */ ) {}; 221 | $.Widget._childConstructors = []; 222 | 223 | $.Widget.prototype = { 224 | widgetName: "widget", 225 | widgetEventPrefix: "", 226 | defaultElement: "
", 227 | options: { 228 | disabled: false, 229 | 230 | // callbacks 231 | create: null 232 | }, 233 | _createWidget: function( options, element ) { 234 | element = $( element || this.defaultElement || this )[ 0 ]; 235 | this.element = $( element ); 236 | this.uuid = uuid++; 237 | this.eventNamespace = "." + this.widgetName + this.uuid; 238 | this.options = $.widget.extend( {}, 239 | this.options, 240 | this._getCreateOptions(), 241 | options ); 242 | 243 | this.bindings = $(); 244 | this.hoverable = $(); 245 | this.focusable = $(); 246 | 247 | if ( element !== this ) { 248 | $.data( element, this.widgetFullName, this ); 249 | this._on( true, this.element, { 250 | remove: function( event ) { 251 | if ( event.target === element ) { 252 | this.destroy(); 253 | } 254 | } 255 | }); 256 | this.document = $( element.style ? 257 | // element within the document 258 | element.ownerDocument : 259 | // element is window or document 260 | element.document || element ); 261 | this.window = $( this.document[0].defaultView || this.document[0].parentWindow ); 262 | } 263 | 264 | this._create(); 265 | this._trigger( "create", null, this._getCreateEventData() ); 266 | this._init(); 267 | }, 268 | _getCreateOptions: $.noop, 269 | _getCreateEventData: $.noop, 270 | _create: $.noop, 271 | _init: $.noop, 272 | 273 | destroy: function() { 274 | this._destroy(); 275 | // we can probably remove the unbind calls in 2.0 276 | // all event bindings should go through this._on() 277 | this.element 278 | .unbind( this.eventNamespace ) 279 | // 1.9 BC for #7810 280 | // TODO remove dual storage 281 | .removeData( this.widgetName ) 282 | .removeData( this.widgetFullName ) 283 | // support: jquery <1.6.3 284 | // http://bugs.jquery.com/ticket/9413 285 | .removeData( $.camelCase( this.widgetFullName ) ); 286 | this.widget() 287 | .unbind( this.eventNamespace ) 288 | .removeAttr( "aria-disabled" ) 289 | .removeClass( 290 | this.widgetFullName + "-disabled " + 291 | "ui-state-disabled" ); 292 | 293 | // clean up events and states 294 | this.bindings.unbind( this.eventNamespace ); 295 | this.hoverable.removeClass( "ui-state-hover" ); 296 | this.focusable.removeClass( "ui-state-focus" ); 297 | }, 298 | _destroy: $.noop, 299 | 300 | widget: function() { 301 | return this.element; 302 | }, 303 | 304 | option: function( key, value ) { 305 | var options = key, 306 | parts, 307 | curOption, 308 | i; 309 | 310 | if ( arguments.length === 0 ) { 311 | // don't return a reference to the internal hash 312 | return $.widget.extend( {}, this.options ); 313 | } 314 | 315 | if ( typeof key === "string" ) { 316 | // handle nested keys, e.g., "foo.bar" => { foo: { bar: ___ } } 317 | options = {}; 318 | parts = key.split( "." ); 319 | key = parts.shift(); 320 | if ( parts.length ) { 321 | curOption = options[ key ] = $.widget.extend( {}, this.options[ key ] ); 322 | for ( i = 0; i < parts.length - 1; i++ ) { 323 | curOption[ parts[ i ] ] = curOption[ parts[ i ] ] || {}; 324 | curOption = curOption[ parts[ i ] ]; 325 | } 326 | key = parts.pop(); 327 | if ( arguments.length === 1 ) { 328 | return curOption[ key ] === undefined ? null : curOption[ key ]; 329 | } 330 | curOption[ key ] = value; 331 | } else { 332 | if ( arguments.length === 1 ) { 333 | return this.options[ key ] === undefined ? null : this.options[ key ]; 334 | } 335 | options[ key ] = value; 336 | } 337 | } 338 | 339 | this._setOptions( options ); 340 | 341 | return this; 342 | }, 343 | _setOptions: function( options ) { 344 | var key; 345 | 346 | for ( key in options ) { 347 | this._setOption( key, options[ key ] ); 348 | } 349 | 350 | return this; 351 | }, 352 | _setOption: function( key, value ) { 353 | this.options[ key ] = value; 354 | 355 | if ( key === "disabled" ) { 356 | this.widget() 357 | .toggleClass( this.widgetFullName + "-disabled ui-state-disabled", !!value ) 358 | .attr( "aria-disabled", value ); 359 | this.hoverable.removeClass( "ui-state-hover" ); 360 | this.focusable.removeClass( "ui-state-focus" ); 361 | } 362 | 363 | return this; 364 | }, 365 | 366 | enable: function() { 367 | return this._setOption( "disabled", false ); 368 | }, 369 | disable: function() { 370 | return this._setOption( "disabled", true ); 371 | }, 372 | 373 | _on: function( suppressDisabledCheck, element, handlers ) { 374 | var delegateElement, 375 | instance = this; 376 | 377 | // no suppressDisabledCheck flag, shuffle arguments 378 | if ( typeof suppressDisabledCheck !== "boolean" ) { 379 | handlers = element; 380 | element = suppressDisabledCheck; 381 | suppressDisabledCheck = false; 382 | } 383 | 384 | // no element argument, shuffle and use this.element 385 | if ( !handlers ) { 386 | handlers = element; 387 | element = this.element; 388 | delegateElement = this.widget(); 389 | } else { 390 | // accept selectors, DOM elements 391 | element = delegateElement = $( element ); 392 | this.bindings = this.bindings.add( element ); 393 | } 394 | 395 | $.each( handlers, function( event, handler ) { 396 | function handlerProxy() { 397 | // allow widgets to customize the disabled handling 398 | // - disabled as an array instead of boolean 399 | // - disabled class as method for disabling individual parts 400 | if ( !suppressDisabledCheck && 401 | ( instance.options.disabled === true || 402 | $( this ).hasClass( "ui-state-disabled" ) ) ) { 403 | return; 404 | } 405 | return ( typeof handler === "string" ? instance[ handler ] : handler ) 406 | .apply( instance, arguments ); 407 | } 408 | 409 | // copy the guid so direct unbinding works 410 | if ( typeof handler !== "string" ) { 411 | handlerProxy.guid = handler.guid = 412 | handler.guid || handlerProxy.guid || $.guid++; 413 | } 414 | 415 | var match = event.match( /^(\w+)\s*(.*)$/ ), 416 | eventName = match[1] + instance.eventNamespace, 417 | selector = match[2]; 418 | if ( selector ) { 419 | delegateElement.delegate( selector, eventName, handlerProxy ); 420 | } else { 421 | element.bind( eventName, handlerProxy ); 422 | } 423 | }); 424 | }, 425 | 426 | _off: function( element, eventName ) { 427 | eventName = (eventName || "").split( " " ).join( this.eventNamespace + " " ) + this.eventNamespace; 428 | element.unbind( eventName ).undelegate( eventName ); 429 | }, 430 | 431 | _delay: function( handler, delay ) { 432 | function handlerProxy() { 433 | return ( typeof handler === "string" ? instance[ handler ] : handler ) 434 | .apply( instance, arguments ); 435 | } 436 | var instance = this; 437 | return setTimeout( handlerProxy, delay || 0 ); 438 | }, 439 | 440 | _hoverable: function( element ) { 441 | this.hoverable = this.hoverable.add( element ); 442 | this._on( element, { 443 | mouseenter: function( event ) { 444 | $( event.currentTarget ).addClass( "ui-state-hover" ); 445 | }, 446 | mouseleave: function( event ) { 447 | $( event.currentTarget ).removeClass( "ui-state-hover" ); 448 | } 449 | }); 450 | }, 451 | 452 | _focusable: function( element ) { 453 | this.focusable = this.focusable.add( element ); 454 | this._on( element, { 455 | focusin: function( event ) { 456 | $( event.currentTarget ).addClass( "ui-state-focus" ); 457 | }, 458 | focusout: function( event ) { 459 | $( event.currentTarget ).removeClass( "ui-state-focus" ); 460 | } 461 | }); 462 | }, 463 | 464 | _trigger: function( type, event, data ) { 465 | var prop, orig, 466 | callback = this.options[ type ]; 467 | 468 | data = data || {}; 469 | event = $.Event( event ); 470 | event.type = ( type === this.widgetEventPrefix ? 471 | type : 472 | this.widgetEventPrefix + type ).toLowerCase(); 473 | // the original event may come from any element 474 | // so we need to reset the target on the new event 475 | event.target = this.element[ 0 ]; 476 | 477 | // copy original event properties over to the new event 478 | orig = event.originalEvent; 479 | if ( orig ) { 480 | for ( prop in orig ) { 481 | if ( !( prop in event ) ) { 482 | event[ prop ] = orig[ prop ]; 483 | } 484 | } 485 | } 486 | 487 | this.element.trigger( event, data ); 488 | return !( $.isFunction( callback ) && 489 | callback.apply( this.element[0], [ event ].concat( data ) ) === false || 490 | event.isDefaultPrevented() ); 491 | } 492 | }; 493 | 494 | $.each( { show: "fadeIn", hide: "fadeOut" }, function( method, defaultEffect ) { 495 | $.Widget.prototype[ "_" + method ] = function( element, options, callback ) { 496 | if ( typeof options === "string" ) { 497 | options = { effect: options }; 498 | } 499 | var hasOptions, 500 | effectName = !options ? 501 | method : 502 | options === true || typeof options === "number" ? 503 | defaultEffect : 504 | options.effect || defaultEffect; 505 | options = options || {}; 506 | if ( typeof options === "number" ) { 507 | options = { duration: options }; 508 | } 509 | hasOptions = !$.isEmptyObject( options ); 510 | options.complete = callback; 511 | if ( options.delay ) { 512 | element.delay( options.delay ); 513 | } 514 | if ( hasOptions && $.effects && $.effects.effect[ effectName ] ) { 515 | element[ method ]( options ); 516 | } else if ( effectName !== method && element[ effectName ] ) { 517 | element[ effectName ]( options.duration, options.easing, callback ); 518 | } else { 519 | element.queue(function( next ) { 520 | $( this )[ method ](); 521 | if ( callback ) { 522 | callback.call( element[ 0 ] ); 523 | } 524 | next(); 525 | }); 526 | } 527 | }; 528 | }); 529 | 530 | })); 531 | --------------------------------------------------------------------------------