├── .gitignore ├── 0.下载文件.ipynb ├── 1.dpo_trl训练.ipynb ├── 2.dpo_手动训练.ipynb ├── 3.ppo_trl训练.ipynb ├── 4.ppo_手动训练.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints -------------------------------------------------------------------------------- /0.下载文件.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "65d45fa0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from transformers import AutoTokenizer\n", 11 | "\n", 12 | "AutoTokenizer.from_pretrained('gpt2').save_pretrained('tokenizer/gpt2')\n", 13 | "\n", 14 | "AutoTokenizer.from_pretrained('lvwerra/gpt2-imdb').save_pretrained(\n", 15 | " 'tokenizer/lvwerra/gpt2-imdb')\n", 16 | "\n", 17 | "AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb').save_pretrained(\n", 18 | " 'tokenizer/lvwerra/distilbert-imdb')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "ba345732", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from datasets import load_dataset\n", 29 | "\n", 30 | "load_dataset('imdb').save_to_disk('dataset/imdb')\n", 31 | "\n", 32 | "load_dataset('b-mc2/sql-create-context').save_to_disk(\n", 33 | " 'dataset/b-mc2/sql-create-context')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "4341f356", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from transformers import AutoModelForCausalLM\n", 44 | "from transformers import AutoModelForSequenceClassification\n", 45 | "\n", 46 | "AutoModelForCausalLM.from_pretrained('gpt2').save_pretrained('model/gpt2')\n", 47 | "\n", 48 | "AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb').save_pretrained(\n", 49 | " 'model/lvwerra/gpt2-imdb')\n", 50 | "\n", 51 | "AutoModelForSequenceClassification.from_pretrained(\n", 52 | " 'lvwerra/distilbert-imdb').save_pretrained('model/lvwerra/distilbert-imdb')" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python [conda env:pt2]", 59 | "language": "python", 60 | "name": "conda-env-pt2-py" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.10.13" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 5 77 | } 78 | -------------------------------------------------------------------------------- /1.dpo_trl训练.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "12276e41", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n", 15 | "Using sep_token, but it is not set yet.\n", 16 | "Using cls_token, but it is not set yet.\n", 17 | "Using mask_token, but it is not set yet.\n" 18 | ] 19 | }, 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "GPT2TokenizerFast(name_or_path='tokenizer/gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", 24 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 25 | "}" 26 | ] 27 | }, 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "output_type": "execute_result" 31 | } 32 | ], 33 | "source": [ 34 | "from transformers import AutoTokenizer\n", 35 | "import random\n", 36 | "import torch\n", 37 | "\n", 38 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2')\n", 39 | "tokenizer.pad_token_id = 0\n", 40 | "\n", 41 | "tokenizer" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "id": "58af6f4a", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "(DatasetDict({\n", 54 | " train: Dataset({\n", 55 | " features: ['prompt', 'chosen', 'rejected'],\n", 56 | " num_rows: 71745\n", 57 | " })\n", 58 | " test: Dataset({\n", 59 | " features: ['prompt', 'chosen', 'rejected'],\n", 60 | " num_rows: 200\n", 61 | " })\n", 62 | " }),\n", 63 | " {'prompt': 'context:CREATE TABLE TV_series (SHARE INTEGER) question:What is minimum and maximum share of TV series? answer:',\n", 64 | " 'chosen': 'SELECT MAX(SHARE), MIN(SHARE) FROM TV_series',\n", 65 | " 'rejected': ''})" 66 | ] 67 | }, 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "from datasets import load_from_disk\n", 75 | "\n", 76 | "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n", 77 | "\n", 78 | "\n", 79 | "def f(data):\n", 80 | " question = 'context:%s question:%s answer:' % (data['context'],\n", 81 | " data['question'])\n", 82 | " answer = data['answer']\n", 83 | " return {'question': question, 'answer': answer}\n", 84 | "\n", 85 | "\n", 86 | "dataset = dataset.map(f, remove_columns=['context'])\n", 87 | "\n", 88 | "\n", 89 | "def f(data):\n", 90 | " question = len(tokenizer.encode(data['question']))\n", 91 | " answer = len(tokenizer.encode(data['answer']))\n", 92 | " return 25 <= question <= 65 and 10 <= answer <= 35\n", 93 | "\n", 94 | "\n", 95 | "dataset = dataset.filter(f)\n", 96 | "\n", 97 | "\n", 98 | "def f(data):\n", 99 | " return {\n", 100 | " 'prompt': data['question'],\n", 101 | " 'chosen': data['answer'],\n", 102 | " 'rejected': ''\n", 103 | " }\n", 104 | "\n", 105 | "\n", 106 | "dataset = dataset.map(f, remove_columns=['question', 'answer'])\n", 107 | "dataset = dataset.train_test_split(test_size=200)\n", 108 | "\n", 109 | "dataset, dataset['train'][0]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "id": "88954533", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "from transformers import AutoModelForCausalLM\n", 120 | "\n", 121 | "model_dpo = AutoModelForCausalLM.from_pretrained('model/gpt2').to('cuda')\n", 122 | "model_dpo_ref = AutoModelForCausalLM.from_pretrained('model/gpt2').to('cuda')" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "id": "8dba89d3", 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "'context:CREATE TABLE department (num_employees INTEGER, ranking INTEGER) question:What is the average number of employees of the departments whose rank is between 10 and 15? answer:10\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer'" 135 | ] 136 | }, 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "import torch\n", 144 | "import random\n", 145 | "\n", 146 | "\n", 147 | "@torch.no_grad()\n", 148 | "def generate(input_ids):\n", 149 | " lens = input_ids.shape[1]\n", 150 | " while True:\n", 151 | " out = model_dpo(input_ids=input_ids)\n", 152 | " topk = out['logits'][0, -1].topk(1)\n", 153 | "\n", 154 | " values = topk.values.softmax(0).tolist()\n", 155 | " indices = topk.indices.tolist()\n", 156 | " next_word = random.choices(indices, weights=values)\n", 157 | "\n", 158 | " next_word = torch.LongTensor(next_word).unsqueeze(0).to('cuda')\n", 159 | " input_ids = torch.cat([input_ids, next_word], dim=1)\n", 160 | "\n", 161 | " if input_ids.shape[1] - lens >= 35:\n", 162 | " break\n", 163 | "\n", 164 | " if input_ids[0, -1] == tokenizer.eos_token_id:\n", 165 | " break\n", 166 | "\n", 167 | " return input_ids\n", 168 | "\n", 169 | "\n", 170 | "input_ids = 'context:CREATE TABLE department (num_employees INTEGER, ranking INTEGER) question:What is the average number of employees of the departments whose rank is between 10 and 15? answer:'\n", 171 | "input_ids = tokenizer.encode(input_ids, return_tensors='pt').to('cuda')\n", 172 | "\n", 173 | "out = generate(input_ids)\n", 174 | "\n", 175 | "tokenizer.decode(out[0])" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 5, 181 | "id": "16be7af3", 182 | "metadata": { 183 | "scrolled": false 184 | }, 185 | "outputs": [ 186 | { 187 | "name": "stderr", 188 | "output_type": "stream", 189 | "text": [ 190 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", 191 | " warnings.warn(\n", 192 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:291: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n", 193 | " warnings.warn(\n", 194 | "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" 195 | ] 196 | }, 197 | { 198 | "data": { 199 | "text/html": [ 200 | "\n", 201 | "
\n", 202 | " \n", 203 | " \n", 204 | " [2000/2000 19:19, Epoch 0/1]\n", 205 | "
\n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | "
StepTraining Loss
5000.001800
10000.000400
15000.000500
20000.000300

" 232 | ], 233 | "text/plain": [ 234 | "" 235 | ] 236 | }, 237 | "metadata": {}, 238 | "output_type": "display_data" 239 | }, 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "100\n", 245 | "context:CREATE TABLE table_25330991_3 (james_e_holmes VARCHAR, reidsville VARCHAR) question:Name the james e. holmes for erselle young answer:SELECT name FROM table_25330991_3 WHERE reidsville = \"young\"<|endoftext|>\n", 246 | "=================\n", 247 | "SELECT james_e_holmes FROM table_25330991_3 WHERE reidsville = \"Erselle Young\"\n", 248 | "=================\n", 249 | "200\n", 250 | "context:CREATE TABLE table_name_8 (weekly_rank VARCHAR, official_ratings__millions_ VARCHAR, show VARCHAR) question:Which Weekly Rank for a Live Final Show has an Official Ratings (millions) greater than 5.2? answer:SELECT weekly_rank FROM table_name_8 WHERE official_ratings__millions_ = 5.2 AND show = \"Live Final Show\" AND show = \"\n", 251 | "=================\n", 252 | "SELECT weekly_rank FROM table_name_8 WHERE official_ratings__millions_ > 5.2 AND show = \"live final\"\n", 253 | "=================\n", 254 | "300\n", 255 | "context:CREATE TABLE table_11173827_1 (english_title VARCHAR, finale VARCHAR, peak VARCHAR) question:What is the english title that has finale as 33 and peak as 42? answer:SELECT title FROM table_11173827_1 WHERE finale = 33 AND peak = 42<|endoftext|>\n", 256 | "=================\n", 257 | "SELECT english_title FROM table_11173827_1 WHERE finale = 33 AND peak = 42\n", 258 | "=================\n", 259 | "400\n", 260 | "context:CREATE TABLE table_name_80 (year INTEGER, date VARCHAR, designated_home VARCHAR) question:What year had a date of TBA where the Oakland raiders were the home team? answer:SELECT year FROM table_name_80 WHERE designated_home = \"Oakland\"<|endoftext|>\n", 261 | "=================\n", 262 | "SELECT AVG(year) FROM table_name_80 WHERE date = \"tba\" AND designated_home = \"oakland raiders\"\n", 263 | "=================\n", 264 | "500\n", 265 | "context:CREATE TABLE table_name_66 (cost___us$__ VARCHAR, open_source VARCHAR) question:What is the cost of an Open Source that is no? answer:SELECT cost___us$__ FROM table_name_66 WHERE open_source = \"Open Source\"<|endoftext|>\n", 266 | "=================\n", 267 | "SELECT cost___us$__ FROM table_name_66 WHERE open_source = \"no\"\n", 268 | "=================\n", 269 | "600\n", 270 | "context:CREATE TABLE table_name_11 (winner VARCHAR, year VARCHAR) question:What is Winner, when Year is 2013? answer:SELECT Winner FROM table_name_11 WHERE year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year =\n", 271 | "=================\n", 272 | "SELECT winner FROM table_name_11 WHERE year = 2013\n", 273 | "=================\n", 274 | "700\n", 275 | "context:CREATE TABLE table_name_44 (years VARCHAR, displacement VARCHAR) question:Which years have a displacement of 1816cc? answer:SELECT years FROM table_name_44 WHERE displacement = 1816cc<|endoftext|>\n", 276 | "=================\n", 277 | "SELECT years FROM table_name_44 WHERE displacement = \"1816cc\"\n", 278 | "=================\n", 279 | "800\n", 280 | "context:CREATE TABLE table_26982362_2 (original_airdate VARCHAR, production_code VARCHAR) question:The episode with production code 693-002, has how many original airdates? answer:SELECT production_code FROM table_26982362_2 WHERE production_code = 693-002<|endoftext|>\n", 281 | "=================\n", 282 | "SELECT COUNT(original_airdate) FROM table_26982362_2 WHERE production_code = \"693-002\"\n", 283 | "=================\n", 284 | "900\n", 285 | "context:CREATE TABLE table_name_24 (games INTEGER, term_ VARCHAR, c_ VARCHAR) question:What is the of games when for the term [c] of 1969 – 1973? answer:SELECT games FROM table_name_24 WHERE term_ = \"1969\"<|endoftext|>\n", 286 | "=================\n", 287 | "SELECT SUM(games) FROM table_name_24 WHERE term_[c_] = \"1969 – 1973\"\n", 288 | "=================\n", 289 | "1000\n", 290 | "context:CREATE TABLE table_18498743_1 (october_20 VARCHAR, _2008 VARCHAR, mexico VARCHAR) question:what is the october 20, 2008 stat where mexico stat is romania answer:SELECT date FROM table_18498743_1 WHERE date = \"2008\"<|endoftext|>\n", 291 | "=================\n", 292 | "SELECT october_20, _2008 FROM table_18498743_1 WHERE mexico = \"Romania\"\n", 293 | "=================\n", 294 | "1100\n", 295 | "context:CREATE TABLE table_3005915_3 (starts INTEGER) question:What is the maximum number of starts? answer:SELECT start FROM table_3005915_3 WHERE start < 3005915_3<|endoftext|>\n", 296 | "=================\n", 297 | "SELECT MAX(starts) FROM table_3005915_3\n", 298 | "=================\n", 299 | "1200\n", 300 | "context:CREATE TABLE table_name_17 (date VARCHAR, venue VARCHAR) question:What was the date of the game at Lake Oval? answer:SELECT date FROM table_name_17 WHERE venue = \"Lake Oval\"<|endoftext|>\n", 301 | "=================\n", 302 | "SELECT date FROM table_name_17 WHERE venue = \"lake oval\"\n", 303 | "=================\n", 304 | "1300\n", 305 | "context:CREATE TABLE table_name_21 (player VARCHAR, total VARCHAR, year_s__won VARCHAR) question:Who won in 1988 with a total less than 287? answer:SELECT player FROM table_name_21 WHERE total = 287<|endoftext|>\n", 306 | "=================\n", 307 | "SELECT player FROM table_name_21 WHERE total < 287 AND year_s__won = \"1988\"\n", 308 | "=================\n", 309 | "1400\n", 310 | "context:CREATE TABLE table_23224961_1 (oberbayern_b VARCHAR, oberpfalz VARCHAR) question:When fc schwandorf is the oberpfalz what is the oberbayern b? answer:SELECT fc, oberpfalz FROM table_23224961_1 WHERE oberpfalz = \"schwandorf\"<|endoftext|>\n", 311 | "=================\n", 312 | "SELECT oberbayern_b FROM table_23224961_1 WHERE oberpfalz = \"FC Schwandorf\"\n", 313 | "=================\n", 314 | "1500\n", 315 | "context:CREATE TABLE table_name_88 (lane INTEGER, name VARCHAR) question:What is the average lane that is called rebecca brown? answer:SELECT MAX(lane) FROM table_name_88 WHERE name = \"rebecca brown\"<|endoftext|>\n", 316 | "=================\n", 317 | "SELECT AVG(lane) FROM table_name_88 WHERE name = \"rebecca brown\"\n", 318 | "=================\n", 319 | "1600\n", 320 | "context:CREATE TABLE table_name_24 (games INTEGER, term_ VARCHAR, c_ VARCHAR) question:What is the of games when for the term [c] of 1969 – 1973? answer:SELECT games FROM table_name_24 WHERE term_ = \"1969\"<|endoftext|>\n", 321 | "=================\n", 322 | "SELECT SUM(games) FROM table_name_24 WHERE term_[c_] = \"1969 – 1973\"\n", 323 | "=================\n", 324 | "1700\n", 325 | "context:CREATE TABLE table_name_87 (type VARCHAR, location VARCHAR) question:What type of Bridge is in Stanley? answer:SELECT type FROM table_name_87 WHERE location = \"Stanley\"<|endoftext|>\n", 326 | "=================\n", 327 | "SELECT type FROM table_name_87 WHERE location = \"stanley\"\n", 328 | "=================\n", 329 | "1800\n", 330 | "context:CREATE TABLE table_name_34 (attendance VARCHAR, date VARCHAR, week VARCHAR) question:What was the attendance of the game on December 13, 1970? answer:SELECT attendance FROM table_name_34 WHERE date = \"2013-12-13T00:00:00Z\"<|endoftext|>\n", 331 | "=================\n", 332 | "SELECT COUNT(attendance) FROM table_name_34 WHERE date = \"december 13, 1970\" AND week > 13\n", 333 | "=================\n", 334 | "1900\n", 335 | "context:CREATE TABLE constructors (nationality VARCHAR) question:What are the numbers of constructors for different nationalities? answer:SELECT nationality FROM constructors WHERE nationality = \"nationality\" AND country = \"nationality\"<|endoftext|>\n", 336 | "=================\n", 337 | "SELECT COUNT(*), nationality FROM constructors GROUP BY nationality\n", 338 | "=================\n", 339 | "2000\n", 340 | "context:CREATE TABLE table_name_93 (home_team VARCHAR, venue VARCHAR) question:Which home team played at MCG? answer:SELECT venue FROM table_name_93 WHERE venue = \"MCG\"<|endoftext|>\n", 341 | "=================\n", 342 | "SELECT home_team FROM table_name_93 WHERE venue = \"mcg\"\n", 343 | "=================\n" 344 | ] 345 | }, 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "TrainOutput(global_step=2000, training_loss=0.0007831706330180168, metrics={'train_runtime': 1160.3168, 'train_samples_per_second': 27.579, 'train_steps_per_second': 1.724, 'total_flos': 0.0, 'train_loss': 0.0007831706330180168, 'epoch': 0.45})" 350 | ] 351 | }, 352 | "execution_count": 5, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "from transformers import TrainingArguments, TrainerCallback\n", 359 | "from trl import DPOTrainer\n", 360 | "import random\n", 361 | "\n", 362 | "args = TrainingArguments(per_device_train_batch_size=16,\n", 363 | " max_steps=2000,\n", 364 | " learning_rate=1e-5,\n", 365 | " evaluation_strategy='no',\n", 366 | " optim='rmsprop',\n", 367 | " report_to='none',\n", 368 | " save_strategy='no',\n", 369 | " output_dir='output_dir')\n", 370 | "\n", 371 | "\n", 372 | "class MyCallback(TrainerCallback):\n", 373 | "\n", 374 | " def on_step_end(self, args, state, control, **kwargs):\n", 375 | " if state.global_step % 100 == 0:\n", 376 | " print(state.global_step)\n", 377 | "\n", 378 | " data = random.choice(dataset['test'])\n", 379 | " input_ids = tokenizer.encode(data['prompt'],\n", 380 | " return_tensors='pt').to('cuda')\n", 381 | "\n", 382 | " out = generate(input_ids)\n", 383 | "\n", 384 | " print(tokenizer.decode(out[0]))\n", 385 | " print('=================')\n", 386 | " print(data['chosen'])\n", 387 | " print('=================')\n", 388 | "\n", 389 | "\n", 390 | "trainer = DPOTrainer(model_dpo,\n", 391 | " model_dpo_ref,\n", 392 | " args=args,\n", 393 | " beta=0.1,\n", 394 | " train_dataset=dataset['train'],\n", 395 | " tokenizer=tokenizer,\n", 396 | " max_length=100,\n", 397 | " max_target_length=100,\n", 398 | " max_prompt_length=100,\n", 399 | " callbacks=[MyCallback()])\n", 400 | "\n", 401 | "trainer.train()" 402 | ] 403 | } 404 | ], 405 | "metadata": { 406 | "kernelspec": { 407 | "display_name": "Python [conda env:pt2]", 408 | "language": "python", 409 | "name": "conda-env-pt2-py" 410 | }, 411 | "language_info": { 412 | "codemirror_mode": { 413 | "name": "ipython", 414 | "version": 3 415 | }, 416 | "file_extension": ".py", 417 | "mimetype": "text/x-python", 418 | "name": "python", 419 | "nbconvert_exporter": "python", 420 | "pygments_lexer": "ipython3", 421 | "version": "3.10.13" 422 | } 423 | }, 424 | "nbformat": 4, 425 | "nbformat_minor": 5 426 | } 427 | -------------------------------------------------------------------------------- /2.dpo_手动训练.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "03697c23", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n", 15 | "Using sep_token, but it is not set yet.\n", 16 | "Using cls_token, but it is not set yet.\n", 17 | "Using mask_token, but it is not set yet.\n" 18 | ] 19 | }, 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "GPT2TokenizerFast(name_or_path='tokenizer/gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", 24 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 25 | "}" 26 | ] 27 | }, 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "output_type": "execute_result" 31 | } 32 | ], 33 | "source": [ 34 | "from transformers import AutoTokenizer\n", 35 | "import random\n", 36 | "import torch\n", 37 | "\n", 38 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 39 | "\n", 40 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2')\n", 41 | "tokenizer.pad_token_id = 0\n", 42 | "\n", 43 | "tokenizer" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "id": "733f6090", 50 | "metadata": { 51 | "scrolled": true 52 | }, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "(DatasetDict({\n", 58 | " train: Dataset({\n", 59 | " features: ['question', 'answer'],\n", 60 | " num_rows: 71745\n", 61 | " })\n", 62 | " test: Dataset({\n", 63 | " features: ['question', 'answer'],\n", 64 | " num_rows: 200\n", 65 | " })\n", 66 | " }),\n", 67 | " {'question': [22866,\n", 68 | " 25,\n", 69 | " 43387,\n", 70 | " 6158,\n", 71 | " 43679,\n", 72 | " 3084,\n", 73 | " 62,\n", 74 | " 3672,\n", 75 | " 62,\n", 76 | " 2414,\n", 77 | " 357,\n", 78 | " 354,\n", 79 | " 20297,\n", 80 | " 569,\n", 81 | " 31315,\n", 82 | " 1503,\n", 83 | " 11,\n", 84 | " 614,\n", 85 | " 17828,\n", 86 | " 7156,\n", 87 | " 1137,\n", 88 | " 8,\n", 89 | " 1808,\n", 90 | " 25,\n", 91 | " 2061,\n", 92 | " 318,\n", 93 | " 262,\n", 94 | " 24587,\n", 95 | " 706,\n", 96 | " 10249,\n", 97 | " 30,\n", 98 | " 3280,\n", 99 | " 25],\n", 100 | " 'answer': [46506,\n", 101 | " 24587,\n", 102 | " 16034,\n", 103 | " 3084,\n", 104 | " 62,\n", 105 | " 3672,\n", 106 | " 62,\n", 107 | " 2414,\n", 108 | " 33411,\n", 109 | " 614,\n", 110 | " 1875,\n", 111 | " 10249]})" 112 | ] 113 | }, 114 | "execution_count": 2, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "from datasets import load_from_disk\n", 121 | "\n", 122 | "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n", 123 | "\n", 124 | "\n", 125 | "def f(data):\n", 126 | " question = 'context:%s question:%s answer:' % (data['context'],\n", 127 | " data['question'])\n", 128 | " answer = data['answer']\n", 129 | "\n", 130 | " question = tokenizer.encode(question, add_special_tokens=False)\n", 131 | " answer = tokenizer.encode(answer, add_special_tokens=False)\n", 132 | "\n", 133 | " return {'question': question, 'answer': answer}\n", 134 | "\n", 135 | "\n", 136 | "dataset = dataset.map(f, remove_columns=['context'])\n", 137 | "\n", 138 | "\n", 139 | "def f(data):\n", 140 | " question = len(data['question'])\n", 141 | " answer = len(data['answer'])\n", 142 | " return 25 <= question <= 65 and 10 <= answer <= 35\n", 143 | "\n", 144 | "\n", 145 | "dataset = dataset.filter(f)\n", 146 | "\n", 147 | "dataset = dataset.train_test_split(test_size=200)\n", 148 | "\n", 149 | "dataset, dataset['train'][0]" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 3, 155 | "id": "0e940d0b", 156 | "metadata": { 157 | "scrolled": true 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "({'input_ids': tensor([[22866, 25, 43387, ..., 0, 0, 0],\n", 164 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 165 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 166 | " ...,\n", 167 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 168 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 169 | " [22866, 25, 43387, ..., 0, 0, 0]], device='cuda:0'),\n", 170 | " 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", 171 | " [1, 1, 1, ..., 0, 0, 0],\n", 172 | " [1, 1, 1, ..., 0, 0, 0],\n", 173 | " ...,\n", 174 | " [1, 1, 1, ..., 0, 0, 0],\n", 175 | " [1, 1, 1, ..., 0, 0, 0],\n", 176 | " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'),\n", 177 | " 'label': tensor([[-100, -100, -100, ..., -100, -100, -100],\n", 178 | " [-100, -100, -100, ..., -100, -100, -100],\n", 179 | " [-100, -100, -100, ..., -100, -100, -100],\n", 180 | " ...,\n", 181 | " [-100, -100, -100, ..., -100, -100, -100],\n", 182 | " [-100, -100, -100, ..., -100, -100, -100],\n", 183 | " [-100, -100, -100, ..., -100, -100, -100]], device='cuda:0')},\n", 184 | " {'input_ids': tensor([[22866, 25, 43387, ..., 0, 0, 0],\n", 185 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 186 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 187 | " ...,\n", 188 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 189 | " [22866, 25, 43387, ..., 0, 0, 0],\n", 190 | " [22866, 25, 43387, ..., 0, 0, 0]], device='cuda:0'),\n", 191 | " 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", 192 | " [1, 1, 1, ..., 0, 0, 0],\n", 193 | " [1, 1, 1, ..., 0, 0, 0],\n", 194 | " ...,\n", 195 | " [1, 1, 1, ..., 0, 0, 0],\n", 196 | " [1, 1, 1, ..., 0, 0, 0],\n", 197 | " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'),\n", 198 | " 'label': tensor([[-100, -100, -100, ..., -100, -100, -100],\n", 199 | " [-100, -100, -100, ..., -100, -100, -100],\n", 200 | " [-100, -100, -100, ..., -100, -100, -100],\n", 201 | " ...,\n", 202 | " [-100, -100, -100, ..., -100, -100, -100],\n", 203 | " [-100, -100, -100, ..., -100, -100, -100],\n", 204 | " [-100, -100, -100, ..., -100, -100, -100]], device='cuda:0')})" 205 | ] 206 | }, 207 | "execution_count": 3, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "def get_batch_data():\n", 214 | "\n", 215 | " def pad(data, split, lens):\n", 216 | " #做个白板\n", 217 | " input_ids = torch.full((len(data), lens),\n", 218 | " tokenizer.pad_token_id,\n", 219 | " device=device)\n", 220 | "\n", 221 | " #往白板里黏贴数据\n", 222 | " for i, d in enumerate(data):\n", 223 | " input_ids[i, :len(d)] = torch.LongTensor(d)\n", 224 | "\n", 225 | " attention_mask = (input_ids != tokenizer.pad_token_id).long()\n", 226 | "\n", 227 | " #计算label\n", 228 | " label = input_ids.clone()\n", 229 | " for l, s in zip(label, split):\n", 230 | " #问题和pad的位置是-100\n", 231 | " l[:s] = -100\n", 232 | " l[l == tokenizer.pad_token_id] = -100\n", 233 | "\n", 234 | " return {\n", 235 | " 'input_ids': input_ids,\n", 236 | " 'attention_mask': attention_mask,\n", 237 | " 'label': label\n", 238 | " }\n", 239 | "\n", 240 | " sample = random.choices(dataset['train'], k=16)\n", 241 | " question = [i['question'] for i in sample]\n", 242 | " answer = [i['answer'] for i in sample]\n", 243 | " split = [len(i) for i in question]\n", 244 | "\n", 245 | " #正确的问答\n", 246 | " choice = [\n", 247 | " q + a + [tokenizer.eos_token_id] for q, a in zip(question, answer)\n", 248 | " ]\n", 249 | "\n", 250 | " #错误的回答简单地定义为空回答就可以了\n", 251 | " reject = [q + [tokenizer.eos_token_id] for q, a in zip(question, answer)]\n", 252 | "\n", 253 | " #求最大长度\n", 254 | " lens = max([len(i) for i in choice])\n", 255 | "\n", 256 | " return pad(choice, split, lens), pad(reject, split, lens)\n", 257 | "\n", 258 | "\n", 259 | "get_batch_data()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 4, 265 | "id": "98c47306", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "class ModelDPO(torch.nn.Module):\n", 270 | "\n", 271 | " def __init__(self):\n", 272 | " super().__init__()\n", 273 | " from transformers import AutoModelForCausalLM\n", 274 | "\n", 275 | " self.model = AutoModelForCausalLM.from_pretrained('model/gpt2')\n", 276 | "\n", 277 | " self.to(device)\n", 278 | " self.train()\n", 279 | "\n", 280 | " def forward(self, input_ids, attention_mask):\n", 281 | " out = self.model.transformer(\n", 282 | " input_ids=input_ids,\n", 283 | " attention_mask=attention_mask).last_hidden_state\n", 284 | "\n", 285 | " return self.model.lm_head(out)\n", 286 | "\n", 287 | "\n", 288 | "model_dpo = ModelDPO()\n", 289 | "model_dpo_ref = ModelDPO()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 5, 295 | "id": "e0e48bda", 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "'context:CREATE TABLE table_name_8 (pos VARCHAR, date_from VARCHAR) question:What is Pos., when Date From is \"28 August 2008\"? answer:: \"What is, when is a table?\"\\n\\nThe table_name_name__name_8 is a string that contains the name of the table name.'" 302 | ] 303 | }, 304 | "execution_count": 5, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | } 308 | ], 309 | "source": [ 310 | "@torch.no_grad()\n", 311 | "def generate(input_ids):\n", 312 | " lens = input_ids.shape[1]\n", 313 | " while True:\n", 314 | " out = model_dpo(input_ids=input_ids,\n", 315 | " attention_mask=torch.ones_like(input_ids))\n", 316 | " topk = out[0, -1].topk(1)\n", 317 | "\n", 318 | " values = topk.values.softmax(0).tolist()\n", 319 | " indices = topk.indices.tolist()\n", 320 | " next_word = random.choices(indices, weights=values)\n", 321 | "\n", 322 | " next_word = torch.LongTensor(next_word).unsqueeze(0).to('cuda')\n", 323 | " input_ids = torch.cat([input_ids, next_word], dim=1)\n", 324 | "\n", 325 | " if input_ids.shape[1] - lens >= 35:\n", 326 | " break\n", 327 | "\n", 328 | " if input_ids[0, -1] == tokenizer.eos_token_id:\n", 329 | " break\n", 330 | "\n", 331 | " return input_ids\n", 332 | "\n", 333 | "\n", 334 | "input_ids = dataset['test'][0]['question']\n", 335 | "input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)\n", 336 | "\n", 337 | "out = generate(input_ids)\n", 338 | "\n", 339 | "tokenizer.decode(out[0])" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 6, 345 | "id": "c35e3b28", 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "tensor([-58.1640, -61.3286, -39.4775, -71.7320, -62.9812, -67.9002, -35.7873,\n", 352 | " -55.2316, -55.5689, -73.9519, -81.4958, -37.8672, -55.4771, -86.0505,\n", 353 | " -48.4158, -43.4941], device='cuda:0', grad_fn=)" 354 | ] 355 | }, 356 | "execution_count": 6, 357 | "metadata": {}, 358 | "output_type": "execute_result" 359 | } 360 | ], 361 | "source": [ 362 | "def get_prob_log(model, choice, reject):\n", 363 | " b = choice['input_ids'].shape[0]\n", 364 | "\n", 365 | " #合并两部分输入,同时计算以提高效率\n", 366 | " #[b, 21]\n", 367 | " input_ids = torch.cat([choice['input_ids'], reject['input_ids']], dim=0)\n", 368 | " attention_mask = torch.cat(\n", 369 | " [choice['attention_mask'], reject['attention_mask']], dim=0)\n", 370 | " label = torch.cat([choice['label'], reject['label']], dim=0)\n", 371 | "\n", 372 | " #[b, 21, 28]\n", 373 | " out = model(input_ids=input_ids, attention_mask=attention_mask)\n", 374 | "\n", 375 | " #偏移以对齐\n", 376 | " #[b, 20]\n", 377 | " label = label[:, 1:]\n", 378 | " #[b, 20, 28]\n", 379 | " out = out[:, :-1]\n", 380 | "\n", 381 | " #取所有字的预测概率,因为要求联合概率,所以取对数\n", 382 | " out = (out.softmax(2) + 1e-8).log()\n", 383 | "\n", 384 | " #取预测到label的概率\n", 385 | " #索引不能是负数,所以这里把负数置0\n", 386 | " index = label.clone().unsqueeze(2)\n", 387 | " index[index == -100] = 0\n", 388 | " prob = out.gather(2, index=index).squeeze(2)\n", 389 | "\n", 390 | " #只取答案部分的loss,筛选后,所有答案的概率对数求和\n", 391 | " prob = (prob * (label != -100)).sum(1)\n", 392 | "\n", 393 | " #choice和reject的预测概率求差\n", 394 | " return prob[:b] - prob[b:]\n", 395 | "\n", 396 | "\n", 397 | "get_prob_log(model_dpo, *get_batch_data())" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 7, 403 | "id": "fb049323", 404 | "metadata": { 405 | "scrolled": false 406 | }, 407 | "outputs": [ 408 | { 409 | "name": "stdout", 410 | "output_type": "stream", 411 | "text": [ 412 | "0 context:CREATE TABLE table_name_50 (character_name VARCHAR, voice_actor__english_1998___pioneer_ VARCHAR) question:what character did Laara sadiq play answer:what character did Laara play answer:what character did Laara sadiq play answer:what character did Laara sadiq play answer:what character did Laara play answer\n", 413 | "=========\n", 414 | "SELECT character_name FROM table_name_50 WHERE voice_actor__english_1998___pioneer_ = \"laara sadiq\"\n", 415 | "=========\n", 416 | "100 context:CREATE TABLE table_name_94 (position VARCHAR, pick VARCHAR) question:What is pick 246's position? answer:SELECT position FROM table_name_94 WHERE position = \"pick 246\"<|endoftext|>\n", 417 | "=========\n", 418 | "SELECT position FROM table_name_94 WHERE pick = 246\n", 419 | "=========\n", 420 | "200 context:CREATE TABLE table_name_78 (division_record VARCHAR, school VARCHAR) question:What is the division record for Woodbridge? answer:SELECT division_record FROM table_name_78 WHERE school = \"woodbridge\"<|endoftext|>\n", 421 | "=========\n", 422 | "SELECT division_record FROM table_name_78 WHERE school = \"woodbridge\"\n", 423 | "=========\n", 424 | "300 context:CREATE TABLE table_name_20 (name VARCHAR, nat VARCHAR) question:What's the name of MKD? answer:SELECT name FROM table_name_20 WHERE nat = \"MKD\"<|endoftext|>\n", 425 | "=========\n", 426 | "SELECT name FROM table_name_20 WHERE nat = \"mkd\"\n", 427 | "=========\n", 428 | "400 context:CREATE TABLE table_21313327_1 (written_by VARCHAR, no_in_season VARCHAR) question:In season number 3, who were the writers? answer:SELECT written_by FROM table_21313327_1 WHERE no_in_season = 3<|endoftext|>\n", 429 | "=========\n", 430 | "SELECT written_by FROM table_21313327_1 WHERE no_in_season = 3\n", 431 | "=========\n", 432 | "500 context:CREATE TABLE table_name_95 (grid INTEGER, time_retired VARCHAR, laps VARCHAR) question:what is the highest grid when the time/retired is fuel pump and the laps is more than 26? answer:SELECT MAX(grid) FROM table_name_95 WHERE time_retired = \"pump and laps\" AND laps > 26<|endoftext|>\n", 433 | "=========\n", 434 | "SELECT MAX(grid) FROM table_name_95 WHERE time_retired = \"fuel pump\" AND laps > 26\n", 435 | "=========\n", 436 | "600 context:CREATE TABLE table_name_31 (top_10 INTEGER, top_25 VARCHAR, cuts_made VARCHAR) question:What is the average number of top-10s for the major with 2 top-25s and fewer than 10 cuts made? answer:SELECT AVG(top_10) FROM table_name_31 WHERE cuts_made = 2 AND top_25 < 10<|endoftext|>\n", 437 | "=========\n", 438 | "SELECT AVG(top_10) FROM table_name_31 WHERE top_25 = 2 AND cuts_made < 10\n", 439 | "=========\n", 440 | "700 context:CREATE TABLE table_name_94 (position VARCHAR, pick VARCHAR) question:What is pick 246's position? answer:SELECT position FROM table_name_94 WHERE pick = \"246\"<|endoftext|>\n", 441 | "=========\n", 442 | "SELECT position FROM table_name_94 WHERE pick = 246\n", 443 | "=========\n", 444 | "800 context:CREATE TABLE table_17358515_1 (points_2 VARCHAR, team VARCHAR) question:How many points did Goole Town accumulate? answer:SELECT points_2 FROM table_17358515_1 WHERE team = \"Goole Town\"<|endoftext|>\n", 445 | "=========\n", 446 | "SELECT COUNT(points_2) FROM table_17358515_1 WHERE team = \"Goole Town\"\n", 447 | "=========\n", 448 | "900 context:CREATE TABLE table_11677691_9 (hometown VARCHAR, college VARCHAR) question:Which hometown has the college Louisiana State? answer:SELECT hometown FROM table_11677691_9 WHERE college = \"Louisiana State\"<|endoftext|>\n", 449 | "=========\n", 450 | "SELECT hometown FROM table_11677691_9 WHERE college = \"Louisiana State\"\n", 451 | "=========\n", 452 | "1000 context:CREATE TABLE table_1341423_40 (candidates VARCHAR, incumbent VARCHAR) question:who are the candidates when the incumbent is lindsey graham? answer:SELECT candidates FROM table_1341423_40 WHERE incumbent = \"Lindsey Graham\"<|endoftext|>\n", 453 | "=========\n", 454 | "SELECT candidates FROM table_1341423_40 WHERE incumbent = \"Lindsey Graham\"\n", 455 | "=========\n", 456 | "1100 context:CREATE TABLE table_name_77 (country VARCHAR, series_premiere VARCHAR) question:Which country had a series that premiered on September 4, 2006? answer:SELECT country FROM table_name_77 WHERE series_premiere = \"september 4, 2006\"<|endoftext|>\n", 457 | "=========\n", 458 | "SELECT country FROM table_name_77 WHERE series_premiere = \"september 4, 2006\"\n", 459 | "=========\n", 460 | "1200 context:CREATE TABLE table_name_35 (director VARCHAR, title VARCHAR) question:Who is the director of Antz? answer:SELECT director FROM table_name_35 WHERE title = \"antz\"<|endoftext|>\n", 461 | "=========\n", 462 | "SELECT director FROM table_name_35 WHERE title = \"antz\"\n", 463 | "=========\n", 464 | "1300 context:CREATE TABLE table_name_55 (role VARCHAR, direction VARCHAR) question:Which role had thulasidas direction? answer:SELECT role FROM table_name_55 WHERE direction = \"thulasidas direction\"<|endoftext|>\n", 465 | "=========\n", 466 | "SELECT role FROM table_name_55 WHERE direction = \"thulasidas\"\n", 467 | "=========\n", 468 | "1400 context:CREATE TABLE table_24223834_3 (us_viewers__in_millions_ VARCHAR, directed_by VARCHAR) question:How many million U.S. viewers watched the episode that Daniel H. Forer directed? answer:SELECT COUNT(us_viewers__in_millions_) FROM table_24223834_3 WHERE directed_by = \"Daniel H. Forer\n", 469 | "=========\n", 470 | "SELECT us_viewers__in_millions_ FROM table_24223834_3 WHERE directed_by = \"Daniel H. Forer\"\n", 471 | "=========\n", 472 | "1500 context:CREATE TABLE table_name_89 (score VARCHAR, player VARCHAR) question:What is the score for Jock Hutchison? answer:SELECT score FROM table_name_89 WHERE player = \"jock hutch jock hutchison\"<|endoftext|>\n", 473 | "=========\n", 474 | "SELECT score FROM table_name_89 WHERE player = \"jock hutchison\"\n", 475 | "=========\n", 476 | "1600 context:CREATE TABLE table_24018430_3 (original_air_date VARCHAR, production_code VARCHAR) question:what is the original air date for production code 216? answer:SELECT original_air_date FROM table_24018430_3 WHERE production_code = \"216\"<|endoftext|>\n", 477 | "=========\n", 478 | "SELECT original_air_date FROM table_24018430_3 WHERE production_code = 216\n", 479 | "=========\n", 480 | "1700 context:CREATE TABLE table_name_32 (west VARCHAR, east VARCHAR) question:Who was in the West when ESV Gebensbach was in the East? answer:SELECT west FROM table_name_32 WHERE east = \"ebvesbach\"<|endoftext|>\n", 481 | "=========\n", 482 | "SELECT west FROM table_name_32 WHERE east = \"esv gebensbach\"\n", 483 | "=========\n", 484 | "1800 context:CREATE TABLE table_name_98 (to_par VARCHAR, margin_of_victory VARCHAR, tournament VARCHAR) question:What is To par, when Margin of Victory is \"2 Strokes\", and when Tournament is \"Women's British Open\"? answer:SELECT to_par FROM table_name_98 WHERE margin_of_victory = \"2 strokes\" AND tournament = \"women's British Open\"<|endoftext|>\n", 485 | "=========\n", 486 | "SELECT to_par FROM table_name_98 WHERE margin_of_victory = \"2 strokes\" AND tournament = \"women's british open\"\n", 487 | "=========\n", 488 | "1900 context:CREATE TABLE table_name_71 (wheels INTEGER, type VARCHAR, location VARCHAR) question:What average wheels has accounting as the type, with IBM Collection as the location? answer:SELECT AVG(wheels) FROM table_name_71 WHERE type = \"audit\" AND location = \"ibm collection\"<|endoftext|>\n", 489 | "=========\n", 490 | "SELECT AVG(wheels) FROM table_name_71 WHERE type = \"accounting\" AND location = \"ibm collection\"\n", 491 | "=========\n" 492 | ] 493 | } 494 | ], 495 | "source": [ 496 | "optimizer = torch.optim.Adam(model_dpo.parameters(),\n", 497 | " lr=1e-5,\n", 498 | " betas=(0.9, 0.999),\n", 499 | " eps=1e-8)\n", 500 | "\n", 501 | "for i in range(2000):\n", 502 | " choice, reject = get_batch_data()\n", 503 | "\n", 504 | " #两个模型分别计算概率对数\n", 505 | " prob_log = get_prob_log(model_dpo, choice, reject)\n", 506 | " with torch.no_grad():\n", 507 | " prob_log_ref = get_prob_log(model_dpo_ref, choice, reject)\n", 508 | "\n", 509 | " #两份概率计算kl散度\n", 510 | " kl = -0.1 * (prob_log - prob_log_ref)\n", 511 | "\n", 512 | " #以kl散度计算loss\n", 513 | " loss = (kl.sigmoid() + 1e-8).log().mean()\n", 514 | " loss.backward()\n", 515 | " optimizer.step()\n", 516 | " optimizer.zero_grad()\n", 517 | "\n", 518 | " if i % 100 == 0:\n", 519 | " data = random.choice(dataset['test'])\n", 520 | " input_ids = torch.LongTensor(data['question']).unsqueeze(0).to(device)\n", 521 | "\n", 522 | " out = generate(input_ids)\n", 523 | "\n", 524 | " print(i, tokenizer.decode(out[0]))\n", 525 | " print('=========')\n", 526 | " print(tokenizer.decode(data['answer']))\n", 527 | " print('=========')" 528 | ] 529 | } 530 | ], 531 | "metadata": { 532 | "kernelspec": { 533 | "display_name": "Python [conda env:pt2]", 534 | "language": "python", 535 | "name": "conda-env-pt2-py" 536 | }, 537 | "language_info": { 538 | "codemirror_mode": { 539 | "name": "ipython", 540 | "version": 3 541 | }, 542 | "file_extension": ".py", 543 | "mimetype": "text/x-python", 544 | "name": "python", 545 | "nbconvert_exporter": "python", 546 | "pygments_lexer": "ipython3", 547 | "version": "3.10.13" 548 | } 549 | }, 550 | "nbformat": 4, 551 | "nbformat_minor": 5 552 | } 553 | -------------------------------------------------------------------------------- /3.ppo_trl训练.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ab835cce", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n", 15 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", 16 | "Using sep_token, but it is not set yet.\n", 17 | "Using cls_token, but it is not set yet.\n", 18 | "Using mask_token, but it is not set yet.\n" 19 | ] 20 | }, 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "GPT2TokenizerFast(name_or_path='tokenizer/lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!', 'additional_special_tokens': ['<|endoftext|>']}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", 25 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 26 | "}" 27 | ] 28 | }, 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "output_type": "execute_result" 32 | } 33 | ], 34 | "source": [ 35 | "from transformers import AutoTokenizer\n", 36 | "import random\n", 37 | "import torch\n", 38 | "\n", 39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 40 | "\n", 41 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')\n", 42 | "tokenizer.pad_token_id = 0\n", 43 | "\n", 44 | "tokenizer" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "dc57406a", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "(Dataset({\n", 57 | " features: ['question'],\n", 58 | " num_rows: 50000\n", 59 | " }),\n", 60 | " {'question': [40, 26399, 314, 3001, 327]})" 61 | ] 62 | }, 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "from datasets import load_from_disk, concatenate_datasets\n", 70 | "\n", 71 | "dataset = load_from_disk('dataset/imdb')\n", 72 | "dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])\n", 73 | "\n", 74 | "\n", 75 | "def f(data):\n", 76 | " question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]\n", 77 | " return {'question': question}\n", 78 | "\n", 79 | "\n", 80 | "dataset = dataset.map(f, remove_columns=['label', 'text'])\n", 81 | "\n", 82 | "\n", 83 | "def f(data):\n", 84 | " return len(data['question']) == 5\n", 85 | "\n", 86 | "\n", 87 | "dataset = dataset.filter(f)\n", 88 | "\n", 89 | "dataset, dataset[0]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "id": "12dbc381", 96 | "metadata": { 97 | "scrolled": true 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "([1,\n", 104 | " 1,\n", 105 | " 1,\n", 106 | " 1,\n", 107 | " 0,\n", 108 | " 1,\n", 109 | " 0,\n", 110 | " 0,\n", 111 | " 1,\n", 112 | " 0,\n", 113 | " 0,\n", 114 | " 0,\n", 115 | " 0,\n", 116 | " 0,\n", 117 | " 0,\n", 118 | " 1,\n", 119 | " 1,\n", 120 | " 0,\n", 121 | " 0,\n", 122 | " 1,\n", 123 | " 1,\n", 124 | " 1,\n", 125 | " 0,\n", 126 | " 0,\n", 127 | " 1,\n", 128 | " 0,\n", 129 | " 0,\n", 130 | " 1,\n", 131 | " 0,\n", 132 | " 0,\n", 133 | " 0,\n", 134 | " 1,\n", 135 | " 1,\n", 136 | " 0,\n", 137 | " 0,\n", 138 | " 0,\n", 139 | " 1,\n", 140 | " 1,\n", 141 | " 0,\n", 142 | " 0,\n", 143 | " 0,\n", 144 | " 1,\n", 145 | " 1,\n", 146 | " 1,\n", 147 | " 1,\n", 148 | " 1,\n", 149 | " 0,\n", 150 | " 1,\n", 151 | " 1,\n", 152 | " 1,\n", 153 | " 1,\n", 154 | " 1,\n", 155 | " 1,\n", 156 | " 1,\n", 157 | " 0,\n", 158 | " 1,\n", 159 | " 1,\n", 160 | " 1,\n", 161 | " 1,\n", 162 | " 1,\n", 163 | " 0,\n", 164 | " 0,\n", 165 | " 0,\n", 166 | " 0,\n", 167 | " 1,\n", 168 | " 0,\n", 169 | " 1,\n", 170 | " 0,\n", 171 | " 1,\n", 172 | " 1,\n", 173 | " 0,\n", 174 | " 0,\n", 175 | " 1,\n", 176 | " 0,\n", 177 | " 0,\n", 178 | " 0,\n", 179 | " 1,\n", 180 | " 0,\n", 181 | " 0,\n", 182 | " 1,\n", 183 | " 1,\n", 184 | " 1,\n", 185 | " 1,\n", 186 | " 0,\n", 187 | " 1,\n", 188 | " 1,\n", 189 | " 1,\n", 190 | " 1,\n", 191 | " 1,\n", 192 | " 1,\n", 193 | " 0,\n", 194 | " 0,\n", 195 | " 1,\n", 196 | " 1,\n", 197 | " 1,\n", 198 | " 0,\n", 199 | " 0,\n", 200 | " 1,\n", 201 | " 0,\n", 202 | " 0,\n", 203 | " 1,\n", 204 | " 0,\n", 205 | " 0,\n", 206 | " 1,\n", 207 | " 0,\n", 208 | " 1,\n", 209 | " 1,\n", 210 | " 0,\n", 211 | " 0,\n", 212 | " 1,\n", 213 | " 0,\n", 214 | " 1,\n", 215 | " 1,\n", 216 | " 1,\n", 217 | " 0,\n", 218 | " 0,\n", 219 | " 0,\n", 220 | " 0,\n", 221 | " 1,\n", 222 | " 0,\n", 223 | " 0,\n", 224 | " 1,\n", 225 | " 0,\n", 226 | " 1,\n", 227 | " 0,\n", 228 | " 1,\n", 229 | " 1,\n", 230 | " 0],\n", 231 | " [[16, 36623, 290, 15579, 1111, 2495],\n", 232 | " [16, 39, 2743, 16543, 318, 546],\n", 233 | " [16, 5962, 612, 373, 23459, 72],\n", 234 | " [16, 7149, 2618, 508, 2925, 284],\n", 235 | " [15, 40, 3505, 4964, 428, 319],\n", 236 | " [16, 3152, 262, 7396, 11533, 286],\n", 237 | " [15, 43920, 32520, 26494, 15992, 11],\n", 238 | " [15, 27991, 1472, 318, 3737, 530],\n", 239 | " [16, 4598, 407, 307, 6522, 28970],\n", 240 | " [15, 464, 29274, 45270, 11, 7808],\n", 241 | " [15, 1212, 2646, 2900, 510, 319],\n", 242 | " [15, 464, 691, 661, 1312, 561],\n", 243 | " [15, 2025, 21218, 306, 24007, 6833],\n", 244 | " [15, 38202, 4835, 15300, 1525, 7504],\n", 245 | " [15, 40, 550, 8359, 262, 15812],\n", 246 | " [16, 4514, 2000, 286, 1450, 33743],\n", 247 | " [16, 11006, 290, 24967, 7091, 7012],\n", 248 | " [15, 5962, 314, 1276, 910, 326],\n", 249 | " [15, 21448, 25699, 423, 4367, 262],\n", 250 | " [16, 8499, 4379, 428, 2646, 329],\n", 251 | " [16, 26556, 1468, 4947, 286, 3923],\n", 252 | " [16, 5005, 7697, 360, 5945, 259],\n", 253 | " [15, 3666, 477, 12, 2435, 4004],\n", 254 | " [15, 13615, 284, 2185, 318, 257],\n", 255 | " [16, 32, 21104, 12006, 290, 37928],\n", 256 | " [15, 40, 423, 1239, 1100, 262],\n", 257 | " [15, 1858, 318, 257, 1256, 2642],\n", 258 | " [16, 1212, 3807, 373, 257, 13899],\n", 259 | " [15, 1212, 318, 257, 6635, 7427],\n", 260 | " [15, 40, 836, 470, 760, 703],\n", 261 | " [15, 7, 4561, 9437, 364, 9426],\n", 262 | " [16, 40, 460, 470, 1037, 475],\n", 263 | " [16, 22442, 284, 428, 2646, 11],\n", 264 | " [15, 40, 7342, 705, 4834, 7670],\n", 265 | " [15, 1858, 318, 257, 1621, 357],\n", 266 | " [15, 3198, 286, 2407, 257, 1178],\n", 267 | " [16, 15784, 736, 625, 262, 1613],\n", 268 | " [16, 2514, 307, 5508, 11, 262],\n", 269 | " [15, 1212, 318, 655, 530, 286],\n", 270 | " [15, 40, 1101, 407, 1654, 703],\n", 271 | " [15, 1212, 3807, 318, 34445, 257],\n", 272 | " [16, 7556, 2140, 2646, 546, 2646],\n", 273 | " [16, 40, 1053, 1107, 8359, 428],\n", 274 | " [16, 1, 818, 262, 995, 286],\n", 275 | " [16, 17821, 3923, 787, 262, 1266],\n", 276 | " [16, 1, 464, 7610, 5932, 1],\n", 277 | " [15, 49738, 12754, 22732, 3807, 326],\n", 278 | " [16, 464, 13172, 286, 2094, 11952],\n", 279 | " [16, 5779, 11, 314, 3214, 329],\n", 280 | " [16, 40, 1053, 1100, 617, 286],\n", 281 | " [16, 19419, 257, 4719, 11, 530],\n", 282 | " [16, 40, 716, 523, 9675, 44922],\n", 283 | " [16, 3137, 7342, 340, 11, 502],\n", 284 | " [16, 40, 1816, 284, 1524, 351],\n", 285 | " [15, 1212, 905, 318, 1049, 13],\n", 286 | " [16, 464, 717, 286, 1936, 520],\n", 287 | " [16, 40, 423, 1239, 587, 355],\n", 288 | " [16, 3237, 262, 6247, 7721, 276],\n", 289 | " [16, 40, 3505, 4379, 262, 12268],\n", 290 | " [16, 1212, 318, 257, 2646, 1312],\n", 291 | " [15, 40, 9159, 314, 423, 257],\n", 292 | " [15, 7975, 286, 262, 867, 7328],\n", 293 | " [15, 21944, 273, 18365, 363, 11925],\n", 294 | " [15, 1858, 338, 645, 779, 2111],\n", 295 | " [16, 1212, 2646, 318, 2089, 13],\n", 296 | " [15, 3844, 517, 621, 1683, 356],\n", 297 | " [16, 11696, 13606, 11, 257, 6036],\n", 298 | " [15, 32, 13206, 11, 5508, 11],\n", 299 | " [16, 1026, 373, 28294, 0, 383],\n", 300 | " [16, 5703, 44207, 1576, 284, 307],\n", 301 | " [15, 40, 550, 30508, 262, 3807],\n", 302 | " [15, 10814, 3730, 290, 4813, 0],\n", 303 | " [16, 32, 6507, 11, 6507, 6504],\n", 304 | " [15, 28718, 5285, 25, 383, 15875],\n", 305 | " [15, 3792, 612, 257, 1492, 11946],\n", 306 | " [15, 34, 49399, 1313, 373, 12132],\n", 307 | " [16, 1639, 561, 1107, 761, 284],\n", 308 | " [15, 4480, 644, 484, 550, 13],\n", 309 | " [15, 1212, 2646, 3607, 649, 3616],\n", 310 | " [16, 40, 550, 7342, 1811, 1528],\n", 311 | " [16, 7554, 21827, 749, 9857, 2646],\n", 312 | " [16, 40, 2497, 326, 428, 3807],\n", 313 | " [16, 6610, 16304, 0, 4162, 319],\n", 314 | " [15, 50, 707, 340, 379, 262],\n", 315 | " [16, 40, 6198, 2497, 428, 319],\n", 316 | " [16, 43977, 25, 383, 4403, 13069],\n", 317 | " [16, 1135, 1094, 88, 9214, 789],\n", 318 | " [16, 5779, 11, 484, 1908, 340],\n", 319 | " [16, 10970, 406, 6158, 6006, 32297],\n", 320 | " [16, 1212, 2646, 3033, 734, 286],\n", 321 | " [15, 76, 40302, 14509, 1559, 318],\n", 322 | " [15, 37, 48437, 338, 645, 343],\n", 323 | " [16, 1026, 338, 4998, 326, 14549],\n", 324 | " [16, 10374, 16057, 9737, 271, 392],\n", 325 | " [16, 3666, 5212, 290, 314, 3332],\n", 326 | " [15, 28951, 64, 318, 257, 845],\n", 327 | " [15, 40, 2497, 428, 3807, 257],\n", 328 | " [16, 49141, 29330, 318, 14702, 11],\n", 329 | " [15, 37887, 618, 257, 5581, 3182],\n", 330 | " [15, 38413, 17049, 1760, 13, 317],\n", 331 | " [16, 40, 7342, 366, 37, 28789],\n", 332 | " [15, 17947, 3978, 11255, 1441, 1363],\n", 333 | " [15, 33481, 5289, 743, 407, 423],\n", 334 | " [16, 1212, 318, 407, 257, 14604],\n", 335 | " [15, 1212, 4471, 20718, 514, 284],\n", 336 | " [16, 40, 1807, 428, 373, 257],\n", 337 | " [16, 1212, 373, 281, 3499, 2050],\n", 338 | " [15, 51, 13, 57, 13, 2947],\n", 339 | " [15, 464, 7110, 340, 338, 407],\n", 340 | " [16, 3633, 262, 19661, 287, 428],\n", 341 | " [15, 32, 7818, 26337, 25, 8381],\n", 342 | " [16, 1212, 3807, 4394, 5626, 39],\n", 343 | " [16, 11380, 886, 286, 262, 1621],\n", 344 | " [16, 21447, 2444, 379, 257, 8294],\n", 345 | " [15, 1722, 262, 3670, 5644, 612],\n", 346 | " [15, 32, 12362, 922, 3807, 0],\n", 347 | " [15, 1212, 3807, 3190, 4966, 24590],\n", 348 | " [15, 44960, 27633, 4244, 290, 262],\n", 349 | " [16, 1, 44, 18526, 1, 318],\n", 350 | " [15, 38, 385, 6656, 10844, 468],\n", 351 | " [15, 40, 7342, 428, 938, 1755],\n", 352 | " [16, 1722, 287, 1703, 417, 494],\n", 353 | " [15, 505, 1110, 2130, 531, 8781],\n", 354 | " [16, 1212, 318, 1194, 286, 616],\n", 355 | " [15, 1, 34, 963, 6495, 1],\n", 356 | " [16, 40, 1422, 470, 1100, 262],\n", 357 | " [16, 40, 373, 12451, 366, 28524],\n", 358 | " [15, 4366, 2365, 499, 2771, 3660]])" 359 | ] 360 | }, 361 | "execution_count": 3, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "def get_batch_data():\n", 368 | " label = random.choices(range(2), k=128)\n", 369 | " question = random.choices(dataset, k=128)\n", 370 | " question = [i['question'] for i in question]\n", 371 | "\n", 372 | " question = [[tokenizer.convert_tokens_to_ids(str(l))] + q\n", 373 | " for l, q in zip(label, question)]\n", 374 | "\n", 375 | " return label, question\n", 376 | "\n", 377 | "\n", 378 | "get_batch_data()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 4, 384 | "id": "81ee37d4", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stderr", 389 | "output_type": "stream", 390 | "text": [ 391 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", 392 | " warnings.warn(\n", 393 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n", 394 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 395 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 396 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n", 397 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 398 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 399 | ] 400 | } 401 | ], 402 | "source": [ 403 | "from trl import AutoModelForCausalLMWithValueHead\n", 404 | "\n", 405 | "model_ppo = AutoModelForCausalLMWithValueHead.from_pretrained(\n", 406 | " 'model/lvwerra/gpt2-imdb').to(device)\n", 407 | "model_ppo_ref = AutoModelForCausalLMWithValueHead.from_pretrained(\n", 408 | " 'model/lvwerra/gpt2-imdb').to(device)\n", 409 | "\n", 410 | "for i in model_ppo_ref.parameters():\n", 411 | " i.requires_grad_(False)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 5, 417 | "id": "805234d3", 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stderr", 422 | "output_type": "stream", 423 | "text": [ 424 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 425 | ] 426 | } 427 | ], 428 | "source": [ 429 | "from transformers import AutoModelForSequenceClassification\n", 430 | "\n", 431 | "tokenizer_cls = AutoTokenizer.from_pretrained(\n", 432 | " 'tokenizer/lvwerra/distilbert-imdb')\n", 433 | "model_cls = AutoModelForSequenceClassification.from_pretrained(\n", 434 | " 'model/lvwerra/distilbert-imdb').to(device)\n", 435 | "\n", 436 | "for i in model_cls.parameters():\n", 437 | " i.requires_grad_(False)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 6, 443 | "id": "c3e7a327", 444 | "metadata": { 445 | "scrolled": true 446 | }, 447 | "outputs": [ 448 | { 449 | "data": { 450 | "text/plain": [ 451 | "(tensor([1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1,\n", 452 | " 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,\n", 453 | " 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,\n", 454 | " 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1,\n", 455 | " 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1,\n", 456 | " 0, 1, 1, 0, 0, 1, 0, 0], device='cuda:0'),\n", 457 | " [tensor([ 16, 40, 1053, 6810, 625, 262], device='cuda:0'),\n", 458 | " tensor([ 16, 1212, 3807, 373, 11, 286], device='cuda:0'),\n", 459 | " tensor([ 15, 3673, 881, 284, 910, 319], device='cuda:0'),\n", 460 | " tensor([ 15, 20459, 8066, 307, 1327, 284], device='cuda:0'),\n", 461 | " tensor([ 15, 38743, 31140, 2492, 470, 262], device='cuda:0'),\n", 462 | " tensor([ 15, 2061, 257, 2085, 286, 257], device='cuda:0'),\n", 463 | " tensor([ 16, 40, 460, 766, 1521, 43442], device='cuda:0'),\n", 464 | " tensor([ 16, 40, 3521, 470, 4043, 284], device='cuda:0'),\n", 465 | " tensor([16, 49, 13, 40, 13, 34], device='cuda:0'),\n", 466 | " tensor([ 15, 1212, 2646, 318, 257, 5003], device='cuda:0')])" 467 | ] 468 | }, 469 | "execution_count": 6, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "def get_question():\n", 476 | " label, question = get_batch_data()\n", 477 | " label = torch.LongTensor(label).to(device)\n", 478 | "\n", 479 | " question = [torch.LongTensor(i).to(device) for i in question]\n", 480 | "\n", 481 | " return label, question\n", 482 | "\n", 483 | "\n", 484 | "label, question = get_question()\n", 485 | "\n", 486 | "label, question[:10]" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 7, 492 | "id": "b77e7d96", 493 | "metadata": { 494 | "scrolled": true 495 | }, 496 | "outputs": [ 497 | { 498 | "data": { 499 | "text/plain": [ 500 | "[tensor([ 812, 326, 3427, 14479, 3371, 6918, 11, 2592, 739, 262,\n", 501 | " 34731, 286, 366, 7353, 12, 23922, 1042, 1, 389, 783,\n", 502 | " 5371, 286, 11560, 257, 43122, 290, 326, 22041, 2331, 1234,\n", 503 | " 379, 9033], device='cuda:0'),\n", 504 | " tensor([ 1781, 11, 2208, 45002, 13, 314, 1392, 604, 14, 940,\n", 505 | " 737, 33443, 10544, 13, 383, 7205, 373, 407, 2089, 11,\n", 506 | " 4632, 616, 6000, 18641, 286, 262, 3807, 2277, 616, 1767,\n", 507 | " 13, 406], device='cuda:0'),\n", 508 | " tensor([ 287, 597, 584, 2912, 780, 1111, 423, 617, 922, 2173,\n", 509 | " 329, 514, 13, 1119, 1111, 4893, 845, 880, 703, 6283,\n", 510 | " 262, 1621, 318, 290, 635, 14846, 14397, 319, 257, 1241,\n", 511 | " 326, 338], device='cuda:0'),\n", 512 | " tensor([ 804, 379, 606, 290, 852, 15049, 644, 257, 2823, 428,\n", 513 | " 318, 13, 314, 460, 470, 787, 503, 257, 2060, 1517,\n", 514 | " 546, 262, 3807, 314, 892, 475, 314, 481, 9159, 1312,\n", 515 | " 8288, 428], device='cuda:0'),\n", 516 | " tensor([ 691, 6575, 8442, 286, 262, 705, 2154, 82, 508, 336,\n", 517 | " 2954, 357, 392, 20143, 284, 4656, 737, 50256],\n", 518 | " device='cuda:0'),\n", 519 | " tensor([ 3807, 329, 1105, 290, 1315, 614, 44979, 29847, 1671, 1220,\n", 520 | " 6927, 1671, 11037, 1212, 530, 318, 257, 36364, 5287, 13,\n", 521 | " 3887, 2168, 1281, 994, 468, 4054, 11, 4556, 286, 1781,\n", 522 | " 345, 389], device='cuda:0'),\n", 523 | " tensor([ 290, 27583, 1661, 3947, 1342, 3616, 913, 777, 1528, 13,\n", 524 | " 921, 892, 1353, 29838, 10544, 460, 3491, 287, 257, 4101,\n", 525 | " 5664, 3807, 290, 1577, 9739, 477, 832, 262, 835, 30,\n", 526 | " 1892, 287], device='cuda:0'),\n", 527 | " tensor([ 651, 736, 319, 262, 6614, 13, 1675, 307, 3148, 11, 340, 373,\n", 528 | " 5867, 3625, 890, 13, 843, 340, 373, 546, 9796, 3625, 416, 3439,\n", 529 | " 3625, 890, 13, 1320, 1838, 340, 804, 588], device='cuda:0'),\n", 530 | " tensor([ 13, 48, 13, 379, 262, 7369, 838, 19245, 319, 42490,\n", 531 | " 318, 407, 284, 307, 287, 262, 1551, 1643, 14702, 13,\n", 532 | " 1026, 2125, 470, 1327, 284, 1560, 655, 703, 10927, 3619,\n", 533 | " 40365, 318], device='cuda:0'),\n", 534 | " tensor([ 9875, 422, 923, 284, 5461, 351, 257, 21840, 4226, 290,\n", 535 | " 281, 6275, 3350, 326, 16316, 23501, 22605, 351, 257, 23754,\n", 536 | " 2854, 355, 34705, 2241, 13, 30436, 6550, 30673, 11392, 318,\n", 537 | " 257, 3756], device='cuda:0')]" 538 | ] 539 | }, 540 | "execution_count": 7, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "#包装类,用于生成\n", 547 | "def generate(input_ids):\n", 548 | " return model_ppo.generate(input_ids=input_ids,\n", 549 | " min_length=-1,\n", 550 | " top_k=0.0,\n", 551 | " top_p=1.0,\n", 552 | " do_sample=True,\n", 553 | " pad_token_id=tokenizer.pad_token_id,\n", 554 | " max_new_tokens=32,\n", 555 | " eos_token_id=tokenizer.eos_token_id)\n", 556 | "\n", 557 | "\n", 558 | "def get_answer(question):\n", 559 | " #如果question的长度确定,这里可以转换成批运算\n", 560 | " if True:\n", 561 | " answer = generate(torch.stack(question))\n", 562 | "\n", 563 | " answer_new = []\n", 564 | " for i in answer:\n", 565 | " if tokenizer.eos_token_id not in i:\n", 566 | " answer_new.append(i.unsqueeze(0))\n", 567 | " continue\n", 568 | " split = i.tolist().index(tokenizer.eos_token_id) + 1\n", 569 | " answer_new.append(i[:split].unsqueeze(0))\n", 570 | " answer = answer_new\n", 571 | " else:\n", 572 | " answer = [generate(i.unsqueeze(0)) for i in question]\n", 573 | "\n", 574 | " #裁剪,只要生成的部分\n", 575 | " answer = [a[0, len(q):] for q, a in zip(question, answer)]\n", 576 | "\n", 577 | " return answer\n", 578 | "\n", 579 | "\n", 580 | "answer = get_answer(question)\n", 581 | "\n", 582 | "answer[:10]" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 8, 588 | "id": "720b4c0f", 589 | "metadata": { 590 | "scrolled": true 591 | }, 592 | "outputs": [ 593 | { 594 | "data": { 595 | "text/plain": [ 596 | "tensor([-0.0790, -2.7419, 0.0679, -1.6332, 0.6677, 2.6076, 0.1573, -1.1461,\n", 597 | " 0.9582, -2.5122, 2.7525, -2.3757, 1.0361, 2.5690, 1.9334, -2.1080,\n", 598 | " -1.7346, 0.6678, -1.2840, -0.0737, 1.8030, -1.5229, -1.9619, -1.7785,\n", 599 | " -1.3398, 2.2020, 0.3973, -1.9178, -1.7657, -0.1403, -1.7730, 2.3305,\n", 600 | " 1.4456, -1.9255, -2.2839, 1.7811, -2.6034, 1.0441, -0.5461, 1.6941,\n", 601 | " 2.0946, -1.7392, 0.8828, -0.9679, 0.2184, -0.4701, -2.4185, -0.0232,\n", 602 | " -0.0851, 2.2899, -2.2363, -2.2239, -1.2585, 1.6426, -0.6541, 0.9921,\n", 603 | " 0.2738, 0.2257, 1.1809, -1.2928, 2.2934, 2.8806, 1.1941, -1.3584,\n", 604 | " -0.4178, 2.0240, 2.0432, 0.2561, -1.0444, -0.7836, -1.6604, 0.9020,\n", 605 | " -2.4616, 0.5171, 0.3906, 1.7344, 1.1841, -1.4855, 2.4969, -1.4334,\n", 606 | " -1.9962, 0.0924, -1.0023, 2.7003, -0.4900, -1.1669, -1.4747, -1.2043,\n", 607 | " 0.4568, 2.5980, 1.8720, -1.3922, 2.7693, -2.4912, -1.5629, -1.2905,\n", 608 | " 2.4059, 1.9968, -0.8069, -1.0914, 0.9309, -2.6868, -0.8958, 2.4808,\n", 609 | " -2.9966, -1.4021, -2.5262, 1.7169, -1.8984, -0.9218, 2.3929, 2.2298,\n", 610 | " 0.2727, 0.1386, -0.5772, 0.9471, 0.2255, 1.2544, 1.9891, 1.8280,\n", 611 | " -0.1826, -1.3578, -0.3683, -1.0374, 0.7353, 1.8427, -2.1924, -1.1235],\n", 612 | " device='cuda:0')" 613 | ] 614 | }, 615 | "execution_count": 8, 616 | "metadata": {}, 617 | "output_type": "execute_result" 618 | } 619 | ], 620 | "source": [ 621 | "def get_reward(question, answer, label):\n", 622 | " token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]\n", 623 | " token = [tokenizer.decode(i) for i in token]\n", 624 | "\n", 625 | " token = tokenizer_cls(token,\n", 626 | " padding=True,\n", 627 | " truncation=True,\n", 628 | " max_length=512,\n", 629 | " return_tensors='pt').to(device)\n", 630 | "\n", 631 | " with torch.no_grad():\n", 632 | " logits = model_cls(**token).logits\n", 633 | "\n", 634 | " return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n", 635 | "\n", 636 | "\n", 637 | "reward = get_reward(question, answer, label)\n", 638 | "\n", 639 | "reward" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 9, 645 | "id": "0ff781ae", 646 | "metadata": {}, 647 | "outputs": [ 648 | { 649 | "data": { 650 | "text/plain": [ 651 | "" 652 | ] 653 | }, 654 | "execution_count": 9, 655 | "metadata": {}, 656 | "output_type": "execute_result" 657 | } 658 | ], 659 | "source": [ 660 | "from trl import PPOConfig, PPOTrainer\n", 661 | "\n", 662 | "config = PPOConfig(learning_rate=1e-5, batch_size=128)\n", 663 | "\n", 664 | "trainer = PPOTrainer(config,\n", 665 | " model_ppo,\n", 666 | " model_ppo_ref,\n", 667 | " tokenizer,\n", 668 | " dataset=dataset)\n", 669 | "\n", 670 | "trainer" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 10, 676 | "id": "a1c70e8e", 677 | "metadata": { 678 | "scrolled": true 679 | }, 680 | "outputs": [ 681 | { 682 | "name": "stderr", 683 | "output_type": "stream", 684 | "text": [ 685 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 686 | "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" 687 | ] 688 | }, 689 | { 690 | "name": "stdout", 691 | "output_type": "stream", 692 | "text": [ 693 | "0 0.09299226105213165\n", 694 | "1First of all, this -> movie is hardly stupid. However, it is one of the worst. It has two very real scenarios.Along with the efforts of a young couple of lost hero -2.487795829772949\n" 695 | ] 696 | }, 697 | { 698 | "name": "stderr", 699 | "output_type": "stream", 700 | "text": [ 701 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 702 | ] 703 | }, 704 | { 705 | "name": "stdout", 706 | "output_type": "stream", 707 | "text": [ 708 | "10 -0.023181095719337463\n", 709 | "0In my book \"Basic -> Instinct\" released 17 years ago I read the book and think that such kind of things are the way the world is now. And quotes like \"All that -1.2708756923675537\n" 710 | ] 711 | }, 712 | { 713 | "name": "stderr", 714 | "output_type": "stream", 715 | "text": [ 716 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 717 | ] 718 | }, 719 | { 720 | "name": "stdout", 721 | "output_type": "stream", 722 | "text": [ 723 | "20 0.31382036209106445\n", 724 | "0Considering it's basically low -> budget and a lot of black nudity and an odd response from the audience, the film is a pretty per ml resolution to an issue worthy of not being all that 1.1325664520263672\n" 725 | ] 726 | }, 727 | { 728 | "name": "stderr", 729 | "output_type": "stream", 730 | "text": [ 731 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 732 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 733 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 734 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 735 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 736 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 737 | ] 738 | }, 739 | { 740 | "name": "stdout", 741 | "output_type": "stream", 742 | "text": [ 743 | "30 0.309106707572937\n", 744 | "0True stories make the best -> work of good actors none that are born into their mid-seventies.<|endoftext|> -2.210801601409912\n" 745 | ] 746 | }, 747 | { 748 | "name": "stderr", 749 | "output_type": "stream", 750 | "text": [ 751 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 752 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 753 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 754 | ] 755 | }, 756 | { 757 | "name": "stdout", 758 | "output_type": "stream", 759 | "text": [ 760 | "40 1.0727498531341553\n", 761 | "1Cheesy 80's horror -> ...... composer: so beautifully chosen, I can say, has a on-screen performance even better.<|endoftext|> 2.0839757919311523\n" 762 | ] 763 | }, 764 | { 765 | "name": "stderr", 766 | "output_type": "stream", 767 | "text": [ 768 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 769 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 770 | ] 771 | }, 772 | { 773 | "name": "stdout", 774 | "output_type": "stream", 775 | "text": [ 776 | "50 1.4121887683868408\n", 777 | "0GBS wrote his own -> script. One name near offense but atmospheric exposition, oddly still bland. Wipeout, rubbish. Violence, i slows.<|endoftext|> 2.3595705032348633\n" 778 | ] 779 | }, 780 | { 781 | "name": "stderr", 782 | "output_type": "stream", 783 | "text": [ 784 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 785 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 786 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 787 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 788 | ] 789 | }, 790 | { 791 | "name": "stdout", 792 | "output_type": "stream", 793 | "text": [ 794 | "60 1.7078416347503662\n", 795 | "0I was never in the -> film.<|endoftext|> 0.9322419166564941\n" 796 | ] 797 | }, 798 | { 799 | "name": "stderr", 800 | "output_type": "stream", 801 | "text": [ 802 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 803 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 804 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 805 | ] 806 | }, 807 | { 808 | "name": "stdout", 809 | "output_type": "stream", 810 | "text": [ 811 | "70 2.1755259037017822\n", 812 | "1I must say, when -> watching this movie, it is amazing.<|endoftext|> 2.4521334171295166\n" 813 | ] 814 | }, 815 | { 816 | "name": "stderr", 817 | "output_type": "stream", 818 | "text": [ 819 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 820 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 821 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 822 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 823 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 824 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 825 | ] 826 | }, 827 | { 828 | "name": "stdout", 829 | "output_type": "stream", 830 | "text": [ 831 | "80 2.1622257232666016\n", 832 | "1This is one of the -> best films of all time.<|endoftext|> 2.7308685779571533\n" 833 | ] 834 | }, 835 | { 836 | "name": "stderr", 837 | "output_type": "stream", 838 | "text": [ 839 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 840 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 841 | ] 842 | }, 843 | { 844 | "name": "stdout", 845 | "output_type": "stream", 846 | "text": [ 847 | "90 2.1979665756225586\n", 848 | "1I resisted watching 15 Park -> but so loved seeing it!<|endoftext|> 2.1706929206848145\n" 849 | ] 850 | }, 851 | { 852 | "name": "stderr", 853 | "output_type": "stream", 854 | "text": [ 855 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 856 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 857 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 858 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 859 | ] 860 | }, 861 | { 862 | "name": "stdout", 863 | "output_type": "stream", 864 | "text": [ 865 | "100 2.2108092308044434\n", 866 | "0This is a Laurel & -> Hardy terrible movie.<|endoftext|> 1.9358655214309692\n" 867 | ] 868 | }, 869 | { 870 | "name": "stderr", 871 | "output_type": "stream", 872 | "text": [ 873 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 874 | ] 875 | }, 876 | { 877 | "name": "stdout", 878 | "output_type": "stream", 879 | "text": [ 880 | "110 2.288398504257202\n", 881 | "1This was the first Mickey -> Renn masteringcom to medieval master and movie makers, and truly supported their development toward English film. highly recommended.<|endoftext|> 2.7567362785339355\n" 882 | ] 883 | }, 884 | { 885 | "name": "stderr", 886 | "output_type": "stream", 887 | "text": [ 888 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 889 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 890 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 891 | ] 892 | }, 893 | { 894 | "name": "stdout", 895 | "output_type": "stream", 896 | "text": [ 897 | "120 2.1484575271606445\n", 898 | "1There have been several films -> and tv. This is great film for the ladies.<|endoftext|> 2.491626739501953\n" 899 | ] 900 | }, 901 | { 902 | "name": "stderr", 903 | "output_type": "stream", 904 | "text": [ 905 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 906 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 907 | ] 908 | }, 909 | { 910 | "name": "stdout", 911 | "output_type": "stream", 912 | "text": [ 913 | "130 2.2387020587921143\n", 914 | "1K Murli Mohan -> in happiness that was tremendously 1934, and terrific movie.<|endoftext|> 2.5319623947143555\n" 915 | ] 916 | }, 917 | { 918 | "name": "stderr", 919 | "output_type": "stream", 920 | "text": [ 921 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 922 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 923 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 924 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 925 | ] 926 | }, 927 | { 928 | "name": "stdout", 929 | "output_type": "stream", 930 | "text": [ 931 | "140 2.111945629119873\n", 932 | "0I don't really know -> how one will figure out how he's going to play the story, but this is a bad movie!<|endoftext|> 2.4522318840026855\n" 933 | ] 934 | }, 935 | { 936 | "name": "stderr", 937 | "output_type": "stream", 938 | "text": [ 939 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 940 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 941 | ] 942 | }, 943 | { 944 | "name": "stdout", 945 | "output_type": "stream", 946 | "text": [ 947 | "150 2.228043556213379\n", 948 | "1I will never get back -> this last glorious film.<|endoftext|> 2.100973606109619\n" 949 | ] 950 | }, 951 | { 952 | "name": "stderr", 953 | "output_type": "stream", 954 | "text": [ 955 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 956 | ] 957 | }, 958 | { 959 | "name": "stdout", 960 | "output_type": "stream", 961 | "text": [ 962 | "160 2.0983808040618896\n", 963 | "1Night of the Twisters -> . Truly one of the best I've seen.<|endoftext|> 2.794154644012451\n" 964 | ] 965 | }, 966 | { 967 | "name": "stderr", 968 | "output_type": "stream", 969 | "text": [ 970 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 971 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 972 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 973 | ] 974 | }, 975 | { 976 | "name": "stdout", 977 | "output_type": "stream", 978 | "text": [ 979 | "170 2.1417503356933594\n", 980 | "1This film was a new -> masterpiece.<|endoftext|> 2.473341464996338\n" 981 | ] 982 | }, 983 | { 984 | "name": "stderr", 985 | "output_type": "stream", 986 | "text": [ 987 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 988 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 989 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 990 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 991 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 992 | ] 993 | }, 994 | { 995 | "name": "stdout", 996 | "output_type": "stream", 997 | "text": [ 998 | "180 2.2145256996154785\n", 999 | "1William Faulkner was -> just perfect! An excellent movie.<|endoftext|> 2.788565158843994\n" 1000 | ] 1001 | }, 1002 | { 1003 | "name": "stderr", 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1007 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1008 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1009 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1010 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1011 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1012 | ] 1013 | }, 1014 | { 1015 | "name": "stdout", 1016 | "output_type": "stream", 1017 | "text": [ 1018 | "190 2.0492758750915527\n", 1019 | "0Mighty Morphin Power -> is an embarrassment.<|endoftext|> 2.2804903984069824\n" 1020 | ] 1021 | }, 1022 | { 1023 | "name": "stderr", 1024 | "output_type": "stream", 1025 | "text": [ 1026 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1027 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1028 | ] 1029 | } 1030 | ], 1031 | "source": [ 1032 | "import warnings\n", 1033 | "\n", 1034 | "warnings.filterwarnings('ignore')\n", 1035 | "\n", 1036 | "for epoch in range(200):\n", 1037 | " label, question = get_question()\n", 1038 | " answer = get_answer(question)\n", 1039 | " reward = get_reward(question, answer, label)\n", 1040 | "\n", 1041 | " trainer.step(question, answer, [i for i in reward])\n", 1042 | "\n", 1043 | " if epoch % 10 == 0:\n", 1044 | " print(epoch, reward.mean().item())\n", 1045 | "\n", 1046 | " question = tokenizer.decode(question[0].tolist())\n", 1047 | " answer = tokenizer.decode(answer[0].tolist())\n", 1048 | " reward = reward[0].item()\n", 1049 | "\n", 1050 | " #0差评,1好评\n", 1051 | " print(question, '->', answer, reward)" 1052 | ] 1053 | } 1054 | ], 1055 | "metadata": { 1056 | "kernelspec": { 1057 | "display_name": "Python [conda env:pt2]", 1058 | "language": "python", 1059 | "name": "conda-env-pt2-py" 1060 | }, 1061 | "language_info": { 1062 | "codemirror_mode": { 1063 | "name": "ipython", 1064 | "version": 3 1065 | }, 1066 | "file_extension": ".py", 1067 | "mimetype": "text/x-python", 1068 | "name": "python", 1069 | "nbconvert_exporter": "python", 1070 | "pygments_lexer": "ipython3", 1071 | "version": "3.10.13" 1072 | } 1073 | }, 1074 | "nbformat": 4, 1075 | "nbformat_minor": 5 1076 | } 1077 | -------------------------------------------------------------------------------- /4.ppo_手动训练.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ab835cce", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n", 15 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", 16 | "Using sep_token, but it is not set yet.\n", 17 | "Using cls_token, but it is not set yet.\n", 18 | "Using mask_token, but it is not set yet.\n" 19 | ] 20 | }, 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "GPT2TokenizerFast(name_or_path='tokenizer/lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!', 'additional_special_tokens': ['<|endoftext|>']}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", 25 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 26 | "}" 27 | ] 28 | }, 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "output_type": "execute_result" 32 | } 33 | ], 34 | "source": [ 35 | "from transformers import AutoTokenizer\n", 36 | "import random\n", 37 | "import torch\n", 38 | "\n", 39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 40 | "\n", 41 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')\n", 42 | "tokenizer.pad_token_id = 0\n", 43 | "\n", 44 | "tokenizer" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "dc57406a", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "(Dataset({\n", 57 | " features: ['question'],\n", 58 | " num_rows: 50000\n", 59 | " }),\n", 60 | " {'question': [40, 26399, 314, 3001, 327]})" 61 | ] 62 | }, 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "from datasets import load_from_disk, concatenate_datasets\n", 70 | "\n", 71 | "dataset = load_from_disk('dataset/imdb')\n", 72 | "dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])\n", 73 | "\n", 74 | "\n", 75 | "def f(data):\n", 76 | " question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]\n", 77 | " return {'question': question}\n", 78 | "\n", 79 | "\n", 80 | "dataset = dataset.map(f, remove_columns=['label', 'text'])\n", 81 | "\n", 82 | "\n", 83 | "def f(data):\n", 84 | " return len(data['question']) == 5\n", 85 | "\n", 86 | "\n", 87 | "dataset = dataset.filter(f)\n", 88 | "\n", 89 | "dataset, dataset[0]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "id": "12dbc381", 96 | "metadata": { 97 | "scrolled": true 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "([1,\n", 104 | " 1,\n", 105 | " 0,\n", 106 | " 0,\n", 107 | " 1,\n", 108 | " 0,\n", 109 | " 0,\n", 110 | " 1,\n", 111 | " 0,\n", 112 | " 0,\n", 113 | " 0,\n", 114 | " 0,\n", 115 | " 1,\n", 116 | " 1,\n", 117 | " 1,\n", 118 | " 0,\n", 119 | " 0,\n", 120 | " 0,\n", 121 | " 0,\n", 122 | " 1,\n", 123 | " 1,\n", 124 | " 1,\n", 125 | " 1,\n", 126 | " 0,\n", 127 | " 0,\n", 128 | " 1,\n", 129 | " 1,\n", 130 | " 0,\n", 131 | " 1,\n", 132 | " 0,\n", 133 | " 0,\n", 134 | " 1,\n", 135 | " 0,\n", 136 | " 1,\n", 137 | " 0,\n", 138 | " 0,\n", 139 | " 0,\n", 140 | " 0,\n", 141 | " 0,\n", 142 | " 0,\n", 143 | " 1,\n", 144 | " 1,\n", 145 | " 0,\n", 146 | " 1,\n", 147 | " 0,\n", 148 | " 0,\n", 149 | " 1,\n", 150 | " 0,\n", 151 | " 1,\n", 152 | " 0,\n", 153 | " 1,\n", 154 | " 1,\n", 155 | " 1,\n", 156 | " 1,\n", 157 | " 1,\n", 158 | " 0,\n", 159 | " 1,\n", 160 | " 0,\n", 161 | " 1,\n", 162 | " 1,\n", 163 | " 1,\n", 164 | " 1,\n", 165 | " 1,\n", 166 | " 0,\n", 167 | " 1,\n", 168 | " 0,\n", 169 | " 1,\n", 170 | " 0,\n", 171 | " 0,\n", 172 | " 0,\n", 173 | " 1,\n", 174 | " 0,\n", 175 | " 0,\n", 176 | " 0,\n", 177 | " 1,\n", 178 | " 1,\n", 179 | " 1,\n", 180 | " 0,\n", 181 | " 1,\n", 182 | " 1,\n", 183 | " 0,\n", 184 | " 1,\n", 185 | " 1,\n", 186 | " 1,\n", 187 | " 1,\n", 188 | " 1,\n", 189 | " 0,\n", 190 | " 0,\n", 191 | " 0,\n", 192 | " 1,\n", 193 | " 0,\n", 194 | " 0,\n", 195 | " 1,\n", 196 | " 0,\n", 197 | " 1,\n", 198 | " 1,\n", 199 | " 0,\n", 200 | " 0,\n", 201 | " 1,\n", 202 | " 1,\n", 203 | " 1,\n", 204 | " 1,\n", 205 | " 1,\n", 206 | " 1,\n", 207 | " 0,\n", 208 | " 0,\n", 209 | " 0,\n", 210 | " 0,\n", 211 | " 0,\n", 212 | " 0,\n", 213 | " 0,\n", 214 | " 1,\n", 215 | " 0,\n", 216 | " 1,\n", 217 | " 0,\n", 218 | " 1,\n", 219 | " 0,\n", 220 | " 1,\n", 221 | " 0,\n", 222 | " 0,\n", 223 | " 1,\n", 224 | " 0,\n", 225 | " 0,\n", 226 | " 0,\n", 227 | " 0,\n", 228 | " 0,\n", 229 | " 1,\n", 230 | " 0],\n", 231 | " [[16, 40, 423, 1775, 257, 1256],\n", 232 | " [16, 33, 18058, 286, 262, 4044],\n", 233 | " [15, 1, 464, 13037, 6289, 1],\n", 234 | " [15, 40, 460, 470, 3505, 618],\n", 235 | " [16, 1212, 318, 262, 5290, 286],\n", 236 | " [15, 40, 1816, 284, 766, 428],\n", 237 | " [15, 6109, 530, 815, 766, 428],\n", 238 | " [16, 1, 36091, 6592, 6711, 1],\n", 239 | " [15, 35700, 922, 2646, 422, 3771],\n", 240 | " [15, 36, 292, 813, 262, 5290],\n", 241 | " [15, 5211, 407, 766, 428, 2646],\n", 242 | " [15, 2504, 338, 257, 2089, 11],\n", 243 | " [16, 2061, 257, 6283, 8137, 318],\n", 244 | " [16, 33, 463, 16988, 290, 4768],\n", 245 | " [16, 35596, 1234, 11, 428, 318],\n", 246 | " [15, 3792, 340, 655, 502, 11],\n", 247 | " [15, 1212, 428, 2406, 286, 2479],\n", 248 | " [15, 40, 550, 878, 257, 4203],\n", 249 | " [15, 40, 4240, 644, 15579, 286],\n", 250 | " [16, 39, 1794, 699, 422, 923],\n", 251 | " [16, 464, 1772, 286, 6409, 16122],\n", 252 | " [16, 28211, 531, 326, 12887, 4462],\n", 253 | " [16, 1212, 3807, 14071, 517, 621],\n", 254 | " [15, 1532, 428, 2646, 4329, 257],\n", 255 | " [15, 32, 5249, 9298, 11, 4713],\n", 256 | " [16, 2061, 1312, 5465, 749, 287],\n", 257 | " [16, 1, 22073, 588, 257, 47808],\n", 258 | " [15, 40, 10014, 4379, 428, 2646],\n", 259 | " [16, 3351, 533, 47114, 318, 900],\n", 260 | " [15, 40, 2492, 470, 12451, 284],\n", 261 | " [15, 13295, 8849, 329, 6319, 2879],\n", 262 | " [16, 1212, 4141, 2646, 318, 13519],\n", 263 | " [15, 18050, 3944, 17731, 318, 1719],\n", 264 | " [16, 18690, 1312, 760, 262, 2656],\n", 265 | " [15, 1, 1639, 460, 7866, 1997],\n", 266 | " [15, 464, 734, 1243, 389, 389],\n", 267 | " [15, 25596, 948, 31078, 18258, 290],\n", 268 | " [15, 6, 32, 17388, 286, 4930],\n", 269 | " [15, 1212, 3807, 318, 2279, 475],\n", 270 | " [15, 1, 17821, 1, 1621, 286],\n", 271 | " [16, 464, 1578, 1829, 3170, 262],\n", 272 | " [16, 83, 359, 18804, 2540, 302],\n", 273 | " [15, 40, 655, 5201, 4964, 428],\n", 274 | " [16, 2348, 292, 11, 3595, 4345],\n", 275 | " [15, 1212, 6918, 318, 262, 1266],\n", 276 | " [15, 5189, 262, 1115, 816, 1124],\n", 277 | " [16, 21448, 43693, 2879, 6918, 547],\n", 278 | " [15, 4711, 8088, 326, 1624, 428],\n", 279 | " [16, 1722, 584, 8088, 423, 5081],\n", 280 | " [15, 40, 1043, 428, 281, 45421],\n", 281 | " [16, 36837, 11, 314, 2630, 326],\n", 282 | " [16, 1212, 3807, 318, 12659, 13],\n", 283 | " [16, 40, 716, 257, 22197, 18854],\n", 284 | " [16, 40, 1053, 34310, 383, 25976],\n", 285 | " [16, 3260, 616, 718, 614, 1468],\n", 286 | " [15, 25681, 1497, 422, 428, 3807],\n", 287 | " [16, 40, 373, 18680, 12617, 351],\n", 288 | " [15, 5195, 1422, 470, 41542, 423],\n", 289 | " [16, 40, 561, 423, 8359, 428],\n", 290 | " [16, 1212, 318, 262, 2081, 1621],\n", 291 | " [16, 1212, 3807, 318, 44089, 355],\n", 292 | " [16, 40, 717, 2497, 366, 29449],\n", 293 | " [16, 11380, 11, 880, 11, 645],\n", 294 | " [15, 40, 2497, 428, 2646, 355],\n", 295 | " [16, 1212, 318, 546, 257, 8805],\n", 296 | " [15, 15496, 661, 11, 27, 1671],\n", 297 | " [16, 46, 57, 318, 281, 1468],\n", 298 | " [15, 16454, 11, 314, 655, 550],\n", 299 | " [15, 40, 1276, 910, 314, 373],\n", 300 | " [15, 16773, 319, 11, 1309, 338],\n", 301 | " [16, 40, 5257, 284, 1700, 3336],\n", 302 | " [15, 1212, 3807, 1107, 2523, 663],\n", 303 | " [15, 2215, 2077, 355, 257, 2187],\n", 304 | " [15, 8332, 663, 2479, 11, 428],\n", 305 | " [16, 3844, 314, 550, 262, 1266],\n", 306 | " [16, 40, 550, 2938, 257, 6547],\n", 307 | " [16, 2601, 7626, 36389, 6026, 344],\n", 308 | " [15, 1212, 3807, 318, 7579, 11211],\n", 309 | " [16, 464, 7157, 286, 5830, 22430],\n", 310 | " [16, 1858, 423, 587, 1811, 7328],\n", 311 | " [15, 1532, 345, 389, 3612, 286],\n", 312 | " [16, 2514, 1998, 7123, 345, 1107],\n", 313 | " [16, 40, 1807, 262, 3807, 366],\n", 314 | " [16, 1858, 389, 617, 1243, 314],\n", 315 | " [16, 42338, 11, 314, 836, 18265],\n", 316 | " [16, 9, 31895, 6509, 9, 27],\n", 317 | " [15, 464, 36947, 318, 281, 4998],\n", 318 | " [15, 40, 5543, 1842, 428, 2646],\n", 319 | " [15, 3666, 4004, 905, 286, 477],\n", 320 | " [16, 3673, 2089, 13289, 13, 5338],\n", 321 | " [15, 32, 845, 23056, 1621, 11],\n", 322 | " [15, 40817, 11, 379, 717, 11],\n", 323 | " [16, 464, 717, 4756, 3715, 326],\n", 324 | " [15, 40, 460, 1833, 703, 41921],\n", 325 | " [16, 3260, 5586, 832, 262, 4512],\n", 326 | " [16, 16454, 986, 568, 1312, 1053],\n", 327 | " [15, 40, 3221, 588, 10997, 6918],\n", 328 | " [15, 1212, 318, 407, 257, 922],\n", 329 | " [16, 22017, 11, 428, 318, 845],\n", 330 | " [16, 16676, 262, 1438, 7939, 43503],\n", 331 | " [16, 2061, 4325, 618, 2130, 468],\n", 332 | " [16, 2396, 30248, 340, 2125, 470],\n", 333 | " [16, 49, 27015, 7328, 1464, 423],\n", 334 | " [16, 11158, 257, 3807, 810, 262],\n", 335 | " [15, 45, 14651, 17064, 293, 33500],\n", 336 | " [15, 40, 1053, 1100, 257, 1271],\n", 337 | " [15, 1, 2437, 1675, 44927, 14213],\n", 338 | " [15, 5189, 262, 3478, 10544, 508],\n", 339 | " [15, 464, 3807, 373, 3105, 11],\n", 340 | " [15, 40, 7342, 428, 3807, 257],\n", 341 | " [15, 50, 592, 12382, 25, 366],\n", 342 | " [16, 818, 1502, 284, 2883, 705],\n", 343 | " [15, 1, 8491, 921, 287, 262],\n", 344 | " [16, 18690, 11, 523, 340, 338],\n", 345 | " [15, 5703, 1392, 428, 287, 262],\n", 346 | " [16, 1, 3856, 322, 4754, 1],\n", 347 | " [15, 1858, 389, 867, 3840, 314],\n", 348 | " [16, 40, 3088, 284, 1577, 428],\n", 349 | " [15, 986, 392, 7685, 1312, 836],\n", 350 | " [15, 19703, 357, 23865, 4687, 88],\n", 351 | " [16, 1639, 714, 3800, 257, 2196],\n", 352 | " [15, 464, 2646, 318, 1912, 319],\n", 353 | " [15, 32, 32611, 3499, 923, 11],\n", 354 | " [15, 4561, 9437, 364, 287, 428],\n", 355 | " [15, 1, 2061, 561, 345, 466],\n", 356 | " [15, 72, 892, 428, 905, 318],\n", 357 | " [16, 1212, 3807, 318, 546, 257],\n", 358 | " [15, 1558, 14, 2999, 14, 2931]])" 359 | ] 360 | }, 361 | "execution_count": 3, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "def get_batch_data():\n", 368 | " label = random.choices(range(2), k=128)\n", 369 | " question = random.choices(dataset, k=128)\n", 370 | " question = [i['question'] for i in question]\n", 371 | "\n", 372 | " question = [[tokenizer.convert_tokens_to_ids(str(l))] + q\n", 373 | " for l, q in zip(label, question)]\n", 374 | "\n", 375 | " return label, question\n", 376 | "\n", 377 | "\n", 378 | "get_batch_data()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 4, 384 | "id": "81ee37d4", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stderr", 389 | "output_type": "stream", 390 | "text": [ 391 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.weight', 'v_head.summary.bias']\n", 392 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 393 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 394 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.weight', 'v_head.summary.bias']\n", 395 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 396 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 397 | ] 398 | } 399 | ], 400 | "source": [ 401 | "class ModelPPO(torch.nn.Module):\n", 402 | "\n", 403 | " def __init__(self):\n", 404 | " super().__init__()\n", 405 | " from transformers import AutoModelForCausalLM\n", 406 | "\n", 407 | " self.model_gen = AutoModelForCausalLM.from_pretrained(\n", 408 | " 'model/lvwerra/gpt2-imdb')\n", 409 | "\n", 410 | " self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),\n", 411 | " torch.nn.Linear(768, 1))\n", 412 | "\n", 413 | " self.to(device)\n", 414 | " self.train()\n", 415 | "\n", 416 | " def forward(self, input_ids, attention_mask):\n", 417 | " last_hidden_state = self.model_gen.transformer(\n", 418 | " input_ids=input_ids,\n", 419 | " attention_mask=attention_mask,\n", 420 | " output_hidden_states=True).last_hidden_state\n", 421 | "\n", 422 | " logits = self.model_gen.lm_head(last_hidden_state)\n", 423 | " value = self.v_head(last_hidden_state).squeeze(-1)\n", 424 | "\n", 425 | " return logits, value\n", 426 | "\n", 427 | "\n", 428 | "model_ppo = ModelPPO()\n", 429 | "model_ppo_ref = ModelPPO()\n", 430 | "\n", 431 | "for i in model_ppo_ref.parameters():\n", 432 | " i.requires_grad_(False)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 5, 438 | "id": "4b5f88e2", 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "tensor([-0.5304, -0.1680, -0.8149, -0.8351, -1.1403, -0.5243, 1.0303, -0.1843,\n", 445 | " 0.0434, 0.9319, -0.2327, 0.5249, 0.8442, -0.8438, 0.8771])" 446 | ] 447 | }, 448 | "execution_count": 5, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "def get_kl(a, b):\n", 455 | " method = 'kl'\n", 456 | "\n", 457 | " if method == 'kl':\n", 458 | " return a - b\n", 459 | "\n", 460 | " if method == 'abs':\n", 461 | " return (a - b).abs()\n", 462 | "\n", 463 | " if method == 'mse':\n", 464 | " return (a - b).square() * 0.5\n", 465 | "\n", 466 | " if method == 'full':\n", 467 | " return torch.nn.functional.kl_div(a,\n", 468 | " b,\n", 469 | " log_target=True,\n", 470 | " reduction='none')\n", 471 | "\n", 472 | "\n", 473 | "get_kl(torch.randn(15), torch.zeros(15))" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 6, 479 | "id": "9d600b60", 480 | "metadata": {}, 481 | "outputs": [ 482 | { 483 | "name": "stderr", 484 | "output_type": "stream", 485 | "text": [ 486 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", 487 | " warnings.warn(\n" 488 | ] 489 | }, 490 | { 491 | "data": { 492 | "text/plain": [ 493 | "<__main__.PPOTrainer at 0x7f04cdfc8190>" 494 | ] 495 | }, 496 | "execution_count": 6, 497 | "metadata": {}, 498 | "output_type": "execute_result" 499 | } 500 | ], 501 | "source": [ 502 | "from trl.core import clip_by_value, logprobs_from_logits, masked_mean, masked_whiten\n", 503 | "\n", 504 | "\n", 505 | "class PPOTrainer:\n", 506 | "\n", 507 | " def __init__(self):\n", 508 | " self.optimizer = torch.optim.Adam(model_ppo.parameters(), lr=1e-5)\n", 509 | "\n", 510 | " def step(self, question, answer, reward):\n", 511 | " with torch.no_grad():\n", 512 | " #编码\n", 513 | " token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]\n", 514 | " token = [{'input_ids': i} for i in token]\n", 515 | " token = tokenizer.pad(token, return_tensors='pt').to(device)\n", 516 | " input_ids = token.input_ids\n", 517 | " attention_mask = token.attention_mask\n", 518 | " del token\n", 519 | "\n", 520 | " #question和answer不需要内容,只需要长度信息即可\n", 521 | " lens_q = [len(i) for i in question]\n", 522 | " lens_a = [len(i) for i in answer]\n", 523 | " del question\n", 524 | " del answer\n", 525 | "\n", 526 | " #根据question计算answer的概率,并计算每个动作的分数\n", 527 | " prob_log, value, mask = self.batched_forward_pass(\n", 528 | " model_ppo, input_ids, attention_mask, lens_q, lens_a)\n", 529 | "\n", 530 | " #使用ref模型计算概率,这是为了计算kl散度\n", 531 | " prob_log_ref, _, _ = self.batched_forward_pass(\n", 532 | " model_ppo_ref, input_ids, attention_mask, lens_q, lens_a)\n", 533 | "\n", 534 | " #计算两份概率的kl散度,并融入reward\n", 535 | " reward = self.compute_rewards(reward, prob_log, prob_log_ref, mask)\n", 536 | "\n", 537 | " #计算delta和target,用于计算loss\n", 538 | " value, delta, target = self.compute_advantages(value, reward, mask)\n", 539 | "\n", 540 | " #每批数据循环N次模型\n", 541 | " for _ in range(4):\n", 542 | " #每次算一个数据\n", 543 | " for i in range(len(input_ids)):\n", 544 | " #重新计算概率和value\n", 545 | " prob_log_new, value_new, _ = self.batched_forward_pass(\n", 546 | " model_ppo, input_ids[i].unsqueeze(0),\n", 547 | " attention_mask[i].unsqueeze(0), [lens_q[i]], [lens_a[i]])\n", 548 | "\n", 549 | " #根据新旧概率求出变化率,进而求出loss\n", 550 | " #根据target和value的差可以计算出另外一份loss\n", 551 | " loss = self.get_loss(prob_log[i].unsqueeze(0),\n", 552 | " value[i].unsqueeze(0), prob_log_new,\n", 553 | " value_new, mask[i].unsqueeze(0),\n", 554 | " delta[i].unsqueeze(0),\n", 555 | " target[i].unsqueeze(0))\n", 556 | "\n", 557 | " if not loss:\n", 558 | " continue\n", 559 | "\n", 560 | " loss.backward()\n", 561 | " #torch.nn.utils.clip_grad_norm_(model_ppo.parameters(), 1.0)\n", 562 | " self.optimizer.step()\n", 563 | " self.optimizer.zero_grad()\n", 564 | "\n", 565 | " def batched_forward_pass(self, model, input_ids, attention_mask, lens_q,\n", 566 | " lens_a):\n", 567 | " logits, value = model(input_ids=input_ids,\n", 568 | " attention_mask=attention_mask)\n", 569 | "\n", 570 | " #取每个字的概率对数\n", 571 | " prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])\n", 572 | "\n", 573 | " #是预测结果并且不是PAD的位置是1\n", 574 | " mask = torch.zeros_like(attention_mask)\n", 575 | " mask[:, :-1] = attention_mask[:, 1:]\n", 576 | " for i in range(len(input_ids)):\n", 577 | " start = lens_q[i] - 1\n", 578 | " end = start + lens_a[i]\n", 579 | " mask[i, :start] = 0\n", 580 | " mask[i, end:] = 0\n", 581 | "\n", 582 | " #对最后一个字的预测没有意义,直接丢弃\n", 583 | " value = value[:, :-1]\n", 584 | " mask = mask[:, :-1]\n", 585 | "\n", 586 | " return prob_log, value, mask\n", 587 | "\n", 588 | " def compute_rewards(self, reward, prob_log, prob_log_ref, mask):\n", 589 | " reward_kl = []\n", 590 | "\n", 591 | " for i in range(len(reward)):\n", 592 | " #求两份概率的kl散度\n", 593 | " kl = get_kl(prob_log[i], prob_log_ref[i]) * -0.2\n", 594 | "\n", 595 | " #把reward加在最后一个字的kl散度上\n", 596 | " if (mask[i] == 0).all():\n", 597 | " #print('all 0')\n", 598 | " idx = 0\n", 599 | " else:\n", 600 | " idx = mask[i].nonzero()[-1].item()\n", 601 | " kl[idx] += reward[i]\n", 602 | "\n", 603 | " reward_kl.append(kl)\n", 604 | "\n", 605 | " return torch.stack(reward_kl)\n", 606 | "\n", 607 | " def compute_advantages(self, value, reward_kl, mask):\n", 608 | " value = value * mask\n", 609 | " reward_kl = reward_kl * mask\n", 610 | "\n", 611 | " delta = []\n", 612 | " lens = reward_kl.shape[1]\n", 613 | "\n", 614 | " #从后往前遍历\n", 615 | " for i in reversed(range(lens)):\n", 616 | " #取下一时刻的value,如果已经是最后一个时刻,则value_next是0\n", 617 | " #因为整个循环是从后往前,所以第0次是0,其他时刻取value\n", 618 | " value_next = 0\n", 619 | " if i < lens - 1:\n", 620 | " value_next = value[:, i + 1]\n", 621 | "\n", 622 | " #value = gamma*下一时刻的value + reward\n", 623 | " #理论上相等,这里的差定义为delta,这里gamma是1,所以省略了\n", 624 | " d = reward_kl[:, i] + value_next - value[:, i]\n", 625 | "\n", 626 | " #取最后一个delta,如果还没有,则初始化为0\n", 627 | " last_d = 0\n", 628 | " if delta:\n", 629 | " last_d = delta[-1]\n", 630 | "\n", 631 | " #delta是从后往前传递的,这里的系数衡量了前后动作的因果关联性\n", 632 | " delta.append(d + 0.95 * last_d)\n", 633 | "\n", 634 | " #翻转顺序\n", 635 | " delta = torch.stack(delta[::-1]).transpose(0, 1)\n", 636 | "\n", 637 | " #定义target,它估计了理想的value值\n", 638 | " target = delta + value\n", 639 | " delta = masked_whiten(delta, mask)\n", 640 | "\n", 641 | " return value, delta, target\n", 642 | "\n", 643 | " def get_loss(self, prob_log, value, prob_log_new, value_new, mask, delta,\n", 644 | " target):\n", 645 | "\n", 646 | " #对数概率,相除变相减,取exp后还原为商,即两个模型输出logits的变化率\n", 647 | " ratio = (prob_log_new - prob_log).exp()\n", 648 | "\n", 649 | " #如果变化率太过于剧烈,可能是发生了震荡,跳过\n", 650 | " if masked_mean(ratio, mask).item() > 10:\n", 651 | " #print('skip', masked_mean(ratio, mask).item())\n", 652 | " return None\n", 653 | "\n", 654 | " #先算两个value的loss,简单的算mse loss就可以了\n", 655 | " loss_vf1 = (value_new - target)**2\n", 656 | " #数值裁剪,很显然是为了缓解自举\n", 657 | " loss_vf2 = clip_by_value(value_new, value - 0.2, value + 0.2)\n", 658 | " loss_vf2 = (loss_vf2 - target)**2\n", 659 | " #两份loss取大的,还是为了缓解自举\n", 660 | " loss_vf = 0.5 * masked_mean(torch.max(loss_vf1, loss_vf2), mask)\n", 661 | "\n", 662 | " #计算ppo loss\n", 663 | " loss_surr1 = -delta * ratio\n", 664 | " #数值裁剪,很显然是为了缓解自举\n", 665 | " loss_surr2 = -delta * ratio.clamp(0.8, 1.2)\n", 666 | " loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2), mask)\n", 667 | "\n", 668 | " return loss_surr + 0.1 * loss_vf\n", 669 | "\n", 670 | "\n", 671 | "trainer = PPOTrainer()\n", 672 | "\n", 673 | "trainer" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 7, 679 | "id": "805234d3", 680 | "metadata": {}, 681 | "outputs": [ 682 | { 683 | "name": "stderr", 684 | "output_type": "stream", 685 | "text": [ 686 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 687 | ] 688 | } 689 | ], 690 | "source": [ 691 | "from transformers import AutoModelForSequenceClassification\n", 692 | "\n", 693 | "tokenizer_cls = AutoTokenizer.from_pretrained(\n", 694 | " 'tokenizer/lvwerra/distilbert-imdb')\n", 695 | "model_cls = AutoModelForSequenceClassification.from_pretrained(\n", 696 | " 'model/lvwerra/distilbert-imdb')\n", 697 | "model_cls.to(device)\n", 698 | "\n", 699 | "for i in model_cls.parameters():\n", 700 | " i.requires_grad_(False)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 8, 706 | "id": "c3e7a327", 707 | "metadata": { 708 | "scrolled": true 709 | }, 710 | "outputs": [ 711 | { 712 | "data": { 713 | "text/plain": [ 714 | "(tensor([0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0,\n", 715 | " 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,\n", 716 | " 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,\n", 717 | " 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0,\n", 718 | " 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,\n", 719 | " 0, 1, 0, 0, 0, 0, 0, 0], device='cuda:0'),\n", 720 | " [tensor([ 15, 2215, 257, 3807, 588, 366], device='cuda:0'),\n", 721 | " tensor([ 15, 17439, 10884, 26066, 750, 523], device='cuda:0'),\n", 722 | " tensor([ 15, 1532, 345, 389, 2045, 329], device='cuda:0'),\n", 723 | " tensor([ 15, 40, 13770, 8359, 428, 2646], device='cuda:0'),\n", 724 | " tensor([ 16, 1722, 262, 13738, 8146, 17607], device='cuda:0'),\n", 725 | " tensor([ 15, 1212, 2646, 815, 423, 587], device='cuda:0'),\n", 726 | " tensor([ 16, 12298, 72, 47, 8404, 284], device='cuda:0'),\n", 727 | " tensor([ 16, 42322, 314, 561, 18595, 661], device='cuda:0'),\n", 728 | " tensor([ 15, 40, 373, 1936, 618, 262], device='cuda:0'),\n", 729 | " tensor([ 16, 3666, 3988, 6151, 428, 3807], device='cuda:0')])" 730 | ] 731 | }, 732 | "execution_count": 8, 733 | "metadata": {}, 734 | "output_type": "execute_result" 735 | } 736 | ], 737 | "source": [ 738 | "def get_question():\n", 739 | " label, question = get_batch_data()\n", 740 | " label = torch.LongTensor(label).to(device)\n", 741 | "\n", 742 | " question = [torch.LongTensor(i).to(device) for i in question]\n", 743 | "\n", 744 | " return label, question\n", 745 | "\n", 746 | "\n", 747 | "label, question = get_question()\n", 748 | "\n", 749 | "label, question[:10]" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 9, 755 | "id": "b77e7d96", 756 | "metadata": { 757 | "scrolled": true 758 | }, 759 | "outputs": [ 760 | { 761 | "data": { 762 | "text/plain": [ 763 | "[tensor([ 464, 7193, 261, 1, 373, 287, 20550, 11, 612, 714,\n", 764 | " 307, 2187, 3923, 546, 28623, 262, 2933, 11, 290, 703,\n", 765 | " 339, 1392, 4978, 287, 262, 16479, 287, 262, 1903, 284,\n", 766 | " 3095, 11445], device='cuda:0'),\n", 767 | " tensor([ 7138, 13, 679, 318, 4084, 379, 257, 966, 287, 465,\n", 768 | " 1981, 3451, 810, 339, 338, 655, 407, 379, 477, 39072,\n", 769 | " 355, 281, 6802, 29847, 1671, 1220, 6927, 1671, 11037, 48393,\n", 770 | " 1279, 1671], device='cuda:0'),\n", 771 | " tensor([ 262, 1388, 3435, 351, 3499, 19063, 11, 13300, 4306, 1100,\n", 772 | " 1088, 329, 257, 39679, 3074, 475, 314, 4719, 340, 611,\n", 773 | " 345, 765, 284, 5899, 20775, 13, 50256], device='cuda:0'),\n", 774 | " tensor([ 13, 1081, 2582, 355, 17369, 465, 12702, 9546, 314, 373,\n", 775 | " 21638, 585, 0, 50256], device='cuda:0'),\n", 776 | " tensor([ 2613, 459, 0, 1318, 338, 530, 7650, 414, 7411, 477,\n", 777 | " 1115, 286, 262, 4476, 274, 4379, 262, 366, 32399, 1911,\n", 778 | " 5278, 306, 340, 338, 13699, 4379, 9935, 440, 6, 6187,\n", 779 | " 2304, 692], device='cuda:0'),\n", 780 | " tensor([ 1695, 319, 34245, 12490, 284, 2342, 691, 986, 37814, 314,\n", 781 | " 7342, 373, 327, 42, 8206, 11, 772, 6918, 314, 1392,\n", 782 | " 379, 1790, 4003, 13, 1675, 910, 326, 428, 2646, 468,\n", 783 | " 262, 2694], device='cuda:0'),\n", 784 | " tensor([ 466, 257, 3155, 286, 1243, 13, 770, 561, 407, 466,\n", 785 | " 284, 1388, 16687, 13, 383, 3807, 857, 407, 423, 262,\n", 786 | " 466, 540, 3404, 326, 7328, 884, 355, 383, 21469, 9501,\n", 787 | " 466, 13], device='cuda:0'),\n", 788 | " tensor([ 284, 27401, 1598, 286, 883, 15774, 2324, 290, 7124, 9073,\n", 789 | " 13, 632, 481, 1790, 1103, 6918, 11, 2529, 11970, 3404,\n", 790 | " 5645, 510, 852, 1716, 517, 4388, 706, 257, 1178, 1933,\n", 791 | " 11, 611], device='cuda:0'),\n", 792 | " tensor([ 1621, 2492, 470, 1598, 13, 1318, 338, 691, 1936, 393,\n", 793 | " 2237, 2431, 1364, 674, 10281, 6, 3160, 13, 632, 2952,\n", 794 | " 2627, 1598, 284, 502, 355, 880, 355, 584, 15618, 7912,\n", 795 | " 326, 351], device='cuda:0'),\n", 796 | " tensor([ 13, 632, 338, 7818, 266, 14, 262, 366, 11718, 500,\n", 797 | " 15461, 465, 22701, 1, 290, 366, 1169, 4293, 12, 20676,\n", 798 | " 6944, 11, 475, 281, 4896, 286, 640, 290, 2568, 326,\n", 799 | " 1854, 466], device='cuda:0')]" 800 | ] 801 | }, 802 | "execution_count": 9, 803 | "metadata": {}, 804 | "output_type": "execute_result" 805 | } 806 | ], 807 | "source": [ 808 | "#包装类,用于生成\n", 809 | "def generate(input_ids):\n", 810 | " return model_ppo.model_gen.generate(input_ids=input_ids,\n", 811 | " min_length=-1,\n", 812 | " top_k=0.0,\n", 813 | " top_p=1.0,\n", 814 | " do_sample=True,\n", 815 | " pad_token_id=tokenizer.pad_token_id,\n", 816 | " max_new_tokens=32,\n", 817 | " eos_token_id=tokenizer.eos_token_id)\n", 818 | "\n", 819 | "\n", 820 | "def get_answer(question):\n", 821 | " #如果question的长度确定,这里可以转换成批运算\n", 822 | " if True:\n", 823 | " answer = generate(torch.stack(question))\n", 824 | "\n", 825 | " answer_new = []\n", 826 | " for i in answer:\n", 827 | " if tokenizer.eos_token_id not in i:\n", 828 | " answer_new.append(i.unsqueeze(0))\n", 829 | " continue\n", 830 | " split = i.tolist().index(tokenizer.eos_token_id) + 1\n", 831 | " answer_new.append(i[:split].unsqueeze(0))\n", 832 | " answer = answer_new\n", 833 | " else:\n", 834 | " answer = [generate(i.unsqueeze(0)) for i in question]\n", 835 | "\n", 836 | " #裁剪,只要生成的部分\n", 837 | " answer = [a[0, len(q):] for q, a in zip(question, answer)]\n", 838 | "\n", 839 | " return answer\n", 840 | "\n", 841 | "\n", 842 | "answer = get_answer(question)\n", 843 | "\n", 844 | "answer[:10]" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": 10, 850 | "id": "720b4c0f", 851 | "metadata": { 852 | "scrolled": true 853 | }, 854 | "outputs": [ 855 | { 856 | "data": { 857 | "text/plain": [ 858 | "tensor([-0.1023, -2.0593, 0.1034, -2.1943, -0.2854, 0.3657, -2.0234, -0.6897,\n", 859 | " -1.0600, 1.9237, -1.1278, -1.9625, 1.9247, 1.0706, 1.9642, -2.7512,\n", 860 | " 0.1869, 0.4816, -1.4677, 1.4251, 1.8278, -2.4232, 0.3808, -2.2452,\n", 861 | " -1.4352, -1.3860, 1.2793, 2.5945, 0.8147, 1.0352, -0.8067, 1.5356,\n", 862 | " -0.4795, 0.6381, -1.9992, 1.6725, 1.0273, -1.3380, 2.6569, 0.8692,\n", 863 | " 0.3785, -0.6434, -0.1132, 0.7105, 0.7786, -1.3073, -0.9017, -1.7745,\n", 864 | " -1.1298, -0.0379, 0.9818, -1.9865, -1.9988, -2.0698, -2.3039, 0.5009,\n", 865 | " 0.7905, -1.5054, -0.3490, -1.1563, 0.6987, 1.6517, -1.2936, -1.1562,\n", 866 | " -0.6478, -1.3625, -0.8715, -1.3332, -1.3316, -2.4711, 0.1419, -0.8977,\n", 867 | " 0.6781, -2.5032, 2.1797, -2.2050, 0.3033, 1.0208, 1.5972, 1.0242,\n", 868 | " 0.0768, 1.0985, 2.1549, 1.7992, -2.0123, 2.7790, 0.7879, -0.4145,\n", 869 | " 1.1283, -2.4656, 0.3194, 2.4502, -1.3679, 1.6667, 2.4101, -0.8367,\n", 870 | " 0.9427, 2.0877, -2.4355, -0.8031, 0.8273, -1.3980, 1.5957, 1.0456,\n", 871 | " 1.6853, -0.6226, 1.0414, 0.5680, -1.2993, -0.8639, -1.3040, 0.6334,\n", 872 | " 2.5615, 0.0733, -2.6728, -1.4309, 0.7316, 1.3707, 1.3495, -1.8754,\n", 873 | " 1.5368, 2.0362, -1.4510, -2.0597, -2.3088, 2.4779, -0.3622, -1.8313],\n", 874 | " device='cuda:0')" 875 | ] 876 | }, 877 | "execution_count": 10, 878 | "metadata": {}, 879 | "output_type": "execute_result" 880 | } 881 | ], 882 | "source": [ 883 | "def get_reward(question, answer, label):\n", 884 | " token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]\n", 885 | " token = [tokenizer.decode(i) for i in token]\n", 886 | "\n", 887 | " token = tokenizer_cls(token,\n", 888 | " padding=True,\n", 889 | " truncation=True,\n", 890 | " max_length=512,\n", 891 | " return_tensors='pt').to(device)\n", 892 | "\n", 893 | " with torch.no_grad():\n", 894 | " logits = model_cls(**token).logits\n", 895 | "\n", 896 | " return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n", 897 | "\n", 898 | "\n", 899 | "reward = get_reward(question, answer, label)\n", 900 | "\n", 901 | "reward" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": 11, 907 | "id": "a1c70e8e", 908 | "metadata": { 909 | "scrolled": false 910 | }, 911 | "outputs": [ 912 | { 913 | "name": "stderr", 914 | "output_type": "stream", 915 | "text": [ 916 | "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" 917 | ] 918 | }, 919 | { 920 | "name": "stdout", 921 | "output_type": "stream", 922 | "text": [ 923 | "0 0.04424596205353737\n", 924 | "0Rowan Atkinson's Mr -> . Incredible not starting off flat or unless you were just seriously down to sea out here I did feel sorry because there is an absolutely amazing heist film thats almost -1.7402509450912476\n" 925 | ] 926 | }, 927 | { 928 | "name": "stderr", 929 | "output_type": "stream", 930 | "text": [ 931 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 932 | ] 933 | }, 934 | { 935 | "name": "stdout", 936 | "output_type": "stream", 937 | "text": [ 938 | "10 0.1893816739320755\n", 939 | "1Not only is this a -> work of pure creativity. Its worth buying Justin Kuehne.<|endoftext|> 2.4963550567626953\n" 940 | ] 941 | }, 942 | { 943 | "name": "stderr", 944 | "output_type": "stream", 945 | "text": [ 946 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 947 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 948 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 949 | ] 950 | }, 951 | { 952 | "name": "stdout", 953 | "output_type": "stream", 954 | "text": [ 955 | "20 0.11868445575237274\n", 956 | "0I found this very touching -> .)<|endoftext|> -1.8980457782745361\n" 957 | ] 958 | }, 959 | { 960 | "name": "stderr", 961 | "output_type": "stream", 962 | "text": [ 963 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 964 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 965 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 966 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 967 | ] 968 | }, 969 | { 970 | "name": "stdout", 971 | "output_type": "stream", 972 | "text": [ 973 | "30 0.04239548742771149\n", 974 | "0Even die hard John Wayne -> was able to play that one scene apart by himself.

The fight in the basement is a simple scene. That's the low- -0.8326716423034668\n" 975 | ] 976 | }, 977 | { 978 | "name": "stderr", 979 | "output_type": "stream", 980 | "text": [ 981 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 982 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 983 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 984 | ] 985 | }, 986 | { 987 | "name": "stdout", 988 | "output_type": "stream", 989 | "text": [ 990 | "40 0.09275282919406891\n", 991 | "0I couldn't wait to -> find it on the prairie.<|endoftext|> -0.02781713753938675\n" 992 | ] 993 | }, 994 | { 995 | "name": "stderr", 996 | "output_type": "stream", 997 | "text": [ 998 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 999 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1000 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1001 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1002 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1003 | ] 1004 | }, 1005 | { 1006 | "name": "stdout", 1007 | "output_type": "stream", 1008 | "text": [ 1009 | "50 0.38358354568481445\n", 1010 | "1The first episode set the -> structure perfectly. It made me hate it.<|endoftext|> -0.7066333889961243\n" 1011 | ] 1012 | }, 1013 | { 1014 | "name": "stderr", 1015 | "output_type": "stream", 1016 | "text": [ 1017 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1018 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1019 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1020 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1021 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1022 | ] 1023 | }, 1024 | { 1025 | "name": "stdout", 1026 | "output_type": "stream", 1027 | "text": [ 1028 | "60 1.2206898927688599\n", 1029 | "1May 1938. Hitler in -> this film is the best.<|endoftext|> 1.5236488580703735\n" 1030 | ] 1031 | }, 1032 | { 1033 | "name": "stderr", 1034 | "output_type": "stream", 1035 | "text": [ 1036 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1037 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1038 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1039 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1040 | ] 1041 | }, 1042 | { 1043 | "name": "stdout", 1044 | "output_type": "stream", 1045 | "text": [ 1046 | "70 1.730372428894043\n", 1047 | "1I disliked this film intensely -> I wanted to finish it so often I didn't know what to do with the characters we never liked, but this film was amazing!<|endoftext|> 1.5757133960723877\n" 1048 | ] 1049 | }, 1050 | { 1051 | "name": "stderr", 1052 | "output_type": "stream", 1053 | "text": [ 1054 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1055 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1056 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1057 | ] 1058 | }, 1059 | { 1060 | "name": "stdout", 1061 | "output_type": "stream", 1062 | "text": [ 1063 | "80 1.9782875776290894\n", 1064 | "0This film is a very -> bad framing.<|endoftext|> 2.4100003242492676\n" 1065 | ] 1066 | }, 1067 | { 1068 | "name": "stderr", 1069 | "output_type": "stream", 1070 | "text": [ 1071 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1072 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1073 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1074 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1075 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1076 | ] 1077 | }, 1078 | { 1079 | "name": "stdout", 1080 | "output_type": "stream", 1081 | "text": [ 1082 | "90 2.1086392402648926\n", 1083 | "1It is the early morning -> in DC the best.<|endoftext|> 2.3762409687042236\n" 1084 | ] 1085 | }, 1086 | { 1087 | "name": "stderr", 1088 | "output_type": "stream", 1089 | "text": [ 1090 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1091 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1092 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1093 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1094 | ] 1095 | }, 1096 | { 1097 | "name": "stdout", 1098 | "output_type": "stream", 1099 | "text": [ 1100 | "100 2.1353683471679688\n", 1101 | "1It hurt to watch this -> , this, but it really felt ripping it into a piece and it makes it very well done.<|endoftext|> 2.6977579593658447\n" 1102 | ] 1103 | }, 1104 | { 1105 | "name": "stderr", 1106 | "output_type": "stream", 1107 | "text": [ 1108 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1109 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1110 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1111 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1112 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1113 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1114 | ] 1115 | }, 1116 | { 1117 | "name": "stdout", 1118 | "output_type": "stream", 1119 | "text": [ 1120 | "110 2.251354455947876\n", 1121 | "1This was the first of -> such themes. It was great!<|endoftext|> 2.4420251846313477\n" 1122 | ] 1123 | }, 1124 | { 1125 | "name": "stderr", 1126 | "output_type": "stream", 1127 | "text": [ 1128 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1129 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1130 | ] 1131 | }, 1132 | { 1133 | "name": "stdout", 1134 | "output_type": "stream", 1135 | "text": [ 1136 | "120 2.341010093688965\n", 1137 | "1Everyone is surely familiar with -> the name and is highly appreciated.<|endoftext|> 2.4222381114959717\n" 1138 | ] 1139 | }, 1140 | { 1141 | "name": "stderr", 1142 | "output_type": "stream", 1143 | "text": [ 1144 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1145 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1146 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1147 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1148 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1149 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1150 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1151 | ] 1152 | }, 1153 | { 1154 | "name": "stdout", 1155 | "output_type": "stream", 1156 | "text": [ 1157 | "130 2.313877582550049\n", 1158 | "1Whoever saddled this piece -> is a great appreciated film.<|endoftext|> 2.3847947120666504\n" 1159 | ] 1160 | }, 1161 | { 1162 | "name": "stderr", 1163 | "output_type": "stream", 1164 | "text": [ 1165 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1166 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1167 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1168 | ] 1169 | }, 1170 | { 1171 | "name": "stdout", 1172 | "output_type": "stream", 1173 | "text": [ 1174 | "140 2.2891697883605957\n", 1175 | "1I can never fathom -> this film blue shadows. I think it is wonderful.<|endoftext|> 2.648195743560791\n" 1176 | ] 1177 | }, 1178 | { 1179 | "name": "stderr", 1180 | "output_type": "stream", 1181 | "text": [ 1182 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1183 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1184 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1185 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1186 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1187 | ] 1188 | }, 1189 | { 1190 | "name": "stdout", 1191 | "output_type": "stream", 1192 | "text": [ 1193 | "150 2.1408190727233887\n", 1194 | "0Making a film for under -> this crap!<|endoftext|> 2.4721293449401855\n" 1195 | ] 1196 | }, 1197 | { 1198 | "name": "stderr", 1199 | "output_type": "stream", 1200 | "text": [ 1201 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1202 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1203 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1204 | ] 1205 | }, 1206 | { 1207 | "name": "stdout", 1208 | "output_type": "stream", 1209 | "text": [ 1210 | "160 2.212196111679077\n", 1211 | "0Lets make a movie -> about this crap.<|endoftext|> 2.4810938835144043\n" 1212 | ] 1213 | }, 1214 | { 1215 | "name": "stderr", 1216 | "output_type": "stream", 1217 | "text": [ 1218 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1219 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1220 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1221 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1222 | ] 1223 | }, 1224 | { 1225 | "name": "stdout", 1226 | "output_type": "stream", 1227 | "text": [ 1228 | "170 2.2233328819274902\n", 1229 | "0Just saw ICE AGE -> for a total waste of time.<|endoftext|> 2.480029821395874\n" 1230 | ] 1231 | }, 1232 | { 1233 | "name": "stderr", 1234 | "output_type": "stream", 1235 | "text": [ 1236 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1237 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1238 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1239 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1240 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1241 | ] 1242 | }, 1243 | { 1244 | "name": "stdout", 1245 | "output_type": "stream", 1246 | "text": [ 1247 | "180 2.2824764251708984\n", 1248 | "0The apolitical musicians Eva -> is just awful.<|endoftext|> 2.610349655151367\n" 1249 | ] 1250 | }, 1251 | { 1252 | "name": "stderr", 1253 | "output_type": "stream", 1254 | "text": [ 1255 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1256 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1257 | ] 1258 | }, 1259 | { 1260 | "name": "stdout", 1261 | "output_type": "stream", 1262 | "text": [ 1263 | "190 2.3515257835388184\n", 1264 | "1Excellent movie, a realistic -> look of my life.<|endoftext|> 2.7795515060424805\n" 1265 | ] 1266 | }, 1267 | { 1268 | "name": "stderr", 1269 | "output_type": "stream", 1270 | "text": [ 1271 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1272 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1273 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n", 1274 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n" 1275 | ] 1276 | } 1277 | ], 1278 | "source": [ 1279 | "for epoch in range(200):\n", 1280 | " label, question = get_question()\n", 1281 | " answer = get_answer(question)\n", 1282 | " reward = get_reward(question, answer, label)\n", 1283 | "\n", 1284 | " trainer.step(question, answer, reward)\n", 1285 | "\n", 1286 | " if epoch % 10 == 0:\n", 1287 | " print(epoch, reward.mean().item())\n", 1288 | " question = tokenizer.decode(question[0].tolist())\n", 1289 | " answer = tokenizer.decode(answer[0].tolist())\n", 1290 | " reward = reward[0].item()\n", 1291 | "\n", 1292 | " #0差评,1好评\n", 1293 | " print(question, '->', answer, reward)" 1294 | ] 1295 | } 1296 | ], 1297 | "metadata": { 1298 | "kernelspec": { 1299 | "display_name": "Python [conda env:pt2]", 1300 | "language": "python", 1301 | "name": "conda-env-pt2-py" 1302 | }, 1303 | "language_info": { 1304 | "codemirror_mode": { 1305 | "name": "ipython", 1306 | "version": 3 1307 | }, 1308 | "file_extension": ".py", 1309 | "mimetype": "text/x-python", 1310 | "name": "python", 1311 | "nbconvert_exporter": "python", 1312 | "pygments_lexer": "ipython3", 1313 | "version": "3.10.13" 1314 | } 1315 | }, 1316 | "nbformat": 4, 1317 | "nbformat_minor": 5 1318 | } 1319 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 训练自然语言的LLM模型 2 | 包括基于TRL的训练,和手动训练两种实现. 3 | 4 | 训练方法包括DPO和PPO 5 | 6 | 环境信息: 7 | 8 | python=3.10 9 | 10 | torch==2.1.0(cuda) 11 | 12 | transformers==4.34.0 13 | 14 | datasets==2.14.5 15 | 16 | trl==0.7.4 17 | 18 | 视频课程:https://www.bilibili.com/video/BV1bu4y1w7Dz 19 | --------------------------------------------------------------------------------