├── .gitignore
├── LICENSE
├── README.md
├── exec_data_process.sh
├── exec_main.sh
├── requirements.txt
└── src
├── custom_dataset.py
├── data_process.py
├── entity_bert.py
├── layers.py
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Large size files
132 | data/
133 | saved_models/
134 |
135 | # Jupyter notebook
136 | nohup.out
137 | *.ipynb
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jaewoo Song
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bert-crf-entity-extraction-pytorch
2 | This repository is for the entity extraction task using the pre-trained **BERT**[[1]](#1) and the additional **CRF**(Conditional Random Field)[[2]](#2) layer.
3 |
4 | Originally, this project has been conducted for dialogue datasets, so it contains both single-turn setting and multi-turn setting.
5 |
6 | The single-turn setting is the same as the basic entity extraction task, but the multi-turn one is a little bit different since it considers the dialogue contexts(previous histories) to conduct the entity extraction task to current utterance.
7 |
8 | The multi-turn context application is based on **ReCoSa**(the Relevant Contexts with Self-attention)[[3]](#3) structure.
9 |
10 | You can see the details of each model in below descriptions.
11 |
12 |
13 |
14 |
15 |
16 | ---
17 |
18 | ### Arguments
19 |
20 | **Arguments for data pre-processing**
21 |
22 | | Argument | Type | Description | Default |
23 | | --------------------- | -------- | ------------------------------------------------------------ | ---------------------- |
24 | | `seed` | `int` | The random seed. | `0` |
25 | | `data_dir` | `str` | The parent data directory. | `"data"` |
26 | | `raw_dir` | `str` | The directory which contains the raw data json files. | `"raw"` |
27 | | `save_dir` | `str` | The directory which will contain the parsed data pickle files. | `"processed"` |
28 | | `bert_type` | `str` | The BERT type to load. | `"bert-base-uncased"` |
29 | | `train_ratio` | `float` | The ratio of train set to the total number of dialogues in each file. | `0.8` |
30 |
31 |
32 |
33 | **Arguments for training/evaluating**
34 |
35 | | Argument | Type | Description | Default |
36 | | -------------------- | ------- | ------------------------------------------------------------ | --------------------- |
37 | | `seed` | `int` | The random seed. | `0` |
38 | | `turn_type` | `str` | The turn type setting. (`"single"` or `"multi"`) | *YOU SHOULD SPECIFY* |
39 | | `bert_type` | `str` | The BERT type to load. | `"bert-base-uncased"` |
40 | | `pooling` | `str` | The pooling policy when using the multi-turn setting. | `"cls"` |
41 | | `data_dir` | `str` | The parent data directory. | `"data"` |
42 | | `processed_dir` | `str` | The directory which contains the parsed data pickle files. | `"processed"` |
43 | | `ckpt_dir` | `str` | The path for saved checkpoints. | `"saved_models"` |
44 | | `gpu` | `int` | The index of a GPU to use. | `0` |
45 | | `sp1_token` | `str` | The speaker1(USER) token. | `"[USR]"` |
46 | | `sp2_token` | `str` | The speaker2(SYSTEM) token. | `"[SYS]"` |
47 | | `max_len` | `int` | The max length of each utterance. | `128` |
48 | | `max_turns` | `int` | The maximum number of the dialogue history to be attended in the multi-turn setting. | `5` |
49 | | `dropout` | `float` | The dropout rate. | `0.1` |
50 | | `context_d_ff` | `int` | The size of intermediate hidden states in the feed-forward layer. | `2048` |
51 | | `context_num_heads` | `int` | The number of heads for the multi-head attention. | `8` |
52 | | `context_dropout` | `float` | The dropout rate for the context encoder. | `0.1` |
53 | | `context_num_layers` | `int` | The number of layers in the context encoder. | `2` |
54 | | `learning_rate` | `float` | The initial learning rate. | `5e-5` |
55 | | `warmup_ratio` | `float` | The ratio of warmup steps to the total training steps. | `0.1` |
56 | | `batch_size` | `int` | The batch size. | `8` |
57 | | `num_workers` | `int` | The number of sub-processes for data loading. | `4` |
58 | | `num_epochs` | `int` | The number of training epochs. | `10` |
59 |
60 |
61 |
62 |
63 |
64 | ### Dataset
65 |
66 | This repository uses the Google's Taskmaster-2[[4]](#4) dataset for entity extraction task.
67 |
68 | You should first download the data (`"TM-2-2020"`), and get all json files in `"TM-2-2020/data"` directory to properly run this project.
69 |
70 | You can see the detailes for using the Taskmaster-2 dataset in the next section.
71 |
72 |
73 |
74 |
75 |
76 | ### How to run
77 |
78 | 1. Install all required packages.
79 |
80 | ```shell
81 | pip install -r requirements.txt
82 | ```
83 |
84 |
85 |
86 | 2. Make the directory `{data_dir}/{raw_dir}` and put the json files, as mentioned in the previous section.
87 |
88 | In default setting, the structure of whole data directory should be like below.
89 |
90 | ```
91 | data
92 | └--raw
93 | └--flight.json
94 | └--food-ordering.json
95 | └--hotels.json
96 | └--movies.json
97 | └--music.json
98 | └--restaurant-search.json
99 | └--sports.json
100 | ```
101 |
102 |
103 |
104 | 3. Run the data processing script.
105 |
106 | ```shell
107 | sh exec_data_processing.sh
108 | ```
109 |
110 | After running it, you will get the processed files like below in the default setting.
111 |
112 | ```
113 | data
114 | └--raw
115 | └--flight.json
116 | └--food-ordering.json
117 | └--hotels.json
118 | └--movies.json
119 | └--music.json
120 | └--restaurant-search.json
121 | └--sports.json
122 | └--processed
123 | └--class_dict.json
124 | └--train_tokens.pkl
125 | └--train_tags.pkl
126 | └--valid_tokens.pkl
127 | └--valid_tags.pkl
128 | └--test_tokens.pkl
129 | └--test_tags.pkl
130 | ```
131 |
132 |
133 |
134 | 4. Run the main script and check the results.
135 |
136 | ```shell
137 | sh exec_main.sh
138 | ```
139 |
140 |
141 |
142 | ---
143 |
144 | ### Results
145 |
146 | | Turn type | Pooling | Validation F1 | Test F1 |
147 | | --------- | ------- | ------------- | ---------- |
148 | | Single | - | 0.6719 | 0.6755 |
149 | | Multi | CLS | **0.7148** | **0.7118** |
150 | | Multi | Mean | 0.7132 | 0.7095 |
151 | | Multi | Max | 0.7116 | 0.7104 |
152 |
153 |
154 |
155 | ---
156 |
157 | ### References
158 |
159 | [1] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. *arXiv preprint arXiv:1810.04805*. ([https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805))
160 |
161 | [2] Lafferty, J., McCallum, A., & Pereira, F. C. (2001). Conditional random fields: Probabilistic models for segmenting and labeling sequence data. ([https://repository.upenn.edu/cis_papers/159/](https://repository.upenn.edu/cis_papers/159/))
162 |
163 | [3] Zhang, H., Lan, Y., Pang, L., Guo, J., & Cheng, X. (2019). Recosa: Detecting the relevant contexts with self-attention for multi-turn dialogue generation. *arXiv preprint arXiv:1907.05339*. ([https://arxiv.org/abs/1907.05339](https://arxiv.org/abs/1907.05339))
164 |
165 | [4] Taskmaster-2 . (2020). ([https://research.google/tools/datasets/taskmaster-2/](https://research.google/tools/datasets/taskmaster-2/))
166 |
--------------------------------------------------------------------------------
/exec_data_process.sh:
--------------------------------------------------------------------------------
1 | python src/data_process.py \
2 | --seed=0 \
3 | --data_dir="data" \
4 | --raw_dir="raw" \
5 | --save_dir="processed" \
6 | --bert_type="bert-base-uncased" \
7 | --train_ratio=0.8
8 |
--------------------------------------------------------------------------------
/exec_main.sh:
--------------------------------------------------------------------------------
1 | python src/main.py \
2 | --seed=0 \
3 | --turn_type=TURN_TYPE \
4 | --bert_type="bert-base-uncased" \
5 | --pooling="max" \
6 | --data_dir="data" \
7 | --processed_dir="processed" \
8 | --ckpt_dir="saved_models" \
9 | --gpu="0" \
10 | --sp1_token="[USR]" \
11 | --sp2_token="[SYS]" \
12 | --max_len=128 \
13 | --max_turns=5 \
14 | --dropout=0.1 \
15 | --context_d_ff=2048 \
16 | --context_num_heads=8 \
17 | --context_dropout=0.1 \
18 | --context_num_layers=2 \
19 | --learning_rate=5e-5 \
20 | --warmup_ratio=0.1 \
21 | --batch_size=8 \
22 | --num_workers=4 \
23 | --num_epochs=10
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | transformers==4.11.3
3 | scikit-learn==1.0.1
4 | pytorch-crf==0.7.2
5 | seqeval==1.2.2
6 |
--------------------------------------------------------------------------------
/src/custom_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tqdm import tqdm
3 |
4 | import torch
5 | import numpy as np
6 | import pickle
7 |
8 |
9 | class CustomDataset(Dataset):
10 | def __init__(self, args, vocab, class_dict, prefix):
11 | with open(f"{args.data_dir}/{args.processed_dir}/{prefix}_tokens.pkl", 'rb') as f:
12 | tokens = pickle.load(f)
13 |
14 | with open(f"{args.data_dir}/{args.processed_dir}/{prefix}_tags.pkl", 'rb') as f:
15 | tags = pickle.load(f)
16 |
17 | self.input_ids = []
18 | self.labels = []
19 | self.valid_lens = []
20 | self.turns = []
21 |
22 | if args.turn_type == 'single':
23 | self.process_single_turn(args, vocab, class_dict, tokens, tags)
24 | elif args.turn_type == 'multi':
25 | self.process_multi_turns(args, vocab, class_dict, tokens, tags)
26 |
27 | assert len(self.input_ids) == len(self.labels)
28 | assert len(self.input_ids) == len(self.valid_lens)
29 | assert len(self.input_ids) == len(self.turns)
30 |
31 | print(f"{len(self.input_ids)} samples are prepared for {prefix} set.")
32 |
33 | self.input_ids = torch.LongTensor(self.input_ids) # (N, L) or (N, T, L)
34 | self.labels = torch.LongTensor(self.labels) # (N, L)
35 | self.valid_lens = torch.LongTensor(self.valid_lens) # (N)
36 | self.turns = torch.LongTensor(self.turns) # (N)
37 |
38 | def process_single_turn(self, args, vocab, class_dict, tokens, tags):
39 | assert len(tokens) == len(tags)
40 |
41 | for d in tqdm(range(len(tokens))):
42 | dial_tokens, dial_tags = tokens[d], tags[d]
43 | assert len(dial_tokens) == len(dial_tags)
44 |
45 | for u in range(len(dial_tokens)):
46 | utter_tokens, utter_tags = dial_tokens[u], dial_tags[u]
47 | sp, utter_tokens = utter_tokens[0], utter_tokens[1:]
48 | assert len(utter_tokens) == len(utter_tags)
49 |
50 | token_ids = [vocab[token] for token in utter_tokens]
51 | tag_ids = [class_dict[tag] for tag in utter_tags]
52 |
53 | if sp == "USER": # Speaker1: USER
54 | sp_id = args.sp1_id
55 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids, tag_ids)
56 |
57 | self.input_ids.append(token_ids)
58 | self.labels.append(tag_ids)
59 | self.valid_lens.append(valid_len)
60 | self.turns.append(0)
61 |
62 | def process_multi_turns(self, args, vocab, class_dict, tokens, tags):
63 | assert len(tokens) == len(tags)
64 |
65 | for d in tqdm(range(len(tokens))):
66 | dial_tokens, dial_tags = tokens[d], tags[d]
67 | assert len(dial_tokens) == len(dial_tags)
68 |
69 | token_hists, tag_hists, len_hists = [], [], []
70 | for u in range(len(dial_tokens)):
71 | utter_tokens, utter_tags = dial_tokens[u], dial_tags[u]
72 | sp, utter_tokens = utter_tokens[0], utter_tokens[1:]
73 | assert len(utter_tokens) == len(utter_tags)
74 |
75 | token_ids = [vocab[token] for token in utter_tokens]
76 | tag_ids = [class_dict[tag] for tag in utter_tags]
77 |
78 | if sp == "USER": # Speaker1: USER
79 | sp_id = args.sp1_id
80 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids, tag_ids)
81 | elif sp == "ASSISTANT": # Speaker2: SYSTEM
82 | sp_id = args.sp2_id
83 | token_ids, tag_ids, valid_len = self.pad_or_truncate(args, sp_id, token_ids)
84 |
85 | token_hists.append(token_ids)
86 | tag_hists.append(tag_ids)
87 | len_hists.append(valid_len)
88 |
89 | assert len(token_hists) == len(tag_hists)
90 | assert len(tag_hists) == len(len_hists)
91 |
92 | init_ids = [args.cls_id] + [args.pad_id] * (args.max_len-2) + [args.sep_id]
93 | for u in range(len(token_hists)):
94 | token_ids, tag_ids, valid_len = token_hists[u], tag_hists[u], len_hists[u]
95 | if token_ids[1] == args.sp1_id:
96 | token_hist = token_hists[max(u+1-args.max_turns, 0):u+1]
97 | assert len(token_hist[-1]) == len(tag_ids)
98 | assert len(token_hist) <= args.max_turns
99 | self.turns.append(len(token_hist)-1)
100 | token_hist += [init_ids] * (args.max_turns-len(token_hist))
101 | assert len(token_hist) == args.max_turns
102 | self.input_ids.append(token_hist)
103 | self.labels.append(tag_ids)
104 | self.valid_lens.append(valid_len)
105 |
106 | def pad_or_truncate(self, args, sp_id, token_ids, tag_ids=None):
107 | token_ids = [args.cls_id, sp_id] + token_ids + [args.sep_id]
108 | if len(token_ids) <= args.max_len:
109 | pad_len = args.max_len - len(token_ids)
110 | token_ids += ([args.pad_id] * pad_len)
111 |
112 | valid_len = -1
113 | if tag_ids is not None:
114 | tag_ids = [args.o_id, args.o_id] + tag_ids + [args.o_id]
115 | valid_len = len(tag_ids)
116 | tag_ids += ([args.o_id] * pad_len)
117 | else:
118 | token_ids = token_ids[:args.max_len]
119 | token_ids[-1] = args.sep_id
120 |
121 | valid_len = -1
122 | if tag_ids is not None:
123 | tag_ids = [args.o_id, args.o_id] + tag_ids + [args.o_id]
124 | tag_ids = tag_ids[:args.max_len]
125 | tag_ids[-1] = args.o_id
126 | valid_len = args.max_len
127 |
128 | assert len(token_ids) == args.max_len
129 | if tag_ids is not None:
130 | assert len(token_ids) == len(tag_ids)
131 |
132 | return token_ids, tag_ids, valid_len
133 |
134 | def __len__(self):
135 | return self.input_ids.shape[0]
136 |
137 | def __getitem__(self, idx):
138 | return self.input_ids[idx], self.labels[idx], self.valid_lens[idx], self.turns[idx]
139 |
--------------------------------------------------------------------------------
/src/data_process.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from glob import glob
3 | from transformers import BertTokenizer
4 |
5 | import argparse
6 | import os
7 | import random
8 | import json
9 | import pickle
10 |
11 |
12 | def load_file(file, tokenizer):
13 | total_tokens = []
14 | total_tags = []
15 |
16 | with open(file, 'r') as f:
17 | data = json.load(f)
18 |
19 | for dial in tqdm(data):
20 | dial_tokens, dial_tags = [], []
21 | turns = dial['utterances']
22 | for turn in turns:
23 | sp = turn['speaker']
24 | text = turn['text']
25 |
26 | tokens = tokenizer.tokenize(text)
27 | entity_tags = ['O'] * len(tokens)
28 |
29 | if 'segments' in turn:
30 | segs = turn['segments']
31 | tokens, entity_tags = find_entities(tokens, segs, entity_tags, tokenizer)
32 |
33 | assert len(tokens) == len(entity_tags)
34 |
35 | dial_tokens.append([sp] + tokens)
36 | dial_tags.append(entity_tags)
37 |
38 | assert len(dial_tokens) == len(dial_tags)
39 |
40 | total_tokens.append(dial_tokens)
41 | total_tags.append(dial_tags)
42 |
43 | assert len(total_tokens) == len(total_tags)
44 |
45 | return total_tokens, total_tags # (N, T, L), (N, T, L)
46 |
47 |
48 | def find_entities(tokens, segs, entity_tags, tokenizer):
49 | entity_list = [(seg['text'], seg['annotations'][0]['name']) for seg in segs]
50 | checked = [False] * len(tokens)
51 |
52 | for entity in entity_list:
53 | value, tag = entity
54 | entity_tokens = tokenizer.tokenize(value)
55 |
56 | entity_tags, checked = find_sublist(tokens, entity_tokens, tag, entity_tags, checked)
57 |
58 | return tokens, entity_tags
59 |
60 |
61 | def find_sublist(full, sub, tag, entity_tags, checked):
62 | for i, e in enumerate(full):
63 | if e == sub[0] and not checked[i]:
64 | cand = full[i:i+len(sub)]
65 |
66 | if cand == sub:
67 | checked[i] = True
68 | entity_tags[i] = f'B-{tag}'
69 |
70 | if f'B-{tag}' not in class_dict:
71 | class_dict[f'B-{tag}'] = len(class_dict)
72 | class_dict[f'I-{tag}'] = len(class_dict)
73 | class_dict[f'E-{tag}'] = len(class_dict)
74 |
75 | if len(sub) > 1:
76 | entity_tags[i+len(sub)-1] = f'E-{tag}'
77 | entity_tags = [f'I-{tag}' if cur_tag == 'O' and (j>i and j (B*T, L)
48 | bert_masks_flattened = bert_masks.view(batch_size * self.max_turns, -1) # (B, T, L) => (B*T, L)
49 |
50 | output = self.bert(input_ids=x_flattened.long(), attention_mask=bert_masks_flattened)[0] # (B*T, L, d_h)
51 | output = output.view(batch_size, self.max_turns, -1, self.hidden_size) # (B*T, L, d_h) => (B, T, L, d_h)
52 |
53 | history_embs = self.embed_context(output) # (B, T, d_h)
54 | encoder_output = self.context_encoder(history_embs, e_masks.unsqueeze(1)) # (B, T, d_h)
55 |
56 | context_vec = encoder_output[torch.arange(encoder_output.shape[0]), turns] # (B, d_h)
57 | output = output[torch.arange(output.shape[0]), turns] # (B, L, d_h)
58 | seq_len = output.shape[1]
59 | output = torch.cat((output, context_vec.unsqueeze(1).repeat(1, seq_len,1)), dim=-1) # (B, L, 2*d_h)
60 |
61 | x_masks = bert_masks[torch.arange(bert_masks.shape[0]), turns] # (B, L)
62 | else:
63 | x_masks = self.make_bert_mask(x, pad_id) # (B, L)
64 |
65 | output = self.bert(input_ids=x, attention_mask=x_masks)[0] # (B, L, d_h)
66 |
67 | emissions = self.position_wise_ff(output) # (B, L, C)
68 |
69 | log_likelihood, sequence_of_tags = self.crf(emissions, tags, mask=x_masks.bool(), reduction='mean'), self.crf.decode(emissions, mask=x_masks.bool())
70 | return log_likelihood, sequence_of_tags # (), (B, L)
71 |
72 | def init_model(self):
73 | init_list = [self.dropout, self.position_wise_ff, self.crf]
74 | for module in init_list:
75 | for param in module.parameters():
76 | if param.dim() > 1:
77 | nn.init.xavier_uniform_(param)
78 |
79 | def embed_context(self, bert_output):
80 | if self.pooling == 'cls':
81 | return bert_output[:, :, 0] # (B, T, d_h)
82 | elif self.pooling == 'mean':
83 | return torch.mean(bert_output, dim=2)
84 | elif self.pooling == 'max':
85 | return torch.max(bert_output, dim=2).values
86 |
87 | def make_bert_mask(self, x, pad_id):
88 | bert_masks = (x != pad_id).float() # (B, L)
89 | return bert_masks
90 |
91 | def make_encoder_mask(self, turns, num_contexts):
92 | batch_size = turns.shape[0]
93 | masks = torch.zeros((turns.shape[0], num_contexts), device=turns.device)
94 | masks[torch.arange(num_contexts, device=masks.device) < turns[..., None]] = 1.0
95 |
96 | return masks
97 |
98 |
99 | class ContextEncoder(nn.Module):
100 | def __init__(self, d_model, d_ff, num_heads, dropout, num_layers, max_turns, p_dim, device):
101 | super().__init__()
102 | self.d_model = d_model
103 | self.d_ff = d_ff
104 | self.num_heads = num_heads
105 | self.dropout = dropout
106 | self.num_layers = num_layers
107 | self.max_turns = max_turns
108 | self.p_dim = p_dim
109 | self.device = device
110 |
111 | self.positional_encoder = PositionalEncoder(self.max_turns, self.p_dim, self.device)
112 | self.linear = nn.Linear(self.d_model+self.p_dim, self.d_model)
113 | self.layers = nn.ModuleList([EncoderLayer(self.d_model, self.d_ff, self.num_heads, self.dropout) for i in range(self.num_layers)])
114 | self.layer_norm = LayerNormalization(self.d_model)
115 |
116 | def forward(self, x, e_masks):
117 | x = self.positional_encoder(x, cal='concat') # (B, T, d_h)
118 | x = self.linear(x) # (B, T, d_h)
119 | for i in range(self.num_layers):
120 | x = self.layers[i](x, e_masks)
121 |
122 | return self.layer_norm(x)
123 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | import torch
4 | import math
5 |
6 |
7 | class EncoderLayer(nn.Module):
8 | def __init__(self, d_model, d_ff, num_heads, dropout):
9 | super().__init__()
10 | self.d_model = d_model
11 | self.d_ff = d_ff
12 | self.num_heads = num_heads
13 | self.dropout = dropout
14 |
15 | self.layer_norm_1 = LayerNormalization(self.d_model)
16 | self.multihead_attention = MultiheadAttention(self.d_model, self.num_heads, self.dropout)
17 | self.drop_out_1 = nn.Dropout(self.dropout)
18 |
19 | self.layer_norm_2 = LayerNormalization(self.d_model)
20 | self.feed_forward = FeedFowardLayer(self.d_model, self.d_ff, self.dropout)
21 | self.drop_out_2 = nn.Dropout(self.dropout)
22 |
23 | def forward(self, x, e_mask):
24 | x_1 = self.layer_norm_1(x) # (B, L, d_model)
25 | x = x + self.drop_out_1(
26 | self.multihead_attention(x_1, x_1, x_1, mask=e_mask)
27 | ) # (B, L, d_model)
28 | x_2 = self.layer_norm_2(x) # (B, L, d_model)
29 | x = x + self.drop_out_2(self.feed_forward(x_2)) # (B, L, d_model)
30 |
31 | return x # (B, L, d_model)
32 |
33 |
34 | class MultiheadAttention(nn.Module):
35 | def __init__(self, d_model, num_heads, dropout):
36 | super().__init__()
37 | self.inf = 1e9
38 | self.d_model = d_model
39 | self.num_heads = num_heads
40 | self.d_k = d_model // num_heads
41 |
42 | # W^Q, W^K, W^V in the paper
43 | self.w_q = nn.Linear(d_model, d_model)
44 | self.w_k = nn.Linear(d_model, d_model)
45 | self.w_v = nn.Linear(d_model, d_model)
46 |
47 | self.dropout = nn.Dropout(dropout)
48 | self.attn_softmax = nn.Softmax(dim=-1)
49 |
50 | # Final output linear transformation
51 | self.w_0 = nn.Linear(d_model, d_model)
52 |
53 | def forward(self, q, k, v, mask=None):
54 | input_shape = q.shape
55 |
56 | # Linear calculation + split into num_heads
57 | q = self.w_q(q).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k)
58 | k = self.w_k(k).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k)
59 | v = self.w_v(v).view(input_shape[0], -1, self.num_heads, self.d_k) # (B, L, H, d_k)
60 |
61 | # For convenience, convert all tensors in size (B, H, L, d_k)
62 | q = q.transpose(1, 2)
63 | k = k.transpose(1, 2)
64 | v = v.transpose(1, 2)
65 |
66 | # Conduct self-attention
67 | attn_values = self.self_attention(q, k, v, mask=mask) # (B, H, L, d_k)
68 | concat_output = attn_values.transpose(1, 2)\
69 | .contiguous().view(input_shape[0], -1, self.d_model) # (B, L, d_model)
70 |
71 | return self.w_0(concat_output)
72 |
73 | def self_attention(self, q, k, v, mask=None):
74 | # Calculate attention scores with scaled dot-product attention
75 | attn_scores = torch.matmul(q, k.transpose(-2, -1)) # (B, H, L, L)
76 | attn_scores = attn_scores / math.sqrt(self.d_k)
77 |
78 | # If there is a mask, make masked spots -INF
79 | if mask is not None:
80 | mask = mask.unsqueeze(1) # (B, 1, L) => (B, 1, 1, L) or (B, L, L) => (B, 1, L, L)
81 | attn_scores = attn_scores.masked_fill_(mask == 0, -1 * self.inf)
82 |
83 | # Softmax and multiplying K to calculate attention value
84 | attn_distribs = self.attn_softmax(attn_scores)
85 |
86 | attn_distribs = self.dropout(attn_distribs)
87 | attn_values = torch.matmul(attn_distribs, v) # (B, H, L, d_k)
88 |
89 | return attn_values
90 |
91 |
92 | class FeedFowardLayer(nn.Module):
93 | def __init__(self, d_model, d_ff, dropout):
94 | super().__init__()
95 | self.d_model = d_model
96 | self.d_ff = d_ff
97 | self.dropout = dropout
98 |
99 | self.linear_1 = nn.Linear(self.d_model, self.d_ff, bias=True)
100 | self.relu = nn.ReLU()
101 | self.linear_2 = nn.Linear(self.d_ff, self.d_model, bias=True)
102 | self.dropout = nn.Dropout(self.dropout)
103 |
104 | def forward(self, x):
105 | x = self.relu(self.linear_1(x)) # (B, L, d_ff)
106 | x = self.dropout(x)
107 | x = self.linear_2(x) # (B, L, d_model)
108 |
109 | return x
110 |
111 |
112 | class LayerNormalization(nn.Module):
113 | def __init__(self, d_model, eps=1e-6):
114 | super().__init__()
115 | self.d_model = d_model
116 | self.eps = eps
117 | self.layer = nn.LayerNorm([self.d_model], elementwise_affine=True, eps=self.eps)
118 |
119 | def forward(self, x):
120 | x = self.layer(x)
121 |
122 | return x
123 |
124 |
125 | class PositionalEncoder(nn.Module):
126 | def __init__(self, max_len, p_dim, device):
127 | super().__init__()
128 | self.device = device
129 | self.max_len = max_len
130 | self.p_dim = p_dim
131 |
132 | # Make initial positional encoding matrix with 0
133 | pe_matrix= torch.zeros(self.max_len, self.p_dim) # (L, d_model)
134 |
135 | # Calculating position encoding values
136 | for pos in range(self.max_len):
137 | for i in range(self.p_dim):
138 | if i % 2 == 0:
139 | pe_matrix[pos, i] = math.sin(pos / (10000 ** (2 * i / self.p_dim)))
140 | elif i % 2 == 1:
141 | pe_matrix[pos, i] = math.cos(pos / (10000 ** (2 * i / self.p_dim)))
142 |
143 | pe_matrix = pe_matrix.unsqueeze(0) # (1, L, p_dim)
144 | self.positional_encoding = pe_matrix.to(self.device).requires_grad_(False)
145 |
146 | def forward(self, x, cal='add'):
147 | assert cal == 'add' or cal == 'concat', "Please specify the calculation method, either 'add' or 'concat'."
148 |
149 | if cal == 'add':
150 | x = x * math.sqrt(self.p_dim) # (B, L, d_model)
151 | x = x + self.positional_encoding # (B, L, d_model)
152 | elif cal == 'concat':
153 | x = torch.cat((x, self.positional_encoding.repeat(x.shape[0],1,1)), dim=-1) # (B, T, d_model+p_dim)
154 |
155 | return x
156 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from torch.utils.data import DataLoader
3 | from torch.nn import functional as F
4 | from entity_bert import EntityBert
5 | from custom_dataset import CustomDataset
6 | from transformers import BertConfig, BertTokenizer, get_polynomial_decay_schedule_with_warmup
7 | from seqeval.metrics import accuracy_score, f1_score
8 | from itertools import chain
9 |
10 | import torch
11 | import random
12 | import os, sys
13 | import numpy as np
14 | import argparse
15 | import time
16 | import json
17 |
18 |
19 | def run(args):
20 | # Device setting
21 | if torch.cuda.is_available():
22 | args.device = torch.device(f'cuda:{args.gpu}')
23 | else:
24 | print("CUDA is unavailable. Starting with CPU.")
25 | args.device = torch.device('cpu')
26 |
27 | print(f"{args.turn_type}-turn setting fixed.")
28 | if args.turn_type == 'multi':
29 | print(f"Pooling policy is {args.pooling}.")
30 |
31 | # Load class dictionary
32 | print("Loading the class dictionary...")
33 | with open(f"{args.data_dir}/{args.processed_dir}/class_dict.json", 'r') as f:
34 | class_dict = json.load(f)
35 | args.num_classes = len(class_dict)
36 | idx2class = {v:k for k, v in class_dict.items()}
37 |
38 | # Adding arguments
39 | bert_config = BertConfig().from_pretrained(args.bert_type)
40 | args.hidden_size = bert_config.hidden_size
41 | args.p_dim = args.hidden_size
42 | args.max_len = min(args.max_len, bert_config.max_position_embeddings)
43 |
44 | # Tokenizer
45 | print("Loading the tokenizer...")
46 | tokenizer = BertTokenizer.from_pretrained(args.bert_type)
47 | num_new_tokens = tokenizer.add_special_tokens(
48 | {
49 | 'additional_special_tokens': [args.sp1_token, args.sp2_token]
50 | }
51 | )
52 | vocab = tokenizer.get_vocab()
53 | args.vocab_size = len(vocab)
54 |
55 | args.cls_token = tokenizer.cls_token
56 | args.sep_token = tokenizer.sep_token
57 | args.pad_token = tokenizer.pad_token
58 |
59 | args.cls_id = vocab[args.cls_token]
60 | args.sep_id = vocab[args.sep_token]
61 | args.pad_id = vocab[args.pad_token]
62 | args.sp1_id = vocab[args.sp1_token]
63 | args.sp2_id = vocab[args.sp2_token]
64 | args.o_id = class_dict['O']
65 |
66 | # Load model & optimizer
67 | print("Loading the model and optimizer...")
68 | set_seed(args.seed)
69 | model = EntityBert(args).to(args.device)
70 | model.init_model()
71 | optim = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
72 |
73 | if not os.path.exists(args.ckpt_dir):
74 | os.mkdir(args.ckpt_dir)
75 |
76 | # Loading datasets & dataloaders
77 | print(f"Loading {args.turn_type}-turn data...")
78 | train_set = CustomDataset(args, vocab, class_dict, prefix='train')
79 | valid_set = CustomDataset(args, vocab, class_dict, prefix='valid')
80 | test_set = CustomDataset(args, vocab, class_dict, prefix='test')
81 | train_loader = DataLoader(train_set, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
82 | valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
83 | test_loader = DataLoader(test_set, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
84 |
85 | # Setting scheduler
86 | num_batches = len(train_loader)
87 | args.total_train_steps = args.num_epochs * num_batches
88 | args.warmup_steps = int(args.warmup_ratio * args.total_train_steps)
89 | sched = get_polynomial_decay_schedule_with_warmup(
90 | optimizer=optim,
91 | num_warmup_steps=args.warmup_steps,
92 | num_training_steps=args.total_train_steps,
93 | power=2.0,
94 | )
95 |
96 | # Training
97 | set_seed(args.seed)
98 | best_ckpt_path = train(args, model, optim, sched, train_loader, valid_loader, idx2class)
99 |
100 | # Testing
101 | print("Testing the model...")
102 | _, test_acc, test_f1 = evaluate(args, model, test_loader, idx2class, ckpt_path=best_ckpt_path)
103 |
104 | print("")
105 | print(f"Test accuracy: {test_acc} || Test F1 score: {test_f1}")
106 | print("GOOD BYE.")
107 |
108 |
109 | def train(args, model, optim, sched, train_loader, valid_loader, idx2class):
110 | print("Training starts.")
111 | best_f1 = 0.0
112 | patience, threshold = 0, 1e-4
113 | best_ckpt_path = None
114 |
115 | for epoch in range(1, args.num_epochs+1):
116 | model.train()
117 |
118 | print("#"*50 + f" Epoch: {epoch} " + "#"*50)
119 | train_losses, train_ys, train_outputs, train_lens = [], [], [], []
120 | for i, batch in enumerate(tqdm(train_loader)):
121 | batch_x, batch_y, batch_lens, batch_turns = batch
122 |
123 | if args.turn_type == 'single':
124 | batch_x, batch_y = batch_x.to(args.device), batch_y.to(args.device)
125 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id) # (), (B, L)
126 | elif args.turn_type == 'multi':
127 | batch_x, batch_y, batch_turns = \
128 | batch_x.to(args.device), batch_y.to(args.device), batch_turns.to(args.device)
129 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id, turns=batch_turns) # (), (B, L)
130 |
131 | loss = -1 * log_likelihood
132 |
133 | model.zero_grad()
134 | optim.zero_grad()
135 |
136 | loss.backward()
137 | optim.step()
138 | sched.step()
139 |
140 | train_losses.append(loss.detach())
141 | train_ys.append(batch_y.detach())
142 | train_outputs.append(outputs)
143 | train_lens.append(batch_lens)
144 |
145 | train_losses = [loss.item() for loss in train_losses]
146 | train_loss = np.mean(train_losses)
147 | train_preds, train_trues = [], []
148 | for i in range(len(train_ys)):
149 | pred_batch, true_batch, batch_lens = train_outputs[i], train_ys[i], train_lens[i]
150 |
151 | batch_lens = batch_lens.tolist() # (B)
152 | true_batch = [batch[:batch_lens[b]] for b, batch in enumerate(true_batch.tolist())]
153 |
154 | assert len(pred_batch) == len(true_batch)
155 | train_preds += pred_batch
156 | train_trues += true_batch
157 |
158 | assert len(train_preds) == len(train_trues)
159 | for i in range(len(train_preds)):
160 | train_pred, train_true = train_preds[i], train_trues[i]
161 | train_pred = [idx2class[class_id] for class_id in train_pred]
162 | train_true = [idx2class[class_id] for class_id in train_true]
163 |
164 | train_preds[i] = train_pred
165 | train_trues[i] = train_true
166 |
167 | train_acc = accuracy_score(train_trues, train_preds)
168 | train_f1 = f1_score(train_trues, train_preds)
169 |
170 | print(f"Train loss: {train_loss} || Train accuracy: {train_acc} || Train F1 score: {train_f1}")
171 |
172 | print("Validation processing...")
173 | valid_loss, valid_acc, valid_f1 = evaluate(args, model, valid_loader, idx2class)
174 |
175 | if valid_f1 >= best_f1 + threshold:
176 | best_f1 = valid_f1
177 | patience = 0
178 | best_ckpt_path = f"{args.ckpt_dir}/ckpt_epoch={epoch}_train_f1={round(train_f1, 4)}_valid_f1={round(valid_f1, 4)}"
179 | torch.save(model.state_dict(), best_ckpt_path)
180 | print(f"***** Current best checkpoint is saved. *****")
181 | else:
182 | patience += 1
183 | print(f"The f1 score did not improve by {threshold}. Patience: {patience}")
184 |
185 | print(f"Best validtion f1 score: {best_f1}")
186 | print(f"Validation loss: {valid_loss} || Validation accuracy: {valid_acc} || Current validation F1 score: {valid_f1}")
187 |
188 | if patience == 3:
189 | print("Run out of patience. Abort!")
190 | break
191 |
192 | print("Training finished!")
193 |
194 | return best_ckpt_path
195 |
196 |
197 | def evaluate(args, model, eval_loader, idx2class, ckpt_path=None):
198 | if ckpt_path is not None:
199 | model.load_state_dict(torch.load(ckpt_path))
200 |
201 | model.eval()
202 |
203 | eval_losses, eval_ys, eval_outputs, eval_lens = [], [], [], []
204 | with torch.no_grad():
205 | for i, batch in enumerate(tqdm(eval_loader)):
206 | batch_x, batch_y, batch_lens, batch_turns = batch
207 |
208 | if args.turn_type == 'single':
209 | batch_x, batch_y = batch_x.to(args.device), batch_y.to(args.device)
210 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id) # (), (B, L)
211 | elif args.turn_type == 'multi':
212 | batch_x, batch_y, batch_turns = \
213 | batch_x.to(args.device), batch_y.to(args.device), batch_turns.to(args.device)
214 | log_likelihood, outputs = model(batch_x, batch_y, args.pad_id, turns=batch_turns) # (), (B, L)
215 |
216 | loss = -1 * log_likelihood
217 |
218 | eval_losses.append(loss.detach())
219 | eval_ys.append(batch_y.detach())
220 | eval_outputs.append(outputs)
221 | eval_lens.append(batch_lens)
222 |
223 | eval_losses = [loss.item() for loss in eval_losses]
224 | eval_loss = np.mean(eval_losses)
225 | eval_preds, eval_trues = [], []
226 | for i in range(len(eval_ys)):
227 | pred_batch, true_batch, batch_lens = eval_outputs[i], eval_ys[i], eval_lens[i]
228 |
229 | batch_lens = batch_lens.tolist() # (B)
230 | true_batch = [batch[:batch_lens[b]] for b, batch in enumerate(true_batch.tolist())]
231 |
232 | assert len(pred_batch) == len(true_batch)
233 | eval_preds += pred_batch
234 | eval_trues += true_batch
235 |
236 | assert len(eval_preds) == len(eval_trues)
237 | for i in range(len(eval_preds)):
238 | eval_pred, eval_true = eval_preds[i], eval_trues[i]
239 | eval_pred = [idx2class[class_id] for class_id in eval_pred]
240 | eval_true = [idx2class[class_id] for class_id in eval_true]
241 |
242 | eval_preds[i] = eval_pred
243 | eval_trues[i] = eval_true
244 |
245 | eval_acc = accuracy_score(eval_trues, eval_preds)
246 | eval_f1 = f1_score(eval_trues, eval_preds)
247 |
248 | return eval_loss, eval_acc, eval_f1
249 |
250 |
251 | def set_seed(seed):
252 | np.random.seed(seed)
253 | torch.manual_seed(seed)
254 | torch.cuda.manual_seed_all(seed)
255 | random.seed(seed)
256 |
257 |
258 | if __name__=='__main__':
259 | parser = argparse.ArgumentParser()
260 | parser.add_argument('--seed', default=0, type=int, help="The random seed.")
261 | parser.add_argument('--turn_type', required=True, type=str, help="The turn type setting. (single or multi)")
262 | parser.add_argument('--bert_type', default="bert-base-uncased", type=str, help="The BERT type to load.")
263 | parser.add_argument('--pooling', default="cls", type=str, help="The pooling policy when using the multi-turn setting.")
264 | parser.add_argument('--data_dir', default="data", type=str, help="The parent data directory.")
265 | parser.add_argument('--processed_dir', default="processed", type=str, help="The directory which contains the parsed data pickle files.")
266 | parser.add_argument('--ckpt_dir', default="saved_models", type=str, help="The path for saved checkpoints.")
267 | parser.add_argument('--gpu', default=0, type=int, help="The index of a GPU to use.")
268 | parser.add_argument('--sp1_token', default="[USR]", type=str, help="The speaker1(USER) token.")
269 | parser.add_argument('--sp2_token', default="[SYS]", type=str, help="The speaker2(SYSTEM) token.")
270 | parser.add_argument('--max_len', default=128, type=int, help="The max length of each utterance.")
271 | parser.add_argument('--max_turns', default=5, type=int, help="The maximum number of the dialogue history to be attended in the multi-turn setting.")
272 | parser.add_argument('--dropout', default=0.1, type=float, help="The dropout rate.")
273 | parser.add_argument('--context_d_ff', default=2048, type=int, help="The size of intermediate hidden states in the feed-forward layer.")
274 | parser.add_argument('--context_num_heads', default=8, type=int, help="The number of heads for the multi-head attention.")
275 | parser.add_argument('--context_dropout', default=0.1, type=float, help="The dropout rate for the context encoder.")
276 | parser.add_argument('--context_num_layers', default=2, type=int, help="The number of layers in the context encoder.")
277 | parser.add_argument('--learning_rate', default=5e-5, type=float, help="The initial learning rate.")
278 | parser.add_argument('--warmup_ratio', default=0.1, type=float, help="The ratio of warmup steps to the total training steps.")
279 | parser.add_argument('--batch_size', default=8, type=int, help="The batch size.")
280 | parser.add_argument('--num_workers', default=4, type=int, help="The number of sub-processes for data loading.")
281 | parser.add_argument('--num_epochs', default=10, type=int, help="The number of training epochs.")
282 |
283 | args = parser.parse_args()
284 |
285 | assert args.turn_type == 'single' or args.turn_type == 'multi', print("Please specify a correct turn type, either 'single' or 'multi'.")
286 | assert args.bert_type in [
287 | "bert-base-uncased",
288 | "bert-base-cased",
289 | "bert-large-uncased",
290 | "bert-large-cased"
291 | ]
292 | assert args.pooling in ["cls", "mean", "max"]
293 |
294 | run(args)
295 |
--------------------------------------------------------------------------------