├── metric ├── __pycache__ │ └── evaluation.cpython-37.pyc └── evaluation.py ├── main.py ├── data2json.py ├── run_train.sh ├── run_test.sh ├── test_gpt.sh ├── requirements.txt ├── run_gpt.sh ├── cal_rouge.py ├── data └── fine-tune │ ├── test.json │ ├── dev.json │ └── train.json ├── README.md ├── run_summarization.py ├── model ├── modeling_gpt_prompt.py └── modeling_t5.py └── run_gpt_prompt.py /metric/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeng-WH/Prompt-Tuning/HEAD/metric/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press Shift+F10 to execute it or replace it with your code. 4 | # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. 5 | 6 | 7 | def print_hi(name): 8 | # Use a breakpoint in the code line below to debug your script. 9 | print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. 10 | 11 | 12 | # Press the green button in the gutter to run the script. 13 | if __name__ == '__main__': 14 | print_hi('PyCharm') 15 | 16 | # See PyCharm help at https://www.jetbrains.com/help/pycharm/ 17 | -------------------------------------------------------------------------------- /data2json.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | with open('/home/ypd-19-2/prefix-tuning/ypd_prefix/multiwoz_transfer_dataset/train_as_test/test.source', 'r') as r: 3 | train_source = r.readlines() 4 | with open('/home/ypd-19-2/prefix-tuning/ypd_prefix/multiwoz_transfer_dataset/train_as_test/test.target', 'r') as r: 5 | train_target = r.readlines() 6 | w = jsonlines.open('/home/ypd-19-2/docomo/data_source_target/train/test.json', 'w') 7 | for s, t in zip(train_source, train_target): 8 | line_dict = {} 9 | line_dict["text"] = s 10 | line_dict["summary"] = t 11 | jsonlines.Writer.write(w, line_dict) 12 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | python run_summarization.py \ 2 | --model_name_or_path facebook/bart-large \ 3 | --do_train \ 4 | --do_eval \ 5 | --train_file /home/ypd-19-2/docomo/data_source_target/train/train.json \ 6 | --validation_file /home/ypd-19-2/docomo/data_source_target/train/val.json \ 7 | --source_prefix "summarize: " \ 8 | --output_dir /home/ypd-19-2/docomo/checkpoints_decode_eval_1.0 \ 9 | --overwrite_output_dir \ 10 | --gradient_accumulation_steps=4 \ 11 | --per_device_train_batch_size=1 \ 12 | --per_device_eval_batch_size=1 \ 13 | --save_total_limit=1 \ 14 | --eval_steps=50 \ 15 | --logging_steps=50 \ 16 | --save_steps=15000 \ 17 | --num_train_epochs=4.0 \ 18 | --learning_rate=1e-3 \ 19 | --predict_with_generate \ 20 | --pre_seq_len=200 -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | python run_summarization.py \ 2 | --model_name_or_path /home/ypd-19-2/docomo/checkpoints_decode_eval_1.0 \ 3 | --do_predict \ 4 | --train_file /home/ypd-19-2/docomo/jsonl/train.json \ 5 | --validation_file /home/ypd-19-2/docomo/jsonl/val.json \ 6 | --test_file /home/ypd-19-2/docomo/data_source_target/train/test.json \ 7 | --source_prefix "summarize: " \ 8 | --output_dir /home/ypd-19-2/docomo/checkpoints_decode_eval_1.0 \ 9 | --overwrite_output_dir \ 10 | --gradient_accumulation_steps=4 \ 11 | --per_device_train_batch_size=16 \ 12 | --per_device_eval_batch_size=8 \ 13 | --save_total_limit=1 \ 14 | --eval_steps=50 \ 15 | --save_steps=15000 \ 16 | --logging_steps=50 \ 17 | --num_train_epochs=4.0 \ 18 | --learning_rate=1e-3 \ 19 | --predict_with_generate \ 20 | --pre_seq_len=200 -------------------------------------------------------------------------------- /test_gpt.sh: -------------------------------------------------------------------------------- 1 | ''' 2 | 此代码train了GPT 3 | ''' 4 | CUDA_VISIBLE_DEVICES=1 python run_gpt_prompt.py \ 5 | --model_name_or_path microsoft/DialoGPT-medium \ 6 | --model_name gpt-small \ 7 | --do_eval \ 8 | --validation_file data/fine-tune/test.json \ 9 | --source_prefix "dialogue: " \ 10 | --output_dir /output_dir \ 11 | --overwrite_output_dir \ 12 | --per_device_train_batch_size=1 \ 13 | --per_device_eval_batch_size=1 \ 14 | --predict_with_generate \ 15 | --eval_steps=50 \ 16 | --logging_steps=50 \ 17 | --num_train_epochs=10.0 \ 18 | --learning_rate=2e-3 \ 19 | --max_source_length=512 \ 20 | --generation_max_length 682 \ 21 | --text_column dialogue \ 22 | --summary_column response \ 23 | --evaluation_strategy epoch \ 24 | --save_strategy epoch \ 25 | --load_best_model_at_end True \ 26 | --pre_seq_len 50 \ 27 | --prefix_drop 0.1 \ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | async-timeout==4.0.2 5 | asynctest==0.13.0 6 | attrs==21.2.0 7 | certifi==2021.10.8 8 | charset-normalizer==2.0.9 9 | click==8.0.3 10 | datasets==1.17.0 11 | dill==0.3.4 12 | filelock==3.4.0 13 | frozenlist==1.2.0 14 | fsspec==2021.11.1 15 | huggingface-hub==0.2.1 16 | idna==3.3 17 | importlib-metadata==4.10.0 18 | joblib==1.1.0 19 | multidict==5.2.0 20 | multiprocess==0.70.12.2 21 | nltk==3.6.6 22 | numpy==1.21.5 23 | packaging==21.3 24 | pandas==1.1.5 25 | pyarrow==6.0.1 26 | pyparsing==3.0.6 27 | python-dateutil==2.8.2 28 | pytz==2021.3 29 | PyYAML==6.0 30 | regex==2021.11.10 31 | requests==2.26.0 32 | rouge==1.0.1 33 | rouge-score==0.0.4 34 | sacremoses==0.0.46 35 | six==1.16.0 36 | tokenizers==0.10.3 37 | torch==1.7.1 38 | tqdm==4.62.3 39 | transformers==4.12.5 40 | typing_extensions==4.0.1 41 | urllib3==1.26.7 42 | xxhash==2.0.2 43 | yarl==1.7.2 44 | zipp==3.6.0 45 | -------------------------------------------------------------------------------- /run_gpt.sh: -------------------------------------------------------------------------------- 1 | ''' 2 | 此代码train了GPT 3 | ''' 4 | CUDA_VISIBLE_DEVICES=1 python run_gpt_prompt.py \ 5 | --model_name_or_path microsoft/DialoGPT-medium \ 6 | --model_name gpt-small \ 7 | --do_train \ 8 | --do_eval \ 9 | --train_file data/fine-tune/train.json \ 10 | --validation_file data/fine-tune/dev.json \ 11 | --source_prefix "dialogue: " \ 12 | --output_dir /output_dir \ 13 | --overwrite_output_dir \ 14 | --per_device_train_batch_size=1 \ 15 | --per_device_eval_batch_size=1 \ 16 | --predict_with_generate \ 17 | --eval_steps=50 \ 18 | --logging_steps=50 \ 19 | --num_train_epochs=10.0 \ 20 | --learning_rate=2e-3 \ 21 | --max_source_length=512 \ 22 | --generation_max_length 682 \ 23 | --text_column dialogue \ 24 | --summary_column response \ 25 | --evaluation_strategy epoch \ 26 | --save_strategy epoch \ 27 | --load_best_model_at_end True \ 28 | --pre_seq_len 50 \ 29 | --prefix_drop 0.1 \ -------------------------------------------------------------------------------- /cal_rouge.py: -------------------------------------------------------------------------------- 1 | # prefix计算rouge分数,自用 2 | 3 | from rouge import Rouge 4 | f = open("/home/ypd-19-2/docomo/checkpoints_decode_eval_1.0/generated_predictions.txt", "r", encoding="utf-8") 5 | f1 = open("/home/ypd-19-2/prefix-tuning/ypd_prefix/multiwoz_transfer_dataset/train_as_test/test.target", "r", encoding="utf-8") 6 | cands = [] 7 | golds = [] 8 | for line in f: 9 | cand = line.strip().replace("", " ") 10 | cands.append(cand) 11 | for line in f1: 12 | gold = line.strip().replace("", " ") 13 | golds.append(gold) 14 | 15 | rouge = Rouge() 16 | total_R1_P = 0 17 | total_R1_R = 0 18 | total_R1_F = 0 19 | total_R2_P = 0 20 | total_R2_R = 0 21 | total_R2_F = 0 22 | total_RL_P = 0 23 | total_RL_R = 0 24 | total_RL_F = 0 25 | for i in range(len(cands)): 26 | rouge_score = rouge.get_scores(cands[i], golds[i]) 27 | R_1 = rouge_score[0]["rouge-1"] 28 | R_2 = rouge_score[0]["rouge-2"] 29 | R_L = rouge_score[0]["rouge-l"] 30 | P_R_1 = R_1['p'] 31 | R_R_1 = R_1['r'] 32 | F_R_1 = R_1['f'] 33 | P_R_2 = R_2['p'] 34 | R_R_2 = R_2['r'] 35 | F_R_2 = R_2['f'] 36 | P_R_L = R_L['p'] 37 | R_R_L = R_L['r'] 38 | F_R_L = R_L['f'] 39 | total_R1_P += P_R_1 40 | total_R1_R += R_R_1 41 | total_R1_F += F_R_1 42 | total_R2_P += P_R_2 43 | total_R2_R += R_R_2 44 | total_R2_F += F_R_2 45 | total_RL_P += P_R_L 46 | total_RL_R += R_R_L 47 | total_RL_F += F_R_L 48 | print("R1 Score:") 49 | print(total_R1_P/len(cands)) 50 | print(total_R1_R/len(cands)) 51 | print(total_R1_F/len(cands)) 52 | print("R2 Score:") 53 | print(total_R2_P/len(cands)) 54 | print(total_R2_R/len(cands)) 55 | print(total_R2_F/len(cands)) 56 | print("RL Score:") 57 | print(total_RL_P/len(cands)) 58 | print(total_RL_R/len(cands)) 59 | print(total_RL_F/len(cands)) -------------------------------------------------------------------------------- /data/fine-tune/test.json: -------------------------------------------------------------------------------- 1 | {"dialogue": "Good morning . I understand that you ' Ve got a problem with your washing machine . I ' m from the repair company .<|endoftext|>Excellent . Come in please . The washing machine is in the bathroom upstairs . It keeps breaking down .<|endoftext|>When did it first break down ?<|endoftext|>About ten days ago . I ' Ve tried to use it since then . Sometimes it works and sometimes it doesn ' t . it ' s very frustrating .<|endoftext|>Is it still under warranty . If it is and I can ' t fix it , it would be quicker and easier to exchange it for a new one .<|endoftext|>Yes , it ' s still under warranty . Over the last few weeks , it ' s also been making a high-pitch noise when it ' s in use .<|endoftext|>Ok . I ' ll start by looking at the motor . I ' ll just unplug it and take a look inside the machine ... oh , yes . There ' s the problem . It ' s quite simple . I ' ll sort it out in a few minutes .<|endoftext|>What ' s wrong with it ?<|endoftext|>Part of the motor is loose . I can put it back in place quite easily .", "response": "That ' s great . Thanks very much . Would you like a cup of tea or coffee ?"} 2 | {"dialogue": "What can I do for you ?<|endoftext|>I've got a suit , a woolen sweater and a white shirt to wash .<|endoftext|>OK , let me see . This white shirt can be washed in water with hands , but this suit and the woolen sweater should be dry-cleaned .<|endoftext|>That's OK . But that must be costly .<|endoftext|>Yes , the cost for dry-cleaning is three times of that for ordinary laundering . But for suits and sweaters , you can only take them to the dry cleaners ' .", "response": "Oh , my wife just threw them into the washer ."} 3 | {"dialogue": "Shall we discuss the packing ? You know , a well-designed package helps sell the goods , so the products must not only be superior in quality , but also attractive in appearance . I'd like to see the sample of packing .<|endoftext|>We have made a lot of improvement in packing . You are welcome to see the sample in the showroom . I think you will find the new packing beautiful and quite well-done .", "response": "Quite good . The beautiful design and bright color are just the European taste . How are you gonna pack these blouses ?"} 4 | {"dialogue": "Would you like a cup of coffee ?<|endoftext|>Yes . That would be good .", "response": "Oh , no ."} -------------------------------------------------------------------------------- /metric/evaluation.py: -------------------------------------------------------------------------------- 1 | from nltk.util import ngrams 2 | from nltk import word_tokenize 3 | from nltk.translate.bleu_score import corpus_bleu 4 | from nltk.translate.meteor_score import meteor_score 5 | import numpy as np 6 | 7 | 8 | def compute_bleu(references, candidates): 9 | ref_list, dec_list = [], [] 10 | for i in range(len(candidates)): 11 | dec_list.append(word_tokenize(candidates[i])) 12 | if type(references[i]) is list: 13 | tmp = [] 14 | for ref in references[i]: 15 | tmp.append(word_tokenize(ref)) 16 | ref_list.append(tmp) 17 | else: 18 | ref_list.append([word_tokenize(references[i])]) 19 | bleu1 = corpus_bleu(ref_list, dec_list, 20 | weights=(1, 0, 0, 0)) 21 | bleu2 = corpus_bleu(ref_list, dec_list, 22 | weights=(0, 1, 0, 0)) 23 | bleu3 = corpus_bleu(ref_list, dec_list, 24 | weights=(0, 0, 1, 0)) 25 | bleu4 = corpus_bleu(ref_list, dec_list, 26 | weights=(0, 0, 0, 1)) 27 | return { 28 | "bleu-1": bleu1, 29 | "bleu-2": bleu2, 30 | "bleu-3": bleu3, 31 | "bleu-4": bleu4, # main result 32 | } 33 | 34 | def compute_meteor(references, candidates): 35 | score_list = [] 36 | ref_list, dec_list = [], [] 37 | for i in range(len(candidates)): 38 | dec_list.append(word_tokenize(candidates[i])) 39 | if type(references[i]) is list: 40 | tmp =[] 41 | for ref in references[i]: 42 | tmp.append(word_tokenize(ref)) 43 | ref_list.append(tmp) 44 | #ref_list = references[i] 45 | else: 46 | #ref_list = [references[i]] 47 | ref_list.append([word_tokenize(references[i])]) 48 | score = meteor_score(ref_list[i], dec_list[i]) 49 | score_list.append(score) 50 | 51 | return { 52 | "METEOR: ": np.mean(score_list), 53 | } 54 | 55 | 56 | def distinct_ngram(candidates, n=2): 57 | """Return basic ngram statistics, as well as a dict of all ngrams and their freqsuencies.""" 58 | ngram_freqs = {} # ngrams with frequencies 59 | ngram_len = 0 # total number of ngrams 60 | for candidate in candidates: 61 | for ngram in ngrams(word_tokenize(candidate), n): 62 | ngram_freqs[ngram] = ngram_freqs.get(ngram, 0) + 1 63 | ngram_len += 1 64 | # number of unique ngrams 65 | uniq_ngrams = len([val for val in ngram_freqs.values() if val == 1]) 66 | distinct_ngram = len(ngram_freqs) / ngram_len if ngram_len > 0 else 0 67 | print(f'Distinct {n}-grams:', round(distinct_ngram,4)) 68 | return ngram_freqs, uniq_ngrams, ngram_len -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt-Tuning 2 | Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning" 3 | Currently, we support the following huggigface models: 4 | 5 | - `BartForConditionalGeneration` 6 | - `T5ForConditionalGeneration` 7 | 8 | In 2022.6, we support GPT2 now: 9 | 10 | - `GPT2LMHeadModel` 11 | 12 | ## Setup 13 | 14 | ``` 15 | conda create -n prompt-tuning python==3.7.0 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Data 20 | 21 | The task of summarization supports custom CSV and JSONLINES formats. 22 | 23 | You can use `data2json.py` to transformer data to JSONLINES formats. 24 | 25 | #### Custom CSV Files 26 | 27 | If it's a csv file the training and validation files should have a column for the inputs texts and a column for the summaries. 28 | 29 | If the csv file has just two columns as in the following example: 30 | 31 | ``` 32 | text,summary 33 | "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder","I'm sitting in a room where I'm waiting for something to happen" 34 | "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.","I'm a gardener and I'm a big fan of flowers." 35 | "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share","It's that time of year again." 36 | ``` 37 | 38 | The first column is assumed to be for `text` and the second is for summary. 39 | 40 | If the csv file has multiple columns, you can then specify the names of the columns to use: 41 | 42 | ``` 43 | --text_column text_column_name \ 44 | --summary_column summary_column_name \ 45 | ``` 46 | 47 | For example if the columns were: 48 | 49 | ``` 50 | id,date,text,summary 51 | ``` 52 | 53 | and you wanted to select only `text` and `summary`, then you'd pass these additional arguments: 54 | 55 | ``` 56 | --text_column text \ 57 | --summary_column summary \ 58 | ``` 59 | 60 | #### Custom JSONLINES Files 61 | 62 | The second supported format is jsonlines. Here is an example of a jsonlines custom data file. 63 | 64 | ``` 65 | {"text": "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder", "summary": "I'm sitting in a room where I'm waiting for something to happen"} 66 | {"text": "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.", "summary": "I'm a gardener and I'm a big fan of flowers."} 67 | {"text": "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share", "summary": "It's that time of year again."} 68 | ``` 69 | 70 | Same as with the CSV files, by default the first value will be used as the text record and the second as the summary record. Therefore you can use any key names for the entries, in this example `text` and `summary` were used. 71 | 72 | And as with the CSV files, you can specify which values to select from the file, by explicitly specifying the corresponding key names. In our example this again would be: 73 | 74 | ``` 75 | --text_column text \ 76 | --summary_column summary \ 77 | ``` 78 | 79 | ## Train 80 | 81 | ``` 82 | bash run_train.sh 83 | ``` 84 | 85 | You can adjust the values for the arguments `--train_file`, `--validation_file` in `run_train.sh` 86 | 87 | To control the prompt length, you can adjust the values for the arguments `--pre_seq_len` in `run_train.sh`. 88 | 89 | Other setting, such as `learning rate`, `batch_size`, you can also adjust in `run_train.sh`. 90 | 91 | ## Test 92 | 93 | ``` 94 | bash run_test.sh 95 | ``` 96 | 97 | You can adjust the values for the arguments `--test_file` in `run_test.sh` 98 | 99 | Other setting, you can also adjust in `run_test.sh`. The generated summary is in `output_dir/generated_predictions.txt` 100 | 101 | ## GPT2 102 | 103 | To run prompt tuning with GPT2, you can use: 104 | 105 | ``` 106 | run_gpt.sh 107 | ``` 108 | 109 | You can adjust prompt length with `--pre_seq_len`, model_name with `--model_name_or_path`, dropout rate of prompt with `--prefix_drop` 110 | 111 | To test prompt tuning with GPT2, you can use: 112 | 113 | ``` 114 | test_gpt.sh 115 | ``` 116 | 117 | ## Citation 118 | 119 | ``` 120 | @misc{lester2021power, 121 | title={The Power of Scale for Parameter-Efficient Prompt Tuning}, 122 | author={Brian Lester and Rami Al-Rfou and Noah Constant}, 123 | year={2021}, 124 | eprint={2104.08691}, 125 | archivePrefix={arXiv}, 126 | primaryClass={cs.CL} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /data/fine-tune/dev.json: -------------------------------------------------------------------------------- 1 | {"dialogue": "Would you like to come to our party ?<|endoftext|>I'd like to . But I'm not myself today .<|endoftext|>What happened ? You look so pale .<|endoftext|>It's the end of the world for me .<|endoftext|>What makes you say that ?", "response": "Today , I'm told I failed three courses ."} 2 | {"dialogue": "Where are you going for lunch ?<|endoftext|>I have no idea . I have no appetite .<|endoftext|>I know of a restaurant . I'm sure you'll like it .<|endoftext|>Let's go there then . Is it a Chinese restaurant ?", "response": "Yes . it serves excellent Sichuan food . Spicy food may improve your appetite ."} 3 | {"dialogue": "how are you doing ?<|endoftext|>I'm ok . I wish I could say the same for my friend .<|endoftext|>what happened to him ?<|endoftext|>he was arrested by the police for drinking and driving .<|endoftext|>was it his first offence ?<|endoftext|>unfortunately not . He was charged with a DUI when he was in university .<|endoftext|>what happened to him then ?<|endoftext|>not much ; it was a minor offence back then . He got away with a fine of $ 500 .<|endoftext|>did they take his license away ?<|endoftext|>no , they were really easy on him . The problem is that that was a long time ago . They are much tougher on crime now .<|endoftext|>what do you think will happen to him ?<|endoftext|>well , he'll definitely lose his license , pay a fine , and maybe even spend some time in jail .<|endoftext|>that doesn't sound too promising . Does he have a defence lawyer ?<|endoftext|>not yet . If we can't find a lawyer for him , then the state will appoint him with one .<|endoftext|>my sister is a lawyer . I can ask her if she can help him .", "response": "that'd be great ! I know he'd appreciate your help ."} 4 | {"dialogue": "what did you do last night ?<|endoftext|>I just stayed at home and watched TV .<|endoftext|>did you watch the Barcelona versus Madrid game ?<|endoftext|>yes , it was a very evenly matched game .<|endoftext|>what were the results of the game ?<|endoftext|>it ended in a tie .<|endoftext|>what was the actual score ?<|endoftext|>it was three all .<|endoftext|>really ? I thought Barcelona was going to win . They are a much better team than Madrid .<|endoftext|>if Barcelona had played as they normally do , they should have won .<|endoftext|>what happened ?<|endoftext|>their goalkeepers got injured after the first goal , but he kept on playing anyway .<|endoftext|>that explains it .<|endoftext|>how many more games are left this season ?<|endoftext|>that was the last game before the playoffs .<|endoftext|>when do the playoffs start ?<|endoftext|>in a few weeks . Are you planning on watching them ?<|endoftext|>of course ! I really love watching football games on TV .<|endoftext|>Me , too ! Do you want to come over and watch the next game at my place ? I'm planning on having a few people over to watch it together .<|endoftext|>sure , that would be great .", "response": "Ok , it's a date then !"} 5 | {"dialogue": "John , I was talking to the travel agent about where we might be taking our vacation this year .<|endoftext|>I am going fishing in Alaska with my friend , Mark .<|endoftext|>What are you talking about ?<|endoftext|>What's wrong with heading out with Mark for vacation ?<|endoftext|>You and I have been together for a whole year , and our vacation time should be about the two of us !<|endoftext|>Really ? Who made that rule up ?", "response": "With that attitude , I don't really think we have much more to discuss here ."} 6 | {"dialogue": "Hi , my name is Violet . Come with me , and I'll help you wash your hair .<|endoftext|>My hair is kind of dry and brittle ...<|endoftext|>I'll pick a shampoo that's just right for your hair type . Sit right here , and rest your neck on the side of the sink . Is the water too hot ?", "response": "No , it's just perfect ."} 7 | {"dialogue": "One thing I love our boss for is that he always knows when to give you a pay raise without being asked for .", "response": "Really ? How can he be so sure about the timing ?"} 8 | {"dialogue": "Hi , I'm Jake . I'm new to the choir . What's your name ?<|endoftext|>Hello , there , my name is Tonia .<|endoftext|>Do you sing alto ?<|endoftext|>Actually , I can do both soprano and alto but the director asked me to sing alto for the next perforate . What about you .<|endoftext|>Looks like we both float back and forth . I'm baritone .<|endoftext|>Our bass section is really good . You're going to love singing with them .", "response": "I heard them warming up earlier . You're right ."} 9 | {"dialogue": "Hey , Robert , that's a nice shirt you are wearing . Where did you get it ?", "response": "thanks , I like it too . I bought it at the nearby department store ."} 10 | {"dialogue": "What does it cost to ride this bus ?<|endoftext|>The fare is $ 1 . 25 .<|endoftext|>Have you been driving buses a long time ?<|endoftext|>I haven't been driving for long -- only for a few months .<|endoftext|>Do you like to drive the bus ?<|endoftext|>Not in the least bit .<|endoftext|>I would have never dreamed of ever becoming a bus driver .<|endoftext|>I never dreamed of doing this either . The only thing I like about it is the money .", "response": "It was really fun chatting with you ."} 11 | {"dialogue": "How is everything going with your girlfriend ?<|endoftext|>Didn't I tell you ? It's over !<|endoftext|>Oh , I am sorry to hear that . I did't know that you had split up . What happened ?<|endoftext|>It was a few things . The first thing that happened was that we were supposed to go out for a romantic dinner for our one year anniversary , but she stood me up !<|endoftext|>Really ! Did she tell you why she didn't show up ?<|endoftext|>No , but I ended up finding out later that night when I saw her drinking with another man at a club near my home !<|endoftext|>What was she thinking ? Did you confront her about it when you saw her ?", "response": "I wanted to , but I knew that if I spoke to her , I'd just blow up at her , so I decided to just go home . I called her later that the night , but she didn't answer the phone ."} 12 | {"dialogue": "Rock music really leaves me cold . What about you ?<|endoftext|>I'm crazy about it . It makes me very excited .", "response": "Then tell me what's good about it ."} -------------------------------------------------------------------------------- /data/fine-tune/train.json: -------------------------------------------------------------------------------- 1 | {"dialogue": "How have you been ?", "response": "Fine , thank you ."} 2 | {"dialogue": "Oh , come on , Ultraman !<|endoftext|>What's up , Bro ? What's in the bulletin ?<|endoftext|>It says that there will be a blackout from 5 p . m . to 7 p . m . in our neighborhood today .<|endoftext|>Blackout ? Even the TV has the limit .<|endoftext|>Don't you know you will look like a monster in the blackout ?", "response": "Oops , no , Daddy can't watch American Idol , either !"} 3 | {"dialogue": "You'll love this . It's a chick drink .<|endoftext|>What's that supposed to mean -- that it's weak ?<|endoftext|>Well , that too . I mean that it's kind of sweet . See if you can guess what's in it .", "response": "Mmm ! Beer , tequila , and ... lime !"} 4 | {"dialogue": "Isn't it wonderful walking here ?<|endoftext|>What do you mean ?<|endoftext|>I mean look at all these magnificent buildings around us .", "response": "Yes , look over there . That's the Empire State Building . My book says it's 102 stories tall ."} 5 | {"dialogue": "I want to say goodbye to everyone .<|endoftext|>You ' re leaving so soon . When are you off ?<|endoftext|>I ' m catching the 9 fifteen train tomorrow morning .<|endoftext|>How about I come and see you off ?<|endoftext|>You really don ' t need to .<|endoftext|>Ok . I ' ll miss you . I hope we can see each other again soon .<|endoftext|>I hope so , too . Thank you , Lily . Thank you for everything .<|endoftext|>You ' re welcome .<|endoftext|>Please say goodbye to the rest of the family for me .<|endoftext|>Ok . Take care . I hope you have a good journey .<|endoftext|>Thank you . Remember to look me up if you ' re ever in Washington .<|endoftext|>Of course . I will .", "response": "Goodbye , then . Thanks again for everything ."} 6 | {"dialogue": "what's wrong , Jerry ? You look so upset .<|endoftext|>to be honest , I was just dumped .<|endoftext|>oh , I'm sorry to hear that . You can go on a holiday to cheer you up .<|endoftext|>no , thanks . I'm not in the mood for traveling .<|endoftext|>come on . A trip will do you good . Are you doing anything this weekend ?<|endoftext|>I was planning on doing a lot of wallowing .<|endoftext|>well , my friends and I are planning on going to Shangri-La on Saturday . Do you want to come with us ?<|endoftext|>where is that ?<|endoftext|>not very far from here . We'll fly . It's about one and a half hours .<|endoftext|>what's there to see ?<|endoftext|>there is a large canyon , vast grasslands , ancient forests and mountain lakes .", "response": "oh , sounds nice ."} 7 | {"dialogue": "What a wonderful party ! I had a good time . How about you ?<|endoftext|>I enjoyed myself , too .<|endoftext|>Shall I give you a ride home ?<|endoftext|>Yes , if it's not too much of a trouble .", "response": "It's no trouble at all because your house is on the way to my place ."} 8 | {"dialogue": "Oh , the ink is spilled on the desk .<|endoftext|>Did it spill on your clothes ?<|endoftext|>No , but the table cloth was dirty .<|endoftext|>That's OK .<|endoftext|>I'm afraid it's too hard to wash off the stain .", "response": "It's no big deal ."} 9 | {"dialogue": "May's birthday is coming . Shall we buy her a birthday present or let her choose one for herself ?<|endoftext|>I think a surprise party may be better . But I forget when her birthday is .", "response": "You are such a good father . It's next Sunday ."} 10 | {"dialogue": "Excuse me . This chicken doesn't taste right to me .<|endoftext|>What seems to be the problem ?", "response": "How should I know ? It's just kind of cold in the middle . It just doesn't taste right . Do you want to try it ?"} 11 | {"dialogue": "That's all the general information of our company . I think you already have good knowledge about our company .<|endoftext|>Yes , I have an overall understanding .<|endoftext|>When we have the final results , we will call you .<|endoftext|>Then when will I get a reply at the latest ?<|endoftext|>If you pass the interview , the personnel department will inform you within two weeks .<|endoftext|>But if I don't pass , will you call me ?<|endoftext|>I'm sorry we won't . You can wait for two weeks . If you don't get a telephone call , it means that you weren't successful .<|endoftext|>Then do I have the chance to get this job ?<|endoftext|>I'm sorry but I can't make the final decision myself , and I have to discuss it with other interviewers .<|endoftext|>I know . No matter what the result will be , I have learned a lot from our conversation .", "response": "Your mentality is very good and that's great ."} 12 | {"dialogue": "Why don't you sit down and relax , darling ?<|endoftext|>I don't want to .<|endoftext|>Well , come over and talk to me then .<|endoftext|>Certainly not .<|endoftext|>May I turn on the TV then ?<|endoftext|>Turn on the TV for what ?<|endoftext|>So that we can sit down together and listen to some music .<|endoftext|>Listen to the music ? And who will cook dinner , will you ?<|endoftext|>I will , but let's go to the disco after dinner .", "response": "To a disco ? Oh , no . You know I hate it ."} 13 | {"dialogue": "Is there someone I can talk to about a payment question ?<|endoftext|>Yes , we can handle that here . How can I help you ?<|endoftext|>My paycheck that just arrived is less than last week ' s check .<|endoftext|>Did you work at all during the last pay period ?<|endoftext|>Yes , actually , I did make a little bit of money .<|endoftext|>Did you report it on your Continued Claim Form ?<|endoftext|>Yes , I showed that income on the Continued Claim Form .<|endoftext|>Well , we deducted a portion of the income that you made from this week ' s check .", "response": "Maybe I just shouldn ' t show the income then ."} 14 | {"dialogue": "Good morning , Miss Wu ! Can I ask you something ?<|endoftext|>Certainly . You are more welcome to do . What is it ?<|endoftext|>Tomorrow is my wife's birthday . We both love spicy Chinese dishes . I am wondering if you could recommend a good local restaurant where I can find some good spicy dishes .<|endoftext|>Well , if spicy local dishes are what you are looking for , In Hua Restaurant is the best place to go . The restaurant serves very good and spicy local dishes . You might want to try there .<|endoftext|>How far is it from here ?<|endoftext|>It is near the university . Five minutes ' walk from your flat , I think .<|endoftext|>Great . We'll go and have a try . Thank you very much .", "response": "You're welcome . May you have a good time . Please give my regards to your wife . I wish her a happy birthday tomorrow ."} 15 | {"dialogue": "Good morning ! I am a rookie in our office .<|endoftext|>Good morning ! Welcome to our office !<|endoftext|>Nice to meet you ! My name is Peter Smith .<|endoftext|>Nice to meet you too ! I am George Williams .", "response": "This is my first day at work !"} 16 | {"dialogue": "Excuse me .<|endoftext|>Yes ?<|endoftext|>Can you tell me the way to the Peak Tram , please ?", "response": "Certainly . Go along Queen's Road ..."} -------------------------------------------------------------------------------- /run_summarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset, load_metric 31 | 32 | import transformers 33 | from filelock import FileLock 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainer, 41 | Seq2SeqTrainingArguments, 42 | set_seed, 43 | BartForConditionalGeneration 44 | ) 45 | #from transformers.file_utils import is_offline_mode 46 | from transformers.trainer_utils import get_last_checkpoint 47 | #from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | #from model.summarization import BARTPromptConditionalGeneration 50 | from model.modeling_bart import BartPromptForConditionalGeneration 51 | 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | #check_min_version("4.13.0.dev0") 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 58 | 59 | logger = logging.getLogger(__name__) 60 | ''' 61 | try: 62 | nltk.data.find("tokenizers/punkt") 63 | except (LookupError, OSError): 64 | if is_offline_mode(): 65 | raise LookupError( 66 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 67 | ) 68 | with FileLock(".lock") as lock: 69 | nltk.download("punkt", quiet=True) 70 | 71 | ''' 72 | @dataclass 73 | class ModelArguments: 74 | """ 75 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 76 | """ 77 | 78 | model_name_or_path: str = field( 79 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 80 | ) 81 | config_name: Optional[str] = field( 82 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 83 | ) 84 | tokenizer_name: Optional[str] = field( 85 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 86 | ) 87 | cache_dir: Optional[str] = field( 88 | default=None, 89 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 90 | ) 91 | use_fast_tokenizer: bool = field( 92 | default=True, 93 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 94 | ) 95 | model_revision: str = field( 96 | default="main", 97 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 98 | ) 99 | use_auth_token: bool = field( 100 | default=False, 101 | metadata={ 102 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 103 | "with private models)." 104 | }, 105 | ) 106 | resize_position_embeddings: Optional[bool] = field( 107 | default=None, 108 | metadata={ 109 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 110 | "the model's position embeddings." 111 | }, 112 | ) 113 | pre_seq_len: Optional[int] = field( 114 | default=200, 115 | metadata={ 116 | "help": "length of prefix" 117 | 118 | } 119 | ) 120 | 121 | 122 | @dataclass 123 | class DataTrainingArguments: 124 | """ 125 | Arguments pertaining to what data we are going to input our model for training and eval. 126 | """ 127 | 128 | dataset_name: Optional[str] = field( 129 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 130 | ) 131 | dataset_config_name: Optional[str] = field( 132 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 133 | ) 134 | text_column: Optional[str] = field( 135 | default=None, 136 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 137 | ) 138 | summary_column: Optional[str] = field( 139 | default=None, 140 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 141 | ) 142 | train_file: Optional[str] = field( 143 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 144 | ) 145 | validation_file: Optional[str] = field( 146 | default=None, 147 | metadata={ 148 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 149 | "(a jsonlines or csv file)." 150 | }, 151 | ) 152 | test_file: Optional[str] = field( 153 | default=None, 154 | metadata={ 155 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 156 | }, 157 | ) 158 | overwrite_cache: bool = field( 159 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 160 | ) 161 | preprocessing_num_workers: Optional[int] = field( 162 | default=None, 163 | metadata={"help": "The number of processes to use for the preprocessing."}, 164 | ) 165 | max_source_length: Optional[int] = field( 166 | default=1024, 167 | metadata={ 168 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 169 | "than this will be truncated, sequences shorter will be padded." 170 | }, 171 | ) 172 | max_target_length: Optional[int] = field( 173 | default=128, 174 | metadata={ 175 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 176 | "than this will be truncated, sequences shorter will be padded." 177 | }, 178 | ) 179 | val_max_target_length: Optional[int] = field( 180 | default=None, 181 | metadata={ 182 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 183 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 184 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 185 | "during ``evaluate`` and ``predict``." 186 | }, 187 | ) 188 | pad_to_max_length: bool = field( 189 | default=False, 190 | metadata={ 191 | "help": "Whether to pad all samples to model maximum sentence length. " 192 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 193 | "efficient on GPU but very bad for TPU." 194 | }, 195 | ) 196 | max_train_samples: Optional[int] = field( 197 | default=None, 198 | metadata={ 199 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 200 | "value if set." 201 | }, 202 | ) 203 | max_eval_samples: Optional[int] = field( 204 | default=None, 205 | metadata={ 206 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 207 | "value if set." 208 | }, 209 | ) 210 | max_predict_samples: Optional[int] = field( 211 | default=None, 212 | metadata={ 213 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 214 | "value if set." 215 | }, 216 | ) 217 | num_beams: Optional[int] = field( 218 | default=None, 219 | metadata={ 220 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 221 | "which is used during ``evaluate`` and ``predict``." 222 | }, 223 | ) 224 | ignore_pad_token_for_loss: bool = field( 225 | default=True, 226 | metadata={ 227 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 228 | }, 229 | ) 230 | source_prefix: Optional[str] = field( 231 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 232 | ) 233 | 234 | def __post_init__(self): 235 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 236 | raise ValueError("Need either a dataset name or a training/validation file.") 237 | else: 238 | if self.train_file is not None: 239 | extension = self.train_file.split(".")[-1] 240 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 241 | if self.validation_file is not None: 242 | extension = self.validation_file.split(".")[-1] 243 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 244 | if self.val_max_target_length is None: 245 | self.val_max_target_length = self.max_target_length 246 | 247 | 248 | summarization_name_mapping = { 249 | "amazon_reviews_multi": ("review_body", "review_title"), 250 | "big_patent": ("description", "abstract"), 251 | "cnn_dailymail": ("article", "highlights"), 252 | "orange_sum": ("text", "summary"), 253 | "pn_summary": ("article", "summary"), 254 | "psc": ("extract_text", "summary_text"), 255 | "samsum": ("dialogue", "summary"), 256 | "thaisum": ("body", "summary"), 257 | "xglue": ("news_body", "news_title"), 258 | "xsum": ("document", "summary"), 259 | "wiki_summary": ("article", "highlights"), 260 | } 261 | 262 | 263 | def main(): 264 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 265 | # See all possible arguments in src/transformers/training_args.py 266 | # or by passing the --help flag to this script. 267 | # We now keep distinct sets of args, for a cleaner separation of concerns. 268 | 269 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 270 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 271 | # If we pass only one argument to the script and it's the path to a json file, 272 | # let's parse it to get our arguments. 273 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 274 | else: 275 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 276 | 277 | # Setup logging 278 | logging.basicConfig( 279 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 280 | datefmt="%m/%d/%Y %H:%M:%S", 281 | handlers=[logging.StreamHandler(sys.stdout)], 282 | ) 283 | print(training_args) 284 | log_level = 'INFO' 285 | logger.setLevel(log_level) 286 | datasets.utils.logging.set_verbosity(log_level) 287 | transformers.utils.logging.set_verbosity(log_level) 288 | transformers.utils.logging.enable_default_handler() 289 | transformers.utils.logging.enable_explicit_format() 290 | 291 | # Log on each process the small summary: 292 | logger.warning( 293 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 294 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 295 | ) 296 | logger.info(f"Training/evaluation parameters {training_args}") 297 | 298 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 299 | "t5-small", 300 | "t5-base", 301 | "t5-large", 302 | "t5-3b", 303 | "t5-11b", 304 | ]: 305 | logger.warning( 306 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 307 | "`--source_prefix 'summarize: ' `" 308 | ) 309 | 310 | # Detecting last checkpoint. 311 | last_checkpoint = None 312 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 313 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 314 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 315 | raise ValueError( 316 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 317 | "Use --overwrite_output_dir to overcome." 318 | ) 319 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 320 | logger.info( 321 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 322 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 323 | ) 324 | 325 | # Set seed before initializing model. 326 | set_seed(training_args.seed) 327 | 328 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 329 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 330 | # (the dataset will be downloaded automatically from the datasets Hub). 331 | # 332 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 333 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 334 | # 335 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 336 | # download the dataset. 337 | if data_args.dataset_name is not None: 338 | # Downloading and loading a dataset from the hub. 339 | raw_datasets = load_dataset( 340 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 341 | ) 342 | else: 343 | data_files = {} 344 | if data_args.train_file is not None: 345 | data_files["train"] = data_args.train_file 346 | extension = data_args.train_file.split(".")[-1] 347 | if data_args.validation_file is not None: 348 | data_files["validation"] = data_args.validation_file 349 | extension = data_args.validation_file.split(".")[-1] 350 | if data_args.test_file is not None: 351 | data_files["test"] = data_args.test_file 352 | extension = data_args.test_file.split(".")[-1] 353 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 354 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 355 | # https://huggingface.co/docs/datasets/loading_datasets.html. 356 | 357 | # Load pretrained model and tokenizer 358 | # 359 | # Distributed training: 360 | # The .from_pretrained methods guarantee that only one local process can concurrently 361 | # download model & vocab. 362 | config = AutoConfig.from_pretrained( 363 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 364 | cache_dir=model_args.cache_dir, 365 | revision=model_args.model_revision, 366 | use_auth_token=True if model_args.use_auth_token else None, 367 | ) 368 | config.pre_seq_len = model_args.pre_seq_len 369 | tokenizer = AutoTokenizer.from_pretrained( 370 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 371 | cache_dir=model_args.cache_dir, 372 | use_fast=model_args.use_fast_tokenizer, 373 | revision=model_args.model_revision, 374 | use_auth_token=True if model_args.use_auth_token else None, 375 | ) 376 | model = BartPromptForConditionalGeneration.from_pretrained( 377 | model_args.model_name_or_path, 378 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 379 | config=config, 380 | cache_dir=model_args.cache_dir, 381 | revision=model_args.model_revision, 382 | use_auth_token=True if model_args.use_auth_token else None, 383 | ) 384 | 385 | for name, param in model.named_parameters(): 386 | if 'prefix_encoder' in name: 387 | param.requires_grad = True 388 | print(name) 389 | elif 'prefix_decoder' in name: 390 | param.requires_grad = True 391 | print(name) 392 | else: 393 | param.requires_grad = False 394 | print("------------------") 395 | 396 | model.resize_token_embeddings(len(tokenizer)) 397 | 398 | if model.config.decoder_start_token_id is None: 399 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 400 | 401 | if ( 402 | hasattr(model.config, "max_position_embeddings") 403 | and model.config.max_position_embeddings < data_args.max_source_length 404 | ): 405 | if model_args.resize_position_embeddings is None: 406 | logger.warning( 407 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 408 | f"to {data_args.max_source_length}." 409 | ) 410 | model.resize_position_embeddings(data_args.max_source_length) 411 | elif model_args.resize_position_embeddings: 412 | model.resize_position_embeddings(data_args.max_source_length) 413 | else: 414 | raise ValueError( 415 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 416 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 417 | "resize the model's position encodings by passing `--resize_position_embeddings`." 418 | ) 419 | 420 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 421 | 422 | # Preprocessing the datasets. 423 | # We need to tokenize inputs and targets. 424 | if training_args.do_train: 425 | column_names = raw_datasets["train"].column_names 426 | elif training_args.do_eval: 427 | column_names = raw_datasets["validation"].column_names 428 | elif training_args.do_predict: 429 | column_names = raw_datasets["test"].column_names 430 | else: 431 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 432 | return 433 | 434 | # Get the column names for input/target. 435 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 436 | if data_args.text_column is None: 437 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 438 | else: 439 | text_column = data_args.text_column 440 | if text_column not in column_names: 441 | raise ValueError( 442 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 443 | ) 444 | if data_args.summary_column is None: 445 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 446 | else: 447 | summary_column = data_args.summary_column 448 | if summary_column not in column_names: 449 | raise ValueError( 450 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 451 | ) 452 | 453 | # Temporarily set max_target_length for training. 454 | max_target_length = data_args.max_target_length 455 | padding = "max_length" if data_args.pad_to_max_length else False 456 | 457 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 458 | logger.warning( 459 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 460 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 461 | ) 462 | 463 | def preprocess_function(examples): 464 | inputs = examples[text_column] 465 | targets = examples[summary_column] 466 | inputs = [prefix + inp for inp in inputs] 467 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length-model_args.pre_seq_len, padding=padding, truncation=True) 468 | 469 | # Setup the tokenizer for targets 470 | with tokenizer.as_target_tokenizer(): 471 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 472 | 473 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 474 | # padding in the loss. 475 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 476 | labels["input_ids"] = [ 477 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 478 | ] 479 | 480 | model_inputs["labels"] = labels["input_ids"] 481 | return model_inputs 482 | 483 | if training_args.do_train: 484 | if "train" not in raw_datasets: 485 | raise ValueError("--do_train requires a train dataset") 486 | train_dataset = raw_datasets["train"] 487 | if data_args.max_train_samples is not None: 488 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 489 | with training_args.main_process_first(desc="train dataset map pre-processing"): 490 | train_dataset = train_dataset.map( 491 | preprocess_function, 492 | batched=True, 493 | num_proc=data_args.preprocessing_num_workers, 494 | remove_columns=column_names, 495 | load_from_cache_file=not data_args.overwrite_cache, 496 | desc="Running tokenizer on train dataset", 497 | ) 498 | 499 | if training_args.do_eval: 500 | max_target_length = data_args.val_max_target_length 501 | if "validation" not in raw_datasets: 502 | raise ValueError("--do_eval requires a validation dataset") 503 | eval_dataset = raw_datasets["validation"] 504 | if data_args.max_eval_samples is not None: 505 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 506 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 507 | eval_dataset = eval_dataset.map( 508 | preprocess_function, 509 | batched=True, 510 | num_proc=data_args.preprocessing_num_workers, 511 | remove_columns=column_names, 512 | load_from_cache_file=not data_args.overwrite_cache, 513 | desc="Running tokenizer on validation dataset", 514 | ) 515 | 516 | if training_args.do_predict: 517 | max_target_length = data_args.val_max_target_length 518 | if "test" not in raw_datasets: 519 | raise ValueError("--do_predict requires a test dataset") 520 | predict_dataset = raw_datasets["test"] 521 | if data_args.max_predict_samples is not None: 522 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 523 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 524 | predict_dataset = predict_dataset.map( 525 | preprocess_function, 526 | batched=True, 527 | num_proc=data_args.preprocessing_num_workers, 528 | remove_columns=column_names, 529 | load_from_cache_file=not data_args.overwrite_cache, 530 | desc="Running tokenizer on prediction dataset", 531 | ) 532 | 533 | # Data collator 534 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 535 | data_collator = DataCollatorForSeq2Seq( 536 | tokenizer, 537 | model=model, 538 | label_pad_token_id=label_pad_token_id, 539 | pad_to_multiple_of=8 if training_args.fp16 else None, 540 | ) 541 | 542 | # Metric 543 | metric = load_metric("rouge") 544 | 545 | def postprocess_text(preds, labels): 546 | preds = [pred.strip() for pred in preds] 547 | labels = [label.strip() for label in labels] 548 | 549 | # rougeLSum expects newline after each sentence 550 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 551 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 552 | 553 | return preds, labels 554 | 555 | def compute_metrics(eval_preds): 556 | preds, labels = eval_preds 557 | if isinstance(preds, tuple): 558 | preds = preds[0] 559 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 560 | if data_args.ignore_pad_token_for_loss: 561 | # Replace -100 in the labels as we can't decode them. 562 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 563 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 564 | 565 | # Some simple post-processing 566 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 567 | 568 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 569 | # Extract a few results from ROUGE 570 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 571 | 572 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 573 | result["gen_len"] = np.mean(prediction_lens) 574 | result = {k: round(v, 4) for k, v in result.items()} 575 | return result 576 | 577 | # Initialize our Trainer 578 | trainer = Seq2SeqTrainer( 579 | model=model, 580 | args=training_args, 581 | train_dataset=train_dataset if training_args.do_train else None, 582 | eval_dataset=eval_dataset if training_args.do_eval else None, 583 | tokenizer=tokenizer, 584 | data_collator=data_collator, 585 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 586 | ) 587 | 588 | # Training 589 | if training_args.do_train: 590 | checkpoint = None 591 | if training_args.resume_from_checkpoint is not None: 592 | checkpoint = training_args.resume_from_checkpoint 593 | elif last_checkpoint is not None: 594 | checkpoint = last_checkpoint 595 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 596 | trainer.save_model() # Saves the tokenizer too for easy upload 597 | 598 | metrics = train_result.metrics 599 | max_train_samples = ( 600 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 601 | ) 602 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 603 | 604 | trainer.log_metrics("train", metrics) 605 | trainer.save_metrics("train", metrics) 606 | trainer.save_state() 607 | 608 | # Evaluation 609 | results = {} 610 | max_length = ( 611 | training_args.generation_max_length 612 | if training_args.generation_max_length is not None 613 | else data_args.val_max_target_length 614 | ) 615 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 616 | if training_args.do_eval: 617 | logger.info("*** Evaluate ***") 618 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 619 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 620 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 621 | 622 | trainer.log_metrics("eval", metrics) 623 | trainer.save_metrics("eval", metrics) 624 | 625 | if training_args.do_predict: 626 | logger.info("*** Predict ***") 627 | 628 | predict_results = trainer.predict( 629 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 630 | ) 631 | #print(predict_results) 632 | metrics = predict_results.metrics 633 | max_predict_samples = ( 634 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 635 | ) 636 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 637 | 638 | trainer.log_metrics("predict", metrics) 639 | trainer.save_metrics("predict", metrics) 640 | 641 | if trainer.is_world_process_zero(): 642 | if training_args.predict_with_generate: 643 | predictions = tokenizer.batch_decode( 644 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 645 | ) 646 | predictions = [" ".join(pred.split('\n')) for pred in predictions] 647 | predictions = [pred.strip() for pred in predictions] 648 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 649 | with open(output_prediction_file, "w") as writer: 650 | writer.write("\n".join(predictions)) 651 | 652 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 653 | if data_args.dataset_name is not None: 654 | kwargs["dataset_tags"] = data_args.dataset_name 655 | if data_args.dataset_config_name is not None: 656 | kwargs["dataset_args"] = data_args.dataset_config_name 657 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 658 | else: 659 | kwargs["dataset"] = data_args.dataset_name 660 | 661 | if training_args.push_to_hub: 662 | trainer.push_to_hub(**kwargs) 663 | else: 664 | trainer.create_model_card(**kwargs) 665 | 666 | return results 667 | 668 | 669 | def _mp_fn(index): 670 | # For xla_spawn (TPUs) 671 | main() 672 | 673 | 674 | if __name__ == "__main__": 675 | main() -------------------------------------------------------------------------------- /model/modeling_gpt_prompt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import math 19 | import os 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from packaging import version 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | import numpy as np 29 | 30 | if version.parse(torch.__version__) >= version.parse("1.6"): 31 | is_amp_available = True 32 | from torch.cuda.amp import autocast 33 | else: 34 | is_amp_available = False 35 | 36 | from transformers.activations import ACT2FN 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutputWithPastAndCrossAttentions, 39 | CausalLMOutputWithCrossAttentions, 40 | SequenceClassifierOutputWithPast, 41 | TokenClassifierOutput, 42 | ) 43 | from transformers.utils import logging 44 | from transformers.file_utils import ( 45 | DUMMY_INPUTS, 46 | DUMMY_MASK, 47 | add_start_docstrings, 48 | add_start_docstrings_to_model_forward, 49 | is_torch_fx_proxy, 50 | replace_return_docstrings, 51 | add_code_sample_docstrings 52 | ) 53 | from transformers.modeling_utils import PreTrainedModel, SequenceSummary 54 | #from transformers.src.transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer 55 | from transformers.models.gpt2 import GPT2Config, GPT2PreTrainedModel 56 | 57 | from transformers.models.gpt2.modeling_gpt2 import GPT2Block 58 | 59 | 60 | 61 | logger = logging.get_logger(__name__) 62 | 63 | _CHECKPOINT_FOR_DOC = "gpt2" 64 | _CONFIG_FOR_DOC = "GPT2Config" 65 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 66 | 67 | GPT2_START_DOCSTRING = r""" 68 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 69 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 70 | etc.) 71 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 72 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 73 | and behavior. 74 | Parameters: 75 | config ([`GPT2Config`]): Model configuration class with all the parameters of the model. 76 | Initializing with a config file does not load the weights associated with the model, only the 77 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 78 | """ 79 | 80 | 81 | PARALLELIZE_DOCSTRING = r""" 82 | This is an experimental feature and is a subject to change at a moment's notice. 83 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 84 | it will evenly distribute blocks across all devices. 85 | Args: 86 | device_map (`Dict[int, list]`, optional, defaults to None): 87 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 88 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 89 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 90 | following number of attention modules: 91 | - gpt2: 12 92 | - gpt2-medium: 24 93 | - gpt2-large: 36 94 | - gpt2-xl: 48 95 | Example: 96 | ```python 97 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 98 | model = GPT2LMHeadModel.from_pretrained("gpt2-xl") 99 | device_map = { 100 | 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 101 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 102 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 103 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], 104 | } 105 | model.parallelize(device_map) 106 | ``` 107 | """ 108 | DEPARALLELIZE_DOCSTRING = r""" 109 | Moves the model to cpu from a model parallel state. 110 | Example: 111 | ```python 112 | # On a 4 GPU machine with gpt2-large: 113 | model = GPT2LMHeadModel.from_pretrained("gpt2-large") 114 | device_map = { 115 | 0: [0, 1, 2, 3, 4, 5, 6, 7], 116 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 117 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 118 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 119 | } 120 | model.parallelize(device_map) # Splits the model across several devices 121 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 122 | ``` 123 | """ 124 | 125 | GPT2_INPUTS_DOCSTRING = r""" 126 | Args: 127 | input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): 128 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else 129 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input 130 | sequence tokens in the vocabulary. 131 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as 132 | `input_ids`. 133 | Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and 134 | [`PreTrainedTokenizer.__call__`] for details. 135 | [What are input IDs?](../glossary#input-ids) 136 | past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): 137 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 138 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have 139 | their past given to this model should not be passed as `input_ids` as they have already been computed. 140 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 141 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 142 | - 1 for tokens that are **not masked**, 143 | - 0 for tokens that are **masked**. 144 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for 145 | `past_key_values`. In other words, the `attention_mask` always has to have the length: 146 | `len(past_key_values) + len(input_ids)` 147 | [What are attention masks?](../glossary#attention-mask) 148 | token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): 149 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 150 | 1]`: 151 | - 0 corresponds to a *sentence A* token, 152 | - 1 corresponds to a *sentence B* token. 153 | [What are token type IDs?](../glossary#token-type-ids) 154 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 155 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 156 | config.max_position_embeddings - 1]`. 157 | [What are position IDs?](../glossary#position-ids) 158 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 159 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 160 | - 1 indicates the head is **not masked**, 161 | - 0 indicates the head is **masked**. 162 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 163 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 164 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 165 | model's internal embedding lookup matrix. 166 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see 167 | `past_key_values`). 168 | use_cache (`bool`, *optional*): 169 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 170 | `past_key_values`). 171 | output_attentions (`bool`, *optional*): 172 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 173 | tensors for more detail. 174 | output_hidden_states (`bool`, *optional*): 175 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 176 | more detail. 177 | return_dict (`bool`, *optional*): 178 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 179 | """ 180 | 181 | @add_start_docstrings( 182 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 183 | GPT2_START_DOCSTRING, 184 | ) 185 | class GPT2Model(GPT2PreTrainedModel): 186 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 187 | 188 | def __init__(self, config): 189 | super().__init__(config) 190 | 191 | self.embed_dim = config.hidden_size 192 | 193 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 194 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 195 | 196 | self.drop = nn.Dropout(config.embd_pdrop) 197 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 198 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 199 | 200 | self.pre_seq_len = config.pre_seq_len 201 | self.prefix_tokens = torch.arange(self.pre_seq_len).long() 202 | self.prefix_decoder = torch.nn.Embedding(self.pre_seq_len, config.n_embd) 203 | self.prefix_drop = nn.Dropout(config.prefix_drop) 204 | 205 | self.prefix_decoder_mlp = nn.Linear(config.n_embd, config.n_embd) 206 | 207 | 208 | # Model parallel 209 | self.model_parallel = False 210 | self.device_map = None 211 | self.gradient_checkpointing = False 212 | 213 | # Initialize weights and apply final processing 214 | self.post_init() 215 | 216 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 217 | def parallelize(self, device_map=None): 218 | # Check validity of device_map 219 | self.device_map = ( 220 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 221 | ) 222 | assert_device_map(self.device_map, len(self.h)) 223 | self.model_parallel = True 224 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 225 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 226 | self.wte = self.wte.to(self.first_device) 227 | self.wpe = self.wpe.to(self.first_device) 228 | # Load onto devices 229 | for k, v in self.device_map.items(): 230 | for block in v: 231 | cuda_device = "cuda:" + str(k) 232 | self.h[block] = self.h[block].to(cuda_device) 233 | # ln_f to last 234 | self.ln_f = self.ln_f.to(self.last_device) 235 | 236 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 237 | def deparallelize(self): 238 | self.model_parallel = False 239 | self.device_map = None 240 | self.first_device = "cpu" 241 | self.last_device = "cpu" 242 | self.wte = self.wte.to("cpu") 243 | self.wpe = self.wpe.to("cpu") 244 | for index in range(len(self.h)): 245 | self.h[index] = self.h[index].to("cpu") 246 | self.ln_f = self.ln_f.to("cpu") 247 | torch.cuda.empty_cache() 248 | 249 | def get_prompt_decoder(self, batch_size): 250 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device) 251 | prompts = self.prefix_decoder(prefix_tokens) 252 | 253 | prompts = self.prefix_decoder_mlp(prompts) 254 | return prompts 255 | 256 | def get_input_embeddings(self): 257 | return self.wte 258 | 259 | def set_input_embeddings(self, new_embeddings): 260 | self.wte = new_embeddings 261 | 262 | def _prune_heads(self, heads_to_prune): 263 | """ 264 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 265 | """ 266 | for layer, heads in heads_to_prune.items(): 267 | self.h[layer].attn.prune_heads(heads) 268 | 269 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 270 | @add_code_sample_docstrings( 271 | processor_class=_TOKENIZER_FOR_DOC, 272 | checkpoint=_CHECKPOINT_FOR_DOC, 273 | output_type=BaseModelOutputWithPastAndCrossAttentions, 274 | config_class=_CONFIG_FOR_DOC, 275 | ) 276 | def forward( 277 | self, 278 | input_ids: Optional[torch.LongTensor] = None, 279 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 280 | attention_mask: Optional[torch.FloatTensor] = None, 281 | token_type_ids: Optional[torch.LongTensor] = None, 282 | position_ids: Optional[torch.LongTensor] = None, 283 | head_mask: Optional[torch.FloatTensor] = None, 284 | inputs_embeds: Optional[torch.FloatTensor] = None, 285 | encoder_hidden_states: Optional[torch.Tensor] = None, 286 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 287 | use_cache: Optional[bool] = None, 288 | output_attentions: Optional[bool] = None, 289 | output_hidden_states: Optional[bool] = None, 290 | return_dict: Optional[bool] = None, 291 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 292 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 293 | output_hidden_states = ( 294 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 295 | ) 296 | use_cache = use_cache if use_cache is not None else self.config.use_cache 297 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 298 | 299 | if input_ids is not None and inputs_embeds is not None: 300 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 301 | elif input_ids is not None: 302 | input_shape = input_ids.size() 303 | input_ids = input_ids.view(-1, input_shape[-1]) 304 | batch_size = input_ids.shape[0] 305 | elif inputs_embeds is not None: 306 | input_shape = inputs_embeds.size()[:-1] 307 | batch_size = inputs_embeds.shape[0] 308 | else: 309 | raise ValueError("You have to specify either input_ids or inputs_embeds") 310 | 311 | device = input_ids.device if input_ids is not None else inputs_embeds.device 312 | 313 | 314 | if past_key_values is None: 315 | past_length = 0 316 | past_key_values = tuple([None] * len(self.h)) 317 | else: 318 | past_length = past_key_values[0][0].size(-2) 319 | 320 | 321 | 322 | if inputs_embeds is None: 323 | if past_length == 0: 324 | batch_size = input_ids.shape[0] 325 | raw_embeds = self.wte(input_ids) 326 | prompts_decoder = self.get_prompt_decoder(batch_size=batch_size) 327 | prompts_decoder = self.prefix_drop(prompts_decoder) 328 | inputs_embeds = torch.cat((prompts_decoder, raw_embeds), dim=1) 329 | 330 | 331 | 332 | #inputs_prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.device) 333 | #attention_mask = torch.cat((inputs_prefix_attention_mask, attention_mask), dim=1) 334 | input_shape = inputs_embeds.size()[:-1] 335 | else: 336 | inputs_embeds = self.wte(input_ids) 337 | input_shape = inputs_embeds.size()[:-1] 338 | 339 | inputs_prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.device) 340 | attention_mask = torch.cat((inputs_prefix_attention_mask, attention_mask), dim=1) 341 | 342 | #print(attention_mask.shape) 343 | 344 | if token_type_ids is not None: 345 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 346 | 347 | 348 | if position_ids is not None: 349 | 350 | if past_length == 0: 351 | position_ids = attention_mask.long().cumsum(-1) - 1 352 | position_ids.masked_fill_(attention_mask == 0, 1) 353 | 354 | position_ids = position_ids.view(-1, input_shape[-1]) 355 | 356 | 357 | 358 | ''' 359 | if past_key_values is None: 360 | past_length = 0 361 | past_key_values = tuple([None] * len(self.h)) 362 | else: 363 | past_length = past_key_values[0][0].size(-2) 364 | 365 | ''' 366 | 367 | 368 | if position_ids is None: 369 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 370 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 371 | 372 | # GPT2Attention mask. 373 | 374 | if attention_mask is not None: 375 | if batch_size <= 0: 376 | raise ValueError("batch_size has to be defined and > 0") 377 | attention_mask = attention_mask.view(batch_size, -1) 378 | # We create a 3D attention mask from a 2D tensor mask. 379 | # Sizes are [batch_size, 1, 1, to_seq_length] 380 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 381 | # this attention mask is more simple than the triangular masking of causal attention 382 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 383 | attention_mask = attention_mask[:, None, None, :] 384 | 385 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 386 | # masked positions, this operation will create a tensor which is 0.0 for 387 | # positions we want to attend and -10000.0 for masked positions. 388 | # Since we are adding it to the raw scores before the softmax, this is 389 | # effectively the same as removing these entirely. 390 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 391 | attention_mask = (1.0 - attention_mask) * -10000.0 392 | 393 | #print("attention_mask") 394 | #print(attention_mask.shape) 395 | 396 | # If a 2D or 3D attention mask is provided for the cross-attention 397 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 398 | if self.config.add_cross_attention and encoder_hidden_states is not None: 399 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 400 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 401 | if encoder_attention_mask is None: 402 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 403 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 404 | else: 405 | encoder_attention_mask = None 406 | 407 | # Prepare head mask if needed 408 | # 1.0 in head_mask indicate we keep the head 409 | # attention_probs has shape bsz x n_heads x N x N 410 | # head_mask has shape n_layer x batch x n_heads x N x N 411 | 412 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 413 | 414 | 415 | ''' 416 | if inputs_embeds is None: 417 | inputs_embeds = self.wte(input_ids) 418 | 419 | ''' 420 | 421 | #print(position_ids.shape) 422 | position_embeds = self.wpe(position_ids) 423 | hidden_states = inputs_embeds + position_embeds 424 | 425 | 426 | 427 | if token_type_ids is not None: 428 | token_type_embeds = self.wte(token_type_ids) 429 | hidden_states = hidden_states + token_type_embeds 430 | 431 | hidden_states = self.drop(hidden_states) 432 | 433 | output_shape = input_shape + (hidden_states.size(-1),) 434 | 435 | presents = () if use_cache else None 436 | all_self_attentions = () if output_attentions else None 437 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 438 | all_hidden_states = () if output_hidden_states else None 439 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 440 | 441 | # Model parallel 442 | if self.model_parallel: 443 | torch.cuda.set_device(hidden_states.device) 444 | # Ensure layer_past is on same device as hidden_states (might not be correct) 445 | if layer_past is not None: 446 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 447 | # Ensure that attention_mask is always on the same device as hidden_states 448 | if attention_mask is not None: 449 | attention_mask = attention_mask.to(hidden_states.device) 450 | if isinstance(head_mask, torch.Tensor): 451 | head_mask = head_mask.to(hidden_states.device) 452 | if output_hidden_states: 453 | all_hidden_states = all_hidden_states + (hidden_states,) 454 | 455 | if self.gradient_checkpointing and self.training: 456 | 457 | if use_cache: 458 | logger.warning( 459 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 460 | ) 461 | use_cache = False 462 | 463 | def create_custom_forward(module): 464 | def custom_forward(*inputs): 465 | # None for past_key_value 466 | return module(*inputs, use_cache, output_attentions) 467 | 468 | return custom_forward 469 | 470 | outputs = torch.utils.checkpoint.checkpoint( 471 | create_custom_forward(block), 472 | hidden_states, 473 | None, 474 | attention_mask, 475 | head_mask[i], 476 | encoder_hidden_states, 477 | encoder_attention_mask, 478 | ) 479 | else: 480 | outputs = block( 481 | hidden_states, 482 | layer_past=layer_past, 483 | attention_mask=attention_mask, 484 | head_mask=head_mask[i], 485 | encoder_hidden_states=encoder_hidden_states, 486 | encoder_attention_mask=encoder_attention_mask, 487 | use_cache=use_cache, 488 | output_attentions=output_attentions, 489 | ) 490 | 491 | hidden_states = outputs[0] 492 | if use_cache is True: 493 | presents = presents + (outputs[1],) 494 | 495 | if output_attentions: 496 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 497 | if self.config.add_cross_attention: 498 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 499 | 500 | # Model Parallel: If it's the last layer for that device, put things on the next device 501 | if self.model_parallel: 502 | for k, v in self.device_map.items(): 503 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 504 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 505 | 506 | hidden_states = self.ln_f(hidden_states) 507 | 508 | hidden_states = hidden_states.view(output_shape) 509 | # Add last hidden state 510 | if output_hidden_states: 511 | all_hidden_states = all_hidden_states + (hidden_states,) 512 | 513 | if not return_dict: 514 | return tuple( 515 | v 516 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 517 | if v is not None 518 | ) 519 | 520 | return BaseModelOutputWithPastAndCrossAttentions( 521 | last_hidden_state=hidden_states, 522 | past_key_values=presents, 523 | hidden_states=all_hidden_states, 524 | attentions=all_self_attentions, 525 | cross_attentions=all_cross_attentions, 526 | ) 527 | 528 | 529 | @add_start_docstrings( 530 | """ 531 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 532 | embeddings). 533 | """, 534 | GPT2_START_DOCSTRING, 535 | ) 536 | class GPT2LMHeadModel(GPT2PreTrainedModel): 537 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 538 | 539 | def __init__(self, config): 540 | super().__init__(config) 541 | self.transformer = GPT2Model(config) 542 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 543 | 544 | self.pre_seq_len = config.pre_seq_len 545 | 546 | # Model parallel 547 | self.model_parallel = False 548 | self.device_map = None 549 | 550 | # Initialize weights and apply final processing 551 | self.post_init() 552 | 553 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 554 | def parallelize(self, device_map=None): 555 | self.device_map = ( 556 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 557 | if device_map is None 558 | else device_map 559 | ) 560 | assert_device_map(self.device_map, len(self.transformer.h)) 561 | self.transformer.parallelize(self.device_map) 562 | self.lm_head = self.lm_head.to(self.transformer.first_device) 563 | self.model_parallel = True 564 | 565 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 566 | def deparallelize(self): 567 | self.transformer.deparallelize() 568 | self.transformer = self.transformer.to("cpu") 569 | self.lm_head = self.lm_head.to("cpu") 570 | self.model_parallel = False 571 | torch.cuda.empty_cache() 572 | 573 | def get_output_embeddings(self): 574 | return self.lm_head 575 | 576 | def set_output_embeddings(self, new_embeddings): 577 | self.lm_head = new_embeddings 578 | 579 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 580 | token_type_ids = kwargs.get("token_type_ids", None) 581 | # only last token for inputs_ids if past is defined in kwargs 582 | if past: 583 | input_ids = input_ids[:, -1].unsqueeze(-1) 584 | if token_type_ids is not None: 585 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 586 | 587 | attention_mask = kwargs.get("attention_mask", None) 588 | position_ids = kwargs.get("position_ids", None) 589 | 590 | if attention_mask is not None and position_ids is None: 591 | # create position_ids on the fly for batch generation 592 | position_ids = attention_mask.long().cumsum(-1) - 1 593 | position_ids.masked_fill_(attention_mask == 0, 1) 594 | if past: 595 | position_ids = position_ids[:, -1].unsqueeze(-1) 596 | else: 597 | position_ids = None 598 | 599 | return { 600 | "input_ids": input_ids, 601 | "past_key_values": past, 602 | "use_cache": kwargs.get("use_cache"), 603 | "position_ids": position_ids, 604 | "attention_mask": attention_mask, 605 | "token_type_ids": token_type_ids, 606 | } 607 | 608 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 609 | @add_code_sample_docstrings( 610 | processor_class=_TOKENIZER_FOR_DOC, 611 | checkpoint=_CHECKPOINT_FOR_DOC, 612 | output_type=CausalLMOutputWithCrossAttentions, 613 | config_class=_CONFIG_FOR_DOC, 614 | ) 615 | def forward( 616 | self, 617 | input_ids: Optional[torch.LongTensor] = None, 618 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 619 | attention_mask: Optional[torch.FloatTensor] = None, 620 | token_type_ids: Optional[torch.LongTensor] = None, 621 | position_ids: Optional[torch.LongTensor] = None, 622 | head_mask: Optional[torch.FloatTensor] = None, 623 | inputs_embeds: Optional[torch.FloatTensor] = None, 624 | encoder_hidden_states: Optional[torch.Tensor] = None, 625 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 626 | labels: Optional[torch.LongTensor] = None, 627 | use_cache: Optional[bool] = None, 628 | output_attentions: Optional[bool] = None, 629 | output_hidden_states: Optional[bool] = None, 630 | return_dict: Optional[bool] = None, 631 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 632 | 633 | 634 | 635 | r""" 636 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 637 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 638 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 639 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 640 | """ 641 | 642 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 643 | 644 | 645 | transformer_outputs = self.transformer( 646 | input_ids, 647 | past_key_values=past_key_values, 648 | attention_mask=attention_mask, 649 | token_type_ids=token_type_ids, 650 | position_ids=position_ids, 651 | head_mask=head_mask, 652 | inputs_embeds=inputs_embeds, 653 | encoder_hidden_states=encoder_hidden_states, 654 | encoder_attention_mask=encoder_attention_mask, 655 | use_cache=use_cache, 656 | output_attentions=output_attentions, 657 | output_hidden_states=output_hidden_states, 658 | return_dict=return_dict, 659 | ) 660 | hidden_states = transformer_outputs[0] 661 | 662 | 663 | if past_key_values is None: 664 | hidden_states = hidden_states[:, self.pre_seq_len:, :].contiguous() 665 | else: 666 | hidden_states = hidden_states 667 | 668 | 669 | # Set device for model parallelism 670 | if self.model_parallel: 671 | torch.cuda.set_device(self.transformer.first_device) 672 | hidden_states = hidden_states.to(self.lm_head.weight.device) 673 | 674 | lm_logits = self.lm_head(hidden_states) 675 | 676 | loss = None 677 | if labels is not None: 678 | # Shift so that tokens < n predict n 679 | shift_logits = lm_logits[..., :-1, :].contiguous() 680 | shift_labels = labels[..., 1:].contiguous() 681 | # Flatten the tokens 682 | loss_fct = CrossEntropyLoss() 683 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 684 | 685 | if not return_dict: 686 | output = (lm_logits,) + transformer_outputs[1:] 687 | return ((loss,) + output) if loss is not None else output 688 | 689 | return CausalLMOutputWithCrossAttentions( 690 | loss=loss, 691 | logits=lm_logits, 692 | past_key_values=transformer_outputs.past_key_values, 693 | hidden_states=transformer_outputs.hidden_states, 694 | attentions=transformer_outputs.attentions, 695 | cross_attentions=transformer_outputs.cross_attentions, 696 | ) 697 | 698 | @staticmethod 699 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 700 | """ 701 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 702 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 703 | beam_idx at every generation step. 704 | """ 705 | return tuple( 706 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 707 | for layer_past in past 708 | ) -------------------------------------------------------------------------------- /run_gpt_prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fine-tuning the library models for sequence to sequence. 3 | """ 4 | 5 | import logging 6 | import os 7 | import sys 8 | from dataclasses import dataclass, field 9 | from typing import Optional 10 | 11 | import datasets 12 | import nltk # Here to have a nice missing dependency error message early on 13 | import numpy as np 14 | from datasets import load_dataset, load_metric 15 | 16 | import transformers 17 | 18 | from filelock import FileLock 19 | 20 | from transformers import GPT2Tokenizer, AutoModelForCausalLM 21 | 22 | 23 | from model.modeling_gpt_prompt import GPT2LMHeadModel 24 | from transformers import Trainer, TrainingArguments 25 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 26 | 27 | 28 | from transformers.file_utils import is_offline_mode 29 | from transformers.trainer_utils import get_last_checkpoint 30 | from transformers.utils import check_min_version 31 | from transformers.utils.versions import require_version 32 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments, set_seed, AutoConfig, AutoTokenizer 33 | from transformers import DataCollatorWithPadding 34 | 35 | 36 | 37 | 38 | from typing import Any, Dict, List, Optional, Tuple, Union 39 | 40 | 41 | import torch 42 | from torch import nn 43 | from torch.utils.data import Dataset 44 | 45 | from transformers.deepspeed import is_deepspeed_zero3_enabled 46 | from transformers.trainer import Trainer 47 | from transformers.trainer_utils import PredictionOutput 48 | import jsonlines 49 | from metric.evaluation import compute_bleu 50 | #from metric.cal_rouge import compute_rouges 51 | #from metric.cal_bleu import compute_bleus 52 | logger = logging.getLogger(__name__) 53 | 54 | class GPTTrainer(Trainer): 55 | 56 | def train_addparam( 57 | self, 58 | max_length: Optional[int] = None, 59 | num_beams: Optional[int] = None, 60 | ): 61 | self._max_length = max_length if max_length is not None else self.args.generation_max_length 62 | self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams 63 | 64 | 65 | def evaluate( 66 | self, 67 | eval_dataset: Optional[Dataset] = None, 68 | ignore_keys: Optional[List[str]] = None, 69 | metric_key_prefix: str = "eval", 70 | max_length: Optional[int] = None, 71 | num_beams: Optional[int] = None, 72 | ) -> Dict[str, float]: 73 | 74 | #if hasattr(self, ) 75 | if not hasattr(self, '_max_length'): 76 | self._max_length = max_length if max_length is not None else self.args.generation_max_length 77 | if not hasattr(self, '_num_beams'): 78 | self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams 79 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 80 | 81 | 82 | 83 | def predict( 84 | self, 85 | test_dataset: Dataset, 86 | ignore_keys: Optional[List[str]] = None, 87 | metric_key_prefix: str = "test", 88 | max_length: Optional[int] = None, 89 | num_beams: Optional[int] = None, 90 | ) -> PredictionOutput: 91 | 92 | 93 | 94 | self._max_length = max_length if max_length is not None else self.args.generation_max_length 95 | self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams 96 | 97 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 98 | 99 | 100 | ''' 101 | def compute_loss(self, model, inputs): 102 | 103 | outputs = model(input_ids=inputs['input_ids'], 104 | labels=inputs['input_ids'], 105 | attention_mask=inputs['attention_mask']) 106 | 107 | 108 | # Save past state if it exists 109 | # TODO: this needs to be fixed and made cleaner later. 110 | 111 | if self.args.past_index >= 0: 112 | self._past = outputs[self.args.past_index] 113 | 114 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 115 | 116 | return outputs["loss"] if isinstance(outputs, dict) else outputs[0] 117 | 118 | ''' 119 | 120 | def prediction_step( 121 | self, 122 | model: nn.Module, 123 | inputs: Dict[str, Union[torch.Tensor, Any]], 124 | prediction_loss_only: bool, 125 | ignore_keys: Optional[List[str]] = None, 126 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 127 | 128 | if not self.args.predict_with_generate or prediction_loss_only: 129 | return super().prediction_step( 130 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 131 | ) 132 | 133 | #has_labels = all(inputs.get(k) is not None for k in self.label_names) 134 | 135 | has_labels = "labels" in inputs 136 | inputs = self._prepare_inputs(inputs) 137 | 138 | # XXX: adapt synced_gpus for fairscale as well 139 | gen_kwargs = { 140 | "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 141 | "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 142 | "synced_gpus": True if is_deepspeed_zero3_enabled() else False, 143 | } 144 | 145 | generated_tokens = self.model.generate( 146 | inputs["input_ids"], 147 | attention_mask=inputs.get("attention_mask", None), 148 | pad_token_id=self.tokenizer.eos_token_id, 149 | **gen_kwargs, 150 | ) 151 | 152 | if len(generated_tokens) == 1: 153 | generated_tokens = generated_tokens[0] 154 | 155 | if generated_tokens.shape[-1] <= gen_kwargs["max_length"]: 156 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 157 | 158 | 159 | ''' 160 | with torch.no_grad(): 161 | with self.autocast_smart_context_manager(): 162 | outputs = model(input_ids=inputs['input_ids'], 163 | labels=inputs['input_ids']) 164 | if has_labels: 165 | if self.label_smoother is not None: 166 | loss = self.label_smoother(outputs, inputs["input_ids"]).mean().detach() 167 | else: 168 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 169 | else: 170 | loss = None 171 | ''' 172 | 173 | with torch.no_grad(): 174 | with self.autocast_smart_context_manager(): 175 | outputs = model(**inputs) 176 | if has_labels: 177 | if self.label_smoother is not None: 178 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 179 | else: 180 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 181 | else: 182 | loss = None 183 | 184 | 185 | 186 | if self.args.prediction_loss_only: 187 | return (loss, None, None) 188 | 189 | if has_labels: 190 | labels = inputs["labels"] 191 | else: 192 | labels = None 193 | 194 | return (loss, generated_tokens, labels) 195 | 196 | def _pad_tensors_to_max_len(self, tensor, max_length): 197 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 198 | # If PAD token is not defined at least EOS token has to be defined 199 | pad_token_id = ( 200 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 201 | ) 202 | else: 203 | if self.model.config.pad_token_id is not None: 204 | pad_token_id = self.model.config.pad_token_id 205 | else: 206 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 207 | 208 | padded_tensor = pad_token_id * torch.ones( 209 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 210 | ) 211 | padded_tensor[:, : tensor.shape[-1]] = tensor 212 | return padded_tensor 213 | 214 | 215 | 216 | @dataclass 217 | class ModelArguments: 218 | """ 219 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 220 | """ 221 | 222 | model_name_or_path: str = field( 223 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 224 | ) 225 | config_name: Optional[str] = field( 226 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 227 | ) 228 | tokenizer_name: Optional[str] = field( 229 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 230 | ) 231 | cache_dir: Optional[str] = field( 232 | default=None, 233 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 234 | ) 235 | use_fast_tokenizer: bool = field( 236 | default=True, 237 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 238 | ) 239 | model_revision: str = field( 240 | default="main", 241 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 242 | ) 243 | use_auth_token: bool = field( 244 | default=False, 245 | metadata={ 246 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 247 | "with private models)." 248 | }, 249 | ) 250 | resize_position_embeddings: Optional[bool] = field( 251 | default=None, 252 | metadata={ 253 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 254 | "the model's position embeddings." 255 | }, 256 | ) 257 | model_name: str = field( 258 | default="t5-pegasus", 259 | metadata={"help" : "model name"}, 260 | ) 261 | pre_seq_len: Optional[int] = field( 262 | default=200, 263 | metadata={ 264 | "help": "length of prefix" 265 | 266 | } 267 | ) 268 | prefix_drop: Optional[float] = field( 269 | default=0.1, 270 | metadata={ 271 | "help": "prefix dropout rate" 272 | 273 | } 274 | ) 275 | 276 | 277 | 278 | @dataclass 279 | class DataTrainingArguments: 280 | """ 281 | Arguments pertaining to what data we are going to input our model for training and eval. 282 | """ 283 | 284 | lang: str = field(default=None, metadata={"help": "Language id for summarization."}) 285 | 286 | dataset_name: Optional[str] = field( 287 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 288 | ) 289 | dataset_config_name: Optional[str] = field( 290 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 291 | ) 292 | text_column: Optional[str] = field( 293 | default=None, 294 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 295 | ) 296 | summary_column: Optional[str] = field( 297 | default=None, 298 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 299 | ) 300 | train_file: Optional[str] = field( 301 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 302 | ) 303 | validation_file: Optional[str] = field( 304 | default=None, 305 | metadata={ 306 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 307 | "(a jsonlines or csv file)." 308 | }, 309 | ) 310 | test_file: Optional[str] = field( 311 | default=None, 312 | metadata={ 313 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 314 | }, 315 | ) 316 | overwrite_cache: bool = field( 317 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 318 | ) 319 | preprocessing_num_workers: Optional[int] = field( 320 | default=None, 321 | metadata={"help": "The number of processes to use for the preprocessing."}, 322 | ) 323 | max_source_length: Optional[int] = field( 324 | default=512, 325 | metadata={ 326 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 327 | "than this will be truncated, sequences shorter will be padded." 328 | }, 329 | ) 330 | max_target_length: Optional[int] = field( 331 | default=128, 332 | metadata={ 333 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 334 | "than this will be truncated, sequences shorter will be padded." 335 | }, 336 | ) 337 | val_max_target_length: Optional[int] = field( 338 | default=None, 339 | metadata={ 340 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 341 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 342 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 343 | "during ``evaluate`` and ``predict``." 344 | }, 345 | ) 346 | pad_to_max_length: bool = field( 347 | default=False, 348 | metadata={ 349 | "help": "Whether to pad all samples to model maximum sentence length. " 350 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 351 | "efficient on GPU but very bad for TPU." 352 | }, 353 | ) 354 | max_train_samples: Optional[int] = field( 355 | default=None, 356 | metadata={ 357 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 358 | "value if set." 359 | }, 360 | ) 361 | max_eval_samples: Optional[int] = field( 362 | default=None, 363 | metadata={ 364 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 365 | "value if set." 366 | }, 367 | ) 368 | max_predict_samples: Optional[int] = field( 369 | default=None, 370 | metadata={ 371 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 372 | "value if set." 373 | }, 374 | ) 375 | num_beams: Optional[int] = field( 376 | default=None, 377 | metadata={ 378 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 379 | "which is used during ``evaluate`` and ``predict``." 380 | }, 381 | ) 382 | ignore_pad_token_for_loss: bool = field( 383 | default=True, 384 | metadata={ 385 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 386 | }, 387 | ) 388 | source_prefix: Optional[str] = field( 389 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 390 | ) 391 | 392 | forced_bos_token: Optional[str] = field( 393 | default=None, 394 | metadata={ 395 | "help": "The token to force as the first generated token after the decoder_start_token_id." 396 | "Useful for multilingual models like mBART where the first generated token" 397 | "needs to be the target language token (Usually it is the target language token)" 398 | }, 399 | ) 400 | 401 | def __post_init__(self): 402 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 403 | raise ValueError("Need either a dataset name or a training/validation file.") 404 | else: 405 | if self.train_file is not None: 406 | extension = self.train_file.split(".")[-1] 407 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 408 | if self.validation_file is not None: 409 | extension = self.validation_file.split(".")[-1] 410 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 411 | if self.val_max_target_length is None: 412 | self.val_max_target_length = self.max_target_length 413 | 414 | 415 | summarization_name_mapping = { 416 | "amazon_reviews_multi": ("review_body", "review_title"), 417 | "big_patent": ("description", "abstract"), 418 | "cnn_dailymail": ("article", "highlights"), 419 | "orange_sum": ("text", "summary"), 420 | "pn_summary": ("article", "summary"), 421 | "psc": ("extract_text", "summary_text"), 422 | "samsum": ("dialogue", "summary"), 423 | "thaisum": ("body", "summary"), 424 | "xglue": ("news_body", "news_title"), 425 | "xsum": ("document", "summary"), 426 | "wiki_summary": ("article", "highlights"), 427 | } 428 | 429 | def main(): 430 | os.environ["WANDB_DISABLED"] = "true" 431 | # See all possible arguments in src/transformers/training_args.py 432 | # or by passing the --help flag to this script. 433 | # We now keep distinct sets of args, for a cleaner separation of concerns. 434 | 435 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 436 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 437 | # If we pass only one argument to the script and it's the path to a json file, 438 | # let's parse it to get our arguments. 439 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 440 | else: 441 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 442 | 443 | # Setup logging 444 | logging.basicConfig( 445 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 446 | datefmt="%m/%d/%Y %H:%M:%S", 447 | handlers=[logging.StreamHandler(sys.stdout)], 448 | ) 449 | log_level = training_args.get_process_log_level() 450 | logger.setLevel(log_level) 451 | datasets.utils.logging.set_verbosity(log_level) 452 | transformers.utils.logging.set_verbosity(log_level) 453 | transformers.utils.logging.enable_default_handler() 454 | transformers.utils.logging.enable_explicit_format() 455 | 456 | # Log on each process the small summary: 457 | logger.warning( 458 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 459 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 460 | ) 461 | logger.info(f"Training/evaluation parameters {training_args}") 462 | 463 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 464 | "t5-small", 465 | "t5-base", 466 | "t5-large", 467 | "t5-3b", 468 | "t5-11b", 469 | ]: 470 | logger.warning( 471 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 472 | "`--source_prefix 'summarize: ' `" 473 | ) 474 | 475 | # Detecting last checkpoint. 476 | last_checkpoint = None 477 | if not os.path.exists(training_args.output_dir): 478 | os.makedirs(training_args.output_dir) 479 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 480 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 481 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 482 | raise ValueError( 483 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 484 | "Use --overwrite_output_dir to overcome." 485 | ) 486 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 487 | logger.info( 488 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 489 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 490 | ) 491 | 492 | # Set seed before initializing model. 493 | set_seed(training_args.seed) 494 | 495 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 496 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 497 | # (the dataset will be downloaded automatically from the datasets Hub). 498 | # 499 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 500 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 501 | # 502 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 503 | # download the dataset. 504 | if data_args.dataset_name is not None: 505 | # Downloading and loading a dataset from the hub. 506 | raw_datasets = load_dataset( 507 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 508 | ) 509 | else: 510 | data_files = {} 511 | if data_args.train_file is not None: 512 | data_files["train"] = data_args.train_file 513 | extension = data_args.train_file.split(".")[-1] 514 | if data_args.validation_file is not None: 515 | data_files["validation"] = data_args.validation_file 516 | extension = data_args.validation_file.split(".")[-1] 517 | if data_args.test_file is not None: 518 | data_files["test"] = data_args.test_file 519 | extension = data_args.test_file.split(".")[-1] 520 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 521 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 522 | # https://huggingface.co/docs/datasets/loading_datasets.html. 523 | 524 | # Load pretrained model and tokenizer 525 | # 526 | # Distributed training: 527 | # The .from_pretrained methods guarantee that only one local process can concurrently 528 | # download model & vocab. 529 | 530 | config = AutoConfig.from_pretrained( 531 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 532 | cache_dir=model_args.cache_dir, 533 | revision=model_args.model_revision, 534 | use_auth_token=True if model_args.use_auth_token else None, 535 | ) 536 | config.pre_seq_len = model_args.pre_seq_len 537 | config.prefix_drop = model_args.prefix_drop 538 | 539 | tokenizer = AutoTokenizer.from_pretrained( 540 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 541 | cache_dir=model_args.cache_dir, 542 | use_fast=model_args.use_fast_tokenizer, 543 | revision=model_args.model_revision, 544 | use_auth_token=True if model_args.use_auth_token else None, 545 | ) 546 | 547 | #tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) 548 | 549 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 550 | 551 | # 将pad_token替换为eos_token 552 | 553 | print(f"Vocab size: {len(tokenizer)}") 554 | 555 | model = GPT2LMHeadModel.from_pretrained( 556 | model_args.model_name_or_path, 557 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 558 | config=config, 559 | cache_dir=model_args.cache_dir, 560 | revision=model_args.model_revision, 561 | use_auth_token=True if model_args.use_auth_token else None, 562 | ) 563 | 564 | for name, param in model.named_parameters(): 565 | if 'prefix_encoder' in name: 566 | param.requires_grad = True 567 | print(name) 568 | elif 'prefix_decoder' in name: 569 | param.requires_grad = True 570 | print(name) 571 | elif 'bias' in name: 572 | param.requires_grad = True 573 | print(name) 574 | else: 575 | param.requires_grad = False 576 | print("------------------") 577 | 578 | model.resize_token_embeddings(len(tokenizer)) 579 | # Preprocessing the datasets. 580 | # We need to tokenize inputs and targets. 581 | if training_args.do_train: 582 | column_names = raw_datasets["train"].column_names 583 | elif training_args.do_eval: 584 | column_names = raw_datasets["validation"].column_names 585 | elif training_args.do_predict: 586 | column_names = raw_datasets["test"].column_names 587 | else: 588 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 589 | return 590 | 591 | # Get the column names for input/target. 592 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 593 | if data_args.text_column is None: 594 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 595 | else: 596 | text_column = data_args.text_column 597 | if text_column not in column_names: 598 | raise ValueError( 599 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 600 | ) 601 | 602 | if data_args.summary_column is None: 603 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 604 | else: 605 | summary_column = data_args.summary_column 606 | if summary_column not in column_names: 607 | raise ValueError( 608 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 609 | ) 610 | 611 | # Temporarily set max_target_length for training. 612 | max_target_length = data_args.max_target_length 613 | padding = "max_length" if data_args.pad_to_max_length else False 614 | 615 | 616 | padding = "max_length" 617 | 618 | print("--------------") 619 | print(padding) 620 | 621 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 622 | logger.warning( 623 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 624 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 625 | ) 626 | 627 | def preprocess_function(examples): 628 | # remove pairs where at least one record is None 629 | 630 | inputs, targets = [], [] 631 | for i in range(len(examples[text_column])): 632 | if examples[text_column][i] is not None and examples[summary_column][i] is not None: 633 | inputs.append(examples[text_column][i]) 634 | targets.append(examples[summary_column][i]) 635 | cat_inputs = [] 636 | for i , t in zip(inputs, targets): 637 | inp = "" 638 | 639 | inp += f"{i} {tokenizer.eos_token} " 640 | inp += f"{t} {tokenizer.eos_token}" 641 | #inp = inp + i + " " + tokenizer.eos_token + " " + t + " " + + tokenizer.eos_token 642 | #print(inp) 643 | cat_inputs.append(inp) 644 | 645 | model_inputs = tokenizer(cat_inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 646 | 647 | with tokenizer.as_target_tokenizer(): 648 | labels = tokenizer(cat_inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 649 | 650 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 651 | labels["input_ids"] = [ 652 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 653 | ] 654 | 655 | model_inputs["labels"] = labels["input_ids"] 656 | #print(labels["input_ids"]) 657 | 658 | #model_inputs["labels"] = model_inputs["input_ids"] 659 | return model_inputs 660 | 661 | def preprocess_function_eval(examples): 662 | # remove pairs where at least one record is None 663 | 664 | inputs, targets = [], [] 665 | for i in range(len(examples[text_column])): 666 | if examples[text_column][i] is not None and examples[summary_column][i] is not None: 667 | inputs.append(examples[text_column][i]) 668 | targets.append(examples[summary_column][i]) 669 | cat_inputs = [] 670 | cat_labels = [] 671 | for i , t in zip(inputs, targets): 672 | inp = "" 673 | inp += f"{i} {tokenizer.eos_token} " 674 | oup = "" 675 | oup += f"{i} {tokenizer.eos_token} " 676 | oup += f"{t} {tokenizer.eos_token}" 677 | cat_inputs.append(inp) 678 | cat_labels.append(oup) 679 | model_inputs = tokenizer(cat_inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 680 | 681 | with tokenizer.as_target_tokenizer(): 682 | labels = tokenizer(cat_labels, max_length=data_args.max_source_length, padding=padding, truncation=True) 683 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 684 | labels["input_ids"] = [ 685 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 686 | ] 687 | model_inputs["labels"] = labels["input_ids"] 688 | return model_inputs 689 | 690 | 691 | 692 | if training_args.do_train: 693 | if "train" not in raw_datasets: 694 | raise ValueError("--do_train requires a train dataset") 695 | train_dataset = raw_datasets["train"] 696 | if data_args.max_train_samples is not None: 697 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 698 | with training_args.main_process_first(desc="train dataset map pre-processing"): 699 | train_dataset = train_dataset.map( 700 | preprocess_function, 701 | batched=True, 702 | num_proc=data_args.preprocessing_num_workers, 703 | remove_columns=column_names, 704 | load_from_cache_file=not data_args.overwrite_cache, 705 | desc="Running tokenizer on train dataset", 706 | ) 707 | 708 | if training_args.do_eval: 709 | max_target_length = data_args.val_max_target_length 710 | if "validation" not in raw_datasets: 711 | raise ValueError("--do_eval requires a validation dataset") 712 | eval_dataset = raw_datasets["validation"] 713 | if data_args.max_eval_samples is not None: 714 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 715 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 716 | eval_dataset = eval_dataset.map( 717 | preprocess_function_eval, 718 | batched=True, 719 | num_proc=data_args.preprocessing_num_workers, 720 | remove_columns=column_names, 721 | load_from_cache_file=not data_args.overwrite_cache, 722 | desc="Running tokenizer on validation dataset", 723 | ) 724 | 725 | if training_args.do_predict: 726 | max_target_length = data_args.val_max_target_length 727 | if "test" not in raw_datasets: 728 | raise ValueError("--do_predict requires a test dataset") 729 | predict_dataset = raw_datasets["test"] 730 | if data_args.max_predict_samples is not None: 731 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 732 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 733 | predict_dataset = predict_dataset.map( 734 | preprocess_function_eval, 735 | batched=True, 736 | num_proc=data_args.preprocessing_num_workers, 737 | remove_columns=column_names, 738 | load_from_cache_file=not data_args.overwrite_cache, 739 | desc="Running tokenizer on prediction dataset", 740 | ) 741 | 742 | data_collator = DataCollatorWithPadding( 743 | tokenizer, 744 | padding=True 745 | ) 746 | 747 | def compute_metrics(eval_preds): 748 | preds, labels = eval_preds 749 | if isinstance(preds, tuple): 750 | preds = preds[0] 751 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 752 | if data_args.ignore_pad_token_for_loss: 753 | # Replace -100 in the labels as we can't decode them. 754 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 755 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 756 | 757 | # Some simple post-processing 758 | #decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 759 | 760 | #result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 761 | result = compute_bleu(decoded_preds, decoded_labels) 762 | # Extract a few results from ROUGE 763 | result = {key: value * 100 for key, value in result.items()} 764 | #result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 765 | 766 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 767 | result["gen_len"] = np.mean(prediction_lens) 768 | result = {k: round(v, 4) for k, v in result.items()} 769 | return result 770 | 771 | # Initialize our Trainer 772 | trainer = GPTTrainer( 773 | model=model, 774 | args=training_args, 775 | train_dataset=train_dataset if training_args.do_train else None, 776 | eval_dataset=eval_dataset if training_args.do_eval else None, 777 | tokenizer=tokenizer, 778 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 779 | ) 780 | max_length = ( 781 | training_args.generation_max_length - config.pre_seq_len 782 | if training_args.generation_max_length is not None 783 | else data_args.val_max_target_length 784 | ) 785 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 786 | # Training 787 | trainer.train_addparam(max_length=max_length, num_beams=num_beams) 788 | if training_args.do_train: 789 | checkpoint = None 790 | if training_args.resume_from_checkpoint is not None: 791 | checkpoint = training_args.resume_from_checkpoint 792 | elif last_checkpoint is not None: 793 | checkpoint = last_checkpoint 794 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 795 | trainer.save_model() # Saves the tokenizer too for easy upload 796 | 797 | metrics = train_result.metrics 798 | max_train_samples = ( 799 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 800 | ) 801 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 802 | 803 | trainer.log_metrics("train", metrics) 804 | trainer.save_metrics("train", metrics) 805 | trainer.save_state() 806 | 807 | # Evaluation 808 | results = {} 809 | 810 | ''' 811 | max_length = ( 812 | training_args.generation_max_length - config.pre_seq_len 813 | if training_args.generation_max_length is not None 814 | else data_args.val_max_target_length 815 | ) 816 | ''' 817 | #num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 818 | if training_args.do_eval: 819 | logger.info("*** Evaluate ***") 820 | #print(eval_dataset[0]) 821 | #metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 822 | predict_results = trainer.predict( 823 | eval_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams, 824 | ) 825 | #print(eval_dataset) 826 | metrics = predict_results.metrics 827 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 828 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 829 | 830 | predictions = [] 831 | for i, gen_ids in enumerate(predict_results.predictions): 832 | sid = int(np.sum(eval_dataset[i]['attention_mask'])) 833 | gen_ids = gen_ids[sid:] 834 | gen_sent = tokenizer.decode(gen_ids.tolist(), skip_special_tokens=True).strip() 835 | predictions.append(gen_sent) 836 | 837 | 838 | ''' 839 | refference_inputs = [] 840 | with open(data_args.validation_file, 'r') as r: 841 | for item in jsonlines.Reader(r): 842 | refference_inputs.append(item[text_column]) 843 | 844 | 845 | truc_predictions = [] 846 | 847 | 848 | all_predictions = tokenizer.batch_decode( 849 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 850 | ) 851 | 852 | all_output_prediction_file = os.path.join(training_args.output_dir, "all_generated_predictions_eval_beam_5.txt") 853 | with open(all_output_prediction_file, "w") as writer: 854 | writer.write("\n".join(all_predictions)) 855 | 856 | 857 | 858 | for inte, pred in zip(refference_inputs, predict_results.predictions): 859 | index = [x for x, y in list(enumerate(pred)) if y == tokenizer.eos_token_id] 860 | #inte_index = [x for x, y in list(enumerate(inte)) if y == tokenizer.eos_token_id] 861 | temp_pred = pred[index[-1]:] 862 | truc_pred = pred.copy() 863 | 864 | truc_pred[:len(temp_pred)] = temp_pred 865 | truc_predictions.append(truc_pred) 866 | 867 | 868 | 869 | predictions = tokenizer.batch_decode( 870 | truc_predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 871 | ) 872 | 873 | 874 | 875 | #predictions = [pred.replace(' ', '').strip() for pred in predictions] 876 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions_eval_beam_5.txt") 877 | with open(output_prediction_file, "w") as writer: 878 | writer.write("\n".join(predictions)) 879 | 880 | 881 | ''' 882 | 883 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions_eval_greedy_5.txt") 884 | with open(output_prediction_file, "w") as writer: 885 | writer.write("\n".join(predictions)) 886 | 887 | 888 | refference_summaries = [] 889 | with open(data_args.validation_file, 'r') as r: 890 | for item in jsonlines.Reader(r): 891 | refference_summaries.append(item[summary_column]) 892 | 893 | 894 | 895 | 896 | 897 | #rouge_scores = compute_rouges(predictions, refference_summaries) 898 | 899 | 900 | 901 | bleu_scores = compute_bleu(predictions, refference_summaries) 902 | 903 | print("Eval Metric:") 904 | #print(rouge_scores) 905 | print(bleu_scores) 906 | 907 | ''' 908 | bleu_scores1 = compute_bleus(truc_predictions, refference_summaries) 909 | 910 | print(bleu_scores1) 911 | 912 | ''' 913 | 914 | #trainer.log_metrics("eval", metrics) 915 | #trainer.save_metrics("eval", metrics) 916 | 917 | if training_args.do_predict: 918 | logger.info("*** Predict ***") 919 | 920 | predict_results = trainer.predict( 921 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams, 922 | ) 923 | metrics = predict_results.metrics 924 | max_predict_samples = ( 925 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 926 | ) 927 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 928 | 929 | #trainer.log_metrics("predict", metrics) 930 | #trainer.save_metrics("predict", metrics) 931 | 932 | if trainer.is_world_process_zero(): 933 | if training_args.predict_with_generate: 934 | predictions = tokenizer.batch_decode( 935 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 936 | ) 937 | predictions = [pred.replace(' ', '').strip() for pred in predictions] 938 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions_test.txt") 939 | with open(output_prediction_file, "w") as writer: 940 | writer.write("\n".join(predictions)) 941 | refference_summaries = [] 942 | with open(data_args.test_file, 'r') as r: 943 | for item in jsonlines.Reader(r): 944 | refference_summaries.append(item[summary_column]) 945 | #rouge_scores = compute_rouges(predictions, refference_summaries) 946 | bleu_scores = compute_bleu(predictions, refference_summaries) 947 | print("Prediction Metric:") 948 | #print(rouge_scores) 949 | print(bleu_scores) 950 | 951 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 952 | if data_args.dataset_name is not None: 953 | kwargs["dataset_tags"] = data_args.dataset_name 954 | if data_args.dataset_config_name is not None: 955 | kwargs["dataset_args"] = data_args.dataset_config_name 956 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 957 | else: 958 | kwargs["dataset"] = data_args.dataset_name 959 | if data_args.lang is not None: 960 | kwargs["language"] = data_args.lang 961 | ''' 962 | if training_args.push_to_hub: 963 | trainer.push_to_hub(**kwargs) 964 | else: 965 | trainer.create_model_card(**kwargs) 966 | ''' 967 | 968 | return results 969 | 970 | if __name__ == "__main__": 971 | main() 972 | 973 | 974 | 975 | 976 | 977 | 978 | 979 | 980 | 981 | 982 | 983 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | 1000 | 1001 | 1002 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | -------------------------------------------------------------------------------- /model/modeling_t5.py: -------------------------------------------------------------------------------- 1 | from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Config, T5Block, T5LayerNorm 2 | import torch 3 | from torch import nn 4 | from torch.nn import CrossEntropyLoss 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | from transformers.activations import ACT2FN 8 | from transformers.file_utils import ( 9 | DUMMY_INPUTS, 10 | DUMMY_MASK, 11 | add_start_docstrings, 12 | add_start_docstrings_to_model_forward, 13 | is_torch_fx_proxy, 14 | replace_return_docstrings, 15 | ) 16 | from transformers.modeling_outputs import ( 17 | BaseModelOutput, 18 | BaseModelOutputWithPastAndCrossAttentions, 19 | Seq2SeqLMOutput, 20 | Seq2SeqModelOutput, 21 | ) 22 | from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer 23 | from transformers.utils import logging 24 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 25 | import copy 26 | import warnings 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | #################################################### 31 | # PyTorch Models are constructed by sub-classing 32 | # - torch.nn.Module for the layers and 33 | # - PreTrainedModel for the models (it-self a sub-class of nn.Module) 34 | #################################################### 35 | __HEAD_MASK_WARNING_MSG = """ 36 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 37 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 38 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 39 | num_heads)`. 40 | """ 41 | _CONFIG_FOR_DOC = "T5Config" 42 | T5_INPUTS_DOCSTRING = r""" 43 | Args: 44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 45 | Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you 46 | should be able to pad the inputs on both the right and the left. 47 | 48 | Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and 49 | [`PreTrainedTokenizer.__call__`] for detail. 50 | 51 | [What are input IDs?](../glossary#input-ids) 52 | 53 | To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). 54 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 55 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 56 | 57 | - 1 for tokens that are **not masked**, 58 | - 0 for tokens that are **masked**. 59 | 60 | [What are attention masks?](../glossary#attention-mask) 61 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 62 | Indices of decoder input sequence tokens in the vocabulary. 63 | 64 | Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and 65 | [`PreTrainedTokenizer.__call__`] for details. 66 | 67 | [What are decoder input IDs?](../glossary#decoder-input-ids) 68 | 69 | T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` 70 | is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). 71 | 72 | To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 73 | Training](./t5#training). 74 | decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 75 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also 76 | be used by default. 77 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 78 | Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, 79 | 1]`: 80 | 81 | - 1 indicates the head is **not masked**, 82 | - 0 indicates the head is **masked**. 83 | 84 | decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 85 | Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, 86 | 1]`: 87 | 88 | - 1 indicates the head is **not masked**, 89 | - 0 indicates the head is **masked**. 90 | 91 | cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 92 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in 93 | `[0, 1]`: 94 | 95 | - 1 indicates the head is **not masked**, 96 | - 0 indicates the head is **masked**. 97 | 98 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): 99 | Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) 100 | `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at 101 | the output of the last layer of the encoder. Used in the cross-attention of the decoder. 102 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 103 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 104 | 105 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 106 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 107 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 108 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 109 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 110 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 111 | model's internal embedding lookup matrix. 112 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 113 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 114 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be 115 | input (see `past_key_values`). This is useful if you want more control over how to convert 116 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 117 | 118 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value 119 | of `inputs_embeds`. 120 | 121 | use_cache (`bool`, *optional*): 122 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 123 | `past_key_values`). 124 | 125 | output_attentions (`bool`, *optional*): 126 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 127 | tensors for more detail. 128 | output_hidden_states (`bool`, *optional*): 129 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 130 | more detail. 131 | return_dict (`bool`, *optional*): 132 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 133 | """ 134 | PARALLELIZE_DOCSTRING = r""" 135 | This is an experimental feature and is a subject to change at a moment's notice. 136 | 137 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 138 | it will evenly distribute blocks across all devices. 139 | 140 | Args: 141 | device_map (`Dict[int, list]`, optional, defaults to None): 142 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 143 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 144 | have fewer attention modules mapped to it than other devices. For reference, the t5 models have the 145 | following number of attention modules: 146 | 147 | - t5-small: 6 148 | - t5-base: 12 149 | - t5-large: 24 150 | - t5-3b: 24 151 | - t5-11b: 24 152 | 153 | Example: 154 | 155 | ```python 156 | # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: 157 | model = T5ForConditionalGeneration.from_pretrained("t5-3b") 158 | device_map = { 159 | 0: [0, 1, 2], 160 | 1: [3, 4, 5, 6, 7, 8, 9], 161 | 2: [10, 11, 12, 13, 14, 15, 16], 162 | 3: [17, 18, 19, 20, 21, 22, 23], 163 | } 164 | model.parallelize(device_map) 165 | ``` 166 | """ 167 | DEPARALLELIZE_DOCSTRING = r""" 168 | Moves the model to cpu from a model parallel state. 169 | 170 | Example: 171 | 172 | ```python 173 | # On a 4 GPU machine with t5-3b: 174 | model = T5ForConditionalGeneration.from_pretrained("t5-3b") 175 | device_map = { 176 | 0: [0, 1, 2], 177 | 1: [3, 4, 5, 6, 7, 8, 9], 178 | 2: [10, 11, 12, 13, 14, 15, 16], 179 | 3: [17, 18, 19, 20, 21, 22, 23], 180 | } 181 | model.parallelize(device_map) # Splits the model across several devices 182 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 183 | ``` 184 | """ 185 | 186 | class T5EncoderStack(T5PreTrainedModel): 187 | def __init__(self, config, embed_tokens=None): 188 | super().__init__(config) 189 | 190 | self.embed_tokens = embed_tokens 191 | self.is_decoder = config.is_decoder 192 | 193 | self.block = nn.ModuleList( 194 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 195 | ) 196 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 197 | self.dropout = nn.Dropout(config.dropout_rate) 198 | self.pre_seq_len = config.pre_seq_len 199 | self.n_layer = config.num_layers 200 | self.n_head = config.num_heads 201 | self.n_embd = config.d_model // self.n_head 202 | 203 | self.prefix_tokens = torch.arange(self.pre_seq_len).long() 204 | self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.d_model) 205 | 206 | 207 | 208 | 209 | # Initialize weights and apply final processing 210 | self.post_init() 211 | # Model parallel 212 | self.model_parallel = False 213 | self.device_map = None 214 | self.gradient_checkpointing = False 215 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 216 | def parallelize(self, device_map=None): 217 | # Check validity of device_map 218 | self.device_map = ( 219 | get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map 220 | ) 221 | assert_device_map(self.device_map, len(self.block)) 222 | self.model_parallel = True 223 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 224 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 225 | # Load onto devices 226 | for k, v in self.device_map.items(): 227 | for layer in v: 228 | cuda_device = "cuda:" + str(k) 229 | self.block[layer] = self.block[layer].to(cuda_device) 230 | 231 | # Set embed_tokens to first layer 232 | self.embed_tokens = self.embed_tokens.to(self.first_device) 233 | # Set final layer norm to last device 234 | self.final_layer_norm = self.final_layer_norm.to(self.last_device) 235 | 236 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 237 | def deparallelize(self): 238 | self.model_parallel = False 239 | self.device_map = None 240 | self.first_device = "cpu" 241 | self.last_device = "cpu" 242 | for i in range(len(self.block)): 243 | self.block[i] = self.block[i].to("cpu") 244 | self.embed_tokens = self.embed_tokens.to("cpu") 245 | self.final_layer_norm = self.final_layer_norm.to("cpu") 246 | torch.cuda.empty_cache() 247 | def get_prompt_encoder(self, batch_size): 248 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device) 249 | prompts = self.prefix_encoder(prefix_tokens) 250 | return prompts 251 | 252 | def get_input_embeddings(self): 253 | return self.embed_tokens 254 | 255 | def set_input_embeddings(self, new_embeddings): 256 | self.embed_tokens = new_embeddings 257 | 258 | def forward( 259 | self, 260 | input_ids=None, 261 | attention_mask=None, 262 | encoder_hidden_states=None, 263 | encoder_attention_mask=None, 264 | inputs_embeds=None, 265 | head_mask=None, 266 | cross_attn_head_mask=None, 267 | past_key_values=None, 268 | use_cache=None, 269 | output_attentions=None, 270 | output_hidden_states=None, 271 | return_dict=None, 272 | ): 273 | # Model parallel 274 | if self.model_parallel: 275 | torch.cuda.set_device(self.first_device) 276 | self.embed_tokens = self.embed_tokens.to(self.first_device) 277 | use_cache = use_cache if use_cache is not None else self.config.use_cache 278 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 279 | output_hidden_states = ( 280 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 281 | ) 282 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 283 | 284 | if input_ids is not None and inputs_embeds is not None: 285 | err_msg_prefix = "decoder_" if self.is_decoder else "" 286 | raise ValueError( 287 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 288 | ) 289 | elif input_ids is not None: 290 | input_shape = input_ids.size() 291 | input_ids = input_ids.view(-1, input_shape[-1]) 292 | elif inputs_embeds is not None: 293 | input_shape = inputs_embeds.size()[:-1] 294 | else: 295 | err_msg_prefix = "decoder_" if self.is_decoder else "" 296 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 297 | 298 | if inputs_embeds is None: 299 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 300 | batch_size = input_ids.shape[0] 301 | raw_embeds = self.embed_tokens(input_ids) 302 | prompts_encoder = self.get_prompt_encoder(batch_size=batch_size) 303 | inputs_embeds = torch.cat((prompts_encoder, raw_embeds), dim=1) 304 | #print(inputs_embeds.shape) 305 | inputs_prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.device) 306 | attention_mask = torch.cat((inputs_prefix_attention_mask, attention_mask), dim=1) 307 | input_shape = inputs_embeds.size()[:-1] 308 | 309 | batch_size, seq_length = input_shape 310 | 311 | # required mask seq length can be calculated via length of past 312 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 313 | 314 | if use_cache is True: 315 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 316 | 317 | if attention_mask is None: 318 | attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) 319 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 320 | encoder_seq_length = encoder_hidden_states.shape[1] 321 | encoder_attention_mask = torch.ones( 322 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 323 | ) 324 | 325 | # initialize past_key_values with `None` if past does not exist 326 | if past_key_values is None: 327 | past_key_values = [None] * len(self.block) 328 | 329 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 330 | # ourselves in which case we just need to make it broadcastable to all heads. 331 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) 332 | 333 | # If a 2D or 3D attention mask is provided for the cross-attention 334 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 335 | if self.is_decoder and encoder_hidden_states is not None: 336 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 337 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 338 | if encoder_attention_mask is None: 339 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 340 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 341 | else: 342 | encoder_extended_attention_mask = None 343 | 344 | # Prepare head mask if needed 345 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 346 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 347 | present_key_value_states = () if use_cache else None 348 | all_hidden_states = () if output_hidden_states else None 349 | all_attentions = () if output_attentions else None 350 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 351 | position_bias = None 352 | encoder_decoder_position_bias = None 353 | 354 | hidden_states = self.dropout(inputs_embeds) 355 | 356 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 357 | layer_head_mask = head_mask[i] 358 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 359 | # Model parallel 360 | if self.model_parallel: 361 | torch.cuda.set_device(hidden_states.device) 362 | # Ensure that attention_mask is always on the same device as hidden_states 363 | if attention_mask is not None: 364 | attention_mask = attention_mask.to(hidden_states.device) 365 | if position_bias is not None: 366 | position_bias = position_bias.to(hidden_states.device) 367 | if encoder_hidden_states is not None: 368 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 369 | if encoder_extended_attention_mask is not None: 370 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 371 | if encoder_decoder_position_bias is not None: 372 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 373 | if layer_head_mask is not None: 374 | layer_head_mask = layer_head_mask.to(hidden_states.device) 375 | if cross_attn_layer_head_mask is not None: 376 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 377 | if output_hidden_states: 378 | all_hidden_states = all_hidden_states + (hidden_states,) 379 | 380 | if self.gradient_checkpointing and self.training: 381 | if use_cache: 382 | logger.warn( 383 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 384 | ) 385 | use_cache = False 386 | 387 | def create_custom_forward(module): 388 | def custom_forward(*inputs): 389 | return tuple(module(*inputs, use_cache, output_attentions)) 390 | 391 | return custom_forward 392 | 393 | layer_outputs = checkpoint( 394 | create_custom_forward(layer_module), 395 | hidden_states, 396 | extended_attention_mask, 397 | position_bias, 398 | encoder_hidden_states, 399 | encoder_extended_attention_mask, 400 | encoder_decoder_position_bias, 401 | layer_head_mask, 402 | cross_attn_layer_head_mask, 403 | None, # past_key_value is always None with gradient checkpointing 404 | ) 405 | else: 406 | layer_outputs = layer_module( 407 | hidden_states, 408 | attention_mask=extended_attention_mask, 409 | position_bias=position_bias, 410 | encoder_hidden_states=encoder_hidden_states, 411 | encoder_attention_mask=encoder_extended_attention_mask, 412 | encoder_decoder_position_bias=encoder_decoder_position_bias, 413 | layer_head_mask=layer_head_mask, 414 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 415 | past_key_value=past_key_value, 416 | use_cache=use_cache, 417 | output_attentions=output_attentions, 418 | ) 419 | # layer_outputs is a tuple with: 420 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 421 | if use_cache is False: 422 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 423 | 424 | hidden_states, present_key_value_state = layer_outputs[:2] 425 | 426 | # We share the position biases between the layers - the first layer store them 427 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 428 | # (cross-attention position bias), (cross-attention weights) 429 | position_bias = layer_outputs[2] 430 | if self.is_decoder and encoder_hidden_states is not None: 431 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 432 | # append next layer key value states 433 | if use_cache: 434 | present_key_value_states = present_key_value_states + (present_key_value_state,) 435 | 436 | if output_attentions: 437 | all_attentions = all_attentions + (layer_outputs[3],) 438 | if self.is_decoder: 439 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 440 | 441 | # Model Parallel: If it's the last layer for that device, put things on the next device 442 | if self.model_parallel: 443 | for k, v in self.device_map.items(): 444 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 445 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 446 | hidden_states = self.final_layer_norm(hidden_states) 447 | hidden_states = self.dropout(hidden_states) 448 | 449 | # Add last layer 450 | if output_hidden_states: 451 | all_hidden_states = all_hidden_states + (hidden_states,) 452 | 453 | if not return_dict: 454 | return tuple( 455 | v 456 | for v in [ 457 | hidden_states, 458 | present_key_value_states, 459 | all_hidden_states, 460 | all_attentions, 461 | all_cross_attentions, 462 | ] 463 | if v is not None 464 | ) 465 | 466 | return BaseModelOutputWithPastAndCrossAttentions( 467 | last_hidden_state=hidden_states, 468 | past_key_values=present_key_value_states, 469 | hidden_states=all_hidden_states, 470 | attentions=all_attentions, 471 | cross_attentions=all_cross_attentions, 472 | ) 473 | 474 | 475 | 476 | class T5DecoderStack(T5PreTrainedModel): 477 | def __init__(self, config, embed_tokens=None): 478 | super().__init__(config) 479 | 480 | self.embed_tokens = embed_tokens 481 | self.is_decoder = config.is_decoder 482 | 483 | self.block = nn.ModuleList( 484 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 485 | ) 486 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 487 | self.dropout = nn.Dropout(config.dropout_rate) 488 | 489 | self.pre_seq_len = config.pre_seq_len 490 | self.n_layer = config.num_layers 491 | self.n_head = config.num_heads 492 | self.n_embd = config.d_model // self.n_head 493 | 494 | self.prefix_tokens = torch.arange(self.pre_seq_len).long() 495 | self.prefix_decoder = torch.nn.Embedding(self.pre_seq_len, config.d_model) 496 | 497 | # Initialize weights and apply final processing 498 | self.post_init() 499 | # Model parallel 500 | self.model_parallel = False 501 | self.device_map = None 502 | self.gradient_checkpointing = False 503 | 504 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 505 | def parallelize(self, device_map=None): 506 | # Check validity of device_map 507 | self.device_map = ( 508 | get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map 509 | ) 510 | assert_device_map(self.device_map, len(self.block)) 511 | self.model_parallel = True 512 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 513 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 514 | # Load onto devices 515 | for k, v in self.device_map.items(): 516 | for layer in v: 517 | cuda_device = "cuda:" + str(k) 518 | self.block[layer] = self.block[layer].to(cuda_device) 519 | 520 | # Set embed_tokens to first layer 521 | self.embed_tokens = self.embed_tokens.to(self.first_device) 522 | # Set final layer norm to last device 523 | self.final_layer_norm = self.final_layer_norm.to(self.last_device) 524 | 525 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 526 | def deparallelize(self): 527 | self.model_parallel = False 528 | self.device_map = None 529 | self.first_device = "cpu" 530 | self.last_device = "cpu" 531 | for i in range(len(self.block)): 532 | self.block[i] = self.block[i].to("cpu") 533 | self.embed_tokens = self.embed_tokens.to("cpu") 534 | self.final_layer_norm = self.final_layer_norm.to("cpu") 535 | torch.cuda.empty_cache() 536 | 537 | 538 | def get_prompt_decoder(self, batch_size): 539 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device) 540 | prompts = self.prefix_decoder(prefix_tokens) 541 | return prompts 542 | 543 | def get_input_embeddings(self): 544 | return self.embed_tokens 545 | 546 | def set_input_embeddings(self, new_embeddings): 547 | self.embed_tokens = new_embeddings 548 | 549 | 550 | def forward( 551 | self, 552 | input_ids=None, 553 | attention_mask=None, 554 | encoder_hidden_states=None, 555 | encoder_attention_mask=None, 556 | inputs_embeds=None, 557 | head_mask=None, 558 | cross_attn_head_mask=None, 559 | past_key_values=None, 560 | use_cache=None, 561 | output_attentions=None, 562 | output_hidden_states=None, 563 | return_dict=None, 564 | ): 565 | # Model parallel 566 | if self.model_parallel: 567 | torch.cuda.set_device(self.first_device) 568 | self.embed_tokens = self.embed_tokens.to(self.first_device) 569 | use_cache = use_cache if use_cache is not None else self.config.use_cache 570 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 571 | output_hidden_states = ( 572 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 573 | ) 574 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 575 | 576 | if input_ids is not None and inputs_embeds is not None: 577 | err_msg_prefix = "decoder_" if self.is_decoder else "" 578 | raise ValueError( 579 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 580 | ) 581 | elif input_ids is not None: 582 | input_shape = input_ids.size() 583 | input_ids = input_ids.view(-1, input_shape[-1]) 584 | elif inputs_embeds is not None: 585 | input_shape = inputs_embeds.size()[:-1] 586 | else: 587 | err_msg_prefix = "decoder_" if self.is_decoder else "" 588 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 589 | 590 | # past_key_values_length 591 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 592 | 593 | if inputs_embeds is None: 594 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 595 | if past_key_values_length == 0: 596 | #inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 597 | batch_size = input_ids.shape[0] 598 | #print(input_ids.shape) 599 | raw_embeds = self.embed_tokens(input_ids) 600 | #print(raw_embeds.shape) 601 | prompts_decoder = self.get_prompt_decoder(batch_size=batch_size) 602 | inputs_embeds = torch.cat((prompts_decoder, raw_embeds), dim=1) 603 | inputs_prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.device) 604 | #attention_mask = torch.cat((inputs_prefix_attention_mask, attention_mask), dim=1) 605 | #input_shape[1] = input_shape[1] + self.pre_seq_len 606 | input_shape = inputs_embeds.size()[:-1] 607 | #print(input_shape) 608 | else: 609 | #print(past_key_values_length) 610 | #print(input_ids.shape) 611 | inputs_embeds = self.embed_tokens(input_ids) 612 | input_shape = inputs_embeds.size()[:-1] 613 | batch_size, seq_length = input_shape 614 | 615 | # required mask seq length can be calculated via length of past 616 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 617 | 618 | if use_cache is True: 619 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 620 | 621 | if attention_mask is None: 622 | attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) 623 | 624 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 625 | encoder_seq_length = encoder_hidden_states.shape[1] 626 | encoder_attention_mask = torch.ones( 627 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 628 | ) 629 | 630 | # initialize past_key_values with `None` if past does not exist 631 | if past_key_values is None: 632 | past_key_values = [None] * len(self.block) 633 | 634 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 635 | # ourselves in which case we just need to make it broadcastable to all heads. 636 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) 637 | 638 | # If a 2D or 3D attention mask is provided for the cross-attention 639 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 640 | if self.is_decoder and encoder_hidden_states is not None: 641 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 642 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 643 | if encoder_attention_mask is None: 644 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 645 | batch_size = input_ids.shape[0] 646 | inputs_prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.device) 647 | encoder_attention_mask = torch.cat((inputs_prefix_attention_mask, encoder_attention_mask), dim=1) 648 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 649 | else: 650 | encoder_extended_attention_mask = None 651 | 652 | # Prepare head mask if needed 653 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 654 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 655 | present_key_value_states = () if use_cache else None 656 | all_hidden_states = () if output_hidden_states else None 657 | all_attentions = () if output_attentions else None 658 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 659 | position_bias = None 660 | encoder_decoder_position_bias = None 661 | 662 | hidden_states = self.dropout(inputs_embeds) 663 | 664 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 665 | layer_head_mask = head_mask[i] 666 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 667 | # Model parallel 668 | if self.model_parallel: 669 | torch.cuda.set_device(hidden_states.device) 670 | # Ensure that attention_mask is always on the same device as hidden_states 671 | if attention_mask is not None: 672 | attention_mask = attention_mask.to(hidden_states.device) 673 | if position_bias is not None: 674 | position_bias = position_bias.to(hidden_states.device) 675 | if encoder_hidden_states is not None: 676 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 677 | if encoder_extended_attention_mask is not None: 678 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 679 | if encoder_decoder_position_bias is not None: 680 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 681 | if layer_head_mask is not None: 682 | layer_head_mask = layer_head_mask.to(hidden_states.device) 683 | if cross_attn_layer_head_mask is not None: 684 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 685 | if output_hidden_states: 686 | all_hidden_states = all_hidden_states + (hidden_states,) 687 | 688 | if self.gradient_checkpointing and self.training: 689 | if use_cache: 690 | logger.warn( 691 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 692 | ) 693 | use_cache = False 694 | 695 | def create_custom_forward(module): 696 | def custom_forward(*inputs): 697 | return tuple(module(*inputs, use_cache, output_attentions)) 698 | 699 | return custom_forward 700 | 701 | layer_outputs = checkpoint( 702 | create_custom_forward(layer_module), 703 | hidden_states, 704 | extended_attention_mask, 705 | position_bias, 706 | encoder_hidden_states, 707 | encoder_extended_attention_mask, 708 | encoder_decoder_position_bias, 709 | layer_head_mask, 710 | cross_attn_layer_head_mask, 711 | None, # past_key_value is always None with gradient checkpointing 712 | ) 713 | else: 714 | layer_outputs = layer_module( 715 | hidden_states, 716 | attention_mask=extended_attention_mask, 717 | position_bias=position_bias, 718 | encoder_hidden_states=encoder_hidden_states, 719 | encoder_attention_mask=encoder_extended_attention_mask, 720 | encoder_decoder_position_bias=encoder_decoder_position_bias, 721 | layer_head_mask=layer_head_mask, 722 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 723 | past_key_value=past_key_value, 724 | use_cache=use_cache, 725 | output_attentions=output_attentions, 726 | ) 727 | 728 | # layer_outputs is a tuple with: 729 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 730 | if use_cache is False: 731 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 732 | 733 | hidden_states, present_key_value_state = layer_outputs[:2] 734 | 735 | # We share the position biases between the layers - the first layer store them 736 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 737 | # (cross-attention position bias), (cross-attention weights) 738 | position_bias = layer_outputs[2] 739 | if self.is_decoder and encoder_hidden_states is not None: 740 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 741 | # append next layer key value states 742 | if use_cache: 743 | present_key_value_states = present_key_value_states + (present_key_value_state,) 744 | 745 | if output_attentions: 746 | all_attentions = all_attentions + (layer_outputs[3],) 747 | if self.is_decoder: 748 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 749 | 750 | # Model Parallel: If it's the last layer for that device, put things on the next device 751 | if self.model_parallel: 752 | for k, v in self.device_map.items(): 753 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 754 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 755 | 756 | hidden_states = self.final_layer_norm(hidden_states) 757 | hidden_states = self.dropout(hidden_states) 758 | 759 | # Add last layer 760 | if output_hidden_states: 761 | all_hidden_states = all_hidden_states + (hidden_states,) 762 | 763 | if not return_dict: 764 | return tuple( 765 | v 766 | for v in [ 767 | hidden_states, 768 | present_key_value_states, 769 | all_hidden_states, 770 | all_attentions, 771 | all_cross_attentions, 772 | ] 773 | if v is not None 774 | ) 775 | return BaseModelOutputWithPastAndCrossAttentions( 776 | last_hidden_state=hidden_states, 777 | past_key_values=present_key_value_states, 778 | hidden_states=all_hidden_states, 779 | attentions=all_attentions, 780 | cross_attentions=all_cross_attentions, 781 | ) 782 | 783 | 784 | class T5PromptForConditionalGeneration(T5PreTrainedModel): 785 | _keys_to_ignore_on_load_missing = [ 786 | r"encoder\.embed_tokens\.weight", 787 | r"decoder\.embed_tokens\.weight", 788 | r"lm_head\.weight", 789 | ] 790 | _keys_to_ignore_on_load_unexpected = [ 791 | r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", 792 | ] 793 | 794 | def __init__(self, config): 795 | super().__init__(config) 796 | self.model_dim = config.d_model 797 | self.pre_seq_len = config.pre_seq_len 798 | 799 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 800 | 801 | encoder_config = copy.deepcopy(config) 802 | encoder_config.is_decoder = False 803 | encoder_config.use_cache = False 804 | encoder_config.is_encoder_decoder = False 805 | self.encoder = T5EncoderStack(encoder_config, self.shared) 806 | 807 | decoder_config = copy.deepcopy(config) 808 | decoder_config.is_decoder = True 809 | decoder_config.is_encoder_decoder = False 810 | decoder_config.num_layers = config.num_decoder_layers 811 | self.decoder = T5DecoderStack(decoder_config, self.shared) 812 | 813 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 814 | 815 | # Initialize weights and apply final processing 816 | self.post_init() 817 | 818 | # Model parallel 819 | self.model_parallel = False 820 | self.device_map = None 821 | 822 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 823 | def parallelize(self, device_map=None): 824 | self.device_map = ( 825 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 826 | if device_map is None 827 | else device_map 828 | ) 829 | assert_device_map(self.device_map, len(self.encoder.block)) 830 | self.encoder.parallelize(self.device_map) 831 | self.decoder.parallelize(self.device_map) 832 | self.lm_head = self.lm_head.to(self.decoder.first_device) 833 | self.model_parallel = True 834 | 835 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 836 | def deparallelize(self): 837 | self.encoder.deparallelize() 838 | self.decoder.deparallelize() 839 | self.encoder = self.encoder.to("cpu") 840 | self.decoder = self.decoder.to("cpu") 841 | self.lm_head = self.lm_head.to("cpu") 842 | self.model_parallel = False 843 | self.device_map = None 844 | torch.cuda.empty_cache() 845 | 846 | def get_input_embeddings(self): 847 | return self.shared 848 | 849 | def set_input_embeddings(self, new_embeddings): 850 | self.shared = new_embeddings 851 | self.encoder.set_input_embeddings(new_embeddings) 852 | self.decoder.set_input_embeddings(new_embeddings) 853 | 854 | def set_output_embeddings(self, new_embeddings): 855 | self.lm_head = new_embeddings 856 | 857 | def get_output_embeddings(self): 858 | return self.lm_head 859 | 860 | def get_encoder(self): 861 | return self.encoder 862 | 863 | def get_decoder(self): 864 | return self.decoder 865 | 866 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) 867 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 868 | def forward( 869 | self, 870 | input_ids=None, 871 | attention_mask=None, 872 | decoder_input_ids=None, 873 | decoder_attention_mask=None, 874 | head_mask=None, 875 | decoder_head_mask=None, 876 | cross_attn_head_mask=None, 877 | encoder_outputs=None, 878 | past_key_values=None, 879 | inputs_embeds=None, 880 | decoder_inputs_embeds=None, 881 | labels=None, 882 | use_cache=None, 883 | output_attentions=None, 884 | output_hidden_states=None, 885 | return_dict=None, 886 | ): 887 | r""" 888 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 889 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., 890 | config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for 891 | labels in `[0, ..., config.vocab_size]` 892 | 893 | Returns: 894 | 895 | Examples: 896 | 897 | ```python 898 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration 899 | 900 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") 901 | >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") 902 | 903 | >>> # training 904 | >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids 905 | >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids 906 | >>> outputs = model(input_ids=input_ids, labels=labels) 907 | >>> loss = outputs.loss 908 | >>> logits = outputs.logits 909 | 910 | >>> # inference 911 | >>> input_ids = tokenizer( 912 | ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" 913 | >>> ).input_ids # Batch size 1 914 | >>> outputs = model.generate(input_ids) 915 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 916 | >>> # studies have shown that owning a dog is good for you. 917 | ```""" 918 | use_cache = use_cache if use_cache is not None else self.config.use_cache 919 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 920 | 921 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 922 | if head_mask is not None and decoder_head_mask is None: 923 | if self.config.num_layers == self.config.num_decoder_layers: 924 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 925 | decoder_head_mask = head_mask 926 | 927 | # Encode if needed (training, first prediction pass) 928 | if encoder_outputs is None: 929 | # Convert encoder inputs in embeddings if needed 930 | encoder_outputs = self.encoder( 931 | input_ids=input_ids, 932 | attention_mask=attention_mask, 933 | inputs_embeds=inputs_embeds, 934 | head_mask=head_mask, 935 | output_attentions=output_attentions, 936 | output_hidden_states=output_hidden_states, 937 | return_dict=return_dict, 938 | ) 939 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 940 | encoder_outputs = BaseModelOutput( 941 | last_hidden_state=encoder_outputs[0], 942 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 943 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 944 | ) 945 | 946 | hidden_states = encoder_outputs[0] 947 | 948 | if self.model_parallel: 949 | torch.cuda.set_device(self.decoder.first_device) 950 | 951 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 952 | # get decoder inputs from shifting lm labels to the right 953 | decoder_input_ids = self._shift_right(labels) 954 | 955 | # Set device for model parallelism 956 | if self.model_parallel: 957 | torch.cuda.set_device(self.decoder.first_device) 958 | hidden_states = hidden_states.to(self.decoder.first_device) 959 | if decoder_input_ids is not None: 960 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 961 | if attention_mask is not None: 962 | attention_mask = attention_mask.to(self.decoder.first_device) 963 | if decoder_attention_mask is not None: 964 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 965 | 966 | # Decode 967 | decoder_outputs = self.decoder( 968 | input_ids=decoder_input_ids, 969 | attention_mask=decoder_attention_mask, 970 | inputs_embeds=decoder_inputs_embeds, 971 | past_key_values=past_key_values, 972 | encoder_hidden_states=hidden_states, 973 | encoder_attention_mask=attention_mask, 974 | head_mask=decoder_head_mask, 975 | cross_attn_head_mask=cross_attn_head_mask, 976 | use_cache=use_cache, 977 | output_attentions=output_attentions, 978 | output_hidden_states=output_hidden_states, 979 | return_dict=return_dict, 980 | ) 981 | 982 | sequence_output = decoder_outputs[0] 983 | #print(outputs[0].shape) 984 | if past_key_values is None: 985 | sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() 986 | else: 987 | sequence_output = sequence_output 988 | 989 | # Set device for model parallelism 990 | if self.model_parallel: 991 | torch.cuda.set_device(self.encoder.first_device) 992 | self.lm_head = self.lm_head.to(self.encoder.first_device) 993 | sequence_output = sequence_output.to(self.lm_head.weight.device) 994 | 995 | if self.config.tie_word_embeddings: 996 | # Rescale output before projecting on vocab 997 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 998 | sequence_output = sequence_output * (self.model_dim ** -0.5) 999 | 1000 | 1001 | lm_logits = self.lm_head(sequence_output) 1002 | 1003 | loss = None 1004 | if labels is not None: 1005 | loss_fct = CrossEntropyLoss(ignore_index=-100) 1006 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 1007 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 1008 | 1009 | if not return_dict: 1010 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 1011 | return ((loss,) + output) if loss is not None else output 1012 | 1013 | return Seq2SeqLMOutput( 1014 | loss=loss, 1015 | logits=lm_logits, 1016 | past_key_values=decoder_outputs.past_key_values, 1017 | decoder_hidden_states=decoder_outputs.hidden_states, 1018 | decoder_attentions=decoder_outputs.attentions, 1019 | cross_attentions=decoder_outputs.cross_attentions, 1020 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1021 | encoder_hidden_states=encoder_outputs.hidden_states, 1022 | encoder_attentions=encoder_outputs.attentions, 1023 | ) 1024 | 1025 | def prepare_inputs_for_generation( 1026 | self, 1027 | input_ids, 1028 | past=None, 1029 | attention_mask=None, 1030 | head_mask=None, 1031 | decoder_head_mask=None, 1032 | cross_attn_head_mask=None, 1033 | use_cache=None, 1034 | encoder_outputs=None, 1035 | **kwargs 1036 | ): 1037 | 1038 | # cut decoder_input_ids if past is used 1039 | if past is not None: 1040 | input_ids = input_ids[:, -1:] 1041 | 1042 | return { 1043 | "decoder_input_ids": input_ids, 1044 | "past_key_values": past, 1045 | "encoder_outputs": encoder_outputs, 1046 | "attention_mask": attention_mask, 1047 | "head_mask": head_mask, 1048 | "decoder_head_mask": decoder_head_mask, 1049 | "cross_attn_head_mask": cross_attn_head_mask, 1050 | "use_cache": use_cache, 1051 | } 1052 | 1053 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1054 | return self._shift_right(labels) 1055 | 1056 | def _reorder_cache(self, past, beam_idx): 1057 | # if decoder past is not included in output 1058 | # speedy decoding is disabled and no need to reorder 1059 | if past is None: 1060 | logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") 1061 | return past 1062 | 1063 | reordered_decoder_past = () 1064 | for layer_past_states in past: 1065 | # get the correct batch idx from layer past batch dim 1066 | # batch dim of `past` is at 2nd position 1067 | reordered_layer_past_states = () 1068 | for layer_past_state in layer_past_states: 1069 | # need to set correct `past` for each of the four key / value states 1070 | reordered_layer_past_states = reordered_layer_past_states + ( 1071 | layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), 1072 | ) 1073 | 1074 | assert reordered_layer_past_states[0].shape == layer_past_states[0].shape 1075 | assert len(reordered_layer_past_states) == len(layer_past_states) 1076 | 1077 | reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) 1078 | return reordered_decoder_past 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1089 | 1090 | 1091 | 1092 | 1093 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | 1100 | --------------------------------------------------------------------------------