├── LICENSE
├── README.md
├── __pycache__
└── utils.cpython-37.pyc
├── bert_weights
└── put_your_bert_weights_in_this_dir.txt
├── datasets
├── datasets.z01
├── datasets.zip
└── just_unzip_the_zip_file.txt
├── images
├── LCM.png
├── curve.png
└── paper_title.png
├── lcm_exp_on_bert.py
├── lcm_exp_on_lstm.py
├── models
├── __pycache__
│ └── lstm.cpython-37.pyc
├── bert.py
└── lstm.py
├── output
└── placeholder.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 beyond_guo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Description:
2 |
3 | This is the official implementation of our AAAI-21 accepted paper *Label Confusion Learning to Enhance Text Classification Models*.
4 |
5 |
6 |
7 |
8 |
9 | The structure of LCM looks like this:
10 |
11 |
12 |
13 |
14 |
15 | Here we provide some demo experimental code & datasets.
16 |
17 | ### Environment:
18 |
19 | python 3.6
20 | tensorflow 2.2.0
21 | keras 2.3.1
22 |
23 | ### Run a Demo:
24 |
25 | **LCM-based LSTM:**
26 |
27 | Run `python lcm_exp_on_lstm.py` to compare the performance of LSTM, LSTM with label smoothing(LS) and LSTM with LCM.
28 |
29 | **LCM-based BERT:**
30 |
31 | Run `python lcm_exp_on_bert.py` to compare the performance of BERT, BERT with label smoothing(LS) and BERT with LCM.
32 |
33 |
34 |
35 | The final results will be outputted to `output/` directory.
36 |
37 |
38 |
39 | The curve below shows our results on 20NG with LSTM as basic predictor. By changing the α, we can control the influence of LCM on the original model.
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_weights/put_your_bert_weights_in_this_dir.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/bert_weights/put_your_bert_weights_in_this_dir.txt
--------------------------------------------------------------------------------
/datasets/datasets.z01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/datasets/datasets.z01
--------------------------------------------------------------------------------
/datasets/datasets.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/datasets/datasets.zip
--------------------------------------------------------------------------------
/datasets/just_unzip_the_zip_file.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/datasets/just_unzip_the_zip_file.txt
--------------------------------------------------------------------------------
/images/LCM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/images/LCM.png
--------------------------------------------------------------------------------
/images/curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/images/curve.png
--------------------------------------------------------------------------------
/images/paper_title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/images/paper_title.png
--------------------------------------------------------------------------------
/lcm_exp_on_bert.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import numpy as np
3 | from utils import load_dataset, create_asy_noise_labels
4 | from sklearn.utils import shuffle
5 | from bert4keras.tokenizers import Tokenizer
6 | from bert4keras.snippets import sequence_padding
7 | from models import bert
8 | import matplotlib.pyplot as plt
9 | import time
10 |
11 |
12 | #%%
13 | # ========== parameters: ==========
14 | maxlen = 128
15 | hidden_size = 64
16 | batch_size = 512
17 | epochs = 150
18 |
19 | # ========== bert config: ==========
20 | # for English, use bert_tiny:
21 | bert_type = 'bert'
22 | config_path = 'bert_weights/bert_tiny_uncased_L-2_H-128_A-2/bert_config.json'
23 | checkpoint_path = 'bert_weights/bert_tiny_uncased_L-2_H-128_A-2/bert_model.ckpt'
24 | vocab_path = 'bert_weights/bert_tiny_uncased_L-2_H-128_A-2/vocab.txt'
25 |
26 | # for Chinese, use albert_tiny:
27 | # bert_type = 'albert'
28 | # config_path = '../bert_weights/albert_tiny_google_zh_489k/albert_config.json'
29 | # checkpoint_path = '../bert_weights/albert_tiny_google_zh_489k/albert_model.ckpt'
30 | # vocab_path = '../bert_weights/albert_tiny_google_zh_489k/vocab.txt'
31 |
32 | tokenizer = Tokenizer(vocab_path, do_lower_case=True)
33 |
34 | # ========== dataset: ==========
35 | dataset_name = '20NG'
36 | group_noise_rate = 0.3
37 | df,num_classes,label_groups = load_dataset(dataset_name)
38 | # define log file name:
39 | log_txt_name = '%s_BERT_log(group_noise=%s,comp+rec+talk)' % (dataset_name,group_noise_rate)
40 |
41 | df = df.dropna(axis=0,how='any')
42 | df = shuffle(df)[:50000]
43 | print('data size:',len(df))
44 |
45 | #%%
46 | # ========== data preparation: ==========
47 | labels = sorted(list(set(df.label)))
48 | assert len(labels) == num_classes,'wrong num of classes!'
49 | label2idx = {name:i for name,i in zip(labels,range(num_classes))}
50 | #%%
51 | print('start tokenizing...')
52 | t = time.time()
53 | X_token = []
54 | X_seg = []
55 | y = []
56 | i = 0
57 | for content,label in zip(list(df.content),list(df.label)):
58 | i += 1
59 | if i%1000 == 0:
60 | print(i)
61 | token_ids, seg_ids = tokenizer.encode(content, maxlen=maxlen)
62 | X_token.append(token_ids)
63 | X_seg.append(seg_ids)
64 | y.append(label2idx[label])
65 |
66 | # the sequences we obtained from above may have different length, so use Padding:
67 | X_token = sequence_padding(X_token)
68 | X_seg = sequence_padding(X_seg)
69 | y = np.array(y)
70 | print('tokenizing time cost:',time.time()-t,'s.')
71 |
72 | #%%
73 | # ========== model traing: ==========
74 | old_list = []
75 | ls_list = []
76 | lcm_list = []
77 | N = 5
78 | for n in range(N):
79 | # randomly split train and test each time:
80 | np.random.seed(n) # 这样保证了每次试验的seed一致
81 | random_indexs = np.random.permutation(range(len(X_token)))
82 | train_size = int(len(X_token)*0.6)
83 | val_size = int(len(X_token)*0.15)
84 | X_token_train = X_token[random_indexs][:train_size]
85 | X_token_val = X_token[random_indexs][train_size:train_size+val_size]
86 | X_token_test = X_token[random_indexs][train_size+val_size:]
87 | X_seg_train = X_seg[random_indexs][:train_size]
88 | X_seg_val = X_seg[random_indexs][train_size:train_size + val_size]
89 | X_seg_test = X_seg[random_indexs][train_size + val_size:]
90 | y_train = y[random_indexs][:train_size]
91 | y_val = y[random_indexs][train_size:train_size+val_size]
92 | y_test = y[random_indexs][train_size+val_size:]
93 | data_package = [X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test, y_test]
94 |
95 | # apply noise only on train set:
96 | if group_noise_rate>0:
97 | _, overall_noise_rate, y_train = create_asy_noise_labels(y_train,label_groups,label2idx,group_noise_rate)
98 | data_package = [X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test,
99 | y_test]
100 | with open('output/%s.txt' % log_txt_name, 'a') as f:
101 | print('-'*30,'\nNOITCE: overall_noise_rate=%s'%round(overall_noise_rate,2), file=f)
102 |
103 |
104 | with open('output/%s.txt'%log_txt_name,'a') as f:
105 | print('\n',str(datetime.datetime.now()),file=f)
106 | print('\n ROUND & SEED = ',n,'-'*20,file=f)
107 |
108 | model_to_run = [""]
109 | print('====Original:============')
110 | model = bert.BERT_Basic(config_path,checkpoint_path,hidden_size,num_classes,bert_type)
111 | train_score_list, val_socre_list, best_val_score, test_score = model.train_val(data_package,batch_size,epochs)
112 | plt.plot(train_score_list, label='train')
113 | plt.plot(val_socre_list, label='val')
114 | plt.title('BERT')
115 | plt.legend()
116 | plt.show()
117 | old_list.append(test_score)
118 | with open('output/%s.txt'%log_txt_name,'a') as f:
119 | print('\n*** Orig BERT ***:',file=f)
120 | print('test acc:', str(test_score), file=f)
121 | print('best val acc:',str(best_val_score),file=f)
122 | print('train acc list:\n',str(train_score_list),file=f)
123 | print('val acc list:\n',str(val_socre_list),'\n',file=f)
124 |
125 | print('====Label Smooth:============')
126 | ls_e = 0.1
127 | model = bert.BERT_LS(config_path,checkpoint_path,hidden_size,num_classes,ls_e,bert_type)
128 | train_score_list, val_socre_list, best_val_score, test_score = model.train_val(data_package,batch_size,epochs)
129 | plt.plot(train_score_list, label='train')
130 | plt.plot(val_socre_list, label='val')
131 | plt.title('BERT with LS')
132 | plt.legend()
133 | plt.show()
134 | old_list.append(test_score)
135 | with open('output/%s.txt'%log_txt_name,'a') as f:
136 | print('\n*** Orig BERT with LS (e=%s) ***:'%ls_e, file=f)
137 | print('test acc:', str(test_score), file=f)
138 | print('best val acc:', str(best_val_score), file=f)
139 | print('train acc list:\n', str(train_score_list), file=f)
140 | print('val acc list:\n', str(val_socre_list), '\n', file=f)
141 |
142 | print('====LCM:============')
143 | # alpha = 3
144 | for alpha in [3,4,5]:
145 | wvdim = 256
146 | lcm_stop = 100
147 | params_str = 'a=%s, wvdim=%s, lcm_stop=%s'%(alpha,wvdim,lcm_stop)
148 | model = bert.BERT_LCM(config_path,checkpoint_path,hidden_size,num_classes,alpha,wvdim,bert_type)
149 | train_score_list, val_socre_list, best_val_score, test_score = model.train_val(data_package, batch_size,epochs,lcm_stop)
150 | plt.plot(train_score_list, label='train')
151 | plt.plot(val_socre_list, label='val')
152 | plt.title('BERT with LCM')
153 | plt.legend()
154 | plt.show()
155 | old_list.append(test_score)
156 | with open('output/%s.txt'%log_txt_name,'a') as f:
157 | print('\n*** Orig BERT with LCM (%s) ***:'%params_str,file=f)
158 | print('test acc:', str(test_score), file=f)
159 | print('best val acc:', str(best_val_score), file=f)
160 | print('train acc list:\n', str(train_score_list), file=f)
161 | print('val acc list:\n', str(val_socre_list), '\n', file=f)
--------------------------------------------------------------------------------
/lcm_exp_on_lstm.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sys
3 | import pandas as pd
4 | import numpy as np
5 | from utils import *
6 | from sklearn.utils import shuffle
7 | from models import lstm
8 |
9 | # ========== parameters & dataset: ==========
10 | vocab_size = 20000
11 | maxlen = 100
12 | wvdim = 64
13 | hidden_size = 64
14 | # num_filters = 100
15 | # filter_sizes = [3,10,25] # 不能超过maxlen
16 | alpha = 4 # new model的loss中的alpha
17 | batch_size = 512
18 | epochs = 40
19 | emb_type = ''
20 |
21 | log_txt_name = '20NG_log'
22 | num_classes = 20
23 | df1 = pd.read_csv('datasets/20NG/20ng-train-all-terms.csv')
24 | df2 = pd.read_csv('datasets/20NG/20ng-test-all-terms.csv')
25 | df = pd.concat([df1,df2])
26 |
27 | # log_txt_name = 'DBPedia' # AG_log
28 | # num_classes = 14
29 | # df1 = pd.read_csv('../datasets/DBPedia/train.csv',header=None,index_col=None)
30 | # df2 = pd.read_csv('../datasets/DBPedia/test.csv',header=None,index_col=None)
31 | # df1.columns = ['label','title','content']
32 | # df2.columns = ['label','title','content']
33 | # df = pd.concat([df1,df2])
34 |
35 | # log_txt_name = 'Fudan_log'
36 | # num_classes = 20
37 | # # df = pd.read_csv('../datasets/thucnews_subset.csv')
38 | # df = pd.read_csv('../datasets/fudan_news.csv')
39 |
40 |
41 | df = df.dropna(axis=0,how='any')
42 | df = shuffle(df)[:50000]
43 | print('data size:',len(df))
44 |
45 |
46 | # ========== data pre-processing: ==========
47 | labels = sorted(list(set(df.label)))
48 | assert len(labels) == num_classes,'wrong num of classes'
49 | label2idx = {name:i for name,i in zip(labels,range(num_classes))}
50 | num_classes = len(label2idx)
51 |
52 | corpus = []
53 | X_words = []
54 | Y = []
55 | i = 0
56 | for content,y in zip(list(df.content),list(df.label)):
57 | i += 1
58 | if i%1000 == 0:
59 | print(i)
60 | # English:
61 | content_words = content.split(' ')
62 | # Chinese:
63 | # content_words = jieba.lcut(content)
64 | corpus += content_words
65 | X_words.append(content_words)
66 | Y.append(label2idx[y])
67 |
68 | tokenizer, word_index, freq_word_index = fit_corpus(corpus,vocab_size=vocab_size)
69 | X = text2idx(tokenizer,X_words,maxlen=maxlen)
70 | y = np.array(Y)
71 |
72 |
73 |
74 | # ========== model traing: ==========
75 | old_list = []
76 | ls_list = []
77 | lcm_list = []
78 | N = 10
79 | for n in range(N):
80 | # randomly split train and test each time:
81 | np.random.seed(n) # 这样保证了每次试验的seed一致
82 | random_indexs = np.random.permutation(range(len(X)))
83 | train_size = int(len(X)*0.6)
84 | val_size = int(len(X)*0.15)
85 | X_train = X[random_indexs][:train_size]
86 | X_val = X[random_indexs][train_size:train_size+val_size]
87 | X_test = X[random_indexs][train_size+val_size:]
88 | y_train = y[random_indexs][:train_size]
89 | y_val = y[random_indexs][train_size:train_size+val_size]
90 | y_test = y[random_indexs][train_size+val_size:]
91 | data_package = [X_train,y_train,X_val,y_val,X_test,y_test]
92 |
93 |
94 | with open('output/%s.txt'%log_txt_name,'a') as f:
95 | print(str(datetime.datetime.now()),file=f)
96 | print('--Round',n+1,file=f)
97 |
98 | print('====Original:============')
99 | basic_model = lstm.LSTM_Basic(maxlen,vocab_size,wvdim,hidden_size,num_classes,None)
100 | best_val_score,val_score_list,final_test_score,final_train_score = basic_model.train_val(data_package,batch_size,epochs)
101 | old_list.append(final_test_score)
102 | with open('output/%s.txt'%log_txt_name,'a') as f:
103 | print(n,'old:',final_train_score,best_val_score,final_test_score,file=f)
104 |
105 | print('====LS:============')
106 | ls_model = lstm.LSTM_LS(maxlen,vocab_size,wvdim,hidden_size,num_classes,None)
107 | best_val_score,val_score_list,final_test_score,final_train_score = ls_model.train_val(data_package,batch_size,epochs)
108 | ls_list.append(final_test_score)
109 | with open('output/%s.txt'%log_txt_name,'a') as f:
110 | print(n,'ls:',final_train_score,best_val_score,final_test_score,file=f)
111 |
112 | print('====LCM:============')
113 | dy_lcm_model = lstm.LSTM_LCM_dynamic(maxlen,vocab_size,wvdim,hidden_size,num_classes,alpha,None,None)
114 | best_val_score,val_score_list,final_test_score,final_train_score = dy_lcm_model.train_val(data_package,batch_size,epochs,lcm_stop=30)
115 | lcm_list.append(final_test_score)
116 | with open('output/%s.txt'%log_txt_name,'a') as f:
117 | print(n,'lcm:',final_train_score,best_val_score,final_test_score,file=f)
118 |
--------------------------------------------------------------------------------
/models/__pycache__/lstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/models/__pycache__/lstm.cpython-37.pyc
--------------------------------------------------------------------------------
/models/bert.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import keras
3 | import tensorflow as tf
4 | import time
5 | from keras.models import Sequential,Model
6 | from keras.layers import Input,Dense,LSTM,Embedding,Conv1D,MaxPooling1D
7 | from keras.layers import Flatten,Dropout,Concatenate,Lambda,Multiply,Reshape,Dot,Bidirectional
8 | import keras.backend as K
9 | from keras.utils import to_categorical
10 | from gensim.models import Word2Vec
11 | from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
12 | from bert4keras.models import build_transformer_model
13 | from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
14 |
15 |
16 | class BERT_Basic:
17 |
18 | def __init__(self,config_path,checkpoint_path,hidden_size,num_classes,model_type='bert'):
19 | self.num_classes = num_classes
20 | bert = build_transformer_model(config_path=config_path,checkpoint_path=checkpoint_path,
21 | model=model_type,return_keras_model=False)
22 | text_emb = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
23 | text_emb = Dense(hidden_size,activation='tanh')(text_emb)
24 | output = Dense(num_classes,activation='softmax')(text_emb)
25 | self.model = Model(bert.model.input,output)
26 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
27 | self.model.compile(loss='categorical_crossentropy',optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1,2000: 0.1}))
28 |
29 | def train_val(self,data_package,batch_size,epochs,save_best=False):
30 | X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test, y_test = data_package
31 | best_val_score = 0
32 | test_score = 0
33 | train_score_list = []
34 | val_socre_list = []
35 | """实验说明:
36 | 每一轮train完,在val上测试,记录其accuracy,
37 | 每当val-acc达到新高,就立马在test上测试,得到test-acc,
38 | 这样最终保留下来的test-acc就是在val上表现最好的模型在test上的accuracy
39 | """
40 | learning_curve = []
41 | for i in range(epochs):
42 | t1 = time.time()
43 | self.model.fit([X_token_train,X_seg_train],to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
44 | # record train set result:
45 | pred_probs = self.model.predict([X_token_train, X_seg_train])
46 | predictions = np.argmax(pred_probs, axis=1)
47 | train_score = round(accuracy_score(y_train, predictions), 5)
48 | train_score_list.append(train_score)
49 | # validation:
50 | pred_probs = self.model.predict([X_token_val,X_seg_val])
51 | predictions = np.argmax(pred_probs,axis=1)
52 | val_score = round(accuracy_score(y_val, predictions), 5)
53 | val_socre_list.append(val_score)
54 | t2 = time.time()
55 | print('(Orig)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score, ' | val acc:', val_score)
56 | # save best model according to validation & test result:
57 | if val_score>best_val_score:
58 | best_val_score = val_score
59 | print('Current Best model!','current epoch:',i+1)
60 | # test on best model:
61 | pred_probs = self.model.predict([X_token_test,X_seg_test])
62 | predictions = np.argmax(pred_probs, axis=1)
63 | test_score = round(accuracy_score(y_test, predictions), 5)
64 | print(' Current Best model! Test score:', test_score)
65 | if save_best:
66 | self.model.save('best_model_bert.h5')
67 | print(' best model saved!')
68 | return train_score_list, val_socre_list, best_val_score, test_score
69 |
70 |
71 | class BERT_LS:
72 | def __init__(self, config_path, checkpoint_path, hidden_size, num_classes, ls_e=0.1, model_type='bert'):
73 |
74 | def ls_loss(y_true, y_pred, e=ls_e):
75 | loss1 = K.categorical_crossentropy(y_true, y_pred)
76 | loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / num_classes, y_pred)
77 | return (1 - e) * loss1 + e * loss2
78 |
79 | self.num_classes = num_classes
80 | bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
81 | model=model_type, return_keras_model=False)
82 | text_emb = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
83 | text_emb = Dense(hidden_size, activation='tanh')(text_emb)
84 | output = Dense(num_classes, activation='softmax')(text_emb)
85 | self.model = Model(bert.model.input, output)
86 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
87 | self.model.compile(loss=ls_loss, optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1, 2000: 0.1}))
88 |
89 | def train_val(self, data_package, batch_size, epochs, save_best=False):
90 | X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test, y_test = data_package
91 | best_val_score = 0
92 | test_score = 0
93 | train_score_list = []
94 | val_socre_list = []
95 | """实验说明:
96 | 每一轮train完,在val上测试,记录其accuracy,
97 | 每当val-acc达到新高,就立马在test上测试,得到test-acc,
98 | 这样最终保留下来的test-acc就是在val上表现最好的模型在test上的accuracy
99 | """
100 | for i in range(epochs):
101 | t1 = time.time()
102 | self.model.fit([X_token_train, X_seg_train], to_categorical(y_train), batch_size=batch_size, verbose=0,
103 | epochs=1)
104 | # record train set result:
105 | pred_probs = self.model.predict([X_token_train, X_seg_train])
106 | predictions = np.argmax(pred_probs, axis=1)
107 | train_score = round(accuracy_score(y_train, predictions), 5)
108 | train_score_list.append(train_score)
109 | # validation:
110 | pred_probs = self.model.predict([X_token_val, X_seg_val])
111 | predictions = np.argmax(pred_probs, axis=1)
112 | val_score = round(accuracy_score(y_val, predictions), 5)
113 | val_socre_list.append(val_score)
114 | t2 = time.time()
115 | print('(LS)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score, ' | val acc:',
116 | val_score)
117 | # save best model according to validation & test result:
118 | if val_score > best_val_score:
119 | best_val_score = val_score
120 | print('Current Best model!', 'current epoch:', i + 1)
121 | # test on best model:
122 | pred_probs = self.model.predict([X_token_test, X_seg_test])
123 | predictions = np.argmax(pred_probs, axis=1)
124 | test_score = round(accuracy_score(y_test, predictions), 5)
125 | print(' Current Best model! Test score:', test_score)
126 | if save_best:
127 | self.model.save('best_model_bert_ls.h5')
128 | print(' best model saved!')
129 | return train_score_list, val_socre_list, best_val_score, test_score
130 |
131 |
132 | class BERT_LCM:
133 | def __init__(self,config_path,checkpoint_path,hidden_size,num_classes,alpha,wvdim=768,model_type='bert',label_embedding_matrix=None):
134 | self.num_classes = num_classes
135 |
136 | def lcm_loss(y_true,y_pred,alpha=alpha):
137 | pred_probs = y_pred[:,:num_classes]
138 | label_sim_dist = y_pred[:,num_classes:]
139 | simulated_y_true = K.softmax(label_sim_dist+alpha*y_true)
140 | loss1 = -K.categorical_crossentropy(simulated_y_true,simulated_y_true)
141 | loss2 = K.categorical_crossentropy(simulated_y_true,pred_probs)
142 | return loss1+loss2
143 |
144 | def ls_loss(y_true, y_pred, e=0.1):
145 | loss1 = K.categorical_crossentropy(y_true, y_pred)
146 | loss2 = K.categorical_crossentropy(K.ones_like(y_pred)/num_classes, y_pred)
147 | return (1-e)*loss1 + e*loss2
148 |
149 | # text_encoder:
150 | bert = build_transformer_model(config_path=config_path,checkpoint_path=checkpoint_path,
151 | model=model_type,return_keras_model=False)
152 | text_emb = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
153 | text_emb = Dense(hidden_size,activation='tanh')(text_emb)
154 | pred_probs = Dense(num_classes,activation='softmax')(text_emb)
155 | self.basic_predictor = Model(bert.model.input,pred_probs)
156 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
157 | self.basic_predictor.compile(loss='categorical_crossentropy',optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1,2000: 0.1}))
158 |
159 |
160 | # label_encoder:
161 | label_input = Input(shape=(num_classes,),name='label_input')
162 | if label_embedding_matrix is None: # 不使用pretrained embedding
163 | label_emb = Embedding(num_classes,wvdim,input_length=num_classes,name='label_emb1')(label_input) # (n,wvdim)
164 | else:
165 | label_emb = Embedding(num_classes,wvdim,input_length=num_classes,weights=[label_embedding_matrix],name='label_emb1')(label_input)
166 | # label_emb = Bidirectional(LSTM(hidden_size,return_sequences=True),merge_mode='ave')(label_emb) # (n,d)
167 | label_emb = Dense(hidden_size,activation='tanh',name='label_emb2')(label_emb)
168 |
169 | # similarity part:
170 | doc_product = Dot(axes=(2,1))([label_emb,text_emb]) # (n,d) dot (d,1) --> (n,1)
171 | label_sim_dict = Dense(num_classes,activation='softmax',name='label_sim_dict')(doc_product)
172 | # concat output:
173 | concat_output = Concatenate()([pred_probs,label_sim_dict])
174 | # compile;
175 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
176 | self.model = Model(bert.model.input+[label_input],concat_output)
177 | self.model.compile(loss=lcm_loss, optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1,2000: 0.1}))
178 |
179 |
180 | def lcm_evaluate(self,model,inputs,y_true):
181 | outputs = model.predict(inputs)
182 | pred_probs = outputs[:,:self.num_classes]
183 | predictions = np.argmax(pred_probs,axis=1)
184 | acc = round(accuracy_score(y_true,predictions),5)
185 | return acc
186 |
187 | def train_val(self, data_package, batch_size,epochs,lcm_stop=50,save_best=False):
188 | X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test, y_test = data_package
189 | best_val_score = 0
190 | test_score = 0
191 | train_score_list = []
192 | val_socre_list = []
193 | """实验说明:
194 | 每一轮train完,在val上测试,记录其accuracy,
195 | 每当val-acc达到新高,就立马在test上测试,得到test-acc,
196 | 这样最终保留下来的test-acc就是在val上表现最好的模型在test上的accuracy
197 | """
198 |
199 | L_train = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_train))])
200 | L_val = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_val))])
201 | L_test = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_test))])
202 |
203 | for i in range(epochs):
204 | t1 = time.time()
205 | if i < lcm_stop:
206 | self.model.fit([X_token_train,X_seg_train,L_train],to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
207 | # record train set result:
208 | train_score = self.lcm_evaluate(self.model,[X_token_train,X_seg_train,L_train],y_train)
209 | train_score_list.append(train_score)
210 | # validation:
211 | val_score = self.lcm_evaluate(self.model,[X_token_val,X_seg_val,L_val],y_val)
212 | val_socre_list.append(val_score)
213 | t2 = time.time()
214 | print('(LCM)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score, ' | val acc:',
215 | val_score)
216 | # save best model according to validation & test result:
217 | if val_score > best_val_score:
218 | best_val_score = val_score
219 | print('Current Best model!', 'current epoch:', i + 1)
220 | # test on best model:
221 | test_score = self.lcm_evaluate(self.model,[X_token_test,X_seg_test,L_test],y_test)
222 | print(' Current Best model! Test score:', test_score)
223 | if save_best:
224 | self.model.save('best_model_bert_lcm.h5')
225 | print(' best model saved!')
226 | else:
227 | self.basic_predictor.fit([X_token_train,X_seg_train],to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
228 | # record train set result:
229 | pred_probs = self.basic_predictor.predict([X_token_train, X_seg_train])
230 | predictions = np.argmax(pred_probs, axis=1)
231 | train_score = round(accuracy_score(y_train, predictions),5)
232 | train_score_list.append(train_score)
233 | # validation:
234 | pred_probs = self.basic_predictor.predict([X_token_val, X_seg_val])
235 | predictions = np.argmax(pred_probs, axis=1)
236 | val_score = round(accuracy_score(y_val, predictions),5)
237 | val_socre_list.append(val_score)
238 | t2 = time.time()
239 | print('(LCM_stopped)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score, ' | val acc:',
240 | val_score)
241 | # save best model according to validation & test result:
242 | if val_score > best_val_score:
243 | best_val_score = val_score
244 | print('Current Best model!', 'current epoch:', i + 1)
245 | # test on best model:
246 | pred_probs = self.basic_predictor.predict([X_token_test, X_seg_test])
247 | predictions = np.argmax(pred_probs, axis=1)
248 | test_score = round(accuracy_score(y_test, predictions),5)
249 | print(' Current Best model! Test score:', test_score)
250 | if save_best:
251 | self.model.save('best_model_bert_lcm.h5')
252 | print(' best model saved!')
253 | return train_score_list, val_socre_list, best_val_score, test_score
254 |
255 |
256 |
257 | class BERT_LCM_Att:
258 | """
259 | 使用self-attention构建的LCM
260 | """
261 | def __init__(self, config_path, checkpoint_path, hidden_size, num_classes, alpha, wvdim=768, model_type='bert',
262 | label_embedding_matrix=None):
263 | self.num_classes = num_classes
264 |
265 | def lcm_loss(y_true, y_pred, alpha=alpha):
266 | pred_probs = y_pred[:, :num_classes]
267 | label_sim_dist = y_pred[:, num_classes:]
268 | simulated_y_true = K.softmax(label_sim_dist + alpha * y_true)
269 | loss1 = -K.categorical_crossentropy(simulated_y_true, simulated_y_true)
270 | loss2 = K.categorical_crossentropy(simulated_y_true, pred_probs)
271 | return loss1 + loss2
272 |
273 | def ls_loss(y_true, y_pred, e=0.1):
274 | loss1 = K.categorical_crossentropy(y_true, y_pred)
275 | loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / num_classes, y_pred)
276 | return (1 - e) * loss1 + e * loss2
277 |
278 | # text_encoder:
279 | bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path,
280 | model=model_type, return_keras_model=False)
281 | text_emb = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
282 | text_emb = Dense(hidden_size, activation='tanh')(text_emb)
283 | pred_probs = Dense(num_classes, activation='softmax')(text_emb)
284 | self.basic_predictor = Model(bert.model.input, pred_probs)
285 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
286 | self.basic_predictor.compile(loss='categorical_crossentropy',
287 | optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1, 2000: 0.1}))
288 |
289 | # label_encoder by attention:
290 | label_input = Input(shape=(num_classes,), name='label_input')
291 | if label_embedding_matrix is None: # 不使用pretrained embedding
292 | label_emb = Embedding(num_classes, wvdim, input_length=num_classes, name='label_emb1')(
293 | label_input) # (n,wvdim)
294 | else:
295 | label_emb = Embedding(num_classes, wvdim, input_length=num_classes, weights=[label_embedding_matrix],
296 | name='label_emb1')(label_input)
297 | # self-attention part:
298 | doc_attention = Lambda(lambda pair: tf.matmul(pair[0],pair[1],transpose_b=True))([label_emb,label_emb]) # (n,d) * (d,n) -> (n,n)
299 | attention_scores = Lambda(lambda x:tf.nn.softmax(x))(doc_attention) # (n,n)
300 | self_attention_seq = Lambda(lambda pair: tf.matmul(pair[0],pair[1]))([attention_scores, label_emb]) # (n,n)*(n,d)->(n,d)
301 |
302 | # similarity part:
303 | doc_product = Dot(axes=(2, 1))([self_attention_seq, text_emb]) # (n,d) dot (d,1) --> (n,1)
304 | label_sim_dist = Lambda(lambda x: tf.nn.softmax(x), name='label_sim_dict')(doc_product)
305 | # concat output:
306 | concat_output = Concatenate()([pred_probs, label_sim_dist])
307 | # compile;
308 | AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
309 | self.model = Model(bert.model.input + [label_input], concat_output)
310 | self.model.compile(loss=lcm_loss, optimizer=AdamLR(learning_rate=1e-4, lr_schedule={1000: 1, 2000: 0.1}))
311 |
312 | def lcm_evaluate(self, model, inputs, y_true):
313 | outputs = model.predict(inputs)
314 | pred_probs = outputs[:, :self.num_classes]
315 | predictions = np.argmax(pred_probs, axis=1)
316 | acc = round(accuracy_score(y_true, predictions), 5)
317 | return acc
318 |
319 | def train_val(self, data_package, batch_size, epochs, lcm_stop=50, save_best=False):
320 | X_token_train, X_seg_train, y_train, X_token_val, X_seg_val, y_val, X_token_test, X_seg_test, y_test = data_package
321 | best_val_score = 0
322 | test_score = 0
323 | train_score_list = []
324 | val_socre_list = []
325 | """实验说明:
326 | 每一轮train完,在val上测试,记录其accuracy,
327 | 每当val-acc达到新高,就立马在test上测试,得到test-acc,
328 | 这样最终保留下来的test-acc就是在val上表现最好的模型在test上的accuracy
329 | """
330 |
331 | L_train = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_train))])
332 | L_val = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_val))])
333 | L_test = np.array([np.array(range(self.num_classes)) for i in range(len(X_token_test))])
334 |
335 | for i in range(epochs):
336 | t1 = time.time()
337 | if i < lcm_stop:
338 | self.model.fit([X_token_train, X_seg_train, L_train], to_categorical(y_train), batch_size=batch_size,
339 | verbose=0, epochs=1)
340 | # record train set result:
341 | train_score = self.lcm_evaluate(self.model, [X_token_train, X_seg_train, L_train], y_train)
342 | train_score_list.append(train_score)
343 | # validation:
344 | val_score = self.lcm_evaluate(self.model, [X_token_val, X_seg_val, L_val], y_val)
345 | val_socre_list.append(val_score)
346 | t2 = time.time()
347 | print('(LCM)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score, ' | val acc:',
348 | val_score)
349 | # save best model according to validation & test result:
350 | if val_score > best_val_score:
351 | best_val_score = val_score
352 | print('Current Best model!', 'current epoch:', i + 1)
353 | # test on best model:
354 | test_score = self.lcm_evaluate(self.model, [X_token_test, X_seg_test, L_test], y_test)
355 | print(' Current Best model! Test score:', test_score)
356 | if save_best:
357 | self.model.save('best_model_bert_lcm.h5')
358 | print(' best model saved!')
359 | else:
360 | self.basic_predictor.fit([X_token_train, X_seg_train], to_categorical(y_train), batch_size=batch_size,
361 | verbose=0, epochs=1)
362 | # record train set result:
363 | pred_probs = self.basic_predictor.predict([X_token_train, X_seg_train])
364 | predictions = np.argmax(pred_probs, axis=1)
365 | train_score = round(accuracy_score(y_train, predictions), 5)
366 | train_score_list.append(train_score)
367 | # validation:
368 | pred_probs = self.basic_predictor.predict([X_token_val, X_seg_val])
369 | predictions = np.argmax(pred_probs, axis=1)
370 | val_score = round(accuracy_score(y_val, predictions), 5)
371 | val_socre_list.append(val_score)
372 | t2 = time.time()
373 | print('(LCM_stopped)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), ' | train acc:', train_score,
374 | ' | val acc:',
375 | val_score)
376 | # save best model according to validation & test result:
377 | if val_score > best_val_score:
378 | best_val_score = val_score
379 | print('Current Best model!', 'current epoch:', i + 1)
380 | # test on best model:
381 | pred_probs = self.basic_predictor.predict([X_token_test, X_seg_test])
382 | predictions = np.argmax(pred_probs, axis=1)
383 | test_score = round(accuracy_score(y_test, predictions), 5)
384 | print(' Current Best model! Test score:', test_score)
385 | if save_best:
386 | self.model.save('best_model_bert_lcm.h5')
387 | print(' best model saved!')
388 | return train_score_list, val_socre_list, best_val_score, test_score
--------------------------------------------------------------------------------
/models/lstm.py:
--------------------------------------------------------------------------------
1 | # checked at 2020.9.14
2 | import numpy as np
3 | import time
4 | import keras
5 | import tensorflow as tf
6 | from keras.models import Sequential,Model
7 | from keras.layers import Input,Dense,LSTM,Embedding,Conv1D,MaxPooling1D
8 | from keras.layers import Flatten,Dropout,Concatenate,Lambda,Multiply,Reshape,Dot,Bidirectional
9 | import keras.backend as K
10 | from keras.utils import to_categorical
11 | from gensim.models import Word2Vec
12 | from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
13 |
14 |
15 | class LSTM_Basic:
16 | """
17 | input->embedding->lstm->softmax_dense
18 | """
19 | def __init__(self,maxlen,vocab_size,wvdim,hidden_size,num_classes,embedding_matrix=None):
20 | text_input = Input(shape=(maxlen,),name='text_input')
21 | if embedding_matrix is None: # 不使用pretrained embedding
22 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,name='text_emb')(text_input) #(V,wvdim)
23 | else:
24 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,weights=[embedding_matrix],name='text_emb')(text_input) #(V,wvdim)
25 | input_vec = LSTM(hidden_size)(input_emb)
26 | pred_probs = Dense(num_classes,activation='softmax',name='pred_probs')(input_vec)
27 | self.model = Model(inputs=text_input,outputs=pred_probs)
28 | self.model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
29 |
30 | def train_val(self,data_package,batch_size,epochs,save_best=False):
31 | X_train,y_train,X_val,y_val,X_test,y_test = data_package
32 | best_val_score = 0
33 | final_test_score = 0
34 | final_train_score = 0
35 | val_socre_list = []
36 | """实验说明:
37 | 每一轮train完,在val上测试,记录其accuracy,
38 | 每当val-acc达到新高,就立马在test上测试,得到test-acc,
39 | 这样最终保留下来的test-acc就是在val上表现最好的模型在test上的accuracy
40 | """
41 | for i in range(epochs):
42 | t1 = time.time()
43 | self.model.fit(X_train,to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
44 | pred_probs = self.model.predict(X_val)
45 | predictions = np.argmax(pred_probs,axis=1)
46 | val_score = round(accuracy_score(y_val,predictions),5)
47 | t2 = time.time()
48 | print('(Orig)Epoch',i+1,'| time: %.3f s'%(t2-t1),'| current val accuracy:',val_score)
49 | if val_score>best_val_score:
50 | best_val_score = val_score
51 | # 使用当前val上最好模型进行test:
52 | pred_probs = self.model.predict(X_test)
53 | predictions = np.argmax(pred_probs,axis=1)
54 | final_test_score = round(accuracy_score(y_test,predictions),5)
55 | print(' Current Best model! Test score:',final_test_score)
56 | # 同时记录一下train上的score:
57 | pred_probs = self.model.predict(X_train)
58 | predictions = np.argmax(pred_probs, axis=1)
59 | final_train_score = round(accuracy_score(y_train, predictions),5)
60 | print(' Current Best model! Train score:', final_train_score)
61 | if save_best:
62 | self.model.save('best_model_lstm.h5')
63 | print(' best model saved!')
64 | val_socre_list.append(val_score)
65 | return best_val_score,val_socre_list,final_test_score,final_train_score
66 |
67 |
68 | class LSTM_LS:
69 | """
70 | input->embedding->lstm->softmax_dense
71 | """
72 | def __init__(self,maxlen,vocab_size,wvdim,hidden_size,num_classes,embedding_matrix=None):
73 | def ls_loss(y_true, y_pred, e=0.1):
74 | loss1 = K.categorical_crossentropy(y_true, y_pred)
75 | loss2 = K.categorical_crossentropy(K.ones_like(y_pred)/num_classes, y_pred)
76 | return (1-e)*loss1 + e*loss2
77 |
78 | text_input = Input(shape=(maxlen,),name='text_input')
79 | if embedding_matrix is None: # 不使用pretrained embedding
80 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,name='text_emb')(text_input) #(V,wvdim)
81 | else:
82 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,weights=[embedding_matrix],name='text_emb')(text_input) #(V,wvdim)
83 | input_vec = LSTM(hidden_size)(input_emb)
84 | pred_probs = Dense(num_classes,activation='softmax',name='pred_probs')(input_vec)
85 | self.model = Model(inputs=text_input,outputs=pred_probs)
86 | self.model.compile(loss=ls_loss, optimizer='adam', metrics=['accuracy'])
87 |
88 |
89 | def train_val(self,data_package,batch_size,epochs,save_best=False):
90 | X_train,y_train,X_val,y_val,X_test,y_test = data_package
91 | best_val_score = 0
92 | final_test_score = 0
93 | final_train_score = 0
94 | val_score_list = []
95 | for i in range(epochs):
96 | t1 = time.time()
97 | self.model.fit(X_train,to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
98 | # val:
99 | pred_probs = self.model.predict(X_val)
100 | predictions = np.argmax(pred_probs,axis=1)
101 | val_score = round(accuracy_score(y_val,predictions),5)
102 | t2 = time.time()
103 | print('(LS)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), '| current val accuracy:', val_score)
104 | if val_score>best_val_score:
105 | best_val_score = val_score
106 | # 使用当前val上最好模型进行test:
107 | pred_probs = self.model.predict(X_test)
108 | predictions = np.argmax(pred_probs, axis=1)
109 | final_test_score = round(accuracy_score(y_test, predictions),5)
110 | print(' Current Best model! Test score:', final_test_score)
111 | # 同时记录一下train上的score:
112 | pred_probs = self.model.predict(X_train)
113 | predictions = np.argmax(pred_probs, axis=1)
114 | final_train_score = round(accuracy_score(y_train, predictions),5)
115 | print(' Current Best model! Train score:', final_train_score)
116 | if save_best:
117 | self.model.save('best_model_ls.h5')
118 | print(' best model saved!')
119 | val_score_list.append(val_score)
120 | return best_val_score,val_score_list,final_test_score,final_train_score
121 |
122 |
123 | class LSTM_LCM_dynamic:
124 | """
125 | LCM dynamic,跟LCM的主要差别在于:
126 | 1.可以设置early stop,即设置在某一个epoch就停止LCM的作用;
127 | 2.在停止使用LCM之后,可以选择是否使用label smoothing来计算loss。
128 | """
129 | def __init__(self,maxlen,vocab_size,wvdim,hidden_size,num_classes,alpha,default_loss='ls',text_embedding_matrix=None,label_embedding_matrix=None):
130 | self.num_classes = num_classes
131 |
132 | def lcm_loss(y_true,y_pred,alpha=alpha):
133 | pred_probs = y_pred[:,:num_classes]
134 | label_sim_dist = y_pred[:,num_classes:]
135 | simulated_y_true = K.softmax(label_sim_dist+alpha*y_true)
136 | loss1 = -K.categorical_crossentropy(simulated_y_true,simulated_y_true)
137 | loss2 = K.categorical_crossentropy(simulated_y_true,pred_probs)
138 | return loss1+loss2
139 |
140 | def ls_loss(y_true, y_pred, e=0.1):
141 | loss1 = K.categorical_crossentropy(y_true, y_pred)
142 | loss2 = K.categorical_crossentropy(K.ones_like(y_pred)/num_classes, y_pred)
143 | return (1-e)*loss1 + e*loss2
144 |
145 | # basic_predictor:
146 | text_input = Input(shape=(maxlen,),name='text_input')
147 | if text_embedding_matrix is None:
148 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,name='text_emb')(text_input) #(V,wvdim)
149 | else:
150 | input_emb = Embedding(vocab_size+1,wvdim,input_length=maxlen,weights=[text_embedding_matrix],name='text_emb')(text_input) #(V,wvdim)
151 | input_vec = LSTM(hidden_size)(input_emb)
152 | pred_probs = Dense(num_classes,activation='softmax',name='pred_probs')(input_vec)
153 | self.basic_predictor = Model(input=text_input,output=pred_probs)
154 | if default_loss == 'ls':
155 | self.basic_predictor.compile(loss=ls_loss, optimizer='adam')
156 | else:
157 | self.basic_predictor.compile(loss='categorical_crossentropy', optimizer='adam')
158 |
159 | # LCM:
160 | label_input = Input(shape=(num_classes,),name='label_input')
161 | label_emb = Embedding(num_classes,wvdim,input_length=num_classes,name='label_emb1')(label_input) # (n,wvdim)
162 | label_emb = Dense(hidden_size,activation='tanh',name='label_emb2')(label_emb)
163 | # similarity part:
164 | doc_product = Dot(axes=(2,1))([label_emb,input_vec]) # (n,d) dot (d,1) --> (n,1)
165 | label_sim_dict = Dense(num_classes,activation='softmax',name='label_sim_dict')(doc_product)
166 | # concat output:
167 | concat_output = Concatenate()([pred_probs,label_sim_dict])
168 | # compile;
169 | self.model = Model(inputs=[text_input,label_input],outputs=concat_output)
170 | self.model.compile(loss=lcm_loss, optimizer='adam')
171 |
172 | def my_evaluator(self,model,inputs,label_list):
173 | outputs = model.predict(inputs)
174 | pred_probs = outputs[:,:self.num_classes]
175 | predictions = np.argmax(pred_probs,axis=1)
176 | acc = round(accuracy_score(label_list,predictions),5)
177 | # recall = recall_score(label_list,predictions,average='weighted')
178 | # f1 = f1_score(label_list,predictions,average='weighted')
179 | return acc
180 |
181 | def train_val(self,data_package,batch_size,epochs,lcm_stop=50,save_best=False):
182 | X_train,y_train,X_val,y_val,X_test,y_test = data_package
183 | L_train = np.array([np.array(range(self.num_classes)) for i in range(len(X_train))])
184 | L_val = np.array([np.array(range(self.num_classes)) for i in range(len(X_val))])
185 | L_test = np.array([np.array(range(self.num_classes)) for i in range(len(X_test))])
186 | best_val_score = 0
187 | final_test_score = 0
188 | final_train_score = 0
189 | val_score_list = []
190 | for i in range(epochs):
191 | if i < lcm_stop:
192 | t1 = time.time()
193 | self.model.fit([X_train,L_train],to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
194 | val_score = self.my_evaluator(self.model,[X_val,L_val],y_val)
195 | t2 = time.time()
196 | print('(LCM)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), '| current val accuracy:', val_score)
197 | if val_score>best_val_score:
198 | best_val_score = val_score
199 | # test:
200 | final_test_score = self.my_evaluator(self.model,[X_test,L_test],y_test)
201 | print(' Current Best model! Test score:',final_test_score)
202 | # train:
203 | final_train_score = self.my_evaluator(self.model, [X_train, L_train], y_train)
204 | print(' Current Best model! Train score:', final_train_score)
205 | if save_best:
206 | self.model.save('best_model.h5')
207 | print('best model saved!')
208 | val_score_list.append(val_score)
209 | else: # 停止LCM的作用
210 | t1 = time.time()
211 | self.basic_predictor.fit(X_train,to_categorical(y_train),batch_size=batch_size,verbose=0,epochs=1)
212 | pred_probs = self.basic_predictor.predict(X_val)
213 | predictions = np.argmax(pred_probs,axis=1)
214 | val_score = round(accuracy_score(y_val,predictions),5)
215 | t2 = time.time()
216 | print('(LCM-stop)Epoch', i + 1, '| time: %.3f s' % (t2 - t1), '| current val accuracy:', val_score)
217 | if val_score>best_val_score:
218 | best_val_score = val_score
219 | # test:
220 | final_test_score = self.my_evaluator(self.model, [X_test, L_test], y_test)
221 | print(' Current Best model! Test score:', final_test_score)
222 | # train:
223 | final_train_score = self.my_evaluator(self.model, [X_train, L_train], y_train)
224 | print(' Current Best model! Train score:', final_train_score)
225 | if save_best:
226 | self.model.save('best_model_lcm.h5')
227 | print(' best model saved!')
228 | val_score_list.append(val_score)
229 | return best_val_score,val_score_list,final_test_score,final_train_score
230 |
231 |
232 |
233 |
234 | # =================================
235 | # class LSTM_LCM:
236 | #
237 | # def __init__(self, maxlen, vocab_size, wvdim, hidden_size, num_classes, alpha, text_embedding_matrix=None,
238 | # label_embedding_matrix=None):
239 | # self.num_classes = num_classes
240 | #
241 | # def lcm_loss(y_true, y_pred, alpha=alpha):
242 | # pred_probs = y_pred[:, :num_classes]
243 | # label_sim_dist = y_pred[:, num_classes:]
244 | # simulated_y_true = K.softmax(label_sim_dist + alpha * y_true)
245 | # loss1 = -K.categorical_crossentropy(simulated_y_true, simulated_y_true)
246 | # loss2 = K.categorical_crossentropy(simulated_y_true, pred_probs)
247 | # return loss1 + loss2
248 | #
249 | # # text_encoder:
250 | # text_input = Input(shape=(maxlen,), name='text_input')
251 | # if text_embedding_matrix is None:
252 | # input_emb = Embedding(vocab_size + 1, wvdim, input_length=maxlen, name='text_emb')(text_input) # (V,wvdim)
253 | # else:
254 | # input_emb = Embedding(vocab_size + 1, wvdim, input_length=maxlen, weights=[text_embedding_matrix],
255 | # name='text_emb')(text_input) # (V,wvdim)
256 | # input_vec = LSTM(hidden_size)(input_emb)
257 | # pred_probs = Dense(num_classes, activation='softmax', name='pred_probs')(input_vec)
258 | # # label_encoder:
259 | # label_input = Input(shape=(num_classes,), name='label_input')
260 | # if label_embedding_matrix is None: # 不使用pretrained embedding
261 | # label_emb = Embedding(num_classes, wvdim, input_length=num_classes, name='label_emb')(
262 | # label_input) # (n,wvdim)
263 | # else:
264 | # label_emb = Embedding(num_classes, wvdim, input_length=num_classes, weights=[label_embedding_matrix],
265 | # name='label_emb')(label_input)
266 | # label_emb = Bidirectional(LSTM(hidden_size, return_sequences=True), merge_mode='ave')(label_emb) # (n,d)
267 | # # similarity part:
268 | # doc_product = Dot(axes=(2, 1))([label_emb, input_vec]) # (n,d) dot (d,1) --> (n,1)
269 | # label_sim_dict = Dense(num_classes, activation='softmax', name='label_sim_dict')(doc_product)
270 | # # concat output:
271 | # concat_output = Concatenate()([pred_probs, label_sim_dict])
272 | # # compile;
273 | # self.model = Model(inputs=[text_input, label_input], outputs=concat_output)
274 | # self.model.compile(loss=lcm_loss, optimizer='adam')
275 | #
276 | # def my_evaluator(self, model, inputs, label_list):
277 | # outputs = model.predict(inputs)
278 | # pred_probs = outputs[:, :self.num_classes]
279 | # predictions = np.argmax(pred_probs, axis=1)
280 | # acc = accuracy_score(label_list, predictions)
281 | # recall = recall_score(label_list, predictions, average='weighted')
282 | # f1 = f1_score(label_list, predictions, average='weighted')
283 | # return acc, recall, f1
284 | #
285 | # def train_val(self, data_package, batch_size, epochs, metric='accuracy', save_best=False):
286 | # X_train, y_train, X_val, y_val, X_test, y_test = data_package
287 | # L_train = np.array([np.array(range(self.num_classes)) for i in range(len(X_train))])
288 | # L_val = np.array([np.array(range(self.num_classes)) for i in range(len(X_val))])
289 | # L_test = np.array([np.array(range(self.num_classes)) for i in range(len(X_test))])
290 | # best_val_score = 0
291 | # learning_curve = []
292 | # for i in range(epochs):
293 | # self.model.fit([X_train, L_train], to_categorical(y_train), batch_size=batch_size, verbose=1, epochs=1)
294 | # acc, recall, f1 = self.my_evaluator(self.model, [X_val, L_val], y_val)
295 | # if metric == 'accuracy':
296 | # score = acc
297 | # print('Epoch', i + 1, '| current val %s:' % metric, score)
298 | # if score > best_val_score:
299 | # best_val_score = score
300 | # # test:
301 | # test_score, test_recall, test_f1 = self.my_evaluator(self.model, [X_test, L_test], y_test)
302 | # print('Current Best model! Test score:', test_score, 'current epoch:', i + 1)
303 | # if save_best:
304 | # self.model.save('best_model.h5')
305 | # print('best model saved!')
306 | # learning_curve.append(score)
307 | # return best_val_score, learning_curve, test_score
308 |
--------------------------------------------------------------------------------
/output/placeholder.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/label_confusion_learning/89a4eebfd5d947ea5ff1e87854209a43c7b8fde5/output/placeholder.txt
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from keras.preprocessing.text import text_to_word_sequence,Tokenizer
2 | from keras.preprocessing.sequence import pad_sequences
3 | import numpy as np
4 | from numpy.random import shuffle
5 | import pandas as pd
6 | import jieba
7 | import re
8 | import time
9 | import copy
10 |
11 | # ===================preprocessing:==============================
12 | def load_dataset(name):
13 | assert name in ['20NG','AG','DBP','FDU','THU'], "name only supports '20NG','AG','DBP','FDU','THU', but your input is %s"%name
14 |
15 | if name == '20NG':
16 | num_classes = 20
17 | df1 = pd.read_csv('datasets/20NG/20ng-train-all-terms.csv')
18 | df2 = pd.read_csv('datasets/20NG/20ng-test-all-terms.csv')
19 | df = pd.concat([df1,df2])
20 | comp_group = ['comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware',
21 | 'comp.windows.x']
22 | rec_group = ['rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey']
23 | talk_group = ['talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
24 | sci_group = ['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space']
25 | other_group = ['alt.atheism', 'misc.forsale', 'soc.religion.christian']
26 | label_groups = [comp_group,rec_group,talk_group]
27 | return df, num_classes,label_groups
28 | if name == 'AG':
29 | num_classes = 4 # DBP:14 AG:4
30 | df1 = pd.read_csv('../datasets/AG_news/train.csv', header=None, index_col=None)
31 | df2 = pd.read_csv('../datasets/AG_news/test.csv', header=None, index_col=None)
32 | df1.columns = ['label', 'title', 'content']
33 | df2.columns = ['label', 'title', 'content']
34 | df = pd.concat([df1, df2])
35 | label_groups = []
36 | return df,num_classes,label_groups
37 | if name == 'DBP':
38 | num_classes = 14
39 | df1 = pd.read_csv('../datasets/DBPedia/train.csv', header=None, index_col=None)
40 | df2 = pd.read_csv('../datasets/DBPedia/test.csv', header=None, index_col=None)
41 | df1.columns = ['label', 'title', 'content']
42 | df2.columns = ['label', 'title', 'content']
43 | df = pd.concat([df1, df2])
44 | label_groups = []
45 | return df,num_classes,label_groups
46 | if name == 'FDU':
47 | num_classes = 20
48 | df = pd.read_csv('../datasets/fudan_news.csv')
49 | label_groups = []
50 | return df, num_classes,label_groups
51 | if name == 'THU':
52 | num_classes = 13
53 | df = pd.read_csv('../datasets/thucnews_subset.csv')
54 | label_groups = []
55 | return df, num_classes,label_groups
56 |
57 | def create_shuffled_labels(y, label_groups, label2idx, rate):
58 | np.random.seed(0)
59 | """
60 | y: the list of label index
61 | label_groups: each group is a list of label names to shuffle
62 | rate: noise rate, note that this isn't the error rate, since the some labels may remain unchanged after shuffling
63 | """
64 | y_orig = list(y)[:]
65 | y = np.array(y)
66 | count = 0
67 | for labels in label_groups:
68 | label_ids = [label2idx[l] for l in labels]
69 | # find out the indexes of your target labels to add noise
70 | indexes = [i for i in range(len(y)) if y[i] in label_ids]
71 | shuffle(indexes)
72 | partial_indexes = indexes[:int(rate*len(indexes))]
73 | count += len(partial_indexes)
74 | shuffled_partial_indexes = partial_indexes[:]
75 | shuffle(shuffled_partial_indexes)
76 | y[partial_indexes] = y[shuffled_partial_indexes]
77 | errors = list(np.array(y_orig) & np.array(y)).count(0)
78 | shuffle_rate = count/len(y)
79 | error_rate = errors/len(y)
80 | return shuffle_rate,error_rate,y
81 |
82 |
83 | def create_asy_noise_labels(y, label_groups, label2idx, rate):
84 | np.random.seed(0)
85 | """
86 | y: the list of label index
87 | label_groups: each group is a list of label names to exchange within groups
88 | rate: noise rate, in this mode, each selected label will change to another random label in the same group
89 | """
90 | y_orig = list(y)[:]
91 | y = np.array(y)
92 | count = 0
93 | for labels in label_groups:
94 | label_ids = [label2idx[l] for l in labels]
95 | # find out the indexes of your target labels to add noise
96 | indexes = [i for i in range(len(y)) if y[i] in label_ids]
97 | shuffle(indexes)
98 | partial_indexes = indexes[:int(rate * len(indexes))]
99 | count += len(partial_indexes)
100 | # find out the indexes of your target labels to add noise
101 | for idx in partial_indexes:
102 | if y[idx] in label_ids:
103 | # randomly change to another label in the same group
104 | other_label_ids = label_ids[:]
105 | other_label_ids.remove(y[idx])
106 | y[idx] = np.random.choice(other_label_ids)
107 | errors = len(y) - list(np.array(y_orig) - np.array(y)).count(0)
108 | shuffle_rate = count / len(y)
109 | error_rate = errors / len(y)
110 | return shuffle_rate, error_rate, y
111 |
112 | def remove_punctuations(text):
113 | return re.sub('[,。:;’‘“”?!、,.!?\'\"\n\t]','',text)
114 |
115 |
116 | def fit_corpus(corpus,vocab_size=None):
117 | """
118 | corpus 为分好词的语料库
119 | """
120 | print("Start fitting the corpus......")
121 | t = Tokenizer(vocab_size) # 要使得文本向量化时省略掉低频词,就要设置这个参数
122 | tik = time.time()
123 | t.fit_on_texts(corpus) # 在所有的评论数据集上训练,得到统计信息
124 | tok = time.time()
125 | word_index = t.word_index # 不受vocab_size的影响
126 | print('all_vocab_size',len(word_index))
127 | print("Fitting time: ",(tok-tik),'s')
128 | freq_word_index = {}
129 | if vocab_size is not None:
130 | print("Creating freq-word_index...")
131 | x = list(t.word_counts.items())
132 | s = sorted(x,key=lambda p:p[1],reverse=True)
133 | freq_word_index = copy.deepcopy(word_index) # 防止原来的字典也被改变了
134 | for item in s[vocab_size:]:
135 | freq_word_index.pop(item[0])
136 | print("Finished!")
137 | return t,word_index,freq_word_index
138 |
139 |
140 | def text2idx(tokenizer,text,maxlen):
141 | """
142 | text 是一个列表,每个元素为一个文档的分词
143 | """
144 | print("Start vectorizing the sentences.......")
145 | X = tokenizer.texts_to_sequences(text) # 受vocab_size的影响
146 | print("Start padding......")
147 | pad_X = pad_sequences(X,maxlen=maxlen,padding='post')
148 | print("Finished!")
149 | return pad_X
150 |
151 |
152 | def create_embedding_matrix(wvmodel,vocab_size,emb_dim,word_index):
153 | """
154 | vocab_size 为词汇表大小,一般为词向量的词汇量
155 | emb_dim 为词向量维度
156 | word_index 为词和其index对应的查询词典
157 | """
158 | embedding_matrix = np.random.uniform(size=(vocab_size+1,emb_dim)) # +1是要留一个给index=0
159 | print("Transfering to the embedding matrix......")
160 | # sorted_small_index = sorted(list(small_word_index.items()),key=lambda x:x[1])
161 | for word,index in word_index.items():
162 | try:
163 | word_vector = wvmodel[word]
164 | embedding_matrix[index] = word_vector
165 | except Exception as e:
166 | print(e,"Use random embedding instead.")
167 | print("Finished!")
168 | print("Embedding matrix shape:\n",embedding_matrix.shape)
169 | return embedding_matrix
170 |
171 |
172 | def label2idx(label_list):
173 | label_dict = {}
174 | unique_labels = list(set(label_list))
175 | for i,each in enumerate(unique_labels):
176 | label_dict[each] = i
177 | new_label_list = []
178 | for label in label_list:
179 | new_label_list.append(label_dict[label])
180 | return new_label_list,label_dict
--------------------------------------------------------------------------------