├── README.md ├── losses └── mtloss.py ├── modules ├── state.py ├── feature_map.py ├── encoder.py ├── linear_attention.py ├── cross_attention.py ├── self_attention.py ├── causal_linear_attention.py └── decoder.py ├── models ├── classification.py ├── classfication_vanilla.py └── seq2seq_model.py ├── LinearKernel └── LinearKernel │ └── causal_product │ ├── __init__.py │ ├── causal_product_cpu.cpp │ └── causal_product_cuda.cu ├── data ├── data_loader.py └── iwslt17_loader.py ├── metrics └── bleu.py ├── train.py ├── train_vanilla.py ├── test └── test_transformer.py └── train_seq2seq.py /README.md: -------------------------------------------------------------------------------- 1 | # Linear-Transformer 2 | Transformer are RNNs: Fast Autoregressive Transformer with Linear Attention 3 | -------------------------------------------------------------------------------- /losses/mtloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from fastNLP import LossBase 4 | from fastNLP import seq_len_to_mask 5 | 6 | class MTLoss(LossBase): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def get_loss(self, pred, tgt_tokens, tgt_seq_len): 11 | tgt_seq_len = tgt_seq_len - 1 12 | mask = seq_len_to_mask(tgt_seq_len).eq(0) 13 | tgt_tokens = tgt_tokens[:, 1:].masked_fill(mask, -100) 14 | loss = F.cross_entropy(target=tgt_tokens, input=pred[:, :-1].transpose(1, 2)) 15 | return loss 16 | -------------------------------------------------------------------------------- /modules/state.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastNLP.modules.decoder.seq2seq_state import TransformerState, State 3 | 4 | class LinearTransformerState(State): 5 | def __init__(self, encoder_output, encoder_mask, num_decoder_layer): 6 | """ 7 | 与TransformerSeq2SeqDecoder对应的State, 8 | 9 | :param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 10 | :param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend 11 | :param int num_decoder_layer: decode有多少层 12 | """ 13 | super().__init__(encoder_output, encoder_mask) 14 | self.encoder_key = [None] * num_decoder_layer 15 | self.encoder_value = [None] * num_decoder_layer 16 | self.decoder_k_sum = [0] * num_decoder_layer 17 | self.decoder_kv_sum = [0] * num_decoder_layer 18 | self.decode_length = 0 19 | 20 | def reorder_state(self, indices: torch.LongTensor): 21 | super().reorder_state(indices) 22 | self.encoder_key = self._reorder_state(self.encoder_key, indices) 23 | self.encoder_value = self._reorder_state(self.encoder_value, indices) 24 | self.decoder_k_sum = self._reorder_state(self.decoder_k_sum, indices) 25 | self.decoder_kv_sum = self._reorder_state(self.decoder_kv_sum, indices) 26 | -------------------------------------------------------------------------------- /models/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union, Tuple 4 | import torch.nn.functional as F 5 | from fastNLP.embeddings import StaticEmbedding 6 | from fastNLP.core.utils import seq_len_to_mask 7 | from modules.encoder import LinearTransformerEncoder 8 | from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder 9 | 10 | 11 | class TextClassification(nn.Module): 12 | def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed=None, 13 | num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, class_num=2): 14 | super(TextClassification, self).__init__() 15 | self.encoder = LinearTransformerEncoder(embed=embed, 16 | pos_embed=pos_embed, 17 | num_layers=num_layers, 18 | d_model=d_model, 19 | n_head=n_head, 20 | dim_ff=dim_ff, 21 | dropout=dropout) 22 | 23 | self.linear = nn.Linear(d_model, class_num) 24 | 25 | def forward(self, words, seq_len): 26 | x, _ = self.encoder(words, seq_len) 27 | feats, _ = torch.max(x, dim=1) 28 | logits = self.linear(feats) 29 | return {'pred': logits} -------------------------------------------------------------------------------- /models/classfication_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union, Tuple 4 | import torch.nn.functional as F 5 | from fastNLP.embeddings import StaticEmbedding 6 | from fastNLP.core.utils import seq_len_to_mask 7 | from modules.encoder import LinearTransformerEncoder 8 | from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder 9 | 10 | 11 | class TextClassification(nn.Module): 12 | def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed=None, 13 | num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, class_num=2): 14 | super(TextClassification, self).__init__() 15 | self.encoder = TransformerSeq2SeqEncoder(embed=embed, 16 | pos_embed=pos_embed, 17 | num_layers=num_layers, 18 | d_model=d_model, 19 | n_head=n_head, 20 | dim_ff=dim_ff, 21 | dropout=dropout) 22 | 23 | self.linear = nn.Linear(d_model, class_num) 24 | 25 | def forward(self, words, seq_len): 26 | x, _ = self.encoder(words, seq_len) 27 | feats, _ = torch.max(x, dim=1) 28 | logits = self.linear(feats) 29 | return {'pred': logits} -------------------------------------------------------------------------------- /modules/feature_map.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.nn import Module 5 | 6 | 7 | class FeatureMap(Module): 8 | """Define the FeatureMap interface.""" 9 | def __init__(self, query_dims): 10 | super().__init__() 11 | self.query_dims = query_dims 12 | 13 | def new_feature_map(self): 14 | """Create a new instance of this feature map. In particular, if it is a 15 | random feature map sample new parameters.""" 16 | raise NotImplementedError() 17 | 18 | def forward_queries(self, x): 19 | """Encode the queries `x` using this feature map.""" 20 | return self(x) 21 | 22 | def forward_keys(self, x): 23 | """Encode the keys `x` using this feature map.""" 24 | return self(x) 25 | 26 | def forward(self, x): 27 | """Encode x using this feature map. For symmetric feature maps it 28 | suffices to define this function, but for asymmetric feature maps one 29 | needs to define the `forward_queries` and `forward_keys` functions.""" 30 | raise NotImplementedError() 31 | 32 | @classmethod 33 | def factory(cls, *args, **kwargs): 34 | """Return a function that when called with the query dimensions returns 35 | an instance of this feature map. 36 | It is inherited by the subclasses so it is available in all feature 37 | maps. 38 | """ 39 | def inner(query_dims): 40 | return cls(query_dims, *args, **kwargs) 41 | return inner 42 | 43 | 44 | class ActivationFunctionFeatureMap(FeatureMap): 45 | """Define a feature map that is simply an element-wise activation 46 | function.""" 47 | def __init__(self, query_dims, activation_function): 48 | super().__init__(query_dims) 49 | self.activation_function = activation_function 50 | 51 | def new_feature_map(self): 52 | return 53 | 54 | def forward(self, x): 55 | return self.activation_function(x) 56 | 57 | 58 | elu_feature_map = ActivationFunctionFeatureMap.factory( 59 | lambda x: torch.nn.functional.elu(x) + 1 60 | ) -------------------------------------------------------------------------------- /LinearKernel/LinearKernel/causal_product/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \ 4 | causal_dot_backward as causal_dot_backward_cpu 5 | 6 | try: 7 | from .causal_product_cuda import \ 8 | causal_dot_product as causal_dot_product_cuda, \ 9 | causal_dot_backward as causal_dot_backward_cuda 10 | except ImportError: 11 | causal_dot_product_cuda = causal_dot_backward_cuda = None 12 | 13 | 14 | class CausalDotProduct(torch.autograd.Function): 15 | """Compute the weighted sum of values but attending only to previous 16 | values.""" 17 | dot = { 18 | "cpu": causal_dot_product_cpu, 19 | "cuda": causal_dot_product_cuda 20 | } 21 | dot_backward = { 22 | "cpu": causal_dot_backward_cpu, 23 | "cuda": causal_dot_backward_cuda 24 | } 25 | 26 | @staticmethod 27 | def forward(ctx, Q, K, V): 28 | # Save the inputs for the gradient computation 29 | ctx.save_for_backward(Q, K, V) 30 | 31 | # Create the output tensor 32 | device = Q.device 33 | N, H, L, _ = Q.shape 34 | _, _, _, M = V.shape 35 | product = torch.zeros((N, H, L, M), device=device) 36 | 37 | # Actually perform the dot product 38 | CausalDotProduct.dot[device.type]( 39 | Q.data, 40 | K.data, 41 | V.data, 42 | product 43 | ) 44 | 45 | return product 46 | 47 | @staticmethod 48 | def backward(ctx, grad_out): 49 | # Extract the saved tensors 50 | Q, K, V = ctx.saved_tensors 51 | 52 | # Allocate memory for the gradients 53 | grad_Q = torch.zeros_like(Q) 54 | grad_K = torch.zeros_like(K) 55 | grad_V = torch.zeros_like(V) 56 | 57 | # Actually compute the gradients 58 | CausalDotProduct.dot_backward[Q.device.type]( 59 | Q.data, 60 | K.data, 61 | V.data, 62 | grad_out, 63 | grad_Q, 64 | grad_K, 65 | grad_V 66 | ) 67 | 68 | return grad_Q, grad_K, grad_V 69 | 70 | 71 | # Alias the autograd functions to python style snake case naming 72 | causal_dot_product = CausalDotProduct.apply -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from numpy import * 4 | import fastNLP 5 | from fastNLP import DataSet 6 | from nltk.corpus import stopwords 7 | from fastNLP import Vocabulary 8 | from fastNLP.io import IMDBLoader, IMDBPipe, DataBundle 9 | from fastNLP.io import SST2Loader, SST2Pipe 10 | 11 | stop_words = stopwords.words('english') 12 | 13 | def load_data(data_path): 14 | if os.path.exists(data_path): 15 | print("===已经存在处理好的数据,正在加载===") 16 | data_bundle = torch.load(data_path) 17 | else: 18 | print("===正在处理数据并保存到指定文件===") 19 | loader = SST2Loader() 20 | data_dir = loader.download() 21 | data_bundle = loader.load(data_dir) 22 | data_bundle = SST2Pipe(lower=True).process(data_bundle) 23 | # loader = IMDBLoader() 24 | # data_dir = loader.download() 25 | # data_bundle = loader.load(data_dir) 26 | # data_bundle = IMDBPipe(lower=True).process(data_bundle) 27 | torch.save(data_bundle, data_path) 28 | # data_bundle.get_dataset('train').drop(lambda ins: len(ins['words']) > 1000) 29 | # data_bundle.get_dataset('test').drop(lambda ins: len(ins['words']) > 1000) 30 | return data_bundle 31 | 32 | if __name__ == '__main__': 33 | data_path = './imdb_data_bundle.pt' 34 | data_bundle = load_data(data_path) 35 | print(data_bundle) 36 | max_length_train = max([seq_len for seq_len in data_bundle.get_dataset('train')['seq_len']]) 37 | # max_length_dev = max([seq_len for seq_len in data_bundle.get_dataset('dev')['seq_len']]) 38 | max_length_test = max([seq_len for seq_len in data_bundle.get_dataset('test')['seq_len']]) 39 | max_length = max(max_length_train, max_length_test) 40 | print("数据集样本最大长度max_length =", max_length) 41 | 42 | len_train = [seq_len for seq_len in data_bundle.get_dataset('train')['seq_len']] 43 | mean_len_train = mean(len_train) 44 | # len_dev = [seq_len for seq_len in data_bundle.get_dataset('dev')['seq_len']] 45 | # mean_len_dev = mean(len_dev) 46 | len_test = [seq_len for seq_len in data_bundle.get_dataset('test')['seq_len']] 47 | mean_len_test = mean(len_test) 48 | print("The average length of train data is", int(mean_len_train)) 49 | # print("The average length of dev data is", int(mean_len_dev)) 50 | print("The average length of test data is", int(mean_len_test)) 51 | print(mean([mean_len_train, mean_len_test])) -------------------------------------------------------------------------------- /metrics/bleu.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # from fastNLP import MetricBase 3 | # from fastNLP import Vocabulary 4 | # from torchtext.data.metrics import bleu_score 5 | # 6 | # class BLUEMetric(MetricBase): 7 | # def __init__(self, tgt_vocab: Vocabulary): 8 | # super().__init__() 9 | # self.tgt_vocab = tgt_vocab 10 | # self.pred_tgt = [] 11 | # self.gold_tgt = [] 12 | # 13 | # def evaluate(self, tgt_tokens: torch.Tensor, pred: torch.Tensor, tgt_seq_len: torch.Tensor): 14 | # if pred.dim == 3: 15 | # pred = pred.argmax(dim=-1) 16 | # 17 | # batch_size, _ = pred.size() 18 | 19 | # assert batch_size == tgt_tokens.size(0) 20 | # for b in range(batch_size): 21 | # pred_tgt = [self.tgt_vocab.idx2word[index.item()] for index in pred[b]] 22 | # gold_tgt = [self.tgt_vocab.idx2word[index.item()] for index in tgt_tokens[b]] 23 | # gold_tgt = gold_tgt[1:tgt_seq_len[b].item()] 24 | # self.pred_tgt.append(pred_tgt) 25 | # self.gold_tgt.append([gold_tgt]) 26 | # 27 | # def get_metric(self, reset=True): 28 | # res = {"BLEU Score": bleu_score(self.pred_tgt, self.gold_tgt)} 29 | # if reset: 30 | # self.pred_tgt = [] 31 | # self.gold_tgt = [] 32 | # return res 33 | 34 | from fastNLP import MetricBase 35 | import sacrebleu 36 | 37 | 38 | class BLEUMetric(MetricBase): 39 | def __init__(self, vocab, eos_index, bpe_indicator='@@'): 40 | super().__init__() 41 | self.vocab = vocab 42 | self.eos_index = eos_index 43 | self.bpe_indicator = bpe_indicator 44 | self.goldens = [] 45 | self.preds = [] 46 | self.get_golden = True 47 | 48 | def evaluate(self, tgt_tokens, tgt_seq_len, pred): 49 | """ 50 | 51 | :param tgt_tokens: bsz x max_len (构成为[] + [tokens] + []) 52 | :param tgt_seq_len: bsz 53 | :param pred: bsz x max_len' (构成为[] + [tokens] + []) 54 | :return: 55 | """ 56 | for i in range(tgt_tokens.size(0)): 57 | self.goldens.append(' '.join(map(self.vocab.to_word, tgt_tokens[i, 1:tgt_seq_len[i]-1].tolist())).replace(f'{self.bpe_indicator} ', '')) 58 | 59 | for i in range(pred.size(0)): 60 | words = [] 61 | for idx in pred[i, 1:].tolist(): 62 | if idx==self.eos_index: 63 | break 64 | words.append(self.vocab.to_word(idx)) 65 | self.preds.append(' '.join(words).replace(f'{self.bpe_indicator} ', '')) 66 | 67 | def get_metric(self, reset=True): 68 | bleu = sacrebleu.corpus_bleu(self.preds, [self.goldens], force=True) 69 | if reset: 70 | self.preds = [] 71 | self.goldens = [] 72 | return {'bleu': bleu.score} 73 | 74 | 75 | -------------------------------------------------------------------------------- /models/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from fastNLP.embeddings import get_embeddings 5 | from fastNLP.models.seq2seq_model import Seq2SeqModel 6 | from fastNLP.embeddings.utils import get_sinusoid_encoding_table 7 | from modules.encoder import LinearTransformerEncoder 8 | from modules.decoder import LinearTransformerDecoder 9 | 10 | 11 | class LinearTransformerSeq2SeqModel(Seq2SeqModel): 12 | def __init__(self, encoder, decoder): 13 | super().__init__(encoder, decoder) 14 | 15 | @classmethod 16 | def build_model(cls, src_embed, tgt_embed=None, pos_embed='sin', 17 | max_position=1024, num_layers=6, d_model=512, 18 | n_head=8, dim_ff=2048, dropout=0.1, 19 | bind_encoder_decoder_embed=False, 20 | bind_decoder_input_output_embed=True): 21 | 22 | if bind_encoder_decoder_embed and tgt_embed is not None: 23 | raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") 24 | 25 | src_embed = get_embeddings(src_embed) 26 | 27 | if bind_encoder_decoder_embed: 28 | tgt_embed = src_embed 29 | else: 30 | assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" 31 | tgt_embed = get_embeddings(tgt_embed) 32 | 33 | if pos_embed == 'sin': 34 | encoder_pos_embed = nn.Embedding.from_pretrained( 35 | get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), 36 | freeze=True) # 这里规定0是padding 37 | deocder_pos_embed = nn.Embedding.from_pretrained( 38 | get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), 39 | freeze=True) # 这里规定0是padding 40 | elif pos_embed == 'learned': 41 | encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) 42 | deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) 43 | else: 44 | raise ValueError("pos_embed only supports sin or learned.") 45 | 46 | encoder = LinearTransformerEncoder(embed=src_embed, pos_embed=encoder_pos_embed, 47 | num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, 48 | dropout=dropout) 49 | decoder = LinearTransformerDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, 50 | d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, 51 | dropout=dropout, 52 | bind_decoder_input_output_embed=bind_decoder_input_output_embed) 53 | 54 | return cls(encoder, decoder) 55 | 56 | 57 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from torch import nn 3 | from torch import optim 4 | from data.data_loader import load_data 5 | from fastNLP import Trainer 6 | from fastNLP import AccuracyMetric 7 | from fastNLP import CrossEntropyLoss 8 | from fastNLP.embeddings import BertEmbedding, StaticEmbedding 9 | from fastNLP.embeddings.utils import get_embeddings 10 | from fastNLP.embeddings.utils import get_sinusoid_encoding_table 11 | from fastNLP import BucketSampler, GradientClipCallback, WarmupCallback 12 | from model.classification import TextClassification 13 | 14 | lr = 1e-5 15 | n_epochs = 10 16 | batch_size = 2 17 | n_head = 4 18 | d_model = 256 19 | dim_ff = 512 20 | dropout = 0.3 21 | num_layers = 6 22 | pos_embed = 'sin' 23 | data_path = './data/imdb_data_bundle.pt' 24 | bind_decoder_input_output_embed = False 25 | 26 | data_bundle = load_data(data_path) 27 | print(data_bundle) 28 | 29 | src_embed = StaticEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='en-glove-840B-300d') 30 | 31 | max_length_train = max([seq_len for seq_len in data_bundle.get_dataset('train')['seq_len']]) 32 | max_length_test = max([seq_len for seq_len in data_bundle.get_dataset('test')['seq_len']]) 33 | max_length = max(max_length_train, max_length_test) 34 | print("数据集样本最大长度max_length =", max_length) 35 | 36 | if pos_embed == 'sin': 37 | encoder_pos_embed = nn.Embedding.from_pretrained( 38 | get_sinusoid_encoding_table(max_length + 1, src_embed.embedding_dim, padding_idx=0), freeze=True 39 | ) 40 | elif pos_embed == 'learned': 41 | encoder_pos_embed = get_embeddings((max_length + 1, src_embed.embedding_dim), padding_idx=0) 42 | 43 | model = TextClassification(embed=src_embed, 44 | pos_embed=encoder_pos_embed, 45 | num_layers=num_layers, 46 | d_model=d_model, 47 | n_head=n_head, 48 | dim_ff=dim_ff, 49 | dropout=dropout, 50 | class_num=2) 51 | 52 | parametrs = [] 53 | params = {'lr': lr} 54 | params['params'] = [param for param in model.parameters() if param.requires_grad] 55 | parametrs.append(params) 56 | 57 | optimizer = optim.Adam(parametrs) 58 | 59 | callbacks = [] 60 | callbacks.append(GradientClipCallback(clip_value=1, clip_type='value')) 61 | callbacks.append(WarmupCallback(warmup=0.01, schedule='linear')) 62 | 63 | sampler = BucketSampler(seq_len_field_name='seq_len') 64 | trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, optimizer=optimizer, 65 | loss=CrossEntropyLoss(), batch_size=batch_size, sampler=sampler, 66 | drop_last=False, update_every=1, num_workers=2, n_epochs=n_epochs, 67 | print_every=1, dev_data=data_bundle.get_dataset('test'), 68 | metrics=AccuracyMetric(), metric_key=None, validate_every=-1, 69 | save_path=None, use_tqdm=True, device=0) 70 | 71 | 72 | start = time.time() 73 | trainer.train(load_best_model=False) 74 | end = time.time() 75 | print(end - start) -------------------------------------------------------------------------------- /train_vanilla.py: -------------------------------------------------------------------------------- 1 | import time 2 | from torch import nn 3 | from torch import optim 4 | from data.data_loader import load_data 5 | from fastNLP import Trainer 6 | from fastNLP import AccuracyMetric 7 | from fastNLP import CrossEntropyLoss 8 | from fastNLP.embeddings import BertEmbedding, StaticEmbedding 9 | from fastNLP.embeddings.utils import get_embeddings 10 | from fastNLP.embeddings.utils import get_sinusoid_encoding_table 11 | from fastNLP import BucketSampler, GradientClipCallback, WarmupCallback 12 | from model.classfication_vanilla import TextClassification 13 | 14 | lr = 1e-5 15 | n_epochs = 10 16 | batch_size = 2 17 | n_head = 4 18 | d_model = 256 19 | dim_ff = 512 20 | dropout = 0.3 21 | num_layers = 6 22 | pos_embed = 'sin' 23 | data_path = './data/imdb_data_bundle.pt' 24 | bind_decoder_input_output_embed = False 25 | 26 | data_bundle = load_data(data_path) 27 | print(data_bundle) 28 | 29 | src_embed = StaticEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='en-glove-840B-300d') 30 | 31 | max_length_train = max([seq_len for seq_len in data_bundle.get_dataset('train')['seq_len']]) 32 | max_length_test = max([seq_len for seq_len in data_bundle.get_dataset('test')['seq_len']]) 33 | max_length = max(max_length_train, max_length_test) 34 | print("数据集样本最大长度max_length =", max_length) 35 | 36 | if pos_embed == 'sin': 37 | encoder_pos_embed = nn.Embedding.from_pretrained( 38 | get_sinusoid_encoding_table(max_length + 1, src_embed.embedding_dim, padding_idx=0), freeze=True 39 | ) 40 | elif pos_embed == 'learned': 41 | encoder_pos_embed = get_embeddings((max_length + 1, src_embed.embedding_dim), padding_idx=0) 42 | 43 | model = TextClassification(embed=src_embed, 44 | pos_embed=encoder_pos_embed, 45 | num_layers=num_layers, 46 | d_model=d_model, 47 | n_head=n_head, 48 | dim_ff=dim_ff, 49 | dropout=dropout, 50 | class_num=2) 51 | 52 | parametrs = [] 53 | params = {'lr': lr} 54 | params['params'] = [param for param in model.parameters() if param.requires_grad] 55 | parametrs.append(params) 56 | 57 | optimizer = optim.Adam(parametrs) 58 | 59 | callbacks = [] 60 | callbacks.append(GradientClipCallback(clip_value=1, clip_type='value')) 61 | callbacks.append(WarmupCallback(warmup=0.01, schedule='linear')) 62 | 63 | sampler = BucketSampler(seq_len_field_name='seq_len') 64 | trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, optimizer=optimizer, 65 | loss=CrossEntropyLoss(), batch_size=batch_size, sampler=sampler, 66 | drop_last=False, update_every=1, num_workers=2, n_epochs=n_epochs, 67 | print_every=1, dev_data=data_bundle.get_dataset('test'), 68 | metrics=AccuracyMetric(), metric_key=None, validate_every=-1, 69 | save_path=None, use_tqdm=True, device=1) 70 | 71 | 72 | start = time.time() 73 | trainer.train(load_best_model=False) 74 | end = time.time() 75 | print(end - start) -------------------------------------------------------------------------------- /test/test_transformer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | from data.iwslt17_loader import IWSLT2017Pipe 4 | from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel 5 | from model.seq2seq_model import LinearTransformerSeq2SeqModel 6 | from fastNLP import Vocabulary 7 | from fastNLP.embeddings import StaticEmbedding 8 | from fastNLP.models import SequenceGeneratorModel 9 | import torch 10 | from torch import optim 11 | import torch.nn.functional as F 12 | from fastNLP import seq_len_to_mask 13 | 14 | def prepare_env(seq_len, batch_size): 15 | vocab = Vocabulary().add_word_lst("This is a test .".split()) 16 | vocab.add_word_lst("Another test !".split()) 17 | embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) 18 | 19 | src_words_idx = torch.arange(1, seq_len + 1) 20 | src_words_idx = src_words_idx.expand([batch_size, seq_len]) 21 | 22 | mask = [] 23 | lens_src = [] 24 | for i in range(batch_size): 25 | length = seq_len - i * 15 26 | lens_src.append(length) 27 | maski = torch.randint(1, 2, (1, length)).view(1, length) 28 | zero_extend = torch.zeros((1, seq_len-length), dtype=torch.long).view(1, seq_len - length) 29 | maski = torch.cat((maski, zero_extend), dim=1) 30 | mask.append(maski) 31 | mask = torch.cat(mask, dim=0).bool() 32 | src_words_idx = src_words_idx.masked_fill_(mask, 0).long() 33 | batch_src_words_idx = src_words_idx.expand((100, batch_size, seq_len)) 34 | 35 | tgt_words_idx = torch.arange(1, seq_len + 1) 36 | tgt_words_idx = tgt_words_idx.expand((batch_size, seq_len)) 37 | 38 | mask = [] 39 | lens_tgt = [] 40 | for i in range(batch_size): 41 | length = seq_len - i * 20 42 | lens_tgt.append(length) 43 | maski = torch.randint(1, 2, (1, length)).view(1, length) 44 | zero_extend = torch.zeros((1, seq_len - length), dtype=torch.long).view(1, seq_len - length) 45 | maski = torch.cat((maski, zero_extend), dim=1) 46 | mask.append(maski) 47 | 48 | mask = torch.cat(mask, dim=0).bool() 49 | tgt_words_idx = tgt_words_idx.masked_fill_(mask, 0).long() 50 | batch_tgt_words_idx = tgt_words_idx.expand((100, batch_size, seq_len)) 51 | 52 | src_seq_len = torch.tensor(lens_src, dtype=torch.long) 53 | batch_src_seq_len = src_seq_len.expand([100, batch_size]) 54 | tgt_seq_len = torch.tensor(lens_tgt, dtype=torch.long) 55 | batch_tgt_seq_len = tgt_seq_len.expand([100, batch_size]) 56 | 57 | return embed, batch_src_words_idx, batch_tgt_words_idx, batch_src_seq_len, batch_tgt_seq_len 58 | 59 | 60 | def train_model(model, batch_src_words_idx, batch_tgt_words_idx, batch_tgt_seq_len, batch_src_seq_len, device): 61 | print("===开始训练===") 62 | optimizer = optim.Adam(model.parameters(), lr=1e-2) 63 | 64 | for i in tqdm(range(20)): 65 | src_words_idx = batch_src_words_idx[i].to(device) 66 | tgt_words_idx = batch_tgt_words_idx[i].to(device) 67 | src_seq_len = batch_src_seq_len[i].to(device) 68 | tgt_seq_len = batch_tgt_seq_len[i].to(device) 69 | mask = seq_len_to_mask(tgt_seq_len).eq(0).to(device) 70 | target = tgt_words_idx.masked_fill(mask, 1e-5).to(device) 71 | 72 | optimizer.zero_grad() 73 | pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size 74 | loss = F.cross_entropy(pred.transpose(1, 2), target) 75 | loss.backward() 76 | optimizer.step() 77 | 78 | print("===训练结束===") 79 | 80 | 81 | # 测试能否train到overfit 82 | batch_size = 1 83 | seq_len = 65536 84 | device = torch.device('cuda') 85 | embed, batch_src_words_idx, batch_tgt_words_idx, batch_src_seq_len, \ 86 | batch_tgt_seq_len = prepare_env(seq_len, batch_size) 87 | 88 | model = LinearTransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, 89 | pos_embed='sin', max_position=80000, num_layers=3, d_model=256, n_head=4, 90 | dim_ff=512, dropout=0.1, bind_encoder_decoder_embed=True, 91 | bind_decoder_input_output_embed=True) 92 | 93 | model = model.to(device) 94 | start = time.clock() 95 | train_model(model, batch_src_words_idx, batch_tgt_words_idx, batch_tgt_seq_len, batch_src_seq_len, device) 96 | end = time.clock() 97 | print("训练时间为: ", end - start) 98 | 99 | -------------------------------------------------------------------------------- /modules/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from typing import Union, Tuple 5 | import torch.nn.functional as F 6 | from fastNLP.embeddings import StaticEmbedding 7 | from fastNLP.core.utils import seq_len_to_mask 8 | from fastNLP.embeddings.utils import get_embeddings 9 | from fastNLP.modules.encoder import Seq2SeqEncoder 10 | from .linear_attention import LinearMultiHeadAttention 11 | 12 | 13 | class LinearTransformerEncoderLayer(nn.Module): 14 | def __init__(self, d_model: int = 512, n_head: int = 8, 15 | dim_ff: int = 2048, dropout: float = 0.1): 16 | """ 17 | Transformer Encoder Block 18 | :param d_model: input和output的输出维度 19 | :param n_head: 多少个head, 每个head的维度为d_model/n_head 20 | :param dim_ff: FFN的维度大小 21 | :param dropout: Self-attention和FFN的dropout大小,0表示不drop 22 | """ 23 | super(LinearTransformerEncoderLayer, self).__init__() 24 | self.d_model = d_model 25 | self.n_head = n_head 26 | self.dim_ff = dim_ff 27 | self.dropout = dropout 28 | 29 | self.attention = LinearMultiHeadAttention(d_model, n_head, dropout) 30 | self.attn_layer_norm = nn.LayerNorm(d_model) 31 | self.ffn_layer_norm = nn.LayerNorm(d_model) 32 | 33 | self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), 34 | nn.ReLU(), 35 | nn.Linear(self.dim_ff, self.d_model), 36 | nn.Dropout(p=self.dropout) 37 | ) 38 | 39 | def forward(self, x, mask): 40 | """ 41 | 42 | :param x: [batch_size, seq_len, d_model] 43 | :param mask: [batch_size, seq_len] 44 | :return: 45 | """ 46 | residual = x 47 | x = self.attn_layer_norm(x) 48 | x = self.attention(query=x, 49 | key=x, 50 | value=x, 51 | key_mask=mask 52 | ) 53 | x = F.dropout(x, p=self.dropout, training=self.training) 54 | x = residual + x 55 | 56 | residual = x 57 | x = self.ffn_layer_norm(x) 58 | x = self.ffn(x) 59 | x = residual + x 60 | 61 | return x 62 | 63 | class LinearTransformerEncoder(Seq2SeqEncoder): 64 | def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed=None, 65 | num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): 66 | """ 67 | 基于Transformer的Encoder 68 | 69 | :param embed: encoder输入token的embedding 70 | :param nn.Module pos_embed: position embedding 71 | :param int num_layers: 多少层的encoder 72 | :param int d_model: 输入输出的维度 73 | :param int n_head: 多少个head 74 | :param int dim_ff: FFN中间的维度大小 75 | :param float dropout: Attention和FFN的dropout大小 76 | """ 77 | super().__init__() 78 | self.embed = get_embeddings(embed) 79 | self.embed_scale = math.sqrt(d_model) 80 | self.pos_embed = pos_embed 81 | self.num_layers = num_layers 82 | self.d_model = d_model 83 | self.n_head = n_head 84 | self.dim_ff = dim_ff 85 | self.dropout = dropout 86 | 87 | self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) 88 | self.layer_stacks = nn.ModuleList([LinearTransformerEncoderLayer(d_model, n_head, dim_ff, dropout) 89 | for _ in range(num_layers)]) 90 | self.layer_norm = nn.LayerNorm(d_model) 91 | 92 | def forward(self, tokens, seq_len): 93 | """ 94 | 95 | :param tokens: batch x max_len 96 | :param seq_len: [batch] 97 | :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) 98 | """ 99 | x = self.embed(tokens) * self.embed_scale 100 | batch_size, max_src_len, _ = x.size() 101 | device = x.device 102 | if self.pos_embed is not None: 103 | position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) 104 | x += self.pos_embed(position) 105 | 106 | x = self.input_fc(x) 107 | x = F.dropout(x, p=self.dropout, training=self.training) 108 | 109 | encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) 110 | encoder_mask = encoder_mask.to(device) 111 | 112 | for layer in self.layer_stacks: 113 | x = layer(x, encoder_mask) 114 | 115 | x = self.layer_norm(x) 116 | 117 | return x, encoder_mask -------------------------------------------------------------------------------- /modules/linear_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from .feature_map import elu_feature_map 5 | from modules.state import LinearTransformerState 6 | from fastNLP.modules.decoder.seq2seq_state import TransformerState 7 | 8 | def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, eps): 9 | Q = q.contiguous().permute(0, 2, 1, 3) 10 | K = k.contiguous().permute(0, 2, 1, 3) 11 | V = v.contiguous().permute(0, 2, 1, 3) 12 | KV = torch.einsum('...sd,...se->...de', K, V) 13 | Z = 1.0 / torch.einsum('...sd,...d->...s', Q, K.sum(dim=-2)+eps) 14 | V_new = torch.einsum('...de,...sd,...s->...se', KV, Q, Z) 15 | return V_new.contiguous().permute(0, 2, 1, 3) 16 | 17 | class LinearMultiHeadAttention(nn.Module): 18 | """ 19 | Implement unmasked attention using dot product of feature maps in 20 | O(N D^2) complexity. 21 | Given the queries, keys and values as Q, K, V instead of computing 22 | V' = softmax(Q.mm(K.t()), dim=-1).mm(V), 23 | we make use of a feature map function Φ(.) and perform the following 24 | computation 25 | V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V). 26 | The above can be computed in O(N D^2) complexity where D is the 27 | dimensionality of Q, K and V and N is the sequence length. Depending on the 28 | feature map, however, the complexity of the attention might be limited. 29 | Arguments 30 | --------- 31 | feature_map: callable, a callable that applies the feature map to the 32 | last dimension of a tensor (default: elu(x)+1) 33 | eps: float, a small number to ensure the numerical stability of the 34 | denominator (default: 1e-6) 35 | """ 36 | def __init__(self, d_model: int, n_head: int, dropout: float, 37 | layer_idx=None, feature_map=None, eps=1e-6): 38 | super(LinearMultiHeadAttention, self).__init__() 39 | self.d_model = d_model 40 | self.n_head = n_head 41 | self.dropout = dropout 42 | self.eps = eps 43 | self.layer_idx = layer_idx 44 | assert d_model % n_head == 0, "d_model should be divisible by n_head" 45 | self.head_dim = d_model // n_head 46 | 47 | self.q_proj = nn.Linear(d_model, d_model) 48 | self.k_proj = nn.Linear(d_model, d_model) 49 | self.v_proj = nn.Linear(d_model, d_model) 50 | self.o_proj = nn.Linear(d_model, d_model) 51 | 52 | self.feature_map = ( 53 | feature_map(d_model) if feature_map else 54 | elu_feature_map(d_model) 55 | ) 56 | 57 | def reset_params(self): 58 | nn.init.xavier_uniform_(self.q_proj.weight) 59 | nn.init.xavier_uniform_(self.k_proj.weight) 60 | nn.init.xavier_uniform_(self.v_proj.weight) 61 | nn.init.xavier_uniform_(self.o_proj.weight) 62 | 63 | def forward(self, query, key, value, key_mask, state=None): 64 | """ 65 | 66 | :param query: [batch_size, seq_len, d_model] 67 | :param key: [batch_size, seq_len, d_model] 68 | :param value: [batch_size, seq_len, d_model] 69 | :param key_mask: [batch_size, seq_len] 用于指示哪些key不要被attend到;注意到mask为1的地方是要attend的 70 | :param attn_mask: [seq_len, seq_len] 用于mask掉attention map。主要是用在训练时decoder端的self-attention,下三角为1 71 | """ 72 | q = self.q_proj(query) 73 | k = v = None 74 | 75 | if self.layer_idx is not None: 76 | if isinstance(state, LinearTransformerState): 77 | k = state.encoder_key[self.layer_idx] 78 | v = state.encoder_value[self.layer_idx] 79 | 80 | if k is None: 81 | k = self.k_proj(key) 82 | v = self.v_proj(value) 83 | 84 | if self.layer_idx is not None: 85 | if isinstance(state, LinearTransformerState): 86 | state.encoder_key[self.layer_idx] = k 87 | state.encoder_value[self.layer_idx] = v 88 | 89 | batch_size, q_len, d_model = q.size() 90 | k_len, v_len = k.size(1), v.size(1) 91 | 92 | q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) 93 | k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) 94 | v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) 95 | 96 | self.feature_map.new_feature_map() 97 | q = self.feature_map.forward_queries(q) 98 | k = self.feature_map.forward_keys(k) 99 | 100 | if key_mask is not None: 101 | _key_mask = ~key_mask[:, :, None, None].bool() 102 | k = k.masked_fill(_key_mask, 0.0) 103 | 104 | # KV = torch.einsum("bsnd,bsnm->bnmd", K, V) 105 | # 106 | # Z = 1 / (torch.einsum("bsnd,bnd->bsn", Q, K.sum(dim=1))+self.eps) 107 | # 108 | # V_new = torch.einsum("bsnd,bnmd,bsn->bsnm", Q, KV, Z) 109 | V_new = linear_attention(q, k, v, self.eps) 110 | 111 | output = V_new.contiguous().reshape(batch_size, q_len, -1) 112 | output = self.o_proj(output) 113 | 114 | return output 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /modules/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules.feature_map import elu_feature_map 4 | from fastNLP.modules.decoder.seq2seq_state import TransformerState 5 | 6 | class RecurrentCrossLinearAttention(nn.Module): 7 | """Implement autoregressive linear cross attention as a recurrent 8 | module. 9 | See fast_transformers.attention.linear_attention.LinearAttention . 10 | Arguments 11 | --------- 12 | feature_map: callable, a callable that applies the feature map to the 13 | last dimension of a tensor (default: elu(x)+1) 14 | eps: float, a small number to ensure the numerical stability of the 15 | denominator (default: 1e-6) 16 | """ 17 | def __init__(self, d_query, feature_map=None, eps=1e-6): 18 | super(RecurrentCrossLinearAttention, self).__init__() 19 | self.feature_map = ( 20 | feature_map(d_query) if feature_map else 21 | elu_feature_map(d_query) 22 | ) 23 | self.eps = eps 24 | 25 | def forward(self, query, key, value, key_mask=None, info=None): 26 | if info is None: 27 | self.feature_map.new_feature_map() 28 | 29 | Q = self.feature_map.forward_queries(query) 30 | 31 | if info is None: 32 | K = self.feature_map.forward_keys(key) 33 | if key_mask is not None: 34 | _key_mask = key_mask[:, :, None, None].bool() 35 | K = K.masked_fill(_key_mask, 0.0) 36 | S = torch.einsum("bsnd,bsnm->bnmd", K, value) 37 | Z = K.sum(dim=1) 38 | else: 39 | S, Z = info 40 | 41 | QZ = 1 / (torch.einsum("bnd,bnd->bn", Q, Z) + self.eps) 42 | V = torch.einsum("bnd,bnmd,bn->bnm", Q, S, QZ) 43 | 44 | return V.contiguous(), [S, Z] 45 | 46 | 47 | class RecurrentCrossAttentionLayer(nn.Module): 48 | """See fast_transformers.attention.attention_layer.AttentionLayer . 49 | The differences with the aforementioned module as well as the 50 | RecurrentAttentionLayer are that this module projects the query every time 51 | and the keys and values only the first time they are provided. 52 | Arguments 53 | --------- 54 | attention: Specific inner attention implementation that just computes a 55 | weighted average of values given a similarity of queries and 56 | keys. 57 | d_model: The input feature dimensionality 58 | n_heads: The number of heads for the multi head attention 59 | d_keys: The dimensionality of the keys/queries 60 | (default: d_model/n_heads) 61 | d_values: The dimensionality of the values (default: d_model/n_heads) 62 | """ 63 | def __init__(self, d_model, n_head, layer_idx): 64 | super(RecurrentCrossAttentionLayer, self).__init__() 65 | self.d_model = d_model 66 | self.n_head = n_head 67 | self.layer_idx = layer_idx 68 | assert d_model % n_head == 0, "d_model should be divisible by n_head" 69 | self.head_dim = d_model // n_head 70 | self.scaling = self.head_dim ** -0.5 71 | 72 | self.cross_attention = RecurrentCrossLinearAttention(self.head_dim) 73 | 74 | self.q_proj = nn.Linear(d_model, d_model) 75 | self.k_proj = nn.Linear(d_model, d_model) 76 | self.v_proj = nn.Linear(d_model, d_model) 77 | self.o_proj = nn.Linear(d_model, d_model) 78 | 79 | def reset_params(self): 80 | nn.init.xavier_uniform_(self.q_proj.weight) 81 | nn.init.xavier_uniform_(self.k_proj.weight) 82 | nn.init.xavier_uniform_(self.v_proj.weight) 83 | nn.init.xavier_uniform_(self.o_proj.weight) 84 | 85 | def forward(self, querys, keys, values, key_mask=None, info=None, state=None): 86 | assert keys.size() == values.size() 87 | if state is not None: 88 | assert self.layer_idx is not None 89 | 90 | q = self.q_proj(querys) 91 | q *= self.scaling 92 | k = v = None 93 | # k = self.k_proj(keys) 94 | # v = self.v_proj(values) 95 | 96 | if isinstance(state, TransformerState): 97 | k = state.encoder_key[self.layer_idx] 98 | v = state.encoder_value[self.layer_idx] 99 | 100 | if k is None: 101 | k = self.k_proj(keys) 102 | v = self.v_proj(values) 103 | 104 | if isinstance(state, TransformerState): 105 | state.encoder_key[self.layer_idx] = k 106 | state.encoder_value[self.layer_idx] = v 107 | 108 | batch_size, length_q, d_model = querys.size() 109 | length_k, length_v = keys.size(1), values.size(1) 110 | 111 | Q = q.reshape(batch_size, length_q, self.n_head, self.head_dim) 112 | K = k.reshape(batch_size, length_k, self.n_head, self.head_dim) 113 | V = v.reshape(batch_size, length_v, self.n_head, self.head_dim) 114 | 115 | output = [] 116 | for i in range(length_q): 117 | Qi = Q[:, i, :, :] 118 | Vi, info = self.cross_attention(Qi, K, V, key_mask, info) 119 | Vi = Vi.reshape(batch_size, 1, self.n_head, self.head_dim) 120 | output.append(Vi) 121 | 122 | output = torch.cat(output, dim=1) 123 | output = output.reshape(batch_size, length_q, -1) 124 | output = self.o_proj(output) 125 | 126 | return output 127 | 128 | -------------------------------------------------------------------------------- /modules/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils import check_state 4 | from .feature_map import elu_feature_map 5 | from fastNLP.modules.decoder.seq2seq_state import TransformerState 6 | 7 | class RecurrentLinearAttention(nn.Module): 8 | """Implement fast_transformers.attention.causal_linear_attention as a 9 | fixed-dimensional state recurrent model. 10 | See fast_transformers.attention.linear_attention and 11 | fast_transformers.attention.causal_linear_attention for the general concept 12 | of replacing the softmax with feature maps. 13 | Arguments 14 | --------- 15 | feature_map: callable, a callable that applies the feature map to the 16 | last dimension of a tensor (default: elu(x)+1) 17 | eps: float, a small number to ensure the numerical stability of the 18 | denominator (default: 1e-6) 19 | """ 20 | def __init__(self, d_query, feature_map=None, eps=1e-6): 21 | super(RecurrentLinearAttention, self).__init__() 22 | self.feature_map = ( 23 | feature_map(d_query) if feature_map else 24 | elu_feature_map(d_query) 25 | ) 26 | self.eps = eps 27 | 28 | def forward(self, query, key, value, info=None): 29 | # state = check_state(state, memroy) 30 | 31 | if info is None: 32 | self.feature_map.new_feature_map() 33 | 34 | Q = self.feature_map.forward_queries(query) 35 | K = self.feature_map.forward_keys(key) 36 | 37 | B, N, D = Q.shape 38 | _, _, M = value.shape 39 | 40 | if info is None: 41 | Si = query.new_zeros((B, N, D, M)) 42 | Zi = query.new_zeros((B, N, D)) 43 | else: 44 | Si, Zi = info 45 | 46 | if len(Si) != B: 47 | raise ValueError("The batch size changed during iteration") 48 | 49 | if K.grad_fn is not None or value.grad_fn is not None: 50 | Zi = Zi + K 51 | Si = Si + torch.einsum("bnd,bnm->bndm", K, value) 52 | else: 53 | Zi += K 54 | Si += torch.einsum("bnd,bnm->bndm", K, value) 55 | 56 | Z = 1. / (torch.einsum("bnd,bnd->bn", Q, Zi) + self.eps) 57 | V = torch.einsum("bnd,bndm,bn->bnm", Q, Si, Z) 58 | 59 | return V, [Si, Zi] 60 | 61 | 62 | class RecurrentSelfAttentionLayer(nn.Module): 63 | """See fast_transformers.attention.attention_layer.AttentionLayer. 64 | 65 | The only difference with the corresponding module is that this projects 66 | only one input and then calls the inner attention with the provided 67 | previous state. 68 | Arguments 69 | --------- 70 | attention: Specific inner attention implementation that just computes a 71 | weighted average of values given a similarity of queries and 72 | keys. 73 | d_model: The input feature dimensionality 74 | n_heads: The number of heads for the multi head attention 75 | d_keys: The dimensionality of the keys/queries 76 | (default: d_model/n_heads) 77 | d_values: The dimensionality of the values (default: d_model/n_heads) 78 | 79 | """ 80 | def __init__(self, d_model, n_heads, layer_idx): 81 | super(RecurrentSelfAttentionLayer, self).__init__() 82 | self.d_model = d_model 83 | self.n_head = n_heads 84 | self.layer_idx = layer_idx 85 | assert d_model % n_heads == 0, "d_model should be divisible by n_head" 86 | self.head_dim = d_model // n_heads 87 | self.scaling = self.head_dim ** -0.5 88 | 89 | self.self_attention = RecurrentLinearAttention(self.head_dim) 90 | 91 | self.q_proj = nn.Linear(d_model, d_model) 92 | self.k_proj = nn.Linear(d_model, d_model) 93 | self.v_proj = nn.Linear(d_model, d_model) 94 | self.o_proj = nn.Linear(d_model, d_model) 95 | 96 | def reset_params(self): 97 | nn.init.xavier_uniform_(self.q_proj.weight) 98 | nn.init.xavier_uniform_(self.k_proj.weight) 99 | nn.init.xavier_uniform_(self.v_proj.weight) 100 | nn.init.xavier_uniform_(self.o_proj.weight) 101 | 102 | def forward(self, querys, keys, values, info=None, state=None): 103 | assert keys.size() == values.size() 104 | if state is not None: 105 | assert self.layer_idx is not None 106 | 107 | # info = check_state(info, memory) 108 | 109 | q = self.q_proj(querys) 110 | k = self.k_proj(keys) 111 | v = self.v_proj(values) 112 | q *= self.scaling 113 | prev_k = prev_v = None 114 | 115 | if isinstance(state, TransformerState): 116 | prev_k = state.decoder_prev_key[self.layer_idx] 117 | prev_v = state.decoder_prev_value[self.layer_idx] 118 | 119 | if prev_k is not None: 120 | k = torch.cat((prev_k, k), dim=1) 121 | v = torch.cat((prev_v, v), dim=1) 122 | 123 | if isinstance(state, TransformerState): 124 | state.decoder_prev_key[self.layer_idx] = k 125 | state.decoder_prev_value[self.layer_idx] = v 126 | 127 | batch_size, length_q, d_model = querys.size() 128 | length_k, length_v = keys.size(1), values.size(1) 129 | 130 | assert length_q == length_k == length_v, "The length of Q, K and V are not the same." 131 | 132 | Q = q.reshape(batch_size, length_q, self.n_head, self.head_dim) 133 | K = k.reshape(batch_size, length_k, self.n_head, self.head_dim) 134 | V = v.reshape(batch_size, length_v, self.n_head, self.head_dim) 135 | 136 | output = [] 137 | for i in range(length_q): 138 | Qi = Q[:, i, :, :] 139 | Ki = K[:, i, :, :] 140 | Vi = V[:, i, :, :] 141 | Vi, info = self.self_attention(Qi, Ki, Vi, info) 142 | Vi = Vi.reshape(batch_size, 1, self.n_head, self.head_dim) 143 | output.append(Vi) 144 | 145 | output = torch.cat(output, dim=1) 146 | output = output.reshape(batch_size, length_q, -1) 147 | output = self.o_proj(output) 148 | 149 | return output 150 | -------------------------------------------------------------------------------- /train_seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | import numpy as np 5 | import random 6 | 7 | from data.wmt16_loader import MTPipe 8 | from losses.mtloss import MTLoss 9 | from metrics.bleu import BLUEMetric 10 | from model.seq2seq_model import LinearTransformerSeq2SeqModel 11 | from fastNLP import Trainer 12 | from fastNLP.embeddings import StaticEmbedding 13 | from fastNLP.models import SequenceGeneratorModel 14 | from fastNLP import BucketSampler, GradientClipCallback, cache_results 15 | from fastNLP import SortedSampler, WarmupCallback, FitlogCallback 16 | 17 | 18 | SEED = 1234 19 | 20 | random.seed(SEED) 21 | np.random.seed(SEED) 22 | torch.manual_seed(SEED) 23 | torch.cuda.manual_seed(SEED) 24 | torch.backends.cudnn.deterministic = True 25 | torch.cuda.set_device(0) 26 | 27 | #######hyper 28 | lr = 5e-4 29 | n_epochs = 20 30 | batch_size = 64 31 | n_heads = 8 32 | d_model = 256 33 | dim_ff = 512 34 | num_layers = 3 35 | bind_decoder_input_output_embed = False 36 | #######hyper 37 | 38 | # @cache_results('caches/data.pkl') 39 | def get_data(): 40 | data_bundle = MTPipe().process_from_file(paths='./wmt16') 41 | return data_bundle 42 | 43 | data_bundle = get_data() 44 | max_len_train = max([seq_len for seq_len in data_bundle.get_dataset('train')['tgt_seq_len']]) 45 | max_len_dev = max([seq_len for seq_len in data_bundle.get_dataset('dev')['tgt_seq_len']]) 46 | max_len_test = max([seq_len for seq_len in data_bundle.get_dataset('test')['tgt_seq_len']]) 47 | max_len = max([max_len_train, max_len_dev, max_len_test]) 48 | print(data_bundle) 49 | print("The maximal length of target is ", max_len) 50 | src_vocab = data_bundle.get_vocab('src_tokens') 51 | tgt_vocab = data_bundle.get_vocab('tgt_tokens') 52 | 53 | src_embed = StaticEmbedding(data_bundle.get_vocab('src_tokens'), embedding_dim=d_model, model_dir_or_name=None) 54 | tgt_embed = StaticEmbedding(data_bundle.get_vocab('tgt_tokens'), embedding_dim=d_model, model_dir_or_name=None) 55 | 56 | model = LinearTransformerSeq2SeqModel.build_model(src_embed=src_embed, tgt_embed=tgt_embed, 57 | pos_embed='sin', max_position=1024, num_layers=6, d_model=256, n_head=8, dim_ff=512, dropout=0.1, 58 | bind_encoder_decoder_embed=False, bind_decoder_input_output_embed=bind_decoder_input_output_embed) 59 | 60 | model = SequenceGeneratorModel(model, bos_token_id=tgt_vocab.to_index(''), 61 | eos_token_id=tgt_vocab.to_index(''), max_length=max_len, 62 | num_beams=4, do_sample=False, temperature=1.0, top_k=20, top_p=1.0, 63 | repetition_penalty=1, length_penalty=1.0, pad_token_id=0) 64 | 65 | optimizer = optim.AdamW(model.parameters(), lr=lr) 66 | 67 | callbacks = [] 68 | callbacks.append(GradientClipCallback(clip_value=1, clip_type='value')) 69 | callbacks.append(WarmupCallback(warmup=0.01, schedule='linear')) 70 | # callbacks.append(FitlogCallback(data_bundle.get_dataset('test'))) 71 | sampler = BucketSampler(seq_len_field_name='src_seq_len') 72 | trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, optimizer=optimizer, 73 | loss=MTLoss(), batch_size=batch_size, sampler=sampler, drop_last=False, 74 | update_every=1, num_workers=2, n_epochs=n_epochs, print_every=1, device=0, 75 | use_tqdm=True, dev_data=data_bundle.get_dataset('dev'), 76 | metrics=BLUEMetric(tgt_vocab), metric_key=None, 77 | validate_every=-1, save_path=None, callbacks=callbacks, 78 | # check_code_level=0, test_use_tqdm=False, 79 | test_sampler=SortedSampler('src_seq_len') 80 | ) 81 | 82 | trainer.train(load_best_model=False) 83 | 84 | 85 | 86 | # model = model.to(device) 87 | # 88 | # LEARNING_RATE = 0.0005 89 | # 90 | # optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE) 91 | # 92 | # criterion = nn.CrossEntropyLoss(ignore_index = tgt_PAD_IDX) 93 | # 94 | # def train(model, iterator, optimizer, criterion, clip): 95 | # model.train() 96 | # epoch_loss = 0.0 97 | # for i, batch in tqdm(enumerate(train_iterator)): 98 | # src_info = batch.src 99 | # tgt_info = batch.trg 100 | # 101 | # src = src_info[0] 102 | # src_len = src_info[1].cpu() 103 | # src_len = src_len.view(1, len(src_len)) 104 | # 105 | # tgt = tgt_info[0] 106 | # tgt_len = tgt_info[1].cpu() 107 | # tgt_len = tgt_len.view(1, len(tgt_len)) 108 | # 109 | # optimizer.zero_grad() 110 | # 111 | # output = model(src, tgt, src_len[0], tgt_len[0]) 112 | # pred = output['pred'] 113 | # output_dim = pred.shape[-1] 114 | # 115 | # pred = pred.contiguous().view(-1, output_dim) 116 | # tgt = tgt.view(-1) 117 | # 118 | # loss = criterion(pred, tgt) 119 | # 120 | # loss.backward() 121 | # 122 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 123 | # 124 | # optimizer.step() 125 | # 126 | # epoch_loss += loss.item() 127 | # return epoch_loss / len(iterator) 128 | # 129 | # def evaluate(model, iterator, criterion): 130 | # model.eval() 131 | # epoch_loss = 0.0 132 | # with torch.no_grad(): 133 | # for i, batch in enumerate(iterator): 134 | # src_info = batch.src 135 | # tgt_info = batch.trg 136 | # 137 | # src = src_info[0] 138 | # src_len = src_info[1].cpu() 139 | # src_len = src_len.view(1, len(src_len)) 140 | # 141 | # tgt = tgt_info[0] 142 | # tgt_len = tgt_info[1].cpu() 143 | # tgt_len = tgt_len.view(1, len(tgt_len)) 144 | # 145 | # output = model(src, tgt, src_len[0], tgt_len[0]) 146 | # pred = output['pred'] 147 | # output_dim = pred.shape[-1] 148 | # 149 | # pred = pred.contiguous().view(-1, output_dim) 150 | # tgt = tgt.view(-1) 151 | # 152 | # loss = criterion(pred, tgt) 153 | # 154 | # epoch_loss += loss.item() 155 | # return epoch_loss / len(iterator) 156 | # 157 | # EPOCHS = 10 158 | # CLIP = 1 159 | 160 | -------------------------------------------------------------------------------- /LinearKernel/LinearKernel/causal_product/causal_product_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | /** 5 | * Compute a*b^T and save it into out. 6 | * 7 | * a \in R^A 8 | * b \in R^B 9 | */ 10 | inline void vvt_dot(float *a, float *b, float *out, int A, int B) { 11 | for (int i=0; i(); 91 | auto ka = keys.accessor(); 92 | auto va = values.accessor(); 93 | auto pa = product.accessor(); 94 | 95 | #pragma omp parallel for collapse(2) 96 | for (int n=0; n(); 100 | for (int l=0; l(); 145 | auto ka = keys.accessor(); 146 | auto va = values.accessor(); 147 | auto ga = grad_out.accessor(); 148 | auto gqa = grad_queries.accessor(); 149 | auto gka = grad_keys.accessor(); 150 | auto gva = grad_values.accessor(); 151 | 152 | #pragma omp parallel for collapse(2) 153 | for (int n=0; n(); 157 | 158 | // Compute the gradient wrt the queries 159 | for (int l=0; l=0; l--) { 179 | vvt_dot( 180 | &qa[n][h][l][0], 181 | &ga[n][h][l][0], 182 | kvp, 183 | E, 184 | M 185 | ); 186 | vmt_dot( 187 | &va[n][h][l][0], 188 | kvp, 189 | &gka[n][h][l][0], 190 | E, 191 | M 192 | ); 193 | vm_dot( 194 | &ka[n][h][l][0], 195 | kvp, 196 | &gva[n][h][l][0], 197 | E, 198 | M 199 | ); 200 | } 201 | } 202 | } 203 | } 204 | 205 | 206 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 207 | m.def( 208 | "causal_dot_product", 209 | &causal_dot_product, 210 | "Compute the weighted sum of values but attending only to previous " 211 | "values." 212 | ); 213 | m.def( 214 | "causal_dot_backward", 215 | &causal_dot_backward, 216 | "Compute the gradient of queries, keys and values given the gradient " 217 | "of causal_dot_product." 218 | ); 219 | } 220 | -------------------------------------------------------------------------------- /modules/causal_linear_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules.feature_map import elu_feature_map 4 | from LinearKernel.causal_product import causal_dot_product 5 | from fastNLP.modules.decoder.seq2seq_state import TransformerState 6 | from modules.state import LinearTransformerState 7 | 8 | def permute_for_matrix(matrix: torch.Tensor): 9 | assert matrix.dim() == 4 10 | return matrix.contiguous().permute(0, 2, 1, 3) 11 | 12 | def causal_linear(Q, K, V): 13 | Q = permute_for_matrix(Q) 14 | K = permute_for_matrix(K) 15 | V = permute_for_matrix(V) 16 | V_new = causal_dot_product(Q, K, V) 17 | return permute_for_matrix(V_new) 18 | 19 | 20 | class CausalLinearAttention(nn.Module): 21 | """Implement causally masked attention using dot product of feature maps in 22 | O(N D^2) complexity. 23 | See fast_transformers.attention.linear_attention.LinearAttention for the 24 | general concept of replacing the softmax with feature maps. In addition to 25 | that, we also make use of the fact that causal masking is a triangular mask 26 | which allows us to apply the masking and still compute the attention in O(N 27 | D^2) complexity. 28 | Arguments 29 | --------- 30 | feature_map: callable, a callable that applies the feature map to the 31 | last dimension of a tensor (default: elu(x)+1) 32 | eps: float, a small number to ensure the numerical stability of the 33 | denominator (default: 1e-6) 34 | """ 35 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6): 36 | super(CausalLinearAttention, self).__init__() 37 | self.feature_map = ( 38 | feature_map(query_dimensions) if feature_map else 39 | elu_feature_map(query_dimensions) 40 | ) 41 | self.eps = eps 42 | # self.cnt = 0 43 | 44 | def _make_sizes_compatible(self, Q, K): 45 | """Either slice or pad K in case that the sizes do not match between Q 46 | and K.""" 47 | N, L, H, E = Q.shape 48 | _, S, _, _ = K.shape 49 | if L == S: 50 | return Q, K 51 | 52 | if L < S: 53 | return Q, K[:, :L, :, :] 54 | 55 | if L > S: 56 | temp = K.new_zeros(N, L-S, H, E) 57 | K = torch.cat([K, temp], dim=1) 58 | return Q, K 59 | 60 | def forward(self, queries, keys, values, key_mask): 61 | # Apply the feature map to the queries and keys 62 | self.feature_map.new_feature_map() 63 | Q = self.feature_map.forward_queries(queries) 64 | K = self.feature_map.forward_keys(keys) 65 | 66 | # if key_mask is not None: 67 | # _key_mask = ~key_mask[:, :, None, None].bool() 68 | # K = K.masked_fill(_key_mask, 0.0) 69 | 70 | # Ensure that Q and K have compatible sizes for the following 71 | # computations, namely L == S 72 | # Q, K = self._make_sizes_compatible(Q, K) 73 | 74 | # Compute the normalizers 75 | Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) 76 | 77 | # Compute the unnormalized result 78 | V = causal_linear( 79 | Q, 80 | K, 81 | values 82 | ) 83 | V = V * Z[:, :, :, None] 84 | return V 85 | 86 | class CausalLinearAttentionLayer(nn.Module): 87 | def __init__(self, d_model, n_head, layer_idx, feature_map=None, eps=1e-6): 88 | super(CausalLinearAttentionLayer, self).__init__() 89 | self.d_model = d_model 90 | self.n_head = n_head 91 | self.layer_idx = layer_idx 92 | assert d_model % n_head == 0, "d_model should be divisible by n_head" 93 | self.head_dim = d_model // n_head 94 | self.scaling = self.head_dim ** -0.5 95 | self.eps = eps 96 | 97 | self.attention = CausalLinearAttention(self.head_dim) 98 | 99 | self.q_proj = nn.Linear(d_model, d_model) 100 | self.k_proj = nn.Linear(d_model, d_model) 101 | self.v_proj = nn.Linear(d_model, d_model) 102 | self.o_proj = nn.Linear(d_model, d_model) 103 | 104 | self.feature_map = ( 105 | feature_map(self.head_dim) if feature_map else 106 | elu_feature_map(self.head_dim) 107 | ) 108 | 109 | def reset_params(self): 110 | nn.init.xavier_uniform_(self.q_proj.weight) 111 | nn.init.xavier_uniform_(self.k_proj.weight) 112 | nn.init.xavier_uniform_(self.v_proj.weight) 113 | nn.init.xavier_uniform_(self.o_proj.weight) 114 | 115 | def forward(self, querys, keys, values, key_mask=None, state=None): 116 | assert keys.size() == values.size() 117 | if state is not None: 118 | assert self.layer_idx is not None 119 | # qkv_same = querys.data_ptr() == keys.data_ptr() == values.data_ptr() 120 | 121 | q = self.q_proj(querys) 122 | q *= self.scaling 123 | k = self.k_proj(keys) 124 | v = self.v_proj(values) 125 | prev_kv_sum = prev_k_sum = None 126 | 127 | batch_size, length_q, d_model = querys.size() 128 | length_k, length_v = keys.size(1), values.size(1) 129 | 130 | if isinstance(state, LinearTransformerState) and self.training is False: 131 | prev_k_sum = state.decoder_k_sum[self.layer_idx] 132 | prev_kv_sum = state.decoder_kv_sum[self.layer_idx] 133 | 134 | Q = q.reshape(batch_size, length_q, self.n_head, self.head_dim) 135 | K = k.reshape(batch_size, length_k, self.n_head, self.head_dim) 136 | V = v.reshape(batch_size, length_v, self.n_head, self.head_dim) 137 | 138 | self.feature_map.new_feature_map() 139 | Q = self.feature_map.forward_queries(Q) 140 | K = self.feature_map.forward_keys(K) 141 | 142 | Q = permute_for_matrix(Q) 143 | K = permute_for_matrix(K) 144 | V = permute_for_matrix(V) 145 | 146 | k_sum = prev_k_sum + K.contiguous().reshape(batch_size, self.n_head, -1) 147 | kv = torch.einsum('...sd,...se->...de', K, V) 148 | kv_sum = prev_kv_sum + kv 149 | 150 | state.decoder_k_sum[self.layer_idx] = k_sum 151 | state.decoder_kv_sum[self.layer_idx] = kv_sum 152 | 153 | Z = 1.0 / torch.einsum('...sd,...d->...s', Q, k_sum + self.eps) 154 | V_new = torch.einsum('...de,...sd,...s->...se', kv_sum, Q, Z) 155 | 156 | output = permute_for_matrix(V_new) 157 | output = output.reshape(batch_size, length_q, -1) 158 | output = self.o_proj(output) 159 | else: 160 | Q = q.contiguous().reshape(batch_size, length_q, self.n_head, self.head_dim) 161 | K = k.contiguous().reshape(batch_size, length_k, self.n_head, self.head_dim) 162 | V = v.contiguous().reshape(batch_size, length_v, self.n_head, self.head_dim) 163 | 164 | output = self.attention(Q, K, V, key_mask) 165 | output = output.contiguous().reshape(batch_size, length_q, -1) 166 | output = self.o_proj(output) 167 | 168 | return output -------------------------------------------------------------------------------- /modules/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from typing import Union, Tuple 5 | import torch.nn.functional as F 6 | from fastNLP.embeddings import StaticEmbedding 7 | from fastNLP.core.utils import seq_len_to_mask 8 | from fastNLP.embeddings.utils import get_embeddings 9 | from modules.state import LinearTransformerState 10 | from modules.self_attention import RecurrentSelfAttentionLayer 11 | from modules.cross_attention import RecurrentCrossAttentionLayer 12 | from modules.linear_attention import LinearMultiHeadAttention 13 | from modules.causal_linear_attention import CausalLinearAttentionLayer 14 | from fastNLP.modules.decoder.seq2seq_decoder import Seq2SeqDecoder 15 | from fastNLP.modules.decoder.seq2seq_state import TransformerState 16 | 17 | class LinearTransformerDecoderLayer(nn.Module): 18 | def __init__(self, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, layer_idx=None): 19 | super().__init__() 20 | self.d_model = d_model 21 | self.n_head = n_head 22 | self.dim_ff = dim_ff 23 | self.dropout = dropout 24 | self.layer_idx = layer_idx 25 | 26 | # self.self_attention = RecurrentSelfAttentionLayer(d_model, n_head, layer_idx) 27 | self.self_attention = CausalLinearAttentionLayer(d_model, n_head, layer_idx) 28 | self.self_attn_layer_norm = nn.LayerNorm(d_model) 29 | 30 | self.cross_attention = LinearMultiHeadAttention(d_model, n_head, dropout, layer_idx) 31 | self.cross_attn_layer_norm = nn.LayerNorm(d_model) 32 | 33 | self.ffn = nn.Sequential(nn.Linear(d_model, dim_ff), 34 | nn.ReLU(), 35 | nn.Dropout(p=dropout), 36 | nn.Linear(dim_ff, d_model), 37 | nn.Dropout(p=dropout)) 38 | 39 | self.final_layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, x, encoder_output, encoder_mask=None, state=None): 42 | residual = x 43 | x = self.self_attn_layer_norm(x) 44 | x = self.self_attention(querys=x, 45 | keys=x, 46 | values=x, 47 | state=state) 48 | 49 | x = F.dropout(x, p=self.dropout, training=self.training) 50 | x = residual + x 51 | 52 | residual = x 53 | x = self.cross_attn_layer_norm(x) 54 | x = self.cross_attention(query=x, 55 | key=encoder_output, 56 | value=encoder_output, 57 | key_mask=encoder_mask, 58 | state=state) 59 | 60 | x = F.dropout(x, p=self.dropout, training=self.training) 61 | x = residual + x 62 | 63 | residual = x 64 | x = self.final_layer_norm(x) 65 | x = self.ffn(x) 66 | x = residual + x 67 | 68 | return x 69 | 70 | class TiedEmbedding(nn.Module): 71 | """ 72 | 用于将weight和原始weight绑定 73 | 74 | """ 75 | def __init__(self, weight): 76 | super().__init__() 77 | self.weight = weight # vocab_size x embed_size 78 | 79 | def forward(self, x): 80 | """ 81 | 82 | :param torch.FloatTensor x: bsz x * x embed_size 83 | :return: torch.FloatTensor bsz x * x vocab_size 84 | """ 85 | return torch.matmul(x, self.weight.t()) 86 | 87 | def get_binded_decoder_output_embed(embed): 88 | """ 89 | 给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding 90 | 91 | :param embed: 92 | :return: 93 | """ 94 | if isinstance(embed, StaticEmbedding): 95 | for idx, map2idx in enumerate(embed.words_to_words): 96 | assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ 97 | "include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ 98 | "`lower=True` or `min_freq!=1`." 99 | elif not isinstance(embed, nn.Embedding): 100 | raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") 101 | 102 | return TiedEmbedding(embed.weight) 103 | 104 | 105 | class LinearTransformerDecoder(Seq2SeqDecoder): 106 | def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], 107 | pos_embed: nn.Module=None, d_model=512, num_layers=6, n_head=8, 108 | dim_ff=2048, dropout=0.1, bind_decoder_input_output_embed=True): 109 | super().__init__() 110 | self.embed = get_embeddings(embed) 111 | self.pos_embed = pos_embed 112 | 113 | if bind_decoder_input_output_embed: 114 | self.output_layer = get_binded_decoder_output_embed(self.embed) 115 | else: 116 | self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) 117 | self.output_layer = TiedEmbedding(self.output_embed.weight) 118 | 119 | self.num_layers = num_layers 120 | self.d_model = d_model 121 | self.n_head = n_head 122 | self.dim_ff = dim_ff 123 | self.dropout = dropout 124 | 125 | self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) 126 | self.layer_stacks = nn.ModuleList([LinearTransformerDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) 127 | for layer_idx in range(num_layers)]) 128 | 129 | self.embed_scale = math.sqrt(d_model) 130 | self.layer_norm = nn.LayerNorm(d_model) 131 | self.output_fc = nn.Linear(d_model, self.embed.embedding_dim) 132 | 133 | def forward(self, tokens, state, return_attention=False): 134 | encoder_output = state.encoder_output 135 | encoder_mask = state.encoder_mask 136 | 137 | assert state.decode_length < tokens.size(1), "The decoded tokens in State should be less than tokens." 138 | tokens = tokens[:, state.decode_length:] 139 | device = tokens.device 140 | 141 | # x = [] 142 | # if self.training: 143 | # tgt_seq_len = tokens.size(1) - 1 144 | # else: 145 | # tgt_seq_len = tokens.size(1) 146 | # for i in range(tgt_seq_len): 147 | # # 这里需要设计一下 148 | # embed = self.embed(tokens[:, i]) 149 | # x.append(embed) 150 | # 151 | # x = torch.stack(x, dim=1) 152 | # x = self.embed_scale * x 153 | 154 | x = self.embed_scale * self.embed(tokens) 155 | if self.pos_embed is not None: 156 | position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] 157 | x += self.pos_embed(position) 158 | x = self.input_fc(x) 159 | x = F.dropout(x, p=self.dropout, training=self.training) 160 | 161 | for layer in self.layer_stacks: 162 | x = layer(x=x, 163 | encoder_output=encoder_output, 164 | encoder_mask=encoder_mask, 165 | state=state 166 | ) 167 | 168 | state.decode_length += 1 169 | 170 | x = self.layer_norm(x) 171 | x = self.output_fc(x) 172 | feats = self.output_layer(x) 173 | 174 | return feats 175 | 176 | def init_state(self, encoder_output, encoder_mask): 177 | if isinstance(encoder_output, torch.Tensor): 178 | encoder_output = encoder_output 179 | elif isinstance(encoder_output, (list, tuple)): 180 | encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 181 | else: 182 | raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") 183 | state = LinearTransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) 184 | return state 185 | 186 | 187 | -------------------------------------------------------------------------------- /data/iwslt17_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import spacy 4 | from numpy import * 5 | from fastNLP.io.loader import Loader 6 | from fastNLP.io.pipe import Pipe 7 | from fastNLP.io import DataBundle 8 | from fastNLP.core.vocabulary import Vocabulary 9 | from fastNLP.core.dataset import DataSet 10 | from fastNLP.core.instance import Instance 11 | 12 | spacy_de = spacy.load('de') 13 | spacy_en = spacy.load('en') 14 | 15 | class IWSLT2017Loader(Loader): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def load_train_data(self, path_en, path_de) -> DataSet: 20 | dataset = DataSet() 21 | with open(path_en, encoding='utf-8', mode='r') as en_corpus: 22 | with open(path_de, encoding='utf-8', mode='r') as de_corpus: 23 | cnt = 0 24 | en_sentences, de_sentences = '', '' 25 | for line_en, line_de in zip(en_corpus, de_corpus): 26 | line_en = line_en.strip() 27 | line_de = line_de.strip() 28 | if line_en == '' or line_de == '': 29 | continue 30 | if line_en.startswith('<') or line_de.startswith('<'): 31 | if line_en.startswith('' 34 | dataset.append(instance=Instance(src_raw=de_sentences, tgt_raw=en_sentences)) 35 | en_sentences = '' 36 | de_sentences = '' 37 | continue 38 | cnt += 1 39 | en_sentences += (' ' + line_en) 40 | de_sentences += (' ' + line_de) 41 | if cnt == 60: 42 | cnt = 0 43 | en_sentences = ' ' + en_sentences + ' ' 44 | dataset.append(instance=Instance(src_raw=de_sentences, tgt_raw=en_sentences)) 45 | en_sentences = '' 46 | de_sentences = '' 47 | return dataset 48 | 49 | def load_dev_test(self, path_en, path_de) -> DataSet: 50 | dataset = DataSet() 51 | with open(path_en, encoding='utf-8', mode='r') as en_corpus: 52 | with open(path_de, encoding='utf-8', mode='r') as de_corpus: 53 | seg_id, cnt = 0, 0 54 | en_sentences, de_sentences = '', '' 55 | for line_en, line_de in zip(en_corpus, de_corpus): 56 | line_en = line_en.strip() 57 | line_de = line_de.strip() 58 | # if line_en == '' or line_de == '': 59 | # continue 60 | if line_en.startswith('') and line_de.startswith(''): 61 | if en_sentences != '': 62 | seg_id, cnt = 0, 0 63 | en_sentences = ' ' + en_sentences + ' ' 64 | dataset.append(instance=Instance(src_raw=de_sentences, tgt_raw=en_sentences)) 65 | en_sentences = '' 66 | de_sentences = '' 67 | 68 | if line_en.startswith('' 72 | line_en = line_en.lstrip(strip_str) 73 | line_en = line_en.rstrip('') 74 | line_en = line_en.strip() 75 | 76 | line_de = line_de.lstrip(strip_str) 77 | line_de = line_de.rstrip('') 78 | line_de = line_de.strip() 79 | 80 | en_sentences += (' ' + line_en) 81 | de_sentences += (' ' + line_de) 82 | 83 | if cnt == 60: 84 | cnt = 0 85 | en_sentences = ' ' + en_sentences + ' ' 86 | dataset.append(instance=Instance(src_raw=de_sentences, tgt_raw=en_sentences)) 87 | en_sentences = '' 88 | de_sentences = '' 89 | return dataset 90 | 91 | def load(self, paths=None) -> DataBundle: 92 | datasets = {} 93 | path_en = '../IWSLT2017/train.tags.de-en.en' 94 | path_de = '../IWSLT2017/train.tags.de-en.de' 95 | train_dataset = self.load_train_data(path_en, path_de) 96 | datasets['train'] = train_dataset 97 | 98 | path_en = '../IWSLT2017/IWSLT17.TED.dev2010.de-en.en.xml' 99 | path_de = '../IWSLT2017/IWSLT17.TED.dev2010.de-en.de.xml' 100 | dev_dataset = self.load_dev_test(path_en, path_de) 101 | datasets['dev'] = dev_dataset 102 | 103 | path_en = '../IWSLT2017/IWSLT17.TED.tst2010.de-en.en.xml' 104 | path_de = '../IWSLT2017/IWSLT17.TED.tst2010.de-en.de.xml' 105 | test_dataset = self.load_dev_test(path_en, path_de) 106 | datasets['test'] = test_dataset 107 | 108 | data_bundle = DataBundle(datasets=datasets) 109 | 110 | return data_bundle 111 | 112 | 113 | class IWSLT2017Pipe(Pipe): 114 | def __init__(self): 115 | super().__init__() 116 | 117 | def _tokenize(self, data_bundle): 118 | for name, dataset in data_bundle.datasets.items(): 119 | 120 | def tokenize_de(raw_text): 121 | output = [(token.text).lower() for token in spacy_de.tokenizer(raw_text)] 122 | return output 123 | dataset.apply_field(tokenize_de, field_name='src_raw', 124 | new_field_name='src_tokens') 125 | def tokenize_en(raw_text): 126 | output = [(token.text).lower() for token in spacy_en.tokenizer(raw_text)] 127 | return output 128 | dataset.apply_field(tokenize_en, field_name='tgt_raw', 129 | new_field_name='tgt_tokens') 130 | 131 | dataset.apply_field(lambda x: len(x), field_name='src_tokens', 132 | new_field_name='src_seq_len') 133 | dataset.apply_field(lambda x: len(x), field_name='tgt_tokens', 134 | new_field_name='tgt_seq_len') 135 | 136 | return data_bundle 137 | 138 | 139 | def process(self, data_bundle: DataBundle) -> DataBundle: 140 | self._tokenize(data_bundle) 141 | 142 | fields = ['src_tokens', 'tgt_tokens'] 143 | for field in fields: 144 | if field == 'src_tokens': 145 | vocab = Vocabulary() 146 | else: 147 | vocab = Vocabulary(unknown='') 148 | vocab.from_dataset(data_bundle.get_dataset('train'), 149 | field_name=field, 150 | no_create_entry_dataset=[data_bundle.get_dataset('dev'), 151 | data_bundle.get_dataset('test')] 152 | ) 153 | vocab.index_dataset(*data_bundle.datasets.values(), field_name=field) 154 | data_bundle.set_vocab(vocab, field) 155 | 156 | # def padding(seq): 157 | # length = len(seq) 158 | # seq = seq + [0] * (2000 - length) 159 | # return seq 160 | # 161 | # data_bundle.apply_field(padding, field_name='src_tokens', new_field_name='src_tokens') 162 | 163 | data_bundle.set_input('src_tokens', 'tgt_tokens', 'src_seq_len', 'tgt_seq_len') 164 | data_bundle.set_target('tgt_tokens', 'tgt_seq_len') 165 | 166 | return data_bundle 167 | 168 | def process_from_file(self, paths=None) -> DataBundle: 169 | if os.path.exists(paths): 170 | data_bundle = torch.load(paths) 171 | else: 172 | data_bundle = IWSLT2017Loader().load(paths) 173 | data_bundle = self.process(data_bundle) 174 | torch.save(data_bundle, paths) 175 | return data_bundle 176 | 177 | if __name__ == '__main__': 178 | save_path = '../IWSLT2017/processed_data_bundle_large.pt' 179 | data_bundle = IWSLT2017Pipe().process_from_file(save_path) 180 | print(data_bundle) 181 | len_train = [seq_len for seq_len in data_bundle.get_dataset('train')['tgt_seq_len']] 182 | mean_len_train = mean(len_train) 183 | len_dev = [seq_len for seq_len in data_bundle.get_dataset('dev')['tgt_seq_len']] 184 | mean_len_dev = mean(len_dev) 185 | len_test = [seq_len for seq_len in data_bundle.get_dataset('test')['tgt_seq_len']] 186 | mean_len_test = mean(len_test) 187 | print("The average length of train data is", int(mean_len_train)) 188 | print("The average length of dev data is", int(mean_len_dev)) 189 | print("The average length of test data is", int(mean_len_test)) 190 | print(mean([mean_len_train, mean_len_dev, mean_len_test])) -------------------------------------------------------------------------------- /LinearKernel/LinearKernel/causal_product/causal_product_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | typedef torch::PackedTensorAccessor32 float_accessor; 4 | 5 | __device__ void get_result( 6 | const float_accessor queries, 7 | const float_accessor keys, 8 | const float_accessor values, 9 | float_accessor kv, 10 | float_accessor result, 11 | const int n, 12 | const int h, 13 | const int e, 14 | const int m, 15 | const int L 16 | ) { 17 | for (int l=0; l= N) || (e >= E)) { 86 | return; 87 | } 88 | shared_kv[threadIdx.x] = kv[n][h][e][m]; 89 | for (int t=0; t>>( 145 | queries.packed_accessor32(), 146 | keys.packed_accessor32(), 147 | values.packed_accessor32(), 148 | kv.packed_accessor32(), 149 | product.packed_accessor32(), 150 | N, H, L, E, M, E_per_block, blocks_per_sequence, T, l_offset 151 | ); 152 | } 153 | } 154 | 155 | 156 | 157 | // we need shared memory to store 158 | // Forward direction 159 | // keys, values, gradout 160 | // kv, results 161 | // Backward direction 162 | // queries, gradout, values 163 | // kv_backwards, results 164 | // Shared memory usage 165 | // Forward 166 | // keys: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E 167 | // Backward 168 | // queries: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E 169 | // Total memory: 170 | __global__ void causal_dot_backward_query_key_kernel( 171 | const float_accessor queries, 172 | const float_accessor keys, 173 | const float_accessor values, 174 | const float_accessor grad_out, 175 | float_accessor kv, 176 | float_accessor kv_backwards, 177 | float_accessor grad_queries, 178 | float_accessor grad_keys, 179 | int N, 180 | int H, 181 | int L, 182 | int E, 183 | int M, 184 | const int M_per_block, 185 | const int blocks_per_sequence, 186 | const int T, 187 | const int l_offset 188 | ) { 189 | 190 | const int sequence_index = blockIdx.x / blocks_per_sequence; 191 | int n = sequence_index / H; 192 | int h = sequence_index % H; 193 | 194 | int m_local = threadIdx.x / E; 195 | int m_start = ((blockIdx.x % blocks_per_sequence)*M_per_block); 196 | int m = m_start + m_local; 197 | int e = threadIdx.x % E; 198 | 199 | // Load the shared memory 200 | // Forward memory 201 | // keys: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E 202 | // Backward memory 203 | // queries: E*T, (values, gradout): M_per_block*T, kv:E*M_per_block, results:E 204 | // Load the shared memory for KV 205 | extern __shared__ float shared_mem[]; 206 | const int shared_kv_size = M_per_block * E; 207 | float* shared_kv = shared_mem; 208 | float* shared_kv_bw = shared_mem + shared_kv_size; 209 | float* shared_results = shared_kv_bw + shared_kv_size; 210 | float* shared_results_bw = shared_results + E; 211 | float* shared_keys = shared_results_bw + E; 212 | float* shared_values = shared_keys + E*T; 213 | float* shared_gradout = shared_values + M_per_block*T; 214 | float* shared_queries_bw = shared_gradout + M_per_block*T; 215 | float* shared_values_bw = shared_queries_bw + E*T; 216 | float* shared_gradout_bw = shared_values_bw + M_per_block*T; 217 | 218 | if (threadIdx.x < E) { 219 | shared_results[threadIdx.x] = 0.0; 220 | shared_results_bw[threadIdx.x] = 0.0; 221 | } 222 | 223 | int t_end = (T + l_offset) <= L ? T : (L - l_offset); 224 | for (int i = threadIdx.x; i < (t_end*M_per_block); i += blockDim.x) 225 | { 226 | int t = int(i / M_per_block) + l_offset; 227 | int t_bw = L - t - 1; 228 | int d = (i % M_per_block) + m_start; 229 | if (d < M) { 230 | shared_values[i] = values[n][h][t][d]; 231 | shared_gradout[i] = grad_out[n][h][t][d]; 232 | shared_values_bw[i] = values[n][h][t_bw][d]; 233 | shared_gradout_bw[i] = grad_out[n][h][t_bw][d]; 234 | } 235 | } 236 | for (int i = threadIdx.x; i < (t_end*E); i += blockDim.x) 237 | { 238 | int t = int(i / E) + l_offset; 239 | int t_bw = L - t - 1; 240 | int d = (i % E); 241 | shared_keys[i] = keys[n][h][t][d]; 242 | shared_queries_bw[i] = queries[n][h][t_bw][d]; 243 | } 244 | __syncthreads(); 245 | 246 | if ((n >= N) || (m >= M)) { 247 | return; 248 | } 249 | 250 | shared_kv[threadIdx.x] = kv[n][h][e][m]; 251 | shared_kv_bw[threadIdx.x] = kv_backwards[n][h][e][m]; 252 | 253 | for (int t=0; t= N) || (e >= E)){ 353 | return; 354 | } 355 | 356 | shared_kv[threadIdx.x] = kv[n][h][e][m]; 357 | for (int t=0; t>>( 421 | queries.packed_accessor32(), 422 | keys.packed_accessor32(), 423 | values.packed_accessor32(), 424 | grad_out.packed_accessor32(), 425 | kv.packed_accessor32(), 426 | kv_backward.packed_accessor32(), 427 | grad_queries.packed_accessor32(), 428 | grad_keys.packed_accessor32(), 429 | N, H, L, E, M, M_per_block, blocks_per_sequence, T, l_offset 430 | ); 431 | } 432 | 433 | int MPB = min(threads, E*M); 434 | // make sure that MUL_PER_BLOCK is divisible by M; 435 | MPB = int(MPB / M) * M; 436 | const int blocks_per_sequence_value = ((E*M) + MPB - 1)/ MPB; 437 | const int E_per_block = MPB / M; 438 | const int blocks_value = N*H*blocks_per_sequence_value; 439 | 440 | shared_mem_const = (E_per_block + 1)*M; 441 | shared_mem_per_time = (M + 2*E_per_block); 442 | T = int(((12 * 1024) - shared_mem_const) / shared_mem_per_time); 443 | const int shared_mem_v_backward = ((T*shared_mem_per_time) + shared_mem_const) * sizeof(float); 444 | kv.zero_(); 445 | for (int l_offset=0; l_offset < L; l_offset += T) { 446 | causal_dot_backward_value_kernel 447 | <<>>( 448 | queries.packed_accessor32(), 449 | keys.packed_accessor32(), 450 | values.packed_accessor32(), 451 | grad_out.packed_accessor32(), 452 | kv.packed_accessor32(), 453 | grad_keys.packed_accessor32(), 454 | grad_values.packed_accessor32(), 455 | N, H, L, E, M, E_per_block, blocks_per_sequence_value, T, l_offset 456 | ); 457 | } 458 | } 459 | 460 | 461 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 462 | m.def( 463 | "causal_dot_product", 464 | &causal_dot_product, 465 | "Compute the weighted sum of values but attending only to previous " 466 | "values." 467 | ); 468 | m.def( 469 | "causal_dot_backward", 470 | &causal_dot_backward, 471 | "Compute the gradients for the causal dot product." 472 | ); 473 | } --------------------------------------------------------------------------------