├── .idea ├── deployment.xml ├── encodings.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml ├── transformer-xl.iml ├── vcs.xml ├── webServers.xml └── workspace.xml ├── README.md ├── code_for_Classfy ├── data_utils.py ├── eval.py ├── mem_transformer.py ├── myattention.py ├── test.py ├── train.py ├── utils │ ├── adaptive_softmax.py │ ├── data_parallel.py │ ├── exp_utils.py │ ├── log_uniform_sampler.py │ ├── proj_adaptive_softmax.py │ └── vocabulary.py └── 运行命令 ├── code_for_LM ├── .DS_Store ├── data_utils.py ├── eval.py ├── mem_transformer.py ├── train.py ├── utils │ ├── adaptive_softmax.py │ ├── data_parallel.py │ ├── exp_utils.py │ ├── log_uniform_sampler.py │ ├── proj_adaptive_softmax.py │ └── vocabulary.py └── 运行命令 ├── image └── model.jpg ├── results └── LM │ └── vocab.txt └── 说明 /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /.idea/transformer-xl.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 介绍 2 | 功能:这是一个使用transform-xl实现的长文本分类器。这是一个文章级别的分类。文章的长度3000(分好词) 3 | 训练方式:采用预训练+微调 4 | 语言:pytorch1.0 5 | 6 | ## 分类模型结构 7 | ![image](https://github.com/xuhaiming1996/Transformer-xl-for-Dcument-Classify/blob/master/image/model.jpg) 8 | 9 | ## 文件说明 10 | ### data 11 | #### /data/LM 12 | train.txt 存放分好词的语料,每一行为一篇章。 13 | #### /data/CLASSIFY 14 | train.label 每一行为文章标签 15 | train.txt 每一行为已经分好词的一篇文章 16 | valid.label 每一行为文章标签 17 | valid.txt 每一行为已经分好词的一篇文章 18 | 19 | 20 | ### code_for_LM 21 | #### /code_for_LM/运行命令 22 | 这是我预训练的命令,你可以根据自己的实际情况自己调整 23 | 这里提示:尽量采用4卡并行计算 24 | 25 | 26 | ### code_for_CLassify 27 | #### /code_for_CLassify/运行命令 28 | 这是微调的的命令,你可以根据自己的实际情况自己调整 29 | 这里提示:尽量采用4卡并行计算, 30 | 31 | 32 | ### results 33 | #### /results/LM 34 | 预训练的保存路径:模型参数,词典等等 35 | #### /results/CLASSIFY 36 | 训练分类器的保存路径:模型参数等 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /code_for_Classfy/data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | from collections import Counter, OrderedDict 4 | import numpy as np 5 | import torch 6 | 7 | from utils.vocabulary import Vocab 8 | 9 | 10 | class BatchIteratorHelper: 11 | def __init__(self, data, lables, bsz, alianlen=3000, device='cpu'): 12 | ''' 13 | 14 | :param data: [样本数,3000] 15 | :param lables: [样本数] 16 | :param bsz: 批次操作 17 | :param alianlen: 3000 18 | :param device: 19 | ''' 20 | 21 | self.bsz = bsz 22 | self.num_example = data.size(0) # 样本的个数 23 | self.device = device 24 | self.data = data.contiguous().to(device) 25 | # Number of mini-batches 26 | self.n_batch = (self.num_example + self.bsz - 1) // self.bsz # 批次的个数 27 | self.lables = lables.contiguous().to(device) 28 | 29 | print(self.data.size()) 30 | # 下面将进行验证 31 | assert self.data.size(1)==alianlen 32 | assert self.data.size(0)==self.lables.size(0) 33 | 34 | 35 | def get_batch(self, i, bsz=None): 36 | if bsz is None: 37 | bsz = self.bsz 38 | bsz_len = min(bsz, self.data.size(0) - 1 - i) 39 | end_idx = i + bsz_len 40 | beg_idx = i 41 | data = self.data[beg_idx:end_idx] 42 | 43 | labels=self.lables[beg_idx:end_idx] 44 | 45 | return data, labels,bsz_len 46 | 47 | def get_fixlen_iter(self, start=0): 48 | for i in range(start, self.data.size(0) - 1, self.bsz): 49 | yield self.get_batch(i) 50 | 51 | 52 | def __iter__(self): 53 | return self.get_fixlen_iter() 54 | 55 | 56 | 57 | 58 | class LMOrderedIterator(object): 59 | def __init__(self, data, bsz, bptt, alianlen=3000, device='cpu', ext_len=None): 60 | ''' 61 | :param data: 62 | :param bsz: batch_size 63 | :param bptt: tgt_len 64 | :param device: 65 | :param ext_len: 0 66 | ''' 67 | """ 68 | data -- LongTensor -- the LongTensor is strictly ordered 69 | """ 70 | self.bsz = bsz 71 | self.bptt = bptt 72 | self.ext_len = ext_len if ext_len is not None else 0 73 | 74 | self.device = device 75 | 76 | # Work out how cleanly we can divide the dataset into bsz parts. 77 | # self.n_step = data.size(0) // bsz 78 | 79 | self.n_step = alianlen 80 | 81 | 82 | # # Trim off any extra elements that wouldn't cleanly fit (remainders). 83 | # data = data.narrow(0, 0, self.n_step * bsz) 84 | 85 | # Evenly divide the data across the bsz batches. 86 | self.data = data.t().contiguous().to(device) 87 | # print(self.data.size()) 88 | # print(self.bsz) 89 | # print(self.n_step) 90 | # self.data = data.view(bsz, -1).t().contiguous().to(device) 91 | # Number of mini-batches 92 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt # 这里不是一步一步的 93 | 94 | # 下面开始验证 95 | 96 | assert self.data.size(0)== self.n_step 97 | assert self.data.size(1) == self.bsz 98 | 99 | 100 | def get_batch(self, i, bptt=None): 101 | if bptt is None: 102 | bptt = self.bptt 103 | seq_len = min(bptt, self.data.size(0) - 1 - i) 104 | end_idx = i + seq_len 105 | beg_idx = max(0, i - self.ext_len) 106 | data = self.data[beg_idx:end_idx] 107 | 108 | return data, seq_len 109 | 110 | def get_fixlen_iter(self, start=0): 111 | for i in range(start, self.data.size(0) - 1, self.bptt): 112 | yield self.get_batch(i) 113 | 114 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 115 | max_len = self.bptt + max_deviation * std 116 | i = start 117 | while True: 118 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 119 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 120 | data, target, seq_len = self.get_batch(i, bptt) 121 | i += seq_len 122 | yield data, target, seq_len 123 | if i >= self.data.size(0) - 2: 124 | break 125 | 126 | def __iter__(self): 127 | return self.get_fixlen_iter() 128 | 129 | 130 | 131 | 132 | # 许海明 133 | class Corpus(object): 134 | def __init__(self, path, *args, **kwargs): 135 | self.vocab = Vocab(*args, **kwargs) 136 | 137 | # 从单词表里面加载单词 138 | self.vocab.build_vocab() 139 | 140 | # 训练集 141 | self.train = self.vocab.encode_file( os.path.join(path, 'train.txt'), verbose=True) 142 | self.train_label = self.vocab.encode_file_only_for_lables(os.path.join(path, 'train.label'), verbose=True) 143 | 144 | # 验证集 145 | self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'), verbose=True) 146 | self.valid_label = self.vocab.encode_file_only_for_lables(os.path.join(path, 'valid.label'), verbose=True) 147 | 148 | 149 | # self.test = self.vocab.encode_file( 150 | # os.path.join(path, 'test.txt'), ordered=True) 151 | 152 | # 许海明 153 | def get_batch_iterator(self, split, *args, **kwargs): 154 | ''' 155 | 156 | :param split: 157 | :param args: 158 | :param kwargs: 159 | :return: 160 | ''' 161 | if split == 'train': 162 | # data_iter = LMOrderedIterator(self.train, *args, **kwargs) 163 | batch_iter = BatchIteratorHelper(self.train,self.train_label, *args, **kwargs) 164 | 165 | elif split == 'valid': 166 | batch_iter = BatchIteratorHelper(self.valid, self.valid_label, *args, **kwargs) 167 | 168 | return batch_iter 169 | 170 | 171 | def get_lm_corpus(datadir,vocab_file,alinlen): 172 | fn = os.path.join(datadir, 'cache.pt') 173 | if os.path.exists(fn): 174 | print('Loading cached dataset...') 175 | corpus = torch.load(fn) 176 | else: 177 | print('Producing dataset {}...'.format(datadir)) 178 | kwargs = {} 179 | 180 | kwargs['special'] = ['','', '', ''] 181 | kwargs['lower_case'] = False 182 | kwargs['vocab_file'] = vocab_file 183 | 184 | 185 | 186 | corpus = Corpus(datadir,alinlen ,**kwargs) 187 | torch.save(corpus, fn) # 这里保存的是一个类的对象 188 | 189 | return corpus 190 | 191 | if __name__ == '__main__': 192 | import argparse 193 | parser = argparse.ArgumentParser(description='unit test') 194 | parser.add_argument('--datadir', type=str, default='../data/enwik8', 195 | help='location of the data corpus') 196 | parser.add_argument('--dataset', type=str, default='enwik8', 197 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 198 | help='dataset name') 199 | args = parser.parse_args() 200 | 201 | corpus = get_lm_corpus(args.datadir, args.dataset) 202 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 203 | tr_iter = corpus.get_iterator('train', 22, 512) 204 | ## 205 | # 许海明 许海明 许海明 许海明 许海明 206 | # 许海明 207 | # 208 | # 209 | # ## 210 | -------------------------------------------------------------------------------- /code_for_Classfy/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | 7 | import torch 8 | 9 | from data_utils import get_lm_corpus 10 | from mem_transformer import MemTransformerLM 11 | from utils.exp_utils import get_logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 14 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 15 | help='location of the data corpus') 16 | parser.add_argument('--dataset', type=str, default='wt103', 17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'], 18 | help='dataset name') 19 | parser.add_argument('--split', type=str, default='all', 20 | choices=['all', 'valid', 'test'], 21 | help='which split to evaluate') 22 | parser.add_argument('--batch_size', type=int, default=10, 23 | help='batch size') 24 | parser.add_argument('--tgt_len', type=int, default=5, 25 | help='number of tokens to predict') 26 | parser.add_argument('--ext_len', type=int, default=0, 27 | help='length of the extended context') 28 | parser.add_argument('--mem_len', type=int, default=0, 29 | help='length of the retained previous heads') 30 | parser.add_argument('--clamp_len', type=int, default=-1, 31 | help='max positional embedding index') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='use CUDA') 34 | parser.add_argument('--work_dir', type=str, required=True, 35 | help='path to the work_dir') 36 | parser.add_argument('--no_log', action='store_true', 37 | help='do not log the eval result') 38 | parser.add_argument('--same_length', action='store_true', 39 | help='set same length attention with masking') 40 | args = parser.parse_args() 41 | assert args.ext_len >= 0, 'extended context length must be non-negative' 42 | 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | # Get logger 46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'), 47 | log_=not args.no_log) 48 | 49 | # Load dataset 50 | corpus = get_lm_corpus(args.data, args.dataset) 51 | ntokens = len(corpus.vocab) 52 | 53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 54 | device=device, ext_len=args.ext_len) 55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 56 | device=device, ext_len=args.ext_len) 57 | 58 | # Load the best saved model. 59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 60 | model = torch.load(f) 61 | model.backward_compatible() 62 | model = model.to(device) 63 | 64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 66 | 67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 68 | if args.clamp_len > 0: 69 | model.clamp_len = args.clamp_len 70 | if args.same_length: 71 | model.same_length = True 72 | 73 | ############################################################################### 74 | # Evaluation code 75 | ############################################################################### 76 | def evaluate(eval_iter): 77 | # Turn on evaluation mode which disables dropout. 78 | model.eval() 79 | total_len, total_loss = 0, 0. 80 | start_time = time.time() 81 | with torch.no_grad(): 82 | mems = tuple() 83 | for idx, (data, target, seq_len) in enumerate(eval_iter): 84 | ret = model(data, target, *mems) 85 | loss, mems = ret[0], ret[1:] 86 | loss = loss.mean() 87 | total_loss += seq_len * loss.item() 88 | total_len += seq_len 89 | total_time = time.time() - start_time 90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format( 91 | total_time, 1000 * total_time / (idx+1))) 92 | return total_loss / total_len 93 | 94 | # Run on test data. 95 | if args.split == 'all': 96 | test_loss = evaluate(te_iter) 97 | valid_loss = evaluate(va_iter) 98 | elif args.split == 'valid': 99 | valid_loss = evaluate(va_iter) 100 | test_loss = None 101 | elif args.split == 'test': 102 | test_loss = evaluate(te_iter) 103 | valid_loss = None 104 | 105 | def format_log(loss, split): 106 | if args.dataset in ['enwik8', 'text8']: 107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( 108 | split, loss, loss / math.log(2)) 109 | else: 110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 111 | split, loss, math.exp(loss)) 112 | return log_str 113 | 114 | log_str = '' 115 | if valid_loss is not None: 116 | log_str += format_log(valid_loss, 'valid') 117 | if test_loss is not None: 118 | log_str += format_log(test_loss, 'test') 119 | 120 | logging('=' * 100) 121 | logging(log_str) 122 | logging('=' * 100) 123 | -------------------------------------------------------------------------------- /code_for_Classfy/mem_transformer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import functools 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from data_utils import LMOrderedIterator 12 | 13 | sys.path.append('utils') 14 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax 15 | from log_uniform_sampler import LogUniformSampler, sample_logits 16 | 17 | class PositionalEmbedding(nn.Module): 18 | def __init__(self, demb): 19 | super(PositionalEmbedding, self).__init__() 20 | 21 | self.demb = demb 22 | 23 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 24 | self.register_buffer('inv_freq', inv_freq) 25 | 26 | def forward(self, pos_seq, bsz=None): 27 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 28 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 29 | 30 | if bsz is not None: 31 | return pos_emb[:,None,:].expand(-1, bsz, -1) 32 | else: 33 | return pos_emb[:,None,:] 34 | 35 | 36 | class PositionwiseFF(nn.Module): 37 | def __init__(self, d_model, d_inner, dropout, pre_lnorm = False): 38 | super(PositionwiseFF, self).__init__() 39 | 40 | self.d_model = d_model 41 | self.d_inner = d_inner 42 | self.dropout = dropout 43 | 44 | self.CoreNet = nn.Sequential( 45 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 46 | nn.Dropout(dropout), 47 | nn.Linear(d_inner, d_model), 48 | nn.Dropout(dropout), 49 | ) 50 | 51 | self.layer_norm = nn.LayerNorm(d_model) 52 | 53 | self.pre_lnorm = pre_lnorm 54 | 55 | def forward(self, inp): 56 | if self.pre_lnorm: 57 | ##### layer normalization + positionwise feed-forward 58 | core_out = self.CoreNet(self.layer_norm(inp)) 59 | 60 | ##### residual connection 61 | output = core_out + inp 62 | else: 63 | ##### positionwise feed-forward 64 | core_out = self.CoreNet(inp) 65 | 66 | ##### residual connection + layer normalization 67 | output = self.layer_norm(inp + core_out) 68 | 69 | return output 70 | 71 | class MultiHeadAttn(nn.Module): 72 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 73 | pre_lnorm=False): 74 | super(MultiHeadAttn, self).__init__() 75 | 76 | self.n_head = n_head 77 | self.d_model = d_model 78 | self.d_head = d_head 79 | self.dropout = dropout 80 | 81 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) 82 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) 83 | 84 | self.drop = nn.Dropout(dropout) 85 | self.dropatt = nn.Dropout(dropatt) 86 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 87 | 88 | self.layer_norm = nn.LayerNorm(d_model) 89 | 90 | self.scale = 1 / (d_head ** 0.5) 91 | 92 | self.pre_lnorm = pre_lnorm 93 | 94 | def forward(self, h, attn_mask=None, mems=None): 95 | ##### multihead attention 96 | # [hlen x bsz x n_head x d_head] 97 | 98 | if mems is not None: 99 | c = torch.cat([mems, h], 0) 100 | else: 101 | c = h 102 | 103 | if self.pre_lnorm: 104 | ##### layer normalization 105 | c = self.layer_norm(c) 106 | 107 | head_q = self.q_net(h) 108 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) 109 | 110 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) 111 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) 112 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) 113 | 114 | # [qlen x klen x bsz x n_head] 115 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) 116 | attn_score.mul_(self.scale) 117 | if attn_mask is not None and attn_mask.any().item(): 118 | if attn_mask.dim() == 2: 119 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 120 | elif attn_mask.dim() == 3: 121 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 122 | 123 | # [qlen x klen x bsz x n_head] 124 | attn_prob = F.softmax(attn_score, dim=1) 125 | attn_prob = self.dropatt(attn_prob) 126 | 127 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] 128 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) 129 | attn_vec = attn_vec.contiguous().view( 130 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 131 | 132 | ##### linear projection 133 | attn_out = self.o_net(attn_vec) 134 | attn_out = self.drop(attn_out) 135 | 136 | if self.pre_lnorm: 137 | ##### residual connection 138 | output = h + attn_out 139 | else: 140 | ##### residual connection + layer normalization 141 | output = self.layer_norm(h + attn_out) 142 | 143 | return output 144 | 145 | 146 | 147 | class RelMultiHeadAttn(nn.Module): 148 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 149 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): 150 | super(RelMultiHeadAttn, self).__init__() 151 | 152 | self.n_head = n_head 153 | self.d_model = d_model 154 | self.d_head = d_head 155 | self.dropout = dropout 156 | 157 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 158 | 159 | self.drop = nn.Dropout(dropout) 160 | self.dropatt = nn.Dropout(dropatt) 161 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 162 | 163 | self.layer_norm = nn.LayerNorm(d_model) 164 | 165 | self.scale = 1 / (d_head ** 0.5) 166 | 167 | self.pre_lnorm = pre_lnorm 168 | 169 | def _parallelogram_mask(self, h, w, left=False): 170 | mask = torch.ones((h, w)).byte() 171 | m = min(h, w) 172 | mask[:m,:m] = torch.triu(mask[:m,:m]) 173 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) 174 | 175 | if left: 176 | return mask 177 | else: 178 | return mask.flip(0) 179 | 180 | def _shift(self, x, qlen, klen, mask, left=False): 181 | if qlen > 1: 182 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), 183 | device=x.device, dtype=x.dtype) 184 | else: 185 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 186 | 187 | if left: 188 | mask = mask.flip(1) 189 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 190 | else: 191 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 192 | 193 | x = x_padded.masked_select(mask[:,:,None,None]) \ 194 | .view(qlen, klen, x.size(2), x.size(3)) 195 | 196 | return x 197 | 198 | def _rel_shift(self, x, zero_triu=False): 199 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), 200 | device=x.device, dtype=x.dtype) 201 | x_padded = torch.cat([zero_pad, x], dim=1) 202 | 203 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 204 | 205 | x = x_padded[1:].view_as(x) 206 | 207 | if zero_triu: 208 | ones = torch.ones((x.size(0), x.size(1))) 209 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] 210 | 211 | return x 212 | 213 | def forward(self, w, r, attn_mask=None, mems=None): 214 | raise NotImplementedError 215 | 216 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 217 | def __init__(self, *args, **kwargs): 218 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 219 | 220 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 221 | 222 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 223 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 224 | 225 | if mems is not None: 226 | cat = torch.cat([mems, w], 0) 227 | if self.pre_lnorm: 228 | w_heads = self.qkv_net(self.layer_norm(cat)) 229 | else: 230 | w_heads = self.qkv_net(cat) 231 | r_head_k = self.r_net(r) 232 | 233 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 234 | w_head_q = w_head_q[-qlen:] 235 | else: 236 | if self.pre_lnorm: 237 | w_heads = self.qkv_net(self.layer_norm(w)) 238 | else: 239 | w_heads = self.qkv_net(w) 240 | r_head_k = self.r_net(r) 241 | 242 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 243 | 244 | klen = w_head_k.size(0) 245 | 246 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 247 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 248 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 249 | 250 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head 251 | 252 | #### compute attention score 253 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head 254 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 255 | 256 | rr_head_q = w_head_q + r_r_bias 257 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head 258 | BD = self._rel_shift(BD) 259 | 260 | # [qlen x klen x bsz x n_head] 261 | attn_score = AC + BD 262 | attn_score.mul_(self.scale) 263 | 264 | #### compute attention probability 265 | if attn_mask is not None and attn_mask.any().item(): 266 | if attn_mask.dim() == 2: 267 | attn_score = attn_score.float().masked_fill( 268 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) 269 | elif attn_mask.dim() == 3: 270 | attn_score = attn_score.float().masked_fill( 271 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) 272 | 273 | # [qlen x klen x bsz x n_head] 274 | attn_prob = F.softmax(attn_score, dim=1) 275 | attn_prob = self.dropatt(attn_prob) 276 | 277 | #### compute attention vector 278 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 279 | 280 | # [qlen x bsz x n_head x d_head] 281 | attn_vec = attn_vec.contiguous().view( 282 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 283 | 284 | ##### linear projection 285 | attn_out = self.o_net(attn_vec) 286 | attn_out = self.drop(attn_out) 287 | 288 | if self.pre_lnorm: 289 | ##### residual connection 290 | output = w + attn_out 291 | else: 292 | ##### residual connection + layer normalization 293 | output = self.layer_norm(w + attn_out) 294 | 295 | return output 296 | 297 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn): 298 | def __init__(self, *args, **kwargs): 299 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 300 | 301 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): 302 | # r_emb: [klen, n_head, d_head], used for term B 303 | # r_w_bias: [n_head, d_head], used for term C 304 | # r_bias: [klen, n_head], used for term D 305 | 306 | qlen, bsz = w.size(0), w.size(1) 307 | 308 | if mems is not None: 309 | cat = torch.cat([mems, w], 0) 310 | if self.pre_lnorm: 311 | w_heads = self.qkv_net(self.layer_norm(cat)) 312 | else: 313 | w_heads = self.qkv_net(cat) 314 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 315 | 316 | w_head_q = w_head_q[-qlen:] 317 | else: 318 | if self.pre_lnorm: 319 | w_heads = self.qkv_net(self.layer_norm(w)) 320 | else: 321 | w_heads = self.qkv_net(w) 322 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 323 | 324 | klen = w_head_k.size(0) 325 | 326 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) 327 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) 328 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) 329 | 330 | if klen > r_emb.size(0): 331 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) 332 | r_emb = torch.cat([r_emb_pad, r_emb], 0) 333 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) 334 | r_bias = torch.cat([r_bias_pad, r_bias], 0) 335 | else: 336 | r_emb = r_emb[-klen:] 337 | r_bias = r_bias[-klen:] 338 | 339 | #### compute attention score 340 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head 341 | 342 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 343 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head 344 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head 345 | BD = self._rel_shift(B_ + D_) 346 | 347 | # [qlen x klen x bsz x n_head] 348 | attn_score = AC + BD 349 | attn_score.mul_(self.scale) 350 | 351 | #### compute attention probability 352 | if attn_mask is not None and attn_mask.any().item(): 353 | if attn_mask.dim() == 2: 354 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 355 | elif attn_mask.dim() == 3: 356 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 357 | 358 | # [qlen x klen x bsz x n_head] 359 | attn_prob = F.softmax(attn_score, dim=1) 360 | attn_prob = self.dropatt(attn_prob) 361 | 362 | #### compute attention vector 363 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 364 | 365 | # [qlen x bsz x n_head x d_head] 366 | attn_vec = attn_vec.contiguous().view( 367 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 368 | 369 | ##### linear projection 370 | attn_out = self.o_net(attn_vec) 371 | attn_out = self.drop(attn_out) 372 | 373 | if self.pre_lnorm: 374 | ##### residual connection 375 | output = w + attn_out 376 | else: 377 | ##### residual connection + layer normalization 378 | output = self.layer_norm(w + attn_out) 379 | 380 | return output 381 | 382 | 383 | 384 | 385 | class DecoderLayer(nn.Module): 386 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): 387 | super(DecoderLayer, self).__init__() 388 | 389 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) 390 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 391 | pre_lnorm=kwargs.get('pre_lnorm')) 392 | 393 | def forward(self, dec_inp, dec_attn_mask=None, mems=None): 394 | 395 | output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, 396 | mems=mems) 397 | output = self.pos_ff(output) 398 | 399 | return output 400 | 401 | class RelLearnableDecoderLayer(nn.Module): 402 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 403 | **kwargs): 404 | super(RelLearnableDecoderLayer, self).__init__() 405 | 406 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, 407 | **kwargs) 408 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 409 | pre_lnorm=kwargs.get('pre_lnorm')) 410 | 411 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): 412 | 413 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, 414 | attn_mask=dec_attn_mask, 415 | mems=mems) 416 | output = self.pos_ff(output) 417 | 418 | return output 419 | 420 | 421 | 422 | 423 | class RelPartialLearnableDecoderLayer(nn.Module): 424 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 425 | **kwargs): 426 | super(RelPartialLearnableDecoderLayer, self).__init__() 427 | 428 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, 429 | d_head, dropout, **kwargs) 430 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 431 | pre_lnorm=kwargs.get('pre_lnorm')) 432 | 433 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 434 | 435 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, 436 | attn_mask=dec_attn_mask, 437 | mems=mems) 438 | output = self.pos_ff(output) 439 | 440 | return output 441 | 442 | 443 | 444 | class MyAttention(nn.Module): 445 | def __init__(self,hidden_size,attention_size): 446 | super(MyAttention, self).__init__() 447 | self.hidden_size=hidden_size 448 | self.attention_size = attention_size 449 | 450 | self.W_omega = nn.Parameter(torch.Tensor(self.attention_size,self.hidden_size)) 451 | self.b_omega = nn.Parameter(torch.Tensor(self.attention_size)) 452 | self.u_omega = nn.Parameter(torch.Tensor(1,self.attention_size)) 453 | 454 | def forward(self, inputs, time_major=True, return_alphas=False): 455 | ''' 456 | 457 | :param inputs: 这里的inputs 的shape必须是[batch,len,hididen] 458 | :param time_major: 459 | :param return_alphas: 460 | :return: 461 | ''' 462 | # v = torch.tanh(torch.matmul(torch.reshape(inputs, [-1, hidden_size]), W_omega) + tf.reshape(b_omega, [1, -1])) 463 | 464 | v = torch.tanh(F.linear(inputs, self.W_omega, self.b_omega)) #[B, L, attention_size] 465 | vu = F.linear(v, self.u_omega) #[B,L,1] 466 | vu = torch.squeeze(vu) 467 | exps = torch.exp(vu) 468 | 469 | alphas = exps /(torch.sum(input=exps,dim=1, keepdim=True)+0.00001) 470 | 471 | # Output of Bi-RNN is reduced with attention vector 472 | output = torch.sum(inputs * torch.unsqueeze(alphas, 2), 1) 473 | return output 474 | 475 | 476 | # 许海明 477 | class AdaptiveEmbedding(nn.Module): 478 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 479 | sample_softmax=False): 480 | 481 | ''' 482 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) 483 | :param n_token: 484 | :param d_embed: 485 | :param d_proj: 486 | :param cutoffs: 487 | :param div_val: 488 | :param sample_softmax: 489 | ''' 490 | super(AdaptiveEmbedding, self).__init__() 491 | 492 | self.n_token = n_token 493 | self.d_embed = d_embed 494 | 495 | self.cutoffs = cutoffs + [n_token] 496 | self.div_val = div_val 497 | self.d_proj = d_proj 498 | 499 | self.emb_scale = d_proj ** 0.5 500 | 501 | self.cutoff_ends = [0] + self.cutoffs 502 | 503 | self.emb_layers = nn.ModuleList() 504 | self.emb_projs = nn.ParameterList() 505 | if div_val == 1: 506 | self.emb_layers.append( 507 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) 508 | ) 509 | if d_proj != d_embed: 510 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 511 | else: 512 | for i in range(len(self.cutoffs)): 513 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 514 | d_emb_i = d_embed // (div_val ** i) 515 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) 516 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) 517 | 518 | 519 | 520 | def forward(self, inp): 521 | if self.div_val == 1: 522 | embed = self.emb_layers[0](inp) 523 | if self.d_proj != self.d_embed: 524 | embed = F.linear(embed, self.emb_projs[0]) 525 | else: 526 | param = next(self.parameters()) 527 | inp_flat = inp.view(-1) 528 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 529 | dtype=param.dtype, device=param.device) 530 | for i in range(len(self.cutoffs)): 531 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 532 | 533 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) 534 | indices_i = mask_i.nonzero().squeeze() 535 | 536 | if indices_i.numel() == 0: 537 | continue 538 | 539 | inp_i = inp_flat.index_select(0, indices_i) - l_idx 540 | emb_i = self.emb_layers[i](inp_i) 541 | emb_i = F.linear(emb_i, self.emb_projs[i]) 542 | 543 | emb_flat.index_copy_(0, indices_i, emb_i) 544 | 545 | embed = emb_flat.view(*inp.size(), self.d_proj) 546 | 547 | embed.mul_(self.emb_scale) 548 | 549 | return embed 550 | 551 | 552 | 553 | 554 | class MemTransformerLM(nn.Module): 555 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, 556 | dropout, dropatt, tie_weight=True, d_embed=None, 557 | div_val=1, tie_projs=[False], pre_lnorm=False, 558 | tgt_len=None, ext_len=None, mem_len=None, 559 | cutoffs=[], adapt_inp=False, 560 | same_length=False, attn_type=0, clamp_len=-1, 561 | sample_softmax=-1): 562 | ''' 563 | 564 | :param n_token: 单词表的大小 565 | :param n_layer: 16 566 | :param n_head: 10 567 | :param d_model: 410 568 | :param d_head: 41 569 | :param d_inner: 2100 570 | :param dropout: 0.1 571 | :param dropatt: 0.0 572 | :param tie_weight: True 573 | :param d_embed: 410 574 | :param div_val: 1 575 | :param tie_projs: [F,T,T,T] 576 | :param pre_lnorm: False 577 | :param tgt_len: 150 578 | :param ext_len: 0 579 | :param mem_len: 150 580 | :param cutoffs: [20000, 40000, 200000] 581 | :param adapt_inp: 582 | :param same_length: 训练的时候是False 583 | :param attn_type: 0 584 | :param clamp_len: -1 585 | :param sample_softmax: -1 586 | ''' 587 | super(MemTransformerLM, self).__init__() 588 | self.n_token = n_token 589 | 590 | d_embed = d_model if d_embed is None else d_embed 591 | self.d_embed = d_embed 592 | self.d_model = d_model 593 | self.n_head = n_head 594 | self.d_head = d_head 595 | 596 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 597 | div_val=div_val) 598 | 599 | self.drop = nn.Dropout(dropout) 600 | self.n_layer = n_layer 601 | self.tgt_len = tgt_len 602 | self.mem_len = mem_len 603 | self.ext_len = ext_len 604 | self.max_klen = tgt_len + ext_len + mem_len 605 | self.attn_type = attn_type 606 | # 创建 attention 607 | self.attention_layer_low = MyAttention(self.d_model, self.d_model//2) 608 | self.attention_layer_high = MyAttention(self.d_model, self.d_model // 2) 609 | 610 | self.fc = nn.Sequential( 611 | nn.Linear(self.d_model,self.d_model), 612 | nn.BatchNorm1d(self.d_model), 613 | nn.ReLU(inplace=True), 614 | nn.Linear(self.d_model, 19) 615 | ) 616 | 617 | self.layers = nn.ModuleList() 618 | if attn_type == 0: # the default attention 619 | for i in range(n_layer): 620 | self.layers.append( 621 | RelPartialLearnableDecoderLayer( 622 | n_head, d_model, d_head, d_inner, dropout, 623 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 624 | dropatt=dropatt, pre_lnorm=pre_lnorm) 625 | ) 626 | elif attn_type == 1: # learnable embeddings 627 | for i in range(n_layer): 628 | self.layers.append( 629 | RelLearnableDecoderLayer( 630 | n_head, d_model, d_head, d_inner, dropout, 631 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 632 | dropatt=dropatt, pre_lnorm=pre_lnorm) 633 | ) 634 | elif attn_type in [2, 3]: # absolute embeddings 635 | for i in range(n_layer): 636 | self.layers.append( 637 | DecoderLayer( 638 | n_head, d_model, d_head, d_inner, dropout, 639 | dropatt=dropatt, pre_lnorm=pre_lnorm) 640 | ) 641 | 642 | self.sample_softmax = sample_softmax 643 | # use sampled softmax 644 | if sample_softmax > 0: 645 | self.out_layer = nn.Linear(d_model, n_token) 646 | if tie_weight: 647 | self.out_layer.weight = self.word_emb.weight 648 | self.tie_weight = tie_weight 649 | self.sampler = LogUniformSampler(n_token, sample_softmax) 650 | 651 | # use adaptive softmax (including standard softmax) 652 | else: 653 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 654 | cutoffs, div_val=div_val) 655 | 656 | if tie_weight: 657 | for i in range(len(self.crit.out_layers)): 658 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight 659 | 660 | if tie_projs: 661 | for i, tie_proj in enumerate(tie_projs): 662 | if tie_proj and div_val == 1 and d_model != d_embed: 663 | self.crit.out_projs[i] = self.word_emb.emb_projs[0] 664 | elif tie_proj and div_val != 1: 665 | self.crit.out_projs[i] = self.word_emb.emb_projs[i] 666 | 667 | self.same_length = same_length 668 | self.clamp_len = clamp_len 669 | 670 | self._create_params() 671 | 672 | def backward_compatible(self): 673 | self.sample_softmax = -1 674 | 675 | def _create_params(self): 676 | if self.attn_type == 0: # default attention 677 | self.pos_emb = PositionalEmbedding(self.d_model) 678 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 679 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 680 | 681 | elif self.attn_type == 1: # learnable 682 | self.r_emb = nn.Parameter(torch.Tensor( 683 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 684 | self.r_w_bias = nn.Parameter(torch.Tensor( 685 | self.n_layer, self.n_head, self.d_head)) 686 | self.r_bias = nn.Parameter(torch.Tensor( 687 | self.n_layer, self.max_klen, self.n_head)) 688 | elif self.attn_type == 2: # absolute standard 689 | self.pos_emb = PositionalEmbedding(self.d_model) 690 | elif self.attn_type == 3: # absolute deeper SA 691 | self.r_emb = nn.Parameter(torch.Tensor( 692 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 693 | 694 | def reset_length(self, tgt_len, ext_len, mem_len): 695 | self.tgt_len = tgt_len 696 | self.mem_len = mem_len 697 | self.ext_len = ext_len 698 | 699 | def init_mems(self): 700 | if self.mem_len > 0: 701 | mems = [] 702 | param = next(self.parameters()) 703 | for i in range(self.n_layer+1): 704 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 705 | mems.append(empty) 706 | 707 | return mems 708 | else: 709 | return None 710 | 711 | def _update_mems(self, hids, mems, qlen, mlen): 712 | # does not deal with None 713 | if mems is None: return None 714 | 715 | # mems is not None 716 | assert len(hids) == len(mems), 'len(hids) != len(mems)' 717 | 718 | # There are `mlen + qlen` steps that can be cached into mems 719 | # For the next step, the last `ext_len` of the `qlen` tokens 720 | # will be used as the extended context. Hence, we only cache 721 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len` 722 | # to `mlen + qlen - self.ext_len`. 723 | with torch.no_grad(): 724 | new_mems = [] 725 | end_idx = mlen + max(0, qlen - 0 - self.ext_len) 726 | beg_idx = max(0, end_idx - self.mem_len) 727 | for i in range(len(hids)): 728 | cat = torch.cat([mems[i], hids[i]], dim=0) 729 | new_mems.append(cat[beg_idx:end_idx].detach()) 730 | 731 | return new_mems 732 | 733 | def _forward(self, dec_inp, mems=None): 734 | qlen, bsz = dec_inp.size() 735 | # print(qlen,bsz) 736 | word_emb = self.word_emb(dec_inp) 737 | # print(word_emb) 738 | mlen = mems[0].size(0) if mems is not None else 0 739 | klen = mlen + qlen 740 | if self.same_length: 741 | all_ones = word_emb.new_ones(qlen, klen) 742 | mask_len = klen - self.mem_len 743 | if mask_len > 0: 744 | mask_shift_len = qlen - mask_len 745 | else: 746 | mask_shift_len = qlen 747 | dec_attn_mask = (torch.triu(all_ones, mlen) 748 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 749 | else: 750 | # dec_attn_mask = torch.triu( 751 | # word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] 752 | # 753 | dec_attn_mask = torch.triu( 754 | word_emb.new_ones(qlen, klen), diagonal=mlen).byte()[:,:,None] 755 | 756 | 757 | hids = [] 758 | if self.attn_type == 0: # default 759 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 760 | dtype=word_emb.dtype) 761 | if self.clamp_len > 0: 762 | pos_seq.clamp_(max=self.clamp_len) 763 | pos_emb = self.pos_emb(pos_seq) 764 | 765 | core_out = self.drop(word_emb) 766 | pos_emb = self.drop(pos_emb) 767 | 768 | hids.append(core_out) 769 | for i, layer in enumerate(self.layers): 770 | mems_i = None if mems is None else mems[i] 771 | core_out = layer(core_out, pos_emb, self.r_w_bias, 772 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 773 | hids.append(core_out) 774 | 775 | 776 | elif self.attn_type == 1: # learnable 777 | core_out = self.drop(word_emb) 778 | hids.append(core_out) 779 | for i, layer in enumerate(self.layers): 780 | if self.clamp_len > 0: 781 | r_emb = self.r_emb[i][-self.clamp_len :] 782 | r_bias = self.r_bias[i][-self.clamp_len :] 783 | else: 784 | r_emb, r_bias = self.r_emb[i], self.r_bias[i] 785 | 786 | mems_i = None if mems is None else mems[i] 787 | core_out = layer(core_out, r_emb, self.r_w_bias[i], 788 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 789 | hids.append(core_out) 790 | elif self.attn_type == 2: # absolute 791 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, 792 | dtype=word_emb.dtype) 793 | if self.clamp_len > 0: 794 | pos_seq.clamp_(max=self.clamp_len) 795 | pos_emb = self.pos_emb(pos_seq) 796 | 797 | core_out = self.drop(word_emb + pos_emb[-qlen:]) 798 | 799 | hids.append(core_out) 800 | for i, layer in enumerate(self.layers): 801 | mems_i = None if mems is None else mems[i] 802 | if mems_i is not None and i == 0: 803 | mems_i += pos_emb[:mlen] 804 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask, 805 | mems=mems_i) 806 | hids.append(core_out) 807 | elif self.attn_type == 3: 808 | core_out = self.drop(word_emb) 809 | 810 | hids.append(core_out) 811 | for i, layer in enumerate(self.layers): 812 | mems_i = None if mems is None else mems[i] 813 | if mems_i is not None and mlen > 0: 814 | cur_emb = self.r_emb[i][:-qlen] 815 | cur_size = cur_emb.size(0) 816 | if cur_size < mlen: 817 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) 818 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) 819 | else: 820 | cur_emb = cur_emb[-mlen:] 821 | mems_i += cur_emb.view(mlen, 1, -1) 822 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) 823 | 824 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask, 825 | mems=mems_i) 826 | hids.append(core_out) 827 | 828 | core_out = self.drop(core_out) 829 | 830 | new_mems = self._update_mems(hids, mems, mlen, qlen) 831 | 832 | return core_out, new_mems 833 | 834 | ''' 835 | data, labels,bsz_len, args.tgt_len, device = 'cpu', ext_len = args.ext_len, *mems) 836 | ''' 837 | def forward(self, data, labels, *mems, tgt_len=None, device=None, ext_len=None): 838 | # nn.DataParallel does not allow size(0) tensors to be broadcasted. 839 | # So, have to initialize size(0) mems inside the model forward. 840 | # Moreover, have to return new_mems to allow nn.DataParallel to piece 841 | # them together. 842 | if not mems: 843 | mems = self.init_mems() 844 | 845 | bsz_len = data.size(0) 846 | 847 | train_iter_one_batch = LMOrderedIterator(data, bsz_len, tgt_len, device=device, ext_len=ext_len) 848 | 849 | memory_low = [] 850 | for index, (data, seq_len) in enumerate(train_iter_one_batch.get_fixlen_iter()): 851 | # tgt_len = target.size(0) 852 | # print(data) 853 | hidden, mems = self._forward(data, mems=mems) 854 | # print(mems) 855 | memory_low.append(self.attention_layer_low(hidden.permute(1,0,2))) 856 | # 857 | # print(memory_low) 858 | shapes=memory_low[0].size() 859 | 860 | # print(len(memory_low)) 861 | # print(shapes) 862 | memory_high = torch.cat(memory_low, 0) 863 | memory_high = torch.reshape(memory_high, (-1, shapes[0], shapes[1])) 864 | # print(memory_high.size()) 865 | res=self.attention_layer_high(memory_high.permute(1,0,2)) 866 | #########下面添加映射层 计算损失函数############## 867 | # print(res) 868 | logits=self.fc(res) 869 | # print(logits.size()) 870 | return logits,labels 871 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | # # pred_hid = hidden[-tgt_len:] 879 | # # if self.sample_softmax > 0 and self.training: 880 | # # assert self.tie_weight 881 | # # logit = sample_logits(self.word_emb, 882 | # # self.out_layer.bias, target, pred_hid, self.sampler) 883 | # # loss = -F.log_softmax(logit, -1)[:, :, 0] 884 | # # else: 885 | # # loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) 886 | # # loss = loss.view(tgt_len, -1) 887 | # 888 | # return loss 889 | 890 | 891 | 892 | if __name__ == '__main__': 893 | import argparse 894 | 895 | parser = argparse.ArgumentParser(description='unit test') 896 | 897 | parser.add_argument('--n_layer', type=int, default=4, help='') 898 | parser.add_argument('--n_rel_layer', type=int, default=4, help='') 899 | parser.add_argument('--n_head', type=int, default=2, help='') 900 | parser.add_argument('--d_head', type=int, default=2, help='') 901 | parser.add_argument('--d_model', type=int, default=200, help='') 902 | parser.add_argument('--d_embed', type=int, default=200, help='') 903 | parser.add_argument('--d_inner', type=int, default=200, help='') 904 | parser.add_argument('--dropout', type=float, default=0.0, help='') 905 | parser.add_argument('--cuda', action='store_true', help='') 906 | parser.add_argument('--seed', type=int, default=1111, help='') 907 | parser.add_argument('--multi_gpu', action='store_true', help='') 908 | 909 | args = parser.parse_args() 910 | 911 | device = torch.device("cuda" if args.cuda else "cpu") 912 | 913 | B = 4 914 | tgt_len, mem_len, ext_len = 36, 36, 0 915 | data_len = tgt_len * 20 916 | args.n_token = 10000 917 | 918 | import data_utils 919 | 920 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device) 921 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len) 922 | 923 | cutoffs = [args.n_token // 2] 924 | tie_projs = [False] + [True] * len(cutoffs) 925 | 926 | for div_val in [1, 2]: 927 | for d_embed in [200, 100]: 928 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head, 929 | args.d_model, args.d_head, args.d_inner, args.dropout, 930 | dropatt=args.dropout, tie_weight=True, 931 | d_embed=d_embed, div_val=div_val, 932 | tie_projs=tie_projs, pre_lnorm=True, 933 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 934 | cutoffs=cutoffs, attn_type=0).to(device) 935 | 936 | print(sum(p.numel() for p in model.parameters())) 937 | 938 | mems = tuple() 939 | for idx, (inp, tgt, seqlen) in enumerate(diter): 940 | print('batch {}'.format(idx)) 941 | out = model(inp, tgt, *mems) 942 | mems = out[1:] 943 | -------------------------------------------------------------------------------- /code_for_Classfy/myattention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class MyAttention(nn.Module): 6 | def __init__(self,hidden_size,attention_size): 7 | super(MyAttention, self).__init__() 8 | self.hidden_size=hidden_size 9 | self.attention_size = attention_size 10 | 11 | self.W_omega = nn.Parameter(torch.Tensor(self.attention_size,self.hidden_size)) 12 | self.b_omega = nn.Parameter(torch.Tensor(self.attention_size)) 13 | self.u_omega = nn.Parameter(torch.Tensor(1,self.attention_size)) 14 | 15 | def forward(self, inputs, time_major=True, return_alphas=False): 16 | ''' 17 | 18 | :param inputs: 这里的inputs 的shape必须是[batch,len,hididen] 19 | :param time_major: 20 | :param return_alphas: 21 | :return: 22 | ''' 23 | # v = torch.tanh(torch.matmul(torch.reshape(inputs, [-1, hidden_size]), W_omega) + tf.reshape(b_omega, [1, -1])) 24 | 25 | v = torch.tanh(F.linear(inputs, self.W_omega, self.b_omega)) #[B, L, attention_size] 26 | print(v.size()) 27 | vu = F.linear(v, self.u_omega) #[B,L,1] 28 | print(vu.size()) 29 | vu = torch.squeeze(vu) 30 | print(vu.size()) 31 | exps = torch.exp(vu) 32 | print(exps.size()) 33 | 34 | alphas = exps /torch.sum(input=exps,dim=1, keepdim=True) 35 | print(alphas.size()) 36 | 37 | # Output of Bi-RNN is reduced with attention vector 38 | output = torch.sum(inputs * torch.unsqueeze(alphas, 2), 1) 39 | return output 40 | # 41 | # myAttention = MyAttention(100,50) 42 | # 43 | # x = torch.randn(10,64,100) 44 | # y=x.permute(1,0,2) 45 | # res=myAttention(y) 46 | # print(res.size()) 47 | 48 | x = torch.randn(2,2) 49 | print(x) 50 | y = torch.randn(2,2) 51 | z = torch.randn(2,2) 52 | 53 | a = [x,y,z] 54 | 55 | res= torch.cat(a, 0) 56 | res=torch.reshape(res, (-1, 2, 2)) 57 | print(res.size()) 58 | print(res[0]) -------------------------------------------------------------------------------- /code_for_Classfy/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | word_emb=torch.randn(5,10) 3 | 4 | dec_attn_mask = torch.triu( 5 | word_emb.new_ones(5,10), diagonal=5+1) 6 | print(dec_attn_mask) -------------------------------------------------------------------------------- /code_for_Classfy/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | import itertools 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | 15 | from data_utils import get_lm_corpus 16 | from mem_transformer import MemTransformerLM 17 | from utils.exp_utils import create_exp_dir 18 | from utils.data_parallel import BalancedDataParallel 19 | from sklearn import metrics 20 | 21 | 22 | 23 | 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "5,7" 25 | 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 28 | parser.add_argument('--alinlen', type=int, default=3000, help='xhm') 29 | parser.add_argument('--debug', action='store_true', 30 | help='run in debug mode (do not create exp dir)') 31 | 32 | parser.add_argument('--work_dir', default='LM-TFM', type=str, 33 | help='experiment directory.') 34 | 35 | parser.add_argument('--init_model_LM', type=str, help='这是预训练语言模型的路径.') 36 | 37 | parser.add_argument('--not_tied', action='store_true', 38 | help='do not tie the word embedding and softmax weights') 39 | parser.add_argument('--data', type=str, 40 | help='location of the data corpus') 41 | 42 | parser.add_argument('--vocab_file', type=str, help='单词表的路径') 43 | 44 | 45 | 46 | parser.add_argument('--dataset', type=str, default='dccl', 47 | choices=['dccl'], 48 | help='dataset name') 49 | 50 | parser.add_argument('--seed', type=int, default=1111, 51 | help='random seed') 52 | parser.add_argument('--ext_len', type=int, default=0, 53 | help='length of the extended context') 54 | 55 | 56 | parser.add_argument('--n_layer', type=int, default=6, 57 | help='number of total layers') 58 | parser.add_argument('--d_model', type=int, default=410, 59 | help='model dimension') 60 | parser.add_argument('--n_head', type=int, default=10, 61 | help='number of heads') 62 | parser.add_argument('--d_head', type=int, default=41, 63 | help='head dimension') 64 | parser.add_argument('--d_embed', type=int, default=-1, 65 | help='embedding dimension') 66 | 67 | parser.add_argument('--cuda', action='store_true', 68 | help='use CUDA') 69 | 70 | 71 | parser.add_argument('--fp16', action='store_true', 72 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).') 73 | 74 | parser.add_argument('--batch_size', type=int, default=60, 75 | help='batch size') 76 | 77 | 78 | parser.add_argument('--tgt_len', type=int, default=150, 79 | help='number of tokens to predict') 80 | parser.add_argument('--eval_tgt_len', type=int, default=150, 81 | help='number of tokens to predict for evaluation') 82 | 83 | parser.add_argument('--adaptive', action='store_true', 84 | help='use adaptive softmax') 85 | 86 | 87 | 88 | parser.add_argument('--d_inner', type=int, default=1000, 89 | help='inner dimension in FF') 90 | parser.add_argument('--dropout', type=float, default=0.0, 91 | help='global dropout rate') 92 | parser.add_argument('--dropatt', type=float, default=0.0, 93 | help='attention probability dropout rate') 94 | parser.add_argument('--init', default='normal', type=str, 95 | help='parameter initializer to use.') 96 | parser.add_argument('--emb_init', default='normal', type=str, 97 | help='parameter initializer to use.') 98 | parser.add_argument('--init_range', type=float, default=0.1, 99 | help='parameters initialized by U(-init_range, init_range)') 100 | parser.add_argument('--emb_init_range', type=float, default=0.01, 101 | help='parameters initialized by U(-init_range, init_range)') 102 | parser.add_argument('--init_std', type=float, default=0.02, 103 | help='parameters initialized by N(0, init_std)') 104 | parser.add_argument('--proj_init_std', type=float, default=0.01, 105 | help='parameters initialized by N(0, init_std)') 106 | parser.add_argument('--optim', default='adam', type=str, 107 | choices=['adam', 'sgd', 'adagrad'], 108 | help='optimizer to use.') 109 | parser.add_argument('--lr', type=float, default=0.00025, 110 | help='initial learning rate (0.00025|5 for adam|sgd)') 111 | parser.add_argument('--mom', type=float, default=0.0, 112 | help='momentum for sgd') 113 | parser.add_argument('--scheduler', default='cosine', type=str, 114 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'], 115 | help='lr scheduler to use.') 116 | parser.add_argument('--warmup_step', type=int, default=0, 117 | help='upper epoch limit') 118 | parser.add_argument('--decay_rate', type=float, default=0.5, 119 | help='decay factor when ReduceLROnPlateau is used') 120 | parser.add_argument('--lr_min', type=float, default=0.0, 121 | help='minimum learning rate during annealing') 122 | parser.add_argument('--clip', type=float, default=0.25, 123 | help='gradient clipping') 124 | parser.add_argument('--clip_nonemb', action='store_true', 125 | help='only clip the gradient of non-embedding params') 126 | parser.add_argument('--max_step', type=int, default=100000, 127 | help='upper epoch limit') 128 | 129 | parser.add_argument('--batch_chunk', type=int, default=1, 130 | help='split batch into chunks to save memory') 131 | 132 | parser.add_argument('--mem_len', type=int, default=0, 133 | help='length of the retained previous heads') 134 | parser.add_argument('--restart_dir', type=str, default='', 135 | help='restart dir') 136 | 137 | parser.add_argument('--div_val', type=int, default=1, 138 | help='divident value for adapative input and softmax') 139 | parser.add_argument('--pre_lnorm', action='store_true', 140 | help='apply LayerNorm to the input instead of the output') 141 | parser.add_argument('--varlen', action='store_true', 142 | help='use variable length') 143 | parser.add_argument('--multi_gpu', action='store_true', 144 | help='use multiple GPU') 145 | parser.add_argument('--log-interval', type=int, default=200, 146 | help='report interval') 147 | parser.add_argument('--eval-interval', type=int, default=4000, 148 | help='evaluation interval') 149 | 150 | parser.add_argument('--restart', action='store_true', 151 | help='restart training from the saved checkpoint') 152 | 153 | 154 | parser.add_argument('--same_length', action='store_true', 155 | help='use the same attn length for all tokens') 156 | parser.add_argument('--attn_type', type=int, default=0, 157 | help='attention type. 0 for ours, 1 for Shaw et al,' 158 | '2 for Vaswani et al, 3 for Al Rfou et al.') 159 | parser.add_argument('--clamp_len', type=int, default=-1, 160 | help='use the same pos embeddings after clamp_len') 161 | parser.add_argument('--eta_min', type=float, default=0.0, 162 | help='min learning rate for cosine scheduler') 163 | parser.add_argument('--gpu0_bsz', type=int, default=-1, 164 | help='batch size on gpu 0') 165 | parser.add_argument('--max_eval_steps', type=int, default=-1, 166 | help='max eval steps') 167 | parser.add_argument('--sample_softmax', type=int, default=-1, 168 | help='number of samples in sampled softmax') 169 | parser.add_argument('--patience', type=int, default=0, 170 | help='patience') 171 | parser.add_argument('--finetune_v2', action='store_true', 172 | help='finetune v2') 173 | parser.add_argument('--finetune_v3', action='store_true', 174 | help='finetune v3') 175 | 176 | parser.add_argument('--static-loss-scale', type=float, default=1, 177 | help='Static loss scale, positive power of 2 values can ' 178 | 'improve fp16 convergence.') 179 | parser.add_argument('--dynamic-loss-scale', action='store_true', 180 | help='Use dynamic loss scaling. If supplied, this argument' 181 | ' supersedes --static-loss-scale.') 182 | 183 | 184 | #####################################程序开始######################################################################################################## 185 | args = parser.parse_args() 186 | args.tied = not args.not_tied # code True 187 | 188 | if args.d_embed < 0: 189 | args.d_embed = args.d_model # code 410 190 | 191 | assert args.ext_len >= 0, 'extended context length must be non-negative' # code 0 192 | assert args.batch_size % args.batch_chunk == 0 193 | 194 | 195 | logging = create_exp_dir(args.work_dir, 196 | scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) 197 | 198 | 199 | 200 | # Set the random seed manually for reproducibility. 201 | np.random.seed(args.seed) 202 | torch.manual_seed(args.seed) 203 | if torch.cuda.is_available(): 204 | if not args.cuda: 205 | print('WARNING: You have a CUDA device, so you should probably run with --cuda') 206 | else: 207 | torch.cuda.manual_seed_all(args.seed) 208 | 209 | # Validate `--fp16` option 210 | if args.fp16: 211 | if not args.cuda: 212 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') 213 | args.fp16 = False 214 | else: 215 | try: 216 | from apex.fp16_utils import FP16_Optimizer 217 | except: 218 | print('WARNING: apex not installed, ignoring --fp16 option') 219 | args.fp16 = False 220 | 221 | device = torch.device('cuda' if args.cuda else 'cpu') 222 | 223 | ############################################################################### 224 | # Load data 225 | ############################################################################### 226 | 227 | assert args.alinlen == 3000 228 | corpus = get_lm_corpus(args.data,args.vocab_file,args.alinlen) 229 | print("数据加载成功") 230 | ntokens = len(corpus.vocab) # code 单词表的大小 231 | args.n_token = ntokens 232 | 233 | 234 | eval_batch_size = 10 235 | tr_batch_iter = corpus.get_batch_iterator('train', args.batch_size, alianlen=3000, device='cpu') 236 | va_batch_iter = corpus.get_batch_iterator('valid', eval_batch_size, alianlen=3000, device=device) 237 | 238 | print("批次迭代器加载成功") 239 | # te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, 240 | # device=device, ext_len=args.ext_len) 241 | 242 | # adaptive softmax / embedding 243 | cutoffs, tie_projs = [], [False] 244 | if args.adaptive: 245 | cutoffs = [70000, 150000, 250000] 246 | tie_projs += [False] * len(cutoffs) 247 | 248 | 249 | ############################################################################### 250 | # Build the model 251 | ############################################################################### 252 | def init_weight(weight): 253 | if args.init == 'uniform': 254 | nn.init.uniform_(weight, -args.init_range, args.init_range) 255 | elif args.init == 'normal': 256 | nn.init.normal_(weight, 0.0, args.init_std) 257 | 258 | def init_bias(bias): 259 | nn.init.constant_(bias, 0.0) 260 | 261 | # 许海明 262 | def weights_init(m): 263 | classname = m.__class__.__name__ 264 | if classname.find('Linear') != -1: 265 | if hasattr(m, 'weight') and m.weight is not None: 266 | init_weight(m.weight) 267 | if hasattr(m, 'bias') and m.bias is not None: 268 | init_bias(m.bias) 269 | elif classname.find('AdaptiveEmbedding') != -1: 270 | if hasattr(m, 'emb_projs'): 271 | for i in range(len(m.emb_projs)): 272 | if m.emb_projs[i] is not None: 273 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std) 274 | elif classname.find('Embedding') != -1: 275 | if hasattr(m, 'weight'): 276 | init_weight(m.weight) 277 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: 278 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: 279 | init_weight(m.cluster_weight) 280 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: 281 | init_bias(m.cluster_bias) 282 | if hasattr(m, 'out_projs'): 283 | for i in range(len(m.out_projs)): 284 | if m.out_projs[i] is not None: 285 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std) 286 | elif classname.find('LayerNorm') != -1: 287 | if hasattr(m, 'weight'): 288 | nn.init.normal_(m.weight, 1.0, args.init_std) 289 | if hasattr(m, 'bias') and m.bias is not None: 290 | init_bias(m.bias) 291 | elif classname.find('TransformerLM') != -1: 292 | if hasattr(m, 'r_emb'): 293 | init_weight(m.r_emb) 294 | if hasattr(m, 'r_w_bias'): 295 | init_weight(m.r_w_bias) 296 | if hasattr(m, 'r_r_bias'): 297 | init_weight(m.r_r_bias) 298 | if hasattr(m, 'r_bias'): 299 | init_bias(m.r_bias) 300 | 301 | def update_dropout(m): 302 | classname = m.__class__.__name__ 303 | if classname.find('Dropout') != -1: 304 | if hasattr(m, 'p'): 305 | m.p = args.dropout 306 | 307 | def update_dropatt(m): 308 | if hasattr(m, 'dropatt'): 309 | m.dropatt.p = args.dropatt 310 | 311 | if args.restart: 312 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: 313 | model = torch.load(f) 314 | if not args.fp16: 315 | model = model.float() 316 | model.apply(update_dropout) 317 | model.apply(update_dropatt) 318 | else: 319 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, 320 | args.d_head, args.d_inner, args.dropout, args.dropatt, 321 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, 322 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, 323 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, 324 | same_length=args.same_length, attn_type=args.attn_type, 325 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) 326 | model.apply(weights_init) 327 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing 328 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 329 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) 330 | 331 | if args.fp16: 332 | model = model.half() 333 | 334 | 335 | 336 | 337 | 338 | if args.init_model_LM is not None: 339 | logging('=' * 100) 340 | logging('=' * 100) 341 | logging('=' * 100) 342 | logging('=' * 100) 343 | logging("开始加载预训练的语言模型!!!!!!!!!!!!") 344 | model_dict = model.state_dict() 345 | model_LM = {k: v for k, v in torch.load(args.init_model_LM)['state_dict'].items() if k in model_dict} 346 | # print(model_LM) 347 | model_dict.update(model_LM) 348 | model.load_state_dict(model_dict) 349 | 350 | logging('=' * 100) 351 | logging('=' * 100) 352 | logging('=' * 100) 353 | logging('=' * 100) 354 | logging("完成加载预训练的语言模型!!!!!!!!!!!!") 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | if args.multi_gpu: 366 | model = model.to(device) 367 | if args.gpu0_bsz >= 0: 368 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, 369 | model, dim=0).to(device) 370 | else: 371 | para_model = nn.DataParallel(model, dim=0).to(device) 372 | else: 373 | para_model = model.to(device) 374 | 375 | print(" para_model = nn.DataParallel(model, dim=0).to(device) ") 376 | 377 | 378 | #### optimizer 379 | if args.optim.lower() == 'sgd': 380 | if args.sample_softmax > 0: 381 | dense_params, sparse_params = [], [] 382 | for param in model.parameters(): 383 | if param.size() == model.word_emb.weight.size(): 384 | sparse_params.append(param) 385 | else: 386 | dense_params.append(param) 387 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) 388 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) 389 | else: 390 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 391 | momentum=args.mom) 392 | elif args.optim.lower() == 'adam': 393 | if args.sample_softmax > 0: 394 | dense_params, sparse_params = [], [] 395 | for param in model.parameters(): 396 | if param.size() == model.word_emb.weight.size(): 397 | sparse_params.append(param) 398 | else: 399 | dense_params.append(param) 400 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) 401 | optimizer = optim.Adam(dense_params, lr=args.lr) 402 | else: 403 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 404 | elif args.optim.lower() == 'adagrad': 405 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr) 406 | 407 | #### scheduler 408 | if args.scheduler == 'cosine': 409 | # here we do not set eta_min to lr_min to be backward compatible 410 | # because in previous versions eta_min is default to 0 411 | # rather than the default value of lr_min 1e-6 412 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 413 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 414 | if args.sample_softmax > 0: 415 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse, 416 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 417 | elif args.scheduler == 'inv_sqrt': 418 | # originally used for Transformer (in Attention is all you need) 419 | def lr_lambda(step): 420 | # return a multiplier instead of a learning rate 421 | if step == 0 and args.warmup_step == 0: 422 | return 1. 423 | else: 424 | return 1. / (step ** 0.5) if step > args.warmup_step \ 425 | else step / (args.warmup_step ** 1.5) 426 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 427 | elif args.scheduler == 'dev_perf': 428 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 429 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 430 | if args.sample_softmax > 0: 431 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse, 432 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 433 | elif args.scheduler == 'constant': 434 | pass 435 | 436 | if args.cuda and args.fp16: 437 | # If args.dynamic_loss_scale is False, static_loss_scale will be used. 438 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale. 439 | optimizer = FP16_Optimizer(optimizer, 440 | static_loss_scale = args.static_loss_scale, 441 | dynamic_loss_scale = args.dynamic_loss_scale, 442 | dynamic_loss_args = {'init_scale': 2 ** 16}) 443 | 444 | if args.restart: 445 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')): 446 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f: 447 | opt_state_dict = torch.load(f) 448 | optimizer.load_state_dict(opt_state_dict) 449 | else: 450 | print('Optimizer was not saved. Start from scratch.') 451 | 452 | # logging('=' * 100) 453 | # for k, v in args.__dict__.items(): 454 | # logging(' - {} : {}'.format(k, v)) 455 | logging('=' * 100) 456 | logging('#params = {}'.format(args.n_all_param)) 457 | logging('#non emb params = {}'.format(args.n_nonemb_param)) 458 | 459 | ############################################################################### 460 | # Training code 461 | ############################################################################### 462 | 463 | def evaluate(eval_iter): 464 | # Turn on evaluation mode which disables dropout. 465 | model.eval() 466 | 467 | acc_n = 0 468 | val_n = 0 469 | predict = np.zeros((0,), dtype=np.int32) 470 | gt = np.zeros((0,), dtype=np.int32) 471 | 472 | # If the model does not use memory at all, make the ext_len longer. 473 | # Otherwise, make the mem_len longer and keep the ext_len the same. 474 | if args.mem_len == 0: 475 | model.reset_length(args.eval_tgt_len, 476 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len) 477 | else: 478 | model.reset_length(args.eval_tgt_len, 479 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len) 480 | 481 | with torch.no_grad(): 482 | mems = tuple() 483 | for i, (data, labels_all, bsz_len) in enumerate(eval_iter): 484 | if args.max_eval_steps > 0 and i >= args.max_eval_steps: 485 | break 486 | outputs,labels = model(data, labels_all, *mems, tgt_len=args.tgt_len, device=device, ext_len=args.ext_len) 487 | pred = outputs.max(1)[1] 488 | acc_n += (pred == labels).sum().item() 489 | val_n += labels.size(0) 490 | predict = np.hstack((predict, pred.cpu().numpy())) 491 | gt = np.hstack((gt, labels.cpu().numpy())) 492 | 493 | 494 | 495 | acc = 100. * acc_n / val_n 496 | f1score = np.mean(metrics.f1_score(predict, gt, average=None)) 497 | print('* Test Acc: {:.3f}%({}/{}), F1 Score: {}'.format(acc, acc_n, val_n, f1score)) 498 | 499 | # Switch back to the training mode 500 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 501 | model.train() 502 | return f1score 503 | 504 | 505 | def train(): 506 | # Turn on training mode which enables dropout. 507 | global train_step, log_start_time, best_score 508 | train_loss=0.0 509 | correct = 0 510 | total = 0 511 | 512 | 513 | model.train() 514 | if args.batch_chunk > 1: 515 | mems = [tuple() for _ in range(args.batch_chunk)] 516 | else: 517 | mems = tuple() 518 | batch_iter = tr_batch_iter.get_fixlen_iter() 519 | 520 | criterion = F.cross_entropy 521 | 522 | for batch, (data, labels_all, bsz_len) in enumerate(batch_iter): 523 | model.zero_grad() 524 | if args.batch_chunk > 1: 525 | data_chunks = torch.chunk(data, args.batch_chunk, 1) 526 | target_chunks = torch.chunk(labels, args.batch_chunk, 1) 527 | for i in range(args.batch_chunk): 528 | data_i = data_chunks[i].contiguous() 529 | target_i = target_chunks[i].contiguous() 530 | ret = para_model(data_i, target_i, *mems[i]) 531 | loss, mems[i] = ret[0], ret[1:] 532 | loss = loss.float().mean().type_as(loss) / args.batch_chunk 533 | if args.fp16: 534 | optimizer.backward(loss) 535 | else: 536 | loss.backward() 537 | train_loss += loss.float().item() 538 | else: 539 | pred,labels = para_model(data, labels_all, *mems, tgt_len=args.tgt_len, device=device, ext_len=args.ext_len) 540 | 541 | # print(pred) 542 | 543 | loss = criterion(pred, labels) 544 | 545 | if args.fp16: 546 | optimizer.backward(loss) 547 | else: 548 | loss.backward() 549 | 550 | # 更新指标 551 | train_loss += loss.float().item() 552 | predicted = pred.max(1)[1] 553 | total += labels.size(0) 554 | correct += predicted.eq(labels).sum().item() 555 | 556 | 557 | 558 | if args.fp16: 559 | optimizer.clip_master_grads(args.clip) 560 | else: 561 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 562 | 563 | optimizer.step() 564 | 565 | 566 | 567 | if args.sample_softmax > 0: 568 | optimizer_sparse.step() 569 | 570 | # step-wise learning rate annealing 571 | train_step += 1 572 | if args.scheduler in ['cosine', 'constant', 'dev_perf']: 573 | # linear warmup stage 574 | if train_step < args.warmup_step: 575 | curr_lr = args.lr * train_step / args.warmup_step 576 | optimizer.param_groups[0]['lr'] = curr_lr 577 | if args.sample_softmax > 0: 578 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 579 | else: 580 | if args.scheduler == 'cosine': 581 | scheduler.step(train_step) 582 | if args.sample_softmax > 0: 583 | scheduler_sparse.step(train_step) 584 | elif args.scheduler == 'inv_sqrt': 585 | scheduler.step(train_step) 586 | 587 | if train_step % args.log_interval == 0: 588 | cur_loss = train_loss / args.log_interval 589 | elapsed = time.time() - log_start_time 590 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ 591 | '| ms/batch {:5.2f} | loss {:5.2f} | Acc: {:.3f}%({}/{})'.format( 592 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], 593 | elapsed * 1000 / args.log_interval, cur_loss, 100. * correct / total, correct, total) 594 | 595 | logging(log_str) 596 | train_loss = 0.0 597 | log_start_time = time.time() 598 | 599 | if train_step % args.eval_interval == 0: 600 | f1score = evaluate(va_batch_iter) 601 | logging('-' * 100) 602 | 603 | # Save the model if the validation loss is the best we've seen so far. 604 | 605 | if f1score > best_score: 606 | save_path=os.path.join(args.work_dir, 'model.pt') 607 | best_score = f1score 608 | checkpoint = { 609 | 'state_dict': model.state_dict(), 610 | 'config': args 611 | } 612 | torch.save(checkpoint, save_path) 613 | print('Best tmp model f1score: {}'.format(best_score)) 614 | 615 | # # dev-performance based learning rate annealing 616 | # if args.scheduler == 'dev_perf': 617 | # scheduler.step(val_loss) 618 | # if args.sample_softmax > 0: 619 | # scheduler_sparse.step(val_loss) 620 | 621 | 622 | if train_step == args.max_step: 623 | break 624 | 625 | # Loop over epochs. 626 | train_step = 0 627 | best_score=0.0 628 | 629 | 630 | log_start_time = time.time() 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | # At any point you can hit Ctrl + C to break out of training early. 640 | try: 641 | for epoch in itertools.count(start=1): 642 | print("许海明提醒你:开始",epoch) 643 | train() 644 | if train_step == args.max_step: 645 | logging('-' * 100) 646 | logging('End of training') 647 | break 648 | except KeyboardInterrupt: 649 | logging('-' * 100) 650 | logging('Exiting from training early') 651 | 652 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | 20 | # 许海明 21 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 22 | if debug: 23 | print('Debug Mode : no experiment dir created') 24 | return functools.partial(logging, log_path=None, log_=False) 25 | 26 | if not os.path.exists(dir_path): 27 | os.makedirs(dir_path) 28 | 29 | print('Experiment dir : {}'.format(dir_path)) 30 | if scripts_to_save is not None: 31 | script_path = os.path.join(dir_path, 'scripts') 32 | if not os.path.exists(script_path): 33 | os.makedirs(script_path) 34 | for script in scripts_to_save: 35 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 36 | shutil.copyfile(script, dst_file) 37 | 38 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 39 | 40 | def save_checkpoint(model, optimizer, path, epoch): 41 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 42 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 43 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class ProjectedAdaptiveLogSoftmax(nn.Module): 13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 14 | keep_order=False): 15 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 16 | 17 | self.n_token = n_token 18 | self.d_embed = d_embed 19 | self.d_proj = d_proj 20 | 21 | self.cutoffs = cutoffs + [n_token] 22 | self.cutoff_ends = [0] + self.cutoffs 23 | self.div_val = div_val 24 | 25 | self.shortlist_size = self.cutoffs[0] 26 | self.n_clusters = len(self.cutoffs) - 1 27 | self.head_size = self.shortlist_size + self.n_clusters 28 | 29 | if self.n_clusters > 0: 30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 32 | 33 | self.out_layers = nn.ModuleList() 34 | self.out_projs = nn.ParameterList() 35 | 36 | if div_val == 1: 37 | for i in range(len(self.cutoffs)): 38 | if d_proj != d_embed: 39 | self.out_projs.append( 40 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 41 | ) 42 | else: 43 | self.out_projs.append(None) 44 | 45 | self.out_layers.append(nn.Linear(d_embed, n_token)) 46 | else: 47 | for i in range(len(self.cutoffs)): 48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 49 | d_emb_i = d_embed // (div_val ** i) 50 | 51 | self.out_projs.append( 52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 53 | ) 54 | 55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 56 | 57 | self.keep_order = keep_order 58 | 59 | def _compute_logit(self, hidden, weight, bias, proj): 60 | if proj is None: 61 | logit = F.linear(hidden, weight, bias=bias) 62 | else: 63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 64 | proj_hid = F.linear(hidden, proj.t().contiguous()) 65 | logit = F.linear(proj_hid, weight, bias=bias) 66 | # else: 67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 68 | # if bias is not None: 69 | # logit = logit + bias 70 | 71 | return logit 72 | 73 | def forward(self, hidden, target, keep_order=False): 74 | ''' 75 | hidden :: [len*bsz x d_proj] 76 | target :: [len*bsz] 77 | ''' 78 | 79 | if hidden.size(0) != target.size(0): 80 | raise RuntimeError('Input and target should have the same size ' 81 | 'in the batch dimension.') 82 | 83 | if self.n_clusters == 0: 84 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 85 | self.out_layers[0].bias, self.out_projs[0]) 86 | nll = -F.log_softmax(logit, dim=-1) \ 87 | .gather(1, target.unsqueeze(1)).squeeze(1) 88 | else: 89 | # construct weights and biases 90 | weights, biases = [], [] 91 | for i in range(len(self.cutoffs)): 92 | if self.div_val == 1: 93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 94 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 95 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 96 | else: 97 | weight_i = self.out_layers[i].weight 98 | bias_i = self.out_layers[i].bias 99 | 100 | if i == 0: 101 | weight_i = torch.cat( 102 | [weight_i, self.cluster_weight], dim=0) 103 | bias_i = torch.cat( 104 | [bias_i, self.cluster_bias], dim=0) 105 | 106 | weights.append(weight_i) 107 | biases.append(bias_i) 108 | 109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 110 | 111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 112 | head_logprob = F.log_softmax(head_logit, dim=1) 113 | 114 | nll = torch.zeros_like(target, 115 | dtype=hidden.dtype, device=hidden.device) 116 | 117 | offset = 0 118 | cutoff_values = [0] + self.cutoffs 119 | for i in range(len(cutoff_values) - 1): 120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 121 | 122 | mask_i = (target >= l_idx) & (target < r_idx) 123 | indices_i = mask_i.nonzero().squeeze() 124 | 125 | if indices_i.numel() == 0: 126 | continue 127 | 128 | target_i = target.index_select(0, indices_i) - l_idx 129 | head_logprob_i = head_logprob.index_select(0, indices_i) 130 | 131 | if i == 0: 132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 133 | else: 134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 135 | 136 | hidden_i = hidden.index_select(0, indices_i) 137 | 138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 140 | 141 | logprob_i = head_logprob_i[:, -i] \ 142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 143 | 144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 145 | nll.index_copy_(0, indices_i, -logprob_i) 146 | else: 147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 148 | 149 | offset += logprob_i.size(0) 150 | 151 | return nll 152 | -------------------------------------------------------------------------------- /code_for_Classfy/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | 7 | 8 | # 这是词典类 9 | class Vocab(object): 10 | def __init__(self, alinlen=3000,special=[], min_freq=10, max_size=None, lower_case=True, 11 | delimiter=None, vocab_file=None): 12 | ''' 13 | code 14 | :param special: 15 | :param min_freq: 16 | :param max_size: 17 | :param lower_case: 18 | :param delimiter: 19 | :param vocab_file: 20 | ''' 21 | self.counter = Counter() 22 | self.special = special 23 | self.min_freq = min_freq 24 | self.max_size = max_size 25 | self.lower_case = lower_case 26 | self.delimiter = delimiter 27 | self.vocab_file = vocab_file 28 | self.alinlen = alinlen 29 | 30 | 31 | # 许海明 32 | def tokenize(self, line): 33 | line = line.strip() 34 | # convert to lower case 35 | if self.lower_case: 36 | line = line.lower() 37 | 38 | # empty delimiter '' will evaluate False 39 | if self.delimiter == '': 40 | symbols = line 41 | else: 42 | symbols = line.split(self.delimiter) 43 | 44 | 45 | if len(symbols) > self.alinlen-2: 46 | symbols = symbols[:self.alinlen-2] 47 | 48 | symbols = [''] + symbols + [''] 49 | 50 | len_pre = len(symbols) 51 | if len_pre == self.alinlen: 52 | return symbols 53 | else: 54 | assert len_pre']*(self.alinlen-len_pre) 56 | new_symbols.extend(symbols) 57 | assert len(new_symbols)==self.alinlen 58 | return new_symbols 59 | 60 | 61 | 62 | ''' 63 | 统计单词 64 | 并返回分好词的数据 每一行元 65 | ''' 66 | def count_file(self, path, verbose=False): 67 | if verbose: 68 | print('counting file {} ...'.format(path)) 69 | assert os.path.exists(path) 70 | 71 | sents = [] 72 | with open(path, 'r', encoding='utf-8') as f: 73 | for idx, line in enumerate(f): 74 | if verbose and idx > 0 and idx % 500000 == 0: 75 | print(' line {}'.format(idx)) 76 | symbols = self.tokenize(line) 77 | self.counter.update(symbols) 78 | sents.append(symbols) 79 | 80 | return sents 81 | 82 | def count_sents(self, sents, verbose=False): 83 | """ 84 | sents : a list of sentences, each a list of tokenized symbols 85 | """ 86 | if verbose: print('counting {} sents ...'.format(len(sents))) 87 | for idx, symbols in enumerate(sents): 88 | if verbose and idx > 0 and idx % 500000 == 0: 89 | print(' line {}'.format(idx)) 90 | self.counter.update(symbols) 91 | 92 | # 许海明 93 | def _build_from_file(self, vocab_file): 94 | self.idx2sym = [] 95 | self.sym2idx = OrderedDict() 96 | 97 | with open(vocab_file, 'r', encoding='utf-8') as f: 98 | for line in f: 99 | symb = line.strip().split()[0] 100 | self.add_symbol(symb) 101 | self.pad_idx = self.sym2idx[''] 102 | self.s_idx = self.sym2idx[''] 103 | self.unk_idx = self.sym2idx[''] 104 | setattr(self, '{}_idx'.format("".strip('<>')), self.sym2idx['']) 105 | 106 | 107 | # 许海明 108 | def build_vocab(self): 109 | if self.vocab_file: 110 | print('building vocab from {}'.format(self.vocab_file)) 111 | self._build_from_file(self.vocab_file) 112 | print('final vocab size {}'.format(len(self))) 113 | else: 114 | 115 | raise ("单词应该从单词文件中提取,不应该重新创建") 116 | print('building vocab with min_freq={}, max_size={}'.format(self.min_freq, self.max_size)) 117 | self.idx2sym = [] 118 | self.sym2idx = OrderedDict() 119 | for sym in self.special: 120 | self.add_special(sym) 121 | 122 | for sym, cnt in self.counter.most_common(self.max_size): 123 | if cnt < self.min_freq: 124 | break 125 | self.add_symbol(sym) 126 | 127 | print('final vocab size {} from {} unique tokens'.format( 128 | len(self), len(self.counter))) 129 | 130 | 131 | # 许海明 对一个文件进行编码 132 | 133 | def encode_file_only_for_lables(self, path, verbose=True): 134 | 135 | if verbose: 136 | print('正在编码标签 {} ...'.format(path)) 137 | assert os.path.exists(path) 138 | encoded = [] 139 | with open(path, 'r', encoding='utf-8') as f: 140 | for idx, line in enumerate(f): 141 | line = line.strip() 142 | if line is None or line == "": 143 | continue 144 | label = int(line)-1 ##################################### 注意这里减去一个 1 ############################################# 145 | if verbose and idx > 0 and idx % 500000 == 0: 146 | print(' line {}'.format(idx)) 147 | 148 | encoded.append(label) 149 | 150 | 151 | encoded = torch.LongTensor(encoded) 152 | 153 | return encoded 154 | 155 | def encode_file(self, path, verbose=False): 156 | ''' 157 | ordered=True 158 | :param path: 159 | :param ordered: 160 | :param verbose: 161 | :param add_eos: 162 | :param add_double_eos: 163 | :return: 164 | ''' 165 | if verbose: 166 | print('encoding file {} ...'.format(path)) 167 | assert os.path.exists(path) 168 | encoded = [] 169 | with open(path, 'r', encoding='utf-8') as f: 170 | for idx, line in enumerate(f): 171 | line = line.strip() 172 | if line is None or line == "": 173 | continue 174 | if verbose and idx > 0 and idx % 500000 == 0: 175 | print(' line {}'.format(idx)) 176 | symbols = self.tokenize(line) # code 每一行加上结束 177 | 178 | encoded.append(self.convert_to_tensor(symbols)) 179 | 180 | 181 | encoded = torch.reshape(torch.cat(encoded,0), (-1, 3000)) 182 | print(encoded.size()) 183 | 184 | return encoded 185 | 186 | # 许海明 这是对多个句子进行编码 187 | def encode_sents(self, sents, ordered=False, verbose=False): 188 | if verbose: 189 | print('encoding {} sents ...'.format(len(sents))) 190 | encoded = [] 191 | for idx, symbols in enumerate(sents): 192 | if verbose and idx > 0 and idx % 500000 == 0: 193 | print(' line {}'.format(idx)) 194 | encoded.append(self.convert_to_tensor(symbols)) 195 | 196 | if ordered: 197 | encoded = torch.cat(encoded) 198 | 199 | return encoded 200 | 201 | # 许海明 202 | def add_special(self, sym): 203 | if sym not in self.sym2idx: 204 | self.idx2sym.append(sym) 205 | self.sym2idx[sym] = len(self.idx2sym) - 1 206 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 207 | 208 | # 许海明 209 | def add_symbol(self, sym): 210 | if sym not in self.sym2idx: 211 | self.idx2sym.append(sym) 212 | self.sym2idx[sym] = len(self.idx2sym) - 1 213 | 214 | def get_sym(self, idx): 215 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 216 | return self.idx2sym[idx] 217 | 218 | def get_idx(self, sym): 219 | if sym in self.sym2idx: 220 | return self.sym2idx[sym] 221 | else: 222 | # print('encounter unk {}'.format(sym)) 223 | assert '' not in sym 224 | assert hasattr(self, 'unk_idx') 225 | return self.sym2idx.get(sym, self.unk_idx) 226 | 227 | def get_symbols(self, indices): 228 | return [self.get_sym(idx) for idx in indices] 229 | 230 | def get_indices(self, symbols): 231 | return [self.get_idx(sym) for sym in symbols] 232 | 233 | # 许海明 234 | def convert_to_tensor(self, symbols): 235 | return torch.LongTensor(self.get_indices(symbols)) 236 | 237 | def convert_to_sent(self, indices, exclude=None): 238 | if exclude is None: 239 | return ' '.join([self.get_sym(idx) for idx in indices]) 240 | else: 241 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 242 | 243 | def __len__(self): 244 | return len(self.idx2sym) 245 | -------------------------------------------------------------------------------- /code_for_Classfy/运行命令: -------------------------------------------------------------------------------- 1 | python train.py 2 | --cuda 3 | --data ../data/LM/ 4 | --work_dir ../results/LM/ 5 | --dataset dccl 6 | --n_layer 6 7 | --d_model 410 8 | --n_head 10 9 | --d_head 41 10 | --tgt_len 150 11 | --mem_len 150 12 | --eval_tgt_len 150 13 | --d_inner 2100 14 | --dropout 0.1 15 | --dropatt 0.0 16 | --optim adam 17 | --lr 0.00025 18 | --warmup_step 0 19 | --max_step 600000 20 | --batch_size 65 21 | --multi_gpu 22 | --gpu0_bsz 15 23 | -------------------------------------------------------------------------------- /code_for_LM/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuhaiming1996/Transformer-xl-classify/7696d0b328aac4665d851aac5f84e49000919aae/code_for_LM/.DS_Store -------------------------------------------------------------------------------- /code_for_LM/data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | from collections import Counter, OrderedDict 4 | import numpy as np 5 | import torch 6 | 7 | from utils.vocabulary import Vocab 8 | 9 | class LMOrderedIterator(object): 10 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 11 | ''' 12 | :param data: 13 | :param bsz: batch_size 14 | :param bptt: tgt_len 15 | :param device: 16 | :param ext_len: 0 17 | ''' 18 | """ 19 | data -- LongTensor -- the LongTensor is strictly ordered 20 | """ 21 | self.bsz = bsz 22 | self.bptt = bptt 23 | self.ext_len = ext_len if ext_len is not None else 0 24 | 25 | self.device = device 26 | 27 | # Work out how cleanly we can divide the dataset into bsz parts. 28 | self.n_step = data.size(0) // bsz 29 | 30 | 31 | 32 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 33 | data = data.narrow(0, 0, self.n_step * bsz) 34 | 35 | # Evenly divide the data across the bsz batches. 36 | self.data = data.view(bsz, -1).t().contiguous().to(device) 37 | # self.data = data.view(bsz, -1).t().contiguous().to(device) 38 | 39 | 40 | # Number of mini-batches 41 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt # 这里不是一步一步的 42 | 43 | def get_batch(self, i, bptt=None): 44 | if bptt is None: 45 | bptt = self.bptt 46 | seq_len = min(bptt, self.data.size(0) - 1 - i) 47 | 48 | end_idx = i + seq_len 49 | beg_idx = max(0, i - self.ext_len) 50 | 51 | data = self.data[beg_idx:end_idx] 52 | target = self.data[i+1:i+1+seq_len] 53 | 54 | return data, target, seq_len 55 | 56 | def get_fixlen_iter(self, start=0): 57 | for i in range(start, self.data.size(0) - 1, self.bptt): 58 | yield self.get_batch(i) 59 | 60 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 61 | max_len = self.bptt + max_deviation * std 62 | i = start 63 | while True: 64 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 65 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 66 | data, target, seq_len = self.get_batch(i, bptt) 67 | i += seq_len 68 | yield data, target, seq_len 69 | if i >= self.data.size(0) - 2: 70 | break 71 | 72 | def __iter__(self): 73 | return self.get_fixlen_iter() 74 | 75 | 76 | 77 | 78 | # 许海明 79 | class Corpus(object): 80 | def __init__(self, path, *args, **kwargs): 81 | self.vocab = Vocab(*args, **kwargs) 82 | 83 | self.vocab.count_file(os.path.join(path, 'train.txt'),verbose=True) 84 | self.vocab.count_file(os.path.join(path, 'valid.txt'),verbose=True) 85 | 86 | self.vocab.build_vocab() 87 | 88 | 89 | self.train = self.vocab.encode_file( 90 | os.path.join(path, 'train.txt'), ordered=True, verbose=True) 91 | self.valid = self.vocab.encode_file( 92 | os.path.join(path, 'valid.txt'), ordered=True, verbose=True) 93 | # self.test = self.vocab.encode_file( 94 | # os.path.join(path, 'test.txt'), ordered=True) 95 | # 许海明 96 | def get_iterator(self, split, *args, **kwargs): 97 | ''' 98 | 99 | :param split: 100 | :param args: 101 | :param kwargs: 102 | :return: 103 | ''' 104 | if split == 'train': 105 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 106 | 107 | elif split in ['valid', 'test']: 108 | data = self.valid if split == 'valid' else self.test 109 | data_iter = LMOrderedIterator(data, *args, **kwargs) 110 | 111 | return data_iter 112 | 113 | 114 | def get_lm_corpus(datadir, alinlen): 115 | fn = os.path.join(datadir, 'cache.pt') 116 | if os.path.exists(fn): 117 | print('Loading cached dataset...') 118 | corpus = torch.load(fn) 119 | 120 | else: 121 | print('Producing dataset {}...'.format(datadir)) 122 | kwargs = {} 123 | 124 | kwargs['special'] = ['','', '', ''] 125 | kwargs['lower_case'] = False 126 | 127 | 128 | corpus = Corpus(datadir,alinlen ,**kwargs) 129 | torch.save(corpus, fn) # 这里保存的是一个类的对象 130 | 131 | return corpus 132 | 133 | if __name__ == '__main__': 134 | import argparse 135 | parser = argparse.ArgumentParser(description='unit test') 136 | parser.add_argument('--datadir', type=str, default='../data/enwik8', 137 | help='location of the data corpus') 138 | parser.add_argument('--dataset', type=str, default='enwik8', 139 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 140 | help='dataset name') 141 | args = parser.parse_args() 142 | 143 | corpus = get_lm_corpus(args.datadir, args.dataset) 144 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 145 | tr_iter = corpus.get_iterator('train', 22, 512) 146 | ## 147 | # 许海明 许海明 许海明 许海明 许海明 148 | # 许海明 149 | # 150 | # 151 | # ## 152 | -------------------------------------------------------------------------------- /code_for_LM/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | 7 | import torch 8 | 9 | from data_utils import get_lm_corpus 10 | from mem_transformer import MemTransformerLM 11 | from utils.exp_utils import get_logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 14 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 15 | help='location of the data corpus') 16 | parser.add_argument('--dataset', type=str, default='wt103', 17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'], 18 | help='dataset name') 19 | parser.add_argument('--split', type=str, default='all', 20 | choices=['all', 'valid', 'test'], 21 | help='which split to evaluate') 22 | parser.add_argument('--batch_size', type=int, default=10, 23 | help='batch size') 24 | parser.add_argument('--tgt_len', type=int, default=5, 25 | help='number of tokens to predict') 26 | parser.add_argument('--ext_len', type=int, default=0, 27 | help='length of the extended context') 28 | parser.add_argument('--mem_len', type=int, default=0, 29 | help='length of the retained previous heads') 30 | parser.add_argument('--clamp_len', type=int, default=-1, 31 | help='max positional embedding index') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='use CUDA') 34 | parser.add_argument('--work_dir', type=str, required=True, 35 | help='path to the work_dir') 36 | parser.add_argument('--no_log', action='store_true', 37 | help='do not log the eval result') 38 | parser.add_argument('--same_length', action='store_true', 39 | help='set same length attention with masking') 40 | args = parser.parse_args() 41 | assert args.ext_len >= 0, 'extended context length must be non-negative' 42 | 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | # Get logger 46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'), 47 | log_=not args.no_log) 48 | 49 | # Load dataset 50 | corpus = get_lm_corpus(args.data, args.dataset) 51 | ntokens = len(corpus.vocab) 52 | 53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 54 | device=device, ext_len=args.ext_len) 55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 56 | device=device, ext_len=args.ext_len) 57 | 58 | # Load the best saved model. 59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 60 | model = torch.load(f) 61 | model.backward_compatible() 62 | model = model.to(device) 63 | 64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 66 | 67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 68 | if args.clamp_len > 0: 69 | model.clamp_len = args.clamp_len 70 | if args.same_length: 71 | model.same_length = True 72 | 73 | ############################################################################### 74 | # Evaluation code 75 | ############################################################################### 76 | def evaluate(eval_iter): 77 | # Turn on evaluation mode which disables dropout. 78 | model.eval() 79 | total_len, total_loss = 0, 0. 80 | start_time = time.time() 81 | with torch.no_grad(): 82 | mems = tuple() 83 | for idx, (data, target, seq_len) in enumerate(eval_iter): 84 | ret = model(data, target, *mems) 85 | loss, mems = ret[0], ret[1:] 86 | loss = loss.mean() 87 | total_loss += seq_len * loss.item() 88 | total_len += seq_len 89 | total_time = time.time() - start_time 90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format( 91 | total_time, 1000 * total_time / (idx+1))) 92 | return total_loss / total_len 93 | 94 | # Run on test data. 95 | if args.split == 'all': 96 | test_loss = evaluate(te_iter) 97 | valid_loss = evaluate(va_iter) 98 | elif args.split == 'valid': 99 | valid_loss = evaluate(va_iter) 100 | test_loss = None 101 | elif args.split == 'test': 102 | test_loss = evaluate(te_iter) 103 | valid_loss = None 104 | 105 | def format_log(loss, split): 106 | if args.dataset in ['enwik8', 'text8']: 107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( 108 | split, loss, loss / math.log(2)) 109 | else: 110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 111 | split, loss, math.exp(loss)) 112 | return log_str 113 | 114 | log_str = '' 115 | if valid_loss is not None: 116 | log_str += format_log(valid_loss, 'valid') 117 | if test_loss is not None: 118 | log_str += format_log(test_loss, 'test') 119 | 120 | logging('=' * 100) 121 | logging(log_str) 122 | logging('=' * 100) 123 | -------------------------------------------------------------------------------- /code_for_LM/mem_transformer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import functools 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | sys.path.append('utils') 12 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax 13 | from log_uniform_sampler import LogUniformSampler, sample_logits 14 | 15 | class PositionalEmbedding(nn.Module): 16 | def __init__(self, demb): 17 | super(PositionalEmbedding, self).__init__() 18 | 19 | self.demb = demb 20 | 21 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 22 | self.register_buffer('inv_freq', inv_freq) 23 | 24 | def forward(self, pos_seq, bsz=None): 25 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 26 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 27 | 28 | if bsz is not None: 29 | return pos_emb[:,None,:].expand(-1, bsz, -1) 30 | else: 31 | return pos_emb[:,None,:] 32 | 33 | 34 | class PositionwiseFF(nn.Module): 35 | def __init__(self, d_model, d_inner, dropout, pre_lnorm = False): 36 | super(PositionwiseFF, self).__init__() 37 | 38 | self.d_model = d_model 39 | self.d_inner = d_inner 40 | self.dropout = dropout 41 | 42 | self.CoreNet = nn.Sequential( 43 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 44 | nn.Dropout(dropout), 45 | nn.Linear(d_inner, d_model), 46 | nn.Dropout(dropout), 47 | ) 48 | 49 | self.layer_norm = nn.LayerNorm(d_model) 50 | 51 | self.pre_lnorm = pre_lnorm 52 | 53 | def forward(self, inp): 54 | if self.pre_lnorm: 55 | ##### layer normalization + positionwise feed-forward 56 | core_out = self.CoreNet(self.layer_norm(inp)) 57 | 58 | ##### residual connection 59 | output = core_out + inp 60 | else: 61 | ##### positionwise feed-forward 62 | core_out = self.CoreNet(inp) 63 | 64 | ##### residual connection + layer normalization 65 | output = self.layer_norm(inp + core_out) 66 | 67 | return output 68 | 69 | class MultiHeadAttn(nn.Module): 70 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 71 | pre_lnorm=False): 72 | super(MultiHeadAttn, self).__init__() 73 | 74 | self.n_head = n_head 75 | self.d_model = d_model 76 | self.d_head = d_head 77 | self.dropout = dropout 78 | 79 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) 80 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) 81 | 82 | self.drop = nn.Dropout(dropout) 83 | self.dropatt = nn.Dropout(dropatt) 84 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 85 | 86 | self.layer_norm = nn.LayerNorm(d_model) 87 | 88 | self.scale = 1 / (d_head ** 0.5) 89 | 90 | self.pre_lnorm = pre_lnorm 91 | 92 | def forward(self, h, attn_mask=None, mems=None): 93 | ##### multihead attention 94 | # [hlen x bsz x n_head x d_head] 95 | 96 | if mems is not None: 97 | c = torch.cat([mems, h], 0) 98 | else: 99 | c = h 100 | 101 | if self.pre_lnorm: 102 | ##### layer normalization 103 | c = self.layer_norm(c) 104 | 105 | head_q = self.q_net(h) 106 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) 107 | 108 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) 109 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) 110 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) 111 | 112 | # [qlen x klen x bsz x n_head] 113 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) 114 | attn_score.mul_(self.scale) 115 | if attn_mask is not None and attn_mask.any().item(): 116 | if attn_mask.dim() == 2: 117 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 118 | elif attn_mask.dim() == 3: 119 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 120 | 121 | # [qlen x klen x bsz x n_head] 122 | attn_prob = F.softmax(attn_score, dim=1) 123 | attn_prob = self.dropatt(attn_prob) 124 | 125 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] 126 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) 127 | attn_vec = attn_vec.contiguous().view( 128 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 129 | 130 | ##### linear projection 131 | attn_out = self.o_net(attn_vec) 132 | attn_out = self.drop(attn_out) 133 | 134 | if self.pre_lnorm: 135 | ##### residual connection 136 | output = h + attn_out 137 | else: 138 | ##### residual connection + layer normalization 139 | output = self.layer_norm(h + attn_out) 140 | 141 | return output 142 | 143 | 144 | 145 | class RelMultiHeadAttn(nn.Module): 146 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 147 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): 148 | super(RelMultiHeadAttn, self).__init__() 149 | 150 | self.n_head = n_head 151 | self.d_model = d_model 152 | self.d_head = d_head 153 | self.dropout = dropout 154 | 155 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 156 | 157 | self.drop = nn.Dropout(dropout) 158 | self.dropatt = nn.Dropout(dropatt) 159 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 160 | 161 | self.layer_norm = nn.LayerNorm(d_model) 162 | 163 | self.scale = 1 / (d_head ** 0.5) 164 | 165 | self.pre_lnorm = pre_lnorm 166 | 167 | def _parallelogram_mask(self, h, w, left=False): 168 | mask = torch.ones((h, w)).byte() 169 | m = min(h, w) 170 | mask[:m,:m] = torch.triu(mask[:m,:m]) 171 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) 172 | 173 | if left: 174 | return mask 175 | else: 176 | return mask.flip(0) 177 | 178 | def _shift(self, x, qlen, klen, mask, left=False): 179 | if qlen > 1: 180 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), 181 | device=x.device, dtype=x.dtype) 182 | else: 183 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 184 | 185 | if left: 186 | mask = mask.flip(1) 187 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 188 | else: 189 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 190 | 191 | x = x_padded.masked_select(mask[:,:,None,None]) \ 192 | .view(qlen, klen, x.size(2), x.size(3)) 193 | 194 | return x 195 | 196 | def _rel_shift(self, x, zero_triu=False): 197 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), 198 | device=x.device, dtype=x.dtype) 199 | x_padded = torch.cat([zero_pad, x], dim=1) 200 | 201 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 202 | 203 | x = x_padded[1:].view_as(x) 204 | 205 | if zero_triu: 206 | ones = torch.ones((x.size(0), x.size(1))) 207 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] 208 | 209 | return x 210 | 211 | def forward(self, w, r, attn_mask=None, mems=None): 212 | raise NotImplementedError 213 | 214 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 215 | def __init__(self, *args, **kwargs): 216 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 217 | 218 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 219 | 220 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 221 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 222 | 223 | if mems is not None: 224 | cat = torch.cat([mems, w], 0) 225 | if self.pre_lnorm: 226 | w_heads = self.qkv_net(self.layer_norm(cat)) 227 | else: 228 | w_heads = self.qkv_net(cat) 229 | r_head_k = self.r_net(r) 230 | 231 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 232 | w_head_q = w_head_q[-qlen:] 233 | else: 234 | if self.pre_lnorm: 235 | w_heads = self.qkv_net(self.layer_norm(w)) 236 | else: 237 | w_heads = self.qkv_net(w) 238 | r_head_k = self.r_net(r) 239 | 240 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 241 | 242 | klen = w_head_k.size(0) 243 | 244 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 245 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 246 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 247 | 248 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head 249 | 250 | #### compute attention score 251 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head 252 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 253 | 254 | rr_head_q = w_head_q + r_r_bias 255 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head 256 | BD = self._rel_shift(BD) 257 | 258 | # [qlen x klen x bsz x n_head] 259 | attn_score = AC + BD 260 | attn_score.mul_(self.scale) 261 | 262 | #### compute attention probability 263 | if attn_mask is not None and attn_mask.any().item(): 264 | if attn_mask.dim() == 2: 265 | attn_score = attn_score.float().masked_fill( 266 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) 267 | elif attn_mask.dim() == 3: 268 | attn_score = attn_score.float().masked_fill( 269 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) 270 | 271 | # [qlen x klen x bsz x n_head] 272 | attn_prob = F.softmax(attn_score, dim=1) 273 | attn_prob = self.dropatt(attn_prob) 274 | 275 | #### compute attention vector 276 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 277 | 278 | # [qlen x bsz x n_head x d_head] 279 | attn_vec = attn_vec.contiguous().view( 280 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 281 | 282 | ##### linear projection 283 | attn_out = self.o_net(attn_vec) 284 | attn_out = self.drop(attn_out) 285 | 286 | if self.pre_lnorm: 287 | ##### residual connection 288 | output = w + attn_out 289 | else: 290 | ##### residual connection + layer normalization 291 | output = self.layer_norm(w + attn_out) 292 | 293 | return output 294 | 295 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn): 296 | def __init__(self, *args, **kwargs): 297 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 298 | 299 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): 300 | # r_emb: [klen, n_head, d_head], used for term B 301 | # r_w_bias: [n_head, d_head], used for term C 302 | # r_bias: [klen, n_head], used for term D 303 | 304 | qlen, bsz = w.size(0), w.size(1) 305 | 306 | if mems is not None: 307 | cat = torch.cat([mems, w], 0) 308 | if self.pre_lnorm: 309 | w_heads = self.qkv_net(self.layer_norm(cat)) 310 | else: 311 | w_heads = self.qkv_net(cat) 312 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 313 | 314 | w_head_q = w_head_q[-qlen:] 315 | else: 316 | if self.pre_lnorm: 317 | w_heads = self.qkv_net(self.layer_norm(w)) 318 | else: 319 | w_heads = self.qkv_net(w) 320 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 321 | 322 | klen = w_head_k.size(0) 323 | 324 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) 325 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) 326 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) 327 | 328 | if klen > r_emb.size(0): 329 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) 330 | r_emb = torch.cat([r_emb_pad, r_emb], 0) 331 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) 332 | r_bias = torch.cat([r_bias_pad, r_bias], 0) 333 | else: 334 | r_emb = r_emb[-klen:] 335 | r_bias = r_bias[-klen:] 336 | 337 | #### compute attention score 338 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head 339 | 340 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 341 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head 342 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head 343 | BD = self._rel_shift(B_ + D_) 344 | 345 | # [qlen x klen x bsz x n_head] 346 | attn_score = AC + BD 347 | attn_score.mul_(self.scale) 348 | 349 | #### compute attention probability 350 | if attn_mask is not None and attn_mask.any().item(): 351 | if attn_mask.dim() == 2: 352 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) 353 | elif attn_mask.dim() == 3: 354 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) 355 | 356 | # [qlen x klen x bsz x n_head] 357 | attn_prob = F.softmax(attn_score, dim=1) 358 | attn_prob = self.dropatt(attn_prob) 359 | 360 | #### compute attention vector 361 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 362 | 363 | # [qlen x bsz x n_head x d_head] 364 | attn_vec = attn_vec.contiguous().view( 365 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 366 | 367 | ##### linear projection 368 | attn_out = self.o_net(attn_vec) 369 | attn_out = self.drop(attn_out) 370 | 371 | if self.pre_lnorm: 372 | ##### residual connection 373 | output = w + attn_out 374 | else: 375 | ##### residual connection + layer normalization 376 | output = self.layer_norm(w + attn_out) 377 | 378 | return output 379 | 380 | 381 | 382 | 383 | class DecoderLayer(nn.Module): 384 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): 385 | super(DecoderLayer, self).__init__() 386 | 387 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) 388 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 389 | pre_lnorm=kwargs.get('pre_lnorm')) 390 | 391 | def forward(self, dec_inp, dec_attn_mask=None, mems=None): 392 | 393 | output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, 394 | mems=mems) 395 | output = self.pos_ff(output) 396 | 397 | return output 398 | 399 | class RelLearnableDecoderLayer(nn.Module): 400 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 401 | **kwargs): 402 | super(RelLearnableDecoderLayer, self).__init__() 403 | 404 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, 405 | **kwargs) 406 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 407 | pre_lnorm=kwargs.get('pre_lnorm')) 408 | 409 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): 410 | 411 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, 412 | attn_mask=dec_attn_mask, 413 | mems=mems) 414 | output = self.pos_ff(output) 415 | 416 | return output 417 | 418 | 419 | 420 | 421 | class RelPartialLearnableDecoderLayer(nn.Module): 422 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, 423 | **kwargs): 424 | super(RelPartialLearnableDecoderLayer, self).__init__() 425 | 426 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, 427 | d_head, dropout, **kwargs) 428 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 429 | pre_lnorm=kwargs.get('pre_lnorm')) 430 | 431 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 432 | 433 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, 434 | attn_mask=dec_attn_mask, 435 | mems=mems) 436 | output = self.pos_ff(output) 437 | 438 | return output 439 | 440 | 441 | 442 | 443 | # 许海明 444 | class AdaptiveEmbedding(nn.Module): 445 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 446 | sample_softmax=False): 447 | 448 | ''' 449 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) 450 | :param n_token: 451 | :param d_embed: 452 | :param d_proj: 453 | :param cutoffs: 454 | :param div_val: 455 | :param sample_softmax: 456 | ''' 457 | super(AdaptiveEmbedding, self).__init__() 458 | 459 | self.n_token = n_token 460 | self.d_embed = d_embed 461 | 462 | self.cutoffs = cutoffs + [n_token] 463 | self.div_val = div_val 464 | self.d_proj = d_proj 465 | 466 | self.emb_scale = d_proj ** 0.5 467 | 468 | self.cutoff_ends = [0] + self.cutoffs 469 | 470 | self.emb_layers = nn.ModuleList() 471 | self.emb_projs = nn.ParameterList() 472 | if div_val == 1: 473 | self.emb_layers.append( 474 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) 475 | ) 476 | if d_proj != d_embed: 477 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 478 | else: 479 | for i in range(len(self.cutoffs)): 480 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 481 | d_emb_i = d_embed // (div_val ** i) 482 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) 483 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) 484 | 485 | 486 | 487 | def forward(self, inp): 488 | if self.div_val == 1: 489 | embed = self.emb_layers[0](inp) 490 | if self.d_proj != self.d_embed: 491 | embed = F.linear(embed, self.emb_projs[0]) 492 | else: 493 | param = next(self.parameters()) 494 | inp_flat = inp.view(-1) 495 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 496 | dtype=param.dtype, device=param.device) 497 | for i in range(len(self.cutoffs)): 498 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 499 | 500 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) 501 | indices_i = mask_i.nonzero().squeeze() 502 | 503 | if indices_i.numel() == 0: 504 | continue 505 | 506 | inp_i = inp_flat.index_select(0, indices_i) - l_idx 507 | emb_i = self.emb_layers[i](inp_i) 508 | emb_i = F.linear(emb_i, self.emb_projs[i]) 509 | 510 | emb_flat.index_copy_(0, indices_i, emb_i) 511 | 512 | embed = emb_flat.view(*inp.size(), self.d_proj) 513 | 514 | embed.mul_(self.emb_scale) 515 | 516 | return embed 517 | 518 | 519 | 520 | 521 | class MemTransformerLM(nn.Module): 522 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, 523 | dropout, dropatt, tie_weight=True, d_embed=None, 524 | div_val=1, tie_projs=[False], pre_lnorm=False, 525 | tgt_len=None, ext_len=None, mem_len=None, 526 | cutoffs=[], adapt_inp=False, 527 | same_length=False, attn_type=0, clamp_len=-1, 528 | sample_softmax=-1): 529 | ''' 530 | 531 | :param n_token: 单词表的大小 532 | :param n_layer: 16 533 | :param n_head: 10 534 | :param d_model: 410 535 | :param d_head: 41 536 | :param d_inner: 2100 537 | :param dropout: 0.1 538 | :param dropatt: 0.0 539 | :param tie_weight: True 540 | :param d_embed: 410 541 | :param div_val: 1 542 | :param tie_projs: [F,T,T,T] 543 | :param pre_lnorm: False 544 | :param tgt_len: 150 545 | :param ext_len: 0 546 | :param mem_len: 150 547 | :param cutoffs: [20000, 40000, 200000] 548 | :param adapt_inp: 549 | :param same_length: 训练的时候是False 550 | :param attn_type: 0 551 | :param clamp_len: -1 552 | :param sample_softmax: -1 553 | ''' 554 | super(MemTransformerLM, self).__init__() 555 | self.n_token = n_token 556 | 557 | d_embed = d_model if d_embed is None else d_embed 558 | self.d_embed = d_embed 559 | self.d_model = d_model 560 | self.n_head = n_head 561 | self.d_head = d_head 562 | 563 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 564 | div_val=div_val) 565 | 566 | self.drop = nn.Dropout(dropout) 567 | self.n_layer = n_layer 568 | self.tgt_len = tgt_len 569 | self.mem_len = mem_len 570 | self.ext_len = ext_len 571 | self.max_klen = tgt_len + ext_len + mem_len 572 | self.attn_type = attn_type 573 | self.layers = nn.ModuleList() 574 | if attn_type == 0: # the default attention 575 | for i in range(n_layer): 576 | self.layers.append( 577 | RelPartialLearnableDecoderLayer( 578 | n_head, d_model, d_head, d_inner, dropout, 579 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 580 | dropatt=dropatt, pre_lnorm=pre_lnorm) 581 | ) 582 | elif attn_type == 1: # learnable embeddings 583 | for i in range(n_layer): 584 | self.layers.append( 585 | RelLearnableDecoderLayer( 586 | n_head, d_model, d_head, d_inner, dropout, 587 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 588 | dropatt=dropatt, pre_lnorm=pre_lnorm) 589 | ) 590 | elif attn_type in [2, 3]: # absolute embeddings 591 | for i in range(n_layer): 592 | self.layers.append( 593 | DecoderLayer( 594 | n_head, d_model, d_head, d_inner, dropout, 595 | dropatt=dropatt, pre_lnorm=pre_lnorm) 596 | ) 597 | 598 | self.sample_softmax = sample_softmax 599 | # use sampled softmax 600 | if sample_softmax > 0: 601 | self.out_layer = nn.Linear(d_model, n_token) 602 | if tie_weight: 603 | self.out_layer.weight = self.word_emb.weight 604 | self.tie_weight = tie_weight 605 | self.sampler = LogUniformSampler(n_token, sample_softmax) 606 | 607 | # use adaptive softmax (including standard softmax) 608 | else: 609 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 610 | cutoffs, div_val=div_val) 611 | 612 | if tie_weight: 613 | for i in range(len(self.crit.out_layers)): 614 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight 615 | 616 | if tie_projs: 617 | for i, tie_proj in enumerate(tie_projs): 618 | if tie_proj and div_val == 1 and d_model != d_embed: 619 | self.crit.out_projs[i] = self.word_emb.emb_projs[0] 620 | elif tie_proj and div_val != 1: 621 | self.crit.out_projs[i] = self.word_emb.emb_projs[i] 622 | 623 | self.same_length = same_length 624 | self.clamp_len = clamp_len 625 | 626 | self._create_params() 627 | 628 | def backward_compatible(self): 629 | self.sample_softmax = -1 630 | 631 | def _create_params(self): 632 | if self.attn_type == 0: # default attention 633 | self.pos_emb = PositionalEmbedding(self.d_model) 634 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 635 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 636 | 637 | elif self.attn_type == 1: # learnable 638 | self.r_emb = nn.Parameter(torch.Tensor( 639 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 640 | self.r_w_bias = nn.Parameter(torch.Tensor( 641 | self.n_layer, self.n_head, self.d_head)) 642 | self.r_bias = nn.Parameter(torch.Tensor( 643 | self.n_layer, self.max_klen, self.n_head)) 644 | elif self.attn_type == 2: # absolute standard 645 | self.pos_emb = PositionalEmbedding(self.d_model) 646 | elif self.attn_type == 3: # absolute deeper SA 647 | self.r_emb = nn.Parameter(torch.Tensor( 648 | self.n_layer, self.max_klen, self.n_head, self.d_head)) 649 | 650 | def reset_length(self, tgt_len, ext_len, mem_len): 651 | self.tgt_len = tgt_len 652 | self.mem_len = mem_len 653 | self.ext_len = ext_len 654 | 655 | def init_mems(self): 656 | if self.mem_len > 0: 657 | mems = [] 658 | param = next(self.parameters()) 659 | for i in range(self.n_layer+1): 660 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 661 | mems.append(empty) 662 | 663 | return mems 664 | else: 665 | return None 666 | 667 | def _update_mems(self, hids, mems, qlen, mlen): 668 | # does not deal with None 669 | if mems is None: return None 670 | 671 | # mems is not None 672 | assert len(hids) == len(mems), 'len(hids) != len(mems)' 673 | 674 | # There are `mlen + qlen` steps that can be cached into mems 675 | # For the next step, the last `ext_len` of the `qlen` tokens 676 | # will be used as the extended context. Hence, we only cache 677 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len` 678 | # to `mlen + qlen - self.ext_len`. 679 | with torch.no_grad(): 680 | new_mems = [] 681 | end_idx = mlen + max(0, qlen - 0 - self.ext_len) 682 | beg_idx = max(0, end_idx - self.mem_len) 683 | for i in range(len(hids)): 684 | cat = torch.cat([mems[i], hids[i]], dim=0) 685 | new_mems.append(cat[beg_idx:end_idx].detach()) 686 | 687 | return new_mems 688 | 689 | def _forward(self, dec_inp, mems=None): 690 | qlen, bsz = dec_inp.size() 691 | 692 | word_emb = self.word_emb(dec_inp) 693 | 694 | mlen = mems[0].size(0) if mems is not None else 0 695 | klen = mlen + qlen 696 | if self.same_length: 697 | all_ones = word_emb.new_ones(qlen, klen) 698 | mask_len = klen - self.mem_len 699 | if mask_len > 0: 700 | mask_shift_len = qlen - mask_len 701 | else: 702 | mask_shift_len = qlen 703 | dec_attn_mask = (torch.triu(all_ones, 1+mlen) 704 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 705 | else: 706 | dec_attn_mask = torch.triu( 707 | word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] 708 | 709 | hids = [] 710 | if self.attn_type == 0: # default 711 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 712 | dtype=word_emb.dtype) 713 | if self.clamp_len > 0: 714 | pos_seq.clamp_(max=self.clamp_len) 715 | pos_emb = self.pos_emb(pos_seq) 716 | 717 | core_out = self.drop(word_emb) 718 | pos_emb = self.drop(pos_emb) 719 | 720 | hids.append(core_out) 721 | for i, layer in enumerate(self.layers): 722 | mems_i = None if mems is None else mems[i] 723 | core_out = layer(core_out, pos_emb, self.r_w_bias, 724 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 725 | hids.append(core_out) 726 | 727 | 728 | elif self.attn_type == 1: # learnable 729 | core_out = self.drop(word_emb) 730 | hids.append(core_out) 731 | for i, layer in enumerate(self.layers): 732 | if self.clamp_len > 0: 733 | r_emb = self.r_emb[i][-self.clamp_len :] 734 | r_bias = self.r_bias[i][-self.clamp_len :] 735 | else: 736 | r_emb, r_bias = self.r_emb[i], self.r_bias[i] 737 | 738 | mems_i = None if mems is None else mems[i] 739 | core_out = layer(core_out, r_emb, self.r_w_bias[i], 740 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 741 | hids.append(core_out) 742 | elif self.attn_type == 2: # absolute 743 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, 744 | dtype=word_emb.dtype) 745 | if self.clamp_len > 0: 746 | pos_seq.clamp_(max=self.clamp_len) 747 | pos_emb = self.pos_emb(pos_seq) 748 | 749 | core_out = self.drop(word_emb + pos_emb[-qlen:]) 750 | 751 | hids.append(core_out) 752 | for i, layer in enumerate(self.layers): 753 | mems_i = None if mems is None else mems[i] 754 | if mems_i is not None and i == 0: 755 | mems_i += pos_emb[:mlen] 756 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask, 757 | mems=mems_i) 758 | hids.append(core_out) 759 | elif self.attn_type == 3: 760 | core_out = self.drop(word_emb) 761 | 762 | hids.append(core_out) 763 | for i, layer in enumerate(self.layers): 764 | mems_i = None if mems is None else mems[i] 765 | if mems_i is not None and mlen > 0: 766 | cur_emb = self.r_emb[i][:-qlen] 767 | cur_size = cur_emb.size(0) 768 | if cur_size < mlen: 769 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) 770 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) 771 | else: 772 | cur_emb = cur_emb[-mlen:] 773 | mems_i += cur_emb.view(mlen, 1, -1) 774 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) 775 | 776 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask, 777 | mems=mems_i) 778 | hids.append(core_out) 779 | 780 | core_out = self.drop(core_out) 781 | 782 | new_mems = self._update_mems(hids, mems, mlen, qlen) 783 | 784 | return core_out, new_mems 785 | 786 | def forward(self, data, target, *mems): 787 | # nn.DataParallel does not allow size(0) tensors to be broadcasted. 788 | # So, have to initialize size(0) mems inside the model forward. 789 | # Moreover, have to return new_mems to allow nn.DataParallel to piece 790 | # them together. 791 | if not mems: 792 | mems = self.init_mems() 793 | 794 | tgt_len = target.size(0) 795 | hidden, new_mems = self._forward(data, mems=mems) 796 | 797 | pred_hid = hidden[-tgt_len:] 798 | if self.sample_softmax > 0 and self.training: 799 | assert self.tie_weight 800 | logit = sample_logits(self.word_emb, 801 | self.out_layer.bias, target, pred_hid, self.sampler) 802 | loss = -F.log_softmax(logit, -1)[:, :, 0] 803 | else: 804 | loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) 805 | loss = loss.view(tgt_len, -1) 806 | 807 | if new_mems is None: 808 | return [loss] 809 | else: 810 | return [loss] + new_mems 811 | 812 | 813 | 814 | if __name__ == '__main__': 815 | import argparse 816 | 817 | parser = argparse.ArgumentParser(description='unit test') 818 | 819 | parser.add_argument('--n_layer', type=int, default=4, help='') 820 | parser.add_argument('--n_rel_layer', type=int, default=4, help='') 821 | parser.add_argument('--n_head', type=int, default=2, help='') 822 | parser.add_argument('--d_head', type=int, default=2, help='') 823 | parser.add_argument('--d_model', type=int, default=200, help='') 824 | parser.add_argument('--d_embed', type=int, default=200, help='') 825 | parser.add_argument('--d_inner', type=int, default=200, help='') 826 | parser.add_argument('--dropout', type=float, default=0.0, help='') 827 | parser.add_argument('--cuda', action='store_true', help='') 828 | parser.add_argument('--seed', type=int, default=1111, help='') 829 | parser.add_argument('--multi_gpu', action='store_true', help='') 830 | 831 | args = parser.parse_args() 832 | 833 | device = torch.device("cuda" if args.cuda else "cpu") 834 | 835 | B = 4 836 | tgt_len, mem_len, ext_len = 36, 36, 0 837 | data_len = tgt_len * 20 838 | args.n_token = 10000 839 | 840 | import data_utils 841 | 842 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device) 843 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len) 844 | 845 | cutoffs = [args.n_token // 2] 846 | tie_projs = [False] + [True] * len(cutoffs) 847 | 848 | for div_val in [1, 2]: 849 | for d_embed in [200, 100]: 850 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head, 851 | args.d_model, args.d_head, args.d_inner, args.dropout, 852 | dropatt=args.dropout, tie_weight=True, 853 | d_embed=d_embed, div_val=div_val, 854 | tie_projs=tie_projs, pre_lnorm=True, 855 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 856 | cutoffs=cutoffs, attn_type=0).to(device) 857 | 858 | print(sum(p.numel() for p in model.parameters())) 859 | 860 | mems = tuple() 861 | for idx, (inp, tgt, seqlen) in enumerate(diter): 862 | print('batch {}'.format(idx)) 863 | out = model(inp, tgt, *mems) 864 | mems = out[1:] 865 | -------------------------------------------------------------------------------- /code_for_LM/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | import itertools 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | 14 | from data_utils import get_lm_corpus 15 | from mem_transformer import MemTransformerLM 16 | from utils.exp_utils import create_exp_dir 17 | from utils.data_parallel import BalancedDataParallel 18 | 19 | 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,7" 22 | 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 25 | parser.add_argument('--alinlen', type=int, default=3000, help='xhm') 26 | parser.add_argument('--debug', action='store_true', 27 | help='run in debug mode (do not create exp dir)') 28 | 29 | parser.add_argument('--work_dir', default='LM-TFM', type=str, 30 | help='experiment directory.') 31 | parser.add_argument('--not_tied', action='store_true', 32 | help='do not tie the word embedding and softmax weights') 33 | parser.add_argument('--data', type=str, 34 | help='location of the data corpus') 35 | parser.add_argument('--dataset', type=str, default='dccl', 36 | choices=['dccl'], 37 | help='dataset name') 38 | 39 | parser.add_argument('--seed', type=int, default=1111, 40 | help='random seed') 41 | parser.add_argument('--ext_len', type=int, default=0, 42 | help='length of the extended context') 43 | 44 | 45 | parser.add_argument('--n_layer', type=int, default=6, 46 | help='number of total layers') 47 | parser.add_argument('--d_model', type=int, default=410, 48 | help='model dimension') 49 | parser.add_argument('--n_head', type=int, default=10, 50 | help='number of heads') 51 | parser.add_argument('--d_head', type=int, default=41, 52 | help='head dimension') 53 | parser.add_argument('--d_embed', type=int, default=-1, 54 | help='embedding dimension') 55 | 56 | parser.add_argument('--cuda', action='store_true', 57 | help='use CUDA') 58 | 59 | 60 | parser.add_argument('--fp16', action='store_true', 61 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).') 62 | 63 | parser.add_argument('--batch_size', type=int, default=60, 64 | help='batch size') 65 | 66 | 67 | parser.add_argument('--tgt_len', type=int, default=150, 68 | help='number of tokens to predict') 69 | parser.add_argument('--eval_tgt_len', type=int, default=150, 70 | help='number of tokens to predict for evaluation') 71 | 72 | parser.add_argument('--adaptive', action='store_true', 73 | help='use adaptive softmax') 74 | 75 | 76 | 77 | parser.add_argument('--d_inner', type=int, default=1000, 78 | help='inner dimension in FF') 79 | parser.add_argument('--dropout', type=float, default=0.0, 80 | help='global dropout rate') 81 | parser.add_argument('--dropatt', type=float, default=0.0, 82 | help='attention probability dropout rate') 83 | parser.add_argument('--init', default='normal', type=str, 84 | help='parameter initializer to use.') 85 | parser.add_argument('--emb_init', default='normal', type=str, 86 | help='parameter initializer to use.') 87 | parser.add_argument('--init_range', type=float, default=0.1, 88 | help='parameters initialized by U(-init_range, init_range)') 89 | parser.add_argument('--emb_init_range', type=float, default=0.01, 90 | help='parameters initialized by U(-init_range, init_range)') 91 | parser.add_argument('--init_std', type=float, default=0.02, 92 | help='parameters initialized by N(0, init_std)') 93 | parser.add_argument('--proj_init_std', type=float, default=0.01, 94 | help='parameters initialized by N(0, init_std)') 95 | parser.add_argument('--optim', default='adam', type=str, 96 | choices=['adam', 'sgd', 'adagrad'], 97 | help='optimizer to use.') 98 | parser.add_argument('--lr', type=float, default=0.00025, 99 | help='initial learning rate (0.00025|5 for adam|sgd)') 100 | parser.add_argument('--mom', type=float, default=0.0, 101 | help='momentum for sgd') 102 | parser.add_argument('--scheduler', default='cosine', type=str, 103 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'], 104 | help='lr scheduler to use.') 105 | parser.add_argument('--warmup_step', type=int, default=0, 106 | help='upper epoch limit') 107 | parser.add_argument('--decay_rate', type=float, default=0.5, 108 | help='decay factor when ReduceLROnPlateau is used') 109 | parser.add_argument('--lr_min', type=float, default=0.0, 110 | help='minimum learning rate during annealing') 111 | parser.add_argument('--clip', type=float, default=0.25, 112 | help='gradient clipping') 113 | parser.add_argument('--clip_nonemb', action='store_true', 114 | help='only clip the gradient of non-embedding params') 115 | parser.add_argument('--max_step', type=int, default=100000, 116 | help='upper epoch limit') 117 | 118 | parser.add_argument('--batch_chunk', type=int, default=1, 119 | help='split batch into chunks to save memory') 120 | 121 | parser.add_argument('--mem_len', type=int, default=0, 122 | help='length of the retained previous heads') 123 | parser.add_argument('--restart_dir', type=str, default='', 124 | help='restart dir') 125 | 126 | parser.add_argument('--div_val', type=int, default=1, 127 | help='divident value for adapative input and softmax') 128 | parser.add_argument('--pre_lnorm', action='store_true', 129 | help='apply LayerNorm to the input instead of the output') 130 | parser.add_argument('--varlen', action='store_true', 131 | help='use variable length') 132 | parser.add_argument('--multi_gpu', action='store_true', 133 | help='use multiple GPU') 134 | parser.add_argument('--log-interval', type=int, default=200, 135 | help='report interval') 136 | parser.add_argument('--eval-interval', type=int, default=2000, 137 | help='evaluation interval') 138 | 139 | parser.add_argument('--restart', action='store_true', 140 | help='restart training from the saved checkpoint') 141 | 142 | 143 | parser.add_argument('--same_length', action='store_true', 144 | help='use the same attn length for all tokens') 145 | parser.add_argument('--attn_type', type=int, default=0, 146 | help='attention type. 0 for ours, 1 for Shaw et al,' 147 | '2 for Vaswani et al, 3 for Al Rfou et al.') 148 | parser.add_argument('--clamp_len', type=int, default=-1, 149 | help='use the same pos embeddings after clamp_len') 150 | parser.add_argument('--eta_min', type=float, default=0.0, 151 | help='min learning rate for cosine scheduler') 152 | parser.add_argument('--gpu0_bsz', type=int, default=-1, 153 | help='batch size on gpu 0') 154 | parser.add_argument('--max_eval_steps', type=int, default=-1, 155 | help='max eval steps') 156 | parser.add_argument('--sample_softmax', type=int, default=-1, 157 | help='number of samples in sampled softmax') 158 | parser.add_argument('--patience', type=int, default=0, 159 | help='patience') 160 | parser.add_argument('--finetune_v2', action='store_true', 161 | help='finetune v2') 162 | parser.add_argument('--finetune_v3', action='store_true', 163 | help='finetune v3') 164 | 165 | parser.add_argument('--static-loss-scale', type=float, default=1, 166 | help='Static loss scale, positive power of 2 values can ' 167 | 'improve fp16 convergence.') 168 | parser.add_argument('--dynamic-loss-scale', action='store_true', 169 | help='Use dynamic loss scaling. If supplied, this argument' 170 | ' supersedes --static-loss-scale.') 171 | 172 | 173 | #####################################程序开始######################################################################################################## 174 | args = parser.parse_args() 175 | args.tied = not args.not_tied # code True 176 | 177 | if args.d_embed < 0: 178 | args.d_embed = args.d_model # code 410 179 | 180 | assert args.ext_len >= 0, 'extended context length must be non-negative' # code 0 181 | assert args.batch_size % args.batch_chunk == 0 182 | 183 | 184 | logging = create_exp_dir(args.work_dir, 185 | scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) 186 | 187 | 188 | 189 | # Set the random seed manually for reproducibility. 190 | np.random.seed(args.seed) 191 | torch.manual_seed(args.seed) 192 | if torch.cuda.is_available(): 193 | if not args.cuda: 194 | print('WARNING: You have a CUDA device, so you should probably run with --cuda') 195 | else: 196 | torch.cuda.manual_seed_all(args.seed) 197 | 198 | # Validate `--fp16` option 199 | if args.fp16: 200 | if not args.cuda: 201 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') 202 | args.fp16 = False 203 | else: 204 | try: 205 | from apex.fp16_utils import FP16_Optimizer 206 | except: 207 | print('WARNING: apex not installed, ignoring --fp16 option') 208 | args.fp16 = False 209 | 210 | device = torch.device('cuda' if args.cuda else 'cpu') 211 | 212 | ############################################################################### 213 | # Load data 214 | ############################################################################### 215 | 216 | assert args.alinlen == 3000 217 | corpus = get_lm_corpus(args.data, args.alinlen) 218 | print("数据加载成功") 219 | 220 | print("保存单词表") 221 | corpus.vocab.save_symbol(os.path.join(args.work_dir,"vocab.txt")) 222 | 223 | ntokens = len(corpus.vocab) # code 单词表的大小 224 | args.n_token = ntokens 225 | 226 | 227 | eval_batch_size = 10 228 | tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, 229 | device='cpu', ext_len=args.ext_len) 230 | va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, 231 | device=device, ext_len=args.ext_len) 232 | 233 | print("迭代器加载成功") 234 | # te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, 235 | # device=device, ext_len=args.ext_len) 236 | 237 | # adaptive softmax / embedding 238 | cutoffs, tie_projs = [], [False] 239 | if args.adaptive: 240 | cutoffs = [70000, 150000, 250000] 241 | tie_projs += [False] * len(cutoffs) 242 | 243 | 244 | ############################################################################### 245 | # Build the model 246 | ############################################################################### 247 | def init_weight(weight): 248 | if args.init == 'uniform': 249 | nn.init.uniform_(weight, -args.init_range, args.init_range) 250 | elif args.init == 'normal': 251 | nn.init.normal_(weight, 0.0, args.init_std) 252 | 253 | def init_bias(bias): 254 | nn.init.constant_(bias, 0.0) 255 | 256 | # 许海明 257 | def weights_init(m): 258 | classname = m.__class__.__name__ 259 | if classname.find('Linear') != -1: 260 | if hasattr(m, 'weight') and m.weight is not None: 261 | init_weight(m.weight) 262 | if hasattr(m, 'bias') and m.bias is not None: 263 | init_bias(m.bias) 264 | elif classname.find('AdaptiveEmbedding') != -1: 265 | if hasattr(m, 'emb_projs'): 266 | for i in range(len(m.emb_projs)): 267 | if m.emb_projs[i] is not None: 268 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std) 269 | elif classname.find('Embedding') != -1: 270 | if hasattr(m, 'weight'): 271 | init_weight(m.weight) 272 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: 273 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: 274 | init_weight(m.cluster_weight) 275 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: 276 | init_bias(m.cluster_bias) 277 | if hasattr(m, 'out_projs'): 278 | for i in range(len(m.out_projs)): 279 | if m.out_projs[i] is not None: 280 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std) 281 | elif classname.find('LayerNorm') != -1: 282 | if hasattr(m, 'weight'): 283 | nn.init.normal_(m.weight, 1.0, args.init_std) 284 | if hasattr(m, 'bias') and m.bias is not None: 285 | init_bias(m.bias) 286 | elif classname.find('TransformerLM') != -1: 287 | if hasattr(m, 'r_emb'): 288 | init_weight(m.r_emb) 289 | if hasattr(m, 'r_w_bias'): 290 | init_weight(m.r_w_bias) 291 | if hasattr(m, 'r_r_bias'): 292 | init_weight(m.r_r_bias) 293 | if hasattr(m, 'r_bias'): 294 | init_bias(m.r_bias) 295 | 296 | def update_dropout(m): 297 | classname = m.__class__.__name__ 298 | if classname.find('Dropout') != -1: 299 | if hasattr(m, 'p'): 300 | m.p = args.dropout 301 | 302 | def update_dropatt(m): 303 | if hasattr(m, 'dropatt'): 304 | m.dropatt.p = args.dropatt 305 | 306 | if args.restart: 307 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: 308 | model = torch.load(f) 309 | if not args.fp16: 310 | model = model.float() 311 | model.apply(update_dropout) 312 | model.apply(update_dropatt) 313 | else: 314 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, 315 | args.d_head, args.d_inner, args.dropout, args.dropatt, 316 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, 317 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, 318 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, 319 | same_length=args.same_length, attn_type=args.attn_type, 320 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) 321 | model.apply(weights_init) 322 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing 323 | args.n_all_param = sum([p.nelement() for p in model.parameters()]) 324 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) 325 | 326 | if args.fp16: 327 | model = model.half() 328 | 329 | 330 | 331 | 332 | if args.multi_gpu: 333 | model = model.to(device) 334 | if args.gpu0_bsz >= 0: 335 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, 336 | model, dim=1).to(device) 337 | else: 338 | para_model = nn.DataParallel(model, dim=1).to(device) 339 | else: 340 | para_model = model.to(device) 341 | 342 | #### optimizer 343 | if args.optim.lower() == 'sgd': 344 | if args.sample_softmax > 0: 345 | dense_params, sparse_params = [], [] 346 | for param in model.parameters(): 347 | if param.size() == model.word_emb.weight.size(): 348 | sparse_params.append(param) 349 | else: 350 | dense_params.append(param) 351 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) 352 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) 353 | else: 354 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 355 | momentum=args.mom) 356 | elif args.optim.lower() == 'adam': 357 | if args.sample_softmax > 0: 358 | dense_params, sparse_params = [], [] 359 | for param in model.parameters(): 360 | if param.size() == model.word_emb.weight.size(): 361 | sparse_params.append(param) 362 | else: 363 | dense_params.append(param) 364 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) 365 | optimizer = optim.Adam(dense_params, lr=args.lr) 366 | else: 367 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 368 | elif args.optim.lower() == 'adagrad': 369 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr) 370 | 371 | #### scheduler 372 | if args.scheduler == 'cosine': 373 | # here we do not set eta_min to lr_min to be backward compatible 374 | # because in previous versions eta_min is default to 0 375 | # rather than the default value of lr_min 1e-6 376 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 377 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 378 | if args.sample_softmax > 0: 379 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse, 380 | args.max_step, eta_min=args.eta_min) # should use eta_min arg 381 | elif args.scheduler == 'inv_sqrt': 382 | # originally used for Transformer (in Attention is all you need) 383 | def lr_lambda(step): 384 | # return a multiplier instead of a learning rate 385 | if step == 0 and args.warmup_step == 0: 386 | return 1. 387 | else: 388 | return 1. / (step ** 0.5) if step > args.warmup_step \ 389 | else step / (args.warmup_step ** 1.5) 390 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 391 | elif args.scheduler == 'dev_perf': 392 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 393 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 394 | if args.sample_softmax > 0: 395 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse, 396 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) 397 | elif args.scheduler == 'constant': 398 | pass 399 | 400 | if args.cuda and args.fp16: 401 | # If args.dynamic_loss_scale is False, static_loss_scale will be used. 402 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale. 403 | optimizer = FP16_Optimizer(optimizer, 404 | static_loss_scale = args.static_loss_scale, 405 | dynamic_loss_scale = args.dynamic_loss_scale, 406 | dynamic_loss_args = {'init_scale': 2 ** 16}) 407 | 408 | if args.restart: 409 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')): 410 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f: 411 | opt_state_dict = torch.load(f) 412 | optimizer.load_state_dict(opt_state_dict) 413 | else: 414 | print('Optimizer was not saved. Start from scratch.') 415 | 416 | logging('=' * 100) 417 | for k, v in args.__dict__.items(): 418 | logging(' - {} : {}'.format(k, v)) 419 | logging('=' * 100) 420 | logging('#params = {}'.format(args.n_all_param)) 421 | logging('#non emb params = {}'.format(args.n_nonemb_param)) 422 | 423 | ############################################################################### 424 | # Training code 425 | ############################################################################### 426 | 427 | def evaluate(eval_iter): 428 | # Turn on evaluation mode which disables dropout. 429 | model.eval() 430 | 431 | # If the model does not use memory at all, make the ext_len longer. 432 | # Otherwise, make the mem_len longer and keep the ext_len the same. 433 | if args.mem_len == 0: 434 | model.reset_length(args.eval_tgt_len, 435 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len) 436 | else: 437 | model.reset_length(args.eval_tgt_len, 438 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len) 439 | 440 | # Evaluation 441 | total_len, total_loss = 0, 0. 442 | with torch.no_grad(): 443 | mems = tuple() 444 | for i, (data, target, seq_len) in enumerate(eval_iter): 445 | if args.max_eval_steps > 0 and i >= args.max_eval_steps: 446 | break 447 | ret = model(data, target, *mems) 448 | loss, mems = ret[0], ret[1:] 449 | loss = loss.mean() 450 | total_loss += seq_len * loss.float().item() 451 | total_len += seq_len 452 | 453 | # Switch back to the training mode 454 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 455 | model.train() 456 | 457 | return total_loss / total_len 458 | 459 | 460 | def train(): 461 | # Turn on training mode which enables dropout. 462 | global train_step, train_loss, best_val_loss, eval_start_time, log_start_time 463 | model.train() 464 | if args.batch_chunk > 1: 465 | mems = [tuple() for _ in range(args.batch_chunk)] 466 | else: 467 | mems = tuple() 468 | train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter 469 | for batch, (data, target, seq_len) in enumerate(train_iter): 470 | model.zero_grad() 471 | if args.batch_chunk > 1: 472 | data_chunks = torch.chunk(data, args.batch_chunk, 1) 473 | target_chunks = torch.chunk(target, args.batch_chunk, 1) 474 | for i in range(args.batch_chunk): 475 | data_i = data_chunks[i].contiguous() 476 | target_i = target_chunks[i].contiguous() 477 | ret = para_model(data_i, target_i, *mems[i]) 478 | loss, mems[i] = ret[0], ret[1:] 479 | loss = loss.float().mean().type_as(loss) / args.batch_chunk 480 | if args.fp16: 481 | optimizer.backward(loss) 482 | else: 483 | loss.backward() 484 | train_loss += loss.float().item() 485 | else: 486 | ret = para_model(data, target, *mems) 487 | loss, mems = ret[0], ret[1:] 488 | loss = loss.float().mean().type_as(loss) 489 | if args.fp16: 490 | optimizer.backward(loss) 491 | else: 492 | loss.backward() 493 | 494 | train_loss += loss.float().item() 495 | 496 | if args.fp16: 497 | optimizer.clip_master_grads(args.clip) 498 | else: 499 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 500 | 501 | optimizer.step() 502 | if args.sample_softmax > 0: 503 | optimizer_sparse.step() 504 | 505 | # step-wise learning rate annealing 506 | train_step += 1 507 | if args.scheduler in ['cosine', 'constant', 'dev_perf']: 508 | # linear warmup stage 509 | if train_step < args.warmup_step: 510 | curr_lr = args.lr * train_step / args.warmup_step 511 | optimizer.param_groups[0]['lr'] = curr_lr 512 | if args.sample_softmax > 0: 513 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 514 | else: 515 | if args.scheduler == 'cosine': 516 | scheduler.step(train_step) 517 | if args.sample_softmax > 0: 518 | scheduler_sparse.step(train_step) 519 | elif args.scheduler == 'inv_sqrt': 520 | scheduler.step(train_step) 521 | 522 | if train_step % args.log_interval == 0: 523 | cur_loss = train_loss / args.log_interval 524 | elapsed = time.time() - log_start_time 525 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ 526 | '| ms/batch {:5.2f} | loss {:5.2f}'.format( 527 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], 528 | elapsed * 1000 / args.log_interval, cur_loss) 529 | if args.dataset in ['enwik8', 'text8']: 530 | log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) 531 | else: 532 | log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) 533 | logging(log_str) 534 | train_loss = 0 535 | log_start_time = time.time() 536 | 537 | if train_step % args.eval_interval == 0: 538 | val_loss = evaluate(va_iter) 539 | logging('-' * 100) 540 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ 541 | '| valid loss {:5.2f}'.format( 542 | train_step // args.eval_interval, train_step, 543 | (time.time() - eval_start_time), val_loss) 544 | if args.dataset in ['enwik8', 'text8']: 545 | log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) 546 | else: 547 | log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) 548 | logging(log_str) 549 | logging('-' * 100) 550 | # Save the model if the validation loss is the best we've seen so far. 551 | if not best_val_loss or val_loss < best_val_loss: 552 | if not args.debug: 553 | save_path = os.path.join(args.work_dir, 'model.pt') 554 | checkpoint = { 555 | 'state_dict': model.state_dict(), 556 | 'config': args 557 | } 558 | torch.save(checkpoint, save_path) 559 | 560 | # print('Best tmp model f1score: {}'.format(best_score)) 561 | # 562 | # with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f: 563 | # torch.save(model, f) 564 | # with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f: 565 | # torch.save(optimizer.state_dict(), f) 566 | 567 | 568 | best_val_loss = val_loss 569 | 570 | # dev-performance based learning rate annealing 571 | if args.scheduler == 'dev_perf': 572 | scheduler.step(val_loss) 573 | if args.sample_softmax > 0: 574 | scheduler_sparse.step(val_loss) 575 | 576 | eval_start_time = time.time() 577 | 578 | if train_step == args.max_step: 579 | break 580 | 581 | # Loop over epochs. 582 | train_step = 0 583 | train_loss = 0 584 | best_val_loss = None 585 | 586 | log_start_time = time.time() 587 | eval_start_time = time.time() 588 | 589 | # At any point you can hit Ctrl + C to break out of training early. 590 | try: 591 | for epoch in itertools.count(start=1): 592 | print("许海明提醒你:开始",epoch) 593 | train() 594 | if train_step == args.max_step: 595 | logging('-' * 100) 596 | logging('End of training') 597 | break 598 | except KeyboardInterrupt: 599 | logging('-' * 100) 600 | logging('Exiting from training early') 601 | 602 | # # Load the best saved model. 603 | # with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 604 | # model = torch.load(f) 605 | # para_model = model.to(device) 606 | # 607 | # # Run on test data. 608 | # test_loss = evaluate(te_iter) 609 | # logging('=' * 100) 610 | # if args.dataset in ['enwik8', 'text8']: 611 | # logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( 612 | # test_loss, test_loss / math.log(2))) 613 | # else: 614 | # logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( 615 | # test_loss, math.exp(test_loss))) 616 | # logging('=' * 100) 617 | -------------------------------------------------------------------------------- /code_for_LM/utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /code_for_LM/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /code_for_LM/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | 20 | # 许海明 21 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 22 | if debug: 23 | print('Debug Mode : no experiment dir created') 24 | return functools.partial(logging, log_path=None, log_=False) 25 | 26 | if not os.path.exists(dir_path): 27 | os.makedirs(dir_path) 28 | 29 | print('Experiment dir : {}'.format(dir_path)) 30 | if scripts_to_save is not None: 31 | script_path = os.path.join(dir_path, 'scripts') 32 | if not os.path.exists(script_path): 33 | os.makedirs(script_path) 34 | for script in scripts_to_save: 35 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 36 | shutil.copyfile(script, dst_file) 37 | 38 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 39 | 40 | def save_checkpoint(model, optimizer, path, epoch): 41 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 42 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 43 | -------------------------------------------------------------------------------- /code_for_LM/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /code_for_LM/utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class ProjectedAdaptiveLogSoftmax(nn.Module): 13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 14 | keep_order=False): 15 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 16 | 17 | self.n_token = n_token 18 | self.d_embed = d_embed 19 | self.d_proj = d_proj 20 | 21 | self.cutoffs = cutoffs + [n_token] 22 | self.cutoff_ends = [0] + self.cutoffs 23 | self.div_val = div_val 24 | 25 | self.shortlist_size = self.cutoffs[0] 26 | self.n_clusters = len(self.cutoffs) - 1 27 | self.head_size = self.shortlist_size + self.n_clusters 28 | 29 | if self.n_clusters > 0: 30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 32 | 33 | self.out_layers = nn.ModuleList() 34 | self.out_projs = nn.ParameterList() 35 | 36 | if div_val == 1: 37 | for i in range(len(self.cutoffs)): 38 | if d_proj != d_embed: 39 | self.out_projs.append( 40 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 41 | ) 42 | else: 43 | self.out_projs.append(None) 44 | 45 | self.out_layers.append(nn.Linear(d_embed, n_token)) 46 | else: 47 | for i in range(len(self.cutoffs)): 48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 49 | d_emb_i = d_embed // (div_val ** i) 50 | 51 | self.out_projs.append( 52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 53 | ) 54 | 55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 56 | 57 | self.keep_order = keep_order 58 | 59 | def _compute_logit(self, hidden, weight, bias, proj): 60 | if proj is None: 61 | logit = F.linear(hidden, weight, bias=bias) 62 | else: 63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 64 | proj_hid = F.linear(hidden, proj.t().contiguous()) 65 | logit = F.linear(proj_hid, weight, bias=bias) 66 | # else: 67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 68 | # if bias is not None: 69 | # logit = logit + bias 70 | 71 | return logit 72 | 73 | def forward(self, hidden, target, keep_order=False): 74 | ''' 75 | hidden :: [len*bsz x d_proj] 76 | target :: [len*bsz] 77 | ''' 78 | 79 | if hidden.size(0) != target.size(0): 80 | raise RuntimeError('Input and target should have the same size ' 81 | 'in the batch dimension.') 82 | 83 | if self.n_clusters == 0: 84 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 85 | self.out_layers[0].bias, self.out_projs[0]) 86 | nll = -F.log_softmax(logit, dim=-1) \ 87 | .gather(1, target.unsqueeze(1)).squeeze(1) 88 | else: 89 | # construct weights and biases 90 | weights, biases = [], [] 91 | for i in range(len(self.cutoffs)): 92 | if self.div_val == 1: 93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 94 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 95 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 96 | else: 97 | weight_i = self.out_layers[i].weight 98 | bias_i = self.out_layers[i].bias 99 | 100 | if i == 0: 101 | weight_i = torch.cat( 102 | [weight_i, self.cluster_weight], dim=0) 103 | bias_i = torch.cat( 104 | [bias_i, self.cluster_bias], dim=0) 105 | 106 | weights.append(weight_i) 107 | biases.append(bias_i) 108 | 109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 110 | 111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 112 | head_logprob = F.log_softmax(head_logit, dim=1) 113 | 114 | nll = torch.zeros_like(target, 115 | dtype=hidden.dtype, device=hidden.device) 116 | 117 | offset = 0 118 | cutoff_values = [0] + self.cutoffs 119 | for i in range(len(cutoff_values) - 1): 120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 121 | 122 | mask_i = (target >= l_idx) & (target < r_idx) 123 | indices_i = mask_i.nonzero().squeeze() 124 | 125 | if indices_i.numel() == 0: 126 | continue 127 | 128 | target_i = target.index_select(0, indices_i) - l_idx 129 | head_logprob_i = head_logprob.index_select(0, indices_i) 130 | 131 | if i == 0: 132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 133 | else: 134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 135 | 136 | hidden_i = hidden.index_select(0, indices_i) 137 | 138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 140 | 141 | logprob_i = head_logprob_i[:, -i] \ 142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 143 | 144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 145 | nll.index_copy_(0, indices_i, -logprob_i) 146 | else: 147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 148 | 149 | offset += logprob_i.size(0) 150 | 151 | return nll 152 | -------------------------------------------------------------------------------- /code_for_LM/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | 7 | 8 | # 这是词典类 9 | class Vocab(object): 10 | def __init__(self, alinlen=3000,special=[], min_freq=10, max_size=100000, lower_case=True, 11 | delimiter=None, vocab_file=None): 12 | ''' 13 | code 14 | :param special: 15 | :param min_freq: 16 | :param max_size: 17 | :param lower_case: 18 | :param delimiter: 19 | :param vocab_file: 20 | ''' 21 | self.counter = Counter() 22 | self.special = special 23 | self.min_freq = min_freq 24 | self.max_size = max_size 25 | self.lower_case = lower_case 26 | self.delimiter = delimiter 27 | self.vocab_file = vocab_file 28 | self.alinlen = alinlen 29 | 30 | 31 | # 许海明 32 | def tokenize(self, line): 33 | line = line.strip() 34 | # convert to lower case 35 | if self.lower_case: 36 | line = line.lower() 37 | 38 | # empty delimiter '' will evaluate False 39 | if self.delimiter == '': 40 | symbols = line 41 | else: 42 | symbols = line.split(self.delimiter) 43 | 44 | 45 | if len(symbols) > self.alinlen-2: 46 | symbols = symbols[:self.alinlen-2] 47 | 48 | symbols = [''] + symbols + [''] 49 | 50 | len_pre = len(symbols) 51 | if len_pre == self.alinlen: 52 | return symbols 53 | else: 54 | assert len_pre']*(self.alinlen-len_pre) 56 | new_symbols.extend(symbols) 57 | assert len(new_symbols)==self.alinlen 58 | return new_symbols 59 | 60 | 61 | 62 | ''' 63 | 统计单词 64 | 并返回分好词的数据 每一行元 65 | ''' 66 | def count_file(self, path, verbose=False): 67 | if verbose: 68 | print('counting file {} ...'.format(path)) 69 | assert os.path.exists(path) 70 | 71 | sents = [] 72 | with open(path, 'r', encoding='utf-8') as f: 73 | for idx, line in enumerate(f): 74 | line = line.strip() 75 | if line is None or line=="": 76 | continue 77 | if verbose and idx > 0 and idx % 500000 == 0: 78 | print(' line {}'.format(idx)) 79 | symbols = self.tokenize(line) 80 | self.counter.update(symbols) 81 | sents.append(symbols) 82 | 83 | return sents 84 | 85 | def count_sents(self, sents, verbose=False): 86 | """ 87 | sents : a list of sentences, each a list of tokenized symbols 88 | """ 89 | if verbose: print('counting {} sents ...'.format(len(sents))) 90 | for idx, symbols in enumerate(sents): 91 | if verbose and idx > 0 and idx % 500000 == 0: 92 | print(' line {}'.format(idx)) 93 | self.counter.update(symbols) 94 | 95 | # 许海明 96 | def _build_from_file(self, vocab_file): 97 | self.idx2sym = [] 98 | self.sym2idx = OrderedDict() 99 | 100 | with open(vocab_file, 'r', encoding='utf-8') as f: 101 | for line in f: 102 | symb = line.strip().split()[0] 103 | self.add_symbol(symb) 104 | self.unk_idx = self.sym2idx[''] 105 | 106 | 107 | # 许海明 108 | def build_vocab(self): 109 | if self.vocab_file: 110 | print('building vocab from {}'.format(self.vocab_file)) 111 | self._build_from_file(self.vocab_file) 112 | print('final vocab size {}'.format(len(self))) 113 | else: 114 | print('building vocab with min_freq={}, max_size={}'.format( 115 | self.min_freq, self.max_size)) 116 | self.idx2sym = [] 117 | self.sym2idx = OrderedDict() 118 | 119 | for sym in self.special: 120 | self.add_special(sym) 121 | 122 | for sym, cnt in self.counter.most_common(self.max_size): 123 | if cnt < self.min_freq: 124 | break 125 | self.add_symbol(sym) 126 | 127 | print('final vocab size {} from {} unique tokens'.format( 128 | len(self), len(self.counter))) 129 | 130 | 131 | # 许海明 对一个文件进行编码 132 | ''' 133 | tensor([1, 2, 3, 1, 2, 3, 1, 2, 3]) 134 | 135 | a=torch.LongTensor([1,2,3]) 136 | a1=torch.LongTensor([1,2,3]) 137 | a2=torch.LongTensor([1,2,3]) 138 | 139 | print(torch.cat([a,a1,a2])) 140 | ''' 141 | def encode_file(self, path, ordered=False, verbose=False): 142 | ''' 143 | ordered=True 144 | :param path: 145 | :param ordered: 146 | :param verbose: 147 | :param add_eos: 148 | :param add_double_eos: 149 | :return: 150 | ''' 151 | if verbose: 152 | print('encoding file {} ...'.format(path)) 153 | assert os.path.exists(path) 154 | encoded = [] 155 | with open(path, 'r', encoding='utf-8') as f: 156 | for idx, line in enumerate(f): 157 | line = line.strip() 158 | if line is None or line == "": 159 | continue 160 | if verbose and idx > 0 and idx % 500000 == 0: 161 | print(' line {}'.format(idx)) 162 | symbols = self.tokenize(line) # code 每一行加上结束 163 | 164 | encoded.append(self.convert_to_tensor(symbols)) 165 | 166 | if ordered: 167 | encoded = torch.cat(encoded) 168 | 169 | return encoded 170 | 171 | # 许海明 这是对多个句子进行编码 172 | def encode_sents(self, sents, ordered=False, verbose=False): 173 | if verbose: 174 | print('encoding {} sents ...'.format(len(sents))) 175 | encoded = [] 176 | for idx, symbols in enumerate(sents): 177 | if verbose and idx > 0 and idx % 500000 == 0: 178 | print(' line {}'.format(idx)) 179 | encoded.append(self.convert_to_tensor(symbols)) 180 | 181 | if ordered: 182 | encoded = torch.cat(encoded) 183 | 184 | return encoded 185 | 186 | # 许海明 187 | def add_special(self, sym): 188 | if sym not in self.sym2idx: 189 | self.idx2sym.append(sym) 190 | self.sym2idx[sym] = len(self.idx2sym) - 1 191 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 192 | 193 | # 许海明 194 | def add_symbol(self, sym): 195 | if sym not in self.sym2idx: 196 | self.idx2sym.append(sym) 197 | self.sym2idx[sym] = len(self.idx2sym) - 1 198 | 199 | 200 | def save_symbol(self, vocabfile): 201 | with open(vocabfile, mode="w", encoding="utf-8") as fw: 202 | for word in self.sym2idx: 203 | fw.write(word+"\n") 204 | 205 | 206 | 207 | 208 | 209 | def get_sym(self, idx): 210 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 211 | return self.idx2sym[idx] 212 | 213 | def get_idx(self, sym): 214 | if sym in self.sym2idx: 215 | return self.sym2idx[sym] 216 | else: 217 | # print('encounter unk {}'.format(sym)) 218 | assert '' not in sym 219 | assert hasattr(self, 'unk_idx') 220 | return self.sym2idx.get(sym, self.unk_idx) 221 | 222 | def get_symbols(self, indices): 223 | return [self.get_sym(idx) for idx in indices] 224 | 225 | def get_indices(self, symbols): 226 | return [self.get_idx(sym) for sym in symbols] 227 | 228 | # 许海明 229 | def convert_to_tensor(self, symbols): 230 | return torch.LongTensor(self.get_indices(symbols)) 231 | 232 | def convert_to_sent(self, indices, exclude=None): 233 | if exclude is None: 234 | return ' '.join([self.get_sym(idx) for idx in indices]) 235 | else: 236 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 237 | 238 | def __len__(self): 239 | return len(self.idx2sym) 240 | -------------------------------------------------------------------------------- /code_for_LM/运行命令: -------------------------------------------------------------------------------- 1 | # 这是训练的目标 2 | 3 | python train.py \ 4 | --cuda \ 5 | --data ../data/LM/ \ 6 | --work_dir ../results/LM/ \ 7 | --dataset dccl \ 8 | --n_layer 6 \ 9 | --d_model 410 \ 10 | --n_head 10 \ 11 | --d_head 41 \ 12 | --tgt_len 150 \ 13 | --mem_len 150 \ 14 | --eval_tgt_len 150 \ 15 | --d_inner 2100 \ 16 | --dropout 0.1 \ 17 | --dropatt 0.0 \ 18 | --optim adam \ 19 | --lr 0.00025 \ 20 | --warmup_step 0 \ 21 | --max_step 200000 \ 22 | --batch_size 32 \ 23 | --multi_gpu \ 24 | 25 | -------------------------------------------------------------------------------- /image/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuhaiming1996/Transformer-xl-classify/7696d0b328aac4665d851aac5f84e49000919aae/image/model.jpg -------------------------------------------------------------------------------- /说明: -------------------------------------------------------------------------------- 1 | 这是一个使用多卡机训练的例子 是一个非常值得学习的代码 --------------------------------------------------------------------------------