├── data ├── marco │ └── data.txt └── ohsumed │ ├── readme.md │ └── queries.txt ├── script ├── local_count.sh ├── test.sh ├── batch_submit_genpretrain.sh ├── gen_pretrain.sh ├── count.sh ├── gendata.sh ├── large_finetune.sh ├── finetune.sh ├── pretrain_large.sh ├── run_rocketqav1.sh ├── pretrain_base.sh └── base_finetune.sh ├── base ├── bert_config.json └── ernie_config.json ├── large ├── bert_config.json └── ernie_config.json ├── ernie ├── test.py ├── __init__.py ├── run_rocketqa.py ├── make_pretrain_data.py ├── utils.py ├── file_utils.py ├── unbatch.py ├── gnn_layer.py ├── static2dynamic.py ├── count.py ├── generate_data.py ├── tokenizing_ernie.py ├── msmarco_eval.py ├── tokenization.py ├── paths.py └── modeling_ernie.py └── README.md /data/marco/data.txt: -------------------------------------------------------------------------------- 1 | to be updated -------------------------------------------------------------------------------- /script/local_count.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for ((i=0;i<16;i++)) 3 | do 4 | nohup python ernie/count.py --local_rank=${i} >> output/log/log.${i} & 5 | done 6 | 7 | -------------------------------------------------------------------------------- /data/ohsumed/readme.md: -------------------------------------------------------------------------------- 1 | Due to the Ohsumed dataset is oriented for the task of text classification, the query and qrels used in this work are constructed from their corresponding text classification labels, and the collection could be downloaded from this [link](https://pan.baidu.com/s/1bVd0x4v-HKm6dpMc6Cz9Og) via the following password: ur5v 2 | 3 | 4 | -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | queue=$1 3 | if [[ ${queue} =~ "v100" ]]; then 4 | batch_size=64 5 | fi 6 | if [[ ${queue} =~ "a100" ]]; then 7 | batch_size=96 8 | fi 9 | val=`echo "scale=5; 33/$batch_size" | bc` 10 | echo "val:${val}" 11 | val2=`echo "100 * ${val}" | bc` 12 | 13 | echo "val2:${val2}" 14 | val3=`expr $batch_size / 2` 15 | echo "val3:${val3}" 16 | 17 | -------------------------------------------------------------------------------- /base/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /large/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 16, 10 | "num_hidden_layers": 24, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /base/ernie_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "max_position_embeddings": 512, 8 | "num_attention_heads": 12, 9 | "num_hidden_layers": 12, 10 | "sent_type_vocab_size": 4, 11 | "task_type_vocab_size": 16, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /large/ernie_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "max_position_embeddings": 512, 8 | "num_attention_heads": 16, 9 | "num_hidden_layers": 24, 10 | "sent_type_vocab_size": 4, 11 | "task_type_vocab_size": 16, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /script/batch_submit_genpretrain.sh: -------------------------------------------------------------------------------- 1 | sh script/gen_pretrain.sh large 0 2 | sleep 1m 3 | sh script/gen_pretrain.sh large 1 4 | sleep 1m 5 | sh script/gen_pretrain.sh large 2 6 | sleep 1m 7 | sh script/gen_pretrain.sh large 3 8 | sleep 1m 9 | sh script/gen_pretrain.sh large 4 10 | sleep 1m 11 | sh script/gen_pretrain.sh large 5 12 | sleep 1m 13 | sh script/gen_pretrain.sh large 6 14 | sleep 1m 15 | sh script/gen_pretrain.sh large 7 16 | -------------------------------------------------------------------------------- /ernie/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import paddle 3 | from paddle.distributed import init_parallel_env 4 | 5 | paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) 6 | init_parallel_env() 7 | tensor_list = [] 8 | if paddle.distributed.ParallelEnv().local_rank == 0: 9 | np_data1 = np.array([[4, 5, 6], [4, 5, 6], [4, 5, 6]]) 10 | np_data2 = np.array([[4, 5, 6], [4, 5, 6]]) 11 | data1 = paddle.to_tensor(np_data1) 12 | data2 = paddle.to_tensor(np_data2) 13 | paddle.distributed.all_gather(tensor_list, data1) 14 | print(tensor_list) 15 | else: 16 | np_data1 = np.array([[1, 2, 3], [1, 2, 3]]) 17 | np_data2 = np.array([[1, 2, 3], [1, 2, 3]]) 18 | data1 = paddle.to_tensor(np_data1) 19 | data2 = paddle.to_tensor(np_data2) 20 | paddle.distributed.all_gather(tensor_list, data2) -------------------------------------------------------------------------------- /ernie/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import division 16 | from __future__ import absolute_import 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | import sys 21 | sys.path.append("..") 22 | from poly_ernie import tokenization 23 | import logging 24 | 25 | import paddle 26 | if paddle.__version__ != '0.0.0' and paddle.__version__ < '2.0.0': 27 | raise RuntimeError('propeller 0.2 requires paddle 2.0+, got %s' % 28 | paddle.__version__) 29 | 30 | from ernie.modeling_ernie import ErnieModel 31 | from ernie.modeling_ernie import ( 32 | ErnieModelForSequenceClassification, ErnieModelForTokenClassification, 33 | ErnieModelForQuestionAnswering, ErnieModelForPretraining) 34 | 35 | from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer 36 | 37 | log = logging.getLogger(__name__) 38 | formatter = logging.Formatter(fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]: %(message)s') 39 | stream_hdl = logging.StreamHandler(stream=sys.stderr) 40 | stream_hdl.setFormatter(formatter) 41 | log.addHandler(stream_hdl) 42 | log.propagate = False -------------------------------------------------------------------------------- /ernie/run_rocketqa.py: -------------------------------------------------------------------------------- 1 | from paddle.fluid.dygraph.parallel import ParallelEnv 2 | from tqdm import tqdm 3 | import pickle as pkl 4 | import pandas as pd 5 | import rocketqa 6 | import paddle 7 | import os 8 | 9 | def main(): 10 | batch_size = 1024 11 | local_rank = paddle.distributed.get_rank() 12 | model = rocketqa.load_model("v2_marco_ce", use_cuda=True, device_id=0, batch_size=batch_size) 13 | _nranks = ParallelEnv().nranks 14 | query = pd.read_csv('/home/user/hdfs_data/data_train/dev.query.txt',sep="\t",header=None) 15 | query.columns = ['qid','text'] 16 | query.index = query.qid 17 | query.pop('qid') 18 | collection = pd.read_csv("/home/user/hdfs_data/data_train/marco/collection.tsv",header=None,sep='\t') 19 | top1000 = pd.read_csv("/home/user/hdfs_data/data_train/run.bm25.dev.small.tsv",sep="\t",header=None) 20 | # query = pd.read_csv('data/dev.query.txt',sep="\t",header=None) 21 | # collection = pd.read_csv("data/marco/collection.tsv",header=None,sep='\t') 22 | # top1000 = pd.read_csv("data/run.bm25.dev.small.tsv",sep="\t",header=None) 23 | new_batch_size = batch_size*_nranks 24 | 25 | qrys = [] 26 | psgs = [] 27 | qids = [] 28 | pids = [] 29 | preds = [] 30 | qids_save = [] 31 | pids_save = [] 32 | for i in tqdm(range(len(top1000))): 33 | qid = top1000.iloc[i][0] 34 | qids.append(qid) 35 | pid = top1000.iloc[i][1] 36 | pids.append(pid) 37 | qrys.append(query.loc[qid].text) 38 | psgs.append(collection.iloc[pid][1]) 39 | if (i+1)%new_batch_size==0: 40 | qrys = qrys[local_rank::_nranks] 41 | psgs = psgs[local_rank::_nranks] 42 | qids = qids[local_rank::_nranks] 43 | pids = pids[local_rank::_nranks] 44 | scores = model.matching(query=qrys, para=psgs) 45 | preds+=scores 46 | qids_save+=qids 47 | pids_save+=pids 48 | qrys = [] 49 | psgs = [] 50 | qids = [] 51 | pids = [] 52 | if local_rank==0 and len(qids)!=0: 53 | scores = model.matching(query=qrys, para=psgs) 54 | preds+=scores 55 | qids_save+=qids 56 | pids_save+=pids 57 | pkl.dump(qids_save, open("output/qids.%d.pkl"%local_rank,"wb")) 58 | pkl.dump(pids_save, open("output/pids.%d.pkl"%local_rank,"wb")) 59 | pkl.dump(preds, open("output/preds.%d.pkl"%local_rank,"wb")) 60 | 61 | if __name__=="__main__": 62 | main() -------------------------------------------------------------------------------- /ernie/make_pretrain_data.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import random 3 | import numpy as np 4 | import pandas as pd 5 | from functools import reduce, partial 6 | def apply_mask(input_ids,mask_rate=0.15): 7 | x,y = np.where(input_ids==102) 8 | y_sep = y[::2] 9 | y_end = y[1::2] 10 | new_x = np.concatenate([[i]*(v-2) for i,v in enumerate(y_end)]).reshape(-1) 11 | new_y = np.concatenate([list(range(1,y_sep[i]))+list(range(y_sep[i]+1, v)) for i,v in enumerate(y_end)]).reshape(-1) 12 | mask_pos = random.choices(range(len(new_x)), k=max(1, int(len(new_x)*mask_rate))) 13 | mask_pos = new_x[mask_pos],new_y[mask_pos] 14 | mask_label = input_ids[mask_pos] 15 | rand = np.random.rand(*mask_pos[0].shape) 16 | choose_original = rand < 0.1 # 17 | choose_random_id = (0.1 < rand) & (rand < 0.2) # 18 | choose_mask_id = 0.2 < rand # 19 | random_id = np.random.randint(1, 30522, size=mask_pos[0].shape) 20 | 21 | replace_id = 103 * choose_mask_id + \ 22 | random_id * choose_random_id + \ 23 | mask_label * choose_original 24 | input_ids[mask_pos] = replace_id 25 | return input_ids, np.stack(mask_pos, -1), mask_label 26 | 27 | 28 | 29 | def make_data(cfg): 30 | collection = pd.read_csv(cfg['collection'],header=None,sep='\t') 31 | collection.columns=['pid','text'] 32 | doc_number = len(collection) 33 | nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) 34 | nlp.add_pipe('sentencizer') 35 | cnts = [0,0,0] # random, next, previous 36 | for i in range(doc_number): 37 | doc = nlp(collection['text'][i]) 38 | sents = list(doc.sents) 39 | sent_number = len(sents) 40 | for j,sent in enumerate(sents): 41 | if j and j1: 54 | srp_type = 0 if cnts[0] output/log/log.${i} 2>&1 & 58 | done 59 | 60 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 61 | # upload 62 | echo "Starting uploading file to HDFS" 63 | # tar -zcvf /root/paddlejob/workspace/env_run/output.tar.gz gen_data/ 64 | ${hdfs_cmd} -mkdir /user/sasd-adv/diaoyan/user/modelzoo/${model_name} 65 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/output /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 66 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/ernie /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 67 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/script /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 68 | echo "Done uploading file to HDFS" 69 | -------------------------------------------------------------------------------- /script/gendata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=gen_data 3 | model_size=large 4 | generate_type=$1 5 | sample_num=20 6 | batch_size=`expr 80 / $sample_num` 7 | if [[ ${model_size} =~ "base" ]]; then 8 | batch_size=`expr $batch_size \* 2` 9 | fi 10 | echo "batch size ${batch_size}" 11 | dev_batch_size=64 12 | top1000='data/marco/top1000-train' 13 | dev_input_file='data/marco/top1000-dev' 14 | warmup_proportion=0.2 15 | eval_step_proportion=0.1 16 | report_step=10 17 | epoch=5 18 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 19 | ### 下面是永远不用改的 20 | min_index=25 21 | max_index=768 22 | max_seq_len=160 23 | collection='data/marco/collection.tsv' 24 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 25 | vocab_file=${model_size}/vocab.txt #ernie配置文件 26 | warm_start_from=data/ernie_${model_size}.p #ernie参数 27 | qrels='data/marco/qrels.tsv' 28 | query='data/marco/train.query.txt' 29 | resource='data/concept.txt' 30 | cpnet='data/conceptnet.en.pruned.graph' 31 | pattern_path='data/matcher_patterns.json' 32 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 33 | ent_emb='data/glove.transe.sgd.ent.npy' 34 | rel_emb='data/glove.transe.sgd.rel.npy' 35 | output_dir=output 36 | log_dir=${output_dir}/log 37 | mkdir -p ${output_dir} 38 | mkdir -p ${log_dir} 39 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 40 | python -m paddle.distributed.launch \ 41 | --log_dir ${log_dir} \ 42 | ernie/generate_data.py \ 43 | --ernie_config_file=${ernie_config_file} \ 44 | --vocab_file=${vocab_file} \ 45 | --resource=${resource} \ 46 | --cpnet=${cpnet} \ 47 | --dev_input_file=${dev_input_file} \ 48 | --warm_start_from=${warm_start_from} \ 49 | --batch_size=${batch_size} \ 50 | --warmup_proportion=${warmup_proportion} \ 51 | --eval_step_proportion=${eval_step_proportion} \ 52 | --report=${report_step} \ 53 | --qrels=${qrels} \ 54 | --query=${query} \ 55 | --collection=${collection} \ 56 | --top1000=${top1000} \ 57 | --min_index=${min_index} \ 58 | --max_index=${max_index} \ 59 | --epoch=${epoch} \ 60 | --sample_num=${sample_num} \ 61 | --dev_batch_size=${dev_batch_size} \ 62 | --num_gnn_layers=3 \ 63 | --pattern_path=${pattern_path} \ 64 | --word2vec=${word2vec} \ 65 | --ent_emb=${ent_emb} \ 66 | --rel_emb=${rel_emb} \ 67 | --max_seq_len=${max_seq_len} \ 68 | --model=ErnieWithGNNv2 \ 69 | --generate_type=${generate_type} 70 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 71 | # upload 72 | tar -zcvf ${generate_type}.concept.tar.gz gen_data/ 73 | mv ${generate_type}.concept.tar.gz data/marco/ -------------------------------------------------------------------------------- /ernie/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import six 3 | import os 4 | class HParams(object): 5 | """Hyper paramerter""" 6 | 7 | def __init__(self, **kwargs): 8 | for k, v in kwargs.items(): 9 | self.__dict__[k] = v 10 | 11 | def __contains__(self, key): 12 | return key in self.__dict__ 13 | 14 | def __getitem__(self, key): 15 | if key not in self.__dict__: 16 | raise ValueError('key(%s) not in HParams.' % key) 17 | return self.__dict__[key] 18 | 19 | def __repr__(self): 20 | return repr(self.to_dict()) 21 | 22 | def __setitem__(self, key, val): 23 | self.__dict__[key] = val 24 | 25 | @classmethod 26 | def from_json(cls, json_str): 27 | """doc""" 28 | d = json.loads(json_str) 29 | if type(d) != dict: 30 | raise ValueError('json object must be dict.') 31 | return HParams.from_dict(d) 32 | 33 | def get(self, key, default=None): 34 | """doc""" 35 | return self.__dict__.get(key, default) 36 | 37 | @classmethod 38 | def from_dict(cls, d): 39 | """doc""" 40 | if type(d) != dict: 41 | raise ValueError('input must be dict.') 42 | hp = HParams(**d) 43 | return hp 44 | 45 | def to_json(self): 46 | """doc""" 47 | return json.dumps(self.__dict__) 48 | 49 | def to_dict(self): 50 | """doc""" 51 | return self.__dict__ 52 | 53 | def print_config(self): 54 | for key,value in self.__dict__.items(): 55 | print(key+":",value) 56 | 57 | def join(self, other): 58 | """doc""" 59 | if not isinstance(other, HParams): 60 | raise ValueError('input must be HParams instance.') 61 | self.__dict__.update(**other.__dict__) 62 | return self 63 | 64 | def _get_dict_from_environ_or_json_or_file(args, env_name): 65 | if args == '': 66 | return None 67 | if args is None: 68 | s = os.environ.get(env_name) 69 | else: 70 | s = args 71 | if os.path.exists(s): 72 | s = open(s).read() 73 | if isinstance(s, six.string_types): 74 | try: 75 | r = eval(s) 76 | except SyntaxError as e: 77 | raise ValueError('json parse error: %s \n>Got json: %s' % 78 | (repr(e), s)) 79 | return r 80 | else: 81 | return s #None 82 | 83 | 84 | def parse_file(filename): 85 | """useless api""" 86 | d = _get_dict_from_environ_or_json_or_file(filename, None) 87 | if d is None: 88 | raise ValueError('file(%s) not found' % filename) 89 | return d -------------------------------------------------------------------------------- /script/large_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=KERMLarge 3 | model_size=large 4 | dataset=marco 5 | sample_num=20 6 | batch_size=3 7 | echo "batch size ${batch_size}" 8 | dev_batch_size=1 9 | pretrain_input_file='data/pretrain/*' 10 | train_input_file=data/${dataset}/train.concept.tar.gz # training data 11 | dev_input_file=data/${dataset}/dev.concept.tar.gz # dev data 12 | test_input_file=data/${dataset}/test.concept.tar.gz # test data 13 | instance_num=502939 #v3: 917012 14 | sample_range=20 15 | warmup_proportion=0.2 16 | eval_step_proportion=0.01 17 | report_step=10 18 | epoch=5 19 | min_index=25 20 | max_index=768 21 | max_seq_len=160 22 | collection=data/${dataset}/collection.tsv 23 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 24 | vocab_file=${model_size}/vocab.txt #ernie配置文件 25 | warm_start_from=data/${dataset}/reranker-4gpu-5-large.p 26 | # warm_start_from=data/${dataset}/ernie_base.p 27 | qrels=data/${dataset}/qrels.tsv 28 | query=data/${dataset}/train.query.txt 29 | resource='data/concept.txt' 30 | cpnet='data/conceptnet.en.pruned.graph' 31 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 32 | ent_emb='data/glove.transe.sgd.ent.npy' 33 | rel_emb='data/glove.transe.sgd.rel.npy' 34 | books='data/books.txt' 35 | output_dir=output 36 | log_dir=${output_dir}/log 37 | mkdir -p ${output_dir} 38 | mkdir -p ${log_dir} 39 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 40 | python -m paddle.distributed.launch \ 41 | --log_dir ${log_dir} \ 42 | ernie/finetune.py \ 43 | --train_input_file=${train_input_file} \ 44 | --ernie_config_file=${ernie_config_file} \ 45 | --vocab_file=${vocab_file} \ 46 | --resource=${resource} \ 47 | --cpnet=${cpnet} \ 48 | --dev_input_file=${dev_input_file} \ 49 | --test_input_file=${test_input_file} \ 50 | --warm_start_from=${warm_start_from} \ 51 | --batch_size=${batch_size} \ 52 | --warmup_proportion=${warmup_proportion} \ 53 | --eval_step_proportion=${eval_step_proportion} \ 54 | --report=${report_step} \ 55 | --qrels=${qrels} \ 56 | --query=${query} \ 57 | --collection=${collection} \ 58 | --top1000=${top1000} \ 59 | --min_index=${min_index} \ 60 | --max_index=${max_index} \ 61 | --epoch=${epoch} \ 62 | --sample_num=${sample_num} \ 63 | --dev_batch_size=${dev_batch_size} \ 64 | --pretrain_input_file=${pretrain_input_file} \ 65 | --ent_emb=${ent_emb} \ 66 | --rel_emb=${rel_emb} \ 67 | --pattern_path=${pattern_path} \ 68 | --word2vec=${word2vec} \ 69 | --instance_num=${instance_num} \ 70 | --dataset=${dataset} \ 71 | --sample_range=${sample_range} \ 72 | --num_gnn_layers=3 73 | 74 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" -------------------------------------------------------------------------------- /script/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=KERMLarge 3 | model_size=large 4 | dataset=marco 5 | sample_num=20 6 | batch_size=`expr 80 / $sample_num` 7 | if [[ ${model_size} =~ "base" ]]; then 8 | batch_size=`expr $batch_size \* 2` 9 | fi 10 | echo "batch size ${batch_size}" 11 | dev_batch_size=1 12 | train_input_file=data/${dataset}/train.concept.tar.gz #训练数据 13 | dev_input_file=data/${dataset}/dev.concept.tar.gz #测试数据 14 | # test_input_file=data/${dataset}/test.concept.tar.gz 15 | instance_num=502939 #v3: 917012 16 | sample_range=20 17 | warmup_proportion=0.2 18 | eval_step_proportion=0.01 19 | report_step=10 20 | epoch=5 21 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 22 | ### 下面是永远不用改的 23 | min_index=25 24 | max_index=768 25 | max_seq_len=160 26 | collection=data/${dataset}/collection.tsv 27 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 28 | vocab_file=${model_size}/vocab.txt #ernie配置文件 29 | warm_start_from=data/${dataset}/pretrain/reranker-4gpu-5.p 30 | qrels=data/${dataset}/qrels.tsv 31 | query=data/${dataset}/train.query.txt 32 | resource='data/concept.txt' 33 | cpnet='data/conceptnet.en.pruned.graph' 34 | pattern_path='data/matcher_patterns.json' 35 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 36 | ent_emb='data/glove.transe.sgd.ent.npy' 37 | rel_emb='data/glove.transe.sgd.rel.npy' 38 | books='data/books.txt' 39 | output_dir=output 40 | log_dir=${output_dir}/log 41 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 42 | python -m paddle.distributed.launch \ 43 | --log_dir ${log_dir} \ 44 | ernie/finetune.py \ 45 | --train_input_file=${train_input_file} \ 46 | --ernie_config_file=${ernie_config_file} \ 47 | --vocab_file=${vocab_file} \ 48 | --resource=${resource} \ 49 | --cpnet=${cpnet} \ 50 | --dev_input_file=${dev_input_file} \ 51 | --test_input_file=${test_input_file} \ 52 | --warm_start_from=${warm_start_from} \ 53 | --batch_size=${batch_size} \ 54 | --warmup_proportion=${warmup_proportion} \ 55 | --eval_step_proportion=${eval_step_proportion} \ 56 | --report=${report_step} \ 57 | --qrels=${qrels} \ 58 | --query=${query} \ 59 | --collection=${collection} \ 60 | --top1000=${top1000} \ 61 | --min_index=${min_index} \ 62 | --max_index=${max_index} \ 63 | --epoch=${epoch} \ 64 | --sample_num=${sample_num} \ 65 | --dev_batch_size=${dev_batch_size} \ 66 | --pretrain_input_file=${pretrain_input_file} \ 67 | --ent_emb=${ent_emb} \ 68 | --rel_emb=${rel_emb} \ 69 | --pattern_path=${pattern_path} \ 70 | --word2vec=${word2vec} \ 71 | --instance_num=${instance_num} \ 72 | --dataset=${dataset} \ 73 | --sample_range=${sample_range} \ 74 | --num_gnn_layers=3 75 | 76 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" -------------------------------------------------------------------------------- /ernie/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import division 15 | from __future__ import absolute_import 16 | from __future__ import print_function 17 | from __future__ import unicode_literals 18 | 19 | import os 20 | import logging 21 | from tqdm import tqdm 22 | from pathlib import Path 23 | import six 24 | import time 25 | if six.PY2: 26 | from pathlib2 import Path 27 | else: 28 | from pathlib import Path 29 | 30 | log = logging.getLogger(__name__) 31 | 32 | 33 | def _fetch_from_remote(url, 34 | force_download=False, 35 | cached_dir='~/.paddle-ernie-cache'): 36 | import hashlib, tempfile, requests, tarfile 37 | sig = hashlib.md5(url.encode('utf8')).hexdigest() 38 | cached_dir = Path(cached_dir).expanduser() 39 | try: 40 | cached_dir.mkdir() 41 | except OSError: 42 | pass 43 | cached_dir_model = cached_dir / sig 44 | from filelock import FileLock 45 | with FileLock(str(cached_dir_model) + '.lock'): 46 | donefile = cached_dir_model / 'done' 47 | if (not force_download) and donefile.exists(): 48 | log.debug('%s cached in %s' % (url, cached_dir_model)) 49 | return cached_dir_model 50 | cached_dir_model.mkdir(exist_ok=True) 51 | tmpfile = cached_dir_model / 'tmp' 52 | with tmpfile.open('wb') as f: 53 | r = requests.get(url, stream=True) 54 | total_len = int(r.headers.get('content-length')) 55 | for chunk in tqdm( 56 | r.iter_content(chunk_size=1024), 57 | total=total_len // 1024, 58 | desc='downloading %s' % url, 59 | unit='KB'): 60 | if chunk: 61 | f.write(chunk) 62 | f.flush() 63 | log.debug('extacting... to %s' % tmpfile) 64 | with tarfile.open(tmpfile.as_posix()) as tf: 65 | tf.extractall(path=str(cached_dir_model)) 66 | donefile.touch() 67 | os.remove(tmpfile.as_posix()) 68 | 69 | return cached_dir_model 70 | 71 | 72 | def add_docstring(doc): 73 | def func(f): 74 | f.__doc__ += ('\n======other docs from supper class ======\n%s' % doc) 75 | return f 76 | 77 | return func -------------------------------------------------------------------------------- /script/pretrain_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=pretrain 3 | model_size=large 4 | dataset=marco 5 | sample_num=20 6 | batch_size=`expr 80 / $sample_num` 7 | if [[ ${model_size} =~ "base" ]]; then 8 | batch_size=`expr $batch_size \* 2` 9 | fi 10 | echo "batch size ${batch_size}" 11 | batch_size=8 12 | pretrain_batch_size=64 13 | run_func=pretrain 14 | dev_batch_size=1 15 | pretrain_input_file='data/marco/pretrain' 16 | instance_num=502939 #v3: 917012 17 | sample_range=20 18 | warmup_proportion=0.2 19 | eval_step_proportion=0.01 20 | report_step=10 21 | epoch=5 22 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 23 | ### 下面是永远不用改的 24 | min_index=25 25 | max_index=768 26 | max_seq_len=80 27 | collection=data/${dataset}/collection.tsv 28 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 29 | vocab_file=${model_size}/vocab.txt #ernie配置文件 30 | warm_start_from=data/${dataset}/ernie_large.p 31 | qrels=data/${dataset}/qrels.tsv 32 | query=data/${dataset}/train.query.txt 33 | resource='data/concept.txt' 34 | cpnet='data/conceptnet.en.pruned.graph' 35 | pattern_path='data/matcher_patterns.json' 36 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 37 | ent_emb='data/glove.transe.sgd.ent.npy' 38 | rel_emb='data/glove.transe.sgd.rel.npy' 39 | books='data/books.txt' 40 | output_dir=output 41 | log_dir=${output_dir}/log 42 | mkdir -p ${output_dir} 43 | mkdir -p ${log_dir} 44 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 45 | python -m paddle.distributed.launch \ 46 | --log_dir ${log_dir} \ 47 | ernie/finetune.py \ 48 | --train_input_file=${train_input_file} \ 49 | --ernie_config_file=${ernie_config_file} \ 50 | --vocab_file=${vocab_file} \ 51 | --resource=${resource} \ 52 | --cpnet=${cpnet} \ 53 | --dev_input_file=${dev_input_file} \ 54 | --warm_start_from=${warm_start_from} \ 55 | --batch_size=${batch_size} \ 56 | --warmup_proportion=${warmup_proportion} \ 57 | --eval_step_proportion=${eval_step_proportion} \ 58 | --report=${report_step} \ 59 | --qrels=${qrels} \ 60 | --query=${query} \ 61 | --collection=${collection} \ 62 | --top1000=${top1000} \ 63 | --min_index=${min_index} \ 64 | --max_index=${max_index} \ 65 | --epoch=${epoch} \ 66 | --sample_num=${sample_num} \ 67 | --dev_batch_size=${dev_batch_size} \ 68 | --pretrain_input_file=${pretrain_input_file} \ 69 | --ent_emb=${ent_emb} \ 70 | --rel_emb=${rel_emb} \ 71 | --pattern_path=${pattern_path} \ 72 | --word2vec=${word2vec} \ 73 | --instance_num=${instance_num} \ 74 | --dataset=${dataset} \ 75 | --sample_range=${sample_range} \ 76 | --max_seq_len=${max_seq_len} \ 77 | --run_func=${run_func} \ 78 | --pretrain_batch_size=${pretrain_batch_size} \ 79 | --num_gnn_layers=3 80 | 81 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 82 | # upload 83 | mv output/* data/marco/pretrain/ -------------------------------------------------------------------------------- /data/ohsumed/queries.txt: -------------------------------------------------------------------------------- 1 | 0 Are there adverse effects on lipids when progesterone is given with estrogen replacement therapy 2 | 1 pathophysiology and treatment of disseminated intravascular coagulation 3 | 2 anticardiolipin and lupus anticoagulants, pathophysiology, epidemiology, complications 4 | 3 effectiveness of etidronate in treating hypercalcemia of malignancy 5 | 4 does estrogen replacement therapy cause breast cancer 6 | 5 t-cell lymphoma associated with autoimmune symptoms 7 | 6 lactase deficiency therapy options 8 | 7 pancytopenia in aids, workup and etiology 9 | 8 Rh isoimmunization, review topics 10 | 9 endocarditis, duration of antimicrobial therapy 11 | 10 chemotherapy advanced for advanced metastatic breast cancer 12 | 11 isolated hypoaldosteronism, syndromes where hypoaldosteronism and hypokalemia occur concurrently 13 | 12 sickle cell disease, treatment advice 14 | 13 thrombocytopenia in pregnancy, etiology and management 15 | 14 chronic pain management, review article, use of tricyclic antidepressants 16 | 15 carotid endarterectomy, when to perform 17 | 16 review article on adult respiratory syndrome 18 | 17 RISK FACTORS and TREATMENT for HEPATOCELLULAR CARCINOMA 19 | 18 FIBROMYALGIA/FIBROSITIS, DIAGNOSIS AND TREATMENT 20 | 19 DIABETIC GASTROPARESIS, TREATMENT 21 | 20 back pain, information on diagnosis and treatment 22 | 21 can radiation therapy cause a delayed pericardial effusion? 23 | 22 occult blood screening, need for routine screening 24 | 23 urinary retention, differential diagnosis 25 | 24 which peripheral neuropathies have associated edema 26 | 25 isolated systolic hypertension, shep study 27 | 26 differential diagnosis of U waves 28 | 27 indications for and success of pericardial windows and pericardectomies 29 | 28 lupus nephritis, diagnosis and management 30 | 29 angiotensin converting enzyme inhibitors, review article 31 | 30 course of anticoagulation with coumadin 32 | 31 cerebral edema secondary to infection, diagnosis and treatment 33 | 32 diagnostic and therapeutic work up of breast mass 34 | 33 hepatobiliary lesions associated with neurofibromatosis 35 | 34 evaluation for complications and management of bulimia 36 | 35 treatment of migraine headaches with beta blockers and calcium channel blockers 37 | 36 prevention, risk factors, pathophysiology of hypothermia 38 | 37 chronic inflammatory demyelinating polyneuropathy, differential diagnosis and criteria 39 | 38 outpatient management of diabetes, standard management of diabetics and any new management techniques 40 | 39 diverticulitis, differential diagnosis and management 41 | 40 cystic fibrosis and renal failure, effect of long term repeated use of aminoglycosides 42 | 41 thyrotoxicosis, diagnosis and management 43 | 42 neuroleptic malignant syndrome, differential diagnosis, treatment 44 | 43 carcinoid tumors of the liver and pancreas, research, treatments 45 | 44 radiation induced thyroiditis, differential diagnosis, management 46 | 45 heat exhaustion, management and pathophysiology 47 | 46 complications and management of anorexia and bulimia 48 | 47 aids dementia, workup 49 | 48 infections in renal transplant patients 50 | 49 theophylline uses--chronic and acute asthma 51 | 50 lung cancer, radiation therapy 52 | 51 surgery vs. percutaneous drainage for lung abscess 53 | 52 Catamenorrheal Anaphylaxis 54 | 53 PLASMAPHERESIS AS THERAPEUTIC OPTION for Guillain- Barre syndrome 55 | 54 Urinary Tract Infection, CRITERIA FOR TREATMENT AND ADMISSION 56 | 55 infiltrative small bowel processes, information about small bowel lymphoma and heavy alpha chain disease 57 | 56 iron deficiency anemia, which test is best 58 | 57 scheurmann's disease, treatment 59 | 58 sigmoidoscopy in preventive care, whether the recommended frequency of sigmoidoscopy is effective and sensitive in detecting cancer 60 | 59 how to best control pain and debilitation secondary to osteoporosis in never treated advanced disease 61 | 60 differential diagnosis of breakthrough vaginal bleeding while on estrogen and progesterone therapy 62 | 61 review of anemia of chronic illness 63 | 62 HIV and the GI tract, recent reviews 64 | -------------------------------------------------------------------------------- /script/run_rocketqav1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | queue=$2 4 | model_size=$3 5 | dataset=marco 6 | sample_num=20 7 | if [[ ${dataset} =~ "treccar" ]]; then 8 | sample_num=10 9 | fi 10 | if [[ ${queue} =~ "v100" ]]; then 11 | batch_size=`expr 40 / $sample_num` 12 | fi 13 | if [[ ${queue} =~ "a100" ]]; then 14 | batch_size=`expr 80 / $sample_num` 15 | fi 16 | if [[ ${model_size} =~ "base" ]]; then 17 | batch_size=`expr $batch_size \* 2` 18 | fi 19 | echo "batch size ${batch_size}" 20 | dev_batch_size=1 21 | pretrain_input_file='data/pretrain/*' 22 | train_input_file=data/${dataset}/train.concept.tar.gz #训练数据 23 | dev_input_file=data/${dataset}/dev.concept.dl2019.tar.gz #测试数据 24 | instance_num=502939 #v3: 917012 25 | sample_range=20 26 | if [[ ${dataset} =~ "treccar" ]]; then 27 | instance_num=2806552 28 | sample_range=10 29 | fi 30 | warmup_proportion=0.2 31 | eval_step_proportion=0.01 32 | report_step=10 33 | epoch=5 34 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 35 | ### 下面是永远不用改的 36 | min_index=25 37 | max_index=768 38 | max_seq_len=160 39 | collection=data/${dataset}/collection.tsv 40 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 41 | vocab_file=${model_size}/vocab.txt #ernie配置文件 42 | # warm_start_from=data/reranker-4gpu-5-2.p #ernie参数 43 | # warm_start_from=data/kgbest.p 44 | # warm_start_from=data/${dataset}/reranker-4gpu-5.p 45 | warm_start_from=data/${dataset}/ernie_base.p 46 | qrels=data/${dataset}/qrels.tsv 47 | query=data/${dataset}/train.query.txt 48 | resource='data/concept.txt' 49 | cpnet='data/conceptnet.en.pruned.graph' 50 | pattern_path='data/matcher_patterns.json' 51 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 52 | ent_emb='data/glove.transe.sgd.ent.npy' 53 | rel_emb='data/glove.transe.sgd.rel.npy' 54 | books='data/books.txt' 55 | output_dir=output 56 | log_dir=${output_dir}/log 57 | mkdir -p ${output_dir} 58 | mkdir -p ${log_dir} 59 | rm -rf /etc/pip.conf 60 | cp pip.conf /etc/pip.conf 61 | pip install networkx 62 | pip install pgl 63 | pip install spacy 64 | pip install nltk 65 | pip install gensim 66 | pip install rocketqa 67 | pip install data/gensim-4.1.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl 68 | python -m spacy download en_core_web_sm 69 | pip install data/en_core_web_sm-3.2.0-py3-none-any.whl 70 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 71 | python -m paddle.distributed.launch \ 72 | --log_dir ${log_dir} \ 73 | ernie/finetune.py \ 74 | --train_input_file=${train_input_file} \ 75 | --ernie_config_file=${ernie_config_file} \ 76 | --vocab_file=${vocab_file} \ 77 | --resource=${resource} \ 78 | --cpnet=${cpnet} \ 79 | --dev_input_file=${dev_input_file} \ 80 | --warm_start_from=${warm_start_from} \ 81 | --batch_size=${batch_size} \ 82 | --warmup_proportion=${warmup_proportion} \ 83 | --eval_step_proportion=${eval_step_proportion} \ 84 | --report=${report_step} \ 85 | --qrels=${qrels} \ 86 | --query=${query} \ 87 | --collection=${collection} \ 88 | --top1000=${top1000} \ 89 | --min_index=${min_index} \ 90 | --max_index=${max_index} \ 91 | --epoch=${epoch} \ 92 | --sample_num=${sample_num} \ 93 | --dev_batch_size=${dev_batch_size} \ 94 | --pretrain_input_file=${pretrain_input_file} \ 95 | --ent_emb=${ent_emb} \ 96 | --rel_emb=${rel_emb} \ 97 | --pattern_path=${pattern_path} \ 98 | --word2vec=${word2vec} \ 99 | --instance_num=${instance_num} \ 100 | --dataset=${dataset} \ 101 | --sample_range=${sample_range} \ 102 | --num_gnn_layers=3 103 | 104 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 105 | # upload 106 | echo "Starting uploading file to HDFS" 107 | # tar -zcvf /root/paddlejob/workspace/env_run/output.tar.gz gen_data/ 108 | ${hdfs_cmd} -mkdir /user/sasd-adv/diaoyan/user/modelzoo/${model_name} 109 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/output /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 110 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/ernie /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 111 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/script /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 112 | echo "Done uploading file to HDFS" 113 | -------------------------------------------------------------------------------- /script/pretrain_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | queue=$2 4 | model_size=$3 5 | dataset=marco 6 | sample_num=20 7 | if [[ ${dataset} =~ "treccar" ]]; then 8 | sample_num=10 9 | fi 10 | if [[ ${queue} =~ "v100" ]]; then 11 | batch_size=`expr 40 / $sample_num` 12 | fi 13 | if [[ ${queue} =~ "a100" ]]; then 14 | batch_size=`expr 80 / $sample_num` 15 | fi 16 | if [[ ${model_size} =~ "base" ]]; then 17 | batch_size=`expr $batch_size \* 2` 18 | fi 19 | echo "batch size ${batch_size}" 20 | dev_batch_size=1 21 | pretrain_input_file='data/pretrain/*' 22 | kerm_version='v3' 23 | run_func='pretrain' 24 | train_input_file=data/${dataset}/train.concept.tar.gz #训练数据 25 | dev_input_file=data/${dataset}/dev.concept.dl2019.tar.gz #测试数据 26 | instance_num=502939 #v3: 917012 27 | sample_range=20 28 | if [[ ${dataset} =~ "treccar" ]]; then 29 | instance_num=2806552 30 | sample_range=10 31 | fi 32 | warmup_proportion=0.2 33 | eval_step_proportion=0.01 34 | report_step=10 35 | epoch=5 36 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 37 | ### 下面是永远不用改的 38 | min_index=25 39 | max_index=768 40 | max_seq_len=160 41 | collection=data/${dataset}/collection.tsv 42 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 43 | vocab_file=${model_size}/vocab.txt #ernie配置文件 44 | # warm_start_from=data/reranker-4gpu-5-2.p #ernie参数 45 | # warm_start_from=data/kgbest.p 46 | # warm_start_from=data/${dataset}/reranker-4gpu-5.p 47 | warm_start_from=data/${dataset}/ernie_base.p 48 | qrels=data/${dataset}/qrels.tsv 49 | query=data/${dataset}/train.query.txt 50 | resource='data/concept.txt' 51 | cpnet='data/conceptnet.en.pruned.graph' 52 | pattern_path='data/matcher_patterns.json' 53 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 54 | ent_emb='data/glove.transe.sgd.ent.npy' 55 | rel_emb='data/glove.transe.sgd.rel.npy' 56 | books='data/books.txt' 57 | output_dir=output 58 | log_dir=${output_dir}/log 59 | mkdir -p ${output_dir} 60 | mkdir -p ${log_dir} 61 | rm -rf /etc/pip.conf 62 | cp pip.conf /etc/pip.conf 63 | pip install networkx 64 | pip install pgl 65 | pip install spacy 66 | pip install nltk 67 | pip install gensim 68 | pip install data/gensim-4.1.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl 69 | python -m spacy download en_core_web_sm 70 | pip install data/en_core_web_sm-3.2.0-py3-none-any.whl 71 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 72 | python -m paddle.distributed.launch \ 73 | --log_dir ${log_dir} \ 74 | ernie/finetune.py \ 75 | --train_input_file=${train_input_file} \ 76 | --ernie_config_file=${ernie_config_file} \ 77 | --vocab_file=${vocab_file} \ 78 | --resource=${resource} \ 79 | --cpnet=${cpnet} \ 80 | --dev_input_file=${dev_input_file} \ 81 | --warm_start_from=${warm_start_from} \ 82 | --batch_size=${batch_size} \ 83 | --warmup_proportion=${warmup_proportion} \ 84 | --eval_step_proportion=${eval_step_proportion} \ 85 | --report=${report_step} \ 86 | --qrels=${qrels} \ 87 | --query=${query} \ 88 | --collection=${collection} \ 89 | --top1000=${top1000} \ 90 | --min_index=${min_index} \ 91 | --max_index=${max_index} \ 92 | --epoch=${epoch} \ 93 | --sample_num=${sample_num} \ 94 | --dev_batch_size=${dev_batch_size} \ 95 | --pretrain_input_file=${pretrain_input_file} \ 96 | --ent_emb=${ent_emb} \ 97 | --rel_emb=${rel_emb} \ 98 | --pattern_path=${pattern_path} \ 99 | --word2vec=${word2vec} \ 100 | --instance_num=${instance_num} \ 101 | --dataset=${dataset} \ 102 | --sample_range=${sample_range} \ 103 | --kerm_version=${kerm_version} \ 104 | --run_func=${run_func} \ 105 | --num_gnn_layers=3 106 | 107 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 108 | # upload 109 | echo "Starting uploading file to HDFS" 110 | # tar -zcvf /root/paddlejob/workspace/env_run/output.tar.gz gen_data/ 111 | ${hdfs_cmd} -mkdir /user/sasd-adv/diaoyan/user/modelzoo/${model_name} 112 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/output /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 113 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/ernie /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 114 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/script /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 115 | echo "Done uploading file to HDFS" 116 | -------------------------------------------------------------------------------- /script/base_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | queue=$2 4 | model_size=$3 5 | dataset=marco 6 | sample_num=20 7 | if [[ ${dataset} =~ "treccar" ]]; then 8 | sample_num=10 9 | fi 10 | if [[ ${queue} =~ "v100" ]]; then 11 | batch_size=`expr 40 / $sample_num` 12 | fi 13 | if [[ ${queue} =~ "a100" ]]; then 14 | batch_size=`expr 80 / $sample_num` 15 | fi 16 | if [[ ${model_size} =~ "base" ]]; then 17 | batch_size=`expr $batch_size \* 2` 18 | fi 19 | echo "batch size ${batch_size}" 20 | dev_batch_size=1 21 | kerm_version='v1' 22 | run_func='finetune' 23 | pretrain_input_file='data/pretrain/*' 24 | # train_input_file=data/${dataset}/train.concept.tar.gz #训练数据 25 | # dev_input_file=data/${dataset}/dev.concept.tar.gz #测试数据 26 | # test_input_file=data/${dataset}/test.concept.tar.gz 27 | train_input_file=data/marco/train.concept.tar.gz #训练数据 28 | dev_input_file=data/ohsumed/dev.concept.tar.gz #测试数据 29 | test_input_file=data/marco/dev.concept.dl2019.tar.gz 30 | instance_num=502939 #v3: 917012 31 | sample_range=20 32 | warmup_proportion=0.2 33 | eval_step_proportion=0.01 34 | report_step=10 35 | epoch=5 36 | # eval_step_proportion=`echo "scale=5; 1/$epoch" | bc` 37 | ### 下面是永远不用改的 38 | min_index=25 39 | max_index=768 40 | max_seq_len=160 41 | collection=data/${dataset}/collection.tsv 42 | ernie_config_file=${model_size}/ernie_config.json #ernie配置文件 43 | vocab_file=${model_size}/vocab.txt #ernie配置文件 44 | warm_start_from=data/${dataset}/reranker-4gpu-5.p 45 | # warm_start_from=data/${dataset}/ernie_base.p 46 | qrels=data/${dataset}/qrels.tsv 47 | query=data/${dataset}/train.query.txt 48 | resource='data/concept.txt' 49 | cpnet='data/conceptnet.en.pruned.graph' 50 | pattern_path='data/matcher_patterns.json' 51 | word2vec='data/GoogleNews-vectors-negative300.bin.gz' 52 | ent_emb='data/glove.transe.sgd.ent.npy' 53 | rel_emb='data/glove.transe.sgd.rel.npy' 54 | books='data/books.txt' 55 | output_dir=output 56 | log_dir=${output_dir}/log 57 | mkdir -p ${output_dir} 58 | mkdir -p ${log_dir} 59 | rm -rf /etc/pip.conf 60 | cp pip.conf /etc/pip.conf 61 | pip install networkx 62 | pip install pgl 63 | pip install spacy 64 | pip install nltk 65 | pip install gensim 66 | pip install data/gensim-4.1.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl 67 | python -m spacy download en_core_web_sm 68 | pip install data/en_core_web_sm-3.2.0-py3-none-any.whl 69 | echo "=================start train ${OMPI_COMM_WORLD_RANK:-0}==================" 70 | python -m paddle.distributed.launch \ 71 | --log_dir ${log_dir} \ 72 | ernie/finetune.py \ 73 | --train_input_file=${train_input_file} \ 74 | --ernie_config_file=${ernie_config_file} \ 75 | --vocab_file=${vocab_file} \ 76 | --resource=${resource} \ 77 | --cpnet=${cpnet} \ 78 | --dev_input_file=${dev_input_file} \ 79 | --test_input_file=${test_input_file} \ 80 | --warm_start_from=${warm_start_from} \ 81 | --batch_size=${batch_size} \ 82 | --warmup_proportion=${warmup_proportion} \ 83 | --eval_step_proportion=${eval_step_proportion} \ 84 | --report=${report_step} \ 85 | --qrels=${qrels} \ 86 | --query=${query} \ 87 | --collection=${collection} \ 88 | --top1000=${top1000} \ 89 | --min_index=${min_index} \ 90 | --max_index=${max_index} \ 91 | --epoch=${epoch} \ 92 | --sample_num=${sample_num} \ 93 | --dev_batch_size=${dev_batch_size} \ 94 | --pretrain_input_file=${pretrain_input_file} \ 95 | --ent_emb=${ent_emb} \ 96 | --rel_emb=${rel_emb} \ 97 | --pattern_path=${pattern_path} \ 98 | --word2vec=${word2vec} \ 99 | --instance_num=${instance_num} \ 100 | --dataset=${dataset} \ 101 | --sample_range=${sample_range} \ 102 | --kerm_version=${kerm_version} \ 103 | --run_func=${run_func} \ 104 | --num_gnn_layers=0 105 | 106 | echo "=================done train ${OMPI_COMM_WORLD_RANK:-0}==================" 107 | # upload 108 | echo "Starting uploading file to HDFS" 109 | # tar -zcvf /root/paddlejob/workspace/env_run/output.tar.gz gen_data/ 110 | ${hdfs_cmd} -mkdir /user/sasd-adv/diaoyan/user/modelzoo/${model_name} 111 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/output /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 112 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/ernie /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 113 | ${hdfs_cmd} -put /root/paddlejob/workspace/env_run/script /user/sasd-adv/diaoyan/user/modelzoo/${model_name}/ 114 | echo "Done uploading file to HDFS" 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KERM 2 | 3 | Code for KERM: Incorporating Explicit Knowledge in Pre-trained Language Models for Passage Re-ranking, accepted at SIGIR 2022. 4 | 5 | ### Dependencies 6 | > networkx==2.6.3 7 | 8 | > paddlepaddle-gpu==2.1.0 9 | 10 | > pgl==2.1.5 11 | 12 | > spacy==3.2.0 13 | 14 | > gensim==4.1.2 15 | 16 | > At least 4*Tesla A100(40GB) 17 | ### Data preparation 18 | 1. The ConceptNet-related resources used in KERM can be downloaded from [here](https://drive.google.com/drive/folders/155codqEnsKazO8-BchF3rO_cP3EyYdws), which is provided by [MHGRN](https://github.com/INK-USC/MHGRN). Placed in "data/". 19 | 2. MARCO and TREC 2019DL could be downloaded from [here](https://microsoft.github.io/msmarco/). Placed in "data/dataset_name". 20 | 3. The bio-medical dataset Ohsumed is available at [here](http://disi.unitn.it/moschitti/corpora.htm). Placed in "data/dataset_name". 21 | 4. The word2vec embedding could be downloaded from [here](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?resourcekey=0-wjGZdNAUop6WykTtMip30g). Placed in "data/". 22 | 5. The parameters of ERNIE-2.0 could be downloaded from [here](https://github.com/PaddlePaddle/ERNIE/blob/3fb0b4911d5be66ef157df3d46d046e16ffc7b36/README.eng.md#3-download-pretrained-models-optional), including both base and large model. Placed in "data/". 23 | 6. The top1000-train passages for training queries could be retrieved by [this](https://github.com/PaddlePaddle/RocketQA/tree/main/research/RocketQA_NAACL2021#dual-encoder-inference), from which the hard negatives are sampled based on the Dense Passage Retrieval (DPR) in [RocketQA](https://github.com/PaddlePaddle/RocketQA/tree/main/research/RocketQA_NAACL2021). Notably, the [top1000-dev](https://www.dropbox.com/s/5pqpcnlzlib2b3a/run.bm25.dev.small.tsv.gz?dl=1) for dev queries is obtained from this [repo](https://github.com/castorini/duobert), which is widely-used in previous works. After obtaining the top1000 data, please convert them into the flowing format: (Placed in "data/dataset_name".) 24 | >top1000-train 25 | >>query_id passage_id index score 26 | >>>121352 2912791 1 131.90326 27 | 28 | >>>121352 7282917 2 131.07689 29 | 30 | >>>121352 7480161 3 130.65248 31 | 32 | >top1000-dev 33 | >>query_id passage_id query_text - passage_text label 34 | >>>188714 2133570 foods and supplements to lower blood sugar - A healthy diet is essential to reversing prediabetes. There are no foods, herbs, drinks, or supplements that lower blood sugar. Only medication and exercise can. But there are things you can eat and drink that are low on the glycemic index (GI). This means these foods wonât raise your blood sugar and may help you avoid a blood sugar spike. 0 35 | 36 | >>>188714 4321742 foods and supplements to lower blood sugar - Ohio State University, researchers saw insulin levels drop 23 percent and blood sugar levels drop 29 percent in patients who took a 1,000-mg dose of the herb. Amazing! These are just a few of the natural foods and supplements that will lower your blood sugar level naturally. One thing that is very important is that you keep your health care provider up to date on any supplements that you will be utilizing as a natural way to lower your blood sugar. 0 37 | 38 | >>>188714 4321745 foods and supplements to lower blood sugar - Food And Supplements That Lower Blood Sugar Levels. Cinnamon: Researchers are finding that cinnamon reduces blood sugar levels naturally when taken daily. If you absolutely love cinnamon you can sprinkle the recommended six grams of cinnamon on your food throughout the day to achieve the desired effect. 1 39 | ### Data generation 40 | For the efficiency of training, we first generate the data used in our work once for all. We take the MARCO and KERM-large for example. 41 | 1. Generate the data for knowledge-enhanced pre-training: 42 | >sh script/batch_submit_genpretrain.sh 43 | 2. Generate the data for training and evaluation: 44 | >sh script/gendata.sh train 45 | 46 | >sh script/gendata.sh eval 47 | 48 | ### Knowledge-enhanced pre-training & finetune 49 | To better training the KERM, we first continuous pre-training the ERNIE-large to warm up the parameters of GMN: 50 | >sh script/pretrain_large.sh 51 | 52 | >sh script/large_finetune.sh 53 | 54 | ### Acknowledgement 55 | Some snippets of the codes are borrowed from [MHGRN](https://github.com/INK-USC/MHGRN), [ERNIE](https://github.com/PaddlePaddle/ERNIE) and [ERNIE-THU](https://github.com/thunlp/ERNIE). 56 | To cite this paper, use the following BibTex: 57 | > @article{dong2022incorporating, 58 | title={Incorporating Explicit Knowledge in Pre-trained Language Models for Passage Re-ranking}, 59 | author={Dong, Qian and Liu, Yiding and Cheng, Suqi and Wang, Shuaiqiang and Cheng, Zhicong and Niu, Shuzi and Yin, Dawei}, 60 | journal={arXiv preprint arXiv:2204.11673}, 61 | year={2022} 62 | } 63 | -------------------------------------------------------------------------------- /ernie/unbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This package implement Heterogeneous Graph structure for handling Heterogeneous graph data. 16 | """ 17 | 18 | import os 19 | import json 20 | import paddle 21 | import copy 22 | import numpy as np 23 | import pickle as pkl 24 | from collections import defaultdict 25 | 26 | from pgl.graph import Graph 27 | from pgl.utils import op 28 | 29 | def unbatch(graph): 30 | """This method disjoint list of graph into a big graph. 31 | 32 | Args: 33 | 34 | graph_list (Graph List): A list of Graphs. 35 | 36 | merged_graph_index: whether to keeped the graph_id that the nodes belongs to. 37 | 38 | 39 | .. code-block:: python 40 | 41 | import numpy as np 42 | import pgl 43 | 44 | num_nodes = 5 45 | edges = [ (0, 1), (1, 2), (3, 4)] 46 | graph = pgl.Graph(num_nodes=num_nodes, 47 | edges=edges) 48 | joint_graph = pgl.Graph.disjoint([graph, graph], merged_graph_index=False) 49 | print(joint_graph.graph_node_id) 50 | >>> [0, 0, 0, 0, 0, 1, 1, 1, 1 ,1] 51 | print(joint_graph.num_graph) 52 | >>> 2 53 | 54 | joint_graph = pgl.Graph.disjoint([graph, graph], merged_graph_index=True) 55 | print(joint_graph.graph_node_id) 56 | >>> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 57 | print(joint_graph.num_graph) 58 | >>> 1 59 | """ 60 | 61 | edges_list = disjoin_edges(graph) 62 | num_nodes_list = disjoin_nodes(graph) 63 | node_feat_list = disjoin_feature(graph, mode="node") 64 | edge_feat_list = disjoin_feature(graph, mode="edge") 65 | graph_list = [] 66 | for edges, num_nodes, node_feat, edge_feat in zip(edges_list, num_nodes_list, node_feat_list, edge_feat_list): 67 | graph = Graph(num_nodes=num_nodes, 68 | edges=edges, 69 | node_feat=node_feat, 70 | edge_feat=edge_feat) 71 | graph_list.append(graph) 72 | return graph_list 73 | 74 | def disjoin_edges(graph): 75 | """join edges for multiple graph""" 76 | start_offset_list = graph._graph_node_index[: -1] 77 | start_list, end_list = graph._graph_edge_index[: -1], graph._graph_edge_index[1: ] 78 | 79 | edges_list = [] 80 | for start, end, start_offset in zip(start_list, end_list, start_offset_list): 81 | edges = graph.edges[start: end] 82 | edges -= start_offset 83 | edges_list.append(edges) 84 | return edges_list 85 | 86 | def disjoin_nodes(graph): 87 | num_nodes_list = [] 88 | start_list, end_list = graph._graph_node_index[: -1], graph._graph_node_index[1: ] 89 | for start, end in zip(start_list, end_list): 90 | num_nodes_list.append(end - start) 91 | return num_nodes_list 92 | 93 | def disjoin_feature(graph, mode="node"): 94 | """join node features for multiple graph""" 95 | is_tensor = graph.is_tensor() 96 | feat_list = [] 97 | if mode == "node": 98 | start_list, end_list = graph._graph_node_index[: -1], graph._graph_node_index[1: ] 99 | for start, end in zip(start_list, end_list): 100 | feat = defaultdict(lambda: []) 101 | for key in graph.node_feat: 102 | feat[key].append(graph.node_feat[key][start: end]) 103 | feat_list.append(feat) 104 | elif mode == "edge": 105 | start_list, end_list = graph._graph_edge_index[: -1], graph._graph_edge_index[1: ] 106 | for start, end in zip(start_list, end_list): 107 | feat = defaultdict(lambda: []) 108 | for key in graph.edge_feat: 109 | feat[key].append(graph.edge_feat[key][start: end]) 110 | feat_list.append(feat) 111 | else: 112 | raise ValueError( 113 | "mode must be in ['node', 'edge']. But received model=%s" % 114 | mode) 115 | 116 | feat_list_temp = [] 117 | for feat in feat_list: 118 | ret_feat = {} 119 | for key in feat: 120 | if len(feat[key]) == 1: 121 | ret_feat[key] = feat[key][0] 122 | else: 123 | if is_tensor: 124 | ret_feat[key] = paddle.concat(feat[key], 0) 125 | else: 126 | ret_feat[key] = np.concatenate(feat[key], axis=0) 127 | feat_list_temp.append(ret_feat) 128 | return feat_list_temp 129 | -------------------------------------------------------------------------------- /ernie/gnn_layer.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import numpy as np 3 | import math 4 | 5 | import paddle 6 | import paddle as P 7 | import paddle.nn as nn 8 | import paddle.nn.functional as F 9 | import paddle.distributed as dist 10 | 11 | import pgl 12 | import pgl.nn as gnn 13 | from pgl.nn import functional as GF 14 | from pgl.utils.logger import log 15 | 16 | def _build_linear(n_in, n_out, name, init): 17 | return nn.Linear( 18 | n_in, 19 | n_out, 20 | weight_attr=P.ParamAttr( 21 | name='%s.w_0' % name if name is not None else None, 22 | initializer=init), 23 | bias_attr='%s.b_0' % name if name is not None else None, ) 24 | 25 | def batch_norm_1d(num_channels): 26 | if dist.get_world_size() > 1: 27 | return nn.SyncBatchNorm.convert_sync_batchnorm(nn.BatchNorm1D(num_channels)) 28 | else: 29 | return nn.BatchNorm1D(num_channels) 30 | 31 | class LiteGEMConv(paddle.nn.Layer): 32 | def __init__(self, config, with_efeat=True): 33 | super(LiteGEMConv, self).__init__() 34 | log.info("layer_type is %s" % self.__class__.__name__) 35 | self.config = config 36 | self.with_efeat = with_efeat 37 | self.aggr = self.config['aggr'] 38 | self.eps = 1e-7 39 | self.emb_dim = self.config['gnn_hidden_size'] 40 | initializer = nn.initializer.TruncatedNormal( 41 | std=config['initializer_range']) 42 | self.f1 = _build_linear(config['gnn_hidden_size']*2, config['gnn_hidden_size'], name=None, init=initializer) 43 | self.f2 = _build_linear(config['gnn_hidden_size']*2, config['gnn_hidden_size'], name=None, init=initializer) 44 | self.f3 = _build_linear(config['gnn_hidden_size']*2, config['gnn_hidden_size'], name=None, init=initializer) 45 | self.fc_concat = Linear(self.emb_dim * 3, self.emb_dim) 46 | assert self.aggr in ['softmax_sg', 'softmax', 'power'] 47 | 48 | channels_list = [self.emb_dim] 49 | for i in range(1, self.config['mlp_layers']): 50 | channels_list.append(self.emb_dim * 2) 51 | channels_list.append(self.emb_dim) 52 | 53 | self.mlp = MLP(channels_list, 54 | norm=self.config['norm'], 55 | last_lin=True) 56 | 57 | def send_func(self, src_feat, dst_feat, edge_feat): 58 | # h = paddle.concat([dst_feat['h'], src_feat['h'], edge_feat['e']], axis=1) 59 | # h = self.fc_concat(h) 60 | # 源节点到边的转移概率,边到目标节点的转移概率 61 | src_feat = src_feat['h'] 62 | dst_feat = dst_feat['h'] 63 | edge_feat = edge_feat['e'] 64 | h = self.f1(paddle.concat([src_feat, edge_feat],axis=1))+self.f2(paddle.concat([edge_feat, dst_feat],axis=1))+self.f3(paddle.concat([src_feat, dst_feat],axis=1)) 65 | msg = {"h": F.swish(h) + self.eps} 66 | return msg 67 | 68 | def recv_func(self, msg): 69 | alpha = msg.reduce_softmax(msg["h"]) 70 | out = msg['h'] * alpha 71 | out = msg.reduce_sum(out) 72 | return out 73 | 74 | 75 | def forward(self, graph, nfeat, efeat=None): 76 | msg = graph.send(src_feat={"h": nfeat}, 77 | dst_feat={"h": nfeat}, 78 | edge_feat={"e": efeat}, 79 | message_func=self.send_func) 80 | out = graph.recv(msg=msg, reduce_func=self.recv_func) 81 | out = nfeat + out 82 | out = self.mlp(out) 83 | return out 84 | 85 | 86 | 87 | 88 | def Linear(input_size, hidden_size, with_bias=True): 89 | fan_in = input_size 90 | bias_bound = 1.0 / math.sqrt(fan_in) 91 | fc_bias_attr = paddle.ParamAttr(initializer=nn.initializer.Uniform( 92 | low=-bias_bound, high=bias_bound)) 93 | 94 | negative_slope = math.sqrt(5) 95 | gain = math.sqrt(2.0 / (1 + negative_slope**2)) 96 | std = gain / math.sqrt(fan_in) 97 | weight_bound = math.sqrt(3.0) * std 98 | fc_w_attr = paddle.ParamAttr(initializer=nn.initializer.Uniform( 99 | low=-weight_bound, high=weight_bound)) 100 | 101 | if not with_bias: 102 | fc_bias_attr = False 103 | 104 | return nn.Linear( 105 | input_size, hidden_size, weight_attr=fc_w_attr, bias_attr=fc_bias_attr) 106 | 107 | def norm_layer(norm_type, nc): 108 | # normalization layer 1d 109 | norm = norm_type.lower() 110 | if norm == 'batch': 111 | layer = batch_norm_1d(nc) 112 | elif norm == 'layer': 113 | layer = nn.LayerNorm(nc) 114 | else: 115 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 116 | return layer 117 | 118 | def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): 119 | # activation layer 120 | act = act_type.lower() 121 | if act == 'relu': 122 | layer = nn.ReLU() 123 | elif act == 'leakyrelu': 124 | layer = nn.LeakyReLU(neg_slope, inplace) 125 | elif act == 'prelu': 126 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 127 | elif act == 'swish': 128 | layer = nn.Swish() 129 | else: 130 | raise NotImplementedError('activation layer [%s] is not found' % act) 131 | return layer 132 | 133 | class MLP(paddle.nn.Sequential): 134 | def __init__(self, channels, act='swish', norm=None, bias=True, drop=0., last_lin=False): 135 | m = [] 136 | 137 | for i in range(1, len(channels)): 138 | m.append(Linear(channels[i - 1], channels[i], bias)) 139 | if norm is not None and norm.lower() != 'none': 140 | m.append(norm_layer(norm, channels[i])) 141 | if act is not None and act.lower() != 'none': 142 | m.append(act_layer(act)) 143 | if drop > 0: 144 | m.append(nn.Dropout(drop)) 145 | 146 | self.m = m 147 | super(MLP, self).__init__(*self.m) -------------------------------------------------------------------------------- /ernie/static2dynamic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import paddle 4 | import numpy as np 5 | def match_embedding_param(convert_parameter_name_dict): 6 | convert_parameter_name_dict[ 7 | "word_emb.weight"] = "word_embedding" 8 | convert_parameter_name_dict[ 9 | "pos_emb.weight"] = "pos_embedding" 10 | convert_parameter_name_dict[ 11 | "sent_emb.weight"] = "sent_embedding" 12 | convert_parameter_name_dict[ 13 | "ln.weight"] = "pre_encoder_layer_norm_scale" 14 | convert_parameter_name_dict[ 15 | "ln.bias"] = "pre_encoder_layer_norm_bias" 16 | return convert_parameter_name_dict 17 | 18 | 19 | def match_encoder_param(convert_parameter_name_dict, layer_num=4): 20 | dygraph_proj_names = ["q", "k", "v", "o"] 21 | static_proj_names = ["query", "key", "value", "output"] 22 | dygraph_param_names = ["weight", "bias"] 23 | static_param_names = ["w", "b"] 24 | dygraph_layer_norm_param_names = ["weight", "bias"] 25 | static_layer_norm_param_names = ["scale", "bias"] 26 | 27 | # Firstly, converts the multihead_attention to the parameter. 28 | # dygraph_format_name = "encoder.layers.{}.self_attn.{}_proj.{}" 29 | # encoder_stack.block.10.attn.o.bias 30 | dygraph_format_name = "encoder_stack.block.{}.attn.{}.{}" 31 | static_format_name = "encoder_layer_{}_multi_head_att_{}_fc.{}_0" 32 | for i in range(0, layer_num): 33 | for dygraph_proj_name, static_proj_name in zip(dygraph_proj_names, 34 | static_proj_names): 35 | for dygraph_param_name, static_param_name in zip( 36 | dygraph_param_names, static_param_names): 37 | convert_parameter_name_dict[dygraph_format_name.format(i, dygraph_proj_name, dygraph_param_name)] = \ 38 | static_format_name.format(i, static_proj_name, static_param_name) 39 | 40 | # Secondly, converts the encoder ffn parameter. 41 | # dygraph_ffn_linear_format_name = "encoder.layers.{}.linear{}.{}" 42 | #encoder_stack.block.0.ffn.i.weight 43 | dygraph_ffn_linear_format_name = "encoder_stack.block.{}.ffn.{}.{}" 44 | static_ffn_linear_format_name = "encoder_layer_{}_ffn_fc_{}.{}_0" 45 | for i in range(0, layer_num): 46 | for cnt,j in enumerate(['i','o']): 47 | for dygraph_param_name, static_param_name in zip( 48 | dygraph_param_names, static_param_names): 49 | convert_parameter_name_dict[dygraph_ffn_linear_format_name.format(i, j, dygraph_param_name)] = \ 50 | static_ffn_linear_format_name.format(i, cnt, static_param_name) 51 | 52 | # Thirdly, converts the multi_head layer_norm parameter. 53 | # dygraph_encoder_attention_layer_norm_format_name = "encoder.layers.{}.norm1.{}" 54 | dygraph_encoder_attention_layer_norm_format_name = "encoder_stack.block.{}.ln1.{}" 55 | static_encoder_attention_layer_norm_format_name = "encoder_layer_{}_post_att_layer_norm_{}" 56 | for i in range(0, layer_num): 57 | for dygraph_param_name, static_pararm_name in zip( 58 | dygraph_layer_norm_param_names, static_layer_norm_param_names): 59 | convert_parameter_name_dict[dygraph_encoder_attention_layer_norm_format_name.format(i, dygraph_param_name)] = \ 60 | static_encoder_attention_layer_norm_format_name.format(i, static_pararm_name) 61 | 62 | # dygraph_encoder_ffn_layer_norm_format_name = "encoder.layers.{}.norm2.{}" 63 | dygraph_encoder_ffn_layer_norm_format_name = "encoder_stack.block.{}.ln2.{}" 64 | static_encoder_ffn_layer_norm_format_name = "encoder_layer_{}_post_ffn_layer_norm_{}" 65 | for i in range(0, layer_num): 66 | for dygraph_param_name, static_pararm_name in zip( 67 | dygraph_layer_norm_param_names, static_layer_norm_param_names): 68 | convert_parameter_name_dict[dygraph_encoder_ffn_layer_norm_format_name.format(i, dygraph_param_name)] = \ 69 | static_encoder_ffn_layer_norm_format_name.format(i, static_pararm_name) 70 | return convert_parameter_name_dict 71 | 72 | def match_pooler_parameter(convert_parameter_name_dict): 73 | convert_parameter_name_dict["pooler.dense.weight"] = "pooled_fc.w_0" 74 | convert_parameter_name_dict["pooler.dense.bias"] = "pooled_fc.b_0" 75 | return convert_parameter_name_dict 76 | 77 | 78 | def match_mlm_parameter(convert_parameter_name_dict): 79 | # convert_parameter_name_dict["cls.predictions.decoder_weight"] = "word_embedding" 80 | convert_parameter_name_dict[ 81 | "cls.predictions.decoder_bias"] = "mask_lm_out_fc.b_0" 82 | convert_parameter_name_dict[ 83 | "cls.predictions.transform.weight"] = "mask_lm_trans_fc.w_0" 84 | convert_parameter_name_dict[ 85 | "cls.predictions.transform.bias"] = "mask_lm_trans_fc.b_0" 86 | convert_parameter_name_dict[ 87 | "cls.predictions.layer_norm.weight"] = "mask_lm_trans_layer_norm_scale" 88 | convert_parameter_name_dict[ 89 | "cls.predictions.layer_norm.bias"] = "mask_lm_trans_layer_norm_bias" 90 | return convert_parameter_name_dict 91 | 92 | 93 | def convert_static_to_dygraph_params(dygraph_params_save_path, 94 | static_params_dir, 95 | static_to_dygraph_param_name, 96 | model_name='static'): 97 | files = os.listdir(static_params_dir) 98 | 99 | state_dict = {} 100 | model_name = model_name 101 | for name in files: 102 | path = os.path.join(static_params_dir, name) 103 | # static_para_name = name.replace('@HUB_chinese-roberta-wwm-ext-large@', 104 | # '') # for hub module params 105 | static_para_name = name.replace('.npy', '') 106 | if static_para_name not in static_to_dygraph_param_name: 107 | print(static_para_name, "not in static_to_dygraph_param_name") 108 | continue 109 | dygraph_para_name = static_to_dygraph_param_name[static_para_name] 110 | value = np.load(path) 111 | if "cls" in dygraph_para_name: 112 | # Note: cls.predictions parameters do not need add `model_name.` prefix 113 | state_dict[dygraph_para_name] = value 114 | else: 115 | state_dict[model_name + '.' + dygraph_para_name] = value 116 | 117 | with open(dygraph_params_save_path, 'wb') as f: 118 | pickle.dump(state_dict, f) 119 | params = paddle.load(dygraph_params_save_path) 120 | 121 | for name in state_dict.keys(): 122 | if name in params: 123 | assert ((state_dict[name] == params[name].numpy()).all() == True) 124 | else: 125 | print(name, 'not in params') 126 | 127 | 128 | if __name__=="__main__": 129 | convert_parameter_name_dict = {} 130 | 131 | convert_parameter_name_dict = match_embedding_param( 132 | convert_parameter_name_dict) 133 | convert_parameter_name_dict = match_encoder_param( 134 | convert_parameter_name_dict, layer_num=24) 135 | convert_parameter_name_dict = match_pooler_parameter( 136 | convert_parameter_name_dict) 137 | convert_parameter_name_dict = match_mlm_parameter( 138 | convert_parameter_name_dict) 139 | 140 | static_to_dygraph_param_name = { 141 | value: key 142 | for key, value in convert_parameter_name_dict.items() 143 | } 144 | import paddle 145 | state_dict = paddle.load("../NAACL2021-RocketQA/checkpoint/marco_cross_encoder_large/") 146 | params = {} 147 | miss = [] 148 | for key in state_dict: 149 | try: 150 | params[static_to_dygraph_param_name[key]] = state_dict[key] 151 | except: 152 | miss.append(key) 153 | -------------------------------------------------------------------------------- /ernie/count.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import unicode_literals 4 | 5 | import os 6 | os.environ['FLAGS_eager_delete_tensor_gb'] = '0' 7 | import sys 8 | sys.path.append("..") 9 | from collections import Counter 10 | import io 11 | import os 12 | import six 13 | import argparse 14 | import logging 15 | import paddle 16 | paddle.device.set_device("gpu") 17 | import paddle.fluid as F 18 | import utils 19 | 20 | import tokenization 21 | import dataset_factory 22 | from model import ErnieWithGNN, ErnieRanker,ErnieWithGNNv2 23 | from ernie_concept import ErnieWithConcept,pretrainedErnieWithConcept 24 | import paddle.distributed as dist 25 | from multiprocessing import cpu_count 26 | import numpy as np 27 | import pickle as pkl 28 | from msmarco_eval import get_mrr 29 | from paddle.io import DistributedBatchSampler,BatchSampler,get_worker_info 30 | from msmarco_eval import Mrr 31 | from paddle.hapi.model import _all_gather 32 | from paddle.fluid.dygraph.parallel import ParallelEnv 33 | from tqdm import tqdm 34 | import time 35 | # if six.PY3: 36 | # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 37 | 38 | import json 39 | import spacy 40 | import glob 41 | import random 42 | import numpy as np 43 | import tarfile 44 | from paths import BatchedGraphBuilder,ConceptGraphBuilder 45 | import paddle 46 | from paddle.io import Dataset, DistributedBatchSampler, get_worker_info,IterableDataset 47 | import paddle.distributed as dist 48 | import pandas as pd 49 | import tokenization 50 | import pickle as pkl 51 | import glob 52 | import pgl 53 | import re 54 | import multiprocessing 55 | 56 | def define_args(): 57 | parser = argparse.ArgumentParser('kg-ERNIE-rerank model') 58 | parser.add_argument('--run', type=str, default="nce") 59 | parser.add_argument('--model', type=str, default="ErnieWithGNN") 60 | parser.add_argument('--num_gnn_layers', type=int, default=3) 61 | parser.add_argument('--ernie_config_file', type=str, default="/home/dongqian06/codes/kgreranker/base/ernie_config.json") 62 | parser.add_argument('--vocab_file', type=str, default="/home/dongqian06/codes/kgreranker/base/vocab.txt") 63 | parser.add_argument('--train_input_file', type=str,default="/home/dongqian06/hdfs_data/data_train/train.concept.tar.gz") 64 | parser.add_argument('--dev_input_file', type=str, default="/home/dongqian06/data.tar.gz") 65 | parser.add_argument('--pretrain_input_file', type=str, default="/home/dongqian06/hdfs_data/data_train/pretrain/*") 66 | parser.add_argument('--pretrain_batch_size', type=int, default=64) 67 | parser.add_argument('--batch_size', type=int, default=8) 68 | parser.add_argument('--dev_batch_size', type=int, default=64) 69 | parser.add_argument('--max_seq_len', type=int, default=160) 70 | parser.add_argument('--warm_start_from', type=str, default="/home/dongqian06/hdfs_data/data_train/ernie_base.p") 71 | parser.add_argument('--learning_rate', type=float, default=1e-4) 72 | parser.add_argument('--weight_decay', type=float, default=0.01) 73 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 74 | parser.add_argument('--eval_step_proportion', type=float, default=1.0) 75 | parser.add_argument('--ernie_lr', type=float, default=1e-5) 76 | parser.add_argument('--report', type=int, default=10) 77 | parser.add_argument('--epoch', type=int, default=50) 78 | parser.add_argument('--qrels', type=str, default="/home/dongqian06/hdfs_data/data_train/qrels.train.tsv") 79 | parser.add_argument('--top1000', type=str, default="/home/dongqian06/hdfs_data/data_train/train.concept.gz") 80 | parser.add_argument('--collection', type=str, default="/home/dongqian06/hdfs_data/data_train/collection.tsv") 81 | parser.add_argument('--query', type=str, default="/home/dongqian06/hdfs_data/data_train/train.query.txt") 82 | parser.add_argument('--min_index', type=int, default=25) 83 | parser.add_argument('--max_index', type=int, default=768) 84 | parser.add_argument('--sample_num', type=int, default=10) 85 | parser.add_argument('--num_labels', type=int, default=1) 86 | parser.add_argument('--local_rank', type=int, default=1) 87 | parser.add_argument('--fp16', type=bool, default=False) 88 | 89 | 90 | # gnn config 91 | parser.add_argument('--with_efeat', type=bool, default=True) 92 | parser.add_argument('--virtual_node', type=bool, default=False) 93 | parser.add_argument('--num_conv_layers', type=int, default=3) 94 | parser.add_argument('--drop_ratio', type=float, default=0.) 95 | parser.add_argument('--norm', type=str, default='layer') 96 | parser.add_argument('--aggr', type=str, default='softmax') 97 | parser.add_argument('--learn_t', type=bool, default=False) 98 | parser.add_argument('--learn_p', type=bool, default=False) 99 | parser.add_argument('--init_t', type=float, default=1.0) 100 | parser.add_argument('--init_p', type=float, default=1.0) 101 | parser.add_argument('--concat', type=bool, default=True) 102 | parser.add_argument('--mlp_layers', type=int, default=1) 103 | parser.add_argument('--edge_num', type=int, default=25) 104 | 105 | parser.add_argument('--resource', type=str, default="/home/dongqian06/hdfs_data/concept_net/concept.txt") 106 | parser.add_argument('--cpnet', type=str, default="/home/dongqian06/hdfs_data/concept_net/conceptnet.en.pruned.graph") 107 | parser.add_argument('--pattern_path', type=str, default="/home/dongqian06/hdfs_data/concept_net/matcher_patterns.json") 108 | parser.add_argument('--word2vec', type=str, default="/home/dongqian06/hdfs_data/data_train/GoogleNews-vectors-negative300.bin.gz") 109 | parser.add_argument('--topk_sents', type=int, default=1) 110 | parser.add_argument('--ent_emb', type=str, default="/home/dongqian06/hdfs_data/concept_net/glove.transe.sgd.ent.npy") 111 | parser.add_argument('--rel_emb', type=str, default="/home/dongqian06/hdfs_data/concept_net/glove.transe.sgd.rel.npy") 112 | parser.add_argument('--books', type=str, default="/home/dongqian06/hdfs_data/data_train/books.txt") 113 | parser.add_argument('--cnts', type=str, default="/home/dongqian06/hdfs_data/data_train/cnts.pkl") 114 | parser.add_argument('--gnn_hidden_size', type=int, default=100) 115 | parser.add_argument('--instance_num', type=int, default=109)# 502939 116 | # args = parser.parse_args(args=[]) 117 | args = parser.parse_args() 118 | return args 119 | 120 | 121 | 122 | def func(cfg): 123 | # cpu_worker_num = multiprocessing.cpu_count() 124 | cpu_worker_num = 26 125 | print(cpu_worker_num) 126 | local_rank = cfg['local_rank'] 127 | books = open(cfg['books'],'r') 128 | books = books.readlines() 129 | nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) 130 | nlp.add_pipe('sentencizer') 131 | def lemmatize(concept): 132 | doc = nlp(concept.replace("_", " ")) 133 | lcs = set() 134 | lcs.add("_".join([token.lemma_ for token in doc])) # all lemma 135 | return lcs 136 | def count(sent): 137 | sent = sent.lower() 138 | sent = sent.replace("-", "_") 139 | spans = [] 140 | tokens = sent.split(" ") 141 | token_num = len(tokens) 142 | for length in range(1, 5): 143 | for i in range(token_num-length+1): 144 | span = "_".join(tokens[i:i+length]) 145 | span = list(lemmatize(span))[0] 146 | if span in cnts: 147 | cnts[span]+=1 148 | # print(span) 149 | cnts = pkl.load(open(cfg['cnts'],"rb")) 150 | length = len(books) 151 | # length = 100 152 | process_args = [(i*length//cpu_worker_num, (i+1)*length//cpu_worker_num) for i in range(cpu_worker_num)] 153 | a,b = process_args[local_rank] 154 | local_start = time.time() 155 | for i in tqdm(range(a,b)): 156 | para = books[i] 157 | count(para) 158 | os.makedirs('output',exist_ok=True) 159 | pkl.dump(cnts, open("output/cnts_%d.pkl"%local_rank, "wb")) 160 | 161 | if __name__=="__main__": 162 | args = define_args() 163 | all_configs = utils.parse_file(args.ernie_config_file) 164 | all_configs.update(vars(args)) 165 | all_configs = utils.HParams(**all_configs) 166 | func(all_configs) 167 | -------------------------------------------------------------------------------- /ernie/generate_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | from collections import Counter 5 | import io 6 | import os 7 | import six 8 | import argparse 9 | import logging 10 | import paddle 11 | 12 | paddle.device.set_device("gpu") 13 | import paddle.fluid as F 14 | import utils 15 | 16 | import tokenization 17 | import dataset_factory 18 | from model import ErnieWithGNN, ErnieRanker, ErnieWithGNNv2 19 | import paddle.distributed as dist 20 | from multiprocessing import cpu_count 21 | import numpy as np 22 | import pickle as pkl 23 | from msmarco_eval import get_mrr 24 | from paddle.io import DistributedBatchSampler, BatchSampler 25 | from msmarco_eval import Mrr 26 | from paddle.hapi.model import _all_gather 27 | from paddle.fluid.dygraph.parallel import ParallelEnv 28 | from tqdm import tqdm 29 | import time 30 | 31 | 32 | # if six.PY3: 33 | # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 34 | 35 | def define_args(): 36 | parser = argparse.ArgumentParser('kg-ERNIE-rerank model') 37 | parser.add_argument('--run', type=str, default="nce") 38 | parser.add_argument('--model', type=str, default="ErnieWithGNN") 39 | parser.add_argument('--num_gnn_layers', type=int, default=3) 40 | parser.add_argument('--ernie_config_file', type=str, default="/home/user/codes/kgreranker/base/ernie_config.json") 41 | parser.add_argument('--vocab_file', type=str, default="/home/user/codes/kgreranker/base/vocab.txt") 42 | parser.add_argument('--train_input_file', type=str, default="/home/user/hdfs_data/data_train/train.top2000.gz") 43 | parser.add_argument('--dev_input_file', type=str, default="/home/user/hdfs_data/data_train/dev.top2000.gz") 44 | parser.add_argument('--batch_size', type=int, default=2) 45 | parser.add_argument('--dev_batch_size', type=int, default=64) 46 | parser.add_argument('--max_seq_len', type=int, default=160) 47 | parser.add_argument('--warm_start_from', type=str, default="/home/user/hdfs_data/data_train/ernie_base.p") 48 | parser.add_argument('--learning_rate', type=float, default=1e-5) 49 | parser.add_argument('--weight_decay', type=float, default=0.01) 50 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 51 | parser.add_argument('--eval_step_proportion', type=float, default=0.33) 52 | parser.add_argument('--report', type=int, default=10) 53 | parser.add_argument('--epoch', type=int, default=50) 54 | parser.add_argument('--qrels', type=str, default="/home/user/hdfs_data/data_train/qrels.train.tsv") 55 | parser.add_argument('--top1000', type=str, default="/home/user/hdfs_data/data_train/train.qidpid.gz") 56 | parser.add_argument('--collection', type=str, default="/home/user/hdfs_data/data_train/collection.tsv") 57 | parser.add_argument('--query', type=str, default="/home/user/hdfs_data/data_train/train.query.txt") 58 | parser.add_argument('--min_index', type=int, default=25) 59 | parser.add_argument('--max_index', type=int, default=768) 60 | parser.add_argument('--sample_num', type=int, default=10) 61 | parser.add_argument('--num_labels', type=int, default=1) 62 | parser.add_argument('--partial_no', type=int, default=1) 63 | parser.add_argument('--generate_type', type=str, default="pretrain") 64 | 65 | # gnn config 66 | parser.add_argument('--with_efeat', type=bool, default=True) 67 | parser.add_argument('--virtual_node', type=bool, default=False) 68 | parser.add_argument('--num_conv_layers', type=int, default=3) 69 | parser.add_argument('--drop_ratio', type=float, default=0.) 70 | parser.add_argument('--norm', type=str, default='layer') 71 | parser.add_argument('--aggr', type=str, default='softmax') 72 | parser.add_argument('--learn_t', type=bool, default=False) 73 | parser.add_argument('--learn_p', type=bool, default=False) 74 | parser.add_argument('--init_t', type=float, default=1.0) 75 | parser.add_argument('--init_p', type=float, default=1.0) 76 | parser.add_argument('--concat', type=bool, default=True) 77 | parser.add_argument('--mlp_layers', type=int, default=1) 78 | parser.add_argument('--edge_num', type=int, default=25) 79 | 80 | parser.add_argument('--resource', type=str, default="/home/user/hdfs_data/concept_net/concept.txt") 81 | parser.add_argument('--cpnet', type=str, default="/home/user/hdfs_data/concept_net/conceptnet.en.pruned.graph") 82 | parser.add_argument('--pattern_path', type=str, default="/home/user/hdfs_data/concept_net/matcher_patterns.json") 83 | parser.add_argument('--word2vec', type=str, 84 | default="/home/user/hdfs_data/data_train/GoogleNews-vectors-negative300.bin.gz") 85 | parser.add_argument('--topk_sents', type=int, default=1) 86 | parser.add_argument('--ent_emb', type=str, default="/home/user/hdfs_data/concept_net/glove.transe.sgd.ent.npy") 87 | parser.add_argument('--rel_emb', type=str, default="/home/user/hdfs_data/concept_net/glove.transe.sgd.rel.npy") 88 | parser.add_argument('--gnn_hidden_size', type=int, default=100) 89 | # args = parser.parse_args(args=[]) 90 | args = parser.parse_args() 91 | return args 92 | 93 | 94 | def generate_eval(): 95 | args = define_args() 96 | all_configs = utils.parse_file(args.ernie_config_file) 97 | all_configs.update(vars(args)) 98 | all_configs = utils.HParams(**all_configs) 99 | cfg = all_configs 100 | all_configs.print_config() 101 | local_rank = paddle.distributed.get_rank() 102 | _nranks = ParallelEnv().nranks 103 | dataset = dataset_factory.GenEvalErnieConceptDataset(cfg) 104 | sampler = DistributedBatchSampler(dataset, batch_size=cfg['dev_batch_size'], shuffle=False) 105 | loader = paddle.io.DataLoader(dataset, batch_sampler=sampler, collate_fn=dataset._collate_fn_gen, num_workers=16) 106 | dist.init_parallel_env() 107 | os.makedirs('gen_data/', exist_ok=True) 108 | all_steps = len(loader) 109 | step = 0 110 | start = time.time() 111 | local_start = time.time() 112 | for batch in loader: 113 | step += 1 114 | batch = [b.numpy() if i < 5 else b for i, b in enumerate(batch)] 115 | pkl.dump(batch, open("gen_data/dev_sample_%d_%d.pkl" % (local_rank, step), "wb")) 116 | if step % args.report == 0 and local_rank == 0: 117 | seconds = time.time() - local_start 118 | m, s = divmod(seconds, 60) 119 | h, m = divmod(m, 60) 120 | local_start = time.time() 121 | print("step: %d/%d, " % (step, all_steps), "report used time:%02d:%02d:%02d," % (h, m, s), end=' ') 122 | seconds = time.time() - start 123 | m, s = divmod(seconds, 60) 124 | h, m = divmod(m, 60) 125 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 126 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 127 | 128 | 129 | def generate_train(): 130 | args = define_args() 131 | all_configs = utils.parse_file(args.ernie_config_file) 132 | all_configs.update(vars(args)) 133 | all_configs = utils.HParams(**all_configs) 134 | cfg = all_configs 135 | all_configs.print_config() 136 | local_rank = paddle.distributed.get_rank() 137 | _nranks = ParallelEnv().nranks 138 | # dataset=dataset_factory.GenTrainV2ErnieConceptDataset(cfg) 139 | dataset = dataset_factory.GenTrainErnieConceptDataset(cfg) 140 | sampler = DistributedBatchSampler(dataset, batch_size=1, shuffle=False) 141 | loader = paddle.io.DataLoader(dataset, batch_sampler=sampler, collate_fn=dataset._collate_fn_gen, num_workers=10) 142 | dist.init_parallel_env() 143 | os.makedirs('gen_data/', exist_ok=True) 144 | all_steps = len(loader) 145 | step = 0 146 | start = time.time() 147 | local_start = time.time() 148 | for batch in loader: 149 | step += 1 150 | batch = [b.numpy() if i < 4 else b for i, b in enumerate(batch)] 151 | pkl.dump(batch, open("gen_data/train_sample_%d_%d.pkl" % (local_rank, step), "wb")) 152 | if step % args.report == 0 and local_rank == 0: 153 | seconds = time.time() - local_start 154 | m, s = divmod(seconds, 60) 155 | h, m = divmod(m, 60) 156 | local_start = time.time() 157 | print("step: %d/%d, " % (step, all_steps), "report used time:%02d:%02d:%02d," % (h, m, s), end=' ') 158 | seconds = time.time() - start 159 | m, s = divmod(seconds, 60) 160 | h, m = divmod(m, 60) 161 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 162 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 163 | 164 | 165 | def generate_pretrain(): 166 | args = define_args() 167 | all_configs = utils.parse_file(args.ernie_config_file) 168 | all_configs.update(vars(args)) 169 | all_configs = utils.HParams(**all_configs) 170 | cfg = all_configs 171 | all_configs.print_config() 172 | local_rank = paddle.distributed.get_rank() 173 | _nranks = ParallelEnv().nranks 174 | partial_passage = 8841823 // 8 + 1 175 | partial_no = all_configs['partial_no'] 176 | dataset = dataset_factory.GenPretrainedConceptDataset(cfg, start=partial_no * partial_passage, 177 | end=partial_no * partial_passage + partial_passage) 178 | sampler = DistributedBatchSampler(dataset, batch_size=1, shuffle=False) 179 | loader = paddle.io.DataLoader(dataset, batch_sampler=sampler, collate_fn=dataset._collate_fn_gen, 180 | num_workers=cpu_count() // _nranks) 181 | dist.init_parallel_env() 182 | os.makedirs('gen_data/', exist_ok=True) 183 | all_steps = len(loader) 184 | save_id = 0 185 | start = time.time() 186 | local_start = time.time() 187 | for step, batch in enumerate(loader): 188 | batch = batch[0] 189 | tmp = list(zip(*batch)) 190 | for sample in tmp: 191 | save_id += 1 192 | sample = [s.numpy() if 0 < i < 3 else s for i, s in enumerate(sample)] 193 | pkl.dump(sample, open("gen_data/pretrain_sample_%d_%d.pkl" % (local_rank, save_id), "wb")) 194 | if step % args.report == 0 and local_rank == 0: 195 | seconds = time.time() - local_start 196 | m, s = divmod(seconds, 60) 197 | h, m = divmod(m, 60) 198 | local_start = time.time() 199 | print("step: %d/%d, " % (step, all_steps), "report used time:%02d:%02d:%02d," % (h, m, s), end=' ') 200 | seconds = time.time() - start 201 | m, s = divmod(seconds, 60) 202 | h, m = divmod(m, 60) 203 | print("total used time:%02d:%02d:%02d" % (h, m, s), end=' ') 204 | print(time.strftime("[TIME %Y-%m-%d %H:%M:%S]", time.localtime())) 205 | 206 | 207 | if __name__ == "__main__": 208 | args = define_args() 209 | all_configs = utils.parse_file(args.ernie_config_file) 210 | all_configs.update(vars(args)) 211 | all_configs = utils.HParams(**all_configs) 212 | cfg = all_configs 213 | if cfg['generate_type'] == 'pretrain': 214 | generate_pretrain() 215 | if cfg['generate_type'] == 'eval': 216 | generate_eval() 217 | if cfg['generate_type'] == 'train': 218 | generate_train() 219 | -------------------------------------------------------------------------------- /ernie/tokenizing_ernie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import division 16 | from __future__ import absolute_import 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | import sys 21 | import os 22 | import six 23 | import re 24 | import logging 25 | import tempfile 26 | from pathlib import Path 27 | from functools import partial 28 | if six.PY2: 29 | from pathlib2 import Path 30 | else: 31 | from pathlib import Path 32 | 33 | from tqdm import tqdm 34 | import numpy as np 35 | 36 | from ernie.file_utils import _fetch_from_remote 37 | import io 38 | 39 | open = partial(io.open, encoding='utf8') 40 | 41 | log = logging.getLogger(__name__) 42 | 43 | _max_input_chars_per_word = 100 44 | 45 | 46 | def _wordpiece(token, vocab, unk_token, prefix='##', sentencepiece_prefix=''): 47 | """ wordpiece: helloworld => [hello, ##world] """ 48 | chars = list(token) 49 | if len(chars) > _max_input_chars_per_word: 50 | return [unk_token], [(0, len(chars))] 51 | 52 | is_bad = False 53 | start = 0 54 | sub_tokens = [] 55 | sub_pos = [] 56 | while start < len(chars): 57 | end = len(chars) 58 | cur_substr = None 59 | while start < end: 60 | substr = "".join(chars[start:end]) 61 | if start == 0: 62 | substr = sentencepiece_prefix + substr 63 | if start > 0: 64 | substr = prefix + substr 65 | if substr in vocab: 66 | cur_substr = substr 67 | break 68 | end -= 1 69 | if cur_substr is None: 70 | is_bad = True 71 | break 72 | sub_tokens.append(cur_substr) 73 | sub_pos.append((start, end)) 74 | start = end 75 | if is_bad: 76 | return [unk_token], [(0, len(chars))] 77 | else: 78 | return sub_tokens, sub_pos 79 | 80 | 81 | class ErnieTokenizer(object): 82 | bce = 'https://ernie-github.cdn.bcebos.com/' 83 | resource_map = { 84 | 'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz', 85 | 'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz', 86 | 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 87 | 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', 88 | 'ernie-gen-base-en': bce + 'model-ernie-gen-base-en.1.tar.gz', 89 | 'ernie-gen-large-en': bce + 'model-ernie-gen-large-en.1.tar.gz', 90 | 'ernie-gram-zh': bce + 'model-ernie-gram-zh.1.tar.gz', 91 | 'ernie-gram-en': bce + 'model-ernie-gram-en.1.tar.gz', 92 | } 93 | 94 | @classmethod 95 | def from_pretrained(cls, 96 | pretrain_dir_or_url, 97 | force_download=False, 98 | **kwargs): 99 | if not Path(pretrain_dir_or_url).exists() and str( 100 | pretrain_dir_or_url) in cls.resource_map: 101 | url = cls.resource_map[str(pretrain_dir_or_url)] 102 | log.info('get pretrain dir from %s' % url) 103 | pretrain_dir = _fetch_from_remote( 104 | url, force_download=force_download) 105 | else: 106 | log.info('pretrain dir %s not in %s, read from local' % 107 | (pretrain_dir_or_url, repr(cls.resource_map))) 108 | pretrain_dir = Path(pretrain_dir_or_url) 109 | if not pretrain_dir.exists(): 110 | raise ValueError('pretrain dir not found: %s, optional: %s' % (pretrain_dir, cls.resource_map.keys())) 111 | vocab_path = pretrain_dir / 'vocab.txt' 112 | if not vocab_path.exists(): 113 | raise ValueError('no vocab file in pretrain dir: %s' % 114 | pretrain_dir) 115 | vocab_dict = { 116 | j.strip().split('\t')[0]: i 117 | for i, j in enumerate( 118 | vocab_path.open(encoding='utf8').readlines()) 119 | } 120 | t = cls(vocab_dict, **kwargs) 121 | return t 122 | 123 | def __init__(self, 124 | vocab, 125 | unk_token='[UNK]', 126 | sep_token='[SEP]', 127 | cls_token='[CLS]', 128 | pad_token='[PAD]', 129 | mask_token='[MASK]', 130 | wordpiece_prefix='##', 131 | sentencepiece_prefix='', 132 | lower=True, 133 | encoding='utf8', 134 | special_token_list=[]): 135 | if not isinstance(vocab, dict): 136 | raise ValueError('expect `vocab` to be instance of dict, got %s' % 137 | type(vocab)) 138 | self.vocab = vocab 139 | self.lower = lower 140 | self.prefix = wordpiece_prefix 141 | self.sentencepiece_prefix = sentencepiece_prefix 142 | self.pad_id = self.vocab[pad_token] 143 | self.cls_id = cls_token and self.vocab[cls_token] 144 | self.sep_id = sep_token and self.vocab[sep_token] 145 | self.unk_id = unk_token and self.vocab[unk_token] 146 | self.mask_id = mask_token and self.vocab[mask_token] 147 | self.unk_token = unk_token 148 | special_tokens = { 149 | pad_token, cls_token, sep_token, unk_token, mask_token 150 | } | set(special_token_list) 151 | pat_str = '' 152 | for t in special_tokens: 153 | if t is None: 154 | continue 155 | pat_str += '(%s)|' % re.escape(t) 156 | pat_str += r'([a-zA-Z0-9]+|\S)' 157 | log.debug('regex: %s' % pat_str) 158 | self.pat = re.compile(pat_str) 159 | self.encoding = encoding 160 | 161 | def tokenize(self, text): 162 | if len(text) == 0: 163 | return [] 164 | if six.PY3 and not isinstance(text, six.string_types): 165 | text = text.decode(self.encoding) 166 | if six.PY2 and isinstance(text, str): 167 | text = text.decode(self.encoding) 168 | 169 | res = [] 170 | for match in self.pat.finditer(text): 171 | match_group = match.group(0) 172 | if match.groups()[-1]: 173 | if self.lower: 174 | match_group = match_group.lower() 175 | words, _ = _wordpiece( 176 | match_group, 177 | vocab=self.vocab, 178 | unk_token=self.unk_token, 179 | prefix=self.prefix, 180 | sentencepiece_prefix=self.sentencepiece_prefix) 181 | else: 182 | words = [match_group] 183 | res += words 184 | return res 185 | 186 | def convert_tokens_to_ids(self, tokens): 187 | return [self.vocab.get(t, self.unk_id) for t in tokens] 188 | 189 | def truncate(self, id1, id2, seqlen): 190 | len1 = len(id1) 191 | len2 = len(id2) 192 | half = seqlen // 2 193 | if len1 > len2: 194 | len1_truncated, len2_truncated = max(half, seqlen - len2), min( 195 | half, len2) 196 | else: 197 | len1_truncated, len2_truncated = min(half, seqlen - len1), max( 198 | half, seqlen - len1) 199 | return id1[:len1_truncated], id2[:len2_truncated] 200 | 201 | def build_for_ernie(self, text_id, pair_id=[]): 202 | """build sentence type id, add [CLS] [SEP]""" 203 | text_id_type = np.zeros_like(text_id, dtype=np.int64) 204 | ret_id = np.concatenate([[self.cls_id], text_id, [self.sep_id]], 0) 205 | ret_id_type = np.concatenate([[0], text_id_type, [0]], 0) 206 | 207 | if len(pair_id): 208 | pair_id_type = np.ones_like(pair_id, dtype=np.int64) 209 | ret_id = np.concatenate([ret_id, pair_id, [self.sep_id]], 0) 210 | ret_id_type = np.concatenate([ret_id_type, pair_id_type, [1]], 0) 211 | return ret_id, ret_id_type 212 | 213 | def encode(self, text, pair=None, truncate_to=None): 214 | text_id = np.array( 215 | self.convert_tokens_to_ids(self.tokenize(text)), dtype=np.int64) 216 | text_id_type = np.zeros_like(text_id, dtype=np.int64) 217 | if pair is not None: 218 | pair_id = np.array( 219 | self.convert_tokens_to_ids(self.tokenize(pair)), 220 | dtype=np.int64) 221 | else: 222 | pair_id = [] 223 | if truncate_to is not None: 224 | text_id, pair_id = self.truncate(text_id, [] if pair_id is None 225 | else pair_id, truncate_to) 226 | 227 | ret_id, ret_id_type = self.build_for_ernie(text_id, pair_id) 228 | return ret_id, ret_id_type 229 | 230 | 231 | class ErnieTinyTokenizer(ErnieTokenizer): 232 | bce = 'https://ernie-github.cdn.bcebos.com/' 233 | resource_map = {'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz'} 234 | 235 | @classmethod 236 | def from_pretrained(cls, 237 | pretrain_dir_or_url, 238 | force_download=False, 239 | **kwargs): 240 | if not Path(pretrain_dir_or_url).exists() and str( 241 | pretrain_dir_or_url) in cls.resource_map: 242 | url = cls.resource_map[str(pretrain_dir_or_url)] 243 | log.info('get pretrain dir from %s' % url) 244 | pretrain_dir = _fetch_from_remote(url, force_download) 245 | else: 246 | log.info('pretrain dir %s not in %s, read from local' % 247 | (pretrain_dir_or_url, repr(cls.resource_map))) 248 | pretrain_dir = Path(pretrain_dir_or_url) 249 | if not pretrain_dir.exists(): 250 | raise ValueError('pretrain dir not found: %s' % pretrain_dir) 251 | vocab_path = pretrain_dir / 'vocab.txt' 252 | sp_model_path = pretrain_dir / 'subword/spm_cased_simp_sampled.model' 253 | 254 | if not vocab_path.exists(): 255 | raise ValueError('no vocab file in pretrain dir: %s' % 256 | pretrain_dir) 257 | vocab_dict = { 258 | j.strip().split('\t')[0]: i 259 | for i, j in enumerate( 260 | vocab_path.open(encoding='utf8').readlines()) 261 | } 262 | 263 | t = cls(vocab_dict, sp_model_path, **kwargs) 264 | return t 265 | 266 | def __init__(self, vocab, sp_model_path, **kwargs): 267 | super(ErnieTinyTokenizer, self).__init__(vocab, **kwargs) 268 | import sentencepiece as spm 269 | import jieba as jb 270 | self.sp_model = spm.SentencePieceProcessor() 271 | self.window_size = 5 272 | self.sp_model.Load(sp_model_path) 273 | self.jb = jb 274 | 275 | def cut(self, sentence): 276 | return self.jb.cut(sentence) 277 | 278 | def tokenize(self, text): 279 | if len(text) == 0: 280 | return [] 281 | if not isinstance(text, six.string_types): 282 | text = text.decode(self.encoding) 283 | if self.lower: 284 | text = text.lower() 285 | 286 | res = [] 287 | for match in self.cut(text): 288 | res += self.sp_model.EncodeAsPieces(match) 289 | return res -------------------------------------------------------------------------------- /ernie/msmarco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. 3 | Command line: 4 | python msmarco_eval_ranking.py 5 | Creation Date : 06/12/2018 6 | Last Modified : 1/21/2019 7 | Authors : Daniel Campos , Rutger van Haasteren 8 | """ 9 | import sys 10 | import pandas as pd 11 | from collections import Counter 12 | import numpy as np 13 | import itertools 14 | MaxMRRRank = 10 15 | import paddle 16 | class Mrr(paddle.metric.Metric): 17 | """doc""" 18 | def __init__(self, cfg): 19 | """doc""" 20 | self.reset() 21 | self._name = "Mrr" 22 | qrels={} 23 | with open(cfg['qrels'],'rb') as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | qid,pid = line.split() 27 | qid=int(qid) 28 | pid=int(pid) 29 | x=qrels.get(qid,[]) 30 | x.append(pid) 31 | qrels[qid]=x 32 | self.qrels = qrels 33 | 34 | def reset(self): 35 | """doc""" 36 | self.qid_saver = np.array([], dtype=np.int64) 37 | self.pid_saver = np.array([], dtype=np.int64) 38 | self.label_saver = np.array([], dtype=np.int64) 39 | self.pred_saver = np.array([], dtype=np.float32) 40 | 41 | def update(self, qid, pid, label, pred): 42 | if isinstance(qid, paddle.Tensor): 43 | qid = qid.numpy() 44 | if isinstance(pid, paddle.Tensor): 45 | pid = pid.numpy() 46 | if isinstance(label, paddle.Tensor): 47 | label = label.numpy() 48 | if isinstance(pred, paddle.Tensor): 49 | pred = pred.numpy() 50 | if not (qid.shape[0] == label.shape[0] == pred.shape[0]): 51 | raise ValueError( 52 | 'Mrr dimention not match: qid[%s] label[%s], pred[%s]' % 53 | (qid.shape, label.shape, pred.shape)) 54 | idx = qid!=-100000000 55 | qid = qid[idx] 56 | pid = pid[idx] 57 | label = label[idx] 58 | pred = pred[idx] 59 | self.qid_saver = np.concatenate( 60 | [self.qid_saver, qid.reshape([-1]).astype(np.int64)]) 61 | self.pid_saver = np.concatenate( 62 | [self.pid_saver, pid.reshape([-1]).astype(np.int64)]) 63 | self.label_saver = np.concatenate( 64 | [self.label_saver, label.reshape([-1]).astype(np.int64)]) 65 | self.pred_saver = np.concatenate( 66 | [self.pred_saver, pred.reshape([-1]).astype(np.float32)]) 67 | 68 | def accumulate(self): 69 | """doc""" 70 | def _key_func(tup): 71 | return tup[0] 72 | def _calc_func(tup): 73 | ranks = [1. / (rank + 1.) for rank, (_, l, p) in enumerate(sorted(tup, key=lambda t: t[2], reverse=True)[:MaxMRRRank]) if l != 0] 74 | if len(ranks): 75 | return ranks[0] 76 | else: 77 | return 0. 78 | mrr_for_qid = [ 79 | _calc_func(tup) 80 | for _, tup in itertools.groupby( 81 | sorted( 82 | zip(self.qid_saver, self.label_saver, self.pred_saver), 83 | key=_key_func), 84 | key=_key_func) 85 | ] 86 | mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid)) 87 | return mrr 88 | def accumulate_map(self): 89 | """doc""" 90 | def _key_func(tup): 91 | return tup[0] 92 | def _calc_func(tup): 93 | tmp = sorted(tup, key=lambda t: t[2], reverse=True) 94 | qid = tmp[0][0] 95 | ranks = [rank+1 for rank, (qid, l, p) in enumerate(tmp) if l != 0] 96 | ranks = [(i+1)/rank for i,rank in enumerate(ranks)] 97 | if len(ranks): 98 | return sum(ranks)/len(self.qrels.get(qid, [0])) 99 | else: 100 | return 0. 101 | map_for_qid = [ 102 | _calc_func(tup) 103 | for _, tup in itertools.groupby( 104 | sorted( 105 | zip(self.qid_saver, self.label_saver, self.pred_saver), 106 | key=_key_func), 107 | key=_key_func) 108 | ] 109 | map = np.float32(sum(map_for_qid) / len(map_for_qid)) 110 | return map 111 | # ndcg 112 | def get_dcg(self, y_pred, y_true, k): 113 | #注意y_pred与y_true必须是一一对应的,并且y_pred越大越接近label=1(用相关性的说法就是,与label=1越相关) 114 | df = pd.DataFrame({"y_pred":y_pred, "y_true":y_true}) 115 | df = df.sort_values(by="y_pred", ascending=False) # 对y_pred进行降序排列,越排在前面的,越接近label=1 116 | df = df.iloc[0:k, :] # 取前K个 117 | dcg = (2 ** df["y_true"] - 1) / np.log2(np.arange(1, df["y_true"].count()+1) + 1) # 位置从1开始计数 118 | dcg = np.sum(dcg) 119 | return dcg 120 | 121 | def accumulate_ndcg(self): 122 | """doc""" 123 | def _key_func(tup): 124 | return tup[0] 125 | def _calc_func(tup): 126 | ranks = [(l, p) for rank, (_, l, p) in enumerate(sorted(tup, key=lambda t: t[2], reverse=True))] 127 | dcg = self.get_dcg([r[1] for r in ranks],[r[0] for r in ranks], 10) 128 | idcg = self.get_dcg([r[0] for r in ranks],[r[0] for r in ranks], 10) 129 | if idcg==0: 130 | return 0 131 | ndcg = dcg/idcg 132 | return ndcg 133 | ndcg_for_qid = [ 134 | _calc_func(tup) 135 | for _, tup in itertools.groupby( 136 | sorted( 137 | zip(self.qid_saver, self.label_saver, self.pred_saver), 138 | key=_key_func), 139 | key=_key_func) 140 | ] 141 | ndcg = np.float32(sum(ndcg_for_qid) / len(ndcg_for_qid)) 142 | return ndcg 143 | 144 | def name(self): 145 | """ 146 | Returns metric name 147 | """ 148 | return self._name 149 | def load_reference_from_stream(f): 150 | """Load Reference reference relevant passages 151 | Args:f (stream): stream to load. 152 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 153 | """ 154 | qids_to_relevant_passageids = {} 155 | for l in f: 156 | try: 157 | l = l.strip().split('\t') 158 | qid = int(l[0]) 159 | if qid in qids_to_relevant_passageids: 160 | pass 161 | else: 162 | qids_to_relevant_passageids[qid] = [] 163 | qids_to_relevant_passageids[qid].append(int(l[1])) 164 | except: 165 | raise IOError('\"%s\" is not valid format' % l) 166 | return qids_to_relevant_passageids 167 | 168 | 169 | def load_reference(path_to_reference): 170 | """Load Reference reference relevant passages 171 | Args:path_to_reference (str): path to a file to load. 172 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 173 | """ 174 | with open(path_to_reference, 'r') as f: 175 | qids_to_relevant_passageids = load_reference_from_stream(f) 176 | return qids_to_relevant_passageids 177 | 178 | 179 | def load_candidate_from_stream(f): 180 | """Load candidate data from a stream. 181 | Args:f (stream): stream to load. 182 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 183 | """ 184 | qid_to_ranked_candidate_passages = {} 185 | for l in f: 186 | try: 187 | l = l.strip().split() 188 | qid = int(l[0]) 189 | pid = int(l[2]) 190 | rank = int(l[3]) 191 | if qid in qid_to_ranked_candidate_passages: 192 | pass 193 | else: 194 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 195 | tmp = [0] * 1000 196 | qid_to_ranked_candidate_passages[qid] = tmp 197 | qid_to_ranked_candidate_passages[qid][rank - 1] = pid 198 | except: 199 | raise IOError('\"%s\" is not valid format' % l) 200 | return qid_to_ranked_candidate_passages 201 | 202 | 203 | def load_candidate(path_to_candidate): 204 | """Load candidate data from a file. 205 | Args:path_to_candidate (str): path to file to load. 206 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 207 | """ 208 | 209 | with open(path_to_candidate, 'r') as f: 210 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 211 | return qid_to_ranked_candidate_passages 212 | 213 | 214 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 215 | """Perform quality checks on the dictionaries 216 | Args: 217 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 218 | Dict as read in with load_reference or load_reference_from_stream 219 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 220 | Returns: 221 | bool,str: Boolean whether allowed, message to be shown in case of a problem 222 | """ 223 | message = '' 224 | allowed = True 225 | 226 | # Create sets of the QIDs for the submitted and reference queries 227 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 228 | ref_set = set(qids_to_relevant_passageids.keys()) 229 | 230 | # Check that we do not have multiple passages per query 231 | for qid in qids_to_ranked_candidate_passages: 232 | # Remove all zeros from the candidates 233 | duplicate_pids = set( 234 | [item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 235 | 236 | if len(duplicate_pids - set([0])) > 0: 237 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 238 | qid=qid, pid=list(duplicate_pids)[0]) 239 | allowed = False 240 | 241 | return allowed, message 242 | 243 | 244 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 245 | """Compute MRR metric 246 | Args: 247 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 248 | Dict as read in with load_reference or load_reference_from_stream 249 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 250 | Returns: 251 | dict: dictionary of metrics {'MRR': } 252 | """ 253 | all_scores = {} 254 | MRR = 0 255 | qids_with_relevant_passages = 0 256 | ranking = [] 257 | recall_q_top1 = set() 258 | recall_q_top50 = set() 259 | recall_q_all = set() 260 | 261 | for qid in qids_to_ranked_candidate_passages: 262 | if qid in qids_to_relevant_passageids: 263 | ranking.append(0) 264 | target_pid = qids_to_relevant_passageids[qid] 265 | candidate_pid = qids_to_ranked_candidate_passages[qid] 266 | for i in range(0, MaxMRRRank): 267 | if candidate_pid[i] in target_pid: 268 | MRR += 1.0 / (i + 1) 269 | ranking.pop() 270 | ranking.append(i + 1) 271 | break 272 | for i, pid in enumerate(candidate_pid): 273 | if pid in target_pid: 274 | recall_q_all.add(qid) 275 | if i < 50: 276 | recall_q_top50.add(qid) 277 | if i == 0: 278 | recall_q_top1.add(qid) 279 | break 280 | if len(ranking) == 0: 281 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 282 | 283 | MRR = MRR / len(qids_to_ranked_candidate_passages) 284 | recall_top1 = len(recall_q_top1) * 1.0 / len(qids_to_ranked_candidate_passages) 285 | recall_top50 = len(recall_q_top50) * 1.0 / len(qids_to_ranked_candidate_passages) 286 | recall_all = len(recall_q_all) * 1.0 / len(qids_to_ranked_candidate_passages) 287 | all_scores['MRR @10'] = MRR 288 | all_scores["recall@1"] = recall_top1 289 | all_scores["recall@50"] = recall_top50 290 | all_scores["recall@all"] = recall_all 291 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 292 | return all_scores 293 | 294 | 295 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 296 | """Compute MRR metric 297 | Args: 298 | p_path_to_reference_file (str): path to reference file. 299 | Reference file should contain lines in the following format: 300 | QUERYID\tPASSAGEID 301 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 302 | p_path_to_candidate_file (str): path to candidate file. 303 | Candidate file sould contain lines in the following format: 304 | QUERYID\tPASSAGEID1\tRank 305 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 306 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 307 | Where the values are separated by tabs and ranked in order of relevance 308 | Returns: 309 | dict: dictionary of metrics {'MRR': } 310 | """ 311 | 312 | qids_to_relevant_passageids = load_reference(path_to_reference) 313 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 314 | if perform_checks: 315 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 316 | if message != '': print(message) 317 | 318 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 319 | 320 | 321 | def main(): 322 | """Command line: 323 | python msmarco_eval_ranking.py 324 | """ 325 | 326 | if len(sys.argv) == 3: 327 | path_to_reference = sys.argv[1] 328 | path_to_candidate = sys.argv[2] 329 | 330 | else: 331 | path_to_reference = 'metric/qp_reference.all.tsv' 332 | path_to_candidate = 'metric/ranking_res' 333 | #print('Usage: msmarco_eval_ranking.py ') 334 | #exit() 335 | 336 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 337 | print('#####################') 338 | for metric in sorted(metrics): 339 | print('{}: {}'.format(metric, metrics[metric])) 340 | print('#####################') 341 | 342 | def get_mrr(path_to_reference="/home/user/codes/NAACL2021-RocketQA/corpus/marco/qrels.dev.tsv", path_to_candidate="output/step_0_pred_dev_scores.txt"): 343 | all_data = pd.read_csv(path_to_candidate,sep="\t",header=None) 344 | all_data.columns = ["qid","pid","score"] 345 | all_data = all_data.groupby("qid").apply(lambda x: x.sort_values('score', ascending=False).reset_index(drop=True)) 346 | all_data.columns = ['query_id',"para_id","score"] 347 | all_data = all_data.reset_index() 348 | all_data.pop("qid") 349 | all_data.columns = ["index","qid","pid","score"] 350 | all_data = all_data.loc[:,["qid","pid","index","score"]] 351 | all_data['index']+=1 352 | path_to_candidate = path_to_candidate.replace("txt","qrels") 353 | all_data.to_csv(path_to_candidate, header=None,index=False,sep="\t") 354 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 355 | return metrics['MRR @10'] 356 | 357 | if __name__ == '__main__': 358 | main() -------------------------------------------------------------------------------- /ernie/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import sentencepiece as sp 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | fin = open(vocab_file, 'rb') 73 | for num, line in enumerate(fin): 74 | items = convert_to_unicode(line.strip()).split("\t") 75 | if len(items) > 2: 76 | break 77 | token = items[0] 78 | index = items[1] if len(items) == 2 else num 79 | token = token.strip() 80 | vocab[token] = int(index) 81 | return vocab 82 | 83 | 84 | def convert_by_vocab(vocab, items): 85 | """Converts a sequence of [tokens|ids] using the vocab.""" 86 | output = [] 87 | for item in items: 88 | output.append(vocab[item]) 89 | return output 90 | 91 | 92 | def convert_tokens_to_ids_include_unk(vocab, tokens, unk_token="[UNK]"): 93 | output = [] 94 | for token in tokens: 95 | if token in vocab: 96 | output.append(vocab[token]) 97 | else: 98 | output.append(vocab[unk_token]) 99 | return output 100 | 101 | 102 | def convert_tokens_to_ids(vocab, tokens): 103 | return convert_by_vocab(vocab, tokens) 104 | 105 | 106 | def convert_ids_to_tokens(inv_vocab, ids): 107 | return convert_by_vocab(inv_vocab, ids) 108 | 109 | 110 | def whitespace_tokenize(text): 111 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 112 | text = text.strip() 113 | if not text: 114 | return [] 115 | tokens = text.split() 116 | return tokens 117 | 118 | 119 | class FullTokenizer(object): 120 | """Runs end-to-end tokenziation.""" 121 | 122 | def __init__(self, vocab_file, do_lower_case=True): 123 | self.vocab = load_vocab(vocab_file) 124 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 125 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 126 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 127 | 128 | def tokenize(self, text): 129 | split_tokens = [] 130 | for token in self.basic_tokenizer.tokenize(text): 131 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 132 | split_tokens.append(sub_token) 133 | 134 | return split_tokens 135 | 136 | def convert_tokens_to_ids(self, tokens): 137 | return convert_by_vocab(self.vocab, tokens) 138 | 139 | def convert_ids_to_tokens(self, ids): 140 | return convert_by_vocab(self.inv_vocab, ids) 141 | 142 | 143 | class CharTokenizer(object): 144 | """Runs end-to-end tokenziation.""" 145 | 146 | def __init__(self, vocab_file, do_lower_case=True): 147 | self.vocab = load_vocab(vocab_file) 148 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 149 | self.tokenizer = WordpieceTokenizer(vocab=self.vocab) 150 | 151 | def tokenize(self, text): 152 | split_tokens = [] 153 | for token in text.lower().split(" "): 154 | for sub_token in self.tokenizer.tokenize(token): 155 | split_tokens.append(sub_token) 156 | return split_tokens 157 | 158 | def convert_tokens_to_ids(self, tokens): 159 | return convert_by_vocab(self.vocab, tokens) 160 | 161 | def convert_ids_to_tokens(self, ids): 162 | return convert_by_vocab(self.inv_vocab, ids) 163 | 164 | 165 | class BasicTokenizer(object): 166 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 167 | 168 | def __init__(self, do_lower_case=True): 169 | """Constructs a BasicTokenizer. 170 | 171 | Args: 172 | do_lower_case: Whether to lower case the input. 173 | """ 174 | self.do_lower_case = do_lower_case 175 | 176 | def tokenize(self, text): 177 | """Tokenizes a piece of text.""" 178 | text = convert_to_unicode(text) 179 | text = self._clean_text(text) 180 | 181 | # This was added on November 1st, 2018 for the multilingual and Chinese 182 | # models. This is also applied to the English models now, but it doesn't 183 | # matter since the English models were not trained on any Chinese data 184 | # and generally don't have any Chinese data in them (there are Chinese 185 | # characters in the vocabulary because Wikipedia does have some Chinese 186 | # words in the English Wikipedia.). 187 | text = self._tokenize_chinese_chars(text) 188 | 189 | orig_tokens = whitespace_tokenize(text) 190 | split_tokens = [] 191 | for token in orig_tokens: 192 | if self.do_lower_case: 193 | token = token.lower() 194 | token = self._run_strip_accents(token) 195 | split_tokens.extend(self._run_split_on_punc(token)) 196 | 197 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 198 | return output_tokens 199 | 200 | def _run_strip_accents(self, text): 201 | """Strips accents from a piece of text.""" 202 | text = unicodedata.normalize("NFD", text) 203 | output = [] 204 | for char in text: 205 | cat = unicodedata.category(char) 206 | if cat == "Mn": 207 | continue 208 | output.append(char) 209 | return "".join(output) 210 | 211 | def _run_split_on_punc(self, text): 212 | """Splits punctuation on a piece of text.""" 213 | chars = list(text) 214 | i = 0 215 | start_new_word = True 216 | output = [] 217 | while i < len(chars): 218 | char = chars[i] 219 | if _is_punctuation(char): 220 | output.append([char]) 221 | start_new_word = True 222 | else: 223 | if start_new_word: 224 | output.append([]) 225 | start_new_word = False 226 | output[-1].append(char) 227 | i += 1 228 | 229 | return ["".join(x) for x in output] 230 | 231 | def _tokenize_chinese_chars(self, text): 232 | """Adds whitespace around any CJK character.""" 233 | output = [] 234 | for char in text: 235 | cp = ord(char) 236 | if self._is_chinese_char(cp): 237 | output.append(" ") 238 | output.append(char) 239 | output.append(" ") 240 | else: 241 | output.append(char) 242 | return "".join(output) 243 | 244 | def _is_chinese_char(self, cp): 245 | """Checks whether CP is the codepoint of a CJK character.""" 246 | # This defines a "chinese character" as anything in the CJK Unicode block: 247 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 248 | # 249 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 250 | # despite its name. The modern Korean Hangul alphabet is a different block, 251 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 252 | # space-separated words, so they are not treated specially and handled 253 | # like the all of the other languages. 254 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 255 | (cp >= 0x3400 and cp <= 0x4DBF) or # 256 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 257 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 258 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 259 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 260 | (cp >= 0xF900 and cp <= 0xFAFF) or # 261 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 262 | return True 263 | 264 | return False 265 | 266 | def _clean_text(self, text): 267 | """Performs invalid character removal and whitespace cleanup on text.""" 268 | output = [] 269 | for char in text: 270 | cp = ord(char) 271 | if cp == 0 or cp == 0xfffd or _is_control(char): 272 | continue 273 | if _is_whitespace(char): 274 | output.append(" ") 275 | else: 276 | output.append(char) 277 | return "".join(output) 278 | 279 | 280 | class SentencepieceTokenizer(object): 281 | """Runs SentencePiece tokenziation.""" 282 | 283 | def __init__(self, vocab_file, do_lower_case=True, unk_token="[UNK]"): 284 | self.vocab = load_vocab(vocab_file) 285 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 286 | self.do_lower_case = do_lower_case 287 | self.tokenizer = sp.SentencePieceProcessor() 288 | self.tokenizer.Load(vocab_file + ".model") 289 | self.sp_unk_token = "" 290 | self.unk_token = unk_token 291 | 292 | def tokenize(self, text): 293 | """Tokenizes a piece of text into its word pieces. 294 | 295 | Returns: 296 | A list of wordpiece tokens. 297 | """ 298 | text = text.lower() if self.do_lower_case else text 299 | text = convert_to_unicode(text.replace("\1", " ")) 300 | tokens = self.tokenizer.EncodeAsPieces(text) 301 | 302 | output_tokens = [] 303 | for token in tokens: 304 | if token == self.sp_unk_token: 305 | token = self.unk_token 306 | 307 | if token in self.vocab: 308 | output_tokens.append(token) 309 | else: 310 | output_tokens.append(self.unk_token) 311 | 312 | return output_tokens 313 | 314 | def convert_tokens_to_ids(self, tokens): 315 | return convert_by_vocab(self.vocab, tokens) 316 | 317 | def convert_ids_to_tokens(self, ids): 318 | return convert_by_vocab(self.inv_vocab, ids) 319 | 320 | 321 | class WordsegTokenizer(object): 322 | """Runs Wordseg tokenziation.""" 323 | 324 | def __init__(self, vocab_file, do_lower_case=True, unk_token="[UNK]", 325 | split_token="\1"): 326 | self.vocab = load_vocab(vocab_file) 327 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 328 | self.tokenizer = sp.SentencePieceProcessor() 329 | self.tokenizer.Load(vocab_file + ".model") 330 | 331 | self.do_lower_case = do_lower_case 332 | self.unk_token = unk_token 333 | self.split_token = split_token 334 | 335 | def tokenize(self, text): 336 | """Tokenizes a piece of text into its word pieces. 337 | 338 | Returns: 339 | A list of wordpiece tokens. 340 | """ 341 | text = text.lower() if self.do_lower_case else text 342 | text = convert_to_unicode(text) 343 | 344 | output_tokens = [] 345 | for token in text.split(self.split_token): 346 | if token in self.vocab: 347 | output_tokens.append(token) 348 | else: 349 | sp_tokens = self.tokenizer.EncodeAsPieces(token) 350 | for sp_token in sp_tokens: 351 | if sp_token in self.vocab: 352 | output_tokens.append(sp_token) 353 | return output_tokens 354 | 355 | def convert_tokens_to_ids(self, tokens): 356 | return convert_by_vocab(self.vocab, tokens) 357 | 358 | def convert_ids_to_tokens(self, ids): 359 | return convert_by_vocab(self.inv_vocab, ids) 360 | 361 | 362 | class WordpieceTokenizer(object): 363 | """Runs WordPiece tokenziation.""" 364 | 365 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 366 | self.vocab = vocab 367 | self.unk_token = unk_token 368 | self.max_input_chars_per_word = max_input_chars_per_word 369 | 370 | def tokenize(self, text): 371 | """Tokenizes a piece of text into its word pieces. 372 | 373 | This uses a greedy longest-match-first algorithm to perform tokenization 374 | using the given vocabulary. 375 | 376 | For example: 377 | input = "unaffable" 378 | output = ["un", "##aff", "##able"] 379 | 380 | Args: 381 | text: A single token or whitespace separated tokens. This should have 382 | already been passed through `BasicTokenizer. 383 | 384 | Returns: 385 | A list of wordpiece tokens. 386 | """ 387 | 388 | text = convert_to_unicode(text) 389 | 390 | output_tokens = [] 391 | for token in whitespace_tokenize(text): 392 | chars = list(token) 393 | if len(chars) > self.max_input_chars_per_word: 394 | output_tokens.append(self.unk_token) 395 | continue 396 | 397 | is_bad = False 398 | start = 0 399 | sub_tokens = [] 400 | while start < len(chars): 401 | end = len(chars) 402 | cur_substr = None 403 | while start < end: 404 | substr = "".join(chars[start:end]) 405 | if start > 0: 406 | substr = "##" + substr 407 | if substr in self.vocab: 408 | cur_substr = substr 409 | break 410 | end -= 1 411 | if cur_substr is None: 412 | is_bad = True 413 | break 414 | sub_tokens.append(cur_substr) 415 | start = end 416 | 417 | if is_bad: 418 | output_tokens.append(self.unk_token) 419 | else: 420 | output_tokens.extend(sub_tokens) 421 | return output_tokens 422 | 423 | 424 | def _is_whitespace(char): 425 | """Checks whether `chars` is a whitespace character.""" 426 | # \t, \n, and \r are technically contorl characters but we treat them 427 | # as whitespace since they are generally considered as such. 428 | if char == " " or char == "\t" or char == "\n" or char == "\r": 429 | return True 430 | cat = unicodedata.category(char) 431 | if cat == "Zs": 432 | return True 433 | return False 434 | 435 | 436 | def _is_control(char): 437 | """Checks whether `chars` is a control character.""" 438 | # These are technically control characters but we count them as whitespace 439 | # characters. 440 | if char == "\t" or char == "\n" or char == "\r": 441 | return False 442 | cat = unicodedata.category(char) 443 | if cat.startswith("C"): 444 | return True 445 | return False 446 | 447 | 448 | def _is_punctuation(char): 449 | """Checks whether `chars` is a punctuation character.""" 450 | cp = ord(char) 451 | # We treat all non-letter/number ASCII as punctuation. 452 | # Characters such as "^", "$", and "`" are not in the Unicode 453 | # Punctuation class but we treat them as punctuation anyways, for 454 | # consistency. 455 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 456 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 457 | return True 458 | cat = unicodedata.category(char) 459 | if cat.startswith("P"): 460 | return True 461 | return False 462 | -------------------------------------------------------------------------------- /ernie/paths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import numpy as np 4 | import networkx as nx 5 | import tokenization 6 | from spacy.matcher import Matcher 7 | import spacy 8 | import gensim 9 | import pgl 10 | import re 11 | class ConceptGraphBuilder: 12 | def __init__(self,cfg): 13 | self.cfg = cfg 14 | self.load_resources(cfg['resource']) 15 | self.load_cpnet(cfg['cpnet']) 16 | self.nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) 17 | self.nlp.add_pipe('sentencizer') 18 | self.matcher = Matcher(self.nlp.vocab) 19 | with open(cfg['pattern_path'], "r", encoding="utf8") as fin: 20 | all_patterns = json.load(fin) 21 | for concept, pattern in all_patterns.items(): 22 | self.matcher.add(concept, [pattern]) 23 | self.word2vec = gensim.models.KeyedVectors.load_word2vec_format(cfg['word2vec'], binary=True) 24 | self.stopwords=['i', 25 | 'me', 26 | 'my', 27 | 'myself', 28 | 'we', 29 | 'our', 30 | 'ours', 31 | 'ourselves', 32 | 'you', 33 | "you're", 34 | "you've", 35 | "you'll", 36 | "you'd", 37 | 'your', 38 | 'yours', 39 | 'yourself', 40 | 'yourselves', 41 | 'he', 42 | 'him', 43 | 'his', 44 | 'himself', 45 | 'she', 46 | "she's", 47 | 'her', 48 | 'hers', 49 | 'herself', 50 | 'it', 51 | "it's", 52 | 'its', 53 | 'itself', 54 | 'they', 55 | 'them', 56 | 'their', 57 | 'theirs', 58 | 'themselves', 59 | 'what', 60 | 'which', 61 | 'who', 62 | 'whom', 63 | 'this', 64 | 'that', 65 | "that'll", 66 | 'these', 67 | 'those', 68 | 'am', 69 | 'is', 70 | 'are', 71 | 'was', 72 | 'were', 73 | 'be', 74 | 'been', 75 | 'being', 76 | 'have', 77 | 'has', 78 | 'had', 79 | 'having', 80 | 'do', 81 | 'does', 82 | 'did', 83 | 'doing', 84 | 'a', 85 | 'an', 86 | 'the', 87 | 'and', 88 | 'but', 89 | 'if', 90 | 'or', 91 | 'because', 92 | 'as', 93 | 'until', 94 | 'while', 95 | 'of', 96 | 'at', 97 | 'by', 98 | 'for', 99 | 'with', 100 | 'about', 101 | 'against', 102 | 'between', 103 | 'into', 104 | 'through', 105 | 'during', 106 | 'before', 107 | 'after', 108 | 'above', 109 | 'below', 110 | 'to', 111 | 'from', 112 | 'up', 113 | 'down', 114 | 'in', 115 | 'out', 116 | 'on', 117 | 'off', 118 | 'over', 119 | 'under', 120 | 'again', 121 | 'further', 122 | 'then', 123 | 'once', 124 | 'here', 125 | 'there', 126 | 'when', 127 | 'where', 128 | 'why', 129 | 'how', 130 | 'all', 131 | 'any', 132 | 'both', 133 | 'each', 134 | 'few', 135 | 'more', 136 | 'most', 137 | 'other', 138 | 'some', 139 | 'such', 140 | 'no', 141 | 'nor', 142 | 'not', 143 | 'only', 144 | 'own', 145 | 'same', 146 | 'so', 147 | 'than', 148 | 'too', 149 | 'very', 150 | 's', 151 | 't', 152 | 'can', 153 | 'will', 154 | 'just', 155 | 'don', 156 | "don't", 157 | 'should', 158 | "should've", 159 | 'now', 160 | 'd', 161 | 'll', 162 | 'm', 163 | 'o', 164 | 're', 165 | 've', 166 | 'y', 167 | 'ain', 168 | 'aren', 169 | "aren't", 170 | 'couldn', 171 | "couldn't", 172 | 'didn', 173 | "didn't", 174 | 'doesn', 175 | "doesn't", 176 | 'hadn', 177 | "hadn't", 178 | 'hasn', 179 | "hasn't", 180 | 'haven', 181 | "haven't", 182 | 'isn', 183 | "isn't", 184 | 'ma', 185 | 'mightn', 186 | "mightn't", 187 | 'mustn', 188 | "mustn't", 189 | 'needn', 190 | "needn't", 191 | 'shan', 192 | "shan't", 193 | 'shouldn', 194 | "shouldn't", 195 | 'wasn', 196 | "wasn't", 197 | 'weren', 198 | "weren't", 199 | 'won', 200 | "won't", 201 | 'wouldn', 202 | "wouldn't"] 203 | self.blacklist = set(["-PRON-", "actually", "likely", "possibly", "want", 204 | "make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to", 205 | "one", "something", "sometimes", "everybody", "somebody", "could", "could_be" 206 | ]+self.stopwords) 207 | def get_span(self, sent): 208 | sent, shift = sent 209 | sent = sent.lower() 210 | # sent = re.sub(r'[^a-z ]+', ' ', sent) 211 | sent = sent.replace("-", "_") 212 | spans = [] 213 | tokens = sent.split(" ") 214 | token_num = len(tokens) 215 | itv = [] 216 | for length in range(1, 5): 217 | for i in range(token_num-length+1): 218 | span = "_".join(tokens[i:i+length]) 219 | span = list(self.lemmatize(span))[0] 220 | if span not in self.blacklist and span in self.concept2id and span not in spans: 221 | spans.append(span) 222 | itv.append((span,i,i+length)) 223 | itv = self.removeCoveredIntervals(itv) 224 | return [(i[0],i[1]+shift,i[2]+shift) for i in itv] 225 | def get_edge(self, src_concept, tgt_concept): 226 | rel_list = self.cpnet[src_concept][tgt_concept] # list of dicts 227 | seen = set() 228 | res = [r['rel'] for r in rel_list.values() if r['rel'] not in seen and (seen.add(r['rel']) or True)] # get unique values from rel_list 229 | return res 230 | def find_paths_qd_concept_pair(self, source: str, target: str, ifprint=False): 231 | s = self.concept2id[source] 232 | t = self.concept2id[target] 233 | if s not in self.cpnet_simple.nodes() or t not in self.cpnet_simple.nodes(): 234 | return [] 235 | all_path = [] 236 | top_paths = 10 237 | try: 238 | for p in nx.shortest_simple_paths(self.cpnet_topk, source=s, target=t): 239 | if len(p) > 4 or len(all_path) >= top_paths: # top 10 paths 240 | break 241 | if len(p) >= 2: # skip paths of length 1 242 | all_path.append(p) 243 | except nx.exception.NetworkXNoPath: 244 | return [] 245 | pf_res = [] 246 | for p in all_path: 247 | rl = [] 248 | for src in range(len(p) - 1): 249 | src_concept = p[src] 250 | tgt_concept = p[src + 1] 251 | 252 | rel_list = self.get_edge(src_concept, tgt_concept) 253 | rl.append(rel_list) 254 | if ifprint: 255 | rel_list_str = [] 256 | for rel in rel_list: 257 | if rel < len(self.id2relation): 258 | rel_list_str.append(self.id2relation[rel]) 259 | else: 260 | rel_list_str.append(self.id2relation[rel - len(self.id2relation)] + "*") 261 | print(self.id2concept[src_concept], "----[%s]---> " % ("/".join(rel_list_str)), end="") 262 | if src + 1 == len(p) - 1: 263 | print(self.id2concept[tgt_concept], end="") 264 | if ifprint: 265 | print() 266 | pf_res.append({"path": p, "rel": rl}) 267 | return pf_res 268 | def lemmatize(self, concept): 269 | doc = self.nlp(concept.replace("_", " ")) 270 | lcs = set() 271 | lcs.add("_".join([token.lemma_ for token in doc])) # all lemma 272 | return lcs 273 | def removeCoveredIntervals(self, intervals): 274 | intervals.sort(key=lambda x:(x[1], -x[2])) 275 | dp=[] 276 | for itv in intervals: 277 | if not dp or dp[-1][2] 0: 310 | mentioned_concepts.add(list(intersect)[0]) 311 | else: 312 | mentioned_concepts.add(c) 313 | 314 | exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()]) 315 | mentioned_concepts.update(exact_match) 316 | return mentioned_concepts 317 | def get_emb(self, sent): 318 | return np.mean(np.array([self.word2vec[w] for w in sent.split(" ") if w in self.word2vec]),axis=0) 319 | def get_topk_sents(self, qry, doc, k): 320 | qemb=self.get_emb(qry) 321 | d=self.nlp(doc) 322 | ret = [] 323 | shift = 0 324 | for s in d.sents: 325 | sent = re.sub(r'[^a-z ]+', ' ', s.text.lower()) 326 | semb=self.get_emb(sent) 327 | if not np.isnan(semb).all(): 328 | sim=np.dot(qemb,semb) 329 | if type(sim) is np.float32: 330 | ret.append((sim, s.text.lower(), shift)) 331 | shift+=len(s.text.split(" ")) 332 | ret.sort(key=lambda x:x[0],reverse=True) 333 | return [(r[1], r[2]) for r in ret[:k]] 334 | def get_graph(self, qry, doc, is_print=False): 335 | topk_sents_and_shift = self.get_topk_sents(qry, doc, self.cfg['topk_sents']) 336 | qry = qry.lower() 337 | doc = doc.lower() 338 | qry_concepts_and_shift = self.get_span((qry,0)) 339 | doc_concepts_and_shift = [('ssss',0,1)] 340 | for sent_and_shift in topk_sents_and_shift: 341 | doc_concepts_and_shift += self.get_span(sent_and_shift) 342 | doc_concepts_and_shift = list(set(doc_concepts_and_shift)) 343 | paths=[] 344 | for q in qry_concepts_and_shift: 345 | for d in doc_concepts_and_shift: 346 | paths+=self.find_paths_qd_concept_pair(q[0],d[0],is_print) 347 | nodes = {} 348 | qry_nodes = [] 349 | doc_nodes = [] 350 | edges = [] 351 | edges_type = [] 352 | for path in paths: 353 | for i,node in enumerate(path['path']): 354 | if node not in nodes: 355 | nodes[node] = len(nodes) 356 | if i: 357 | prev_node = nodes[path['path'][i-1]] 358 | edges.append((prev_node, nodes[node])) 359 | edges_type.append(path['rel'][i-1][0]%17) 360 | q_spans = [] 361 | d_spans = [] 362 | qry_concepts_to_shift={} 363 | for c,s,e in qry_concepts_and_shift: 364 | qry_concepts_to_shift[self.concept2id[c]]=(s,e) 365 | doc_concepts_to_shift = {} 366 | for c,s,e in doc_concepts_and_shift: 367 | doc_concepts_to_shift[self.concept2id[c]] = (s,e) 368 | for path in paths: 369 | q_node = path['path'][0] 370 | d_node = path['path'][-1] 371 | if nodes[q_node] not in qry_nodes: 372 | qry_nodes.append(nodes[q_node]) 373 | q_spans.append(qry_concepts_to_shift[q_node]) 374 | if nodes[d_node] not in doc_nodes: 375 | doc_nodes.append(nodes[d_node]) 376 | d_spans.append(doc_concepts_to_shift[d_node]) 377 | 378 | num_nodes = len(nodes) 379 | nodes_feature = self.ent_emb[list(nodes)] 380 | edges_feature = self.rel_emb[edges_type] 381 | graph = pgl.Graph(num_nodes=num_nodes, 382 | edges=edges, 383 | node_feat={"feature": nodes_feature}, 384 | edge_feat={"edge_feature":edges_feature}) 385 | return graph,q_spans,d_spans,qry_nodes,doc_nodes 386 | 387 | 388 | def load_resources(self,cpnet_vocab_path): 389 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 390 | self.id2concept = [w.strip() for w in fin] 391 | self.concept2id = {w: i for i, w in enumerate(self.id2concept)} 392 | self.id2relation = [ 393 | 'antonym', 394 | 'atlocation', 395 | 'capableof', 396 | 'causes', 397 | 'createdby', 398 | 'isa', 399 | 'desires', 400 | 'hassubevent', 401 | 'partof', 402 | 'hascontext', 403 | 'hasproperty', 404 | 'madeof', 405 | 'notcapableof', 406 | 'notdesires', 407 | 'receivesaction', 408 | 'relatedto', 409 | 'usedfor', 410 | ] 411 | self.relation2id = {r: i for i, r in enumerate(self.id2relation)} 412 | def load_cpnet(self,cpnet_graph_path): 413 | cpnet = nx.read_gpickle(cpnet_graph_path) 414 | cpnet.add_edge(self.concept2id['cccc'],self.concept2id['ssss'],weight=1.0,rel=15) 415 | cpnet_simple = nx.Graph() 416 | self.ent_emb = np.load(self.cfg['ent_emb']) 417 | self.ent_emb = np.concatenate([self.ent_emb, np.zeros((2,100))]).astype("float32") 418 | self.rel_emb = np.load(self.cfg['rel_emb']).astype("float32") 419 | for u, v, data in cpnet.edges(data=True): 420 | rid = data['rel']%17 421 | w = np.inner(self.ent_emb[u], self.ent_emb[v])+np.inner(self.ent_emb[u], self.rel_emb[rid])+np.inner(self.rel_emb[rid], self.ent_emb[v]) 422 | w = 1/abs(w) 423 | if not cpnet_simple.has_edge(u, v): 424 | cpnet_simple.add_edge(u, v, weight=w) 425 | topk={} 426 | max_neighbor=50 #消融改了这个 427 | # max_neighbor = 8000000 428 | for u in cpnet_simple: 429 | topk[u] = [] 430 | for v,data in cpnet_simple[u].items(): 431 | topk[u].append((data['weight'],v)) 432 | topk[u].sort(key=lambda x:x[0]) 433 | topk[u]=topk[u][:max_neighbor] 434 | cpnet_topk = nx.DiGraph() 435 | for u in topk: 436 | for w,v in topk[u]: 437 | if not cpnet_topk.has_edge(u,v): 438 | cpnet_topk.add_edge(u,v,weight=w) 439 | self.cpnet_topk = cpnet_topk 440 | self.cpnet_simple = cpnet_simple 441 | self.cpnet = cpnet 442 | 443 | 444 | class GraphBuilder: 445 | def __init__(self, cfg): 446 | self.load_resources(cfg['resource']) 447 | self.load_cpnet(cfg['cpnet']) 448 | self.tokenizer = tokenization.FullTokenizer(cfg['vocab_file']) 449 | self.cfg = cfg 450 | 451 | def load_resources(self,cpnet_vocab_path): 452 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 453 | self.id2concept = [w.strip() for w in fin] 454 | self.concept2id = {w: i for i, w in enumerate(self.id2concept)} 455 | 456 | self.id2relation = merged_relations=['atlocation', 457 | 'capableof', 458 | 'causes', 459 | 'createdby', 460 | 'desires', 461 | 'hasproperty', 462 | 'madeof', 463 | 'notcapableof', 464 | 'notdesires', 465 | 'partof', 466 | 'usedfor', 467 | 'receivesaction'] 468 | self.relation2id = {r: i for i, r in enumerate(self.id2relation)} 469 | 470 | 471 | def load_cpnet(self,cpnet_graph_path): 472 | global cpnet, cpnet_simple 473 | cpnet = nx.read_gpickle(cpnet_graph_path) 474 | cpnet_simple = nx.Graph() 475 | for u, v, data in cpnet.edges(data=True): 476 | w = data['weight'] if 'weight' in data else 1.0 477 | if cpnet_simple.has_edge(u, v): 478 | cpnet_simple[u][v]['weight'] += w 479 | else: 480 | cpnet_simple.add_edge(u, v, weight=w) 481 | self.cpnet_simple = cpnet_simple 482 | self.cpnet = cpnet 483 | 484 | 485 | 486 | def get_graph(self, query_tokens_id, doc_tokens_id, is_print=False): 487 | edges = [] 488 | nodes_num = self.cfg['q_max_seq_len']+self.cfg['p_max_seq_len'] 489 | nodes_tokens = [] 490 | edges_type = [] 491 | query_tokens = self.tokenizer.convert_ids_to_tokens(query_tokens_id) 492 | doc_tokens = self.tokenizer.convert_ids_to_tokens(doc_tokens_id) 493 | for q_tok_id, q_token in enumerate(query_tokens): 494 | if self.concept2id.get(q_token,-1)!=-1: 495 | for d_tok_id, d_token in enumerate(doc_tokens): 496 | if self.concept2id.get(d_token,-1)!=-1: 497 | paths = self.find_paths_qd_concept_pair(q_token, d_token, is_print) 498 | d_tok_id+=self.cfg['q_max_seq_len'] 499 | for p in paths: 500 | if len(p['path'])==2: 501 | edges.append((q_tok_id, d_tok_id)) 502 | edges_type.append(p['rel'][0][0]) 503 | else: 504 | edges.append((q_tok_id, nodes_num)) 505 | edges_type.append(p['rel'][0][0]) 506 | 507 | for i,node in enumerate(p['path'][1:-2]): 508 | edges.append((nodes_num, nodes_num+1)) 509 | edges_type.append(p['rel'][i+1][0]) 510 | nodes_tokens.append(node) 511 | nodes_num+=1 512 | 513 | edges.append((nodes_num, d_tok_id)) 514 | edges_type.append(p['rel'][-1][0]) 515 | 516 | nodes_tokens.append(p['path'][-2]) 517 | nodes_num+=1 518 | return edges, edges_type, nodes_tokens,nodes_num 519 | 520 | def get_edge(self, src_concept, tgt_concept): 521 | rel_list = self.cpnet[src_concept][tgt_concept] # list of dicts 522 | seen = set() 523 | res = [r['rel'] for r in rel_list.values() if r['rel'] not in seen and (seen.add(r['rel']) or True)] # get unique values from rel_list 524 | return res 525 | 526 | def find_paths_qd_concept_pair(self,source: str, target: str, ifprint=False): 527 | s = self.concept2id[source] 528 | t = self.concept2id[target] 529 | 530 | if s not in self.cpnet_simple.nodes() or t not in self.cpnet_simple.nodes(): 531 | return [] 532 | all_path = [] 533 | try: 534 | for p in nx.shortest_simple_paths(self.cpnet_simple, source=s, target=t): 535 | if len(p) > 4 or len(all_path) >= 20: # top 20 paths 536 | break 537 | if len(p) >= 2: # skip paths of length 1 538 | all_path.append(p) 539 | except nx.exception.NetworkXNoPath: 540 | return [] 541 | 542 | pf_res = [] 543 | for p in all_path: 544 | # print([id2concept[i] for i in p]) 545 | rl = [] 546 | for src in range(len(p) - 1): 547 | src_concept = p[src] 548 | tgt_concept = p[src + 1] 549 | 550 | rel_list = self.get_edge(src_concept, tgt_concept) 551 | rl.append(rel_list) 552 | if ifprint: 553 | rel_list_str = [] 554 | for rel in rel_list: 555 | if rel < len(self.id2relation): 556 | rel_list_str.append(self.id2relation[rel]) 557 | else: 558 | rel_list_str.append(self.id2relation[rel - len(self.id2relation)] + "*") 559 | print(self.id2concept[src_concept], "----[%s]---> " % ("/".join(rel_list_str)), end="") 560 | if src + 1 == len(p) - 1: 561 | print(self.id2concept[tgt_concept], end="") 562 | if ifprint: 563 | print() 564 | 565 | pf_res.append({"path": [self.tokenizer.vocab[self.id2concept[concept_id]] for concept_id in p], "rel": rl}) 566 | return pf_res 567 | 568 | 569 | 570 | class BatchedGraphBuilder(GraphBuilder): 571 | def __init__(self, cfg): 572 | GraphBuilder.__init__(self, cfg) 573 | 574 | def get_batched_graph(self, pair_input_ids, is_print=False): 575 | # pair_input_ids:[batch_size, seq_len] 576 | batch_size = pair_input_ids.shape[0] 577 | seq_len = pair_input_ids.shape[1] 578 | edges = [] 579 | nodes_num = batch_size*seq_len 580 | nodes_tokens = [] 581 | edges_type = [] 582 | batched_query_tokens_id = [] 583 | batched_doc_tokens_id = [] 584 | for sample_id in range(batch_size): 585 | sep_id = np.where(pair_input_ids[sample_id]==102)[0][0] 586 | batched_query_tokens_id.append(pair_input_ids[sample_id][:sep_id]) 587 | batched_doc_tokens_id.append(pair_input_ids[sample_id][sep_id:]) 588 | edges.append((seq_len*sample_id, seq_len*sample_id+sep_id)) 589 | edges_type.append(self.cfg['edge_num']-1) 590 | for batch_i,query_tokens_id in enumerate(batched_query_tokens_id): 591 | doc_tokens_id = batched_doc_tokens_id[batch_i] 592 | query_tokens = self.tokenizer.convert_ids_to_tokens(query_tokens_id.reshape(-1)) 593 | doc_tokens = self.tokenizer.convert_ids_to_tokens(doc_tokens_id.reshape(-1)) 594 | # print(query_tokens,len(query_tokens)) 595 | # print(doc_tokens,len(doc_tokens)) 596 | for q_tok_id, q_token in enumerate(query_tokens): 597 | if self.concept2id.get(q_token,-1)!=-1: 598 | q_tok_id = q_tok_id + seq_len*batch_i 599 | for d_tok_id, d_token in enumerate(doc_tokens): 600 | if self.concept2id.get(d_token,-1)!=-1: 601 | paths = self.find_paths_qd_concept_pair(q_token, d_token, False) 602 | d_tok_id = d_tok_id + seq_len*batch_i + len(query_tokens) 603 | # if paths: 604 | # print(q_tok_id,d_tok_id,q_token,d_token) 605 | for p in paths: 606 | if len(p['path'])==2: 607 | edges.append((q_tok_id, d_tok_id)) 608 | edges_type.append(p['rel'][0][0]) 609 | else: 610 | edges.append((q_tok_id, nodes_num)) 611 | edges_type.append(p['rel'][0][0]) 612 | 613 | for i,node in enumerate(p['path'][1:-2]): 614 | edges.append((nodes_num, nodes_num+1)) 615 | edges_type.append(p['rel'][i+1][0]) 616 | nodes_tokens.append(node) 617 | nodes_num+=1 618 | 619 | edges.append((nodes_num, d_tok_id)) 620 | edges_type.append(p['rel'][-1][0]) 621 | 622 | nodes_tokens.append(p['path'][-2]) 623 | nodes_num+=1 624 | if nodes_num>=1.8*(batch_size*seq_len): 625 | return (edges, edges_type, nodes_tokens,nodes_num) 626 | return (edges, edges_type, nodes_tokens,nodes_num) 627 | 628 | -------------------------------------------------------------------------------- /ernie/modeling_ernie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import division 16 | from __future__ import absolute_import 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | import json 21 | import logging 22 | import math 23 | import six 24 | if six.PY2: 25 | from pathlib2 import Path 26 | else: 27 | from pathlib import Path 28 | import numpy as np 29 | import paddle as P 30 | from paddle import nn 31 | from paddle.nn import functional as F 32 | from file_utils import _fetch_from_remote, add_docstring 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | ACT_DICT = { 37 | 'relu': nn.ReLU, 38 | 'gelu': nn.GELU, 39 | } 40 | def _get_rel_pos_bias(seq_len, max_len=128, num_buckets=32, bidirectional=True, reset=True): 41 | #max_len = 520 42 | pos = np.array(range(seq_len)) 43 | rel_pos = pos[:, None] - pos[None, :] 44 | ret = 0 45 | n = -rel_pos 46 | if bidirectional: 47 | num_buckets //= 2 48 | ret += (n < 0).astype('int32') * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets 49 | n = np.abs(n) 50 | else: 51 | n = np.max(n, np.zeros_like(n)) 52 | # now n is in the range [0, inf) 53 | 54 | # half of the buckets are for exact increments in positions 55 | max_exact = num_buckets // 2 56 | is_small = n < max_exact 57 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 58 | val_if_large = max_exact + (np.log(n.astype('float32') / max_exact) / math.log(max_len / max_exact) * (num_buckets - max_exact)).astype('int32') 59 | tmp = np.full_like(val_if_large, num_buckets-1) 60 | val_if_large = np.where(val_if_large < tmp, val_if_large, tmp) 61 | 62 | ret += np.where(is_small, n, val_if_large) 63 | if reset: 64 | num_buckets *= 2 65 | ret[:, 0] = num_buckets 66 | ret[0, :] = num_buckets // 2 67 | 68 | return np.array(ret).reshape([seq_len, seq_len]).astype("int64") 69 | 70 | def _build_linear(n_in, n_out, name, init, lr=1.0): 71 | return nn.Linear( 72 | n_in, 73 | n_out, 74 | weight_attr=P.ParamAttr( 75 | name='%s.w_0' % name if name is not None else None, 76 | initializer=init,learning_rate=lr), 77 | bias_attr='%s.b_0' % name if name is not None else None, 78 | ) 79 | 80 | 81 | def _build_ln(n_in, name, lr=1.0): 82 | return nn.LayerNorm( 83 | normalized_shape=n_in, 84 | weight_attr=P.ParamAttr( 85 | name='%s_layer_norm_scale' % name if name is not None else None, 86 | initializer=nn.initializer.Constant(1.), 87 | learning_rate=lr), 88 | bias_attr=P.ParamAttr( 89 | name='%s_layer_norm_bias' % name if name is not None else None, 90 | initializer=nn.initializer.Constant(0.), learning_rate=lr), ) 91 | 92 | 93 | def append_name(name, postfix): 94 | if name is None: 95 | ret = None 96 | elif name == '': 97 | ret = postfix 98 | else: 99 | ret = '%s_%s' % (name, postfix) 100 | return ret 101 | 102 | 103 | class AttentionLayer(nn.Layer): 104 | def __init__(self, cfg, name=None, lr=1.0): 105 | super(AttentionLayer, self).__init__() 106 | initializer = nn.initializer.TruncatedNormal( 107 | std=cfg['initializer_range']) 108 | d_model = cfg['hidden_size'] 109 | n_head = cfg['num_attention_heads'] 110 | assert d_model % n_head == 0 111 | d_model_q = cfg.get('query_hidden_size_per_head', 112 | d_model // n_head) * n_head 113 | d_model_v = cfg.get('value_hidden_size_per_head', 114 | d_model // n_head) * n_head 115 | self.n_head = n_head 116 | self.d_key = d_model_q // n_head 117 | self.q = _build_linear(d_model, d_model_q, 118 | append_name(name, 'query_fc'), initializer, lr) 119 | self.k = _build_linear(d_model, d_model_q, 120 | append_name(name, 'key_fc'), initializer, lr) 121 | self.v = _build_linear(d_model, d_model_v, 122 | append_name(name, 'value_fc'), initializer, lr) 123 | self.o = _build_linear(d_model_v, d_model, 124 | append_name(name, 'output_fc'), initializer, lr) 125 | self.dropout = nn.Dropout(p=cfg['attention_probs_dropout_prob']) 126 | 127 | def forward(self, queries, keys, values, attn_bias, past_cache): 128 | assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3 129 | #bsz, q_len, q_dim = queries.shape 130 | #bsz, k_len, k_dim = keys.shape 131 | #bsz, v_len, v_dim = values.shape 132 | #assert k_len == v_len 133 | 134 | q = self.q(queries) 135 | k = self.k(keys) 136 | v = self.v(values) 137 | 138 | cache = (k, v) 139 | if past_cache is not None: 140 | cached_k, cached_v = past_cache 141 | k = P.concat([cached_k, k], 1) 142 | v = P.concat([cached_v, v], 1) 143 | 144 | q = q.reshape( 145 | [0, 0, self.n_head, q.shape[-1] // self.n_head]).transpose( 146 | [0, 2, 1, 3]) #[batch, head, seq, dim] 147 | k = k.reshape( 148 | [0, 0, self.n_head, k.shape[-1] // self.n_head]).transpose( 149 | [0, 2, 1, 3]) #[batch, head, seq, dim] 150 | v = v.reshape( 151 | [0, 0, self.n_head, v.shape[-1] // self.n_head]).transpose( 152 | [0, 2, 1, 3]) #[batch, head, seq, dim] 153 | 154 | q = q.scale(self.d_key**-0.5) 155 | score = q.matmul(k, transpose_y=True) 156 | if attn_bias is not None: 157 | score += attn_bias 158 | score = F.softmax(score) 159 | score = self.dropout(score) 160 | 161 | out = score.matmul(v).transpose([0, 2, 1, 3]) 162 | out = out.reshape([0, 0, out.shape[2] * out.shape[3]]) 163 | out = self.o(out) 164 | return out, cache 165 | 166 | 167 | class PositionwiseFeedForwardLayer(nn.Layer): 168 | def __init__(self, cfg, name=None): 169 | super(PositionwiseFeedForwardLayer, self).__init__() 170 | initializer = nn.initializer.TruncatedNormal( 171 | std=cfg['initializer_range']) 172 | d_model = cfg['hidden_size'] 173 | d_ffn = cfg.get('intermediate_size', 4 * d_model) 174 | self.act = ACT_DICT[cfg['hidden_act']]() 175 | self.i = _build_linear( 176 | d_model, 177 | d_ffn, 178 | append_name(name, 'fc_0'), 179 | initializer, cfg['ernie_lr']) 180 | self.o = _build_linear(d_ffn, d_model, 181 | append_name(name, 'fc_1'), initializer,cfg['ernie_lr']) 182 | prob = cfg.get('intermediate_dropout_prob', 0.) 183 | self.dropout = nn.Dropout(p=prob) 184 | 185 | def forward(self, inputs): 186 | hidden = self.act(self.i(inputs)) 187 | hidden = self.dropout(hidden) 188 | out = self.o(hidden) 189 | return out 190 | 191 | 192 | class ErnieBlock(nn.Layer): 193 | def __init__(self, cfg, name=None): 194 | super(ErnieBlock, self).__init__() 195 | d_model = cfg['hidden_size'] 196 | self.attn = AttentionLayer( 197 | cfg, name=append_name(name, 'multi_head_att')) 198 | self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att'), lr=cfg['ernie_lr']) 199 | self.ffn = PositionwiseFeedForwardLayer( 200 | cfg, name=append_name(name, 'ffn')) 201 | self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn'), lr=cfg['ernie_lr']) 202 | prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob']) 203 | self.dropout = nn.Dropout(p=prob) 204 | 205 | def forward(self, inputs, attn_bias=None, past_cache=None): 206 | attn_out, cache = self.attn( 207 | inputs, inputs, inputs, attn_bias, 208 | past_cache=past_cache) #self attn 209 | attn_out = self.dropout(attn_out) 210 | hidden = attn_out + inputs 211 | hidden = self.ln1(hidden) # dropout/ add/ norm 212 | 213 | ffn_out = self.ffn(hidden) 214 | ffn_out = self.dropout(ffn_out) 215 | hidden = ffn_out + hidden 216 | hidden = self.ln2(hidden) 217 | return hidden, cache 218 | 219 | 220 | class ErnieEncoderStack(nn.Layer): 221 | def __init__(self, cfg, name=None): 222 | super(ErnieEncoderStack, self).__init__() 223 | n_layers = cfg['num_hidden_layers'] 224 | self.block = nn.LayerList([ 225 | ErnieBlock(cfg, append_name(name, 'layer_%d' % i)) 226 | for i in range(n_layers) 227 | ]) 228 | 229 | def forward(self, inputs, attn_bias=None, past_cache=None): 230 | if past_cache is not None: 231 | assert isinstance( 232 | past_cache, tuple 233 | ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr( 234 | type(past_cache)) 235 | past_cache = list(zip(*past_cache)) 236 | else: 237 | past_cache = [None] * len(self.block) 238 | cache_list_k, cache_list_v, hidden_list = [], [], [inputs] 239 | 240 | for b, p in zip(self.block, past_cache): 241 | inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p) 242 | cache_k, cache_v = cache 243 | cache_list_k.append(cache_k) 244 | cache_list_v.append(cache_v) 245 | hidden_list.append(inputs) 246 | 247 | return inputs, hidden_list, (cache_list_k, cache_list_v) 248 | 249 | 250 | class PretrainedModel(object): 251 | bce = 'https://ernie-github.cdn.bcebos.com/' 252 | resource_map = { 253 | 'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz', 254 | 'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz', 255 | 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 256 | 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', 257 | 'ernie-gram-zh': bce + 'model-ernie-gram-zh.1.tar.gz', 258 | 'ernie-gram-en': bce + 'model-ernie-gram-en.1.tar.gz', 259 | } 260 | 261 | @classmethod 262 | def from_pretrained(cls, 263 | pretrain_dir_or_url, 264 | force_download=False, 265 | **kwargs): 266 | if not Path(pretrain_dir_or_url).exists() and str( 267 | pretrain_dir_or_url) in cls.resource_map: 268 | url = cls.resource_map[str(pretrain_dir_or_url)] 269 | log.info('get pretrain dir from %s' % url) 270 | pretrain_dir = _fetch_from_remote(url, force_download) 271 | else: 272 | log.info('pretrain dir %s not in %s, read from local' % 273 | (pretrain_dir_or_url, repr(cls.resource_map))) 274 | pretrain_dir = Path(pretrain_dir_or_url) 275 | 276 | if not pretrain_dir.exists(): 277 | raise ValueError('pretrain dir not found: %s, optional: %s' % (pretrain_dir, cls.resource_map.keys())) 278 | state_dict_path = pretrain_dir / 'saved_weights.pdparams' 279 | config_path = pretrain_dir / 'ernie_config.json' 280 | 281 | if not config_path.exists(): 282 | raise ValueError('config path not found: %s' % config_path) 283 | name_prefix = kwargs.pop('name', None) 284 | cfg_dict = dict(json.loads(config_path.open().read()), **kwargs) 285 | model = cls(cfg_dict, name=name_prefix) 286 | 287 | log.info('loading pretrained model from %s' % pretrain_dir) 288 | 289 | #param_path = pretrain_dir / 'params' 290 | #if os.path.exists(param_path): 291 | # raise NotImplementedError() 292 | # log.debug('load pretrained weight from program state') 293 | # F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix 294 | if state_dict_path.exists(): 295 | m = P.load(state_dict_path) 296 | for k, v in model.state_dict().items(): 297 | if k not in m: 298 | log.warn('param:%s not set in pretrained model, skip' % k) 299 | m[k] = v # FIXME: no need to do this in the future 300 | model.set_state_dict(m) 301 | else: 302 | raise ValueError('weight file not found in pretrain dir: %s' % 303 | pretrain_dir) 304 | return model 305 | 306 | 307 | class ErnieModel(nn.Layer, PretrainedModel): 308 | def __init__(self, cfg, name=None): 309 | """ 310 | Fundamental pretrained Ernie model 311 | """ 312 | log.debug('init ErnieModel with config: %s' % repr(cfg)) 313 | nn.Layer.__init__(self) 314 | d_model = cfg['hidden_size'] 315 | d_emb = cfg.get('emb_size', cfg['hidden_size']) 316 | d_vocab = cfg['vocab_size'] 317 | d_pos = cfg['max_position_embeddings'] 318 | d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] 319 | self.d_rel_pos = cfg.get('rel_pos_size', None) 320 | max_seq_len = cfg.get("max_seq_len", 512) 321 | self.n_head = cfg['num_attention_heads'] 322 | self.return_additional_info = cfg.get('return_additional_info', False) 323 | initializer = nn.initializer.TruncatedNormal( 324 | std=cfg['initializer_range']) 325 | if self.d_rel_pos: 326 | self.rel_pos_bias = _get_rel_pos_bias(max_seq_len) 327 | 328 | self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) 329 | self.word_emb = nn.Embedding( 330 | d_vocab, 331 | d_emb, 332 | weight_attr=P.ParamAttr( 333 | name=append_name(name, 'word_embedding'), 334 | initializer=initializer)) 335 | self.pos_emb = nn.Embedding( 336 | d_pos, 337 | d_emb, 338 | weight_attr=P.ParamAttr( 339 | name=append_name(name, 'pos_embedding'), 340 | initializer=initializer)) 341 | self.sent_emb = nn.Embedding( 342 | d_sent, 343 | d_emb, 344 | weight_attr=P.ParamAttr( 345 | name=append_name(name, 'sent_embedding'), 346 | initializer=initializer)) 347 | if self.d_rel_pos: 348 | self.rel_pos_bias_emb = nn.Embedding( 349 | self.d_rel_pos, 350 | self.n_head, 351 | weight_attr=P.ParamAttr( 352 | name=append_name(name, 'rel_pos_embedding'), 353 | initializer=initializer)) 354 | prob = cfg['hidden_dropout_prob'] 355 | self.dropout = nn.Dropout(p=prob) 356 | 357 | self.encoder_stack = ErnieEncoderStack(cfg, 358 | append_name(name, 'encoder')) 359 | if cfg.get('has_pooler', True): 360 | self.pooler = _build_linear( 361 | cfg['hidden_size'], 362 | cfg['hidden_size'], 363 | append_name(name, 'pooled_fc'), 364 | initializer, ) 365 | else: 366 | self.pooler = None 367 | self.train() 368 | 369 | #FIXME:remove this 370 | def eval(self): 371 | if P.in_dynamic_mode(): 372 | super(ErnieModel, self).eval() 373 | self.training = False 374 | for l in self.sublayers(): 375 | l.training = False 376 | return self 377 | 378 | def train(self): 379 | if P.in_dynamic_mode(): 380 | super(ErnieModel, self).train() 381 | self.training = True 382 | for l in self.sublayers(): 383 | l.training = True 384 | return self 385 | 386 | def forward(self, 387 | src_ids, 388 | sent_ids=None, 389 | pos_ids=None, 390 | input_mask=None, 391 | attn_bias=None, 392 | past_cache=None, 393 | use_causal_mask=False): 394 | 395 | """ 396 | Args: 397 | src_ids (`Variable` of shape `[batch_size, seq_len]`): 398 | Indices of input sequence tokens in the vocabulary. 399 | sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): 400 | aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. 401 | if None, assume all tokens come from `segment_a` 402 | pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): 403 | Indices of positions of each input sequence tokens in the position embeddings. 404 | input_mask(optional `Variable` of shape `[batch_size, seq_len]`): 405 | Mask to avoid performing attention on the padding token indices of the encoder input. 406 | attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): 407 | 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask 408 | past_cache(optional, tuple of two lists: cached key and cached value, 409 | each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): 410 | cached key/value tensor that will be concated to generated key/value when performing self attention. 411 | if set, `attn_bias` should not be None. 412 | Returns: 413 | pooled (`Variable` of shape `[batch_size, hidden_size]`): 414 | output logits of pooler classifier 415 | encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): 416 | output logits of transformer stack 417 | info (Dictionary): 418 | addtional middle level info, inclues: all hidden stats, k/v caches. 419 | """ 420 | assert len( 421 | src_ids. 422 | shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % ( 423 | repr(src_ids.shape)) 424 | assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' 425 | d_seqlen = P.shape(src_ids)[1] 426 | if pos_ids is None: 427 | pos_ids = P.arange( 428 | 0, d_seqlen, 1, dtype='int32').reshape([1, -1]).cast('int64') 429 | if attn_bias is None: 430 | if input_mask is None: 431 | input_mask = P.cast(src_ids != 0, 'float32') 432 | assert len(input_mask.shape) == 2 433 | input_mask = input_mask.unsqueeze(-1) 434 | attn_bias = input_mask.matmul(input_mask, transpose_y=True) 435 | if use_causal_mask: 436 | sequence = P.reshape( 437 | P.arange( 438 | 0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1]) 439 | causal_mask = (sequence.matmul( 440 | 1. / sequence, transpose_y=True) >= 1.).cast('float32') 441 | attn_bias *= causal_mask 442 | else: 443 | assert len( 444 | attn_bias.shape 445 | ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape 446 | attn_bias = (1. - attn_bias) * -10000.0 447 | attn_bias = attn_bias.unsqueeze(1).tile( 448 | [1, self.n_head, 1, 1]) # avoid broadcast =_= 449 | attn_bias.stop_gradient=True 450 | if sent_ids is None: 451 | sent_ids = P.zeros_like(src_ids) 452 | if self.d_rel_pos: 453 | rel_pos_ids = self.rel_pos_bias[:d_seqlen, :d_seqlen] 454 | rel_pos_ids = P.to_tensor(rel_pos_ids, dtype='int64') 455 | rel_pos_bias = self.rel_pos_bias_emb(rel_pos_ids).transpose([2, 0, 1]) 456 | attn_bias += rel_pos_bias 457 | src_embedded = self.word_emb(src_ids) 458 | pos_embedded = self.pos_emb(pos_ids) 459 | sent_embedded = self.sent_emb(sent_ids) 460 | embedded = src_embedded + pos_embedded + sent_embedded 461 | 462 | 463 | embedded = self.dropout(self.ln(embedded)) 464 | 465 | encoded, hidden_list, cache_list = self.encoder_stack( 466 | embedded, attn_bias, past_cache=past_cache) 467 | if self.pooler is not None: 468 | pooled = F.tanh(self.pooler(encoded[:, 0, :])) 469 | else: 470 | pooled = None 471 | 472 | additional_info = { 473 | 'hiddens': hidden_list, 474 | 'caches': cache_list, 475 | } 476 | 477 | if self.return_additional_info: 478 | return pooled, encoded, additional_info 479 | return pooled, encoded 480 | 481 | 482 | class ErnieModelForSequenceClassification(ErnieModel): 483 | """ 484 | Ernie Model for text classfication or pointwise ranking tasks 485 | """ 486 | 487 | def __init__(self, cfg, name=None): 488 | super(ErnieModelForSequenceClassification, self).__init__( 489 | cfg, name=name) 490 | 491 | initializer = nn.initializer.TruncatedNormal( 492 | std=cfg['initializer_range']) 493 | self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'], 494 | append_name(name, 'cls'), initializer) 495 | 496 | prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob']) 497 | self.dropout = nn.Dropout(p=prob) 498 | self.train() 499 | 500 | @add_docstring(ErnieModel.forward.__doc__) 501 | def forward(self, *args, **kwargs): 502 | """ 503 | Args: 504 | labels (optional, `Variable` of shape [batch_size]): 505 | ground truth label id for each sentence 506 | Returns: 507 | loss (`Variable` of shape []): 508 | Cross entropy loss mean over batch 509 | if labels not set, returns None 510 | logits (`Variable` of shape [batch_size, hidden_size]): 511 | output logits of classifier 512 | """ 513 | labels = kwargs.pop('labels', None) 514 | pooled, encoded = super(ErnieModelForSequenceClassification, 515 | self).forward(*args, **kwargs) 516 | hidden = self.dropout(pooled) 517 | logits = self.classifier(hidden) 518 | 519 | if labels is not None: 520 | if len(labels.shape) != 1: 521 | labels = labels.squeeze() 522 | loss = F.cross_entropy(logits, labels) 523 | else: 524 | loss = None 525 | return loss, logits 526 | 527 | 528 | class ErnieModelForTokenClassification(ErnieModel): 529 | """ 530 | Ernie Model for Named entity tasks(NER) 531 | """ 532 | 533 | def __init__(self, cfg, name=None): 534 | super(ErnieModelForTokenClassification, self).__init__(cfg, name=name) 535 | 536 | initializer = nn.initializer.TruncatedNormal( 537 | std=cfg['initializer_range']) 538 | self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'], 539 | append_name(name, 'cls'), initializer) 540 | 541 | prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob']) 542 | self.dropout = nn.Dropout(p=prob) 543 | self.train() 544 | 545 | @add_docstring(ErnieModel.forward.__doc__) 546 | def forward(self, *args, **kwargs): 547 | """ 548 | Args: 549 | labels (optional, `Variable` of shape [batch_size, seq_len]): 550 | ground truth label id for each token 551 | Returns: 552 | loss (`Variable` of shape []): 553 | Cross entropy loss mean over batch and time, ignore positions where label == -100 554 | if labels not set, returns None 555 | logits (`Variable` of shape [batch_size, seq_len, hidden_size]): 556 | output logits of classifier 557 | loss_weights (`Variable` of shape [batch_size, seq_len]): 558 | weigths of loss for each tokens. 559 | ignore_index (int): 560 | when label == `ignore_index`, this token will not contribute to loss 561 | """ 562 | ignore_index = kwargs.pop('ignore_index', -100) 563 | labels = kwargs.pop('labels', None) 564 | loss_weights = kwargs.pop('loss_weights', None) 565 | pooled, encoded = super(ErnieModelForTokenClassification, 566 | self).forward(*args, **kwargs) 567 | hidden = self.dropout(encoded) # maybe not? 568 | logits = self.classifier(hidden) 569 | 570 | if labels is not None: 571 | if len(labels.shape) != 2: 572 | labels = labels.squeeze() 573 | loss = F.cross_entropy( 574 | logits, labels, ignore_index=ignore_index, reduction='none') 575 | if loss_weights is not None: 576 | loss = loss * loss_weights 577 | loss = loss.mean() 578 | else: 579 | loss = None 580 | return loss, logits 581 | 582 | 583 | class ErnieModelForQuestionAnswering(ErnieModel): 584 | """ 585 | Ernie model for reading comprehension tasks (SQuAD) 586 | """ 587 | 588 | def __init__(self, cfg, name=None): 589 | super(ErnieModelForQuestionAnswering, self).__init__(cfg, name=name) 590 | 591 | initializer = nn.initializer.TruncatedNormal( 592 | std=cfg['initializer_range']) 593 | self.classifier = _build_linear(cfg['hidden_size'], 2, 594 | append_name(name, 'cls_mrc'), 595 | initializer) 596 | 597 | prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob']) 598 | self.dropout = nn.Dropout(p=prob) 599 | self.train() 600 | 601 | @add_docstring(ErnieModel.forward.__doc__) 602 | def forward(self, *args, **kwargs): 603 | """ 604 | Args: 605 | start_pos (optional, `Variable` of shape [batch_size]): 606 | token index of start of answer span in `context` 607 | end_pos (optional, `Variable` of shape [batch_size]): 608 | token index of end of answer span in `context` 609 | Returns: 610 | loss (`Variable` of shape []): 611 | Cross entropy loss mean over batch and time, ignore positions where label == -100 612 | if labels not set, returns None 613 | start_logits (`Variable` of shape [batch_size, hidden_size]): 614 | output logits of start position, use argmax(start_logit) to get start index 615 | end_logits (`Variable` of shape [batch_size, hidden_size]): 616 | output logits of end position, use argmax(end_logit) to get end index 617 | """ 618 | 619 | start_pos = kwargs.pop('start_pos', None) 620 | end_pos = kwargs.pop('end_pos', None) 621 | pooled, encoded = super(ErnieModelForQuestionAnswering, self).forward( 622 | *args, **kwargs) 623 | encoded = self.dropout(encoded) 624 | encoded = self.classifier(encoded) 625 | start_logit, end_logits = P.unstack(encoded, axis=-1) 626 | if start_pos is not None and end_pos is not None: 627 | if len(start_pos.shape) != 1: 628 | start_pos = start_pos.squeeze() 629 | if len(end_pos.shape) != 1: 630 | end_pos = end_pos.squeeze() 631 | start_loss = F.cross_entropy(start_logit, start_pos) 632 | end_loss = F.cross_entropy(end_logits, end_pos) 633 | loss = (start_loss.mean() + end_loss.mean()) / 2. 634 | else: 635 | loss = None 636 | return loss, start_logit, end_logits 637 | 638 | 639 | class NSPHead(nn.Layer): 640 | def __init__(self, cfg, name=None): 641 | super(NSPHead, self).__init__() 642 | initializer = nn.initializer.TruncatedNormal( 643 | std=cfg['initializer_range']) 644 | self.nsp = _build_linear(cfg['hidden_size'], 2, 645 | append_name(name, 'nsp_fc'), initializer) 646 | 647 | def forward(self, inputs, labels): 648 | """ 649 | Args: 650 | start_pos (optional, `Variable` of shape [batch_size]): 651 | token index of start of answer span in `context` 652 | end_pos (optional, `Variable` of shape [batch_size]): 653 | token index of end of answer span in `context` 654 | Returns: 655 | loss (`Variable` of shape []): 656 | Cross entropy loss mean over batch and time, ignore positions where label == -100 657 | if labels not set, returns None 658 | start_logits (`Variable` of shape [batch_size, hidden_size]): 659 | output logits of start position 660 | end_logits (`Variable` of shape [batch_size, hidden_size]): 661 | output logits of end position 662 | """ 663 | 664 | logits = self.nsp(inputs) 665 | loss = F.cross_entropy(logits, labels) 666 | return loss 667 | 668 | 669 | class ErnieModelForPretraining(ErnieModel): 670 | """ 671 | Ernie Model for Masked Languate Model pretrain 672 | """ 673 | 674 | def __init__(self, cfg, name=None): 675 | super(ErnieModelForPretraining, self).__init__(cfg, name=name) 676 | initializer = nn.initializer.TruncatedNormal( 677 | std=cfg['initializer_range']) 678 | d_model = cfg['hidden_size'] 679 | d_vocab = cfg['vocab_size'] 680 | 681 | self.pooler_heads = nn.LayerList([NSPHead(cfg, name=name)]) 682 | self.mlm = _build_linear( 683 | d_model, 684 | d_model, 685 | append_name(name, 'mask_lm_trans_fc'), 686 | initializer, ) 687 | self.act = ACT_DICT[cfg['hidden_act']]() 688 | self.mlm_ln = _build_ln( 689 | d_model, name=append_name(name, 'mask_lm_trans')) 690 | self.mlm_bias = P.create_parameter( 691 | dtype='float32', 692 | shape=[d_vocab], 693 | attr=P.ParamAttr( 694 | name=append_name(name, 'mask_lm_out_fc.b_0'), 695 | initializer=nn.initializer.Constant(value=0.0)), 696 | is_bias=True, ) 697 | self.train() 698 | 699 | @add_docstring(ErnieModel.forward.__doc__) 700 | def forward(self, *args, **kwargs): 701 | """ 702 | Args: 703 | nsp_labels (optional, `Variable` of shape [batch_size]): 704 | labels for `next sentence prediction` tasks 705 | mlm_pos (optional, `Variable` of shape [n_mask, 2]): 706 | index of mask_id in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)` 707 | labels (optional, `Variable` of shape [n_mask]): 708 | labels for `mask language model` tasks, the original token indices in masked position in `src_ids` 709 | Returns: 710 | loss (`Variable` of shape []): 711 | total_loss of `next sentence prediction` and `masked language model` 712 | mlm_loss (`Variable` of shape []): 713 | loss for `masked language model` task 714 | nsp_loss (`Variable` of shape []): 715 | loss for `next sentence prediction` task 716 | """ 717 | 718 | mlm_labels = kwargs.pop('labels') 719 | mlm_pos = kwargs.pop('mlm_pos') 720 | nsp_labels = kwargs.pop('nsp_labels') 721 | pooled, encoded = super(ErnieModelForPretraining, self).forward( 722 | *args, **kwargs) 723 | if len(mlm_labels.shape) != 1: 724 | mlm_labels = mlm_labels.squeeze() 725 | if len(nsp_labels.shape) == 1: 726 | nsp_labels = nsp_labels.squeeze() 727 | 728 | nsp_loss = self.pooler_heads[0](pooled, nsp_labels) 729 | 730 | encoded_2d = encoded.gather_nd(mlm_pos) 731 | encoded_2d = self.act(self.mlm(encoded_2d)) 732 | encoded_2d = self.mlm_ln(encoded_2d) 733 | logits_2d = encoded_2d.matmul( 734 | self.word_emb.weight, transpose_y=True) + self.mlm_bias 735 | mlm_loss = F.cross_entropy(logits_2d, mlm_labels) 736 | total_loss = mlm_loss + nsp_loss 737 | return total_loss, mlm_loss, nsp_loss 738 | 739 | 740 | class ErnieModelForGeneration(ErnieModel): 741 | """ 742 | Ernie Model for sequence to sequence generation. 743 | """ 744 | resource_map = { 745 | 'ernie-gen-base-en': 746 | ErnieModel.bce + 'model-ernie-gen-base-en.1.tar.gz', 747 | 'ernie-gen-large-en': 748 | ErnieModel.bce + 'model-ernie-gen-large-en.1.tar.gz', 749 | 'ernie-gen-large-430g-en': 750 | ErnieModel.bce + 'model-ernie-gen-large-430g-en.1.tar.gz', 751 | 'ernie-1.0': ErnieModel.bce + 'model-ernie1.0.1.tar.gz', 752 | } 753 | 754 | def __init__(self, cfg, name=None): 755 | cfg['return_additional_info'] = True 756 | cfg['has_pooler'] = False 757 | super(ErnieModelForGeneration, self).__init__(cfg, name=name) 758 | initializer = nn.initializer.TruncatedNormal( 759 | std=cfg['initializer_range']) 760 | d_model = cfg['hidden_size'] 761 | d_vocab = cfg['vocab_size'] 762 | 763 | self.mlm = _build_linear( 764 | d_model, 765 | d_model, 766 | append_name(name, 'mask_lm_trans_fc'), 767 | initializer, ) 768 | self.act = ACT_DICT[cfg['hidden_act']]() 769 | self.mlm_ln = _build_ln( 770 | d_model, name=append_name(name, 'mask_lm_trans')) 771 | self.mlm_bias = P.create_parameter( 772 | dtype='float32', 773 | shape=[d_vocab], 774 | attr=P.ParamAttr( 775 | name=append_name(name, 'mask_lm_out_fc.b_0'), 776 | initializer=nn.initializer.Constant(value=0.0)), 777 | is_bias=True, ) 778 | self.train() 779 | 780 | @add_docstring(ErnieModel.forward.__doc__) 781 | def forward(self, *args, **kwargs): 782 | """ 783 | Args 784 | tgt_labels(`Variable` of shape [batch_size, seqlen] or [batch, seqlen, vocab_size]): 785 | ground trouth target sequence id (hard label) or distribution (soft label) 786 | tgt_pos(`Variable` of shape [n_targets, 2]): 787 | index of tgt_labels in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)` 788 | encoder_only(Bool): 789 | if set, will not return loss, logits_2d 790 | Returns: 791 | loss(`Variable` of shape []): 792 | cross entropy loss mean over every target label. if `encode_only`, returns None. 793 | logits(`Variable` of shape [n_targets, vocab_size]): 794 | logits for every targets. if `encode_only`, returns None. 795 | info(Dictionary): see `ErnieModel` 796 | """ 797 | tgt_labels = kwargs.pop('tgt_labels', None) 798 | tgt_pos = kwargs.pop('tgt_pos', None) 799 | encode_only = kwargs.pop('encode_only', False) 800 | _, encoded, info = ErnieModel.forward(self, *args, **kwargs) 801 | if encode_only: 802 | return None, None, info 803 | if tgt_labels is None or tgt_pos is None: 804 | encoded = self.act(self.mlm(encoded)) 805 | encoded = self.mlm_ln(encoded) 806 | logits = encoded.matmul( 807 | self.word_emb.weight, transpose_y=True) + self.mlm_bias 808 | output_ids = logits.cast('float32').argmax(-1) 809 | return output_ids, logits, info 810 | else: 811 | encoded_2d = encoded.gather_nd(tgt_pos) 812 | encoded_2d = self.act(self.mlm(encoded_2d)) 813 | encoded_2d = self.mlm_ln(encoded_2d) 814 | logits_2d = encoded_2d.matmul( 815 | self.word_emb.weight, transpose_y=True) + self.mlm_bias 816 | assert len( 817 | tgt_labels.shape) == 2, 'expect 2d label, got %r' % tgt_labels 818 | 819 | loss = F.cross_entropy(logits_2d, tgt_labels, soft_label=True) 820 | return loss, logits_2d, info --------------------------------------------------------------------------------