├── .gitignore
├── 0.下载文件.ipynb
├── 1.dpo_trl训练.ipynb
├── 2.dpo_手动训练.ipynb
├── 3.ppo_trl训练.ipynb
├── 4.ppo_手动训练.ipynb
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
--------------------------------------------------------------------------------
/0.下载文件.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "65d45fa0",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from transformers import AutoTokenizer\n",
11 | "\n",
12 | "AutoTokenizer.from_pretrained('gpt2').save_pretrained('tokenizer/gpt2')\n",
13 | "\n",
14 | "AutoTokenizer.from_pretrained('lvwerra/gpt2-imdb').save_pretrained(\n",
15 | " 'tokenizer/lvwerra/gpt2-imdb')\n",
16 | "\n",
17 | "AutoTokenizer.from_pretrained('lvwerra/distilbert-imdb').save_pretrained(\n",
18 | " 'tokenizer/lvwerra/distilbert-imdb')"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "ba345732",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from datasets import load_dataset\n",
29 | "\n",
30 | "load_dataset('imdb').save_to_disk('dataset/imdb')\n",
31 | "\n",
32 | "load_dataset('b-mc2/sql-create-context').save_to_disk(\n",
33 | " 'dataset/b-mc2/sql-create-context')"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "id": "4341f356",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "from transformers import AutoModelForCausalLM\n",
44 | "from transformers import AutoModelForSequenceClassification\n",
45 | "\n",
46 | "AutoModelForCausalLM.from_pretrained('gpt2').save_pretrained('model/gpt2')\n",
47 | "\n",
48 | "AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb').save_pretrained(\n",
49 | " 'model/lvwerra/gpt2-imdb')\n",
50 | "\n",
51 | "AutoModelForSequenceClassification.from_pretrained(\n",
52 | " 'lvwerra/distilbert-imdb').save_pretrained('model/lvwerra/distilbert-imdb')"
53 | ]
54 | }
55 | ],
56 | "metadata": {
57 | "kernelspec": {
58 | "display_name": "Python [conda env:pt2]",
59 | "language": "python",
60 | "name": "conda-env-pt2-py"
61 | },
62 | "language_info": {
63 | "codemirror_mode": {
64 | "name": "ipython",
65 | "version": 3
66 | },
67 | "file_extension": ".py",
68 | "mimetype": "text/x-python",
69 | "name": "python",
70 | "nbconvert_exporter": "python",
71 | "pygments_lexer": "ipython3",
72 | "version": "3.10.13"
73 | }
74 | },
75 | "nbformat": 4,
76 | "nbformat_minor": 5
77 | }
78 |
--------------------------------------------------------------------------------
/1.dpo_trl训练.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "12276e41",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14 | " from .autonotebook import tqdm as notebook_tqdm\n",
15 | "Using sep_token, but it is not set yet.\n",
16 | "Using cls_token, but it is not set yet.\n",
17 | "Using mask_token, but it is not set yet.\n"
18 | ]
19 | },
20 | {
21 | "data": {
22 | "text/plain": [
23 | "GPT2TokenizerFast(name_or_path='tokenizer/gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
24 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
25 | "}"
26 | ]
27 | },
28 | "execution_count": 1,
29 | "metadata": {},
30 | "output_type": "execute_result"
31 | }
32 | ],
33 | "source": [
34 | "from transformers import AutoTokenizer\n",
35 | "import random\n",
36 | "import torch\n",
37 | "\n",
38 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2')\n",
39 | "tokenizer.pad_token_id = 0\n",
40 | "\n",
41 | "tokenizer"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "id": "58af6f4a",
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "data": {
52 | "text/plain": [
53 | "(DatasetDict({\n",
54 | " train: Dataset({\n",
55 | " features: ['prompt', 'chosen', 'rejected'],\n",
56 | " num_rows: 71745\n",
57 | " })\n",
58 | " test: Dataset({\n",
59 | " features: ['prompt', 'chosen', 'rejected'],\n",
60 | " num_rows: 200\n",
61 | " })\n",
62 | " }),\n",
63 | " {'prompt': 'context:CREATE TABLE TV_series (SHARE INTEGER) question:What is minimum and maximum share of TV series? answer:',\n",
64 | " 'chosen': 'SELECT MAX(SHARE), MIN(SHARE) FROM TV_series',\n",
65 | " 'rejected': ''})"
66 | ]
67 | },
68 | "execution_count": 2,
69 | "metadata": {},
70 | "output_type": "execute_result"
71 | }
72 | ],
73 | "source": [
74 | "from datasets import load_from_disk\n",
75 | "\n",
76 | "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n",
77 | "\n",
78 | "\n",
79 | "def f(data):\n",
80 | " question = 'context:%s question:%s answer:' % (data['context'],\n",
81 | " data['question'])\n",
82 | " answer = data['answer']\n",
83 | " return {'question': question, 'answer': answer}\n",
84 | "\n",
85 | "\n",
86 | "dataset = dataset.map(f, remove_columns=['context'])\n",
87 | "\n",
88 | "\n",
89 | "def f(data):\n",
90 | " question = len(tokenizer.encode(data['question']))\n",
91 | " answer = len(tokenizer.encode(data['answer']))\n",
92 | " return 25 <= question <= 65 and 10 <= answer <= 35\n",
93 | "\n",
94 | "\n",
95 | "dataset = dataset.filter(f)\n",
96 | "\n",
97 | "\n",
98 | "def f(data):\n",
99 | " return {\n",
100 | " 'prompt': data['question'],\n",
101 | " 'chosen': data['answer'],\n",
102 | " 'rejected': ''\n",
103 | " }\n",
104 | "\n",
105 | "\n",
106 | "dataset = dataset.map(f, remove_columns=['question', 'answer'])\n",
107 | "dataset = dataset.train_test_split(test_size=200)\n",
108 | "\n",
109 | "dataset, dataset['train'][0]"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 3,
115 | "id": "88954533",
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "from transformers import AutoModelForCausalLM\n",
120 | "\n",
121 | "model_dpo = AutoModelForCausalLM.from_pretrained('model/gpt2').to('cuda')\n",
122 | "model_dpo_ref = AutoModelForCausalLM.from_pretrained('model/gpt2').to('cuda')"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 4,
128 | "id": "8dba89d3",
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "data": {
133 | "text/plain": [
134 | "'context:CREATE TABLE department (num_employees INTEGER, ranking INTEGER) question:What is the average number of employees of the departments whose rank is between 10 and 15? answer:10\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer is:\\n\\nThe answer'"
135 | ]
136 | },
137 | "execution_count": 4,
138 | "metadata": {},
139 | "output_type": "execute_result"
140 | }
141 | ],
142 | "source": [
143 | "import torch\n",
144 | "import random\n",
145 | "\n",
146 | "\n",
147 | "@torch.no_grad()\n",
148 | "def generate(input_ids):\n",
149 | " lens = input_ids.shape[1]\n",
150 | " while True:\n",
151 | " out = model_dpo(input_ids=input_ids)\n",
152 | " topk = out['logits'][0, -1].topk(1)\n",
153 | "\n",
154 | " values = topk.values.softmax(0).tolist()\n",
155 | " indices = topk.indices.tolist()\n",
156 | " next_word = random.choices(indices, weights=values)\n",
157 | "\n",
158 | " next_word = torch.LongTensor(next_word).unsqueeze(0).to('cuda')\n",
159 | " input_ids = torch.cat([input_ids, next_word], dim=1)\n",
160 | "\n",
161 | " if input_ids.shape[1] - lens >= 35:\n",
162 | " break\n",
163 | "\n",
164 | " if input_ids[0, -1] == tokenizer.eos_token_id:\n",
165 | " break\n",
166 | "\n",
167 | " return input_ids\n",
168 | "\n",
169 | "\n",
170 | "input_ids = 'context:CREATE TABLE department (num_employees INTEGER, ranking INTEGER) question:What is the average number of employees of the departments whose rank is between 10 and 15? answer:'\n",
171 | "input_ids = tokenizer.encode(input_ids, return_tensors='pt').to('cuda')\n",
172 | "\n",
173 | "out = generate(input_ids)\n",
174 | "\n",
175 | "tokenizer.decode(out[0])"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 5,
181 | "id": "16be7af3",
182 | "metadata": {
183 | "scrolled": false
184 | },
185 | "outputs": [
186 | {
187 | "name": "stderr",
188 | "output_type": "stream",
189 | "text": [
190 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
191 | " warnings.warn(\n",
192 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py:291: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n",
193 | " warnings.warn(\n",
194 | "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
195 | ]
196 | },
197 | {
198 | "data": {
199 | "text/html": [
200 | "\n",
201 | "
\n",
202 | " \n",
203 | "
\n",
204 | " [2000/2000 19:19, Epoch 0/1]\n",
205 | "
\n",
206 | " \n",
207 | " \n",
208 | " \n",
209 | " Step | \n",
210 | " Training Loss | \n",
211 | "
\n",
212 | " \n",
213 | " \n",
214 | " \n",
215 | " 500 | \n",
216 | " 0.001800 | \n",
217 | "
\n",
218 | " \n",
219 | " 1000 | \n",
220 | " 0.000400 | \n",
221 | "
\n",
222 | " \n",
223 | " 1500 | \n",
224 | " 0.000500 | \n",
225 | "
\n",
226 | " \n",
227 | " 2000 | \n",
228 | " 0.000300 | \n",
229 | "
\n",
230 | " \n",
231 | "
"
232 | ],
233 | "text/plain": [
234 | ""
235 | ]
236 | },
237 | "metadata": {},
238 | "output_type": "display_data"
239 | },
240 | {
241 | "name": "stdout",
242 | "output_type": "stream",
243 | "text": [
244 | "100\n",
245 | "context:CREATE TABLE table_25330991_3 (james_e_holmes VARCHAR, reidsville VARCHAR) question:Name the james e. holmes for erselle young answer:SELECT name FROM table_25330991_3 WHERE reidsville = \"young\"<|endoftext|>\n",
246 | "=================\n",
247 | "SELECT james_e_holmes FROM table_25330991_3 WHERE reidsville = \"Erselle Young\"\n",
248 | "=================\n",
249 | "200\n",
250 | "context:CREATE TABLE table_name_8 (weekly_rank VARCHAR, official_ratings__millions_ VARCHAR, show VARCHAR) question:Which Weekly Rank for a Live Final Show has an Official Ratings (millions) greater than 5.2? answer:SELECT weekly_rank FROM table_name_8 WHERE official_ratings__millions_ = 5.2 AND show = \"Live Final Show\" AND show = \"\n",
251 | "=================\n",
252 | "SELECT weekly_rank FROM table_name_8 WHERE official_ratings__millions_ > 5.2 AND show = \"live final\"\n",
253 | "=================\n",
254 | "300\n",
255 | "context:CREATE TABLE table_11173827_1 (english_title VARCHAR, finale VARCHAR, peak VARCHAR) question:What is the english title that has finale as 33 and peak as 42? answer:SELECT title FROM table_11173827_1 WHERE finale = 33 AND peak = 42<|endoftext|>\n",
256 | "=================\n",
257 | "SELECT english_title FROM table_11173827_1 WHERE finale = 33 AND peak = 42\n",
258 | "=================\n",
259 | "400\n",
260 | "context:CREATE TABLE table_name_80 (year INTEGER, date VARCHAR, designated_home VARCHAR) question:What year had a date of TBA where the Oakland raiders were the home team? answer:SELECT year FROM table_name_80 WHERE designated_home = \"Oakland\"<|endoftext|>\n",
261 | "=================\n",
262 | "SELECT AVG(year) FROM table_name_80 WHERE date = \"tba\" AND designated_home = \"oakland raiders\"\n",
263 | "=================\n",
264 | "500\n",
265 | "context:CREATE TABLE table_name_66 (cost___us$__ VARCHAR, open_source VARCHAR) question:What is the cost of an Open Source that is no? answer:SELECT cost___us$__ FROM table_name_66 WHERE open_source = \"Open Source\"<|endoftext|>\n",
266 | "=================\n",
267 | "SELECT cost___us$__ FROM table_name_66 WHERE open_source = \"no\"\n",
268 | "=================\n",
269 | "600\n",
270 | "context:CREATE TABLE table_name_11 (winner VARCHAR, year VARCHAR) question:What is Winner, when Year is 2013? answer:SELECT Winner FROM table_name_11 WHERE year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year = 2013 AND year =\n",
271 | "=================\n",
272 | "SELECT winner FROM table_name_11 WHERE year = 2013\n",
273 | "=================\n",
274 | "700\n",
275 | "context:CREATE TABLE table_name_44 (years VARCHAR, displacement VARCHAR) question:Which years have a displacement of 1816cc? answer:SELECT years FROM table_name_44 WHERE displacement = 1816cc<|endoftext|>\n",
276 | "=================\n",
277 | "SELECT years FROM table_name_44 WHERE displacement = \"1816cc\"\n",
278 | "=================\n",
279 | "800\n",
280 | "context:CREATE TABLE table_26982362_2 (original_airdate VARCHAR, production_code VARCHAR) question:The episode with production code 693-002, has how many original airdates? answer:SELECT production_code FROM table_26982362_2 WHERE production_code = 693-002<|endoftext|>\n",
281 | "=================\n",
282 | "SELECT COUNT(original_airdate) FROM table_26982362_2 WHERE production_code = \"693-002\"\n",
283 | "=================\n",
284 | "900\n",
285 | "context:CREATE TABLE table_name_24 (games INTEGER, term_ VARCHAR, c_ VARCHAR) question:What is the of games when for the term [c] of 1969 – 1973? answer:SELECT games FROM table_name_24 WHERE term_ = \"1969\"<|endoftext|>\n",
286 | "=================\n",
287 | "SELECT SUM(games) FROM table_name_24 WHERE term_[c_] = \"1969 – 1973\"\n",
288 | "=================\n",
289 | "1000\n",
290 | "context:CREATE TABLE table_18498743_1 (october_20 VARCHAR, _2008 VARCHAR, mexico VARCHAR) question:what is the october 20, 2008 stat where mexico stat is romania answer:SELECT date FROM table_18498743_1 WHERE date = \"2008\"<|endoftext|>\n",
291 | "=================\n",
292 | "SELECT october_20, _2008 FROM table_18498743_1 WHERE mexico = \"Romania\"\n",
293 | "=================\n",
294 | "1100\n",
295 | "context:CREATE TABLE table_3005915_3 (starts INTEGER) question:What is the maximum number of starts? answer:SELECT start FROM table_3005915_3 WHERE start < 3005915_3<|endoftext|>\n",
296 | "=================\n",
297 | "SELECT MAX(starts) FROM table_3005915_3\n",
298 | "=================\n",
299 | "1200\n",
300 | "context:CREATE TABLE table_name_17 (date VARCHAR, venue VARCHAR) question:What was the date of the game at Lake Oval? answer:SELECT date FROM table_name_17 WHERE venue = \"Lake Oval\"<|endoftext|>\n",
301 | "=================\n",
302 | "SELECT date FROM table_name_17 WHERE venue = \"lake oval\"\n",
303 | "=================\n",
304 | "1300\n",
305 | "context:CREATE TABLE table_name_21 (player VARCHAR, total VARCHAR, year_s__won VARCHAR) question:Who won in 1988 with a total less than 287? answer:SELECT player FROM table_name_21 WHERE total = 287<|endoftext|>\n",
306 | "=================\n",
307 | "SELECT player FROM table_name_21 WHERE total < 287 AND year_s__won = \"1988\"\n",
308 | "=================\n",
309 | "1400\n",
310 | "context:CREATE TABLE table_23224961_1 (oberbayern_b VARCHAR, oberpfalz VARCHAR) question:When fc schwandorf is the oberpfalz what is the oberbayern b? answer:SELECT fc, oberpfalz FROM table_23224961_1 WHERE oberpfalz = \"schwandorf\"<|endoftext|>\n",
311 | "=================\n",
312 | "SELECT oberbayern_b FROM table_23224961_1 WHERE oberpfalz = \"FC Schwandorf\"\n",
313 | "=================\n",
314 | "1500\n",
315 | "context:CREATE TABLE table_name_88 (lane INTEGER, name VARCHAR) question:What is the average lane that is called rebecca brown? answer:SELECT MAX(lane) FROM table_name_88 WHERE name = \"rebecca brown\"<|endoftext|>\n",
316 | "=================\n",
317 | "SELECT AVG(lane) FROM table_name_88 WHERE name = \"rebecca brown\"\n",
318 | "=================\n",
319 | "1600\n",
320 | "context:CREATE TABLE table_name_24 (games INTEGER, term_ VARCHAR, c_ VARCHAR) question:What is the of games when for the term [c] of 1969 – 1973? answer:SELECT games FROM table_name_24 WHERE term_ = \"1969\"<|endoftext|>\n",
321 | "=================\n",
322 | "SELECT SUM(games) FROM table_name_24 WHERE term_[c_] = \"1969 – 1973\"\n",
323 | "=================\n",
324 | "1700\n",
325 | "context:CREATE TABLE table_name_87 (type VARCHAR, location VARCHAR) question:What type of Bridge is in Stanley? answer:SELECT type FROM table_name_87 WHERE location = \"Stanley\"<|endoftext|>\n",
326 | "=================\n",
327 | "SELECT type FROM table_name_87 WHERE location = \"stanley\"\n",
328 | "=================\n",
329 | "1800\n",
330 | "context:CREATE TABLE table_name_34 (attendance VARCHAR, date VARCHAR, week VARCHAR) question:What was the attendance of the game on December 13, 1970? answer:SELECT attendance FROM table_name_34 WHERE date = \"2013-12-13T00:00:00Z\"<|endoftext|>\n",
331 | "=================\n",
332 | "SELECT COUNT(attendance) FROM table_name_34 WHERE date = \"december 13, 1970\" AND week > 13\n",
333 | "=================\n",
334 | "1900\n",
335 | "context:CREATE TABLE constructors (nationality VARCHAR) question:What are the numbers of constructors for different nationalities? answer:SELECT nationality FROM constructors WHERE nationality = \"nationality\" AND country = \"nationality\"<|endoftext|>\n",
336 | "=================\n",
337 | "SELECT COUNT(*), nationality FROM constructors GROUP BY nationality\n",
338 | "=================\n",
339 | "2000\n",
340 | "context:CREATE TABLE table_name_93 (home_team VARCHAR, venue VARCHAR) question:Which home team played at MCG? answer:SELECT venue FROM table_name_93 WHERE venue = \"MCG\"<|endoftext|>\n",
341 | "=================\n",
342 | "SELECT home_team FROM table_name_93 WHERE venue = \"mcg\"\n",
343 | "=================\n"
344 | ]
345 | },
346 | {
347 | "data": {
348 | "text/plain": [
349 | "TrainOutput(global_step=2000, training_loss=0.0007831706330180168, metrics={'train_runtime': 1160.3168, 'train_samples_per_second': 27.579, 'train_steps_per_second': 1.724, 'total_flos': 0.0, 'train_loss': 0.0007831706330180168, 'epoch': 0.45})"
350 | ]
351 | },
352 | "execution_count": 5,
353 | "metadata": {},
354 | "output_type": "execute_result"
355 | }
356 | ],
357 | "source": [
358 | "from transformers import TrainingArguments, TrainerCallback\n",
359 | "from trl import DPOTrainer\n",
360 | "import random\n",
361 | "\n",
362 | "args = TrainingArguments(per_device_train_batch_size=16,\n",
363 | " max_steps=2000,\n",
364 | " learning_rate=1e-5,\n",
365 | " evaluation_strategy='no',\n",
366 | " optim='rmsprop',\n",
367 | " report_to='none',\n",
368 | " save_strategy='no',\n",
369 | " output_dir='output_dir')\n",
370 | "\n",
371 | "\n",
372 | "class MyCallback(TrainerCallback):\n",
373 | "\n",
374 | " def on_step_end(self, args, state, control, **kwargs):\n",
375 | " if state.global_step % 100 == 0:\n",
376 | " print(state.global_step)\n",
377 | "\n",
378 | " data = random.choice(dataset['test'])\n",
379 | " input_ids = tokenizer.encode(data['prompt'],\n",
380 | " return_tensors='pt').to('cuda')\n",
381 | "\n",
382 | " out = generate(input_ids)\n",
383 | "\n",
384 | " print(tokenizer.decode(out[0]))\n",
385 | " print('=================')\n",
386 | " print(data['chosen'])\n",
387 | " print('=================')\n",
388 | "\n",
389 | "\n",
390 | "trainer = DPOTrainer(model_dpo,\n",
391 | " model_dpo_ref,\n",
392 | " args=args,\n",
393 | " beta=0.1,\n",
394 | " train_dataset=dataset['train'],\n",
395 | " tokenizer=tokenizer,\n",
396 | " max_length=100,\n",
397 | " max_target_length=100,\n",
398 | " max_prompt_length=100,\n",
399 | " callbacks=[MyCallback()])\n",
400 | "\n",
401 | "trainer.train()"
402 | ]
403 | }
404 | ],
405 | "metadata": {
406 | "kernelspec": {
407 | "display_name": "Python [conda env:pt2]",
408 | "language": "python",
409 | "name": "conda-env-pt2-py"
410 | },
411 | "language_info": {
412 | "codemirror_mode": {
413 | "name": "ipython",
414 | "version": 3
415 | },
416 | "file_extension": ".py",
417 | "mimetype": "text/x-python",
418 | "name": "python",
419 | "nbconvert_exporter": "python",
420 | "pygments_lexer": "ipython3",
421 | "version": "3.10.13"
422 | }
423 | },
424 | "nbformat": 4,
425 | "nbformat_minor": 5
426 | }
427 |
--------------------------------------------------------------------------------
/2.dpo_手动训练.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "03697c23",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14 | " from .autonotebook import tqdm as notebook_tqdm\n",
15 | "Using sep_token, but it is not set yet.\n",
16 | "Using cls_token, but it is not set yet.\n",
17 | "Using mask_token, but it is not set yet.\n"
18 | ]
19 | },
20 | {
21 | "data": {
22 | "text/plain": [
23 | "GPT2TokenizerFast(name_or_path='tokenizer/gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
24 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
25 | "}"
26 | ]
27 | },
28 | "execution_count": 1,
29 | "metadata": {},
30 | "output_type": "execute_result"
31 | }
32 | ],
33 | "source": [
34 | "from transformers import AutoTokenizer\n",
35 | "import random\n",
36 | "import torch\n",
37 | "\n",
38 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
39 | "\n",
40 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2')\n",
41 | "tokenizer.pad_token_id = 0\n",
42 | "\n",
43 | "tokenizer"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "id": "733f6090",
50 | "metadata": {
51 | "scrolled": true
52 | },
53 | "outputs": [
54 | {
55 | "data": {
56 | "text/plain": [
57 | "(DatasetDict({\n",
58 | " train: Dataset({\n",
59 | " features: ['question', 'answer'],\n",
60 | " num_rows: 71745\n",
61 | " })\n",
62 | " test: Dataset({\n",
63 | " features: ['question', 'answer'],\n",
64 | " num_rows: 200\n",
65 | " })\n",
66 | " }),\n",
67 | " {'question': [22866,\n",
68 | " 25,\n",
69 | " 43387,\n",
70 | " 6158,\n",
71 | " 43679,\n",
72 | " 3084,\n",
73 | " 62,\n",
74 | " 3672,\n",
75 | " 62,\n",
76 | " 2414,\n",
77 | " 357,\n",
78 | " 354,\n",
79 | " 20297,\n",
80 | " 569,\n",
81 | " 31315,\n",
82 | " 1503,\n",
83 | " 11,\n",
84 | " 614,\n",
85 | " 17828,\n",
86 | " 7156,\n",
87 | " 1137,\n",
88 | " 8,\n",
89 | " 1808,\n",
90 | " 25,\n",
91 | " 2061,\n",
92 | " 318,\n",
93 | " 262,\n",
94 | " 24587,\n",
95 | " 706,\n",
96 | " 10249,\n",
97 | " 30,\n",
98 | " 3280,\n",
99 | " 25],\n",
100 | " 'answer': [46506,\n",
101 | " 24587,\n",
102 | " 16034,\n",
103 | " 3084,\n",
104 | " 62,\n",
105 | " 3672,\n",
106 | " 62,\n",
107 | " 2414,\n",
108 | " 33411,\n",
109 | " 614,\n",
110 | " 1875,\n",
111 | " 10249]})"
112 | ]
113 | },
114 | "execution_count": 2,
115 | "metadata": {},
116 | "output_type": "execute_result"
117 | }
118 | ],
119 | "source": [
120 | "from datasets import load_from_disk\n",
121 | "\n",
122 | "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n",
123 | "\n",
124 | "\n",
125 | "def f(data):\n",
126 | " question = 'context:%s question:%s answer:' % (data['context'],\n",
127 | " data['question'])\n",
128 | " answer = data['answer']\n",
129 | "\n",
130 | " question = tokenizer.encode(question, add_special_tokens=False)\n",
131 | " answer = tokenizer.encode(answer, add_special_tokens=False)\n",
132 | "\n",
133 | " return {'question': question, 'answer': answer}\n",
134 | "\n",
135 | "\n",
136 | "dataset = dataset.map(f, remove_columns=['context'])\n",
137 | "\n",
138 | "\n",
139 | "def f(data):\n",
140 | " question = len(data['question'])\n",
141 | " answer = len(data['answer'])\n",
142 | " return 25 <= question <= 65 and 10 <= answer <= 35\n",
143 | "\n",
144 | "\n",
145 | "dataset = dataset.filter(f)\n",
146 | "\n",
147 | "dataset = dataset.train_test_split(test_size=200)\n",
148 | "\n",
149 | "dataset, dataset['train'][0]"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 3,
155 | "id": "0e940d0b",
156 | "metadata": {
157 | "scrolled": true
158 | },
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/plain": [
163 | "({'input_ids': tensor([[22866, 25, 43387, ..., 0, 0, 0],\n",
164 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
165 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
166 | " ...,\n",
167 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
168 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
169 | " [22866, 25, 43387, ..., 0, 0, 0]], device='cuda:0'),\n",
170 | " 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
171 | " [1, 1, 1, ..., 0, 0, 0],\n",
172 | " [1, 1, 1, ..., 0, 0, 0],\n",
173 | " ...,\n",
174 | " [1, 1, 1, ..., 0, 0, 0],\n",
175 | " [1, 1, 1, ..., 0, 0, 0],\n",
176 | " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'),\n",
177 | " 'label': tensor([[-100, -100, -100, ..., -100, -100, -100],\n",
178 | " [-100, -100, -100, ..., -100, -100, -100],\n",
179 | " [-100, -100, -100, ..., -100, -100, -100],\n",
180 | " ...,\n",
181 | " [-100, -100, -100, ..., -100, -100, -100],\n",
182 | " [-100, -100, -100, ..., -100, -100, -100],\n",
183 | " [-100, -100, -100, ..., -100, -100, -100]], device='cuda:0')},\n",
184 | " {'input_ids': tensor([[22866, 25, 43387, ..., 0, 0, 0],\n",
185 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
186 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
187 | " ...,\n",
188 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
189 | " [22866, 25, 43387, ..., 0, 0, 0],\n",
190 | " [22866, 25, 43387, ..., 0, 0, 0]], device='cuda:0'),\n",
191 | " 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
192 | " [1, 1, 1, ..., 0, 0, 0],\n",
193 | " [1, 1, 1, ..., 0, 0, 0],\n",
194 | " ...,\n",
195 | " [1, 1, 1, ..., 0, 0, 0],\n",
196 | " [1, 1, 1, ..., 0, 0, 0],\n",
197 | " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'),\n",
198 | " 'label': tensor([[-100, -100, -100, ..., -100, -100, -100],\n",
199 | " [-100, -100, -100, ..., -100, -100, -100],\n",
200 | " [-100, -100, -100, ..., -100, -100, -100],\n",
201 | " ...,\n",
202 | " [-100, -100, -100, ..., -100, -100, -100],\n",
203 | " [-100, -100, -100, ..., -100, -100, -100],\n",
204 | " [-100, -100, -100, ..., -100, -100, -100]], device='cuda:0')})"
205 | ]
206 | },
207 | "execution_count": 3,
208 | "metadata": {},
209 | "output_type": "execute_result"
210 | }
211 | ],
212 | "source": [
213 | "def get_batch_data():\n",
214 | "\n",
215 | " def pad(data, split, lens):\n",
216 | " #做个白板\n",
217 | " input_ids = torch.full((len(data), lens),\n",
218 | " tokenizer.pad_token_id,\n",
219 | " device=device)\n",
220 | "\n",
221 | " #往白板里黏贴数据\n",
222 | " for i, d in enumerate(data):\n",
223 | " input_ids[i, :len(d)] = torch.LongTensor(d)\n",
224 | "\n",
225 | " attention_mask = (input_ids != tokenizer.pad_token_id).long()\n",
226 | "\n",
227 | " #计算label\n",
228 | " label = input_ids.clone()\n",
229 | " for l, s in zip(label, split):\n",
230 | " #问题和pad的位置是-100\n",
231 | " l[:s] = -100\n",
232 | " l[l == tokenizer.pad_token_id] = -100\n",
233 | "\n",
234 | " return {\n",
235 | " 'input_ids': input_ids,\n",
236 | " 'attention_mask': attention_mask,\n",
237 | " 'label': label\n",
238 | " }\n",
239 | "\n",
240 | " sample = random.choices(dataset['train'], k=16)\n",
241 | " question = [i['question'] for i in sample]\n",
242 | " answer = [i['answer'] for i in sample]\n",
243 | " split = [len(i) for i in question]\n",
244 | "\n",
245 | " #正确的问答\n",
246 | " choice = [\n",
247 | " q + a + [tokenizer.eos_token_id] for q, a in zip(question, answer)\n",
248 | " ]\n",
249 | "\n",
250 | " #错误的回答简单地定义为空回答就可以了\n",
251 | " reject = [q + [tokenizer.eos_token_id] for q, a in zip(question, answer)]\n",
252 | "\n",
253 | " #求最大长度\n",
254 | " lens = max([len(i) for i in choice])\n",
255 | "\n",
256 | " return pad(choice, split, lens), pad(reject, split, lens)\n",
257 | "\n",
258 | "\n",
259 | "get_batch_data()"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 4,
265 | "id": "98c47306",
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "class ModelDPO(torch.nn.Module):\n",
270 | "\n",
271 | " def __init__(self):\n",
272 | " super().__init__()\n",
273 | " from transformers import AutoModelForCausalLM\n",
274 | "\n",
275 | " self.model = AutoModelForCausalLM.from_pretrained('model/gpt2')\n",
276 | "\n",
277 | " self.to(device)\n",
278 | " self.train()\n",
279 | "\n",
280 | " def forward(self, input_ids, attention_mask):\n",
281 | " out = self.model.transformer(\n",
282 | " input_ids=input_ids,\n",
283 | " attention_mask=attention_mask).last_hidden_state\n",
284 | "\n",
285 | " return self.model.lm_head(out)\n",
286 | "\n",
287 | "\n",
288 | "model_dpo = ModelDPO()\n",
289 | "model_dpo_ref = ModelDPO()"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 5,
295 | "id": "e0e48bda",
296 | "metadata": {},
297 | "outputs": [
298 | {
299 | "data": {
300 | "text/plain": [
301 | "'context:CREATE TABLE table_name_8 (pos VARCHAR, date_from VARCHAR) question:What is Pos., when Date From is \"28 August 2008\"? answer:: \"What is, when is a table?\"\\n\\nThe table_name_name__name_8 is a string that contains the name of the table name.'"
302 | ]
303 | },
304 | "execution_count": 5,
305 | "metadata": {},
306 | "output_type": "execute_result"
307 | }
308 | ],
309 | "source": [
310 | "@torch.no_grad()\n",
311 | "def generate(input_ids):\n",
312 | " lens = input_ids.shape[1]\n",
313 | " while True:\n",
314 | " out = model_dpo(input_ids=input_ids,\n",
315 | " attention_mask=torch.ones_like(input_ids))\n",
316 | " topk = out[0, -1].topk(1)\n",
317 | "\n",
318 | " values = topk.values.softmax(0).tolist()\n",
319 | " indices = topk.indices.tolist()\n",
320 | " next_word = random.choices(indices, weights=values)\n",
321 | "\n",
322 | " next_word = torch.LongTensor(next_word).unsqueeze(0).to('cuda')\n",
323 | " input_ids = torch.cat([input_ids, next_word], dim=1)\n",
324 | "\n",
325 | " if input_ids.shape[1] - lens >= 35:\n",
326 | " break\n",
327 | "\n",
328 | " if input_ids[0, -1] == tokenizer.eos_token_id:\n",
329 | " break\n",
330 | "\n",
331 | " return input_ids\n",
332 | "\n",
333 | "\n",
334 | "input_ids = dataset['test'][0]['question']\n",
335 | "input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)\n",
336 | "\n",
337 | "out = generate(input_ids)\n",
338 | "\n",
339 | "tokenizer.decode(out[0])"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 6,
345 | "id": "c35e3b28",
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "data": {
350 | "text/plain": [
351 | "tensor([-58.1640, -61.3286, -39.4775, -71.7320, -62.9812, -67.9002, -35.7873,\n",
352 | " -55.2316, -55.5689, -73.9519, -81.4958, -37.8672, -55.4771, -86.0505,\n",
353 | " -48.4158, -43.4941], device='cuda:0', grad_fn=)"
354 | ]
355 | },
356 | "execution_count": 6,
357 | "metadata": {},
358 | "output_type": "execute_result"
359 | }
360 | ],
361 | "source": [
362 | "def get_prob_log(model, choice, reject):\n",
363 | " b = choice['input_ids'].shape[0]\n",
364 | "\n",
365 | " #合并两部分输入,同时计算以提高效率\n",
366 | " #[b, 21]\n",
367 | " input_ids = torch.cat([choice['input_ids'], reject['input_ids']], dim=0)\n",
368 | " attention_mask = torch.cat(\n",
369 | " [choice['attention_mask'], reject['attention_mask']], dim=0)\n",
370 | " label = torch.cat([choice['label'], reject['label']], dim=0)\n",
371 | "\n",
372 | " #[b, 21, 28]\n",
373 | " out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
374 | "\n",
375 | " #偏移以对齐\n",
376 | " #[b, 20]\n",
377 | " label = label[:, 1:]\n",
378 | " #[b, 20, 28]\n",
379 | " out = out[:, :-1]\n",
380 | "\n",
381 | " #取所有字的预测概率,因为要求联合概率,所以取对数\n",
382 | " out = (out.softmax(2) + 1e-8).log()\n",
383 | "\n",
384 | " #取预测到label的概率\n",
385 | " #索引不能是负数,所以这里把负数置0\n",
386 | " index = label.clone().unsqueeze(2)\n",
387 | " index[index == -100] = 0\n",
388 | " prob = out.gather(2, index=index).squeeze(2)\n",
389 | "\n",
390 | " #只取答案部分的loss,筛选后,所有答案的概率对数求和\n",
391 | " prob = (prob * (label != -100)).sum(1)\n",
392 | "\n",
393 | " #choice和reject的预测概率求差\n",
394 | " return prob[:b] - prob[b:]\n",
395 | "\n",
396 | "\n",
397 | "get_prob_log(model_dpo, *get_batch_data())"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": 7,
403 | "id": "fb049323",
404 | "metadata": {
405 | "scrolled": false
406 | },
407 | "outputs": [
408 | {
409 | "name": "stdout",
410 | "output_type": "stream",
411 | "text": [
412 | "0 context:CREATE TABLE table_name_50 (character_name VARCHAR, voice_actor__english_1998___pioneer_ VARCHAR) question:what character did Laara sadiq play answer:what character did Laara play answer:what character did Laara sadiq play answer:what character did Laara sadiq play answer:what character did Laara play answer\n",
413 | "=========\n",
414 | "SELECT character_name FROM table_name_50 WHERE voice_actor__english_1998___pioneer_ = \"laara sadiq\"\n",
415 | "=========\n",
416 | "100 context:CREATE TABLE table_name_94 (position VARCHAR, pick VARCHAR) question:What is pick 246's position? answer:SELECT position FROM table_name_94 WHERE position = \"pick 246\"<|endoftext|>\n",
417 | "=========\n",
418 | "SELECT position FROM table_name_94 WHERE pick = 246\n",
419 | "=========\n",
420 | "200 context:CREATE TABLE table_name_78 (division_record VARCHAR, school VARCHAR) question:What is the division record for Woodbridge? answer:SELECT division_record FROM table_name_78 WHERE school = \"woodbridge\"<|endoftext|>\n",
421 | "=========\n",
422 | "SELECT division_record FROM table_name_78 WHERE school = \"woodbridge\"\n",
423 | "=========\n",
424 | "300 context:CREATE TABLE table_name_20 (name VARCHAR, nat VARCHAR) question:What's the name of MKD? answer:SELECT name FROM table_name_20 WHERE nat = \"MKD\"<|endoftext|>\n",
425 | "=========\n",
426 | "SELECT name FROM table_name_20 WHERE nat = \"mkd\"\n",
427 | "=========\n",
428 | "400 context:CREATE TABLE table_21313327_1 (written_by VARCHAR, no_in_season VARCHAR) question:In season number 3, who were the writers? answer:SELECT written_by FROM table_21313327_1 WHERE no_in_season = 3<|endoftext|>\n",
429 | "=========\n",
430 | "SELECT written_by FROM table_21313327_1 WHERE no_in_season = 3\n",
431 | "=========\n",
432 | "500 context:CREATE TABLE table_name_95 (grid INTEGER, time_retired VARCHAR, laps VARCHAR) question:what is the highest grid when the time/retired is fuel pump and the laps is more than 26? answer:SELECT MAX(grid) FROM table_name_95 WHERE time_retired = \"pump and laps\" AND laps > 26<|endoftext|>\n",
433 | "=========\n",
434 | "SELECT MAX(grid) FROM table_name_95 WHERE time_retired = \"fuel pump\" AND laps > 26\n",
435 | "=========\n",
436 | "600 context:CREATE TABLE table_name_31 (top_10 INTEGER, top_25 VARCHAR, cuts_made VARCHAR) question:What is the average number of top-10s for the major with 2 top-25s and fewer than 10 cuts made? answer:SELECT AVG(top_10) FROM table_name_31 WHERE cuts_made = 2 AND top_25 < 10<|endoftext|>\n",
437 | "=========\n",
438 | "SELECT AVG(top_10) FROM table_name_31 WHERE top_25 = 2 AND cuts_made < 10\n",
439 | "=========\n",
440 | "700 context:CREATE TABLE table_name_94 (position VARCHAR, pick VARCHAR) question:What is pick 246's position? answer:SELECT position FROM table_name_94 WHERE pick = \"246\"<|endoftext|>\n",
441 | "=========\n",
442 | "SELECT position FROM table_name_94 WHERE pick = 246\n",
443 | "=========\n",
444 | "800 context:CREATE TABLE table_17358515_1 (points_2 VARCHAR, team VARCHAR) question:How many points did Goole Town accumulate? answer:SELECT points_2 FROM table_17358515_1 WHERE team = \"Goole Town\"<|endoftext|>\n",
445 | "=========\n",
446 | "SELECT COUNT(points_2) FROM table_17358515_1 WHERE team = \"Goole Town\"\n",
447 | "=========\n",
448 | "900 context:CREATE TABLE table_11677691_9 (hometown VARCHAR, college VARCHAR) question:Which hometown has the college Louisiana State? answer:SELECT hometown FROM table_11677691_9 WHERE college = \"Louisiana State\"<|endoftext|>\n",
449 | "=========\n",
450 | "SELECT hometown FROM table_11677691_9 WHERE college = \"Louisiana State\"\n",
451 | "=========\n",
452 | "1000 context:CREATE TABLE table_1341423_40 (candidates VARCHAR, incumbent VARCHAR) question:who are the candidates when the incumbent is lindsey graham? answer:SELECT candidates FROM table_1341423_40 WHERE incumbent = \"Lindsey Graham\"<|endoftext|>\n",
453 | "=========\n",
454 | "SELECT candidates FROM table_1341423_40 WHERE incumbent = \"Lindsey Graham\"\n",
455 | "=========\n",
456 | "1100 context:CREATE TABLE table_name_77 (country VARCHAR, series_premiere VARCHAR) question:Which country had a series that premiered on September 4, 2006? answer:SELECT country FROM table_name_77 WHERE series_premiere = \"september 4, 2006\"<|endoftext|>\n",
457 | "=========\n",
458 | "SELECT country FROM table_name_77 WHERE series_premiere = \"september 4, 2006\"\n",
459 | "=========\n",
460 | "1200 context:CREATE TABLE table_name_35 (director VARCHAR, title VARCHAR) question:Who is the director of Antz? answer:SELECT director FROM table_name_35 WHERE title = \"antz\"<|endoftext|>\n",
461 | "=========\n",
462 | "SELECT director FROM table_name_35 WHERE title = \"antz\"\n",
463 | "=========\n",
464 | "1300 context:CREATE TABLE table_name_55 (role VARCHAR, direction VARCHAR) question:Which role had thulasidas direction? answer:SELECT role FROM table_name_55 WHERE direction = \"thulasidas direction\"<|endoftext|>\n",
465 | "=========\n",
466 | "SELECT role FROM table_name_55 WHERE direction = \"thulasidas\"\n",
467 | "=========\n",
468 | "1400 context:CREATE TABLE table_24223834_3 (us_viewers__in_millions_ VARCHAR, directed_by VARCHAR) question:How many million U.S. viewers watched the episode that Daniel H. Forer directed? answer:SELECT COUNT(us_viewers__in_millions_) FROM table_24223834_3 WHERE directed_by = \"Daniel H. Forer\n",
469 | "=========\n",
470 | "SELECT us_viewers__in_millions_ FROM table_24223834_3 WHERE directed_by = \"Daniel H. Forer\"\n",
471 | "=========\n",
472 | "1500 context:CREATE TABLE table_name_89 (score VARCHAR, player VARCHAR) question:What is the score for Jock Hutchison? answer:SELECT score FROM table_name_89 WHERE player = \"jock hutch jock hutchison\"<|endoftext|>\n",
473 | "=========\n",
474 | "SELECT score FROM table_name_89 WHERE player = \"jock hutchison\"\n",
475 | "=========\n",
476 | "1600 context:CREATE TABLE table_24018430_3 (original_air_date VARCHAR, production_code VARCHAR) question:what is the original air date for production code 216? answer:SELECT original_air_date FROM table_24018430_3 WHERE production_code = \"216\"<|endoftext|>\n",
477 | "=========\n",
478 | "SELECT original_air_date FROM table_24018430_3 WHERE production_code = 216\n",
479 | "=========\n",
480 | "1700 context:CREATE TABLE table_name_32 (west VARCHAR, east VARCHAR) question:Who was in the West when ESV Gebensbach was in the East? answer:SELECT west FROM table_name_32 WHERE east = \"ebvesbach\"<|endoftext|>\n",
481 | "=========\n",
482 | "SELECT west FROM table_name_32 WHERE east = \"esv gebensbach\"\n",
483 | "=========\n",
484 | "1800 context:CREATE TABLE table_name_98 (to_par VARCHAR, margin_of_victory VARCHAR, tournament VARCHAR) question:What is To par, when Margin of Victory is \"2 Strokes\", and when Tournament is \"Women's British Open\"? answer:SELECT to_par FROM table_name_98 WHERE margin_of_victory = \"2 strokes\" AND tournament = \"women's British Open\"<|endoftext|>\n",
485 | "=========\n",
486 | "SELECT to_par FROM table_name_98 WHERE margin_of_victory = \"2 strokes\" AND tournament = \"women's british open\"\n",
487 | "=========\n",
488 | "1900 context:CREATE TABLE table_name_71 (wheels INTEGER, type VARCHAR, location VARCHAR) question:What average wheels has accounting as the type, with IBM Collection as the location? answer:SELECT AVG(wheels) FROM table_name_71 WHERE type = \"audit\" AND location = \"ibm collection\"<|endoftext|>\n",
489 | "=========\n",
490 | "SELECT AVG(wheels) FROM table_name_71 WHERE type = \"accounting\" AND location = \"ibm collection\"\n",
491 | "=========\n"
492 | ]
493 | }
494 | ],
495 | "source": [
496 | "optimizer = torch.optim.Adam(model_dpo.parameters(),\n",
497 | " lr=1e-5,\n",
498 | " betas=(0.9, 0.999),\n",
499 | " eps=1e-8)\n",
500 | "\n",
501 | "for i in range(2000):\n",
502 | " choice, reject = get_batch_data()\n",
503 | "\n",
504 | " #两个模型分别计算概率对数\n",
505 | " prob_log = get_prob_log(model_dpo, choice, reject)\n",
506 | " with torch.no_grad():\n",
507 | " prob_log_ref = get_prob_log(model_dpo_ref, choice, reject)\n",
508 | "\n",
509 | " #两份概率计算kl散度\n",
510 | " kl = -0.1 * (prob_log - prob_log_ref)\n",
511 | "\n",
512 | " #以kl散度计算loss\n",
513 | " loss = (kl.sigmoid() + 1e-8).log().mean()\n",
514 | " loss.backward()\n",
515 | " optimizer.step()\n",
516 | " optimizer.zero_grad()\n",
517 | "\n",
518 | " if i % 100 == 0:\n",
519 | " data = random.choice(dataset['test'])\n",
520 | " input_ids = torch.LongTensor(data['question']).unsqueeze(0).to(device)\n",
521 | "\n",
522 | " out = generate(input_ids)\n",
523 | "\n",
524 | " print(i, tokenizer.decode(out[0]))\n",
525 | " print('=========')\n",
526 | " print(tokenizer.decode(data['answer']))\n",
527 | " print('=========')"
528 | ]
529 | }
530 | ],
531 | "metadata": {
532 | "kernelspec": {
533 | "display_name": "Python [conda env:pt2]",
534 | "language": "python",
535 | "name": "conda-env-pt2-py"
536 | },
537 | "language_info": {
538 | "codemirror_mode": {
539 | "name": "ipython",
540 | "version": 3
541 | },
542 | "file_extension": ".py",
543 | "mimetype": "text/x-python",
544 | "name": "python",
545 | "nbconvert_exporter": "python",
546 | "pygments_lexer": "ipython3",
547 | "version": "3.10.13"
548 | }
549 | },
550 | "nbformat": 4,
551 | "nbformat_minor": 5
552 | }
553 |
--------------------------------------------------------------------------------
/3.ppo_trl训练.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "ab835cce",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14 | " from .autonotebook import tqdm as notebook_tqdm\n",
15 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
16 | "Using sep_token, but it is not set yet.\n",
17 | "Using cls_token, but it is not set yet.\n",
18 | "Using mask_token, but it is not set yet.\n"
19 | ]
20 | },
21 | {
22 | "data": {
23 | "text/plain": [
24 | "GPT2TokenizerFast(name_or_path='tokenizer/lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!', 'additional_special_tokens': ['<|endoftext|>']}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
25 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
26 | "}"
27 | ]
28 | },
29 | "execution_count": 1,
30 | "metadata": {},
31 | "output_type": "execute_result"
32 | }
33 | ],
34 | "source": [
35 | "from transformers import AutoTokenizer\n",
36 | "import random\n",
37 | "import torch\n",
38 | "\n",
39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
40 | "\n",
41 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')\n",
42 | "tokenizer.pad_token_id = 0\n",
43 | "\n",
44 | "tokenizer"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "dc57406a",
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "data": {
55 | "text/plain": [
56 | "(Dataset({\n",
57 | " features: ['question'],\n",
58 | " num_rows: 50000\n",
59 | " }),\n",
60 | " {'question': [40, 26399, 314, 3001, 327]})"
61 | ]
62 | },
63 | "execution_count": 2,
64 | "metadata": {},
65 | "output_type": "execute_result"
66 | }
67 | ],
68 | "source": [
69 | "from datasets import load_from_disk, concatenate_datasets\n",
70 | "\n",
71 | "dataset = load_from_disk('dataset/imdb')\n",
72 | "dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])\n",
73 | "\n",
74 | "\n",
75 | "def f(data):\n",
76 | " question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]\n",
77 | " return {'question': question}\n",
78 | "\n",
79 | "\n",
80 | "dataset = dataset.map(f, remove_columns=['label', 'text'])\n",
81 | "\n",
82 | "\n",
83 | "def f(data):\n",
84 | " return len(data['question']) == 5\n",
85 | "\n",
86 | "\n",
87 | "dataset = dataset.filter(f)\n",
88 | "\n",
89 | "dataset, dataset[0]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 3,
95 | "id": "12dbc381",
96 | "metadata": {
97 | "scrolled": true
98 | },
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "([1,\n",
104 | " 1,\n",
105 | " 1,\n",
106 | " 1,\n",
107 | " 0,\n",
108 | " 1,\n",
109 | " 0,\n",
110 | " 0,\n",
111 | " 1,\n",
112 | " 0,\n",
113 | " 0,\n",
114 | " 0,\n",
115 | " 0,\n",
116 | " 0,\n",
117 | " 0,\n",
118 | " 1,\n",
119 | " 1,\n",
120 | " 0,\n",
121 | " 0,\n",
122 | " 1,\n",
123 | " 1,\n",
124 | " 1,\n",
125 | " 0,\n",
126 | " 0,\n",
127 | " 1,\n",
128 | " 0,\n",
129 | " 0,\n",
130 | " 1,\n",
131 | " 0,\n",
132 | " 0,\n",
133 | " 0,\n",
134 | " 1,\n",
135 | " 1,\n",
136 | " 0,\n",
137 | " 0,\n",
138 | " 0,\n",
139 | " 1,\n",
140 | " 1,\n",
141 | " 0,\n",
142 | " 0,\n",
143 | " 0,\n",
144 | " 1,\n",
145 | " 1,\n",
146 | " 1,\n",
147 | " 1,\n",
148 | " 1,\n",
149 | " 0,\n",
150 | " 1,\n",
151 | " 1,\n",
152 | " 1,\n",
153 | " 1,\n",
154 | " 1,\n",
155 | " 1,\n",
156 | " 1,\n",
157 | " 0,\n",
158 | " 1,\n",
159 | " 1,\n",
160 | " 1,\n",
161 | " 1,\n",
162 | " 1,\n",
163 | " 0,\n",
164 | " 0,\n",
165 | " 0,\n",
166 | " 0,\n",
167 | " 1,\n",
168 | " 0,\n",
169 | " 1,\n",
170 | " 0,\n",
171 | " 1,\n",
172 | " 1,\n",
173 | " 0,\n",
174 | " 0,\n",
175 | " 1,\n",
176 | " 0,\n",
177 | " 0,\n",
178 | " 0,\n",
179 | " 1,\n",
180 | " 0,\n",
181 | " 0,\n",
182 | " 1,\n",
183 | " 1,\n",
184 | " 1,\n",
185 | " 1,\n",
186 | " 0,\n",
187 | " 1,\n",
188 | " 1,\n",
189 | " 1,\n",
190 | " 1,\n",
191 | " 1,\n",
192 | " 1,\n",
193 | " 0,\n",
194 | " 0,\n",
195 | " 1,\n",
196 | " 1,\n",
197 | " 1,\n",
198 | " 0,\n",
199 | " 0,\n",
200 | " 1,\n",
201 | " 0,\n",
202 | " 0,\n",
203 | " 1,\n",
204 | " 0,\n",
205 | " 0,\n",
206 | " 1,\n",
207 | " 0,\n",
208 | " 1,\n",
209 | " 1,\n",
210 | " 0,\n",
211 | " 0,\n",
212 | " 1,\n",
213 | " 0,\n",
214 | " 1,\n",
215 | " 1,\n",
216 | " 1,\n",
217 | " 0,\n",
218 | " 0,\n",
219 | " 0,\n",
220 | " 0,\n",
221 | " 1,\n",
222 | " 0,\n",
223 | " 0,\n",
224 | " 1,\n",
225 | " 0,\n",
226 | " 1,\n",
227 | " 0,\n",
228 | " 1,\n",
229 | " 1,\n",
230 | " 0],\n",
231 | " [[16, 36623, 290, 15579, 1111, 2495],\n",
232 | " [16, 39, 2743, 16543, 318, 546],\n",
233 | " [16, 5962, 612, 373, 23459, 72],\n",
234 | " [16, 7149, 2618, 508, 2925, 284],\n",
235 | " [15, 40, 3505, 4964, 428, 319],\n",
236 | " [16, 3152, 262, 7396, 11533, 286],\n",
237 | " [15, 43920, 32520, 26494, 15992, 11],\n",
238 | " [15, 27991, 1472, 318, 3737, 530],\n",
239 | " [16, 4598, 407, 307, 6522, 28970],\n",
240 | " [15, 464, 29274, 45270, 11, 7808],\n",
241 | " [15, 1212, 2646, 2900, 510, 319],\n",
242 | " [15, 464, 691, 661, 1312, 561],\n",
243 | " [15, 2025, 21218, 306, 24007, 6833],\n",
244 | " [15, 38202, 4835, 15300, 1525, 7504],\n",
245 | " [15, 40, 550, 8359, 262, 15812],\n",
246 | " [16, 4514, 2000, 286, 1450, 33743],\n",
247 | " [16, 11006, 290, 24967, 7091, 7012],\n",
248 | " [15, 5962, 314, 1276, 910, 326],\n",
249 | " [15, 21448, 25699, 423, 4367, 262],\n",
250 | " [16, 8499, 4379, 428, 2646, 329],\n",
251 | " [16, 26556, 1468, 4947, 286, 3923],\n",
252 | " [16, 5005, 7697, 360, 5945, 259],\n",
253 | " [15, 3666, 477, 12, 2435, 4004],\n",
254 | " [15, 13615, 284, 2185, 318, 257],\n",
255 | " [16, 32, 21104, 12006, 290, 37928],\n",
256 | " [15, 40, 423, 1239, 1100, 262],\n",
257 | " [15, 1858, 318, 257, 1256, 2642],\n",
258 | " [16, 1212, 3807, 373, 257, 13899],\n",
259 | " [15, 1212, 318, 257, 6635, 7427],\n",
260 | " [15, 40, 836, 470, 760, 703],\n",
261 | " [15, 7, 4561, 9437, 364, 9426],\n",
262 | " [16, 40, 460, 470, 1037, 475],\n",
263 | " [16, 22442, 284, 428, 2646, 11],\n",
264 | " [15, 40, 7342, 705, 4834, 7670],\n",
265 | " [15, 1858, 318, 257, 1621, 357],\n",
266 | " [15, 3198, 286, 2407, 257, 1178],\n",
267 | " [16, 15784, 736, 625, 262, 1613],\n",
268 | " [16, 2514, 307, 5508, 11, 262],\n",
269 | " [15, 1212, 318, 655, 530, 286],\n",
270 | " [15, 40, 1101, 407, 1654, 703],\n",
271 | " [15, 1212, 3807, 318, 34445, 257],\n",
272 | " [16, 7556, 2140, 2646, 546, 2646],\n",
273 | " [16, 40, 1053, 1107, 8359, 428],\n",
274 | " [16, 1, 818, 262, 995, 286],\n",
275 | " [16, 17821, 3923, 787, 262, 1266],\n",
276 | " [16, 1, 464, 7610, 5932, 1],\n",
277 | " [15, 49738, 12754, 22732, 3807, 326],\n",
278 | " [16, 464, 13172, 286, 2094, 11952],\n",
279 | " [16, 5779, 11, 314, 3214, 329],\n",
280 | " [16, 40, 1053, 1100, 617, 286],\n",
281 | " [16, 19419, 257, 4719, 11, 530],\n",
282 | " [16, 40, 716, 523, 9675, 44922],\n",
283 | " [16, 3137, 7342, 340, 11, 502],\n",
284 | " [16, 40, 1816, 284, 1524, 351],\n",
285 | " [15, 1212, 905, 318, 1049, 13],\n",
286 | " [16, 464, 717, 286, 1936, 520],\n",
287 | " [16, 40, 423, 1239, 587, 355],\n",
288 | " [16, 3237, 262, 6247, 7721, 276],\n",
289 | " [16, 40, 3505, 4379, 262, 12268],\n",
290 | " [16, 1212, 318, 257, 2646, 1312],\n",
291 | " [15, 40, 9159, 314, 423, 257],\n",
292 | " [15, 7975, 286, 262, 867, 7328],\n",
293 | " [15, 21944, 273, 18365, 363, 11925],\n",
294 | " [15, 1858, 338, 645, 779, 2111],\n",
295 | " [16, 1212, 2646, 318, 2089, 13],\n",
296 | " [15, 3844, 517, 621, 1683, 356],\n",
297 | " [16, 11696, 13606, 11, 257, 6036],\n",
298 | " [15, 32, 13206, 11, 5508, 11],\n",
299 | " [16, 1026, 373, 28294, 0, 383],\n",
300 | " [16, 5703, 44207, 1576, 284, 307],\n",
301 | " [15, 40, 550, 30508, 262, 3807],\n",
302 | " [15, 10814, 3730, 290, 4813, 0],\n",
303 | " [16, 32, 6507, 11, 6507, 6504],\n",
304 | " [15, 28718, 5285, 25, 383, 15875],\n",
305 | " [15, 3792, 612, 257, 1492, 11946],\n",
306 | " [15, 34, 49399, 1313, 373, 12132],\n",
307 | " [16, 1639, 561, 1107, 761, 284],\n",
308 | " [15, 4480, 644, 484, 550, 13],\n",
309 | " [15, 1212, 2646, 3607, 649, 3616],\n",
310 | " [16, 40, 550, 7342, 1811, 1528],\n",
311 | " [16, 7554, 21827, 749, 9857, 2646],\n",
312 | " [16, 40, 2497, 326, 428, 3807],\n",
313 | " [16, 6610, 16304, 0, 4162, 319],\n",
314 | " [15, 50, 707, 340, 379, 262],\n",
315 | " [16, 40, 6198, 2497, 428, 319],\n",
316 | " [16, 43977, 25, 383, 4403, 13069],\n",
317 | " [16, 1135, 1094, 88, 9214, 789],\n",
318 | " [16, 5779, 11, 484, 1908, 340],\n",
319 | " [16, 10970, 406, 6158, 6006, 32297],\n",
320 | " [16, 1212, 2646, 3033, 734, 286],\n",
321 | " [15, 76, 40302, 14509, 1559, 318],\n",
322 | " [15, 37, 48437, 338, 645, 343],\n",
323 | " [16, 1026, 338, 4998, 326, 14549],\n",
324 | " [16, 10374, 16057, 9737, 271, 392],\n",
325 | " [16, 3666, 5212, 290, 314, 3332],\n",
326 | " [15, 28951, 64, 318, 257, 845],\n",
327 | " [15, 40, 2497, 428, 3807, 257],\n",
328 | " [16, 49141, 29330, 318, 14702, 11],\n",
329 | " [15, 37887, 618, 257, 5581, 3182],\n",
330 | " [15, 38413, 17049, 1760, 13, 317],\n",
331 | " [16, 40, 7342, 366, 37, 28789],\n",
332 | " [15, 17947, 3978, 11255, 1441, 1363],\n",
333 | " [15, 33481, 5289, 743, 407, 423],\n",
334 | " [16, 1212, 318, 407, 257, 14604],\n",
335 | " [15, 1212, 4471, 20718, 514, 284],\n",
336 | " [16, 40, 1807, 428, 373, 257],\n",
337 | " [16, 1212, 373, 281, 3499, 2050],\n",
338 | " [15, 51, 13, 57, 13, 2947],\n",
339 | " [15, 464, 7110, 340, 338, 407],\n",
340 | " [16, 3633, 262, 19661, 287, 428],\n",
341 | " [15, 32, 7818, 26337, 25, 8381],\n",
342 | " [16, 1212, 3807, 4394, 5626, 39],\n",
343 | " [16, 11380, 886, 286, 262, 1621],\n",
344 | " [16, 21447, 2444, 379, 257, 8294],\n",
345 | " [15, 1722, 262, 3670, 5644, 612],\n",
346 | " [15, 32, 12362, 922, 3807, 0],\n",
347 | " [15, 1212, 3807, 3190, 4966, 24590],\n",
348 | " [15, 44960, 27633, 4244, 290, 262],\n",
349 | " [16, 1, 44, 18526, 1, 318],\n",
350 | " [15, 38, 385, 6656, 10844, 468],\n",
351 | " [15, 40, 7342, 428, 938, 1755],\n",
352 | " [16, 1722, 287, 1703, 417, 494],\n",
353 | " [15, 505, 1110, 2130, 531, 8781],\n",
354 | " [16, 1212, 318, 1194, 286, 616],\n",
355 | " [15, 1, 34, 963, 6495, 1],\n",
356 | " [16, 40, 1422, 470, 1100, 262],\n",
357 | " [16, 40, 373, 12451, 366, 28524],\n",
358 | " [15, 4366, 2365, 499, 2771, 3660]])"
359 | ]
360 | },
361 | "execution_count": 3,
362 | "metadata": {},
363 | "output_type": "execute_result"
364 | }
365 | ],
366 | "source": [
367 | "def get_batch_data():\n",
368 | " label = random.choices(range(2), k=128)\n",
369 | " question = random.choices(dataset, k=128)\n",
370 | " question = [i['question'] for i in question]\n",
371 | "\n",
372 | " question = [[tokenizer.convert_tokens_to_ids(str(l))] + q\n",
373 | " for l, q in zip(label, question)]\n",
374 | "\n",
375 | " return label, question\n",
376 | "\n",
377 | "\n",
378 | "get_batch_data()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 4,
384 | "id": "81ee37d4",
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "name": "stderr",
389 | "output_type": "stream",
390 | "text": [
391 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
392 | " warnings.warn(\n",
393 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
394 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
395 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
396 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']\n",
397 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
398 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
399 | ]
400 | }
401 | ],
402 | "source": [
403 | "from trl import AutoModelForCausalLMWithValueHead\n",
404 | "\n",
405 | "model_ppo = AutoModelForCausalLMWithValueHead.from_pretrained(\n",
406 | " 'model/lvwerra/gpt2-imdb').to(device)\n",
407 | "model_ppo_ref = AutoModelForCausalLMWithValueHead.from_pretrained(\n",
408 | " 'model/lvwerra/gpt2-imdb').to(device)\n",
409 | "\n",
410 | "for i in model_ppo_ref.parameters():\n",
411 | " i.requires_grad_(False)"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 5,
417 | "id": "805234d3",
418 | "metadata": {},
419 | "outputs": [
420 | {
421 | "name": "stderr",
422 | "output_type": "stream",
423 | "text": [
424 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
425 | ]
426 | }
427 | ],
428 | "source": [
429 | "from transformers import AutoModelForSequenceClassification\n",
430 | "\n",
431 | "tokenizer_cls = AutoTokenizer.from_pretrained(\n",
432 | " 'tokenizer/lvwerra/distilbert-imdb')\n",
433 | "model_cls = AutoModelForSequenceClassification.from_pretrained(\n",
434 | " 'model/lvwerra/distilbert-imdb').to(device)\n",
435 | "\n",
436 | "for i in model_cls.parameters():\n",
437 | " i.requires_grad_(False)"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": 6,
443 | "id": "c3e7a327",
444 | "metadata": {
445 | "scrolled": true
446 | },
447 | "outputs": [
448 | {
449 | "data": {
450 | "text/plain": [
451 | "(tensor([1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1,\n",
452 | " 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,\n",
453 | " 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,\n",
454 | " 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1,\n",
455 | " 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1,\n",
456 | " 0, 1, 1, 0, 0, 1, 0, 0], device='cuda:0'),\n",
457 | " [tensor([ 16, 40, 1053, 6810, 625, 262], device='cuda:0'),\n",
458 | " tensor([ 16, 1212, 3807, 373, 11, 286], device='cuda:0'),\n",
459 | " tensor([ 15, 3673, 881, 284, 910, 319], device='cuda:0'),\n",
460 | " tensor([ 15, 20459, 8066, 307, 1327, 284], device='cuda:0'),\n",
461 | " tensor([ 15, 38743, 31140, 2492, 470, 262], device='cuda:0'),\n",
462 | " tensor([ 15, 2061, 257, 2085, 286, 257], device='cuda:0'),\n",
463 | " tensor([ 16, 40, 460, 766, 1521, 43442], device='cuda:0'),\n",
464 | " tensor([ 16, 40, 3521, 470, 4043, 284], device='cuda:0'),\n",
465 | " tensor([16, 49, 13, 40, 13, 34], device='cuda:0'),\n",
466 | " tensor([ 15, 1212, 2646, 318, 257, 5003], device='cuda:0')])"
467 | ]
468 | },
469 | "execution_count": 6,
470 | "metadata": {},
471 | "output_type": "execute_result"
472 | }
473 | ],
474 | "source": [
475 | "def get_question():\n",
476 | " label, question = get_batch_data()\n",
477 | " label = torch.LongTensor(label).to(device)\n",
478 | "\n",
479 | " question = [torch.LongTensor(i).to(device) for i in question]\n",
480 | "\n",
481 | " return label, question\n",
482 | "\n",
483 | "\n",
484 | "label, question = get_question()\n",
485 | "\n",
486 | "label, question[:10]"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": 7,
492 | "id": "b77e7d96",
493 | "metadata": {
494 | "scrolled": true
495 | },
496 | "outputs": [
497 | {
498 | "data": {
499 | "text/plain": [
500 | "[tensor([ 812, 326, 3427, 14479, 3371, 6918, 11, 2592, 739, 262,\n",
501 | " 34731, 286, 366, 7353, 12, 23922, 1042, 1, 389, 783,\n",
502 | " 5371, 286, 11560, 257, 43122, 290, 326, 22041, 2331, 1234,\n",
503 | " 379, 9033], device='cuda:0'),\n",
504 | " tensor([ 1781, 11, 2208, 45002, 13, 314, 1392, 604, 14, 940,\n",
505 | " 737, 33443, 10544, 13, 383, 7205, 373, 407, 2089, 11,\n",
506 | " 4632, 616, 6000, 18641, 286, 262, 3807, 2277, 616, 1767,\n",
507 | " 13, 406], device='cuda:0'),\n",
508 | " tensor([ 287, 597, 584, 2912, 780, 1111, 423, 617, 922, 2173,\n",
509 | " 329, 514, 13, 1119, 1111, 4893, 845, 880, 703, 6283,\n",
510 | " 262, 1621, 318, 290, 635, 14846, 14397, 319, 257, 1241,\n",
511 | " 326, 338], device='cuda:0'),\n",
512 | " tensor([ 804, 379, 606, 290, 852, 15049, 644, 257, 2823, 428,\n",
513 | " 318, 13, 314, 460, 470, 787, 503, 257, 2060, 1517,\n",
514 | " 546, 262, 3807, 314, 892, 475, 314, 481, 9159, 1312,\n",
515 | " 8288, 428], device='cuda:0'),\n",
516 | " tensor([ 691, 6575, 8442, 286, 262, 705, 2154, 82, 508, 336,\n",
517 | " 2954, 357, 392, 20143, 284, 4656, 737, 50256],\n",
518 | " device='cuda:0'),\n",
519 | " tensor([ 3807, 329, 1105, 290, 1315, 614, 44979, 29847, 1671, 1220,\n",
520 | " 6927, 1671, 11037, 1212, 530, 318, 257, 36364, 5287, 13,\n",
521 | " 3887, 2168, 1281, 994, 468, 4054, 11, 4556, 286, 1781,\n",
522 | " 345, 389], device='cuda:0'),\n",
523 | " tensor([ 290, 27583, 1661, 3947, 1342, 3616, 913, 777, 1528, 13,\n",
524 | " 921, 892, 1353, 29838, 10544, 460, 3491, 287, 257, 4101,\n",
525 | " 5664, 3807, 290, 1577, 9739, 477, 832, 262, 835, 30,\n",
526 | " 1892, 287], device='cuda:0'),\n",
527 | " tensor([ 651, 736, 319, 262, 6614, 13, 1675, 307, 3148, 11, 340, 373,\n",
528 | " 5867, 3625, 890, 13, 843, 340, 373, 546, 9796, 3625, 416, 3439,\n",
529 | " 3625, 890, 13, 1320, 1838, 340, 804, 588], device='cuda:0'),\n",
530 | " tensor([ 13, 48, 13, 379, 262, 7369, 838, 19245, 319, 42490,\n",
531 | " 318, 407, 284, 307, 287, 262, 1551, 1643, 14702, 13,\n",
532 | " 1026, 2125, 470, 1327, 284, 1560, 655, 703, 10927, 3619,\n",
533 | " 40365, 318], device='cuda:0'),\n",
534 | " tensor([ 9875, 422, 923, 284, 5461, 351, 257, 21840, 4226, 290,\n",
535 | " 281, 6275, 3350, 326, 16316, 23501, 22605, 351, 257, 23754,\n",
536 | " 2854, 355, 34705, 2241, 13, 30436, 6550, 30673, 11392, 318,\n",
537 | " 257, 3756], device='cuda:0')]"
538 | ]
539 | },
540 | "execution_count": 7,
541 | "metadata": {},
542 | "output_type": "execute_result"
543 | }
544 | ],
545 | "source": [
546 | "#包装类,用于生成\n",
547 | "def generate(input_ids):\n",
548 | " return model_ppo.generate(input_ids=input_ids,\n",
549 | " min_length=-1,\n",
550 | " top_k=0.0,\n",
551 | " top_p=1.0,\n",
552 | " do_sample=True,\n",
553 | " pad_token_id=tokenizer.pad_token_id,\n",
554 | " max_new_tokens=32,\n",
555 | " eos_token_id=tokenizer.eos_token_id)\n",
556 | "\n",
557 | "\n",
558 | "def get_answer(question):\n",
559 | " #如果question的长度确定,这里可以转换成批运算\n",
560 | " if True:\n",
561 | " answer = generate(torch.stack(question))\n",
562 | "\n",
563 | " answer_new = []\n",
564 | " for i in answer:\n",
565 | " if tokenizer.eos_token_id not in i:\n",
566 | " answer_new.append(i.unsqueeze(0))\n",
567 | " continue\n",
568 | " split = i.tolist().index(tokenizer.eos_token_id) + 1\n",
569 | " answer_new.append(i[:split].unsqueeze(0))\n",
570 | " answer = answer_new\n",
571 | " else:\n",
572 | " answer = [generate(i.unsqueeze(0)) for i in question]\n",
573 | "\n",
574 | " #裁剪,只要生成的部分\n",
575 | " answer = [a[0, len(q):] for q, a in zip(question, answer)]\n",
576 | "\n",
577 | " return answer\n",
578 | "\n",
579 | "\n",
580 | "answer = get_answer(question)\n",
581 | "\n",
582 | "answer[:10]"
583 | ]
584 | },
585 | {
586 | "cell_type": "code",
587 | "execution_count": 8,
588 | "id": "720b4c0f",
589 | "metadata": {
590 | "scrolled": true
591 | },
592 | "outputs": [
593 | {
594 | "data": {
595 | "text/plain": [
596 | "tensor([-0.0790, -2.7419, 0.0679, -1.6332, 0.6677, 2.6076, 0.1573, -1.1461,\n",
597 | " 0.9582, -2.5122, 2.7525, -2.3757, 1.0361, 2.5690, 1.9334, -2.1080,\n",
598 | " -1.7346, 0.6678, -1.2840, -0.0737, 1.8030, -1.5229, -1.9619, -1.7785,\n",
599 | " -1.3398, 2.2020, 0.3973, -1.9178, -1.7657, -0.1403, -1.7730, 2.3305,\n",
600 | " 1.4456, -1.9255, -2.2839, 1.7811, -2.6034, 1.0441, -0.5461, 1.6941,\n",
601 | " 2.0946, -1.7392, 0.8828, -0.9679, 0.2184, -0.4701, -2.4185, -0.0232,\n",
602 | " -0.0851, 2.2899, -2.2363, -2.2239, -1.2585, 1.6426, -0.6541, 0.9921,\n",
603 | " 0.2738, 0.2257, 1.1809, -1.2928, 2.2934, 2.8806, 1.1941, -1.3584,\n",
604 | " -0.4178, 2.0240, 2.0432, 0.2561, -1.0444, -0.7836, -1.6604, 0.9020,\n",
605 | " -2.4616, 0.5171, 0.3906, 1.7344, 1.1841, -1.4855, 2.4969, -1.4334,\n",
606 | " -1.9962, 0.0924, -1.0023, 2.7003, -0.4900, -1.1669, -1.4747, -1.2043,\n",
607 | " 0.4568, 2.5980, 1.8720, -1.3922, 2.7693, -2.4912, -1.5629, -1.2905,\n",
608 | " 2.4059, 1.9968, -0.8069, -1.0914, 0.9309, -2.6868, -0.8958, 2.4808,\n",
609 | " -2.9966, -1.4021, -2.5262, 1.7169, -1.8984, -0.9218, 2.3929, 2.2298,\n",
610 | " 0.2727, 0.1386, -0.5772, 0.9471, 0.2255, 1.2544, 1.9891, 1.8280,\n",
611 | " -0.1826, -1.3578, -0.3683, -1.0374, 0.7353, 1.8427, -2.1924, -1.1235],\n",
612 | " device='cuda:0')"
613 | ]
614 | },
615 | "execution_count": 8,
616 | "metadata": {},
617 | "output_type": "execute_result"
618 | }
619 | ],
620 | "source": [
621 | "def get_reward(question, answer, label):\n",
622 | " token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]\n",
623 | " token = [tokenizer.decode(i) for i in token]\n",
624 | "\n",
625 | " token = tokenizer_cls(token,\n",
626 | " padding=True,\n",
627 | " truncation=True,\n",
628 | " max_length=512,\n",
629 | " return_tensors='pt').to(device)\n",
630 | "\n",
631 | " with torch.no_grad():\n",
632 | " logits = model_cls(**token).logits\n",
633 | "\n",
634 | " return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n",
635 | "\n",
636 | "\n",
637 | "reward = get_reward(question, answer, label)\n",
638 | "\n",
639 | "reward"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": 9,
645 | "id": "0ff781ae",
646 | "metadata": {},
647 | "outputs": [
648 | {
649 | "data": {
650 | "text/plain": [
651 | ""
652 | ]
653 | },
654 | "execution_count": 9,
655 | "metadata": {},
656 | "output_type": "execute_result"
657 | }
658 | ],
659 | "source": [
660 | "from trl import PPOConfig, PPOTrainer\n",
661 | "\n",
662 | "config = PPOConfig(learning_rate=1e-5, batch_size=128)\n",
663 | "\n",
664 | "trainer = PPOTrainer(config,\n",
665 | " model_ppo,\n",
666 | " model_ppo_ref,\n",
667 | " tokenizer,\n",
668 | " dataset=dataset)\n",
669 | "\n",
670 | "trainer"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 10,
676 | "id": "a1c70e8e",
677 | "metadata": {
678 | "scrolled": true
679 | },
680 | "outputs": [
681 | {
682 | "name": "stderr",
683 | "output_type": "stream",
684 | "text": [
685 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
686 | "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
687 | ]
688 | },
689 | {
690 | "name": "stdout",
691 | "output_type": "stream",
692 | "text": [
693 | "0 0.09299226105213165\n",
694 | "1First of all, this -> movie is hardly stupid. However, it is one of the worst. It has two very real scenarios.Along with the efforts of a young couple of lost hero -2.487795829772949\n"
695 | ]
696 | },
697 | {
698 | "name": "stderr",
699 | "output_type": "stream",
700 | "text": [
701 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
702 | ]
703 | },
704 | {
705 | "name": "stdout",
706 | "output_type": "stream",
707 | "text": [
708 | "10 -0.023181095719337463\n",
709 | "0In my book \"Basic -> Instinct\" released 17 years ago I read the book and think that such kind of things are the way the world is now. And quotes like \"All that -1.2708756923675537\n"
710 | ]
711 | },
712 | {
713 | "name": "stderr",
714 | "output_type": "stream",
715 | "text": [
716 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
717 | ]
718 | },
719 | {
720 | "name": "stdout",
721 | "output_type": "stream",
722 | "text": [
723 | "20 0.31382036209106445\n",
724 | "0Considering it's basically low -> budget and a lot of black nudity and an odd response from the audience, the film is a pretty per ml resolution to an issue worthy of not being all that 1.1325664520263672\n"
725 | ]
726 | },
727 | {
728 | "name": "stderr",
729 | "output_type": "stream",
730 | "text": [
731 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
732 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
733 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
734 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
735 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
736 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
737 | ]
738 | },
739 | {
740 | "name": "stdout",
741 | "output_type": "stream",
742 | "text": [
743 | "30 0.309106707572937\n",
744 | "0True stories make the best -> work of good actors none that are born into their mid-seventies.<|endoftext|> -2.210801601409912\n"
745 | ]
746 | },
747 | {
748 | "name": "stderr",
749 | "output_type": "stream",
750 | "text": [
751 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
752 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
753 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
754 | ]
755 | },
756 | {
757 | "name": "stdout",
758 | "output_type": "stream",
759 | "text": [
760 | "40 1.0727498531341553\n",
761 | "1Cheesy 80's horror -> ...... composer: so beautifully chosen, I can say, has a on-screen performance even better.<|endoftext|> 2.0839757919311523\n"
762 | ]
763 | },
764 | {
765 | "name": "stderr",
766 | "output_type": "stream",
767 | "text": [
768 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
769 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
770 | ]
771 | },
772 | {
773 | "name": "stdout",
774 | "output_type": "stream",
775 | "text": [
776 | "50 1.4121887683868408\n",
777 | "0GBS wrote his own -> script. One name near offense but atmospheric exposition, oddly still bland. Wipeout, rubbish. Violence, i slows.<|endoftext|> 2.3595705032348633\n"
778 | ]
779 | },
780 | {
781 | "name": "stderr",
782 | "output_type": "stream",
783 | "text": [
784 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
785 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
786 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
787 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
788 | ]
789 | },
790 | {
791 | "name": "stdout",
792 | "output_type": "stream",
793 | "text": [
794 | "60 1.7078416347503662\n",
795 | "0I was never in the -> film.<|endoftext|> 0.9322419166564941\n"
796 | ]
797 | },
798 | {
799 | "name": "stderr",
800 | "output_type": "stream",
801 | "text": [
802 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
803 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
804 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
805 | ]
806 | },
807 | {
808 | "name": "stdout",
809 | "output_type": "stream",
810 | "text": [
811 | "70 2.1755259037017822\n",
812 | "1I must say, when -> watching this movie, it is amazing.<|endoftext|> 2.4521334171295166\n"
813 | ]
814 | },
815 | {
816 | "name": "stderr",
817 | "output_type": "stream",
818 | "text": [
819 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
820 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
821 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
822 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
823 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
824 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
825 | ]
826 | },
827 | {
828 | "name": "stdout",
829 | "output_type": "stream",
830 | "text": [
831 | "80 2.1622257232666016\n",
832 | "1This is one of the -> best films of all time.<|endoftext|> 2.7308685779571533\n"
833 | ]
834 | },
835 | {
836 | "name": "stderr",
837 | "output_type": "stream",
838 | "text": [
839 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
840 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
841 | ]
842 | },
843 | {
844 | "name": "stdout",
845 | "output_type": "stream",
846 | "text": [
847 | "90 2.1979665756225586\n",
848 | "1I resisted watching 15 Park -> but so loved seeing it!<|endoftext|> 2.1706929206848145\n"
849 | ]
850 | },
851 | {
852 | "name": "stderr",
853 | "output_type": "stream",
854 | "text": [
855 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
856 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
857 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
858 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
859 | ]
860 | },
861 | {
862 | "name": "stdout",
863 | "output_type": "stream",
864 | "text": [
865 | "100 2.2108092308044434\n",
866 | "0This is a Laurel & -> Hardy terrible movie.<|endoftext|> 1.9358655214309692\n"
867 | ]
868 | },
869 | {
870 | "name": "stderr",
871 | "output_type": "stream",
872 | "text": [
873 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
874 | ]
875 | },
876 | {
877 | "name": "stdout",
878 | "output_type": "stream",
879 | "text": [
880 | "110 2.288398504257202\n",
881 | "1This was the first Mickey -> Renn masteringcom to medieval master and movie makers, and truly supported their development toward English film. highly recommended.<|endoftext|> 2.7567362785339355\n"
882 | ]
883 | },
884 | {
885 | "name": "stderr",
886 | "output_type": "stream",
887 | "text": [
888 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
889 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
890 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
891 | ]
892 | },
893 | {
894 | "name": "stdout",
895 | "output_type": "stream",
896 | "text": [
897 | "120 2.1484575271606445\n",
898 | "1There have been several films -> and tv. This is great film for the ladies.<|endoftext|> 2.491626739501953\n"
899 | ]
900 | },
901 | {
902 | "name": "stderr",
903 | "output_type": "stream",
904 | "text": [
905 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
906 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
907 | ]
908 | },
909 | {
910 | "name": "stdout",
911 | "output_type": "stream",
912 | "text": [
913 | "130 2.2387020587921143\n",
914 | "1K Murli Mohan -> in happiness that was tremendously 1934, and terrific movie.<|endoftext|> 2.5319623947143555\n"
915 | ]
916 | },
917 | {
918 | "name": "stderr",
919 | "output_type": "stream",
920 | "text": [
921 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
922 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
923 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
924 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
925 | ]
926 | },
927 | {
928 | "name": "stdout",
929 | "output_type": "stream",
930 | "text": [
931 | "140 2.111945629119873\n",
932 | "0I don't really know -> how one will figure out how he's going to play the story, but this is a bad movie!<|endoftext|> 2.4522318840026855\n"
933 | ]
934 | },
935 | {
936 | "name": "stderr",
937 | "output_type": "stream",
938 | "text": [
939 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
940 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
941 | ]
942 | },
943 | {
944 | "name": "stdout",
945 | "output_type": "stream",
946 | "text": [
947 | "150 2.228043556213379\n",
948 | "1I will never get back -> this last glorious film.<|endoftext|> 2.100973606109619\n"
949 | ]
950 | },
951 | {
952 | "name": "stderr",
953 | "output_type": "stream",
954 | "text": [
955 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
956 | ]
957 | },
958 | {
959 | "name": "stdout",
960 | "output_type": "stream",
961 | "text": [
962 | "160 2.0983808040618896\n",
963 | "1Night of the Twisters -> . Truly one of the best I've seen.<|endoftext|> 2.794154644012451\n"
964 | ]
965 | },
966 | {
967 | "name": "stderr",
968 | "output_type": "stream",
969 | "text": [
970 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
971 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
972 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
973 | ]
974 | },
975 | {
976 | "name": "stdout",
977 | "output_type": "stream",
978 | "text": [
979 | "170 2.1417503356933594\n",
980 | "1This film was a new -> masterpiece.<|endoftext|> 2.473341464996338\n"
981 | ]
982 | },
983 | {
984 | "name": "stderr",
985 | "output_type": "stream",
986 | "text": [
987 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
988 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
989 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
990 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
991 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
992 | ]
993 | },
994 | {
995 | "name": "stdout",
996 | "output_type": "stream",
997 | "text": [
998 | "180 2.2145256996154785\n",
999 | "1William Faulkner was -> just perfect! An excellent movie.<|endoftext|> 2.788565158843994\n"
1000 | ]
1001 | },
1002 | {
1003 | "name": "stderr",
1004 | "output_type": "stream",
1005 | "text": [
1006 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1007 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1008 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1009 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1010 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1011 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1012 | ]
1013 | },
1014 | {
1015 | "name": "stdout",
1016 | "output_type": "stream",
1017 | "text": [
1018 | "190 2.0492758750915527\n",
1019 | "0Mighty Morphin Power -> is an embarrassment.<|endoftext|> 2.2804903984069824\n"
1020 | ]
1021 | },
1022 | {
1023 | "name": "stderr",
1024 | "output_type": "stream",
1025 | "text": [
1026 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1027 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1028 | ]
1029 | }
1030 | ],
1031 | "source": [
1032 | "import warnings\n",
1033 | "\n",
1034 | "warnings.filterwarnings('ignore')\n",
1035 | "\n",
1036 | "for epoch in range(200):\n",
1037 | " label, question = get_question()\n",
1038 | " answer = get_answer(question)\n",
1039 | " reward = get_reward(question, answer, label)\n",
1040 | "\n",
1041 | " trainer.step(question, answer, [i for i in reward])\n",
1042 | "\n",
1043 | " if epoch % 10 == 0:\n",
1044 | " print(epoch, reward.mean().item())\n",
1045 | "\n",
1046 | " question = tokenizer.decode(question[0].tolist())\n",
1047 | " answer = tokenizer.decode(answer[0].tolist())\n",
1048 | " reward = reward[0].item()\n",
1049 | "\n",
1050 | " #0差评,1好评\n",
1051 | " print(question, '->', answer, reward)"
1052 | ]
1053 | }
1054 | ],
1055 | "metadata": {
1056 | "kernelspec": {
1057 | "display_name": "Python [conda env:pt2]",
1058 | "language": "python",
1059 | "name": "conda-env-pt2-py"
1060 | },
1061 | "language_info": {
1062 | "codemirror_mode": {
1063 | "name": "ipython",
1064 | "version": 3
1065 | },
1066 | "file_extension": ".py",
1067 | "mimetype": "text/x-python",
1068 | "name": "python",
1069 | "nbconvert_exporter": "python",
1070 | "pygments_lexer": "ipython3",
1071 | "version": "3.10.13"
1072 | }
1073 | },
1074 | "nbformat": 4,
1075 | "nbformat_minor": 5
1076 | }
1077 |
--------------------------------------------------------------------------------
/4.ppo_手动训练.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "ab835cce",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14 | " from .autonotebook import tqdm as notebook_tqdm\n",
15 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
16 | "Using sep_token, but it is not set yet.\n",
17 | "Using cls_token, but it is not set yet.\n",
18 | "Using mask_token, but it is not set yet.\n"
19 | ]
20 | },
21 | {
22 | "data": {
23 | "text/plain": [
24 | "GPT2TokenizerFast(name_or_path='tokenizer/lvwerra/gpt2-imdb', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!', 'additional_special_tokens': ['<|endoftext|>']}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
25 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
26 | "}"
27 | ]
28 | },
29 | "execution_count": 1,
30 | "metadata": {},
31 | "output_type": "execute_result"
32 | }
33 | ],
34 | "source": [
35 | "from transformers import AutoTokenizer\n",
36 | "import random\n",
37 | "import torch\n",
38 | "\n",
39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
40 | "\n",
41 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/lvwerra/gpt2-imdb')\n",
42 | "tokenizer.pad_token_id = 0\n",
43 | "\n",
44 | "tokenizer"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "dc57406a",
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "data": {
55 | "text/plain": [
56 | "(Dataset({\n",
57 | " features: ['question'],\n",
58 | " num_rows: 50000\n",
59 | " }),\n",
60 | " {'question': [40, 26399, 314, 3001, 327]})"
61 | ]
62 | },
63 | "execution_count": 2,
64 | "metadata": {},
65 | "output_type": "execute_result"
66 | }
67 | ],
68 | "source": [
69 | "from datasets import load_from_disk, concatenate_datasets\n",
70 | "\n",
71 | "dataset = load_from_disk('dataset/imdb')\n",
72 | "dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])\n",
73 | "\n",
74 | "\n",
75 | "def f(data):\n",
76 | " question = tokenizer.encode(data['text'], add_special_tokens=False)[:5]\n",
77 | " return {'question': question}\n",
78 | "\n",
79 | "\n",
80 | "dataset = dataset.map(f, remove_columns=['label', 'text'])\n",
81 | "\n",
82 | "\n",
83 | "def f(data):\n",
84 | " return len(data['question']) == 5\n",
85 | "\n",
86 | "\n",
87 | "dataset = dataset.filter(f)\n",
88 | "\n",
89 | "dataset, dataset[0]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 3,
95 | "id": "12dbc381",
96 | "metadata": {
97 | "scrolled": true
98 | },
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "([1,\n",
104 | " 1,\n",
105 | " 0,\n",
106 | " 0,\n",
107 | " 1,\n",
108 | " 0,\n",
109 | " 0,\n",
110 | " 1,\n",
111 | " 0,\n",
112 | " 0,\n",
113 | " 0,\n",
114 | " 0,\n",
115 | " 1,\n",
116 | " 1,\n",
117 | " 1,\n",
118 | " 0,\n",
119 | " 0,\n",
120 | " 0,\n",
121 | " 0,\n",
122 | " 1,\n",
123 | " 1,\n",
124 | " 1,\n",
125 | " 1,\n",
126 | " 0,\n",
127 | " 0,\n",
128 | " 1,\n",
129 | " 1,\n",
130 | " 0,\n",
131 | " 1,\n",
132 | " 0,\n",
133 | " 0,\n",
134 | " 1,\n",
135 | " 0,\n",
136 | " 1,\n",
137 | " 0,\n",
138 | " 0,\n",
139 | " 0,\n",
140 | " 0,\n",
141 | " 0,\n",
142 | " 0,\n",
143 | " 1,\n",
144 | " 1,\n",
145 | " 0,\n",
146 | " 1,\n",
147 | " 0,\n",
148 | " 0,\n",
149 | " 1,\n",
150 | " 0,\n",
151 | " 1,\n",
152 | " 0,\n",
153 | " 1,\n",
154 | " 1,\n",
155 | " 1,\n",
156 | " 1,\n",
157 | " 1,\n",
158 | " 0,\n",
159 | " 1,\n",
160 | " 0,\n",
161 | " 1,\n",
162 | " 1,\n",
163 | " 1,\n",
164 | " 1,\n",
165 | " 1,\n",
166 | " 0,\n",
167 | " 1,\n",
168 | " 0,\n",
169 | " 1,\n",
170 | " 0,\n",
171 | " 0,\n",
172 | " 0,\n",
173 | " 1,\n",
174 | " 0,\n",
175 | " 0,\n",
176 | " 0,\n",
177 | " 1,\n",
178 | " 1,\n",
179 | " 1,\n",
180 | " 0,\n",
181 | " 1,\n",
182 | " 1,\n",
183 | " 0,\n",
184 | " 1,\n",
185 | " 1,\n",
186 | " 1,\n",
187 | " 1,\n",
188 | " 1,\n",
189 | " 0,\n",
190 | " 0,\n",
191 | " 0,\n",
192 | " 1,\n",
193 | " 0,\n",
194 | " 0,\n",
195 | " 1,\n",
196 | " 0,\n",
197 | " 1,\n",
198 | " 1,\n",
199 | " 0,\n",
200 | " 0,\n",
201 | " 1,\n",
202 | " 1,\n",
203 | " 1,\n",
204 | " 1,\n",
205 | " 1,\n",
206 | " 1,\n",
207 | " 0,\n",
208 | " 0,\n",
209 | " 0,\n",
210 | " 0,\n",
211 | " 0,\n",
212 | " 0,\n",
213 | " 0,\n",
214 | " 1,\n",
215 | " 0,\n",
216 | " 1,\n",
217 | " 0,\n",
218 | " 1,\n",
219 | " 0,\n",
220 | " 1,\n",
221 | " 0,\n",
222 | " 0,\n",
223 | " 1,\n",
224 | " 0,\n",
225 | " 0,\n",
226 | " 0,\n",
227 | " 0,\n",
228 | " 0,\n",
229 | " 1,\n",
230 | " 0],\n",
231 | " [[16, 40, 423, 1775, 257, 1256],\n",
232 | " [16, 33, 18058, 286, 262, 4044],\n",
233 | " [15, 1, 464, 13037, 6289, 1],\n",
234 | " [15, 40, 460, 470, 3505, 618],\n",
235 | " [16, 1212, 318, 262, 5290, 286],\n",
236 | " [15, 40, 1816, 284, 766, 428],\n",
237 | " [15, 6109, 530, 815, 766, 428],\n",
238 | " [16, 1, 36091, 6592, 6711, 1],\n",
239 | " [15, 35700, 922, 2646, 422, 3771],\n",
240 | " [15, 36, 292, 813, 262, 5290],\n",
241 | " [15, 5211, 407, 766, 428, 2646],\n",
242 | " [15, 2504, 338, 257, 2089, 11],\n",
243 | " [16, 2061, 257, 6283, 8137, 318],\n",
244 | " [16, 33, 463, 16988, 290, 4768],\n",
245 | " [16, 35596, 1234, 11, 428, 318],\n",
246 | " [15, 3792, 340, 655, 502, 11],\n",
247 | " [15, 1212, 428, 2406, 286, 2479],\n",
248 | " [15, 40, 550, 878, 257, 4203],\n",
249 | " [15, 40, 4240, 644, 15579, 286],\n",
250 | " [16, 39, 1794, 699, 422, 923],\n",
251 | " [16, 464, 1772, 286, 6409, 16122],\n",
252 | " [16, 28211, 531, 326, 12887, 4462],\n",
253 | " [16, 1212, 3807, 14071, 517, 621],\n",
254 | " [15, 1532, 428, 2646, 4329, 257],\n",
255 | " [15, 32, 5249, 9298, 11, 4713],\n",
256 | " [16, 2061, 1312, 5465, 749, 287],\n",
257 | " [16, 1, 22073, 588, 257, 47808],\n",
258 | " [15, 40, 10014, 4379, 428, 2646],\n",
259 | " [16, 3351, 533, 47114, 318, 900],\n",
260 | " [15, 40, 2492, 470, 12451, 284],\n",
261 | " [15, 13295, 8849, 329, 6319, 2879],\n",
262 | " [16, 1212, 4141, 2646, 318, 13519],\n",
263 | " [15, 18050, 3944, 17731, 318, 1719],\n",
264 | " [16, 18690, 1312, 760, 262, 2656],\n",
265 | " [15, 1, 1639, 460, 7866, 1997],\n",
266 | " [15, 464, 734, 1243, 389, 389],\n",
267 | " [15, 25596, 948, 31078, 18258, 290],\n",
268 | " [15, 6, 32, 17388, 286, 4930],\n",
269 | " [15, 1212, 3807, 318, 2279, 475],\n",
270 | " [15, 1, 17821, 1, 1621, 286],\n",
271 | " [16, 464, 1578, 1829, 3170, 262],\n",
272 | " [16, 83, 359, 18804, 2540, 302],\n",
273 | " [15, 40, 655, 5201, 4964, 428],\n",
274 | " [16, 2348, 292, 11, 3595, 4345],\n",
275 | " [15, 1212, 6918, 318, 262, 1266],\n",
276 | " [15, 5189, 262, 1115, 816, 1124],\n",
277 | " [16, 21448, 43693, 2879, 6918, 547],\n",
278 | " [15, 4711, 8088, 326, 1624, 428],\n",
279 | " [16, 1722, 584, 8088, 423, 5081],\n",
280 | " [15, 40, 1043, 428, 281, 45421],\n",
281 | " [16, 36837, 11, 314, 2630, 326],\n",
282 | " [16, 1212, 3807, 318, 12659, 13],\n",
283 | " [16, 40, 716, 257, 22197, 18854],\n",
284 | " [16, 40, 1053, 34310, 383, 25976],\n",
285 | " [16, 3260, 616, 718, 614, 1468],\n",
286 | " [15, 25681, 1497, 422, 428, 3807],\n",
287 | " [16, 40, 373, 18680, 12617, 351],\n",
288 | " [15, 5195, 1422, 470, 41542, 423],\n",
289 | " [16, 40, 561, 423, 8359, 428],\n",
290 | " [16, 1212, 318, 262, 2081, 1621],\n",
291 | " [16, 1212, 3807, 318, 44089, 355],\n",
292 | " [16, 40, 717, 2497, 366, 29449],\n",
293 | " [16, 11380, 11, 880, 11, 645],\n",
294 | " [15, 40, 2497, 428, 2646, 355],\n",
295 | " [16, 1212, 318, 546, 257, 8805],\n",
296 | " [15, 15496, 661, 11, 27, 1671],\n",
297 | " [16, 46, 57, 318, 281, 1468],\n",
298 | " [15, 16454, 11, 314, 655, 550],\n",
299 | " [15, 40, 1276, 910, 314, 373],\n",
300 | " [15, 16773, 319, 11, 1309, 338],\n",
301 | " [16, 40, 5257, 284, 1700, 3336],\n",
302 | " [15, 1212, 3807, 1107, 2523, 663],\n",
303 | " [15, 2215, 2077, 355, 257, 2187],\n",
304 | " [15, 8332, 663, 2479, 11, 428],\n",
305 | " [16, 3844, 314, 550, 262, 1266],\n",
306 | " [16, 40, 550, 2938, 257, 6547],\n",
307 | " [16, 2601, 7626, 36389, 6026, 344],\n",
308 | " [15, 1212, 3807, 318, 7579, 11211],\n",
309 | " [16, 464, 7157, 286, 5830, 22430],\n",
310 | " [16, 1858, 423, 587, 1811, 7328],\n",
311 | " [15, 1532, 345, 389, 3612, 286],\n",
312 | " [16, 2514, 1998, 7123, 345, 1107],\n",
313 | " [16, 40, 1807, 262, 3807, 366],\n",
314 | " [16, 1858, 389, 617, 1243, 314],\n",
315 | " [16, 42338, 11, 314, 836, 18265],\n",
316 | " [16, 9, 31895, 6509, 9, 27],\n",
317 | " [15, 464, 36947, 318, 281, 4998],\n",
318 | " [15, 40, 5543, 1842, 428, 2646],\n",
319 | " [15, 3666, 4004, 905, 286, 477],\n",
320 | " [16, 3673, 2089, 13289, 13, 5338],\n",
321 | " [15, 32, 845, 23056, 1621, 11],\n",
322 | " [15, 40817, 11, 379, 717, 11],\n",
323 | " [16, 464, 717, 4756, 3715, 326],\n",
324 | " [15, 40, 460, 1833, 703, 41921],\n",
325 | " [16, 3260, 5586, 832, 262, 4512],\n",
326 | " [16, 16454, 986, 568, 1312, 1053],\n",
327 | " [15, 40, 3221, 588, 10997, 6918],\n",
328 | " [15, 1212, 318, 407, 257, 922],\n",
329 | " [16, 22017, 11, 428, 318, 845],\n",
330 | " [16, 16676, 262, 1438, 7939, 43503],\n",
331 | " [16, 2061, 4325, 618, 2130, 468],\n",
332 | " [16, 2396, 30248, 340, 2125, 470],\n",
333 | " [16, 49, 27015, 7328, 1464, 423],\n",
334 | " [16, 11158, 257, 3807, 810, 262],\n",
335 | " [15, 45, 14651, 17064, 293, 33500],\n",
336 | " [15, 40, 1053, 1100, 257, 1271],\n",
337 | " [15, 1, 2437, 1675, 44927, 14213],\n",
338 | " [15, 5189, 262, 3478, 10544, 508],\n",
339 | " [15, 464, 3807, 373, 3105, 11],\n",
340 | " [15, 40, 7342, 428, 3807, 257],\n",
341 | " [15, 50, 592, 12382, 25, 366],\n",
342 | " [16, 818, 1502, 284, 2883, 705],\n",
343 | " [15, 1, 8491, 921, 287, 262],\n",
344 | " [16, 18690, 11, 523, 340, 338],\n",
345 | " [15, 5703, 1392, 428, 287, 262],\n",
346 | " [16, 1, 3856, 322, 4754, 1],\n",
347 | " [15, 1858, 389, 867, 3840, 314],\n",
348 | " [16, 40, 3088, 284, 1577, 428],\n",
349 | " [15, 986, 392, 7685, 1312, 836],\n",
350 | " [15, 19703, 357, 23865, 4687, 88],\n",
351 | " [16, 1639, 714, 3800, 257, 2196],\n",
352 | " [15, 464, 2646, 318, 1912, 319],\n",
353 | " [15, 32, 32611, 3499, 923, 11],\n",
354 | " [15, 4561, 9437, 364, 287, 428],\n",
355 | " [15, 1, 2061, 561, 345, 466],\n",
356 | " [15, 72, 892, 428, 905, 318],\n",
357 | " [16, 1212, 3807, 318, 546, 257],\n",
358 | " [15, 1558, 14, 2999, 14, 2931]])"
359 | ]
360 | },
361 | "execution_count": 3,
362 | "metadata": {},
363 | "output_type": "execute_result"
364 | }
365 | ],
366 | "source": [
367 | "def get_batch_data():\n",
368 | " label = random.choices(range(2), k=128)\n",
369 | " question = random.choices(dataset, k=128)\n",
370 | " question = [i['question'] for i in question]\n",
371 | "\n",
372 | " question = [[tokenizer.convert_tokens_to_ids(str(l))] + q\n",
373 | " for l, q in zip(label, question)]\n",
374 | "\n",
375 | " return label, question\n",
376 | "\n",
377 | "\n",
378 | "get_batch_data()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 4,
384 | "id": "81ee37d4",
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "name": "stderr",
389 | "output_type": "stream",
390 | "text": [
391 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.weight', 'v_head.summary.bias']\n",
392 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
393 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
394 | "Some weights of the model checkpoint at model/lvwerra/gpt2-imdb were not used when initializing GPT2LMHeadModel: ['v_head.summary.weight', 'v_head.summary.bias']\n",
395 | "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
396 | "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
397 | ]
398 | }
399 | ],
400 | "source": [
401 | "class ModelPPO(torch.nn.Module):\n",
402 | "\n",
403 | " def __init__(self):\n",
404 | " super().__init__()\n",
405 | " from transformers import AutoModelForCausalLM\n",
406 | "\n",
407 | " self.model_gen = AutoModelForCausalLM.from_pretrained(\n",
408 | " 'model/lvwerra/gpt2-imdb')\n",
409 | "\n",
410 | " self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),\n",
411 | " torch.nn.Linear(768, 1))\n",
412 | "\n",
413 | " self.to(device)\n",
414 | " self.train()\n",
415 | "\n",
416 | " def forward(self, input_ids, attention_mask):\n",
417 | " last_hidden_state = self.model_gen.transformer(\n",
418 | " input_ids=input_ids,\n",
419 | " attention_mask=attention_mask,\n",
420 | " output_hidden_states=True).last_hidden_state\n",
421 | "\n",
422 | " logits = self.model_gen.lm_head(last_hidden_state)\n",
423 | " value = self.v_head(last_hidden_state).squeeze(-1)\n",
424 | "\n",
425 | " return logits, value\n",
426 | "\n",
427 | "\n",
428 | "model_ppo = ModelPPO()\n",
429 | "model_ppo_ref = ModelPPO()\n",
430 | "\n",
431 | "for i in model_ppo_ref.parameters():\n",
432 | " i.requires_grad_(False)"
433 | ]
434 | },
435 | {
436 | "cell_type": "code",
437 | "execution_count": 5,
438 | "id": "4b5f88e2",
439 | "metadata": {},
440 | "outputs": [
441 | {
442 | "data": {
443 | "text/plain": [
444 | "tensor([-0.5304, -0.1680, -0.8149, -0.8351, -1.1403, -0.5243, 1.0303, -0.1843,\n",
445 | " 0.0434, 0.9319, -0.2327, 0.5249, 0.8442, -0.8438, 0.8771])"
446 | ]
447 | },
448 | "execution_count": 5,
449 | "metadata": {},
450 | "output_type": "execute_result"
451 | }
452 | ],
453 | "source": [
454 | "def get_kl(a, b):\n",
455 | " method = 'kl'\n",
456 | "\n",
457 | " if method == 'kl':\n",
458 | " return a - b\n",
459 | "\n",
460 | " if method == 'abs':\n",
461 | " return (a - b).abs()\n",
462 | "\n",
463 | " if method == 'mse':\n",
464 | " return (a - b).square() * 0.5\n",
465 | "\n",
466 | " if method == 'full':\n",
467 | " return torch.nn.functional.kl_div(a,\n",
468 | " b,\n",
469 | " log_target=True,\n",
470 | " reduction='none')\n",
471 | "\n",
472 | "\n",
473 | "get_kl(torch.randn(15), torch.zeros(15))"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": 6,
479 | "id": "9d600b60",
480 | "metadata": {},
481 | "outputs": [
482 | {
483 | "name": "stderr",
484 | "output_type": "stream",
485 | "text": [
486 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
487 | " warnings.warn(\n"
488 | ]
489 | },
490 | {
491 | "data": {
492 | "text/plain": [
493 | "<__main__.PPOTrainer at 0x7f04cdfc8190>"
494 | ]
495 | },
496 | "execution_count": 6,
497 | "metadata": {},
498 | "output_type": "execute_result"
499 | }
500 | ],
501 | "source": [
502 | "from trl.core import clip_by_value, logprobs_from_logits, masked_mean, masked_whiten\n",
503 | "\n",
504 | "\n",
505 | "class PPOTrainer:\n",
506 | "\n",
507 | " def __init__(self):\n",
508 | " self.optimizer = torch.optim.Adam(model_ppo.parameters(), lr=1e-5)\n",
509 | "\n",
510 | " def step(self, question, answer, reward):\n",
511 | " with torch.no_grad():\n",
512 | " #编码\n",
513 | " token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]\n",
514 | " token = [{'input_ids': i} for i in token]\n",
515 | " token = tokenizer.pad(token, return_tensors='pt').to(device)\n",
516 | " input_ids = token.input_ids\n",
517 | " attention_mask = token.attention_mask\n",
518 | " del token\n",
519 | "\n",
520 | " #question和answer不需要内容,只需要长度信息即可\n",
521 | " lens_q = [len(i) for i in question]\n",
522 | " lens_a = [len(i) for i in answer]\n",
523 | " del question\n",
524 | " del answer\n",
525 | "\n",
526 | " #根据question计算answer的概率,并计算每个动作的分数\n",
527 | " prob_log, value, mask = self.batched_forward_pass(\n",
528 | " model_ppo, input_ids, attention_mask, lens_q, lens_a)\n",
529 | "\n",
530 | " #使用ref模型计算概率,这是为了计算kl散度\n",
531 | " prob_log_ref, _, _ = self.batched_forward_pass(\n",
532 | " model_ppo_ref, input_ids, attention_mask, lens_q, lens_a)\n",
533 | "\n",
534 | " #计算两份概率的kl散度,并融入reward\n",
535 | " reward = self.compute_rewards(reward, prob_log, prob_log_ref, mask)\n",
536 | "\n",
537 | " #计算delta和target,用于计算loss\n",
538 | " value, delta, target = self.compute_advantages(value, reward, mask)\n",
539 | "\n",
540 | " #每批数据循环N次模型\n",
541 | " for _ in range(4):\n",
542 | " #每次算一个数据\n",
543 | " for i in range(len(input_ids)):\n",
544 | " #重新计算概率和value\n",
545 | " prob_log_new, value_new, _ = self.batched_forward_pass(\n",
546 | " model_ppo, input_ids[i].unsqueeze(0),\n",
547 | " attention_mask[i].unsqueeze(0), [lens_q[i]], [lens_a[i]])\n",
548 | "\n",
549 | " #根据新旧概率求出变化率,进而求出loss\n",
550 | " #根据target和value的差可以计算出另外一份loss\n",
551 | " loss = self.get_loss(prob_log[i].unsqueeze(0),\n",
552 | " value[i].unsqueeze(0), prob_log_new,\n",
553 | " value_new, mask[i].unsqueeze(0),\n",
554 | " delta[i].unsqueeze(0),\n",
555 | " target[i].unsqueeze(0))\n",
556 | "\n",
557 | " if not loss:\n",
558 | " continue\n",
559 | "\n",
560 | " loss.backward()\n",
561 | " #torch.nn.utils.clip_grad_norm_(model_ppo.parameters(), 1.0)\n",
562 | " self.optimizer.step()\n",
563 | " self.optimizer.zero_grad()\n",
564 | "\n",
565 | " def batched_forward_pass(self, model, input_ids, attention_mask, lens_q,\n",
566 | " lens_a):\n",
567 | " logits, value = model(input_ids=input_ids,\n",
568 | " attention_mask=attention_mask)\n",
569 | "\n",
570 | " #取每个字的概率对数\n",
571 | " prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])\n",
572 | "\n",
573 | " #是预测结果并且不是PAD的位置是1\n",
574 | " mask = torch.zeros_like(attention_mask)\n",
575 | " mask[:, :-1] = attention_mask[:, 1:]\n",
576 | " for i in range(len(input_ids)):\n",
577 | " start = lens_q[i] - 1\n",
578 | " end = start + lens_a[i]\n",
579 | " mask[i, :start] = 0\n",
580 | " mask[i, end:] = 0\n",
581 | "\n",
582 | " #对最后一个字的预测没有意义,直接丢弃\n",
583 | " value = value[:, :-1]\n",
584 | " mask = mask[:, :-1]\n",
585 | "\n",
586 | " return prob_log, value, mask\n",
587 | "\n",
588 | " def compute_rewards(self, reward, prob_log, prob_log_ref, mask):\n",
589 | " reward_kl = []\n",
590 | "\n",
591 | " for i in range(len(reward)):\n",
592 | " #求两份概率的kl散度\n",
593 | " kl = get_kl(prob_log[i], prob_log_ref[i]) * -0.2\n",
594 | "\n",
595 | " #把reward加在最后一个字的kl散度上\n",
596 | " if (mask[i] == 0).all():\n",
597 | " #print('all 0')\n",
598 | " idx = 0\n",
599 | " else:\n",
600 | " idx = mask[i].nonzero()[-1].item()\n",
601 | " kl[idx] += reward[i]\n",
602 | "\n",
603 | " reward_kl.append(kl)\n",
604 | "\n",
605 | " return torch.stack(reward_kl)\n",
606 | "\n",
607 | " def compute_advantages(self, value, reward_kl, mask):\n",
608 | " value = value * mask\n",
609 | " reward_kl = reward_kl * mask\n",
610 | "\n",
611 | " delta = []\n",
612 | " lens = reward_kl.shape[1]\n",
613 | "\n",
614 | " #从后往前遍历\n",
615 | " for i in reversed(range(lens)):\n",
616 | " #取下一时刻的value,如果已经是最后一个时刻,则value_next是0\n",
617 | " #因为整个循环是从后往前,所以第0次是0,其他时刻取value\n",
618 | " value_next = 0\n",
619 | " if i < lens - 1:\n",
620 | " value_next = value[:, i + 1]\n",
621 | "\n",
622 | " #value = gamma*下一时刻的value + reward\n",
623 | " #理论上相等,这里的差定义为delta,这里gamma是1,所以省略了\n",
624 | " d = reward_kl[:, i] + value_next - value[:, i]\n",
625 | "\n",
626 | " #取最后一个delta,如果还没有,则初始化为0\n",
627 | " last_d = 0\n",
628 | " if delta:\n",
629 | " last_d = delta[-1]\n",
630 | "\n",
631 | " #delta是从后往前传递的,这里的系数衡量了前后动作的因果关联性\n",
632 | " delta.append(d + 0.95 * last_d)\n",
633 | "\n",
634 | " #翻转顺序\n",
635 | " delta = torch.stack(delta[::-1]).transpose(0, 1)\n",
636 | "\n",
637 | " #定义target,它估计了理想的value值\n",
638 | " target = delta + value\n",
639 | " delta = masked_whiten(delta, mask)\n",
640 | "\n",
641 | " return value, delta, target\n",
642 | "\n",
643 | " def get_loss(self, prob_log, value, prob_log_new, value_new, mask, delta,\n",
644 | " target):\n",
645 | "\n",
646 | " #对数概率,相除变相减,取exp后还原为商,即两个模型输出logits的变化率\n",
647 | " ratio = (prob_log_new - prob_log).exp()\n",
648 | "\n",
649 | " #如果变化率太过于剧烈,可能是发生了震荡,跳过\n",
650 | " if masked_mean(ratio, mask).item() > 10:\n",
651 | " #print('skip', masked_mean(ratio, mask).item())\n",
652 | " return None\n",
653 | "\n",
654 | " #先算两个value的loss,简单的算mse loss就可以了\n",
655 | " loss_vf1 = (value_new - target)**2\n",
656 | " #数值裁剪,很显然是为了缓解自举\n",
657 | " loss_vf2 = clip_by_value(value_new, value - 0.2, value + 0.2)\n",
658 | " loss_vf2 = (loss_vf2 - target)**2\n",
659 | " #两份loss取大的,还是为了缓解自举\n",
660 | " loss_vf = 0.5 * masked_mean(torch.max(loss_vf1, loss_vf2), mask)\n",
661 | "\n",
662 | " #计算ppo loss\n",
663 | " loss_surr1 = -delta * ratio\n",
664 | " #数值裁剪,很显然是为了缓解自举\n",
665 | " loss_surr2 = -delta * ratio.clamp(0.8, 1.2)\n",
666 | " loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2), mask)\n",
667 | "\n",
668 | " return loss_surr + 0.1 * loss_vf\n",
669 | "\n",
670 | "\n",
671 | "trainer = PPOTrainer()\n",
672 | "\n",
673 | "trainer"
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "execution_count": 7,
679 | "id": "805234d3",
680 | "metadata": {},
681 | "outputs": [
682 | {
683 | "name": "stderr",
684 | "output_type": "stream",
685 | "text": [
686 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
687 | ]
688 | }
689 | ],
690 | "source": [
691 | "from transformers import AutoModelForSequenceClassification\n",
692 | "\n",
693 | "tokenizer_cls = AutoTokenizer.from_pretrained(\n",
694 | " 'tokenizer/lvwerra/distilbert-imdb')\n",
695 | "model_cls = AutoModelForSequenceClassification.from_pretrained(\n",
696 | " 'model/lvwerra/distilbert-imdb')\n",
697 | "model_cls.to(device)\n",
698 | "\n",
699 | "for i in model_cls.parameters():\n",
700 | " i.requires_grad_(False)"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 8,
706 | "id": "c3e7a327",
707 | "metadata": {
708 | "scrolled": true
709 | },
710 | "outputs": [
711 | {
712 | "data": {
713 | "text/plain": [
714 | "(tensor([0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0,\n",
715 | " 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,\n",
716 | " 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,\n",
717 | " 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0,\n",
718 | " 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,\n",
719 | " 0, 1, 0, 0, 0, 0, 0, 0], device='cuda:0'),\n",
720 | " [tensor([ 15, 2215, 257, 3807, 588, 366], device='cuda:0'),\n",
721 | " tensor([ 15, 17439, 10884, 26066, 750, 523], device='cuda:0'),\n",
722 | " tensor([ 15, 1532, 345, 389, 2045, 329], device='cuda:0'),\n",
723 | " tensor([ 15, 40, 13770, 8359, 428, 2646], device='cuda:0'),\n",
724 | " tensor([ 16, 1722, 262, 13738, 8146, 17607], device='cuda:0'),\n",
725 | " tensor([ 15, 1212, 2646, 815, 423, 587], device='cuda:0'),\n",
726 | " tensor([ 16, 12298, 72, 47, 8404, 284], device='cuda:0'),\n",
727 | " tensor([ 16, 42322, 314, 561, 18595, 661], device='cuda:0'),\n",
728 | " tensor([ 15, 40, 373, 1936, 618, 262], device='cuda:0'),\n",
729 | " tensor([ 16, 3666, 3988, 6151, 428, 3807], device='cuda:0')])"
730 | ]
731 | },
732 | "execution_count": 8,
733 | "metadata": {},
734 | "output_type": "execute_result"
735 | }
736 | ],
737 | "source": [
738 | "def get_question():\n",
739 | " label, question = get_batch_data()\n",
740 | " label = torch.LongTensor(label).to(device)\n",
741 | "\n",
742 | " question = [torch.LongTensor(i).to(device) for i in question]\n",
743 | "\n",
744 | " return label, question\n",
745 | "\n",
746 | "\n",
747 | "label, question = get_question()\n",
748 | "\n",
749 | "label, question[:10]"
750 | ]
751 | },
752 | {
753 | "cell_type": "code",
754 | "execution_count": 9,
755 | "id": "b77e7d96",
756 | "metadata": {
757 | "scrolled": true
758 | },
759 | "outputs": [
760 | {
761 | "data": {
762 | "text/plain": [
763 | "[tensor([ 464, 7193, 261, 1, 373, 287, 20550, 11, 612, 714,\n",
764 | " 307, 2187, 3923, 546, 28623, 262, 2933, 11, 290, 703,\n",
765 | " 339, 1392, 4978, 287, 262, 16479, 287, 262, 1903, 284,\n",
766 | " 3095, 11445], device='cuda:0'),\n",
767 | " tensor([ 7138, 13, 679, 318, 4084, 379, 257, 966, 287, 465,\n",
768 | " 1981, 3451, 810, 339, 338, 655, 407, 379, 477, 39072,\n",
769 | " 355, 281, 6802, 29847, 1671, 1220, 6927, 1671, 11037, 48393,\n",
770 | " 1279, 1671], device='cuda:0'),\n",
771 | " tensor([ 262, 1388, 3435, 351, 3499, 19063, 11, 13300, 4306, 1100,\n",
772 | " 1088, 329, 257, 39679, 3074, 475, 314, 4719, 340, 611,\n",
773 | " 345, 765, 284, 5899, 20775, 13, 50256], device='cuda:0'),\n",
774 | " tensor([ 13, 1081, 2582, 355, 17369, 465, 12702, 9546, 314, 373,\n",
775 | " 21638, 585, 0, 50256], device='cuda:0'),\n",
776 | " tensor([ 2613, 459, 0, 1318, 338, 530, 7650, 414, 7411, 477,\n",
777 | " 1115, 286, 262, 4476, 274, 4379, 262, 366, 32399, 1911,\n",
778 | " 5278, 306, 340, 338, 13699, 4379, 9935, 440, 6, 6187,\n",
779 | " 2304, 692], device='cuda:0'),\n",
780 | " tensor([ 1695, 319, 34245, 12490, 284, 2342, 691, 986, 37814, 314,\n",
781 | " 7342, 373, 327, 42, 8206, 11, 772, 6918, 314, 1392,\n",
782 | " 379, 1790, 4003, 13, 1675, 910, 326, 428, 2646, 468,\n",
783 | " 262, 2694], device='cuda:0'),\n",
784 | " tensor([ 466, 257, 3155, 286, 1243, 13, 770, 561, 407, 466,\n",
785 | " 284, 1388, 16687, 13, 383, 3807, 857, 407, 423, 262,\n",
786 | " 466, 540, 3404, 326, 7328, 884, 355, 383, 21469, 9501,\n",
787 | " 466, 13], device='cuda:0'),\n",
788 | " tensor([ 284, 27401, 1598, 286, 883, 15774, 2324, 290, 7124, 9073,\n",
789 | " 13, 632, 481, 1790, 1103, 6918, 11, 2529, 11970, 3404,\n",
790 | " 5645, 510, 852, 1716, 517, 4388, 706, 257, 1178, 1933,\n",
791 | " 11, 611], device='cuda:0'),\n",
792 | " tensor([ 1621, 2492, 470, 1598, 13, 1318, 338, 691, 1936, 393,\n",
793 | " 2237, 2431, 1364, 674, 10281, 6, 3160, 13, 632, 2952,\n",
794 | " 2627, 1598, 284, 502, 355, 880, 355, 584, 15618, 7912,\n",
795 | " 326, 351], device='cuda:0'),\n",
796 | " tensor([ 13, 632, 338, 7818, 266, 14, 262, 366, 11718, 500,\n",
797 | " 15461, 465, 22701, 1, 290, 366, 1169, 4293, 12, 20676,\n",
798 | " 6944, 11, 475, 281, 4896, 286, 640, 290, 2568, 326,\n",
799 | " 1854, 466], device='cuda:0')]"
800 | ]
801 | },
802 | "execution_count": 9,
803 | "metadata": {},
804 | "output_type": "execute_result"
805 | }
806 | ],
807 | "source": [
808 | "#包装类,用于生成\n",
809 | "def generate(input_ids):\n",
810 | " return model_ppo.model_gen.generate(input_ids=input_ids,\n",
811 | " min_length=-1,\n",
812 | " top_k=0.0,\n",
813 | " top_p=1.0,\n",
814 | " do_sample=True,\n",
815 | " pad_token_id=tokenizer.pad_token_id,\n",
816 | " max_new_tokens=32,\n",
817 | " eos_token_id=tokenizer.eos_token_id)\n",
818 | "\n",
819 | "\n",
820 | "def get_answer(question):\n",
821 | " #如果question的长度确定,这里可以转换成批运算\n",
822 | " if True:\n",
823 | " answer = generate(torch.stack(question))\n",
824 | "\n",
825 | " answer_new = []\n",
826 | " for i in answer:\n",
827 | " if tokenizer.eos_token_id not in i:\n",
828 | " answer_new.append(i.unsqueeze(0))\n",
829 | " continue\n",
830 | " split = i.tolist().index(tokenizer.eos_token_id) + 1\n",
831 | " answer_new.append(i[:split].unsqueeze(0))\n",
832 | " answer = answer_new\n",
833 | " else:\n",
834 | " answer = [generate(i.unsqueeze(0)) for i in question]\n",
835 | "\n",
836 | " #裁剪,只要生成的部分\n",
837 | " answer = [a[0, len(q):] for q, a in zip(question, answer)]\n",
838 | "\n",
839 | " return answer\n",
840 | "\n",
841 | "\n",
842 | "answer = get_answer(question)\n",
843 | "\n",
844 | "answer[:10]"
845 | ]
846 | },
847 | {
848 | "cell_type": "code",
849 | "execution_count": 10,
850 | "id": "720b4c0f",
851 | "metadata": {
852 | "scrolled": true
853 | },
854 | "outputs": [
855 | {
856 | "data": {
857 | "text/plain": [
858 | "tensor([-0.1023, -2.0593, 0.1034, -2.1943, -0.2854, 0.3657, -2.0234, -0.6897,\n",
859 | " -1.0600, 1.9237, -1.1278, -1.9625, 1.9247, 1.0706, 1.9642, -2.7512,\n",
860 | " 0.1869, 0.4816, -1.4677, 1.4251, 1.8278, -2.4232, 0.3808, -2.2452,\n",
861 | " -1.4352, -1.3860, 1.2793, 2.5945, 0.8147, 1.0352, -0.8067, 1.5356,\n",
862 | " -0.4795, 0.6381, -1.9992, 1.6725, 1.0273, -1.3380, 2.6569, 0.8692,\n",
863 | " 0.3785, -0.6434, -0.1132, 0.7105, 0.7786, -1.3073, -0.9017, -1.7745,\n",
864 | " -1.1298, -0.0379, 0.9818, -1.9865, -1.9988, -2.0698, -2.3039, 0.5009,\n",
865 | " 0.7905, -1.5054, -0.3490, -1.1563, 0.6987, 1.6517, -1.2936, -1.1562,\n",
866 | " -0.6478, -1.3625, -0.8715, -1.3332, -1.3316, -2.4711, 0.1419, -0.8977,\n",
867 | " 0.6781, -2.5032, 2.1797, -2.2050, 0.3033, 1.0208, 1.5972, 1.0242,\n",
868 | " 0.0768, 1.0985, 2.1549, 1.7992, -2.0123, 2.7790, 0.7879, -0.4145,\n",
869 | " 1.1283, -2.4656, 0.3194, 2.4502, -1.3679, 1.6667, 2.4101, -0.8367,\n",
870 | " 0.9427, 2.0877, -2.4355, -0.8031, 0.8273, -1.3980, 1.5957, 1.0456,\n",
871 | " 1.6853, -0.6226, 1.0414, 0.5680, -1.2993, -0.8639, -1.3040, 0.6334,\n",
872 | " 2.5615, 0.0733, -2.6728, -1.4309, 0.7316, 1.3707, 1.3495, -1.8754,\n",
873 | " 1.5368, 2.0362, -1.4510, -2.0597, -2.3088, 2.4779, -0.3622, -1.8313],\n",
874 | " device='cuda:0')"
875 | ]
876 | },
877 | "execution_count": 10,
878 | "metadata": {},
879 | "output_type": "execute_result"
880 | }
881 | ],
882 | "source": [
883 | "def get_reward(question, answer, label):\n",
884 | " token = [q.tolist()[1:] + a.tolist() for q, a in zip(question, answer)]\n",
885 | " token = [tokenizer.decode(i) for i in token]\n",
886 | "\n",
887 | " token = tokenizer_cls(token,\n",
888 | " padding=True,\n",
889 | " truncation=True,\n",
890 | " max_length=512,\n",
891 | " return_tensors='pt').to(device)\n",
892 | "\n",
893 | " with torch.no_grad():\n",
894 | " logits = model_cls(**token).logits\n",
895 | "\n",
896 | " return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n",
897 | "\n",
898 | "\n",
899 | "reward = get_reward(question, answer, label)\n",
900 | "\n",
901 | "reward"
902 | ]
903 | },
904 | {
905 | "cell_type": "code",
906 | "execution_count": 11,
907 | "id": "a1c70e8e",
908 | "metadata": {
909 | "scrolled": false
910 | },
911 | "outputs": [
912 | {
913 | "name": "stderr",
914 | "output_type": "stream",
915 | "text": [
916 | "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
917 | ]
918 | },
919 | {
920 | "name": "stdout",
921 | "output_type": "stream",
922 | "text": [
923 | "0 0.04424596205353737\n",
924 | "0Rowan Atkinson's Mr -> . Incredible not starting off flat or unless you were just seriously down to sea out here I did feel sorry because there is an absolutely amazing heist film thats almost -1.7402509450912476\n"
925 | ]
926 | },
927 | {
928 | "name": "stderr",
929 | "output_type": "stream",
930 | "text": [
931 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
932 | ]
933 | },
934 | {
935 | "name": "stdout",
936 | "output_type": "stream",
937 | "text": [
938 | "10 0.1893816739320755\n",
939 | "1Not only is this a -> work of pure creativity. Its worth buying Justin Kuehne.<|endoftext|> 2.4963550567626953\n"
940 | ]
941 | },
942 | {
943 | "name": "stderr",
944 | "output_type": "stream",
945 | "text": [
946 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
947 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
948 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
949 | ]
950 | },
951 | {
952 | "name": "stdout",
953 | "output_type": "stream",
954 | "text": [
955 | "20 0.11868445575237274\n",
956 | "0I found this very touching -> .)<|endoftext|> -1.8980457782745361\n"
957 | ]
958 | },
959 | {
960 | "name": "stderr",
961 | "output_type": "stream",
962 | "text": [
963 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
964 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
965 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
966 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
967 | ]
968 | },
969 | {
970 | "name": "stdout",
971 | "output_type": "stream",
972 | "text": [
973 | "30 0.04239548742771149\n",
974 | "0Even die hard John Wayne -> was able to play that one scene apart by himself.
The fight in the basement is a simple scene. That's the low- -0.8326716423034668\n"
975 | ]
976 | },
977 | {
978 | "name": "stderr",
979 | "output_type": "stream",
980 | "text": [
981 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
982 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
983 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
984 | ]
985 | },
986 | {
987 | "name": "stdout",
988 | "output_type": "stream",
989 | "text": [
990 | "40 0.09275282919406891\n",
991 | "0I couldn't wait to -> find it on the prairie.<|endoftext|> -0.02781713753938675\n"
992 | ]
993 | },
994 | {
995 | "name": "stderr",
996 | "output_type": "stream",
997 | "text": [
998 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
999 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1000 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1001 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1002 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1003 | ]
1004 | },
1005 | {
1006 | "name": "stdout",
1007 | "output_type": "stream",
1008 | "text": [
1009 | "50 0.38358354568481445\n",
1010 | "1The first episode set the -> structure perfectly. It made me hate it.<|endoftext|> -0.7066333889961243\n"
1011 | ]
1012 | },
1013 | {
1014 | "name": "stderr",
1015 | "output_type": "stream",
1016 | "text": [
1017 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1018 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1019 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1020 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1021 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1022 | ]
1023 | },
1024 | {
1025 | "name": "stdout",
1026 | "output_type": "stream",
1027 | "text": [
1028 | "60 1.2206898927688599\n",
1029 | "1May 1938. Hitler in -> this film is the best.<|endoftext|> 1.5236488580703735\n"
1030 | ]
1031 | },
1032 | {
1033 | "name": "stderr",
1034 | "output_type": "stream",
1035 | "text": [
1036 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1037 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1038 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1039 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1040 | ]
1041 | },
1042 | {
1043 | "name": "stdout",
1044 | "output_type": "stream",
1045 | "text": [
1046 | "70 1.730372428894043\n",
1047 | "1I disliked this film intensely -> I wanted to finish it so often I didn't know what to do with the characters we never liked, but this film was amazing!<|endoftext|> 1.5757133960723877\n"
1048 | ]
1049 | },
1050 | {
1051 | "name": "stderr",
1052 | "output_type": "stream",
1053 | "text": [
1054 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1055 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1056 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1057 | ]
1058 | },
1059 | {
1060 | "name": "stdout",
1061 | "output_type": "stream",
1062 | "text": [
1063 | "80 1.9782875776290894\n",
1064 | "0This film is a very -> bad framing.<|endoftext|> 2.4100003242492676\n"
1065 | ]
1066 | },
1067 | {
1068 | "name": "stderr",
1069 | "output_type": "stream",
1070 | "text": [
1071 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1072 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1073 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1074 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1075 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1076 | ]
1077 | },
1078 | {
1079 | "name": "stdout",
1080 | "output_type": "stream",
1081 | "text": [
1082 | "90 2.1086392402648926\n",
1083 | "1It is the early morning -> in DC the best.<|endoftext|> 2.3762409687042236\n"
1084 | ]
1085 | },
1086 | {
1087 | "name": "stderr",
1088 | "output_type": "stream",
1089 | "text": [
1090 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1091 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1092 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1093 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1094 | ]
1095 | },
1096 | {
1097 | "name": "stdout",
1098 | "output_type": "stream",
1099 | "text": [
1100 | "100 2.1353683471679688\n",
1101 | "1It hurt to watch this -> , this, but it really felt ripping it into a piece and it makes it very well done.<|endoftext|> 2.6977579593658447\n"
1102 | ]
1103 | },
1104 | {
1105 | "name": "stderr",
1106 | "output_type": "stream",
1107 | "text": [
1108 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1109 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1110 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1111 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1112 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1113 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1114 | ]
1115 | },
1116 | {
1117 | "name": "stdout",
1118 | "output_type": "stream",
1119 | "text": [
1120 | "110 2.251354455947876\n",
1121 | "1This was the first of -> such themes. It was great!<|endoftext|> 2.4420251846313477\n"
1122 | ]
1123 | },
1124 | {
1125 | "name": "stderr",
1126 | "output_type": "stream",
1127 | "text": [
1128 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1129 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1130 | ]
1131 | },
1132 | {
1133 | "name": "stdout",
1134 | "output_type": "stream",
1135 | "text": [
1136 | "120 2.341010093688965\n",
1137 | "1Everyone is surely familiar with -> the name and is highly appreciated.<|endoftext|> 2.4222381114959717\n"
1138 | ]
1139 | },
1140 | {
1141 | "name": "stderr",
1142 | "output_type": "stream",
1143 | "text": [
1144 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1145 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1146 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1147 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1148 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1149 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1150 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1151 | ]
1152 | },
1153 | {
1154 | "name": "stdout",
1155 | "output_type": "stream",
1156 | "text": [
1157 | "130 2.313877582550049\n",
1158 | "1Whoever saddled this piece -> is a great appreciated film.<|endoftext|> 2.3847947120666504\n"
1159 | ]
1160 | },
1161 | {
1162 | "name": "stderr",
1163 | "output_type": "stream",
1164 | "text": [
1165 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1166 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1167 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1168 | ]
1169 | },
1170 | {
1171 | "name": "stdout",
1172 | "output_type": "stream",
1173 | "text": [
1174 | "140 2.2891697883605957\n",
1175 | "1I can never fathom -> this film blue shadows. I think it is wonderful.<|endoftext|> 2.648195743560791\n"
1176 | ]
1177 | },
1178 | {
1179 | "name": "stderr",
1180 | "output_type": "stream",
1181 | "text": [
1182 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1183 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1184 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1185 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1186 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1187 | ]
1188 | },
1189 | {
1190 | "name": "stdout",
1191 | "output_type": "stream",
1192 | "text": [
1193 | "150 2.1408190727233887\n",
1194 | "0Making a film for under -> this crap!<|endoftext|> 2.4721293449401855\n"
1195 | ]
1196 | },
1197 | {
1198 | "name": "stderr",
1199 | "output_type": "stream",
1200 | "text": [
1201 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1202 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1203 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1204 | ]
1205 | },
1206 | {
1207 | "name": "stdout",
1208 | "output_type": "stream",
1209 | "text": [
1210 | "160 2.212196111679077\n",
1211 | "0Lets make a movie -> about this crap.<|endoftext|> 2.4810938835144043\n"
1212 | ]
1213 | },
1214 | {
1215 | "name": "stderr",
1216 | "output_type": "stream",
1217 | "text": [
1218 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1219 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1220 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1221 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1222 | ]
1223 | },
1224 | {
1225 | "name": "stdout",
1226 | "output_type": "stream",
1227 | "text": [
1228 | "170 2.2233328819274902\n",
1229 | "0Just saw ICE AGE -> for a total waste of time.<|endoftext|> 2.480029821395874\n"
1230 | ]
1231 | },
1232 | {
1233 | "name": "stderr",
1234 | "output_type": "stream",
1235 | "text": [
1236 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1237 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1238 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1239 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1240 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1241 | ]
1242 | },
1243 | {
1244 | "name": "stdout",
1245 | "output_type": "stream",
1246 | "text": [
1247 | "180 2.2824764251708984\n",
1248 | "0The apolitical musicians Eva -> is just awful.<|endoftext|> 2.610349655151367\n"
1249 | ]
1250 | },
1251 | {
1252 | "name": "stderr",
1253 | "output_type": "stream",
1254 | "text": [
1255 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1256 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1257 | ]
1258 | },
1259 | {
1260 | "name": "stdout",
1261 | "output_type": "stream",
1262 | "text": [
1263 | "190 2.3515257835388184\n",
1264 | "1Excellent movie, a realistic -> look of my life.<|endoftext|> 2.7795515060424805\n"
1265 | ]
1266 | },
1267 | {
1268 | "name": "stderr",
1269 | "output_type": "stream",
1270 | "text": [
1271 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1272 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1273 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
1274 | "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
1275 | ]
1276 | }
1277 | ],
1278 | "source": [
1279 | "for epoch in range(200):\n",
1280 | " label, question = get_question()\n",
1281 | " answer = get_answer(question)\n",
1282 | " reward = get_reward(question, answer, label)\n",
1283 | "\n",
1284 | " trainer.step(question, answer, reward)\n",
1285 | "\n",
1286 | " if epoch % 10 == 0:\n",
1287 | " print(epoch, reward.mean().item())\n",
1288 | " question = tokenizer.decode(question[0].tolist())\n",
1289 | " answer = tokenizer.decode(answer[0].tolist())\n",
1290 | " reward = reward[0].item()\n",
1291 | "\n",
1292 | " #0差评,1好评\n",
1293 | " print(question, '->', answer, reward)"
1294 | ]
1295 | }
1296 | ],
1297 | "metadata": {
1298 | "kernelspec": {
1299 | "display_name": "Python [conda env:pt2]",
1300 | "language": "python",
1301 | "name": "conda-env-pt2-py"
1302 | },
1303 | "language_info": {
1304 | "codemirror_mode": {
1305 | "name": "ipython",
1306 | "version": 3
1307 | },
1308 | "file_extension": ".py",
1309 | "mimetype": "text/x-python",
1310 | "name": "python",
1311 | "nbconvert_exporter": "python",
1312 | "pygments_lexer": "ipython3",
1313 | "version": "3.10.13"
1314 | }
1315 | },
1316 | "nbformat": 4,
1317 | "nbformat_minor": 5
1318 | }
1319 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 训练自然语言的LLM模型
2 | 包括基于TRL的训练,和手动训练两种实现.
3 |
4 | 训练方法包括DPO和PPO
5 |
6 | 环境信息:
7 |
8 | python=3.10
9 |
10 | torch==2.1.0(cuda)
11 |
12 | transformers==4.34.0
13 |
14 | datasets==2.14.5
15 |
16 | trl==0.7.4
17 |
18 | 视频课程:https://www.bilibili.com/video/BV1bu4y1w7Dz
19 |
--------------------------------------------------------------------------------