├── __init__.py ├── longformer ├── __init__.py ├── lib │ └── lib_diagonaled_mm_float32_cuda.so ├── sliding_chunks.py ├── longformer.py └── diagonaled_mm_tvm.py ├── requirements.txt ├── .gitignore ├── predict.py ├── README.md └── classification.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /longformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | transformers==3.4.0 3 | tensorboardX 4 | pytorch-lightning==1.0.3 5 | -------------------------------------------------------------------------------- /longformer/lib/lib_diagonaled_mm_float32_cuda.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCHENLIU/longformer-chinese/HEAD/longformer/lib/lib_diagonaled_mm_float32_cuda.so -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | longformer-chinese-base-4096/ 7 | data/ 8 | models/ 9 | *.txt 10 | # Distribution / packaging 11 | longformer-chinese-base-4096/ 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from classification import LongformerClassifier, ClassificationDataset 2 | from collections import namedtuple 3 | from transformers import BertTokenizer 4 | import torch, time, sys 5 | Argument = namedtuple("Argument", ['test_checkpoint', 'num_labels', "test_file", "sequence_length"]) 6 | args = Argument(test_checkpoint='models/version_0/checkpoints/ep-epoch=0_acc-acc=0.915.ckpt', 7 | num_labels=18, 8 | test_file = "data/train.txt", 9 | sequence_length = 4096) 10 | 11 | class LongformerClassify(): 12 | def __init__(self, mask_padding_with_zero=True, map_generate_from_file=True): 13 | self.data = [] 14 | self._tokenizer = BertTokenizer.from_pretrained('longformer-chinese-base-4096/') 15 | self._tokenizer.model_max_length = args.sequence_length 16 | self.mask_padding_with_zero = mask_padding_with_zero 17 | self.seqlen = args.sequence_length 18 | if map_generate_from_file: 19 | self.produce_label_map() 20 | self.device=torch.device("cuda:0"if torch.cuda.is_available() else "cpu") 21 | self.model = LongformerClassifier.load_from_checkpoint(args.test_checkpoint, num_labels=args.num_labels) 22 | self.model.to(self.device) 23 | 24 | def produce_label_map(self): 25 | data = [] 26 | with open(args.test_file, encoding='UTF-8') as fin: 27 | for i, line in enumerate(fin): 28 | items = line.strip().split('\tSEP\t') 29 | if len(items) != 10: continue 30 | data.append({ 31 | "text": items[0]+items[1], 32 | "label": items[5] 33 | }) 34 | all_labels = list(set([e["label"] for e in data])) 35 | self.label_to_idx = {e: i for i, e in enumerate(sorted(all_labels))} 36 | self.idx_to_label = {v: k for k, v in self.label_to_idx.items()} 37 | print(self.label_to_idx) 38 | 39 | def _convert_to_tensors(self, instance): 40 | def tok(s): 41 | return self._tokenizer.tokenize(s) 42 | tokens = [self._tokenizer.cls_token] + tok(instance["text"]) 43 | token_ids = self._tokenizer.convert_tokens_to_ids(tokens) 44 | token_ids = token_ids[:self.seqlen-1] +[self._tokenizer.sep_token_id] 45 | input_len = len(token_ids) 46 | attention_mask = [1 if self.mask_padding_with_zero else 0] * input_len 47 | padding_length = self.seqlen - input_len 48 | token_ids = token_ids + ([self._tokenizer.pad_token_id] * padding_length) 49 | attention_mask = attention_mask + ([0 if self.mask_padding_with_zero else 1] * padding_length) 50 | assert len(token_ids) == self.seqlen, "Error with input length {} vs {}".format( 51 | len(token_ids), self.seqlen 52 | ) 53 | assert len(attention_mask) == self.seqlen, "Error with input length {} vs {}".format( 54 | len(attention_mask), self.seqlen 55 | ) 56 | label = self.label_to_idx[instance["label"]] if instance["label"] else 0 57 | return (torch.tensor([token_ids]).to(self.device), torch.tensor([attention_mask]).to(self.device), torch.tensor([label]).to(self.device)) 58 | 59 | def predict(self, text): 60 | instance = {"text": text, "label": None} 61 | token_ids, attention_mask, label = self._convert_to_tensors(instance=instance) 62 | logits = self.model(token_ids, attention_mask, label)[0] 63 | softmax = torch.nn.Softmax(dim=0) 64 | probabilities = softmax(logits.squeeze()) 65 | res_list = sorted(zip(range(args.num_labels), probabilities.tolist()), key=lambda a: a[1], reverse=True) 66 | res_list = [(self.idx_to_label[i],str(j)) for i,j in res_list] 67 | return res_list[:5] 68 | 69 | 70 | 71 | 72 | 73 | if __name__ == '__main__': 74 | classifier = LongformerClassify() 75 | input_file = sys.argv[1] 76 | fw = open("result.txt", "w") 77 | with open(input_file, "r") as fr: 78 | for item in fr: 79 | item = item.strip().split("\t|SEP|\t") 80 | if len(item) != 10: 81 | continue 82 | res = [] 83 | for r in classifier.predict(item[8]+item[9]): 84 | res.extend(r) 85 | fw.write('\t|SEP|\t'.join(item+res)+'\n') 86 | fw.flush() 87 | fw.close() 88 | # for i in range(100): 89 | # begin = time.time() 90 | # print(classifier.predict("泰安市政协十三届三十四次 主席会议召开")) 91 | # print(time.time() - begin) 92 | 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

`Longformer-chinese`

2 | All work is based on `Longformer`(https://github.com/allenai/longformer) 3 | 4 | `Longformer-chinese` 提供了:基于BERT的中文预训练模型、在分类任务上的实现 5 | 6 | ### WHAT'S DIFFERENT 7 | 8 | `Longformer-chinese` 基于BERT框架进行修改,在embedding层会与原版的稍有区别。加载时使用longformer.longformer: 9 | 10 | ``` 11 | from longformer.longformer import * 12 | config = LongformerConfig.from_pretrained('schen/longformer-chinese-base-4096') 13 | model = Longformer.from_pretrained('schen/longformer-chinese-base-4096', config=config) 14 | ``` 15 | 16 | 17 | 使用`schen/longformer-chinese-base-4096`会自动从transformers下载预训练模型,也可以自行下载后替换成所在目录: 18 | https://huggingface.co/schen/longformer-chinese-base-4096 19 | 20 | ### How to use 21 | 22 | 1. Download pretrained model 23 | * [`longformer-base-4096`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-base-4096.tar.gz) 24 | * [`longformer-large-4096`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-large-4096.tar.gz) 25 | 26 | 2. Install environment and code 27 | 28 | ```bash 29 | conda create --name longformer python=3.7 30 | conda activate longformer 31 | conda install cudatoolkit=10.0 32 | pip install git+https://github.com/allenai/longformer.git 33 | ``` 34 | 35 | 3. Run the model 36 | 37 | ```python 38 | import torch 39 | from longformer.longformer import Longformer, LongformerConfig 40 | from longformer.sliding_chunks import pad_to_window_size 41 | from transformers import RobertaTokenizer 42 | 43 | config = LongformerConfig.from_pretrained('longformer-base-4096/') 44 | # choose the attention mode 'n2', 'tvm' or 'sliding_chunks' 45 | # 'n2': for regular n2 attantion 46 | # 'tvm': a custom CUDA kernel implementation of our sliding window attention 47 | # 'sliding_chunks': a PyTorch implementation of our sliding window attention 48 | config.attention_mode = 'sliding_chunks' 49 | 50 | model = Longformer.from_pretrained('longformer-base-4096/', config=config) 51 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 52 | tokenizer.model_max_length = model.config.max_position_embeddings 53 | 54 | SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document 55 | 56 | input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 57 | 58 | # TVM code doesn't work on CPU. Uncomment this if `config.attention_mode = 'tvm'` 59 | # model = model.cuda(); input_ids = input_ids.cuda() 60 | 61 | # Attention mask values -- 0: no attention, 1: local attention, 2: global attention 62 | attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention 63 | attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example, 64 | # classification: the token 65 | # QA: question tokens 66 | 67 | # padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention 68 | input_ids, attention_mask = pad_to_window_size( 69 | input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id) 70 | 71 | output = model(input_ids, attention_mask=attention_mask)[0] 72 | ``` 73 | 74 | ### Model pretraining 75 | 76 | [This notebook](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) demonstrates our procedure for training Longformer starting from the RoBERTa checkpoint. The same procedure can be followed to get a long-version of other existing pretrained models. 77 | 78 | ### TriviaQA 79 | 80 | * Training scripts: `scripts/triviaqa.py` 81 | * Pretrained large model: [`here`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/triviaqa-longformer-large.tar.gz) (replicates leaderboard results) 82 | * Instructions: `scripts/cheatsheet.txt` 83 | 84 | 85 | ### CUDA kernel 86 | 87 | Our custom CUDA kernel is implemented in TVM. For now, the kernel only works on GPUs and Linux. We tested it on Ubuntu, Python 3.7, CUDA10, PyTorch >= 1.2.0. If it doesn't work for your environment, please create a new issue. 88 | 89 | **Compiling the kernel**: We already include the compiled binaries of the CUDA kernel, so most users won't need to compile it, but if you are intersted, check `scripts/cheatsheet.txt` for instructions. 90 | 91 | 92 | ### Known issues 93 | 94 | Please check the repo [issues](https://github.com/allenai/longformer/issues) for a list of known issues that we are planning to address soon. If your issue is not discussed, please create a new one. 95 | 96 | 97 | ### Citing 98 | 99 | If you use `Longformer` in your research, please cite [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150). 100 | ``` 101 | @article{Beltagy2020Longformer, 102 | title={Longformer: The Long-Document Transformer}, 103 | author={Iz Beltagy and Matthew E. Peters and Arman Cohan}, 104 | journal={arXiv:2004.05150}, 105 | year={2020}, 106 | } 107 | ``` 108 | 109 | `Longformer` is an open-source project developed by [the Allen Institute for Artificial Intelligence (AI2)](http://www.allenai.org). 110 | AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering. 111 | -------------------------------------------------------------------------------- /longformer/sliding_chunks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from longformer.diagonaled_mm_tvm import mask_invalid_locations 4 | 5 | 6 | def _skew(x, direction, padding_value): 7 | '''Convert diagonals into columns (or columns into diagonals depending on `direction`''' 8 | x_padded = F.pad(x, direction, value=padding_value) 9 | x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) 10 | return x_padded 11 | 12 | 13 | def _skew2(x, padding_value): 14 | '''shift every row 1 step to right converting columns into diagonals''' 15 | # X = B x C x M x L 16 | B, C, M, L = x.size() 17 | x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1) 18 | x = x.view(B, C, -1) # B x C x ML+MM+M 19 | x = x[:, :, :-M] # B x C x ML+MM 20 | x = x.view(B, C, M, M + L) # B x C, M x L+M 21 | x = x[:, :, :, :-1] 22 | return x 23 | 24 | 25 | def _chunk(x, w): 26 | '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w''' 27 | 28 | # non-overlapping chunks of size = 2w 29 | x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) 30 | 31 | # use `as_strided` to make the chunks overlap with an overlap size = w 32 | chunk_size = list(x.size()) 33 | chunk_size[1] = chunk_size[1] * 2 - 1 34 | 35 | chunk_stride = list(x.stride()) 36 | chunk_stride[1] = chunk_stride[1] // 2 37 | return x.as_strided(size=chunk_size, stride=chunk_stride) 38 | 39 | 40 | def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): 41 | '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. 42 | This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) 43 | with an overlap of size w''' 44 | bsz, seqlen, num_heads, head_dim = q.size() 45 | assert seqlen % (w * 2) == 0 46 | assert q.size() == k.size() 47 | 48 | chunks_count = seqlen // w - 1 49 | 50 | # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2 51 | q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) 52 | k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) 53 | 54 | chunk_q = _chunk(q, w) 55 | chunk_k = _chunk(k, w) 56 | 57 | # matrix multipication 58 | # bcxd: bsz*num_heads x chunks x 2w x head_dim 59 | # bcyd: bsz*num_heads x chunks x 2w x head_dim 60 | # bcxy: bsz*num_heads x chunks x 2w x 2w 61 | chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply 62 | 63 | # convert diagonals into columns 64 | diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value) 65 | 66 | # allocate space for the overall attention matrix where the chunks are compined. The last dimension 67 | # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to 68 | # w previous words). The following column is attention score from each word to itself, then 69 | # followed by w columns for the upper triangle. 70 | 71 | diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1)) 72 | 73 | # copy parts from diagonal_chunk_attn into the compined matrix of attentions 74 | # - copying the main diagonal and the upper triangle 75 | diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1] 76 | diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1] 77 | # - copying the lower triangle 78 | diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:] 79 | diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:] 80 | 81 | # separate bsz and num_heads dimensions again 82 | diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1) 83 | 84 | mask_invalid_locations(diagonal_attn, w, 1, False) 85 | return diagonal_attn 86 | 87 | 88 | def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): 89 | '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output 90 | format from sliding_chunks_matmul_qk''' 91 | bsz, seqlen, num_heads, head_dim = v.size() 92 | assert seqlen % (w * 2) == 0 93 | assert prob.size()[:3] == v.size()[:3] 94 | assert prob.size(3) == 2 * w + 1 95 | chunks_count = seqlen // w - 1 96 | # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w 97 | chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1) 98 | 99 | # group bsz and num_heads dimensions into one 100 | v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) 101 | 102 | # pad seqlen with w at the beginning of the sequence and another w at the end 103 | padded_v = F.pad(v, (0, 0, w, w), value=-1) 104 | 105 | # chunk padded_v into chunks of size 3w and an overlap of size w 106 | chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim) 107 | chunk_v_stride = padded_v.stride() 108 | chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] 109 | chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) 110 | 111 | skewed_prob = _skew2(chunk_prob, padding_value=0) 112 | 113 | context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v)) 114 | return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2) 115 | 116 | 117 | def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor, 118 | one_sided_window_size: int, pad_token_id: int): 119 | '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention. 120 | Input: 121 | input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces 122 | attention_mask = torch.Tensor(bsz x seqlen): attention mask 123 | one_sided_window_size = int: window size on one side of each token 124 | pad_token_id = int: tokenizer.pad_token_id 125 | Returns 126 | (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size 127 | ''' 128 | w = 2 * one_sided_window_size 129 | seqlen = input_ids.size(1) 130 | padding_len = (w - seqlen % w) % w 131 | input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) 132 | attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens 133 | return input_ids, attention_mask 134 | 135 | 136 | # ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention ========= 137 | # This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw 138 | # To make this implemenation comparable to "sliding_chunks" set w such that 139 | # w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3 140 | # For example, 141 | # w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512) 142 | # w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510) 143 | # Performance: 144 | # - Speed: 30% faster than "sliding_chunks" 145 | # - Memory: 95% of the memory usage of "sliding_chunks" 146 | # The windows are asymmetric where number of attention on each side of a token ranges between w to 2w 147 | # while "sliding_chunks" has a symmetric window around each token. 148 | 149 | 150 | def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): 151 | bsz, seqlen, num_heads, head_dim = q.size() 152 | assert seqlen % w == 0 153 | assert q.size() == k.size() 154 | # chunk seqlen into non-overlapping chunks of size w 155 | chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim) 156 | chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim) 157 | chunk_k_expanded = torch.stack(( 158 | F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), 159 | chunk_k, 160 | F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), 161 | ), dim=-1) 162 | diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply 163 | return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w) 164 | 165 | 166 | def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): 167 | bsz, seqlen, num_heads, head_dim = v.size() 168 | chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w) 169 | chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim) 170 | chunk_v_extended = torch.stack(( 171 | F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), 172 | chunk_v, 173 | F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), 174 | ), dim=-1) 175 | context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended)) 176 | return context.reshape(bsz, seqlen, num_heads, head_dim) 177 | -------------------------------------------------------------------------------- /longformer/longformer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import math 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations 7 | from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv 8 | from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv 9 | from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM 10 | from transformers.modeling_bert import BertConfig, BertModel, BertForMaskedLM 11 | 12 | 13 | class Longformer(BertModel): 14 | def __init__(self, config): 15 | super(Longformer, self).__init__(config) 16 | if config.attention_mode == 'n2': 17 | pass # do nothing, use BertSelfAttention instead 18 | else: 19 | for i, layer in enumerate(self.encoder.layer): 20 | layer.attention.self = LongformerSelfAttention(config, layer_id=i) 21 | 22 | 23 | class LongformerForMaskedLM(BertForMaskedLM): 24 | def __init__(self, config): 25 | super(LongformerForMaskedLM, self).__init__(config) 26 | if config.attention_mode == 'n2': 27 | pass # do nothing, use BertSelfAttention instead 28 | else: 29 | for i, layer in enumerate(self.bert.encoder.layer): 30 | layer.attention.self = LongformerSelfAttention(config, layer_id=i) 31 | 32 | 33 | class LongformerConfig(BertConfig): 34 | def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, 35 | autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs): 36 | """ 37 | Args: 38 | attention_window: list of attention window sizes of length = number of layers. 39 | window size = number of attention locations on each side. 40 | For an affective window size of 512, use `attention_window=[256]*num_layers` 41 | which is 256 on each side. 42 | attention_dilation: list of attention dilation of length = number of layers. 43 | attention dilation of `1` means no dilation. 44 | autoregressive: do autoregressive attention or have attention of both sides 45 | attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer 46 | selfattention, 'sliding_chunks' for another implementation of Longformer selfattention 47 | """ 48 | super().__init__(**kwargs) 49 | self.attention_window = attention_window 50 | self.attention_dilation = attention_dilation 51 | self.autoregressive = autoregressive 52 | self.attention_mode = attention_mode 53 | assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap'] 54 | 55 | 56 | class LongformerSelfAttention(nn.Module): 57 | def __init__(self, config, layer_id): 58 | super(LongformerSelfAttention, self).__init__() 59 | if config.hidden_size % config.num_attention_heads != 0: 60 | raise ValueError( 61 | "The hidden size (%d) is not a multiple of the number of attention " 62 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 63 | self.output_attentions = config.output_attentions 64 | self.num_heads = config.num_attention_heads 65 | self.head_dim = int(config.hidden_size / config.num_attention_heads) 66 | self.embed_dim = config.hidden_size 67 | 68 | self.query = nn.Linear(config.hidden_size, self.embed_dim) 69 | self.key = nn.Linear(config.hidden_size, self.embed_dim) 70 | self.value = nn.Linear(config.hidden_size, self.embed_dim) 71 | 72 | self.query_global = nn.Linear(config.hidden_size, self.embed_dim) 73 | self.key_global = nn.Linear(config.hidden_size, self.embed_dim) 74 | self.value_global = nn.Linear(config.hidden_size, self.embed_dim) 75 | 76 | self.dropout = config.attention_probs_dropout_prob 77 | 78 | self.layer_id = layer_id 79 | self.attention_window = config.attention_window[self.layer_id] 80 | # self.attention_dilation = config.attention_dilation[self.layer_id] 81 | self.attention_dilation = 1 82 | self.attention_mode = config.attention_mode 83 | self.autoregressive = config.autoregressive 84 | assert self.attention_window > 0 85 | assert self.attention_dilation > 0 86 | assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap'] 87 | if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']: 88 | assert not self.autoregressive # not supported 89 | assert self.attention_dilation == 1 # dilation is not supported 90 | 91 | def forward( 92 | self, 93 | hidden_states, 94 | attention_mask=None, 95 | head_mask=None, 96 | encoder_hidden_states=None, 97 | encoder_attention_mask=None, 98 | output_attentions=False, 99 | ): 100 | ''' 101 | The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to 102 | -ve: no attention 103 | 0: local attention 104 | +ve: global attention 105 | ''' 106 | assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" 107 | assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" 108 | 109 | if attention_mask is not None: 110 | attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) 111 | key_padding_mask = attention_mask < 0 112 | extra_attention_mask = attention_mask > 0 113 | remove_from_windowed_attention_mask = attention_mask != 0 114 | 115 | num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) 116 | max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() 117 | if max_num_extra_indices_per_batch <= 0: 118 | extra_attention_mask = None 119 | else: 120 | # To support the case of variable number of global attention in the rows of a batch, 121 | # we use the following three selection masks to select global attention embeddings 122 | # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` 123 | # 1) selecting embeddings that correspond to global attention 124 | extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) 125 | zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch, 126 | device=num_extra_indices_per_batch.device) 127 | # mask indicating which values are actually going to be padding 128 | selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) 129 | # 2) location of the non-padding values in the selected global attention 130 | selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) 131 | # 3) location of the padding values in the selected global attention 132 | selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) 133 | else: 134 | remove_from_windowed_attention_mask = None 135 | extra_attention_mask = None 136 | key_padding_mask = None 137 | 138 | hidden_states = hidden_states.transpose(0, 1) 139 | seq_len, bsz, embed_dim = hidden_states.size() 140 | assert embed_dim == self.embed_dim 141 | q = self.query(hidden_states) 142 | k = self.key(hidden_states) 143 | v = self.value(hidden_states) 144 | q /= math.sqrt(self.head_dim) 145 | 146 | q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) 147 | k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) 148 | # attn_weights = (bsz, seq_len, num_heads, window*2+1) 149 | if self.attention_mode == 'tvm': 150 | q = q.float().contiguous() 151 | k = k.float().contiguous() 152 | attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) 153 | elif self.attention_mode == "sliding_chunks": 154 | attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) 155 | elif self.attention_mode == "sliding_chunks_no_overlap": 156 | attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) 157 | else: 158 | raise False 159 | mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) 160 | if remove_from_windowed_attention_mask is not None: 161 | # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 162 | # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) 163 | remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) 164 | # cast to float/half then replace 1's with -inf 165 | float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0) 166 | repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) 167 | float_mask = float_mask.repeat(1, 1, repeat_size, 1) 168 | ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones 169 | # diagonal mask with zeros everywhere and -inf inplace of padding 170 | if self.attention_mode == 'tvm': 171 | d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) 172 | elif self.attention_mode == "sliding_chunks": 173 | d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) 174 | elif self.attention_mode == "sliding_chunks_no_overlap": 175 | d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) 176 | 177 | attn_weights += d_mask 178 | assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] 179 | assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] 180 | 181 | # the extra attention 182 | if extra_attention_mask is not None: 183 | selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) 184 | selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] 185 | # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) 186 | selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) 187 | selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 188 | # concat to attn_weights 189 | # (bsz, seq_len, num_heads, extra attention count + 2*window+1) 190 | attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) 191 | 192 | attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability 193 | if key_padding_mask is not None: 194 | # softmax sometimes inserts NaN if all positions are masked, replace them with 0 195 | attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) 196 | 197 | attn_weights = attn_weights_float.type_as(attn_weights) 198 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) 199 | v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) 200 | attn = 0 201 | if extra_attention_mask is not None: 202 | selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) 203 | selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) 204 | selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] 205 | # use `matmul` because `einsum` crashes sometimes with fp16 206 | # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) 207 | attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) 208 | attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() 209 | 210 | if self.attention_mode == 'tvm': 211 | v = v.float().contiguous() 212 | attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) 213 | elif self.attention_mode == "sliding_chunks": 214 | attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) 215 | elif self.attention_mode == "sliding_chunks_no_overlap": 216 | attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window) 217 | else: 218 | raise False 219 | 220 | attn = attn.type_as(hidden_states) 221 | assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] 222 | attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() 223 | 224 | # For this case, we'll just recompute the attention for these indices 225 | # and overwrite the attn tensor. TODO: remove the redundant computation 226 | if extra_attention_mask is not None: 227 | selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) 228 | selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]] 229 | 230 | q = self.query_global(selected_hidden_states) 231 | k = self.key_global(hidden_states) 232 | v = self.value_global(hidden_states) 233 | q /= math.sqrt(self.head_dim) 234 | 235 | q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) 236 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) 237 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) 238 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 239 | assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len] 240 | 241 | attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) 242 | attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 243 | if key_padding_mask is not None: 244 | attn_weights = attn_weights.masked_fill( 245 | key_padding_mask.unsqueeze(1).unsqueeze(2), 246 | -10000.0, 247 | ) 248 | attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) 249 | attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability 250 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) 251 | selected_attn = torch.bmm(attn_probs, v) 252 | assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim] 253 | 254 | selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) 255 | nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] 256 | attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) 257 | 258 | context_layer = attn.transpose(0, 1) 259 | if output_attentions: 260 | if extra_attention_mask is not None: 261 | # With global attention, return global attention probabilities only 262 | # batch_size x num_heads x max_num_global_attention_tokens x sequence_length 263 | # which is the attention weights from tokens with global attention to all tokens 264 | # It doesn't not return local attention 265 | # In case of variable number of global attantion in the rows of a batch, 266 | # attn_weights are padded with -10000.0 attention scores 267 | attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) 268 | else: 269 | # without global attention, return local attention probabilities 270 | # batch_size x num_heads x sequence_length x window_size 271 | # which is the attention weights of every token attending to its neighbours 272 | attn_weights = attn_weights.permute(0, 2, 1, 3) 273 | outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) 274 | return outputs 275 | -------------------------------------------------------------------------------- /longformer/diagonaled_mm_tvm.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from functools import lru_cache 3 | 4 | import torch 5 | import os.path 6 | 7 | 8 | class DiagonaledMM(torch.autograd.Function): 9 | '''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling 10 | this function from PyTorch 11 | ''' 12 | 13 | function_dict = {} # save a list of functions, each has a different set of parameters 14 | 15 | @staticmethod 16 | def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16): 17 | '''Compiles a tvm function that computes diagonal_mm 18 | args: 19 | dtype: str in ['float64', 'float32', 'float16'] 20 | device: str in ['cpu' or 'cuda'] 21 | b0, b1, b2: size of tensor tiles. Very important for good performance 22 | 23 | ''' 24 | import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile 25 | from tvm.contrib import nvcc 26 | @tvm.register_func 27 | def tvm_callback_cuda_compile(code): 28 | """Use nvcc compiler for better perf.""" 29 | ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs 30 | return ptx 31 | 32 | assert dtype in ['float16', 'float32', 'float64'] 33 | assert device in ['cpu', 'cuda'] 34 | device = None if device == 'cpu' else device 35 | tgt_host="llvm" 36 | 37 | b = tvm.var('b') # batch size 38 | n = tvm.var('n') # sequence length 39 | h = tvm.var('h') # number of heads 40 | m = tvm.var('m') # hidden dimension 41 | w = tvm.var('w') # window size 42 | w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w` 43 | padding = tvm.var('padding') # padding 44 | transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed 45 | t1d3 = tvm.var('t1d3') # last dimension of t1 46 | t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor) 47 | X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor 48 | Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor 49 | k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over 50 | D = tvm.placeholder((h), name='D', dtype='int') # dilation per head 51 | output_shape = (b, n, h, t3d3) # shape of the result tensor 52 | algorithm = lambda l, i, q, j: tvm.sum( 53 | tvm.if_then_else( 54 | t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3) 55 | tvm.if_then_else( 56 | transpose_t1 == 0, 57 | tvm.if_then_else( 58 | tvm.all( 59 | i + D[q] * (k - w) >= 0, 60 | i + D[q] * (k - w) < n, 61 | ), 62 | X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled 63 | padding 64 | ), 65 | tvm.if_then_else( 66 | tvm.all( 67 | i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True` 68 | i + D[q] * (k - w_upper) < n, 69 | ), 70 | X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed 71 | padding 72 | ), 73 | ), 74 | tvm.if_then_else( 75 | tvm.all( 76 | i + D[q] * (j - w) >= 0, 77 | i + D[q] * (j - w) < n, 78 | ), 79 | X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be 80 | padding 81 | ) 82 | ), axis=k) 83 | 84 | Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code 85 | s = tvm.create_schedule(Z.op) 86 | 87 | print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True))) 88 | 89 | # split long axis into smaller chunks and assing each one to a separate GPU thread/block 90 | ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0) 91 | ZF = s.rfactor(Z, ki) 92 | 93 | j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1) 94 | i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2) 95 | 96 | s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x")) 97 | s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y")) 98 | 99 | s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y")) 100 | s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z")) 101 | 102 | tx = tvm.thread_axis("threadIdx.x") 103 | s[Z].bind(s[Z].op.reduce_axis[0], tx) 104 | s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0]) 105 | s[Z].set_store_predicate(tx.var.equal(0)) 106 | 107 | print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True))) 108 | 109 | # compiling the automatically generated cuda code 110 | diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm') 111 | return diagonaled_mm 112 | 113 | @staticmethod 114 | def _get_lib_filename(dtype: str, device: str): 115 | base_filename = 'longformer/lib/lib_diagonaled_mm' 116 | return '{}_{}_{}.so'.format(base_filename, dtype, device) 117 | 118 | @staticmethod 119 | def _save_compiled_function(f, dtype: str, device: str): 120 | if not os.path.exists('longformer/lib/'): 121 | os.makedirs('longformer/lib/') 122 | f.export_library(DiagonaledMM._get_lib_filename(dtype, device)) 123 | 124 | @staticmethod 125 | def _load_compiled_function(dtype: str, device: str): 126 | from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing 127 | filename = DiagonaledMM._get_lib_filename(dtype, device) 128 | current_dir = os.path.dirname(os.path.abspath(__file__)) 129 | potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../'] 130 | for potential_dir in potential_dirs: 131 | filepath = '{}{}'.format(potential_dir, filename) 132 | if os.path.isfile(filepath): 133 | print('Loading tvm binary from: {}'.format(filepath)) 134 | return load(filepath) 135 | return None 136 | 137 | @staticmethod 138 | def _get_function(dtype: str, device: str): 139 | '''Loads the function from the disk or compile it''' 140 | # A list of arguments that define the function 141 | args = (dtype, device) 142 | if args not in DiagonaledMM.function_dict: 143 | diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk 144 | if not diagonaled_mm: 145 | print('Tvm binary not found. Compiling ...') 146 | diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile 147 | DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk 148 | # convert the tvm function into a pytorch function 149 | from tvm.contrib import dlpack 150 | diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function 151 | # save the function into a dictionary to be reused 152 | DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time 153 | return DiagonaledMM.function_dict[args] 154 | 155 | @staticmethod 156 | def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], 157 | is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0, 158 | autoregressive: bool = False): 159 | '''Calls the compiled function after checking the input format. This function is called in three different modes. 160 | t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores 161 | t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context 162 | t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of 163 | the calculations in the backward pass. 164 | ''' 165 | dtype = str(t1.dtype).split('.')[1] 166 | device = t1.device.type 167 | assert len(t1.shape) == 4 168 | assert len(t1.shape) == len(t2.shape) 169 | assert t1.shape[:3] == t2.shape[:3] 170 | if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length 171 | # as number of heads, and it is filled with the same dilation value 172 | d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False) 173 | 174 | assert len(d.shape) == 1 175 | assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads 176 | b = t1.shape[0] # batch size 177 | n = t1.shape[1] # sequence length 178 | h = t1.shape[2] # number of heads 179 | m = t2.shape[3] # hidden dimension 180 | w_upper = 0 if autoregressive else w 181 | c = w_upper + w + 1 # number of diagonals 182 | if is_t1_diagonaled: 183 | assert t1.shape[3] == c 184 | r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor 185 | else: 186 | assert not transpose_t1 187 | assert t1.shape[3] == m 188 | r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor 189 | 190 | # gets function from memory, from disk or compiles it from scratch 191 | _diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device) 192 | 193 | # The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor 194 | # We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa). 195 | # The second reason is that the lambda expression in `_compile_function` is easier to express when the shape 196 | # of the output is known 197 | # This functions computes diagonal_mm then saves the result in `r` 198 | if m == c: 199 | # FIXME 200 | print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}') 201 | assert False 202 | _diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c) 203 | return r 204 | 205 | @staticmethod 206 | def _prepare_tensors(t): 207 | '''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch. 208 | For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter. 209 | TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`. 210 | Here's an example to reporduce this issue: 211 | import torch 212 | print(torch.randn(1, 10).stride()) 213 | > (10, 1) 214 | print(torch.randn(10, 1).t().contiguous().stride()) 215 | > (1, 1) # expected it to be (10, 1) as above 216 | print(torch.randn(10, 2).t().contiguous().stride()) 217 | > (10, 1) # but gets the expected stride if the first dimension is > 1 218 | ''' 219 | assert t.is_contiguous() 220 | t_stride = list(t.stride()) 221 | t_size = list(t.size()) 222 | # Fix wrong stride information for the first dimension. This occures when batch_size=1 223 | if t_size[0] == 1 and t_stride[0] == t_stride[1]: 224 | # In this case, the stride of the first dimension should be the product 225 | # of the sizes of all other dimensions 226 | t_stride[0] = t_size[1] * t_size[2] * t_size[3] 227 | t = t.as_strided(size=t_size, stride=t_stride) 228 | return t 229 | 230 | min_seq_len = 16 # unexpected output if seq_len < 16 231 | 232 | @staticmethod 233 | def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor: 234 | '''Compuates diagonal_mm of t1 and t2. 235 | args: 236 | t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals). 237 | t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`) 238 | t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled 239 | tensor, e.g. `key_layer` or `value_layer` 240 | w: int = window size; number of attentions on each side of the word 241 | d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all 242 | heads. If torch.Tensor, it should be 1D of lenth=number of attention heads 243 | is_t1_diagonaled: is t1 a diagonaled or a regular tensor 244 | padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding 245 | needs to be a very large negative value (to compute softmax of attentions). For other usecases, 246 | please use zero padding. 247 | autoregressive: if true, return only the lower triangle 248 | returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals) 249 | if t1 is diagonaed, result is non-diagonaled, and vice versa 250 | ''' 251 | batch_size, seq_len, num_attention_heads, hidden_size = t1.size() 252 | assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME 253 | ctx.save_for_backward(t1, t2) 254 | ctx.w = w 255 | ctx.d = d 256 | ctx.is_t1_diagonaled = is_t1_diagonaled 257 | ctx.autoregressive = autoregressive 258 | t1 = DiagonaledMM._prepare_tensors(t1) 259 | t2 = DiagonaledMM._prepare_tensors(t2) 260 | # output = t1.mm(t2) # what would have been called if this was a regular matmul 261 | output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive) 262 | return output 263 | 264 | @staticmethod 265 | def backward(ctx, grad_output): 266 | t1, t2 = ctx.saved_tensors 267 | w = ctx.w 268 | d = ctx.d 269 | is_t1_diagonaled = ctx.is_t1_diagonaled 270 | autoregressive = ctx.autoregressive 271 | if not grad_output.is_contiguous(): 272 | grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous 273 | grad_output = DiagonaledMM._prepare_tensors(grad_output) 274 | t1 = DiagonaledMM._prepare_tensors(t1) 275 | t2 = DiagonaledMM._prepare_tensors(t2) 276 | # http://cs231n.github.io/optimization-2/ 277 | # https://pytorch.org/docs/master/notes/extending.html 278 | # grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul 279 | grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive) 280 | # grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T` 281 | if is_t1_diagonaled: 282 | grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive) 283 | else: 284 | grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive) 285 | return grad_t1, grad_t2, None, None, None, None, None 286 | 287 | 288 | def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int): 289 | diagonals_list = [] 290 | for j in range(-d * w, d, d): 291 | diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8) 292 | diagonal_mask[:-j] = 1 293 | diagonals_list.append(diagonal_mask) 294 | return torch.stack(diagonals_list, dim=-1) 295 | 296 | @lru_cache() 297 | def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str): 298 | if isinstance(d, int): 299 | affected_seq_len = w * d 300 | mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d) 301 | mask = mask[None, :, None, :] 302 | else: 303 | affected_seq_len = w * d.max() 304 | head_masks = [] 305 | d_list = d.cpu().numpy().tolist() 306 | for d in d_list: 307 | one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d) 308 | head_masks.append(one_head_mask) 309 | mask = torch.stack(head_masks, dim=-2) 310 | mask = mask[None, :, :, :] 311 | 312 | ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device) 313 | return affected_seq_len, mask.bool().to(device), ending_mask 314 | 315 | def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor: 316 | affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device) 317 | seq_len = input_tensor.size(1) 318 | beginning_input = input_tensor[:, :affected_seq_len, :, :w+1] 319 | beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size()) 320 | beginning_input.masked_fill_(beginning_mask, -float('inf')) 321 | if not autoregressive: 322 | ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):] 323 | ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size()) 324 | ending_input.masked_fill_(ending_mask, -float('inf')) 325 | 326 | 327 | diagonaled_mm = DiagonaledMM.apply 328 | 329 | # The non-tvm implementation is the default, we don't need to load the kernel at loading time. 330 | # DiagonaledMM._get_function('float32', 'cuda') 331 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import os 5 | import random 6 | import argparse 7 | from argparse import Namespace 8 | import numpy as np 9 | import glob 10 | import gzip 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import CrossEntropyLoss, MSELoss 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader, Dataset, random_split 17 | import pytorch_lightning as pl 18 | from pytorch_lightning.loggers import TensorBoardLogger 19 | from pytorch_lightning.callbacks import ModelCheckpoint 20 | 21 | import torch.distributed as dist 22 | 23 | from longformer.longformer import Longformer, LongformerConfig 24 | from longformer.sliding_chunks import pad_to_window_size 25 | from transformers import BertTokenizer, AdamW 26 | 27 | from torch.utils.data.dataset import IterableDataset 28 | from tqdm.auto import tqdm 29 | 30 | import logging 31 | logger = logging.getLogger(__name__) 32 | 33 | from transformers.optimization import ( 34 | Adafactor, 35 | get_cosine_schedule_with_warmup, 36 | get_cosine_with_hard_restarts_schedule_with_warmup, 37 | get_linear_schedule_with_warmup, 38 | get_polynomial_decay_schedule_with_warmup, 39 | ) 40 | 41 | 42 | TEXT_FIELD_NAME = 'text' 43 | LABEL_FIELD_NAME = 'label' 44 | 45 | arg_to_scheduler = { 46 | "linear": get_linear_schedule_with_warmup, 47 | "cosine": get_cosine_schedule_with_warmup, 48 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 49 | "polynomial": get_polynomial_decay_schedule_with_warmup, 50 | # '': get_constant_schedule, # not supported for now 51 | # '': get_constant_schedule_with_warmup, # not supported for now 52 | } 53 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) 54 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" 55 | 56 | 57 | def calc_f1(y_pred:torch.Tensor, y_true:torch.Tensor) -> torch.Tensor: 58 | tp = (y_true * y_pred).sum().to(torch.float32) 59 | tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32) 60 | fp = ((1 - y_true) * y_pred).sum().to(torch.float32) 61 | fn = (y_true * (1 - y_pred)).sum().to(torch.float32) 62 | epsilon = 1e-7 63 | precision = tp / (tp + fp + epsilon) 64 | recall = tp / (tp + fn + epsilon) 65 | 66 | f1 = 2 * (precision * recall) / (precision + recall + epsilon) 67 | f1 = f1.clamp(min=epsilon, max=1 - epsilon) 68 | return f1 69 | 70 | 71 | class ClassificationDataset(Dataset): 72 | def __init__(self, file_path, tokenizer, seqlen, num_samples=None, mask_padding_with_zero=True): 73 | self.data = [] 74 | with (gzip.open(file_path, 'rt') if file_path.endswith('.gz') else open(file_path)) as fin: 75 | for i, line in enumerate(tqdm(fin, desc=f'loading input file {file_path.split("/")[-1]}', unit_scale=1)): 76 | items = line.strip().split('\tSEP\t') 77 | if len(items) != 10: continue 78 | self.data.append({ 79 | "text": items[0]+items[1], 80 | "label": items[5] 81 | }) 82 | if num_samples and len(self.data) > num_samples: 83 | break 84 | self.seqlen = seqlen 85 | self._tokenizer = tokenizer 86 | all_labels = list(set([e[LABEL_FIELD_NAME] for e in self.data])) 87 | self.label_to_idx = {e: i for i, e in enumerate(sorted(all_labels))} 88 | self.idx_to_label = {v: k for k, v in self.label_to_idx.items()} 89 | self.mask_padding_with_zero = mask_padding_with_zero 90 | 91 | def __len__(self): 92 | return len(self.data) 93 | 94 | def __getitem__(self, idx): 95 | return self._convert_to_tensors(self.data[idx]) 96 | 97 | def _convert_to_tensors(self, instance): 98 | def tok(s): 99 | return self._tokenizer.tokenize(s) 100 | tokens = [self._tokenizer.cls_token] + tok(instance[TEXT_FIELD_NAME]) 101 | token_ids = self._tokenizer.convert_tokens_to_ids(tokens) 102 | token_ids = token_ids[:self.seqlen-1] +[self._tokenizer.sep_token_id] 103 | input_len = len(token_ids) 104 | attention_mask = [1 if self.mask_padding_with_zero else 0] * input_len 105 | padding_length = self.seqlen - input_len 106 | token_ids = token_ids + ([self._tokenizer.pad_token_id] * padding_length) 107 | 108 | attention_mask = attention_mask + ([0 if self.mask_padding_with_zero else 1] * padding_length) 109 | 110 | assert len(token_ids) == self.seqlen, "Error with input length {} vs {}".format( 111 | len(token_ids), self.seqlen 112 | ) 113 | assert len(attention_mask) == self.seqlen, "Error with input length {} vs {}".format( 114 | len(attention_mask), self.seqlen 115 | ) 116 | 117 | label = self.label_to_idx[instance[LABEL_FIELD_NAME]] 118 | 119 | return (torch.tensor(token_ids), torch.tensor(attention_mask), torch.tensor(label)) 120 | 121 | 122 | class LongformerClassifier(pl.LightningModule): 123 | 124 | def __init__(self, init_args): 125 | super().__init__() 126 | if isinstance(init_args, dict): 127 | # for loading the checkpoint, pl passes a dict (hparams are saved as dict) 128 | init_args = Namespace(**init_args) 129 | config_path = init_args.config_path or init_args.model_dir 130 | checkpoint_path = init_args.checkpoint_path or init_args.model_dir 131 | logger.info(f'loading model from config: {config_path}, checkpoint: {checkpoint_path}') 132 | config = LongformerConfig.from_pretrained(config_path) 133 | config.attention_mode = init_args.attention_mode 134 | logger.info(f'attention mode set to {config.attention_mode}') 135 | self.model_config = config 136 | self.model = Longformer.from_pretrained(checkpoint_path, config=config) 137 | self.tokenizer = BertTokenizer.from_pretrained(init_args.tokenizer) 138 | self.tokenizer.model_max_length = self.model.config.max_position_embeddings 139 | self.hparams = init_args 140 | self.hparams.seqlen = self.model.config.max_position_embeddings 141 | self.classifier = nn.Linear(config.hidden_size, init_args.num_labels) 142 | 143 | def forward(self, input_ids, attention_mask, labels=None): 144 | input_ids, attention_mask = pad_to_window_size( 145 | input_ids, attention_mask, self.model_config.attention_window[0], self.tokenizer.pad_token_id) 146 | attention_mask[:, 0] = 2 # global attention for the first token 147 | #use Bert inner Pooler 148 | output = self.model(input_ids, attention_mask=attention_mask)[1] 149 | # pool the entire sequence into one vector (CLS token) 150 | # output = output[:, 0, :] 151 | logits = self.classifier(output) 152 | 153 | loss = None 154 | if labels is not None: 155 | loss_fct = CrossEntropyLoss() 156 | 157 | loss = loss_fct(logits.view(-1, self.hparams.num_labels), labels.view(-1)) 158 | 159 | return logits, loss 160 | 161 | def _get_loader(self, split, shuffle=True): 162 | if split == 'train': 163 | fname = self.hparams.train_file 164 | elif split == 'dev': 165 | fname = self.hparams.dev_file 166 | elif split == 'test': 167 | fname = self.hparams.test_file 168 | else: 169 | assert False 170 | is_train = split == 'train' 171 | 172 | dataset = ClassificationDataset( 173 | fname, tokenizer=self.tokenizer, seqlen=self.hparams.seqlen, num_samples=self.hparams.num_samples 174 | ) 175 | 176 | loader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, shuffle=(shuffle and is_train)) 177 | return loader 178 | 179 | def setup(self, mode): 180 | self.train_loader = self._get_loader("train") 181 | 182 | def train_dataloader(self): 183 | return self.train_loader 184 | 185 | def val_dataloader(self): 186 | self.val_dataloader_obj = self._get_loader('dev') 187 | return self.val_dataloader_obj 188 | 189 | def test_dataloader(self): 190 | return self._get_loader('test') 191 | 192 | @property 193 | def total_steps(self) -> int: 194 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" 195 | num_devices = max(1, self.hparams.total_gpus) # TODO: consider num_tpu_cores 196 | effective_batch_size = self.hparams.batch_size * self.hparams.grad_accum * num_devices 197 | dataset_size = len(self.train_loader.dataset) 198 | return (dataset_size / effective_batch_size) * self.hparams.num_epochs 199 | 200 | def get_lr_scheduler(self): 201 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] 202 | scheduler = get_schedule_func( 203 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps 204 | ) 205 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 206 | return scheduler 207 | 208 | def configure_optimizers(self): 209 | """Prepare optimizer and schedule (linear warmup and decay)""" 210 | model = self.model 211 | no_decay = ["bias", "LayerNorm.weight"] 212 | optimizer_grouped_parameters = [ 213 | { 214 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 215 | "weight_decay": self.hparams.weight_decay, 216 | }, 217 | { 218 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 219 | "weight_decay": 0.0, 220 | }, 221 | ] 222 | if self.hparams.adafactor: 223 | optimizer = Adafactor( 224 | optimizer_grouped_parameters, lr=self.hparams.lr, scale_parameter=False, relative_step=False 225 | ) 226 | 227 | else: 228 | optimizer = AdamW( 229 | optimizer_grouped_parameters, lr=self.hparams.lr, eps=self.hparams.adam_epsilon 230 | ) 231 | self.opt = optimizer 232 | 233 | scheduler = self.get_lr_scheduler() 234 | 235 | return [optimizer], [scheduler] 236 | 237 | def training_step(self, batch, batch_idx): 238 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]} 239 | 240 | outputs = self(**inputs) 241 | loss = outputs[1] 242 | 243 | lr_scheduler = self.trainer.lr_schedulers[0]["scheduler"] 244 | tensorboard_logs = {"loss": loss, "rate": lr_scheduler.get_last_lr()[-1]} 245 | return {"loss": loss, "log": tensorboard_logs} 246 | 247 | 248 | def validation_step(self, batch, batch_idx): 249 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]} 250 | 251 | outputs = self(**inputs) 252 | logits, tmp_eval_loss = outputs 253 | preds = logits 254 | out_label_ids = inputs["labels"] 255 | return {"val_loss": tmp_eval_loss, "pred": preds, "target": out_label_ids} 256 | 257 | def _eval_end(self, outputs) -> tuple: 258 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 259 | preds = torch.cat([x["pred"] for x in outputs], dim=0) 260 | labels = torch.cat([x["target"] for x in outputs], dim=0) 261 | preds = torch.argmax(preds, axis=-1) 262 | accuracy = (preds == labels).int().sum() / float(torch.tensor(preds.shape[-1], dtype=torch.float32, device=labels.device)) 263 | f1 = calc_f1(preds, labels) 264 | if self.trainer.use_ddp: 265 | torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM) 266 | avg_loss /= self.trainer.world_size 267 | torch.distributed.all_reduce(accuracy, op=torch.distributed.ReduceOp.SUM) 268 | accuracy /= self.trainer.world_size 269 | torch.distributed.all_reduce(f1, op=torch.distributed.ReduceOp.SUM) 270 | f1 /= self.trainer.world_size 271 | # accuracy = (preds == out_label_ids).int().sum() / float(torch.tensor(preds.shape[0], dtype=torch.float32, device=out_label_ids.device)) 272 | results = {"val_loss": avg_loss, "f1": f1, "acc": accuracy} 273 | 274 | ret = {k: v for k, v in results.items()} 275 | ret["log"] = results 276 | return ret 277 | 278 | def validation_epoch_end(self, outputs: list) -> dict: 279 | ret = self._eval_end(outputs) 280 | logs = ret["log"] 281 | return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} 282 | 283 | def test_epoch_end(self, outputs) -> dict: 284 | ret = self._eval_end(outputs) 285 | logs = ret["log"] 286 | results = {} 287 | for k, v in logs.items(): 288 | if isinstance(v, torch.Tensor): 289 | results[k] = v.detach().cpu().item() 290 | # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss` 291 | return {"avg_test_loss": logs["val_loss"].detach().cpu().item(), "log": results, "progress_bar": results} 292 | 293 | def test_step(self, batch, batch_nb): 294 | return self.validation_step(batch, batch_nb) 295 | 296 | 297 | def parse_args(): 298 | parser = argparse.ArgumentParser() 299 | parser.add_argument('--model_dir', dest='model_dir', default='longformer-chinese-base-4096/', help='path to the model') 300 | parser.add_argument('--config_path', default=None, help='path to the config (if not setting dir)') 301 | parser.add_argument('--checkpoint_path', default=None, help='path to the model (if not setting checkpoint)') 302 | parser.add_argument('--attention_mode', required=True, default='sliding_chunks') 303 | parser.add_argument('--tokenizer', default='longformer-chinese-base-4096/') 304 | parser.add_argument('--train_file') 305 | parser.add_argument('--dev_file') 306 | parser.add_argument('--test_file') 307 | parser.add_argument('--input_dir', default=None, help='optionally provide a directory of the data and train/test/dev files will be automatically detected') 308 | parser.add_argument('--batch_size', default=1, type=int) 309 | parser.add_argument('--grad_accum', default=1, type=int) 310 | parser.add_argument('--gpus', default=1) 311 | parser.add_argument('--seed', default=1918, type=int) 312 | parser.add_argument('--fp16', default=False, action='store_true') 313 | parser.add_argument('--test_only', default=False, action='store_true') 314 | parser.add_argument('--test_checkpoint', default=None) 315 | parser.add_argument('--test_percent_check', default=1.0, type=float) 316 | parser.add_argument('--limit_val_batches', default=1.0, type=float) 317 | parser.add_argument('--val_check_interval', default=1.0, type=float) 318 | parser.add_argument('--num_epochs', default=1, type=int) 319 | parser.add_argument('--do_predict', default=False, action='store_true') 320 | parser.add_argument("--lr", type=float, default=2e-5) 321 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 322 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 323 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 324 | parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 325 | parser.add_argument("--adafactor", action="store_true") 326 | parser.add_argument('--save_dir', required=True) 327 | parser.add_argument('--num_labels', default=-1, type=int, 328 | help='if -1, it automatically finds number of labels.' 329 | 'for larger datasets precomute this and manually set') 330 | parser.add_argument('--num_samples', default=None, type=int) 331 | parser.add_argument("--lr_scheduler", 332 | default="linear", 333 | choices=arg_to_scheduler_choices, 334 | metavar=arg_to_scheduler_metavar, 335 | type=str, 336 | help="Learning rate scheduler") 337 | args = parser.parse_args() 338 | 339 | if args.input_dir is not None: 340 | files = glob.glob(args.input_dir + '/*') 341 | for f in files: 342 | fname = f.split('/')[-1] 343 | if 'train' in fname: 344 | args.train_file = f 345 | elif 'dev' in fname or 'val' in fname: 346 | args.dev_file = f 347 | elif 'test' in fname: 348 | args.test_file = f 349 | return args 350 | 351 | def get_train_params(args): 352 | train_params = {} 353 | train_params["precision"] = 16 if args.fp16 else 32 354 | if (isinstance(args.gpus, int) and args.gpus > 1) or (isinstance(args.gpus, list ) and len(args.gpus) > 1): 355 | train_params["distributed_backend"] = "ddp" 356 | else: 357 | train_params["distributed_backend"] = None 358 | train_params["accumulate_grad_batches"] = args.grad_accum 359 | train_params['track_grad_norm'] = -1 360 | train_params['limit_val_batches'] = args.limit_val_batches 361 | train_params['val_check_interval'] = args.val_check_interval 362 | train_params['gpus'] = args.gpus 363 | train_params['max_epochs'] = args.num_epochs 364 | return train_params 365 | 366 | def main(): 367 | args = parse_args() 368 | random.seed(args.seed) 369 | np.random.seed(args.seed) 370 | torch.manual_seed(args.seed) 371 | if torch.cuda.is_available(): 372 | torch.cuda.manual_seed_all(args.seed) 373 | if ',' in args.gpus: 374 | args.gpus = list(map(int, args.gpus.split(','))) 375 | args.total_gpus = len(args.gpus) 376 | else: 377 | args.gpus = int(args.gpus) 378 | args.total_gpus = args.gpus 379 | 380 | def infer_num_labels(args): 381 | # Dataset will be constructred inside model, here we just want to read labels (seq len doesn't matter here) 382 | ds = ClassificationDataset(args.train_file, tokenizer=args.tokenizer, seqlen=4096) 383 | num_labels = len(ds.label_to_idx) 384 | return num_labels 385 | 386 | if args.test_only: 387 | print('loading model...') 388 | if args.num_labels == -1: 389 | args.num_labels = infer_num_labels(args) 390 | model = LongformerClassifier.load_from_checkpoint(args.test_checkpoint, num_labels=args.num_labels) 391 | trainer = pl.Trainer(gpus=args.gpus, test_percent_check=args.test_percent_check) 392 | trainer.test(model) 393 | 394 | else: 395 | if args.num_labels == -1: 396 | args.num_labels = infer_num_labels(args) 397 | model = LongformerClassifier(args) 398 | 399 | # default logger used by trainer 400 | logger = TensorBoardLogger( 401 | save_dir=args.save_dir, 402 | version=0, 403 | name='pl-logs' 404 | ) 405 | 406 | # second part of the path shouldn't be f-string 407 | filepath = f'{args.save_dir}/version_{logger.version}/checkpoints/' + 'ep-{epoch}_acc-{acc:.3f}' 408 | checkpoint_callback = ModelCheckpoint( 409 | filepath=filepath, 410 | save_top_k=1, 411 | verbose=True, 412 | monitor='acc', 413 | mode='max', 414 | prefix='' 415 | ) 416 | 417 | extra_train_params = get_train_params(args) 418 | 419 | trainer = pl.Trainer(logger=logger, 420 | checkpoint_callback=checkpoint_callback, 421 | **extra_train_params) 422 | 423 | trainer.fit(model) 424 | 425 | if args.do_predict: 426 | # Optionally, predict and write to output_dir 427 | fpath = glob.glob(checkpoint_callback.dirpath + '/*.ckpt')[0] 428 | model = LongformerClassifier.load_from_checkpoint(fpath) 429 | model.hparams.num_gpus = 1 430 | model.hparams.total_gpus = 1 431 | model.hparams = args 432 | model.hparams.dev_file = args.dev_file 433 | model.hparams.test_file = args.test_file 434 | model.hparams.train_file = args.dev_file # the model won't get trained, pass in the dev file instead to load faster 435 | trainer = pl.Trainer(gpus=1, test_percent_check=1.0, train_percent_check=0.01, limit_val_batches=0.01, precision=extra_train_params['precision']) 436 | trainer.test(model) 437 | 438 | if __name__ == '__main__': 439 | main() 440 | --------------------------------------------------------------------------------