├── 0.下载文件.ipynb ├── 1.dpo_trl训练.ipynb └── README.md /0.下载文件.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "65d45fa0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from transformers import AutoTokenizer\n", 11 | "\n", 12 | "AutoTokenizer.from_pretrained('gpt2-large').save_pretrained(\n", 13 | " 'tokenizer/gpt2-large')" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "ba345732", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from datasets import load_dataset\n", 24 | "\n", 25 | "load_dataset('b-mc2/sql-create-context').save_to_disk(\n", 26 | " 'dataset/b-mc2/sql-create-context')" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "4341f356", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from transformers import AutoModelForCausalLM\n", 37 | "\n", 38 | "AutoModelForCausalLM.from_pretrained('gpt2-large').save_pretrained(\n", 39 | " 'model/gpt2-large')" 40 | ] 41 | } 42 | ], 43 | "metadata": { 44 | "kernelspec": { 45 | "display_name": "Python [conda env:pt2]", 46 | "language": "python", 47 | "name": "conda-env-pt2-py" 48 | }, 49 | "language_info": { 50 | "codemirror_mode": { 51 | "name": "ipython", 52 | "version": 3 53 | }, 54 | "file_extension": ".py", 55 | "mimetype": "text/x-python", 56 | "name": "python", 57 | "nbconvert_exporter": "python", 58 | "pygments_lexer": "ipython3", 59 | "version": "3.10.13" 60 | } 61 | }, 62 | "nbformat": 4, 63 | "nbformat_minor": 5 64 | } 65 | -------------------------------------------------------------------------------- /1.dpo_trl训练.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "3c3d93a6", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n", 15 | "Using sep_token, but it is not set yet.\n", 16 | "Using cls_token, but it is not set yet.\n", 17 | "Using mask_token, but it is not set yet.\n" 18 | ] 19 | }, 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "GPT2TokenizerFast(name_or_path='tokenizer/gpt2-large', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", 24 | "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 25 | "}" 26 | ] 27 | }, 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "output_type": "execute_result" 31 | } 32 | ], 33 | "source": [ 34 | "import random\n", 35 | "import torch\n", 36 | "\n", 37 | "checkpoint = 110\n", 38 | "device = 'cuda'\n", 39 | "dtype = torch.float16\n", 40 | "only_test = True\n", 41 | "\n", 42 | "from transformers import AutoTokenizer\n", 43 | "\n", 44 | "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2-large')\n", 45 | "tokenizer.pad_token_id = 0\n", 46 | "\n", 47 | "tokenizer" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "id": "58af6f4a", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "(DatasetDict({\n", 60 | " train: Dataset({\n", 61 | " features: ['prompt', 'chosen', 'rejected'],\n", 62 | " num_rows: 200\n", 63 | " })\n", 64 | " test: Dataset({\n", 65 | " features: ['prompt', 'chosen', 'rejected'],\n", 66 | " num_rows: 1500\n", 67 | " })\n", 68 | " }),\n", 69 | " {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',\n", 70 | " 'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beats International\"',\n", 71 | " 'rejected': ''})" 72 | ] 73 | }, 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "from datasets import load_from_disk\n", 81 | "\n", 82 | "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n", 83 | "\n", 84 | "\n", 85 | "def f(data):\n", 86 | " question = 'context:%s question:%s answer:' % (data['context'],\n", 87 | " data['question'])\n", 88 | " answer = data['answer']\n", 89 | " return {'question': question, 'answer': answer}\n", 90 | "\n", 91 | "\n", 92 | "dataset = dataset.map(f, remove_columns=['context'])\n", 93 | "\n", 94 | "\n", 95 | "def f(data):\n", 96 | " question = len(tokenizer.encode(data['question']))\n", 97 | " answer = len(tokenizer.encode(data['answer']))\n", 98 | " return 25 <= question <= 65 and 10 <= answer <= 35\n", 99 | "\n", 100 | "\n", 101 | "dataset = dataset.filter(f)\n", 102 | "\n", 103 | "\n", 104 | "def f(data):\n", 105 | " return {\n", 106 | " 'prompt': data['question'],\n", 107 | " 'chosen': data['answer'],\n", 108 | " 'rejected': ''\n", 109 | " }\n", 110 | "\n", 111 | "\n", 112 | "dataset = dataset.map(f, remove_columns=['question', 'answer'])\n", 113 | "dataset = dataset.train_test_split(test_size=1500)\n", 114 | "\n", 115 | "if only_test:\n", 116 | " dataset['train'] = dataset['train'].select(range(200))\n", 117 | "\n", 118 | "dataset, dataset['train'][0]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 3, 124 | "id": "12276e41", 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "'model/dpo_110.model'" 131 | ] 132 | }, 133 | "execution_count": 3, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "from transformers import AutoModelForCausalLM\n", 140 | "\n", 141 | "path = 'model/gpt2-large'\n", 142 | "if checkpoint != -1:\n", 143 | " path = 'model/dpo_%d.model' % checkpoint\n", 144 | "\n", 145 | "model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)\n", 146 | "if not only_test:\n", 147 | " model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)\n", 148 | "\n", 149 | "path" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 4, 155 | "id": "08c4ec71", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stderr", 160 | "output_type": "stream", 161 | "text": [ 162 | "Map: 100%|██████████| 200/200 [00:02<00:00, 68.13 examples/s]\n", 163 | "Map: 100%|██████████| 1500/1500 [00:17<00:00, 87.49 examples/s]\n" 164 | ] 165 | }, 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "(DatasetDict({\n", 170 | " train: Dataset({\n", 171 | " features: ['prompt', 'chosen', 'rejected'],\n", 172 | " num_rows: 200\n", 173 | " })\n", 174 | " test: Dataset({\n", 175 | " features: ['prompt', 'chosen', 'rejected'],\n", 176 | " num_rows: 1500\n", 177 | " })\n", 178 | " }),\n", 179 | " {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',\n", 180 | " 'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beats International\"',\n", 181 | " 'rejected': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beatles International\"'})" 182 | ] 183 | }, 184 | "execution_count": 4, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "#重新生成数据集中的rejected字段\n", 191 | "def remake_dataset():\n", 192 | " global dataset\n", 193 | " tokenizer.padding_side = 'left'\n", 194 | " model_dpo.to(dtype)\n", 195 | "\n", 196 | " def f(data):\n", 197 | " token = tokenizer(data['prompt'],\n", 198 | " return_tensors='pt',\n", 199 | " padding=True,\n", 200 | " truncation=True).to(device)\n", 201 | "\n", 202 | " out = model_dpo.generate(**token,\n", 203 | " min_length=-1,\n", 204 | " top_k=1,\n", 205 | " top_p=1.0,\n", 206 | " do_sample=True,\n", 207 | " pad_token_id=tokenizer.pad_token_id,\n", 208 | " max_new_tokens=35,\n", 209 | " eos_token_id=tokenizer.eos_token_id)\n", 210 | "\n", 211 | " for i in range(len(out)):\n", 212 | " lens = len(token['input_ids'][i])\n", 213 | " rejected = out[i, lens:]\n", 214 | "\n", 215 | " if tokenizer.eos_token_id in rejected:\n", 216 | " lens = rejected.tolist().index(tokenizer.eos_token_id) + 1\n", 217 | " rejected = rejected[:lens]\n", 218 | "\n", 219 | " rejected = rejected[:35]\n", 220 | " rejected = tokenizer.decode(rejected, skip_special_tokens=True)\n", 221 | "\n", 222 | " if rejected == data['chosen'][i]:\n", 223 | " rejected = ''\n", 224 | "\n", 225 | " data['rejected'][i] = rejected\n", 226 | "\n", 227 | " return data\n", 228 | "\n", 229 | " dataset = dataset.map(f, batched=True, batch_size=128, num_proc=1)\n", 230 | "\n", 231 | " tokenizer.padding_side = 'right'\n", 232 | " model_dpo.to(torch.float32)\n", 233 | "\n", 234 | "\n", 235 | "if only_test or checkpoint != -1:\n", 236 | " remake_dataset()\n", 237 | "\n", 238 | "dataset, dataset['train'][0]" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 5, 244 | "id": "03f8f012", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "#重载模型\n", 249 | "def reload_model(epoch):\n", 250 | " global model_dpo\n", 251 | " global model_dpo_ref\n", 252 | "\n", 253 | " path = 'model/dpo_%d.model' % epoch\n", 254 | " model_dpo.save_pretrained(path)\n", 255 | " model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)\n", 256 | " model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)\n", 257 | "\n", 258 | "\n", 259 | "# reload_model(0)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 6, 265 | "id": "16be7af3", 266 | "metadata": { 267 | "scrolled": true 268 | }, 269 | "outputs": [ 270 | { 271 | "name": "stderr", 272 | "output_type": "stream", 273 | "text": [ 274 | "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", 275 | " warnings.warn(\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "from transformers import TrainingArguments, TrainerCallback\n", 281 | "from trl import DPOTrainer\n", 282 | "\n", 283 | "\n", 284 | "def retrain():\n", 285 | "\n", 286 | " class MyCallback(TrainerCallback):\n", 287 | "\n", 288 | " def on_step_end(self, args, state, control, **kwargs):\n", 289 | " if state.global_step % 250 == 0:\n", 290 | " print(state.global_step)\n", 291 | " return\n", 292 | "\n", 293 | " data = random.choice(dataset['test'])\n", 294 | " input_ids = tokenizer.encode(data['prompt'],\n", 295 | " return_tensors='pt').to(device)\n", 296 | "\n", 297 | " out = model_dpo.generate(input_ids,\n", 298 | " min_length=-1,\n", 299 | " top_k=1,\n", 300 | " top_p=1.0,\n", 301 | " do_sample=True,\n", 302 | " pad_token_id=tokenizer.pad_token_id,\n", 303 | " max_new_tokens=35,\n", 304 | " eos_token_id=tokenizer.eos_token_id)\n", 305 | "\n", 306 | " print(tokenizer.decode(out[0]))\n", 307 | " print('=================')\n", 308 | " print(data['chosen'])\n", 309 | " print(data['rejected'])\n", 310 | " print('=================')\n", 311 | "\n", 312 | " args = TrainingArguments(output_dir='output_dir',\n", 313 | " learning_rate=1e-5,\n", 314 | " per_device_train_batch_size=4,\n", 315 | " max_steps=5000,\n", 316 | " evaluation_strategy='no',\n", 317 | " report_to='none',\n", 318 | " save_strategy='no')\n", 319 | "\n", 320 | " trainer = DPOTrainer(model_dpo,\n", 321 | " model_dpo_ref,\n", 322 | " args=args,\n", 323 | " beta=0.1,\n", 324 | " train_dataset=dataset['train'],\n", 325 | " tokenizer=tokenizer,\n", 326 | " max_length=100,\n", 327 | " max_target_length=100,\n", 328 | " max_prompt_length=100,\n", 329 | " callbacks=[MyCallback()])\n", 330 | "\n", 331 | " trainer.train()\n", 332 | "\n", 333 | "\n", 334 | "# retrain()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 7, 340 | "id": "e3a5bd33", 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "prompt -> context:CREATE TABLE table_name_11 (division INTEGER, reg_season VARCHAR) question:Who was the lowest division in the 7th season? answer:\n", 348 | "chosen -> SELECT MIN(division) FROM table_name_11 WHERE reg_season = \"7th\"\n", 349 | "rejected -> \n", 350 | "=========\n", 351 | "prompt -> context:CREATE TABLE cinema (LOCATION VARCHAR) question:Show each location and the number of cinemas there. answer:\n", 352 | "chosen -> SELECT LOCATION, COUNT(*) FROM cinema GROUP BY LOCATION\n", 353 | "rejected -> \n", 354 | "=========\n", 355 | "prompt -> context:CREATE TABLE Sections (section_name VARCHAR, section_description VARCHAR) question:What are the names and descriptions of all the sections? answer:\n", 356 | "chosen -> SELECT section_name, section_description FROM Sections\n", 357 | "rejected -> \n", 358 | "=========\n", 359 | "prompt -> context:CREATE TABLE table_name_41 (venue VARCHAR, date VARCHAR) question:What is the venue of the game that was played on 23 October 1966? answer:\n", 360 | "chosen -> SELECT venue FROM table_name_41 WHERE date = \"23 october 1966\"\n", 361 | "rejected -> \n", 362 | "=========\n", 363 | "prompt -> context:CREATE TABLE table_26077092_7 (pick__number INTEGER, player VARCHAR) question:What was the pick number for Andrew Quarless? answer:\n", 364 | "chosen -> SELECT MAX(pick__number) FROM table_26077092_7 WHERE player = \"Andrew Quarless\"\n", 365 | "rejected -> \n", 366 | "=========\n", 367 | "prompt -> context:CREATE TABLE table_name_27 (film VARCHAR, year VARCHAR) question:What film was made in 1999? answer:\n", 368 | "chosen -> SELECT film FROM table_name_27 WHERE year = \"1999\"\n", 369 | "rejected -> \n", 370 | "=========\n", 371 | "prompt -> context:CREATE TABLE table_name_82 (origin VARCHAR, owner VARCHAR) question:What is the origin for the item with an owner of Hunan Broadcasting System (HBS)? answer:\n", 372 | "chosen -> SELECT origin FROM table_name_82 WHERE owner = \"hunan broadcasting system (hbs)\"\n", 373 | "rejected -> \n", 374 | "=========\n", 375 | "prompt -> context:CREATE TABLE table_name_7 (region VARCHAR, date VARCHAR) question:From which region is the album with release date of 19 June 2007? answer:\n", 376 | "chosen -> SELECT region FROM table_name_7 WHERE date = \"19 june 2007\"\n", 377 | "rejected -> \n", 378 | "=========\n" 379 | ] 380 | }, 381 | { 382 | "data": { 383 | "text/plain": [ 384 | "(0.848, 0.8573333333333333)" 385 | ] 386 | }, 387 | "execution_count": 7, 388 | "metadata": {}, 389 | "output_type": "execute_result" 390 | } 391 | ], 392 | "source": [ 393 | "def test():\n", 394 | " sample = random.choices(dataset['test'], k=8)\n", 395 | " for i in sample:\n", 396 | " for k, v in i.items():\n", 397 | " print(k, '->', v)\n", 398 | " print('=========')\n", 399 | "\n", 400 | " def correct(data, lower):\n", 401 | " rejected = data['rejected']\n", 402 | " chosen = data['chosen']\n", 403 | "\n", 404 | " if rejected == '':\n", 405 | " return True\n", 406 | "\n", 407 | " if lower:\n", 408 | " chosen = chosen.lower().replace('\"', '\\'')\n", 409 | " rejected = rejected.lower().replace('\"', '\\'')\n", 410 | "\n", 411 | " return chosen == rejected\n", 412 | "\n", 413 | " def accuracy(lower):\n", 414 | " return sum([correct(i, lower)\n", 415 | " for i in dataset['test']]) / len(dataset['test'])\n", 416 | "\n", 417 | " return accuracy(False), accuracy(True)\n", 418 | "\n", 419 | "\n", 420 | "test()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 8, 426 | "id": "6dec7111", 427 | "metadata": { 428 | "scrolled": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "if not only_test:\n", 433 | " for epoch in range(checkpoint + 1, 100):\n", 434 | " retrain()\n", 435 | " reload_model(epoch)\n", 436 | " remake_dataset()\n", 437 | "\n", 438 | " print('epoch', epoch, 'test:', test())" 439 | ] 440 | } 441 | ], 442 | "metadata": { 443 | "kernelspec": { 444 | "display_name": "Python [conda env:pt2]", 445 | "language": "python", 446 | "name": "conda-env-pt2-py" 447 | }, 448 | "language_info": { 449 | "codemirror_mode": { 450 | "name": "ipython", 451 | "version": 3 452 | }, 453 | "file_extension": ".py", 454 | "mimetype": "text/x-python", 455 | "name": "python", 456 | "nbconvert_exporter": "python", 457 | "pygments_lexer": "ipython3", 458 | "version": "3.10.13" 459 | } 460 | }, 461 | "nbformat": 4, 462 | "nbformat_minor": 5 463 | } 464 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 训练LLM写SQL 2 | 环境信息: 3 | 4 | python==3.10 5 | 6 | pytorch==2.1.0 7 | 8 | transformers==4.34.0 9 | 10 | datasets==2.14.5 11 | 12 | trl==0.7.4 13 | 14 | 视频课程:https://www.bilibili.com/video/BV1hT4y187Uu 15 | --------------------------------------------------------------------------------