├── __init__.py
├── longformer
├── __init__.py
├── lib
│ └── lib_diagonaled_mm_float32_cuda.so
├── sliding_chunks.py
├── longformer.py
└── diagonaled_mm_tvm.py
├── requirements.txt
├── .gitignore
├── predict.py
├── README.md
└── classification.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/longformer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | transformers==3.4.0
3 | tensorboardX
4 | pytorch-lightning==1.0.3
5 |
--------------------------------------------------------------------------------
/longformer/lib/lib_diagonaled_mm_float32_cuda.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCHENLIU/longformer-chinese/HEAD/longformer/lib/lib_diagonaled_mm_float32_cuda.so
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | longformer-chinese-base-4096/
7 | data/
8 | models/
9 | *.txt
10 | # Distribution / packaging
11 | longformer-chinese-base-4096/
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from classification import LongformerClassifier, ClassificationDataset
2 | from collections import namedtuple
3 | from transformers import BertTokenizer
4 | import torch, time, sys
5 | Argument = namedtuple("Argument", ['test_checkpoint', 'num_labels', "test_file", "sequence_length"])
6 | args = Argument(test_checkpoint='models/version_0/checkpoints/ep-epoch=0_acc-acc=0.915.ckpt',
7 | num_labels=18,
8 | test_file = "data/train.txt",
9 | sequence_length = 4096)
10 |
11 | class LongformerClassify():
12 | def __init__(self, mask_padding_with_zero=True, map_generate_from_file=True):
13 | self.data = []
14 | self._tokenizer = BertTokenizer.from_pretrained('longformer-chinese-base-4096/')
15 | self._tokenizer.model_max_length = args.sequence_length
16 | self.mask_padding_with_zero = mask_padding_with_zero
17 | self.seqlen = args.sequence_length
18 | if map_generate_from_file:
19 | self.produce_label_map()
20 | self.device=torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
21 | self.model = LongformerClassifier.load_from_checkpoint(args.test_checkpoint, num_labels=args.num_labels)
22 | self.model.to(self.device)
23 |
24 | def produce_label_map(self):
25 | data = []
26 | with open(args.test_file, encoding='UTF-8') as fin:
27 | for i, line in enumerate(fin):
28 | items = line.strip().split('\tSEP\t')
29 | if len(items) != 10: continue
30 | data.append({
31 | "text": items[0]+items[1],
32 | "label": items[5]
33 | })
34 | all_labels = list(set([e["label"] for e in data]))
35 | self.label_to_idx = {e: i for i, e in enumerate(sorted(all_labels))}
36 | self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
37 | print(self.label_to_idx)
38 |
39 | def _convert_to_tensors(self, instance):
40 | def tok(s):
41 | return self._tokenizer.tokenize(s)
42 | tokens = [self._tokenizer.cls_token] + tok(instance["text"])
43 | token_ids = self._tokenizer.convert_tokens_to_ids(tokens)
44 | token_ids = token_ids[:self.seqlen-1] +[self._tokenizer.sep_token_id]
45 | input_len = len(token_ids)
46 | attention_mask = [1 if self.mask_padding_with_zero else 0] * input_len
47 | padding_length = self.seqlen - input_len
48 | token_ids = token_ids + ([self._tokenizer.pad_token_id] * padding_length)
49 | attention_mask = attention_mask + ([0 if self.mask_padding_with_zero else 1] * padding_length)
50 | assert len(token_ids) == self.seqlen, "Error with input length {} vs {}".format(
51 | len(token_ids), self.seqlen
52 | )
53 | assert len(attention_mask) == self.seqlen, "Error with input length {} vs {}".format(
54 | len(attention_mask), self.seqlen
55 | )
56 | label = self.label_to_idx[instance["label"]] if instance["label"] else 0
57 | return (torch.tensor([token_ids]).to(self.device), torch.tensor([attention_mask]).to(self.device), torch.tensor([label]).to(self.device))
58 |
59 | def predict(self, text):
60 | instance = {"text": text, "label": None}
61 | token_ids, attention_mask, label = self._convert_to_tensors(instance=instance)
62 | logits = self.model(token_ids, attention_mask, label)[0]
63 | softmax = torch.nn.Softmax(dim=0)
64 | probabilities = softmax(logits.squeeze())
65 | res_list = sorted(zip(range(args.num_labels), probabilities.tolist()), key=lambda a: a[1], reverse=True)
66 | res_list = [(self.idx_to_label[i],str(j)) for i,j in res_list]
67 | return res_list[:5]
68 |
69 |
70 |
71 |
72 |
73 | if __name__ == '__main__':
74 | classifier = LongformerClassify()
75 | input_file = sys.argv[1]
76 | fw = open("result.txt", "w")
77 | with open(input_file, "r") as fr:
78 | for item in fr:
79 | item = item.strip().split("\t|SEP|\t")
80 | if len(item) != 10:
81 | continue
82 | res = []
83 | for r in classifier.predict(item[8]+item[9]):
84 | res.extend(r)
85 | fw.write('\t|SEP|\t'.join(item+res)+'\n')
86 | fw.flush()
87 | fw.close()
88 | # for i in range(100):
89 | # begin = time.time()
90 | # print(classifier.predict("泰安市政协十三届三十四次 主席会议召开"))
91 | # print(time.time() - begin)
92 |
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
`Longformer-chinese`
2 | All work is based on `Longformer`(https://github.com/allenai/longformer)
3 |
4 | `Longformer-chinese` 提供了:基于BERT的中文预训练模型、在分类任务上的实现
5 |
6 | ### WHAT'S DIFFERENT
7 |
8 | `Longformer-chinese` 基于BERT框架进行修改,在embedding层会与原版的稍有区别。加载时使用longformer.longformer:
9 |
10 | ```
11 | from longformer.longformer import *
12 | config = LongformerConfig.from_pretrained('schen/longformer-chinese-base-4096')
13 | model = Longformer.from_pretrained('schen/longformer-chinese-base-4096', config=config)
14 | ```
15 |
16 |
17 | 使用`schen/longformer-chinese-base-4096`会自动从transformers下载预训练模型,也可以自行下载后替换成所在目录:
18 | https://huggingface.co/schen/longformer-chinese-base-4096
19 |
20 | ### How to use
21 |
22 | 1. Download pretrained model
23 | * [`longformer-base-4096`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-base-4096.tar.gz)
24 | * [`longformer-large-4096`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-large-4096.tar.gz)
25 |
26 | 2. Install environment and code
27 |
28 | ```bash
29 | conda create --name longformer python=3.7
30 | conda activate longformer
31 | conda install cudatoolkit=10.0
32 | pip install git+https://github.com/allenai/longformer.git
33 | ```
34 |
35 | 3. Run the model
36 |
37 | ```python
38 | import torch
39 | from longformer.longformer import Longformer, LongformerConfig
40 | from longformer.sliding_chunks import pad_to_window_size
41 | from transformers import RobertaTokenizer
42 |
43 | config = LongformerConfig.from_pretrained('longformer-base-4096/')
44 | # choose the attention mode 'n2', 'tvm' or 'sliding_chunks'
45 | # 'n2': for regular n2 attantion
46 | # 'tvm': a custom CUDA kernel implementation of our sliding window attention
47 | # 'sliding_chunks': a PyTorch implementation of our sliding window attention
48 | config.attention_mode = 'sliding_chunks'
49 |
50 | model = Longformer.from_pretrained('longformer-base-4096/', config=config)
51 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
52 | tokenizer.model_max_length = model.config.max_position_embeddings
53 |
54 | SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
55 |
56 | input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
57 |
58 | # TVM code doesn't work on CPU. Uncomment this if `config.attention_mode = 'tvm'`
59 | # model = model.cuda(); input_ids = input_ids.cuda()
60 |
61 | # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
62 | attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
63 | attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
64 | # classification: the token
65 | # QA: question tokens
66 |
67 | # padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention
68 | input_ids, attention_mask = pad_to_window_size(
69 | input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)
70 |
71 | output = model(input_ids, attention_mask=attention_mask)[0]
72 | ```
73 |
74 | ### Model pretraining
75 |
76 | [This notebook](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) demonstrates our procedure for training Longformer starting from the RoBERTa checkpoint. The same procedure can be followed to get a long-version of other existing pretrained models.
77 |
78 | ### TriviaQA
79 |
80 | * Training scripts: `scripts/triviaqa.py`
81 | * Pretrained large model: [`here`](https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/triviaqa-longformer-large.tar.gz) (replicates leaderboard results)
82 | * Instructions: `scripts/cheatsheet.txt`
83 |
84 |
85 | ### CUDA kernel
86 |
87 | Our custom CUDA kernel is implemented in TVM. For now, the kernel only works on GPUs and Linux. We tested it on Ubuntu, Python 3.7, CUDA10, PyTorch >= 1.2.0. If it doesn't work for your environment, please create a new issue.
88 |
89 | **Compiling the kernel**: We already include the compiled binaries of the CUDA kernel, so most users won't need to compile it, but if you are intersted, check `scripts/cheatsheet.txt` for instructions.
90 |
91 |
92 | ### Known issues
93 |
94 | Please check the repo [issues](https://github.com/allenai/longformer/issues) for a list of known issues that we are planning to address soon. If your issue is not discussed, please create a new one.
95 |
96 |
97 | ### Citing
98 |
99 | If you use `Longformer` in your research, please cite [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150).
100 | ```
101 | @article{Beltagy2020Longformer,
102 | title={Longformer: The Long-Document Transformer},
103 | author={Iz Beltagy and Matthew E. Peters and Arman Cohan},
104 | journal={arXiv:2004.05150},
105 | year={2020},
106 | }
107 | ```
108 |
109 | `Longformer` is an open-source project developed by [the Allen Institute for Artificial Intelligence (AI2)](http://www.allenai.org).
110 | AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering.
111 |
--------------------------------------------------------------------------------
/longformer/sliding_chunks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from longformer.diagonaled_mm_tvm import mask_invalid_locations
4 |
5 |
6 | def _skew(x, direction, padding_value):
7 | '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
8 | x_padded = F.pad(x, direction, value=padding_value)
9 | x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
10 | return x_padded
11 |
12 |
13 | def _skew2(x, padding_value):
14 | '''shift every row 1 step to right converting columns into diagonals'''
15 | # X = B x C x M x L
16 | B, C, M, L = x.size()
17 | x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1)
18 | x = x.view(B, C, -1) # B x C x ML+MM+M
19 | x = x[:, :, :-M] # B x C x ML+MM
20 | x = x.view(B, C, M, M + L) # B x C, M x L+M
21 | x = x[:, :, :, :-1]
22 | return x
23 |
24 |
25 | def _chunk(x, w):
26 | '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
27 |
28 | # non-overlapping chunks of size = 2w
29 | x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
30 |
31 | # use `as_strided` to make the chunks overlap with an overlap size = w
32 | chunk_size = list(x.size())
33 | chunk_size[1] = chunk_size[1] * 2 - 1
34 |
35 | chunk_stride = list(x.stride())
36 | chunk_stride[1] = chunk_stride[1] // 2
37 | return x.as_strided(size=chunk_size, stride=chunk_stride)
38 |
39 |
40 | def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
41 | '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
42 | This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
43 | with an overlap of size w'''
44 | bsz, seqlen, num_heads, head_dim = q.size()
45 | assert seqlen % (w * 2) == 0
46 | assert q.size() == k.size()
47 |
48 | chunks_count = seqlen // w - 1
49 |
50 | # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
51 | q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
52 | k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
53 |
54 | chunk_q = _chunk(q, w)
55 | chunk_k = _chunk(k, w)
56 |
57 | # matrix multipication
58 | # bcxd: bsz*num_heads x chunks x 2w x head_dim
59 | # bcyd: bsz*num_heads x chunks x 2w x head_dim
60 | # bcxy: bsz*num_heads x chunks x 2w x 2w
61 | chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
62 |
63 | # convert diagonals into columns
64 | diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
65 |
66 | # allocate space for the overall attention matrix where the chunks are compined. The last dimension
67 | # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
68 | # w previous words). The following column is attention score from each word to itself, then
69 | # followed by w columns for the upper triangle.
70 |
71 | diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
72 |
73 | # copy parts from diagonal_chunk_attn into the compined matrix of attentions
74 | # - copying the main diagonal and the upper triangle
75 | diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
76 | diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
77 | # - copying the lower triangle
78 | diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
79 | diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]
80 |
81 | # separate bsz and num_heads dimensions again
82 | diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
83 |
84 | mask_invalid_locations(diagonal_attn, w, 1, False)
85 | return diagonal_attn
86 |
87 |
88 | def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
89 | '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
90 | format from sliding_chunks_matmul_qk'''
91 | bsz, seqlen, num_heads, head_dim = v.size()
92 | assert seqlen % (w * 2) == 0
93 | assert prob.size()[:3] == v.size()[:3]
94 | assert prob.size(3) == 2 * w + 1
95 | chunks_count = seqlen // w - 1
96 | # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
97 | chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)
98 |
99 | # group bsz and num_heads dimensions into one
100 | v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
101 |
102 | # pad seqlen with w at the beginning of the sequence and another w at the end
103 | padded_v = F.pad(v, (0, 0, w, w), value=-1)
104 |
105 | # chunk padded_v into chunks of size 3w and an overlap of size w
106 | chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
107 | chunk_v_stride = padded_v.stride()
108 | chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
109 | chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
110 |
111 | skewed_prob = _skew2(chunk_prob, padding_value=0)
112 |
113 | context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
114 | return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)
115 |
116 |
117 | def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
118 | one_sided_window_size: int, pad_token_id: int):
119 | '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
120 | Input:
121 | input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
122 | attention_mask = torch.Tensor(bsz x seqlen): attention mask
123 | one_sided_window_size = int: window size on one side of each token
124 | pad_token_id = int: tokenizer.pad_token_id
125 | Returns
126 | (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
127 | '''
128 | w = 2 * one_sided_window_size
129 | seqlen = input_ids.size(1)
130 | padding_len = (w - seqlen % w) % w
131 | input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
132 | attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
133 | return input_ids, attention_mask
134 |
135 |
136 | # ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
137 | # This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
138 | # To make this implemenation comparable to "sliding_chunks" set w such that
139 | # w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
140 | # For example,
141 | # w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
142 | # w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
143 | # Performance:
144 | # - Speed: 30% faster than "sliding_chunks"
145 | # - Memory: 95% of the memory usage of "sliding_chunks"
146 | # The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
147 | # while "sliding_chunks" has a symmetric window around each token.
148 |
149 |
150 | def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
151 | bsz, seqlen, num_heads, head_dim = q.size()
152 | assert seqlen % w == 0
153 | assert q.size() == k.size()
154 | # chunk seqlen into non-overlapping chunks of size w
155 | chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
156 | chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
157 | chunk_k_expanded = torch.stack((
158 | F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
159 | chunk_k,
160 | F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
161 | ), dim=-1)
162 | diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
163 | return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
164 |
165 |
166 | def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
167 | bsz, seqlen, num_heads, head_dim = v.size()
168 | chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
169 | chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
170 | chunk_v_extended = torch.stack((
171 | F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
172 | chunk_v,
173 | F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
174 | ), dim=-1)
175 | context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
176 | return context.reshape(bsz, seqlen, num_heads, head_dim)
177 |
--------------------------------------------------------------------------------
/longformer/longformer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import math
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
7 | from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
8 | from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
9 | from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
10 | from transformers.modeling_bert import BertConfig, BertModel, BertForMaskedLM
11 |
12 |
13 | class Longformer(BertModel):
14 | def __init__(self, config):
15 | super(Longformer, self).__init__(config)
16 | if config.attention_mode == 'n2':
17 | pass # do nothing, use BertSelfAttention instead
18 | else:
19 | for i, layer in enumerate(self.encoder.layer):
20 | layer.attention.self = LongformerSelfAttention(config, layer_id=i)
21 |
22 |
23 | class LongformerForMaskedLM(BertForMaskedLM):
24 | def __init__(self, config):
25 | super(LongformerForMaskedLM, self).__init__(config)
26 | if config.attention_mode == 'n2':
27 | pass # do nothing, use BertSelfAttention instead
28 | else:
29 | for i, layer in enumerate(self.bert.encoder.layer):
30 | layer.attention.self = LongformerSelfAttention(config, layer_id=i)
31 |
32 |
33 | class LongformerConfig(BertConfig):
34 | def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
35 | autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs):
36 | """
37 | Args:
38 | attention_window: list of attention window sizes of length = number of layers.
39 | window size = number of attention locations on each side.
40 | For an affective window size of 512, use `attention_window=[256]*num_layers`
41 | which is 256 on each side.
42 | attention_dilation: list of attention dilation of length = number of layers.
43 | attention dilation of `1` means no dilation.
44 | autoregressive: do autoregressive attention or have attention of both sides
45 | attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
46 | selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
47 | """
48 | super().__init__(**kwargs)
49 | self.attention_window = attention_window
50 | self.attention_dilation = attention_dilation
51 | self.autoregressive = autoregressive
52 | self.attention_mode = attention_mode
53 | assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']
54 |
55 |
56 | class LongformerSelfAttention(nn.Module):
57 | def __init__(self, config, layer_id):
58 | super(LongformerSelfAttention, self).__init__()
59 | if config.hidden_size % config.num_attention_heads != 0:
60 | raise ValueError(
61 | "The hidden size (%d) is not a multiple of the number of attention "
62 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
63 | self.output_attentions = config.output_attentions
64 | self.num_heads = config.num_attention_heads
65 | self.head_dim = int(config.hidden_size / config.num_attention_heads)
66 | self.embed_dim = config.hidden_size
67 |
68 | self.query = nn.Linear(config.hidden_size, self.embed_dim)
69 | self.key = nn.Linear(config.hidden_size, self.embed_dim)
70 | self.value = nn.Linear(config.hidden_size, self.embed_dim)
71 |
72 | self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
73 | self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
74 | self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
75 |
76 | self.dropout = config.attention_probs_dropout_prob
77 |
78 | self.layer_id = layer_id
79 | self.attention_window = config.attention_window[self.layer_id]
80 | # self.attention_dilation = config.attention_dilation[self.layer_id]
81 | self.attention_dilation = 1
82 | self.attention_mode = config.attention_mode
83 | self.autoregressive = config.autoregressive
84 | assert self.attention_window > 0
85 | assert self.attention_dilation > 0
86 | assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
87 | if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
88 | assert not self.autoregressive # not supported
89 | assert self.attention_dilation == 1 # dilation is not supported
90 |
91 | def forward(
92 | self,
93 | hidden_states,
94 | attention_mask=None,
95 | head_mask=None,
96 | encoder_hidden_states=None,
97 | encoder_attention_mask=None,
98 | output_attentions=False,
99 | ):
100 | '''
101 | The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
102 | -ve: no attention
103 | 0: local attention
104 | +ve: global attention
105 | '''
106 | assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
107 | assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"
108 |
109 | if attention_mask is not None:
110 | attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
111 | key_padding_mask = attention_mask < 0
112 | extra_attention_mask = attention_mask > 0
113 | remove_from_windowed_attention_mask = attention_mask != 0
114 |
115 | num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
116 | max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
117 | if max_num_extra_indices_per_batch <= 0:
118 | extra_attention_mask = None
119 | else:
120 | # To support the case of variable number of global attention in the rows of a batch,
121 | # we use the following three selection masks to select global attention embeddings
122 | # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
123 | # 1) selecting embeddings that correspond to global attention
124 | extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
125 | zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
126 | device=num_extra_indices_per_batch.device)
127 | # mask indicating which values are actually going to be padding
128 | selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
129 | # 2) location of the non-padding values in the selected global attention
130 | selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
131 | # 3) location of the padding values in the selected global attention
132 | selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
133 | else:
134 | remove_from_windowed_attention_mask = None
135 | extra_attention_mask = None
136 | key_padding_mask = None
137 |
138 | hidden_states = hidden_states.transpose(0, 1)
139 | seq_len, bsz, embed_dim = hidden_states.size()
140 | assert embed_dim == self.embed_dim
141 | q = self.query(hidden_states)
142 | k = self.key(hidden_states)
143 | v = self.value(hidden_states)
144 | q /= math.sqrt(self.head_dim)
145 |
146 | q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
147 | k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
148 | # attn_weights = (bsz, seq_len, num_heads, window*2+1)
149 | if self.attention_mode == 'tvm':
150 | q = q.float().contiguous()
151 | k = k.float().contiguous()
152 | attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
153 | elif self.attention_mode == "sliding_chunks":
154 | attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
155 | elif self.attention_mode == "sliding_chunks_no_overlap":
156 | attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
157 | else:
158 | raise False
159 | mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
160 | if remove_from_windowed_attention_mask is not None:
161 | # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
162 | # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
163 | remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
164 | # cast to float/half then replace 1's with -inf
165 | float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
166 | repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
167 | float_mask = float_mask.repeat(1, 1, repeat_size, 1)
168 | ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
169 | # diagonal mask with zeros everywhere and -inf inplace of padding
170 | if self.attention_mode == 'tvm':
171 | d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
172 | elif self.attention_mode == "sliding_chunks":
173 | d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
174 | elif self.attention_mode == "sliding_chunks_no_overlap":
175 | d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
176 |
177 | attn_weights += d_mask
178 | assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
179 | assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
180 |
181 | # the extra attention
182 | if extra_attention_mask is not None:
183 | selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
184 | selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
185 | # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
186 | selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
187 | selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
188 | # concat to attn_weights
189 | # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
190 | attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
191 |
192 | attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
193 | if key_padding_mask is not None:
194 | # softmax sometimes inserts NaN if all positions are masked, replace them with 0
195 | attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
196 |
197 | attn_weights = attn_weights_float.type_as(attn_weights)
198 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
199 | v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
200 | attn = 0
201 | if extra_attention_mask is not None:
202 | selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
203 | selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
204 | selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
205 | # use `matmul` because `einsum` crashes sometimes with fp16
206 | # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
207 | attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
208 | attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
209 |
210 | if self.attention_mode == 'tvm':
211 | v = v.float().contiguous()
212 | attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
213 | elif self.attention_mode == "sliding_chunks":
214 | attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
215 | elif self.attention_mode == "sliding_chunks_no_overlap":
216 | attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
217 | else:
218 | raise False
219 |
220 | attn = attn.type_as(hidden_states)
221 | assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
222 | attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
223 |
224 | # For this case, we'll just recompute the attention for these indices
225 | # and overwrite the attn tensor. TODO: remove the redundant computation
226 | if extra_attention_mask is not None:
227 | selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
228 | selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
229 |
230 | q = self.query_global(selected_hidden_states)
231 | k = self.key_global(hidden_states)
232 | v = self.value_global(hidden_states)
233 | q /= math.sqrt(self.head_dim)
234 |
235 | q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
236 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
237 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
238 | attn_weights = torch.bmm(q, k.transpose(1, 2))
239 | assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
240 |
241 | attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
242 | attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
243 | if key_padding_mask is not None:
244 | attn_weights = attn_weights.masked_fill(
245 | key_padding_mask.unsqueeze(1).unsqueeze(2),
246 | -10000.0,
247 | )
248 | attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
249 | attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
250 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
251 | selected_attn = torch.bmm(attn_probs, v)
252 | assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
253 |
254 | selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
255 | nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
256 | attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
257 |
258 | context_layer = attn.transpose(0, 1)
259 | if output_attentions:
260 | if extra_attention_mask is not None:
261 | # With global attention, return global attention probabilities only
262 | # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
263 | # which is the attention weights from tokens with global attention to all tokens
264 | # It doesn't not return local attention
265 | # In case of variable number of global attantion in the rows of a batch,
266 | # attn_weights are padded with -10000.0 attention scores
267 | attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
268 | else:
269 | # without global attention, return local attention probabilities
270 | # batch_size x num_heads x sequence_length x window_size
271 | # which is the attention weights of every token attending to its neighbours
272 | attn_weights = attn_weights.permute(0, 2, 1, 3)
273 | outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
274 | return outputs
275 |
--------------------------------------------------------------------------------
/longformer/diagonaled_mm_tvm.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from functools import lru_cache
3 |
4 | import torch
5 | import os.path
6 |
7 |
8 | class DiagonaledMM(torch.autograd.Function):
9 | '''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
10 | this function from PyTorch
11 | '''
12 |
13 | function_dict = {} # save a list of functions, each has a different set of parameters
14 |
15 | @staticmethod
16 | def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16):
17 | '''Compiles a tvm function that computes diagonal_mm
18 | args:
19 | dtype: str in ['float64', 'float32', 'float16']
20 | device: str in ['cpu' or 'cuda']
21 | b0, b1, b2: size of tensor tiles. Very important for good performance
22 |
23 | '''
24 | import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
25 | from tvm.contrib import nvcc
26 | @tvm.register_func
27 | def tvm_callback_cuda_compile(code):
28 | """Use nvcc compiler for better perf."""
29 | ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
30 | return ptx
31 |
32 | assert dtype in ['float16', 'float32', 'float64']
33 | assert device in ['cpu', 'cuda']
34 | device = None if device == 'cpu' else device
35 | tgt_host="llvm"
36 |
37 | b = tvm.var('b') # batch size
38 | n = tvm.var('n') # sequence length
39 | h = tvm.var('h') # number of heads
40 | m = tvm.var('m') # hidden dimension
41 | w = tvm.var('w') # window size
42 | w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w`
43 | padding = tvm.var('padding') # padding
44 | transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed
45 | t1d3 = tvm.var('t1d3') # last dimension of t1
46 | t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor)
47 | X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
48 | Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
49 | k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over
50 | D = tvm.placeholder((h), name='D', dtype='int') # dilation per head
51 | output_shape = (b, n, h, t3d3) # shape of the result tensor
52 | algorithm = lambda l, i, q, j: tvm.sum(
53 | tvm.if_then_else(
54 | t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
55 | tvm.if_then_else(
56 | transpose_t1 == 0,
57 | tvm.if_then_else(
58 | tvm.all(
59 | i + D[q] * (k - w) >= 0,
60 | i + D[q] * (k - w) < n,
61 | ),
62 | X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled
63 | padding
64 | ),
65 | tvm.if_then_else(
66 | tvm.all(
67 | i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True`
68 | i + D[q] * (k - w_upper) < n,
69 | ),
70 | X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed
71 | padding
72 | ),
73 | ),
74 | tvm.if_then_else(
75 | tvm.all(
76 | i + D[q] * (j - w) >= 0,
77 | i + D[q] * (j - w) < n,
78 | ),
79 | X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be
80 | padding
81 | )
82 | ), axis=k)
83 |
84 | Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
85 | s = tvm.create_schedule(Z.op)
86 |
87 | print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
88 |
89 | # split long axis into smaller chunks and assing each one to a separate GPU thread/block
90 | ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
91 | ZF = s.rfactor(Z, ki)
92 |
93 | j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)
94 | i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2)
95 |
96 | s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x"))
97 | s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y"))
98 |
99 | s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y"))
100 | s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z"))
101 |
102 | tx = tvm.thread_axis("threadIdx.x")
103 | s[Z].bind(s[Z].op.reduce_axis[0], tx)
104 | s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
105 | s[Z].set_store_predicate(tx.var.equal(0))
106 |
107 | print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
108 |
109 | # compiling the automatically generated cuda code
110 | diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm')
111 | return diagonaled_mm
112 |
113 | @staticmethod
114 | def _get_lib_filename(dtype: str, device: str):
115 | base_filename = 'longformer/lib/lib_diagonaled_mm'
116 | return '{}_{}_{}.so'.format(base_filename, dtype, device)
117 |
118 | @staticmethod
119 | def _save_compiled_function(f, dtype: str, device: str):
120 | if not os.path.exists('longformer/lib/'):
121 | os.makedirs('longformer/lib/')
122 | f.export_library(DiagonaledMM._get_lib_filename(dtype, device))
123 |
124 | @staticmethod
125 | def _load_compiled_function(dtype: str, device: str):
126 | from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing
127 | filename = DiagonaledMM._get_lib_filename(dtype, device)
128 | current_dir = os.path.dirname(os.path.abspath(__file__))
129 | potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
130 | for potential_dir in potential_dirs:
131 | filepath = '{}{}'.format(potential_dir, filename)
132 | if os.path.isfile(filepath):
133 | print('Loading tvm binary from: {}'.format(filepath))
134 | return load(filepath)
135 | return None
136 |
137 | @staticmethod
138 | def _get_function(dtype: str, device: str):
139 | '''Loads the function from the disk or compile it'''
140 | # A list of arguments that define the function
141 | args = (dtype, device)
142 | if args not in DiagonaledMM.function_dict:
143 | diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk
144 | if not diagonaled_mm:
145 | print('Tvm binary not found. Compiling ...')
146 | diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile
147 | DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk
148 | # convert the tvm function into a pytorch function
149 | from tvm.contrib import dlpack
150 | diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function
151 | # save the function into a dictionary to be reused
152 | DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time
153 | return DiagonaledMM.function_dict[args]
154 |
155 | @staticmethod
156 | def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
157 | is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
158 | autoregressive: bool = False):
159 | '''Calls the compiled function after checking the input format. This function is called in three different modes.
160 | t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
161 | t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
162 | t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
163 | the calculations in the backward pass.
164 | '''
165 | dtype = str(t1.dtype).split('.')[1]
166 | device = t1.device.type
167 | assert len(t1.shape) == 4
168 | assert len(t1.shape) == len(t2.shape)
169 | assert t1.shape[:3] == t2.shape[:3]
170 | if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length
171 | # as number of heads, and it is filled with the same dilation value
172 | d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)
173 |
174 | assert len(d.shape) == 1
175 | assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads
176 | b = t1.shape[0] # batch size
177 | n = t1.shape[1] # sequence length
178 | h = t1.shape[2] # number of heads
179 | m = t2.shape[3] # hidden dimension
180 | w_upper = 0 if autoregressive else w
181 | c = w_upper + w + 1 # number of diagonals
182 | if is_t1_diagonaled:
183 | assert t1.shape[3] == c
184 | r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor
185 | else:
186 | assert not transpose_t1
187 | assert t1.shape[3] == m
188 | r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor
189 |
190 | # gets function from memory, from disk or compiles it from scratch
191 | _diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)
192 |
193 | # The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
194 | # We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
195 | # The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
196 | # of the output is known
197 | # This functions computes diagonal_mm then saves the result in `r`
198 | if m == c:
199 | # FIXME
200 | print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
201 | assert False
202 | _diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
203 | return r
204 |
205 | @staticmethod
206 | def _prepare_tensors(t):
207 | '''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
208 | For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
209 | TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
210 | Here's an example to reporduce this issue:
211 | import torch
212 | print(torch.randn(1, 10).stride())
213 | > (10, 1)
214 | print(torch.randn(10, 1).t().contiguous().stride())
215 | > (1, 1) # expected it to be (10, 1) as above
216 | print(torch.randn(10, 2).t().contiguous().stride())
217 | > (10, 1) # but gets the expected stride if the first dimension is > 1
218 | '''
219 | assert t.is_contiguous()
220 | t_stride = list(t.stride())
221 | t_size = list(t.size())
222 | # Fix wrong stride information for the first dimension. This occures when batch_size=1
223 | if t_size[0] == 1 and t_stride[0] == t_stride[1]:
224 | # In this case, the stride of the first dimension should be the product
225 | # of the sizes of all other dimensions
226 | t_stride[0] = t_size[1] * t_size[2] * t_size[3]
227 | t = t.as_strided(size=t_size, stride=t_stride)
228 | return t
229 |
230 | min_seq_len = 16 # unexpected output if seq_len < 16
231 |
232 | @staticmethod
233 | def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
234 | '''Compuates diagonal_mm of t1 and t2.
235 | args:
236 | t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
237 | t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
238 | t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
239 | tensor, e.g. `key_layer` or `value_layer`
240 | w: int = window size; number of attentions on each side of the word
241 | d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
242 | heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
243 | is_t1_diagonaled: is t1 a diagonaled or a regular tensor
244 | padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
245 | needs to be a very large negative value (to compute softmax of attentions). For other usecases,
246 | please use zero padding.
247 | autoregressive: if true, return only the lower triangle
248 | returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
249 | if t1 is diagonaed, result is non-diagonaled, and vice versa
250 | '''
251 | batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
252 | assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME
253 | ctx.save_for_backward(t1, t2)
254 | ctx.w = w
255 | ctx.d = d
256 | ctx.is_t1_diagonaled = is_t1_diagonaled
257 | ctx.autoregressive = autoregressive
258 | t1 = DiagonaledMM._prepare_tensors(t1)
259 | t2 = DiagonaledMM._prepare_tensors(t2)
260 | # output = t1.mm(t2) # what would have been called if this was a regular matmul
261 | output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
262 | return output
263 |
264 | @staticmethod
265 | def backward(ctx, grad_output):
266 | t1, t2 = ctx.saved_tensors
267 | w = ctx.w
268 | d = ctx.d
269 | is_t1_diagonaled = ctx.is_t1_diagonaled
270 | autoregressive = ctx.autoregressive
271 | if not grad_output.is_contiguous():
272 | grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
273 | grad_output = DiagonaledMM._prepare_tensors(grad_output)
274 | t1 = DiagonaledMM._prepare_tensors(t1)
275 | t2 = DiagonaledMM._prepare_tensors(t2)
276 | # http://cs231n.github.io/optimization-2/
277 | # https://pytorch.org/docs/master/notes/extending.html
278 | # grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
279 | grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
280 | # grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
281 | if is_t1_diagonaled:
282 | grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
283 | else:
284 | grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
285 | return grad_t1, grad_t2, None, None, None, None, None
286 |
287 |
288 | def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
289 | diagonals_list = []
290 | for j in range(-d * w, d, d):
291 | diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
292 | diagonal_mask[:-j] = 1
293 | diagonals_list.append(diagonal_mask)
294 | return torch.stack(diagonals_list, dim=-1)
295 |
296 | @lru_cache()
297 | def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
298 | if isinstance(d, int):
299 | affected_seq_len = w * d
300 | mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
301 | mask = mask[None, :, None, :]
302 | else:
303 | affected_seq_len = w * d.max()
304 | head_masks = []
305 | d_list = d.cpu().numpy().tolist()
306 | for d in d_list:
307 | one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
308 | head_masks.append(one_head_mask)
309 | mask = torch.stack(head_masks, dim=-2)
310 | mask = mask[None, :, :, :]
311 |
312 | ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
313 | return affected_seq_len, mask.bool().to(device), ending_mask
314 |
315 | def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
316 | affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
317 | seq_len = input_tensor.size(1)
318 | beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
319 | beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
320 | beginning_input.masked_fill_(beginning_mask, -float('inf'))
321 | if not autoregressive:
322 | ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
323 | ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
324 | ending_input.masked_fill_(ending_mask, -float('inf'))
325 |
326 |
327 | diagonaled_mm = DiagonaledMM.apply
328 |
329 | # The non-tvm implementation is the default, we don't need to load the kernel at loading time.
330 | # DiagonaledMM._get_function('float32', 'cuda')
331 |
--------------------------------------------------------------------------------
/classification.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import json
4 | import os
5 | import random
6 | import argparse
7 | from argparse import Namespace
8 | import numpy as np
9 | import glob
10 | import gzip
11 |
12 | import torch
13 | from torch import nn
14 | from torch.nn import CrossEntropyLoss, MSELoss
15 | import torch.nn.functional as F
16 | from torch.utils.data import DataLoader, Dataset, random_split
17 | import pytorch_lightning as pl
18 | from pytorch_lightning.loggers import TensorBoardLogger
19 | from pytorch_lightning.callbacks import ModelCheckpoint
20 |
21 | import torch.distributed as dist
22 |
23 | from longformer.longformer import Longformer, LongformerConfig
24 | from longformer.sliding_chunks import pad_to_window_size
25 | from transformers import BertTokenizer, AdamW
26 |
27 | from torch.utils.data.dataset import IterableDataset
28 | from tqdm.auto import tqdm
29 |
30 | import logging
31 | logger = logging.getLogger(__name__)
32 |
33 | from transformers.optimization import (
34 | Adafactor,
35 | get_cosine_schedule_with_warmup,
36 | get_cosine_with_hard_restarts_schedule_with_warmup,
37 | get_linear_schedule_with_warmup,
38 | get_polynomial_decay_schedule_with_warmup,
39 | )
40 |
41 |
42 | TEXT_FIELD_NAME = 'text'
43 | LABEL_FIELD_NAME = 'label'
44 |
45 | arg_to_scheduler = {
46 | "linear": get_linear_schedule_with_warmup,
47 | "cosine": get_cosine_schedule_with_warmup,
48 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
49 | "polynomial": get_polynomial_decay_schedule_with_warmup,
50 | # '': get_constant_schedule, # not supported for now
51 | # '': get_constant_schedule_with_warmup, # not supported for now
52 | }
53 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
54 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
55 |
56 |
57 | def calc_f1(y_pred:torch.Tensor, y_true:torch.Tensor) -> torch.Tensor:
58 | tp = (y_true * y_pred).sum().to(torch.float32)
59 | tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
60 | fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
61 | fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
62 | epsilon = 1e-7
63 | precision = tp / (tp + fp + epsilon)
64 | recall = tp / (tp + fn + epsilon)
65 |
66 | f1 = 2 * (precision * recall) / (precision + recall + epsilon)
67 | f1 = f1.clamp(min=epsilon, max=1 - epsilon)
68 | return f1
69 |
70 |
71 | class ClassificationDataset(Dataset):
72 | def __init__(self, file_path, tokenizer, seqlen, num_samples=None, mask_padding_with_zero=True):
73 | self.data = []
74 | with (gzip.open(file_path, 'rt') if file_path.endswith('.gz') else open(file_path)) as fin:
75 | for i, line in enumerate(tqdm(fin, desc=f'loading input file {file_path.split("/")[-1]}', unit_scale=1)):
76 | items = line.strip().split('\tSEP\t')
77 | if len(items) != 10: continue
78 | self.data.append({
79 | "text": items[0]+items[1],
80 | "label": items[5]
81 | })
82 | if num_samples and len(self.data) > num_samples:
83 | break
84 | self.seqlen = seqlen
85 | self._tokenizer = tokenizer
86 | all_labels = list(set([e[LABEL_FIELD_NAME] for e in self.data]))
87 | self.label_to_idx = {e: i for i, e in enumerate(sorted(all_labels))}
88 | self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
89 | self.mask_padding_with_zero = mask_padding_with_zero
90 |
91 | def __len__(self):
92 | return len(self.data)
93 |
94 | def __getitem__(self, idx):
95 | return self._convert_to_tensors(self.data[idx])
96 |
97 | def _convert_to_tensors(self, instance):
98 | def tok(s):
99 | return self._tokenizer.tokenize(s)
100 | tokens = [self._tokenizer.cls_token] + tok(instance[TEXT_FIELD_NAME])
101 | token_ids = self._tokenizer.convert_tokens_to_ids(tokens)
102 | token_ids = token_ids[:self.seqlen-1] +[self._tokenizer.sep_token_id]
103 | input_len = len(token_ids)
104 | attention_mask = [1 if self.mask_padding_with_zero else 0] * input_len
105 | padding_length = self.seqlen - input_len
106 | token_ids = token_ids + ([self._tokenizer.pad_token_id] * padding_length)
107 |
108 | attention_mask = attention_mask + ([0 if self.mask_padding_with_zero else 1] * padding_length)
109 |
110 | assert len(token_ids) == self.seqlen, "Error with input length {} vs {}".format(
111 | len(token_ids), self.seqlen
112 | )
113 | assert len(attention_mask) == self.seqlen, "Error with input length {} vs {}".format(
114 | len(attention_mask), self.seqlen
115 | )
116 |
117 | label = self.label_to_idx[instance[LABEL_FIELD_NAME]]
118 |
119 | return (torch.tensor(token_ids), torch.tensor(attention_mask), torch.tensor(label))
120 |
121 |
122 | class LongformerClassifier(pl.LightningModule):
123 |
124 | def __init__(self, init_args):
125 | super().__init__()
126 | if isinstance(init_args, dict):
127 | # for loading the checkpoint, pl passes a dict (hparams are saved as dict)
128 | init_args = Namespace(**init_args)
129 | config_path = init_args.config_path or init_args.model_dir
130 | checkpoint_path = init_args.checkpoint_path or init_args.model_dir
131 | logger.info(f'loading model from config: {config_path}, checkpoint: {checkpoint_path}')
132 | config = LongformerConfig.from_pretrained(config_path)
133 | config.attention_mode = init_args.attention_mode
134 | logger.info(f'attention mode set to {config.attention_mode}')
135 | self.model_config = config
136 | self.model = Longformer.from_pretrained(checkpoint_path, config=config)
137 | self.tokenizer = BertTokenizer.from_pretrained(init_args.tokenizer)
138 | self.tokenizer.model_max_length = self.model.config.max_position_embeddings
139 | self.hparams = init_args
140 | self.hparams.seqlen = self.model.config.max_position_embeddings
141 | self.classifier = nn.Linear(config.hidden_size, init_args.num_labels)
142 |
143 | def forward(self, input_ids, attention_mask, labels=None):
144 | input_ids, attention_mask = pad_to_window_size(
145 | input_ids, attention_mask, self.model_config.attention_window[0], self.tokenizer.pad_token_id)
146 | attention_mask[:, 0] = 2 # global attention for the first token
147 | #use Bert inner Pooler
148 | output = self.model(input_ids, attention_mask=attention_mask)[1]
149 | # pool the entire sequence into one vector (CLS token)
150 | # output = output[:, 0, :]
151 | logits = self.classifier(output)
152 |
153 | loss = None
154 | if labels is not None:
155 | loss_fct = CrossEntropyLoss()
156 |
157 | loss = loss_fct(logits.view(-1, self.hparams.num_labels), labels.view(-1))
158 |
159 | return logits, loss
160 |
161 | def _get_loader(self, split, shuffle=True):
162 | if split == 'train':
163 | fname = self.hparams.train_file
164 | elif split == 'dev':
165 | fname = self.hparams.dev_file
166 | elif split == 'test':
167 | fname = self.hparams.test_file
168 | else:
169 | assert False
170 | is_train = split == 'train'
171 |
172 | dataset = ClassificationDataset(
173 | fname, tokenizer=self.tokenizer, seqlen=self.hparams.seqlen, num_samples=self.hparams.num_samples
174 | )
175 |
176 | loader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, shuffle=(shuffle and is_train))
177 | return loader
178 |
179 | def setup(self, mode):
180 | self.train_loader = self._get_loader("train")
181 |
182 | def train_dataloader(self):
183 | return self.train_loader
184 |
185 | def val_dataloader(self):
186 | self.val_dataloader_obj = self._get_loader('dev')
187 | return self.val_dataloader_obj
188 |
189 | def test_dataloader(self):
190 | return self._get_loader('test')
191 |
192 | @property
193 | def total_steps(self) -> int:
194 | """The number of total training steps that will be run. Used for lr scheduler purposes."""
195 | num_devices = max(1, self.hparams.total_gpus) # TODO: consider num_tpu_cores
196 | effective_batch_size = self.hparams.batch_size * self.hparams.grad_accum * num_devices
197 | dataset_size = len(self.train_loader.dataset)
198 | return (dataset_size / effective_batch_size) * self.hparams.num_epochs
199 |
200 | def get_lr_scheduler(self):
201 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
202 | scheduler = get_schedule_func(
203 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
204 | )
205 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
206 | return scheduler
207 |
208 | def configure_optimizers(self):
209 | """Prepare optimizer and schedule (linear warmup and decay)"""
210 | model = self.model
211 | no_decay = ["bias", "LayerNorm.weight"]
212 | optimizer_grouped_parameters = [
213 | {
214 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
215 | "weight_decay": self.hparams.weight_decay,
216 | },
217 | {
218 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
219 | "weight_decay": 0.0,
220 | },
221 | ]
222 | if self.hparams.adafactor:
223 | optimizer = Adafactor(
224 | optimizer_grouped_parameters, lr=self.hparams.lr, scale_parameter=False, relative_step=False
225 | )
226 |
227 | else:
228 | optimizer = AdamW(
229 | optimizer_grouped_parameters, lr=self.hparams.lr, eps=self.hparams.adam_epsilon
230 | )
231 | self.opt = optimizer
232 |
233 | scheduler = self.get_lr_scheduler()
234 |
235 | return [optimizer], [scheduler]
236 |
237 | def training_step(self, batch, batch_idx):
238 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]}
239 |
240 | outputs = self(**inputs)
241 | loss = outputs[1]
242 |
243 | lr_scheduler = self.trainer.lr_schedulers[0]["scheduler"]
244 | tensorboard_logs = {"loss": loss, "rate": lr_scheduler.get_last_lr()[-1]}
245 | return {"loss": loss, "log": tensorboard_logs}
246 |
247 |
248 | def validation_step(self, batch, batch_idx):
249 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]}
250 |
251 | outputs = self(**inputs)
252 | logits, tmp_eval_loss = outputs
253 | preds = logits
254 | out_label_ids = inputs["labels"]
255 | return {"val_loss": tmp_eval_loss, "pred": preds, "target": out_label_ids}
256 |
257 | def _eval_end(self, outputs) -> tuple:
258 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
259 | preds = torch.cat([x["pred"] for x in outputs], dim=0)
260 | labels = torch.cat([x["target"] for x in outputs], dim=0)
261 | preds = torch.argmax(preds, axis=-1)
262 | accuracy = (preds == labels).int().sum() / float(torch.tensor(preds.shape[-1], dtype=torch.float32, device=labels.device))
263 | f1 = calc_f1(preds, labels)
264 | if self.trainer.use_ddp:
265 | torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM)
266 | avg_loss /= self.trainer.world_size
267 | torch.distributed.all_reduce(accuracy, op=torch.distributed.ReduceOp.SUM)
268 | accuracy /= self.trainer.world_size
269 | torch.distributed.all_reduce(f1, op=torch.distributed.ReduceOp.SUM)
270 | f1 /= self.trainer.world_size
271 | # accuracy = (preds == out_label_ids).int().sum() / float(torch.tensor(preds.shape[0], dtype=torch.float32, device=out_label_ids.device))
272 | results = {"val_loss": avg_loss, "f1": f1, "acc": accuracy}
273 |
274 | ret = {k: v for k, v in results.items()}
275 | ret["log"] = results
276 | return ret
277 |
278 | def validation_epoch_end(self, outputs: list) -> dict:
279 | ret = self._eval_end(outputs)
280 | logs = ret["log"]
281 | return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
282 |
283 | def test_epoch_end(self, outputs) -> dict:
284 | ret = self._eval_end(outputs)
285 | logs = ret["log"]
286 | results = {}
287 | for k, v in logs.items():
288 | if isinstance(v, torch.Tensor):
289 | results[k] = v.detach().cpu().item()
290 | # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
291 | return {"avg_test_loss": logs["val_loss"].detach().cpu().item(), "log": results, "progress_bar": results}
292 |
293 | def test_step(self, batch, batch_nb):
294 | return self.validation_step(batch, batch_nb)
295 |
296 |
297 | def parse_args():
298 | parser = argparse.ArgumentParser()
299 | parser.add_argument('--model_dir', dest='model_dir', default='longformer-chinese-base-4096/', help='path to the model')
300 | parser.add_argument('--config_path', default=None, help='path to the config (if not setting dir)')
301 | parser.add_argument('--checkpoint_path', default=None, help='path to the model (if not setting checkpoint)')
302 | parser.add_argument('--attention_mode', required=True, default='sliding_chunks')
303 | parser.add_argument('--tokenizer', default='longformer-chinese-base-4096/')
304 | parser.add_argument('--train_file')
305 | parser.add_argument('--dev_file')
306 | parser.add_argument('--test_file')
307 | parser.add_argument('--input_dir', default=None, help='optionally provide a directory of the data and train/test/dev files will be automatically detected')
308 | parser.add_argument('--batch_size', default=1, type=int)
309 | parser.add_argument('--grad_accum', default=1, type=int)
310 | parser.add_argument('--gpus', default=1)
311 | parser.add_argument('--seed', default=1918, type=int)
312 | parser.add_argument('--fp16', default=False, action='store_true')
313 | parser.add_argument('--test_only', default=False, action='store_true')
314 | parser.add_argument('--test_checkpoint', default=None)
315 | parser.add_argument('--test_percent_check', default=1.0, type=float)
316 | parser.add_argument('--limit_val_batches', default=1.0, type=float)
317 | parser.add_argument('--val_check_interval', default=1.0, type=float)
318 | parser.add_argument('--num_epochs', default=1, type=int)
319 | parser.add_argument('--do_predict', default=False, action='store_true')
320 | parser.add_argument("--lr", type=float, default=2e-5)
321 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
322 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
323 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
324 | parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
325 | parser.add_argument("--adafactor", action="store_true")
326 | parser.add_argument('--save_dir', required=True)
327 | parser.add_argument('--num_labels', default=-1, type=int,
328 | help='if -1, it automatically finds number of labels.'
329 | 'for larger datasets precomute this and manually set')
330 | parser.add_argument('--num_samples', default=None, type=int)
331 | parser.add_argument("--lr_scheduler",
332 | default="linear",
333 | choices=arg_to_scheduler_choices,
334 | metavar=arg_to_scheduler_metavar,
335 | type=str,
336 | help="Learning rate scheduler")
337 | args = parser.parse_args()
338 |
339 | if args.input_dir is not None:
340 | files = glob.glob(args.input_dir + '/*')
341 | for f in files:
342 | fname = f.split('/')[-1]
343 | if 'train' in fname:
344 | args.train_file = f
345 | elif 'dev' in fname or 'val' in fname:
346 | args.dev_file = f
347 | elif 'test' in fname:
348 | args.test_file = f
349 | return args
350 |
351 | def get_train_params(args):
352 | train_params = {}
353 | train_params["precision"] = 16 if args.fp16 else 32
354 | if (isinstance(args.gpus, int) and args.gpus > 1) or (isinstance(args.gpus, list ) and len(args.gpus) > 1):
355 | train_params["distributed_backend"] = "ddp"
356 | else:
357 | train_params["distributed_backend"] = None
358 | train_params["accumulate_grad_batches"] = args.grad_accum
359 | train_params['track_grad_norm'] = -1
360 | train_params['limit_val_batches'] = args.limit_val_batches
361 | train_params['val_check_interval'] = args.val_check_interval
362 | train_params['gpus'] = args.gpus
363 | train_params['max_epochs'] = args.num_epochs
364 | return train_params
365 |
366 | def main():
367 | args = parse_args()
368 | random.seed(args.seed)
369 | np.random.seed(args.seed)
370 | torch.manual_seed(args.seed)
371 | if torch.cuda.is_available():
372 | torch.cuda.manual_seed_all(args.seed)
373 | if ',' in args.gpus:
374 | args.gpus = list(map(int, args.gpus.split(',')))
375 | args.total_gpus = len(args.gpus)
376 | else:
377 | args.gpus = int(args.gpus)
378 | args.total_gpus = args.gpus
379 |
380 | def infer_num_labels(args):
381 | # Dataset will be constructred inside model, here we just want to read labels (seq len doesn't matter here)
382 | ds = ClassificationDataset(args.train_file, tokenizer=args.tokenizer, seqlen=4096)
383 | num_labels = len(ds.label_to_idx)
384 | return num_labels
385 |
386 | if args.test_only:
387 | print('loading model...')
388 | if args.num_labels == -1:
389 | args.num_labels = infer_num_labels(args)
390 | model = LongformerClassifier.load_from_checkpoint(args.test_checkpoint, num_labels=args.num_labels)
391 | trainer = pl.Trainer(gpus=args.gpus, test_percent_check=args.test_percent_check)
392 | trainer.test(model)
393 |
394 | else:
395 | if args.num_labels == -1:
396 | args.num_labels = infer_num_labels(args)
397 | model = LongformerClassifier(args)
398 |
399 | # default logger used by trainer
400 | logger = TensorBoardLogger(
401 | save_dir=args.save_dir,
402 | version=0,
403 | name='pl-logs'
404 | )
405 |
406 | # second part of the path shouldn't be f-string
407 | filepath = f'{args.save_dir}/version_{logger.version}/checkpoints/' + 'ep-{epoch}_acc-{acc:.3f}'
408 | checkpoint_callback = ModelCheckpoint(
409 | filepath=filepath,
410 | save_top_k=1,
411 | verbose=True,
412 | monitor='acc',
413 | mode='max',
414 | prefix=''
415 | )
416 |
417 | extra_train_params = get_train_params(args)
418 |
419 | trainer = pl.Trainer(logger=logger,
420 | checkpoint_callback=checkpoint_callback,
421 | **extra_train_params)
422 |
423 | trainer.fit(model)
424 |
425 | if args.do_predict:
426 | # Optionally, predict and write to output_dir
427 | fpath = glob.glob(checkpoint_callback.dirpath + '/*.ckpt')[0]
428 | model = LongformerClassifier.load_from_checkpoint(fpath)
429 | model.hparams.num_gpus = 1
430 | model.hparams.total_gpus = 1
431 | model.hparams = args
432 | model.hparams.dev_file = args.dev_file
433 | model.hparams.test_file = args.test_file
434 | model.hparams.train_file = args.dev_file # the model won't get trained, pass in the dev file instead to load faster
435 | trainer = pl.Trainer(gpus=1, test_percent_check=1.0, train_percent_check=0.01, limit_val_batches=0.01, precision=extra_train_params['precision'])
436 | trainer.test(model)
437 |
438 | if __name__ == '__main__':
439 | main()
440 |
--------------------------------------------------------------------------------