├── .gitignore ├── README.md ├── dataset_dstc2.py ├── dataset_sim.py ├── eval_pred.sh ├── main.py ├── metric_bert_dst.py ├── storage ├── dstc2-clean.zip └── woz_2.0.zip ├── train.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | # Edit at https://www.gitignore.io/?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # End of https://www.gitignore.io/api/python 131 | 132 | 133 | # Created by https://www.gitignore.io/api/pycharm+all 134 | # Edit at https://www.gitignore.io/?templates=pycharm+all 135 | 136 | ### PyCharm+all ### 137 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 138 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 139 | 140 | # User-specific stuff 141 | .idea/**/workspace.xml 142 | .idea/**/tasks.xml 143 | .idea/**/usage.statistics.xml 144 | .idea/**/dictionaries 145 | .idea/**/shelf 146 | 147 | # Generated files 148 | .idea/**/contentModel.xml 149 | 150 | # Sensitive or high-churn files 151 | .idea/**/dataSources/ 152 | .idea/**/dataSources.ids 153 | .idea/**/dataSources.local.xml 154 | .idea/**/sqlDataSources.xml 155 | .idea/**/dynamic.xml 156 | .idea/**/uiDesigner.xml 157 | .idea/**/dbnavigator.xml 158 | 159 | # Gradle 160 | .idea/**/gradle.xml 161 | .idea/**/libraries 162 | 163 | # Gradle and Maven with auto-import 164 | # When using Gradle or Maven with auto-import, you should exclude module files, 165 | # since they will be recreated, and may cause churn. Uncomment if using 166 | # auto-import. 167 | # .idea/modules.xml 168 | # .idea/*.iml 169 | # .idea/modules 170 | 171 | # CMake 172 | cmake-build-*/ 173 | 174 | # Mongo Explorer plugin 175 | .idea/**/mongoSettings.xml 176 | 177 | # File-based project format 178 | *.iws 179 | 180 | # IntelliJ 181 | out/ 182 | 183 | # mpeltonen/sbt-idea plugin 184 | .idea_modules/ 185 | 186 | # JIRA plugin 187 | atlassian-ide-plugin.xml 188 | 189 | # Cursive Clojure plugin 190 | .idea/replstate.xml 191 | 192 | # Crashlytics plugin (for Android Studio and IntelliJ) 193 | com_crashlytics_export_strings.xml 194 | crashlytics.properties 195 | crashlytics-build.properties 196 | fabric.properties 197 | 198 | # Editor-based Rest Client 199 | .idea/httpRequests 200 | 201 | # Android studio 3.1+ serialized cache file 202 | .idea/caches/build_file_checksums.ser 203 | 204 | # JetBrains templates 205 | **___jb_tmp___ 206 | 207 | ### PyCharm+all Patch ### 208 | # Ignores the whole .idea folder and all .iml files 209 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 210 | 211 | .idea/ 212 | 213 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 214 | 215 | *.iml 216 | modules.xml 217 | .idea/misc.xml 218 | *.ipr 219 | 220 | # Sonarlint plugin 221 | .idea/sonarlint 222 | 223 | # End of https://www.gitignore.io/api/pycharm+all -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT-DST 2 | 3 | Contact: Guan-Lin Chao (guanlinchao@cmu.edu) 4 | 5 | Source code of our paper [BERT-DST: Scalable End-to-End Dialogue State Tracking with Bidirectional Encoder Representations from Transformer](https://arxiv.org/abs/1907.03040) (Interspeech 2019). 6 | ``` 7 | @inproceedings{chao2019bert, 8 | title={{BERT-DST}: Scalable End-to-End Dialogue State Tracking with Bidirectional Encoder Representations from Transformer}, 9 | author={Chao, Guan-Lin and Lane, Ian}, 10 | booktitle={INTERSPEECH}, 11 | year={2019} 12 | } 13 | ``` 14 | 15 | Tested on Python 3.6, Tensorflow==1.13.0rc0 16 | 17 | ## Required packages (no need to install, just provide the paths in code): 18 | 1. [bert](https://github.com/google-research/bert) 19 | 2. uncased_L-12_H-768_A-12: pretrained [BERT-Base, Uncased] model checkpoint. Download link in [bert](https://github.com/google-research/bert). 20 | 21 | ## Datasets: 22 | [dstc2-clean](https://github.com/guanlinchao/bert-dst/blob/master/storage/dstc2-clean.zip), [woz_2.0](https://github.com/guanlinchao/bert-dst/blob/master/storage/woz_2.0.zip), [sim-M and sim-R](https://github.com/google-research-datasets/simulated-dialogue) 23 | -------------------------------------------------------------------------------- /dataset_dstc2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import sys 4 | 5 | import tensorflow as tf 6 | 7 | import util 8 | 9 | # Directory of bert, cloned from github.com/google-research/bert 10 | sys.path.append("/path/to/bert") 11 | import tokenization 12 | 13 | 14 | SEMANTIC_DICT = { 15 | 'center': ['centre', 'downtown', 'central', 'down town', 'middle'], 16 | 'south': ['southern', 'southside'], 17 | 'north': ['northern', 'uptown', 'northside'], 18 | 'west': ['western', 'westside'], 19 | 'east': ['eastern', 'eastside'], 20 | 'east side': ['eastern', 'eastside'], 21 | 22 | 'cheap': ['low price', 'inexpensive', 'cheaper', 'low priced', 'affordable', 23 | 'nothing too expensive', 'without costing a fortune', 'cheapest', 24 | 'good deals', 'low prices', 'afford', 'on a budget', 'fair prices', 25 | 'less expensive', 'cheapeast', 'not cost an arm and a leg'], 26 | 'moderate': ['moderately', 'medium priced', 'medium price', 'fair price', 27 | 'fair prices', 'reasonable', 'reasonably priced', 'mid price', 28 | 'fairly priced', 'not outrageous','not too expensive', 29 | 'on a budget', 'mid range', 'reasonable priced', 'less expensive', 30 | 'not too pricey', 'nothing too expensive', 'nothing cheap', 31 | 'not overpriced', 'medium', 'inexpensive'], 32 | 'expensive': ['high priced', 'high end', 'high class', 'high quality', 33 | 'fancy', 'upscale', 'nice', 'fine dining', 'expensively priced'], 34 | 35 | 'afghan': ['afghanistan'], 36 | 'african': ['africa'], 37 | 'asian oriental': ['asian', 'oriental'], 38 | 'australasian': ['australian asian', 'austral asian'], 39 | 'australian': ['aussie'], 40 | 'barbeque': ['barbecue', 'bbq'], 41 | 'basque': ['bask'], 42 | 'belgian': ['belgium'], 43 | 'british': ['cotto'], 44 | 'canapes': ['canopy', 'canape', 'canap'], 45 | 'catalan': ['catalonian'], 46 | 'corsican': ['corsica'], 47 | 'crossover': ['cross over', 'over'], 48 | 'gastropub': ['gastro pub', 'gastro', 'gastropubs'], 49 | 'hungarian': ['goulash'], 50 | 'indian': ['india', 'indians', 'nirala'], 51 | 'international': ['all types of food'], 52 | 'italian': ['prezzo'], 53 | 'jamaican': ['jamaica'], 54 | 'japanese': ['sushi', 'beni hana'], 55 | 'korean': ['korea'], 56 | 'lebanese': ['lebanse'], 57 | 'north american': ['american', 'hamburger'], 58 | 'portuguese': ['portugese'], 59 | 'seafood': ['sea food', 'shellfish', 'fish'], 60 | 'singaporean': ['singapore'], 61 | 'steakhouse': ['steak house', 'steak'], 62 | 'thai': ['thailand', 'bangkok'], 63 | 'traditional': ['old fashioned', 'plain'], 64 | 'turkish': ['turkey'], 65 | 'unusual': ['unique and strange'], 66 | 'venetian': ['vanessa'], 67 | 'vietnamese': ['vietnam', 'thanh binh'], 68 | } 69 | 70 | FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number'} 71 | 72 | 73 | def get_token_pos(tok_list, label): 74 | find_pos = [] 75 | found = False 76 | label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0] 77 | len_label = len(label_list) 78 | for i in range(len(tok_list) + 1 - len_label): 79 | if tok_list[i:i+len_label] == label_list: 80 | find_pos.append((i,i+len_label)) # start, exclusive_end 81 | found = True 82 | return found, find_pos 83 | 84 | 85 | def check_label_existence(label, usr_utt_tok, sys_utt_tok): 86 | in_usr, usr_pos = get_token_pos(usr_utt_tok, label) 87 | in_sys, sys_pos = get_token_pos(sys_utt_tok, label) 88 | 89 | if not in_usr and not in_sys and label in SEMANTIC_DICT: 90 | for tmp_label in SEMANTIC_DICT[label]: 91 | in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label) 92 | in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label) 93 | if in_usr or in_sys: 94 | label = tmp_label 95 | break 96 | return label, in_usr, usr_pos, in_sys, sys_pos 97 | 98 | 99 | def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence): 100 | sys_utt_tok_label = [0 for _ in sys_utt_tok] 101 | usr_utt_tok_label = [0 for _ in usr_utt_tok] 102 | if label == 'none' or label == 'dontcare': 103 | class_type = label 104 | else: 105 | label, in_usr, usr_pos, in_sys, sys_pos = check_label_existence(label, usr_utt_tok, sys_utt_tok) 106 | if in_usr or in_sys: 107 | class_type = 'copy_value' 108 | if slot_last_occurrence: 109 | if in_usr: 110 | (s, e) = usr_pos[-1] 111 | for i in range(s, e): 112 | usr_utt_tok_label[i] = 1 113 | else: 114 | (s, e) = sys_pos[-1] 115 | for i in range(s, e): 116 | sys_utt_tok_label[i] = 1 117 | else: 118 | for (s, e) in usr_pos: 119 | for i in range(s, e): 120 | usr_utt_tok_label[i] = 1 121 | for (s, e) in sys_pos: 122 | for i in range(s, e): 123 | sys_utt_tok_label[i] = 1 124 | else: 125 | class_type = 'unpointable' 126 | return sys_utt_tok_label, usr_utt_tok_label, class_type 127 | 128 | 129 | def tokenize(utt): 130 | utt_lower = tokenization.convert_to_unicode(utt).lower() 131 | utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if 132 | len(tok) > 0] 133 | return utt_tok 134 | 135 | 136 | def create_examples(dialog_filename, slot_list, set_type, use_asr_hyp=0, 137 | exclude_unpointable=True): 138 | examples = [] 139 | with open(dialog_filename) as f: 140 | dst_set = json.load(f) 141 | for dial in dst_set: 142 | for turn in dial['dialogue']: 143 | guid = '%s-%s-%s' % (set_type, 144 | str(dial['dialogue_idx']), 145 | str(turn['turn_idx'])) 146 | 147 | sys_utt_tok = tokenize(turn['system_transcript']) 148 | 149 | usr_utt_tok_list = [] 150 | if use_asr_hyp == 0: 151 | usr_utt_tok_list.append(tokenize(turn['transcript'])) 152 | else: 153 | for asr_hyp, _ in turn['asr'][:use_asr_hyp]: 154 | usr_utt_tok_list.append(tokenize(asr_hyp)) 155 | 156 | turn_label = [[FIX.get(s.strip(), s.strip()), FIX.get(v.strip(), v.strip())] for s, v in turn['turn_label']] 157 | 158 | for usr_utt_tok in usr_utt_tok_list: 159 | sys_utt_tok_label_dict = {} 160 | usr_utt_tok_label_dict = {} 161 | class_type_dict = {} 162 | for slot in slot_list: 163 | label = 'none' 164 | for [s, v] in turn_label: 165 | if s == slot: 166 | label = v 167 | break 168 | sys_utt_tok_label, usr_utt_tok_label, class_type = get_turn_label( 169 | label, sys_utt_tok, usr_utt_tok, 170 | slot_last_occurrence=True) 171 | sys_utt_tok_label_dict[slot] = sys_utt_tok_label 172 | usr_utt_tok_label_dict[slot] = usr_utt_tok_label 173 | class_type_dict[slot] = class_type 174 | if class_type == 'unpointable': 175 | tf.logging.info( 176 | 'Unpointable: guid=%s, slot=%s, label=%s, usr_utt=%s, sys_utt=%s' % ( 177 | guid, slot, label, usr_utt_tok, sys_utt_tok)) 178 | if 'unpointable' not in class_type_dict.values() or not exclude_unpointable: 179 | examples.append(util.InputExample( 180 | guid=guid, 181 | text_a=sys_utt_tok, 182 | text_b=usr_utt_tok, 183 | text_a_label=sys_utt_tok_label_dict, 184 | text_b_label=usr_utt_tok_label_dict, 185 | class_label=class_type_dict)) 186 | return examples 187 | 188 | def create_examples_with_history(dialog_filename, slot_list, set_type, use_asr_hyp=0, exclude_unpointable=True): 189 | examples = [] 190 | with open(dialog_filename) as f: 191 | dst_set = json.load(f) 192 | for dial in dst_set: 193 | if use_asr_hyp == 0: 194 | his_utt_list = [[]] 195 | else: 196 | his_utt_list = [[] for _ in range(use_asr_hyp)] 197 | 198 | for turn in dial['dialogue']: 199 | guid = '%s-%s-%s' % (set_type, str(dial['dialogue_idx']), str(turn['turn_idx'])) 200 | 201 | sys_utt_tok = tokenize(turn['system_transcript']) 202 | 203 | for his_utt in his_utt_list: 204 | his_utt.append(sys_utt_tok) 205 | 206 | usr_utt_tok_list = [] 207 | if use_asr_hyp == 0: 208 | usr_utt_tok_list.append(tokenize(turn['transcript'])) 209 | else: 210 | for asr_hyp in turn.asr[:use_asr_hyp]: 211 | usr_utt_tok_list.append(tokenize(asr_hyp)) 212 | 213 | turn_label = [[FIX.get(s.strip(), s.strip()), FIX.get(v.strip(), v.strip())] for s, v in turn['turn_label']] 214 | 215 | for his_utt, usr_utt_tok in zip(his_utt_list, usr_utt_tok_list): 216 | his_utt_tok = [] 217 | for utt_tok in his_utt[-5:]: 218 | his_utt_tok.extend(utt_tok) 219 | his_utt_tok_label_dict = {} 220 | usr_utt_tok_label_dict = {} 221 | class_type_dict = {} 222 | for slot in slot_list: 223 | label = 'none' 224 | for [s, v] in turn_label: 225 | if s == slot: 226 | label = v 227 | break 228 | his_utt_tok_label, usr_utt_tok_label, class_type = get_turn_label( 229 | label, his_utt_tok, usr_utt_tok, slot_last_occurrence=True) 230 | his_utt_tok_label_dict[slot] = his_utt_tok_label 231 | usr_utt_tok_label_dict[slot] = usr_utt_tok_label 232 | class_type_dict[slot] = class_type 233 | if class_type == 'unpointable': 234 | tf.logging.info('Unpointable: guid=%s, slot=%s, label=%s, his_utt=%s, usr_utt=%s' % (guid, slot, label, his_utt_tok, usr_utt_tok)) 235 | if 'unpointable' not in class_type_dict.values() or not exclude_unpointable: 236 | examples.append(util.InputExample(guid=guid, 237 | text_a=his_utt_tok, 238 | text_b=usr_utt_tok, 239 | text_a_label=his_utt_tok_label_dict, 240 | text_b_label=usr_utt_tok_label_dict, 241 | class_label=class_type_dict)) 242 | 243 | for his_utt, usr_utt_tok in zip(his_utt_list, usr_utt_tok_list): 244 | his_utt.append(usr_utt_tok) 245 | return examples 246 | 247 | 248 | -------------------------------------------------------------------------------- /dataset_sim.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import util 4 | 5 | 6 | def dialogue_state_to_sv_dict(sv_list): 7 | sv_dict = {} 8 | for d in sv_list: 9 | sv_dict[d['slot']] = d['value'] 10 | return sv_dict 11 | 12 | 13 | def get_token_and_slot_label(turn): 14 | if 'system_utterance' in turn: 15 | sys_utt_tok = turn['system_utterance']['tokens'] 16 | sys_slot_label = turn['system_utterance']['slots'] 17 | else: 18 | sys_utt_tok = [] 19 | sys_slot_label = [] 20 | 21 | usr_utt_tok = turn['user_utterance']['tokens'] 22 | usr_slot_label = turn['user_utterance']['slots'] 23 | return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label 24 | 25 | 26 | def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, 27 | sys_slot_label, usr_utt_tok, usr_slot_label, dial_id, 28 | turn_id, slot_last_occurrence=True): 29 | """The position of the last occurrence of the slot value will be used.""" 30 | sys_utt_tok_label = [0 for _ in sys_utt_tok] 31 | usr_utt_tok_label = [0 for _ in usr_utt_tok] 32 | if slot_type not in cur_ds_dict: 33 | class_type = 'none' 34 | else: 35 | value = cur_ds_dict[slot_type] 36 | if value == 'dontcare' and (slot_type not in prev_ds_dict or 37 | prev_ds_dict[slot_type] != 'dontcare'): 38 | # Only label dontcare at its first occurrence in the dialog 39 | class_type = 'dontcare' 40 | else: # If not none or dontcare, we have to identify the position 41 | class_type = 'copy_value' 42 | found_pos = False 43 | for label_d in usr_slot_label: 44 | if label_d['slot'] == slot_type and value == ' '.join( 45 | usr_utt_tok[label_d['start']:label_d['exclusive_end']]): 46 | 47 | for idx in range(label_d['start'], label_d['exclusive_end']): 48 | usr_utt_tok_label[idx] = 1 49 | found_pos = True 50 | if slot_last_occurrence: 51 | break 52 | if not found_pos or not slot_last_occurrence: 53 | for label_d in sys_slot_label: 54 | if label_d['slot'] == slot_type and value == ' '.join( 55 | sys_utt_tok[label_d['start']:label_d['exclusive_end']]): 56 | for idx in range(label_d['start'], label_d['exclusive_end']): 57 | sys_utt_tok_label[idx] = 1 58 | found_pos = True 59 | if slot_last_occurrence: 60 | break 61 | if not found_pos: 62 | assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0 63 | if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]): 64 | raise ValueError('Copy value cannot found in Dial %s Turn %s' % 65 | (str(dial_id), str(turn_id))) 66 | else: 67 | class_type = 'none' 68 | else: 69 | assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0 70 | return sys_utt_tok_label, usr_utt_tok_label, class_type 71 | 72 | 73 | def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, 74 | slot_last_occurrence=True): 75 | """Make turn_label a dictionary of slot with value positions or being dontcare / none: 76 | Turn label contains: 77 | (1) the updates from previous to current dialogue state, 78 | (2) values in current dialogue state explicitly mentioned in system or user utterance.""" 79 | prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state) 80 | cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state']) 81 | 82 | (sys_utt_tok, sys_slot_label, 83 | usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn) 84 | 85 | sys_utt_tok_label_dict = {} 86 | usr_utt_tok_label_dict = {} 87 | class_type_dict = {} 88 | 89 | for slot_type in slot_list: 90 | sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label( 91 | prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label, 92 | usr_utt_tok, usr_slot_label, dial_id, turn_id, 93 | slot_last_occurrence=slot_last_occurrence) 94 | sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label 95 | usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label 96 | class_type_dict[slot_type] = class_type 97 | return (sys_utt_tok, sys_utt_tok_label_dict, 98 | usr_utt_tok, usr_utt_tok_label_dict, class_type_dict) 99 | 100 | 101 | def create_examples(dialog_filename, slot_list, set_type): 102 | examples = [] 103 | with open(dialog_filename) as f: 104 | dst_set = json.load(f) 105 | for dial in dst_set: 106 | dial_id = dial['dialogue_id'] 107 | prev_ds = [] 108 | for turn_id, turn in enumerate(dial['turns']): 109 | guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id)) 110 | (sys_utt_tok, 111 | sys_utt_tok_label_dict, 112 | usr_utt_tok, 113 | usr_utt_tok_label_dict, 114 | class_type_dict) = get_turn_label(turn, 115 | prev_ds, 116 | slot_list, 117 | dial_id, 118 | turn_id, 119 | slot_last_occurrence=True) 120 | examples.append(util.InputExample( 121 | guid=guid, 122 | text_a=sys_utt_tok, 123 | text_b=usr_utt_tok, 124 | text_a_label=sys_utt_tok_label_dict, 125 | text_b_label=usr_utt_tok_label_dict, 126 | class_label=class_type_dict)) 127 | prev_ds = turn['dialogue_state'] 128 | return examples -------------------------------------------------------------------------------- /eval_pred.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TASK can be "dstc2_clean", "woz2", "sim-m", or "sim-r" 4 | TASK="sim-m" 5 | # Directory for dstc2-clean, woz_2.0, sim-M, or sim-R, which contains json files 6 | DATA_DIR=/path/to/sim-M 7 | # Directory of the pre-trained [BERT-Base, Uncased] model 8 | PRETRAINED_BERT=/path/to/uncased_L-12_H-768_A-12 9 | # Output directory of trained checkpoints, evaluation and prediction outputs 10 | OUTPUT_DIR=/path/to/output 11 | # DSET can be "dev" or "test" 12 | DSET="dev" 13 | 14 | # Comma separated list of checkpoint steps to be evaluated. 15 | for num in {0..12000..1000}; do 16 | CKPT_NUM="$CKPT_NUM,$num" 17 | done 18 | 19 | python main.py \ 20 | --task_name=${TASK} \ 21 | --do_eval=true \ 22 | --do_predict=true \ 23 | --max_seq_length=180 \ 24 | --eval_set=$DSET \ 25 | --eval_ckpt=$CKPT_NUM \ 26 | --data_dir=$DATA_DIR \ 27 | --vocab_file=${PRETRAINED_BERT}/vocab.txt \ 28 | --bert_config_file=${PRETRAINED_BERT}/bert_config.json \ 29 | --init_checkpoint=${PRETRAINED_BERT}/bert_model.ckpt \ 30 | --output_dir=$OUTPUT_DIR \ 31 | 2>&1 | tee -a $OUTPUT_DIR/eval.log 32 | 33 | 34 | python metric_bert_dst.py \ 35 | ${TASK} \ 36 | "$OUTPUT_DIR/pred_res.${DSET}*json" 37 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Model definition for BERT-DST. Modified from bert/run_classifier.py""" 2 | import collections 3 | import json 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | import tensorflow as tf 9 | 10 | import dataset_dstc2 11 | import dataset_sim 12 | import util 13 | 14 | flags = tf.flags 15 | FLAGS = flags.FLAGS 16 | 17 | # Directory of bert, cloned from github.com/google-research/bert 18 | sys.path.append("/path/to/bert") 19 | 20 | import modeling 21 | import optimization 22 | import run_classifier 23 | import tokenization 24 | 25 | ## BERT-DST params 26 | 27 | flags.DEFINE_string( 28 | "eval_set", "dev", 29 | "Which set to be evaluated: dev or test.") 30 | 31 | flags.DEFINE_string( 32 | "eval_ckpt", "", 33 | "comma seperated ckpt numbers to be evaluated.") 34 | 35 | flags.DEFINE_integer( 36 | "num_class_hidden_layer", 0, 37 | "Number of prediction layers in class prediction.") 38 | 39 | flags.DEFINE_integer( 40 | "num_token_hidden_layer", 0, 41 | "Number of prediction layers in class prediction.") 42 | 43 | flags.DEFINE_float("dropout_rate", 0.3, "Dropout rate for BERT representations.") 44 | 45 | flags.DEFINE_bool( 46 | "location_loss_for_nonpointable", False, 47 | "Whether the location loss for none or dontcare is contributed towards total loss.") 48 | 49 | flags.DEFINE_float("class_loss_ratio", 0.8, 50 | "The ratio applied on class loss in total loss calculation." 51 | "Should be a value in [0.0, 1.0]" 52 | "The ratio applied on token loss is (1-class_loss_ratio).") 53 | 54 | flags.DEFINE_float("slot_value_dropout", 0.0, 55 | "The rate that targeted slot value was replaced by [UNK].") 56 | 57 | 58 | class InputFeatures(object): 59 | """A single set of features of data.""" 60 | 61 | def __init__(self, 62 | input_ids, 63 | input_mask, 64 | segment_ids, 65 | start_pos, 66 | end_pos, 67 | class_label_id, 68 | is_real_example=True, 69 | guid="NONE"): 70 | self.guid = guid 71 | self.input_ids = input_ids 72 | self.input_mask = input_mask 73 | self.segment_ids = segment_ids 74 | self.start_pos = start_pos 75 | self.end_pos = end_pos 76 | self.class_label_id = class_label_id 77 | self.is_real_example = is_real_example 78 | 79 | 80 | class Dstc2Processor(object): 81 | class_types = ['none', 'dontcare', 'copy_value', 'unpointable'] 82 | slot_list = ['area', 'food', 'price range'] 83 | 84 | def get_train_examples(self, data_dir): 85 | return dataset_dstc2.create_examples( 86 | os.path.join(data_dir, 'dstc2_train_en.json'), self.slot_list, 'train') 87 | 88 | def get_dev_examples(self, data_dir): 89 | return dataset_dstc2.create_examples( 90 | os.path.join(data_dir, 'dstc2_validate_en.json'), self.slot_list, 'dev', 91 | use_asr_hyp=1, exclude_unpointable=False) 92 | 93 | def get_test_examples(self, data_dir): 94 | return dataset_dstc2.create_examples( 95 | os.path.join(data_dir, 'dstc2_test_en.json'), self.slot_list, 'test', 96 | use_asr_hyp=1, exclude_unpointable=False) 97 | 98 | 99 | class Woz2Processor(object): 100 | class_types = ['none', 'dontcare', 'copy_value', 'unpointable'] 101 | slot_list = ['area', 'food', 'price range'] 102 | 103 | def get_train_examples(self, data_dir): 104 | return dataset_dstc2.create_examples_with_history( 105 | os.path.join(data_dir, 'woz_train_en.json'), self.slot_list, 'train') 106 | 107 | def get_dev_examples(self, data_dir): 108 | return dataset_dstc2.create_examples_with_history( 109 | os.path.join(data_dir, 'woz_validate_en.json'), self.slot_list, 'dev', 110 | use_asr_hyp=0, exclude_unpointable=False) 111 | 112 | def get_test_examples(self, data_dir): 113 | return dataset_dstc2.create_examples_with_history( 114 | os.path.join(data_dir, 'woz_test_en.json'), self.slot_list, 'test', 115 | use_asr_hyp=0, exclude_unpointable=False) 116 | 117 | 118 | class SimMProcessor(object): 119 | class_types = ['none', 'dontcare', 'copy_value'] 120 | slot_list = ['date', 'movie', 'time', 'num_tickets', 'theatre_name'] 121 | 122 | def get_train_examples(self, data_dir): 123 | return dataset_sim.create_examples( 124 | os.path.join(data_dir, 'train.json'), self.slot_list, 'train') 125 | 126 | def get_dev_examples(self, data_dir): 127 | return dataset_sim.create_examples( 128 | os.path.join(data_dir, 'dev.json'), self.slot_list, 'dev') 129 | 130 | def get_test_examples(self, data_dir): 131 | return dataset_sim.create_examples( 132 | os.path.join(data_dir, 'test.json'), self.slot_list, 'test') 133 | 134 | 135 | class SimRProcessor(SimMProcessor): 136 | slot_list = ['category', 'rating', 'num_people', 'location', 137 | 'restaurant_name', 'time', 'date', 'price_range', 'meal'] 138 | 139 | 140 | def tokenize_text_and_label(text, text_label_dict, slot, tokenizer): 141 | joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label 142 | for slot_text_label in text_label_dict.values(): 143 | for idx, label in enumerate(slot_text_label): 144 | if label == 1: 145 | joint_text_label[idx] = 1 146 | 147 | text_label = text_label_dict[slot] 148 | tokens = [] 149 | token_labels = [] 150 | for token, token_label, joint_label in zip(text, text_label, joint_text_label): 151 | token = tokenization.convert_to_unicode(token) 152 | sub_tokens = tokenizer.tokenize(token) 153 | if FLAGS.slot_value_dropout == 0.0 or joint_label == 0: 154 | tokens.extend(sub_tokens) 155 | else: 156 | rn_list = np.random.random_sample((len(sub_tokens),)) 157 | for rn, sub_token in zip(rn_list, sub_tokens): 158 | if rn > FLAGS.slot_value_dropout: 159 | tokens.append(sub_token) 160 | else: 161 | tokens.append('[UNK]') 162 | 163 | token_labels.extend([token_label for _ in sub_tokens]) 164 | assert len(tokens) == len(token_labels) 165 | return tokens, token_labels 166 | 167 | 168 | def convert_single_example(ex_index, example, slot_list, class_types, max_seq_length, 169 | tokenizer): 170 | """Converts a single `InputExample` into a single `InputFeatures`.""" 171 | 172 | if isinstance(example, run_classifier.PaddingInputExample): 173 | return InputFeatures( 174 | input_ids=[0] * max_seq_length, 175 | input_mask=[0] * max_seq_length, 176 | segment_ids=[0] * max_seq_length, 177 | start_pos={slot: 0 for slot in slot_list}, 178 | end_pos={slot: 0 for slot in slot_list}, 179 | class_label_id={slot: 0 for slot in slot_list}, 180 | is_real_example=False, 181 | guid="NONE") 182 | 183 | class_label_id_dict = {} 184 | start_pos_dict = {} 185 | end_pos_dict = {} 186 | for slot in slot_list: 187 | tokens_a, token_labels_a = tokenize_text_and_label( 188 | example.text_a, example.text_a_label, slot, tokenizer) 189 | tokens_b, token_labels_b = tokenize_text_and_label( 190 | example.text_b, example.text_b_label, slot, tokenizer) 191 | 192 | input_text_too_long = util.truncate_length_and_warn( 193 | tokens_a, tokens_b, max_seq_length, example.guid) 194 | 195 | if input_text_too_long: 196 | if ex_index < 10: 197 | if len(token_labels_a) > len(tokens_a): 198 | tf.logging.info(' tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):])) 199 | if len(token_labels_b) > len(tokens_b): 200 | tf.logging.info(' tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):])) 201 | 202 | token_labels_a = token_labels_a[:len(tokens_a)] 203 | token_labels_b = token_labels_b[:len(tokens_b)] 204 | 205 | assert len(token_labels_a) == len(tokens_a) 206 | assert len(token_labels_b) == len(tokens_b) 207 | token_label_ids = util.get_token_label_ids( 208 | token_labels_a, token_labels_b, max_seq_length) 209 | 210 | class_label_id_dict[slot] = class_types.index(example.class_label[slot]) 211 | start_pos_dict[slot], end_pos_dict[ 212 | slot] = util.get_start_end_pos( 213 | example.class_label[slot], token_label_ids, 214 | max_seq_length) 215 | 216 | tokens, input_ids, input_mask, segment_ids = util.get_bert_input(tokens_a, 217 | tokens_b, 218 | max_seq_length, 219 | tokenizer) 220 | 221 | if ex_index < 10: 222 | tf.logging.info("*** Example ***") 223 | tf.logging.info("guid: %s" % (example.guid)) 224 | tf.logging.info("tokens: %s" % " ".join( 225 | [tokenization.printable_text(x) for x in tokens])) 226 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 227 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 228 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 229 | tf.logging.info("start_pos: %s" % str(start_pos_dict)) 230 | tf.logging.info("end_pos: %s" % str(end_pos_dict)) 231 | tf.logging.info("class_label_id: %s" % str(class_label_id_dict)) 232 | 233 | 234 | feature = InputFeatures( 235 | input_ids=input_ids, 236 | input_mask=input_mask, 237 | segment_ids=segment_ids, 238 | start_pos=start_pos_dict, 239 | end_pos=end_pos_dict, 240 | class_label_id=class_label_id_dict, 241 | is_real_example=True, 242 | guid=example.guid) 243 | return feature, input_text_too_long 244 | 245 | 246 | def file_based_convert_examples_to_features( 247 | examples, slot_list, class_types, max_seq_length, tokenizer, output_file): 248 | """Convert a set of `InputExample`s to a TFRecord file.""" 249 | 250 | writer = tf.python_io.TFRecordWriter(output_file) 251 | 252 | total_cnt = 0 253 | too_long_cnt = 0 254 | 255 | for (ex_index, example) in enumerate(examples): 256 | if ex_index % 10000 == 0: 257 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 258 | 259 | feature, input_text_too_long = convert_single_example(ex_index, example, slot_list, class_types, 260 | max_seq_length, tokenizer) 261 | total_cnt += 1 262 | if input_text_too_long: 263 | too_long_cnt += 1 264 | 265 | def create_int_feature(values): 266 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 267 | return f 268 | 269 | features = collections.OrderedDict() 270 | features["input_ids"] = create_int_feature(feature.input_ids) 271 | features["input_mask"] = create_int_feature(feature.input_mask) 272 | features["segment_ids"] = create_int_feature(feature.segment_ids) 273 | for slot in slot_list: 274 | features["start_pos_%s" % slot] = create_int_feature([feature.start_pos[slot]]) 275 | features["end_pos_%s" % slot] = create_int_feature([feature.end_pos[slot]]) 276 | features["class_label_id_%s" % slot] = create_int_feature([feature.class_label_id[slot]]) 277 | features["is_real_example"] = create_int_feature([int(feature.is_real_example)]) 278 | features["guid"] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.guid.encode('utf-8')])) 279 | 280 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 281 | writer.write(tf_example.SerializeToString()) 282 | tf.logging.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt)) 283 | writer.close() 284 | 285 | 286 | def file_based_input_fn_builder(input_file, seq_length, is_training, 287 | drop_remainder, slot_list): 288 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 289 | 290 | name_to_features = { 291 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 292 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 293 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 294 | "is_real_example": tf.FixedLenFeature([], tf.int64), 295 | "guid": tf.FixedLenFeature([], tf.string) 296 | } 297 | for slot in slot_list: 298 | name_to_features["start_pos_%s" % slot] = tf.FixedLenFeature([], tf.int64) 299 | name_to_features["end_pos_%s" % slot] = tf.FixedLenFeature([], tf.int64) 300 | name_to_features["class_label_id_%s" % slot] = tf.FixedLenFeature([], tf.int64) 301 | 302 | def _decode_record(record, name_to_features): 303 | """Decodes a record to a TensorFlow example.""" 304 | example = tf.parse_single_example(record, name_to_features) 305 | 306 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 307 | # So cast all int64 to int32. 308 | for name in list(example.keys()): 309 | t = example[name] 310 | if t.dtype == tf.int64: 311 | t = tf.to_int32(t) 312 | example[name] = t 313 | 314 | return example 315 | 316 | def input_fn(params): 317 | """The actual input function.""" 318 | batch_size = params["batch_size"] 319 | 320 | # For training, we want a lot of parallel reading and shuffling. 321 | # For eval, we want no shuffling and parallel reading doesn't matter. 322 | d = tf.data.TFRecordDataset(input_file) 323 | if is_training: 324 | d = d.repeat() 325 | d = d.shuffle(buffer_size=100) 326 | 327 | d = d.apply( 328 | tf.contrib.data.map_and_batch( 329 | lambda record: _decode_record(record, name_to_features), 330 | batch_size=batch_size, 331 | drop_remainder=drop_remainder)) 332 | 333 | return d 334 | 335 | return input_fn 336 | 337 | 338 | def create_model(bert_config, is_training, slot_list, features, num_class_labels, use_one_hot_embeddings): 339 | """Creates a classification model.""" 340 | input_ids = features["input_ids"] 341 | input_mask = features["input_mask"] 342 | segment_ids = features["segment_ids"] 343 | 344 | model = modeling.BertModel( 345 | config=bert_config, 346 | is_training=is_training, 347 | input_ids=input_ids, 348 | input_mask=input_mask, 349 | token_type_ids=segment_ids, 350 | use_one_hot_embeddings=use_one_hot_embeddings) 351 | 352 | # In the demo, we are doing a simple classification task on the entire 353 | # segment. 354 | # 355 | # If you want to use the token-level output, use model.get_sequence_output() 356 | # instead. 357 | class_output_layer = model.get_pooled_output() 358 | token_output_layer = model.get_sequence_output() 359 | 360 | token_output_shape = modeling.get_shape_list(token_output_layer, expected_rank=3) 361 | batch_size = token_output_shape[0] 362 | seq_length = token_output_shape[1] 363 | hidden_size = token_output_shape[2] 364 | 365 | # Define prediction variables 366 | class_proj_layer_dim = [hidden_size] 367 | for idx in range(FLAGS.num_class_hidden_layer): 368 | class_proj_layer_dim.append(64) 369 | class_proj_layer_dim.append(num_class_labels) 370 | 371 | token_proj_layer_dim = [hidden_size] 372 | for idx in range(FLAGS.num_token_hidden_layer): 373 | token_proj_layer_dim.append(64) 374 | token_proj_layer_dim.append(2) 375 | 376 | if is_training: 377 | # I.e., 0.1 dropout 378 | class_output_layer = tf.nn.dropout(class_output_layer, 379 | keep_prob=(1 - FLAGS.dropout_rate)) 380 | token_output_layer = tf.nn.dropout(token_output_layer, 381 | keep_prob=(1 - FLAGS.dropout_rate)) 382 | total_loss = 0 383 | per_slot_per_example_loss = {} 384 | per_slot_class_logits = {} 385 | per_slot_start_logits = {} 386 | per_slot_end_logits = {} 387 | for slot in slot_list: 388 | start_pos = features["start_pos_%s" % slot] 389 | end_pos = features["end_pos_%s" % slot] 390 | class_label_id = features["class_label_id_%s" % slot] 391 | slot_scope_name = "slot_%s" % slot 392 | if slot == 'price range': 393 | slot_scope_name = "slot_price" 394 | with tf.variable_scope(slot_scope_name): 395 | class_list_output_weights = [] 396 | class_list_output_bias = [] 397 | 398 | for l_idx in range(len(class_proj_layer_dim) - 1): 399 | dim_in = class_proj_layer_dim[l_idx] 400 | dim_out = class_proj_layer_dim[l_idx + 1] 401 | class_list_output_weights.append(tf.get_variable( 402 | "class/output_weights_%d" % l_idx, [dim_in, dim_out], 403 | initializer=tf.truncated_normal_initializer(stddev=0.02))) 404 | class_list_output_bias.append(tf.get_variable( 405 | "class/output_bias_%d" % l_idx, [dim_out], 406 | initializer=tf.zeros_initializer())) 407 | 408 | token_list_output_weights = [] 409 | token_list_output_bias = [] 410 | 411 | for l_idx in range(len(token_proj_layer_dim) - 1): 412 | dim_in = token_proj_layer_dim[l_idx] 413 | dim_out = token_proj_layer_dim[l_idx + 1] 414 | token_list_output_weights.append(tf.get_variable( 415 | "token/output_weights_%d" % l_idx, [dim_in, dim_out], 416 | initializer=tf.truncated_normal_initializer(stddev=0.02))) 417 | token_list_output_bias.append(tf.get_variable( 418 | "token/output_bias_%d" % l_idx, [dim_out], 419 | initializer=tf.zeros_initializer())) 420 | 421 | with tf.variable_scope("loss"): 422 | class_logits = util.fully_connect_layers(class_output_layer, 423 | class_list_output_weights, 424 | class_list_output_bias) 425 | one_hot_class_labels = tf.one_hot(class_label_id, 426 | depth=num_class_labels, 427 | dtype=tf.float32) 428 | class_loss = tf.losses.softmax_cross_entropy( 429 | one_hot_class_labels, class_logits, reduction=tf.losses.Reduction.NONE) 430 | 431 | token_is_pointable = tf.cast(tf.equal(class_label_id, 2), dtype=tf.float32) 432 | 433 | token_output_layer = tf.reshape(token_output_layer, 434 | [batch_size * seq_length, hidden_size]) 435 | token_logits = util.fully_connect_layers(token_output_layer, 436 | token_list_output_weights, 437 | token_list_output_bias) 438 | token_logits = tf.reshape(token_logits, [batch_size, seq_length, 2]) 439 | token_logits = tf.transpose(token_logits, [2, 0, 1]) 440 | unstacked_token_logits = tf.unstack(token_logits, axis=0) 441 | (start_logits, end_logits) = ( 442 | unstacked_token_logits[0], unstacked_token_logits[1]) 443 | 444 | def compute_loss(logits, positions): 445 | one_hot_positions = tf.one_hot( 446 | positions, depth=seq_length, dtype=tf.float32) 447 | log_probs = tf.nn.log_softmax(logits, axis=1) 448 | loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=1) 449 | return loss 450 | 451 | token_loss = (compute_loss(start_logits, start_pos) + compute_loss(end_logits, end_pos)) / 2.0 # per example 452 | if not FLAGS.location_loss_for_nonpointable: 453 | token_loss *= token_is_pointable 454 | 455 | per_example_loss = FLAGS.class_loss_ratio * class_loss + (1-FLAGS.class_loss_ratio) * token_loss 456 | 457 | total_loss += tf.reduce_sum(per_example_loss) 458 | per_slot_per_example_loss[slot] = per_example_loss 459 | per_slot_class_logits[slot] = class_logits 460 | per_slot_start_logits[slot] = start_logits 461 | per_slot_end_logits[slot] = end_logits 462 | return (total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits) 463 | 464 | 465 | def model_fn_builder(bert_config, slot_list, num_class_labels, init_checkpoint, 466 | learning_rate, 467 | num_train_steps, num_warmup_steps, use_tpu, 468 | use_one_hot_embeddings): 469 | """Returns `model_fn` closure for TPUEstimator.""" 470 | 471 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 472 | """The `model_fn` for TPUEstimator.""" 473 | 474 | tf.logging.info("*** Features ***") 475 | for name in sorted(features.keys()): 476 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 477 | 478 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 479 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 480 | 481 | (total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits) = create_model( 482 | bert_config, is_training, slot_list, features, num_class_labels, use_one_hot_embeddings) 483 | 484 | tvars = tf.trainable_variables() 485 | initialized_variable_names = {} 486 | scaffold_fn = None 487 | if init_checkpoint: 488 | (assignment_map, initialized_variable_names 489 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 490 | if use_tpu: 491 | 492 | def tpu_scaffold(): 493 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 494 | return tf.train.Scaffold() 495 | 496 | scaffold_fn = tpu_scaffold 497 | else: 498 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 499 | 500 | tf.logging.info("**** Trainable Variables ****") 501 | for var in tvars: 502 | init_string = "" 503 | if var.name in initialized_variable_names: 504 | init_string = ", *INIT_FROM_CKPT*" 505 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 506 | init_string) 507 | 508 | output_spec = None 509 | if mode == tf.estimator.ModeKeys.TRAIN: 510 | train_op = optimization.create_optimizer( 511 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 512 | 513 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 514 | mode=mode, 515 | loss=total_loss, 516 | train_op=train_op, 517 | scaffold_fn=scaffold_fn) 518 | elif mode == tf.estimator.ModeKeys.EVAL: 519 | def metric_fn(per_slot_per_example_loss, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, is_real_example): 520 | metric_dict = {} 521 | per_slot_correctness = {} 522 | for slot in slot_list: 523 | per_example_loss = per_slot_per_example_loss[slot] 524 | class_logits = per_slot_class_logits[slot] 525 | start_logits = per_slot_start_logits[slot] 526 | end_logits = per_slot_end_logits[slot] 527 | 528 | class_label_id = features['class_label_id_%s' % slot] 529 | start_pos = features['start_pos_%s' % slot] 530 | end_pos = features['end_pos_%s' % slot] 531 | 532 | class_prediction = tf.cast(tf.argmax(class_logits, axis=1), tf.int32) 533 | class_correctness = tf.cast(tf.equal(class_prediction, class_label_id), dtype=tf.float32) 534 | class_accuracy = tf.metrics.mean( 535 | tf.reduce_sum( 536 | class_correctness * is_real_example) / tf.reduce_sum( 537 | is_real_example)) 538 | 539 | token_is_pointable = tf.cast(tf.equal(class_label_id, 2), 540 | dtype=tf.float32) 541 | start_prediction = tf.cast(tf.argmax(start_logits, axis=1), tf.int32) 542 | start_correctness = tf.cast(tf.equal(start_prediction, start_pos), 543 | dtype=tf.float32) 544 | end_prediction = tf.cast(tf.argmax(end_logits, axis=1), tf.int32) 545 | end_correctness = tf.cast(tf.equal(end_prediction, end_pos), 546 | dtype=tf.float32) 547 | token_correctness = start_correctness * end_correctness 548 | token_accuracy = tf.metrics.mean( 549 | tf.reduce_sum( 550 | token_correctness * token_is_pointable) / tf.reduce_sum( 551 | token_is_pointable)) 552 | 553 | total_corretness = class_correctness * (token_is_pointable * token_correctness + (1-token_is_pointable)) 554 | total_accuracy = tf.metrics.mean( 555 | tf.reduce_sum(total_corretness) * is_real_example / tf.reduce_sum(is_real_example)) 556 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 557 | metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy 558 | metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy 559 | metric_dict['eval_accuracy_%s' % slot] = total_accuracy 560 | metric_dict['eval_loss_%s' % slot] = loss 561 | per_slot_correctness[slot] = total_corretness 562 | goal_correctness = tf.reduce_prod( 563 | tf.stack( 564 | [correctness for correctness in 565 | per_slot_correctness.values()], 566 | axis=1), 567 | axis=1) 568 | goal_accuracy = tf.metrics.mean(tf.reduce_sum(goal_correctness * is_real_example) / tf.reduce_sum( 569 | is_real_example)) 570 | metric_dict['eval_accuracy_goal'] = goal_accuracy 571 | return metric_dict 572 | 573 | eval_metrics = (metric_fn, 574 | [per_slot_per_example_loss, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, is_real_example]) 575 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 576 | mode=mode, 577 | loss=total_loss, 578 | eval_metrics=eval_metrics, 579 | scaffold_fn=scaffold_fn) 580 | else: 581 | predictions_dict = {"guid": features["guid"]} 582 | for slot in slot_list: 583 | slot_scope_name = "slot_%s" % slot 584 | if slot == 'price range': 585 | slot_scope_name = "slot_price" 586 | with tf.variable_scope(slot_scope_name): 587 | class_prediction = tf.argmax(per_slot_class_logits[slot], axis=1) 588 | start_prediction = tf.argmax(per_slot_start_logits[slot], axis=1) 589 | end_prediction = tf.argmax(per_slot_end_logits[slot], axis=1) 590 | 591 | predictions_dict["class_prediction_%s" % slot] = class_prediction 592 | predictions_dict["class_label_id_%s" % slot] = features["class_label_id_%s" % slot] 593 | predictions_dict["start_prediction_%s" % slot] = start_prediction 594 | predictions_dict["start_pos_%s" % slot] = features["start_pos_%s" % slot] 595 | predictions_dict["end_prediction_%s" % slot] = end_prediction 596 | predictions_dict["end_pos_%s" % slot] = features["end_pos_%s" % slot] 597 | predictions_dict["input_ids_%s" % slot] = features["input_ids"] 598 | 599 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 600 | mode=mode, 601 | predictions=predictions_dict, 602 | scaffold_fn=scaffold_fn) 603 | return output_spec 604 | 605 | return model_fn 606 | 607 | 608 | def main(_): 609 | tf.logging.set_verbosity(tf.logging.INFO) 610 | 611 | processors = { 612 | "dstc2_clean": Dstc2Processor, 613 | "woz2": Woz2Processor, 614 | "sim-m": SimMProcessor, 615 | "sim-r": SimRProcessor, 616 | } 617 | 618 | tokenization.validate_case_matches_checkpoint( 619 | do_lower_case=True, init_checkpoint=FLAGS.init_checkpoint) 620 | 621 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 622 | raise ValueError( 623 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 624 | 625 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 626 | 627 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 628 | raise ValueError( 629 | "Cannot use sequence length %d because the BERT model " 630 | "was only trained up to sequence length %d" % 631 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 632 | 633 | tf.gfile.MakeDirs(FLAGS.output_dir) 634 | 635 | task_name = FLAGS.task_name.lower() 636 | 637 | if task_name not in processors: 638 | raise ValueError("Task not found: %s" % (task_name)) 639 | 640 | processor = processors[task_name]() 641 | 642 | slot_list = processor.slot_list 643 | class_types = processor.class_types 644 | num_class_labels = len(class_types) 645 | if task_name in ['woz2', 'dstc2_clean']: 646 | num_class_labels -= 1 647 | 648 | tokenizer = tokenization.FullTokenizer( 649 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 650 | 651 | tpu_cluster_resolver = None 652 | if FLAGS.use_tpu and FLAGS.tpu_name: 653 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 654 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 655 | 656 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 657 | run_config = tf.contrib.tpu.RunConfig( 658 | cluster=tpu_cluster_resolver, 659 | master=FLAGS.master, 660 | model_dir=FLAGS.output_dir, 661 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 662 | keep_checkpoint_max=None, 663 | tpu_config=tf.contrib.tpu.TPUConfig( 664 | iterations_per_loop=FLAGS.iterations_per_loop, 665 | num_shards=FLAGS.num_tpu_cores, 666 | per_host_input_for_training=is_per_host)) 667 | 668 | train_examples = None 669 | num_train_steps = None 670 | num_warmup_steps = None 671 | if FLAGS.do_train: 672 | train_examples = processor.get_train_examples(FLAGS.data_dir) 673 | num_train_steps = int( 674 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 675 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 676 | 677 | model_fn = model_fn_builder( 678 | bert_config=bert_config, 679 | slot_list=slot_list, 680 | num_class_labels=num_class_labels, 681 | init_checkpoint=FLAGS.init_checkpoint, 682 | learning_rate=FLAGS.learning_rate, 683 | num_train_steps=num_train_steps, 684 | num_warmup_steps=num_warmup_steps, 685 | use_tpu=FLAGS.use_tpu, 686 | use_one_hot_embeddings=FLAGS.use_tpu) 687 | 688 | # If TPU is not available, this will fall back to normal Estimator on CPU 689 | # or GPU. 690 | estimator = tf.contrib.tpu.TPUEstimator( 691 | use_tpu=FLAGS.use_tpu, 692 | model_fn=model_fn, 693 | config=run_config, 694 | train_batch_size=FLAGS.train_batch_size, 695 | eval_batch_size=FLAGS.eval_batch_size, 696 | predict_batch_size=FLAGS.predict_batch_size) 697 | 698 | if FLAGS.do_train: 699 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 700 | file_based_convert_examples_to_features( 701 | train_examples, slot_list, class_types, FLAGS.max_seq_length, tokenizer, train_file) 702 | tf.logging.info("***** Running training *****") 703 | tf.logging.info(" Num examples = %d", len(train_examples)) 704 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 705 | tf.logging.info(" Num steps = %d", num_train_steps) 706 | train_input_fn = file_based_input_fn_builder( 707 | input_file=train_file, 708 | seq_length=FLAGS.max_seq_length, 709 | is_training=True, 710 | drop_remainder=True, 711 | slot_list=slot_list) 712 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 713 | 714 | if FLAGS.do_eval: 715 | if FLAGS.eval_set == 'dev': 716 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 717 | else: 718 | eval_examples = processor.get_test_examples(FLAGS.data_dir) 719 | num_actual_eval_examples = len(eval_examples) 720 | if FLAGS.use_tpu: 721 | # TPU requires a fixed batch size for all batches, therefore the number 722 | # of examples must be a multiple of the batch size, or else examples 723 | # will get dropped. So we pad with fake examples which are ignored 724 | # later on. These do NOT count towards the metric (all tf.metrics 725 | # support a per-instance weight, and these get a weight of 0.0). 726 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 727 | eval_examples.append(run_classifier.PaddingInputExample()) 728 | 729 | eval_file = os.path.join(FLAGS.output_dir, "eval.%s.tf_record" % FLAGS.eval_set) 730 | file_based_convert_examples_to_features( 731 | eval_examples, slot_list, class_types, FLAGS.max_seq_length, tokenizer, eval_file) 732 | 733 | tf.logging.info("***** Running evaluation *****") 734 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 735 | len(eval_examples), num_actual_eval_examples, 736 | len(eval_examples) - num_actual_eval_examples) 737 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 738 | 739 | # This tells the estimator to run through the entire set. 740 | eval_steps = None 741 | # However, if running eval on the TPU, you will need to specify the 742 | # number of steps. 743 | if FLAGS.use_tpu: 744 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 745 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 746 | 747 | eval_drop_remainder = True if FLAGS.use_tpu else False 748 | eval_input_fn = file_based_input_fn_builder( 749 | input_file=eval_file, 750 | seq_length=FLAGS.max_seq_length, 751 | is_training=False, 752 | drop_remainder=eval_drop_remainder, 753 | slot_list=slot_list) 754 | output_eval_file = os.path.join(FLAGS.output_dir, 755 | "eval_res.%s.json" % FLAGS.eval_set) 756 | if tf.gfile.Exists(output_eval_file): 757 | with tf.gfile.GFile(output_eval_file) as f: 758 | eval_result = json.load(f) 759 | else: 760 | eval_result = [] 761 | 762 | ckpt_nums = [num.strip() for num in FLAGS.eval_ckpt.split(',') if num.strip() != ""] 763 | for ckpt_num in ckpt_nums: 764 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, 765 | checkpoint_path=os.path.join(FLAGS.output_dir, 766 | "model.ckpt-%s" % ckpt_num)) 767 | result_dict = {k: float(v) for k, v in result.items()} 768 | eval_result.append(result_dict) 769 | tf.logging.info("***** Eval results for %s set *****", FLAGS.eval_set) 770 | for key in sorted(result.keys()): 771 | tf.logging.info("%s = %s", key, str(result[key])) 772 | if len(eval_result) > 0: 773 | with tf.gfile.GFile(output_eval_file, "w") as f: 774 | json.dump(eval_result, f, indent=2) 775 | 776 | if FLAGS.do_predict: 777 | if FLAGS.eval_set == 'dev': 778 | predict_examples = processor.get_dev_examples(FLAGS.data_dir) 779 | else: 780 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 781 | num_actual_predict_examples = len(predict_examples) 782 | if FLAGS.use_tpu: 783 | # TPU requires a fixed batch size for all batches, therefore the number 784 | # of examples must be a multiple of the batch size, or else examples 785 | # will get dropped. So we pad with fake examples which are ignored 786 | # later on. 787 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 788 | predict_examples.append(run_classifier.PaddingInputExample()) 789 | 790 | predict_file = os.path.join(FLAGS.output_dir, "pred.%s.tf_record" % FLAGS.eval_set) 791 | file_based_convert_examples_to_features(predict_examples, slot_list, class_types, 792 | FLAGS.max_seq_length, tokenizer, 793 | predict_file) 794 | 795 | tf.logging.info("***** Running prediction *****") 796 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 797 | len(predict_examples), num_actual_predict_examples, 798 | len(predict_examples) - num_actual_predict_examples) 799 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 800 | 801 | predict_drop_remainder = True if FLAGS.use_tpu else False 802 | predict_input_fn = file_based_input_fn_builder( 803 | input_file=predict_file, 804 | seq_length=FLAGS.max_seq_length, 805 | is_training=False, 806 | drop_remainder=predict_drop_remainder, 807 | slot_list=slot_list) 808 | 809 | ckpt_nums = [num for num in FLAGS.eval_ckpt.split(',') if num != ""] 810 | for ckpt_num in ckpt_nums: 811 | result = estimator.predict(input_fn=predict_input_fn, 812 | checkpoint_path=os.path.join(FLAGS.output_dir, 813 | "model.ckpt-%s" % ckpt_num)) 814 | 815 | output_predict_file = os.path.join(FLAGS.output_dir, 816 | "pred_res.%s.%08d.json" % (FLAGS.eval_set, int(ckpt_num))) 817 | with tf.gfile.GFile(output_predict_file, "w") as f: 818 | num_written_ex = 0 819 | tf.logging.info("***** Predict results for %s set *****", FLAGS.eval_set) 820 | list_prediction = [] 821 | for (i, prediction) in enumerate(result): 822 | # Str feature is encoded as bytes, which is not JSON serializable. 823 | # Hence convert to str. 824 | prediction["guid"] = prediction["guid"].decode("utf-8").split("-") 825 | for slot in slot_list: 826 | start_pd = prediction['start_prediction_%s' % slot] 827 | start_gt = prediction['start_pos_%s' % slot] 828 | end_pd = prediction['end_prediction_%s' % slot] 829 | end_gt = prediction['end_pos_%s' % slot] 830 | # TF uses int64, which is not JSON serializable. 831 | # Hence convert to int. 832 | prediction['class_prediction_%s' % slot] = int(prediction['class_prediction_%s' % slot]) 833 | prediction['class_label_id_%s' % slot] = int(prediction['class_label_id_%s' % slot]) 834 | prediction['start_prediction_%s' % slot] = int(start_pd) 835 | prediction['start_pos_%s' % slot] = int(start_gt) 836 | prediction['end_prediction_%s' % slot] = int(end_pd) 837 | prediction['end_pos_%s' % slot] = int(end_gt) 838 | prediction["input_ids_%s" % slot] = list(map(int, prediction["input_ids_%s" % slot].tolist())) 839 | input_tokens = tokenizer.convert_ids_to_tokens(prediction["input_ids_%s" % slot]) 840 | prediction["slot_prediction_%s" % slot] = ' '.join(input_tokens[start_pd:end_pd+1]) 841 | prediction["slot_groundtruth_%s" % slot] = ' '.join(input_tokens[start_gt:end_gt + 1]) 842 | list_prediction.append(prediction) 843 | if i >= num_actual_predict_examples: 844 | break 845 | num_written_ex += 1 846 | json.dump(list_prediction, f, indent=2) 847 | assert num_written_ex == num_actual_predict_examples 848 | 849 | 850 | if __name__ == "__main__": 851 | flags.mark_flag_as_required("data_dir") 852 | flags.mark_flag_as_required("task_name") 853 | flags.mark_flag_as_required("vocab_file") 854 | flags.mark_flag_as_required("bert_config_file") 855 | flags.mark_flag_as_required("output_dir") 856 | tf.app.run() 857 | -------------------------------------------------------------------------------- /metric_bert_dst.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import sys 4 | import numpy as np 5 | 6 | 7 | def get_joint_slot_correctness(fp, 8 | key_class_label_id='class_label_id', 9 | key_class_prediction='class_prediction', 10 | key_start_pos='start_pos', 11 | key_start_prediction='start_prediction', 12 | key_end_pos='end_pos', 13 | key_end_prediction='end_prediction'): 14 | with open(fp) as f: 15 | preds = json.load(f) 16 | class_correctness = [] 17 | pos_correctness = [] 18 | total_correctness = [] 19 | 20 | for pred in preds: 21 | guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx 22 | turn_gt_class = pred[key_class_label_id] 23 | turn_pd_class = pred[key_class_prediction] 24 | gt_start_pos = pred[key_start_pos] 25 | pd_start_pos = pred[key_start_prediction] 26 | gt_end_pos = pred[key_end_pos] 27 | pd_end_pos = pred[key_end_prediction] 28 | 29 | if guid[-1] == '0': # First turn, reset the slots 30 | joint_gt_class = turn_gt_class 31 | joint_gt_start_pos = gt_start_pos 32 | joint_gt_end_pos = gt_end_pos 33 | joint_pd_class = turn_pd_class 34 | joint_pd_start_pos = pd_start_pos 35 | joint_pd_end_pos = pd_end_pos 36 | else: 37 | if turn_gt_class > 0: 38 | joint_gt_class = turn_gt_class 39 | joint_gt_start_pos = gt_start_pos 40 | joint_gt_end_pos = gt_end_pos 41 | if turn_pd_class > 0: 42 | joint_pd_class = turn_pd_class 43 | joint_pd_start_pos = pd_start_pos 44 | joint_pd_end_pos = pd_end_pos 45 | 46 | total_correct = True 47 | if joint_gt_class == joint_pd_class: 48 | class_correctness.append(1.0) 49 | if joint_gt_class == 2: 50 | if joint_gt_start_pos == joint_pd_start_pos and joint_gt_end_pos == joint_pd_end_pos: 51 | pos_correctness.append(1.0) 52 | else: 53 | pos_correctness.append(0.0) 54 | total_correct = False 55 | else: 56 | class_correctness.append(0.0) 57 | total_correct = False 58 | if total_correct: 59 | total_correctness.append(1.0) 60 | else: 61 | total_correctness.append(0.0) 62 | 63 | return np.asarray(total_correctness), np.asarray(class_correctness), np.asarray(pos_correctness) 64 | 65 | 66 | if __name__ == "__main__": 67 | acc_list = [] 68 | key_class_label_id = 'class_label_id_%s' 69 | key_class_prediction = 'class_prediction_%s' 70 | key_start_pos = 'start_pos_%s' 71 | key_start_prediction = 'start_prediction_%s' 72 | key_end_pos = 'end_pos_%s' 73 | key_end_prediction = 'end_prediction_%s' 74 | 75 | 76 | for fp in sorted(glob.glob(sys.argv[2])): 77 | print(fp) 78 | goal_correctness = 1.0 79 | dataset = sys.argv[1].lower() 80 | if dataset in ['woz2', 'dstc2_clean']: 81 | slots = ['area', 'food', 'price range'] 82 | elif dataset == 'sim-m': 83 | slots = ['date', 'movie', 'time', 'num_tickets', 'theatre_name'] 84 | elif dataset == 'sim-r': 85 | slots = ['category', 'rating', 'num_people', 'location', 'restaurant_name', 86 | 'time', 'date', 'price_range', 'meal'] 87 | for slot in slots: 88 | tot_cor, cls_cor, pos_cor = get_joint_slot_correctness(fp, 89 | key_class_label_id=(key_class_label_id % slot), 90 | key_class_prediction=(key_class_prediction % slot), 91 | key_start_pos=(key_start_pos % slot), 92 | key_start_prediction=(key_start_prediction % slot), 93 | key_end_pos=(key_end_pos % slot), 94 | key_end_prediction=(key_end_prediction % slot) 95 | ) 96 | print('%s: joint slot acc: %g, class acc: %g, position acc: %g' % (slot, np.mean(tot_cor), np.mean(cls_cor), np.mean(pos_cor))) 97 | goal_correctness *= tot_cor 98 | 99 | acc = np.mean(goal_correctness) 100 | acc_list.append((fp, acc)) 101 | acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True) 102 | for (fp, acc) in acc_list_s: 103 | print('Joint goal acc: %g, %s' % (acc, fp)) -------------------------------------------------------------------------------- /storage/dstc2-clean.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanlinchao/bert-dst/5295b2d7e74c57aa32a02aadeb3ab70d22af033b/storage/dstc2-clean.zip -------------------------------------------------------------------------------- /storage/woz_2.0.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanlinchao/bert-dst/5295b2d7e74c57aa32a02aadeb3ab70d22af033b/storage/woz_2.0.zip -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # TASK can be "dstc2_clean", "woz2", "sim-m", or "sim-r" 4 | TASK="dstc2_clean" 5 | # Directory for dstc2-clean, woz_2.0, sim-M, or sim-R, which contains json files 6 | DATA_DIR=/path/to/dstc2-clean 7 | # Directory of the pre-trained [BERT-Base, Uncased] model 8 | PRETRAINED_BERT=/path/to/uncased_L-12_H-768_A-12 9 | # Output directory of trained checkpoints 10 | OUTPUT_DIR=/path/to/output 11 | 12 | mkdir -p $OUTPUT_DIR 13 | python main.py \ 14 | --task_name=${TASK} \ 15 | --do_train=true \ 16 | --train_batch_size=16 \ 17 | --slot_value_dropout=0.0 \ 18 | --max_seq_length=180 \ 19 | --data_dir=$DATA_DIR \ 20 | --vocab_file=${PRETRAINED_BERT}/vocab.txt \ 21 | --bert_config_file=${PRETRAINED_BERT}/bert_config.json \ 22 | --init_checkpoint=${PRETRAINED_BERT}/bert_model.ckpt \ 23 | --learning_rate=2e-5 \ 24 | --num_train_epochs=100 \ 25 | --output_dir=$OUTPUT_DIR \ 26 | 2>&1 | tee -a $OUTPUT_DIR/train.log 27 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class InputExample(object): 5 | """A single training/test example for simple sequence classification.""" 6 | 7 | def __init__(self, guid, text_a, text_b, text_a_label=None, 8 | text_b_label=None, class_label=None): 9 | """Constructs a InputExample. 10 | """ 11 | self.guid = guid 12 | self.text_a = text_a 13 | self.text_b = text_b 14 | self.text_a_label = text_a_label 15 | self.text_b_label = text_b_label 16 | self.class_label = class_label 17 | 18 | 19 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 20 | """Truncates a sequence pair in place to the maximum length. 21 | 22 | Copied from bert/run_classifier.py 23 | """ 24 | 25 | # This is a simple heuristic which will always truncate the longer sequence 26 | # one token at a time. This makes more sense than truncating an equal percent 27 | # of tokens from each, since if one sequence is very short then each token 28 | # that's truncated likely contains more information than a longer sequence. 29 | while True: 30 | total_length = len(tokens_a) + len(tokens_b) 31 | if total_length <= max_length: 32 | break 33 | if len(tokens_a) > len(tokens_b): 34 | tokens_a.pop() 35 | else: 36 | tokens_b.pop() 37 | 38 | 39 | def truncate_length_and_warn(tokens_a, tokens_b, max_seq_length, guid): 40 | # Modifies `tokens_a` and `tokens_b` in place so that the total 41 | # length is less than the specified length. 42 | # Account for [CLS], [SEP], [SEP] with "- 3" 43 | if len(tokens_a) + len(tokens_b) > max_seq_length - 3: 44 | tf.logging.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b))) 45 | input_text_too_long = True 46 | else: 47 | input_text_too_long = False 48 | 49 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 50 | 51 | return input_text_too_long 52 | 53 | 54 | def get_token_label_ids(token_labels_a, token_labels_b, max_seq_length): 55 | token_label_ids = [] 56 | token_label_ids.append(0) 57 | 58 | for token_label in token_labels_a: 59 | token_label_ids.append(token_label) 60 | 61 | token_label_ids.append(0) 62 | 63 | for token_label in token_labels_b: 64 | token_label_ids.append(token_label) 65 | 66 | token_label_ids.append(0) 67 | 68 | while len(token_label_ids) < max_seq_length: 69 | token_label_ids.append(0) 70 | 71 | assert len(token_label_ids) == max_seq_length 72 | return token_label_ids 73 | 74 | 75 | def get_start_end_pos(class_type, token_label_ids, max_seq_length): 76 | if class_type == 'copy_value' and 1 not in token_label_ids: 77 | raise ValueError('Copy value but token_label not detected.') 78 | if class_type != 'copy_value': 79 | start_pos = 0 80 | end_pos = 0 81 | else: 82 | start_pos = token_label_ids.index(1) 83 | end_pos = max_seq_length - 1 - token_label_ids[::-1].index(1) 84 | # tf.logging.info('token_label_ids: %s' % str(token_label_ids)) 85 | # tf.logging.info('start_pos: %d' % start_pos) 86 | # tf.logging.info('end_pos: %d' % end_pos) 87 | for i in range(max_seq_length): 88 | if i >= start_pos and i <= end_pos: 89 | assert token_label_ids[i] == 1 90 | else: 91 | assert token_label_ids[i] == 0 92 | return start_pos, end_pos 93 | 94 | 95 | def get_bert_input(tokens_a, tokens_b, max_seq_length, tokenizer): 96 | # The convention in BERT is: 97 | # (a) For sequence pairs: 98 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 99 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 100 | # (b) For single sequences: 101 | # tokens: [CLS] the dog is hairy . [SEP] 102 | # type_ids: 0 0 0 0 0 0 0 103 | # 104 | # Where "type_ids" are used to indicate whether this is the first 105 | # sequence or the second sequence. The embedding vectors for `type=0` and 106 | # `type=1` were learned during pre-training and are added to the wordpiece 107 | # embedding vector (and position vector). This is not *strictly* necessary 108 | # since the [SEP] token unambiguously separates the sequences, but it makes 109 | # it easier for the model to learn the concept of sequences. 110 | # 111 | # For classification tasks, the first vector (corresponding to [CLS]) is 112 | # used as the "sentence vector". Note that this only makes sense because 113 | # the entire model is fine-tuned. 114 | tokens = [] 115 | segment_ids = [] 116 | 117 | tokens.append("[CLS]") 118 | segment_ids.append(0) 119 | 120 | for token in tokens_a: 121 | tokens.append(token) 122 | segment_ids.append(0) 123 | 124 | tokens.append("[SEP]") 125 | segment_ids.append(0) 126 | 127 | for token in tokens_b: 128 | tokens.append(token) 129 | segment_ids.append(1) 130 | 131 | tokens.append("[SEP]") 132 | segment_ids.append(1) 133 | 134 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 135 | 136 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 137 | # tokens are attended to. 138 | input_mask = [1] * len(input_ids) 139 | 140 | # Zero-pad up to the sequence length. 141 | while len(input_ids) < max_seq_length: 142 | input_ids.append(0) 143 | input_mask.append(0) 144 | segment_ids.append(0) 145 | 146 | assert len(input_ids) == max_seq_length 147 | assert len(input_mask) == max_seq_length 148 | assert len(segment_ids) == max_seq_length 149 | return tokens, input_ids, input_mask, segment_ids 150 | 151 | 152 | def fully_connect(logits, weights, bias=None, activation=None): 153 | out_logits = tf.matmul(logits, weights) 154 | if bias is not None: 155 | out_logits = tf.nn.bias_add(out_logits, bias) 156 | if activation == 'relu': 157 | out_logits = tf.nn.relu(out_logits) 158 | return out_logits 159 | 160 | 161 | def fully_connect_layers(input_layer, list_weights, list_bias): 162 | """Fully conntect multiple layers, with 163 | (1) input layer unchanged. 164 | (2) all layers have relu activation except for the last layer. 165 | """ 166 | if len(list_weights) == 1: 167 | logits = fully_connect(input_layer, list_weights[0], list_bias[0]) 168 | else: 169 | logits = fully_connect(input_layer, list_weights[0], list_bias[0], activation='relu') 170 | if len(list_weights) > 2: 171 | for l_idx in range(1, len(list_weights) - 1): 172 | logits = fully_connect(logits, list_weights[l_idx], list_bias[l_idx], activation='relu') 173 | logits = fully_connect(logits, list_weights[-1], list_bias[-1]) 174 | return logits 175 | 176 | --------------------------------------------------------------------------------