├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── clue_small_wwm_data
├── dataset.arrow
├── dataset_info.json
├── place_holder.txt
└── state.json
├── data
├── clue_corpus_small_14g.jsonl
├── refids.txt
└── reftext.txt
├── demo.ipynb
├── figure
└── tnews.jpg
├── flash
├── __init__.py
├── flash.py
├── flash_lucidrains.py
└── gau.py
├── mlm_trainer.py
├── pretrain.sh
├── requirements.txt
├── run_chinese_ref.py
├── run_mlm_wwm.py
└── trans_to_json.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=python
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FLASHQuad_pytorch & FLASH_pytorch
2 | pytorch implement of FLASHQuad and FLASH
3 |
4 | # Describtion
5 | 个人实现`pytorch`版本的[《Transformer Quality in Linear Time》](https://arxiv.org/abs/2202.10447)
6 |
7 | # 存在的问题
8 | - `A = square(relu(qk / seq_len + bias))`感觉不对劲,假设训练的时候都是在seq_len=512的长度上进行的,如果预测的时候seq_len=16时,A的结果会发生很大的变化。
9 | - `embedding`部分和`MLM head`不确定是否使用的是`ScaleNorm`,不确定是否使用到`dropout`。
10 | - 发现当前代码训练出来的模型结果不理想,`n-1`层的输出和`n`层的输出差距不大。
11 |
12 |
13 | # 更新
14 | - 2022/04/01 添加[lucidrains实现的FLASH-pytorch](https://github.com/lucidrains/FLASH-pytorch),并添加训练好的权重[flash_small_wwm_cluecorpussmall权重](https://huggingface.co/junnyu/flash_small_wwm_cluecorpussmall) 和 [训练日志](https://wandb.ai/junyu/huggingface/runs/1jg2jlgt)。
15 | - 2022/03/01 使用带有`mlm_acc`的`Trainer`,训练过程中可以监控训练集`每logging_steps`的MLM准确率。
16 | - 2022/02/28 添加[checkpoint-170000的small权重](https://huggingface.co/junnyu/flashquad_small_wwm_cluecorpussmall),[训练日志1](https://wandb.ai/junyu/huggingface/runs/ofdc74wr)和[训练日志2](https://wandb.ai/junyu/huggingface/runs/2ep6cl14),感觉结果不理想。
17 | - 2022/02/26 修改了`rel_pos_bias`部分的代码,发现之前的代码会出现输出异常(训练是在512长度进行的,在别的长度进行测试,模型的输出会出问题. )
18 | ```python
19 | # 之前的代码.
20 | bias = self.rel_pos_bias(seq_len)
21 | kernel = torch.square(torch.relu(qk / seq_len + bias))
22 | # 更新后的代码.
23 | self.max_position_embeddings = 512
24 | bias = self.rel_pos_bias(self.max_position_embeddings)[:, :seq_len, :seq_len]
25 | kernel = torch.square(torch.relu(qk / self.max_position_embeddings + bias))
26 | ```
27 |
28 | # Usage
29 | ```python
30 | # flashquad
31 | from flash import FLASHQuadConfig, FLASHQuadModel
32 | import torch
33 | config = FLASHQuadConfig()
34 | model = FLASHQuadModel(config)
35 | model.eval()
36 | input_ids = torch.randint(0,12000,(4,128))
37 | with torch.no_grad():
38 | outputs = model(input_ids=input_ids, output_attentions=True, output_hidden_states=True)
39 | print(outputs)
40 |
41 | # flash
42 | from flash import FLASHConfig, FLASHModel
43 | import torch
44 | config = FLASHConfig()
45 | model = FLASHModel(config)
46 | model.eval()
47 | input_ids = torch.randint(0, 12000, (4, 128))
48 | with torch.no_grad():
49 | outputs = model(
50 | input_ids=input_ids, output_attentions=True, output_hidden_states=True
51 | )
52 | print(outputs)
53 | ```
54 |
55 | # Pretrain
56 | ## 准备数据
57 | ## CLUECorpusSmall 数据集处理教程(摘抄自[paddlenlp](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/data_tools/README.md))
58 | **数据集简介**:可用于语言建模、预训练或生成型任务等,数据量超过14G,近4000个定义良好的txt文件、50亿个字。主要部分来自于nlp_chinese_corpus项目
59 | 包含如下子语料库(总共14G语料):新闻语料[news2016zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/6bac09db4e6d4857b6d680d34447457490cb2dbdd8b8462ea1780a407f38e12b?responseContentDisposition=attachment%3B%20filename%3Dnews2016zh_corpus.zip), 社区互动语料[webText2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/83da03f7b4974871a52348b41c16c7e3b34a26d5ca644f558df8435be4de51c3?responseContentDisposition=attachment%3B%20filename%3DwebText2019zh_corpus.zip),维基百科语料[wiki2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/d7a166408d8b4ffdaf4de9cfca09f6ee1e2340260f26440a92f78134d068b28f?responseContentDisposition=attachment%3B%20filename%3Dwiki2019zh_corpus.zip),评论数据语料[comment2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/b66ddd445735408383c42322850ac4bb82faf9cc611447c2affb925443de7a6d?responseContentDisposition=attachment%3B%20filename%3Dcomment2019zh_corpus.zip)。
60 |
61 | **数据集下载**:
62 | 用户可以通过官方github网页下载,https://github.com/CLUEbenchmark/CLUECorpus2020 。同时,为方便用户,我们也提供了aistudio数据集下载地址。[part1](https://aistudio.baidu.com/aistudio/datasetdetail/60598),[part2](https://aistudio.baidu.com/aistudio/datasetdetail/124357)。使用aistudio版本的数据,下载好后,可以核对md5值:
63 | ```shell
64 | > md5sum ./*
65 | 8a8be341ebce39cfe9524fb0b46b08c5 ./comment2019zh_corpus.zip
66 | 4bdc2c941a7adb4a061caf273fea42b8 ./news2016zh_corpus.zip
67 | fc582409f078b10d717caf233cc58ddd ./webText2019zh_corpus.zip
68 | 157dacde91dcbd2e52a60af49f710fa5 ./wiki2019zh_corpus.zip
69 | ```
70 | (1) 解压文件
71 | ```shell
72 | unzip comment2019zh_corpus.zip -d clue_corpus_small_14g/comment2019zh_corpus
73 | unzip news2016zh_corpus.zip -d clue_corpus_small_14g/news2016zh_corpus
74 | unzip webText2019zh_corpus.zip -d clue_corpus_small_14g/webText2019zh_corpus
75 | unzip wiki2019zh_corpus.zip -d clue_corpus_small_14g/wiki2019zh_corpus
76 | ```
77 | (2) 将txt文件转换为jsonl格式
78 | ```shell
79 | python trans_to_json.py --input_path ./clue_corpus_small_14g --output_path clue_corpus_small_14g.jsonl
80 | mkdir data #创建data文件夹
81 | mv clue_corpus_small_14g.jsonl ./data #将jsonl放进该目录
82 | ```
83 | (3) 使用rjieba进行中文分词,会得到`data/refids.txt`和`data/reftext.txt`两个文件,并组合`data/refids.txt`和`data/reftext.txt`两个文件保存成`huggingface`的`dataset`
84 | ```shell
85 | python run_chinese_ref.py --model_name junnyu/roformer_chinese_char_base --input_path ./data/clue_corpus_small_14g.jsonl
86 | ```
87 |
88 | ## 开始训练(small版本模型)
89 | ```bash
90 | TRAIN_DIR=./clue_small_wwm_data
91 | OUTPUT_DIR=./wwm_flash_small/
92 | BATCH_SIZE=32
93 | ACCUMULATION=4
94 | LR=1e-4
95 | python run_mlm_wwm.py \
96 | --do_train \
97 | --tokenizer_name junnyu/roformer_chinese_char_base \
98 | --train_dir $TRAIN_DIR \
99 | --output_dir $OUTPUT_DIR \
100 | --logging_dir $OUTPUT_DIR/logs \
101 | --per_device_train_batch_size $BATCH_SIZE \
102 | --gradient_accumulation_steps $ACCUMULATION \
103 | --learning_rate $LR \
104 | --weight_decay 0.01 \
105 | --adam_epsilon 1e-6 \
106 | --max_steps 250000 \
107 | --warmup_steps 5000 \
108 | --logging_steps 100 \
109 | --save_steps 5000 \
110 | --seed 2022 \
111 | --max_grad_norm 3.0 \
112 | --dataloader_num_workers 6 \
113 | --fp16
114 |
115 | ```
116 |
117 | # MLM测试
118 | ```python
119 | # flashquad
120 | import torch
121 | from flash import FLASHQuadForMaskedLM
122 | from transformers import BertTokenizerFast
123 | tokenizer = BertTokenizerFast.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall")
124 | model = FLASHQuadForMaskedLM.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall")
125 | model.eval()
126 | text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
127 | inputs = tokenizer(text, return_tensors="pt")
128 | with torch.no_grad():
129 | pt_outputs = model(**inputs).logits[0]
130 |
131 | pt_outputs_sentence = "pytorch: "
132 | for i, id in enumerate(tokenizer.encode(text)):
133 | if id == tokenizer.mask_token_id:
134 | val,idx = pt_outputs[i].softmax(-1).topk(k=5)
135 | tokens = tokenizer.convert_ids_to_tokens(idx)
136 | new_tokens = []
137 | for v,t in zip(val.cpu(),tokens):
138 | new_tokens.append(f"{t}+{round(v.item(),4)}")
139 | pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
140 | else:
141 | pt_outputs_sentence += "".join(
142 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
143 | print(pt_outputs_sentence)
144 | # pytorch: 天气预报说今天的天[气+0.9948||空+0.0011||色+0.0007||候+0.0004||势+0.0003]很好,那么我[就+0.4915||们+0.4186||也+0.0753||还+0.0021||都+0.0016]一起去公园玩吧!
145 |
146 | # flash
147 | import torch
148 | from flash import FLASHForMaskedLM
149 | from transformers import BertTokenizerFast
150 | tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
151 | model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
152 | model.eval()
153 | text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
154 | inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512, return_token_type_ids=False) #这里必须是512,不然结果可能不对。
155 | with torch.no_grad():
156 | pt_outputs = model(**inputs).logits[0]
157 |
158 | pt_outputs_sentence = "pytorch: "
159 | for i, id in enumerate(tokenizer.encode(text)):
160 | if id == tokenizer.mask_token_id:
161 | val,idx = pt_outputs[i].softmax(-1).topk(k=5)
162 | tokens = tokenizer.convert_ids_to_tokens(idx)
163 | new_tokens = []
164 | for v,t in zip(val.cpu(),tokens):
165 | new_tokens.append(f"{t}+{round(v.item(),4)}")
166 | pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
167 | else:
168 | pt_outputs_sentence += "".join(
169 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
170 | print(pt_outputs_sentence)
171 | # pytorch: 天气预报说今天的天[气+0.9938||天+0.0017||空+0.0011||晴+0.0007||阳+0.0002]很好,那么我[们+0.9367||就+0.0554||也+0.0041||俩+0.0005||还+0.0004]一起去公园玩吧!
172 | ```
173 |
174 | # Tnews分类
175 |
176 |
177 |
178 |
179 | # Tips
180 | 不怎么确定实现的对不对,如果代码有错误的话,请帮我指出来,谢谢~
181 |
--------------------------------------------------------------------------------
/clue_small_wwm_data/dataset.arrow:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/clue_small_wwm_data/dataset.arrow
--------------------------------------------------------------------------------
/clue_small_wwm_data/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "builder_name": null,
3 | "citation": "",
4 | "config_name": null,
5 | "dataset_size": null,
6 | "description": "",
7 | "download_checksums": null,
8 | "download_size": null,
9 | "features": {
10 | "text": {
11 | "dtype": "string",
12 | "id": null,
13 | "_type": "Value"
14 | },
15 | "chinese_ref": {
16 | "feature": {
17 | "dtype": "int64",
18 | "id": null,
19 | "_type": "Value"
20 | },
21 | "length": -1,
22 | "id": null,
23 | "_type": "Sequence"
24 | }
25 | },
26 | "homepage": "",
27 | "license": "",
28 | "post_processed": null,
29 | "post_processing_size": null,
30 | "size_in_bytes": null,
31 | "splits": null,
32 | "supervised_keys": null,
33 | "task_templates": null,
34 | "version": null
35 | }
--------------------------------------------------------------------------------
/clue_small_wwm_data/place_holder.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/clue_small_wwm_data/place_holder.txt
--------------------------------------------------------------------------------
/clue_small_wwm_data/state.json:
--------------------------------------------------------------------------------
1 | {
2 | "_data_files": [
3 | {
4 | "filename": "dataset.arrow"
5 | }
6 | ],
7 | "_fingerprint": "602b282613955525",
8 | "_format_columns": null,
9 | "_format_kwargs": {},
10 | "_format_type": null,
11 | "_indexes": {},
12 | "_output_all_columns": false,
13 | "_split": null
14 | }
--------------------------------------------------------------------------------
/data/clue_corpus_small_14g.jsonl:
--------------------------------------------------------------------------------
1 | {"text":"如何评价中企为打捞世越号赔本11亿?韩国媒体报道,上海打捞局和韩国政府签定的合同为916亿韩元,打捞局现已花费2800亿韩元。首先声明,此为本人拙见,仅作为猜测,欢迎大家拍砖!中国有实力进行海上大型船舶打捞作业企业一共有三家,都隶属于交通运输部。按北中南分布,分别是烟台打捞局上海打捞局广州打捞局,都是根正苗红的国企(国家意志)。合同签署的时间点很重要,2015年5月起,韩国对世越号沉船打捞项目公开招标。这个时间点中韩关系处在“蜜月期”。小弟出事了,大哥搭把手。再一个合同的价格,2015年8月,由中国交通运输部下属上海打捞"}
2 | {"text":"味道很不错,就是地方有点小。总体来说是很不错的,值得一吃哦。喜欢吃西餐的朋友可以去一下,反正我是听喜欢这种氛围的。"}
3 | {"text":"有哪些法律术语容易被大众误解?比如纳税人,挤了30种出来,大脑短路搞忘了几个点。以下全是干货。1.法人和法人的法定代表人。法人是法律拟制之人,自然人是一个与之并列的概念,一个自然人是万万成为不了法人的,只能成为法定代表人。刑法中对妇女的规定是已满十四周岁的女性(生日的第二天起算),大家通常理解为非处女。不是关两年再执行,而是给犯人一个不死的机会。被判死缓的人类似于坐在达摩克利斯之剑下,两年不故意犯罪就可减刑为无期,有"}
4 | {"text":"这家粥铺在上海春城商业街上也算开的时间较久了。个人觉得还是有存在价值。粥还是比较营养的,品种一般。就是小菜有点贵。"}
5 | {"text":"非常喜欢这本书封面设计很特别字体也很清楚。内容就更不用说了与电影相比原著更值得一读让人为之动容。"}
6 | {"text":"佩兰首秀遭球迷调侃“世界杯I组,中国队首胜”(组图)主教练佩兰小试牛刀 新华社发 主教练佩兰小试牛刀新华社发 热身赛 中国 2:0 马其顿 热身赛中国2:0马其顿0618AP08于汉超庆祝进球新华社发昨晚,中国男足时隔13年后重回福地沈阳,依靠于汉超和高迪在下半场的两粒进球,最终2比0战胜对手。日期“撞车”世界杯,昨天沈阳下了一场雷暴雨,原本想大赚一笔的黄牛,却在赛前把100元一张的门票贱卖至5块也无人问津,幽怨地对记者说:“回家没法跟老婆交待,因为赔惨了。”而在世界杯这个热闹的大背景之下,国足自然也难逃被调侃的命运。法国人佩兰执掌国足的首秀,被不少球迷称为中国队在“世界杯I组(世界杯只有8个小组,从A-H组)的首胜”。3天之后,双方将进行第二场的较量。特约记者 程玲林备战揭秘国脚聊侃世界杯奥乔亚被赞“来自星星”因为晚上有比赛,昨天是国足集训以来唯一白天没安排训练的一天。不过国脚们也没人“敢看”世界杯的直播。可一觉醒来,奥乔亚竟然成为了国脚们的谈资,就连主教练佩兰的翻译也私下里透露,“佩兰虽然没看(世界杯直播)比赛,但他却认为\"是奥乔亚打败了巴西队\"。”奥乔亚是谁?这在巴西与墨西哥赛前可能没有中国球迷能认得。"}
7 | {"text":"可几乎在一夜之间,奥乔亚这个名字就传遍了大江南北。回到家乡沈阳的国脚刘建业,更是用“那是来自星星的你”,这句纯正的东北话让队友们笑翻了。因为佩兰在国足集训首日就已经公布了球队的作息时间表,因此尽管在沈阳这几天相对时间宽松,不过国脚们还是遵守纪律,没人敢“偷看世界杯直播”。可早上吃饭时,巴西队被墨西哥逼平,而且功臣是奥乔亚,也很快成为大家的谈资。“墨西哥就不缺这样的(守门员)。”于汉超说,“之前不有个什么花蝴蝶么(坎波斯)……”“不是常说,\"好的守门员等于半支球队\"么,这家伙(奥乔亚)真是\"神兽\"。”张呈栋冷不丁来上这一句。“是来自星星的奥乔亚!”刘建业用他那东北话说。向群世界杯·链接国足不参赛照样抢钱中国足协稳赚75万美元听说过躺着中枪的,现在有了躺着赚钱的:由于国际足联提高了巴西世界杯的总奖金,即使没有资格出现在巴西,中国足协也能从国际足联那里分到75万美元的分红。2010年南非世界杯,32支参赛队共获得了4.2亿美元奖金,这个数值在今年将达到5.76亿。巴西世界杯,国际足联将破纪录地获得45亿美元收入,除了给参赛队5.76亿之外,国际足联还将提供2亿美元作为分红发放给各国足协,所有209个成员国将在7月底收到25万美元,明年年初收到后续50万美元。"}
8 | {"text":"这意味着,就算没有资格在巴西亮相,中国足协也可以收到75万美元分红,尽管比起参赛球队至少可以获得900万美元的数字少了不少,但躺着赚钱的中国足协也应该满足了。陈甘露 佩帅点将首发起用了5名新球员巴西世界杯激战正酣,而在地球的另一端,已连续几届无法打进世界杯预选赛亚洲区决赛阶段的中国队与欧洲鱼腩马其顿队进行了一场友谊赛,而这也是二月份挂帅国足帅印的法籍主帅佩兰,上任后所执教的首场比赛。此役,新官上任三把火的佩兰,一改中国男足此前重用老球员的模式,首发名单中有多达5名球员第一次为国足先发上场。面对欧洲弱旅马其顿,汇集目前国内最好球员的中国男足并未能在上半场展现出优势,反倒给马其顿队几次反击打得狼狈十足。下半场开始后,佩兰对场上人员进行了大幅度调整。第56分钟,刚刚转会广州恒大的于汗超依靠一次单刀机会为中国队破门得分,此后上海绿地的高迪在临近比赛结束时一脚远射,将比分改写为2比0,最终中国队依靠这两粒进球轻取对手,而佩兰也在执掌此役后,向国人交出了一份及格答卷。值得一提的是,在佩兰阵容中,任航、吴曦分别坐镇左边后卫和后腰位置。从昨天的表现来看,两人发挥可谓中规中距,尽到自己的责任。特别是作为新人,第一次代表中国男足先发出场的任航,在面对马其顿锋线球员的冲击,表现出其在俱乐部的真实水平,虽然在上半场临近结束时有一次失位,险些让对方投机得手,但整场比赛数据显示,任航在与对方前锋一对一的抢断成功率高达82.3%,这是全场防守球员数据最好的。"}
9 | {"text":"观战花絮天气不佳,上座率不足千人昨晚,国足时隔13年后重回福地沈阳进行比赛。13年前的沈阳五里河,对于中国球迷来说,是唯一残存在脑海里的幸福瞬间。而今,沈阳五里河早已拆为平地建成了商品房,中国男足却在遥望巴西的同时,与世界杯渐行渐远。或许是沈阳昨天下了一整天雷暴雨的原因,又或者是大家在世界杯赛期间,无心关注国足的比赛。截至到开赛前,沈阳奥林匹克体育中心的上座率不足千人,而场外的黄牛则叫苦连连。原本期待国足13年后重回沈阳,能重展当年一票难求的球市,但实际上,黄牛在临开场把自购100元一张的球票贱卖5元都无人问津,记者现场采访了一位持数张球票在苦寻买主的黄牛,得到的答案是:“别提了,赔惨了!回家都不晓得怎么跟老婆交待……”现场的场景有点凄惨,而国足在网络上也成为了球迷调侃的苦主。昨天,国足的比赛用球也是本届世界杯的官方用球—桑巴荣耀,有球迷在网络留言中大玩转折,称“中国队终于踢上了世界杯……用球”。翻看FIFA比赛系统,6-7月在案的国际比赛并不多,只有几场印尼、越南和尼泊尔的热身赛。当32强在巴西厮杀正酣的时候,国足却只能拉来马其顿二队进行热身赛,这样的情形确实虐心。因此这场热身赛,也被众多网友戏谑地称为“世界杯I组比赛”,并恭喜国足取得比赛胜利。"}
10 | {"text":"去杭州玩随便找了这家吃的,味道还行,糖醋小排挺好吃的,还有蚝油生菜,本来不知道什么是片儿川,现在总算知道了-。- 牛蛙太辣了,没吃几口。鸭舌就算了,冷菜而已。"}
11 |
--------------------------------------------------------------------------------
/data/refids.txt:
--------------------------------------------------------------------------------
1 | [2, 4, 6, 9, 11, 14, 19, 21, 22, 23, 26, 28, 32, 33, 34, 36, 39, 44, 45, 48, 51, 53, 56, 57, 60, 62, 67, 69, 73, 75, 78, 80, 82, 85, 88, 90, 92, 94, 96, 98, 100, 102, 104, 107, 111, 112, 114, 115, 116, 120, 122, 124, 127, 130, 132, 135, 137, 140, 142, 148, 149, 150, 153, 156, 158, 162, 164, 167, 171, 180, 182, 183, 186, 188, 190, 192, 193, 194, 197, 199, 202, 203, 204, 206, 209, 214, 216, 220, 223, 227, 229, 232, 241, 243, 244, 245, 247, 250, 252]
2 | [2, 5, 8, 10, 12, 16, 18, 22, 26, 28, 32, 35, 38, 40, 43, 46, 51, 53, 55]
3 | [3, 5, 7, 9, 12, 14, 17, 19, 20, 27, 30, 32, 37, 41, 43, 45, 50, 53, 56, 57, 58, 62, 65, 67, 72, 73, 76, 80, 83, 86, 88, 89, 92, 94, 96, 98, 102, 104, 106, 107, 108, 112, 116, 119, 124, 125, 126, 129, 132, 135, 136, 142, 144, 146, 150, 153, 156, 159, 162, 165, 167, 169, 172, 175, 177, 181, 184, 186, 188, 189, 191, 192, 195, 198, 199, 200, 204, 207]
4 | [2, 7, 9, 11, 12, 16, 19, 21, 25, 27, 29, 32, 34, 38, 40, 42, 46, 48, 51, 53, 55]
5 | [2, 4, 7, 9, 10, 11, 14, 16, 20, 23, 27, 28, 32, 34, 36, 39, 41, 45, 46, 47]
6 | [2, 4, 7, 9, 12, 13, 18, 19, 21, 25, 28, 29, 31, 33, 34, 35, 40, 41, 44, 45, 47, 49, 50, 51, 53, 54, 57, 58, 63, 74, 75, 77, 78, 80, 85, 86, 92, 93, 95, 97, 99, 100, 103, 106, 107, 108, 110, 115, 117, 119, 122, 124, 125, 128, 131, 132, 135, 137, 140, 145, 147, 150, 153, 156, 157, 160, 162, 166, 168, 169, 172, 174, 177, 180, 185, 190, 193, 195, 201, 202, 203, 206, 210, 215, 217, 220, 222, 225, 227, 234, 235, 237, 239, 243, 245, 248, 250, 253, 256, 259, 262, 263, 265, 267, 269, 272, 276, 278, 280, 282, 283, 287, 288, 293, 294, 296, 300, 310, 316, 319, 322, 324, 325, 328, 331, 332, 333, 335, 336, 338, 340, 342, 344, 346, 347, 349, 350, 355, 357, 360, 362, 365, 368, 371, 373, 375, 377, 379, 382, 384, 387, 390, 392, 396, 402, 403, 406, 409, 411, 415, 416, 418, 420, 423, 427, 432, 433, 435, 438, 441, 442, 444, 448, 450, 455, 456, 458, 461, 464, 465, 467, 471, 472, 474, 477, 478, 483, 484, 491, 494, 495, 497, 499, 501, 503, 505, 508]
7 | [3, 6, 7, 8, 11, 12, 14, 16, 19, 22, 23, 24, 27, 29, 31, 34, 36, 37, 40, 46, 48, 54, 56, 59, 63, 70, 72, 75, 77, 79, 82, 84, 87, 90, 91, 92, 96, 98, 101, 104, 106, 108, 110, 113, 115, 118, 120, 121, 122, 125, 126, 129, 131, 132, 134, 139, 141, 145, 146, 149, 150, 152, 155, 157, 160, 161, 165, 167, 169, 172, 176, 177, 180, 182, 186, 187, 192, 193, 198, 200, 203, 205, 206, 211, 218, 220, 226, 227, 229, 231, 233, 239, 242, 243, 246, 249, 254, 255, 257, 258, 263, 268, 270, 273, 274, 278, 279, 284, 291, 292, 295, 297, 300, 302, 304, 306, 307, 308, 310, 313, 314, 316, 321, 325, 331, 335, 337, 338, 339, 341, 344, 346, 347, 351, 354, 356, 358, 360, 363, 366, 367, 368, 373, 374, 375, 377, 379, 382, 383, 386, 391, 393, 394, 399, 400, 403, 409, 410, 412, 415, 417, 420, 423, 430, 432, 433, 436, 437, 438, 441, 442, 445, 448, 449, 451, 454, 457, 458, 464, 467, 468, 469, 473, 476, 477, 479, 481, 483, 486, 488, 491, 495, 496, 501, 503, 506, 507, 510, 512, 514, 516, 519, 520]
8 | [3, 4, 7, 9, 11, 14, 16, 19, 20, 21, 24, 26, 29, 30, 32, 35, 37, 39, 41, 43, 45, 47, 50, 51, 54, 58, 64, 67, 68, 69, 72, 74, 79, 81, 83, 85, 87, 91, 93, 95, 97, 98, 100, 102, 107, 110, 111, 115, 117, 119, 123, 124, 126, 127, 129, 130, 132, 134, 137, 138, 141, 143, 145, 146, 149, 152, 154, 155, 162, 163, 165, 169, 172, 174, 176, 179, 183, 186, 188, 191, 194, 195, 196, 197, 198, 199, 202, 205, 207, 208, 209, 211, 213, 216, 219, 222, 224, 226, 228, 232, 234, 235, 237, 238, 240, 242, 245, 247, 249, 251, 252, 255, 257, 259, 261, 263, 266, 267, 268, 270, 274, 275, 277, 278, 280, 283, 286, 287, 290, 292, 296, 298, 301, 302, 304, 308, 311, 313, 315, 318, 319, 321, 326, 329, 331, 333, 335, 341, 343, 345, 347, 350, 351, 353, 355, 358, 360, 362, 365, 368, 370, 372, 375, 377, 381, 383, 390, 392, 393, 395, 398, 400, 402, 404, 408, 412, 414, 419, 421, 424, 426, 428, 431, 432, 433, 434, 435, 439, 441, 445, 448, 450, 452, 454, 455, 456, 459, 461, 465, 468, 470, 473, 475, 477, 479, 481, 484, 486, 489, 492, 495, 497, 500, 501, 503, 505, 506, 507, 509, 514, 518, 520, 521, 523, 525, 528, 531, 533, 536, 537, 540, 542, 545, 548, 549, 551, 553, 555, 557, 559, 562, 565, 567, 569, 573, 575, 577, 579, 582, 586, 588, 590, 591, 594, 596, 597, 598, 606, 608, 610, 612, 614, 616]
9 | [2, 4, 6, 8, 11, 12, 14, 16, 18, 21, 23, 28, 30, 32, 34, 36, 40, 43, 45, 46, 49, 51, 53, 55, 59, 61, 64, 68, 70, 73, 76, 78, 79, 81, 83, 85, 87, 90, 91, 94, 95, 96, 100, 102, 105, 109, 110, 112, 114, 117, 120, 122, 126, 127, 129, 130, 133, 137, 140, 143, 144, 145, 147, 150, 152, 154, 157, 160, 163, 167, 169, 170, 171, 173, 174, 175, 178, 179, 181, 183, 187, 190, 193, 195, 198, 200, 202, 207, 209, 213, 215, 217, 219, 222, 226, 227, 230, 234, 237, 241, 244, 246, 251, 252, 253, 256, 258, 259, 260, 263, 266, 268, 273, 276, 279, 282, 287, 291, 295, 299, 301, 304, 306, 311, 314, 316, 318, 321, 322, 325, 329, 332, 334, 337, 340, 343, 346, 348, 352, 354, 355, 358, 360, 363, 365, 369, 372, 374, 376, 377, 379, 384, 385, 387, 392, 393, 397, 401, 404, 406, 413, 416, 418, 420, 421, 424, 426, 428, 431, 434, 435, 438, 439, 446, 448, 450, 453, 456, 459, 463, 464, 466, 468, 470, 471, 474, 477, 479, 481, 484, 486, 488, 489, 494, 496, 498, 501, 504, 505, 509, 514, 516, 518, 520, 522]
10 | [3, 6, 10, 15, 20, 22, 24, 25, 29, 31, 33, 36, 39, 41, 44, 48, 50, 52, 58, 60, 66, 69, 71, 75, 77]
--------------------------------------------------------------------------------
/data/reftext.txt:
--------------------------------------------------------------------------------
1 | 如何评价中企为打捞世越号赔本11亿?韩国媒体报道,上海打捞局和韩国政府签定的合同为916亿韩元,打捞局现已花费2800亿韩元。首先声明,此为本人拙见,仅作为猜测,欢迎大家拍砖!中国有实力进行海上大型船舶打捞作业企业一共有三家,都隶属于交通运输部。按北中南分布,分别是烟台打捞局上海打捞局广州打捞局,都是根正苗红的国企(国家意志)。合同签署的时间点很重要,2015年5月起,韩国对世越号沉船打捞项目公开招标。这个时间点中韩关系处在“蜜月期”。小弟出事了,大哥搭把手。再一个合同的价格,2015年8月,由中国交通运输部下属上海打捞
2 | 味道很不错,就是地方有点小。总体来说是很不错的,值得一吃哦。喜欢吃西餐的朋友可以去一下,反正我是听喜欢这种氛围的。
3 | 有哪些法律术语容易被大众误解?比如纳税人,挤了30种出来,大脑短路搞忘了几个点。以下全是干货。1.法人和法人的法定代表人。法人是法律拟制之人,自然人是一个与之并列的概念,一个自然人是万万成为不了法人的,只能成为法定代表人。刑法中对妇女的规定是已满十四周岁的女性(生日的第二天起算),大家通常理解为非处女。不是关两年再执行,而是给犯人一个不死的机会。被判死缓的人类似于坐在达摩克利斯之剑下,两年不故意犯罪就可减刑为无期,有
4 | 这家粥铺在上海春城商业街上也算开的时间较久了。个人觉得还是有存在价值。粥还是比较营养的,品种一般。就是小菜有点贵。
5 | 非常喜欢这本书封面设计很特别字体也很清楚。内容就更不用说了与电影相比原著更值得一读让人为之动容。
6 | 佩兰首秀遭球迷调侃“世界杯I组,中国队首胜”(组图)主教练佩兰小试牛刀 新华社发 主教练佩兰小试牛刀新华社发 热身赛 中国 2:0 马其顿 热身赛中国2:0马其顿0618AP08于汉超庆祝进球新华社发昨晚,中国男足时隔13年后重回福地沈阳,依靠于汉超和高迪在下半场的两粒进球,最终2比0战胜对手。日期“撞车”世界杯,昨天沈阳下了一场雷暴雨,原本想大赚一笔的黄牛,却在赛前把100元一张的门票贱卖至5块也无人问津,幽怨地对记者说:“回家没法跟老婆交待,因为赔惨了。”而在世界杯这个热闹的大背景之下,国足自然也难逃被调侃的命运。法国人佩兰执掌国足的首秀,被不少球迷称为中国队在“世界杯I组(世界杯只有8个小组,从A-H组)的首胜”。3天之后,双方将进行第二场的较量。特约记者 程玲林备战揭秘国脚聊侃世界杯奥乔亚被赞“来自星星”因为晚上有比赛,昨天是国足集训以来唯一白天没安排训练的一天。不过国脚们也没人“敢看”世界杯的直播。可一觉醒来,奥乔亚竟然成为了国脚们的谈资,就连主教练佩兰的翻译也私下里透露,“佩兰虽然没看(世界杯直播)比赛,但他却认为"是奥乔亚打败了巴西队"。”奥乔亚是谁?这在巴西与墨西哥赛前可能没有中国球迷能认得。
7 | 可几乎在一夜之间,奥乔亚这个名字就传遍了大江南北。回到家乡沈阳的国脚刘建业,更是用“那是来自星星的你”,这句纯正的东北话让队友们笑翻了。因为佩兰在国足集训首日就已经公布了球队的作息时间表,因此尽管在沈阳这几天相对时间宽松,不过国脚们还是遵守纪律,没人敢“偷看世界杯直播”。可早上吃饭时,巴西队被墨西哥逼平,而且功臣是奥乔亚,也很快成为大家的谈资。“墨西哥就不缺这样的(守门员)。”于汉超说,“之前不有个什么花蝴蝶么(坎波斯)……”“不是常说,"好的守门员等于半支球队"么,这家伙(奥乔亚)真是"神兽"。”张呈栋冷不丁来上这一句。“是来自星星的奥乔亚!”刘建业用他那东北话说。向群世界杯·链接国足不参赛照样抢钱中国足协稳赚75万美元听说过躺着中枪的,现在有了躺着赚钱的:由于国际足联提高了巴西世界杯的总奖金,即使没有资格出现在巴西,中国足协也能从国际足联那里分到75万美元的分红。2010年南非世界杯,32支参赛队共获得了4.2亿美元奖金,这个数值在今年将达到5.76亿。巴西世界杯,国际足联将破纪录地获得45亿美元收入,除了给参赛队5.76亿之外,国际足联还将提供2亿美元作为分红发放给各国足协,所有209个成员国将在7月底收到25万美元,明年年初收到后续50万美元。
8 | 这意味着,就算没有资格在巴西亮相,中国足协也可以收到75万美元分红,尽管比起参赛球队至少可以获得900万美元的数字少了不少,但躺着赚钱的中国足协也应该满足了。陈甘露 佩帅点将首发起用了5名新球员巴西世界杯激战正酣,而在地球的另一端,已连续几届无法打进世界杯预选赛亚洲区决赛阶段的中国队与欧洲鱼腩马其顿队进行了一场友谊赛,而这也是二月份挂帅国足帅印的法籍主帅佩兰,上任后所执教的首场比赛。此役,新官上任三把火的佩兰,一改中国男足此前重用老球员的模式,首发名单中有多达5名球员第一次为国足先发上场。面对欧洲弱旅马其顿,汇集目前国内最好球员的中国男足并未能在上半场展现出优势,反倒给马其顿队几次反击打得狼狈十足。下半场开始后,佩兰对场上人员进行了大幅度调整。第56分钟,刚刚转会广州恒大的于汗超依靠一次单刀机会为中国队破门得分,此后上海绿地的高迪在临近比赛结束时一脚远射,将比分改写为2比0,最终中国队依靠这两粒进球轻取对手,而佩兰也在执掌此役后,向国人交出了一份及格答卷。值得一提的是,在佩兰阵容中,任航、吴曦分别坐镇左边后卫和后腰位置。从昨天的表现来看,两人发挥可谓中规中距,尽到自己的责任。特别是作为新人,第一次代表中国男足先发出场的任航,在面对马其顿锋线球员的冲击,表现出其在俱乐部的真实水平,虽然在上半场临近结束时有一次失位,险些让对方投机得手,但整场比赛数据显示,任航在与对方前锋一对一的抢断成功率高达82.3%,这是全场防守球员数据最好的。
9 | 观战花絮天气不佳,上座率不足千人昨晚,国足时隔13年后重回福地沈阳进行比赛。13年前的沈阳五里河,对于中国球迷来说,是唯一残存在脑海里的幸福瞬间。而今,沈阳五里河早已拆为平地建成了商品房,中国男足却在遥望巴西的同时,与世界杯渐行渐远。或许是沈阳昨天下了一整天雷暴雨的原因,又或者是大家在世界杯赛期间,无心关注国足的比赛。截至到开赛前,沈阳奥林匹克体育中心的上座率不足千人,而场外的黄牛则叫苦连连。原本期待国足13年后重回沈阳,能重展当年一票难求的球市,但实际上,黄牛在临开场把自购100元一张的球票贱卖5元都无人问津,记者现场采访了一位持数张球票在苦寻买主的黄牛,得到的答案是:“别提了,赔惨了!回家都不晓得怎么跟老婆交待……”现场的场景有点凄惨,而国足在网络上也成为了球迷调侃的苦主。昨天,国足的比赛用球也是本届世界杯的官方用球—桑巴荣耀,有球迷在网络留言中大玩转折,称“中国队终于踢上了世界杯……用球”。翻看FIFA比赛系统,6-7月在案的国际比赛并不多,只有几场印尼、越南和尼泊尔的热身赛。当32强在巴西厮杀正酣的时候,国足却只能拉来马其顿二队进行热身赛,这样的情形确实虐心。因此这场热身赛,也被众多网友戏谑地称为“世界杯I组比赛”,并恭喜国足取得比赛胜利。
10 | 去杭州玩随便找了这家吃的,味道还行,糖醋小排挺好吃的,还有蚝油生菜,本来不知道什么是片儿川,现在总算知道了-。- 牛蛙太辣了,没吃几口。鸭舌就算了,冷菜而已。
--------------------------------------------------------------------------------
/figure/tnews.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/figure/tnews.jpg
--------------------------------------------------------------------------------
/flash/__init__.py:
--------------------------------------------------------------------------------
1 | from flash.flash import FLASHQuadConfig, FLASHQuadModel, FLASHQuadForMaskedLM, FLASHQuadForSequenceClassification
2 | from flash.flash_lucidrains import FLASHConfig, FLASHModel, FLASHForMaskedLM, FLASHForSequenceClassification, FLASHForMultipleChoice, FLASHForTokenClassification, FLASHForQuestionAnswering
--------------------------------------------------------------------------------
/flash/flash.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.modeling_outputs import (
5 | BaseModelOutput,
6 | MaskedLMOutput,
7 | SequenceClassifierOutput,
8 | )
9 | from transformers.modeling_utils import PreTrainedModel
10 | from transformers.models.bert.modeling_bert import (
11 | BertOnlyMLMHead as FLASHQuadOnlyMLMHead,
12 | )
13 | from transformers.utils import logging
14 |
15 | from flash.gau import GAU, ScaleNorm
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 |
20 | class FLASHQuadConfig(PretrainedConfig):
21 | model_type = "flash_quad"
22 |
23 | def __init__(
24 | self,
25 | vocab_size=12000,
26 | hidden_size=768,
27 | num_hidden_layers=24, # base
28 | max_position_embeddings=512,
29 | type_vocab_size=2,
30 | initializer_range=0.02,
31 | layer_norm_eps=1e-5,
32 | pad_token_id=0,
33 | expansion_factor=2,
34 | s=128,
35 | norm_type="scale_norm",
36 | gradient_checkpointing=False,
37 | dropout=0.0,
38 | hidden_act="silu",
39 | classifier_dropout=0.1,
40 | **kwargs
41 | ):
42 | super().__init__(pad_token_id=pad_token_id, **kwargs)
43 |
44 | self.vocab_size = vocab_size
45 | self.hidden_size = hidden_size
46 | self.num_hidden_layers = num_hidden_layers
47 | self.max_position_embeddings = max_position_embeddings
48 | self.type_vocab_size = type_vocab_size
49 | self.initializer_range = initializer_range
50 | self.layer_norm_eps = layer_norm_eps
51 | self.expansion_factor = expansion_factor
52 | self.s = s
53 | self.norm_type = norm_type
54 | self.dropout = dropout
55 | self.hidden_act = hidden_act
56 | self.gradient_checkpointing = gradient_checkpointing
57 | self.classifier_dropout = classifier_dropout
58 |
59 |
60 | class FLASHQuadPreTrainedModel(PreTrainedModel):
61 | config_class = FLASHQuadConfig
62 | base_model_prefix = "flash_quad"
63 |
64 | def _init_weights(self, module):
65 | """Initialize the weights"""
66 | if isinstance(module, nn.Linear):
67 | module.weight.data.normal_(
68 | mean=0.0, std=self.config.initializer_range)
69 | if module.bias is not None:
70 | module.bias.data.zero_()
71 | elif isinstance(module, nn.Embedding):
72 | module.weight.data.normal_(
73 | mean=0.0, std=self.config.initializer_range)
74 | if module.padding_idx is not None:
75 | module.weight.data[module.padding_idx].zero_()
76 | elif isinstance(module, nn.LayerNorm):
77 | module.bias.data.zero_()
78 | module.weight.data.fill_(1.0)
79 |
80 |
81 | class FLASHQuadEmbeddings(nn.Module):
82 | """Construct the embeddings from word, position and token_type embeddings."""
83 |
84 | def __init__(self, config):
85 | super().__init__()
86 | self.config = config
87 | self.word_embeddings = nn.Embedding(
88 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
89 | )
90 | self.token_type_embeddings = nn.Embedding(
91 | config.type_vocab_size, config.hidden_size
92 | )
93 | self.LayerNorm = (
94 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
95 | if config.norm_type == "layer_norm"
96 | else ScaleNorm(eps=config.layer_norm_eps)
97 | )
98 | self.dropout = nn.Dropout(config.dropout)
99 | self.register_buffer(
100 | "position_ids", torch.arange(
101 | config.max_position_embeddings).expand((1, -1))
102 | )
103 | self.scaledsin_scalar = nn.Parameter(
104 | torch.ones(1) / (config.hidden_size ** 0.5)
105 | )
106 | self.register_buffer("scaledsin_embeddings", self.get_scaledsin())
107 |
108 | def get_scaledsin(self):
109 | """Create sinusoidal position embedding with a scaling factor."""
110 | seqlen, hidden_size = (
111 | self.config.max_position_embeddings,
112 | self.config.hidden_size,
113 | )
114 | pos = torch.arange(seqlen, dtype=torch.float32)
115 | half_d = hidden_size // 2
116 |
117 | freq_seq = -torch.arange(half_d, dtype=torch.float32) / float(half_d)
118 | inv_freq = 10000 ** freq_seq
119 | sinusoid = torch.einsum("s,d->sd", pos, inv_freq)
120 | scaledsin = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1)
121 | # scalar = 1 / hidden_size ** 0.5
122 | # scaledsin *= scalar
123 | return scaledsin
124 |
125 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
126 | input_shape = input_ids.shape
127 | seq_length = input_shape[1]
128 |
129 | if position_ids is None:
130 | position_ids = self.position_ids[:, :seq_length]
131 |
132 | if token_type_ids is None:
133 | token_type_ids = torch.zeros_like(input_ids)
134 |
135 | inputs_embeds = self.word_embeddings(input_ids)
136 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
137 | position_embeddings = (
138 | self.scaledsin_embeddings[position_ids] * self.scaledsin_scalar
139 | )
140 | embeddings = inputs_embeds + token_type_embeddings + position_embeddings
141 | embeddings = self.LayerNorm(embeddings)
142 | embeddings = self.dropout(embeddings)
143 | return embeddings
144 |
145 |
146 | class FLASHQuadEncoder(nn.Module):
147 | def __init__(self, config):
148 | super().__init__()
149 | self.config = config
150 | self.layer = nn.ModuleList(
151 | [
152 | GAU(
153 | config.hidden_size,
154 | config.expansion_factor,
155 | config.s,
156 | config.norm_type,
157 | config.layer_norm_eps,
158 | config.hidden_act,
159 | config.max_position_embeddings,
160 | )
161 | for _ in range(config.num_hidden_layers)
162 | ]
163 | )
164 |
165 | def forward(
166 | self,
167 | hidden_states,
168 | attention_mask=None,
169 | output_attentions=False,
170 | output_hidden_states=False,
171 | return_dict=True,
172 | ):
173 | all_hidden_states = () if output_hidden_states else None
174 | all_self_attentions = () if output_attentions else None
175 |
176 | for i, layer_module in enumerate(self.layer):
177 | if output_hidden_states:
178 | all_hidden_states = all_hidden_states + (hidden_states,)
179 |
180 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
181 |
182 | def create_custom_forward(module):
183 | def custom_forward(*inputs):
184 | return module(*inputs, output_attentions)
185 |
186 | return custom_forward
187 |
188 | layer_outputs = torch.utils.checkpoint.checkpoint(
189 | create_custom_forward(layer_module),
190 | hidden_states,
191 | attention_mask,
192 | )
193 | else:
194 | layer_outputs = layer_module(
195 | hidden_states,
196 | attention_mask,
197 | output_attentions,
198 | )
199 |
200 | hidden_states = layer_outputs[0]
201 |
202 | if output_attentions:
203 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
204 |
205 | if output_hidden_states:
206 | all_hidden_states = all_hidden_states + (hidden_states,)
207 |
208 | if not return_dict:
209 | return tuple(
210 | v
211 | for v in [
212 | hidden_states,
213 | all_hidden_states,
214 | all_self_attentions,
215 | ]
216 | if v is not None
217 | )
218 | return BaseModelOutput(
219 | last_hidden_state=hidden_states,
220 | hidden_states=all_hidden_states,
221 | attentions=all_self_attentions,
222 | )
223 |
224 |
225 | class FLASHQuadModel(FLASHQuadPreTrainedModel):
226 | def __init__(self, config):
227 | super().__init__(config)
228 | self.config = config
229 |
230 | self.embeddings = FLASHQuadEmbeddings(config)
231 | self.encoder = FLASHQuadEncoder(config)
232 |
233 | self.post_init()
234 |
235 | def get_input_embeddings(self):
236 | return self.embeddings.word_embeddings
237 |
238 | def set_input_embeddings(self, value):
239 | self.embeddings.word_embeddings = value
240 |
241 | def forward(
242 | self,
243 | input_ids=None,
244 | attention_mask=None,
245 | token_type_ids=None,
246 | position_ids=None,
247 | output_attentions=None,
248 | output_hidden_states=None,
249 | return_dict=None,
250 | ):
251 | output_attentions = (
252 | output_attentions
253 | if output_attentions is not None
254 | else self.config.output_attentions
255 | )
256 | output_hidden_states = (
257 | output_hidden_states
258 | if output_hidden_states is not None
259 | else self.config.output_hidden_states
260 | )
261 | return_dict = (
262 | return_dict if return_dict is not None else self.config.use_return_dict
263 | )
264 |
265 | if attention_mask is None:
266 | attention_mask = (input_ids != self.config.pad_token_id).type_as(
267 | self.embeddings.word_embeddings.weight
268 | )
269 |
270 | if token_type_ids is None:
271 | token_type_ids = torch.zeros_like(input_ids)
272 |
273 | embedding_output = self.embeddings(
274 | input_ids=input_ids,
275 | position_ids=position_ids,
276 | token_type_ids=token_type_ids,
277 | )
278 |
279 | encoder_outputs = self.encoder(
280 | embedding_output,
281 | attention_mask=attention_mask,
282 | output_attentions=output_attentions,
283 | output_hidden_states=output_hidden_states,
284 | return_dict=return_dict,
285 | )
286 | sequence_output = encoder_outputs[0]
287 |
288 | if not return_dict:
289 | return (sequence_output,) + encoder_outputs[1:]
290 |
291 | return BaseModelOutput(
292 | last_hidden_state=sequence_output,
293 | hidden_states=encoder_outputs.hidden_states,
294 | attentions=encoder_outputs.attentions,
295 | )
296 |
297 |
298 | class FLASHQuadForMaskedLM(FLASHQuadPreTrainedModel):
299 | def __init__(self, config):
300 | super().__init__(config)
301 |
302 | self.flash_quad = FLASHQuadModel(config)
303 | self.cls = FLASHQuadOnlyMLMHead(config)
304 | self.cls.predictions.transform.LayerNorm = (
305 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
306 | if config.norm_type == "layer_norm"
307 | else ScaleNorm(eps=config.layer_norm_eps)
308 | )
309 | self.loss_fn = nn.CrossEntropyLoss()
310 | self.post_init()
311 |
312 | def get_output_embeddings(self):
313 | return self.cls.predictions.decoder
314 |
315 | def set_output_embeddings(self, new_embeddings):
316 | self.cls.predictions.decoder = new_embeddings
317 |
318 | def forward(
319 | self,
320 | input_ids=None,
321 | attention_mask=None,
322 | token_type_ids=None,
323 | position_ids=None,
324 | labels=None,
325 | output_attentions=None,
326 | output_hidden_states=None,
327 | return_dict=None,
328 | ):
329 |
330 | return_dict = (
331 | return_dict if return_dict is not None else self.config.use_return_dict
332 | )
333 |
334 | outputs = self.flash_quad(
335 | input_ids,
336 | attention_mask=attention_mask,
337 | token_type_ids=token_type_ids,
338 | position_ids=position_ids,
339 | output_attentions=output_attentions,
340 | output_hidden_states=output_hidden_states,
341 | return_dict=return_dict,
342 | )
343 |
344 | sequence_output = outputs[0]
345 |
346 | prediction_scores = self.cls(sequence_output)
347 |
348 | masked_lm_loss = None
349 | if labels is not None:
350 | masked_lm_loss = self.loss_fn(
351 | prediction_scores.reshape(-1, self.config.vocab_size),
352 | labels.reshape(-1),
353 | )
354 |
355 | if not return_dict:
356 | output = (prediction_scores,) + outputs[1:]
357 | return (
358 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
359 | )
360 |
361 | return MaskedLMOutput(
362 | loss=masked_lm_loss,
363 | logits=prediction_scores,
364 | hidden_states=outputs.hidden_states,
365 | attentions=outputs.attentions,
366 | )
367 |
368 |
369 | class FLASHQuadForSequenceClassification(FLASHQuadPreTrainedModel):
370 | def __init__(self, config):
371 | super().__init__(config)
372 | self.num_labels = config.num_labels
373 | self.config = config
374 |
375 | self.flash_quad = FLASHQuadModel(config)
376 | classifier_dropout = (
377 | config.classifier_dropout
378 | if config.classifier_dropout is not None
379 | else config.dropout
380 | )
381 | self.dropout = nn.Dropout(classifier_dropout)
382 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
383 |
384 | # Initialize weights and apply final processing
385 | self.post_init()
386 |
387 | def forward(
388 | self,
389 | input_ids=None,
390 | attention_mask=None,
391 | token_type_ids=None,
392 | position_ids=None,
393 | labels=None,
394 | output_attentions=None,
395 | output_hidden_states=None,
396 | return_dict=None,
397 | ):
398 |
399 | return_dict = (
400 | return_dict if return_dict is not None else self.config.use_return_dict
401 | )
402 |
403 | outputs = self.flash_quad(
404 | input_ids,
405 | attention_mask=attention_mask,
406 | token_type_ids=token_type_ids,
407 | position_ids=position_ids,
408 | output_attentions=output_attentions,
409 | output_hidden_states=output_hidden_states,
410 | return_dict=return_dict,
411 | )
412 | pooled_output = outputs[0][:, 0]
413 | logits = self.classifier(pooled_output)
414 |
415 | loss = None
416 | if labels is not None:
417 | if self.config.problem_type is None:
418 | if self.num_labels == 1:
419 | self.config.problem_type = "regression"
420 | elif self.num_labels > 1 and (
421 | labels.dtype == torch.long or labels.dtype == torch.int
422 | ):
423 | self.config.problem_type = "single_label_classification"
424 | else:
425 | self.config.problem_type = "multi_label_classification"
426 |
427 | if self.config.problem_type == "regression":
428 | loss_fct = nn.MSELoss()
429 | if self.num_labels == 1:
430 | loss = loss_fct(logits.squeeze(), labels.squeeze())
431 | else:
432 | loss = loss_fct(logits, labels)
433 | elif self.config.problem_type == "single_label_classification":
434 | loss_fct = nn.CrossEntropyLoss()
435 | loss = loss_fct(
436 | logits.reshape(-1, self.num_labels), labels.reshape(-1))
437 | elif self.config.problem_type == "multi_label_classification":
438 | loss_fct = nn.BCEWithLogitsLoss()
439 | loss = loss_fct(logits, labels)
440 | if not return_dict:
441 | output = (logits,) + outputs[2:]
442 | return ((loss,) + output) if loss is not None else output
443 |
444 | return SequenceClassifierOutput(
445 | loss=loss,
446 | logits=logits,
447 | hidden_states=outputs.hidden_states,
448 | attentions=outputs.attentions,
449 | )
450 |
--------------------------------------------------------------------------------
/flash/flash_lucidrains.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.modeling_outputs import (
5 | BaseModelOutput,
6 | MaskedLMOutput,
7 | SequenceClassifierOutput,
8 | MultipleChoiceModelOutput,
9 | TokenClassifierOutput,
10 | QuestionAnsweringModelOutput,
11 | )
12 | from transformers.modeling_utils import PreTrainedModel, SequenceSummary
13 |
14 | from transformers.utils import logging
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 | ###################################################copied from https://github.com/lucidrains/FLASH-pytorch
19 | import math
20 | import torch.nn.functional as F
21 | from torch import nn, einsum
22 | from einops import rearrange
23 | from rotary_embedding_torch import RotaryEmbedding
24 |
25 | # helper functions
26 |
27 |
28 | def exists(val):
29 | return val is not None
30 |
31 |
32 | def default(val, d):
33 | return val if exists(val) else d
34 |
35 |
36 | def padding_to_multiple_of(n, mult):
37 | remainder = n % mult
38 | if remainder == 0:
39 | return 0
40 | return mult - remainder
41 |
42 |
43 | # scalenorm
44 |
45 |
46 | class ScaleNorm(nn.Module):
47 | def __init__(self, dim, eps=1e-5):
48 | super().__init__()
49 | self.scale = dim ** -0.5
50 | self.eps = eps
51 | self.g = nn.Parameter(torch.ones(1))
52 |
53 | def forward(self, x):
54 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
55 | return x / norm.clamp(min=self.eps) * self.g
56 |
57 |
58 | # T5 relative positional bias
59 |
60 |
61 | class T5RelativePositionBias(nn.Module):
62 | def __init__(self, scale, causal=False, num_buckets=32, max_distance=128):
63 | super().__init__()
64 | self.scale = scale
65 | self.causal = causal
66 | self.num_buckets = num_buckets
67 | self.max_distance = max_distance
68 | self.relative_attention_bias = nn.Embedding(num_buckets, 1)
69 |
70 | @staticmethod
71 | def _relative_position_bucket(
72 | relative_position, causal=True, num_buckets=32, max_distance=128
73 | ):
74 | ret = 0
75 | n = -relative_position
76 | if not causal:
77 | num_buckets //= 2
78 | ret += (n < 0).long() * num_buckets
79 | n = torch.abs(n)
80 | else:
81 | n = torch.max(n, torch.zeros_like(n))
82 |
83 | max_exact = num_buckets // 2
84 | is_small = n < max_exact
85 |
86 | val_if_large = (
87 | max_exact
88 | + (
89 | torch.log(n.float() / max_exact)
90 | / math.log(max_distance / max_exact)
91 | * (num_buckets - max_exact)
92 | ).long()
93 | )
94 | val_if_large = torch.min(
95 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)
96 | )
97 |
98 | ret += torch.where(is_small, n, val_if_large)
99 | return ret
100 |
101 | def forward(self, x):
102 | i, j, device = *x.shape[-2:], x.device
103 | q_pos = torch.arange(i, dtype=torch.long, device=device)
104 | k_pos = torch.arange(j, dtype=torch.long, device=device)
105 | rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
106 | rp_bucket = self._relative_position_bucket(
107 | rel_pos,
108 | causal=self.causal,
109 | num_buckets=self.num_buckets,
110 | max_distance=self.max_distance,
111 | )
112 | values = self.relative_attention_bias(rp_bucket)
113 | bias = rearrange(values, "i j 1 -> i j")
114 | return bias * self.scale
115 |
116 |
117 | # class
118 |
119 |
120 | class OffsetScale(nn.Module):
121 | def __init__(self, dim, heads=1):
122 | super().__init__()
123 | self.weight = nn.Parameter(torch.ones(heads, dim))
124 | self.bias = nn.Parameter(torch.zeros(heads, dim))
125 | nn.init.normal_(self.weight, std=0.02)
126 |
127 | def forward(self, x):
128 | out = einsum("... d, h d -> ... h d", x, self.weight) + self.bias
129 | return out.unbind(dim=-2)
130 |
131 |
132 | # FLASH
133 |
134 |
135 | class FLASH(nn.Module):
136 | def __init__(
137 | self,
138 | *,
139 | dim,
140 | group_size=256,
141 | query_key_dim=128,
142 | expansion_factor=2.0,
143 | causal=False,
144 | dropout=0.0,
145 | rotary_pos_emb=None,
146 | norm_klass=nn.LayerNorm,
147 | shift_tokens=False
148 | ):
149 | super().__init__()
150 | hidden_dim = int(dim * expansion_factor)
151 | self.group_size = group_size
152 | self.causal = causal
153 | self.shift_tokens = shift_tokens
154 |
155 | # positional embeddings
156 |
157 | self.rotary_pos_emb = rotary_pos_emb
158 | self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal=causal)
159 |
160 | # norm
161 |
162 | self.norm = norm_klass(dim)
163 | self.dropout = nn.Dropout(dropout)
164 |
165 | # projections
166 |
167 | self.to_hidden = nn.Sequential(nn.Linear(dim, hidden_dim * 2), nn.SiLU())
168 |
169 | self.to_qk = nn.Sequential(nn.Linear(dim, query_key_dim), nn.SiLU())
170 |
171 | self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
172 | self.to_out = nn.Linear(hidden_dim, dim)
173 |
174 | def forward(self, x, attention_mask=None, output_attentions=False):
175 | """
176 | b - batch
177 | n - sequence length (within groups)
178 | g - group dimension
179 | d - feature dimension (keys)
180 | e - feature dimension (values)
181 | i - sequence dimension (source)
182 | j - sequence dimension (target)
183 | """
184 | mask = attention_mask
185 | b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
186 |
187 | # prenorm
188 |
189 | normed_x = self.norm(x)
190 |
191 | # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
192 |
193 | if self.shift_tokens:
194 | x_shift, x_pass = normed_x.chunk(2, dim=-1)
195 | x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0)
196 | normed_x = torch.cat((x_shift, x_pass), dim=-1)
197 |
198 | # initial projections
199 |
200 | v, gate = self.to_hidden(normed_x).chunk(2, dim=-1)
201 | qk = self.to_qk(normed_x)
202 |
203 | # offset and scale
204 |
205 | quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
206 |
207 | # mask out linear attention keys
208 |
209 | if exists(mask):
210 | mask = mask.bool()
211 | lin_k = lin_k.masked_fill(~mask[..., None], 0.0)
212 |
213 | # rotate queries and keys
214 |
215 | if exists(self.rotary_pos_emb):
216 | quad_q, lin_q, quad_k, lin_k = map(
217 | self.rotary_pos_emb.rotate_queries_or_keys,
218 | (quad_q, lin_q, quad_k, lin_k),
219 | )
220 |
221 | # padding for groups
222 |
223 | padding = padding_to_multiple_of(n, g)
224 |
225 | if padding > 0:
226 | quad_q, quad_k, lin_q, lin_k, v = map(
227 | lambda t: F.pad(t, (0, 0, 0, padding), value=0.0),
228 | (quad_q, quad_k, lin_q, lin_k, v),
229 | )
230 |
231 | mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
232 | mask = F.pad(mask, (0, padding), value=False)
233 |
234 | # group along sequence
235 |
236 | quad_q, quad_k, lin_q, lin_k, v = map(
237 | lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size),
238 | (quad_q, quad_k, lin_q, lin_k, v),
239 | )
240 |
241 | if exists(mask):
242 | mask = rearrange(mask, "b (g j) -> b g 1 j", j=g)
243 |
244 | # calculate quadratic attention output
245 |
246 | sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g
247 |
248 | sim = sim + self.rel_pos_bias(sim)
249 |
250 | attn = F.relu(sim) ** 2
251 | attn = self.dropout(attn)
252 |
253 | if exists(mask):
254 | attn = attn.masked_fill(~mask, 0.0)
255 |
256 | if self.causal:
257 | causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
258 | attn = attn.masked_fill(causal_mask, 0.0)
259 |
260 | quad_out = einsum("... i j, ... j d -> ... i d", attn, v)
261 |
262 | # calculate linear attention output
263 |
264 | if self.causal:
265 | lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g
266 |
267 | # exclusive cumulative sum along group dimension
268 |
269 | lin_kv = lin_kv.cumsum(dim=1)
270 | lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0)
271 |
272 | lin_out = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q)
273 | else:
274 | lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n
275 | lin_out = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv)
276 |
277 | # fold back groups into full sequence, and excise out padding
278 |
279 | quad_attn_out, lin_attn_out = map(
280 | lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n], (quad_out, lin_out)
281 | )
282 |
283 | # gate
284 |
285 | out = gate * (quad_attn_out + lin_attn_out)
286 |
287 | # projection out and residual
288 | out = self.to_out(out) + x
289 | return (out, attn) if output_attentions else (out,)
290 |
291 |
292 | ###################################################
293 |
294 |
295 | class FLASHConfig(PretrainedConfig):
296 | model_type = "flash"
297 |
298 | def __init__(
299 | self,
300 | vocab_size=12000,
301 | hidden_size=768,
302 | num_hidden_layers=12, # base
303 | max_position_embeddings=512,
304 | group_size=256,
305 | initializer_range=0.02,
306 | layer_norm_eps=1e-5,
307 | pad_token_id=0,
308 | expansion_factor=2,
309 | query_key_dim=128,
310 | norm_type="scalenorm",
311 | gradient_checkpointing=False,
312 | dropout=0.0,
313 | classifier_dropout=0.1,
314 | **kwargs
315 | ):
316 | super().__init__(pad_token_id=pad_token_id, **kwargs)
317 |
318 | self.vocab_size = vocab_size
319 | self.hidden_size = hidden_size
320 | self.num_hidden_layers = num_hidden_layers
321 | self.max_position_embeddings = max_position_embeddings
322 | self.group_size = group_size
323 | self.initializer_range = initializer_range
324 | self.layer_norm_eps = layer_norm_eps
325 | self.expansion_factor = expansion_factor
326 | self.query_key_dim = query_key_dim
327 | self.norm_type = norm_type
328 | self.dropout = dropout
329 | self.gradient_checkpointing = gradient_checkpointing
330 | self.classifier_dropout = classifier_dropout
331 |
332 |
333 | class FLASHPreTrainedModel(PreTrainedModel):
334 | config_class = FLASHConfig
335 | base_model_prefix = "flash"
336 |
337 | def _init_weights(self, module):
338 | """Initialize the weights"""
339 | if isinstance(module, nn.Linear):
340 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
341 | if module.bias is not None:
342 | module.bias.data.zero_()
343 | elif isinstance(module, nn.Embedding):
344 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
345 | if module.padding_idx is not None:
346 | module.weight.data[module.padding_idx].zero_()
347 | elif isinstance(module, nn.LayerNorm):
348 | module.bias.data.zero_()
349 | module.weight.data.fill_(1.0)
350 |
351 |
352 | class FLASHLMPredictionHead(nn.Module):
353 | def __init__(self, config):
354 | super().__init__()
355 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
356 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
357 |
358 | def forward(self, hidden_states):
359 | return self.decoder(self.LayerNorm(hidden_states))
360 |
361 |
362 | class FLASHOnlyMLMHead(nn.Module):
363 | def __init__(self, config):
364 | super().__init__()
365 | self.predictions = FLASHLMPredictionHead(config)
366 |
367 | def forward(self, sequence_output):
368 | prediction_scores = self.predictions(sequence_output)
369 | return prediction_scores
370 |
371 |
372 | class FLASHEmbeddings(nn.Module):
373 | """Construct the embeddings from word, position embeddings."""
374 |
375 | def __init__(self, config):
376 | super().__init__()
377 | self.config = config
378 | self.word_embeddings = nn.Embedding(
379 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
380 | )
381 | self.register_buffer(
382 | "position_ids",
383 | torch.arange(config.max_position_embeddings).expand((1, -1)),
384 | persistent=False,
385 | )
386 | self.scale = nn.Parameter(torch.ones(1))
387 | self.register_buffer(
388 | "scaledsin_embeddings", self.get_scaledsin(), persistent=False
389 | )
390 |
391 | def get_scaledsin(self):
392 | """Create sinusoidal position embedding with a scaling factor."""
393 | seqlen, hidden_size = (
394 | self.config.max_position_embeddings,
395 | self.config.hidden_size,
396 | )
397 | pos = torch.arange(seqlen, dtype=torch.float32)
398 | half_d = hidden_size // 2
399 |
400 | freq_seq = -torch.arange(half_d, dtype=torch.float32) / float(half_d)
401 | inv_freq = 10000 ** freq_seq
402 | sinusoid = torch.einsum("s,d->sd", pos, inv_freq)
403 | scaledsin = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1)
404 | return scaledsin
405 |
406 | def forward(self, input_ids=None, position_ids=None):
407 | input_shape = input_ids.shape
408 | seq_length = input_shape[1]
409 |
410 | if position_ids is None:
411 | position_ids = self.position_ids[:, :seq_length]
412 | inputs_embeds = self.word_embeddings(input_ids)
413 | position_embeddings = self.scaledsin_embeddings[position_ids] * self.scale
414 | embeddings = inputs_embeds + position_embeddings
415 | return embeddings
416 |
417 |
418 | class FLASHEncoder(nn.Module):
419 | def __init__(self, config):
420 | super().__init__()
421 | self.config = config
422 | if config.norm_type == "scalenorm":
423 | norm_klass = ScaleNorm
424 | elif config.norm_type == "layernorm":
425 | norm_klass = nn.LayerNorm
426 | rotary_pos_emb = RotaryEmbedding(dim=min(32, config.query_key_dim))
427 | self.layers = nn.ModuleList(
428 | [
429 | FLASH(
430 | dim=config.hidden_size,
431 | group_size=config.group_size,
432 | query_key_dim=config.query_key_dim,
433 | expansion_factor=config.expansion_factor,
434 | causal=False,
435 | dropout=config.dropout,
436 | rotary_pos_emb=rotary_pos_emb,
437 | norm_klass=norm_klass,
438 | shift_tokens=False,
439 | )
440 | for _ in range(config.num_hidden_layers)
441 | ]
442 | )
443 |
444 | def forward(
445 | self,
446 | hidden_states,
447 | attention_mask=None,
448 | output_attentions=False,
449 | output_hidden_states=False,
450 | return_dict=True,
451 | ):
452 | all_hidden_states = () if output_hidden_states else None
453 | all_self_attentions = () if output_attentions else None
454 |
455 | for i, layer_module in enumerate(self.layers):
456 | if output_hidden_states:
457 | all_hidden_states = all_hidden_states + (hidden_states,)
458 |
459 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
460 |
461 | def create_custom_forward(module):
462 | def custom_forward(*inputs):
463 | return module(*inputs, output_attentions)
464 |
465 | return custom_forward
466 |
467 | layer_outputs = torch.utils.checkpoint.checkpoint(
468 | create_custom_forward(layer_module),
469 | hidden_states,
470 | attention_mask,
471 | output_attentions,
472 | )
473 | else:
474 | layer_outputs = layer_module(
475 | hidden_states, attention_mask, output_attentions
476 | )
477 |
478 | hidden_states = layer_outputs[0]
479 |
480 | if output_attentions:
481 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
482 |
483 | if output_hidden_states:
484 | all_hidden_states = all_hidden_states + (hidden_states,)
485 |
486 | if not return_dict:
487 | return tuple(
488 | v
489 | for v in [
490 | hidden_states,
491 | all_hidden_states,
492 | all_self_attentions,
493 | ]
494 | if v is not None
495 | )
496 | return BaseModelOutput(
497 | last_hidden_state=hidden_states,
498 | hidden_states=all_hidden_states,
499 | attentions=all_self_attentions,
500 | )
501 |
502 |
503 | class FLASHModel(FLASHPreTrainedModel):
504 | def __init__(self, config):
505 | super().__init__(config)
506 | self.config = config
507 |
508 | self.embeddings = FLASHEmbeddings(config)
509 | self.encoder = FLASHEncoder(config)
510 |
511 | self.post_init()
512 |
513 | def get_input_embeddings(self):
514 | return self.embeddings.word_embeddings
515 |
516 | def set_input_embeddings(self, value):
517 | self.embeddings.word_embeddings = value
518 |
519 | def forward(
520 | self,
521 | input_ids=None,
522 | attention_mask=None,
523 | position_ids=None,
524 | output_attentions=None,
525 | output_hidden_states=None,
526 | return_dict=None,
527 | ):
528 | output_attentions = (
529 | output_attentions
530 | if output_attentions is not None
531 | else self.config.output_attentions
532 | )
533 | output_hidden_states = (
534 | output_hidden_states
535 | if output_hidden_states is not None
536 | else self.config.output_hidden_states
537 | )
538 | return_dict = (
539 | return_dict if return_dict is not None else self.config.use_return_dict
540 | )
541 |
542 | embedding_output = self.embeddings(
543 | input_ids=input_ids, position_ids=position_ids
544 | )
545 |
546 | encoder_outputs = self.encoder(
547 | embedding_output,
548 | attention_mask=attention_mask,
549 | output_attentions=output_attentions,
550 | output_hidden_states=output_hidden_states,
551 | return_dict=return_dict,
552 | )
553 | sequence_output = encoder_outputs[0]
554 |
555 | if not return_dict:
556 | return (sequence_output,) + encoder_outputs[1:]
557 |
558 | return BaseModelOutput(
559 | last_hidden_state=sequence_output,
560 | hidden_states=encoder_outputs.hidden_states,
561 | attentions=encoder_outputs.attentions,
562 | )
563 |
564 |
565 | class FLASHForMaskedLM(FLASHPreTrainedModel):
566 | def __init__(self, config):
567 | super().__init__(config)
568 |
569 | self.flash = FLASHModel(config)
570 | self.cls = FLASHOnlyMLMHead(config)
571 |
572 | self.post_init()
573 |
574 | def get_output_embeddings(self):
575 | return self.cls.predictions.decoder
576 |
577 | def set_output_embeddings(self, new_embeddings):
578 | self.cls.predictions.decoder = new_embeddings
579 |
580 | def forward(
581 | self,
582 | input_ids=None,
583 | attention_mask=None,
584 | position_ids=None,
585 | labels=None,
586 | output_attentions=None,
587 | output_hidden_states=None,
588 | return_dict=None,
589 | ):
590 |
591 | return_dict = (
592 | return_dict if return_dict is not None else self.config.use_return_dict
593 | )
594 |
595 | outputs = self.flash(
596 | input_ids,
597 | attention_mask=attention_mask,
598 | position_ids=position_ids,
599 | output_attentions=output_attentions,
600 | output_hidden_states=output_hidden_states,
601 | return_dict=return_dict,
602 | )
603 |
604 | sequence_output = outputs[0]
605 |
606 | prediction_scores = self.cls(sequence_output)
607 |
608 | masked_lm_loss = None
609 | if labels is not None:
610 | loss_fct = nn.CrossEntropyLoss()
611 | masked_lm_loss = loss_fct(
612 | prediction_scores.reshape(-1, self.config.vocab_size),
613 | labels.reshape(-1),
614 | )
615 |
616 | if not return_dict:
617 | output = (prediction_scores,) + outputs[1:]
618 | return (
619 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
620 | )
621 |
622 | return MaskedLMOutput(
623 | loss=masked_lm_loss,
624 | logits=prediction_scores,
625 | hidden_states=outputs.hidden_states,
626 | attentions=outputs.attentions,
627 | )
628 |
629 |
630 | class FLASHForSequenceClassification(FLASHPreTrainedModel):
631 | def __init__(self, config):
632 | super().__init__(config)
633 | self.num_labels = config.num_labels
634 | self.flash = FLASHModel(config)
635 | self.sequence_summary = SequenceSummary(config)
636 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
637 |
638 | # Initialize weights and apply final processing
639 | self.post_init()
640 |
641 | def forward(
642 | self,
643 | input_ids=None,
644 | attention_mask=None,
645 | position_ids=None,
646 | labels=None,
647 | output_attentions=None,
648 | output_hidden_states=None,
649 | return_dict=None,
650 | ):
651 |
652 | return_dict = (
653 | return_dict if return_dict is not None else self.config.use_return_dict
654 | )
655 |
656 | outputs = self.flash(
657 | input_ids,
658 | attention_mask=attention_mask,
659 | position_ids=position_ids,
660 | output_attentions=output_attentions,
661 | output_hidden_states=output_hidden_states,
662 | return_dict=return_dict,
663 | )
664 | pooled_output = self.sequence_summary(outputs[0])
665 | logits = self.classifier(pooled_output)
666 |
667 | loss = None
668 | if labels is not None:
669 | if self.config.problem_type is None:
670 | if self.num_labels == 1:
671 | self.config.problem_type = "regression"
672 | elif self.num_labels > 1 and (
673 | labels.dtype == torch.long or labels.dtype == torch.int
674 | ):
675 | self.config.problem_type = "single_label_classification"
676 | else:
677 | self.config.problem_type = "multi_label_classification"
678 |
679 | if self.config.problem_type == "regression":
680 | loss_fct = nn.MSELoss()
681 | if self.num_labels == 1:
682 | loss = loss_fct(logits.squeeze(), labels.squeeze())
683 | else:
684 | loss = loss_fct(logits, labels)
685 | elif self.config.problem_type == "single_label_classification":
686 | loss_fct = nn.CrossEntropyLoss()
687 | loss = loss_fct(logits.reshape(-1, self.num_labels), labels.reshape(-1))
688 | elif self.config.problem_type == "multi_label_classification":
689 | loss_fct = nn.BCEWithLogitsLoss()
690 | loss = loss_fct(logits, labels)
691 | if not return_dict:
692 | output = (logits,) + outputs[1:]
693 | return ((loss,) + output) if loss is not None else output
694 |
695 | return SequenceClassifierOutput(
696 | loss=loss,
697 | logits=logits,
698 | hidden_states=outputs.hidden_states,
699 | attentions=outputs.attentions,
700 | )
701 |
702 |
703 | class FLASHForMultipleChoice(FLASHPreTrainedModel):
704 | def __init__(self, config):
705 | super().__init__(config)
706 |
707 | self.flash = FLASHModel(config)
708 | self.sequence_summary = SequenceSummary(config)
709 | self.classifier = nn.Linear(config.hidden_size, 1)
710 |
711 | # Initialize weights and apply final processing
712 | self.post_init()
713 |
714 | def forward(
715 | self,
716 | input_ids=None,
717 | attention_mask=None,
718 | position_ids=None,
719 | labels=None,
720 | output_attentions=None,
721 | output_hidden_states=None,
722 | return_dict=None,
723 | ):
724 |
725 | return_dict = (
726 | return_dict if return_dict is not None else self.config.use_return_dict
727 | )
728 | num_choices = input_ids.shape[1]
729 | input_ids = (
730 | input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None
731 | )
732 | attention_mask = (
733 | attention_mask.reshape(-1, attention_mask.size(-1))
734 | if attention_mask is not None
735 | else None
736 | )
737 | position_ids = (
738 | position_ids.reshape(-1, position_ids.size(-1))
739 | if position_ids is not None
740 | else None
741 | )
742 |
743 | outputs = self.flash(
744 | input_ids,
745 | attention_mask=attention_mask,
746 | position_ids=position_ids,
747 | output_attentions=output_attentions,
748 | output_hidden_states=output_hidden_states,
749 | return_dict=return_dict,
750 | )
751 |
752 | pooled_output = self.sequence_summary(outputs[0])
753 | logits = self.classifier(pooled_output)
754 | reshaped_logits = logits.reshape(-1, num_choices)
755 |
756 | loss = None
757 | if labels is not None:
758 | loss_fct = nn.CrossEntropyLoss()
759 | loss = loss_fct(reshaped_logits, labels)
760 |
761 | if not return_dict:
762 | output = (reshaped_logits,) + outputs[1:]
763 | return ((loss,) + output) if loss is not None else output
764 |
765 | return MultipleChoiceModelOutput(
766 | loss=loss,
767 | logits=reshaped_logits,
768 | hidden_states=outputs.hidden_states,
769 | attentions=outputs.attentions,
770 | )
771 |
772 |
773 | class FLASHForTokenClassification(FLASHPreTrainedModel):
774 | def __init__(self, config):
775 | super().__init__(config)
776 | self.num_labels = config.num_labels
777 | self.flash = FLASHModel(config)
778 | classifier_dropout = (
779 | config.classifier_dropout
780 | if config.classifier_dropout is not None
781 | else config.dropout
782 | )
783 | self.dropout = nn.Dropout(classifier_dropout)
784 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
785 |
786 | # Initialize weights and apply final processing
787 | self.post_init()
788 |
789 | def forward(
790 | self,
791 | input_ids=None,
792 | attention_mask=None,
793 | position_ids=None,
794 | labels=None,
795 | output_attentions=None,
796 | output_hidden_states=None,
797 | return_dict=None,
798 | ):
799 | return_dict = (
800 | return_dict if return_dict is not None else self.config.use_return_dict
801 | )
802 |
803 | outputs = self.flash(
804 | input_ids,
805 | attention_mask=attention_mask,
806 | position_ids=position_ids,
807 | output_attentions=output_attentions,
808 | output_hidden_states=output_hidden_states,
809 | return_dict=return_dict,
810 | )
811 |
812 | sequence_output = outputs[0]
813 | sequence_output = self.dropout(sequence_output)
814 | logits = self.classifier(sequence_output)
815 |
816 | loss = None
817 | if labels is not None:
818 | loss_fct = nn.CrossEntropyLoss()
819 | loss = loss_fct(logits.reshape(-1, self.num_labels), labels.reshape(-1))
820 |
821 | if not return_dict:
822 | output = (logits,) + outputs[1:]
823 | return ((loss,) + output) if loss is not None else output
824 |
825 | return TokenClassifierOutput(
826 | loss=loss,
827 | logits=logits,
828 | hidden_states=outputs.hidden_states,
829 | attentions=outputs.attentions,
830 | )
831 |
832 |
833 | class FLASHForQuestionAnswering(FLASHPreTrainedModel):
834 | def __init__(self, config):
835 | super().__init__(config)
836 | self.num_labels = config.num_labels
837 |
838 | self.flash = FLASHModel(config)
839 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
840 |
841 | # Initialize weights and apply final processing
842 | self.post_init()
843 |
844 | def forward(
845 | self,
846 | input_ids=None,
847 | attention_mask=None,
848 | position_ids=None,
849 | start_positions=None,
850 | end_positions=None,
851 | output_attentions=None,
852 | output_hidden_states=None,
853 | return_dict=None,
854 | ):
855 |
856 | return_dict = (
857 | return_dict if return_dict is not None else self.config.use_return_dict
858 | )
859 |
860 | outputs = self.flash(
861 | input_ids,
862 | attention_mask=attention_mask,
863 | position_ids=position_ids,
864 | output_attentions=output_attentions,
865 | output_hidden_states=output_hidden_states,
866 | return_dict=return_dict,
867 | )
868 |
869 | sequence_output = outputs[0]
870 |
871 | logits = self.qa_outputs(sequence_output)
872 | start_logits, end_logits = logits.split(1, dim=-1)
873 | start_logits = start_logits.squeeze(-1).contiguous()
874 | end_logits = end_logits.squeeze(-1).contiguous()
875 |
876 | total_loss = None
877 | if start_positions is not None and end_positions is not None:
878 | # If we are on multi-GPU, split add a dimension
879 | if start_positions.ndim > 1:
880 | start_positions = start_positions.squeeze(-1)
881 | if start_positions.ndim > 1:
882 | end_positions = end_positions.squeeze(-1)
883 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
884 | ignored_index = start_logits.size(1)
885 | start_positions = start_positions.clamp(0, ignored_index)
886 | end_positions = end_positions.clamp(0, ignored_index)
887 |
888 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
889 | start_loss = loss_fct(start_logits, start_positions)
890 | end_loss = loss_fct(end_logits, end_positions)
891 | total_loss = (start_loss + end_loss) / 2
892 |
893 | if not return_dict:
894 | output = (start_logits, end_logits) + outputs[1:]
895 | return ((total_loss,) + output) if total_loss is not None else output
896 |
897 | return QuestionAnsweringModelOutput(
898 | loss=total_loss,
899 | start_logits=start_logits,
900 | end_logits=end_logits,
901 | hidden_states=outputs.hidden_states,
902 | attentions=outputs.attentions,
903 | )
904 |
--------------------------------------------------------------------------------
/flash/gau.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers.activations import ACT2FN
5 |
6 |
7 | def rope(x, dim):
8 | """RoPE position embedding."""
9 | shape = x.shape
10 | if isinstance(dim, int):
11 | dim = [dim]
12 | spatial_shape = [shape[i] for i in dim]
13 | total_len = 1
14 | for i in spatial_shape:
15 | total_len *= i
16 | position = torch.reshape(
17 | torch.arange(total_len, dtype=x.dtype,
18 | device=x.device), spatial_shape
19 | )
20 | for i in range(dim[-1] + 1, len(shape) - 1, 1):
21 | position = position.unsqueeze(-1)
22 | half_size = shape[-1] // 2
23 | freq_seq = -torch.arange(half_size, dtype=x.dtype, device=x.device) / float(
24 | half_size
25 | )
26 | inv_freq = 10000 ** freq_seq
27 | sinusoid = torch.einsum("...,d->...d", position, inv_freq)
28 | sin = sinusoid.sin()
29 | cos = sinusoid.cos()
30 | x1, x2 = torch.chunk(x, 2, dim=-1)
31 |
32 | return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
33 |
34 |
35 | class ScaleNorm(nn.Module):
36 | def __init__(self, eps=1e-5):
37 | super().__init__()
38 | self.eps = eps
39 | self.scala = nn.Parameter(torch.ones(1))
40 |
41 | def forward(self, x):
42 | mean_square = (x ** 2).mean(dim=-1, keepdim=True)
43 | x = x * torch.rsqrt(mean_square + self.eps) * self.scala
44 | return x
45 |
46 |
47 | class GAU(nn.Module):
48 | """GAU block.
49 | Input shape: batch size x sequence length x model size
50 | """
51 |
52 | def __init__(
53 | self,
54 | hidden_size=768,
55 | expansion_factor=2,
56 | s=128,
57 | norm_type="layer_norm",
58 | eps=1e-5,
59 | hidden_act="silu",
60 | max_position_embeddings=512,
61 | ):
62 | super().__init__()
63 | self.s = s
64 | self.e = int(hidden_size * expansion_factor)
65 | self.uv = nn.Linear(hidden_size, 2 * self.e + self.s)
66 | self.weight = nn.Parameter(torch.randn(2, self.s))
67 | self.bias = nn.Parameter(torch.zeros(2, self.s))
68 | self.o = nn.Linear(self.e, hidden_size)
69 | self.LayerNorm = (
70 | nn.LayerNorm(hidden_size, eps=eps)
71 | if norm_type == "layer_norm"
72 | else ScaleNorm(eps=eps)
73 | )
74 | self.w = nn.Parameter(torch.randn(2 * max_position_embeddings - 1))
75 | self.a = nn.Parameter(torch.randn(1, self.s))
76 | self.b = nn.Parameter(torch.randn(1, self.s))
77 | self.act_fn = ACT2FN[hidden_act]
78 | self.max_position_embeddings = max_position_embeddings
79 |
80 | nn.init.normal_(self.weight, std=0.02)
81 | nn.init.normal_(self.w, std=0.02)
82 | nn.init.normal_(self.a, std=0.02)
83 | nn.init.normal_(self.b, std=0.02)
84 |
85 | def rel_pos_bias(self, seq_len):
86 | """Relative position bias."""
87 | if seq_len <= 512:
88 | # Construct Toeplitz matrix directly when the sequence length is less than 512
89 | t = F.pad(self.w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
90 | t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
91 | r = (2 * seq_len - 1) // 2
92 | t = t[..., r:-r]
93 | else:
94 | # Construct Toeplitz matrix using RoPE when the sequence length is over 512.
95 | a = rope(self.a.repeat(seq_len, 1), dim=0)
96 | b = rope(self.b.repeat(seq_len, 1), dim=0)
97 | t = torch.einsum("mk,nk ->mn", a, b)
98 |
99 | return t
100 |
101 | def forward(self, x, attention_mask=None, output_attentions=False, causal=False):
102 | seq_len = x.shape[1]
103 | shortcut, x = x, self.LayerNorm(x)
104 | uv = self.uv(x)
105 | u, v, base = torch.split(self.act_fn(
106 | uv), [self.e, self.e, self.s], dim=-1)
107 | # Generate Query (q) and Key (k) from base.
108 | base = torch.einsum("...r,hr->...hr", base, self.weight) + self.bias
109 | base = rope(base, dim=1)
110 | q, k = torch.unbind(base, dim=-2)
111 | # Calculate the quadratic attention.
112 | qk = torch.einsum("bnd,bmd->bnm", q, k)
113 |
114 | bias = self.rel_pos_bias(self.max_position_embeddings)[
115 | :, :seq_len, :seq_len]
116 | kernel = torch.square(torch.relu(
117 | qk / self.max_position_embeddings + bias))
118 | # attention_mask
119 | if attention_mask is not None:
120 | assert attention_mask.ndim == 2
121 | attn_mask = (
122 | attention_mask[:, None, :] * attention_mask[:, :, None]
123 | ).type_as(x)
124 | kernel *= attn_mask
125 |
126 | if causal:
127 | causal_mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0)
128 | kernel *= causal_mask
129 |
130 | x = u * torch.einsum("bnm,bme->bne", kernel, v)
131 | x = self.o(x)
132 | if output_attentions:
133 | return x + shortcut, kernel
134 | return (x + shortcut,)
135 |
--------------------------------------------------------------------------------
/pretrain.sh:
--------------------------------------------------------------------------------
1 | TRAIN_DIR=./clue_small_wwm_data
2 | OUTPUT_DIR=./wwm_flash_small/
3 | BATCH_SIZE=32
4 | ACCUMULATION=4
5 | LR=1e-4
6 | python run_mlm_wwm.py \
7 | --do_train \
8 | --tokenizer_name junnyu/roformer_chinese_char_base \
9 | --train_dir $TRAIN_DIR \
10 | --output_dir $OUTPUT_DIR \
11 | --logging_dir $OUTPUT_DIR/logs \
12 | --per_device_train_batch_size $BATCH_SIZE \
13 | --gradient_accumulation_steps $ACCUMULATION \
14 | --learning_rate $LR \
15 | --weight_decay 0.01 \
16 | --adam_epsilon 1e-6 \
17 | --max_steps 250000 \
18 | --warmup_steps 5000 \
19 | --logging_steps 100 \
20 | --save_steps 5000 \
21 | --seed 2022 \
22 | --max_grad_norm 3.0 \
23 | --dataloader_num_workers 6 \
24 | --fp16
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.16.2
2 | rjieba
3 | datasets
4 | rotary_embedding_torch
--------------------------------------------------------------------------------
/run_chinese_ref.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import multiprocessing
4 | import os
5 | import sys
6 | import time
7 | from typing import List
8 |
9 | import rjieba
10 | from datasets import Dataset, concatenate_datasets, load_dataset
11 | from tqdm import tqdm
12 | from transformers import BertTokenizerFast
13 |
14 |
15 | def _is_chinese_char(cp):
16 | """Checks whether CP is the codepoint of a CJK character."""
17 | # This defines a "chinese character" as anything in the CJK Unicode block:
18 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
19 | #
20 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
21 | # despite its name. The modern Korean Hangul alphabet is a different block,
22 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
23 | # space-separated words, so they are not treated specially and handled
24 | # like the all of the other languages.
25 | if (
26 | (cp >= 0x4E00 and cp <= 0x9FFF)
27 | or (cp >= 0x3400 and cp <= 0x4DBF) #
28 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
29 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
30 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
31 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
32 | or (cp >= 0xF900 and cp <= 0xFAFF)
33 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
34 | ): #
35 | return True
36 |
37 | return False
38 |
39 |
40 | def is_chinese(word: str):
41 | # word like '180' or '身高' or '神'
42 | for char in word:
43 | char = ord(char)
44 | if not _is_chinese_char(char):
45 | return False
46 | return True
47 |
48 |
49 | def get_chinese_word(tokens: List[str]):
50 | word_set = set()
51 |
52 | for token in tokens:
53 | chinese_word = len(token) > 1 and is_chinese(token)
54 | if chinese_word:
55 | word_set.add(token)
56 | word_list = list(word_set)
57 | return word_list
58 |
59 |
60 | def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
61 | if not chinese_word_set:
62 | return bert_tokens
63 | max_word_len = max([len(w) for w in chinese_word_set])
64 |
65 | bert_word = bert_tokens
66 | start, end = 0, len(bert_word)
67 | while start < end:
68 | single_word = True
69 | if is_chinese(bert_word[start]):
70 | l = min(end - start, max_word_len)
71 | for i in range(l, 1, -1):
72 | whole_word = "".join(bert_word[start : start + i])
73 | if whole_word in chinese_word_set:
74 | for j in range(start + 1, start + i):
75 | bert_word[j] = "##" + bert_word[j]
76 | start = start + i
77 | single_word = False
78 | break
79 | if single_word:
80 | start += 1
81 | return bert_word
82 |
83 |
84 | block_size = 512
85 |
86 |
87 | class BlockSizeSplitter:
88 | def tokenize(self, text):
89 | tstr = ""
90 | all_ts = []
91 | for txt in text.split("\n"):
92 | if len(tstr) > block_size:
93 | all_ts.append(tstr)
94 | tstr = ""
95 | tstr += txt
96 | if len(tstr) > 0:
97 | all_ts.append(tstr)
98 | return all_ts
99 |
100 |
101 | def jieba_segmentation_fn():
102 | def process(line):
103 | words = rjieba.cut(line)
104 | return words
105 |
106 | return process
107 |
108 |
109 | class Converter:
110 | def __init__(self, args):
111 | self.args = args
112 |
113 | def initializer(self):
114 | Converter.tokenizer = BertTokenizerFast.from_pretrained(self.args.model_name)
115 |
116 | # Split document to sentence.
117 | Converter.splitter = BlockSizeSplitter()
118 | Converter.segment_func = jieba_segmentation_fn()
119 |
120 | def process(text):
121 | words = Converter.segment_func(text)
122 | new_text = "".join(words).replace("\n", "")
123 | chinese_word = get_chinese_word(words)
124 | input_tokens = (
125 | [Converter.tokenizer.cls_token]
126 | + Converter.tokenizer.tokenize(new_text)
127 | + [Converter.tokenizer.sep_token]
128 | )
129 |
130 | input_tokens = add_sub_symbol(input_tokens, chinese_word)
131 | ref_id = []
132 | for i, token in enumerate(input_tokens):
133 | if token[:2] == "##":
134 | clean_token = token[2:]
135 | # save chinese tokens' pos
136 | if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
137 | ref_id.append(i)
138 |
139 | return ref_id, new_text
140 |
141 | Converter.process = process
142 |
143 | def encode(self, json_line):
144 | text = json.loads(json_line)[self.args.json_key]
145 | ref_ids = []
146 | all_texts = []
147 | for sentence in Converter.splitter.tokenize(text):
148 | ref_id, new_text = Converter.process(sentence.strip())
149 | if len(new_text) < 20:
150 | continue
151 | if len(ref_id) > 0 and len(new_text) > 0:
152 | ref_ids.append(ref_id)
153 | all_texts.append(new_text)
154 |
155 | return ref_ids, all_texts, len(text.encode("utf-8"))
156 |
157 |
158 | def main(args):
159 |
160 | file_paths = []
161 | if os.path.isfile(args.input_path):
162 | file_paths.append(args.input_path)
163 | else:
164 | for root, _, fs in os.walk(args.input_path):
165 | for f in fs:
166 | file_paths.append(os.path.join(root, f))
167 | convert = Converter(args)
168 | pool = multiprocessing.Pool(args.workers, initializer=convert.initializer)
169 | step = 0
170 | total_bytes_processed = 0
171 | startup_start = time.time()
172 | with open("data/refids.txt", "w+", encoding="utf8") as w1:
173 | with open("data/reftext.txt", "w+", encoding="utf8") as w2:
174 | for file_path in tqdm(file_paths):
175 | if file_path.endswith(".jsonl"):
176 | text = open(file_path, "r", encoding="utf-8")
177 | else:
178 | print("Unexpected data format, skiped %s" % file_path)
179 | continue
180 |
181 | encoded_docs = pool.imap(convert.encode, text, 256)
182 | print("Processing %s" % file_path)
183 | for rid, alltxt, bytes_processed in encoded_docs:
184 | step += 1
185 | total_bytes_processed += bytes_processed
186 | if len(rid) == 0:
187 | continue
188 |
189 | for sentence in rid:
190 | sentence_len = len(sentence)
191 | if sentence_len == 0:
192 | continue
193 | w1.write(str(sentence) + "\n")
194 | for txt in alltxt:
195 | txt_len = len(txt)
196 | if txt_len == 0:
197 | continue
198 | w2.write(txt + "\n")
199 |
200 | if step % args.log_interval == 0:
201 | current = time.time()
202 | elapsed = current - startup_start
203 | mbs = total_bytes_processed / elapsed / 1024 / 1024
204 | print(
205 | f"Processed {step} documents",
206 | f"({step/elapsed:.2f} docs/s, {mbs:.4f} MB/s).",
207 | file=sys.stderr,
208 | )
209 | pool.close()
210 | print("Saving tokens to files...")
211 |
212 | # concatenate_datasets
213 | print("concatenate_datasets...")
214 | reftext = load_dataset("text", data_files="data/reftext.txt")["train"]
215 | refids = load_dataset("text", data_files="data/refids.txt")["train"]
216 | refids = refids.rename_column("text", "chinese_ref")
217 | refids = refids.map(lambda example: {"chinese_ref": eval(example["chinese_ref"])})
218 | concat_ds = concatenate_datasets([reftext, refids], axis=1)
219 | concat_ds.save_to_disk("./clue_small_wwm_data")
220 |
221 |
222 | if __name__ == "__main__":
223 | parser = argparse.ArgumentParser()
224 | parser.add_argument(
225 | "--model_name",
226 | type=str,
227 | default="junnyu/roformer_chinese_char_base",
228 | help="What model to use.",
229 | )
230 |
231 | group = parser.add_argument_group(title="data input/output")
232 | group.add_argument(
233 | "--input_path",
234 | type=str,
235 | default="data/clue_corpus_small_14g.jsonl",
236 | help="Path to input JSON files.",
237 | )
238 |
239 | group.add_argument(
240 | "--json_key",
241 | type=str,
242 | default="text",
243 | help="For JSON format. Space separate listed of keys to extract from json",
244 | )
245 |
246 | group = parser.add_argument_group(title="common config")
247 |
248 | group.add_argument(
249 | "--log_interval",
250 | type=int,
251 | default=100,
252 | help="Interval between progress updates",
253 | )
254 | group.add_argument(
255 | "--workers", type=int, default=12, help="Number of worker processes to launch"
256 | )
257 |
258 | args = parser.parse_args()
259 | main(args)
260 |
--------------------------------------------------------------------------------
/run_mlm_wwm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Team All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
16 |
17 | import logging
18 | import os
19 | import sys
20 | import time
21 | from dataclasses import dataclass, field
22 | from typing import Optional
23 |
24 | import transformers
25 | from datasets import Dataset
26 | from transformers import (
27 | BertTokenizerFast,
28 | DataCollatorForWholeWordMask,
29 | HfArgumentParser,
30 | TrainingArguments,
31 | set_seed,
32 | )
33 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
34 |
35 | from flash import FLASHQuadConfig, FLASHQuadForMaskedLM, FLASHConfig, FLASHForMaskedLM
36 | from mlm_trainer import Trainer
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | name2cls = {
42 | "flash": (FLASHConfig, FLASHForMaskedLM ),
43 | "flashquad" : (FLASHQuadConfig, FLASHQuadForMaskedLM ),
44 | }
45 |
46 | @dataclass
47 | class ModelArguments:
48 | """
49 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
50 | """
51 |
52 | tokenizer_name: Optional[str] = field(
53 | default="junnyu/roformer_chinese_char_base",
54 | metadata={
55 | "help": "Pretrained tokenizer name or path if not the same as model_name"
56 | },
57 | )
58 | model_name: Optional[str] = field(
59 | default="flash",
60 | metadata={
61 | "help": "model_name"
62 | },
63 | )
64 |
65 | @dataclass
66 | class DataTrainingArguments:
67 | """
68 | Arguments pertaining to what data we are going to input our model for training and eval.
69 | """
70 |
71 | train_dir: Optional[str] = field(
72 | default="./clue_small_wwm_data",
73 | metadata={"help": "The input training data file."},
74 | )
75 | overwrite_cache: bool = field(
76 | default=False,
77 | metadata={"help": "Overwrite the cached training and evaluation sets"},
78 | )
79 | max_seq_length: Optional[int] = field(
80 | default=512,
81 | metadata={
82 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
83 | "than this will be truncated. Default to the max input length of the model."
84 | },
85 | )
86 | preprocessing_num_workers: Optional[int] = field(
87 | default=12,
88 | metadata={"help": "The number of processes to use for the preprocessing."},
89 | )
90 | mlm_probability: float = field(
91 | default=0.15,
92 | metadata={
93 | "help": "Ratio of tokens to mask for masked language modeling loss"},
94 | )
95 | pad_to_max_length: bool = field(
96 | default=False,
97 | metadata={
98 | "help": "Whether to pad all samples to `max_seq_length`. "
99 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
100 | },
101 | )
102 |
103 |
104 | def main():
105 | # See all possible arguments in src/transformers/training_args.py
106 | # or by passing the --help flag to this script.
107 | # We now keep distinct sets of args, for a cleaner separation of concerns.
108 |
109 | parser = HfArgumentParser(
110 | (ModelArguments, DataTrainingArguments, TrainingArguments)
111 | )
112 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
113 | # If we pass only one argument to the script and it's the path to a json file,
114 | # let's parse it to get our arguments.
115 | model_args, data_args, training_args = parser.parse_json_file(
116 | json_file=os.path.abspath(sys.argv[1])
117 | )
118 | else:
119 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
120 |
121 | # Detecting last checkpoint.
122 | last_checkpoint = None
123 | if (
124 | os.path.isdir(training_args.output_dir)
125 | and not training_args.overwrite_output_dir
126 | ):
127 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
128 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
129 | raise ValueError(
130 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
131 | "Use --overwrite_output_dir to overcome."
132 | )
133 | elif last_checkpoint is not None:
134 | logger.info(
135 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
136 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
137 | )
138 |
139 | # Setup logging
140 | logging.basicConfig(
141 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
142 | datefmt="%m/%d/%Y %H:%M:%S",
143 | handlers=[logging.StreamHandler(sys.stdout)],
144 | )
145 | logger.setLevel(
146 | logging.INFO if is_main_process(
147 | training_args.local_rank) else logging.WARN
148 | )
149 |
150 | # Log on each process the small summary:
151 | logger.warning(
152 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
153 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
154 | )
155 | # Set the verbosity to info of the Transformers logger (on main process only):
156 | if is_main_process(training_args.local_rank):
157 | transformers.utils.logging.set_verbosity_info()
158 | transformers.utils.logging.enable_default_handler()
159 | transformers.utils.logging.enable_explicit_format()
160 | logger.info("Training parameters %s", training_args)
161 |
162 | # Set seed before initializing model.
163 | set_seed(training_args.seed)
164 |
165 | # download the dataset.
166 | # 加载clue_wwm_13g数据集
167 | datasets = Dataset.load_from_disk(data_args.train_dir)
168 |
169 | config_cls, model_cls = name2cls[model_args.model_name]
170 | config = config_cls(num_hidden_layers=12) # small
171 | # tokenizer使用了roformer_chinese_char_base
172 | tokenizer = BertTokenizerFast.from_pretrained(model_args.tokenizer_name)
173 | model = model_cls(config)
174 | model.resize_token_embeddings(len(tokenizer))
175 |
176 | # Preprocessing the datasets.
177 | # First we tokenize all the texts.
178 | column_names = datasets.column_names
179 | text_column_name = "text" if "text" in column_names else column_names[0]
180 |
181 | padding = "max_length" if data_args.pad_to_max_length else False
182 |
183 | def tokenize_function(examples):
184 | # Remove empty lines
185 | texts = []
186 | chinese_ref = []
187 | for text, ref in zip(examples["text"], examples["chinese_ref"]):
188 | if len(text) > 0 and not text.isspace():
189 | texts.append(text.strip())
190 | chinese_ref.append(ref)
191 | examples["text"] = texts
192 | examples["chinese_ref"] = chinese_ref
193 | data = tokenizer(
194 | examples["text"],
195 | padding=padding,
196 | truncation=True,
197 | max_length=data_args.max_seq_length,
198 | return_token_type_ids=False,
199 | return_attention_mask=False,
200 | )
201 | data["text"] = texts
202 | data["chinese_ref"] = chinese_ref
203 | return data
204 |
205 | tokenized_datasets = datasets.map(
206 | tokenize_function,
207 | batched=True,
208 | num_proc=data_args.preprocessing_num_workers,
209 | remove_columns=[text_column_name],
210 | load_from_cache_file=not data_args.overwrite_cache,
211 | new_fingerprint="clue_13g_small_roformer_wwm",
212 | )
213 |
214 | training_args.remove_unused_columns = False
215 |
216 | # Data collator
217 | # This one will take care of randomly masking the tokens.
218 | data_collator = DataCollatorForWholeWordMask(
219 | tokenizer=tokenizer,
220 | mlm_probability=data_args.mlm_probability,
221 | pad_to_multiple_of=8 if training_args.fp16 else None,
222 | )
223 |
224 | # Initialize our Trainer
225 | trainer = Trainer(
226 | model=model,
227 | args=training_args,
228 | train_dataset=tokenized_datasets,
229 | eval_dataset=None,
230 | tokenizer=tokenizer,
231 | data_collator=data_collator,
232 | )
233 | # trainer.add_callback(LoggingCallback(save_interval=training_args.save_interval))
234 | # Training
235 | if last_checkpoint is not None:
236 | checkpoint = last_checkpoint
237 | else:
238 | checkpoint = None
239 |
240 | logger.info("Training a model...")
241 | start_time = time.time()
242 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
243 | train_time = time.time() - start_time
244 | logger.info(f"Training time: {train_time}")
245 | trainer.save_model() # Saves the tokenizer too for easy upload
246 |
247 | output_train_file = os.path.join(
248 | training_args.output_dir, "train_results.txt")
249 | if trainer.is_world_process_zero():
250 | with open(output_train_file, "w") as writer:
251 | logger.info("***** Train results *****")
252 | for key, value in sorted(train_result.metrics.items()):
253 | logger.info(f" {key} = {value}")
254 | writer.write(f"{key} = {value}\n")
255 |
256 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
257 | trainer.state.save_to_json(
258 | os.path.join(training_args.output_dir, "trainer_state.json")
259 | )
260 |
261 |
262 | if __name__ == "__main__":
263 | main()
264 |
--------------------------------------------------------------------------------
/trans_to_json.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import json
17 | import multiprocessing
18 | import os
19 | import re
20 | import shutil
21 | import sys
22 | import time
23 | from functools import partial
24 |
25 |
26 | def get_args():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument(
29 | "--input_path",
30 | type=str,
31 | required=True,
32 | help="Path to you raw files. Folder or file path.",
33 | )
34 | parser.add_argument(
35 | "--output_path",
36 | type=str,
37 | required=True,
38 | help="Path to save the output json files.",
39 | )
40 | parser.add_argument(
41 | "--json_key", type=str, default="text", help="The content key of json file."
42 | )
43 | parser.add_argument(
44 | "--doc_spliter",
45 | type=str,
46 | default="",
47 | help="Spliter between documents. We will strip the line, if you use blank line to split doc, leave it blank.",
48 | )
49 | parser.add_argument(
50 | "--min_doc_length", type=int, default=10, help="Minimal char of a documment."
51 | )
52 | parser.add_argument(
53 | "--workers", type=int, default=1, help="Number of worker processes to launch"
54 | )
55 | parser.add_argument(
56 | "--log_interval", type=int, default=1, help="Interval between progress updates."
57 | )
58 | parser.add_argument("--no-merge", action="store_true", help="Don't merge the file.")
59 | parser.add_argument(
60 | "--no-shuffle", action="store_true", help="Don't shuffle the file."
61 | )
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | def raw_text_to_json(path, doc_spliter="", json_key="text", min_doc_length=10):
67 | path = os.path.abspath(path)
68 | if not os.path.exists(path):
69 | print("No found file %s" % path)
70 | return 0, None
71 |
72 | out_filepath = path + ".jsonl"
73 | fout = open(out_filepath, "w", encoding="utf-8")
74 | len_files = 0
75 | with open(path, "r") as f:
76 | doc = ""
77 | line = f.readline()
78 | while line:
79 | len_files += len(line)
80 | if line.strip() == doc_spliter:
81 | if len(doc) > min_doc_length:
82 | fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n")
83 | doc = ""
84 | else:
85 | doc += line
86 | line = f.readline()
87 |
88 | if len(doc) > min_doc_length:
89 | fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n")
90 | doc = ""
91 |
92 | return len_files, out_filepath
93 |
94 |
95 | def merge_file(file_paths, output_path):
96 | if not output_path.endswith(".jsonl"):
97 | output_path = output_path + ".jsonl"
98 | print("Merging files into %s" % output_path)
99 | with open(output_path, "wb") as wfd:
100 | for f in file_paths:
101 | if f is not None and os.path.exists(f):
102 | with open(f, "rb") as fd:
103 | shutil.copyfileobj(fd, wfd)
104 | os.remove(f)
105 | print("File save in %s" % output_path)
106 | return output_path
107 |
108 |
109 | def shuffle_file(output_path):
110 | print("Shuffling the jsonl file...")
111 | if os.path.exists(output_path):
112 | os.system("shuf %s -o %s" % (output_path, output_path))
113 | print("File shuffled!!!")
114 | else:
115 | raise ValueError("File not found: %s" % output_path)
116 |
117 |
118 | def main():
119 | args = get_args()
120 | startup_start = time.time()
121 |
122 | file_paths = []
123 | if os.path.isfile(args.input_path):
124 | file_paths.append(args.input_path)
125 | else:
126 | for root, _, fs in os.walk(args.input_path):
127 | for f in fs:
128 | file_paths.append(os.path.join(root, f))
129 |
130 | pool = multiprocessing.Pool(args.workers)
131 |
132 | startup_end = time.time()
133 | proc_start = time.time()
134 | total_bytes_processed = 0
135 | print("Time to startup:", startup_end - startup_start)
136 |
137 | trans_json = partial(
138 | raw_text_to_json,
139 | doc_spliter=args.doc_spliter,
140 | json_key=args.json_key,
141 | min_doc_length=args.min_doc_length,
142 | )
143 | encoded_files = pool.imap(trans_json, file_paths, 1)
144 |
145 | out_paths = []
146 | for i, (bytes_processed, out_path) in enumerate(encoded_files, start=1):
147 | total_bytes_processed += bytes_processed
148 | out_paths.append(out_path)
149 | if i % args.log_interval == 0:
150 | current = time.time()
151 | elapsed = current - proc_start
152 | mbs = total_bytes_processed / elapsed / 1024 / 1024
153 | print(
154 | f"Processed {i} files",
155 | f"({i/elapsed} files/s, {mbs} MB/s).",
156 | file=sys.stderr,
157 | )
158 |
159 | if not args.no_merge:
160 | output_path = merge_file(out_paths, args.output_path)
161 | if not args.no_shuffle:
162 | shuffle_file(output_path)
163 |
164 |
165 | if __name__ == "__main__":
166 | main()
167 |
--------------------------------------------------------------------------------