├── .gitignore ├── LICENSE ├── README.md ├── corenlp.py ├── data_preprocessing ├── README.md ├── data_preprocessing.py ├── getdata.sh ├── install_BNP.py ├── langconv.py ├── t2s.py └── zh_wiki.py ├── get_syninfo.py ├── models └── README.md ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── crf.py ├── crf2.py ├── file_utils.py ├── modeling.py ├── modeling_convert.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_word_boundary_crf.py ├── optimization.py ├── optimization_openai.py ├── tokenization.py ├── tokenization_chinese_word.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── pytorch_pretrained_zen ├── __init__.py ├── file_utils.py ├── modeling.py ├── ngram_utils.py ├── optimization.py └── tokenization.py ├── requirements.txt ├── run.sh ├── run_sample.sh ├── sample_data ├── dev.stanford.json ├── dev.tsv ├── label2id ├── sentence.txt ├── test.stanford.json ├── test.tsv ├── train.stanford.json └── train.tsv ├── twasp_eval.py ├── twasp_helper.py ├── twasp_main.py ├── twasp_model.py └── updates.md /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /data_preprocessing/ud-treebanks-v2.4 3 | /data_preprocessing/LDC2016T13 4 | /data_preprocessing/LDC10T07 5 | /data_preprocessing/LDC07T36 6 | /data_preprocessing/LDC05T01 7 | /data_preprocessing/LDC05T01/ 8 | /data_preprocessing/LDC07T36/ 9 | /data_preprocessing/LDC10T07/ 10 | /tmp/ 11 | twasp.convert.sh 12 | twasp_convert.py 13 | test.sh 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 SVAIGBA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TwASP 2 | 3 | This is the implementation of [Joint Chinese Word Segmentation and Part-of-speech Tagging via Two-way Attentions of Auto-analyzed Knowledge](https://www.aclweb.org/anthology/2020.acl-main.735/) at ACL2020. 4 | 5 | Please contact us at `yhtian@uw.edu` if you have any questions. 6 | 7 | **Visit our [homepage](https://github.com/synlp/.github) to find more our recent research and softwares for NLP (e.g., pre-trained LM, POS tagging, NER, sentiment analysis, relation extraction, datasets, etc.).** 8 | 9 | ## Upgrades of TwASP 10 | 11 | We are improving TwASP. For updates, please visit [HERE](https://github.com/synlp/TwASP). 12 | 13 | ## Citation 14 | 15 | If you use or extend our work, please cite our paper at ACL2020. 16 | 17 | ``` 18 | @inproceedings{tian-etal-2020-joint, 19 | title = "Joint Chinese Word Segmentation and Part-of-speech Tagging via Two-way Attentions of Auto-analyzed Knowledge", 20 | author = "Tian, Yuanhe and Song, Yan and Ao, Xiang and Xia, Fei and Quan, Xiaojun and Zhang, Tong and Wang, Yonggang", 21 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 22 | month = jul, 23 | year = "2020", 24 | address = "Online", 25 | pages = "8286--8296", 26 | } 27 | ``` 28 | 29 | ## Requirements 30 | 31 | Our code works with the following environment. 32 | * `python=3.6` 33 | * `pytorch=1.1` 34 | 35 | To run [Stanford CoreNLP Toolkit](https://stanfordnlp.github.io/CoreNLP/cmdline.html), you need 36 | * `Java 8` 37 | 38 | To run [Berkeley Neural Parser](https://github.com/nikitakit/self-attentive-parser), you need 39 | * `tensorfolw==1.13.1` 40 | * `benepar[cpu]` 41 | * `cython` 42 | 43 | Note that Berkeley Neural Parser does not support `TensorFlow 2.0`. 44 | 45 | You can refer to their websites for more information. 46 | 47 | ## Downloading BERT, ZEN and TwASP 48 | 49 | In our paper, we use BERT ([paper](https://www.aclweb.org/anthology/N19-1423/)) and ZEN ([paper](https://arxiv.org/abs/1911.00720)) as the encoder. 50 | 51 | For BERT, please download pre-trained BERT-Base Chinese from [Google](https://github.com/google-research/bert) or from [HuggingFace](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz). If you download it from Google, you need to convert the model from TensorFlow version to PyTorch version. 52 | 53 | For ZEN, you can download the pre-trained model form [here](https://github.com/sinovation/ZEN). 54 | 55 | For TwASP, you can download the models we trained in our experiments from [here](https://github.com/SVAIGBA/TwASP/tree/master/models). **Note**: since we have improved the code, to reproduce the results in our paper, please check out [this version](https://github.com/SVAIGBA/TwASP/tree/cc7a4dfc2a78db98d7414755babd82c6ca98fba0) of our code. 56 | 57 | ## Run on Sample Data 58 | 59 | Run `run_sample.sh` to train a model on the small sample data under the `sample_data` folder. 60 | 61 | ## Datasets 62 | 63 | We use [CTB5](https://catalog.ldc.upenn.edu/LDC2005T01), [CTB6](https://catalog.ldc.upenn.edu/LDC2007T36), [CTB7](https://catalog.ldc.upenn.edu/LDC2010T07), [CTB9](https://catalog.ldc.upenn.edu/LDC2016T13), and [Universal Dependencies 2.4](https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-2988) (UD) in our paper. 64 | 65 | To obtain and pre-process the data, you can go to `data_preprocessing` directory and run `getdata.sh`. This script will download and process the official data from UD. For CTB5 (LDC05T01), CTB6 (LDC07T36), CTB7 (LDC10T07), and CTB9 (LDC2016T13), you need to obtain the official data yourself, and then put the raw data folder under the `data_preprocessing` directory. 66 | 67 | The script will also download the [Stanford CoreNLP Toolkit v3.9.2](https://stanfordnlp.github.io/CoreNLP/history.html) (SCT) and [Berkeley Neural Parser](https://github.com/nikitakit/self-attentive-parser) (BNP) to obtain the auto-analyzed syntactic knowledge. You can refer to their website for more information. 68 | 69 | All processed data will appear in `data` directory organized by the datasets, where each of them contains the files with the same file names under the `sample_data` directory. 70 | 71 | ## Training and Testing 72 | 73 | You can find the command lines to train and test model on a specific dataset with the part-of-speech (POS) knowledge from Stanford CoreNLP Toolkit v3.9.2 (SCT) in `run.sh`. 74 | 75 | Here are some important parameters: 76 | 77 | * `--do_train`: train the model 78 | * `--do_test`: test the model 79 | * `--use_bert`: use BERT as encoder 80 | * `--use_zen`: use ZEN as encoder 81 | * `--bert_model`: the directory of pre-trained BERT/ZEN model 82 | * `--use_attention`: use two-way attention 83 | * `--source`: the toolkit to be use (`stanford` or `berkeley`) 84 | * `--feature_flag`: use `pos`, `chunk`, or `dep` knowledge 85 | * `--model_name`: the name of model to save 86 | 87 | ## Predicting 88 | 89 | `run_sample.sh` contains the command line to segment and tag the sentences in an input file ([./sample_data/sentence.txt](./sample_data/sentence.txt)). 90 | 91 | Here are some important parameters: 92 | 93 | * `--do_predict`: segment and tag the sentences using a pre-trained TwASP model. 94 | * `--input_file`: the file contains sentences to be segmented and tagged. Each line contains one sentence; you can refer to [a sample input file](./sample_data/sentence.txt) for the input format. 95 | * `--output_file`: the path of the output file. Words are segmented by a space; POS labels are attached to the resulting words by an underline ("_"). 96 | * `--eval_model`: the pre-trained WMSeg model to be used to segment the sentences in the input file. 97 | 98 | To run a pre-trained TwASP model, you need to install SCT and BNP to obtain the auto-analyzed syntactic knowledge. See [data_processing](./data_preprocessing) for more information to download the two toolkits. 99 | 100 | ## To-do List 101 | 102 | * Regular maintenance 103 | 104 | You can leave comments in the `Issues` section, if you want us to implement any functions. 105 | 106 | You can check our updates at [updates.md](./updates.md). 107 | 108 | -------------------------------------------------------------------------------- /corenlp.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8_*_ 2 | from __future__ import print_function 3 | 4 | import glob 5 | import json 6 | import logging 7 | import os 8 | import re 9 | import socket 10 | import subprocess 11 | import sys 12 | import time 13 | 14 | import psutil 15 | 16 | try: 17 | from urlparse import urlparse 18 | except ImportError: 19 | from urllib.parse import urlparse 20 | 21 | import requests 22 | 23 | 24 | class StanfordCoreNLP: 25 | def __init__(self, path_or_host, port=None, memory='4g', lang='en', timeout=1500, quiet=True, 26 | logging_level=logging.WARNING, max_retries=5): 27 | self.path_or_host = path_or_host 28 | self.port = port 29 | self.memory = memory 30 | self.lang = lang 31 | self.timeout = timeout 32 | self.quiet = quiet 33 | self.logging_level = logging_level 34 | 35 | logging.basicConfig(level=self.logging_level) 36 | 37 | # Check args 38 | self._check_args() 39 | 40 | if path_or_host.startswith('http'): 41 | self.url = path_or_host + ':' + str(port) 42 | logging.info('Using an existing server {}'.format(self.url)) 43 | else: 44 | 45 | # Check Java 46 | if not subprocess.call(['java', '-version'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) == 0: 47 | raise RuntimeError('Java not found.') 48 | 49 | # Check if the dir exists 50 | if not os.path.isdir(self.path_or_host): 51 | raise IOError(str(self.path_or_host) + ' is not a directory.') 52 | directory = os.path.normpath(self.path_or_host) + os.sep 53 | self.class_path_dir = directory 54 | 55 | # Check if the language specific model file exists 56 | switcher = { 57 | 'en': 'stanford-corenlp-[0-9].[0-9].[0-9]-models.jar', 58 | 'zh': 'stanford-chinese-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 59 | 'ar': 'stanford-arabic-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 60 | 'fr': 'stanford-french-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 61 | 'de': 'stanford-german-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar', 62 | 'es': 'stanford-spanish-corenlp-[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]-models.jar' 63 | } 64 | jars = { 65 | 'en': 'stanford-corenlp-x.x.x-models.jar', 66 | 'zh': 'stanford-chinese-corenlp-yyyy-MM-dd-models.jar', 67 | 'ar': 'stanford-arabic-corenlp-yyyy-MM-dd-models.jar', 68 | 'fr': 'stanford-french-corenlp-yyyy-MM-dd-models.jar', 69 | 'de': 'stanford-german-corenlp-yyyy-MM-dd-models.jar', 70 | 'es': 'stanford-spanish-corenlp-yyyy-MM-dd-models.jar' 71 | } 72 | if len(glob.glob(directory + switcher.get(self.lang))) <= 0: 73 | raise IOError(jars.get( 74 | self.lang) + ' not exists. You should download and place it in the ' + directory + ' first.') 75 | 76 | # If port not set, auto select 77 | if self.port is None: 78 | for port_candidate in range(9000, 65535): 79 | if port_candidate not in [conn.laddr[1] for conn in psutil.net_connections()]: 80 | self.port = port_candidate 81 | break 82 | 83 | # Check if the port is in use 84 | if self.port in [conn.laddr[1] for conn in psutil.net_connections()]: 85 | raise IOError('Port ' + str(self.port) + ' is already in use.') 86 | 87 | # Start native server 88 | logging.info('Initializing native server...') 89 | cmd = "java" 90 | java_args = "-Xmx{}".format(self.memory) 91 | java_class = "edu.stanford.nlp.pipeline.StanfordCoreNLPServer" 92 | class_path = '"{}*"'.format(directory) 93 | 94 | args = [cmd, java_args, '-cp', class_path, java_class, '-port', str(self.port)] 95 | 96 | args = ' '.join(args) 97 | 98 | logging.info(args) 99 | 100 | # Silence 101 | with open(os.devnull, 'w') as null_file: 102 | out_file = None 103 | if self.quiet: 104 | out_file = null_file 105 | 106 | self.p = subprocess.Popen(args, shell=True, stdout=out_file, stderr=subprocess.STDOUT) 107 | logging.info('Server shell PID: {}'.format(self.p.pid)) 108 | 109 | self.url = 'http://localhost:' + str(self.port) 110 | 111 | # Wait until server starts 112 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 113 | host_name = urlparse(self.url).hostname 114 | time.sleep(1) # OSX, not tested 115 | trial = 1 116 | while sock.connect_ex((host_name, self.port)): 117 | if trial > max_retries: 118 | raise ValueError('Corenlp server is not available') 119 | logging.info('Waiting until the server is available.') 120 | trial += 1 121 | time.sleep(1) 122 | logging.info('The server is available.') 123 | 124 | def __enter__(self): 125 | return self 126 | 127 | def __exit__(self, exc_type, exc_val, exc_tb): 128 | self.close() 129 | 130 | def close(self): 131 | logging.info('Cleanup...') 132 | if hasattr(self, 'p'): 133 | try: 134 | parent = psutil.Process(self.p.pid) 135 | except psutil.NoSuchProcess: 136 | logging.info('No process: {}'.format(self.p.pid)) 137 | return 138 | 139 | if self.class_path_dir not in ' '.join(parent.cmdline()): 140 | logging.info('Process not in: {}'.format(parent.cmdline())) 141 | return 142 | 143 | children = parent.children(recursive=True) 144 | for process in children: 145 | logging.info('Killing pid: {}, cmdline: {}'.format(process.pid, process.cmdline())) 146 | # process.send_signal(signal.SIGTERM) 147 | process.kill() 148 | 149 | logging.info('Killing shell pid: {}, cmdline: {}'.format(parent.pid, parent.cmdline())) 150 | # parent.send_signal(signal.SIGTERM) 151 | parent.kill() 152 | 153 | def annotate(self, text, properties=None): 154 | if sys.version_info.major >= 3: 155 | text = text.encode('utf-8') 156 | 157 | r = requests.post(self.url, params={'properties': str(properties)}, data=text, 158 | headers={'Connection': 'close'}) 159 | return r.text 160 | 161 | def tregex(self, sentence, pattern): 162 | tregex_url = self.url + '/tregex' 163 | r_dict = self._request(tregex_url, "tokenize,ssplit,depparse,parse", sentence, pattern=pattern) 164 | return r_dict 165 | 166 | def tokensregex(self, sentence, pattern): 167 | tokensregex_url = self.url + '/tokensregex' 168 | r_dict = self._request(tokensregex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) 169 | return r_dict 170 | 171 | def semgrex(self, sentence, pattern): 172 | semgrex_url = self.url + '/semgrex' 173 | r_dict = self._request(semgrex_url, "tokenize,ssplit,depparse", sentence, pattern=pattern) 174 | return r_dict 175 | 176 | def word_tokenize(self, sentence, span=False): 177 | r_dict = self._request('ssplit,tokenize', sentence) 178 | tokens = [token['originalText'] for s in r_dict['sentences'] for token in s['tokens']] 179 | 180 | # Whether return token span 181 | if span: 182 | spans = [(token['characterOffsetBegin'], token['characterOffsetEnd']) for s in r_dict['sentences'] for token 183 | in s['tokens']] 184 | return tokens, spans 185 | else: 186 | return tokens 187 | 188 | def pos_tag(self, sentence): 189 | r_dict = self._request(self.url, 'pos', sentence) 190 | words = [] 191 | tags = [] 192 | for s in r_dict['sentences']: 193 | for token in s['tokens']: 194 | words.append(token['originalText']) 195 | tags.append(token['pos']) 196 | return list(zip(words, tags)) 197 | 198 | def ner(self, sentence): 199 | r_dict = self._request(self.url, 'ner', sentence) 200 | words = [] 201 | ner_tags = [] 202 | for s in r_dict['sentences']: 203 | for token in s['tokens']: 204 | words.append(token['originalText']) 205 | ner_tags.append(token['ner']) 206 | return list(zip(words, ner_tags)) 207 | 208 | def parse(self, sentence): 209 | r_dict = self._request(self.url, 'pos,parse', sentence) 210 | return [s['parse'] for s in r_dict['sentences']] 211 | 212 | def dependency_parse(self, sentence): 213 | r_dict = self._request(self.url, 'depparse', sentence) 214 | return [(dep['dep'], dep['governor'], dep['dependent']) for s in r_dict['sentences'] for dep in 215 | s['basicDependencies']] 216 | 217 | def coref(self, text): 218 | r_dict = self._request('coref', text) 219 | 220 | corefs = [] 221 | for k, mentions in r_dict['corefs'].items(): 222 | simplified_mentions = [] 223 | for m in mentions: 224 | simplified_mentions.append((m['sentNum'], m['startIndex'], m['endIndex'], m['text'])) 225 | corefs.append(simplified_mentions) 226 | return corefs 227 | 228 | def switch_language(self, language="en"): 229 | self._check_language(language) 230 | self.lang = language 231 | 232 | # def _request(self, url, annotators=None, data=None, *args, **kwargs): 233 | # if sys.version_info.major >= 3: 234 | # data = data.encode('utf-8') 235 | # 236 | # properties = {'annotators': annotators, 'outputFormat': 'json'} 237 | # params = {'properties': str(properties), 'pipelineLanguage': self.lang} 238 | # if 'pattern' in kwargs: 239 | # params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang} 240 | # 241 | # logging.info(params) 242 | # r = requests.post(url, params=params, data=data, headers={'Connection': 'close'}) 243 | # r_dict = json.loads(r.text) 244 | # 245 | # return r_dict 246 | 247 | def request(self, annotators=None, data=None, *args, **kwargs): 248 | # if sys.version_info.major >= 3: 249 | data = data.encode('utf-8') 250 | 251 | properties = {'annotators': annotators, 'outputFormat': 'json'} 252 | params = {'properties': str(properties), 'pipelineLanguage': self.lang, 253 | 'parse.model': 'edu/stanford/nlp/models/lexparser/chinesePCFG.ser.gz', 254 | 'parse.kbest': 3} 255 | if 'pattern' in kwargs: 256 | params = {"pattern": kwargs['pattern'], 'properties': str(properties), 'pipelineLanguage': self.lang} 257 | 258 | logging.info(params) 259 | r = requests.post(self.url, params=params, data=data, headers={'Connection': 'close'}) 260 | r_dict = json.loads(r.text) 261 | 262 | return r_dict 263 | 264 | def _check_args(self): 265 | self._check_language(self.lang) 266 | if not re.match('\dg', self.memory): 267 | raise ValueError('memory=' + self.memory + ' not supported. Use 4g, 6g, 8g and etc. ') 268 | 269 | def _check_language(self, lang): 270 | if lang not in ['en', 'zh', 'ar', 'fr', 'de', 'es']: 271 | raise ValueError('lang=' + self.lang + ' not supported. Use English(en), Chinese(zh), Arabic(ar), ' 272 | 'French(fr), German(de), Spanish(es).') 273 | -------------------------------------------------------------------------------- /data_preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Data Pre-processing 2 | 3 | Run `getdata.sh` under that directory to obtain and pre-process the data. This script will download and process the official data from UD. For CTB5, CTB6, CTB7, and CTB9, you need to obtain the official data yourself, and then put the raw data folder under the `data_preprocessing` directory. The folder name for the CTB datasets should be: 4 | 5 | * CTB5: LDC05T01 6 | * CTB6: LDC07T36 7 | * CTB7: LDC10T07 8 | * CTB9: LDC2016T13 9 | 10 | This script will also download the [Stanford CoreNLP Toolkit v3.9.2](https://stanfordnlp.github.io/CoreNLP/history.html) (SCT) and [Berkeley Neural Parser](https://github.com/nikitakit/self-attentive-parser) (BNP) from their official website, which are used to obtain the auto-analyzed syntactic knowledge. If you only want to use the knowledge from SCT, you can comment out the script to download BNP in `getdata.sh`. If you want to use the auto-analyzed knowledge from BNP, you need to download both SCT and BNP, because BNP relies on the segmentation results from SCT. 11 | 12 | To run SCT, you need `java 8`; to run BNP, you need `tensorflow==1.1.3`. 13 | 14 | You can refer to their websites for more information. 15 | 16 | All processed data will appear in `data` directory organized by the datasets, where each of them contains the files with the same file names in the `sample_data` folder. 17 | -------------------------------------------------------------------------------- /data_preprocessing/getdata.sh: -------------------------------------------------------------------------------- 1 | ############## process data ############## 2 | 3 | # download Universal Dependencies 2.4 4 | # If this step fails, you can manually download the file and put it under this directory 5 | wget https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2988/ud-treebanks-v2.4.tgz 6 | 7 | tar zxvf ud-treebanks-v2.4.tgz 8 | rm ud-treebanks-v2.4.tgz 9 | 10 | ## process UD 11 | python data_preprocessing.py --dataset=ud --translate 12 | 13 | ## process CTB5 14 | python data_preprocessing.py --dataset=ctb5 15 | 16 | ## process CTB6 17 | python data_preprocessing.py --dataset=ctb6 18 | 19 | ## process CTB7 20 | python data_preprocessing.py --dataset=ctb7 21 | 22 | ## process CTB9 23 | python data_preprocessing.py --dataset=ctb9 24 | 25 | ############### process data ############## 26 | 27 | 28 | ############### download Stanford CoreNLP ################# 29 | # download StanfordCoreNLP v3.9.2 30 | wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip 31 | 32 | unzip stanford-corenlp-full-2018-10-05.zip 33 | rm stanford-corenlp-full-2018-10-05.zip 34 | cd stanford-corenlp-full-2018-10-05 || exit 35 | # download Chinese 36 | wget http://nlp.stanford.edu/software/stanford-chinese-corenlp-2018-10-05-models.jar 37 | 38 | cd .. 39 | 40 | ############## download Stanford CoreNLP ################# 41 | 42 | 43 | ############## download Berkeley Neural Parser (BNP) ################# 44 | # You can see https://github.com/nikitakit/self-attentive-parser for details 45 | 46 | pip install cython numpy 47 | pip install benepar[cpu] 48 | pip install nltk 49 | 50 | # download BNP model 51 | python install_BNP.py 52 | 53 | ############## download Berkeley Neural Parser ################# 54 | 55 | cd .. 56 | 57 | ############## obtain auto_processed data from Stanford CoreNLP Toolkits (SCT) ################# 58 | 59 | python get_syninfo.py --dataset=UD1 --toolkit=SCT --overwrite 60 | 61 | python get_syninfo.py --dataset=UD2 --toolkit=SCT --overwrite 62 | 63 | python get_syninfo.py --dataset=CTB5 --toolkit=SCT --overwrite 64 | 65 | python get_syninfo.py --dataset=CTB6 --toolkit=SCT --overwrite 66 | 67 | python get_syninfo.py --dataset=CTB7 --toolkit=SCT --overwrite 68 | 69 | python get_syninfo.py --dataset=CTB8 --toolkit=SCT --overwrite 70 | 71 | ############## obtain auto_processed data from Stanford CoreNLP Toolkits (SCT) ################# 72 | 73 | 74 | 75 | ############## obtain auto_processed data from Berkeley Neural Parser (BNP) ################# 76 | 77 | # BNP requires TensorFlow framework (tensorflow-gpu=1.13). It is recommended to use GPUs when obtaining syntactic knowledge from BNP 78 | 79 | # uncomment the following line if you want to install Tensorflow to run BNP 80 | # pip install tensorflow-gpu==1.13.1 81 | 82 | python get_syninfo.py --dataset=UD1 --toolkit=BNP --overwrite 83 | 84 | python get_syninfo.py --dataset=UD2 --toolkit=BNP --overwrite 85 | 86 | python get_syninfo.py --dataset=CTB5 --toolkit=BNP --overwrite 87 | 88 | python get_syninfo.py --dataset=CTB6 --toolkit=BNP --overwrite 89 | 90 | python get_syninfo.py --dataset=CTB7 --toolkit=BNP --overwrite 91 | 92 | python get_syninfo.py --dataset=CTB8 --toolkit=BNP --overwrite 93 | 94 | ############## obtain auto_processed data from Berkeley Neural Parser (BNP) ################# 95 | -------------------------------------------------------------------------------- /data_preprocessing/install_BNP.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import benepar 3 | 4 | nltk.download('punkt') 5 | benepar.download('benepar_zh') 6 | -------------------------------------------------------------------------------- /data_preprocessing/langconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from copy import deepcopy 5 | import re 6 | 7 | try: 8 | import psyco 9 | psyco.full() 10 | except: 11 | pass 12 | 13 | try: 14 | from zh_wiki import zh2Hant, zh2Hans 15 | except ImportError: 16 | from zhtools.zh_wiki import zh2Hant, zh2Hans 17 | 18 | import sys 19 | py3k = sys.version_info >= (3, 0, 0) 20 | 21 | if py3k: 22 | UEMPTY = '' 23 | else: 24 | _zh2Hant, _zh2Hans = {}, {} 25 | for old, new in ((zh2Hant, _zh2Hant), (zh2Hans, _zh2Hans)): 26 | for k, v in old.items(): 27 | new[k.decode('utf8')] = v.decode('utf8') 28 | zh2Hant = _zh2Hant 29 | zh2Hans = _zh2Hans 30 | UEMPTY = ''.decode('utf8') 31 | 32 | # states 33 | (START, END, FAIL, WAIT_TAIL) = list(range(4)) 34 | # conditions 35 | (TAIL, ERROR, MATCHED_SWITCH, UNMATCHED_SWITCH, CONNECTOR) = list(range(5)) 36 | 37 | MAPS = {} 38 | 39 | class Node(object): 40 | def __init__(self, from_word, to_word=None, is_tail=True, 41 | have_child=False): 42 | self.from_word = from_word 43 | if to_word is None: 44 | self.to_word = from_word 45 | self.data = (is_tail, have_child, from_word) 46 | self.is_original = True 47 | else: 48 | self.to_word = to_word or from_word 49 | self.data = (is_tail, have_child, to_word) 50 | self.is_original = False 51 | self.is_tail = is_tail 52 | self.have_child = have_child 53 | 54 | def is_original_long_word(self): 55 | return self.is_original and len(self.from_word)>1 56 | 57 | def is_follow(self, chars): 58 | return chars != self.from_word[:-1] 59 | 60 | def __str__(self): 61 | return '' % (repr(self.from_word), 62 | repr(self.to_word), self.is_tail, self.have_child) 63 | 64 | __repr__ = __str__ 65 | 66 | class ConvertMap(object): 67 | def __init__(self, name, mapping=None): 68 | self.name = name 69 | self._map = {} 70 | if mapping: 71 | self.set_convert_map(mapping) 72 | 73 | def set_convert_map(self, mapping): 74 | convert_map = {} 75 | have_child = {} 76 | max_key_length = 0 77 | for key in sorted(mapping.keys()): 78 | if len(key)>1: 79 | for i in range(1, len(key)): 80 | parent_key = key[:i] 81 | have_child[parent_key] = True 82 | have_child[key] = False 83 | max_key_length = max(max_key_length, len(key)) 84 | for key in sorted(have_child.keys()): 85 | convert_map[key] = (key in mapping, have_child[key], 86 | mapping.get(key, UEMPTY)) 87 | self._map = convert_map 88 | self.max_key_length = max_key_length 89 | 90 | def __getitem__(self, k): 91 | try: 92 | is_tail, have_child, to_word = self._map[k] 93 | return Node(k, to_word, is_tail, have_child) 94 | except: 95 | return Node(k) 96 | 97 | def __contains__(self, k): 98 | return k in self._map 99 | 100 | def __len__(self): 101 | return len(self._map) 102 | 103 | class StatesMachineException(Exception): pass 104 | 105 | class StatesMachine(object): 106 | def __init__(self): 107 | self.state = START 108 | self.final = UEMPTY 109 | self.len = 0 110 | self.pool = UEMPTY 111 | 112 | def clone(self, pool): 113 | new = deepcopy(self) 114 | new.state = WAIT_TAIL 115 | new.pool = pool 116 | return new 117 | 118 | def feed(self, char, map): 119 | node = map[self.pool+char] 120 | 121 | if node.have_child: 122 | if node.is_tail: 123 | if node.is_original: 124 | cond = UNMATCHED_SWITCH 125 | else: 126 | cond = MATCHED_SWITCH 127 | else: 128 | cond = CONNECTOR 129 | else: 130 | if node.is_tail: 131 | cond = TAIL 132 | else: 133 | cond = ERROR 134 | 135 | new = None 136 | if cond == ERROR: 137 | self.state = FAIL 138 | elif cond == TAIL: 139 | if self.state == WAIT_TAIL and node.is_original_long_word(): 140 | self.state = FAIL 141 | else: 142 | self.final += node.to_word 143 | self.len += 1 144 | self.pool = UEMPTY 145 | self.state = END 146 | elif self.state == START or self.state == WAIT_TAIL: 147 | if cond == MATCHED_SWITCH: 148 | new = self.clone(node.from_word) 149 | self.final += node.to_word 150 | self.len += 1 151 | self.state = END 152 | self.pool = UEMPTY 153 | elif cond == UNMATCHED_SWITCH or cond == CONNECTOR: 154 | if self.state == START: 155 | new = self.clone(node.from_word) 156 | self.final += node.to_word 157 | self.len += 1 158 | self.state = END 159 | else: 160 | if node.is_follow(self.pool): 161 | self.state = FAIL 162 | else: 163 | self.pool = node.from_word 164 | elif self.state == END: 165 | # END is a new START 166 | self.state = START 167 | new = self.feed(char, map) 168 | elif self.state == FAIL: 169 | raise StatesMachineException('Translate States Machine ' 170 | 'have error with input data %s' % node) 171 | return new 172 | 173 | def __len__(self): 174 | return self.len + 1 175 | 176 | def __str__(self): 177 | return '' % ( 178 | id(self), self.pool, self.state, self.final) 179 | __repr__ = __str__ 180 | 181 | class Converter(object): 182 | def __init__(self, to_encoding): 183 | self.to_encoding = to_encoding 184 | self.map = MAPS[to_encoding] 185 | self.start() 186 | 187 | def feed(self, char): 188 | branches = [] 189 | for fsm in self.machines: 190 | new = fsm.feed(char, self.map) 191 | if new: 192 | branches.append(new) 193 | if branches: 194 | self.machines.extend(branches) 195 | self.machines = [fsm for fsm in self.machines if fsm.state != FAIL] 196 | all_ok = True 197 | for fsm in self.machines: 198 | if fsm.state != END: 199 | all_ok = False 200 | if all_ok: 201 | self._clean() 202 | return self.get_result() 203 | 204 | def _clean(self): 205 | if len(self.machines): 206 | self.machines.sort(key=lambda x: len(x)) 207 | # self.machines.sort(cmp=lambda x,y: cmp(len(x), len(y))) 208 | self.final += self.machines[0].final 209 | self.machines = [StatesMachine()] 210 | 211 | def start(self): 212 | self.machines = [StatesMachine()] 213 | self.final = UEMPTY 214 | 215 | def end(self): 216 | self.machines = [fsm for fsm in self.machines 217 | if fsm.state == FAIL or fsm.state == END] 218 | self._clean() 219 | 220 | def convert(self, string): 221 | self.start() 222 | for char in string: 223 | self.feed(char) 224 | self.end() 225 | return self.get_result() 226 | 227 | def get_result(self): 228 | return self.final 229 | 230 | 231 | def registery(name, mapping): 232 | global MAPS 233 | MAPS[name] = ConvertMap(name, mapping) 234 | 235 | registery('zh-hant', zh2Hant) 236 | registery('zh-hans', zh2Hans) 237 | del zh2Hant, zh2Hans 238 | 239 | 240 | def run(): 241 | import sys 242 | from optparse import OptionParser 243 | parser = OptionParser() 244 | parser.add_option('-e', type='string', dest='encoding', 245 | help='encoding') 246 | parser.add_option('-f', type='string', dest='file_in', 247 | help='input file (- for stdin)') 248 | parser.add_option('-t', type='string', dest='file_out', 249 | help='output file') 250 | (options, args) = parser.parse_args() 251 | if not options.encoding: 252 | parser.error('encoding must be set') 253 | if options.file_in: 254 | if options.file_in == '-': 255 | file_in = sys.stdin 256 | else: 257 | file_in = open(options.file_in) 258 | else: 259 | file_in = sys.stdin 260 | if options.file_out: 261 | if options.file_out == '-': 262 | file_out = sys.stdout 263 | else: 264 | file_out = open(options.file_out, 'wb') 265 | else: 266 | file_out = sys.stdout 267 | 268 | c = Converter(options.encoding) 269 | for line in file_in: 270 | # print >> file_out, c.convert(line.rstrip('\n').decode( 271 | file_out.write(c.convert(line.rstrip('\n').decode( 272 | 'utf8')).encode('utf8')) 273 | 274 | 275 | if __name__ == '__main__': 276 | run() 277 | 278 | -------------------------------------------------------------------------------- /data_preprocessing/t2s.py: -------------------------------------------------------------------------------- 1 | from langconv import * 2 | from os import path 3 | import os 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | input_dir = './traditional_data/' 8 | output_dir = './processed/' 9 | 10 | def read_file(file_path): 11 | sentence_list = [] 12 | label_list = [] 13 | with open(file_path, 'r', encoding='utf8') as f: 14 | lines = f.readlines() 15 | sentence = [] 16 | labels = [] 17 | for line in lines: 18 | line = line.strip() 19 | if line == '': 20 | sentence_list.append(sentence) 21 | label_list.append(labels) 22 | sentence = [] 23 | labels = [] 24 | continue 25 | items = re.split('\\s+', line) 26 | character = items[0] 27 | label = items[-1] 28 | sentence.append(character) 29 | labels.append(label) 30 | 31 | return sentence_list, label_list 32 | 33 | 34 | def write_file(file_path, sentence_list, label_list): 35 | with open(file_path, 'w', encoding='utf8') as f: 36 | for sentence, label in zip(sentence_list, label_list): 37 | for s, l in zip(sentence, label): 38 | f.write('%s\t%s\n' % (s, l)) 39 | f.write('\n') 40 | 41 | 42 | def Traditional2Simplified(sentence): 43 | sentence = Converter('zh-hans').convert(sentence) 44 | return sentence 45 | 46 | 47 | def traditional2simplified(input_file_path, output_file_path): 48 | simp_sentence_list = [] 49 | 50 | sentence_list, label_list = read_file(input_file_path) 51 | for sentence in tqdm(sentence_list): 52 | sentence_str = ''.join(sentence) 53 | simp_sentence_str = Traditional2Simplified(sentence_str) 54 | assert len(simp_sentence_str) == len(sentence) 55 | simp_sentence_list.append(simp_sentence_str) 56 | 57 | write_file(output_file_path, simp_sentence_list, label_list) 58 | -------------------------------------------------------------------------------- /get_syninfo.py: -------------------------------------------------------------------------------- 1 | from twasp_helper import request_features_from_stanford, request_features_from_berkeley 2 | import argparse 3 | import os 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument("--dataset", 9 | default=None, 10 | type=str, 11 | required=True, 12 | help="The dataset. Should be one of \'CTB5\', \'CTB6\', \'CTB7\', \'CTB9\', " 13 | "\'UD1\' and \'UD2\'.") 14 | 15 | parser.add_argument("--toolkit", 16 | default=None, 17 | type=str, 18 | required=True, 19 | help="The toolkit to be used. Should be one of \'SCT\' and \'BNP\'.") 20 | 21 | parser.add_argument("--overwrite", 22 | action='store_true', 23 | help="Whether to overwrite existing data.") 24 | 25 | args = parser.parse_args() 26 | 27 | print(vars(args)) 28 | 29 | input_dir = os.path.join('./data/', args.dataset) 30 | 31 | if args.overwrite: 32 | print('All existing files will be overwrote') 33 | 34 | for flag in ['train', 'dev', 'test']: 35 | input_file = os.path.join(input_dir, flag + '.tsv') 36 | if not os.path.exists(input_file): 37 | print('File does not exits: %s' % str(input_file)) 38 | continue 39 | 40 | if args.toolkit == 'SCT': 41 | out_file = os.path.join(input_dir, flag + '.stanford.json') 42 | if os.path.exists(out_file) and not args.overwrite: 43 | print('File already exists: %s' % str(out_file)) 44 | continue 45 | request_features_from_stanford(input_file) 46 | 47 | elif args.toolkit == 'BNP': 48 | out_file = os.path.join(input_dir, flag + '.berkeley.json') 49 | if os.path.exists(out_file) and not args.overwrite: 50 | print('File already exists: %s' % str(out_file)) 51 | continue 52 | request_features_from_berkeley(input_file) 53 | else: 54 | raise ValueError('Invalid type of toolkit name: %s. Should be one of \'SCT\' and \'BNP\'.' % args.toolkit) 55 | 56 | 57 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## TwASP models 2 | 3 | The pre-trained TwASP models on different datasets with BERT and ZEN encoder. You can download the models and run them on the corresponding datasets to obtain the results reported in our paper. 4 | 5 | 6 | | Section | BaiduNetDisk | Description | 7 | |-|-|-| 8 | |[TwASP.ZEN.CTB5]| [download](https://pan.baidu.com/s/1ebTaNXrw3P0E7wQaqZofkQ) (password: 2kul)| TwASP model trained on **CTB5** dataset using **ZEN** as encoder | 9 | |[TwASP.ZEN.CTB6]| [download](https://pan.baidu.com/s/1rgLd2QT2txSTiYOT-fVRNg) (password: 1q59)| TwASP model trained on **CTB6** dataset using **ZEN** as encoder | 10 | |[TwASP.ZEN.CTB7]| [download](https://pan.baidu.com/s/1d9SZD9soyv02J3BZr3NhBg) (password: krwr)| TwASP model trained on **CTB7** dataset using **ZEN** as encoder | 11 | |[TwASP.ZEN.CTB9]| [download](https://pan.baidu.com/s/1Yw2-J_E7p3yaa5PleeGkpg) (password: 6jq2)| TwASP model trained on **CTB9** dataset using **ZEN** as encoder | 12 | |[TwASP.ZEN.UD1]| [download](https://pan.baidu.com/s/1gvQx8g8SXJwTK0jqhrTAfg) (password: lrwu)| TwASP model trained on **UD1** dataset using **ZEN** as encoder | 13 | |[TwASP.ZEN.UD2]| [download](https://pan.baidu.com/s/1roGP76Cggef8DvcheRGWjg) (password: mku5)| TwASP model trained on **UD2** dataset using **ZEN** as encoder | 14 | |[TwASP.BERT.CTB5]| [download](https://pan.baidu.com/s/1g4rYCMulEW_nCtwQryU90Q) (password: mldm)| TwASP model trained on **CTB5** dataset using **BERT** as encoder | 15 | |[TwASP.BERT.CTB6]| [download](https://pan.baidu.com/s/1XJd0Tr7KDnaXuDYIPYuoWg) (password: ro0m)| TwASP model trained on **CTB6** dataset using **BERT** as encoder | 16 | |[TwASP.BERT.CTB7]| [download](https://pan.baidu.com/s/1FSZxe2cnkyRHEjLVK6Uejw) (password: wagb)| TwASP model trained on **CTB7** dataset using **BERT** as encoder | 17 | |[TwASP.BERT.CTB9]| [download](https://pan.baidu.com/s/1s_8z20Ud6YohbG70wgheWA) (password: klw9)| TwASP model trained on **CTB9** dataset using **BERT** as encoder | 18 | |[TwASP.BERT.UD1]| [download](https://pan.baidu.com/s/1wrN3unMWQGh1P3vxsqGvpw) (password: tnbh)| TwASP model trained on **UD1** dataset using **BERT** as encoder | 19 | |[TwASP.BERT.UD2]| [download](https://pan.baidu.com/s/1gvPcyLCiAyhA_hkmijJmPw) (password: gs71)| TwASP model trained on **UD2** dataset using **BERT** as encoder | 20 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | # from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | # from modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 30 | from modeling_convert import BertConfig, BertForPreTraining, load_tf_weights_in_bert 31 | 32 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 33 | # Initialise PyTorch model 34 | config = BertConfig.from_json_file(bert_config_file) 35 | print("Building PyTorch model from configuration: {}".format(str(config))) 36 | model = BertForPreTraining(config) 37 | 38 | # Load weights from tf checkpoint 39 | load_tf_weights_in_bert(model, tf_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 43 | torch.save(model.state_dict(), pytorch_dump_path) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | ## Required parameters 49 | parser.add_argument("--tf_checkpoint_path", 50 | default = None, 51 | type = str, 52 | required = True, 53 | help = "Path the TensorFlow checkpoint path.") 54 | parser.add_argument("--bert_config_file", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "The config json file corresponding to the pre-trained BERT model. \n" 59 | "This specifies the model architecture.") 60 | parser.add_argument("--pytorch_dump_path", 61 | default = None, 62 | type = str, 63 | required = True, 64 | help = "Path to the output PyTorch model.") 65 | args = parser.parse_args() 66 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 67 | args.bert_config_file, 68 | args.pytorch_dump_path) 69 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from torch.hub import _get_torch_home 27 | torch_cache_home = _get_torch_home() 28 | except ImportError: 29 | torch_cache_home = os.path.expanduser( 30 | os.getenv('TORCH_HOME', os.path.join( 31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 33 | 34 | try: 35 | from urllib.parse import urlparse 36 | except ImportError: 37 | from urlparse import urlparse 38 | 39 | try: 40 | from pathlib import Path 41 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 43 | except (AttributeError, ImportError): 44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 45 | default_cache_path) 46 | 47 | CONFIG_NAME = "config.json" 48 | WEIGHTS_NAME = "pytorch_model.bin" 49 | 50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 51 | 52 | 53 | def url_to_filename(url, etag=None): 54 | """ 55 | Convert `url` into a hashed filename in a repeatable way. 56 | If `etag` is specified, append its hash to the url's, delimited 57 | by a period. 58 | """ 59 | url_bytes = url.encode('utf-8') 60 | url_hash = sha256(url_bytes) 61 | filename = url_hash.hexdigest() 62 | 63 | if etag: 64 | etag_bytes = etag.encode('utf-8') 65 | etag_hash = sha256(etag_bytes) 66 | filename += '.' + etag_hash.hexdigest() 67 | 68 | return filename 69 | 70 | 71 | def filename_to_url(filename, cache_dir=None): 72 | """ 73 | Return the url and etag (which may be ``None``) stored for `filename`. 74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 75 | """ 76 | if cache_dir is None: 77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 79 | cache_dir = str(cache_dir) 80 | 81 | cache_path = os.path.join(cache_dir, filename) 82 | if not os.path.exists(cache_path): 83 | raise EnvironmentError("file {} not found".format(cache_path)) 84 | 85 | meta_path = cache_path + '.json' 86 | if not os.path.exists(meta_path): 87 | raise EnvironmentError("file {} not found".format(meta_path)) 88 | 89 | with open(meta_path, encoding="utf-8") as meta_file: 90 | metadata = json.load(meta_file) 91 | url = metadata['url'] 92 | etag = metadata['etag'] 93 | 94 | return url, etag 95 | 96 | 97 | def cached_path(url_or_filename, cache_dir=None): 98 | """ 99 | Given something that might be a URL (or might be a local path), 100 | determine which. If it's a URL, download the file and cache it, and 101 | return the path to the cached file. If it's already a local path, 102 | make sure the file exists and then return the path. 103 | """ 104 | if cache_dir is None: 105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 107 | url_or_filename = str(url_or_filename) 108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 109 | cache_dir = str(cache_dir) 110 | 111 | parsed = urlparse(url_or_filename) 112 | 113 | if parsed.scheme in ('http', 'https', 's3'): 114 | # URL, so get it from the cache (downloading if necessary) 115 | return get_from_cache(url_or_filename, cache_dir) 116 | elif os.path.exists(url_or_filename): 117 | # File, and it exists. 118 | return url_or_filename 119 | elif parsed.scheme == '': 120 | # File, but it doesn't exist. 121 | raise EnvironmentError("file {} not found".format(url_or_filename)) 122 | else: 123 | # Something unknown 124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 125 | 126 | 127 | def split_s3_path(url): 128 | """Split a full s3 path into the bucket name and path.""" 129 | parsed = urlparse(url) 130 | if not parsed.netloc or not parsed.path: 131 | raise ValueError("bad s3 path {}".format(url)) 132 | bucket_name = parsed.netloc 133 | s3_path = parsed.path 134 | # Remove '/' at beginning of path. 135 | if s3_path.startswith("/"): 136 | s3_path = s3_path[1:] 137 | return bucket_name, s3_path 138 | 139 | 140 | def s3_request(func): 141 | """ 142 | Wrapper function for s3 requests in order to create more helpful error 143 | messages. 144 | """ 145 | 146 | @wraps(func) 147 | def wrapper(url, *args, **kwargs): 148 | try: 149 | return func(url, *args, **kwargs) 150 | except ClientError as exc: 151 | if int(exc.response["Error"]["Code"]) == 404: 152 | raise EnvironmentError("file {} not found".format(url)) 153 | else: 154 | raise 155 | 156 | return wrapper 157 | 158 | 159 | @s3_request 160 | def s3_etag(url): 161 | """Check ETag on S3 object.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_object = s3_resource.Object(bucket_name, s3_path) 165 | return s3_object.e_tag 166 | 167 | 168 | @s3_request 169 | def s3_get(url, temp_file): 170 | """Pull a file directly from S3.""" 171 | s3_resource = boto3.resource("s3") 172 | bucket_name, s3_path = split_s3_path(url) 173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 174 | 175 | 176 | def http_get(url, temp_file): 177 | req = requests.get(url, stream=True) 178 | content_length = req.headers.get('Content-Length') 179 | total = int(content_length) if content_length is not None else None 180 | progress = tqdm(unit="B", total=total) 181 | for chunk in req.iter_content(chunk_size=1024): 182 | if chunk: # filter out keep-alive new chunks 183 | progress.update(len(chunk)) 184 | temp_file.write(chunk) 185 | progress.close() 186 | 187 | 188 | def get_from_cache(url, cache_dir=None): 189 | """ 190 | Given a URL, look for the corresponding dataset in the local cache. 191 | If it's not there, download it. Then return the path to the cached file. 192 | """ 193 | if cache_dir is None: 194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 196 | cache_dir = str(cache_dir) 197 | 198 | if not os.path.exists(cache_dir): 199 | os.makedirs(cache_dir) 200 | 201 | # Get eTag to add to filename, if it exists. 202 | if url.startswith("s3://"): 203 | etag = s3_etag(url) 204 | else: 205 | try: 206 | response = requests.head(url, allow_redirects=True) 207 | if response.status_code != 200: 208 | etag = None 209 | else: 210 | etag = response.headers.get("ETag") 211 | except EnvironmentError: 212 | etag = None 213 | 214 | if sys.version_info[0] == 2 and etag is not None: 215 | etag = etag.decode('utf-8') 216 | filename = url_to_filename(url, etag) 217 | 218 | # get cache path to put the file 219 | cache_path = os.path.join(cache_dir, filename) 220 | 221 | # If we don't have a connection (etag is None) and can't identify the file 222 | # try to get the last downloaded one 223 | if not os.path.exists(cache_path) and etag is None: 224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 226 | if matching_files: 227 | cache_path = os.path.join(cache_dir, matching_files[-1]) 228 | 229 | if not os.path.exists(cache_path): 230 | # Download to temporary file, then copy to cache dir once finished. 231 | # Otherwise you get corrupt cache entries if the download gets interrupted. 232 | with tempfile.NamedTemporaryFile() as temp_file: 233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 234 | 235 | # GET file object 236 | if url.startswith("s3://"): 237 | s3_get(url, temp_file) 238 | else: 239 | http_get(url, temp_file) 240 | 241 | # we are copying the file before closing it, so flush to avoid truncation 242 | temp_file.flush() 243 | # shutil.copyfileobj() starts at the current position, so go to the start 244 | temp_file.seek(0) 245 | 246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 247 | with open(cache_path, 'wb') as cache_file: 248 | shutil.copyfileobj(temp_file, cache_file) 249 | 250 | logger.info("creating metadata file for %s", cache_path) 251 | meta = {'url': url, 'etag': etag} 252 | meta_path = cache_path + '.json' 253 | with open(meta_path, 'w') as meta_file: 254 | output_string = json.dumps(meta) 255 | if sys.version_info[0] == 2 and isinstance(output_string, str): 256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 257 | meta_file.write(output_string) 258 | 259 | logger.info("removing temp file %s", temp_file.name) 260 | 261 | return cache_path 262 | 263 | 264 | def read_set_from_file(filename): 265 | ''' 266 | Extract a de-duped collection (set) of text from a file. 267 | Expected file format is one item per line. 268 | ''' 269 | collection = set() 270 | with open(filename, 'r', encoding='utf-8') as file_: 271 | for line in file_: 272 | collection.add(line.rstrip()) 273 | return collection 274 | 275 | 276 | def get_file_extension(path, dot=True, lower=True): 277 | ext = os.path.splitext(path)[1] 278 | ext = ext if dot else ext[1:] 279 | return ext.lower() if lower else ext 280 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | betas: Adams betas. Default: (0.9, 0.999) 195 | e: Adams epsilon. Default: 1e-6 196 | weight_decay: Weight decay. Default: 0.01 197 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 198 | """ 199 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 200 | betas=(0.9, 0.999), e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 201 | if lr is not required and lr < 0.0: 202 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 203 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 204 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 205 | if not 0.0 <= betas[0] < 1.0: 206 | raise ValueError("Invalid beta parameter at index 0: {} - should be in [0.0, 1.0[".format(betas[0])) 207 | if not 0.0 <= betas[1] < 1.0: 208 | raise ValueError("Invalid beta parameter at index 1: {} - should be in [0.0, 1.0[".format(betas[1])) 209 | if not e >= 0.0: 210 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 211 | # initialize schedule object 212 | if not isinstance(schedule, _LRSchedule): 213 | schedule_type = SCHEDULES[schedule] 214 | schedule = schedule_type(warmup=warmup, t_total=t_total) 215 | else: 216 | if warmup != -1 or t_total != -1: 217 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 218 | "Please specify custom warmup and t_total in _LRSchedule object.") 219 | defaults = dict(lr=lr, schedule=schedule, 220 | betas=betas, e=e, weight_decay=weight_decay, 221 | max_grad_norm=max_grad_norm) 222 | super(BertAdam, self).__init__(params, defaults) 223 | 224 | def get_lr(self): 225 | lr = [] 226 | for group in self.param_groups: 227 | for p in group['params']: 228 | state = self.state[p] 229 | if len(state) == 0: 230 | return [0] 231 | lr_scheduled = group['lr'] 232 | lr_scheduled *= group['schedule'].get_lr(state['step']) 233 | lr.append(lr_scheduled) 234 | return lr 235 | 236 | def step(self, closure=None): 237 | """Performs a single optimization step. 238 | 239 | Arguments: 240 | closure (callable, optional): A closure that reevaluates the model 241 | and returns the loss. 242 | """ 243 | loss = None 244 | if closure is not None: 245 | loss = closure() 246 | 247 | for group in self.param_groups: 248 | for p in group['params']: 249 | if p.grad is None: 250 | continue 251 | grad = p.grad.data 252 | if grad.is_sparse: 253 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 254 | 255 | state = self.state[p] 256 | 257 | # State initialization 258 | if len(state) == 0: 259 | state['step'] = 0 260 | # Exponential moving average of gradient values 261 | state['next_m'] = torch.zeros_like(p.data) 262 | # Exponential moving average of squared gradient values 263 | state['next_v'] = torch.zeros_like(p.data) 264 | 265 | next_m, next_v = state['next_m'], state['next_v'] 266 | beta1, beta2 = group['betas'] 267 | 268 | # Add grad clipping 269 | if group['max_grad_norm'] > 0: 270 | clip_grad_norm_(p, group['max_grad_norm']) 271 | 272 | # Decay the first and second moment running average coefficient 273 | # In-place operations to update the averages at the same time 274 | next_m.mul_(beta1).add_(1 - beta1, grad) 275 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 276 | update = next_m / (next_v.sqrt() + group['e']) 277 | 278 | # Just adding the square of the weights to the loss function is *not* 279 | # the correct way of using L2 regularization/weight decay with Adam, 280 | # since that will interact with the m and v parameters in strange ways. 281 | # 282 | # Instead we want to decay the weights in a manner that doesn't interact 283 | # with the m/v parameters. This is equivalent to adding the square 284 | # of the weights to the loss with plain (non-momentum) SGD. 285 | if group['weight_decay'] > 0.0: 286 | update += group['weight_decay'] * p.data 287 | 288 | lr_scheduled = group['lr'] 289 | lr_scheduled *= group['schedule'].get_lr(state['step']) 290 | 291 | update_with_lr = lr_scheduled * update 292 | p.data.add_(-update_with_lr) 293 | 294 | state['step'] += 1 295 | 296 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 297 | # No bias correction 298 | # bias_correction1 = 1 - beta1 ** state['step'] 299 | # bias_correction2 = 1 - beta2 ** state['step'] 300 | 301 | return loss 302 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | betas=(0.9, 0.999), e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {} - should be in [0.0, 1.0[".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {} - should be in [0.0, 1.0[".format(betas[1])) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | betas=betas, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['betas'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_chinese_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | def convert_to_unicode(text): 26 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 27 | if six.PY3: 28 | if isinstance(text, str): 29 | return text 30 | elif isinstance(text, bytes): 31 | return text.decode("utf-8", "ignore") 32 | else: 33 | raise ValueError("Unsupported string type: %s" % (type(text))) 34 | elif six.PY2: 35 | if isinstance(text, str): 36 | return text.decode("utf-8", "ignore") 37 | elif isinstance(text, unicode): 38 | return text 39 | else: 40 | raise ValueError("Unsupported string type: %s" % (type(text))) 41 | else: 42 | raise ValueError("Not running on Python2 or Python 3?") 43 | 44 | 45 | def printable_text(text): 46 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 47 | 48 | # These functions want `str` for both Python2 and Python3, but in one case 49 | # it's a Unicode string and in the other it's a byte string. 50 | if six.PY3: 51 | if isinstance(text, str): 52 | return text 53 | elif isinstance(text, bytes): 54 | return text.decode("utf-8", "ignore") 55 | else: 56 | raise ValueError("Unsupported string type: %s" % (type(text))) 57 | elif six.PY2: 58 | if isinstance(text, str): 59 | return text 60 | elif isinstance(text, unicode): 61 | return text.encode("utf-8") 62 | else: 63 | raise ValueError("Unsupported string type: %s" % (type(text))) 64 | else: 65 | raise ValueError("Not running on Python2 or Python 3?") 66 | 67 | 68 | def load_vocab(vocab_file): 69 | """Loads a vocabulary file into a dictionary.""" 70 | vocab = collections.OrderedDict() 71 | 72 | index_vocab = collections.OrderedDict() 73 | index = 0 74 | with open(vocab_file, "rb") as reader: 75 | while True: 76 | tmp = reader.readline() 77 | token = convert_to_unicode(tmp) 78 | 79 | 80 | if not token: 81 | break 82 | 83 | #file_out.write("%d\t%s\n" %(index,token)) 84 | token = token.strip() 85 | vocab[token] = index 86 | index_vocab[index]=token 87 | index += 1 88 | 89 | 90 | return vocab,index_vocab 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | """Converts a sequence of tokens into ids using the vocab.""" 95 | ids = [] 96 | for token in tokens: 97 | ids.append(vocab[token]) 98 | return ids 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab,self.index_vocab = load_vocab(vocab_file) 115 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 116 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 117 | 118 | def tokenize(self, text): 119 | split_tokens = [] 120 | for token in self.basic_tokenizer.tokenize(text): 121 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 122 | split_tokens.append(sub_token) 123 | 124 | return split_tokens 125 | 126 | def convert_tokens_to_ids(self, tokens): 127 | return convert_tokens_to_ids(self.vocab, tokens) 128 | 129 | 130 | class BasicTokenizer(object): 131 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 132 | 133 | def __init__(self, do_lower_case=True): 134 | """Constructs a BasicTokenizer. 135 | 136 | Args: 137 | do_lower_case: Whether to lower case the input. 138 | """ 139 | self.do_lower_case = do_lower_case 140 | 141 | def tokenize(self, text): 142 | """Tokenizes a piece of text.""" 143 | text = convert_to_unicode(text) 144 | text = self._clean_text(text) 145 | # This was added on November 1st, 2018 for the multilingual and Chinese 146 | # models. This is also applied to the English models now, but it doesn't 147 | # matter since the English models were not trained on any Chinese data 148 | # and generally don't have any Chinese data in them (there are Chinese 149 | # characters in the vocabulary because Wikipedia does have some Chinese 150 | # words in the English Wikipedia.). 151 | text = self._tokenize_chinese_chars(text) 152 | orig_tokens = whitespace_tokenize(text) 153 | split_tokens = [] 154 | for token in orig_tokens: 155 | if self.do_lower_case: 156 | token = token.lower() 157 | token = self._run_strip_accents(token) 158 | split_tokens.extend(self._run_split_on_punc(token)) 159 | 160 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 161 | return output_tokens 162 | 163 | def _run_strip_accents(self, text): 164 | """Strips accents from a piece of text.""" 165 | text = unicodedata.normalize("NFD", text) 166 | output = [] 167 | for char in text: 168 | cat = unicodedata.category(char) 169 | if cat == "Mn": 170 | continue 171 | output.append(char) 172 | return "".join(output) 173 | 174 | def _run_split_on_punc(self, text): 175 | """Splits punctuation on a piece of text.""" 176 | chars = list(text) 177 | i = 0 178 | start_new_word = True 179 | output = [] 180 | while i < len(chars): 181 | char = chars[i] 182 | if _is_punctuation(char): 183 | output.append([char]) 184 | start_new_word = True 185 | else: 186 | if start_new_word: 187 | output.append([]) 188 | start_new_word = False 189 | output[-1].append(char) 190 | i += 1 191 | 192 | return ["".join(x) for x in output] 193 | 194 | def _tokenize_chinese_chars(self, text): 195 | """Adds whitespace around any CJK character.""" 196 | output = [] 197 | for char in text: 198 | cp = ord(char) 199 | if self._is_chinese_char(cp): 200 | output.append(" ") 201 | output.append(char) 202 | output.append(" ") 203 | else: 204 | output.append(char) 205 | return "".join(output) 206 | 207 | def _is_chinese_char(self, cp): 208 | """Checks whether CP is the codepoint of a CJK character.""" 209 | # This defines a "chinese character" as anything in the CJK Unicode block: 210 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 211 | # 212 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 213 | # despite its name. The modern Korean Hangul alphabet is a different block, 214 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 215 | # space-separated words, so they are not treated specially and handled 216 | # like the all of the other languages. 217 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 218 | (cp >= 0x3400 and cp <= 0x4DBF) or # 219 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 220 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 221 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 222 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 223 | (cp >= 0xF900 and cp <= 0xFAFF) or # 224 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 225 | return True 226 | 227 | return False 228 | 229 | def _clean_text(self, text): 230 | """Performs invalid character removal and whitespace cleanup on text.""" 231 | output = [] 232 | for char in text: 233 | cp = ord(char) 234 | if cp == 0 or cp == 0xfffd or _is_control(char): 235 | continue 236 | if _is_whitespace(char): 237 | output.append(" ") 238 | else: 239 | output.append(char) 240 | return "".join(output) 241 | 242 | 243 | class WordpieceTokenizer(object): 244 | """Runs WordPiece tokenization.""" 245 | 246 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 247 | self.vocab = vocab 248 | self.unk_token = unk_token 249 | self.max_input_chars_per_word = max_input_chars_per_word 250 | 251 | def tokenize(self, text): 252 | """Tokenizes a piece of text into its word pieces. 253 | 254 | This uses a greedy longest-match-first algorithm to perform tokenization 255 | using the given vocabulary. 256 | 257 | For example: 258 | input = "unaffable" 259 | output = ["un", "##aff", "##able"] 260 | 261 | Args: 262 | text: A single token or whitespace separated tokens. This should have 263 | already been passed through `BasicTokenizer. 264 | 265 | Returns: 266 | A list of wordpiece tokens. 267 | """ 268 | 269 | text = convert_to_unicode(text) 270 | 271 | output_tokens = [] 272 | for token in whitespace_tokenize(text): 273 | chars = list(token) 274 | if len(chars) > self.max_input_chars_per_word: 275 | output_tokens.append(self.unk_token) 276 | continue 277 | 278 | is_bad = False 279 | start = 0 280 | sub_tokens = [] 281 | while start < len(chars): 282 | end = len(chars) 283 | cur_substr = None 284 | while start < end: 285 | substr = "".join(chars[start:end]) 286 | if start > 0: 287 | substr = "##" + substr 288 | if substr in self.vocab: 289 | cur_substr = substr 290 | break 291 | end -= 1 292 | if cur_substr is None: 293 | is_bad = True 294 | break 295 | sub_tokens.append(cur_substr) 296 | start = end 297 | 298 | if is_bad: 299 | output_tokens.append(self.unk_token) 300 | else: 301 | output_tokens.extend(sub_tokens) 302 | return output_tokens 303 | 304 | 305 | def _is_whitespace(char): 306 | """Checks whether `chars` is a whitespace character.""" 307 | # \t, \n, and \r are technically contorl characters but we treat them 308 | # as whitespace since they are generally considered as such. 309 | if char == " " or char == "\t" or char == "\n" or char == "\r": 310 | return True 311 | cat = unicodedata.category(char) 312 | if cat == "Zs": 313 | return True 314 | return False 315 | 316 | 317 | def _is_control(char): 318 | """Checks whether `chars` is a control character.""" 319 | # These are technically control characters but we count them as whitespace 320 | # characters. 321 | if char == "\t" or char == "\n" or char == "\r": 322 | return False 323 | cat = unicodedata.category(char) 324 | if cat.startswith("C"): 325 | return True 326 | return False 327 | 328 | 329 | def _is_punctuation(char): 330 | """Checks whether `chars` is a punctuation character.""" 331 | cp = ord(char) 332 | # We treat all non-letter/number ASCII as punctuation. 333 | # Characters such as "^", "$", and "`" are not in the Unicode 334 | # Punctuation class but we treat them as punctuation anyways, for 335 | # consistency. 336 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 337 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 338 | return True 339 | cat = unicodedata.category(char) 340 | if cat.startswith("P"): 341 | return True 342 | return False 343 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 41 | } 42 | PRETRAINED_MERGES_ARCHIVE_MAP = { 43 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 44 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 45 | } 46 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 47 | 'gpt2': 1024, 48 | } 49 | VOCAB_NAME = 'vocab.json' 50 | MERGES_NAME = 'merges.txt' 51 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 52 | 53 | @lru_cache() 54 | def bytes_to_unicode(): 55 | """ 56 | Returns list of utf-8 byte and a corresponding list of unicode strings. 57 | The reversible bpe codes work on unicode strings. 58 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 59 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 60 | This is a signficant percentage of your normal, say, 32K bpe vocab. 61 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 62 | And avoids mapping to whitespace/control characters the bpe code barfs on. 63 | """ 64 | _chr = unichr if sys.version_info[0] == 2 else chr 65 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 66 | cs = bs[:] 67 | n = 0 68 | for b in range(2**8): 69 | if b not in bs: 70 | bs.append(b) 71 | cs.append(2**8+n) 72 | n += 1 73 | cs = [_chr(n) for n in cs] 74 | return dict(zip(bs, cs)) 75 | 76 | def get_pairs(word): 77 | """Return set of symbol pairs in a word. 78 | 79 | Word is represented as tuple of symbols (symbols being variable-length strings). 80 | """ 81 | pairs = set() 82 | prev_char = word[0] 83 | for char in word[1:]: 84 | pairs.add((prev_char, char)) 85 | prev_char = char 86 | return pairs 87 | 88 | class GPT2Tokenizer(object): 89 | """ 90 | GPT-2 BPE tokenizer. Peculiarities: 91 | - Byte-level BPE 92 | """ 93 | @classmethod 94 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 95 | """ 96 | Instantiate a GPT2Tokenizer from a pre-trained model file. 97 | Download and cache the pre-trained model file if needed. 98 | """ 99 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 100 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 101 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 102 | special_tokens_file = None 103 | else: 104 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 105 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 106 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 107 | if not os.path.exists(special_tokens_file): 108 | special_tokens_file = None 109 | else: 110 | logger.info("loading special tokens file {}".format(special_tokens_file)) 111 | # redirect to the cache, if necessary 112 | try: 113 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 114 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 115 | except EnvironmentError: 116 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 117 | logger.error( 118 | "Couldn't reach server at '{}' to download vocabulary.".format( 119 | vocab_file)) 120 | else: 121 | logger.error( 122 | "Model name '{}' was not found in model name list ({}). " 123 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 124 | "at this path or url.".format( 125 | pretrained_model_name_or_path, 126 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 127 | pretrained_model_name_or_path, 128 | vocab_file, merges_file)) 129 | return None 130 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 131 | logger.info("loading vocabulary file {}".format(vocab_file)) 132 | logger.info("loading merges file {}".format(merges_file)) 133 | else: 134 | logger.info("loading vocabulary file {} from cache at {}".format( 135 | vocab_file, resolved_vocab_file)) 136 | logger.info("loading merges file {} from cache at {}".format( 137 | merges_file, resolved_merges_file)) 138 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 139 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 140 | # than the number of positional embeddings 141 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 142 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 143 | # Instantiate tokenizer. 144 | if special_tokens_file and 'special_tokens' not in kwargs: 145 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 146 | else: 147 | special_tokens = kwargs.pop('special_tokens', []) 148 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 149 | return tokenizer 150 | 151 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 152 | self.max_len = max_len if max_len is not None else int(1e12) 153 | self.encoder = json.load(open(vocab_file)) 154 | self.decoder = {v:k for k,v in self.encoder.items()} 155 | self.errors = errors # how to handle errors in decoding 156 | self.byte_encoder = bytes_to_unicode() 157 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 158 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 159 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 160 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 161 | self.cache = {} 162 | 163 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 164 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 165 | 166 | self.special_tokens = {} 167 | self.special_tokens_decoder = {} 168 | self.set_special_tokens(special_tokens) 169 | 170 | def __len__(self): 171 | return len(self.encoder) + len(self.special_tokens) 172 | 173 | def set_special_tokens(self, special_tokens): 174 | """ Add a list of additional tokens to the encoder. 175 | The additional tokens are indexed starting from the last index of the 176 | current vocabulary in the order of the `special_tokens` list. 177 | """ 178 | if not special_tokens: 179 | self.special_tokens = {} 180 | self.special_tokens_decoder = {} 181 | return 182 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 183 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 184 | logger.info("Special tokens {}".format(self.special_tokens)) 185 | 186 | def bpe(self, token): 187 | if token in self.cache: 188 | return self.cache[token] 189 | word = tuple(token) 190 | pairs = get_pairs(word) 191 | 192 | if not pairs: 193 | return token 194 | 195 | while True: 196 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 197 | if bigram not in self.bpe_ranks: 198 | break 199 | first, second = bigram 200 | new_word = [] 201 | i = 0 202 | while i < len(word): 203 | try: 204 | j = word.index(first, i) 205 | new_word.extend(word[i:j]) 206 | i = j 207 | except: 208 | new_word.extend(word[i:]) 209 | break 210 | 211 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 212 | new_word.append(first+second) 213 | i += 2 214 | else: 215 | new_word.append(word[i]) 216 | i += 1 217 | new_word = tuple(new_word) 218 | word = new_word 219 | if len(word) == 1: 220 | break 221 | else: 222 | pairs = get_pairs(word) 223 | word = ' '.join(word) 224 | self.cache[token] = word 225 | return word 226 | 227 | def tokenize(self, text): 228 | """ Tokenize a string. """ 229 | bpe_tokens = [] 230 | for token in re.findall(self.pat, text): 231 | if sys.version_info[0] == 2: 232 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 233 | else: 234 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 235 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 236 | return bpe_tokens 237 | 238 | def convert_tokens_to_ids(self, tokens): 239 | """ Converts a sequence of tokens into ids using the vocab. """ 240 | ids = [] 241 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 242 | if tokens in self.special_tokens: 243 | return self.special_tokens[tokens] 244 | else: 245 | return self.encoder.get(tokens, 0) 246 | for token in tokens: 247 | if token in self.special_tokens: 248 | ids.append(self.special_tokens[token]) 249 | else: 250 | ids.append(self.encoder.get(token, 0)) 251 | if len(ids) > self.max_len: 252 | logger.warning( 253 | "Token indices sequence length is longer than the specified maximum " 254 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 255 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 256 | ) 257 | return ids 258 | 259 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 260 | """Converts a sequence of ids in BPE tokens using the vocab.""" 261 | tokens = [] 262 | for i in ids: 263 | if i in self.special_tokens_decoder: 264 | if not skip_special_tokens: 265 | tokens.append(self.special_tokens_decoder[i]) 266 | else: 267 | tokens.append(self.decoder[i]) 268 | return tokens 269 | 270 | def encode(self, text): 271 | return self.convert_tokens_to_ids(self.tokenize(text)) 272 | 273 | def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True): 274 | text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) 275 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 276 | if clean_up_tokenization_spaces: 277 | text = text.replace('', '') 278 | text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' 279 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 280 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 281 | return text 282 | 283 | def save_vocabulary(self, vocab_path): 284 | """Save the tokenizer vocabulary and merge files to a directory.""" 285 | if not os.path.isdir(vocab_path): 286 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 287 | return 288 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 289 | merge_file = os.path.join(vocab_path, MERGES_NAME) 290 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 291 | 292 | with open(vocab_file, 'w', encoding='utf-8') as f: 293 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 294 | 295 | index = 0 296 | with open(merge_file, "w", encoding="utf-8") as writer: 297 | writer.write(u'#version: 0.2\n') 298 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 299 | if index != token_index: 300 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 301 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 302 | index = token_index 303 | writer.write(' '.join(bpe_tokens) + u'\n') 304 | index += 1 305 | 306 | index = len(self.encoder) 307 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 308 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 309 | if index != token_index: 310 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 311 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 312 | index = token_index 313 | writer.write(token + u'\n') 314 | index += 1 315 | 316 | return vocab_file, merge_file, special_tokens_file 317 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 45 | 46 | def get_pairs(word): 47 | """ 48 | Return set of symbol pairs in a word. 49 | word is represented as tuple of symbols (symbols being variable-length strings) 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | def text_standardize(text): 59 | """ 60 | fixes some issues the spacy tokenizer had on books corpus 61 | also does some whitespace standardization 62 | """ 63 | text = text.replace('—', '-') 64 | text = text.replace('–', '-') 65 | text = text.replace('―', '-') 66 | text = text.replace('…', '...') 67 | text = text.replace('´', "'") 68 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 69 | text = re.sub(r'\s*\n\s*', ' \n ', text) 70 | text = re.sub(r'[^\S\n]+', ' ', text) 71 | return text.strip() 72 | 73 | class OpenAIGPTTokenizer(object): 74 | """ 75 | BPE tokenizer. Peculiarities: 76 | - lower case all inputs 77 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 78 | - argument special_tokens and function set_special_tokens: 79 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 80 | """ 81 | @classmethod 82 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 83 | """ 84 | Instantiate a PreTrainedBertModel from a pre-trained model file. 85 | Download and cache the pre-trained model file if needed. 86 | """ 87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 90 | special_tokens_file = None 91 | else: 92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 95 | if not os.path.exists(special_tokens_file): 96 | special_tokens_file = None 97 | else: 98 | logger.info("loading special tokens file {}".format(special_tokens_file)) 99 | # redirect to the cache, if necessary 100 | try: 101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 103 | except EnvironmentError: 104 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 105 | logger.error( 106 | "Couldn't reach server at '{}' to download vocabulary.".format( 107 | vocab_file)) 108 | else: 109 | logger.error( 110 | "Model name '{}' was not found in model name list ({}). " 111 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 112 | "at this path or url.".format( 113 | pretrained_model_name_or_path, 114 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 115 | pretrained_model_name_or_path, 116 | vocab_file, merges_file)) 117 | return None 118 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 119 | logger.info("loading vocabulary file {}".format(vocab_file)) 120 | logger.info("loading merges file {}".format(merges_file)) 121 | else: 122 | logger.info("loading vocabulary file {} from cache at {}".format( 123 | vocab_file, resolved_vocab_file)) 124 | logger.info("loading merges file {} from cache at {}".format( 125 | merges_file, resolved_merges_file)) 126 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 127 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 128 | # than the number of positional embeddings 129 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 130 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 131 | # Instantiate tokenizer. 132 | if special_tokens_file and 'special_tokens' not in kwargs: 133 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 134 | else: 135 | special_tokens = kwargs.pop('special_tokens', []) 136 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 137 | return tokenizer 138 | 139 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 140 | try: 141 | import ftfy 142 | import spacy 143 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 144 | self.fix_text = ftfy.fix_text 145 | except ImportError: 146 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 147 | self.nlp = BasicTokenizer(do_lower_case=True, 148 | never_split=special_tokens if special_tokens is not None else []) 149 | self.fix_text = None 150 | 151 | self.max_len = max_len if max_len is not None else int(1e12) 152 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 153 | self.decoder = {v:k for k,v in self.encoder.items()} 154 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 155 | merges = [tuple(merge.split()) for merge in merges] 156 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 157 | self.cache = {} 158 | self.special_tokens = {} 159 | self.special_tokens_decoder = {} 160 | self.set_special_tokens(special_tokens) 161 | 162 | def __len__(self): 163 | return len(self.encoder) + len(self.special_tokens) 164 | 165 | def set_special_tokens(self, special_tokens): 166 | """ Add a list of additional tokens to the encoder. 167 | The additional tokens are indexed starting from the last index of the 168 | current vocabulary in the order of the `special_tokens` list. 169 | """ 170 | if not special_tokens: 171 | self.special_tokens = {} 172 | self.special_tokens_decoder = {} 173 | return 174 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 175 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 176 | if self.fix_text is None: 177 | # Using BERT's BasicTokenizer: we can update the tokenizer 178 | self.nlp.never_split = special_tokens 179 | logger.info("Special tokens {}".format(self.special_tokens)) 180 | 181 | def bpe(self, token): 182 | word = tuple(token[:-1]) + (token[-1] + '',) 183 | if token in self.cache: 184 | return self.cache[token] 185 | pairs = get_pairs(word) 186 | 187 | if not pairs: 188 | return token+'' 189 | 190 | while True: 191 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 192 | if bigram not in self.bpe_ranks: 193 | break 194 | first, second = bigram 195 | new_word = [] 196 | i = 0 197 | while i < len(word): 198 | try: 199 | j = word.index(first, i) 200 | new_word.extend(word[i:j]) 201 | i = j 202 | except: 203 | new_word.extend(word[i:]) 204 | break 205 | 206 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 207 | new_word.append(first+second) 208 | i += 2 209 | else: 210 | new_word.append(word[i]) 211 | i += 1 212 | new_word = tuple(new_word) 213 | word = new_word 214 | if len(word) == 1: 215 | break 216 | else: 217 | pairs = get_pairs(word) 218 | word = ' '.join(word) 219 | if word == '\n ': 220 | word = '\n' 221 | self.cache[token] = word 222 | return word 223 | 224 | def tokenize(self, text): 225 | """ Tokenize a string. """ 226 | split_tokens = [] 227 | if self.fix_text is None: 228 | # Using BERT's BasicTokenizer 229 | text = self.nlp.tokenize(text) 230 | for token in text: 231 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 232 | else: 233 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 234 | text = self.nlp(text_standardize(self.fix_text(text))) 235 | for token in text: 236 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 237 | return split_tokens 238 | 239 | def convert_tokens_to_ids(self, tokens): 240 | """ Converts a sequence of tokens into ids using the vocab. """ 241 | ids = [] 242 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 243 | if tokens in self.special_tokens: 244 | return self.special_tokens[tokens] 245 | else: 246 | return self.encoder.get(tokens, 0) 247 | for token in tokens: 248 | if token in self.special_tokens: 249 | ids.append(self.special_tokens[token]) 250 | else: 251 | ids.append(self.encoder.get(token, 0)) 252 | if len(ids) > self.max_len: 253 | logger.warning( 254 | "Token indices sequence length is longer than the specified maximum " 255 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 256 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 257 | ) 258 | return ids 259 | 260 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 261 | """Converts a sequence of ids in BPE tokens using the vocab.""" 262 | tokens = [] 263 | for i in ids: 264 | if i in self.special_tokens_decoder: 265 | if not skip_special_tokens: 266 | tokens.append(self.special_tokens_decoder[i]) 267 | else: 268 | tokens.append(self.decoder[i]) 269 | return tokens 270 | 271 | def encode(self, text): 272 | return self.convert_tokens_to_ids(self.tokenize(text)) 273 | 274 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 275 | """Converts a sequence of ids in a string.""" 276 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 277 | out_string = ''.join(tokens).replace('', ' ').strip() 278 | if clean_up_tokenization_spaces: 279 | out_string = out_string.replace('', '') 280 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' 281 | ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 282 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") 283 | return out_string 284 | 285 | def save_vocabulary(self, vocab_path): 286 | """Save the tokenizer vocabulary and merge files to a directory.""" 287 | if not os.path.isdir(vocab_path): 288 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 289 | return 290 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 291 | merge_file = os.path.join(vocab_path, MERGES_NAME) 292 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 293 | 294 | with open(vocab_file, 'w', encoding='utf-8') as f: 295 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 296 | 297 | index = 0 298 | with open(merge_file, "w", encoding="utf-8") as writer: 299 | writer.write(u'#version: 0.2\n') 300 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 301 | if index != token_index: 302 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 303 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 304 | index = token_index 305 | writer.write(' '.join(bpe_tokens) + u'\n') 306 | index += 1 307 | 308 | index = len(self.encoder) 309 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 310 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 311 | if index != token_index: 312 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 313 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 314 | index = token_index 315 | writer.write(token + u'\n') 316 | index += 1 317 | 318 | return vocab_file, merge_file, special_tokens_file 319 | -------------------------------------------------------------------------------- /pytorch_pretrained_zen/__init__.py: -------------------------------------------------------------------------------- 1 | version = "0.1.0" 2 | 3 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 4 | from .optimization import BertAdam, WarmupLinearSchedule 5 | from .modeling import ZenConfig, ZenForPreTraining, ZenForTokenClassification, ZenForSequenceClassification 6 | from .file_utils import WEIGHTS_NAME, CONFIG_NAME, PYTORCH_PRETRAINED_BERT_CACHE 7 | from .ngram_utils import ZenNgramDict, NGRAM_DICT_NAME 8 | 9 | -------------------------------------------------------------------------------- /pytorch_pretrained_zen/file_utils.py: -------------------------------------------------------------------------------- 1 | # This file is derived from the code at 2 | # https://github.com/huggingface/transformers/blob/master/transformers/file_utils.py 3 | # and the code at 4 | # https://github.com/allenai/allennlp/blob/master/allennlp/common/file_utils.py. 5 | # 6 | # Original copyright notice: 7 | # 8 | # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 9 | # Copyright by the AllenNLP authors. 10 | """Utilities for working with the local dataset cache.""" 11 | 12 | from __future__ import (absolute_import, division, print_function, unicode_literals) 13 | 14 | import sys 15 | import json 16 | import logging 17 | import os 18 | import shutil 19 | import tempfile 20 | import fnmatch 21 | from functools import wraps 22 | from hashlib import sha256 23 | import sys 24 | from io import open 25 | 26 | import boto3 27 | import requests 28 | from botocore.exceptions import ClientError 29 | from tqdm import tqdm 30 | 31 | try: 32 | from torch.hub import _get_torch_home 33 | 34 | torch_cache_home = _get_torch_home() 35 | except ImportError: 36 | torch_cache_home = os.path.expanduser( 37 | os.getenv('TORCH_HOME', os.path.join( 38 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 39 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 40 | 41 | try: 42 | from urllib.parse import urlparse 43 | except ImportError: 44 | from urlparse import urlparse 45 | 46 | try: 47 | from pathlib import Path 48 | 49 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 50 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 51 | except (AttributeError, ImportError): 52 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 53 | default_cache_path) 54 | 55 | CONFIG_NAME = "config.json" 56 | WEIGHTS_NAME = "pytorch_model.bin" 57 | 58 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 59 | 60 | 61 | def url_to_filename(url, etag=None): 62 | """ 63 | Convert `url` into a hashed filename in a repeatable way. 64 | If `etag` is specified, append its hash to the url's, delimited 65 | by a period. 66 | """ 67 | url_bytes = url.encode('utf-8') 68 | url_hash = sha256(url_bytes) 69 | filename = url_hash.hexdigest() 70 | 71 | if etag: 72 | etag_bytes = etag.encode('utf-8') 73 | etag_hash = sha256(etag_bytes) 74 | filename += '.' + etag_hash.hexdigest() 75 | 76 | return filename 77 | 78 | 79 | def filename_to_url(filename, cache_dir=None): 80 | """ 81 | Return the url and etag (which may be ``None``) stored for `filename`. 82 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 83 | """ 84 | if cache_dir is None: 85 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 86 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 87 | cache_dir = str(cache_dir) 88 | 89 | cache_path = os.path.join(cache_dir, filename) 90 | if not os.path.exists(cache_path): 91 | raise EnvironmentError("file {} not found".format(cache_path)) 92 | 93 | meta_path = cache_path + '.json' 94 | if not os.path.exists(meta_path): 95 | raise EnvironmentError("file {} not found".format(meta_path)) 96 | 97 | with open(meta_path, encoding="utf-8") as meta_file: 98 | metadata = json.load(meta_file) 99 | url = metadata['url'] 100 | etag = metadata['etag'] 101 | 102 | return url, etag 103 | 104 | 105 | def cached_path(url_or_filename, cache_dir=None): 106 | """ 107 | Given something that might be a URL (or might be a local path), 108 | determine which. If it's a URL, download the file and cache it, and 109 | return the path to the cached file. If it's already a local path, 110 | make sure the file exists and then return the path. 111 | """ 112 | if cache_dir is None: 113 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 114 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 115 | url_or_filename = str(url_or_filename) 116 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 117 | cache_dir = str(cache_dir) 118 | 119 | parsed = urlparse(url_or_filename) 120 | 121 | if parsed.scheme in ('http', 'https', 's3'): 122 | # URL, so get it from the cache (downloading if necessary) 123 | return get_from_cache(url_or_filename, cache_dir) 124 | elif os.path.exists(url_or_filename): 125 | # File, and it exists. 126 | return url_or_filename 127 | elif parsed.scheme == '': 128 | # File, but it doesn't exist. 129 | raise EnvironmentError("file {} not found".format(url_or_filename)) 130 | else: 131 | # Something unknown 132 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 133 | 134 | 135 | def split_s3_path(url): 136 | """Split a full s3 path into the bucket name and path.""" 137 | parsed = urlparse(url) 138 | if not parsed.netloc or not parsed.path: 139 | raise ValueError("bad s3 path {}".format(url)) 140 | bucket_name = parsed.netloc 141 | s3_path = parsed.path 142 | # Remove '/' at beginning of path. 143 | if s3_path.startswith("/"): 144 | s3_path = s3_path[1:] 145 | return bucket_name, s3_path 146 | 147 | 148 | def s3_request(func): 149 | """ 150 | Wrapper function for s3 requests in order to create more helpful error 151 | messages. 152 | """ 153 | 154 | @wraps(func) 155 | def wrapper(url, *args, **kwargs): 156 | try: 157 | return func(url, *args, **kwargs) 158 | except ClientError as exc: 159 | if int(exc.response["Error"]["Code"]) == 404: 160 | raise EnvironmentError("file {} not found".format(url)) 161 | else: 162 | raise 163 | 164 | return wrapper 165 | 166 | 167 | @s3_request 168 | def s3_etag(url): 169 | """Check ETag on S3 object.""" 170 | s3_resource = boto3.resource("s3") 171 | bucket_name, s3_path = split_s3_path(url) 172 | s3_object = s3_resource.Object(bucket_name, s3_path) 173 | return s3_object.e_tag 174 | 175 | 176 | @s3_request 177 | def s3_get(url, temp_file): 178 | """Pull a file directly from S3.""" 179 | s3_resource = boto3.resource("s3") 180 | bucket_name, s3_path = split_s3_path(url) 181 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 182 | 183 | 184 | def http_get(url, temp_file): 185 | req = requests.get(url, stream=True) 186 | content_length = req.headers.get('Content-Length') 187 | total = int(content_length) if content_length is not None else None 188 | progress = tqdm(unit="B", total=total) 189 | for chunk in req.iter_content(chunk_size=1024): 190 | if chunk: # filter out keep-alive new chunks 191 | progress.update(len(chunk)) 192 | temp_file.write(chunk) 193 | progress.close() 194 | 195 | 196 | def get_from_cache(url, cache_dir=None): 197 | """ 198 | Given a URL, look for the corresponding dataset in the local cache. 199 | If it's not there, download it. Then return the path to the cached file. 200 | """ 201 | if cache_dir is None: 202 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 203 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 204 | cache_dir = str(cache_dir) 205 | 206 | if not os.path.exists(cache_dir): 207 | os.makedirs(cache_dir) 208 | 209 | # Get eTag to add to filename, if it exists. 210 | if url.startswith("s3://"): 211 | etag = s3_etag(url) 212 | else: 213 | try: 214 | response = requests.head(url, allow_redirects=True) 215 | if response.status_code != 200: 216 | etag = None 217 | else: 218 | etag = response.headers.get("ETag") 219 | except EnvironmentError: 220 | etag = None 221 | 222 | if sys.version_info[0] == 2 and etag is not None: 223 | etag = etag.decode('utf-8') 224 | filename = url_to_filename(url, etag) 225 | 226 | # get cache path to put the file 227 | cache_path = os.path.join(cache_dir, filename) 228 | 229 | # If we don't have a connection (etag is None) and can't identify the file 230 | # try to get the last downloaded one 231 | if not os.path.exists(cache_path) and etag is None: 232 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 233 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 234 | if matching_files: 235 | cache_path = os.path.join(cache_dir, matching_files[-1]) 236 | 237 | if not os.path.exists(cache_path): 238 | # Download to temporary file, then copy to cache dir once finished. 239 | # Otherwise you get corrupt cache entries if the download gets interrupted. 240 | with tempfile.NamedTemporaryFile() as temp_file: 241 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 242 | 243 | # GET file object 244 | if url.startswith("s3://"): 245 | s3_get(url, temp_file) 246 | else: 247 | http_get(url, temp_file) 248 | 249 | # we are copying the file before closing it, so flush to avoid truncation 250 | temp_file.flush() 251 | # shutil.copyfileobj() starts at the current position, so go to the start 252 | temp_file.seek(0) 253 | 254 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 255 | with open(cache_path, 'wb') as cache_file: 256 | shutil.copyfileobj(temp_file, cache_file) 257 | 258 | logger.info("creating metadata file for %s", cache_path) 259 | meta = {'url': url, 'etag': etag} 260 | meta_path = cache_path + '.json' 261 | with open(meta_path, 'w') as meta_file: 262 | output_string = json.dumps(meta) 263 | if sys.version_info[0] == 2 and isinstance(output_string, str): 264 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 265 | meta_file.write(output_string) 266 | 267 | logger.info("removing temp file %s", temp_file.name) 268 | 269 | return cache_path 270 | 271 | 272 | def read_set_from_file(filename): 273 | ''' 274 | Extract a de-duped collection (set) of text from a file. 275 | Expected file format is one item per line. 276 | ''' 277 | collection = set() 278 | with open(filename, 'r', encoding='utf-8') as file_: 279 | for line in file_: 280 | collection.add(line.rstrip()) 281 | return collection 282 | 283 | 284 | def get_file_extension(path, dot=True, lower=True): 285 | ext = os.path.splitext(path)[1] 286 | ext = ext if dot else ext[1:] 287 | return ext.lower() if lower else ext 288 | -------------------------------------------------------------------------------- /pytorch_pretrained_zen/ngram_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright 2019 Sinovation Ventures AI Institute 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """utils for ngram for ZEN model.""" 16 | 17 | import os 18 | import logging 19 | 20 | NGRAM_DICT_NAME = 'ngram.txt' 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class ZenNgramDict(object): 25 | """ 26 | Dict class to store the ngram 27 | """ 28 | def __init__(self, ngram_freq_path, tokenizer, max_ngram_in_seq=128): 29 | """Constructs ZenNgramDict 30 | 31 | :param ngram_freq_path: ngrams with frequency 32 | """ 33 | if os.path.isdir(ngram_freq_path): 34 | ngram_freq_path = os.path.join(ngram_freq_path, NGRAM_DICT_NAME) 35 | self.ngram_freq_path = ngram_freq_path 36 | self.max_ngram_in_seq = max_ngram_in_seq 37 | self.id_to_ngram_list = ["[pad]"] 38 | self.ngram_to_id_dict = {"[pad]": 0} 39 | self.ngram_to_freq_dict = {} 40 | 41 | logger.info("loading ngram frequency file {}".format(ngram_freq_path)) 42 | with open(ngram_freq_path, "r", encoding="utf-8") as fin: 43 | for i, line in enumerate(fin): 44 | ngram,freq = line.split(",") 45 | tokens = tuple(tokenizer.tokenize(ngram)) 46 | self.ngram_to_freq_dict[ngram] = freq 47 | self.id_to_ngram_list.append(tokens) 48 | self.ngram_to_id_dict[tokens] = i + 1 49 | 50 | def save(self, ngram_freq_path): 51 | with open(ngram_freq_path, "w", encoding="utf-8") as fout: 52 | for ngram,freq in self.ngram_to_freq_dict.items(): 53 | fout.write("{},{}\n".format(ngram, freq)) -------------------------------------------------------------------------------- /pytorch_pretrained_zen/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file is derived from the code at 3 | # https://github.com/huggingface/transformers/blob/master/transformers/optimization.py 4 | # 5 | # Original copyright notice: 6 | # 7 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """PyTorch optimization for BERT model.""" 21 | 22 | import math 23 | import torch 24 | from torch.optim import Optimizer 25 | from torch.optim.optimizer import required 26 | from torch.nn.utils import clip_grad_norm_ 27 | import logging 28 | import abc 29 | import sys 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | if sys.version_info >= (3, 4): 34 | ABC = abc.ABC 35 | else: 36 | ABC = abc.ABCMeta('ABC', (), {}) 37 | 38 | 39 | class _LRSchedule(ABC): 40 | """ Parent of all LRSchedules here. """ 41 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 42 | 43 | def __init__(self, warmup=0.002, t_total=-1, **kw): 44 | """ 45 | :param warmup: what fraction of t_total steps will be used for linear warmup 46 | :param t_total: how many training steps (updates) are planned 47 | :param kw: 48 | """ 49 | super(_LRSchedule, self).__init__(**kw) 50 | if t_total < 0: 51 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 52 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 53 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 54 | warmup = max(warmup, 0.) 55 | self.warmup, self.t_total = float(warmup), float(t_total) 56 | self.warned_for_t_total_at_progress = -1 57 | 58 | def get_lr(self, step, nowarn=False): 59 | """ 60 | :param step: which of t_total steps we're on 61 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 62 | :return: learning rate multiplier for current update 63 | """ 64 | if self.t_total < 0: 65 | return 1. 66 | progress = float(step) / self.t_total 67 | ret = self.get_lr_(progress) 68 | # warning for exceeding t_total (only active with warmup_linear 69 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 70 | logger.warning( 71 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 72 | .format(ret, self.__class__.__name__)) 73 | self.warned_for_t_total_at_progress = progress 74 | # end warning 75 | return ret 76 | 77 | @abc.abstractmethod 78 | def get_lr_(self, progress): 79 | """ 80 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 81 | :return: learning rate multiplier for current update 82 | """ 83 | return 1. 84 | 85 | 86 | class ConstantLR(_LRSchedule): 87 | def get_lr_(self, progress): 88 | return 1. 89 | 90 | 91 | class WarmupCosineSchedule(_LRSchedule): 92 | """ 93 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 94 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 95 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 96 | """ 97 | warn_t_total = True 98 | 99 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 100 | """ 101 | :param warmup: see LRSchedule 102 | :param t_total: see LRSchedule 103 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 104 | :param kw: 105 | """ 106 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 107 | self.cycles = cycles 108 | 109 | def get_lr_(self, progress): 110 | if progress < self.warmup: 111 | return progress / self.warmup 112 | else: 113 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 114 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 115 | 116 | 117 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 118 | """ 119 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 120 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 121 | learning rate (with hard restarts). 122 | """ 123 | 124 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 125 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 126 | assert (cycles >= 1.) 127 | 128 | def get_lr_(self, progress): 129 | if progress < self.warmup: 130 | return progress / self.warmup 131 | else: 132 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 133 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 134 | return ret 135 | 136 | 137 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 138 | """ 139 | All training progress is divided in `cycles` (default=1.) parts of equal length. 140 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 141 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 142 | """ 143 | 144 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 145 | assert (warmup * cycles < 1.) 146 | warmup = warmup * cycles if warmup >= 0 else warmup 147 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, 148 | **kw) 149 | 150 | def get_lr_(self, progress): 151 | progress = progress * self.cycles % 1. 152 | if progress < self.warmup: 153 | return progress / self.warmup 154 | else: 155 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 156 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 157 | return ret 158 | 159 | 160 | class WarmupConstantSchedule(_LRSchedule): 161 | """ 162 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 163 | Keeps learning rate equal to 1. after warmup. 164 | """ 165 | 166 | def get_lr_(self, progress): 167 | if progress < self.warmup: 168 | return progress / self.warmup 169 | return 1. 170 | 171 | 172 | class WarmupLinearSchedule(_LRSchedule): 173 | """ 174 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 175 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 176 | """ 177 | warn_t_total = True 178 | 179 | def get_lr_(self, progress): 180 | if progress < self.warmup: 181 | return progress / self.warmup 182 | return max((progress - 1.) / (self.warmup - 1.), 0.) 183 | 184 | 185 | SCHEDULES = { 186 | None: ConstantLR, 187 | "none": ConstantLR, 188 | "warmup_cosine": WarmupCosineSchedule, 189 | "warmup_constant": WarmupConstantSchedule, 190 | "warmup_linear": WarmupLinearSchedule 191 | } 192 | 193 | 194 | class BertAdam(Optimizer): 195 | """Implements BERT version of Adam algorithm with weight decay fix. 196 | Params: 197 | lr: learning rate 198 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 199 | t_total: total number of training steps for the learning 200 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 201 | schedule: schedule to use for the warmup (see above). 202 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 203 | If `None` or `'none'`, learning rate is always kept constant. 204 | Default : `'warmup_linear'` 205 | b1: Adams b1. Default: 0.9 206 | b2: Adams b2. Default: 0.999 207 | e: Adams epsilon. Default: 1e-6 208 | weight_decay: Weight decay. Default: 0.01 209 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 210 | """ 211 | 212 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 213 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 214 | if lr is not required and lr < 0.0: 215 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 216 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 217 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 218 | if not 0.0 <= b1 < 1.0: 219 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 220 | if not 0.0 <= b2 < 1.0: 221 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 222 | if not e >= 0.0: 223 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 224 | # initialize schedule object 225 | if not isinstance(schedule, _LRSchedule): 226 | schedule_type = SCHEDULES[schedule] 227 | schedule = schedule_type(warmup=warmup, t_total=t_total) 228 | else: 229 | if warmup != -1 or t_total != -1: 230 | logger.warning( 231 | "warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 232 | "Please specify custom warmup and t_total in _LRSchedule object.") 233 | defaults = dict(lr=lr, schedule=schedule, 234 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 235 | max_grad_norm=max_grad_norm) 236 | super(BertAdam, self).__init__(params, defaults) 237 | 238 | def get_lr(self): 239 | lr = [] 240 | for group in self.param_groups: 241 | for p in group['params']: 242 | state = self.state[p] 243 | if len(state) == 0: 244 | return [0] 245 | lr_scheduled = group['lr'] 246 | lr_scheduled *= group['schedule'].get_lr(state['step']) 247 | lr.append(lr_scheduled) 248 | return lr 249 | 250 | def step(self, closure=None): 251 | """Performs a single optimization step. 252 | 253 | Arguments: 254 | closure (callable, optional): A closure that reevaluates the model 255 | and returns the loss. 256 | """ 257 | loss = None 258 | if closure is not None: 259 | loss = closure() 260 | 261 | for group in self.param_groups: 262 | for p in group['params']: 263 | if p.grad is None: 264 | continue 265 | grad = p.grad.data 266 | if grad.is_sparse: 267 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 268 | 269 | state = self.state[p] 270 | 271 | # State initialization 272 | if len(state) == 0: 273 | state['step'] = 0 274 | # Exponential moving average of gradient values 275 | state['next_m'] = torch.zeros_like(p.data) 276 | # Exponential moving average of squared gradient values 277 | state['next_v'] = torch.zeros_like(p.data) 278 | 279 | next_m, next_v = state['next_m'], state['next_v'] 280 | beta1, beta2 = group['b1'], group['b2'] 281 | 282 | # Add grad clipping 283 | if group['max_grad_norm'] > 0: 284 | clip_grad_norm_(p, group['max_grad_norm']) 285 | 286 | # Decay the first and second moment running average coefficient 287 | # In-place operations to update the averages at the same time 288 | next_m.mul_(beta1).add_(1 - beta1, grad) 289 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 290 | update = next_m / (next_v.sqrt() + group['e']) 291 | 292 | # Just adding the square of the weights to the loss function is *not* 293 | # the correct way of using L2 regularization/weight decay with Adam, 294 | # since that will interact with the m and v parameters in strange ways. 295 | # 296 | # Instead we want to decay the weights in a manner that doesn't interact 297 | # with the m/v parameters. This is equivalent to adding the square 298 | # of the weights to the loss with plain (non-momentum) SGD. 299 | if group['weight_decay'] > 0.0: 300 | update += group['weight_decay'] * p.data 301 | 302 | lr_scheduled = group['lr'] 303 | lr_scheduled *= group['schedule'].get_lr(state['step']) 304 | 305 | update_with_lr = lr_scheduled * update 306 | p.data.add_(-update_with_lr) 307 | 308 | state['step'] += 1 309 | 310 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 311 | # No bias correction 312 | # bias_correction1 = 1 - beta1 ** state['step'] 313 | # bias_correction2 = 1 - beta2 ** state['step'] 314 | 315 | return loss 316 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch == 1.1.0 2 | tensorflow-gpu==1.13.1 3 | tqdm 4 | nltk 5 | pandas 6 | boto3 7 | requests 8 | regex 9 | seqeval 10 | psutil 11 | cython 12 | benepar[cpu] 13 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | mkdir logs 2 | 3 | # important parameters 4 | # do_train: train the model 5 | # do_test: test the model 6 | # use_bert: use BERT as encoder 7 | # use_zen: use ZEN as encoder 8 | # bert_model: the directory of BERT/ZEN model 9 | # use_attention: use two-way attention 10 | # source: the toolkit to be use (stanford or berkeley) 11 | # feature_flag: use pos, chunk, or dep knowledge 12 | # model_name: the name of model to save 13 | 14 | # training 15 | 16 | # Command lines to train our model with POS knowledge from SCT 17 | # use BERT 18 | 19 | # CTB5 20 | python twasp_main.py --do_train --train_data_path=./data/CTB5/train.tsv --eval_data_path=./data/CTB5/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB5_bert_two_way_stanford_pos 21 | 22 | # CTB6 23 | python twasp_main.py --do_train --train_data_path=./data/CTB6/train.tsv --eval_data_path=./data/CTB6/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB6_bert_two_way_stanford_pos 24 | 25 | # CTB7 26 | python twasp_main.py --do_train --train_data_path=./data/CTB7/train.tsv --eval_data_path=./data/CTB7/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB7_bert_two_way_stanford_pos 27 | 28 | # CTB9 29 | python twasp_main.py --do_train --train_data_path=./data/CTB9/train.tsv --eval_data_path=./data/CTB9/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB9_bert_two_way_stanford_pos 30 | 31 | # UD1 32 | python twasp_main.py --do_train --train_data_path=./data/UD1/train.tsv --eval_data_path=./data/UD1/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=UD1_bert_two_way_stanford_pos 33 | 34 | # UD2 35 | python twasp_main.py --do_train --train_data_path=./data/UD2/train.tsv --eval_data_path=./data/UD2/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=UD2_bert_two_way_stanford_pos 36 | 37 | 38 | # use ZEN 39 | 40 | # CTB5 41 | python twasp_main.py --do_train --train_data_path=./data/CTB5/train.tsv --eval_data_path=./data/CTB5/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB5_zen_two_way_stanford_pos 42 | 43 | # CTB6 44 | python twasp_main.py --do_train --train_data_path=./data/CTB6/train.tsv --eval_data_path=./data/CTB6/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB6_zen_two_way_stanford_pos 45 | 46 | # CTB7 47 | python twasp_main.py --do_train --train_data_path=./data/CTB7/train.tsv --eval_data_path=./data/CTB7/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB7_zen_two_way_stanford_pos 48 | 49 | # CTB9 50 | python twasp_main.py --do_train --train_data_path=./data/CTB9/train.tsv --eval_data_path=./data/CTB9/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=CTB9_zen_two_way_stanford_pos 51 | 52 | # UD1 53 | python twasp_main.py --do_train --train_data_path=./data/UD1/train.tsv --eval_data_path=./data/UD1/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=UD1_zen_two_way_stanford_pos 54 | 55 | # UD2 56 | python twasp_main.py --do_train --train_data_path=./data/UD2/train.tsv --eval_data_path=./data/UD2/dev.tsv --use_zen --bert_model=/path/to/zen/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=16 --eval_batch_size=16 --num_train_epochs=100 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=UD2_zen_two_way_stanford_pos 57 | 58 | 59 | # testing 60 | 61 | python twasp_main.py --do_test --eval_data_path=./data/dataset_name/test.tsv --eval_model=./models/model_name/model.pt 62 | 63 | -------------------------------------------------------------------------------- /run_sample.sh: -------------------------------------------------------------------------------- 1 | mkdir logs 2 | 3 | # train 4 | python twasp_main.py --do_train --train_data_path=./sample_data/train.tsv --eval_data_path=./sample_data/dev.tsv --use_bert --bert_model=/path/to/bert/model --use_attention --max_seq_length=300 --max_ngram_size=300 --train_batch_size=2 --eval_batch_size=2 --num_train_epochs=3 --warmup_proportion=0.1 --learning_rate=1e-5 --patient=15 --source=stanford --feature_flag=pos --model_name=sample_model 5 | 6 | # test 7 | python twasp_main.py --do_test --eval_data_path=./sample_data/test.tsv --eval_model=./models/model_name/model.pt 8 | 9 | # predict 10 | python twasp_main.py --do_predict --input_file=./sample_data/sentence.txt --output_file=./sample_data/sentece.txt.out --eval_model=./models/model_name/model.pt 11 | -------------------------------------------------------------------------------- /sample_data/dev.stanford.json: -------------------------------------------------------------------------------- 1 | {"index": 0, "parse": "(ROOT (ROOT\n (IP\n (NP (NR 中) (NR 美))\n (VP\n (PP (P 在)\n (NP (NR 沪)))\n (VP (VV 签订)\n (NP (NN 高科技) (NN 合作) (NN 协议)))))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 5, "dependentGloss": "签订"}, {"dep": "name", "governor": 2, "governorGloss": "美", "dependent": 1, "dependentGloss": "中"}, {"dep": "nsubj", "governor": 5, "governorGloss": "签订", "dependent": 2, "dependentGloss": "美"}, {"dep": "case", "governor": 4, "governorGloss": "沪", "dependent": 3, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 5, "governorGloss": "签订", "dependent": 4, "dependentGloss": "沪"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "协议", "dependent": 6, "dependentGloss": "高科技"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "协议", "dependent": 7, "dependentGloss": "合作"}, {"dep": "dobj", "governor": 5, "governorGloss": "签订", "dependent": 8, "dependentGloss": "协议"}], "tokens": [{"index": 1, "word": "中", "originalText": "中", "characterOffsetBegin": 0, "characterOffsetEnd": 1, "pos": "NR"}, {"index": 2, "word": "美", "originalText": "美", "characterOffsetBegin": 1, "characterOffsetEnd": 2, "pos": "NR"}, {"index": 3, "word": "在", "originalText": "在", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "P"}, {"index": 4, "word": "沪", "originalText": "沪", "characterOffsetBegin": 3, "characterOffsetEnd": 4, "pos": "NR"}, {"index": 5, "word": "签订", "originalText": "签订", "characterOffsetBegin": 4, "characterOffsetEnd": 6, "pos": "VV"}, {"index": 6, "word": "高科技", "originalText": "高科技", "characterOffsetBegin": 6, "characterOffsetEnd": 9, "pos": "NN"}, {"index": 7, "word": "合作", "originalText": "合作", "characterOffsetBegin": 9, "characterOffsetEnd": 11, "pos": "NN"}, {"index": 8, "word": "协议", "originalText": "协议", "characterOffsetBegin": 11, "characterOffsetEnd": 13, "pos": "NN"}]} 2 | {"index": 0, "parse": "(ROOT (ROOT\n (FRAG (NR 新华社) (NR 上海) (NT 八月) (NT 三十一日) (NN 电) (PU () (NN 记者) (NR 白国良) (PU 、) (NR 夏儒阁) (PU )))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 7, "dependentGloss": "记者"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 1, "dependentGloss": "新华社"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 2, "dependentGloss": "上海"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 3, "dependentGloss": "八月"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 4, "dependentGloss": "三十一日"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 5, "dependentGloss": "电"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 6, "dependentGloss": "("}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 8, "dependentGloss": "白国良"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 9, "dependentGloss": "、"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 10, "dependentGloss": "夏儒阁"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 11, "dependentGloss": ")"}], "tokens": [{"index": 1, "word": "新华社", "originalText": "新华社", "characterOffsetBegin": 0, "characterOffsetEnd": 3, "pos": "NR"}, {"index": 2, "word": "上海", "originalText": "上海", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "NR"}, {"index": 3, "word": "八月", "originalText": "八月", "characterOffsetBegin": 5, "characterOffsetEnd": 7, "pos": "NT"}, {"index": 4, "word": "三十一日", "originalText": "三十一日", "characterOffsetBegin": 7, "characterOffsetEnd": 11, "pos": "NT"}, {"index": 5, "word": "电", "originalText": "电", "characterOffsetBegin": 11, "characterOffsetEnd": 12, "pos": "NN"}, {"index": 6, "word": "(", "originalText": "(", "characterOffsetBegin": 12, "characterOffsetEnd": 13, "pos": "PU"}, {"index": 7, "word": "记者", "originalText": "记者", "characterOffsetBegin": 13, "characterOffsetEnd": 15, "pos": "NN"}, {"index": 8, "word": "白国良", "originalText": "白国良", "characterOffsetBegin": 15, "characterOffsetEnd": 18, "pos": "NR"}, {"index": 9, "word": "、", "originalText": "、", "characterOffsetBegin": 18, "characterOffsetEnd": 19, "pos": "PU"}, {"index": 10, "word": "夏儒阁", "originalText": "夏儒阁", "characterOffsetBegin": 19, "characterOffsetEnd": 22, "pos": "NR"}, {"index": 11, "word": ")", "originalText": ")", "characterOffsetBegin": 22, "characterOffsetEnd": 23, "pos": "PU"}]} 3 | {"index": 0, "parse": "(ROOT (ROOT\n (IP (PU “)\n (NP\n (NP (NR 中) (NR 美))\n (VV 合作)\n (NP\n (ADJP (NN 高科技))\n (NP (NN 项目)))\n (NP (NN 签字) (NN 仪式))\n (PU ”))\n (VP\n (NP (NT 今天))\n (PP (P 在)\n (NP (NR 上海)))\n (VP (VV 举行)))\n (PU 。))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 13, "dependentGloss": "举行"}, {"dep": "punct", "governor": 8, "governorGloss": "仪式", "dependent": 1, "dependentGloss": "“"}, {"dep": "name", "governor": 3, "governorGloss": "美", "dependent": 2, "dependentGloss": "中"}, {"dep": "nsubj", "governor": 4, "governorGloss": "合作", "dependent": 3, "dependentGloss": "美"}, {"dep": "acl", "governor": 6, "governorGloss": "项目", "dependent": 4, "dependentGloss": "合作"}, {"dep": "compound:nn", "governor": 6, "governorGloss": "项目", "dependent": 5, "dependentGloss": "高科技"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "仪式", "dependent": 6, "dependentGloss": "项目"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "仪式", "dependent": 7, "dependentGloss": "签字"}, {"dep": "nsubj", "governor": 13, "governorGloss": "举行", "dependent": 8, "dependentGloss": "仪式"}, {"dep": "punct", "governor": 8, "governorGloss": "仪式", "dependent": 9, "dependentGloss": "”"}, {"dep": "nmod:tmod", "governor": 13, "governorGloss": "举行", "dependent": 10, "dependentGloss": "今天"}, {"dep": "case", "governor": 12, "governorGloss": "上海", "dependent": 11, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 13, "governorGloss": "举行", "dependent": 12, "dependentGloss": "上海"}, {"dep": "punct", "governor": 13, "governorGloss": "举行", "dependent": 14, "dependentGloss": "。"}], "tokens": [{"index": 1, "word": "“", "originalText": "“", "characterOffsetBegin": 0, "characterOffsetEnd": 1, "pos": "PU"}, {"index": 2, "word": "中", "originalText": "中", "characterOffsetBegin": 1, "characterOffsetEnd": 2, "pos": "NR"}, {"index": 3, "word": "美", "originalText": "美", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "NR"}, {"index": 4, "word": "合作", "originalText": "合作", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "VV"}, {"index": 5, "word": "高科技", "originalText": "高科技", "characterOffsetBegin": 5, "characterOffsetEnd": 8, "pos": "NN"}, {"index": 6, "word": "项目", "originalText": "项目", "characterOffsetBegin": 8, "characterOffsetEnd": 10, "pos": "NN"}, {"index": 7, "word": "签字", "originalText": "签字", "characterOffsetBegin": 10, "characterOffsetEnd": 12, "pos": "NN"}, {"index": 8, "word": "仪式", "originalText": "仪式", "characterOffsetBegin": 12, "characterOffsetEnd": 14, "pos": "NN"}, {"index": 9, "word": "”", "originalText": "”", "characterOffsetBegin": 14, "characterOffsetEnd": 15, "pos": "PU"}, {"index": 10, "word": "今天", "originalText": "今天", "characterOffsetBegin": 15, "characterOffsetEnd": 17, "pos": "NT"}, {"index": 11, "word": "在", "originalText": "在", "characterOffsetBegin": 17, "characterOffsetEnd": 18, "pos": "P"}, {"index": 12, "word": "上海", "originalText": "上海", "characterOffsetBegin": 18, "characterOffsetEnd": 20, "pos": "NR"}, {"index": 13, "word": "举行", "originalText": "举行", "characterOffsetBegin": 20, "characterOffsetEnd": 22, "pos": "VV"}, {"index": 14, "word": "。", "originalText": "。", "characterOffsetBegin": 22, "characterOffsetEnd": 23, "pos": "PU"}]} 4 | {"index": 0, "parse": "(ROOT (ROOT\n (IP\n (IP\n (NP\n (CP\n (IP\n (VP\n (NP (NT 上午))\n (PP (P 在)\n (NP (PN 这里)))\n (VP (VV 签字))))\n (DEC 的)))\n (VP (VC 是)\n (NP\n (NP (NN 知识) (NN 信息) (NN 网络) (NN 通讯) (NN 技术)\n (CC 和)\n (NN 脱氧核糖核酸) (NN 生物) (NN 技术))\n (QP (CD 两)\n (CLP (M 个)))\n (NP (NN 项目)))))\n (PU ,)\n (IP\n (VP\n (ADVP (AD 同时))\n (ADVP (AD 还))\n (VP (VV 签订) (AS 了)\n (NP (NN 语言) (NN 教学) (NN 交流) (NN 合作) (NN 协议)))))\n (PU 。))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 18, "dependentGloss": "项目"}, {"dep": "nmod:tmod", "governor": 4, "governorGloss": "签字", "dependent": 1, "dependentGloss": "上午"}, {"dep": "case", "governor": 3, "governorGloss": "这里", "dependent": 2, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 4, "governorGloss": "签字", "dependent": 3, "dependentGloss": "这里"}, {"dep": "nsubj", "governor": 18, "governorGloss": "项目", "dependent": 4, "dependentGloss": "签字"}, {"dep": "mark", "governor": 4, "governorGloss": "签字", "dependent": 5, "dependentGloss": "的"}, {"dep": "cop", "governor": 18, "governorGloss": "项目", "dependent": 6, "dependentGloss": "是"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 7, "dependentGloss": "知识"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 8, "dependentGloss": "信息"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 9, "dependentGloss": "网络"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 10, "dependentGloss": "通讯"}, {"dep": "conj", "governor": 15, "governorGloss": "技术", "dependent": 11, "dependentGloss": "技术"}, {"dep": "cc", "governor": 15, "governorGloss": "技术", "dependent": 12, "dependentGloss": "和"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 13, "dependentGloss": "脱氧核糖核酸"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 14, "dependentGloss": "生物"}, {"dep": "compound:nn", "governor": 18, "governorGloss": "项目", "dependent": 15, "dependentGloss": "技术"}, {"dep": "nummod", "governor": 18, "governorGloss": "项目", "dependent": 16, "dependentGloss": "两"}, {"dep": "mark:clf", "governor": 16, "governorGloss": "两", "dependent": 17, "dependentGloss": "个"}, {"dep": "punct", "governor": 18, "governorGloss": "项目", "dependent": 19, "dependentGloss": ","}, {"dep": "advmod", "governor": 22, "governorGloss": "签订", "dependent": 20, "dependentGloss": "同时"}, {"dep": "advmod", "governor": 22, "governorGloss": "签订", "dependent": 21, "dependentGloss": "还"}, {"dep": "conj", "governor": 18, "governorGloss": "项目", "dependent": 22, "dependentGloss": "签订"}, {"dep": "aux:asp", "governor": 22, "governorGloss": "签订", "dependent": 23, "dependentGloss": "了"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 24, "dependentGloss": "语言"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 25, "dependentGloss": "教学"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 26, "dependentGloss": "交流"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 27, "dependentGloss": "合作"}, {"dep": "dobj", "governor": 22, "governorGloss": "签订", "dependent": 28, "dependentGloss": "协议"}, {"dep": "punct", "governor": 18, "governorGloss": "项目", "dependent": 29, "dependentGloss": "。"}], "tokens": [{"index": 1, "word": "上午", "originalText": "上午", "characterOffsetBegin": 0, "characterOffsetEnd": 2, "pos": "NT"}, {"index": 2, "word": "在", "originalText": "在", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "P"}, {"index": 3, "word": "这里", "originalText": "这里", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "PN"}, {"index": 4, "word": "签字", "originalText": "签字", "characterOffsetBegin": 5, "characterOffsetEnd": 7, "pos": "VV"}, {"index": 5, "word": "的", "originalText": "的", "characterOffsetBegin": 7, "characterOffsetEnd": 8, "pos": "DEC"}, {"index": 6, "word": "是", "originalText": "是", "characterOffsetBegin": 8, "characterOffsetEnd": 9, "pos": "VC"}, {"index": 7, "word": "知识", "originalText": "知识", "characterOffsetBegin": 9, "characterOffsetEnd": 11, "pos": "NN"}, {"index": 8, "word": "信息", "originalText": "信息", "characterOffsetBegin": 11, "characterOffsetEnd": 13, "pos": "NN"}, {"index": 9, "word": "网络", "originalText": "网络", "characterOffsetBegin": 13, "characterOffsetEnd": 15, "pos": "NN"}, {"index": 10, "word": "通讯", "originalText": "通讯", "characterOffsetBegin": 15, "characterOffsetEnd": 17, "pos": "NN"}, {"index": 11, "word": "技术", "originalText": "技术", "characterOffsetBegin": 17, "characterOffsetEnd": 19, "pos": "NN"}, {"index": 12, "word": "和", "originalText": "和", "characterOffsetBegin": 19, "characterOffsetEnd": 20, "pos": "CC"}, {"index": 13, "word": "脱氧核糖核酸", "originalText": "脱氧核糖核酸", "characterOffsetBegin": 20, "characterOffsetEnd": 26, "pos": "NN"}, {"index": 14, "word": "生物", "originalText": "生物", "characterOffsetBegin": 26, "characterOffsetEnd": 28, "pos": "NN"}, {"index": 15, "word": "技术", "originalText": "技术", "characterOffsetBegin": 28, "characterOffsetEnd": 30, "pos": "NN"}, {"index": 16, "word": "两", "originalText": "两", "characterOffsetBegin": 30, "characterOffsetEnd": 31, "pos": "CD"}, {"index": 17, "word": "个", "originalText": "个", "characterOffsetBegin": 31, "characterOffsetEnd": 32, "pos": "M"}, {"index": 18, "word": "项目", "originalText": "项目", "characterOffsetBegin": 32, "characterOffsetEnd": 34, "pos": "NN"}, {"index": 19, "word": ",", "originalText": ",", "characterOffsetBegin": 34, "characterOffsetEnd": 35, "pos": "PU"}, {"index": 20, "word": "同时", "originalText": "同时", "characterOffsetBegin": 35, "characterOffsetEnd": 37, "pos": "AD"}, {"index": 21, "word": "还", "originalText": "还", "characterOffsetBegin": 37, "characterOffsetEnd": 38, "pos": "AD"}, {"index": 22, "word": "签订", "originalText": "签订", "characterOffsetBegin": 38, "characterOffsetEnd": 40, "pos": "VV"}, {"index": 23, "word": "了", "originalText": "了", "characterOffsetBegin": 40, "characterOffsetEnd": 41, "pos": "AS"}, {"index": 24, "word": "语言", "originalText": "语言", "characterOffsetBegin": 41, "characterOffsetEnd": 43, "pos": "NN"}, {"index": 25, "word": "教学", "originalText": "教学", "characterOffsetBegin": 43, "characterOffsetEnd": 45, "pos": "NN"}, {"index": 26, "word": "交流", "originalText": "交流", "characterOffsetBegin": 45, "characterOffsetEnd": 47, "pos": "NN"}, {"index": 27, "word": "合作", "originalText": "合作", "characterOffsetBegin": 47, "characterOffsetEnd": 49, "pos": "NN"}, {"index": 28, "word": "协议", "originalText": "协议", "characterOffsetBegin": 49, "characterOffsetEnd": 51, "pos": "NN"}, {"index": 29, "word": "。", "originalText": "。", "characterOffsetBegin": 51, "characterOffsetEnd": 52, "pos": "PU"}]} 5 | -------------------------------------------------------------------------------- /sample_data/dev.tsv: -------------------------------------------------------------------------------- 1 | 中 S-NR 2 | 美 S-NR 3 | 在 S-P 4 | 沪 S-NR 5 | 签 B-VV 6 | 订 E-VV 7 | 高 S-JJ 8 | 科 B-NN 9 | 技 E-NN 10 | 合 B-NN 11 | 作 E-NN 12 | 协 B-NN 13 | 议 E-NN 14 | 15 | 16 | 新 B-NR 17 | 华 I-NR 18 | 社 E-NR 19 | 上 B-NR 20 | 海 E-NR 21 | 八 B-NT 22 | 月 E-NT 23 | 三 B-NT 24 | 十 I-NT 25 | 一 I-NT 26 | 日 E-NT 27 | 电 S-NN 28 | ( S-PU 29 | 记 B-NN 30 | 者 E-NN 31 | 白 B-NR 32 | 国 I-NR 33 | 良 E-NR 34 | 、 S-PU 35 | 夏 B-NR 36 | 儒 I-NR 37 | 阁 E-NR 38 | ) S-PU 39 | 40 | 41 | “ S-PU 42 | 中 S-NR 43 | 美 S-NR 44 | 合 B-NN 45 | 作 E-NN 46 | 高 S-JJ 47 | 科 B-NN 48 | 技 E-NN 49 | 项 B-NN 50 | 目 E-NN 51 | 签 B-NN 52 | 字 E-NN 53 | 仪 B-NN 54 | 式 E-NN 55 | ” S-PU 56 | 今 B-NT 57 | 天 E-NT 58 | 在 S-P 59 | 上 B-NR 60 | 海 E-NR 61 | 举 B-VV 62 | 行 E-VV 63 | 。 S-PU 64 | 65 | 66 | 上 B-NT 67 | 午 E-NT 68 | 在 S-P 69 | 这 B-PN 70 | 里 E-PN 71 | 签 B-VV 72 | 字 E-VV 73 | 的 S-DEC 74 | 是 S-VC 75 | 知 B-NN 76 | 识 E-NN 77 | 信 B-NN 78 | 息 E-NN 79 | 网 B-NN 80 | 络 E-NN 81 | 通 B-NN 82 | 讯 E-NN 83 | 技 B-NN 84 | 术 E-NN 85 | 和 S-CC 86 | 脱 B-NN 87 | 氧 I-NN 88 | 核 I-NN 89 | 糖 I-NN 90 | 核 I-NN 91 | 酸 E-NN 92 | 生 B-NN 93 | 物 E-NN 94 | 技 B-NN 95 | 术 E-NN 96 | 两 S-CD 97 | 个 S-M 98 | 项 B-NN 99 | 目 E-NN 100 | , S-PU 101 | 同 B-AD 102 | 时 E-AD 103 | 还 S-AD 104 | 签 B-VV 105 | 订 E-VV 106 | 了 S-AS 107 | 语 B-NN 108 | 言 E-NN 109 | 教 B-NN 110 | 学 E-NN 111 | 交 B-NN 112 | 流 E-NN 113 | 合 B-NN 114 | 作 E-NN 115 | 协 B-NN 116 | 议 E-NN 117 | 。 S-PU 118 | 119 | -------------------------------------------------------------------------------- /sample_data/label2id: -------------------------------------------------------------------------------- 1 | B-NR 2 | E-NR 3 | B-NN 4 | E-NN 5 | S-CC 6 | B-VV 7 | E-VV 8 | I-NN 9 | B-NT 10 | E-NT 11 | S-NN 12 | S-PU 13 | I-NR 14 | S-LC 15 | S-AS 16 | S-ETC 17 | S-DEC 18 | B-CD 19 | I-CD 20 | E-CD 21 | S-M 22 | S-DEG 23 | B-JJ 24 | E-JJ 25 | S-VC 26 | S-CD 27 | I-JJ 28 | B-AD 29 | E-AD 30 | S-AD 31 | S-JJ 32 | S-P 33 | S-PN 34 | B-VA 35 | E-VA 36 | S-DEV 37 | S-VV 38 | B-LC 39 | E-LC 40 | B-DT 41 | E-DT 42 | S-SB 43 | B-OD 44 | E-OD 45 | B-P 46 | E-P 47 | S-VE 48 | S-DT 49 | B-M 50 | E-M 51 | B-CS 52 | E-CS 53 | B-PN 54 | E-PN 55 | S-VA 56 | I-NT 57 | S-NR 58 | -------------------------------------------------------------------------------- /sample_data/sentence.txt: -------------------------------------------------------------------------------- 1 | 共同创造美好的新世纪——二○○一年新年贺词 2 | (二○○○年十二月三十一日)(附图片1张) 3 | 女士们,先生们,同志们,朋友们: 4 | 2001年新年钟声即将敲响。人类社会前进的航船就要驶入21世纪的新航程。中国人民进入了向现代化建设第三步战略目标迈进的新征程。 5 | 在这个激动人心的时刻,我很高兴通过中国国际广播电台、中央人民广播电台和中央电视台,向全国各族人民,向香港特别行政区同胞、澳门特别行政区同胞和台湾同胞、海外侨胞,向世界各国的朋友们,致以新世纪第一个新年的祝贺! 6 | -------------------------------------------------------------------------------- /sample_data/test.stanford.json: -------------------------------------------------------------------------------- 1 | {"index": 0, "parse": "(ROOT (ROOT\n (IP\n (NP (NR 中) (NR 美))\n (VP\n (PP (P 在)\n (NP (NR 沪)))\n (VP (VV 签订)\n (NP (NN 高科技) (NN 合作) (NN 协议)))))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 5, "dependentGloss": "签订"}, {"dep": "name", "governor": 2, "governorGloss": "美", "dependent": 1, "dependentGloss": "中"}, {"dep": "nsubj", "governor": 5, "governorGloss": "签订", "dependent": 2, "dependentGloss": "美"}, {"dep": "case", "governor": 4, "governorGloss": "沪", "dependent": 3, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 5, "governorGloss": "签订", "dependent": 4, "dependentGloss": "沪"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "协议", "dependent": 6, "dependentGloss": "高科技"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "协议", "dependent": 7, "dependentGloss": "合作"}, {"dep": "dobj", "governor": 5, "governorGloss": "签订", "dependent": 8, "dependentGloss": "协议"}], "tokens": [{"index": 1, "word": "中", "originalText": "中", "characterOffsetBegin": 0, "characterOffsetEnd": 1, "pos": "NR"}, {"index": 2, "word": "美", "originalText": "美", "characterOffsetBegin": 1, "characterOffsetEnd": 2, "pos": "NR"}, {"index": 3, "word": "在", "originalText": "在", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "P"}, {"index": 4, "word": "沪", "originalText": "沪", "characterOffsetBegin": 3, "characterOffsetEnd": 4, "pos": "NR"}, {"index": 5, "word": "签订", "originalText": "签订", "characterOffsetBegin": 4, "characterOffsetEnd": 6, "pos": "VV"}, {"index": 6, "word": "高科技", "originalText": "高科技", "characterOffsetBegin": 6, "characterOffsetEnd": 9, "pos": "NN"}, {"index": 7, "word": "合作", "originalText": "合作", "characterOffsetBegin": 9, "characterOffsetEnd": 11, "pos": "NN"}, {"index": 8, "word": "协议", "originalText": "协议", "characterOffsetBegin": 11, "characterOffsetEnd": 13, "pos": "NN"}]} 2 | {"index": 0, "parse": "(ROOT (ROOT\n (FRAG (NR 新华社) (NR 上海) (NT 八月) (NT 三十一日) (NN 电) (PU () (NN 记者) (NR 白国良) (PU 、) (NR 夏儒阁) (PU )))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 7, "dependentGloss": "记者"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 1, "dependentGloss": "新华社"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 2, "dependentGloss": "上海"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 3, "dependentGloss": "八月"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 4, "dependentGloss": "三十一日"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 5, "dependentGloss": "电"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 6, "dependentGloss": "("}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 8, "dependentGloss": "白国良"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 9, "dependentGloss": "、"}, {"dep": "dep", "governor": 7, "governorGloss": "记者", "dependent": 10, "dependentGloss": "夏儒阁"}, {"dep": "punct", "governor": 7, "governorGloss": "记者", "dependent": 11, "dependentGloss": ")"}], "tokens": [{"index": 1, "word": "新华社", "originalText": "新华社", "characterOffsetBegin": 0, "characterOffsetEnd": 3, "pos": "NR"}, {"index": 2, "word": "上海", "originalText": "上海", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "NR"}, {"index": 3, "word": "八月", "originalText": "八月", "characterOffsetBegin": 5, "characterOffsetEnd": 7, "pos": "NT"}, {"index": 4, "word": "三十一日", "originalText": "三十一日", "characterOffsetBegin": 7, "characterOffsetEnd": 11, "pos": "NT"}, {"index": 5, "word": "电", "originalText": "电", "characterOffsetBegin": 11, "characterOffsetEnd": 12, "pos": "NN"}, {"index": 6, "word": "(", "originalText": "(", "characterOffsetBegin": 12, "characterOffsetEnd": 13, "pos": "PU"}, {"index": 7, "word": "记者", "originalText": "记者", "characterOffsetBegin": 13, "characterOffsetEnd": 15, "pos": "NN"}, {"index": 8, "word": "白国良", "originalText": "白国良", "characterOffsetBegin": 15, "characterOffsetEnd": 18, "pos": "NR"}, {"index": 9, "word": "、", "originalText": "、", "characterOffsetBegin": 18, "characterOffsetEnd": 19, "pos": "PU"}, {"index": 10, "word": "夏儒阁", "originalText": "夏儒阁", "characterOffsetBegin": 19, "characterOffsetEnd": 22, "pos": "NR"}, {"index": 11, "word": ")", "originalText": ")", "characterOffsetBegin": 22, "characterOffsetEnd": 23, "pos": "PU"}]} 3 | {"index": 0, "parse": "(ROOT (ROOT\n (IP (PU “)\n (NP\n (NP (NR 中) (NR 美))\n (VV 合作)\n (NP\n (ADJP (NN 高科技))\n (NP (NN 项目)))\n (NP (NN 签字) (NN 仪式))\n (PU ”))\n (VP\n (NP (NT 今天))\n (PP (P 在)\n (NP (NR 上海)))\n (VP (VV 举行)))\n (PU 。))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 13, "dependentGloss": "举行"}, {"dep": "punct", "governor": 8, "governorGloss": "仪式", "dependent": 1, "dependentGloss": "“"}, {"dep": "name", "governor": 3, "governorGloss": "美", "dependent": 2, "dependentGloss": "中"}, {"dep": "nsubj", "governor": 4, "governorGloss": "合作", "dependent": 3, "dependentGloss": "美"}, {"dep": "acl", "governor": 6, "governorGloss": "项目", "dependent": 4, "dependentGloss": "合作"}, {"dep": "compound:nn", "governor": 6, "governorGloss": "项目", "dependent": 5, "dependentGloss": "高科技"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "仪式", "dependent": 6, "dependentGloss": "项目"}, {"dep": "compound:nn", "governor": 8, "governorGloss": "仪式", "dependent": 7, "dependentGloss": "签字"}, {"dep": "nsubj", "governor": 13, "governorGloss": "举行", "dependent": 8, "dependentGloss": "仪式"}, {"dep": "punct", "governor": 8, "governorGloss": "仪式", "dependent": 9, "dependentGloss": "”"}, {"dep": "nmod:tmod", "governor": 13, "governorGloss": "举行", "dependent": 10, "dependentGloss": "今天"}, {"dep": "case", "governor": 12, "governorGloss": "上海", "dependent": 11, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 13, "governorGloss": "举行", "dependent": 12, "dependentGloss": "上海"}, {"dep": "punct", "governor": 13, "governorGloss": "举行", "dependent": 14, "dependentGloss": "。"}], "tokens": [{"index": 1, "word": "“", "originalText": "“", "characterOffsetBegin": 0, "characterOffsetEnd": 1, "pos": "PU"}, {"index": 2, "word": "中", "originalText": "中", "characterOffsetBegin": 1, "characterOffsetEnd": 2, "pos": "NR"}, {"index": 3, "word": "美", "originalText": "美", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "NR"}, {"index": 4, "word": "合作", "originalText": "合作", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "VV"}, {"index": 5, "word": "高科技", "originalText": "高科技", "characterOffsetBegin": 5, "characterOffsetEnd": 8, "pos": "NN"}, {"index": 6, "word": "项目", "originalText": "项目", "characterOffsetBegin": 8, "characterOffsetEnd": 10, "pos": "NN"}, {"index": 7, "word": "签字", "originalText": "签字", "characterOffsetBegin": 10, "characterOffsetEnd": 12, "pos": "NN"}, {"index": 8, "word": "仪式", "originalText": "仪式", "characterOffsetBegin": 12, "characterOffsetEnd": 14, "pos": "NN"}, {"index": 9, "word": "”", "originalText": "”", "characterOffsetBegin": 14, "characterOffsetEnd": 15, "pos": "PU"}, {"index": 10, "word": "今天", "originalText": "今天", "characterOffsetBegin": 15, "characterOffsetEnd": 17, "pos": "NT"}, {"index": 11, "word": "在", "originalText": "在", "characterOffsetBegin": 17, "characterOffsetEnd": 18, "pos": "P"}, {"index": 12, "word": "上海", "originalText": "上海", "characterOffsetBegin": 18, "characterOffsetEnd": 20, "pos": "NR"}, {"index": 13, "word": "举行", "originalText": "举行", "characterOffsetBegin": 20, "characterOffsetEnd": 22, "pos": "VV"}, {"index": 14, "word": "。", "originalText": "。", "characterOffsetBegin": 22, "characterOffsetEnd": 23, "pos": "PU"}]} 4 | {"index": 0, "parse": "(ROOT (ROOT\n (IP\n (IP\n (NP\n (CP\n (IP\n (VP\n (NP (NT 上午))\n (PP (P 在)\n (NP (PN 这里)))\n (VP (VV 签字))))\n (DEC 的)))\n (VP (VC 是)\n (NP\n (NP (NN 知识) (NN 信息) (NN 网络) (NN 通讯) (NN 技术)\n (CC 和)\n (NN 脱氧核糖核酸) (NN 生物) (NN 技术))\n (QP (CD 两)\n (CLP (M 个)))\n (NP (NN 项目)))))\n (PU ,)\n (IP\n (VP\n (ADVP (AD 同时))\n (ADVP (AD 还))\n (VP (VV 签订) (AS 了)\n (NP (NN 语言) (NN 教学) (NN 交流) (NN 合作) (NN 协议)))))\n (PU 。))) )", "basicDependencies": [{"dep": "ROOT", "governor": 0, "governorGloss": "ROOT", "dependent": 18, "dependentGloss": "项目"}, {"dep": "nmod:tmod", "governor": 4, "governorGloss": "签字", "dependent": 1, "dependentGloss": "上午"}, {"dep": "case", "governor": 3, "governorGloss": "这里", "dependent": 2, "dependentGloss": "在"}, {"dep": "nmod:prep", "governor": 4, "governorGloss": "签字", "dependent": 3, "dependentGloss": "这里"}, {"dep": "nsubj", "governor": 18, "governorGloss": "项目", "dependent": 4, "dependentGloss": "签字"}, {"dep": "mark", "governor": 4, "governorGloss": "签字", "dependent": 5, "dependentGloss": "的"}, {"dep": "cop", "governor": 18, "governorGloss": "项目", "dependent": 6, "dependentGloss": "是"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 7, "dependentGloss": "知识"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 8, "dependentGloss": "信息"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 9, "dependentGloss": "网络"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 10, "dependentGloss": "通讯"}, {"dep": "conj", "governor": 15, "governorGloss": "技术", "dependent": 11, "dependentGloss": "技术"}, {"dep": "cc", "governor": 15, "governorGloss": "技术", "dependent": 12, "dependentGloss": "和"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 13, "dependentGloss": "脱氧核糖核酸"}, {"dep": "compound:nn", "governor": 15, "governorGloss": "技术", "dependent": 14, "dependentGloss": "生物"}, {"dep": "compound:nn", "governor": 18, "governorGloss": "项目", "dependent": 15, "dependentGloss": "技术"}, {"dep": "nummod", "governor": 18, "governorGloss": "项目", "dependent": 16, "dependentGloss": "两"}, {"dep": "mark:clf", "governor": 16, "governorGloss": "两", "dependent": 17, "dependentGloss": "个"}, {"dep": "punct", "governor": 18, "governorGloss": "项目", "dependent": 19, "dependentGloss": ","}, {"dep": "advmod", "governor": 22, "governorGloss": "签订", "dependent": 20, "dependentGloss": "同时"}, {"dep": "advmod", "governor": 22, "governorGloss": "签订", "dependent": 21, "dependentGloss": "还"}, {"dep": "conj", "governor": 18, "governorGloss": "项目", "dependent": 22, "dependentGloss": "签订"}, {"dep": "aux:asp", "governor": 22, "governorGloss": "签订", "dependent": 23, "dependentGloss": "了"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 24, "dependentGloss": "语言"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 25, "dependentGloss": "教学"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 26, "dependentGloss": "交流"}, {"dep": "compound:nn", "governor": 28, "governorGloss": "协议", "dependent": 27, "dependentGloss": "合作"}, {"dep": "dobj", "governor": 22, "governorGloss": "签订", "dependent": 28, "dependentGloss": "协议"}, {"dep": "punct", "governor": 18, "governorGloss": "项目", "dependent": 29, "dependentGloss": "。"}], "tokens": [{"index": 1, "word": "上午", "originalText": "上午", "characterOffsetBegin": 0, "characterOffsetEnd": 2, "pos": "NT"}, {"index": 2, "word": "在", "originalText": "在", "characterOffsetBegin": 2, "characterOffsetEnd": 3, "pos": "P"}, {"index": 3, "word": "这里", "originalText": "这里", "characterOffsetBegin": 3, "characterOffsetEnd": 5, "pos": "PN"}, {"index": 4, "word": "签字", "originalText": "签字", "characterOffsetBegin": 5, "characterOffsetEnd": 7, "pos": "VV"}, {"index": 5, "word": "的", "originalText": "的", "characterOffsetBegin": 7, "characterOffsetEnd": 8, "pos": "DEC"}, {"index": 6, "word": "是", "originalText": "是", "characterOffsetBegin": 8, "characterOffsetEnd": 9, "pos": "VC"}, {"index": 7, "word": "知识", "originalText": "知识", "characterOffsetBegin": 9, "characterOffsetEnd": 11, "pos": "NN"}, {"index": 8, "word": "信息", "originalText": "信息", "characterOffsetBegin": 11, "characterOffsetEnd": 13, "pos": "NN"}, {"index": 9, "word": "网络", "originalText": "网络", "characterOffsetBegin": 13, "characterOffsetEnd": 15, "pos": "NN"}, {"index": 10, "word": "通讯", "originalText": "通讯", "characterOffsetBegin": 15, "characterOffsetEnd": 17, "pos": "NN"}, {"index": 11, "word": "技术", "originalText": "技术", "characterOffsetBegin": 17, "characterOffsetEnd": 19, "pos": "NN"}, {"index": 12, "word": "和", "originalText": "和", "characterOffsetBegin": 19, "characterOffsetEnd": 20, "pos": "CC"}, {"index": 13, "word": "脱氧核糖核酸", "originalText": "脱氧核糖核酸", "characterOffsetBegin": 20, "characterOffsetEnd": 26, "pos": "NN"}, {"index": 14, "word": "生物", "originalText": "生物", "characterOffsetBegin": 26, "characterOffsetEnd": 28, "pos": "NN"}, {"index": 15, "word": "技术", "originalText": "技术", "characterOffsetBegin": 28, "characterOffsetEnd": 30, "pos": "NN"}, {"index": 16, "word": "两", "originalText": "两", "characterOffsetBegin": 30, "characterOffsetEnd": 31, "pos": "CD"}, {"index": 17, "word": "个", "originalText": "个", "characterOffsetBegin": 31, "characterOffsetEnd": 32, "pos": "M"}, {"index": 18, "word": "项目", "originalText": "项目", "characterOffsetBegin": 32, "characterOffsetEnd": 34, "pos": "NN"}, {"index": 19, "word": ",", "originalText": ",", "characterOffsetBegin": 34, "characterOffsetEnd": 35, "pos": "PU"}, {"index": 20, "word": "同时", "originalText": "同时", "characterOffsetBegin": 35, "characterOffsetEnd": 37, "pos": "AD"}, {"index": 21, "word": "还", "originalText": "还", "characterOffsetBegin": 37, "characterOffsetEnd": 38, "pos": "AD"}, {"index": 22, "word": "签订", "originalText": "签订", "characterOffsetBegin": 38, "characterOffsetEnd": 40, "pos": "VV"}, {"index": 23, "word": "了", "originalText": "了", "characterOffsetBegin": 40, "characterOffsetEnd": 41, "pos": "AS"}, {"index": 24, "word": "语言", "originalText": "语言", "characterOffsetBegin": 41, "characterOffsetEnd": 43, "pos": "NN"}, {"index": 25, "word": "教学", "originalText": "教学", "characterOffsetBegin": 43, "characterOffsetEnd": 45, "pos": "NN"}, {"index": 26, "word": "交流", "originalText": "交流", "characterOffsetBegin": 45, "characterOffsetEnd": 47, "pos": "NN"}, {"index": 27, "word": "合作", "originalText": "合作", "characterOffsetBegin": 47, "characterOffsetEnd": 49, "pos": "NN"}, {"index": 28, "word": "协议", "originalText": "协议", "characterOffsetBegin": 49, "characterOffsetEnd": 51, "pos": "NN"}, {"index": 29, "word": "。", "originalText": "。", "characterOffsetBegin": 51, "characterOffsetEnd": 52, "pos": "PU"}]} 5 | -------------------------------------------------------------------------------- /sample_data/test.tsv: -------------------------------------------------------------------------------- 1 | 中 S-NR 2 | 美 S-NR 3 | 在 S-P 4 | 沪 S-NR 5 | 签 B-VV 6 | 订 E-VV 7 | 高 S-JJ 8 | 科 B-NN 9 | 技 E-NN 10 | 合 B-NN 11 | 作 E-NN 12 | 协 B-NN 13 | 议 E-NN 14 | 15 | 16 | 新 B-NR 17 | 华 I-NR 18 | 社 E-NR 19 | 上 B-NR 20 | 海 E-NR 21 | 八 B-NT 22 | 月 E-NT 23 | 三 B-NT 24 | 十 I-NT 25 | 一 I-NT 26 | 日 E-NT 27 | 电 S-NN 28 | ( S-PU 29 | 记 B-NN 30 | 者 E-NN 31 | 白 B-NR 32 | 国 I-NR 33 | 良 E-NR 34 | 、 S-PU 35 | 夏 B-NR 36 | 儒 I-NR 37 | 阁 E-NR 38 | ) S-PU 39 | 40 | 41 | “ S-PU 42 | 中 S-NR 43 | 美 S-NR 44 | 合 B-NN 45 | 作 E-NN 46 | 高 S-JJ 47 | 科 B-NN 48 | 技 E-NN 49 | 项 B-NN 50 | 目 E-NN 51 | 签 B-NN 52 | 字 E-NN 53 | 仪 B-NN 54 | 式 E-NN 55 | ” S-PU 56 | 今 B-NT 57 | 天 E-NT 58 | 在 S-P 59 | 上 B-NR 60 | 海 E-NR 61 | 举 B-VV 62 | 行 E-VV 63 | 。 S-PU 64 | 65 | 66 | 上 B-NT 67 | 午 E-NT 68 | 在 S-P 69 | 这 B-PN 70 | 里 E-PN 71 | 签 B-VV 72 | 字 E-VV 73 | 的 S-DEC 74 | 是 S-VC 75 | 知 B-NN 76 | 识 E-NN 77 | 信 B-NN 78 | 息 E-NN 79 | 网 B-NN 80 | 络 E-NN 81 | 通 B-NN 82 | 讯 E-NN 83 | 技 B-NN 84 | 术 E-NN 85 | 和 S-CC 86 | 脱 B-NN 87 | 氧 I-NN 88 | 核 I-NN 89 | 糖 I-NN 90 | 核 I-NN 91 | 酸 E-NN 92 | 生 B-NN 93 | 物 E-NN 94 | 技 B-NN 95 | 术 E-NN 96 | 两 S-CD 97 | 个 S-M 98 | 项 B-NN 99 | 目 E-NN 100 | , S-PU 101 | 同 B-AD 102 | 时 E-AD 103 | 还 S-AD 104 | 签 B-VV 105 | 订 E-VV 106 | 了 S-AS 107 | 语 B-NN 108 | 言 E-NN 109 | 教 B-NN 110 | 学 E-NN 111 | 交 B-NN 112 | 流 E-NN 113 | 合 B-NN 114 | 作 E-NN 115 | 协 B-NN 116 | 议 E-NN 117 | 。 S-PU 118 | 119 | -------------------------------------------------------------------------------- /sample_data/train.tsv: -------------------------------------------------------------------------------- 1 | 上 B-NR 2 | 海 E-NR 3 | 浦 B-NR 4 | 东 E-NR 5 | 开 B-NN 6 | 发 E-NN 7 | 与 S-CC 8 | 法 B-NN 9 | 制 E-NN 10 | 建 B-NN 11 | 设 E-NN 12 | 同 B-VV 13 | 步 E-VV 14 | 15 | 16 | 新 B-NN 17 | 华 I-NN 18 | 社 E-NN 19 | 上 B-NR 20 | 海 E-NR 21 | 二 B-NT 22 | 月 E-NT 23 | 十 B-NT 24 | 日 E-NT 25 | 电 S-NN 26 | ( S-PU 27 | 记 B-NN 28 | 者 E-NN 29 | 谢 B-NR 30 | 金 I-NR 31 | 虎 E-NR 32 | 、 S-PU 33 | 张 B-NR 34 | 持 I-NR 35 | 坚 E-NR 36 | ) S-PU 37 | 38 | 39 | 上 B-NR 40 | 海 E-NR 41 | 浦 B-NR 42 | 东 E-NR 43 | 近 B-NT 44 | 年 E-NT 45 | 来 S-LC 46 | 颁 B-VV 47 | 布 E-VV 48 | 实 B-VV 49 | 行 E-VV 50 | 了 S-AS 51 | 涉 B-VV 52 | 及 E-VV 53 | 经 B-NN 54 | 济 E-NN 55 | 、 S-PU 56 | 贸 B-NN 57 | 易 E-NN 58 | 、 S-PU 59 | 建 B-NN 60 | 设 E-NN 61 | 、 S-PU 62 | 规 B-NN 63 | 划 E-NN 64 | 、 S-PU 65 | 科 B-NN 66 | 技 E-NN 67 | 、 S-PU 68 | 文 B-NN 69 | 教 E-NN 70 | 等 S-ETC 71 | 领 B-NN 72 | 域 E-NN 73 | 的 S-DEC 74 | 七 B-CD 75 | 十 I-CD 76 | 一 E-CD 77 | 件 S-M 78 | 法 B-NN 79 | 规 I-NN 80 | 性 E-NN 81 | 文 B-NN 82 | 件 E-NN 83 | , S-PU 84 | 确 B-VV 85 | 保 E-VV 86 | 了 S-AS 87 | 浦 B-NR 88 | 东 E-NR 89 | 开 B-NN 90 | 发 E-NN 91 | 的 S-DEG 92 | 有 B-JJ 93 | 序 E-JJ 94 | 进 B-NN 95 | 行 E-NN 96 | 。 S-PU 97 | 98 | 99 | 浦 B-NR 100 | 东 E-NR 101 | 开 B-NN 102 | 发 E-NN 103 | 开 B-NN 104 | 放 E-NN 105 | 是 S-VC 106 | 一 S-CD 107 | 项 S-M 108 | 振 B-VV 109 | 兴 E-VV 110 | 上 B-NR 111 | 海 E-NR 112 | , S-PU 113 | 建 B-VV 114 | 设 E-VV 115 | 现 B-NN 116 | 代 I-NN 117 | 化 E-NN 118 | 经 B-NN 119 | 济 E-NN 120 | 、 S-PU 121 | 贸 B-NN 122 | 易 E-NN 123 | 、 S-PU 124 | 金 B-NN 125 | 融 E-NN 126 | 中 B-NN 127 | 心 E-NN 128 | 的 S-DEC 129 | 跨 B-JJ 130 | 世 I-JJ 131 | 纪 E-JJ 132 | 工 B-NN 133 | 程 E-NN 134 | , S-PU 135 | 因 B-AD 136 | 此 E-AD 137 | 大 B-AD 138 | 量 E-AD 139 | 出 B-VV 140 | 现 E-VV 141 | 的 S-DEC 142 | 是 S-VC 143 | 以 B-NT 144 | 前 E-NT 145 | 不 S-AD 146 | 曾 S-AD 147 | 遇 B-VV 148 | 到 E-VV 149 | 过 S-AS 150 | 的 S-DEC 151 | 新 S-JJ 152 | 情 B-NN 153 | 况 E-NN 154 | 、 S-PU 155 | 新 S-JJ 156 | 问 B-NN 157 | 题 E-NN 158 | 。 S-PU 159 | 160 | 161 | 对 S-P 162 | 此 S-PN 163 | , S-PU 164 | 浦 B-NR 165 | 东 E-NR 166 | 不 S-AD 167 | 是 S-VC 168 | 简 B-VA 169 | 单 E-VA 170 | 的 S-DEV 171 | 采 B-VV 172 | 取 E-VV 173 | “ S-PU 174 | 干 S-VV 175 | 一 S-CD 176 | 段 S-M 177 | 时 B-NN 178 | 间 E-NN 179 | , S-PU 180 | 等 S-P 181 | 积 B-VV 182 | 累 E-VV 183 | 了 S-AS 184 | 经 B-NN 185 | 验 E-NN 186 | 以 B-LC 187 | 后 E-LC 188 | 再 S-AD 189 | 制 B-VV 190 | 定 E-VV 191 | 法 B-NN 192 | 规 E-NN 193 | 条 B-NN 194 | 例 E-NN 195 | ” S-PU 196 | 的 S-DEC 197 | 做 B-NN 198 | 法 E-NN 199 | , S-PU 200 | 而 S-CC 201 | 是 S-VC 202 | 借 B-VV 203 | 鉴 E-VV 204 | 发 B-JJ 205 | 达 E-JJ 206 | 国 B-NN 207 | 家 E-NN 208 | 和 S-CC 209 | 深 B-NR 210 | 圳 E-NR 211 | 等 S-ETC 212 | 特 B-NN 213 | 区 E-NN 214 | 的 S-DEG 215 | 经 B-NN 216 | 验 E-NN 217 | 教 B-NN 218 | 训 E-NN 219 | , S-PU 220 | 聘 B-VV 221 | 请 E-VV 222 | 国 B-NN 223 | 内 I-NN 224 | 外 E-NN 225 | 有 B-JJ 226 | 关 E-JJ 227 | 专 B-NN 228 | 家 E-NN 229 | 学 B-NN 230 | 者 E-NN 231 | , S-PU 232 | 积 B-AD 233 | 极 E-AD 234 | 、 S-PU 235 | 及 B-AD 236 | 时 E-AD 237 | 地 S-DEV 238 | 制 B-VV 239 | 定 E-VV 240 | 和 S-CC 241 | 推 B-VV 242 | 出 E-VV 243 | 法 B-NN 244 | 规 I-NN 245 | 性 E-NN 246 | 文 B-NN 247 | 件 E-NN 248 | , S-PU 249 | 使 S-VV 250 | 这 B-DT 251 | 些 E-DT 252 | 经 B-NN 253 | 济 E-NN 254 | 活 B-NN 255 | 动 E-NN 256 | 一 S-AD 257 | 出 B-VV 258 | 现 E-VV 259 | 就 S-AD 260 | 被 S-SB 261 | 纳 B-VV 262 | 入 E-VV 263 | 法 B-NN 264 | 制 E-NN 265 | 轨 B-NN 266 | 道 E-NN 267 | 。 S-PU 268 | 269 | 270 | 去 B-NT 271 | 年 E-NT 272 | 初 S-LC 273 | 浦 B-NR 274 | 东 E-NR 275 | 新 B-NN 276 | 区 E-NN 277 | 诞 B-VV 278 | 生 E-VV 279 | 的 S-DEC 280 | 中 B-NR 281 | 国 E-NR 282 | 第 B-OD 283 | 一 E-OD 284 | 家 S-M 285 | 医 B-NN 286 | 疗 E-NN 287 | 机 B-NN 288 | 构 E-NN 289 | 药 B-NN 290 | 品 E-NN 291 | 采 B-NN 292 | 购 E-NN 293 | 服 B-NN 294 | 务 E-NN 295 | 中 B-NN 296 | 心 E-NN 297 | , S-PU 298 | 正 S-AD 299 | 因 B-P 300 | 为 E-P 301 | 一 S-AD 302 | 开 B-VV 303 | 始 E-VV 304 | 就 S-AD 305 | 比 B-AD 306 | 较 E-AD 307 | 规 B-VA 308 | 范 E-VA 309 | , S-PU 310 | 运 B-VV 311 | 转 E-VV 312 | 至 B-VV 313 | 今 E-VV 314 | , S-PU 315 | 成 B-VV 316 | 交 E-VV 317 | 药 B-NN 318 | 品 E-NN 319 | 一 B-CD 320 | 亿 I-CD 321 | 多 E-CD 322 | 元 S-M 323 | , S-PU 324 | 没 B-VV 325 | 有 E-VV 326 | 发 B-VV 327 | 现 E-VV 328 | 一 S-CD 329 | 例 S-M 330 | 回 B-NN 331 | 扣 E-NN 332 | 。 S-PU 333 | 334 | 335 | 建 B-NN 336 | 筑 E-NN 337 | 是 S-VC 338 | 开 B-VV 339 | 发 E-VV 340 | 浦 B-NR 341 | 东 E-NR 342 | 的 S-DEC 343 | 一 S-CD 344 | 项 S-M 345 | 主 B-JJ 346 | 要 E-JJ 347 | 经 B-NN 348 | 济 E-NN 349 | 活 B-NN 350 | 动 E-NN 351 | , S-PU 352 | 这 B-DT 353 | 些 E-DT 354 | 年 S-M 355 | 有 S-VE 356 | 数 B-CD 357 | 百 E-CD 358 | 家 S-M 359 | 建 B-NN 360 | 筑 E-NN 361 | 公 B-NN 362 | 司 E-NN 363 | 、 S-PU 364 | 四 B-CD 365 | 千 I-CD 366 | 余 E-CD 367 | 个 S-M 368 | 建 B-NN 369 | 筑 E-NN 370 | 工 B-NN 371 | 地 E-NN 372 | 遍 B-VV 373 | 布 E-VV 374 | 在 S-P 375 | 这 S-DT 376 | 片 S-M 377 | 热 B-NN 378 | 土 E-NN 379 | 上 S-LC 380 | 。 S-PU 381 | 382 | 383 | 为 S-P 384 | 规 B-VV 385 | 范 E-VV 386 | 建 B-NN 387 | 筑 E-NN 388 | 行 B-NN 389 | 为 E-NN 390 | , S-PU 391 | 防 B-VV 392 | 止 E-VV 393 | 出 B-VV 394 | 现 E-VV 395 | 无 B-JJ 396 | 序 E-JJ 397 | 现 B-NN 398 | 象 E-NN 399 | , S-PU 400 | 新 B-NN 401 | 区 E-NN 402 | 管 B-NN 403 | 委 I-NN 404 | 会 E-NN 405 | 根 B-P 406 | 据 E-P 407 | 国 B-NN 408 | 家 E-NN 409 | 和 S-CC 410 | 上 B-NR 411 | 海 I-NR 412 | 市 E-NR 413 | 的 S-DEG 414 | 有 B-JJ 415 | 关 E-JJ 416 | 规 B-NN 417 | 定 E-NN 418 | , S-PU 419 | 结 B-VV 420 | 合 E-VV 421 | 浦 B-NR 422 | 东 E-NR 423 | 开 B-NN 424 | 发 E-NN 425 | 实 B-NN 426 | 际 E-NN 427 | , S-PU 428 | 及 B-AD 429 | 时 E-AD 430 | 出 B-VV 431 | 台 E-VV 432 | 了 S-AS 433 | 一 S-CD 434 | 系 B-M 435 | 列 E-M 436 | 规 B-VV 437 | 范 E-VV 438 | 建 B-NN 439 | 设 E-NN 440 | 市 B-NN 441 | 场 E-NN 442 | 的 S-DEC 443 | 文 B-NN 444 | 件 E-NN 445 | , S-PU 446 | 其 B-NN 447 | 中 E-NN 448 | 包 B-VV 449 | 括 E-VV 450 | 工 B-NN 451 | 程 E-NN 452 | 施 B-NN 453 | 工 E-NN 454 | 招 B-NN 455 | 投 I-NN 456 | 标 E-NN 457 | 管 B-NN 458 | 理 E-NN 459 | 办 B-NN 460 | 法 E-NN 461 | 、 S-PU 462 | 拆 B-NN 463 | 迁 E-NN 464 | 工 B-NN 465 | 作 E-NN 466 | 若 B-CD 467 | 干 E-CD 468 | 规 B-NN 469 | 定 E-NN 470 | 、 S-PU 471 | 整 B-VV 472 | 治 E-VV 473 | 违 B-JJ 474 | 章 E-JJ 475 | 建 B-NN 476 | 筑 E-NN 477 | 实 B-NN 478 | 施 E-NN 479 | 办 B-NN 480 | 法 E-NN 481 | 、 S-PU 482 | 通 B-NN 483 | 信 E-NN 484 | 设 B-NN 485 | 施 E-NN 486 | 及 S-CC 487 | 管 B-NN 488 | 线 E-NN 489 | 配 B-NN 490 | 套 E-NN 491 | 建 B-NN 492 | 设 E-NN 493 | 意 B-NN 494 | 见 E-NN 495 | 、 S-PU 496 | 建 B-NN 497 | 设 E-NN 498 | 工 B-NN 499 | 地 E-NN 500 | 施 B-NN 501 | 工 E-NN 502 | 环 B-NN 503 | 境 E-NN 504 | 管 B-NN 505 | 理 E-NN 506 | 暂 B-JJ 507 | 行 E-JJ 508 | 办 B-NN 509 | 法 E-NN 510 | 等 S-ETC 511 | , S-PU 512 | 基 B-AD 513 | 本 E-AD 514 | 做 B-VV 515 | 到 E-VV 516 | 了 S-AS 517 | 每 S-DT 518 | 个 S-M 519 | 环 B-NN 520 | 节 E-NN 521 | 都 S-AD 522 | 有 S-VE 523 | 明 B-VA 524 | 确 E-VA 525 | 而 S-CC 526 | 又 S-AD 527 | 具 B-VA 528 | 体 E-VA 529 | 的 S-DEC 530 | 规 B-NN 531 | 定 E-NN 532 | 。 S-PU 533 | 534 | 535 | 建 B-NN 536 | 筑 E-NN 537 | 公 B-NN 538 | 司 E-NN 539 | 进 S-VV 540 | 区 S-NN 541 | , S-PU 542 | 有 B-JJ 543 | 关 E-JJ 544 | 部 B-NN 545 | 门 E-NN 546 | 先 S-AD 547 | 送 B-VV 548 | 上 E-VV 549 | 这 B-DT 550 | 些 E-DT 551 | 法 B-NN 552 | 规 I-NN 553 | 性 E-NN 554 | 文 B-NN 555 | 件 E-NN 556 | , S-PU 557 | 然 B-AD 558 | 后 E-AD 559 | 有 S-VE 560 | 专 B-JJ 561 | 门 E-JJ 562 | 队 B-NN 563 | 伍 E-NN 564 | 进 B-VV 565 | 行 E-VV 566 | 监 B-NN 567 | 督 E-NN 568 | 检 B-NN 569 | 查 E-NN 570 | 。 S-PU 571 | 572 | 573 | 尽 B-CS 574 | 管 E-CS 575 | 浦 B-NR 576 | 东 E-NR 577 | 新 B-NN 578 | 区 E-NN 579 | 制 B-VV 580 | 定 E-VV 581 | 的 S-DEC 582 | 法 B-NN 583 | 规 I-NN 584 | 性 E-NN 585 | 文 B-NN 586 | 件 E-NN 587 | 有 B-PN 588 | 些 E-PN 589 | 比 B-AD 590 | 较 E-AD 591 | “ S-PU 592 | 粗 S-VA 593 | ” S-PU 594 | , S-PU 595 | 有 B-PN 596 | 些 E-PN 597 | 还 S-AD 598 | 只 S-AD 599 | 是 S-VC 600 | 暂 B-JJ 601 | 行 E-JJ 602 | 规 B-NN 603 | 定 E-NN 604 | , S-PU 605 | 有 B-VV 606 | 待 E-VV 607 | 在 S-P 608 | 实 B-NN 609 | 践 E-NN 610 | 中 S-LC 611 | 逐 B-AD 612 | 步 E-AD 613 | 完 B-VV 614 | 善 E-VV 615 | , S-PU 616 | 但 S-AD 617 | 这 S-DT 618 | 种 S-M 619 | 法 B-NN 620 | 制 E-NN 621 | 紧 B-VV 622 | 跟 E-VV 623 | 经 B-NN 624 | 济 E-NN 625 | 和 S-CC 626 | 社 B-NN 627 | 会 E-NN 628 | 活 B-NN 629 | 动 E-NN 630 | 的 S-DEC 631 | 做 B-NN 632 | 法 E-NN 633 | , S-PU 634 | 受 B-VV 635 | 到 E-VV 636 | 了 S-AS 637 | 国 B-NN 638 | 内 I-NN 639 | 外 E-NN 640 | 投 B-NN 641 | 资 I-NN 642 | 者 E-NN 643 | 的 S-DEG 644 | 好 S-JJ 645 | 评 S-NN 646 | , S-PU 647 | 他 B-PN 648 | 们 E-PN 649 | 认 B-VV 650 | 为 E-VV 651 | , S-PU 652 | 到 S-VV 653 | 浦 B-NR 654 | 东 E-NR 655 | 新 B-NN 656 | 区 E-NN 657 | 投 B-VV 658 | 资 E-VV 659 | 办 B-NN 660 | 事 E-NN 661 | 有 S-VE 662 | 章 B-NN 663 | 法 E-NN 664 | , S-PU 665 | 讲 S-VV 666 | 规 B-NN 667 | 矩 E-NN 668 | , S-PU 669 | 利 B-NN 670 | 益 E-NN 671 | 能 S-VV 672 | 得 B-VV 673 | 到 E-VV 674 | 保 B-NN 675 | 障 E-NN 676 | 。 S-PU 677 | 678 | 679 | 外 B-NN 680 | 商 E-NN 681 | 投 B-NN 682 | 资 E-NN 683 | 企 B-NN 684 | 业 E-NN 685 | 成 B-VV 686 | 为 E-VV 687 | 中 B-NR 688 | 国 E-NR 689 | 外 B-NN 690 | 贸 E-NN 691 | 重 B-JJ 692 | 要 E-JJ 693 | 增 B-NN 694 | 长 I-NN 695 | 点 E-NN 696 | 697 | 698 | 新 B-NN 699 | 华 I-NN 700 | 社 E-NN 701 | 北 B-NR 702 | 京 E-NR 703 | 二 B-NT 704 | 月 E-NT 705 | 十 B-NT 706 | 一 I-NT 707 | 日 E-NT 708 | 电 S-NN 709 | ( S-PU 710 | 记 B-NN 711 | 者 E-NN 712 | 唐 B-NR 713 | 虹 E-NR 714 | ) S-PU 715 | 716 | 717 | 海 B-NN 718 | 关 E-NN 719 | 统 B-NN 720 | 计 E-NN 721 | 表 B-VV 722 | 明 E-VV 723 | , S-PU 724 | “ S-PU 725 | 八 B-NT 726 | 五 E-NT 727 | ” S-PU 728 | 期 B-NN 729 | 间 E-NN 730 | ( S-PU 731 | 一 B-NT 732 | 九 I-NT 733 | 九 I-NT 734 | 0 I-NT 735 | 年 E-NT 736 | ― S-PU 737 | 一 B-NT 738 | 九 I-NT 739 | 九 I-NT 740 | 五 I-NT 741 | 年 E-NT 742 | ) S-PU 743 | , S-PU 744 | 中 B-NR 745 | 国 E-NR 746 | 外 B-NN 747 | 商 E-NN 748 | 投 B-NN 749 | 资 E-NN 750 | 企 B-NN 751 | 业 E-NN 752 | 的 S-DEG 753 | 进 B-NN 754 | 出 I-NN 755 | 口 E-NN 756 | 呈 S-VV 757 | 直 B-AD 758 | 线 E-AD 759 | 上 B-VV 760 | 升 E-VV 761 | 之 S-DEC 762 | 势 S-NN 763 | , S-PU 764 | 出 B-NN 765 | 口 E-NN 766 | 年 B-AD 767 | 均 E-AD 768 | 增 B-VV 769 | 长 E-VV 770 | 百 B-CD 771 | 分 I-CD 772 | 之 I-CD 773 | 四 I-CD 774 | 十 I-CD 775 | 三 I-CD 776 | 点 I-CD 777 | 二 E-CD 778 | , S-PU 779 | 进 B-NN 780 | 口 E-NN 781 | 年 B-AD 782 | 均 E-AD 783 | 增 B-VV 784 | 长 E-VV 785 | 百 B-CD 786 | 分 I-CD 787 | 之 I-CD 788 | 三 I-CD 789 | 十 I-CD 790 | 八 I-CD 791 | 点 I-CD 792 | 六 E-CD 793 | 。 S-PU 794 | 795 | 796 | 去 B-NT 797 | 年 E-NT 798 | 实 B-VV 799 | 现 E-VV 800 | 进 B-NN 801 | 出 I-NN 802 | 口 E-NN 803 | 总 B-NN 804 | 值 E-NN 805 | 达 S-VV 806 | 一 B-CD 807 | 千 I-CD 808 | 零 I-CD 809 | 九 I-CD 810 | 十 I-CD 811 | 八 I-CD 812 | 点 I-CD 813 | 二 I-CD 814 | 亿 E-CD 815 | 美 B-M 816 | 元 E-M 817 | , S-PU 818 | 占 S-VV 819 | 全 S-DT 820 | 国 S-NN 821 | 进 B-NN 822 | 出 I-NN 823 | 口 E-NN 824 | 总 B-NN 825 | 值 E-NN 826 | 的 S-DEC 827 | 比 B-NN 828 | 重 E-NN 829 | 由 S-P 830 | 上 S-DT 831 | 年 S-M 832 | 的 S-DEG 833 | 百 B-CD 834 | 分 I-CD 835 | 之 I-CD 836 | 三 I-CD 837 | 十 I-CD 838 | 七 E-CD 839 | 提 B-VV 840 | 高 E-VV 841 | 到 S-VV 842 | 百 B-CD 843 | 分 I-CD 844 | 之 I-CD 845 | 三 I-CD 846 | 十 I-CD 847 | 九 E-CD 848 | 。 S-PU 849 | 850 | 851 | 外 B-NN 852 | 商 E-NN 853 | 投 B-NN 854 | 资 E-NN 855 | 企 B-NN 856 | 业 E-NN 857 | 在 S-P 858 | 改 B-VV 859 | 善 E-VV 860 | 中 B-NR 861 | 国 E-NR 862 | 出 B-NN 863 | 口 E-NN 864 | 商 B-NN 865 | 品 E-NN 866 | 结 B-NN 867 | 构 E-NN 868 | 中 S-LC 869 | 发 B-VV 870 | 挥 E-VV 871 | 了 S-AS 872 | 显 B-JJ 873 | 著 E-JJ 874 | 作 B-NN 875 | 用 E-NN 876 | 。 S-PU 877 | 878 | 879 | 去 B-NT 880 | 年 E-NT 881 | 外 B-NN 882 | 商 E-NN 883 | 投 B-NN 884 | 资 E-NN 885 | 企 B-NN 886 | 业 E-NN 887 | 出 B-NN 888 | 口 E-NN 889 | 商 B-NN 890 | 品 E-NN 891 | 中 S-LC 892 | , S-PU 893 | 工 B-NN 894 | 业 E-NN 895 | 制 B-NN 896 | 成 I-NN 897 | 品 E-NN 898 | 占 S-VV 899 | 九 B-CD 900 | 成 E-CD 901 | 以 B-LC 902 | 上 E-LC 903 | , S-PU 904 | 达 S-VV 905 | 四 B-CD 906 | 百 I-CD 907 | 三 I-CD 908 | 十 I-CD 909 | 八 I-CD 910 | 点 I-CD 911 | 八 I-CD 912 | 亿 E-CD 913 | 美 B-M 914 | 元 E-M 915 | , S-PU 916 | 比 S-P 917 | 上 S-DT 918 | 年 S-M 919 | 增 B-VV 920 | 长 E-VV 921 | 了 S-AS 922 | 百 B-CD 923 | 分 I-CD 924 | 之 I-CD 925 | 三 I-CD 926 | 十 I-CD 927 | 六 I-CD 928 | 点 I-CD 929 | 七 E-CD 930 | , S-PU 931 | 明 B-AD 932 | 显 E-AD 933 | 高 B-VV 934 | 于 E-VV 935 | 全 S-DT 936 | 国 S-NN 937 | 平 B-JJ 938 | 均 E-JJ 939 | 水 B-NN 940 | 平 E-NN 941 | 。 S-PU 942 | 943 | -------------------------------------------------------------------------------- /twasp_eval.py: -------------------------------------------------------------------------------- 1 | from seqeval.metrics import f1_score, precision_score, recall_score 2 | 3 | 4 | def eval_sentence(y_pred, y, sentence, word2id): 5 | words = sentence.split(' ') 6 | 7 | if y is not None: 8 | seg_true = [] 9 | word_true = '' 10 | y_word = [] 11 | y_pos = [] 12 | for y_label in y: 13 | y_word.append(y_label[0]) 14 | y_pos.append(y_label[2:]) 15 | 16 | for i in range(len(y_word)): 17 | word_true += words[i] 18 | if y_word[i] in ['S', 'E']: 19 | pos_tag_true = y_pos[i] 20 | word_pos_true = word_true + '_' + pos_tag_true 21 | if word_true not in word2id: 22 | word_pos_true = '*' + word_pos_true + '*' 23 | seg_true.append(word_pos_true) 24 | word_true = '' 25 | 26 | seg_true_str = ' '.join(seg_true) 27 | else: 28 | seg_true_str = None 29 | 30 | seg_pred = [] 31 | word_pred = '' 32 | 33 | y_pred_word = [] 34 | y_pred_pos = [] 35 | for y_pred_label in y_pred: 36 | y_pred_word.append(y_pred_label[0]) 37 | y_pred_pos.append(y_pred_label[2:]) 38 | 39 | for i in range(len(y_pred_word)): 40 | word_pred += words[i] 41 | if y_pred_word[i] in ['S', 'E']: 42 | pos_tag_pred = y_pred_pos[i] 43 | word_pos_pred = word_pred + '_' + pos_tag_pred 44 | seg_pred.append(word_pos_pred) 45 | word_pred = '' 46 | 47 | seg_pred_str = ' '.join(seg_pred) 48 | return seg_true_str, seg_pred_str 49 | 50 | 51 | def input2file(save_path): 52 | s = input("Please input demo sentence: \n") 53 | f = open(save_path, "w") 54 | for item in s: 55 | I = 'O' 56 | f.write(item + '' + I + '\n') 57 | f.write('\n ') 58 | 59 | 60 | def pos_evaluate_word_PRF(y_pred, y): 61 | #dict = {'E': 2, 'S': 3, 'B':0, 'I':1} 62 | y_word = [] 63 | y_pos = [] 64 | y_pred_word = [] 65 | y_pred_pos = [] 66 | for y_label, y_pred_label in zip(y, y_pred): 67 | y_word.append(y_label[0]) 68 | y_pos.append(y_label[2:]) 69 | y_pred_word.append(y_pred_label[0]) 70 | y_pred_pos.append(y_pred_label[2:]) 71 | 72 | word_cor_num = 0 73 | pos_cor_num = 0 74 | yp_wordnum = y_pred_word.count('E')+y_pred_word.count('S') 75 | yt_wordnum = y_word.count('E')+y_word.count('S') 76 | start = 0 77 | for i in range(len(y_word)): 78 | if y_word[i] == 'E' or y_word[i] == 'S': 79 | word_flag = True 80 | pos_flag = True 81 | for j in range(start, i+1): 82 | if y_word[j] != y_pred_word[j]: 83 | word_flag = False 84 | pos_flag = False 85 | break 86 | if y_pos[j] != y_pred_pos[j]: 87 | pos_flag = False 88 | if word_flag: 89 | word_cor_num += 1 90 | if pos_flag: 91 | pos_cor_num += 1 92 | start = i+1 93 | 94 | wP = word_cor_num / float(yp_wordnum) if yp_wordnum > 0 else -1 95 | wR = word_cor_num / float(yt_wordnum) if yt_wordnum > 0 else -1 96 | wF = 2 * wP * wR / (wP + wR) 97 | 98 | # pP = pos_cor_num / float(yp_wordnum) if yp_wordnum > 0 else -1 99 | # pR = pos_cor_num / float(yt_wordnum) if yt_wordnum > 0 else -1 100 | # pF = 2 * pP * pR / (pP + pR) 101 | 102 | pP = precision_score(y, y_pred) 103 | pR = recall_score(y, y_pred) 104 | pF = f1_score(y, y_pred) 105 | 106 | return (wP, wR, wF), (pP, pR, pF) 107 | 108 | 109 | def pos_evaluate_OOV(y_pred_list, y_list, sentence_list, word2id): 110 | word_cor_num = 0 111 | pos_cor_num = 0 112 | yt_wordnum = 0 113 | 114 | y_word_list = [] 115 | y_pos_list = [] 116 | y_pred_word_list = [] 117 | y_pred_pos_list = [] 118 | for y_label, y_pred_label in zip(y_list, y_pred_list): 119 | y_word = [] 120 | y_pos = [] 121 | y_pred_word = [] 122 | y_pred_pos = [] 123 | for y_l in y_label: 124 | y_word.append(y_l[0]) 125 | y_pos.append(y_l[2:]) 126 | for y_pred_l in y_pred_label: 127 | y_pred_word.append(y_pred_l[0]) 128 | y_pred_pos.append(y_pred_l[2:]) 129 | y_word_list.append(y_word) 130 | y_pos_list.append(y_pos) 131 | y_pred_word_list.append(y_pred_word) 132 | y_pred_pos_list.append(y_pred_pos) 133 | 134 | for y_w, y_p, y_p_w, y_p_p, sentence in zip(y_word_list, y_pos_list, y_pred_word_list, y_pred_pos_list, sentence_list): 135 | start = 0 136 | for i in range(len(y_w)): 137 | if y_w[i] == 'E' or y_w[i] == 'S': 138 | word = ''.join(sentence[start:i+1]) 139 | if word in word2id: 140 | start = i + 1 141 | continue 142 | word_flag = True 143 | pos_flag = True 144 | yt_wordnum += 1 145 | for j in range(start, i+1): 146 | if y_w[j] != y_p_w[j]: 147 | word_flag = False 148 | pos_flag = False 149 | break 150 | if y_p[j] != y_p_p[j]: 151 | pos_flag = False 152 | if word_flag: 153 | word_cor_num += 1 154 | if pos_flag: 155 | pos_cor_num += 1 156 | start = i + 1 157 | 158 | word_OOV = word_cor_num / float(yt_wordnum) if yt_wordnum > 0 else -1 159 | pos_OOV = pos_cor_num / float(yt_wordnum) if yt_wordnum > 0 else -1 160 | 161 | return word_OOV, pos_OOV 162 | -------------------------------------------------------------------------------- /updates.md: -------------------------------------------------------------------------------- 1 | # Important Updates 2 | 3 | * July 14, 2020: Implement the `predict` function in `twasp_main.py`. You can use that function to segment and tag the sentences in an input file with a pre-trained TwASP model. See [run_sample.sh](./run_sample.sh) for the usage, and [./sample_data/sentences.txt](./sample_data/sentence.txt) for the input format. If you run pre-trained TwASP models using Stanford CoreNLP Toolkit v3.9.2 or Berkeley Neural Parser, you need to download these toolkits before running. See [data_preprocessing](./data_preprocessing) for more information to install the toolkits. 4 | * July 7, 2020: the release of [pre-trained TwASP models](./models). 5 | --------------------------------------------------------------------------------