├── .gitignore ├── cls.model ├── gen.model ├── ppo.model ├── README.md ├── 2.train_cls.ipynb ├── 1.train_gen.ipynb ├── 4.test.ipynb ├── common.ipynb └── 3.train_ppo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ -------------------------------------------------------------------------------- /cls.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_LLM_PPO/HEAD/cls.model -------------------------------------------------------------------------------- /gen.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_LLM_PPO/HEAD/gen.model -------------------------------------------------------------------------------- /ppo.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lansinuote/Simple_LLM_PPO/HEAD/ppo.model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PPO方法训练大语言模型,简易实现代码 2 | 3 | 环境信息: 4 | 5 | python=3.10 6 | 7 | torch==2.1.0(cuda) 8 | 9 | transformers==4.34.0 10 | 11 | datasets==2.14.5 12 | 13 | trl==0.7.2 14 | 15 | 视频课程:https://www.bilibili.com/video/BV1uy4y1c7DV 16 | 17 | 2024年8月9日更新: 修改部分注释 18 | -------------------------------------------------------------------------------- /2.train_cls.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "58e4f06d", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "['S__:9821=叁玖贰捌柒玖贰捌柒玖贰捌柒玖贰捌肆E',\n", 13 | " 'S__:2424=oyoyoyoyoyoyoyoyEP',\n", 14 | " 'S__:6623=贰陆肆玖肆陆肆玖肆陆肆玖肆陆肆玖贰E',\n", 15 | " 'S__:2037=捌壹肆捌捌壹肆捌捌壹肆捌捌壹肆捌EP',\n", 16 | " 'S__:9685=叁捌柒肆叁捌柒肆叁捌柒肆叁捌柒肆零E',\n", 17 | " 'S__:5989=二三九五八三九五八三九五八三九五六E',\n", 18 | " 'S__:1561=六二四四六二四四六二四四六二四四EP',\n", 19 | " 'S__:1197=4788478847884788EP',\n", 20 | " 'S__:6300=贰伍贰零贰伍贰零贰伍贰零贰伍贰零零E',\n", 21 | " 'S__:8081=ewewuwewuwewuwewrE']" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "%run common.ipynb\n", 31 | "\n", 32 | "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=False)[1]][:10]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "62aea6a1", 39 | "metadata": { 40 | "scrolled": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "model_cls = ModelCLS()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "id": "dbe0768c", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "0 0.25\n", 58 | "S__:3899=壹伍伍玖柒伍伍玖柒伍伍玖柒伍伍玖陆E 3\n", 59 | "S__:4966=一九八六五九八六五九八六五九八六四E 3\n", 60 | "100 1.0\n", 61 | "S__:8025=ewqpewqpewqpewqppE 1\n", 62 | "S__:4263=17053705370537052E 0\n", 63 | "200 1.0\n", 64 | "S__:1158=ryewryewryewryewEP 1\n", 65 | "S__:1526=6104610461046104EP 0\n", 66 | "300 1.0\n", 67 | "S__:4577=一八三〇九八三〇九八三〇九八三〇八E 2\n", 68 | "S__:3455=一三八二一三八二一三八二一三八二〇E 2\n", 69 | "400 1.0\n", 70 | "S__:3273=壹叁零玖叁叁零玖叁叁零玖叁叁零玖贰E 3\n", 71 | "S__:5049=20198019801980196E 0\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "optimizer = torch.optim.AdamW(params=model_cls.parameters(), lr=1e-4)\n", 77 | "criterion = torch.nn.CrossEntropyLoss()\n", 78 | "\n", 79 | "for epoch in range(500):\n", 80 | " label, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)\n", 81 | " label = torch.LongTensor(label).to(device)\n", 82 | " input_ids = torch.LongTensor(input_ids).to(device)\n", 83 | " attention_mask = torch.LongTensor(attention_mask).to(device)\n", 84 | "\n", 85 | " logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)\n", 86 | "\n", 87 | " loss = criterion(logits, label)\n", 88 | " loss.backward()\n", 89 | " optimizer.step()\n", 90 | " optimizer.zero_grad()\n", 91 | "\n", 92 | " if epoch % 100 == 0:\n", 93 | " logits = logits.argmax(1)\n", 94 | " acc = (logits == label).sum().item() / len(label)\n", 95 | " print(epoch, acc)\n", 96 | "\n", 97 | " for i in range(2):\n", 98 | " print(tokenizer.decode(input_ids[i].tolist()), logits[i].item())\n", 99 | "\n", 100 | "model_cls.to('cpu')\n", 101 | "torch.save(model_cls, 'cls.model')" 102 | ] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python [conda env:cuda117]", 108 | "language": "python", 109 | "name": "conda-env-cuda117-py" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.10.13" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /1.train_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "7cdb3b0f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "['S__:4370=quriquriquriquripE',\n", 13 | " 'S__:1872=7488748874887488EP',\n", 14 | " 'S__:3782=15129512951295128E',\n", 15 | " 'S__:3436=qeurteurteurteurrE',\n", 16 | " 'S__:2659=qpyeupyeupyeupyeyE',\n", 17 | " 'S__:8508=erpetrpetrpetrpewE',\n", 18 | " 'S__:6202=wriqpriqpriqpripiE',\n", 19 | " 'S__:7528=叁零壹壹伍零壹壹伍零壹壹伍零壹壹贰E',\n", 20 | " 'S__:9819=三九二七九九二七九九二七九九二七六E',\n", 21 | " 'S__:7813=eqwttqwttqwttqwtwE']" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "%run common.ipynb\n", 31 | "\n", 32 | "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=False)[1]][:10]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "5e45b488", 39 | "metadata": { 40 | "scrolled": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "model_gen = ModelGEN()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "id": "fbf00a68", 51 | "metadata": { 52 | "scrolled": false 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "0\n", 60 | "S__:4952=肆五二8E\n", 61 | "S__:4761=九EPPP\n", 62 | "1000\n", 63 | "S__:5948=贰壹柒玖肆壹柒玖肆壹柒玖伍壹柒玖贰E\n", 64 | "S__:5106=0042604460426044oE\n", 65 | "2000\n", 66 | "S__:3943=1573357三三57335732E\n", 67 | "S__:8711=三四八四七四八四七四八四七四八四四E\n", 68 | "3000\n", 69 | "S__:7605=eprweprweprweprwpE\n", 70 | "S__:9670=三八六八三八六八三八六八三八六八〇E\n", 71 | "4000\n", 72 | "S__:1748=7992799279927992EP\n", 73 | "S__:7750=贰捌陆零贰捌陆零贰捌陆零贰捌陆零零E\n", 74 | "5000\n", 75 | "S__:9917=39671967196719668E\n", 76 | "S__:8157=叁贰陆叁壹贰陆叁壹贰陆叁壹贰陆贰捌E\n", 77 | "6000\n", 78 | "S__:1532=六一二八六一二八六一二八六一二八EP\n", 79 | "S__:8413=三三六五五三六五五三六五五三六五二E\n", 80 | "7000\n", 81 | "S__:1487=toritoritoritoriEP\n", 82 | "S__:3048=一二一九三二一九三二一九三二一九二E\n", 83 | "8000\n", 84 | "S__:1813=柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E\n", 85 | "S__:1526=陆壹零肆陆壹零肆陆壹零肆陆壹零肆E\n", 86 | "9000\n", 87 | "S__:3136=一二五四五二五四五二五四五二五四四E\n", 88 | "S__:3200=12801280128012800E\n", 89 | "10000\n", 90 | "S__:3851=qtrpttrpttrpttrprE\n", 91 | "S__:2963=一一八五三一八五三一八五三一八五二E\n", 92 | "11000\n", 93 | "S__:2488=玖玖玖贰玖玖玖贰玖玖玖贰玖玖玖贰EP\n", 94 | "S__:4904=19617961796179616E\n", 95 | "12000\n", 96 | "S__:3908=15633567356335632E\n", 97 | "S__:5260=二一〇四二一〇四二一〇四二一〇四〇E\n", 98 | "13000\n", 99 | "S__:7537=三〇一五一〇一五一〇一五一〇一四八E\n", 100 | "S__:6896=贰柒伍捌陆柒伍捌陆柒伍捌陆柒伍捌肆E\n", 101 | "14000\n", 102 | "S__:7923=贰叁陆玖肆叁陆玖肆叁陆玖肆叁陆玖贰E\n", 103 | "S__:5067=贰零陆柒零零陆柒零零陆柒零零陆陆捌E\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "optimizer = torch.optim.AdamW(model_gen.parameters(), lr=1e-4)\n", 109 | "criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.encoder['P'])\n", 110 | "\n", 111 | "for epoch in range(15000):\n", 112 | " _, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)\n", 113 | " input_ids = torch.LongTensor(input_ids).to(device)\n", 114 | " attention_mask = torch.LongTensor(attention_mask).to(device)\n", 115 | "\n", 116 | " logits = model_gen(input_ids=input_ids, attention_mask=attention_mask)\n", 117 | "\n", 118 | " loss = criterion(logits[:, :-1].flatten(end_dim=1),\n", 119 | " input_ids[:, 1:].flatten())\n", 120 | "\n", 121 | " loss.backward()\n", 122 | " optimizer.step()\n", 123 | " optimizer.zero_grad()\n", 124 | "\n", 125 | " if epoch % 1000 == 0:\n", 126 | " print(epoch)\n", 127 | " for i in generate(model_gen, input_ids[:2, :9]):\n", 128 | " print(tokenizer.decode(i.tolist()))\n", 129 | "\n", 130 | "model_gen.to('cpu')\n", 131 | "torch.save(model_gen, 'gen.model')" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python [conda env:cuda117]", 138 | "language": "python", 139 | "name": "conda-env-cuda117-py" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.10.13" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 5 156 | } 157 | -------------------------------------------------------------------------------- /4.test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1f242922", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "['S字母:5278=wqqqrqqqrqqqrqqqwE',\n", 13 | " 'S大写:9113=叁陆肆伍伍陆肆伍伍陆肆伍伍陆肆伍贰E',\n", 14 | " 'S小写:1255=五〇二〇五〇二〇五〇二〇五〇二〇EP',\n", 15 | " 'S大写:1198=肆柒玖贰肆柒玖贰肆柒玖贰肆柒玖贰EP',\n", 16 | " 'S大写:3751=壹伍零零伍伍零零伍伍零零伍伍零零肆E',\n", 17 | " 'S大写:5649=贰贰伍玖捌贰伍玖捌贰伍玖捌贰伍玖陆E',\n", 18 | " 'S小写:6892=二七五七〇七五七〇七五七〇七五六八E',\n", 19 | " 'S小写:2195=八七八〇八七八〇八七八〇八七八〇EP',\n", 20 | " 'S字母:6627=wytqpytqpytqpytpiE',\n", 21 | " 'S大写:7516=叁零零陆柒零零陆柒零零陆柒零零陆肆E']" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "%run common.ipynb\n", 31 | "\n", 32 | "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "e7063d80", 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "(torch.Size([64, 9]), torch.Size([64, 18]))" 45 | ] 46 | }, 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "@torch.no_grad()\n", 54 | "def get_question_and_answer():\n", 55 | " _, token, _ = tokenizer.get_batch_data(prefix=True)\n", 56 | "\n", 57 | " split = [i.index(tokenizer.encoder['=']) + 1 for i in token]\n", 58 | "\n", 59 | " #只要问题部分,等号后面的内容切除\n", 60 | " question = [t[:s] for t, s in zip(token, split)]\n", 61 | " answer = [t[s:] for t, s in zip(token, split)]\n", 62 | "\n", 63 | " #统一长度\n", 64 | " lens = max([len(i) for i in question])\n", 65 | " question = [[tokenizer.encoder['P']] * (lens - len(i)) + i\n", 66 | " for i in question]\n", 67 | " question = torch.LongTensor(question).to(device)\n", 68 | "\n", 69 | " lens = max([len(i) for i in answer])\n", 70 | " answer = [[tokenizer.encoder['P']] * (lens - len(i)) + i for i in answer]\n", 71 | " answer = torch.LongTensor(answer).to(device)\n", 72 | "\n", 73 | " return question, answer\n", 74 | "\n", 75 | "\n", 76 | "question, answer = get_question_and_answer()\n", 77 | "\n", 78 | "question.shape, answer.shape" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "972f2d2e", 85 | "metadata": { 86 | "scrolled": true 87 | }, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "torch.Size([64, 18])" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "model_ppo = torch.load('ppo.model')\n", 102 | "model_ppo.to(device)\n", 103 | "model_ppo.eval()\n", 104 | "\n", 105 | "predict = generate(model_ppo.model_gen, question)\n", 106 | "predict = predict[:, question.shape[1]:]\n", 107 | "\n", 108 | "predict.shape" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 4, 114 | "id": "d4468184", 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "S大写:7664= 叁零陆伍玖零陆伍玖零陆伍玖零陆伍陆E 叁零陆伍玖零陆伍玖零陆伍玖零陆伍陆E\n", 122 | "S小写:9893= 三九五七五九五七五九五七五九五七二E 三九五七五九五七五九五七五九五七二E\n", 123 | "S小写:6529= 二六一一八六一一八六一一八六一一六E 二六一一八六一一八六一一八六一一六E\n", 124 | "S小写:8139= 三二五五九二五五九二五五九二五五六E 三二五五九二五五九二五五九二五五六E\n", 125 | "S大写:7086= 贰捌叁肆陆捌叁肆陆捌叁肆陆捌叁肆肆E 贰捌叁肆陆捌叁肆陆捌叁肆陆捌叁肆肆E\n", 126 | "S大写:5099= 贰零叁玖捌零叁玖捌零叁玖捌零叁玖陆E 贰零叁玖捌零叁玖捌零叁玖捌零叁玖陆E\n", 127 | "S小写:9199= 三六七九九六七九九六七九九六七九六E 三六七九九六七九九六七九九六七九六E\n", 128 | "S大写:6229= 贰肆玖壹捌肆玖壹捌肆玖壹捌肆玖壹陆E 贰肆玖壹捌肆玖壹捌肆玖壹捌肆玖壹陆E\n", 129 | "S小写:6546= 二六一八六六一八六六一八六六一八四E 二六一八六六一八六六一八六六一八四E\n", 130 | "S数字:4435= 17741774177417740E 17741774177417740E\n", 131 | "S字母:3113= qwrtewrtewrtewrtwE qwrtewrtewrtewrtwE\n", 132 | "S小写:1271= 五〇八四五〇八四五〇八四五〇八四E 四〇八四四〇八四四〇八四四〇八四E\n", 133 | "S小写:1424= 五六九六五六九六五六九六五六九六E 五六九六五六九六五六九六五六九六E\n", 134 | "S大写:4742= 壹捌玖陆玖捌玖陆玖捌玖陆玖捌玖陆捌E 壹捌玖陆玖捌玖陆玖捌玖陆玖捌玖陆捌E\n", 135 | "S数字:3692= 14769476947694768E 14769476947694768E\n", 136 | "S数字:7171= 28686868686868684E 28686868686868684E\n", 137 | "S大写:2777= 壹壹壹零玖壹壹零玖壹壹零玖壹壹零捌E 壹壹壹零玖壹壹零玖壹壹零玖壹壹零捌E\n", 138 | "S小写:9911= 三九六四七九六四七九六四七九六四四E 三九六四七九六四七九六四七九六四四E\n", 139 | "S数字:4613= 18453845384538452E 18453845384538452E\n", 140 | "S大写:2593= 壹零叁柒叁零叁柒叁零叁柒叁零叁柒贰E 壹零叁柒叁零叁柒叁零叁柒叁零叁柒贰E\n", 141 | "S小写:9875= 三九五〇三九五〇三九五〇三九五〇〇E 三九五〇三九五〇三九五〇三九五〇〇E\n", 142 | "S数字:2299= 9196919691969196E 9196919691969196E\n", 143 | "S小写:2626= 一〇五〇五〇五〇五〇五〇五〇五〇四E 一〇五〇五〇五〇五〇五〇五〇五〇四E\n", 144 | "S小写:4443= 一七七七三七七七三七七七三七七七二E 一七七七三七七七三七七七三七七七二E\n", 145 | "S小写:2786= 一一一四五一一四五一一四五一一四四E 一一一四五一一四五一一四五一一四四E\n", 146 | "S字母:1077= repirepirepirepiE repirepirepirepiE\n", 147 | "S小写:8809= 三五二三九五二三九五二三九五二三六E 三五二三九五二三九五二三九五二三六E\n", 148 | "S小写:6800= 二七二〇二七二〇二七二〇二七二〇〇E 二七二〇二七二〇二七二〇二七二〇〇E\n", 149 | "S字母:7953= eqiqtqiqtqiqtqiqwE eqiqtqiqtqiqtqiqwE\n", 150 | "S小写:5907= 二三六三〇三六三〇三六三〇三六二八E 二三六三〇三六三〇三六三〇三六二八E\n", 151 | "S数字:6032= 24130413041304128E 24130413041304128E\n", 152 | "S大写:9967= 叁玖捌柒壹玖捌柒壹玖捌柒壹玖捌陆捌E 叁玖捌柒壹玖捌柒壹玖捌柒壹玖捌陆捌E\n", 153 | "S大写:9511= 叁捌零肆柒捌零肆柒捌零肆柒捌零肆肆E 叁捌零肆柒捌零肆柒捌零肆柒捌零肆肆E\n", 154 | "S字母:8200= ewipewipewipewippE ewipewipewipewippE\n", 155 | "S字母:5950= weipweipweipweippE weipweipweipweippE\n", 156 | "S小写:6737= 二六九五〇六九五〇六九五〇六九四八E 二六九五〇六九五〇六九五〇六九四八E\n", 157 | "S字母:3734= qroeuroeuroeuroeyE qroeuroeuroeuroeyE\n", 158 | "S数字:2278= 9112911291129112E 9112911291129112E\n", 159 | "S数字:8039= 32159215921592156E 32159215921592156E\n", 160 | "S字母:9799= eoqoooqoooqoooqoyE eoqoooqoooqoooqoyE\n", 161 | "S数字:2980= 11921192119211920E 11921192119211920E\n", 162 | "S字母:7377= wotqpotqpotqpotpiE wotqpotqpotqpotpiE\n", 163 | "S小写:5406= 二一六二六一六二六一六二六一六二四E 二一六二六一六二六一六二六一六二四E\n", 164 | "S小写:6428= 二五七一四五七一四五七一四五七一二E 二五七一四五七一四五七一四五七一二E\n", 165 | "S字母:2509= qppeuppeuppeuppeyE qppeuppeuppeuppeyE\n", 166 | "S大写:5187= 贰零柒伍零零柒伍零零柒伍零零柒肆捌E 贰零柒伍零零柒伍零零柒伍零零柒肆捌E\n", 167 | "S字母:7275= woqpwoqpwoqpwoqppE woqpwoqpwoqpwoqppE\n", 168 | "S数字:1783= 7132713271327132E 7132713271327132E\n", 169 | "S字母:2144= ituyituyituyituyE ituyituyituyituyE\n", 170 | "S字母:6048= wrqorrqorrqorrqowE wrqorrqorrqorrqowE\n", 171 | "S大写:1813= 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E\n", 172 | "S小写:7707= 三〇八三一〇八三一〇八三一〇八二八E 三〇八三一〇八三一〇八三一〇八二八E\n", 173 | "S大写:8436= 叁叁柒肆柒叁柒肆柒叁柒肆柒叁柒肆肆E 叁叁柒肆柒叁柒肆柒叁柒肆柒叁柒肆肆E\n", 174 | "S小写:9010= 三六〇四三六〇四三六〇四三六〇四〇E 三六〇四三六〇四三六〇四三六〇四〇E\n", 175 | "S大写:1813= 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E\n", 176 | "S数字:5718= 22874287428742872E 22874287428742872E\n", 177 | "S小写:4968= 一九八七三九八七三九八七三九八七二E 一九八七三九八七三九八七三九八七二E\n", 178 | "S数字:5132= 20530053005300528E 20530053005300528E\n", 179 | "S数字:6421= 25686568656865684E 25686568656865684E\n", 180 | "S小写:5481= 二一九二六一九二六一九二六一九二四E 二一九二六一九二六一九二六一九二四E\n", 181 | "S大写:9309= 叁柒贰叁玖柒贰叁玖柒贰叁玖柒贰叁陆E 叁柒贰叁玖柒贰叁玖柒贰叁玖柒贰叁陆E\n", 182 | "S数字:4044= 16177617761776176E 16177617761776176E\n", 183 | "S数字:6370= 25482548254825480E 25482548254825480E\n", 184 | "S小写:5980= 二三九二二三九二二三九二二三九二〇E 二三九二二三九二二三九二二三九二〇E\n" 185 | ] 186 | }, 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "0.984375" 191 | ] 192 | }, 193 | "execution_count": 4, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "correct = 0\n", 200 | "for q, a, p in zip(question, answer, predict):\n", 201 | " q, a, p = q.tolist(), a.tolist(), p.tolist()\n", 202 | "\n", 203 | " if tokenizer.encoder['E'] in a:\n", 204 | " split = a.index(tokenizer.encoder['E']) + 1\n", 205 | " a = a[:split]\n", 206 | "\n", 207 | " if tokenizer.encoder['E'] in p:\n", 208 | " split = p.index(tokenizer.encoder['E']) + 1\n", 209 | " p = p[:split]\n", 210 | "\n", 211 | " q, a, p = tokenizer.decode(q), tokenizer.decode(a), tokenizer.decode(p)\n", 212 | "\n", 213 | " print(q, a, p)\n", 214 | "\n", 215 | " correct += a == p\n", 216 | "\n", 217 | "correct / len(answer)" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python [conda env:cuda117]", 224 | "language": "python", 225 | "name": "conda-env-cuda117-py" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.10.13" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 5 242 | } 243 | -------------------------------------------------------------------------------- /common.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e1c27274", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "['S数字:3929=15717571757175716E',\n", 13 | " 'S小写:3616=一四四六五四四六五四四六五四四六四E',\n", 14 | " 'S大写:6438=贰伍柒伍肆伍柒伍肆伍柒伍肆伍柒伍贰E',\n", 15 | " 'S字母:5966=weiyyeiyyeiyyeiyrE',\n", 16 | " 'S小写:7716=三〇八六七〇八六七〇八六七〇八六四E',\n", 17 | " 'S大写:7307=贰玖贰叁零玖贰叁零玖贰叁零玖贰贰捌E',\n", 18 | " 'S数字:9302=37211721172117208E',\n", 19 | " 'S字母:7822=eqwoqqwoqqwoqqwiiE',\n", 20 | " 'S小写:7413=二九六五四九六五四九六五四九六五二E',\n", 21 | " 'S小写:6266=二五〇六六五〇六六五〇六六五〇六四E']" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "import random\n", 31 | "\n", 32 | "\n", 33 | "class Tokenizer:\n", 34 | "\n", 35 | " def __init__(self):\n", 36 | " self.vocab = {\n", 37 | " 'mark': list('PSEU'),\n", 38 | " 'number': list('0123456789'),\n", 39 | " 'letter': list('pqwertyuio'),\n", 40 | " 'chinese_lower': list('〇一二三四五六七八九'),\n", 41 | " 'chinese_upper': list('零壹贰叁肆伍陆柒捌玖'),\n", 42 | " 'other': list('数字大写小母:=_'),\n", 43 | " }\n", 44 | "\n", 45 | " self.decoder = [j for i in self.vocab.values() for j in i]\n", 46 | " self.encoder = {j: i for i, j in enumerate(self.decoder)}\n", 47 | "\n", 48 | " self.label = {\n", 49 | " 'number': 0,\n", 50 | " 'letter': 1,\n", 51 | " 'chinese_lower': 2,\n", 52 | " 'chinese_upper': 3\n", 53 | " }\n", 54 | " self.prefix = ['数字', '字母', '小写', '大写']\n", 55 | "\n", 56 | " def decode(self, x):\n", 57 | " return ''.join([self.decoder[i] for i in x])\n", 58 | "\n", 59 | " def get_data(self, prefix):\n", 60 | " #生成问题和答案\n", 61 | " question = random.randint(1000, 9999)\n", 62 | " answer = int(str(question) * 4) * 4\n", 63 | " #answer = question**8\n", 64 | " \n", 65 | " question = list(str(question))\n", 66 | " answer = list(str(answer))\n", 67 | "\n", 68 | " #随机label\n", 69 | " label = random.choice(list(self.label.keys()))\n", 70 | "\n", 71 | " #根据label替换答案成其他字符集\n", 72 | " answer = [self.vocab[label][int(i)] for i in answer]\n", 73 | "\n", 74 | " #label转数字\n", 75 | " label = self.label[label]\n", 76 | "\n", 77 | " #组合问题和答案\n", 78 | " if prefix:\n", 79 | " prefix = list(self.prefix[label])\n", 80 | " else:\n", 81 | " prefix = list('__')\n", 82 | " token = prefix + [':'] + question + ['='] + answer\n", 83 | "\n", 84 | " #编码\n", 85 | " token = [self.encoder[i] for i in token]\n", 86 | " token = [self.encoder['S']] + token + [self.encoder['E']]\n", 87 | "\n", 88 | " return label, token\n", 89 | "\n", 90 | " def get_batch_data(self, prefix):\n", 91 | " data = [self.get_data(prefix=prefix) for _ in range(64)]\n", 92 | "\n", 93 | " label = [i[0] for i in data]\n", 94 | " token = [i[1] for i in data]\n", 95 | "\n", 96 | " return label, *self.batch_pad(token=token)\n", 97 | "\n", 98 | " def batch_pad(self, text=None, token=None):\n", 99 | " if text:\n", 100 | " #编码\n", 101 | " token = [[self.encoder[j] for j in i] for i in text]\n", 102 | "\n", 103 | " lens = max([len(i) for i in token])\n", 104 | "\n", 105 | " input_ids = []\n", 106 | " attention_mask = []\n", 107 | " for i in token:\n", 108 | " attention_mask.append([1] * len(i) + [0] * (lens - len(i)))\n", 109 | " input_ids.append(i + [self.encoder['P']] * (lens - len(i)))\n", 110 | "\n", 111 | " return input_ids, attention_mask\n", 112 | "\n", 113 | "\n", 114 | "tokenizer = Tokenizer()\n", 115 | "\n", 116 | "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 2, 122 | "id": "e242caa8", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "'cuda'" 129 | ] 130 | }, 131 | "execution_count": 2, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "import torch\n", 138 | "\n", 139 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 140 | "\n", 141 | "device" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 3, 147 | "id": "17208b2d", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "class ModelGEN(torch.nn.Module):\n", 152 | "\n", 153 | " def __init__(self):\n", 154 | " super().__init__()\n", 155 | " from transformers import GPT2Config, GPT2Model\n", 156 | "\n", 157 | " self.config = GPT2Config(bos_token_id=tokenizer.encoder['S'],\n", 158 | " eos_token_id=tokenizer.encoder['E'],\n", 159 | " n_embd=64,\n", 160 | " n_head=4,\n", 161 | " n_layer=4,\n", 162 | " n_positions=128,\n", 163 | " vocab_size=len(tokenizer.decoder))\n", 164 | "\n", 165 | " self.feature = GPT2Model(self.config)\n", 166 | "\n", 167 | " self.fc_out = torch.nn.Linear(64, self.config.vocab_size, bias=False)\n", 168 | "\n", 169 | " self.to(device)\n", 170 | " self.train()\n", 171 | "\n", 172 | " def forward(self, input_ids, attention_mask):\n", 173 | " out = self.feature(input_ids=input_ids,\n", 174 | " attention_mask=attention_mask).last_hidden_state\n", 175 | "\n", 176 | " return self.fc_out(out)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 4, 182 | "id": "a23b28dd", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "class ModelCLS(torch.nn.Module):\n", 187 | "\n", 188 | " def __init__(self):\n", 189 | " super().__init__()\n", 190 | " from transformers import BertConfig, BertModel\n", 191 | "\n", 192 | " self.config = BertConfig(hidden_size=64,\n", 193 | " intermediate_size=64,\n", 194 | " max_position_embeddings=128,\n", 195 | " num_attention_heads=4,\n", 196 | " num_hidden_layers=4,\n", 197 | " vocab_size=len(tokenizer.decoder))\n", 198 | "\n", 199 | " self.feature = BertModel(self.config)\n", 200 | "\n", 201 | " self.fc_out = torch.nn.Sequential(torch.nn.Dropout(p=0.1),\n", 202 | " torch.nn.Linear(64, 4))\n", 203 | "\n", 204 | " self.to(device)\n", 205 | " self.train()\n", 206 | "\n", 207 | " def forward(self, input_ids, attention_mask):\n", 208 | " out = self.feature(input_ids=input_ids,\n", 209 | " attention_mask=attention_mask).pooler_output\n", 210 | "\n", 211 | " return self.fc_out(out)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 5, 217 | "id": "c6224a5e", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "class ModelPPO(torch.nn.Module):\n", 222 | "\n", 223 | " def __init__(self, model_gen):\n", 224 | " super().__init__()\n", 225 | " self.model_gen = model_gen\n", 226 | " self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),\n", 227 | " torch.nn.Linear(64, 1))\n", 228 | "\n", 229 | " self.to(device)\n", 230 | " self.train()\n", 231 | "\n", 232 | " def forward(self, input_ids, attention_mask):\n", 233 | " last_hidden_state = self.model_gen.feature(\n", 234 | " input_ids=input_ids,\n", 235 | " attention_mask=attention_mask,\n", 236 | " output_hidden_states=True).last_hidden_state\n", 237 | "\n", 238 | " logits = self.model_gen.fc_out(last_hidden_state)\n", 239 | " value = self.v_head(last_hidden_state).squeeze(-1)\n", 240 | "\n", 241 | " return logits, value" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 6, 247 | "id": "a13f60dc", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "generater = None\n", 252 | "\n", 253 | "\n", 254 | "def generate(model_gen, input_ids):\n", 255 | " global generater\n", 256 | " if not generater:\n", 257 | " #包装类,用于生成\n", 258 | " from transformers import GPT2LMHeadModel\n", 259 | " generater = GPT2LMHeadModel(model_gen.config)\n", 260 | " generater.transformer = model_gen.feature\n", 261 | " generater.lm_head = model_gen.fc_out\n", 262 | " generater.to(device)\n", 263 | "\n", 264 | " return generater.generate(input_ids=input_ids,\n", 265 | " min_length=-1,\n", 266 | " top_k=0.0,\n", 267 | " top_p=1.0,\n", 268 | " do_sample=True,\n", 269 | " pad_token_id=tokenizer.encoder['P'],\n", 270 | " max_new_tokens=25,\n", 271 | " eos_token_id=tokenizer.encoder['E'])" 272 | ] 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "Python [conda env:cuda117]", 278 | "language": "python", 279 | "name": "conda-env-cuda117-py" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 3 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython3", 291 | "version": "3.10.13" 292 | } 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 5 296 | } 297 | -------------------------------------------------------------------------------- /3.train_ppo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ab835cce", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "['S数字:4564=18257825782578256E',\n", 13 | " 'S大写:5895=贰叁伍捌贰叁伍捌贰叁伍捌贰叁伍捌零E',\n", 14 | " 'S数字:3532=14129412941294128E',\n", 15 | " 'S字母:8984=etoeotoeotoeotoeyE',\n", 16 | " 'S小写:1412=五六四八五六四八五六四八五六四八EP',\n", 17 | " 'S数字:7764=31059105910591056E',\n", 18 | " 'S小写:8989=三五九五九五九五九五九五九五九五六E',\n", 19 | " 'S数字:6596=26386638663866384E',\n", 20 | " 'S字母:5759=wepeiepeiepeiepeyE',\n", 21 | " 'S数字:7736=30947094709470944E']" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "%run common.ipynb\n", 31 | "\n", 32 | "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "bcb3d7c8", 39 | "metadata": { 40 | "scrolled": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "model_ppo = ModelPPO(torch.load('gen.model'))\n", 45 | "model_ppo_ref = ModelPPO(torch.load('gen.model'))\n", 46 | "\n", 47 | "for i in model_ppo_ref.parameters():\n", 48 | " i.requires_grad_(False)\n", 49 | "\n", 50 | "model_cls = torch.load('cls.model')\n", 51 | "model_cls.to(device)\n", 52 | "\n", 53 | "for i in model_cls.parameters():\n", 54 | " i.requires_grad_(False)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "c3e7a327", 61 | "metadata": { 62 | "scrolled": true 63 | }, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "(tensor([2, 0, 0, 3, 3, 1, 1, 3, 3, 2, 2, 3, 2, 0, 2, 3, 1, 0, 3, 0, 3, 0, 3, 3,\n", 69 | " 0, 2, 3, 0, 2, 2, 2, 0, 1, 0, 3, 1, 1, 2, 1, 0, 3, 3, 3, 3, 3, 2, 3, 3,\n", 70 | " 3, 2, 2, 2, 3, 2, 3, 2, 0, 2, 0, 2, 3, 1, 3, 3], device='cuda:0'),\n", 71 | " tensor([[ 1, 48, 47, 50, 6, 8, 7, 4, 51],\n", 72 | " [ 1, 44, 45, 50, 6, 7, 4, 11, 51],\n", 73 | " [ 1, 44, 45, 50, 11, 8, 11, 12, 51],\n", 74 | " [ 1, 46, 47, 50, 10, 12, 11, 11, 51],\n", 75 | " [ 1, 46, 47, 50, 8, 13, 13, 8, 51],\n", 76 | " [ 1, 45, 49, 50, 13, 6, 11, 4, 51],\n", 77 | " [ 1, 45, 49, 50, 5, 7, 11, 12, 51],\n", 78 | " [ 1, 46, 47, 50, 9, 5, 12, 8, 51],\n", 79 | " [ 1, 46, 47, 50, 9, 13, 7, 11, 51],\n", 80 | " [ 1, 48, 47, 50, 6, 12, 12, 5, 51]], device='cuda:0'))" 81 | ] 82 | }, 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "@torch.no_grad()\n", 90 | "def get_question():\n", 91 | " label, question, _ = tokenizer.get_batch_data(prefix=True)\n", 92 | " label = torch.LongTensor(label).to(device)\n", 93 | "\n", 94 | " #只要问题部分,等号后面的内容切除\n", 95 | " question = [i[:i.index(tokenizer.encoder['=']) + 1] for i in question]\n", 96 | "\n", 97 | " #统一长度\n", 98 | " lens = max([len(i) for i in question])\n", 99 | " question = [[tokenizer.encoder['P']] * (lens - len(i)) + i\n", 100 | " for i in question]\n", 101 | "\n", 102 | " question = torch.LongTensor(question).to(device)\n", 103 | "\n", 104 | " return label, question\n", 105 | "\n", 106 | "\n", 107 | "label, question = get_question()\n", 108 | "\n", 109 | "label, question[:10]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 4, 115 | "id": "b77e7d96", 116 | "metadata": { 117 | "scrolled": true 118 | }, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "tensor([[23, 21, 16, 14, 23, 21, 16, 14, 23, 21, 16, 14, 23, 21, 16, 14, 2, 0],\n", 124 | " [42, 39, 34, 42, 42, 39, 34, 42, 42, 39, 34, 42, 42, 39, 34, 42, 2, 0],\n", 125 | " [ 6, 13, 13, 7, 8, 13, 13, 5, 8, 13, 13, 5, 8, 13, 13, 5, 6, 2],\n", 126 | " [ 6, 11, 9, 5, 4, 11, 9, 5, 4, 11, 9, 5, 4, 11, 9, 4, 12, 2],\n", 127 | " [ 5, 13, 13, 11, 11, 13, 13, 11, 11, 13, 13, 11, 11, 13, 13, 11, 10, 2],\n", 128 | " [37, 41, 34, 42, 37, 41, 34, 42, 37, 41, 34, 42, 37, 41, 34, 42, 34, 2],\n", 129 | " [18, 17, 15, 16, 19, 17, 15, 16, 19, 17, 15, 16, 19, 17, 15, 16, 2, 0],\n", 130 | " [36, 34, 41, 37, 42, 34, 41, 37, 42, 34, 41, 37, 42, 34, 41, 37, 40, 2],\n", 131 | " [16, 17, 21, 19, 14, 17, 21, 19, 14, 17, 21, 19, 14, 17, 21, 18, 22, 2],\n", 132 | " [35, 35, 39, 36, 39, 35, 39, 36, 39, 35, 39, 36, 39, 35, 39, 36, 38, 2]],\n", 133 | " device='cuda:0')" 134 | ] 135 | }, 136 | "execution_count": 4, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "#如果question的长度确定,这里可以转换成批运算\n", 143 | "@torch.no_grad()\n", 144 | "def get_answer(question):\n", 145 | " answer = generate(model_ppo.model_gen, question)\n", 146 | "\n", 147 | " #裁剪,只要生成的部分\n", 148 | " answer = answer[:, question.shape[1]:]\n", 149 | "\n", 150 | " return answer\n", 151 | "\n", 152 | "\n", 153 | "answer = get_answer(question)\n", 154 | "\n", 155 | "answer[:10]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "id": "720b4c0f", 162 | "metadata": { 163 | "scrolled": true 164 | }, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "tensor([-1.3870, -1.4730, 3.6036, -1.4200, -1.7467, -0.8223, 3.5521, 3.4942,\n", 170 | " -0.6207, -0.9800, -1.3754, -0.6482, -0.6081, -0.9560, -1.1991, -1.5163,\n", 171 | " -1.4721, -0.5136, 3.2385, -0.7568, -1.9590, -1.7225, -0.8976, -0.7586,\n", 172 | " -2.2114, 4.0351, 3.3967, 3.8929, -0.5727, -0.6243, -1.2016, 3.9614,\n", 173 | " -0.4552, -0.9152, 3.4247, -0.8902, -0.5747, -1.6382, -0.8617, 3.7373,\n", 174 | " -1.0044, -0.8499, -0.8065, -0.6340, -0.8241, -0.9206, -1.2685, -0.8730,\n", 175 | " -0.7941, -1.1478, -0.5335, -0.7640, -1.9618, -0.4679, 3.3003, -0.8195,\n", 176 | " 3.8506, -1.7656, 4.1354, -0.9555, 3.6093, 3.5485, -0.4862, -0.7969],\n", 177 | " device='cuda:0')" 178 | ] 179 | }, 180 | "execution_count": 5, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "@torch.no_grad()\n", 187 | "def get_reward(question, answer, label):\n", 188 | " input_ids = torch.cat((question, answer), 1)\n", 189 | " attention_mask = (input_ids != tokenizer.encoder['P']).long()\n", 190 | "\n", 191 | " with torch.no_grad():\n", 192 | " logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)\n", 193 | "\n", 194 | " return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n", 195 | "\n", 196 | "\n", 197 | "reward = get_reward(question, answer, label)\n", 198 | "\n", 199 | "reward" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "id": "a48cfa6f", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "'success'" 212 | ] 213 | }, 214 | "execution_count": 6, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "\"\"\"注释代码,可以不看\"\"\"\n", 221 | "\n", 222 | "\n", 223 | "#get_delta函数的原理解释,注释性代码\n", 224 | "#数学上和get_delta函数等价,但是运行效率低\n", 225 | "def get_delta_note(value, reward_kl):\n", 226 | " #下一个词的value,减去当前词的value,相当于对value去基线,缩小数值方差\n", 227 | " #每个词的value是相互独立的,前后词value的差,可以视为预测质量的衡量\n", 228 | " value_next = torch.zeros_like(value)\n", 229 | " value_next[:, :-1] = value[:, 1:].clone()\n", 230 | "\n", 231 | " #在value中融合reward,kl\n", 232 | " diff = reward_kl + value_next - value\n", 233 | "\n", 234 | " #蒙特卡洛采样法估计Q函数,每个时刻的价值,等于后续所有价值的加权求和\n", 235 | " #这里计算的其实就是adv\n", 236 | " delta = []\n", 237 | " for i in range(diff.shape[1]):\n", 238 | " s = 0\n", 239 | " for j in range(i, diff.shape[1]):\n", 240 | " s += diff[:, j] * 0.95**(j - i)\n", 241 | " delta.append(s)\n", 242 | "\n", 243 | " return torch.stack(delta, dim=1)\n", 244 | "\n", 245 | "\n", 246 | "#只用一次循环就计算出delta,计算效率提高很多\n", 247 | "def get_delta_fast(value, reward_kl):\n", 248 | " delta = []\n", 249 | "\n", 250 | " for i in reversed(range(reward_kl.shape[1])):\n", 251 | " value_next = 0\n", 252 | " if i < reward_kl.shape[1] - 1:\n", 253 | " value_next = value[:, i + 1]\n", 254 | "\n", 255 | " diff = reward_kl[:, i] + value_next - value[:, i]\n", 256 | "\n", 257 | " diff_last = 0\n", 258 | " if len(delta):\n", 259 | " diff_last = delta[-1]\n", 260 | "\n", 261 | " delta.append(diff + 0.95 * diff_last)\n", 262 | "\n", 263 | " return torch.stack(delta[::-1]).transpose(0, 1)\n", 264 | "\n", 265 | "\n", 266 | "#测试两个函数是等价的,误差是由于计算机精度导致的\n", 267 | "for _ in range(200):\n", 268 | " value = torch.randn(64, 26)\n", 269 | " reward_kl = torch.randn(64, 26)\n", 270 | "\n", 271 | " assert (get_delta_note(value, reward_kl) -\n", 272 | " get_delta_fast(value, reward_kl)).abs().max() < 1e-5\n", 273 | "'success'" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 7, 279 | "id": "9d600b60", 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "from trl.core import clip_by_value, logprobs_from_logits, masked_mean, masked_whiten\n", 284 | "\n", 285 | "\n", 286 | "class PPOTrainer:\n", 287 | "\n", 288 | " def __init__(self):\n", 289 | " self.optimizer = torch.optim.Adam(model_ppo.parameters(), lr=1e-5)\n", 290 | "\n", 291 | " def step(self, question, answer, reward):\n", 292 | " with torch.no_grad():\n", 293 | " #编码\n", 294 | " token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]\n", 295 | " input_ids, attention_mask = tokenizer.batch_pad(token=token)\n", 296 | " del token\n", 297 | " input_ids = torch.LongTensor(input_ids).to(device)\n", 298 | " attention_mask = torch.LongTensor(attention_mask).to(device)\n", 299 | "\n", 300 | " #question和answer不需要内容,只需要长度信息即可\n", 301 | " lens_q = [question.shape[1]] * len(question)\n", 302 | " lens_a = []\n", 303 | "\n", 304 | " for a in answer:\n", 305 | " if tokenizer.encoder['E'] in a:\n", 306 | " lens_a.append(a.tolist().index(tokenizer.encoder['E']) + 1)\n", 307 | " continue\n", 308 | " lens_a.append(len(a))\n", 309 | "\n", 310 | " del question\n", 311 | " del answer\n", 312 | "\n", 313 | " #根据question计算answer的概率,并计算每个动作的分数\n", 314 | " prob_log, value, mask = self.batched_forward_pass(\n", 315 | " model_ppo, input_ids, attention_mask, lens_q, lens_a)\n", 316 | "\n", 317 | " #使用ref模型计算概率,这是为了计算kl散度\n", 318 | " prob_log_ref, _, _ = self.batched_forward_pass(\n", 319 | " model_ppo_ref, input_ids, attention_mask, lens_q, lens_a)\n", 320 | "\n", 321 | " #计算两份概率的kl散度,并融入reward\n", 322 | " reward = self.compute_rewards(reward, prob_log, prob_log_ref, mask)\n", 323 | "\n", 324 | " #计算delta和target,用于计算loss\n", 325 | " value, delta, target = self.compute_advantages(value, reward, mask)\n", 326 | "\n", 327 | " #每批数据循环N次模型\n", 328 | " for _ in range(4):\n", 329 | " #每次算一个数据\n", 330 | " for i in range(len(input_ids)):\n", 331 | " #重新计算概率和value\n", 332 | " prob_log_new, value_new, _ = self.batched_forward_pass(\n", 333 | " model_ppo, input_ids[i].unsqueeze(0),\n", 334 | " attention_mask[i].unsqueeze(0), [lens_q[i]], [lens_a[i]])\n", 335 | "\n", 336 | " #根据新旧概率求出变化率,进而求出loss\n", 337 | " #根据target和value的差可以计算出另外一份loss\n", 338 | " loss = self.get_loss(prob_log[i].unsqueeze(0),\n", 339 | " value[i].unsqueeze(0), prob_log_new,\n", 340 | " value_new, mask[i].unsqueeze(0),\n", 341 | " delta[i].unsqueeze(0),\n", 342 | " target[i].unsqueeze(0))\n", 343 | "\n", 344 | " if not loss:\n", 345 | " continue\n", 346 | "\n", 347 | " loss.backward()\n", 348 | " #torch.nn.utils.clip_grad_norm_(model_ppo.parameters(), 1.0)\n", 349 | " self.optimizer.step()\n", 350 | " self.optimizer.zero_grad()\n", 351 | "\n", 352 | " def batched_forward_pass(self, model, input_ids, attention_mask, lens_q,\n", 353 | " lens_a):\n", 354 | " logits, value = model(input_ids=input_ids,\n", 355 | " attention_mask=attention_mask)\n", 356 | "\n", 357 | " #取每个字的概率对数\n", 358 | " prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])\n", 359 | "\n", 360 | " #是预测结果并且不是PAD的位置是1\n", 361 | " mask = torch.zeros_like(attention_mask)\n", 362 | " mask[:, :-1] = attention_mask[:, 1:]\n", 363 | " for i in range(len(input_ids)):\n", 364 | " start = lens_q[i] - 1\n", 365 | " end = start + lens_a[i]\n", 366 | " mask[i, :start] = 0\n", 367 | " mask[i, end:] = 0\n", 368 | "\n", 369 | " #对最后一个字的预测没有意义,直接丢弃\n", 370 | " value = value[:, :-1]\n", 371 | " mask = mask[:, :-1]\n", 372 | "\n", 373 | " return prob_log, value, mask\n", 374 | "\n", 375 | " def compute_rewards(self, reward, prob_log, prob_log_ref, mask):\n", 376 | " reward_kl = []\n", 377 | "\n", 378 | " for i in range(len(reward)):\n", 379 | " #求两份概率的kl散度\n", 380 | " kl = self.get_kl(prob_log[i], prob_log_ref[i]) * -0.2\n", 381 | "\n", 382 | " #把reward加在最后一个字的kl散度上\n", 383 | " if (mask[i] == 0).all():\n", 384 | " #print('all 0')\n", 385 | " idx = 0\n", 386 | " else:\n", 387 | " idx = mask[i].nonzero()[-1].item()\n", 388 | " kl[idx] += reward[i]\n", 389 | "\n", 390 | " reward_kl.append(kl)\n", 391 | "\n", 392 | " return torch.stack(reward_kl)\n", 393 | "\n", 394 | " def compute_advantages(self, value, reward_kl, mask):\n", 395 | " value = value * mask\n", 396 | " reward_kl = reward_kl * mask\n", 397 | "\n", 398 | " #这里计算delta的过程,可以看上面的注释.\n", 399 | " delta = []\n", 400 | " for i in reversed(range(reward_kl.shape[1])):\n", 401 | " value_next = 0\n", 402 | " if i < reward_kl.shape[1] - 1:\n", 403 | " value_next = value[:, i + 1]\n", 404 | "\n", 405 | " diff = reward_kl[:, i] + value_next - value[:, i]\n", 406 | "\n", 407 | " diff_last = 0\n", 408 | " if len(delta):\n", 409 | " diff_last = delta[-1]\n", 410 | "\n", 411 | " delta.append(diff + 0.95 * diff_last)\n", 412 | "\n", 413 | " delta = torch.stack(delta[::-1]).transpose(0, 1)\n", 414 | "\n", 415 | " #定义target,它估计了理想的value值\n", 416 | " target = delta + value\n", 417 | " delta = masked_whiten(delta, mask)\n", 418 | "\n", 419 | " return value, delta, target\n", 420 | "\n", 421 | " def get_loss(self, prob_log, value, prob_log_new, value_new, mask, delta,\n", 422 | " target):\n", 423 | "\n", 424 | " #对数概率,相除变相减,取exp后还原为商,即两个模型输出logits的变化率\n", 425 | " ratio = (prob_log_new - prob_log).exp()\n", 426 | "\n", 427 | " #如果变化率太过于剧烈,可能是发生了震荡,跳过\n", 428 | " if masked_mean(ratio, mask).item() > 10:\n", 429 | " #print('skip', masked_mean(ratio, mask).item())\n", 430 | " return None\n", 431 | "\n", 432 | " #先算两个value的loss,简单的算mse loss就可以了\n", 433 | " loss_vf1 = (value_new - target)**2\n", 434 | " #数值裁剪,很显然是为了缓解自举\n", 435 | " loss_vf2 = clip_by_value(value_new, value - 0.2, value + 0.2)\n", 436 | " loss_vf2 = (loss_vf2 - target)**2\n", 437 | " #两份loss取大的,还是为了缓解自举\n", 438 | " loss_vf = 0.5 * masked_mean(torch.max(loss_vf1, loss_vf2), mask)\n", 439 | "\n", 440 | " #计算ppo loss\n", 441 | " loss_surr1 = -delta * ratio\n", 442 | " #数值裁剪,很显然是为了缓解自举\n", 443 | " loss_surr2 = -delta * ratio.clamp(0.8, 1.2)\n", 444 | " loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2), mask)\n", 445 | "\n", 446 | " return loss_surr + 0.1 * loss_vf\n", 447 | "\n", 448 | " def get_kl(self, a, b):\n", 449 | " method = 'kl'\n", 450 | "\n", 451 | " if method == 'kl':\n", 452 | " return a - b\n", 453 | "\n", 454 | " if method == 'abs':\n", 455 | " return (a - b).abs()\n", 456 | "\n", 457 | " if method == 'mse':\n", 458 | " return (a - b).square() * 0.5\n", 459 | "\n", 460 | " if method == 'full':\n", 461 | " return torch.nn.functional.kl_div(a,\n", 462 | " b,\n", 463 | " log_target=True,\n", 464 | " reduction='none')\n", 465 | "\n", 466 | "\n", 467 | "trainer = PPOTrainer()\n", 468 | "\n", 469 | "trainer.step(question, answer, reward)" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 8, 475 | "id": "a1c70e8e", 476 | "metadata": { 477 | "scrolled": false 478 | }, 479 | "outputs": [ 480 | { 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "0 -0.1644705832004547\n", 485 | "S字母:8788= 三五一五五五一五五五一五五五一五二E -1.8778380155563354\n", 486 | "S大写:5525= 22102210221022100E -1.3465244770050049\n", 487 | "100 -0.13109271228313446\n", 488 | "S大写:5438= 21754175417541752E -1.740771770477295\n", 489 | "S字母:9446= 37787778777877784E -0.9766024351119995\n", 490 | "200 0.08248350769281387\n", 491 | "S字母:5045= 二〇一八二〇一八二〇一八二〇一八〇E -1.8033310174942017\n", 492 | "S字母:1891= 七五六四七五六四七五六四七五六四EP -1.385983943939209\n", 493 | "300 0.08983775973320007\n", 494 | "S大写:4288= 17153715371537152E -1.6275231838226318\n", 495 | "S字母:2708= qpieepitepiuepiewE 3.6285555362701416\n", 496 | "400 -0.011256426572799683\n", 497 | "S字母:3165= 一二六六一二六六一二六六一二六六〇E -1.7630295753479004\n", 498 | "S小写:6637= wyttpyttpyttpytriE -1.261555552482605\n", 499 | "500 1.0705180168151855\n", 500 | "S大写:7549= 30199019901990196E -1.7732865810394287\n", 501 | "S字母:7106= wirwyirwyirwyirwrE 3.213606595993042\n", 502 | "600 0.9927632808685303\n", 503 | "S小写:3522= qrpiorpiorpiorpiiE -1.0137674808502197\n", 504 | "S字母:9272= 37091709170917088E -0.6493400931358337\n", 505 | "700 1.253991723060608\n", 506 | "S字母:8986= 三五一四七五一四七五一四七五一四四E -1.7062468528747559\n", 507 | "S数字:2630= 10521052105210520E 3.4852147102355957\n", 508 | "800 1.9489874839782715\n", 509 | "S小写:6515= 二六〇六二六〇六二六〇六二六〇六〇E 3.5120785236358643\n", 510 | "S大写:7809= 三一二三九一二三九一二三九一二三六E -1.1651257276535034\n", 511 | "900 2.2832207679748535\n", 512 | "S小写:9063= 三六二五五六二五五六二五五六二五二E 3.869190216064453\n", 513 | "S小写:5056= 二〇二二六〇二二六〇二二六〇二二四E 3.8240041732788086\n", 514 | "1000 2.552072525024414\n", 515 | "S字母:6267= wtpuptpuptpuptpyiE 3.886591672897339\n", 516 | "S大写:7411= 二九六四六九六四六九六四六九六四四E -0.9950095415115356\n", 517 | "1100 1.9958045482635498\n", 518 | "S小写:2293= 九一七二三一七二三一七二三一七二EP 3.9680638313293457\n", 519 | "S大写:3712= 一四八四九四八四九四八四九四八四八E -1.0302479267120361\n", 520 | "1200 2.961498260498047\n", 521 | "S数字:1598= 6392639263926392EP 3.4238054752349854\n", 522 | "S小写:5128= 二〇五一四〇五一四〇五一四〇五一二E 3.8417978286743164\n", 523 | "1300 3.190735340118408\n", 524 | "S小写:1263= 五〇五二五〇五二五〇五二五〇五二EP 3.958763599395752\n", 525 | "S数字:9667= 38651865186518668E 3.9419965744018555\n", 526 | "1400 2.782565116882324\n", 527 | "S大写:9953= eoiqtoiqtoiqtoiqwE -0.39596277475357056\n", 528 | "S字母:4869= qoruuoruuoruuoruyE 3.7458271980285645\n", 529 | "1500 2.666975498199463\n", 530 | "S大写:7570= 三〇二八三〇二〇三〇二八三〇二八〇E -0.5756788849830627\n", 531 | "S数字:7847= 31311131113111308E 3.6758008003234863\n", 532 | "1600 2.0806758403778076\n", 533 | "S数字:5065= 20262026202620260E 3.6465954780578613\n", 534 | "S字母:7631= eptwuptwuptwuptwrE 3.6560962200164795\n", 535 | "1700 3.6355276107788086\n", 536 | "S大写:4187= 壹陆柒肆玖陆柒肆玖陆柒肆玖陆柒肆捌EPPPPPPP 3.55444073677063\n", 537 | "S小写:1774= 七〇九六七〇九六七〇九六七〇九六EPPPPPPPP 3.7898011207580566\n", 538 | "1800 3.7213969230651855\n", 539 | "S数字:7659= 30639063906390636E 3.8294143676757812\n", 540 | "S字母:6008= wrperrperrperrpe贰E 3.660592794418335\n", 541 | "1900 3.6552512645721436\n", 542 | "S字母:1361= trrrtrrrtrrrtrrrEP 3.942028760910034\n", 543 | "S小写:8038= 三二一五五二一五五二一五五二一五二E 4.008636951446533\n" 544 | ] 545 | } 546 | ], 547 | "source": [ 548 | "for epoch in range(2000):\n", 549 | " label, question = get_question()\n", 550 | " answer = get_answer(question)\n", 551 | " reward = get_reward(question, answer, label)\n", 552 | "\n", 553 | " trainer.step(question, answer, reward)\n", 554 | "\n", 555 | " if epoch % 100 == 0:\n", 556 | " print(epoch, reward.mean().item())\n", 557 | " for _, q, a, r in zip(range(2), question, answer, reward):\n", 558 | " q = tokenizer.decode(q.tolist())\n", 559 | " a = tokenizer.decode(a.tolist())\n", 560 | " r = r.item()\n", 561 | " print(q, a, r)\n", 562 | "\n", 563 | "model_ppo.to('cpu')\n", 564 | "torch.save(model_ppo, 'ppo.model')" 565 | ] 566 | } 567 | ], 568 | "metadata": { 569 | "kernelspec": { 570 | "display_name": "Python [conda env:cuda117]", 571 | "language": "python", 572 | "name": "conda-env-cuda117-py" 573 | }, 574 | "language_info": { 575 | "codemirror_mode": { 576 | "name": "ipython", 577 | "version": 3 578 | }, 579 | "file_extension": ".py", 580 | "mimetype": "text/x-python", 581 | "name": "python", 582 | "nbconvert_exporter": "python", 583 | "pygments_lexer": "ipython3", 584 | "version": "3.10.13" 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 5 589 | } 590 | --------------------------------------------------------------------------------