├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── tianchi-multi-task-nlp.iml └── vcs.xml ├── README.md ├── __init__.py ├── bert_pretrain_model └── empty.txt ├── calculate_loss.py ├── data_generator.py ├── generate_data.py ├── inference.py ├── net.py ├── submission ├── Dockerfile ├── empty.txt └── run.sh ├── tianchi_datasets ├── OCEMOTION │ └── empty.txt ├── OCNLI │ └── empty.txt ├── TNEWS │ └── empty.txt └── empty.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.tar 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/tianchi-multi-task-nlp.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 比赛全流程体验 2 | NLP中文预训练模型泛化能力挑战赛 3 | 4 | ## 训练环境介绍 5 | 6 | ``` 7 | 机器信息:NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 8 | pytorch 版本 1.6.0 9 | 10 | 机器信息:NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 11 | pytorch 版本 1.7.1 12 | ``` 13 | 14 | python依赖: 15 | ``` 16 | pip install transformers 17 | ``` 18 | 19 | ## Docker安装(Ubutun) 20 | 21 | 命令行安装: 22 | ``` 23 | sudo apt install docker.io 24 | ``` 25 | 26 | 验证: 27 | ``` 28 | docker info 29 | ``` 30 | ![](https://tianchi-public.oss-cn-hangzhou.aliyuncs.com/public/files/forum/160658242933332501606582428585.png) 31 | 32 | 33 | ## 运行过程 34 | 35 | 1. 下载Bert全权重,下载 https://huggingface.co/bert-base-chinese/tree/main 下载config.json vocab.txt pytorch_model.bin,把这三个文件放进tianchi-multi-task-nlp/bert_pretrain_model文件夹下。 36 | 37 | 2. 下载比赛数据集,把三个数据集分别放进 `tianchi-multi-task-nlp/tianchi_datasets/数据集名字/` 下面: 38 | - OCEMOTION/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCEMOTION_train1128.csv 39 | - OCEMOTION/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/ocemotion_test_B.csv 40 | - TNEWS/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/TNEWS_train1128.csv 41 | - TNEWS/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/tnews_test_B.csv 42 | - OCNLI/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCNLI_train1128.csv 43 | - OCNLI/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/ocnli_test_B.csv 44 | 45 | 文件目录样例: 46 | ``` 47 | tianchi-multi-task-nlp/tianchi_datasets/OCNLI/total.csv 48 | tianchi-multi-task-nlp/tianchi_datasets/OCNLI/test.csv 49 | ``` 50 | 51 | 3. 分开训练集和验证集,默认验证集是各3000条数据,参数可以自己修改: 52 | ``` 53 | python ./generate_data.py 54 | ``` 55 | 4. 训练模型,一个epoch: 56 | ``` 57 | python ./train.py 58 | ``` 59 | 会保存验证集上平均f1分数最高的模型到 ./saved_best.pt 60 | 61 | 5. 用训练好的模型 ./saved_best.pt 生成结果: 62 | ``` 63 | python ./inference.py 64 | ``` 65 | 66 | 6. 打包预测结果。 67 | ``` 68 | zip -r ./result.zip ./*.json 69 | ``` 70 | 7. 生成Docker并进行提交,参考:https://tianchi.aliyun.com/competition/entrance/231759/tab/174 71 | - 创建云端镜像仓库:https://cr.console.aliyun.com/ 72 | - 创建命名空间和镜像仓库; 73 | - 然后切换到`submission`文件夹下,执行下面命令; 74 | 75 | ``` 76 | # 用于登录的用户名为阿里云账号全名,密码为开通服务时设置的密码。 77 | sudo docker login --username=xxx@mail.com registry.cn-hangzhou.aliyuncs.com 78 | 79 | # 使用本地Dockefile进行构建,使用创建仓库的【公网地址】 80 | # 如 docker build -t registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 . 81 | docker build -t registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 . 82 | ``` 83 | 84 | 输出构建过程: 85 | ``` 86 | Sending build context to Docker daemon 18.94kB 87 | Step 1/4 : FROM registry.cn-shanghai.aliyuncs.com/tcc-public/python:3 88 | ---> a4cc999cf2aa 89 | Step 2/4 : ADD . / 90 | ---> Using cache 91 | ---> b18fbb4425ef 92 | Step 3/4 : WORKDIR / 93 | ---> Using cache 94 | ---> f5fcc4ca5eca 95 | Step 4/4 : CMD ["sh", "run.sh"] 96 | ---> Using cache 97 | ---> ed0c4b0e545f 98 | Successfully built ed0c4b0e545f 99 | ``` 100 | 101 | ``` 102 | # ed0c4b0e545f 为镜像id,上面构建过程最后一行 103 | sudo docker tag ed0c4b0e545f registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 104 | 105 | # 提交镜像到云端 106 | docker push registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 107 | ``` 108 | 109 | 8. [比赛提交页面](https://tianchi.aliyun.com/competition/entrance/531865/submission/723),填写镜像路径+版本号,以及用户名和密码则可以完成提交。 110 | 111 | 112 | ## 比赛改进思路 113 | 114 | 1. 修改 calculate_loss.py 改变loss的计算方式,从平衡子任务难度以及各子任务类别样本不均匀入手; 115 | 2. 修改 net.py 改变模型的结构,加入attention层,或者其他层; 116 | 3. 使用 cleanlab 等工具对训练文本进行清洗; 117 | 4. 做文本数据增强,或者在预训练时候用其他数据集pretrain; 118 | 5. 对训练好的模型再在完整数据集(包括验证集和训练集)上用小的学习率训练一个epoch; 119 | 6. 调整bathSize和a_step,变更梯度累计的程度,当前是batchSize=16,a_step=16; 120 | 7. 用 chinese-roberta-wwm-ext 作为预训练模型; 121 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 12:39:45 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | 10 | -------------------------------------------------------------------------------- /bert_pretrain_model/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /calculate_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:23:01 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | import numpy as np 12 | from math import exp, log 13 | 14 | 15 | class Calculate_loss(): 16 | def __init__(self, label_dict, weighted=False, tnews_weights=None, ocnli_weights=None, ocemotion_weights=None): 17 | self.weighted = weighted 18 | if weighted: 19 | self.tnews_loss = nn.CrossEntropyLoss(tnews_weights) 20 | self.ocnli_loss = nn.CrossEntropyLoss(ocnli_weights) 21 | self.ocemotion_loss = nn.CrossEntropyLoss(ocemotion_weights) 22 | else: 23 | self.loss = nn.CrossEntropyLoss() 24 | self.label2idx = dict() 25 | self.idx2label = dict() 26 | for key in ['TNEWS', 'OCNLI', 'OCEMOTION']: 27 | self.label2idx[key] = dict() 28 | self.idx2label[key] = dict() 29 | for i, e in enumerate(label_dict[key]): 30 | self.label2idx[key][e] = i 31 | self.idx2label[key][i] = e 32 | 33 | def idxToLabel(self, key, idx): 34 | return self.idx2Label[key][idx] 35 | 36 | def labelToIdx(self, key, label): 37 | return self.label2idx[key][label] 38 | 39 | def compute(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold): 40 | res = 0 41 | if tnews_pred != None: 42 | res += self.tnews_loss(tnews_pred, tnews_gold) if self.weighted else self.loss(tnews_pred, tnews_gold) 43 | if ocnli_pred != None: 44 | res += self.ocnli_loss(ocnli_pred, ocnli_gold) if self.weighted else self.loss(ocnli_pred, ocnli_gold) 45 | if ocemotion_pred != None: 46 | res += self.ocemotion_loss(ocemotion_pred, ocemotion_gold) if self.weighted else self.loss(ocemotion_pred, ocemotion_gold) 47 | return res 48 | 49 | def compute_dtp(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold, tnews_kpi=0.1, ocnli_kpi=0.1, ocemotion_kpi=0.1, y=0.5): 50 | res = 0 51 | if tnews_pred != None: 52 | res += self.tnews_loss(tnews_pred, tnews_gold) * self._calculate_weight(tnews_kpi, y) if self.weighted else self.loss(tnews_pred, tnews_gold) * self._calculate_weight(tnews_kpi, y) 53 | if ocnli_pred != None: 54 | res += self.ocnli_loss(ocnli_pred, ocnli_gold) * self._calculate_weight(ocnli_kpi, y) if self.weighted else self.loss(ocnli_pred, ocnli_gold) * self._calculate_weight(ocnli_kpi, y) 55 | if ocemotion_pred != None: 56 | res += self.ocemotion_loss(ocemotion_pred, ocemotion_gold) * self._calculate_weight(ocemotion_kpi, y) if self.weighted else self.loss(ocemotion_pred, ocemotion_gold) * self._calculate_weight(ocemotion_kpi, y) 57 | return res 58 | 59 | 60 | def correct_cnt(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold): 61 | good_nb = 0 62 | total_nb = 0 63 | if tnews_pred != None: 64 | tnews_val = torch.argmax(tnews_pred, axis=1) 65 | for i, e in enumerate(tnews_gold): 66 | if e == tnews_val[i]: 67 | good_nb += 1 68 | total_nb += 1 69 | if ocnli_pred != None: 70 | ocnli_val = torch.argmax(ocnli_pred, axis=1) 71 | for i, e in enumerate(ocnli_gold): 72 | if e == ocnli_val[i]: 73 | good_nb += 1 74 | total_nb += 1 75 | if ocemotion_pred != None: 76 | ocemotion_val = torch.argmax(ocemotion_pred, axis=1) 77 | for i, e in enumerate(ocemotion_gold): 78 | if e == ocemotion_val[i]: 79 | good_nb += 1 80 | total_nb += 1 81 | return good_nb, total_nb 82 | 83 | def correct_cnt_each(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold): 84 | good_ocnli_nb = 0 85 | good_ocemotion_nb = 0 86 | good_tnews_nb = 0 87 | total_ocnli_nb = 0 88 | total_ocemotion_nb = 0 89 | total_tnews_nb = 0 90 | if tnews_pred != None: 91 | tnews_val = torch.argmax(tnews_pred, axis=1) 92 | for i, e in enumerate(tnews_gold): 93 | if e == tnews_val[i]: 94 | good_tnews_nb += 1 95 | total_tnews_nb += 1 96 | if ocnli_pred != None: 97 | ocnli_val = torch.argmax(ocnli_pred, axis=1) 98 | for i, e in enumerate(ocnli_gold): 99 | if e == ocnli_val[i]: 100 | good_ocnli_nb += 1 101 | total_ocnli_nb += 1 102 | if ocemotion_pred != None: 103 | ocemotion_val = torch.argmax(ocemotion_pred, axis=1) 104 | for i, e in enumerate(ocemotion_gold): 105 | if e == ocemotion_val[i]: 106 | good_ocemotion_nb += 1 107 | total_ocemotion_nb += 1 108 | return good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb 109 | 110 | def collect_pred_and_gold(self, pred, gold): 111 | if pred == None or gold == None: 112 | p, g = [], [] 113 | else: 114 | p, g = np.array(torch.argmax(pred, axis=1).cpu()).tolist(), np.array(gold.cpu()).tolist() 115 | return p, g 116 | 117 | def _calculate_weight(self, kpi, y): 118 | kpi = max(0.1, kpi) 119 | kpi = min(0.99, kpi) 120 | w = -1 * ((1 - kpi) ** y) * log(kpi) 121 | return w -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:29:08 2020 5 | 6 | @author: luokai 7 | """ 8 | import random 9 | import torch 10 | from transformers import BertTokenizer 11 | 12 | class Data_generator(): 13 | def __init__(self, ocnli_dict, ocemotion_dict, tnews_dict, label_dict, device, tokenizer, max_len=512): 14 | self.max_len = max_len 15 | self.tokenizer = tokenizer 16 | self.device = device 17 | self.label2idx = dict() 18 | self.idx2label = dict() 19 | for key in ['TNEWS', 'OCNLI', 'OCEMOTION']: 20 | self.label2idx[key] = dict() 21 | self.idx2label[key] = dict() 22 | for i, e in enumerate(label_dict[key]): 23 | self.label2idx[key][e] = i 24 | self.idx2label[key][i] = e 25 | self.ocnli_data = dict() 26 | self.ocnli_data['s1'] = [] 27 | self.ocnli_data['s2'] = [] 28 | self.ocnli_data['label'] = [] 29 | for k, v in ocnli_dict.items(): 30 | self.ocnli_data['s1'].append(v['s1']) 31 | self.ocnli_data['s2'].append(v['s2']) 32 | self.ocnli_data['label'].append(self.label2idx['OCNLI'][v['label']]) 33 | self.ocemotion_data = dict() 34 | self.ocemotion_data['s1'] = [] 35 | self.ocemotion_data['label'] = [] 36 | for k, v in ocemotion_dict.items(): 37 | self.ocemotion_data['s1'].append(v['s1']) 38 | self.ocemotion_data['label'].append(self.label2idx['OCEMOTION'][v['label']]) 39 | self.tnews_data = dict() 40 | self.tnews_data['s1'] = [] 41 | self.tnews_data['label'] = [] 42 | for k, v in tnews_dict.items(): 43 | self.tnews_data['s1'].append(v['s1']) 44 | self.tnews_data['label'].append(self.label2idx['TNEWS'][v['label']]) 45 | self.reset() 46 | def reset(self): 47 | self.ocnli_ids = list(range(len(self.ocnli_data['s1']))) 48 | self.ocemotion_ids = list(range(len(self.ocemotion_data['s1']))) 49 | self.tnews_ids = list(range(len(self.tnews_data['s1']))) 50 | random.shuffle(self.ocnli_ids) 51 | random.shuffle(self.ocemotion_ids) 52 | random.shuffle(self.tnews_ids) 53 | def get_next_batch(self, batchSize=64): 54 | ocnli_len = len(self.ocnli_ids) 55 | ocemotion_len = len(self.ocemotion_ids) 56 | tnews_len = len(self.tnews_ids) 57 | total_len = ocnli_len + ocemotion_len + tnews_len 58 | if total_len == 0: 59 | return None 60 | elif total_len > batchSize: 61 | if ocnli_len > 0: 62 | ocnli_tmp_len = int((ocnli_len / total_len) * batchSize) 63 | ocnli_cur = self.ocnli_ids[:ocnli_tmp_len] 64 | self.ocnli_ids = self.ocnli_ids[ocnli_tmp_len:] 65 | if ocemotion_len > 0: 66 | ocemotion_tmp_len = int((ocemotion_len / total_len) * batchSize) 67 | ocemotion_cur = self.ocemotion_ids[:ocemotion_tmp_len] 68 | self.ocemotion_ids = self.ocemotion_ids[ocemotion_tmp_len:] 69 | if tnews_len > 0: 70 | tnews_tmp_len = batchSize - len(ocnli_cur) - len(ocemotion_cur) 71 | tnews_cur = self.tnews_ids[:tnews_tmp_len] 72 | self.tnews_ids = self.tnews_ids[tnews_tmp_len:] 73 | else: 74 | ocnli_cur = self.ocnli_ids 75 | self.ocnli_ids = [] 76 | ocemotion_cur = self.ocemotion_ids 77 | self.ocemotion_ids = [] 78 | tnews_cur = self.tnews_ids 79 | self.tnews_ids = [] 80 | max_len = self._get_max_total_len(ocnli_cur, ocemotion_cur, tnews_cur) 81 | input_ids = [] 82 | token_type_ids = [] 83 | attention_mask = [] 84 | ocnli_gold = None 85 | ocemotion_gold = None 86 | tnews_gold = None 87 | if len(ocnli_cur) > 0: 88 | flower = self.tokenizer([self.ocnli_data['s1'][idx] for idx in ocnli_cur], [self.ocnli_data['s2'][idx] for idx in ocnli_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True) 89 | input_ids.append(flower['input_ids']) 90 | token_type_ids.append(flower['token_type_ids']) 91 | attention_mask.append(flower['attention_mask']) 92 | ocnli_gold = torch.tensor([self.ocnli_data['label'][idx] for idx in ocnli_cur]).to(self.device) 93 | if len(ocemotion_cur) > 0: 94 | flower = self.tokenizer([self.ocemotion_data['s1'][idx] for idx in ocemotion_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True) 95 | input_ids.append(flower['input_ids']) 96 | token_type_ids.append(flower['token_type_ids']) 97 | attention_mask.append(flower['attention_mask']) 98 | ocemotion_gold = torch.tensor([self.ocemotion_data['label'][idx] for idx in ocemotion_cur]).to(self.device) 99 | if len(tnews_cur) > 0: 100 | flower = self.tokenizer([self.tnews_data['s1'][idx] for idx in tnews_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True) 101 | input_ids.append(flower['input_ids']) 102 | token_type_ids.append(flower['token_type_ids']) 103 | attention_mask.append(flower['attention_mask']) 104 | tnews_gold = torch.tensor([self.tnews_data['label'][idx] for idx in tnews_cur]).to(self.device) 105 | st = 0 106 | ed = len(ocnli_cur) 107 | ocnli_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device) 108 | st += len(ocnli_cur) 109 | ed += len(ocemotion_cur) 110 | ocemotion_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device) 111 | st += len(ocemotion_cur) 112 | ed += len(tnews_cur) 113 | tnews_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device) 114 | input_ids = torch.cat(input_ids, axis=0).to(self.device) 115 | token_type_ids = torch.cat(token_type_ids, axis=0).to(self.device) 116 | attention_mask = torch.cat(attention_mask, axis=0).to(self.device) 117 | res = dict() 118 | res['input_ids'] = input_ids 119 | res['token_type_ids'] = token_type_ids 120 | res['attention_mask'] = attention_mask 121 | res['ocnli_ids'] = ocnli_tensor 122 | res['ocemotion_ids'] = ocemotion_tensor 123 | res['tnews_ids'] = tnews_tensor 124 | res['ocnli_gold'] = ocnli_gold 125 | res['ocemotion_gold'] = ocemotion_gold 126 | res['tnews_gold'] = tnews_gold 127 | return res 128 | 129 | def _get_max_total_len(self, ocnli_cur, ocemotion_cur, tnews_cur): 130 | res = 1 131 | for idx in ocnli_cur: 132 | res = max(res, 3 + len(self.ocnli_data['s1'][idx]) + len(self.ocnli_data['s2'][idx])) 133 | for idx in ocemotion_cur: 134 | res = max(res, 2 + len(self.ocemotion_data['s1'][idx])) 135 | for idx in tnews_cur: 136 | res = max(res, 2 + len(self.tnews_data['s1'][idx])) 137 | return min(res, self.max_len) -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 12:46:03 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | 10 | import json 11 | from collections import defaultdict 12 | from math import log 13 | 14 | def split_dataset(dev_data_cnt=5000): 15 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']: 16 | cnt = 0 17 | with open('./tianchi_datasets/' + e + '/total.csv') as f: 18 | with open('./tianchi_datasets/' + e + '/train.csv', 'w') as f_train: 19 | with open('./tianchi_datasets/' + e + '/dev.csv', 'w') as f_dev: 20 | for line in f: 21 | cnt += 1 22 | if cnt <= dev_data_cnt: 23 | f_dev.write(line) 24 | else: 25 | f_train.write(line) 26 | 27 | def print_one_data(path, name, print_content=False): 28 | data_cnt = 0 29 | with open(path) as f: 30 | for line in f: 31 | tmp = json.loads(line) 32 | for _, v in tmp.items(): 33 | data_cnt += 1 34 | if print_content: 35 | print(v) 36 | print(name, 'contains:', data_cnt, 'numbers of data') 37 | 38 | def generate_data(): 39 | label_set = dict() 40 | label_cnt_set = dict() 41 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']: 42 | label_set[e] = set() 43 | label_cnt_set[e] = defaultdict(int) 44 | with open('./tianchi_datasets/' + e + '/total.csv') as f: 45 | for line in f: 46 | label = line.strip().split('\t')[-1] 47 | label_set[e].add(label) 48 | label_cnt_set[e][label] += 1 49 | for k in label_set: 50 | label_set[k] = sorted(list(label_set[k])) 51 | for k, v in label_set.items(): 52 | print(k, v) 53 | with open('./tianchi_datasets/label.json', 'w') as fw: 54 | fw.write(json.dumps(label_set)) 55 | label_weight_set = dict() 56 | for k in label_set: 57 | label_weight_set[k] = [label_cnt_set[k][e] for e in label_set[k]] 58 | total_weight = sum(label_weight_set[k]) 59 | label_weight_set[k] = [log(total_weight / e) for e in label_weight_set[k]] 60 | for k, v in label_weight_set.items(): 61 | print(k, v) 62 | with open('./tianchi_datasets/label_weights.json', 'w') as fw: 63 | fw.write(json.dumps(label_weight_set)) 64 | 65 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']: 66 | for name in ['dev', 'train']: 67 | with open('./tianchi_datasets/' + e + '/' + name + '.csv') as fr: 68 | with open('./tianchi_datasets/' + e + '/' + name + '.json', 'w') as fw: 69 | json_dict = dict() 70 | for line in fr: 71 | tmp_list = line.strip().split('\t') 72 | json_dict[tmp_list[0]] = dict() 73 | json_dict[tmp_list[0]]['s1'] = tmp_list[1] 74 | if e == 'OCNLI': 75 | json_dict[tmp_list[0]]['s2'] = tmp_list[2] 76 | json_dict[tmp_list[0]]['label'] = tmp_list[3] 77 | else: 78 | json_dict[tmp_list[0]]['label'] = tmp_list[2] 79 | fw.write(json.dumps(json_dict)) 80 | 81 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']: 82 | for name in ['dev', 'train']: 83 | cur_path = './tianchi_datasets/' + e + '/' + name + '.json' 84 | data_name = e + '_' + name 85 | print_one_data(cur_path, data_name) 86 | 87 | print_one_data('./tianchi_datasets/label.json', 'label_set') 88 | 89 | if __name__ == '__main__': 90 | print('-------------------------------start-----------------------------------') 91 | split_dataset(dev_data_cnt=3000) 92 | generate_data() 93 | print('-------------------------------finish-----------------------------------') -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:47:24 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | 10 | from net import Net 11 | import json 12 | import torch 13 | import numpy as np 14 | from transformers import BertModel, BertTokenizer 15 | from utils import get_task_chinese 16 | 17 | 18 | def test_csv_to_json(): 19 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']: 20 | with open('./tianchi_datasets/' + e + '/test.csv') as fr: 21 | with open('./tianchi_datasets/' + e + '/test.json', 'w') as fw: 22 | json_dict = dict() 23 | for line in fr: 24 | tmp_list = line.strip().split('\t') 25 | json_dict[tmp_list[0]] = dict() 26 | json_dict[tmp_list[0]]['s1'] = tmp_list[1] 27 | if e == 'OCNLI': 28 | json_dict[tmp_list[0]]['s2'] = tmp_list[2] 29 | fw.write(json.dumps(json_dict)) 30 | 31 | def inference_warpper(tokenizer_model): 32 | ocnli_test = dict() 33 | with open('./tianchi_datasets/OCNLI/test.json') as f: 34 | for line in f: 35 | ocnli_test = json.loads(line) 36 | break 37 | 38 | ocemotion_test = dict() 39 | with open('./tianchi_datasets/OCEMOTION/test.json') as f: 40 | for line in f: 41 | ocemotion_test = json.loads(line) 42 | break 43 | 44 | tnews_test = dict() 45 | with open('./tianchi_datasets/TNEWS/test.json') as f: 46 | for line in f: 47 | tnews_test = json.loads(line) 48 | break 49 | 50 | label_dict = dict() 51 | with open('./tianchi_datasets/label.json') as f: 52 | for line in f: 53 | label_dict = json.loads(line) 54 | break 55 | 56 | model = torch.load('./saved_best.pt') 57 | tokenizer = BertTokenizer.from_pretrained(tokenizer_model) 58 | inference('./submission/ocnli_predict.json', ocnli_test, model, tokenizer, label_dict['OCNLI'], 'ocnli', 'cuda:0', 64, True) 59 | inference('./submission/ocemotion_predict.json', ocemotion_test, model, tokenizer, label_dict['OCEMOTION'], 'ocemotion', 'cuda:0', 64, True) 60 | inference('./submission/tnews_predict.json', tnews_test, model, tokenizer, label_dict['TNEWS'], 'tnews', 'cuda:0', 64, True) 61 | 62 | def inference(path, data_dict, model, tokenizer, idx2label, task_type, device='cuda:0', batchSize=64, print_result=True): 63 | if task_type != 'ocnli' and task_type != 'ocemotion' and task_type != 'tnews': 64 | print('task_type is incorrect!') 65 | return 66 | model.to(device, non_blocking=True) 67 | model.eval() 68 | ids_list = [k for k, _ in data_dict.items()] 69 | next_start_ids = 0 70 | with torch.no_grad(): 71 | with open(path, 'w') as f: 72 | while next_start_ids < len(ids_list): 73 | cur_ids_list = ids_list[next_start_ids: next_start_ids + batchSize] 74 | next_start_ids += batchSize 75 | if task_type == 'ocnli': 76 | flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], [data_dict[idx]['s2'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt') 77 | else: 78 | flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt') 79 | input_ids = flower['input_ids'].to(device, non_blocking=True) 80 | token_type_ids = flower['token_type_ids'].to(device, non_blocking=True) 81 | attention_mask = flower['attention_mask'].to(device, non_blocking=True) 82 | ocnli_ids = torch.tensor([]).to(device, non_blocking=True) 83 | ocemotion_ids = torch.tensor([]).to(device, non_blocking=True) 84 | tnews_ids = torch.tensor([]).to(device, non_blocking=True) 85 | if task_type == 'ocnli': 86 | ocnli_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True) 87 | elif task_type == 'ocemotion': 88 | ocemotion_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True) 89 | else: 90 | tnews_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True) 91 | ocnli_out, ocemotion_out, tnews_out = model(input_ids, ocnli_ids, ocemotion_ids, tnews_ids, token_type_ids, attention_mask) 92 | if task_type == 'ocnli': 93 | pred = torch.argmax(ocnli_out, axis=1) 94 | elif task_type == 'ocemotion': 95 | pred = torch.argmax(ocemotion_out, axis=1) 96 | else: 97 | pred = torch.argmax(tnews_out, axis=1) 98 | pred_final = [idx2label[e] for e in np.array(pred.cpu()).tolist()] 99 | #torch.cuda.empty_cache() 100 | for i, idx in enumerate(cur_ids_list): 101 | if print_result: 102 | print_str = '[ ' + task_type + ' : ' + 'sentence one: ' + data_dict[idx]['s1'] 103 | if task_type == 'ocnli': 104 | print_str += '; sentence two: ' + data_dict[idx]['s2'] 105 | print_str += '; result: ' + pred_final[i] + ' ]' 106 | print(print_str) 107 | single_result_dict = dict() 108 | single_result_dict['id'] = idx 109 | single_result_dict['label'] = pred_final[i] 110 | f.write(json.dumps(single_result_dict, ensure_ascii=False)) 111 | if not (next_start_ids >= len(ids_list) and i == len(cur_ids_list) - 1): 112 | f.write('\n') 113 | 114 | if __name__ == '__main__': 115 | test_csv_to_json() 116 | print('---------------------------------start inference-----------------------------') 117 | inference_warpper(tokenizer_model='./bert_pretrain_model') 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:20:35 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | from transformers import BertModel 12 | 13 | 14 | class Net(nn.Module): 15 | def __init__(self, bert_model): 16 | super(Net, self).__init__() 17 | self.bert = bert_model 18 | self.atten_layer = nn.Linear(768, 16) 19 | self.softmax_d1 = nn.Softmax(dim=1) 20 | self.dropout = nn.Dropout(0.2) 21 | self.OCNLI_layer = nn.Linear(768, 16 * 3) 22 | self.OCEMOTION_layer = nn.Linear(768, 16 * 7) 23 | self.TNEWS_layer = nn.Linear(768, 16 * 15) 24 | 25 | def forward(self, input_ids, ocnli_ids, ocemotion_ids, tnews_ids, token_type_ids=None, attention_mask=None): 26 | cls_emb = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0][:, 0, :].squeeze(1) 27 | if ocnli_ids.size()[0] > 0: 28 | attention_score = self.atten_layer(cls_emb[ocnli_ids, :]) 29 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1)) 30 | ocnli_value = self.OCNLI_layer(cls_emb[ocnli_ids, :]).contiguous().view(-1, 16, 3) 31 | ocnli_out = torch.matmul(attention_score, ocnli_value).squeeze(1) 32 | else: 33 | ocnli_out = None 34 | if ocemotion_ids.size()[0] > 0: 35 | attention_score = self.atten_layer(cls_emb[ocemotion_ids, :]) 36 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1)) 37 | ocemotion_value = self.OCEMOTION_layer(cls_emb[ocemotion_ids, :]).contiguous().view(-1, 16, 7) 38 | ocemotion_out = torch.matmul(attention_score, ocemotion_value).squeeze(1) 39 | else: 40 | ocemotion_out = None 41 | if tnews_ids.size()[0] > 0: 42 | attention_score = self.atten_layer(cls_emb[tnews_ids, :]) 43 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1)) 44 | tnews_value = self.TNEWS_layer(cls_emb[tnews_ids, :]).contiguous().view(-1, 16, 15) 45 | tnews_out = torch.matmul(attention_score, tnews_value).squeeze(1) 46 | else: 47 | tnews_out = None 48 | return ocnli_out, ocemotion_out, tnews_out -------------------------------------------------------------------------------- /submission/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base Images 2 | ## 从天池基础镜像构建 3 | FROM registry.cn-shanghai.aliyuncs.com/tcc-public/python:3 4 | 5 | ## 把当前文件夹里的文件构建到镜像的根目录下 6 | ADD . / 7 | 8 | ## 指定默认工作目录为根目录(需要把run.sh和生成的结果文件都放在该文件夹下,提交后才能运行) 9 | WORKDIR / 10 | 11 | ## 镜像启动后统一执行 sh run.sh 12 | CMD ["sh", "run.sh"] -------------------------------------------------------------------------------- /submission/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /submission/run.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finlay-liu/tianchi-multi-task-nlp/261ec1b2fa611112cc313c048681a9a141141f04/submission/run.sh -------------------------------------------------------------------------------- /tianchi_datasets/OCEMOTION/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /tianchi_datasets/OCNLI/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /tianchi_datasets/TNEWS/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /tianchi_datasets/empty.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:34:27 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | import torch 10 | from transformers import BertModel, BertTokenizer 11 | import json 12 | from utils import get_f1, print_result, load_pretrained_model, load_tokenizer 13 | from net import Net 14 | from data_generator import Data_generator 15 | from calculate_loss import Calculate_loss 16 | 17 | 18 | def train(epochs=20, batchSize=64, lr=0.0001, device='cuda:0', accumulate=True, a_step=16, load_saved=False, file_path='./saved_best.pt', use_dtp=False, pretrained_model='./bert_pretrain_model', tokenizer_model='bert-base-chinese', weighted_loss=False): 19 | device = device 20 | tokenizer = load_tokenizer(tokenizer_model) 21 | my_net = torch.load(file_path) if load_saved else Net(load_pretrained_model(pretrained_model)) 22 | my_net.to(device, non_blocking=True) 23 | label_dict = dict() 24 | with open('./tianchi_datasets/label.json') as f: 25 | for line in f: 26 | label_dict = json.loads(line) 27 | break 28 | label_weights_dict = dict() 29 | with open('./tianchi_datasets/label_weights.json') as f: 30 | for line in f: 31 | label_weights_dict = json.loads(line) 32 | break 33 | ocnli_train = dict() 34 | with open('./tianchi_datasets/OCNLI/train.json') as f: 35 | for line in f: 36 | ocnli_train = json.loads(line) 37 | break 38 | ocnli_dev = dict() 39 | with open('./tianchi_datasets/OCNLI/dev.json') as f: 40 | for line in f: 41 | ocnli_dev = json.loads(line) 42 | break 43 | ocemotion_train = dict() 44 | with open('./tianchi_datasets/OCEMOTION/train.json') as f: 45 | for line in f: 46 | ocemotion_train = json.loads(line) 47 | break 48 | ocemotion_dev = dict() 49 | with open('./tianchi_datasets/OCEMOTION/dev.json') as f: 50 | for line in f: 51 | ocemotion_dev = json.loads(line) 52 | break 53 | tnews_train = dict() 54 | with open('./tianchi_datasets/TNEWS/train.json') as f: 55 | for line in f: 56 | tnews_train = json.loads(line) 57 | break 58 | tnews_dev = dict() 59 | with open('./tianchi_datasets/TNEWS/dev.json') as f: 60 | for line in f: 61 | tnews_dev = json.loads(line) 62 | break 63 | train_data_generator = Data_generator(ocnli_train, ocemotion_train, tnews_train, label_dict, device, tokenizer) 64 | dev_data_generator = Data_generator(ocnli_dev, ocemotion_dev, tnews_dev, label_dict, device, tokenizer) 65 | tnews_weights = torch.tensor(label_weights_dict['TNEWS']).to(device, non_blocking=True) 66 | ocnli_weights = torch.tensor(label_weights_dict['OCNLI']).to(device, non_blocking=True) 67 | ocemotion_weights = torch.tensor(label_weights_dict['OCEMOTION']).to(device, non_blocking=True) 68 | loss_object = Calculate_loss(label_dict, weighted=weighted_loss, tnews_weights=tnews_weights, ocnli_weights=ocnli_weights, ocemotion_weights=ocemotion_weights) 69 | optimizer=torch.optim.Adam(my_net.parameters(), lr=lr) 70 | best_dev_f1 = 0.0 71 | best_epoch = -1 72 | for epoch in range(epochs): 73 | my_net.train() 74 | train_loss = 0.0 75 | train_total = 0 76 | train_correct = 0 77 | train_ocnli_correct = 0 78 | train_ocemotion_correct = 0 79 | train_tnews_correct = 0 80 | train_ocnli_pred_list = [] 81 | train_ocnli_gold_list = [] 82 | train_ocemotion_pred_list = [] 83 | train_ocemotion_gold_list = [] 84 | train_tnews_pred_list = [] 85 | train_tnews_gold_list = [] 86 | cnt_train = 0 87 | while True: 88 | raw_data = train_data_generator.get_next_batch(batchSize) 89 | if raw_data == None: 90 | break 91 | data = dict() 92 | data['input_ids'] = raw_data['input_ids'] 93 | data['token_type_ids'] = raw_data['token_type_ids'] 94 | data['attention_mask'] = raw_data['attention_mask'] 95 | data['ocnli_ids'] = raw_data['ocnli_ids'] 96 | data['ocemotion_ids'] = raw_data['ocemotion_ids'] 97 | data['tnews_ids'] = raw_data['tnews_ids'] 98 | tnews_gold = raw_data['tnews_gold'] 99 | ocnli_gold = raw_data['ocnli_gold'] 100 | ocemotion_gold = raw_data['ocemotion_gold'] 101 | if not accumulate: 102 | optimizer.zero_grad() 103 | ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data) 104 | if use_dtp: 105 | tnews_kpi = 0.1 if len(train_tnews_pred_list) == 0 else train_tnews_correct / len(train_tnews_pred_list) 106 | ocnli_kpi = 0.1 if len(train_ocnli_pred_list) == 0 else train_ocnli_correct / len(train_ocnli_pred_list) 107 | ocemotion_kpi = 0.1 if len(train_ocemotion_pred_list) == 0 else train_ocemotion_correct / len(train_ocemotion_pred_list) 108 | current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, 109 | ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi) 110 | else: 111 | current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 112 | train_loss += current_loss.item() 113 | current_loss.backward() 114 | if accumulate and (cnt_train + 1) % a_step == 0: 115 | optimizer.step() 116 | optimizer.zero_grad() 117 | if not accumulate: 118 | optimizer.step() 119 | if use_dtp: 120 | good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 121 | tmp_good = sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb]) 122 | tmp_total = sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb]) 123 | train_ocemotion_correct += good_ocemotion_nb 124 | train_ocnli_correct += good_ocnli_nb 125 | train_tnews_correct += good_tnews_nb 126 | else: 127 | tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 128 | train_correct += tmp_good 129 | train_total += tmp_total 130 | p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold) 131 | train_ocnli_pred_list += p 132 | train_ocnli_gold_list += g 133 | p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold) 134 | train_ocemotion_pred_list += p 135 | train_ocemotion_gold_list += g 136 | p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold) 137 | train_tnews_pred_list += p 138 | train_tnews_gold_list += g 139 | cnt_train += 1 140 | #torch.cuda.empty_cache() 141 | if (cnt_train + 1) % 1000 == 0: 142 | print('[', cnt_train + 1, '- th batch : train acc is:', train_correct / train_total, '; train loss is:', train_loss / cnt_train, ']') 143 | if accumulate: 144 | optimizer.step() 145 | optimizer.zero_grad() 146 | train_ocnli_f1 = get_f1(train_ocnli_gold_list, train_ocnli_pred_list) 147 | train_ocemotion_f1 = get_f1(train_ocemotion_gold_list, train_ocemotion_pred_list) 148 | train_tnews_f1 = get_f1(train_tnews_gold_list, train_tnews_pred_list) 149 | train_avg_f1 = (train_ocnli_f1 + train_ocemotion_f1 + train_tnews_f1) / 3 150 | print(epoch, 'th epoch train average f1 is:', train_avg_f1) 151 | print(epoch, 'th epoch train ocnli is below:') 152 | print_result(train_ocnli_gold_list, train_ocnli_pred_list) 153 | print(epoch, 'th epoch train ocemotion is below:') 154 | print_result(train_ocemotion_gold_list, train_ocemotion_pred_list) 155 | print(epoch, 'th epoch train tnews is below:') 156 | print_result(train_tnews_gold_list, train_tnews_pred_list) 157 | 158 | train_data_generator.reset() 159 | 160 | my_net.eval() 161 | dev_loss = 0.0 162 | dev_total = 0 163 | dev_correct = 0 164 | dev_ocnli_correct = 0 165 | dev_ocemotion_correct = 0 166 | dev_tnews_correct = 0 167 | dev_ocnli_pred_list = [] 168 | dev_ocnli_gold_list = [] 169 | dev_ocemotion_pred_list = [] 170 | dev_ocemotion_gold_list = [] 171 | dev_tnews_pred_list = [] 172 | dev_tnews_gold_list = [] 173 | cnt_dev = 0 174 | with torch.no_grad(): 175 | while True: 176 | raw_data = dev_data_generator.get_next_batch(batchSize) 177 | if raw_data == None: 178 | break 179 | data = dict() 180 | data['input_ids'] = raw_data['input_ids'] 181 | data['token_type_ids'] = raw_data['token_type_ids'] 182 | data['attention_mask'] = raw_data['attention_mask'] 183 | data['ocnli_ids'] = raw_data['ocnli_ids'] 184 | data['ocemotion_ids'] = raw_data['ocemotion_ids'] 185 | data['tnews_ids'] = raw_data['tnews_ids'] 186 | tnews_gold = raw_data['tnews_gold'] 187 | ocnli_gold = raw_data['ocnli_gold'] 188 | ocemotion_gold = raw_data['ocemotion_gold'] 189 | ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data) 190 | if use_dtp: 191 | tnews_kpi = 0.1 if len(dev_tnews_pred_list) == 0 else dev_tnews_correct / len( 192 | dev_tnews_pred_list) 193 | ocnli_kpi = 0.1 if len(dev_ocnli_pred_list) == 0 else dev_ocnli_correct / len( 194 | dev_ocnli_pred_list) 195 | ocemotion_kpi = 0.1 if len(dev_ocemotion_pred_list) == 0 else dev_ocemotion_correct / len( 196 | dev_ocemotion_pred_list) 197 | current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, 198 | ocnli_gold, 199 | ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi) 200 | else: 201 | current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 202 | dev_loss += current_loss.item() 203 | if use_dtp: 204 | good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each( 205 | tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 206 | tmp_good += sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb]) 207 | tmp_total += sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb]) 208 | dev_ocemotion_correct += good_ocemotion_nb 209 | dev_ocnli_correct += good_ocnli_nb 210 | dev_tnews_correct += good_tnews_nb 211 | else: 212 | tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold) 213 | dev_correct += tmp_good 214 | dev_total += tmp_total 215 | p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold) 216 | dev_ocnli_pred_list += p 217 | dev_ocnli_gold_list += g 218 | p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold) 219 | dev_ocemotion_pred_list += p 220 | dev_ocemotion_gold_list += g 221 | p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold) 222 | dev_tnews_pred_list += p 223 | dev_tnews_gold_list += g 224 | cnt_dev += 1 225 | #torch.cuda.empty_cache() 226 | #if (cnt_dev + 1) % 1000 == 0: 227 | # print('[', cnt_dev + 1, '- th batch : dev acc is:', dev_correct / dev_total, '; dev loss is:', dev_loss / cnt_dev, ']') 228 | dev_ocnli_f1 = get_f1(dev_ocnli_gold_list, dev_ocnli_pred_list) 229 | dev_ocemotion_f1 = get_f1(dev_ocemotion_gold_list, dev_ocemotion_pred_list) 230 | dev_tnews_f1 = get_f1(dev_tnews_gold_list, dev_tnews_pred_list) 231 | dev_avg_f1 = (dev_ocnli_f1 + dev_ocemotion_f1 + dev_tnews_f1) / 3 232 | print(epoch, 'th epoch dev average f1 is:', dev_avg_f1) 233 | print(epoch, 'th epoch dev ocnli is below:') 234 | print_result(dev_ocnli_gold_list, dev_ocnli_pred_list) 235 | print(epoch, 'th epoch dev ocemotion is below:') 236 | print_result(dev_ocemotion_gold_list, dev_ocemotion_pred_list) 237 | print(epoch, 'th epoch dev tnews is below:') 238 | print_result(dev_tnews_gold_list, dev_tnews_pred_list) 239 | 240 | dev_data_generator.reset() 241 | 242 | if dev_avg_f1 > best_dev_f1: 243 | best_dev_f1 = dev_avg_f1 244 | best_epoch = epoch 245 | torch.save(my_net, file_path) 246 | print('best epoch is:', best_epoch, '; with best f1 is:', best_dev_f1) 247 | 248 | if __name__ == '__main__': 249 | print('---------------------start training-----------------------') 250 | pretrained_model = './bert_pretrain_model' 251 | tokenizer_model = './bert_pretrain_model' 252 | train(batchSize=16, device='cuda:0', lr=0.0001, use_dtp=True, pretrained_model=pretrained_model, tokenizer_model=tokenizer_model, weighted_loss=True) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Dec 5 17:31:42 2020 5 | 6 | @author: luokai 7 | """ 8 | 9 | from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, f1_score 10 | from transformers import BertModel, BertTokenizer 11 | 12 | 13 | def get_f1(l_t, l_p): 14 | marco_f1_score = f1_score(l_t, l_p, average='macro') 15 | return marco_f1_score 16 | 17 | def print_result(l_t, l_p): 18 | marco_f1_score = f1_score(l_t, l_p, average='macro') 19 | print(marco_f1_score) 20 | print(f"{'confusion_matrix':*^80}") 21 | print(confusion_matrix(l_t, l_p, )) 22 | print(f"{'classification_report':*^80}") 23 | print(classification_report(l_t, l_p, )) 24 | 25 | def load_tokenizer(path_or_name): 26 | return BertTokenizer.from_pretrained(path_or_name) 27 | 28 | def load_pretrained_model(path_or_name): 29 | return BertModel.from_pretrained(path_or_name) 30 | 31 | def get_task_chinese(task_type): 32 | if task_type == 'ocnli': 33 | return '(中文原版自然语言推理)' 34 | elif task_type == 'ocemotion': 35 | return '(中文情感分类)' 36 | else: 37 | return '(今日头条新闻标题分类)' --------------------------------------------------------------------------------