├── .gitignore ├── README.md ├── README_zh.md ├── customization_cn.md ├── examples ├── classification │ ├── README.md │ ├── download.py │ ├── evaluate.py │ └── run.py ├── matching │ ├── README.md │ ├── download.py │ ├── evaluate.py │ ├── process.py │ └── run.py ├── mrc │ ├── README.md │ ├── download.py │ ├── evaluate.py │ └── run.py ├── multi-task │ ├── README.md │ ├── download.py │ ├── evaluate_intent.py │ ├── evaluate_slot.py │ ├── joint_predict.py │ ├── predict_intent.py │ ├── predict_slot.py │ ├── process.py │ └── run.py ├── predict │ ├── README.md │ ├── download.py │ ├── evaluate.py │ └── run.py ├── tagging │ ├── README.md │ ├── download.py │ ├── evaluate.py │ └── run.py └── train_with_eval │ ├── README.md │ ├── download.py │ ├── evaluate.py │ └── run.py ├── paddlepalm.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── not-zip-safe ├── requires.txt └── top_level.txt ├── paddlepalm ├── __init__.py ├── _downloader.py ├── backbone │ ├── README.md │ ├── __init__.py │ ├── base_backbone.py │ ├── bert.py │ ├── ernie.py │ └── utils │ │ ├── __init__.py │ │ └── transformer.py ├── distribute │ ├── __init__.py │ └── reader.py ├── downloader.py ├── head │ ├── __init__.py │ ├── base_head.py │ ├── cls.py │ ├── match.py │ ├── mlm.py │ ├── mrc.py │ └── ner.py ├── lr_sched │ ├── __init__.py │ ├── base_schedualer.py │ ├── slanted_triangular_schedualer.py │ └── warmup_schedualer.py ├── multihead_trainer.py ├── optimizer │ ├── __init__.py │ ├── adam.py │ └── base_optimizer.py ├── reader │ ├── __init__.py │ ├── base_reader.py │ ├── cls.py │ ├── match.py │ ├── mlm.py │ ├── mrc.py │ ├── seq_label.py │ └── utils │ │ ├── __init__.py │ │ ├── batching4bert.py │ │ ├── batching4ernie.py │ │ ├── mlm_batching.py │ │ ├── mrqa_helper.py │ │ └── reader4ernie.py ├── tokenizer │ ├── __init__.py │ ├── bert_tokenizer.py │ └── ernie_tokenizer.py ├── trainer.py └── utils │ ├── __init__.py │ ├── basic_helper.py │ ├── config_helper.py │ ├── plot_helper.py │ ├── print_helper.py │ ├── reader_helper.py │ ├── saver.py │ └── textprocess_helper.py ├── setup.cfg ├── setup.py └── test ├── test2 ├── config.yaml ├── data │ ├── cls4mrqa │ │ ├── dev.tsv │ │ └── train.tsv │ ├── match4mrqa │ │ └── train.tsv │ ├── mlm4mrqa │ │ └── train.tsv │ └── mrqa │ │ ├── dev.json │ │ └── train.json ├── paddlepalm ├── run.py └── run.sh └── test3 ├── config.yaml ├── data └── cls4mrqa │ ├── dev.tsv │ └── train.tsv ├── paddlepalm ├── run.py └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | paddlepalm.egg-info 3 | data 4 | __pycache__ 5 | *egg-info 6 | pretrain_model 7 | pretrain 8 | output* 9 | output_model 10 | build 11 | dist 12 | paddle_palm.egg-info 13 | mrqa_output 14 | *.log 15 | -------------------------------------------------------------------------------- /examples/classification/README.md: -------------------------------------------------------------------------------- 1 | ## Example 1: Classification 2 | This task is a sentiment analysis task. The following sections detail model preparation, dataset preparation, and how to run the task. 3 | 4 | ### Step 1: Prepare Pre-trained Model & Dataset 5 | 6 | #### Pre-trained Model 7 | 8 | The pre-training model of this mission is: [ERNIE-v1-zh-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). 9 | 10 | Make sure you have downloaded the required pre-training model in the current folder. 11 | 12 | 13 | #### Dataset 14 | 15 | This example demonstrates with [ChnSentiCorp](https://github.com/SophonPlus/ChineseNlpCorpus/tree/master/datasets/ChnSentiCorp_htl_all), a Chinese sentiment analysis dataset. 16 | 17 | Download dataset: 18 | ```shell 19 | python download.py 20 | ``` 21 | 22 | If everything goes well, there will be a folder named `data/` created with all the data files in it. 23 | 24 | The dataset file (for training) should have 2 fields, `text_a` and `label`, stored with [tsv](https://en.wikipedia.org/wiki/Tab-separated_values) format. Here shows an example: 25 | 26 | ``` 27 | label text_a 28 | 0 当当网名不符实,订货多日不见送货,询问客服只会推托,只会要求用户再下订单。如此服务留不住顾客的。去别的网站买书服务更好。 29 | 0 XP的驱动不好找!我的17号提的货,现在就降价了100元,而且还送杀毒软件! 30 | 1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道当年我听说这本书的时候花很长时间去图书馆找和借都没能如愿,所以这次一看到当当有,马上买了,红迷们也要记得备货哦! 31 | ``` 32 | 33 | ### Step 2: Train & Predict 34 | 35 | The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: 36 | 37 | ```shell 38 | python run.py 39 | ``` 40 | 41 | If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: 42 | 43 | ```shell 44 | CUDA_VISIBLE_DEVICES=0,1 python run.py 45 | ``` 46 | 47 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 48 | 49 | 50 | Some logs will be shown below: 51 | 52 | ``` 53 | step 1/154 (epoch 0), loss: 5.512, speed: 0.51 steps/s 54 | step 2/154 (epoch 0), loss: 2.595, speed: 3.36 steps/s 55 | step 3/154 (epoch 0), loss: 1.798, speed: 3.48 steps/s 56 | ``` 57 | 58 | 59 | After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: 60 | 61 | 62 | ``` 63 | {"index": 0, "logits": [-0.2014336884021759, 0.6799028515815735], "probs": [0.29290086030960083, 0.7070990800857544], "label": 1} 64 | {"index": 1, "logits": [0.8593899011611938, -0.29743513464927673], "probs": [0.7607553601264954, 0.23924466967582703], "label": 0} 65 | {"index": 2, "logits": [0.7462944388389587, -0.7083730101585388], "probs": [0.8107157349586487, 0.18928426504135132], "label": 0} 66 | ``` 67 | 68 | ### Step 3: Evaluate 69 | 70 | Once you have the prediction, you can run the evaluation script to evaluate the model: 71 | 72 | ```shell 73 | python evaluate.py 74 | ``` 75 | 76 | The evaluation results are as follows: 77 | 78 | ``` 79 | data num: 1200 80 | accuracy: 0.9575, precision: 0.9634, recall: 0.9523, f1: 0.9578 81 | ``` 82 | -------------------------------------------------------------------------------- /examples/classification/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | abs_path = os.path.abspath(__file__) 34 | dst_dir = os.path.join(os.path.dirname(abs_path), "data") 35 | if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): 36 | os.makedirs(dst_dir) 37 | 38 | for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): 39 | shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) 40 | 41 | shutil.rmtree(os.path.join(target_dir, 'task_data')) 42 | print(" done!") 43 | -------------------------------------------------------------------------------- /examples/classification/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | 6 | def accuracy(preds, labels): 7 | preds = np.array(preds) 8 | labels = np.array(labels) 9 | return (preds == labels).mean() 10 | 11 | def pre_recall_f1(preds, labels): 12 | preds = np.array(preds) 13 | labels = np.array(labels) 14 | # recall=TP/(TP+FN) 15 | tp = np.sum((labels == '1') & (preds == '1')) 16 | fp = np.sum((labels == '0') & (preds == '1')) 17 | fn = np.sum((labels == '1') & (preds == '0')) 18 | r = tp * 1.0 / (tp + fn) 19 | # Precision=TP/(TP+FP) 20 | p = tp * 1.0 / (tp + fp) 21 | epsilon = 1e-31 22 | f1 = 2 * p * r / (p+r+epsilon) 23 | return p, r, f1 24 | 25 | 26 | def res_evaluate(res_dir="./outputs/predict/predictions.json", eval_phase='test'): 27 | if eval_phase == 'test': 28 | data_dir="./data/test.tsv" 29 | elif eval_phase == 'dev': 30 | data_dir="./data/dev.tsv" 31 | else: 32 | assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' 33 | 34 | labels = [] 35 | with open(data_dir, "r") as file: 36 | first_flag = True 37 | for line in file: 38 | line = line.split("\t") 39 | label = line[0] 40 | if label=='label': 41 | continue 42 | labels.append(str(label)) 43 | file.close() 44 | 45 | preds = [] 46 | with open(res_dir, "r") as file: 47 | for line in file.readlines(): 48 | line = json.loads(line) 49 | pred = line['label'] 50 | preds.append(str(pred)) 51 | file.close() 52 | assert len(labels) == len(preds), "prediction result doesn't match to labels" 53 | print('data num: {}'.format(len(labels))) 54 | p, r, f1 = pre_recall_f1(preds, labels) 55 | print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) 56 | 57 | res_evaluate() 58 | -------------------------------------------------------------------------------- /examples/classification/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | # configs 9 | max_seqlen = 256 10 | batch_size = 8 11 | num_epochs = 10 12 | lr = 5e-5 13 | weight_decay = 0.01 14 | vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' 15 | 16 | train_file = './data/train.tsv' 17 | predict_file = './data/test.tsv' 18 | config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) 19 | input_dim = config['hidden_size'] 20 | num_classes = 2 21 | dropout_prob = 0.1 22 | random_seed = 1 23 | task_name = 'chnsenticorp' 24 | save_path = './outputs/' 25 | pred_output = './outputs/predict/' 26 | save_type = 'ckpt' 27 | print_steps = 20 28 | pre_params = './pretrain/ERNIE-v1-zh-base/params' 29 | 30 | # ----------------------- for training ----------------------- 31 | 32 | # step 1-1: create readers for training 33 | cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) 34 | # step 1-2: load the training data 35 | cls_reader.load_data(train_file, batch_size, num_epochs=num_epochs) 36 | 37 | # step 2: create a backbone of the model to extract text features 38 | ernie = palm.backbone.ERNIE.from_config(config) 39 | 40 | # step 3: register the backbone in reader 41 | cls_reader.register_with(ernie) 42 | 43 | # step 4: create the task output head 44 | cls_head = palm.head.Classify(num_classes, input_dim, dropout_prob) 45 | 46 | # step 5-1: create a task trainer 47 | trainer = palm.Trainer(task_name) 48 | # step 5-2: build forward graph with backbone and task head 49 | loss_var = trainer.build_forward(ernie, cls_head) 50 | 51 | # step 6-1*: use warmup 52 | n_steps = cls_reader.num_examples * num_epochs // batch_size 53 | warmup_steps = int(0.1 * n_steps) 54 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 55 | # step 6-2: create a optimizer 56 | adam = palm.optimizer.Adam(loss_var, lr, sched) 57 | # step 6-3: build backward 58 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 59 | 60 | # step 7: fit prepared reader and data 61 | trainer.fit_reader(cls_reader) 62 | 63 | # step 8-1*: load pretrained parameters 64 | trainer.load_pretrain(pre_params) 65 | # step 8-2*: set saver to save model 66 | # save_steps = n_steps 67 | save_steps = 2396 68 | trainer.set_saver(save_steps=save_steps, save_path=save_path, save_type=save_type) 69 | # step 8-3: start training 70 | trainer.train(print_steps=print_steps) 71 | 72 | # ----------------------- for prediction ----------------------- 73 | 74 | # step 1-1: create readers for prediction 75 | print('prepare to predict...') 76 | predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') 77 | # step 1-2: load the training data 78 | predict_cls_reader.load_data(predict_file, batch_size) 79 | 80 | # step 2: create a backbone of the model to extract text features 81 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 82 | 83 | # step 3: register the backbone in reader 84 | predict_cls_reader.register_with(pred_ernie) 85 | 86 | # step 4: create the task output head 87 | cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict') 88 | 89 | # step 5: build forward graph with backbone and task head 90 | trainer.build_predict_forward(pred_ernie, cls_pred_head) 91 | 92 | # step 6: load checkpoint 93 | # model_path = './outputs/ckpt.step'+str(save_steps) 94 | model_path = './outputs/ckpt.step'+str(11980) 95 | trainer.load_ckpt(model_path) 96 | 97 | # step 7: fit prepared reader and data 98 | trainer.fit_reader(predict_cls_reader, phase='predict') 99 | 100 | # step 8: predict 101 | print('predicting..') 102 | trainer.predict(print_steps=print_steps, output_dir=pred_output) 103 | -------------------------------------------------------------------------------- /examples/matching/README.md: -------------------------------------------------------------------------------- 1 | ## Example 2: Matching 2 | This task is a sentence pair matching task. The following sections detail model preparation, dataset preparation, and how to run the task with PaddlePALM. 3 | 4 | ### Step 1: Prepare Pre-trained Models & Datasets 5 | 6 | #### Download Pre-trained Model 7 | 8 | The pre-training model of this mission is: [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). 9 | 10 | Make sure you have downloaded the required pre-training model in the current folder. 11 | 12 | 13 | #### Dataset 14 | 15 | Here takes the [Quora Question Pairs](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) dataset as the testbed for matching. 16 | 17 | Download dataset: 18 | ```shell 19 | python download.py 20 | ``` 21 | 22 | After the dataset is downloaded, you should convert the data format for training: 23 | ```shell 24 | python process.py data/quora_duplicate_questions.tsv data/train.tsv data/test.tsv 25 | ``` 26 | 27 | If everything goes well, there will be a folder named `data/` created with all the converted datas in it. 28 | 29 | The dataset file (for training) should have 3 fields, `text_a`, `text_b` and `label`, stored with [tsv](https://en.wikipedia.org/wiki/Tab-separated_values) format. Here shows an example: 30 | 31 | ``` 32 | text_a text_b label 33 | How can the arrangement of corynebacterium xerosis be described? How would you describe waves? 0 34 | How do you fix a Google Play Store account that isn't working? What can cause the Google Play store to not open? How are such probelms fixed? 1 35 | Which is the best earphone under 1000? What are the best earphones under 1k? 1 36 | What are the differences between the Dell Inspiron 3000, 5000, and 7000 series laptops? "Should I buy an Apple MacBook Pro 15"" or a Dell Inspiron 17 5000 series?" 0 37 | ``` 38 | 39 | 40 | 41 | ### Step 2: Train & Predict 42 | 43 | The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: 44 | 45 | ```shell 46 | python run.py 47 | ``` 48 | 49 | If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: 50 | 51 | ```shell 52 | CUDA_VISIBLE_DEVICES=0,1 python run.py 53 | ``` 54 | 55 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 56 | 57 | Some logs will be shown below: 58 | 59 | ``` 60 | step 20/49087 (epoch 0), loss: 1.079, speed: 3.48 steps/s 61 | step 40/49087 (epoch 0), loss: 1.251, speed: 5.18 steps/s 62 | step 60/49087 (epoch 0), loss: 1.193, speed: 5.04 steps/s 63 | ``` 64 | 65 | 66 | After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: 67 | 68 | 69 | ``` 70 | {"index": 0, "logits": [-0.32688724994659424, -0.8568955063819885], "probs": [0.629485011100769, 0.3705149292945862], "label": 0} 71 | {"index": 1, "logits": [-0.2735646963119507, -0.7983021140098572], "probs": [0.6282548904418945, 0.37174513936042786], "label": 0} 72 | {"index": 2, "logits": [-0.3381381630897522, -0.8614270091056824], "probs": [0.6279165148735046, 0.37208351492881775], "label": 0} 73 | ``` 74 | 75 | ### Step 3: Evaluate 76 | 77 | Once you have the prediction, you can run the evaluation script to evaluate the model: 78 | 79 | ```shell 80 | python evaluate.py 81 | ``` 82 | 83 | The evaluation results are as follows: 84 | 85 | ``` 86 | data num: 4300 87 | accuracy: 0.8619, precision: 0.8061, recall: 0.8377, f1: 0.8216 88 | ``` 89 | -------------------------------------------------------------------------------- /examples/matching/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import sys 5 | import urllib 6 | URLLIB=urllib 7 | if sys.version_info >= (3, 0): 8 | import urllib.request 9 | URLLIB=urllib.request 10 | 11 | def download(src, url): 12 | def _reporthook(count, chunk_size, total_size): 13 | bytes_so_far = count * chunk_size 14 | percent = float(bytes_so_far) / float(total_size) 15 | if percent > 1: 16 | percent = 1 17 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 18 | 19 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 20 | 21 | 22 | abs_path = os.path.abspath(__file__) 23 | data_dir = os.path.join(os.path.dirname(abs_path), "data") 24 | if not os.path.exists(data_dir) or not os.path.isdir(data_dir): 25 | os.makedirs(data_dir) 26 | 27 | download_url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" 28 | downlaod_path = os.path.join(data_dir, "quora_duplicate_questions.tsv") 29 | download(downlaod_path, download_url) 30 | print(" done!") 31 | -------------------------------------------------------------------------------- /examples/matching/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | 6 | def accuracy(preds, labels): 7 | preds = np.array(preds) 8 | labels = np.array(labels) 9 | return (preds == labels).mean() 10 | 11 | def pre_recall_f1(preds, labels): 12 | preds = np.array(preds) 13 | labels = np.array(labels) 14 | # recall=TP/(TP+FN) 15 | tp = np.sum((labels == '1') & (preds == '1')) 16 | fp = np.sum((labels == '0') & (preds == '1')) 17 | fn = np.sum((labels == '1') & (preds == '0')) 18 | r = tp * 1.0 / (tp + fn) 19 | # Precision=TP/(TP+FP) 20 | p = tp * 1.0 / (tp + fp) 21 | epsilon = 1e-31 22 | f1 = 2 * p * r / (p+r+epsilon) 23 | return p, r, f1 24 | 25 | 26 | def res_evaluate(res_dir="./outputs/predict/predictions.json", eval_phase='test'): 27 | if eval_phase == 'test': 28 | data_dir="./data/test.tsv" 29 | elif eval_phase == 'dev': 30 | data_dir="./data/dev.tsv" 31 | else: 32 | assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' 33 | 34 | labels = [] 35 | with open(data_dir, "r") as file: 36 | first_flag = True 37 | for line in file: 38 | line = line.split("\t") 39 | label = line[2][:-1] 40 | if label=='label': 41 | continue 42 | labels.append(str(label)) 43 | file.close() 44 | 45 | preds = [] 46 | with open(res_dir, "r") as file: 47 | for line in file.readlines(): 48 | line = json.loads(line) 49 | pred = line['label'] 50 | preds.append(str(pred)) 51 | file.close() 52 | assert len(labels) == len(preds), "prediction result({}) doesn't match to labels({})".format(len(preds),len(labels)) 53 | print('data num: {}'.format(len(labels))) 54 | p, r, f1 = pre_recall_f1(preds, labels) 55 | print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) 56 | 57 | res_evaluate() 58 | -------------------------------------------------------------------------------- /examples/matching/process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | 6 | if len(sys.argv) != 4: 7 | exit(0) 8 | 9 | data_dir = sys.argv[1] 10 | if not os.path.exists(data_dir): 11 | print("%s not exists" % data_dir) 12 | exit(0) 13 | 14 | train_dir = sys.argv[2] 15 | train_file = open(train_dir, "w") 16 | train_file.write("text_a\ttext_b\tlabel\n") 17 | 18 | test_dir = sys.argv[3] 19 | test_file = open(test_dir, "w") 20 | test_file.write("text_a\ttext_b\tlabel\n") 21 | with open(data_dir, "r") as file: 22 | before = "" 23 | cnt = 0 24 | for line in file: 25 | line = line.strip("\n") 26 | line_t = line.split("\t") 27 | flag = 0 28 | if len(line_t) < 6: 29 | if flag: 30 | flag = 0 31 | out_line = "{}{}\n".format(out_line, line) 32 | else: 33 | flag = 1 34 | outline = "{}".format(line) 35 | continue 36 | else: 37 | out_line = "{}\t{}\t{}\n".format(line_t[3], line_t[4], line_t[5]) 38 | cnt += 1 39 | 40 | if 2 <= cnt <= 4301: 41 | test_file.write(out_line) 42 | if 4301 <= cnt <= 104301: 43 | train_file.write(out_line) 44 | 45 | train_file.close() 46 | test_file.close() 47 | -------------------------------------------------------------------------------- /examples/matching/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | if __name__ == '__main__': 6 | 7 | # configs 8 | max_seqlen = 128 9 | batch_size = 16 10 | num_epochs = 3 11 | lr = 3e-5 12 | weight_decay = 0.0 13 | num_classes = 2 14 | random_seed = 1 15 | dropout_prob = 0.1 16 | save_path = './outputs/' 17 | save_type = 'ckpt' 18 | pred_model_path = './outputs/ckpt.step'+str(18732) 19 | print_steps = 50 20 | pred_output = './outputs/predict/' 21 | pre_params = './pretrain/ERNIE-v2-en-base/params' 22 | task_name = 'Quora Question Pairs matching' 23 | 24 | vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt' 25 | train_file = './data/train.tsv' 26 | predict_file = './data/test.tsv' 27 | config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) 28 | input_dim = config['hidden_size'] 29 | 30 | # ----------------------- for training ----------------------- 31 | 32 | # step 1-1: create readers for training 33 | match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed) 34 | # step 1-2: load the training data 35 | match_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size) 36 | 37 | # step 2: create a backbone of the model to extract text features 38 | ernie = palm.backbone.ERNIE.from_config(config) 39 | 40 | # step 3: register the backbone in reader 41 | match_reader.register_with(ernie) 42 | 43 | # step 4: create the task output head 44 | match_head = palm.head.Match(num_classes, input_dim, dropout_prob) 45 | 46 | # step 5-1: create a task trainer 47 | trainer = palm.Trainer(task_name) 48 | # step 5-2: build forward graph with backbone and task head 49 | loss_var = trainer.build_forward(ernie, match_head) 50 | 51 | # step 6-1*: use warmup 52 | n_steps = match_reader.num_examples * num_epochs // batch_size 53 | warmup_steps = int(0.1 * n_steps) 54 | print('total_steps: {}'.format(n_steps)) 55 | print('warmup_steps: {}'.format(warmup_steps)) 56 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 57 | 58 | # step 6-2: create a optimizer 59 | adam = palm.optimizer.Adam(loss_var, lr, sched) 60 | # step 6-3: build backward 61 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 62 | 63 | # step 7: fit prepared reader and data 64 | trainer.fit_reader(match_reader) 65 | 66 | # step 8-1*: load pretrained parameters 67 | trainer.load_pretrain(pre_params, False) 68 | # step 8-2*: set saver to save model 69 | # save_steps = n_steps-16 70 | save_steps = 6244 71 | trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) 72 | # step 8-3: start training 73 | trainer.train(print_steps=print_steps) 74 | 75 | # ----------------------- for prediction ----------------------- 76 | 77 | # step 1-1: create readers for prediction 78 | print('prepare to predict...') 79 | predict_match_reader = palm.reader.MatchReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') 80 | # step 1-2: load the training data 81 | predict_match_reader.load_data(predict_file, batch_size) 82 | 83 | # step 2: create a backbone of the model to extract text features 84 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 85 | 86 | # step 3: register the backbone in reader 87 | predict_match_reader.register_with(pred_ernie) 88 | 89 | # step 4: create the task output head 90 | match_pred_head = palm.head.Match(num_classes, input_dim, phase='predict') 91 | 92 | # step 5: build forward graph with backbone and task head 93 | trainer.build_predict_forward(pred_ernie, match_pred_head) 94 | 95 | # step 6: load checkpoint 96 | trainer.load_ckpt(pred_model_path) 97 | 98 | # step 7: fit prepared reader and data 99 | trainer.fit_reader(predict_match_reader, phase='predict') 100 | 101 | # step 8: predict 102 | print('predicting..') 103 | trainer.predict(print_steps=print_steps, output_dir=pred_output) 104 | -------------------------------------------------------------------------------- /examples/mrc/README.md: -------------------------------------------------------------------------------- 1 | ## Example 4: Machine Reading Comprehension 2 | This task is a machine reading comprehension task. The following sections detail model preparation, dataset preparation, and how to run the task. 3 | 4 | ### Step 1: Prepare Pre-trained Models & Datasets 5 | 6 | #### Pre-trianed Model 7 | 8 | The pre-training model of this mission is: [ERNIE-v1-zh-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). 9 | 10 | Make sure you have downloaded the required pre-training model in the current folder. 11 | 12 | 13 | #### Dataset 14 | 15 | This task uses the `CMRC2018` dataset. `CMRC2018` is an evaluation conducted by Chinese information society. The task of evaluation is to extract reading comprehension. 16 | 17 | Download dataset: 18 | ```shell 19 | python download.py 20 | ``` 21 | 22 | If everything goes well, there will be a folder named `data/` created with all the datas in it. 23 | 24 | Here is some example datas: 25 | 26 | ```json 27 | "paragraphs": [ 28 | { 29 | "id": "TRAIN_36", 30 | "context": "NGC 6231是一个位于天蝎座的疏散星团,天球座标为赤经16时54分,赤纬-41度48分,视觉观测大小约45角分,亮度约2.6视星等,距地球5900光年。NGC 6231年龄约为三百二十万年,是一个非常年轻的星团,星团内的最亮星是5等的天蝎座 ζ1星。用双筒望远镜或小型望远镜就能看到个别的行星。NGC 6231在1654年被意大利天文学家乔瓦尼·巴蒂斯特·霍迪尔纳(Giovanni Battista Hodierna)以Luminosae的名字首次纪录在星表中,但是未见记载于夏尔·梅西耶的天体列表和威廉·赫歇尔的深空天体目录。这个天体在1678年被爱德蒙·哈雷(I.7)、1745年被夏西亚科斯(Jean-Phillippe Loys de Cheseaux)(9)、1751年被尼可拉·路易·拉卡伊(II.13)分别再次独立发现。", 31 | "qas": [ 32 | { 33 | "question": "NGC 6231的经纬度是多少?", 34 | "id": "TRAIN_36_QUERY_0", 35 | "answers": [ 36 | { 37 | "text": "赤经16时54分,赤纬-41度48分", 38 | "answer_start": 27 39 | } 40 | ] 41 | } 42 | } 43 | ``` 44 | 45 | 46 | ### Step 2: Train & Predict 47 | 48 | The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: 49 | 50 | ```shell 51 | python run.py 52 | ``` 53 | 54 | If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: 55 | 56 | ```shell 57 | CUDA_VISIBLE_DEVICES=0,1 python run.py 58 | ``` 59 | 60 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 61 | 62 | Some logs will be shown below: 63 | 64 | ``` 65 | step 1/1515 (epoch 0), loss: 6.251, speed: 0.31 steps/s 66 | step 2/1515 (epoch 0), loss: 6.206, speed: 0.80 steps/s 67 | step 3/1515 (epoch 0), loss: 6.172, speed: 0.86 steps/s 68 | ``` 69 | 70 | 71 | After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: 72 | 73 | 74 | ```json 75 | { 76 | "DEV_0_QUERY_0": "光 荣 和 ω-force 开 发", 77 | "DEV_0_QUERY_1": "任 天 堂 游 戏 谜 之 村 雨 城", 78 | "DEV_0_QUERY_2": "战 史 演 武 」&「 争 霸 演 武 」。", 79 | "DEV_1_QUERY_0": "大 陆 传 统 器 乐 及 戏 曲 里 面 常 用 的 打 击 乐 记 谱 方 法 , 以 中 文 字 的 声 音 模 拟 敲 击 乐 的 声 音 , 纪 录 打 击 乐 的 各 种 不 同 的 演 奏 方 法 。", 80 | "DEV_1_QUERY_1": "「 锣 鼓 点", 81 | "DEV_1_QUERY_2": "锣 鼓 的 运 用 有 约 定 俗 成 的 程 式 , 依 照 角 色 行 当 的 身 份 、 性 格 、 情 绪 以 及 环 境 , 配 合 相 应 的 锣 鼓 点", 82 | "DEV_1_QUERY_3": "鼓 、 锣 、 钹 和 板 四 类 型", 83 | "DEV_2_QUERY_0": "364.6 公 里", 84 | } 85 | ``` 86 | 87 | ### Step 3: Evaluate 88 | 89 | #### Library Dependencies 90 | Before the evaluation, you need to install `nltk` and download the `punkt` tokenizer for nltk: 91 | 92 | ```shell 93 | pip insall nltk 94 | python -m nltk.downloader punkt 95 | ``` 96 | 97 | #### Evaluate 98 | You can run the evaluation script to evaluate the model: 99 | 100 | ```shell 101 | python evaluate.py 102 | ``` 103 | 104 | The evaluation results are as follows: 105 | 106 | ``` 107 | data_num: 3219 108 | em_sroce: 0.6434, f1: 0.8518 109 | ``` 110 | -------------------------------------------------------------------------------- /examples/mrc/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | abs_path = os.path.abspath(__file__) 34 | dst_dir = os.path.join(os.path.dirname(abs_path), "data") 35 | if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): 36 | os.makedirs(dst_dir) 37 | 38 | for file in os.listdir(os.path.join(target_dir, 'task_data', 'cmrc2018')): 39 | shutil.move(os.path.join(target_dir, 'task_data', 'cmrc2018', file), dst_dir) 40 | 41 | shutil.rmtree(os.path.join(target_dir, 'task_data')) 42 | print(" done!") 43 | 44 | -------------------------------------------------------------------------------- /examples/mrc/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | ''' 16 | Evaluation script for CMRC 2018 17 | version: v5 18 | Note: 19 | v5 formatted output, add usage description 20 | v4 fixed segmentation issues 21 | ''' 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | from __future__ import unicode_literals 26 | from __future__ import absolute_import 27 | 28 | from collections import Counter, OrderedDict 29 | import string 30 | import re 31 | import argparse 32 | import json 33 | import sys 34 | import nltk 35 | import pdb 36 | 37 | 38 | # split Chinese with English 39 | def mixed_segmentation(in_str, rm_punc=False): 40 | in_str = in_str.lower().strip() 41 | segs_out = [] 42 | temp_str = "" 43 | sp_char = [ 44 | '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', 45 | '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', 46 | ')', '-', '~', '『', '』',' ' 47 | ] 48 | for char in in_str: 49 | if rm_punc and char in sp_char: 50 | continue 51 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: 52 | if temp_str != "": 53 | ss = nltk.word_tokenize(temp_str) 54 | segs_out.extend(ss) 55 | temp_str = "" 56 | segs_out.append(char) 57 | else: 58 | temp_str += char 59 | 60 | #handling last part 61 | if temp_str != "": 62 | ss = nltk.word_tokenize(temp_str) 63 | segs_out.extend(ss) 64 | 65 | return segs_out 66 | 67 | 68 | # remove punctuation 69 | def remove_punctuation(in_str): 70 | in_str = in_str.lower().strip() 71 | sp_char = [ 72 | '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', 73 | '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', 74 | ')', '-', '~', '『', '』', ' ' 75 | ] 76 | out_segs = [] 77 | for char in in_str: 78 | if char in sp_char: 79 | continue 80 | else: 81 | out_segs.append(char) 82 | return ''.join(out_segs) 83 | 84 | 85 | # find longest common string 86 | def find_lcs(s1, s2): 87 | m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] 88 | mmax = 0 89 | p = 0 90 | for i in range(len(s1)): 91 | for j in range(len(s2)): 92 | if s1[i] == s2[j]: 93 | m[i + 1][j + 1] = m[i][j] + 1 94 | if m[i + 1][j + 1] > mmax: 95 | mmax = m[i + 1][j + 1] 96 | p = i + 1 97 | return s1[p - mmax:p], mmax 98 | 99 | 100 | def evaluate(ground_truth_file, prediction_file): 101 | f1 = 0 102 | em = 0 103 | total_count = 0 104 | skip_count = 0 105 | for instances in ground_truth_file["data"]: 106 | for instance in instances["paragraphs"]: 107 | context_text = instance['context'].strip() 108 | for qas in instance['qas']: 109 | total_count += 1 110 | query_id = qas['id'].strip() 111 | query_text = qas['question'].strip() 112 | answers = [ans["text"] for ans in qas["answers"]] 113 | 114 | if query_id not in prediction_file: 115 | print('Unanswered question: {}\n'.format( 116 | query_id)) 117 | skip_count += 1 118 | continue 119 | 120 | prediction = prediction_file[query_id] 121 | f1 += calc_f1_score(answers, prediction) 122 | em += calc_em_score(answers, prediction) 123 | 124 | f1_score = f1 / total_count 125 | em_score = em / total_count 126 | return f1_score, em_score, total_count, skip_count 127 | 128 | 129 | def calc_f1_score(answers, prediction): 130 | f1_scores = [] 131 | for ans in answers: 132 | ans_segs = mixed_segmentation(ans, rm_punc=True) 133 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 134 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 135 | if lcs_len == 0: 136 | f1_scores.append(0) 137 | continue 138 | precision = 1.0 * lcs_len / len(prediction_segs) 139 | recall = 1.0 * lcs_len / len(ans_segs) 140 | f1 = (2 * precision * recall) / (precision + recall) 141 | f1_scores.append(f1) 142 | return max(f1_scores) 143 | 144 | 145 | def calc_em_score(answers, prediction): 146 | em = 0 147 | for ans in answers: 148 | ans_ = remove_punctuation(ans) 149 | prediction_ = remove_punctuation(prediction) 150 | if ans_ == prediction_: 151 | em = 1 152 | break 153 | return em 154 | 155 | 156 | def eval_file(dataset_file, prediction_file): 157 | ground_truth_file = json.load(open(dataset_file, 'r')) 158 | prediction_file = json.load(open(prediction_file, 'r')) 159 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 160 | AVG = (EM + F1) * 0.5 161 | return EM, F1, AVG, TOTAL 162 | 163 | 164 | if __name__ == '__main__': 165 | EM, F1, AVG, TOTAL = eval_file("data/dev.json", "outputs/predict/predictions.json") 166 | print('data_num: {}'.format(TOTAL)) 167 | print('em_sroce: {:.4f}, f1: {:.4f}'.format(EM,F1)) 168 | -------------------------------------------------------------------------------- /examples/mrc/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | # configs 9 | max_seqlen = 512 10 | batch_size = 8 11 | num_epochs = 2 12 | lr = 3e-5 13 | doc_stride = 128 14 | max_query_len = 64 15 | max_ans_len = 128 16 | weight_decay = 0.01 17 | print_steps = 20 18 | vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' 19 | do_lower_case = True 20 | 21 | train_file = './data/train.json' 22 | predict_file = './data/dev.json' 23 | save_path = './outputs/' 24 | pred_output = './outputs/predict/' 25 | save_type = 'ckpt' 26 | task_name = 'cmrc2018' 27 | pre_params = './pretrain/ERNIE-v1-zh-base/params' 28 | config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) 29 | 30 | # ----------------------- for training ----------------------- 31 | 32 | # step 1-1: create readers for training 33 | mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case) 34 | # step 1-2: load the training data 35 | mrc_reader.load_data(train_file, file_format='json', num_epochs=num_epochs, batch_size=batch_size) 36 | 37 | # step 2: create a backbone of the model to extract text features 38 | ernie = palm.backbone.ERNIE.from_config(config) 39 | 40 | # step 3: register the backbone in reader 41 | mrc_reader.register_with(ernie) 42 | 43 | # step 4: create the task output head 44 | mrc_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len) 45 | 46 | # step 5-1: create a task trainer 47 | trainer = palm.Trainer(task_name) 48 | # step 5-2: build forward graph with backbone and task head 49 | loss_var = trainer.build_forward(ernie, mrc_head) 50 | 51 | # step 6-1*: use warmup 52 | n_steps = mrc_reader.num_examples * num_epochs // batch_size 53 | warmup_steps = int(0.1 * n_steps) 54 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 55 | # step 6-2: create a optimizer 56 | adam = palm.optimizer.Adam(loss_var, lr, sched) 57 | # step 6-3: build backward 58 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 59 | 60 | # step 7: fit prepared reader and data 61 | trainer.fit_reader(mrc_reader) 62 | 63 | # step 8-1*: load pretrained parameters 64 | trainer.load_pretrain(pre_params) 65 | # step 8-2*: set saver to save model 66 | save_steps = 3040 67 | trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) 68 | # step 8-3: start training 69 | trainer.train(print_steps=print_steps) 70 | 71 | # ----------------------- for prediction ----------------------- 72 | 73 | # step 1-1: create readers for prediction 74 | predict_mrc_reader = palm.reader.MRCReader(vocab_path, max_seqlen, max_query_len, doc_stride, do_lower_case=do_lower_case, phase='predict') 75 | # step 1-2: load the training data 76 | predict_mrc_reader.load_data(predict_file, batch_size) 77 | 78 | # step 2: create a backbone of the model to extract text features 79 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 80 | 81 | # step 3: register the backbone in reader 82 | predict_mrc_reader.register_with(pred_ernie) 83 | 84 | # step 4: create the task output head 85 | mrc_pred_head = palm.head.MRC(max_query_len, config['hidden_size'], do_lower_case=do_lower_case, max_ans_len=max_ans_len, phase='predict') 86 | 87 | # step 5: build forward graph with backbone and task head 88 | trainer.build_predict_forward(pred_ernie, mrc_pred_head) 89 | 90 | # step 6: load checkpoint 91 | pred_model_path = './outputs/ckpt.step'+str(3040) 92 | trainer.load_ckpt(pred_model_path) 93 | 94 | # step 7: fit prepared reader and data 95 | trainer.fit_reader(predict_mrc_reader, phase='predict') 96 | 97 | # step 8: predict 98 | print('predicting..') 99 | trainer.predict(print_steps=print_steps, output_dir="outputs/predict") 100 | -------------------------------------------------------------------------------- /examples/multi-task/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "dmtk_data_1.0.0.tar.gz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | shutil.rmtree(os.path.join(target_dir, 'data/dstc2/')) 34 | shutil.rmtree(os.path.join(target_dir, 'data/mrda/')) 35 | shutil.rmtree(os.path.join(target_dir, 'data/multi-woz/')) 36 | shutil.rmtree(os.path.join(target_dir, 'data/swda/')) 37 | shutil.rmtree(os.path.join(target_dir, 'data/udc/')) 38 | print(" done!") 39 | -------------------------------------------------------------------------------- /examples/multi-task/evaluate_intent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | 6 | def accuracy(preds, labels): 7 | preds = np.array(preds) 8 | labels = np.array(labels) 9 | return (preds == labels).mean() 10 | 11 | def pre_recall_f1(preds, labels): 12 | preds = np.array(preds) 13 | labels = np.array(labels) 14 | # recall=TP/(TP+FN) 15 | tp = np.sum((labels == '1') & (preds == '1')) 16 | fp = np.sum((labels == '0') & (preds == '1')) 17 | fn = np.sum((labels == '1') & (preds == '0')) 18 | r = tp * 1.0 / (tp + fn) 19 | # Precision=TP/(TP+FP) 20 | p = tp * 1.0 / (tp + fp) 21 | epsilon = 1e-31 22 | f1 = 2 * p * r / (p+r+epsilon) 23 | return p, r, f1 24 | 25 | 26 | def res_evaluate(res_dir="./outputs/predict-intent/predictions.json", eval_phase='test'): 27 | if eval_phase == 'test': 28 | data_dir="./data/atis/atis_intent/test.tsv" 29 | elif eval_phase == 'dev': 30 | data_dir="./data/dev.tsv" 31 | 32 | else: 33 | assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' 34 | 35 | labels = [] 36 | with open(data_dir, "r") as file: 37 | first_flag = True 38 | for line in file: 39 | line = line.split("\t") 40 | label = line[0] 41 | if label=='label': 42 | continue 43 | labels.append(str(label)) 44 | file.close() 45 | 46 | preds = [] 47 | with open(res_dir, "r") as file: 48 | for line in file.readlines(): 49 | line = json.loads(line) 50 | pred = line['label'] 51 | preds.append(str(pred)) 52 | file.close() 53 | assert len(labels) == len(preds), "prediction result doesn't match to labels" 54 | print('data num: {}'.format(len(labels))) 55 | p, r, f1 = pre_recall_f1(preds, labels) 56 | print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) 57 | 58 | res_evaluate() 59 | -------------------------------------------------------------------------------- /examples/multi-task/evaluate_slot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | 5 | 6 | def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"): 7 | """ 8 | :param map_dir: dict indictuing chunk type 9 | :return: 10 | """ 11 | return json.load(open(map_dir, "r")) 12 | 13 | 14 | def cal_chunk(pred_label, refer_label): 15 | tp = dict() 16 | fn = dict() 17 | fp = dict() 18 | for i in range(len(refer_label)): 19 | if refer_label[i] == pred_label[i]: 20 | if refer_label[i] not in tp: 21 | tp[refer_label[i]] = 0 22 | tp[refer_label[i]] += 1 23 | else: 24 | if pred_label[i] not in fp: 25 | fp[pred_label[i]] = 0 26 | fp[pred_label[i]] += 1 27 | if refer_label[i] not in fn: 28 | fn[refer_label[i]] = 0 29 | fn[refer_label[i]] += 1 30 | 31 | tp_total = sum(tp.values()) 32 | fn_total = sum(fn.values()) 33 | fp_total = sum(fp.values()) 34 | p_total = float(tp_total) / (tp_total + fp_total) 35 | r_total = float(tp_total) / (tp_total + fn_total) 36 | f_micro = 2 * p_total * r_total / (p_total + r_total) 37 | 38 | return f_micro 39 | 40 | 41 | def res_evaluate(res_dir="./outputs/predict-slot/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"): 42 | label_map = load_label_map() 43 | 44 | total_label = [] 45 | with open(data_dir, "r") as file: 46 | first_flag = True 47 | for line in file: 48 | if first_flag: 49 | first_flag = False 50 | continue 51 | line = line.strip("\n") 52 | if len(line) == 0: 53 | continue 54 | line = line.split("\t") 55 | if len(line) < 2: 56 | continue 57 | labels = line[1][:-1].split("\x02") 58 | total_label.append(labels) 59 | total_label = [[label_map[j] for j in i] for i in total_label] 60 | 61 | total_res = [] 62 | with open(res_dir, "r") as file: 63 | cnt = 0 64 | for line in file: 65 | line = line.strip("\n") 66 | if len(line) == 0: 67 | continue 68 | try: 69 | res_arr = json.loads(line) 70 | 71 | if len(total_label[cnt]) < len(res_arr): 72 | total_res.append(res_arr[1: 1 + len(total_label[cnt])]) 73 | elif len(total_label[cnt]) == len(res_arr): 74 | total_res.append(res_arr) 75 | else: 76 | total_res.append(res_arr) 77 | total_label[cnt] = total_label[cnt][: len(res_arr)] 78 | except: 79 | print("json format error: {}".format(cnt)) 80 | print(line) 81 | 82 | cnt += 1 83 | 84 | total_res_equal = [] 85 | total_label_equal = [] 86 | assert len(total_label) == len(total_res), "prediction result doesn't match to labels" 87 | for i in range(len(total_label)): 88 | num = len(total_label[i]) 89 | total_label_equal.extend(total_label[i]) 90 | total_res[i] = total_res[i][:num] 91 | total_res_equal.extend(total_res[i]) 92 | 93 | f1 = cal_chunk(total_res_equal, total_label_equal) 94 | print('data num: {}'.format(len(total_label))) 95 | print("f1: {:.4f}".format(f1)) 96 | 97 | 98 | res_evaluate() 99 | -------------------------------------------------------------------------------- /examples/multi-task/joint_predict.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | import numpy as np 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | # configs 10 | max_seqlen = 128 11 | batch_size = 128 12 | num_epochs = 20 13 | print_steps = 5 14 | lr = 2e-5 15 | num_classes = 130 16 | weight_decay = 0.01 17 | num_classes_intent = 26 18 | dropout_prob = 0.1 19 | random_seed = 0 20 | label_map = './data/atis/atis_slot/label_map.json' 21 | vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt' 22 | 23 | train_slot = './data/atis/atis_slot/train.tsv' 24 | train_intent = './data/atis/atis_intent/train.tsv' 25 | 26 | config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) 27 | input_dim = config['hidden_size'] 28 | 29 | # ----------------------- for training ----------------------- 30 | 31 | # step 1-1: create readers 32 | slot_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict') 33 | intent_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') 34 | 35 | # step 1-2: load train data 36 | slot_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) 37 | intent_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) 38 | 39 | # step 2: create a backbone of the model to extract text features 40 | ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 41 | 42 | # step 3: register readers with ernie backbone 43 | slot_reader.register_with(ernie) 44 | intent_reader.register_with(ernie) 45 | 46 | # step 4: create task output heads 47 | slot_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob, phase='predict') 48 | intent_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob, phase='predict') 49 | 50 | # step 5-1: create task trainers and multiHeadTrainer 51 | trainer_slot = palm.Trainer("slot", mix_ratio=1.0) 52 | trainer_intent = palm.Trainer("intent", mix_ratio=1.0) 53 | trainer = palm.MultiHeadTrainer([trainer_slot, trainer_intent]) 54 | # # step 5-2: build forward graph with backbone and task head 55 | vars = trainer_intent.build_predict_forward(ernie, intent_head) 56 | vars = trainer_slot.build_predict_forward(ernie, slot_head) 57 | loss_var = trainer.build_predict_forward() 58 | 59 | # load checkpoint 60 | trainer.load_ckpt('outputs/ckpt.step300') 61 | 62 | # merge inference readers 63 | joint_iterator = trainer.merge_inference_readers([slot_reader, intent_reader]) 64 | 65 | # for test 66 | # batch = next(joint_iterator('slot')) 67 | # results = trainer.predict_one_batch('slot', batch) 68 | # batch = next(joint_iterator('intent')) 69 | # results = trainer.predict_one_batch('intent', batch) 70 | 71 | # predict slot filling 72 | print('processing slot filling examples...') 73 | print('num examples: '+str(slot_reader.num_examples)) 74 | cnt = 0 75 | for batch in joint_iterator('slot'): 76 | cnt += len(trainer.predict_one_batch('slot', batch)['logits']) 77 | if cnt % 1000 <= 128: 78 | print(str(cnt)+'th example processed.') 79 | print(str(cnt)+'th example processed.') 80 | 81 | # predict intent recognition 82 | print('processing intent recognition examples...') 83 | print('num examples: '+str(intent_reader.num_examples)) 84 | cnt = 0 85 | for batch in joint_iterator('intent'): 86 | cnt += len(trainer.predict_one_batch('intent', batch)['logits']) 87 | if cnt % 1000 <= 128: 88 | print(str(cnt)+'th example processed.') 89 | print(str(cnt)+'th example processed.') 90 | 91 | -------------------------------------------------------------------------------- /examples/multi-task/predict_intent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | from paddlepalm.distribute import gpu_dev_count 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | # configs 10 | max_seqlen = 256 11 | batch_size = 16 12 | num_epochs = 6 13 | print_steps = 5 14 | num_classes = 26 15 | vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt' 16 | predict_file = './data/atis/atis_intent/test.tsv' 17 | save_path = './outputs/' 18 | pred_output = './outputs/predict-intent/' 19 | save_type = 'ckpt' 20 | random_seed = 0 21 | config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) 22 | input_dim = config['hidden_size'] 23 | 24 | # ----------------------- for prediction ----------------------- 25 | 26 | # step 1-1: create readers for prediction 27 | print('prepare to predict...') 28 | predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') 29 | # step 1-2: load the training data 30 | predict_cls_reader.load_data(predict_file, batch_size) 31 | 32 | # step 2: create a backbone of the model to extract text features 33 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 34 | 35 | # step 3: register the backbone in reader 36 | predict_cls_reader.register_with(pred_ernie) 37 | 38 | # step 4: create the task output head 39 | cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict') 40 | 41 | # step 5-1: create a task trainer 42 | trainer = palm.Trainer("intent") 43 | # step 5-2: build forward graph with backbone and task head 44 | trainer.build_predict_forward(pred_ernie, cls_pred_head) 45 | 46 | # step 6: load checkpoint 47 | pred_model_path = './outputs/ckpt.step4641' 48 | trainer.load_ckpt(pred_model_path) 49 | 50 | # step 7: fit prepared reader and data 51 | trainer.fit_reader(predict_cls_reader, phase='predict') 52 | 53 | # step 8: predict 54 | print('predicting..') 55 | trainer.predict(print_steps=print_steps, output_dir=pred_output) 56 | -------------------------------------------------------------------------------- /examples/multi-task/predict_slot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | from paddlepalm.distribute import gpu_dev_count 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | # configs 10 | max_seqlen = 256 11 | batch_size = 16 12 | num_epochs = 6 13 | print_steps = 5 14 | num_classes = 130 15 | label_map = './data/atis/atis_slot/label_map.json' 16 | vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt' 17 | predict_file = './data/atis/atis_slot/test.tsv' 18 | save_path = './outputs/' 19 | pred_output = './outputs/predict-slot/' 20 | save_type = 'ckpt' 21 | random_seed = 0 22 | config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) 23 | input_dim = config['hidden_size'] 24 | 25 | # ----------------------- for prediction ----------------------- 26 | 27 | # step 1-1: create readers for prediction 28 | print('prepare to predict...') 29 | predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict') 30 | # step 1-2: load the training data 31 | predict_seq_label_reader.load_data(predict_file, batch_size) 32 | 33 | # step 2: create a backbone of the model to extract text features 34 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 35 | 36 | # step 3: register the backbone in reader 37 | predict_seq_label_reader.register_with(pred_ernie) 38 | 39 | # step 4: create the task output head 40 | seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict') 41 | 42 | # step 5-1: create a task trainer 43 | trainer_seq_label = palm.Trainer("slot") 44 | # step 5-2: build forward graph with backbone and task head 45 | trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head) 46 | 47 | # step 6: load checkpoint 48 | pred_model_path = './outputs/ckpt.step4641' 49 | trainer_seq_label.load_ckpt(pred_model_path) 50 | 51 | # step 7: fit prepared reader and data 52 | trainer_seq_label.fit_reader(predict_seq_label_reader, phase='predict') 53 | 54 | # step 8: predict 55 | print('predicting..') 56 | trainer_seq_label.predict(print_steps=print_steps, output_dir=pred_output) 57 | -------------------------------------------------------------------------------- /examples/multi-task/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | label_new = "data/atis/atis_slot/label_map.json" 5 | label_old = "data/atis/atis_slot/map_tag_slot_id.txt" 6 | train_old = "data/atis/atis_slot/train.txt" 7 | train_new = "data/atis/atis_slot/train.tsv" 8 | dev_old = "data/atis/atis_slot/dev.txt" 9 | dev_new = "data/atis/atis_slot/dev.tsv" 10 | test_old = "data/atis/atis_slot/test.txt" 11 | test_new = "data/atis/atis_slot/test.tsv" 12 | 13 | 14 | intent_test = "data/atis/atis_intent/test.tsv" 15 | os.rename("data/atis/atis_intent/test.txt", intent_test) 16 | intent_train = "data/atis/atis_intent/train.tsv" 17 | os.rename("data/atis/atis_intent/train.txt", intent_train) 18 | intent_dev = "data/atis/atis_intent/dev.tsv" 19 | os.rename("data/atis/atis_intent/dev.txt", intent_dev) 20 | 21 | with open(intent_dev, 'r+') as f: 22 | content = f.read() 23 | f.seek(0, 0) 24 | f.write("label\ttext_a\n"+content) 25 | f.close() 26 | 27 | with open(intent_test, 'r+') as f: 28 | content = f.read() 29 | f.seek(0, 0) 30 | f.write("label\ttext_a\n"+content) 31 | f.close() 32 | 33 | with open(intent_train, 'r+') as f: 34 | content = f.read() 35 | f.seek(0, 0) 36 | f.write("label\ttext_a\n"+content) 37 | f.close() 38 | 39 | os.mknod(label_new) 40 | os.mknod(train_new) 41 | os.mknod(dev_new) 42 | os.mknod(test_new) 43 | 44 | 45 | tag = [] 46 | id = [] 47 | map = {} 48 | with open(label_old, "r") as f: 49 | with open(label_new, "w") as f2: 50 | for line in f.readlines(): 51 | line = line.split('\t') 52 | tag.append(line[0]) 53 | id.append(int(line[1][:-1])) 54 | map[line[1][:-1]] = line[0] 55 | 56 | re = {tag[i]:id[i] for i in range(len(tag))} 57 | re = json.dumps(re) 58 | f2.write(re) 59 | f2.close() 60 | f.close() 61 | 62 | 63 | with open(train_old, "r") as f: 64 | with open(train_new, "w") as f2: 65 | f2.write("text_a\tlabel\n") 66 | for line in f.readlines(): 67 | line = line.split('\t') 68 | text = line[0].split(' ') 69 | label = line[1].split(' ') 70 | for t in text: 71 | f2.write(t) 72 | f2.write('\2') 73 | f2.write('\t') 74 | for t in label: 75 | if t.endswith('\n'): 76 | t = t[:-1] 77 | f2.write(map[t]) 78 | f2.write('\2') 79 | f2.write('\n') 80 | f2.close() 81 | f.close() 82 | 83 | with open(test_old, "r") as f: 84 | with open(test_new, "w") as f2: 85 | f2.write("text_a\tlabel\n") 86 | for line in f.readlines(): 87 | line = line.split('\t') 88 | text = line[0].split(' ') 89 | label = line[1].split(' ') 90 | for t in text: 91 | f2.write(t) 92 | f2.write('\2') 93 | f2.write('\t') 94 | for t in label: 95 | if t.endswith('\n'): 96 | t = t[:-1] 97 | f2.write(map[t]) 98 | f2.write('\2') 99 | f2.write('\n') 100 | f2.close() 101 | f.close() 102 | 103 | with open(dev_old, "r") as f: 104 | with open(dev_new, "w") as f2: 105 | f2.write("text_a\tlabel\n") 106 | for line in f.readlines(): 107 | line = line.split('\t') 108 | text = line[0].split(' ') 109 | label = line[1].split(' ') 110 | for t in text: 111 | f2.write(t) 112 | f2.write('\2') 113 | f2.write('\t') 114 | for t in label: 115 | if t.endswith('\n'): 116 | t = t[:-1] 117 | f2.write(map[t]) 118 | f2.write('\2') 119 | f2.write('\n') 120 | f2.close() 121 | f.close() 122 | 123 | os.remove(label_old) 124 | os.remove(train_old) 125 | os.remove(test_old) 126 | os.remove(dev_old) -------------------------------------------------------------------------------- /examples/multi-task/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | # configs 9 | max_seqlen = 128 10 | batch_size = 16 11 | num_epochs = 20 12 | print_steps = 5 13 | lr = 2e-5 14 | num_classes = 130 15 | weight_decay = 0.01 16 | num_classes_intent = 26 17 | dropout_prob = 0.1 18 | random_seed = 0 19 | label_map = './data/atis/atis_slot/label_map.json' 20 | vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt' 21 | 22 | train_slot = './data/atis/atis_slot/train.tsv' 23 | train_intent = './data/atis/atis_intent/train.tsv' 24 | 25 | config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json')) 26 | input_dim = config['hidden_size'] 27 | 28 | # ----------------------- for training ----------------------- 29 | 30 | # step 1-1: create readers 31 | seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) 32 | cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) 33 | 34 | # step 1-2: load train data 35 | seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) 36 | cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None) 37 | 38 | # step 2: create a backbone of the model to extract text features 39 | ernie = palm.backbone.ERNIE.from_config(config) 40 | 41 | # step 3: register readers with ernie backbone 42 | seq_label_reader.register_with(ernie) 43 | cls_reader.register_with(ernie) 44 | 45 | # step 4: create task output heads 46 | seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) 47 | cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob) 48 | 49 | # step 5-1: create task trainers and multiHeadTrainer 50 | trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) 51 | trainer_cls = palm.Trainer("intent", mix_ratio=1.0) 52 | trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls]) 53 | # # step 5-2: build forward graph with backbone and task head 54 | loss1 = trainer_cls.build_forward(ernie, cls_head) 55 | loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) 56 | loss_var = trainer.build_forward() 57 | 58 | # step 6-1*: enable warmup for better fine-tuning 59 | n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size 60 | warmup_steps = int(0.1 * n_steps) 61 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 62 | # step 6-2: build a optimizer 63 | adam = palm.optimizer.Adam(loss_var, lr, sched) 64 | # step 6-3: build backward graph 65 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 66 | 67 | # step 7: fit readers to trainer 68 | trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs) 69 | 70 | # step 8-1*: load pretrained model 71 | trainer.load_pretrain('./pretrain/ERNIE-v2-en-base') 72 | # step 8-2*: set saver to save models during training 73 | trainer.set_saver(save_path='./outputs/', save_steps=300) 74 | # step 8-3: start training 75 | trainer.train(print_steps=10) 76 | -------------------------------------------------------------------------------- /examples/predict/README.md: -------------------------------------------------------------------------------- 1 | ## Example 5: Prediction 2 | This example demonstrates how to directly do prediction with PaddlePALM. You can either initialize the model from a checkpoint, a pretrained model or just randomly initialization. Here we reuse the task and data in example 1. Hence repeat the step 1 in example 1 to pretrain data. 3 | 4 | After you have prepared the pre-training model and the data set required for the task, run: 5 | 6 | ```shell 7 | python run.py 8 | ``` 9 | 10 | If you want to specify a specific gpu or use multiple gpus for predict, please use **`CUDA_VISIBLE_DEVICES`**, for example: 11 | 12 | ```shell 13 | CUDA_VISIBLE_DEVICES=0,1 python run.py 14 | ``` 15 | 16 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 17 | 18 | 19 | Some logs will be shown below: 20 | 21 | ``` 22 | step 1/154, speed: 0.51 steps/s 23 | step 2/154, speed: 3.36 steps/s 24 | step 3/154, speed: 3.48 steps/s 25 | ``` 26 | 27 | 28 | After the run, you can view the predictions in the `outputs/predict` folder. Here are some examples of predictions: 29 | 30 | 31 | ``` 32 | {"index": 0, "logits": [-0.2014336884021759, 0.6799028515815735], "probs": [0.29290086030960083, 0.7070990800857544], "label": 1} 33 | {"index": 1, "logits": [0.8593899011611938, -0.29743513464927673], "probs": [0.7607553601264954, 0.23924466967582703], "label": 0} 34 | {"index": 2, "logits": [0.7462944388389587, -0.7083730101585388], "probs": [0.8107157349586487, 0.18928426504135132], "label": 0} 35 | ``` 36 | 37 | ### Step 3: Evaluate 38 | 39 | Once you have the prediction, you can run the evaluation script to evaluate the model: 40 | 41 | ```shell 42 | python evaluate.py 43 | ``` 44 | 45 | The evaluation results are as follows: 46 | 47 | ``` 48 | data num: 1200 49 | accuracy: 0.4758, precision: 0.4730, recall: 0.3026, f1: 0.3691 50 | ``` 51 | -------------------------------------------------------------------------------- /examples/predict/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | abs_path = os.path.abspath(__file__) 34 | dst_dir = os.path.join(os.path.dirname(abs_path), "data") 35 | if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): 36 | os.makedirs(dst_dir) 37 | 38 | for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): 39 | shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) 40 | 41 | shutil.rmtree(os.path.join(target_dir, 'task_data')) 42 | print(" done!") 43 | -------------------------------------------------------------------------------- /examples/predict/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | 6 | def accuracy(preds, labels): 7 | preds = np.array(preds) 8 | labels = np.array(labels) 9 | return (preds == labels).mean() 10 | 11 | def pre_recall_f1(preds, labels): 12 | preds = np.array(preds) 13 | labels = np.array(labels) 14 | # recall=TP/(TP+FN) 15 | tp = np.sum((labels == '1') & (preds == '1')) 16 | fp = np.sum((labels == '0') & (preds == '1')) 17 | fn = np.sum((labels == '1') & (preds == '0')) 18 | r = tp * 1.0 / (tp + fn) 19 | # Precision=TP/(TP+FP) 20 | p = tp * 1.0 / (tp + fp) 21 | epsilon = 1e-31 22 | f1 = 2 * p * r / (p+r+epsilon) 23 | return p, r, f1 24 | 25 | 26 | def res_evaluate(res_dir="./outputs/predict/predictions.json", eval_phase='test'): 27 | if eval_phase == 'test': 28 | data_dir="./data/test.tsv" 29 | elif eval_phase == 'dev': 30 | data_dir="./data/dev.tsv" 31 | else: 32 | assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' 33 | 34 | labels = [] 35 | with open(data_dir, "r") as file: 36 | first_flag = True 37 | for line in file: 38 | line = line.split("\t") 39 | label = line[0] 40 | if label=='label': 41 | continue 42 | labels.append(str(label)) 43 | file.close() 44 | 45 | preds = [] 46 | with open(res_dir, "r") as file: 47 | for line in file.readlines(): 48 | line = json.loads(line) 49 | pred = line['label'] 50 | preds.append(str(pred)) 51 | file.close() 52 | assert len(labels) == len(preds), "prediction result doesn't match to labels" 53 | print('data num: {}'.format(len(labels))) 54 | p, r, f1 = pre_recall_f1(preds, labels) 55 | print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) 56 | 57 | res_evaluate() 58 | -------------------------------------------------------------------------------- /examples/predict/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | # configs 9 | max_seqlen = 256 10 | batch_size = 8 11 | vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' 12 | predict_file = './data/test.tsv' 13 | random_seed = 1 14 | config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) 15 | input_dim = config['hidden_size'] 16 | num_classes = 2 17 | task_name = 'chnsenticorp' 18 | pred_output = './outputs/predict/' 19 | print_steps = 20 20 | pre_params = './pretrain/ERNIE-v1-zh-base/params' 21 | 22 | # ----------------------- for prediction ----------------------- 23 | 24 | # step 1-1: create readers for prediction 25 | print('prepare to predict...') 26 | predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed, phase='predict') 27 | # step 1-2: load the training data 28 | predict_cls_reader.load_data(predict_file, batch_size) 29 | 30 | # step 2: create a backbone of the model to extract text features 31 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 32 | 33 | # step 3: register the backbone in reader 34 | predict_cls_reader.register_with(pred_ernie) 35 | 36 | # step 4: create the task output head 37 | cls_pred_head = palm.head.Classify(num_classes, input_dim, phase='predict') 38 | 39 | # step 5-1: create a task trainer 40 | trainer = palm.Trainer(task_name) 41 | # step 5-2: build forward graph with backbone and task head 42 | trainer.build_predict_forward(pred_ernie, cls_pred_head) 43 | 44 | # step 6: load checkpoint 45 | trainer.load_predict_model(pre_params) 46 | 47 | # step 7: fit prepared reader and data 48 | trainer.fit_reader(predict_cls_reader, phase='predict') 49 | 50 | # step 8: predict 51 | print('predicting..') 52 | trainer.predict(print_steps=print_steps, output_dir=pred_output) 53 | -------------------------------------------------------------------------------- /examples/tagging/README.md: -------------------------------------------------------------------------------- 1 | ## Example 3: Tagging 2 | This task is a named entity recognition task. The following sections detail model preparation, dataset preparation, and how to run the task. 3 | 4 | ### Step 1: Prepare Pre-trained Models & Datasets 5 | 6 | #### Pre-trianed Model 7 | 8 | The pre-training model of this mission is: [ERNIE-v1-zh-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). 9 | 10 | Make sure you have downloaded the required pre-training model in the current folder. 11 | 12 | 13 | #### Dataset 14 | 15 | This task uses the `MSRA-NER(SIGHAN2006)` dataset. 16 | 17 | Download dataset: 18 | ```shell 19 | python download.py 20 | ``` 21 | 22 | If everything goes well, there will be a folder named `data/` created with all the datas in it. 23 | 24 | The data should have 2 fields, `text_a label`, with tsv format. Here is some example datas: 25 | 26 | ``` 27 | text_a label 28 | 在 这 里 恕 弟 不 恭 之 罪 , 敢 在 尊 前 一 诤 : 前 人 论 书 , 每 曰 “ 字 字 有 来 历 , 笔 笔 有 出 处 ” , 细 读 公 字 , 何 尝 跳 出 前 人 藩 篱 , 自 隶 变 而 后 , 直 至 明 季 , 兄 有 何 新 出 ? O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O 29 | 相 比 之 下 , 青 岛 海 牛 队 和 广 州 松 日 队 的 雨 中 之 战 虽 然 也 是 0 ∶ 0 , 但 乏 善 可 陈 。 O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O B-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O O O O O O O 30 | 理 由 多 多 , 最 无 奈 的 却 是 : 5 月 恰 逢 双 重 考 试 , 她 攻 读 的 博 士 学 位 论 文 要 通 考 ; 她 任 教 的 两 所 学 校 , 也 要 在 这 段 时 日 大 考 。 O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O 31 | ``` 32 | 33 | 34 | 35 | ### Step 2: Train & Predict 36 | 37 | The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: 38 | 39 | ```shell 40 | python run.py 41 | ``` 42 | 43 | If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: 44 | 45 | ```shell 46 | CUDA_VISIBLE_DEVICES=0,1 python run.py 47 | ``` 48 | 49 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 50 | 51 | Some logs will be shown below: 52 | 53 | ``` 54 | step 1/652 (epoch 0), loss: 216.002, speed: 0.32 steps/s 55 | step 2/652 (epoch 0), loss: 202.567, speed: 1.28 steps/s 56 | step 3/652 (epoch 0), loss: 170.677, speed: 1.05 steps/s 57 | ``` 58 | 59 | After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: 60 | 61 | 62 | ``` 63 | [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] 64 | [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] 65 | [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] 66 | ``` 67 | 68 | ### Step 3: Evaluate 69 | 70 | Once you have the prediction, you can run the evaluation script to evaluate the model: 71 | 72 | ```python 73 | python evaluate.py 74 | ``` 75 | 76 | The evaluation results are as follows: 77 | 78 | ``` 79 | data num: 4636 80 | f1: 0.9918 81 | ``` 82 | -------------------------------------------------------------------------------- /examples/tagging/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | abs_path = os.path.abspath(__file__) 34 | dst_dir = os.path.join(os.path.dirname(abs_path), "data") 35 | if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): 36 | os.makedirs(dst_dir) 37 | 38 | for file in os.listdir(os.path.join(target_dir, 'task_data', 'msra_ner')): 39 | shutil.move(os.path.join(target_dir, 'task_data', 'msra_ner', file), dst_dir) 40 | 41 | shutil.rmtree(os.path.join(target_dir, 'task_data')) 42 | print(" done!") 43 | -------------------------------------------------------------------------------- /examples/tagging/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | 5 | 6 | def load_label_map(map_dir="./data/label_map.json"): 7 | """ 8 | :param map_dir: dict indictuing chunk type 9 | :return: 10 | """ 11 | return json.load(open(map_dir, "r")) 12 | 13 | 14 | def cal_chunk(pred_label, refer_label): 15 | tp = dict() 16 | fn = dict() 17 | fp = dict() 18 | for i in range(len(refer_label)): 19 | if refer_label[i] == pred_label[i]: 20 | if refer_label[i] not in tp: 21 | tp[refer_label[i]] = 0 22 | tp[refer_label[i]] += 1 23 | else: 24 | if pred_label[i] not in fp: 25 | fp[pred_label[i]] = 0 26 | fp[pred_label[i]] += 1 27 | if refer_label[i] not in fn: 28 | fn[refer_label[i]] = 0 29 | fn[refer_label[i]] += 1 30 | 31 | tp_total = sum(tp.values()) 32 | fn_total = sum(fn.values()) 33 | fp_total = sum(fp.values()) 34 | p_total = float(tp_total) / (tp_total + fp_total) 35 | r_total = float(tp_total) / (tp_total + fn_total) 36 | f_micro = 2 * p_total * r_total / (p_total + r_total) 37 | 38 | return f_micro 39 | 40 | 41 | def res_evaluate(res_dir="./outputs/predict/predictions.json", data_dir="./data/test.tsv"): 42 | label_map = load_label_map() 43 | 44 | total_label = [] 45 | with open(data_dir, "r") as file: 46 | first_flag = True 47 | for line in file: 48 | if first_flag: 49 | first_flag = False 50 | continue 51 | line = line.strip("\n") 52 | if len(line) == 0: 53 | continue 54 | line = line.split("\t") 55 | if len(line) < 2: 56 | continue 57 | labels = line[1].split("\x02") 58 | total_label.append(labels) 59 | total_label = [[label_map[j] for j in i] for i in total_label] 60 | 61 | total_res = [] 62 | with open(res_dir, "r") as file: 63 | cnt = 0 64 | for line in file: 65 | line = line.strip("\n") 66 | if len(line) == 0: 67 | continue 68 | try: 69 | res_arr = json.loads(line) 70 | 71 | if len(total_label[cnt]) < len(res_arr): 72 | total_res.append(res_arr[1: 1 + len(total_label[cnt])]) 73 | elif len(total_label[cnt]) == len(res_arr): 74 | total_res.append(res_arr) 75 | else: 76 | total_res.append(res_arr) 77 | total_label[cnt] = total_label[cnt][: len(res_arr)] 78 | except: 79 | print("json format error: {}".format(cnt)) 80 | print(line) 81 | 82 | cnt += 1 83 | 84 | total_res_equal = [] 85 | total_label_equal = [] 86 | assert len(total_label) == len(total_res), "prediction result doesn't match to labels" 87 | for i in range(len(total_label)): 88 | num = len(total_label[i]) 89 | total_label_equal.extend(total_label[i]) 90 | total_res[i] = total_res[i][:num] 91 | total_res_equal.extend(total_res[i]) 92 | 93 | f1 = cal_chunk(total_res_equal, total_label_equal) 94 | print('data num: {}'.format(len(total_label))) 95 | print("f1: {:.4f}".format(f1)) 96 | 97 | res_evaluate() 98 | -------------------------------------------------------------------------------- /examples/tagging/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | if __name__ == '__main__': 6 | 7 | # configs 8 | max_seqlen = 256 9 | batch_size = 16 10 | num_epochs = 6 11 | lr = 5e-5 12 | num_classes = 7 13 | weight_decay = 0.01 14 | dropout_prob = 0.1 15 | vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' 16 | label_map = './data/label_map.json' 17 | random_seed = 1 18 | train_file = './data/train.tsv' 19 | predict_file = './data/test.tsv' 20 | 21 | save_path='./outputs/' 22 | save_type='ckpt' 23 | pre_params = './pretrain/ERNIE-v1-zh-base/params' 24 | config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) 25 | input_dim = config['hidden_size'] 26 | task_name = 'msra_ner' 27 | pred_output = './outputs/predict/' 28 | train_print_steps = 10 29 | pred_print_steps = 20 30 | 31 | # ----------------------- for training ----------------------- 32 | 33 | # step 1-1: create readers for training 34 | seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) 35 | # step 1-2: load the training data 36 | seq_label_reader.load_data(train_file, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size) 37 | 38 | # step 2: create a backbone of the model to extract text features 39 | ernie = palm.backbone.ERNIE.from_config(config) 40 | 41 | # step 3: register the backbone in reader 42 | seq_label_reader.register_with(ernie) 43 | 44 | # step 4: create the task output head 45 | seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) 46 | 47 | # step 5-1: create a task trainer 48 | trainer = palm.Trainer(task_name) 49 | # step 5-2: build forward graph with backbone and task head 50 | loss_var = trainer.build_forward(ernie, seq_label_head) 51 | 52 | # step 6-1*: use warmup 53 | n_steps = seq_label_reader.num_examples * num_epochs // batch_size 54 | warmup_steps = int(0.1 * n_steps) 55 | print('total_steps: {}'.format(n_steps)) 56 | print('warmup_steps: {}'.format(warmup_steps)) 57 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 58 | # step 6-2: create a optimizer 59 | adam = palm.optimizer.Adam(loss_var, lr, sched) 60 | # step 6-3: build backward 61 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 62 | 63 | # step 7: fit prepared reader and data 64 | trainer.fit_reader(seq_label_reader) 65 | 66 | # step 8-1*: load pretrained parameters 67 | trainer.load_pretrain(pre_params) 68 | # step 8-2*: set saver to save model 69 | save_steps = 1951 70 | # print('save_steps: {}'.format(save_steps)) 71 | trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) 72 | # # step 8-3: start training 73 | trainer.train(print_steps=train_print_steps) 74 | 75 | # ----------------------- for prediction ----------------------- 76 | 77 | # step 1-1: create readers for prediction 78 | print('prepare to predict...') 79 | predict_seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed, phase='predict') 80 | # step 1-2: load the training data 81 | predict_seq_label_reader.load_data(predict_file, batch_size) 82 | 83 | # step 2: create a backbone of the model to extract text features 84 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='predict') 85 | 86 | # step 3: register the backbone in reader 87 | predict_seq_label_reader.register_with(pred_ernie) 88 | 89 | # step 4: create the task output head 90 | seq_label_pred_head = palm.head.SequenceLabel(num_classes, input_dim, phase='predict') 91 | 92 | # step 5: build forward graph with backbone and task head 93 | trainer.build_predict_forward(pred_ernie, seq_label_pred_head) 94 | 95 | # step 6: load checkpoint 96 | pred_model_path = './outputs/ckpt.step' + str(save_steps) 97 | trainer.load_ckpt(pred_model_path) 98 | 99 | # step 7: fit prepared reader and data 100 | trainer.fit_reader(predict_seq_label_reader, phase='predict') 101 | 102 | # step 8: predict 103 | print('predicting..') 104 | trainer.predict(print_steps=pred_print_steps, output_dir=pred_output) 105 | -------------------------------------------------------------------------------- /examples/train_with_eval/README.md: -------------------------------------------------------------------------------- 1 | ## Train with Evaluation version of Example 1: Classification 2 | This task is a sentiment analysis task. The following sections detail model preparation, dataset preparation, and how to run the task. Here to demonstrate how to do evaluation during training in PaddlePALM. 3 | 4 | ### Step 1: Prepare Pre-trained Model & Dataset 5 | 6 | #### Pre-trained Model 7 | 8 | The pre-training model of this mission is: [ERNIE-v1-zh-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). 9 | 10 | Make sure you have downloaded the required pre-training model in the current folder. 11 | 12 | 13 | #### Dataset 14 | 15 | This example demonstrates with [ChnSentiCorp](https://github.com/SophonPlus/ChineseNlpCorpus/tree/master/datasets/ChnSentiCorp_htl_all), a Chinese sentiment analysis dataset. 16 | 17 | Download dataset: 18 | ```shell 19 | python download.py 20 | ``` 21 | 22 | If everything goes well, there will be a folder named `data/` created with all the data files in it. 23 | 24 | The dataset file (for training) should have 2 fields, `text_a` and `label`, stored with [tsv](https://en.wikipedia.org/wiki/Tab-separated_values) format. Here shows an example: 25 | 26 | ``` 27 | label text_a 28 | 0 当当网名不符实,订货多日不见送货,询问客服只会推托,只会要求用户再下订单。如此服务留不住顾客的。去别的网站买书服务更好。 29 | 0 XP的驱动不好找!我的17号提的货,现在就降价了100元,而且还送杀毒软件! 30 | 1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道当年我听说这本书的时候花很长时间去图书馆找和借都没能如愿,所以这次一看到当当有,马上买了,红迷们也要记得备货哦! 31 | ``` 32 | 33 | ### Step 2: Train & Predict 34 | 35 | The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: 36 | 37 | ```shell 38 | python run.py 39 | ``` 40 | 41 | If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: 42 | 43 | ```shell 44 | CUDA_VISIBLE_DEVICES=0,1 python run.py 45 | ``` 46 | 47 | Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** 48 | 49 | 50 | Some logs will be shown below: 51 | 52 | ``` 53 | step 1/154 (epoch 0), loss: 5.512, speed: 0.51 steps/s 54 | step 2/154 (epoch 0), loss: 2.595, speed: 3.36 steps/s 55 | step 3/154 (epoch 0), loss: 1.798, speed: 3.48 steps/s 56 | ``` 57 | 58 | 59 | After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: 60 | 61 | 62 | ``` 63 | {"index": 0, "logits": [-0.2014336884021759, 0.6799028515815735], "probs": [0.29290086030960083, 0.7070990800857544], "label": 1} 64 | {"index": 1, "logits": [0.8593899011611938, -0.29743513464927673], "probs": [0.7607553601264954, 0.23924466967582703], "label": 0} 65 | {"index": 2, "logits": [0.7462944388389587, -0.7083730101585388], "probs": [0.8107157349586487, 0.18928426504135132], "label": 0} 66 | ``` 67 | 68 | ### Step 3: Evaluate 69 | 70 | Once you have the prediction, you can run the evaluation script to evaluate the model: 71 | 72 | ```shell 73 | python evaluate.py 74 | ``` 75 | 76 | The evaluation results are as follows: 77 | 78 | ``` 79 | data num: 1200 80 | accuracy: 0.9575, precision: 0.9634, recall: 0.9523, f1: 0.9578 81 | ``` 82 | -------------------------------------------------------------------------------- /examples/train_with_eval/download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import tarfile 5 | import shutil 6 | import sys 7 | import urllib 8 | URLLIB=urllib 9 | if sys.version_info >= (3, 0): 10 | import urllib.request 11 | URLLIB=urllib.request 12 | 13 | def download(src, url): 14 | def _reporthook(count, chunk_size, total_size): 15 | bytes_so_far = count * chunk_size 16 | percent = float(bytes_so_far) / float(total_size) 17 | if percent > 1: 18 | percent = 1 19 | print('\r>> Downloading... {:.1%}'.format(percent), end="") 20 | 21 | URLLIB.urlretrieve(url, src, reporthook=_reporthook) 22 | 23 | abs_path = os.path.abspath(__file__) 24 | download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" 25 | downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") 26 | target_dir = os.path.dirname(abs_path) 27 | download(downlaod_path, download_url) 28 | 29 | tar = tarfile.open(downlaod_path) 30 | tar.extractall(target_dir) 31 | os.remove(downlaod_path) 32 | 33 | abs_path = os.path.abspath(__file__) 34 | dst_dir = os.path.join(os.path.dirname(abs_path), "data") 35 | if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): 36 | os.makedirs(dst_dir) 37 | 38 | for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): 39 | shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) 40 | 41 | shutil.rmtree(os.path.join(target_dir, 'task_data')) 42 | print(" done!") 43 | -------------------------------------------------------------------------------- /examples/train_with_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import numpy as np 5 | 6 | def accuracy(preds, labels): 7 | preds = np.array(preds) 8 | labels = np.array(labels) 9 | return (preds == labels).mean() 10 | 11 | def pre_recall_f1(preds, labels): 12 | preds = np.array(preds) 13 | labels = np.array(labels) 14 | # recall=TP/(TP+FN) 15 | tp = np.sum((labels == '1') & (preds == '1')) 16 | fp = np.sum((labels == '0') & (preds == '1')) 17 | fn = np.sum((labels == '1') & (preds == '0')) 18 | r = tp * 1.0 / (tp + fn) 19 | # Precision=TP/(TP+FP) 20 | p = tp * 1.0 / (tp + fp) 21 | epsilon = 1e-31 22 | f1 = 2 * p * r / (p+r+epsilon) 23 | return p, r, f1 24 | 25 | 26 | def res_evaluate(res_dir="./outputs/predict/predictions.json", eval_phase='test'): 27 | if eval_phase == 'test': 28 | data_dir="./data/test.tsv" 29 | elif eval_phase == 'dev': 30 | data_dir="./data/dev.tsv" 31 | else: 32 | assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' 33 | 34 | labels = [] 35 | with open(data_dir, "r") as file: 36 | first_flag = True 37 | for line in file: 38 | line = line.split("\t") 39 | label = line[0] 40 | if label=='label': 41 | continue 42 | labels.append(str(label)) 43 | file.close() 44 | 45 | preds = [] 46 | with open(res_dir, "r") as file: 47 | for line in file.readlines(): 48 | line = json.loads(line) 49 | pred = line['label'] 50 | preds.append(str(pred)) 51 | file.close() 52 | assert len(labels) == len(preds), "prediction result doesn't match to labels" 53 | print('data num: {}'.format(len(labels))) 54 | p, r, f1 = pre_recall_f1(preds, labels) 55 | print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) 56 | 57 | res_evaluate() 58 | -------------------------------------------------------------------------------- /examples/train_with_eval/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | # configs 9 | max_seqlen = 256 10 | batch_size = 8 11 | num_epochs = 10 12 | lr = 5e-5 13 | weight_decay = 0.01 14 | vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' 15 | 16 | train_file = './data/train.tsv' 17 | predict_file = './data/test.tsv' 18 | config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) 19 | input_dim = config['hidden_size'] 20 | num_classes = 2 21 | dropout_prob = 0.1 22 | random_seed = 1 23 | task_name = 'chnsenticorp' 24 | save_path = './outputs/' 25 | pred_output = './outputs/predict/' 26 | save_type = 'ckpt' 27 | print_steps = 20 28 | pre_params = './pretrain/ERNIE-v1-zh-base/params' 29 | 30 | # ----------------------- for training ----------------------- 31 | 32 | # step 1-1: create readers for training 33 | cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) 34 | # step 1-2: load the training data 35 | cls_reader.load_data(train_file, batch_size, num_epochs=num_epochs) 36 | 37 | # step 2: create a backbone of the model to extract text features 38 | ernie = palm.backbone.ERNIE.from_config(config) 39 | 40 | # step 3: register the backbone in reader 41 | cls_reader.register_with(ernie) 42 | 43 | # step 4: create the task output head 44 | cls_head = palm.head.Classify(num_classes, input_dim, dropout_prob) 45 | 46 | # step 5-1: create a task trainer 47 | trainer = palm.Trainer(task_name) 48 | # step 5-2: build forward graph with backbone and task head 49 | loss_var = trainer.build_forward(ernie, cls_head) 50 | 51 | # step 6-1*: use warmup 52 | n_steps = cls_reader.num_examples * num_epochs // batch_size 53 | warmup_steps = int(0.1 * n_steps) 54 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 55 | # step 6-2: create a optimizer 56 | adam = palm.optimizer.Adam(loss_var, lr, sched) 57 | # step 6-3: build backward 58 | trainer.build_backward(optimizer=adam, weight_decay=weight_decay) 59 | 60 | # step 7: fit prepared reader and data 61 | iterator = trainer.fit_reader(cls_reader) 62 | 63 | # step 8-1*: load pretrained parameters 64 | trainer.load_pretrain(pre_params) 65 | # step 8-2*: set saver to save model 66 | # save_steps = n_steps 67 | save_steps = 2396 68 | trainer.set_saver(save_steps=save_steps, save_path=save_path, save_type=save_type) 69 | 70 | # step 8-3: start training 71 | # you can repeatly get one train batch with trainer.get_one_batch() 72 | # batch = trainer.get_one_batch() 73 | for step, batch in enumerate(iterator, start=1): 74 | trainer.train_one_step(batch) 75 | if step % 100 == 0: 76 | print('do evaluation.') 77 | # insert evaluation code here 78 | 79 | -------------------------------------------------------------------------------- /paddlepalm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.cfg 3 | setup.py 4 | ./paddlepalm/__init__.py 5 | ./paddlepalm/_downloader.py 6 | ./paddlepalm/downloader.py 7 | ./paddlepalm/multihead_trainer.py 8 | ./paddlepalm/trainer.py 9 | ./paddlepalm/backbone/__init__.py 10 | ./paddlepalm/backbone/base_backbone.py 11 | ./paddlepalm/backbone/bert.py 12 | ./paddlepalm/backbone/ernie.py 13 | ./paddlepalm/backbone/utils/__init__.py 14 | ./paddlepalm/backbone/utils/transformer.py 15 | ./paddlepalm/distribute/__init__.py 16 | ./paddlepalm/distribute/reader.py 17 | ./paddlepalm/head/__init__.py 18 | ./paddlepalm/head/base_head.py 19 | ./paddlepalm/head/cls.py 20 | ./paddlepalm/head/match.py 21 | ./paddlepalm/head/mlm.py 22 | ./paddlepalm/head/mrc.py 23 | ./paddlepalm/head/ner.py 24 | ./paddlepalm/lr_sched/__init__.py 25 | ./paddlepalm/lr_sched/base_schedualer.py 26 | ./paddlepalm/lr_sched/slanted_triangular_schedualer.py 27 | ./paddlepalm/lr_sched/warmup_schedualer.py 28 | ./paddlepalm/optimizer/__init__.py 29 | ./paddlepalm/optimizer/adam.py 30 | ./paddlepalm/optimizer/base_optimizer.py 31 | ./paddlepalm/reader/__init__.py 32 | ./paddlepalm/reader/base_reader.py 33 | ./paddlepalm/reader/cls.py 34 | ./paddlepalm/reader/match.py 35 | ./paddlepalm/reader/mlm.py 36 | ./paddlepalm/reader/mrc.py 37 | ./paddlepalm/reader/seq_label.py 38 | ./paddlepalm/reader/utils/__init__.py 39 | ./paddlepalm/reader/utils/batching4bert.py 40 | ./paddlepalm/reader/utils/batching4ernie.py 41 | ./paddlepalm/reader/utils/mlm_batching.py 42 | ./paddlepalm/reader/utils/mrqa_helper.py 43 | ./paddlepalm/reader/utils/reader4ernie.py 44 | ./paddlepalm/tokenizer/__init__.py 45 | ./paddlepalm/tokenizer/bert_tokenizer.py 46 | ./paddlepalm/tokenizer/ernie_tokenizer.py 47 | ./paddlepalm/utils/__init__.py 48 | ./paddlepalm/utils/basic_helper.py 49 | ./paddlepalm/utils/config_helper.py 50 | ./paddlepalm/utils/plot_helper.py 51 | ./paddlepalm/utils/print_helper.py 52 | ./paddlepalm/utils/reader_helper.py 53 | ./paddlepalm/utils/saver.py 54 | ./paddlepalm/utils/textprocess_helper.py 55 | paddlepalm.egg-info/PKG-INFO 56 | paddlepalm.egg-info/SOURCES.txt 57 | paddlepalm.egg-info/dependency_links.txt 58 | paddlepalm.egg-info/not-zip-safe 59 | paddlepalm.egg-info/requires.txt 60 | paddlepalm.egg-info/top_level.txt -------------------------------------------------------------------------------- /paddlepalm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /paddlepalm.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /paddlepalm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | paddlepaddle-gpu>=1.7.0 2 | -------------------------------------------------------------------------------- /paddlepalm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | paddlepalm 2 | -------------------------------------------------------------------------------- /paddlepalm/__init__.py: -------------------------------------------------------------------------------- 1 | from . import downloader 2 | # from mtl_controller import Controller 3 | #import controller 4 | from . import optimizer 5 | from . import lr_sched 6 | from . import backbone 7 | from . import reader 8 | from . import head 9 | 10 | 11 | from .trainer import Trainer 12 | from .multihead_trainer import MultiHeadTrainer 13 | 14 | #del interface 15 | #del task_instance 16 | #del default_settings 17 | #del utils 18 | -------------------------------------------------------------------------------- /paddlepalm/_downloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | import os 18 | import tarfile 19 | import shutil 20 | from collections import OrderedDict 21 | import sys 22 | import urllib 23 | URLLIB=urllib 24 | if sys.version_info >= (3, 0): 25 | import urllib.request 26 | URLLIB=urllib.request 27 | 28 | __all__ = ["download", "ls"] 29 | 30 | _pretrain = (('RoBERTa-zh-base', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz'), 31 | ('RoBERTa-zh-large', 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz'), 32 | ('ERNIE-v2-en-base', 'https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz'), 33 | ('ERNIE-v2-en-large', 'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz'), 34 | ('XLNet-cased-base','https://xlnet.bj.bcebos.com/xlnet_cased_L-12_H-768_A-12.tgz'), 35 | ('XLNet-cased-large','https://xlnet.bj.bcebos.com/xlnet_cased_L-24_H-1024_A-16.tgz'), 36 | ('ERNIE-v1-zh-base','https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz'), 37 | ('ERNIE-v1-zh-base-max-len-512','https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz'), 38 | ('BERT-en-uncased-large-whole-word-masking','https://bert-models.bj.bcebos.com/wwm_uncased_L-24_H-1024_A-16.tar.gz'), 39 | ('BERT-en-cased-large-whole-word-masking','https://bert-models.bj.bcebos.com/wwm_cased_L-24_H-1024_A-16.tar.gz'), 40 | ('BERT-en-uncased-base', 'https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz'), 41 | ('BERT-en-uncased-large', 'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz'), 42 | ('BERT-en-cased-base','https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz'), 43 | ('BERT-en-cased-large','https://bert-models.bj.bcebos.com/cased_L-24_H-1024_A-16.tar.gz'), 44 | ('BERT-multilingual-uncased-base','https://bert-models.bj.bcebos.com/multilingual_L-12_H-768_A-12.tar.gz'), 45 | ('BERT-multilingual-cased-base','https://bert-models.bj.bcebos.com/multi_cased_L-12_H-768_A-12.tar.gz'), 46 | ('BERT-zh-base','https://bert-models.bj.bcebos.com/chinese_L-12_H-768_A-12.tar.gz'), 47 | ('utils', None)) 48 | _vocab = (('utils', None),('utils', None)) 49 | _backbone =(('utils', None),('utils', None)) 50 | _head = (('utils', None),('utils', None)) 51 | _reader = (('utils', None),('utils', None)) 52 | 53 | _items = (('pretrain', OrderedDict(_pretrain)), 54 | ('vocab', OrderedDict(_vocab)), 55 | ('backbone', OrderedDict(_backbone)), 56 | ('head', OrderedDict(_head)), 57 | ('reader', OrderedDict(_reader)) 58 | ) 59 | _items = OrderedDict(_items) 60 | 61 | def _download(item, scope, path, silent=False, convert=False): 62 | data_url = _items[item][scope] 63 | if data_url == None: 64 | return 65 | if not silent: 66 | print('Downloading {}: {} from {}...'.format(item, scope, data_url)) 67 | data_dir = path + '/' + item + '/' + scope 68 | if not os.path.exists(data_dir): 69 | os.makedirs(os.path.join(data_dir)) 70 | data_name = data_url.split('/')[-1] 71 | filename = data_dir + '/' + data_name 72 | 73 | # print process 74 | def _reporthook(count, chunk_size, total_size): 75 | bytes_so_far = count * chunk_size 76 | percent = float(bytes_so_far) / float(total_size) 77 | if percent > 1: 78 | percent = 1 79 | if not silent: 80 | print('\r>> Downloading... {:.1%}'.format(percent), end = "") 81 | 82 | URLLIB.urlretrieve(data_url, filename, reporthook=_reporthook) 83 | if not silent: 84 | print(' done!') 85 | 86 | if item == 'pretrain': 87 | if not silent: 88 | print ('Extracting {}...'.format(data_name), end=" ") 89 | if os.path.exists(filename): 90 | tar = tarfile.open(filename, 'r') 91 | tar.extractall(path = data_dir) 92 | tar.close() 93 | os.remove(filename) 94 | if len(os.listdir(data_dir))==1: 95 | source_path = data_dir + '/' + data_name.split('.')[0] 96 | fileList = os.listdir(source_path) 97 | for file in fileList: 98 | filePath = os.path.join(source_path, file) 99 | shutil.move(filePath, data_dir) 100 | os.removedirs(source_path) 101 | if not silent: 102 | print ('done!') 103 | if convert: 104 | if not silent: 105 | print ('Converting params...', end=" ") 106 | _convert(data_dir, silent) 107 | if not silent: 108 | print ('done!') 109 | 110 | 111 | def _convert(path, silent=False): 112 | if os.path.isfile(path + '/params/__palminfo__'): 113 | if not silent: 114 | print ('already converted.') 115 | else: 116 | if os.path.exists(path + '/params/'): 117 | os.rename(path + '/params/', path + '/params1/') 118 | os.mkdir(path + '/params/') 119 | tar_model = tarfile.open(path + '/params/' + '__palmmodel__', 'w') 120 | tar_info = open(path + '/params/'+ '__palminfo__', 'w') 121 | for root, dirs, files in os.walk(path + '/params1/'): 122 | for file in files: 123 | src_file = os.path.join(root, file) 124 | tar_model.add(src_file, '__paddlepalm_' + file) 125 | tar_info.write('__paddlepalm_' + file) 126 | os.remove(src_file) 127 | tar_model.close() 128 | tar_info.close() 129 | os.removedirs(path + '/params1/') 130 | 131 | def download(item, scope='all', path='.'): 132 | """download an item. The available scopes and contained items can be showed with `paddlepalm.downloader.ls`. 133 | 134 | Args: 135 | item: the item to download. 136 | scope: the scope of the item to download. 137 | path: the target dir to download to. Default is `.`, means current dir. 138 | """ 139 | # item = item.lower() 140 | # scope = scope.lower() 141 | assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys())) 142 | 143 | if _items[item]['utils'] is not None: 144 | _download(item, 'utils', path, silent=True) 145 | 146 | if scope != 'all': 147 | assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys())) 148 | _download(item, scope, path) 149 | else: 150 | for s in _items[item].keys(): 151 | _download(item, s, path) 152 | 153 | 154 | def _ls(item, scope, l = 10): 155 | if scope != 'all': 156 | assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys())) 157 | print ('{}'.format(scope)) 158 | else: 159 | for s in _items[item].keys(): 160 | if s == 'utils': 161 | continue 162 | print (' => '+s) 163 | 164 | def ls(item='all', scope='all'): 165 | 166 | if scope == 'utils': 167 | return 168 | if item != 'all': 169 | assert item in _items, '{} is not found. Support scopes: {}'.format(item, list(_items.keys())) 170 | print ('Available {} items:'.format(item)) 171 | _ls(item, scope) 172 | else: 173 | l = max(map(len, _items.keys())) 174 | for i in _items.keys(): 175 | print ('Available {} items: '.format(i)) 176 | _ls(i, scope, l) 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /paddlepalm/backbone/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/PALM/2555c0e2a5fab1b702ae8d1c7612bef48c65af38/paddlepalm/backbone/README.md -------------------------------------------------------------------------------- /paddlepalm/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .ernie import ERNIE 3 | from .bert import BERT 4 | 5 | -------------------------------------------------------------------------------- /paddlepalm/backbone/base_backbone.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | class Backbone(object): 18 | """interface of backbone model.""" 19 | 20 | def __init__(self, phase): 21 | """该函数完成一个主干网络的构造,至少需要包含一个phase参数。 22 | 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 23 | Args: 24 | phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict 25 | """ 26 | 27 | assert isinstance(config, dict) 28 | 29 | @property 30 | def inputs_attr(self): 31 | """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象 32 | 为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape 33 | 中的相应维度设置为-1。 34 | 35 | Return: 36 | dict类型。对各个输入对象的属性描述。例如, 37 | 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 38 | {"token_ids": ([-1, max_len], 'int64'), 39 | "input_ids": ([-1, max_len], 'int64'), 40 | "segment_ids": ([-1, max_len], 'int64'), 41 | "input_mask": ([-1, max_len], 'float32')}""" 42 | raise NotImplementedError() 43 | 44 | @property 45 | def outputs_attr(self): 46 | """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如 47 | str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 48 | 49 | Return: 50 | dict类型。对各个输出对象的属性描述。例如, 51 | 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 52 | {"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'), 53 | "sentence_emb": ([-1, hidden_size], 'float32'), 54 | "sim_vec": ([-1, hidden_size], 'float32')}""" 55 | raise NotImplementedError() 56 | 57 | def build(self, inputs): 58 | """建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。 59 | Args: 60 | inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 61 | Return: 62 | 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 63 | """ 64 | raise NotImplementedError() 65 | -------------------------------------------------------------------------------- /paddlepalm/backbone/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/PALM/2555c0e2a5fab1b702ae8d1c7612bef48c65af38/paddlepalm/backbone/utils/__init__.py -------------------------------------------------------------------------------- /paddlepalm/distribute/__init__.py: -------------------------------------------------------------------------------- 1 | from paddle import fluid 2 | import os 3 | import multiprocessing 4 | 5 | gpu_dev_count = int(fluid.core.get_cuda_device_count()) 6 | cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) 7 | 8 | from .reader import yield_pieces, data_feeder, decode_fake 9 | 10 | -------------------------------------------------------------------------------- /paddlepalm/distribute/reader.py: -------------------------------------------------------------------------------- 1 | 2 | from . import gpu_dev_count, cpu_dev_count 3 | try: 4 | import queue as Queue 5 | except ImportError: 6 | import Queue 7 | from threading import Thread 8 | 9 | dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count 10 | 11 | def yield_pieces(data, distribute_strategy, batch_size): 12 | """ 13 | Args: 14 | distribute_strategy: support s=split, c=copy, u=unstack, 15 | """ 16 | assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count." 17 | # print('data in yield pieces') 18 | # print(len(data)) 19 | 20 | assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)] 21 | assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)] 22 | if isinstance(data, dict): 23 | keys = list(data.keys()) 24 | data_list = [data[i] for i in keys] 25 | ds_list = [distribute_strategy[i] for i in keys] 26 | else: 27 | assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors." 28 | data_list = data 29 | ds_list = distribute_strategy 30 | stride = batch_size // dev_count 31 | p = stride 32 | # while p < len(data_list) + stride: 33 | while p <= batch_size: 34 | temp = [] 35 | for d, s in zip(data_list, ds_list): 36 | s = s.strip().lower() 37 | if s == 's' or s == 'split': 38 | if p - stride >= len(d): 39 | # print('WARNING: no more examples to feed empty devices') 40 | temp = [] 41 | return 42 | temp.append(d[p-stride:p]) 43 | elif s == 'u' or s == 'unstack': 44 | assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.' 45 | if p//stride > len(d): 46 | # print('WARNING: no more examples to feed empty devices') 47 | return 48 | temp.append(d[p//stride-1]) 49 | elif s == 'c' or s == 'copy': 50 | temp.append(d) 51 | else: 52 | raise NotImplementedError() 53 | 54 | p += stride 55 | if type(data) == dict: 56 | yield dict(zip(*[keys, temp])) 57 | else: 58 | # print('yielded pieces') 59 | # print(len(temp)) 60 | yield temp 61 | 62 | 63 | def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is_multi=False): 64 | if postprocess_fn is None: 65 | def postprocess_fn(batch, id=-1, phase='train', is_multi=False): 66 | return batch 67 | 68 | def worker(reader, dev_count, queue): 69 | dev_batches = [] 70 | for index, data in enumerate(reader()): 71 | if len(dev_batches) < dev_count: 72 | dev_batches.append(data) 73 | if len(dev_batches) == dev_count: 74 | queue.put((dev_batches, 0)) 75 | dev_batches = [] 76 | # For the prediction of the remained batches, pad more batches to 77 | # the number of devices and the padded samples would be removed in 78 | # prediction outputs. 79 | if len(dev_batches) > 0: 80 | num_pad = dev_count - len(dev_batches) 81 | for i in range(len(dev_batches), dev_count): 82 | dev_batches.append(dev_batches[-1]) 83 | queue.put((dev_batches, num_pad)) 84 | queue.put(None) 85 | 86 | queue = Queue.Queue(dev_count*prefetch_steps) 87 | p = Thread( 88 | target=worker, args=(reader, dev_count, queue)) 89 | p.daemon = True 90 | p.start() 91 | while True: 92 | ret = queue.get() 93 | queue.task_done() 94 | if ret is not None: 95 | batches, num_pad = ret 96 | if dev_count > 1 and phase == 'train' and is_multi: 97 | id = batches[0]['__task_id'][0] 98 | else: 99 | id = -1 100 | batch_buf = [] 101 | flag_buf = [] 102 | for idx, batch in enumerate(batches): 103 | # flag = num_pad == 0 104 | flag = idx-len(batches) < -num_pad 105 | # if num_pad > 0: 106 | # num_pad -= 1 107 | batch = postprocess_fn(batch, id, phase, is_multi=is_multi) 108 | # batch = postprocess_fn(batch) 109 | batch_buf.append(batch) 110 | flag_buf.append(flag) 111 | yield batch_buf, flag_buf 112 | else: 113 | break 114 | queue.join() 115 | 116 | 117 | 118 | def decode_fake(nums, mask, bs): 119 | bs //= dev_count 120 | n_t = 0 121 | for flag in mask: 122 | if not flag: 123 | break 124 | n_t = n_t + 1 125 | 126 | n_f = len(mask) - n_t 127 | p1 = nums - (n_t-1) * bs 128 | assert p1 % (n_f+1) == 0 129 | each_f = p1 // (n_f+1) 130 | return each_f * n_f 131 | 132 | -------------------------------------------------------------------------------- /paddlepalm/downloader.py: -------------------------------------------------------------------------------- 1 | from ._downloader import * 2 | -------------------------------------------------------------------------------- /paddlepalm/head/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .cls import Classify 3 | from .match import Match 4 | from .ner import SequenceLabel 5 | from .mrc import MRC 6 | from .mlm import MaskLM 7 | -------------------------------------------------------------------------------- /paddlepalm/head/base_head.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import json 18 | import copy 19 | 20 | class Head(object): 21 | 22 | def __init__(self, phase='train'): 23 | """该函数完成一个任务头的构造,至少需要包含一个phase参数。 24 | 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 25 | Args: 26 | phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict 27 | """ 28 | self._stop_gradient = {} 29 | self._phase = phase 30 | self._prog = None 31 | self._results_buffer = [] 32 | 33 | @property 34 | def inputs_attrs(self): 35 | """step级别的任务输入对象声明。 36 | 37 | 描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象(每个step获取一次)。使用字典进行描述, 38 | 字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。 39 | 输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象 40 | 的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相应维度设置为-1。 41 | 42 | Return: 43 | dict类型。描述该任务头所依赖的step级输入,即来自各个组件的输出对象。""" 44 | raise NotImplementedError() 45 | 46 | @property 47 | def outputs_attr(self): 48 | """step级别的任务输出对象声明。 49 | 50 | 描述该任务头的输出对象(每个step输出一次),包括每个输出对象的名字,shape和dtype。输出对象会被加入到 51 | fetch_list中,从而在每个训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方 52 | 法中进行当前step的后处理。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[], 53 | 当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 54 | 55 | Return: 56 | dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 57 | """ 58 | 59 | raise NotImplementedError() 60 | 61 | @property 62 | def epoch_inputs_attrs(self): 63 | """epoch级别的任务输入对象声明。 64 | 65 | 描述该任务所依赖的来自reader、backbone和来自其他任务头的输出对象(每个epoch结束后产生一次),如完整的 66 | 样本集,有效的样本数等。使用字典进行描述,字典的key为输出对象所在的组件(如’reader‘,’backbone‘等), 67 | value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关 68 | 组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相 69 | 应维度设置为-1。 70 | 71 | Return: 72 | dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 73 | """ 74 | return {} 75 | 76 | def build(self, inputs, scope_name=""): 77 | """建立任务头的计算图。 78 | 79 | 将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 80 | 81 | Args: 82 | inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 83 | Return: 84 | 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 85 | """ 86 | raise NotImplementedError() 87 | 88 | def batch_postprocess(self, rt_outputs): 89 | """batch/step级别的后处理。 90 | 91 | 每个训练或推理step后针对当前batch的任务头输出对象的实时计算结果来进行相关后处理。 92 | 默认将输出结果存储到缓冲区self._results_buffer中。""" 93 | if isinstance(rt_outputs, dict): 94 | keys = rt_outputs.keys() 95 | vals = [rt_outputs[k] for k in keys] 96 | lens = [len(v) for v in vals] 97 | if len(set(lens)) == 1: 98 | results = [dict(zip(*[keys, i])) for i in zip(*vals)] 99 | self._results_buffer.extend(results) 100 | return results 101 | else: 102 | print('WARNING: irregular output results. visualize failed.') 103 | self._results_buffer.append(rt_outputs) 104 | return None 105 | 106 | def reset(self): 107 | """清空该任务头的缓冲区(在训练或推理过程中积累的处理结果)""" 108 | self._results_buffer = [] 109 | 110 | def get_results(self): 111 | """返回当前任务头积累的处理结果。""" 112 | return copy.deepcopy(self._results_buffer) 113 | 114 | def epoch_postprocess(self, post_inputs=None, output_dir=None): 115 | """epoch级别的后处理。 116 | 117 | 每个训练或推理epoch结束后,对积累的各样本的后处理结果results进行后处理。默认情况下,当output_dir为None时,直接将results打印到 118 | 屏幕上。当指定output_dir时,将results存储在指定的文件夹内,并以任务头所处阶段来作为存储文件的文件名。 119 | 120 | Args: 121 | post_inputs: 当声明的epoch_inputs_attr不为空时,该参数会携带对应的输入变量的内容。 122 | output_dir: 积累结果的保存路径。 123 | """ 124 | if output_dir is not None: 125 | if not os.path.exists(output_dir): 126 | os.makedirs(output_dir) 127 | with open(os.path.join(output_dir, self._phase), 'w') as writer: 128 | for i in self._results_buffer: 129 | writer.write(json.dumps(i)+'\n') 130 | else: 131 | return self._results_buffer 132 | 133 | -------------------------------------------------------------------------------- /paddlepalm/head/cls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import paddle.fluid as fluid 17 | from paddle.fluid import layers 18 | from paddlepalm.head.base_head import Head 19 | import numpy as np 20 | import os 21 | import json 22 | 23 | 24 | class Classify(Head): 25 | """ 26 | classification 27 | """ 28 | def __init__(self, num_classes, input_dim, dropout_prob=0.0, \ 29 | param_initializer_range=0.02, phase='train'): 30 | 31 | self._is_training = phase == 'train' 32 | self._hidden_size = input_dim 33 | 34 | self.num_classes = num_classes 35 | 36 | self._dropout_prob = dropout_prob if phase == 'train' else 0.0 37 | self._param_initializer = fluid.initializer.TruncatedNormal( 38 | scale=param_initializer_range) 39 | self._preds = [] 40 | self._probs = [] 41 | 42 | @property 43 | def inputs_attrs(self): 44 | reader = {} 45 | bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} 46 | if self._is_training: 47 | reader["label_ids"] = [[-1], 'int64'] 48 | return {'reader': reader, 'backbone': bb} 49 | 50 | @property 51 | def outputs_attrs(self): 52 | if self._is_training: 53 | return {'loss': [[1], 'float32']} 54 | else: 55 | return {'logits': [[-1, self.num_classes], 'float32'], 56 | 'probs': [[-1, self.num_classes], 'float32']} 57 | 58 | 59 | def build(self, inputs, scope_name=''): 60 | sent_emb = inputs['backbone']['sentence_embedding'] 61 | if self._is_training: 62 | label_ids = inputs['reader']['label_ids'] 63 | cls_feats = fluid.layers.dropout( 64 | x=sent_emb, 65 | dropout_prob=self._dropout_prob, 66 | dropout_implementation="upscale_in_train") 67 | 68 | logits = fluid.layers.fc( 69 | input=sent_emb, 70 | size=self.num_classes, 71 | param_attr=fluid.ParamAttr( 72 | name=scope_name+"cls_out_w", 73 | initializer=self._param_initializer), 74 | bias_attr=fluid.ParamAttr( 75 | name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.))) 76 | probs = fluid.layers.softmax(logits) 77 | if self._is_training: 78 | loss = fluid.layers.cross_entropy( 79 | input=probs, label=label_ids) 80 | loss = layers.mean(loss) 81 | return {"loss": loss} 82 | else: 83 | return {"logits":logits, 84 | "probs":probs} 85 | 86 | def batch_postprocess(self, rt_outputs): 87 | if not self._is_training: 88 | logits = rt_outputs['logits'] 89 | probs = rt_outputs['probs'] 90 | self._preds.extend(logits.tolist()) 91 | self._probs.extend(probs.tolist()) 92 | 93 | 94 | def epoch_postprocess(self, post_inputs, output_dir=None): 95 | # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs 96 | if not self._is_training: 97 | results = [] 98 | for i in range(len(self._preds)): 99 | label = int(np.argmax(np.array(self._preds[i]))) 100 | result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} 101 | results.append(result) 102 | if output_dir is not None: 103 | with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: 104 | for result in results: 105 | result = json.dumps(result) 106 | writer.write(result+'\n') 107 | print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) 108 | return results 109 | 110 | 111 | -------------------------------------------------------------------------------- /paddlepalm/head/mlm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import paddle.fluid as fluid 17 | from paddlepalm.head.base_head import Head 18 | from paddle.fluid import layers 19 | import numpy as np 20 | import os 21 | from paddlepalm.backbone.utils.transformer import pre_process_layer 22 | 23 | class MaskLM(Head): 24 | ''' 25 | mlm 26 | ''' 27 | def __init__(self, input_dim, vocab_size, hidden_act, dropout_prob=0.0, \ 28 | param_initializer_range=0.02, phase='train'): 29 | self._is_training = phase == 'train' 30 | self._emb_size = input_dim 31 | self._hidden_size = input_dim 32 | self._dropout_prob = dropout_prob if phase == 'train' else 0.0 33 | self._preds = [] 34 | 35 | self._vocab_size = vocab_size 36 | self._hidden_act = hidden_act 37 | self._initializer_range = param_initializer_range 38 | 39 | @property 40 | def inputs_attrs(self): 41 | reader = { 42 | "mask_label": [[-1], 'int64'], 43 | "mask_pos": [[-1], 'int64'], 44 | } 45 | if not self._is_training: 46 | del reader['mask_label'] 47 | bb = { 48 | "encoder_outputs": [[-1, -1, self._hidden_size], 'float32'], 49 | "embedding_table": [[-1, self._vocab_size, self._emb_size], 'float32']} 50 | return {'reader': reader, 'backbone': bb} 51 | 52 | @property 53 | def outputs_attrs(self): 54 | if self._is_training: 55 | return {"loss": [[1], 'float32']} 56 | else: 57 | return {"logits": [[-1], 'float32']} 58 | 59 | def build(self, inputs, scope_name=""): 60 | mask_pos = inputs["reader"]["mask_pos"] 61 | 62 | word_emb = inputs["backbone"]["embedding_table"] 63 | enc_out = inputs["backbone"]["encoder_outputs"] 64 | 65 | if self._is_training: 66 | mask_label = inputs["reader"]["mask_label"] 67 | l1 = enc_out.shape[0] 68 | l2 = enc_out.shape[1] 69 | bxs = fluid.layers.fill_constant(shape=[1], value=l1*l2, dtype='int64') 70 | max_position = bxs - 1 71 | mask_pos = fluid.layers.elementwise_min(mask_pos, max_position) 72 | mask_pos.stop_gradient = True 73 | 74 | emb_size = word_emb.shape[-1] 75 | 76 | _param_initializer = fluid.initializer.TruncatedNormal( 77 | scale=self._initializer_range) 78 | 79 | reshaped_emb_out = fluid.layers.reshape( 80 | x=enc_out, shape=[-1, emb_size]) 81 | 82 | # extract masked tokens' feature 83 | mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos) 84 | 85 | # transform: fc 86 | mask_trans_feat = fluid.layers.fc( 87 | input=mask_feat, 88 | size=emb_size, 89 | act=self._hidden_act, 90 | param_attr=fluid.ParamAttr( 91 | name=scope_name+'mask_lm_trans_fc.w_0', 92 | initializer=_param_initializer), 93 | bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0')) 94 | # transform: layer norm 95 | mask_trans_feat = pre_process_layer( 96 | mask_trans_feat, 'n', name=scope_name+'mask_lm_trans') 97 | 98 | mask_lm_out_bias_attr = fluid.ParamAttr( 99 | name=scope_name+"mask_lm_out_fc.b_0", 100 | initializer=fluid.initializer.Constant(value=0.0)) 101 | 102 | fc_out = fluid.layers.matmul( 103 | x=mask_trans_feat, 104 | y=word_emb, 105 | transpose_y=True) 106 | fc_out += fluid.layers.create_parameter( 107 | shape=[self._vocab_size], 108 | dtype='float32', 109 | attr=mask_lm_out_bias_attr, 110 | is_bias=True) 111 | 112 | if self._is_training: 113 | inputs = fluid.layers.softmax(fc_out) 114 | mask_lm_loss = fluid.layers.cross_entropy( 115 | input=inputs, label=mask_label) 116 | loss = fluid.layers.mean(mask_lm_loss) 117 | return {'loss': loss} 118 | else: 119 | return {'logits': fc_out} 120 | 121 | def batch_postprocess(self, rt_outputs): 122 | if not self._is_training: 123 | logits = rt_outputs['logits'] 124 | preds = np.argmax(logits, -1) 125 | self._preds.extend(preds.tolist()) 126 | return preds 127 | 128 | def epoch_postprocess(self, post_inputs, output_dir=None): 129 | # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs 130 | if not self._is_training: 131 | results = [] 132 | for i in range(len(self._preds)): 133 | result = {'index': i, 'word_id': self._preds[i]} 134 | results.append(result) 135 | if output_dir is not None: 136 | with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: 137 | for result in results: 138 | result = json.dumps(result) 139 | writer.write(result+'\n') 140 | print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) 141 | return results 142 | 143 | -------------------------------------------------------------------------------- /paddlepalm/head/ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import paddle.fluid as fluid 17 | from paddle.fluid import layers 18 | from paddlepalm.head.base_head import Head 19 | import numpy as np 20 | import os 21 | import math 22 | 23 | class SequenceLabel(Head): 24 | ''' 25 | Sequence label 26 | ''' 27 | def __init__(self, num_classes, input_dim, dropout_prob=0.0, learning_rate=1e-3, \ 28 | param_initializer_range=0.02, phase='train'): 29 | 30 | """ 31 | Args: 32 | phase: train, eval, pred 33 | lang: en, ch, ... 34 | """ 35 | 36 | self._is_training = phase == 'train' 37 | self._hidden_size = input_dim 38 | 39 | self.num_classes = num_classes 40 | 41 | self._dropout_prob = dropout_prob if phase == 'train' else 0.0 42 | self._param_initializer = fluid.initializer.TruncatedNormal( 43 | scale=param_initializer_range) 44 | 45 | self.learning_rate = learning_rate 46 | self._preds = [] 47 | 48 | 49 | @property 50 | def inputs_attrs(self): 51 | reader = {} 52 | bb = {"encoder_outputs": [[-1, -1, -1], 'float32']} 53 | if self._is_training: 54 | reader["label_ids"] = [[-1, -1], 'int64'] 55 | reader["seq_lens"] = [[-1], 'int64'] 56 | return {'reader': reader, 'backbone': bb} 57 | 58 | @property 59 | def outputs_attrs(self): 60 | if self._is_training: 61 | return {'loss': [[1], 'float32']} 62 | else: 63 | return {'logits': [[-1, -1, self.num_classes], 'float32']} 64 | 65 | def build(self, inputs, scope_name=''): 66 | token_emb = inputs['backbone']['encoder_outputs'] 67 | if self._is_training: 68 | label_ids = inputs['reader']['label_ids'] 69 | seq_lens = inputs['reader']['seq_lens'] 70 | 71 | emission = fluid.layers.fc( 72 | size=self.num_classes, 73 | input=token_emb, 74 | param_attr=fluid.ParamAttr( 75 | initializer=self._param_initializer, 76 | regularizer=fluid.regularizer.L2DecayRegularizer( 77 | regularization_coeff=1e-4)), 78 | bias_attr=fluid.ParamAttr( 79 | name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)), 80 | num_flatten_dims=2) 81 | 82 | if self._is_training: 83 | 84 | # compute loss 85 | crf_cost = fluid.layers.linear_chain_crf( 86 | input=emission, 87 | label=label_ids, 88 | param_attr=fluid.ParamAttr( 89 | name=scope_name+'crfw', learning_rate=self.learning_rate), 90 | length=seq_lens) 91 | 92 | avg_cost = fluid.layers.mean(x=crf_cost) 93 | crf_decode = fluid.layers.crf_decoding( 94 | input=emission, 95 | param_attr=fluid.ParamAttr(name=scope_name+'crfw'), 96 | length=seq_lens) 97 | 98 | (precision, recall, f1_score, num_infer_chunks, num_label_chunks, 99 | num_correct_chunks) = fluid.layers.chunk_eval( 100 | input=crf_decode, 101 | label=label_ids, 102 | chunk_scheme="IOB", 103 | num_chunk_types=int(math.ceil((self.num_classes - 1) / 2.0)), 104 | seq_length=seq_lens) 105 | chunk_evaluator = fluid.metrics.ChunkEvaluator() 106 | chunk_evaluator.reset() 107 | 108 | return {"loss": avg_cost} 109 | else: 110 | return {"logits": emission} 111 | 112 | def batch_postprocess(self, rt_outputs): 113 | if not self._is_training: 114 | emission = rt_outputs['emission'] 115 | preds = np.argmax(emission, -1) 116 | self._preds.extend(preds.tolist()) 117 | 118 | def epoch_postprocess(self, post_inputs, output_dir=None): 119 | # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs 120 | if not self._is_training: 121 | if output_dir is not None: 122 | with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: 123 | for p in self._preds: 124 | writer.write(str(p)+'\n') 125 | print('Predictions saved at '+os.path.join(output_dir, 'predictions.json')) 126 | return self._preds 127 | -------------------------------------------------------------------------------- /paddlepalm/lr_sched/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .slanted_triangular_schedualer import TriangularSchedualer 3 | from .warmup_schedualer import WarmupSchedualer 4 | 5 | -------------------------------------------------------------------------------- /paddlepalm/lr_sched/base_schedualer.py: -------------------------------------------------------------------------------- 1 | 2 | class Schedualer(): 3 | 4 | def __init__(self): 5 | self._prog = None 6 | 7 | def _set_prog(self, prog): 8 | self._prog = prog 9 | 10 | def _build(self, learning_rate): 11 | raise NotImplementedError() 12 | 13 | -------------------------------------------------------------------------------- /paddlepalm/lr_sched/slanted_triangular_schedualer.py: -------------------------------------------------------------------------------- 1 | from paddlepalm.lr_sched.base_schedualer import Schedualer 2 | from paddle import fluid 3 | 4 | class TriangularSchedualer(Schedualer): 5 | 6 | """ Implementation of Slanted Triangular learning rate schedual method, more details refer to https://arxiv.org/pdf/1801.06146.pdf . Apply linear warmup of learning rate from 0 to learning_rate until warmup_steps, and then decay to 0 linearly until num_train_steps.""" 7 | 8 | def __init__(self, warmup_steps, num_train_steps): 9 | """Create a new TriangularSchedualer object. 10 | 11 | Args: 12 | warmup_steps: the learning rate will grow from 0 to max_learning_rate over `warmup_steps` steps. 13 | num_train_steps: the number of train steps. 14 | 15 | """ 16 | Schedualer.__init__(self) 17 | assert num_train_steps > warmup_steps > 0 18 | self.warmup_steps = warmup_steps 19 | self.num_train_steps = num_train_steps 20 | 21 | 22 | def _build(self, learning_rate): 23 | with self._prog._lr_schedule_guard(): 24 | lr = fluid.layers.tensor.create_global_var( 25 | shape=[1], 26 | value=0.0, 27 | dtype='float32', 28 | persistable=True, 29 | name="scheduled_learning_rate") 30 | 31 | global_step = fluid.layers.learning_rate_scheduler._decay_step_counter() 32 | 33 | with fluid.layers.control_flow.Switch() as switch: 34 | with switch.case(global_step < self.warmup_steps): 35 | warmup_lr = learning_rate * (global_step / self.warmup_steps) 36 | fluid.layers.tensor.assign(warmup_lr, lr) 37 | with switch.default(): 38 | decayed_lr = fluid.layers.learning_rate_scheduler.polynomial_decay( 39 | learning_rate=learning_rate, 40 | decay_steps=self.num_train_steps, 41 | end_learning_rate=0.0, 42 | power=1.0, 43 | cycle=False) 44 | fluid.layers.tensor.assign(decayed_lr, lr) 45 | 46 | return lr 47 | 48 | 49 | -------------------------------------------------------------------------------- /paddlepalm/lr_sched/warmup_schedualer.py: -------------------------------------------------------------------------------- 1 | 2 | from paddlepalm.lr_sched.base_schedualer import Schedualer 3 | import paddle.fluid as fluid 4 | 5 | def WarmupSchedualer(Schedualer): 6 | """ Applies linear warmup of learning rate from 0 to learning_rate until warmup_steps, and then decay to 0 linearly until num_train_steps.""" 7 | 8 | def __init__(self, warmup_steps): 9 | schedualer.__init__(self) 10 | self.warmup_steps = warmup_steps 11 | 12 | def _build(self, learning_rate): 13 | 14 | with self._prog._lr_schedule_guard(): 15 | lr = fluid.layers.tensor.create_global_var( 16 | shape=[1], 17 | value=0.0, 18 | dtype='float32', 19 | persistable=True, 20 | name="scheduled_learning_rate") 21 | 22 | global_step = fluid.layers.learning_rate_scheduler._decay_step_counter() 23 | 24 | with fluid.layers.control_flow.Switch() as switch: 25 | with switch.case(global_step < self.warmup_steps): 26 | warmup_lr = learning_rate * (global_step / self.warmup_steps) 27 | fluid.layers.tensor.assign(warmup_lr, lr) 28 | with switch.default(): 29 | fluid.layers.tensor.assign(learning_rate, lr) 30 | 31 | return lr 32 | 33 | -------------------------------------------------------------------------------- /paddlepalm/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .adam import Adam 3 | -------------------------------------------------------------------------------- /paddlepalm/optimizer/adam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Optimization and learning rate scheduling.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import paddle.fluid as fluid 23 | from paddlepalm.optimizer.base_optimizer import Optimizer 24 | 25 | class Adam(Optimizer): 26 | 27 | def __init__(self, loss_var, lr, lr_schedualer=None): 28 | 29 | Optimizer.__init__(self, loss_var, lr, lr_schedualer=None) 30 | 31 | self._loss = loss_var 32 | self._lr = lr 33 | self._lr_schedualer = lr_schedualer 34 | 35 | def _build(self, grad_clip=None): 36 | 37 | if self._lr_schedualer is not None: 38 | self._lr = self._lr_schedualer._build(self._lr) 39 | 40 | optimizer = fluid.optimizer.Adam(learning_rate=self._lr) 41 | 42 | if grad_clip is not None: 43 | clip_norm_thres = grad_clip 44 | # When using mixed precision training, scale the gradient clip threshold 45 | # by loss_scaling 46 | fluid.clip.set_gradient_clip( 47 | clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres)) 48 | 49 | _, param_grads = optimizer.minimize(self._loss) 50 | return param_grads 51 | 52 | def get_cur_learning_rate(self): 53 | return self._lr 54 | 55 | 56 | -------------------------------------------------------------------------------- /paddlepalm/optimizer/base_optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | class Optimizer(object): 3 | 4 | def __init__(self, loss_var, lr, lr_schedualer=None): 5 | self._prog = None 6 | self._lr_schedualer = lr_schedualer 7 | 8 | def _build(self, grad_clip=None): 9 | raise NotImplementedError() 10 | 11 | def _set_prog(self, prog, init_prog): 12 | self._prog = prog 13 | self._init_prog = prog 14 | if self._lr_schedualer is not None: 15 | self._lr_schedualer._set_prog(prog) 16 | 17 | def get_cur_learning_rate(self): 18 | pass 19 | 20 | 21 | -------------------------------------------------------------------------------- /paddlepalm/reader/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .cls import ClassifyReader 3 | from .match import MatchReader 4 | from .seq_label import SequenceLabelReader 5 | from .mrc import MRCReader 6 | from .mlm import MaskLMReader 7 | -------------------------------------------------------------------------------- /paddlepalm/reader/base_reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from copy import copy 17 | class Reader(object): 18 | """interface of data reader.""" 19 | 20 | def __init__(self, phase='train'): 21 | """该函数完成一个Reader的构造,至少需要包含一个phase参数。 22 | 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 23 | Args: 24 | phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict 25 | """ 26 | 27 | self._phase = phase 28 | self._batch_size = None 29 | self._num_epochs = 1 30 | self._register = set() 31 | self._registered_backbone = None 32 | 33 | @classmethod 34 | def create_register(self): 35 | return set() 36 | 37 | def clone(self, phase='train'): 38 | """拷贝一个新的reader对象。""" 39 | if phase == self._phase: 40 | return copy(self) 41 | else: 42 | ret = copy(self) 43 | ret._phase = phase 44 | return ret 45 | 46 | def require_attr(self, attr_name): 47 | """在注册器中新增一个需要产生的对象。 48 | 49 | Args: 50 | attr_name: 需要产出的对象的对象名,例如’segment_ids‘。 51 | """ 52 | self._register.add(attr_name) 53 | 54 | def register_with(self, backbone): 55 | """根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。 56 | 57 | Args: 58 | backbone: 需要对接的主干网络。 59 | """ 60 | for attr in backbone.inputs_attr: 61 | self.require_attr(attr) 62 | self._registered_backbone = backbone 63 | 64 | def get_registered_backbone(self): 65 | """返回该reader所注册的backbone。""" 66 | return self._registered_backbone 67 | 68 | def _get_registed_attrs(self, attrs): 69 | ret = {} 70 | for i in self._register: 71 | if i not in attrs: 72 | raise NotImplementedError('output attr {} is not found in this reader.'.format(i)) 73 | ret[i] = attrs[i] 74 | return ret 75 | 76 | def load_data(self, input_file, batch_size, num_epochs=None, \ 77 | file_format='tsv', shuffle_train=True): 78 | """将磁盘上的数据载入到reader中。 79 | 80 | 注意:实现该方法时需要同步创建self._batch_size和self._num_epochs。 81 | 82 | Args: 83 | input_file: 数据集文件路径。文件格式需要满足`file_format`参数的要求。 84 | batch_size: 迭代器每次yield出的样本数量。注意:当环境中存在多个GPU时,batch_size需要保证被GPU卡数整除。 85 | num_epochs: 数据集遍历次数。默认为None, 在单任务模式下代表遍历一次,在多任务模式下该参数会被上层的Trainer进行自动赋值。该参数仅对训练阶段有效。 86 | file_format: 输入文件的文件格式。目前支持的格式: tsv. 默认为tsv. 87 | shuffle_train: 是否打乱训练集中的样本。默认为True。该参数仅对训练阶段有效。 88 | """ 89 | raise NotImplementedError() 90 | 91 | @property 92 | def outputs_attr(self): 93 | """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据 94 | 类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 95 | 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) 96 | Return: 97 | dict类型。对各个输入对象的属性描述。例如, 98 | 对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象) 99 | {"token_ids": ([-1, max_len], 'int64'), 100 | "input_ids": ([-1, max_len], 'int64'), 101 | "segment_ids": ([-1, max_len], 'int64'), 102 | "input_mask": ([-1, max_len], 'float32'), 103 | "label": ([-1], 'int')} 104 | """ 105 | raise NotImplementedError() 106 | 107 | def _iterator(self): 108 | """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 109 | Yield: 110 | dict类型。符合outputs_attr描述的当前step的输出对象。 111 | """ 112 | raise NotImplementedError() 113 | 114 | def get_epoch_outputs(self): 115 | """返回数据集每个epoch遍历后的输出对象。""" 116 | raise NotImplementedError() 117 | 118 | @property 119 | def num_examples(self): 120 | """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时 121 | 该接口应返回runtime阶段的实际样本数。""" 122 | raise NotImplementedError() 123 | 124 | @property 125 | def num_epochs(self): 126 | """数据集遍历次数""" 127 | return self._num_epochs 128 | -------------------------------------------------------------------------------- /paddlepalm/reader/cls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from paddlepalm.reader.base_reader import Reader 17 | from paddlepalm.reader.utils.reader4ernie import ClassifyReader as CLSReader 18 | 19 | 20 | class ClassifyReader(Reader): 21 | """ 22 | The reader completes the loading and processing of text classification dataset. Supported file format: tsv. 23 | 24 | For tsv format, training dataset file should have two header areas, i.e., `label` and `text`, and test set only requires `text` area. For example, 25 | 26 | ``` 27 | label [TAB] text 28 | 1 [TAB] Today is a good day. 29 | 0 [TAB] Such a terriable day! 30 | 1 [TAB] I feel lucky to meet you, dear. 31 | 1 [TAB] He likes sunshine and I like him :). 32 | 0 [TAB] JUST! GO! OUT! 33 | ``` 34 | 35 | CAUTIOUS: The first line of the file must be header! And areas are splited by tab (\\t). 36 | 37 | """ 38 | 39 | def __init__(self, vocab_path, max_len, tokenizer='wordpiece', \ 40 | lang='en', seed=None, do_lower_case=False, phase='train'): 41 | """Create a new Reader for loading and processing classification task data. 42 | 43 | Args: 44 | vocab_path: the vocab file path to do tokenization and token_ids generation. 45 | max_len: The maximum length of the sequence (after word segmentation). The part exceeding max_len will be removed from right. 46 | tokenizer: string type. The name of the used tokenizer. A tokenizer is to convert raw text into tokens. Avaliable tokenizers: wordpiece. 47 | lang: the language of dataset. Supported language: en (English), cn (Chinese). Default is en (English). 48 | seed: int type. The random seed to shuffle dataset. Default is None, means no use of random seed. 49 | do_lower_case: bool type. Whether to do lowercase on English text. Default is False. This argument only works on English text. 50 | phase: the running phase of this reader. Supported phase: train, predict. Default is train. 51 | 52 | Return: 53 | a Reader object for classification task. 54 | """ 55 | 56 | Reader.__init__(self, phase) 57 | 58 | assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)." 59 | assert phase in ['train', 'predict'], "supported phase: train, predict." 60 | 61 | for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' 62 | 63 | self._register.add('token_ids') 64 | if phase == 'train': 65 | self._register.add('label_ids') 66 | 67 | self._is_training = phase == 'train' 68 | 69 | cls_reader = CLSReader(vocab_path, 70 | max_seq_len=max_len, 71 | do_lower_case=do_lower_case, 72 | for_cn=for_cn, 73 | random_seed=seed) 74 | self._reader = cls_reader 75 | 76 | self._phase = phase 77 | # self._batch_size = 78 | # self._print_first_n = config.get('print_first_n', 0) 79 | 80 | 81 | @property 82 | def outputs_attr(self): 83 | """The contained output items (input features) of this reader.""" 84 | attrs = {"token_ids": [[-1, -1], 'int64'], 85 | "position_ids": [[-1, -1], 'int64'], 86 | "segment_ids": [[-1, -1], 'int64'], 87 | "input_mask": [[-1, -1, 1], 'float32'], 88 | "label_ids": [[-1], 'int64'], 89 | "task_ids": [[-1, -1], 'int64'] 90 | } 91 | return self._get_registed_attrs(attrs) 92 | 93 | 94 | def load_data(self, input_file, batch_size, num_epochs=None, \ 95 | file_format='tsv', shuffle_train=True): 96 | """Load classification data into reader. 97 | 98 | Args: 99 | input_file: the dataset file path. File format should keep consistent with `file_format` argument. 100 | batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices (marked as dev_count), the batch_size should be divided by dev_count with no remainder! 101 | num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning and automatically calculated for multi-task learning. This argument only works on train phase. 102 | file_format: the file format of input file. Supported format: tsv. Default is tsv. 103 | shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase. 104 | 105 | """ 106 | self._batch_size = batch_size 107 | self._num_epochs = num_epochs 108 | self._data_generator = self._reader.data_generator( \ 109 | input_file, batch_size, num_epochs if self._phase == 'train' else 1, \ 110 | shuffle=shuffle_train if self._phase == 'train' else False, \ 111 | phase=self._phase) 112 | 113 | def _iterator(self): 114 | 115 | names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 116 | 'label_ids', 'unique_ids'] 117 | for batch in self._data_generator(): 118 | outputs = {n: i for n,i in zip(names, batch)} 119 | ret = {} 120 | # TODO: move runtime shape check here 121 | for attr in self.outputs_attr.keys(): 122 | ret[attr] = outputs[attr] 123 | yield ret 124 | 125 | def get_epoch_outputs(self): 126 | return {'examples': self._reader.get_examples(self._phase), 127 | 'features': self._reader.get_features(self._phase)} 128 | 129 | @property 130 | def num_examples(self): 131 | return self._reader.get_num_examples(phase=self._phase) 132 | 133 | @property 134 | def num_epochs(self): 135 | return self._num_epochs 136 | 137 | 138 | -------------------------------------------------------------------------------- /paddlepalm/reader/match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from paddlepalm.reader.base_reader import Reader 17 | from paddlepalm.reader.utils.reader4ernie import ClassifyReader as CLSReader 18 | 19 | 20 | class MatchReader(Reader): 21 | """ 22 | The reader completes the loading and processing of matching-like task (e.g, query-query, question-answer, text similarity, natural language inference) dataset. Supported file format: tsv. 23 | 24 | For pointwise learning strategy, there should be two fields in training dataset file, i.e., `text_a`, `text_b` and `label`. For pairwise learning, there should exist three fields, i.e., `text_a`, `text_b` and `text_b_neg`. For predicting, only `text_a` and `text_b` are required. 25 | 26 | A pointwise learning case shows as follows: 27 | ``` 28 | label [TAB] text_a [TAB] text_b 29 | 1 [TAB] Today is a good day. [TAB] what a nice day! 30 | 0 [TAB] Such a terriable day! [TAB] There is a dog. 31 | 1 [TAB] I feel lucky to meet you, dear. [TAB] You are my lucky, darling. 32 | 1 [TAB] He likes sunshine and I like him :). [TAB] I like him. He like sunshine. 33 | 0 [TAB] JUST! GO! OUT! [TAB] Come in please. 34 | ``` 35 | A pairwise learning case shows as follows: 36 | text_a [TAB] text_b [TAB] text_b_neg 37 | Today is a good day. [TAB] what a nice day! [TAB] terriable day! 38 | Such a terriable day! [TAB] So terriable today! [TAB] There is a dog. 39 | I feel lucky to meet you, dear. [TAB] You are my lucky, darling. [TAB] Buy some bananas, okey? 40 | He likes sunshine and I like him :). [TAB] I like him. He like sunshine. [TAB] He has a dog. 41 | JUST! GO! OUT! [TAB] go out now! [TAB] Come in please. 42 | 43 | CAUTIOUS: the HEADER is required for each dataset file! And fields (columns) should be splited by Tab (\\t). 44 | 45 | """ 46 | 47 | def __init__(self, vocab_path, max_len, tokenizer='wordpiece', lang='en', seed=None, \ 48 | do_lower_case=False, learning_strategy='pointwise', phase='train', dev_count=1, print_prefix=''): 49 | """Create a new Reader for classification task data. 50 | 51 | Args: 52 | vocab_path: the vocab file path to do tokenization and token_ids generation. 53 | max_len: The maximum length of the sequence (after word segmentation). The part exceeding max_len will be removed from right. 54 | tokenizer: string type. The name of the used tokenizer. A tokenizer is to convert raw text into tokens. Avaliable tokenizers: wordpiece. 55 | lang: the language of dataset. Supported language: en (English), cn (Chinese). Default is en (English). 56 | seed: int type. The random seed to shuffle dataset. Default is None, means no use of random seed. 57 | do_lower_case: bool type. Whether to do lowercase on English text. Default is False. This argument only works on English text. 58 | learning_strategy: string type. This only works for training phase. Available strategies: pointwise, pairwise. 59 | phase: the running phase of this reader. Supported phase: train, predict. Default is train. 60 | 61 | Return: 62 | a Reader object for matching-like task. 63 | """ 64 | 65 | Reader.__init__(self, phase) 66 | 67 | assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)." 68 | assert phase in ['train', 'predict'], "supported phase: train, predict." 69 | 70 | for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' 71 | 72 | self._register.add('token_ids') 73 | if phase == 'train': 74 | if learning_strategy == 'pointwise': 75 | self._register.add('label_ids') 76 | if learning_strategy == 'pairwise': 77 | self._register.add('token_ids_neg') 78 | self._register.add('position_ids_neg') 79 | self._register.add('segment_ids_neg') 80 | self._register.add('input_mask_neg') 81 | self._register.add('task_ids_neg') 82 | 83 | self._is_training = phase == 'train' 84 | self._learning_strategy = learning_strategy 85 | 86 | 87 | match_reader = CLSReader(vocab_path, 88 | max_seq_len=max_len, 89 | do_lower_case=do_lower_case, 90 | for_cn=for_cn, 91 | random_seed=seed, 92 | learning_strategy = learning_strategy) 93 | 94 | self._reader = match_reader 95 | self._dev_count = dev_count 96 | self._phase = phase 97 | 98 | 99 | @property 100 | def outputs_attr(self): 101 | attrs = {"token_ids": [[-1, -1], 'int64'], 102 | "position_ids": [[-1, -1], 'int64'], 103 | "segment_ids": [[-1, -1], 'int64'], 104 | "input_mask": [[-1, -1, 1], 'float32'], 105 | "task_ids": [[-1, -1], 'int64'], 106 | "label_ids": [[-1], 'int64'], 107 | "token_ids_neg": [[-1, -1], 'int64'], 108 | "position_ids_neg": [[-1, -1], 'int64'], 109 | "segment_ids_neg": [[-1, -1], 'int64'], 110 | "input_mask_neg": [[-1, -1, 1], 'float32'], 111 | "task_ids_neg": [[-1, -1], 'int64'] 112 | } 113 | return self._get_registed_attrs(attrs) 114 | 115 | 116 | def load_data(self, input_file, batch_size, num_epochs=None, \ 117 | file_format='tsv', shuffle_train=True): 118 | """Load matching data into reader. 119 | 120 | Args: 121 | input_file: the dataset file path. File format should keep consistent with `file_format` argument. 122 | batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices (marked as dev_count), the batch_size should be divided by dev_count with no remainder! 123 | num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning and automatically calculated for multi-task learning. This argument only works on train phase. 124 | file_format: the file format of input file. Supported format: tsv. Default is tsv. 125 | shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase. 126 | 127 | """ 128 | self._batch_size = batch_size 129 | self._num_epochs = num_epochs 130 | self._data_generator = self._reader.data_generator( \ 131 | input_file, batch_size, num_epochs if self._phase == 'train' else 1, \ 132 | shuffle=shuffle_train if self._phase == 'train' else False, \ 133 | phase=self._phase) 134 | 135 | def _iterator(self): 136 | 137 | 138 | names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 'label_ids', \ 139 | 'token_ids_neg', 'segment_ids_neg', 'position_ids_neg', 'task_ids_neg', 'input_mask_neg'] 140 | 141 | if self._learning_strategy == 'pairwise': 142 | names.remove('label_ids') 143 | 144 | 145 | for batch in self._data_generator(): 146 | outputs = {n: i for n,i in zip(names, batch)} 147 | ret = {} 148 | # TODO: move runtime shape check here 149 | for attr in self.outputs_attr.keys(): 150 | ret[attr] = outputs[attr] 151 | yield ret 152 | 153 | @property 154 | def num_examples(self): 155 | return self._reader.get_num_examples(phase=self._phase) 156 | 157 | @property 158 | def num_epochs(self): 159 | return self._num_epochs 160 | 161 | -------------------------------------------------------------------------------- /paddlepalm/reader/mlm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from paddlepalm.reader.base_reader import Reader 17 | from paddlepalm.reader.utils.reader4ernie import MaskLMReader as MLMReader 18 | import numpy as np 19 | 20 | class MaskLMReader(Reader): 21 | 22 | def __init__(self, vocab_path, max_len, tokenizer='wordpiece', \ 23 | lang='en', seed=None, do_lower_case=False, phase='train', dev_count=1, print_prefix=''): 24 | """ 25 | Args: 26 | phase: train, eval, pred 27 | """ 28 | 29 | 30 | Reader.__init__(self, phase) 31 | 32 | assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)." 33 | assert phase in ['train', 'predict'], "supported phase: train, predict." 34 | 35 | for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' 36 | 37 | self._register.add('mask_pos') 38 | if phase == 'train': 39 | self._register.add('mask_label') 40 | self._is_training = phase == 'train' 41 | 42 | mlm_reader = MLMReader(vocab_path, 43 | max_seq_len=max_len, 44 | do_lower_case=do_lower_case, 45 | for_cn=for_cn, 46 | random_seed=seed) 47 | self._reader = mlm_reader 48 | 49 | self._phase = phase 50 | self._dev_count = dev_count 51 | 52 | 53 | @property 54 | def outputs_attr(self): 55 | attrs = {"token_ids": [[-1, -1], 'int64'], 56 | "position_ids": [[-1, -1], 'int64'], 57 | "segment_ids": [[-1, -1], 'int64'], 58 | "input_mask": [[-1, -1, 1], 'float32'], 59 | "task_ids": [[-1, -1], 'int64'], 60 | "mask_label": [[-1], 'int64'], 61 | "mask_pos": [[-1], 'int64'] 62 | } 63 | 64 | return self._get_registed_attrs(attrs) 65 | 66 | 67 | def load_data(self, input_file, batch_size, num_epochs=None, \ 68 | file_format='csv', shuffle_train=True): 69 | self._batch_size = batch_size 70 | self._num_epochs = num_epochs 71 | self._data_generator = self._reader.data_generator( \ 72 | input_file, batch_size, num_epochs if self._phase == 'train' else 1, \ 73 | shuffle=shuffle_train if self._phase == 'train' else False, \ 74 | phase=self._phase) 75 | 76 | def _iterator(self): 77 | 78 | names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask', 79 | 'task_ids', 'mask_label', 'mask_pos'] 80 | for batch in self._data_generator(): 81 | outputs = {n: i for n,i in zip(names, batch)} 82 | ret = {} 83 | # TODO: move runtime shape check here 84 | for attr in self.outputs_attr.keys(): 85 | ret[attr] = outputs[attr] 86 | 87 | yield ret 88 | 89 | def get_epoch_outputs(self): 90 | return {'examples': self._reader.get_examples(self._phase), 91 | 'features': self._reader.get_features(self._phase)} 92 | 93 | @property 94 | def num_examples(self): 95 | return self._reader.get_num_examples(phase=self._phase) 96 | 97 | @property 98 | def num_epochs(self): 99 | return self._num_epochs 100 | 101 | -------------------------------------------------------------------------------- /paddlepalm/reader/seq_label.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from paddlepalm.reader.base_reader import Reader 17 | from paddlepalm.reader.utils.reader4ernie import SequenceLabelReader as SLReader 18 | 19 | class SequenceLabelReader(Reader): 20 | """ 21 | The reader completes the loading and processing of sequence labeling type task (e.g, pos tagging, named entity recognition) dataset. Supported file format: tsv. 22 | """ 23 | 24 | def __init__(self, vocab_path, max_len, label_map_config, tokenizer='wordpiece', \ 25 | lang='en', seed=None, do_lower_case=False, phase='train', dev_count=1, print_prefix=''): 26 | """ 27 | Args: 28 | phase: train, eval, pred 29 | lang: en, ch, ... 30 | """ 31 | 32 | Reader.__init__(self, phase) 33 | 34 | assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)." 35 | assert phase in ['train', 'predict'], "supported phase: train, predict." 36 | 37 | for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' 38 | 39 | self._register.add('token_ids') 40 | self._register.add('seq_lens') 41 | if phase == 'train': 42 | self._register.add('label_ids') 43 | 44 | self._is_training = phase == 'train' 45 | 46 | ner_reader = SLReader(vocab_path, 47 | max_seq_len=max_len, 48 | do_lower_case=do_lower_case, 49 | for_cn=for_cn, 50 | random_seed=seed, 51 | label_map_config=label_map_config) 52 | self._reader = ner_reader 53 | self._phase = phase 54 | self._dev_count = dev_count 55 | 56 | 57 | @property 58 | def outputs_attr(self): 59 | attrs = {"token_ids": [[-1, -1], 'int64'], 60 | "position_ids": [[-1, -1], 'int64'], 61 | "segment_ids": [[-1, -1], 'int64'], 62 | "task_ids": [[-1, -1], 'int64'], 63 | "input_mask": [[-1, -1, 1], 'float32'], 64 | "seq_lens": [[-1], 'int64'], 65 | "label_ids": [[-1, -1], 'int64']} 66 | return self._get_registed_attrs(attrs) 67 | 68 | 69 | def load_data(self, input_file, batch_size, num_epochs=None, \ 70 | file_format='tsv', shuffle_train=True): 71 | """Load sequence labeling data into reader. 72 | 73 | Args: 74 | input_file: the dataset file path. File format should keep consistent with `file_format` argument. 75 | batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices (marked as dev_count), the batch_size should be divided by dev_count with no remainder! 76 | num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning and automatically calculated for multi-task learning. This argument only works on train phase. 77 | file_format: the file format of input file. Supported format: tsv. Default is tsv. 78 | shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase. 79 | 80 | """ 81 | self._batch_size = batch_size 82 | self._num_epochs = num_epochs 83 | self._data_generator = self._reader.data_generator( \ 84 | input_file, batch_size, num_epochs if self._phase == 'train' else 1, \ 85 | shuffle=shuffle_train if self._phase == 'train' else False, \ 86 | phase=self._phase) 87 | 88 | def _iterator(self): 89 | 90 | names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 91 | 'label_ids', 'seq_lens', 'label_ids'] 92 | for batch in self._data_generator(): 93 | outputs = {n: i for n,i in zip(names, batch)} 94 | ret = {} 95 | # TODO: move runtime shape check here 96 | for attr in self.outputs_attr.keys(): 97 | ret[attr] = outputs[attr] 98 | yield ret 99 | 100 | def get_epoch_outputs(self): 101 | return {'examples': self._reader.get_examples(self._phase), 102 | 'features': self._reader.get_features(self._phase)} 103 | 104 | @property 105 | def num_examples(self): 106 | return self._reader.get_num_examples(phase=self._phase) 107 | 108 | @property 109 | def num_epochs(self): 110 | return self._num_epochs 111 | -------------------------------------------------------------------------------- /paddlepalm/reader/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/PALM/2555c0e2a5fab1b702ae8d1c7612bef48c65af38/paddlepalm/reader/utils/__init__.py -------------------------------------------------------------------------------- /paddlepalm/reader/utils/batching4bert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Mask, padding and batching.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import numpy as np 20 | 21 | 22 | def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): 23 | """ 24 | Add mask for batch_tokens, return out, mask_label, mask_pos; 25 | Note: mask_pos responding the batch_tokens after padded; 26 | """ 27 | max_len = max([len(sent) for sent in batch_tokens]) 28 | mask_label = [] 29 | mask_pos = [] 30 | prob_mask = np.random.rand(total_token_num) 31 | # Note: the first token is [CLS], so [low=1] 32 | replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) 33 | pre_sent_len = 0 34 | prob_index = 0 35 | for sent_index, sent in enumerate(batch_tokens): 36 | mask_flag = False 37 | prob_index += pre_sent_len 38 | for token_index, token in enumerate(sent): 39 | prob = prob_mask[prob_index + token_index] 40 | if prob > 0.15: 41 | continue 42 | elif 0.03 < prob <= 0.15: 43 | # mask 44 | if token != SEP and token != CLS: 45 | mask_label.append(sent[token_index]) 46 | sent[token_index] = MASK 47 | mask_flag = True 48 | mask_pos.append(sent_index * max_len + token_index) 49 | elif 0.015 < prob <= 0.03: 50 | # random replace 51 | if token != SEP and token != CLS: 52 | mask_label.append(sent[token_index]) 53 | sent[token_index] = replace_ids[prob_index + token_index] 54 | mask_flag = True 55 | mask_pos.append(sent_index * max_len + token_index) 56 | else: 57 | # keep the original token 58 | if token != SEP and token != CLS: 59 | mask_label.append(sent[token_index]) 60 | mask_pos.append(sent_index * max_len + token_index) 61 | pre_sent_len = len(sent) 62 | # ensure at least mask one word in a sentence 63 | while not mask_flag: 64 | token_index = int(np.random.randint(1, high=len(sent) - 1, size=1)) 65 | if sent[token_index] != SEP and sent[token_index] != CLS: 66 | mask_label.append(sent[token_index]) 67 | sent[token_index] = MASK 68 | mask_flag = True 69 | mask_pos.append(sent_index * max_len + token_index) 70 | mask_label = np.array(mask_label).astype("int64").reshape([-1]) 71 | mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) 72 | return batch_tokens, mask_label, mask_pos 73 | 74 | 75 | def prepare_batch_data(insts, 76 | total_token_num, 77 | max_len=None, 78 | voc_size=0, 79 | pad_id=None, 80 | cls_id=None, 81 | sep_id=None, 82 | mask_id=None, 83 | return_input_mask=True, 84 | return_max_len=True, 85 | return_num_token=False): 86 | """ 87 | 1. generate Tensor of data 88 | 2. generate Tensor of position 89 | 3. generate self attention mask, [shape: batch_size * max_len * max_len] 90 | """ 91 | batch_src_ids = [inst[0] for inst in insts] 92 | batch_sent_ids = [inst[1] for inst in insts] 93 | batch_pos_ids = [inst[2] for inst in insts] 94 | labels_list = [] 95 | # compatible with mrqa, whose example includes start/end positions, 96 | # or unique id 97 | for i in range(3, len(insts[0]), 1): 98 | labels = [inst[i] for inst in insts] 99 | labels = np.array(labels).astype("int64").reshape([-1]) 100 | labels_list.append(labels) 101 | # First step: do mask without padding 102 | if mask_id >= 0: 103 | out, mask_label, mask_pos = mask( 104 | batch_src_ids, 105 | total_token_num, 106 | vocab_size=voc_size, 107 | CLS=cls_id, 108 | SEP=sep_id, 109 | MASK=mask_id) 110 | else: 111 | out = batch_src_ids 112 | # Second step: padding 113 | src_id, self_input_mask = pad_batch_data( 114 | out, 115 | max_len=max_len, 116 | pad_idx=pad_id, return_input_mask=True) 117 | pos_id = pad_batch_data( 118 | batch_pos_ids, 119 | max_len=max_len, 120 | pad_idx=pad_id, 121 | return_pos=False, 122 | return_input_mask=False) 123 | sent_id = pad_batch_data( 124 | batch_sent_ids, 125 | max_len=max_len, 126 | pad_idx=pad_id, 127 | return_pos=False, 128 | return_input_mask=False) 129 | if mask_id >= 0: 130 | return_list = [ 131 | src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos 132 | ] + labels_list 133 | else: 134 | return_list = [src_id, pos_id, sent_id, self_input_mask] + labels_list 135 | return return_list if len(return_list) > 1 else return_list[0] 136 | 137 | 138 | def pad_batch_data(insts, 139 | max_len=None, 140 | pad_idx=0, 141 | return_pos=False, 142 | return_input_mask=False, 143 | return_max_len=False, 144 | return_num_token=False): 145 | """ 146 | Pad the instances to the max sequence length in batch, and generate the 147 | corresponding position data and input mask. 148 | """ 149 | return_list = [] 150 | if max_len is None: 151 | max_len = max(len(inst) for inst in insts) 152 | # Any token included in dict can be used to pad, since the paddings' loss 153 | # will be masked out by weights and make no effect on parameter gradients. 154 | inst_data = np.array([ 155 | list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts 156 | ]) 157 | return_list += [inst_data.astype("int64").reshape([-1, max_len])] 158 | # position data 159 | if return_pos: 160 | inst_pos = np.array([ 161 | list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) 162 | for inst in insts 163 | ]) 164 | return_list += [inst_pos.astype("int64").reshape([-1, max_len])] 165 | if return_input_mask: 166 | # This is used to avoid attention on paddings. 167 | input_mask_data = np.array([[1] * len(inst) + [0] * 168 | (max_len - len(inst)) for inst in insts]) 169 | input_mask_data = np.expand_dims(input_mask_data, axis=-1) 170 | return_list += [input_mask_data.astype("float32")] 171 | if return_max_len: 172 | return_list += [max_len] 173 | if return_num_token: 174 | num_token = 0 175 | for inst in insts: 176 | num_token += len(inst) 177 | return_list += [num_token] 178 | return return_list if len(return_list) > 1 else return_list[0] 179 | 180 | 181 | if __name__ == "__main__": 182 | pass 183 | 184 | 185 | -------------------------------------------------------------------------------- /paddlepalm/reader/utils/batching4ernie.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Mask, padding and batching.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | 23 | from six.moves import xrange 24 | 25 | 26 | def mask(batch_tokens, 27 | seg_labels, 28 | mask_word_tags, 29 | total_token_num, 30 | vocab_size, 31 | CLS=1, 32 | SEP=2, 33 | MASK=3): 34 | """ 35 | Add mask for batch_tokens, return out, mask_label, mask_pos; 36 | Note: mask_pos responding the batch_tokens after padded; 37 | """ 38 | max_len = max([len(sent) for sent in batch_tokens]) 39 | mask_label = [] 40 | mask_pos = [] 41 | prob_mask = np.random.rand(total_token_num) 42 | # Note: the first token is [CLS], so [low=1] 43 | replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) 44 | pre_sent_len = 0 45 | prob_index = 0 46 | for sent_index, sent in enumerate(batch_tokens): 47 | mask_flag = False 48 | mask_word = mask_word_tags[sent_index] 49 | prob_index += pre_sent_len 50 | if mask_word: 51 | beg = 0 52 | for token_index, token in enumerate(sent): 53 | seg_label = seg_labels[sent_index][token_index] 54 | if seg_label == 1: 55 | continue 56 | if beg == 0: 57 | if seg_label != -1: 58 | beg = token_index 59 | continue 60 | 61 | prob = prob_mask[prob_index + beg] 62 | if prob > 0.15: 63 | pass 64 | else: 65 | for index in xrange(beg, token_index): 66 | prob = prob_mask[prob_index + index] 67 | base_prob = 1.0 68 | if index == beg: 69 | base_prob = 0.15 70 | if base_prob * 0.2 < prob <= base_prob: 71 | mask_label.append(sent[index]) 72 | sent[index] = MASK 73 | mask_flag = True 74 | mask_pos.append(sent_index * max_len + index) 75 | elif base_prob * 0.1 < prob <= base_prob * 0.2: 76 | mask_label.append(sent[index]) 77 | sent[index] = replace_ids[prob_index + index] 78 | mask_flag = True 79 | mask_pos.append(sent_index * max_len + index) 80 | else: 81 | mask_label.append(sent[index]) 82 | mask_pos.append(sent_index * max_len + index) 83 | 84 | if seg_label == -1: 85 | beg = 0 86 | else: 87 | beg = token_index 88 | else: 89 | for token_index, token in enumerate(sent): 90 | prob = prob_mask[prob_index + token_index] 91 | if prob > 0.15: 92 | continue 93 | elif 0.03 < prob <= 0.15: 94 | # mask 95 | if token != SEP and token != CLS: 96 | mask_label.append(sent[token_index]) 97 | sent[token_index] = MASK 98 | mask_flag = True 99 | mask_pos.append(sent_index * max_len + token_index) 100 | elif 0.015 < prob <= 0.03: 101 | # random replace 102 | if token != SEP and token != CLS: 103 | mask_label.append(sent[token_index]) 104 | sent[token_index] = replace_ids[prob_index + 105 | token_index] 106 | mask_flag = True 107 | mask_pos.append(sent_index * max_len + token_index) 108 | else: 109 | # keep the original token 110 | if token != SEP and token != CLS: 111 | mask_label.append(sent[token_index]) 112 | mask_pos.append(sent_index * max_len + token_index) 113 | 114 | pre_sent_len = len(sent) 115 | 116 | mask_label = np.array(mask_label).astype("int64").reshape([-1]) 117 | mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) 118 | return batch_tokens, mask_label, mask_pos 119 | 120 | 121 | def pad_batch_data(insts, 122 | pad_idx=0, 123 | return_pos=False, 124 | return_input_mask=False, 125 | return_max_len=False, 126 | return_num_token=False, 127 | return_seq_lens=False): 128 | """ 129 | Pad the instances to the max sequence length in batch, and generate the 130 | corresponding position data and attention bias. 131 | """ 132 | return_list = [] 133 | max_len = max(len(inst) for inst in insts) 134 | # Any token included in dict can be used to pad, since the paddings' loss 135 | # will be masked out by weights and make no effect on parameter gradients. 136 | 137 | inst_data = np.array( 138 | [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) 139 | return_list += [inst_data.astype("int64").reshape([-1, max_len])] 140 | 141 | # position data 142 | if return_pos: 143 | inst_pos = np.array([ 144 | list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) 145 | for inst in insts 146 | ]) 147 | 148 | return_list += [inst_pos.astype("int64").reshape([-1, max_len])] 149 | 150 | if return_input_mask: 151 | # This is used to avoid attention on paddings. 152 | input_mask_data = np.array([[1] * len(inst) + [0] * 153 | (max_len - len(inst)) for inst in insts]) 154 | input_mask_data = np.expand_dims(input_mask_data, axis=-1) 155 | return_list += [input_mask_data.astype("float32")] 156 | 157 | if return_max_len: 158 | return_list += [max_len] 159 | 160 | if return_num_token: 161 | num_token = 0 162 | for inst in insts: 163 | num_token += len(inst) 164 | return_list += [num_token] 165 | 166 | if return_seq_lens: 167 | seq_lens = np.array([len(inst) for inst in insts]) 168 | return_list += [seq_lens.astype("int64").reshape([-1])] 169 | 170 | return return_list if len(return_list) > 1 else return_list[0] 171 | 172 | 173 | if __name__ == "__main__": 174 | 175 | pass 176 | -------------------------------------------------------------------------------- /paddlepalm/reader/utils/mlm_batching.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Mask, padding and batching.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import numpy as np 20 | 21 | 22 | def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1): 23 | """ 24 | Add mask for batch_tokens, return out, mask_label, mask_pos; 25 | Note: mask_pos responding the batch_tokens after padded; 26 | """ 27 | max_len = max([len(sent) for sent in batch_tokens]) 28 | 29 | multidev_batch_tokens = [] 30 | multidev_mask_label = [] 31 | multidev_mask_pos = [] 32 | 33 | big_batch_tokens = batch_tokens 34 | stride = len(batch_tokens) // dev_count 35 | if stride == 0: 36 | return None, None, None 37 | p = stride 38 | 39 | for i in range(dev_count): 40 | batch_tokens = big_batch_tokens[p-stride:p] 41 | p += stride 42 | mask_label = [] 43 | mask_pos = [] 44 | prob_mask = np.random.rand(total_token_num) 45 | # Note: the first token is [CLS], so [low=1] 46 | replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) 47 | pre_sent_len = 0 48 | prob_index = 0 49 | for sent_index, sent in enumerate(batch_tokens): 50 | mask_flag = False 51 | prob_index += pre_sent_len 52 | for token_index, token in enumerate(sent): 53 | prob = prob_mask[prob_index + token_index] 54 | if prob > 0.15: 55 | continue 56 | elif 0.03 < prob <= 0.15: 57 | # mask 58 | if token != SEP and token != CLS: 59 | mask_label.append(sent[token_index]) 60 | sent[token_index] = MASK 61 | mask_flag = True 62 | mask_pos.append(sent_index * max_len + token_index) 63 | elif 0.015 < prob <= 0.03: 64 | # random replace 65 | if token != SEP and token != CLS: 66 | mask_label.append(sent[token_index]) 67 | sent[token_index] = replace_ids[prob_index + token_index] 68 | mask_flag = True 69 | mask_pos.append(sent_index * max_len + token_index) 70 | else: 71 | # keep the original token 72 | if token != SEP and token != CLS: 73 | mask_label.append(sent[token_index]) 74 | mask_pos.append(sent_index * max_len + token_index) 75 | pre_sent_len = len(sent) 76 | # ensure at least mask one word in a sentence 77 | while not mask_flag: 78 | token_index = int(np.random.randint(1, high=len(sent) - 1, size=1)) 79 | if sent[token_index] != SEP and sent[token_index] != CLS: 80 | mask_label.append(sent[token_index]) 81 | sent[token_index] = MASK 82 | mask_flag = True 83 | mask_pos.append(sent_index * max_len + token_index) 84 | mask_label = np.array(mask_label).astype("int64").reshape([-1]) 85 | mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) 86 | 87 | multidev_batch_tokens.extend(batch_tokens) 88 | multidev_mask_label.append(mask_label) 89 | multidev_mask_pos.append(mask_pos) 90 | 91 | return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos 92 | 93 | 94 | def prepare_batch_data(insts, 95 | total_token_num, 96 | max_len=None, 97 | voc_size=0, 98 | pad_id=None, 99 | cls_id=None, 100 | sep_id=None, 101 | mask_id=None, 102 | task_id=0, 103 | return_input_mask=True, 104 | return_max_len=True, 105 | return_num_token=False, 106 | dev_count=1): 107 | """ 108 | 1. generate Tensor of data 109 | 2. generate Tensor of position 110 | 3. generate self attention mask, [shape: batch_size * max_len * max_len] 111 | """ 112 | batch_src_ids = [inst[0] for inst in insts] 113 | batch_sent_ids = [inst[1] for inst in insts] 114 | batch_pos_ids = [inst[2] for inst in insts] 115 | 116 | # 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的? 117 | # First step: do mask without padding 118 | out, mask_label, mask_pos = mask( 119 | batch_src_ids, 120 | total_token_num, 121 | vocab_size=voc_size, 122 | CLS=cls_id, 123 | SEP=sep_id, 124 | MASK=mask_id, 125 | dev_count=dev_count) 126 | # Second step: padding 127 | src_id, self_input_mask = pad_batch_data( 128 | out, 129 | max_len=max_len, 130 | pad_idx=pad_id, return_input_mask=True) 131 | 132 | pos_id = pad_batch_data( 133 | batch_pos_ids, 134 | max_len=max_len, 135 | pad_idx=pad_id, 136 | return_pos=False, 137 | return_input_mask=False) 138 | sent_id = pad_batch_data( 139 | batch_sent_ids, 140 | max_len=max_len, 141 | pad_idx=pad_id, 142 | return_pos=False, 143 | return_input_mask=False) 144 | task_ids = np.ones_like( 145 | src_id, dtype="int64") * task_id 146 | return_list = [ 147 | src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos 148 | ] 149 | return return_list 150 | 151 | 152 | def pad_batch_data(insts, 153 | max_len=None, 154 | pad_idx=0, 155 | return_pos=False, 156 | return_input_mask=False, 157 | return_max_len=False, 158 | return_num_token=False): 159 | """ 160 | Pad the instances to the max sequence length in batch, and generate the 161 | corresponding position data and input mask. 162 | """ 163 | return_list = [] 164 | if max_len is None: 165 | max_len = max(len(inst) for inst in insts) 166 | # Any token included in dict can be used to pad, since the paddings' loss 167 | # will be masked out by weights and make no effect on parameter gradients. 168 | inst_data = np.array([ 169 | list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts 170 | ]) 171 | return_list += [inst_data.astype("int64").reshape([-1, max_len])] 172 | # position data 173 | if return_pos: 174 | inst_pos = np.array([ 175 | list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) 176 | for inst in insts 177 | ]) 178 | return_list += [inst_pos.astype("int64").reshape([-1, max_len])] 179 | if return_input_mask: 180 | # This is used to avoid attention on paddings. 181 | input_mask_data = np.array([[1] * len(inst) + [0] * 182 | (max_len - len(inst)) for inst in insts]) 183 | input_mask_data = np.expand_dims(input_mask_data, axis=-1) 184 | return_list += [input_mask_data.astype("float32")] 185 | if return_max_len: 186 | return_list += [max_len] 187 | if return_num_token: 188 | num_token = 0 189 | for inst in insts: 190 | num_token += len(inst) 191 | return_list += [num_token] 192 | return return_list if len(return_list) > 1 else return_list[0] 193 | 194 | 195 | if __name__ == "__main__": 196 | pass 197 | 198 | 199 | -------------------------------------------------------------------------------- /paddlepalm/reader/utils/mrqa_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | class MRQAExample(object): 17 | """A single training/test example for simple sequence classification. 18 | 19 | For examples without an answer, the start and end position are -1. 20 | """ 21 | 22 | def __init__(self, 23 | qas_id, 24 | question_text, 25 | doc_tokens, 26 | orig_answer_text=None, 27 | start_position=None, 28 | end_position=None, 29 | is_impossible=False): 30 | self.qas_id = qas_id 31 | self.question_text = question_text 32 | self.doc_tokens = doc_tokens 33 | self.orig_answer_text = orig_answer_text 34 | self.start_position = start_position 35 | self.end_position = end_position 36 | self.is_impossible = is_impossible 37 | 38 | def __str__(self): 39 | return self.__repr__() 40 | 41 | def __repr__(self): 42 | s = "" 43 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 44 | s += ", question_text: %s" % ( 45 | tokenization.printable_text(self.question_text)) 46 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 47 | if self.start_position: 48 | s += ", start_position: %d" % (self.start_position) 49 | if self.start_position: 50 | s += ", end_position: %d" % (self.end_position) 51 | if self.start_position: 52 | s += ", is_impossible: %r" % (self.is_impossible) 53 | return s 54 | 55 | 56 | class MRQAFeature(object): 57 | """A single set of features of data.""" 58 | 59 | def __init__(self, 60 | unique_id, 61 | example_index, 62 | doc_span_index, 63 | tokens, 64 | token_to_orig_map, 65 | token_is_max_context, 66 | input_ids, 67 | input_mask, 68 | segment_ids, 69 | start_position=None, 70 | end_position=None, 71 | is_impossible=None): 72 | self.unique_id = unique_id 73 | self.example_index = example_index 74 | self.doc_span_index = doc_span_index 75 | self.tokens = tokens 76 | self.token_to_orig_map = token_to_orig_map 77 | self.token_is_max_context = token_is_max_context 78 | self.input_ids = input_ids 79 | self.input_mask = input_mask 80 | self.segment_ids = segment_ids 81 | self.start_position = start_position 82 | self.end_position = end_position 83 | self.is_impossible = is_impossible 84 | 85 | -------------------------------------------------------------------------------- /paddlepalm/tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/PALM/2555c0e2a5fab1b702ae8d1c7612bef48c65af38/paddlepalm/tokenizer/__init__.py -------------------------------------------------------------------------------- /paddlepalm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import basic_helper 3 | from . import config_helper 4 | 5 | -------------------------------------------------------------------------------- /paddlepalm/utils/basic_helper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import json 4 | import yaml 5 | from .config_helper import PDConfig 6 | import logging 7 | from paddle import fluid 8 | 9 | def get_basename(f): 10 | return os.path.splitext(f)[0] 11 | 12 | 13 | def get_suffix(f): 14 | return os.path.splitext(f)[-1] 15 | 16 | 17 | def parse_yaml(f, asdict=True, support_cmd_line=False): 18 | assert os.path.exists(f), "file {} not found.".format(f) 19 | if support_cmd_line: 20 | args = PDConfig(yaml_file=f, fuse_args=True) 21 | args.build() 22 | return args.asdict() if asdict else args 23 | else: 24 | if asdict: 25 | with open(f, "r") as fin: 26 | yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) 27 | return yaml_config 28 | else: 29 | raise NotImplementedError() 30 | 31 | 32 | def parse_json(f, asdict=True, support_cmd_line=False): 33 | assert os.path.exists(f), "file {} not found.".format(f) 34 | if support_cmd_line: 35 | args = PDConfig(json_file=f, fuse_args=support_cmd_line) 36 | args.build() 37 | return args.asdict() if asdict else args 38 | else: 39 | if asdict: 40 | with open(f, "r") as fin: 41 | config = json.load(fin) 42 | return config 43 | else: 44 | raise NotImplementedError() 45 | 46 | 47 | def parse_list(string, astype=str): 48 | assert isinstance(string, str), "{} is not a string.".format(string) 49 | if ',' not in string: 50 | return [astype(string)] 51 | string = string.replace(',', ' ') 52 | return [astype(i) for i in string.split()] 53 | 54 | 55 | def try_float(s): 56 | try: 57 | float(s) 58 | return(float(s)) 59 | except: 60 | return s 61 | 62 | 63 | # TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None 64 | def check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): 65 | for name, attr in in_attr.items(): 66 | assert name in out_attr, in_name+': '+name+' not found in '+out_name 67 | if attr != out_attr[name]: 68 | if strict: 69 | raise ValueError(name+': shape or dtype not consistent!') 70 | else: 71 | logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) 72 | 73 | 74 | def encode_inputs(inputs, scope_name, sep='.', cand_set=None): 75 | outputs = {} 76 | for k, v in inputs.items(): 77 | if cand_set is not None: 78 | if k in cand_set: 79 | outputs[k] = v 80 | if scope_name+sep+k in cand_set: 81 | outputs[scope_name+sep+k] = v 82 | else: 83 | outputs[scope_name+sep+k] = v 84 | return outputs 85 | 86 | 87 | def decode_inputs(inputs, scope_name, sep='.', keep_unk_keys=True): 88 | outputs = {} 89 | for name, value in inputs.items(): 90 | # var for backbone are also available to tasks 91 | if keep_unk_keys and sep not in name: 92 | outputs[name] = value 93 | # var for this inst 94 | if name.startswith(scope_name+'.'): 95 | outputs[name[len(scope_name+'.'):]] = value 96 | return outputs 97 | 98 | 99 | def build_executor(on_gpu): 100 | if on_gpu: 101 | place = fluid.CUDAPlace(0) 102 | # dev_count = fluid.core.get_cuda_device_count() 103 | else: 104 | place = fluid.CPUPlace() 105 | # dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) 106 | # return fluid.Executor(place), dev_count 107 | return fluid.Executor(place) 108 | 109 | 110 | def fit_attr(conf, fit_attr, strict=False): 111 | for i, attr in fit_attr.items(): 112 | if i not in conf: 113 | if strict: 114 | raise Exception('Argument {} is required to create a controller.'.format(i)) 115 | else: 116 | continue 117 | conf[i] = attr(conf[i]) 118 | return conf 119 | -------------------------------------------------------------------------------- /paddlepalm/utils/plot_helper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/PALM/2555c0e2a5fab1b702ae8d1c7612bef48c65af38/paddlepalm/utils/plot_helper.py -------------------------------------------------------------------------------- /paddlepalm/utils/print_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | MAXLEN = 70 17 | def print_dict(dic, title=""): 18 | 19 | if title: 20 | title = ' ' + title + ' ' 21 | left_len = (MAXLEN - len(title)) // 2 22 | title = '-' * left_len + title 23 | right_len = MAXLEN - len(title) 24 | title = title + '-' * right_len 25 | else: 26 | title = '-' * MAXLEN 27 | print(title) 28 | for name in dic: 29 | print("{: <25}\t{}".format(str(name), str(dic[name]))) 30 | print("") 31 | # print("-" * MAXLEN + '\n') 32 | -------------------------------------------------------------------------------- /paddlepalm/utils/saver.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import os 19 | import six 20 | import ast 21 | import copy 22 | import tarfile 23 | import shutil 24 | 25 | import numpy as np 26 | import paddle.fluid as fluid 27 | 28 | def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []): 29 | assert os.path.exists( 30 | init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path 31 | 32 | def existed_persitables(var): 33 | if not fluid.io.is_persistable(var): 34 | return False 35 | if var.name in skip_list: 36 | return False 37 | return os.path.exists(os.path.join(init_checkpoint_path, var.name)) 38 | 39 | fluid.io.load_vars( 40 | exe, 41 | init_checkpoint_path, 42 | main_program=main_program, 43 | predicate=existed_persitables) 44 | print("Load model from {}".format(init_checkpoint_path)) 45 | 46 | 47 | def init_pretraining_params(exe, 48 | pretraining_params_path, 49 | convert, 50 | main_program, 51 | strict=False): 52 | 53 | assert os.path.exists(pretraining_params_path 54 | ), "[%s] cann't be found." % pretraining_params_path 55 | 56 | if convert: 57 | assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found." 58 | 59 | with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f: 60 | f.extractall(os.path.join(pretraining_params_path, '.temp')) 61 | 62 | log_path = os.path.join(pretraining_params_path, '__palmmodel__') 63 | pretraining_params_path = os.path.join(pretraining_params_path, '.temp') 64 | 65 | else: 66 | log_path = pretraining_params_path 67 | 68 | print("Loading pretraining parameters from {}...".format(pretraining_params_path)) 69 | 70 | def existed_params(var): 71 | if not isinstance(var, fluid.framework.Parameter): 72 | return False 73 | if not os.path.exists(os.path.join(pretraining_params_path, var.name)): 74 | if strict: 75 | raise Exception('Error: {} not found in {}.'.format(var.name, log_path)) 76 | else: 77 | print('Warning: {} not found in {}.'.format(var.name, log_path)) 78 | return os.path.exists(os.path.join(pretraining_params_path, var.name)) 79 | 80 | fluid.io.load_vars( 81 | exe, 82 | pretraining_params_path, 83 | main_program=main_program, 84 | predicate=existed_params) 85 | if convert: 86 | shutil.rmtree(pretraining_params_path) 87 | print('') 88 | 89 | 90 | -------------------------------------------------------------------------------- /paddlepalm/utils/textprocess_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | def is_whitespace(c): 17 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 18 | return True 19 | return False 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | 3 | name = paddlepalm 4 | 5 | author = zhangyiming 6 | author_email = zhangyiming04@baidu.com 7 | 8 | version = 2.1.0 9 | 10 | description = PaddlePALM 11 | long_description = file: README.md 12 | long_description_content_type = text/markdown 13 | 14 | home_page = https://github.com/PaddlePaddle/PALM 15 | license = Apache 2.0 16 | 17 | classifier = 18 | Private :: Do Not Upload 19 | Programming Language :: Python 20 | Programming Language :: Python :: 2 21 | Programming Language :: Python :: 2.7 22 | Programming Language :: Python :: 3 23 | Programming Language :: Python :: 3.5 24 | Programming Language :: Python :: 3.6 25 | Programming Language :: Python :: 3.7 26 | 27 | keywords = 28 | paddlepaddle 29 | paddle 30 | nlp 31 | pretrain 32 | multi-task-learning 33 | 34 | [options] 35 | 36 | packages = find: 37 | 38 | include_package_data = True 39 | zip_safe = False 40 | 41 | [sdist] 42 | dist_dir = output/dist 43 | 44 | [bdist_wheel] 45 | dist_dir = output/dist 46 | 47 | [easy_install] 48 | index_url = http://pip.baidu.com/root/baidu/+simple/ 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License" 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | ################################################################################ 18 | """ 19 | Setup script. 20 | Authors: zhouxiangyang(zhouxiangyang@baidu.com) 21 | Date: 2020/2/4 00:00:01 22 | """ 23 | import setuptools 24 | with open("README.md", "r") as fh: 25 | long_description = fh.read() 26 | setuptools.setup( 27 | name="paddlepalm", 28 | version="2.1.0", 29 | author="PaddlePaddle", 30 | author_email="zhangyiming04@baidu.com", 31 | description="a flexible, general and easy-to-use NLP large-scale pretraining and multi-task learning framework.", 32 | # long_description=long_description, 33 | # long_description_content_type="text/markdown", 34 | url="https://github.com/PaddlePaddle/PALM", 35 | # packages=setuptools.find_packages(), 36 | packages = ['paddlepalm', 37 | 'paddlepalm.backbone', 38 | 'paddlepalm.backbone.utils', 39 | 'paddlepalm.optimizer', 40 | 'paddlepalm.reader', 41 | 'paddlepalm.reader.utils', 42 | 'paddlepalm.head', 43 | 'paddlepalm.distribute', 44 | 'paddlepalm.lr_sched', 45 | 'paddlepalm.tokenizer', 46 | 'paddlepalm.utils'], 47 | package_dir={'paddlepalm':'./paddlepalm', 48 | 'paddlepalm.backbone':'./paddlepalm/backbone', 49 | 'paddlepalm.backbone.utils':'./paddlepalm/backbone/utils', 50 | 'paddlepalm.optimizer':'./paddlepalm/optimizer', 51 | 'paddlepalm.lr_sched': './paddlepalm/lr_sched', 52 | 'paddlepalm.distribute': './paddlepalm/distribute', 53 | 'paddlepalm.reader':'./paddlepalm/reader', 54 | 'paddlepalm.reader.utils':'./paddlepalm/reader/utils', 55 | 'paddlepalm.head':'./paddlepalm/head', 56 | 'paddlepalm.tokenizer':'./paddlepalm/tokenizer', 57 | 'paddlepalm.utils':'./paddlepalm/utils'}, 58 | platforms = "any", 59 | license='Apache 2.0', 60 | classifiers = [ 61 | 'License :: OSI Approved :: Apache Software License', 62 | 'Programming Language :: Python', 63 | 'Programming Language :: Python :: 2', 64 | 'Programming Language :: Python :: 2.7', 65 | 'Programming Language :: Python :: 3', 66 | 'Programming Language :: Python :: 3.5', 67 | 'Programming Language :: Python :: 3.6', 68 | 'Programming Language :: Python :: 3.7', 69 | ], 70 | install_requires = [ 71 | 'paddlepaddle-gpu>=1.8.0' 72 | ] 73 | ) 74 | 75 | 76 | -------------------------------------------------------------------------------- /test/test2/config.yaml: -------------------------------------------------------------------------------- 1 | ask_instance: "mrqa, mlm4mrqa, match4mrqa" 2 | target_tag: 1, 0, 0 3 | mix_ratio: 1.0, 0.5, 0.5 4 | 5 | save_path: "output_model/secondrun" 6 | 7 | backbone: "ernie" 8 | backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" 9 | 10 | vocab_path: "../../pretrain_model/ernie/vocab.txt" 11 | do_lower_case: True 12 | max_seq_len: 512 13 | 14 | batch_size: 4 15 | num_epochs: 2 16 | optimizer: "adam" 17 | learning_rate: 3e-5 18 | warmup_proportion: 0.1 19 | weight_decay: 0.1 20 | 21 | print_every_n_steps: 1 22 | -------------------------------------------------------------------------------- /test/test2/paddlepalm: -------------------------------------------------------------------------------- 1 | ../../paddlepalm/ -------------------------------------------------------------------------------- /test/test2/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | if __name__ == '__main__': 6 | 7 | max_seqlen = 512 8 | batch_size = 4 9 | num_epochs = 2 10 | lr = 1e-3 11 | vocab_path = './pretrain/ernie/vocab.txt' 12 | 13 | train_file = './data/cls4mrqa/train.tsv' 14 | predict_file = './data/cls4mrqa/dev.tsv' 15 | 16 | config = json.load(open('./pretrain/ernie/ernie_config.json')) 17 | # ernie = palm.backbone.ERNIE(...) 18 | ernie = palm.backbone.ERNIE.from_config(config) 19 | 20 | # cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen) 21 | # cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen) 22 | # topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls) 23 | # subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls) 24 | 25 | # 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等 26 | cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen) 27 | cls_reader2 = palm.reader.ClassifyReader(vocab_path, max_seqlen) 28 | print(cls_reader.outputs_attr) 29 | # 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段 30 | cls_reader.register_with(ernie) 31 | cls_reader2.register_with(ernie) 32 | print(cls_reader.outputs_attr) 33 | 34 | print("preparing data...") 35 | print(cls_reader.num_examples) 36 | cls_reader.load_data(train_file, batch_size) 37 | cls_reader2.load_data(train_file, batch_size) 38 | print(cls_reader.num_examples) 39 | print('done!') 40 | 41 | # 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的 42 | cls_head = palm.head.Classify(4, 1024, 0.1) 43 | cls_head2 = palm.head.Classify(4, 1024, 0.1) 44 | 45 | # 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制 46 | trainer = palm.Trainer('cls') 47 | trainer2 = palm.Trainer('senti_cls') 48 | mh_trainer = palm.MultiHeadTrainer([trainer, trainer2]) 49 | 50 | # match4mrqa.reuse_head_with(mrc4mrqa) 51 | 52 | # data_vars = cls_reader.build() 53 | # output_vars = ernie.build(data_vars) 54 | # cls_head.build({'backbone': output_vars, 'reader': data_vars}) 55 | 56 | loss_var = mh_trainer.build_forward(ernie, [cls_head, cls_head2]) 57 | 58 | n_steps = cls_reader.num_examples * num_epochs // batch_size 59 | warmup_steps = int(0.1 * n_steps) 60 | print(warmup_steps) 61 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 62 | 63 | adam = palm.optimizer.Adam(loss_var, lr, sched) 64 | 65 | mh_trainer.build_backward(optimizer=adam, weight_decay=0.001) 66 | 67 | # mh_trainer.random_init_params() 68 | mh_trainer.load_pretrain('pretrain/ernie/params') 69 | 70 | # trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict') 71 | mh_trainer.fit_readers_with_mixratio([cls_reader, cls_reader2], 'cls', 2) 72 | mh_trainer.train(print_steps=1) 73 | # trainer.save() 74 | 75 | -------------------------------------------------------------------------------- /test/test2/run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | python run.py 3 | 4 | -------------------------------------------------------------------------------- /test/test3/config.yaml: -------------------------------------------------------------------------------- 1 | task_instance: "cls1, cls2, cls3, cls4, cls5, cls6" 2 | 3 | task_reuse_tag: 0,0,1,1,0,2 4 | 5 | save_path: "output_model/thirdrun" 6 | 7 | backbone: "ernie" 8 | backbone_config_path: "../../pretrain_model/ernie/ernie_config.json" 9 | 10 | vocab_path: "../../pretrain_model/ernie/vocab.txt" 11 | do_lower_case: True 12 | max_seq_len: 512 13 | 14 | batch_size: 4 15 | num_epochs: 2 16 | optimizer: "adam" 17 | learning_rate: 3e-5 18 | warmup_proportion: 0.1 19 | weight_decay: 0.1 20 | 21 | print_every_n_steps: 1 22 | -------------------------------------------------------------------------------- /test/test3/paddlepalm: -------------------------------------------------------------------------------- 1 | ../../paddlepalm/ -------------------------------------------------------------------------------- /test/test3/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import paddlepalm as palm 3 | import json 4 | 5 | if __name__ == '__main__': 6 | 7 | max_seqlen = 512 8 | batch_size = 4 9 | num_epochs = 2 10 | lr = 1e-3 11 | vocab_path = './pretrain/ernie/vocab.txt' 12 | 13 | train_file = './data/cls4mrqa/train.tsv' 14 | predict_file = './data/cls4mrqa/dev.tsv' 15 | 16 | config = json.load(open('./pretrain/ernie/ernie_config.json')) 17 | # ernie = palm.backbone.ERNIE(...) 18 | ernie = palm.backbone.ERNIE.from_config(config) 19 | 20 | # cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen) 21 | # cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen) 22 | # topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls) 23 | # subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls) 24 | 25 | # 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等 26 | cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen) 27 | <<<<<<< HEAD:test/test2/run.py 28 | cls_reader2 = palm.reader.ClassifyReader(vocab_path, max_seqlen) 29 | ======= 30 | predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, phase='predict') 31 | >>>>>>> remotes/upstream/r0.3-api:test/test3/run.py 32 | print(cls_reader.outputs_attr) 33 | print(predict_cls_reader.outputs_attr) 34 | # 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段 35 | cls_reader.register_with(ernie) 36 | cls_reader2.register_with(ernie) 37 | print(cls_reader.outputs_attr) 38 | <<<<<<< HEAD:test/test2/run.py 39 | 40 | print("preparing data...") 41 | print(cls_reader.num_examples) 42 | cls_reader.load_data(train_file, batch_size) 43 | cls_reader2.load_data(train_file, batch_size) 44 | ======= 45 | print(predict_cls_reader.outputs_attr) 46 | 47 | print("preparing data...") 48 | print(cls_reader.num_examples) 49 | cls_reader.load_data(train_file, batch_size, num_epochs=num_epochs) 50 | >>>>>>> remotes/upstream/r0.3-api:test/test3/run.py 51 | print(cls_reader.num_examples) 52 | print('done!') 53 | 54 | # 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的 55 | cls_head = palm.head.Classify(4, 1024, 0.1) 56 | <<<<<<< HEAD:test/test2/run.py 57 | cls_head2 = palm.head.Classify(4, 1024, 0.1) 58 | 59 | # 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制 60 | trainer = palm.Trainer('cls') 61 | trainer2 = palm.Trainer('senti_cls') 62 | mh_trainer = palm.MultiHeadTrainer([trainer, trainer2]) 63 | ======= 64 | 65 | # 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制 66 | trainer = palm.Trainer('senti_cls') 67 | >>>>>>> remotes/upstream/r0.3-api:test/test3/run.py 68 | 69 | # match4mrqa.reuse_head_with(mrc4mrqa) 70 | 71 | # data_vars = cls_reader.build() 72 | # output_vars = ernie.build(data_vars) 73 | # cls_head.build({'backbone': output_vars, 'reader': data_vars}) 74 | 75 | <<<<<<< HEAD:test/test2/run.py 76 | loss_var = mh_trainer.build_forward(ernie, [cls_head, cls_head2]) 77 | 78 | n_steps = cls_reader.num_examples * num_epochs // batch_size 79 | warmup_steps = int(0.1 * n_steps) 80 | print(warmup_steps) 81 | sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 82 | ======= 83 | loss_var = trainer.build_forward(ernie, cls_head) 84 | 85 | # controller.build_forward() 86 | # Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer! 87 | 88 | # n_steps = cls_reader.num_examples * num_epochs // batch_size 89 | # warmup_steps = int(0.1 * n_steps) 90 | # print(warmup_steps) 91 | # sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) 92 | sched = None 93 | >>>>>>> remotes/upstream/r0.3-api:test/test3/run.py 94 | 95 | adam = palm.optimizer.Adam(loss_var, lr, sched) 96 | 97 | mh_trainer.build_backward(optimizer=adam, weight_decay=0.001) 98 | 99 | # mh_trainer.random_init_params() 100 | mh_trainer.load_pretrain('pretrain/ernie/params') 101 | 102 | # trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict') 103 | <<<<<<< HEAD:test/test2/run.py 104 | mh_trainer.fit_readers_with_mixratio([cls_reader, cls_reader2], 'cls', 2) 105 | mh_trainer.train(print_steps=1) 106 | # trainer.save() 107 | 108 | ======= 109 | trainer.fit_reader(cls_reader) 110 | trainer.train(print_steps=1) 111 | # trainer.save() 112 | 113 | print('prepare to predict...') 114 | pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred') 115 | cls_pred_head = palm.head.Classify(4, 1024, phase='pred') 116 | trainer.build_predict_forward(pred_ernie, cls_pred_head) 117 | 118 | predict_cls_reader.load_data(predict_file, 8) 119 | print(predict_cls_reader.num_examples) 120 | predict_cls_reader.register_with(pred_ernie) 121 | trainer.fit_reader(predict_cls_reader, phase='predict') 122 | print('predicting..') 123 | trainer.predict(print_steps=20) 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | # controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa]) 133 | 134 | # loss = controller.build_forward(bb, mask_task=[]) 135 | 136 | # n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4) 137 | # adam = palm.optimizer.Adam(loss) 138 | # sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps) 139 | # 140 | # controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999) 141 | 142 | # controller.random_init_params() 143 | # controller.load_pretrain('../../pretrain_model/ernie/params') 144 | # controller.train() 145 | 146 | 147 | 148 | 149 | 150 | # controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False) 151 | # controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infer_model') 152 | 153 | 154 | >>>>>>> remotes/upstream/r0.3-api:test/test3/run.py 155 | -------------------------------------------------------------------------------- /test/test3/run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python run.py 4 | 5 | --------------------------------------------------------------------------------