├── .idea
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
├── transformer-xl.iml
├── vcs.xml
├── webServers.xml
└── workspace.xml
├── README.md
├── code_for_Classfy
├── data_utils.py
├── eval.py
├── mem_transformer.py
├── myattention.py
├── test.py
├── train.py
├── utils
│ ├── adaptive_softmax.py
│ ├── data_parallel.py
│ ├── exp_utils.py
│ ├── log_uniform_sampler.py
│ ├── proj_adaptive_softmax.py
│ └── vocabulary.py
└── 运行命令
├── code_for_LM
├── .DS_Store
├── data_utils.py
├── eval.py
├── mem_transformer.py
├── train.py
├── utils
│ ├── adaptive_softmax.py
│ ├── data_parallel.py
│ ├── exp_utils.py
│ ├── log_uniform_sampler.py
│ ├── proj_adaptive_softmax.py
│ └── vocabulary.py
└── 运行命令
├── image
└── model.jpg
├── results
└── LM
│ └── vocab.txt
└── 说明
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/transformer-xl.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 介绍
2 | 功能:这是一个使用transform-xl实现的长文本分类器。这是一个文章级别的分类。文章的长度3000(分好词)
3 | 训练方式:采用预训练+微调
4 | 语言:pytorch1.0
5 |
6 | ## 分类模型结构
7 | 
8 |
9 | ## 文件说明
10 | ### data
11 | #### /data/LM
12 | train.txt 存放分好词的语料,每一行为一篇章。
13 | #### /data/CLASSIFY
14 | train.label 每一行为文章标签
15 | train.txt 每一行为已经分好词的一篇文章
16 | valid.label 每一行为文章标签
17 | valid.txt 每一行为已经分好词的一篇文章
18 |
19 |
20 | ### code_for_LM
21 | #### /code_for_LM/运行命令
22 | 这是我预训练的命令,你可以根据自己的实际情况自己调整
23 | 这里提示:尽量采用4卡并行计算
24 |
25 |
26 | ### code_for_CLassify
27 | #### /code_for_CLassify/运行命令
28 | 这是微调的的命令,你可以根据自己的实际情况自己调整
29 | 这里提示:尽量采用4卡并行计算,
30 |
31 |
32 | ### results
33 | #### /results/LM
34 | 预训练的保存路径:模型参数,词典等等
35 | #### /results/CLASSIFY
36 | 训练分类器的保存路径:模型参数等
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/code_for_Classfy/data_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import glob
3 | from collections import Counter, OrderedDict
4 | import numpy as np
5 | import torch
6 |
7 | from utils.vocabulary import Vocab
8 |
9 |
10 | class BatchIteratorHelper:
11 | def __init__(self, data, lables, bsz, alianlen=3000, device='cpu'):
12 | '''
13 |
14 | :param data: [样本数,3000]
15 | :param lables: [样本数]
16 | :param bsz: 批次操作
17 | :param alianlen: 3000
18 | :param device:
19 | '''
20 |
21 | self.bsz = bsz
22 | self.num_example = data.size(0) # 样本的个数
23 | self.device = device
24 | self.data = data.contiguous().to(device)
25 | # Number of mini-batches
26 | self.n_batch = (self.num_example + self.bsz - 1) // self.bsz # 批次的个数
27 | self.lables = lables.contiguous().to(device)
28 |
29 | print(self.data.size())
30 | # 下面将进行验证
31 | assert self.data.size(1)==alianlen
32 | assert self.data.size(0)==self.lables.size(0)
33 |
34 |
35 | def get_batch(self, i, bsz=None):
36 | if bsz is None:
37 | bsz = self.bsz
38 | bsz_len = min(bsz, self.data.size(0) - 1 - i)
39 | end_idx = i + bsz_len
40 | beg_idx = i
41 | data = self.data[beg_idx:end_idx]
42 |
43 | labels=self.lables[beg_idx:end_idx]
44 |
45 | return data, labels,bsz_len
46 |
47 | def get_fixlen_iter(self, start=0):
48 | for i in range(start, self.data.size(0) - 1, self.bsz):
49 | yield self.get_batch(i)
50 |
51 |
52 | def __iter__(self):
53 | return self.get_fixlen_iter()
54 |
55 |
56 |
57 |
58 | class LMOrderedIterator(object):
59 | def __init__(self, data, bsz, bptt, alianlen=3000, device='cpu', ext_len=None):
60 | '''
61 | :param data:
62 | :param bsz: batch_size
63 | :param bptt: tgt_len
64 | :param device:
65 | :param ext_len: 0
66 | '''
67 | """
68 | data -- LongTensor -- the LongTensor is strictly ordered
69 | """
70 | self.bsz = bsz
71 | self.bptt = bptt
72 | self.ext_len = ext_len if ext_len is not None else 0
73 |
74 | self.device = device
75 |
76 | # Work out how cleanly we can divide the dataset into bsz parts.
77 | # self.n_step = data.size(0) // bsz
78 |
79 | self.n_step = alianlen
80 |
81 |
82 | # # Trim off any extra elements that wouldn't cleanly fit (remainders).
83 | # data = data.narrow(0, 0, self.n_step * bsz)
84 |
85 | # Evenly divide the data across the bsz batches.
86 | self.data = data.t().contiguous().to(device)
87 | # print(self.data.size())
88 | # print(self.bsz)
89 | # print(self.n_step)
90 | # self.data = data.view(bsz, -1).t().contiguous().to(device)
91 | # Number of mini-batches
92 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt # 这里不是一步一步的
93 |
94 | # 下面开始验证
95 |
96 | assert self.data.size(0)== self.n_step
97 | assert self.data.size(1) == self.bsz
98 |
99 |
100 | def get_batch(self, i, bptt=None):
101 | if bptt is None:
102 | bptt = self.bptt
103 | seq_len = min(bptt, self.data.size(0) - 1 - i)
104 | end_idx = i + seq_len
105 | beg_idx = max(0, i - self.ext_len)
106 | data = self.data[beg_idx:end_idx]
107 |
108 | return data, seq_len
109 |
110 | def get_fixlen_iter(self, start=0):
111 | for i in range(start, self.data.size(0) - 1, self.bptt):
112 | yield self.get_batch(i)
113 |
114 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
115 | max_len = self.bptt + max_deviation * std
116 | i = start
117 | while True:
118 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
119 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
120 | data, target, seq_len = self.get_batch(i, bptt)
121 | i += seq_len
122 | yield data, target, seq_len
123 | if i >= self.data.size(0) - 2:
124 | break
125 |
126 | def __iter__(self):
127 | return self.get_fixlen_iter()
128 |
129 |
130 |
131 |
132 | # 许海明
133 | class Corpus(object):
134 | def __init__(self, path, *args, **kwargs):
135 | self.vocab = Vocab(*args, **kwargs)
136 |
137 | # 从单词表里面加载单词
138 | self.vocab.build_vocab()
139 |
140 | # 训练集
141 | self.train = self.vocab.encode_file( os.path.join(path, 'train.txt'), verbose=True)
142 | self.train_label = self.vocab.encode_file_only_for_lables(os.path.join(path, 'train.label'), verbose=True)
143 |
144 | # 验证集
145 | self.valid = self.vocab.encode_file(os.path.join(path, 'valid.txt'), verbose=True)
146 | self.valid_label = self.vocab.encode_file_only_for_lables(os.path.join(path, 'valid.label'), verbose=True)
147 |
148 |
149 | # self.test = self.vocab.encode_file(
150 | # os.path.join(path, 'test.txt'), ordered=True)
151 |
152 | # 许海明
153 | def get_batch_iterator(self, split, *args, **kwargs):
154 | '''
155 |
156 | :param split:
157 | :param args:
158 | :param kwargs:
159 | :return:
160 | '''
161 | if split == 'train':
162 | # data_iter = LMOrderedIterator(self.train, *args, **kwargs)
163 | batch_iter = BatchIteratorHelper(self.train,self.train_label, *args, **kwargs)
164 |
165 | elif split == 'valid':
166 | batch_iter = BatchIteratorHelper(self.valid, self.valid_label, *args, **kwargs)
167 |
168 | return batch_iter
169 |
170 |
171 | def get_lm_corpus(datadir,vocab_file,alinlen):
172 | fn = os.path.join(datadir, 'cache.pt')
173 | if os.path.exists(fn):
174 | print('Loading cached dataset...')
175 | corpus = torch.load(fn)
176 | else:
177 | print('Producing dataset {}...'.format(datadir))
178 | kwargs = {}
179 |
180 | kwargs['special'] = ['','', '', '']
181 | kwargs['lower_case'] = False
182 | kwargs['vocab_file'] = vocab_file
183 |
184 |
185 |
186 | corpus = Corpus(datadir,alinlen ,**kwargs)
187 | torch.save(corpus, fn) # 这里保存的是一个类的对象
188 |
189 | return corpus
190 |
191 | if __name__ == '__main__':
192 | import argparse
193 | parser = argparse.ArgumentParser(description='unit test')
194 | parser.add_argument('--datadir', type=str, default='../data/enwik8',
195 | help='location of the data corpus')
196 | parser.add_argument('--dataset', type=str, default='enwik8',
197 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
198 | help='dataset name')
199 | args = parser.parse_args()
200 |
201 | corpus = get_lm_corpus(args.datadir, args.dataset)
202 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))
203 | tr_iter = corpus.get_iterator('train', 22, 512)
204 | ##
205 | # 许海明 许海明 许海明 许海明 许海明
206 | # 许海明
207 | #
208 | #
209 | # ##
210 |
--------------------------------------------------------------------------------
/code_for_Classfy/eval.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import time
4 | import math
5 | import os, sys
6 |
7 | import torch
8 |
9 | from data_utils import get_lm_corpus
10 | from mem_transformer import MemTransformerLM
11 | from utils.exp_utils import get_logger
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
14 | parser.add_argument('--data', type=str, default='../data/wikitext-103',
15 | help='location of the data corpus')
16 | parser.add_argument('--dataset', type=str, default='wt103',
17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'],
18 | help='dataset name')
19 | parser.add_argument('--split', type=str, default='all',
20 | choices=['all', 'valid', 'test'],
21 | help='which split to evaluate')
22 | parser.add_argument('--batch_size', type=int, default=10,
23 | help='batch size')
24 | parser.add_argument('--tgt_len', type=int, default=5,
25 | help='number of tokens to predict')
26 | parser.add_argument('--ext_len', type=int, default=0,
27 | help='length of the extended context')
28 | parser.add_argument('--mem_len', type=int, default=0,
29 | help='length of the retained previous heads')
30 | parser.add_argument('--clamp_len', type=int, default=-1,
31 | help='max positional embedding index')
32 | parser.add_argument('--cuda', action='store_true',
33 | help='use CUDA')
34 | parser.add_argument('--work_dir', type=str, required=True,
35 | help='path to the work_dir')
36 | parser.add_argument('--no_log', action='store_true',
37 | help='do not log the eval result')
38 | parser.add_argument('--same_length', action='store_true',
39 | help='set same length attention with masking')
40 | args = parser.parse_args()
41 | assert args.ext_len >= 0, 'extended context length must be non-negative'
42 |
43 | device = torch.device("cuda" if args.cuda else "cpu")
44 |
45 | # Get logger
46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
47 | log_=not args.no_log)
48 |
49 | # Load dataset
50 | corpus = get_lm_corpus(args.data, args.dataset)
51 | ntokens = len(corpus.vocab)
52 |
53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
54 | device=device, ext_len=args.ext_len)
55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
56 | device=device, ext_len=args.ext_len)
57 |
58 | # Load the best saved model.
59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
60 | model = torch.load(f)
61 | model.backward_compatible()
62 | model = model.to(device)
63 |
64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
66 |
67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
68 | if args.clamp_len > 0:
69 | model.clamp_len = args.clamp_len
70 | if args.same_length:
71 | model.same_length = True
72 |
73 | ###############################################################################
74 | # Evaluation code
75 | ###############################################################################
76 | def evaluate(eval_iter):
77 | # Turn on evaluation mode which disables dropout.
78 | model.eval()
79 | total_len, total_loss = 0, 0.
80 | start_time = time.time()
81 | with torch.no_grad():
82 | mems = tuple()
83 | for idx, (data, target, seq_len) in enumerate(eval_iter):
84 | ret = model(data, target, *mems)
85 | loss, mems = ret[0], ret[1:]
86 | loss = loss.mean()
87 | total_loss += seq_len * loss.item()
88 | total_len += seq_len
89 | total_time = time.time() - start_time
90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
91 | total_time, 1000 * total_time / (idx+1)))
92 | return total_loss / total_len
93 |
94 | # Run on test data.
95 | if args.split == 'all':
96 | test_loss = evaluate(te_iter)
97 | valid_loss = evaluate(va_iter)
98 | elif args.split == 'valid':
99 | valid_loss = evaluate(va_iter)
100 | test_loss = None
101 | elif args.split == 'test':
102 | test_loss = evaluate(te_iter)
103 | valid_loss = None
104 |
105 | def format_log(loss, split):
106 | if args.dataset in ['enwik8', 'text8']:
107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
108 | split, loss, loss / math.log(2))
109 | else:
110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
111 | split, loss, math.exp(loss))
112 | return log_str
113 |
114 | log_str = ''
115 | if valid_loss is not None:
116 | log_str += format_log(valid_loss, 'valid')
117 | if test_loss is not None:
118 | log_str += format_log(test_loss, 'test')
119 |
120 | logging('=' * 100)
121 | logging(log_str)
122 | logging('=' * 100)
123 |
--------------------------------------------------------------------------------
/code_for_Classfy/mem_transformer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import functools
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from data_utils import LMOrderedIterator
12 |
13 | sys.path.append('utils')
14 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
15 | from log_uniform_sampler import LogUniformSampler, sample_logits
16 |
17 | class PositionalEmbedding(nn.Module):
18 | def __init__(self, demb):
19 | super(PositionalEmbedding, self).__init__()
20 |
21 | self.demb = demb
22 |
23 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
24 | self.register_buffer('inv_freq', inv_freq)
25 |
26 | def forward(self, pos_seq, bsz=None):
27 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
28 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
29 |
30 | if bsz is not None:
31 | return pos_emb[:,None,:].expand(-1, bsz, -1)
32 | else:
33 | return pos_emb[:,None,:]
34 |
35 |
36 | class PositionwiseFF(nn.Module):
37 | def __init__(self, d_model, d_inner, dropout, pre_lnorm = False):
38 | super(PositionwiseFF, self).__init__()
39 |
40 | self.d_model = d_model
41 | self.d_inner = d_inner
42 | self.dropout = dropout
43 |
44 | self.CoreNet = nn.Sequential(
45 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
46 | nn.Dropout(dropout),
47 | nn.Linear(d_inner, d_model),
48 | nn.Dropout(dropout),
49 | )
50 |
51 | self.layer_norm = nn.LayerNorm(d_model)
52 |
53 | self.pre_lnorm = pre_lnorm
54 |
55 | def forward(self, inp):
56 | if self.pre_lnorm:
57 | ##### layer normalization + positionwise feed-forward
58 | core_out = self.CoreNet(self.layer_norm(inp))
59 |
60 | ##### residual connection
61 | output = core_out + inp
62 | else:
63 | ##### positionwise feed-forward
64 | core_out = self.CoreNet(inp)
65 |
66 | ##### residual connection + layer normalization
67 | output = self.layer_norm(inp + core_out)
68 |
69 | return output
70 |
71 | class MultiHeadAttn(nn.Module):
72 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
73 | pre_lnorm=False):
74 | super(MultiHeadAttn, self).__init__()
75 |
76 | self.n_head = n_head
77 | self.d_model = d_model
78 | self.d_head = d_head
79 | self.dropout = dropout
80 |
81 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
82 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
83 |
84 | self.drop = nn.Dropout(dropout)
85 | self.dropatt = nn.Dropout(dropatt)
86 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
87 |
88 | self.layer_norm = nn.LayerNorm(d_model)
89 |
90 | self.scale = 1 / (d_head ** 0.5)
91 |
92 | self.pre_lnorm = pre_lnorm
93 |
94 | def forward(self, h, attn_mask=None, mems=None):
95 | ##### multihead attention
96 | # [hlen x bsz x n_head x d_head]
97 |
98 | if mems is not None:
99 | c = torch.cat([mems, h], 0)
100 | else:
101 | c = h
102 |
103 | if self.pre_lnorm:
104 | ##### layer normalization
105 | c = self.layer_norm(c)
106 |
107 | head_q = self.q_net(h)
108 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
109 |
110 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
111 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
112 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
113 |
114 | # [qlen x klen x bsz x n_head]
115 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
116 | attn_score.mul_(self.scale)
117 | if attn_mask is not None and attn_mask.any().item():
118 | if attn_mask.dim() == 2:
119 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
120 | elif attn_mask.dim() == 3:
121 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
122 |
123 | # [qlen x klen x bsz x n_head]
124 | attn_prob = F.softmax(attn_score, dim=1)
125 | attn_prob = self.dropatt(attn_prob)
126 |
127 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
128 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
129 | attn_vec = attn_vec.contiguous().view(
130 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
131 |
132 | ##### linear projection
133 | attn_out = self.o_net(attn_vec)
134 | attn_out = self.drop(attn_out)
135 |
136 | if self.pre_lnorm:
137 | ##### residual connection
138 | output = h + attn_out
139 | else:
140 | ##### residual connection + layer normalization
141 | output = self.layer_norm(h + attn_out)
142 |
143 | return output
144 |
145 |
146 |
147 | class RelMultiHeadAttn(nn.Module):
148 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
149 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
150 | super(RelMultiHeadAttn, self).__init__()
151 |
152 | self.n_head = n_head
153 | self.d_model = d_model
154 | self.d_head = d_head
155 | self.dropout = dropout
156 |
157 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
158 |
159 | self.drop = nn.Dropout(dropout)
160 | self.dropatt = nn.Dropout(dropatt)
161 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
162 |
163 | self.layer_norm = nn.LayerNorm(d_model)
164 |
165 | self.scale = 1 / (d_head ** 0.5)
166 |
167 | self.pre_lnorm = pre_lnorm
168 |
169 | def _parallelogram_mask(self, h, w, left=False):
170 | mask = torch.ones((h, w)).byte()
171 | m = min(h, w)
172 | mask[:m,:m] = torch.triu(mask[:m,:m])
173 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
174 |
175 | if left:
176 | return mask
177 | else:
178 | return mask.flip(0)
179 |
180 | def _shift(self, x, qlen, klen, mask, left=False):
181 | if qlen > 1:
182 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
183 | device=x.device, dtype=x.dtype)
184 | else:
185 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
186 |
187 | if left:
188 | mask = mask.flip(1)
189 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
190 | else:
191 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
192 |
193 | x = x_padded.masked_select(mask[:,:,None,None]) \
194 | .view(qlen, klen, x.size(2), x.size(3))
195 |
196 | return x
197 |
198 | def _rel_shift(self, x, zero_triu=False):
199 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
200 | device=x.device, dtype=x.dtype)
201 | x_padded = torch.cat([zero_pad, x], dim=1)
202 |
203 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
204 |
205 | x = x_padded[1:].view_as(x)
206 |
207 | if zero_triu:
208 | ones = torch.ones((x.size(0), x.size(1)))
209 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
210 |
211 | return x
212 |
213 | def forward(self, w, r, attn_mask=None, mems=None):
214 | raise NotImplementedError
215 |
216 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
217 | def __init__(self, *args, **kwargs):
218 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
219 |
220 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
221 |
222 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
223 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
224 |
225 | if mems is not None:
226 | cat = torch.cat([mems, w], 0)
227 | if self.pre_lnorm:
228 | w_heads = self.qkv_net(self.layer_norm(cat))
229 | else:
230 | w_heads = self.qkv_net(cat)
231 | r_head_k = self.r_net(r)
232 |
233 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
234 | w_head_q = w_head_q[-qlen:]
235 | else:
236 | if self.pre_lnorm:
237 | w_heads = self.qkv_net(self.layer_norm(w))
238 | else:
239 | w_heads = self.qkv_net(w)
240 | r_head_k = self.r_net(r)
241 |
242 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
243 |
244 | klen = w_head_k.size(0)
245 |
246 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
247 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
248 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
249 |
250 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head
251 |
252 | #### compute attention score
253 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
254 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
255 |
256 | rr_head_q = w_head_q + r_r_bias
257 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
258 | BD = self._rel_shift(BD)
259 |
260 | # [qlen x klen x bsz x n_head]
261 | attn_score = AC + BD
262 | attn_score.mul_(self.scale)
263 |
264 | #### compute attention probability
265 | if attn_mask is not None and attn_mask.any().item():
266 | if attn_mask.dim() == 2:
267 | attn_score = attn_score.float().masked_fill(
268 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
269 | elif attn_mask.dim() == 3:
270 | attn_score = attn_score.float().masked_fill(
271 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
272 |
273 | # [qlen x klen x bsz x n_head]
274 | attn_prob = F.softmax(attn_score, dim=1)
275 | attn_prob = self.dropatt(attn_prob)
276 |
277 | #### compute attention vector
278 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
279 |
280 | # [qlen x bsz x n_head x d_head]
281 | attn_vec = attn_vec.contiguous().view(
282 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
283 |
284 | ##### linear projection
285 | attn_out = self.o_net(attn_vec)
286 | attn_out = self.drop(attn_out)
287 |
288 | if self.pre_lnorm:
289 | ##### residual connection
290 | output = w + attn_out
291 | else:
292 | ##### residual connection + layer normalization
293 | output = self.layer_norm(w + attn_out)
294 |
295 | return output
296 |
297 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
298 | def __init__(self, *args, **kwargs):
299 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
300 |
301 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
302 | # r_emb: [klen, n_head, d_head], used for term B
303 | # r_w_bias: [n_head, d_head], used for term C
304 | # r_bias: [klen, n_head], used for term D
305 |
306 | qlen, bsz = w.size(0), w.size(1)
307 |
308 | if mems is not None:
309 | cat = torch.cat([mems, w], 0)
310 | if self.pre_lnorm:
311 | w_heads = self.qkv_net(self.layer_norm(cat))
312 | else:
313 | w_heads = self.qkv_net(cat)
314 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
315 |
316 | w_head_q = w_head_q[-qlen:]
317 | else:
318 | if self.pre_lnorm:
319 | w_heads = self.qkv_net(self.layer_norm(w))
320 | else:
321 | w_heads = self.qkv_net(w)
322 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
323 |
324 | klen = w_head_k.size(0)
325 |
326 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
327 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
328 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
329 |
330 | if klen > r_emb.size(0):
331 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
332 | r_emb = torch.cat([r_emb_pad, r_emb], 0)
333 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
334 | r_bias = torch.cat([r_bias_pad, r_bias], 0)
335 | else:
336 | r_emb = r_emb[-klen:]
337 | r_bias = r_bias[-klen:]
338 |
339 | #### compute attention score
340 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
341 |
342 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
343 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
344 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
345 | BD = self._rel_shift(B_ + D_)
346 |
347 | # [qlen x klen x bsz x n_head]
348 | attn_score = AC + BD
349 | attn_score.mul_(self.scale)
350 |
351 | #### compute attention probability
352 | if attn_mask is not None and attn_mask.any().item():
353 | if attn_mask.dim() == 2:
354 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
355 | elif attn_mask.dim() == 3:
356 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
357 |
358 | # [qlen x klen x bsz x n_head]
359 | attn_prob = F.softmax(attn_score, dim=1)
360 | attn_prob = self.dropatt(attn_prob)
361 |
362 | #### compute attention vector
363 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
364 |
365 | # [qlen x bsz x n_head x d_head]
366 | attn_vec = attn_vec.contiguous().view(
367 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
368 |
369 | ##### linear projection
370 | attn_out = self.o_net(attn_vec)
371 | attn_out = self.drop(attn_out)
372 |
373 | if self.pre_lnorm:
374 | ##### residual connection
375 | output = w + attn_out
376 | else:
377 | ##### residual connection + layer normalization
378 | output = self.layer_norm(w + attn_out)
379 |
380 | return output
381 |
382 |
383 |
384 |
385 | class DecoderLayer(nn.Module):
386 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
387 | super(DecoderLayer, self).__init__()
388 |
389 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
390 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
391 | pre_lnorm=kwargs.get('pre_lnorm'))
392 |
393 | def forward(self, dec_inp, dec_attn_mask=None, mems=None):
394 |
395 | output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
396 | mems=mems)
397 | output = self.pos_ff(output)
398 |
399 | return output
400 |
401 | class RelLearnableDecoderLayer(nn.Module):
402 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
403 | **kwargs):
404 | super(RelLearnableDecoderLayer, self).__init__()
405 |
406 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
407 | **kwargs)
408 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
409 | pre_lnorm=kwargs.get('pre_lnorm'))
410 |
411 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
412 |
413 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
414 | attn_mask=dec_attn_mask,
415 | mems=mems)
416 | output = self.pos_ff(output)
417 |
418 | return output
419 |
420 |
421 |
422 |
423 | class RelPartialLearnableDecoderLayer(nn.Module):
424 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
425 | **kwargs):
426 | super(RelPartialLearnableDecoderLayer, self).__init__()
427 |
428 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
429 | d_head, dropout, **kwargs)
430 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
431 | pre_lnorm=kwargs.get('pre_lnorm'))
432 |
433 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
434 |
435 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
436 | attn_mask=dec_attn_mask,
437 | mems=mems)
438 | output = self.pos_ff(output)
439 |
440 | return output
441 |
442 |
443 |
444 | class MyAttention(nn.Module):
445 | def __init__(self,hidden_size,attention_size):
446 | super(MyAttention, self).__init__()
447 | self.hidden_size=hidden_size
448 | self.attention_size = attention_size
449 |
450 | self.W_omega = nn.Parameter(torch.Tensor(self.attention_size,self.hidden_size))
451 | self.b_omega = nn.Parameter(torch.Tensor(self.attention_size))
452 | self.u_omega = nn.Parameter(torch.Tensor(1,self.attention_size))
453 |
454 | def forward(self, inputs, time_major=True, return_alphas=False):
455 | '''
456 |
457 | :param inputs: 这里的inputs 的shape必须是[batch,len,hididen]
458 | :param time_major:
459 | :param return_alphas:
460 | :return:
461 | '''
462 | # v = torch.tanh(torch.matmul(torch.reshape(inputs, [-1, hidden_size]), W_omega) + tf.reshape(b_omega, [1, -1]))
463 |
464 | v = torch.tanh(F.linear(inputs, self.W_omega, self.b_omega)) #[B, L, attention_size]
465 | vu = F.linear(v, self.u_omega) #[B,L,1]
466 | vu = torch.squeeze(vu)
467 | exps = torch.exp(vu)
468 |
469 | alphas = exps /(torch.sum(input=exps,dim=1, keepdim=True)+0.00001)
470 |
471 | # Output of Bi-RNN is reduced with attention vector
472 | output = torch.sum(inputs * torch.unsqueeze(alphas, 2), 1)
473 | return output
474 |
475 |
476 | # 许海明
477 | class AdaptiveEmbedding(nn.Module):
478 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
479 | sample_softmax=False):
480 |
481 | '''
482 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val)
483 | :param n_token:
484 | :param d_embed:
485 | :param d_proj:
486 | :param cutoffs:
487 | :param div_val:
488 | :param sample_softmax:
489 | '''
490 | super(AdaptiveEmbedding, self).__init__()
491 |
492 | self.n_token = n_token
493 | self.d_embed = d_embed
494 |
495 | self.cutoffs = cutoffs + [n_token]
496 | self.div_val = div_val
497 | self.d_proj = d_proj
498 |
499 | self.emb_scale = d_proj ** 0.5
500 |
501 | self.cutoff_ends = [0] + self.cutoffs
502 |
503 | self.emb_layers = nn.ModuleList()
504 | self.emb_projs = nn.ParameterList()
505 | if div_val == 1:
506 | self.emb_layers.append(
507 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
508 | )
509 | if d_proj != d_embed:
510 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
511 | else:
512 | for i in range(len(self.cutoffs)):
513 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
514 | d_emb_i = d_embed // (div_val ** i)
515 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
516 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
517 |
518 |
519 |
520 | def forward(self, inp):
521 | if self.div_val == 1:
522 | embed = self.emb_layers[0](inp)
523 | if self.d_proj != self.d_embed:
524 | embed = F.linear(embed, self.emb_projs[0])
525 | else:
526 | param = next(self.parameters())
527 | inp_flat = inp.view(-1)
528 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
529 | dtype=param.dtype, device=param.device)
530 | for i in range(len(self.cutoffs)):
531 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
532 |
533 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
534 | indices_i = mask_i.nonzero().squeeze()
535 |
536 | if indices_i.numel() == 0:
537 | continue
538 |
539 | inp_i = inp_flat.index_select(0, indices_i) - l_idx
540 | emb_i = self.emb_layers[i](inp_i)
541 | emb_i = F.linear(emb_i, self.emb_projs[i])
542 |
543 | emb_flat.index_copy_(0, indices_i, emb_i)
544 |
545 | embed = emb_flat.view(*inp.size(), self.d_proj)
546 |
547 | embed.mul_(self.emb_scale)
548 |
549 | return embed
550 |
551 |
552 |
553 |
554 | class MemTransformerLM(nn.Module):
555 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
556 | dropout, dropatt, tie_weight=True, d_embed=None,
557 | div_val=1, tie_projs=[False], pre_lnorm=False,
558 | tgt_len=None, ext_len=None, mem_len=None,
559 | cutoffs=[], adapt_inp=False,
560 | same_length=False, attn_type=0, clamp_len=-1,
561 | sample_softmax=-1):
562 | '''
563 |
564 | :param n_token: 单词表的大小
565 | :param n_layer: 16
566 | :param n_head: 10
567 | :param d_model: 410
568 | :param d_head: 41
569 | :param d_inner: 2100
570 | :param dropout: 0.1
571 | :param dropatt: 0.0
572 | :param tie_weight: True
573 | :param d_embed: 410
574 | :param div_val: 1
575 | :param tie_projs: [F,T,T,T]
576 | :param pre_lnorm: False
577 | :param tgt_len: 150
578 | :param ext_len: 0
579 | :param mem_len: 150
580 | :param cutoffs: [20000, 40000, 200000]
581 | :param adapt_inp:
582 | :param same_length: 训练的时候是False
583 | :param attn_type: 0
584 | :param clamp_len: -1
585 | :param sample_softmax: -1
586 | '''
587 | super(MemTransformerLM, self).__init__()
588 | self.n_token = n_token
589 |
590 | d_embed = d_model if d_embed is None else d_embed
591 | self.d_embed = d_embed
592 | self.d_model = d_model
593 | self.n_head = n_head
594 | self.d_head = d_head
595 |
596 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
597 | div_val=div_val)
598 |
599 | self.drop = nn.Dropout(dropout)
600 | self.n_layer = n_layer
601 | self.tgt_len = tgt_len
602 | self.mem_len = mem_len
603 | self.ext_len = ext_len
604 | self.max_klen = tgt_len + ext_len + mem_len
605 | self.attn_type = attn_type
606 | # 创建 attention
607 | self.attention_layer_low = MyAttention(self.d_model, self.d_model//2)
608 | self.attention_layer_high = MyAttention(self.d_model, self.d_model // 2)
609 |
610 | self.fc = nn.Sequential(
611 | nn.Linear(self.d_model,self.d_model),
612 | nn.BatchNorm1d(self.d_model),
613 | nn.ReLU(inplace=True),
614 | nn.Linear(self.d_model, 19)
615 | )
616 |
617 | self.layers = nn.ModuleList()
618 | if attn_type == 0: # the default attention
619 | for i in range(n_layer):
620 | self.layers.append(
621 | RelPartialLearnableDecoderLayer(
622 | n_head, d_model, d_head, d_inner, dropout,
623 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
624 | dropatt=dropatt, pre_lnorm=pre_lnorm)
625 | )
626 | elif attn_type == 1: # learnable embeddings
627 | for i in range(n_layer):
628 | self.layers.append(
629 | RelLearnableDecoderLayer(
630 | n_head, d_model, d_head, d_inner, dropout,
631 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
632 | dropatt=dropatt, pre_lnorm=pre_lnorm)
633 | )
634 | elif attn_type in [2, 3]: # absolute embeddings
635 | for i in range(n_layer):
636 | self.layers.append(
637 | DecoderLayer(
638 | n_head, d_model, d_head, d_inner, dropout,
639 | dropatt=dropatt, pre_lnorm=pre_lnorm)
640 | )
641 |
642 | self.sample_softmax = sample_softmax
643 | # use sampled softmax
644 | if sample_softmax > 0:
645 | self.out_layer = nn.Linear(d_model, n_token)
646 | if tie_weight:
647 | self.out_layer.weight = self.word_emb.weight
648 | self.tie_weight = tie_weight
649 | self.sampler = LogUniformSampler(n_token, sample_softmax)
650 |
651 | # use adaptive softmax (including standard softmax)
652 | else:
653 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
654 | cutoffs, div_val=div_val)
655 |
656 | if tie_weight:
657 | for i in range(len(self.crit.out_layers)):
658 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
659 |
660 | if tie_projs:
661 | for i, tie_proj in enumerate(tie_projs):
662 | if tie_proj and div_val == 1 and d_model != d_embed:
663 | self.crit.out_projs[i] = self.word_emb.emb_projs[0]
664 | elif tie_proj and div_val != 1:
665 | self.crit.out_projs[i] = self.word_emb.emb_projs[i]
666 |
667 | self.same_length = same_length
668 | self.clamp_len = clamp_len
669 |
670 | self._create_params()
671 |
672 | def backward_compatible(self):
673 | self.sample_softmax = -1
674 |
675 | def _create_params(self):
676 | if self.attn_type == 0: # default attention
677 | self.pos_emb = PositionalEmbedding(self.d_model)
678 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
679 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
680 |
681 | elif self.attn_type == 1: # learnable
682 | self.r_emb = nn.Parameter(torch.Tensor(
683 | self.n_layer, self.max_klen, self.n_head, self.d_head))
684 | self.r_w_bias = nn.Parameter(torch.Tensor(
685 | self.n_layer, self.n_head, self.d_head))
686 | self.r_bias = nn.Parameter(torch.Tensor(
687 | self.n_layer, self.max_klen, self.n_head))
688 | elif self.attn_type == 2: # absolute standard
689 | self.pos_emb = PositionalEmbedding(self.d_model)
690 | elif self.attn_type == 3: # absolute deeper SA
691 | self.r_emb = nn.Parameter(torch.Tensor(
692 | self.n_layer, self.max_klen, self.n_head, self.d_head))
693 |
694 | def reset_length(self, tgt_len, ext_len, mem_len):
695 | self.tgt_len = tgt_len
696 | self.mem_len = mem_len
697 | self.ext_len = ext_len
698 |
699 | def init_mems(self):
700 | if self.mem_len > 0:
701 | mems = []
702 | param = next(self.parameters())
703 | for i in range(self.n_layer+1):
704 | empty = torch.empty(0, dtype=param.dtype, device=param.device)
705 | mems.append(empty)
706 |
707 | return mems
708 | else:
709 | return None
710 |
711 | def _update_mems(self, hids, mems, qlen, mlen):
712 | # does not deal with None
713 | if mems is None: return None
714 |
715 | # mems is not None
716 | assert len(hids) == len(mems), 'len(hids) != len(mems)'
717 |
718 | # There are `mlen + qlen` steps that can be cached into mems
719 | # For the next step, the last `ext_len` of the `qlen` tokens
720 | # will be used as the extended context. Hence, we only cache
721 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
722 | # to `mlen + qlen - self.ext_len`.
723 | with torch.no_grad():
724 | new_mems = []
725 | end_idx = mlen + max(0, qlen - 0 - self.ext_len)
726 | beg_idx = max(0, end_idx - self.mem_len)
727 | for i in range(len(hids)):
728 | cat = torch.cat([mems[i], hids[i]], dim=0)
729 | new_mems.append(cat[beg_idx:end_idx].detach())
730 |
731 | return new_mems
732 |
733 | def _forward(self, dec_inp, mems=None):
734 | qlen, bsz = dec_inp.size()
735 | # print(qlen,bsz)
736 | word_emb = self.word_emb(dec_inp)
737 | # print(word_emb)
738 | mlen = mems[0].size(0) if mems is not None else 0
739 | klen = mlen + qlen
740 | if self.same_length:
741 | all_ones = word_emb.new_ones(qlen, klen)
742 | mask_len = klen - self.mem_len
743 | if mask_len > 0:
744 | mask_shift_len = qlen - mask_len
745 | else:
746 | mask_shift_len = qlen
747 | dec_attn_mask = (torch.triu(all_ones, mlen)
748 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
749 | else:
750 | # dec_attn_mask = torch.triu(
751 | # word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
752 | #
753 | dec_attn_mask = torch.triu(
754 | word_emb.new_ones(qlen, klen), diagonal=mlen).byte()[:,:,None]
755 |
756 |
757 | hids = []
758 | if self.attn_type == 0: # default
759 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
760 | dtype=word_emb.dtype)
761 | if self.clamp_len > 0:
762 | pos_seq.clamp_(max=self.clamp_len)
763 | pos_emb = self.pos_emb(pos_seq)
764 |
765 | core_out = self.drop(word_emb)
766 | pos_emb = self.drop(pos_emb)
767 |
768 | hids.append(core_out)
769 | for i, layer in enumerate(self.layers):
770 | mems_i = None if mems is None else mems[i]
771 | core_out = layer(core_out, pos_emb, self.r_w_bias,
772 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
773 | hids.append(core_out)
774 |
775 |
776 | elif self.attn_type == 1: # learnable
777 | core_out = self.drop(word_emb)
778 | hids.append(core_out)
779 | for i, layer in enumerate(self.layers):
780 | if self.clamp_len > 0:
781 | r_emb = self.r_emb[i][-self.clamp_len :]
782 | r_bias = self.r_bias[i][-self.clamp_len :]
783 | else:
784 | r_emb, r_bias = self.r_emb[i], self.r_bias[i]
785 |
786 | mems_i = None if mems is None else mems[i]
787 | core_out = layer(core_out, r_emb, self.r_w_bias[i],
788 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
789 | hids.append(core_out)
790 | elif self.attn_type == 2: # absolute
791 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
792 | dtype=word_emb.dtype)
793 | if self.clamp_len > 0:
794 | pos_seq.clamp_(max=self.clamp_len)
795 | pos_emb = self.pos_emb(pos_seq)
796 |
797 | core_out = self.drop(word_emb + pos_emb[-qlen:])
798 |
799 | hids.append(core_out)
800 | for i, layer in enumerate(self.layers):
801 | mems_i = None if mems is None else mems[i]
802 | if mems_i is not None and i == 0:
803 | mems_i += pos_emb[:mlen]
804 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
805 | mems=mems_i)
806 | hids.append(core_out)
807 | elif self.attn_type == 3:
808 | core_out = self.drop(word_emb)
809 |
810 | hids.append(core_out)
811 | for i, layer in enumerate(self.layers):
812 | mems_i = None if mems is None else mems[i]
813 | if mems_i is not None and mlen > 0:
814 | cur_emb = self.r_emb[i][:-qlen]
815 | cur_size = cur_emb.size(0)
816 | if cur_size < mlen:
817 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
818 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
819 | else:
820 | cur_emb = cur_emb[-mlen:]
821 | mems_i += cur_emb.view(mlen, 1, -1)
822 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
823 |
824 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
825 | mems=mems_i)
826 | hids.append(core_out)
827 |
828 | core_out = self.drop(core_out)
829 |
830 | new_mems = self._update_mems(hids, mems, mlen, qlen)
831 |
832 | return core_out, new_mems
833 |
834 | '''
835 | data, labels,bsz_len, args.tgt_len, device = 'cpu', ext_len = args.ext_len, *mems)
836 | '''
837 | def forward(self, data, labels, *mems, tgt_len=None, device=None, ext_len=None):
838 | # nn.DataParallel does not allow size(0) tensors to be broadcasted.
839 | # So, have to initialize size(0) mems inside the model forward.
840 | # Moreover, have to return new_mems to allow nn.DataParallel to piece
841 | # them together.
842 | if not mems:
843 | mems = self.init_mems()
844 |
845 | bsz_len = data.size(0)
846 |
847 | train_iter_one_batch = LMOrderedIterator(data, bsz_len, tgt_len, device=device, ext_len=ext_len)
848 |
849 | memory_low = []
850 | for index, (data, seq_len) in enumerate(train_iter_one_batch.get_fixlen_iter()):
851 | # tgt_len = target.size(0)
852 | # print(data)
853 | hidden, mems = self._forward(data, mems=mems)
854 | # print(mems)
855 | memory_low.append(self.attention_layer_low(hidden.permute(1,0,2)))
856 | #
857 | # print(memory_low)
858 | shapes=memory_low[0].size()
859 |
860 | # print(len(memory_low))
861 | # print(shapes)
862 | memory_high = torch.cat(memory_low, 0)
863 | memory_high = torch.reshape(memory_high, (-1, shapes[0], shapes[1]))
864 | # print(memory_high.size())
865 | res=self.attention_layer_high(memory_high.permute(1,0,2))
866 | #########下面添加映射层 计算损失函数##############
867 | # print(res)
868 | logits=self.fc(res)
869 | # print(logits.size())
870 | return logits,labels
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 | # # pred_hid = hidden[-tgt_len:]
879 | # # if self.sample_softmax > 0 and self.training:
880 | # # assert self.tie_weight
881 | # # logit = sample_logits(self.word_emb,
882 | # # self.out_layer.bias, target, pred_hid, self.sampler)
883 | # # loss = -F.log_softmax(logit, -1)[:, :, 0]
884 | # # else:
885 | # # loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
886 | # # loss = loss.view(tgt_len, -1)
887 | #
888 | # return loss
889 |
890 |
891 |
892 | if __name__ == '__main__':
893 | import argparse
894 |
895 | parser = argparse.ArgumentParser(description='unit test')
896 |
897 | parser.add_argument('--n_layer', type=int, default=4, help='')
898 | parser.add_argument('--n_rel_layer', type=int, default=4, help='')
899 | parser.add_argument('--n_head', type=int, default=2, help='')
900 | parser.add_argument('--d_head', type=int, default=2, help='')
901 | parser.add_argument('--d_model', type=int, default=200, help='')
902 | parser.add_argument('--d_embed', type=int, default=200, help='')
903 | parser.add_argument('--d_inner', type=int, default=200, help='')
904 | parser.add_argument('--dropout', type=float, default=0.0, help='')
905 | parser.add_argument('--cuda', action='store_true', help='')
906 | parser.add_argument('--seed', type=int, default=1111, help='')
907 | parser.add_argument('--multi_gpu', action='store_true', help='')
908 |
909 | args = parser.parse_args()
910 |
911 | device = torch.device("cuda" if args.cuda else "cpu")
912 |
913 | B = 4
914 | tgt_len, mem_len, ext_len = 36, 36, 0
915 | data_len = tgt_len * 20
916 | args.n_token = 10000
917 |
918 | import data_utils
919 |
920 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
921 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
922 |
923 | cutoffs = [args.n_token // 2]
924 | tie_projs = [False] + [True] * len(cutoffs)
925 |
926 | for div_val in [1, 2]:
927 | for d_embed in [200, 100]:
928 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
929 | args.d_model, args.d_head, args.d_inner, args.dropout,
930 | dropatt=args.dropout, tie_weight=True,
931 | d_embed=d_embed, div_val=div_val,
932 | tie_projs=tie_projs, pre_lnorm=True,
933 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
934 | cutoffs=cutoffs, attn_type=0).to(device)
935 |
936 | print(sum(p.numel() for p in model.parameters()))
937 |
938 | mems = tuple()
939 | for idx, (inp, tgt, seqlen) in enumerate(diter):
940 | print('batch {}'.format(idx))
941 | out = model(inp, tgt, *mems)
942 | mems = out[1:]
943 |
--------------------------------------------------------------------------------
/code_for_Classfy/myattention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class MyAttention(nn.Module):
6 | def __init__(self,hidden_size,attention_size):
7 | super(MyAttention, self).__init__()
8 | self.hidden_size=hidden_size
9 | self.attention_size = attention_size
10 |
11 | self.W_omega = nn.Parameter(torch.Tensor(self.attention_size,self.hidden_size))
12 | self.b_omega = nn.Parameter(torch.Tensor(self.attention_size))
13 | self.u_omega = nn.Parameter(torch.Tensor(1,self.attention_size))
14 |
15 | def forward(self, inputs, time_major=True, return_alphas=False):
16 | '''
17 |
18 | :param inputs: 这里的inputs 的shape必须是[batch,len,hididen]
19 | :param time_major:
20 | :param return_alphas:
21 | :return:
22 | '''
23 | # v = torch.tanh(torch.matmul(torch.reshape(inputs, [-1, hidden_size]), W_omega) + tf.reshape(b_omega, [1, -1]))
24 |
25 | v = torch.tanh(F.linear(inputs, self.W_omega, self.b_omega)) #[B, L, attention_size]
26 | print(v.size())
27 | vu = F.linear(v, self.u_omega) #[B,L,1]
28 | print(vu.size())
29 | vu = torch.squeeze(vu)
30 | print(vu.size())
31 | exps = torch.exp(vu)
32 | print(exps.size())
33 |
34 | alphas = exps /torch.sum(input=exps,dim=1, keepdim=True)
35 | print(alphas.size())
36 |
37 | # Output of Bi-RNN is reduced with attention vector
38 | output = torch.sum(inputs * torch.unsqueeze(alphas, 2), 1)
39 | return output
40 | #
41 | # myAttention = MyAttention(100,50)
42 | #
43 | # x = torch.randn(10,64,100)
44 | # y=x.permute(1,0,2)
45 | # res=myAttention(y)
46 | # print(res.size())
47 |
48 | x = torch.randn(2,2)
49 | print(x)
50 | y = torch.randn(2,2)
51 | z = torch.randn(2,2)
52 |
53 | a = [x,y,z]
54 |
55 | res= torch.cat(a, 0)
56 | res=torch.reshape(res, (-1, 2, 2))
57 | print(res.size())
58 | print(res[0])
--------------------------------------------------------------------------------
/code_for_Classfy/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | word_emb=torch.randn(5,10)
3 |
4 | dec_attn_mask = torch.triu(
5 | word_emb.new_ones(5,10), diagonal=5+1)
6 | print(dec_attn_mask)
--------------------------------------------------------------------------------
/code_for_Classfy/train.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import time
4 | import math
5 | import os, sys
6 | import itertools
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | import torch.nn.functional as F
14 |
15 | from data_utils import get_lm_corpus
16 | from mem_transformer import MemTransformerLM
17 | from utils.exp_utils import create_exp_dir
18 | from utils.data_parallel import BalancedDataParallel
19 | from sklearn import metrics
20 |
21 |
22 |
23 |
24 | os.environ["CUDA_VISIBLE_DEVICES"] = "5,7"
25 |
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
28 | parser.add_argument('--alinlen', type=int, default=3000, help='xhm')
29 | parser.add_argument('--debug', action='store_true',
30 | help='run in debug mode (do not create exp dir)')
31 |
32 | parser.add_argument('--work_dir', default='LM-TFM', type=str,
33 | help='experiment directory.')
34 |
35 | parser.add_argument('--init_model_LM', type=str, help='这是预训练语言模型的路径.')
36 |
37 | parser.add_argument('--not_tied', action='store_true',
38 | help='do not tie the word embedding and softmax weights')
39 | parser.add_argument('--data', type=str,
40 | help='location of the data corpus')
41 |
42 | parser.add_argument('--vocab_file', type=str, help='单词表的路径')
43 |
44 |
45 |
46 | parser.add_argument('--dataset', type=str, default='dccl',
47 | choices=['dccl'],
48 | help='dataset name')
49 |
50 | parser.add_argument('--seed', type=int, default=1111,
51 | help='random seed')
52 | parser.add_argument('--ext_len', type=int, default=0,
53 | help='length of the extended context')
54 |
55 |
56 | parser.add_argument('--n_layer', type=int, default=6,
57 | help='number of total layers')
58 | parser.add_argument('--d_model', type=int, default=410,
59 | help='model dimension')
60 | parser.add_argument('--n_head', type=int, default=10,
61 | help='number of heads')
62 | parser.add_argument('--d_head', type=int, default=41,
63 | help='head dimension')
64 | parser.add_argument('--d_embed', type=int, default=-1,
65 | help='embedding dimension')
66 |
67 | parser.add_argument('--cuda', action='store_true',
68 | help='use CUDA')
69 |
70 |
71 | parser.add_argument('--fp16', action='store_true',
72 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
73 |
74 | parser.add_argument('--batch_size', type=int, default=60,
75 | help='batch size')
76 |
77 |
78 | parser.add_argument('--tgt_len', type=int, default=150,
79 | help='number of tokens to predict')
80 | parser.add_argument('--eval_tgt_len', type=int, default=150,
81 | help='number of tokens to predict for evaluation')
82 |
83 | parser.add_argument('--adaptive', action='store_true',
84 | help='use adaptive softmax')
85 |
86 |
87 |
88 | parser.add_argument('--d_inner', type=int, default=1000,
89 | help='inner dimension in FF')
90 | parser.add_argument('--dropout', type=float, default=0.0,
91 | help='global dropout rate')
92 | parser.add_argument('--dropatt', type=float, default=0.0,
93 | help='attention probability dropout rate')
94 | parser.add_argument('--init', default='normal', type=str,
95 | help='parameter initializer to use.')
96 | parser.add_argument('--emb_init', default='normal', type=str,
97 | help='parameter initializer to use.')
98 | parser.add_argument('--init_range', type=float, default=0.1,
99 | help='parameters initialized by U(-init_range, init_range)')
100 | parser.add_argument('--emb_init_range', type=float, default=0.01,
101 | help='parameters initialized by U(-init_range, init_range)')
102 | parser.add_argument('--init_std', type=float, default=0.02,
103 | help='parameters initialized by N(0, init_std)')
104 | parser.add_argument('--proj_init_std', type=float, default=0.01,
105 | help='parameters initialized by N(0, init_std)')
106 | parser.add_argument('--optim', default='adam', type=str,
107 | choices=['adam', 'sgd', 'adagrad'],
108 | help='optimizer to use.')
109 | parser.add_argument('--lr', type=float, default=0.00025,
110 | help='initial learning rate (0.00025|5 for adam|sgd)')
111 | parser.add_argument('--mom', type=float, default=0.0,
112 | help='momentum for sgd')
113 | parser.add_argument('--scheduler', default='cosine', type=str,
114 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
115 | help='lr scheduler to use.')
116 | parser.add_argument('--warmup_step', type=int, default=0,
117 | help='upper epoch limit')
118 | parser.add_argument('--decay_rate', type=float, default=0.5,
119 | help='decay factor when ReduceLROnPlateau is used')
120 | parser.add_argument('--lr_min', type=float, default=0.0,
121 | help='minimum learning rate during annealing')
122 | parser.add_argument('--clip', type=float, default=0.25,
123 | help='gradient clipping')
124 | parser.add_argument('--clip_nonemb', action='store_true',
125 | help='only clip the gradient of non-embedding params')
126 | parser.add_argument('--max_step', type=int, default=100000,
127 | help='upper epoch limit')
128 |
129 | parser.add_argument('--batch_chunk', type=int, default=1,
130 | help='split batch into chunks to save memory')
131 |
132 | parser.add_argument('--mem_len', type=int, default=0,
133 | help='length of the retained previous heads')
134 | parser.add_argument('--restart_dir', type=str, default='',
135 | help='restart dir')
136 |
137 | parser.add_argument('--div_val', type=int, default=1,
138 | help='divident value for adapative input and softmax')
139 | parser.add_argument('--pre_lnorm', action='store_true',
140 | help='apply LayerNorm to the input instead of the output')
141 | parser.add_argument('--varlen', action='store_true',
142 | help='use variable length')
143 | parser.add_argument('--multi_gpu', action='store_true',
144 | help='use multiple GPU')
145 | parser.add_argument('--log-interval', type=int, default=200,
146 | help='report interval')
147 | parser.add_argument('--eval-interval', type=int, default=4000,
148 | help='evaluation interval')
149 |
150 | parser.add_argument('--restart', action='store_true',
151 | help='restart training from the saved checkpoint')
152 |
153 |
154 | parser.add_argument('--same_length', action='store_true',
155 | help='use the same attn length for all tokens')
156 | parser.add_argument('--attn_type', type=int, default=0,
157 | help='attention type. 0 for ours, 1 for Shaw et al,'
158 | '2 for Vaswani et al, 3 for Al Rfou et al.')
159 | parser.add_argument('--clamp_len', type=int, default=-1,
160 | help='use the same pos embeddings after clamp_len')
161 | parser.add_argument('--eta_min', type=float, default=0.0,
162 | help='min learning rate for cosine scheduler')
163 | parser.add_argument('--gpu0_bsz', type=int, default=-1,
164 | help='batch size on gpu 0')
165 | parser.add_argument('--max_eval_steps', type=int, default=-1,
166 | help='max eval steps')
167 | parser.add_argument('--sample_softmax', type=int, default=-1,
168 | help='number of samples in sampled softmax')
169 | parser.add_argument('--patience', type=int, default=0,
170 | help='patience')
171 | parser.add_argument('--finetune_v2', action='store_true',
172 | help='finetune v2')
173 | parser.add_argument('--finetune_v3', action='store_true',
174 | help='finetune v3')
175 |
176 | parser.add_argument('--static-loss-scale', type=float, default=1,
177 | help='Static loss scale, positive power of 2 values can '
178 | 'improve fp16 convergence.')
179 | parser.add_argument('--dynamic-loss-scale', action='store_true',
180 | help='Use dynamic loss scaling. If supplied, this argument'
181 | ' supersedes --static-loss-scale.')
182 |
183 |
184 | #####################################程序开始########################################################################################################
185 | args = parser.parse_args()
186 | args.tied = not args.not_tied # code True
187 |
188 | if args.d_embed < 0:
189 | args.d_embed = args.d_model # code 410
190 |
191 | assert args.ext_len >= 0, 'extended context length must be non-negative' # code 0
192 | assert args.batch_size % args.batch_chunk == 0
193 |
194 |
195 | logging = create_exp_dir(args.work_dir,
196 | scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)
197 |
198 |
199 |
200 | # Set the random seed manually for reproducibility.
201 | np.random.seed(args.seed)
202 | torch.manual_seed(args.seed)
203 | if torch.cuda.is_available():
204 | if not args.cuda:
205 | print('WARNING: You have a CUDA device, so you should probably run with --cuda')
206 | else:
207 | torch.cuda.manual_seed_all(args.seed)
208 |
209 | # Validate `--fp16` option
210 | if args.fp16:
211 | if not args.cuda:
212 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
213 | args.fp16 = False
214 | else:
215 | try:
216 | from apex.fp16_utils import FP16_Optimizer
217 | except:
218 | print('WARNING: apex not installed, ignoring --fp16 option')
219 | args.fp16 = False
220 |
221 | device = torch.device('cuda' if args.cuda else 'cpu')
222 |
223 | ###############################################################################
224 | # Load data
225 | ###############################################################################
226 |
227 | assert args.alinlen == 3000
228 | corpus = get_lm_corpus(args.data,args.vocab_file,args.alinlen)
229 | print("数据加载成功")
230 | ntokens = len(corpus.vocab) # code 单词表的大小
231 | args.n_token = ntokens
232 |
233 |
234 | eval_batch_size = 10
235 | tr_batch_iter = corpus.get_batch_iterator('train', args.batch_size, alianlen=3000, device='cpu')
236 | va_batch_iter = corpus.get_batch_iterator('valid', eval_batch_size, alianlen=3000, device=device)
237 |
238 | print("批次迭代器加载成功")
239 | # te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
240 | # device=device, ext_len=args.ext_len)
241 |
242 | # adaptive softmax / embedding
243 | cutoffs, tie_projs = [], [False]
244 | if args.adaptive:
245 | cutoffs = [70000, 150000, 250000]
246 | tie_projs += [False] * len(cutoffs)
247 |
248 |
249 | ###############################################################################
250 | # Build the model
251 | ###############################################################################
252 | def init_weight(weight):
253 | if args.init == 'uniform':
254 | nn.init.uniform_(weight, -args.init_range, args.init_range)
255 | elif args.init == 'normal':
256 | nn.init.normal_(weight, 0.0, args.init_std)
257 |
258 | def init_bias(bias):
259 | nn.init.constant_(bias, 0.0)
260 |
261 | # 许海明
262 | def weights_init(m):
263 | classname = m.__class__.__name__
264 | if classname.find('Linear') != -1:
265 | if hasattr(m, 'weight') and m.weight is not None:
266 | init_weight(m.weight)
267 | if hasattr(m, 'bias') and m.bias is not None:
268 | init_bias(m.bias)
269 | elif classname.find('AdaptiveEmbedding') != -1:
270 | if hasattr(m, 'emb_projs'):
271 | for i in range(len(m.emb_projs)):
272 | if m.emb_projs[i] is not None:
273 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
274 | elif classname.find('Embedding') != -1:
275 | if hasattr(m, 'weight'):
276 | init_weight(m.weight)
277 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
278 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
279 | init_weight(m.cluster_weight)
280 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
281 | init_bias(m.cluster_bias)
282 | if hasattr(m, 'out_projs'):
283 | for i in range(len(m.out_projs)):
284 | if m.out_projs[i] is not None:
285 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
286 | elif classname.find('LayerNorm') != -1:
287 | if hasattr(m, 'weight'):
288 | nn.init.normal_(m.weight, 1.0, args.init_std)
289 | if hasattr(m, 'bias') and m.bias is not None:
290 | init_bias(m.bias)
291 | elif classname.find('TransformerLM') != -1:
292 | if hasattr(m, 'r_emb'):
293 | init_weight(m.r_emb)
294 | if hasattr(m, 'r_w_bias'):
295 | init_weight(m.r_w_bias)
296 | if hasattr(m, 'r_r_bias'):
297 | init_weight(m.r_r_bias)
298 | if hasattr(m, 'r_bias'):
299 | init_bias(m.r_bias)
300 |
301 | def update_dropout(m):
302 | classname = m.__class__.__name__
303 | if classname.find('Dropout') != -1:
304 | if hasattr(m, 'p'):
305 | m.p = args.dropout
306 |
307 | def update_dropatt(m):
308 | if hasattr(m, 'dropatt'):
309 | m.dropatt.p = args.dropatt
310 |
311 | if args.restart:
312 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
313 | model = torch.load(f)
314 | if not args.fp16:
315 | model = model.float()
316 | model.apply(update_dropout)
317 | model.apply(update_dropatt)
318 | else:
319 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
320 | args.d_head, args.d_inner, args.dropout, args.dropatt,
321 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
322 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
323 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
324 | same_length=args.same_length, attn_type=args.attn_type,
325 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
326 | model.apply(weights_init)
327 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
328 | args.n_all_param = sum([p.nelement() for p in model.parameters()])
329 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
330 |
331 | if args.fp16:
332 | model = model.half()
333 |
334 |
335 |
336 |
337 |
338 | if args.init_model_LM is not None:
339 | logging('=' * 100)
340 | logging('=' * 100)
341 | logging('=' * 100)
342 | logging('=' * 100)
343 | logging("开始加载预训练的语言模型!!!!!!!!!!!!")
344 | model_dict = model.state_dict()
345 | model_LM = {k: v for k, v in torch.load(args.init_model_LM)['state_dict'].items() if k in model_dict}
346 | # print(model_LM)
347 | model_dict.update(model_LM)
348 | model.load_state_dict(model_dict)
349 |
350 | logging('=' * 100)
351 | logging('=' * 100)
352 | logging('=' * 100)
353 | logging('=' * 100)
354 | logging("完成加载预训练的语言模型!!!!!!!!!!!!")
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 | if args.multi_gpu:
366 | model = model.to(device)
367 | if args.gpu0_bsz >= 0:
368 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
369 | model, dim=0).to(device)
370 | else:
371 | para_model = nn.DataParallel(model, dim=0).to(device)
372 | else:
373 | para_model = model.to(device)
374 |
375 | print(" para_model = nn.DataParallel(model, dim=0).to(device) ")
376 |
377 |
378 | #### optimizer
379 | if args.optim.lower() == 'sgd':
380 | if args.sample_softmax > 0:
381 | dense_params, sparse_params = [], []
382 | for param in model.parameters():
383 | if param.size() == model.word_emb.weight.size():
384 | sparse_params.append(param)
385 | else:
386 | dense_params.append(param)
387 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
388 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
389 | else:
390 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
391 | momentum=args.mom)
392 | elif args.optim.lower() == 'adam':
393 | if args.sample_softmax > 0:
394 | dense_params, sparse_params = [], []
395 | for param in model.parameters():
396 | if param.size() == model.word_emb.weight.size():
397 | sparse_params.append(param)
398 | else:
399 | dense_params.append(param)
400 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
401 | optimizer = optim.Adam(dense_params, lr=args.lr)
402 | else:
403 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
404 | elif args.optim.lower() == 'adagrad':
405 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
406 |
407 | #### scheduler
408 | if args.scheduler == 'cosine':
409 | # here we do not set eta_min to lr_min to be backward compatible
410 | # because in previous versions eta_min is default to 0
411 | # rather than the default value of lr_min 1e-6
412 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
413 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
414 | if args.sample_softmax > 0:
415 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
416 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
417 | elif args.scheduler == 'inv_sqrt':
418 | # originally used for Transformer (in Attention is all you need)
419 | def lr_lambda(step):
420 | # return a multiplier instead of a learning rate
421 | if step == 0 and args.warmup_step == 0:
422 | return 1.
423 | else:
424 | return 1. / (step ** 0.5) if step > args.warmup_step \
425 | else step / (args.warmup_step ** 1.5)
426 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
427 | elif args.scheduler == 'dev_perf':
428 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
429 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
430 | if args.sample_softmax > 0:
431 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
432 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
433 | elif args.scheduler == 'constant':
434 | pass
435 |
436 | if args.cuda and args.fp16:
437 | # If args.dynamic_loss_scale is False, static_loss_scale will be used.
438 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
439 | optimizer = FP16_Optimizer(optimizer,
440 | static_loss_scale = args.static_loss_scale,
441 | dynamic_loss_scale = args.dynamic_loss_scale,
442 | dynamic_loss_args = {'init_scale': 2 ** 16})
443 |
444 | if args.restart:
445 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
446 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
447 | opt_state_dict = torch.load(f)
448 | optimizer.load_state_dict(opt_state_dict)
449 | else:
450 | print('Optimizer was not saved. Start from scratch.')
451 |
452 | # logging('=' * 100)
453 | # for k, v in args.__dict__.items():
454 | # logging(' - {} : {}'.format(k, v))
455 | logging('=' * 100)
456 | logging('#params = {}'.format(args.n_all_param))
457 | logging('#non emb params = {}'.format(args.n_nonemb_param))
458 |
459 | ###############################################################################
460 | # Training code
461 | ###############################################################################
462 |
463 | def evaluate(eval_iter):
464 | # Turn on evaluation mode which disables dropout.
465 | model.eval()
466 |
467 | acc_n = 0
468 | val_n = 0
469 | predict = np.zeros((0,), dtype=np.int32)
470 | gt = np.zeros((0,), dtype=np.int32)
471 |
472 | # If the model does not use memory at all, make the ext_len longer.
473 | # Otherwise, make the mem_len longer and keep the ext_len the same.
474 | if args.mem_len == 0:
475 | model.reset_length(args.eval_tgt_len,
476 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
477 | else:
478 | model.reset_length(args.eval_tgt_len,
479 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
480 |
481 | with torch.no_grad():
482 | mems = tuple()
483 | for i, (data, labels_all, bsz_len) in enumerate(eval_iter):
484 | if args.max_eval_steps > 0 and i >= args.max_eval_steps:
485 | break
486 | outputs,labels = model(data, labels_all, *mems, tgt_len=args.tgt_len, device=device, ext_len=args.ext_len)
487 | pred = outputs.max(1)[1]
488 | acc_n += (pred == labels).sum().item()
489 | val_n += labels.size(0)
490 | predict = np.hstack((predict, pred.cpu().numpy()))
491 | gt = np.hstack((gt, labels.cpu().numpy()))
492 |
493 |
494 |
495 | acc = 100. * acc_n / val_n
496 | f1score = np.mean(metrics.f1_score(predict, gt, average=None))
497 | print('* Test Acc: {:.3f}%({}/{}), F1 Score: {}'.format(acc, acc_n, val_n, f1score))
498 |
499 | # Switch back to the training mode
500 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
501 | model.train()
502 | return f1score
503 |
504 |
505 | def train():
506 | # Turn on training mode which enables dropout.
507 | global train_step, log_start_time, best_score
508 | train_loss=0.0
509 | correct = 0
510 | total = 0
511 |
512 |
513 | model.train()
514 | if args.batch_chunk > 1:
515 | mems = [tuple() for _ in range(args.batch_chunk)]
516 | else:
517 | mems = tuple()
518 | batch_iter = tr_batch_iter.get_fixlen_iter()
519 |
520 | criterion = F.cross_entropy
521 |
522 | for batch, (data, labels_all, bsz_len) in enumerate(batch_iter):
523 | model.zero_grad()
524 | if args.batch_chunk > 1:
525 | data_chunks = torch.chunk(data, args.batch_chunk, 1)
526 | target_chunks = torch.chunk(labels, args.batch_chunk, 1)
527 | for i in range(args.batch_chunk):
528 | data_i = data_chunks[i].contiguous()
529 | target_i = target_chunks[i].contiguous()
530 | ret = para_model(data_i, target_i, *mems[i])
531 | loss, mems[i] = ret[0], ret[1:]
532 | loss = loss.float().mean().type_as(loss) / args.batch_chunk
533 | if args.fp16:
534 | optimizer.backward(loss)
535 | else:
536 | loss.backward()
537 | train_loss += loss.float().item()
538 | else:
539 | pred,labels = para_model(data, labels_all, *mems, tgt_len=args.tgt_len, device=device, ext_len=args.ext_len)
540 |
541 | # print(pred)
542 |
543 | loss = criterion(pred, labels)
544 |
545 | if args.fp16:
546 | optimizer.backward(loss)
547 | else:
548 | loss.backward()
549 |
550 | # 更新指标
551 | train_loss += loss.float().item()
552 | predicted = pred.max(1)[1]
553 | total += labels.size(0)
554 | correct += predicted.eq(labels).sum().item()
555 |
556 |
557 |
558 | if args.fp16:
559 | optimizer.clip_master_grads(args.clip)
560 | else:
561 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
562 |
563 | optimizer.step()
564 |
565 |
566 |
567 | if args.sample_softmax > 0:
568 | optimizer_sparse.step()
569 |
570 | # step-wise learning rate annealing
571 | train_step += 1
572 | if args.scheduler in ['cosine', 'constant', 'dev_perf']:
573 | # linear warmup stage
574 | if train_step < args.warmup_step:
575 | curr_lr = args.lr * train_step / args.warmup_step
576 | optimizer.param_groups[0]['lr'] = curr_lr
577 | if args.sample_softmax > 0:
578 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
579 | else:
580 | if args.scheduler == 'cosine':
581 | scheduler.step(train_step)
582 | if args.sample_softmax > 0:
583 | scheduler_sparse.step(train_step)
584 | elif args.scheduler == 'inv_sqrt':
585 | scheduler.step(train_step)
586 |
587 | if train_step % args.log_interval == 0:
588 | cur_loss = train_loss / args.log_interval
589 | elapsed = time.time() - log_start_time
590 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
591 | '| ms/batch {:5.2f} | loss {:5.2f} | Acc: {:.3f}%({}/{})'.format(
592 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
593 | elapsed * 1000 / args.log_interval, cur_loss, 100. * correct / total, correct, total)
594 |
595 | logging(log_str)
596 | train_loss = 0.0
597 | log_start_time = time.time()
598 |
599 | if train_step % args.eval_interval == 0:
600 | f1score = evaluate(va_batch_iter)
601 | logging('-' * 100)
602 |
603 | # Save the model if the validation loss is the best we've seen so far.
604 |
605 | if f1score > best_score:
606 | save_path=os.path.join(args.work_dir, 'model.pt')
607 | best_score = f1score
608 | checkpoint = {
609 | 'state_dict': model.state_dict(),
610 | 'config': args
611 | }
612 | torch.save(checkpoint, save_path)
613 | print('Best tmp model f1score: {}'.format(best_score))
614 |
615 | # # dev-performance based learning rate annealing
616 | # if args.scheduler == 'dev_perf':
617 | # scheduler.step(val_loss)
618 | # if args.sample_softmax > 0:
619 | # scheduler_sparse.step(val_loss)
620 |
621 |
622 | if train_step == args.max_step:
623 | break
624 |
625 | # Loop over epochs.
626 | train_step = 0
627 | best_score=0.0
628 |
629 |
630 | log_start_time = time.time()
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 | # At any point you can hit Ctrl + C to break out of training early.
640 | try:
641 | for epoch in itertools.count(start=1):
642 | print("许海明提醒你:开始",epoch)
643 | train()
644 | if train_step == args.max_step:
645 | logging('-' * 100)
646 | logging('End of training')
647 | break
648 | except KeyboardInterrupt:
649 | logging('-' * 100)
650 | logging('Exiting from training early')
651 |
652 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class AdaptiveLogSoftmax(nn.Module):
10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
11 | super(AdaptiveLogSoftmax, self).__init__()
12 |
13 | cutoffs = list(cutoffs)
14 |
15 | if (cutoffs != sorted(cutoffs)) \
16 | or (min(cutoffs) <= 0) \
17 | or (max(cutoffs) >= (n_classes - 1)) \
18 | or (len(set(cutoffs)) != len(cutoffs)) \
19 | or any([int(c) != c for c in cutoffs]):
20 |
21 | raise ValueError("cutoffs should be a sequence of unique, positive "
22 | "integers sorted in an increasing order, where "
23 | "each value is between 1 and n_classes-1")
24 |
25 | self.in_features = in_features
26 | self.n_classes = n_classes
27 | self.cutoffs = cutoffs + [n_classes]
28 |
29 | self.shortlist_size = self.cutoffs[0]
30 | self.n_clusters = len(self.cutoffs) - 1
31 | self.head_size = self.shortlist_size + self.n_clusters
32 |
33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
35 |
36 | self.keep_order = keep_order
37 |
38 |
39 | def forward(self, hidden, target, weight, bias, keep_order=False):
40 | if hidden.size(0) != target.size(0):
41 | raise RuntimeError('Input and target should have the same size '
42 | 'in the batch dimension.')
43 |
44 | head_weight = torch.cat(
45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0)
46 | head_bias = torch.cat(
47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0)
48 |
49 | head_logit = F.linear(hidden, head_weight, bias=head_bias)
50 | head_logprob = F.log_softmax(head_logit, dim=1)
51 |
52 | nll = torch.zeros_like(target,
53 | dtype=hidden.dtype, device=hidden.device)
54 |
55 | offset = 0
56 | cutoff_values = [0] + self.cutoffs
57 | for i in range(len(cutoff_values) - 1):
58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
59 |
60 | mask_i = (target >= l_idx) & (target < h_idx)
61 | indices_i = mask_i.nonzero().squeeze()
62 |
63 | if indices_i.numel() == 0:
64 | continue
65 |
66 | target_i = target.index_select(0, indices_i) - l_idx
67 | head_logprob_i = head_logprob.index_select(0, indices_i)
68 |
69 | if i == 0:
70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
71 | else:
72 | weight_i = weight[l_idx:h_idx]
73 | bias_i = bias[l_idx:h_idx]
74 |
75 | hidden_i = hidden.index_select(0, indices_i)
76 |
77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
79 |
80 | logprob_i = head_logprob_i[:, -i] \
81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
82 |
83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
84 | nll.index_copy_(0, indices_i, -logprob_i)
85 | else:
86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
87 |
88 | offset += logprob_i.size(0)
89 |
90 | return nll
91 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn.parallel import DataParallel
3 | import torch
4 | from torch.nn.parallel._functions import Scatter
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8 | r"""
9 | Slices tensors into approximately equal chunks and
10 | distributes them across given GPUs. Duplicates
11 | references to objects that are not tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, torch.Tensor):
15 | try:
16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17 | except:
18 | print('obj', obj.size())
19 | print('dim', dim)
20 | print('chunk_sizes', chunk_sizes)
21 | quit()
22 | if isinstance(obj, tuple) and len(obj) > 0:
23 | return list(zip(*map(scatter_map, obj)))
24 | if isinstance(obj, list) and len(obj) > 0:
25 | return list(map(list, zip(*map(scatter_map, obj))))
26 | if isinstance(obj, dict) and len(obj) > 0:
27 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28 | return [obj for targets in target_gpus]
29 |
30 | # After scatter_map is called, a scatter_map cell will exist. This cell
31 | # has a reference to the actual function scatter_map, which has references
32 | # to a closure that has a reference to the scatter_map cell (because the
33 | # fn is recursive). To avoid this reference cycle, we set the function to
34 | # None, clearing the cell
35 | try:
36 | return scatter_map(inputs)
37 | finally:
38 | scatter_map = None
39 |
40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
41 | r"""Scatter with support for kwargs dictionary"""
42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
44 | if len(inputs) < len(kwargs):
45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
46 | elif len(kwargs) < len(inputs):
47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
48 | inputs = tuple(inputs)
49 | kwargs = tuple(kwargs)
50 | return inputs, kwargs
51 |
52 | class BalancedDataParallel(DataParallel):
53 | def __init__(self, gpu0_bsz, *args, **kwargs):
54 | self.gpu0_bsz = gpu0_bsz
55 | super().__init__(*args, **kwargs)
56 |
57 | def forward(self, *inputs, **kwargs):
58 | if not self.device_ids:
59 | return self.module(*inputs, **kwargs)
60 | if self.gpu0_bsz == 0:
61 | device_ids = self.device_ids[1:]
62 | else:
63 | device_ids = self.device_ids
64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
65 | if len(self.device_ids) == 1:
66 | return self.module(*inputs[0], **kwargs[0])
67 | replicas = self.replicate(self.module, self.device_ids)
68 | if self.gpu0_bsz == 0:
69 | replicas = replicas[1:]
70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
71 | return self.gather(outputs, self.output_device)
72 |
73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs):
74 | return parallel_apply(replicas, inputs, kwargs, device_ids)
75 |
76 | def scatter(self, inputs, kwargs, device_ids):
77 | bsz = inputs[0].size(self.dim)
78 | num_dev = len(self.device_ids)
79 | gpu0_bsz = self.gpu0_bsz
80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
81 | if gpu0_bsz < bsz_unit:
82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
83 | delta = bsz - sum(chunk_sizes)
84 | for i in range(delta):
85 | chunk_sizes[i + 1] += 1
86 | if gpu0_bsz == 0:
87 | chunk_sizes = chunk_sizes[1:]
88 | else:
89 | return super().scatter(inputs, kwargs, device_ids)
90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
91 |
92 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/exp_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os, shutil
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 |
9 | def logging(s, log_path, print_=True, log_=True):
10 | if print_:
11 | print(s)
12 | if log_:
13 | with open(log_path, 'a+') as f_log:
14 | f_log.write(s + '\n')
15 |
16 | def get_logger(log_path, **kwargs):
17 | return functools.partial(logging, log_path=log_path, **kwargs)
18 |
19 |
20 | # 许海明
21 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
22 | if debug:
23 | print('Debug Mode : no experiment dir created')
24 | return functools.partial(logging, log_path=None, log_=False)
25 |
26 | if not os.path.exists(dir_path):
27 | os.makedirs(dir_path)
28 |
29 | print('Experiment dir : {}'.format(dir_path))
30 | if scripts_to_save is not None:
31 | script_path = os.path.join(dir_path, 'scripts')
32 | if not os.path.exists(script_path):
33 | os.makedirs(script_path)
34 | for script in scripts_to_save:
35 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
36 | shutil.copyfile(script, dst_file)
37 |
38 | return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
39 |
40 | def save_checkpoint(model, optimizer, path, epoch):
41 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
42 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
43 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/log_uniform_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | class LogUniformSampler(object):
6 | def __init__(self, range_max, n_sample):
7 | """
8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
10 |
11 | expected count can be approximated by 1 - (1 - p)^n
12 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
13 |
14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
15 | """
16 | with torch.no_grad():
17 | self.range_max = range_max
18 | log_indices = torch.arange(1., range_max+2., 1.).log_()
19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
20 | # print('P', self.dist.numpy().tolist()[-30:])
21 |
22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
23 |
24 | self.n_sample = n_sample
25 |
26 | def sample(self, labels):
27 | """
28 | labels: [b1, b2]
29 | Return
30 | true_log_probs: [b1, b2]
31 | samp_log_probs: [n_sample]
32 | neg_samples: [n_sample]
33 | """
34 |
35 | # neg_samples = torch.empty(0).long()
36 | n_sample = self.n_sample
37 | n_tries = 2 * n_sample
38 |
39 | with torch.no_grad():
40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
41 | device = labels.device
42 | neg_samples = neg_samples.to(device)
43 | true_log_probs = self.log_q[labels].to(device)
44 | samp_log_probs = self.log_q[neg_samples].to(device)
45 | return true_log_probs, samp_log_probs, neg_samples
46 |
47 | def sample_logits(embedding, bias, labels, inputs, sampler):
48 | """
49 | embedding: an nn.Embedding layer
50 | bias: [n_vocab]
51 | labels: [b1, b2]
52 | inputs: [b1, b2, n_emb]
53 | sampler: you may use a LogUniformSampler
54 | Return
55 | logits: [b1, b2, 1 + n_sample]
56 | """
57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
58 | n_sample = neg_samples.size(0)
59 | b1, b2 = labels.size(0), labels.size(1)
60 | all_ids = torch.cat([labels.view(-1), neg_samples])
61 | all_w = embedding(all_ids)
62 | true_w = all_w[: -n_sample].view(b1, b2, -1)
63 | sample_w = all_w[- n_sample:].view(n_sample, -1)
64 |
65 | all_b = bias[all_ids]
66 | true_b = all_b[: -n_sample].view(b1, b2)
67 | sample_b = all_b[- n_sample:]
68 |
69 | hit = (labels[:, :, None] == neg_samples).detach()
70 |
71 | true_logits = torch.einsum('ijk,ijk->ij',
72 | [true_w, inputs]) + true_b - true_log_probs
73 | sample_logits = torch.einsum('lk,ijk->ijl',
74 | [sample_w, inputs]) + sample_b - samp_log_probs
75 | sample_logits.masked_fill_(hit, -1e30)
76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
77 |
78 | return logits
79 |
80 |
81 | # class LogUniformSampler(object):
82 | # def __init__(self, range_max, unique=False):
83 | # """
84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
86 | # """
87 | # self.range_max = range_max
88 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
90 |
91 | # self.unique = unique
92 |
93 | # if self.unique:
94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
95 |
96 | # def sample(self, n_sample, labels):
97 | # pos_sample, new_labels = labels.unique(return_inverse=True)
98 | # n_pos_sample = pos_sample.size(0)
99 | # n_neg_sample = n_sample - n_pos_sample
100 |
101 | # if self.unique:
102 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
104 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
105 | # else:
106 | # sample_dist = self.dist
107 |
108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
109 |
110 | # sample = torch.cat([pos_sample, neg_sample])
111 | # sample_prob = self.dist[sample]
112 |
113 | # return new_labels, sample, sample_prob
114 |
115 |
116 | if __name__ == '__main__':
117 | S, B = 3, 4
118 | n_vocab = 10000
119 | n_sample = 5
120 | H = 32
121 |
122 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
123 |
124 | # sampler = LogUniformSampler(n_vocab, unique=False)
125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
126 |
127 | sampler = LogUniformSampler(n_vocab, unique=True)
128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
129 |
130 | # print('true_probs', true_probs.numpy().tolist())
131 | # print('samp_probs', samp_probs.numpy().tolist())
132 | # print('neg_samples', neg_samples.numpy().tolist())
133 |
134 | # print('sum', torch.sum(sampler.dist).item())
135 |
136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
137 |
138 | embedding = nn.Embedding(n_vocab, H)
139 | bias = torch.zeros(n_vocab)
140 | inputs = torch.Tensor(S, B, H).normal_()
141 |
142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
143 | print('logits', logits.detach().numpy().tolist())
144 | print('logits shape', logits.size())
145 | print('out_labels', out_labels.detach().numpy().tolist())
146 | print('out_labels shape', out_labels.size())
147 |
148 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/proj_adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1])
11 |
12 | class ProjectedAdaptiveLogSoftmax(nn.Module):
13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
14 | keep_order=False):
15 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
16 |
17 | self.n_token = n_token
18 | self.d_embed = d_embed
19 | self.d_proj = d_proj
20 |
21 | self.cutoffs = cutoffs + [n_token]
22 | self.cutoff_ends = [0] + self.cutoffs
23 | self.div_val = div_val
24 |
25 | self.shortlist_size = self.cutoffs[0]
26 | self.n_clusters = len(self.cutoffs) - 1
27 | self.head_size = self.shortlist_size + self.n_clusters
28 |
29 | if self.n_clusters > 0:
30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
32 |
33 | self.out_layers = nn.ModuleList()
34 | self.out_projs = nn.ParameterList()
35 |
36 | if div_val == 1:
37 | for i in range(len(self.cutoffs)):
38 | if d_proj != d_embed:
39 | self.out_projs.append(
40 | nn.Parameter(torch.Tensor(d_proj, d_embed))
41 | )
42 | else:
43 | self.out_projs.append(None)
44 |
45 | self.out_layers.append(nn.Linear(d_embed, n_token))
46 | else:
47 | for i in range(len(self.cutoffs)):
48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
49 | d_emb_i = d_embed // (div_val ** i)
50 |
51 | self.out_projs.append(
52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
53 | )
54 |
55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
56 |
57 | self.keep_order = keep_order
58 |
59 | def _compute_logit(self, hidden, weight, bias, proj):
60 | if proj is None:
61 | logit = F.linear(hidden, weight, bias=bias)
62 | else:
63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
64 | proj_hid = F.linear(hidden, proj.t().contiguous())
65 | logit = F.linear(proj_hid, weight, bias=bias)
66 | # else:
67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
68 | # if bias is not None:
69 | # logit = logit + bias
70 |
71 | return logit
72 |
73 | def forward(self, hidden, target, keep_order=False):
74 | '''
75 | hidden :: [len*bsz x d_proj]
76 | target :: [len*bsz]
77 | '''
78 |
79 | if hidden.size(0) != target.size(0):
80 | raise RuntimeError('Input and target should have the same size '
81 | 'in the batch dimension.')
82 |
83 | if self.n_clusters == 0:
84 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
85 | self.out_layers[0].bias, self.out_projs[0])
86 | nll = -F.log_softmax(logit, dim=-1) \
87 | .gather(1, target.unsqueeze(1)).squeeze(1)
88 | else:
89 | # construct weights and biases
90 | weights, biases = [], []
91 | for i in range(len(self.cutoffs)):
92 | if self.div_val == 1:
93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
94 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
95 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
96 | else:
97 | weight_i = self.out_layers[i].weight
98 | bias_i = self.out_layers[i].bias
99 |
100 | if i == 0:
101 | weight_i = torch.cat(
102 | [weight_i, self.cluster_weight], dim=0)
103 | bias_i = torch.cat(
104 | [bias_i, self.cluster_bias], dim=0)
105 |
106 | weights.append(weight_i)
107 | biases.append(bias_i)
108 |
109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
110 |
111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
112 | head_logprob = F.log_softmax(head_logit, dim=1)
113 |
114 | nll = torch.zeros_like(target,
115 | dtype=hidden.dtype, device=hidden.device)
116 |
117 | offset = 0
118 | cutoff_values = [0] + self.cutoffs
119 | for i in range(len(cutoff_values) - 1):
120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
121 |
122 | mask_i = (target >= l_idx) & (target < r_idx)
123 | indices_i = mask_i.nonzero().squeeze()
124 |
125 | if indices_i.numel() == 0:
126 | continue
127 |
128 | target_i = target.index_select(0, indices_i) - l_idx
129 | head_logprob_i = head_logprob.index_select(0, indices_i)
130 |
131 | if i == 0:
132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
133 | else:
134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
135 |
136 | hidden_i = hidden.index_select(0, indices_i)
137 |
138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
140 |
141 | logprob_i = head_logprob_i[:, -i] \
142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
143 |
144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
145 | nll.index_copy_(0, indices_i, -logprob_i)
146 | else:
147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
148 |
149 | offset += logprob_i.size(0)
150 |
151 | return nll
152 |
--------------------------------------------------------------------------------
/code_for_Classfy/utils/vocabulary.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import Counter, OrderedDict
3 |
4 | import torch
5 |
6 |
7 |
8 | # 这是词典类
9 | class Vocab(object):
10 | def __init__(self, alinlen=3000,special=[], min_freq=10, max_size=None, lower_case=True,
11 | delimiter=None, vocab_file=None):
12 | '''
13 | code
14 | :param special:
15 | :param min_freq:
16 | :param max_size:
17 | :param lower_case:
18 | :param delimiter:
19 | :param vocab_file:
20 | '''
21 | self.counter = Counter()
22 | self.special = special
23 | self.min_freq = min_freq
24 | self.max_size = max_size
25 | self.lower_case = lower_case
26 | self.delimiter = delimiter
27 | self.vocab_file = vocab_file
28 | self.alinlen = alinlen
29 |
30 |
31 | # 许海明
32 | def tokenize(self, line):
33 | line = line.strip()
34 | # convert to lower case
35 | if self.lower_case:
36 | line = line.lower()
37 |
38 | # empty delimiter '' will evaluate False
39 | if self.delimiter == '':
40 | symbols = line
41 | else:
42 | symbols = line.split(self.delimiter)
43 |
44 |
45 | if len(symbols) > self.alinlen-2:
46 | symbols = symbols[:self.alinlen-2]
47 |
48 | symbols = [''] + symbols + ['']
49 |
50 | len_pre = len(symbols)
51 | if len_pre == self.alinlen:
52 | return symbols
53 | else:
54 | assert len_pre']*(self.alinlen-len_pre)
56 | new_symbols.extend(symbols)
57 | assert len(new_symbols)==self.alinlen
58 | return new_symbols
59 |
60 |
61 |
62 | '''
63 | 统计单词
64 | 并返回分好词的数据 每一行元
65 | '''
66 | def count_file(self, path, verbose=False):
67 | if verbose:
68 | print('counting file {} ...'.format(path))
69 | assert os.path.exists(path)
70 |
71 | sents = []
72 | with open(path, 'r', encoding='utf-8') as f:
73 | for idx, line in enumerate(f):
74 | if verbose and idx > 0 and idx % 500000 == 0:
75 | print(' line {}'.format(idx))
76 | symbols = self.tokenize(line)
77 | self.counter.update(symbols)
78 | sents.append(symbols)
79 |
80 | return sents
81 |
82 | def count_sents(self, sents, verbose=False):
83 | """
84 | sents : a list of sentences, each a list of tokenized symbols
85 | """
86 | if verbose: print('counting {} sents ...'.format(len(sents)))
87 | for idx, symbols in enumerate(sents):
88 | if verbose and idx > 0 and idx % 500000 == 0:
89 | print(' line {}'.format(idx))
90 | self.counter.update(symbols)
91 |
92 | # 许海明
93 | def _build_from_file(self, vocab_file):
94 | self.idx2sym = []
95 | self.sym2idx = OrderedDict()
96 |
97 | with open(vocab_file, 'r', encoding='utf-8') as f:
98 | for line in f:
99 | symb = line.strip().split()[0]
100 | self.add_symbol(symb)
101 | self.pad_idx = self.sym2idx['']
102 | self.s_idx = self.sym2idx['']
103 | self.unk_idx = self.sym2idx['']
104 | setattr(self, '{}_idx'.format("".strip('<>')), self.sym2idx[''])
105 |
106 |
107 | # 许海明
108 | def build_vocab(self):
109 | if self.vocab_file:
110 | print('building vocab from {}'.format(self.vocab_file))
111 | self._build_from_file(self.vocab_file)
112 | print('final vocab size {}'.format(len(self)))
113 | else:
114 |
115 | raise ("单词应该从单词文件中提取,不应该重新创建")
116 | print('building vocab with min_freq={}, max_size={}'.format(self.min_freq, self.max_size))
117 | self.idx2sym = []
118 | self.sym2idx = OrderedDict()
119 | for sym in self.special:
120 | self.add_special(sym)
121 |
122 | for sym, cnt in self.counter.most_common(self.max_size):
123 | if cnt < self.min_freq:
124 | break
125 | self.add_symbol(sym)
126 |
127 | print('final vocab size {} from {} unique tokens'.format(
128 | len(self), len(self.counter)))
129 |
130 |
131 | # 许海明 对一个文件进行编码
132 |
133 | def encode_file_only_for_lables(self, path, verbose=True):
134 |
135 | if verbose:
136 | print('正在编码标签 {} ...'.format(path))
137 | assert os.path.exists(path)
138 | encoded = []
139 | with open(path, 'r', encoding='utf-8') as f:
140 | for idx, line in enumerate(f):
141 | line = line.strip()
142 | if line is None or line == "":
143 | continue
144 | label = int(line)-1 ##################################### 注意这里减去一个 1 #############################################
145 | if verbose and idx > 0 and idx % 500000 == 0:
146 | print(' line {}'.format(idx))
147 |
148 | encoded.append(label)
149 |
150 |
151 | encoded = torch.LongTensor(encoded)
152 |
153 | return encoded
154 |
155 | def encode_file(self, path, verbose=False):
156 | '''
157 | ordered=True
158 | :param path:
159 | :param ordered:
160 | :param verbose:
161 | :param add_eos:
162 | :param add_double_eos:
163 | :return:
164 | '''
165 | if verbose:
166 | print('encoding file {} ...'.format(path))
167 | assert os.path.exists(path)
168 | encoded = []
169 | with open(path, 'r', encoding='utf-8') as f:
170 | for idx, line in enumerate(f):
171 | line = line.strip()
172 | if line is None or line == "":
173 | continue
174 | if verbose and idx > 0 and idx % 500000 == 0:
175 | print(' line {}'.format(idx))
176 | symbols = self.tokenize(line) # code 每一行加上结束
177 |
178 | encoded.append(self.convert_to_tensor(symbols))
179 |
180 |
181 | encoded = torch.reshape(torch.cat(encoded,0), (-1, 3000))
182 | print(encoded.size())
183 |
184 | return encoded
185 |
186 | # 许海明 这是对多个句子进行编码
187 | def encode_sents(self, sents, ordered=False, verbose=False):
188 | if verbose:
189 | print('encoding {} sents ...'.format(len(sents)))
190 | encoded = []
191 | for idx, symbols in enumerate(sents):
192 | if verbose and idx > 0 and idx % 500000 == 0:
193 | print(' line {}'.format(idx))
194 | encoded.append(self.convert_to_tensor(symbols))
195 |
196 | if ordered:
197 | encoded = torch.cat(encoded)
198 |
199 | return encoded
200 |
201 | # 许海明
202 | def add_special(self, sym):
203 | if sym not in self.sym2idx:
204 | self.idx2sym.append(sym)
205 | self.sym2idx[sym] = len(self.idx2sym) - 1
206 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
207 |
208 | # 许海明
209 | def add_symbol(self, sym):
210 | if sym not in self.sym2idx:
211 | self.idx2sym.append(sym)
212 | self.sym2idx[sym] = len(self.idx2sym) - 1
213 |
214 | def get_sym(self, idx):
215 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
216 | return self.idx2sym[idx]
217 |
218 | def get_idx(self, sym):
219 | if sym in self.sym2idx:
220 | return self.sym2idx[sym]
221 | else:
222 | # print('encounter unk {}'.format(sym))
223 | assert '' not in sym
224 | assert hasattr(self, 'unk_idx')
225 | return self.sym2idx.get(sym, self.unk_idx)
226 |
227 | def get_symbols(self, indices):
228 | return [self.get_sym(idx) for idx in indices]
229 |
230 | def get_indices(self, symbols):
231 | return [self.get_idx(sym) for sym in symbols]
232 |
233 | # 许海明
234 | def convert_to_tensor(self, symbols):
235 | return torch.LongTensor(self.get_indices(symbols))
236 |
237 | def convert_to_sent(self, indices, exclude=None):
238 | if exclude is None:
239 | return ' '.join([self.get_sym(idx) for idx in indices])
240 | else:
241 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
242 |
243 | def __len__(self):
244 | return len(self.idx2sym)
245 |
--------------------------------------------------------------------------------
/code_for_Classfy/运行命令:
--------------------------------------------------------------------------------
1 | python train.py
2 | --cuda
3 | --data ../data/LM/
4 | --work_dir ../results/LM/
5 | --dataset dccl
6 | --n_layer 6
7 | --d_model 410
8 | --n_head 10
9 | --d_head 41
10 | --tgt_len 150
11 | --mem_len 150
12 | --eval_tgt_len 150
13 | --d_inner 2100
14 | --dropout 0.1
15 | --dropatt 0.0
16 | --optim adam
17 | --lr 0.00025
18 | --warmup_step 0
19 | --max_step 600000
20 | --batch_size 65
21 | --multi_gpu
22 | --gpu0_bsz 15
23 |
--------------------------------------------------------------------------------
/code_for_LM/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuhaiming1996/Transformer-xl-classify/7696d0b328aac4665d851aac5f84e49000919aae/code_for_LM/.DS_Store
--------------------------------------------------------------------------------
/code_for_LM/data_utils.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import glob
3 | from collections import Counter, OrderedDict
4 | import numpy as np
5 | import torch
6 |
7 | from utils.vocabulary import Vocab
8 |
9 | class LMOrderedIterator(object):
10 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
11 | '''
12 | :param data:
13 | :param bsz: batch_size
14 | :param bptt: tgt_len
15 | :param device:
16 | :param ext_len: 0
17 | '''
18 | """
19 | data -- LongTensor -- the LongTensor is strictly ordered
20 | """
21 | self.bsz = bsz
22 | self.bptt = bptt
23 | self.ext_len = ext_len if ext_len is not None else 0
24 |
25 | self.device = device
26 |
27 | # Work out how cleanly we can divide the dataset into bsz parts.
28 | self.n_step = data.size(0) // bsz
29 |
30 |
31 |
32 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
33 | data = data.narrow(0, 0, self.n_step * bsz)
34 |
35 | # Evenly divide the data across the bsz batches.
36 | self.data = data.view(bsz, -1).t().contiguous().to(device)
37 | # self.data = data.view(bsz, -1).t().contiguous().to(device)
38 |
39 |
40 | # Number of mini-batches
41 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt # 这里不是一步一步的
42 |
43 | def get_batch(self, i, bptt=None):
44 | if bptt is None:
45 | bptt = self.bptt
46 | seq_len = min(bptt, self.data.size(0) - 1 - i)
47 |
48 | end_idx = i + seq_len
49 | beg_idx = max(0, i - self.ext_len)
50 |
51 | data = self.data[beg_idx:end_idx]
52 | target = self.data[i+1:i+1+seq_len]
53 |
54 | return data, target, seq_len
55 |
56 | def get_fixlen_iter(self, start=0):
57 | for i in range(start, self.data.size(0) - 1, self.bptt):
58 | yield self.get_batch(i)
59 |
60 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
61 | max_len = self.bptt + max_deviation * std
62 | i = start
63 | while True:
64 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
65 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
66 | data, target, seq_len = self.get_batch(i, bptt)
67 | i += seq_len
68 | yield data, target, seq_len
69 | if i >= self.data.size(0) - 2:
70 | break
71 |
72 | def __iter__(self):
73 | return self.get_fixlen_iter()
74 |
75 |
76 |
77 |
78 | # 许海明
79 | class Corpus(object):
80 | def __init__(self, path, *args, **kwargs):
81 | self.vocab = Vocab(*args, **kwargs)
82 |
83 | self.vocab.count_file(os.path.join(path, 'train.txt'),verbose=True)
84 | self.vocab.count_file(os.path.join(path, 'valid.txt'),verbose=True)
85 |
86 | self.vocab.build_vocab()
87 |
88 |
89 | self.train = self.vocab.encode_file(
90 | os.path.join(path, 'train.txt'), ordered=True, verbose=True)
91 | self.valid = self.vocab.encode_file(
92 | os.path.join(path, 'valid.txt'), ordered=True, verbose=True)
93 | # self.test = self.vocab.encode_file(
94 | # os.path.join(path, 'test.txt'), ordered=True)
95 | # 许海明
96 | def get_iterator(self, split, *args, **kwargs):
97 | '''
98 |
99 | :param split:
100 | :param args:
101 | :param kwargs:
102 | :return:
103 | '''
104 | if split == 'train':
105 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
106 |
107 | elif split in ['valid', 'test']:
108 | data = self.valid if split == 'valid' else self.test
109 | data_iter = LMOrderedIterator(data, *args, **kwargs)
110 |
111 | return data_iter
112 |
113 |
114 | def get_lm_corpus(datadir, alinlen):
115 | fn = os.path.join(datadir, 'cache.pt')
116 | if os.path.exists(fn):
117 | print('Loading cached dataset...')
118 | corpus = torch.load(fn)
119 |
120 | else:
121 | print('Producing dataset {}...'.format(datadir))
122 | kwargs = {}
123 |
124 | kwargs['special'] = ['','', '', '']
125 | kwargs['lower_case'] = False
126 |
127 |
128 | corpus = Corpus(datadir,alinlen ,**kwargs)
129 | torch.save(corpus, fn) # 这里保存的是一个类的对象
130 |
131 | return corpus
132 |
133 | if __name__ == '__main__':
134 | import argparse
135 | parser = argparse.ArgumentParser(description='unit test')
136 | parser.add_argument('--datadir', type=str, default='../data/enwik8',
137 | help='location of the data corpus')
138 | parser.add_argument('--dataset', type=str, default='enwik8',
139 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
140 | help='dataset name')
141 | args = parser.parse_args()
142 |
143 | corpus = get_lm_corpus(args.datadir, args.dataset)
144 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))
145 | tr_iter = corpus.get_iterator('train', 22, 512)
146 | ##
147 | # 许海明 许海明 许海明 许海明 许海明
148 | # 许海明
149 | #
150 | #
151 | # ##
152 |
--------------------------------------------------------------------------------
/code_for_LM/eval.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import time
4 | import math
5 | import os, sys
6 |
7 | import torch
8 |
9 | from data_utils import get_lm_corpus
10 | from mem_transformer import MemTransformerLM
11 | from utils.exp_utils import get_logger
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
14 | parser.add_argument('--data', type=str, default='../data/wikitext-103',
15 | help='location of the data corpus')
16 | parser.add_argument('--dataset', type=str, default='wt103',
17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'],
18 | help='dataset name')
19 | parser.add_argument('--split', type=str, default='all',
20 | choices=['all', 'valid', 'test'],
21 | help='which split to evaluate')
22 | parser.add_argument('--batch_size', type=int, default=10,
23 | help='batch size')
24 | parser.add_argument('--tgt_len', type=int, default=5,
25 | help='number of tokens to predict')
26 | parser.add_argument('--ext_len', type=int, default=0,
27 | help='length of the extended context')
28 | parser.add_argument('--mem_len', type=int, default=0,
29 | help='length of the retained previous heads')
30 | parser.add_argument('--clamp_len', type=int, default=-1,
31 | help='max positional embedding index')
32 | parser.add_argument('--cuda', action='store_true',
33 | help='use CUDA')
34 | parser.add_argument('--work_dir', type=str, required=True,
35 | help='path to the work_dir')
36 | parser.add_argument('--no_log', action='store_true',
37 | help='do not log the eval result')
38 | parser.add_argument('--same_length', action='store_true',
39 | help='set same length attention with masking')
40 | args = parser.parse_args()
41 | assert args.ext_len >= 0, 'extended context length must be non-negative'
42 |
43 | device = torch.device("cuda" if args.cuda else "cpu")
44 |
45 | # Get logger
46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
47 | log_=not args.no_log)
48 |
49 | # Load dataset
50 | corpus = get_lm_corpus(args.data, args.dataset)
51 | ntokens = len(corpus.vocab)
52 |
53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
54 | device=device, ext_len=args.ext_len)
55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
56 | device=device, ext_len=args.ext_len)
57 |
58 | # Load the best saved model.
59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
60 | model = torch.load(f)
61 | model.backward_compatible()
62 | model = model.to(device)
63 |
64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
66 |
67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
68 | if args.clamp_len > 0:
69 | model.clamp_len = args.clamp_len
70 | if args.same_length:
71 | model.same_length = True
72 |
73 | ###############################################################################
74 | # Evaluation code
75 | ###############################################################################
76 | def evaluate(eval_iter):
77 | # Turn on evaluation mode which disables dropout.
78 | model.eval()
79 | total_len, total_loss = 0, 0.
80 | start_time = time.time()
81 | with torch.no_grad():
82 | mems = tuple()
83 | for idx, (data, target, seq_len) in enumerate(eval_iter):
84 | ret = model(data, target, *mems)
85 | loss, mems = ret[0], ret[1:]
86 | loss = loss.mean()
87 | total_loss += seq_len * loss.item()
88 | total_len += seq_len
89 | total_time = time.time() - start_time
90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
91 | total_time, 1000 * total_time / (idx+1)))
92 | return total_loss / total_len
93 |
94 | # Run on test data.
95 | if args.split == 'all':
96 | test_loss = evaluate(te_iter)
97 | valid_loss = evaluate(va_iter)
98 | elif args.split == 'valid':
99 | valid_loss = evaluate(va_iter)
100 | test_loss = None
101 | elif args.split == 'test':
102 | test_loss = evaluate(te_iter)
103 | valid_loss = None
104 |
105 | def format_log(loss, split):
106 | if args.dataset in ['enwik8', 'text8']:
107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
108 | split, loss, loss / math.log(2))
109 | else:
110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
111 | split, loss, math.exp(loss))
112 | return log_str
113 |
114 | log_str = ''
115 | if valid_loss is not None:
116 | log_str += format_log(valid_loss, 'valid')
117 | if test_loss is not None:
118 | log_str += format_log(test_loss, 'test')
119 |
120 | logging('=' * 100)
121 | logging(log_str)
122 | logging('=' * 100)
123 |
--------------------------------------------------------------------------------
/code_for_LM/mem_transformer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import functools
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | sys.path.append('utils')
12 | from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
13 | from log_uniform_sampler import LogUniformSampler, sample_logits
14 |
15 | class PositionalEmbedding(nn.Module):
16 | def __init__(self, demb):
17 | super(PositionalEmbedding, self).__init__()
18 |
19 | self.demb = demb
20 |
21 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
22 | self.register_buffer('inv_freq', inv_freq)
23 |
24 | def forward(self, pos_seq, bsz=None):
25 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
26 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
27 |
28 | if bsz is not None:
29 | return pos_emb[:,None,:].expand(-1, bsz, -1)
30 | else:
31 | return pos_emb[:,None,:]
32 |
33 |
34 | class PositionwiseFF(nn.Module):
35 | def __init__(self, d_model, d_inner, dropout, pre_lnorm = False):
36 | super(PositionwiseFF, self).__init__()
37 |
38 | self.d_model = d_model
39 | self.d_inner = d_inner
40 | self.dropout = dropout
41 |
42 | self.CoreNet = nn.Sequential(
43 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
44 | nn.Dropout(dropout),
45 | nn.Linear(d_inner, d_model),
46 | nn.Dropout(dropout),
47 | )
48 |
49 | self.layer_norm = nn.LayerNorm(d_model)
50 |
51 | self.pre_lnorm = pre_lnorm
52 |
53 | def forward(self, inp):
54 | if self.pre_lnorm:
55 | ##### layer normalization + positionwise feed-forward
56 | core_out = self.CoreNet(self.layer_norm(inp))
57 |
58 | ##### residual connection
59 | output = core_out + inp
60 | else:
61 | ##### positionwise feed-forward
62 | core_out = self.CoreNet(inp)
63 |
64 | ##### residual connection + layer normalization
65 | output = self.layer_norm(inp + core_out)
66 |
67 | return output
68 |
69 | class MultiHeadAttn(nn.Module):
70 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
71 | pre_lnorm=False):
72 | super(MultiHeadAttn, self).__init__()
73 |
74 | self.n_head = n_head
75 | self.d_model = d_model
76 | self.d_head = d_head
77 | self.dropout = dropout
78 |
79 | self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
80 | self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
81 |
82 | self.drop = nn.Dropout(dropout)
83 | self.dropatt = nn.Dropout(dropatt)
84 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
85 |
86 | self.layer_norm = nn.LayerNorm(d_model)
87 |
88 | self.scale = 1 / (d_head ** 0.5)
89 |
90 | self.pre_lnorm = pre_lnorm
91 |
92 | def forward(self, h, attn_mask=None, mems=None):
93 | ##### multihead attention
94 | # [hlen x bsz x n_head x d_head]
95 |
96 | if mems is not None:
97 | c = torch.cat([mems, h], 0)
98 | else:
99 | c = h
100 |
101 | if self.pre_lnorm:
102 | ##### layer normalization
103 | c = self.layer_norm(c)
104 |
105 | head_q = self.q_net(h)
106 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
107 |
108 | head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
109 | head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
110 | head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
111 |
112 | # [qlen x klen x bsz x n_head]
113 | attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
114 | attn_score.mul_(self.scale)
115 | if attn_mask is not None and attn_mask.any().item():
116 | if attn_mask.dim() == 2:
117 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
118 | elif attn_mask.dim() == 3:
119 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
120 |
121 | # [qlen x klen x bsz x n_head]
122 | attn_prob = F.softmax(attn_score, dim=1)
123 | attn_prob = self.dropatt(attn_prob)
124 |
125 | # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
126 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
127 | attn_vec = attn_vec.contiguous().view(
128 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
129 |
130 | ##### linear projection
131 | attn_out = self.o_net(attn_vec)
132 | attn_out = self.drop(attn_out)
133 |
134 | if self.pre_lnorm:
135 | ##### residual connection
136 | output = h + attn_out
137 | else:
138 | ##### residual connection + layer normalization
139 | output = self.layer_norm(h + attn_out)
140 |
141 | return output
142 |
143 |
144 |
145 | class RelMultiHeadAttn(nn.Module):
146 | def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
147 | tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
148 | super(RelMultiHeadAttn, self).__init__()
149 |
150 | self.n_head = n_head
151 | self.d_model = d_model
152 | self.d_head = d_head
153 | self.dropout = dropout
154 |
155 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
156 |
157 | self.drop = nn.Dropout(dropout)
158 | self.dropatt = nn.Dropout(dropatt)
159 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
160 |
161 | self.layer_norm = nn.LayerNorm(d_model)
162 |
163 | self.scale = 1 / (d_head ** 0.5)
164 |
165 | self.pre_lnorm = pre_lnorm
166 |
167 | def _parallelogram_mask(self, h, w, left=False):
168 | mask = torch.ones((h, w)).byte()
169 | m = min(h, w)
170 | mask[:m,:m] = torch.triu(mask[:m,:m])
171 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
172 |
173 | if left:
174 | return mask
175 | else:
176 | return mask.flip(0)
177 |
178 | def _shift(self, x, qlen, klen, mask, left=False):
179 | if qlen > 1:
180 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
181 | device=x.device, dtype=x.dtype)
182 | else:
183 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
184 |
185 | if left:
186 | mask = mask.flip(1)
187 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
188 | else:
189 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
190 |
191 | x = x_padded.masked_select(mask[:,:,None,None]) \
192 | .view(qlen, klen, x.size(2), x.size(3))
193 |
194 | return x
195 |
196 | def _rel_shift(self, x, zero_triu=False):
197 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
198 | device=x.device, dtype=x.dtype)
199 | x_padded = torch.cat([zero_pad, x], dim=1)
200 |
201 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
202 |
203 | x = x_padded[1:].view_as(x)
204 |
205 | if zero_triu:
206 | ones = torch.ones((x.size(0), x.size(1)))
207 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
208 |
209 | return x
210 |
211 | def forward(self, w, r, attn_mask=None, mems=None):
212 | raise NotImplementedError
213 |
214 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
215 | def __init__(self, *args, **kwargs):
216 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
217 |
218 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
219 |
220 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
221 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
222 |
223 | if mems is not None:
224 | cat = torch.cat([mems, w], 0)
225 | if self.pre_lnorm:
226 | w_heads = self.qkv_net(self.layer_norm(cat))
227 | else:
228 | w_heads = self.qkv_net(cat)
229 | r_head_k = self.r_net(r)
230 |
231 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
232 | w_head_q = w_head_q[-qlen:]
233 | else:
234 | if self.pre_lnorm:
235 | w_heads = self.qkv_net(self.layer_norm(w))
236 | else:
237 | w_heads = self.qkv_net(w)
238 | r_head_k = self.r_net(r)
239 |
240 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
241 |
242 | klen = w_head_k.size(0)
243 |
244 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
245 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
246 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
247 |
248 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head
249 |
250 | #### compute attention score
251 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
252 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
253 |
254 | rr_head_q = w_head_q + r_r_bias
255 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
256 | BD = self._rel_shift(BD)
257 |
258 | # [qlen x klen x bsz x n_head]
259 | attn_score = AC + BD
260 | attn_score.mul_(self.scale)
261 |
262 | #### compute attention probability
263 | if attn_mask is not None and attn_mask.any().item():
264 | if attn_mask.dim() == 2:
265 | attn_score = attn_score.float().masked_fill(
266 | attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
267 | elif attn_mask.dim() == 3:
268 | attn_score = attn_score.float().masked_fill(
269 | attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
270 |
271 | # [qlen x klen x bsz x n_head]
272 | attn_prob = F.softmax(attn_score, dim=1)
273 | attn_prob = self.dropatt(attn_prob)
274 |
275 | #### compute attention vector
276 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
277 |
278 | # [qlen x bsz x n_head x d_head]
279 | attn_vec = attn_vec.contiguous().view(
280 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
281 |
282 | ##### linear projection
283 | attn_out = self.o_net(attn_vec)
284 | attn_out = self.drop(attn_out)
285 |
286 | if self.pre_lnorm:
287 | ##### residual connection
288 | output = w + attn_out
289 | else:
290 | ##### residual connection + layer normalization
291 | output = self.layer_norm(w + attn_out)
292 |
293 | return output
294 |
295 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
296 | def __init__(self, *args, **kwargs):
297 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
298 |
299 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
300 | # r_emb: [klen, n_head, d_head], used for term B
301 | # r_w_bias: [n_head, d_head], used for term C
302 | # r_bias: [klen, n_head], used for term D
303 |
304 | qlen, bsz = w.size(0), w.size(1)
305 |
306 | if mems is not None:
307 | cat = torch.cat([mems, w], 0)
308 | if self.pre_lnorm:
309 | w_heads = self.qkv_net(self.layer_norm(cat))
310 | else:
311 | w_heads = self.qkv_net(cat)
312 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
313 |
314 | w_head_q = w_head_q[-qlen:]
315 | else:
316 | if self.pre_lnorm:
317 | w_heads = self.qkv_net(self.layer_norm(w))
318 | else:
319 | w_heads = self.qkv_net(w)
320 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
321 |
322 | klen = w_head_k.size(0)
323 |
324 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
325 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
326 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
327 |
328 | if klen > r_emb.size(0):
329 | r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
330 | r_emb = torch.cat([r_emb_pad, r_emb], 0)
331 | r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
332 | r_bias = torch.cat([r_bias_pad, r_bias], 0)
333 | else:
334 | r_emb = r_emb[-klen:]
335 | r_bias = r_bias[-klen:]
336 |
337 | #### compute attention score
338 | rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
339 |
340 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
341 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
342 | D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
343 | BD = self._rel_shift(B_ + D_)
344 |
345 | # [qlen x klen x bsz x n_head]
346 | attn_score = AC + BD
347 | attn_score.mul_(self.scale)
348 |
349 | #### compute attention probability
350 | if attn_mask is not None and attn_mask.any().item():
351 | if attn_mask.dim() == 2:
352 | attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
353 | elif attn_mask.dim() == 3:
354 | attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
355 |
356 | # [qlen x klen x bsz x n_head]
357 | attn_prob = F.softmax(attn_score, dim=1)
358 | attn_prob = self.dropatt(attn_prob)
359 |
360 | #### compute attention vector
361 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
362 |
363 | # [qlen x bsz x n_head x d_head]
364 | attn_vec = attn_vec.contiguous().view(
365 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
366 |
367 | ##### linear projection
368 | attn_out = self.o_net(attn_vec)
369 | attn_out = self.drop(attn_out)
370 |
371 | if self.pre_lnorm:
372 | ##### residual connection
373 | output = w + attn_out
374 | else:
375 | ##### residual connection + layer normalization
376 | output = self.layer_norm(w + attn_out)
377 |
378 | return output
379 |
380 |
381 |
382 |
383 | class DecoderLayer(nn.Module):
384 | def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
385 | super(DecoderLayer, self).__init__()
386 |
387 | self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
388 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
389 | pre_lnorm=kwargs.get('pre_lnorm'))
390 |
391 | def forward(self, dec_inp, dec_attn_mask=None, mems=None):
392 |
393 | output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
394 | mems=mems)
395 | output = self.pos_ff(output)
396 |
397 | return output
398 |
399 | class RelLearnableDecoderLayer(nn.Module):
400 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
401 | **kwargs):
402 | super(RelLearnableDecoderLayer, self).__init__()
403 |
404 | self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
405 | **kwargs)
406 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
407 | pre_lnorm=kwargs.get('pre_lnorm'))
408 |
409 | def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
410 |
411 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
412 | attn_mask=dec_attn_mask,
413 | mems=mems)
414 | output = self.pos_ff(output)
415 |
416 | return output
417 |
418 |
419 |
420 |
421 | class RelPartialLearnableDecoderLayer(nn.Module):
422 | def __init__(self, n_head, d_model, d_head, d_inner, dropout,
423 | **kwargs):
424 | super(RelPartialLearnableDecoderLayer, self).__init__()
425 |
426 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
427 | d_head, dropout, **kwargs)
428 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
429 | pre_lnorm=kwargs.get('pre_lnorm'))
430 |
431 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
432 |
433 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
434 | attn_mask=dec_attn_mask,
435 | mems=mems)
436 | output = self.pos_ff(output)
437 |
438 | return output
439 |
440 |
441 |
442 |
443 | # 许海明
444 | class AdaptiveEmbedding(nn.Module):
445 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
446 | sample_softmax=False):
447 |
448 | '''
449 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val)
450 | :param n_token:
451 | :param d_embed:
452 | :param d_proj:
453 | :param cutoffs:
454 | :param div_val:
455 | :param sample_softmax:
456 | '''
457 | super(AdaptiveEmbedding, self).__init__()
458 |
459 | self.n_token = n_token
460 | self.d_embed = d_embed
461 |
462 | self.cutoffs = cutoffs + [n_token]
463 | self.div_val = div_val
464 | self.d_proj = d_proj
465 |
466 | self.emb_scale = d_proj ** 0.5
467 |
468 | self.cutoff_ends = [0] + self.cutoffs
469 |
470 | self.emb_layers = nn.ModuleList()
471 | self.emb_projs = nn.ParameterList()
472 | if div_val == 1:
473 | self.emb_layers.append(
474 | nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
475 | )
476 | if d_proj != d_embed:
477 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
478 | else:
479 | for i in range(len(self.cutoffs)):
480 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
481 | d_emb_i = d_embed // (div_val ** i)
482 | self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
483 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
484 |
485 |
486 |
487 | def forward(self, inp):
488 | if self.div_val == 1:
489 | embed = self.emb_layers[0](inp)
490 | if self.d_proj != self.d_embed:
491 | embed = F.linear(embed, self.emb_projs[0])
492 | else:
493 | param = next(self.parameters())
494 | inp_flat = inp.view(-1)
495 | emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
496 | dtype=param.dtype, device=param.device)
497 | for i in range(len(self.cutoffs)):
498 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
499 |
500 | mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
501 | indices_i = mask_i.nonzero().squeeze()
502 |
503 | if indices_i.numel() == 0:
504 | continue
505 |
506 | inp_i = inp_flat.index_select(0, indices_i) - l_idx
507 | emb_i = self.emb_layers[i](inp_i)
508 | emb_i = F.linear(emb_i, self.emb_projs[i])
509 |
510 | emb_flat.index_copy_(0, indices_i, emb_i)
511 |
512 | embed = emb_flat.view(*inp.size(), self.d_proj)
513 |
514 | embed.mul_(self.emb_scale)
515 |
516 | return embed
517 |
518 |
519 |
520 |
521 | class MemTransformerLM(nn.Module):
522 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
523 | dropout, dropatt, tie_weight=True, d_embed=None,
524 | div_val=1, tie_projs=[False], pre_lnorm=False,
525 | tgt_len=None, ext_len=None, mem_len=None,
526 | cutoffs=[], adapt_inp=False,
527 | same_length=False, attn_type=0, clamp_len=-1,
528 | sample_softmax=-1):
529 | '''
530 |
531 | :param n_token: 单词表的大小
532 | :param n_layer: 16
533 | :param n_head: 10
534 | :param d_model: 410
535 | :param d_head: 41
536 | :param d_inner: 2100
537 | :param dropout: 0.1
538 | :param dropatt: 0.0
539 | :param tie_weight: True
540 | :param d_embed: 410
541 | :param div_val: 1
542 | :param tie_projs: [F,T,T,T]
543 | :param pre_lnorm: False
544 | :param tgt_len: 150
545 | :param ext_len: 0
546 | :param mem_len: 150
547 | :param cutoffs: [20000, 40000, 200000]
548 | :param adapt_inp:
549 | :param same_length: 训练的时候是False
550 | :param attn_type: 0
551 | :param clamp_len: -1
552 | :param sample_softmax: -1
553 | '''
554 | super(MemTransformerLM, self).__init__()
555 | self.n_token = n_token
556 |
557 | d_embed = d_model if d_embed is None else d_embed
558 | self.d_embed = d_embed
559 | self.d_model = d_model
560 | self.n_head = n_head
561 | self.d_head = d_head
562 |
563 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
564 | div_val=div_val)
565 |
566 | self.drop = nn.Dropout(dropout)
567 | self.n_layer = n_layer
568 | self.tgt_len = tgt_len
569 | self.mem_len = mem_len
570 | self.ext_len = ext_len
571 | self.max_klen = tgt_len + ext_len + mem_len
572 | self.attn_type = attn_type
573 | self.layers = nn.ModuleList()
574 | if attn_type == 0: # the default attention
575 | for i in range(n_layer):
576 | self.layers.append(
577 | RelPartialLearnableDecoderLayer(
578 | n_head, d_model, d_head, d_inner, dropout,
579 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
580 | dropatt=dropatt, pre_lnorm=pre_lnorm)
581 | )
582 | elif attn_type == 1: # learnable embeddings
583 | for i in range(n_layer):
584 | self.layers.append(
585 | RelLearnableDecoderLayer(
586 | n_head, d_model, d_head, d_inner, dropout,
587 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
588 | dropatt=dropatt, pre_lnorm=pre_lnorm)
589 | )
590 | elif attn_type in [2, 3]: # absolute embeddings
591 | for i in range(n_layer):
592 | self.layers.append(
593 | DecoderLayer(
594 | n_head, d_model, d_head, d_inner, dropout,
595 | dropatt=dropatt, pre_lnorm=pre_lnorm)
596 | )
597 |
598 | self.sample_softmax = sample_softmax
599 | # use sampled softmax
600 | if sample_softmax > 0:
601 | self.out_layer = nn.Linear(d_model, n_token)
602 | if tie_weight:
603 | self.out_layer.weight = self.word_emb.weight
604 | self.tie_weight = tie_weight
605 | self.sampler = LogUniformSampler(n_token, sample_softmax)
606 |
607 | # use adaptive softmax (including standard softmax)
608 | else:
609 | self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
610 | cutoffs, div_val=div_val)
611 |
612 | if tie_weight:
613 | for i in range(len(self.crit.out_layers)):
614 | self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
615 |
616 | if tie_projs:
617 | for i, tie_proj in enumerate(tie_projs):
618 | if tie_proj and div_val == 1 and d_model != d_embed:
619 | self.crit.out_projs[i] = self.word_emb.emb_projs[0]
620 | elif tie_proj and div_val != 1:
621 | self.crit.out_projs[i] = self.word_emb.emb_projs[i]
622 |
623 | self.same_length = same_length
624 | self.clamp_len = clamp_len
625 |
626 | self._create_params()
627 |
628 | def backward_compatible(self):
629 | self.sample_softmax = -1
630 |
631 | def _create_params(self):
632 | if self.attn_type == 0: # default attention
633 | self.pos_emb = PositionalEmbedding(self.d_model)
634 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
635 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
636 |
637 | elif self.attn_type == 1: # learnable
638 | self.r_emb = nn.Parameter(torch.Tensor(
639 | self.n_layer, self.max_klen, self.n_head, self.d_head))
640 | self.r_w_bias = nn.Parameter(torch.Tensor(
641 | self.n_layer, self.n_head, self.d_head))
642 | self.r_bias = nn.Parameter(torch.Tensor(
643 | self.n_layer, self.max_klen, self.n_head))
644 | elif self.attn_type == 2: # absolute standard
645 | self.pos_emb = PositionalEmbedding(self.d_model)
646 | elif self.attn_type == 3: # absolute deeper SA
647 | self.r_emb = nn.Parameter(torch.Tensor(
648 | self.n_layer, self.max_klen, self.n_head, self.d_head))
649 |
650 | def reset_length(self, tgt_len, ext_len, mem_len):
651 | self.tgt_len = tgt_len
652 | self.mem_len = mem_len
653 | self.ext_len = ext_len
654 |
655 | def init_mems(self):
656 | if self.mem_len > 0:
657 | mems = []
658 | param = next(self.parameters())
659 | for i in range(self.n_layer+1):
660 | empty = torch.empty(0, dtype=param.dtype, device=param.device)
661 | mems.append(empty)
662 |
663 | return mems
664 | else:
665 | return None
666 |
667 | def _update_mems(self, hids, mems, qlen, mlen):
668 | # does not deal with None
669 | if mems is None: return None
670 |
671 | # mems is not None
672 | assert len(hids) == len(mems), 'len(hids) != len(mems)'
673 |
674 | # There are `mlen + qlen` steps that can be cached into mems
675 | # For the next step, the last `ext_len` of the `qlen` tokens
676 | # will be used as the extended context. Hence, we only cache
677 | # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
678 | # to `mlen + qlen - self.ext_len`.
679 | with torch.no_grad():
680 | new_mems = []
681 | end_idx = mlen + max(0, qlen - 0 - self.ext_len)
682 | beg_idx = max(0, end_idx - self.mem_len)
683 | for i in range(len(hids)):
684 | cat = torch.cat([mems[i], hids[i]], dim=0)
685 | new_mems.append(cat[beg_idx:end_idx].detach())
686 |
687 | return new_mems
688 |
689 | def _forward(self, dec_inp, mems=None):
690 | qlen, bsz = dec_inp.size()
691 |
692 | word_emb = self.word_emb(dec_inp)
693 |
694 | mlen = mems[0].size(0) if mems is not None else 0
695 | klen = mlen + qlen
696 | if self.same_length:
697 | all_ones = word_emb.new_ones(qlen, klen)
698 | mask_len = klen - self.mem_len
699 | if mask_len > 0:
700 | mask_shift_len = qlen - mask_len
701 | else:
702 | mask_shift_len = qlen
703 | dec_attn_mask = (torch.triu(all_ones, 1+mlen)
704 | + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
705 | else:
706 | dec_attn_mask = torch.triu(
707 | word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
708 |
709 | hids = []
710 | if self.attn_type == 0: # default
711 | pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
712 | dtype=word_emb.dtype)
713 | if self.clamp_len > 0:
714 | pos_seq.clamp_(max=self.clamp_len)
715 | pos_emb = self.pos_emb(pos_seq)
716 |
717 | core_out = self.drop(word_emb)
718 | pos_emb = self.drop(pos_emb)
719 |
720 | hids.append(core_out)
721 | for i, layer in enumerate(self.layers):
722 | mems_i = None if mems is None else mems[i]
723 | core_out = layer(core_out, pos_emb, self.r_w_bias,
724 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
725 | hids.append(core_out)
726 |
727 |
728 | elif self.attn_type == 1: # learnable
729 | core_out = self.drop(word_emb)
730 | hids.append(core_out)
731 | for i, layer in enumerate(self.layers):
732 | if self.clamp_len > 0:
733 | r_emb = self.r_emb[i][-self.clamp_len :]
734 | r_bias = self.r_bias[i][-self.clamp_len :]
735 | else:
736 | r_emb, r_bias = self.r_emb[i], self.r_bias[i]
737 |
738 | mems_i = None if mems is None else mems[i]
739 | core_out = layer(core_out, r_emb, self.r_w_bias[i],
740 | r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
741 | hids.append(core_out)
742 | elif self.attn_type == 2: # absolute
743 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
744 | dtype=word_emb.dtype)
745 | if self.clamp_len > 0:
746 | pos_seq.clamp_(max=self.clamp_len)
747 | pos_emb = self.pos_emb(pos_seq)
748 |
749 | core_out = self.drop(word_emb + pos_emb[-qlen:])
750 |
751 | hids.append(core_out)
752 | for i, layer in enumerate(self.layers):
753 | mems_i = None if mems is None else mems[i]
754 | if mems_i is not None and i == 0:
755 | mems_i += pos_emb[:mlen]
756 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
757 | mems=mems_i)
758 | hids.append(core_out)
759 | elif self.attn_type == 3:
760 | core_out = self.drop(word_emb)
761 |
762 | hids.append(core_out)
763 | for i, layer in enumerate(self.layers):
764 | mems_i = None if mems is None else mems[i]
765 | if mems_i is not None and mlen > 0:
766 | cur_emb = self.r_emb[i][:-qlen]
767 | cur_size = cur_emb.size(0)
768 | if cur_size < mlen:
769 | cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
770 | cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
771 | else:
772 | cur_emb = cur_emb[-mlen:]
773 | mems_i += cur_emb.view(mlen, 1, -1)
774 | core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
775 |
776 | core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
777 | mems=mems_i)
778 | hids.append(core_out)
779 |
780 | core_out = self.drop(core_out)
781 |
782 | new_mems = self._update_mems(hids, mems, mlen, qlen)
783 |
784 | return core_out, new_mems
785 |
786 | def forward(self, data, target, *mems):
787 | # nn.DataParallel does not allow size(0) tensors to be broadcasted.
788 | # So, have to initialize size(0) mems inside the model forward.
789 | # Moreover, have to return new_mems to allow nn.DataParallel to piece
790 | # them together.
791 | if not mems:
792 | mems = self.init_mems()
793 |
794 | tgt_len = target.size(0)
795 | hidden, new_mems = self._forward(data, mems=mems)
796 |
797 | pred_hid = hidden[-tgt_len:]
798 | if self.sample_softmax > 0 and self.training:
799 | assert self.tie_weight
800 | logit = sample_logits(self.word_emb,
801 | self.out_layer.bias, target, pred_hid, self.sampler)
802 | loss = -F.log_softmax(logit, -1)[:, :, 0]
803 | else:
804 | loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
805 | loss = loss.view(tgt_len, -1)
806 |
807 | if new_mems is None:
808 | return [loss]
809 | else:
810 | return [loss] + new_mems
811 |
812 |
813 |
814 | if __name__ == '__main__':
815 | import argparse
816 |
817 | parser = argparse.ArgumentParser(description='unit test')
818 |
819 | parser.add_argument('--n_layer', type=int, default=4, help='')
820 | parser.add_argument('--n_rel_layer', type=int, default=4, help='')
821 | parser.add_argument('--n_head', type=int, default=2, help='')
822 | parser.add_argument('--d_head', type=int, default=2, help='')
823 | parser.add_argument('--d_model', type=int, default=200, help='')
824 | parser.add_argument('--d_embed', type=int, default=200, help='')
825 | parser.add_argument('--d_inner', type=int, default=200, help='')
826 | parser.add_argument('--dropout', type=float, default=0.0, help='')
827 | parser.add_argument('--cuda', action='store_true', help='')
828 | parser.add_argument('--seed', type=int, default=1111, help='')
829 | parser.add_argument('--multi_gpu', action='store_true', help='')
830 |
831 | args = parser.parse_args()
832 |
833 | device = torch.device("cuda" if args.cuda else "cpu")
834 |
835 | B = 4
836 | tgt_len, mem_len, ext_len = 36, 36, 0
837 | data_len = tgt_len * 20
838 | args.n_token = 10000
839 |
840 | import data_utils
841 |
842 | data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
843 | diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
844 |
845 | cutoffs = [args.n_token // 2]
846 | tie_projs = [False] + [True] * len(cutoffs)
847 |
848 | for div_val in [1, 2]:
849 | for d_embed in [200, 100]:
850 | model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
851 | args.d_model, args.d_head, args.d_inner, args.dropout,
852 | dropatt=args.dropout, tie_weight=True,
853 | d_embed=d_embed, div_val=div_val,
854 | tie_projs=tie_projs, pre_lnorm=True,
855 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
856 | cutoffs=cutoffs, attn_type=0).to(device)
857 |
858 | print(sum(p.numel() for p in model.parameters()))
859 |
860 | mems = tuple()
861 | for idx, (inp, tgt, seqlen) in enumerate(diter):
862 | print('batch {}'.format(idx))
863 | out = model(inp, tgt, *mems)
864 | mems = out[1:]
865 |
--------------------------------------------------------------------------------
/code_for_LM/train.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import argparse
3 | import time
4 | import math
5 | import os, sys
6 | import itertools
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 |
14 | from data_utils import get_lm_corpus
15 | from mem_transformer import MemTransformerLM
16 | from utils.exp_utils import create_exp_dir
17 | from utils.data_parallel import BalancedDataParallel
18 |
19 |
20 |
21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,7"
22 |
23 |
24 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
25 | parser.add_argument('--alinlen', type=int, default=3000, help='xhm')
26 | parser.add_argument('--debug', action='store_true',
27 | help='run in debug mode (do not create exp dir)')
28 |
29 | parser.add_argument('--work_dir', default='LM-TFM', type=str,
30 | help='experiment directory.')
31 | parser.add_argument('--not_tied', action='store_true',
32 | help='do not tie the word embedding and softmax weights')
33 | parser.add_argument('--data', type=str,
34 | help='location of the data corpus')
35 | parser.add_argument('--dataset', type=str, default='dccl',
36 | choices=['dccl'],
37 | help='dataset name')
38 |
39 | parser.add_argument('--seed', type=int, default=1111,
40 | help='random seed')
41 | parser.add_argument('--ext_len', type=int, default=0,
42 | help='length of the extended context')
43 |
44 |
45 | parser.add_argument('--n_layer', type=int, default=6,
46 | help='number of total layers')
47 | parser.add_argument('--d_model', type=int, default=410,
48 | help='model dimension')
49 | parser.add_argument('--n_head', type=int, default=10,
50 | help='number of heads')
51 | parser.add_argument('--d_head', type=int, default=41,
52 | help='head dimension')
53 | parser.add_argument('--d_embed', type=int, default=-1,
54 | help='embedding dimension')
55 |
56 | parser.add_argument('--cuda', action='store_true',
57 | help='use CUDA')
58 |
59 |
60 | parser.add_argument('--fp16', action='store_true',
61 | help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
62 |
63 | parser.add_argument('--batch_size', type=int, default=60,
64 | help='batch size')
65 |
66 |
67 | parser.add_argument('--tgt_len', type=int, default=150,
68 | help='number of tokens to predict')
69 | parser.add_argument('--eval_tgt_len', type=int, default=150,
70 | help='number of tokens to predict for evaluation')
71 |
72 | parser.add_argument('--adaptive', action='store_true',
73 | help='use adaptive softmax')
74 |
75 |
76 |
77 | parser.add_argument('--d_inner', type=int, default=1000,
78 | help='inner dimension in FF')
79 | parser.add_argument('--dropout', type=float, default=0.0,
80 | help='global dropout rate')
81 | parser.add_argument('--dropatt', type=float, default=0.0,
82 | help='attention probability dropout rate')
83 | parser.add_argument('--init', default='normal', type=str,
84 | help='parameter initializer to use.')
85 | parser.add_argument('--emb_init', default='normal', type=str,
86 | help='parameter initializer to use.')
87 | parser.add_argument('--init_range', type=float, default=0.1,
88 | help='parameters initialized by U(-init_range, init_range)')
89 | parser.add_argument('--emb_init_range', type=float, default=0.01,
90 | help='parameters initialized by U(-init_range, init_range)')
91 | parser.add_argument('--init_std', type=float, default=0.02,
92 | help='parameters initialized by N(0, init_std)')
93 | parser.add_argument('--proj_init_std', type=float, default=0.01,
94 | help='parameters initialized by N(0, init_std)')
95 | parser.add_argument('--optim', default='adam', type=str,
96 | choices=['adam', 'sgd', 'adagrad'],
97 | help='optimizer to use.')
98 | parser.add_argument('--lr', type=float, default=0.00025,
99 | help='initial learning rate (0.00025|5 for adam|sgd)')
100 | parser.add_argument('--mom', type=float, default=0.0,
101 | help='momentum for sgd')
102 | parser.add_argument('--scheduler', default='cosine', type=str,
103 | choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
104 | help='lr scheduler to use.')
105 | parser.add_argument('--warmup_step', type=int, default=0,
106 | help='upper epoch limit')
107 | parser.add_argument('--decay_rate', type=float, default=0.5,
108 | help='decay factor when ReduceLROnPlateau is used')
109 | parser.add_argument('--lr_min', type=float, default=0.0,
110 | help='minimum learning rate during annealing')
111 | parser.add_argument('--clip', type=float, default=0.25,
112 | help='gradient clipping')
113 | parser.add_argument('--clip_nonemb', action='store_true',
114 | help='only clip the gradient of non-embedding params')
115 | parser.add_argument('--max_step', type=int, default=100000,
116 | help='upper epoch limit')
117 |
118 | parser.add_argument('--batch_chunk', type=int, default=1,
119 | help='split batch into chunks to save memory')
120 |
121 | parser.add_argument('--mem_len', type=int, default=0,
122 | help='length of the retained previous heads')
123 | parser.add_argument('--restart_dir', type=str, default='',
124 | help='restart dir')
125 |
126 | parser.add_argument('--div_val', type=int, default=1,
127 | help='divident value for adapative input and softmax')
128 | parser.add_argument('--pre_lnorm', action='store_true',
129 | help='apply LayerNorm to the input instead of the output')
130 | parser.add_argument('--varlen', action='store_true',
131 | help='use variable length')
132 | parser.add_argument('--multi_gpu', action='store_true',
133 | help='use multiple GPU')
134 | parser.add_argument('--log-interval', type=int, default=200,
135 | help='report interval')
136 | parser.add_argument('--eval-interval', type=int, default=2000,
137 | help='evaluation interval')
138 |
139 | parser.add_argument('--restart', action='store_true',
140 | help='restart training from the saved checkpoint')
141 |
142 |
143 | parser.add_argument('--same_length', action='store_true',
144 | help='use the same attn length for all tokens')
145 | parser.add_argument('--attn_type', type=int, default=0,
146 | help='attention type. 0 for ours, 1 for Shaw et al,'
147 | '2 for Vaswani et al, 3 for Al Rfou et al.')
148 | parser.add_argument('--clamp_len', type=int, default=-1,
149 | help='use the same pos embeddings after clamp_len')
150 | parser.add_argument('--eta_min', type=float, default=0.0,
151 | help='min learning rate for cosine scheduler')
152 | parser.add_argument('--gpu0_bsz', type=int, default=-1,
153 | help='batch size on gpu 0')
154 | parser.add_argument('--max_eval_steps', type=int, default=-1,
155 | help='max eval steps')
156 | parser.add_argument('--sample_softmax', type=int, default=-1,
157 | help='number of samples in sampled softmax')
158 | parser.add_argument('--patience', type=int, default=0,
159 | help='patience')
160 | parser.add_argument('--finetune_v2', action='store_true',
161 | help='finetune v2')
162 | parser.add_argument('--finetune_v3', action='store_true',
163 | help='finetune v3')
164 |
165 | parser.add_argument('--static-loss-scale', type=float, default=1,
166 | help='Static loss scale, positive power of 2 values can '
167 | 'improve fp16 convergence.')
168 | parser.add_argument('--dynamic-loss-scale', action='store_true',
169 | help='Use dynamic loss scaling. If supplied, this argument'
170 | ' supersedes --static-loss-scale.')
171 |
172 |
173 | #####################################程序开始########################################################################################################
174 | args = parser.parse_args()
175 | args.tied = not args.not_tied # code True
176 |
177 | if args.d_embed < 0:
178 | args.d_embed = args.d_model # code 410
179 |
180 | assert args.ext_len >= 0, 'extended context length must be non-negative' # code 0
181 | assert args.batch_size % args.batch_chunk == 0
182 |
183 |
184 | logging = create_exp_dir(args.work_dir,
185 | scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)
186 |
187 |
188 |
189 | # Set the random seed manually for reproducibility.
190 | np.random.seed(args.seed)
191 | torch.manual_seed(args.seed)
192 | if torch.cuda.is_available():
193 | if not args.cuda:
194 | print('WARNING: You have a CUDA device, so you should probably run with --cuda')
195 | else:
196 | torch.cuda.manual_seed_all(args.seed)
197 |
198 | # Validate `--fp16` option
199 | if args.fp16:
200 | if not args.cuda:
201 | print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
202 | args.fp16 = False
203 | else:
204 | try:
205 | from apex.fp16_utils import FP16_Optimizer
206 | except:
207 | print('WARNING: apex not installed, ignoring --fp16 option')
208 | args.fp16 = False
209 |
210 | device = torch.device('cuda' if args.cuda else 'cpu')
211 |
212 | ###############################################################################
213 | # Load data
214 | ###############################################################################
215 |
216 | assert args.alinlen == 3000
217 | corpus = get_lm_corpus(args.data, args.alinlen)
218 | print("数据加载成功")
219 |
220 | print("保存单词表")
221 | corpus.vocab.save_symbol(os.path.join(args.work_dir,"vocab.txt"))
222 |
223 | ntokens = len(corpus.vocab) # code 单词表的大小
224 | args.n_token = ntokens
225 |
226 |
227 | eval_batch_size = 10
228 | tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
229 | device='cpu', ext_len=args.ext_len)
230 | va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
231 | device=device, ext_len=args.ext_len)
232 |
233 | print("迭代器加载成功")
234 | # te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
235 | # device=device, ext_len=args.ext_len)
236 |
237 | # adaptive softmax / embedding
238 | cutoffs, tie_projs = [], [False]
239 | if args.adaptive:
240 | cutoffs = [70000, 150000, 250000]
241 | tie_projs += [False] * len(cutoffs)
242 |
243 |
244 | ###############################################################################
245 | # Build the model
246 | ###############################################################################
247 | def init_weight(weight):
248 | if args.init == 'uniform':
249 | nn.init.uniform_(weight, -args.init_range, args.init_range)
250 | elif args.init == 'normal':
251 | nn.init.normal_(weight, 0.0, args.init_std)
252 |
253 | def init_bias(bias):
254 | nn.init.constant_(bias, 0.0)
255 |
256 | # 许海明
257 | def weights_init(m):
258 | classname = m.__class__.__name__
259 | if classname.find('Linear') != -1:
260 | if hasattr(m, 'weight') and m.weight is not None:
261 | init_weight(m.weight)
262 | if hasattr(m, 'bias') and m.bias is not None:
263 | init_bias(m.bias)
264 | elif classname.find('AdaptiveEmbedding') != -1:
265 | if hasattr(m, 'emb_projs'):
266 | for i in range(len(m.emb_projs)):
267 | if m.emb_projs[i] is not None:
268 | nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
269 | elif classname.find('Embedding') != -1:
270 | if hasattr(m, 'weight'):
271 | init_weight(m.weight)
272 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
273 | if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
274 | init_weight(m.cluster_weight)
275 | if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
276 | init_bias(m.cluster_bias)
277 | if hasattr(m, 'out_projs'):
278 | for i in range(len(m.out_projs)):
279 | if m.out_projs[i] is not None:
280 | nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
281 | elif classname.find('LayerNorm') != -1:
282 | if hasattr(m, 'weight'):
283 | nn.init.normal_(m.weight, 1.0, args.init_std)
284 | if hasattr(m, 'bias') and m.bias is not None:
285 | init_bias(m.bias)
286 | elif classname.find('TransformerLM') != -1:
287 | if hasattr(m, 'r_emb'):
288 | init_weight(m.r_emb)
289 | if hasattr(m, 'r_w_bias'):
290 | init_weight(m.r_w_bias)
291 | if hasattr(m, 'r_r_bias'):
292 | init_weight(m.r_r_bias)
293 | if hasattr(m, 'r_bias'):
294 | init_bias(m.r_bias)
295 |
296 | def update_dropout(m):
297 | classname = m.__class__.__name__
298 | if classname.find('Dropout') != -1:
299 | if hasattr(m, 'p'):
300 | m.p = args.dropout
301 |
302 | def update_dropatt(m):
303 | if hasattr(m, 'dropatt'):
304 | m.dropatt.p = args.dropatt
305 |
306 | if args.restart:
307 | with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
308 | model = torch.load(f)
309 | if not args.fp16:
310 | model = model.float()
311 | model.apply(update_dropout)
312 | model.apply(update_dropatt)
313 | else:
314 | model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
315 | args.d_head, args.d_inner, args.dropout, args.dropatt,
316 | tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
317 | tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
318 | ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
319 | same_length=args.same_length, attn_type=args.attn_type,
320 | clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
321 | model.apply(weights_init)
322 | model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
323 | args.n_all_param = sum([p.nelement() for p in model.parameters()])
324 | args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
325 |
326 | if args.fp16:
327 | model = model.half()
328 |
329 |
330 |
331 |
332 | if args.multi_gpu:
333 | model = model.to(device)
334 | if args.gpu0_bsz >= 0:
335 | para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
336 | model, dim=1).to(device)
337 | else:
338 | para_model = nn.DataParallel(model, dim=1).to(device)
339 | else:
340 | para_model = model.to(device)
341 |
342 | #### optimizer
343 | if args.optim.lower() == 'sgd':
344 | if args.sample_softmax > 0:
345 | dense_params, sparse_params = [], []
346 | for param in model.parameters():
347 | if param.size() == model.word_emb.weight.size():
348 | sparse_params.append(param)
349 | else:
350 | dense_params.append(param)
351 | optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
352 | optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
353 | else:
354 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
355 | momentum=args.mom)
356 | elif args.optim.lower() == 'adam':
357 | if args.sample_softmax > 0:
358 | dense_params, sparse_params = [], []
359 | for param in model.parameters():
360 | if param.size() == model.word_emb.weight.size():
361 | sparse_params.append(param)
362 | else:
363 | dense_params.append(param)
364 | optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
365 | optimizer = optim.Adam(dense_params, lr=args.lr)
366 | else:
367 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
368 | elif args.optim.lower() == 'adagrad':
369 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
370 |
371 | #### scheduler
372 | if args.scheduler == 'cosine':
373 | # here we do not set eta_min to lr_min to be backward compatible
374 | # because in previous versions eta_min is default to 0
375 | # rather than the default value of lr_min 1e-6
376 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
377 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
378 | if args.sample_softmax > 0:
379 | scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
380 | args.max_step, eta_min=args.eta_min) # should use eta_min arg
381 | elif args.scheduler == 'inv_sqrt':
382 | # originally used for Transformer (in Attention is all you need)
383 | def lr_lambda(step):
384 | # return a multiplier instead of a learning rate
385 | if step == 0 and args.warmup_step == 0:
386 | return 1.
387 | else:
388 | return 1. / (step ** 0.5) if step > args.warmup_step \
389 | else step / (args.warmup_step ** 1.5)
390 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
391 | elif args.scheduler == 'dev_perf':
392 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
393 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
394 | if args.sample_softmax > 0:
395 | scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
396 | factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
397 | elif args.scheduler == 'constant':
398 | pass
399 |
400 | if args.cuda and args.fp16:
401 | # If args.dynamic_loss_scale is False, static_loss_scale will be used.
402 | # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
403 | optimizer = FP16_Optimizer(optimizer,
404 | static_loss_scale = args.static_loss_scale,
405 | dynamic_loss_scale = args.dynamic_loss_scale,
406 | dynamic_loss_args = {'init_scale': 2 ** 16})
407 |
408 | if args.restart:
409 | if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
410 | with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
411 | opt_state_dict = torch.load(f)
412 | optimizer.load_state_dict(opt_state_dict)
413 | else:
414 | print('Optimizer was not saved. Start from scratch.')
415 |
416 | logging('=' * 100)
417 | for k, v in args.__dict__.items():
418 | logging(' - {} : {}'.format(k, v))
419 | logging('=' * 100)
420 | logging('#params = {}'.format(args.n_all_param))
421 | logging('#non emb params = {}'.format(args.n_nonemb_param))
422 |
423 | ###############################################################################
424 | # Training code
425 | ###############################################################################
426 |
427 | def evaluate(eval_iter):
428 | # Turn on evaluation mode which disables dropout.
429 | model.eval()
430 |
431 | # If the model does not use memory at all, make the ext_len longer.
432 | # Otherwise, make the mem_len longer and keep the ext_len the same.
433 | if args.mem_len == 0:
434 | model.reset_length(args.eval_tgt_len,
435 | args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
436 | else:
437 | model.reset_length(args.eval_tgt_len,
438 | args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
439 |
440 | # Evaluation
441 | total_len, total_loss = 0, 0.
442 | with torch.no_grad():
443 | mems = tuple()
444 | for i, (data, target, seq_len) in enumerate(eval_iter):
445 | if args.max_eval_steps > 0 and i >= args.max_eval_steps:
446 | break
447 | ret = model(data, target, *mems)
448 | loss, mems = ret[0], ret[1:]
449 | loss = loss.mean()
450 | total_loss += seq_len * loss.float().item()
451 | total_len += seq_len
452 |
453 | # Switch back to the training mode
454 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
455 | model.train()
456 |
457 | return total_loss / total_len
458 |
459 |
460 | def train():
461 | # Turn on training mode which enables dropout.
462 | global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
463 | model.train()
464 | if args.batch_chunk > 1:
465 | mems = [tuple() for _ in range(args.batch_chunk)]
466 | else:
467 | mems = tuple()
468 | train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
469 | for batch, (data, target, seq_len) in enumerate(train_iter):
470 | model.zero_grad()
471 | if args.batch_chunk > 1:
472 | data_chunks = torch.chunk(data, args.batch_chunk, 1)
473 | target_chunks = torch.chunk(target, args.batch_chunk, 1)
474 | for i in range(args.batch_chunk):
475 | data_i = data_chunks[i].contiguous()
476 | target_i = target_chunks[i].contiguous()
477 | ret = para_model(data_i, target_i, *mems[i])
478 | loss, mems[i] = ret[0], ret[1:]
479 | loss = loss.float().mean().type_as(loss) / args.batch_chunk
480 | if args.fp16:
481 | optimizer.backward(loss)
482 | else:
483 | loss.backward()
484 | train_loss += loss.float().item()
485 | else:
486 | ret = para_model(data, target, *mems)
487 | loss, mems = ret[0], ret[1:]
488 | loss = loss.float().mean().type_as(loss)
489 | if args.fp16:
490 | optimizer.backward(loss)
491 | else:
492 | loss.backward()
493 |
494 | train_loss += loss.float().item()
495 |
496 | if args.fp16:
497 | optimizer.clip_master_grads(args.clip)
498 | else:
499 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
500 |
501 | optimizer.step()
502 | if args.sample_softmax > 0:
503 | optimizer_sparse.step()
504 |
505 | # step-wise learning rate annealing
506 | train_step += 1
507 | if args.scheduler in ['cosine', 'constant', 'dev_perf']:
508 | # linear warmup stage
509 | if train_step < args.warmup_step:
510 | curr_lr = args.lr * train_step / args.warmup_step
511 | optimizer.param_groups[0]['lr'] = curr_lr
512 | if args.sample_softmax > 0:
513 | optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
514 | else:
515 | if args.scheduler == 'cosine':
516 | scheduler.step(train_step)
517 | if args.sample_softmax > 0:
518 | scheduler_sparse.step(train_step)
519 | elif args.scheduler == 'inv_sqrt':
520 | scheduler.step(train_step)
521 |
522 | if train_step % args.log_interval == 0:
523 | cur_loss = train_loss / args.log_interval
524 | elapsed = time.time() - log_start_time
525 | log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
526 | '| ms/batch {:5.2f} | loss {:5.2f}'.format(
527 | epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
528 | elapsed * 1000 / args.log_interval, cur_loss)
529 | if args.dataset in ['enwik8', 'text8']:
530 | log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
531 | else:
532 | log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
533 | logging(log_str)
534 | train_loss = 0
535 | log_start_time = time.time()
536 |
537 | if train_step % args.eval_interval == 0:
538 | val_loss = evaluate(va_iter)
539 | logging('-' * 100)
540 | log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
541 | '| valid loss {:5.2f}'.format(
542 | train_step // args.eval_interval, train_step,
543 | (time.time() - eval_start_time), val_loss)
544 | if args.dataset in ['enwik8', 'text8']:
545 | log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
546 | else:
547 | log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
548 | logging(log_str)
549 | logging('-' * 100)
550 | # Save the model if the validation loss is the best we've seen so far.
551 | if not best_val_loss or val_loss < best_val_loss:
552 | if not args.debug:
553 | save_path = os.path.join(args.work_dir, 'model.pt')
554 | checkpoint = {
555 | 'state_dict': model.state_dict(),
556 | 'config': args
557 | }
558 | torch.save(checkpoint, save_path)
559 |
560 | # print('Best tmp model f1score: {}'.format(best_score))
561 | #
562 | # with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
563 | # torch.save(model, f)
564 | # with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
565 | # torch.save(optimizer.state_dict(), f)
566 |
567 |
568 | best_val_loss = val_loss
569 |
570 | # dev-performance based learning rate annealing
571 | if args.scheduler == 'dev_perf':
572 | scheduler.step(val_loss)
573 | if args.sample_softmax > 0:
574 | scheduler_sparse.step(val_loss)
575 |
576 | eval_start_time = time.time()
577 |
578 | if train_step == args.max_step:
579 | break
580 |
581 | # Loop over epochs.
582 | train_step = 0
583 | train_loss = 0
584 | best_val_loss = None
585 |
586 | log_start_time = time.time()
587 | eval_start_time = time.time()
588 |
589 | # At any point you can hit Ctrl + C to break out of training early.
590 | try:
591 | for epoch in itertools.count(start=1):
592 | print("许海明提醒你:开始",epoch)
593 | train()
594 | if train_step == args.max_step:
595 | logging('-' * 100)
596 | logging('End of training')
597 | break
598 | except KeyboardInterrupt:
599 | logging('-' * 100)
600 | logging('Exiting from training early')
601 |
602 | # # Load the best saved model.
603 | # with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
604 | # model = torch.load(f)
605 | # para_model = model.to(device)
606 | #
607 | # # Run on test data.
608 | # test_loss = evaluate(te_iter)
609 | # logging('=' * 100)
610 | # if args.dataset in ['enwik8', 'text8']:
611 | # logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
612 | # test_loss, test_loss / math.log(2)))
613 | # else:
614 | # logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
615 | # test_loss, math.exp(test_loss)))
616 | # logging('=' * 100)
617 |
--------------------------------------------------------------------------------
/code_for_LM/utils/adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class AdaptiveLogSoftmax(nn.Module):
10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
11 | super(AdaptiveLogSoftmax, self).__init__()
12 |
13 | cutoffs = list(cutoffs)
14 |
15 | if (cutoffs != sorted(cutoffs)) \
16 | or (min(cutoffs) <= 0) \
17 | or (max(cutoffs) >= (n_classes - 1)) \
18 | or (len(set(cutoffs)) != len(cutoffs)) \
19 | or any([int(c) != c for c in cutoffs]):
20 |
21 | raise ValueError("cutoffs should be a sequence of unique, positive "
22 | "integers sorted in an increasing order, where "
23 | "each value is between 1 and n_classes-1")
24 |
25 | self.in_features = in_features
26 | self.n_classes = n_classes
27 | self.cutoffs = cutoffs + [n_classes]
28 |
29 | self.shortlist_size = self.cutoffs[0]
30 | self.n_clusters = len(self.cutoffs) - 1
31 | self.head_size = self.shortlist_size + self.n_clusters
32 |
33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
35 |
36 | self.keep_order = keep_order
37 |
38 |
39 | def forward(self, hidden, target, weight, bias, keep_order=False):
40 | if hidden.size(0) != target.size(0):
41 | raise RuntimeError('Input and target should have the same size '
42 | 'in the batch dimension.')
43 |
44 | head_weight = torch.cat(
45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0)
46 | head_bias = torch.cat(
47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0)
48 |
49 | head_logit = F.linear(hidden, head_weight, bias=head_bias)
50 | head_logprob = F.log_softmax(head_logit, dim=1)
51 |
52 | nll = torch.zeros_like(target,
53 | dtype=hidden.dtype, device=hidden.device)
54 |
55 | offset = 0
56 | cutoff_values = [0] + self.cutoffs
57 | for i in range(len(cutoff_values) - 1):
58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
59 |
60 | mask_i = (target >= l_idx) & (target < h_idx)
61 | indices_i = mask_i.nonzero().squeeze()
62 |
63 | if indices_i.numel() == 0:
64 | continue
65 |
66 | target_i = target.index_select(0, indices_i) - l_idx
67 | head_logprob_i = head_logprob.index_select(0, indices_i)
68 |
69 | if i == 0:
70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
71 | else:
72 | weight_i = weight[l_idx:h_idx]
73 | bias_i = bias[l_idx:h_idx]
74 |
75 | hidden_i = hidden.index_select(0, indices_i)
76 |
77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
79 |
80 | logprob_i = head_logprob_i[:, -i] \
81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
82 |
83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
84 | nll.index_copy_(0, indices_i, -logprob_i)
85 | else:
86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
87 |
88 | offset += logprob_i.size(0)
89 |
90 | return nll
91 |
--------------------------------------------------------------------------------
/code_for_LM/utils/data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn.parallel import DataParallel
3 | import torch
4 | from torch.nn.parallel._functions import Scatter
5 | from torch.nn.parallel.parallel_apply import parallel_apply
6 |
7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0):
8 | r"""
9 | Slices tensors into approximately equal chunks and
10 | distributes them across given GPUs. Duplicates
11 | references to objects that are not tensors.
12 | """
13 | def scatter_map(obj):
14 | if isinstance(obj, torch.Tensor):
15 | try:
16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
17 | except:
18 | print('obj', obj.size())
19 | print('dim', dim)
20 | print('chunk_sizes', chunk_sizes)
21 | quit()
22 | if isinstance(obj, tuple) and len(obj) > 0:
23 | return list(zip(*map(scatter_map, obj)))
24 | if isinstance(obj, list) and len(obj) > 0:
25 | return list(map(list, zip(*map(scatter_map, obj))))
26 | if isinstance(obj, dict) and len(obj) > 0:
27 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
28 | return [obj for targets in target_gpus]
29 |
30 | # After scatter_map is called, a scatter_map cell will exist. This cell
31 | # has a reference to the actual function scatter_map, which has references
32 | # to a closure that has a reference to the scatter_map cell (because the
33 | # fn is recursive). To avoid this reference cycle, we set the function to
34 | # None, clearing the cell
35 | try:
36 | return scatter_map(inputs)
37 | finally:
38 | scatter_map = None
39 |
40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
41 | r"""Scatter with support for kwargs dictionary"""
42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
44 | if len(inputs) < len(kwargs):
45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
46 | elif len(kwargs) < len(inputs):
47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
48 | inputs = tuple(inputs)
49 | kwargs = tuple(kwargs)
50 | return inputs, kwargs
51 |
52 | class BalancedDataParallel(DataParallel):
53 | def __init__(self, gpu0_bsz, *args, **kwargs):
54 | self.gpu0_bsz = gpu0_bsz
55 | super().__init__(*args, **kwargs)
56 |
57 | def forward(self, *inputs, **kwargs):
58 | if not self.device_ids:
59 | return self.module(*inputs, **kwargs)
60 | if self.gpu0_bsz == 0:
61 | device_ids = self.device_ids[1:]
62 | else:
63 | device_ids = self.device_ids
64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
65 | if len(self.device_ids) == 1:
66 | return self.module(*inputs[0], **kwargs[0])
67 | replicas = self.replicate(self.module, self.device_ids)
68 | if self.gpu0_bsz == 0:
69 | replicas = replicas[1:]
70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
71 | return self.gather(outputs, self.output_device)
72 |
73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs):
74 | return parallel_apply(replicas, inputs, kwargs, device_ids)
75 |
76 | def scatter(self, inputs, kwargs, device_ids):
77 | bsz = inputs[0].size(self.dim)
78 | num_dev = len(self.device_ids)
79 | gpu0_bsz = self.gpu0_bsz
80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
81 | if gpu0_bsz < bsz_unit:
82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
83 | delta = bsz - sum(chunk_sizes)
84 | for i in range(delta):
85 | chunk_sizes[i + 1] += 1
86 | if gpu0_bsz == 0:
87 | chunk_sizes = chunk_sizes[1:]
88 | else:
89 | return super().scatter(inputs, kwargs, device_ids)
90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
91 |
92 |
--------------------------------------------------------------------------------
/code_for_LM/utils/exp_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os, shutil
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 |
9 | def logging(s, log_path, print_=True, log_=True):
10 | if print_:
11 | print(s)
12 | if log_:
13 | with open(log_path, 'a+') as f_log:
14 | f_log.write(s + '\n')
15 |
16 | def get_logger(log_path, **kwargs):
17 | return functools.partial(logging, log_path=log_path, **kwargs)
18 |
19 |
20 | # 许海明
21 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
22 | if debug:
23 | print('Debug Mode : no experiment dir created')
24 | return functools.partial(logging, log_path=None, log_=False)
25 |
26 | if not os.path.exists(dir_path):
27 | os.makedirs(dir_path)
28 |
29 | print('Experiment dir : {}'.format(dir_path))
30 | if scripts_to_save is not None:
31 | script_path = os.path.join(dir_path, 'scripts')
32 | if not os.path.exists(script_path):
33 | os.makedirs(script_path)
34 | for script in scripts_to_save:
35 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
36 | shutil.copyfile(script, dst_file)
37 |
38 | return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
39 |
40 | def save_checkpoint(model, optimizer, path, epoch):
41 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
42 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
43 |
--------------------------------------------------------------------------------
/code_for_LM/utils/log_uniform_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | class LogUniformSampler(object):
6 | def __init__(self, range_max, n_sample):
7 | """
8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
10 |
11 | expected count can be approximated by 1 - (1 - p)^n
12 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
13 |
14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
15 | """
16 | with torch.no_grad():
17 | self.range_max = range_max
18 | log_indices = torch.arange(1., range_max+2., 1.).log_()
19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
20 | # print('P', self.dist.numpy().tolist()[-30:])
21 |
22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
23 |
24 | self.n_sample = n_sample
25 |
26 | def sample(self, labels):
27 | """
28 | labels: [b1, b2]
29 | Return
30 | true_log_probs: [b1, b2]
31 | samp_log_probs: [n_sample]
32 | neg_samples: [n_sample]
33 | """
34 |
35 | # neg_samples = torch.empty(0).long()
36 | n_sample = self.n_sample
37 | n_tries = 2 * n_sample
38 |
39 | with torch.no_grad():
40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
41 | device = labels.device
42 | neg_samples = neg_samples.to(device)
43 | true_log_probs = self.log_q[labels].to(device)
44 | samp_log_probs = self.log_q[neg_samples].to(device)
45 | return true_log_probs, samp_log_probs, neg_samples
46 |
47 | def sample_logits(embedding, bias, labels, inputs, sampler):
48 | """
49 | embedding: an nn.Embedding layer
50 | bias: [n_vocab]
51 | labels: [b1, b2]
52 | inputs: [b1, b2, n_emb]
53 | sampler: you may use a LogUniformSampler
54 | Return
55 | logits: [b1, b2, 1 + n_sample]
56 | """
57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
58 | n_sample = neg_samples.size(0)
59 | b1, b2 = labels.size(0), labels.size(1)
60 | all_ids = torch.cat([labels.view(-1), neg_samples])
61 | all_w = embedding(all_ids)
62 | true_w = all_w[: -n_sample].view(b1, b2, -1)
63 | sample_w = all_w[- n_sample:].view(n_sample, -1)
64 |
65 | all_b = bias[all_ids]
66 | true_b = all_b[: -n_sample].view(b1, b2)
67 | sample_b = all_b[- n_sample:]
68 |
69 | hit = (labels[:, :, None] == neg_samples).detach()
70 |
71 | true_logits = torch.einsum('ijk,ijk->ij',
72 | [true_w, inputs]) + true_b - true_log_probs
73 | sample_logits = torch.einsum('lk,ijk->ijl',
74 | [sample_w, inputs]) + sample_b - samp_log_probs
75 | sample_logits.masked_fill_(hit, -1e30)
76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
77 |
78 | return logits
79 |
80 |
81 | # class LogUniformSampler(object):
82 | # def __init__(self, range_max, unique=False):
83 | # """
84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
86 | # """
87 | # self.range_max = range_max
88 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
90 |
91 | # self.unique = unique
92 |
93 | # if self.unique:
94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
95 |
96 | # def sample(self, n_sample, labels):
97 | # pos_sample, new_labels = labels.unique(return_inverse=True)
98 | # n_pos_sample = pos_sample.size(0)
99 | # n_neg_sample = n_sample - n_pos_sample
100 |
101 | # if self.unique:
102 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
104 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
105 | # else:
106 | # sample_dist = self.dist
107 |
108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
109 |
110 | # sample = torch.cat([pos_sample, neg_sample])
111 | # sample_prob = self.dist[sample]
112 |
113 | # return new_labels, sample, sample_prob
114 |
115 |
116 | if __name__ == '__main__':
117 | S, B = 3, 4
118 | n_vocab = 10000
119 | n_sample = 5
120 | H = 32
121 |
122 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
123 |
124 | # sampler = LogUniformSampler(n_vocab, unique=False)
125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
126 |
127 | sampler = LogUniformSampler(n_vocab, unique=True)
128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
129 |
130 | # print('true_probs', true_probs.numpy().tolist())
131 | # print('samp_probs', samp_probs.numpy().tolist())
132 | # print('neg_samples', neg_samples.numpy().tolist())
133 |
134 | # print('sum', torch.sum(sampler.dist).item())
135 |
136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
137 |
138 | embedding = nn.Embedding(n_vocab, H)
139 | bias = torch.zeros(n_vocab)
140 | inputs = torch.Tensor(S, B, H).normal_()
141 |
142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
143 | print('logits', logits.detach().numpy().tolist())
144 | print('logits shape', logits.size())
145 | print('out_labels', out_labels.detach().numpy().tolist())
146 | print('out_labels shape', out_labels.size())
147 |
148 |
--------------------------------------------------------------------------------
/code_for_LM/utils/proj_adaptive_softmax.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1])
11 |
12 | class ProjectedAdaptiveLogSoftmax(nn.Module):
13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
14 | keep_order=False):
15 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
16 |
17 | self.n_token = n_token
18 | self.d_embed = d_embed
19 | self.d_proj = d_proj
20 |
21 | self.cutoffs = cutoffs + [n_token]
22 | self.cutoff_ends = [0] + self.cutoffs
23 | self.div_val = div_val
24 |
25 | self.shortlist_size = self.cutoffs[0]
26 | self.n_clusters = len(self.cutoffs) - 1
27 | self.head_size = self.shortlist_size + self.n_clusters
28 |
29 | if self.n_clusters > 0:
30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
32 |
33 | self.out_layers = nn.ModuleList()
34 | self.out_projs = nn.ParameterList()
35 |
36 | if div_val == 1:
37 | for i in range(len(self.cutoffs)):
38 | if d_proj != d_embed:
39 | self.out_projs.append(
40 | nn.Parameter(torch.Tensor(d_proj, d_embed))
41 | )
42 | else:
43 | self.out_projs.append(None)
44 |
45 | self.out_layers.append(nn.Linear(d_embed, n_token))
46 | else:
47 | for i in range(len(self.cutoffs)):
48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
49 | d_emb_i = d_embed // (div_val ** i)
50 |
51 | self.out_projs.append(
52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
53 | )
54 |
55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
56 |
57 | self.keep_order = keep_order
58 |
59 | def _compute_logit(self, hidden, weight, bias, proj):
60 | if proj is None:
61 | logit = F.linear(hidden, weight, bias=bias)
62 | else:
63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
64 | proj_hid = F.linear(hidden, proj.t().contiguous())
65 | logit = F.linear(proj_hid, weight, bias=bias)
66 | # else:
67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
68 | # if bias is not None:
69 | # logit = logit + bias
70 |
71 | return logit
72 |
73 | def forward(self, hidden, target, keep_order=False):
74 | '''
75 | hidden :: [len*bsz x d_proj]
76 | target :: [len*bsz]
77 | '''
78 |
79 | if hidden.size(0) != target.size(0):
80 | raise RuntimeError('Input and target should have the same size '
81 | 'in the batch dimension.')
82 |
83 | if self.n_clusters == 0:
84 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
85 | self.out_layers[0].bias, self.out_projs[0])
86 | nll = -F.log_softmax(logit, dim=-1) \
87 | .gather(1, target.unsqueeze(1)).squeeze(1)
88 | else:
89 | # construct weights and biases
90 | weights, biases = [], []
91 | for i in range(len(self.cutoffs)):
92 | if self.div_val == 1:
93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
94 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
95 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
96 | else:
97 | weight_i = self.out_layers[i].weight
98 | bias_i = self.out_layers[i].bias
99 |
100 | if i == 0:
101 | weight_i = torch.cat(
102 | [weight_i, self.cluster_weight], dim=0)
103 | bias_i = torch.cat(
104 | [bias_i, self.cluster_bias], dim=0)
105 |
106 | weights.append(weight_i)
107 | biases.append(bias_i)
108 |
109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
110 |
111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
112 | head_logprob = F.log_softmax(head_logit, dim=1)
113 |
114 | nll = torch.zeros_like(target,
115 | dtype=hidden.dtype, device=hidden.device)
116 |
117 | offset = 0
118 | cutoff_values = [0] + self.cutoffs
119 | for i in range(len(cutoff_values) - 1):
120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
121 |
122 | mask_i = (target >= l_idx) & (target < r_idx)
123 | indices_i = mask_i.nonzero().squeeze()
124 |
125 | if indices_i.numel() == 0:
126 | continue
127 |
128 | target_i = target.index_select(0, indices_i) - l_idx
129 | head_logprob_i = head_logprob.index_select(0, indices_i)
130 |
131 | if i == 0:
132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
133 | else:
134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
135 |
136 | hidden_i = hidden.index_select(0, indices_i)
137 |
138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
140 |
141 | logprob_i = head_logprob_i[:, -i] \
142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
143 |
144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
145 | nll.index_copy_(0, indices_i, -logprob_i)
146 | else:
147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
148 |
149 | offset += logprob_i.size(0)
150 |
151 | return nll
152 |
--------------------------------------------------------------------------------
/code_for_LM/utils/vocabulary.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import Counter, OrderedDict
3 |
4 | import torch
5 |
6 |
7 |
8 | # 这是词典类
9 | class Vocab(object):
10 | def __init__(self, alinlen=3000,special=[], min_freq=10, max_size=100000, lower_case=True,
11 | delimiter=None, vocab_file=None):
12 | '''
13 | code
14 | :param special:
15 | :param min_freq:
16 | :param max_size:
17 | :param lower_case:
18 | :param delimiter:
19 | :param vocab_file:
20 | '''
21 | self.counter = Counter()
22 | self.special = special
23 | self.min_freq = min_freq
24 | self.max_size = max_size
25 | self.lower_case = lower_case
26 | self.delimiter = delimiter
27 | self.vocab_file = vocab_file
28 | self.alinlen = alinlen
29 |
30 |
31 | # 许海明
32 | def tokenize(self, line):
33 | line = line.strip()
34 | # convert to lower case
35 | if self.lower_case:
36 | line = line.lower()
37 |
38 | # empty delimiter '' will evaluate False
39 | if self.delimiter == '':
40 | symbols = line
41 | else:
42 | symbols = line.split(self.delimiter)
43 |
44 |
45 | if len(symbols) > self.alinlen-2:
46 | symbols = symbols[:self.alinlen-2]
47 |
48 | symbols = [''] + symbols + ['']
49 |
50 | len_pre = len(symbols)
51 | if len_pre == self.alinlen:
52 | return symbols
53 | else:
54 | assert len_pre']*(self.alinlen-len_pre)
56 | new_symbols.extend(symbols)
57 | assert len(new_symbols)==self.alinlen
58 | return new_symbols
59 |
60 |
61 |
62 | '''
63 | 统计单词
64 | 并返回分好词的数据 每一行元
65 | '''
66 | def count_file(self, path, verbose=False):
67 | if verbose:
68 | print('counting file {} ...'.format(path))
69 | assert os.path.exists(path)
70 |
71 | sents = []
72 | with open(path, 'r', encoding='utf-8') as f:
73 | for idx, line in enumerate(f):
74 | line = line.strip()
75 | if line is None or line=="":
76 | continue
77 | if verbose and idx > 0 and idx % 500000 == 0:
78 | print(' line {}'.format(idx))
79 | symbols = self.tokenize(line)
80 | self.counter.update(symbols)
81 | sents.append(symbols)
82 |
83 | return sents
84 |
85 | def count_sents(self, sents, verbose=False):
86 | """
87 | sents : a list of sentences, each a list of tokenized symbols
88 | """
89 | if verbose: print('counting {} sents ...'.format(len(sents)))
90 | for idx, symbols in enumerate(sents):
91 | if verbose and idx > 0 and idx % 500000 == 0:
92 | print(' line {}'.format(idx))
93 | self.counter.update(symbols)
94 |
95 | # 许海明
96 | def _build_from_file(self, vocab_file):
97 | self.idx2sym = []
98 | self.sym2idx = OrderedDict()
99 |
100 | with open(vocab_file, 'r', encoding='utf-8') as f:
101 | for line in f:
102 | symb = line.strip().split()[0]
103 | self.add_symbol(symb)
104 | self.unk_idx = self.sym2idx['']
105 |
106 |
107 | # 许海明
108 | def build_vocab(self):
109 | if self.vocab_file:
110 | print('building vocab from {}'.format(self.vocab_file))
111 | self._build_from_file(self.vocab_file)
112 | print('final vocab size {}'.format(len(self)))
113 | else:
114 | print('building vocab with min_freq={}, max_size={}'.format(
115 | self.min_freq, self.max_size))
116 | self.idx2sym = []
117 | self.sym2idx = OrderedDict()
118 |
119 | for sym in self.special:
120 | self.add_special(sym)
121 |
122 | for sym, cnt in self.counter.most_common(self.max_size):
123 | if cnt < self.min_freq:
124 | break
125 | self.add_symbol(sym)
126 |
127 | print('final vocab size {} from {} unique tokens'.format(
128 | len(self), len(self.counter)))
129 |
130 |
131 | # 许海明 对一个文件进行编码
132 | '''
133 | tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])
134 |
135 | a=torch.LongTensor([1,2,3])
136 | a1=torch.LongTensor([1,2,3])
137 | a2=torch.LongTensor([1,2,3])
138 |
139 | print(torch.cat([a,a1,a2]))
140 | '''
141 | def encode_file(self, path, ordered=False, verbose=False):
142 | '''
143 | ordered=True
144 | :param path:
145 | :param ordered:
146 | :param verbose:
147 | :param add_eos:
148 | :param add_double_eos:
149 | :return:
150 | '''
151 | if verbose:
152 | print('encoding file {} ...'.format(path))
153 | assert os.path.exists(path)
154 | encoded = []
155 | with open(path, 'r', encoding='utf-8') as f:
156 | for idx, line in enumerate(f):
157 | line = line.strip()
158 | if line is None or line == "":
159 | continue
160 | if verbose and idx > 0 and idx % 500000 == 0:
161 | print(' line {}'.format(idx))
162 | symbols = self.tokenize(line) # code 每一行加上结束
163 |
164 | encoded.append(self.convert_to_tensor(symbols))
165 |
166 | if ordered:
167 | encoded = torch.cat(encoded)
168 |
169 | return encoded
170 |
171 | # 许海明 这是对多个句子进行编码
172 | def encode_sents(self, sents, ordered=False, verbose=False):
173 | if verbose:
174 | print('encoding {} sents ...'.format(len(sents)))
175 | encoded = []
176 | for idx, symbols in enumerate(sents):
177 | if verbose and idx > 0 and idx % 500000 == 0:
178 | print(' line {}'.format(idx))
179 | encoded.append(self.convert_to_tensor(symbols))
180 |
181 | if ordered:
182 | encoded = torch.cat(encoded)
183 |
184 | return encoded
185 |
186 | # 许海明
187 | def add_special(self, sym):
188 | if sym not in self.sym2idx:
189 | self.idx2sym.append(sym)
190 | self.sym2idx[sym] = len(self.idx2sym) - 1
191 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
192 |
193 | # 许海明
194 | def add_symbol(self, sym):
195 | if sym not in self.sym2idx:
196 | self.idx2sym.append(sym)
197 | self.sym2idx[sym] = len(self.idx2sym) - 1
198 |
199 |
200 | def save_symbol(self, vocabfile):
201 | with open(vocabfile, mode="w", encoding="utf-8") as fw:
202 | for word in self.sym2idx:
203 | fw.write(word+"\n")
204 |
205 |
206 |
207 |
208 |
209 | def get_sym(self, idx):
210 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
211 | return self.idx2sym[idx]
212 |
213 | def get_idx(self, sym):
214 | if sym in self.sym2idx:
215 | return self.sym2idx[sym]
216 | else:
217 | # print('encounter unk {}'.format(sym))
218 | assert '' not in sym
219 | assert hasattr(self, 'unk_idx')
220 | return self.sym2idx.get(sym, self.unk_idx)
221 |
222 | def get_symbols(self, indices):
223 | return [self.get_sym(idx) for idx in indices]
224 |
225 | def get_indices(self, symbols):
226 | return [self.get_idx(sym) for sym in symbols]
227 |
228 | # 许海明
229 | def convert_to_tensor(self, symbols):
230 | return torch.LongTensor(self.get_indices(symbols))
231 |
232 | def convert_to_sent(self, indices, exclude=None):
233 | if exclude is None:
234 | return ' '.join([self.get_sym(idx) for idx in indices])
235 | else:
236 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
237 |
238 | def __len__(self):
239 | return len(self.idx2sym)
240 |
--------------------------------------------------------------------------------
/code_for_LM/运行命令:
--------------------------------------------------------------------------------
1 | # 这是训练的目标
2 |
3 | python train.py \
4 | --cuda \
5 | --data ../data/LM/ \
6 | --work_dir ../results/LM/ \
7 | --dataset dccl \
8 | --n_layer 6 \
9 | --d_model 410 \
10 | --n_head 10 \
11 | --d_head 41 \
12 | --tgt_len 150 \
13 | --mem_len 150 \
14 | --eval_tgt_len 150 \
15 | --d_inner 2100 \
16 | --dropout 0.1 \
17 | --dropatt 0.0 \
18 | --optim adam \
19 | --lr 0.00025 \
20 | --warmup_step 0 \
21 | --max_step 200000 \
22 | --batch_size 32 \
23 | --multi_gpu \
24 |
25 |
--------------------------------------------------------------------------------
/image/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuhaiming1996/Transformer-xl-classify/7696d0b328aac4665d851aac5f84e49000919aae/image/model.jpg
--------------------------------------------------------------------------------
/说明:
--------------------------------------------------------------------------------
1 | 这是一个使用多卡机训练的例子 是一个非常值得学习的代码
--------------------------------------------------------------------------------