├── image └── train.png ├── requirements.txt ├── output └── shopee │ ├── supervise │ └── bsz-64-lr-3e-05-dropout-0.1-threshold-0.3 │ │ └── events.out.tfevents.1643381019.4bd8c2e5-6aaf-420a-8f64-4423fdb9501b.34765.0 │ └── unsupervise │ ├── bsz-64-lr-3e-05-dropout-0.1-threshold-0.3 │ └── events.out.tfevents.1643388993.f6f34af9-8f20-4a33-8780-ba675caea906.4917.0 │ ├── bsz-64-lr-3e-05-dropout-0.2-threshold-0.3 │ └── events.out.tfevents.1643389051.f6f34af9-8f20-4a33-8780-ba675caea906.5130.0 │ └── bsz-64-lr-3e-05-dropout-0.3-threshold-0.3 │ └── events.out.tfevents.1643448238.f6f34af9-8f20-4a33-8780-ba675caea906.90177.0 ├── dataset.py ├── script ├── run_sup_train.sh └── run_unsup_train.sh ├── preprocess_data.py ├── model.py ├── README.md └── train.py /image/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/Shopee-Price-Match-Guarantee/HEAD/image/train.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8 2 | transformers==3.1.0 3 | loguru 4 | numpy 5 | sklearn 6 | pandas 7 | tensorboard -------------------------------------------------------------------------------- /output/shopee/supervise/bsz-64-lr-3e-05-dropout-0.1-threshold-0.3/events.out.tfevents.1643381019.4bd8c2e5-6aaf-420a-8f64-4423fdb9501b.34765.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/Shopee-Price-Match-Guarantee/HEAD/output/shopee/supervise/bsz-64-lr-3e-05-dropout-0.1-threshold-0.3/events.out.tfevents.1643381019.4bd8c2e5-6aaf-420a-8f64-4423fdb9501b.34765.0 -------------------------------------------------------------------------------- /output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.1-threshold-0.3/events.out.tfevents.1643388993.f6f34af9-8f20-4a33-8780-ba675caea906.4917.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/Shopee-Price-Match-Guarantee/HEAD/output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.1-threshold-0.3/events.out.tfevents.1643388993.f6f34af9-8f20-4a33-8780-ba675caea906.4917.0 -------------------------------------------------------------------------------- /output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.2-threshold-0.3/events.out.tfevents.1643389051.f6f34af9-8f20-4a33-8780-ba675caea906.5130.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/Shopee-Price-Match-Guarantee/HEAD/output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.2-threshold-0.3/events.out.tfevents.1643389051.f6f34af9-8f20-4a33-8780-ba675caea906.5130.0 -------------------------------------------------------------------------------- /output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.3-threshold-0.3/events.out.tfevents.1643448238.f6f34af9-8f20-4a33-8780-ba675caea906.90177.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/Shopee-Price-Match-Guarantee/HEAD/output/shopee/unsupervise/bsz-64-lr-3e-05-dropout-0.3-threshold-0.3/events.out.tfevents.1643448238.f6f34af9-8f20-4a33-8780-ba675caea906.90177.0 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | 3 | 4 | class TrainDataset(Dataset): 5 | def __init__(self, data, tokenizer, max_len=256): 6 | self.data = data 7 | 8 | def __len__(self): 9 | return len(self.data) 10 | 11 | def __getitem__(self, index): 12 | return self.data[index] 13 | 14 | 15 | class TestDataset(Dataset): 16 | def __init__(self, data, tokenizer, max_len=256): 17 | self.data = data 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | 22 | def __getitem__(self, index): 23 | return self.data[index] 24 | -------------------------------------------------------------------------------- /script/run_sup_train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --device gpu \ 3 | --output_path output \ 4 | --lr 3e-5 \ 5 | --dropout 0.1 \ 6 | --epochs 10 \ 7 | --batch_size_train 64 \ 8 | --batch_size_eval 256 \ 9 | --num_workers 0 \ 10 | --eval_step 50 \ 11 | --max_len 150 \ 12 | --seed 42 \ 13 | --train_file data/shopee/train.csv \ 14 | --dev_file data/shopee/dev.csv \ 15 | --test_file data/shopee/test.csv \ 16 | --pretrain_model_path pretrain_model/bert-base-uncased \ 17 | --pooler cls \ 18 | --train_mode supervise \ 19 | --overwrite_cache \ 20 | --do_train \ 21 | --do_predict -------------------------------------------------------------------------------- /script/run_unsup_train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --device gpu \ 3 | --output_path output \ 4 | --lr 3e-5 \ 5 | --dropout 0.2 \ 6 | --epochs 10 \ 7 | --batch_size_train 64 \ 8 | --batch_size_eval 256 \ 9 | --num_workers 0 \ 10 | --eval_step 50 \ 11 | --max_len 150 \ 12 | --seed 42 \ 13 | --train_file data/shopee/train.csv \ 14 | --dev_file data/shopee/dev.csv \ 15 | --test_file data/shopee/test.csv \ 16 | --pretrain_model_path pretrain_model/bert-base-uncased \ 17 | --pooler cls \ 18 | --train_mode unsupervise \ 19 | --overwrite_cache \ 20 | --do_train \ 21 | --do_predict -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | import random 5 | from loguru import logger 6 | 7 | 8 | def split_data(input_file, train_file, dev_file, test_file): 9 | """ 10 | 将训练数据,切分为train/dev/test三份数据 11 | """ 12 | dev_size = 500 13 | test_size = 1000 14 | df = pd.read_csv(input_file, sep=',') 15 | logger.info("len of input data:{}".format(len(df))) 16 | label2rows = defaultdict(list) 17 | rows = df.to_dict('records') 18 | 19 | # 收集每个label_group下的数据集合 20 | for row in tqdm(rows): 21 | label = row['label_group'] 22 | label2rows[label].append(row) 23 | label2rows = list(label2rows.items()) 24 | random.shuffle(label2rows) 25 | 26 | # 保存切分后的数据 27 | dev_rows = [] 28 | for label, rows in label2rows[:dev_size]: 29 | dev_rows += rows 30 | df_dev = pd.DataFrame(dev_rows) 31 | df_dev.to_csv(dev_file) 32 | 33 | test_rows = [] 34 | for label, rows in label2rows[dev_size: dev_size + test_size]: 35 | test_rows += rows 36 | df_test = pd.DataFrame(test_rows) 37 | df_test.to_csv(test_file) 38 | 39 | train_rows = [] 40 | for label, rows in label2rows[dev_size + test_size:]: 41 | train_rows += rows 42 | df_train = pd.DataFrame(train_rows) 43 | df_train.to_csv(train_file) 44 | print('dev len:{}'.format(len(dev_rows))) 45 | print('test len:{}'.format(len(test_rows))) 46 | print('train len:{}'.format(len(train_rows))) 47 | 48 | 49 | if __name__ == '__main__': 50 | input_file = 'data/shopee-product-matching/train.csv' # 原始训练集 51 | train_file = 'data/shopee/train.csv' # 切分后的训练集 52 | dev_file = 'data/shopee/dev.csv' # 切分后的验证集 53 | test_file = 'data/shopee/test.csv' # 切分后的测试集 54 | split_data(input_file, train_file, dev_file, test_file) 55 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers import BertModel, BertConfig, BertTokenizer 6 | 7 | 8 | class SimcseModel(nn.Module): 9 | """Simcse无监督模型定义""" 10 | 11 | def __init__(self, pretrained_model, pooling, dropout=0.3): 12 | super(SimcseModel, self).__init__() 13 | config = BertConfig.from_pretrained(pretrained_model) 14 | config.attention_probs_dropout_prob = dropout # 修改config的dropout系数 15 | config.hidden_dropout_prob = dropout 16 | self.bert = BertModel.from_pretrained(pretrained_model, config=config) 17 | self.pooling = pooling 18 | 19 | def forward(self, input_ids, attention_mask, token_type_ids): 20 | out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True, return_dict=True) 21 | if self.pooling == 'cls': 22 | return out.last_hidden_state[:, 0] # [batch, 768] 23 | if self.pooling == 'pooler': 24 | return out.pooler_output # [batch, 768] 25 | if self.pooling == 'last-avg': 26 | last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen] 27 | return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 28 | if self.pooling == 'first-last-avg': 29 | first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen] 30 | last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen] 31 | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 32 | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768] 33 | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768] 34 | return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768] 35 | 36 | 37 | def simcse_unsup_loss(y_pred, device, temp=0.05): 38 | """无监督的损失函数 39 | y_pred (tensor): bert的输出, [batch_size * 2, 768] 40 | 41 | """ 42 | # 得到y_pred对应的label, [1, 0, 3, 2, ..., batch_size-1, batch_size-2] 43 | y_true = torch.arange(y_pred.shape[0], device=device) 44 | y_true = (y_true - y_true % 2 * 2) + 1 45 | # batch内两两计算相似度, 得到相似度矩阵(对角矩阵) 46 | sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1) 47 | # 将相似度矩阵对角线置为很小的值, 消除自身的影响 48 | sim = sim - torch.eye(y_pred.shape[0], device=device) * 1e12 49 | # 相似度矩阵除以温度系数 50 | sim = sim / temp 51 | # 计算相似度矩阵与y_true的交叉熵损失 52 | # 计算交叉熵,每个case都会计算与其他case的相似度得分,得到一个得分向量,目的是使得该得分向量中正样本的得分最高,负样本的得分最低 53 | loss = F.cross_entropy(sim, y_true) 54 | return torch.mean(loss) 55 | 56 | 57 | def simcse_sup_loss(y_pred, device, lamda=0.05): 58 | """ 59 | 有监督损失函数 60 | """ 61 | similarities = F.cosine_similarity(y_pred.unsqueeze(0), y_pred.unsqueeze(1), dim=2) 62 | row = torch.arange(0, y_pred.shape[0], 3) 63 | col = torch.arange(0, y_pred.shape[0]) 64 | col = col[col % 3 != 0] 65 | 66 | similarities = similarities[row, :] 67 | similarities = similarities[:, col] 68 | similarities = similarities / lamda 69 | 70 | y_true = torch.arange(0, len(col), 2, device=device) 71 | loss = F.cross_entropy(similarities, y_true) 72 | return loss 73 | 74 | 75 | if __name__ == '__main__': 76 | y_pred = torch.rand((30 ,16)) 77 | loss = simcse_sup_loss(y_pred, 'cpu', lamda=0.05) 78 | print(loss) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shopee-Price Match Guarantee(虾皮同款商品匹配) 2 | 3 | ## 比赛简介 4 | [Shopee-Price Match Guarantee](https://www.kaggle.com/c/shopee-product-matching/overview) 是Shopee(虾皮)发起的同款商品匹配的比赛, 5 | 可用于同款商品匹配,商品比价等场景。 6 | 7 | ### 数据集 8 | 训练集(34250条数据)格式如下,每条数据包含5个字段: 9 | - posting_id:商品编号 10 | - image:商品的图片文件名 11 | - image_phash:商品图片的哈希值 12 | - title:商品标题 13 | - label_group:商品类别 14 | 15 | 22 | 23 | 测试集(3条数据样例)格式如下: 24 | - posting_id:商品编号 25 | - image:商品的图片文件名 26 | - image_phash:商品图片的哈希值 27 | - title:商品标题 28 | 29 | 35 | 36 | ### 评价指标 37 | 对于测试集中的每个商品,预测出其对应的同款商品的posting_id列表,上限为50个。对于每个商品的F1得分取平均,得到最终的F1得分。 38 | 39 | 40 | ## 项目简介 41 | 为了验证对比学习的有效性,本项目使用对比学习的方式对title进行相似度建模,本项目只使用了title信息,没有引入商品图片信息。 42 | 43 | 训练目的:对于每个batch中的每个样本,拉近其与正样本之间的距离,拉远其与负样本之间的距离,使得模型能够学习到文本之间的相似关系。 44 | 45 | 对比学习的难点在于正样本的构造,本项目分别验证了有监督与无监督的效果: 46 | - 有监督对比训练:相同label_group的样本可以看做互为正样本 47 | - 无监督对比训练:对于每个title,使用dropout的方式进行正样本增强 48 | 49 | 50 | ## 数据处理 51 | 由于比赛中,只公开了训练集,测试集为黑盒测试集,也就是说没有公开测试集。为了在本地验证方法的有效性,我们将原有的训练集切分为三部分:训练集、验证集、测试集(后续的数据集名称均以此为准)。 52 | 53 | 原有训练集包含34250条数据,11014个label_group。我们随机取500个label_group的数据作为验证集,1000个label_group的数据作为测试集,剩下9514个label_group作为最终的训练集。 54 | 最终的数据分布: 55 | - 训练集:29693条数据,分别属于9514个label_group 56 | - 验证集:1499条数据,分别属于500个label_group 57 | - 测试集:3058条数据,分别属于1000个label_group 58 | 59 | 可以直接运行脚本preprocess_data.py进行数据集的划分(由于是随机划分的,每次的划分结果均不同),也可以直接使用data/shopee目录下已经划分好的数据集进行实验验证,本项目的实验结果基于该目录下的数据集。 60 | 61 | 62 | ## 运行环境 63 | python==3.6、transformers==3.1.0、torch==1.8.0 64 | 65 | 运行下面脚本安装依赖环境: 66 | ``` 67 | pip install -r requirements.txt 68 | ``` 69 | 70 | 71 | ## 项目结构 72 | - data:存放训练数据 73 | - shopee:虾皮数据集 74 | - dev.csv:验证集 75 | - test.csv:测试集 76 | - train.csv:训练集 77 | - output:输出目录 ,包括训练日志、tensorboard输出、模型保存等 78 | - pretrain_model:预训练模型存放位置 79 | - script:脚本存放位置。 80 | - dataset.py 81 | - model.py:模型代码 82 | - train.py:训练代码 83 | - preprocess_data.py:切分数据集 84 | 85 | 86 | ## 使用方法 87 | ### Quick Start 88 | 89 | 无监督训练,运行脚本 90 | ``` 91 | bash script/run_unsup_train.sh 92 | ``` 93 | 有监督训练,运行脚本 94 | ``` 95 | bash script/run_sup_train.sh 96 | ``` 97 | 98 | ## 实验总结 99 | 使用余弦距离衡量title之间的相似度,在训练时,threshold设为0.3,也就是当两个title之间的余弦距离小于0.3时,认为两个商品为同款,从而计算验证集的F1、Precision与Recall,进行checkpoint保存。 100 | 101 | ### 有监督与无监督实验对比 102 | 下表为threshold=0.3时,各模型的实验效果,从下表可得出以下结论: 103 | - 在有监督训练中,模型能够更好对title的相似性进行建模,效果远比无监督训练要好。 104 | - 在无监督训练中,dropout为0.3的时候,训练效果比0.1与0.2更好。(玄学调参,在[SimCSE实验复现](https://github.com/yangjianxin1/SimCSE) 项目中,我们得出的结论是dropou=0.2为最佳) 105 | 106 | | 训练方法|learning rate | batch size | dropout | save step|验证集F1| 验证集Precision|验证集Recall|测试集F1|测试集Precision|测试集Recall| 107 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 108 | |有监督| 3e-5 | 64 |0.1|7150|__0.847__|0.901|0.860|__0.791__|0.809|0.876| 109 | |无监督| 3e-5 | 64 |0.1|50|0.638|0.967|0.532|0.628|0.919|0.554| 110 | |无监督| 3e-5 | 64 |0.2|300|0.670|0.976|0.566|0.674|0.954|0.586| 111 | |无监督| 3e-5 | 64 |0.3|350|0.689|0.970|0.594|0.680|0.939|0.604| 112 | 113 | 114 | ### Threshold分析 115 | 对于有监督训练,threshold的取值,对结果的影响如下: 116 | 117 | |threshold|测试集F1|测试集Precision|测试集Recall| 118 | | ---- | ---- | ---- | ---- | 119 | |0.4|0.704|0.642|0.932| 120 | |0.3|0.791|0.809|0.876| 121 | |0.27|0.793|0.837|0.849| 122 | |0.26|0.796|0.849|0.842| 123 | |__0.25__|__0.797__|0.859|0.834| 124 | |0.24|0.796|0.868|0.824| 125 | |0.23|0.794|0.876|0.813| 126 | |0.2|0.781|0.900|0.770| 127 | |0.1|0.689|0.968|0.597| 128 | 129 | 130 | 131 | ### 训练过程分析 132 | 训练过程中,各个模型在验证集上的F1得分的变化曲线如下图,可以看到,有监督训练比无监督训练的效果更好: 133 | 134 | |颜色|训练方法|dropout| 135 | | ---- | ---- | ---- | 136 | |橙色|有监督|0.1| 137 | |绿色|无监督|0.3| 138 | |粉色|无监督|0.2| 139 | |蓝色|无监督|0.1| 140 | 141 | ![avatar](./image/train.png) 142 | 143 | 144 | ## REFERENCE 145 | - https://github.com/yangjianxin1/SimCSE 146 | - https://arxiv.org/pdf/2104.08821.pdf 147 | - https://github.com/princeton-nlp/SimCSE 148 | - https://kexue.fm/archives/8348 149 | - https://github.com/bojone/SimCSE 150 | 151 | ## TODO 152 | - 验证模型在黑盒测试集上的效果 153 | - 使用Faiss,实现一个简单的同款商品检索功能 154 | 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | from loguru import logger 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | from dataset import TrainDataset, TestDataset 11 | from model import SimcseModel, simcse_unsup_loss, simcse_sup_loss 12 | from transformers import BertModel, BertConfig, BertTokenizer 13 | import os 14 | from os.path import join 15 | from torch.utils.tensorboard import SummaryWriter 16 | import random 17 | import pickle 18 | import pandas as pd 19 | import time 20 | from collections import defaultdict 21 | from sklearn.metrics.pairwise import cosine_distances 22 | 23 | 24 | def seed_everything(seed=42): 25 | ''' 26 | 设置整个开发环境的seed 27 | :param seed: 28 | :param device: 29 | :return: 30 | ''' 31 | random.seed(seed) 32 | os.environ['PYTHONHASHSEED'] = str(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | # some cudnn methods can be random even after fixing the seed 38 | # unless you tell it to be deterministic 39 | torch.backends.cudnn.deterministic = True 40 | 41 | 42 | def train(model, train_loader, dev_loader, dev_df, optimizer, args): 43 | logger.info("start training") 44 | model.train() 45 | device = args.device 46 | best = 0 47 | for epoch in range(args.epochs): 48 | for batch_idx, data in enumerate(tqdm(train_loader)): 49 | step = epoch * len(train_loader) + batch_idx 50 | # [batch, n, seq_len] -> [batch * n, sql_len] 51 | sql_len = data['input_ids'].shape[-1] 52 | input_ids = data['input_ids'].view(-1, sql_len).to(device) 53 | attention_mask = data['attention_mask'].view(-1, sql_len).to(device) 54 | token_type_ids = data['token_type_ids'].view(-1, sql_len).to(device) 55 | 56 | out = model(input_ids, attention_mask, token_type_ids) 57 | if args.train_mode == 'unsupervise': 58 | loss = simcse_unsup_loss(out, device) 59 | else: 60 | loss = simcse_unsup_loss(out, device) 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | step += 1 65 | 66 | if step % args.eval_step == 0: 67 | precision, recall, f1, predictions = evaluate(model, dev_loader, dev_df, device, 68 | threshold=args.threshold) 69 | logger.info('loss:{}, f1:{}, precision: {}, recall:{} in step {} epoch {}'.format( 70 | loss, f1, precision, recall, step, epoch) 71 | ) 72 | writer.add_scalar('loss', loss, step) 73 | writer.add_scalar('f1', f1, step) 74 | writer.add_scalar('precision', precision, step) 75 | writer.add_scalar('recall', recall, step) 76 | 77 | model.train() 78 | if best < f1: 79 | best = f1 80 | torch.save(model.state_dict(), join(args.output_path, 'simcse.pt')) 81 | logger.info('higher f1: {} in step {} epoch {}, save model'.format(best, step, epoch)) 82 | 83 | 84 | def evaluate(model, dataloader, df, device, threshold=0.5): 85 | model.eval() 86 | 87 | embeddings = torch.tensor([], device=device) 88 | with torch.no_grad(): 89 | for source in tqdm(dataloader): 90 | # source [batch, 1, seq_len] -> [batch, seq_len] 91 | sql_len = source['input_ids'].shape[-1] 92 | source_input_ids = source.get('input_ids').view(-1, sql_len).to(device) 93 | source_attention_mask = source.get('attention_mask').view(-1, sql_len).to(device) 94 | source_token_type_ids = source.get('token_type_ids').view(-1, sql_len).to(device) 95 | # pdb.set_trace() 96 | source_pred = model(source_input_ids, source_attention_mask, source_token_type_ids) 97 | 98 | # embeddings = np.append(embeddings, source_pred.numpy()) 99 | embeddings = torch.cat((embeddings, source_pred), dim=0) 100 | 101 | distances = cosine_distances(embeddings.cpu(), embeddings.cpu()) 102 | distances = torch.from_numpy(distances).to(device) 103 | distances, indices = torch.sort(distances, dim=1) 104 | distances = distances.cpu().numpy() 105 | indices = indices.cpu().numpy() 106 | 107 | # get predictions 108 | predictions = [] 109 | precision_lst = [] 110 | recall_lst = [] 111 | f1_lst = [] 112 | for k in range(embeddings.shape[0]): 113 | # pdb.set_trace() 114 | # 第k个数据的label 115 | label = df['label_group'].iloc[k] 116 | idx = np.where(distances[k,] < threshold)[0] 117 | ids = indices[k, idx] 118 | # 模型认为相似的title的label 119 | predict_lst = df['label_group'].iloc[ids].values 120 | # 预测正确的数量 121 | num_right = len([x for x in predict_lst if x == label]) 122 | # 该label实际上存在的case的数量 123 | num_label = len(df.loc[df['label_group'] == label]) 124 | 125 | precision = num_right / len(predict_lst) 126 | recall = num_right / num_label 127 | f1 = 2 * (precision * recall) / (precision + recall) 128 | precision_lst.append(precision) 129 | recall_lst.append(recall) 130 | f1_lst.append(f1) 131 | 132 | posting_ids = np.unique(df['posting_id'].iloc[ids].values) 133 | predictions.append(posting_ids) 134 | 135 | precision = sum(precision_lst) / len(precision_lst) 136 | recall = sum(recall_lst) / len(recall_lst) 137 | f1 = sum(f1_lst) / len(f1_lst) 138 | return precision, recall, f1, predictions 139 | 140 | 141 | def load_train_data_unsupervised(tokenizer, args): 142 | """ 143 | 获取无监督训练语料,对于每个title,复制一份作为正样本 144 | """ 145 | logger.info('loading unsupervised train data') 146 | output_path = os.path.dirname(args.output_path) 147 | train_file_cache = join(output_path, 'train-unsupervise.pkl') 148 | if os.path.exists(train_file_cache) and not args.overwrite_cache: 149 | with open(train_file_cache, 'rb') as f: 150 | feature_list = pickle.load(f) 151 | logger.info("len of train data:{}".format(len(feature_list))) 152 | return feature_list 153 | feature_list = [] 154 | df = pd.read_csv(args.train_file, sep=',') 155 | rows = df.to_dict('records') 156 | for row in tqdm(rows): 157 | title = row['title'] 158 | title_ids = tokenizer([title, title], max_length=args.max_len, truncation=True, padding='max_length', 159 | return_tensors='pt') 160 | feature_list.append(title_ids) 161 | 162 | logger.info("len of train data:{}".format(len(feature_list))) 163 | with open(train_file_cache, 'wb') as f: 164 | pickle.dump(feature_list, f) 165 | return feature_list 166 | 167 | 168 | def load_train_data_supervised(tokenizer, args): 169 | """ 170 | 获取有监督训练数据,同一个类别下,两两title组成一条训练数据 171 | """ 172 | # 加载缓存数据 173 | logger.info('loading supervised train data') 174 | output_path = os.path.dirname(args.output_path) 175 | train_file_cache = join(output_path, 'train-supervised.pkl') 176 | if os.path.exists(train_file_cache) and not args.overwrite_cache: 177 | with open(train_file_cache, 'rb') as f: 178 | feature_list = pickle.load(f) 179 | logger.info("len of train data:{}".format(len(feature_list))) 180 | return feature_list 181 | feature_list = [] 182 | df = pd.read_csv(args.train_file, sep=',') 183 | logger.info("len of train data:{}".format(len(df))) 184 | label2titles = defaultdict(list) 185 | rows = df.to_dict('records') 186 | # rows = rows[:10000] 187 | 188 | # 收集每个label_group下的title集合 189 | for row in tqdm(rows): 190 | title = row['title'] 191 | label = row['label_group'] 192 | label2titles[label].append(title) 193 | 194 | # todo 195 | for label, titles in tqdm(label2titles.items()): 196 | # 同一类别下,两两组成一条训练数据 197 | titles_tokens = tokenizer(titles, max_length=args.max_len, truncation=True, padding='max_length', 198 | return_tensors='pt') 199 | for i in range(len(titles)): 200 | for j in range(len(titles)): 201 | if i >= j: 202 | continue 203 | input_ids = torch.cat( 204 | [titles_tokens['input_ids'][i].unsqueeze(0), titles_tokens['input_ids'][j].unsqueeze(0)], dim=0) 205 | token_type_ids = torch.cat( 206 | [titles_tokens['token_type_ids'][i].unsqueeze(0), titles_tokens['token_type_ids'][j].unsqueeze(0)], 207 | dim=0) 208 | attention_mask = torch.cat( 209 | [titles_tokens['attention_mask'][i].unsqueeze(0), titles_tokens['attention_mask'][j].unsqueeze(0)], 210 | dim=0) 211 | feature_list.append( 212 | {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}) 213 | 214 | logger.info("len of train data:{}".format(len(feature_list))) 215 | with open(train_file_cache, 'wb') as f: 216 | pickle.dump(feature_list, f) 217 | return feature_list 218 | 219 | 220 | def load_eval_data(tokenizer, args, mode): 221 | """ 222 | 加载验证集或者测试集 223 | """ 224 | assert mode in ['dev', 'test'], 'mode should in ["dev", "test"]' 225 | logger.info('loading {} data'.format(mode)) 226 | output_path = os.path.dirname(args.output_path) 227 | eval_file_cache = join(output_path, '{}.pkl'.format(mode)) 228 | if os.path.exists(eval_file_cache) and not args.overwrite_cache: 229 | with open(eval_file_cache, 'rb') as f: 230 | feature_list = pickle.load(f) 231 | logger.info("len of {} data:{}".format(mode, len(feature_list))) 232 | return feature_list 233 | 234 | if mode == 'dev': 235 | eval_file = args.dev_file 236 | else: 237 | eval_file = args.test_file 238 | 239 | df = pd.read_csv(eval_file, sep=',') 240 | logger.info("len of {} data:{}".format(mode, len(df))) 241 | feature_list = [] 242 | rows = df.to_dict('records') 243 | for index, row in enumerate(tqdm(rows)): 244 | title = row['title'] 245 | feature = tokenizer(title, max_length=args.max_len, truncation=True, padding='max_length', 246 | return_tensors='pt') 247 | feature_list.append(feature) 248 | 249 | res = {'df': df, 'feature_list': feature_list} 250 | with open(eval_file_cache, 'wb') as f: 251 | pickle.dump(res, f) 252 | return res 253 | 254 | 255 | def main(args): 256 | # 加载模型 257 | config = BertConfig.from_pretrained(args.pretrain_model_path) 258 | tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_path) 259 | assert args.pooler in ['cls', "pooler", "last-avg", "first-last-avg"], \ 260 | 'pooler should in ["cls", "pooler", "last-avg", "first-last-avg"]' 261 | model = SimcseModel(pretrained_model=args.pretrain_model_path, pooling=args.pooler, dropout=args.dropout).to( 262 | args.device) 263 | # pdb.set_trace() 264 | if args.do_train: 265 | # 加载数据集 266 | assert args.train_mode in ['supervise', 'unsupervise'], \ 267 | "train_mode should in ['supervise', 'unsupervise']" 268 | if args.train_mode == 'supervise': 269 | train_data = load_train_data_supervised(tokenizer, args) 270 | elif args.train_mode == 'unsupervise': 271 | train_data = load_train_data_unsupervised(tokenizer, args) 272 | train_dataset = TrainDataset(train_data, tokenizer, max_len=args.max_len) 273 | # train_dataset = train_dataset[:32] 274 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size_train, shuffle=True, 275 | num_workers=args.num_workers) 276 | dev_data = load_eval_data(tokenizer, args, 'dev') 277 | dev_df = dev_data['df'] 278 | dev_dataset = TestDataset(dev_data['feature_list'], tokenizer, max_len=args.max_len) 279 | # dev_dataset = dev_dataset[:8] 280 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size_eval, shuffle=False, 281 | num_workers=args.num_workers) 282 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 283 | train(model, train_dataloader, dev_dataloader, dev_df, optimizer, args) 284 | if args.do_predict: 285 | test_data = load_eval_data(tokenizer, args, 'test') 286 | test_df = test_data['df'] 287 | test_dataset = TestDataset(test_data['feature_list'], tokenizer, max_len=args.max_len) 288 | # test_dataset = test_dataset[:8] 289 | test_df = test_df.iloc() 290 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size_eval, shuffle=False, 291 | num_workers=args.num_workers) 292 | path = join(args.output_path, 'simcse.pt') 293 | logger.info(path) 294 | model.load_state_dict(torch.load(path)) 295 | model.eval() 296 | print(int(model.bert.embeddings.position_ids[0, -1])) 297 | precision, recall, f1, predictions = evaluate(model, test_dataloader, test_df, args.device, 298 | threshold=args.threshold) 299 | logger.info('testset precision:{}, recall:{}, f1:{}'.format(precision, recall, f1)) 300 | 301 | 302 | if __name__ == '__main__': 303 | parser = argparse.ArgumentParser() 304 | parser.add_argument("--device", type=str, default='gpu', choices=['gpu', 'cpu'], help="gpu or cpu") 305 | parser.add_argument("--output_path", type=str, default='output/shopee') 306 | parser.add_argument("--lr", type=float, default=3e-5) 307 | parser.add_argument("--dropout", type=float, default=0.1) 308 | parser.add_argument("--epochs", type=int, default=10) 309 | parser.add_argument("--batch_size_train", type=int, default=64) 310 | parser.add_argument("--batch_size_eval", type=int, default=256) 311 | parser.add_argument("--num_workers", type=int, default=0) 312 | parser.add_argument("--eval_step", type=int, default=50, help="every eval_step to evaluate model") 313 | parser.add_argument("--max_len", type=int, default=150, help="max length of input") 314 | parser.add_argument("--threshold", type=float, default=0.3, help="threshold") 315 | parser.add_argument("--seed", type=int, default=42, help="random seed") 316 | parser.add_argument("--train_file", type=str, default="data/shopee/train.csv") 317 | parser.add_argument("--dev_file", type=str, default="data/shopee/dev.csv") 318 | parser.add_argument("--test_file", type=str, default="data/shopee/test.csv") 319 | parser.add_argument("--pretrain_model_path", type=str, 320 | default="pretrain_model/bert-base-uncased") 321 | parser.add_argument("--pooler", type=str, choices=['cls', "pooler", "last-avg", "first-last-avg"], 322 | default='cls', help='pooler to use') 323 | parser.add_argument("--train_mode", type=str, default='supervise', choices=['unsupervise', 'supervise'], 324 | help="unsupervise or supervise") 325 | parser.add_argument("--overwrite_cache", action='store_true', default=True, help="overwrite cache") 326 | parser.add_argument("--do_train", action='store_true', default=True) 327 | parser.add_argument("--do_predict", action='store_true', default=True) 328 | 329 | args = parser.parse_args() 330 | seed_everything(args.seed) 331 | args.device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'gpu' else "cpu") 332 | args.output_path = join(args.output_path, args.train_mode, 333 | 'bsz-{}-lr-{}-dropout-{}-threshold-{}'.format(args.batch_size_train, args.lr, args.dropout, 334 | args.threshold)) 335 | 336 | if not os.path.exists(args.output_path): 337 | os.makedirs(args.output_path) 338 | 339 | if args.do_train: 340 | cur_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 341 | logger.add(join(args.output_path, 'train-{}.log'.format(cur_time))) 342 | logger.info(args) 343 | writer = SummaryWriter(args.output_path) 344 | main(args) 345 | 346 | 347 | --------------------------------------------------------------------------------