├── .gitignore ├── LICENSE ├── README.md ├── exec_data_process.sh ├── exec_main.sh ├── requirements.txt └── src ├── custom_dataset.py ├── data_process.py ├── entity_bert.py ├── layers.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Large size files 132 | data/ 133 | saved_models/ 134 | 135 | # Jupyter notebook 136 | nohup.out 137 | *.ipynb 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jaewoo Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert-crf-entity-extraction-pytorch 2 | This repository is for the entity extraction task using the pre-trained **BERT**[[1]](#1) and the additional **CRF**(Conditional Random Field)[[2]](#2) layer. 3 | 4 | Originally, this project has been conducted for dialogue datasets, so it contains both single-turn setting and multi-turn setting. 5 | 6 | The single-turn setting is the same as the basic entity extraction task, but the multi-turn one is a little bit different since it considers the dialogue contexts(previous histories) to conduct the entity extraction task to current utterance. 7 | 8 | The multi-turn context application is based on **ReCoSa**(the Relevant Contexts with Self-attention)[[3]](#3) structure. 9 | 10 | You can see the details of each model in below descriptions. 11 | 12 | The structure of BERT+CRF entity extraction model in the single-turn / multi-turn setting. 13 | 14 |
15 | 16 | --- 17 | 18 | ### Arguments 19 | 20 | **Arguments for data pre-processing** 21 | 22 | | Argument | Type | Description | Default | 23 | | --------------------- | -------- | ------------------------------------------------------------ | ---------------------- | 24 | | `seed` | `int` | The random seed. | `0` | 25 | | `data_dir` | `str` | The parent data directory. | `"data"` | 26 | | `raw_dir` | `str` | The directory which contains the raw data json files. | `"raw"` | 27 | | `save_dir` | `str` | The directory which will contain the parsed data pickle files. | `"processed"` | 28 | | `bert_type` | `str` | The BERT type to load. | `"bert-base-uncased"` | 29 | | `train_ratio` | `float` | The ratio of train set to the total number of dialogues in each file. | `0.8` | 30 | 31 |
32 | 33 | **Arguments for training/evaluating** 34 | 35 | | Argument | Type | Description | Default | 36 | | -------------------- | ------- | ------------------------------------------------------------ | --------------------- | 37 | | `seed` | `int` | The random seed. | `0` | 38 | | `turn_type` | `str` | The turn type setting. (`"single"` or `"multi"`) | *YOU SHOULD SPECIFY* | 39 | | `bert_type` | `str` | The BERT type to load. | `"bert-base-uncased"` | 40 | | `pooling` | `str` | The pooling policy when using the multi-turn setting. | `"cls"` | 41 | | `data_dir` | `str` | The parent data directory. | `"data"` | 42 | | `processed_dir` | `str` | The directory which contains the parsed data pickle files. | `"processed"` | 43 | | `ckpt_dir` | `str` | The path for saved checkpoints. | `"saved_models"` | 44 | | `gpu` | `int` | The index of a GPU to use. | `0` | 45 | | `sp1_token` | `str` | The speaker1(USER) token. | `"[USR]"` | 46 | | `sp2_token` | `str` | The speaker2(SYSTEM) token. | `"[SYS]"` | 47 | | `max_len` | `int` | The max length of each utterance. | `128` | 48 | | `max_turns` | `int` | The maximum number of the dialogue history to be attended in the multi-turn setting. | `5` | 49 | | `dropout` | `float` | The dropout rate. | `0.1` | 50 | | `context_d_ff` | `int` | The size of intermediate hidden states in the feed-forward layer. | `2048` | 51 | | `context_num_heads` | `int` | The number of heads for the multi-head attention. | `8` | 52 | | `context_dropout` | `float` | The dropout rate for the context encoder. | `0.1` | 53 | | `context_num_layers` | `int` | The number of layers in the context encoder. | `2` | 54 | | `learning_rate` | `float` | The initial learning rate. | `5e-5` | 55 | | `warmup_ratio` | `float` | The ratio of warmup steps to the total training steps. | `0.1` | 56 | | `batch_size` | `int` | The batch size. | `8` | 57 | | `num_workers` | `int` | The number of sub-processes for data loading. | `4` | 58 | | `num_epochs` | `int` | The number of training epochs. | `10` | 59 | 60 |
61 | 62 |
63 | 64 | ### Dataset 65 | 66 | This repository uses the Google's Taskmaster-2[[4]](#4) dataset for entity extraction task. 67 | 68 | You should first download the data (`"TM-2-2020"`), and get all json files in `"TM-2-2020/data"` directory to properly run this project. 69 | 70 | You can see the detailes for using the Taskmaster-2 dataset in the next section. 71 | 72 |
73 | 74 |
75 | 76 | ### How to run 77 | 78 | 1. Install all required packages. 79 | 80 | ```shell 81 | pip install -r requirements.txt 82 | ``` 83 | 84 |
85 | 86 | 2. Make the directory `{data_dir}/{raw_dir}` and put the json files, as mentioned in the previous section. 87 | 88 | In default setting, the structure of whole data directory should be like below. 89 | 90 | ``` 91 | data 92 | └--raw 93 | └--flight.json 94 | └--food-ordering.json 95 | └--hotels.json 96 | └--movies.json 97 | └--music.json 98 | └--restaurant-search.json 99 | └--sports.json 100 | ``` 101 | 102 |
103 | 104 | 3. Run the data processing script. 105 | 106 | ```shell 107 | sh exec_data_processing.sh 108 | ``` 109 | 110 | After running it, you will get the processed files like below in the default setting. 111 | 112 | ``` 113 | data 114 | └--raw 115 | └--flight.json 116 | └--food-ordering.json 117 | └--hotels.json 118 | └--movies.json 119 | └--music.json 120 | └--restaurant-search.json 121 | └--sports.json 122 | └--processed 123 | └--class_dict.json 124 | └--train_tokens.pkl 125 | └--train_tags.pkl 126 | └--valid_tokens.pkl 127 | └--valid_tags.pkl 128 | └--test_tokens.pkl 129 | └--test_tags.pkl 130 | ``` 131 | 132 |
133 | 134 | 4. Run the main script and check the results. 135 | 136 | ```shell 137 | sh exec_main.sh 138 | ``` 139 | 140 |
141 | 142 | --- 143 | 144 | ### Results 145 | 146 | | Turn type | Pooling | Validation F1 | Test F1 | 147 | | --------- | ------- | ------------- | ---------- | 148 | | Single | - | 0.6719 | 0.6755 | 149 | | Multi | CLS | **0.7148** | **0.7118** | 150 | | Multi | Mean | 0.7132 | 0.7095 | 151 | | Multi | Max | 0.7116 | 0.7104 | 152 | 153 |
154 | 155 | --- 156 | 157 | ### References 158 | 159 | [1] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. *arXiv preprint arXiv:1810.04805*. ([https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805)) 160 | 161 | [2] Lafferty, J., McCallum, A., & Pereira, F. C. (2001). Conditional random fields: Probabilistic models for segmenting and labeling sequence data. ([https://repository.upenn.edu/cis_papers/159/](https://repository.upenn.edu/cis_papers/159/)) 162 | 163 | [3] Zhang, H., Lan, Y., Pang, L., Guo, J., & Cheng, X. (2019). Recosa: Detecting the relevant contexts with self-attention for multi-turn dialogue generation. *arXiv preprint arXiv:1907.05339*. ([https://arxiv.org/abs/1907.05339](https://arxiv.org/abs/1907.05339)) 164 | 165 | [4] Taskmaster-2 . (2020). ([https://research.google/tools/datasets/taskmaster-2/](https://research.google/tools/datasets/taskmaster-2/)) 166 | -------------------------------------------------------------------------------- /exec_data_process.sh: -------------------------------------------------------------------------------- 1 | python src/data_process.py \ 2 | --seed=0 \ 3 | --data_dir="data" \ 4 | --raw_dir="raw" \ 5 | --save_dir="processed" \ 6 | --bert_type="bert-base-uncased" \ 7 | --train_ratio=0.8 8 | -------------------------------------------------------------------------------- /exec_main.sh: -------------------------------------------------------------------------------- 1 | python src/main.py \ 2 | --seed=0 \ 3 | --turn_type=TURN_TYPE \ 4 | --bert_type="bert-base-uncased" \ 5 | --pooling="max" \ 6 | --data_dir="data" \ 7 | --processed_dir="processed" \ 8 | --ckpt_dir="saved_models" \ 9 | --gpu="0" \ 10 | --sp1_token="[USR]" \ 11 | --sp2_token="[SYS]" \ 12 | --max_len=128 \ 13 | --max_turns=5 \ 14 | --dropout=0.1 \ 15 | --context_d_ff=2048 \ 16 | --context_num_heads=8 \ 17 | --context_dropout=0.1 \ 18 | --context_num_layers=2 \ 19 | --learning_rate=5e-5 \ 20 | --warmup_ratio=0.1 \ 21 | --batch_size=8 \ 22 | --num_workers=4 \ 23 | --num_epochs=10 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | transformers==4.11.3 3 | scikit-learn==1.0.1 4 | pytorch-crf==0.7.2 5 | seqeval==1.2.2 6 | -------------------------------------------------------------------------------- /src/custom_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | 4 | import torch 5 | import numpy as np 6 | import pickle 7 | 8 | 9 | class CustomDataset(Dataset): 10 | def __init__(self, args, vocab, class_dict, prefix): 11 | with open(f"{args.data_dir}/{args.processed_dir}/{prefix}_tokens.pkl", 'rb') as f: 12 | tokens = pickle.load(f) 13 | 14 | with open(f"{args.data_dir}/{args.processed_dir}/{prefix}_tags.pkl", 'rb') as f: 15 | tags = pickle.load(f) 16 | 17 | self.input_ids = [] 18 | self.labels = [] 19 | self.valid_lens = [] 20 | self.turns = [] 21 | 22 | if args.turn_type == 'single': 23 | self.process_single_turn(args, vocab, class_dict, tokens, tags) 24 | elif args.turn_type == 'multi': 25 | self.process_multi_turns(args, vocab, class_dict, tokens, tags) 26 | 27 | assert len(self.input_ids) == len(self.labels) 28 | assert len(self.input_ids) == len(self.valid_lens) 29 | assert len(self.input_ids) == len(self.turns) 30 | 31 | print(f"{len(self.input_ids)} samples are prepared for {prefix} set.") 32 | 33 | self.input_ids = torch.LongTensor(self.input_ids) # (N, L) or (N, T, L) 34 | self.labels = torch.LongTensor(self.labels) # (N, L) 35 | self.valid_lens = torch.LongTensor(self.valid_lens) # (N) 36 | self.turns = torch.LongTensor(self.turns) # (N) 37 | 38 | def process_single_turn(self, args, vocab, class_dict, tokens, tags): 39 | assert len(tokens) == len(tags) 40 | 41 | for d in tqdm(range(len(tokens))): 42 | dial_tokens, dial_tags = tokens[d], tags[d] 43 | assert len(dial_tokens) == len(dial_tags) 44 | 45 | for u in range(len(dial_tokens)): 46 | utter_tokens, utter_tags = dial_tokens[u], dial_tags[u] 47 | sp, utter_tokens = utter_tokens[0], utter_tokens[1:] 48 | assert len(utter_tokens) == len(utter_tags) 49 | 50 | token_ids = [vocab[token] for token in utter_tokens] 51 | tag_ids = [class_dict[tag] for tag in utter_tags] 52 | 53 | if sp == "USER": # Speaker1: USER 54 | sp_id = args.sp1_id 55 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids, tag_ids) 56 | 57 | self.input_ids.append(token_ids) 58 | self.labels.append(tag_ids) 59 | self.valid_lens.append(valid_len) 60 | self.turns.append(0) 61 | 62 | def process_multi_turns(self, args, vocab, class_dict, tokens, tags): 63 | assert len(tokens) == len(tags) 64 | 65 | for d in tqdm(range(len(tokens))): 66 | dial_tokens, dial_tags = tokens[d], tags[d] 67 | assert len(dial_tokens) == len(dial_tags) 68 | 69 | token_hists, tag_hists, len_hists = [], [], [] 70 | for u in range(len(dial_tokens)): 71 | utter_tokens, utter_tags = dial_tokens[u], dial_tags[u] 72 | sp, utter_tokens = utter_tokens[0], utter_tokens[1:] 73 | assert len(utter_tokens) == len(utter_tags) 74 | 75 | token_ids = [vocab[token] for token in utter_tokens] 76 | tag_ids = [class_dict[tag] for tag in utter_tags] 77 | 78 | if sp == "USER": # Speaker1: USER 79 | sp_id = args.sp1_id 80 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids, tag_ids) 81 | elif sp == "ASSISTANT": # Speaker2: SYSTEM 82 | sp_id = args.sp2_id 83 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids) 84 | 85 | token_hists.append(token_ids) 86 | tag_hists.append(tag_ids) 87 | len_hists.append(valid_len) 88 | 89 | assert len(token_hists) == len(tag_hists) 90 | assert len(tag_hists) == len(len_hists) 91 | 92 | init_ids = [args.cls_id] + [args.pad_id] * (args.max_len-2) + [args.sep_id] 93 | for u in range(len(token_hists)): 94 | token_ids, tag_ids, valid_len = token_hists[u], tag_hists[u], len_hists[u] 95 | if token_ids[1] == args.sp1_id: 96 | token_hist = token_hists[max(u+1-args.max_turns, 0):u+1] 97 | assert len(token_hist[-1]) == len(tag_ids) 98 | assert len(token_hist) <= args.max_turns 99 | self.turns.append(len(token_hist)-1) 100 | token_hist += [init_ids] * (args.max_turns-len(token_hist)) 101 | assert len(token_hist) == args.max_turns 102 | self.input_ids.append(token_hist) 103 | self.labels.append(tag_ids) 104 | self.valid_lens.append(valid_len) 105 | 106 | def pad_or_truncate(self, args, sp_id, token_ids, tag_ids=None): 107 | token_ids = [args.cls_id, sp_id] + token_ids + [args.sep_id] 108 | if len(token_ids) <= args.max_len: 109 | pad_len = args.max_len - len(token_ids) 110 | token_ids += ([args.pad_id] * pad_len) 111 | 112 | valid_len = -1 113 | if tag_ids is not None: 114 | tag_ids = [args.o_id, args.o_id] + tag_ids + [args.o_id] 115 | valid_len = len(tag_ids) 116 | tag_ids += ([args.o_id] * pad_len) 117 | else: 118 | token_ids = token_ids[:args.max_len] 119 | token_ids[-1] = args.sep_id 120 | 121 | valid_len = -1 122 | if tag_ids is not None: 123 | tag_ids = [args.o_id, args.o_id] + tag_ids + [args.o_id] 124 | tag_ids = tag_ids[:args.max_len] 125 | tag_ids[-1] = args.o_id 126 | valid_len = args.max_len 127 | 128 | assert len(token_ids) == args.max_len 129 | if tag_ids is not None: 130 | assert len(token_ids) == len(tag_ids) 131 | 132 | return token_ids, tag_ids, valid_len 133 | 134 | def __len__(self): 135 | return self.input_ids.shape[0] 136 | 137 | def __getitem__(self, idx): 138 | return self.input_ids[idx], self.labels[idx], self.valid_lens[idx], self.turns[idx] 139 | -------------------------------------------------------------------------------- /src/data_process.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from glob import glob 3 | from transformers import BertTokenizer 4 | 5 | import argparse 6 | import os 7 | import random 8 | import json 9 | import pickle 10 | 11 | 12 | def load_file(file, tokenizer): 13 | total_tokens = [] 14 | total_tags = [] 15 | 16 | with open(file, 'r') as f: 17 | data = json.load(f) 18 | 19 | for dial in tqdm(data): 20 | dial_tokens, dial_tags = [], [] 21 | turns = dial['utterances'] 22 | for turn in turns: 23 | sp = turn['speaker'] 24 | text = turn['text'] 25 | 26 | tokens = tokenizer.tokenize(text) 27 | entity_tags = ['O'] * len(tokens) 28 | 29 | if 'segments' in turn: 30 | segs = turn['segments'] 31 | tokens, entity_tags = find_entities(tokens, segs, entity_tags, tokenizer) 32 | 33 | assert len(tokens) == len(entity_tags) 34 | 35 | dial_tokens.append([sp] + tokens) 36 | dial_tags.append(entity_tags) 37 | 38 | assert len(dial_tokens) == len(dial_tags) 39 | 40 | total_tokens.append(dial_tokens) 41 | total_tags.append(dial_tags) 42 | 43 | assert len(total_tokens) == len(total_tags) 44 | 45 | return total_tokens, total_tags # (N, T, L), (N, T, L) 46 | 47 | 48 | def find_entities(tokens, segs, entity_tags, tokenizer): 49 | entity_list = [(seg['text'], seg['annotations'][0]['name']) for seg in segs] 50 | checked = [False] * len(tokens) 51 | 52 | for entity in entity_list: 53 | value, tag = entity 54 | entity_tokens = tokenizer.tokenize(value) 55 | 56 | entity_tags, checked = find_sublist(tokens, entity_tokens, tag, entity_tags, checked) 57 | 58 | return tokens, entity_tags 59 | 60 | 61 | def find_sublist(full, sub, tag, entity_tags, checked): 62 | for i, e in enumerate(full): 63 | if e == sub[0] and not checked[i]: 64 | cand = full[i:i+len(sub)] 65 | 66 | if cand == sub: 67 | checked[i] = True 68 | entity_tags[i] = f'B-{tag}' 69 | 70 | if f'B-{tag}' not in class_dict: 71 | class_dict[f'B-{tag}'] = len(class_dict) 72 | class_dict[f'I-{tag}'] = len(class_dict) 73 | class_dict[f'E-{tag}'] = len(class_dict) 74 | 75 | if len(sub) > 1: 76 | entity_tags[i+len(sub)-1] = f'E-{tag}' 77 | entity_tags = [f'I-{tag}' if cur_tag == 'O' and (j>i and j (B*T, L) 48 | bert_masks_flattened = bert_masks.view(batch_size * self.max_turns, -1) # (B, T, L) => (B*T, L) 49 | 50 | output = self.bert(input_ids=x_flattened.long(), attention_mask=bert_masks_flattened)[0] # (B*T, L, d_h) 51 | output = output.view(batch_size, self.max_turns, -1, self.hidden_size) # (B*T, L, d_h) => (B, T, L, d_h) 52 | 53 | history_embs = self.embed_context(output) # (B, T, d_h) 54 | encoder_output = self.context_encoder(history_embs, e_masks.unsqueeze(1)) # (B, T, d_h) 55 | 56 | context_vec = encoder_output[torch.arange(encoder_output.shape[0]), turns] # (B, d_h) 57 | output = output[torch.arange(output.shape[0]), turns] # (B, L, d_h) 58 | seq_len = output.shape[1] 59 | output = torch.cat((output, context_vec.unsqueeze(1).repeat(1, seq_len,1)), dim=-1) # (B, L, 2*d_h) 60 | 61 | x_masks = bert_masks[torch.arange(bert_masks.shape[0]), turns] # (B, L) 62 | else: 63 | x_masks = self.make_bert_mask(x, pad_id) # (B, L) 64 | 65 | output = self.bert(input_ids=x, attention_mask=x_masks)[0] # (B, L, d_h) 66 | 67 | emissions = self.position_wise_ff(output) # (B, L, C) 68 | 69 | log_likelihood, sequence_of_tags = self.crf(emissions, tags, mask=x_masks.bool(), reduction='mean'), self.crf.decode(emissions, mask=x_masks.bool()) 70 | return log_likelihood, sequence_of_tags # (), (B, L) 71 | 72 | def init_model(self): 73 | init_list = [self.dropout, self.position_wise_ff, self.crf] 74 | for module in init_list: 75 | for param in module.parameters(): 76 | if param.dim() > 1: 77 | nn.init.xavier_uniform_(param) 78 | 79 | def embed_context(self, bert_output): 80 | if self.pooling == 'cls': 81 | return bert_output[:, :, 0] # (B, T, d_h) 82 | elif self.pooling == 'mean': 83 | return torch.mean(bert_output, dim=2) 84 | elif self.pooling == 'max': 85 | return torch.max(bert_output, dim=2).values 86 | 87 | def make_bert_mask(self, x, pad_id): 88 | bert_masks = (x != pad_id).float() # (B, L) 89 | return bert_masks 90 | 91 | def make_encoder_mask(self, turns, num_contexts): 92 | batch_size = turns.shape[0] 93 | masks = torch.zeros((turns.shape[0], num_contexts), device=turns.device) 94 | masks[torch.arange(num_contexts, device=masks.device) < turns[..., None]] = 1.0 95 | 96 | return masks 97 | 98 | 99 | class ContextEncoder(nn.Module): 100 | def __init__(self, d_model, d_ff, num_heads, dropout, num_layers, max_turns, p_dim, device): 101 | super().__init__() 102 | self.d_model = d_model 103 | self.d_ff = d_ff 104 | self.num_heads = num_heads 105 | self.dropout = dropout 106 | self.num_layers = num_layers 107 | self.max_turns = max_turns 108 | self.p_dim = p_dim 109 | self.device = device 110 | 111 | self.positional_encoder = PositionalEncoder(self.max_turns, self.p_dim, self.device) 112 | self.linear = nn.Linear(self.d_model+self.p_dim, self.d_model) 113 | self.layers = nn.ModuleList([EncoderLayer(self.d_model, self.d_ff, self.num_heads, self.dropout) for i in range(self.num_layers)]) 114 | self.layer_norm = LayerNormalization(self.d_model) 115 | 116 | def forward(self, x, e_masks): 117 | x = self.positional_encoder(x, cal='concat') # (B, T, d_h) 118 | x = self.linear(x) # (B, T, d_h) 119 | for i in range(self.num_layers): 120 | x = self.layers[i](x, e_masks) 121 | 122 | return self.layer_norm(x) 123 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch 4 | import math 5 | 6 | 7 | class EncoderLayer(nn.Module): 8 | def __init__(self, d_model, d_ff, num_heads, dropout): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.d_ff = d_ff 12 | self.num_heads = num_heads 13 | self.dropout = dropout 14 | 15 | self.layer_norm_1 = LayerNormalization(self.d_model) 16 | self.multihead_attention = MultiheadAttention(self.d_model, self.num_heads, self.dropout) 17 | self.drop_out_1 = nn.Dropout(self.dropout) 18 | 19 | self.layer_norm_2 = LayerNormalization(self.d_model) 20 | self.feed_forward = FeedFowardLayer(self.d_model, self.d_ff, self.dropout) 21 | self.drop_out_2 = nn.Dropout(self.dropout) 22 | 23 | def forward(self, x, e_mask): 24 | x_1 = self.layer_norm_1(x) # (B, L, d_model) 25 | x = x + self.drop_out_1( 26 | self.multihead_attention(x_1, x_1, x_1, mask=e_mask) 27 | ) # (B, L, d_model) 28 | x_2 = self.layer_norm_2(x) # (B, L, d_model) 29 | x = x + self.drop_out_2(self.feed_forward(x_2)) # (B, L, d_model) 30 | 31 | return x # (B, L, d_model) 32 | 33 | 34 | class MultiheadAttention(nn.Module): 35 | def __init__(self, d_model, num_heads, dropout): 36 | super().__init__() 37 | self.inf = 1e9 38 | self.d_model = d_model 39 | self.num_heads = num_heads 40 | self.d_k = d_model // num_heads 41 | 42 | # W^Q, W^K, W^V in the paper 43 | self.w_q = nn.Linear(d_model, d_model) 44 | self.w_k = nn.Linear(d_model, d_model) 45 | self.w_v = nn.Linear(d_model, d_model) 46 | 47 | self.dropout = nn.Dropout(dropout) 48 | self.attn_softmax = nn.Softmax(dim=-1) 49 | 50 | # Final output linear transformation 51 | self.w_0 = nn.Linear(d_model, d_model) 52 | 53 | def forward(self, q, k, v, mask=None): 54 | input_shape = q.shape 55 | 56 | # Linear calculation + split into num_heads 57 | q = self.w_q(q).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k) 58 | k = self.w_k(k).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k) 59 | v = self.w_v(v).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k) 60 | 61 | # For convenience, convert all tensors in size (B, H, L, d_k) 62 | q = q.transpose(1, 2) 63 | k = k.transpose(1, 2) 64 | v = v.transpose(1, 2) 65 | 66 | # Conduct self-attention 67 | attn_values = self.self_attention(q, k, v, mask=mask) # (B, H, L, d_k) 68 | concat_output = attn_values.transpose(1, 2)\ 69 | .contiguous().view(input_shape[0], -1, self.d_model) # (B, L, d_model) 70 | 71 | return self.w_0(concat_output) 72 | 73 | def self_attention(self, q, k, v, mask=None): 74 | # Calculate attention scores with scaled dot-product attention 75 | attn_scores = torch.matmul(q, k.transpose(-2, -1)) # (B, H, L, L) 76 | attn_scores = attn_scores / math.sqrt(self.d_k) 77 | 78 | # If there is a mask, make masked spots -INF 79 | if mask is not None: 80 | mask = mask.unsqueeze(1) # (B, 1, L) => (B, 1, 1, L) or (B, L, L) => (B, 1, L, L) 81 | attn_scores = attn_scores.masked_fill_(mask == 0, -1 * self.inf) 82 | 83 | # Softmax and multiplying K to calculate attention value 84 | attn_distribs = self.attn_softmax(attn_scores) 85 | 86 | attn_distribs = self.dropout(attn_distribs) 87 | attn_values = torch.matmul(attn_distribs, v) # (B, H, L, d_k) 88 | 89 | return attn_values 90 | 91 | 92 | class FeedFowardLayer(nn.Module): 93 | def __init__(self, d_model, d_ff, dropout): 94 | super().__init__() 95 | self.d_model = d_model 96 | self.d_ff = d_ff 97 | self.dropout = dropout 98 | 99 | self.linear_1 = nn.Linear(self.d_model, self.d_ff, bias=True) 100 | self.relu = nn.ReLU() 101 | self.linear_2 = nn.Linear(self.d_ff, self.d_model, bias=True) 102 | self.dropout = nn.Dropout(self.dropout) 103 | 104 | def forward(self, x): 105 | x = self.relu(self.linear_1(x)) # (B, L, d_ff) 106 | x = self.dropout(x) 107 | x = self.linear_2(x) # (B, L, d_model) 108 | 109 | return x 110 | 111 | 112 | class LayerNormalization(nn.Module): 113 | def __init__(self, d_model, eps=1e-6): 114 | super().__init__() 115 | self.d_model = d_model 116 | self.eps = eps 117 | self.layer = nn.LayerNorm([self.d_model], elementwise_affine=True, eps=self.eps) 118 | 119 | def forward(self, x): 120 | x = self.layer(x) 121 | 122 | return x 123 | 124 | 125 | class PositionalEncoder(nn.Module): 126 | def __init__(self, max_len, p_dim, device): 127 | super().__init__() 128 | self.device = device 129 | self.max_len = max_len 130 | self.p_dim = p_dim 131 | 132 | # Make initial positional encoding matrix with 0 133 | pe_matrix= torch.zeros(self.max_len, self.p_dim) # (L, d_model) 134 | 135 | # Calculating position encoding values 136 | for pos in range(self.max_len): 137 | for i in range(self.p_dim): 138 | if i % 2 == 0: 139 | pe_matrix[pos, i] = math.sin(pos / (10000 ** (2 * i / self.p_dim))) 140 | elif i % 2 == 1: 141 | pe_matrix[pos, i] = math.cos(pos / (10000 ** (2 * i / self.p_dim))) 142 | 143 | pe_matrix = pe_matrix.unsqueeze(0) # (1, L, p_dim) 144 | self.positional_encoding = pe_matrix.to(self.device).requires_grad_(False) 145 | 146 | def forward(self, x, cal='add'): 147 | assert cal == 'add' or cal == 'concat', "Please specify the calculation method, either 'add' or 'concat'." 148 | 149 | if cal == 'add': 150 | x = x * math.sqrt(self.p_dim) # (B, L, d_model) 151 | x = x + self.positional_encoding # (B, L, d_model) 152 | elif cal == 'concat': 153 | x = torch.cat((x, self.positional_encoding.repeat(x.shape[0],1,1)), dim=-1) # (B, T, d_model+p_dim) 154 | 155 | return x 156 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from torch.utils.data import DataLoader 3 | from torch.nn import functional as F 4 | from entity_bert import EntityBert 5 | from custom_dataset import CustomDataset 6 | from transformers import BertConfig, BertTokenizer, get_polynomial_decay_schedule_with_warmup 7 | from seqeval.metrics import accuracy_score, f1_score 8 | from itertools import chain 9 | 10 | import torch 11 | import random 12 | import os, sys 13 | import numpy as np 14 | import argparse 15 | import time 16 | import json 17 | 18 | 19 | def run(args): 20 | # Device setting 21 | if torch.cuda.is_available(): 22 | args.device = torch.device(f'cuda:{args.gpu}') 23 | else: 24 | print("CUDA is unavailable. Starting with CPU.") 25 | args.device = torch.device('cpu') 26 | 27 | print(f"{args.turn_type}-turn setting fixed.") 28 | if args.turn_type == 'multi': 29 | print(f"Pooling policy is {args.pooling}.") 30 | 31 | # Load class dictionary 32 | print("Loading the class dictionary...") 33 | with open(f"{args.data_dir}/{args.processed_dir}/class_dict.json", 'r') as f: 34 | class_dict = json.load(f) 35 | args.num_classes = len(class_dict) 36 | idx2class = {v:k for k, v in class_dict.items()} 37 | 38 | # Adding arguments 39 | bert_config = BertConfig().from_pretrained(args.bert_type) 40 | args.hidden_size = bert_config.hidden_size 41 | args.p_dim = args.hidden_size 42 | args.max_len = min(args.max_len, bert_config.max_position_embeddings) 43 | 44 | # Tokenizer 45 | print("Loading the tokenizer...") 46 | tokenizer = BertTokenizer.from_pretrained(args.bert_type) 47 | num_new_tokens = tokenizer.add_special_tokens( 48 | { 49 | 'additional_special_tokens': [args.sp1_token, args.sp2_token] 50 | } 51 | ) 52 | vocab = tokenizer.get_vocab() 53 | args.vocab_size = len(vocab) 54 | 55 | args.cls_token = tokenizer.cls_token 56 | args.sep_token = tokenizer.sep_token 57 | args.pad_token = tokenizer.pad_token 58 | 59 | args.cls_id = vocab[args.cls_token] 60 | args.sep_id = vocab[args.sep_token] 61 | args.pad_id = vocab[args.pad_token] 62 | args.sp1_id = vocab[args.sp1_token] 63 | args.sp2_id = vocab[args.sp2_token] 64 | args.o_id = class_dict['O'] 65 | 66 | # Load model & optimizer 67 | print("Loading the model and optimizer...") 68 | set_seed(args.seed) 69 | model = EntityBert(args).to(args.device) 70 | model.init_model() 71 | optim = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) 72 | 73 | if not os.path.exists(args.ckpt_dir): 74 | os.mkdir(args.ckpt_dir) 75 | 76 | # Loading datasets & dataloaders 77 | print(f"Loading {args.turn_type}-turn data...") 78 | train_set = CustomDataset(args, vocab, class_dict, prefix='train') 79 | valid_set = CustomDataset(args, vocab, class_dict, prefix='valid') 80 | test_set = CustomDataset(args, vocab, class_dict, prefix='test') 81 | train_loader = DataLoader(train_set, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) 82 | valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) 83 | test_loader = DataLoader(test_set, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) 84 | 85 | # Setting scheduler 86 | num_batches = len(train_loader) 87 | args.total_train_steps = args.num_epochs * num_batches 88 | args.warmup_steps = int(args.warmup_ratio * args.total_train_steps) 89 | sched = get_polynomial_decay_schedule_with_warmup( 90 | optimizer=optim, 91 | num_warmup_steps=args.warmup_steps, 92 | num_training_steps=args.total_train_steps, 93 | power=2.0, 94 | ) 95 | 96 | # Training 97 | set_seed(args.seed) 98 | best_ckpt_path = train(args, model, optim, sched, train_loader, valid_loader, idx2class) 99 | 100 | # Testing 101 | print("Testing the model...") 102 | _, test_acc, test_f1 = evaluate(args, model, test_loader, idx2class, ckpt_path=best_ckpt_path) 103 | 104 | print("") 105 | print(f"Test accuracy: {test_acc} || Test F1 score: {test_f1}") 106 | print("GOOD BYE.") 107 | 108 | 109 | def train(args, model, optim, sched, train_loader, valid_loader, idx2class): 110 | print("Training starts.") 111 | best_f1 = 0.0 112 | patience, threshold = 0, 1e-4 113 | best_ckpt_path = None 114 | 115 | for epoch in range(1, args.num_epochs+1): 116 | model.train() 117 | 118 | print("#"*50 + f" Epoch: {epoch} " + "#"*50) 119 | train_losses, train_ys, train_outputs, train_lens = [], [], [], [] 120 | for i, batch in enumerate(tqdm(train_loader)): 121 | batch_x, batch_y, batch_lens, batch_turns = batch 122 | 123 | if args.turn_type == 'single': 124 | batch_x, batch_y = batch_x.to(args.device), batch_y.to(args.device) 125 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id) # (), (B, L) 126 | elif args.turn_type == 'multi': 127 | batch_x, batch_y, batch_turns = \ 128 | batch_x.to(args.device), batch_y.to(args.device), batch_turns.to(args.device) 129 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id, turns=batch_turns) # (), (B, L) 130 | 131 | loss = -1 * log_likelihood 132 | 133 | model.zero_grad() 134 | optim.zero_grad() 135 | 136 | loss.backward() 137 | optim.step() 138 | sched.step() 139 | 140 | train_losses.append(loss.detach()) 141 | train_ys.append(batch_y.detach()) 142 | train_outputs.append(outputs) 143 | train_lens.append(batch_lens) 144 | 145 | train_losses = [loss.item() for loss in train_losses] 146 | train_loss = np.mean(train_losses) 147 | train_preds, train_trues = [], [] 148 | for i in range(len(train_ys)): 149 | pred_batch, true_batch, batch_lens = train_outputs[i], train_ys[i], train_lens[i] 150 | 151 | batch_lens = batch_lens.tolist() # (B) 152 | true_batch = [batch[:batch_lens[b]] for b, batch in enumerate(true_batch.tolist())] 153 | 154 | assert len(pred_batch) == len(true_batch) 155 | train_preds += pred_batch 156 | train_trues += true_batch 157 | 158 | assert len(train_preds) == len(train_trues) 159 | for i in range(len(train_preds)): 160 | train_pred, train_true = train_preds[i], train_trues[i] 161 | train_pred = [idx2class[class_id] for class_id in train_pred] 162 | train_true = [idx2class[class_id] for class_id in train_true] 163 | 164 | train_preds[i] = train_pred 165 | train_trues[i] = train_true 166 | 167 | train_acc = accuracy_score(train_trues, train_preds) 168 | train_f1 = f1_score(train_trues, train_preds) 169 | 170 | print(f"Train loss: {train_loss} || Train accuracy: {train_acc} || Train F1 score: {train_f1}") 171 | 172 | print("Validation processing...") 173 | valid_loss, valid_acc, valid_f1 = evaluate(args, model, valid_loader, idx2class) 174 | 175 | if valid_f1 >= best_f1 + threshold: 176 | best_f1 = valid_f1 177 | patience = 0 178 | best_ckpt_path = f"{args.ckpt_dir}/ckpt_epoch={epoch}_train_f1={round(train_f1, 4)}_valid_f1={round(valid_f1, 4)}" 179 | torch.save(model.state_dict(), best_ckpt_path) 180 | print(f"***** Current best checkpoint is saved. *****") 181 | else: 182 | patience += 1 183 | print(f"The f1 score did not improve by {threshold}. Patience: {patience}") 184 | 185 | print(f"Best validtion f1 score: {best_f1}") 186 | print(f"Validation loss: {valid_loss} || Validation accuracy: {valid_acc} || Current validation F1 score: {valid_f1}") 187 | 188 | if patience == 3: 189 | print("Run out of patience. Abort!") 190 | break 191 | 192 | print("Training finished!") 193 | 194 | return best_ckpt_path 195 | 196 | 197 | def evaluate(args, model, eval_loader, idx2class, ckpt_path=None): 198 | if ckpt_path is not None: 199 | model.load_state_dict(torch.load(ckpt_path)) 200 | 201 | model.eval() 202 | 203 | eval_losses, eval_ys, eval_outputs, eval_lens = [], [], [], [] 204 | with torch.no_grad(): 205 | for i, batch in enumerate(tqdm(eval_loader)): 206 | batch_x, batch_y, batch_lens, batch_turns = batch 207 | 208 | if args.turn_type == 'single': 209 | batch_x, batch_y = batch_x.to(args.device), batch_y.to(args.device) 210 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id) # (), (B, L) 211 | elif args.turn_type == 'multi': 212 | batch_x, batch_y, batch_turns = \ 213 | batch_x.to(args.device), batch_y.to(args.device), batch_turns.to(args.device) 214 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id, turns=batch_turns) # (), (B, L) 215 | 216 | loss = -1 * log_likelihood 217 | 218 | eval_losses.append(loss.detach()) 219 | eval_ys.append(batch_y.detach()) 220 | eval_outputs.append(outputs) 221 | eval_lens.append(batch_lens) 222 | 223 | eval_losses = [loss.item() for loss in eval_losses] 224 | eval_loss = np.mean(eval_losses) 225 | eval_preds, eval_trues = [], [] 226 | for i in range(len(eval_ys)): 227 | pred_batch, true_batch, batch_lens = eval_outputs[i], eval_ys[i], eval_lens[i] 228 | 229 | batch_lens = batch_lens.tolist() # (B) 230 | true_batch = [batch[:batch_lens[b]] for b, batch in enumerate(true_batch.tolist())] 231 | 232 | assert len(pred_batch) == len(true_batch) 233 | eval_preds += pred_batch 234 | eval_trues += true_batch 235 | 236 | assert len(eval_preds) == len(eval_trues) 237 | for i in range(len(eval_preds)): 238 | eval_pred, eval_true = eval_preds[i], eval_trues[i] 239 | eval_pred = [idx2class[class_id] for class_id in eval_pred] 240 | eval_true = [idx2class[class_id] for class_id in eval_true] 241 | 242 | eval_preds[i] = eval_pred 243 | eval_trues[i] = eval_true 244 | 245 | eval_acc = accuracy_score(eval_trues, eval_preds) 246 | eval_f1 = f1_score(eval_trues, eval_preds) 247 | 248 | return eval_loss, eval_acc, eval_f1 249 | 250 | 251 | def set_seed(seed): 252 | np.random.seed(seed) 253 | torch.manual_seed(seed) 254 | torch.cuda.manual_seed_all(seed) 255 | random.seed(seed) 256 | 257 | 258 | if __name__=='__main__': 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('--seed', default=0, type=int, help="The random seed.") 261 | parser.add_argument('--turn_type', required=True, type=str, help="The turn type setting. (single or multi)") 262 | parser.add_argument('--bert_type', default="bert-base-uncased", type=str, help="The BERT type to load.") 263 | parser.add_argument('--pooling', default="cls", type=str, help="The pooling policy when using the multi-turn setting.") 264 | parser.add_argument('--data_dir', default="data", type=str, help="The parent data directory.") 265 | parser.add_argument('--processed_dir', default="processed", type=str, help="The directory which contains the parsed data pickle files.") 266 | parser.add_argument('--ckpt_dir', default="saved_models", type=str, help="The path for saved checkpoints.") 267 | parser.add_argument('--gpu', default=0, type=int, help="The index of a GPU to use.") 268 | parser.add_argument('--sp1_token', default="[USR]", type=str, help="The speaker1(USER) token.") 269 | parser.add_argument('--sp2_token', default="[SYS]", type=str, help="The speaker2(SYSTEM) token.") 270 | parser.add_argument('--max_len', default=128, type=int, help="The max length of each utterance.") 271 | parser.add_argument('--max_turns', default=5, type=int, help="The maximum number of the dialogue history to be attended in the multi-turn setting.") 272 | parser.add_argument('--dropout', default=0.1, type=float, help="The dropout rate.") 273 | parser.add_argument('--context_d_ff', default=2048, type=int, help="The size of intermediate hidden states in the feed-forward layer.") 274 | parser.add_argument('--context_num_heads', default=8, type=int, help="The number of heads for the multi-head attention.") 275 | parser.add_argument('--context_dropout', default=0.1, type=float, help="The dropout rate for the context encoder.") 276 | parser.add_argument('--context_num_layers', default=2, type=int, help="The number of layers in the context encoder.") 277 | parser.add_argument('--learning_rate', default=5e-5, type=float, help="The initial learning rate.") 278 | parser.add_argument('--warmup_ratio', default=0.1, type=float, help="The ratio of warmup steps to the total training steps.") 279 | parser.add_argument('--batch_size', default=8, type=int, help="The batch size.") 280 | parser.add_argument('--num_workers', default=4, type=int, help="The number of sub-processes for data loading.") 281 | parser.add_argument('--num_epochs', default=10, type=int, help="The number of training epochs.") 282 | 283 | args = parser.parse_args() 284 | 285 | assert args.turn_type == 'single' or args.turn_type == 'multi', print("Please specify a correct turn type, either 'single' or 'multi'.") 286 | assert args.bert_type in [ 287 | "bert-base-uncased", 288 | "bert-base-cased", 289 | "bert-large-uncased", 290 | "bert-large-cased" 291 | ] 292 | assert args.pooling in ["cls", "mean", "max"] 293 | 294 | run(args) 295 | --------------------------------------------------------------------------------