├── data
├── raw_data
│ └── .keep
├── prediction_result
│ └── .keep
├── user_data
│ └── user_data.txt
├── __init__.py
└── code
│ ├── __init__.py
│ ├── model
│ ├── __init__.py
│ ├── JointBertModel.py
│ ├── InteractModel_1.py
│ ├── InteractModel_3.py
│ └── torchcrf.py
│ ├── predict
│ ├── __init__.py
│ ├── post_process.py
│ ├── test_dataset.py
│ ├── test_utils.py
│ ├── nest.txt
│ ├── integration.py
│ ├── run_JointBert.py
│ ├── run_interact1.py
│ └── run_interact3.py
│ ├── preprocess
│ ├── __init__.py
│ ├── generate_intent.py
│ ├── extract_intent_sample.py
│ ├── process_other.py
│ ├── slot_sorted.py
│ ├── split_train_dev.py
│ ├── extend_tv_sample.py
│ ├── analysis.py
│ ├── rectify.py
│ ├── extend_audio_sample.py
│ └── process_rawdata.py
│ └── scripts
│ ├── __init__.py
│ ├── config_jointBert.py
│ ├── config_Interact3.py
│ ├── config_Interact1.py
│ ├── build_vocab.py
│ ├── dataset.py
│ ├── train_interact1.py
│ ├── train_interact3.py
│ ├── train_jointBert.py
│ └── utils.py
├── image
├── ccir-image.txt
├── readme_images
│ ├── image-20210929212611984.png
│ ├── image-20210929212740433.png
│ ├── image-20210929212830068.png
│ └── image-20210929213002909.png
├── run_infer.sh
├── run.sh
└── README.md
├── 【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx
├── 【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx
├── .idea
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── .gitignore
├── modules.xml
├── deployment.xml
└── CCIR-Cup.iml
└── README.md
/data/raw_data/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/prediction_result/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/image/ccir-image.txt:
--------------------------------------------------------------------------------
1 | 链接:https://pan.baidu.com/s/1oUpFVgHvqcvsYYD4w6yHnQ
2 | 提取码:78pw
--------------------------------------------------------------------------------
/data/user_data/user_data.txt:
--------------------------------------------------------------------------------
1 | 链接:https://pan.baidu.com/s/1uwy8sO3KKEp-OdrI4C1H5Q
2 | 提取码:0i1x
--------------------------------------------------------------------------------
/【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx
--------------------------------------------------------------------------------
/【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx
--------------------------------------------------------------------------------
/image/readme_images/image-20210929212611984.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212611984.png
--------------------------------------------------------------------------------
/image/readme_images/image-20210929212740433.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212740433.png
--------------------------------------------------------------------------------
/image/readme_images/image-20210929212830068.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212830068.png
--------------------------------------------------------------------------------
/image/readme_images/image-20210929213002909.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929213002909.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/27 14:50
4 | # @Author : JJkinging
5 | # @File : __init__.py
6 |
--------------------------------------------------------------------------------
/data/code/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/27 14:50
4 | # @Author : JJkinging
5 | # @File : __init__.py
6 |
--------------------------------------------------------------------------------
/data/code/model/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/27 14:45
4 | # @Author : JJkinging
5 | # @File : __init__.py.py
6 |
--------------------------------------------------------------------------------
/data/code/predict/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/9 16:00
4 | # @Author : JJkinging
5 | # @File : __init__.py.py
6 |
--------------------------------------------------------------------------------
/data/code/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/30 11:58
4 | # @Author : JJkinging
5 | # @File : __init__.py.py
6 |
--------------------------------------------------------------------------------
/data/code/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/27 14:37
4 | # @Author : JJkinging
5 | # @File : __init__.py.py
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../:\python_project\CCIR-Cup\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/image/run_infer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /ccir/data/code/predict
4 | python run_trained_JointBert.py
5 | python run_trained_interact1.py
6 | python run_trained_interact3.py
7 | python integration.py
8 | echo '模型已全部推理完成,结果result.json已保存在prediction_result文件夹下'
9 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/image/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo 'start train the first model--JointBert'
4 | cd /ccir/data/code/scripts
5 | python train_jointBert.py
6 |
7 | echo 'start train the second model--InteractModel1'
8 | python train_interact1.py
9 |
10 | echo 'start train the third model--InteractModel3'
11 | python train_interact3.py
12 |
13 | cd ../predict
14 | python run_JointBert.py
15 | python run_interact1.py
16 | python run_interact3.py
17 | python integration.py
18 | echo '模型已全部推理完成,结果result.json已保存在prediction_result文件夹下'
19 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/code/preprocess/generate_intent.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/16 12:28
4 | # @Author : JJkinging
5 | # @File : generate_intent.py
6 | import json
7 |
8 | with open('../small_sample/new_B/train_final_clear_del.json', 'r', encoding='utf-8') as fp:
9 | raw_data = json.load(fp)
10 |
11 | intent_write = open('../small_sample/new_B/train_intent_label.txt', 'w+', encoding='utf-8')
12 | for filename, single_data in raw_data.items():
13 | intent = single_data['intent']
14 | intent_write.write(intent+'\n')
15 |
16 | intent_write.close()
17 |
--------------------------------------------------------------------------------
/.idea/CCIR-Cup.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/code/preprocess/extract_intent_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/15 21:07
4 | # @Author : JJkinging
5 | # @File : extract_small_sample.py
6 | import json
7 |
8 | with open('../raw_data/train_8.json', 'r', encoding='utf-8') as fp:
9 | raw_data = json.load(fp)
10 |
11 | res = {}
12 | for filename, single_data in raw_data.items():
13 | intent = single_data['intent']
14 | if intent == 'TVProgram-Play':
15 | res[filename] = single_data
16 |
17 | with open('../extend_data/same_intent/ori_data/TVProgram-Play.json', 'w', encoding='utf-8') as fp:
18 | json.dump(res, fp, ensure_ascii=False)
19 |
--------------------------------------------------------------------------------
/data/code/preprocess/process_other.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/10 22:22
4 | # @Author : JJkinging
5 | # @File : process_other.py
6 | import json
7 |
8 | with open('../../dataset/final_data/other_2.txt', 'r', encoding='utf-8') as fp:
9 | data = fp.readlines()
10 | data = [item.split(' ')[0].strip('\n') for item in data]
11 |
12 |
13 | res = {}
14 | for i, sen in enumerate(data):
15 | tem_dict = {}
16 | tem_dict['text'] = sen
17 | tem_dict['intent'] = 'Other'
18 | tem_dict['slots'] = {}
19 |
20 | lens = len(str(i))
21 | o_nums = 5 - lens
22 | filename = 'NLU' + '0' * o_nums + str(i)
23 | res[filename] = tem_dict
24 |
25 |
26 | with open('../../dataset/small_sample/data/other_2.json', 'w', encoding='utf-8') as fp:
27 | json.dump(res, fp, ensure_ascii=False)
28 |
--------------------------------------------------------------------------------
/data/code/preprocess/slot_sorted.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/12 18:52
4 | # @Author : JJkinging
5 | # @File : slot_sorted.py
6 |
7 | '''对train.json的数据按slot的value的长度从大到小排序'''
8 | import json
9 |
10 | with open('../small_sample/new_B/train_final_clear.json', 'r', encoding='utf-8') as fp:
11 | raw_data = json.load(fp)
12 | res = {}
13 | for filename, single_data in raw_data.items():
14 | tmp_dict = {}
15 | tmp_dict['text'] = single_data['text']
16 | tmp_dict['intent'] = single_data['intent']
17 | slots = single_data['slots']
18 | dic = {}
19 | tmp_tuple = sorted(slots.items(), key=lambda x: len(x[1]), reverse=True)
20 | for tuple_data in tmp_tuple:
21 | dic[tuple_data[0]] = tuple_data[1]
22 | tmp_dict['slots'] = dic
23 | res[filename] = tmp_dict
24 |
25 | with open('../small_sample/new_B/train_final_clear_sorted.json', 'w', encoding='utf-8') as fp:
26 | json.dump(res, fp, ensure_ascii=False)
27 |
--------------------------------------------------------------------------------
/data/code/preprocess/split_train_dev.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 16:10
4 | # @Author : JJkinging
5 | # @File : split_train_dev.py
6 | import json
7 |
8 |
9 | def split_train_dev(ratio=0.8):
10 | with open('../raw_data/train_slot_sorted.json', 'r', encoding='utf-8') as fp:
11 | raw_data = json.load(fp)
12 |
13 | length = len(raw_data)
14 |
15 | train_data = {}
16 | dev_data = {}
17 |
18 | count = 0
19 | for key, value in raw_data.items():
20 | if count < length*ratio:
21 | train_data[key] = value
22 | else:
23 | dev_data[key] = value
24 | count += 1
25 |
26 | # 写训练集
27 | with open('../raw_data/train_8.json', 'w', encoding='utf-8') as fp:
28 | json.dump(train_data, fp, ensure_ascii=False)
29 | # 写验证集
30 | with open('../raw_data/dev_2.json', 'w', encoding='utf-8') as fp:
31 | json.dump(dev_data, fp, ensure_ascii=False)
32 |
33 |
34 | if __name__ == "__main__":
35 | split_train_dev(0.8)
36 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/image/README.md:
--------------------------------------------------------------------------------
1 | 由于本地训练和线上训练存在一定偏差,所以提供了两种方式供复现
2 | * 线上训练 + 线上推理 — run.sh
3 | * 本地训练 + 线上推理 — run_infer.sh
4 |
5 | 准备工作:
6 | 先从ccir/data/user_data/user_data.txt中的百度云网盘链接下载user_data.zip,然后将其解压并替换掉user_data目录;
7 | 然后从ccir/image/ccir-image.txt的百度云网盘链接下载镜像文件ccir-image.tar,并将其放在ccir/image目录下
8 |
9 | 进入到ccir/image目录下,加载 ccir-image.tar 镜像文件
10 | ```
11 | sudo docker load -i ccir-image.tar
12 | ```
13 | #### 1.线上训练 + 线上推理
14 |
15 | 我的项目名字叫 ccir
16 |
17 | 这种方式是在线上重新训练模型然后推理复现,线上一共需要训练三个模型,在 T4 GPU 上一共大约需要耗时7个小时
18 | 镜像挂载运行方式:
19 |
20 | ```
21 | nvidia-docker run -v /home/seatrend/jinxiang/ccir/:/ccir ccir-image sh /ccir/image/run.sh
22 | ```
23 |
24 | 其中:
25 |
26 | ```
27 | /home/seatrend/jinxiang/ccir/
28 | ```
29 |
30 | 是我本地项目 ccir 所在的绝对路径,线上复现时需要替换成ccir在线上的绝对路径名,然后把它挂载到镜像的 /ccir 目录下,镜像名称是 ccir-image
31 |
32 | #### 2.本地训练 + 线上推理
33 |
34 | 如果需要直接复现线下训练结果,则:
35 |
36 | ```
37 | nvidia-docker run -v /home/seatrend/jinxiang/ccir/:/ccir ccir-image sh /ccir/image/run_infer.sh
38 | ```
39 | 其中:
40 |
41 | ```
42 | /home/seatrend/jinxiang/ccir/
43 | ```
44 |
45 | 是我本地项目 ccir 所在的绝对路径,线上复现时需要替换成ccir在线上的绝对路径名,然后把它挂载到镜像的 /ccir 目录下,镜像名称是 ccir-image
46 |
47 | 线下训练好的模型共三个,分别保存在:
48 |
49 | ```
50 | ccir/data/user_data/output_model/JointBert/trained_model
51 | ```
52 |
53 | ```
54 | ccir/data/user_data/output_model/InteractModel_1/trained_model
55 | ```
56 |
57 | ```
58 | ccir/data/user_data/output_model/InteractModel_3/trained_model
59 | ```
60 |
61 | **注意:** run.sh和run_infer.sh在window下打开可能会出现格式问题,如果运行报错,可以尝试在vim中,命令模式输入:set ff=unix,转换成unix的格式,然后保存退出重新运行。
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/data/code/predict/post_process.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/14 19:58
4 | # @Author : JJkinging
5 | # @File : post_process.py
6 | import json
7 |
8 |
9 | def process(source_path, target_path):
10 | '''
11 | 该函数作用:对模型预测结果进行纠正,即把不属于某一类intent的槽值删除,举例来说:
12 | 比如我在一条测试数据中预测出其intent = FilmTele-Play, 然后其槽值预测中出现了"notes"这个槽标签,
13 | 这与我之前统计的哪些槽标签只出现在哪些意图中不符合(即训练数据中FileTele-Play这个意图不可能出现"notes"这个槽标签),
14 | 所以该函数就把"notes"这个槽位和槽值删除掉。
15 | :param source_path:
16 | :param target_path:
17 | :return:
18 | '''
19 | with open(source_path, 'r', encoding='utf-8') as fp:
20 | ori_data = json.load(fp)
21 |
22 | with open('../../user_data/common_data/intent_slot_mapping.json', 'r', encoding='utf-8') as fp:
23 | intent_slot_mapping = json.load(fp)
24 |
25 | mapping_keys = list(intent_slot_mapping.keys())
26 | for filename, single_data in ori_data.items():
27 | intent = single_data['intent']
28 | slots = single_data['slots']
29 | slot_keys = list(single_data['slots'].keys())
30 | for item in mapping_keys:
31 | if intent == item:
32 | all_tag = intent_slot_mapping[item] # ["name", "tag", "artist", "region", "play_setting", "age"]
33 | for key in slot_keys: # 检查slots结果的每一个槽位是否合理
34 | if key not in all_tag:
35 | ori_data[filename]['slots'].pop(key) # 如果不合理,则删除该条槽
36 |
37 | for key, values in slots.items():
38 | if isinstance(values, list):
39 | tmp = list(set(values))
40 | if len(tmp) == 1:
41 | tmp = tmp[0]
42 | ori_data[filename]['slots'][key] = tmp
43 | with open(target_path, 'w', encoding='utf-8') as fp:
44 | json.dump(ori_data, fp, ensure_ascii=False)
45 |
--------------------------------------------------------------------------------
/data/code/predict/test_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 12:39
4 | # @Author : JJkinging
5 | # @File : utils.py
6 | import json
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class CCFDataset(Dataset):
11 | def __init__(self, filename, vocab, intent_dict, slot_none_dict, slot_dict, max_length=512):
12 | '''
13 | :param filename:读取数据文件名,例如:train_seq_in.txt
14 | :param slot_none_dict: slot_none的字典
15 | :param slot_dict: slot_label的字典
16 | :param vocab: 词表,例如:bert的vocab.txt
17 | :param intent_dict: intent2id的字典
18 | :param max_length: 单句最大长度
19 | '''
20 | self.filename = filename
21 | self.vocab = vocab
22 | self.intent_dict = intent_dict
23 | self.slot_none_dict = slot_none_dict
24 | self.slot_dict = slot_dict
25 | self.max_length = max_length
26 |
27 | self.result = []
28 |
29 | # 读取数据
30 | with open(self.filename, 'r', encoding='utf-8') as fp:
31 | sen_data = fp.readlines()
32 | sen_data = [item.strip('\n') for item in sen_data] # 删除句子结尾的换行符('\n')
33 |
34 | for utterance in sen_data:
35 | utterance = utterance.split(' ') # str变list
36 | # 最大长度检验
37 | if len(utterance) > self.max_length-2:
38 | utterance = utterance[:max_length]
39 |
40 | # input_ids
41 | utterance = ['[CLS]'] + utterance + ['[SEP]']
42 | input_ids = [int(self.vocab[i]) for i in utterance]
43 |
44 | length = len(input_ids)
45 |
46 | # input_mask
47 | input_mask = [1] * len(input_ids)
48 |
49 | self.result.append((input_ids, input_mask))
50 |
51 | def __len__(self):
52 | return len(self.result)
53 |
54 | def __getitem__(self, index):
55 | input_ids, input_mask = self.result[index]
56 |
57 | return input_ids, input_mask
58 |
--------------------------------------------------------------------------------
/data/code/predict/test_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/9 15:57
4 | # @Author : JJkinging
5 | # @File : utils.py
6 | import torch
7 |
8 | def load_vocab(vocab_file):
9 | '''construct word2id'''
10 | vocab = {}
11 | index = 0
12 | with open(vocab_file, 'r', encoding='utf-8') as fp:
13 | while True:
14 | token = fp.readline()
15 | if not token:
16 | break
17 | token = token.strip() # 删除空白符
18 | vocab[token] = index
19 | index += 1
20 | return vocab
21 |
22 |
23 | def load_reverse_vocab(vocab_file):
24 | '''construct id2word'''
25 | vocab = {}
26 | index = 0
27 | with open(vocab_file, 'r', encoding='utf-8') as fp:
28 | while True:
29 | token = fp.readline()
30 | if not token:
31 | break
32 | token = token.strip() # 删除空白符
33 | vocab[index] = token
34 | index += 1
35 | return vocab
36 |
37 |
38 | def collate_to_max_length(batch):
39 | # input_ids, input_mask
40 | batch_size = len(batch)
41 | input_ids_list = []
42 | input_mask_list = []
43 | for single_data in batch:
44 | input_ids_list.append(single_data[0])
45 | input_mask_list.append(single_data[1])
46 |
47 | max_length = max([len(item) for item in input_ids_list])
48 |
49 | output = [torch.full([batch_size, max_length],
50 | fill_value=0,
51 | dtype=torch.long),
52 | torch.full([batch_size, max_length],
53 | fill_value=0,
54 | dtype=torch.long)
55 | ]
56 |
57 | for i in range(batch_size):
58 | output[0][i][0:len(input_ids_list[i])] = torch.LongTensor(input_ids_list[i])
59 | output[1][i][0:len(input_mask_list[i])] = torch.LongTensor(input_mask_list[i])
60 |
61 | return output # (input_ids, input_mask)
62 |
--------------------------------------------------------------------------------
/data/code/scripts/config_jointBert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 15:13
4 | # @Author : JJkinging
5 | # @File : config.py
6 | class Config(object):
7 | '''配置类'''
8 |
9 | def __init__(self):
10 | self.train_file = '../../user_data/train_data/train_seq_in.txt'
11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt'
12 | self.test_file = '../../user_data/test_data/test_seq_in.txt'
13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt'
14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt'
15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt'
16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt'
17 | self.max_length = 512
18 | self.batch_size = 16
19 | self.test_batch_size = 32
20 | self.bert_model_path = '../../user_data/pretrained_model/ernie'
21 | self.checkpoint = None # '../../user_data/output_model/JointBert/model_14.pth.tar'
22 | self.use_gpu = True
23 | self.cuda = "cuda:0"
24 | self.attention_dropout = 0.1
25 | self.bert_hidden_size = 768
26 | self.embedding_dim = 300
27 | self.lr = 5e-5 # 5e-5
28 | self.crf_lr = 5e-2 # 5e-2
29 | self.weight_decay = 0.0
30 | self.epochs = 60
31 | self.max_grad_norm = 4
32 | self.patience = 60
33 | self.target_dir = '../../user_data/output_model/JointBert'
34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt'
35 | self.slot_label = '../../user_data/common_data/slot_label.txt'
36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt'
37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt'
38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt'
39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt'
40 |
41 | def update(self, **kwargs):
42 | for k, v in kwargs.items():
43 | setattr(self, k, v)
44 |
45 | def __str__(self):
46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()])
47 |
48 |
49 | if __name__ == '__main__':
50 | con = Config()
51 | con.update(gpu=8)
52 | print(con.gpu)
53 | print(con)
54 |
--------------------------------------------------------------------------------
/data/code/scripts/config_Interact3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 15:13
4 | # @Author : JJkinging
5 | # @File : config.py
6 | class Config(object):
7 | '''配置类'''
8 |
9 | def __init__(self):
10 | self.train_file = '../../user_data/train_data/train_seq_in.txt'
11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt'
12 | self.test_file = '../../user_data/test_data/test_seq_in.txt'
13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt'
14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt'
15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt'
16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt'
17 | self.max_length = 512
18 | self.batch_size = 16
19 | self.test_batch_size = 32
20 | self.bert_model_path = '../../user_data/pretrained_model/ernie'
21 | self.checkpoint = None # '../../user_data/output_model/InteractModel_3/model_27.pth.tar'
22 | self.use_gpu = True
23 | self.cuda = "cuda:0"
24 | self.attention_dropout = 0.1
25 | self.bert_hidden_size = 768
26 | self.embedding_dim = 300
27 | self.lr = 5e-5 # 5e-5
28 | self.crf_lr = 5e-2 # 5e-2
29 | self.weight_decay = 0.0
30 | self.epochs = 60
31 | self.max_grad_norm = 4
32 | self.patience = 60
33 | self.target_dir = '../../user_data/output_model/InteractModel_3'
34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt'
35 | self.slot_label = '../../user_data/common_data/slot_label.txt'
36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt'
37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt'
38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt'
39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt'
40 |
41 | def update(self, **kwargs):
42 | for k, v in kwargs.items():
43 | setattr(self, k, v)
44 |
45 | def __str__(self):
46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()])
47 |
48 |
49 | if __name__ == '__main__':
50 | con = Config()
51 | con.update(gpu=8)
52 | print(con.gpu)
53 | print(con)
54 |
--------------------------------------------------------------------------------
/data/code/scripts/config_Interact1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 15:13
4 | # @Author : JJkinging
5 | # @File : config.py
6 | class Config(object):
7 | '''配置类'''
8 |
9 | def __init__(self):
10 | self.train_file = '../../user_data/train_data/train_seq_in.txt'
11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt'
12 | self.test_file = '../../user_data/test_data/test_seq_in.txt'
13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt'
14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt'
15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt'
16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt'
17 | self.max_length = 512
18 | self.batch_size = 16
19 | self.test_batch_size = 32
20 | self.bert_model_path = '../../user_data/pretrained_model/ernie'
21 | self.checkpoint = None # '../../user_data/output_model/InteractModel_1/trained_model/model_27.pth.tar'
22 | self.use_gpu = True
23 | self.cuda = "cuda:0"
24 | self.attention_dropout = 0.1
25 | self.bert_hidden_size = 768
26 | self.embedding_dim = 300
27 | self.lr = 5e-5 # 5e-5
28 | self.crf_lr = 5e-2 # 5e-2
29 | self.weight_decay = 0.0
30 | self.epochs = 60
31 | self.max_grad_norm = 4
32 | self.patience = 60
33 | self.target_dir = '../../user_data/output_model/InteractModel_1'
34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt'
35 | self.slot_label = '../../user_data/common_data/slot_label.txt'
36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt'
37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt'
38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt'
39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt'
40 |
41 | def update(self, **kwargs):
42 | for k, v in kwargs.items():
43 | setattr(self, k, v)
44 |
45 | def __str__(self):
46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()])
47 |
48 |
49 | if __name__ == '__main__':
50 | con = Config()
51 | con.update(gpu=8)
52 | print(con.gpu)
53 | print(con)
54 |
--------------------------------------------------------------------------------
/data/code/preprocess/extend_tv_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/16 11:18
4 | # @Author : JJkinging
5 | # @File : extend_tv_sample.py
6 | '''数据增强: 扩充TVProgram-Play小样本数据集'''
7 | import json
8 | import random
9 | random.seed(1000)
10 |
11 | with open('../../small_sample/same_intent/TVProgram-Play.json', 'r', encoding='utf-8') as fp:
12 | ori_data = json.load(fp)
13 |
14 | with open('../../small_sample/tv_entity_dic.json', 'r', encoding='utf-8') as fp:
15 | tv_dic = json.load(fp)
16 |
17 | res = {}
18 | idx = 0
19 | for filename, single_data in ori_data.items():
20 | text = single_data['text']
21 | slots = single_data['slots']
22 | for _ in range(30): # 每条扩充99条数据
23 | tmp_dic = {}
24 | slot_dic = {}
25 | length = len(str(idx))
26 | o_num = 5 - length
27 | new_filename = 'NLU' + '0'*o_num + str(idx)
28 | idx += 1
29 |
30 | new_text = text
31 | for prefix in tv_dic['prefix']:
32 | if text[:-3].find(prefix) != -1:
33 | ran_i = random.randint(0, len(tv_dic['prefix'])-1)
34 | new_text = text.replace(prefix, tv_dic['prefix'][ran_i]) # 随机找一个prefix替换
35 | break
36 | else:
37 | ran_i = random.randint(0, len(tv_dic['prefix'])-1)
38 | new_text = tv_dic['prefix'][ran_i] + text
39 | if 'name' in slots.keys():
40 | ran_i = random.randint(0, len(tv_dic['name'])-1)
41 | new_text = new_text.replace(slots['name'], tv_dic['name'][ran_i])
42 | slot_dic['name'] = tv_dic['name'][ran_i] # 随机替换name
43 | if 'channel' in slots.keys():
44 | ran_i = random.randint(0, len(tv_dic['channel']) - 1)
45 | new_text = new_text.replace(slots['channel'], tv_dic['channel'][ran_i])
46 | slot_dic['channel'] = tv_dic['channel'][ran_i] # 随机替换name
47 | if 'datetime_date' in slots.keys():
48 | ran_i = random.randint(0, len(tv_dic['datetime_date']) - 1)
49 | new_text = new_text.replace(slots['datetime_date'], tv_dic['datetime_date'][ran_i])
50 | slot_dic['datetime_date'] = tv_dic['datetime_date'][ran_i] # 随机替换language
51 | if 'datetime_time' in slots.keys():
52 | ran_i = random.randint(0, len(tv_dic['datetime_time']) - 1)
53 | new_text = new_text.replace(slots['datetime_time'], tv_dic['datetime_time'][ran_i])
54 | slot_dic['datetime_time'] = tv_dic['datetime_time'][ran_i] # 随机替换language
55 | tmp_dic['text'] = new_text
56 | tmp_dic['intent'] = "TVProgram-Play"
57 | tmp_dic['slots'] = slot_dic
58 | res[new_filename] = tmp_dic
59 |
60 |
61 | print(res)
62 | print(len(res))
63 |
64 | with open('../../extend_data/data/extend_tv.json', 'w', encoding='utf-8') as fp:
65 | json.dump(res, fp, ensure_ascii=False)
66 |
67 |
--------------------------------------------------------------------------------
/data/code/preprocess/analysis.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 17:07
4 | # @Author : JJkinging
5 | # @File : analysis.py
6 | import json
7 |
8 | with open('../raw_data/train.json', 'r', encoding='utf-8') as fp:
9 | data = json.load(fp)
10 |
11 | FilmTele_Play = set()
12 | Audio_Play = set()
13 | Radio_Listen = set()
14 | TVProgram_Play = set()
15 | Travel_Query = set()
16 | Music_Play = set()
17 | HomeAppliance_Control = set()
18 | Calendar_Query = set()
19 | Alarm_Update = set()
20 | Video_Play = set()
21 | Weather_Query = set()
22 |
23 | for filename, single_data in data.items():
24 | intent = single_data['intent']
25 | slots = single_data['slots']
26 | if intent == 'FilmTele-Play':
27 | FilmTele_Play.update(slots.keys())
28 | elif intent == 'Audio-Play':
29 | Audio_Play.update(slots.keys())
30 | elif intent == 'Radio-Listen':
31 | Radio_Listen.update(slots.keys())
32 | elif intent == 'TVProgram-Play':
33 | TVProgram_Play.update(slots.keys())
34 | elif intent == 'Travel-Query':
35 | Travel_Query.update(slots.keys())
36 | elif intent == 'Music-Play':
37 | Music_Play.update(slots.keys())
38 | elif intent == 'HomeAppliance-Control':
39 | HomeAppliance_Control.update(slots.keys())
40 | elif intent == 'Calendar-Query':
41 | Calendar_Query.update(slots.keys())
42 | elif intent == 'Alarm-Update':
43 | Alarm_Update.update(slots.keys())
44 | elif intent == 'Video-Play':
45 | Video_Play.update(slots.keys())
46 | elif intent == 'Weather-Query':
47 | Weather_Query.update(slots.keys())
48 |
49 | with open('../intent_classify/FilmTele_Play.txt', 'w', encoding='utf-8') as fp:
50 | fp.write(str(list(FilmTele_Play)))
51 | with open('../intent_classify/Audio_Play.txt', 'w', encoding='utf-8') as fp:
52 | fp.write(str(list(Audio_Play)))
53 | with open('../intent_classify/Radio_Listen.txt', 'w', encoding='utf-8') as fp:
54 | fp.write(str(list(Radio_Listen)))
55 | with open('../intent_classify/TVProgram_Play.txt', 'w', encoding='utf-8') as fp:
56 | fp.write(str(list(TVProgram_Play)))
57 | with open('../intent_classify/Travel_Query.txt', 'w', encoding='utf-8') as fp:
58 | fp.write(str(list(Travel_Query)))
59 | with open('../intent_classify/Music_Play.txt', 'w', encoding='utf-8') as fp:
60 | fp.write(str(list(Music_Play)))
61 | with open('../intent_classify/HomeAppliance_Control.txt', 'w', encoding='utf-8') as fp:
62 | fp.write(str(list(HomeAppliance_Control)))
63 | with open('../intent_classify/Calendar_Query.txt', 'w', encoding='utf-8') as fp:
64 | fp.write(str(list(Calendar_Query)))
65 | with open('../intent_classify/Alarm_Update.txt', 'w', encoding='utf-8') as fp:
66 | fp.write(str(list(Alarm_Update)))
67 | with open('../intent_classify/Video_Play.txt', 'w', encoding='utf-8') as fp:
68 | fp.write(str(list(Video_Play)))
69 | with open('../intent_classify/Weather_Query.txt', 'w', encoding='utf-8') as fp:
70 | fp.write(str(list(Weather_Query)))
71 |
--------------------------------------------------------------------------------
/data/code/model/JointBertModel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:40
4 | # @Author : JJkinging
5 | # @File : model.py
6 | import torch.nn as nn
7 | from transformers import BertModel
8 | from data.code.model.torchcrf import CRF
9 | import torch
10 |
11 |
12 | class JointBertModel(nn.Module):
13 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size, slot_tag_size, device):
14 | super(JointBertModel, self).__init__()
15 | self.bert_model_path = bert_model_path
16 | self.bert_hidden_size = bert_hidden_size
17 | self.intent_tag_size = intent_tag_size
18 | self.slot_none_tag_size = slot_none_tag_size
19 | self.slot_tag_size = slot_tag_size
20 | self.device = device
21 | self.bert = BertModel.from_pretrained(self.bert_model_path)
22 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True)
23 |
24 | self.intent_classification = nn.Linear(self.bert_hidden_size, self.intent_tag_size)
25 | self.slot_none_classification = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size)
26 | self.slot_classification = nn.Linear(self.bert_hidden_size, self.slot_tag_size)
27 | self.Dropout = nn.Dropout(p=0.5)
28 |
29 | def forward(self, input_ids, input_mask):
30 | batch_size = input_ids.size(0)
31 | seq_len = input_ids.size(1)
32 | utter_encoding = self.bert(input_ids, input_mask)
33 | sequence_output = utter_encoding[0]
34 | pooled_output = utter_encoding[1]
35 |
36 | pooled_output = self.Dropout(pooled_output)
37 | sequence_output = self.Dropout(sequence_output)
38 |
39 | intent_logits = self.intent_classification(pooled_output) # [batch_size, slot_tag_size]
40 | slot_logits = self.slot_classification(sequence_output)
41 | slot_none_logits = self.slot_none_classification(pooled_output)
42 |
43 | return intent_logits, slot_none_logits, slot_logits
44 |
45 | def slot_loss(self, feats, slot_ids, mask):
46 | ''' 做训练时用
47 | :param feats: the output of BiLSTM and Liner
48 | :param slot_ids:
49 | :param mask:
50 | :return:
51 | '''
52 | feats = feats.to(self.device)
53 | slot_ids = slot_ids.to(self.device)
54 | mask = mask.to(self.device)
55 | loss_value = self.CRF(emissions=feats,
56 | tags=slot_ids,
57 | mask=mask,
58 | reduction='mean')
59 | return -loss_value
60 |
61 | def slot_predict(self, feats, mask, id2slot):
62 | feats = feats.to(self.device)
63 | mask = mask.to(self.device)
64 | slot2id = {value: key for key, value in id2slot.items()}
65 | # 做验证和测试时用
66 | out_path = self.CRF.decode(emissions=feats, mask=mask)
67 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path]
68 | for out in out_path:
69 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等
70 | if tag.startswith('I-'): # 当前tag为I-开头
71 | if i == 0: # 0位置应该是[START]
72 | out[i] = '[START]'
73 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的
74 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag
75 |
76 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path]
77 |
78 | return out_path
79 |
--------------------------------------------------------------------------------
/data/code/preprocess/rectify.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/19 10:39
4 | # @Author : JJkinging
5 | # @File : rectify.py
6 | import json
7 |
8 |
9 | def rectify(source_path, target_path):
10 | with open(source_path, 'r', encoding='utf-8') as fp:
11 | ori_data = json.load(fp)
12 |
13 | for filename, single_data in ori_data.items():
14 | text = single_data['text']
15 | slots = single_data['slots']
16 | intent = single_data['intent']
17 | if intent == 'Alarm-Update':
18 | if 'datetime_time' in slots.keys():
19 | idx_1 = text.find(':')
20 | idx_2 = text[idx_1+1:].find(':')
21 |
22 | if idx_1 != -1:
23 | start_idx = idx_1
24 | end_idx = idx_1
25 | for i in range(1, 3): # 找到第一个时间的起始index
26 | if text[start_idx - 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
27 | start_idx -= 1
28 | if text[end_idx + 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
29 | end_idx += 1
30 | if idx_2 != -1: # 说明有两个时间
31 | start_idx_2 = idx_2 + idx_1 + 1 # 修正一下
32 | end_idx_2 = start_idx_2
33 | for i in range(1, 3): # 找到第二个时间的起始index
34 | if text[start_idx_2 - 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
35 | start_idx_2 -= 1
36 | if text[end_idx_2 + 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
37 | end_idx_2 += 1
38 | for item in slots['datetime_time']:
39 | if item in text[start_idx: end_idx+1]:
40 | slots['datetime_time'].remove(item)
41 | slots['datetime_time'].append(text[start_idx: end_idx+1])
42 | elif item in text[start_idx_2: end_idx_2+1]:
43 | slots['datetime_time'].remove(item)
44 | slots['datetime_time'].append(text[start_idx_2: end_idx_2 + 1])
45 | else:
46 | if isinstance(slots['datetime_time'], list):
47 | for item in slots['datetime_time']:
48 | if item in text[start_idx: end_idx + 1]:
49 | slots['datetime_time'].remove(item)
50 | slots['datetime_time'].append(text[start_idx: end_idx+1])
51 | else:
52 | index = slots['datetime_time'].find(':')
53 | if index != -1:
54 | slots['datetime_time'] = slots['datetime_time'][:index] + text[idx_1: end_idx+1]
55 | else:
56 | slots['datetime_time'] = slots['datetime_time'] + text[idx_1: end_idx + 1]
57 |
58 | with open(target_path, 'w', encoding='utf-8') as fp:
59 | json.dump(ori_data, fp, ensure_ascii=False)
60 |
61 |
62 | if __name__ == "__main__":
63 | # source = '../small_sample/data/train_8_with_other.json'
64 | # target = '../small_sample/data/train_8_with_other_clear.json'
65 | source = '../small_sample/new_B/train_final.json'
66 | target = '../small_sample/new_B/train_final_clear.json'
67 | rectify(source, target)
68 |
69 | '''手动改'datetime_time': []'''
70 |
71 | '''先删除
72 | NLU12276
73 | "NLU11929": {
74 | "text": "能不能建一个今天晚上6:30~7:00提醒我钉钉直播会议",
75 | "intent": "Alarm-Update",
76 | "slots": {
77 | "datetime_date": "今天",
78 | "datetime_time": "晚上6:30~7:00",
79 | "notes": "钉钉直播会议"
80 | }
81 | },
82 | 再添加
83 | '''
84 |
--------------------------------------------------------------------------------
/data/code/predict/nest.txt:
--------------------------------------------------------------------------------
1 | 放一个讲述美食的美国纪录片来看 {'name': '美食的美国纪录片', 'region': '美国'}
2 | O O O O O B-name I-name I-name B-region I-region I-name I-name I-name O O
3 | 王力宏在重庆开演唱会的视频放来看一下 {'name': '王力宏在重庆开演唱会的视频', 'region': '重庆'}
4 | B-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O O O
5 | 放一下钟汉良在2019年央视春晚上唱歌的视频 {'name': '钟汉良在2019年央视春晚上唱歌的视频', 'datetime_date': '2019年'}
6 | O O O B-name I-name I-name I-name B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name
7 | 帮我找下有没有关于北京的纪录片 {'region': '北京', 'name': '北京的纪录片'}
8 | O O O O O O O O O B-region I-region I-name I-name I-name I-name
9 | 给我播放2021中国女子足球锦标赛 {'datetime_date': '2021', 'region': '中国', 'name': '中国女子足球锦标赛'}
10 | O O O O B-datetime_date B-region I-region I-name I-name I-name I-name I-name I-name I-name
11 | 请播放钢琴奏鸣曲 {'instrument': '钢琴', 'song': '钢琴奏鸣曲'}
12 | O O O B-instrument I-instrument I-song I-song I-song
13 | 给我播放2012年刘翔参加钻石联赛上海站的视频 {'datetime_date': '2012年', 'name': '刘翔参加钻石联赛上海站的视频', 'region': '上海'}
14 | O O O O B-datetime_date I-datetime_date B-name I-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name
15 | 有没有介绍日本地理的纪录片,播来看看 {'region': '日本', 'name': '日本地理的纪录片'}
16 | O O O O O B-region I-region I-name I-name I-name I-name I-name I-name O O O O O
17 | 给我播放阿飞今年讲解游戏的系列视频吧 {'datetime_date': '今年', 'name': '今年讲解游戏的系列视频'}
18 | O O O O O O B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name O
19 | 超级想看辣目洋子今年在超新星运动会上表演的艺术体操视频 {'name': '辣目洋子今年在超新星运动会上表演的艺术体操视频', 'datetime_date': '今年'}
20 | O O O O B-name I-name I-name I-name B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name
21 | 给我找一找吃货俱乐部在意大利拍的那一期 {'name': '吃货俱乐部在意大利拍的那一期', 'region': '意大利'}
22 | O O O O O B-name I-name I-name I-name I-name I-name B-region I-region I-region I-name I-name I-name I-name I-name
23 | 4月16日的斯里兰卡青睐中国这个视频能找到吗 {'datetime_date': '4月16日', 'region': '斯里兰卡', 'name': '斯里兰卡青睐中国这个视频'}
24 | B-datetime_date I-datetime_date I-datetime_date I-datetime_date O B-region I-region I-region I-region I-name I-name I-name I-name I-name I-name I-name I-name O O O O
25 | 周末特供2021-316的节目回看一下 {'datetime_date': '周末', 'name': '周末特供2021-316'}
26 | B-datetime_date I-datetime_date I-name I-name I-name I-name I-name O O O O O O O
27 | 我想看童蕾在6月15日拍摄都市丽人的花絮视频 {'name': '童蕾在6月15日拍摄都市丽人的花絮视频', 'datetime_date': '6月15日'}
28 | O O O B-name I-name I-name B-datetime_date I-datetime_date I-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name
29 | 推荐一部催泪的美国纪录片来看 {'name': '催泪的美国纪录片', 'region': '美国'}
30 | O O O O B-name I-name I-name B-region I-region I-name I-name I-name O O
31 | 一场关于爱情的谎言在4月26播放的花絮有么 {'name': ['一场关于爱情的谎言', '4月26播放的花絮'], 'datetime_date': '4月26'}
32 | B-name I-name I-name I-name I-name I-name I-name I-name I-name O B-datetime_date I-datetime_date I-datetime_date I-name I-name I-name I-name I-name O O
33 | 我想看2012年麦当娜在巴黎举行演唱会的那个视频 {'datetime_date': '2012年', 'name': '麦当娜在巴黎举行演唱会的那个视频', 'region': '巴黎'}
34 | O O O B-datetime_date I-datetime_date B-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name
35 | 播放一个豆瓣评分最高的日本纪录片来看 {'name': '豆瓣评分最高的日本纪录片', 'region': '日本'}
36 | O O O O B-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name O O
37 | 给我放一个高桥留美子的日本动漫片剪辑版视频 {'name': '高桥留美子的日本动漫片剪辑版视频', 'region': '日本'}
38 | O O O O O B-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name
39 | 找几个美国动漫视频小短片来看看 {'region': '美国', 'name': '美国动漫视频小短片'}
40 | O O O B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O
41 | 吸血鬼侦探这个韩国电视剧的花絮给我播放一下 {'name': '吸血鬼侦探这个韩国电视剧的花絮', 'region': '韩国'}
42 | B-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name O O O O O O
43 | 回放2008年的北京奥运会 {'datetime_date': '2008年', 'region': '北京', 'name': '北京奥运会'}
44 | O O B-datetime_date I-datetime_date O B-region I-region I-name I-name I-name
45 | 昨天法国报告“罕见” 变异新冠病毒的新闻视频再给我放下吧 {'datetime_date': '昨天', 'region': '法国', 'name': '法国报告“罕见” 变异新冠病毒的新闻视频'}
46 | B-datetime_date I-datetime_date B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name O O O O O O
47 |
--------------------------------------------------------------------------------
/data/code/scripts/build_vocab.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/5 15:21
4 | # @Author : JJkinging
5 | # @File : build_vocab.py
6 | import json
7 | import pickle
8 |
9 | import numpy as np
10 | from collections import Counter
11 |
12 | import torch
13 |
14 |
15 | def build_worddict(train_path, dev_path, test_path):
16 | '''
17 | function:构建词典
18 | :param data: read_data返回的数据
19 | :return:
20 | '''
21 | with open(train_path, 'r', encoding='utf-8') as fp:
22 | train_data = fp.readlines()
23 | train_data = [item.strip('\n').split(' ') for item in train_data]
24 | with open(dev_path, 'r', encoding='utf-8') as fp:
25 | dev_data = fp.readlines()
26 | dev_data = [item.strip('\n').split(' ') for item in dev_data]
27 | with open(test_path, 'r', encoding='utf-8') as fp:
28 | test_data = fp.readlines()
29 | test_data = [item.strip('\n').split(' ') for item in test_data]
30 |
31 | word_dict = {}
32 | words = []
33 | lengths = []
34 | for single_data in train_data:
35 | words.extend(single_data)
36 | lengths.append(len(single_data))
37 | for single_data in dev_data:
38 | words.extend(single_data)
39 | lengths.append(len(single_data))
40 | for single_data in test_data:
41 | words.extend(single_data)
42 | lengths.append(len(single_data))
43 | print(len(words))
44 | print(lengths)
45 | x = 0
46 | for len1 in lengths:
47 | if len1 >= 100:
48 | x += 1
49 | print('最长:', max(lengths))
50 | print('最短:', min(lengths))
51 | print('长度超过100(含)的句子:', x)
52 |
53 | counts = Counter(words)
54 | print(counts)
55 | num_words = len(counts)
56 | word_dict['[PAD]'] = 0
57 | word_dict['[CLS]'] = 1
58 | word_dict['[SEP]'] = 2
59 |
60 | offset = 3
61 |
62 | for i, word in enumerate(counts.most_common(num_words)):
63 | word_dict[word[0]] = i + offset
64 | # print(word_dict)
65 | return word_dict
66 |
67 |
68 | def build_embedding_matrix(embeddings_file, worddict):
69 | embeddings = {}
70 | with open(embeddings_file, "r", encoding="utf8") as input_data:
71 | for line in input_data:
72 | line = line.split()
73 | try:
74 | float(line[1])
75 | word = line[0]
76 | if word in worddict:
77 | embeddings[word] = line[1:]
78 |
79 | # Ignore lines corresponding to multiple words separated
80 | # by spaces.
81 | except ValueError:
82 | continue
83 |
84 | num_words = len(worddict)
85 | embedding_dim = len(list(embeddings.values())[0])
86 | embedding_matrix = np.zeros((num_words, embedding_dim))
87 |
88 | # Actual building of the embedding matrix.
89 | missed = 0
90 | for word, i in worddict.items():
91 | if word in embeddings:
92 | embedding_matrix[i] = np.array(embeddings[word], dtype=float)
93 | else:
94 | if word == "[PAD]":
95 | continue
96 | missed += 1
97 | # Out of vocabulary words are initialised with random gaussian
98 | # samples.
99 | embedding_matrix[i] = np.random.normal(size=(embedding_dim))
100 | print("Missed words: ", missed)
101 |
102 | return embedding_matrix
103 |
104 |
105 | def word_to_indices(self, sentence):
106 | indices = []
107 | for word in sentence:
108 | if word in self.word_dict:
109 | indices.append(self.word_dict[word])
110 | else:
111 | indices.append(self.word_dict['UNK'])
112 | return indices
113 |
114 |
115 | def load_vocab(label_file):
116 | '''construct word2id or label2id'''
117 | vocab = {}
118 | index = 0
119 | with open(label_file, 'r', encoding='utf-8') as fp:
120 | while True:
121 | token = fp.readline()
122 | if not token:
123 | break
124 | token = token.strip() # 删除空白符
125 | vocab[token] = index
126 | index += 1
127 | return vocab
128 |
129 |
130 | if __name__ == "__main__":
131 | train_path = '../dataset/small_sample/data/train_seq_in.txt'
132 | dev_path = '../dataset/final_data/dev_seq_in.txt'
133 | test_path = '../script/test/dataset/test_seq_in.txt'
134 | word_dict = build_worddict(train_path, dev_path, test_path)
135 | print(word_dict)
136 | print(len(word_dict))
137 |
138 | # label_file = '../dataset/new_data/tag.txt'
139 | # label_dict = load_vocab(label_file)
140 | # print(label_dict)
141 | # 保存embedding_file
142 | embedding_file = '../dataset/embeddings/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5'
143 | embed_matrix = build_embedding_matrix(embedding_file, word_dict)
144 | with open("../dataset/embeddings/embedding.pkl", "wb") as pkl_file:
145 | pickle.dump(embed_matrix, pkl_file)
146 | print(torch.tensor(embed_matrix).shape)
--------------------------------------------------------------------------------
/data/code/scripts/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 12:39
4 | # @Author : JJkinging
5 | # @File : utils.py
6 | from torch.utils.data import Dataset, DataLoader
7 | from data.code.predict.test_utils import load_vocab, collate_to_max_length
8 |
9 |
10 | class CCFDataset(Dataset):
11 | def __init__(self, filename, intent_filename, slot_filename, slot_none_filename, vocab, intent_dict,
12 | slot_none_dict, slot_dict, max_length=512):
13 | '''
14 | :param filename:读取数据文件名,例如:train_seq_in.txt
15 | :param intent_filename: train_intent_label.txt or dev_intent_label.txt
16 | :param slot_filename: train_seq_out.txt
17 | :param slot_none_filename: train_slot_none.txt or dev_slot_none.txt
18 | :param slot_none_dict: slot_none的字典
19 | :param slot_dict: slot_label的字典
20 | :param vocab: 词表,例如:bert的vocab.txt
21 | :param intent_dict: intent2id的字典
22 | :param max_length: 单句最大长度
23 | '''
24 | self.filename = filename
25 | self.intent_filename = intent_filename
26 | self.slot_filename = slot_filename
27 | self.slot_none_filename = slot_none_filename
28 | self.vocab = vocab
29 | self.intent_dict = intent_dict
30 | self.slot_none_dict = slot_none_dict
31 | self.slot_dict = slot_dict
32 | self.max_length = max_length
33 |
34 | self.result = []
35 |
36 | # 读取数据
37 | with open(self.filename, 'r', encoding='utf-8') as fp:
38 | sen_data = fp.readlines()
39 | sen_data = [item.strip('\n') for item in sen_data] # 删除句子结尾的换行符('\n')
40 |
41 | # 读取intent
42 | with open(self.intent_filename, 'r', encoding='utf-8') as fp:
43 | intent_data = fp.readlines()
44 | intent_data = [item.strip('\n') for item in intent_data] # 删除结尾的换行符('\n')
45 | intent_ids = [intent_dict[item] for item in intent_data]
46 |
47 | # 读取slot_none
48 | with open(self.slot_none_filename, 'r', encoding='utf-8') as fp:
49 | slot_none_data = fp.readlines()
50 | # 删除结尾的空格和换行符('\n')
51 | slot_none_data = [item.strip('\n').strip(' ').split(' ') for item in slot_none_data]
52 | # 下面列表表达式把slot_none转为id
53 | slot_none_ids = [[self.slot_none_dict[ite] for ite in item] for item in slot_none_data]
54 |
55 | # 读取slot
56 | with open(self.slot_filename, 'r', encoding='utf-8') as fp:
57 | slot_data = fp.readlines()
58 | slot_data = [item.strip('\n') for item in slot_data] # 删除句子结尾的换行符('\n')
59 | # slot_ids = [self.slot_dict[item] for item in slot_data]
60 |
61 | idx = 0
62 | for utterance in sen_data:
63 | utterance = utterance.split(' ') # str变list
64 | slot_utterence = slot_data[idx].split(' ')
65 | # 最大长度检验
66 | if len(utterance) > self.max_length-2:
67 | utterance = utterance[:max_length]
68 | slot_utterence = slot_utterence[:max_length]
69 |
70 | # input_ids
71 | utterance = ['[CLS]'] + utterance + ['[SEP]']
72 | input_ids = [int(self.vocab[i]) for i in utterance]
73 |
74 |
75 | length = len(input_ids)
76 |
77 | # slot_ids
78 | slot_utterence = ['[START]'] + slot_utterence + ['[EOS]']
79 | slot_ids = [int(self.slot_dict[i]) for i in slot_utterence]
80 |
81 | # input_mask
82 | input_mask = [1] * len(input_ids)
83 |
84 | # intent_ids
85 | intent_id = intent_ids[idx]
86 |
87 | # slot_none_ids
88 | slot_none_id = slot_none_ids[idx] # slot_none_id 为 int or list
89 |
90 | idx += 1
91 |
92 | self.result.append((input_ids, slot_ids, input_mask, intent_id, slot_none_id))
93 |
94 | def __len__(self):
95 | return len(self.result)
96 |
97 | def __getitem__(self, index):
98 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = self.result[index]
99 |
100 | return input_ids, slot_ids, input_mask, intent_id, slot_none_id
101 |
102 |
103 | if __name__ == "__main__":
104 | filename = '../dataset/final_data/train_seq_in.txt'
105 | vocab_file = '../dataset/pretrained_model/erine/vocab.txt'
106 | intent_filename = '../dataset/final_data/train_intent_label.txt'
107 | slot_filename = '../dataset/final_data/train_seq_out.txt'
108 | slot_none_filename = '../dataset/final_data/train_slot_none.txt'
109 | intent_label = '../dataset/final_data/intent_label.txt'
110 | slot_label = '../dataset/final_data/slot_label.txt'
111 | slot_none_vocab = '../dataset/final_data/slot_none_vocab.txt'
112 | intent_dict = load_vocab(intent_label)
113 | slot_dict = load_vocab(slot_label)
114 | slot_none_dict = load_vocab(slot_none_vocab)
115 | vocab = load_vocab(vocab_file)
116 |
117 | dataset = CCFDataset(filename, intent_filename, slot_filename, slot_none_filename, vocab, intent_dict,
118 | slot_none_dict, slot_dict)
119 |
120 | dataloader = DataLoader(dataset, shuffle=False, batch_size=8, collate_fn=collate_to_max_length)
121 |
122 | for batch in dataloader:
123 | print(batch)
124 | break
125 |
--------------------------------------------------------------------------------
/data/code/preprocess/extend_audio_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/15 21:08
4 | # @Author : JJkinging
5 | # @File : extend_small_sample.py
6 | '''数据增强: 扩充Audio-Play小样本数据集'''
7 | import json
8 | import random
9 | random.seed(1000)
10 |
11 | with open('../../extend_data/same_intent/ori_data/Audio-Play.json', 'r', encoding='utf-8') as fp:
12 | ori_data = json.load(fp)
13 |
14 | with open('../../small_sample/audio_entity_dic.json', 'r', encoding='utf-8') as fp:
15 | audio_dic = json.load(fp)
16 |
17 | res = {}
18 | idx = 0
19 | for filename, single_data in ori_data.items():
20 | text = single_data['text']
21 | slots = single_data['slots']
22 | for _ in range(20): # 每条扩充20条数据
23 | tmp_dic = {}
24 | slot_dic = {}
25 | length = len(str(idx))
26 | o_num = 5 - length
27 | new_filename = 'NLU' + '0'*o_num + str(idx)
28 | idx += 1
29 |
30 | new_text = text
31 | for prefix in audio_dic['prefix']:
32 | if text.find(prefix) != -1:
33 | ran_i = random.randint(0, len(audio_dic['prefix'])-1)
34 | new_text = text.replace(prefix, audio_dic['prefix'][ran_i]) # 随机找一个prefix替换
35 | break
36 | else:
37 | ran_i = random.randint(0, len(audio_dic['prefix'])-1)
38 | new_text = audio_dic['prefix'][ran_i] + text
39 | if 'name' in slots.keys():
40 | ran_i = random.randint(0, len(audio_dic['name'])-1)
41 | new_text = new_text.replace(slots['name'], audio_dic['name'][ran_i])
42 | slot_dic['name'] = audio_dic['name'][ran_i] # 随机替换name
43 | if 'artist' in slots.keys():
44 | ran_i = random.randint(0, len(audio_dic['artist']) - 1)
45 | if isinstance(slots['artist'], list):
46 | tmp_list = []
47 | for art in slots['artist']:
48 | ran_j = random.randint(0, len(audio_dic['artist']) - 1)
49 | new_text = new_text.replace(art, audio_dic['artist'][ran_j])
50 | tmp_list.append(audio_dic['artist'][ran_j])
51 | slot_dic['artist'] = tmp_list
52 | else:
53 | new_text = new_text.replace(slots['artist'], audio_dic['artist'][ran_i])
54 | slot_dic['artist'] = audio_dic['artist'][ran_i] # 随机替换artist
55 | if 'play_setting' in slots.keys():
56 | ran_i = random.randint(0, len(audio_dic['play_setting']) - 1)
57 | if isinstance(slots['play_setting'], list):
58 | tmp_list = []
59 | for set1 in slots['play_setting']:
60 | ran_j = random.randint(0, len(audio_dic['play_setting']) - 1)
61 | if set1 in audio_dic['play_setting']:
62 | new_text = new_text.replace(set1, audio_dic['play_setting'][ran_j])
63 | tmp_list.append(audio_dic['play_setting'][ran_j])
64 | else:
65 | tmp_list.append(set1)
66 | slot_dic['play_setting'] = tmp_list
67 | else:
68 | if slots['play_setting'] in audio_dic['play_setting']:
69 | new_text = new_text.replace(slots['play_setting'], audio_dic['play_setting'][ran_i])
70 | slot_dic['play_setting'] = audio_dic['play_setting'][ran_i] # 随机替换play_setting
71 | if 'language' in slots.keys():
72 | ran_i = random.randint(0, len(audio_dic['language']) - 1)
73 | if slots['language'] == "俄语" and audio_dic['language'][ran_i] != "俄语":
74 | new_text = new_text.replace(slots['language'], audio_dic['language'][ran_i][:-1]+'文')
75 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language
76 | elif slots['language'] == "俄语" and audio_dic['language'][ran_i] == "俄语":
77 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language
78 | elif slots['language'] == "华语" and audio_dic['language'][ran_i] != "华语":
79 | new_text = new_text.replace("中文", audio_dic['language'][ran_i][:-1] + '文')
80 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language
81 | elif slots['language'] == "华语" and audio_dic['language'][ran_i] == "华语":
82 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language
83 | else:
84 | new_text = new_text.replace(slots['language'][:-1]+'文', audio_dic['language'][ran_i][:-1] + '文')
85 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language
86 | if 'tag' in slots.keys():
87 | ran_i = random.randint(0, len(audio_dic['tag']) - 1)
88 | new_text = new_text.replace(slots['tag'], audio_dic['tag'][ran_i])
89 | slot_dic['tag'] = audio_dic['tag'][ran_i] # 随机替换language
90 | tmp_dic['text'] = new_text
91 | tmp_dic['intent'] = "Audio-Play"
92 | tmp_dic['slots'] = slot_dic
93 | res[new_filename] = tmp_dic
94 |
95 |
96 | print(res)
97 | print(len(res))
98 |
99 | with open('../../extend_data/data/extend_audio.json', 'w', encoding='utf-8') as fp:
100 | json.dump(res, fp, ensure_ascii=False)
101 |
102 | # 手动替换 华语|华文 ——> 中文 西班牙文 ——> 西班牙语
103 |
--------------------------------------------------------------------------------
/data/code/predict/integration.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/20 18:35
4 | # @Author : JJkinging
5 | # @File : integration.py
6 | import json
7 | import os
8 |
9 | from data.code.predict.post_process import process
10 |
11 | with open('../../user_data/tmp_result/result_interact_1_post.json', 'r', encoding='utf-8') as fp:
12 | best_data = json.load(fp)
13 | with open('../../user_data/tmp_result/result_interact_3_post.json', 'r', encoding='utf-8') as fp:
14 | second_data = json.load(fp)
15 | with open('../../user_data/tmp_result/result_bert_post.json', 'r', encoding='utf-8') as fp:
16 | third_data = json.load(fp)
17 |
18 | idx_list = []
19 | res = {}
20 | for i in range(len(best_data)):
21 | lengths = len(str(i))
22 | o_nums = 5 - lengths
23 | idx = 'NLU' + '0'*o_nums + str(i)
24 | tmp = {}
25 |
26 | best_dic = best_data[idx]
27 | second_dic = second_data[idx]
28 | third_dic = third_data[idx]
29 | # 调整intent
30 | if best_dic['intent'] == second_dic['intent'] == third_dic['intent'] or \
31 | best_dic['intent'] == second_dic['intent'] or best_dic['intent'] == third_dic['intent'] or \
32 | (best_dic['intent'] != second_dic['intent'] and best_dic['intent'] != third_dic['intent'] and
33 | second_dic['intent'] != third_dic['intent']):
34 | intent = best_dic['intent']
35 |
36 | elif second_dic['intent'] == third_dic['intent']:
37 | intent = second_dic['intent']
38 |
39 | slot_dic = {}
40 | # 调整slot
41 | best_slots = best_dic['slots']
42 | second_slots = second_dic['slots']
43 | third_slots = third_dic['slots']
44 |
45 | total_slot_keys = set()
46 | total_slot_keys.update(list(best_slots.keys()))
47 | total_slot_keys.update(list(second_slots.keys()))
48 | total_slot_keys.update(list(third_slots.keys()))
49 |
50 | second_key = []
51 | third_key = []
52 | for key in total_slot_keys:
53 | if key in best_slots.keys() and key in second_slots.keys() and key in third_slots.keys(): # 若三者都有
54 | if isinstance(best_slots[key], list) and isinstance(second_slots[key], list) and \
55 | isinstance(third_slots[key], list):
56 | best_slots[key] = set(best_slots[key])
57 | second_slots[key] = set(second_slots[key])
58 | third_slots[key] = set(third_slots[key])
59 |
60 | if best_slots[key] == second_slots[key] == third_slots[key] or \
61 | best_slots[key] == second_slots[key] or best_slots[key] == third_slots[key] or \
62 | (best_slots[key] != second_slots[key] and best_slots[key] != third_slots[key] and
63 | second_slots[key] != third_slots[key]):
64 | if isinstance(best_slots[key], set):
65 | slot_dic[key] = list(best_slots[key])
66 | else:
67 | slot_dic[key] = best_slots[key]
68 | elif second_slots[key] == third_slots[key]:
69 | if isinstance(second_slots[key], set):
70 | slot_dic[key] = list(second_slots[key])
71 | else:
72 | slot_dic[key] = second_slots[key]
73 | elif key in best_slots.keys() and key in second_slots.keys(): # 只有best 和 second 有
74 | if isinstance(best_slots[key], list) and isinstance(second_slots[key], list):
75 | best_slots[key] = set(best_slots[key])
76 | second_slots[key] = set(second_slots[key])
77 |
78 | if isinstance(best_slots[key], set):
79 | slot_dic[key] = list(best_slots[key])
80 | else:
81 | slot_dic[key] = best_slots[key]
82 | elif key in best_slots.keys() and key in third_slots.keys():
83 | if isinstance(best_slots[key], list) and isinstance(third_slots[key], list):
84 | best_slots[key] = set(best_slots[key])
85 | third_slots[key] = set(third_slots[key])
86 | if isinstance(best_slots[key], set):
87 | slot_dic[key] = list(best_slots[key])
88 | else:
89 | slot_dic[key] = best_slots[key]
90 | elif key in second_slots.keys() and key in third_slots.keys():
91 | if isinstance(second_slots[key], list) and isinstance(third_slots[key], list):
92 | second_slots[key] = set(second_slots[key])
93 | third_slots[key] = set(third_slots[key])
94 | if isinstance(second_slots[key], set):
95 | slot_dic[key] = list(second_slots[key])
96 | else:
97 | slot_dic[key] = second_slots[key]
98 | elif key in best_slots.keys():
99 | slot_dic[key] = best_slots[key]
100 | elif key in second_slots.keys():
101 | second_key.append(key)
102 | elif key in third_slots.keys():
103 | third_key.append(key)
104 |
105 | for key in second_key:
106 | if second_slots[key] not in slot_dic.values():
107 | slot_dic[key] = second_slots[key]
108 | for key in third_key:
109 | if third_slots[key] not in slot_dic.values():
110 | slot_dic[key] = third_slots[key]
111 |
112 | tmp['intent'] = intent
113 | tmp['slots'] = slot_dic
114 | res[idx] = tmp
115 |
116 | with open('../../prediction_result/result_tmp.json', 'w', encoding='utf-8') as fp:
117 | json.dump(res, fp, ensure_ascii=False)
118 |
119 | source_path = '../../prediction_result/result_tmp.json'
120 | target_path = '../../prediction_result/result.json'
121 | process(source_path, target_path)
122 | os.remove('../../prediction_result/result_tmp.json')
123 |
--------------------------------------------------------------------------------
/data/code/scripts/train_interact1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:41
4 | # @Author : JJkinging
5 | # @File : main.py
6 |
7 | import os
8 | import torch
9 | import random
10 | import numpy as np
11 | import warnings
12 | from transformers import get_linear_schedule_with_warmup
13 | from data.code.scripts.dataset import CCFDataset
14 | from data.code.scripts.config_Interact1 import Config
15 | from torch.utils.data import DataLoader
16 | from data.code.model.InteractModel_1 import InteractModel
17 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab
18 |
19 |
20 | def torch_seed(seed):
21 | os.environ['PYTHONHASHSEED'] = str(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | np.random.seed(seed) # Numpy module.
26 | random.seed(seed) # Python random module.
27 | torch.manual_seed(seed)
28 | torch.backends.cudnn.benchmark = False
29 | torch.backends.cudnn.deterministic = True
30 | torch.backends.cudnn.enabled = False
31 |
32 |
33 | def main():
34 | warnings.filterwarnings("ignore")
35 | torch_seed(1000)
36 | # 设置GPU数目
37 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
38 | config = Config()
39 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
40 | print('loading corpus')
41 | vocab = load_vocab(config.vocab_file)
42 | intent_dict = load_vocab(config.intent_label_file)
43 | slot_none_dict = load_vocab(config.slot_none_vocab)
44 | slot_dict = load_vocab(config.slot_label)
45 | intent_tagset_size = len(intent_dict)
46 | slot_none_tag_size = len(slot_none_dict)
47 | slot_tag_size = len(slot_dict)
48 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename,
49 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
50 | config.max_length)
51 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size,
52 | collate_fn=collate_to_max_length)
53 |
54 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename,
55 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
56 | config.max_length)
57 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size,
58 | collate_fn=collate_to_max_length)
59 |
60 | model = InteractModel(config.bert_model_path,
61 | config.bert_hidden_size,
62 | intent_tagset_size,
63 | slot_none_tag_size,
64 | slot_tag_size,
65 | device).to(device)
66 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0])
67 | I_criterion = torch.nn.CrossEntropyLoss()
68 | N_criterion = torch.nn.BCEWithLogitsLoss()
69 |
70 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id
71 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter)
72 |
73 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr},
74 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay)
75 | # optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0.0)
76 | total_step = len(train_loader) // config.batch_size * config.epochs
77 | scheduler = get_linear_schedule_with_warmup(optimizer,
78 | num_warmup_steps=0,
79 | num_training_steps=total_step)
80 |
81 | best_score = 0.0
82 | start_epoch = 1
83 | # Data for loss curves plot.
84 | epochs_count = []
85 | train_losses = []
86 | valid_losses = []
87 |
88 | # Continuing training from a checkpoint if one was given as argument.
89 | if config.checkpoint:
90 | checkpoint = torch.load(config.checkpoint)
91 | start_epoch = checkpoint["epoch"] + 1
92 | best_score = checkpoint["best_score"]
93 |
94 | print("\t* Training will continue on existing model from epoch {}..."
95 | .format(start_epoch))
96 |
97 | model.load_state_dict(checkpoint["model"])
98 | optimizer.load_state_dict(checkpoint["optimizer"])
99 | epochs_count = checkpoint["epochs_count"]
100 | train_losses = checkpoint["train_losses"]
101 | valid_losses = checkpoint["valid_losses"]
102 |
103 | # Compute loss and accuracy before starting (or resuming) training.
104 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
105 | dev_loader,
106 | I_criterion,
107 | N_criterion)
108 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
109 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
110 |
111 | # -------------------- Training epochs ------------------- #
112 | print("\n",
113 | 20 * "=",
114 | "Training Model model on device: {}".format(device),
115 | 20 * "=")
116 | for epoch in range(start_epoch, config.epochs+1):
117 | epochs_count.append(epoch)
118 | print("* Training epoch {}:".format(epoch))
119 | train_time, train_loss = train(model,
120 | train_loader,
121 | optimizer,
122 | I_criterion,
123 | N_criterion,
124 | config.max_grad_norm)
125 | train_losses.append(train_loss)
126 | print("-> Training time: {:.4f}s loss = {:.4f}"
127 | .format(train_time, train_loss))
128 | with open('../../user_data/output_model/InteractModel_1/metric.txt', 'a', encoding='utf-8') as fp:
129 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t')
130 |
131 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
132 | dev_loader,
133 | I_criterion,
134 | N_criterion,
135 | )
136 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
137 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
138 | with open('../../user_data/output_model/InteractModel_1/metric.txt', 'a', encoding='utf-8') as fp:
139 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' +
140 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' +
141 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t'
142 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n')
143 |
144 | valid_losses.append(valid_losses)
145 | # Update the optimizer's learning rate with the scheduler.
146 | scheduler.step()
147 | if sen_acc >= best_score:
148 | best_score = sen_acc
149 | torch.save({"epoch": epoch,
150 | "model": model.state_dict(),
151 | "best_score": best_score,
152 | "optimizer": optimizer.state_dict(),
153 | "epochs_count": epochs_count,
154 | "train_losses": train_losses,
155 | "valid_losses": valid_losses},
156 | os.path.join(config.target_dir, "Interact1_model_best.pth.tar"))
157 |
158 | if __name__ == "__main__":
159 | main()
160 |
--------------------------------------------------------------------------------
/data/code/scripts/train_interact3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:41
4 | # @Author : JJkinging
5 | # @File : main.py
6 |
7 | import os
8 | import warnings
9 |
10 | import torch
11 | import random
12 | import numpy as np
13 | from transformers import get_linear_schedule_with_warmup
14 | from data.code.scripts.dataset import CCFDataset
15 | from data.code.scripts.config_Interact3 import Config
16 | from torch.utils.data import DataLoader
17 | from data.code.model.InteractModel_3 import InteractModel
18 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab
19 |
20 |
21 | def torch_seed(seed):
22 | os.environ['PYTHONHASHSEED'] = str(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 | np.random.seed(seed) # Numpy module.
27 | random.seed(seed) # Python random module.
28 | torch.manual_seed(seed)
29 | torch.backends.cudnn.benchmark = False
30 | torch.backends.cudnn.deterministic = True
31 | torch.backends.cudnn.enabled = False
32 |
33 |
34 | def main():
35 | torch_seed(1000)
36 | warnings.filterwarnings("ignore")
37 | # 设置GPU数目
38 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
39 | config = Config()
40 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
41 | print('loading corpus')
42 | vocab = load_vocab(config.vocab_file)
43 | intent_dict = load_vocab(config.intent_label_file)
44 | slot_none_dict = load_vocab(config.slot_none_vocab)
45 | slot_dict = load_vocab(config.slot_label)
46 | intent_tagset_size = len(intent_dict)
47 | slot_none_tag_size = len(slot_none_dict)
48 | slot_tag_size = len(slot_dict)
49 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename,
50 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
51 | config.max_length)
52 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size,
53 | collate_fn=collate_to_max_length)
54 |
55 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename,
56 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
57 | config.max_length)
58 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size,
59 | collate_fn=collate_to_max_length)
60 |
61 | model = InteractModel(config.bert_model_path,
62 | config.bert_hidden_size,
63 | intent_tagset_size,
64 | slot_none_tag_size,
65 | slot_tag_size,
66 | device).to(device)
67 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0])
68 | I_criterion = torch.nn.CrossEntropyLoss()
69 | N_criterion = torch.nn.BCEWithLogitsLoss()
70 |
71 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id
72 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter)
73 |
74 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr},
75 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay)
76 | total_step = len(train_loader) // config.batch_size * config.epochs
77 | scheduler = get_linear_schedule_with_warmup(optimizer,
78 | num_warmup_steps=0,
79 | num_training_steps=total_step)
80 |
81 | best_score = 0.0
82 | start_epoch = 1
83 | # Data for loss curves plot.
84 | epochs_count = []
85 | train_losses = []
86 | valid_losses = []
87 |
88 | # Continuing training from a checkpoint if one was given as argument.
89 | if config.checkpoint:
90 | checkpoint = torch.load(config.checkpoint)
91 | start_epoch = checkpoint["epoch"] + 1
92 | best_score = checkpoint["best_score"]
93 |
94 | print("\t* Training will continue on existing model from epoch {}..."
95 | .format(start_epoch))
96 |
97 | model.load_state_dict(checkpoint["model"])
98 | optimizer.load_state_dict(checkpoint["optimizer"])
99 | epochs_count = checkpoint["epochs_count"]
100 | train_losses = checkpoint["train_losses"]
101 | valid_losses = checkpoint["valid_losses"]
102 |
103 | # Compute loss and accuracy before starting (or resuming) training.
104 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
105 | dev_loader,
106 | I_criterion,
107 | N_criterion)
108 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
109 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
110 |
111 | # -------------------- Training epochs ------------------- #
112 | print("\n",
113 | 20 * "=",
114 | "Training Model model on device: {}".format(device),
115 | 20 * "=")
116 | for epoch in range(start_epoch, config.epochs+1):
117 | epochs_count.append(epoch)
118 | print("* Training epoch {}:".format(epoch))
119 | train_time, train_loss = train(model,
120 | train_loader,
121 | optimizer,
122 | I_criterion,
123 | N_criterion,
124 | config.max_grad_norm)
125 | train_losses.append(train_loss)
126 | print("-> Training time: {:.4f}s loss = {:.4f}"
127 | .format(train_time, train_loss))
128 | with open('../../user_data/output_model/InteractModel_3/trained_model/metric.txt', 'a', encoding='utf-8') as fp:
129 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t')
130 |
131 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
132 | dev_loader,
133 | I_criterion,
134 | N_criterion,
135 | )
136 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
137 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
138 | with open('../../user_data/output_model/InteractModel_3/trained_model/metric.txt', 'a', encoding='utf-8') as fp:
139 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' +
140 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' +
141 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t'
142 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n')
143 |
144 | valid_losses.append(valid_losses)
145 | # Update the optimizer's learning rate with the scheduler.
146 | scheduler.step()
147 |
148 | # Early stopping on validation accuracy.
149 | if sen_acc >= best_score:
150 | best_score = sen_acc
151 | torch.save({"epoch": epoch,
152 | "model": model.state_dict(),
153 | "best_score": best_score,
154 | "optimizer": optimizer.state_dict(),
155 | "epochs_count": epochs_count,
156 | "train_losses": train_losses,
157 | "valid_losses": valid_losses},
158 | os.path.join(config.target_dir, "Interact3_model_best.pth.tar"))
159 |
160 |
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
--------------------------------------------------------------------------------
/data/code/scripts/train_jointBert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:41
4 | # @Author : JJkinging
5 | # @File : main.py
6 |
7 | import os
8 | import warnings
9 |
10 | import torch
11 | import random
12 | import numpy as np
13 | from torch.utils.data import DataLoader
14 | from transformers import get_linear_schedule_with_warmup
15 | from data.code.scripts.config_jointBert import Config
16 | from data.code.scripts.dataset import CCFDataset
17 | from data.code.model.JointBertModel import JointBertModel
18 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab
19 |
20 |
21 | def torch_seed(seed):
22 | os.environ['PYTHONHASHSEED'] = str(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 | np.random.seed(seed) # Numpy module.
27 | random.seed(seed) # Python random module.
28 | torch.manual_seed(seed)
29 | torch.backends.cudnn.benchmark = False
30 | torch.backends.cudnn.deterministic = True
31 | torch.backends.cudnn.enabled = False
32 |
33 |
34 | def main():
35 | warnings.filterwarnings("ignore")
36 | torch_seed(1000)
37 | # 设置GPU数目
38 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
39 | config = Config()
40 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
41 | print('loading corpus')
42 | vocab = load_vocab(config.vocab_file)
43 | intent_dict = load_vocab(config.intent_label_file)
44 | slot_none_dict = load_vocab(config.slot_none_vocab)
45 | slot_dict = load_vocab(config.slot_label)
46 | intent_tagset_size = len(intent_dict)
47 | slot_none_tag_size = len(slot_none_dict)
48 | slot_tag_size = len(slot_dict)
49 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename,
50 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
51 | config.max_length)
52 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size,
53 | collate_fn=collate_to_max_length)
54 |
55 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename,
56 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict,
57 | config.max_length)
58 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size,
59 | collate_fn=collate_to_max_length)
60 |
61 | model = JointBertModel(config.bert_model_path,
62 | config.bert_hidden_size,
63 | intent_tagset_size,
64 | slot_none_tag_size,
65 | slot_tag_size,
66 | device).to(device)
67 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0])
68 |
69 | I_criterion = torch.nn.CrossEntropyLoss()
70 | N_criterion = torch.nn.BCEWithLogitsLoss()
71 |
72 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id
73 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter)
74 |
75 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr},
76 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay)
77 | # optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0.0)
78 | total_step = len(train_loader) // config.batch_size * config.epochs
79 | scheduler = get_linear_schedule_with_warmup(optimizer,
80 | num_warmup_steps=0,
81 | num_training_steps=total_step)
82 |
83 | best_score = 0.0
84 | start_epoch = 1
85 | # Data for loss curves plot.
86 | epochs_count = []
87 | train_losses = []
88 | valid_losses = []
89 |
90 | # Continuing training from a checkpoint if one was given as argument.
91 | if config.checkpoint:
92 | checkpoint = torch.load(config.checkpoint)
93 | start_epoch = checkpoint["epoch"] + 1
94 | best_score = checkpoint["best_score"]
95 |
96 | print("\t* Training will continue on existing model from epoch {}..."
97 | .format(start_epoch))
98 |
99 | model.load_state_dict(checkpoint["model"])
100 | optimizer.load_state_dict(checkpoint["optimizer"])
101 | epochs_count = checkpoint["epochs_count"]
102 | train_losses = checkpoint["train_losses"]
103 | valid_losses = checkpoint["valid_losses"]
104 |
105 | # Compute loss and accuracy before starting (or resuming) training.
106 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
107 | dev_loader,
108 | I_criterion,
109 | N_criterion)
110 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
111 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
112 |
113 | # -------------------- Training epochs ------------------- #
114 | print("\n",
115 | 20 * "=",
116 | "Training Model model on device: {}".format(device),
117 | 20 * "=")
118 | for epoch in range(start_epoch, config.epochs+1):
119 | epochs_count.append(epoch)
120 | print("* Training epoch {}:".format(epoch))
121 | train_time, train_loss = train(model,
122 | train_loader,
123 | optimizer,
124 | I_criterion,
125 | N_criterion,
126 | config.max_grad_norm)
127 | train_losses.append(train_loss)
128 | print("-> Training time: {:.4f}s loss = {:.4f}"
129 | .format(train_time, train_loss))
130 | with open('../../user_data/output_model/JointBert/metric.txt', 'a', encoding='utf-8') as fp:
131 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t')
132 |
133 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model,
134 | dev_loader,
135 | I_criterion,
136 | N_criterion,
137 | )
138 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}"
139 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc))
140 | with open('../../user_data/output_model/JointBert/metric.txt', 'a', encoding='utf-8') as fp:
141 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' +
142 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' +
143 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t'
144 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n')
145 |
146 | valid_losses.append(valid_losses)
147 | # Update the optimizer's learning rate with the scheduler.
148 | scheduler.step()
149 |
150 | # Early stopping on validation accuracy.
151 | if sen_acc >= best_score:
152 | best_score = sen_acc
153 | torch.save({"epoch": epoch,
154 | "model": model.state_dict(),
155 | "best_score": best_score,
156 | "optimizer": optimizer.state_dict(),
157 | "epochs_count": epochs_count,
158 | "train_losses": train_losses,
159 | "valid_losses": valid_losses},
160 | os.path.join(config.target_dir, "bert_model_best.pth.tar"))
161 |
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
--------------------------------------------------------------------------------
/data/code/scripts/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 13:09
4 | # @Author : JJkinging
5 | # @File : utils.py
6 | import time
7 | import torch
8 | from tqdm import tqdm
9 | from sklearn.metrics import precision_score, recall_score, f1_score
10 | from seqeval.metrics import precision_score as slot_precision_score, recall_score as slot_recall_score, \
11 | f1_score as slot_F1_score
12 | from data.code.scripts.config_jointBert import Config
13 |
14 |
15 | def load_vocab(vocab_file):
16 | '''construct word2id'''
17 | vocab = {}
18 | index = 0
19 | with open(vocab_file, 'r', encoding='utf-8') as fp:
20 | while True:
21 | token = fp.readline()
22 | if not token:
23 | break
24 | token = token.strip() # 删除空白符
25 | vocab[token] = index
26 | index += 1
27 | return vocab
28 |
29 |
30 | def collate_to_max_length(batch):
31 | # input_ids, slot_ids, input_mask, intent_id, slot_none_id
32 | batch_size = len(batch)
33 | input_ids_list = []
34 | slot_ids_list = []
35 | input_mask_list = []
36 | intent_ids_list = []
37 | slot_none_ids_list = []
38 | for single_data in batch:
39 | input_ids_list.append(single_data[0])
40 | slot_ids_list.append(single_data[1])
41 | input_mask_list.append(single_data[2])
42 | intent_ids_list.append(single_data[3])
43 | slot_none_ids_list.append(single_data[4])
44 |
45 | max_length = max([len(item) for item in input_ids_list])
46 |
47 | output = [torch.full([batch_size, max_length],
48 | fill_value=0,
49 | dtype=torch.long),
50 | torch.full([batch_size, max_length],
51 | fill_value=0,
52 | dtype=torch.long),
53 | torch.full([batch_size, max_length],
54 | fill_value=0,
55 | dtype=torch.long)
56 | ]
57 |
58 | for i in range(batch_size):
59 | output[0][i][0:len(input_ids_list[i])] = torch.LongTensor(input_ids_list[i])
60 | output[1][i][0:len(slot_ids_list[i])] = torch.LongTensor(slot_ids_list[i])
61 | output[2][i][0:len(input_mask_list[i])] = torch.LongTensor(input_mask_list[i])
62 |
63 | intent_ids_list = torch.LongTensor(intent_ids_list)
64 | output.append(intent_ids_list)
65 |
66 | slot_none_ids = torch.zeros([batch_size, 29], dtype=torch.long)
67 | for i, slot_none_id in enumerate(slot_none_ids_list):
68 | for idx in slot_none_id:
69 | slot_none_ids[i][idx] = 1
70 |
71 | output.append(slot_none_ids)
72 |
73 | return output # (input_ids, slot_ids, input_mask, intent_id, slot_none_id)
74 |
75 |
76 | def train(model,
77 | dataloader,
78 | optimizer,
79 | I_criterion,
80 | N_criterion,
81 | max_gradient_norm):
82 | model.train()
83 | # device = model.module.device
84 | device = model.device
85 | epoch_start = time.time()
86 | batch_time_avg = 0.0
87 | running_loss = 0.0
88 | preds_mounts = 0
89 |
90 | tqdm_batch_iterator = tqdm(dataloader)
91 | for batch_index, batch in enumerate(tqdm_batch_iterator):
92 | batch_start = time.time()
93 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = batch
94 |
95 | input_id = input_ids.to(device)
96 | slot_ids = slot_ids.to(device)
97 | input_mask = input_mask.byte().to(device)
98 | intent_id = intent_id.to(device)
99 | slot_none_id = slot_none_id.to(device)
100 |
101 | optimizer.zero_grad()
102 | intent_logits, slot_none_logits, slot_logits = model(input_id, input_mask)
103 | # intent_loss = CE_criterion(intent_logits, intent_id)
104 | intent_loss = I_criterion(intent_logits, intent_id)
105 | # slot_none_loss = BCE_criterion(slot_none_logits, slot_none_id.float())
106 | slot_none_loss = N_criterion(slot_none_logits, slot_none_id.float())
107 | slot_loss = model.slot_loss(slot_logits, slot_ids, input_mask)
108 | loss = intent_loss + slot_none_loss + slot_loss
109 | loss.backward()
110 |
111 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
112 |
113 | optimizer.step()
114 |
115 | batch_time_avg += time.time() - batch_start
116 | running_loss += loss.item()
117 |
118 | description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}" \
119 | .format(batch_time_avg / (batch_index + 1),
120 | running_loss / (batch_index + 1))
121 | tqdm_batch_iterator.set_description(description)
122 |
123 | epoch_time = time.time() - epoch_start
124 | epoch_loss = running_loss / len(dataloader)
125 |
126 | return epoch_time, epoch_loss
127 |
128 |
129 | def valid(model,
130 | dataloader,
131 | I_criterion,
132 | N_criterion,):
133 | config = Config()
134 | model.eval()
135 | # device = model.module.device
136 | device = model.device
137 | epoch_start = time.time()
138 | running_loss = 0.0
139 | preds_mounts = 0
140 | sen_counts = 0
141 |
142 | intent_true = []
143 | intent_pred = []
144 | slot_pre_output = []
145 | slot_true_output = []
146 | slotNone_pre_output = []
147 | slotNone_true_output = []
148 |
149 | slot_dict = load_vocab(config.slot_label)
150 | id2slot = {value: key for key, value in slot_dict.items()}
151 |
152 | with torch.no_grad():
153 | tqdm_batch_iterator = tqdm(dataloader)
154 | for _, batch in enumerate(tqdm_batch_iterator):
155 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = batch
156 | batch_size = len(input_ids)
157 | input_ids = input_ids.to(device)
158 | slot_ids = slot_ids.to(device)
159 | input_mask = input_mask.byte().to(device)
160 | intent_id = intent_id.to(device)
161 | slot_none_id = slot_none_id.to(device)
162 |
163 | real_length = torch.sum(input_mask, dim=1)
164 | tmp = []
165 | i = 0
166 | for line in slot_ids.cpu().numpy().tolist():
167 | line = [id2slot[idx] for idx in line[1: real_length[i]-1]]
168 | tmp.append(line)
169 | i += 1
170 |
171 | slot_true_output.extend(tmp)
172 |
173 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask)
174 | # intent_loss = CE_criterion(intent_logits, intent_id)
175 | intent_loss = I_criterion(intent_logits, intent_id)
176 |
177 | # slot_none_loss = BCE_criterion(slot_none_logits, slot_none_id.float())
178 | slot_none_loss = N_criterion(slot_none_logits, slot_none_id.float())
179 | slot_none_probs = torch.sigmoid(slot_none_logits)
180 | slot_none_probs = slot_none_probs > 0.5
181 | slot_none_probs = slot_none_probs.cpu().numpy()
182 | slot_none_probs = slot_none_probs.astype(int)
183 | slot_none_probs = slot_none_probs.tolist()
184 | slotNone_pre_output.extend(slot_none_probs)
185 |
186 | slot_none_id = slot_none_id.cpu().numpy().tolist()
187 | slotNone_true_output.extend(slot_none_id)
188 |
189 | slot_loss = model.slot_loss(slot_logits, slot_ids, input_mask)
190 | out_path = model.slot_predict(slot_logits, input_mask, id2slot)
191 | out_path = [[id2slot[idx] for idx in one_data[1:-1]] for one_data in out_path] # 去掉'[START]'和'[EOS]'标记
192 | slot_pre_output.extend(out_path)
193 | loss = intent_loss + slot_none_loss + slot_loss
194 | running_loss += loss.item()
195 |
196 | # intent acc
197 | intent_probs = torch.softmax(intent_logits, dim=-1)
198 | predict_labels = torch.argmax(intent_probs, dim=-1)
199 | correct_preds = (predict_labels == intent_id).sum().item()
200 | preds_mounts += correct_preds
201 |
202 | predict_labels = predict_labels.cpu().numpy().tolist()
203 | intent_pred.extend(predict_labels)
204 | intent_id = intent_id.cpu().numpy().tolist()
205 | intent_true.extend(intent_id)
206 |
207 | # 计算slotNone classification 准确率、召回率、F1值
208 | slotNone_micro_acc = precision_score(slotNone_true_output, slotNone_pre_output, average='micro')
209 | slotNone_micro_recall = recall_score(slotNone_true_output, slotNone_pre_output, average='micro')
210 | slotNone_micro_f1 = f1_score(slotNone_true_output, slotNone_pre_output, average='micro')
211 |
212 | # 计算slot filling 准确率、召回率、F1值
213 | slot_acc = slot_precision_score(slot_true_output, slot_pre_output)
214 | slot_recall = slot_recall_score(slot_true_output, slot_pre_output)
215 | slot_f1 = slot_F1_score(slot_true_output, slot_pre_output)
216 |
217 | # 计算整句正确率
218 | for i in range(len(intent_true)):
219 | if intent_true[i] == intent_pred[i] and slot_true_output[i] == slot_pre_output[i] and \
220 | slotNone_pre_output[i] == slotNone_true_output[i]:
221 | sen_counts += 1
222 | # fp = open('../error_list1', 'w', encoding='utf-8')
223 | # for i in range(len(intent_true)):
224 | # if intent_true[i] != intent_pred[i] or slot_true_output[i] != slot_pre_output[i] or \
225 | # slotNone_true_output[i] != slotNone_pre_output[i]:
226 | # fp.write(str(i)+'\n')
227 |
228 | slot_none = (slotNone_micro_acc, slotNone_micro_recall, slotNone_micro_f1)
229 | slot = (slot_acc, slot_recall, slot_f1)
230 |
231 | epoch_time = time.time() - epoch_start
232 | epoch_loss = running_loss / len(dataloader)
233 | intent_accuracy = preds_mounts / len(dataloader.dataset)
234 | sen_accuracy = sen_counts / len(dataloader.dataset)
235 |
236 | return epoch_time, epoch_loss, intent_accuracy, slot_none, slot, sen_accuracy
237 |
--------------------------------------------------------------------------------
/data/code/preprocess/process_rawdata.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/8/30 23:32
4 | # @Author : JJkinging
5 | # @File : process_rawdata.py
6 | import json
7 | from transformers import AutoTokenizer
8 |
9 |
10 | def find_idx(text, value, dic, mode):
11 | '''
12 | 返回value在text中的起始位置
13 | :param text: list
14 | :param value: list
15 | :param dic: 同义词词表, 只有mode=='mohu'时才用到
16 | :param mode: 如果mode=='mohu',则表示是模糊搜索
17 | :return:
18 | '''
19 |
20 | if mode == 'mohu':
21 | value = ''.join(value) # 变成字符串
22 | if value not in dic.keys():
23 | return -1
24 | candidate = dic[value]
25 | flag = False
26 | index = -1
27 | for ca_value in candidate:
28 | for idx, item in enumerate(text):
29 | if item == ca_value[0]: # 匹配第一个字符
30 | index = idx
31 | for i in range(len(ca_value)):
32 | if text[idx + i] == ca_value[i]:
33 | flag = True
34 | if i == len(ca_value) - 1:
35 | return index
36 | else:
37 | index = -1
38 | flag = False
39 | break
40 | else:
41 | flag = False
42 | index = -1
43 | for idx, item in enumerate(text):
44 | if item == value[0]: # 匹配第一个
45 | index = idx
46 | for i in range(len(value)):
47 | if text[idx + i] == value[i]:
48 | flag = True
49 | if i == len(value) - 1:
50 | return index
51 | else:
52 | index = -1
53 | flag = False
54 | break
55 | if flag:
56 | return index
57 | else:
58 | return -1
59 |
60 |
61 | def load_reverse_vocab(label_file):
62 | '''construct id2word'''
63 | vocab = {}
64 | index = 0
65 | with open(label_file, 'r', encoding='utf-8') as fp:
66 | while True:
67 | token = fp.readline()
68 | if not token:
69 | break
70 | token = token.strip() # 删除空白符
71 | vocab[index] = token
72 | index += 1
73 | return vocab
74 |
75 |
76 | def fun(filename,
77 | seq_in_target_file,
78 | seq_out_target_file,
79 | slot_none_target_file,
80 | region_dict_file,
81 | vocab):
82 | # tokenizer = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
83 | tokenizer = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
84 |
85 | with open(filename, 'r', encoding='utf-8') as fp:
86 | raw_data = json.load(fp)
87 |
88 | with open(region_dict_file, 'r', encoding='utf-8') as fp:
89 | region_dic = json.load(fp)
90 |
91 | seq_in = open(seq_in_target_file, 'a', encoding='utf-8')
92 | seq_out = open(seq_out_target_file, 'a', encoding='utf-8')
93 | slot_none_out = open(slot_none_target_file, 'a', encoding='utf-8')
94 | label = open('../new_data/label', 'a', encoding='utf-8')
95 | query = open('../new_data/query', 'a', encoding='utf-8')
96 | command = open('../new_data/command', 'a', encoding='utf-8')
97 | play_mode = open('../new_data/play_mode', 'a', encoding='utf-8')
98 | index = open('../new_data/index', 'a', encoding='utf-8')
99 |
100 | intent_label = set()
101 | slot_label = set()
102 | command_type = set()
103 | index_type = set()
104 | play_mode_type = set()
105 | query_type = set()
106 |
107 | count = 0
108 | raw_data_tmp = raw_data.copy()
109 | for dataname, single_data in raw_data.items():
110 | flag = True
111 | text = single_data['text']
112 | if text[-1] == '。': # 去掉末尾的句号
113 | text = text[:-1]
114 | text_id = tokenizer.encode(text) # 包含1和2
115 | text = [vocab[idx] for idx in text_id[1:-1]]
116 | intent = single_data['intent']
117 | slots = single_data['slots']
118 | slot_tags_str = ('O ' * len(text)).strip(' ')
119 | slot_tags = [item for item in slot_tags_str] # 初始化句子槽标签全为'O'
120 | # label.write(intent+'\n')
121 | # intent_label.add(intent)
122 |
123 | slot_none_str = ''
124 |
125 | for slot, value in slots.items():
126 | if slot == 'command':
127 | command_type.add(str(value))
128 | if isinstance(value, list):
129 | for item in value: # value是个列表,item是字符串槽值
130 | slot_none_str += item + ' '
131 | else:
132 | slot_none_str += value + ' '
133 | # command.write(str(value)+'\n')
134 | elif slot == 'index':
135 | index_type.add(str(value))
136 | if isinstance(value, list):
137 | for item in value: # value是个列表,item是字符串槽值
138 | slot_none_str += item + ' '
139 | else:
140 | slot_none_str += value + ' '
141 | # index.write(str(value)+'\n')
142 | elif slot == 'play_mode':
143 | play_mode_type.add(str(value))
144 | if isinstance(value, list):
145 | for item in value: # value是个列表,item是字符串槽值
146 | slot_none_str += item + ' '
147 | else:
148 | slot_none_str += value + ' '
149 | # play_mode.write(str(value)+'\n')
150 | elif slot == 'query_type':
151 | query_type.add(str(value))
152 | if isinstance(value, list):
153 | for item in value: # value是个列表,item是字符串槽值
154 | slot_none_str += item + ' '
155 | else:
156 | slot_none_str += value + ' '
157 | # query.write(str(value)+'\n')
158 | else:
159 | # query.write('None'+'\n')
160 | # command.write('None'+'\n')
161 | # play_mode.write('None'+'\n')
162 | # index.write('None'+'\n')
163 | slot_label.add(slot)
164 | if isinstance(value, list): # 槽不止一个值 例如:"artist": ["陈博","嘉洋"]
165 | for v in value:
166 | v = [vocab[idx] for idx in tokenizer.encode(v)[1:-1]]
167 | start_idx = find_idx(text, v, region_dic, mode='normal')
168 | if start_idx == -1:
169 | start_idx = find_idx(text, value, region_dic, mode='mohu') # 二次检测
170 | if start_idx == -1:
171 | flag = False
172 | print("无此槽值:", slot, v, dataname)
173 | raw_data_tmp.pop(dataname)
174 | count += 1
175 | break
176 | tag_len = len(v)
177 | if tag_len == 1:
178 | slot_tags[start_idx * 2] = 'B-' + slot
179 | else:
180 | for i in range(tag_len):
181 | if i == 0:
182 | slot_tags[(start_idx + i) * 2] = 'B-' + slot
183 | else:
184 | slot_tags[(start_idx + i) * 2] = 'I-' + slot
185 |
186 | else: # 槽值为单个字符串
187 | value = [vocab[idx] for idx in tokenizer.encode(value)[1:-1]]
188 | if not value: # value 为空
189 | break
190 | start_idx = find_idx(text, value, region_dic, mode='normal') # 槽标签起始位置
191 | if start_idx == -1: # 句子无此槽值时
192 | start_idx = find_idx(text, value, region_dic, mode='mohu') # 二次检测
193 | if start_idx == -1:
194 | flag = False
195 | print("无此槽值:", slot, value, dataname)
196 | raw_data_tmp.pop(dataname)
197 | count += 1
198 | break
199 | tag_len = len(value) # 槽标签长度
200 | if tag_len == 1:
201 | slot_tags[start_idx * 2] = 'B-' + slot
202 | else:
203 | for i in range(tag_len):
204 | if i == 0:
205 | slot_tags[(start_idx + i) * 2] = 'B-' + slot
206 | else:
207 | slot_tags[(start_idx + i) * 2] = 'I-' + slot
208 |
209 | if flag:
210 | if slot_none_str == '': # slot_none_str为空,说明此句话无slot_none槽类型
211 | slot_none_out.write('None' + '\n')
212 | else:
213 | slot_none_out.write(slot_none_str + '\n')
214 |
215 | slot_tags = ''.join(slot_tags) # 把slot_tags变为str类型
216 | seq_out.write(slot_tags + '\n')
217 | seq_in.write(' '.join(text) + '\n')
218 |
219 | intent_label = ' '.join(intent_label)
220 | slot_label = ' '.join(slot_label)
221 |
222 | # with open('../new_data/intent_label.txt', 'a', encoding='utf-8') as fp:
223 | # fp.write(intent_label+'\n')
224 | # with open('../new_data/slot_label.txt', 'a', encoding='utf-8') as fp:
225 | # fp.write(slot_label+'\n')
226 | # with open('../new_data/command_type.txt', 'a', encoding='utf-8') as fp:
227 | # fp.write(str(command_type)+'\n')
228 | # with open('../new_data/index_type.txt', 'a', encoding='utf-8') as fp:
229 | # fp.write(str(index_type)+'\n')
230 | # with open('../new_data/play_mode_type.txt', 'a', encoding='utf-8') as fp:
231 | # fp.write(str(play_mode_type)+'\n')
232 | # with open('../new_data/query_type.txt', 'a', encoding='utf-8') as fp:
233 | # fp.write(str(query_type)+'\n')
234 |
235 | seq_in.close()
236 | seq_out.close()
237 | slot_none_out.close()
238 | with open('../small_sample/new_B/train_final_clear_del.json', 'w', encoding='utf-8') as fp:
239 | json.dump(raw_data_tmp, fp, ensure_ascii=False)
240 | print('缺失数目:', count)
241 |
242 |
243 | if __name__ == "__main__":
244 | vocab_file = '../pretrained_model/erine/vocab.txt'
245 | train_filename = '../small_sample/new_B/train_final_clear_sorted.json'
246 | dev_filename = '../extend_data/last_data/dev_sorted.json'
247 | test_filename = '../raw_data/test_B_final_text.json'
248 | seq_in_target_file = '../small_sample/new_B/train_seq_in.txt'
249 | seq_out_target_file = '../small_sample/new_B/train_seq_out.txt'
250 | slot_none_target_file = '../small_sample/new_B/train_slot_none.txt'
251 | region_dict_file = '../final_data/region_dic.json'
252 | vocab = load_reverse_vocab(vocab_file)
253 | fun(filename=train_filename,
254 | seq_in_target_file=seq_in_target_file,
255 | seq_out_target_file=seq_out_target_file,
256 | slot_none_target_file=slot_none_target_file,
257 | region_dict_file=region_dict_file,
258 | vocab=vocab)
259 | # text = ['下', '周', '六', '我', '爷', '爷', '让', '我', '去', '买', '茶', '叶', ',', '记', '得', '提', '醒', '我']
260 | # value = ['我', '爷', '爷', '让', '我', '去', '买', '茶', '叶']
261 | # idx = find_idx(text, value)
262 |
--------------------------------------------------------------------------------
/data/code/model/InteractModel_1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:40
4 | # @Author : JJkinging
5 | # @File : model.py
6 | import math
7 |
8 | import torch
9 | import torch.nn as nn
10 | from transformers import BertModel
11 | from data.code.model.torchcrf import CRF
12 | import torch.nn.functional as F
13 | from data.code.scripts.config_Interact1 import Config
14 |
15 |
16 | class InteractModel(nn.Module):
17 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size,
18 | slot_tag_size, device):
19 | super(InteractModel, self).__init__()
20 | self.bert_model_path = bert_model_path
21 | self.bert_hidden_size = bert_hidden_size
22 | self.intent_tag_size = intent_tag_size
23 | self.slot_none_tag_size = slot_none_tag_size
24 | self.slot_tag_size = slot_tag_size
25 | self.device = device
26 | self.bert = BertModel.from_pretrained(self.bert_model_path)
27 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True)
28 | self.IntentClassify = nn.Linear(self.bert_hidden_size, self.intent_tag_size)
29 | self.SlotNoneClassify = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size)
30 | self.SlotClassify = nn.Linear(self.bert_hidden_size, self.slot_tag_size)
31 | self.Dropout = nn.Dropout(p=0.5)
32 |
33 | self.I_S_Emb = Label_Attention(self.IntentClassify, self.SlotClassify)
34 | self.T_block1 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size)
35 |
36 | def forward(self, input_ids, input_mask):
37 | batch_size = input_ids.size(0)
38 | seq_len = input_ids.size(1)
39 | utter_encoding = self.bert(input_ids, input_mask)
40 | H = utter_encoding[0]
41 | pooler_out = utter_encoding[1]
42 | H = self.Dropout(H)
43 | pooler_out = self.Dropout(pooler_out)
44 |
45 | # 1. Label Attention
46 | H_I, H_S = self.I_S_Emb(H, H, input_mask)
47 | # Co-Interactive Attention Layer
48 | H_I, H_S = self.T_block1(H_I + H, H_S + H, input_mask)
49 |
50 | intent_input = F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2)
51 |
52 | intent_logits = self.IntentClassify(intent_input)
53 | slot_none_logits = self.SlotNoneClassify(pooler_out)
54 | slot_logits = self.SlotClassify(H_S + H)
55 |
56 | return intent_logits, slot_none_logits, slot_logits
57 |
58 | def slot_loss(self, feats, slot_ids, mask):
59 | ''' 做训练时用
60 | :param feats: the output of BiLSTM and Liner
61 | :param slot_ids:
62 | :param mask:
63 | :return:
64 | '''
65 | feats = feats.to(self.device)
66 | slot_ids = slot_ids.to(self.device)
67 | mask = mask.to(self.device)
68 | loss_value = self.CRF(emissions=feats,
69 | tags=slot_ids,
70 | mask=mask,
71 | reduction='mean')
72 | return -loss_value
73 |
74 | def slot_predict(self, feats, mask, id2slot):
75 | feats = feats.to(self.device)
76 | mask = mask.to(self.device)
77 | slot2id = {value: key for key, value in id2slot.items()}
78 | # 做验证和测试时用
79 | out_path = self.CRF.decode(emissions=feats, mask=mask)
80 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path]
81 | for out in out_path:
82 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等
83 | if tag.startswith('I-'): # 当前tag为I-开头
84 | if i == 0: # 0位置应该是[START]
85 | out[i] = '[START]'
86 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的
87 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag
88 |
89 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path]
90 |
91 | return out_path
92 |
93 |
94 | class Label_Attention(nn.Module):
95 | def __init__(self, intent_emb, slot_emb):
96 | super(Label_Attention, self).__init__()
97 |
98 | self.W_intent_emb = intent_emb.weight # [num_class, hidden_dize]
99 | self.W_slot_emb = slot_emb.weight
100 |
101 | def forward(self, input_intent, input_slot, mask):
102 | intent_score = torch.matmul(input_intent, self.W_intent_emb.t())
103 | slot_score = torch.matmul(input_slot, self.W_slot_emb.t())
104 | intent_probs = nn.Softmax(dim=-1)(intent_score)
105 | slot_probs = nn.Softmax(dim=-1)(slot_score)
106 | intent_res = torch.matmul(intent_probs, self.W_intent_emb) # [bs, seq_len, hidden_size]
107 | slot_res = torch.matmul(slot_probs, self.W_slot_emb)
108 |
109 | return intent_res, slot_res
110 |
111 |
112 | class I_S_Block(nn.Module):
113 | def __init__(self, intent_emb, slot_emb, hidden_size):
114 | super(I_S_Block, self).__init__()
115 | config = Config()
116 | self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size)
117 | self.I_Out = SelfOutput(hidden_size, config.attention_dropout)
118 | self.S_Out = SelfOutput(hidden_size, config.attention_dropout)
119 | self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size)
120 |
121 | def forward(self, H_intent_input, H_slot_input, mask):
122 | H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask)
123 | H_slot = self.S_Out(H_slot, H_slot_input) # H_slot_input: label attention的输出
124 | H_intent = self.I_Out(H_intent, H_intent_input)
125 | H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot)
126 |
127 | return H_intent, H_slot
128 |
129 |
130 | class I_S_SelfAttention(nn.Module):
131 | def __init__(self, input_size, hidden_size, out_size):
132 | super(I_S_SelfAttention, self).__init__()
133 | config = Config()
134 |
135 | self.num_attention_heads = 12
136 | self.attention_head_size = int(hidden_size / self.num_attention_heads)
137 |
138 | self.all_head_size = self.num_attention_heads * self.attention_head_size
139 | self.out_size = out_size
140 | self.query = nn.Linear(input_size, self.all_head_size)
141 | self.query_slot = nn.Linear(input_size, self.all_head_size)
142 | self.key = nn.Linear(input_size, self.all_head_size)
143 | self.key_slot = nn.Linear(input_size, self.all_head_size)
144 | self.value = nn.Linear(input_size, self.out_size)
145 | self.value_slot = nn.Linear(input_size, self.out_size)
146 | self.dropout = nn.Dropout(config.attention_dropout)
147 |
148 | def transpose_for_scores(self, x):
149 | last_dim = int(x.size()[-1] / self.num_attention_heads)
150 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim)
151 | x = x.view(*new_x_shape)
152 | return x.permute(0, 2, 1, 3)
153 |
154 | def forward(self, intent, slot, mask):
155 | extended_attention_mask = mask.unsqueeze(1).unsqueeze(2)
156 |
157 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
158 | attention_mask = (1.0 - extended_attention_mask) * -10000.0
159 |
160 | mixed_query_layer = self.query(intent)
161 | mixed_key_layer = self.key(slot)
162 | mixed_value_layer = self.value(slot)
163 |
164 | mixed_query_layer_slot = self.query_slot(slot)
165 | mixed_key_layer_slot = self.key_slot(intent)
166 | mixed_value_layer_slot = self.value_slot(intent)
167 |
168 | query_layer = self.transpose_for_scores(mixed_query_layer)
169 | query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot)
170 | key_layer = self.transpose_for_scores(mixed_key_layer)
171 | key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot)
172 | value_layer = self.transpose_for_scores(mixed_value_layer)
173 | value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot)
174 |
175 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
176 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
177 | # attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0))
178 | attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2))
179 | attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size)
180 | attention_scores_intent = attention_scores + attention_mask
181 |
182 | attention_scores_slot = attention_scores_slot + attention_mask
183 |
184 | # Normalize the attention scores to probabilities.
185 | attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot)
186 | attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent)
187 |
188 | attention_probs_slot = self.dropout(attention_probs_slot)
189 | attention_probs_intent = self.dropout(attention_probs_intent)
190 |
191 | context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot)
192 | context_layer_intent = torch.matmul(attention_probs_intent, value_layer)
193 |
194 | context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous()
195 | context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous()
196 | new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,)
197 | new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,)
198 |
199 | context_layer = context_layer.view(*new_context_layer_shape)
200 | context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent)
201 | return context_layer, context_layer_intent
202 |
203 |
204 | class SelfOutput(nn.Module):
205 | def __init__(self, hidden_size, hidden_dropout_prob):
206 | super(SelfOutput, self).__init__()
207 | self.dense = nn.Linear(hidden_size, hidden_size)
208 | self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
209 | self.dropout = nn.Dropout(hidden_dropout_prob)
210 |
211 | def forward(self, hidden_states, input_tensor):
212 | hidden_states = self.dense(hidden_states)
213 | hidden_states = self.dropout(hidden_states)
214 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
215 | return hidden_states
216 |
217 |
218 | class LayerNorm(nn.Module):
219 | def __init__(self, hidden_size, eps=1e-12):
220 | """Construct a layernorm module in the TF style (epsilon inside the square root).
221 | """
222 | super(LayerNorm, self).__init__()
223 | self.weight = nn.Parameter(torch.ones(hidden_size))
224 | self.bias = nn.Parameter(torch.zeros(hidden_size))
225 | self.variance_epsilon = eps
226 |
227 | def forward(self, x):
228 | u = x.mean(-1, keepdim=True)
229 | s = (x - u).pow(2).mean(-1, keepdim=True)
230 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
231 | return self.weight * x + self.bias
232 |
233 |
234 | class Intermediate_I_S(nn.Module):
235 | def __init__(self, intermediate_size, hidden_size):
236 | super(Intermediate_I_S, self).__init__()
237 | self.config = Config()
238 | self.dense_in = nn.Linear(hidden_size * 6, intermediate_size)
239 | self.intermediate_act_fn = nn.ReLU()
240 | self.dense_out = nn.Linear(intermediate_size, hidden_size)
241 | self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12)
242 | self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12)
243 | self.dropout = nn.Dropout(self.config.attention_dropout)
244 |
245 | def forward(self, hidden_states_I, hidden_states_S):
246 | hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2)
247 | batch_size, max_length, hidden_size = hidden_states_in.size()
248 | h_pad = torch.zeros(batch_size, 1, hidden_size)
249 | if self.config.use_gpu and torch.cuda.is_available():
250 | h_pad = h_pad.cuda()
251 | h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1)
252 | h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1)
253 | hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2)
254 |
255 | hidden_states = self.dense_in(hidden_states_in)
256 | hidden_states = self.intermediate_act_fn(hidden_states)
257 | hidden_states = self.dense_out(hidden_states)
258 | hidden_states = self.dropout(hidden_states)
259 | hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I)
260 | hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S)
261 | return hidden_states_I_NEW, hidden_states_S_NEW
--------------------------------------------------------------------------------
/data/code/model/InteractModel_3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/8 14:40
4 | # @Author : JJkinging
5 | # @File : model.py
6 | import math
7 |
8 | import torch
9 | import torch.nn as nn
10 | from transformers import BertModel
11 | from data.code.model.torchcrf import CRF
12 | import torch.nn.functional as F
13 | from data.code.scripts.config_Interact3 import Config
14 |
15 |
16 | class InteractModel(nn.Module):
17 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size,
18 | slot_tag_size, device):
19 | super(InteractModel, self).__init__()
20 | self.bert_model_path = bert_model_path
21 | self.bert_hidden_size = bert_hidden_size
22 | self.intent_tag_size = intent_tag_size
23 | self.slot_none_tag_size = slot_none_tag_size
24 | self.slot_tag_size = slot_tag_size
25 | self.device = device
26 | self.bert = BertModel.from_pretrained(self.bert_model_path)
27 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True)
28 | self.IntentClassify = nn.Linear(self.bert_hidden_size, self.intent_tag_size)
29 | self.SlotNoneClassify = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size)
30 | self.SlotClassify = nn.Linear(self.bert_hidden_size, self.slot_tag_size)
31 | self.Dropout = nn.Dropout(p=0.5)
32 |
33 | self.I_S_Emb = Label_Attention(self.IntentClassify, self.SlotClassify)
34 | self.T_block1 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size)
35 | self.T_block2 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size)
36 | self.T_block3 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size)
37 |
38 | def forward(self, input_ids, input_mask):
39 | batch_size = input_ids.size(0)
40 | seq_len = input_ids.size(1)
41 | utter_encoding = self.bert(input_ids, input_mask)
42 | H = utter_encoding[0]
43 | pooler_out = utter_encoding[1]
44 | H = self.Dropout(H)
45 | pooler_out = self.Dropout(pooler_out)
46 |
47 | # 1. Label Attention
48 | H_I, H_S = self.I_S_Emb(H, H, input_mask)
49 | # Co-Interactive Attention Layer
50 | H_I, H_S = self.T_block1(H_I + H, H_S + H, input_mask)
51 |
52 | # 2. Label Attention
53 | H_I_1, H_S_1 = self.I_S_Emb(H_I, H_S, input_mask)
54 | # # # Co-Interactive Attention Layer
55 | H_I, H_S = self.T_block2(H_I + H_I_1, H_S + H_S_1, input_mask)
56 |
57 | # 3. Label Attention
58 | H_I_2, H_S_2 = self.I_S_Emb(H_I, H_S, input_mask)
59 | # # # Co-Interactive Attention Layer
60 | H_I, H_S = self.T_block3(H_I + H_I_2, H_S + H_S_2, input_mask)
61 |
62 | intent_input = F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2)
63 |
64 | intent_logits = self.IntentClassify(intent_input)
65 | slot_none_logits = self.SlotNoneClassify(pooler_out)
66 | slot_logits = self.SlotClassify(H_S + H)
67 |
68 | return intent_logits, slot_none_logits, slot_logits
69 |
70 | def slot_loss(self, feats, slot_ids, mask):
71 | ''' 做训练时用
72 | :param feats: the output of BiLSTM and Liner
73 | :param slot_ids:
74 | :param mask:
75 | :return:
76 | '''
77 | feats = feats.to(self.device)
78 | slot_ids = slot_ids.to(self.device)
79 | mask = mask.to(self.device)
80 | loss_value = self.CRF(emissions=feats,
81 | tags=slot_ids,
82 | mask=mask,
83 | reduction='mean')
84 | return -loss_value
85 |
86 | def slot_predict(self, feats, mask, id2slot):
87 | feats = feats.to(self.device)
88 | mask = mask.to(self.device)
89 | slot2id = {value: key for key, value in id2slot.items()}
90 | # 做验证和测试时用
91 | out_path = self.CRF.decode(emissions=feats, mask=mask)
92 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path]
93 | for out in out_path:
94 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等
95 | if tag.startswith('I-'): # 当前tag为I-开头
96 | if i == 0: # 0位置应该是[START]
97 | out[i] = '[START]'
98 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的
99 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag
100 |
101 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path]
102 |
103 | return out_path
104 |
105 | class Label_Attention(nn.Module):
106 | def __init__(self, intent_emb, slot_emb):
107 | super(Label_Attention, self).__init__()
108 |
109 | self.W_intent_emb = intent_emb.weight # [num_class, hidden_dize]
110 | self.W_slot_emb = slot_emb.weight
111 |
112 | def forward(self, input_intent, input_slot, mask):
113 | intent_score = torch.matmul(input_intent, self.W_intent_emb.t())
114 | slot_score = torch.matmul(input_slot, self.W_slot_emb.t())
115 | intent_probs = nn.Softmax(dim=-1)(intent_score)
116 | slot_probs = nn.Softmax(dim=-1)(slot_score)
117 | intent_res = torch.matmul(intent_probs, self.W_intent_emb) # [bs, seq_len, hidden_size]
118 | slot_res = torch.matmul(slot_probs, self.W_slot_emb)
119 |
120 | return intent_res, slot_res
121 |
122 |
123 | class I_S_Block(nn.Module):
124 | def __init__(self, intent_emb, slot_emb, hidden_size):
125 | super(I_S_Block, self).__init__()
126 | config = Config()
127 | self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size)
128 | self.I_Out = SelfOutput(hidden_size, config.attention_dropout)
129 | self.S_Out = SelfOutput(hidden_size, config.attention_dropout)
130 | self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size)
131 |
132 | def forward(self, H_intent_input, H_slot_input, mask):
133 | H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask)
134 | H_slot = self.S_Out(H_slot, H_slot_input) # H_slot_input: label attention的输出
135 | H_intent = self.I_Out(H_intent, H_intent_input)
136 | H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot)
137 |
138 | return H_intent, H_slot
139 |
140 |
141 | class I_S_SelfAttention(nn.Module):
142 | def __init__(self, input_size, hidden_size, out_size):
143 | super(I_S_SelfAttention, self).__init__()
144 | config = Config()
145 |
146 | self.num_attention_heads = 12
147 | self.attention_head_size = int(hidden_size / self.num_attention_heads)
148 |
149 | self.all_head_size = self.num_attention_heads * self.attention_head_size
150 | self.out_size = out_size
151 | self.query = nn.Linear(input_size, self.all_head_size)
152 | self.query_slot = nn.Linear(input_size, self.all_head_size)
153 | self.key = nn.Linear(input_size, self.all_head_size)
154 | self.key_slot = nn.Linear(input_size, self.all_head_size)
155 | self.value = nn.Linear(input_size, self.out_size)
156 | self.value_slot = nn.Linear(input_size, self.out_size)
157 | self.dropout = nn.Dropout(config.attention_dropout)
158 |
159 | def transpose_for_scores(self, x):
160 | last_dim = int(x.size()[-1] / self.num_attention_heads)
161 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim)
162 | x = x.view(*new_x_shape)
163 | return x.permute(0, 2, 1, 3)
164 |
165 | def forward(self, intent, slot, mask):
166 | extended_attention_mask = mask.unsqueeze(1).unsqueeze(2)
167 |
168 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
169 | attention_mask = (1.0 - extended_attention_mask) * -10000.0
170 |
171 | mixed_query_layer = self.query(intent)
172 | mixed_key_layer = self.key(slot)
173 | mixed_value_layer = self.value(slot)
174 |
175 | mixed_query_layer_slot = self.query_slot(slot)
176 | mixed_key_layer_slot = self.key_slot(intent)
177 | mixed_value_layer_slot = self.value_slot(intent)
178 |
179 | query_layer = self.transpose_for_scores(mixed_query_layer)
180 | query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot)
181 | key_layer = self.transpose_for_scores(mixed_key_layer)
182 | key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot)
183 | value_layer = self.transpose_for_scores(mixed_value_layer)
184 | value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot)
185 |
186 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
187 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
188 | # attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0))
189 | attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2))
190 | attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size)
191 | attention_scores_intent = attention_scores + attention_mask
192 |
193 | attention_scores_slot = attention_scores_slot + attention_mask
194 |
195 | # Normalize the attention scores to probabilities.
196 | attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot)
197 | attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent)
198 |
199 | attention_probs_slot = self.dropout(attention_probs_slot)
200 | attention_probs_intent = self.dropout(attention_probs_intent)
201 |
202 | context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot)
203 | context_layer_intent = torch.matmul(attention_probs_intent, value_layer)
204 |
205 | context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous()
206 | context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous()
207 | new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,)
208 | new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,)
209 |
210 | context_layer = context_layer.view(*new_context_layer_shape)
211 | context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent)
212 | return context_layer, context_layer_intent
213 |
214 |
215 | class SelfOutput(nn.Module):
216 | def __init__(self, hidden_size, hidden_dropout_prob):
217 | super(SelfOutput, self).__init__()
218 | self.dense = nn.Linear(hidden_size, hidden_size)
219 | self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
220 | self.dropout = nn.Dropout(hidden_dropout_prob)
221 |
222 | def forward(self, hidden_states, input_tensor):
223 | hidden_states = self.dense(hidden_states)
224 | hidden_states = self.dropout(hidden_states)
225 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
226 | return hidden_states
227 |
228 |
229 | class LayerNorm(nn.Module):
230 | def __init__(self, hidden_size, eps=1e-12):
231 | """Construct a layernorm module in the TF style (epsilon inside the square root).
232 | """
233 | super(LayerNorm, self).__init__()
234 | self.weight = nn.Parameter(torch.ones(hidden_size))
235 | self.bias = nn.Parameter(torch.zeros(hidden_size))
236 | self.variance_epsilon = eps
237 |
238 | def forward(self, x):
239 | u = x.mean(-1, keepdim=True)
240 | s = (x - u).pow(2).mean(-1, keepdim=True)
241 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
242 | return self.weight * x + self.bias
243 |
244 |
245 | class Intermediate_I_S(nn.Module):
246 | def __init__(self, intermediate_size, hidden_size):
247 | super(Intermediate_I_S, self).__init__()
248 | self.config = Config()
249 | self.dense_in = nn.Linear(hidden_size * 6, intermediate_size)
250 | self.intermediate_act_fn = nn.ReLU()
251 | self.dense_out = nn.Linear(intermediate_size, hidden_size)
252 | self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12)
253 | self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12)
254 | self.dropout = nn.Dropout(self.config.attention_dropout)
255 |
256 | def forward(self, hidden_states_I, hidden_states_S):
257 | hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2)
258 | batch_size, max_length, hidden_size = hidden_states_in.size()
259 | h_pad = torch.zeros(batch_size, 1, hidden_size)
260 | if self.config.use_gpu and torch.cuda.is_available():
261 | h_pad = h_pad.cuda()
262 | h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1)
263 | h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1)
264 | hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2)
265 |
266 | hidden_states = self.dense_in(hidden_states_in)
267 | hidden_states = self.intermediate_act_fn(hidden_states)
268 | hidden_states = self.dense_out(hidden_states)
269 | hidden_states = self.dropout(hidden_states)
270 | hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I)
271 | hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S)
272 | return hidden_states_I_NEW, hidden_states_S_NEW
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 2021 CCF BDCI 全国信息检索挑战杯(CCIR-Cup) 智能人机交互自然语言理解赛道第二名解决方案(意图识别&槽填充)
2 | 比赛网址: CCIR-Cup-智能人机交互自然语言理解
3 | 
4 | ## 1.依赖环境:
5 |
6 | * python==3.8
7 | * torch==1.7.1+cu110
8 | * numpy==1.19.2
9 | * transformers==4.5.1
10 | * scikit_learn==1.0
11 | * seqeval==1.2.2
12 | * tqdm==4.50.2
13 | * CUDA==11.0
14 |
15 | ## 2.解决方案
16 |
17 | ### 1.数据预处理部分
18 |
19 | 1.首先对任务进行明确,赛题任务是意图识别与槽填充,在观察训练数据之后,发现槽填充任务可分解为两类任务来做:一类是标准槽填充任务,即槽值可在当前对话句子当中完全匹配到;另一类是非标准槽填充(自拟的名字),即槽值不可在当前对话句子中找到或者完全匹配。对于非标准槽填充任务,把它当作另一种分类任务来解决。所以,我就把比赛任务当作三个子任务来进行,分别是意图识别、标准槽填充和非标准槽填充(分类任务)。
20 |
21 | 2.对于标准槽填充而言,有些槽标签在大部分训练数据中都是可完全匹配的,但是仍然存在少量不完全匹配槽标签,例如:句中出现的是港片、韩剧、美剧、内地、英文、中文等词汇时,对应的槽值标注却是香港、韩国、美国、大陆、英语、华语等。对于这一类数据,若将其当作非标准槽填充任务显得不太合理,解决方案是:提前准备好一个特殊词汇的映射字典 region_dic.json,在处理训练数据的时候,如果碰到了有的槽值出现在特殊词汇字典中,则对其进行槽标注的时候需要先进行转换。
22 |
23 | 3.然后用 ernie 的 Tokenizer 将对话句子按字切分,可能存在 ##后缀 和 [UNK] 的情况,将切分好的句子作为原始输入,对于标准槽填充任务,按 BIO 的方式进行标注。观察训练数据后发现,标准槽填充任务中存在少量的嵌套命名实体,例如:
24 |
25 | ```json
26 | {
27 | "NLU07301": {
28 | "text": "西安WE比赛的视频帮我找一下",
29 | "intent": "Video-Play",
30 | "slots": {
31 | "name": "西安WE比赛的视频",
32 | "region": "西安"
33 | }
34 | }
35 | }
36 | ```
37 |
38 | 对于这种少量嵌套的例子,我并没有涉及特殊的网络结构来解决,而是使用了一种简单的方式:首先将每条训练数据的 slots 下的所有槽值的长度按从大到小排列,然后在对其进行序列标注的时候按槽的先后顺序进行标注,比如上面的例子的标注方式为:
39 |
40 | 首先是 "name" 这个槽标签:
41 |
42 | ```
43 | B-name I-name I-name I-name I-name I-name I-name I-name I-name O O O O O
44 | ```
45 |
46 | 然后是 "region" 这个槽标签:
47 |
48 | ```
49 | B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O O O
50 | ```
51 |
52 | "region" 标注完成后将 B-name I-name 这两个tag给覆盖掉了。
53 |
54 | 最后,对于这种嵌套实体,模型就按照这种标注方式去训练;在解码的时候,按照一定的匹配规则识别出"name" 和 "region"这两个槽标签,后面的实验中表明使用这种标注方式能够有效的识别出测试集数据中存在的嵌套实体。
55 |
56 | 4.手动纠正部分,训练数据中存在一些数据有着明显的标注错误。例如:NLU00400 的 "artist": "银临,河图",正确的标注应该为 "artist": ["银临",""河图"];"query_type":"汽车票查询" 和 "query_type":"汽车票",这两个明显就是一样的,故将其统一,再比如:
57 |
58 | ```json
59 | {
60 | "NLU04386": {
61 | "text": "明天早上7:20你会不会通知我带洗衣液",
62 | "intent": "Alarm-Update",
63 | "slots": {
64 | "datetime_date": "明天",
65 | "notes": "带洗衣液",
66 | "datetime_time": "早上7"
67 | }
68 | }
69 | }
70 | ```
71 |
72 | "datetime_time":"早上7" 标注错误,正确标注应该为"datetime_time": "早上7:20",且类似这种的错误在"Alarm-Update"这个意图中大量存在。如果不进行纠正,则对模型的训练会造成很大的影响。
73 |
74 | 5.对于非标准槽填充部分,统计出了四种槽标签:command、index、play_mode 和 query_type,将这四类槽标签的槽值当作类别,一共有29种类别,如:
75 |
76 | ```
77 | 音量调节
78 | 查询状态
79 | 汽车票查询
80 | 穿衣指数
81 | 紫外线指数
82 | ...
83 | None
84 | ```
85 |
86 | None表示不存在这一类值。然后对其进行类似意图识别那样做分类。
87 |
88 | 6.对于域外检测任务:在 a 榜阶段,我在 LCQMC 数据集中选择了1000条左右数据作为 Other 数据的来源;在 b 榜阶段,我把 a 榜阶段预测出的 intent 为 Other 的数据加上 LCQMC数据集中选择出500条数据一起作为训练集的 intent 为 Other 类进行训练。
89 |
90 | 7.对于小样本检测任务:发现意图为 Audio-Play 和 TVProgram_Play 的这两个意图是小样本数据,在原始训练集中分别为50条。解决方法:对小样本意图数据进行数据增强,使其数量接近基本任务数据的1000条左右,具体做法:分别对 Audio-Play 和TVProgram_Play这两个意图进行增强,举例来说,对于 Audio-Play 而言,其可能的槽标签有
91 |
92 | 8.在模型输出后,进行一步后处理操作:对模型预测结果进行纠正,即把不属于某一类intent的槽值删除,首先统计出训练数据中的意图和槽标签之间的关系:
93 | ```json
94 | {
95 | "FilmTele-Play": ["name", "tag", "artist", "region", "play_setting", "age"],
96 | "Audio-Play": ["language", "artist", "tag", "name", "play_setting"],
97 | "Radio-Listen": ["name", "channel", "frequency", "artist"],
98 | "TVProgram-Play": ["name", "channel", "datetime_date", "datetime_time"],
99 | "Travel-Query": ["query_type", "datetime_date", "departure", "destination", "datetime_time"],
100 | "Music-Play": ["language", "artist", "album", "instrument", "song", "play_mode", "age"],
101 | "HomeAppliance-Control": ["command", "appliance", "details"],
102 | "Calendar-Query": ["datetime_date"],
103 | "Alarm-Update": ["notes", "datetime_date", "datetime_time"],
104 | "Video-Play": ["name", "datetime_date", "region", "datetime_time"],
105 | "Weather-Query": ["datetime_date", "type", "city", "index", "datetime_time"],
106 | "Other": []
107 | }
108 | ```
109 | 就可以发现意图只能包含特定的标签,某些标签不可能出现在其它意图中
110 | 举例来说:比如我在一条测试数据中预测出其intent = FilmTele-Play, 然后其槽值预测中出现了"notes"这个槽标签,这与我之前统计的哪些槽标签只出现在哪些意图中不符合(即训练数据中FileTele-Play这个意图不可能出现"notes"这个槽标签),所以该函数就把"notes"这个槽位和槽值删除掉。
111 |
112 | "Audio-Play": ["language", "artist", "tag", "name", "play_setting"], 一共五类,分别在训练数据中统计出各个槽标签可能出现的槽值有哪些,比如 "language": ["日语", "英语", "法语", "俄语", "西班牙语", "华语", "韩语", "德语", "藏语"] ,language 有这些可选项,当然也可以自己随便添加几个合适的选项。得到了一个这种字典后,对于每一条原训练数据中意图为 Audio-Play 的数据进行扩充,每一条扩充20条新数据,具体扩充方式:对于原数据而言,如果某一个槽标签出现,则在该槽标签对应的标签候选项中随机选择一个替换它,对原数据存在的每一个槽标签都进行这种操作,这样就增加了一条”新数据“,经过实验验证,小样本数据扩充前后,线上得分提升了两个点,证明这种方式还是效果不错的。
113 | 9.在a榜阶段,训练数据包括三部分:原始训练数据 + 小样本意图扩充数据 + LCQMC数据集(1000条)当作域外数据;在b榜阶段,训练数据除了和a榜相同的部分外,还把模型在a榜测试集上的输出当作a榜测试集的标注,再将这些数据也添加到训练数据中进行训练,然后再去预测b榜测试集,所以b榜阶段的最终使用的训练集有13119条。
114 |
115 | ### 2.模型算法部分
116 |
117 | 此次比赛我一共使用了三个模型进行训练,最后的结果 result.json 由三个模型的投票表决产生。
118 |
119 | #### 1.JointErine 模型
120 |
121 | 思路参照:[BERT for Joint Intent Classification and Slot Filling](http://arxiv.org/abs/1902.10909v1)
122 |
123 | 预训练模型并没有使用中文版 bert base,而是使用的是百度的中文版 ernie-1.0 base,三个任务进行联合训练。意图识别和非标准槽填充使用 erine 模型的输出分别连接一个全连接层进行分类;标准槽填充得到 erine 的输出后,再将其输入到 CRF 层,erine预训练模型与CRF层采用不同的学习率,erine 的学习率是5e-5,CRF层的学习率是5e-2
124 |
125 | #### 2.InteractModel_1 模型
126 |
127 | 思路参照:[A CO-INTERACTIVE TRANSFORMER FOR JOINT SLOT FILLING AND INTENT DETECTION](http://arxiv.org/abs/2010.03880v3)
128 |
129 | 仍然是三个任务联合训练,不同的是,意图识别和标准槽填充部分进行了一层交互,非标准槽填充未与上述二个任务进行交互,而是把预训练模型的输出 pooled_output 直接输入到全连接层中进行分类。
130 |
131 | ***交互层部分:***
132 |
133 | 首先使用中文版 ernie-1.0 base 预训练模型作为主体部分:
134 | $$
135 | 对于输入的序列\{x_1,x_2,...,x_n\}(n是token的数量),输入到中文版ernie-1.0模型之后得到输出H = \{h_1, h_2, ..., h_n\}
136 | $$
137 | 然后是意图和标准槽填充的交互层,这个模型中使用了一层交互层:
138 |
139 | **标签注意力层**
140 |
141 | 
142 |
143 | **协同交互注意力层**
144 |
145 | 
146 |
147 | **前馈神经网络层**
148 |
149 | 
150 |
151 | ***解码器层部分***
152 |
153 | 
154 |
155 | #### 3.InteractModel_3模型
156 |
157 | InteractModel_3 模型结构类似于上述讲解的 InteractModel_1 ,唯一的区别就是 InteractModel_3 的意图识别和标准槽填充部分使用了三层的交互。
158 | **镜像复现说明请查看 ccir/image/README.md**
159 |
160 | ## 3.项目目录结构
161 | ```
162 | ccir
163 | |-- data —— 数据文件夹
164 | | |-- code —— 包含所有代码文件夹
165 | | | |-- __init__.py
166 | | | |-- model —— 与网络模型相关文件夹
167 | | | | |-- __init__.py
168 | | | | |-- InteractModel_1.py —— InteractModel_1 网络模型
169 | | | | |-- InteractModel_3.py —— InteractModel_3 网络模型
170 | | | | |-- JointBertModel.py —— JointBertModel 网络模型
171 | | | | |-- __pycache__
172 | | | | | |-- __init__.cpython-38.pyc
173 | | | | | |-- InteractModel_1.cpython-38.pyc
174 | | | | | |-- InteractModel_3.cpython-38.pyc
175 | | | | | |-- JointBertModel.cpython-38.pyc
176 | | | | | `-- torchcrf.cpython-38.pyc
177 | | | | `-- torchcrf.py —— CRF层网络模型
178 | | | |-- predict —— 包含推理代码的文件夹
179 | | | | |-- __init__.py
180 | | | | |-- integration.py —— 负责将三个模型的结果进行投票输出成最后的结果result.json
181 | | | | |-- post_process.py —— 对模型的结果进行后处理的代码
182 | | | | |-- __pycache__
183 | | | | | |-- __init__.cpython-38.pyc
184 | | | | | |-- post_process.cpython-38.pyc
185 | | | | | |-- test_dataset.cpython-38.pyc
186 | | | | | `-- test_utils.cpython-38.pyc
187 | | | | |-- run_interact1.py —— 对线上训练的InteractModel_1模型进行推理
188 | | | | |-- run_interact3.py —— 对线上训练的InteractModel_3模型进行推理
189 | | | | |-- run_JointBert.py —— 对线上训练的JointBert模型进行推理
190 | | | | |-- run_trained_interact1.py —— 对本地训练的InteractModel_1模型进行推理
191 | | | | |-- run_trained_interact3.py —— 对本地训练的InteractModel_3模型进行推理
192 | | | | |-- run_trained_JointBert.py —— 对本地训练的JointBert模型进行推理
193 | | | | |-- test_dataset.py —— 构建测试集dataset
194 | | | | `-- test_utils.py —— 测试集的工具类函数
195 | | | |-- __pycache__
196 | | | | `-- __init__.cpython-38.pyc
197 | | | |-- preprocess —— 数据预处理代码
198 | | | | |-- __init__.py
199 | | | | |-- analysis.py —— 分析训练数据,哪些意图包含哪些标签
200 | | | | |-- extend_audio_sample.py —— 用于扩充意图为”Audio-Play“的小样本数据
201 | | | | |-- extend_tv_sample.py —— 用于扩充意图为”TVProgram-Play“的小样本数据
202 | | | | |-- extract_intent_sample.py —— 在训练数据中提取特定意图的数据
203 | | | | |-- generate_intent.py —— 对训练数据和验证集数据提取意图
204 | | | | |-- process_other.py —— 处理域外数据
205 | | | | |-- process_rawdata.py —— 处理原始训练数据和验证集数据
206 | | | | |-- rectify.py —— 对意图为”Alarm-Update“的训练数据进行纠正
207 | | | | |-- slot_sorted.py —— 对槽填充的标注按槽值从大到小进行排序
208 | | | | |-- split_train_dev.py —— 将原始训练数据按8: 2划分
209 | | | `-- scripts
210 | | | |-- build_vocab.py —— 构建词典,加载词汇表
211 | | | |-- config_Interact1.py —— InteractModel_1的配置文件(参数设置)
212 | | | |-- config_Interact3.py —— InteractModel_3的配置文件(参数设置)
213 | | | |-- config_jointBert.py —— JointBert的配置文件(参数设置)
214 | | | |-- dataset.py —— 构建训练集dataset
215 | | | |-- __init__.py
216 | | | |-- __pycache__
217 | | | | |-- build_vocab.cpython-38.pyc
218 | | | | |-- config_Interact1.cpython-38.pyc
219 | | | | |-- config_Interact3.cpython-38.pyc
220 | | | | |-- config_jointBert.cpython-38.pyc
221 | | | | |-- dataset.cpython-38.pyc
222 | | | | |-- __init__.cpython-38.pyc
223 | | | | `-- utils.cpython-38.pyc
224 | | | |-- train_interact1.py —— InteractModel_1的主函数训练代码
225 | | | |-- train_interact3.py —— InteractModel_3的主函数训练代码
226 | | | |-- train_jointBert.py —— JointBert的主函数训练代码
227 | | | `-- utils.py —— 训练阶段的工具类函数(train、valid等)
228 | | |-- __init__.py
229 | | |-- prediction_result —— 模型推理结果result.json的保存文件夹
230 | | |-- __pycache__
231 | | | `-- __init__.cpython-38.pyc
232 | | |-- raw_data —— 原始训练集文件夹
233 | | `-- user_data —— 用户数据文件夹
234 | | |-- common_data —— 训练和验证时使用的公共文件夹
235 | | | |-- intent_label.txt —— 包含11类意图的txt文件
236 | | | |-- intent_slot_mapping.json —— 意图与槽标签对应关系的字典
237 | | | |-- region_dic.json —— 用于帮助训练集槽标注的映射字典
238 | | | |-- slot_label.txt —— 标准槽标签
239 | | | `-- slot_none_vocab.txt —— 非标准槽标签
240 | | |-- dev_data
241 | | | |-- dev_intent_label.txt —— 验证集数据的意图
242 | | | |-- dev_seq_in.txt —— 验证集数据的原始句子分词输入
243 | | | |-- dev_seq_out.txt —— 验证集数据的序列标注
244 | | | `-- dev_slot_none.txt —— 验证集数据的非标准槽填充分类标签
245 | | |-- output_model —— 保存训练过程中间的模型文件夹
246 | | | |-- InteractModel_1 —— InteractModel_1文件夹(线上训练的模型保存在该文件夹下)
247 | | | | `-- trained_model —— 保存本地训练好的模型的文件夹
248 | | | | `-- Interact1_model_best.pth.tar —— 本地训练好的InteractModel_1模型
249 | | | |-- InteractModel_3 —— InteractModel_3文件夹(线上训练的模型保存在该文件夹下)
250 | | | | `-- trained_model
251 | | | | `-- Interact3_model_best.pth.tar —— 本地训练好的InteractModel_3模型
252 | | | `-- JointBert —— JointBert文件夹(线上训练的模型保存在该文件夹下)
253 | | | `-- trained_model
254 | | | `-- bert_model_best.pth.tar —— 本地训练好的JointBert模型
255 | | |-- pretrained_model —— 预训练模型文件夹
256 | | | `-- ernie —— ernie
257 | | | |-- config.json
258 | | | |-- pytorch_model.bin
259 | | | |-- special_tokens_map.json
260 | | | |-- tokenizer_config.json
261 | | | `-- vocab.txt
262 | | |-- test_data —— 测试集文件夹
263 | | | |-- test_B_final_text.json —— 测试集原始json文件
264 | | | `-- test_seq_in_B.txt —— 经过Tokenizer分词后的测试集文件
265 | | |-- tmp_result —— 保存单个模型输出结果
266 | | `-- train_data —— 训练集文件夹
267 | | |-- train_intent_label.txt —— 训练集数据的意图
268 | | |-- train_seq_in.txt —— 验证集数据的原始句子分词输入
269 | | |-- train_seq_out.txt —— 验证集数据的序列标注
270 | | `-- train_slot_none.txt —— 验证集数据的非标准槽填充分类标签
271 | |-- image —— 镜像相关文件夹
272 | | |-- readme_images —— 存放赛题解决方案和算法介绍的README文档的图片
273 | | |-- ccir-image.tar —— 镜像文件
274 | | |-- README.md —— 关于复现具体操作的REDAME文档
275 | | |-- run_infer.sh —— 使用本地训练的模型进行线上推理的脚本
276 | | `-- run.sh —— 进行线上训练和推理的脚本
277 | `-- README.md —— 赛题整体的解决方案和算法介绍的README文档
278 | ```
279 |
--------------------------------------------------------------------------------
/data/code/model/torchcrf.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.7.2'
2 |
3 | from typing import List, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class CRF(nn.Module):
10 | """Conditional random field.
11 |
12 | This module implements a conditional random field [LMP01]_. The forward computation
13 | of this class computes the log likelihood of the given sequence of tags and
14 | emission score tensor. This class also has `~CRF.decode` method which finds
15 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
16 |
17 | Args:
18 | num_tags: Number of tags.
19 | batch_first: Whether the first dimension corresponds to the size of a minibatch.
20 |
21 | Attributes:
22 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
23 | ``(num_tags,)``.
24 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
25 | ``(num_tags,)``.
26 | transitions (`~torch.nn.Parameter`): Transition score tensor of size
27 | ``(num_tags, num_tags)``.
28 |
29 |
30 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
31 | "Conditional random fields: Probabilistic models for segmenting and
32 | labeling sequence data". *Proc. 18th International Conf. on Machine
33 | Learning*. Morgan Kaufmann. pp. 282–289.
34 |
35 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
36 | """
37 |
38 | def __init__(self, num_tags: int, batch_first: bool = False) -> None:
39 | if num_tags <= 0:
40 | raise ValueError(f'invalid number of tags: {num_tags}')
41 | super().__init__()
42 | self.num_tags = num_tags
43 | self.batch_first = batch_first
44 | self.start_transitions = nn.Parameter(torch.empty(num_tags))
45 | self.end_transitions = nn.Parameter(torch.empty(num_tags))
46 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
47 |
48 | self.reset_parameters()
49 |
50 | def reset_parameters(self) -> None:
51 | """Initialize the transition parameters.
52 |
53 | The parameters will be initialized randomly from a uniform distribution
54 | between -0.1 and 0.1.
55 | """
56 | nn.init.uniform_(self.start_transitions, -0.1, 0.1)
57 | nn.init.uniform_(self.end_transitions, -0.1, 0.1)
58 | nn.init.uniform_(self.transitions, -0.1, 0.1)
59 |
60 | def __repr__(self) -> str:
61 | return f'{self.__class__.__name__}(num_tags={self.num_tags})'
62 |
63 | def forward(
64 | self,
65 | emissions: torch.Tensor,
66 | tags: torch.LongTensor,
67 | mask: Optional[torch.ByteTensor] = None,
68 | reduction: str = 'sum',
69 | ) -> torch.Tensor:
70 | """Compute the conditional log likelihood of a sequence of tags given emission scores.
71 |
72 | Args:
73 | emissions (`~torch.Tensor`): Emission score tensor of size
74 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
75 | ``(batch_size, seq_length, num_tags)`` otherwise.
76 | tags (`~torch.LongTensor`): Sequence of tags tensor of size
77 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
78 | ``(batch_size, seq_length)`` otherwise.
79 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
80 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
81 | reduction: Specifies the reduction to apply to the output:
82 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
83 | ``sum``: the output will be summed over batches. ``mean``: the output will be
84 | averaged over batches. ``token_mean``: the output will be averaged over tokens.
85 |
86 | Returns:
87 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
88 | reduction is ``none``, ``()`` otherwise.
89 | """
90 | self._validate(emissions, tags=tags, mask=mask)
91 | if reduction not in ('none', 'sum', 'mean', 'token_mean'):
92 | raise ValueError(f'invalid reduction: {reduction}')
93 | if mask is None:
94 | mask = torch.ones_like(tags, dtype=torch.uint8)
95 |
96 | if self.batch_first:
97 | emissions = emissions.transpose(0, 1)
98 | tags = tags.transpose(0, 1)
99 | mask = mask.transpose(0, 1)
100 |
101 | # shape: (batch_size,)
102 | numerator = self._compute_score(emissions, tags, mask)
103 | # shape: (batch_size,)
104 | denominator = self._compute_normalizer(emissions, mask)
105 | # shape: (batch_size,)
106 | llh = numerator - denominator
107 |
108 | if reduction == 'none':
109 | return llh
110 | if reduction == 'sum':
111 | return llh.sum()
112 | if reduction == 'mean':
113 | return llh.mean()
114 | assert reduction == 'token_mean'
115 | return llh.sum() / mask.float().sum()
116 |
117 | def decode(self, emissions: torch.Tensor,
118 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
119 | """Find the most likely tag sequence using Viterbi algorithm.
120 |
121 | Args:
122 | emissions (`~torch.Tensor`): Emission score tensor of size
123 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
124 | ``(batch_size, seq_length, num_tags)`` otherwise.
125 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
126 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
127 |
128 | Returns:
129 | List of list containing the best tag sequence for each batch.
130 | """
131 | self._validate(emissions, mask=mask)
132 | if mask is None:
133 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
134 |
135 | if self.batch_first:
136 | emissions = emissions.transpose(0, 1)
137 | mask = mask.transpose(0, 1)
138 |
139 | return self._viterbi_decode(emissions, mask)
140 |
141 | def _validate(
142 | self,
143 | emissions: torch.Tensor,
144 | tags: Optional[torch.LongTensor] = None,
145 | mask: Optional[torch.ByteTensor] = None) -> None:
146 | if emissions.dim() != 3:
147 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
148 | if emissions.size(2) != self.num_tags:
149 | raise ValueError(
150 | f'expected last dimension of emissions is {self.num_tags}, '
151 | f'got {emissions.size(2)}')
152 |
153 | if tags is not None:
154 | if emissions.shape[:2] != tags.shape:
155 | raise ValueError(
156 | 'the first two dimensions of emissions and tags must match, '
157 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
158 |
159 | if mask is not None:
160 | if emissions.shape[:2] != mask.shape:
161 | raise ValueError(
162 | 'the first two dimensions of emissions and mask must match, '
163 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
164 | no_empty_seq = not self.batch_first and mask[0].all()
165 | no_empty_seq_bf = self.batch_first and mask[:, 0].all()
166 | if not no_empty_seq and not no_empty_seq_bf:
167 | raise ValueError('mask of the first timestep must all be on')
168 |
169 | def _compute_score(
170 | self, emissions: torch.Tensor, tags: torch.LongTensor,
171 | mask: torch.ByteTensor) -> torch.Tensor:
172 | # emissions: (seq_length, batch_size, num_tags)
173 | # tags: (seq_length, batch_size)
174 | # mask: (seq_length, batch_size)
175 | assert emissions.dim() == 3 and tags.dim() == 2
176 | assert emissions.shape[:2] == tags.shape
177 | assert emissions.size(2) == self.num_tags
178 | assert mask.shape == tags.shape
179 | assert mask[0].all()
180 |
181 | seq_length, batch_size = tags.shape
182 | mask = mask.float()
183 |
184 | # Start transition score and first emission
185 | # shape: (batch_size,)
186 | score = self.start_transitions[tags[0]]
187 | score += emissions[0, torch.arange(batch_size), tags[0]]
188 |
189 | for i in range(1, seq_length):
190 | # Transition score to next tag, only added if next timestep is valid (mask == 1)
191 | # shape: (batch_size,)
192 | score += self.transitions[tags[i - 1], tags[i]] * mask[i]
193 |
194 | # Emission score for next tag, only added if next timestep is valid (mask == 1)
195 | # shape: (batch_size,)
196 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
197 |
198 | # End transition score
199 | # shape: (batch_size,)
200 | seq_ends = mask.long().sum(dim=0) - 1
201 | # shape: (batch_size,)
202 | last_tags = tags[seq_ends, torch.arange(batch_size)]
203 | # shape: (batch_size,)
204 | score += self.end_transitions[last_tags]
205 |
206 | return score
207 |
208 | def _compute_normalizer(
209 | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
210 | # emissions: (seq_length, batch_size, num_tags)
211 | # mask: (seq_length, batch_size)
212 | assert emissions.dim() == 3 and mask.dim() == 2
213 | assert emissions.shape[:2] == mask.shape
214 | assert emissions.size(2) == self.num_tags
215 | assert mask[0].all()
216 |
217 | seq_length = emissions.size(0)
218 |
219 | # Start transition score and first emission; score has size of
220 | # (batch_size, num_tags) where for each batch, the j-th column stores
221 | # the score that the first timestep has tag j
222 | # shape: (batch_size, num_tags)
223 | score = self.start_transitions + emissions[0]
224 |
225 | for i in range(1, seq_length):
226 | # Broadcast score for every possible next tag
227 | # shape: (batch_size, num_tags, 1)
228 | broadcast_score = score.unsqueeze(2)
229 |
230 | # Broadcast emission score for every possible current tag
231 | # shape: (batch_size, 1, num_tags)
232 | broadcast_emissions = emissions[i].unsqueeze(1)
233 |
234 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
235 | # for each sample, entry at row i and column j stores the sum of scores of all
236 | # possible tag sequences so far that end with transitioning from tag i to tag j
237 | # and emitting
238 | # shape: (batch_size, num_tags, num_tags)
239 | next_score = broadcast_score + self.transitions + broadcast_emissions
240 |
241 | # Sum over all possible current tags, but we're in score space, so a sum
242 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
243 | # all possible tag sequences so far, that end in tag i
244 | # shape: (batch_size, num_tags)
245 | next_score = torch.logsumexp(next_score, dim=1)
246 |
247 | # Set score to the next score if this timestep is valid (mask == 1)
248 | # shape: (batch_size, num_tags)
249 | score = torch.where(mask[i].unsqueeze(1), next_score, score)
250 |
251 | # End transition score
252 | # shape: (batch_size, num_tags)
253 | score += self.end_transitions
254 |
255 | # Sum (log-sum-exp) over all possible tags
256 | # shape: (batch_size,)
257 | return torch.logsumexp(score, dim=1)
258 |
259 | def _viterbi_decode(self, emissions: torch.FloatTensor,
260 | mask: torch.ByteTensor) -> List[List[int]]:
261 | # emissions: (seq_length, batch_size, num_tags)
262 | # mask: (seq_length, batch_size)
263 | assert emissions.dim() == 3 and mask.dim() == 2
264 | assert emissions.shape[:2] == mask.shape
265 | assert emissions.size(2) == self.num_tags
266 | assert mask[0].all()
267 |
268 | seq_length, batch_size = mask.shape
269 |
270 | # Start transition and first emission
271 | # shape: (batch_size, num_tags)
272 | score = self.start_transitions + emissions[0]
273 | history = []
274 |
275 | # score is a tensor of size (batch_size, num_tags) where for every batch,
276 | # value at column j stores the score of the best tag sequence so far that ends
277 | # with tag j
278 | # history saves where the best tags candidate transitioned from; this is used
279 | # when we trace back the best tag sequence
280 |
281 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence
282 | # for every possible next tag
283 | for i in range(1, seq_length):
284 | # Broadcast viterbi score for every possible next tag
285 | # shape: (batch_size, num_tags, 1)
286 | broadcast_score = score.unsqueeze(2)
287 |
288 | # Broadcast emission score for every possible current tag
289 | # shape: (batch_size, 1, num_tags)
290 | broadcast_emission = emissions[i].unsqueeze(1)
291 |
292 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
293 | # for each sample, entry at row i and column j stores the score of the best
294 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting
295 | # shape: (batch_size, num_tags, num_tags)
296 | next_score = broadcast_score + self.transitions + broadcast_emission
297 |
298 | # Find the maximum score over all possible current tag
299 | # shape: (batch_size, num_tags)
300 | next_score, indices = next_score.max(dim=1)
301 |
302 | # Set score to the next score if this timestep is valid (mask == 1)
303 | # and save the index that produces the next score
304 | # shape: (batch_size, num_tags)
305 | score = torch.where(mask[i].unsqueeze(1), next_score, score)
306 | history.append(indices)
307 |
308 | # End transition score
309 | # shape: (batch_size, num_tags)
310 | score += self.end_transitions
311 |
312 | # Now, compute the best path for each sample
313 |
314 | # shape: (batch_size,)
315 | seq_ends = mask.long().sum(dim=0) - 1
316 | best_tags_list = []
317 |
318 | for idx in range(batch_size):
319 | # Find the tag which maximizes the score at the last timestep; this is our best tag
320 | # for the last timestep
321 | _, best_last_tag = score[idx].max(dim=0)
322 | best_tags = [best_last_tag.item()]
323 |
324 | # We trace back where the best last tag comes from, append that to our best tag
325 | # sequence, and trace it back again, and so on
326 | for hist in reversed(history[:seq_ends[idx]]):
327 | best_last_tag = hist[idx][best_tags[-1]]
328 | best_tags.append(best_last_tag.item())
329 |
330 | # Reverse the order because we start from the last timestep
331 | best_tags.reverse()
332 | best_tags_list.append(best_tags)
333 |
334 | return best_tags_list
335 |
--------------------------------------------------------------------------------
/data/code/predict/run_JointBert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/9 15:22
4 | # @Author : JJkinging
5 | # @File : run_predict.py
6 | import warnings
7 |
8 | import torch
9 | import json
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from data.code.scripts.config_jointBert import Config
13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length
14 | from data.code.predict.test_dataset import CCFDataset
15 | from data.code.model.JointBertModel import JointBertModel
16 | from transformers import AutoTokenizer
17 | from data.code.predict.post_process import process
18 |
19 |
20 | def run_test(model, dataloader, slot_dict):
21 |
22 | model.eval()
23 | device = model.device
24 | intent_probs_list = []
25 |
26 | intent_pred = []
27 | slotNone_pred_output = []
28 | slot_pre_output = []
29 | input_ids_list = []
30 |
31 | id2slot = {value: key for key, value in slot_dict.items()}
32 | with torch.no_grad():
33 | tqdm_batch_iterator = tqdm(dataloader)
34 | for _, batch in enumerate(tqdm_batch_iterator):
35 | input_ids, input_mask = batch
36 |
37 | input_ids = input_ids.to(device)
38 | input_mask = input_mask.byte().to(device)
39 | real_length = torch.sum(input_mask, dim=1)
40 |
41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask)
42 |
43 | # intent
44 | intent_probs = torch.softmax(intent_logits, dim=-1)
45 | predict_labels = torch.argmax(intent_probs, dim=-1)
46 | predict_labels = predict_labels.cpu().numpy().tolist()
47 | intent_pred.extend(predict_labels)
48 |
49 | # slot_none
50 | slot_none_probs = torch.sigmoid(slot_none_logits)
51 | slot_none_probs = slot_none_probs > 0.5
52 | slot_none_probs = slot_none_probs.cpu().numpy()
53 | slot_none_probs = slot_none_probs.astype(int)
54 | slot_none_probs = slot_none_probs.tolist()
55 |
56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id
57 | tmp = []
58 | if 1 in slot_none_id:
59 | for idx, none_id in enumerate(slot_none_id):
60 | if none_id == 1:
61 | tmp.append(idx)
62 | else:
63 | tmp.append(28) # 28在slot_none_vocab中表示none
64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...]
65 |
66 | # slot
67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot)
68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记
69 | slot_pre_output.extend(out_path)
70 |
71 | # input_ids
72 | input_ids = input_ids.cpu().numpy().tolist()
73 | input_ids_list.extend(input_ids)
74 |
75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list
76 |
77 |
78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key):
79 | value = tokenizer.decode(input_ids[start:end + 1])
80 | value = ''.join(value.split(' ')).replace('[UNK]', '$')
81 | s = -1
82 | e = -1
83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写
84 | is_unk = False
85 | if len(value) > 0:
86 | if value[0] == '$': # 如果第一个就是[UNK]
87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位
88 | is_unk = True
89 | i = 0
90 | j = 0
91 | while i < len(ori_sen_lower) and j < len(value):
92 | if ori_sen_lower[i] == value[j]:
93 | if s == -1:
94 | s = i
95 | e = i
96 | else:
97 | e += 1
98 | i += 1
99 | j += 1
100 | elif ori_sen_lower[i] == ' ':
101 | e += 1
102 | i += 1
103 | elif value[j] == '$':
104 | e += 1
105 | i += 1
106 | j += 1
107 | elif ori_sen_lower[i] != value[j]:
108 | i -= j - 1
109 | j = 0
110 | s = -1
111 | e = -1
112 | if is_unk:
113 | s = s-1
114 | final_value = ori_sen[s:e + 1]
115 | if key in tmp_dict.keys():
116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list
117 | tmp_list = [tmp_dict[key], final_value]
118 | tmp_dict[key] = tmp_list
119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list):
120 | tmp_dict[key].append(final_value)
121 | else:
122 | tmp_dict[key] = final_value
123 | return tmp_dict
124 |
125 |
126 | def is_nest(slot_pre):
127 | '''
128 | 判断该条数据的 slot filling 是否存在嵌套
129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O']
130 | :return: 存在返回True, 否则返回False
131 | '''
132 | for j, cur_tag in enumerate(slot_pre):
133 | if cur_tag.startswith('I-'):
134 | cur_tag_name = cur_tag[2:]
135 | if j < len(slot_pre)-1:
136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头
137 | post_tag_name = slot_pre[j+1][2:]
138 | if cur_tag_name != post_tag_name: # 但二者不等
139 | return True
140 | return False
141 |
142 |
143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套
144 | '''
145 | :param input_ids: 原句的id形式(wordpiece过)
146 | :param slot_pre: 该句的预测slot
147 | :param ori_sen: 原句(字符)
148 | :return:
149 | '''
150 | start_outer = -1
151 | end_outer = -1
152 | tmp_dict = {}
153 | pre_end = -1
154 | for i, tag in enumerate(slot_pre):
155 | if i <= pre_end:
156 | continue
157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-'
158 | start_outer = i # 临时起始位置
159 | end_outer = i
160 | key_outer = tag[2:]
161 |
162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同
163 | end_outer = i
164 | if i == len(slot_pre)-1: # 到了最后
165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
166 |
167 | elif tag.startswith('O') and start_outer == -1:
168 | continue
169 | elif tag.startswith('O') and start_outer != -1:
170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
171 | start_outer = -1
172 | end_outer = -1
173 | # 第一种嵌套:B-region I-region I-name I-name
174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同
175 | start_inner = start_outer
176 | end_inner = end_outer
177 | key_inner = key_outer
178 | # 处理内层
179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner)
180 | # start_outer不变
181 | end_outer = i # end_outer为当前slot下标
182 | key_outer = slot_pre[i][2:] # 需修改key_outer
183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name
184 | elif tag.startswith('B-'):
185 | flag = False
186 | pre_start = start_outer
187 | pre_end = end_outer
188 | pre_key = key_outer
189 | start_outer = i
190 | end_outer = i
191 | key_outer = slot_pre[i][2:]
192 | # ************************
193 | for j in range(i, len(slot_pre)):
194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key:
195 | flag = True
196 | pre_end = j
197 | slot_pre[j] = 'O' # 置为'O'
198 | # *************************
199 | if flag: # 上一个slot是嵌套
200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层
201 | # 以后遇到完整slot的就加入(假设只有二级嵌套)
202 | for k in range(i+1, pre_end+1):
203 | if slot_pre[k] != 'O':
204 | if slot_pre[k].startswith('I-'):
205 | end_outer = k
206 | elif slot_pre[k].startswith('B-') and start_outer == -1:
207 | start_outer = k
208 | end_outer = k
209 | key_outer = slot_pre[k][2:]
210 | elif slot_pre[k].startswith('B-') and start_outer != -1:
211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
212 | key_outer) # 处理内层
213 | start_outer = k
214 | end_outer = k
215 | key_outer = slot_pre[k][2:]
216 | elif slot_pre[k] == 'O' and start_outer != -1:
217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
218 | key_outer) # 处理内层
219 | start_outer = -1
220 | end_outer = -1
221 | key_outer = None
222 | elif slot_pre[k] == 'O' and start_outer == -1:
223 | continue
224 | else: # 上一个不是嵌套
225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套
226 |
227 | return tmp_dict
228 |
229 |
230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list):
231 | fp = open('nest.txt', 'w+', encoding='utf-8')
232 | slot_list = []
233 | count = 0
234 | # ddd = 0
235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据
236 | # if ddd < 1111:
237 | # ddd += 1
238 | # continue
239 | tmp_dict = {}
240 | start = 0
241 | end = 0
242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为
243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i])
244 | slot_list.append(tmp_dict)
245 | fp.write(str(ori_sen_list[i])+'\t')
246 | fp.write(str(tmp_dict)+'\n')
247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n')
248 | count += 1
249 | else:
250 | for j, slot in enumerate(slot_ids): # 遍历每个字
251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]':
252 | if slot.startswith('B-') and start == 0:
253 | start = j # 槽值起始位置
254 | end = j
255 | key = slot[2:]
256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
258 | break
259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽
260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key
261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入
262 | start = j
263 | end = j
264 | key = slot[2:]
265 | else: # 'I-'开头
266 | end += 1
267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
269 | break
270 | else:
271 | if end == 0: # 说明没找到槽
272 | continue
273 | else:
274 | if slot != '[EOS]':
275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
276 | start = 0
277 | end = 0
278 | slot_list.append(tmp_dict)
279 | print(str(count)+'——JointBert 推理完成!')
280 | fp.close()
281 | return slot_list
282 |
283 |
284 | if __name__ == "__main__":
285 | warnings.filterwarnings("ignore")
286 | config = Config()
287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp:
289 | region_dict = json.load(fp)
290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie')
291 | print('JointBert 开始推理!')
292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt')
293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt')
294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt')
295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt')
296 |
297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt')
298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt')
299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt')
300 |
301 | intent_tagset_size = len(intent_dict)
302 | slot_none_tag_size = len(slot_none_dict)
303 | slot_tag_size = len(slot_dict)
304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt'
305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length)
306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size,
307 | collate_fn=collate_to_max_length)
308 |
309 | model = JointBertModel('../../user_data/pretrained_model/ernie',
310 | config.bert_hidden_size,
311 | intent_tagset_size,
312 | slot_none_tag_size,
313 | slot_tag_size,
314 | device).to(device)
315 | checkpoint = torch.load('../../user_data/output_model/JointBert/bert_model_best.pth.tar', map_location='cpu')
316 | model.load_state_dict(checkpoint["model"])
317 |
318 | # -------------------- Testing ------------------- #
319 | print("\n",
320 | 20 * "=",
321 | "Test model on device: {}".format(device),
322 | 20 * "=")
323 |
324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict)
325 |
326 | ori_sen_list = []
327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp:
328 | raw_data = json.load(fp)
329 | for filename, single_data in raw_data.items():
330 | text = single_data['text']
331 | ori_sen_list.append(text)
332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list)
333 |
334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!)
335 | for i, slot_dict in enumerate(slot_list):
336 | for slot, slot_value in slot_dict.items():
337 | for region_key, region_list in region_dict.items():
338 | if slot_value in region_list:
339 | slot_list[i][slot] = region_key
340 | res = {}
341 | for i in range(len(intent_pred)):
342 | big_tmp = {}
343 | slot_tmp_dict = {}
344 | # intent
345 | intent = id2intent[intent_pred[i]]
346 | # slot_none
347 | for slot_none_id in slotNone_pred_output[i]:
348 | if 0 <= slot_none_id <= 9:
349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id]
350 | elif 10 <= slot_none_id <= 18:
351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id]
352 | elif 19 <= slot_none_id <= 22:
353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id]
354 | elif 23 <= slot_none_id <= 27:
355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id]
356 | # slot
357 | slot_tmp_dict.update(slot_list[i])
358 |
359 | big_tmp['intent'] = intent
360 | big_tmp['slots'] = slot_tmp_dict
361 |
362 | length = len(str(i))
363 | o_num = 5 - length
364 | index = 'NLU' + '0'*o_num + str(i)
365 | res[index] = big_tmp
366 |
367 | with open('../../user_data/tmp_result/result_bert.json', 'w', encoding='utf-8') as fp:
368 | json.dump(res, fp, ensure_ascii=False)
369 |
370 | source_path = '../../user_data/tmp_result/result_bert.json'
371 | target_path = '../../user_data/tmp_result/result_bert_post.json'
372 | process(source_path, target_path)
373 |
--------------------------------------------------------------------------------
/data/code/predict/run_interact1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/9 15:22
4 | # @Author : JJkinging
5 | # @File : run_predict.py
6 | import warnings
7 |
8 | import torch
9 | import json
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from data.code.scripts.config_Interact1 import Config
13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length
14 | from data.code.predict.test_dataset import CCFDataset
15 | from data.code.model.InteractModel_1 import InteractModel
16 | from transformers import AutoTokenizer
17 | from data.code.predict.post_process import process
18 |
19 |
20 | def run_test(model, dataloader, slot_dict):
21 |
22 | model.eval()
23 | device = model.device
24 | intent_probs_list = []
25 |
26 | intent_pred = []
27 | slotNone_pred_output = []
28 | slot_pre_output = []
29 | input_ids_list = []
30 |
31 | id2slot = {value: key for key, value in slot_dict.items()}
32 | with torch.no_grad():
33 | tqdm_batch_iterator = tqdm(dataloader)
34 | for _, batch in enumerate(tqdm_batch_iterator):
35 | input_ids, input_mask = batch
36 |
37 | input_ids = input_ids.to(device)
38 | input_mask = input_mask.byte().to(device)
39 | real_length = torch.sum(input_mask, dim=1)
40 |
41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask)
42 |
43 | # intent
44 | intent_probs = torch.softmax(intent_logits, dim=-1)
45 | predict_labels = torch.argmax(intent_probs, dim=-1)
46 | predict_labels = predict_labels.cpu().numpy().tolist()
47 | intent_pred.extend(predict_labels)
48 |
49 | # slot_none
50 | slot_none_probs = torch.sigmoid(slot_none_logits)
51 | slot_none_probs = slot_none_probs > 0.5
52 | slot_none_probs = slot_none_probs.cpu().numpy()
53 | slot_none_probs = slot_none_probs.astype(int)
54 | slot_none_probs = slot_none_probs.tolist()
55 |
56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id
57 | tmp = []
58 | if 1 in slot_none_id:
59 | for idx, none_id in enumerate(slot_none_id):
60 | if none_id == 1:
61 | tmp.append(idx)
62 | else:
63 | tmp.append(28) # 28在slot_none_vocab中表示none
64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...]
65 |
66 | # slot
67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot)
68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记
69 | slot_pre_output.extend(out_path)
70 |
71 | # input_ids
72 | input_ids = input_ids.cpu().numpy().tolist()
73 | input_ids_list.extend(input_ids)
74 |
75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list
76 |
77 |
78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key):
79 | value = tokenizer.decode(input_ids[start:end + 1])
80 | value = ''.join(value.split(' ')).replace('[UNK]', '$')
81 | s = -1
82 | e = -1
83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写
84 | is_unk = False
85 | if len(value) > 0:
86 | if value[0] == '$': # 如果第一个就是[UNK]
87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位
88 | is_unk = True
89 | i = 0
90 | j = 0
91 | while i < len(ori_sen_lower) and j < len(value):
92 | if ori_sen_lower[i] == value[j]:
93 | if s == -1:
94 | s = i
95 | e = i
96 | else:
97 | e += 1
98 | i += 1
99 | j += 1
100 | elif ori_sen_lower[i] == ' ':
101 | e += 1
102 | i += 1
103 | elif value[j] == '$':
104 | e += 1
105 | i += 1
106 | j += 1
107 | elif ori_sen_lower[i] != value[j]:
108 | i -= j - 1
109 | j = 0
110 | s = -1
111 | e = -1
112 | if is_unk:
113 | s = s-1
114 | final_value = ori_sen[s:e + 1]
115 | if key in tmp_dict.keys():
116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list
117 | tmp_list = [tmp_dict[key], final_value]
118 | tmp_dict[key] = tmp_list
119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list):
120 | tmp_dict[key].append(final_value)
121 | else:
122 | tmp_dict[key] = final_value
123 | return tmp_dict
124 |
125 |
126 | def is_nest(slot_pre):
127 | '''
128 | 判断该条数据的 slot filling 是否存在嵌套
129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O']
130 | :return: 存在返回True, 否则返回False
131 | '''
132 | for j, cur_tag in enumerate(slot_pre):
133 | if cur_tag.startswith('I-'):
134 | cur_tag_name = cur_tag[2:]
135 | if j < len(slot_pre)-1:
136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头
137 | post_tag_name = slot_pre[j+1][2:]
138 | if cur_tag_name != post_tag_name: # 但二者不等
139 | return True
140 | return False
141 |
142 |
143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套
144 | '''
145 | :param input_ids: 原句的id形式(wordpiece过)
146 | :param slot_pre: 该句的预测slot
147 | :param ori_sen: 原句(字符)
148 | :return:
149 | '''
150 | start_outer = -1
151 | end_outer = -1
152 | tmp_dict = {}
153 | pre_end = -1
154 | for i, tag in enumerate(slot_pre):
155 | if i <= pre_end:
156 | continue
157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-'
158 | start_outer = i # 临时起始位置
159 | end_outer = i
160 | key_outer = tag[2:]
161 |
162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同
163 | end_outer = i
164 | if i == len(slot_pre)-1: # 到了最后
165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
166 |
167 | elif tag.startswith('O') and start_outer == -1:
168 | continue
169 | elif tag.startswith('O') and start_outer != -1:
170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
171 | start_outer = -1
172 | end_outer = -1
173 | # 第一种嵌套:B-region I-region I-name I-name
174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同
175 | start_inner = start_outer
176 | end_inner = end_outer
177 | key_inner = key_outer
178 | # 处理内层
179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner)
180 | # start_outer不变
181 | end_outer = i # end_outer为当前slot下标
182 | key_outer = slot_pre[i][2:] # 需修改key_outer
183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name
184 | elif tag.startswith('B-'):
185 | flag = False
186 | pre_start = start_outer
187 | pre_end = end_outer
188 | pre_key = key_outer
189 | start_outer = i
190 | end_outer = i
191 | key_outer = slot_pre[i][2:]
192 | # ************************
193 | for j in range(i, len(slot_pre)):
194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key:
195 | flag = True
196 | pre_end = j
197 | slot_pre[j] = 'O' # 置为'O'
198 | # *************************
199 | if flag: # 上一个slot是嵌套
200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层
201 | # 以后遇到完整slot的就加入(假设只有二级嵌套)
202 | for k in range(i+1, pre_end+1):
203 | if slot_pre[k] != 'O':
204 | if slot_pre[k].startswith('I-'):
205 | end_outer = k
206 | elif slot_pre[k].startswith('B-') and start_outer == -1:
207 | start_outer = k
208 | end_outer = k
209 | key_outer = slot_pre[k][2:]
210 | elif slot_pre[k].startswith('B-') and start_outer != -1:
211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
212 | key_outer) # 处理内层
213 | start_outer = k
214 | end_outer = k
215 | key_outer = slot_pre[k][2:]
216 | elif slot_pre[k] == 'O' and start_outer != -1:
217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
218 | key_outer) # 处理内层
219 | start_outer = -1
220 | end_outer = -1
221 | key_outer = None
222 | elif slot_pre[k] == 'O' and start_outer == -1:
223 | continue
224 | else: # 上一个不是嵌套
225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套
226 |
227 | return tmp_dict
228 |
229 |
230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list):
231 | fp = open('nest.txt', 'w+', encoding='utf-8')
232 | slot_list = []
233 | count = 0
234 | # ddd = 0
235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据
236 | # if ddd < 1111:
237 | # ddd += 1
238 | # continue
239 | tmp_dict = {}
240 | start = 0
241 | end = 0
242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为
243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i])
244 | slot_list.append(tmp_dict)
245 | fp.write(str(ori_sen_list[i])+'\t')
246 | fp.write(str(tmp_dict)+'\n')
247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n')
248 | count += 1
249 | else:
250 | for j, slot in enumerate(slot_ids): # 遍历每个字
251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]':
252 | if slot.startswith('B-') and start == 0:
253 | start = j # 槽值起始位置
254 | end = j
255 | key = slot[2:]
256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
258 | break
259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽
260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key
261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入
262 | start = j
263 | end = j
264 | key = slot[2:]
265 | else: # 'I-'开头
266 | end += 1
267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
269 | break
270 | else:
271 | if end == 0: # 说明没找到槽
272 | continue
273 | else:
274 | if slot != '[EOS]':
275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
276 | start = 0
277 | end = 0
278 | slot_list.append(tmp_dict)
279 | print(str(count)+'——interact1 推理完成!')
280 | fp.close()
281 | return slot_list
282 |
283 |
284 | if __name__ == "__main__":
285 | warnings.filterwarnings("ignore")
286 | config = Config()
287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp:
289 | region_dict = json.load(fp)
290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie')
291 | print('interact1 开始推理!')
292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt')
293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt')
294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt')
295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt')
296 |
297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt')
298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt')
299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt')
300 |
301 | intent_tagset_size = len(intent_dict)
302 | slot_none_tag_size = len(slot_none_dict)
303 | slot_tag_size = len(slot_dict)
304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt'
305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length)
306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size,
307 | collate_fn=collate_to_max_length)
308 |
309 | model = InteractModel('../../user_data/pretrained_model/ernie',
310 | config.bert_hidden_size,
311 | intent_tagset_size,
312 | slot_none_tag_size,
313 | slot_tag_size,
314 | device).to(device)
315 | checkpoint = torch.load('../../user_data/output_model/InteractModel_1/Interact1_model_best.pth.tar')
316 | model.load_state_dict(checkpoint["model"])
317 |
318 | # -------------------- Testing ------------------- #
319 | print("\n",
320 | 20 * "=",
321 | "Test model on device: {}".format(device),
322 | 20 * "=")
323 |
324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict)
325 |
326 | ori_sen_list = []
327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp:
328 | raw_data = json.load(fp)
329 | for filename, single_data in raw_data.items():
330 | text = single_data['text']
331 | ori_sen_list.append(text)
332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list)
333 |
334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!)
335 | for i, slot_dict in enumerate(slot_list):
336 | for slot, slot_value in slot_dict.items():
337 | for region_key, region_list in region_dict.items():
338 | if slot_value in region_list:
339 | slot_list[i][slot] = region_key
340 | res = {}
341 | for i in range(len(intent_pred)):
342 | big_tmp = {}
343 | slot_tmp_dict = {}
344 | # intent
345 | intent = id2intent[intent_pred[i]]
346 | # slot_none
347 | for slot_none_id in slotNone_pred_output[i]:
348 | if 0 <= slot_none_id <= 9:
349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id]
350 | elif 10 <= slot_none_id <= 18:
351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id]
352 | elif 19 <= slot_none_id <= 22:
353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id]
354 | elif 23 <= slot_none_id <= 27:
355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id]
356 | # slot
357 | slot_tmp_dict.update(slot_list[i])
358 |
359 | big_tmp['intent'] = intent
360 | big_tmp['slots'] = slot_tmp_dict
361 |
362 | length = len(str(i))
363 | o_num = 5 - length
364 | index = 'NLU' + '0'*o_num + str(i)
365 | res[index] = big_tmp
366 |
367 | with open('../../user_data/tmp_result/result_interact_1.json', 'w', encoding='utf-8') as fp:
368 | json.dump(res, fp, ensure_ascii=False)
369 |
370 | source_path = '../../user_data/tmp_result/result_interact_1.json'
371 | target_path = '../../user_data/tmp_result/result_interact_1_post.json'
372 | process(source_path, target_path)
373 |
--------------------------------------------------------------------------------
/data/code/predict/run_interact3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2021/9/9 15:22
4 | # @Author : JJkinging
5 | # @File : run_predict.py
6 | import warnings
7 |
8 | import torch
9 | import json
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from data.code.scripts.config_Interact3 import Config
13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length
14 | from data.code.predict.test_dataset import CCFDataset
15 | from data.code.model.InteractModel_3 import InteractModel
16 | from transformers import AutoTokenizer
17 | from data.code.predict.post_process import process
18 |
19 |
20 | def run_test(model, dataloader, slot_dict):
21 |
22 | model.eval()
23 | device = model.device
24 | intent_probs_list = []
25 |
26 | intent_pred = []
27 | slotNone_pred_output = []
28 | slot_pre_output = []
29 | input_ids_list = []
30 |
31 | id2slot = {value: key for key, value in slot_dict.items()}
32 | with torch.no_grad():
33 | tqdm_batch_iterator = tqdm(dataloader)
34 | for _, batch in enumerate(tqdm_batch_iterator):
35 | input_ids, input_mask = batch
36 |
37 | input_ids = input_ids.to(device)
38 | input_mask = input_mask.byte().to(device)
39 | real_length = torch.sum(input_mask, dim=1)
40 |
41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask)
42 |
43 | # intent
44 | intent_probs = torch.softmax(intent_logits, dim=-1)
45 | predict_labels = torch.argmax(intent_probs, dim=-1)
46 | predict_labels = predict_labels.cpu().numpy().tolist()
47 | intent_pred.extend(predict_labels)
48 |
49 | # slot_none
50 | slot_none_probs = torch.sigmoid(slot_none_logits)
51 | slot_none_probs = slot_none_probs > 0.5
52 | slot_none_probs = slot_none_probs.cpu().numpy()
53 | slot_none_probs = slot_none_probs.astype(int)
54 | slot_none_probs = slot_none_probs.tolist()
55 |
56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id
57 | tmp = []
58 | if 1 in slot_none_id:
59 | for idx, none_id in enumerate(slot_none_id):
60 | if none_id == 1:
61 | tmp.append(idx)
62 | else:
63 | tmp.append(28) # 28在slot_none_vocab中表示none
64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...]
65 |
66 | # slot
67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot)
68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记
69 | slot_pre_output.extend(out_path)
70 |
71 | # input_ids
72 | input_ids = input_ids.cpu().numpy().tolist()
73 | input_ids_list.extend(input_ids)
74 |
75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list
76 |
77 |
78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key):
79 | value = tokenizer.decode(input_ids[start:end + 1])
80 | value = ''.join(value.split(' ')).replace('[UNK]', '$')
81 | s = -1
82 | e = -1
83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写
84 | is_unk = False
85 | if len(value) > 0:
86 | if value[0] == '$': # 如果第一个就是[UNK]
87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位
88 | is_unk = True
89 | i = 0
90 | j = 0
91 | while i < len(ori_sen_lower) and j < len(value):
92 | if ori_sen_lower[i] == value[j]:
93 | if s == -1:
94 | s = i
95 | e = i
96 | else:
97 | e += 1
98 | i += 1
99 | j += 1
100 | elif ori_sen_lower[i] == ' ':
101 | e += 1
102 | i += 1
103 | elif value[j] == '$':
104 | e += 1
105 | i += 1
106 | j += 1
107 | elif ori_sen_lower[i] != value[j]:
108 | i -= j - 1
109 | j = 0
110 | s = -1
111 | e = -1
112 | if is_unk:
113 | s = s-1
114 | final_value = ori_sen[s:e + 1]
115 | if key in tmp_dict.keys():
116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list
117 | tmp_list = [tmp_dict[key], final_value]
118 | tmp_dict[key] = tmp_list
119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list):
120 | tmp_dict[key].append(final_value)
121 | else:
122 | tmp_dict[key] = final_value
123 | return tmp_dict
124 |
125 |
126 | def is_nest(slot_pre):
127 | '''
128 | 判断该条数据的 slot filling 是否存在嵌套
129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O']
130 | :return: 存在返回True, 否则返回False
131 | '''
132 | for j, cur_tag in enumerate(slot_pre):
133 | if cur_tag.startswith('I-'):
134 | cur_tag_name = cur_tag[2:]
135 | if j < len(slot_pre)-1:
136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头
137 | post_tag_name = slot_pre[j+1][2:]
138 | if cur_tag_name != post_tag_name: # 但二者不等
139 | return True
140 | return False
141 |
142 |
143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套
144 | '''
145 | :param input_ids: 原句的id形式(wordpiece过)
146 | :param slot_pre: 该句的预测slot
147 | :param ori_sen: 原句(字符)
148 | :return:
149 | '''
150 | start_outer = -1
151 | end_outer = -1
152 | tmp_dict = {}
153 | pre_end = -1
154 | for i, tag in enumerate(slot_pre):
155 | if i <= pre_end:
156 | continue
157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-'
158 | start_outer = i # 临时起始位置
159 | end_outer = i
160 | key_outer = tag[2:]
161 |
162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同
163 | end_outer = i
164 | if i == len(slot_pre)-1: # 到了最后
165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
166 |
167 | elif tag.startswith('O') and start_outer == -1:
168 | continue
169 | elif tag.startswith('O') and start_outer != -1:
170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer)
171 | start_outer = -1
172 | end_outer = -1
173 | # 第一种嵌套:B-region I-region I-name I-name
174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同
175 | start_inner = start_outer
176 | end_inner = end_outer
177 | key_inner = key_outer
178 | # 处理内层
179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner)
180 | # start_outer不变
181 | end_outer = i # end_outer为当前slot下标
182 | key_outer = slot_pre[i][2:] # 需修改key_outer
183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name
184 | elif tag.startswith('B-'):
185 | flag = False
186 | pre_start = start_outer
187 | pre_end = end_outer
188 | pre_key = key_outer
189 | start_outer = i
190 | end_outer = i
191 | key_outer = slot_pre[i][2:]
192 | # ************************
193 | for j in range(i, len(slot_pre)):
194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key:
195 | flag = True
196 | pre_end = j
197 | slot_pre[j] = 'O' # 置为'O'
198 | # *************************
199 | if flag: # 上一个slot是嵌套
200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层
201 | # 以后遇到完整slot的就加入(假设只有二级嵌套)
202 | for k in range(i+1, pre_end+1):
203 | if slot_pre[k] != 'O':
204 | if slot_pre[k].startswith('I-'):
205 | end_outer = k
206 | elif slot_pre[k].startswith('B-') and start_outer == -1:
207 | start_outer = k
208 | end_outer = k
209 | key_outer = slot_pre[k][2:]
210 | elif slot_pre[k].startswith('B-') and start_outer != -1:
211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
212 | key_outer) # 处理内层
213 | start_outer = k
214 | end_outer = k
215 | key_outer = slot_pre[k][2:]
216 | elif slot_pre[k] == 'O' and start_outer != -1:
217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer,
218 | key_outer) # 处理内层
219 | start_outer = -1
220 | end_outer = -1
221 | key_outer = None
222 | elif slot_pre[k] == 'O' and start_outer == -1:
223 | continue
224 | else: # 上一个不是嵌套
225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套
226 |
227 | return tmp_dict
228 |
229 |
230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list):
231 | fp = open('nest.txt', 'w+', encoding='utf-8')
232 | slot_list = []
233 | count = 0
234 | # ddd = 0
235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据
236 | # if ddd < 1111:
237 | # ddd += 1
238 | # continue
239 | tmp_dict = {}
240 | start = 0
241 | end = 0
242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为
243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i])
244 | slot_list.append(tmp_dict)
245 | fp.write(str(ori_sen_list[i])+'\t')
246 | fp.write(str(tmp_dict)+'\n')
247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n')
248 | count += 1
249 | else:
250 | for j, slot in enumerate(slot_ids): # 遍历每个字
251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]':
252 | if slot.startswith('B-') and start == 0:
253 | start = j # 槽值起始位置
254 | end = j
255 | key = slot[2:]
256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
258 | break
259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽
260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key
261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入
262 | start = j
263 | end = j
264 | key = slot[2:]
265 | else: # 'I-'开头
266 | end += 1
267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS)
268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
269 | break
270 | else:
271 | if end == 0: # 说明没找到槽
272 | continue
273 | else:
274 | if slot != '[EOS]':
275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key)
276 | start = 0
277 | end = 0
278 | slot_list.append(tmp_dict)
279 | print(str(count)+'——interact3 推理完成!')
280 | fp.close()
281 | return slot_list
282 |
283 |
284 | if __name__ == "__main__":
285 | warnings.filterwarnings("ignore")
286 | config = Config()
287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu")
288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp:
289 | region_dict = json.load(fp)
290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie')
291 | print('interact3 开始推理!')
292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt')
293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt')
294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt')
295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt')
296 |
297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt')
298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt')
299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt')
300 |
301 | intent_tagset_size = len(intent_dict)
302 | slot_none_tag_size = len(slot_none_dict)
303 | slot_tag_size = len(slot_dict)
304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt'
305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length)
306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size,
307 | collate_fn=collate_to_max_length)
308 |
309 | model = InteractModel('../../user_data/pretrained_model/ernie',
310 | config.bert_hidden_size,
311 | intent_tagset_size,
312 | slot_none_tag_size,
313 | slot_tag_size,
314 | device).to(device)
315 | checkpoint = torch.load('../../user_data/output_model/InteractModel_3/Interact3_model_best.pth.tar')
316 | model.load_state_dict(checkpoint["model"])
317 |
318 | # -------------------- Testing ------------------- #
319 | print("\n",
320 | 20 * "=",
321 | "Test model on device: {}".format(device),
322 | 20 * "=")
323 |
324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict)
325 |
326 | ori_sen_list = []
327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp:
328 | raw_data = json.load(fp)
329 | for filename, single_data in raw_data.items():
330 | text = single_data['text']
331 | ori_sen_list.append(text)
332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list)
333 |
334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!)
335 | for i, slot_dict in enumerate(slot_list):
336 | for slot, slot_value in slot_dict.items():
337 | for region_key, region_list in region_dict.items():
338 | if slot_value in region_list:
339 | slot_list[i][slot] = region_key
340 | res = {}
341 | for i in range(len(intent_pred)):
342 | big_tmp = {}
343 | slot_tmp_dict = {}
344 | # intent
345 | intent = id2intent[intent_pred[i]]
346 | # slot_none
347 | for slot_none_id in slotNone_pred_output[i]:
348 | if 0 <= slot_none_id <= 9:
349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id]
350 | elif 10 <= slot_none_id <= 18:
351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id]
352 | elif 19 <= slot_none_id <= 22:
353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id]
354 | elif 23 <= slot_none_id <= 27:
355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id]
356 | # slot
357 | slot_tmp_dict.update(slot_list[i])
358 |
359 | big_tmp['intent'] = intent
360 | big_tmp['slots'] = slot_tmp_dict
361 |
362 | length = len(str(i))
363 | o_num = 5 - length
364 | index = 'NLU' + '0'*o_num + str(i)
365 | res[index] = big_tmp
366 |
367 | with open('../../user_data/tmp_result/result_interact_3.json', 'w', encoding='utf-8') as fp:
368 | json.dump(res, fp, ensure_ascii=False)
369 |
370 | source_path = '../../user_data/tmp_result/result_interact_3.json'
371 | target_path = '../../user_data/tmp_result/result_interact_3_post.json'
372 | process(source_path, target_path)
373 |
--------------------------------------------------------------------------------