├── README.md ├── config.py ├── data ├── raw.txt └── 中文品牌_适用季节.pkl ├── img ├── 1.png ├── 2.png └── 3.png ├── main.py ├── models ├── LSTM_CRF.py ├── OpenTag_2019.py ├── __init__.py ├── __pycache__ │ ├── LSTM_CRF.cpython-36.pyc │ ├── OpenTag_2019.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── basic_module.cpython-36.pyc │ └── squeeze_embedding.cpython-36.pyc ├── basic_module.py └── squeeze_embedding.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── dataset.cpython-36.pyc ├── data_process.py └── dataset.py /README.md: -------------------------------------------------------------------------------- 1 | # OpenTag_2019 2 | [Scaling Up Open Tagging from Tens to Thousands: Comprehension Empowered Attribute Value Extraction from Product Title](https://www.aclweb.org/anthology/P19-1514) 3 | > 该论文是在[OpenTag: Open Attribute Value Extraction from Product Profiles](https://arxiv.org/pdf/1806.01264.pdf)的基础上做的改进。模型结构如下: 4 | 5 | ![模型结构](/img/1.png) 6 | 7 | ### requirements 8 | > 1. pytorch 9 | > 2. pytorch-transformer 10 | > 3. sklearn 11 | > 4. seqeval 12 | > 5. tqdm 13 | > 6. torchcrf 14 | 15 | ### data 16 | 1. data目录下的raw.txt 是全量数据,中文品牌_适用季节.pkl 是从中抽出用来实验的小数据集 17 | 2. utils下的data_process.py 提供两种获得实验数据的方式,bert分词和不用bert分词,运行 18 | python data_process.py 可以得到 中文品牌_适用季节.pkl 19 | 3. 想要获取全量数据自己看data_process.py 应该也可以看明白了 20 | 4. dataset.py 封装了Dataset和DataLoader 21 | 22 | ### model 23 | 1. 提供了两个模型,LSTM_CRF.py 做一个baseline 24 | 2. OpenTag_2019.py 复现的是该论文的模型结构 25 | 26 | ### run 27 | 1. python main.py train --batch_size=128 即可运行 28 | 2. 相应的配置可以更改config.py 29 | 30 | ### result 31 | > 1. 没有很仔细的去调参,该结果看看就好了。需要注意的是使用bert时,lr应该在2e-5、3e-5等,bert对学习率还是非常敏感的 32 | > 2. 在小量数据集上的实验结果 (中文品牌_适用季节.pkl) 33 | 34 | + LSTM_CRF 35 | ![lstm_crf](/img/2.png) 36 | 37 | + OpenTag_2019 38 | ![opentag_2019](/img/3.png) 39 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import warnings 3 | import torch as t 4 | 5 | class DefaultConfig(object): 6 | env = 'default' # visdom 环境 7 | vis_port =8097 # visdom 端口 8 | model = 'OpenTag2019' # 使用的模型,名字必须与models/__init__.py中的名字一致 9 | pretrained_bert_name = 'bert-base-chinese' 10 | 11 | pickle_path = './data/中文品牌_适用季节.pkl' 12 | load_model_path = None # 加载预训练的模型的路径,为None代表不加载 13 | 14 | batch_size = 32 # batch size 15 | embedding_dim = 768 16 | hidden_dim = 1024 17 | tagset_size = 4 18 | use_gpu = True # user GPU or not 19 | num_workers = 4 # how many workers for loading data 20 | print_freq = 100 # print info every N batch 21 | 22 | max_epoch = 20 23 | lr = 2e-5 # initial learning rate 24 | lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay 25 | weight_decay = 0e-5 # L2正则 26 | dropout = 0.2 27 | seed = 1234 28 | device = 'cuda' 29 | 30 | 31 | def _parse(self, kwargs): 32 | """ 33 | 根据字典kwargs 更新 config参数 34 | """ 35 | for k, v in kwargs.items(): 36 | if not hasattr(self, k): 37 | warnings.warn("Warning: opt has not attribut %s" % k) 38 | setattr(self, k, v) 39 | 40 | opt.device =t.device('cuda') if opt.use_gpu else t.device('cpu') 41 | 42 | 43 | print('user config:') 44 | for k, v in self.__class__.__dict__.items(): 45 | if not k.startswith('_'): 46 | print(k, getattr(self, k)) 47 | 48 | opt = DefaultConfig() 49 | -------------------------------------------------------------------------------- /data/中文品牌_适用季节.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/data/中文品牌_适用季节.pkl -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/img/3.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from time import strftime, localtime 4 | from collections import Counter 5 | from config import opt 6 | from pytorch_transformers import BertTokenizer 7 | import random 8 | import numpy as np 9 | import torch 10 | import models 11 | from utils import get_dataloader 12 | # from sklearn.metrics import precision_score, recall_score, f1_score, classification_report 13 | from seqeval.metrics import f1_score, accuracy_score, classification_report 14 | from tqdm import tqdm 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 16 | 17 | import logging 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | logger.addHandler(logging.StreamHandler(sys.stdout)) 21 | 22 | def get_attributes(path='./data/raw.txt'): 23 | atts = [] 24 | with open(path, 'r') as f: 25 | for line in f.readlines(): 26 | line = line.strip('\n') 27 | if line: 28 | title, attribute, value = line.split('<$$$>') 29 | atts.append(attribute) 30 | return [item[0] for item in Counter(atts).most_common()] 31 | 32 | def train(**kwargs): 33 | log_file = '{}-{}.log'.format(opt.model, strftime("%y%m%d-%H%M", localtime())) 34 | logger.addHandler(logging.FileHandler(log_file)) 35 | 36 | att_list = get_attributes() 37 | 38 | tokenizer = BertTokenizer.from_pretrained(opt.pretrained_bert_name) 39 | tags2id = {'':0,'B':1,'I':2,'O':3} 40 | id2tags = {v:k for k,v in tags2id.items()} 41 | 42 | opt._parse(kwargs) 43 | 44 | if opt.seed is not None: 45 | random.seed(opt.seed) 46 | np.random.seed(opt.seed) 47 | torch.manual_seed(opt.seed) 48 | torch.cuda.manual_seed(opt.seed) 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | 52 | # step1: configure model 53 | model = getattr(models, opt.model)(opt) 54 | if opt.load_model_path: 55 | model.load(opt.load_model_path) 56 | model.to(opt.device) 57 | 58 | # step2: data 59 | train_dataloader,valid_dataloader,test_dataloader = get_dataloader(opt) 60 | 61 | # step3: criterion and optimizer 62 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 63 | lr = opt.lr 64 | optimizer = model.get_optimizer(lr, opt.weight_decay) 65 | 66 | # step4 train 67 | for epoch in range(opt.max_epoch): 68 | model.train() 69 | for ii,batch in tqdm(enumerate(train_dataloader)): 70 | # train model 71 | optimizer.zero_grad() 72 | x = batch['x'].to(opt.device) 73 | y = batch['y'].to(opt.device) 74 | att = batch['att'].to(opt.device) 75 | inputs = [x, att, y] 76 | loss = model.log_likelihood(inputs) 77 | loss.backward() 78 | #CRF 79 | torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=3) 80 | optimizer.step() 81 | if ii % opt.print_freq == 0: 82 | print('epoch:%04d,------------loss:%f'%(epoch,loss.item())) 83 | 84 | model.save() 85 | 86 | preds, labels = [], [] 87 | for index, batch in enumerate(valid_dataloader): 88 | model.eval() 89 | x = batch['x'].to(opt.device) 90 | y = batch['y'].to(opt.device) 91 | att = batch['att'].to(opt.device) 92 | inputs = [x, att, y] 93 | predict = model(inputs) 94 | 95 | if index % 5 == 0: 96 | print(tokenizer.convert_ids_to_tokens([i.item() for i in x[0].cpu() if i.item()>0])) 97 | length = [id2tags[i.item()] for i in y[0].cpu() if i.item()>0] 98 | print(length) 99 | print([id2tags[i] for i in predict[0][:len(length)]]) 100 | 101 | # 统计非0的,也就是真实标签的长度 102 | leng = [] 103 | for i in y.cpu(): 104 | tmp = [] 105 | for j in i: 106 | if j.item()>0: 107 | tmp.append(j.item()) 108 | leng.append(tmp) 109 | 110 | for index, i in enumerate(predict): 111 | preds.append([id2tags[k] if k>0 else id2tags[3] for k in i[:len(leng[index])]]) 112 | # preds += i[:len(leng[index])] 113 | 114 | for index, i in enumerate(y.tolist()): 115 | labels.append([id2tags[k] if k>0 else id2tags[3] for k in i[:len(leng[index])]]) 116 | # labels += i[:len(leng[index])] 117 | # precision = precision_score(labels, preds, average='macro') 118 | # recall = recall_score(labels, preds, average='macro') 119 | # f1 = f1_score(labels, preds, average='macro') 120 | report = classification_report(labels, preds) 121 | print(report) 122 | logger.info(report) 123 | 124 | 125 | if __name__=='__main__': 126 | import fire 127 | fire.Fire() -------------------------------------------------------------------------------- /models/LSTM_CRF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchcrf import CRF 5 | from .basic_module import BasicModule 6 | from .squeeze_embedding import SqueezeEmbedding 7 | 8 | 9 | class NERLSTM_CRF(BasicModule): 10 | def __init__(self, opt, embedding_dim=200, hidden_dim=300, dropout=0.2, word2id=100000, tag2id=4): 11 | super(NERLSTM_CRF, self).__init__() 12 | 13 | self.opt = opt 14 | self.embedding_dim = embedding_dim 15 | self.hidden_dim = hidden_dim 16 | self.vocab_size = word2id + 1 17 | self.tag_to_ix = tag2id 18 | self.tagset_size = tag2id 19 | 20 | self.word_embeds = nn.Embedding(self.vocab_size, self.embedding_dim) 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | #CRF 24 | self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=False) 25 | 26 | self.hidden2tag = nn.Linear(self.hidden_dim, self.tagset_size) 27 | self.crf = CRF(self.tagset_size) 28 | 29 | def forward(self, inputs): 30 | x, att, tags = inputs 31 | #CRF 32 | x = x.transpose(0,1) 33 | 34 | embedding = self.word_embeds(x) 35 | outputs, hidden = self.lstm(embedding) 36 | outputs = self.dropout(outputs) 37 | outputs = self.hidden2tag(outputs) 38 | #CRF 39 | outputs = self.crf.decode(outputs) 40 | return outputs 41 | 42 | def log_likelihood(self, inputs): 43 | x, att, tags = inputs 44 | x = x.transpose(0,1) 45 | tags = tags.transpose(0,1) 46 | embedding = self.word_embeds(x) 47 | outputs, hidden = self.lstm(embedding) 48 | outputs = self.dropout(outputs) 49 | outputs = self.hidden2tag(outputs) 50 | return - self.crf(outputs, tags) 51 | -------------------------------------------------------------------------------- /models/OpenTag_2019.py: -------------------------------------------------------------------------------- 1 | from .basic_module import BasicModule 2 | from pytorch_transformers import BertModel 3 | import torch 4 | from torchcrf import CRF 5 | from .squeeze_embedding import SqueezeEmbedding 6 | 7 | class OpenTag2019(BasicModule): 8 | def __init__(self, opt, *args, **kwargs): 9 | super(OpenTag2019, self).__init__(*args, **kwargs) 10 | 11 | self.model_name = 'opentag2019' 12 | self.opt = opt 13 | self.embedding_dim = opt.embedding_dim 14 | self.hidden_dim = opt.hidden_dim 15 | self.tagset_size = opt.tagset_size 16 | 17 | self.bert = BertModel.from_pretrained(opt.pretrained_bert_name) 18 | self.word_embeds = torch.nn.Embedding(30000, self.opt.embedding_dim) 19 | 20 | self.dropout = torch.nn.Dropout(opt.dropout) 21 | 22 | self.squeeze_embedding = SqueezeEmbedding() 23 | 24 | #CRF 25 | self.lstm = torch.nn.LSTM(self.embedding_dim, self.hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True) 26 | 27 | self.hidden2tag = torch.nn.Linear(self.hidden_dim*2, self.tagset_size) 28 | self.crf = CRF(self.tagset_size, batch_first=True) 29 | 30 | def calculate_cosin(self, context_output, att_hidden): 31 | ''' 32 | context_output (batchsize, seqlen, hidden_dim) 33 | att_hidden (batchsize, hidden_dim) 34 | ''' 35 | batchsize,seqlen,hidden_dim = context_output.size() 36 | att_hidden = att_hidden.unsqueeze(1).repeat(1,seqlen,1) 37 | 38 | context_output = context_output.float() 39 | att_hidden = att_hidden.float() 40 | 41 | cos = torch.sum(context_output*att_hidden, dim=-1)/(torch.norm(context_output, dim=-1)*torch.norm(att_hidden, dim=-1)) 42 | cos = cos.unsqueeze(-1) 43 | cos_output = context_output*cos 44 | outputs = torch.cat([context_output, cos_output], dim=-1) 45 | 46 | return outputs 47 | 48 | def forward(self, inputs): 49 | context, att, target = inputs[0], inputs[1], inputs[2] 50 | context_len = torch.sum(context != 0, dim=-1) 51 | att_len = torch.sum(att != 0, dim=-1) 52 | 53 | context = self.squeeze_embedding(context, context_len) 54 | context, _ = self.bert(context) 55 | # context = self.word_embeds(context) 56 | context_output, _ = self.lstm(context) 57 | 58 | att = self.squeeze_embedding(att, att_len) 59 | att, _ = self.bert(att) 60 | # att = self.word_embeds(att) 61 | _, att_hidden = self.lstm(att) 62 | att_hidden = torch.cat([att_hidden[0][-2],att_hidden[0][-1]], dim=-1) 63 | 64 | outputs = self.calculate_cosin(context_output, att_hidden) 65 | outputs = self.dropout(outputs) 66 | 67 | outputs = self.hidden2tag(outputs) 68 | #CRF 69 | # outputs = outputs.transpose(0,1).contiguous() 70 | outputs = self.crf.decode(outputs) 71 | return outputs 72 | 73 | 74 | 75 | def log_likelihood(self, inputs): 76 | context, att, target = inputs[0], inputs[1], inputs[2] 77 | context_len = torch.sum(context != 0, dim=-1) 78 | att_len = torch.sum(att != 0, dim=-1) 79 | target_len = torch.sum(target != 0, dim=-1) 80 | 81 | target = self.squeeze_embedding(target, target_len) 82 | # target = target.transpose(0,1).contiguous() 83 | 84 | context = self.squeeze_embedding(context, context_len) 85 | context, _ = self.bert(context) 86 | # context = self.word_embeds(context) 87 | context_output, _ = self.lstm(context) 88 | 89 | att = self.squeeze_embedding(att, att_len) 90 | att, _ = self.bert(att) 91 | # att = self.word_embeds(att) 92 | _, att_hidden = self.lstm(att) 93 | att_hidden = torch.cat([att_hidden[0][-2],att_hidden[0][-1]], dim=-1) 94 | 95 | outputs = self.calculate_cosin(context_output, att_hidden) 96 | outputs = self.dropout(outputs) 97 | 98 | outputs = self.hidden2tag(outputs) 99 | #CRF 100 | # outputs = outputs.transpose(0,1).contiguous() 101 | 102 | return - self.crf(outputs, target) 103 | 104 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .OpenTag_2019 import OpenTag2019 2 | from .LSTM_CRF import NERLSTM_CRF -------------------------------------------------------------------------------- /models/__pycache__/LSTM_CRF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/models/__pycache__/LSTM_CRF.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/OpenTag_2019.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/models/__pycache__/OpenTag_2019.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/models/__pycache__/basic_module.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/squeeze_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/models/__pycache__/squeeze_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /models/basic_module.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | import time 4 | 5 | 6 | class BasicModule(t.nn.Module): 7 | """ 8 | 封装了nn.Module,主要是提供了save和load两个方法 9 | """ 10 | 11 | def __init__(self): 12 | super(BasicModule,self).__init__() 13 | self.model_name=str(type(self))# 默认名字 14 | 15 | def load(self, path): 16 | """ 17 | 可加载指定路径的模型 18 | """ 19 | self.load_state_dict(t.load(path)) 20 | 21 | def save(self, name=None): 22 | """ 23 | 保存模型,默认使用“模型名字+时间”作为文件名 24 | """ 25 | if name is None: 26 | prefix = './checkpoints/' + self.model_name + '_' 27 | name = time.strftime(prefix + '%m%d_%H:%M:%S.pth') 28 | t.save(self.state_dict(), name) 29 | return name 30 | 31 | def get_optimizer(self, lr, weight_decay): 32 | return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) 33 | 34 | 35 | class Flat(t.nn.Module): 36 | """ 37 | 把输入reshape成(batch_size,dim_length) 38 | """ 39 | 40 | def __init__(self): 41 | super(Flat, self).__init__() 42 | #self.size = size 43 | 44 | def forward(self, x): 45 | return x.view(x.size(0), -1) 46 | -------------------------------------------------------------------------------- /models/squeeze_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: squeeze_embedding.py 3 | # author: songyouwei 4 | # Copyright (C) 2018. All Rights Reserved. 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | class SqueezeEmbedding(nn.Module): 12 | """ 13 | Squeeze sequence embedding length to the longest one in the batch 14 | """ 15 | def __init__(self, batch_first=True): 16 | super(SqueezeEmbedding, self).__init__() 17 | self.batch_first = batch_first 18 | 19 | def forward(self, x, x_len): 20 | """ 21 | sequence -> sort -> pad and pack -> unpack ->unsort 22 | :param x: sequence embedding vectors 23 | :param x_len: numpy/tensor list 24 | :return: 25 | """ 26 | """sort""" 27 | x_sort_idx = torch.sort(-x_len)[1].long() 28 | x_unsort_idx = torch.sort(x_sort_idx)[1].long() 29 | x_len = x_len[x_sort_idx] 30 | x = x[x_sort_idx] 31 | """pack""" 32 | x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first) 33 | """unpack: out""" 34 | out = torch.nn.utils.rnn.pad_packed_sequence(x_emb_p, batch_first=self.batch_first) # (sequence, lengths) 35 | out = out[0] # 36 | """unsort""" 37 | out = out[x_unsort_idx] 38 | return out 39 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import get_dataloader -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackerxiaobai/OpenTag_2019/c441ea325a1be9eaaaf251078646f2544152584b/utils/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pytorch_transformers import BertTokenizer 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import pickle 6 | import random 7 | import numpy as np 8 | import collections 9 | from collections import Counter 10 | import sys 11 | 12 | def _is_chinese_char(cp): 13 | """Checks whether CP is the codepoint of a CJK character.""" 14 | # This defines a "chinese character" as anything in the CJK Unicode block: 15 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 16 | # 17 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 18 | # despite its name. The modern Korean Hangul alphabet is a different block, 19 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 20 | # space-separated words, so they are not treated specially and handled 21 | # like the all of the other languages. 22 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 23 | (cp >= 0x3400 and cp <= 0x4DBF) or # 24 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 25 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 26 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 27 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 28 | (cp >= 0xF900 and cp <= 0xFAFF) or # 29 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 30 | return True 31 | 32 | return False 33 | 34 | def bert4token(tokenizer, title, attribute, value): 35 | title = tokenizer.tokenize(title) 36 | attribute = tokenizer.tokenize(attribute) 37 | value = tokenizer.tokenize(value) 38 | tag = ['O']*len(title) 39 | 40 | for i in range(0,len(title)-len(value)): 41 | if title[i:i+len(value)] == value: 42 | for j in range(len(value)): 43 | if j==0: 44 | tag[i+j] = 'B' 45 | else: 46 | tag[i+j] = 'I' 47 | title_id = tokenizer.convert_tokens_to_ids(title) 48 | attribute_id = tokenizer.convert_tokens_to_ids(attribute) 49 | value_id = tokenizer.convert_tokens_to_ids(value) 50 | tag_id = [TAGS[_] for _ in tag] 51 | return title_id, attribute_id, value_id, tag_id 52 | 53 | def nobert4token(tokenizer, title, attribute, value): 54 | 55 | def get_char(sent): 56 | tmp = [] 57 | s = '' 58 | for char in sent.strip(): 59 | if char.strip(): 60 | cp = ord(char) 61 | if _is_chinese_char(cp): 62 | if s: 63 | tmp.append(s) 64 | tmp.append(char) 65 | s = '' 66 | else: 67 | s += char 68 | elif s: 69 | tmp.append(s) 70 | s = '' 71 | if s: 72 | tmp.append(s) 73 | return tmp 74 | 75 | title_list = get_char(title) 76 | attribute_list = get_char(attribute) 77 | value_list = get_char(value) 78 | 79 | tag_list = ['O']*len(title_list) 80 | for i in range(0,len(title_list)-len(value_list)): 81 | if title_list[i:i+len(value_list)] == value_list: 82 | for j in range(len(value_list)): 83 | if j==0: 84 | tag_list[i+j] = 'B' 85 | else: 86 | tag_list[i+j] = 'I' 87 | 88 | title_list = tokenizer.convert_tokens_to_ids(title_list) 89 | attribute_list = tokenizer.convert_tokens_to_ids(attribute_list) 90 | value_list = tokenizer.convert_tokens_to_ids(value_list) 91 | tag_list = [TAGS[i] for i in tag_list] 92 | 93 | return title_list, attribute_list, value_list, tag_list 94 | 95 | 96 | max_len = 40 97 | def X_padding(ids): 98 | if len(ids) >= max_len: 99 | return ids[:max_len] 100 | ids.extend([0]*(max_len-len(ids))) 101 | return ids 102 | 103 | tag_max_len = 6 104 | def tag_padding(ids): 105 | if len(ids) >= tag_max_len: 106 | return ids[:tag_max_len] 107 | ids.extend([0]*(tag_max_len-len(ids))) 108 | return ids 109 | 110 | def rawdata2pkl4nobert(path): 111 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 112 | titles = [] 113 | attributes = [] 114 | values = [] 115 | tags = [] 116 | with open(path, 'r') as f: 117 | for index, line in enumerate(tqdm(f.readlines())): 118 | line = line.strip('\n') 119 | if line: 120 | title, attribute, value = line.split('<$$$>') 121 | if attribute in ['适用季节','品牌'] and value in title and _is_chinese_char(ord(value[0])): 122 | title, attribute, value, tag = nobert4token(tokenizer, title, attribute, value) 123 | titles.append(title) 124 | attributes.append(attribute) 125 | values.append(value) 126 | tags.append(tag) 127 | print([tokenizer.convert_ids_to_tokens(i) for i in titles[:3]]) 128 | print([[id2tags[j] for j in i] for i in tags[:3]]) 129 | print([tokenizer.convert_ids_to_tokens(i) for i in attributes[:3]]) 130 | print([tokenizer.convert_ids_to_tokens(i) for i in values[:3]]) 131 | 132 | df = pd.DataFrame({'titles': titles, 'attributes': attributes, 'values': values, 'tags': tags}, 133 | index=range(len(titles))) 134 | print(df.shape) 135 | df['x'] = df['titles'].apply(X_padding) 136 | df['y'] = df['tags'].apply(X_padding) 137 | df['att'] = df['attributes'].apply(tag_padding) 138 | 139 | index = list(range(len(titles))) 140 | random.shuffle(index) 141 | train_index = index[:int(0.9 * len(index))] 142 | valid_index = index[int(0.9 * len(index)):int(0.96 * len(index))] 143 | test_index = index[int(0.96 * len(index)):] 144 | 145 | train = df.loc[train_index, :] 146 | valid = df.loc[valid_index, :] 147 | test = df.loc[test_index, :] 148 | 149 | train_x = np.asarray(list(train['x'].values)) 150 | train_att = np.asarray(list(train['att'].values)) 151 | train_y = np.asarray(list(train['y'].values)) 152 | 153 | valid_x = np.asarray(list(valid['x'].values)) 154 | valid_att = np.asarray(list(valid['att'].values)) 155 | valid_y = np.asarray(list(valid['y'].values)) 156 | 157 | test_x = np.asarray(list(test['x'].values)) 158 | test_att = np.asarray(list(test['att'].values)) 159 | test_value = np.asarray(list(test['values'].values)) 160 | test_y = np.asarray(list(test['y'].values)) 161 | 162 | with open('../data/中文_适用季节.pkl', 'wb') as outp: 163 | pickle.dump(train_x, outp) 164 | pickle.dump(train_att, outp) 165 | pickle.dump(train_y, outp) 166 | pickle.dump(valid_x, outp) 167 | pickle.dump(valid_att, outp) 168 | pickle.dump(valid_y, outp) 169 | pickle.dump(test_x, outp) 170 | pickle.dump(test_att, outp) 171 | pickle.dump(test_value, outp) 172 | pickle.dump(test_y, outp) 173 | 174 | 175 | 176 | def rawdata2pkl4bert(path, att_list): 177 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 178 | with open(path, 'r') as f: 179 | lines = f.readlines() 180 | for att_name in tqdm(att_list): 181 | print('#'*20+att_name+'#'*20) 182 | titles = [] 183 | attributes = [] 184 | values = [] 185 | tags = [] 186 | for index, line in enumerate(lines): 187 | line = line.strip('\n') 188 | if line: 189 | title, attribute, value = line.split('<$$$>') 190 | if attribute in [att_name] and value in title: #and _is_chinese_char(ord(value[0])): 191 | title, attribute, value, tag = bert4token(tokenizer, title, attribute, value) 192 | titles.append(title) 193 | attributes.append(attribute) 194 | values.append(value) 195 | tags.append(tag) 196 | if titles: 197 | print([tokenizer.convert_ids_to_tokens(i) for i in titles[:3]]) 198 | print([[id2tags[j] for j in i] for i in tags[:3]]) 199 | print([tokenizer.convert_ids_to_tokens(i) for i in attributes[:3]]) 200 | print([tokenizer.convert_ids_to_tokens(i) for i in values[:3]]) 201 | df = pd.DataFrame({'titles':titles,'attributes':attributes,'values':values,'tags':tags}, index=range(len(titles))) 202 | print(df.shape) 203 | df['x'] = df['titles'].apply(X_padding) 204 | df['y'] = df['tags'].apply(X_padding) 205 | df['att'] = df['attributes'].apply(tag_padding) 206 | 207 | index = list(range(len(titles))) 208 | random.shuffle(index) 209 | train_index = index[:int(0.85*len(index))] 210 | valid_index = index[int(0.85*len(index)):int(0.95*len(index))] 211 | test_index = index[int(0.95*len(index)):] 212 | 213 | train = df.loc[train_index,:] 214 | valid = df.loc[valid_index,:] 215 | test = df.loc[test_index,:] 216 | 217 | train_x = np.asarray(list(train['x'].values)) 218 | train_att = np.asarray(list(train['att'].values)) 219 | train_y = np.asarray(list(train['y'].values)) 220 | 221 | valid_x = np.asarray(list(valid['x'].values)) 222 | valid_att = np.asarray(list(valid['att'].values)) 223 | valid_y = np.asarray(list(valid['y'].values)) 224 | 225 | test_x = np.asarray(list(test['x'].values)) 226 | test_att = np.asarray(list(test['att'].values)) 227 | test_value = np.asarray(list(test['values'].values)) 228 | test_y = np.asarray(list(test['y'].values)) 229 | 230 | att_name = att_name.replace('/','_') 231 | with open('../data/all/{}.pkl'.format(att_name), 'wb') as outp: 232 | # with open('../data/top105_att.pkl', 'wb') as outp: 233 | pickle.dump(train_x, outp) 234 | pickle.dump(train_att, outp) 235 | pickle.dump(train_y, outp) 236 | pickle.dump(valid_x, outp) 237 | pickle.dump(valid_att, outp) 238 | pickle.dump(valid_y, outp) 239 | pickle.dump(test_x, outp) 240 | pickle.dump(test_att, outp) 241 | pickle.dump(test_value, outp) 242 | pickle.dump(test_y, outp) 243 | 244 | def get_attributes(path): 245 | atts = [] 246 | with open(path, 'r') as f: 247 | for line in f.readlines(): 248 | line = line.strip('\n') 249 | if line: 250 | title, attribute, value = line.split('<$$$>') 251 | atts.append(attribute) 252 | return [item[0] for item in Counter(atts).most_common()] 253 | 254 | 255 | if __name__=='__main__': 256 | TAGS = {'':0,'B':1,'I':2,'O':3} 257 | id2tags = {v:k for k,v in TAGS.items()} 258 | path = '../data/raw.txt' 259 | att_list = get_attributes(path) 260 | # rawdata2pkl4bert(path, att_list) 261 | rawdata2pkl4nobert(path) -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | from torch.utils.data import Dataset,DataLoader 4 | 5 | class MyDataset(Dataset): 6 | def __init__(self, X, Y, att): 7 | self.data = [{'x':X[i],'y':Y[i],'att':att[i]} for i in range(X.shape[0])] 8 | 9 | def __getitem__(self, index): 10 | return self.data[index] 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | 16 | def get_dataloader(opt): 17 | with open(opt.pickle_path, 'rb') as inp: 18 | train_x = pickle.load(inp) 19 | train_att = pickle.load(inp) 20 | train_y = pickle.load(inp) 21 | valid_x = pickle.load(inp) 22 | valid_att = pickle.load(inp) 23 | valid_y = pickle.load(inp) 24 | test_x = pickle.load(inp) 25 | test_att = pickle.load(inp) 26 | test_value = pickle.load(inp) 27 | test_y = pickle.load(inp) 28 | print("train len:",train_x.shape) 29 | print("test len:",test_x.shape) 30 | print("valid len", valid_x.shape) 31 | 32 | train_dataset = MyDataset(train_x, train_y, train_att) 33 | valid_dataset = MyDataset(valid_x, valid_y, valid_att) 34 | test_dataset = MyDataset(test_x, test_y, test_att) 35 | 36 | try: 37 | train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 38 | valid_dataloader = DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 39 | test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 40 | return train_dataloader,valid_dataloader,test_dataloader 41 | except: 42 | pass 43 | return None 44 | 45 | # if __name__=='__main__': 46 | # train_dataloader, valid_dataloader, test_dataloader = get_dataloader() 47 | # for batch in train_dataloader: 48 | # print(batch['x'].shape) 49 | # print(batch['y'].shape) 50 | # print(batch['att'].shape) --------------------------------------------------------------------------------