├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── clue_small_wwm_data ├── dataset.arrow ├── dataset_info.json ├── place_holder.txt └── state.json ├── data ├── clue_corpus_small_14g.jsonl ├── refids.txt └── reftext.txt ├── demo.ipynb ├── figure └── tnews.jpg ├── flash ├── __init__.py ├── flash.py ├── flash_lucidrains.py └── gau.py ├── mlm_trainer.py ├── pretrain.sh ├── requirements.txt ├── run_chinese_ref.py ├── run_mlm_wwm.py └── trans_to_json.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=python 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLASHQuad_pytorch & FLASH_pytorch 2 | pytorch implement of FLASHQuad and FLASH 3 | 4 | # Describtion 5 | 个人实现`pytorch`版本的[《Transformer Quality in Linear Time》](https://arxiv.org/abs/2202.10447) 6 | 7 | # 存在的问题 8 | - `A = square(relu(qk / seq_len + bias))`感觉不对劲,假设训练的时候都是在seq_len=512的长度上进行的,如果预测的时候seq_len=16时,A的结果会发生很大的变化。 9 | - `embedding`部分和`MLM head`不确定是否使用的是`ScaleNorm`,不确定是否使用到`dropout`。 10 | - 发现当前代码训练出来的模型结果不理想,`n-1`层的输出和`n`层的输出差距不大。 11 | 12 | 13 | # 更新 14 | - 2022/04/01 添加[lucidrains实现的FLASH-pytorch](https://github.com/lucidrains/FLASH-pytorch),并添加训练好的权重[flash_small_wwm_cluecorpussmall权重](https://huggingface.co/junnyu/flash_small_wwm_cluecorpussmall) 和 [训练日志](https://wandb.ai/junyu/huggingface/runs/1jg2jlgt)。 15 | - 2022/03/01 使用带有`mlm_acc`的`Trainer`,训练过程中可以监控训练集`每logging_steps`的MLM准确率。 16 | - 2022/02/28 添加[checkpoint-170000的small权重](https://huggingface.co/junnyu/flashquad_small_wwm_cluecorpussmall),[训练日志1](https://wandb.ai/junyu/huggingface/runs/ofdc74wr)和[训练日志2](https://wandb.ai/junyu/huggingface/runs/2ep6cl14),感觉结果不理想。 17 | - 2022/02/26 修改了`rel_pos_bias`部分的代码,发现之前的代码会出现输出异常(训练是在512长度进行的,在别的长度进行测试,模型的输出会出问题. ) 18 | ```python 19 | # 之前的代码. 20 | bias = self.rel_pos_bias(seq_len) 21 | kernel = torch.square(torch.relu(qk / seq_len + bias)) 22 | # 更新后的代码. 23 | self.max_position_embeddings = 512 24 | bias = self.rel_pos_bias(self.max_position_embeddings)[:, :seq_len, :seq_len] 25 | kernel = torch.square(torch.relu(qk / self.max_position_embeddings + bias)) 26 | ``` 27 | 28 | # Usage 29 | ```python 30 | # flashquad 31 | from flash import FLASHQuadConfig, FLASHQuadModel 32 | import torch 33 | config = FLASHQuadConfig() 34 | model = FLASHQuadModel(config) 35 | model.eval() 36 | input_ids = torch.randint(0,12000,(4,128)) 37 | with torch.no_grad(): 38 | outputs = model(input_ids=input_ids, output_attentions=True, output_hidden_states=True) 39 | print(outputs) 40 | 41 | # flash 42 | from flash import FLASHConfig, FLASHModel 43 | import torch 44 | config = FLASHConfig() 45 | model = FLASHModel(config) 46 | model.eval() 47 | input_ids = torch.randint(0, 12000, (4, 128)) 48 | with torch.no_grad(): 49 | outputs = model( 50 | input_ids=input_ids, output_attentions=True, output_hidden_states=True 51 | ) 52 | print(outputs) 53 | ``` 54 | 55 | # Pretrain 56 | ## 准备数据 57 | ## CLUECorpusSmall 数据集处理教程(摘抄自[paddlenlp](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/data_tools/README.md)) 58 | **数据集简介**:可用于语言建模、预训练或生成型任务等,数据量超过14G,近4000个定义良好的txt文件、50亿个字。主要部分来自于nlp_chinese_corpus项目 59 | 包含如下子语料库(总共14G语料):新闻语料[news2016zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/6bac09db4e6d4857b6d680d34447457490cb2dbdd8b8462ea1780a407f38e12b?responseContentDisposition=attachment%3B%20filename%3Dnews2016zh_corpus.zip), 社区互动语料[webText2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/83da03f7b4974871a52348b41c16c7e3b34a26d5ca644f558df8435be4de51c3?responseContentDisposition=attachment%3B%20filename%3DwebText2019zh_corpus.zip),维基百科语料[wiki2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/d7a166408d8b4ffdaf4de9cfca09f6ee1e2340260f26440a92f78134d068b28f?responseContentDisposition=attachment%3B%20filename%3Dwiki2019zh_corpus.zip),评论数据语料[comment2019zh_corpus.zip](https://bj.bcebos.com/v1/ai-studio-online/b66ddd445735408383c42322850ac4bb82faf9cc611447c2affb925443de7a6d?responseContentDisposition=attachment%3B%20filename%3Dcomment2019zh_corpus.zip)。 60 | 61 | **数据集下载**: 62 | 用户可以通过官方github网页下载,https://github.com/CLUEbenchmark/CLUECorpus2020 。同时,为方便用户,我们也提供了aistudio数据集下载地址。[part1](https://aistudio.baidu.com/aistudio/datasetdetail/60598),[part2](https://aistudio.baidu.com/aistudio/datasetdetail/124357)。使用aistudio版本的数据,下载好后,可以核对md5值: 63 | ```shell 64 | > md5sum ./* 65 | 8a8be341ebce39cfe9524fb0b46b08c5 ./comment2019zh_corpus.zip 66 | 4bdc2c941a7adb4a061caf273fea42b8 ./news2016zh_corpus.zip 67 | fc582409f078b10d717caf233cc58ddd ./webText2019zh_corpus.zip 68 | 157dacde91dcbd2e52a60af49f710fa5 ./wiki2019zh_corpus.zip 69 | ``` 70 | (1) 解压文件 71 | ```shell 72 | unzip comment2019zh_corpus.zip -d clue_corpus_small_14g/comment2019zh_corpus 73 | unzip news2016zh_corpus.zip -d clue_corpus_small_14g/news2016zh_corpus 74 | unzip webText2019zh_corpus.zip -d clue_corpus_small_14g/webText2019zh_corpus 75 | unzip wiki2019zh_corpus.zip -d clue_corpus_small_14g/wiki2019zh_corpus 76 | ``` 77 | (2) 将txt文件转换为jsonl格式 78 | ```shell 79 | python trans_to_json.py --input_path ./clue_corpus_small_14g --output_path clue_corpus_small_14g.jsonl 80 | mkdir data #创建data文件夹 81 | mv clue_corpus_small_14g.jsonl ./data #将jsonl放进该目录 82 | ``` 83 | (3) 使用rjieba进行中文分词,会得到`data/refids.txt`和`data/reftext.txt`两个文件,并组合`data/refids.txt`和`data/reftext.txt`两个文件保存成`huggingface`的`dataset` 84 | ```shell 85 | python run_chinese_ref.py --model_name junnyu/roformer_chinese_char_base --input_path ./data/clue_corpus_small_14g.jsonl 86 | ``` 87 | 88 | ## 开始训练(small版本模型) 89 | ```bash 90 | TRAIN_DIR=./clue_small_wwm_data 91 | OUTPUT_DIR=./wwm_flash_small/ 92 | BATCH_SIZE=32 93 | ACCUMULATION=4 94 | LR=1e-4 95 | python run_mlm_wwm.py \ 96 | --do_train \ 97 | --tokenizer_name junnyu/roformer_chinese_char_base \ 98 | --train_dir $TRAIN_DIR \ 99 | --output_dir $OUTPUT_DIR \ 100 | --logging_dir $OUTPUT_DIR/logs \ 101 | --per_device_train_batch_size $BATCH_SIZE \ 102 | --gradient_accumulation_steps $ACCUMULATION \ 103 | --learning_rate $LR \ 104 | --weight_decay 0.01 \ 105 | --adam_epsilon 1e-6 \ 106 | --max_steps 250000 \ 107 | --warmup_steps 5000 \ 108 | --logging_steps 100 \ 109 | --save_steps 5000 \ 110 | --seed 2022 \ 111 | --max_grad_norm 3.0 \ 112 | --dataloader_num_workers 6 \ 113 | --fp16 114 | 115 | ``` 116 | 117 | # MLM测试 118 | ```python 119 | # flashquad 120 | import torch 121 | from flash import FLASHQuadForMaskedLM 122 | from transformers import BertTokenizerFast 123 | tokenizer = BertTokenizerFast.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall") 124 | model = FLASHQuadForMaskedLM.from_pretrained("junnyu/flashquad_small_wwm_cluecorpussmall") 125 | model.eval() 126 | text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!" 127 | inputs = tokenizer(text, return_tensors="pt") 128 | with torch.no_grad(): 129 | pt_outputs = model(**inputs).logits[0] 130 | 131 | pt_outputs_sentence = "pytorch: " 132 | for i, id in enumerate(tokenizer.encode(text)): 133 | if id == tokenizer.mask_token_id: 134 | val,idx = pt_outputs[i].softmax(-1).topk(k=5) 135 | tokens = tokenizer.convert_ids_to_tokens(idx) 136 | new_tokens = [] 137 | for v,t in zip(val.cpu(),tokens): 138 | new_tokens.append(f"{t}+{round(v.item(),4)}") 139 | pt_outputs_sentence += "[" + "||".join(new_tokens) + "]" 140 | else: 141 | pt_outputs_sentence += "".join( 142 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)) 143 | print(pt_outputs_sentence) 144 | # pytorch: 天气预报说今天的天[气+0.9948||空+0.0011||色+0.0007||候+0.0004||势+0.0003]很好,那么我[就+0.4915||们+0.4186||也+0.0753||还+0.0021||都+0.0016]一起去公园玩吧! 145 | 146 | # flash 147 | import torch 148 | from flash import FLASHForMaskedLM 149 | from transformers import BertTokenizerFast 150 | tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall") 151 | model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall") 152 | model.eval() 153 | text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!" 154 | inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512, return_token_type_ids=False) #这里必须是512,不然结果可能不对。 155 | with torch.no_grad(): 156 | pt_outputs = model(**inputs).logits[0] 157 | 158 | pt_outputs_sentence = "pytorch: " 159 | for i, id in enumerate(tokenizer.encode(text)): 160 | if id == tokenizer.mask_token_id: 161 | val,idx = pt_outputs[i].softmax(-1).topk(k=5) 162 | tokens = tokenizer.convert_ids_to_tokens(idx) 163 | new_tokens = [] 164 | for v,t in zip(val.cpu(),tokens): 165 | new_tokens.append(f"{t}+{round(v.item(),4)}") 166 | pt_outputs_sentence += "[" + "||".join(new_tokens) + "]" 167 | else: 168 | pt_outputs_sentence += "".join( 169 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)) 170 | print(pt_outputs_sentence) 171 | # pytorch: 天气预报说今天的天[气+0.9938||天+0.0017||空+0.0011||晴+0.0007||阳+0.0002]很好,那么我[们+0.9367||就+0.0554||也+0.0041||俩+0.0005||还+0.0004]一起去公园玩吧! 172 | ``` 173 | 174 | # Tnews分类 175 |

176 | 177 |

178 | 179 | # Tips 180 | 不怎么确定实现的对不对,如果代码有错误的话,请帮我指出来,谢谢~ 181 | -------------------------------------------------------------------------------- /clue_small_wwm_data/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/clue_small_wwm_data/dataset.arrow -------------------------------------------------------------------------------- /clue_small_wwm_data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "text": { 11 | "dtype": "string", 12 | "id": null, 13 | "_type": "Value" 14 | }, 15 | "chinese_ref": { 16 | "feature": { 17 | "dtype": "int64", 18 | "id": null, 19 | "_type": "Value" 20 | }, 21 | "length": -1, 22 | "id": null, 23 | "_type": "Sequence" 24 | } 25 | }, 26 | "homepage": "", 27 | "license": "", 28 | "post_processed": null, 29 | "post_processing_size": null, 30 | "size_in_bytes": null, 31 | "splits": null, 32 | "supervised_keys": null, 33 | "task_templates": null, 34 | "version": null 35 | } -------------------------------------------------------------------------------- /clue_small_wwm_data/place_holder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/clue_small_wwm_data/place_holder.txt -------------------------------------------------------------------------------- /clue_small_wwm_data/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "602b282613955525", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /data/clue_corpus_small_14g.jsonl: -------------------------------------------------------------------------------- 1 | {"text":"如何评价中企为打捞世越号赔本11亿?韩国媒体报道,上海打捞局和韩国政府签定的合同为916亿韩元,打捞局现已花费2800亿韩元。首先声明,此为本人拙见,仅作为猜测,欢迎大家拍砖!中国有实力进行海上大型船舶打捞作业企业一共有三家,都隶属于交通运输部。按北中南分布,分别是烟台打捞局上海打捞局广州打捞局,都是根正苗红的国企(国家意志)。合同签署的时间点很重要,2015年5月起,韩国对世越号沉船打捞项目公开招标。这个时间点中韩关系处在“蜜月期”。小弟出事了,大哥搭把手。再一个合同的价格,2015年8月,由中国交通运输部下属上海打捞"} 2 | {"text":"味道很不错,就是地方有点小。总体来说是很不错的,值得一吃哦。喜欢吃西餐的朋友可以去一下,反正我是听喜欢这种氛围的。"} 3 | {"text":"有哪些法律术语容易被大众误解?比如纳税人,挤了30种出来,大脑短路搞忘了几个点。以下全是干货。1.法人和法人的法定代表人。法人是法律拟制之人,自然人是一个与之并列的概念,一个自然人是万万成为不了法人的,只能成为法定代表人。刑法中对妇女的规定是已满十四周岁的女性(生日的第二天起算),大家通常理解为非处女。不是关两年再执行,而是给犯人一个不死的机会。被判死缓的人类似于坐在达摩克利斯之剑下,两年不故意犯罪就可减刑为无期,有"} 4 | {"text":"这家粥铺在上海春城商业街上也算开的时间较久了。个人觉得还是有存在价值。粥还是比较营养的,品种一般。就是小菜有点贵。"} 5 | {"text":"非常喜欢这本书封面设计很特别字体也很清楚。内容就更不用说了与电影相比原著更值得一读让人为之动容。"} 6 | {"text":"佩兰首秀遭球迷调侃“世界杯I组,中国队首胜”(组图)主教练佩兰小试牛刀 新华社发 主教练佩兰小试牛刀新华社发 热身赛 中国 2:0 马其顿 热身赛中国2:0马其顿0618AP08于汉超庆祝进球新华社发昨晚,中国男足时隔13年后重回福地沈阳,依靠于汉超和高迪在下半场的两粒进球,最终2比0战胜对手。日期“撞车”世界杯,昨天沈阳下了一场雷暴雨,原本想大赚一笔的黄牛,却在赛前把100元一张的门票贱卖至5块也无人问津,幽怨地对记者说:“回家没法跟老婆交待,因为赔惨了。”而在世界杯这个热闹的大背景之下,国足自然也难逃被调侃的命运。法国人佩兰执掌国足的首秀,被不少球迷称为中国队在“世界杯I组(世界杯只有8个小组,从A-H组)的首胜”。3天之后,双方将进行第二场的较量。特约记者 程玲林备战揭秘国脚聊侃世界杯奥乔亚被赞“来自星星”因为晚上有比赛,昨天是国足集训以来唯一白天没安排训练的一天。不过国脚们也没人“敢看”世界杯的直播。可一觉醒来,奥乔亚竟然成为了国脚们的谈资,就连主教练佩兰的翻译也私下里透露,“佩兰虽然没看(世界杯直播)比赛,但他却认为\"是奥乔亚打败了巴西队\"。”奥乔亚是谁?这在巴西与墨西哥赛前可能没有中国球迷能认得。"} 7 | {"text":"可几乎在一夜之间,奥乔亚这个名字就传遍了大江南北。回到家乡沈阳的国脚刘建业,更是用“那是来自星星的你”,这句纯正的东北话让队友们笑翻了。因为佩兰在国足集训首日就已经公布了球队的作息时间表,因此尽管在沈阳这几天相对时间宽松,不过国脚们还是遵守纪律,没人敢“偷看世界杯直播”。可早上吃饭时,巴西队被墨西哥逼平,而且功臣是奥乔亚,也很快成为大家的谈资。“墨西哥就不缺这样的(守门员)。”于汉超说,“之前不有个什么花蝴蝶么(坎波斯)……”“不是常说,\"好的守门员等于半支球队\"么,这家伙(奥乔亚)真是\"神兽\"。”张呈栋冷不丁来上这一句。“是来自星星的奥乔亚!”刘建业用他那东北话说。向群世界杯·链接国足不参赛照样抢钱中国足协稳赚75万美元听说过躺着中枪的,现在有了躺着赚钱的:由于国际足联提高了巴西世界杯的总奖金,即使没有资格出现在巴西,中国足协也能从国际足联那里分到75万美元的分红。2010年南非世界杯,32支参赛队共获得了4.2亿美元奖金,这个数值在今年将达到5.76亿。巴西世界杯,国际足联将破纪录地获得45亿美元收入,除了给参赛队5.76亿之外,国际足联还将提供2亿美元作为分红发放给各国足协,所有209个成员国将在7月底收到25万美元,明年年初收到后续50万美元。"} 8 | {"text":"这意味着,就算没有资格在巴西亮相,中国足协也可以收到75万美元分红,尽管比起参赛球队至少可以获得900万美元的数字少了不少,但躺着赚钱的中国足协也应该满足了。陈甘露 佩帅点将首发起用了5名新球员巴西世界杯激战正酣,而在地球的另一端,已连续几届无法打进世界杯预选赛亚洲区决赛阶段的中国队与欧洲鱼腩马其顿队进行了一场友谊赛,而这也是二月份挂帅国足帅印的法籍主帅佩兰,上任后所执教的首场比赛。此役,新官上任三把火的佩兰,一改中国男足此前重用老球员的模式,首发名单中有多达5名球员第一次为国足先发上场。面对欧洲弱旅马其顿,汇集目前国内最好球员的中国男足并未能在上半场展现出优势,反倒给马其顿队几次反击打得狼狈十足。下半场开始后,佩兰对场上人员进行了大幅度调整。第56分钟,刚刚转会广州恒大的于汗超依靠一次单刀机会为中国队破门得分,此后上海绿地的高迪在临近比赛结束时一脚远射,将比分改写为2比0,最终中国队依靠这两粒进球轻取对手,而佩兰也在执掌此役后,向国人交出了一份及格答卷。值得一提的是,在佩兰阵容中,任航、吴曦分别坐镇左边后卫和后腰位置。从昨天的表现来看,两人发挥可谓中规中距,尽到自己的责任。特别是作为新人,第一次代表中国男足先发出场的任航,在面对马其顿锋线球员的冲击,表现出其在俱乐部的真实水平,虽然在上半场临近结束时有一次失位,险些让对方投机得手,但整场比赛数据显示,任航在与对方前锋一对一的抢断成功率高达82.3%,这是全场防守球员数据最好的。"} 9 | {"text":"观战花絮天气不佳,上座率不足千人昨晚,国足时隔13年后重回福地沈阳进行比赛。13年前的沈阳五里河,对于中国球迷来说,是唯一残存在脑海里的幸福瞬间。而今,沈阳五里河早已拆为平地建成了商品房,中国男足却在遥望巴西的同时,与世界杯渐行渐远。或许是沈阳昨天下了一整天雷暴雨的原因,又或者是大家在世界杯赛期间,无心关注国足的比赛。截至到开赛前,沈阳奥林匹克体育中心的上座率不足千人,而场外的黄牛则叫苦连连。原本期待国足13年后重回沈阳,能重展当年一票难求的球市,但实际上,黄牛在临开场把自购100元一张的球票贱卖5元都无人问津,记者现场采访了一位持数张球票在苦寻买主的黄牛,得到的答案是:“别提了,赔惨了!回家都不晓得怎么跟老婆交待……”现场的场景有点凄惨,而国足在网络上也成为了球迷调侃的苦主。昨天,国足的比赛用球也是本届世界杯的官方用球—桑巴荣耀,有球迷在网络留言中大玩转折,称“中国队终于踢上了世界杯……用球”。翻看FIFA比赛系统,6-7月在案的国际比赛并不多,只有几场印尼、越南和尼泊尔的热身赛。当32强在巴西厮杀正酣的时候,国足却只能拉来马其顿二队进行热身赛,这样的情形确实虐心。因此这场热身赛,也被众多网友戏谑地称为“世界杯I组比赛”,并恭喜国足取得比赛胜利。"} 10 | {"text":"去杭州玩随便找了这家吃的,味道还行,糖醋小排挺好吃的,还有蚝油生菜,本来不知道什么是片儿川,现在总算知道了-。- 牛蛙太辣了,没吃几口。鸭舌就算了,冷菜而已。"} 11 | -------------------------------------------------------------------------------- /data/refids.txt: -------------------------------------------------------------------------------- 1 | [2, 4, 6, 9, 11, 14, 19, 21, 22, 23, 26, 28, 32, 33, 34, 36, 39, 44, 45, 48, 51, 53, 56, 57, 60, 62, 67, 69, 73, 75, 78, 80, 82, 85, 88, 90, 92, 94, 96, 98, 100, 102, 104, 107, 111, 112, 114, 115, 116, 120, 122, 124, 127, 130, 132, 135, 137, 140, 142, 148, 149, 150, 153, 156, 158, 162, 164, 167, 171, 180, 182, 183, 186, 188, 190, 192, 193, 194, 197, 199, 202, 203, 204, 206, 209, 214, 216, 220, 223, 227, 229, 232, 241, 243, 244, 245, 247, 250, 252] 2 | [2, 5, 8, 10, 12, 16, 18, 22, 26, 28, 32, 35, 38, 40, 43, 46, 51, 53, 55] 3 | [3, 5, 7, 9, 12, 14, 17, 19, 20, 27, 30, 32, 37, 41, 43, 45, 50, 53, 56, 57, 58, 62, 65, 67, 72, 73, 76, 80, 83, 86, 88, 89, 92, 94, 96, 98, 102, 104, 106, 107, 108, 112, 116, 119, 124, 125, 126, 129, 132, 135, 136, 142, 144, 146, 150, 153, 156, 159, 162, 165, 167, 169, 172, 175, 177, 181, 184, 186, 188, 189, 191, 192, 195, 198, 199, 200, 204, 207] 4 | [2, 7, 9, 11, 12, 16, 19, 21, 25, 27, 29, 32, 34, 38, 40, 42, 46, 48, 51, 53, 55] 5 | [2, 4, 7, 9, 10, 11, 14, 16, 20, 23, 27, 28, 32, 34, 36, 39, 41, 45, 46, 47] 6 | [2, 4, 7, 9, 12, 13, 18, 19, 21, 25, 28, 29, 31, 33, 34, 35, 40, 41, 44, 45, 47, 49, 50, 51, 53, 54, 57, 58, 63, 74, 75, 77, 78, 80, 85, 86, 92, 93, 95, 97, 99, 100, 103, 106, 107, 108, 110, 115, 117, 119, 122, 124, 125, 128, 131, 132, 135, 137, 140, 145, 147, 150, 153, 156, 157, 160, 162, 166, 168, 169, 172, 174, 177, 180, 185, 190, 193, 195, 201, 202, 203, 206, 210, 215, 217, 220, 222, 225, 227, 234, 235, 237, 239, 243, 245, 248, 250, 253, 256, 259, 262, 263, 265, 267, 269, 272, 276, 278, 280, 282, 283, 287, 288, 293, 294, 296, 300, 310, 316, 319, 322, 324, 325, 328, 331, 332, 333, 335, 336, 338, 340, 342, 344, 346, 347, 349, 350, 355, 357, 360, 362, 365, 368, 371, 373, 375, 377, 379, 382, 384, 387, 390, 392, 396, 402, 403, 406, 409, 411, 415, 416, 418, 420, 423, 427, 432, 433, 435, 438, 441, 442, 444, 448, 450, 455, 456, 458, 461, 464, 465, 467, 471, 472, 474, 477, 478, 483, 484, 491, 494, 495, 497, 499, 501, 503, 505, 508] 7 | [3, 6, 7, 8, 11, 12, 14, 16, 19, 22, 23, 24, 27, 29, 31, 34, 36, 37, 40, 46, 48, 54, 56, 59, 63, 70, 72, 75, 77, 79, 82, 84, 87, 90, 91, 92, 96, 98, 101, 104, 106, 108, 110, 113, 115, 118, 120, 121, 122, 125, 126, 129, 131, 132, 134, 139, 141, 145, 146, 149, 150, 152, 155, 157, 160, 161, 165, 167, 169, 172, 176, 177, 180, 182, 186, 187, 192, 193, 198, 200, 203, 205, 206, 211, 218, 220, 226, 227, 229, 231, 233, 239, 242, 243, 246, 249, 254, 255, 257, 258, 263, 268, 270, 273, 274, 278, 279, 284, 291, 292, 295, 297, 300, 302, 304, 306, 307, 308, 310, 313, 314, 316, 321, 325, 331, 335, 337, 338, 339, 341, 344, 346, 347, 351, 354, 356, 358, 360, 363, 366, 367, 368, 373, 374, 375, 377, 379, 382, 383, 386, 391, 393, 394, 399, 400, 403, 409, 410, 412, 415, 417, 420, 423, 430, 432, 433, 436, 437, 438, 441, 442, 445, 448, 449, 451, 454, 457, 458, 464, 467, 468, 469, 473, 476, 477, 479, 481, 483, 486, 488, 491, 495, 496, 501, 503, 506, 507, 510, 512, 514, 516, 519, 520] 8 | [3, 4, 7, 9, 11, 14, 16, 19, 20, 21, 24, 26, 29, 30, 32, 35, 37, 39, 41, 43, 45, 47, 50, 51, 54, 58, 64, 67, 68, 69, 72, 74, 79, 81, 83, 85, 87, 91, 93, 95, 97, 98, 100, 102, 107, 110, 111, 115, 117, 119, 123, 124, 126, 127, 129, 130, 132, 134, 137, 138, 141, 143, 145, 146, 149, 152, 154, 155, 162, 163, 165, 169, 172, 174, 176, 179, 183, 186, 188, 191, 194, 195, 196, 197, 198, 199, 202, 205, 207, 208, 209, 211, 213, 216, 219, 222, 224, 226, 228, 232, 234, 235, 237, 238, 240, 242, 245, 247, 249, 251, 252, 255, 257, 259, 261, 263, 266, 267, 268, 270, 274, 275, 277, 278, 280, 283, 286, 287, 290, 292, 296, 298, 301, 302, 304, 308, 311, 313, 315, 318, 319, 321, 326, 329, 331, 333, 335, 341, 343, 345, 347, 350, 351, 353, 355, 358, 360, 362, 365, 368, 370, 372, 375, 377, 381, 383, 390, 392, 393, 395, 398, 400, 402, 404, 408, 412, 414, 419, 421, 424, 426, 428, 431, 432, 433, 434, 435, 439, 441, 445, 448, 450, 452, 454, 455, 456, 459, 461, 465, 468, 470, 473, 475, 477, 479, 481, 484, 486, 489, 492, 495, 497, 500, 501, 503, 505, 506, 507, 509, 514, 518, 520, 521, 523, 525, 528, 531, 533, 536, 537, 540, 542, 545, 548, 549, 551, 553, 555, 557, 559, 562, 565, 567, 569, 573, 575, 577, 579, 582, 586, 588, 590, 591, 594, 596, 597, 598, 606, 608, 610, 612, 614, 616] 9 | [2, 4, 6, 8, 11, 12, 14, 16, 18, 21, 23, 28, 30, 32, 34, 36, 40, 43, 45, 46, 49, 51, 53, 55, 59, 61, 64, 68, 70, 73, 76, 78, 79, 81, 83, 85, 87, 90, 91, 94, 95, 96, 100, 102, 105, 109, 110, 112, 114, 117, 120, 122, 126, 127, 129, 130, 133, 137, 140, 143, 144, 145, 147, 150, 152, 154, 157, 160, 163, 167, 169, 170, 171, 173, 174, 175, 178, 179, 181, 183, 187, 190, 193, 195, 198, 200, 202, 207, 209, 213, 215, 217, 219, 222, 226, 227, 230, 234, 237, 241, 244, 246, 251, 252, 253, 256, 258, 259, 260, 263, 266, 268, 273, 276, 279, 282, 287, 291, 295, 299, 301, 304, 306, 311, 314, 316, 318, 321, 322, 325, 329, 332, 334, 337, 340, 343, 346, 348, 352, 354, 355, 358, 360, 363, 365, 369, 372, 374, 376, 377, 379, 384, 385, 387, 392, 393, 397, 401, 404, 406, 413, 416, 418, 420, 421, 424, 426, 428, 431, 434, 435, 438, 439, 446, 448, 450, 453, 456, 459, 463, 464, 466, 468, 470, 471, 474, 477, 479, 481, 484, 486, 488, 489, 494, 496, 498, 501, 504, 505, 509, 514, 516, 518, 520, 522] 10 | [3, 6, 10, 15, 20, 22, 24, 25, 29, 31, 33, 36, 39, 41, 44, 48, 50, 52, 58, 60, 66, 69, 71, 75, 77] -------------------------------------------------------------------------------- /data/reftext.txt: -------------------------------------------------------------------------------- 1 | 如何评价中企为打捞世越号赔本11亿?韩国媒体报道,上海打捞局和韩国政府签定的合同为916亿韩元,打捞局现已花费2800亿韩元。首先声明,此为本人拙见,仅作为猜测,欢迎大家拍砖!中国有实力进行海上大型船舶打捞作业企业一共有三家,都隶属于交通运输部。按北中南分布,分别是烟台打捞局上海打捞局广州打捞局,都是根正苗红的国企(国家意志)。合同签署的时间点很重要,2015年5月起,韩国对世越号沉船打捞项目公开招标。这个时间点中韩关系处在“蜜月期”。小弟出事了,大哥搭把手。再一个合同的价格,2015年8月,由中国交通运输部下属上海打捞 2 | 味道很不错,就是地方有点小。总体来说是很不错的,值得一吃哦。喜欢吃西餐的朋友可以去一下,反正我是听喜欢这种氛围的。 3 | 有哪些法律术语容易被大众误解?比如纳税人,挤了30种出来,大脑短路搞忘了几个点。以下全是干货。1.法人和法人的法定代表人。法人是法律拟制之人,自然人是一个与之并列的概念,一个自然人是万万成为不了法人的,只能成为法定代表人。刑法中对妇女的规定是已满十四周岁的女性(生日的第二天起算),大家通常理解为非处女。不是关两年再执行,而是给犯人一个不死的机会。被判死缓的人类似于坐在达摩克利斯之剑下,两年不故意犯罪就可减刑为无期,有 4 | 这家粥铺在上海春城商业街上也算开的时间较久了。个人觉得还是有存在价值。粥还是比较营养的,品种一般。就是小菜有点贵。 5 | 非常喜欢这本书封面设计很特别字体也很清楚。内容就更不用说了与电影相比原著更值得一读让人为之动容。 6 | 佩兰首秀遭球迷调侃“世界杯I组,中国队首胜”(组图)主教练佩兰小试牛刀 新华社发 主教练佩兰小试牛刀新华社发 热身赛 中国 2:0 马其顿 热身赛中国2:0马其顿0618AP08于汉超庆祝进球新华社发昨晚,中国男足时隔13年后重回福地沈阳,依靠于汉超和高迪在下半场的两粒进球,最终2比0战胜对手。日期“撞车”世界杯,昨天沈阳下了一场雷暴雨,原本想大赚一笔的黄牛,却在赛前把100元一张的门票贱卖至5块也无人问津,幽怨地对记者说:“回家没法跟老婆交待,因为赔惨了。”而在世界杯这个热闹的大背景之下,国足自然也难逃被调侃的命运。法国人佩兰执掌国足的首秀,被不少球迷称为中国队在“世界杯I组(世界杯只有8个小组,从A-H组)的首胜”。3天之后,双方将进行第二场的较量。特约记者 程玲林备战揭秘国脚聊侃世界杯奥乔亚被赞“来自星星”因为晚上有比赛,昨天是国足集训以来唯一白天没安排训练的一天。不过国脚们也没人“敢看”世界杯的直播。可一觉醒来,奥乔亚竟然成为了国脚们的谈资,就连主教练佩兰的翻译也私下里透露,“佩兰虽然没看(世界杯直播)比赛,但他却认为"是奥乔亚打败了巴西队"。”奥乔亚是谁?这在巴西与墨西哥赛前可能没有中国球迷能认得。 7 | 可几乎在一夜之间,奥乔亚这个名字就传遍了大江南北。回到家乡沈阳的国脚刘建业,更是用“那是来自星星的你”,这句纯正的东北话让队友们笑翻了。因为佩兰在国足集训首日就已经公布了球队的作息时间表,因此尽管在沈阳这几天相对时间宽松,不过国脚们还是遵守纪律,没人敢“偷看世界杯直播”。可早上吃饭时,巴西队被墨西哥逼平,而且功臣是奥乔亚,也很快成为大家的谈资。“墨西哥就不缺这样的(守门员)。”于汉超说,“之前不有个什么花蝴蝶么(坎波斯)……”“不是常说,"好的守门员等于半支球队"么,这家伙(奥乔亚)真是"神兽"。”张呈栋冷不丁来上这一句。“是来自星星的奥乔亚!”刘建业用他那东北话说。向群世界杯·链接国足不参赛照样抢钱中国足协稳赚75万美元听说过躺着中枪的,现在有了躺着赚钱的:由于国际足联提高了巴西世界杯的总奖金,即使没有资格出现在巴西,中国足协也能从国际足联那里分到75万美元的分红。2010年南非世界杯,32支参赛队共获得了4.2亿美元奖金,这个数值在今年将达到5.76亿。巴西世界杯,国际足联将破纪录地获得45亿美元收入,除了给参赛队5.76亿之外,国际足联还将提供2亿美元作为分红发放给各国足协,所有209个成员国将在7月底收到25万美元,明年年初收到后续50万美元。 8 | 这意味着,就算没有资格在巴西亮相,中国足协也可以收到75万美元分红,尽管比起参赛球队至少可以获得900万美元的数字少了不少,但躺着赚钱的中国足协也应该满足了。陈甘露 佩帅点将首发起用了5名新球员巴西世界杯激战正酣,而在地球的另一端,已连续几届无法打进世界杯预选赛亚洲区决赛阶段的中国队与欧洲鱼腩马其顿队进行了一场友谊赛,而这也是二月份挂帅国足帅印的法籍主帅佩兰,上任后所执教的首场比赛。此役,新官上任三把火的佩兰,一改中国男足此前重用老球员的模式,首发名单中有多达5名球员第一次为国足先发上场。面对欧洲弱旅马其顿,汇集目前国内最好球员的中国男足并未能在上半场展现出优势,反倒给马其顿队几次反击打得狼狈十足。下半场开始后,佩兰对场上人员进行了大幅度调整。第56分钟,刚刚转会广州恒大的于汗超依靠一次单刀机会为中国队破门得分,此后上海绿地的高迪在临近比赛结束时一脚远射,将比分改写为2比0,最终中国队依靠这两粒进球轻取对手,而佩兰也在执掌此役后,向国人交出了一份及格答卷。值得一提的是,在佩兰阵容中,任航、吴曦分别坐镇左边后卫和后腰位置。从昨天的表现来看,两人发挥可谓中规中距,尽到自己的责任。特别是作为新人,第一次代表中国男足先发出场的任航,在面对马其顿锋线球员的冲击,表现出其在俱乐部的真实水平,虽然在上半场临近结束时有一次失位,险些让对方投机得手,但整场比赛数据显示,任航在与对方前锋一对一的抢断成功率高达82.3%,这是全场防守球员数据最好的。 9 | 观战花絮天气不佳,上座率不足千人昨晚,国足时隔13年后重回福地沈阳进行比赛。13年前的沈阳五里河,对于中国球迷来说,是唯一残存在脑海里的幸福瞬间。而今,沈阳五里河早已拆为平地建成了商品房,中国男足却在遥望巴西的同时,与世界杯渐行渐远。或许是沈阳昨天下了一整天雷暴雨的原因,又或者是大家在世界杯赛期间,无心关注国足的比赛。截至到开赛前,沈阳奥林匹克体育中心的上座率不足千人,而场外的黄牛则叫苦连连。原本期待国足13年后重回沈阳,能重展当年一票难求的球市,但实际上,黄牛在临开场把自购100元一张的球票贱卖5元都无人问津,记者现场采访了一位持数张球票在苦寻买主的黄牛,得到的答案是:“别提了,赔惨了!回家都不晓得怎么跟老婆交待……”现场的场景有点凄惨,而国足在网络上也成为了球迷调侃的苦主。昨天,国足的比赛用球也是本届世界杯的官方用球—桑巴荣耀,有球迷在网络留言中大玩转折,称“中国队终于踢上了世界杯……用球”。翻看FIFA比赛系统,6-7月在案的国际比赛并不多,只有几场印尼、越南和尼泊尔的热身赛。当32强在巴西厮杀正酣的时候,国足却只能拉来马其顿二队进行热身赛,这样的情形确实虐心。因此这场热身赛,也被众多网友戏谑地称为“世界杯I组比赛”,并恭喜国足取得比赛胜利。 10 | 去杭州玩随便找了这家吃的,味道还行,糖醋小排挺好吃的,还有蚝油生菜,本来不知道什么是片儿川,现在总算知道了-。- 牛蛙太辣了,没吃几口。鸭舌就算了,冷菜而已。 -------------------------------------------------------------------------------- /figure/tnews.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/FLASHQuad_pytorch/e5902617f4573c9edd967313eba8f01234b5cebf/figure/tnews.jpg -------------------------------------------------------------------------------- /flash/__init__.py: -------------------------------------------------------------------------------- 1 | from flash.flash import FLASHQuadConfig, FLASHQuadModel, FLASHQuadForMaskedLM, FLASHQuadForSequenceClassification 2 | from flash.flash_lucidrains import FLASHConfig, FLASHModel, FLASHForMaskedLM, FLASHForSequenceClassification, FLASHForMultipleChoice, FLASHForTokenClassification, FLASHForQuestionAnswering -------------------------------------------------------------------------------- /flash/flash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.modeling_outputs import ( 5 | BaseModelOutput, 6 | MaskedLMOutput, 7 | SequenceClassifierOutput, 8 | ) 9 | from transformers.modeling_utils import PreTrainedModel 10 | from transformers.models.bert.modeling_bert import ( 11 | BertOnlyMLMHead as FLASHQuadOnlyMLMHead, 12 | ) 13 | from transformers.utils import logging 14 | 15 | from flash.gau import GAU, ScaleNorm 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | class FLASHQuadConfig(PretrainedConfig): 21 | model_type = "flash_quad" 22 | 23 | def __init__( 24 | self, 25 | vocab_size=12000, 26 | hidden_size=768, 27 | num_hidden_layers=24, # base 28 | max_position_embeddings=512, 29 | type_vocab_size=2, 30 | initializer_range=0.02, 31 | layer_norm_eps=1e-5, 32 | pad_token_id=0, 33 | expansion_factor=2, 34 | s=128, 35 | norm_type="scale_norm", 36 | gradient_checkpointing=False, 37 | dropout=0.0, 38 | hidden_act="silu", 39 | classifier_dropout=0.1, 40 | **kwargs 41 | ): 42 | super().__init__(pad_token_id=pad_token_id, **kwargs) 43 | 44 | self.vocab_size = vocab_size 45 | self.hidden_size = hidden_size 46 | self.num_hidden_layers = num_hidden_layers 47 | self.max_position_embeddings = max_position_embeddings 48 | self.type_vocab_size = type_vocab_size 49 | self.initializer_range = initializer_range 50 | self.layer_norm_eps = layer_norm_eps 51 | self.expansion_factor = expansion_factor 52 | self.s = s 53 | self.norm_type = norm_type 54 | self.dropout = dropout 55 | self.hidden_act = hidden_act 56 | self.gradient_checkpointing = gradient_checkpointing 57 | self.classifier_dropout = classifier_dropout 58 | 59 | 60 | class FLASHQuadPreTrainedModel(PreTrainedModel): 61 | config_class = FLASHQuadConfig 62 | base_model_prefix = "flash_quad" 63 | 64 | def _init_weights(self, module): 65 | """Initialize the weights""" 66 | if isinstance(module, nn.Linear): 67 | module.weight.data.normal_( 68 | mean=0.0, std=self.config.initializer_range) 69 | if module.bias is not None: 70 | module.bias.data.zero_() 71 | elif isinstance(module, nn.Embedding): 72 | module.weight.data.normal_( 73 | mean=0.0, std=self.config.initializer_range) 74 | if module.padding_idx is not None: 75 | module.weight.data[module.padding_idx].zero_() 76 | elif isinstance(module, nn.LayerNorm): 77 | module.bias.data.zero_() 78 | module.weight.data.fill_(1.0) 79 | 80 | 81 | class FLASHQuadEmbeddings(nn.Module): 82 | """Construct the embeddings from word, position and token_type embeddings.""" 83 | 84 | def __init__(self, config): 85 | super().__init__() 86 | self.config = config 87 | self.word_embeddings = nn.Embedding( 88 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id 89 | ) 90 | self.token_type_embeddings = nn.Embedding( 91 | config.type_vocab_size, config.hidden_size 92 | ) 93 | self.LayerNorm = ( 94 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 95 | if config.norm_type == "layer_norm" 96 | else ScaleNorm(eps=config.layer_norm_eps) 97 | ) 98 | self.dropout = nn.Dropout(config.dropout) 99 | self.register_buffer( 100 | "position_ids", torch.arange( 101 | config.max_position_embeddings).expand((1, -1)) 102 | ) 103 | self.scaledsin_scalar = nn.Parameter( 104 | torch.ones(1) / (config.hidden_size ** 0.5) 105 | ) 106 | self.register_buffer("scaledsin_embeddings", self.get_scaledsin()) 107 | 108 | def get_scaledsin(self): 109 | """Create sinusoidal position embedding with a scaling factor.""" 110 | seqlen, hidden_size = ( 111 | self.config.max_position_embeddings, 112 | self.config.hidden_size, 113 | ) 114 | pos = torch.arange(seqlen, dtype=torch.float32) 115 | half_d = hidden_size // 2 116 | 117 | freq_seq = -torch.arange(half_d, dtype=torch.float32) / float(half_d) 118 | inv_freq = 10000 ** freq_seq 119 | sinusoid = torch.einsum("s,d->sd", pos, inv_freq) 120 | scaledsin = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1) 121 | # scalar = 1 / hidden_size ** 0.5 122 | # scaledsin *= scalar 123 | return scaledsin 124 | 125 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None): 126 | input_shape = input_ids.shape 127 | seq_length = input_shape[1] 128 | 129 | if position_ids is None: 130 | position_ids = self.position_ids[:, :seq_length] 131 | 132 | if token_type_ids is None: 133 | token_type_ids = torch.zeros_like(input_ids) 134 | 135 | inputs_embeds = self.word_embeddings(input_ids) 136 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 137 | position_embeddings = ( 138 | self.scaledsin_embeddings[position_ids] * self.scaledsin_scalar 139 | ) 140 | embeddings = inputs_embeds + token_type_embeddings + position_embeddings 141 | embeddings = self.LayerNorm(embeddings) 142 | embeddings = self.dropout(embeddings) 143 | return embeddings 144 | 145 | 146 | class FLASHQuadEncoder(nn.Module): 147 | def __init__(self, config): 148 | super().__init__() 149 | self.config = config 150 | self.layer = nn.ModuleList( 151 | [ 152 | GAU( 153 | config.hidden_size, 154 | config.expansion_factor, 155 | config.s, 156 | config.norm_type, 157 | config.layer_norm_eps, 158 | config.hidden_act, 159 | config.max_position_embeddings, 160 | ) 161 | for _ in range(config.num_hidden_layers) 162 | ] 163 | ) 164 | 165 | def forward( 166 | self, 167 | hidden_states, 168 | attention_mask=None, 169 | output_attentions=False, 170 | output_hidden_states=False, 171 | return_dict=True, 172 | ): 173 | all_hidden_states = () if output_hidden_states else None 174 | all_self_attentions = () if output_attentions else None 175 | 176 | for i, layer_module in enumerate(self.layer): 177 | if output_hidden_states: 178 | all_hidden_states = all_hidden_states + (hidden_states,) 179 | 180 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 181 | 182 | def create_custom_forward(module): 183 | def custom_forward(*inputs): 184 | return module(*inputs, output_attentions) 185 | 186 | return custom_forward 187 | 188 | layer_outputs = torch.utils.checkpoint.checkpoint( 189 | create_custom_forward(layer_module), 190 | hidden_states, 191 | attention_mask, 192 | ) 193 | else: 194 | layer_outputs = layer_module( 195 | hidden_states, 196 | attention_mask, 197 | output_attentions, 198 | ) 199 | 200 | hidden_states = layer_outputs[0] 201 | 202 | if output_attentions: 203 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 204 | 205 | if output_hidden_states: 206 | all_hidden_states = all_hidden_states + (hidden_states,) 207 | 208 | if not return_dict: 209 | return tuple( 210 | v 211 | for v in [ 212 | hidden_states, 213 | all_hidden_states, 214 | all_self_attentions, 215 | ] 216 | if v is not None 217 | ) 218 | return BaseModelOutput( 219 | last_hidden_state=hidden_states, 220 | hidden_states=all_hidden_states, 221 | attentions=all_self_attentions, 222 | ) 223 | 224 | 225 | class FLASHQuadModel(FLASHQuadPreTrainedModel): 226 | def __init__(self, config): 227 | super().__init__(config) 228 | self.config = config 229 | 230 | self.embeddings = FLASHQuadEmbeddings(config) 231 | self.encoder = FLASHQuadEncoder(config) 232 | 233 | self.post_init() 234 | 235 | def get_input_embeddings(self): 236 | return self.embeddings.word_embeddings 237 | 238 | def set_input_embeddings(self, value): 239 | self.embeddings.word_embeddings = value 240 | 241 | def forward( 242 | self, 243 | input_ids=None, 244 | attention_mask=None, 245 | token_type_ids=None, 246 | position_ids=None, 247 | output_attentions=None, 248 | output_hidden_states=None, 249 | return_dict=None, 250 | ): 251 | output_attentions = ( 252 | output_attentions 253 | if output_attentions is not None 254 | else self.config.output_attentions 255 | ) 256 | output_hidden_states = ( 257 | output_hidden_states 258 | if output_hidden_states is not None 259 | else self.config.output_hidden_states 260 | ) 261 | return_dict = ( 262 | return_dict if return_dict is not None else self.config.use_return_dict 263 | ) 264 | 265 | if attention_mask is None: 266 | attention_mask = (input_ids != self.config.pad_token_id).type_as( 267 | self.embeddings.word_embeddings.weight 268 | ) 269 | 270 | if token_type_ids is None: 271 | token_type_ids = torch.zeros_like(input_ids) 272 | 273 | embedding_output = self.embeddings( 274 | input_ids=input_ids, 275 | position_ids=position_ids, 276 | token_type_ids=token_type_ids, 277 | ) 278 | 279 | encoder_outputs = self.encoder( 280 | embedding_output, 281 | attention_mask=attention_mask, 282 | output_attentions=output_attentions, 283 | output_hidden_states=output_hidden_states, 284 | return_dict=return_dict, 285 | ) 286 | sequence_output = encoder_outputs[0] 287 | 288 | if not return_dict: 289 | return (sequence_output,) + encoder_outputs[1:] 290 | 291 | return BaseModelOutput( 292 | last_hidden_state=sequence_output, 293 | hidden_states=encoder_outputs.hidden_states, 294 | attentions=encoder_outputs.attentions, 295 | ) 296 | 297 | 298 | class FLASHQuadForMaskedLM(FLASHQuadPreTrainedModel): 299 | def __init__(self, config): 300 | super().__init__(config) 301 | 302 | self.flash_quad = FLASHQuadModel(config) 303 | self.cls = FLASHQuadOnlyMLMHead(config) 304 | self.cls.predictions.transform.LayerNorm = ( 305 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 306 | if config.norm_type == "layer_norm" 307 | else ScaleNorm(eps=config.layer_norm_eps) 308 | ) 309 | self.loss_fn = nn.CrossEntropyLoss() 310 | self.post_init() 311 | 312 | def get_output_embeddings(self): 313 | return self.cls.predictions.decoder 314 | 315 | def set_output_embeddings(self, new_embeddings): 316 | self.cls.predictions.decoder = new_embeddings 317 | 318 | def forward( 319 | self, 320 | input_ids=None, 321 | attention_mask=None, 322 | token_type_ids=None, 323 | position_ids=None, 324 | labels=None, 325 | output_attentions=None, 326 | output_hidden_states=None, 327 | return_dict=None, 328 | ): 329 | 330 | return_dict = ( 331 | return_dict if return_dict is not None else self.config.use_return_dict 332 | ) 333 | 334 | outputs = self.flash_quad( 335 | input_ids, 336 | attention_mask=attention_mask, 337 | token_type_ids=token_type_ids, 338 | position_ids=position_ids, 339 | output_attentions=output_attentions, 340 | output_hidden_states=output_hidden_states, 341 | return_dict=return_dict, 342 | ) 343 | 344 | sequence_output = outputs[0] 345 | 346 | prediction_scores = self.cls(sequence_output) 347 | 348 | masked_lm_loss = None 349 | if labels is not None: 350 | masked_lm_loss = self.loss_fn( 351 | prediction_scores.reshape(-1, self.config.vocab_size), 352 | labels.reshape(-1), 353 | ) 354 | 355 | if not return_dict: 356 | output = (prediction_scores,) + outputs[1:] 357 | return ( 358 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 359 | ) 360 | 361 | return MaskedLMOutput( 362 | loss=masked_lm_loss, 363 | logits=prediction_scores, 364 | hidden_states=outputs.hidden_states, 365 | attentions=outputs.attentions, 366 | ) 367 | 368 | 369 | class FLASHQuadForSequenceClassification(FLASHQuadPreTrainedModel): 370 | def __init__(self, config): 371 | super().__init__(config) 372 | self.num_labels = config.num_labels 373 | self.config = config 374 | 375 | self.flash_quad = FLASHQuadModel(config) 376 | classifier_dropout = ( 377 | config.classifier_dropout 378 | if config.classifier_dropout is not None 379 | else config.dropout 380 | ) 381 | self.dropout = nn.Dropout(classifier_dropout) 382 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 383 | 384 | # Initialize weights and apply final processing 385 | self.post_init() 386 | 387 | def forward( 388 | self, 389 | input_ids=None, 390 | attention_mask=None, 391 | token_type_ids=None, 392 | position_ids=None, 393 | labels=None, 394 | output_attentions=None, 395 | output_hidden_states=None, 396 | return_dict=None, 397 | ): 398 | 399 | return_dict = ( 400 | return_dict if return_dict is not None else self.config.use_return_dict 401 | ) 402 | 403 | outputs = self.flash_quad( 404 | input_ids, 405 | attention_mask=attention_mask, 406 | token_type_ids=token_type_ids, 407 | position_ids=position_ids, 408 | output_attentions=output_attentions, 409 | output_hidden_states=output_hidden_states, 410 | return_dict=return_dict, 411 | ) 412 | pooled_output = outputs[0][:, 0] 413 | logits = self.classifier(pooled_output) 414 | 415 | loss = None 416 | if labels is not None: 417 | if self.config.problem_type is None: 418 | if self.num_labels == 1: 419 | self.config.problem_type = "regression" 420 | elif self.num_labels > 1 and ( 421 | labels.dtype == torch.long or labels.dtype == torch.int 422 | ): 423 | self.config.problem_type = "single_label_classification" 424 | else: 425 | self.config.problem_type = "multi_label_classification" 426 | 427 | if self.config.problem_type == "regression": 428 | loss_fct = nn.MSELoss() 429 | if self.num_labels == 1: 430 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 431 | else: 432 | loss = loss_fct(logits, labels) 433 | elif self.config.problem_type == "single_label_classification": 434 | loss_fct = nn.CrossEntropyLoss() 435 | loss = loss_fct( 436 | logits.reshape(-1, self.num_labels), labels.reshape(-1)) 437 | elif self.config.problem_type == "multi_label_classification": 438 | loss_fct = nn.BCEWithLogitsLoss() 439 | loss = loss_fct(logits, labels) 440 | if not return_dict: 441 | output = (logits,) + outputs[2:] 442 | return ((loss,) + output) if loss is not None else output 443 | 444 | return SequenceClassifierOutput( 445 | loss=loss, 446 | logits=logits, 447 | hidden_states=outputs.hidden_states, 448 | attentions=outputs.attentions, 449 | ) 450 | -------------------------------------------------------------------------------- /flash/flash_lucidrains.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.modeling_outputs import ( 5 | BaseModelOutput, 6 | MaskedLMOutput, 7 | SequenceClassifierOutput, 8 | MultipleChoiceModelOutput, 9 | TokenClassifierOutput, 10 | QuestionAnsweringModelOutput, 11 | ) 12 | from transformers.modeling_utils import PreTrainedModel, SequenceSummary 13 | 14 | from transformers.utils import logging 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | ###################################################copied from https://github.com/lucidrains/FLASH-pytorch 19 | import math 20 | import torch.nn.functional as F 21 | from torch import nn, einsum 22 | from einops import rearrange 23 | from rotary_embedding_torch import RotaryEmbedding 24 | 25 | # helper functions 26 | 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | 32 | def default(val, d): 33 | return val if exists(val) else d 34 | 35 | 36 | def padding_to_multiple_of(n, mult): 37 | remainder = n % mult 38 | if remainder == 0: 39 | return 0 40 | return mult - remainder 41 | 42 | 43 | # scalenorm 44 | 45 | 46 | class ScaleNorm(nn.Module): 47 | def __init__(self, dim, eps=1e-5): 48 | super().__init__() 49 | self.scale = dim ** -0.5 50 | self.eps = eps 51 | self.g = nn.Parameter(torch.ones(1)) 52 | 53 | def forward(self, x): 54 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 55 | return x / norm.clamp(min=self.eps) * self.g 56 | 57 | 58 | # T5 relative positional bias 59 | 60 | 61 | class T5RelativePositionBias(nn.Module): 62 | def __init__(self, scale, causal=False, num_buckets=32, max_distance=128): 63 | super().__init__() 64 | self.scale = scale 65 | self.causal = causal 66 | self.num_buckets = num_buckets 67 | self.max_distance = max_distance 68 | self.relative_attention_bias = nn.Embedding(num_buckets, 1) 69 | 70 | @staticmethod 71 | def _relative_position_bucket( 72 | relative_position, causal=True, num_buckets=32, max_distance=128 73 | ): 74 | ret = 0 75 | n = -relative_position 76 | if not causal: 77 | num_buckets //= 2 78 | ret += (n < 0).long() * num_buckets 79 | n = torch.abs(n) 80 | else: 81 | n = torch.max(n, torch.zeros_like(n)) 82 | 83 | max_exact = num_buckets // 2 84 | is_small = n < max_exact 85 | 86 | val_if_large = ( 87 | max_exact 88 | + ( 89 | torch.log(n.float() / max_exact) 90 | / math.log(max_distance / max_exact) 91 | * (num_buckets - max_exact) 92 | ).long() 93 | ) 94 | val_if_large = torch.min( 95 | val_if_large, torch.full_like(val_if_large, num_buckets - 1) 96 | ) 97 | 98 | ret += torch.where(is_small, n, val_if_large) 99 | return ret 100 | 101 | def forward(self, x): 102 | i, j, device = *x.shape[-2:], x.device 103 | q_pos = torch.arange(i, dtype=torch.long, device=device) 104 | k_pos = torch.arange(j, dtype=torch.long, device=device) 105 | rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1") 106 | rp_bucket = self._relative_position_bucket( 107 | rel_pos, 108 | causal=self.causal, 109 | num_buckets=self.num_buckets, 110 | max_distance=self.max_distance, 111 | ) 112 | values = self.relative_attention_bias(rp_bucket) 113 | bias = rearrange(values, "i j 1 -> i j") 114 | return bias * self.scale 115 | 116 | 117 | # class 118 | 119 | 120 | class OffsetScale(nn.Module): 121 | def __init__(self, dim, heads=1): 122 | super().__init__() 123 | self.weight = nn.Parameter(torch.ones(heads, dim)) 124 | self.bias = nn.Parameter(torch.zeros(heads, dim)) 125 | nn.init.normal_(self.weight, std=0.02) 126 | 127 | def forward(self, x): 128 | out = einsum("... d, h d -> ... h d", x, self.weight) + self.bias 129 | return out.unbind(dim=-2) 130 | 131 | 132 | # FLASH 133 | 134 | 135 | class FLASH(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | dim, 140 | group_size=256, 141 | query_key_dim=128, 142 | expansion_factor=2.0, 143 | causal=False, 144 | dropout=0.0, 145 | rotary_pos_emb=None, 146 | norm_klass=nn.LayerNorm, 147 | shift_tokens=False 148 | ): 149 | super().__init__() 150 | hidden_dim = int(dim * expansion_factor) 151 | self.group_size = group_size 152 | self.causal = causal 153 | self.shift_tokens = shift_tokens 154 | 155 | # positional embeddings 156 | 157 | self.rotary_pos_emb = rotary_pos_emb 158 | self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal=causal) 159 | 160 | # norm 161 | 162 | self.norm = norm_klass(dim) 163 | self.dropout = nn.Dropout(dropout) 164 | 165 | # projections 166 | 167 | self.to_hidden = nn.Sequential(nn.Linear(dim, hidden_dim * 2), nn.SiLU()) 168 | 169 | self.to_qk = nn.Sequential(nn.Linear(dim, query_key_dim), nn.SiLU()) 170 | 171 | self.qk_offset_scale = OffsetScale(query_key_dim, heads=4) 172 | self.to_out = nn.Linear(hidden_dim, dim) 173 | 174 | def forward(self, x, attention_mask=None, output_attentions=False): 175 | """ 176 | b - batch 177 | n - sequence length (within groups) 178 | g - group dimension 179 | d - feature dimension (keys) 180 | e - feature dimension (values) 181 | i - sequence dimension (source) 182 | j - sequence dimension (target) 183 | """ 184 | mask = attention_mask 185 | b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size 186 | 187 | # prenorm 188 | 189 | normed_x = self.norm(x) 190 | 191 | # do token shift - a great, costless trick from an independent AI researcher in Shenzhen 192 | 193 | if self.shift_tokens: 194 | x_shift, x_pass = normed_x.chunk(2, dim=-1) 195 | x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0) 196 | normed_x = torch.cat((x_shift, x_pass), dim=-1) 197 | 198 | # initial projections 199 | 200 | v, gate = self.to_hidden(normed_x).chunk(2, dim=-1) 201 | qk = self.to_qk(normed_x) 202 | 203 | # offset and scale 204 | 205 | quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) 206 | 207 | # mask out linear attention keys 208 | 209 | if exists(mask): 210 | mask = mask.bool() 211 | lin_k = lin_k.masked_fill(~mask[..., None], 0.0) 212 | 213 | # rotate queries and keys 214 | 215 | if exists(self.rotary_pos_emb): 216 | quad_q, lin_q, quad_k, lin_k = map( 217 | self.rotary_pos_emb.rotate_queries_or_keys, 218 | (quad_q, lin_q, quad_k, lin_k), 219 | ) 220 | 221 | # padding for groups 222 | 223 | padding = padding_to_multiple_of(n, g) 224 | 225 | if padding > 0: 226 | quad_q, quad_k, lin_q, lin_k, v = map( 227 | lambda t: F.pad(t, (0, 0, 0, padding), value=0.0), 228 | (quad_q, quad_k, lin_q, lin_k, v), 229 | ) 230 | 231 | mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool)) 232 | mask = F.pad(mask, (0, padding), value=False) 233 | 234 | # group along sequence 235 | 236 | quad_q, quad_k, lin_q, lin_k, v = map( 237 | lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size), 238 | (quad_q, quad_k, lin_q, lin_k, v), 239 | ) 240 | 241 | if exists(mask): 242 | mask = rearrange(mask, "b (g j) -> b g 1 j", j=g) 243 | 244 | # calculate quadratic attention output 245 | 246 | sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g 247 | 248 | sim = sim + self.rel_pos_bias(sim) 249 | 250 | attn = F.relu(sim) ** 2 251 | attn = self.dropout(attn) 252 | 253 | if exists(mask): 254 | attn = attn.masked_fill(~mask, 0.0) 255 | 256 | if self.causal: 257 | causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1) 258 | attn = attn.masked_fill(causal_mask, 0.0) 259 | 260 | quad_out = einsum("... i j, ... j d -> ... i d", attn, v) 261 | 262 | # calculate linear attention output 263 | 264 | if self.causal: 265 | lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g 266 | 267 | # exclusive cumulative sum along group dimension 268 | 269 | lin_kv = lin_kv.cumsum(dim=1) 270 | lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0) 271 | 272 | lin_out = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q) 273 | else: 274 | lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n 275 | lin_out = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv) 276 | 277 | # fold back groups into full sequence, and excise out padding 278 | 279 | quad_attn_out, lin_attn_out = map( 280 | lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n], (quad_out, lin_out) 281 | ) 282 | 283 | # gate 284 | 285 | out = gate * (quad_attn_out + lin_attn_out) 286 | 287 | # projection out and residual 288 | out = self.to_out(out) + x 289 | return (out, attn) if output_attentions else (out,) 290 | 291 | 292 | ################################################### 293 | 294 | 295 | class FLASHConfig(PretrainedConfig): 296 | model_type = "flash" 297 | 298 | def __init__( 299 | self, 300 | vocab_size=12000, 301 | hidden_size=768, 302 | num_hidden_layers=12, # base 303 | max_position_embeddings=512, 304 | group_size=256, 305 | initializer_range=0.02, 306 | layer_norm_eps=1e-5, 307 | pad_token_id=0, 308 | expansion_factor=2, 309 | query_key_dim=128, 310 | norm_type="scalenorm", 311 | gradient_checkpointing=False, 312 | dropout=0.0, 313 | classifier_dropout=0.1, 314 | **kwargs 315 | ): 316 | super().__init__(pad_token_id=pad_token_id, **kwargs) 317 | 318 | self.vocab_size = vocab_size 319 | self.hidden_size = hidden_size 320 | self.num_hidden_layers = num_hidden_layers 321 | self.max_position_embeddings = max_position_embeddings 322 | self.group_size = group_size 323 | self.initializer_range = initializer_range 324 | self.layer_norm_eps = layer_norm_eps 325 | self.expansion_factor = expansion_factor 326 | self.query_key_dim = query_key_dim 327 | self.norm_type = norm_type 328 | self.dropout = dropout 329 | self.gradient_checkpointing = gradient_checkpointing 330 | self.classifier_dropout = classifier_dropout 331 | 332 | 333 | class FLASHPreTrainedModel(PreTrainedModel): 334 | config_class = FLASHConfig 335 | base_model_prefix = "flash" 336 | 337 | def _init_weights(self, module): 338 | """Initialize the weights""" 339 | if isinstance(module, nn.Linear): 340 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 341 | if module.bias is not None: 342 | module.bias.data.zero_() 343 | elif isinstance(module, nn.Embedding): 344 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 345 | if module.padding_idx is not None: 346 | module.weight.data[module.padding_idx].zero_() 347 | elif isinstance(module, nn.LayerNorm): 348 | module.bias.data.zero_() 349 | module.weight.data.fill_(1.0) 350 | 351 | 352 | class FLASHLMPredictionHead(nn.Module): 353 | def __init__(self, config): 354 | super().__init__() 355 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 356 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size) 357 | 358 | def forward(self, hidden_states): 359 | return self.decoder(self.LayerNorm(hidden_states)) 360 | 361 | 362 | class FLASHOnlyMLMHead(nn.Module): 363 | def __init__(self, config): 364 | super().__init__() 365 | self.predictions = FLASHLMPredictionHead(config) 366 | 367 | def forward(self, sequence_output): 368 | prediction_scores = self.predictions(sequence_output) 369 | return prediction_scores 370 | 371 | 372 | class FLASHEmbeddings(nn.Module): 373 | """Construct the embeddings from word, position embeddings.""" 374 | 375 | def __init__(self, config): 376 | super().__init__() 377 | self.config = config 378 | self.word_embeddings = nn.Embedding( 379 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id 380 | ) 381 | self.register_buffer( 382 | "position_ids", 383 | torch.arange(config.max_position_embeddings).expand((1, -1)), 384 | persistent=False, 385 | ) 386 | self.scale = nn.Parameter(torch.ones(1)) 387 | self.register_buffer( 388 | "scaledsin_embeddings", self.get_scaledsin(), persistent=False 389 | ) 390 | 391 | def get_scaledsin(self): 392 | """Create sinusoidal position embedding with a scaling factor.""" 393 | seqlen, hidden_size = ( 394 | self.config.max_position_embeddings, 395 | self.config.hidden_size, 396 | ) 397 | pos = torch.arange(seqlen, dtype=torch.float32) 398 | half_d = hidden_size // 2 399 | 400 | freq_seq = -torch.arange(half_d, dtype=torch.float32) / float(half_d) 401 | inv_freq = 10000 ** freq_seq 402 | sinusoid = torch.einsum("s,d->sd", pos, inv_freq) 403 | scaledsin = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1) 404 | return scaledsin 405 | 406 | def forward(self, input_ids=None, position_ids=None): 407 | input_shape = input_ids.shape 408 | seq_length = input_shape[1] 409 | 410 | if position_ids is None: 411 | position_ids = self.position_ids[:, :seq_length] 412 | inputs_embeds = self.word_embeddings(input_ids) 413 | position_embeddings = self.scaledsin_embeddings[position_ids] * self.scale 414 | embeddings = inputs_embeds + position_embeddings 415 | return embeddings 416 | 417 | 418 | class FLASHEncoder(nn.Module): 419 | def __init__(self, config): 420 | super().__init__() 421 | self.config = config 422 | if config.norm_type == "scalenorm": 423 | norm_klass = ScaleNorm 424 | elif config.norm_type == "layernorm": 425 | norm_klass = nn.LayerNorm 426 | rotary_pos_emb = RotaryEmbedding(dim=min(32, config.query_key_dim)) 427 | self.layers = nn.ModuleList( 428 | [ 429 | FLASH( 430 | dim=config.hidden_size, 431 | group_size=config.group_size, 432 | query_key_dim=config.query_key_dim, 433 | expansion_factor=config.expansion_factor, 434 | causal=False, 435 | dropout=config.dropout, 436 | rotary_pos_emb=rotary_pos_emb, 437 | norm_klass=norm_klass, 438 | shift_tokens=False, 439 | ) 440 | for _ in range(config.num_hidden_layers) 441 | ] 442 | ) 443 | 444 | def forward( 445 | self, 446 | hidden_states, 447 | attention_mask=None, 448 | output_attentions=False, 449 | output_hidden_states=False, 450 | return_dict=True, 451 | ): 452 | all_hidden_states = () if output_hidden_states else None 453 | all_self_attentions = () if output_attentions else None 454 | 455 | for i, layer_module in enumerate(self.layers): 456 | if output_hidden_states: 457 | all_hidden_states = all_hidden_states + (hidden_states,) 458 | 459 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 460 | 461 | def create_custom_forward(module): 462 | def custom_forward(*inputs): 463 | return module(*inputs, output_attentions) 464 | 465 | return custom_forward 466 | 467 | layer_outputs = torch.utils.checkpoint.checkpoint( 468 | create_custom_forward(layer_module), 469 | hidden_states, 470 | attention_mask, 471 | output_attentions, 472 | ) 473 | else: 474 | layer_outputs = layer_module( 475 | hidden_states, attention_mask, output_attentions 476 | ) 477 | 478 | hidden_states = layer_outputs[0] 479 | 480 | if output_attentions: 481 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 482 | 483 | if output_hidden_states: 484 | all_hidden_states = all_hidden_states + (hidden_states,) 485 | 486 | if not return_dict: 487 | return tuple( 488 | v 489 | for v in [ 490 | hidden_states, 491 | all_hidden_states, 492 | all_self_attentions, 493 | ] 494 | if v is not None 495 | ) 496 | return BaseModelOutput( 497 | last_hidden_state=hidden_states, 498 | hidden_states=all_hidden_states, 499 | attentions=all_self_attentions, 500 | ) 501 | 502 | 503 | class FLASHModel(FLASHPreTrainedModel): 504 | def __init__(self, config): 505 | super().__init__(config) 506 | self.config = config 507 | 508 | self.embeddings = FLASHEmbeddings(config) 509 | self.encoder = FLASHEncoder(config) 510 | 511 | self.post_init() 512 | 513 | def get_input_embeddings(self): 514 | return self.embeddings.word_embeddings 515 | 516 | def set_input_embeddings(self, value): 517 | self.embeddings.word_embeddings = value 518 | 519 | def forward( 520 | self, 521 | input_ids=None, 522 | attention_mask=None, 523 | position_ids=None, 524 | output_attentions=None, 525 | output_hidden_states=None, 526 | return_dict=None, 527 | ): 528 | output_attentions = ( 529 | output_attentions 530 | if output_attentions is not None 531 | else self.config.output_attentions 532 | ) 533 | output_hidden_states = ( 534 | output_hidden_states 535 | if output_hidden_states is not None 536 | else self.config.output_hidden_states 537 | ) 538 | return_dict = ( 539 | return_dict if return_dict is not None else self.config.use_return_dict 540 | ) 541 | 542 | embedding_output = self.embeddings( 543 | input_ids=input_ids, position_ids=position_ids 544 | ) 545 | 546 | encoder_outputs = self.encoder( 547 | embedding_output, 548 | attention_mask=attention_mask, 549 | output_attentions=output_attentions, 550 | output_hidden_states=output_hidden_states, 551 | return_dict=return_dict, 552 | ) 553 | sequence_output = encoder_outputs[0] 554 | 555 | if not return_dict: 556 | return (sequence_output,) + encoder_outputs[1:] 557 | 558 | return BaseModelOutput( 559 | last_hidden_state=sequence_output, 560 | hidden_states=encoder_outputs.hidden_states, 561 | attentions=encoder_outputs.attentions, 562 | ) 563 | 564 | 565 | class FLASHForMaskedLM(FLASHPreTrainedModel): 566 | def __init__(self, config): 567 | super().__init__(config) 568 | 569 | self.flash = FLASHModel(config) 570 | self.cls = FLASHOnlyMLMHead(config) 571 | 572 | self.post_init() 573 | 574 | def get_output_embeddings(self): 575 | return self.cls.predictions.decoder 576 | 577 | def set_output_embeddings(self, new_embeddings): 578 | self.cls.predictions.decoder = new_embeddings 579 | 580 | def forward( 581 | self, 582 | input_ids=None, 583 | attention_mask=None, 584 | position_ids=None, 585 | labels=None, 586 | output_attentions=None, 587 | output_hidden_states=None, 588 | return_dict=None, 589 | ): 590 | 591 | return_dict = ( 592 | return_dict if return_dict is not None else self.config.use_return_dict 593 | ) 594 | 595 | outputs = self.flash( 596 | input_ids, 597 | attention_mask=attention_mask, 598 | position_ids=position_ids, 599 | output_attentions=output_attentions, 600 | output_hidden_states=output_hidden_states, 601 | return_dict=return_dict, 602 | ) 603 | 604 | sequence_output = outputs[0] 605 | 606 | prediction_scores = self.cls(sequence_output) 607 | 608 | masked_lm_loss = None 609 | if labels is not None: 610 | loss_fct = nn.CrossEntropyLoss() 611 | masked_lm_loss = loss_fct( 612 | prediction_scores.reshape(-1, self.config.vocab_size), 613 | labels.reshape(-1), 614 | ) 615 | 616 | if not return_dict: 617 | output = (prediction_scores,) + outputs[1:] 618 | return ( 619 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 620 | ) 621 | 622 | return MaskedLMOutput( 623 | loss=masked_lm_loss, 624 | logits=prediction_scores, 625 | hidden_states=outputs.hidden_states, 626 | attentions=outputs.attentions, 627 | ) 628 | 629 | 630 | class FLASHForSequenceClassification(FLASHPreTrainedModel): 631 | def __init__(self, config): 632 | super().__init__(config) 633 | self.num_labels = config.num_labels 634 | self.flash = FLASHModel(config) 635 | self.sequence_summary = SequenceSummary(config) 636 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 637 | 638 | # Initialize weights and apply final processing 639 | self.post_init() 640 | 641 | def forward( 642 | self, 643 | input_ids=None, 644 | attention_mask=None, 645 | position_ids=None, 646 | labels=None, 647 | output_attentions=None, 648 | output_hidden_states=None, 649 | return_dict=None, 650 | ): 651 | 652 | return_dict = ( 653 | return_dict if return_dict is not None else self.config.use_return_dict 654 | ) 655 | 656 | outputs = self.flash( 657 | input_ids, 658 | attention_mask=attention_mask, 659 | position_ids=position_ids, 660 | output_attentions=output_attentions, 661 | output_hidden_states=output_hidden_states, 662 | return_dict=return_dict, 663 | ) 664 | pooled_output = self.sequence_summary(outputs[0]) 665 | logits = self.classifier(pooled_output) 666 | 667 | loss = None 668 | if labels is not None: 669 | if self.config.problem_type is None: 670 | if self.num_labels == 1: 671 | self.config.problem_type = "regression" 672 | elif self.num_labels > 1 and ( 673 | labels.dtype == torch.long or labels.dtype == torch.int 674 | ): 675 | self.config.problem_type = "single_label_classification" 676 | else: 677 | self.config.problem_type = "multi_label_classification" 678 | 679 | if self.config.problem_type == "regression": 680 | loss_fct = nn.MSELoss() 681 | if self.num_labels == 1: 682 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 683 | else: 684 | loss = loss_fct(logits, labels) 685 | elif self.config.problem_type == "single_label_classification": 686 | loss_fct = nn.CrossEntropyLoss() 687 | loss = loss_fct(logits.reshape(-1, self.num_labels), labels.reshape(-1)) 688 | elif self.config.problem_type == "multi_label_classification": 689 | loss_fct = nn.BCEWithLogitsLoss() 690 | loss = loss_fct(logits, labels) 691 | if not return_dict: 692 | output = (logits,) + outputs[1:] 693 | return ((loss,) + output) if loss is not None else output 694 | 695 | return SequenceClassifierOutput( 696 | loss=loss, 697 | logits=logits, 698 | hidden_states=outputs.hidden_states, 699 | attentions=outputs.attentions, 700 | ) 701 | 702 | 703 | class FLASHForMultipleChoice(FLASHPreTrainedModel): 704 | def __init__(self, config): 705 | super().__init__(config) 706 | 707 | self.flash = FLASHModel(config) 708 | self.sequence_summary = SequenceSummary(config) 709 | self.classifier = nn.Linear(config.hidden_size, 1) 710 | 711 | # Initialize weights and apply final processing 712 | self.post_init() 713 | 714 | def forward( 715 | self, 716 | input_ids=None, 717 | attention_mask=None, 718 | position_ids=None, 719 | labels=None, 720 | output_attentions=None, 721 | output_hidden_states=None, 722 | return_dict=None, 723 | ): 724 | 725 | return_dict = ( 726 | return_dict if return_dict is not None else self.config.use_return_dict 727 | ) 728 | num_choices = input_ids.shape[1] 729 | input_ids = ( 730 | input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None 731 | ) 732 | attention_mask = ( 733 | attention_mask.reshape(-1, attention_mask.size(-1)) 734 | if attention_mask is not None 735 | else None 736 | ) 737 | position_ids = ( 738 | position_ids.reshape(-1, position_ids.size(-1)) 739 | if position_ids is not None 740 | else None 741 | ) 742 | 743 | outputs = self.flash( 744 | input_ids, 745 | attention_mask=attention_mask, 746 | position_ids=position_ids, 747 | output_attentions=output_attentions, 748 | output_hidden_states=output_hidden_states, 749 | return_dict=return_dict, 750 | ) 751 | 752 | pooled_output = self.sequence_summary(outputs[0]) 753 | logits = self.classifier(pooled_output) 754 | reshaped_logits = logits.reshape(-1, num_choices) 755 | 756 | loss = None 757 | if labels is not None: 758 | loss_fct = nn.CrossEntropyLoss() 759 | loss = loss_fct(reshaped_logits, labels) 760 | 761 | if not return_dict: 762 | output = (reshaped_logits,) + outputs[1:] 763 | return ((loss,) + output) if loss is not None else output 764 | 765 | return MultipleChoiceModelOutput( 766 | loss=loss, 767 | logits=reshaped_logits, 768 | hidden_states=outputs.hidden_states, 769 | attentions=outputs.attentions, 770 | ) 771 | 772 | 773 | class FLASHForTokenClassification(FLASHPreTrainedModel): 774 | def __init__(self, config): 775 | super().__init__(config) 776 | self.num_labels = config.num_labels 777 | self.flash = FLASHModel(config) 778 | classifier_dropout = ( 779 | config.classifier_dropout 780 | if config.classifier_dropout is not None 781 | else config.dropout 782 | ) 783 | self.dropout = nn.Dropout(classifier_dropout) 784 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 785 | 786 | # Initialize weights and apply final processing 787 | self.post_init() 788 | 789 | def forward( 790 | self, 791 | input_ids=None, 792 | attention_mask=None, 793 | position_ids=None, 794 | labels=None, 795 | output_attentions=None, 796 | output_hidden_states=None, 797 | return_dict=None, 798 | ): 799 | return_dict = ( 800 | return_dict if return_dict is not None else self.config.use_return_dict 801 | ) 802 | 803 | outputs = self.flash( 804 | input_ids, 805 | attention_mask=attention_mask, 806 | position_ids=position_ids, 807 | output_attentions=output_attentions, 808 | output_hidden_states=output_hidden_states, 809 | return_dict=return_dict, 810 | ) 811 | 812 | sequence_output = outputs[0] 813 | sequence_output = self.dropout(sequence_output) 814 | logits = self.classifier(sequence_output) 815 | 816 | loss = None 817 | if labels is not None: 818 | loss_fct = nn.CrossEntropyLoss() 819 | loss = loss_fct(logits.reshape(-1, self.num_labels), labels.reshape(-1)) 820 | 821 | if not return_dict: 822 | output = (logits,) + outputs[1:] 823 | return ((loss,) + output) if loss is not None else output 824 | 825 | return TokenClassifierOutput( 826 | loss=loss, 827 | logits=logits, 828 | hidden_states=outputs.hidden_states, 829 | attentions=outputs.attentions, 830 | ) 831 | 832 | 833 | class FLASHForQuestionAnswering(FLASHPreTrainedModel): 834 | def __init__(self, config): 835 | super().__init__(config) 836 | self.num_labels = config.num_labels 837 | 838 | self.flash = FLASHModel(config) 839 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 840 | 841 | # Initialize weights and apply final processing 842 | self.post_init() 843 | 844 | def forward( 845 | self, 846 | input_ids=None, 847 | attention_mask=None, 848 | position_ids=None, 849 | start_positions=None, 850 | end_positions=None, 851 | output_attentions=None, 852 | output_hidden_states=None, 853 | return_dict=None, 854 | ): 855 | 856 | return_dict = ( 857 | return_dict if return_dict is not None else self.config.use_return_dict 858 | ) 859 | 860 | outputs = self.flash( 861 | input_ids, 862 | attention_mask=attention_mask, 863 | position_ids=position_ids, 864 | output_attentions=output_attentions, 865 | output_hidden_states=output_hidden_states, 866 | return_dict=return_dict, 867 | ) 868 | 869 | sequence_output = outputs[0] 870 | 871 | logits = self.qa_outputs(sequence_output) 872 | start_logits, end_logits = logits.split(1, dim=-1) 873 | start_logits = start_logits.squeeze(-1).contiguous() 874 | end_logits = end_logits.squeeze(-1).contiguous() 875 | 876 | total_loss = None 877 | if start_positions is not None and end_positions is not None: 878 | # If we are on multi-GPU, split add a dimension 879 | if start_positions.ndim > 1: 880 | start_positions = start_positions.squeeze(-1) 881 | if start_positions.ndim > 1: 882 | end_positions = end_positions.squeeze(-1) 883 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 884 | ignored_index = start_logits.size(1) 885 | start_positions = start_positions.clamp(0, ignored_index) 886 | end_positions = end_positions.clamp(0, ignored_index) 887 | 888 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 889 | start_loss = loss_fct(start_logits, start_positions) 890 | end_loss = loss_fct(end_logits, end_positions) 891 | total_loss = (start_loss + end_loss) / 2 892 | 893 | if not return_dict: 894 | output = (start_logits, end_logits) + outputs[1:] 895 | return ((total_loss,) + output) if total_loss is not None else output 896 | 897 | return QuestionAnsweringModelOutput( 898 | loss=total_loss, 899 | start_logits=start_logits, 900 | end_logits=end_logits, 901 | hidden_states=outputs.hidden_states, 902 | attentions=outputs.attentions, 903 | ) 904 | -------------------------------------------------------------------------------- /flash/gau.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers.activations import ACT2FN 5 | 6 | 7 | def rope(x, dim): 8 | """RoPE position embedding.""" 9 | shape = x.shape 10 | if isinstance(dim, int): 11 | dim = [dim] 12 | spatial_shape = [shape[i] for i in dim] 13 | total_len = 1 14 | for i in spatial_shape: 15 | total_len *= i 16 | position = torch.reshape( 17 | torch.arange(total_len, dtype=x.dtype, 18 | device=x.device), spatial_shape 19 | ) 20 | for i in range(dim[-1] + 1, len(shape) - 1, 1): 21 | position = position.unsqueeze(-1) 22 | half_size = shape[-1] // 2 23 | freq_seq = -torch.arange(half_size, dtype=x.dtype, device=x.device) / float( 24 | half_size 25 | ) 26 | inv_freq = 10000 ** freq_seq 27 | sinusoid = torch.einsum("...,d->...d", position, inv_freq) 28 | sin = sinusoid.sin() 29 | cos = sinusoid.cos() 30 | x1, x2 = torch.chunk(x, 2, dim=-1) 31 | 32 | return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) 33 | 34 | 35 | class ScaleNorm(nn.Module): 36 | def __init__(self, eps=1e-5): 37 | super().__init__() 38 | self.eps = eps 39 | self.scala = nn.Parameter(torch.ones(1)) 40 | 41 | def forward(self, x): 42 | mean_square = (x ** 2).mean(dim=-1, keepdim=True) 43 | x = x * torch.rsqrt(mean_square + self.eps) * self.scala 44 | return x 45 | 46 | 47 | class GAU(nn.Module): 48 | """GAU block. 49 | Input shape: batch size x sequence length x model size 50 | """ 51 | 52 | def __init__( 53 | self, 54 | hidden_size=768, 55 | expansion_factor=2, 56 | s=128, 57 | norm_type="layer_norm", 58 | eps=1e-5, 59 | hidden_act="silu", 60 | max_position_embeddings=512, 61 | ): 62 | super().__init__() 63 | self.s = s 64 | self.e = int(hidden_size * expansion_factor) 65 | self.uv = nn.Linear(hidden_size, 2 * self.e + self.s) 66 | self.weight = nn.Parameter(torch.randn(2, self.s)) 67 | self.bias = nn.Parameter(torch.zeros(2, self.s)) 68 | self.o = nn.Linear(self.e, hidden_size) 69 | self.LayerNorm = ( 70 | nn.LayerNorm(hidden_size, eps=eps) 71 | if norm_type == "layer_norm" 72 | else ScaleNorm(eps=eps) 73 | ) 74 | self.w = nn.Parameter(torch.randn(2 * max_position_embeddings - 1)) 75 | self.a = nn.Parameter(torch.randn(1, self.s)) 76 | self.b = nn.Parameter(torch.randn(1, self.s)) 77 | self.act_fn = ACT2FN[hidden_act] 78 | self.max_position_embeddings = max_position_embeddings 79 | 80 | nn.init.normal_(self.weight, std=0.02) 81 | nn.init.normal_(self.w, std=0.02) 82 | nn.init.normal_(self.a, std=0.02) 83 | nn.init.normal_(self.b, std=0.02) 84 | 85 | def rel_pos_bias(self, seq_len): 86 | """Relative position bias.""" 87 | if seq_len <= 512: 88 | # Construct Toeplitz matrix directly when the sequence length is less than 512 89 | t = F.pad(self.w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len) 90 | t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) 91 | r = (2 * seq_len - 1) // 2 92 | t = t[..., r:-r] 93 | else: 94 | # Construct Toeplitz matrix using RoPE when the sequence length is over 512. 95 | a = rope(self.a.repeat(seq_len, 1), dim=0) 96 | b = rope(self.b.repeat(seq_len, 1), dim=0) 97 | t = torch.einsum("mk,nk ->mn", a, b) 98 | 99 | return t 100 | 101 | def forward(self, x, attention_mask=None, output_attentions=False, causal=False): 102 | seq_len = x.shape[1] 103 | shortcut, x = x, self.LayerNorm(x) 104 | uv = self.uv(x) 105 | u, v, base = torch.split(self.act_fn( 106 | uv), [self.e, self.e, self.s], dim=-1) 107 | # Generate Query (q) and Key (k) from base. 108 | base = torch.einsum("...r,hr->...hr", base, self.weight) + self.bias 109 | base = rope(base, dim=1) 110 | q, k = torch.unbind(base, dim=-2) 111 | # Calculate the quadratic attention. 112 | qk = torch.einsum("bnd,bmd->bnm", q, k) 113 | 114 | bias = self.rel_pos_bias(self.max_position_embeddings)[ 115 | :, :seq_len, :seq_len] 116 | kernel = torch.square(torch.relu( 117 | qk / self.max_position_embeddings + bias)) 118 | # attention_mask 119 | if attention_mask is not None: 120 | assert attention_mask.ndim == 2 121 | attn_mask = ( 122 | attention_mask[:, None, :] * attention_mask[:, :, None] 123 | ).type_as(x) 124 | kernel *= attn_mask 125 | 126 | if causal: 127 | causal_mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0) 128 | kernel *= causal_mask 129 | 130 | x = u * torch.einsum("bnm,bme->bne", kernel, v) 131 | x = self.o(x) 132 | if output_attentions: 133 | return x + shortcut, kernel 134 | return (x + shortcut,) 135 | -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | TRAIN_DIR=./clue_small_wwm_data 2 | OUTPUT_DIR=./wwm_flash_small/ 3 | BATCH_SIZE=32 4 | ACCUMULATION=4 5 | LR=1e-4 6 | python run_mlm_wwm.py \ 7 | --do_train \ 8 | --tokenizer_name junnyu/roformer_chinese_char_base \ 9 | --train_dir $TRAIN_DIR \ 10 | --output_dir $OUTPUT_DIR \ 11 | --logging_dir $OUTPUT_DIR/logs \ 12 | --per_device_train_batch_size $BATCH_SIZE \ 13 | --gradient_accumulation_steps $ACCUMULATION \ 14 | --learning_rate $LR \ 15 | --weight_decay 0.01 \ 16 | --adam_epsilon 1e-6 \ 17 | --max_steps 250000 \ 18 | --warmup_steps 5000 \ 19 | --logging_steps 100 \ 20 | --save_steps 5000 \ 21 | --seed 2022 \ 22 | --max_grad_norm 3.0 \ 23 | --dataloader_num_workers 6 \ 24 | --fp16 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.16.2 2 | rjieba 3 | datasets 4 | rotary_embedding_torch -------------------------------------------------------------------------------- /run_chinese_ref.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing 4 | import os 5 | import sys 6 | import time 7 | from typing import List 8 | 9 | import rjieba 10 | from datasets import Dataset, concatenate_datasets, load_dataset 11 | from tqdm import tqdm 12 | from transformers import BertTokenizerFast 13 | 14 | 15 | def _is_chinese_char(cp): 16 | """Checks whether CP is the codepoint of a CJK character.""" 17 | # This defines a "chinese character" as anything in the CJK Unicode block: 18 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 19 | # 20 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 21 | # despite its name. The modern Korean Hangul alphabet is a different block, 22 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 23 | # space-separated words, so they are not treated specially and handled 24 | # like the all of the other languages. 25 | if ( 26 | (cp >= 0x4E00 and cp <= 0x9FFF) 27 | or (cp >= 0x3400 and cp <= 0x4DBF) # 28 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 29 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 30 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 31 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 32 | or (cp >= 0xF900 and cp <= 0xFAFF) 33 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 34 | ): # 35 | return True 36 | 37 | return False 38 | 39 | 40 | def is_chinese(word: str): 41 | # word like '180' or '身高' or '神' 42 | for char in word: 43 | char = ord(char) 44 | if not _is_chinese_char(char): 45 | return False 46 | return True 47 | 48 | 49 | def get_chinese_word(tokens: List[str]): 50 | word_set = set() 51 | 52 | for token in tokens: 53 | chinese_word = len(token) > 1 and is_chinese(token) 54 | if chinese_word: 55 | word_set.add(token) 56 | word_list = list(word_set) 57 | return word_list 58 | 59 | 60 | def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()): 61 | if not chinese_word_set: 62 | return bert_tokens 63 | max_word_len = max([len(w) for w in chinese_word_set]) 64 | 65 | bert_word = bert_tokens 66 | start, end = 0, len(bert_word) 67 | while start < end: 68 | single_word = True 69 | if is_chinese(bert_word[start]): 70 | l = min(end - start, max_word_len) 71 | for i in range(l, 1, -1): 72 | whole_word = "".join(bert_word[start : start + i]) 73 | if whole_word in chinese_word_set: 74 | for j in range(start + 1, start + i): 75 | bert_word[j] = "##" + bert_word[j] 76 | start = start + i 77 | single_word = False 78 | break 79 | if single_word: 80 | start += 1 81 | return bert_word 82 | 83 | 84 | block_size = 512 85 | 86 | 87 | class BlockSizeSplitter: 88 | def tokenize(self, text): 89 | tstr = "" 90 | all_ts = [] 91 | for txt in text.split("\n"): 92 | if len(tstr) > block_size: 93 | all_ts.append(tstr) 94 | tstr = "" 95 | tstr += txt 96 | if len(tstr) > 0: 97 | all_ts.append(tstr) 98 | return all_ts 99 | 100 | 101 | def jieba_segmentation_fn(): 102 | def process(line): 103 | words = rjieba.cut(line) 104 | return words 105 | 106 | return process 107 | 108 | 109 | class Converter: 110 | def __init__(self, args): 111 | self.args = args 112 | 113 | def initializer(self): 114 | Converter.tokenizer = BertTokenizerFast.from_pretrained(self.args.model_name) 115 | 116 | # Split document to sentence. 117 | Converter.splitter = BlockSizeSplitter() 118 | Converter.segment_func = jieba_segmentation_fn() 119 | 120 | def process(text): 121 | words = Converter.segment_func(text) 122 | new_text = "".join(words).replace("\n", "") 123 | chinese_word = get_chinese_word(words) 124 | input_tokens = ( 125 | [Converter.tokenizer.cls_token] 126 | + Converter.tokenizer.tokenize(new_text) 127 | + [Converter.tokenizer.sep_token] 128 | ) 129 | 130 | input_tokens = add_sub_symbol(input_tokens, chinese_word) 131 | ref_id = [] 132 | for i, token in enumerate(input_tokens): 133 | if token[:2] == "##": 134 | clean_token = token[2:] 135 | # save chinese tokens' pos 136 | if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)): 137 | ref_id.append(i) 138 | 139 | return ref_id, new_text 140 | 141 | Converter.process = process 142 | 143 | def encode(self, json_line): 144 | text = json.loads(json_line)[self.args.json_key] 145 | ref_ids = [] 146 | all_texts = [] 147 | for sentence in Converter.splitter.tokenize(text): 148 | ref_id, new_text = Converter.process(sentence.strip()) 149 | if len(new_text) < 20: 150 | continue 151 | if len(ref_id) > 0 and len(new_text) > 0: 152 | ref_ids.append(ref_id) 153 | all_texts.append(new_text) 154 | 155 | return ref_ids, all_texts, len(text.encode("utf-8")) 156 | 157 | 158 | def main(args): 159 | 160 | file_paths = [] 161 | if os.path.isfile(args.input_path): 162 | file_paths.append(args.input_path) 163 | else: 164 | for root, _, fs in os.walk(args.input_path): 165 | for f in fs: 166 | file_paths.append(os.path.join(root, f)) 167 | convert = Converter(args) 168 | pool = multiprocessing.Pool(args.workers, initializer=convert.initializer) 169 | step = 0 170 | total_bytes_processed = 0 171 | startup_start = time.time() 172 | with open("data/refids.txt", "w+", encoding="utf8") as w1: 173 | with open("data/reftext.txt", "w+", encoding="utf8") as w2: 174 | for file_path in tqdm(file_paths): 175 | if file_path.endswith(".jsonl"): 176 | text = open(file_path, "r", encoding="utf-8") 177 | else: 178 | print("Unexpected data format, skiped %s" % file_path) 179 | continue 180 | 181 | encoded_docs = pool.imap(convert.encode, text, 256) 182 | print("Processing %s" % file_path) 183 | for rid, alltxt, bytes_processed in encoded_docs: 184 | step += 1 185 | total_bytes_processed += bytes_processed 186 | if len(rid) == 0: 187 | continue 188 | 189 | for sentence in rid: 190 | sentence_len = len(sentence) 191 | if sentence_len == 0: 192 | continue 193 | w1.write(str(sentence) + "\n") 194 | for txt in alltxt: 195 | txt_len = len(txt) 196 | if txt_len == 0: 197 | continue 198 | w2.write(txt + "\n") 199 | 200 | if step % args.log_interval == 0: 201 | current = time.time() 202 | elapsed = current - startup_start 203 | mbs = total_bytes_processed / elapsed / 1024 / 1024 204 | print( 205 | f"Processed {step} documents", 206 | f"({step/elapsed:.2f} docs/s, {mbs:.4f} MB/s).", 207 | file=sys.stderr, 208 | ) 209 | pool.close() 210 | print("Saving tokens to files...") 211 | 212 | # concatenate_datasets 213 | print("concatenate_datasets...") 214 | reftext = load_dataset("text", data_files="data/reftext.txt")["train"] 215 | refids = load_dataset("text", data_files="data/refids.txt")["train"] 216 | refids = refids.rename_column("text", "chinese_ref") 217 | refids = refids.map(lambda example: {"chinese_ref": eval(example["chinese_ref"])}) 218 | concat_ds = concatenate_datasets([reftext, refids], axis=1) 219 | concat_ds.save_to_disk("./clue_small_wwm_data") 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument( 225 | "--model_name", 226 | type=str, 227 | default="junnyu/roformer_chinese_char_base", 228 | help="What model to use.", 229 | ) 230 | 231 | group = parser.add_argument_group(title="data input/output") 232 | group.add_argument( 233 | "--input_path", 234 | type=str, 235 | default="data/clue_corpus_small_14g.jsonl", 236 | help="Path to input JSON files.", 237 | ) 238 | 239 | group.add_argument( 240 | "--json_key", 241 | type=str, 242 | default="text", 243 | help="For JSON format. Space separate listed of keys to extract from json", 244 | ) 245 | 246 | group = parser.add_argument_group(title="common config") 247 | 248 | group.add_argument( 249 | "--log_interval", 250 | type=int, 251 | default=100, 252 | help="Interval between progress updates", 253 | ) 254 | group.add_argument( 255 | "--workers", type=int, default=12, help="Number of worker processes to launch" 256 | ) 257 | 258 | args = parser.parse_args() 259 | main(args) 260 | -------------------------------------------------------------------------------- /run_mlm_wwm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. 16 | 17 | import logging 18 | import os 19 | import sys 20 | import time 21 | from dataclasses import dataclass, field 22 | from typing import Optional 23 | 24 | import transformers 25 | from datasets import Dataset 26 | from transformers import ( 27 | BertTokenizerFast, 28 | DataCollatorForWholeWordMask, 29 | HfArgumentParser, 30 | TrainingArguments, 31 | set_seed, 32 | ) 33 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 34 | 35 | from flash import FLASHQuadConfig, FLASHQuadForMaskedLM, FLASHConfig, FLASHForMaskedLM 36 | from mlm_trainer import Trainer 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | name2cls = { 42 | "flash": (FLASHConfig, FLASHForMaskedLM ), 43 | "flashquad" : (FLASHQuadConfig, FLASHQuadForMaskedLM ), 44 | } 45 | 46 | @dataclass 47 | class ModelArguments: 48 | """ 49 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 50 | """ 51 | 52 | tokenizer_name: Optional[str] = field( 53 | default="junnyu/roformer_chinese_char_base", 54 | metadata={ 55 | "help": "Pretrained tokenizer name or path if not the same as model_name" 56 | }, 57 | ) 58 | model_name: Optional[str] = field( 59 | default="flash", 60 | metadata={ 61 | "help": "model_name" 62 | }, 63 | ) 64 | 65 | @dataclass 66 | class DataTrainingArguments: 67 | """ 68 | Arguments pertaining to what data we are going to input our model for training and eval. 69 | """ 70 | 71 | train_dir: Optional[str] = field( 72 | default="./clue_small_wwm_data", 73 | metadata={"help": "The input training data file."}, 74 | ) 75 | overwrite_cache: bool = field( 76 | default=False, 77 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 78 | ) 79 | max_seq_length: Optional[int] = field( 80 | default=512, 81 | metadata={ 82 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 83 | "than this will be truncated. Default to the max input length of the model." 84 | }, 85 | ) 86 | preprocessing_num_workers: Optional[int] = field( 87 | default=12, 88 | metadata={"help": "The number of processes to use for the preprocessing."}, 89 | ) 90 | mlm_probability: float = field( 91 | default=0.15, 92 | metadata={ 93 | "help": "Ratio of tokens to mask for masked language modeling loss"}, 94 | ) 95 | pad_to_max_length: bool = field( 96 | default=False, 97 | metadata={ 98 | "help": "Whether to pad all samples to `max_seq_length`. " 99 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 100 | }, 101 | ) 102 | 103 | 104 | def main(): 105 | # See all possible arguments in src/transformers/training_args.py 106 | # or by passing the --help flag to this script. 107 | # We now keep distinct sets of args, for a cleaner separation of concerns. 108 | 109 | parser = HfArgumentParser( 110 | (ModelArguments, DataTrainingArguments, TrainingArguments) 111 | ) 112 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 113 | # If we pass only one argument to the script and it's the path to a json file, 114 | # let's parse it to get our arguments. 115 | model_args, data_args, training_args = parser.parse_json_file( 116 | json_file=os.path.abspath(sys.argv[1]) 117 | ) 118 | else: 119 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 120 | 121 | # Detecting last checkpoint. 122 | last_checkpoint = None 123 | if ( 124 | os.path.isdir(training_args.output_dir) 125 | and not training_args.overwrite_output_dir 126 | ): 127 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 128 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 129 | raise ValueError( 130 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 131 | "Use --overwrite_output_dir to overcome." 132 | ) 133 | elif last_checkpoint is not None: 134 | logger.info( 135 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 136 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 137 | ) 138 | 139 | # Setup logging 140 | logging.basicConfig( 141 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 142 | datefmt="%m/%d/%Y %H:%M:%S", 143 | handlers=[logging.StreamHandler(sys.stdout)], 144 | ) 145 | logger.setLevel( 146 | logging.INFO if is_main_process( 147 | training_args.local_rank) else logging.WARN 148 | ) 149 | 150 | # Log on each process the small summary: 151 | logger.warning( 152 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 153 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 154 | ) 155 | # Set the verbosity to info of the Transformers logger (on main process only): 156 | if is_main_process(training_args.local_rank): 157 | transformers.utils.logging.set_verbosity_info() 158 | transformers.utils.logging.enable_default_handler() 159 | transformers.utils.logging.enable_explicit_format() 160 | logger.info("Training parameters %s", training_args) 161 | 162 | # Set seed before initializing model. 163 | set_seed(training_args.seed) 164 | 165 | # download the dataset. 166 | # 加载clue_wwm_13g数据集 167 | datasets = Dataset.load_from_disk(data_args.train_dir) 168 | 169 | config_cls, model_cls = name2cls[model_args.model_name] 170 | config = config_cls(num_hidden_layers=12) # small 171 | # tokenizer使用了roformer_chinese_char_base 172 | tokenizer = BertTokenizerFast.from_pretrained(model_args.tokenizer_name) 173 | model = model_cls(config) 174 | model.resize_token_embeddings(len(tokenizer)) 175 | 176 | # Preprocessing the datasets. 177 | # First we tokenize all the texts. 178 | column_names = datasets.column_names 179 | text_column_name = "text" if "text" in column_names else column_names[0] 180 | 181 | padding = "max_length" if data_args.pad_to_max_length else False 182 | 183 | def tokenize_function(examples): 184 | # Remove empty lines 185 | texts = [] 186 | chinese_ref = [] 187 | for text, ref in zip(examples["text"], examples["chinese_ref"]): 188 | if len(text) > 0 and not text.isspace(): 189 | texts.append(text.strip()) 190 | chinese_ref.append(ref) 191 | examples["text"] = texts 192 | examples["chinese_ref"] = chinese_ref 193 | data = tokenizer( 194 | examples["text"], 195 | padding=padding, 196 | truncation=True, 197 | max_length=data_args.max_seq_length, 198 | return_token_type_ids=False, 199 | return_attention_mask=False, 200 | ) 201 | data["text"] = texts 202 | data["chinese_ref"] = chinese_ref 203 | return data 204 | 205 | tokenized_datasets = datasets.map( 206 | tokenize_function, 207 | batched=True, 208 | num_proc=data_args.preprocessing_num_workers, 209 | remove_columns=[text_column_name], 210 | load_from_cache_file=not data_args.overwrite_cache, 211 | new_fingerprint="clue_13g_small_roformer_wwm", 212 | ) 213 | 214 | training_args.remove_unused_columns = False 215 | 216 | # Data collator 217 | # This one will take care of randomly masking the tokens. 218 | data_collator = DataCollatorForWholeWordMask( 219 | tokenizer=tokenizer, 220 | mlm_probability=data_args.mlm_probability, 221 | pad_to_multiple_of=8 if training_args.fp16 else None, 222 | ) 223 | 224 | # Initialize our Trainer 225 | trainer = Trainer( 226 | model=model, 227 | args=training_args, 228 | train_dataset=tokenized_datasets, 229 | eval_dataset=None, 230 | tokenizer=tokenizer, 231 | data_collator=data_collator, 232 | ) 233 | # trainer.add_callback(LoggingCallback(save_interval=training_args.save_interval)) 234 | # Training 235 | if last_checkpoint is not None: 236 | checkpoint = last_checkpoint 237 | else: 238 | checkpoint = None 239 | 240 | logger.info("Training a model...") 241 | start_time = time.time() 242 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 243 | train_time = time.time() - start_time 244 | logger.info(f"Training time: {train_time}") 245 | trainer.save_model() # Saves the tokenizer too for easy upload 246 | 247 | output_train_file = os.path.join( 248 | training_args.output_dir, "train_results.txt") 249 | if trainer.is_world_process_zero(): 250 | with open(output_train_file, "w") as writer: 251 | logger.info("***** Train results *****") 252 | for key, value in sorted(train_result.metrics.items()): 253 | logger.info(f" {key} = {value}") 254 | writer.write(f"{key} = {value}\n") 255 | 256 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 257 | trainer.state.save_to_json( 258 | os.path.join(training_args.output_dir, "trainer_state.json") 259 | ) 260 | 261 | 262 | if __name__ == "__main__": 263 | main() 264 | -------------------------------------------------------------------------------- /trans_to_json.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import json 17 | import multiprocessing 18 | import os 19 | import re 20 | import shutil 21 | import sys 22 | import time 23 | from functools import partial 24 | 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--input_path", 30 | type=str, 31 | required=True, 32 | help="Path to you raw files. Folder or file path.", 33 | ) 34 | parser.add_argument( 35 | "--output_path", 36 | type=str, 37 | required=True, 38 | help="Path to save the output json files.", 39 | ) 40 | parser.add_argument( 41 | "--json_key", type=str, default="text", help="The content key of json file." 42 | ) 43 | parser.add_argument( 44 | "--doc_spliter", 45 | type=str, 46 | default="", 47 | help="Spliter between documents. We will strip the line, if you use blank line to split doc, leave it blank.", 48 | ) 49 | parser.add_argument( 50 | "--min_doc_length", type=int, default=10, help="Minimal char of a documment." 51 | ) 52 | parser.add_argument( 53 | "--workers", type=int, default=1, help="Number of worker processes to launch" 54 | ) 55 | parser.add_argument( 56 | "--log_interval", type=int, default=1, help="Interval between progress updates." 57 | ) 58 | parser.add_argument("--no-merge", action="store_true", help="Don't merge the file.") 59 | parser.add_argument( 60 | "--no-shuffle", action="store_true", help="Don't shuffle the file." 61 | ) 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def raw_text_to_json(path, doc_spliter="", json_key="text", min_doc_length=10): 67 | path = os.path.abspath(path) 68 | if not os.path.exists(path): 69 | print("No found file %s" % path) 70 | return 0, None 71 | 72 | out_filepath = path + ".jsonl" 73 | fout = open(out_filepath, "w", encoding="utf-8") 74 | len_files = 0 75 | with open(path, "r") as f: 76 | doc = "" 77 | line = f.readline() 78 | while line: 79 | len_files += len(line) 80 | if line.strip() == doc_spliter: 81 | if len(doc) > min_doc_length: 82 | fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n") 83 | doc = "" 84 | else: 85 | doc += line 86 | line = f.readline() 87 | 88 | if len(doc) > min_doc_length: 89 | fout.write(json.dumps({json_key: doc}, ensure_ascii=False) + "\n") 90 | doc = "" 91 | 92 | return len_files, out_filepath 93 | 94 | 95 | def merge_file(file_paths, output_path): 96 | if not output_path.endswith(".jsonl"): 97 | output_path = output_path + ".jsonl" 98 | print("Merging files into %s" % output_path) 99 | with open(output_path, "wb") as wfd: 100 | for f in file_paths: 101 | if f is not None and os.path.exists(f): 102 | with open(f, "rb") as fd: 103 | shutil.copyfileobj(fd, wfd) 104 | os.remove(f) 105 | print("File save in %s" % output_path) 106 | return output_path 107 | 108 | 109 | def shuffle_file(output_path): 110 | print("Shuffling the jsonl file...") 111 | if os.path.exists(output_path): 112 | os.system("shuf %s -o %s" % (output_path, output_path)) 113 | print("File shuffled!!!") 114 | else: 115 | raise ValueError("File not found: %s" % output_path) 116 | 117 | 118 | def main(): 119 | args = get_args() 120 | startup_start = time.time() 121 | 122 | file_paths = [] 123 | if os.path.isfile(args.input_path): 124 | file_paths.append(args.input_path) 125 | else: 126 | for root, _, fs in os.walk(args.input_path): 127 | for f in fs: 128 | file_paths.append(os.path.join(root, f)) 129 | 130 | pool = multiprocessing.Pool(args.workers) 131 | 132 | startup_end = time.time() 133 | proc_start = time.time() 134 | total_bytes_processed = 0 135 | print("Time to startup:", startup_end - startup_start) 136 | 137 | trans_json = partial( 138 | raw_text_to_json, 139 | doc_spliter=args.doc_spliter, 140 | json_key=args.json_key, 141 | min_doc_length=args.min_doc_length, 142 | ) 143 | encoded_files = pool.imap(trans_json, file_paths, 1) 144 | 145 | out_paths = [] 146 | for i, (bytes_processed, out_path) in enumerate(encoded_files, start=1): 147 | total_bytes_processed += bytes_processed 148 | out_paths.append(out_path) 149 | if i % args.log_interval == 0: 150 | current = time.time() 151 | elapsed = current - proc_start 152 | mbs = total_bytes_processed / elapsed / 1024 / 1024 153 | print( 154 | f"Processed {i} files", 155 | f"({i/elapsed} files/s, {mbs} MB/s).", 156 | file=sys.stderr, 157 | ) 158 | 159 | if not args.no_merge: 160 | output_path = merge_file(out_paths, args.output_path) 161 | if not args.no_shuffle: 162 | shuffle_file(output_path) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | --------------------------------------------------------------------------------