├── data ├── raw_data │ └── .keep ├── prediction_result │ └── .keep ├── user_data │ └── user_data.txt ├── __init__.py └── code │ ├── __init__.py │ ├── model │ ├── __init__.py │ ├── JointBertModel.py │ ├── InteractModel_1.py │ ├── InteractModel_3.py │ └── torchcrf.py │ ├── predict │ ├── __init__.py │ ├── post_process.py │ ├── test_dataset.py │ ├── test_utils.py │ ├── nest.txt │ ├── integration.py │ ├── run_JointBert.py │ ├── run_interact1.py │ └── run_interact3.py │ ├── preprocess │ ├── __init__.py │ ├── generate_intent.py │ ├── extract_intent_sample.py │ ├── process_other.py │ ├── slot_sorted.py │ ├── split_train_dev.py │ ├── extend_tv_sample.py │ ├── analysis.py │ ├── rectify.py │ ├── extend_audio_sample.py │ └── process_rawdata.py │ └── scripts │ ├── __init__.py │ ├── config_jointBert.py │ ├── config_Interact3.py │ ├── config_Interact1.py │ ├── build_vocab.py │ ├── dataset.py │ ├── train_interact1.py │ ├── train_interact3.py │ ├── train_jointBert.py │ └── utils.py ├── image ├── ccir-image.txt ├── readme_images │ ├── image-20210929212611984.png │ ├── image-20210929212740433.png │ ├── image-20210929212830068.png │ └── image-20210929213002909.png ├── run_infer.sh ├── run.sh └── README.md ├── 【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx ├── 【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx ├── .idea ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── .gitignore ├── modules.xml ├── deployment.xml └── CCIR-Cup.iml └── README.md /data/raw_data/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/prediction_result/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/ccir-image.txt: -------------------------------------------------------------------------------- 1 | 链接:https://pan.baidu.com/s/1oUpFVgHvqcvsYYD4w6yHnQ 2 | 提取码:78pw -------------------------------------------------------------------------------- /data/user_data/user_data.txt: -------------------------------------------------------------------------------- 1 | 链接:https://pan.baidu.com/s/1uwy8sO3KKEp-OdrI4C1H5Q 2 | 提取码:0i1x -------------------------------------------------------------------------------- /【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/【智能人机交互自然语言理解】SCU-JJkinging-说明论文.docx -------------------------------------------------------------------------------- /【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/【智能人机交互自然语言理解】SCU-JJkinging-分享PPT.pptx -------------------------------------------------------------------------------- /image/readme_images/image-20210929212611984.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212611984.png -------------------------------------------------------------------------------- /image/readme_images/image-20210929212740433.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212740433.png -------------------------------------------------------------------------------- /image/readme_images/image-20210929212830068.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929212830068.png -------------------------------------------------------------------------------- /image/readme_images/image-20210929213002909.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCU-JJkinging/CCIR-Cup/HEAD/image/readme_images/image-20210929213002909.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/27 14:50 4 | # @Author : JJkinging 5 | # @File : __init__.py 6 | -------------------------------------------------------------------------------- /data/code/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/27 14:50 4 | # @Author : JJkinging 5 | # @File : __init__.py 6 | -------------------------------------------------------------------------------- /data/code/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/27 14:45 4 | # @Author : JJkinging 5 | # @File : __init__.py.py 6 | -------------------------------------------------------------------------------- /data/code/predict/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/9 16:00 4 | # @Author : JJkinging 5 | # @File : __init__.py.py 6 | -------------------------------------------------------------------------------- /data/code/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/30 11:58 4 | # @Author : JJkinging 5 | # @File : __init__.py.py 6 | -------------------------------------------------------------------------------- /data/code/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/27 14:37 4 | # @Author : JJkinging 5 | # @File : __init__.py.py 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../:\python_project\CCIR-Cup\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /image/run_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /ccir/data/code/predict 4 | python run_trained_JointBert.py 5 | python run_trained_interact1.py 6 | python run_trained_interact3.py 7 | python integration.py 8 | echo '模型已全部推理完成,结果result.json已保存在prediction_result文件夹下' 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /image/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'start train the first model--JointBert' 4 | cd /ccir/data/code/scripts 5 | python train_jointBert.py 6 | 7 | echo 'start train the second model--InteractModel1' 8 | python train_interact1.py 9 | 10 | echo 'start train the third model--InteractModel3' 11 | python train_interact3.py 12 | 13 | cd ../predict 14 | python run_JointBert.py 15 | python run_interact1.py 16 | python run_interact3.py 17 | python integration.py 18 | echo '模型已全部推理完成,结果result.json已保存在prediction_result文件夹下' 19 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /data/code/preprocess/generate_intent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/16 12:28 4 | # @Author : JJkinging 5 | # @File : generate_intent.py 6 | import json 7 | 8 | with open('../small_sample/new_B/train_final_clear_del.json', 'r', encoding='utf-8') as fp: 9 | raw_data = json.load(fp) 10 | 11 | intent_write = open('../small_sample/new_B/train_intent_label.txt', 'w+', encoding='utf-8') 12 | for filename, single_data in raw_data.items(): 13 | intent = single_data['intent'] 14 | intent_write.write(intent+'\n') 15 | 16 | intent_write.close() 17 | -------------------------------------------------------------------------------- /.idea/CCIR-Cup.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /data/code/preprocess/extract_intent_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/15 21:07 4 | # @Author : JJkinging 5 | # @File : extract_small_sample.py 6 | import json 7 | 8 | with open('../raw_data/train_8.json', 'r', encoding='utf-8') as fp: 9 | raw_data = json.load(fp) 10 | 11 | res = {} 12 | for filename, single_data in raw_data.items(): 13 | intent = single_data['intent'] 14 | if intent == 'TVProgram-Play': 15 | res[filename] = single_data 16 | 17 | with open('../extend_data/same_intent/ori_data/TVProgram-Play.json', 'w', encoding='utf-8') as fp: 18 | json.dump(res, fp, ensure_ascii=False) 19 | -------------------------------------------------------------------------------- /data/code/preprocess/process_other.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/10 22:22 4 | # @Author : JJkinging 5 | # @File : process_other.py 6 | import json 7 | 8 | with open('../../dataset/final_data/other_2.txt', 'r', encoding='utf-8') as fp: 9 | data = fp.readlines() 10 | data = [item.split(' ')[0].strip('\n') for item in data] 11 | 12 | 13 | res = {} 14 | for i, sen in enumerate(data): 15 | tem_dict = {} 16 | tem_dict['text'] = sen 17 | tem_dict['intent'] = 'Other' 18 | tem_dict['slots'] = {} 19 | 20 | lens = len(str(i)) 21 | o_nums = 5 - lens 22 | filename = 'NLU' + '0' * o_nums + str(i) 23 | res[filename] = tem_dict 24 | 25 | 26 | with open('../../dataset/small_sample/data/other_2.json', 'w', encoding='utf-8') as fp: 27 | json.dump(res, fp, ensure_ascii=False) 28 | -------------------------------------------------------------------------------- /data/code/preprocess/slot_sorted.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/12 18:52 4 | # @Author : JJkinging 5 | # @File : slot_sorted.py 6 | 7 | '''对train.json的数据按slot的value的长度从大到小排序''' 8 | import json 9 | 10 | with open('../small_sample/new_B/train_final_clear.json', 'r', encoding='utf-8') as fp: 11 | raw_data = json.load(fp) 12 | res = {} 13 | for filename, single_data in raw_data.items(): 14 | tmp_dict = {} 15 | tmp_dict['text'] = single_data['text'] 16 | tmp_dict['intent'] = single_data['intent'] 17 | slots = single_data['slots'] 18 | dic = {} 19 | tmp_tuple = sorted(slots.items(), key=lambda x: len(x[1]), reverse=True) 20 | for tuple_data in tmp_tuple: 21 | dic[tuple_data[0]] = tuple_data[1] 22 | tmp_dict['slots'] = dic 23 | res[filename] = tmp_dict 24 | 25 | with open('../small_sample/new_B/train_final_clear_sorted.json', 'w', encoding='utf-8') as fp: 26 | json.dump(res, fp, ensure_ascii=False) 27 | -------------------------------------------------------------------------------- /data/code/preprocess/split_train_dev.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 16:10 4 | # @Author : JJkinging 5 | # @File : split_train_dev.py 6 | import json 7 | 8 | 9 | def split_train_dev(ratio=0.8): 10 | with open('../raw_data/train_slot_sorted.json', 'r', encoding='utf-8') as fp: 11 | raw_data = json.load(fp) 12 | 13 | length = len(raw_data) 14 | 15 | train_data = {} 16 | dev_data = {} 17 | 18 | count = 0 19 | for key, value in raw_data.items(): 20 | if count < length*ratio: 21 | train_data[key] = value 22 | else: 23 | dev_data[key] = value 24 | count += 1 25 | 26 | # 写训练集 27 | with open('../raw_data/train_8.json', 'w', encoding='utf-8') as fp: 28 | json.dump(train_data, fp, ensure_ascii=False) 29 | # 写验证集 30 | with open('../raw_data/dev_2.json', 'w', encoding='utf-8') as fp: 31 | json.dump(dev_data, fp, ensure_ascii=False) 32 | 33 | 34 | if __name__ == "__main__": 35 | split_train_dev(0.8) 36 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 22 | -------------------------------------------------------------------------------- /image/README.md: -------------------------------------------------------------------------------- 1 | 由于本地训练和线上训练存在一定偏差,所以提供了两种方式供复现 2 | * 线上训练 + 线上推理 — run.sh 3 | * 本地训练 + 线上推理 — run_infer.sh 4 | 5 | 准备工作:
6 | 先从ccir/data/user_data/user_data.txt中的百度云网盘链接下载user_data.zip,然后将其解压并替换掉user_data目录; 7 | 然后从ccir/image/ccir-image.txt的百度云网盘链接下载镜像文件ccir-image.tar,并将其放在ccir/image目录下 8 | 9 | 进入到ccir/image目录下,加载 ccir-image.tar 镜像文件 10 | ``` 11 | sudo docker load -i ccir-image.tar 12 | ``` 13 | #### 1.线上训练 + 线上推理 14 | 15 | 我的项目名字叫 ccir 16 | 17 | 这种方式是在线上重新训练模型然后推理复现,线上一共需要训练三个模型,在 T4 GPU 上一共大约需要耗时7个小时
18 | 镜像挂载运行方式: 19 | 20 | ``` 21 | nvidia-docker run -v /home/seatrend/jinxiang/ccir/:/ccir ccir-image sh /ccir/image/run.sh 22 | ``` 23 | 24 | 其中: 25 | 26 | ``` 27 | /home/seatrend/jinxiang/ccir/ 28 | ``` 29 | 30 | 是我本地项目 ccir 所在的绝对路径,线上复现时需要替换成ccir在线上的绝对路径名,然后把它挂载到镜像的 /ccir 目录下,镜像名称是 ccir-image 31 | 32 | #### 2.本地训练 + 线上推理 33 | 34 | 如果需要直接复现线下训练结果,则: 35 | 36 | ``` 37 | nvidia-docker run -v /home/seatrend/jinxiang/ccir/:/ccir ccir-image sh /ccir/image/run_infer.sh 38 | ``` 39 | 其中: 40 | 41 | ``` 42 | /home/seatrend/jinxiang/ccir/ 43 | ``` 44 | 45 | 是我本地项目 ccir 所在的绝对路径,线上复现时需要替换成ccir在线上的绝对路径名,然后把它挂载到镜像的 /ccir 目录下,镜像名称是 ccir-image 46 | 47 | 线下训练好的模型共三个,分别保存在: 48 | 49 | ``` 50 | ccir/data/user_data/output_model/JointBert/trained_model 51 | ``` 52 | 53 | ``` 54 | ccir/data/user_data/output_model/InteractModel_1/trained_model 55 | ``` 56 | 57 | ``` 58 | ccir/data/user_data/output_model/InteractModel_3/trained_model 59 | ``` 60 | 61 | **注意:** run.sh和run_infer.sh在window下打开可能会出现格式问题,如果运行报错,可以尝试在vim中,命令模式输入:set ff=unix,转换成unix的格式,然后保存退出重新运行。
62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /data/code/predict/post_process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/14 19:58 4 | # @Author : JJkinging 5 | # @File : post_process.py 6 | import json 7 | 8 | 9 | def process(source_path, target_path): 10 | ''' 11 | 该函数作用:对模型预测结果进行纠正,即把不属于某一类intent的槽值删除,举例来说: 12 | 比如我在一条测试数据中预测出其intent = FilmTele-Play, 然后其槽值预测中出现了"notes"这个槽标签, 13 | 这与我之前统计的哪些槽标签只出现在哪些意图中不符合(即训练数据中FileTele-Play这个意图不可能出现"notes"这个槽标签), 14 | 所以该函数就把"notes"这个槽位和槽值删除掉。 15 | :param source_path: 16 | :param target_path: 17 | :return: 18 | ''' 19 | with open(source_path, 'r', encoding='utf-8') as fp: 20 | ori_data = json.load(fp) 21 | 22 | with open('../../user_data/common_data/intent_slot_mapping.json', 'r', encoding='utf-8') as fp: 23 | intent_slot_mapping = json.load(fp) 24 | 25 | mapping_keys = list(intent_slot_mapping.keys()) 26 | for filename, single_data in ori_data.items(): 27 | intent = single_data['intent'] 28 | slots = single_data['slots'] 29 | slot_keys = list(single_data['slots'].keys()) 30 | for item in mapping_keys: 31 | if intent == item: 32 | all_tag = intent_slot_mapping[item] # ["name", "tag", "artist", "region", "play_setting", "age"] 33 | for key in slot_keys: # 检查slots结果的每一个槽位是否合理 34 | if key not in all_tag: 35 | ori_data[filename]['slots'].pop(key) # 如果不合理,则删除该条槽 36 | 37 | for key, values in slots.items(): 38 | if isinstance(values, list): 39 | tmp = list(set(values)) 40 | if len(tmp) == 1: 41 | tmp = tmp[0] 42 | ori_data[filename]['slots'][key] = tmp 43 | with open(target_path, 'w', encoding='utf-8') as fp: 44 | json.dump(ori_data, fp, ensure_ascii=False) 45 | -------------------------------------------------------------------------------- /data/code/predict/test_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 12:39 4 | # @Author : JJkinging 5 | # @File : utils.py 6 | import json 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CCFDataset(Dataset): 11 | def __init__(self, filename, vocab, intent_dict, slot_none_dict, slot_dict, max_length=512): 12 | ''' 13 | :param filename:读取数据文件名,例如:train_seq_in.txt 14 | :param slot_none_dict: slot_none的字典 15 | :param slot_dict: slot_label的字典 16 | :param vocab: 词表,例如:bert的vocab.txt 17 | :param intent_dict: intent2id的字典 18 | :param max_length: 单句最大长度 19 | ''' 20 | self.filename = filename 21 | self.vocab = vocab 22 | self.intent_dict = intent_dict 23 | self.slot_none_dict = slot_none_dict 24 | self.slot_dict = slot_dict 25 | self.max_length = max_length 26 | 27 | self.result = [] 28 | 29 | # 读取数据 30 | with open(self.filename, 'r', encoding='utf-8') as fp: 31 | sen_data = fp.readlines() 32 | sen_data = [item.strip('\n') for item in sen_data] # 删除句子结尾的换行符('\n') 33 | 34 | for utterance in sen_data: 35 | utterance = utterance.split(' ') # str变list 36 | # 最大长度检验 37 | if len(utterance) > self.max_length-2: 38 | utterance = utterance[:max_length] 39 | 40 | # input_ids 41 | utterance = ['[CLS]'] + utterance + ['[SEP]'] 42 | input_ids = [int(self.vocab[i]) for i in utterance] 43 | 44 | length = len(input_ids) 45 | 46 | # input_mask 47 | input_mask = [1] * len(input_ids) 48 | 49 | self.result.append((input_ids, input_mask)) 50 | 51 | def __len__(self): 52 | return len(self.result) 53 | 54 | def __getitem__(self, index): 55 | input_ids, input_mask = self.result[index] 56 | 57 | return input_ids, input_mask 58 | -------------------------------------------------------------------------------- /data/code/predict/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/9 15:57 4 | # @Author : JJkinging 5 | # @File : utils.py 6 | import torch 7 | 8 | def load_vocab(vocab_file): 9 | '''construct word2id''' 10 | vocab = {} 11 | index = 0 12 | with open(vocab_file, 'r', encoding='utf-8') as fp: 13 | while True: 14 | token = fp.readline() 15 | if not token: 16 | break 17 | token = token.strip() # 删除空白符 18 | vocab[token] = index 19 | index += 1 20 | return vocab 21 | 22 | 23 | def load_reverse_vocab(vocab_file): 24 | '''construct id2word''' 25 | vocab = {} 26 | index = 0 27 | with open(vocab_file, 'r', encoding='utf-8') as fp: 28 | while True: 29 | token = fp.readline() 30 | if not token: 31 | break 32 | token = token.strip() # 删除空白符 33 | vocab[index] = token 34 | index += 1 35 | return vocab 36 | 37 | 38 | def collate_to_max_length(batch): 39 | # input_ids, input_mask 40 | batch_size = len(batch) 41 | input_ids_list = [] 42 | input_mask_list = [] 43 | for single_data in batch: 44 | input_ids_list.append(single_data[0]) 45 | input_mask_list.append(single_data[1]) 46 | 47 | max_length = max([len(item) for item in input_ids_list]) 48 | 49 | output = [torch.full([batch_size, max_length], 50 | fill_value=0, 51 | dtype=torch.long), 52 | torch.full([batch_size, max_length], 53 | fill_value=0, 54 | dtype=torch.long) 55 | ] 56 | 57 | for i in range(batch_size): 58 | output[0][i][0:len(input_ids_list[i])] = torch.LongTensor(input_ids_list[i]) 59 | output[1][i][0:len(input_mask_list[i])] = torch.LongTensor(input_mask_list[i]) 60 | 61 | return output # (input_ids, input_mask) 62 | -------------------------------------------------------------------------------- /data/code/scripts/config_jointBert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 15:13 4 | # @Author : JJkinging 5 | # @File : config.py 6 | class Config(object): 7 | '''配置类''' 8 | 9 | def __init__(self): 10 | self.train_file = '../../user_data/train_data/train_seq_in.txt' 11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt' 12 | self.test_file = '../../user_data/test_data/test_seq_in.txt' 13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt' 14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt' 15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt' 16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt' 17 | self.max_length = 512 18 | self.batch_size = 16 19 | self.test_batch_size = 32 20 | self.bert_model_path = '../../user_data/pretrained_model/ernie' 21 | self.checkpoint = None # '../../user_data/output_model/JointBert/model_14.pth.tar' 22 | self.use_gpu = True 23 | self.cuda = "cuda:0" 24 | self.attention_dropout = 0.1 25 | self.bert_hidden_size = 768 26 | self.embedding_dim = 300 27 | self.lr = 5e-5 # 5e-5 28 | self.crf_lr = 5e-2 # 5e-2 29 | self.weight_decay = 0.0 30 | self.epochs = 60 31 | self.max_grad_norm = 4 32 | self.patience = 60 33 | self.target_dir = '../../user_data/output_model/JointBert' 34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt' 35 | self.slot_label = '../../user_data/common_data/slot_label.txt' 36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt' 37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt' 38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt' 39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt' 40 | 41 | def update(self, **kwargs): 42 | for k, v in kwargs.items(): 43 | setattr(self, k, v) 44 | 45 | def __str__(self): 46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()]) 47 | 48 | 49 | if __name__ == '__main__': 50 | con = Config() 51 | con.update(gpu=8) 52 | print(con.gpu) 53 | print(con) 54 | -------------------------------------------------------------------------------- /data/code/scripts/config_Interact3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 15:13 4 | # @Author : JJkinging 5 | # @File : config.py 6 | class Config(object): 7 | '''配置类''' 8 | 9 | def __init__(self): 10 | self.train_file = '../../user_data/train_data/train_seq_in.txt' 11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt' 12 | self.test_file = '../../user_data/test_data/test_seq_in.txt' 13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt' 14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt' 15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt' 16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt' 17 | self.max_length = 512 18 | self.batch_size = 16 19 | self.test_batch_size = 32 20 | self.bert_model_path = '../../user_data/pretrained_model/ernie' 21 | self.checkpoint = None # '../../user_data/output_model/InteractModel_3/model_27.pth.tar' 22 | self.use_gpu = True 23 | self.cuda = "cuda:0" 24 | self.attention_dropout = 0.1 25 | self.bert_hidden_size = 768 26 | self.embedding_dim = 300 27 | self.lr = 5e-5 # 5e-5 28 | self.crf_lr = 5e-2 # 5e-2 29 | self.weight_decay = 0.0 30 | self.epochs = 60 31 | self.max_grad_norm = 4 32 | self.patience = 60 33 | self.target_dir = '../../user_data/output_model/InteractModel_3' 34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt' 35 | self.slot_label = '../../user_data/common_data/slot_label.txt' 36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt' 37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt' 38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt' 39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt' 40 | 41 | def update(self, **kwargs): 42 | for k, v in kwargs.items(): 43 | setattr(self, k, v) 44 | 45 | def __str__(self): 46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()]) 47 | 48 | 49 | if __name__ == '__main__': 50 | con = Config() 51 | con.update(gpu=8) 52 | print(con.gpu) 53 | print(con) 54 | -------------------------------------------------------------------------------- /data/code/scripts/config_Interact1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 15:13 4 | # @Author : JJkinging 5 | # @File : config.py 6 | class Config(object): 7 | '''配置类''' 8 | 9 | def __init__(self): 10 | self.train_file = '../../user_data/train_data/train_seq_in.txt' 11 | self.dev_file = '../../user_data/dev_data/dev_seq_in.txt' 12 | self.test_file = '../../user_data/test_data/test_seq_in.txt' 13 | self.intent_label_file = '../../user_data/common_data/intent_label.txt' 14 | self.vocab_file = '../../user_data/pretrained_model/ernie/vocab.txt' 15 | self.train_intent_file = '../../user_data/train_data/train_intent_label.txt' 16 | self.dev_intent_file = '../../user_data/dev_data/dev_intent_label.txt' 17 | self.max_length = 512 18 | self.batch_size = 16 19 | self.test_batch_size = 32 20 | self.bert_model_path = '../../user_data/pretrained_model/ernie' 21 | self.checkpoint = None # '../../user_data/output_model/InteractModel_1/trained_model/model_27.pth.tar' 22 | self.use_gpu = True 23 | self.cuda = "cuda:0" 24 | self.attention_dropout = 0.1 25 | self.bert_hidden_size = 768 26 | self.embedding_dim = 300 27 | self.lr = 5e-5 # 5e-5 28 | self.crf_lr = 5e-2 # 5e-2 29 | self.weight_decay = 0.0 30 | self.epochs = 60 31 | self.max_grad_norm = 4 32 | self.patience = 60 33 | self.target_dir = '../../user_data/output_model/InteractModel_1' 34 | self.slot_none_vocab = '../../user_data/common_data/slot_none_vocab.txt' 35 | self.slot_label = '../../user_data/common_data/slot_label.txt' 36 | self.train_slot_filename = '../../user_data/train_data/train_seq_out.txt' 37 | self.dev_slot_filename = '../../user_data/dev_data/dev_seq_out.txt' 38 | self.train_slot_none_filename = '../../user_data/train_data/train_slot_none.txt' 39 | self.dev_slot_none_filename = '../../user_data/dev_data/dev_slot_none.txt' 40 | 41 | def update(self, **kwargs): 42 | for k, v in kwargs.items(): 43 | setattr(self, k, v) 44 | 45 | def __str__(self): 46 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()]) 47 | 48 | 49 | if __name__ == '__main__': 50 | con = Config() 51 | con.update(gpu=8) 52 | print(con.gpu) 53 | print(con) 54 | -------------------------------------------------------------------------------- /data/code/preprocess/extend_tv_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/16 11:18 4 | # @Author : JJkinging 5 | # @File : extend_tv_sample.py 6 | '''数据增强: 扩充TVProgram-Play小样本数据集''' 7 | import json 8 | import random 9 | random.seed(1000) 10 | 11 | with open('../../small_sample/same_intent/TVProgram-Play.json', 'r', encoding='utf-8') as fp: 12 | ori_data = json.load(fp) 13 | 14 | with open('../../small_sample/tv_entity_dic.json', 'r', encoding='utf-8') as fp: 15 | tv_dic = json.load(fp) 16 | 17 | res = {} 18 | idx = 0 19 | for filename, single_data in ori_data.items(): 20 | text = single_data['text'] 21 | slots = single_data['slots'] 22 | for _ in range(30): # 每条扩充99条数据 23 | tmp_dic = {} 24 | slot_dic = {} 25 | length = len(str(idx)) 26 | o_num = 5 - length 27 | new_filename = 'NLU' + '0'*o_num + str(idx) 28 | idx += 1 29 | 30 | new_text = text 31 | for prefix in tv_dic['prefix']: 32 | if text[:-3].find(prefix) != -1: 33 | ran_i = random.randint(0, len(tv_dic['prefix'])-1) 34 | new_text = text.replace(prefix, tv_dic['prefix'][ran_i]) # 随机找一个prefix替换 35 | break 36 | else: 37 | ran_i = random.randint(0, len(tv_dic['prefix'])-1) 38 | new_text = tv_dic['prefix'][ran_i] + text 39 | if 'name' in slots.keys(): 40 | ran_i = random.randint(0, len(tv_dic['name'])-1) 41 | new_text = new_text.replace(slots['name'], tv_dic['name'][ran_i]) 42 | slot_dic['name'] = tv_dic['name'][ran_i] # 随机替换name 43 | if 'channel' in slots.keys(): 44 | ran_i = random.randint(0, len(tv_dic['channel']) - 1) 45 | new_text = new_text.replace(slots['channel'], tv_dic['channel'][ran_i]) 46 | slot_dic['channel'] = tv_dic['channel'][ran_i] # 随机替换name 47 | if 'datetime_date' in slots.keys(): 48 | ran_i = random.randint(0, len(tv_dic['datetime_date']) - 1) 49 | new_text = new_text.replace(slots['datetime_date'], tv_dic['datetime_date'][ran_i]) 50 | slot_dic['datetime_date'] = tv_dic['datetime_date'][ran_i] # 随机替换language 51 | if 'datetime_time' in slots.keys(): 52 | ran_i = random.randint(0, len(tv_dic['datetime_time']) - 1) 53 | new_text = new_text.replace(slots['datetime_time'], tv_dic['datetime_time'][ran_i]) 54 | slot_dic['datetime_time'] = tv_dic['datetime_time'][ran_i] # 随机替换language 55 | tmp_dic['text'] = new_text 56 | tmp_dic['intent'] = "TVProgram-Play" 57 | tmp_dic['slots'] = slot_dic 58 | res[new_filename] = tmp_dic 59 | 60 | 61 | print(res) 62 | print(len(res)) 63 | 64 | with open('../../extend_data/data/extend_tv.json', 'w', encoding='utf-8') as fp: 65 | json.dump(res, fp, ensure_ascii=False) 66 | 67 | -------------------------------------------------------------------------------- /data/code/preprocess/analysis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 17:07 4 | # @Author : JJkinging 5 | # @File : analysis.py 6 | import json 7 | 8 | with open('../raw_data/train.json', 'r', encoding='utf-8') as fp: 9 | data = json.load(fp) 10 | 11 | FilmTele_Play = set() 12 | Audio_Play = set() 13 | Radio_Listen = set() 14 | TVProgram_Play = set() 15 | Travel_Query = set() 16 | Music_Play = set() 17 | HomeAppliance_Control = set() 18 | Calendar_Query = set() 19 | Alarm_Update = set() 20 | Video_Play = set() 21 | Weather_Query = set() 22 | 23 | for filename, single_data in data.items(): 24 | intent = single_data['intent'] 25 | slots = single_data['slots'] 26 | if intent == 'FilmTele-Play': 27 | FilmTele_Play.update(slots.keys()) 28 | elif intent == 'Audio-Play': 29 | Audio_Play.update(slots.keys()) 30 | elif intent == 'Radio-Listen': 31 | Radio_Listen.update(slots.keys()) 32 | elif intent == 'TVProgram-Play': 33 | TVProgram_Play.update(slots.keys()) 34 | elif intent == 'Travel-Query': 35 | Travel_Query.update(slots.keys()) 36 | elif intent == 'Music-Play': 37 | Music_Play.update(slots.keys()) 38 | elif intent == 'HomeAppliance-Control': 39 | HomeAppliance_Control.update(slots.keys()) 40 | elif intent == 'Calendar-Query': 41 | Calendar_Query.update(slots.keys()) 42 | elif intent == 'Alarm-Update': 43 | Alarm_Update.update(slots.keys()) 44 | elif intent == 'Video-Play': 45 | Video_Play.update(slots.keys()) 46 | elif intent == 'Weather-Query': 47 | Weather_Query.update(slots.keys()) 48 | 49 | with open('../intent_classify/FilmTele_Play.txt', 'w', encoding='utf-8') as fp: 50 | fp.write(str(list(FilmTele_Play))) 51 | with open('../intent_classify/Audio_Play.txt', 'w', encoding='utf-8') as fp: 52 | fp.write(str(list(Audio_Play))) 53 | with open('../intent_classify/Radio_Listen.txt', 'w', encoding='utf-8') as fp: 54 | fp.write(str(list(Radio_Listen))) 55 | with open('../intent_classify/TVProgram_Play.txt', 'w', encoding='utf-8') as fp: 56 | fp.write(str(list(TVProgram_Play))) 57 | with open('../intent_classify/Travel_Query.txt', 'w', encoding='utf-8') as fp: 58 | fp.write(str(list(Travel_Query))) 59 | with open('../intent_classify/Music_Play.txt', 'w', encoding='utf-8') as fp: 60 | fp.write(str(list(Music_Play))) 61 | with open('../intent_classify/HomeAppliance_Control.txt', 'w', encoding='utf-8') as fp: 62 | fp.write(str(list(HomeAppliance_Control))) 63 | with open('../intent_classify/Calendar_Query.txt', 'w', encoding='utf-8') as fp: 64 | fp.write(str(list(Calendar_Query))) 65 | with open('../intent_classify/Alarm_Update.txt', 'w', encoding='utf-8') as fp: 66 | fp.write(str(list(Alarm_Update))) 67 | with open('../intent_classify/Video_Play.txt', 'w', encoding='utf-8') as fp: 68 | fp.write(str(list(Video_Play))) 69 | with open('../intent_classify/Weather_Query.txt', 'w', encoding='utf-8') as fp: 70 | fp.write(str(list(Weather_Query))) 71 | -------------------------------------------------------------------------------- /data/code/model/JointBertModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:40 4 | # @Author : JJkinging 5 | # @File : model.py 6 | import torch.nn as nn 7 | from transformers import BertModel 8 | from data.code.model.torchcrf import CRF 9 | import torch 10 | 11 | 12 | class JointBertModel(nn.Module): 13 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size, slot_tag_size, device): 14 | super(JointBertModel, self).__init__() 15 | self.bert_model_path = bert_model_path 16 | self.bert_hidden_size = bert_hidden_size 17 | self.intent_tag_size = intent_tag_size 18 | self.slot_none_tag_size = slot_none_tag_size 19 | self.slot_tag_size = slot_tag_size 20 | self.device = device 21 | self.bert = BertModel.from_pretrained(self.bert_model_path) 22 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True) 23 | 24 | self.intent_classification = nn.Linear(self.bert_hidden_size, self.intent_tag_size) 25 | self.slot_none_classification = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size) 26 | self.slot_classification = nn.Linear(self.bert_hidden_size, self.slot_tag_size) 27 | self.Dropout = nn.Dropout(p=0.5) 28 | 29 | def forward(self, input_ids, input_mask): 30 | batch_size = input_ids.size(0) 31 | seq_len = input_ids.size(1) 32 | utter_encoding = self.bert(input_ids, input_mask) 33 | sequence_output = utter_encoding[0] 34 | pooled_output = utter_encoding[1] 35 | 36 | pooled_output = self.Dropout(pooled_output) 37 | sequence_output = self.Dropout(sequence_output) 38 | 39 | intent_logits = self.intent_classification(pooled_output) # [batch_size, slot_tag_size] 40 | slot_logits = self.slot_classification(sequence_output) 41 | slot_none_logits = self.slot_none_classification(pooled_output) 42 | 43 | return intent_logits, slot_none_logits, slot_logits 44 | 45 | def slot_loss(self, feats, slot_ids, mask): 46 | ''' 做训练时用 47 | :param feats: the output of BiLSTM and Liner 48 | :param slot_ids: 49 | :param mask: 50 | :return: 51 | ''' 52 | feats = feats.to(self.device) 53 | slot_ids = slot_ids.to(self.device) 54 | mask = mask.to(self.device) 55 | loss_value = self.CRF(emissions=feats, 56 | tags=slot_ids, 57 | mask=mask, 58 | reduction='mean') 59 | return -loss_value 60 | 61 | def slot_predict(self, feats, mask, id2slot): 62 | feats = feats.to(self.device) 63 | mask = mask.to(self.device) 64 | slot2id = {value: key for key, value in id2slot.items()} 65 | # 做验证和测试时用 66 | out_path = self.CRF.decode(emissions=feats, mask=mask) 67 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] 68 | for out in out_path: 69 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等 70 | if tag.startswith('I-'): # 当前tag为I-开头 71 | if i == 0: # 0位置应该是[START] 72 | out[i] = '[START]' 73 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的 74 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag 75 | 76 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path] 77 | 78 | return out_path 79 | -------------------------------------------------------------------------------- /data/code/preprocess/rectify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/19 10:39 4 | # @Author : JJkinging 5 | # @File : rectify.py 6 | import json 7 | 8 | 9 | def rectify(source_path, target_path): 10 | with open(source_path, 'r', encoding='utf-8') as fp: 11 | ori_data = json.load(fp) 12 | 13 | for filename, single_data in ori_data.items(): 14 | text = single_data['text'] 15 | slots = single_data['slots'] 16 | intent = single_data['intent'] 17 | if intent == 'Alarm-Update': 18 | if 'datetime_time' in slots.keys(): 19 | idx_1 = text.find(':') 20 | idx_2 = text[idx_1+1:].find(':') 21 | 22 | if idx_1 != -1: 23 | start_idx = idx_1 24 | end_idx = idx_1 25 | for i in range(1, 3): # 找到第一个时间的起始index 26 | if text[start_idx - 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 27 | start_idx -= 1 28 | if text[end_idx + 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 29 | end_idx += 1 30 | if idx_2 != -1: # 说明有两个时间 31 | start_idx_2 = idx_2 + idx_1 + 1 # 修正一下 32 | end_idx_2 = start_idx_2 33 | for i in range(1, 3): # 找到第二个时间的起始index 34 | if text[start_idx_2 - 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 35 | start_idx_2 -= 1 36 | if text[end_idx_2 + 1] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: 37 | end_idx_2 += 1 38 | for item in slots['datetime_time']: 39 | if item in text[start_idx: end_idx+1]: 40 | slots['datetime_time'].remove(item) 41 | slots['datetime_time'].append(text[start_idx: end_idx+1]) 42 | elif item in text[start_idx_2: end_idx_2+1]: 43 | slots['datetime_time'].remove(item) 44 | slots['datetime_time'].append(text[start_idx_2: end_idx_2 + 1]) 45 | else: 46 | if isinstance(slots['datetime_time'], list): 47 | for item in slots['datetime_time']: 48 | if item in text[start_idx: end_idx + 1]: 49 | slots['datetime_time'].remove(item) 50 | slots['datetime_time'].append(text[start_idx: end_idx+1]) 51 | else: 52 | index = slots['datetime_time'].find(':') 53 | if index != -1: 54 | slots['datetime_time'] = slots['datetime_time'][:index] + text[idx_1: end_idx+1] 55 | else: 56 | slots['datetime_time'] = slots['datetime_time'] + text[idx_1: end_idx + 1] 57 | 58 | with open(target_path, 'w', encoding='utf-8') as fp: 59 | json.dump(ori_data, fp, ensure_ascii=False) 60 | 61 | 62 | if __name__ == "__main__": 63 | # source = '../small_sample/data/train_8_with_other.json' 64 | # target = '../small_sample/data/train_8_with_other_clear.json' 65 | source = '../small_sample/new_B/train_final.json' 66 | target = '../small_sample/new_B/train_final_clear.json' 67 | rectify(source, target) 68 | 69 | '''手动改'datetime_time': []''' 70 | 71 | '''先删除 72 | NLU12276 73 | "NLU11929": { 74 | "text": "能不能建一个今天晚上6:30~7:00提醒我钉钉直播会议", 75 | "intent": "Alarm-Update", 76 | "slots": { 77 | "datetime_date": "今天", 78 | "datetime_time": "晚上6:30~7:00", 79 | "notes": "钉钉直播会议" 80 | } 81 | }, 82 | 再添加 83 | ''' 84 | -------------------------------------------------------------------------------- /data/code/predict/nest.txt: -------------------------------------------------------------------------------- 1 | 放一个讲述美食的美国纪录片来看 {'name': '美食的美国纪录片', 'region': '美国'} 2 | O O O O O B-name I-name I-name B-region I-region I-name I-name I-name O O 3 | 王力宏在重庆开演唱会的视频放来看一下 {'name': '王力宏在重庆开演唱会的视频', 'region': '重庆'} 4 | B-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O O O 5 | 放一下钟汉良在2019年央视春晚上唱歌的视频 {'name': '钟汉良在2019年央视春晚上唱歌的视频', 'datetime_date': '2019年'} 6 | O O O B-name I-name I-name I-name B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name 7 | 帮我找下有没有关于北京的纪录片 {'region': '北京', 'name': '北京的纪录片'} 8 | O O O O O O O O O B-region I-region I-name I-name I-name I-name 9 | 给我播放2021中国女子足球锦标赛 {'datetime_date': '2021', 'region': '中国', 'name': '中国女子足球锦标赛'} 10 | O O O O B-datetime_date B-region I-region I-name I-name I-name I-name I-name I-name I-name 11 | 请播放钢琴奏鸣曲 {'instrument': '钢琴', 'song': '钢琴奏鸣曲'} 12 | O O O B-instrument I-instrument I-song I-song I-song 13 | 给我播放2012年刘翔参加钻石联赛上海站的视频 {'datetime_date': '2012年', 'name': '刘翔参加钻石联赛上海站的视频', 'region': '上海'} 14 | O O O O B-datetime_date I-datetime_date B-name I-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name 15 | 有没有介绍日本地理的纪录片,播来看看 {'region': '日本', 'name': '日本地理的纪录片'} 16 | O O O O O B-region I-region I-name I-name I-name I-name I-name I-name O O O O O 17 | 给我播放阿飞今年讲解游戏的系列视频吧 {'datetime_date': '今年', 'name': '今年讲解游戏的系列视频'} 18 | O O O O O O B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name O 19 | 超级想看辣目洋子今年在超新星运动会上表演的艺术体操视频 {'name': '辣目洋子今年在超新星运动会上表演的艺术体操视频', 'datetime_date': '今年'} 20 | O O O O B-name I-name I-name I-name B-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name 21 | 给我找一找吃货俱乐部在意大利拍的那一期 {'name': '吃货俱乐部在意大利拍的那一期', 'region': '意大利'} 22 | O O O O O B-name I-name I-name I-name I-name I-name B-region I-region I-region I-name I-name I-name I-name I-name 23 | 4月16日的斯里兰卡青睐中国这个视频能找到吗 {'datetime_date': '4月16日', 'region': '斯里兰卡', 'name': '斯里兰卡青睐中国这个视频'} 24 | B-datetime_date I-datetime_date I-datetime_date I-datetime_date O B-region I-region I-region I-region I-name I-name I-name I-name I-name I-name I-name I-name O O O O 25 | 周末特供2021-316的节目回看一下 {'datetime_date': '周末', 'name': '周末特供2021-316'} 26 | B-datetime_date I-datetime_date I-name I-name I-name I-name I-name O O O O O O O 27 | 我想看童蕾在6月15日拍摄都市丽人的花絮视频 {'name': '童蕾在6月15日拍摄都市丽人的花絮视频', 'datetime_date': '6月15日'} 28 | O O O B-name I-name I-name B-datetime_date I-datetime_date I-datetime_date I-datetime_date I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name 29 | 推荐一部催泪的美国纪录片来看 {'name': '催泪的美国纪录片', 'region': '美国'} 30 | O O O O B-name I-name I-name B-region I-region I-name I-name I-name O O 31 | 一场关于爱情的谎言在4月26播放的花絮有么 {'name': ['一场关于爱情的谎言', '4月26播放的花絮'], 'datetime_date': '4月26'} 32 | B-name I-name I-name I-name I-name I-name I-name I-name I-name O B-datetime_date I-datetime_date I-datetime_date I-name I-name I-name I-name I-name O O 33 | 我想看2012年麦当娜在巴黎举行演唱会的那个视频 {'datetime_date': '2012年', 'name': '麦当娜在巴黎举行演唱会的那个视频', 'region': '巴黎'} 34 | O O O B-datetime_date I-datetime_date B-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name 35 | 播放一个豆瓣评分最高的日本纪录片来看 {'name': '豆瓣评分最高的日本纪录片', 'region': '日本'} 36 | O O O O B-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name O O 37 | 给我放一个高桥留美子的日本动漫片剪辑版视频 {'name': '高桥留美子的日本动漫片剪辑版视频', 'region': '日本'} 38 | O O O O O B-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name 39 | 找几个美国动漫视频小短片来看看 {'region': '美国', 'name': '美国动漫视频小短片'} 40 | O O O B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O 41 | 吸血鬼侦探这个韩国电视剧的花絮给我播放一下 {'name': '吸血鬼侦探这个韩国电视剧的花絮', 'region': '韩国'} 42 | B-name I-name I-name I-name I-name I-name I-name B-region I-region I-name I-name I-name I-name I-name I-name O O O O O O 43 | 回放2008年的北京奥运会 {'datetime_date': '2008年', 'region': '北京', 'name': '北京奥运会'} 44 | O O B-datetime_date I-datetime_date O B-region I-region I-name I-name I-name 45 | 昨天法国报告“罕见” 变异新冠病毒的新闻视频再给我放下吧 {'datetime_date': '昨天', 'region': '法国', 'name': '法国报告“罕见” 变异新冠病毒的新闻视频'} 46 | B-datetime_date I-datetime_date B-region I-region I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name I-name O O O O O O 47 | -------------------------------------------------------------------------------- /data/code/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/5 15:21 4 | # @Author : JJkinging 5 | # @File : build_vocab.py 6 | import json 7 | import pickle 8 | 9 | import numpy as np 10 | from collections import Counter 11 | 12 | import torch 13 | 14 | 15 | def build_worddict(train_path, dev_path, test_path): 16 | ''' 17 | function:构建词典 18 | :param data: read_data返回的数据 19 | :return: 20 | ''' 21 | with open(train_path, 'r', encoding='utf-8') as fp: 22 | train_data = fp.readlines() 23 | train_data = [item.strip('\n').split(' ') for item in train_data] 24 | with open(dev_path, 'r', encoding='utf-8') as fp: 25 | dev_data = fp.readlines() 26 | dev_data = [item.strip('\n').split(' ') for item in dev_data] 27 | with open(test_path, 'r', encoding='utf-8') as fp: 28 | test_data = fp.readlines() 29 | test_data = [item.strip('\n').split(' ') for item in test_data] 30 | 31 | word_dict = {} 32 | words = [] 33 | lengths = [] 34 | for single_data in train_data: 35 | words.extend(single_data) 36 | lengths.append(len(single_data)) 37 | for single_data in dev_data: 38 | words.extend(single_data) 39 | lengths.append(len(single_data)) 40 | for single_data in test_data: 41 | words.extend(single_data) 42 | lengths.append(len(single_data)) 43 | print(len(words)) 44 | print(lengths) 45 | x = 0 46 | for len1 in lengths: 47 | if len1 >= 100: 48 | x += 1 49 | print('最长:', max(lengths)) 50 | print('最短:', min(lengths)) 51 | print('长度超过100(含)的句子:', x) 52 | 53 | counts = Counter(words) 54 | print(counts) 55 | num_words = len(counts) 56 | word_dict['[PAD]'] = 0 57 | word_dict['[CLS]'] = 1 58 | word_dict['[SEP]'] = 2 59 | 60 | offset = 3 61 | 62 | for i, word in enumerate(counts.most_common(num_words)): 63 | word_dict[word[0]] = i + offset 64 | # print(word_dict) 65 | return word_dict 66 | 67 | 68 | def build_embedding_matrix(embeddings_file, worddict): 69 | embeddings = {} 70 | with open(embeddings_file, "r", encoding="utf8") as input_data: 71 | for line in input_data: 72 | line = line.split() 73 | try: 74 | float(line[1]) 75 | word = line[0] 76 | if word in worddict: 77 | embeddings[word] = line[1:] 78 | 79 | # Ignore lines corresponding to multiple words separated 80 | # by spaces. 81 | except ValueError: 82 | continue 83 | 84 | num_words = len(worddict) 85 | embedding_dim = len(list(embeddings.values())[0]) 86 | embedding_matrix = np.zeros((num_words, embedding_dim)) 87 | 88 | # Actual building of the embedding matrix. 89 | missed = 0 90 | for word, i in worddict.items(): 91 | if word in embeddings: 92 | embedding_matrix[i] = np.array(embeddings[word], dtype=float) 93 | else: 94 | if word == "[PAD]": 95 | continue 96 | missed += 1 97 | # Out of vocabulary words are initialised with random gaussian 98 | # samples. 99 | embedding_matrix[i] = np.random.normal(size=(embedding_dim)) 100 | print("Missed words: ", missed) 101 | 102 | return embedding_matrix 103 | 104 | 105 | def word_to_indices(self, sentence): 106 | indices = [] 107 | for word in sentence: 108 | if word in self.word_dict: 109 | indices.append(self.word_dict[word]) 110 | else: 111 | indices.append(self.word_dict['UNK']) 112 | return indices 113 | 114 | 115 | def load_vocab(label_file): 116 | '''construct word2id or label2id''' 117 | vocab = {} 118 | index = 0 119 | with open(label_file, 'r', encoding='utf-8') as fp: 120 | while True: 121 | token = fp.readline() 122 | if not token: 123 | break 124 | token = token.strip() # 删除空白符 125 | vocab[token] = index 126 | index += 1 127 | return vocab 128 | 129 | 130 | if __name__ == "__main__": 131 | train_path = '../dataset/small_sample/data/train_seq_in.txt' 132 | dev_path = '../dataset/final_data/dev_seq_in.txt' 133 | test_path = '../script/test/dataset/test_seq_in.txt' 134 | word_dict = build_worddict(train_path, dev_path, test_path) 135 | print(word_dict) 136 | print(len(word_dict)) 137 | 138 | # label_file = '../dataset/new_data/tag.txt' 139 | # label_dict = load_vocab(label_file) 140 | # print(label_dict) 141 | # 保存embedding_file 142 | embedding_file = '../dataset/embeddings/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5' 143 | embed_matrix = build_embedding_matrix(embedding_file, word_dict) 144 | with open("../dataset/embeddings/embedding.pkl", "wb") as pkl_file: 145 | pickle.dump(embed_matrix, pkl_file) 146 | print(torch.tensor(embed_matrix).shape) -------------------------------------------------------------------------------- /data/code/scripts/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 12:39 4 | # @Author : JJkinging 5 | # @File : utils.py 6 | from torch.utils.data import Dataset, DataLoader 7 | from data.code.predict.test_utils import load_vocab, collate_to_max_length 8 | 9 | 10 | class CCFDataset(Dataset): 11 | def __init__(self, filename, intent_filename, slot_filename, slot_none_filename, vocab, intent_dict, 12 | slot_none_dict, slot_dict, max_length=512): 13 | ''' 14 | :param filename:读取数据文件名,例如:train_seq_in.txt 15 | :param intent_filename: train_intent_label.txt or dev_intent_label.txt 16 | :param slot_filename: train_seq_out.txt 17 | :param slot_none_filename: train_slot_none.txt or dev_slot_none.txt 18 | :param slot_none_dict: slot_none的字典 19 | :param slot_dict: slot_label的字典 20 | :param vocab: 词表,例如:bert的vocab.txt 21 | :param intent_dict: intent2id的字典 22 | :param max_length: 单句最大长度 23 | ''' 24 | self.filename = filename 25 | self.intent_filename = intent_filename 26 | self.slot_filename = slot_filename 27 | self.slot_none_filename = slot_none_filename 28 | self.vocab = vocab 29 | self.intent_dict = intent_dict 30 | self.slot_none_dict = slot_none_dict 31 | self.slot_dict = slot_dict 32 | self.max_length = max_length 33 | 34 | self.result = [] 35 | 36 | # 读取数据 37 | with open(self.filename, 'r', encoding='utf-8') as fp: 38 | sen_data = fp.readlines() 39 | sen_data = [item.strip('\n') for item in sen_data] # 删除句子结尾的换行符('\n') 40 | 41 | # 读取intent 42 | with open(self.intent_filename, 'r', encoding='utf-8') as fp: 43 | intent_data = fp.readlines() 44 | intent_data = [item.strip('\n') for item in intent_data] # 删除结尾的换行符('\n') 45 | intent_ids = [intent_dict[item] for item in intent_data] 46 | 47 | # 读取slot_none 48 | with open(self.slot_none_filename, 'r', encoding='utf-8') as fp: 49 | slot_none_data = fp.readlines() 50 | # 删除结尾的空格和换行符('\n') 51 | slot_none_data = [item.strip('\n').strip(' ').split(' ') for item in slot_none_data] 52 | # 下面列表表达式把slot_none转为id 53 | slot_none_ids = [[self.slot_none_dict[ite] for ite in item] for item in slot_none_data] 54 | 55 | # 读取slot 56 | with open(self.slot_filename, 'r', encoding='utf-8') as fp: 57 | slot_data = fp.readlines() 58 | slot_data = [item.strip('\n') for item in slot_data] # 删除句子结尾的换行符('\n') 59 | # slot_ids = [self.slot_dict[item] for item in slot_data] 60 | 61 | idx = 0 62 | for utterance in sen_data: 63 | utterance = utterance.split(' ') # str变list 64 | slot_utterence = slot_data[idx].split(' ') 65 | # 最大长度检验 66 | if len(utterance) > self.max_length-2: 67 | utterance = utterance[:max_length] 68 | slot_utterence = slot_utterence[:max_length] 69 | 70 | # input_ids 71 | utterance = ['[CLS]'] + utterance + ['[SEP]'] 72 | input_ids = [int(self.vocab[i]) for i in utterance] 73 | 74 | 75 | length = len(input_ids) 76 | 77 | # slot_ids 78 | slot_utterence = ['[START]'] + slot_utterence + ['[EOS]'] 79 | slot_ids = [int(self.slot_dict[i]) for i in slot_utterence] 80 | 81 | # input_mask 82 | input_mask = [1] * len(input_ids) 83 | 84 | # intent_ids 85 | intent_id = intent_ids[idx] 86 | 87 | # slot_none_ids 88 | slot_none_id = slot_none_ids[idx] # slot_none_id 为 int or list 89 | 90 | idx += 1 91 | 92 | self.result.append((input_ids, slot_ids, input_mask, intent_id, slot_none_id)) 93 | 94 | def __len__(self): 95 | return len(self.result) 96 | 97 | def __getitem__(self, index): 98 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = self.result[index] 99 | 100 | return input_ids, slot_ids, input_mask, intent_id, slot_none_id 101 | 102 | 103 | if __name__ == "__main__": 104 | filename = '../dataset/final_data/train_seq_in.txt' 105 | vocab_file = '../dataset/pretrained_model/erine/vocab.txt' 106 | intent_filename = '../dataset/final_data/train_intent_label.txt' 107 | slot_filename = '../dataset/final_data/train_seq_out.txt' 108 | slot_none_filename = '../dataset/final_data/train_slot_none.txt' 109 | intent_label = '../dataset/final_data/intent_label.txt' 110 | slot_label = '../dataset/final_data/slot_label.txt' 111 | slot_none_vocab = '../dataset/final_data/slot_none_vocab.txt' 112 | intent_dict = load_vocab(intent_label) 113 | slot_dict = load_vocab(slot_label) 114 | slot_none_dict = load_vocab(slot_none_vocab) 115 | vocab = load_vocab(vocab_file) 116 | 117 | dataset = CCFDataset(filename, intent_filename, slot_filename, slot_none_filename, vocab, intent_dict, 118 | slot_none_dict, slot_dict) 119 | 120 | dataloader = DataLoader(dataset, shuffle=False, batch_size=8, collate_fn=collate_to_max_length) 121 | 122 | for batch in dataloader: 123 | print(batch) 124 | break 125 | -------------------------------------------------------------------------------- /data/code/preprocess/extend_audio_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/15 21:08 4 | # @Author : JJkinging 5 | # @File : extend_small_sample.py 6 | '''数据增强: 扩充Audio-Play小样本数据集''' 7 | import json 8 | import random 9 | random.seed(1000) 10 | 11 | with open('../../extend_data/same_intent/ori_data/Audio-Play.json', 'r', encoding='utf-8') as fp: 12 | ori_data = json.load(fp) 13 | 14 | with open('../../small_sample/audio_entity_dic.json', 'r', encoding='utf-8') as fp: 15 | audio_dic = json.load(fp) 16 | 17 | res = {} 18 | idx = 0 19 | for filename, single_data in ori_data.items(): 20 | text = single_data['text'] 21 | slots = single_data['slots'] 22 | for _ in range(20): # 每条扩充20条数据 23 | tmp_dic = {} 24 | slot_dic = {} 25 | length = len(str(idx)) 26 | o_num = 5 - length 27 | new_filename = 'NLU' + '0'*o_num + str(idx) 28 | idx += 1 29 | 30 | new_text = text 31 | for prefix in audio_dic['prefix']: 32 | if text.find(prefix) != -1: 33 | ran_i = random.randint(0, len(audio_dic['prefix'])-1) 34 | new_text = text.replace(prefix, audio_dic['prefix'][ran_i]) # 随机找一个prefix替换 35 | break 36 | else: 37 | ran_i = random.randint(0, len(audio_dic['prefix'])-1) 38 | new_text = audio_dic['prefix'][ran_i] + text 39 | if 'name' in slots.keys(): 40 | ran_i = random.randint(0, len(audio_dic['name'])-1) 41 | new_text = new_text.replace(slots['name'], audio_dic['name'][ran_i]) 42 | slot_dic['name'] = audio_dic['name'][ran_i] # 随机替换name 43 | if 'artist' in slots.keys(): 44 | ran_i = random.randint(0, len(audio_dic['artist']) - 1) 45 | if isinstance(slots['artist'], list): 46 | tmp_list = [] 47 | for art in slots['artist']: 48 | ran_j = random.randint(0, len(audio_dic['artist']) - 1) 49 | new_text = new_text.replace(art, audio_dic['artist'][ran_j]) 50 | tmp_list.append(audio_dic['artist'][ran_j]) 51 | slot_dic['artist'] = tmp_list 52 | else: 53 | new_text = new_text.replace(slots['artist'], audio_dic['artist'][ran_i]) 54 | slot_dic['artist'] = audio_dic['artist'][ran_i] # 随机替换artist 55 | if 'play_setting' in slots.keys(): 56 | ran_i = random.randint(0, len(audio_dic['play_setting']) - 1) 57 | if isinstance(slots['play_setting'], list): 58 | tmp_list = [] 59 | for set1 in slots['play_setting']: 60 | ran_j = random.randint(0, len(audio_dic['play_setting']) - 1) 61 | if set1 in audio_dic['play_setting']: 62 | new_text = new_text.replace(set1, audio_dic['play_setting'][ran_j]) 63 | tmp_list.append(audio_dic['play_setting'][ran_j]) 64 | else: 65 | tmp_list.append(set1) 66 | slot_dic['play_setting'] = tmp_list 67 | else: 68 | if slots['play_setting'] in audio_dic['play_setting']: 69 | new_text = new_text.replace(slots['play_setting'], audio_dic['play_setting'][ran_i]) 70 | slot_dic['play_setting'] = audio_dic['play_setting'][ran_i] # 随机替换play_setting 71 | if 'language' in slots.keys(): 72 | ran_i = random.randint(0, len(audio_dic['language']) - 1) 73 | if slots['language'] == "俄语" and audio_dic['language'][ran_i] != "俄语": 74 | new_text = new_text.replace(slots['language'], audio_dic['language'][ran_i][:-1]+'文') 75 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language 76 | elif slots['language'] == "俄语" and audio_dic['language'][ran_i] == "俄语": 77 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language 78 | elif slots['language'] == "华语" and audio_dic['language'][ran_i] != "华语": 79 | new_text = new_text.replace("中文", audio_dic['language'][ran_i][:-1] + '文') 80 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language 81 | elif slots['language'] == "华语" and audio_dic['language'][ran_i] == "华语": 82 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language 83 | else: 84 | new_text = new_text.replace(slots['language'][:-1]+'文', audio_dic['language'][ran_i][:-1] + '文') 85 | slot_dic['language'] = audio_dic['language'][ran_i] # 随机替换language 86 | if 'tag' in slots.keys(): 87 | ran_i = random.randint(0, len(audio_dic['tag']) - 1) 88 | new_text = new_text.replace(slots['tag'], audio_dic['tag'][ran_i]) 89 | slot_dic['tag'] = audio_dic['tag'][ran_i] # 随机替换language 90 | tmp_dic['text'] = new_text 91 | tmp_dic['intent'] = "Audio-Play" 92 | tmp_dic['slots'] = slot_dic 93 | res[new_filename] = tmp_dic 94 | 95 | 96 | print(res) 97 | print(len(res)) 98 | 99 | with open('../../extend_data/data/extend_audio.json', 'w', encoding='utf-8') as fp: 100 | json.dump(res, fp, ensure_ascii=False) 101 | 102 | # 手动替换 华语|华文 ——> 中文 西班牙文 ——> 西班牙语 103 | -------------------------------------------------------------------------------- /data/code/predict/integration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/20 18:35 4 | # @Author : JJkinging 5 | # @File : integration.py 6 | import json 7 | import os 8 | 9 | from data.code.predict.post_process import process 10 | 11 | with open('../../user_data/tmp_result/result_interact_1_post.json', 'r', encoding='utf-8') as fp: 12 | best_data = json.load(fp) 13 | with open('../../user_data/tmp_result/result_interact_3_post.json', 'r', encoding='utf-8') as fp: 14 | second_data = json.load(fp) 15 | with open('../../user_data/tmp_result/result_bert_post.json', 'r', encoding='utf-8') as fp: 16 | third_data = json.load(fp) 17 | 18 | idx_list = [] 19 | res = {} 20 | for i in range(len(best_data)): 21 | lengths = len(str(i)) 22 | o_nums = 5 - lengths 23 | idx = 'NLU' + '0'*o_nums + str(i) 24 | tmp = {} 25 | 26 | best_dic = best_data[idx] 27 | second_dic = second_data[idx] 28 | third_dic = third_data[idx] 29 | # 调整intent 30 | if best_dic['intent'] == second_dic['intent'] == third_dic['intent'] or \ 31 | best_dic['intent'] == second_dic['intent'] or best_dic['intent'] == third_dic['intent'] or \ 32 | (best_dic['intent'] != second_dic['intent'] and best_dic['intent'] != third_dic['intent'] and 33 | second_dic['intent'] != third_dic['intent']): 34 | intent = best_dic['intent'] 35 | 36 | elif second_dic['intent'] == third_dic['intent']: 37 | intent = second_dic['intent'] 38 | 39 | slot_dic = {} 40 | # 调整slot 41 | best_slots = best_dic['slots'] 42 | second_slots = second_dic['slots'] 43 | third_slots = third_dic['slots'] 44 | 45 | total_slot_keys = set() 46 | total_slot_keys.update(list(best_slots.keys())) 47 | total_slot_keys.update(list(second_slots.keys())) 48 | total_slot_keys.update(list(third_slots.keys())) 49 | 50 | second_key = [] 51 | third_key = [] 52 | for key in total_slot_keys: 53 | if key in best_slots.keys() and key in second_slots.keys() and key in third_slots.keys(): # 若三者都有 54 | if isinstance(best_slots[key], list) and isinstance(second_slots[key], list) and \ 55 | isinstance(third_slots[key], list): 56 | best_slots[key] = set(best_slots[key]) 57 | second_slots[key] = set(second_slots[key]) 58 | third_slots[key] = set(third_slots[key]) 59 | 60 | if best_slots[key] == second_slots[key] == third_slots[key] or \ 61 | best_slots[key] == second_slots[key] or best_slots[key] == third_slots[key] or \ 62 | (best_slots[key] != second_slots[key] and best_slots[key] != third_slots[key] and 63 | second_slots[key] != third_slots[key]): 64 | if isinstance(best_slots[key], set): 65 | slot_dic[key] = list(best_slots[key]) 66 | else: 67 | slot_dic[key] = best_slots[key] 68 | elif second_slots[key] == third_slots[key]: 69 | if isinstance(second_slots[key], set): 70 | slot_dic[key] = list(second_slots[key]) 71 | else: 72 | slot_dic[key] = second_slots[key] 73 | elif key in best_slots.keys() and key in second_slots.keys(): # 只有best 和 second 有 74 | if isinstance(best_slots[key], list) and isinstance(second_slots[key], list): 75 | best_slots[key] = set(best_slots[key]) 76 | second_slots[key] = set(second_slots[key]) 77 | 78 | if isinstance(best_slots[key], set): 79 | slot_dic[key] = list(best_slots[key]) 80 | else: 81 | slot_dic[key] = best_slots[key] 82 | elif key in best_slots.keys() and key in third_slots.keys(): 83 | if isinstance(best_slots[key], list) and isinstance(third_slots[key], list): 84 | best_slots[key] = set(best_slots[key]) 85 | third_slots[key] = set(third_slots[key]) 86 | if isinstance(best_slots[key], set): 87 | slot_dic[key] = list(best_slots[key]) 88 | else: 89 | slot_dic[key] = best_slots[key] 90 | elif key in second_slots.keys() and key in third_slots.keys(): 91 | if isinstance(second_slots[key], list) and isinstance(third_slots[key], list): 92 | second_slots[key] = set(second_slots[key]) 93 | third_slots[key] = set(third_slots[key]) 94 | if isinstance(second_slots[key], set): 95 | slot_dic[key] = list(second_slots[key]) 96 | else: 97 | slot_dic[key] = second_slots[key] 98 | elif key in best_slots.keys(): 99 | slot_dic[key] = best_slots[key] 100 | elif key in second_slots.keys(): 101 | second_key.append(key) 102 | elif key in third_slots.keys(): 103 | third_key.append(key) 104 | 105 | for key in second_key: 106 | if second_slots[key] not in slot_dic.values(): 107 | slot_dic[key] = second_slots[key] 108 | for key in third_key: 109 | if third_slots[key] not in slot_dic.values(): 110 | slot_dic[key] = third_slots[key] 111 | 112 | tmp['intent'] = intent 113 | tmp['slots'] = slot_dic 114 | res[idx] = tmp 115 | 116 | with open('../../prediction_result/result_tmp.json', 'w', encoding='utf-8') as fp: 117 | json.dump(res, fp, ensure_ascii=False) 118 | 119 | source_path = '../../prediction_result/result_tmp.json' 120 | target_path = '../../prediction_result/result.json' 121 | process(source_path, target_path) 122 | os.remove('../../prediction_result/result_tmp.json') 123 | -------------------------------------------------------------------------------- /data/code/scripts/train_interact1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:41 4 | # @Author : JJkinging 5 | # @File : main.py 6 | 7 | import os 8 | import torch 9 | import random 10 | import numpy as np 11 | import warnings 12 | from transformers import get_linear_schedule_with_warmup 13 | from data.code.scripts.dataset import CCFDataset 14 | from data.code.scripts.config_Interact1 import Config 15 | from torch.utils.data import DataLoader 16 | from data.code.model.InteractModel_1 import InteractModel 17 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab 18 | 19 | 20 | def torch_seed(seed): 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) # Numpy module. 26 | random.seed(seed) # Python random module. 27 | torch.manual_seed(seed) 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.enabled = False 31 | 32 | 33 | def main(): 34 | warnings.filterwarnings("ignore") 35 | torch_seed(1000) 36 | # 设置GPU数目 37 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 38 | config = Config() 39 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 40 | print('loading corpus') 41 | vocab = load_vocab(config.vocab_file) 42 | intent_dict = load_vocab(config.intent_label_file) 43 | slot_none_dict = load_vocab(config.slot_none_vocab) 44 | slot_dict = load_vocab(config.slot_label) 45 | intent_tagset_size = len(intent_dict) 46 | slot_none_tag_size = len(slot_none_dict) 47 | slot_tag_size = len(slot_dict) 48 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename, 49 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 50 | config.max_length) 51 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, 52 | collate_fn=collate_to_max_length) 53 | 54 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename, 55 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 56 | config.max_length) 57 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size, 58 | collate_fn=collate_to_max_length) 59 | 60 | model = InteractModel(config.bert_model_path, 61 | config.bert_hidden_size, 62 | intent_tagset_size, 63 | slot_none_tag_size, 64 | slot_tag_size, 65 | device).to(device) 66 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0]) 67 | I_criterion = torch.nn.CrossEntropyLoss() 68 | N_criterion = torch.nn.BCEWithLogitsLoss() 69 | 70 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id 71 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter) 72 | 73 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr}, 74 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay) 75 | # optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0.0) 76 | total_step = len(train_loader) // config.batch_size * config.epochs 77 | scheduler = get_linear_schedule_with_warmup(optimizer, 78 | num_warmup_steps=0, 79 | num_training_steps=total_step) 80 | 81 | best_score = 0.0 82 | start_epoch = 1 83 | # Data for loss curves plot. 84 | epochs_count = [] 85 | train_losses = [] 86 | valid_losses = [] 87 | 88 | # Continuing training from a checkpoint if one was given as argument. 89 | if config.checkpoint: 90 | checkpoint = torch.load(config.checkpoint) 91 | start_epoch = checkpoint["epoch"] + 1 92 | best_score = checkpoint["best_score"] 93 | 94 | print("\t* Training will continue on existing model from epoch {}..." 95 | .format(start_epoch)) 96 | 97 | model.load_state_dict(checkpoint["model"]) 98 | optimizer.load_state_dict(checkpoint["optimizer"]) 99 | epochs_count = checkpoint["epochs_count"] 100 | train_losses = checkpoint["train_losses"] 101 | valid_losses = checkpoint["valid_losses"] 102 | 103 | # Compute loss and accuracy before starting (or resuming) training. 104 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 105 | dev_loader, 106 | I_criterion, 107 | N_criterion) 108 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 109 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 110 | 111 | # -------------------- Training epochs ------------------- # 112 | print("\n", 113 | 20 * "=", 114 | "Training Model model on device: {}".format(device), 115 | 20 * "=") 116 | for epoch in range(start_epoch, config.epochs+1): 117 | epochs_count.append(epoch) 118 | print("* Training epoch {}:".format(epoch)) 119 | train_time, train_loss = train(model, 120 | train_loader, 121 | optimizer, 122 | I_criterion, 123 | N_criterion, 124 | config.max_grad_norm) 125 | train_losses.append(train_loss) 126 | print("-> Training time: {:.4f}s loss = {:.4f}" 127 | .format(train_time, train_loss)) 128 | with open('../../user_data/output_model/InteractModel_1/metric.txt', 'a', encoding='utf-8') as fp: 129 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t') 130 | 131 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 132 | dev_loader, 133 | I_criterion, 134 | N_criterion, 135 | ) 136 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 137 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 138 | with open('../../user_data/output_model/InteractModel_1/metric.txt', 'a', encoding='utf-8') as fp: 139 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' + 140 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' + 141 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t' 142 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n') 143 | 144 | valid_losses.append(valid_losses) 145 | # Update the optimizer's learning rate with the scheduler. 146 | scheduler.step() 147 | if sen_acc >= best_score: 148 | best_score = sen_acc 149 | torch.save({"epoch": epoch, 150 | "model": model.state_dict(), 151 | "best_score": best_score, 152 | "optimizer": optimizer.state_dict(), 153 | "epochs_count": epochs_count, 154 | "train_losses": train_losses, 155 | "valid_losses": valid_losses}, 156 | os.path.join(config.target_dir, "Interact1_model_best.pth.tar")) 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /data/code/scripts/train_interact3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:41 4 | # @Author : JJkinging 5 | # @File : main.py 6 | 7 | import os 8 | import warnings 9 | 10 | import torch 11 | import random 12 | import numpy as np 13 | from transformers import get_linear_schedule_with_warmup 14 | from data.code.scripts.dataset import CCFDataset 15 | from data.code.scripts.config_Interact3 import Config 16 | from torch.utils.data import DataLoader 17 | from data.code.model.InteractModel_3 import InteractModel 18 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab 19 | 20 | 21 | def torch_seed(seed): 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | np.random.seed(seed) # Numpy module. 27 | random.seed(seed) # Python random module. 28 | torch.manual_seed(seed) 29 | torch.backends.cudnn.benchmark = False 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.enabled = False 32 | 33 | 34 | def main(): 35 | torch_seed(1000) 36 | warnings.filterwarnings("ignore") 37 | # 设置GPU数目 38 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 39 | config = Config() 40 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 41 | print('loading corpus') 42 | vocab = load_vocab(config.vocab_file) 43 | intent_dict = load_vocab(config.intent_label_file) 44 | slot_none_dict = load_vocab(config.slot_none_vocab) 45 | slot_dict = load_vocab(config.slot_label) 46 | intent_tagset_size = len(intent_dict) 47 | slot_none_tag_size = len(slot_none_dict) 48 | slot_tag_size = len(slot_dict) 49 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename, 50 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 51 | config.max_length) 52 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, 53 | collate_fn=collate_to_max_length) 54 | 55 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename, 56 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 57 | config.max_length) 58 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size, 59 | collate_fn=collate_to_max_length) 60 | 61 | model = InteractModel(config.bert_model_path, 62 | config.bert_hidden_size, 63 | intent_tagset_size, 64 | slot_none_tag_size, 65 | slot_tag_size, 66 | device).to(device) 67 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0]) 68 | I_criterion = torch.nn.CrossEntropyLoss() 69 | N_criterion = torch.nn.BCEWithLogitsLoss() 70 | 71 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id 72 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter) 73 | 74 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr}, 75 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay) 76 | total_step = len(train_loader) // config.batch_size * config.epochs 77 | scheduler = get_linear_schedule_with_warmup(optimizer, 78 | num_warmup_steps=0, 79 | num_training_steps=total_step) 80 | 81 | best_score = 0.0 82 | start_epoch = 1 83 | # Data for loss curves plot. 84 | epochs_count = [] 85 | train_losses = [] 86 | valid_losses = [] 87 | 88 | # Continuing training from a checkpoint if one was given as argument. 89 | if config.checkpoint: 90 | checkpoint = torch.load(config.checkpoint) 91 | start_epoch = checkpoint["epoch"] + 1 92 | best_score = checkpoint["best_score"] 93 | 94 | print("\t* Training will continue on existing model from epoch {}..." 95 | .format(start_epoch)) 96 | 97 | model.load_state_dict(checkpoint["model"]) 98 | optimizer.load_state_dict(checkpoint["optimizer"]) 99 | epochs_count = checkpoint["epochs_count"] 100 | train_losses = checkpoint["train_losses"] 101 | valid_losses = checkpoint["valid_losses"] 102 | 103 | # Compute loss and accuracy before starting (or resuming) training. 104 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 105 | dev_loader, 106 | I_criterion, 107 | N_criterion) 108 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 109 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 110 | 111 | # -------------------- Training epochs ------------------- # 112 | print("\n", 113 | 20 * "=", 114 | "Training Model model on device: {}".format(device), 115 | 20 * "=") 116 | for epoch in range(start_epoch, config.epochs+1): 117 | epochs_count.append(epoch) 118 | print("* Training epoch {}:".format(epoch)) 119 | train_time, train_loss = train(model, 120 | train_loader, 121 | optimizer, 122 | I_criterion, 123 | N_criterion, 124 | config.max_grad_norm) 125 | train_losses.append(train_loss) 126 | print("-> Training time: {:.4f}s loss = {:.4f}" 127 | .format(train_time, train_loss)) 128 | with open('../../user_data/output_model/InteractModel_3/trained_model/metric.txt', 'a', encoding='utf-8') as fp: 129 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t') 130 | 131 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 132 | dev_loader, 133 | I_criterion, 134 | N_criterion, 135 | ) 136 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 137 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 138 | with open('../../user_data/output_model/InteractModel_3/trained_model/metric.txt', 'a', encoding='utf-8') as fp: 139 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' + 140 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' + 141 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t' 142 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n') 143 | 144 | valid_losses.append(valid_losses) 145 | # Update the optimizer's learning rate with the scheduler. 146 | scheduler.step() 147 | 148 | # Early stopping on validation accuracy. 149 | if sen_acc >= best_score: 150 | best_score = sen_acc 151 | torch.save({"epoch": epoch, 152 | "model": model.state_dict(), 153 | "best_score": best_score, 154 | "optimizer": optimizer.state_dict(), 155 | "epochs_count": epochs_count, 156 | "train_losses": train_losses, 157 | "valid_losses": valid_losses}, 158 | os.path.join(config.target_dir, "Interact3_model_best.pth.tar")) 159 | 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /data/code/scripts/train_jointBert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:41 4 | # @Author : JJkinging 5 | # @File : main.py 6 | 7 | import os 8 | import warnings 9 | 10 | import torch 11 | import random 12 | import numpy as np 13 | from torch.utils.data import DataLoader 14 | from transformers import get_linear_schedule_with_warmup 15 | from data.code.scripts.config_jointBert import Config 16 | from data.code.scripts.dataset import CCFDataset 17 | from data.code.model.JointBertModel import JointBertModel 18 | from data.code.scripts.utils import train, valid, collate_to_max_length, load_vocab 19 | 20 | 21 | def torch_seed(seed): 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | np.random.seed(seed) # Numpy module. 27 | random.seed(seed) # Python random module. 28 | torch.manual_seed(seed) 29 | torch.backends.cudnn.benchmark = False 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.enabled = False 32 | 33 | 34 | def main(): 35 | warnings.filterwarnings("ignore") 36 | torch_seed(1000) 37 | # 设置GPU数目 38 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 39 | config = Config() 40 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 41 | print('loading corpus') 42 | vocab = load_vocab(config.vocab_file) 43 | intent_dict = load_vocab(config.intent_label_file) 44 | slot_none_dict = load_vocab(config.slot_none_vocab) 45 | slot_dict = load_vocab(config.slot_label) 46 | intent_tagset_size = len(intent_dict) 47 | slot_none_tag_size = len(slot_none_dict) 48 | slot_tag_size = len(slot_dict) 49 | train_dataset = CCFDataset(config.train_file, config.train_intent_file, config.train_slot_filename, 50 | config.train_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 51 | config.max_length) 52 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, 53 | collate_fn=collate_to_max_length) 54 | 55 | dev_dataset = CCFDataset(config.dev_file, config.dev_intent_file, config.dev_slot_filename, 56 | config.dev_slot_none_filename, vocab, intent_dict, slot_none_dict, slot_dict, 57 | config.max_length) 58 | dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size, 59 | collate_fn=collate_to_max_length) 60 | 61 | model = JointBertModel(config.bert_model_path, 62 | config.bert_hidden_size, 63 | intent_tagset_size, 64 | slot_none_tag_size, 65 | slot_tag_size, 66 | device).to(device) 67 | # model = torch.nn.DataParallel(model, device_ids=[0, 1], output_device=[0]) 68 | 69 | I_criterion = torch.nn.CrossEntropyLoss() 70 | N_criterion = torch.nn.BCEWithLogitsLoss() 71 | 72 | crf_params = list(map(id, model.CRF.parameters())) # 把CRF层的参数映射为id 73 | other_params = filter(lambda x: id(x) not in crf_params, model.parameters()) # 在整个模型的参数中将CRF层的参数过滤掉(filter) 74 | 75 | optimizer = torch.optim.AdamW([{'params': model.CRF.parameters(), 'lr': config.crf_lr}, 76 | {'params': other_params, 'lr': config.lr}], weight_decay=config.weight_decay) 77 | # optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0.0) 78 | total_step = len(train_loader) // config.batch_size * config.epochs 79 | scheduler = get_linear_schedule_with_warmup(optimizer, 80 | num_warmup_steps=0, 81 | num_training_steps=total_step) 82 | 83 | best_score = 0.0 84 | start_epoch = 1 85 | # Data for loss curves plot. 86 | epochs_count = [] 87 | train_losses = [] 88 | valid_losses = [] 89 | 90 | # Continuing training from a checkpoint if one was given as argument. 91 | if config.checkpoint: 92 | checkpoint = torch.load(config.checkpoint) 93 | start_epoch = checkpoint["epoch"] + 1 94 | best_score = checkpoint["best_score"] 95 | 96 | print("\t* Training will continue on existing model from epoch {}..." 97 | .format(start_epoch)) 98 | 99 | model.load_state_dict(checkpoint["model"]) 100 | optimizer.load_state_dict(checkpoint["optimizer"]) 101 | epochs_count = checkpoint["epochs_count"] 102 | train_losses = checkpoint["train_losses"] 103 | valid_losses = checkpoint["valid_losses"] 104 | 105 | # Compute loss and accuracy before starting (or resuming) training. 106 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 107 | dev_loader, 108 | I_criterion, 109 | N_criterion) 110 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 111 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 112 | 113 | # -------------------- Training epochs ------------------- # 114 | print("\n", 115 | 20 * "=", 116 | "Training Model model on device: {}".format(device), 117 | 20 * "=") 118 | for epoch in range(start_epoch, config.epochs+1): 119 | epochs_count.append(epoch) 120 | print("* Training epoch {}:".format(epoch)) 121 | train_time, train_loss = train(model, 122 | train_loader, 123 | optimizer, 124 | I_criterion, 125 | N_criterion, 126 | config.max_grad_norm) 127 | train_losses.append(train_loss) 128 | print("-> Training time: {:.4f}s loss = {:.4f}" 129 | .format(train_time, train_loss)) 130 | with open('../../user_data/output_model/JointBert/metric.txt', 'a', encoding='utf-8') as fp: 131 | fp.write('Epoch:' + str(epoch) + '\t' + 'Loss:' + str(round(train_loss, 4)) + '\t') 132 | 133 | valid_time, valid_loss, intent_accuracy, slot_none, slot, sen_acc = valid(model, 134 | dev_loader, 135 | I_criterion, 136 | N_criterion, 137 | ) 138 | print("-> Valid time: {:.4f}s loss = {:.4f} intentAcc: {:.4f} slot_none: {:.4f} slot_F1: {:.4f} SEN_ACC: {:.4f}" 139 | .format(valid_time, valid_loss, intent_accuracy, slot_none[2], slot[2], sen_acc)) 140 | with open('../../user_data/output_model/JointBert/metric.txt', 'a', encoding='utf-8') as fp: 141 | fp.write('Loss:' + str(round(valid_loss, 4)) + '\t' + 'Intent_acc:' + 142 | str(round(intent_accuracy, 4)) + '\t' + 'slot_none:' + 143 | str(round(slot_none[2], 4)) + '\t''slot_F1:' + str(round(slot[2], 4)) + '\t' 144 | + 'SEN_ACC:' + str(round(sen_acc, 4)) + '\n') 145 | 146 | valid_losses.append(valid_losses) 147 | # Update the optimizer's learning rate with the scheduler. 148 | scheduler.step() 149 | 150 | # Early stopping on validation accuracy. 151 | if sen_acc >= best_score: 152 | best_score = sen_acc 153 | torch.save({"epoch": epoch, 154 | "model": model.state_dict(), 155 | "best_score": best_score, 156 | "optimizer": optimizer.state_dict(), 157 | "epochs_count": epochs_count, 158 | "train_losses": train_losses, 159 | "valid_losses": valid_losses}, 160 | os.path.join(config.target_dir, "bert_model_best.pth.tar")) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /data/code/scripts/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 13:09 4 | # @Author : JJkinging 5 | # @File : utils.py 6 | import time 7 | import torch 8 | from tqdm import tqdm 9 | from sklearn.metrics import precision_score, recall_score, f1_score 10 | from seqeval.metrics import precision_score as slot_precision_score, recall_score as slot_recall_score, \ 11 | f1_score as slot_F1_score 12 | from data.code.scripts.config_jointBert import Config 13 | 14 | 15 | def load_vocab(vocab_file): 16 | '''construct word2id''' 17 | vocab = {} 18 | index = 0 19 | with open(vocab_file, 'r', encoding='utf-8') as fp: 20 | while True: 21 | token = fp.readline() 22 | if not token: 23 | break 24 | token = token.strip() # 删除空白符 25 | vocab[token] = index 26 | index += 1 27 | return vocab 28 | 29 | 30 | def collate_to_max_length(batch): 31 | # input_ids, slot_ids, input_mask, intent_id, slot_none_id 32 | batch_size = len(batch) 33 | input_ids_list = [] 34 | slot_ids_list = [] 35 | input_mask_list = [] 36 | intent_ids_list = [] 37 | slot_none_ids_list = [] 38 | for single_data in batch: 39 | input_ids_list.append(single_data[0]) 40 | slot_ids_list.append(single_data[1]) 41 | input_mask_list.append(single_data[2]) 42 | intent_ids_list.append(single_data[3]) 43 | slot_none_ids_list.append(single_data[4]) 44 | 45 | max_length = max([len(item) for item in input_ids_list]) 46 | 47 | output = [torch.full([batch_size, max_length], 48 | fill_value=0, 49 | dtype=torch.long), 50 | torch.full([batch_size, max_length], 51 | fill_value=0, 52 | dtype=torch.long), 53 | torch.full([batch_size, max_length], 54 | fill_value=0, 55 | dtype=torch.long) 56 | ] 57 | 58 | for i in range(batch_size): 59 | output[0][i][0:len(input_ids_list[i])] = torch.LongTensor(input_ids_list[i]) 60 | output[1][i][0:len(slot_ids_list[i])] = torch.LongTensor(slot_ids_list[i]) 61 | output[2][i][0:len(input_mask_list[i])] = torch.LongTensor(input_mask_list[i]) 62 | 63 | intent_ids_list = torch.LongTensor(intent_ids_list) 64 | output.append(intent_ids_list) 65 | 66 | slot_none_ids = torch.zeros([batch_size, 29], dtype=torch.long) 67 | for i, slot_none_id in enumerate(slot_none_ids_list): 68 | for idx in slot_none_id: 69 | slot_none_ids[i][idx] = 1 70 | 71 | output.append(slot_none_ids) 72 | 73 | return output # (input_ids, slot_ids, input_mask, intent_id, slot_none_id) 74 | 75 | 76 | def train(model, 77 | dataloader, 78 | optimizer, 79 | I_criterion, 80 | N_criterion, 81 | max_gradient_norm): 82 | model.train() 83 | # device = model.module.device 84 | device = model.device 85 | epoch_start = time.time() 86 | batch_time_avg = 0.0 87 | running_loss = 0.0 88 | preds_mounts = 0 89 | 90 | tqdm_batch_iterator = tqdm(dataloader) 91 | for batch_index, batch in enumerate(tqdm_batch_iterator): 92 | batch_start = time.time() 93 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = batch 94 | 95 | input_id = input_ids.to(device) 96 | slot_ids = slot_ids.to(device) 97 | input_mask = input_mask.byte().to(device) 98 | intent_id = intent_id.to(device) 99 | slot_none_id = slot_none_id.to(device) 100 | 101 | optimizer.zero_grad() 102 | intent_logits, slot_none_logits, slot_logits = model(input_id, input_mask) 103 | # intent_loss = CE_criterion(intent_logits, intent_id) 104 | intent_loss = I_criterion(intent_logits, intent_id) 105 | # slot_none_loss = BCE_criterion(slot_none_logits, slot_none_id.float()) 106 | slot_none_loss = N_criterion(slot_none_logits, slot_none_id.float()) 107 | slot_loss = model.slot_loss(slot_logits, slot_ids, input_mask) 108 | loss = intent_loss + slot_none_loss + slot_loss 109 | loss.backward() 110 | 111 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm) 112 | 113 | optimizer.step() 114 | 115 | batch_time_avg += time.time() - batch_start 116 | running_loss += loss.item() 117 | 118 | description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}" \ 119 | .format(batch_time_avg / (batch_index + 1), 120 | running_loss / (batch_index + 1)) 121 | tqdm_batch_iterator.set_description(description) 122 | 123 | epoch_time = time.time() - epoch_start 124 | epoch_loss = running_loss / len(dataloader) 125 | 126 | return epoch_time, epoch_loss 127 | 128 | 129 | def valid(model, 130 | dataloader, 131 | I_criterion, 132 | N_criterion,): 133 | config = Config() 134 | model.eval() 135 | # device = model.module.device 136 | device = model.device 137 | epoch_start = time.time() 138 | running_loss = 0.0 139 | preds_mounts = 0 140 | sen_counts = 0 141 | 142 | intent_true = [] 143 | intent_pred = [] 144 | slot_pre_output = [] 145 | slot_true_output = [] 146 | slotNone_pre_output = [] 147 | slotNone_true_output = [] 148 | 149 | slot_dict = load_vocab(config.slot_label) 150 | id2slot = {value: key for key, value in slot_dict.items()} 151 | 152 | with torch.no_grad(): 153 | tqdm_batch_iterator = tqdm(dataloader) 154 | for _, batch in enumerate(tqdm_batch_iterator): 155 | input_ids, slot_ids, input_mask, intent_id, slot_none_id = batch 156 | batch_size = len(input_ids) 157 | input_ids = input_ids.to(device) 158 | slot_ids = slot_ids.to(device) 159 | input_mask = input_mask.byte().to(device) 160 | intent_id = intent_id.to(device) 161 | slot_none_id = slot_none_id.to(device) 162 | 163 | real_length = torch.sum(input_mask, dim=1) 164 | tmp = [] 165 | i = 0 166 | for line in slot_ids.cpu().numpy().tolist(): 167 | line = [id2slot[idx] for idx in line[1: real_length[i]-1]] 168 | tmp.append(line) 169 | i += 1 170 | 171 | slot_true_output.extend(tmp) 172 | 173 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask) 174 | # intent_loss = CE_criterion(intent_logits, intent_id) 175 | intent_loss = I_criterion(intent_logits, intent_id) 176 | 177 | # slot_none_loss = BCE_criterion(slot_none_logits, slot_none_id.float()) 178 | slot_none_loss = N_criterion(slot_none_logits, slot_none_id.float()) 179 | slot_none_probs = torch.sigmoid(slot_none_logits) 180 | slot_none_probs = slot_none_probs > 0.5 181 | slot_none_probs = slot_none_probs.cpu().numpy() 182 | slot_none_probs = slot_none_probs.astype(int) 183 | slot_none_probs = slot_none_probs.tolist() 184 | slotNone_pre_output.extend(slot_none_probs) 185 | 186 | slot_none_id = slot_none_id.cpu().numpy().tolist() 187 | slotNone_true_output.extend(slot_none_id) 188 | 189 | slot_loss = model.slot_loss(slot_logits, slot_ids, input_mask) 190 | out_path = model.slot_predict(slot_logits, input_mask, id2slot) 191 | out_path = [[id2slot[idx] for idx in one_data[1:-1]] for one_data in out_path] # 去掉'[START]'和'[EOS]'标记 192 | slot_pre_output.extend(out_path) 193 | loss = intent_loss + slot_none_loss + slot_loss 194 | running_loss += loss.item() 195 | 196 | # intent acc 197 | intent_probs = torch.softmax(intent_logits, dim=-1) 198 | predict_labels = torch.argmax(intent_probs, dim=-1) 199 | correct_preds = (predict_labels == intent_id).sum().item() 200 | preds_mounts += correct_preds 201 | 202 | predict_labels = predict_labels.cpu().numpy().tolist() 203 | intent_pred.extend(predict_labels) 204 | intent_id = intent_id.cpu().numpy().tolist() 205 | intent_true.extend(intent_id) 206 | 207 | # 计算slotNone classification 准确率、召回率、F1值 208 | slotNone_micro_acc = precision_score(slotNone_true_output, slotNone_pre_output, average='micro') 209 | slotNone_micro_recall = recall_score(slotNone_true_output, slotNone_pre_output, average='micro') 210 | slotNone_micro_f1 = f1_score(slotNone_true_output, slotNone_pre_output, average='micro') 211 | 212 | # 计算slot filling 准确率、召回率、F1值 213 | slot_acc = slot_precision_score(slot_true_output, slot_pre_output) 214 | slot_recall = slot_recall_score(slot_true_output, slot_pre_output) 215 | slot_f1 = slot_F1_score(slot_true_output, slot_pre_output) 216 | 217 | # 计算整句正确率 218 | for i in range(len(intent_true)): 219 | if intent_true[i] == intent_pred[i] and slot_true_output[i] == slot_pre_output[i] and \ 220 | slotNone_pre_output[i] == slotNone_true_output[i]: 221 | sen_counts += 1 222 | # fp = open('../error_list1', 'w', encoding='utf-8') 223 | # for i in range(len(intent_true)): 224 | # if intent_true[i] != intent_pred[i] or slot_true_output[i] != slot_pre_output[i] or \ 225 | # slotNone_true_output[i] != slotNone_pre_output[i]: 226 | # fp.write(str(i)+'\n') 227 | 228 | slot_none = (slotNone_micro_acc, slotNone_micro_recall, slotNone_micro_f1) 229 | slot = (slot_acc, slot_recall, slot_f1) 230 | 231 | epoch_time = time.time() - epoch_start 232 | epoch_loss = running_loss / len(dataloader) 233 | intent_accuracy = preds_mounts / len(dataloader.dataset) 234 | sen_accuracy = sen_counts / len(dataloader.dataset) 235 | 236 | return epoch_time, epoch_loss, intent_accuracy, slot_none, slot, sen_accuracy 237 | -------------------------------------------------------------------------------- /data/code/preprocess/process_rawdata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/8/30 23:32 4 | # @Author : JJkinging 5 | # @File : process_rawdata.py 6 | import json 7 | from transformers import AutoTokenizer 8 | 9 | 10 | def find_idx(text, value, dic, mode): 11 | ''' 12 | 返回value在text中的起始位置 13 | :param text: list 14 | :param value: list 15 | :param dic: 同义词词表, 只有mode=='mohu'时才用到 16 | :param mode: 如果mode=='mohu',则表示是模糊搜索 17 | :return: 18 | ''' 19 | 20 | if mode == 'mohu': 21 | value = ''.join(value) # 变成字符串 22 | if value not in dic.keys(): 23 | return -1 24 | candidate = dic[value] 25 | flag = False 26 | index = -1 27 | for ca_value in candidate: 28 | for idx, item in enumerate(text): 29 | if item == ca_value[0]: # 匹配第一个字符 30 | index = idx 31 | for i in range(len(ca_value)): 32 | if text[idx + i] == ca_value[i]: 33 | flag = True 34 | if i == len(ca_value) - 1: 35 | return index 36 | else: 37 | index = -1 38 | flag = False 39 | break 40 | else: 41 | flag = False 42 | index = -1 43 | for idx, item in enumerate(text): 44 | if item == value[0]: # 匹配第一个 45 | index = idx 46 | for i in range(len(value)): 47 | if text[idx + i] == value[i]: 48 | flag = True 49 | if i == len(value) - 1: 50 | return index 51 | else: 52 | index = -1 53 | flag = False 54 | break 55 | if flag: 56 | return index 57 | else: 58 | return -1 59 | 60 | 61 | def load_reverse_vocab(label_file): 62 | '''construct id2word''' 63 | vocab = {} 64 | index = 0 65 | with open(label_file, 'r', encoding='utf-8') as fp: 66 | while True: 67 | token = fp.readline() 68 | if not token: 69 | break 70 | token = token.strip() # 删除空白符 71 | vocab[index] = token 72 | index += 1 73 | return vocab 74 | 75 | 76 | def fun(filename, 77 | seq_in_target_file, 78 | seq_out_target_file, 79 | slot_none_target_file, 80 | region_dict_file, 81 | vocab): 82 | # tokenizer = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0') 83 | tokenizer = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0') 84 | 85 | with open(filename, 'r', encoding='utf-8') as fp: 86 | raw_data = json.load(fp) 87 | 88 | with open(region_dict_file, 'r', encoding='utf-8') as fp: 89 | region_dic = json.load(fp) 90 | 91 | seq_in = open(seq_in_target_file, 'a', encoding='utf-8') 92 | seq_out = open(seq_out_target_file, 'a', encoding='utf-8') 93 | slot_none_out = open(slot_none_target_file, 'a', encoding='utf-8') 94 | label = open('../new_data/label', 'a', encoding='utf-8') 95 | query = open('../new_data/query', 'a', encoding='utf-8') 96 | command = open('../new_data/command', 'a', encoding='utf-8') 97 | play_mode = open('../new_data/play_mode', 'a', encoding='utf-8') 98 | index = open('../new_data/index', 'a', encoding='utf-8') 99 | 100 | intent_label = set() 101 | slot_label = set() 102 | command_type = set() 103 | index_type = set() 104 | play_mode_type = set() 105 | query_type = set() 106 | 107 | count = 0 108 | raw_data_tmp = raw_data.copy() 109 | for dataname, single_data in raw_data.items(): 110 | flag = True 111 | text = single_data['text'] 112 | if text[-1] == '。': # 去掉末尾的句号 113 | text = text[:-1] 114 | text_id = tokenizer.encode(text) # 包含1和2 115 | text = [vocab[idx] for idx in text_id[1:-1]] 116 | intent = single_data['intent'] 117 | slots = single_data['slots'] 118 | slot_tags_str = ('O ' * len(text)).strip(' ') 119 | slot_tags = [item for item in slot_tags_str] # 初始化句子槽标签全为'O' 120 | # label.write(intent+'\n') 121 | # intent_label.add(intent) 122 | 123 | slot_none_str = '' 124 | 125 | for slot, value in slots.items(): 126 | if slot == 'command': 127 | command_type.add(str(value)) 128 | if isinstance(value, list): 129 | for item in value: # value是个列表,item是字符串槽值 130 | slot_none_str += item + ' ' 131 | else: 132 | slot_none_str += value + ' ' 133 | # command.write(str(value)+'\n') 134 | elif slot == 'index': 135 | index_type.add(str(value)) 136 | if isinstance(value, list): 137 | for item in value: # value是个列表,item是字符串槽值 138 | slot_none_str += item + ' ' 139 | else: 140 | slot_none_str += value + ' ' 141 | # index.write(str(value)+'\n') 142 | elif slot == 'play_mode': 143 | play_mode_type.add(str(value)) 144 | if isinstance(value, list): 145 | for item in value: # value是个列表,item是字符串槽值 146 | slot_none_str += item + ' ' 147 | else: 148 | slot_none_str += value + ' ' 149 | # play_mode.write(str(value)+'\n') 150 | elif slot == 'query_type': 151 | query_type.add(str(value)) 152 | if isinstance(value, list): 153 | for item in value: # value是个列表,item是字符串槽值 154 | slot_none_str += item + ' ' 155 | else: 156 | slot_none_str += value + ' ' 157 | # query.write(str(value)+'\n') 158 | else: 159 | # query.write('None'+'\n') 160 | # command.write('None'+'\n') 161 | # play_mode.write('None'+'\n') 162 | # index.write('None'+'\n') 163 | slot_label.add(slot) 164 | if isinstance(value, list): # 槽不止一个值 例如:"artist": ["陈博","嘉洋"] 165 | for v in value: 166 | v = [vocab[idx] for idx in tokenizer.encode(v)[1:-1]] 167 | start_idx = find_idx(text, v, region_dic, mode='normal') 168 | if start_idx == -1: 169 | start_idx = find_idx(text, value, region_dic, mode='mohu') # 二次检测 170 | if start_idx == -1: 171 | flag = False 172 | print("无此槽值:", slot, v, dataname) 173 | raw_data_tmp.pop(dataname) 174 | count += 1 175 | break 176 | tag_len = len(v) 177 | if tag_len == 1: 178 | slot_tags[start_idx * 2] = 'B-' + slot 179 | else: 180 | for i in range(tag_len): 181 | if i == 0: 182 | slot_tags[(start_idx + i) * 2] = 'B-' + slot 183 | else: 184 | slot_tags[(start_idx + i) * 2] = 'I-' + slot 185 | 186 | else: # 槽值为单个字符串 187 | value = [vocab[idx] for idx in tokenizer.encode(value)[1:-1]] 188 | if not value: # value 为空 189 | break 190 | start_idx = find_idx(text, value, region_dic, mode='normal') # 槽标签起始位置 191 | if start_idx == -1: # 句子无此槽值时 192 | start_idx = find_idx(text, value, region_dic, mode='mohu') # 二次检测 193 | if start_idx == -1: 194 | flag = False 195 | print("无此槽值:", slot, value, dataname) 196 | raw_data_tmp.pop(dataname) 197 | count += 1 198 | break 199 | tag_len = len(value) # 槽标签长度 200 | if tag_len == 1: 201 | slot_tags[start_idx * 2] = 'B-' + slot 202 | else: 203 | for i in range(tag_len): 204 | if i == 0: 205 | slot_tags[(start_idx + i) * 2] = 'B-' + slot 206 | else: 207 | slot_tags[(start_idx + i) * 2] = 'I-' + slot 208 | 209 | if flag: 210 | if slot_none_str == '': # slot_none_str为空,说明此句话无slot_none槽类型 211 | slot_none_out.write('None' + '\n') 212 | else: 213 | slot_none_out.write(slot_none_str + '\n') 214 | 215 | slot_tags = ''.join(slot_tags) # 把slot_tags变为str类型 216 | seq_out.write(slot_tags + '\n') 217 | seq_in.write(' '.join(text) + '\n') 218 | 219 | intent_label = ' '.join(intent_label) 220 | slot_label = ' '.join(slot_label) 221 | 222 | # with open('../new_data/intent_label.txt', 'a', encoding='utf-8') as fp: 223 | # fp.write(intent_label+'\n') 224 | # with open('../new_data/slot_label.txt', 'a', encoding='utf-8') as fp: 225 | # fp.write(slot_label+'\n') 226 | # with open('../new_data/command_type.txt', 'a', encoding='utf-8') as fp: 227 | # fp.write(str(command_type)+'\n') 228 | # with open('../new_data/index_type.txt', 'a', encoding='utf-8') as fp: 229 | # fp.write(str(index_type)+'\n') 230 | # with open('../new_data/play_mode_type.txt', 'a', encoding='utf-8') as fp: 231 | # fp.write(str(play_mode_type)+'\n') 232 | # with open('../new_data/query_type.txt', 'a', encoding='utf-8') as fp: 233 | # fp.write(str(query_type)+'\n') 234 | 235 | seq_in.close() 236 | seq_out.close() 237 | slot_none_out.close() 238 | with open('../small_sample/new_B/train_final_clear_del.json', 'w', encoding='utf-8') as fp: 239 | json.dump(raw_data_tmp, fp, ensure_ascii=False) 240 | print('缺失数目:', count) 241 | 242 | 243 | if __name__ == "__main__": 244 | vocab_file = '../pretrained_model/erine/vocab.txt' 245 | train_filename = '../small_sample/new_B/train_final_clear_sorted.json' 246 | dev_filename = '../extend_data/last_data/dev_sorted.json' 247 | test_filename = '../raw_data/test_B_final_text.json' 248 | seq_in_target_file = '../small_sample/new_B/train_seq_in.txt' 249 | seq_out_target_file = '../small_sample/new_B/train_seq_out.txt' 250 | slot_none_target_file = '../small_sample/new_B/train_slot_none.txt' 251 | region_dict_file = '../final_data/region_dic.json' 252 | vocab = load_reverse_vocab(vocab_file) 253 | fun(filename=train_filename, 254 | seq_in_target_file=seq_in_target_file, 255 | seq_out_target_file=seq_out_target_file, 256 | slot_none_target_file=slot_none_target_file, 257 | region_dict_file=region_dict_file, 258 | vocab=vocab) 259 | # text = ['下', '周', '六', '我', '爷', '爷', '让', '我', '去', '买', '茶', '叶', ',', '记', '得', '提', '醒', '我'] 260 | # value = ['我', '爷', '爷', '让', '我', '去', '买', '茶', '叶'] 261 | # idx = find_idx(text, value) 262 | -------------------------------------------------------------------------------- /data/code/model/InteractModel_1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:40 4 | # @Author : JJkinging 5 | # @File : model.py 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import BertModel 11 | from data.code.model.torchcrf import CRF 12 | import torch.nn.functional as F 13 | from data.code.scripts.config_Interact1 import Config 14 | 15 | 16 | class InteractModel(nn.Module): 17 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size, 18 | slot_tag_size, device): 19 | super(InteractModel, self).__init__() 20 | self.bert_model_path = bert_model_path 21 | self.bert_hidden_size = bert_hidden_size 22 | self.intent_tag_size = intent_tag_size 23 | self.slot_none_tag_size = slot_none_tag_size 24 | self.slot_tag_size = slot_tag_size 25 | self.device = device 26 | self.bert = BertModel.from_pretrained(self.bert_model_path) 27 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True) 28 | self.IntentClassify = nn.Linear(self.bert_hidden_size, self.intent_tag_size) 29 | self.SlotNoneClassify = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size) 30 | self.SlotClassify = nn.Linear(self.bert_hidden_size, self.slot_tag_size) 31 | self.Dropout = nn.Dropout(p=0.5) 32 | 33 | self.I_S_Emb = Label_Attention(self.IntentClassify, self.SlotClassify) 34 | self.T_block1 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size) 35 | 36 | def forward(self, input_ids, input_mask): 37 | batch_size = input_ids.size(0) 38 | seq_len = input_ids.size(1) 39 | utter_encoding = self.bert(input_ids, input_mask) 40 | H = utter_encoding[0] 41 | pooler_out = utter_encoding[1] 42 | H = self.Dropout(H) 43 | pooler_out = self.Dropout(pooler_out) 44 | 45 | # 1. Label Attention 46 | H_I, H_S = self.I_S_Emb(H, H, input_mask) 47 | # Co-Interactive Attention Layer 48 | H_I, H_S = self.T_block1(H_I + H, H_S + H, input_mask) 49 | 50 | intent_input = F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2) 51 | 52 | intent_logits = self.IntentClassify(intent_input) 53 | slot_none_logits = self.SlotNoneClassify(pooler_out) 54 | slot_logits = self.SlotClassify(H_S + H) 55 | 56 | return intent_logits, slot_none_logits, slot_logits 57 | 58 | def slot_loss(self, feats, slot_ids, mask): 59 | ''' 做训练时用 60 | :param feats: the output of BiLSTM and Liner 61 | :param slot_ids: 62 | :param mask: 63 | :return: 64 | ''' 65 | feats = feats.to(self.device) 66 | slot_ids = slot_ids.to(self.device) 67 | mask = mask.to(self.device) 68 | loss_value = self.CRF(emissions=feats, 69 | tags=slot_ids, 70 | mask=mask, 71 | reduction='mean') 72 | return -loss_value 73 | 74 | def slot_predict(self, feats, mask, id2slot): 75 | feats = feats.to(self.device) 76 | mask = mask.to(self.device) 77 | slot2id = {value: key for key, value in id2slot.items()} 78 | # 做验证和测试时用 79 | out_path = self.CRF.decode(emissions=feats, mask=mask) 80 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] 81 | for out in out_path: 82 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等 83 | if tag.startswith('I-'): # 当前tag为I-开头 84 | if i == 0: # 0位置应该是[START] 85 | out[i] = '[START]' 86 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的 87 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag 88 | 89 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path] 90 | 91 | return out_path 92 | 93 | 94 | class Label_Attention(nn.Module): 95 | def __init__(self, intent_emb, slot_emb): 96 | super(Label_Attention, self).__init__() 97 | 98 | self.W_intent_emb = intent_emb.weight # [num_class, hidden_dize] 99 | self.W_slot_emb = slot_emb.weight 100 | 101 | def forward(self, input_intent, input_slot, mask): 102 | intent_score = torch.matmul(input_intent, self.W_intent_emb.t()) 103 | slot_score = torch.matmul(input_slot, self.W_slot_emb.t()) 104 | intent_probs = nn.Softmax(dim=-1)(intent_score) 105 | slot_probs = nn.Softmax(dim=-1)(slot_score) 106 | intent_res = torch.matmul(intent_probs, self.W_intent_emb) # [bs, seq_len, hidden_size] 107 | slot_res = torch.matmul(slot_probs, self.W_slot_emb) 108 | 109 | return intent_res, slot_res 110 | 111 | 112 | class I_S_Block(nn.Module): 113 | def __init__(self, intent_emb, slot_emb, hidden_size): 114 | super(I_S_Block, self).__init__() 115 | config = Config() 116 | self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size) 117 | self.I_Out = SelfOutput(hidden_size, config.attention_dropout) 118 | self.S_Out = SelfOutput(hidden_size, config.attention_dropout) 119 | self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size) 120 | 121 | def forward(self, H_intent_input, H_slot_input, mask): 122 | H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask) 123 | H_slot = self.S_Out(H_slot, H_slot_input) # H_slot_input: label attention的输出 124 | H_intent = self.I_Out(H_intent, H_intent_input) 125 | H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot) 126 | 127 | return H_intent, H_slot 128 | 129 | 130 | class I_S_SelfAttention(nn.Module): 131 | def __init__(self, input_size, hidden_size, out_size): 132 | super(I_S_SelfAttention, self).__init__() 133 | config = Config() 134 | 135 | self.num_attention_heads = 12 136 | self.attention_head_size = int(hidden_size / self.num_attention_heads) 137 | 138 | self.all_head_size = self.num_attention_heads * self.attention_head_size 139 | self.out_size = out_size 140 | self.query = nn.Linear(input_size, self.all_head_size) 141 | self.query_slot = nn.Linear(input_size, self.all_head_size) 142 | self.key = nn.Linear(input_size, self.all_head_size) 143 | self.key_slot = nn.Linear(input_size, self.all_head_size) 144 | self.value = nn.Linear(input_size, self.out_size) 145 | self.value_slot = nn.Linear(input_size, self.out_size) 146 | self.dropout = nn.Dropout(config.attention_dropout) 147 | 148 | def transpose_for_scores(self, x): 149 | last_dim = int(x.size()[-1] / self.num_attention_heads) 150 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim) 151 | x = x.view(*new_x_shape) 152 | return x.permute(0, 2, 1, 3) 153 | 154 | def forward(self, intent, slot, mask): 155 | extended_attention_mask = mask.unsqueeze(1).unsqueeze(2) 156 | 157 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 158 | attention_mask = (1.0 - extended_attention_mask) * -10000.0 159 | 160 | mixed_query_layer = self.query(intent) 161 | mixed_key_layer = self.key(slot) 162 | mixed_value_layer = self.value(slot) 163 | 164 | mixed_query_layer_slot = self.query_slot(slot) 165 | mixed_key_layer_slot = self.key_slot(intent) 166 | mixed_value_layer_slot = self.value_slot(intent) 167 | 168 | query_layer = self.transpose_for_scores(mixed_query_layer) 169 | query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot) 170 | key_layer = self.transpose_for_scores(mixed_key_layer) 171 | key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot) 172 | value_layer = self.transpose_for_scores(mixed_value_layer) 173 | value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot) 174 | 175 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 176 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 177 | # attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0)) 178 | attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2)) 179 | attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size) 180 | attention_scores_intent = attention_scores + attention_mask 181 | 182 | attention_scores_slot = attention_scores_slot + attention_mask 183 | 184 | # Normalize the attention scores to probabilities. 185 | attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot) 186 | attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent) 187 | 188 | attention_probs_slot = self.dropout(attention_probs_slot) 189 | attention_probs_intent = self.dropout(attention_probs_intent) 190 | 191 | context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot) 192 | context_layer_intent = torch.matmul(attention_probs_intent, value_layer) 193 | 194 | context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous() 195 | context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous() 196 | new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,) 197 | new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,) 198 | 199 | context_layer = context_layer.view(*new_context_layer_shape) 200 | context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent) 201 | return context_layer, context_layer_intent 202 | 203 | 204 | class SelfOutput(nn.Module): 205 | def __init__(self, hidden_size, hidden_dropout_prob): 206 | super(SelfOutput, self).__init__() 207 | self.dense = nn.Linear(hidden_size, hidden_size) 208 | self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) 209 | self.dropout = nn.Dropout(hidden_dropout_prob) 210 | 211 | def forward(self, hidden_states, input_tensor): 212 | hidden_states = self.dense(hidden_states) 213 | hidden_states = self.dropout(hidden_states) 214 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 215 | return hidden_states 216 | 217 | 218 | class LayerNorm(nn.Module): 219 | def __init__(self, hidden_size, eps=1e-12): 220 | """Construct a layernorm module in the TF style (epsilon inside the square root). 221 | """ 222 | super(LayerNorm, self).__init__() 223 | self.weight = nn.Parameter(torch.ones(hidden_size)) 224 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 225 | self.variance_epsilon = eps 226 | 227 | def forward(self, x): 228 | u = x.mean(-1, keepdim=True) 229 | s = (x - u).pow(2).mean(-1, keepdim=True) 230 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 231 | return self.weight * x + self.bias 232 | 233 | 234 | class Intermediate_I_S(nn.Module): 235 | def __init__(self, intermediate_size, hidden_size): 236 | super(Intermediate_I_S, self).__init__() 237 | self.config = Config() 238 | self.dense_in = nn.Linear(hidden_size * 6, intermediate_size) 239 | self.intermediate_act_fn = nn.ReLU() 240 | self.dense_out = nn.Linear(intermediate_size, hidden_size) 241 | self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12) 242 | self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12) 243 | self.dropout = nn.Dropout(self.config.attention_dropout) 244 | 245 | def forward(self, hidden_states_I, hidden_states_S): 246 | hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2) 247 | batch_size, max_length, hidden_size = hidden_states_in.size() 248 | h_pad = torch.zeros(batch_size, 1, hidden_size) 249 | if self.config.use_gpu and torch.cuda.is_available(): 250 | h_pad = h_pad.cuda() 251 | h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1) 252 | h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1) 253 | hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2) 254 | 255 | hidden_states = self.dense_in(hidden_states_in) 256 | hidden_states = self.intermediate_act_fn(hidden_states) 257 | hidden_states = self.dense_out(hidden_states) 258 | hidden_states = self.dropout(hidden_states) 259 | hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I) 260 | hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S) 261 | return hidden_states_I_NEW, hidden_states_S_NEW -------------------------------------------------------------------------------- /data/code/model/InteractModel_3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/8 14:40 4 | # @Author : JJkinging 5 | # @File : model.py 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import BertModel 11 | from data.code.model.torchcrf import CRF 12 | import torch.nn.functional as F 13 | from data.code.scripts.config_Interact3 import Config 14 | 15 | 16 | class InteractModel(nn.Module): 17 | def __init__(self, bert_model_path, bert_hidden_size, intent_tag_size, slot_none_tag_size, 18 | slot_tag_size, device): 19 | super(InteractModel, self).__init__() 20 | self.bert_model_path = bert_model_path 21 | self.bert_hidden_size = bert_hidden_size 22 | self.intent_tag_size = intent_tag_size 23 | self.slot_none_tag_size = slot_none_tag_size 24 | self.slot_tag_size = slot_tag_size 25 | self.device = device 26 | self.bert = BertModel.from_pretrained(self.bert_model_path) 27 | self.CRF = CRF(num_tags=self.slot_tag_size, batch_first=True) 28 | self.IntentClassify = nn.Linear(self.bert_hidden_size, self.intent_tag_size) 29 | self.SlotNoneClassify = nn.Linear(self.bert_hidden_size, self.slot_none_tag_size) 30 | self.SlotClassify = nn.Linear(self.bert_hidden_size, self.slot_tag_size) 31 | self.Dropout = nn.Dropout(p=0.5) 32 | 33 | self.I_S_Emb = Label_Attention(self.IntentClassify, self.SlotClassify) 34 | self.T_block1 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size) 35 | self.T_block2 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size) 36 | self.T_block3 = I_S_Block(self.IntentClassify, self.SlotClassify, self.bert_hidden_size) 37 | 38 | def forward(self, input_ids, input_mask): 39 | batch_size = input_ids.size(0) 40 | seq_len = input_ids.size(1) 41 | utter_encoding = self.bert(input_ids, input_mask) 42 | H = utter_encoding[0] 43 | pooler_out = utter_encoding[1] 44 | H = self.Dropout(H) 45 | pooler_out = self.Dropout(pooler_out) 46 | 47 | # 1. Label Attention 48 | H_I, H_S = self.I_S_Emb(H, H, input_mask) 49 | # Co-Interactive Attention Layer 50 | H_I, H_S = self.T_block1(H_I + H, H_S + H, input_mask) 51 | 52 | # 2. Label Attention 53 | H_I_1, H_S_1 = self.I_S_Emb(H_I, H_S, input_mask) 54 | # # # Co-Interactive Attention Layer 55 | H_I, H_S = self.T_block2(H_I + H_I_1, H_S + H_S_1, input_mask) 56 | 57 | # 3. Label Attention 58 | H_I_2, H_S_2 = self.I_S_Emb(H_I, H_S, input_mask) 59 | # # # Co-Interactive Attention Layer 60 | H_I, H_S = self.T_block3(H_I + H_I_2, H_S + H_S_2, input_mask) 61 | 62 | intent_input = F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2) 63 | 64 | intent_logits = self.IntentClassify(intent_input) 65 | slot_none_logits = self.SlotNoneClassify(pooler_out) 66 | slot_logits = self.SlotClassify(H_S + H) 67 | 68 | return intent_logits, slot_none_logits, slot_logits 69 | 70 | def slot_loss(self, feats, slot_ids, mask): 71 | ''' 做训练时用 72 | :param feats: the output of BiLSTM and Liner 73 | :param slot_ids: 74 | :param mask: 75 | :return: 76 | ''' 77 | feats = feats.to(self.device) 78 | slot_ids = slot_ids.to(self.device) 79 | mask = mask.to(self.device) 80 | loss_value = self.CRF(emissions=feats, 81 | tags=slot_ids, 82 | mask=mask, 83 | reduction='mean') 84 | return -loss_value 85 | 86 | def slot_predict(self, feats, mask, id2slot): 87 | feats = feats.to(self.device) 88 | mask = mask.to(self.device) 89 | slot2id = {value: key for key, value in id2slot.items()} 90 | # 做验证和测试时用 91 | out_path = self.CRF.decode(emissions=feats, mask=mask) 92 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] 93 | for out in out_path: 94 | for i, tag in enumerate(out): # tag为O、B-*、I-* 等等 95 | if tag.startswith('I-'): # 当前tag为I-开头 96 | if i == 0: # 0位置应该是[START] 97 | out[i] = '[START]' 98 | elif out[i-1] == 'O' or out[i-1] == '[START]': # 但是前一个tag不是以B-开头的 99 | out[i] = id2slot[slot2id[tag]-1] # 将其纠正为对应的B-开头的tag 100 | 101 | out_path = [[slot2id[idx] for idx in one_data] for one_data in out_path] 102 | 103 | return out_path 104 | 105 | class Label_Attention(nn.Module): 106 | def __init__(self, intent_emb, slot_emb): 107 | super(Label_Attention, self).__init__() 108 | 109 | self.W_intent_emb = intent_emb.weight # [num_class, hidden_dize] 110 | self.W_slot_emb = slot_emb.weight 111 | 112 | def forward(self, input_intent, input_slot, mask): 113 | intent_score = torch.matmul(input_intent, self.W_intent_emb.t()) 114 | slot_score = torch.matmul(input_slot, self.W_slot_emb.t()) 115 | intent_probs = nn.Softmax(dim=-1)(intent_score) 116 | slot_probs = nn.Softmax(dim=-1)(slot_score) 117 | intent_res = torch.matmul(intent_probs, self.W_intent_emb) # [bs, seq_len, hidden_size] 118 | slot_res = torch.matmul(slot_probs, self.W_slot_emb) 119 | 120 | return intent_res, slot_res 121 | 122 | 123 | class I_S_Block(nn.Module): 124 | def __init__(self, intent_emb, slot_emb, hidden_size): 125 | super(I_S_Block, self).__init__() 126 | config = Config() 127 | self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size) 128 | self.I_Out = SelfOutput(hidden_size, config.attention_dropout) 129 | self.S_Out = SelfOutput(hidden_size, config.attention_dropout) 130 | self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size) 131 | 132 | def forward(self, H_intent_input, H_slot_input, mask): 133 | H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask) 134 | H_slot = self.S_Out(H_slot, H_slot_input) # H_slot_input: label attention的输出 135 | H_intent = self.I_Out(H_intent, H_intent_input) 136 | H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot) 137 | 138 | return H_intent, H_slot 139 | 140 | 141 | class I_S_SelfAttention(nn.Module): 142 | def __init__(self, input_size, hidden_size, out_size): 143 | super(I_S_SelfAttention, self).__init__() 144 | config = Config() 145 | 146 | self.num_attention_heads = 12 147 | self.attention_head_size = int(hidden_size / self.num_attention_heads) 148 | 149 | self.all_head_size = self.num_attention_heads * self.attention_head_size 150 | self.out_size = out_size 151 | self.query = nn.Linear(input_size, self.all_head_size) 152 | self.query_slot = nn.Linear(input_size, self.all_head_size) 153 | self.key = nn.Linear(input_size, self.all_head_size) 154 | self.key_slot = nn.Linear(input_size, self.all_head_size) 155 | self.value = nn.Linear(input_size, self.out_size) 156 | self.value_slot = nn.Linear(input_size, self.out_size) 157 | self.dropout = nn.Dropout(config.attention_dropout) 158 | 159 | def transpose_for_scores(self, x): 160 | last_dim = int(x.size()[-1] / self.num_attention_heads) 161 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim) 162 | x = x.view(*new_x_shape) 163 | return x.permute(0, 2, 1, 3) 164 | 165 | def forward(self, intent, slot, mask): 166 | extended_attention_mask = mask.unsqueeze(1).unsqueeze(2) 167 | 168 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 169 | attention_mask = (1.0 - extended_attention_mask) * -10000.0 170 | 171 | mixed_query_layer = self.query(intent) 172 | mixed_key_layer = self.key(slot) 173 | mixed_value_layer = self.value(slot) 174 | 175 | mixed_query_layer_slot = self.query_slot(slot) 176 | mixed_key_layer_slot = self.key_slot(intent) 177 | mixed_value_layer_slot = self.value_slot(intent) 178 | 179 | query_layer = self.transpose_for_scores(mixed_query_layer) 180 | query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot) 181 | key_layer = self.transpose_for_scores(mixed_key_layer) 182 | key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot) 183 | value_layer = self.transpose_for_scores(mixed_value_layer) 184 | value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot) 185 | 186 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 187 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 188 | # attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0)) 189 | attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2)) 190 | attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size) 191 | attention_scores_intent = attention_scores + attention_mask 192 | 193 | attention_scores_slot = attention_scores_slot + attention_mask 194 | 195 | # Normalize the attention scores to probabilities. 196 | attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot) 197 | attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent) 198 | 199 | attention_probs_slot = self.dropout(attention_probs_slot) 200 | attention_probs_intent = self.dropout(attention_probs_intent) 201 | 202 | context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot) 203 | context_layer_intent = torch.matmul(attention_probs_intent, value_layer) 204 | 205 | context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous() 206 | context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous() 207 | new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,) 208 | new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,) 209 | 210 | context_layer = context_layer.view(*new_context_layer_shape) 211 | context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent) 212 | return context_layer, context_layer_intent 213 | 214 | 215 | class SelfOutput(nn.Module): 216 | def __init__(self, hidden_size, hidden_dropout_prob): 217 | super(SelfOutput, self).__init__() 218 | self.dense = nn.Linear(hidden_size, hidden_size) 219 | self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) 220 | self.dropout = nn.Dropout(hidden_dropout_prob) 221 | 222 | def forward(self, hidden_states, input_tensor): 223 | hidden_states = self.dense(hidden_states) 224 | hidden_states = self.dropout(hidden_states) 225 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 226 | return hidden_states 227 | 228 | 229 | class LayerNorm(nn.Module): 230 | def __init__(self, hidden_size, eps=1e-12): 231 | """Construct a layernorm module in the TF style (epsilon inside the square root). 232 | """ 233 | super(LayerNorm, self).__init__() 234 | self.weight = nn.Parameter(torch.ones(hidden_size)) 235 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 236 | self.variance_epsilon = eps 237 | 238 | def forward(self, x): 239 | u = x.mean(-1, keepdim=True) 240 | s = (x - u).pow(2).mean(-1, keepdim=True) 241 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 242 | return self.weight * x + self.bias 243 | 244 | 245 | class Intermediate_I_S(nn.Module): 246 | def __init__(self, intermediate_size, hidden_size): 247 | super(Intermediate_I_S, self).__init__() 248 | self.config = Config() 249 | self.dense_in = nn.Linear(hidden_size * 6, intermediate_size) 250 | self.intermediate_act_fn = nn.ReLU() 251 | self.dense_out = nn.Linear(intermediate_size, hidden_size) 252 | self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12) 253 | self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12) 254 | self.dropout = nn.Dropout(self.config.attention_dropout) 255 | 256 | def forward(self, hidden_states_I, hidden_states_S): 257 | hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2) 258 | batch_size, max_length, hidden_size = hidden_states_in.size() 259 | h_pad = torch.zeros(batch_size, 1, hidden_size) 260 | if self.config.use_gpu and torch.cuda.is_available(): 261 | h_pad = h_pad.cuda() 262 | h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1) 263 | h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1) 264 | hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2) 265 | 266 | hidden_states = self.dense_in(hidden_states_in) 267 | hidden_states = self.intermediate_act_fn(hidden_states) 268 | hidden_states = self.dense_out(hidden_states) 269 | hidden_states = self.dropout(hidden_states) 270 | hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I) 271 | hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S) 272 | return hidden_states_I_NEW, hidden_states_S_NEW -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2021 CCF BDCI 全国信息检索挑战杯(CCIR-Cup) 智能人机交互自然语言理解赛道第二名解决方案(意图识别&槽填充)
2 | 比赛网址: CCIR-Cup-智能人机交互自然语言理解 3 | ![@}77TWD9UCSZ5~5~5NQ5RAP](https://user-images.githubusercontent.com/59439162/139065196-254292ae-81d9-44f4-945b-8cc17178e41d.png) 4 | ## 1.依赖环境: 5 | 6 | * python==3.8 7 | * torch==1.7.1+cu110 8 | * numpy==1.19.2 9 | * transformers==4.5.1 10 | * scikit_learn==1.0 11 | * seqeval==1.2.2 12 | * tqdm==4.50.2 13 | * CUDA==11.0 14 | 15 | ## 2.解决方案 16 | 17 | ### 1.数据预处理部分 18 | 19 | 1.首先对任务进行明确,赛题任务是意图识别与槽填充,在观察训练数据之后,发现槽填充任务可分解为两类任务来做:一类是标准槽填充任务,即槽值可在当前对话句子当中完全匹配到;另一类是非标准槽填充(自拟的名字),即槽值不可在当前对话句子中找到或者完全匹配。对于非标准槽填充任务,把它当作另一种分类任务来解决。所以,我就把比赛任务当作三个子任务来进行,分别是意图识别、标准槽填充和非标准槽填充(分类任务)。 20 | 21 | 2.对于标准槽填充而言,有些槽标签在大部分训练数据中都是可完全匹配的,但是仍然存在少量不完全匹配槽标签,例如:句中出现的是港片、韩剧、美剧、内地、英文、中文等词汇时,对应的槽值标注却是香港、韩国、美国、大陆、英语、华语等。对于这一类数据,若将其当作非标准槽填充任务显得不太合理,解决方案是:提前准备好一个特殊词汇的映射字典 region_dic.json,在处理训练数据的时候,如果碰到了有的槽值出现在特殊词汇字典中,则对其进行槽标注的时候需要先进行转换。 22 | 23 | 3.然后用 ernie 的 Tokenizer 将对话句子按字切分,可能存在 ##后缀 和 [UNK] 的情况,将切分好的句子作为原始输入,对于标准槽填充任务,按 BIO 的方式进行标注。观察训练数据后发现,标准槽填充任务中存在少量的嵌套命名实体,例如: 24 | 25 | ```json 26 | { 27 | "NLU07301": { 28 | "text": "西安WE比赛的视频帮我找一下", 29 | "intent": "Video-Play", 30 | "slots": { 31 | "name": "西安WE比赛的视频", 32 | "region": "西安" 33 | } 34 | } 35 | } 36 | ``` 37 | 38 | 对于这种少量嵌套的例子,我并没有涉及特殊的网络结构来解决,而是使用了一种简单的方式:首先将每条训练数据的 slots 下的所有槽值的长度按从大到小排列,然后在对其进行序列标注的时候按槽的先后顺序进行标注,比如上面的例子的标注方式为: 39 | 40 | 首先是 "name" 这个槽标签: 41 | 42 | ``` 43 | B-name I-name I-name I-name I-name I-name I-name I-name I-name O O O O O 44 | ``` 45 | 46 | 然后是 "region" 这个槽标签: 47 | 48 | ``` 49 | B-region I-region I-name I-name I-name I-name I-name I-name I-name O O O O O 50 | ``` 51 | 52 | "region" 标注完成后将 B-name I-name 这两个tag给覆盖掉了。 53 | 54 | 最后,对于这种嵌套实体,模型就按照这种标注方式去训练;在解码的时候,按照一定的匹配规则识别出"name" 和 "region"这两个槽标签,后面的实验中表明使用这种标注方式能够有效的识别出测试集数据中存在的嵌套实体。 55 | 56 | 4.手动纠正部分,训练数据中存在一些数据有着明显的标注错误。例如:NLU00400 的 "artist": "银临,河图",正确的标注应该为 "artist": ["银临",""河图"];"query_type":"汽车票查询" 和 "query_type":"汽车票",这两个明显就是一样的,故将其统一,再比如: 57 | 58 | ```json 59 | { 60 | "NLU04386": { 61 | "text": "明天早上7:20你会不会通知我带洗衣液", 62 | "intent": "Alarm-Update", 63 | "slots": { 64 | "datetime_date": "明天", 65 | "notes": "带洗衣液", 66 | "datetime_time": "早上7" 67 | } 68 | } 69 | } 70 | ``` 71 | 72 | "datetime_time":"早上7" 标注错误,正确标注应该为"datetime_time": "早上7:20",且类似这种的错误在"Alarm-Update"这个意图中大量存在。如果不进行纠正,则对模型的训练会造成很大的影响。 73 | 74 | 5.对于非标准槽填充部分,统计出了四种槽标签:command、index、play_mode 和 query_type,将这四类槽标签的槽值当作类别,一共有29种类别,如: 75 | 76 | ``` 77 | 音量调节 78 | 查询状态 79 | 汽车票查询 80 | 穿衣指数 81 | 紫外线指数 82 | ... 83 | None 84 | ``` 85 | 86 | None表示不存在这一类值。然后对其进行类似意图识别那样做分类。 87 | 88 | 6.对于域外检测任务:在 a 榜阶段,我在 LCQMC 数据集中选择了1000条左右数据作为 Other 数据的来源;在 b 榜阶段,我把 a 榜阶段预测出的 intent 为 Other 的数据加上 LCQMC数据集中选择出500条数据一起作为训练集的 intent 为 Other 类进行训练。 89 | 90 | 7.对于小样本检测任务:发现意图为 Audio-Play 和 TVProgram_Play 的这两个意图是小样本数据,在原始训练集中分别为50条。解决方法:对小样本意图数据进行数据增强,使其数量接近基本任务数据的1000条左右,具体做法:分别对 Audio-Play 和TVProgram_Play这两个意图进行增强,举例来说,对于 Audio-Play 而言,其可能的槽标签有 91 | 92 | 8.在模型输出后,进行一步后处理操作:对模型预测结果进行纠正,即把不属于某一类intent的槽值删除,首先统计出训练数据中的意图和槽标签之间的关系:
93 | ```json 94 | { 95 | "FilmTele-Play": ["name", "tag", "artist", "region", "play_setting", "age"], 96 | "Audio-Play": ["language", "artist", "tag", "name", "play_setting"], 97 | "Radio-Listen": ["name", "channel", "frequency", "artist"], 98 | "TVProgram-Play": ["name", "channel", "datetime_date", "datetime_time"], 99 | "Travel-Query": ["query_type", "datetime_date", "departure", "destination", "datetime_time"], 100 | "Music-Play": ["language", "artist", "album", "instrument", "song", "play_mode", "age"], 101 | "HomeAppliance-Control": ["command", "appliance", "details"], 102 | "Calendar-Query": ["datetime_date"], 103 | "Alarm-Update": ["notes", "datetime_date", "datetime_time"], 104 | "Video-Play": ["name", "datetime_date", "region", "datetime_time"], 105 | "Weather-Query": ["datetime_date", "type", "city", "index", "datetime_time"], 106 | "Other": [] 107 | } 108 | ``` 109 | 就可以发现意图只能包含特定的标签,某些标签不可能出现在其它意图中
110 | 举例来说:比如我在一条测试数据中预测出其intent = FilmTele-Play, 然后其槽值预测中出现了"notes"这个槽标签,这与我之前统计的哪些槽标签只出现在哪些意图中不符合(即训练数据中FileTele-Play这个意图不可能出现"notes"这个槽标签),所以该函数就把"notes"这个槽位和槽值删除掉。 111 | 112 | "Audio-Play": ["language", "artist", "tag", "name", "play_setting"], 一共五类,分别在训练数据中统计出各个槽标签可能出现的槽值有哪些,比如 "language": ["日语", "英语", "法语", "俄语", "西班牙语", "华语", "韩语", "德语", "藏语"] ,language 有这些可选项,当然也可以自己随便添加几个合适的选项。得到了一个这种字典后,对于每一条原训练数据中意图为 Audio-Play 的数据进行扩充,每一条扩充20条新数据,具体扩充方式:对于原数据而言,如果某一个槽标签出现,则在该槽标签对应的标签候选项中随机选择一个替换它,对原数据存在的每一个槽标签都进行这种操作,这样就增加了一条”新数据“,经过实验验证,小样本数据扩充前后,线上得分提升了两个点,证明这种方式还是效果不错的。
113 | 9.在a榜阶段,训练数据包括三部分:原始训练数据 + 小样本意图扩充数据 + LCQMC数据集(1000条)当作域外数据;在b榜阶段,训练数据除了和a榜相同的部分外,还把模型在a榜测试集上的输出当作a榜测试集的标注,再将这些数据也添加到训练数据中进行训练,然后再去预测b榜测试集,所以b榜阶段的最终使用的训练集有13119条。
114 | 115 | ### 2.模型算法部分 116 | 117 | 此次比赛我一共使用了三个模型进行训练,最后的结果 result.json 由三个模型的投票表决产生。 118 | 119 | #### 1.JointErine 模型 120 | 121 | 思路参照:[BERT for Joint Intent Classification and Slot Filling](http://arxiv.org/abs/1902.10909v1) 122 | 123 | 预训练模型并没有使用中文版 bert base,而是使用的是百度的中文版 ernie-1.0 base,三个任务进行联合训练。意图识别和非标准槽填充使用 erine 模型的输出分别连接一个全连接层进行分类;标准槽填充得到 erine 的输出后,再将其输入到 CRF 层,erine预训练模型与CRF层采用不同的学习率,erine 的学习率是5e-5,CRF层的学习率是5e-2 124 | 125 | #### 2.InteractModel_1 模型 126 | 127 | 思路参照:[A CO-INTERACTIVE TRANSFORMER FOR JOINT SLOT FILLING AND INTENT DETECTION](http://arxiv.org/abs/2010.03880v3) 128 | 129 | 仍然是三个任务联合训练,不同的是,意图识别和标准槽填充部分进行了一层交互,非标准槽填充未与上述二个任务进行交互,而是把预训练模型的输出 pooled_output 直接输入到全连接层中进行分类。 130 | 131 | ***交互层部分:*** 132 | 133 | 首先使用中文版 ernie-1.0 base 预训练模型作为主体部分: 134 | $$ 135 | 对于输入的序列\{x_1,x_2,...,x_n\}(n是token的数量),输入到中文版ernie-1.0模型之后得到输出H = \{h_1, h_2, ..., h_n\} 136 | $$ 137 | 然后是意图和标准槽填充的交互层,这个模型中使用了一层交互层: 138 | 139 | **标签注意力层** 140 | 141 | ![image-20210929212611984](image/readme_images/image-20210929212611984.png) 142 | 143 | **协同交互注意力层** 144 | 145 | ![image-20210929212740433](image/readme_images/image-20210929212740433.png) 146 | 147 | **前馈神经网络层** 148 | 149 | ![image-20210929212830068](image/readme_images/image-20210929212830068.png) 150 | 151 | ***解码器层部分*** 152 | 153 | ![image-20210929213002909](image/readme_images/image-20210929213002909.png) 154 | 155 | #### 3.InteractModel_3模型 156 | 157 | InteractModel_3 模型结构类似于上述讲解的 InteractModel_1 ,唯一的区别就是 InteractModel_3 的意图识别和标准槽填充部分使用了三层的交互。
158 | **镜像复现说明请查看 ccir/image/README.md** 159 | 160 | ## 3.项目目录结构 161 | ``` 162 | ccir 163 | |-- data —— 数据文件夹 164 | | |-- code —— 包含所有代码文件夹 165 | | | |-- __init__.py 166 | | | |-- model —— 与网络模型相关文件夹 167 | | | | |-- __init__.py 168 | | | | |-- InteractModel_1.py —— InteractModel_1 网络模型 169 | | | | |-- InteractModel_3.py —— InteractModel_3 网络模型 170 | | | | |-- JointBertModel.py —— JointBertModel 网络模型 171 | | | | |-- __pycache__ 172 | | | | | |-- __init__.cpython-38.pyc 173 | | | | | |-- InteractModel_1.cpython-38.pyc 174 | | | | | |-- InteractModel_3.cpython-38.pyc 175 | | | | | |-- JointBertModel.cpython-38.pyc 176 | | | | | `-- torchcrf.cpython-38.pyc 177 | | | | `-- torchcrf.py —— CRF层网络模型 178 | | | |-- predict —— 包含推理代码的文件夹 179 | | | | |-- __init__.py 180 | | | | |-- integration.py —— 负责将三个模型的结果进行投票输出成最后的结果result.json 181 | | | | |-- post_process.py —— 对模型的结果进行后处理的代码 182 | | | | |-- __pycache__ 183 | | | | | |-- __init__.cpython-38.pyc 184 | | | | | |-- post_process.cpython-38.pyc 185 | | | | | |-- test_dataset.cpython-38.pyc 186 | | | | | `-- test_utils.cpython-38.pyc 187 | | | | |-- run_interact1.py —— 对线上训练的InteractModel_1模型进行推理 188 | | | | |-- run_interact3.py —— 对线上训练的InteractModel_3模型进行推理 189 | | | | |-- run_JointBert.py —— 对线上训练的JointBert模型进行推理 190 | | | | |-- run_trained_interact1.py —— 对本地训练的InteractModel_1模型进行推理 191 | | | | |-- run_trained_interact3.py —— 对本地训练的InteractModel_3模型进行推理 192 | | | | |-- run_trained_JointBert.py —— 对本地训练的JointBert模型进行推理 193 | | | | |-- test_dataset.py —— 构建测试集dataset 194 | | | | `-- test_utils.py —— 测试集的工具类函数 195 | | | |-- __pycache__ 196 | | | | `-- __init__.cpython-38.pyc 197 | | | |-- preprocess —— 数据预处理代码 198 | | | | |-- __init__.py 199 | | | | |-- analysis.py —— 分析训练数据,哪些意图包含哪些标签 200 | | | | |-- extend_audio_sample.py —— 用于扩充意图为”Audio-Play“的小样本数据 201 | | | | |-- extend_tv_sample.py —— 用于扩充意图为”TVProgram-Play“的小样本数据 202 | | | | |-- extract_intent_sample.py —— 在训练数据中提取特定意图的数据 203 | | | | |-- generate_intent.py —— 对训练数据和验证集数据提取意图 204 | | | | |-- process_other.py —— 处理域外数据 205 | | | | |-- process_rawdata.py —— 处理原始训练数据和验证集数据 206 | | | | |-- rectify.py —— 对意图为”Alarm-Update“的训练数据进行纠正 207 | | | | |-- slot_sorted.py —— 对槽填充的标注按槽值从大到小进行排序 208 | | | | |-- split_train_dev.py —— 将原始训练数据按8: 2划分 209 | | | `-- scripts 210 | | | |-- build_vocab.py —— 构建词典,加载词汇表 211 | | | |-- config_Interact1.py —— InteractModel_1的配置文件(参数设置) 212 | | | |-- config_Interact3.py —— InteractModel_3的配置文件(参数设置) 213 | | | |-- config_jointBert.py —— JointBert的配置文件(参数设置) 214 | | | |-- dataset.py —— 构建训练集dataset 215 | | | |-- __init__.py 216 | | | |-- __pycache__ 217 | | | | |-- build_vocab.cpython-38.pyc 218 | | | | |-- config_Interact1.cpython-38.pyc 219 | | | | |-- config_Interact3.cpython-38.pyc 220 | | | | |-- config_jointBert.cpython-38.pyc 221 | | | | |-- dataset.cpython-38.pyc 222 | | | | |-- __init__.cpython-38.pyc 223 | | | | `-- utils.cpython-38.pyc 224 | | | |-- train_interact1.py —— InteractModel_1的主函数训练代码 225 | | | |-- train_interact3.py —— InteractModel_3的主函数训练代码 226 | | | |-- train_jointBert.py —— JointBert的主函数训练代码 227 | | | `-- utils.py —— 训练阶段的工具类函数(train、valid等) 228 | | |-- __init__.py 229 | | |-- prediction_result —— 模型推理结果result.json的保存文件夹 230 | | |-- __pycache__ 231 | | | `-- __init__.cpython-38.pyc 232 | | |-- raw_data —— 原始训练集文件夹 233 | | `-- user_data —— 用户数据文件夹 234 | | |-- common_data —— 训练和验证时使用的公共文件夹 235 | | | |-- intent_label.txt —— 包含11类意图的txt文件 236 | | | |-- intent_slot_mapping.json —— 意图与槽标签对应关系的字典 237 | | | |-- region_dic.json —— 用于帮助训练集槽标注的映射字典 238 | | | |-- slot_label.txt —— 标准槽标签 239 | | | `-- slot_none_vocab.txt —— 非标准槽标签 240 | | |-- dev_data 241 | | | |-- dev_intent_label.txt —— 验证集数据的意图 242 | | | |-- dev_seq_in.txt —— 验证集数据的原始句子分词输入 243 | | | |-- dev_seq_out.txt —— 验证集数据的序列标注 244 | | | `-- dev_slot_none.txt —— 验证集数据的非标准槽填充分类标签 245 | | |-- output_model —— 保存训练过程中间的模型文件夹 246 | | | |-- InteractModel_1 —— InteractModel_1文件夹(线上训练的模型保存在该文件夹下) 247 | | | | `-- trained_model —— 保存本地训练好的模型的文件夹 248 | | | | `-- Interact1_model_best.pth.tar —— 本地训练好的InteractModel_1模型 249 | | | |-- InteractModel_3 —— InteractModel_3文件夹(线上训练的模型保存在该文件夹下) 250 | | | | `-- trained_model 251 | | | | `-- Interact3_model_best.pth.tar —— 本地训练好的InteractModel_3模型 252 | | | `-- JointBert —— JointBert文件夹(线上训练的模型保存在该文件夹下) 253 | | | `-- trained_model 254 | | | `-- bert_model_best.pth.tar —— 本地训练好的JointBert模型 255 | | |-- pretrained_model —— 预训练模型文件夹 256 | | | `-- ernie —— ernie 257 | | | |-- config.json 258 | | | |-- pytorch_model.bin 259 | | | |-- special_tokens_map.json 260 | | | |-- tokenizer_config.json 261 | | | `-- vocab.txt 262 | | |-- test_data —— 测试集文件夹 263 | | | |-- test_B_final_text.json —— 测试集原始json文件 264 | | | `-- test_seq_in_B.txt —— 经过Tokenizer分词后的测试集文件 265 | | |-- tmp_result —— 保存单个模型输出结果 266 | | `-- train_data —— 训练集文件夹 267 | | |-- train_intent_label.txt —— 训练集数据的意图 268 | | |-- train_seq_in.txt —— 验证集数据的原始句子分词输入 269 | | |-- train_seq_out.txt —— 验证集数据的序列标注 270 | | `-- train_slot_none.txt —— 验证集数据的非标准槽填充分类标签 271 | |-- image —— 镜像相关文件夹 272 | | |-- readme_images —— 存放赛题解决方案和算法介绍的README文档的图片 273 | | |-- ccir-image.tar —— 镜像文件 274 | | |-- README.md —— 关于复现具体操作的REDAME文档 275 | | |-- run_infer.sh —— 使用本地训练的模型进行线上推理的脚本 276 | | `-- run.sh —— 进行线上训练和推理的脚本 277 | `-- README.md —— 赛题整体的解决方案和算法介绍的README文档 278 | ``` 279 | -------------------------------------------------------------------------------- /data/code/model/torchcrf.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.2' 2 | 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class CRF(nn.Module): 10 | """Conditional random field. 11 | 12 | This module implements a conditional random field [LMP01]_. The forward computation 13 | of this class computes the log likelihood of the given sequence of tags and 14 | emission score tensor. This class also has `~CRF.decode` method which finds 15 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. 16 | 17 | Args: 18 | num_tags: Number of tags. 19 | batch_first: Whether the first dimension corresponds to the size of a minibatch. 20 | 21 | Attributes: 22 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size 23 | ``(num_tags,)``. 24 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size 25 | ``(num_tags,)``. 26 | transitions (`~torch.nn.Parameter`): Transition score tensor of size 27 | ``(num_tags, num_tags)``. 28 | 29 | 30 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). 31 | "Conditional random fields: Probabilistic models for segmenting and 32 | labeling sequence data". *Proc. 18th International Conf. on Machine 33 | Learning*. Morgan Kaufmann. pp. 282–289. 34 | 35 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm 36 | """ 37 | 38 | def __init__(self, num_tags: int, batch_first: bool = False) -> None: 39 | if num_tags <= 0: 40 | raise ValueError(f'invalid number of tags: {num_tags}') 41 | super().__init__() 42 | self.num_tags = num_tags 43 | self.batch_first = batch_first 44 | self.start_transitions = nn.Parameter(torch.empty(num_tags)) 45 | self.end_transitions = nn.Parameter(torch.empty(num_tags)) 46 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) 47 | 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self) -> None: 51 | """Initialize the transition parameters. 52 | 53 | The parameters will be initialized randomly from a uniform distribution 54 | between -0.1 and 0.1. 55 | """ 56 | nn.init.uniform_(self.start_transitions, -0.1, 0.1) 57 | nn.init.uniform_(self.end_transitions, -0.1, 0.1) 58 | nn.init.uniform_(self.transitions, -0.1, 0.1) 59 | 60 | def __repr__(self) -> str: 61 | return f'{self.__class__.__name__}(num_tags={self.num_tags})' 62 | 63 | def forward( 64 | self, 65 | emissions: torch.Tensor, 66 | tags: torch.LongTensor, 67 | mask: Optional[torch.ByteTensor] = None, 68 | reduction: str = 'sum', 69 | ) -> torch.Tensor: 70 | """Compute the conditional log likelihood of a sequence of tags given emission scores. 71 | 72 | Args: 73 | emissions (`~torch.Tensor`): Emission score tensor of size 74 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 75 | ``(batch_size, seq_length, num_tags)`` otherwise. 76 | tags (`~torch.LongTensor`): Sequence of tags tensor of size 77 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, 78 | ``(batch_size, seq_length)`` otherwise. 79 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 80 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 81 | reduction: Specifies the reduction to apply to the output: 82 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. 83 | ``sum``: the output will be summed over batches. ``mean``: the output will be 84 | averaged over batches. ``token_mean``: the output will be averaged over tokens. 85 | 86 | Returns: 87 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if 88 | reduction is ``none``, ``()`` otherwise. 89 | """ 90 | self._validate(emissions, tags=tags, mask=mask) 91 | if reduction not in ('none', 'sum', 'mean', 'token_mean'): 92 | raise ValueError(f'invalid reduction: {reduction}') 93 | if mask is None: 94 | mask = torch.ones_like(tags, dtype=torch.uint8) 95 | 96 | if self.batch_first: 97 | emissions = emissions.transpose(0, 1) 98 | tags = tags.transpose(0, 1) 99 | mask = mask.transpose(0, 1) 100 | 101 | # shape: (batch_size,) 102 | numerator = self._compute_score(emissions, tags, mask) 103 | # shape: (batch_size,) 104 | denominator = self._compute_normalizer(emissions, mask) 105 | # shape: (batch_size,) 106 | llh = numerator - denominator 107 | 108 | if reduction == 'none': 109 | return llh 110 | if reduction == 'sum': 111 | return llh.sum() 112 | if reduction == 'mean': 113 | return llh.mean() 114 | assert reduction == 'token_mean' 115 | return llh.sum() / mask.float().sum() 116 | 117 | def decode(self, emissions: torch.Tensor, 118 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: 119 | """Find the most likely tag sequence using Viterbi algorithm. 120 | 121 | Args: 122 | emissions (`~torch.Tensor`): Emission score tensor of size 123 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 124 | ``(batch_size, seq_length, num_tags)`` otherwise. 125 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 126 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 127 | 128 | Returns: 129 | List of list containing the best tag sequence for each batch. 130 | """ 131 | self._validate(emissions, mask=mask) 132 | if mask is None: 133 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) 134 | 135 | if self.batch_first: 136 | emissions = emissions.transpose(0, 1) 137 | mask = mask.transpose(0, 1) 138 | 139 | return self._viterbi_decode(emissions, mask) 140 | 141 | def _validate( 142 | self, 143 | emissions: torch.Tensor, 144 | tags: Optional[torch.LongTensor] = None, 145 | mask: Optional[torch.ByteTensor] = None) -> None: 146 | if emissions.dim() != 3: 147 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') 148 | if emissions.size(2) != self.num_tags: 149 | raise ValueError( 150 | f'expected last dimension of emissions is {self.num_tags}, ' 151 | f'got {emissions.size(2)}') 152 | 153 | if tags is not None: 154 | if emissions.shape[:2] != tags.shape: 155 | raise ValueError( 156 | 'the first two dimensions of emissions and tags must match, ' 157 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') 158 | 159 | if mask is not None: 160 | if emissions.shape[:2] != mask.shape: 161 | raise ValueError( 162 | 'the first two dimensions of emissions and mask must match, ' 163 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') 164 | no_empty_seq = not self.batch_first and mask[0].all() 165 | no_empty_seq_bf = self.batch_first and mask[:, 0].all() 166 | if not no_empty_seq and not no_empty_seq_bf: 167 | raise ValueError('mask of the first timestep must all be on') 168 | 169 | def _compute_score( 170 | self, emissions: torch.Tensor, tags: torch.LongTensor, 171 | mask: torch.ByteTensor) -> torch.Tensor: 172 | # emissions: (seq_length, batch_size, num_tags) 173 | # tags: (seq_length, batch_size) 174 | # mask: (seq_length, batch_size) 175 | assert emissions.dim() == 3 and tags.dim() == 2 176 | assert emissions.shape[:2] == tags.shape 177 | assert emissions.size(2) == self.num_tags 178 | assert mask.shape == tags.shape 179 | assert mask[0].all() 180 | 181 | seq_length, batch_size = tags.shape 182 | mask = mask.float() 183 | 184 | # Start transition score and first emission 185 | # shape: (batch_size,) 186 | score = self.start_transitions[tags[0]] 187 | score += emissions[0, torch.arange(batch_size), tags[0]] 188 | 189 | for i in range(1, seq_length): 190 | # Transition score to next tag, only added if next timestep is valid (mask == 1) 191 | # shape: (batch_size,) 192 | score += self.transitions[tags[i - 1], tags[i]] * mask[i] 193 | 194 | # Emission score for next tag, only added if next timestep is valid (mask == 1) 195 | # shape: (batch_size,) 196 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] 197 | 198 | # End transition score 199 | # shape: (batch_size,) 200 | seq_ends = mask.long().sum(dim=0) - 1 201 | # shape: (batch_size,) 202 | last_tags = tags[seq_ends, torch.arange(batch_size)] 203 | # shape: (batch_size,) 204 | score += self.end_transitions[last_tags] 205 | 206 | return score 207 | 208 | def _compute_normalizer( 209 | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: 210 | # emissions: (seq_length, batch_size, num_tags) 211 | # mask: (seq_length, batch_size) 212 | assert emissions.dim() == 3 and mask.dim() == 2 213 | assert emissions.shape[:2] == mask.shape 214 | assert emissions.size(2) == self.num_tags 215 | assert mask[0].all() 216 | 217 | seq_length = emissions.size(0) 218 | 219 | # Start transition score and first emission; score has size of 220 | # (batch_size, num_tags) where for each batch, the j-th column stores 221 | # the score that the first timestep has tag j 222 | # shape: (batch_size, num_tags) 223 | score = self.start_transitions + emissions[0] 224 | 225 | for i in range(1, seq_length): 226 | # Broadcast score for every possible next tag 227 | # shape: (batch_size, num_tags, 1) 228 | broadcast_score = score.unsqueeze(2) 229 | 230 | # Broadcast emission score for every possible current tag 231 | # shape: (batch_size, 1, num_tags) 232 | broadcast_emissions = emissions[i].unsqueeze(1) 233 | 234 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 235 | # for each sample, entry at row i and column j stores the sum of scores of all 236 | # possible tag sequences so far that end with transitioning from tag i to tag j 237 | # and emitting 238 | # shape: (batch_size, num_tags, num_tags) 239 | next_score = broadcast_score + self.transitions + broadcast_emissions 240 | 241 | # Sum over all possible current tags, but we're in score space, so a sum 242 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of 243 | # all possible tag sequences so far, that end in tag i 244 | # shape: (batch_size, num_tags) 245 | next_score = torch.logsumexp(next_score, dim=1) 246 | 247 | # Set score to the next score if this timestep is valid (mask == 1) 248 | # shape: (batch_size, num_tags) 249 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 250 | 251 | # End transition score 252 | # shape: (batch_size, num_tags) 253 | score += self.end_transitions 254 | 255 | # Sum (log-sum-exp) over all possible tags 256 | # shape: (batch_size,) 257 | return torch.logsumexp(score, dim=1) 258 | 259 | def _viterbi_decode(self, emissions: torch.FloatTensor, 260 | mask: torch.ByteTensor) -> List[List[int]]: 261 | # emissions: (seq_length, batch_size, num_tags) 262 | # mask: (seq_length, batch_size) 263 | assert emissions.dim() == 3 and mask.dim() == 2 264 | assert emissions.shape[:2] == mask.shape 265 | assert emissions.size(2) == self.num_tags 266 | assert mask[0].all() 267 | 268 | seq_length, batch_size = mask.shape 269 | 270 | # Start transition and first emission 271 | # shape: (batch_size, num_tags) 272 | score = self.start_transitions + emissions[0] 273 | history = [] 274 | 275 | # score is a tensor of size (batch_size, num_tags) where for every batch, 276 | # value at column j stores the score of the best tag sequence so far that ends 277 | # with tag j 278 | # history saves where the best tags candidate transitioned from; this is used 279 | # when we trace back the best tag sequence 280 | 281 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 282 | # for every possible next tag 283 | for i in range(1, seq_length): 284 | # Broadcast viterbi score for every possible next tag 285 | # shape: (batch_size, num_tags, 1) 286 | broadcast_score = score.unsqueeze(2) 287 | 288 | # Broadcast emission score for every possible current tag 289 | # shape: (batch_size, 1, num_tags) 290 | broadcast_emission = emissions[i].unsqueeze(1) 291 | 292 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 293 | # for each sample, entry at row i and column j stores the score of the best 294 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 295 | # shape: (batch_size, num_tags, num_tags) 296 | next_score = broadcast_score + self.transitions + broadcast_emission 297 | 298 | # Find the maximum score over all possible current tag 299 | # shape: (batch_size, num_tags) 300 | next_score, indices = next_score.max(dim=1) 301 | 302 | # Set score to the next score if this timestep is valid (mask == 1) 303 | # and save the index that produces the next score 304 | # shape: (batch_size, num_tags) 305 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 306 | history.append(indices) 307 | 308 | # End transition score 309 | # shape: (batch_size, num_tags) 310 | score += self.end_transitions 311 | 312 | # Now, compute the best path for each sample 313 | 314 | # shape: (batch_size,) 315 | seq_ends = mask.long().sum(dim=0) - 1 316 | best_tags_list = [] 317 | 318 | for idx in range(batch_size): 319 | # Find the tag which maximizes the score at the last timestep; this is our best tag 320 | # for the last timestep 321 | _, best_last_tag = score[idx].max(dim=0) 322 | best_tags = [best_last_tag.item()] 323 | 324 | # We trace back where the best last tag comes from, append that to our best tag 325 | # sequence, and trace it back again, and so on 326 | for hist in reversed(history[:seq_ends[idx]]): 327 | best_last_tag = hist[idx][best_tags[-1]] 328 | best_tags.append(best_last_tag.item()) 329 | 330 | # Reverse the order because we start from the last timestep 331 | best_tags.reverse() 332 | best_tags_list.append(best_tags) 333 | 334 | return best_tags_list 335 | -------------------------------------------------------------------------------- /data/code/predict/run_JointBert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/9 15:22 4 | # @Author : JJkinging 5 | # @File : run_predict.py 6 | import warnings 7 | 8 | import torch 9 | import json 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from data.code.scripts.config_jointBert import Config 13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length 14 | from data.code.predict.test_dataset import CCFDataset 15 | from data.code.model.JointBertModel import JointBertModel 16 | from transformers import AutoTokenizer 17 | from data.code.predict.post_process import process 18 | 19 | 20 | def run_test(model, dataloader, slot_dict): 21 | 22 | model.eval() 23 | device = model.device 24 | intent_probs_list = [] 25 | 26 | intent_pred = [] 27 | slotNone_pred_output = [] 28 | slot_pre_output = [] 29 | input_ids_list = [] 30 | 31 | id2slot = {value: key for key, value in slot_dict.items()} 32 | with torch.no_grad(): 33 | tqdm_batch_iterator = tqdm(dataloader) 34 | for _, batch in enumerate(tqdm_batch_iterator): 35 | input_ids, input_mask = batch 36 | 37 | input_ids = input_ids.to(device) 38 | input_mask = input_mask.byte().to(device) 39 | real_length = torch.sum(input_mask, dim=1) 40 | 41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask) 42 | 43 | # intent 44 | intent_probs = torch.softmax(intent_logits, dim=-1) 45 | predict_labels = torch.argmax(intent_probs, dim=-1) 46 | predict_labels = predict_labels.cpu().numpy().tolist() 47 | intent_pred.extend(predict_labels) 48 | 49 | # slot_none 50 | slot_none_probs = torch.sigmoid(slot_none_logits) 51 | slot_none_probs = slot_none_probs > 0.5 52 | slot_none_probs = slot_none_probs.cpu().numpy() 53 | slot_none_probs = slot_none_probs.astype(int) 54 | slot_none_probs = slot_none_probs.tolist() 55 | 56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id 57 | tmp = [] 58 | if 1 in slot_none_id: 59 | for idx, none_id in enumerate(slot_none_id): 60 | if none_id == 1: 61 | tmp.append(idx) 62 | else: 63 | tmp.append(28) # 28在slot_none_vocab中表示none 64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...] 65 | 66 | # slot 67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot) 68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记 69 | slot_pre_output.extend(out_path) 70 | 71 | # input_ids 72 | input_ids = input_ids.cpu().numpy().tolist() 73 | input_ids_list.extend(input_ids) 74 | 75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list 76 | 77 | 78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key): 79 | value = tokenizer.decode(input_ids[start:end + 1]) 80 | value = ''.join(value.split(' ')).replace('[UNK]', '$') 81 | s = -1 82 | e = -1 83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写 84 | is_unk = False 85 | if len(value) > 0: 86 | if value[0] == '$': # 如果第一个就是[UNK] 87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位 88 | is_unk = True 89 | i = 0 90 | j = 0 91 | while i < len(ori_sen_lower) and j < len(value): 92 | if ori_sen_lower[i] == value[j]: 93 | if s == -1: 94 | s = i 95 | e = i 96 | else: 97 | e += 1 98 | i += 1 99 | j += 1 100 | elif ori_sen_lower[i] == ' ': 101 | e += 1 102 | i += 1 103 | elif value[j] == '$': 104 | e += 1 105 | i += 1 106 | j += 1 107 | elif ori_sen_lower[i] != value[j]: 108 | i -= j - 1 109 | j = 0 110 | s = -1 111 | e = -1 112 | if is_unk: 113 | s = s-1 114 | final_value = ori_sen[s:e + 1] 115 | if key in tmp_dict.keys(): 116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list 117 | tmp_list = [tmp_dict[key], final_value] 118 | tmp_dict[key] = tmp_list 119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list): 120 | tmp_dict[key].append(final_value) 121 | else: 122 | tmp_dict[key] = final_value 123 | return tmp_dict 124 | 125 | 126 | def is_nest(slot_pre): 127 | ''' 128 | 判断该条数据的 slot filling 是否存在嵌套 129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O'] 130 | :return: 存在返回True, 否则返回False 131 | ''' 132 | for j, cur_tag in enumerate(slot_pre): 133 | if cur_tag.startswith('I-'): 134 | cur_tag_name = cur_tag[2:] 135 | if j < len(slot_pre)-1: 136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头 137 | post_tag_name = slot_pre[j+1][2:] 138 | if cur_tag_name != post_tag_name: # 但二者不等 139 | return True 140 | return False 141 | 142 | 143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套 144 | ''' 145 | :param input_ids: 原句的id形式(wordpiece过) 146 | :param slot_pre: 该句的预测slot 147 | :param ori_sen: 原句(字符) 148 | :return: 149 | ''' 150 | start_outer = -1 151 | end_outer = -1 152 | tmp_dict = {} 153 | pre_end = -1 154 | for i, tag in enumerate(slot_pre): 155 | if i <= pre_end: 156 | continue 157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-' 158 | start_outer = i # 临时起始位置 159 | end_outer = i 160 | key_outer = tag[2:] 161 | 162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同 163 | end_outer = i 164 | if i == len(slot_pre)-1: # 到了最后 165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 166 | 167 | elif tag.startswith('O') and start_outer == -1: 168 | continue 169 | elif tag.startswith('O') and start_outer != -1: 170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 171 | start_outer = -1 172 | end_outer = -1 173 | # 第一种嵌套:B-region I-region I-name I-name 174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同 175 | start_inner = start_outer 176 | end_inner = end_outer 177 | key_inner = key_outer 178 | # 处理内层 179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner) 180 | # start_outer不变 181 | end_outer = i # end_outer为当前slot下标 182 | key_outer = slot_pre[i][2:] # 需修改key_outer 183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name 184 | elif tag.startswith('B-'): 185 | flag = False 186 | pre_start = start_outer 187 | pre_end = end_outer 188 | pre_key = key_outer 189 | start_outer = i 190 | end_outer = i 191 | key_outer = slot_pre[i][2:] 192 | # ************************ 193 | for j in range(i, len(slot_pre)): 194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key: 195 | flag = True 196 | pre_end = j 197 | slot_pre[j] = 'O' # 置为'O' 198 | # ************************* 199 | if flag: # 上一个slot是嵌套 200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层 201 | # 以后遇到完整slot的就加入(假设只有二级嵌套) 202 | for k in range(i+1, pre_end+1): 203 | if slot_pre[k] != 'O': 204 | if slot_pre[k].startswith('I-'): 205 | end_outer = k 206 | elif slot_pre[k].startswith('B-') and start_outer == -1: 207 | start_outer = k 208 | end_outer = k 209 | key_outer = slot_pre[k][2:] 210 | elif slot_pre[k].startswith('B-') and start_outer != -1: 211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 212 | key_outer) # 处理内层 213 | start_outer = k 214 | end_outer = k 215 | key_outer = slot_pre[k][2:] 216 | elif slot_pre[k] == 'O' and start_outer != -1: 217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 218 | key_outer) # 处理内层 219 | start_outer = -1 220 | end_outer = -1 221 | key_outer = None 222 | elif slot_pre[k] == 'O' and start_outer == -1: 223 | continue 224 | else: # 上一个不是嵌套 225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套 226 | 227 | return tmp_dict 228 | 229 | 230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list): 231 | fp = open('nest.txt', 'w+', encoding='utf-8') 232 | slot_list = [] 233 | count = 0 234 | # ddd = 0 235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据 236 | # if ddd < 1111: 237 | # ddd += 1 238 | # continue 239 | tmp_dict = {} 240 | start = 0 241 | end = 0 242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为 243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i]) 244 | slot_list.append(tmp_dict) 245 | fp.write(str(ori_sen_list[i])+'\t') 246 | fp.write(str(tmp_dict)+'\n') 247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n') 248 | count += 1 249 | else: 250 | for j, slot in enumerate(slot_ids): # 遍历每个字 251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]': 252 | if slot.startswith('B-') and start == 0: 253 | start = j # 槽值起始位置 254 | end = j 255 | key = slot[2:] 256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 258 | break 259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽 260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key 261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入 262 | start = j 263 | end = j 264 | key = slot[2:] 265 | else: # 'I-'开头 266 | end += 1 267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 269 | break 270 | else: 271 | if end == 0: # 说明没找到槽 272 | continue 273 | else: 274 | if slot != '[EOS]': 275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 276 | start = 0 277 | end = 0 278 | slot_list.append(tmp_dict) 279 | print(str(count)+'——JointBert 推理完成!') 280 | fp.close() 281 | return slot_list 282 | 283 | 284 | if __name__ == "__main__": 285 | warnings.filterwarnings("ignore") 286 | config = Config() 287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp: 289 | region_dict = json.load(fp) 290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie') 291 | print('JointBert 开始推理!') 292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt') 293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt') 294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt') 295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt') 296 | 297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt') 298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt') 299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt') 300 | 301 | intent_tagset_size = len(intent_dict) 302 | slot_none_tag_size = len(slot_none_dict) 303 | slot_tag_size = len(slot_dict) 304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt' 305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length) 306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size, 307 | collate_fn=collate_to_max_length) 308 | 309 | model = JointBertModel('../../user_data/pretrained_model/ernie', 310 | config.bert_hidden_size, 311 | intent_tagset_size, 312 | slot_none_tag_size, 313 | slot_tag_size, 314 | device).to(device) 315 | checkpoint = torch.load('../../user_data/output_model/JointBert/bert_model_best.pth.tar', map_location='cpu') 316 | model.load_state_dict(checkpoint["model"]) 317 | 318 | # -------------------- Testing ------------------- # 319 | print("\n", 320 | 20 * "=", 321 | "Test model on device: {}".format(device), 322 | 20 * "=") 323 | 324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict) 325 | 326 | ori_sen_list = [] 327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp: 328 | raw_data = json.load(fp) 329 | for filename, single_data in raw_data.items(): 330 | text = single_data['text'] 331 | ori_sen_list.append(text) 332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list) 333 | 334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!) 335 | for i, slot_dict in enumerate(slot_list): 336 | for slot, slot_value in slot_dict.items(): 337 | for region_key, region_list in region_dict.items(): 338 | if slot_value in region_list: 339 | slot_list[i][slot] = region_key 340 | res = {} 341 | for i in range(len(intent_pred)): 342 | big_tmp = {} 343 | slot_tmp_dict = {} 344 | # intent 345 | intent = id2intent[intent_pred[i]] 346 | # slot_none 347 | for slot_none_id in slotNone_pred_output[i]: 348 | if 0 <= slot_none_id <= 9: 349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id] 350 | elif 10 <= slot_none_id <= 18: 351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id] 352 | elif 19 <= slot_none_id <= 22: 353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id] 354 | elif 23 <= slot_none_id <= 27: 355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id] 356 | # slot 357 | slot_tmp_dict.update(slot_list[i]) 358 | 359 | big_tmp['intent'] = intent 360 | big_tmp['slots'] = slot_tmp_dict 361 | 362 | length = len(str(i)) 363 | o_num = 5 - length 364 | index = 'NLU' + '0'*o_num + str(i) 365 | res[index] = big_tmp 366 | 367 | with open('../../user_data/tmp_result/result_bert.json', 'w', encoding='utf-8') as fp: 368 | json.dump(res, fp, ensure_ascii=False) 369 | 370 | source_path = '../../user_data/tmp_result/result_bert.json' 371 | target_path = '../../user_data/tmp_result/result_bert_post.json' 372 | process(source_path, target_path) 373 | -------------------------------------------------------------------------------- /data/code/predict/run_interact1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/9 15:22 4 | # @Author : JJkinging 5 | # @File : run_predict.py 6 | import warnings 7 | 8 | import torch 9 | import json 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from data.code.scripts.config_Interact1 import Config 13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length 14 | from data.code.predict.test_dataset import CCFDataset 15 | from data.code.model.InteractModel_1 import InteractModel 16 | from transformers import AutoTokenizer 17 | from data.code.predict.post_process import process 18 | 19 | 20 | def run_test(model, dataloader, slot_dict): 21 | 22 | model.eval() 23 | device = model.device 24 | intent_probs_list = [] 25 | 26 | intent_pred = [] 27 | slotNone_pred_output = [] 28 | slot_pre_output = [] 29 | input_ids_list = [] 30 | 31 | id2slot = {value: key for key, value in slot_dict.items()} 32 | with torch.no_grad(): 33 | tqdm_batch_iterator = tqdm(dataloader) 34 | for _, batch in enumerate(tqdm_batch_iterator): 35 | input_ids, input_mask = batch 36 | 37 | input_ids = input_ids.to(device) 38 | input_mask = input_mask.byte().to(device) 39 | real_length = torch.sum(input_mask, dim=1) 40 | 41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask) 42 | 43 | # intent 44 | intent_probs = torch.softmax(intent_logits, dim=-1) 45 | predict_labels = torch.argmax(intent_probs, dim=-1) 46 | predict_labels = predict_labels.cpu().numpy().tolist() 47 | intent_pred.extend(predict_labels) 48 | 49 | # slot_none 50 | slot_none_probs = torch.sigmoid(slot_none_logits) 51 | slot_none_probs = slot_none_probs > 0.5 52 | slot_none_probs = slot_none_probs.cpu().numpy() 53 | slot_none_probs = slot_none_probs.astype(int) 54 | slot_none_probs = slot_none_probs.tolist() 55 | 56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id 57 | tmp = [] 58 | if 1 in slot_none_id: 59 | for idx, none_id in enumerate(slot_none_id): 60 | if none_id == 1: 61 | tmp.append(idx) 62 | else: 63 | tmp.append(28) # 28在slot_none_vocab中表示none 64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...] 65 | 66 | # slot 67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot) 68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记 69 | slot_pre_output.extend(out_path) 70 | 71 | # input_ids 72 | input_ids = input_ids.cpu().numpy().tolist() 73 | input_ids_list.extend(input_ids) 74 | 75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list 76 | 77 | 78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key): 79 | value = tokenizer.decode(input_ids[start:end + 1]) 80 | value = ''.join(value.split(' ')).replace('[UNK]', '$') 81 | s = -1 82 | e = -1 83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写 84 | is_unk = False 85 | if len(value) > 0: 86 | if value[0] == '$': # 如果第一个就是[UNK] 87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位 88 | is_unk = True 89 | i = 0 90 | j = 0 91 | while i < len(ori_sen_lower) and j < len(value): 92 | if ori_sen_lower[i] == value[j]: 93 | if s == -1: 94 | s = i 95 | e = i 96 | else: 97 | e += 1 98 | i += 1 99 | j += 1 100 | elif ori_sen_lower[i] == ' ': 101 | e += 1 102 | i += 1 103 | elif value[j] == '$': 104 | e += 1 105 | i += 1 106 | j += 1 107 | elif ori_sen_lower[i] != value[j]: 108 | i -= j - 1 109 | j = 0 110 | s = -1 111 | e = -1 112 | if is_unk: 113 | s = s-1 114 | final_value = ori_sen[s:e + 1] 115 | if key in tmp_dict.keys(): 116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list 117 | tmp_list = [tmp_dict[key], final_value] 118 | tmp_dict[key] = tmp_list 119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list): 120 | tmp_dict[key].append(final_value) 121 | else: 122 | tmp_dict[key] = final_value 123 | return tmp_dict 124 | 125 | 126 | def is_nest(slot_pre): 127 | ''' 128 | 判断该条数据的 slot filling 是否存在嵌套 129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O'] 130 | :return: 存在返回True, 否则返回False 131 | ''' 132 | for j, cur_tag in enumerate(slot_pre): 133 | if cur_tag.startswith('I-'): 134 | cur_tag_name = cur_tag[2:] 135 | if j < len(slot_pre)-1: 136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头 137 | post_tag_name = slot_pre[j+1][2:] 138 | if cur_tag_name != post_tag_name: # 但二者不等 139 | return True 140 | return False 141 | 142 | 143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套 144 | ''' 145 | :param input_ids: 原句的id形式(wordpiece过) 146 | :param slot_pre: 该句的预测slot 147 | :param ori_sen: 原句(字符) 148 | :return: 149 | ''' 150 | start_outer = -1 151 | end_outer = -1 152 | tmp_dict = {} 153 | pre_end = -1 154 | for i, tag in enumerate(slot_pre): 155 | if i <= pre_end: 156 | continue 157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-' 158 | start_outer = i # 临时起始位置 159 | end_outer = i 160 | key_outer = tag[2:] 161 | 162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同 163 | end_outer = i 164 | if i == len(slot_pre)-1: # 到了最后 165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 166 | 167 | elif tag.startswith('O') and start_outer == -1: 168 | continue 169 | elif tag.startswith('O') and start_outer != -1: 170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 171 | start_outer = -1 172 | end_outer = -1 173 | # 第一种嵌套:B-region I-region I-name I-name 174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同 175 | start_inner = start_outer 176 | end_inner = end_outer 177 | key_inner = key_outer 178 | # 处理内层 179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner) 180 | # start_outer不变 181 | end_outer = i # end_outer为当前slot下标 182 | key_outer = slot_pre[i][2:] # 需修改key_outer 183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name 184 | elif tag.startswith('B-'): 185 | flag = False 186 | pre_start = start_outer 187 | pre_end = end_outer 188 | pre_key = key_outer 189 | start_outer = i 190 | end_outer = i 191 | key_outer = slot_pre[i][2:] 192 | # ************************ 193 | for j in range(i, len(slot_pre)): 194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key: 195 | flag = True 196 | pre_end = j 197 | slot_pre[j] = 'O' # 置为'O' 198 | # ************************* 199 | if flag: # 上一个slot是嵌套 200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层 201 | # 以后遇到完整slot的就加入(假设只有二级嵌套) 202 | for k in range(i+1, pre_end+1): 203 | if slot_pre[k] != 'O': 204 | if slot_pre[k].startswith('I-'): 205 | end_outer = k 206 | elif slot_pre[k].startswith('B-') and start_outer == -1: 207 | start_outer = k 208 | end_outer = k 209 | key_outer = slot_pre[k][2:] 210 | elif slot_pre[k].startswith('B-') and start_outer != -1: 211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 212 | key_outer) # 处理内层 213 | start_outer = k 214 | end_outer = k 215 | key_outer = slot_pre[k][2:] 216 | elif slot_pre[k] == 'O' and start_outer != -1: 217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 218 | key_outer) # 处理内层 219 | start_outer = -1 220 | end_outer = -1 221 | key_outer = None 222 | elif slot_pre[k] == 'O' and start_outer == -1: 223 | continue 224 | else: # 上一个不是嵌套 225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套 226 | 227 | return tmp_dict 228 | 229 | 230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list): 231 | fp = open('nest.txt', 'w+', encoding='utf-8') 232 | slot_list = [] 233 | count = 0 234 | # ddd = 0 235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据 236 | # if ddd < 1111: 237 | # ddd += 1 238 | # continue 239 | tmp_dict = {} 240 | start = 0 241 | end = 0 242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为 243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i]) 244 | slot_list.append(tmp_dict) 245 | fp.write(str(ori_sen_list[i])+'\t') 246 | fp.write(str(tmp_dict)+'\n') 247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n') 248 | count += 1 249 | else: 250 | for j, slot in enumerate(slot_ids): # 遍历每个字 251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]': 252 | if slot.startswith('B-') and start == 0: 253 | start = j # 槽值起始位置 254 | end = j 255 | key = slot[2:] 256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 258 | break 259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽 260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key 261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入 262 | start = j 263 | end = j 264 | key = slot[2:] 265 | else: # 'I-'开头 266 | end += 1 267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 269 | break 270 | else: 271 | if end == 0: # 说明没找到槽 272 | continue 273 | else: 274 | if slot != '[EOS]': 275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 276 | start = 0 277 | end = 0 278 | slot_list.append(tmp_dict) 279 | print(str(count)+'——interact1 推理完成!') 280 | fp.close() 281 | return slot_list 282 | 283 | 284 | if __name__ == "__main__": 285 | warnings.filterwarnings("ignore") 286 | config = Config() 287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp: 289 | region_dict = json.load(fp) 290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie') 291 | print('interact1 开始推理!') 292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt') 293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt') 294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt') 295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt') 296 | 297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt') 298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt') 299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt') 300 | 301 | intent_tagset_size = len(intent_dict) 302 | slot_none_tag_size = len(slot_none_dict) 303 | slot_tag_size = len(slot_dict) 304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt' 305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length) 306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size, 307 | collate_fn=collate_to_max_length) 308 | 309 | model = InteractModel('../../user_data/pretrained_model/ernie', 310 | config.bert_hidden_size, 311 | intent_tagset_size, 312 | slot_none_tag_size, 313 | slot_tag_size, 314 | device).to(device) 315 | checkpoint = torch.load('../../user_data/output_model/InteractModel_1/Interact1_model_best.pth.tar') 316 | model.load_state_dict(checkpoint["model"]) 317 | 318 | # -------------------- Testing ------------------- # 319 | print("\n", 320 | 20 * "=", 321 | "Test model on device: {}".format(device), 322 | 20 * "=") 323 | 324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict) 325 | 326 | ori_sen_list = [] 327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp: 328 | raw_data = json.load(fp) 329 | for filename, single_data in raw_data.items(): 330 | text = single_data['text'] 331 | ori_sen_list.append(text) 332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list) 333 | 334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!) 335 | for i, slot_dict in enumerate(slot_list): 336 | for slot, slot_value in slot_dict.items(): 337 | for region_key, region_list in region_dict.items(): 338 | if slot_value in region_list: 339 | slot_list[i][slot] = region_key 340 | res = {} 341 | for i in range(len(intent_pred)): 342 | big_tmp = {} 343 | slot_tmp_dict = {} 344 | # intent 345 | intent = id2intent[intent_pred[i]] 346 | # slot_none 347 | for slot_none_id in slotNone_pred_output[i]: 348 | if 0 <= slot_none_id <= 9: 349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id] 350 | elif 10 <= slot_none_id <= 18: 351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id] 352 | elif 19 <= slot_none_id <= 22: 353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id] 354 | elif 23 <= slot_none_id <= 27: 355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id] 356 | # slot 357 | slot_tmp_dict.update(slot_list[i]) 358 | 359 | big_tmp['intent'] = intent 360 | big_tmp['slots'] = slot_tmp_dict 361 | 362 | length = len(str(i)) 363 | o_num = 5 - length 364 | index = 'NLU' + '0'*o_num + str(i) 365 | res[index] = big_tmp 366 | 367 | with open('../../user_data/tmp_result/result_interact_1.json', 'w', encoding='utf-8') as fp: 368 | json.dump(res, fp, ensure_ascii=False) 369 | 370 | source_path = '../../user_data/tmp_result/result_interact_1.json' 371 | target_path = '../../user_data/tmp_result/result_interact_1_post.json' 372 | process(source_path, target_path) 373 | -------------------------------------------------------------------------------- /data/code/predict/run_interact3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2021/9/9 15:22 4 | # @Author : JJkinging 5 | # @File : run_predict.py 6 | import warnings 7 | 8 | import torch 9 | import json 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from data.code.scripts.config_Interact3 import Config 13 | from data.code.predict.test_utils import load_reverse_vocab, load_vocab, collate_to_max_length 14 | from data.code.predict.test_dataset import CCFDataset 15 | from data.code.model.InteractModel_3 import InteractModel 16 | from transformers import AutoTokenizer 17 | from data.code.predict.post_process import process 18 | 19 | 20 | def run_test(model, dataloader, slot_dict): 21 | 22 | model.eval() 23 | device = model.device 24 | intent_probs_list = [] 25 | 26 | intent_pred = [] 27 | slotNone_pred_output = [] 28 | slot_pre_output = [] 29 | input_ids_list = [] 30 | 31 | id2slot = {value: key for key, value in slot_dict.items()} 32 | with torch.no_grad(): 33 | tqdm_batch_iterator = tqdm(dataloader) 34 | for _, batch in enumerate(tqdm_batch_iterator): 35 | input_ids, input_mask = batch 36 | 37 | input_ids = input_ids.to(device) 38 | input_mask = input_mask.byte().to(device) 39 | real_length = torch.sum(input_mask, dim=1) 40 | 41 | intent_logits, slot_none_logits, slot_logits = model(input_ids, input_mask) 42 | 43 | # intent 44 | intent_probs = torch.softmax(intent_logits, dim=-1) 45 | predict_labels = torch.argmax(intent_probs, dim=-1) 46 | predict_labels = predict_labels.cpu().numpy().tolist() 47 | intent_pred.extend(predict_labels) 48 | 49 | # slot_none 50 | slot_none_probs = torch.sigmoid(slot_none_logits) 51 | slot_none_probs = slot_none_probs > 0.5 52 | slot_none_probs = slot_none_probs.cpu().numpy() 53 | slot_none_probs = slot_none_probs.astype(int) 54 | slot_none_probs = slot_none_probs.tolist() 55 | 56 | for slot_none_id in slot_none_probs: # 遍历这个batch的slot_none_id 57 | tmp = [] 58 | if 1 in slot_none_id: 59 | for idx, none_id in enumerate(slot_none_id): 60 | if none_id == 1: 61 | tmp.append(idx) 62 | else: 63 | tmp.append(28) # 28在slot_none_vocab中表示none 64 | slotNone_pred_output.append(tmp) # [[4, 5], [], [], ...] 65 | 66 | # slot 67 | out_path = model.slot_predict(slot_logits, input_mask, id2slot) 68 | out_path = [[id2slot[idx] for idx in one_data] for one_data in out_path] # 不去掉'[START]'和'[EOS]'标记 69 | slot_pre_output.extend(out_path) 70 | 71 | # input_ids 72 | input_ids = input_ids.cpu().numpy().tolist() 73 | input_ids_list.extend(input_ids) 74 | 75 | return intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list 76 | 77 | 78 | def get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start, end, key): 79 | value = tokenizer.decode(input_ids[start:end + 1]) 80 | value = ''.join(value.split(' ')).replace('[UNK]', '$') 81 | s = -1 82 | e = -1 83 | ori_sen_lower = ori_sen.lower() # 把其中的英文字母全变为小写 84 | is_unk = False 85 | if len(value) > 0: 86 | if value[0] == '$': # 如果第一个就是[UNK] 87 | value = value[1:] # 查找时先忽略,最后再把开始位置往前移一位 88 | is_unk = True 89 | i = 0 90 | j = 0 91 | while i < len(ori_sen_lower) and j < len(value): 92 | if ori_sen_lower[i] == value[j]: 93 | if s == -1: 94 | s = i 95 | e = i 96 | else: 97 | e += 1 98 | i += 1 99 | j += 1 100 | elif ori_sen_lower[i] == ' ': 101 | e += 1 102 | i += 1 103 | elif value[j] == '$': 104 | e += 1 105 | i += 1 106 | j += 1 107 | elif ori_sen_lower[i] != value[j]: 108 | i -= j - 1 109 | j = 0 110 | s = -1 111 | e = -1 112 | if is_unk: 113 | s = s-1 114 | final_value = ori_sen[s:e + 1] 115 | if key in tmp_dict.keys(): 116 | if tmp_dict[key] != '' and not isinstance(tmp_dict[key], list): # 如果该key已经有值,且不为list 117 | tmp_list = [tmp_dict[key], final_value] 118 | tmp_dict[key] = tmp_list 119 | elif tmp_dict[key] != '' and isinstance(tmp_dict[key], list): 120 | tmp_dict[key].append(final_value) 121 | else: 122 | tmp_dict[key] = final_value 123 | return tmp_dict 124 | 125 | 126 | def is_nest(slot_pre): 127 | ''' 128 | 判断该条数据的 slot filling 是否存在嵌套 129 | :param pred_slot: 例如:['O', 'O', 'O', 'B-age', 'I-age', 'O', 'O', 'O'] 130 | :return: 存在返回True, 否则返回False 131 | ''' 132 | for j, cur_tag in enumerate(slot_pre): 133 | if cur_tag.startswith('I-'): 134 | cur_tag_name = cur_tag[2:] 135 | if j < len(slot_pre)-1: 136 | if slot_pre[j+1].startswith('I-'): # 如果下一个也是'I-'开头 137 | post_tag_name = slot_pre[j+1][2:] 138 | if cur_tag_name != post_tag_name: # 但二者不等 139 | return True 140 | return False 141 | 142 | 143 | def process_nest(tokenizer, input_ids, slot_pre, ori_sen): # 处理嵌套 144 | ''' 145 | :param input_ids: 原句的id形式(wordpiece过) 146 | :param slot_pre: 该句的预测slot 147 | :param ori_sen: 原句(字符) 148 | :return: 149 | ''' 150 | start_outer = -1 151 | end_outer = -1 152 | tmp_dict = {} 153 | pre_end = -1 154 | for i, tag in enumerate(slot_pre): 155 | if i <= pre_end: 156 | continue 157 | if tag.startswith('B-') and start_outer == -1: # 第一个'B-' 158 | start_outer = i # 临时起始位置 159 | end_outer = i 160 | key_outer = tag[2:] 161 | 162 | elif tag.startswith('I-') and tag[2:] == slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot相同 163 | end_outer = i 164 | if i == len(slot_pre)-1: # 到了最后 165 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 166 | 167 | elif tag.startswith('O') and start_outer == -1: 168 | continue 169 | elif tag.startswith('O') and start_outer != -1: 170 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, key_outer) 171 | start_outer = -1 172 | end_outer = -1 173 | # 第一种嵌套:B-region I-region I-name I-name 174 | elif tag.startswith('I-') and tag[2:] != slot_pre[i-1][2:]: # 'I-' 且标签与前一个slot不同 175 | start_inner = start_outer 176 | end_inner = end_outer 177 | key_inner = key_outer 178 | # 处理内层 179 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_inner, end_inner, key_inner) 180 | # start_outer不变 181 | end_outer = i # end_outer为当前slot下标 182 | key_outer = slot_pre[i][2:] # 需修改key_outer 183 | # 第二种嵌套:B-name I-name B-datetime_date I-datetime_date I-name B-region I-region I-name I-name 184 | elif tag.startswith('B-'): 185 | flag = False 186 | pre_start = start_outer 187 | pre_end = end_outer 188 | pre_key = key_outer 189 | start_outer = i 190 | end_outer = i 191 | key_outer = slot_pre[i][2:] 192 | # ************************ 193 | for j in range(i, len(slot_pre)): 194 | if slot_pre[j] != 'O' and slot_pre[j][2:] == pre_key: 195 | flag = True 196 | pre_end = j 197 | slot_pre[j] = 'O' # 置为'O' 198 | # ************************* 199 | if flag: # 上一个slot是嵌套 200 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理外层 201 | # 以后遇到完整slot的就加入(假设只有二级嵌套) 202 | for k in range(i+1, pre_end+1): 203 | if slot_pre[k] != 'O': 204 | if slot_pre[k].startswith('I-'): 205 | end_outer = k 206 | elif slot_pre[k].startswith('B-') and start_outer == -1: 207 | start_outer = k 208 | end_outer = k 209 | key_outer = slot_pre[k][2:] 210 | elif slot_pre[k].startswith('B-') and start_outer != -1: 211 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 212 | key_outer) # 处理内层 213 | start_outer = k 214 | end_outer = k 215 | key_outer = slot_pre[k][2:] 216 | elif slot_pre[k] == 'O' and start_outer != -1: 217 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, start_outer, end_outer, 218 | key_outer) # 处理内层 219 | start_outer = -1 220 | end_outer = -1 221 | key_outer = None 222 | elif slot_pre[k] == 'O' and start_outer == -1: 223 | continue 224 | else: # 上一个不是嵌套 225 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids, ori_sen, pre_start, pre_end, pre_key) #先处理上一个非嵌套 226 | 227 | return tmp_dict 228 | 229 | 230 | def find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list): 231 | fp = open('nest.txt', 'w+', encoding='utf-8') 232 | slot_list = [] 233 | count = 0 234 | # ddd = 0 235 | for i, slot_ids in enumerate(slot_pre_output): # 遍历每条数据 236 | # if ddd < 1111: 237 | # ddd += 1 238 | # continue 239 | tmp_dict = {} 240 | start = 0 241 | end = 0 242 | if is_nest(slot_ids[1:-1]): # 如果确实存在嵌套行为 243 | tmp_dict = process_nest(tokenizer, input_ids_list[i][1:-1], slot_pre_output[i][1:-1], ori_sen_list[i]) 244 | slot_list.append(tmp_dict) 245 | fp.write(str(ori_sen_list[i])+'\t') 246 | fp.write(str(tmp_dict)+'\n') 247 | fp.write(' '.join(slot_pre_output[i][1:-1])+'\n') 248 | count += 1 249 | else: 250 | for j, slot in enumerate(slot_ids): # 遍历每个字 251 | if slot != 'O' and slot != '[START]' and slot != '[EOS]': 252 | if slot.startswith('B-') and start == 0: 253 | start = j # 槽值起始位置 254 | end = j 255 | key = slot[2:] 256 | if j == len(slot_ids)-2: # # 如果'B-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 257 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 258 | break 259 | elif slot.startswith('B-') and start != 0: # 说明上一个是槽 260 | # tokenizer, tmp_dict, input_ids, ori_sen, start, end, key 261 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) # 将槽值写入 262 | start = j 263 | end = j 264 | key = slot[2:] 265 | else: # 'I-'开头 266 | end += 1 267 | if j == len(slot_ids)-2: # 如果'I-'是最后一个字符(即倒数第二个,真倒数第一是EOS) 268 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 269 | break 270 | else: 271 | if end == 0: # 说明没找到槽 272 | continue 273 | else: 274 | if slot != '[EOS]': 275 | tmp_dict = get_slot(tokenizer, tmp_dict, input_ids_list[i], ori_sen_list[i], start, end, key) 276 | start = 0 277 | end = 0 278 | slot_list.append(tmp_dict) 279 | print(str(count)+'——interact3 推理完成!') 280 | fp.close() 281 | return slot_list 282 | 283 | 284 | if __name__ == "__main__": 285 | warnings.filterwarnings("ignore") 286 | config = Config() 287 | device = torch.device(config.cuda if torch.cuda.is_available() else "cpu") 288 | with open('../../user_data/common_data/region_dic.json', 'r', encoding='utf-8') as fp: 289 | region_dict = json.load(fp) 290 | tokenizer = AutoTokenizer.from_pretrained('../../user_data/pretrained_model/ernie') 291 | print('interact3 开始推理!') 292 | vocab = load_vocab('../../user_data/pretrained_model/ernie/vocab.txt') 293 | id2intent = load_reverse_vocab('../../user_data/common_data/intent_label.txt') 294 | id2slotNone = load_reverse_vocab('../../user_data/common_data/slot_none_vocab.txt') 295 | id2slot = load_reverse_vocab('../../user_data/common_data/slot_label.txt') 296 | 297 | intent_dict = load_vocab('../../user_data/common_data/intent_label.txt') 298 | slot_none_dict = load_vocab('../../user_data/common_data/slot_none_vocab.txt') 299 | slot_dict = load_vocab('../../user_data/common_data/slot_label.txt') 300 | 301 | intent_tagset_size = len(intent_dict) 302 | slot_none_tag_size = len(slot_none_dict) 303 | slot_tag_size = len(slot_dict) 304 | test_filename = '../../user_data/test_data/test_seq_in_B.txt' 305 | test_dataset = CCFDataset(test_filename, vocab, intent_dict, slot_none_dict, slot_dict, config.max_length) 306 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size, 307 | collate_fn=collate_to_max_length) 308 | 309 | model = InteractModel('../../user_data/pretrained_model/ernie', 310 | config.bert_hidden_size, 311 | intent_tagset_size, 312 | slot_none_tag_size, 313 | slot_tag_size, 314 | device).to(device) 315 | checkpoint = torch.load('../../user_data/output_model/InteractModel_3/Interact3_model_best.pth.tar') 316 | model.load_state_dict(checkpoint["model"]) 317 | 318 | # -------------------- Testing ------------------- # 319 | print("\n", 320 | 20 * "=", 321 | "Test model on device: {}".format(device), 322 | 20 * "=") 323 | 324 | intent_pred, slotNone_pred_output, slot_pre_output, input_ids_list = run_test(model, test_loader, slot_dict) 325 | 326 | ori_sen_list = [] 327 | with open('../../user_data/test_data/test_B_final_text.json', 'r', encoding='utf-8') as fp: 328 | raw_data = json.load(fp) 329 | for filename, single_data in raw_data.items(): 330 | text = single_data['text'] 331 | ori_sen_list.append(text) 332 | slot_list = find_slot(tokenizer, input_ids_list, slot_pre_output, ori_sen_list) 333 | 334 | # 用region_dict对slot_list中需要替换的词进行替换 (这一步不能忽略!!!!) 335 | for i, slot_dict in enumerate(slot_list): 336 | for slot, slot_value in slot_dict.items(): 337 | for region_key, region_list in region_dict.items(): 338 | if slot_value in region_list: 339 | slot_list[i][slot] = region_key 340 | res = {} 341 | for i in range(len(intent_pred)): 342 | big_tmp = {} 343 | slot_tmp_dict = {} 344 | # intent 345 | intent = id2intent[intent_pred[i]] 346 | # slot_none 347 | for slot_none_id in slotNone_pred_output[i]: 348 | if 0 <= slot_none_id <= 9: 349 | slot_tmp_dict['command'] = id2slotNone[slot_none_id] 350 | elif 10 <= slot_none_id <= 18: 351 | slot_tmp_dict['index'] = id2slotNone[slot_none_id] 352 | elif 19 <= slot_none_id <= 22: 353 | slot_tmp_dict['play_mode'] = id2slotNone[slot_none_id] 354 | elif 23 <= slot_none_id <= 27: 355 | slot_tmp_dict['query_type'] = id2slotNone[slot_none_id] 356 | # slot 357 | slot_tmp_dict.update(slot_list[i]) 358 | 359 | big_tmp['intent'] = intent 360 | big_tmp['slots'] = slot_tmp_dict 361 | 362 | length = len(str(i)) 363 | o_num = 5 - length 364 | index = 'NLU' + '0'*o_num + str(i) 365 | res[index] = big_tmp 366 | 367 | with open('../../user_data/tmp_result/result_interact_3.json', 'w', encoding='utf-8') as fp: 368 | json.dump(res, fp, ensure_ascii=False) 369 | 370 | source_path = '../../user_data/tmp_result/result_interact_3.json' 371 | target_path = '../../user_data/tmp_result/result_interact_3_post.json' 372 | process(source_path, target_path) 373 | --------------------------------------------------------------------------------