├── requirements.txt ├── Image ├── CoBERT.png └── DataConstruction.png ├── Tasks ├── Generation │ ├── src │ │ ├── trainer │ │ │ ├── __pycache__ │ │ │ │ ├── trainer_new.cpython-38.pyc │ │ │ │ └── seq2seq_trainer.cpython-38.pyc │ │ │ └── seq2seq_trainer.py │ │ ├── utils.py │ │ ├── metrics.py │ │ ├── predict.py │ │ ├── arguments.py │ │ └── train.py │ └── README.md ├── ResponseModeling │ ├── README.md │ └── src │ │ ├── dataloader.py │ │ ├── utils.py │ │ ├── model.py │ │ └── train.py └── AddresseeRecognition │ ├── README.md │ └── src │ ├── dataloader.py │ ├── utils.py │ ├── train_bert.py │ ├── model.py │ └── train_cobert.py ├── LICENSE ├── Dataset ├── Subset │ ├── basic_profile.json │ ├── text_profile.json │ └── subset.json └── README.md └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | transformers==4.21.2 3 | numpy 4 | datasets -------------------------------------------------------------------------------- /Image/CoBERT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaojingsheng/LiveChat/HEAD/Image/CoBERT.png -------------------------------------------------------------------------------- /Image/DataConstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaojingsheng/LiveChat/HEAD/Image/DataConstruction.png -------------------------------------------------------------------------------- /Tasks/Generation/src/trainer/__pycache__/trainer_new.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaojingsheng/LiveChat/HEAD/Tasks/Generation/src/trainer/__pycache__/trainer_new.cpython-38.pyc -------------------------------------------------------------------------------- /Tasks/Generation/src/trainer/__pycache__/seq2seq_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaojingsheng/LiveChat/HEAD/Tasks/Generation/src/trainer/__pycache__/seq2seq_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /Tasks/Generation/README.md: -------------------------------------------------------------------------------- 1 | ## How to run 2 | 3 | Preprocessing the original data, and copy the data_path into --datadir 4 | 5 | The file tree of datadir should be like: 6 | 7 | ``` 8 | . 9 | +-- LiveChat 10 | | +-- train_data.pk 11 | | +-- test_data.pk 12 | | +-- dev_data.pk 13 | ``` 14 | 15 | ## Train 16 | 17 | ```bash 18 | python train.py --model_path fnlp/bart-base-chinese --output_dir ./outputs --data_dir ./LiveChat/ --do_train --do_eval 19 | ``` 20 | 21 | ## Test 22 | 23 | Similar to the training process 24 | 25 | ```bash 26 | python train.py --model_path --output_dir ./outputs --data_dir ./LiveChat/ --do_eval 27 | ``` 28 | -------------------------------------------------------------------------------- /Tasks/ResponseModeling/README.md: -------------------------------------------------------------------------------- 1 | ## How to run 2 | 3 | Preprocessing the original data, and copy the data_path into --datadir 4 | 5 | The file tree of datadir should be like: 6 | 7 | ``` 8 | . 9 | +-- dataset 10 | | +-- train_data.pk 11 | | +-- test_data.pk 12 | | +-- dev_data.pk 13 | ``` 14 | 15 | ## Train 16 | 17 | ```bash 18 | python train.py --history_post --add_id --train_from_scratch --do_train --do_eval --output_dir ./outputs_history_ID --writer_dir ./outputs_history_ID/runs 19 | ``` 20 | 21 | remove the --history and --add_id mean remove the input of text profiles and basic profiles 22 | 23 | ## Test 24 | 25 | Similar to the training process 26 | 27 | ```bash 28 | python train.py --history_post --add_id --do_eval --load_model_path ./outputs_history_ID/epoch_i --output_dir ./outputs_history_ID --writer_dir ./outputs_history_ID/runs 29 | ``` 30 | -------------------------------------------------------------------------------- /Tasks/Generation/src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | import torch 5 | import pickle as pk 6 | from datasets import Dataset as Ddataset 7 | import pandas as pd 8 | import json 9 | 10 | 11 | 12 | def set_seed(seed): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | def postprocess_text(preds, labels): 20 | preds = [pred.strip() for pred in preds] 21 | labels = [label.strip() for label in labels] 22 | 23 | while '' in preds: 24 | idx=preds.index('') 25 | preds[idx]='。' 26 | return preds, labels 27 | 28 | def load_pk(file_path): 29 | data = pk.load(open(file_path, 'rb')) 30 | results = {'query':[], 'response':[]} 31 | for sample in data: 32 | query = sample[1] 33 | response = sample[2] 34 | results['query'].append(query) 35 | results['response'].append(response) 36 | results = Ddataset.from_dict(results) 37 | return results 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xiaobing.AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dataset/Subset/basic_profile.json: -------------------------------------------------------------------------------- 1 | { 2 | "streamer1": { 3 | "gender": "1", 4 | "age": 2, 5 | "location": "安徽", 6 | "character": 9, 7 | "fans_num": 1, 8 | "live_time": 2, 9 | "reply_barrage": 1, 10 | "audiences": 1, 11 | "skill": null 12 | }, 13 | "streamer2": { 14 | "gender": "1", 15 | "age": 3, 16 | "location": "安徽", 17 | "character": 9, 18 | "fans_num": 3, 19 | "live_time": 3, 20 | "reply_barrage": 1, 21 | "audiences": 2, 22 | "skill": 2 23 | }, 24 | "streamer3": { 25 | "gender": "1", 26 | "age": 3, 27 | "location": "北京", 28 | "character": 34, 29 | "fans_num": 1, 30 | "live_time": null, 31 | "reply_barrage": 1, 32 | "audiences": 1, 33 | "skill": null 34 | }, 35 | "streamer4": { 36 | "gender": "1", 37 | "age": 2, 38 | "location": "内蒙古", 39 | "character": 34, 40 | "fans_num": 2, 41 | "live_time": 2, 42 | "reply_barrage": 1, 43 | "audiences": 1, 44 | "skill": 2 45 | } 46 | } -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/README.md: -------------------------------------------------------------------------------- 1 | ## How to run 2 | 3 | Preprocessing the original data, and copy the data_path into --datadir 4 | 5 | The file tree of datadir should be like: 6 | 7 | ``` 8 | . 9 | +-- dataset 10 | | +-- train_data.pk 11 | | +-- test_data.pk 12 | | +-- dev_data.pk 13 | ``` 14 | 15 | ## Train 16 | 17 | Train CoBERT 18 | 19 | CoBERT 20 | 21 | ```bash 22 | python train_cobert.py --history_post --add_id --do_train --do_eval --apply_interaction --output_dir ./outputs --writer_dir ./outputs/runs 23 | ``` 24 | 25 | Train TwinBERT 26 | 27 | ```bash 28 | python train_cobert.py --add_id --do_train --do_eval --apply_interaction --output_dir ./outputs --writer_dir ./outputs/runs 29 | ``` 30 | 31 | Train BERT 32 | 33 | ```shell 34 | python train_bert.py --add_id --do_train --do_eval --output_dir ./outputs --writer_dir ./outputs/runs 35 | ``` 36 | 37 | ## Test 38 | 39 | Test CoBERT or TwinBERT 40 | 41 | ```bash 42 | python train_cobert.py --history_post --add_id --do_eval --apply_interaction --output_dir ./outputs --writer_dir ./outputs/runs --load_model_path outputs/epoch_i 43 | ``` 44 | 45 | Test BERT 46 | 47 | ```bash 48 | python train_bert.py --add_id --do_eval --apply_interaction --output_dir ./outputs --writer_dir ./outputs/runs --load_model_path outputs/epoch_i 49 | ``` 50 | 51 | ## 52 | 53 | -------------------------------------------------------------------------------- /Dataset/Subset/text_profile.json: -------------------------------------------------------------------------------- 1 | { 2 | "streamer1": [ 3 | "我不想再再那个了认清现实了", 4 | "我睡觉从来不打呼但是有有好像梦梦游过", 5 | "我今天开光了因为晚上", 6 | "嗯我还有几年到30岁我觉得我已经提前跨到这个坎了", 7 | "无人区在哪里啊我也想去我也要找一个没有人认识我的地方", 8 | "是的我跟你讲我最怕石头我最不喜欢石头了", 9 | "我余额不足了发不了红包了要命要命要命", 10 | "莫干山我是前年去的", 11 | "他们都讲我讲话声音像反唱的", 12 | "我不敢睡下的和切你也不知道我做什么了你们竟然", 13 | "我一般在45点啊56点这个时候好容易犯困", 14 | "我要在我的瑜伽里面找到自信" 15 | ], 16 | "streamer2": [ 17 | "我从来不喝酒哈哈但是在直播间我可以千杯不醉", 18 | "我想要个保时捷", 19 | "我们两个是姐妹", 20 | "我是卖丝瓜的纸纸", 21 | "我是草莓来给我送个礼物来", 22 | "哦哦哦我是一个打工人", 23 | "我七哥说了他不上", 24 | "我不染头发自己头发以前染过黑茶色少点粉", 25 | "我们嫁给我们芜湖男人", 26 | "我是一个农村人", 27 | "我怎么知道哈哈我跟大保健不是很熟哈哈卖大闸蟹的", 28 | "我不要吃兔子", 29 | "我的妹妹不要唱了你看我一唱人人都跑完了", 30 | "你觉得我唱歌好听不", 31 | "我前女友是南宁人哈哈哈", 32 | "你觉得我唱的好不好听啊文化人", 33 | "没意思叶超来给我送一个大眼镜行不行", 34 | "发现了我见石头不见你", 35 | "我没开空调", 36 | "我也没卖过西瓜哈哈了半年的积蓄飘洋过海的来", 37 | "我今天回来掰玉米我那天咋整个活都熏着", 38 | "等一下我哥来看我我得哈哈农村母子爱唱歌", 39 | "我都快自闭了", 40 | "哈哈我以前是卖我餐饮", 41 | "叉五叉五你垫个晚上我半夜你看衣服飘", 42 | "两个猪脚送一个鸭爪子哈哈我不吃草莓啊草莓", 43 | "哈哈我这美颜你都能看出来我脸色不好", 44 | "我是芜湖万州人", 45 | "我明天卖玉米卖乐瓜", 46 | "我还从来没有给任何一个大哥买过西瓜", 47 | "我大嫂是世界上最美的女人", 48 | "我本科是会计毕业", 49 | "我明天就在弯指把石挂卖了", 50 | "我家有玉米", 51 | "我想要大天鹅", 52 | "我跟草莓是姐妹啊" 53 | ], 54 | "streamer3": [ 55 | "我喜欢吃包子皮一样", 56 | "我不在咖啡店", 57 | "我谈恋爱起码谈个三年我才能结婚", 58 | "我也喜欢这是新西兰产的", 59 | "对我喜欢喝拿铁", 60 | "我会做蛋糕呢我可能会卖蛋糕", 61 | "我不想找对象", 62 | "我也不认识我对酒精所有对所有酒对我来说都是酒精味", 63 | "我午睡我现在基本上起码4个小时起" 64 | ], 65 | "streamer4": [ 66 | "我们内蒙有很多美女的", 67 | "我天天早上6点就起来跳绳", 68 | "我要喝椰子水", 69 | "我是内蒙古的", 70 | "我毕业还想去深圳工作呢", 71 | "想找我住的地方大家都想找那怎的我犯法了吗", 72 | "黄焖鸡我爱吃椰子鸡", 73 | "我前男友就是靠拆迁挣了不少钱", 74 | "不想谈感情但是我这个前男友他是一个死直男", 75 | "下周我要回内蒙哎", 76 | "我在我去过哈尔滨", 77 | "我跟我爸长得很像可像了", 78 | "我的经验都是远离爱装备的老男人" 79 | ] 80 | } -------------------------------------------------------------------------------- /Dataset/README.md: -------------------------------------------------------------------------------- 1 | # Processed Dataset for LiveChat 2 | 3 | The processed dataset can be downloaded from a Google [Drive](https://drive.google.com/drive/folders/1q2GXfeNRN5bOr2Hc5aDneiBXXVfGN45V?usp=sharing) 4 | 5 | The tree structure of the process dataset is as follows: 6 | ``` 7 | . 8 | +-- ProcessedData 9 | | +-- 400kDialogueInPaper 10 | | +-- train_data.pk 11 | | +-- dev_data.pk 12 | | +-- AddresseeRecognition 13 | | +-- train_data.pk 14 | | +-- dev_data.pk 15 | | +-- test_data.pk 16 | | +-- RawDialogueData 17 | | +-- train_data.pk 18 | | +-- dev_data.pk 19 | | +-- test_data.pk 20 | | +-- basic_profile.json 21 | | +-- text_profile.json 22 | 23 | ``` 24 | ## 400kDialogueinPaper 25 | 26 | This includes the training (train_data.pk) and testing (dev_data.pk) sets of the dialogue dataset used in the paper. 27 | 28 | each dialogue sample is composed by a list: 29 | ``` 30 | [streamer_id, audience_comment, streamer_reponse] 31 | ``` 32 | 33 | ## AddresseeRecognition 34 | 35 | This includes the training (train_data.pk), deving (dev_data.pk) and testing (test_data.pk) sets of the addressee recognition dataset used in the paper. 36 | 37 | each sample is composed by a list: 38 | ``` 39 | [streamer_id, audience_comment_list, streamer_reponse] 40 | ``` 41 | The length of each audience_comment_list is 10, where the last sencetence is the target addresse of the streamer_response. 42 | 43 | ## RawDialogueData 44 | 45 | This includes the raw all training (train_data.pk), deving (dev_data.pk) and testing (test_data.pk) sets of the dialogue dataset of LiveChat. 46 | 47 | each sample is composed by a list: 48 | ``` 49 | [streamer_id, audience_comment, streamer_reponse] 50 | ``` 51 | 52 | ## basic_profile.json 53 | 54 | This includes the basic profile of the streamers, where some characteristics of the streamers are replaced by anonymous numbers. 55 | 56 | ## text_profile.json 57 | 58 | This includes the text profile (max_length is 512) of the streamers used in our paper. 59 | 60 | 61 | 62 | ## Notification 63 | 64 | The statistical quantities of the processed data may have slight discrepancies from those mentioned in the paper. This is because the statistical analysis in the paper was based on the earliest version of multi-party dialogues. -------------------------------------------------------------------------------- /Tasks/ResponseModeling/src/dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class LiveDataset(Dataset): 6 | def __init__(self, total_data, tokenizer, max_len, history_post_len=512): 7 | self.tokenizer = tokenizer 8 | self.data = total_data 9 | self.max_len = max_len 10 | self.history_post_len = history_post_len 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | def tokenize(self, input_text, max_len): 16 | inputs = self.tokenizer.encode_plus( 17 | input_text, 18 | None, 19 | add_special_tokens=True, 20 | max_length=max_len, 21 | pad_to_max_length=True, 22 | return_token_type_ids=True, 23 | truncation=True 24 | ) 25 | ids = inputs['input_ids'] 26 | mask = inputs['attention_mask'] 27 | token_type_ids = inputs["token_type_ids"] 28 | return ids, mask, token_type_ids 29 | 30 | def __getitem__(self, index): 31 | queries = self.data[index]['query'] 32 | reponses = self.data[index]['response'] 33 | ids1, mask1, token_type_ids1 = self.tokenize(queries, self.max_len) 34 | ids2, mask2, token_type_ids2 = self.tokenize(reponses, self.max_len) 35 | if 'history_post' in self.data[index]: 36 | # history_post_id = self.data[real_index]['history_post'] 37 | # print("history_post_id is ", history_post_id) 38 | history_posts = self.data[index]['history_post'] 39 | ids3, mask3, token_type_ids3 = self.tokenize(history_posts, self.history_post_len) 40 | return { 41 | 'ids': [torch.tensor(ids1, dtype=torch.long),torch.tensor(ids2, dtype=torch.long),torch.tensor(ids3, dtype=torch.long)], 42 | 'mask': [torch.tensor(mask1, dtype=torch.long),torch.tensor(mask2, dtype=torch.long),torch.tensor(mask3, dtype=torch.long)], 43 | 'token_type_ids': [torch.tensor(token_type_ids1, dtype=torch.long),torch.tensor(token_type_ids2, dtype=torch.long),torch.tensor(token_type_ids3, dtype=torch.long)], 44 | } 45 | else: 46 | return { 47 | 'ids': [torch.tensor(ids1, dtype=torch.long),torch.tensor(ids2, dtype=torch.long)], 48 | 'mask': [torch.tensor(mask1, dtype=torch.long),torch.tensor(mask2, dtype=torch.long)], 49 | 'token_type_ids': [torch.tensor(token_type_ids1, dtype=torch.long),torch.tensor(token_type_ids2, dtype=torch.long)], 50 | } 51 | 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LiveChat: A Large-Scale Personalized Dialogue Dataset Automatically Constructed from Live Streaming 2 | This is the official repository for the ACL 2023 paper "LiveChat: A Large-Scale Personalized Dialogue Dataset Automatically Constructed from Live Streaming" 3 | 4 | ![DataConstruction](./Image/DataConstruction.png) 5 | LiveChat is a large-scale dataset, composed of 1.33 million real-life Chinese dialogues with almost 3800 average sessions across 351 personas and fine-grained profiles for each persona. LiveChat is automatically constructed by processing numerous live videos on the Internet and naturally falls within the scope of multi-party conversations. 6 | 7 | This repo implements two benchmark tasks (Response Modeling and Addressee Recognition) and a generation task (Generation) for LiveChat: 8 | 9 | - [Response Modeling](https://github.com/gaojingsheng/LiveChat/tree/master/Tasks/ResponseModeling) (Retrival-based) 10 | - [Addressee Recognition](https://github.com/gaojingsheng/LiveChat/tree/master/Tasks/AddresseeRecognition) 11 | - [Generation](https://github.com/gaojingsheng/LiveChat/tree/master/Tasks/Generation) (BART) 12 | 13 | Instructions of how to run these models on the two tasks are described in their README files. Before trying them, you need to first download the dataset and unzip it into the folder ./Dataset. The file tree should be like 14 | 15 | ``` 16 | . 17 | +-- dataset 18 | | +-- train.json 19 | | +-- val.json 20 | | +-- test.json 21 | | +-- basic_profile.json 22 | | +-- text_profile.json 23 | ``` 24 | 25 | ## Enviroment 26 | You need to clone our project: 27 | ```bash 28 | $ git clone https://github.com/gaojingsheng/LiveChat.git 29 | ``` 30 | 31 | 32 | Create the environment and download the packages 33 | ```bash 34 | $ conda create -n LiveChat python==3.8 35 | $ conda activate LiveChat 36 | $ pip install -r requirements.txt 37 | ``` 38 | 39 | ## DataSet 40 | ### Download 41 | Please refer to dataset [README.md](https://github.com/gaojingsheng/LiveChat/blob/master/Dataset/README.md) 42 | 43 | 44 | ## Citation 45 | If you find our paper and repository useful, please cite us in your paper: 46 | ``` 47 | @inproceedings{gao-etal-2023-livechat, 48 | title = "{L}ive{C}hat: A Large-Scale Personalized Dialogue Dataset Automatically Constructed from Live Streaming", 49 | author = "Gao, Jingsheng and 50 | Lian, Yixin and 51 | Zhou, Ziyi and 52 | Fu, Yuzhuo and 53 | Wang, Baoyuan", 54 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 55 | month = jul, 56 | year = "2023", 57 | address = "Toronto, Canada", 58 | publisher = "Association for Computational Linguistics", 59 | url = "https://aclanthology.org/2023.acl-long.858", 60 | doi = "10.18653/v1/2023.acl-long.858", 61 | pages = "15387--15405", 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /Dataset/Subset/subset.json: -------------------------------------------------------------------------------- 1 | { 2 | "streamer1": [ 3 | [ 4 | "我的用了4年了", 5 | "用了4年是手机吗?大家先帮主播点点赞好不好?" 6 | ], 7 | [ 8 | "说我又不结婚装修干嘛", 9 | "准备什么时候装修啊?" 10 | ], 11 | [ 12 | "我升级了啊", 13 | "对啊,你升到25级了。" 14 | ], 15 | [ 16 | "音乐大,说话声音带混响", 17 | "那我怎么调啊?" 18 | ], 19 | [ 20 | "忘记了吗", 21 | "我要看一下啊,因为看这个看不出来。" 22 | ], 23 | [ 24 | "被认出来是好事", 25 | "不是啊我不想被认出来出门我就巴不得是谁都不认输哈哈哈。" 26 | ], 27 | [ 28 | "主要是人衬托的衣服", 29 | "没有没有,这衣服颜色本来就是比较显得比较亮啊。" 30 | ], 31 | [ 32 | "斗鱼", 33 | "对啊斗鱼你们去过吗?" 34 | ], 35 | [ 36 | "新出来的音乐", 37 | "嗯,什么歌呀?" 38 | ] 39 | ], 40 | "streamer2": [ 41 | [ 42 | "来电话了", 43 | "好的呢知道了。" 44 | ], 45 | [ 46 | "都不叫哥哥", 47 | "我我想问一下我认识你吗?" 48 | ], 49 | [ 50 | "就是剪子好", 51 | "哈哈你怎么这么有眼光。" 52 | ], 53 | [ 54 | "晚上你吃饱了", 55 | "我最低目标都没有完成。" 56 | ], 57 | [ 58 | "我是家里人", 59 | "你跟谁是家里人?" 60 | ], 61 | [ 62 | "被你吓一身汗😓", 63 | "哈哈别怕别怕不做亏心事你怕什么。" 64 | ], 65 | [ 66 | "我马上也要升级了", 67 | "就在我家升吧大总管。" 68 | ], 69 | [ 70 | "在看什么啊", 71 | "我在看我跟他玩发的段子" 72 | ] 73 | ], 74 | "streamer3": [ 75 | [ 76 | "有点卡", 77 | "村里面信号卡。" 78 | ], 79 | [ 80 | "这有什么可处理的", 81 | "但是问题是我引起的呀" 82 | ], 83 | [ 84 | "我在外面吃饭啊", 85 | "没事你找个角落没事咱们这局都约好了。" 86 | ], 87 | [ 88 | "儿童节么,棒棒糖", 89 | "还有儿童节新出的礼物,谢我杰哥棒棒糖" 90 | ], 91 | [ 92 | "可爱", 93 | "必须可爱!隔壁小屋他说我闭一下兔头的麦。" 94 | ], 95 | [ 96 | "中午晚上我都在", 97 | "你是昨天晚上在吗?" 98 | ], 99 | [ 100 | "查到你说的咖啡店了", 101 | "真的很好喝," 102 | ], 103 | [ 104 | "莲雾哪里买的?", 105 | "我不知道哪买的我妈买的。" 106 | ], 107 | [ 108 | "你酒量太差", 109 | "给你挡酒的,这不就过分了吗?" 110 | ], 111 | [ 112 | "可能他想下车了", 113 | "他车都没上,他怎么下呀?" 114 | ], 115 | [ 116 | "应聘秘书?", 117 | "对啊。" 118 | ], 119 | [ 120 | "等我今天努努力", 121 | "继续努力搬砖。" 122 | ], 123 | [ 124 | "你们家有用艾草洗澡嘛", 125 | "为啥要用艾草洗澡?" 126 | ] 127 | ], 128 | "streamer4": [ 129 | [ 130 | "考试结束了吗", 131 | "考试没结束呢考试在年底呢。" 132 | ], 133 | [ 134 | "我跟你介绍一些北京老铁吧", 135 | "什么什么样的老铁啊?" 136 | ], 137 | [ 138 | "这个美女真好看", 139 | "嗯嗯,真的很好看。" 140 | ], 141 | [ 142 | "我就喜欢内蒙的羊肉", 143 | "内蒙的话苏尼特的羊肉好吃,在锡林浩特那边,苏尼特羊肉,我们乌兰察布呢,就是四子王起的羊肉。" 144 | ], 145 | [ 146 | "沙漠里的风能把人吹飞起来", 147 | "沙漠里的风哎呦城市里的风也很大,就整个内蒙古,我不知道内蒙东北边是不是也这样。" 148 | ], 149 | [ 150 | "你家有农场吗", 151 | "我家没有农场。我家只有草原。" 152 | ], 153 | [ 154 | "你对象呢?", 155 | "我对象可能会可能在七彩祥云上面,他会踩着七彩祥云来接我的,娶我回家。" 156 | ], 157 | [ 158 | "你没事吧?", 159 | "我倒是没事,还好离得远,镜子没有砸到我。" 160 | ], 161 | [ 162 | "我想关注你", 163 | "批准了,关注吧。" 164 | ], 165 | [ 166 | "美颜加美瞳", 167 | "对然后化妆都整上了。" 168 | ], 169 | [ 170 | "你没答应啊", 171 | "我为什么要答应呢?" 172 | ], 173 | [ 174 | "就我一个人?", 175 | "是的就你一个人有多苦哎" 176 | ], 177 | [ 178 | "还没睡呢?", 179 | "还没有到点呢,还有半个小时下班。" 180 | ], 181 | [ 182 | "你能开麦吗", 183 | "怎么开?" 184 | ], 185 | [ 186 | "你在学习啥", 187 | "在学母猪的产后护理。" 188 | ] 189 | ] 190 | 191 | } -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/src/dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class LiveDataset(Dataset): 6 | def __init__(self, total_data, tokenizer, max_len, history_post_len=512): 7 | self.tokenizer = tokenizer 8 | self.data = total_data 9 | self.max_len = max_len 10 | self.history_post_len = history_post_len 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | def tokenize(self, input_text, max_len): 16 | inputs = self.tokenizer.encode_plus( 17 | input_text, 18 | None, 19 | add_special_tokens=True, 20 | max_length=max_len, 21 | pad_to_max_length=True, 22 | return_token_type_ids=True, 23 | truncation=True 24 | ) 25 | ids = inputs['input_ids'] 26 | mask = inputs['attention_mask'] 27 | token_type_ids = inputs["token_type_ids"] 28 | return ids, mask, token_type_ids 29 | 30 | def __getitem__(self, index): 31 | queries = self.data[index]['query'] 32 | reponses = self.data[index]['response'] 33 | ids1, mask1, token_type_ids1 = self.tokenize(queries, self.max_len) 34 | ids2, mask2, token_type_ids2 = self.tokenize(reponses, self.max_len) 35 | if 'history_post' in self.data[index]: 36 | history_posts = self.data[index]['history_post'] 37 | ids3, mask3, token_type_ids3 = self.tokenize(history_posts, self.history_post_len) 38 | return { 39 | 'ids': [torch.tensor(ids1, dtype=torch.long),torch.tensor(ids2, dtype=torch.long),torch.tensor(ids3, dtype=torch.long)], 40 | 'mask': [torch.tensor(mask1, dtype=torch.long),torch.tensor(mask2, dtype=torch.long),torch.tensor(mask3, dtype=torch.long)], 41 | 'token_type_ids': [torch.tensor(token_type_ids1, dtype=torch.long),torch.tensor(token_type_ids2, dtype=torch.long),torch.tensor(token_type_ids3, dtype=torch.long)], 42 | } 43 | else: 44 | return { 45 | 'ids': [torch.tensor(ids1, dtype=torch.long),torch.tensor(ids2, dtype=torch.long)], 46 | 'mask': [torch.tensor(mask1, dtype=torch.long),torch.tensor(mask2, dtype=torch.long)], 47 | 'token_type_ids': [torch.tensor(token_type_ids1, dtype=torch.long),torch.tensor(token_type_ids2, dtype=torch.long)], 48 | } 49 | 50 | 51 | class SingleBertDataset(Dataset): 52 | def __init__(self, total_data, tokenizer, max_len): 53 | self.tokenizer = tokenizer 54 | self.data = total_data 55 | self.max_len = max_len 56 | self.history_post_len = 512 57 | self.addressee_data = [] 58 | for i in range(len(self.data)): 59 | one_dialogue = self.data[i] 60 | for j in range(len(self.data[i]["audiences"])): 61 | sentence_label = [] 62 | sentence_label.append(one_dialogue["audiences"][j] + " [SEP] " + one_dialogue["streamer"]) 63 | if j == len(self.data[i]["audiences"])-1: 64 | sentence_label.append(1) 65 | else: 66 | sentence_label.append(0) 67 | self.addressee_data.append(sentence_label) 68 | 69 | def __len__(self): 70 | return len(self.addressee_data) 71 | 72 | def tokenize(self, input_text, max_len): 73 | inputs = self.tokenizer.encode_plus( 74 | input_text, 75 | None, 76 | add_special_tokens=True, 77 | max_length=max_len, 78 | padding='max_length', 79 | return_token_type_ids=True, 80 | truncation=True 81 | ) 82 | ids = inputs['input_ids'] 83 | mask = inputs['attention_mask'] 84 | token_type_ids = inputs["token_type_ids"] 85 | return ids, mask, token_type_ids 86 | 87 | def __getitem__(self, index): 88 | 89 | sentence = self.addressee_data[index][0] 90 | label = self.addressee_data[index][1] 91 | ids, mask, token_type_ids = self.tokenize(sentence, self.max_len) 92 | 93 | return { 94 | 'ids': torch.tensor(ids, dtype=torch.long), 95 | 'mask': torch.tensor(mask, dtype=torch.long), 96 | 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), 97 | 'label': torch.tensor(label, dtype=torch.long), 98 | } 99 | 100 | -------------------------------------------------------------------------------- /Tasks/Generation/src/metrics.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction 2 | from nltk.tokenize import word_tokenize # word_tokenize according to blank 3 | from rouge import Rouge 4 | 5 | from collections import Counter 6 | import json 7 | import pickle as pk 8 | import numpy as np 9 | # import torch 10 | 11 | # generate_corpus: [sentence1, sentence2, ...] 12 | # reference_corpus: [sentence1, sentence2, ...] 13 | # use eval_result to retrive the bleu, dist, rouge result 14 | def eval_result(generate_corpus, reference_corpus): 15 | 16 | generate_corpus = [" ".join([s for s in gen]) for gen in generate_corpus] 17 | reference_corpus = [" ".join([s for s in ref]) for ref in reference_corpus] 18 | 19 | results = {} 20 | bleu_result = cal_bleu(generate_corpus, reference_corpus) 21 | dist_result = cal_dist(generate_corpus, reference_corpus) 22 | rouge_result = cal_rouge(generate_corpus, reference_corpus) 23 | 24 | results.update(bleu_result) 25 | results.update(dist_result) 26 | results.update(rouge_result) 27 | 28 | return results 29 | 30 | def cal_bleu(generate_corpus, reference_corpus): 31 | 32 | print("generate_corpus is:",generate_corpus[:10]) 33 | print("reference_corpus is:",reference_corpus[:10]) 34 | ngrams = ['bleu-{}'.format(n) for n in range(1, 5)] 35 | ngram_weights = [] 36 | results = {} 37 | for ngram in ngrams: 38 | results[ngram] = [] 39 | for n in range(1, 5): 40 | weights = [0.] * 4 41 | weights[:n] = [1. / n] * n 42 | ngram_weights.append(weights) 43 | 44 | for gen, refs in zip(generate_corpus, reference_corpus): 45 | gen = word_tokenize(gen.strip()) 46 | refs = word_tokenize(refs.strip()) 47 | # print(gen, refs) 48 | try: 49 | for ngram, weights in zip(ngrams, ngram_weights): 50 | score = sentence_bleu([refs], gen, weights=weights, smoothing_function=SmoothingFunction().method7) 51 | assert type(score) == float or int 52 | results[ngram].append(score * 100) 53 | except: 54 | pass 55 | for item in results: 56 | results[item] = sum(results[item])/len(results[item]) 57 | return results 58 | 59 | def cal_dist(generate_corpus, reference_corpus=None): 60 | 61 | results = {} 62 | ngrams_all = [Counter() for _ in range(4)] 63 | for gen in generate_corpus: 64 | gen = gen.strip() 65 | ngrams = [] 66 | for i in range(4): 67 | ngrams.append(gen[i:]) 68 | ngram = Counter(zip(*ngrams)) 69 | ngrams_all[i].update(ngram) 70 | for i in range(4): 71 | results[f'distinct-{i+1}'] = (len(ngrams_all[i])+1e-12) / (sum(ngrams_all[i].values())+1e-5) * 100 72 | return results 73 | 74 | def cal_rouge(generate_corpus, reference_corpus): 75 | 76 | rouge = Rouge() 77 | results = rouge.get_scores(generate_corpus, reference_corpus, avg=True) 78 | for key in results: 79 | results[key]=results[key]['f']*100 80 | return results 81 | 82 | def cal_ppl(input_list): 83 | # this is a fault cal_ppl 84 | return 1 85 | 86 | if __name__ == "__main__": 87 | eva_generation_path = "generation.json" 88 | total_data = json.load(open(eva_generation_path,"rb")) 89 | # dev_data = pk.load(open("dataset_withpersona/dev_data.pk","rb")) 90 | generate_corpus = [] 91 | reference_corpus = [] 92 | print("total_data length is ", len(total_data)) 93 | for item in total_data: 94 | generate_corpus.append(item["generation"]) 95 | reference_corpus.append(item["response"]) 96 | 97 | generate_corpus = [" ".join([s for s in gen[:62]]) for gen in generate_corpus] 98 | reference_corpus = [" ".join([s for s in ref[:62]]) for ref in reference_corpus] 99 | 100 | 101 | bleu_result = cal_bleu(generate_corpus, reference_corpus) 102 | dist_result = cal_dist(generate_corpus, reference_corpus) 103 | rouge_result = cal_rouge(generate_corpus, reference_corpus) 104 | 105 | print(bleu_result) 106 | print(dist_result) 107 | print(rouge_result) 108 | 109 | 110 | 111 | # {'bleu-1': 31.59630122915178, 'bleu-2': 20.18833827479204, 'bleu-3': 12.927942535617492, 'bleu-4': 8.251670544064492} 112 | # {'distinct-1': 0.9541625837141989, 'distinct-2': 1.978280561728691, 'distinct-3': 10.446433473700157, 'distinct-4': 18.939314718827234} 113 | # {'rouge-1': 25.18328101682556, 'rouge-2': 6.363913083144415, 'rouge-l': 23.290554644714607} 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /Tasks/Generation/src/predict.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import logging 5 | import numpy as np 6 | import math 7 | from transformers import (BertTokenizer, BartForConditionalGeneration, HfArgumentParser, DataCollatorForSeq2Seq, 8 | Seq2SeqTrainingArguments) 9 | from datasets import Dataset 10 | 11 | from arguments import DataTrainingArguments, ModelArguments 12 | from utils import set_seed 13 | 14 | from trainer.seq2seq_trainer import Seq2SeqTrainerNew 15 | 16 | class Predictor: 17 | def __init__(self): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model_name",default="cpt",type=str) 20 | parser.add_argument("--model_path",default="./outputs/checkpoint",type=str) 21 | parser.add_argument("--lr",default=2e-5,type=float) 22 | parser.add_argument("--batch_size",default='32',type=str) 23 | parser.add_argument("--epoch",default='15',type=str) 24 | parser.add_argument("--data_dir",default="./dataset/",type=str) 25 | args = parser.parse_args() 26 | arg_dict=args.__dict__ 27 | 28 | args=[ 29 | '--model_name_or_path',arg_dict['model_path'], 30 | '--model_name',arg_dict['model_name'], 31 | '--output_dir',"./outputs", 32 | '--preprocessing_num_workers=4', 33 | '--logging_steps=100', 34 | # '--max_train_samples=200', 35 | # '--max_val_samples=200', 36 | '--dataloader_num_workers=4', 37 | '--per_device_train_batch_size',arg_dict['batch_size'], 38 | '--per_device_eval_batch_size',arg_dict['batch_size'], 39 | '--overwrite_output_dir', 40 | '--max_source_length=64', 41 | '--val_max_target_length='+'64', 42 | '--predict_with_generate=1', 43 | '--seed',str(1000*1), 44 | '--num_train_epochs',arg_dict['epoch'], 45 | '--save_strategy','epoch', 46 | '--save_total_limit', '3', 47 | '--evaluation_strategy','epoch', 48 | '--learning_rate',str(arg_dict['lr']), 49 | ] 50 | 51 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 52 | model_args, self.data_args, self.training_args = parser.parse_args_into_dataclasses(args) 53 | set_seed(self.training_args.seed) 54 | 55 | logging.basicConfig( 56 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 57 | datefmt="%m/%d/%Y %H:%M:%S", 58 | handlers=[logging.StreamHandler(sys.stdout)], 59 | ) 60 | 61 | self.tokenizer=BertTokenizer.from_pretrained(model_args.model_name_or_path) 62 | self.model = BartForConditionalGeneration.from_pretrained(model_args.model_name_or_path) 63 | 64 | 65 | self.model.config.max_length=self.data_args.val_max_target_length 66 | 67 | self.max_target_length = self.data_args.val_max_target_length 68 | self.padding=False 69 | 70 | label_pad_token_id = -100 if self.data_args.ignore_pad_token_for_loss else self.tokenizer.pad_token_id 71 | data_collator = DataCollatorForSeq2Seq( 72 | self.tokenizer, 73 | model=self.model, 74 | label_pad_token_id=label_pad_token_id, 75 | pad_to_multiple_of=8 if self.training_args.fp16 else None, 76 | ) 77 | 78 | print("Initialize our Trainer") 79 | self.trainer = Seq2SeqTrainerNew( 80 | model=self.model, 81 | args=self.training_args, 82 | train_dataset=None, 83 | eval_dataset=None, 84 | tokenizer=self.tokenizer, 85 | data_collator=data_collator, 86 | compute_metrics=None, 87 | ) 88 | 89 | def preprocess_function(self, examples): 90 | queries = examples['query'] 91 | reponses = examples['response'] 92 | 93 | # inputs = ["[SEP]".join(b) + "[SEP]" + a for a, b in zip(documents, contexts)] 94 | model_inputs = self.tokenizer(queries, max_length=self.data_args.max_source_length, padding=self.padding, truncation=True) 95 | 96 | # Setup the tokenizer for targets 97 | with self.tokenizer.as_target_tokenizer(): 98 | labels = self.tokenizer(reponses, max_length=self.max_target_length, padding=self.padding, truncation=True) 99 | 100 | model_inputs["labels"] = labels["input_ids"] 101 | return model_inputs 102 | 103 | def predict(self, query): 104 | results={'query':[],'response':[]} 105 | 106 | results['response'].append('') 107 | results['query'].append(query) 108 | results=Dataset.from_dict(results) 109 | test_dataset = results 110 | column_names = test_dataset.column_names 111 | 112 | test_dataset = test_dataset.map( 113 | self.preprocess_function, 114 | batched=True, 115 | remove_columns=column_names, 116 | ) 117 | 118 | output, scores = self.trainer.predict(test_dataset, metric_key_prefix="predict") # 使用自定义trainer predict函数 119 | predictions = output.predictions 120 | test_preds = self.tokenizer.batch_decode( 121 | predictions, skip_special_tokens=True, 122 | ) 123 | test_preds = [pred.strip() for pred in test_preds] 124 | test_pred = "".join([x for x in test_preds[0] if x != ' ']) 125 | return test_pred, scores[0].cpu().numpy() 126 | 127 | 128 | if __name__ == "__main__": 129 | pred = Predictor() 130 | query_list = [] 131 | results = [] 132 | 133 | for query in query_list: 134 | test_pred, score = pred.predict(query) 135 | results.append([query, test_pred, math.exp(float(score))]) 136 | 137 | print(query, test_pred) 138 | -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | import torch 5 | import pickle as pk 6 | from datasets import Dataset as Ddataset 7 | import pandas as pd 8 | import json 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | def load_pk(file_path): 18 | data = pk.load(open(file_path, 'rb')) 19 | results = {'streamer':[], 'audiences':[]} 20 | for sample in data: 21 | audiences = sample[1] 22 | streamer = sample[2] 23 | results['streamer'].append(streamer) 24 | results['audiences'].append(audiences) 25 | results = Ddataset.from_dict(results) 26 | return results 27 | 28 | def load_pk_persona(file_path, personaId_list_path): 29 | id_list = json.load(open(personaId_list_path,"rb")) 30 | data = pk.load(open(file_path, 'rb')) 31 | results = {'streamer':[], 'audiences':[]} 32 | for sample in data: 33 | id_index = sample[0] 34 | audiences = sample[1] 35 | streamer = str(id_list[id_index]) + " [SEP] "+ sample[2] 36 | results['streamer'].append(streamer) 37 | results['audiences'].append(audiences) 38 | results = Ddataset.from_dict(results) 39 | return results 40 | 41 | 42 | def load_retrive_history_post(file_path, history_post_path): 43 | history_post_list = json.load(open(history_post_path,"rb")) 44 | data = pk.load(open(file_path, 'rb')) 45 | 46 | results = {'history_post':[], 'streamer':[], 'audiences':[]} 47 | for sample in data: 48 | scratch_id = sample[0] 49 | audiences = sample[1] 50 | streamer = sample[2] 51 | results['history_post'].append(history_post_list[scratch_id]) 52 | results['streamer'].append(streamer) 53 | results['audiences'].append(audiences) 54 | results = Ddataset.from_dict(results) 55 | return results 56 | 57 | 58 | def load_retrive_history_post_and_id(file_path, history_post_path, persona_id_path): 59 | 60 | history_post_list = json.load(open(history_post_path,"rb")) 61 | id_list = json.load(open(persona_id_path,"rb")) 62 | data = pk.load(open(file_path, 'rb')) 63 | 64 | results = {'history_post':[], 'streamer':[], 'audiences':[]} 65 | for sample in data: 66 | scratch_id = sample[0] 67 | audiences = sample[1] 68 | streamer = str(id_list[scratch_id]) + " [SEP] "+ sample[2] 69 | 70 | results['history_post'].append(history_post_list[scratch_id]) 71 | results['streamer'].append(streamer) 72 | results['audiences'].append(audiences) 73 | results = Ddataset.from_dict(results) 74 | return results 75 | 76 | 77 | def compute_metrics(batch_x_emb, batch_y_emb): 78 | """ 79 | recall@k for N candidates 80 | if batch_x_emb.dim() == 2: 81 | # batch_x_emb: (batch_size, emb_size) 82 | # batch_y_emb: (batch_size, emb_size) 83 | 84 | if batch_x_emb.dim() == 3: 85 | # batch_x_emb: (batch_size, batch_size, emb_size), the 1st dim is along examples and the 2nd dim is along candidates 86 | # batch_y_emb: (batch_size, emb_size) 87 | 88 | """ 89 | batch_size = batch_x_emb.size(0) 90 | targets = torch.arange(batch_size, device=batch_x_emb.device) 91 | if batch_x_emb.dim() == 2: 92 | dot_products = batch_x_emb.mm(batch_y_emb.t()) # (batch_size, batch_size) 93 | elif batch_x_emb.dim() == 3: 94 | dot_products = torch.bmm(batch_x_emb, batch_y_emb.unsqueeze(0).repeat(batch_size, 1, 1).transpose(1,2))[:, targets, targets] 95 | 96 | # dot_products: (batch_size, batch_size) 97 | sorted_indices = dot_products.sort(descending=True)[1] 98 | targets = np.arange(batch_size).tolist() 99 | recall_k = [] 100 | if batch_size <= 10: 101 | ks = [1, max(1, round(batch_size*0.2)), max(1, round(batch_size*0.5))] 102 | elif batch_size <= 100: 103 | ks = [1, max(1, round(batch_size*0.1)), max(1, round(batch_size*0.5))] 104 | else: 105 | raise ValueError("batch_size: {0} is not proper".format(batch_size)) 106 | for k in ks: 107 | # sorted_indices[:,:k]: (batch_size, k) 108 | num_ok = 0 109 | for tgt, topk in zip(targets, sorted_indices[:,:k].tolist()): 110 | if tgt in topk: 111 | num_ok += 1 112 | recall_k.append(num_ok/batch_size) 113 | 114 | # MRR 115 | MRR = 0 116 | for tgt, topk in zip(targets, sorted_indices.tolist()): 117 | rank = topk.index(tgt)+1 118 | MRR += 1/rank 119 | MRR = MRR/batch_size 120 | return recall_k, MRR 121 | 122 | def compute_metrics_from_logits(logits, targets): 123 | """ 124 | recall@k for N candidates 125 | 126 | logits: (batch_size, num_candidates) 127 | targets: (batch_size, ) 128 | 129 | """ 130 | batch_size, num_candidates = logits.shape 131 | 132 | sorted_indices = logits.sort(descending=True)[1] 133 | targets = targets.tolist() 134 | 135 | recall_k = [] 136 | if num_candidates <= 10: 137 | ks = [1, max(1, round(num_candidates*0.2)), max(1, round(num_candidates*0.5))] 138 | elif num_candidates <= 100: 139 | ks = [1, max(1, round(num_candidates*0.1)), max(1, round(num_candidates*0.5))] 140 | else: 141 | raise ValueError("num_candidates: {0} is not proper".format(num_candidates)) 142 | for k in ks: 143 | # sorted_indices[:,:k]: (batch_size, k) 144 | num_ok = 0 145 | for tgt, topk in zip(targets, sorted_indices[:,:k].tolist()): 146 | if tgt in topk: 147 | num_ok += 1 148 | recall_k.append(num_ok/batch_size) 149 | 150 | # MRR 151 | MRR = 0 152 | for tgt, topk in zip(targets, sorted_indices.tolist()): 153 | rank = topk.index(tgt)+1 154 | MRR += 1/rank 155 | MRR = MRR/batch_size 156 | return recall_k, MRR 157 | -------------------------------------------------------------------------------- /Tasks/ResponseModeling/src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | import torch 5 | import pickle as pk 6 | from datasets import Dataset as Ddataset 7 | import json 8 | 9 | def process_index_to_str(index, name): 10 | if name == "age": 11 | if index == 1: 12 | return 16 13 | elif index == 2: 14 | return 22 15 | 16 | def set_seed(seed): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def postprocess_text(preds, labels): 24 | preds = [pred.strip() for pred in preds] 25 | labels = [label.strip() for label in labels] 26 | while '' in preds: 27 | idx=preds.index('') 28 | preds[idx]='。' 29 | return preds, labels 30 | 31 | def load_pk(file_path): 32 | data = pk.load(open(file_path, 'rb')) 33 | results = {'query':[], 'response':[]} 34 | for sample in data: 35 | query = sample[1] 36 | response = sample[2] 37 | results['query'].append(query) 38 | results['response'].append(response) 39 | results = Ddataset.from_dict(results) 40 | return results 41 | 42 | def load_pk_persona(file_path, personaId_list_path): 43 | id_list = json.load(open(personaId_list_path,"rb")) 44 | data = pk.load(open(file_path, 'rb')) 45 | results = {'query':[], 'response':[]} 46 | for sample in data: 47 | id_index = sample[0] 48 | query = str(id_list[id_index]) + " [SEP] "+ sample[1] 49 | response = sample[2] 50 | results['query'].append(query) 51 | results['response'].append(response) 52 | results = Ddataset.from_dict(results) 53 | return results 54 | 55 | def load_pk_history_post(file_path, history_post_path): 56 | data = pk.load(open(file_path, 'rb')) 57 | results = {'history_post':[], 'query':[], 'response':[]} 58 | for sample in data: 59 | scratch_id = sample[0] 60 | query = sample[1] 61 | response = sample[2] 62 | results['history_post'].append(int(scratch_id)) 63 | results['query'].append(query) 64 | results['response'].append(response) 65 | results = Ddataset.from_dict(results) 66 | return results 67 | 68 | def load_retrive_history_post(file_path, history_post_path): 69 | history_post_list = json.load(open(history_post_path,"rb")) 70 | data = pk.load(open(file_path, 'rb')) 71 | 72 | results = {'history_post':[], 'query':[], 'response':[]} 73 | for sample in data: 74 | scratch_id = sample[0] 75 | query = sample[1] 76 | response = sample[2] 77 | results['history_post'].append(history_post_list[scratch_id]) 78 | results['query'].append(query) 79 | results['response'].append(response) 80 | results = Ddataset.from_dict(results) 81 | return results 82 | 83 | def load_retrive_history_post_and_id(file_path, history_post_path, persona_id_path): 84 | history_post_list = json.load(open(history_post_path,"rb")) 85 | id_list = json.load(open(persona_id_path,"rb")) 86 | data = pk.load(open(file_path, 'rb')) 87 | 88 | results = {'history_post':[], 'query':[], 'response':[]} 89 | for sample in data: 90 | scratch_id = sample[0] 91 | query = str(id_list[scratch_id]) + " [SEP] "+ sample[1] 92 | response = sample[2] 93 | results['history_post'].append(history_post_list[scratch_id]) 94 | results['query'].append(query) 95 | results['response'].append(response) 96 | results = Ddataset.from_dict(results) 97 | return results 98 | 99 | def compute_metrics(batch_x_emb, batch_y_emb): 100 | """ 101 | recall@k for N candidates 102 | if batch_x_emb.dim() == 2: 103 | # batch_x_emb: (batch_size, emb_size) 104 | # batch_y_emb: (batch_size, emb_size) 105 | 106 | if batch_x_emb.dim() == 3: 107 | # batch_x_emb: (batch_size, batch_size, emb_size), the 1st dim is along examples and the 2nd dim is along candidates 108 | # batch_y_emb: (batch_size, emb_size) 109 | 110 | """ 111 | batch_size = batch_x_emb.size(0) 112 | targets = torch.arange(batch_size, device=batch_x_emb.device) 113 | if batch_x_emb.dim() == 2: 114 | dot_products = batch_x_emb.mm(batch_y_emb.t()) # (batch_size, batch_size) 115 | elif batch_x_emb.dim() == 3: 116 | dot_products = torch.bmm(batch_x_emb, batch_y_emb.unsqueeze(0).repeat(batch_size, 1, 1).transpose(1,2))[:, targets, targets] 117 | 118 | # dot_products: (batch_size, batch_size) 119 | sorted_indices = dot_products.sort(descending=True)[1] 120 | targets = np.arange(batch_size).tolist() 121 | recall_k = [] 122 | if batch_size <= 10: 123 | ks = [1, max(1, round(batch_size*0.2)), max(1, round(batch_size*0.5))] 124 | elif batch_size <= 100: 125 | ks = [1, max(1, round(batch_size*0.1)), max(1, round(batch_size*0.5))] 126 | else: 127 | raise ValueError("batch_size: {0} is not proper".format(batch_size)) 128 | for k in ks: 129 | # sorted_indices[:,:k]: (batch_size, k) 130 | num_ok = 0 131 | for tgt, topk in zip(targets, sorted_indices[:,:k].tolist()): 132 | if tgt in topk: 133 | num_ok += 1 134 | recall_k.append(num_ok/batch_size) 135 | 136 | # MRR 137 | MRR = 0 138 | for tgt, topk in zip(targets, sorted_indices.tolist()): 139 | rank = topk.index(tgt)+1 140 | MRR += 1/rank 141 | MRR = MRR/batch_size 142 | return recall_k, MRR 143 | 144 | def compute_metrics_from_logits(logits, targets): 145 | """ 146 | recall@k for N candidates 147 | 148 | logits: (batch_size, num_candidates) 149 | targets: (batch_size, ) 150 | 151 | """ 152 | batch_size, num_candidates = logits.shape 153 | 154 | sorted_indices = logits.sort(descending=True)[1] 155 | targets = targets.tolist() 156 | 157 | recall_k = [] 158 | if num_candidates <= 10: 159 | ks = [1, max(1, round(num_candidates*0.2)), max(1, round(num_candidates*0.5))] 160 | elif num_candidates <= 100: 161 | ks = [1, max(1, round(num_candidates*0.1)), max(1, round(num_candidates*0.5))] 162 | else: 163 | raise ValueError("num_candidates: {0} is not proper".format(num_candidates)) 164 | for k in ks: 165 | # sorted_indices[:,:k]: (batch_size, k) 166 | num_ok = 0 167 | for tgt, topk in zip(targets, sorted_indices[:,:k].tolist()): 168 | if tgt in topk: 169 | num_ok += 1 170 | recall_k.append(num_ok/batch_size) 171 | 172 | # MRR 173 | MRR = 0 174 | for tgt, topk in zip(targets, sorted_indices.tolist()): 175 | rank = topk.index(tgt)+1 176 | MRR += 1/rank 177 | MRR = MRR/batch_size 178 | return recall_k, MRR 179 | 180 | -------------------------------------------------------------------------------- /Tasks/Generation/src/arguments.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: ethanlian 3 | Date: 2022-03-13 20:48:22 4 | LastEditTime: 2022-03-14 09:40:51 5 | LastEditors: ethanlian 6 | FilePath: /TopSearch/topsearch-chat-engine-closechat/DocGroundedDialogue/naturalconv/arguments.py 7 | ''' 8 | 9 | from dataclasses import dataclass, field 10 | from typing import Optional 11 | import pickle as pk 12 | import json 13 | 14 | @dataclass 15 | class ModelArguments: 16 | """ 17 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 18 | """ 19 | 20 | model_name_or_path: str = field( 21 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 22 | ) 23 | model_name: str = field( 24 | default=None, metadata={"help": "define which model to load"} 25 | ) 26 | config_name: Optional[str] = field( 27 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 28 | ) 29 | tokenizer_name: Optional[str] = field( 30 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 31 | ) 32 | cache_dir: Optional[str] = field( 33 | default=None, 34 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 35 | ) 36 | use_fast_tokenizer: bool = field( 37 | default=True, 38 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 39 | ) 40 | model_revision: str = field( 41 | default="main", 42 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 43 | ) 44 | use_auth_token: bool = field( 45 | default=False, 46 | metadata={ 47 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 48 | "with private models)." 49 | }, 50 | ) 51 | 52 | 53 | @dataclass 54 | class DataTrainingArguments: 55 | """ 56 | Arguments pertaining to what data we are going to input our model for training and eval. 57 | """ 58 | 59 | dataset_name: Optional[str] = field( 60 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 61 | ) 62 | dataset_config_name: Optional[str] = field( 63 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 64 | ) 65 | text_column: Optional[str] = field( 66 | default=None, 67 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 68 | ) 69 | summary_column: Optional[str] = field( 70 | default=None, 71 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 72 | ) 73 | train_file: Optional[str] = field( 74 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 75 | ) 76 | validation_file: Optional[str] = field( 77 | default=None, 78 | metadata={ 79 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 80 | "(a jsonlines or csv file)." 81 | }, 82 | ) 83 | test_file: Optional[str] = field( 84 | default=None, 85 | metadata={ 86 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 87 | }, 88 | ) 89 | overwrite_cache: bool = field( 90 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 91 | ) 92 | preprocessing_num_workers: Optional[int] = field( 93 | default=None, 94 | metadata={"help": "The number of processes to use for the preprocessing."}, 95 | ) 96 | max_source_length: Optional[int] = field( 97 | default=1024, 98 | metadata={ 99 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 100 | "than this will be truncated, sequences shorter will be padded." 101 | }, 102 | ) 103 | max_target_length: Optional[int] = field( 104 | default=128, 105 | metadata={ 106 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 107 | "than this will be truncated, sequences shorter will be padded." 108 | }, 109 | ) 110 | val_max_target_length: Optional[int] = field( 111 | default=None, 112 | metadata={ 113 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 114 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 115 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 116 | "during ``evaluate`` and ``predict``." 117 | }, 118 | ) 119 | pad_to_max_length: bool = field( 120 | default=False, 121 | metadata={ 122 | "help": "Whether to pad all samples to model maximum sentence length. " 123 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 124 | "efficient on GPU but very bad for TPU." 125 | }, 126 | ) 127 | max_train_samples: Optional[int] = field( 128 | default=None, 129 | metadata={ 130 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 131 | "value if set." 132 | }, 133 | ) 134 | max_val_samples: Optional[int] = field( 135 | default=None, 136 | metadata={ 137 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 138 | "value if set." 139 | }, 140 | ) 141 | max_test_samples: Optional[int] = field( 142 | default=None, 143 | metadata={ 144 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 145 | "value if set." 146 | }, 147 | ) 148 | num_beams: Optional[int] = field( 149 | default=None, 150 | metadata={ 151 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 152 | "which is used during ``evaluate`` and ``predict``." 153 | }, 154 | ) 155 | ignore_pad_token_for_loss: bool = field( 156 | default=True, 157 | metadata={ 158 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 159 | }, 160 | ) 161 | source_prefix: Optional[str] = field( 162 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 163 | ) 164 | 165 | def __post_init__(self): 166 | # if self.dataset_name is None and self.train_file is None and self.validation_file is None: 167 | # raise ValueError("Need either a dataset name or a training/validation file.") 168 | # else: 169 | # if self.train_file is not None: 170 | # extension = self.train_file.split(".")[-1] 171 | # assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 172 | # if self.validation_file is not None: 173 | # extension = self.validation_file.split(".")[-1] 174 | # assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 175 | if self.val_max_target_length is None: 176 | self.val_max_target_length = self.max_target_length -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/src/train_bert.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import logging 5 | import os 6 | import numpy as np 7 | from transformers import BertTokenizer, BertConfig 8 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 9 | 10 | from utils import set_seed, load_pk_persona, load_pk, compute_metrics_from_logits 11 | from torch.utils.data import DataLoader 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch import optim, cuda 16 | from model import SingleBert 17 | from dataloader import SingleBertDataset 18 | import torch.nn.functional as F 19 | from tqdm import tqdm 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | def freeze_params(model): 23 | """Set requires_grad=False for each of model.parameters()""" 24 | for _, para in model.named_parameters(): 25 | para.requires_grad = False 26 | 27 | def validation(model, test_dataloader): 28 | model.eval() 29 | 30 | total_acc = [] 31 | total_loss = [] 32 | total_recall = [] 33 | total_MRR = [] 34 | 35 | with torch.no_grad(): 36 | for _, data in enumerate(tqdm(test_dataloader, desc='Evaluating')): 37 | ids, mask, token_type_ids = data['ids'], data['mask'], data['token_type_ids'] 38 | ids = data['ids'].to(device, dtype=torch.long) 39 | mask = data['mask'].to(device, dtype=torch.long) 40 | token_type_ids = data['token_type_ids'].to(device, dtype=torch.long) 41 | label = data['label'].to(device, dtype=torch.long) 42 | assert label.size(0) == 10 43 | 44 | logits = model(ids, mask, token_type_ids) # 45 | loss = F.cross_entropy(logits, label) 46 | 47 | logits_softmax = F.softmax(logits, dim=1)[:,1].unsqueeze(0) 48 | label_softmax = torch.tensor([9], dtype=torch.long).to(device) 49 | 50 | acc = (label_softmax.long() == logits_softmax.float().argmax(dim=1)).sum() / label_softmax.size(0) 51 | test_recall, test_MRR = compute_metrics_from_logits(logits_softmax, label_softmax) 52 | 53 | total_loss.append(float(loss)) 54 | total_acc.append(float(acc)) 55 | total_recall.append(test_recall) 56 | total_MRR.append(test_MRR) 57 | 58 | return np.mean(total_loss), np.mean(total_acc), np.mean(total_recall, axis=0), np.mean(total_MRR) 59 | 60 | 61 | def save_model(args, output_dir, model, tokenizer=None): 62 | 63 | os.makedirs(output_dir, exist_ok=True) 64 | logger.info(f"Saving model checkpoint to {output_dir}") 65 | 66 | if not isinstance(model, PreTrainedModel): 67 | if isinstance(unwrap_model(model), PreTrainedModel): 68 | if state_dict is None: 69 | state_dict = model.state_dict() 70 | unwrap_model(model).save_pretrained(output_dir, state_dict=state_dict) 71 | else: 72 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 73 | if state_dict is None: 74 | state_dict = model.state_dict() 75 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) 76 | else: 77 | model.save_pretrained(output_dir, state_dict=model.state_dict()) 78 | if tokenizer is not None: 79 | tokenizer.save_pretrained(output_dir) 80 | # Good practice: save your training arguments together with the trained model 81 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--model_path", default="bert-base-chinese", type=str) 87 | parser.add_argument("--load_model_path", default="", type=str) 88 | parser.add_argument("--output_dir", default="./outputs", type=str) 89 | parser.add_argument("--writer_dir", default="./outputs/runs", type=str) 90 | parser.add_argument("--lr", default=1e-5, type=float) 91 | parser.add_argument("--add_id", action='store_true', default=False) 92 | parser.add_argument("--train_batch_size", default=40, type=int) 93 | parser.add_argument("--test_batch_size", default=10, type=int) 94 | parser.add_argument("--max_length", default=64, type=int) 95 | parser.add_argument("--epoch", default=30, type=int) 96 | parser.add_argument("--logging_steps", default=100, type=int) 97 | parser.add_argument("--max_train_samples", default=400000, type=int) 98 | parser.add_argument("--max_val_samples", default=10000, type=int) 99 | parser.add_argument("--seed", default=20, type=int) 100 | parser.add_argument("--data_dir", default="ProcessedData/AddresseeRecognition", type=str) 101 | parser.add_argument("--do_train", action='store_true', default=False) 102 | parser.add_argument("--do_eval", action='store_true', default=False) 103 | parser.add_argument("--do_test", action='store_true', default=False) 104 | parser.add_argument("--freeze_plm", action='store_true', default=False) 105 | 106 | args = parser.parse_args() 107 | 108 | logger = logging.getLogger(__name__) 109 | set_seed(args.seed) 110 | 111 | datasets = {} 112 | data_files = {} 113 | 114 | if args.do_train: 115 | writer = SummaryWriter(log_dir=args.writer_dir, flush_secs=120) 116 | data_files["train"] = os.path.join(args.data_dir, "train_data.pk") 117 | if args.do_eval: 118 | data_files["validation"] = os.path.join(args.data_dir, "dev_data.pk") 119 | if args.do_test: 120 | data_files["test"] = os.path.join(args.data_dir, "test_data.pk") 121 | 122 | for key in data_files: 123 | if args.add_id: 124 | print("load ID") 125 | persona_id_path = "ProcessedData/basic_profile.json" 126 | datasets[key] = load_pk_persona(data_files[key], persona_id_path) 127 | else: 128 | print("load no persona!") 129 | datasets[key] = load_pk(data_files[key]) 130 | 131 | logging.basicConfig( 132 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 133 | datefmt="%m/%d/%Y %H:%M:%S", 134 | handlers=[logging.StreamHandler(sys.stdout)], 135 | ) 136 | 137 | device = 'cuda' if cuda.is_available() else 'cpu' 138 | config = BertConfig.from_pretrained(args.model_path) 139 | tokenizer = BertTokenizer.from_pretrained(args.model_path) 140 | 141 | model = SingleBert(config).to(device) 142 | 143 | if args.load_model_path != "": 144 | model.load_state_dict(torch.load(os.path.join(args.load_model_path, "pytorch_model.bin"))) 145 | 146 | optimizer = optim.Adam(model.parameters(), lr = args.lr ) 147 | 148 | if args.freeze_plm: 149 | print("Freeze the pretrained model.......") 150 | freeze_params(model) 151 | 152 | if args.do_train: 153 | train_dataset = datasets["train"] 154 | if args.max_train_samples is not None: 155 | train_dataset = train_dataset.select(range(args.max_train_samples)) 156 | 157 | print("train dataset length is {}".format(len(train_dataset))) 158 | training_set = SingleBertDataset(train_dataset, tokenizer, args.max_length) 159 | print("SiameseNetworkDataset length is {}".format(len(training_set))) 160 | training_loader = DataLoader(training_set, batch_size=args.train_batch_size, shuffle=False, num_workers=4) 161 | 162 | print("train dataset processed over") 163 | print("train dataset length is {}".format(len(training_loader))) 164 | 165 | if args.do_eval: 166 | eval_dataset = datasets["validation"] 167 | if args.max_val_samples is not None: 168 | eval_dataset = eval_dataset.select(range(args.max_val_samples)) 169 | 170 | testing_set = SingleBertDataset(eval_dataset, tokenizer, args.max_length) 171 | testing_loader = DataLoader(testing_set, batch_size=args.test_batch_size, shuffle=False, num_workers=4) 172 | 173 | print("eval dataset processed over") 174 | 175 | # Training 176 | if args.do_train: 177 | print("Begin training") 178 | train_step = -1 179 | for epoch in range(args.epoch): 180 | model.train() 181 | for data in tqdm(training_loader, desc='Training'): 182 | 183 | train_step += 1 184 | ids = data['ids'].to(device, dtype=torch.long) 185 | mask = data['mask'].to(device, dtype=torch.long) 186 | token_type_ids = data['token_type_ids'].to(device, dtype=torch.long) 187 | label = data['label'].to(device, dtype=torch.long) 188 | 189 | optimizer.zero_grad() 190 | logits = model(ids, mask, token_type_ids) 191 | loss = F.cross_entropy(logits, label) 192 | 193 | if train_step % 500==0: 194 | train_acc = (label.long() == logits.float().argmax(dim=1)).sum() / label.size(0) 195 | print(f'Step:{train_step}, Epoch:{epoch}, Loss:{loss.item()}, batch_acc:{train_acc}') 196 | writer.add_scalar('Loss/train', loss.item(), train_step) 197 | loss.backward() 198 | optimizer.step() 199 | 200 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 201 | print(f'Test Epoch:{epoch}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 202 | writer.add_scalar('Test loss', test_loss, epoch) 203 | writer.add_scalar('Test accuracy', test_acc, epoch) 204 | writer.add_scalar('Test MRR', test_mrr, epoch) 205 | 206 | save_model_path = os.path.join(args.output_dir,"epoch_{}".format(epoch)) 207 | save_model(args, save_model_path, model, tokenizer) 208 | 209 | 210 | if args.do_eval: 211 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 212 | print(f'Test checkpoint{args.load_model_path}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 213 | -------------------------------------------------------------------------------- /Tasks/ResponseModeling/src/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from transformers import BertModel, BertPreTrainedModel, BertTokenizer 5 | from transformers.models.bert.modeling_bert import (BertEmbeddings, 6 | BertEncoder, BertPooler) 7 | import torch.nn.functional as F 8 | 9 | 10 | class BertModel2(BertModel): 11 | def __init__(self, config, add_pooling_layer=True): 12 | super().__init__(config) 13 | self.config = config 14 | self.embeddings = BertEmbeddings(config) 15 | self.encoder = BertEncoder(config) 16 | self.pooler = BertPooler(config) if add_pooling_layer else None 17 | 18 | # Initialize weights and apply final processing 19 | self.post_init() 20 | 21 | def get_input_embeddings(self): 22 | return self.embeddings 23 | 24 | 25 | class ThreeBert(BertPreTrainedModel): 26 | def __init__(self, config, train_from_scratch=False): 27 | super(ThreeBert, self).__init__(config) 28 | 29 | self.bert_model_config = config 30 | self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 31 | 32 | if not train_from_scratch: 33 | self.bert_model1 = BertModel.from_pretrained("bert-base-chinese", config=self.bert_model_config) 34 | self.bert_model2 = BertModel2.from_pretrained("bert-base-chinese", config=self.bert_model_config) 35 | self.bert_model3 = BertModel.from_pretrained("bert-base-chinese", config=self.bert_model_config) 36 | else: 37 | self.bert_model1 = BertModel(config=self.bert_model_config) 38 | self.bert_model2 = BertModel2(config=self.bert_model_config) 39 | self.bert_model3 = BertModel(config=self.bert_model_config) 40 | 41 | def get_history(self, history): 42 | 43 | history_posts = [] 44 | for i in range(history.size(0)): 45 | history_posts.append(self.history_post_list[str(int(history[i]))]) 46 | output = self.bert_tokenizer(history_posts, max_length=512, padding=True, truncation=True, return_tensors="pt") 47 | for index in output: 48 | output[index] = output[index].to(self.device) 49 | assert output[index].device != "cpu" 50 | return output 51 | 52 | def match(self, model, x, y, x_mask, y_mask): 53 | # Multi-hop Co-Attention 54 | # x: (batch_size, m, hidden_size) 55 | # y: (batch_size, n, hidden_size) 56 | # x_mask: (batch_size, m) 57 | # y_mask: (batch_size, n) 58 | assert x.dim() == 3 and y.dim() == 3 59 | assert x_mask.dim() == 2 and y_mask.dim() == 2 60 | assert x_mask.shape == x.shape[:2] and y_mask.shape == y.shape[:2] 61 | 62 | attn_mask = torch.bmm(x_mask.unsqueeze(-1), y_mask.unsqueeze(1)) # (batch_size, m, n) 63 | attn = torch.bmm(x, y.transpose(1,2)) # (batch_size, m, n) 64 | model.attn = attn 65 | model.attn_mask = attn_mask 66 | 67 | x_to_y = torch.softmax(attn * attn_mask + (-5e4) * (1-attn_mask), dim=2) # (batch_size, m, n) 68 | y_to_x = torch.softmax(attn * attn_mask + (-5e4) * (1-attn_mask), dim=1).transpose(1,2) # # (batch_size, n, m) 69 | 70 | # x_attended, y_attended = None, None # no hop-1 71 | x_attended = torch.bmm(x_to_y, y) # (batch_size, m, hidden_size) 72 | y_attended = torch.bmm(y_to_x, x) # (batch_size, n, hidden_size) 73 | 74 | # x_attended_2hop, y_attended_2hop = None, None # no hop-2 75 | y_attn = torch.bmm(y_to_x.mean(dim=1, keepdim=True), x_to_y) # (batch_size, 1, n) # true important attention over y 76 | x_attn = torch.bmm(x_to_y.mean(dim=1, keepdim=True), y_to_x) # (batch_size, 1, m) # true important attention over x 77 | 78 | # truly attended representation 79 | x_attended_2hop = torch.bmm(x_attn, x).squeeze(1) # (batch_size, hidden_size) 80 | y_attended_2hop = torch.bmm(y_attn, y).squeeze(1) # (batch_size, hidden_size) 81 | 82 | x_attended = x_attended, x_attended_2hop 83 | y_attended = y_attended, y_attended_2hop 84 | 85 | return x_attended, y_attended 86 | 87 | def aggregate(self, aggregation_method, x, x_mask): 88 | # x: (batch_size, seq_len, emb_size) 89 | # x_mask: (batch_size, seq_len) 90 | assert x.dim() == 3 and x_mask.dim() == 2 91 | assert x.shape[:2] == x_mask.shape 92 | # batch_size, seq_len, emb_size = x.shape 93 | 94 | if aggregation_method == "mean": 95 | return (x * x_mask.unsqueeze(-1)).sum(dim=1)/x_mask.sum(dim=-1, keepdim=True).clamp(min=1) # (batch_size, emb_size) 96 | 97 | if aggregation_method == "max": 98 | return x.masked_fill(x_mask.unsqueeze(-1)==0, -5e4).max(dim=1)[0] # (batch_size, emb_size) 99 | 100 | if aggregation_method == "mean_max": 101 | return torch.cat([(x * x_mask.unsqueeze(-1)).sum(dim=1)/x_mask.sum(dim=-1, keepdim=True).clamp(min=1), \ 102 | x.masked_fill(x_mask.unsqueeze(-1)==0, -5e4).max(dim=1)[0]], dim=-1) # (batch_size, 2*emb_size) 103 | 104 | 105 | def fuse(self, model, aggregation_method, batch_x_emb, batch_y_emb, batch_persona_emb, \ 106 | batch_x_mask, batch_y_mask, batch_persona_mask, batch_size, num_candidates): 107 | 108 | batch_x_emb, batch_y_emb_context = self.match(model, batch_x_emb, batch_y_emb, batch_x_mask, batch_y_mask) 109 | # batch_x_emb: ((batch_size*num_candidates, m, emb_size), (batch_size*num_candidates, emb_size)) 110 | # batch_y_emb_context: (batch_size*num_candidates, n, emb_size), (batch_size*num_candidates, emb_size) 111 | 112 | # hop 2 results 113 | batch_x_emb_2hop = batch_x_emb[1] 114 | batch_y_emb_context_2hop = batch_y_emb_context[1] 115 | 116 | # mean_max aggregation for the 1st hop result 117 | batch_x_emb = self.aggregate(aggregation_method, batch_x_emb[0], batch_x_mask) # batch_x_emb: (batch_size*num_candidates, 2*emb_size) 118 | batch_y_emb_context = self.aggregate(aggregation_method, batch_y_emb_context[0], batch_y_mask) # batch_y_emb_context: (batch_size*num_candidates, 2*emb_size) 119 | 120 | if batch_persona_emb is not None: 121 | batch_persona_emb, batch_y_emb_persona = self.match(model, batch_persona_emb, batch_y_emb, batch_persona_mask, batch_y_mask) 122 | # batch_persona_emb: (batch_size*num_candidates, m, emb_size), (batch_size*num_candidates, emb_size) 123 | # batch_y_emb_persona: (batch_size*num_candidates, n, emb_size), (batch_size*num_candidates, emb_size) 124 | 125 | batch_persona_emb_2hop = batch_persona_emb[1] 126 | batch_y_emb_persona_2hop = batch_y_emb_persona[1] 127 | 128 | # # no hop-1 129 | # return torch.bmm(torch.cat([batch_x_emb_2hop, batch_persona_emb_2hop], dim=-1).unsqueeze(1), \ 130 | # torch.cat([batch_y_emb_context_2hop, batch_y_emb_persona_2hop], dim=-1)\ 131 | # .unsqueeze(-1)).reshape(batch_size, num_candidates) 132 | 133 | batch_persona_emb = self.aggregate(aggregation_method, batch_persona_emb[0], batch_persona_mask) # batch_persona_emb: (batch_size*num_candidates, 2*emb_size) 134 | batch_y_emb_persona = self.aggregate(aggregation_method, batch_y_emb_persona[0], batch_y_mask) # batch_y_emb_persona: (batch_size*num_candidates, 2*emb_size) 135 | 136 | # # no hop-2 137 | # return torch.bmm(torch.cat([batch_x_emb, batch_persona_emb], dim=-1).unsqueeze(1), \ 138 | # torch.cat([batch_y_emb_context, batch_y_emb_persona], dim=-1)\ 139 | # .unsqueeze(-1)).reshape(batch_size, num_candidates) 140 | return torch.bmm(torch.cat([batch_x_emb, batch_x_emb_2hop, batch_persona_emb, batch_persona_emb_2hop], dim=-1).unsqueeze(1), \ 141 | torch.cat([batch_y_emb_context, batch_y_emb_context_2hop, batch_y_emb_persona, batch_y_emb_persona_2hop], dim=-1)\ 142 | .unsqueeze(-1)).reshape(batch_size, num_candidates) 143 | else: 144 | return torch.bmm(torch.cat([batch_x_emb, batch_x_emb_2hop], dim=-1).unsqueeze(1), \ 145 | torch.cat([batch_y_emb_context, batch_y_emb_context_2hop], dim=-1)\ 146 | .unsqueeze(-1)).reshape(batch_size, num_candidates) 147 | 148 | def forward(self, ids, mask, token_type_ids, apply_interaction=True, aggregation_method="max"): 149 | # version 3, fussion 150 | # Context 151 | input_ids1 = ids[0] 152 | attention_mask1 = mask[0] 153 | token_type_ids1 = token_type_ids[0] 154 | # Response 155 | input_ids2 = ids[1] 156 | attention_mask2 = mask[1] 157 | token_type_ids2 = token_type_ids[1] 158 | 159 | if len(ids)==3: 160 | # print("begin history post") 161 | input_ids3 = ids[2] 162 | attention_mask3 = mask[2] 163 | token_type_ids3 = token_type_ids[2] 164 | # last_hidden_state, bert_output = self.bert_model1(input_ids=input_ids3, attention_mask=attention_mask3, token_type_ids=token_type_ids3) 165 | outputs = self.bert_model1(input_ids=input_ids3, attention_mask=attention_mask3, token_type_ids=token_type_ids3) 166 | 167 | outputs1 = self.bert_model2( 168 | input_ids=input_ids1, 169 | attention_mask=attention_mask1, 170 | token_type_ids=token_type_ids1, 171 | ) 172 | outputs2 = self.bert_model3( 173 | input_ids=input_ids2, 174 | attention_mask=attention_mask2, 175 | token_type_ids=token_type_ids2, 176 | ) 177 | 178 | if apply_interaction: 179 | 180 | attention_mask1 = attention_mask1.float() 181 | attention_mask2 = attention_mask2.float() 182 | attention_mask3 = attention_mask3.float() 183 | 184 | batch_size, sent_len, emb_size = outputs2[0].shape 185 | history_output = outputs[0].repeat_interleave(batch_size, dim=0) 186 | attention_mask3 = attention_mask3.repeat_interleave(batch_size, dim=0) 187 | 188 | context_output = outputs1[0].repeat_interleave(batch_size, dim=0) 189 | attention_mask1 = attention_mask1.repeat_interleave(batch_size, dim=0) 190 | 191 | response_output = outputs2[0].unsqueeze(0).repeat(batch_size, 1, 1, 1).reshape(-1, sent_len, emb_size) 192 | attention_mask2 = attention_mask2.unsqueeze(0).repeat(batch_size, 1, 1).reshape(-1, sent_len) 193 | 194 | logits = self.fuse(self.bert_model2, aggregation_method, \ 195 | context_output, response_output, history_output, attention_mask1, attention_mask2, attention_mask3, batch_size, batch_size) 196 | targets = torch.arange(batch_size, dtype=torch.long, device=self.device) 197 | # loss = F.cross_entropy(logits, targets) 198 | # num_ok = (targets.long() == logits.float().argmax(dim=1)).sum() 199 | # return loss, num_ok 200 | return logits, targets 201 | 202 | else: 203 | history_output = outputs[0].mean(dim=1) 204 | context_output = outputs1[0].mean(dim=1) 205 | response_output = outputs2[0].mean(dim=1) 206 | history_context_output = (history_output + context_output)/2 207 | 208 | return history_context_output, response_output 209 | 210 | else: 211 | outputs1 = self.bert_model2( 212 | input_ids=input_ids1, 213 | attention_mask=attention_mask1, 214 | token_type_ids=token_type_ids1, 215 | ) 216 | outputs2 = self.bert_model3( 217 | input_ids=input_ids2, 218 | attention_mask=attention_mask2, 219 | token_type_ids=token_type_ids2, 220 | ) 221 | if apply_interaction: 222 | attention_mask1 = attention_mask1.float() 223 | attention_mask2 = attention_mask2.float() 224 | 225 | attention_mask3 = None 226 | history_output = None 227 | 228 | batch_size, sent_len, emb_size = outputs2[0].shape 229 | 230 | context_output = outputs1[0].repeat_interleave(batch_size, dim=0) 231 | attention_mask1 = attention_mask1.repeat_interleave(batch_size, dim=0) 232 | 233 | response_output = outputs2[0].unsqueeze(0).repeat(batch_size, 1, 1, 1).reshape(-1, sent_len, emb_size) 234 | attention_mask2 = attention_mask2.unsqueeze(0).repeat(batch_size, 1, 1).reshape(-1, sent_len) 235 | 236 | logits = self.fuse(self.bert_model2, aggregation_method, \ 237 | context_output, response_output, history_output, attention_mask1, attention_mask2, attention_mask3, batch_size, batch_size) 238 | targets = torch.arange(batch_size, dtype=torch.long, device=self.device) 239 | 240 | return logits, targets 241 | else: 242 | return outputs1[0].mean(dim=1), outputs2[0].mean(dim=1) 243 | -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/src/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from transformers import (BertModel, BertPreTrainedModel, BertTokenizer) 5 | from transformers.models.bert.modeling_bert import (BertEmbeddings, BertEncoder, BertPooler) 6 | import torch.nn.functional as F 7 | 8 | class SingleBert(BertPreTrainedModel): 9 | def __init__(self, config): 10 | super(SingleBert, self).__init__(config) 11 | self.bert_model_config = config 12 | self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 13 | self.bert_model = BertModel.from_pretrained("bert-base-chinese", config=self.bert_model_config) 14 | self.fc = nn.Linear(config.hidden_size, 2) 15 | 16 | def forward(self, ids, mask, token_type_ids): 17 | 18 | outputs = self.bert_model( 19 | input_ids=ids, 20 | attention_mask=mask, 21 | token_type_ids=token_type_ids, 22 | ) 23 | fc_outputs = self.fc(outputs[1]) 24 | 25 | return fc_outputs 26 | 27 | class BertModel2(BertModel): 28 | def __init__(self, config, add_pooling_layer=True): 29 | super().__init__(config) 30 | self.config = config 31 | self.embeddings = BertEmbeddings(config) 32 | self.encoder = BertEncoder(config) 33 | self.pooler = BertPooler(config) if add_pooling_layer else None 34 | # Initialize weights and apply final processing 35 | self.post_init() 36 | 37 | def get_input_embeddings(self): 38 | return self.embeddings 39 | 40 | 41 | class ThreeBert(BertPreTrainedModel): 42 | def __init__(self, config, train_from_scratch=False): 43 | super(ThreeBert, self).__init__(config) 44 | 45 | self.bert_model_config = config # BertConfig.from_pretrained("bert-base-chinese") 46 | self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 47 | if not train_from_scratch: 48 | self.bert_model1 = BertModel.from_pretrained("bert-base-chinese", config=self.bert_model_config) 49 | self.bert_model2 = BertModel2.from_pretrained("bert-base-chinese", config=self.bert_model_config) 50 | self.bert_model3 = BertModel.from_pretrained("bert-base-chinese", config=self.bert_model_config) 51 | else: 52 | self.bert_model1 = BertModel(config=self.bert_model_config) 53 | self.bert_model2 = BertModel2(config=self.bert_model_config) 54 | self.bert_model3 = BertModel(config=self.bert_model_config) 55 | 56 | def get_history(self, history): 57 | 58 | history_posts = [] 59 | for i in range(history.size(0)): 60 | history_posts.append(self.history_post_list[str(int(history[i]))]) 61 | output = self.bert_tokenizer(history_posts, max_length=512, padding=True, truncation=True, return_tensors="pt") 62 | for index in output: 63 | output[index] = output[index].to(self.device) 64 | assert output[index].device != "cpu" 65 | return output 66 | 67 | def match(self, model, x, y, x_mask, y_mask): 68 | # Multi-hop Co-Attention 69 | # x: (batch_size, m, hidden_size) 70 | # y: (batch_size, n, hidden_size) 71 | # x_mask: (batch_size, m) 72 | # y_mask: (batch_size, n) 73 | assert x.dim() == 3 and y.dim() == 3 74 | assert x_mask.dim() == 2 and y_mask.dim() == 2 75 | assert x_mask.shape == x.shape[:2] and y_mask.shape == y.shape[:2] 76 | 77 | attn_mask = torch.bmm(x_mask.unsqueeze(-1), y_mask.unsqueeze(1)) # (batch_size, m, n) 78 | attn = torch.bmm(x, y.transpose(1,2)) # (batch_size, m, n) 79 | model.attn = attn 80 | model.attn_mask = attn_mask 81 | 82 | x_to_y = torch.softmax(attn * attn_mask + (-5e4) * (1-attn_mask), dim=2) # (batch_size, m, n) 83 | y_to_x = torch.softmax(attn * attn_mask + (-5e4) * (1-attn_mask), dim=1).transpose(1,2) # # (batch_size, n, m) 84 | 85 | # x_attended, y_attended = None, None # no hop-1 86 | x_attended = torch.bmm(x_to_y, y) # (batch_size, m, hidden_size) 87 | y_attended = torch.bmm(y_to_x, x) # (batch_size, n, hidden_size) 88 | 89 | # x_attended_2hop, y_attended_2hop = None, None # no hop-2 90 | y_attn = torch.bmm(y_to_x.mean(dim=1, keepdim=True), x_to_y) # (batch_size, 1, n) # true important attention over y 91 | x_attn = torch.bmm(x_to_y.mean(dim=1, keepdim=True), y_to_x) # (batch_size, 1, m) # true important attention over x 92 | 93 | # truly attended representation 94 | x_attended_2hop = torch.bmm(x_attn, x).squeeze(1) # (batch_size, hidden_size) 95 | y_attended_2hop = torch.bmm(y_attn, y).squeeze(1) # (batch_size, hidden_size) 96 | 97 | x_attended = x_attended, x_attended_2hop 98 | y_attended = y_attended, y_attended_2hop 99 | 100 | return x_attended, y_attended 101 | 102 | def aggregate(self, aggregation_method, x, x_mask): 103 | # x: (batch_size, seq_len, emb_size) 104 | # x_mask: (batch_size, seq_len) 105 | assert x.dim() == 3 and x_mask.dim() == 2 106 | assert x.shape[:2] == x_mask.shape 107 | # batch_size, seq_len, emb_size = x.shape 108 | 109 | if aggregation_method == "mean": 110 | return (x * x_mask.unsqueeze(-1)).sum(dim=1)/x_mask.sum(dim=-1, keepdim=True).clamp(min=1) # (batch_size, emb_size) 111 | 112 | if aggregation_method == "max": 113 | return x.masked_fill(x_mask.unsqueeze(-1)==0, -5e4).max(dim=1)[0] # (batch_size, emb_size) 114 | 115 | if aggregation_method == "mean_max": 116 | return torch.cat([(x * x_mask.unsqueeze(-1)).sum(dim=1)/x_mask.sum(dim=-1, keepdim=True).clamp(min=1), \ 117 | x.masked_fill(x_mask.unsqueeze(-1)==0, -5e4).max(dim=1)[0]], dim=-1) # (batch_size, 2*emb_size) 118 | 119 | 120 | def fuse(self, model, aggregation_method, batch_x_emb, batch_y_emb, batch_persona_emb, \ 121 | batch_x_mask, batch_y_mask, batch_persona_mask, batch_size, num_candidates): 122 | 123 | batch_x_emb, batch_y_emb_context = self.match(model, batch_x_emb, batch_y_emb, batch_x_mask, batch_y_mask) 124 | # batch_x_emb: ((batch_size*num_candidates, m, emb_size), (batch_size*num_candidates, emb_size)) 125 | # batch_y_emb_context: (batch_size*num_candidates, n, emb_size), (batch_size*num_candidates, emb_size) 126 | 127 | # hop 2 results 128 | batch_x_emb_2hop = batch_x_emb[1] 129 | batch_y_emb_context_2hop = batch_y_emb_context[1] 130 | 131 | # mean_max aggregation for the 1st hop result 132 | batch_x_emb = self.aggregate(aggregation_method, batch_x_emb[0], batch_x_mask) # batch_x_emb: (batch_size*num_candidates, 2*emb_size) 133 | batch_y_emb_context = self.aggregate(aggregation_method, batch_y_emb_context[0], batch_y_mask) # batch_y_emb_context: (batch_size*num_candidates, 2*emb_size) 134 | 135 | if batch_persona_emb is not None: 136 | batch_persona_emb, batch_y_emb_persona = self.match(model, batch_persona_emb, batch_y_emb, batch_persona_mask, batch_y_mask) 137 | # batch_persona_emb: (batch_size*num_candidates, m, emb_size), (batch_size*num_candidates, emb_size) 138 | # batch_y_emb_persona: (batch_size*num_candidates, n, emb_size), (batch_size*num_candidates, emb_size) 139 | 140 | batch_persona_emb_2hop = batch_persona_emb[1] 141 | batch_y_emb_persona_2hop = batch_y_emb_persona[1] 142 | 143 | # # no hop-1 144 | # return torch.bmm(torch.cat([batch_x_emb_2hop, batch_persona_emb_2hop], dim=-1).unsqueeze(1), \ 145 | # torch.cat([batch_y_emb_context_2hop, batch_y_emb_persona_2hop], dim=-1)\ 146 | # .unsqueeze(-1)).reshape(batch_size, num_candidates) 147 | 148 | batch_persona_emb = self.aggregate(aggregation_method, batch_persona_emb[0], batch_persona_mask) # batch_persona_emb: (batch_size*num_candidates, 2*emb_size) 149 | batch_y_emb_persona = self.aggregate(aggregation_method, batch_y_emb_persona[0], batch_y_mask) # batch_y_emb_persona: (batch_size*num_candidates, 2*emb_size) 150 | 151 | # # no hop-2 152 | # return torch.bmm(torch.cat([batch_x_emb, batch_persona_emb], dim=-1).unsqueeze(1), \ 153 | # torch.cat([batch_y_emb_context, batch_y_emb_persona], dim=-1)\ 154 | # .unsqueeze(-1)).reshape(batch_size, num_candidates) 155 | return torch.bmm(torch.cat([batch_x_emb, batch_x_emb_2hop, batch_persona_emb, batch_persona_emb_2hop], dim=-1).unsqueeze(1), \ 156 | torch.cat([batch_y_emb_context, batch_y_emb_context_2hop, batch_y_emb_persona, batch_y_emb_persona_2hop], dim=-1)\ 157 | .unsqueeze(-1)).reshape(batch_size, num_candidates) 158 | else: 159 | return torch.bmm(torch.cat([batch_x_emb, batch_x_emb_2hop], dim=-1).unsqueeze(1), \ 160 | torch.cat([batch_y_emb_context, batch_y_emb_context_2hop], dim=-1)\ 161 | .unsqueeze(-1)).reshape(batch_size, num_candidates) 162 | 163 | 164 | def forward(self, ids, mask, token_type_ids, apply_interaction=True, aggregation_method="max"): 165 | # version 3, fussion 166 | # Streamer 167 | input_ids1 = ids[0] 168 | attention_mask1 = mask[0] 169 | token_type_ids1 = token_type_ids[0] 170 | # Comments: batch * candiate * length, batch_size = candidate 171 | batch_size, candiate_length, max_len = ids[1].size() 172 | 173 | # assert batch_size == candiate_length 174 | input_ids2 = ids[1].reshape(batch_size*candiate_length, max_len) 175 | attention_mask2 = mask[1].reshape(batch_size*candiate_length, max_len) 176 | token_type_ids2 = token_type_ids[1].reshape(batch_size*candiate_length, max_len) 177 | 178 | if len(ids)==3: 179 | # print("begin history post") 180 | input_ids3 = ids[2] 181 | attention_mask3 = mask[2] 182 | token_type_ids3 = token_type_ids[2] 183 | # last_hidden_state, bert_output = self.bert_model1(input_ids=input_ids3, attention_mask=attention_mask3, token_type_ids=token_type_ids3) 184 | outputs = self.bert_model1(input_ids=input_ids3, attention_mask=attention_mask3, token_type_ids=token_type_ids3) 185 | 186 | outputs1 = self.bert_model2( 187 | input_ids=input_ids1, 188 | attention_mask=attention_mask1, 189 | token_type_ids=token_type_ids1, 190 | ) 191 | outputs2 = self.bert_model3( 192 | input_ids=input_ids2, 193 | attention_mask=attention_mask2, 194 | token_type_ids=token_type_ids2, 195 | ) 196 | 197 | if apply_interaction: 198 | 199 | attention_mask1 = attention_mask1.float() 200 | attention_mask2 = attention_mask2.float() 201 | attention_mask3 = attention_mask3.float() 202 | 203 | batch_size, sent_len, emb_size = outputs1[0].shape 204 | history_output = outputs[0].repeat_interleave(candiate_length, dim=0) 205 | attention_mask3 = attention_mask3.repeat_interleave(candiate_length, dim=0) 206 | 207 | context_output = outputs1[0].repeat_interleave(candiate_length, dim=0) 208 | attention_mask1 = attention_mask1.repeat_interleave(candiate_length, dim=0) 209 | 210 | response_output = outputs2[0] 211 | attention_mask2 = attention_mask2 212 | 213 | logits = self.fuse(self.bert_model2, aggregation_method, \ 214 | context_output, response_output, history_output, attention_mask1, attention_mask2, attention_mask3, batch_size, candiate_length) 215 | targets = torch.tensor([candiate_length-1 for i in range(batch_size)], dtype=torch.long, device=self.device) 216 | 217 | return logits, targets 218 | 219 | 220 | else: 221 | outputs1 = self.bert_model2( 222 | input_ids=input_ids1, 223 | attention_mask=attention_mask1, 224 | token_type_ids=token_type_ids1, 225 | ) 226 | outputs2 = self.bert_model3( 227 | input_ids=input_ids2, 228 | attention_mask=attention_mask2, 229 | token_type_ids=token_type_ids2, 230 | ) 231 | if apply_interaction: 232 | attention_mask1 = attention_mask1.float() 233 | attention_mask2 = attention_mask2.float() 234 | 235 | attention_mask3 = None 236 | history_output = None 237 | 238 | batch_size, sent_len, emb_size = outputs1[0].shape 239 | 240 | context_output = outputs1[0].repeat_interleave(candiate_length, dim=0) 241 | attention_mask1 = attention_mask1.repeat_interleave(candiate_length, dim=0) 242 | 243 | response_output = outputs2[0] 244 | attention_mask2 = attention_mask2 245 | 246 | logits = self.fuse(self.bert_model2, aggregation_method, \ 247 | context_output, response_output, history_output, attention_mask1, attention_mask2, attention_mask3, batch_size, candiate_length) 248 | targets = torch.tensor([candiate_length-1 for i in range(batch_size)], dtype=torch.long, device=self.device) 249 | 250 | return logits, targets 251 | 252 | -------------------------------------------------------------------------------- /Tasks/Generation/src/trainer/seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, List, Optional, Tuple, Union 15 | import collections 16 | 17 | from torch.utils.data.dataset import Dataset 18 | from torch.utils.data.dataloader import DataLoader 19 | 20 | import torch 21 | from packaging import version 22 | 23 | from torch import nn 24 | 25 | if version.parse(torch.__version__) >= version.parse("1.6"): 26 | from torch.cuda.amp import autocast 27 | 28 | from transformers import Seq2SeqTrainer 29 | from transformers.trainer_utils import PredictionOutput 30 | from transformers.utils import logging 31 | from transformers.trainer_pt_utils import DistributedTensorGatherer, SequentialDistributedSampler, nested_concat 32 | from transformers.file_utils import is_torch_tpu_available 33 | from transformers.trainer_utils import EvalPrediction, denumpify_detensorize, speed_metrics 34 | 35 | import time 36 | 37 | logger = logging.get_logger(__name__) 38 | 39 | if is_torch_tpu_available(): 40 | import torch_xla.core.xla_model as xm 41 | import torch_xla.debug.metrics as met 42 | import torch_xla.distributed.parallel_loader as pl 43 | 44 | class Seq2SeqTrainerNew(Seq2SeqTrainer): 45 | 46 | def predict( 47 | self, 48 | test_dataset: Dataset, 49 | ignore_keys: Optional[List[str]] = None, 50 | metric_key_prefix: str = "eval", 51 | max_length: Optional[int] = None, 52 | num_beams: Optional[int] = None, 53 | ) -> PredictionOutput: 54 | 55 | self._max_length = max_length 56 | self._num_beams = num_beams 57 | 58 | # memory metrics - must set up as early as possible 59 | self._memory_tracker.start() 60 | 61 | if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): 62 | raise ValueError("test_dataset must implement __len__") 63 | 64 | test_dataloader = self.get_test_dataloader(test_dataset) 65 | start_time = time.time() 66 | 67 | output, scores = self.prediction_loop( 68 | test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix 69 | ) 70 | output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset))) 71 | 72 | self._memory_tracker.stop_and_update_metrics(output.metrics) 73 | 74 | return output, scores 75 | 76 | def prediction_step( 77 | self, 78 | model: nn.Module, 79 | inputs: Dict[str, Union[torch.Tensor, Any]], 80 | prediction_loss_only: bool, 81 | ignore_keys: Optional[List[str]] = None, 82 | output_scores: bool = True, 83 | return_dict_in_generate: bool = True, 84 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 85 | """ 86 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 87 | 88 | Subclass and override to inject custom behavior. 89 | 90 | Args: 91 | model (:obj:`nn.Module`): 92 | The model to evaluate. 93 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 94 | The inputs and targets of the model. 95 | 96 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 97 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 98 | prediction_loss_only (:obj:`bool`): 99 | Whether or not to return the loss only. 100 | 101 | Return: 102 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 103 | labels (each being optional). 104 | """ 105 | 106 | if not self.args.predict_with_generate or prediction_loss_only: 107 | return super().prediction_step( 108 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, 109 | ) 110 | 111 | has_labels = "labels" in inputs 112 | inputs = self._prepare_inputs(inputs) 113 | 114 | gen_kwargs = { 115 | "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 116 | "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 117 | } 118 | 119 | generated_tokens_and_scores = self.model.generate( 120 | inputs["input_ids"], 121 | attention_mask=inputs["attention_mask"], 122 | **gen_kwargs, 123 | output_scores=output_scores, 124 | return_dict_in_generate=return_dict_in_generate, 125 | ) 126 | 127 | generated_tokens = generated_tokens_and_scores['sequences'] 128 | generated_scores = generated_tokens_and_scores['sequences_scores'] 129 | 130 | 131 | # in case the batch is shorter than max length, the output should be padded 132 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 133 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 134 | 135 | with torch.no_grad(): 136 | 137 | outputs = model(**inputs) 138 | if has_labels: 139 | if self.label_smoother is not None: 140 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 141 | else: 142 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 143 | else: 144 | loss = None 145 | 146 | if self.args.prediction_loss_only: 147 | return (loss, None, None) 148 | 149 | labels = inputs["labels"] 150 | if labels.shape[-1] < gen_kwargs["max_length"]: 151 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 152 | 153 | return (loss, generated_tokens, labels, generated_scores) 154 | 155 | def _pad_tensors_to_max_len(self, tensor, max_length): 156 | if self.tokenizer is None: 157 | raise ValueError( 158 | f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating " 159 | "this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer." 160 | ) 161 | # If PAD token is not defined at least EOS token has to be defined 162 | pad_token_id = ( 163 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 164 | ) 165 | 166 | padded_tensor = pad_token_id * torch.ones( 167 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 168 | ) 169 | padded_tensor[:, : tensor.shape[-1]] = tensor 170 | return padded_tensor 171 | 172 | 173 | def prediction_loop( 174 | self, 175 | dataloader: DataLoader, 176 | description: str, 177 | prediction_loss_only: Optional[bool] = None, 178 | ignore_keys: Optional[List[str]] = None, 179 | metric_key_prefix: str = "eval", 180 | ) -> PredictionOutput: 181 | """ 182 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 183 | 184 | Works both with or without labels. 185 | """ 186 | if not isinstance(dataloader.dataset, collections.abc.Sized): 187 | raise ValueError("dataset must implement __len__") 188 | prediction_loss_only = ( 189 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only 190 | ) 191 | 192 | if self.args.deepspeed and not self.args.do_train: 193 | # no harm, but flagging to the user that deepspeed config is ignored for eval 194 | # flagging only for when --do_train wasn't passed as only then it's redundant 195 | logger.info("Detected the deepspeed argument but it will not be used for evaluation") 196 | 197 | model = self._wrap_model(self.model, training=False) 198 | 199 | # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while 200 | # ``train`` is running, half it first and then put on device 201 | if not self.is_in_train and self.args.fp16_full_eval: 202 | model = model.half().to(self.args.device) 203 | 204 | batch_size = dataloader.batch_size 205 | num_examples = self.num_examples(dataloader) 206 | logger.info("***** Running %s *****", description) 207 | logger.info(" Num examples = %d", num_examples) 208 | logger.info(" Batch size = %d", batch_size) 209 | losses_host: torch.Tensor = None 210 | preds_host: Union[torch.Tensor, List[torch.Tensor]] = None 211 | labels_host: Union[torch.Tensor, List[torch.Tensor]] = None 212 | 213 | world_size = max(1, self.args.world_size) 214 | 215 | eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) 216 | if not prediction_loss_only: 217 | # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass 218 | # a batch size to the sampler) 219 | make_multiple_of = None 220 | if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): 221 | make_multiple_of = dataloader.sampler.batch_size 222 | preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) 223 | labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) 224 | 225 | model.eval() 226 | 227 | if is_torch_tpu_available(): 228 | dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) 229 | 230 | if self.args.past_index >= 0: 231 | self._past = None 232 | 233 | self.callback_handler.eval_dataloader = dataloader 234 | 235 | for step, inputs in enumerate(dataloader): 236 | loss, logits, labels, scores = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 237 | if loss is not None: 238 | losses = loss.repeat(batch_size) 239 | losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) 240 | if logits is not None: 241 | preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) 242 | if labels is not None: 243 | labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) 244 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 245 | 246 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 247 | if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: 248 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 249 | if not prediction_loss_only: 250 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 251 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 252 | 253 | # Set back to None to begin a new accumulation 254 | losses_host, preds_host, labels_host = None, None, None 255 | 256 | if self.args.past_index and hasattr(self, "_past"): 257 | # Clean the state at the end of the evaluation loop 258 | delattr(self, "_past") 259 | 260 | # Gather all remaining tensors and put them back on the CPU 261 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 262 | if not prediction_loss_only: 263 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 264 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 265 | 266 | eval_loss = eval_losses_gatherer.finalize() 267 | preds = preds_gatherer.finalize() if not prediction_loss_only else None 268 | label_ids = labels_gatherer.finalize() if not prediction_loss_only else None 269 | 270 | if self.compute_metrics is not None and preds is not None and label_ids is not None: 271 | metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 272 | else: 273 | metrics = {} 274 | 275 | # To be JSON-serializable, we need to remove numpy types or zero-d tensors 276 | metrics = denumpify_detensorize(metrics) 277 | 278 | if eval_loss is not None: 279 | metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() 280 | 281 | # Prefix all keys with metric_key_prefix + '_' 282 | for key in list(metrics.keys()): 283 | if not key.startswith(f"{metric_key_prefix}_"): 284 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 285 | 286 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics), scores -------------------------------------------------------------------------------- /Tasks/Generation/src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import logging 5 | import os 6 | import numpy as np 7 | from metrics import cal_bleu, cal_dist, cal_rouge, cal_ppl 8 | 9 | import transformers 10 | from transformers import (BertTokenizer, BartForConditionalGeneration, \ 11 | HfArgumentParser, DataCollatorForSeq2Seq,Seq2SeqTrainer, Seq2SeqTrainingArguments) 12 | from transformers.trainer_utils import is_main_process 13 | 14 | from arguments import DataTrainingArguments, ModelArguments 15 | from utils import set_seed, postprocess_text, load_pk 16 | import re 17 | 18 | def preprocess_function(examples): 19 | 20 | queries = examples['query'] 21 | reponses = examples['response'] 22 | # [CLS] [SEP] 23 | model_inputs = tokenizer(queries, max_length=data_args.max_source_length, padding=padding, truncation=True) 24 | with tokenizer.as_target_tokenizer(): 25 | labels = tokenizer(reponses, max_length=max_target_length, padding=padding, truncation=True) 26 | model_inputs["labels"] = labels["input_ids"] 27 | 28 | return model_inputs 29 | 30 | def freeze_params(model): 31 | """Set requires_grad=False for each of model.parameters()""" 32 | for name, para in model.named_parameters(): 33 | # print(name, para) 34 | if name != "persona_embdding.weight": 35 | para.requires_grad = False 36 | 37 | def compute_metrics(eval_preds): 38 | 39 | preds, labels = eval_preds 40 | if isinstance(preds, tuple): 41 | preds = preds[0] 42 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 43 | if data_args.ignore_pad_token_for_loss: 44 | # Replace -100 in the labels as we can't decode them. 45 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 46 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 47 | 48 | # Some simple post-processing 49 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 50 | # print("decoded_preds is ", decoded_preds) 51 | # print("decoded_labels is ", decoded_labels) 52 | dist_score = cal_dist(decoded_preds) 53 | bleu_score = cal_bleu(decoded_preds, decoded_labels) 54 | rouge_score = cal_rouge(decoded_preds, decoded_labels) 55 | 56 | result = rouge_score 57 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 58 | result["gen_len"] = np.mean(prediction_lens) 59 | result.update(dist_score) 60 | result.update(bleu_score) 61 | result = {k: round(v, 4) for k, v in result.items()} 62 | return result 63 | 64 | def de_prefix(sentence): 65 | # 实现对闲聊模块输出的句子,去除该句子复述用户query的前缀 66 | pattern = r',|\.|/|;|\'|`|\[|\]|<|>|\?|:|"|\{|\}|\~|!|@|#|\$|%|\^|&|\(|\)|-|=|\_|\+|,|?|。|、|;|‘|’|【|】|·|!| |…|(|)' 67 | flag = True 68 | for idx in re.finditer(pattern, sentence): 69 | idx = idx.span()[1] 70 | temp = sentence[idx:] 71 | flag = False 72 | break 73 | if flag: 74 | return sentence 75 | else: 76 | if idx >= len(sentence)-1: 77 | return sentence 78 | else: 79 | return temp 80 | 81 | def eval_prefix_result(generate_corpus, reference_corpus): 82 | 83 | generate_corpus = [de_prefix(gen) for gen in generate_corpus] 84 | reference_corpus = [de_prefix(ref) for ref in reference_corpus] 85 | # print(generate_corpus) 86 | # print(reference_corpus) 87 | generate_corpus = [" ".join([s for s in gen]) for gen in generate_corpus] 88 | reference_corpus = [" ".join([s for s in ref]) for ref in reference_corpus] 89 | 90 | results = {} 91 | bleu_result = cal_bleu(generate_corpus, reference_corpus) 92 | dist_result = cal_dist(generate_corpus, reference_corpus) 93 | rouge_result = cal_rouge(generate_corpus, reference_corpus) 94 | 95 | results.update(bleu_result) 96 | results.update(dist_result) 97 | results.update(rouge_result) 98 | 99 | return results 100 | 101 | def eval_keep_result(generate_corpus, reference_corpus): 102 | 103 | generate_corpus = [" ".join([s for s in gen]) for gen in generate_corpus] 104 | reference_corpus = [" ".join([s for s in ref]) for ref in reference_corpus] 105 | 106 | results = {} 107 | bleu_result = cal_bleu(generate_corpus, reference_corpus) 108 | dist_result = cal_dist(generate_corpus, reference_corpus) 109 | rouge_result = cal_rouge(generate_corpus, reference_corpus) 110 | 111 | results.update(bleu_result) 112 | results.update(dist_result) 113 | results.update(rouge_result) 114 | 115 | return results 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--model_path", default="fnlp/bart-base-chinese", type=str) 120 | parser.add_argument("--output_dir", default="./outputs", type=str) 121 | parser.add_argument("--lr", default=2e-5, type=float) 122 | parser.add_argument("--with_persona", action='store_true', default=False) 123 | parser.add_argument("--persona_embedding", action='store_true', default=False) 124 | parser.add_argument("--persona_traits", action='store_true', default=False) 125 | parser.add_argument("--cls_embedding", action='store_true', default=False) 126 | parser.add_argument("--history_post", action='store_true', default=False) 127 | parser.add_argument("--batch_size", default='96', type=str) 128 | parser.add_argument("--epoch", default='30', type=str) 129 | parser.add_argument("--data_dir", default="./dataset", type=str) 130 | parser.add_argument("--do_train", action='store_true', default=False) 131 | parser.add_argument("--do_eval", action='store_true', default=False) 132 | parser.add_argument("--do_predict", action='store_true', default=False) 133 | parser.add_argument("--freeze_plm", action='store_true', default=False) 134 | 135 | args = parser.parse_args() 136 | arg_dict=args.__dict__ 137 | 138 | logger = logging.getLogger(__name__) 139 | 140 | args=[ 141 | '--model_name_or_path',arg_dict['model_path'], 142 | '--do_train={}'.format(arg_dict['do_train']), 143 | '--do_eval={}'.format(arg_dict['do_eval']), 144 | '--do_predict={}'.format(arg_dict['do_predict']), 145 | '--train_file', os.path.join(arg_dict['data_dir'], "train_data.pk"), 146 | '--validation_file', os.path.join(arg_dict['data_dir'],"dev_data.pk"), 147 | '--test_file', os.path.join(arg_dict['data_dir'],"test_data.pk"), 148 | '--output_dir', arg_dict["output_dir"], 149 | '--preprocessing_num_workers=3', 150 | '--logging_steps=100', 151 | '--max_train_samples=400000', 152 | '--max_val_samples=10000', 153 | '--dataloader_num_workers=3', 154 | '--per_device_train_batch_size', arg_dict['batch_size'], 155 | '--per_device_eval_batch_size', arg_dict['batch_size'], 156 | '--overwrite_output_dir', 157 | '--max_source_length=64', 158 | '--val_max_target_length='+'64', 159 | '--predict_with_generate=1', 160 | '--seed', str(1000*1), 161 | '--num_train_epochs', arg_dict['epoch'], 162 | '--save_strategy','epoch', 163 | '--save_total_limit', '10', 164 | '--evaluation_strategy', 'epoch', 165 | '--learning_rate', str(arg_dict['lr']), 166 | ] 167 | 168 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 169 | model_args, data_args, training_args = parser.parse_args_into_dataclasses(args) 170 | # model_args: ModelArguments; data_args: DataTrainingArguments; training_args: Seq2SeqTrainingArguments 171 | set_seed(training_args.seed) 172 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 173 | 174 | datasets = {} 175 | data_files = {} 176 | if data_args.train_file is not None: 177 | data_files["train"] = data_args.train_file 178 | if data_args.validation_file is not None: 179 | data_files["validation"] = data_args.validation_file 180 | if data_args.test_file is not None: 181 | data_files["test"] = data_args.test_file 182 | 183 | for key in data_files: 184 | 185 | print("load no persona!") 186 | datasets[key] = load_pk(data_files[key]) 187 | 188 | logging.basicConfig( 189 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 190 | datefmt="%m/%d/%Y %H:%M:%S", 191 | handlers=[logging.StreamHandler(sys.stdout)], 192 | ) 193 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 194 | 195 | # Log on each process the small summary: 196 | logger.warning( 197 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 198 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 199 | ) 200 | # Set the verbosity to info of the Transformers logger (on main process only): 201 | if is_main_process(training_args.local_rank): 202 | transformers.utils.logging.set_verbosity_info() 203 | logger.info("Training/evaluation parameters %s", training_args) 204 | tokenizer = BertTokenizer.from_pretrained(model_args.model_name_or_path) 205 | model = BartForConditionalGeneration.from_pretrained(model_args.model_name_or_path) 206 | 207 | model.config.max_length = data_args.val_max_target_length 208 | 209 | if arg_dict["freeze_plm"]: 210 | print("Freeze the pretrained model.......") 211 | freeze_params(model) 212 | 213 | column_names = datasets["train"].column_names 214 | max_target_length = data_args.val_max_target_length 215 | padding = False 216 | 217 | if training_args.do_train: 218 | train_dataset = datasets["train"] 219 | 220 | if data_args.max_train_samples is not None: 221 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 222 | print("process train dataset......................") 223 | train_dataset = train_dataset.map( 224 | preprocess_function, 225 | batched=True, 226 | num_proc=data_args.preprocessing_num_workers, 227 | remove_columns=column_names, 228 | load_from_cache_file=not data_args.overwrite_cache, 229 | ) 230 | 231 | print("train dataset processed over") 232 | 233 | if training_args.do_eval: 234 | eval_dataset = datasets["validation"] 235 | if data_args.max_val_samples is not None: 236 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 237 | 238 | eval_dataset = eval_dataset.map( 239 | preprocess_function, 240 | batched=True, 241 | # batch_size=training_args.per_device_eval_batch_size, 242 | num_proc=data_args.preprocessing_num_workers, 243 | remove_columns=column_names, 244 | load_from_cache_file=not data_args.overwrite_cache, 245 | ) 246 | 247 | if training_args.do_predict: 248 | test_dataset = datasets["test"] 249 | if data_args.max_predict_samples is not None: 250 | eval_dataset = test_dataset.select(range(data_args.max_predict_samples)) 251 | test_dataset = test_dataset.map( 252 | preprocess_function, 253 | batched=True, 254 | # batch_size=training_args.per_device_eval_batch_size, 255 | num_proc=data_args.preprocessing_num_workers, 256 | remove_columns=column_names, 257 | load_from_cache_file=not data_args.overwrite_cache, 258 | ) 259 | 260 | print("eval dataset processed over") 261 | 262 | # Data collator 263 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 264 | data_collator = DataCollatorForSeq2Seq( 265 | tokenizer, 266 | model=model, 267 | label_pad_token_id=label_pad_token_id, 268 | pad_to_multiple_of=8 if training_args.fp16 else None, 269 | ) 270 | 271 | print("Initialize our Trainer") 272 | trainer = Seq2SeqTrainer( 273 | model=model, 274 | args=training_args, 275 | train_dataset=train_dataset if training_args.do_train else None, 276 | eval_dataset=eval_dataset if training_args.do_eval else None, 277 | tokenizer=tokenizer, 278 | data_collator=data_collator, 279 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 280 | ) 281 | 282 | # Training 283 | if training_args.do_train: 284 | train_result = trainer.train() 285 | trainer.save_model() 286 | metrics = train_result.metrics 287 | max_train_samples = ( 288 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 289 | ) 290 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 291 | trainer.log_metrics("train", metrics) 292 | trainer.save_metrics("train", metrics) 293 | trainer.save_state() 294 | 295 | if training_args.do_eval: 296 | 297 | eval_dataloader = trainer.get_eval_dataloader(eval_dataset) # make list into tensor 298 | 299 | print("begin evaluating") 300 | eval_result = trainer.predict(eval_dataset, metric_key_prefix="predict") 301 | metrics = eval_result.metrics 302 | 303 | loss_list = [] 304 | for index, inputs in enumerate(eval_dataloader): 305 | 306 | loss, logits, labels = trainer.prediction_step(trainer.model, inputs, prediction_loss_only=True) 307 | loss_list.append(float(loss.cpu())) 308 | 309 | ppl_result = round(cal_ppl(loss_list),4) 310 | metrics.update({"ppl": ppl_result}) 311 | trainer.log_metrics("eval", metrics) 312 | trainer.save_metrics("eval", metrics) 313 | 314 | if training_args.do_predict: 315 | 316 | test_dataloader = trainer.get_eval_dataloader(test_dataset) # make list into tensor 317 | loss_list = [] 318 | for index, inputs in enumerate(test_dataloader): 319 | loss, logits, labels = trainer.prediction_step(trainer.model, inputs, prediction_loss_only=True) 320 | loss_list.append(float(loss.cpu())) 321 | 322 | ppl_result = round(cal_ppl(loss_list),4) 323 | test_result = trainer.predict(test_dataset, metric_key_prefix="predict") 324 | metrics = test_result.metrics 325 | metrics.update({"ppl": ppl_result}) 326 | 327 | trainer.log_metrics("test", metrics) 328 | trainer.save_metrics("test", metrics) -------------------------------------------------------------------------------- /Tasks/AddresseeRecognition/src/train_cobert.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import logging 5 | import os 6 | import numpy as np 7 | from transformers import BertTokenizer, BertConfig 8 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 9 | 10 | from utils import set_seed, load_retrive_history_post, load_retrive_history_post_and_id, load_pk_persona, load_pk, compute_metrics, compute_metrics_from_logits 11 | from torch.utils.data import DataLoader 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch import optim, cuda 16 | from model import ThreeBert 17 | from dataloader import LiveDataset 18 | import torch.nn.functional as F 19 | from sklearn import metrics 20 | from tqdm import tqdm 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | class CosineContrastiveLoss(nn.Module): 24 | def __init__(self, margin=0.4): 25 | super(CosineContrastiveLoss, self).__init__() 26 | self.margin = margin 27 | 28 | def forward(self, output1, output2, label): 29 | 30 | cos_sim = F.cosine_similarity(output1, output2) 31 | 32 | loss_cos_con = torch.mean((1-label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) + \ 33 | (label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2)) 34 | return loss_cos_con 35 | 36 | class BatchCosineContrastiveLoss(nn.Module): 37 | def __init__(self, ): 38 | super(BatchCosineContrastiveLoss, self).__init__() 39 | 40 | def forward(self, output1, output2): 41 | batch_size = output2.size(0) 42 | y_true = torch.arange(batch_size, dtype=torch.long, device=output1.device) 43 | 44 | cos_sim_matrix = F.cosine_similarity(output1.unsqueeze(1), output2.unsqueeze(0), dim=2) 45 | cos_sim_matrix = cos_sim_matrix - torch.eye(batch_size, device="cuda") * 1e-12 46 | 47 | loss = F.cross_entropy(cos_sim_matrix, y_true) 48 | 49 | return loss 50 | 51 | def DotProDuctLoss(batch_x_emb, batch_y_emb): 52 | """ 53 | if batch_x_emb.dim() == 2: 54 | # batch_x_emb: (batch_size, emb_size) 55 | # batch_y_emb: (batch_size, emb_size) 56 | 57 | if batch_x_emb.dim() == 3: 58 | # batch_x_emb: (batch_size, batch_size, emb_size), the 1st dim is along examples and the 2nd dim is along candidates 59 | # batch_y_emb: (batch_size, emb_size) 60 | """ 61 | batch_size = batch_x_emb.size(0) 62 | targets = torch.tensor([batch_size-1 for i in range(batch_size)], device=batch_x_emb.device) 63 | 64 | if batch_x_emb.dim() == 2: 65 | dot_products = batch_x_emb.mm(batch_y_emb.t()) 66 | elif batch_x_emb.dim() == 3: 67 | dot_products = torch.bmm(batch_x_emb, batch_y_emb.unsqueeze(0).repeat(batch_size, 1, 1).transpose(1,2))[:, targets, targets] # (batch_size, batch_size) 68 | 69 | # dot_products: [batch, batch] 70 | log_prob = F.log_softmax(dot_products, dim=1) 71 | loss = F.nll_loss(log_prob, targets) 72 | nb_ok = (log_prob.max(dim=1)[1] == targets).float().sum() 73 | 74 | return loss, nb_ok/batch_size 75 | 76 | 77 | class Similarity(nn.Module): 78 | def __init__(self,): 79 | super(Similarity, self).__init__() 80 | 81 | def forward(self, output1, output2): 82 | if len(output1.size()) == 1 and len(output2.size()) == 1: 83 | output1 = output1.unsqueeze(0) 84 | output2 = output2.unsqueeze(0) 85 | cos_sim = F.cosine_similarity(output1, output2) 86 | else: 87 | cos_sim = F.cosine_similarity(output1, output2) 88 | 89 | return float(cos_sim) 90 | 91 | def freeze_params(model): 92 | """Set requires_grad=False for each of model.parameters()""" 93 | for _, para in model.named_parameters(): 94 | para.requires_grad = False 95 | 96 | def validation(model, test_dataloader): 97 | model.eval() 98 | 99 | total_acc = [] 100 | total_loss = [] 101 | total_recall = [] 102 | total_MRR = [] 103 | 104 | with torch.no_grad(): 105 | for _, data in enumerate(tqdm(test_dataloader, desc='Evaluating')): 106 | ids, mask, token_type_ids = data['ids'], data['mask'], data['token_type_ids'] 107 | if args.history_post: 108 | ids = [ids[0].to(device, dtype = torch.long), ids[1].to(device, dtype = torch.long), ids[2].to(device, dtype = torch.long)] 109 | mask = [mask[0].to(device, dtype = torch.long), mask[1].to(device, dtype = torch.long), mask[2].to(device, dtype = torch.long)] 110 | token_type_ids = [token_type_ids[0].to(device, dtype = torch.long), token_type_ids[1].to(device, dtype = torch.long), token_type_ids[2].to(device, dtype = torch.long)] 111 | else: 112 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long)] 113 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long)] 114 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long)] 115 | 116 | if args.apply_interaction == True: 117 | logits, targets = model(ids, mask, token_type_ids, args.apply_interaction) 118 | loss = F.cross_entropy(logits, targets) 119 | acc = (targets.long() == logits.float().argmax(dim=1)).sum() / targets.size(0) 120 | test_recall, test_MRR = compute_metrics_from_logits(logits, targets) 121 | 122 | else: 123 | output1, output2 = model(ids, mask, token_type_ids, args.apply_interaction) 124 | # output1, output2 = model(ids, mask, token_type_ids) 125 | loss, acc = DotProDuctLoss(output1, output2) 126 | test_recall, test_MRR = compute_metrics(output1, output2) 127 | 128 | total_loss.append(float(loss)) 129 | total_acc.append(float(acc)) 130 | total_recall.append(test_recall) 131 | total_MRR.append(test_MRR) 132 | 133 | return np.mean(total_loss), np.mean(total_acc), np.mean(total_recall, axis=0), np.mean(total_MRR) 134 | 135 | def save_model(args, output_dir, model, tokenizer=None): 136 | 137 | os.makedirs(output_dir, exist_ok=True) 138 | logger.info(f"Saving model checkpoint to {output_dir}") 139 | 140 | if not isinstance(model, PreTrainedModel): 141 | if isinstance(unwrap_model(model), PreTrainedModel): 142 | if state_dict is None: 143 | state_dict = model.state_dict() 144 | unwrap_model(model).save_pretrained(output_dir, state_dict=state_dict) 145 | else: 146 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 147 | if state_dict is None: 148 | state_dict = model.state_dict() 149 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) 150 | else: 151 | model.save_pretrained(output_dir, state_dict=model.state_dict()) 152 | if tokenizer is not None: 153 | tokenizer.save_pretrained(output_dir) 154 | # Good practice: save your training arguments together with the trained model 155 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("--model_path", default="bert-base-chinese", type=str) 161 | parser.add_argument("--load_model_path", default="", type=str) 162 | parser.add_argument("--history_post_path", default="ProcessedData/text_profile_512.json", type=str) 163 | parser.add_argument("--persona_id_path", default="Dataset/ProcessedData/basic_profile.json", type=str) 164 | parser.add_argument("--output_dir", default="./outputs", type=str) 165 | parser.add_argument("--writer_dir", default="./outputs/runs", type=str) 166 | parser.add_argument("--lr", default=1e-5, type=float) 167 | parser.add_argument("--history_post", action='store_true', default=False) 168 | parser.add_argument("--add_id", action='store_true', default=False) 169 | parser.add_argument("--apply_interaction", action='store_true', default=False) 170 | parser.add_argument("--train_from_scratch", action='store_true', default=False) 171 | parser.add_argument("--batch_size", default=10, type=int) 172 | parser.add_argument("--max_length", default=64, type=int) 173 | parser.add_argument("--epoch", default=30, type=int) 174 | parser.add_argument("--logging_steps", default=100, type=int) 175 | parser.add_argument("--max_train_samples", default=400000, type=int) 176 | parser.add_argument("--max_val_samples", default=10000, type=int) 177 | parser.add_argument("--seed", default=20, type=int) 178 | parser.add_argument("--data_dir", default="./dataset", type=str) 179 | parser.add_argument("--do_train", action='store_true', default=False) 180 | parser.add_argument("--do_eval", action='store_true', default=False) 181 | parser.add_argument("--do_test", action='store_true', default=False) 182 | parser.add_argument("--freeze_plm", action='store_true', default=False) 183 | 184 | args = parser.parse_args() 185 | 186 | logger = logging.getLogger(__name__) 187 | set_seed(args.seed) 188 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 189 | 190 | datasets = {} 191 | data_files = {} 192 | 193 | args.add_id = True 194 | print(args) 195 | if args.do_train: 196 | writer = SummaryWriter(log_dir=args.writer_dir, flush_secs=120) 197 | 198 | if args.do_train: 199 | data_files["train"] = os.path.join(args.data_dir, "train_data.pk") 200 | if args.do_eval: 201 | data_files["validation"] = os.path.join(args.data_dir, "dev_data.pk") 202 | if args.do_test: 203 | data_files["test"] = os.path.join(args.data_dir, "test_data.pk") 204 | 205 | for key in data_files: 206 | if args.history_post and args.add_id: 207 | print("load history_post and ID") 208 | history_post_path = args.history_post_path 209 | persona_id_path = args.persona_id_path 210 | datasets[key] = load_retrive_history_post_and_id(data_files[key], history_post_path, persona_id_path) 211 | 212 | elif args.history_post and not args.add_id: 213 | print("load history_post") 214 | history_post_path = args.history_post_path 215 | datasets[key] = load_retrive_history_post(data_files[key], history_post_path) 216 | 217 | elif args.add_id and not args.history_post: 218 | print("load ID") 219 | persona_id_path = args.persona_id_path 220 | datasets[key] = load_pk_persona(data_files[key], persona_id_path) 221 | else: 222 | print("load no persona!") 223 | datasets[key] = load_pk(data_files[key]) 224 | 225 | logging.basicConfig( 226 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 227 | datefmt="%m/%d/%Y %H:%M:%S", 228 | handlers=[logging.StreamHandler(sys.stdout)], 229 | ) 230 | 231 | device = 'cuda' if cuda.is_available() else 'cpu' 232 | config = BertConfig.from_pretrained(args.model_path) 233 | tokenizer = BertTokenizer.from_pretrained(args.model_path) 234 | if args.train_from_scratch: 235 | model = ThreeBert(config, train_from_scratch=True).to(device) 236 | else: 237 | model = ThreeBert(config, train_from_scratch=False).to(device) 238 | if args.load_model_path != "": 239 | model.load_state_dict(torch.load(os.path.join(args.load_model_path, "pytorch_model.bin"))) 240 | criterion = BatchCosineContrastiveLoss() 241 | optimizer = optim.Adam(model.parameters(), lr = args.lr ) 242 | 243 | if args.freeze_plm: 244 | print("Freeze the pretrained model.......") 245 | freeze_params(model) 246 | 247 | if args.do_train: 248 | train_dataset = datasets["train"] 249 | if args.max_train_samples is not None: 250 | train_dataset = train_dataset.select(range(args.max_train_samples)) 251 | 252 | print("train dataset length is {}".format(len(train_dataset))) 253 | training_set = LiveDataset(train_dataset, tokenizer, args.max_length) 254 | print("LiveDataset length is {}".format(len(training_set))) 255 | training_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=False, num_workers=4) 256 | 257 | print("train dataset length is {}".format(len(training_loader))) 258 | 259 | if args.do_eval: 260 | eval_dataset = datasets["validation"] 261 | if args.max_val_samples is not None: 262 | eval_dataset = eval_dataset.select(range(args.max_val_samples)) 263 | testing_set = LiveDataset(eval_dataset, tokenizer, args.max_length) 264 | testing_loader = DataLoader(testing_set, batch_size=args.batch_size, shuffle=False, num_workers=4) 265 | 266 | print("eval dataset processed over") 267 | # Training 268 | if args.do_train: 269 | print("Begin training") 270 | train_step = -1 271 | for epoch in range(args.epoch): 272 | model.train() 273 | for _, data in enumerate(tqdm(training_loader, desc='Training')): 274 | # tqdm slow is for it didn't display until next epoch 275 | train_step += 1 276 | ids, mask, token_type_ids= data['ids'], data['mask'], data['token_type_ids'] 277 | if args.history_post: 278 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long), ids[2].to(device, dtype=torch.long)] 279 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long), mask[2].to(device, dtype=torch.long)] 280 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long), token_type_ids[2].to(device, dtype=torch.long)] 281 | else: 282 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long)] 283 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long)] 284 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long)] 285 | optimizer.zero_grad() 286 | 287 | if args.apply_interaction == True: 288 | logits, targets = model(ids, mask, token_type_ids, args.apply_interaction) 289 | loss = F.cross_entropy(logits, targets) 290 | acc = (targets.long() == logits.float().argmax(dim=1)).sum() / targets.size(0) 291 | if train_step % 500==0: 292 | train_recall, train_MRR = compute_metrics_from_logits(logits, targets) 293 | print(f'Step:{train_step}, Epoch:{epoch}, Loss:{loss.item()}, batch_acc:{acc}, batch_recall:{train_recall}, batch_MRR:{train_MRR}') 294 | else: 295 | output1, output2 = model(ids, mask, token_type_ids, args.apply_interaction) 296 | loss, acc = DotProDuctLoss(output1, output2) 297 | if train_step % 500==0: 298 | train_recall, train_MRR = compute_metrics(output1, output2) 299 | print(f'Step:{train_step}, Epoch:{epoch}, Loss:{loss.item()}, batch_acc:{acc}, batch_recall:{train_recall}, batch_MRR:{train_MRR}') 300 | writer.add_scalar('Loss/train', loss.item(), train_step) 301 | loss.backward() 302 | optimizer.step() 303 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 304 | print(f'Test Epoch:{epoch}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 305 | writer.add_scalar('Test loss', test_loss, epoch) 306 | writer.add_scalar('Test accuracy', test_acc, epoch) 307 | writer.add_scalar('Test MRR', test_mrr, epoch) 308 | 309 | save_model_path = os.path.join(args.output_dir,"epoch_{}".format(epoch)) 310 | save_model(args, save_model_path, model, tokenizer) 311 | 312 | if args.do_eval: 313 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 314 | print(f'Test checkpoint{args.load_model_path}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 315 | -------------------------------------------------------------------------------- /Tasks/ResponseModeling/src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import logging 5 | import os 6 | import numpy as np 7 | from transformers import BertTokenizer, BertConfig 8 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 9 | 10 | from utils import set_seed, load_retrive_history_post, load_retrive_history_post_and_id, load_pk_persona, load_pk, compute_metrics, compute_metrics_from_logits 11 | from torch.utils.data import DataLoader 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch import optim, cuda 16 | from model import ThreeBert 17 | from dataloader import LiveDataset 18 | import torch.nn.functional as F 19 | from tqdm import tqdm 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | def preprocess_function(examples, max_target_length=64): 23 | 24 | queries = examples['query'] 25 | history = examples['history_post'] 26 | reponses = examples['response'] 27 | # [CLS] [SEP] 28 | queries_inputs = tokenizer(queries, max_length=max_target_length, padding=False, truncation=True) 29 | reponses_inputs = tokenizer(reponses, max_length=max_target_length, padding=False, truncation=True) 30 | history_inputs = tokenizer(history, max_length=512, padding=False, truncation=True) 31 | 32 | ids1, mask1, token_type_ids1 = queries_inputs['input_ids'], queries_inputs['attention_mask'], queries_inputs['token_type_ids'] 33 | ids2, mask2, token_type_ids2 = reponses_inputs['input_ids'], reponses_inputs['attention_mask'], reponses_inputs['token_type_ids'] 34 | ids3, mask3, token_type_ids3 = history_inputs['input_ids'], history_inputs['attention_mask'], history_inputs['token_type_ids'] 35 | 36 | return { 37 | 'ids1': ids1, 38 | 'ids2': ids2, 39 | 'ids3': ids3, 40 | 'mask1': mask1, 41 | 'mask2': mask2, 42 | 'mask3': mask3, 43 | 'token_type_ids1': token_type_ids1, 44 | 'token_type_ids2': token_type_ids2, 45 | 'token_type_ids3': token_type_ids3, 46 | } 47 | 48 | class CosineContrastiveLoss(nn.Module): 49 | def __init__(self, margin=0.4): 50 | super(CosineContrastiveLoss, self).__init__() 51 | self.margin = margin 52 | 53 | def forward(self, output1, output2, label): 54 | 55 | cos_sim = F.cosine_similarity(output1, output2) 56 | loss_cos_con = torch.mean((1-label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) + \ 57 | (label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2)) 58 | return loss_cos_con 59 | 60 | class BatchCosineContrastiveLoss(nn.Module): 61 | def __init__(self, ): 62 | super(BatchCosineContrastiveLoss, self).__init__() 63 | 64 | def forward(self, output1, output2): 65 | batch_size = output2.size(0) 66 | y_true = torch.arange(batch_size, dtype=torch.long, device=output1.device) 67 | 68 | cos_sim_matrix = F.cosine_similarity(output1.unsqueeze(1), output2.unsqueeze(0), dim=2) 69 | cos_sim_matrix = cos_sim_matrix - torch.eye(batch_size, device="cuda") * 1e-12 70 | 71 | loss = F.cross_entropy(cos_sim_matrix, y_true) 72 | 73 | return loss 74 | 75 | def DotProDuctLoss(batch_x_emb, batch_y_emb): 76 | """ 77 | if batch_x_emb.dim() == 2: 78 | # batch_x_emb: (batch_size, emb_size) 79 | # batch_y_emb: (batch_size, emb_size) 80 | 81 | if batch_x_emb.dim() == 3: 82 | # batch_x_emb: (batch_size, batch_size, emb_size), the 1st dim is along examples and the 2nd dim is along candidates 83 | # batch_y_emb: (batch_size, emb_size) 84 | """ 85 | batch_size = batch_x_emb.size(0) 86 | targets = torch.arange(batch_size, device=batch_x_emb.device) 87 | 88 | if batch_x_emb.dim() == 2: 89 | dot_products = batch_x_emb.mm(batch_y_emb.t()) 90 | elif batch_x_emb.dim() == 3: 91 | dot_products = torch.bmm(batch_x_emb, batch_y_emb.unsqueeze(0).repeat(batch_size, 1, 1).transpose(1,2))[:, targets, targets] # (batch_size, batch_size) 92 | 93 | # dot_products: [batch, batch] 94 | log_prob = F.log_softmax(dot_products, dim=1) 95 | loss = F.nll_loss(log_prob, targets) 96 | nb_ok = (log_prob.max(dim=1)[1] == targets).float().sum() 97 | 98 | return loss, nb_ok/batch_size 99 | 100 | 101 | class Similarity(nn.Module): 102 | def __init__(self,): 103 | super(Similarity, self).__init__() 104 | 105 | def forward(self, output1, output2): 106 | if len(output1.size()) == 1 and len(output2.size()) == 1: 107 | output1 = output1.unsqueeze(0) 108 | output2 = output2.unsqueeze(0) 109 | cos_sim = F.cosine_similarity(output1, output2) 110 | else: 111 | cos_sim = F.cosine_similarity(output1, output2) 112 | 113 | return float(cos_sim) 114 | 115 | def freeze_params(model): 116 | """Set requires_grad=False for each of model.parameters()""" 117 | for _, para in model.named_parameters(): 118 | para.requires_grad = False 119 | 120 | def validation(model, test_dataloader): 121 | model.eval() 122 | 123 | total_acc = [] 124 | total_loss = [] 125 | total_recall = [] 126 | total_MRR = [] 127 | 128 | with torch.no_grad(): 129 | for _, data in enumerate(tqdm(test_dataloader, desc='process evaluate data')): 130 | ids, mask, token_type_ids = data['ids'], data['mask'], data['token_type_ids'] 131 | 132 | if args.history_post: 133 | ids = [ids[0].to(device, dtype = torch.long), ids[1].to(device, dtype = torch.long), ids[2].to(device, dtype = torch.long)] 134 | mask = [mask[0].to(device, dtype = torch.long), mask[1].to(device, dtype = torch.long), mask[2].to(device, dtype = torch.long)] 135 | token_type_ids = [token_type_ids[0].to(device, dtype = torch.long), token_type_ids[1].to(device, dtype = torch.long), token_type_ids[2].to(device, dtype = torch.long)] 136 | else: 137 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long)] 138 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long)] 139 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long)] 140 | 141 | if args.apply_interaction == True: 142 | logits, targets = model(ids, mask, token_type_ids, args.apply_interaction) 143 | loss = F.cross_entropy(logits, targets) 144 | acc = (targets.long() == logits.float().argmax(dim=1)).sum() / targets.size(0) 145 | test_recall, test_MRR = compute_metrics_from_logits(logits, targets) 146 | 147 | else: 148 | output1, output2 = model(ids, mask, token_type_ids, args.apply_interaction) 149 | loss, acc = DotProDuctLoss(output1, output2) 150 | test_recall, test_MRR = compute_metrics(output1, output2) 151 | 152 | total_loss.append(float(loss)) 153 | total_acc.append(float(acc)) 154 | total_recall.append(test_recall) 155 | total_MRR.append(test_MRR) 156 | 157 | return np.mean(total_loss), np.mean(total_acc), np.mean(total_recall, axis=0), np.mean(total_MRR) 158 | 159 | def save_model(args, output_dir, model, tokenizer=None): 160 | 161 | os.makedirs(output_dir, exist_ok=True) 162 | logger.info(f"Saving model checkpoint to {output_dir}") 163 | 164 | if not isinstance(model, PreTrainedModel): 165 | if isinstance(unwrap_model(model), PreTrainedModel): 166 | if state_dict is None: 167 | state_dict = model.state_dict() 168 | unwrap_model(model).save_pretrained(output_dir, state_dict=state_dict) 169 | else: 170 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 171 | if state_dict is None: 172 | state_dict = model.state_dict() 173 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) 174 | else: 175 | model.save_pretrained(output_dir, state_dict=model.state_dict()) 176 | if tokenizer is not None: 177 | tokenizer.save_pretrained(output_dir) 178 | # Good practice: save your training arguments together with the trained model 179 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument("--model_path", default="bert-base-chinese", type=str) 185 | parser.add_argument("--history_post_path", default="/mnt/user/gaojingsheng/project/chitchat/chitchat/history_post_512.json", type=str) 186 | parser.add_argument("--load_model_path", default="", type=str) 187 | parser.add_argument("--output_dir", default="./outputs", type=str) 188 | parser.add_argument("--writer_dir", default="./runs", type=str) 189 | parser.add_argument("--lr", default=1e-5, type=float) 190 | parser.add_argument("--history_post", action='store_true', default=False) 191 | parser.add_argument("--add_id", action='store_true', default=False) 192 | parser.add_argument("--apply_interaction", action='store_true', default=False) 193 | parser.add_argument("--train_from_scratch", action='store_true', default=False) 194 | parser.add_argument("--batch_size", default=24, type=int) 195 | parser.add_argument("--max_length", default=64, type=int) 196 | parser.add_argument("--epoch", default=30, type=int) 197 | parser.add_argument("--logging_steps", default=100, type=int) 198 | parser.add_argument("--train_id_num", default=150, type=int) 199 | parser.add_argument("--max_train_samples", default=4000, type=int) 200 | parser.add_argument("--max_val_samples", default=100, type=int) 201 | parser.add_argument("--seed", default=20, type=int) 202 | parser.add_argument("--data_dir", default="./dataset", type=str) 203 | parser.add_argument("--do_train", action='store_true', default=False) 204 | parser.add_argument("--do_eval", action='store_true', default=False) 205 | parser.add_argument("--do_test", action='store_true', default=False) 206 | parser.add_argument("--do_predict", action='store_true', default=False) 207 | parser.add_argument("--freeze_plm", action='store_true', default=False) 208 | 209 | args = parser.parse_args() 210 | # arg_dict=args.__dict__ 211 | 212 | logger = logging.getLogger(__name__) 213 | set_seed(args.seed) 214 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 215 | 216 | datasets = {} 217 | data_files = {} 218 | 219 | # args.do_train = True 220 | # args.do_eval = True 221 | # args.history_post = True 222 | args.add_id = True 223 | args.apply_interaction=True 224 | 225 | if args.do_train: 226 | writer = SummaryWriter(log_dir=args.writer_dir, flush_secs=120) 227 | 228 | if args.do_train: 229 | if args.train_id_num == 150: 230 | data_files["train"] = os.path.join(args.data_dir, "train_data.pk") 231 | elif args.train_id_num == 50: 232 | data_files["train"] = os.path.join(args.data_dir, "train50_data.pk") 233 | elif args.train_id_num == 15: 234 | data_files["train"] = os.path.join(args.data_dir, "train15_data.pk") 235 | else: 236 | raise KeyError 237 | 238 | if args.do_eval: 239 | if args.train_id_num == 15: 240 | data_files["validation"] = os.path.join(args.data_dir, "dev15_data.pk") 241 | else: 242 | data_files["validation"] = os.path.join(args.data_dir, "dev_data.pk") 243 | if args.do_test: 244 | data_files["test"] = os.path.join(args.data_dir, "test_data.pk") 245 | 246 | 247 | for key in data_files: 248 | if args.history_post and args.add_id: 249 | print("load history_post and ID") 250 | persona_id_path = "./long_deprefix/personaId_list.json" 251 | datasets[key] = load_retrive_history_post_and_id(data_files[key], args.history_post_path, persona_id_path) 252 | 253 | elif args.history_post and not args.add_id: 254 | print("load history_post") 255 | datasets[key] = load_retrive_history_post(data_files[key], args.history_post_path) 256 | 257 | elif args.add_id and not args.history_post: 258 | print("load ID") 259 | persona_id_path = "./long_deprefix/personaId_list.json" 260 | datasets[key] = load_pk_persona(data_files[key], persona_id_path) 261 | else: 262 | print("load no persona!") 263 | datasets[key] = load_pk(data_files[key]) 264 | 265 | logging.basicConfig( 266 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 267 | datefmt="%m/%d/%Y %H:%M:%S", 268 | handlers=[logging.StreamHandler(sys.stdout)], 269 | ) 270 | 271 | device = 'cuda' if cuda.is_available() else 'cpu' 272 | config = BertConfig.from_pretrained(args.model_path) 273 | tokenizer = BertTokenizer.from_pretrained(args.model_path) 274 | if args.train_from_scratch: 275 | model = ThreeBert(config, train_from_scratch=True).to(device) 276 | else: 277 | model = ThreeBert(config, train_from_scratch=False).to(device) 278 | if args.load_model_path != "": 279 | model.load_state_dict(torch.load(os.path.join(args.load_model_path, "pytorch_model.bin"))) 280 | criterion = BatchCosineContrastiveLoss() 281 | optimizer = optim.Adam(model.parameters(), lr = args.lr ) 282 | 283 | if args.freeze_plm: 284 | print("Freeze the pretrained model.......") 285 | freeze_params(model) 286 | 287 | 288 | if args.do_train: 289 | train_dataset = datasets["train"] 290 | if args.max_train_samples is not None: 291 | train_dataset = train_dataset.select(range(args.max_train_samples)) 292 | 293 | print("train dataset length is {}".format(len(train_dataset))) 294 | training_set = LiveDataset(train_dataset, tokenizer, args.max_length) 295 | print("Dataset length is {}".format(len(training_set))) 296 | training_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=False, num_workers=4) 297 | 298 | print("train dataset processed over") 299 | print("train dataset length is {}".format(len(training_loader))) 300 | 301 | if args.do_eval: 302 | eval_dataset = datasets["validation"] 303 | if args.max_val_samples is not None: 304 | eval_dataset = eval_dataset.select(range(args.max_val_samples)) 305 | 306 | testing_set = LiveDataset(eval_dataset, tokenizer, args.max_length) 307 | testing_loader = DataLoader(testing_set, batch_size=args.batch_size, shuffle=False, num_workers=4) 308 | 309 | print("eval dataset processed over") 310 | 311 | # Training 312 | if args.do_train: 313 | print("Begin training") 314 | train_step = -1 315 | for epoch in range(args.epoch): 316 | model.train() 317 | for _, data in enumerate(tqdm(training_loader, desc='Processing')): 318 | 319 | train_step += 1 320 | ids, mask, token_type_ids= data['ids'], data['mask'], data['token_type_ids'] 321 | if args.history_post: 322 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long), ids[2].to(device, dtype=torch.long)] 323 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long), mask[2].to(device, dtype=torch.long)] 324 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long), token_type_ids[2].to(device, dtype=torch.long)] 325 | else: 326 | ids = [ids[0].to(device, dtype=torch.long), ids[1].to(device, dtype=torch.long)] 327 | mask = [mask[0].to(device, dtype=torch.long), mask[1].to(device, dtype=torch.long)] 328 | token_type_ids = [token_type_ids[0].to(device, dtype=torch.long), token_type_ids[1].to(device, dtype=torch.long)] 329 | optimizer.zero_grad() 330 | 331 | if args.apply_interaction == True: 332 | logits, targets = model(ids, mask, token_type_ids, args.apply_interaction) 333 | loss = F.cross_entropy(logits, targets) 334 | acc = (targets.long() == logits.float().argmax(dim=1)).sum() / targets.size(0) 335 | if train_step % 500==0: 336 | train_recall, train_MRR = compute_metrics_from_logits(logits, targets) 337 | print(f'Step:{train_step}, Epoch:{epoch}, Loss:{loss.item()}, batch_acc:{acc}, batch_recall:{train_recall}, batch_MRR:{train_MRR}') 338 | else: 339 | output1, output2 = model(ids, mask, token_type_ids, args.apply_interaction) 340 | loss, acc = DotProDuctLoss(output1, output2) 341 | if train_step % 500==0: 342 | train_recall, train_MRR = compute_metrics(output1, output2) 343 | print(f'Step:{train_step}, Epoch:{epoch}, Loss:{loss.item()}, batch_acc:{acc}, batch_recall:{train_recall}, batch_MRR:{train_MRR}') 344 | writer.add_scalar('Loss/train', loss.item(), train_step) 345 | loss.backward() 346 | optimizer.step() 347 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 348 | print(f'Test Epoch:{epoch}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 349 | writer.add_scalar('Test loss', test_loss, epoch) 350 | writer.add_scalar('Test accuracy', test_acc, epoch) 351 | writer.add_scalar('Test MRR', test_mrr, epoch) 352 | 353 | save_model_path = os.path.join(args.output_dir,"epoch_{}".format(epoch)) 354 | save_model(args, save_model_path, model, tokenizer) 355 | 356 | if args.do_eval: 357 | test_loss, test_acc, test_recall, test_mrr = validation(model, testing_loader) 358 | print(f'Test checkpoint{args.load_model_path}, Test loss:{test_loss}, Test accuracy:{test_acc}, Test recall:{test_recall}, Test MRR:{test_mrr}') 359 | 360 | --------------------------------------------------------------------------------