├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── tianchi-multi-task-nlp.iml
└── vcs.xml
├── README.md
├── __init__.py
├── bert_pretrain_model
└── empty.txt
├── calculate_loss.py
├── data_generator.py
├── generate_data.py
├── inference.py
├── net.py
├── submission
├── Dockerfile
├── empty.txt
└── run.sh
├── tianchi_datasets
├── OCEMOTION
│ └── empty.txt
├── OCNLI
│ └── empty.txt
├── TNEWS
│ └── empty.txt
└── empty.txt
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.tar
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/tianchi-multi-task-nlp.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 比赛全流程体验
2 | NLP中文预训练模型泛化能力挑战赛
3 |
4 | ## 训练环境介绍
5 |
6 | ```
7 | 机器信息:NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2
8 | pytorch 版本 1.6.0
9 |
10 | 机器信息:NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2
11 | pytorch 版本 1.7.1
12 | ```
13 |
14 | python依赖:
15 | ```
16 | pip install transformers
17 | ```
18 |
19 | ## Docker安装(Ubutun)
20 |
21 | 命令行安装:
22 | ```
23 | sudo apt install docker.io
24 | ```
25 |
26 | 验证:
27 | ```
28 | docker info
29 | ```
30 | 
31 |
32 |
33 | ## 运行过程
34 |
35 | 1. 下载Bert全权重,下载 https://huggingface.co/bert-base-chinese/tree/main 下载config.json vocab.txt pytorch_model.bin,把这三个文件放进tianchi-multi-task-nlp/bert_pretrain_model文件夹下。
36 |
37 | 2. 下载比赛数据集,把三个数据集分别放进 `tianchi-multi-task-nlp/tianchi_datasets/数据集名字/` 下面:
38 | - OCEMOTION/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCEMOTION_train1128.csv
39 | - OCEMOTION/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/ocemotion_test_B.csv
40 | - TNEWS/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/TNEWS_train1128.csv
41 | - TNEWS/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/tnews_test_B.csv
42 | - OCNLI/total.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCNLI_train1128.csv
43 | - OCNLI/test.csv: http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/b/ocnli_test_B.csv
44 |
45 | 文件目录样例:
46 | ```
47 | tianchi-multi-task-nlp/tianchi_datasets/OCNLI/total.csv
48 | tianchi-multi-task-nlp/tianchi_datasets/OCNLI/test.csv
49 | ```
50 |
51 | 3. 分开训练集和验证集,默认验证集是各3000条数据,参数可以自己修改:
52 | ```
53 | python ./generate_data.py
54 | ```
55 | 4. 训练模型,一个epoch:
56 | ```
57 | python ./train.py
58 | ```
59 | 会保存验证集上平均f1分数最高的模型到 ./saved_best.pt
60 |
61 | 5. 用训练好的模型 ./saved_best.pt 生成结果:
62 | ```
63 | python ./inference.py
64 | ```
65 |
66 | 6. 打包预测结果。
67 | ```
68 | zip -r ./result.zip ./*.json
69 | ```
70 | 7. 生成Docker并进行提交,参考:https://tianchi.aliyun.com/competition/entrance/231759/tab/174
71 | - 创建云端镜像仓库:https://cr.console.aliyun.com/
72 | - 创建命名空间和镜像仓库;
73 | - 然后切换到`submission`文件夹下,执行下面命令;
74 |
75 | ```
76 | # 用于登录的用户名为阿里云账号全名,密码为开通服务时设置的密码。
77 | sudo docker login --username=xxx@mail.com registry.cn-hangzhou.aliyuncs.com
78 |
79 | # 使用本地Dockefile进行构建,使用创建仓库的【公网地址】
80 | # 如 docker build -t registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 .
81 | docker build -t registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0 .
82 | ```
83 |
84 | 输出构建过程:
85 | ```
86 | Sending build context to Docker daemon 18.94kB
87 | Step 1/4 : FROM registry.cn-shanghai.aliyuncs.com/tcc-public/python:3
88 | ---> a4cc999cf2aa
89 | Step 2/4 : ADD . /
90 | ---> Using cache
91 | ---> b18fbb4425ef
92 | Step 3/4 : WORKDIR /
93 | ---> Using cache
94 | ---> f5fcc4ca5eca
95 | Step 4/4 : CMD ["sh", "run.sh"]
96 | ---> Using cache
97 | ---> ed0c4b0e545f
98 | Successfully built ed0c4b0e545f
99 | ```
100 |
101 | ```
102 | # ed0c4b0e545f 为镜像id,上面构建过程最后一行
103 | sudo docker tag ed0c4b0e545f registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0
104 |
105 | # 提交镜像到云端
106 | docker push registry.cn-shenzhen.aliyuncs.com/test_for_tianchi/test_for_tianchi_submit:1.0
107 | ```
108 |
109 | 8. [比赛提交页面](https://tianchi.aliyun.com/competition/entrance/531865/submission/723),填写镜像路径+版本号,以及用户名和密码则可以完成提交。
110 |
111 |
112 | ## 比赛改进思路
113 |
114 | 1. 修改 calculate_loss.py 改变loss的计算方式,从平衡子任务难度以及各子任务类别样本不均匀入手;
115 | 2. 修改 net.py 改变模型的结构,加入attention层,或者其他层;
116 | 3. 使用 cleanlab 等工具对训练文本进行清洗;
117 | 4. 做文本数据增强,或者在预训练时候用其他数据集pretrain;
118 | 5. 对训练好的模型再在完整数据集(包括验证集和训练集)上用小的学习率训练一个epoch;
119 | 6. 调整bathSize和a_step,变更梯度累计的程度,当前是batchSize=16,a_step=16;
120 | 7. 用 chinese-roberta-wwm-ext 作为预训练模型;
121 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 12:39:45 2020
5 |
6 | @author: luokai
7 | """
8 |
9 |
10 |
--------------------------------------------------------------------------------
/bert_pretrain_model/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/calculate_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:23:01 2020
5 |
6 | @author: luokai
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | import numpy as np
12 | from math import exp, log
13 |
14 |
15 | class Calculate_loss():
16 | def __init__(self, label_dict, weighted=False, tnews_weights=None, ocnli_weights=None, ocemotion_weights=None):
17 | self.weighted = weighted
18 | if weighted:
19 | self.tnews_loss = nn.CrossEntropyLoss(tnews_weights)
20 | self.ocnli_loss = nn.CrossEntropyLoss(ocnli_weights)
21 | self.ocemotion_loss = nn.CrossEntropyLoss(ocemotion_weights)
22 | else:
23 | self.loss = nn.CrossEntropyLoss()
24 | self.label2idx = dict()
25 | self.idx2label = dict()
26 | for key in ['TNEWS', 'OCNLI', 'OCEMOTION']:
27 | self.label2idx[key] = dict()
28 | self.idx2label[key] = dict()
29 | for i, e in enumerate(label_dict[key]):
30 | self.label2idx[key][e] = i
31 | self.idx2label[key][i] = e
32 |
33 | def idxToLabel(self, key, idx):
34 | return self.idx2Label[key][idx]
35 |
36 | def labelToIdx(self, key, label):
37 | return self.label2idx[key][label]
38 |
39 | def compute(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold):
40 | res = 0
41 | if tnews_pred != None:
42 | res += self.tnews_loss(tnews_pred, tnews_gold) if self.weighted else self.loss(tnews_pred, tnews_gold)
43 | if ocnli_pred != None:
44 | res += self.ocnli_loss(ocnli_pred, ocnli_gold) if self.weighted else self.loss(ocnli_pred, ocnli_gold)
45 | if ocemotion_pred != None:
46 | res += self.ocemotion_loss(ocemotion_pred, ocemotion_gold) if self.weighted else self.loss(ocemotion_pred, ocemotion_gold)
47 | return res
48 |
49 | def compute_dtp(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold, tnews_kpi=0.1, ocnli_kpi=0.1, ocemotion_kpi=0.1, y=0.5):
50 | res = 0
51 | if tnews_pred != None:
52 | res += self.tnews_loss(tnews_pred, tnews_gold) * self._calculate_weight(tnews_kpi, y) if self.weighted else self.loss(tnews_pred, tnews_gold) * self._calculate_weight(tnews_kpi, y)
53 | if ocnli_pred != None:
54 | res += self.ocnli_loss(ocnli_pred, ocnli_gold) * self._calculate_weight(ocnli_kpi, y) if self.weighted else self.loss(ocnli_pred, ocnli_gold) * self._calculate_weight(ocnli_kpi, y)
55 | if ocemotion_pred != None:
56 | res += self.ocemotion_loss(ocemotion_pred, ocemotion_gold) * self._calculate_weight(ocemotion_kpi, y) if self.weighted else self.loss(ocemotion_pred, ocemotion_gold) * self._calculate_weight(ocemotion_kpi, y)
57 | return res
58 |
59 |
60 | def correct_cnt(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold):
61 | good_nb = 0
62 | total_nb = 0
63 | if tnews_pred != None:
64 | tnews_val = torch.argmax(tnews_pred, axis=1)
65 | for i, e in enumerate(tnews_gold):
66 | if e == tnews_val[i]:
67 | good_nb += 1
68 | total_nb += 1
69 | if ocnli_pred != None:
70 | ocnli_val = torch.argmax(ocnli_pred, axis=1)
71 | for i, e in enumerate(ocnli_gold):
72 | if e == ocnli_val[i]:
73 | good_nb += 1
74 | total_nb += 1
75 | if ocemotion_pred != None:
76 | ocemotion_val = torch.argmax(ocemotion_pred, axis=1)
77 | for i, e in enumerate(ocemotion_gold):
78 | if e == ocemotion_val[i]:
79 | good_nb += 1
80 | total_nb += 1
81 | return good_nb, total_nb
82 |
83 | def correct_cnt_each(self, tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold):
84 | good_ocnli_nb = 0
85 | good_ocemotion_nb = 0
86 | good_tnews_nb = 0
87 | total_ocnli_nb = 0
88 | total_ocemotion_nb = 0
89 | total_tnews_nb = 0
90 | if tnews_pred != None:
91 | tnews_val = torch.argmax(tnews_pred, axis=1)
92 | for i, e in enumerate(tnews_gold):
93 | if e == tnews_val[i]:
94 | good_tnews_nb += 1
95 | total_tnews_nb += 1
96 | if ocnli_pred != None:
97 | ocnli_val = torch.argmax(ocnli_pred, axis=1)
98 | for i, e in enumerate(ocnli_gold):
99 | if e == ocnli_val[i]:
100 | good_ocnli_nb += 1
101 | total_ocnli_nb += 1
102 | if ocemotion_pred != None:
103 | ocemotion_val = torch.argmax(ocemotion_pred, axis=1)
104 | for i, e in enumerate(ocemotion_gold):
105 | if e == ocemotion_val[i]:
106 | good_ocemotion_nb += 1
107 | total_ocemotion_nb += 1
108 | return good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb
109 |
110 | def collect_pred_and_gold(self, pred, gold):
111 | if pred == None or gold == None:
112 | p, g = [], []
113 | else:
114 | p, g = np.array(torch.argmax(pred, axis=1).cpu()).tolist(), np.array(gold.cpu()).tolist()
115 | return p, g
116 |
117 | def _calculate_weight(self, kpi, y):
118 | kpi = max(0.1, kpi)
119 | kpi = min(0.99, kpi)
120 | w = -1 * ((1 - kpi) ** y) * log(kpi)
121 | return w
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:29:08 2020
5 |
6 | @author: luokai
7 | """
8 | import random
9 | import torch
10 | from transformers import BertTokenizer
11 |
12 | class Data_generator():
13 | def __init__(self, ocnli_dict, ocemotion_dict, tnews_dict, label_dict, device, tokenizer, max_len=512):
14 | self.max_len = max_len
15 | self.tokenizer = tokenizer
16 | self.device = device
17 | self.label2idx = dict()
18 | self.idx2label = dict()
19 | for key in ['TNEWS', 'OCNLI', 'OCEMOTION']:
20 | self.label2idx[key] = dict()
21 | self.idx2label[key] = dict()
22 | for i, e in enumerate(label_dict[key]):
23 | self.label2idx[key][e] = i
24 | self.idx2label[key][i] = e
25 | self.ocnli_data = dict()
26 | self.ocnli_data['s1'] = []
27 | self.ocnli_data['s2'] = []
28 | self.ocnli_data['label'] = []
29 | for k, v in ocnli_dict.items():
30 | self.ocnli_data['s1'].append(v['s1'])
31 | self.ocnli_data['s2'].append(v['s2'])
32 | self.ocnli_data['label'].append(self.label2idx['OCNLI'][v['label']])
33 | self.ocemotion_data = dict()
34 | self.ocemotion_data['s1'] = []
35 | self.ocemotion_data['label'] = []
36 | for k, v in ocemotion_dict.items():
37 | self.ocemotion_data['s1'].append(v['s1'])
38 | self.ocemotion_data['label'].append(self.label2idx['OCEMOTION'][v['label']])
39 | self.tnews_data = dict()
40 | self.tnews_data['s1'] = []
41 | self.tnews_data['label'] = []
42 | for k, v in tnews_dict.items():
43 | self.tnews_data['s1'].append(v['s1'])
44 | self.tnews_data['label'].append(self.label2idx['TNEWS'][v['label']])
45 | self.reset()
46 | def reset(self):
47 | self.ocnli_ids = list(range(len(self.ocnli_data['s1'])))
48 | self.ocemotion_ids = list(range(len(self.ocemotion_data['s1'])))
49 | self.tnews_ids = list(range(len(self.tnews_data['s1'])))
50 | random.shuffle(self.ocnli_ids)
51 | random.shuffle(self.ocemotion_ids)
52 | random.shuffle(self.tnews_ids)
53 | def get_next_batch(self, batchSize=64):
54 | ocnli_len = len(self.ocnli_ids)
55 | ocemotion_len = len(self.ocemotion_ids)
56 | tnews_len = len(self.tnews_ids)
57 | total_len = ocnli_len + ocemotion_len + tnews_len
58 | if total_len == 0:
59 | return None
60 | elif total_len > batchSize:
61 | if ocnli_len > 0:
62 | ocnli_tmp_len = int((ocnli_len / total_len) * batchSize)
63 | ocnli_cur = self.ocnli_ids[:ocnli_tmp_len]
64 | self.ocnli_ids = self.ocnli_ids[ocnli_tmp_len:]
65 | if ocemotion_len > 0:
66 | ocemotion_tmp_len = int((ocemotion_len / total_len) * batchSize)
67 | ocemotion_cur = self.ocemotion_ids[:ocemotion_tmp_len]
68 | self.ocemotion_ids = self.ocemotion_ids[ocemotion_tmp_len:]
69 | if tnews_len > 0:
70 | tnews_tmp_len = batchSize - len(ocnli_cur) - len(ocemotion_cur)
71 | tnews_cur = self.tnews_ids[:tnews_tmp_len]
72 | self.tnews_ids = self.tnews_ids[tnews_tmp_len:]
73 | else:
74 | ocnli_cur = self.ocnli_ids
75 | self.ocnli_ids = []
76 | ocemotion_cur = self.ocemotion_ids
77 | self.ocemotion_ids = []
78 | tnews_cur = self.tnews_ids
79 | self.tnews_ids = []
80 | max_len = self._get_max_total_len(ocnli_cur, ocemotion_cur, tnews_cur)
81 | input_ids = []
82 | token_type_ids = []
83 | attention_mask = []
84 | ocnli_gold = None
85 | ocemotion_gold = None
86 | tnews_gold = None
87 | if len(ocnli_cur) > 0:
88 | flower = self.tokenizer([self.ocnli_data['s1'][idx] for idx in ocnli_cur], [self.ocnli_data['s2'][idx] for idx in ocnli_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True)
89 | input_ids.append(flower['input_ids'])
90 | token_type_ids.append(flower['token_type_ids'])
91 | attention_mask.append(flower['attention_mask'])
92 | ocnli_gold = torch.tensor([self.ocnli_data['label'][idx] for idx in ocnli_cur]).to(self.device)
93 | if len(ocemotion_cur) > 0:
94 | flower = self.tokenizer([self.ocemotion_data['s1'][idx] for idx in ocemotion_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True)
95 | input_ids.append(flower['input_ids'])
96 | token_type_ids.append(flower['token_type_ids'])
97 | attention_mask.append(flower['attention_mask'])
98 | ocemotion_gold = torch.tensor([self.ocemotion_data['label'][idx] for idx in ocemotion_cur]).to(self.device)
99 | if len(tnews_cur) > 0:
100 | flower = self.tokenizer([self.tnews_data['s1'][idx] for idx in tnews_cur], add_special_tokens=True, max_length=max_len, padding='max_length', return_tensors='pt', truncation=True)
101 | input_ids.append(flower['input_ids'])
102 | token_type_ids.append(flower['token_type_ids'])
103 | attention_mask.append(flower['attention_mask'])
104 | tnews_gold = torch.tensor([self.tnews_data['label'][idx] for idx in tnews_cur]).to(self.device)
105 | st = 0
106 | ed = len(ocnli_cur)
107 | ocnli_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device)
108 | st += len(ocnli_cur)
109 | ed += len(ocemotion_cur)
110 | ocemotion_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device)
111 | st += len(ocemotion_cur)
112 | ed += len(tnews_cur)
113 | tnews_tensor = torch.tensor([i for i in range(st, ed)]).to(self.device)
114 | input_ids = torch.cat(input_ids, axis=0).to(self.device)
115 | token_type_ids = torch.cat(token_type_ids, axis=0).to(self.device)
116 | attention_mask = torch.cat(attention_mask, axis=0).to(self.device)
117 | res = dict()
118 | res['input_ids'] = input_ids
119 | res['token_type_ids'] = token_type_ids
120 | res['attention_mask'] = attention_mask
121 | res['ocnli_ids'] = ocnli_tensor
122 | res['ocemotion_ids'] = ocemotion_tensor
123 | res['tnews_ids'] = tnews_tensor
124 | res['ocnli_gold'] = ocnli_gold
125 | res['ocemotion_gold'] = ocemotion_gold
126 | res['tnews_gold'] = tnews_gold
127 | return res
128 |
129 | def _get_max_total_len(self, ocnli_cur, ocemotion_cur, tnews_cur):
130 | res = 1
131 | for idx in ocnli_cur:
132 | res = max(res, 3 + len(self.ocnli_data['s1'][idx]) + len(self.ocnli_data['s2'][idx]))
133 | for idx in ocemotion_cur:
134 | res = max(res, 2 + len(self.ocemotion_data['s1'][idx]))
135 | for idx in tnews_cur:
136 | res = max(res, 2 + len(self.tnews_data['s1'][idx]))
137 | return min(res, self.max_len)
--------------------------------------------------------------------------------
/generate_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 12:46:03 2020
5 |
6 | @author: luokai
7 | """
8 |
9 |
10 | import json
11 | from collections import defaultdict
12 | from math import log
13 |
14 | def split_dataset(dev_data_cnt=5000):
15 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
16 | cnt = 0
17 | with open('./tianchi_datasets/' + e + '/total.csv') as f:
18 | with open('./tianchi_datasets/' + e + '/train.csv', 'w') as f_train:
19 | with open('./tianchi_datasets/' + e + '/dev.csv', 'w') as f_dev:
20 | for line in f:
21 | cnt += 1
22 | if cnt <= dev_data_cnt:
23 | f_dev.write(line)
24 | else:
25 | f_train.write(line)
26 |
27 | def print_one_data(path, name, print_content=False):
28 | data_cnt = 0
29 | with open(path) as f:
30 | for line in f:
31 | tmp = json.loads(line)
32 | for _, v in tmp.items():
33 | data_cnt += 1
34 | if print_content:
35 | print(v)
36 | print(name, 'contains:', data_cnt, 'numbers of data')
37 |
38 | def generate_data():
39 | label_set = dict()
40 | label_cnt_set = dict()
41 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
42 | label_set[e] = set()
43 | label_cnt_set[e] = defaultdict(int)
44 | with open('./tianchi_datasets/' + e + '/total.csv') as f:
45 | for line in f:
46 | label = line.strip().split('\t')[-1]
47 | label_set[e].add(label)
48 | label_cnt_set[e][label] += 1
49 | for k in label_set:
50 | label_set[k] = sorted(list(label_set[k]))
51 | for k, v in label_set.items():
52 | print(k, v)
53 | with open('./tianchi_datasets/label.json', 'w') as fw:
54 | fw.write(json.dumps(label_set))
55 | label_weight_set = dict()
56 | for k in label_set:
57 | label_weight_set[k] = [label_cnt_set[k][e] for e in label_set[k]]
58 | total_weight = sum(label_weight_set[k])
59 | label_weight_set[k] = [log(total_weight / e) for e in label_weight_set[k]]
60 | for k, v in label_weight_set.items():
61 | print(k, v)
62 | with open('./tianchi_datasets/label_weights.json', 'w') as fw:
63 | fw.write(json.dumps(label_weight_set))
64 |
65 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
66 | for name in ['dev', 'train']:
67 | with open('./tianchi_datasets/' + e + '/' + name + '.csv') as fr:
68 | with open('./tianchi_datasets/' + e + '/' + name + '.json', 'w') as fw:
69 | json_dict = dict()
70 | for line in fr:
71 | tmp_list = line.strip().split('\t')
72 | json_dict[tmp_list[0]] = dict()
73 | json_dict[tmp_list[0]]['s1'] = tmp_list[1]
74 | if e == 'OCNLI':
75 | json_dict[tmp_list[0]]['s2'] = tmp_list[2]
76 | json_dict[tmp_list[0]]['label'] = tmp_list[3]
77 | else:
78 | json_dict[tmp_list[0]]['label'] = tmp_list[2]
79 | fw.write(json.dumps(json_dict))
80 |
81 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
82 | for name in ['dev', 'train']:
83 | cur_path = './tianchi_datasets/' + e + '/' + name + '.json'
84 | data_name = e + '_' + name
85 | print_one_data(cur_path, data_name)
86 |
87 | print_one_data('./tianchi_datasets/label.json', 'label_set')
88 |
89 | if __name__ == '__main__':
90 | print('-------------------------------start-----------------------------------')
91 | split_dataset(dev_data_cnt=3000)
92 | generate_data()
93 | print('-------------------------------finish-----------------------------------')
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:47:24 2020
5 |
6 | @author: luokai
7 | """
8 |
9 |
10 | from net import Net
11 | import json
12 | import torch
13 | import numpy as np
14 | from transformers import BertModel, BertTokenizer
15 | from utils import get_task_chinese
16 |
17 |
18 | def test_csv_to_json():
19 | for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
20 | with open('./tianchi_datasets/' + e + '/test.csv') as fr:
21 | with open('./tianchi_datasets/' + e + '/test.json', 'w') as fw:
22 | json_dict = dict()
23 | for line in fr:
24 | tmp_list = line.strip().split('\t')
25 | json_dict[tmp_list[0]] = dict()
26 | json_dict[tmp_list[0]]['s1'] = tmp_list[1]
27 | if e == 'OCNLI':
28 | json_dict[tmp_list[0]]['s2'] = tmp_list[2]
29 | fw.write(json.dumps(json_dict))
30 |
31 | def inference_warpper(tokenizer_model):
32 | ocnli_test = dict()
33 | with open('./tianchi_datasets/OCNLI/test.json') as f:
34 | for line in f:
35 | ocnli_test = json.loads(line)
36 | break
37 |
38 | ocemotion_test = dict()
39 | with open('./tianchi_datasets/OCEMOTION/test.json') as f:
40 | for line in f:
41 | ocemotion_test = json.loads(line)
42 | break
43 |
44 | tnews_test = dict()
45 | with open('./tianchi_datasets/TNEWS/test.json') as f:
46 | for line in f:
47 | tnews_test = json.loads(line)
48 | break
49 |
50 | label_dict = dict()
51 | with open('./tianchi_datasets/label.json') as f:
52 | for line in f:
53 | label_dict = json.loads(line)
54 | break
55 |
56 | model = torch.load('./saved_best.pt')
57 | tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
58 | inference('./submission/ocnli_predict.json', ocnli_test, model, tokenizer, label_dict['OCNLI'], 'ocnli', 'cuda:0', 64, True)
59 | inference('./submission/ocemotion_predict.json', ocemotion_test, model, tokenizer, label_dict['OCEMOTION'], 'ocemotion', 'cuda:0', 64, True)
60 | inference('./submission/tnews_predict.json', tnews_test, model, tokenizer, label_dict['TNEWS'], 'tnews', 'cuda:0', 64, True)
61 |
62 | def inference(path, data_dict, model, tokenizer, idx2label, task_type, device='cuda:0', batchSize=64, print_result=True):
63 | if task_type != 'ocnli' and task_type != 'ocemotion' and task_type != 'tnews':
64 | print('task_type is incorrect!')
65 | return
66 | model.to(device, non_blocking=True)
67 | model.eval()
68 | ids_list = [k for k, _ in data_dict.items()]
69 | next_start_ids = 0
70 | with torch.no_grad():
71 | with open(path, 'w') as f:
72 | while next_start_ids < len(ids_list):
73 | cur_ids_list = ids_list[next_start_ids: next_start_ids + batchSize]
74 | next_start_ids += batchSize
75 | if task_type == 'ocnli':
76 | flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], [data_dict[idx]['s2'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt')
77 | else:
78 | flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt')
79 | input_ids = flower['input_ids'].to(device, non_blocking=True)
80 | token_type_ids = flower['token_type_ids'].to(device, non_blocking=True)
81 | attention_mask = flower['attention_mask'].to(device, non_blocking=True)
82 | ocnli_ids = torch.tensor([]).to(device, non_blocking=True)
83 | ocemotion_ids = torch.tensor([]).to(device, non_blocking=True)
84 | tnews_ids = torch.tensor([]).to(device, non_blocking=True)
85 | if task_type == 'ocnli':
86 | ocnli_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
87 | elif task_type == 'ocemotion':
88 | ocemotion_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
89 | else:
90 | tnews_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
91 | ocnli_out, ocemotion_out, tnews_out = model(input_ids, ocnli_ids, ocemotion_ids, tnews_ids, token_type_ids, attention_mask)
92 | if task_type == 'ocnli':
93 | pred = torch.argmax(ocnli_out, axis=1)
94 | elif task_type == 'ocemotion':
95 | pred = torch.argmax(ocemotion_out, axis=1)
96 | else:
97 | pred = torch.argmax(tnews_out, axis=1)
98 | pred_final = [idx2label[e] for e in np.array(pred.cpu()).tolist()]
99 | #torch.cuda.empty_cache()
100 | for i, idx in enumerate(cur_ids_list):
101 | if print_result:
102 | print_str = '[ ' + task_type + ' : ' + 'sentence one: ' + data_dict[idx]['s1']
103 | if task_type == 'ocnli':
104 | print_str += '; sentence two: ' + data_dict[idx]['s2']
105 | print_str += '; result: ' + pred_final[i] + ' ]'
106 | print(print_str)
107 | single_result_dict = dict()
108 | single_result_dict['id'] = idx
109 | single_result_dict['label'] = pred_final[i]
110 | f.write(json.dumps(single_result_dict, ensure_ascii=False))
111 | if not (next_start_ids >= len(ids_list) and i == len(cur_ids_list) - 1):
112 | f.write('\n')
113 |
114 | if __name__ == '__main__':
115 | test_csv_to_json()
116 | print('---------------------------------start inference-----------------------------')
117 | inference_warpper(tokenizer_model='./bert_pretrain_model')
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:20:35 2020
5 |
6 | @author: luokai
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | from transformers import BertModel
12 |
13 |
14 | class Net(nn.Module):
15 | def __init__(self, bert_model):
16 | super(Net, self).__init__()
17 | self.bert = bert_model
18 | self.atten_layer = nn.Linear(768, 16)
19 | self.softmax_d1 = nn.Softmax(dim=1)
20 | self.dropout = nn.Dropout(0.2)
21 | self.OCNLI_layer = nn.Linear(768, 16 * 3)
22 | self.OCEMOTION_layer = nn.Linear(768, 16 * 7)
23 | self.TNEWS_layer = nn.Linear(768, 16 * 15)
24 |
25 | def forward(self, input_ids, ocnli_ids, ocemotion_ids, tnews_ids, token_type_ids=None, attention_mask=None):
26 | cls_emb = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0][:, 0, :].squeeze(1)
27 | if ocnli_ids.size()[0] > 0:
28 | attention_score = self.atten_layer(cls_emb[ocnli_ids, :])
29 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1))
30 | ocnli_value = self.OCNLI_layer(cls_emb[ocnli_ids, :]).contiguous().view(-1, 16, 3)
31 | ocnli_out = torch.matmul(attention_score, ocnli_value).squeeze(1)
32 | else:
33 | ocnli_out = None
34 | if ocemotion_ids.size()[0] > 0:
35 | attention_score = self.atten_layer(cls_emb[ocemotion_ids, :])
36 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1))
37 | ocemotion_value = self.OCEMOTION_layer(cls_emb[ocemotion_ids, :]).contiguous().view(-1, 16, 7)
38 | ocemotion_out = torch.matmul(attention_score, ocemotion_value).squeeze(1)
39 | else:
40 | ocemotion_out = None
41 | if tnews_ids.size()[0] > 0:
42 | attention_score = self.atten_layer(cls_emb[tnews_ids, :])
43 | attention_score = self.dropout(self.softmax_d1(attention_score).unsqueeze(1))
44 | tnews_value = self.TNEWS_layer(cls_emb[tnews_ids, :]).contiguous().view(-1, 16, 15)
45 | tnews_out = torch.matmul(attention_score, tnews_value).squeeze(1)
46 | else:
47 | tnews_out = None
48 | return ocnli_out, ocemotion_out, tnews_out
--------------------------------------------------------------------------------
/submission/Dockerfile:
--------------------------------------------------------------------------------
1 | # Base Images
2 | ## 从天池基础镜像构建
3 | FROM registry.cn-shanghai.aliyuncs.com/tcc-public/python:3
4 |
5 | ## 把当前文件夹里的文件构建到镜像的根目录下
6 | ADD . /
7 |
8 | ## 指定默认工作目录为根目录(需要把run.sh和生成的结果文件都放在该文件夹下,提交后才能运行)
9 | WORKDIR /
10 |
11 | ## 镜像启动后统一执行 sh run.sh
12 | CMD ["sh", "run.sh"]
--------------------------------------------------------------------------------
/submission/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/submission/run.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/finlay-liu/tianchi-multi-task-nlp/261ec1b2fa611112cc313c048681a9a141141f04/submission/run.sh
--------------------------------------------------------------------------------
/tianchi_datasets/OCEMOTION/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/tianchi_datasets/OCNLI/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/tianchi_datasets/TNEWS/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/tianchi_datasets/empty.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:34:27 2020
5 |
6 | @author: luokai
7 | """
8 |
9 | import torch
10 | from transformers import BertModel, BertTokenizer
11 | import json
12 | from utils import get_f1, print_result, load_pretrained_model, load_tokenizer
13 | from net import Net
14 | from data_generator import Data_generator
15 | from calculate_loss import Calculate_loss
16 |
17 |
18 | def train(epochs=20, batchSize=64, lr=0.0001, device='cuda:0', accumulate=True, a_step=16, load_saved=False, file_path='./saved_best.pt', use_dtp=False, pretrained_model='./bert_pretrain_model', tokenizer_model='bert-base-chinese', weighted_loss=False):
19 | device = device
20 | tokenizer = load_tokenizer(tokenizer_model)
21 | my_net = torch.load(file_path) if load_saved else Net(load_pretrained_model(pretrained_model))
22 | my_net.to(device, non_blocking=True)
23 | label_dict = dict()
24 | with open('./tianchi_datasets/label.json') as f:
25 | for line in f:
26 | label_dict = json.loads(line)
27 | break
28 | label_weights_dict = dict()
29 | with open('./tianchi_datasets/label_weights.json') as f:
30 | for line in f:
31 | label_weights_dict = json.loads(line)
32 | break
33 | ocnli_train = dict()
34 | with open('./tianchi_datasets/OCNLI/train.json') as f:
35 | for line in f:
36 | ocnli_train = json.loads(line)
37 | break
38 | ocnli_dev = dict()
39 | with open('./tianchi_datasets/OCNLI/dev.json') as f:
40 | for line in f:
41 | ocnli_dev = json.loads(line)
42 | break
43 | ocemotion_train = dict()
44 | with open('./tianchi_datasets/OCEMOTION/train.json') as f:
45 | for line in f:
46 | ocemotion_train = json.loads(line)
47 | break
48 | ocemotion_dev = dict()
49 | with open('./tianchi_datasets/OCEMOTION/dev.json') as f:
50 | for line in f:
51 | ocemotion_dev = json.loads(line)
52 | break
53 | tnews_train = dict()
54 | with open('./tianchi_datasets/TNEWS/train.json') as f:
55 | for line in f:
56 | tnews_train = json.loads(line)
57 | break
58 | tnews_dev = dict()
59 | with open('./tianchi_datasets/TNEWS/dev.json') as f:
60 | for line in f:
61 | tnews_dev = json.loads(line)
62 | break
63 | train_data_generator = Data_generator(ocnli_train, ocemotion_train, tnews_train, label_dict, device, tokenizer)
64 | dev_data_generator = Data_generator(ocnli_dev, ocemotion_dev, tnews_dev, label_dict, device, tokenizer)
65 | tnews_weights = torch.tensor(label_weights_dict['TNEWS']).to(device, non_blocking=True)
66 | ocnli_weights = torch.tensor(label_weights_dict['OCNLI']).to(device, non_blocking=True)
67 | ocemotion_weights = torch.tensor(label_weights_dict['OCEMOTION']).to(device, non_blocking=True)
68 | loss_object = Calculate_loss(label_dict, weighted=weighted_loss, tnews_weights=tnews_weights, ocnli_weights=ocnli_weights, ocemotion_weights=ocemotion_weights)
69 | optimizer=torch.optim.Adam(my_net.parameters(), lr=lr)
70 | best_dev_f1 = 0.0
71 | best_epoch = -1
72 | for epoch in range(epochs):
73 | my_net.train()
74 | train_loss = 0.0
75 | train_total = 0
76 | train_correct = 0
77 | train_ocnli_correct = 0
78 | train_ocemotion_correct = 0
79 | train_tnews_correct = 0
80 | train_ocnli_pred_list = []
81 | train_ocnli_gold_list = []
82 | train_ocemotion_pred_list = []
83 | train_ocemotion_gold_list = []
84 | train_tnews_pred_list = []
85 | train_tnews_gold_list = []
86 | cnt_train = 0
87 | while True:
88 | raw_data = train_data_generator.get_next_batch(batchSize)
89 | if raw_data == None:
90 | break
91 | data = dict()
92 | data['input_ids'] = raw_data['input_ids']
93 | data['token_type_ids'] = raw_data['token_type_ids']
94 | data['attention_mask'] = raw_data['attention_mask']
95 | data['ocnli_ids'] = raw_data['ocnli_ids']
96 | data['ocemotion_ids'] = raw_data['ocemotion_ids']
97 | data['tnews_ids'] = raw_data['tnews_ids']
98 | tnews_gold = raw_data['tnews_gold']
99 | ocnli_gold = raw_data['ocnli_gold']
100 | ocemotion_gold = raw_data['ocemotion_gold']
101 | if not accumulate:
102 | optimizer.zero_grad()
103 | ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data)
104 | if use_dtp:
105 | tnews_kpi = 0.1 if len(train_tnews_pred_list) == 0 else train_tnews_correct / len(train_tnews_pred_list)
106 | ocnli_kpi = 0.1 if len(train_ocnli_pred_list) == 0 else train_ocnli_correct / len(train_ocnli_pred_list)
107 | ocemotion_kpi = 0.1 if len(train_ocemotion_pred_list) == 0 else train_ocemotion_correct / len(train_ocemotion_pred_list)
108 | current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold,
109 | ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi)
110 | else:
111 | current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
112 | train_loss += current_loss.item()
113 | current_loss.backward()
114 | if accumulate and (cnt_train + 1) % a_step == 0:
115 | optimizer.step()
116 | optimizer.zero_grad()
117 | if not accumulate:
118 | optimizer.step()
119 | if use_dtp:
120 | good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
121 | tmp_good = sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb])
122 | tmp_total = sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb])
123 | train_ocemotion_correct += good_ocemotion_nb
124 | train_ocnli_correct += good_ocnli_nb
125 | train_tnews_correct += good_tnews_nb
126 | else:
127 | tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
128 | train_correct += tmp_good
129 | train_total += tmp_total
130 | p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold)
131 | train_ocnli_pred_list += p
132 | train_ocnli_gold_list += g
133 | p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold)
134 | train_ocemotion_pred_list += p
135 | train_ocemotion_gold_list += g
136 | p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold)
137 | train_tnews_pred_list += p
138 | train_tnews_gold_list += g
139 | cnt_train += 1
140 | #torch.cuda.empty_cache()
141 | if (cnt_train + 1) % 1000 == 0:
142 | print('[', cnt_train + 1, '- th batch : train acc is:', train_correct / train_total, '; train loss is:', train_loss / cnt_train, ']')
143 | if accumulate:
144 | optimizer.step()
145 | optimizer.zero_grad()
146 | train_ocnli_f1 = get_f1(train_ocnli_gold_list, train_ocnli_pred_list)
147 | train_ocemotion_f1 = get_f1(train_ocemotion_gold_list, train_ocemotion_pred_list)
148 | train_tnews_f1 = get_f1(train_tnews_gold_list, train_tnews_pred_list)
149 | train_avg_f1 = (train_ocnli_f1 + train_ocemotion_f1 + train_tnews_f1) / 3
150 | print(epoch, 'th epoch train average f1 is:', train_avg_f1)
151 | print(epoch, 'th epoch train ocnli is below:')
152 | print_result(train_ocnli_gold_list, train_ocnli_pred_list)
153 | print(epoch, 'th epoch train ocemotion is below:')
154 | print_result(train_ocemotion_gold_list, train_ocemotion_pred_list)
155 | print(epoch, 'th epoch train tnews is below:')
156 | print_result(train_tnews_gold_list, train_tnews_pred_list)
157 |
158 | train_data_generator.reset()
159 |
160 | my_net.eval()
161 | dev_loss = 0.0
162 | dev_total = 0
163 | dev_correct = 0
164 | dev_ocnli_correct = 0
165 | dev_ocemotion_correct = 0
166 | dev_tnews_correct = 0
167 | dev_ocnli_pred_list = []
168 | dev_ocnli_gold_list = []
169 | dev_ocemotion_pred_list = []
170 | dev_ocemotion_gold_list = []
171 | dev_tnews_pred_list = []
172 | dev_tnews_gold_list = []
173 | cnt_dev = 0
174 | with torch.no_grad():
175 | while True:
176 | raw_data = dev_data_generator.get_next_batch(batchSize)
177 | if raw_data == None:
178 | break
179 | data = dict()
180 | data['input_ids'] = raw_data['input_ids']
181 | data['token_type_ids'] = raw_data['token_type_ids']
182 | data['attention_mask'] = raw_data['attention_mask']
183 | data['ocnli_ids'] = raw_data['ocnli_ids']
184 | data['ocemotion_ids'] = raw_data['ocemotion_ids']
185 | data['tnews_ids'] = raw_data['tnews_ids']
186 | tnews_gold = raw_data['tnews_gold']
187 | ocnli_gold = raw_data['ocnli_gold']
188 | ocemotion_gold = raw_data['ocemotion_gold']
189 | ocnli_pred, ocemotion_pred, tnews_pred = my_net(**data)
190 | if use_dtp:
191 | tnews_kpi = 0.1 if len(dev_tnews_pred_list) == 0 else dev_tnews_correct / len(
192 | dev_tnews_pred_list)
193 | ocnli_kpi = 0.1 if len(dev_ocnli_pred_list) == 0 else dev_ocnli_correct / len(
194 | dev_ocnli_pred_list)
195 | ocemotion_kpi = 0.1 if len(dev_ocemotion_pred_list) == 0 else dev_ocemotion_correct / len(
196 | dev_ocemotion_pred_list)
197 | current_loss = loss_object.compute_dtp(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold,
198 | ocnli_gold,
199 | ocemotion_gold, tnews_kpi, ocnli_kpi, ocemotion_kpi)
200 | else:
201 | current_loss = loss_object.compute(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
202 | dev_loss += current_loss.item()
203 | if use_dtp:
204 | good_tnews_nb, good_ocnli_nb, good_ocemotion_nb, total_tnews_nb, total_ocnli_nb, total_ocemotion_nb = loss_object.correct_cnt_each(
205 | tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
206 | tmp_good += sum([good_tnews_nb, good_ocnli_nb, good_ocemotion_nb])
207 | tmp_total += sum([total_tnews_nb, total_ocnli_nb, total_ocemotion_nb])
208 | dev_ocemotion_correct += good_ocemotion_nb
209 | dev_ocnli_correct += good_ocnli_nb
210 | dev_tnews_correct += good_tnews_nb
211 | else:
212 | tmp_good, tmp_total = loss_object.correct_cnt(tnews_pred, ocnli_pred, ocemotion_pred, tnews_gold, ocnli_gold, ocemotion_gold)
213 | dev_correct += tmp_good
214 | dev_total += tmp_total
215 | p, g = loss_object.collect_pred_and_gold(ocnli_pred, ocnli_gold)
216 | dev_ocnli_pred_list += p
217 | dev_ocnli_gold_list += g
218 | p, g = loss_object.collect_pred_and_gold(ocemotion_pred, ocemotion_gold)
219 | dev_ocemotion_pred_list += p
220 | dev_ocemotion_gold_list += g
221 | p, g = loss_object.collect_pred_and_gold(tnews_pred, tnews_gold)
222 | dev_tnews_pred_list += p
223 | dev_tnews_gold_list += g
224 | cnt_dev += 1
225 | #torch.cuda.empty_cache()
226 | #if (cnt_dev + 1) % 1000 == 0:
227 | # print('[', cnt_dev + 1, '- th batch : dev acc is:', dev_correct / dev_total, '; dev loss is:', dev_loss / cnt_dev, ']')
228 | dev_ocnli_f1 = get_f1(dev_ocnli_gold_list, dev_ocnli_pred_list)
229 | dev_ocemotion_f1 = get_f1(dev_ocemotion_gold_list, dev_ocemotion_pred_list)
230 | dev_tnews_f1 = get_f1(dev_tnews_gold_list, dev_tnews_pred_list)
231 | dev_avg_f1 = (dev_ocnli_f1 + dev_ocemotion_f1 + dev_tnews_f1) / 3
232 | print(epoch, 'th epoch dev average f1 is:', dev_avg_f1)
233 | print(epoch, 'th epoch dev ocnli is below:')
234 | print_result(dev_ocnli_gold_list, dev_ocnli_pred_list)
235 | print(epoch, 'th epoch dev ocemotion is below:')
236 | print_result(dev_ocemotion_gold_list, dev_ocemotion_pred_list)
237 | print(epoch, 'th epoch dev tnews is below:')
238 | print_result(dev_tnews_gold_list, dev_tnews_pred_list)
239 |
240 | dev_data_generator.reset()
241 |
242 | if dev_avg_f1 > best_dev_f1:
243 | best_dev_f1 = dev_avg_f1
244 | best_epoch = epoch
245 | torch.save(my_net, file_path)
246 | print('best epoch is:', best_epoch, '; with best f1 is:', best_dev_f1)
247 |
248 | if __name__ == '__main__':
249 | print('---------------------start training-----------------------')
250 | pretrained_model = './bert_pretrain_model'
251 | tokenizer_model = './bert_pretrain_model'
252 | train(batchSize=16, device='cuda:0', lr=0.0001, use_dtp=True, pretrained_model=pretrained_model, tokenizer_model=tokenizer_model, weighted_loss=True)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Sat Dec 5 17:31:42 2020
5 |
6 | @author: luokai
7 | """
8 |
9 | from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, f1_score
10 | from transformers import BertModel, BertTokenizer
11 |
12 |
13 | def get_f1(l_t, l_p):
14 | marco_f1_score = f1_score(l_t, l_p, average='macro')
15 | return marco_f1_score
16 |
17 | def print_result(l_t, l_p):
18 | marco_f1_score = f1_score(l_t, l_p, average='macro')
19 | print(marco_f1_score)
20 | print(f"{'confusion_matrix':*^80}")
21 | print(confusion_matrix(l_t, l_p, ))
22 | print(f"{'classification_report':*^80}")
23 | print(classification_report(l_t, l_p, ))
24 |
25 | def load_tokenizer(path_or_name):
26 | return BertTokenizer.from_pretrained(path_or_name)
27 |
28 | def load_pretrained_model(path_or_name):
29 | return BertModel.from_pretrained(path_or_name)
30 |
31 | def get_task_chinese(task_type):
32 | if task_type == 'ocnli':
33 | return '(中文原版自然语言推理)'
34 | elif task_type == 'ocemotion':
35 | return '(中文情感分类)'
36 | else:
37 | return '(今日头条新闻标题分类)'
--------------------------------------------------------------------------------