├── img ├── model.png ├── framework.pdf ├── framework.png └── pseudo-labeling.png ├── requirements.txt ├── layers ├── losses │ ├── __init__.py │ └── focal_loss.py ├── __init__.py ├── masked_softmax.py └── attention.py ├── data_loader ├── __init__.py ├── base_loader.py └── multitask_classify_loader.py ├── callbacks ├── __init__.py ├── metric.py ├── ensemble.py └── lr_scheduler.py ├── raw_data └── embeddings │ └── bert-base-uncased │ └── bert-base-uncased-config.json ├── utils ├── __init__.py ├── other.py ├── metrics.py ├── embedding.py ├── nn.py ├── transformers.py └── io.py ├── models ├── __init__.py └── base_model.py ├── .gitignore ├── ensemble.py ├── ensemble_pseudo_label.py ├── README.md ├── train.py ├── train_pseudo_label.py ├── two_level_ensembler.py ├── preprocess.py ├── config.py └── trainer.py /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexYangLi/iswc2020_prodcls/HEAD/img/model.png -------------------------------------------------------------------------------- /img/framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexYangLi/iswc2020_prodcls/HEAD/img/framework.pdf -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexYangLi/iswc2020_prodcls/HEAD/img/framework.png -------------------------------------------------------------------------------- /img/pseudo-labeling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexYangLi/iswc2020_prodcls/HEAD/img/pseudo-labeling.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.1.0 2 | gensim==3.8.3 3 | nltk==3.5 4 | scikit-learn==0.23.1 5 | matplotlib==3.2.1 6 | torch==1.5.0 7 | xgboost==1.1.1 8 | googletrans==3.0.0 9 | pandas==1.0.5 -------------------------------------------------------------------------------- /layers/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/5/29 22:18 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .focal_loss import multi_category_focal_loss2 18 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/5/21 21:01 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .base_loader import load_data 18 | from .multitask_classify_loader import MultiTaskClsDataGenerator 19 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/5/21 14:42 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .masked_softmax import MaskedSoftmax 18 | from .attention import HierarchicalAttentionRecurrent, HierarchicalAttention 19 | -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/4/26 9:07 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .ensemble import SWA, SWAWithCLR, SnapshotEnsemble, HorizontalEnsemble, FGE, PolyakAverage 18 | from .lr_scheduler import LRRangeTest, CyclicLR, SGDR, SGDRScheduler, CyclicLR_1, CyclicLR_2, WarmUp 19 | from .metric import MultiTaskMetric, MaskedMultiTaskMetric 20 | -------------------------------------------------------------------------------- /raw_data/embeddings/bert-base-uncased/bert-base-uncased-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522 19 | } 20 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/5/17 16:02 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .io import format_filename, pickle_load, pickle_dump, write_log, writer_md, ensure_dir, save_prob_to_file, \ 18 | save_diff_to_file, submit_result 19 | from .embedding import train_w2v, train_fasttext, load_pre_trained 20 | from .other import analyze_len, pad_sequences_1d 21 | from .nn import get_optimizer 22 | from .transformers import get_bert_tokenizer, get_bert_config, get_transformer 23 | from .metrics import precision_recall_fscore 24 | -------------------------------------------------------------------------------- /layers/masked_softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: masked_softmax.py 10 | 11 | @time: 2020/5/30 22:18 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import tensorflow.keras.backend as K 18 | from tensorflow.keras.layers import Layer 19 | 20 | 21 | class MaskedSoftmax(Layer): 22 | def __init__(self, **kwargs): 23 | super(MaskedSoftmax, self).__init__(**kwargs) 24 | 25 | def call(self, inputs, mask=None): 26 | prob = inputs[0] 27 | mask = inputs[1] 28 | 29 | exp_prob = K.exp(prob) * mask + 1e-10 30 | return exp_prob / K.sum(exp_prob, axis=1, keepdims=True) 31 | 32 | def compute_output_shape(self, input_shape): 33 | return input_shape 34 | -------------------------------------------------------------------------------- /utils/other.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: other.py 10 | 11 | @time: 2020/5/17 16:02 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import numpy as np 18 | from tensorflow.keras.preprocessing.sequence import pad_sequences 19 | 20 | 21 | def analyze_len(len_list): 22 | def frange(x, y, jump): 23 | while x < y: 24 | yield x 25 | x += jump 26 | 27 | sort_len_list = sorted(len_list) 28 | print(f'max : {sort_len_list[-1]}') 29 | print(f'min : {sort_len_list[0]}') 30 | print(f'mean : {np.mean(sort_len_list)}') 31 | print(f'median : {np.median(sort_len_list)}') 32 | for i in frange(0.5, 1, 0.1): 33 | print(f'{i:.2f} : {sort_len_list[int(len(sort_len_list) * i)]}') 34 | 35 | 36 | def pad_sequences_1d(sequences, max_len=None, padding='post', truncating='post', value=0.): 37 | """pad sequence for [[a, b, c, ...]]""" 38 | return pad_sequences(sequences, maxlen=max_len, padding=padding, truncating=truncating, value=value) 39 | -------------------------------------------------------------------------------- /data_loader/base_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: base_loader.py 10 | 11 | @time: 2020/5/21 21:01 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from config import PROCESSED_DATA_DIR, TRAIN_DATA_FILENAME, DEV_DATA_FILENAME, TEST_DATA_FILENAME, \ 18 | TRAIN_CV_DATA_TEMPLATE, DEV_CV_DATA_TEMPLATE 19 | from utils import pickle_load, format_filename 20 | 21 | 22 | def load_data(data_type, train_on_cv=False, cv_random_state=42, cv_fold=5, cv_index=0): 23 | if data_type == 'train': 24 | if train_on_cv: 25 | data = pickle_load(format_filename(PROCESSED_DATA_DIR, TRAIN_CV_DATA_TEMPLATE, 26 | random=cv_random_state, fold=cv_fold, index=cv_index)) 27 | else: 28 | data = pickle_load(format_filename(PROCESSED_DATA_DIR, TRAIN_DATA_FILENAME)) 29 | elif data_type == 'dev': 30 | if train_on_cv: 31 | data = pickle_load(format_filename(PROCESSED_DATA_DIR, DEV_CV_DATA_TEMPLATE, 32 | random=cv_random_state, fold=cv_fold, index=cv_index)) 33 | else: 34 | data = pickle_load(format_filename(PROCESSED_DATA_DIR, DEV_DATA_FILENAME)) 35 | elif data_type == 'test': 36 | data = pickle_load(format_filename(PROCESSED_DATA_DIR, TEST_DATA_FILENAME)) 37 | else: 38 | raise ValueError('data type not understood: {}'.format(data_type)) 39 | return data 40 | -------------------------------------------------------------------------------- /layers/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: focal_loss.py 10 | 11 | @time: 2020/5/29 22:19 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def multi_category_focal_loss2(gamma=2., alpha=.25): 21 | """ 22 | focal loss for multi category of multi label problem 23 | 适用于多分类或多标签问题的focal loss 24 | alpha控制真值y_true为1/0时的权重 25 | 1的权重为alpha, 0的权重为1-alpha 26 | 当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss 27 | 当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小 28 | 当模型过于惰性(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征) 29 | 尝试将alpha调大,鼓励模型进行预测出1。 30 | Usage: 31 | model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam) 32 | """ 33 | epsilon = 1.e-7 34 | gamma = float(gamma) 35 | alpha = tf.constant(alpha, dtype=tf.float32) 36 | 37 | def focal_loss(y_true, y_pred): 38 | y_true = tf.cast(y_true, tf.float32) 39 | y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon) 40 | 41 | alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha) 42 | y_t = tf.multiply(y_true, y_pred) + tf.multiply(1 - y_true, 1 - y_pred) 43 | ce = -tf.math.log(y_t) 44 | weight = tf.pow(tf.subtract(1., y_t), gamma) 45 | fl = tf.multiply(tf.multiply(weight, ce), alpha_t) 46 | loss = tf.reduce_mean(fl) 47 | return loss 48 | 49 | return focal_loss 50 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: metrics.py 10 | 11 | @time: 2020/6/13 22:19 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import numpy as np 18 | from sklearn.metrics import precision_recall_fscore_support 19 | 20 | 21 | def precision_recall_fscore(pred_cate1_list, true_cate1_list, 22 | pred_cate2_list, true_cate2_list, 23 | pred_cate3_list, true_cate3_list): 24 | cate1_p, cate1_r, cate1_f1, _ = precision_recall_fscore_support(true_cate1_list, 25 | pred_cate1_list, 26 | average='weighted') 27 | print(f'Logging Info - Level 1 Category: ({cate1_p}, {cate1_r}, {cate1_f1})') 28 | 29 | cate2_p, cate2_r, cate2_f1, _ = precision_recall_fscore_support(true_cate2_list, 30 | pred_cate2_list, 31 | average='weighted') 32 | print(f'Logging Info - Level 2 Category: ({cate2_p}, {cate2_r}, {cate2_f1})') 33 | 34 | cate3_p, cate3_r, cate3_f1, _ = precision_recall_fscore_support(true_cate3_list, 35 | pred_cate3_list, 36 | average='weighted') 37 | print(f'Logging Info - Level 3 Category: ({cate3_p}, {cate3_r}, {cate3_f1})') 38 | 39 | val_p = np.mean([cate1_p, cate2_p, cate3_p]) 40 | val_r = np.mean([cate1_r, cate2_r, cate3_r]) 41 | val_f1 = np.mean([cate1_f1, cate2_f1, cate3_f1]) 42 | print(f'Logging Info - ALL Level Category: ({val_p}, {val_r}, {val_f1})') 43 | 44 | eval_results = { 45 | 'cate1_p': cate1_p, 'cate1_r': cate1_r, 'cate1_f1': cate1_f1, 46 | 'cate2_p': cate2_p, 'cate2_r': cate2_r, 'cate2_f1': cate2_f1, 47 | 'cate3_p': cate3_p, 'cate3_r': cate3_r, 'cate3_f1': cate3_f1, 48 | 'val_p': val_p, 'val_r': val_r, 'val_f1': val_f1 49 | } 50 | return eval_results 51 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: __init__.py.py 10 | 11 | @time: 2020/5/21 7:24 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from .multitask_classify_model import BiLSTM, CNNRNN, DPCNN, MultiTextCNN, RCNN, RNNCNN, TextCNN, BertClsModel, \ 18 | RobertaClsModel, XLNetClsModel, GPT2ClsModel, TransfoXLClsModel, DistllBertClsModel 19 | from .sklearn_base_model import BernoulliNBModel, DecisionTreeModel, ExtraTreeModel, EnsembleExtraTreeModel, \ 20 | GaussianNBModel, KNeighborsModel, LDAModel, LinearSVCModel, LRCVModel, LRModel, MLPModel, RandomForestModel, \ 21 | GBDTModel, XGBoostModel 22 | 23 | MultiTaskClsModel = { 24 | 'bilstm': BiLSTM, 25 | 'cnnrnn': CNNRNN, 26 | 'dpcnn': DPCNN, 27 | 'multicnn': MultiTextCNN, 28 | 'rnncnn': RNNCNN, 29 | 'cnn': TextCNN, 30 | 'bert-base-uncased': BertClsModel, 31 | 'bert-base-cased': BertClsModel, 32 | 'bert-large-uncased': BertClsModel, 33 | 'bert-large-uncased-whole-word-masking': BertClsModel, 34 | 'bert-large-uncased-whole-word-masking-finetuned-squad': BertClsModel, 35 | 'roberta-base': RobertaClsModel, 36 | 'prod-bert-base-uncased': BertClsModel, 37 | 'prod-roberta-base-cased': RobertaClsModel, 38 | 'roberta-large': RobertaClsModel, 39 | 'roberta-large-mnli': RobertaClsModel, 40 | 'distilroberta-base': RobertaClsModel, 41 | 'tune_bert-base-uncased_nsp': BertClsModel, 42 | 'xlnet-base-cased': XLNetClsModel, 43 | 'albert-base-v1': BertClsModel, 44 | 'albert-large-v1': BertClsModel, 45 | 'albert-xlarge-v1': BertClsModel, 46 | 'albert-xxlarge-v1': BertClsModel, 47 | 'gpt2': GPT2ClsModel, 48 | 'gpt2-medium': GPT2ClsModel, 49 | 'transfo-xl': TransfoXLClsModel, 50 | 'distilbert-base-uncased': DistllBertClsModel, 51 | 'distilbert-base-uncased-distilled-squad': DistllBertClsModel, 52 | } 53 | 54 | 55 | SklearnEnsembleModel = { 56 | 'bnb': BernoulliNBModel, # 输出概率 57 | 'dt': DecisionTreeModel, # 输出准确值 58 | 'et': ExtraTreeModel, # 输出准确值 59 | 'eet': EnsembleExtraTreeModel, # 输出准确值 60 | 'gnb': GaussianNBModel, # 输出准确值 61 | 'kn': KNeighborsModel, # 输出准确值 62 | 'lda': LDAModel, 63 | 'svc': LinearSVCModel, 64 | 'lr': LRModel, 65 | 'mlp': MLPModel, 66 | 'rf': RandomForestModel, 67 | 'gbdt': GBDTModel, 68 | 'xgboost': XGBoostModel 69 | } 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: ensemble.py 10 | 11 | @time: 2020/6/8 22:01 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | 18 | from two_level_ensembler import voting_of_averaging 19 | 20 | 21 | if __name__ == '__main__': 22 | cv_model_list = [ 23 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_5_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 24 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_4_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 25 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_3_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 26 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_2_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 27 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_1_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 28 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_5_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 29 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_4_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 30 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_3_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 31 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_2_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 32 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_1_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 33 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_1_harl_100_1.0_1.0_1.0_adam_2e-05_32_50', 34 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 35 | 'bert-base-uncased_name_desc_pair_len_200_bert_gru_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 36 | 'bert-base-uncased_name_desc_pair_len_200_bert_gru_lstm_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 37 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_gru_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 38 | 'bert-base-uncased_name_desc_pair_len_200_bert_cnn_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 39 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_cnn_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 40 | 'bert-base-uncased_name_desc_pair_len_200_bert_pooler_tune_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50' 41 | ] 42 | voting_of_averaging(prefix_model_name_list=cv_model_list, 43 | submit_file_prefix='cross_validation', 44 | cv_random_state=42, 45 | cv_fold=5, 46 | use_ex_pair=False, 47 | use_pseudo=False) 48 | 49 | -------------------------------------------------------------------------------- /ensemble_pseudo_label.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: ensemble.py 10 | 11 | @time: 2020/6/8 22:01 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | 18 | from two_level_ensembler import voting_of_averaging 19 | 20 | 21 | if __name__ == '__main__': 22 | '''4. 伪标签交叉验证模型集成''' 23 | cv_model_list = [ 24 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_5_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 25 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_4_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 26 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_3_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 27 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_2_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 28 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_pooler_tune_hid_1_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 29 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_5_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 30 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_4_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 31 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_3_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 32 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_2_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 33 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_1_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 34 | 'bert-base-uncased_name_desc_pair_len_200_bert_hidden_tune_hid_1_harl_100_1.0_1.0_1.0_adam_2e-05_32_50', 35 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 36 | 'bert-base-uncased_name_desc_pair_len_200_bert_gru_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 37 | 'bert-base-uncased_name_desc_pair_len_200_bert_gru_lstm_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 38 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_gru_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 39 | 'bert-base-uncased_name_desc_pair_len_200_bert_cnn_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 40 | 'bert-base-uncased_name_desc_pair_len_200_bert_lstm_cnn_tune_dense_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50', 41 | 'bert-base-uncased_name_desc_pair_len_200_bert_pooler_tune_1.0_1.0_1.0_mask_cate3_with_cate1_adam_2e-05_32_50' 42 | ] 43 | voting_of_averaging(prefix_model_name_list=cv_model_list, 44 | submit_file_prefix='cross_validation', 45 | cv_random_state=42, 46 | cv_fold=5, 47 | use_ex_pair=False, 48 | use_pseudo=True, 49 | pseudo_random_state=42, 50 | pseudo_rate=5) 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iswc2020_prodcls 2 | ISWC2020 Semantic Web Challenge - Product Classification Top1 Solution 3 | 4 | ## Solution 5 | 6 | ### Overall Framework 7 | ![framework](./img/framework.png) 8 | 9 | 1. Train 17 different BERT base models with dynamic masked softmax 10 | 2. Adopt a two-level ensemble strategy to combine the single models 11 | 3. Utilize pseudo labeling for data augmentation 12 | 13 | ### Base Model Construction 14 | ![model](./img/model.png) 15 | 16 | - **BERT-PO** uses the pooler output of BERT as the product representation. 17 | - **BERT-*K*-hidden** concatenates the first hidden state from the last $K$ hidden layers of BERT as the product representation. 18 | - **BERT-*K*-hidden-PO** concatenates the first hidden state from the last $K$ hidden layers as well as the pooler output of BERT as the product representation. 19 | - **BERT-seq** uses the hidden states from the last hidden layer of BERT as the input of another sequence layer, and then concatenates the pooler output of BERT, with the last hidden output as well as the max-pooling and mean-pooling over the hidden states of sequence layer, as the final product representation. 20 | 21 | ### Dynamic Masked Softmax 22 | 23 | 1. Devise a mask matrix for each sub-level based on the category hierarchy 24 | $$ 25 | M^{l} \in\{0,1\}^{N^{l-1} * N^{l}} 26 | $$ 27 | 2. Adopt dynamic masked softmax to filter out the unrelated categories 28 | $$ 29 | P\left(y_{v}^{l} \mid s, \theta\right)=\frac{\exp \left(O_{v}^{l}\right) * M_{u, v}^{l}+\exp (-8)}{\sum_{v^{\prime}=1}^{N} \exp \left(O_{v^{\prime}}^{l}\right) * M_{u, v^{\prime}}^{l}+\exp (-8)} 30 | $$ 31 | 32 | ### Model Ensemble 33 | 34 | 1. Averaging ensemble to the single models with the same architecture but trained on different folds of data 35 | 2. Voting ensemble to 17 different single models 36 | 37 | ### Pseudo Labeling 38 | ![pseudo-labelding](./img/pseudo-labeling.png) 39 | 40 | ## Experiment 41 | 42 | ### Environment Preparation 43 | ```shell script 44 | pip install virtualenv 45 | virtualenv tf2 46 | source tf2/bin/activate 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ### Data Preparation 51 | 52 | 1. dataset 53 | Download the dataset below, put them in the `raw_data` dir: 54 | - [train.json](https://drive.google.com/open?id=1WirDfqGvBYgly27egMx6Om9QeXO6B2UX) 55 | - [validation.json](https://drive.google.com/open?id=1WirDfqGvBYgly27egMx6Om9QeXO6B2UX) 56 | - [test_public.json](https://bit.ly/2Yr0dkb) 57 | - [task2_testset_with_labels.json](https://drive.google.com/file/d/1RI27LIp_s-LP10eKKNWz914bapfRhOJl/view?usp=sharing) 58 | 59 | 60 | 2. pre-trained BERT model 61 | Download the files of BERT model, put them in the `raw_data/embeddings/bert-base-uncased` dir: 62 | - [bert-base-uncased-vocab.txt](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt) 63 | - [bert-base-uncased-config.json](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json) 64 | - [bert-base-uncased-tf_model.h5](https://cdn.huggingface.co/bert-base-uncased-tf_model.h5) 65 | 66 | 67 | ### Pre-processing 68 | ```shell script 69 | python preprocess.py 70 | ``` 71 | 72 | ### (1st) Training 73 | ```shell script 74 | python train.py 75 | ``` 76 | 77 | ### (1st) Ensemble and Pseudo Labeling 78 | ```shell script 79 | python ensemble.py 80 | ``` 81 | 82 | ### (2nd) Re-training with Pseudo Labels 83 | ```shell script 84 | python train_pseudo_label.py 85 | ``` 86 | 87 | ### (2nd) Re-ensemble 88 | ```shell script 89 | python ensemble_pseudo_label.py 90 | ``` -------------------------------------------------------------------------------- /utils/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: embedding.py 10 | 11 | @time: 2020/4/26 13:34 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import numpy as np 18 | from gensim.models import Word2Vec 19 | from gensim.models import KeyedVectors 20 | from gensim.models import FastText 21 | 22 | 23 | def load_glove_format(filename): 24 | word_vectors = {} 25 | embeddings_dim = -1 26 | with open(filename, 'r') as f: 27 | for line in f: 28 | line = line.strip().split() 29 | 30 | try: 31 | word = line[0] 32 | word_vector = np.array([float(v) for v in line[1:]]) 33 | except ValueError: 34 | continue 35 | 36 | if embeddings_dim == -1: 37 | embeddings_dim = len(word_vector) 38 | 39 | if len(word_vector) != embeddings_dim: 40 | continue 41 | 42 | word_vectors[word] = word_vector 43 | 44 | assert all(len(vw) == embeddings_dim for vw in word_vectors.values()) 45 | 46 | return word_vectors, embeddings_dim 47 | 48 | 49 | def load_pre_trained(load_filename, vocabulary=None): 50 | word_vectors = {} 51 | try: 52 | model = KeyedVectors.load_word2vec_format(load_filename) 53 | weights = model.wv.syn0 54 | embedding_dim = weights.shape[1] 55 | for k, v in model.wv.vocab.items(): 56 | word_vectors[k] = weights[v.index, :] 57 | except ValueError: 58 | word_vectors, embedding_dim = load_glove_format(load_filename) 59 | 60 | if vocabulary is not None: 61 | emb = np.zeros(shape=(len(vocabulary) + 2, embedding_dim), dtype='float32') 62 | emb[1] = np.random.normal(0, 0.05, embedding_dim) 63 | 64 | nb_unk = 0 65 | for w, i in vocabulary.items(): 66 | if w not in word_vectors: 67 | nb_unk += 1 68 | emb[i, :] = np.random.normal(0, 0.05, embedding_dim) 69 | else: 70 | emb[i, :] = word_vectors[w] 71 | print('Logging Info - From {} Embedding matrix created : {}, unknown tokens: {}'.format(load_filename, emb.shape, 72 | nb_unk)) 73 | return emb 74 | else: 75 | print('Logging Info - Loading {} Embedding : {}'.format(load_filename, (len(word_vectors), embedding_dim))) 76 | return word_vectors 77 | 78 | 79 | def train_w2v(corpus, vocabulary, embedding_dim=300): 80 | model = Word2Vec(corpus, size=embedding_dim, min_count=1, window=5, sg=1, iter=10) 81 | weights = model.wv.syn0 82 | d = dict([(k, v.index) for k, v in model.wv.vocab.items()]) 83 | emb = np.zeros(shape=(len(vocabulary) + 2, embedding_dim), dtype='float32') # 0 for mask, 1 for unknown token 84 | emb[1] = np.random.normal(0, 0.05, embedding_dim) 85 | 86 | nb_unk = 0 87 | for w, i in vocabulary.items(): 88 | if w not in d: 89 | nb_unk += 1 90 | emb[i, :] = np.random.normal(0, 0.05, embedding_dim) 91 | else: 92 | emb[i, :] = weights[d[w], :] 93 | print('Logging Info - Word2Vec Embedding matrix created: {}, unknown tokens: {}'.format(emb.shape, nb_unk)) 94 | return emb 95 | 96 | 97 | def train_fasttext(corpus, vocabulary, embedding_dim=300): 98 | model = FastText(size=embedding_dim, min_count=1, window=5, sg=1, word_ngrams=1) 99 | model.build_vocab(sentences=corpus) 100 | model.train(sentences=corpus, total_examples=len(corpus), epochs=10) 101 | 102 | emb = np.zeros(shape=(len(vocabulary) + 2, embedding_dim), dtype='float32') # 0 for mask, 1 for unknown token 103 | emb[1] = np.random.normal(0, 0.05, embedding_dim) 104 | 105 | for w, i in vocabulary.items(): 106 | emb[i, :] = model.wv[w] # note that oov words can still have word vectors 107 | 108 | print('Logging Info - Fasttext Embedding matrix created: {}'.format(emb.shape)) 109 | return emb 110 | -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: nn.py 10 | 11 | @time: 2020/4/26 15:46 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from tensorflow.keras import optimizers 18 | from transformers import AdamWeightDecay 19 | 20 | 21 | import tensorflow as tf 22 | 23 | 24 | class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): 25 | """Applies a warmup schedule on a given learning rate decay schedule.""" 26 | 27 | def __init__( 28 | self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None, 29 | ): 30 | super().__init__() 31 | self.initial_learning_rate = initial_learning_rate 32 | self.warmup_steps = warmup_steps 33 | self.power = power 34 | self.decay_schedule_fn = decay_schedule_fn 35 | self.name = name 36 | 37 | def __call__(self, step): 38 | with tf.name_scope(self.name or "WarmUp") as name: 39 | # Implements polynomial warmup. i.e., if global_step < warmup_steps, the 40 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 41 | global_step_float = tf.cast(step, tf.float32) 42 | warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) 43 | warmup_percent_done = global_step_float / warmup_steps_float 44 | warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power) 45 | return tf.cond( 46 | global_step_float < warmup_steps_float, 47 | lambda: warmup_learning_rate, 48 | lambda: self.decay_schedule_fn(step), 49 | name=name, 50 | ) 51 | 52 | def get_config(self): 53 | return { 54 | "initial_learning_rate": self.initial_learning_rate, 55 | "decay_schedule_fn": self.decay_schedule_fn, 56 | "warmup_steps": self.warmup_steps, 57 | "power": self.power, 58 | "name": self.name, 59 | } 60 | 61 | 62 | def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type="adamw"): 63 | """Creates an optimizer with learning rate schedule.""" 64 | # Implements linear decay of the learning rate. 65 | lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( 66 | initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr, 67 | ) 68 | if num_warmup_steps: 69 | lr_schedule = WarmUp( 70 | initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps, 71 | ) 72 | 73 | optimizer = AdamWeightDecay( 74 | learning_rate=lr_schedule, 75 | weight_decay_rate=0.01, 76 | beta_1=0.9, 77 | beta_2=0.999, 78 | epsilon=1e-6, 79 | exclude_from_weight_decay=["layer_norm", "bias"], 80 | ) 81 | 82 | return optimizer 83 | 84 | 85 | def get_optimizer(op_type, learning_rate): 86 | if op_type == 'sgd': 87 | return optimizers.SGD(learning_rate) 88 | elif op_type == 'rmsprop': 89 | return optimizers.RMSprop(learning_rate) 90 | elif op_type == 'adagrad': 91 | return optimizers.Adagrad(learning_rate) 92 | elif op_type == 'adadelta': 93 | return optimizers.Adadelta(learning_rate) 94 | elif op_type == 'adam': 95 | return optimizers.Adam(learning_rate, clipnorm=5) 96 | elif op_type == 'adamw': 97 | return AdamWeightDecay(learning_rate=learning_rate, 98 | weight_decay_rate=0.01, 99 | beta_1=0.9, 100 | beta_2=0.999, 101 | epsilon=1e-6, 102 | exclude_from_weight_decay=["layer_norm", "bias"]) 103 | elif op_type == 'adamw_2': 104 | return create_optimizer(init_lr=learning_rate, num_train_steps=9000, num_warmup_steps=0) 105 | elif op_type == 'adamw_3': 106 | return create_optimizer(init_lr=learning_rate, num_train_steps=9000, num_warmup_steps=100) 107 | else: 108 | raise ValueError('Optimizer Not Understood: {}'.format(op_type)) 109 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: train.py 10 | 11 | @time: 2020/5/21 22:06 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from trainer import main 18 | 19 | if __name__ == '__main__': 20 | 21 | # train without pseudo labeling 22 | for n_hidden in range(1, 6): 23 | for cv_index in range(5): 24 | main(model_type='bert-base-uncased', 25 | input_type='name_desc', 26 | use_word_input=False, 27 | use_bert_input=True, 28 | bert_trainable=True, 29 | batch_size=32, 30 | predict_batch_size=32, 31 | use_pair_input=True, 32 | use_bert_type='hidden', 33 | n_last_hidden_layer=n_hidden, 34 | dense_after_bert=True, 35 | learning_rate=2e-5, 36 | use_multi_task=True, 37 | use_harl=False, 38 | use_mask_for_cate2=False, 39 | use_mask_for_cate3=True, 40 | cate3_mask_type='cate1', 41 | train_on_cv=True, 42 | cv_random_state=42, 43 | cv_fold=5, 44 | cv_index=cv_index, 45 | exchange_pair=True, 46 | use_pseudo_label=False, 47 | use_gpu_id=7) 48 | 49 | for n_hidden in range(1, 6): 50 | for cv_index in range(5): 51 | main(model_type='bert-base-uncased', 52 | input_type='name_desc', 53 | use_word_input=False, 54 | use_bert_input=True, 55 | bert_trainable=True, 56 | batch_size=32, 57 | predict_batch_size=32, 58 | use_pair_input=True, 59 | use_bert_type='hidden_pooler', 60 | n_last_hidden_layer=n_hidden, 61 | dense_after_bert=True, 62 | learning_rate=2e-5, 63 | use_multi_task=True, 64 | use_harl=False, 65 | use_mask_for_cate2=False, 66 | use_mask_for_cate3=True, 67 | cate3_mask_type='cate1', 68 | train_on_cv=True, 69 | cv_random_state=42, 70 | cv_fold=5, 71 | cv_index=cv_index, 72 | exchange_pair=True, 73 | use_pseudo_label=False, 74 | use_gpu_id=6) 75 | 76 | for use_bert_type in ['lstm', 'gru', 'lstm_gru', 'gru_lstm', 'cnn', 'lstm_cnn']: 77 | for cv_index in range(5): 78 | main(model_type='bert-base-uncased', 79 | input_type='name_desc', 80 | use_word_input=False, 81 | use_bert_input=True, 82 | bert_trainable=True, 83 | batch_size=32, 84 | predict_batch_size=32, 85 | use_pair_input=True, 86 | use_bert_type=use_bert_type, 87 | n_last_hidden_layer=0, 88 | dense_after_bert=True, 89 | learning_rate=2e-5, 90 | use_multi_task=True, 91 | use_harl=False, 92 | use_mask_for_cate2=False, 93 | use_mask_for_cate3=True, 94 | cate3_mask_type='cate1', 95 | train_on_cv=True, 96 | cv_random_state=42, 97 | cv_fold=5, 98 | cv_index=cv_index, 99 | exchange_pair=True, 100 | use_pseudo_label=False, 101 | use_gpu_id=5) 102 | 103 | for cv_index in range(5): 104 | main(model_type='bert-base-uncased', 105 | input_type='name_desc', 106 | use_word_input=False, 107 | use_bert_input=True, 108 | bert_trainable=True, 109 | batch_size=32, 110 | predict_batch_size=32, 111 | use_pair_input=True, 112 | use_bert_type='hidden', 113 | n_last_hidden_layer=1, 114 | dense_after_bert=False, 115 | learning_rate=2e-5, 116 | use_multi_task=True, 117 | use_harl=True, 118 | use_mask_for_cate2=False, 119 | use_mask_for_cate3=True, 120 | cate3_mask_type='cate1', 121 | train_on_cv=True, 122 | cv_random_state=42, 123 | cv_fold=5, 124 | cv_index=cv_index, 125 | exchange_pair=True, 126 | use_pseudo_label=False, 127 | use_gpu_id=4) 128 | 129 | for cv_index in range(5): 130 | main(model_type='bert-base-uncased', 131 | input_type='name_desc', 132 | use_word_input=False, 133 | use_bert_input=True, 134 | bert_trainable=True, 135 | batch_size=32, 136 | predict_batch_size=32, 137 | use_pair_input=True, 138 | use_bert_type='pooler', 139 | n_last_hidden_layer=0, 140 | dense_after_bert=False, 141 | learning_rate=2e-5, 142 | use_multi_task=True, 143 | use_mask_for_cate2=False, 144 | use_mask_for_cate3=True, 145 | cate3_mask_type='cate1', 146 | train_on_cv=True, 147 | cv_random_state=42, 148 | cv_fold=5, 149 | cv_index=cv_index, 150 | exchange_pair=True, 151 | use_pseudo_label=False, 152 | use_gpu_id=3) 153 | -------------------------------------------------------------------------------- /train_pseudo_label.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: train.py 10 | 11 | @time: 2020/5/21 22:06 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from trainer import main 18 | 19 | if __name__ == '__main__': 20 | 21 | # with pseudo labeling 22 | for n_hidden in range(1, 6): 23 | for cv_index in range(5): 24 | main(model_type='bert-base-uncased', 25 | input_type='name_desc', 26 | use_word_input=False, 27 | use_bert_input=True, 28 | bert_trainable=True, 29 | batch_size=32, 30 | predict_batch_size=32, 31 | use_pair_input=True, 32 | use_bert_type='hidden', 33 | n_last_hidden_layer=n_hidden, 34 | dense_after_bert=True, 35 | learning_rate=2e-5, 36 | use_multi_task=True, 37 | use_harl=False, 38 | use_mask_for_cate2=False, 39 | use_mask_for_cate3=True, 40 | cate3_mask_type='cate1', 41 | train_on_cv=True, 42 | cv_random_state=42, 43 | cv_fold=5, 44 | cv_index=cv_index, 45 | exchange_pair=True, 46 | use_pseudo_label=True, 47 | pseudo_path='./submit/cross_validation_test_vote_ensemble.csv', 48 | pseudo_random_state=42, 49 | pseudo_rate=5, 50 | pseudo_index=cv_index, 51 | pseudo_name='cv_test_vote', 52 | use_gpu_id=7) 53 | 54 | for n_hidden in range(1, 6): 55 | for cv_index in range(5): 56 | main(model_type='bert-base-uncased', 57 | input_type='name_desc', 58 | use_word_input=False, 59 | use_bert_input=True, 60 | bert_trainable=True, 61 | batch_size=32, 62 | predict_batch_size=32, 63 | use_pair_input=True, 64 | use_bert_type='hidden_pooler', 65 | n_last_hidden_layer=n_hidden, 66 | dense_after_bert=True, 67 | learning_rate=2e-5, 68 | use_multi_task=True, 69 | use_harl=False, 70 | use_mask_for_cate2=False, 71 | use_mask_for_cate3=True, 72 | cate3_mask_type='cate1', 73 | train_on_cv=True, 74 | cv_random_state=42, 75 | cv_fold=5, 76 | cv_index=cv_index, 77 | exchange_pair=True, 78 | use_pseudo_label=True, 79 | pseudo_path='./submit/cross_validation_test_vote_ensemble.csv', 80 | pseudo_random_state=42, 81 | pseudo_rate=5, 82 | pseudo_index=cv_index, 83 | pseudo_name='cv_test_vote', 84 | use_gpu_id=6) 85 | 86 | for use_bert_type in ['lstm', 'gru', 'lstm_gru', 'gru_lstm', 'cnn', 'lstm_cnn']: 87 | for cv_index in range(5): 88 | main(model_type='bert-base-uncased', 89 | input_type='name_desc', 90 | use_word_input=False, 91 | use_bert_input=True, 92 | bert_trainable=True, 93 | batch_size=32, 94 | predict_batch_size=32, 95 | use_pair_input=True, 96 | use_bert_type=use_bert_type, 97 | n_last_hidden_layer=0, 98 | dense_after_bert=True, 99 | learning_rate=2e-5, 100 | use_multi_task=True, 101 | use_harl=False, 102 | use_mask_for_cate2=False, 103 | use_mask_for_cate3=True, 104 | cate3_mask_type='cate1', 105 | train_on_cv=True, 106 | cv_random_state=42, 107 | cv_fold=5, 108 | cv_index=cv_index, 109 | exchange_pair=True, 110 | use_pseudo_label=True, 111 | pseudo_path='./submit/cross_validation_test_vote_ensemble.csv', 112 | pseudo_random_state=42, 113 | pseudo_rate=5, 114 | pseudo_index=cv_index, 115 | pseudo_name='cv_test_vote', 116 | use_gpu_id=5) 117 | 118 | for cv_index in range(5): 119 | main(model_type='bert-base-uncased', 120 | input_type='name_desc', 121 | use_word_input=False, 122 | use_bert_input=True, 123 | bert_trainable=True, 124 | batch_size=32, 125 | predict_batch_size=32, 126 | use_pair_input=True, 127 | use_bert_type='hidden', 128 | n_last_hidden_layer=1, 129 | dense_after_bert=False, 130 | learning_rate=2e-5, 131 | use_multi_task=True, 132 | use_harl=True, 133 | use_mask_for_cate2=False, 134 | use_mask_for_cate3=True, 135 | cate3_mask_type='cate1', 136 | train_on_cv=True, 137 | cv_random_state=42, 138 | cv_fold=5, 139 | cv_index=cv_index, 140 | exchange_pair=True, 141 | use_pseudo_label=True, 142 | pseudo_path='./submit/cross_validation_test_vote_ensemble.csv', 143 | pseudo_random_state=42, 144 | pseudo_rate=5, 145 | pseudo_index=cv_index, 146 | pseudo_name='cv_test_vote', 147 | use_gpu_id=4) 148 | 149 | for cv_index in range(5): 150 | main(model_type='bert-base-uncased', 151 | input_type='name_desc', 152 | use_word_input=False, 153 | use_bert_input=True, 154 | bert_trainable=True, 155 | batch_size=32, 156 | predict_batch_size=32, 157 | use_pair_input=True, 158 | use_bert_type='pooler', 159 | n_last_hidden_layer=0, 160 | dense_after_bert=False, 161 | learning_rate=2e-5, 162 | use_multi_task=True, 163 | use_mask_for_cate2=False, 164 | use_mask_for_cate3=True, 165 | cate3_mask_type='cate1', 166 | train_on_cv=True, 167 | cv_random_state=42, 168 | cv_fold=5, 169 | cv_index=cv_index, 170 | exchange_pair=True, 171 | use_pseudo_label=True, 172 | pseudo_path='./submit/cross_validation_test_vote_ensemble.csv', 173 | pseudo_random_state=42, 174 | pseudo_rate=5, 175 | pseudo_index=cv_index, 176 | pseudo_name='cv_test_vote', 177 | use_gpu_id=4) 178 | -------------------------------------------------------------------------------- /callbacks/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: metric.py 10 | 11 | @time: 2020/5/21 7:33 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import numpy as np 18 | from tensorflow.keras.callbacks import Callback 19 | from sklearn.metrics import precision_recall_fscore_support 20 | 21 | from data_loader import MultiTaskClsDataGenerator 22 | from config import ModelConfig 23 | from utils import precision_recall_fscore 24 | 25 | 26 | class MultiTaskMetric(Callback): 27 | def __init__(self, valid_generator: MultiTaskClsDataGenerator, config: ModelConfig): 28 | self.valid_generator = valid_generator 29 | self.config = config 30 | super(MultiTaskMetric, self).__init__() 31 | 32 | def predict_multitask(self, x): 33 | pred_cate1, pred_cate2, pred_cate3 = self.model.predict(x) 34 | pred_cate1 = np.argmax(pred_cate1, axis=-1).tolist() 35 | pred_cate2 = np.argmax(pred_cate2, axis=-1).tolist() 36 | pred_cate3 = np.argmax(pred_cate3, axis=-1).tolist() 37 | return pred_cate1, pred_cate2, pred_cate3 38 | 39 | def predict_one_task(self, x): 40 | pred_all_cate = np.argmax(self.model.predict(x), axis=-1).tolist() 41 | pred_all_cate = map(self.config.idx2all_cate.get, pred_all_cate) 42 | pred_all_cate_split = map(lambda c: c.split('|'), pred_all_cate) 43 | pred_cate1, pred_cate2, pred_cate3 = zip(*list(pred_all_cate_split)) 44 | 45 | pred_cate1 = list(map(self.config.cate1_vocab.get, pred_cate1)) 46 | pred_cate2 = list(map(self.config.cate2_vocab.get, pred_cate2)) 47 | pred_cate3 = list(map(self.config.cate3_vocab.get, pred_cate3)) 48 | return pred_cate1, pred_cate2, pred_cate3 49 | 50 | def on_epoch_end(self, epoch, logs={}): 51 | pred_cate1_list, true_cate1_list = [], [] 52 | pred_cate2_list, true_cate2_list = [], [] 53 | pred_cate3_list, true_cate3_list = [], [] 54 | for x_valid, y_valid in self.valid_generator: 55 | if self.config.use_multi_task: 56 | pred_cate1, pred_cate2, pred_cate3 = self.predict_multitask(x_valid) 57 | else: 58 | pred_cate1, pred_cate2, pred_cate3 = self.predict_one_task(x_valid) 59 | pred_cate1_list.extend(pred_cate1) 60 | pred_cate2_list.extend(pred_cate2) 61 | pred_cate3_list.extend(pred_cate3) 62 | 63 | true_cate1_list.extend(np.argmax(y_valid[0], axis=-1).tolist()) 64 | true_cate2_list.extend(np.argmax(y_valid[1], axis=-1).tolist()) 65 | true_cate3_list.extend(np.argmax(y_valid[2], axis=-1).tolist()) 66 | 67 | print(f'Logging Info - Epoch: {epoch+1} evaluation:') 68 | eval_results = precision_recall_fscore(pred_cate1_list=pred_cate1_list, true_cate1_list=true_cate1_list, 69 | pred_cate2_list=pred_cate2_list, true_cate2_list=true_cate2_list, 70 | pred_cate3_list=pred_cate3_list, true_cate3_list=true_cate3_list) 71 | logs.update(eval_results) 72 | 73 | 74 | class MaskedMultiTaskMetric(Callback): 75 | def __init__(self, 76 | valid_generator: MultiTaskClsDataGenerator, 77 | cate1_model, 78 | cate2_model, 79 | cate3_model, 80 | config: ModelConfig): 81 | self.valid_generator = valid_generator 82 | self.cate1_model = cate1_model 83 | self.cate2_model = cate2_model 84 | self.cate3_model = cate3_model 85 | self.config = config 86 | assert self.config.use_mask_for_cate2 or self.config.use_mask_for_cate3 87 | if self.config.use_mask_for_cate2: 88 | assert self.valid_generator.cate1_to_cate2_matrix is not None 89 | if self.config.use_mask_for_cate3: 90 | assert self.valid_generator.cate_to_cate3_matrix is not None 91 | super(MaskedMultiTaskMetric, self).__init__() 92 | 93 | def predict_masked_multitask(self, x): 94 | pred_cate1 = self.cate1_model.predict(x) 95 | pred_cate1 = np.argmax(pred_cate1, axis=-1) 96 | 97 | inputs_for_cate2_model = x 98 | if self.config.use_mask_for_cate2: 99 | input_cate2_mask = self.valid_generator.cate1_to_cate2_matrix[pred_cate1] 100 | if isinstance(inputs_for_cate2_model, list): 101 | inputs_for_cate2_model = inputs_for_cate2_model + [input_cate2_mask] 102 | else: 103 | inputs_for_cate2_model = [inputs_for_cate2_model, input_cate2_mask] 104 | pred_cate2 = self.cate2_model.predict(inputs_for_cate2_model) 105 | pred_cate2 = np.argmax(pred_cate2, axis=-1) 106 | 107 | inputs_for_cate3_model = x 108 | if self.config.use_mask_for_cate3: 109 | if self.config.cate3_mask_type == 'cate1': 110 | input_cate3_mask = self.valid_generator.cate_to_cate3_matrix[pred_cate1] 111 | elif self.config.cate3_mask_type == 'cate2': 112 | input_cate3_mask = self.valid_generator.cate_to_cate3_matrix[pred_cate2] 113 | else: 114 | raise ValueError(f'`cate3_mask_type` not understood: {self.config.cate3_mask_type}') 115 | if isinstance(inputs_for_cate3_model, list): 116 | inputs_for_cate3_model = inputs_for_cate3_model + [input_cate3_mask] 117 | else: 118 | inputs_for_cate3_model = [inputs_for_cate3_model, input_cate3_mask] 119 | pred_cate3 = self.cate3_model.predict(inputs_for_cate3_model) 120 | pred_cate3 = np.argmax(pred_cate3, axis=-1) 121 | 122 | return pred_cate1.tolist(), pred_cate2.tolist(), pred_cate3.tolist() 123 | 124 | def on_epoch_end(self, epoch, logs={}): 125 | pred_cate1_list, true_cate1_list = [], [] 126 | pred_cate2_list, true_cate2_list = [], [] 127 | pred_cate3_list, true_cate3_list = [], [] 128 | for x_valid, y_valid in self.valid_generator: 129 | pred_cate1, pred_cate2, pred_cate3 = self.predict_masked_multitask(x_valid) 130 | 131 | pred_cate1_list.extend(pred_cate1) 132 | pred_cate2_list.extend(pred_cate2) 133 | pred_cate3_list.extend(pred_cate3) 134 | 135 | true_cate1_list.extend(np.argmax(y_valid[0], axis=-1).tolist()) 136 | true_cate2_list.extend(np.argmax(y_valid[1], axis=-1).tolist()) 137 | true_cate3_list.extend(np.argmax(y_valid[2], axis=-1).tolist()) 138 | 139 | print(f'Logging Info - Epoch: {epoch + 1} evaluation:') 140 | eval_results = precision_recall_fscore(pred_cate1_list=pred_cate1_list, true_cate1_list=true_cate1_list, 141 | pred_cate2_list=pred_cate2_list, true_cate2_list=true_cate2_list, 142 | pred_cate3_list=pred_cate3_list, true_cate3_list=true_cate3_list) 143 | logs.update(eval_results) 144 | -------------------------------------------------------------------------------- /utils/transformers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: transformers.py 10 | 11 | @time: 2020/5/23 20:34 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from transformers import BertTokenizerFast, RobertaTokenizerFast, XLNetTokenizer, AlbertTokenizer, GPT2TokenizerFast, \ 18 | TransfoXLTokenizerFast, DistilBertTokenizerFast 19 | from transformers import BertConfig, RobertaConfig, XLNetConfig, AlbertConfig, GPT2Config, TransfoXLConfig, \ 20 | DistilBertConfig 21 | from transformers import TFBertModel, TFRobertaModel, TFXLNetModel, TFAlbertModel, TFGPT2Model, TFTransfoXLModel, \ 22 | TFDistilBertModel 23 | 24 | from config import BERT_VOCAB_FILE, BERT_MERGE_FILE, BERT_CONFIG_FILE, BERT_MODEL_FILE 25 | 26 | 27 | def get_bert_tokenizer(bert_model_type): 28 | if bert_model_type in ['bert-base-uncased', 'prod-bert-base-uncased', 'bert-base-cased', 'bert-large-uncased', 29 | 'tune_bert-base-uncased_nsp', 'bert-large-uncased-whole-word-masking', 30 | 'bert-large-uncased-whole-word-masking-finetuned-squad']: 31 | if '-cased' in bert_model_type: 32 | do_lower_case = False 33 | else: 34 | do_lower_case = True # default 35 | return BertTokenizerFast(vocab_file=BERT_VOCAB_FILE[bert_model_type], do_lower_case=do_lower_case) 36 | elif bert_model_type in ['roberta-base', 'prod-roberta-base-cased', 'roberta-large', 'roberta-large-mnli', 37 | 'distilroberta-base']: 38 | return RobertaTokenizerFast(vocab_file=BERT_VOCAB_FILE[bert_model_type], 39 | merges_file=BERT_MERGE_FILE[bert_model_type], 40 | add_prefix_space=True) 41 | elif bert_model_type in ['xlnet-base-cased']: 42 | if '-uncased' in bert_model_type: 43 | do_lower_case = True 44 | else: 45 | do_lower_case = False # default 46 | return XLNetTokenizer(vocab_file=BERT_VOCAB_FILE[bert_model_type], do_lower_case=do_lower_case) 47 | elif bert_model_type in ['albert-base-v1', 'albert-large-v1', 'albert-xlarge-v1', 'albert-xxlarge-v1']: 48 | return AlbertTokenizer(vocab_file=BERT_VOCAB_FILE[bert_model_type]) 49 | elif bert_model_type in ['gpt2', 'gpt2-medium']: 50 | tokenizer = GPT2TokenizerFast(vocab_file=BERT_VOCAB_FILE[bert_model_type], 51 | merges_file=BERT_MERGE_FILE[bert_model_type], 52 | add_prefix_space=True) 53 | # https://github.com/huggingface/transformers/issues/3859 54 | tokenizer.pad_token = tokenizer.eos_token 55 | return tokenizer 56 | elif bert_model_type in ['transfo-xl']: 57 | return TransfoXLTokenizerFast(vocab_file=BERT_VOCAB_FILE[bert_model_type]) 58 | elif bert_model_type in ['distilbert-base-uncased', 'distilbert-base-uncased-distilled-squad']: 59 | if '-cased' in bert_model_type: 60 | do_lower_case = False 61 | else: 62 | do_lower_case = True # default 63 | return DistilBertTokenizerFast(vocab_file=BERT_VOCAB_FILE[bert_model_type], do_lower_case=do_lower_case) 64 | else: 65 | raise ValueError(f'`bert_model_type` not understood: {bert_model_type}') 66 | 67 | 68 | def get_bert_config(bert_model_type, output_hidden_states=False): 69 | if bert_model_type in ['bert-base-uncased', 'prod-bert-base-uncased', 'bert-base-cased', 'bert-large-uncased', 70 | 'tune_bert-base-uncased_nsp', 'bert-large-uncased-whole-word-masking', 71 | 'bert-large-uncased-whole-word-masking-finetuned-squad']: 72 | bert_config = BertConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 73 | elif bert_model_type in ['roberta-base', 'prod-roberta-base-cased', 'roberta-large', 'roberta-large-mnli', 74 | 'distilroberta-base']: 75 | bert_config = RobertaConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 76 | elif bert_model_type in ['xlnet-base-cased']: 77 | bert_config = XLNetConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 78 | elif bert_model_type in ['albert-base-v1', 'albert-large-v1', 'albert-xlarge-v1', 'albert-xxlarge-v1']: 79 | bert_config = AlbertConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 80 | elif bert_model_type in ['gpt2', 'gpt2-medium']: 81 | bert_config = GPT2Config.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 82 | elif bert_model_type in ['transfo-xl']: 83 | bert_config = TransfoXLConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 84 | elif bert_model_type in ['distilbert-base-uncased', 'distilbert-base-uncased-distilled-squad']: 85 | bert_config = DistilBertConfig.from_pretrained(BERT_CONFIG_FILE[bert_model_type]) 86 | else: 87 | raise ValueError(f'`bert_model_type` not understood: {bert_model_type}') 88 | 89 | bert_config.output_hidden_states = output_hidden_states 90 | return bert_config 91 | 92 | 93 | def get_transformer(bert_model_type, output_hidden_states=False): 94 | config = get_bert_config(bert_model_type, output_hidden_states) 95 | if bert_model_type in ['bert-base-uncased', 'bert-base-cased', 'bert-large-uncased', 96 | 'bert-large-uncased-whole-word-masking', 97 | 'bert-large-uncased-whole-word-masking-finetuned-squad']: 98 | return TFBertModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 99 | elif bert_model_type in ['prod-bert-base-uncased', 'tune_bert-base-uncased_nsp']: 100 | return TFBertModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config, from_pt=True) 101 | elif bert_model_type in ['roberta-base', 'roberta-large', 'roberta-large-mnli', 'distilroberta-base']: 102 | return TFRobertaModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 103 | elif bert_model_type in ['prod-roberta-base-cased']: 104 | return TFRobertaModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config, from_pt=True) 105 | elif bert_model_type in ['xlnet-base-cased']: 106 | return TFXLNetModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 107 | elif bert_model_type in ['albert-base-v1', 'albert-large-v1', 'albert-xlarge-v1', 'albert-xxlarge-v1']: 108 | return TFAlbertModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 109 | elif bert_model_type in ['gpt2', 'gpt2-medium']: 110 | return TFGPT2Model.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 111 | elif bert_model_type in ['transfo-xl']: 112 | return TFTransfoXLModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 113 | elif bert_model_type in ['distilbert-base-uncased', 'distilbert-base-uncased-distilled-squad']: 114 | return TFDistilBertModel.from_pretrained(BERT_MODEL_FILE[bert_model_type], config=config) 115 | else: 116 | raise ValueError(f'`bert_model_type` not understood: {bert_model_type}') 117 | -------------------------------------------------------------------------------- /two_level_ensembler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: ensemble.py 10 | 11 | @time: 2020/6/8 22:01 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import os 18 | import numpy as np 19 | from collections import Counter 20 | 21 | from config import PREDICT_DIR, PROCESSED_DATA_DIR, TEST_DATA_FILENAME, IDX2TOKEN_TEMPLATE 22 | from utils import pickle_load, format_filename, submit_result 23 | 24 | 25 | def load_pred_prob_list(model_name_list, data_type): 26 | pred_prob_cate1_list = [] 27 | pred_prob_cate2_list = [] 28 | pred_prob_cate3_list = [] 29 | for model_name in model_name_list: 30 | pred_probs = pickle_load(os.path.join(PREDICT_DIR, f'{model_name}_{data_type}_prob.pkl')) 31 | pred_prob_cate1_list.append(pred_probs[0]) 32 | pred_prob_cate2_list.append(pred_probs[1]) 33 | pred_prob_cate3_list.append(pred_probs[2]) 34 | return pred_prob_cate1_list, pred_prob_cate2_list, pred_prob_cate3_list 35 | 36 | 37 | def load_idx2cate(): 38 | idx2cate1 = pickle_load(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1')) 39 | idx2cate2 = pickle_load(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2')) 40 | idx2cate3 = pickle_load(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3')) 41 | return idx2cate1, idx2cate2, idx2cate3 42 | 43 | 44 | def avg_ensemble(pred_prob_list, return_prob=False, weights=None): 45 | model_weights = weights or [1.] * len(pred_prob_list) 46 | assert np.sum(model_weights) == len(pred_prob_list) 47 | weighted_pred_prob_list = [pred_prob * weight for weight, pred_prob in zip(model_weights, pred_prob_list)] 48 | 49 | avg_pred_prob = np.mean(np.stack(weighted_pred_prob_list, axis=2), axis=2) 50 | if return_prob: 51 | return np.argmax(avg_pred_prob, axis=1).tolist(), avg_pred_prob 52 | else: 53 | return np.argmax(avg_pred_prob, axis=1).tolist() 54 | 55 | 56 | def vote_ensemble(pred_prob_list, weights=None): 57 | model_weights = weights or [1] * len(pred_prob_list) 58 | pred_label_list = [np.argmax(pred_prob, axis=1) for pred_prob in pred_prob_list] 59 | ensemble_pred_label = [] 60 | for sample_pred in zip(*pred_label_list): 61 | label_counter = Counter() 62 | for idx, pred in enumerate(sample_pred): 63 | label_counter[pred] += 1 * model_weights[idx] 64 | majority_label, majority_count = label_counter.most_common(1)[0] 65 | ensemble_pred_label.append(majority_label) 66 | return ensemble_pred_label 67 | 68 | 69 | def voting_of_averaging(prefix_model_name_list, 70 | submit_file_prefix, 71 | cv_random_state=42, 72 | cv_fold=5, 73 | use_ex_pair=False, 74 | ex_threshold=0.1, 75 | use_pseudo=False, 76 | pseudo_random_state=42, 77 | pseudo_rate=5): 78 | use_pseudo_str = 'pseudo_' if use_pseudo else '' 79 | use_ex_pair_str = 'ex_pair_' if use_ex_pair else '' 80 | cv_pred_prob_cate1_list, cv_pred_prob_cate2_list, cv_pred_prob_cate3_list = [], [], [] 81 | test_data = pickle_load(format_filename(PROCESSED_DATA_DIR, TEST_DATA_FILENAME)) 82 | idx2cate1, idx2cate2, idx2cate3 = load_idx2cate() 83 | for model_name in prefix_model_name_list: 84 | ex_pair_str = f'_ex_pair_{ex_threshold}' if use_ex_pair else '' 85 | cv_model_name_list = [f'{model_name}_{cv_random_state}_{cv_fold}_{fold}{ex_pair_str}' 86 | for fold in range(cv_fold)] 87 | if use_pseudo: 88 | cv_model_name_list = [f'{cv_model_name}_cv_test_vote_pseudo' \ 89 | f'_{pseudo_random_state}_{pseudo_rate}_{fold}' 90 | for fold, cv_model_name in enumerate(cv_model_name_list)] 91 | pred_prob_cate1_list, pred_prob_cate2_list, pred_prob_cate3_list = load_pred_prob_list( 92 | cv_model_name_list, data_type='test') 93 | 94 | print(f'Logging Info - average ensemble for {model_name}...') 95 | cv_pred_cate1, cv_pred_prob_cate1 = avg_ensemble(pred_prob_cate1_list, return_prob=True) 96 | cv_pred_cate2, cv_pred_prob_cate2 = avg_ensemble(pred_prob_cate2_list, return_prob=True) 97 | cv_pred_cate3, cv_pred_prob_cate3 = avg_ensemble(pred_prob_cate3_list, return_prob=True) 98 | submit_result(test_data=test_data, 99 | pred_cate1_list=cv_pred_cate1, 100 | pred_cate2_list=cv_pred_cate2, 101 | pred_cate3_list=cv_pred_cate3, 102 | idx2cate1=idx2cate1, 103 | idx2cate2=idx2cate2, 104 | idx2cate3=idx2cate3, 105 | submit_file=f'{use_ex_pair_str}{use_pseudo_str}{submit_file_prefix}_{model_name}_test_avg_ensemble.csv', 106 | submit_with_text=True) 107 | cv_pred_prob_cate1_list.append(cv_pred_prob_cate1) 108 | cv_pred_prob_cate2_list.append(cv_pred_prob_cate2) 109 | cv_pred_prob_cate3_list.append(cv_pred_prob_cate3) 110 | 111 | print(f'Logging Info - voting ensemble for {model_name}...') 112 | cv_pred_cate1 = vote_ensemble(pred_prob_cate1_list) 113 | cv_pred_cate2 = vote_ensemble(pred_prob_cate2_list) 114 | cv_pred_cate3 = vote_ensemble(pred_prob_cate3_list) 115 | submit_result(test_data=test_data, 116 | pred_cate1_list=cv_pred_cate1, 117 | pred_cate2_list=cv_pred_cate2, 118 | pred_cate3_list=cv_pred_cate3, 119 | idx2cate1=idx2cate1, 120 | idx2cate2=idx2cate2, 121 | idx2cate3=idx2cate3, 122 | submit_file=f'{use_ex_pair_str}{use_pseudo_str}{submit_file_prefix}_{model_name}_test_vote_ensemble.csv', 123 | submit_with_text=True) 124 | 125 | print(f'Logging Info - average ensemble for all cross validation model...') 126 | cv_pred_cate1 = avg_ensemble(cv_pred_prob_cate1_list) 127 | cv_pred_cate2 = avg_ensemble(cv_pred_prob_cate2_list) 128 | cv_pred_cate3 = avg_ensemble(cv_pred_prob_cate3_list) 129 | submit_result(test_data=test_data, 130 | pred_cate1_list=cv_pred_cate1, 131 | pred_cate2_list=cv_pred_cate2, 132 | pred_cate3_list=cv_pred_cate3, 133 | idx2cate1=idx2cate1, 134 | idx2cate2=idx2cate2, 135 | idx2cate3=idx2cate3, 136 | submit_file=f'{use_ex_pair_str}{use_pseudo_str}{submit_file_prefix}_test_avg_ensemble.csv', 137 | submit_with_text=True) 138 | 139 | print(f'Logging Info - voting ensemble for all cross validation model...') 140 | cv_pred_cate1 = vote_ensemble(cv_pred_prob_cate1_list) 141 | cv_pred_cate2 = vote_ensemble(cv_pred_prob_cate2_list) 142 | cv_pred_cate3 = vote_ensemble(cv_pred_prob_cate3_list) 143 | submit_result(test_data=test_data, 144 | pred_cate1_list=cv_pred_cate1, 145 | pred_cate2_list=cv_pred_cate2, 146 | pred_cate3_list=cv_pred_cate3, 147 | idx2cate1=idx2cate1, 148 | idx2cate2=idx2cate2, 149 | idx2cate3=idx2cate3, 150 | submit_file=f'{use_ex_pair_str}{use_pseudo_str}{submit_file_prefix}_test_vote_ensemble.csv', 151 | submit_with_text=True) 152 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: io.py 10 | 11 | @time: 2020/5/8 9:52 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import os 18 | import json 19 | import pickle 20 | 21 | import numpy as np 22 | 23 | from config import ModelConfig, PREDICT_DIR, SUBMIT_DIR, LOG_DIR 24 | 25 | 26 | def pickle_load(filename): 27 | try: 28 | with open(filename, 'rb') as f: 29 | obj = pickle.load(f) 30 | 31 | print('Logging Info - Loaded:', filename) 32 | except EOFError: 33 | print('Logging Error - Cannot load:', filename) 34 | obj = None 35 | 36 | return obj 37 | 38 | 39 | def pickle_dump(filename, obj): 40 | with open(filename, 'wb') as f: 41 | pickle.dump(obj, f) 42 | 43 | print('Logging Info - Saved:', filename) 44 | 45 | 46 | def write_log(filename, log, mode='w'): 47 | with open(filename, mode) as writer: 48 | writer.write('\n') 49 | json.dump(log, writer, indent=4, ensure_ascii=False) 50 | 51 | 52 | def writer_md(filename, config: ModelConfig, trainer_logger, mode='a'): 53 | with open(os.path.join(LOG_DIR, filename), mode) as writer: 54 | parameters = f"|{config.exp_name:^100}|{trainer_logger['epoch']:^2}" 55 | writer.write(parameters) 56 | performance = '|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|'.format( 57 | trainer_logger['eval_result']['cate1_p'], 58 | trainer_logger['eval_result']['cate1_r'], 59 | trainer_logger['eval_result']['cate1_f1'], 60 | trainer_logger['eval_result']['cate2_p'], 61 | trainer_logger['eval_result']['cate2_r'], 62 | trainer_logger['eval_result']['cate2_f1'], 63 | trainer_logger['eval_result']['cate3_p'], 64 | trainer_logger['eval_result']['cate3_r'], 65 | trainer_logger['eval_result']['cate3_f1'], 66 | trainer_logger['eval_result']['val_p'], 67 | trainer_logger['eval_result']['val_r'], 68 | trainer_logger['eval_result']['val_f1'] 69 | ) 70 | performance += '{:^8}|{:^20}|'.format( 71 | trainer_logger['train_time'], 72 | trainer_logger['timestamp'] 73 | ) 74 | writer.write(performance) 75 | writer.write('\n') 76 | 77 | if 'swa_result' in trainer_logger: 78 | exp_name = config.exp_name + '_swa' 79 | parameters = f"|{exp_name:^100}|{trainer_logger['epoch']:^2}" 80 | writer.write(parameters) 81 | performance = '|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|{:.4f}|'.format( 82 | trainer_logger['swa_result']['cate1_p'], 83 | trainer_logger['swa_result']['cate1_r'], 84 | trainer_logger['swa_result']['cate1_f1'], 85 | trainer_logger['swa_result']['cate2_p'], 86 | trainer_logger['swa_result']['cate2_r'], 87 | trainer_logger['swa_result']['cate2_f1'], 88 | trainer_logger['swa_result']['cate3_p'], 89 | trainer_logger['swa_result']['cate3_r'], 90 | trainer_logger['swa_result']['cate3_f1'], 91 | trainer_logger['swa_result']['val_p'], 92 | trainer_logger['swa_result']['val_r'], 93 | trainer_logger['swa_result']['val_f1'] 94 | ) 95 | performance += '{:^8}|{:^20}|'.format( 96 | trainer_logger['train_time'], 97 | trainer_logger['timestamp'] 98 | ) 99 | 100 | writer.write(performance) 101 | writer.write('\n') 102 | 103 | 104 | def ensure_dir(path): 105 | if not os.path.exists(path): 106 | os.makedirs(path) 107 | 108 | 109 | def format_filename(_dir, filename_template, **kwargs): 110 | """Obtain the filename of data base on the provided template and parameters""" 111 | filename = os.path.join(_dir, filename_template.format(**kwargs)) 112 | return filename 113 | 114 | 115 | def submit_result(test_data, pred_cate1_list, pred_cate2_list, pred_cate3_list, idx2cate1, idx2cate2, idx2cate3, 116 | submit_file, submit_with_text=False): 117 | test_id_list = test_data['id'] 118 | assert len(test_id_list) == len(pred_cate1_list) == len(pred_cate2_list) == len(pred_cate3_list) 119 | with open(os.path.join(SUBMIT_DIR, submit_file), 'w', encoding='utf8') as writer: 120 | for idx, (test_id, pred_cate1_id, pred_cate2_id, pred_cate3_id) in enumerate(zip(test_id_list, 121 | pred_cate1_list, 122 | pred_cate2_list, 123 | pred_cate3_list)): 124 | 125 | pred_cate1 = idx2cate1[pred_cate1_id] 126 | pred_cate2 = idx2cate2[pred_cate2_id] 127 | pred_cate3 = idx2cate3[pred_cate3_id] 128 | if submit_with_text: 129 | test_name = test_data['name'][idx] 130 | test_desc = test_data['desc'][idx] 131 | writer.write(f'{test_id}##{test_name}##{test_desc}##{pred_cate1}##{pred_cate2}##{pred_cate3}\n') 132 | else: 133 | writer.write(f'{test_id},{pred_cate1},{pred_cate2},{pred_cate3}\n') 134 | 135 | 136 | def save_prob_to_file(pred_prob_cate1_list, pred_prob_cate2_list, pred_prob_cate3_list, pred_prob_all_cate_list, 137 | prob_file, use_multi_task): 138 | if not use_multi_task: 139 | pickle_dump(os.path.join(prob_file), np.vstack(pred_prob_all_cate_list)) 140 | else: 141 | pickle_dump(os.path.join(PREDICT_DIR, prob_file), 142 | [np.vstack(pred_prob_cate1_list), 143 | np.vstack(pred_prob_cate2_list), 144 | np.vstack(pred_prob_cate3_list)]) 145 | 146 | 147 | def save_diff_to_file(data, pred_cate1_list, pred_cate2_list, pred_cate3_list, true_cate1_list, true_cate2_list, 148 | true_cate3_list, cate1_to_cate2, cate2_to_cate3, cate1_to_cate3, idx2cate1, idx2cate2, idx2cate3, 149 | cate1_count_dict, cate2_count_dict, cate3_count_dict, 150 | diff_file): 151 | with open(os.path.join(LOG_DIR, diff_file), 'w') as writer: 152 | for i, data_id in enumerate(data['id']): 153 | t_cate1, p_cate1 = true_cate1_list[i], pred_cate1_list[i] 154 | t_cate2, p_cate2 = true_cate2_list[i], pred_cate2_list[i] 155 | t_cate3, p_cate3 = true_cate3_list[i], pred_cate3_list[i] 156 | 157 | cate1_diff = t_cate1 != p_cate1 158 | cate2_diff = t_cate2 != p_cate2 159 | cate3_diff = t_cate3 != p_cate3 160 | 161 | cate2_in_cate1 = p_cate2 in cate1_to_cate2[p_cate1] 162 | cate3_in_cate2 = p_cate3 in cate2_to_cate3[p_cate2] 163 | cate3_in_cate1 = p_cate3 in cate1_to_cate3[p_cate1] 164 | 165 | cate1_count = cate1_count_dict[idx2cate1[t_cate1]] 166 | cate2_count = cate2_count_dict[idx2cate2[t_cate2]] 167 | cate3_count = cate3_count_dict[idx2cate3[t_cate3]] 168 | 169 | if cate1_diff or cate2_diff or cate3_diff: 170 | writer.write(f'{data_id}\t') 171 | writer.write(f'{not cate1_diff}\t{not cate2_diff}\t{not cate3_diff}\t') 172 | writer.write(f'{cate2_in_cate1}\t{cate3_in_cate2}\t{cate3_in_cate1}\t') 173 | writer.write(f'{cate1_count:.6f}\t{cate2_count:.6f}\t{cate3_count:.6f}\t') 174 | writer.write(f'{idx2cate1[t_cate1]}|{idx2cate1[p_cate1]}\t') 175 | writer.write(f'{idx2cate2[t_cate2]}|{idx2cate2[p_cate2]}\t') 176 | writer.write(f'{idx2cate3[t_cate3]}|' 177 | f'{idx2cate3[p_cate3]}') 178 | writer.write('\n') 179 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: preprocess.py 10 | 11 | @time: 2020/5/19 21:55 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import json 18 | 19 | import numpy as np 20 | from sklearn.model_selection import StratifiedKFold, KFold 21 | 22 | from config import PROCESSED_DATA_DIR, LOG_DIR, MODEL_SAVED_DIR, SUBMIT_DIR, RAW_TRAIN_FILENAME, RAW_DEV_FILENAME, \ 23 | TRAIN_DATA_FILENAME, DEV_DATA_FILENAME, VOCABULARY_TEMPLATE, IDX2TOKEN_TEMPLATE, CATE1_TO_CATE2_DICT, \ 24 | CATE1_TO_CATE3_DICT, CATE2_TO_CATE3_DICT, RAW_TEST_FILENAME, TEST_DATA_FILENAME, TRAIN_CV_DATA_TEMPLATE, \ 25 | DEV_CV_DATA_TEMPLATE, PREDICT_DIR, CATE1_COUNT_DICT, CATE2_COUNT_DICT, CATE3_COUNT_DICT 26 | from utils import ensure_dir, format_filename, pickle_dump 27 | 28 | 29 | def load_raw_data(file_path, data_type='train'): 30 | data = { 31 | 'id': [], 32 | 'name': [], 33 | 'desc': [], 34 | 'cate1': [], 35 | 'cate2': [], 36 | 'cate3': [] 37 | } 38 | 39 | cate1_to_cate2 = dict() 40 | cate2_to_cate3 = dict() 41 | cate1_to_cate3 = dict() 42 | 43 | with open(file_path, 'r') as reader: 44 | for line in reader: 45 | product_info = json.loads(line.strip()) 46 | data['id'].append(product_info['ID']) 47 | data['name'].append(product_info['Name']) 48 | data['desc'].append(product_info['Description']) 49 | 50 | if data_type != 'test': 51 | cate1 = product_info['lvl1'] 52 | cate2 = product_info['lvl2'] 53 | cate3 = product_info['lvl3'] 54 | data['cate1'].append(cate1) 55 | data['cate2'].append(cate2) 56 | data['cate3'].append(cate3) 57 | 58 | if cate1 not in cate1_to_cate2: 59 | cate1_to_cate2[cate1] = set() 60 | cate1_to_cate2[cate1].add(cate2) 61 | 62 | if cate2 not in cate2_to_cate3: 63 | cate2_to_cate3[cate2] = set() 64 | cate2_to_cate3[cate2].add(cate3) 65 | 66 | if cate1 not in cate1_to_cate3: 67 | cate1_to_cate3[cate1] = set() 68 | cate1_to_cate3[cate1].add(cate3) 69 | 70 | print('Logging Info - Data Size:', len(data['id'])) 71 | if data_type == 'train': 72 | assert len(data['id']) == len(data['name']) == len(data['desc']) == len(data['cate1']) == len( 73 | data['cate2']) == len(data['cate3']) 74 | return data, cate1_to_cate2, cate2_to_cate3, cate1_to_cate3 75 | else: 76 | return data 77 | 78 | 79 | def load_cate_count(data, cate_key): 80 | cate_count = {} 81 | for cate in data[cate_key]: 82 | if cate not in cate_count: 83 | cate_count[cate] = 0 84 | cate_count[cate] += 1 85 | for cate in cate_count: 86 | cate_count[cate] /= len(data[cate_key]) 87 | return cate_count 88 | 89 | 90 | def load_vocab_and_corpus(data, cut_func, min_count=1): 91 | print('Logging Info: Constructing vocabulary and corpus...') 92 | tokens = dict() 93 | corpus = [] 94 | for name, desc in zip(data['name'], data['desc']): 95 | if desc is None or desc == '': 96 | text = name 97 | else: 98 | text = f'{name} {desc}' 99 | text_cut = cut_func(text) 100 | for token in text_cut: 101 | tokens[token] = tokens.get(token, 0) + 1 102 | corpus.append(text_cut) 103 | tokens = [token for token, token_count in tokens.items() if token_count >= min_count] 104 | idx2token = {idx + 2: token for idx, token in enumerate(tokens)} # 0: mask, 1: padding 105 | token2idx = {token: idx for idx, token in idx2token.items()} 106 | print(f'Logging Info - Token Vocabulary: {len(token2idx)}') 107 | return token2idx, idx2token, corpus 108 | 109 | 110 | def load_label_vocab(data): 111 | cate1_vocab = {} 112 | cate2_vocab = {} 113 | cate3_vocab = {} 114 | all_cate_vocab = {} 115 | for cate1, cate2, cate3 in zip(data['cate1'], data['cate2'], data['cate3']): 116 | if cate1 not in cate1_vocab: 117 | cate1_vocab[cate1] = len(cate1_vocab) 118 | if cate2 not in cate2_vocab: 119 | cate2_vocab[cate2] = len(cate2_vocab) 120 | if cate3 not in cate3_vocab: 121 | cate3_vocab[cate3] = len(cate3_vocab) 122 | all_cate = f"{cate1}|{cate2}|{cate3}" 123 | if all_cate not in all_cate_vocab: 124 | all_cate_vocab[all_cate] = len(all_cate_vocab) 125 | 126 | print(f'Logging Info - Level 1 Category: {len(cate1_vocab)}') 127 | print(f'Logging Info - Level 2 Category: {len(cate2_vocab)}') 128 | print(f'Logging Info - Level 3 Category: {len(cate3_vocab)}') 129 | print(f'Logging Info - All Level Category: {len(all_cate_vocab)}') 130 | 131 | return cate1_vocab, cate2_vocab, cate3_vocab, all_cate_vocab 132 | 133 | 134 | def convert_to_id(cate1_to_cate2, cate1_vocab, cate2_vocab): 135 | cate1_to_cate2_id = dict() 136 | for cate1 in cate1_to_cate2: 137 | cate1_id = cate1_vocab[cate1] 138 | cate1_to_cate2_id[cate1_id] = set(map(cate2_vocab.get, cate1_to_cate2[cate1])) 139 | return cate1_to_cate2_id 140 | 141 | 142 | def cv_split(train_data, dev_data, cate3_vocab, fold=5, balanced=True, random_state=42): 143 | def indexing_data(data, indices): 144 | part_data = {} 145 | for k in data.keys(): 146 | part_data[k] = [data[k][i] for i in indices] 147 | return part_data 148 | 149 | all_data = {} 150 | for key in train_data.keys(): 151 | all_data[key] = train_data[key] + dev_data[key] 152 | 153 | # some category in validation set is not in cate3_vocab 154 | cate3_id_list = [cate3_vocab.get(cate3, 0) for cate3 in all_data['cate3']] 155 | index_range = np.arange(len(all_data['id'])) 156 | 157 | if balanced: 158 | kf = StratifiedKFold(n_splits=fold, shuffle=True, random_state=random_state) 159 | else: 160 | kf = KFold(n_splits=fold, shuffle=True, random_state=random_state) 161 | 162 | for idx, (train_index, dev_index) in enumerate(kf.split(index_range, cate3_id_list)): 163 | train_data_fold = indexing_data(all_data, train_index) 164 | dev_data_fold = indexing_data(all_data, dev_index) 165 | 166 | pickle_dump(format_filename(PROCESSED_DATA_DIR, TRAIN_CV_DATA_TEMPLATE, random=random_state, 167 | fold=fold, index=idx), train_data_fold) 168 | pickle_dump(format_filename(PROCESSED_DATA_DIR, DEV_CV_DATA_TEMPLATE, random=random_state, 169 | fold=fold, index=idx), dev_data_fold) 170 | 171 | 172 | if __name__ == '__main__': 173 | # create directory 174 | ensure_dir(PROCESSED_DATA_DIR) 175 | ensure_dir(LOG_DIR) 176 | ensure_dir(MODEL_SAVED_DIR) 177 | ensure_dir(SUBMIT_DIR) 178 | ensure_dir(PREDICT_DIR) 179 | 180 | # read dataset 181 | train_data, cate1_to_cate2, cate2_to_cate3, cate1_to_cate3 = load_raw_data(RAW_TRAIN_FILENAME, 182 | data_type='train') 183 | dev_data = load_raw_data(RAW_DEV_FILENAME, data_type='dev') 184 | test_data = load_raw_data(RAW_TEST_FILENAME, data_type='test') 185 | pickle_dump(format_filename(PROCESSED_DATA_DIR, TRAIN_DATA_FILENAME), train_data) 186 | pickle_dump(format_filename(PROCESSED_DATA_DIR, DEV_DATA_FILENAME), dev_data) 187 | pickle_dump(format_filename(PROCESSED_DATA_DIR, TEST_DATA_FILENAME), test_data) 188 | 189 | cate1_count = load_cate_count(train_data, 'cate1') 190 | cate2_count = load_cate_count(train_data, 'cate2') 191 | cate3_count = load_cate_count(train_data, 'cate3') 192 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT), cate1_count) 193 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT), cate2_count) 194 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT), cate3_count) 195 | 196 | cate1_vocab, cate2_vocab, cate3_vocab, all_cate_vocab = load_label_vocab(train_data) 197 | idx2cate1 = dict((idx, cate) for cate, idx in cate1_vocab.items()) 198 | idx2cate2 = dict((idx, cate) for cate, idx in cate2_vocab.items()) 199 | idx2cate3 = dict((idx, cate) for cate, idx in cate3_vocab.items()) 200 | idx2all_cate = dict((idx, cate) for cate, idx in all_cate_vocab.items()) 201 | pickle_dump(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate1'), cate1_vocab) 202 | pickle_dump(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1'), idx2cate1) 203 | pickle_dump(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate2'), cate2_vocab) 204 | pickle_dump(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2'), idx2cate2) 205 | pickle_dump(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate3'), cate3_vocab) 206 | pickle_dump(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3'), idx2cate3) 207 | pickle_dump(format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='all_cate'), all_cate_vocab) 208 | pickle_dump(format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='all_cate'), idx2all_cate) 209 | 210 | cate1_to_cate2_id = convert_to_id(cate1_to_cate2, cate1_vocab, cate2_vocab) 211 | cate2_to_cate3_id = convert_to_id(cate2_to_cate3, cate2_vocab, cate3_vocab) 212 | cate1_to_cate3_id = convert_to_id(cate1_to_cate3, cate1_vocab, cate3_vocab) 213 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT), cate1_to_cate2_id) 214 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT), cate2_to_cate3_id) 215 | pickle_dump(format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT), cate1_to_cate3_id) 216 | 217 | cv_split(train_data, dev_data, cate3_vocab, fold=5, balanced=True, random_state=42) 218 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: base_model.py 10 | 11 | @time: 2020/5/21 7:27 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import os 18 | import math 19 | 20 | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint 21 | 22 | from callbacks import * 23 | 24 | 25 | class BaseModel(object): 26 | def __init__(self, config): 27 | self.config = config 28 | self.callbacks = [] 29 | self.model = None 30 | self.swa_model = None 31 | self.polyak_model, self.temp_polyak_model = None, None 32 | self.earlystop_callback = None 33 | 34 | def add_model_checkpoint(self): 35 | self.callbacks.append(ModelCheckpoint( 36 | filepath=os.path.join(self.config.checkpoint_dir, '{}.hdf5'.format(self.config.exp_name)), 37 | monitor=self.config.checkpoint_monitor, 38 | save_best_only=self.config.checkpoint_save_best_only, 39 | save_weights_only=self.config.checkpoint_save_weights_only, 40 | mode=self.config.checkpoint_save_weights_mode, 41 | verbose=self.config.checkpoint_verbose 42 | )) 43 | print('Logging Info - Callback Added: ModelCheckPoint...') 44 | 45 | def add_early_stopping(self): 46 | self.earlystop_callback = EarlyStopping( 47 | monitor=self.config.early_stopping_monitor, 48 | mode=self.config.early_stopping_mode, 49 | patience=self.config.early_stopping_patience, 50 | verbose=self.config.early_stopping_verbose 51 | ) 52 | self.callbacks.append(self.earlystop_callback) 53 | print('Logging Info - Callback Added: EarlyStopping...') 54 | 55 | def add_clr(self, kind, min_lr, max_lr, cycle_length): 56 | """ 57 | add cyclic learning rate schedule callback 58 | :param kind: add what kind of clr, 0: the original cyclic lr, 1: the one introduced in FGE, 2: the one 59 | introduced in swa 60 | """ 61 | if kind == 0: 62 | self.callbacks.append(CyclicLR(base_lr=min_lr, max_lr=max_lr, step_size=cycle_length/2, mode='triangular2', 63 | plot=True, save_plot_prefix=self.config.exp_name)) 64 | elif kind == 1: 65 | self.callbacks.append(CyclicLR_1(min_lr=min_lr, max_lr=max_lr, cycle_length=cycle_length, plot=True, 66 | save_plot_prefix=self.config.exp_name)) 67 | elif kind == 2: 68 | self.callbacks.append(CyclicLR_2(min_lr=min_lr, max_lr=max_lr, cycle_length=cycle_length, plot=True, 69 | save_plot_prefix=self.config.exp_name)) 70 | else: 71 | raise ValueError('param `kind` not understood : {}'.format(kind)) 72 | print('Logging Info - Callback Added: CLR_{}...'.format(kind)) 73 | 74 | def add_sgdr(self, min_lr, max_lr, cycle_length): 75 | self.callbacks.append(SGDR(min_lr=min_lr, max_lr=max_lr, cycle_length=cycle_length, 76 | save_plot_prefix=self.config.exp_name)) 77 | print('Logging Info - Callback Added: SGDR...') 78 | 79 | def add_swa(self, with_clr, min_lr=None, max_lr=None, cycle_length=None, swa_start=5): 80 | if with_clr: 81 | self.callbacks.append(SWAWithCLR(self.swa_model, self.config.checkpoint_dir, self.config.exp_name, 82 | min_lr=min_lr, max_lr=max_lr, cycle_length=cycle_length, 83 | swa_start=swa_start)) 84 | else: 85 | self.callbacks.append(SWA(self.swa_model, self.config.checkpoint_dir, self.config.exp_name, 86 | swa_start=swa_start)) 87 | print('Logging Info - Callback Added: SWA with {}...'.format('clr' if with_clr else 'constant lr')) 88 | 89 | def add_polyak(self, avg_type, polyak_start=5): 90 | self.callbacks.append(PolyakAverage(self.polyak_model, self.temp_polyak_model, self.config.checkpoint_dir, 91 | self.config.exp_name, avg_type, self.config.early_stopping_patience, 92 | polyak_start)) 93 | print('Logging Info - Callback Added: Polyak Average Ensemble...') 94 | 95 | def add_hor(self, hor_start=5): 96 | self.callbacks.append(HorizontalEnsemble(self.config.checkpoint_dir, self.config.exp_name, hor_start)) 97 | print('Logging Info - Callback Added: Horizontal Ensemble...') 98 | 99 | def add_sse(self, max_lr, cycle_length, sse_start): 100 | self.callbacks.append(SnapshotEnsemble(self.config.checkpoint_dir, self.config.exp_name, 101 | max_lr=max_lr, cycle_length=cycle_length, snapshot_start=sse_start)) 102 | print('Logging Info - Callback Added: Snapshot Ensemble...') 103 | 104 | def add_fge(self, min_lr, max_lr, cycle_length, fge_start): 105 | self.callbacks.append(FGE(self.config.checkpoint_dir, self.config.exp_name, min_lr=min_lr, max_lr=max_lr, 106 | cycle_length=cycle_length, fge_start=fge_start)) 107 | print('Logging Info - Callback Added: Fast Geometric Ensemble...') 108 | 109 | def add_warmup(self, lr=5e-5, min_lr=1e-5): 110 | self.callbacks.append(WarmUp(lr, min_lr)) 111 | print('Logging Info - Callback Added: WarmUp....') 112 | 113 | def init_callbacks(self, data_size): 114 | cycle_length = 6 * math.floor(data_size / self.config.batch_size) 115 | 116 | if 'modelcheckpoint' in self.config.callbacks_to_add: 117 | self.add_model_checkpoint() 118 | if 'earlystopping' in self.config.callbacks_to_add: 119 | self.add_early_stopping() 120 | if 'clr' in self.config.callbacks_to_add: 121 | self.add_clr(kind=0, min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length) 122 | if 'sgdr' in self.config.callbacks_to_add: 123 | self.add_sgdr(min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length) 124 | if 'clr_1' in self.config.callbacks_to_add: 125 | self.add_clr(kind=1, min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length) 126 | if 'clr_2' in self.config.callbacks_to_add: 127 | self.add_clr(kind=2, min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length) 128 | if 'swa' in self.config.callbacks_to_add: 129 | self.add_swa(with_clr=False, swa_start=self.config.swa_start) 130 | if 'swa_clr' in self.config.callbacks_to_add: 131 | self.add_swa(with_clr=True, min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length, 132 | swa_start=self.config.swq_clr_start) 133 | if 'sse' in self.config.callbacks_to_add: 134 | self.add_sse(max_lr=self.config.max_lr, cycle_length=cycle_length, sse_start=self.config.sse_start) 135 | if 'fge' in self.config.callbacks_to_add: 136 | self.add_fge(min_lr=self.config.min_lr, max_lr=self.config.max_lr, cycle_length=cycle_length, 137 | fge_start=self.config.fge_start) 138 | if 'hor' in self.config.callbacks_to_add: 139 | self.add_hor(hor_start=self.config.swa_start) 140 | if 'polyak_avg' in self.config.callbacks_to_add: 141 | self.add_polyak('avg', self.config.polyak_start) 142 | if 'polyak_linear' in self.config.callbacks_to_add: 143 | self.add_polyak('linear', self.config.polyak_start) 144 | if 'polay_exp' in self.config.callbacks_to_add: 145 | self.add_polyak('exp', self.config.polyak_start) 146 | if 'warmup' in self.config.callbacks_to_add: 147 | self.add_warmup(lr=self.config.max_lr, min_lr=self.config.min_lr) 148 | 149 | def load_weights(self, filename): 150 | self.model.load_weights(filename) 151 | 152 | def load_model(self, filename): 153 | # we only save model's weight instead of the whole model 154 | self.model.load_weights(filename) 155 | 156 | def load_best_model(self): 157 | print('Logging Info - Loading model checkpoint: %s.hdf5\n' % self.config.exp_name) 158 | self.load_model(os.path.join(self.config.checkpoint_dir, '{}.hdf5'.format(self.config.exp_name))) 159 | print('Logging Info - Model loaded') 160 | 161 | def load_swa_model(self, swa_type='swa'): 162 | print('Logging Info - Loading SWA model checkpoint: %s_%s.hdf5\n' % (self.config.exp_name, swa_type)) 163 | self.load_model(os.path.join(self.config.checkpoint_dir, '%s_%s.hdf5' % (self.config.exp_name, swa_type))) 164 | print('Logging Info - SWA Model loaded') 165 | 166 | def load_polyak_model(self, polyak_type='swa'): 167 | print('Logging Info - Loading Polyak Average model checkpoint: %s_polyak_%s.hdf5\n' % (self.config.exp_name, polyak_type)) 168 | self.load_model(os.path.join(self.config.checkpoint_dir, '%s_polyak_%s.hdf5' % (self.config.exp_name, polyak_type))) 169 | print('Logging Info - Polyak Model loaded') 170 | 171 | def add_metric_callback(self, valid_generator): 172 | raise NotImplementedError 173 | 174 | def train(self, train_generator, valid_generator): 175 | self.callbacks = [] 176 | self.add_metric_callback(valid_generator=valid_generator) 177 | self.init_callbacks(train_generator.data_size) 178 | 179 | print('Logging Info - Start training...') 180 | self.model.fit(x=train_generator, epochs=self.config.n_epoch, callbacks=self.callbacks) 181 | print('Logging Info - Training end...') 182 | 183 | def return_trained_epoch(self): 184 | # https://stackoverflow.com/questions/49852241/return-number-of-epochs-for-earlystopping-callback-in-keras 185 | if self.earlystop_callback: 186 | stopped_epoch = self.earlystop_callback.stopped_epoch + 1 187 | if stopped_epoch == 1: 188 | return self.config.n_epoch 189 | else: 190 | return stopped_epoch 191 | else: 192 | return self.config.n_epoch 193 | 194 | def summary(self): 195 | self.model.summary() 196 | -------------------------------------------------------------------------------- /layers/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: attention.py 10 | 11 | @time: 2020/6/5 21:57 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import tensorflow as tf 18 | import tensorflow.keras.backend as K 19 | from tensorflow.keras import initializers 20 | from tensorflow.keras.initializers import zeros 21 | from tensorflow.keras.layers import Activation, Layer 22 | 23 | 24 | class HierarchicalAttentionRecurrent(Layer): 25 | """Hierarchical Attention-based Recurrent Layer""" 26 | def __init__(self, cate_hierarchy, cate_embed_dim=100, kernel_initializer=None, **kwargs): 27 | super(HierarchicalAttentionRecurrent, self).__init__(**kwargs) 28 | self.cate_hierarchy = cate_hierarchy 29 | self.n_hierarchy = len(self.cate_hierarchy) 30 | self.cate_embed_dim = cate_embed_dim 31 | self.kernel_initializer = initializers.get(kernel_initializer) 32 | 33 | def build(self, input_shape): 34 | assert len(input_shape) == 3 35 | self.category_embeddings = [] 36 | self.attend_weights = [] 37 | self.cls_dense_weights = [] 38 | self.cls_dense_biases = [] 39 | self.cls_pred_weights = [] 40 | self.cls_pred_biases = [] 41 | for h in range(self.n_hierarchy): 42 | # category embedding 43 | self.category_embeddings.append(self.add_weight(name=f'category_embed_{h}', 44 | shape=(self.cate_hierarchy[h], self.cate_embed_dim), 45 | initializer=self.kernel_initializer)) 46 | # attention weights to perform attention 47 | self.attend_weights.append(self.add_weight(name=f'attend_weight_{h}', 48 | shape=(input_shape[2], self.cate_embed_dim), 49 | initializer=self.kernel_initializer)) 50 | # class prediction weights and biases 51 | self.cls_dense_weights.append(self.add_weight(name=f'cls_dense_weight_{h}', 52 | shape=(input_shape[2] * 2, input_shape[2]), 53 | initializer=self.kernel_initializer)) 54 | self.cls_dense_biases.append(self.add_weight(name=f'cls_dense_bias_{h}', 55 | shape=(input_shape[2]), 56 | initializer=zeros)) 57 | self.cls_pred_weights.append(self.add_weight(name=f'cls_pred_weight_{h}', 58 | shape=(input_shape[2], self.cate_hierarchy[h]), 59 | initializer=self.kernel_initializer)) 60 | self.cls_pred_biases.append(self.add_weight(name=f'cls_pred_bias_{h}', 61 | shape=(self.cate_hierarchy[h]), 62 | initializer=zeros)) 63 | 64 | super(HierarchicalAttentionRecurrent, self).build(input_shape) 65 | 66 | def call(self, inputs, mask=None): 67 | text_seq_embed = inputs # hidden states of text sequence 68 | batch_size = K.shape(text_seq_embed)[0] 69 | max_len = K.shape(text_seq_embed)[1] 70 | prev_cate_info = None # information of previous category level 71 | 72 | prob_outputs = [] 73 | for h in range(self.n_hierarchy): 74 | '''1. Text-Category Attention TCA''' 75 | if h == 0: 76 | text_with_prev_cate_embed = text_seq_embed 77 | else: 78 | text_with_prev_cate_embed = tf.multiply(text_seq_embed, K.expand_dims(prev_cate_info,axis=2)) 79 | attend_matrix = K.dot(K.tanh(K.dot(text_with_prev_cate_embed, self.attend_weights[h])), 80 | K.transpose(self.category_embeddings[h])) 81 | attend_matrix = K.softmax(attend_matrix, axis=1) # attention score between text and categories at h level 82 | 83 | # associated text-category representation 84 | text_cate_attend_embed = K.mean(K.batch_dot(attend_matrix, text_with_prev_cate_embed, axes=1), axis=1) 85 | 86 | '''2. Class Prediction Module CPM''' 87 | cls_dense_embed = K.dot(K.concatenate([K.mean(text_seq_embed, axis=1), text_cate_attend_embed]), 88 | self.cls_dense_weights[h]) + self.cls_dense_biases[h] 89 | cls_dense_embed = K.relu(cls_dense_embed) 90 | 91 | cls_pred_embed = K.dot(cls_dense_embed, self.cls_pred_weights[h]) + self.cls_pred_biases[h] 92 | cls_pred_embed = K.softmax(cls_pred_embed, axis=1) 93 | prob_outputs.append(cls_pred_embed) 94 | 95 | '''3. Class Dependency Module CPM''' 96 | prev_cate_info = K.mean(attend_matrix * K.expand_dims(cls_pred_embed, axis=1), axis=2) 97 | 98 | return prob_outputs 99 | 100 | def compute_output_shape(self, input_shape): 101 | output_shape = [] 102 | for h in range(self.n_hierarchy): 103 | output_shape.append((input_shape[0], self.cate_hierarchy[h])) 104 | return output_shape 105 | 106 | 107 | class HierarchicalAttention(Layer): 108 | """Hierarchical Attention-based Layer""" 109 | def __init__(self, cate_hierarchy, cate_embed_dim=100, kernel_initializer=None, **kwargs): 110 | super(HierarchicalAttention, self).__init__(**kwargs) 111 | self.cate_hierarchy = cate_hierarchy 112 | self.n_hierarchy = len(self.cate_hierarchy) 113 | self.cate_embed_dim = cate_embed_dim 114 | self.kernel_initializer = initializers.get(kernel_initializer) 115 | 116 | def build(self, input_shape): 117 | assert len(input_shape) == 3 118 | self.category_embeddings = [] 119 | self.attend_weights_1 = [] 120 | self.attend_weights_2 = [] 121 | self.cls_dense_weights = [] 122 | self.cls_dense_biases = [] 123 | self.cls_pred_weights = [] 124 | self.cls_pred_biases = [] 125 | 126 | self.category_embeddings.append(self.add_weight(name=f'category_embed_0', 127 | shape=(1, self.cate_embed_dim), 128 | initializer=self.kernel_initializer)) 129 | for h in range(self.n_hierarchy): 130 | # category embedding 131 | self.category_embeddings.append(self.add_weight(name=f'category_embed_{h+1}', 132 | shape=(self.cate_hierarchy[h], self.cate_embed_dim), 133 | initializer=self.kernel_initializer)) 134 | # attention weights to perform attention 135 | self.attend_weights_1.append(self.add_weight(name=f'attend_weight_1_{h}', 136 | shape=(input_shape[2] + self.cate_embed_dim, 137 | input_shape[2] + self.cate_embed_dim), 138 | initializer=self.kernel_initializer)) 139 | self.attend_weights_2.append(self.add_weight(name=f'attend_weight_2_{h}', 140 | shape=(input_shape[2] + self.cate_embed_dim, 1), 141 | initializer=self.kernel_initializer)) 142 | # class prediction weights and biases 143 | self.cls_dense_weights.append(self.add_weight(name=f'cls_dense_weight_{h}', 144 | shape=(input_shape[2] + self.cate_embed_dim, 145 | input_shape[2] + self.cate_embed_dim), 146 | initializer=self.kernel_initializer)) 147 | self.cls_dense_biases.append(self.add_weight(name=f'cls_dense_bias_{h}', 148 | shape=(input_shape[2] + self.cate_embed_dim), 149 | initializer=zeros)) 150 | self.cls_pred_weights.append(self.add_weight(name=f'cls_pred_weight_{h}', 151 | shape=(input_shape[2] + self.cate_embed_dim, 152 | self.cate_hierarchy[h]), 153 | initializer=self.kernel_initializer)) 154 | self.cls_pred_biases.append(self.add_weight(name=f'cls_pred_bias_{h}', 155 | shape=(self.cate_hierarchy[h]), 156 | initializer=zeros)) 157 | 158 | super(HierarchicalAttention, self).build(input_shape) 159 | 160 | def call(self, inputs, mask=None): 161 | 162 | text_seq_embed = inputs # hidden states of text sequence 163 | batch_size = K.shape(text_seq_embed)[0] 164 | max_len = K.shape(text_seq_embed)[1] 165 | prev_cate_embed = self.category_embeddings[0] 166 | prev_cate_embed = K.tile(K.expand_dims(prev_cate_embed, axis=0), [batch_size, 1, 1]) 167 | 168 | prob_outputs = [] 169 | for h in range(self.n_hierarchy): 170 | text_with_prev_cate_embed = K.concatenate([text_seq_embed, K.tile(prev_cate_embed, [1, max_len, 1])]) 171 | 172 | attend_matrix = K.dot(K.tanh(K.dot(text_with_prev_cate_embed, self.attend_weights_1[h])), 173 | self.attend_weights_2[h]) 174 | attend_matrix = K.softmax(attend_matrix, axis=1) 175 | 176 | text_cate_attend_embed = K.sum(attend_matrix * text_with_prev_cate_embed, axis=1) 177 | 178 | cls_dense_embed = K.dot(text_cate_attend_embed, self.cls_dense_weights[h]) + self.cls_dense_biases[h] 179 | cls_dense_embed = K.relu(cls_dense_embed) 180 | 181 | cls_pred_embed = K.dot(cls_dense_embed, self.cls_pred_weights[h]) + self.cls_pred_biases[h] 182 | prob_outputs.append(cls_pred_embed) 183 | 184 | prev_cate_embed = K.expand_dims(K.dot(cls_pred_embed, self.category_embeddings[h+1]), axis=1) 185 | 186 | return prob_outputs 187 | 188 | def compute_output_shape(self, input_shape): 189 | output_shape = [] 190 | for h in range(self.n_hierarchy): 191 | output_shape.append((input_shape[0], self.cate_hierarchy[h])) 192 | return output_shape 193 | -------------------------------------------------------------------------------- /data_loader/multitask_classify_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: multitask_classify_loader.py 10 | 11 | @time: 2020/5/21 21:12 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | import os 18 | import nltk 19 | import numpy as np 20 | from tensorflow.keras.utils import Sequence, to_categorical 21 | 22 | from config import NLTK_DATA 23 | from utils import pad_sequences_1d, get_bert_tokenizer 24 | from .base_loader import load_data 25 | 26 | 27 | class MultiTaskClsDataGenerator(Sequence): 28 | def __init__(self, 29 | data_type, 30 | batch_size, 31 | use_multi_task=True, 32 | input_type='name_desc', 33 | use_word_input=True, 34 | word_vocab=None, 35 | use_bert_input=False, 36 | use_pair_input=False, 37 | bert_model_type=None, 38 | max_len=None, 39 | cate1_vocab=None, 40 | cate2_vocab=None, 41 | cate3_vocab=None, 42 | all_cate_vocab=None, 43 | use_mask_for_cate2=False, 44 | use_mask_for_cate3=False, 45 | cate3_mask_type=None, 46 | cate1_to_cate2=None, 47 | cate_to_cate3=None, 48 | train_on_cv=False, 49 | cv_random_state=42, 50 | cv_fold=5, 51 | cv_index=0, 52 | exchange_pair=False, 53 | exchange_threshold=0.1, 54 | cate3_count_dict=None, 55 | use_pseudo_label=False, 56 | pseudo_path=None, 57 | pseudo_random_state=42, 58 | pseudo_rate=0.1, 59 | pseudo_index=0): 60 | self.data_type = data_type 61 | self.train_on_cv = train_on_cv 62 | self.data = load_data(data_type, train_on_cv=train_on_cv, cv_random_state=cv_random_state, 63 | cv_fold=cv_fold, cv_index=cv_index) 64 | # data augmentation, only for training set 65 | if self.data_type == 'train': 66 | if use_pseudo_label: 67 | self.add_pseudo_label(pseudo_path, pseudo_random_state, pseudo_rate, pseudo_index) 68 | if exchange_pair: 69 | self.exchange_pair_data(cate3_count_dict, exchange_threshold) 70 | 71 | self.data_size = len(self.data['name']) 72 | self.indices = np.arange(self.data_size) 73 | if self.data_type == 'train': 74 | np.random.shuffle(self.indices) # only shuffle for training set, we can't shuffle validation and test set!! 75 | self.batch_size = batch_size 76 | self.steps = int(np.ceil(self.data_size / self.batch_size)) 77 | self.use_multi_task = use_multi_task 78 | self.input_type = input_type 79 | self.use_word_input = use_word_input 80 | self.word_vocab = word_vocab 81 | self.use_bert_input = use_bert_input 82 | 83 | if use_word_input: 84 | nltk.data.path.append(NLTK_DATA) 85 | assert word_vocab is not None 86 | self.word_vocab = word_vocab 87 | elif use_bert_input: 88 | assert bert_model_type is not None 89 | assert max_len is not None 90 | self.bert_model_type = bert_model_type 91 | self.bert_tokenizer = get_bert_tokenizer(bert_model_type) 92 | 93 | if input_type != 'name_desc': 94 | assert not use_pair_input 95 | self.use_pair_input = use_pair_input 96 | self.max_len = max_len 97 | 98 | self.cate1_vocab = cate1_vocab 99 | self.cate2_vocab = cate2_vocab 100 | self.cate3_vocab = cate3_vocab 101 | self.all_cate_vocab = all_cate_vocab 102 | if not use_multi_task: 103 | assert self.all_cate_vocab is not None 104 | else: 105 | assert self.cate1_vocab is not None and self.cate2_vocab is not None and self.cate3_vocab is not None 106 | 107 | self.use_mask_for_cate2 = use_mask_for_cate2 108 | self.use_mask_for_cate3 = use_mask_for_cate3 109 | self.cate3_mask_type = cate3_mask_type 110 | if self.use_mask_for_cate2: 111 | assert self.use_multi_task 112 | assert cate1_to_cate2 is not None 113 | self.cate1_to_cate2_matrix = self.create_mask_matrix(cate1_to_cate2, len(cate1_vocab), len(cate2_vocab)) 114 | else: 115 | self.cate1_to_cate2_matrix = None 116 | if self.use_mask_for_cate3: 117 | assert self.use_multi_task 118 | assert self.cate3_mask_type in ['cate1', 'cate2'] 119 | self.cate_to_cate3_matrix = self.create_mask_matrix( 120 | cate_to_cate3, 121 | len(cate1_vocab) if self.cate3_mask_type == 'cate1' else len(cate2_vocab), 122 | len(cate3_vocab) 123 | ) 124 | else: 125 | self.cate_to_cate3_matrix = None 126 | 127 | def exchange_pair_data(self, cate3_count_dict, exchange_threshold): 128 | added_data = { 129 | 'id': [], 'name': [], 'desc': [], 'cate1': [], 'cate2': [], 'cate3': [] 130 | } 131 | for i in range(len(self.data['id'])): 132 | cate3 = self.data['cate3'][i] 133 | if cate3 in cate3_count_dict and cate3_count_dict[cate3] <= exchange_threshold: 134 | added_data['id'].append(self.data['id'][i]) 135 | # exchange name and desc 136 | added_data['name'].append(self.data['desc'][i]) 137 | added_data['desc'].append(self.data['name'][i]) 138 | # keep the labels 139 | added_data['cate1'].append(self.data['cate1'][i]) 140 | added_data['cate2'].append(self.data['cate2'][i]) 141 | added_data['cate3'].append(self.data['cate3'][i]) 142 | 143 | for key in added_data: 144 | self.data[key].extend(added_data[key]) 145 | 146 | def add_pseudo_label(self, pseudo_path, pseudo_random_state=42, pseudo_rate=0.1, pseudo_index=0): 147 | pseudo_label_data = { 148 | 'id': [], 'name': [], 'desc': [], 'cate1': [], 'cate2': [], 'cate3': [] 149 | } 150 | with open(pseudo_path, 'r', encoding='utf8') as reader: 151 | lines = reader.readlines() 152 | pseudo_data_size = len(lines) 153 | if pseudo_rate < 1: 154 | np.random.seed(pseudo_random_state) 155 | sample_pseudo_size = int(pseudo_data_size * pseudo_rate) 156 | sample_indices = np.random.choice(pseudo_data_size, sample_pseudo_size, 157 | replace=False) 158 | elif pseudo_rate == 1: 159 | sample_indices = range(pseudo_data_size) 160 | else: 161 | sample_pseudo_size = int(pseudo_data_size / pseudo_rate) 162 | start = pseudo_index * sample_pseudo_size 163 | end = (pseudo_index + 1) * sample_pseudo_size 164 | np.random.seed(pseudo_random_state) 165 | indices = np.random.permutation(pseudo_data_size) 166 | sample_indices = indices[start: end] 167 | for idx in sample_indices: 168 | line = lines[idx] 169 | text_id, name, desc, cate1, cate2, cate3 = line.strip().split('##') 170 | pseudo_label_data['id'].append(text_id) 171 | pseudo_label_data['name'].append(name) 172 | pseudo_label_data['desc'].append(desc) 173 | pseudo_label_data['cate1'].append(cate1) 174 | pseudo_label_data['cate2'].append(cate2) 175 | pseudo_label_data['cate3'].append(cate3) 176 | 177 | for key in pseudo_label_data: 178 | self.data[key].extend(pseudo_label_data[key]) 179 | 180 | def __len__(self): 181 | return self.steps 182 | 183 | def on_epoch_end(self): 184 | np.random.shuffle(self.indices) 185 | 186 | def __getitem__(self, index): 187 | batch_index = self.indices[index * self.batch_size: (index + 1) * self.batch_size] 188 | 189 | batch_input_ids, batch_input_masks, batch_input_types = [], [], [] 190 | batch_cate1_ids, batch_cate2_ids, batch_cate3_ids = [], [], [] # labels of multi task taining 191 | batch_all_cate_ids = [] # labels of single task taining 192 | for i in batch_index: 193 | text = self.prepare_text(self.data['name'][i], self.data['desc'][i]) 194 | 195 | # prepare input 196 | if self.use_word_input: 197 | word_ids = [self.word_vocab.get(w, 1) for w in nltk.tokenize.word_tokenize(text)] 198 | batch_input_ids.append(word_ids) 199 | elif self.use_bert_input: 200 | if self.use_pair_input: 201 | try: 202 | inputs = self.bert_tokenizer.encode_plus(text=text[0], text_pair=text[1], 203 | max_length=self.max_len, 204 | pad_to_max_length=True, 205 | truncation_strategy='only_second') 206 | except Exception: 207 | inputs = self.bert_tokenizer.encode_plus(text=text[0], text_pair=text[1], 208 | max_length=self.max_len, 209 | pad_to_max_length=True, 210 | truncation_strategy='longest_first') 211 | else: 212 | inputs = self.bert_tokenizer.encode_plus(text=text, max_length=self.max_len, pad_to_max_length=True) 213 | batch_input_ids.append(inputs['input_ids']) 214 | batch_input_masks.append(inputs['attention_mask']) 215 | if 'token_type_ids' in inputs: 216 | batch_input_types.append(inputs['token_type_ids']) 217 | else: 218 | raise ValueError('must use word or bert as input') 219 | 220 | if self.data_type == 'test': # no labels for test data 221 | continue 222 | # prepare label for training or validation set 223 | if not self.use_multi_task: 224 | all_cate = f"{self.data['cate1'][i]}|{self.data['cate2'][i]}|{self.data['cate3'][i]}" 225 | if (self.data_type == 'dev' or self.train_on_cv) and all_cate not in self.all_cate_vocab: 226 | batch_all_cate_ids.append(0) 227 | else: 228 | batch_all_cate_ids.append(self.all_cate_vocab[all_cate]) 229 | else: 230 | batch_cate1_ids.append(self.cate1_vocab[self.data['cate1'][i]]) 231 | if (self.data_type == 'dev' or self.train_on_cv) and self.data['cate2'][i] not in self.cate2_vocab: 232 | batch_cate2_ids.append(0) 233 | else: 234 | batch_cate2_ids.append(self.cate2_vocab[self.data['cate2'][i]]) 235 | if (self.data_type == 'dev' or self.train_on_cv) and self.data['cate3'][i] not in self.cate3_vocab: 236 | batch_cate3_ids.append(0) 237 | else: 238 | batch_cate3_ids.append(self.cate3_vocab[self.data['cate3'][i]]) 239 | 240 | # feature input 241 | if self.use_word_input: 242 | batch_inputs = pad_sequences_1d(batch_input_ids, max_len=self.max_len) 243 | else: 244 | batch_inputs = [np.array(batch_input_ids), np.array(batch_input_masks)] 245 | if batch_input_types: 246 | batch_inputs.append(np.array(batch_input_types)) 247 | 248 | if self.data_type == 'test': # no labels for test data 249 | return batch_inputs 250 | 251 | # label masking (only for training dataset) 252 | if self.use_multi_task and self.data_type == 'train': 253 | if self.use_mask_for_cate2: 254 | if not isinstance(batch_inputs, list): 255 | batch_inputs = [batch_inputs] 256 | batch_inputs.append(self.cate1_to_cate2_matrix[np.array(batch_cate1_ids)]) 257 | if self.use_mask_for_cate3: 258 | if not isinstance(batch_inputs, list): 259 | batch_inputs = [batch_inputs] 260 | if self.cate3_mask_type == 'cate1': 261 | batch_inputs.append(self.cate_to_cate3_matrix[np.array(batch_cate1_ids)]) 262 | elif self.cate3_mask_type == 'cate2': 263 | batch_inputs.append(self.cate_to_cate3_matrix[np.array(batch_cate2_ids)]) 264 | else: 265 | raise ValueError(f'`cate3_mask_type` not understood') 266 | 267 | # ground truth labels 268 | if not self.use_multi_task: 269 | batch_labels = to_categorical(batch_all_cate_ids, num_classes=len(self.all_cate_vocab)) 270 | else: 271 | batch_labels = [ 272 | to_categorical(batch_cate1_ids, num_classes=len(self.cate1_vocab)), 273 | to_categorical(batch_cate2_ids, num_classes=len(self.cate2_vocab)), 274 | to_categorical(batch_cate3_ids, num_classes=len(self.cate3_vocab)) 275 | ] 276 | 277 | return batch_inputs, batch_labels 278 | 279 | def prepare_text(self, name, desc): 280 | if self.input_type == 'name': 281 | return name 282 | elif self.input_type == 'desc': 283 | if not desc: 284 | return name 285 | else: 286 | return desc 287 | elif self.input_type == 'name_desc': 288 | if desc: 289 | if self.use_pair_input: 290 | return name, desc 291 | else: 292 | return f"{name} {desc}" 293 | else: 294 | if self.use_pair_input: 295 | return name, name 296 | else: 297 | return name 298 | else: 299 | raise ValueError(f'`input_type` not understood: {self.input_type}') 300 | 301 | @staticmethod 302 | def create_mask_matrix(cate1_to_cate2, n_cate1, n_cate2): 303 | mask_matrix = np.zeros(shape=(n_cate1, n_cate2)) 304 | for cate1 in cate1_to_cate2: 305 | for cate2 in cate1_to_cate2[cate1]: 306 | mask_matrix[cate1][cate2] = 1 307 | return mask_matrix 308 | 309 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: config.py 10 | 11 | @time: 2020/5/17 15:33 12 | 13 | @desc: 14 | 15 | """ 16 | 17 | from os import path 18 | 19 | RAW_DATA_DIR = './raw_data' 20 | PROCESSED_DATA_DIR = './data' 21 | LOG_DIR = './log' 22 | MODEL_SAVED_DIR = './ckpt' 23 | SUBMIT_DIR = './submit' 24 | IMG_DIR = './img' 25 | PREDICT_DIR = './predict' 26 | 27 | NLTK_DATA = path.join(RAW_DATA_DIR, 'nltk_data') 28 | RAW_TRAIN_FILENAME = path.join(RAW_DATA_DIR, 'train.json') 29 | RAW_DEV_FILENAME = path.join(RAW_DATA_DIR, 'validation.json') 30 | RAW_TEST_FILENAME = path.join(RAW_DATA_DIR, 'test_public.json') 31 | 32 | TRAIN_DATA_FILENAME = 'train.pkl' 33 | TRAIN_CV_DATA_TEMPLATE = 'train_{random}_{fold}_{index}.pkl' 34 | DEV_CV_DATA_TEMPLATE = 'dev_{random}_{fold}_{index}.pkl' 35 | DEV_DATA_FILENAME = 'dev.pkl' 36 | 37 | TEST_DATA_FILENAME = 'test.pkl' 38 | 39 | VOCABULARY_TEMPLATE = '{level}_vocab.pkl' 40 | IDX2TOKEN_TEMPLATE = 'idx2{level}.pkl' 41 | EMBEDDING_MATRIX_TEMPLATE = '{type}_embeddings.npy' 42 | PERFORMANCE_MD = 'performance.md' 43 | 44 | CATE1_TO_CATE2_DICT = 'cate1_to_cate2.pkk' 45 | CATE1_TO_CATE3_DICT = 'cate1_to_cate3.pkl' 46 | CATE2_TO_CATE3_DICT = 'cate2_to_cate3.pkl' 47 | 48 | CATE1_COUNT_DICT = 'cate1_count_dict.pkl' 49 | CATE2_COUNT_DICT = 'cate2_count_dict.pkl' 50 | CATE3_COUNT_DICT = 'cate3_count_dict.pkl' 51 | 52 | RANDOM_SEED = 2020 53 | 54 | EXTERNAL_EMBEDDINGS_DIR = path.join(RAW_DATA_DIR, 'embeddings') 55 | 56 | BERT_VOCAB_FILE = { 57 | 'bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-uncased', 'bert-base-uncased-vocab.txt'), 58 | 'bert-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-cased', 'bert-base-cased-vocab.txt'), 59 | 'bert-large-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased', 'bert-large-uncased-vocab.txt'), 60 | 'bert-large-uncased-whole-word-masking': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking', 'bert-large-uncased-whole-word-masking-vocab.txt'), 61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt'), 62 | 'roberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-base', 'roberta-base-vocab.json'), 63 | 'roberta-large': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large', 'roberta-large-vocab.json'), 64 | 'roberta-large-mnli': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large-mnli', 'roberta-large-mnli-vocab.json'), 65 | 'distilroberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilroberta-base', 'distilroberta-base-vocab.json'), 66 | 'prod-bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-bert-base-uncased', 'vocab.txt'), 67 | 'prod-roberta-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-roberta-base-cased', 'vocab.json'), 68 | 'tune_bert-base-uncased_nsp': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-uncased', 'bert-base-uncased-vocab.txt'), 69 | 'xlnet-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'xlnet-base-cased', 'xlnet-base-cased-spiece.model'), 70 | 'albert-base-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-base-v1', 'albert-base-v1-spiece.model'), 71 | 'albert-large-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-large-v1', 'albert-large-v1-spiece.model'), 72 | 'albert-xlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xlarge-v1', 'albert-xlarge-v1-spiece.model'), 73 | 'albert-xxlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xxlarge-v1', 'albert-xxlarge-v1-spiece.model'), 74 | 'gpt2': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2', 'gpt2-vocab.json'), 75 | 'gpt2-medium': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2-medium', 'gpt2-medium-vocab.json'), 76 | 'transfo-xl': path.join(EXTERNAL_EMBEDDINGS_DIR, 'transfo-xl', 'transfo-xl-wt103-vocab.json'), 77 | 'distilbert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased', 'bert-base-uncased-vocab.txt'), 78 | 'distilbert-base-uncased-distilled-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased-distilled-squad', 'bert-large-uncased-vocab.txt'), 79 | } 80 | 81 | BERT_CONFIG_FILE = { 82 | 'bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-uncased', 'bert-base-uncased-config.json'), 83 | 'bert-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-cased', 'bert-base-cased-config.json'), 84 | 'bert-large-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased', 'bert-large-uncased-config.json'), 85 | 'bert-large-uncased-whole-word-masking': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking', 'bert-large-uncased-whole-word-masking-config.json'), 86 | 'bert-large-uncased-whole-word-masking-finetuned-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking-finetuned-squad-config.json'), 87 | 'roberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-base', 'roberta-base-config.json'), 88 | 'roberta-large': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large', 'roberta-large-config.json'), 89 | 'roberta-large-mnli': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large-mnli', 'roberta-large-mnli-config.json'), 90 | 'distilroberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilroberta-base', 'distilroberta-base-config.json'), 91 | 'prod-bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-bert-base-uncased', 'config.json'), 92 | 'prod-roberta-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-roberta-base-cased', 'config.json'), 93 | 'tune_bert-base-uncased_nsp': path.join(EXTERNAL_EMBEDDINGS_DIR, 'tune_bert-base-uncased_nsp_neg_1', 'config.json'), 94 | 'xlnet-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'xlnet-base-cased', 'xlnet-base-cased-config.json'), 95 | 'albert-base-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-base-v1', 'albert-base-v1-config.json'), 96 | 'albert-large-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-large-v1', 'albert-large-v1-config.json'), 97 | 'albert-xlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xlarge-v1', 'albert-xlarge-v1-config.json'), 98 | 'albert-xxlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xxlarge-v1', 'albert-xxlarge-v1-config.json'), 99 | 'gpt2': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2', 'gpt2-config.json'), 100 | 'gpt2-medium': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2-medium', 'gpt2-medium-config.json'), 101 | 'transfo-xl': path.join(EXTERNAL_EMBEDDINGS_DIR, 'transfo-xl', 'transfo-xl-wt103-config.json'), 102 | 'distilbert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased', 'distilbert-base-uncased-config.json'), 103 | 'distilbert-base-uncased-distilled-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-distilled-squad-config.json'), 104 | } 105 | 106 | BERT_MODEL_FILE = { 107 | 'bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-uncased', 'bert-base-uncased-tf_model.h5'), 108 | 'bert-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-cased', 'bert-base-cased-tf_model.h5'), 109 | 'bert-large-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased', 'bert-large-uncased-tf_model.h5'), 110 | 'bert-large-uncased-whole-word-masking': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking', 'bert-large-uncased-whole-word-masking-tf_model.h5'), 111 | 'bert-large-uncased-whole-word-masking-finetuned-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5'), 112 | 'roberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-base', 'roberta-base-tf_model.h5'), 113 | 'roberta-large': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large', 'roberta-large-tf_model.h5'), 114 | 'roberta-large-mnli': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large-mnli', 'roberta-large-mnli-tf_model.h5'), 115 | 'distilroberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilroberta-base', 'distilroberta-base-tf_model.h5'), 116 | 'prod-bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-bert-base-uncased', 'pytorch_model.bin'), 117 | 'prod-roberta-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-roberta-base-cased', 'pytorch_model.bin'), 118 | 'tune_bert-base-uncased_nsp': path.join(EXTERNAL_EMBEDDINGS_DIR, 'tune_bert-base-uncased_nsp_neg_1', 'pytorch_model.bin'), 119 | 'xlnet-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'xlnet-base-cased', 'xlnet-base-cased-tf_model.h5'), 120 | 'albert-base-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-base-v1', 'albert-base-v1-with-prefix-tf_model.h5'), 121 | 'albert-large-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-large-v1', 'albert-large-v1-with-prefix-tf_model.h5'), 122 | 'albert-xlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xlarge-v1', 'albert-xlarge-v1-with-prefix-tf_model.h5'), 123 | 'albert-xxlarge-v1': path.join(EXTERNAL_EMBEDDINGS_DIR, 'albert-xxlarge-v1', 'albert-xxlarge-v1-with-prefix-tf_model.h5'), 124 | 'gpt2': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2', 'gpt2-tf_model.h5'), 125 | 'gpt2-medium': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2-medium', 'gpt2-medium-tf_model.h5'), 126 | 'transfo-xl': path.join(EXTERNAL_EMBEDDINGS_DIR, 'transfo-xl', 'transfo-xl-wt103-tf_mode.h5'), 127 | 'distilbert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased', 'distilbert-base-uncased-tf_model.h5'), 128 | 'distilbert-base-uncased-distilled-squad': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-distilled-squad-tf_model.h5'), 129 | } 130 | 131 | BERT_TORCH_MODEL_FILE = { 132 | 'bert-base-uncased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'bert-base-uncased', 'bert-base-uncased-pytorch_model.bin'), 133 | } 134 | 135 | BERT_MERGE_FILE = { 136 | 'roberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-base', 'roberta-base-merges.txt'), 137 | 'roberta-large': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large', 'roberta-large-merges.txt'), 138 | 'roberta-large-mnli': path.join(EXTERNAL_EMBEDDINGS_DIR, 'roberta-large-mnli', 'roberta-large-mnli-merges.txt'), 139 | 'distilroberta-base': path.join(EXTERNAL_EMBEDDINGS_DIR, 'distilroberta-base', 'distilroberta-base-merges.txt'), 140 | 'prod-roberta-base-cased': path.join(EXTERNAL_EMBEDDINGS_DIR, 'prod-roberta-base-cased', 'merges.txt'), 141 | 'gpt2': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2', 'gpt2-merges.txt'), 142 | 'gpt2-medium': path.join(EXTERNAL_EMBEDDINGS_DIR, 'gpt2-medium', 'gpt2-medium-merges.txt'), 143 | } 144 | 145 | TEXT_CORPUS_DIR = path.join(RAW_DATA_DIR, 'text_corpus') 146 | TEST_TEXT_COPRUS_DIR = path.join(RAW_DATA_DIR, 'test_text_corpus') 147 | 148 | MAX_LEN = { 149 | 'name': 20, 150 | 'desc': 200, 151 | 'name_desc': 200 152 | } 153 | 154 | 155 | class ModelConfig(object): 156 | def __init__(self): 157 | # model general configuration 158 | self.exp_name = None 159 | self.model_type = 'bert-base-uncased' 160 | 161 | # input general configuration 162 | self.input_type = 'name_desc' 163 | self.use_multi_task = True 164 | self.use_harl = False 165 | self.use_hal = False 166 | self.cate_embed_dim = 100 167 | 168 | self.use_word_input = False 169 | self.word_embed_trainable = True 170 | self.word_embed_type = None 171 | self.word_embeddings = None 172 | self.word_embed_dim = 300 173 | self.word_vocab = None 174 | self.word_vocab_size = -1 175 | 176 | self.use_bert_input = True 177 | self.bert_trainable = True 178 | self.use_bert_type = 'pooler' 179 | self.output_hidden_state = False 180 | self.n_last_hidden_layer = 0 181 | self.dense_after_bert = False 182 | self.use_pair_input = True 183 | 184 | self.max_len = None 185 | self.cate1_vocab = None 186 | self.cate2_vocab = None 187 | self.cate3_vocab = None 188 | self.all_cate_vocab = None 189 | self.idx2cate1 = None 190 | self.idx2cate2 = None 191 | self.idx2cate3 = None 192 | self.idx2all_cate = None 193 | self.n_cate1 = -1 194 | self.n_cate2 = -1 195 | self.n_cate3 = -1 196 | self.n_all_cate = -1 197 | 198 | # model training configuration 199 | self.share_father_pred = 'no' 200 | self.use_mask_for_cate2 = False 201 | self.use_mask_for_cate3 = False 202 | self.cate3_mask_type = 'cate1' 203 | self.cate1_to_cate2 = None 204 | self.cate1_to_cate2 = None 205 | self.cate2_to_cate3 = None 206 | self.cate1_to_cate3 = None 207 | self.cate_to_cate3 = None 208 | self.cate1_loss_weight = 1. 209 | self.cate2_loss_weight = 1. 210 | self.cate3_loss_weight = 1. 211 | self.cate1_count_dict = None 212 | self.cate2_count_dict = None 213 | self.cate3_count_dict = None 214 | 215 | self.batch_size = 32 216 | self.predict_batch_size = 32 217 | self.n_epoch = 50 218 | self.learning_rate = 2e-5 219 | self.optimizer = 'adam' 220 | self.use_focal_loss = False 221 | self.callbacks_to_add = None 222 | 223 | # checkpoint configuration 224 | self.checkpoint_dir = MODEL_SAVED_DIR 225 | self.checkpoint_monitor = 'val_f1' 226 | self.checkpoint_save_best_only = True 227 | self.checkpoint_save_weights_only = True 228 | self.checkpoint_save_weights_mode = 'max' 229 | self.checkpoint_verbose = 1 230 | 231 | # early_stopping configuration 232 | self.early_stopping_monitor = 'val_f1' 233 | self.early_stopping_mode = 'max' 234 | self.early_stopping_patience = 5 235 | self.early_stopping_verbose = 1 236 | 237 | # ensembler configuration 238 | self.swa_start = 10 239 | 240 | # lr scheduler configuration 241 | self.max_lr = 5e-5 242 | self.min_lr = 1e-5 243 | 244 | # cross validation 245 | self.train_on_cv = False 246 | self.cv_random_state = 42 247 | self.cv_fold = 5 248 | self.cv_index = 0 249 | 250 | # data augmentation 251 | self.exchange_pair = False 252 | self.exchange_threshold = 0.1 253 | 254 | self.use_pseudo_label = False 255 | self.pseudo_path = None 256 | self.pseudo_random_state = 42 257 | self.pseudo_rate = 0.1 258 | self.pseudo_index = 0 259 | 260 | 261 | class LanguageModelConfig: 262 | def __init__(self): 263 | self.model_name = None 264 | 265 | self.fine_tune = True 266 | self.model_type = 'bert-base-uncased' 267 | 268 | self.do_lm = True # train with language modeling 269 | self.do_mlm = True # train with masked language modeling 270 | self.lm_with_cls_corpus = True # other than product corpus, also use product classification dataset for LM 271 | 272 | self.do_nsp = False # train next sentence prediction with product classification dataset 273 | self.num_neg_sample = 1 274 | 275 | self.model_save_dir = None 276 | 277 | self.tokenizer_type = 'word_piece' 278 | self.lowercase = True 279 | self.vocab_size = 30000 280 | 281 | -------------------------------------------------------------------------------- /callbacks/ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: alexyang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: ensembler.py 10 | 11 | @time: 2019/4/13 20:39 12 | 13 | @desc: ensemble during training process 14 | 15 | """ 16 | import os 17 | import math 18 | 19 | import tensorflow.keras.backend as K 20 | from tensorflow.keras.callbacks import Callback 21 | 22 | 23 | class SWA(Callback): 24 | """ 25 | This callback implements a stochastic weight averaging (SWA) method with constant lr as presented in the paper - 26 | "Izmailov et al. Averaging Weights Leads to Wider Optima and Better Generalization" 27 | (https://arxiv.org/abs/1803.05407) 28 | Author's implementation: https://github.com/timgaripov/swa 29 | """ 30 | def __init__(self, swa_model, checkpoint_dir, model_name, swa_start=1): 31 | """ 32 | :param swa_model: the model that we will use to store the average of the weights once SWA begins 33 | :param checkpoint_dir: the directory where the model will be saved in 34 | :param model_name: the name of model we're training 35 | :param swa_start: the epoch when averaging begins. We generally pre-train the network for a certain amount of 36 | epochs to start (swa_start > 1), as opposed to starting to track the average from the 37 | very beginning. 38 | """ 39 | super(SWA, self).__init__() 40 | self.checkpoint_dir = checkpoint_dir 41 | self.model_name = model_name 42 | self.swa_start = swa_start 43 | self.swa_model = swa_model # the model that we will use to store the average of the weights once SWA begins 44 | 45 | def on_train_begin(self, logs=None): 46 | self.epoch = 0 47 | self.swa_n = 0 48 | '''Note: I found deep copy of a model with customized layer would give errors''' 49 | # self.swa_model = copy.deepcopy(self.model) # make a copy of the model we're training 50 | 51 | '''Note: Something wired still happen even though i use keras.models.clone_model method, so I build swa_model 52 | outside this callback and pass it as an argument. It's not fancy, but the best I can do :) 53 | ''' 54 | # self.swa_model = keras.models.clone_model(self.model) 55 | self.swa_model.set_weights(self.model.get_weights()) # see: https://github.com/keras-team/keras/issues/1765 56 | 57 | def on_epoch_end(self, epoch, logs=None): 58 | if (self.epoch + 1) >= self.swa_start: 59 | self.update_average_model() 60 | self.swa_n += 1 61 | 62 | self.epoch += 1 63 | 64 | def update_average_model(self): 65 | # update running average of parameters 66 | alpha = 1. / (self.swa_n + 1) 67 | for layer, swa_layer in zip(self.model.layers, self.swa_model.layers): 68 | weights = [] 69 | for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()): 70 | weights.append((1 - alpha) * w1 + alpha * w2) 71 | swa_layer.set_weights(weights) 72 | 73 | def on_train_end(self, logs=None): 74 | print('Logging Info - Saving SWA model checkpoint: %s_swa.hdf5\n' % self.model_name) 75 | self.swa_model.save_weights(os.path.join(self.checkpoint_dir, '{}_swa.hdf5'.format(self.model_name))) 76 | print('Logging Info - SWA model Saved') 77 | 78 | 79 | class SWAWithCLR(Callback): 80 | """ 81 | SWA with cyclical learning rate, collect model 82 | """ 83 | 84 | def __init__(self, swa_model, checkpoint_dir, model_name, min_lr, max_lr, cycle_length, swa_start=1): 85 | super(SWAWithCLR, self).__init__() 86 | 87 | self.checkpoint_dir = checkpoint_dir 88 | self.model_name = model_name 89 | self.swa_start = swa_start 90 | self.swa_model = swa_model # the model that we will use to store the average of the weights once SWA begins 91 | 92 | self.min_lr = min_lr 93 | self.max_lr = max_lr 94 | self.cycle_length = cycle_length 95 | self.trn_iteration = 0. 96 | self.cycle = 0. 97 | self.swa_n = 0 98 | self.history = {} 99 | 100 | def on_train_begin(self, logs={}): 101 | K.set_value(self.model.optimizer.lr, self.max_lr) 102 | # self.swa_model = keras.models.clone_model(self.model) 103 | self.swa_model.set_weights(self.model.get_weights()) # see: https://github.com/keras-team/keras/issues/1765 104 | 105 | def on_batch_end(self, epoch, logs=None): 106 | logs = logs or {} 107 | self.trn_iteration += 1 108 | 109 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 110 | self.history.setdefault('iterations', []).append(self.trn_iteration) 111 | for k, v in logs.items(): 112 | self.history.setdefault(k, []).append(v) 113 | 114 | self.swa() 115 | 116 | def swa(self): 117 | t = (self.trn_iteration % self.cycle_length) / self.cycle_length 118 | lr = (1 - t) * self.max_lr + t * self.min_lr 119 | K.set_value(self.model.optimizer.lr, lr) 120 | 121 | if t == 0: 122 | self.cycle += 1 123 | if self.cycle >= self.swa_start: 124 | self.update_average_model() 125 | self.swa_n += 1 126 | 127 | def update_average_model(self): 128 | # update running average of parameters 129 | alpha = 1. / (self.swa_n + 1) 130 | for layer, swa_layer in zip(self.model.layers, self.swa_model.layers): 131 | weights = [] 132 | for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()): 133 | weights.append((1 - alpha) * w1 + alpha * w2) 134 | swa_layer.set_weights(weights) 135 | 136 | def on_train_end(self, logs=None): 137 | print('Logging Info - Saving SWA model checkpoint: %s_swa_with_clr.hdf5\n' % self.model_name) 138 | self.swa_model.save_weights(os.path.join(self.checkpoint_dir, '{}_swa_clr.hdf5'.format(self.model_name))) 139 | print('Logging Info - SWA model loaded') 140 | 141 | 142 | class SnapshotEnsemble(Callback): 143 | """ 144 | This Callback implements `Snapshot Ensemble` method, which can produce an ensemble 145 | of accurate and diverse models from a single training process. 146 | 147 | SnapShot Ensemble using a cosine annealing learning rate schedule where learning rate starts high and is drooped 148 | relatively rapidly to a minimum value near zero before bedding increased again to the maximum. This lr schedule 149 | is similar as SGDR(https://arxiv.org/pdf/1608.03983.pdf) 150 | 151 | Snapshot Ensemble is proposed by Huang et al. "Snapshot Ensebles: Train 1, Get M for free" 152 | (https://arxiv.org/pdf/1704.00109.pdf) 153 | """ 154 | def __init__(self, checkpoint_dir, model_name, max_lr, cycle_length, snapshot_start=1): 155 | """ 156 | :param checkpoint_dir: where to save the snapshot model 157 | :param model_name: the name of model we're training 158 | :param max_lr: upper bound learning rate 159 | :param cycle_length: the number of iterations (1 mini-batch is 1 iteration) in an cycle, generally we set it to 160 | np.ceil(train_iterations / n_cycle) 161 | :param snapshot_start: epoch to start snapshot ensemble 162 | """ 163 | self.checkpoint_dir = checkpoint_dir 164 | self.model_name = model_name 165 | self.max_lr = max_lr 166 | self.cycle_length = cycle_length 167 | self.snapshot_start = snapshot_start 168 | self.n_cycle = 0 169 | self.trn_iteration = 0 170 | self.history = {} 171 | super(SnapshotEnsemble, self).__init__() 172 | 173 | def on_train_begin(self, logs=None): 174 | K.set_value(self.model.optimizer.lr, self.max_lr) 175 | 176 | def on_batch_end(self, batch, logs=None): 177 | logs = logs or {} 178 | 179 | self.trn_iteration += 1 180 | self.history.setdefault('iteration', []).append(self.trn_iteration) 181 | self.history.setdefault('lr', []).append(self.model.optimizer.lr) 182 | for k, v in logs.items(): 183 | self.history.setdefault(k, []).append(v) 184 | 185 | self.snapshot() 186 | 187 | def snapshot(self): 188 | # update lr 189 | fraction_to_restart = self.trn_iteration % self.cycle_length / self.cycle_length 190 | lr = 0.5 * self.max_lr * (math.cos(fraction_to_restart * math.pi) + 1) 191 | K.set_value(self.model.optimizer.lr, lr) 192 | 193 | '''Check for the end of cycle''' 194 | if fraction_to_restart == 0: 195 | self.n_cycle += 1 196 | if self.n_cycle >= self.snapshot_start: 197 | snapshot_id = self.n_cycle - self.snapshot_start 198 | # print('Logging Info - Iteration %s : Saving Snapshot Ensemble model checkpoint: %s_snapshot_%d.hdf5\n' 199 | # % (self.trn_iteration, self.model_name, snapshot_id)) 200 | self.model.save_weights(os.path.join(self.checkpoint_dir, '{}_sse_{}.hdf5'.format(self.model_name, 201 | snapshot_id))) 202 | # print('Logging Info - Snapshot Ensemble model Saved') 203 | 204 | 205 | class HorizontalEnsemble(Callback): 206 | """ 207 | This Callback implements `Horizontal Ensemble` method, which is proposed by 208 | Xie et al. "Horizontal and Vertical Ensemble with Deep Representation for Classification" 209 | (https://arxiv.org/pdf/1306.2759.pdf) 210 | """ 211 | def __init__(self, checkpoint_dir, model_name, horizontal_start=1): 212 | """ 213 | :param checkpoint_dir: where to save the snapshot model 214 | :param model_name: the name of model we're training 215 | :param horizontal_start: epoch to start horizontal ensemble 216 | """ 217 | self.checkpoint_dir = checkpoint_dir 218 | self.model_name = model_name 219 | self.horizontal_start = horizontal_start 220 | super(HorizontalEnsemble, self).__init__() 221 | 222 | def on_epoch_end(self, epoch, logs=None): 223 | if epoch + 1 >= self.horizontal_start: 224 | model_id = epoch + 1 - self.horizontal_start 225 | self.model.save_weights(os.path.join(self.checkpoint_dir, 226 | '{}_horizontal_{}.hdf5'.format(self.model_name, model_id))) 227 | 228 | 229 | class FGE(Callback): 230 | """ 231 | This Callback implement `Fast Geometruc Ensembling` (FGE) 232 | """ 233 | def __init__(self, checkpoint_dir, model_name, min_lr, max_lr, cycle_length, fge_start=1): 234 | """ 235 | :param checkpoint_dir: where to save the snapshot model 236 | :param model_name: the name of model we're training 237 | :param max_lr: upper bound learning rate 238 | :param min_lr: lower bound lr 239 | :param cycle_length: the number of iterations (1 mini-batch is 1 iteration) in an cycle, generally we set it to 240 | np.ceil(train_iterations / n_cycle) 241 | :param fge_start: epoch to start fge ensemble 242 | """ 243 | self.checkpoint_dir = checkpoint_dir 244 | self.model_name = model_name 245 | self.max_lr = max_lr 246 | self.min_lr = min_lr 247 | self.cycle_length = cycle_length 248 | self.fge_start = fge_start 249 | self.n_cycle = 0 250 | self.trn_iteration = 0 251 | self.history = {} 252 | super(FGE, self).__init__() 253 | 254 | def on_train_begin(self, logs=None): 255 | K.set_value(self.model.optimizer.lr, self.max_lr) 256 | 257 | def on_batch_end(self, batch, logs=None): 258 | logs = logs or {} 259 | 260 | self.trn_iteration += 1 261 | self.history.setdefault('iteration', []).append(self.trn_iteration) 262 | self.history.setdefault('lr', []).append(self.model.optimizer.lr) 263 | for k, v in logs.items(): 264 | self.history.setdefault(k, []).append(v) 265 | 266 | self.fge() 267 | 268 | def fge(self): 269 | # update lr 270 | t = self.trn_iteration % self.cycle_length / self.cycle_length 271 | if t <= 0.5: 272 | lr = (1 - 2 * t) * self.max_lr + 2 * t * self.min_lr 273 | else: 274 | lr = (2 - 2 * t) * self.max_lr + (2 * t - 1) * self.min_lr 275 | K.set_value(self.model.optimizer.lr, lr) 276 | 277 | # when the learning rate reaches its minimum value, collect model 278 | if t == 0.5: 279 | self.n_cycle += 1 280 | if self.n_cycle >= self.fge_start: 281 | fge_id = self.n_cycle - self.fge_start 282 | self.model.save_weights(os.path.join(self.checkpoint_dir, '{}_fge_{}.hdf5'.format(self.model_name, 283 | fge_id))) 284 | 285 | 286 | class PolyakAverage(Callback): 287 | """ 288 | Same as swa (https://arxiv.org/abs/1803.05407), this callback combine the weights from multiple different models 289 | into a single model for making predictions by using an average of the weights from models seen toward the end of 290 | the training run, which can be further improved by using linearly or exponentially decreasing weighted average. 291 | 292 | Reference: https://machinelearningmastery.com/ensemble-methods-for-deep-learning-neural-networks/ 293 | """ 294 | def __init__(self, polyak_model, temp_polyak_model, checkpoint_dir, model_name, avg_type, early_stop_patience, 295 | polyak_start=1): 296 | """ 297 | :param polyak_model: the model that we will use to store the average of the weights 298 | :param checkpoint_dir: the directory where the model will be saved in 299 | :param model_name: the name of model we're training 300 | :param polyak_start: the epoch when averaging begins. We generally pre-train the network for a certain amount of 301 | epochs to start (polyak_start > 1), as opposed to starting to track the average from the 302 | very beginning. 303 | """ 304 | super(PolyakAverage, self).__init__() 305 | self.polyak_model = polyak_model 306 | self.temp_polyak_model = temp_polyak_model 307 | self.checkpoint_dir = checkpoint_dir 308 | self.model_name = model_name 309 | self.avg_type = avg_type 310 | assert avg_type in ['avg', 'linear', 'exp'] 311 | self.early_stop_patience = early_stop_patience 312 | self.polyak_start = polyak_start 313 | 314 | def on_train_begin(self, logs=None): 315 | self.epoch = 0 316 | self.polyak_n = 0 317 | 318 | def on_epoch_end(self, epoch, logs=None): 319 | if (self.epoch + 1) >= self.polyak_start: 320 | self.model.save_weights(os.path.join(self.checkpoint_dir, '{}_polyak_{}_{}.hdf5'.format(self.model_name, self.avg_type, self.polyak_n))) 321 | self.polyak_n += 1 322 | 323 | self.epoch += 1 324 | 325 | def on_train_end(self, logs=None): 326 | self.polyak_n -= self.early_stop_patience 327 | if self.avg_type == 'avg': 328 | contrib = [1.0 / self.polyak_n for _ in range(self.polyak_n)] 329 | elif self.avg_type == 'linear': 330 | contrib = [i / self.polyak_n for i in range(1, self.polyak_n+1)] 331 | norm = sum(contrib) 332 | contrib = [i / norm for i in contrib] 333 | elif self.avg_type == 'exp': 334 | alpha = 2.0 335 | contrib = [math.exp(-i/alpha) for i in range(self.polyak_n, 0, -1)] 336 | norm = sum(contrib) 337 | contrib = [i / norm for i in contrib] 338 | else: 339 | raise ValueError('avg_type not understood: {}'.format(self.avg_type)) 340 | for polyak_id in range(self.polyak_n): 341 | print('Logging Info - Loading Polyak model checkpoint: {}_polyak_{}_{}.hdf5'.format(self.model_name, self.avg_type, polyak_id)) 342 | self.temp_polyak_model.load_weights(os.path.join(self.checkpoint_dir, '{}_polyak_{}_{}.hdf5'.format(self.model_name, self.avg_type, polyak_id))) 343 | for layer, temp_layer in zip(self.polyak_model.layers, self.temp_polyak_model.layers): 344 | weights = [] 345 | for w1, w2 in zip(layer.get_weights(), temp_layer.get_weights()): 346 | if polyak_id == 0: 347 | weights.append(contrib[polyak_id] * w2) 348 | else: 349 | weights.append(w1 + contrib[polyak_id] * w2) 350 | layer.set_weights(weights) 351 | os.remove(os.path.join(self.checkpoint_dir, '{}_polyak_{}_{}.hdf5'.format(self.model_name, self.avg_type, polyak_id))) 352 | 353 | for polyak_id in range(self.polyak_n, self.polyak_n+self.early_stop_patience): 354 | os.remove(os.path.join(self.checkpoint_dir, '{}_polyak_{}_{}.hdf5'.format(self.model_name, self.avg_type, polyak_id))) 355 | 356 | print('Logging Info - Saving Polyak model checkpoint: %s_polayk.hdf5\n' % self.model_name) 357 | self.polyak_model.save_weights(os.path.join(self.checkpoint_dir, '{}_polyak_{}.hdf5'.format(self.model_name, self.avg_type))) 358 | print('Logging Info - Polyak model Saved') 359 | -------------------------------------------------------------------------------- /callbacks/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: alexyang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: lr_scheduler.py 10 | 11 | @time: 2019/4/13 20:28 12 | 13 | @desc: different learning rate scheduler 14 | 15 | """ 16 | import os 17 | import math 18 | import numpy as np 19 | import tensorflow.keras.backend as K 20 | from tensorflow.keras.callbacks import Callback 21 | 22 | from config import IMG_DIR 23 | 24 | # Force matplotlib to not use any Xwindows backend. 25 | # See: https://stackoverflow.com/questions/2801882/generating-a-png-with-matplotlib-when-display-is-undefined 26 | import matplotlib 27 | matplotlib.use('Agg') 28 | import matplotlib.pyplot as plt 29 | 30 | 31 | class LRScheduler(Callback): 32 | def __init__(self, alpha=0.2, sma=20, plot=False, save_plot_prefix=None): 33 | """ 34 | :param alpha: parameter for exponential moving average 35 | :param sma: number of batches for simple moving average 36 | :param plot: whether to plot figures to show the results of the experiment 37 | :param save_plot_prefix: where to save the figure 38 | """ 39 | super(LRScheduler, self).__init__() 40 | self.history = {} 41 | self.alpha = alpha 42 | self.sma = sma 43 | self.plot = plot 44 | self.save_plot_prefix = save_plot_prefix 45 | 46 | def update_lr(self): 47 | raise NotImplementedError 48 | 49 | def on_train_end(self, logs=None): 50 | if self.plot: 51 | self.plot_lr() 52 | self.plot_loss(valid_loss=False) # plot train loss 53 | self.plot_loss(valid_loss=True) # plot valid loss 54 | self.plot_acc(valid_acc=False) # plot train acc 55 | self.plot_acc(valid_acc=True) # plot valid acc 56 | 57 | def plot_lr(self): 58 | """ 59 | plot the learning rate w.r.t iteration 60 | """ 61 | if 'iteration' in self.history: 62 | iteration = self.history['iteration'] 63 | lr = self.history['lr'] 64 | if self.save_plot_prefix: 65 | lr_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_lr.png') 66 | else: 67 | lr_save_path = None 68 | self.plot_figure(iteration, lr, 'iteration', 'learning rate', yscale='log', save_path=lr_save_path) 69 | 70 | def plot_loss(self, valid_loss: bool): 71 | """ 72 | plot the loss w.r.t learning rate 73 | """ 74 | loss_key = 'val_loss' if valid_loss else 'loss' 75 | if loss_key not in self.history: 76 | return 77 | lr = self.history['lr'] 78 | loss = self.history[loss_key] 79 | avg_loss = self.exp_move_avg(loss, self.alpha) 80 | 81 | loss_derivates = [0] * self.sma 82 | for i in range(self.sma, len(loss)): 83 | loss_derivates.append((loss[i] - loss[i - self.sma]) / self.sma) 84 | 85 | if self.save_plot_prefix: 86 | loss_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_%s_lr.png' % loss_key) 87 | avg_loss_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_avg_%s_lr.png' % loss_key) 88 | loss_derivate_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_%s_derivate_lr.png' % loss_key) 89 | else: 90 | loss_save_path, avg_loss_save_path, loss_derivate_save_path = None, None, None 91 | 92 | self.plot_figure(lr, loss, 'learning rate', 'loss', xscale='log', save_path=loss_save_path) 93 | self.plot_figure(lr, avg_loss, 'learning rate', 'exp_move_avg loss', xscale='log', 94 | save_path=avg_loss_save_path) 95 | self.plot_figure(lr, loss_derivates, 'learning rate', 'rate of loss change', xscale='log', 96 | save_path=loss_derivate_save_path) 97 | 98 | def plot_acc(self, valid_acc: bool): 99 | """ 100 | plot the accuracy with respect to learning rate 101 | """ 102 | acc_key = 'val_acc' if valid_acc else 'acc' 103 | if acc_key in self.history: 104 | lr = self.history['lr'] 105 | acc = self.history[acc_key] 106 | avg_acc = self.exp_move_avg(acc, self.alpha) 107 | if self.save_plot_prefix: 108 | acc_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_%s_lr.png' % acc_key) 109 | avg_acc_save_path = os.path.join(IMG_DIR, self.save_plot_prefix+'_avg_%s_lr.png' % acc_key) 110 | else: 111 | acc_save_path, avg_acc_save_path = None, None 112 | self.plot_figure(lr, acc, 'learning rate', 'acc', xscale='log', save_path=acc_save_path) 113 | self.plot_figure(lr, avg_acc, 'learing_rate', 'exp_move_avg acc', xscale='log', save_path=avg_acc_save_path) 114 | 115 | @staticmethod 116 | def plot_figure(x, y, xlabel, ylabel, xscale=None, yscale=None, show=False, save_path=None): 117 | plt.clf() 118 | plt.ylabel(ylabel) 119 | plt.xlabel(xlabel) 120 | plt.plot(x, y) 121 | 122 | if xscale is not None: 123 | plt.xscale(xscale) 124 | if yscale is not None: 125 | plt.yscale(yscale) 126 | 127 | if show: 128 | plt.show() 129 | if save_path: 130 | plt.savefig(save_path) 131 | print('Logging Info - Plot Figure has save to:', save_path) 132 | 133 | @staticmethod 134 | def exp_move_avg(_list, alpha): 135 | """exponential moving average""" 136 | cur_value = 0 137 | avg_list = [] 138 | for i in _list: 139 | cur_value = (1-alpha)*cur_value + alpha*i 140 | avg_list.append(cur_value) 141 | avg_list[0] = avg_list[1] 142 | return avg_list 143 | 144 | 145 | # Note: Modify based on https://github.com/surmenok/keras_lr_finder/blob/master/keras_lr_finder/lr_finder.py 146 | class LRRangeTest(LRScheduler): 147 | def __init__(self, num_batches, start_lr=1e-5, end_lr=1., alpha=0.2, sma=20, plot=True, save_plot_prefix=None): 148 | """ 149 | This callback implements LR Range Test as presented in section 3.3 of the 2015 paper - 150 | "Cyclical Learning Rates for Training Neural Networks" (https://arxiv.org/abs/1506.01186), which is used to 151 | estimate an optimal learning rate for dnn. 152 | 153 | This idea behind is to train a network starting from a low learning rate and increase the learning rate 154 | exponentially for every mini-batch. Record the learning rate and training loss for every batch. Then, plot the 155 | loss (also the rate of change of the loss) and the learning rate. Then select a lr range with the fastest 156 | decrease in the loss. 157 | 158 | For more detail, please see the paper and this blog: 159 | https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0. 160 | 161 | # Example 162 | ```python 163 | lrtest = LRRangeTest(num_batches=128) 164 | model.fit(X_train, Y_train, callbacks=[lrtest]) 165 | ``` 166 | :param num_batches: num of iterations to run lr range test 167 | :param start_lr: the lower bound of the lr range 168 | :param end_lr: the upper bpound of the lr range 169 | :param alpha: parameter for exponential moving average 170 | :param sma: number of batches for simple moving average 171 | :param plot: whether to plot figures to show the results of the experiment 172 | :param save_plot_prefix: where to save the figure 173 | """ 174 | self.start_lr = start_lr 175 | self.end_lr = end_lr 176 | self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(num_batches)) # exponentially increase 177 | self.best_loss = 1e9 178 | self.trn_iterations = 0 179 | super(LRRangeTest, self).__init__(alpha, sma, plot, 180 | save_plot_prefix + '_lr_range_test' if save_plot_prefix else None) 181 | 182 | def on_train_begin(self, logs=None): 183 | # set initial learning rate 184 | K.set_value(self.model.optimizer.lr, self.start_lr) 185 | 186 | def on_batch_end(self, epoch, logs=None): 187 | logs = logs or {} 188 | self.trn_iterations += 1 189 | 190 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 191 | self.history.setdefault('iteration', []).append(self.trn_iterations) 192 | for k, v in logs.items(): 193 | self.history.setdefault(k, []).append(v) 194 | 195 | loss = logs['loss'] 196 | if math.isnan(loss) or loss > self.best_loss * 4: 197 | # stop training whe the loss gets a lot higher than the previously observed best value 198 | print('Loggin Info - LR Range Test training end!') 199 | self.model.stop_training = True 200 | return 201 | 202 | if loss < self.best_loss: 203 | self.best_loss = loss 204 | self.update_lr() 205 | 206 | def update_lr(self): 207 | # increse the learning rate exponentially for next batch 208 | lr = K.get_value(self.model.optimizer.lr) 209 | lr *= self.lr_mult 210 | K.set_value(self.model.optimizer.lr, lr) 211 | 212 | 213 | # Note: Copy from https://github.com/bckenstler/CLR/blob/master/clr_callback.py 214 | class CyclicLR(LRScheduler): 215 | """This callback implements a cyclical learning rate policy (CLR). 216 | The method cycles the learning rate between two boundaries with 217 | some constant frequency, as detailed in this paper (https://arxiv.org/abs/1506.01186). 218 | The amplitude of the cycle can be scaled on a per-iteration or 219 | per-cycle basis. 220 | This class has three built-in policies, as put forth in the paper. 221 | "triangular": 222 | A basic triangular cycle w/ no amplitude scaling. 223 | "triangular2": 224 | A basic triangular cycle that scales initial amplitude by half each cycle. 225 | "exp_range": 226 | A cycle that scales initial amplitude by gamma**(cycle iterations) at each 227 | cycle iteration. 228 | For more detail, please see paper. 229 | 230 | # Example 231 | ```python 232 | clr = CyclicLR(base_lr=0.001, max_lr=0.006, 233 | step_size=2000., mode='triangular') 234 | model.fit(X_train, Y_train, callbacks=[clr]) 235 | ``` 236 | 237 | Class also supports custom scaling functions: 238 | ```python 239 | clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.)) 240 | clr = CyclicLR(base_lr=0.001, max_lr=0.006, 241 | step_size=2000., scale_fn=clr_fn, 242 | scale_mode='cycle') 243 | model.fit(X_train, Y_train, callbacks=[clr]) 244 | ``` 245 | # Arguments 246 | base_lr: initial learning rate which is the 247 | lower boundary in the cycle. 248 | max_lr: upper boundary in the cycle. Functionally, 249 | it defines the cycle amplitude (max_lr - base_lr). 250 | The lr at any cycle is the sum of base_lr 251 | and some scaling of the amplitude; therefore 252 | max_lr may not actually be reached depending on 253 | scaling function. 254 | step_size: number of training iterations per 255 | half cycle. Authors suggest setting step_size 256 | 2-8 x training iterations in epoch. 257 | mode: one of {triangular, triangular2, exp_range}. 258 | Default 'triangular'. 259 | Values correspond to policies detailed above. 260 | If scale_fn is not None, this argument is ignored. 261 | gamma: constant in 'exp_range' scaling function: 262 | gamma**(cycle iterations) 263 | scale_fn: Custom scaling policy defined by a single 264 | argument lambda function, where 265 | 0 <= scale_fn(x) <= 1 for all x >= 0. 266 | mode paramater is ignored 267 | scale_mode: {'cycle', 'iterations'}. 268 | Defines whether scale_fn is evaluated on 269 | cycle number or cycle iterations (training 270 | iterations since start of cycle). Default is 'cycle'. 271 | alpha: parameter for exponential moving average 272 | sma: number of batches for simple moving average 273 | """ 274 | 275 | def __init__(self, base_lr=0.001, max_lr=0.006, step_size=2000., mode='triangular', 276 | gamma=1., scale_fn=None, scale_mode='cycle', alpha=0.2, sma=20, plot=False, save_plot_prefix=None): 277 | self.base_lr = base_lr 278 | self.max_lr = max_lr 279 | self.step_size = step_size 280 | self.mode = mode 281 | self.gamma = gamma 282 | if scale_fn is None: 283 | if self.mode == 'triangular': 284 | self.scale_fn = lambda x: 1. 285 | self.scale_mode = 'cycle' 286 | elif self.mode == 'triangular2': 287 | self.scale_fn = lambda x: 1 / (2. ** (x - 1)) 288 | self.scale_mode = 'cycle' 289 | elif self.mode == 'exp_range': 290 | self.scale_fn = lambda x: gamma ** (x - 1) 291 | self.scale_mode = 'iterations' 292 | else: 293 | self.scale_fn = scale_fn 294 | self.scale_mode = scale_mode 295 | self.clr_iterations = 0. 296 | super(CyclicLR, self).__init__(alpha, sma, plot, 297 | save_plot_prefix + '_cyclic' if save_plot_prefix else None) 298 | 299 | def on_train_begin(self, logs={}): 300 | if self.clr_iterations == 0: 301 | K.set_value(self.model.optimizer.lr, self.base_lr) 302 | else: 303 | K.set_value(self.model.optimizer.lr, self.update_lr()) 304 | 305 | def on_batch_end(self, epoch, logs=None): 306 | logs = logs or {} 307 | self.clr_iterations += 1 308 | 309 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 310 | self.history.setdefault('iteration', []).append(self.clr_iterations) 311 | 312 | for k, v in logs.items(): 313 | self.history.setdefault(k, []).append(v) 314 | self.update_lr() 315 | 316 | def update_lr(self): 317 | cycle = np.floor(1 + self.clr_iterations / (2 * self.step_size)) 318 | x = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1) 319 | if self.scale_mode == 'cycle': 320 | lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) * self.scale_fn(cycle) 321 | else: 322 | lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) * self.scale_fn(self.clr_iterations) 323 | K.set_value(self.model.optimizer.lr, lr) 324 | 325 | 326 | class SGDR(LRScheduler): 327 | """ 328 | This Callback implements "stochastic gradient descent with warm restarts", a similar approach cyclic approach where 329 | an cosine annealing schedule is combined with periodic "restarts" to the original starting learning rate. 330 | See "https://arxiv.org/pdf/1608.03983.pdf" for more details 331 | """ 332 | def __init__(self, min_lr, max_lr, cycle_length, lr_decay=1, mult_factor=2, alpha=0.2, sma=20, plot=False, 333 | save_plot_prefix=None): 334 | """ 335 | :param min_lr: 336 | :param max_lr: max_lr and min_lr define the range of desired learning rates 337 | :param cycle_length: the number of iterations (1 mini-batch is 1 iteration) in an cycle 338 | :param lr_decay: reduce the max_lr after the completion of each cycle 339 | Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. 340 | :param mult_factor: scale cycle_length after each full cycle completion 341 | :param alpha: parameter for exponential moving average 342 | :param sma: number of batches for simple moving average 343 | :param plot: whether to plot figures to show the results of the experiment 344 | :param save_plot_prefix: where to save the figure 345 | """ 346 | self.min_lr = min_lr 347 | self.max_lr = max_lr 348 | self.cycle_length = cycle_length 349 | self.lr_decay = lr_decay 350 | self.mult_factor = mult_factor 351 | self.trn_iterations = 0 # the number of iteration since training 352 | self.cur_iterations = 0 # the number of iterations since the last restart 353 | 354 | super(SGDR, self).__init__(alpha, sma, plot, save_plot_prefix + '_sgdr' if save_plot_prefix else None) 355 | 356 | def on_train_begin(self, logs=None): 357 | K.set_value(self.model.optimizer.lr, self.max_lr) 358 | 359 | def on_batch_end(self, batch, logs=None): 360 | logs = logs or {} 361 | self.trn_iterations += 1 362 | self.cur_iterations += 1 363 | 364 | self.history.setdefault('iteration', []).append(self.trn_iterations) 365 | self.history.setdefault('lr', []).append(self.model.optimizer.lr) 366 | for k, v in logs.items(): 367 | self.history.setdefault(k, []).append(v) 368 | 369 | self.update_lr() 370 | 371 | def update_lr(self): 372 | fraction_to_restart = self.cur_iterations / self.cycle_length 373 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(fraction_to_restart * math.pi)) 374 | K.set_value(self.model.optimizer.lr, lr) 375 | 376 | '''Check for end of current cycle''' 377 | if fraction_to_restart == 1.: 378 | self.cur_iterations = 0 379 | self.max_lr = self.lr_decay * self.max_lr 380 | self.cycle_length = math.ceil(self.cycle_length * self.mult_factor) 381 | 382 | 383 | # copy from https://gist.github.com/jeremyjordan/5a222e04bb78c242f5763ad40626c452#file-sgdr-py 384 | class SGDRScheduler(LRScheduler): 385 | """Cosine annealing learning rate scheduler with periodic restarts. 386 | # Usage 387 | ```python 388 | schedule = SGDRScheduler(min_lr=1e-5, 389 | max_lr=1e-2, 390 | steps_per_epoch=np.ceil(epoch_size/batch_size), 391 | lr_decay=0.9, 392 | cycle_length=5, 393 | mult_factor=1.5) 394 | model.fit(X_train, Y_train, epochs=100, callbacks=[schedule]) 395 | ``` 396 | # Arguments 397 | min_lr: The lower bound of the learning rate range for the experiment. 398 | max_lr: The upper bound of the learning rate range for the experiment. 399 | steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 400 | lr_decay: Reduce the max_lr after the completion of each cycle. 401 | Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. 402 | cycle_length: Initial number of epochs in a cycle. 403 | mult_factor: Scale epochs_to_restart after each full cycle completion. 404 | # References 405 | Blog post: jeremyjordan.me/nn-learning-rate 406 | Original paper: http://arxiv.org/abs/1608.03983 407 | """ 408 | def __init__(self, min_lr, max_lr, steps_per_epoch, lr_decay=1, cycle_length=10, mult_factor=2, 409 | alpha=0.2, sma=20, plot=False, save_plot_prefix=None): 410 | self.min_lr = min_lr 411 | self.max_lr = max_lr 412 | self.lr_decay = lr_decay 413 | 414 | self.batch_since_restart = 0 415 | self.next_restart = cycle_length 416 | 417 | self.steps_per_epoch = steps_per_epoch 418 | 419 | self.cycle_length = cycle_length 420 | self.mult_factor = mult_factor 421 | 422 | super(SGDRScheduler, self).__init__(alpha, sma, plot, save_plot_prefix + '_sgdr2' if save_plot_prefix else None) 423 | 424 | def on_train_begin(self, logs={}): 425 | # Initialize the learning rate to the minimum value at the start of training. 426 | K.set_value(self.model.optimizer.lr, self.max_lr) 427 | 428 | def on_batch_end(self, batch, logs={}): 429 | # Record previous batch statistics and update the learning rate. 430 | logs = logs or {} 431 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 432 | for k, v in logs.items(): 433 | self.history.setdefault(k, []).append(v) 434 | 435 | self.batch_since_restart += 1 436 | 437 | def on_epoch_end(self, epoch, logs={}): 438 | # Check for end of current cycle, apply restarts when necessary. 439 | if epoch + 1 == self.next_restart: 440 | self.batch_since_restart = 0 441 | self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) 442 | self.next_restart += self.cycle_length 443 | self.max_lr *= self.lr_decay 444 | 445 | def update_lr(self): 446 | # Calculate the learning rate. 447 | fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length) 448 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi)) 449 | K.set_value(self.model.optimizer.lr, lr) 450 | 451 | 452 | class CyclicLR_1(LRScheduler): 453 | """ 454 | This Callback implements a cyclic learning schedule introduced in the paper - 455 | "Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs" (https://arxiv.org/pdf/1802.10026.pdf) 456 | """ 457 | def __init__(self, min_lr, max_lr, cycle_length, alpha=0.2, sma=20, plot=False, save_plot_prefix=None): 458 | self.min_lr = min_lr 459 | self.max_lr = max_lr 460 | self.cycle_length = cycle_length 461 | self.trn_iteration = 0 462 | super(CyclicLR_1, self).__init__(alpha, sma, plot, 463 | save_plot_prefix + '_cyclic_1' if save_plot_prefix else None) 464 | 465 | def on_train_begin(self, logs=None): 466 | K.set_value(self.model.optimizer.lr, self.max_lr) 467 | 468 | def on_batch_end(self, batch, logs=None): 469 | logs = logs or {} 470 | self.trn_iteration += 1 471 | 472 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 473 | self.history.setdefault('iteration', []).append(self.trn_iteration) 474 | 475 | for k, v in logs.items(): 476 | self.history.setdefault(k, []).append(v) 477 | self.update_lr() 478 | 479 | def update_lr(self): 480 | t = (self.trn_iteration % self.cycle_length) / self.cycle_length 481 | if t <= 0.5: 482 | lr = (1 - 2 * t) * self.max_lr + 2 * t * self.min_lr 483 | else: 484 | lr = (2 - 2 * t) * self.max_lr + (2 * t - 1) * self.min_lr 485 | K.set_value(self.model.optimizer.lr, lr) 486 | 487 | 488 | class CyclicLR_2(LRScheduler): 489 | """ 490 | This Callback implements a cyclic learning schedule introduced in the paper - 491 | "AveragingWeights Leads to Wider Optima and Better Generalization" (https://arxiv.org/pdf/1803.05407.pdf) 492 | """ 493 | def __init__(self, min_lr, max_lr, cycle_length, alpha=0.2, sma=20, plot=False, save_plot_prefix=None): 494 | self.min_lr = min_lr 495 | self.max_lr = max_lr 496 | self.cycle_length = cycle_length 497 | self.trn_iteration = 0 498 | super(CyclicLR_2, self).__init__(alpha, sma, plot, 499 | save_plot_prefix + '_cyclic_1' if save_plot_prefix else None) 500 | 501 | def on_train_begin(self, logs=None): 502 | K.set_value(self.model.optimizer.lr, self.max_lr) 503 | 504 | def on_batch_end(self, batch, logs=None): 505 | logs = logs or {} 506 | self.trn_iteration += 1 507 | 508 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 509 | self.history.setdefault('iteration', []).append(self.trn_iteration) 510 | 511 | for k, v in logs.items(): 512 | self.history.setdefault(k, []).append(v) 513 | self.update_lr() 514 | 515 | def update_lr(self): 516 | t = (self.trn_iteration % self.cycle_length) / self.cycle_length 517 | lr = (1 - t) * self.max_lr + t * self.min_lr 518 | K.set_value(self.model.optimizer.lr, lr) 519 | 520 | 521 | class WarmUp(Callback): 522 | def __init__(self, learning_rate=5e-5, min_learning_rate=2e-5): 523 | super(WarmUp, self).__init__() 524 | self.passed = 0. 525 | self.learning_rate = learning_rate 526 | self.min_learning_rate = min_learning_rate 527 | 528 | def on_batch_begin(self, batch, logs=None): 529 | """第一个epoch用来warmup,第二个epoch把学习率降到最低 530 | """ 531 | if self.passed < self.params['steps']: 532 | lr = (self.passed + 1.) / self.params['steps'] * self.learning_rate 533 | K.set_value(self.model.optimizer.lr, lr) 534 | self.passed += 1 535 | elif self.params['steps'] <= self.passed < self.params['steps'] * 2: 536 | lr = (2 - (self.passed + 1.) / self.params['steps']) * ( 537 | self.learning_rate - self.min_learning_rate) 538 | lr += self.min_learning_rate 539 | K.set_value(self.model.optimizer.lr, lr) 540 | self.passed += 1 541 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | 5 | @author: Alex Yang 6 | 7 | @contact: alex.yang0326@gmail.com 8 | 9 | @file: train.py 10 | 11 | @time: 2020/5/21 22:06 12 | 13 | @desc: 14 | 15 | """ 16 | import os 17 | import time 18 | import gc 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | import tensorflow.keras.backend as K 23 | 24 | from config import ModelConfig, PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, VOCABULARY_TEMPLATE, \ 25 | PERFORMANCE_MD, \ 26 | MAX_LEN, IDX2TOKEN_TEMPLATE, CATE1_TO_CATE2_DICT, CATE2_TO_CATE3_DICT, CATE1_TO_CATE3_DICT, \ 27 | CATE1_COUNT_DICT, \ 28 | CATE2_COUNT_DICT, CATE3_COUNT_DICT 29 | from utils import format_filename, pickle_load, writer_md 30 | from models import MultiTaskClsModel 31 | from data_loader import MultiTaskClsDataGenerator 32 | 33 | 34 | def prepare_config(model_type='bert-base-uncased', 35 | input_type='name_desc', 36 | use_multi_task=True, 37 | use_harl=False, 38 | use_hal=False, 39 | cate_embed_dim=100, 40 | use_word_input=False, 41 | word_embed_type='w2v', 42 | word_embed_trainable=True, 43 | word_embed_dim=300, 44 | use_bert_input=True, 45 | bert_trainable=True, 46 | use_bert_type='pooler', 47 | n_last_hidden_layer=0, 48 | dense_after_bert=True, 49 | use_pair_input=True, 50 | max_len=None, 51 | share_father_pred='no', 52 | use_mask_for_cate2=False, 53 | use_mask_for_cate3=True, 54 | cate3_mask_type='cate1', 55 | cate1_loss_weight=1., 56 | cate2_loss_weight=1., 57 | cate3_loss_weight=1., 58 | batch_size=32, 59 | predict_batch_size=32, 60 | n_epoch=50, 61 | learning_rate=2e-5, 62 | optimizer='adam', 63 | use_focal_loss=False, 64 | callbacks_to_add=None, 65 | swa_start=15, 66 | early_stopping_patience=5, 67 | max_lr=6e-5, 68 | min_lr=1e-5, 69 | train_on_cv=False, 70 | cv_random_state=42, 71 | cv_fold=5, 72 | cv_index=0, 73 | exchange_pair=False, 74 | exchange_threshold=0.1, 75 | use_pseudo_label=False, 76 | pseudo_path=None, 77 | pseudo_random_state=42, 78 | pseudo_rate=0.1, 79 | pseudo_index=0, 80 | pseudo_name=None, 81 | exp_name=None): 82 | config = ModelConfig() 83 | config.model_type = model_type 84 | config.input_type = input_type 85 | config.use_multi_task = use_multi_task 86 | config.use_harl = use_harl 87 | config.use_hal = use_hal 88 | assert not (config.use_harl and config.use_hal) 89 | config.cate_embed_dim = cate_embed_dim 90 | 91 | config.use_word_input = use_word_input 92 | config.word_embed_type = word_embed_type 93 | if config.use_word_input: 94 | if word_embed_type: 95 | config.word_embeddings = np.load( 96 | format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, 97 | type=word_embed_type)) 98 | config.word_embed_trainable = word_embed_trainable 99 | config.word_embed_dim = config.word_embeddings.shape[1] 100 | else: 101 | config.word_embeddings = None 102 | config.word_embed_trainable = True 103 | config.word_embed_dim = word_embed_dim 104 | config.word_vocab = pickle_load( 105 | format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='word')) 106 | config.word_vocab_size = len(config.word_vocab) + 2 # 0: mask, 1: padding 107 | else: 108 | config.word_vocab = None 109 | 110 | config.use_bert_input = use_bert_input 111 | config.bert_trainable = bert_trainable 112 | if config.use_bert_input: 113 | config.use_bert_type = use_bert_type 114 | config.dense_after_bert = dense_after_bert 115 | if config.use_bert_type in ['hidden', 'hidden_pooler'] or \ 116 | (config.use_multi_task and (config.use_harl or config.use_hal)): 117 | config.output_hidden_state = True 118 | config.n_last_hidden_layer = n_last_hidden_layer 119 | else: 120 | config.output_hidden_state = False 121 | config.n_last_hidden_layer = 0 122 | if config.input_type == 'name_desc': 123 | config.use_pair_input = use_pair_input 124 | else: 125 | config.use_pair_input = False 126 | 127 | if config.use_bert_input and max_len is None: 128 | config.max_len = MAX_LEN[input_type] 129 | else: 130 | config.max_len = max_len 131 | 132 | config.cate1_vocab = pickle_load( 133 | format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate1')) 134 | config.cate2_vocab = pickle_load( 135 | format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate2')) 136 | config.cate3_vocab = pickle_load( 137 | format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='cate3')) 138 | config.all_cate_vocab = pickle_load( 139 | format_filename(PROCESSED_DATA_DIR, VOCABULARY_TEMPLATE, level='all_cate')) 140 | config.idx2cate1 = pickle_load( 141 | format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate1')) 142 | config.idx2cate2 = pickle_load( 143 | format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate2')) 144 | config.idx2cate3 = pickle_load( 145 | format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='cate3')) 146 | config.idx2all_cate = pickle_load( 147 | format_filename(PROCESSED_DATA_DIR, IDX2TOKEN_TEMPLATE, level='all_cate')) 148 | config.cate1_to_cate2 = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE2_DICT)) 149 | config.cate2_to_cate3 = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) 150 | config.cate1_to_cate3 = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) 151 | config.cate1_count_dict = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE1_COUNT_DICT)) 152 | config.cate2_count_dict = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE2_COUNT_DICT)) 153 | config.cate3_count_dict = pickle_load(format_filename(PROCESSED_DATA_DIR, CATE3_COUNT_DICT)) 154 | config.n_cate1 = len(config.cate1_vocab) 155 | config.n_cate2 = len(config.cate2_vocab) 156 | config.n_cate3 = len(config.cate3_vocab) 157 | config.n_all_cate = len(config.all_cate_vocab) 158 | 159 | if config.use_multi_task and (config.use_harl or config.use_hal): 160 | config.share_father_pred = 'no' 161 | config.use_mask_for_cate2 = False 162 | config.use_mask_for_cate3 = False 163 | config.cate3_mask_type = None 164 | else: 165 | config.share_father_pred = share_father_pred 166 | config.use_mask_for_cate2 = use_mask_for_cate2 167 | config.use_mask_for_cate3 = use_mask_for_cate3 168 | config.cate3_mask_type = cate3_mask_type 169 | # if config.use_mask_for_cate2: 170 | if config.use_mask_for_cate3: 171 | if config.cate3_mask_type == 'cate1': 172 | config.cate_to_cate3 = pickle_load( 173 | format_filename(PROCESSED_DATA_DIR, CATE1_TO_CATE3_DICT)) 174 | elif config.cate3_mask_type == 'cate2': 175 | config.cate_to_cate3 = pickle_load( 176 | format_filename(PROCESSED_DATA_DIR, CATE2_TO_CATE3_DICT)) 177 | config.cate1_loss_weight = cate1_loss_weight 178 | config.cate2_loss_weight = cate2_loss_weight 179 | config.cate3_loss_weight = cate3_loss_weight 180 | 181 | config.batch_size = batch_size 182 | config.predict_batch_size = predict_batch_size 183 | config.n_epoch = n_epoch 184 | config.learning_rate = learning_rate 185 | config.optimizer = optimizer 186 | config.learning_rate = learning_rate 187 | config.use_focal_loss = use_focal_loss 188 | config.callbacks_to_add = callbacks_to_add or ['modelcheckpoint', 'earlystopping'] 189 | if 'swa' in config.callbacks_to_add: 190 | config.swa_start = swa_start 191 | config.early_stopping_patience = early_stopping_patience 192 | for lr_scheduler in ['clr', 'sgdr', 'clr_1', 'clr_2', 'warm_up', 'swa_clr']: 193 | if lr_scheduler in config.callbacks_to_add: 194 | config.max_lr = max_lr 195 | config.min_lr = min_lr 196 | 197 | config.train_on_cv = train_on_cv 198 | if config.train_on_cv: 199 | config.cv_random_state = cv_random_state 200 | config.cv_fold = cv_fold 201 | config.cv_index = cv_index 202 | 203 | config.exchange_pair = exchange_pair 204 | if config.exchange_pair: 205 | config.exchange_threshold = exchange_threshold 206 | 207 | config.use_pseudo_label = use_pseudo_label 208 | if config.use_pseudo_label: 209 | config.pseudo_path = pseudo_path 210 | config.pseudo_random_state = pseudo_random_state 211 | config.pseudo_rate = pseudo_rate 212 | config.pseudo_index = pseudo_index 213 | 214 | # build experiment name from parameter configuration 215 | config.exp_name = f'{config.model_type}_{config.input_type}' 216 | if config.use_pair_input: 217 | config.exp_name += '_pair' 218 | config.exp_name += f'_len_{config.max_len}' 219 | if config.use_word_input: 220 | config.exp_name += f"_word_{config.word_embed_type}_{'tune' if config.word_embed_trainable else 'fix'}" 221 | if config.use_bert_input: 222 | config.exp_name += f"_bert_{config.use_bert_type}_{'tune' if config.bert_trainable else 'fix'}" 223 | if config.output_hidden_state: 224 | config.exp_name += f'_hid_{config.n_last_hidden_layer}' 225 | if config.dense_after_bert: 226 | config.exp_name += '_dense' 227 | if config.use_multi_task: 228 | if config.use_harl: 229 | config.exp_name += f'_harl_{config.cate_embed_dim}' 230 | elif config.use_hal: 231 | config.exp_name += f'_hal_{config.cate_embed_dim}' 232 | config.exp_name += f'_{config.cate1_loss_weight}_{config.cate2_loss_weight}_{config.cate3_loss_weight}' 233 | else: 234 | config.exp_name += f'_not_multi_task' 235 | if config.share_father_pred in ['after', 'before']: 236 | config.exp_name += f'_{config.share_father_pred}' 237 | if config.use_mask_for_cate2: 238 | config.exp_name += f'_mask_cate2' 239 | if config.use_mask_for_cate3: 240 | config.exp_name += f'_mask_cate3_with_{config.cate3_mask_type}' 241 | if config.use_focal_loss: 242 | config.exp_name += f'_focal' 243 | config.exp_name += f'_{config.optimizer}_{config.learning_rate}_{config.batch_size}_{config.n_epoch}' 244 | callback_str = '_' + '_'.join(config.callbacks_to_add) 245 | callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '') 246 | config.exp_name += callback_str 247 | if config.train_on_cv: 248 | config.exp_name += f'_{config.cv_random_state}_{config.cv_fold}_{config.cv_index}' 249 | if config.exchange_pair: 250 | config.exp_name += f"_ex_pair_{config.exchange_threshold}" 251 | if config.use_pseudo_label: 252 | if pseudo_name: 253 | config.exp_name += f"_{pseudo_name}_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" 254 | elif 'dev' in config.pseudo_path: 255 | config.exp_name += f"_dev_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" 256 | else: 257 | config.exp_name += f"_test_pseudo_{pseudo_random_state}_{pseudo_rate}_{pseudo_index}" 258 | 259 | if exp_name: 260 | config.exp_name = exp_name 261 | 262 | return config 263 | 264 | 265 | def train(config: ModelConfig, 266 | use_gpu_id=5): 267 | # see: https://www.bookstack.cn/read/TensorFlow2.0/spilt.6.3b87bc87b85cbe5d.md 268 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 269 | tf.config.experimental.set_visible_devices(devices=gpus[use_gpu_id], device_type='GPU') 270 | tf.config.experimental.set_memory_growth(device=gpus[use_gpu_id], enable=True) 271 | 272 | print('Logging Info - Experiment: %s' % config.exp_name) 273 | model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) 274 | model = MultiTaskClsModel[config.model_type](config) 275 | model.summary() 276 | 277 | train_generator = MultiTaskClsDataGenerator(data_type='train', 278 | batch_size=config.batch_size, 279 | use_multi_task=config.use_multi_task, 280 | input_type=config.input_type, 281 | use_word_input=config.use_word_input, 282 | word_vocab=config.word_vocab, 283 | use_bert_input=config.use_bert_input, 284 | use_pair_input=config.use_pair_input, 285 | bert_model_type=config.model_type, 286 | max_len=config.max_len, 287 | cate1_vocab=config.cate1_vocab, 288 | cate2_vocab=config.cate2_vocab, 289 | cate3_vocab=config.cate3_vocab, 290 | all_cate_vocab=config.all_cate_vocab, 291 | use_mask_for_cate2=config.use_mask_for_cate2, 292 | use_mask_for_cate3=config.use_mask_for_cate3, 293 | cate3_mask_type=config.cate3_mask_type, 294 | cate1_to_cate2=config.cate1_to_cate2, 295 | cate_to_cate3=config.cate_to_cate3, 296 | train_on_cv=config.train_on_cv, 297 | cv_random_state=config.cv_random_state, 298 | cv_fold=config.cv_fold, 299 | cv_index=config.cv_index, 300 | exchange_pair=config.exchange_pair, 301 | exchange_threshold=config.exchange_threshold, 302 | cate3_count_dict=config.cate3_count_dict, 303 | use_pseudo_label=config.use_pseudo_label, 304 | pseudo_path=config.pseudo_path, 305 | pseudo_random_state=config.pseudo_random_state, 306 | pseudo_rate=config.pseudo_rate, 307 | pseudo_index=config.pseudo_index 308 | ) 309 | valid_generator = MultiTaskClsDataGenerator(data_type='dev', 310 | batch_size=config.predict_batch_size, 311 | use_multi_task=True, 312 | input_type=config.input_type, 313 | use_word_input=config.use_word_input, 314 | word_vocab=config.word_vocab, 315 | use_bert_input=config.use_bert_input, 316 | use_pair_input=config.use_pair_input, 317 | bert_model_type=config.model_type, 318 | max_len=config.max_len, 319 | cate1_vocab=config.cate1_vocab, 320 | cate2_vocab=config.cate2_vocab, 321 | cate3_vocab=config.cate3_vocab, 322 | all_cate_vocab=config.all_cate_vocab, 323 | use_mask_for_cate2=config.use_mask_for_cate2, 324 | use_mask_for_cate3=config.use_mask_for_cate3, 325 | cate3_mask_type=config.cate3_mask_type, 326 | cate1_to_cate2=config.cate1_to_cate2, 327 | cate_to_cate3=config.cate_to_cate3, 328 | train_on_cv=config.train_on_cv, 329 | cv_random_state=config.cv_random_state, 330 | cv_fold=config.cv_fold, 331 | cv_index=config.cv_index 332 | ) 333 | test_generator = MultiTaskClsDataGenerator(data_type='test', 334 | batch_size=config.predict_batch_size, 335 | use_multi_task=True, 336 | input_type=config.input_type, 337 | use_word_input=config.use_word_input, 338 | word_vocab=config.word_vocab, 339 | use_bert_input=config.use_bert_input, 340 | use_pair_input=config.use_pair_input, 341 | bert_model_type=config.model_type, 342 | max_len=config.max_len, 343 | cate1_vocab=config.cate1_vocab, 344 | cate2_vocab=config.cate2_vocab, 345 | cate3_vocab=config.cate3_vocab, 346 | all_cate_vocab=config.all_cate_vocab, 347 | use_mask_for_cate2=config.use_mask_for_cate2, 348 | use_mask_for_cate3=config.use_mask_for_cate3, 349 | cate3_mask_type=config.cate3_mask_type, 350 | cate1_to_cate2=config.cate1_to_cate2, 351 | cate_to_cate3=config.cate_to_cate3) 352 | 353 | train_logger = {} 354 | if not os.path.exists(model_save_path): 355 | start_time = time.time() 356 | model.train(train_generator, valid_generator) 357 | elapsed_time = time.time() - start_time 358 | print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", 359 | time.gmtime(elapsed_time))) 360 | train_logger['epoch'] = model.return_trained_epoch() 361 | train_logger['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) 362 | train_logger['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 363 | 364 | print('Logging Info - Loading best model...') 365 | model.load_best_model() 366 | print('Logging Info - Evaluating valid set...') 367 | eval_results = model.evaluate(valid_generator, 368 | save_diff=True, 369 | save_prob=True, 370 | prob_file=f'{config.exp_name}_dev_prob.pkl', 371 | diff_file=f'{config.exp_name}_diff.txt') 372 | 373 | print('Logging Info - Predicting test set...') 374 | model.predict(test_generator, 375 | save_prob=True, 376 | prob_file=f'{config.exp_name}_test_prob.pkl', 377 | submit=True, 378 | submit_file=f'{config.exp_name}_submit.csv', 379 | submit_with_text=True) 380 | if train_logger: 381 | train_logger['eval_result'] = eval_results 382 | 383 | swa_type = None 384 | if 'swa' in config.callbacks_to_add: 385 | swa_type = 'swa' 386 | elif 'swa_clr' in config.callbacks_to_add: 387 | swa_type = 'swa_clr' 388 | if swa_type: 389 | print('Logging Info - Loading swa model...') 390 | model.load_swa_model(swa_type) 391 | print('Logging Info - Evaluating valid set...') 392 | swa_results = model.evaluate(valid_generator, 393 | save_prob=True, 394 | prob_file=f'{config.exp_name}_{swa_type}_dev_prob.pkl', 395 | save_diff=True, 396 | diff_file=f'{config.exp_name}_{swa_type}_diff.txt') 397 | print('Logging Info - Predicting test set...') 398 | model.predict(test_generator, 399 | save_prob=True, 400 | prob_file=f'{config.exp_name}_{swa_type}_test_prob.pkl', 401 | submit=True, 402 | submit_file=f'{config.exp_name}_{swa_type}_submit.csv', 403 | submit_with_text=True) 404 | if train_logger: 405 | train_logger['swa_result'] = swa_results 406 | 407 | if train_logger: 408 | writer_md(filename=PERFORMANCE_MD, config=config, trainer_logger=train_logger) 409 | 410 | del model 411 | gc.collect() 412 | K.clear_session() 413 | 414 | 415 | def main(model_type='bert-base-uncased', 416 | input_type='name_desc', 417 | use_multi_task=True, 418 | use_harl=False, 419 | use_hal=False, 420 | cate_embed_dim=100, 421 | use_word_input=False, 422 | word_embed_type='w2v', 423 | word_embed_trainable=True, 424 | word_embed_dim=300, 425 | use_bert_input=True, 426 | bert_trainable=True, 427 | use_bert_type='pooler', 428 | n_last_hidden_layer=0, 429 | dense_after_bert=True, 430 | use_pair_input=True, 431 | max_len=None, 432 | share_father_pred='no', 433 | use_mask_for_cate2=False, 434 | use_mask_for_cate3=True, 435 | cate3_mask_type='cate1', 436 | cate1_loss_weight=1., 437 | cate2_loss_weight=1., 438 | cate3_loss_weight=1., 439 | batch_size=32, 440 | predict_batch_size=32, 441 | n_epoch=50, 442 | learning_rate=2e-5, 443 | optimizer='adam', 444 | use_focal_loss=False, 445 | callbacks_to_add=None, 446 | swa_start=15, 447 | early_stopping_patience=5, 448 | max_lr=6e-5, 449 | min_lr=1e-5, 450 | train_on_cv=False, 451 | cv_random_state=42, 452 | cv_fold=5, 453 | cv_index=0, 454 | exchange_pair=False, 455 | exchange_threshold=0.1, 456 | use_pseudo_label=False, 457 | pseudo_path=None, 458 | pseudo_random_state=42, 459 | pseudo_rate=0.1, 460 | pseudo_index=0, 461 | pseudo_name=None, 462 | exp_name=None, 463 | use_gpu_id=5): 464 | model_config = prepare_config(model_type=model_type, 465 | input_type=input_type, 466 | use_multi_task=use_multi_task, 467 | use_harl=use_harl, 468 | use_hal=use_hal, 469 | cate_embed_dim=cate_embed_dim, 470 | use_word_input=use_word_input, 471 | word_embed_type=word_embed_type, 472 | word_embed_trainable=word_embed_trainable, 473 | word_embed_dim=word_embed_dim, 474 | use_bert_input=use_bert_input, 475 | bert_trainable=bert_trainable, 476 | use_bert_type=use_bert_type, 477 | n_last_hidden_layer=n_last_hidden_layer, 478 | dense_after_bert=dense_after_bert, 479 | use_pair_input=use_pair_input, 480 | max_len=max_len, 481 | share_father_pred=share_father_pred, 482 | use_mask_for_cate2=use_mask_for_cate2, 483 | use_mask_for_cate3=use_mask_for_cate3, 484 | cate3_mask_type=cate3_mask_type, 485 | cate1_loss_weight=cate1_loss_weight, 486 | cate2_loss_weight=cate2_loss_weight, 487 | cate3_loss_weight=cate3_loss_weight, 488 | batch_size=batch_size, 489 | predict_batch_size=predict_batch_size, 490 | n_epoch=n_epoch, 491 | learning_rate=learning_rate, 492 | optimizer=optimizer, 493 | use_focal_loss=use_focal_loss, 494 | callbacks_to_add=callbacks_to_add, 495 | swa_start=swa_start, 496 | early_stopping_patience=early_stopping_patience, 497 | max_lr=max_lr, 498 | min_lr=min_lr, 499 | train_on_cv=train_on_cv, 500 | cv_random_state=cv_random_state, 501 | cv_fold=cv_fold, 502 | cv_index=cv_index, 503 | exchange_pair=exchange_pair, 504 | exchange_threshold=exchange_threshold, 505 | use_pseudo_label=use_pseudo_label, 506 | pseudo_path=pseudo_path, 507 | pseudo_random_state=pseudo_random_state, 508 | pseudo_rate=pseudo_rate, 509 | pseudo_index=pseudo_index, 510 | pseudo_name=pseudo_name, 511 | exp_name=exp_name) 512 | train(model_config, use_gpu_id=use_gpu_id) 513 | --------------------------------------------------------------------------------