├── .gitignore ├── .gitmodules ├── README.md ├── data └── ch_auto.csv ├── env.yml ├── finetuning.py ├── helper.py ├── predicting.py ├── preprocessing.py ├── pretraining.py ├── train_classifier.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tmp.zip 3 | tmp 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fastai"] 2 | path = fastai 3 | url = https://github.com/fastai/fastai.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 中文ULMFiT 3 | [Universal Language Model Fine-tuning for Text Classification 4 | ](https://arxiv.org/abs/1801.06146) 5 | 6 | [下载预训练的模型](https://drive.google.com/open?id=1Z9b1gVqfFjPaEEuU0Y-XfgsnmHr9yB_m) 7 | 8 | 9 | 创建虚拟环境([可以配置清华conda源](https://mirror.tuna.tsinghua.edu.cn/help/anaconda/)) 10 | ```bash 11 | conda env create -f env.yml 12 | ``` 13 | 14 | 解压中文维基百科语料 15 | ```bash 16 | python -m gensim.scripts.segment_wiki -i -f /data/zhwiki-latest-pages-articles.xml.bz2 -o tmp/wiki2018-11-14.json.gz 17 | ``` 18 | 19 | 分词维基百科语料 20 | ```bash 21 | python preprocessing.py segment-wiki --input_file=tmp/wiki2018-11-14.json.gz --output_file=tmp/wiki2018-11-14.words.pkl 22 | ``` 23 | 24 | 分词领域语料 25 | ```bash 26 | python preprocessing.py segment-csv --input_file=data/ch_auto.csv --output_file=tmp/ch_auto.words.pkl --label_file=tmp/ch_auto.labels.npy 27 | ``` 28 | 29 | tokenize维基百科语料 30 | ```bash 31 | python preprocessing.py tokenize --input_file=tmp/wiki2018-11-14.words.pkl --output_file=tmp/wiki2018-11-14.ids.npy --mapping_file=tmp/wiki2018-11-14.mapping.pkl 32 | ``` 33 | 34 | tokenize领域语料 35 | ```bash 36 | python preprocessing.py tokenize --input_file=tmp/ch_auto.words.pkl --output_file=tmp/ch_auto.ids.npy --mapping_file=tmp/ch_auto.mapping.pkl 37 | ``` 38 | 39 | 预训练 40 | ```bash 41 | python pretraining.py --input_file=tmp/wiki2018-11-14.ids.npy --mapping_file=tmp/wiki2018-11-14.mapping.pkl --dir_path=tmp 42 | ``` 43 | 44 | 微调 45 | ```bash 46 | python finetuning.py --input_file=tmp/ch_auto.ids.npy --mapping_file=tmp/ch_auto.mapping.pkl --pretrain_model_file=tmp/models/wiki2018-11-14.h5 --pretrain_mapping_file=tmp/wiki2018-11-14.mapping.pkl --dir_path=tmp --model_id=ch_auto 47 | ``` 48 | 49 | 训练分类器 50 | ```bash 51 | python3 train_classifier.py --id_file=tmp/ch_auto.ids.npy --label_file=tmp/ch_auto.labels.npy --mapping_file=tmp/ch_auto.mapping.pkl --encoder_file=ch_auto_enc 52 | ``` 53 | 54 | 测试 55 | ```bash 56 | python3 predicting.py --mapping_file=tmp/ch_auto.mapping.pkl --classifier_filename=tmp/models/classifier_1.h5 --num_class=2 57 | ``` 58 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: fastai 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/peterjc123/ 7 | - http://mirrors.ustc.edu.cn/anaconda/pkgs/free/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 11 | dependencies: 12 | - cuda92=1.0=0 13 | - pytorch=0.3.0=py36_cuda8.0.61_cudnn7.0.3h37a80b5_4 14 | - blas=1.0=mkl 15 | - bleach=1.5.0=py36_0 16 | - certifi=2016.2.28=py36_0 17 | - cffi=1.10.0=py36_0 18 | - cudatoolkit=8.0=3 19 | - dbus=1.10.20=0 20 | - decorator=4.1.2=py36_0 21 | - entrypoints=0.2.3=py36_0 22 | - expat=2.1.0=0 23 | - fontconfig=2.12.1=3 24 | - freetype=2.5.5=2 25 | - glib=2.50.2=1 26 | - gst-plugins-base=1.8.0=0 27 | - gstreamer=1.8.0=0 28 | - html5lib=0.9999999=py36_0 29 | - icu=54.1=0 30 | - ipykernel=4.6.1=py36_0 31 | - ipython=6.1.0=py36_0 32 | - ipython_genutils=0.2.0=py36_0 33 | - ipywidgets=6.0.0=py36_0 34 | - jedi=0.10.2=py36_2 35 | - jinja2=2.9.6=py36_0 36 | - jpeg=9b=0 37 | - jsonschema=2.6.0=py36_0 38 | - jupyter_client=5.1.0=py36_0 39 | - jupyter_console=5.2.0=py36_0 40 | - jupyter_core=4.3.0=py36_0 41 | - libffi=3.2.1=1 42 | - libgcc=5.2.0=0 43 | - libiconv=1.14=0 44 | - libpng=1.6.30=1 45 | - libsodium=1.0.10=0 46 | - libxcb=1.12=1 47 | - libxml2=2.9.4=0 48 | - markupsafe=1.0=py36_0 49 | - mistune=0.7.4=py36_0 50 | - nbconvert=5.2.1=py36_0 51 | - nbformat=4.4.0=py36_0 52 | - notebook=5.0.0=py36_0 53 | - openssl=1.0.2l=0 54 | - pandocfilters=1.4.2=py36_0 55 | - path.py=10.3.1=py36_0 56 | - pcre=8.39=1 57 | - pexpect=4.2.1=py36_0 58 | - pickleshare=0.7.4=py36_0 59 | - pip=9.0.1=py36_1 60 | - prompt_toolkit=1.0.15=py36_0 61 | - ptyprocess=0.5.2=py36_0 62 | - pycparser=2.18=py36_0 63 | - pygments=2.2.0=py36_0 64 | - pyqt=5.6.0=py36_2 65 | - python=3.6.2=0 66 | - python-dateutil=2.6.1=py36_0 67 | - pyzmq=16.0.2=py36_0 68 | - qt=5.6.2=5 69 | - qtconsole=4.3.1=py36_0 70 | - readline=6.2=2 71 | - setuptools=36.4.0=py36_1 72 | - simplegeneric=0.8.1=py36_1 73 | - sip=4.18=py36_0 74 | - six=1.10.0=py36_0 75 | - sqlite=3.13.0=0 76 | - terminado=0.6=py36_0 77 | - testpath=0.3.1=py36_0 78 | - tk=8.5.18=0 79 | - tornado=4.5.2=py36_0 80 | - traitlets=4.3.2=py36_0 81 | - wcwidth=0.1.7=py36_0 82 | - wheel=0.29.0=py36_0 83 | - widgetsnbextension=3.0.2=py36_0 84 | - xz=5.2.3=0 85 | - zeromq=4.1.5=0 86 | - zlib=1.2.11=0 87 | - intel-openmp=2019.0=118 88 | - libgcc-ng=8.2.0=hdf63c60_1 89 | - libgfortran-ng=7.3.0=hdf63c60_0 90 | - libstdcxx-ng=8.2.0=hdf63c60_1 91 | - mkl=2018.0.3=1 92 | - mkl_fft=1.0.6=py36h7dd41cf_0 93 | - mkl_random=1.0.1=py36h4414c95_1 94 | - numpy=1.15.4=py36h1d66e8a_0 95 | - numpy-base=1.15.4=py36h81de0dd_0 96 | - pip: 97 | - backcall==0.1.0 98 | - bcolz==1.2.1 99 | - boto==2.49.0 100 | - boto3==1.9.38 101 | - botocore==1.12.38 102 | - bz2file==0.98 103 | - chardet==3.0.4 104 | - click==7.0 105 | - cycler==0.10.0 106 | - cymem==2.0.2 107 | - cytoolz==0.9.0.1 108 | - defusedxml==0.5.0 109 | - descartes==1.1.0 110 | - dill==0.2.8.2 111 | - docutils==0.14 112 | - fastai==0.7.0 113 | - feather-format==0.4.0 114 | - fire==0.1.3 115 | - gensim==3.6.0 116 | - graphviz==0.10.1 117 | - idna==2.7 118 | - isoweek==1.3.3 119 | - jieba==0.39 120 | - jmespath==0.9.3 121 | - jupyter==1.0.0 122 | - kiwisolver==1.0.1 123 | - matplotlib==3.0.1 124 | - mizani==0.5.2 125 | - msgpack==0.5.6 126 | - msgpack-numpy==0.4.3.2 127 | - murmurhash==1.0.1 128 | - opencc-python-reimplemented==0.1.4 129 | - opencv-python==3.4.3.18 130 | - palettable==3.1.1 131 | - pandas==0.23.4 132 | - pandas-summary==0.0.5 133 | - parso==0.3.1 134 | - patsy==0.5.1 135 | - pillow==5.3.0 136 | - plac==0.9.6 137 | - plotnine==0.5.1 138 | - preshed==2.0.1 139 | - prometheus-client==0.4.2 140 | - pyarrow==0.11.1 141 | - pyparsing==2.3.0 142 | - pytz==2018.7 143 | - pyyaml==3.13 144 | - regex==2018.1.10 145 | - requests==2.20.0 146 | - s3transfer==0.1.13 147 | - scikit-learn==0.20.0 148 | - scipy==1.1.0 149 | - seaborn==0.9.0 150 | - send2trash==1.5.0 151 | - sklearn==0.0 152 | - sklearn-pandas==1.7.0 153 | - smart-open==1.7.1 154 | - spacy==2.0.16 155 | - statsmodels==0.9.0 156 | - thinc==6.12.0 157 | - toolz==0.9.0 158 | - torch==0.3.0.post4 159 | - torchtext==0.2.3 160 | - torchvision==0.2.1 161 | - tqdm==4.19.6 162 | - ujson==1.35 163 | - urllib3==1.24.1 164 | - webencodings==0.5.1 165 | - wrapt==1.10.11 166 | prefix: /data/conda/fastai 167 | 168 | -------------------------------------------------------------------------------- /finetuning.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from utils import * 3 | from fastai.text import * 4 | from fastai.lm_rnn import * 5 | 6 | 7 | class EarlyStopping(Callback): 8 | def __init__(self, learner, model_path, encoder_path, patience=5): 9 | super().__init__() 10 | self.learner = learner 11 | self.model_path = model_path 12 | self.encoder_path = encoder_path 13 | self.patience = patience 14 | 15 | def on_train_begin(self): 16 | self.best_validation_loss = 100 17 | self.num_epochs_no_improvement = 0 18 | 19 | def on_epoch_end(self, metrics): 20 | validation_loss = metrics[0] 21 | if validation_loss < self.best_validation_loss: 22 | self.best_validation_loss = validation_loss 23 | self.learner.save(self.model_path) 24 | self.learner.save_encoder(self.encoder_path) 25 | print('\nSaving best model') 26 | self.num_epochs_no_improvement = 0 27 | else: 28 | self.num_epochs_no_improvement += 1 29 | if self.num_epochs_no_improvement > self.patience: 30 | print(f'\nStopping - no improvement after {self.patience+1} epochs') 31 | return True 32 | 33 | def on_train_end(self): 34 | pass 35 | 36 | def finetune_language_model(input_file, mapping_file, dir_path, pretrain_model_file, pretrain_mapping_file, model_id, 37 | cuda_id=1, cycle_len=25, batch_size=64, 38 | dropout_multiply=1.0, learning_rate=4e-3): 39 | torch.cuda.set_device(cuda_id) 40 | 41 | bptt = 70 42 | embedding_size, n_hidden, n_layer = 400, 1150, 3 43 | opt_func = partial(optim.Adam, betas=(0.8, 0.99)) 44 | 45 | data = np.load(input_file) 46 | train_data = data[:-len(data) // 10] 47 | validation_data = data[-len(data) // 10:] 48 | 49 | train_data = np.concatenate(train_data) 50 | validation_data = np.concatenate(validation_data) 51 | 52 | itos = load_pickle(mapping_file) 53 | vocabulary_size = len(itos) 54 | 55 | train_data_loader = LanguageModelLoader(train_data, batch_size, bptt) 56 | validation_data_loader = LanguageModelLoader(validation_data, batch_size, bptt) 57 | model_data = LanguageModelData(Path(dir_path), 1, vocabulary_size, train_data_loader, validation_data_loader, bs=batch_size, bptt=bptt) 58 | 59 | dropouts = np.array([0.25, 0.1, 0.2, 0.02, 0.15]) * dropout_multiply 60 | 61 | learner = model_data.get_model(opt_func, embedding_size, n_hidden, n_layer, 62 | dropouti=dropouts[0], dropout=dropouts[1], wdrop=dropouts[2], dropoute=dropouts[3], dropouth=dropouts[4]) 63 | learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1) 64 | learner.clip = 0.3 65 | learner.metrics = [accuracy] 66 | weight_decay = 1e-7 67 | 68 | learning_rates = np.array([learning_rate / 6, learning_rate / 3, learning_rate, learning_rate / 2]) 69 | weights = torch.load(pretrain_model_file, map_location=lambda storage, loc: storage) 70 | encoder_weights = to_np(weights['0.encoder.weight']) 71 | row_mean = encoder_weights.mean(0) 72 | 73 | pretrain_itos = load_pickle(pretrain_mapping_file) 74 | pretrain_stoi = collections.defaultdict(lambda: -1, {v: k for k, v in enumerate(pretrain_itos)}) 75 | new_weights = np.zeros((vocabulary_size, embedding_size), dtype=np.float32) 76 | for i, word in enumerate(itos): 77 | _id = pretrain_stoi[word] 78 | if _id >= 0: 79 | new_weights[i] = encoder_weights[_id] 80 | else: 81 | new_weights[i] = row_mean 82 | weights['0.encoder.weight'] = T(new_weights) 83 | weights['0.encoder_with_dropout.embed.weight'] = T(np.copy(new_weights)) 84 | weights['1.decoder.weight'] = T(np.copy(new_weights)) 85 | learner.model.load_state_dict(weights) 86 | n_cycle = 1 87 | callbacks = [EarlyStopping(learner, f'{model_id}', f'{model_id}_enc', patience=5)] 88 | learner.fit(learning_rates, n_cycle, wds=weight_decay, use_clr=(32, 10), cycle_len=cycle_len, 89 | callbacks=callbacks) 90 | 91 | if __name__ == '__main__': fire.Fire(finetune_language_model) 92 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | from fastai.learner import * 2 | from fastai.text import * 3 | 4 | 5 | def get_probs(ids, vocabulary_size): 6 | counter = Counter(ids) 7 | counter = np.array([counter[i] for i in range(vocabulary_size)]) 8 | return counter / counter.sum() 9 | 10 | 11 | class LinearDecoder(nn.Module): 12 | init_range = 0.1 13 | 14 | def __init__(self, n_out, n_hid, dropout, tie_encoder=None, decode_train=True): 15 | super().__init__() 16 | self.decode_train = decode_train 17 | self.decoder = nn.Linear(n_hid, n_out, bias=False) 18 | self.decoder.weight.data.uniform_(-self.init_range, self.init_range) 19 | self.dropout = LockedDropout(dropout) 20 | if tie_encoder: 21 | self.decoder.weight = tie_encoder.weight 22 | 23 | def forward(self, inputs): 24 | raw_outputs, outputs = inputs 25 | output = self.dropout(outputs[-1]) 26 | output = output.view(output.size(0) * output.size(1), output.size(2)) 27 | if self.decode_train or not self.training: 28 | decoded = self.decoder(output) 29 | output = decoded.view(-1, decoded.size(1)) 30 | return output, raw_outputs, outputs 31 | 32 | 33 | def get_language_model(n_token, embedding_size, n_hid, n_layer, padding_token, decode_train=True, dropouts=None): 34 | if dropouts is None: 35 | dropouts = [0.5, 0.4, 0.5, 0.05, 0.3] 36 | enc = RNN_Encoder(n_token, embedding_size, nhid=n_hid, nlayers=n_layer, pad_token=padding_token, 37 | dropouti=dropouts[0], wdrop=dropouts[2], dropoute=dropouts[3], dropouth=dropouts[4]) 38 | dec = LinearDecoder(n_token, embedding_size, dropouts[1], decode_train=decode_train, 39 | tie_encoder=enc.encoder) 40 | return SequentialRNN(enc, dec) 41 | 42 | 43 | def pt_sample(probs, n): 44 | w = -torch.log(cuda.FloatTensor(len(probs)).uniform_()) / (probs + 1e-10) 45 | return torch.topk(w, n, largest=False)[1] 46 | 47 | 48 | class CrossEntropyDecoder(nn.Module): 49 | init_range = 0.1 50 | 51 | def __init__(self, probs, decoder, n_neg=4000, sampled=True): 52 | super().__init__() 53 | self.probs, self.decoder, self.sampled = T(probs).cuda(), decoder, sampled 54 | self.set_n_neg(n_neg) 55 | 56 | def set_n_neg(self, n_neg): 57 | self.n_neg = n_neg 58 | 59 | def get_random_indexes(self): 60 | return pt_sample(self.probs, self.n_neg) 61 | 62 | def sampled_softmax(self, input, target): 63 | idxs = V(self.get_random_indexes()) 64 | dw = self.decoder.weight 65 | output = input @ dw[idxs].t() 66 | max_output = output.max() 67 | output = output - max_output 68 | num = (dw[target] * input).sum(1) - max_output 69 | negs = torch.exp(num) + (torch.exp(output) * 2).sum(1) 70 | return (torch.log(negs) - num).mean() 71 | 72 | def forward(self, input, target): 73 | if self.decoder.training: 74 | if self.sampled: 75 | return self.sampled_softmax(input, target) 76 | else: 77 | input = self.decoder(input) 78 | return F.cross_entropy(input, target) 79 | 80 | 81 | def get_learner(dropouts, n_neg, sampled, model_data, embedding_size, n_hidden, n_layer, opt_func, probs): 82 | model = to_gpu(get_language_model(model_data.n_tok, embedding_size, n_hidden, n_layer, model_data.pad_idx, decode_train=False, dropouts=dropouts)) 83 | criterion = CrossEntropyDecoder(probs, model[1].decoder, n_neg=n_neg, sampled=sampled).cuda() 84 | learner = RNN_Learner(model_data, LanguageModel(model), opt_fn=opt_func) 85 | criterion.dw = learner.model[0].encoder.weight 86 | learner.crit = criterion 87 | learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1) 88 | learner.clip = 0.3 89 | return learner, criterion 90 | -------------------------------------------------------------------------------- /predicting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastai.text import * 3 | from utils import * 4 | import fire 5 | from preprocessing import * 6 | 7 | 8 | def load_model(itos_filename, classifier_filename, num_class): 9 | itos = load_pickle(itos_filename) 10 | stoi = collections.defaultdict(lambda: 0, {str(v): int(k) for k, v in enumerate(itos)}) 11 | bptt, embedding_size, n_hidden, n_layer = 70, 400, 1150, 3 12 | dropouts = np.array([0.4, 0.5, 0.05, 0.3, 0.4]) * 0.5 13 | vocabulary_size = len(itos) 14 | 15 | model = get_rnn_classifer(bptt, 20 * 70, num_class, vocabulary_size, emb_sz=embedding_size, n_hid=n_hidden, 16 | n_layers=n_layer, 17 | pad_token=1, 18 | layers=[embedding_size * 3, 50, num_class], drops=[dropouts[4], 0.1], 19 | dropouti=dropouts[0], wdrop=dropouts[1], dropoute=dropouts[2], dropouth=dropouts[3]) 20 | 21 | model.load_state_dict(torch.load(classifier_filename, map_location=lambda storage, loc: storage)) 22 | model.reset() 23 | model.eval() 24 | 25 | return stoi, model 26 | 27 | 28 | def softmax(x): 29 | if x.ndim == 1: 30 | x = x.reshape((1, -1)) 31 | max_x = np.max(x, axis=1).reshape((-1, 1)) 32 | exp_x = np.exp(x - max_x) 33 | return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1)) 34 | 35 | 36 | def predict_text(stoi, model, text): 37 | words = segment_line(text) 38 | ids = tokenize_words(stoi, words) 39 | array = np.reshape(np.array(ids), (-1, 1)) 40 | tensor = torch.from_numpy(array) 41 | variable = Variable(tensor) 42 | predictions = model(variable) 43 | numpy_prediction = predictions[0].data.numpy() 44 | return softmax(numpy_prediction[0])[0] 45 | 46 | 47 | def predict_input(mapping_file, classifier_filename, num_class=2): 48 | stoi, model = load_model(mapping_file, classifier_filename, num_class) 49 | while True: 50 | text = input("Text: ") 51 | scores = predict_text(stoi, model, text) 52 | print("Scores: {0}".format(scores)) 53 | 54 | 55 | if __name__ == '__main__': 56 | fire.Fire(predict_input) 57 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | from smart_open import smart_open 2 | from tqdm import tqdm 3 | from collections import Counter 4 | from opencc import OpenCC 5 | import json 6 | import pickle 7 | import jieba 8 | import click 9 | import re 10 | import numpy as np 11 | from pandas import read_csv 12 | import pandas as pd 13 | 14 | jieba.initialize() 15 | 16 | CC = OpenCC('t2s') 17 | REGEX = re.compile(r'[^\u4e00-\u9fa5aA-Za-z0-9]') 18 | 19 | UNK_ID = 0 20 | PAD_ID = 1 21 | BOS_ID = 2 22 | UNK = '_unk_' 23 | PAD = '_pad_' 24 | BOS = '_bos_' 25 | 26 | 27 | def segment_line(line): 28 | line = CC.convert(REGEX.sub(' ', line)) 29 | return list(filter(lambda x: x.strip(), jieba.cut(line))) 30 | 31 | 32 | def tokenize_words(stoi, words): 33 | return [BOS_ID] + [stoi.get(word, UNK_ID) for word in words] 34 | 35 | 36 | @click.command() 37 | @click.option('--input_file') 38 | @click.option('--output_file') 39 | def segment_wiki(input_file, output_file): 40 | with smart_open(input_file) as fin: 41 | with smart_open(output_file, 'wb') as fout: 42 | words = [] 43 | for line in tqdm(fin): 44 | article = json.loads(line) 45 | words.append(segment_line(article['title'])) 46 | for section_title, section_text in zip(article['section_titles'], article['section_texts']): 47 | words.append(segment_line(section_title)) 48 | for text in section_text.splitlines(): 49 | words.append(segment_line(text)) 50 | pickle.dump(words, fout) 51 | 52 | 53 | @click.command() 54 | @click.option('--input_file') 55 | @click.option('--output_file') 56 | @click.option('--label_file') 57 | def segment_csv(input_file, output_file, label_file): 58 | with smart_open(output_file, 'wb') as fout: 59 | df = pd.read_csv(input_file) 60 | np.save(label_file, df['label'].values) 61 | words = [] 62 | for line in tqdm(df['text']): 63 | words.append(segment_line(line)) 64 | pickle.dump(words, fout) 65 | 66 | 67 | @click.command() 68 | @click.option('--input_file') 69 | @click.option('--mapping_file') 70 | @click.option('--output_file') 71 | @click.option('--vocabulary_size', default=100000) 72 | @click.option('--min_word_count', default=2) 73 | def tokenize(input_file, mapping_file, output_file, vocabulary_size, min_word_count): 74 | counter = Counter() 75 | with smart_open(input_file) as fin: 76 | with smart_open(mapping_file, 'wb') as fmapping: 77 | total_words = pickle.load(fin) 78 | for words in tqdm(total_words): 79 | counter.update(words) 80 | stoi = {**{UNK: UNK_ID, PAD: PAD_ID, BOS: BOS_ID}, 81 | **{word: token + 3 for token, (word, count) in enumerate(counter.most_common(vocabulary_size)) if count > min_word_count}} 82 | itos = [UNK, PAD, BOS] + [word for word, _ in counter.most_common(vocabulary_size)] 83 | pickle.dump(itos, fmapping) 84 | total_ids = [] 85 | for words in tqdm(total_words): 86 | total_ids.append(tokenize_words(stoi, words)) 87 | np.save(output_file, np.array(total_ids)) 88 | 89 | 90 | @click.group() 91 | def entry_point(): 92 | pass 93 | 94 | 95 | entry_point.add_command(segment_wiki) 96 | entry_point.add_command(segment_csv) 97 | entry_point.add_command(tokenize) 98 | 99 | if __name__ == '__main__': 100 | entry_point() 101 | -------------------------------------------------------------------------------- /pretraining.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from fastai.text import * 3 | from helper import * 4 | from utils import * 5 | 6 | 7 | def train_language_model(input_file, mapping_file, dir_path, model_id='wiki2018-11-14', cuda_id=1, cycle_len=12, batch_size=64, learning_rate=3e-4, 8 | sampled=True): 9 | torch.cuda.set_device(cuda_id) 10 | bptt = 70 11 | embedding_size, n_hidden, n_layer = 400, 1150, 3 12 | opt_func = partial(optim.Adam, betas=(0.8, 0.99)) 13 | 14 | data = np.load(input_file) 15 | train_data = data[:-len(data) // 10] 16 | validation_data = data[-len(data) // 10:] 17 | 18 | train_data = np.concatenate(train_data) 19 | validation_data = np.concatenate(validation_data) 20 | 21 | itos = load_pickle(mapping_file) 22 | vocabulary_size = len(itos) 23 | 24 | train_data_loader = LanguageModelLoader(train_data, batch_size, bptt) 25 | validation_data_loader = LanguageModelLoader(validation_data, batch_size // 5 if sampled else batch_size, bptt) 26 | model_data = LanguageModelData(Path(dir_path), 1, vocabulary_size, train_data_loader, validation_data_loader, bs=batch_size, bptt=bptt) 27 | probs = get_probs(train_data, vocabulary_size) 28 | drops = np.array([0.25, 0.1, 0.2, 0.02, 0.15]) * 0.5 29 | learner, _ = get_learner(drops, 15000, sampled, model_data, embedding_size, n_hidden, n_layer, opt_func, probs) 30 | weight_decay = 1e-7 31 | learner.metrics = [accuracy] 32 | learning_rates = np.array([learning_rate / 6, learning_rate / 3, learning_rate, learning_rate]) 33 | learner.fit(learning_rates, 1, wds=weight_decay, use_clr=(32, 10), cycle_len=cycle_len) 34 | learner.save(f'{model_id}') 35 | learner.save_encoder(f'{model_id}_enc') 36 | 37 | 38 | if __name__ == '__main__': fire.Fire(train_language_model) 39 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from utils import * 3 | from fastai.text import * 4 | from fastai.lm_rnn import * 5 | from sklearn.model_selection import StratifiedShuffleSplit 6 | 7 | 8 | def freeze_all_but(learner, n): 9 | layer_groups = learner.get_layer_groups() 10 | for group in layer_groups: 11 | set_trainable(group, False) 12 | set_trainable(layer_groups[n], True) 13 | 14 | 15 | def train_classifier(id_file, label_file, mapping_file, encoder_file, dir_path='tmp', cuda_id=1, batch_size=64, 16 | cycle_len=15, 17 | learning_rate=0.01, dropout_multiply=1.0): 18 | torch.cuda.set_device(cuda_id) 19 | 20 | dir_path = Path(dir_path) 21 | intermediate_classifier_file = 'classifier_0' 22 | final_classifier_file = 'classifier_1' 23 | 24 | bptt, embedding_size, n_hidden, n_layer = 70, 400, 1150, 3 25 | opt_func = partial(optim.Adam, betas=(0.8, 0.99)) 26 | ids = np.load(id_file) 27 | labels = np.load(label_file) 28 | split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42) 29 | for train_index, test_index in split.split(ids, labels): 30 | train_ids, train_labels = ids[train_index], labels[train_index] 31 | validation_ids, validation_labels = ids[test_index], labels[test_index] 32 | 33 | train_labels = train_labels.flatten() 34 | validation_labels = validation_labels.flatten() 35 | train_labels -= train_labels.min() 36 | validation_labels -= validation_labels.min() 37 | label_count = int(train_labels.max()) + 1 38 | 39 | itos = load_pickle(mapping_file) 40 | vocabulary_size = len(itos) 41 | 42 | train_data_set = TextDataset(train_ids, train_labels) 43 | validation_data_set = TextDataset(validation_ids, validation_labels) 44 | train_sampler = SortishSampler(train_ids, key=lambda x: len(train_ids[x]), bs=batch_size // 2) 45 | validation_sampler = SortSampler(validation_ids, key=lambda x: len(validation_ids[x])) 46 | train_data_loader = DataLoader(train_data_set, batch_size // 2, transpose=True, num_workers=1, pad_idx=1, 47 | sampler=train_sampler) 48 | validation_data_loader = DataLoader(validation_data_set, batch_size, transpose=True, num_workers=1, pad_idx=1, 49 | sampler=validation_sampler) 50 | model_data = ModelData(dir_path, train_data_loader, validation_data_loader) 51 | 52 | dropouts = np.array([0.4, 0.5, 0.05, 0.3, 0.4]) * dropout_multiply 53 | 54 | model = get_rnn_classifer(bptt, 20 * bptt, label_count, vocabulary_size, emb_sz=embedding_size, n_hid=n_hidden, 55 | n_layers=n_layer, 56 | pad_token=1, 57 | layers=[embedding_size * 3, 50, label_count], drops=[dropouts[4], 0.1], 58 | dropouti=dropouts[0], wdrop=dropouts[1], dropoute=dropouts[2], dropouth=dropouts[3]) 59 | 60 | learn = RNN_Learner(model_data, TextModel(to_gpu(model)), opt_fn=opt_func) 61 | learn.reg_fn = partial(seq2seq_reg, alpha=2, beta=1) 62 | learn.clip = 25. 63 | learn.metrics = [accuracy] 64 | 65 | ratio = 2.6 66 | learning_rates = np.array([ 67 | learning_rate / (ratio ** 4), 68 | learning_rate / (ratio ** 3), 69 | learning_rate / (ratio ** 2), 70 | learning_rate / ratio, 71 | learning_rate]) 72 | 73 | weight_decay = 1e-6 74 | learn.load_encoder(encoder_file) 75 | 76 | learn.freeze_to(-1) 77 | learn.fit(learning_rates, 1, wds=weight_decay, cycle_len=1, use_clr=(8, 3)) 78 | learn.freeze_to(-2) 79 | learn.fit(learning_rates, 1, wds=weight_decay, cycle_len=1, use_clr=(8, 3)) 80 | learn.save(intermediate_classifier_file) 81 | 82 | learn.unfreeze() 83 | n_cycle = 1 84 | learn.fit(learning_rates, n_cycle, wds=weight_decay, cycle_len=cycle_len, use_clr=(8, 8)) 85 | learn.save(final_classifier_file) 86 | 87 | 88 | if __name__ == '__main__': fire.Fire(train_classifier) 89 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def load_pickle(path): 5 | with open(path, 'rb') as f: 6 | return pickle.load(f) 7 | --------------------------------------------------------------------------------