├── LICENSE ├── .gitignore ├── README.md └── nbs ├── zh2cc_translate.ipynb ├── kw_leading_poe.ipynb ├── punktuation_ner.ipynb ├── xlsearch.ipynb └── cc2zh_translate.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Raynard Jon (Zhang,Xiaochen) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 渊 2 | > 渊 - AI+文言文一站式附庸风雅, 欢迎贡献新的思路 笔记 模型 数据 3 | 4 | 一个文言诗词的NLP项目们。🌼 5 | 6 | * [附帶翻譯引擎的所有古文在這裡開始閱讀](https://huggingface.co/spaces/raynardj/duguwen-classical-chinese-to-morden-translate) 7 | * [搜索: 博古搜今](#搜索) 8 | * [翻译](#翻译) 9 | * [现代文到文言文翻译器](#现代文到文言文翻译器) 10 | * [文言文到现代文翻译器](#文言文到现代文翻译器) 11 | * [断句](#断句) 12 | * [资源清单](#资源清单) 13 | 14 | ## 搜索 15 | ### 博古搜今 16 | * 用现代文语句模糊搜索文言文语句,可以去[🤗 模型主页](https://huggingface.co/raynardj/xlsearch-cross-lang-search-zh-vs-classicical-cn)下载模型 17 | 18 | 您是不是经常遇到这样的问题: 19 | * 我不记得是谁, 哪个朝代,我只记得大概这么一个事儿,我就能模糊找到原文 20 | * 我不记得原文, 但是我只记得原文想表达的现代汉语意思, 希望能找出来引用一下。 21 | * 我在写文章, 有个观点, 我想碰运气看看古人有没有提过同样类似的说法。 22 | * 我只是想更有效率地阅读古文 23 | 24 | 推荐的使用通道如下,当然, cosine距离搜索相关的框架和引擎很多, 大家自己看着适用的选 25 | 26 | 装包 27 | ```shell 28 | pip install -Uqq unpackai 29 | pip install -Uqq SentenceTransformer 30 | ``` 31 | 32 | 搜索语句的函数 33 | ```python 34 | from unpackai.interp import CosineSearch 35 | from sentence_transformers import SentenceTransformer 36 | import pandas as pd 37 | import numpy as np 38 | 39 | TAG = "raynardj/xlsearch-cross-lang-search-zh-vs-classicical-cn" 40 | encoder = SentenceTransformer(TAG) 41 | 42 | # all_lines is a list of all your sentences 43 | # all_lines 是一个你所有句子的列表, 可以是一本书, 按照句子分割, 也可以是很多很多书 44 | all_lines = ["句子1","句子2",...] 45 | vec = encoder.encode(all_lines, batch_size=32, show_progress_bar=True) 46 | 47 | # consine距离搜索器 48 | cosine = CosineSearch(vec) 49 | 50 | def search(text): 51 | enc = encoder.encode(text) # encode the search key 52 | order = cosine(enc) # distance array 53 | sentence_df = pd.DataFrame({"sentence":np.array(all_lines)[order[:5]]}) 54 | return sentence_df 55 | ``` 56 | 57 | 将史记打成句子以后, 搜索效果是这样的: 58 | 59 | ```python 60 | >>> search("他是一个很慷慨的人") 61 | ``` 62 | ``` 63 | sentence 64 | 0 季布者,楚人也。为气任侠,有名於楚。 65 | 1 董仲舒为人廉直。 66 | 2 大将军为人仁善退让,以和柔自媚於上,然天下未有称也。 67 | 3 勃为人木彊敦厚,高帝以为可属大事。 68 | 4 石奢者,楚昭王相也。坚直廉正,无所阿避。 69 | ``` 70 | 71 | ```python 72 | >>> search("进入军营,必须缓缓牵着马骑") 73 | ``` 74 | ``` 75 | sentence 76 | 0 壁门士吏谓从属车骑曰:将军约,军中不得驱驰。 77 | 1 起之为将,与士卒最下者同衣食。卧不设席,行不骑乘,亲裹赢粮,与士卒分劳苦。 78 | 2 既出,沛公留车骑,独骑一马,与樊哙等四人步从,从间道山下归走霸上军,而使张良谢项羽。 79 | 3 顷之,上行出中渭桥,有一人从穚下走出,乘舆马惊。 80 | 4 元狩四年春,上令大将军青、骠骑将军去病将各五万骑,步兵转者踵军数十万,而敢力战深入之士皆属骠骑。 81 | ``` 82 | 83 | ## 翻译 84 | ### 现代文到文言文翻译器 85 | * 可以去[🤗 模型主页](https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient)体验或下载这个模型。 86 | * 使用了这个[翻译句对的数据集](https://github.com/BangBOOM/Classical-Chinese) 87 | * 感兴趣的可以参考[训练的笔记](nbs/zh2cc_translate.ipynb) 88 | 89 | 在python中推荐使用以下的代码进行inference: 90 | ```python 91 | from transformers import ( 92 | EncoderDecoderModel, 93 | AutoTokenizer 94 | ) 95 | PRETRAINED = "raynardj/wenyanwen-chinese-translate-to-ancient" 96 | tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) 97 | model = EncoderDecoderModel.from_pretrained(PRETRAINED) 98 | 99 | def inference(text): 100 | tk_kwargs = dict( 101 | truncation=True, 102 | max_length=128, 103 | padding="max_length", 104 | return_tensors='pt') 105 | 106 | inputs = tokenizer([text,],**tk_kwargs) 107 | with torch.no_grad(): 108 | return tokenizer.batch_decode( 109 | model.generate( 110 | inputs.input_ids, 111 | attention_mask=inputs.attention_mask, 112 | num_beams=3, 113 | max_length=128, 114 | bos_token_id=101, 115 | eos_token_id=tokenizer.sep_token_id, 116 | pad_token_id=tokenizer.pad_token_id, 117 | ), skip_special_tokens=True) 118 | ``` 119 | #### 目前版本的案例 120 | 目前版本, 按照上述通道的翻译案例: 121 | ```python 122 | >>> inference('你连一百块都不肯给我') 123 | ['不 肯 与 我 百 钱 。'] 124 | >>> inference("他不能做长远的谋划") 125 | ['不 能 为 远 谋 。'] 126 | >>> inference("我们要干一番大事业") 127 | ['吾 属 当 举 大 事 。'] 128 | >>> inference("这感觉,已经不对,我努力,在挽回") 129 | ['此 之 谓 也 , 已 不 可 矣 , 我 勉 之 , 以 回 之 。'] 130 | >>> inference("轻轻地我走了, 正如我轻轻地来, 我挥一挥衣袖,不带走一片云彩") 131 | ['轻 我 行 , 如 我 轻 来 , 挥 袂 不 携 一 片 云 。'] 132 | ``` 133 | 134 | 其中可改进处颇多: 135 | 136 | * [ ] 目前,模型最长语句是128,可以通过修改tokenizer的max_length参数来调整。也就是会忽略一些现代文的语句。 137 | * [ ] 可以通过去除pad token 的标签设置为-100,这样就不需要传eos token id了。 138 | * [ ] 目前使用现代文预训练的bert-base-chinese作为encoder, 现代文预训练的 gpt2作为decoder。我们完全可以使用文言文+诗词预训练的gpt2作为decoder, 提升效果几乎是肯定的。 139 | * [ ] 算力有限,许多调参细节,基本都没有试过。 140 | 141 | ### 文言文到现代文翻译器 142 | > 输入文言文, 可以是**断句** 或者 **未断句**的文言文, 模型会预测现代文的表述。 143 | 144 | * 欢迎前往[🤗 文言文(古文)到现代文的翻译器模型主页](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern) 145 | * 训练语料是就是九十多万句句对, [数据集链接📚](https://github.com/BangBOOM/Classical-Chinese)。 训练时source序列(古文序列), 按照50%的概率整句去除所有标点符号。 146 | * 感兴趣的可以参考[训练的笔记](nbs/cc2zh_translate.ipynb),其中可改进处颇多。 147 | 148 | #### 推荐的inference 通道 149 | **注意** 150 | * 你必须将```generate```函数的```eos_token_id```设置为102就可以翻译出完整的语句, 不然翻译完了会有残留的语句(因为做熵的时候用pad标签=-100导致)。 151 | 目前huggingface 页面上compute按钮会有这个问题, 推荐使用以下代码来得到翻译结果 152 | * 请设置```generate```的参数```num_beams>=3```, 以达到较好的翻译效果 153 | * 请设置```generate```的参数```max_length```256, 不然结果会吃掉句子 154 | ```python 155 | from transformers import ( 156 | EncoderDecoderModel, 157 | AutoTokenizer 158 | ) 159 | PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" 160 | tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) 161 | model = EncoderDecoderModel.from_pretrained(PRETRAINED) 162 | def inference(text): 163 | tk_kwargs = dict( 164 | truncation=True, 165 | max_length=128, 166 | padding="max_length", 167 | return_tensors='pt') 168 | 169 | inputs = tokenizer([text,],**tk_kwargs) 170 | with torch.no_grad(): 171 | return tokenizer.batch_decode( 172 | model.generate( 173 | inputs.input_ids, 174 | attention_mask=inputs.attention_mask, 175 | num_beams=3, 176 | max_length=256, 177 | bos_token_id=101, 178 | eos_token_id=tokenizer.sep_token_id, 179 | pad_token_id=tokenizer.pad_token_id, 180 | ), skip_special_tokens=True) 181 | ``` 182 | #### 目前版本的案例 183 | > 当然, 拿比较熟知的语句过来, 通常会有些贻笑大方的失误, 大家如果有好玩的调戏案例, 也欢迎反馈 184 | ```python 185 | >>> inference('非我族类其心必异') 186 | ['不 是 我 们 的 族 类 , 他 们 的 心 思 必 然 不 同 。'] 187 | >>> inference('肉食者鄙未能远谋') 188 | ['吃 肉 的 人 鄙 陋 , 不 能 长 远 谋 划 。'] 189 | # 这里我好几批模型都翻不出这个**输**字(甚至有一个版本翻成了秦始皇和汉武帝), 可能并不是很古朴的用法, 190 | >>> inference('江山如此多娇引无数英雄竞折腰惜秦皇汉武略输文采唐宗宋祖稍逊风骚') 191 | ['江 山 如 此 多 , 招 引 无 数 的 英 雄 , 竞 相 折 腰 , 可 惜 秦 皇 、 汉 武 , 略 微 有 文 采 , 唐 宗 、 宋 祖 稍 稍 逊 出 风 雅 。'] 192 | >>> inference("清风徐来水波不兴") 193 | ['清 风 慢 慢 吹 来 , 水 波 不 兴 。'] 194 | >>> inference("无他唯手熟尔") 195 | ['没 有 别 的 事 , 只 是 手 熟 罢 了 。'] 196 | >>> inference("此诚危急存亡之秋也") 197 | ['这 实 在 是 危 急 存 亡 的 时 候 。'] 198 | ``` 199 | 200 | ## 断句 201 | > 输入一串未断句文言文, 可以断句, 目前支持二十多种标点符号 202 | 203 | * 训练好的模型[这里可以下](https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian) 204 | * 使用了[【殆知阁v2.0数据集】](https://github.com/garychowcmu/daizhigev20) 205 | 206 | 这里推荐的Inference函数如下 207 | 208 | ```python 209 | from transformers import AutoTokenizer, BertForTokenClassification 210 | from transformers import pipeline 211 | 212 | TAG = "raynardj/classical-chinese-punctuation-guwen-biaodian" 213 | ner = pipeline("ner",module.model,tokenizer=tokenizer) 214 | 215 | model = BertForTokenClassification.from_pretrained(TAG) 216 | tokenizer = AutoTokenizer.from_pretrained(TAG) 217 | 218 | def mark_sentence(x: str): 219 | outputs = ner(x) 220 | x_list = list(x) 221 | for i, output in enumerate(outputs): 222 | x_list.insert(output['end']+i, output['entity']) 223 | return "".join(x_list) 224 | ``` 225 | 226 | 案例 227 | ```python 228 | >>> mark_sentence("""郡邑置夫子庙于学以嵗时释奠盖自唐贞观以来未之或改我宋有天下因其制而损益之姑苏当浙右要区规模尤大更建炎戎马荡然无遗虽修学宫于荆榛瓦砾之余独殿宇未遑议也每春秋展礼于斋庐已则置不问殆为阙典今寳文阁直学士括苍梁公来牧之明年实绍兴十有一禩也二月上丁修祀既毕乃愓然自咎揖诸生而告之曰天子不以汝嘉为不肖俾再守兹土顾治民事神皆守之职惟是夫子之祀教化所基尤宜严且谨而拜跪荐祭之地卑陋乃尔其何以掲防妥灵汝嘉不敢避其责曩常去此弥年若有所负尚安得以罢輭自恕复累后人乎他日或克就绪愿与诸君落之于是谋之僚吏搜故府得遗材千枚取赢资以给其费鸠工庀役各举其任嵗月讫工民不与知像设礼器百用具修至于堂室廊序门牖垣墙皆一新之""") 229 | 230 | '郡邑,置夫子庙于学,以嵗时释奠。盖自唐贞观以来,未之或改。我宋有天下因其制而损益之。姑苏当浙右要区,规模尤大,更建炎戎马,荡然无遗。虽修学宫于荆榛瓦砾之余,独殿宇未遑议也。每春秋展礼于斋庐,已则置不问,殆为阙典。今寳文阁直学士括苍梁公来牧之。明年,实绍兴十有一禩也。二月,上丁修祀既毕,乃愓然自咎,揖诸生而告之曰"天子不以汝嘉为不肖,俾再守兹土,顾治民事,神皆守之职。惟是夫子之祀,教化所基,尤宜严且谨。而拜跪荐祭之地,卑陋乃尔。其何以掲防妥灵?汝嘉不敢避其责。曩常去此弥年,若有所负,尚安得以罢輭自恕,复累后人乎!他日或克就绪,愿与诸君落之。于是谋之,僚吏搜故府,得遗材千枚,取赢资以给其费。鸠工庀役,各举其任。嵗月讫,工民不与知像,设礼器,百用具修。至于堂室。廊序。门牖。垣墙,皆一新之。' 231 | ``` 232 | 233 | ### 可能会有的瑕疵 234 | * 有时候两个标点符号连在一起, 会被吃掉一个, 比如```:【```会只有```【``` 235 | * 有时候标记的字太强了, 很难学会例外, 比如也字就很霸道, "吾生也有涯,而知也无涯。以有涯随无涯,殆已" 怎么都断不正确 236 | 237 | ## 资源清单 238 | * [项目源代码 🌟, 欢迎+star提pr](https://github.com/raynardj/yuan) 239 | * [跨语种搜索 🔎](https://huggingface.co/raynardj/xlsearch-cross-lang-search-zh-vs-classicical-cn) 240 | * [现代文翻译古汉语的模型 ⛰](https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient) 241 | * [古汉语到现代文的翻译模型, 输入可以是未断句的句子 🚀](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern) 242 | * [断句模型 🗡](https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian) 243 | * [意境关键词 和 藏头写诗🤖](https://huggingface.co/raynardj/keywords-cangtou-chinese-poetry) 244 | 245 | 欢迎联系我github的邮箱讨论,或者提交issue,我会尽力帮助你。 -------------------------------------------------------------------------------- /nbs/zh2cc_translate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Translate model\n", 8 | "\n", 9 | "We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Imports" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from forgebox.imports import *\n", 26 | "from datasets import load_dataset\n", 27 | "# from fastai.text.all import *\n", 28 | "from unpackai.nlp import *\n", 29 | "from tqdm.notebook import tqdm\n", 30 | "import pytorch_lightning as pl" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Config" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "data=Path(\"/some_location/data\")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "DATA = Path(data/\"nlp\"/\"zh\"/\"cc_vs_zh\")\n", 56 | "TO_CLASSICAL = True" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Data\n", 64 | "\n", 65 | "### Combine data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "all_file = list(DATA.rglob(\"data/*\"))\n", 75 | "\n", 76 | "\n", 77 | "def open_file_to_lines(file):\n", 78 | " with open(file) as f:\n", 79 | " lines = f.read().splitlines()\n", 80 | " return lines\n", 81 | "\n", 82 | "def pairing_the_file(files,kw):\n", 83 | " pairs = []\n", 84 | " for file in files:\n", 85 | " if kw not in file.name:\n", 86 | " file1 = file\n", 87 | " file2 = f\"{file}{kw}\"\n", 88 | " pairs.append((file1,file2))\n", 89 | " return pairs\n", 90 | "\n", 91 | "pairs = pairing_the_file(all_file,\"翻译\")\n", 92 | "\n", 93 | "def open_pairs(pairs):\n", 94 | " chunks = []\n", 95 | " for pair in tqdm(pairs, leave=False):\n", 96 | " file1,file2 = pair\n", 97 | " lines1 = open_file_to_lines(file1)\n", 98 | " lines2 = open_file_to_lines(file2)\n", 99 | " chunks.append(pd.DataFrame({\"classical\":lines1,\"modern\":lines2}))\n", 100 | " return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)\n", 101 | "\n", 102 | "data_df = open_pairs(pairs)\n", 103 | "\n", 104 | "df = data_df.rename(\n", 105 | " columns = dict(\n", 106 | " zip([\"modern\",\"classical\"],\n", 107 | " [\"source\",\"target\"] if TO_CLASSICAL else [\"target\",\"source\",]))\n", 108 | ")\n", 109 | "\n", 110 | "df.head()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "### Loading tokenizer" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "from transformers import (\n", 127 | " AutoTokenizer,\n", 128 | " AutoModelForCausalLM,\n", 129 | " AutoModel,\n", 130 | " EncoderDecoderModel\n", 131 | " )\n", 132 | "\n", 133 | "# we find a English parsing encoder, as a pretrained bert is good at understanding english\n", 134 | "# BERT is short for Bidirectional **Encoder** Representations from Transformers, which consists fully of encoder blocks\n", 135 | "ENCODER_PRETRAINED = \"bert-base-chinese\"\n", 136 | "# we find a Chinese writing model for decoder, as decoder is the part of the model that can write stuff\n", 137 | "DECODER_PRETRAINED = \"uer/gpt2-chinese-poem\"\n", 138 | "\n", 139 | "encoder_tokenizer = AutoTokenizer.from_pretrained(ENCODER_PRETRAINED)\n", 140 | "\n", 141 | "decoder_tokenizer = AutoTokenizer.from_pretrained(\n", 142 | " ENCODER_PRETRAINED # notice we use the BERT's tokenizer here\n", 143 | ")" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "### Pytoch Dataset" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "class Seq2Seq(Dataset):\n", 160 | " def __init__(self, df, tokenizer, target_tokenizer, max_len=128):\n", 161 | " super().__init__()\n", 162 | " self.df = df\n", 163 | " self.tokenizer = tokenizer\n", 164 | " self.target_tokenizer = target_tokenizer\n", 165 | " self.max_len = max_len\n", 166 | " \n", 167 | " def __len__(self, ):\n", 168 | " return len(self.df)\n", 169 | "\n", 170 | " def __getitem__(self, idx):\n", 171 | " return dict(self.df.iloc[idx])\n", 172 | "\n", 173 | " def collate(self, batch):\n", 174 | " batch_df = pd.DataFrame(list(batch))\n", 175 | " x, y = batch_df.source, batch_df.target\n", 176 | " x_batch = self.tokenizer(\n", 177 | " list(x),\n", 178 | " max_length=self.max_len,\n", 179 | " padding='max_length',\n", 180 | " truncation=True,\n", 181 | " return_tensors='pt',\n", 182 | " )\n", 183 | " y_batch = self.target_tokenizer(\n", 184 | " list(y),\n", 185 | " max_length=self.max_len,\n", 186 | " padding='max_length',\n", 187 | " truncation=True,\n", 188 | " return_tensors='pt',\n", 189 | " )\n", 190 | " x_batch['decoder_input_ids'] = y_batch['input_ids']\n", 191 | " x_batch['labels'] = y_batch['input_ids'].clone()\n", 192 | " x_batch['labels'][x_batch['labels'] == self.tokenizer.pad_token_id] = -100\n", 193 | " return x_batch\n", 194 | "\n", 195 | " def dataloader(self, batch_size, shuffle=True):\n", 196 | " return DataLoader(\n", 197 | " self,\n", 198 | " batch_size=batch_size,\n", 199 | " shuffle=shuffle,\n", 200 | " collate_fn=self.collate,\n", 201 | " )\n", 202 | "\n", 203 | " def split_train_valid(self, valid_size=0.1):\n", 204 | " split_index = int(len(self) * (1 - valid_size))\n", 205 | " cls = type(self)\n", 206 | " shuffled = self.df.sample(frac=1).reset_index(drop=True)\n", 207 | " train_set = cls(\n", 208 | " shuffled.iloc[:split_index],\n", 209 | " tokenizer=self.tokenizer,\n", 210 | " target_tokenizer=self.target_tokenizer,\n", 211 | " max_len=self.max_len,\n", 212 | " )\n", 213 | " valid_set = cls(\n", 214 | " shuffled.iloc[split_index:],\n", 215 | " tokenizer=self.tokenizer,\n", 216 | " target_tokenizer=self.target_tokenizer,\n", 217 | " max_len=self.max_len,\n", 218 | " )\n", 219 | " return train_set, valid_set" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "### PL datamodule" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "class Seq2SeqData(pl.LightningDataModule):\n", 236 | " def __init__(self, df, tokenizer, target_tokenizer, batch_size=12, max_len=128):\n", 237 | " super().__init__()\n", 238 | " self.df = df\n", 239 | " self.ds = Seq2Seq(df, tokenizer, target_tokenizer,max_len=max_len)\n", 240 | " self.tokenizer = tokenizer\n", 241 | " self.target_tokenizer = target_tokenizer\n", 242 | " self.max_len = max_len\n", 243 | " self.batch_size = batch_size\n", 244 | "\n", 245 | " def setup(self, stage=None):\n", 246 | " self.train_set, self.valid_set = self.ds.split_train_valid()\n", 247 | "\n", 248 | " def train_dataloader(self):\n", 249 | " return self.train_set.dataloader(\n", 250 | " batch_size=self.batch_size, shuffle=True)\n", 251 | "\n", 252 | " def val_dataloader(self):\n", 253 | " return self.valid_set.dataloader(\n", 254 | " batch_size=self.batch_size*2, shuffle=False)\n", 255 | "\n", 256 | "data_module = Seq2SeqData(df, encoder_tokenizer, decoder_tokenizer, batch_size=64, )\n", 257 | "data_module.setup()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "next(iter(data_module.train_dataloader()))" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "### Load pretrained models" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "# loading pretrained model\n", 283 | "encoder_decoder = EncoderDecoderModel.from_encoder_decoder_pretrained(\n", 284 | " encoder_pretrained_model_name_or_path=ENCODER_PRETRAINED,\n", 285 | " decoder_pretrained_model_name_or_path=DECODER_PRETRAINED,\n", 286 | ")" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "class Seq2SeqTrain(pl.LightningModule):\n", 296 | " def __init__(self, encoder_decoder):\n", 297 | " super().__init__()\n", 298 | " self.encoder_decoder = encoder_decoder\n", 299 | " \n", 300 | " def forward(self, batch):\n", 301 | " return self.encoder_decoder(\n", 302 | " **batch\n", 303 | " )\n", 304 | "\n", 305 | " def training_step(self, batch, batch_idx):\n", 306 | " outputs = self(batch)\n", 307 | " self.log('loss', outputs.loss)\n", 308 | " return outputs.loss\n", 309 | "\n", 310 | " def validation_step(self, batch, batch_idx):\n", 311 | " outputs = self(batch)\n", 312 | " self.log('val_loss', outputs.loss)\n", 313 | " return outputs.loss\n", 314 | " \n", 315 | " def configure_optimizers(self):\n", 316 | " encoder_params = list(\n", 317 | " {\"params\":param,\"lr\":1e-5}\n", 318 | " for param in self.encoder_decoder.encoder.embeddings.parameters()) +\\\n", 319 | " list({\"params\":param,\"lr\":1e-5}\n", 320 | " for param in self.encoder_decoder.encoder.encoder.parameters()) +\\\n", 321 | " list({\"params\":param,\"lr\":1e-3}\n", 322 | " for param in self.encoder_decoder.encoder.pooler.parameters())\n", 323 | "\n", 324 | " decoder_params = list()\n", 325 | " for name, param in self.encoder_decoder.decoder.named_parameters():\n", 326 | " if 'ln_cross_attn' in name:\n", 327 | " decoder_params.append({\"params\":param,\"lr\":1e-3})\n", 328 | " elif 'crossattention' in name:\n", 329 | " decoder_params.append({\"params\":param,\"lr\":1e-3})\n", 330 | " elif 'lm_head' in name:\n", 331 | " decoder_params.append({\"params\":param,\"lr\":1e-4})\n", 332 | " else:\n", 333 | " decoder_params.append({\"params\":param,\"lr\":1e-5})\n", 334 | "\n", 335 | " return torch.optim.Adam(\n", 336 | " encoder_params + decoder_params,\n", 337 | " lr=1e-3,\n", 338 | " )" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "module = Seq2SeqTrain(encoder_decoder)" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": {}, 353 | "source": [ 354 | "## Training" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "save = pl.callbacks.ModelCheckpoint(\n", 364 | " data/'../weights/cc_to_zh',\n", 365 | " save_top_k=2,\n", 366 | " verbose=True,\n", 367 | " monitor='val_loss',\n", 368 | " mode='min',\n", 369 | ")\n", 370 | "\n", 371 | "trainer = pl.Trainer(\n", 372 | " gpus=[0],\n", 373 | " max_epochs=10,\n", 374 | " callbacks=[save],\n", 375 | ")" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "trainer.fit(module, datamodule=data_module)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "## Inference" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "best = save.best\n", 401 | "module.load_state_dict(torch.load(best, map_location=\"cpu\")['state_dict'])\n", 402 | "\n", 403 | "\n", 404 | "encoder_decoder = encoder_decoder.cpu()\n", 405 | "encoder_decoder = encoder_decoder.eval()\n", 406 | "\n", 407 | "def inference(text, starter=''):\n", 408 | " tk_kwargs = dict(truncation=True, max_length=128, padding=\"max_length\",\n", 409 | " return_tensors='pt')\n", 410 | " inputs = encoder_tokenizer([text,],**tk_kwargs)\n", 411 | " with torch.no_grad():\n", 412 | " return decoder_tokenizer.batch_decode(\n", 413 | " encoder_decoder.generate(\n", 414 | " inputs.input_ids,\n", 415 | " attention_mask=inputs.attention_mask,\n", 416 | " num_beams=3,\n", 417 | " bos_token_id=101,\n", 418 | " ),\n", 419 | " skip_special_tokens=True)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "inference('我来跟大家说一句话')" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "inference(\"这个翻译不是很聪明,因为训练数据不够\")" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "encoder_decoder.push_to_hub(\"raynardj/wenyanwen-chinese-translate-to-ancient\")\n", 447 | "encoder_tokenizer.push_to_hub(\"raynardj/wenyanwen-chinese-translate-to-ancient\")" 448 | ] 449 | } 450 | ], 451 | "metadata": { 452 | "kernelspec": { 453 | "display_name": "Python 3 (ipykernel)", 454 | "language": "python", 455 | "name": "python3" 456 | }, 457 | "language_info": { 458 | "codemirror_mode": { 459 | "name": "ipython", 460 | "version": 3 461 | }, 462 | "file_extension": ".py", 463 | "mimetype": "text/x-python", 464 | "name": "python", 465 | "nbconvert_exporter": "python", 466 | "pygments_lexer": "ipython3", 467 | "version": "3.7.4" 468 | }, 469 | "toc": { 470 | "base_numbering": 1, 471 | "nav_menu": {}, 472 | "number_sections": true, 473 | "sideBar": true, 474 | "skip_h1_title": false, 475 | "title_cell": "Table of Contents", 476 | "title_sidebar": "Contents", 477 | "toc_cell": false, 478 | "toc_position": {}, 479 | "toc_section_display": true, 480 | "toc_window_display": false 481 | } 482 | }, 483 | "nbformat": 4, 484 | "nbformat_minor": 4 485 | } 486 | -------------------------------------------------------------------------------- /nbs/kw_leading_poe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Conditional text generation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Data downloaded [here](https://github.com/chinese-poetry/chinese-poetry)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Imports" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Forgebox Imports\n", 31 | "from forgebox.imports import *\n", 32 | "from gc_utils.env import *\n", 33 | "import pytorch_lightning as pl\n", 34 | "from transformers import (\n", 35 | " AutoTokenizer,\n", 36 | " GPT2LMHeadModel\n", 37 | ")\n", 38 | "import random\n", 39 | "from typing import List\n", 40 | "import re\n", 41 | "from jieba import cut" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def is_jupyter():\n", 51 | " try:\n", 52 | " get_ipython()\n", 53 | " return True\n", 54 | " except NameError:\n", 55 | " return False\n", 56 | " \n", 57 | "IS_JUPYTER = is_jupyter()\n", 58 | "if IS_JUPYTER:\n", 59 | " from tqdm.notebook import tqdm\n", 60 | "else:\n", 61 | " from tqdm import tqdm" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Locations" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "['cc_vs_zh', 'cctc', 'cn_shi', 'daizhigev20']" 80 | ] 81 | }, 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "DATA = sys_loc(\"DATA\")/\"nlp\"/\"zh\"\n", 89 | "DATA.ls()" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "POET = DATA/\"cn_shi\"\n", 99 | "ALL_JSON = list(POET.rglob(\"*.json\"))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Read and transform data" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "def read_json(path):\n", 116 | " return json.loads(Path(path).read_text())" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 7, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "application/vnd.jupyter.widget-view+json": { 127 | "model_id": "befb7e3f30e8467793de76ee58339a3e", 128 | "version_major": 2, 129 | "version_minor": 0 130 | }, 131 | "text/plain": [ 132 | " 0%| | 0/23 [00:00= num_head:\n", 205 | " return_text += text[i+1:]\n", 206 | " break\n", 207 | " if c in puncts:\n", 208 | " last_is_break = True\n", 209 | " else:\n", 210 | " last_is_break = False\n", 211 | " return heads, return_text" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 12, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "('間曜翩過',\n", 223 | " '[CLS]維有常度,[CLS]靈無停輈。[CLS]翩葉辭柯,[CLS]眼綠已稠。弱榦不盈尺,忽已高岑樓。念昔過庭日,朋來悉良儔。我年未成童,子少無與侔。我質本駑駘,蹇步畏阻脩。子如渥洼駒,猛氣已食牛。當時二老人,笑語懽且酬。門戶各有托,寧計才與不。登門如昨日,星紀跡再周。二老安在哉,體魄歸山丘。隔屋聞讀書,玉樹鏘琳球。呼燈使來前,秀氣炯雙眸。問之垂九齡,屬對解冥搜。感此傷我心,淚下不可收。來者日已長,逝者挽不留。其間我與子,能閲幾春秋。寧復青衿佩,與子從親游。幸子齒猶壯,有母方白頭。刷翮凌青霄,足勝負米由。而我風樹悲,耿耿何時休。四十已無聞,過是夫何求。矧復病日侵,見面良可羞。竹實不療饑,芰製非寒裘。躬耕苦勤勞,代耕多悔尤。學仙竟誰成,百年等浮漚。俛仰天地間,身世真悠悠。時雨漲綠池,好風交平疇。嚶嚶出谷鳥,汎汎川上鷗。遇景適會心,曠望聊夷猶。')" 224 | ] 225 | }, 226 | "execution_count": 12, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "extract(para)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "## Get tokenizer" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 13, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 14, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def replace_punctuation(text):\n", 258 | " return re.sub(r'[^\\w\\s]', ' ', text)\n", 259 | "\n", 260 | "def cutting(text):\n", 261 | " return list(i for i in cut(replace_punctuation(text), HMM=True,)if i != ' ')" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 15, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stderr", 271 | "output_type": "stream", 272 | "text": [ 273 | "Building prefix dict from the default dictionary ...\n", 274 | "Loading model from cache /tmp/jieba.cache\n", 275 | "Loading model cost 0.665 seconds.\n", 276 | "Prefix dict has been built successfully.\n" 277 | ] 278 | }, 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "['春眠', '不觉', '晓', '处处', '闻啼鸟']" 283 | ] 284 | }, 285 | "execution_count": 15, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "cutting(\"春眠不觉晓, 处处闻啼鸟\")" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 20, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "def pick_and_shuffle(li, min_n:int=0, max_n:int=None):\n", 301 | " if max_n is None:\n", 302 | " max_n = int(len(li)*.7)\n", 303 | " n = min_n + random.randint(0, min(max_n - min_n,10))\n", 304 | " random.shuffle(li)\n", 305 | " return list(set(li[:n]))" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 21, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "def create_kw(text):\n", 315 | " return pick_and_shuffle(cutting(text))" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 23, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "data": { 325 | "text/plain": [ 326 | "['晓', '春眠', '不觉']" 327 | ] 328 | }, 329 | "execution_count": 23, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "create_kw(\"春眠不觉晓, 处处闻啼鸟\")" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 24, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "heads, headless = extract(para)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 36, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/plain": [ 355 | "('間曜翩過', ['星紀跡', '我', '何求', '駘', '身世', '風樹悲', '雙眸', '復', '托', '與子'])" 356 | ] 357 | }, 358 | "execution_count": 36, 359 | "metadata": {}, 360 | "output_type": "execute_result" 361 | } 362 | ], 363 | "source": [ 364 | "heads, create_kw(headless.replace('[CLS]',\"\"))" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "## Dataset" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 48, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "class PoetDataset(Dataset):\n", 381 | " def __init__(\n", 382 | " self,\n", 383 | " df,\n", 384 | " tokenizer,\n", 385 | " p_head:float=.2,\n", 386 | " ):\n", 387 | " self.df = df.sample(frac=1).reset_index(drop=True)\n", 388 | " self.tokenizer = tokenizer\n", 389 | " self.p_head = p_head\n", 390 | " self.cn_num_dict = dict((i+1,f\"『{c}』\") for i, c in enumerate(\"一二三四\"))\n", 391 | "\n", 392 | " def __len__(self):\n", 393 | " return len(self.df)\n", 394 | " \n", 395 | " def __getitem__(self, idx):\n", 396 | " row = self.df.loc[idx]\n", 397 | " paragraphs = row.paragraphs\n", 398 | " heads, headless = extract(paragraphs)\n", 399 | " kws = '-'.join(create_kw(headless.replace('[CLS]',\"\")))\n", 400 | " return f\"{kws}《{heads}》{self.cn_num_dict.get(len(heads))}{headless}\"\n", 401 | " \n", 402 | " def collate_fn(self, batch):\n", 403 | " texts = list(batch)\n", 404 | " batch = self.tokenizer(\n", 405 | " list(texts),\n", 406 | " max_length=256,\n", 407 | " padding='max_length',\n", 408 | " return_tensors='pt',\n", 409 | " truncation=True\n", 410 | " )\n", 411 | " \n", 412 | " labels = batch['input_ids'].clone()\n", 413 | " labels[labels==0] = -100\n", 414 | " batch['labels'] = labels\n", 415 | " return batch\n", 416 | " \n", 417 | " def dataloader(self, batch_size=32, shuffle=True):\n", 418 | " return DataLoader(\n", 419 | " self,\n", 420 | " batch_size=batch_size,\n", 421 | " shuffle=shuffle,\n", 422 | " collate_fn=self.collate_fn\n", 423 | " )\n", 424 | "\n", 425 | " def split(self, val_ratio=.05):\n", 426 | " df = self.df.sample(frac=1).reset_index(drop=True)\n", 427 | " train_df = df[:int(len(df)*(1-val_ratio))]\n", 428 | " val_df = df[int(len(df)*(1-val_ratio)):]\n", 429 | " return PoetDataset(train_df, tokenizer=self.tokenizer),\\\n", 430 | " PoetDataset(val_df, tokenizer=self.tokenizer)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 49, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "poet_ds = PoetDataset(all_df, tokenizer)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "Let's arrange the text data this way, so the casual language modeling will work it's own magic" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 51, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "'忍看-窈窕-孤寝-勾带-嫩-黄昏《粉度》『二』[CLS]堞云齐,[CLS]清笳、愁入暮烟林杪。素艳透春,玉骨凄凉,勾带月痕生早。江天苍莽黄昏後,依然是、粉寒香瘦。动追感、西园嫩约,夜深人悄。记得东风窈窕。曾夜踏横斜,醉携娇小。惆怅旧欢,回首俱非,忍看绿笺红豆。香销纸帐人孤寝,相思恨、花还知否。梦回处,霜飞翠楼已晓。'" 458 | ] 459 | }, 460 | "execution_count": 51, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "poet_ds[1000]" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 52, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "dl = poet_ds.dataloader(12)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 53, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "batch = next(iter(dl))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 54, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "model = GPT2LMHeadModel.from_pretrained(\"uer/gpt2-chinese-poem\")" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 55, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "class DataModule(pl.LightningDataModule):\n", 503 | " def __init__(self, dataset, batch_size=32):\n", 504 | " super().__init__()\n", 505 | " self.dataset = dataset\n", 506 | " self.batch_size = batch_size\n", 507 | " \n", 508 | " def setup(self, stage=None):\n", 509 | " self.train_dataset, self.val_dataset = self.dataset.split()\n", 510 | "\n", 511 | " def train_dataloader(self):\n", 512 | " return self.train_dataset.dataloader(\n", 513 | " batch_size = self.batch_size,\n", 514 | " shuffle=True)\n", 515 | "\n", 516 | " def val_dataloader(self):\n", 517 | " return self.val_dataset.dataloader(\n", 518 | " batch_size = self.batch_size*2,\n", 519 | " shuffle=False)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 56, 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "class CausalLMModule(pl.LightningModule):\n", 529 | " def __init__(self, model):\n", 530 | " super().__init__()\n", 531 | " self.model = model\n", 532 | "\n", 533 | " def forward(self, **batch):\n", 534 | " return self.model(**batch)\n", 535 | "\n", 536 | " def training_step(self, batch, batch_idx):\n", 537 | " outputs = self(\n", 538 | " input_ids=batch[\"input_ids\"],\n", 539 | " attention_mask=batch[\"attention_mask\"],\n", 540 | " labels=batch.labels,\n", 541 | " )\n", 542 | " loss = outputs.loss\n", 543 | " self.log(\"loss\", loss)\n", 544 | " return loss\n", 545 | "\n", 546 | " def validation_step(self, batch, batch_idx):\n", 547 | " outputs = self(\n", 548 | " input_ids=batch[\"input_ids\"],\n", 549 | " attention_mask=batch[\"attention_mask\"],\n", 550 | " labels=batch.labels,\n", 551 | " )\n", 552 | " loss = outputs.loss\n", 553 | " self.log(\"val_loss\", loss)\n", 554 | " return loss\n", 555 | "\n", 556 | " def configure_optimizers(self):\n", 557 | " return torch.optim.Adam(self.parameters(), lr=1e-5)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 57, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "data_module = DataModule(poet_ds, batch_size=54)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 58, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "module = CausalLMModule(model)" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 59, 581 | "metadata": {}, 582 | "outputs": [ 583 | { 584 | "name": "stderr", 585 | "output_type": "stream", 586 | "text": [ 587 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:360: UserWarning: Checkpoint directory /GCI/transformers/weights/kw_leading_po exists and is not empty.\n", 588 | " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", 589 | "GPU available: True, used: True\n", 590 | "TPU available: False, using: 0 TPU cores\n" 591 | ] 592 | } 593 | ], 594 | "source": [ 595 | "save = pl.callbacks.ModelCheckpoint(\n", 596 | " '/GCI/transformers/weights/kw_leading_po',\n", 597 | " save_top_k=2,\n", 598 | " verbose=True,\n", 599 | " monitor='val_loss',\n", 600 | " mode='min',\n", 601 | ")\n", 602 | "\n", 603 | "trainer = pl.Trainer(\n", 604 | " gpus=[1],\n", 605 | " max_epochs=6,\n", 606 | " callbacks=[save],\n", 607 | ")" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 161, 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "name": "stderr", 617 | "output_type": "stream", 618 | "text": [ 619 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", 620 | "\n", 621 | " | Name | Type | Params\n", 622 | "------------------------------------------\n", 623 | "0 | model | GPT2LMHeadModel | 103 M \n", 624 | "------------------------------------------\n", 625 | "103 M Trainable params\n", 626 | "0 Non-trainable params\n", 627 | "103 M Total params\n", 628 | "412.665 Total estimated model params size (MB)\n" 629 | ] 630 | }, 631 | { 632 | "data": { 633 | "application/vnd.jupyter.widget-view+json": { 634 | "model_id": "", 635 | "version_major": 2, 636 | "version_minor": 0 637 | }, 638 | "text/plain": [ 639 | "Validation sanity check: 0it [00:00, ?it/s]" 640 | ] 641 | }, 642 | "metadata": {}, 643 | "output_type": "display_data" 644 | }, 645 | { 646 | "name": "stderr", 647 | "output_type": "stream", 648 | "text": [ 649 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 650 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n", 651 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 652 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n" 653 | ] 654 | }, 655 | { 656 | "data": { 657 | "application/vnd.jupyter.widget-view+json": { 658 | "model_id": "73bb9a74f1ad407ca77e56df07f35b46", 659 | "version_major": 2, 660 | "version_minor": 0 661 | }, 662 | "text/plain": [ 663 | "Training: 0it [00:00, ?it/s]" 664 | ] 665 | }, 666 | "metadata": {}, 667 | "output_type": "display_data" 668 | }, 669 | { 670 | "name": "stderr", 671 | "output_type": "stream", 672 | "text": [ 673 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py:897: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", 674 | " rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')\n" 675 | ] 676 | } 677 | ], 678 | "source": [ 679 | "trainer.fit(module, datamodule=data_module)" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": {}, 686 | "outputs": [], 687 | "source": [ 688 | "module.load_state_dict(\n", 689 | " torch.load(str(save.best), map_location=\"cpu\")['state_dict'])" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 26, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "model = module.model\n", 699 | "model = model.cpu()\n", 700 | "model = model.eval()" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 27, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [ 709 | "model.save_pretrained(hub/\"kw-lead-po\")" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": null, 715 | "metadata": {}, 716 | "outputs": [], 717 | "source": [ 718 | "model.push_to_hub(\"raynardj/keywords-cangtou-chinese-poetry\")" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 28, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "def inference(lead):\n", 728 | " leading = f\"《{lead}》\"\n", 729 | " input_ids = tokenizer(leading, return_tensors='pt', ).input_ids\n", 730 | " with torch.no_grad():\n", 731 | " pred = model.generate(\n", 732 | " input_ids,\n", 733 | " max_length=256,\n", 734 | " num_beams=3,\n", 735 | "# do_sample=True,\n", 736 | "# top_p=.6,\n", 737 | " bos_token_id=tokenizer.sep_token_id,\n", 738 | " pad_token_id=tokenizer.pad_token_id,\n", 739 | " eos_token_id=tokenizer.sep_token_id,\n", 740 | " )\n", 741 | " print(pred)\n", 742 | " return tokenizer.batch_decode(pred, skip_special_tokens=True)" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": null, 748 | "metadata": {}, 749 | "outputs": [], 750 | "source": [] 751 | } 752 | ], 753 | "metadata": { 754 | "kernelspec": { 755 | "display_name": "Python 3", 756 | "language": "python", 757 | "name": "python3" 758 | }, 759 | "language_info": { 760 | "codemirror_mode": { 761 | "name": "ipython", 762 | "version": 3 763 | }, 764 | "file_extension": ".py", 765 | "mimetype": "text/x-python", 766 | "name": "python", 767 | "nbconvert_exporter": "python", 768 | "pygments_lexer": "ipython3", 769 | "version": "3.7.6" 770 | }, 771 | "toc": { 772 | "base_numbering": 1, 773 | "nav_menu": {}, 774 | "number_sections": true, 775 | "sideBar": true, 776 | "skip_h1_title": false, 777 | "title_cell": "Table of Contents", 778 | "title_sidebar": "Contents", 779 | "toc_cell": false, 780 | "toc_position": {}, 781 | "toc_section_display": true, 782 | "toc_window_display": false 783 | } 784 | }, 785 | "nbformat": 4, 786 | "nbformat_minor": 4 787 | } 788 | -------------------------------------------------------------------------------- /nbs/punktuation_ner.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Punctuation NER" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# Forgebox Imports\n", 17 | "from forgebox.imports import *\n", 18 | "from forgebox.category import Category\n", 19 | "import pytorch_lightning as pl\n", 20 | "from transformers import AutoTokenizer, BertForTokenClassification\n", 21 | "from transformers import pipeline\n", 22 | "from typing import List\n", 23 | "import re" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from gc_utils.env import sys_loc\n", 33 | "DATA = sys_loc('DATA')/\"nlp\"/\"zh\"/\"daizhigev20\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Read Metadata" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "META = pd.read_csv(DATA/\"meta.csv\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "LABELS = META.query(\"charspan<15\").sample(frac=1.).reset_index(drop=True)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 6, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "punkt_regex = r'[^\\w\\s]'\n", 68 | "\n", 69 | "def position_of_all_punctuation(x):\n", 70 | " return [m.start() for m in re.finditer(punkt_regex, x)]\n", 71 | "\n", 72 | "# simplify the punctuation\n", 73 | "eng_punkt_to_cn_dict = {\n", 74 | " \".\": \"。\",\n", 75 | " \",\": \",\",\n", 76 | " \":\": \":\",\n", 77 | " \";\": \";\",\n", 78 | " \"?\": \"?\",\n", 79 | " \"!\": \"!\",\n", 80 | " \"“\": \"\\\"\",\n", 81 | " \"”\": \"\\\"\",\n", 82 | " \"‘\": \"\\'\",\n", 83 | " \"’\": \"\\'\",\n", 84 | " \"「\": \"(\",\n", 85 | " \"」\": \")\",\n", 86 | " \"『\": \"\\\"\",\n", 87 | " \"』\": \"\\\"\",\n", 88 | " \"(\": \"(\",\n", 89 | " \")\": \")\",\n", 90 | " \"《\": \"【\",\n", 91 | " \"》\": \"】\",\n", 92 | " \"[\": \"【\",\n", 93 | " \"]\": \"】\",\n", 94 | " }\n", 95 | "\n", 96 | "def translate_eng_punkt_to_cn(char):\n", 97 | " if char == \"O\":\n", 98 | " return char\n", 99 | " if char in eng_punkt_to_cn_dict.values():\n", 100 | " return char\n", 101 | " result = eng_punkt_to_cn_dict.get(char)\n", 102 | " if result is None:\n", 103 | " return \"。\"\n", 104 | " return result\n", 105 | "\n", 106 | "def punct_ner_pair(sentence):\n", 107 | " positions = position_of_all_punctuation(sentence)\n", 108 | " x = re.sub(punkt_regex, '', sentence)\n", 109 | " y = list(\"O\"*len(x))\n", 110 | " \n", 111 | " for i, p in enumerate(positions):\n", 112 | " y[p-i-1] = sentence[p]\n", 113 | " p_df = pd.DataFrame({\"x\":list(x), \"y\":y})\n", 114 | " p_df[\"y\"] = p_df[\"y\"].apply(translate_eng_punkt_to_cn)\n", 115 | " return p_df" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "ALL_LABELS = [\"O\",]+list(eng_punkt_to_cn_dict.values())" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 9, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "cates = Category(ALL_LABELS)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 10, 139 | "metadata": { 140 | "code_folding": [ 141 | 0 142 | ] 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "class PunctDataset(Dataset):\n", 147 | " def __init__(\n", 148 | " self,\n", 149 | " data_dir: Path,\n", 150 | " filelist: List[str],\n", 151 | " num_threads: int = 8,\n", 152 | " length: int = 1000,\n", 153 | " size: int = 540\n", 154 | " ):\n", 155 | " \"\"\"\n", 156 | " Args:\n", 157 | " - filelist: list of file names\n", 158 | " - The dataset will open ```num_threads``` files, and hold\n", 159 | " in memory simoultaneously.\n", 160 | " - num_threads: number of threads to read files,\n", 161 | " - length: number of sentences per batch\n", 162 | " - size: number of characters per sentence\n", 163 | " \"\"\"\n", 164 | " self.data_dir = Path(data_dir)\n", 165 | " self.filelist = filelist\n", 166 | " self.num_threads = num_threads\n", 167 | " self.length = length\n", 168 | " # open file strings, index is mod of num_threads\n", 169 | " self.current_files = dict(enumerate([\"\"]*length))\n", 170 | " self.string_index = dict(enumerate([0]*length))\n", 171 | " self.to_open_idx = 0\n", 172 | " self.size = size\n", 173 | " self.get_counter = 0\n", 174 | " self.return_string = False\n", 175 | "\n", 176 | " def __len__(self):\n", 177 | " return self.length\n", 178 | "\n", 179 | " def __repr__(self):\n", 180 | " return f\"PunctDataset: {len(self)}, on {len(self.filelist)} files\"\n", 181 | "\n", 182 | " def new_file(self, idx_mod):\n", 183 | " filename = self.filelist[self.to_open_idx]\n", 184 | " with open(self.data_dir/filename, \"r\", encoding=\"utf-8\") as f:\n", 185 | " self.current_files[idx_mod] = f.read()\n", 186 | "\n", 187 | " self.to_open_idx += 1\n", 188 | "\n", 189 | " # reset to open article file index\n", 190 | " if self.to_open_idx >= len(self.filelist):\n", 191 | " self.to_open_idx = 0\n", 192 | "\n", 193 | " # reset string_index within new article file\n", 194 | " self.string_index[idx_mod] = 0\n", 195 | "\n", 196 | " if self.to_open_idx % 500 == 0:\n", 197 | " print(f\"went through files:\\t{self.to_open_idx}\")\n", 198 | "\n", 199 | " def __getitem__(self, idx):\n", 200 | " idx_mod = self.get_counter % self.num_threads\n", 201 | "\n", 202 | " if self.string_index[idx_mod] >= len(self.current_files[idx_mod]):\n", 203 | " self.new_file(idx_mod)\n", 204 | " string_idx = self.string_index[idx_mod]\n", 205 | "\n", 206 | " # slicing a sentence\n", 207 | " sentence = self.current_files[idx_mod][string_idx:string_idx+self.size]\n", 208 | "\n", 209 | " # move the string_index within current article file\n", 210 | " self.string_index[idx_mod] += self.size\n", 211 | "\n", 212 | " # move the get_counter\n", 213 | " self.get_counter += 1\n", 214 | " p_df = punct_ner_pair(sentence)\n", 215 | " return list(p_df.x), list(p_df.y)\n", 216 | "\n", 217 | " def align_offsets(\n", 218 | " self,\n", 219 | " inputs,\n", 220 | " text_labels: List[List[str]],\n", 221 | " words: List[List[str]]\n", 222 | " ):\n", 223 | " \"\"\"\n", 224 | " inputs: output if tokenizer\n", 225 | " text_labels: labels in form of list of list of strings\n", 226 | " words: words in form of list of list of strings\n", 227 | " \"\"\"\n", 228 | " labels = torch.zeros_like(inputs.input_ids).long()\n", 229 | " labels -= 100\n", 230 | " text_lables_array = np.empty(labels.shape, dtype=object)\n", 231 | " words_array = np.empty(labels.shape, dtype=object)\n", 232 | " max_len = inputs.input_ids.shape[1]\n", 233 | "\n", 234 | " for row_id, input_ids in enumerate(inputs.input_ids):\n", 235 | " word_pos = inputs.word_ids(row_id)\n", 236 | " for idx, pos in enumerate(word_pos):\n", 237 | " if pos is None:\n", 238 | " continue\n", 239 | " if pos <= max_len:\n", 240 | " labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]\n", 241 | " if self.return_string:\n", 242 | " text_lables_array[row_id,\n", 243 | " idx] = text_labels[row_id][pos]\n", 244 | " words_array[row_id, idx] = words[row_id][pos]\n", 245 | "\n", 246 | " inputs['labels'] = labels\n", 247 | " if self.return_string:\n", 248 | " inputs['text_labels'] = text_lables_array.tolist()\n", 249 | " inputs['word'] = words_array.tolist()\n", 250 | " return inputs\n", 251 | "\n", 252 | " def collate_fn(self, data):\n", 253 | " \"\"\"\n", 254 | " data: list of tuple\n", 255 | " \"\"\"\n", 256 | " words, text_labels = zip(*data)\n", 257 | "\n", 258 | " inputs = self.tokenizer(\n", 259 | " list(words),\n", 260 | " return_tensors='pt',\n", 261 | " padding=True,\n", 262 | " truncation=True,\n", 263 | " max_length=self.max_len,\n", 264 | " is_split_into_words=True,\n", 265 | " return_offsets_mapping=True,\n", 266 | " add_special_tokens=False,\n", 267 | " )\n", 268 | " return self.align_offsets(inputs, text_labels, words)\n", 269 | "\n", 270 | " def dataloaders(self, tokenizer, cates, max_len: int = 512, batch_size: int = 32):\n", 271 | " self.tokenizer = tokenizer\n", 272 | " self.cates = cates\n", 273 | " self.max_len = max_len\n", 274 | " return DataLoader(\n", 275 | " self,\n", 276 | " batch_size=batch_size,\n", 277 | " shuffle=False,\n", 278 | " collate_fn=self.collate_fn\n", 279 | " )\n", 280 | "\n", 281 | " def split(self, ratio: float = 0.9):\n", 282 | " \"\"\"\n", 283 | " Split the dataset into train and valid\n", 284 | " \"\"\"\n", 285 | " np.random.shuffle(self.filelist)\n", 286 | " split_idx = int(len(self.filelist)*ratio)\n", 287 | " train_dataset = PunctDataset(\n", 288 | " self.data_dir,\n", 289 | " self.filelist[:split_idx],\n", 290 | " num_threads=self.num_threads,\n", 291 | " length=int(self.length*ratio),\n", 292 | " size=self.size,\n", 293 | " )\n", 294 | " valid_dataset = PunctDataset(\n", 295 | " self.data_dir,\n", 296 | " self.filelist[split_idx:],\n", 297 | " num_threads=self.num_threads,\n", 298 | " length=int(self.length*(1-ratio)),\n", 299 | " size=self.size,\n", 300 | " )\n", 301 | " return train_dataset, valid_dataset" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "Create dataset object\n", 309 | "\n", 310 | "* Length is the length of the epoch\n", 311 | "* Size: is the sequence length\n", 312 | "* num_threads: num of files that is opening at the same time" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 11, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "ds = PunctDataset(DATA, list(LABELS.filepath), num_threads=8, length=10000, size=512)\n", 322 | "train_ds, valid_ds = ds.split(0.9)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "### lightning data module" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 12, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "class PunctDataModule(pl.LightningDataModule):\n", 339 | " def __init__(self, train_ds, valid_ds, tokenizer, cates, \n", 340 | " max_len=512, batch_size=32):\n", 341 | " super().__init__()\n", 342 | " self.train_ds, self.valid_ds = train_ds, valid_ds\n", 343 | " self.tokenizer = tokenizer\n", 344 | " self.cates = cates\n", 345 | " self.max_len = max_len\n", 346 | " self.batch_size = batch_size\n", 347 | "\n", 348 | " def split_data(self):\n", 349 | " \n", 350 | " return train_ds, valid_ds\n", 351 | " \n", 352 | " def train_dataloader(self):\n", 353 | " return self.train_ds.dataloaders(\n", 354 | " self.tokenizer,\n", 355 | " self.cates,\n", 356 | " self.max_len,\n", 357 | " self.batch_size,\n", 358 | " )\n", 359 | " \n", 360 | " def val_dataloader(self):\n", 361 | " return self.valid_ds.dataloaders(\n", 362 | " self.tokenizer,\n", 363 | " self.cates,\n", 364 | " self.max_len,\n", 365 | " self.batch_size*4)" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "## Load Pretrained" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 14, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "Load pretrained model with proper num of categories" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 15, 394 | "metadata": {}, 395 | "outputs": [ 396 | { 397 | "name": "stderr", 398 | "output_type": "stream", 399 | "text": [ 400 | "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", 401 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 402 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 403 | "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']\n", 404 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "model = BertForTokenClassification.from_pretrained(\"bert-base-chinese\", num_labels=len(cates),)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 16, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "data_module = PunctDataModule(train_ds, valid_ds, tokenizer, cates,\n", 419 | " batch_size=32,)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "### Run data pipeline" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 17, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "inputs = next(iter(data_module.val_dataloader()))" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 18, 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "text/plain": [ 446 | "torch.Size([128, 464])" 447 | ] 448 | }, 449 | "execution_count": 18, 450 | "metadata": {}, 451 | "output_type": "execute_result" 452 | } 453 | ], 454 | "source": [ 455 | "inputs.input_ids.shape" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 19, 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "torch.Size([128, 464])" 467 | ] 468 | }, 469 | "execution_count": 19, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "inputs.labels.shape" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 20, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "# @interact\n", 485 | "# def view_label(idx=range(0,31)):\n", 486 | "# for x,y in zip(inputs['word'][idx], inputs['text_labels'][idx]):\n", 487 | "# print(f\"{x}-{y}\", end=\"\\t\")" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "## NER tranining module" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 21, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "from forgebox.thunder.callbacks import DataFrameMetricsCallback\n", 504 | "from forgebox.hf.train import NERModule" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 22, 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "module = NERModule(model)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 23, 519 | "metadata": {}, 520 | "outputs": [ 521 | { 522 | "name": "stderr", 523 | "output_type": "stream", 524 | "text": [ 525 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:360: UserWarning: Checkpoint directory /GCI/transformers/weights/punkt_ner/ exists and is not empty.\n", 526 | " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n" 527 | ] 528 | } 529 | ], 530 | "source": [ 531 | "save_callback = pl.callbacks.ModelCheckpoint(\n", 532 | " dirpath=\"/GCI/transformers/weights/punkt_ner/\",\n", 533 | " save_top_k=2,\n", 534 | " verbose=True,\n", 535 | " monitor='val_loss',\n", 536 | " mode='min',\n", 537 | ")\n", 538 | "df_show = DataFrameMetricsCallback()" 539 | ] 540 | }, 541 | { 542 | "cell_type": "markdown", 543 | "metadata": {}, 544 | "source": [ 545 | "Reset the configure_optimizers function" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 24, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "def configure_optimizers(self):\n", 555 | " # discriminative learning rate\n", 556 | " param_groups = [\n", 557 | " {'params': self.model.bert.parameters(), 'lr': 5e-6},\n", 558 | " {'params': self.model.classifier.parameters(), 'lr': 1e-3},\n", 559 | " ]\n", 560 | " optimizer = torch.optim.Adam(param_groups, lr=1e-3)\n", 561 | " return optimizer\n", 562 | "\n", 563 | "NERModule.configure_optimizers = configure_optimizers" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "Trainer" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 25, 576 | "metadata": {}, 577 | "outputs": [ 578 | { 579 | "name": "stderr", 580 | "output_type": "stream", 581 | "text": [ 582 | "GPU available: True, used: True\n", 583 | "TPU available: False, using: 0 TPU cores\n" 584 | ] 585 | } 586 | ], 587 | "source": [ 588 | "trainer = pl.Trainer(\n", 589 | " gpus=[0],\n", 590 | " max_epochs=100,\n", 591 | " callbacks=[df_show, save_callback],\n", 592 | " )" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": null, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "trainer.fit(module, datamodule=data_module)" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": {}, 607 | "source": [ 608 | "## Load the best model" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 29, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "module = module.load_from_checkpoint(save_callback.best_model_path, model=model)" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 28, 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "module.model.config.id2label = dict(enumerate(cates.i2c))\n", 627 | "module.model.config.label2id = cates.c2i.dict" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 35, 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "from transformers import pipeline" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 40, 642 | "metadata": {}, 643 | "outputs": [], 644 | "source": [ 645 | "module.model = module.model.eval()\n", 646 | "module.model = module.model.cpu()" 647 | ] 648 | }, 649 | { 650 | "cell_type": "markdown", 651 | "metadata": {}, 652 | "source": [ 653 | "## Push to model hub" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 32, 659 | "metadata": {}, 660 | "outputs": [], 661 | "source": [ 662 | "TAG = \"raynardj/classical-chinese-punctuation-guwen-biaodian\"" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 33, 668 | "metadata": {}, 669 | "outputs": [ 670 | { 671 | "data": { 672 | "application/vnd.jupyter.widget-view+json": { 673 | "model_id": "d70229344f854882bb4e83c42420b9ff", 674 | "version_major": 2, 675 | "version_minor": 0 676 | }, 677 | "text/plain": [ 678 | "Upload file pytorch_model.bin: 0%| | 32.0k/388M [00:00 main\n", 690 | "\n" 691 | ] 692 | }, 693 | { 694 | "data": { 695 | "text/plain": [ 696 | "'https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian/commit/163772b14564fa2930b1460f48be30fa7c9f8438'" 697 | ] 698 | }, 699 | "execution_count": 33, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | } 703 | ], 704 | "source": [ 705 | "module.model.push_to_hub(TAG)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 34, 711 | "metadata": {}, 712 | "outputs": [ 713 | { 714 | "name": "stderr", 715 | "output_type": "stream", 716 | "text": [ 717 | "To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian\n", 718 | " 163772b..c83256b main -> main\n", 719 | "\n" 720 | ] 721 | }, 722 | { 723 | "data": { 724 | "text/plain": [ 725 | "'https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian/commit/c83256b9ba08883a91c78512cce496b3cebe27a5'" 726 | ] 727 | }, 728 | "execution_count": 34, 729 | "metadata": {}, 730 | "output_type": "execute_result" 731 | } 732 | ], 733 | "source": [ 734 | "tokenizer.push_to_hub(TAG)" 735 | ] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "execution_count": 36, 740 | "metadata": {}, 741 | "outputs": [], 742 | "source": [ 743 | "ner = pipeline(\"ner\",module.model,tokenizer=tokenizer)" 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": 37, 749 | "metadata": {}, 750 | "outputs": [], 751 | "source": [ 752 | "def mark_sentence(x: str):\n", 753 | " outputs = ner(x)\n", 754 | " x_list = list(x)\n", 755 | " for i, output in enumerate(outputs):\n", 756 | " x_list.insert(output['end']+i, output['entity'])\n", 757 | " return \"\".join(x_list)" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": 42, 763 | "metadata": {}, 764 | "outputs": [ 765 | { 766 | "data": { 767 | "text/plain": [ 768 | "'是书虽称文粹,实与地志相表里。东南文献多借。是以有征与范成大呉郡志相辅而行,亦如骖有靳矣。乾隆四十二年三月,恭校上。'" 769 | ] 770 | }, 771 | "execution_count": 42, 772 | "metadata": {}, 773 | "output_type": "execute_result" 774 | } 775 | ], 776 | "source": [ 777 | "mark_sentence(\"\"\"是书虽称文粹实与地志相表里东南文献多借是以有征与范成大呉郡志相辅而行亦如骖有靳矣乾隆四十二年三月恭校上\"\"\")" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": 47, 783 | "metadata": {}, 784 | "outputs": [ 785 | { 786 | "data": { 787 | "text/plain": [ 788 | "'郡邑,置夫子庙于学,以嵗时释奠。盖自唐贞观以来,未之或改。我宋有天下因其制而损益之。姑苏当浙右要区,规模尤大,更建炎戎马,荡然无遗。虽修学宫于荆榛瓦砾之余,独殿宇未遑议也。每春秋展礼于斋庐,已则置不问,殆为阙典。今寳文阁直学士括苍梁公来牧之。明年,实绍兴十有一禩也。二月,上丁修祀既毕,乃愓然自咎,揖诸生而告之曰\"天子不以汝嘉为不肖,俾再守兹土,顾治民事,神皆守之职。惟是夫子之祀,教化所基,尤宜严且谨。而拜跪荐祭之地,卑陋乃尔。其何以掲防妥灵?汝嘉不敢避其责。曩常去此弥年,若有所负,尚安得以罢輭自恕,复累后人乎!他日或克就绪,愿与诸君落之。于是谋之,僚吏搜故府,得遗材千枚,取赢资以给其费。鸠工庀役,各举其任。嵗月讫,工民不与知像,设礼器,百用具修。至于堂室。廊序。门牖。垣墙,皆一新之。'" 789 | ] 790 | }, 791 | "execution_count": 47, 792 | "metadata": {}, 793 | "output_type": "execute_result" 794 | } 795 | ], 796 | "source": [ 797 | "mark_sentence(\"\"\"郡邑置夫子庙于学以嵗时释奠盖自唐贞观以来未之或改我宋有天下因其制而损益之姑苏当浙右要区规模尤大更建炎戎马荡然无遗虽修学宫于荆榛瓦砾之余独殿宇未遑议也每春秋展礼于斋庐已则置不问殆为阙典今寳文阁直学士括苍梁公来牧之明年实绍兴十有一禩也二月上丁修祀既毕乃愓然自咎揖诸生而告之曰天子不以汝嘉为不肖俾再守兹土顾治民事神皆守之职惟是夫子之祀教化所基尤宜严且谨而拜跪荐祭之地卑陋乃尔其何以掲防妥灵汝嘉不敢避其责曩常去此弥年若有所负尚安得以罢輭自恕复累后人乎他日或克就绪愿与诸君落之于是谋之僚吏搜故府得遗材千枚取赢资以给其费鸠工庀役各举其任嵗月讫工民不与知像设礼器百用具修至于堂室廊序门牖垣墙皆一新之\"\"\")" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": null, 803 | "metadata": {}, 804 | "outputs": [], 805 | "source": [] 806 | } 807 | ], 808 | "metadata": { 809 | "kernelspec": { 810 | "display_name": "Python 3", 811 | "language": "python", 812 | "name": "python3" 813 | }, 814 | "language_info": { 815 | "codemirror_mode": { 816 | "name": "ipython", 817 | "version": 3 818 | }, 819 | "file_extension": ".py", 820 | "mimetype": "text/x-python", 821 | "name": "python", 822 | "nbconvert_exporter": "python", 823 | "pygments_lexer": "ipython3", 824 | "version": "3.7.6" 825 | }, 826 | "toc": { 827 | "base_numbering": 1, 828 | "nav_menu": {}, 829 | "number_sections": true, 830 | "sideBar": true, 831 | "skip_h1_title": false, 832 | "title_cell": "Table of Contents", 833 | "title_sidebar": "Contents", 834 | "toc_cell": false, 835 | "toc_position": {}, 836 | "toc_section_display": true, 837 | "toc_window_display": true 838 | } 839 | }, 840 | "nbformat": 4, 841 | "nbformat_minor": 4 842 | } 843 | -------------------------------------------------------------------------------- /nbs/xlsearch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "9EtssVJyGkmU" 7 | }, 8 | "source": [ 9 | "# Cross Language Search" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## Imports" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# !pip install -Uqq git+https://github.com/raynardj/forgebox" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from forgebox.imports import *\n", 42 | "from forgebox.thunder.callbacks import DataFrameMetricsCallback\n", 43 | "from forgebox.multiproc import DataFrameRowling\n", 44 | "from gc_utils.env import *\n", 45 | "from datasets import load_dataset\n", 46 | "# from fastai.text.all import *\n", 47 | "from unpackai.nlp import *\n", 48 | "from tqdm.notebook import tqdm\n", 49 | "import random" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "import pytorch_lightning as pl" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "import re\n", 68 | "\n", 69 | "def remove_all_punkt(text):\n", 70 | " \"\"\"\n", 71 | " Removes all punctuation from Chinese text.\n", 72 | "\n", 73 | " :param text: text to remove punctuation from\n", 74 | " :return: text with no punctuation\n", 75 | " \"\"\"\n", 76 | " return re.sub(r'[^\\w\\s]', '', text)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 5, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "'亳州水军千户胡进等领骑兵渡淝水逾荆山与宋兵战杀获甚众赏钞币有差'" 88 | ] 89 | }, 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "remove_all_punkt(\"亳州水军千户胡进等领骑兵渡淝水,逾荆山,与宋兵战,杀获甚众,赏钞币有差。\")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Config" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "DATA = Path(sys_loc('DATA')/\"nlp\"/\"zh\"/\"cc_vs_zh\")\n", 113 | "TO_CLASSICAL = False" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Download data" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "ZbXuwqr0KEr8" 127 | }, 128 | "source": [ 129 | "## Data" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "### Combine data" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "all_file = list(DATA.rglob(\"data/*\"))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 8, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "def open_file_to_lines(file):\n", 155 | " with open(file) as f:\n", 156 | " lines = f.read().splitlines()\n", 157 | " return lines\n", 158 | "\n", 159 | "def pairing_the_file(files,kw):\n", 160 | " pairs = []\n", 161 | " for file in files:\n", 162 | " if kw not in file.name:\n", 163 | " file1 = file\n", 164 | " file2 = f\"{file}{kw}\"\n", 165 | " pairs.append((file1,file2))\n", 166 | " return pairs" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 9, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "pairs = pairing_the_file(all_file,\"翻译\")" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 10, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "def open_pairs(pairs):\n", 185 | " chunks = []\n", 186 | " for pair in tqdm(pairs, leave=False):\n", 187 | " file1,file2 = pair\n", 188 | " lines1 = open_file_to_lines(file1)\n", 189 | " lines2 = open_file_to_lines(file2)\n", 190 | " chunks.append(pd.DataFrame({\"classical\":lines1,\"modern\":lines2}))\n", 191 | " return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 11, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "application/vnd.jupyter.widget-view+json": { 202 | "model_id": "", 203 | "version_major": 2, 204 | "version_minor": 0 205 | }, 206 | "text/plain": [ 207 | " 0%| | 0/27 [00:00\n", 240 | "\n", 253 | "\n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | "
sourcetarget
0下也。因为在下面。
1长乐王尉粲甚礼之。垦銮王幽垩很礼待他。
2太师王舜自莽篡位后,病悸剧,死。太师王舜自王莽篡夺皇位后,得了心悸病,渐渐加剧,终于病故。
3秋七月丙寅,以旱,亲录京城囚徒。秋七月二十九日,因为干旱,皇上亲自审查并记录囚徒罪状。
4乙亥,齐仪同三司元旭坐事赐死。乙亥,北齐国仪同三司元旭因犯罪被赐死。
\n", 289 | "" 290 | ], 291 | "text/plain": [ 292 | " source target\n", 293 | "0 下也。 因为在下面。\n", 294 | "1 长乐王尉粲甚礼之。 垦銮王幽垩很礼待他。\n", 295 | "2 太师王舜自莽篡位后,病悸剧,死。 太师王舜自王莽篡夺皇位后,得了心悸病,渐渐加剧,终于病故。\n", 296 | "3 秋七月丙寅,以旱,亲录京城囚徒。 秋七月二十九日,因为干旱,皇上亲自审查并记录囚徒罪状。\n", 297 | "4 乙亥,齐仪同三司元旭坐事赐死。 乙亥,北齐国仪同三司元旭因犯罪被赐死。" 298 | ] 299 | }, 300 | "execution_count": 13, 301 | "metadata": {}, 302 | "output_type": "execute_result" 303 | } 304 | ], 305 | "source": [ 306 | "df.head()" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "### Loading tokenizer" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 14, 319 | "metadata": { 320 | "id": "ukyVGg8HmSd-" 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "from transformers import (\n", 325 | " AutoTokenizer,\n", 326 | " AutoModelForMaskedLM,\n", 327 | " AutoModel,\n", 328 | " EncoderDecoderModel\n", 329 | " )\n", 330 | "PRETRAINED = \"bert-base-chinese\"\n", 331 | "\n", 332 | "tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "### Pytoch Dataset" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 15, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "import random\n", 349 | "\n", 350 | "def combine_randomly(data):\n", 351 | " if random.random()>.5:\n", 352 | " a,b = data['source'],data['target']\n", 353 | " else:\n", 354 | " a,b = data['target'],data['source']\n", 355 | " return f\"{a}{b}\"\n", 356 | "\n", 357 | "def pick_randomly(data):\n", 358 | " return list(data.values())[int(random.random()>.5)]\n", 359 | "\n", 360 | "def mixup(data):\n", 361 | " if len(data['target'])> 70:\n", 362 | " th = .7\n", 363 | " else:\n", 364 | " th = .3\n", 365 | " if random.random()>th:\n", 366 | " return combine_randomly(data)\n", 367 | " else:\n", 368 | " return pick_randomly(data)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 16, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "class XLSearch(Dataset):\n", 378 | " def __init__(\n", 379 | " self, df, tokenizer,\n", 380 | " max_len=128,\n", 381 | " no_punkt:bool = False,\n", 382 | " mlm_probability:float = .15,\n", 383 | " ):\n", 384 | " \"\"\"\n", 385 | " no_punkt, do we ramdomly remove punctuation\n", 386 | " from source sentence\n", 387 | " \"\"\"\n", 388 | " super().__init__()\n", 389 | " self.df = df.reset_index(drop=True)\n", 390 | " self.tokenizer = tokenizer\n", 391 | " self.max_len = max_len\n", 392 | " self.mlm_probability = mlm_probability\n", 393 | " \n", 394 | " def __len__(self, ):\n", 395 | " return len(self.df)\n", 396 | "\n", 397 | " def __getitem__(self, idx):\n", 398 | " return mixup(dict(self.df.loc[idx]))\n", 399 | "\n", 400 | " def collate(self, data):\n", 401 | " inputs = self.tokenizer(\n", 402 | " list(data),\n", 403 | " max_length=self.max_len,\n", 404 | " padding='max_length',\n", 405 | " truncation=True,\n", 406 | " return_tensors='pt',\n", 407 | " )\n", 408 | " return self.mlm_masking(inputs)\n", 409 | " \n", 410 | " def mlm_masking(self,inputs):\n", 411 | " \"\"\"\n", 412 | " convert inputs for masked language modeling\n", 413 | " \"\"\"\n", 414 | " if self.mlm_probability is None:\n", 415 | " return inputs\n", 416 | " input_ids = inputs.input_ids\n", 417 | " token_type_ids = inputs.token_type_ids\n", 418 | " \n", 419 | " # masking input_ids\n", 420 | " masked = input_ids.clone()\n", 421 | " masked[\n", 422 | " torch.rand(input_ids.shape).to(input_ids.device) < self.mlm_probability\n", 423 | " ] = self.tokenizer.mask_token_id\n", 424 | " \n", 425 | " labels = input_ids.clone()\n", 426 | " labels[token_type_ids == 1] = -100\n", 427 | " labels[labels==0] = -100\n", 428 | " token_type_ids[masked==self.tokenizer.mask_token_id] = 1\n", 429 | " labels[token_type_ids == 0] = -100\n", 430 | " \n", 431 | " inputs['input_ids'] = masked\n", 432 | " inputs['labels'] = labels\n", 433 | " inputs['token_type_ids'] = token_type_ids\n", 434 | " return inputs\n", 435 | "\n", 436 | " def dataloader(self, batch_size, shuffle=True):\n", 437 | " return DataLoader(\n", 438 | " self,\n", 439 | " batch_size=batch_size,\n", 440 | " shuffle=shuffle,\n", 441 | " collate_fn=self.collate,\n", 442 | " )\n", 443 | "\n", 444 | " def split_train_valid(self, valid_size=0.1):\n", 445 | " split_index = int(len(self) * (1 - valid_size))\n", 446 | " cls = type(self)\n", 447 | " shuffled = self.df.sample(frac=1).reset_index(drop=True)\n", 448 | " train_set = cls(\n", 449 | " shuffled.iloc[:split_index],\n", 450 | " tokenizer=self.tokenizer,\n", 451 | " max_len=self.max_len,\n", 452 | " )\n", 453 | " valid_set = cls(\n", 454 | " shuffled.iloc[split_index:],\n", 455 | " tokenizer=self.tokenizer,\n", 456 | " max_len=self.max_len,\n", 457 | " )\n", 458 | " return train_set, valid_set" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 17, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "ds = XLSearch(df, tokenizer, )" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 18, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "data": { 477 | "text/plain": [ 478 | "'又将御史王金,主事马思聪、金山,参议黄宏、许效廉,布政使胡廉,参政陈杲、刘非木,佥事赖凤,指挥许金、白昂等人逮捕下狱。执御史王金,主事马思聪、金山,参议黄宏、许效廉,布政使胡廉,参政陈杲、刘棐,佥事赖凤,指挥许金、白昂等下狱。'" 479 | ] 480 | }, 481 | "execution_count": 18, 482 | "metadata": {}, 483 | "output_type": "execute_result" 484 | } 485 | ], 486 | "source": [ 487 | "ds[5]" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "### Different ways of mixing and masking" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": {}, 500 | "source": [ 501 | "### PL datamodule" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 19, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "class DataModule(pl.LightningDataModule):\n", 511 | " def __init__(\n", 512 | " self, df,\n", 513 | " tokenizer,\n", 514 | " batch_size=12,\n", 515 | " max_len=128,\n", 516 | " no_punkt:bool=False):\n", 517 | " super().__init__()\n", 518 | " self.df = df\n", 519 | " self.ds = XLSearch(df,\n", 520 | " tokenizer,\n", 521 | " max_len=max_len,)\n", 522 | " self.tokenizer = tokenizer\n", 523 | " self.max_len = max_len\n", 524 | " self.batch_size = batch_size\n", 525 | "\n", 526 | " def setup(self, stage=None):\n", 527 | " self.train_set, self.valid_set = self.ds.split_train_valid()\n", 528 | "\n", 529 | " def train_dataloader(self):\n", 530 | " return self.train_set.dataloader(\n", 531 | " batch_size=self.batch_size, shuffle=True)\n", 532 | "\n", 533 | " def val_dataloader(self):\n", 534 | " return self.valid_set.dataloader(\n", 535 | " batch_size=self.batch_size*2, shuffle=False)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 20, 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "data_module = DataModule(\n", 545 | " df, tokenizer,\n", 546 | " batch_size=64,\n", 547 | " max_len=256,\n", 548 | " no_punkt=False if TO_CLASSICAL else True,)\n", 549 | "data_module.setup()" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": 21, 555 | "metadata": {}, 556 | "outputs": [ 557 | { 558 | "data": { 559 | "text/plain": [ 560 | "{'input_ids': tensor([[ 101, 1282, 103, ..., 0, 0, 0],\n", 561 | " [ 101, 3293, 1062, ..., 0, 0, 103],\n", 562 | " [ 101, 758, 2399, ..., 0, 0, 0],\n", 563 | " ...,\n", 564 | " [ 101, 7826, 815, ..., 103, 0, 0],\n", 565 | " [ 101, 5628, 6818, ..., 0, 0, 0],\n", 566 | " [ 101, 5745, 815, ..., 0, 103, 0]]), 'token_type_ids': tensor([[0, 0, 1, ..., 0, 0, 0],\n", 567 | " [0, 0, 0, ..., 0, 0, 1],\n", 568 | " [0, 0, 0, ..., 0, 0, 0],\n", 569 | " ...,\n", 570 | " [0, 0, 0, ..., 1, 0, 0],\n", 571 | " [0, 0, 0, ..., 0, 0, 0],\n", 572 | " [0, 0, 0, ..., 0, 1, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", 573 | " [1, 1, 1, ..., 0, 0, 0],\n", 574 | " [1, 1, 1, ..., 0, 0, 0],\n", 575 | " ...,\n", 576 | " [1, 1, 1, ..., 0, 0, 0],\n", 577 | " [1, 1, 1, ..., 0, 0, 0],\n", 578 | " [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([[-100, -100, 1063, ..., -100, -100, -100],\n", 579 | " [-100, -100, -100, ..., -100, -100, -100],\n", 580 | " [-100, -100, -100, ..., -100, -100, -100],\n", 581 | " ...,\n", 582 | " [-100, -100, -100, ..., -100, -100, -100],\n", 583 | " [-100, -100, -100, ..., -100, -100, -100],\n", 584 | " [-100, -100, -100, ..., -100, -100, -100]])}" 585 | ] 586 | }, 587 | "execution_count": 21, 588 | "metadata": {}, 589 | "output_type": "execute_result" 590 | } 591 | ], 592 | "source": [ 593 | "inputs = next(iter(data_module.train_dataloader()))\n", 594 | "inputs" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [ 601 | "if we are doing clasical Chinese to modern Chinese, we can randomly set half of the input without any punctuation, as many data source might be" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 22, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "# tokenizer.batch_decode(\n", 611 | "# inputs.input_ids,skip_special_tokens=False\n", 612 | "# )" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": { 618 | "id": "92iwRu6Oqbzb" 619 | }, 620 | "source": [ 621 | "### Load pretrained models" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": { 627 | "id": "pajv5ridLamp" 628 | }, 629 | "source": [ 630 | "## Model" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 23, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "name": "stderr", 640 | "output_type": "stream", 641 | "text": [ 642 | "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", 643 | "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 644 | "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 645 | ] 646 | } 647 | ], 648 | "source": [ 649 | "# loading pretrained model\n", 650 | "model = AutoModelForMaskedLM.from_pretrained(PRETRAINED\n", 651 | ")" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 24, 657 | "metadata": { 658 | "id": "jBVyNeKUv6FU" 659 | }, 660 | "outputs": [], 661 | "source": [ 662 | "class MaskedLM(pl.LightningModule):\n", 663 | " def __init__(\n", 664 | " self,\n", 665 | " model):\n", 666 | " super().__init__()\n", 667 | " self.model = model\n", 668 | "\n", 669 | " def forward(self, **kwargs):\n", 670 | " return self.model(**kwargs)\n", 671 | "\n", 672 | " def accuracy(self, batch_input, outputs):\n", 673 | " \"\"\"\n", 674 | " Accuracy for masked language model\n", 675 | " \"\"\"\n", 676 | " mask_mask = batch_input.labels != -100\n", 677 | " predictions = outputs.logits.argmax(-1)[mask_mask]\n", 678 | " targets = batch_input.labels[mask_mask]\n", 679 | " return (predictions == targets).float().mean()\n", 680 | "\n", 681 | " def training_step(self, batch, batch_idx):\n", 682 | " inputs = dict(\n", 683 | " input_ids=batch.input_ids,\n", 684 | " attention_mask=batch.attention_mask,\n", 685 | " labels=batch.labels,\n", 686 | " )\n", 687 | " outputs = self(**inputs)\n", 688 | " self.log(\"loss\", outputs.loss, prog_bar=True)\n", 689 | " self.log(\"acc\",\n", 690 | " self.accuracy(batch, outputs),\n", 691 | " on_step=True, prog_bar=True)\n", 692 | " return outputs.loss\n", 693 | "\n", 694 | " def validation_step(self, batch, batch_idx):\n", 695 | " inputs = dict(\n", 696 | " input_ids=batch.input_ids,\n", 697 | " attention_mask=batch.attention_mask,\n", 698 | " labels=batch.labels,\n", 699 | " )\n", 700 | " outputs = self(**inputs)\n", 701 | " self.log(\"val_loss\", outputs.loss, prog_bar=True)\n", 702 | " self.log(\"val_acc\",\n", 703 | " self.accuracy(batch, outputs),\n", 704 | " on_step=False, prog_bar=True)\n", 705 | " return outputs.loss\n", 706 | " \n", 707 | " def configure_optimizers(self):\n", 708 | " return torch.optim.Adam(self.parameters(), lr=1e-6)" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 25, 714 | "metadata": { 715 | "id": "5uIjcPuXw0Fr" 716 | }, 717 | "outputs": [], 718 | "source": [ 719 | "module = MaskedLM(model)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "metadata": { 725 | "id": "DBf3NTKSLcUb" 726 | }, 727 | "source": [ 728 | "## Training" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": 26, 734 | "metadata": {}, 735 | "outputs": [], 736 | "source": [ 737 | "TASK = \"xlsearch_cc_zh\"" 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": null, 743 | "metadata": {}, 744 | "outputs": [ 745 | { 746 | "name": "stderr", 747 | "output_type": "stream", 748 | "text": [ 749 | "GPU available: True, used: True\n", 750 | "TPU available: False, using: 0 TPU cores\n", 751 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", 752 | "\n", 753 | " | Name | Type | Params\n", 754 | "------------------------------------------\n", 755 | "0 | model | BertForMaskedLM | 102 M \n", 756 | "------------------------------------------\n", 757 | "102 M Trainable params\n", 758 | "0 Non-trainable params\n", 759 | "102 M Total params\n", 760 | "409.161 Total estimated model params size (MB)\n" 761 | ] 762 | }, 763 | { 764 | "data": { 765 | "application/vnd.jupyter.widget-view+json": { 766 | "model_id": "", 767 | "version_major": 2, 768 | "version_minor": 0 769 | }, 770 | "text/plain": [ 771 | "Validation sanity check: 0it [00:00, ?it/s]" 772 | ] 773 | }, 774 | "metadata": {}, 775 | "output_type": "display_data" 776 | }, 777 | { 778 | "name": "stderr", 779 | "output_type": "stream", 780 | "text": [ 781 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 782 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n", 783 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 784 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n" 785 | ] 786 | }, 787 | { 788 | "data": { 789 | "application/vnd.jupyter.widget-view+json": { 790 | "model_id": "b4eaec356ce34562923404d08801bc53", 791 | "version_major": 2, 792 | "version_minor": 0 793 | }, 794 | "text/plain": [ 795 | "Training: 0it [00:00, ?it/s]" 796 | ] 797 | }, 798 | "metadata": {}, 799 | "output_type": "display_data" 800 | }, 801 | { 802 | "name": "stderr", 803 | "output_type": "stream", 804 | "text": [ 805 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/properties.py:249: UserWarning: The progress bar already tracks a metric with the name(s) 'loss' and `self.log('loss', ..., prog_bar=True)` will overwrite this value. If this is undesired, change the name or override `get_progress_bar_dict()` in `LightingModule`.\n", 806 | " f\" in `LightingModule`.\", UserWarning\n", 807 | "Epoch 0, global step 1023: acc reached 0.54819 (best 0.54819), saving model to \"/nvme/GCI/transformers/weights/xlsearch_cc_zh/epoch=0-step=1023.ckpt\" as top 3\n", 808 | "IOPub message rate exceeded.\n", 809 | "The notebook server will temporarily stop sending output\n", 810 | "to the client in order to avoid crashing it.\n", 811 | "To change this limit, set the config variable\n", 812 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 813 | "\n", 814 | "Current values:\n", 815 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 816 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 817 | "\n" 818 | ] 819 | } 820 | ], 821 | "source": [ 822 | "tb_logger = pl.loggers.TensorBoardLogger(\n", 823 | " save_dir=f\"/GCI/tensorboard/{TASK}\",\n", 824 | " name=TASK,\n", 825 | " )\n", 826 | "\n", 827 | "save_cb = pl.callbacks.ModelCheckpoint(\n", 828 | " dirpath=f\"/GCI/transformers/weights/{TASK}\",\n", 829 | " save_top_k=3,\n", 830 | " verbose=True,\n", 831 | " monitor=\"acc\",\n", 832 | " save_weights_only=True,\n", 833 | " every_n_train_steps=1024,\n", 834 | " mode=\"max\",\n", 835 | " )\n", 836 | "\n", 837 | "trainer = pl.Trainer(\n", 838 | " gpus=[1,],\n", 839 | " max_epochs=10,\n", 840 | " logger = [tb_logger,],\n", 841 | " callbacks=[save_cb,\n", 842 | "# DataFrameMetricsCallback()\n", 843 | " ],\n", 844 | " )\n", 845 | "\n", 846 | "trainer.fit(\n", 847 | " module,\n", 848 | " datamodule = data_module\n", 849 | " )\n" 850 | ] 851 | }, 852 | { 853 | "cell_type": "markdown", 854 | "metadata": {}, 855 | "source": [ 856 | "## Save" 857 | ] 858 | }, 859 | { 860 | "cell_type": "code", 861 | "execution_count": 22, 862 | "metadata": {}, 863 | "outputs": [], 864 | "source": [ 865 | "best = save.best" 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": 24, 871 | "metadata": {}, 872 | "outputs": [ 873 | { 874 | "data": { 875 | "text/plain": [ 876 | "" 877 | ] 878 | }, 879 | "execution_count": 24, 880 | "metadata": {}, 881 | "output_type": "execute_result" 882 | } 883 | ], 884 | "source": [ 885 | "module.load_state_dict(torch.load(best, map_location=\"cpu\")['state_dict'])" 886 | ] 887 | }, 888 | { 889 | "cell_type": "code", 890 | "execution_count": 61, 891 | "metadata": {}, 892 | "outputs": [ 893 | { 894 | "name": "stderr", 895 | "output_type": "stream", 896 | "text": [ 897 | "Cloning https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient into local empty directory.\n" 898 | ] 899 | }, 900 | { 901 | "data": { 902 | "application/vnd.jupyter.widget-view+json": { 903 | "model_id": "20893a4c96924d2ba176c94cf0eb5e1c", 904 | "version_major": 2, 905 | "version_minor": 0 906 | }, 907 | "text/plain": [ 908 | "Upload file pytorch_model.bin: 0%| | 32.0k/916M [00:00 main\n", 920 | "\n" 921 | ] 922 | }, 923 | { 924 | "data": { 925 | "text/plain": [ 926 | "'https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient/commit/5ee213356db17dfa9577226a90d5e9bd9461b495'" 927 | ] 928 | }, 929 | "execution_count": 61, 930 | "metadata": {}, 931 | "output_type": "execute_result" 932 | } 933 | ], 934 | "source": [ 935 | "# encoder_decoder.push_to_hub(\"raynardj/wenyanwen-chinese-translate-to-ancient\")" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": 65, 941 | "metadata": {}, 942 | "outputs": [ 943 | { 944 | "name": "stderr", 945 | "output_type": "stream", 946 | "text": [ 947 | "To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient\n", 948 | " 5ee2133..ab72fa4 main -> main\n", 949 | "\n" 950 | ] 951 | }, 952 | { 953 | "data": { 954 | "text/plain": [ 955 | "'https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient/commit/ab72fa41627cfeb6fef64e196d68d81b0adb6228'" 956 | ] 957 | }, 958 | "execution_count": 65, 959 | "metadata": {}, 960 | "output_type": "execute_result" 961 | } 962 | ], 963 | "source": [ 964 | "# encoder_tokenizer.push_to_hub(\"raynardj/wenyanwen-chinese-translate-to-ancient\")" 965 | ] 966 | } 967 | ], 968 | "metadata": { 969 | "accelerator": "GPU", 970 | "colab": { 971 | "collapsed_sections": [], 972 | "name": "seq2seq.ipynb", 973 | "provenance": [] 974 | }, 975 | "kernelspec": { 976 | "display_name": "Python 3", 977 | "language": "python", 978 | "name": "python3" 979 | }, 980 | "language_info": { 981 | "codemirror_mode": { 982 | "name": "ipython", 983 | "version": 3 984 | }, 985 | "file_extension": ".py", 986 | "mimetype": "text/x-python", 987 | "name": "python", 988 | "nbconvert_exporter": "python", 989 | "pygments_lexer": "ipython3", 990 | "version": "3.7.6" 991 | }, 992 | "toc": { 993 | "base_numbering": 1, 994 | "nav_menu": {}, 995 | "number_sections": true, 996 | "sideBar": true, 997 | "skip_h1_title": false, 998 | "title_cell": "Table of Contents", 999 | "title_sidebar": "Contents", 1000 | "toc_cell": false, 1001 | "toc_position": {}, 1002 | "toc_section_display": true, 1003 | "toc_window_display": false 1004 | }, 1005 | "widgets": { 1006 | "application/vnd.jupyter.widget-state+json": { 1007 | "579055f403bf4594a2c665adfdfb8995": { 1008 | "model_module": "@jupyter-widgets/controls", 1009 | "model_module_version": "1.5.0", 1010 | "model_name": "DescriptionStyleModel", 1011 | "state": { 1012 | "_model_module": "@jupyter-widgets/controls", 1013 | "_model_module_version": "1.5.0", 1014 | "_model_name": "DescriptionStyleModel", 1015 | "_view_count": null, 1016 | "_view_module": "@jupyter-widgets/base", 1017 | "_view_module_version": "1.2.0", 1018 | "_view_name": "StyleView", 1019 | "description_width": "" 1020 | } 1021 | }, 1022 | "659eee19636c45da881d243f66aedf27": { 1023 | "model_module": "@jupyter-widgets/base", 1024 | "model_module_version": "1.2.0", 1025 | "model_name": "LayoutModel", 1026 | "state": { 1027 | "_model_module": "@jupyter-widgets/base", 1028 | "_model_module_version": "1.2.0", 1029 | "_model_name": "LayoutModel", 1030 | "_view_count": null, 1031 | "_view_module": "@jupyter-widgets/base", 1032 | "_view_module_version": "1.2.0", 1033 | "_view_name": "LayoutView", 1034 | "align_content": null, 1035 | "align_items": null, 1036 | "align_self": null, 1037 | "border": null, 1038 | "bottom": null, 1039 | "display": null, 1040 | "flex": null, 1041 | "flex_flow": null, 1042 | "grid_area": null, 1043 | "grid_auto_columns": null, 1044 | "grid_auto_flow": null, 1045 | "grid_auto_rows": null, 1046 | "grid_column": null, 1047 | "grid_gap": null, 1048 | "grid_row": null, 1049 | "grid_template_areas": null, 1050 | "grid_template_columns": null, 1051 | "grid_template_rows": null, 1052 | "height": null, 1053 | "justify_content": null, 1054 | "justify_items": null, 1055 | "left": null, 1056 | "margin": null, 1057 | "max_height": null, 1058 | "max_width": null, 1059 | "min_height": null, 1060 | "min_width": null, 1061 | "object_fit": null, 1062 | "object_position": null, 1063 | "order": null, 1064 | "overflow": null, 1065 | "overflow_x": null, 1066 | "overflow_y": null, 1067 | "padding": null, 1068 | "right": null, 1069 | "top": null, 1070 | "visibility": null, 1071 | "width": null 1072 | } 1073 | }, 1074 | "94111dfb9a2d4a4f93e00bdb34c70090": { 1075 | "model_module": "@jupyter-widgets/controls", 1076 | "model_module_version": "1.5.0", 1077 | "model_name": "HBoxModel", 1078 | "state": { 1079 | "_dom_classes": [], 1080 | "_model_module": "@jupyter-widgets/controls", 1081 | "_model_module_version": "1.5.0", 1082 | "_model_name": "HBoxModel", 1083 | "_view_count": null, 1084 | "_view_module": "@jupyter-widgets/controls", 1085 | "_view_module_version": "1.5.0", 1086 | "_view_name": "HBoxView", 1087 | "box_style": "", 1088 | "children": [ 1089 | "IPY_MODEL_ff2e02e62d0b438cac9f521da8c0d5eb", 1090 | "IPY_MODEL_fc66ee3afa1944beb42494efbb1301ac", 1091 | "IPY_MODEL_9ac2a1e65c084bca8cdff9f1dc7541e0" 1092 | ], 1093 | "layout": "IPY_MODEL_659eee19636c45da881d243f66aedf27" 1094 | } 1095 | }, 1096 | "94765776469249ea94eee8ccf64c47e7": { 1097 | "model_module": "@jupyter-widgets/base", 1098 | "model_module_version": "1.2.0", 1099 | "model_name": "LayoutModel", 1100 | "state": { 1101 | "_model_module": "@jupyter-widgets/base", 1102 | "_model_module_version": "1.2.0", 1103 | "_model_name": "LayoutModel", 1104 | "_view_count": null, 1105 | "_view_module": "@jupyter-widgets/base", 1106 | "_view_module_version": "1.2.0", 1107 | "_view_name": "LayoutView", 1108 | "align_content": null, 1109 | "align_items": null, 1110 | "align_self": null, 1111 | "border": null, 1112 | "bottom": null, 1113 | "display": null, 1114 | "flex": null, 1115 | "flex_flow": null, 1116 | "grid_area": null, 1117 | "grid_auto_columns": null, 1118 | "grid_auto_flow": null, 1119 | "grid_auto_rows": null, 1120 | "grid_column": null, 1121 | "grid_gap": null, 1122 | "grid_row": null, 1123 | "grid_template_areas": null, 1124 | "grid_template_columns": null, 1125 | "grid_template_rows": null, 1126 | "height": null, 1127 | "justify_content": null, 1128 | "justify_items": null, 1129 | "left": null, 1130 | "margin": null, 1131 | "max_height": null, 1132 | "max_width": null, 1133 | "min_height": null, 1134 | "min_width": null, 1135 | "object_fit": null, 1136 | "object_position": null, 1137 | "order": null, 1138 | "overflow": null, 1139 | "overflow_x": null, 1140 | "overflow_y": null, 1141 | "padding": null, 1142 | "right": null, 1143 | "top": null, 1144 | "visibility": null, 1145 | "width": null 1146 | } 1147 | }, 1148 | "9ac2a1e65c084bca8cdff9f1dc7541e0": { 1149 | "model_module": "@jupyter-widgets/controls", 1150 | "model_module_version": "1.5.0", 1151 | "model_name": "HTMLModel", 1152 | "state": { 1153 | "_dom_classes": [], 1154 | "_model_module": "@jupyter-widgets/controls", 1155 | "_model_module_version": "1.5.0", 1156 | "_model_name": "HTMLModel", 1157 | "_view_count": null, 1158 | "_view_module": "@jupyter-widgets/controls", 1159 | "_view_module_version": "1.5.0", 1160 | "_view_name": "HTMLView", 1161 | "description": "", 1162 | "description_tooltip": null, 1163 | "layout": "IPY_MODEL_94765776469249ea94eee8ccf64c47e7", 1164 | "placeholder": "​", 1165 | "style": "IPY_MODEL_579055f403bf4594a2c665adfdfb8995", 1166 | "value": " 1/1 [00:00<00:00, 21.65it/s]" 1167 | } 1168 | }, 1169 | "b3486fd1f15b43068e47df0ad6a81559": { 1170 | "model_module": "@jupyter-widgets/base", 1171 | "model_module_version": "1.2.0", 1172 | "model_name": "LayoutModel", 1173 | "state": { 1174 | "_model_module": "@jupyter-widgets/base", 1175 | "_model_module_version": "1.2.0", 1176 | "_model_name": "LayoutModel", 1177 | "_view_count": null, 1178 | "_view_module": "@jupyter-widgets/base", 1179 | "_view_module_version": "1.2.0", 1180 | "_view_name": "LayoutView", 1181 | "align_content": null, 1182 | "align_items": null, 1183 | "align_self": null, 1184 | "border": null, 1185 | "bottom": null, 1186 | "display": null, 1187 | "flex": null, 1188 | "flex_flow": null, 1189 | "grid_area": null, 1190 | "grid_auto_columns": null, 1191 | "grid_auto_flow": null, 1192 | "grid_auto_rows": null, 1193 | "grid_column": null, 1194 | "grid_gap": null, 1195 | "grid_row": null, 1196 | "grid_template_areas": null, 1197 | "grid_template_columns": null, 1198 | "grid_template_rows": null, 1199 | "height": null, 1200 | "justify_content": null, 1201 | "justify_items": null, 1202 | "left": null, 1203 | "margin": null, 1204 | "max_height": null, 1205 | "max_width": null, 1206 | "min_height": null, 1207 | "min_width": null, 1208 | "object_fit": null, 1209 | "object_position": null, 1210 | "order": null, 1211 | "overflow": null, 1212 | "overflow_x": null, 1213 | "overflow_y": null, 1214 | "padding": null, 1215 | "right": null, 1216 | "top": null, 1217 | "visibility": null, 1218 | "width": null 1219 | } 1220 | }, 1221 | "bb2e04bed86047b0b3a4e587cfb48ef0": { 1222 | "model_module": "@jupyter-widgets/base", 1223 | "model_module_version": "1.2.0", 1224 | "model_name": "LayoutModel", 1225 | "state": { 1226 | "_model_module": "@jupyter-widgets/base", 1227 | "_model_module_version": "1.2.0", 1228 | "_model_name": "LayoutModel", 1229 | "_view_count": null, 1230 | "_view_module": "@jupyter-widgets/base", 1231 | "_view_module_version": "1.2.0", 1232 | "_view_name": "LayoutView", 1233 | "align_content": null, 1234 | "align_items": null, 1235 | "align_self": null, 1236 | "border": null, 1237 | "bottom": null, 1238 | "display": null, 1239 | "flex": null, 1240 | "flex_flow": null, 1241 | "grid_area": null, 1242 | "grid_auto_columns": null, 1243 | "grid_auto_flow": null, 1244 | "grid_auto_rows": null, 1245 | "grid_column": null, 1246 | "grid_gap": null, 1247 | "grid_row": null, 1248 | "grid_template_areas": null, 1249 | "grid_template_columns": null, 1250 | "grid_template_rows": null, 1251 | "height": null, 1252 | "justify_content": null, 1253 | "justify_items": null, 1254 | "left": null, 1255 | "margin": null, 1256 | "max_height": null, 1257 | "max_width": null, 1258 | "min_height": null, 1259 | "min_width": null, 1260 | "object_fit": null, 1261 | "object_position": null, 1262 | "order": null, 1263 | "overflow": null, 1264 | "overflow_x": null, 1265 | "overflow_y": null, 1266 | "padding": null, 1267 | "right": null, 1268 | "top": null, 1269 | "visibility": null, 1270 | "width": null 1271 | } 1272 | }, 1273 | "c8dad1a95c8646edbde1af6fcc3f0ff9": { 1274 | "model_module": "@jupyter-widgets/controls", 1275 | "model_module_version": "1.5.0", 1276 | "model_name": "ProgressStyleModel", 1277 | "state": { 1278 | "_model_module": "@jupyter-widgets/controls", 1279 | "_model_module_version": "1.5.0", 1280 | "_model_name": "ProgressStyleModel", 1281 | "_view_count": null, 1282 | "_view_module": "@jupyter-widgets/base", 1283 | "_view_module_version": "1.2.0", 1284 | "_view_name": "StyleView", 1285 | "bar_color": null, 1286 | "description_width": "" 1287 | } 1288 | }, 1289 | "db73dbc0dabc429481860871b02dc9e0": { 1290 | "model_module": "@jupyter-widgets/controls", 1291 | "model_module_version": "1.5.0", 1292 | "model_name": "DescriptionStyleModel", 1293 | "state": { 1294 | "_model_module": "@jupyter-widgets/controls", 1295 | "_model_module_version": "1.5.0", 1296 | "_model_name": "DescriptionStyleModel", 1297 | "_view_count": null, 1298 | "_view_module": "@jupyter-widgets/base", 1299 | "_view_module_version": "1.2.0", 1300 | "_view_name": "StyleView", 1301 | "description_width": "" 1302 | } 1303 | }, 1304 | "fc66ee3afa1944beb42494efbb1301ac": { 1305 | "model_module": "@jupyter-widgets/controls", 1306 | "model_module_version": "1.5.0", 1307 | "model_name": "FloatProgressModel", 1308 | "state": { 1309 | "_dom_classes": [], 1310 | "_model_module": "@jupyter-widgets/controls", 1311 | "_model_module_version": "1.5.0", 1312 | "_model_name": "FloatProgressModel", 1313 | "_view_count": null, 1314 | "_view_module": "@jupyter-widgets/controls", 1315 | "_view_module_version": "1.5.0", 1316 | "_view_name": "ProgressView", 1317 | "bar_style": "success", 1318 | "description": "", 1319 | "description_tooltip": null, 1320 | "layout": "IPY_MODEL_bb2e04bed86047b0b3a4e587cfb48ef0", 1321 | "max": 1, 1322 | "min": 0, 1323 | "orientation": "horizontal", 1324 | "style": "IPY_MODEL_c8dad1a95c8646edbde1af6fcc3f0ff9", 1325 | "value": 1 1326 | } 1327 | }, 1328 | "ff2e02e62d0b438cac9f521da8c0d5eb": { 1329 | "model_module": "@jupyter-widgets/controls", 1330 | "model_module_version": "1.5.0", 1331 | "model_name": "HTMLModel", 1332 | "state": { 1333 | "_dom_classes": [], 1334 | "_model_module": "@jupyter-widgets/controls", 1335 | "_model_module_version": "1.5.0", 1336 | "_model_name": "HTMLModel", 1337 | "_view_count": null, 1338 | "_view_module": "@jupyter-widgets/controls", 1339 | "_view_module_version": "1.5.0", 1340 | "_view_name": "HTMLView", 1341 | "description": "", 1342 | "description_tooltip": null, 1343 | "layout": "IPY_MODEL_b3486fd1f15b43068e47df0ad6a81559", 1344 | "placeholder": "​", 1345 | "style": "IPY_MODEL_db73dbc0dabc429481860871b02dc9e0", 1346 | "value": "100%" 1347 | } 1348 | } 1349 | } 1350 | } 1351 | }, 1352 | "nbformat": 4, 1353 | "nbformat_minor": 1 1354 | } 1355 | -------------------------------------------------------------------------------- /nbs/cc2zh_translate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "9EtssVJyGkmU" 7 | }, 8 | "source": [ 9 | "# Translate model" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## Imports" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from forgebox.imports import *\n", 33 | "from forgebox.thunder.callbacks import DataFrameMetricsCallback\n", 34 | "from gc_utils.env import *\n", 35 | "from datasets import load_dataset\n", 36 | "# from fastai.text.all import *\n", 37 | "from unpackai.nlp import *\n", 38 | "from tqdm.notebook import tqdm\n", 39 | "import random" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import pytorch_lightning as pl" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "import re\n", 58 | "\n", 59 | "def remove_all_punkt(text):\n", 60 | " \"\"\"\n", 61 | " Removes all punctuation from Chinese text.\n", 62 | "\n", 63 | " :param text: text to remove punctuation from\n", 64 | " :return: text with no punctuation\n", 65 | " \"\"\"\n", 66 | " return re.sub(r'[^\\w\\s]', '', text)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "'亳州水军千户胡进等领骑兵渡淝水逾荆山与宋兵战杀获甚众赏钞币有差'" 78 | ] 79 | }, 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "remove_all_punkt(\"亳州水军千户胡进等领骑兵渡淝水,逾荆山,与宋兵战,杀获甚众,赏钞币有差。\")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Config" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "DATA = Path(sys_loc('DATA')/\"nlp\"/\"zh\"/\"cc_vs_zh\")\n", 103 | "TO_CLASSICAL = False" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Download data" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "ZbXuwqr0KEr8" 117 | }, 118 | "source": [ 119 | "## Data" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "### Combine data" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "all_file = list(DATA.rglob(\"data/*\"))" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 7, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "def open_file_to_lines(file):\n", 145 | " with open(file) as f:\n", 146 | " lines = f.read().splitlines()\n", 147 | " return lines\n", 148 | "\n", 149 | "def pairing_the_file(files,kw):\n", 150 | " pairs = []\n", 151 | " for file in files:\n", 152 | " if kw not in file.name:\n", 153 | " file1 = file\n", 154 | " file2 = f\"{file}{kw}\"\n", 155 | " pairs.append((file1,file2))\n", 156 | " return pairs" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 8, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "pairs = pairing_the_file(all_file,\"翻译\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "def open_pairs(pairs):\n", 175 | " chunks = []\n", 176 | " for pair in tqdm(pairs, leave=False):\n", 177 | " file1,file2 = pair\n", 178 | " lines1 = open_file_to_lines(file1)\n", 179 | " lines2 = open_file_to_lines(file2)\n", 180 | " chunks.append(pd.DataFrame({\"classical\":lines1,\"modern\":lines2}))\n", 181 | " return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 10, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "application/vnd.jupyter.widget-view+json": { 192 | "model_id": "", 193 | "version_major": 2, 194 | "version_minor": 0 195 | }, 196 | "text/plain": [ 197 | " 0%| | 0/27 [00:00\n", 230 | "\n", 243 | "\n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | "
sourcetarget
0谏议大夫宁原悌上言:以为先朝悖逆庶人以爱女骄盈而及祸,新城、宜都以庶孽抑损而获全。谏议大夫宁原悌向唐睿宗进言认为:先朝悖逆庶人作为中宗和韦后的爱女而骄傲自满,终于难逃杀身之祸...
1意等漏卮,江河无以充其溢。思想像渗漏的酒器,长江、黄河无法来填满他的欲壑。
2琥珀太多,及差,痕不灭,左颊有赤点如痣。因琥珀用得过多,到伤愈时,邓夫人左颊疤疮没有完全去掉,脸上留下一颗象痣一样的红点。
3督军疾进,师至阴山,遇其斥候千余帐,皆俘以随军。于是督军疾进,军队行进到阴山,遇到颉利可汗的哨兵千余帐,把他们全部俘获,并押着他们随军行动。
4莽曰夕阴。王莽时叫夕阴县。
\n", 279 | "" 280 | ], 281 | "text/plain": [ 282 | " source \\\n", 283 | "0 谏议大夫宁原悌上言:以为先朝悖逆庶人以爱女骄盈而及祸,新城、宜都以庶孽抑损而获全。 \n", 284 | "1 意等漏卮,江河无以充其溢。 \n", 285 | "2 琥珀太多,及差,痕不灭,左颊有赤点如痣。 \n", 286 | "3 督军疾进,师至阴山,遇其斥候千余帐,皆俘以随军。 \n", 287 | "4 莽曰夕阴。 \n", 288 | "\n", 289 | " target \n", 290 | "0 谏议大夫宁原悌向唐睿宗进言认为:先朝悖逆庶人作为中宗和韦后的爱女而骄傲自满,终于难逃杀身之祸... \n", 291 | "1 思想像渗漏的酒器,长江、黄河无法来填满他的欲壑。 \n", 292 | "2 因琥珀用得过多,到伤愈时,邓夫人左颊疤疮没有完全去掉,脸上留下一颗象痣一样的红点。 \n", 293 | "3 于是督军疾进,军队行进到阴山,遇到颉利可汗的哨兵千余帐,把他们全部俘获,并押着他们随军行动。 \n", 294 | "4 王莽时叫夕阴县。 " 295 | ] 296 | }, 297 | "execution_count": 12, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "df.head()" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "### Loading tokenizer" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 13, 316 | "metadata": { 317 | "id": "ukyVGg8HmSd-" 318 | }, 319 | "outputs": [], 320 | "source": [ 321 | "from transformers import (\n", 322 | " AutoTokenizer,\n", 323 | " AutoModelForCausalLM,\n", 324 | " AutoModel,\n", 325 | " EncoderDecoderModel\n", 326 | " )\n", 327 | "\n", 328 | "# we find a English parsing encoder, as a pretrained bert is good at understanding english\n", 329 | "# BERT is short for Bidirectional **Encoder** Representations from Transformers, which consists fully of encoder blocks\n", 330 | "ENCODER_PRETRAINED = \"bert-base-chinese\"\n", 331 | "# we find a Chinese writing model for decoder, as decoder is the part of the model that can write stuff\n", 332 | "DECODER_PRETRAINED = \"uer/gpt2-chinese-poem\"\n", 333 | "\n", 334 | "encoder_tokenizer = AutoTokenizer.from_pretrained(ENCODER_PRETRAINED)\n", 335 | "\n", 336 | "decoder_tokenizer = AutoTokenizer.from_pretrained(\n", 337 | " ENCODER_PRETRAINED # notice we use the BERT's tokenizer here\n", 338 | ")" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "### Pytoch Dataset" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 14, 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "class Seq2Seq(Dataset):\n", 355 | " def __init__(\n", 356 | " self, df, tokenizer, target_tokenizer,\n", 357 | " max_len=128,\n", 358 | " no_punkt:bool = False,\n", 359 | " ):\n", 360 | " \"\"\"\n", 361 | " no_punkt, do we ramdomly remove punctuation\n", 362 | " from source sentence\n", 363 | " \"\"\"\n", 364 | " super().__init__()\n", 365 | " self.df = df\n", 366 | " self.tokenizer = tokenizer\n", 367 | " self.target_tokenizer = target_tokenizer\n", 368 | " self.max_len = max_len\n", 369 | " self.no_punkt = no_punkt\n", 370 | " \n", 371 | " def __len__(self, ):\n", 372 | " return len(self.df)\n", 373 | "\n", 374 | " def __getitem__(self, idx):\n", 375 | " return dict(self.df.iloc[idx])\n", 376 | "\n", 377 | " def collate(self, batch):\n", 378 | " batch_df = pd.DataFrame(list(batch))\n", 379 | " x, y = batch_df.source, batch_df.target\n", 380 | " # there is a random no punctuation mode\n", 381 | " # for source text\n", 382 | " # as some of the classical text we get\n", 383 | " # might be whole chunk of paragraph without\n", 384 | " # any punctuation\n", 385 | " if self.no_punkt:\n", 386 | " x = list(i if random.random()>.5\n", 387 | " else remove_all_punkt(i)\n", 388 | " for i in x)\n", 389 | " else:\n", 390 | " x = list(x)\n", 391 | " x_batch = self.tokenizer(\n", 392 | " x,\n", 393 | " max_length=self.max_len,\n", 394 | " padding='max_length',\n", 395 | " truncation=True,\n", 396 | " return_tensors='pt',\n", 397 | " )\n", 398 | " y_batch = self.target_tokenizer(\n", 399 | " list(y),\n", 400 | " max_length=self.max_len,\n", 401 | " padding='max_length',\n", 402 | " truncation=True,\n", 403 | " return_tensors='pt',\n", 404 | " )\n", 405 | " x_batch['decoder_input_ids'] = y_batch['input_ids']\n", 406 | " x_batch['labels'] = y_batch['input_ids'].clone()\n", 407 | " x_batch['labels'][x_batch['labels'] == self.tokenizer.pad_token_id] = -100\n", 408 | " return x_batch\n", 409 | "\n", 410 | " def dataloader(self, batch_size, shuffle=True):\n", 411 | " return DataLoader(\n", 412 | " self,\n", 413 | " batch_size=batch_size,\n", 414 | " shuffle=shuffle,\n", 415 | " collate_fn=self.collate,\n", 416 | " )\n", 417 | "\n", 418 | " def split_train_valid(self, valid_size=0.1):\n", 419 | " split_index = int(len(self) * (1 - valid_size))\n", 420 | " cls = type(self)\n", 421 | " shuffled = self.df.sample(frac=1).reset_index(drop=True)\n", 422 | " train_set = cls(\n", 423 | " shuffled.iloc[:split_index],\n", 424 | " tokenizer=self.tokenizer,\n", 425 | " target_tokenizer=self.target_tokenizer,\n", 426 | " max_len=self.max_len,\n", 427 | " no_punkt=self.no_punkt,\n", 428 | " )\n", 429 | " valid_set = cls(\n", 430 | " shuffled.iloc[split_index:],\n", 431 | " tokenizer=self.tokenizer,\n", 432 | " target_tokenizer=self.target_tokenizer,\n", 433 | " max_len=self.max_len,\n", 434 | " no_punkt=self.no_punkt,\n", 435 | " )\n", 436 | " return train_set, valid_set" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "### PL datamodule" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 15, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "class Seq2SeqData(pl.LightningDataModule):\n", 453 | " def __init__(\n", 454 | " self, df,\n", 455 | " tokenizer,\n", 456 | " target_tokenizer,\n", 457 | " batch_size=12,\n", 458 | " max_len=128,\n", 459 | " no_punkt:bool=False):\n", 460 | " super().__init__()\n", 461 | " self.df = df\n", 462 | " self.ds = Seq2Seq(df,\n", 463 | " tokenizer,\n", 464 | " target_tokenizer,\n", 465 | " max_len=max_len,\n", 466 | " no_punkt=no_punkt)\n", 467 | " self.tokenizer = tokenizer\n", 468 | " self.target_tokenizer = target_tokenizer\n", 469 | " self.max_len = max_len\n", 470 | " self.batch_size = batch_size\n", 471 | "\n", 472 | " def setup(self, stage=None):\n", 473 | " self.train_set, self.valid_set = self.ds.split_train_valid()\n", 474 | "\n", 475 | " def train_dataloader(self):\n", 476 | " return self.train_set.dataloader(\n", 477 | " batch_size=self.batch_size, shuffle=True)\n", 478 | "\n", 479 | " def val_dataloader(self):\n", 480 | " return self.valid_set.dataloader(\n", 481 | " batch_size=self.batch_size*2, shuffle=False)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 16, 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [ 490 | "data_module = Seq2SeqData(\n", 491 | " df, encoder_tokenizer,\n", 492 | " decoder_tokenizer,\n", 493 | " batch_size=28,\n", 494 | " max_len=256,\n", 495 | " no_punkt=False if TO_CLASSICAL else True,)\n", 496 | "data_module.setup()" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 17, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "data": { 506 | "text/plain": [ 507 | "{'input_ids': tensor([[ 101, 1921, 5688, ..., 0, 0, 0],\n", 508 | " [ 101, 2828, 1062, ..., 0, 0, 0],\n", 509 | " [ 101, 1039, 1469, ..., 0, 0, 0],\n", 510 | " ...,\n", 511 | " [ 101, 718, 886, ..., 0, 0, 0],\n", 512 | " [ 101, 1071, 1095, ..., 0, 0, 0],\n", 513 | " [ 101, 1062, 1920, ..., 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, ..., 0, 0, 0],\n", 514 | " [0, 0, 0, ..., 0, 0, 0],\n", 515 | " [0, 0, 0, ..., 0, 0, 0],\n", 516 | " ...,\n", 517 | " [0, 0, 0, ..., 0, 0, 0],\n", 518 | " [0, 0, 0, ..., 0, 0, 0],\n", 519 | " [0, 0, 0, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", 520 | " [1, 1, 1, ..., 0, 0, 0],\n", 521 | " [1, 1, 1, ..., 0, 0, 0],\n", 522 | " ...,\n", 523 | " [1, 1, 1, ..., 0, 0, 0],\n", 524 | " [1, 1, 1, ..., 0, 0, 0],\n", 525 | " [1, 1, 1, ..., 0, 0, 0]]), 'decoder_input_ids': tensor([[ 101, 1921, 5688, ..., 0, 0, 0],\n", 526 | " [ 101, 1315, 752, ..., 0, 0, 0],\n", 527 | " [ 101, 1039, 1469, ..., 0, 0, 0],\n", 528 | " ...,\n", 529 | " [ 101, 3727, 4374, ..., 0, 0, 0],\n", 530 | " [ 101, 7342, 5093, ..., 0, 0, 0],\n", 531 | " [ 101, 1155, 4893, ..., 0, 0, 0]]), 'labels': tensor([[ 101, 1921, 5688, ..., -100, -100, -100],\n", 532 | " [ 101, 1315, 752, ..., -100, -100, -100],\n", 533 | " [ 101, 1039, 1469, ..., -100, -100, -100],\n", 534 | " ...,\n", 535 | " [ 101, 3727, 4374, ..., -100, -100, -100],\n", 536 | " [ 101, 7342, 5093, ..., -100, -100, -100],\n", 537 | " [ 101, 1155, 4893, ..., -100, -100, -100]])}" 538 | ] 539 | }, 540 | "execution_count": 17, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "inputs = next(iter(data_module.train_dataloader()))\n", 547 | "inputs" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "if we are doing clasical Chinese to modern Chinese, we can randomly set half of the input without any punctuation, as many data source might be" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 18, 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "data": { 564 | "text/plain": [ 565 | "['天 节 八 星 , 在 毕 、 附 耳 南 , 主 使 臣 持 节 宣 威 四 方 。',\n", 566 | " '把 公 子 成 的 话 报 告 给 赵 武 灵 王 。 武 灵 王 说 : 我 就 知 道 王 叔 反 对 这 件 事 。 于 是 马 上 就 去 公 子 成 家 里 , 亲 自 向 他 阐 述 自 己 的 观 点 : 大 凡 衣 服 是 为 了 便 于 穿 用 , 礼 制 是 为 了 便 于 办 事 。',\n", 567 | " '元 和 五 年 已 前 租 赋 并 放 。',\n", 568 | " '凡 杀 三 人 , 伤 五 人 , 手 驱 郎 吏 二 十 余 人 。',\n", 569 | " '杨 石 二 少 年 为 民 害 简 置 狱 中 谕 以 祸 福 咸 感 悟 愿 自 赎',\n", 570 | " '辛 亥 诸 将 自 汉 口 开 坝 引 船 入 沦 河 先 遣 万 户 阿 剌 罕 以 兵 拒 沙 芜 口 逼 近 武 矶 巡 视 阳 罗 城 堡 径 趋 沙 芜 遂 入 大 江',\n", 571 | " '江 东 民 户 殷 盛 风 俗 峻 刻 强 弱 相 陵 奸 吏 蜂 起 符 书 一 下 文 摄 相 续',\n", 572 | " '昏 夜 , 平 善 , 乡 晨 , 傅 绔 袜 欲 起 , 因 失 衣 , 不 能 言 , 昼 漏 上 十 刻 而 崩 。',\n", 573 | " '子 十 三 篇',\n", 574 | " '扶 风 民 鲁 悉 达 , 纠 合 乡 人 以 保 新 蔡 , 力 田 蓄 谷 。',\n", 575 | " '明 年 , 又 贬 武 安 军 节 度 副 使 、 永 州 安 置 。',\n", 576 | " '良 久 徐 曰 恬 罪 故 当 死 矣',\n", 577 | " '部 曲 将 田 泓 请 没 水 潜 行 趣 彭 城 , 玄 遣 之 。',\n", 578 | " '必 久 停 留 , 恐 非 天 意 也 。',\n", 579 | " '具 传 其 业 又 默 讲 论 义 理 五 经 诸 子 无 不 该 览 加 博 好 技 艺 算 术 卜 数 医 药 弓 弩 机 械 之 巧 皆 致 思 焉',\n", 580 | " '苏 秦 初 合 纵 至 燕',\n", 581 | " '讼 者 言 词 忿 争 理 无 所 屈',\n", 582 | " '高 祖 闻 之 , 曰 : 二 将 和 , 师 必 济 矣 。',\n", 583 | " '谧 兄 谌 字 兴 伯 性 平 和',\n", 584 | " '平 受 诏 , 立 复 驰 至 宫 , 哭 殊 悲 ; 因 固 请 得 宿 卫 中 。',\n", 585 | " '属 淮 阴 , 击 破 齐 历 下 军 , 击 田 解 。',\n", 586 | " '惇 与 蔡 卞 将 必 置 之 死 , 因 使 者 入 海 岛 诛 陈 衍 , 讽 使 者 过 安 世 , 胁 使 自 裁 。',\n", 587 | " '是 后 , 将 士 功 赏 视 立 功 之 地 , 准 例 奏 行 。',\n", 588 | " '左 右 欲 兵 之 。 太 公 曰 : 此 义 人 也 。',\n", 589 | " '僧 知 是 非 常 人 顶 礼 忏 悔 授 书 与 之',\n", 590 | " '乃 使 良 还',\n", 591 | " '其 冢 人 祠 之 不 绝',\n", 592 | " '公 大 怒 , 揖 出 之 。']" 593 | ] 594 | }, 595 | "execution_count": 18, 596 | "metadata": {}, 597 | "output_type": "execute_result" 598 | } 599 | ], 600 | "source": [ 601 | "encoder_tokenizer.batch_decode(\n", 602 | " inputs.input_ids,skip_special_tokens=True\n", 603 | ")" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": { 609 | "id": "92iwRu6Oqbzb" 610 | }, 611 | "source": [ 612 | "### Load pretrained models" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 19, 618 | "metadata": { 619 | "colab": { 620 | "base_uri": "https://localhost:8080/" 621 | }, 622 | "id": "gZkPxJVTm8Ng", 623 | "outputId": "dcecf16e-22fe-4c25-9ffb-aae9d75785f3" 624 | }, 625 | "outputs": [], 626 | "source": [ 627 | "# encoder = AutoModel.from_pretrained(ENCODER_PRETRAINED, proxies={\"http\":\"bifrost:3128\"})\n", 628 | "# decoder = AutoModelForCausalLM.from_pretrained(DECODER_PRETRAINED, add_cross_attention=True,\n", 629 | "# proxies={\"http\":\"bifrost:3128\"})" 630 | ] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "metadata": { 635 | "id": "pajv5ridLamp" 636 | }, 637 | "source": [ 638 | "## Model" 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "metadata": { 644 | "id": "s1zqJXDsCUw-" 645 | }, 646 | "source": [ 647 | "We create a seq2seq model by using pretrained encoder + pretrained decoder" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 20, 653 | "metadata": {}, 654 | "outputs": [ 655 | { 656 | "name": "stderr", 657 | "output_type": "stream", 658 | "text": [ 659 | "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']\n", 660 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 661 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 662 | "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at uer/gpt2-chinese-poem and are newly initialized: ['transformer.h.6.crossattention.bias', 'transformer.h.4.crossattention.bias', 'transformer.h.6.crossattention.masked_bias', 'transformer.h.6.crossattention.c_attn.weight', 'transformer.h.2.crossattention.bias', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.weight', 'transformer.h.10.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.9.crossattention.c_proj.weight', 'transformer.h.9.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.11.crossattention.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.9.crossattention.bias', 'transformer.h.4.crossattention.masked_bias', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.9.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.3.crossattention.bias', 'transformer.h.9.crossattention.c_attn.weight', 'transformer.h.1.crossattention.masked_bias', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.7.crossattention.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.7.crossattention.c_proj.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.5.crossattention.masked_bias', 'transformer.h.7.crossattention.c_proj.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.7.ln_cross_attn.weight', 'transformer.h.11.crossattention.c_proj.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.9.crossattention.masked_bias', 'transformer.h.11.crossattention.q_attn.weight', 'transformer.h.1.crossattention.bias', 'transformer.h.7.crossattention.c_attn.weight', 'transformer.h.10.crossattention.masked_bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.6.ln_cross_attn.weight', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.10.crossattention.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.6.crossattention.c_proj.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.11.crossattention.masked_bias', 'transformer.h.6.crossattention.c_proj.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.bias', 'transformer.h.8.crossattention.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.7.crossattention.q_attn.weight', 'transformer.h.11.crossattention.c_attn.weight', 'transformer.h.9.crossattention.q_attn.weight', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.11.crossattention.c_proj.weight', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.6.crossattention.q_attn.weight', 'transformer.h.8.crossattention.masked_bias', 'transformer.h.8.crossattention.q_attn.weight', 'transformer.h.2.crossattention.masked_bias', 'transformer.h.8.crossattention.c_attn.weight', 'transformer.h.7.crossattention.masked_bias', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.3.crossattention.masked_bias', 'transformer.h.0.crossattention.masked_bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.8.crossattention.c_proj.weight', 'transformer.h.11.ln_cross_attn.weight']\n", 663 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 664 | ] 665 | } 666 | ], 667 | "source": [ 668 | "# loading pretrained model\n", 669 | "encoder_decoder = EncoderDecoderModel.from_encoder_decoder_pretrained(\n", 670 | " encoder_pretrained_model_name_or_path=ENCODER_PRETRAINED,\n", 671 | " decoder_pretrained_model_name_or_path=DECODER_PRETRAINED,\n", 672 | ")" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": 21, 678 | "metadata": { 679 | "id": "jBVyNeKUv6FU" 680 | }, 681 | "outputs": [], 682 | "source": [ 683 | "class Seq2SeqTrain(pl.LightningModule):\n", 684 | " def __init__(self, encoder_decoder):\n", 685 | " super().__init__()\n", 686 | " self.encoder_decoder = encoder_decoder\n", 687 | " \n", 688 | " def forward(self, batch):\n", 689 | " return self.encoder_decoder(\n", 690 | " **batch\n", 691 | " )\n", 692 | "\n", 693 | " def training_step(self, batch, batch_idx):\n", 694 | " outputs = self(batch)\n", 695 | " self.log('loss', outputs.loss)\n", 696 | " return outputs.loss\n", 697 | "\n", 698 | " def validation_step(self, batch, batch_idx):\n", 699 | " outputs = self(batch)\n", 700 | " self.log('val_loss', outputs.loss)\n", 701 | " return outputs.loss\n", 702 | " \n", 703 | " def configure_optimizers(self):\n", 704 | " encoder_params = list(\n", 705 | " {\"params\":param,\"lr\":1e-5}\n", 706 | " for param in self.encoder_decoder.encoder.embeddings.parameters()) +\\\n", 707 | " list({\"params\":param,\"lr\":1e-5}\n", 708 | " for param in self.encoder_decoder.encoder.encoder.parameters()) +\\\n", 709 | " list({\"params\":param,\"lr\":1e-3}\n", 710 | " for param in self.encoder_decoder.encoder.pooler.parameters())\n", 711 | "\n", 712 | " decoder_params = list()\n", 713 | " for name, param in self.encoder_decoder.decoder.named_parameters():\n", 714 | " if 'ln_cross_attn' in name:\n", 715 | " decoder_params.append({\"params\":param,\"lr\":1e-3})\n", 716 | " elif 'crossattention' in name:\n", 717 | " decoder_params.append({\"params\":param,\"lr\":1e-3})\n", 718 | " elif 'lm_head' in name:\n", 719 | " decoder_params.append({\"params\":param,\"lr\":1e-4})\n", 720 | " else:\n", 721 | " decoder_params.append({\"params\":param,\"lr\":1e-5})\n", 722 | "\n", 723 | " return torch.optim.Adam(\n", 724 | " encoder_params + decoder_params,\n", 725 | " lr=1e-3,\n", 726 | " )" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 22, 732 | "metadata": { 733 | "id": "5uIjcPuXw0Fr" 734 | }, 735 | "outputs": [], 736 | "source": [ 737 | "module = Seq2SeqTrain(encoder_decoder)" 738 | ] 739 | }, 740 | { 741 | "cell_type": "markdown", 742 | "metadata": { 743 | "id": "DBf3NTKSLcUb" 744 | }, 745 | "source": [ 746 | "## Training" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 23, 752 | "metadata": {}, 753 | "outputs": [ 754 | { 755 | "name": "stderr", 756 | "output_type": "stream", 757 | "text": [ 758 | "GPU available: True, used: True\n", 759 | "TPU available: False, using: 0 TPU cores\n" 760 | ] 761 | } 762 | ], 763 | "source": [ 764 | "save = pl.callbacks.ModelCheckpoint(\n", 765 | " '/GCI/transformers/weights/cc_to_zh',\n", 766 | " save_top_k=2,\n", 767 | " verbose=True,\n", 768 | " monitor='val_loss',\n", 769 | " mode='min',\n", 770 | ")\n", 771 | "\n", 772 | "trainer = pl.Trainer(\n", 773 | " gpus=[1],\n", 774 | " max_epochs=10,\n", 775 | " callbacks=[save],\n", 776 | ")" 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": 24, 782 | "metadata": {}, 783 | "outputs": [ 784 | { 785 | "name": "stderr", 786 | "output_type": "stream", 787 | "text": [ 788 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", 789 | "\n", 790 | " | Name | Type | Params\n", 791 | "--------------------------------------------------------\n", 792 | "0 | encoder_decoder | EncoderDecoderModel | 233 M \n", 793 | "--------------------------------------------------------\n", 794 | "233 M Trainable params\n", 795 | "0 Non-trainable params\n", 796 | "233 M Total params\n", 797 | "935.203 Total estimated model params size (MB)\n" 798 | ] 799 | }, 800 | { 801 | "data": { 802 | "application/vnd.jupyter.widget-view+json": { 803 | "model_id": "", 804 | "version_major": 2, 805 | "version_minor": 0 806 | }, 807 | "text/plain": [ 808 | "Validation sanity check: 0it [00:00, ?it/s]" 809 | ] 810 | }, 811 | "metadata": {}, 812 | "output_type": "display_data" 813 | }, 814 | { 815 | "name": "stderr", 816 | "output_type": "stream", 817 | "text": [ 818 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 819 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n", 820 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:103: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 821 | " f'The dataloader, {name}, does not have many workers which may be a bottleneck.'\n" 822 | ] 823 | }, 824 | { 825 | "data": { 826 | "application/vnd.jupyter.widget-view+json": { 827 | "model_id": "41e24592063e457da4259fab0911e194", 828 | "version_major": 2, 829 | "version_minor": 0 830 | }, 831 | "text/plain": [ 832 | "Training: 0it [00:00, ?it/s]" 833 | ] 834 | }, 835 | "metadata": {}, 836 | "output_type": "display_data" 837 | }, 838 | { 839 | "name": "stderr", 840 | "output_type": "stream", 841 | "text": [ 842 | "/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py:897: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", 843 | " rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')\n" 844 | ] 845 | } 846 | ], 847 | "source": [ 848 | "trainer.fit(module, datamodule=data_module)" 849 | ] 850 | } 851 | ], 852 | "metadata": { 853 | "accelerator": "GPU", 854 | "colab": { 855 | "collapsed_sections": [], 856 | "name": "seq2seq.ipynb", 857 | "provenance": [] 858 | }, 859 | "kernelspec": { 860 | "display_name": "Python 3", 861 | "language": "python", 862 | "name": "python3" 863 | }, 864 | "language_info": { 865 | "codemirror_mode": { 866 | "name": "ipython", 867 | "version": 3 868 | }, 869 | "file_extension": ".py", 870 | "mimetype": "text/x-python", 871 | "name": "python", 872 | "nbconvert_exporter": "python", 873 | "pygments_lexer": "ipython3", 874 | "version": "3.7.4" 875 | }, 876 | "toc": { 877 | "base_numbering": 1, 878 | "nav_menu": {}, 879 | "number_sections": true, 880 | "sideBar": true, 881 | "skip_h1_title": false, 882 | "title_cell": "Table of Contents", 883 | "title_sidebar": "Contents", 884 | "toc_cell": false, 885 | "toc_position": {}, 886 | "toc_section_display": true, 887 | "toc_window_display": false 888 | }, 889 | "widgets": { 890 | "application/vnd.jupyter.widget-state+json": { 891 | "579055f403bf4594a2c665adfdfb8995": { 892 | "model_module": "@jupyter-widgets/controls", 893 | "model_module_version": "1.5.0", 894 | "model_name": "DescriptionStyleModel", 895 | "state": { 896 | "_model_module": "@jupyter-widgets/controls", 897 | "_model_module_version": "1.5.0", 898 | "_model_name": "DescriptionStyleModel", 899 | "_view_count": null, 900 | "_view_module": "@jupyter-widgets/base", 901 | "_view_module_version": "1.2.0", 902 | "_view_name": "StyleView", 903 | "description_width": "" 904 | } 905 | }, 906 | "659eee19636c45da881d243f66aedf27": { 907 | "model_module": "@jupyter-widgets/base", 908 | "model_module_version": "1.2.0", 909 | "model_name": "LayoutModel", 910 | "state": { 911 | "_model_module": "@jupyter-widgets/base", 912 | "_model_module_version": "1.2.0", 913 | "_model_name": "LayoutModel", 914 | "_view_count": null, 915 | "_view_module": "@jupyter-widgets/base", 916 | "_view_module_version": "1.2.0", 917 | "_view_name": "LayoutView", 918 | "align_content": null, 919 | "align_items": null, 920 | "align_self": null, 921 | "border": null, 922 | "bottom": null, 923 | "display": null, 924 | "flex": null, 925 | "flex_flow": null, 926 | "grid_area": null, 927 | "grid_auto_columns": null, 928 | "grid_auto_flow": null, 929 | "grid_auto_rows": null, 930 | "grid_column": null, 931 | "grid_gap": null, 932 | "grid_row": null, 933 | "grid_template_areas": null, 934 | "grid_template_columns": null, 935 | "grid_template_rows": null, 936 | "height": null, 937 | "justify_content": null, 938 | "justify_items": null, 939 | "left": null, 940 | "margin": null, 941 | "max_height": null, 942 | "max_width": null, 943 | "min_height": null, 944 | "min_width": null, 945 | "object_fit": null, 946 | "object_position": null, 947 | "order": null, 948 | "overflow": null, 949 | "overflow_x": null, 950 | "overflow_y": null, 951 | "padding": null, 952 | "right": null, 953 | "top": null, 954 | "visibility": null, 955 | "width": null 956 | } 957 | }, 958 | "94111dfb9a2d4a4f93e00bdb34c70090": { 959 | "model_module": "@jupyter-widgets/controls", 960 | "model_module_version": "1.5.0", 961 | "model_name": "HBoxModel", 962 | "state": { 963 | "_dom_classes": [], 964 | "_model_module": "@jupyter-widgets/controls", 965 | "_model_module_version": "1.5.0", 966 | "_model_name": "HBoxModel", 967 | "_view_count": null, 968 | "_view_module": "@jupyter-widgets/controls", 969 | "_view_module_version": "1.5.0", 970 | "_view_name": "HBoxView", 971 | "box_style": "", 972 | "children": [ 973 | "IPY_MODEL_ff2e02e62d0b438cac9f521da8c0d5eb", 974 | "IPY_MODEL_fc66ee3afa1944beb42494efbb1301ac", 975 | "IPY_MODEL_9ac2a1e65c084bca8cdff9f1dc7541e0" 976 | ], 977 | "layout": "IPY_MODEL_659eee19636c45da881d243f66aedf27" 978 | } 979 | }, 980 | "94765776469249ea94eee8ccf64c47e7": { 981 | "model_module": "@jupyter-widgets/base", 982 | "model_module_version": "1.2.0", 983 | "model_name": "LayoutModel", 984 | "state": { 985 | "_model_module": "@jupyter-widgets/base", 986 | "_model_module_version": "1.2.0", 987 | "_model_name": "LayoutModel", 988 | "_view_count": null, 989 | "_view_module": "@jupyter-widgets/base", 990 | "_view_module_version": "1.2.0", 991 | "_view_name": "LayoutView", 992 | "align_content": null, 993 | "align_items": null, 994 | "align_self": null, 995 | "border": null, 996 | "bottom": null, 997 | "display": null, 998 | "flex": null, 999 | "flex_flow": null, 1000 | "grid_area": null, 1001 | "grid_auto_columns": null, 1002 | "grid_auto_flow": null, 1003 | "grid_auto_rows": null, 1004 | "grid_column": null, 1005 | "grid_gap": null, 1006 | "grid_row": null, 1007 | "grid_template_areas": null, 1008 | "grid_template_columns": null, 1009 | "grid_template_rows": null, 1010 | "height": null, 1011 | "justify_content": null, 1012 | "justify_items": null, 1013 | "left": null, 1014 | "margin": null, 1015 | "max_height": null, 1016 | "max_width": null, 1017 | "min_height": null, 1018 | "min_width": null, 1019 | "object_fit": null, 1020 | "object_position": null, 1021 | "order": null, 1022 | "overflow": null, 1023 | "overflow_x": null, 1024 | "overflow_y": null, 1025 | "padding": null, 1026 | "right": null, 1027 | "top": null, 1028 | "visibility": null, 1029 | "width": null 1030 | } 1031 | }, 1032 | "9ac2a1e65c084bca8cdff9f1dc7541e0": { 1033 | "model_module": "@jupyter-widgets/controls", 1034 | "model_module_version": "1.5.0", 1035 | "model_name": "HTMLModel", 1036 | "state": { 1037 | "_dom_classes": [], 1038 | "_model_module": "@jupyter-widgets/controls", 1039 | "_model_module_version": "1.5.0", 1040 | "_model_name": "HTMLModel", 1041 | "_view_count": null, 1042 | "_view_module": "@jupyter-widgets/controls", 1043 | "_view_module_version": "1.5.0", 1044 | "_view_name": "HTMLView", 1045 | "description": "", 1046 | "description_tooltip": null, 1047 | "layout": "IPY_MODEL_94765776469249ea94eee8ccf64c47e7", 1048 | "placeholder": "​", 1049 | "style": "IPY_MODEL_579055f403bf4594a2c665adfdfb8995", 1050 | "value": " 1/1 [00:00<00:00, 21.65it/s]" 1051 | } 1052 | }, 1053 | "b3486fd1f15b43068e47df0ad6a81559": { 1054 | "model_module": "@jupyter-widgets/base", 1055 | "model_module_version": "1.2.0", 1056 | "model_name": "LayoutModel", 1057 | "state": { 1058 | "_model_module": "@jupyter-widgets/base", 1059 | "_model_module_version": "1.2.0", 1060 | "_model_name": "LayoutModel", 1061 | "_view_count": null, 1062 | "_view_module": "@jupyter-widgets/base", 1063 | "_view_module_version": "1.2.0", 1064 | "_view_name": "LayoutView", 1065 | "align_content": null, 1066 | "align_items": null, 1067 | "align_self": null, 1068 | "border": null, 1069 | "bottom": null, 1070 | "display": null, 1071 | "flex": null, 1072 | "flex_flow": null, 1073 | "grid_area": null, 1074 | "grid_auto_columns": null, 1075 | "grid_auto_flow": null, 1076 | "grid_auto_rows": null, 1077 | "grid_column": null, 1078 | "grid_gap": null, 1079 | "grid_row": null, 1080 | "grid_template_areas": null, 1081 | "grid_template_columns": null, 1082 | "grid_template_rows": null, 1083 | "height": null, 1084 | "justify_content": null, 1085 | "justify_items": null, 1086 | "left": null, 1087 | "margin": null, 1088 | "max_height": null, 1089 | "max_width": null, 1090 | "min_height": null, 1091 | "min_width": null, 1092 | "object_fit": null, 1093 | "object_position": null, 1094 | "order": null, 1095 | "overflow": null, 1096 | "overflow_x": null, 1097 | "overflow_y": null, 1098 | "padding": null, 1099 | "right": null, 1100 | "top": null, 1101 | "visibility": null, 1102 | "width": null 1103 | } 1104 | }, 1105 | "bb2e04bed86047b0b3a4e587cfb48ef0": { 1106 | "model_module": "@jupyter-widgets/base", 1107 | "model_module_version": "1.2.0", 1108 | "model_name": "LayoutModel", 1109 | "state": { 1110 | "_model_module": "@jupyter-widgets/base", 1111 | "_model_module_version": "1.2.0", 1112 | "_model_name": "LayoutModel", 1113 | "_view_count": null, 1114 | "_view_module": "@jupyter-widgets/base", 1115 | "_view_module_version": "1.2.0", 1116 | "_view_name": "LayoutView", 1117 | "align_content": null, 1118 | "align_items": null, 1119 | "align_self": null, 1120 | "border": null, 1121 | "bottom": null, 1122 | "display": null, 1123 | "flex": null, 1124 | "flex_flow": null, 1125 | "grid_area": null, 1126 | "grid_auto_columns": null, 1127 | "grid_auto_flow": null, 1128 | "grid_auto_rows": null, 1129 | "grid_column": null, 1130 | "grid_gap": null, 1131 | "grid_row": null, 1132 | "grid_template_areas": null, 1133 | "grid_template_columns": null, 1134 | "grid_template_rows": null, 1135 | "height": null, 1136 | "justify_content": null, 1137 | "justify_items": null, 1138 | "left": null, 1139 | "margin": null, 1140 | "max_height": null, 1141 | "max_width": null, 1142 | "min_height": null, 1143 | "min_width": null, 1144 | "object_fit": null, 1145 | "object_position": null, 1146 | "order": null, 1147 | "overflow": null, 1148 | "overflow_x": null, 1149 | "overflow_y": null, 1150 | "padding": null, 1151 | "right": null, 1152 | "top": null, 1153 | "visibility": null, 1154 | "width": null 1155 | } 1156 | }, 1157 | "c8dad1a95c8646edbde1af6fcc3f0ff9": { 1158 | "model_module": "@jupyter-widgets/controls", 1159 | "model_module_version": "1.5.0", 1160 | "model_name": "ProgressStyleModel", 1161 | "state": { 1162 | "_model_module": "@jupyter-widgets/controls", 1163 | "_model_module_version": "1.5.0", 1164 | "_model_name": "ProgressStyleModel", 1165 | "_view_count": null, 1166 | "_view_module": "@jupyter-widgets/base", 1167 | "_view_module_version": "1.2.0", 1168 | "_view_name": "StyleView", 1169 | "bar_color": null, 1170 | "description_width": "" 1171 | } 1172 | }, 1173 | "db73dbc0dabc429481860871b02dc9e0": { 1174 | "model_module": "@jupyter-widgets/controls", 1175 | "model_module_version": "1.5.0", 1176 | "model_name": "DescriptionStyleModel", 1177 | "state": { 1178 | "_model_module": "@jupyter-widgets/controls", 1179 | "_model_module_version": "1.5.0", 1180 | "_model_name": "DescriptionStyleModel", 1181 | "_view_count": null, 1182 | "_view_module": "@jupyter-widgets/base", 1183 | "_view_module_version": "1.2.0", 1184 | "_view_name": "StyleView", 1185 | "description_width": "" 1186 | } 1187 | }, 1188 | "fc66ee3afa1944beb42494efbb1301ac": { 1189 | "model_module": "@jupyter-widgets/controls", 1190 | "model_module_version": "1.5.0", 1191 | "model_name": "FloatProgressModel", 1192 | "state": { 1193 | "_dom_classes": [], 1194 | "_model_module": "@jupyter-widgets/controls", 1195 | "_model_module_version": "1.5.0", 1196 | "_model_name": "FloatProgressModel", 1197 | "_view_count": null, 1198 | "_view_module": "@jupyter-widgets/controls", 1199 | "_view_module_version": "1.5.0", 1200 | "_view_name": "ProgressView", 1201 | "bar_style": "success", 1202 | "description": "", 1203 | "description_tooltip": null, 1204 | "layout": "IPY_MODEL_bb2e04bed86047b0b3a4e587cfb48ef0", 1205 | "max": 1, 1206 | "min": 0, 1207 | "orientation": "horizontal", 1208 | "style": "IPY_MODEL_c8dad1a95c8646edbde1af6fcc3f0ff9", 1209 | "value": 1 1210 | } 1211 | }, 1212 | "ff2e02e62d0b438cac9f521da8c0d5eb": { 1213 | "model_module": "@jupyter-widgets/controls", 1214 | "model_module_version": "1.5.0", 1215 | "model_name": "HTMLModel", 1216 | "state": { 1217 | "_dom_classes": [], 1218 | "_model_module": "@jupyter-widgets/controls", 1219 | "_model_module_version": "1.5.0", 1220 | "_model_name": "HTMLModel", 1221 | "_view_count": null, 1222 | "_view_module": "@jupyter-widgets/controls", 1223 | "_view_module_version": "1.5.0", 1224 | "_view_name": "HTMLView", 1225 | "description": "", 1226 | "description_tooltip": null, 1227 | "layout": "IPY_MODEL_b3486fd1f15b43068e47df0ad6a81559", 1228 | "placeholder": "​", 1229 | "style": "IPY_MODEL_db73dbc0dabc429481860871b02dc9e0", 1230 | "value": "100%" 1231 | } 1232 | } 1233 | } 1234 | } 1235 | }, 1236 | "nbformat": 4, 1237 | "nbformat_minor": 1 1238 | } 1239 | --------------------------------------------------------------------------------