├── finetune ├── seqcls │ ├── data │ │ ├── bioasq_hf │ │ │ ├── dev.json │ │ │ ├── test.json │ │ │ └── train.json │ │ └── pubmedqa_hf │ │ │ ├── dev.json │ │ │ ├── test.json │ │ │ └── train.json │ ├── README.md │ ├── preprocess_blurb_seqcls.py │ └── run_seqcls_gpt.py ├── setup │ └── requirements.txt ├── mc │ ├── data │ │ └── medqa_usmle_hf │ │ │ ├── dev.json │ │ │ ├── test.json │ │ │ └── train.json │ ├── README.md │ ├── preprocess_medqa.py │ ├── run_experiments.py │ └── run_multiple_choice.py ├── textgen │ ├── data │ │ └── meqsum │ │ │ ├── test.source │ │ │ ├── val.source │ │ │ ├── train.source │ │ │ ├── val.target │ │ │ ├── test.target │ │ │ └── train.target │ └── gpt2 │ │ ├── generate_demo.py │ │ ├── finetune_for_summarization.py │ │ ├── sum_data_collator.py │ │ └── sum_dataset.py ├── deepspeed │ └── cpu_offload.json ├── README.md └── utils │ ├── hf_flash_gpt_2.py │ ├── custom_modeling_gpt2.py │ └── custom_modeling_gpt_neo.py ├── demo.py ├── README.md └── tokenize └── train_bpe.py /finetune/seqcls/data/bioasq_hf/dev.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/seqcls/data/bioasq_hf/test.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/seqcls/data/bioasq_hf/train.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/seqcls/data/pubmedqa_hf/dev.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/seqcls/data/pubmedqa_hf/test.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/seqcls/data/pubmedqa_hf/train.json: -------------------------------------------------------------------------------- 1 | {"id": "passage id", "sentence1": "question text ...", "sentence2": "passage text ...", "label": "label"} 2 | -------------------------------------------------------------------------------- /finetune/setup/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.6.1 2 | fairscale==0.4.12 3 | huggingface-hub==0.10.1 4 | rouge-score==0.0.4 5 | sacrebleu==2.0.0 6 | transformers==4.24.0 7 | wandb==0.13.5 8 | -------------------------------------------------------------------------------- /finetune/mc/data/medqa_usmle_hf/dev.json: -------------------------------------------------------------------------------- 1 | {"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"} 2 | -------------------------------------------------------------------------------- /finetune/mc/data/medqa_usmle_hf/test.json: -------------------------------------------------------------------------------- 1 | {"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"} 2 | -------------------------------------------------------------------------------- /finetune/mc/data/medqa_usmle_hf/train.json: -------------------------------------------------------------------------------- 1 | {"id": "id", "sent1": "passage and question ...", "sent2": "", "ending0": "answer 0", "ending1": "answer 1", "ending2": "answer 2", "ending3": "answer 3", "label": "int of correct answer"} 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/test.source: -------------------------------------------------------------------------------- 1 | The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file. 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/val.source: -------------------------------------------------------------------------------- 1 | The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file. 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/train.source: -------------------------------------------------------------------------------- 1 | The source text for an example. For instance this could be the full article that is supposed to be summarized. There should be one example per line. The corresponding train.target file would have the gold generations for each example. So the Nth line of this file would correspond to the Nth line of the *.target file. 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/val.target: -------------------------------------------------------------------------------- 1 | The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file. 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/test.target: -------------------------------------------------------------------------------- 1 | The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file. 2 | -------------------------------------------------------------------------------- /finetune/textgen/data/meqsum/train.target: -------------------------------------------------------------------------------- 1 | The gold sequence for this example. Each line should be a new example. In the corresponding line in the *.source file is the original text. This text is the desired generation for that source. So if this was a summarization task, the *.source file would have the full article, and this would be the summarization. The Nth line of this file corresponds to the Nth line of the *.source file. 2 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 4 | 5 | device = torch.device("cuda") 6 | 7 | tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/pubmed_gpt_tokenizer") 8 | 9 | model = GPT2LMHeadModel.from_pretrained("stanford-crfm/pubmedgpt").to(device) 10 | 11 | input_ids = tokenizer.encode( 12 | "Photosynthesis is ", return_tensors="pt" 13 | ).to(device) 14 | 15 | sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50) 16 | 17 | print("Output:\n" + 100 * "-") 18 | print(tokenizer.decode(sample_output[0], skip_special_tokens=True)) 19 | -------------------------------------------------------------------------------- /finetune/mc/README.md: -------------------------------------------------------------------------------- 1 | ## Setting Up MedQA 2 | 3 | 1.) Download data from [MedQA GitHub](https://github.com/jind11/MedQA) . The GitHub should have a link to a Google Drive. Make sure to download the contents to a directory path matching `raw_data/medqa` in this directory. For more details, review the `preprocess_medqa.py` script to see the specific paths the preprocessing script expects. For example, `raw_data/medqa/data_clean/questions/US/4_options` should exist when the original data is set up properly. 4 | 5 | 2.) Run the `preprocess_medqa.py` script in this directory to produce the data in the format expected by our fine-tuning code. It should produce the appropriate `.jsonl` files in `data/medqa_usmle_hf`. 6 | -------------------------------------------------------------------------------- /finetune/seqcls/README.md: -------------------------------------------------------------------------------- 1 | ## Setting Up BLURB (PubMedQA and BioASQ) 2 | 3 | 1.) Download [BioASQ](http://www.bioasq.org/) and [PubMedQA](https://pubmedqa.github.io/) original data. Make sure when downloading and expanding the data that it matches these paths: `raw_data/blurb/data_generation/data/pubmedqa` and `raw_data/blurb/data_generation/data/BioASQ` in this directory. For more details, review the `preprocess_blurb_seqcls.py` script to see the specific paths the preprocessing script expects. For example, the path `raw_data/blurb/data_generation/data/pubmedqa/pqal_fold0` should exist when the data has been set up properly. 4 | 5 | 2.) Run the `preprocess_medqa.py` script in this directory to produce the data in the format expected by our fine-tuning code. It should produce the appropriate `.jsonl` files in `data/pubmedqa_hf` and `data/bioasq_hf`. 6 | -------------------------------------------------------------------------------- /finetune/textgen/gpt2/generate_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | model_path = sys.argv[1] 7 | device = torch.device("cuda") 8 | 9 | # load tokenizer 10 | print("Loading tokenizer ...") 11 | tokenizer = AutoTokenizer.from_pretrained(model_path) 12 | 13 | # load model 14 | print("Loading model ...") 15 | model = AutoModelForCausalLM.from_pretrained(sys.argv[1]).to(device) 16 | 17 | # run model 18 | print("Generating text ...") 19 | prompt = sys.argv[2] 20 | prompt_w_start = f"{prompt}<|startoftext|>" 21 | encoding = tokenizer.encode(prompt_w_start, return_tensors='pt').to(device) 22 | generated_ids = model.generate(encoding, max_new_tokens=100, eos_token_id=28895) 23 | generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) 24 | print(f"Input: {prompt}") 25 | print(f"Output: {generated_text[len(prompt):]}") 26 | -------------------------------------------------------------------------------- /finetune/deepspeed/cpu_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "optimizer": { 3 | "type": "AdamW", 4 | "params": { 5 | "lr": 2e-06, 6 | "betas": [ 7 | 0.9, 8 | 0.999 9 | ], 10 | "eps": 1e-8, 11 | "weight_decay": 0.0 12 | } 13 | }, 14 | 15 | "scheduler": { 16 | "type": "WarmupDecayLR", 17 | "params": { 18 | "total_num_steps": "auto", 19 | "warmup_max_lr": 2e-06, 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | 24 | "zero_optimization": { 25 | "stage": 1, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "reduce_scatter": true, 29 | "reduce_bucket_size": 5e8, 30 | "overlap_comm": true, 31 | "contiguous_gradients": true, 32 | "cpu_offload": true 33 | }, 34 | 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu": "auto", 37 | 38 | "fp16": { 39 | "enabled": true 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BioMedLM 2 | 3 | Code used for pre-training and fine-tuning the [BioMedLM](https://huggingface.co/stanford-crfm/pubmedgpt) model. 4 | 5 | Note: This model was previously known as PubMedGPT, but the NIH has asked us to change the name since they hold the trademark on "PubMed", so the new name is BioMedLM! 6 | 7 | ### Links 8 | 9 | [Blog](https://crfm.stanford.edu/2022/12/15/pubmedgpt.html) 10 | 11 | [Model](https://huggingface.co/stanford-crfm/pubmedgpt/tree/main) 12 | 13 | [MosaicML Composer](https://github.com/mosaicml/composer) 14 | 15 | ### Example Usage 16 | 17 | ``` 18 | import torch 19 | 20 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 21 | 22 | device = torch.device("cuda") 23 | 24 | tokenizer = GPT2Tokenizer.from_pretrained("stanford-crfm/BioMedLM") 25 | 26 | model = GPT2LMHeadModel.from_pretrained("stanford-crfm/BioMedLM").to(device) 27 | 28 | input_ids = tokenizer.encode( 29 | "Photosynthesis is ", return_tensors="pt" 30 | ).to(device) 31 | 32 | sample_output = model.generate(input_ids, do_sample=True, max_length=50, top_k=50) 33 | 34 | print("Output:\n" + 100 * "-") 35 | print(tokenizer.decode(sample_output[0], skip_special_tokens=True)) 36 | ``` 37 | -------------------------------------------------------------------------------- /tokenize/train_bpe.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors 5 | 6 | input_files = sys.argv[1].split(",") 7 | tokenizer_name = sys.argv[2] 8 | os.system(f"mkdir {tokenizer_name}") 9 | 10 | # Initialize a tokenizer 11 | tokenizer = Tokenizer(models.BPE()) 12 | 13 | # Customize pre-tokenization and decoding 14 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 15 | tokenizer.decoder = decoders.ByteLevel() 16 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) 17 | 18 | # And then train 19 | trainer = trainers.BpeTrainer( 20 | vocab_size=28896, 21 | min_frequency=2, 22 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet() 23 | ) 24 | tokenizer.train(input_files,trainer=trainer) 25 | 26 | # And Save it 27 | tokenizer.save(f"{tokenizer_name}/tokenizer.json", pretty=True) 28 | 29 | # create vocab.json and merges.txt 30 | with open(f"{tokenizer_name}/vocab.json", "w") as vocab_file: 31 | vocab_json = json.loads(open(f"{tokenizer_name}/tokenizer.json").read())["model"]["vocab"] 32 | vocab_file.write(json.dumps(vocab_json)) 33 | 34 | with open(f"{tokenizer_name}/merges.txt", "w") as merges_file: 35 | merges = "\n".join(json.loads(open(f"{tokenizer_name}/tokenizer.json").read())["model"]["merges"]) 36 | merges_file.write(merges) 37 | -------------------------------------------------------------------------------- /finetune/mc/preprocess_medqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import shutil 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | root = "data" 10 | os.system(f"mkdir -p {root}") 11 | 12 | 13 | def dump_jsonl(data, fpath): 14 | with open(fpath, "w") as outf: 15 | for d in data: 16 | print (json.dumps(d), file=outf) 17 | 18 | def process_medqa(fname): 19 | dname = "medqa_usmle" 20 | lines = open(f"raw_data/medqa/data_clean/questions/US/4_options/phrases_no_exclude_{fname}.jsonl").readlines() 21 | outs, lens = [], [] 22 | for i, line in enumerate(tqdm(lines)): 23 | stmt = json.loads(line) 24 | sent1 = stmt["question"] 25 | ends = [stmt["options"][key] for key in "ABCD"] 26 | outs.append({"id": f"{fname}-{i:05d}", 27 | "sent1": sent1, 28 | "sent2": "", 29 | "ending0": ends[0], 30 | "ending1": ends[1], 31 | "ending2": ends[2], 32 | "ending3": ends[3], 33 | "label": ord(stmt["answer_idx"]) - ord("A") 34 | }) 35 | lens.append(len(sent1) + max([len(ends[0]),len(ends[1]), len(ends[2]), len(ends[3])])) 36 | print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens)) 37 | # 38 | os.system(f'mkdir -p {root}/{dname}_hf') 39 | dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json") 40 | 41 | 42 | process_medqa("train") 43 | process_medqa("test") 44 | process_medqa("dev") 45 | -------------------------------------------------------------------------------- /finetune/mc/run_experiments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | env_setup_cmd = "task=medqa_usmle_hf ; datadir=data/$task ; export WANDB_PROJECT='biomedical-nlp-eval'" 7 | 8 | experiments = [json.loads(line) for line in open(sys.argv[1]).read().split("\n") if line] 9 | 10 | for experiment in experiments: 11 | checkpoint = experiment["checkpoint"] 12 | lr = experiment["lr"] 13 | epochs = experiment["epochs"] 14 | grad_accum = experiment["grad_accum"] 15 | train_per_device_batch_size = experiment["train_per_device_batch_size"] 16 | num_devices = experiment["num_devices"] if "num_devices" in experiment else 8 17 | batch_size = int(num_devices) * int(grad_accum) * int(train_per_device_batch_size) 18 | tokenizer = experiment["tokenizer"] 19 | numerical_format = experiment["numerical"] if "numerical" in experiment else "bf16" 20 | seed = experiment["seed"] 21 | use_flash = experiment["use_flash"] 22 | run_name = f"{os.path.basename(checkpoint)}-lr={lr}-batch_size={batch_size}-epochs={epochs}-seed={seed}-task=medqa" 23 | exp_cmd = ( 24 | f"python -m torch.distributed.launch --nproc_per_node={num_devices} --nnodes=1 --node_rank=0" 25 | f" run_multiple_choice.py --use_flash {use_flash} --tokenizer_name {tokenizer} --model_name_or_path" 26 | f" {checkpoint} --train_file data/medqa_usmle_hf/train.json --validation_file data/medqa_usmle_hf/dev.json" 27 | " --test_file data/medqa_usmle_hf/test.json --do_train --do_eval --do_predict --per_device_train_batch_size" 28 | f" {train_per_device_batch_size} --per_device_eval_batch_size 1 --gradient_accumulation_steps {grad_accum}" 29 | f" --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {epochs} --max_seq_length 512" 30 | f" --{numerical_format} --seed {seed} --data_seed {seed} --logging_first_step --logging_steps 20" 31 | f" --save_strategy no --evaluation_strategy steps --eval_steps 500 --run_name {run_name} " 32 | " --output_dir trash/" 33 | " --overwrite_output_dir" 34 | ) 35 | if "sharded_ddp" in experiment and experiment["sharded_ddp"].lower() == "true": 36 | exp_cmd += " --sharded_ddp zero_dp_2 " 37 | print("---") 38 | print(exp_cmd) 39 | subprocess.call(f"{env_setup_cmd} ; {exp_cmd}", shell=True) 40 | -------------------------------------------------------------------------------- /finetune/seqcls/preprocess_blurb_seqcls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import json 4 | import random 5 | import shutil 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | 11 | def dump_jsonl(data, fpath): 12 | with open(fpath, "w") as outf: 13 | for d in data: 14 | print (json.dumps(d), file=outf) 15 | 16 | 17 | ######################### BLURB sequence classification ######################### 18 | root = "data" 19 | os.system(f"mkdir -p {root}") 20 | 21 | 22 | def process_pubmedqa(fname): 23 | dname = "pubmedqa" 24 | print (dname, fname) 25 | if fname in ["train", "dev"]: 26 | data = json.load(open(f"raw_data/blurb/data_generation/data/pubmedqa/pqal_fold0/{fname}_set.json")) 27 | elif fname == "test": 28 | data = json.load(open(f"raw_data/blurb/data_generation/data/pubmedqa/{fname}_set.json")) 29 | else: 30 | assert False 31 | outs, lens = [], [] 32 | for id in data: 33 | obj = data[id] 34 | context = " ".join([c.strip() for c in obj["CONTEXTS"] if c.strip()]) 35 | question = obj["QUESTION"].strip() 36 | label = obj["final_decision"].strip() 37 | assert label in ["yes", "no", "maybe"] 38 | outs.append({"id": id, "sentence1": question, "sentence2": context, "label": label}) 39 | lens.append(len(question) + len(context)) 40 | print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens)) 41 | # 42 | os.system(f"mkdir -p {root}/{dname}_hf") 43 | dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json") 44 | 45 | process_pubmedqa("test") 46 | process_pubmedqa("train") 47 | process_pubmedqa("dev") 48 | 49 | 50 | def process_bioasq(fname): 51 | dname = "bioasq" 52 | print (dname, fname) 53 | df = pd.read_csv(open(f"raw_data/blurb/data_generation/data/BioASQ/{fname}.tsv"), sep="\t", header=None) 54 | outs, lens = [], [] 55 | for _, row in df.iterrows(): 56 | id = row[0].strip() 57 | question = row[1].strip() 58 | context = row[2].strip() 59 | label = row[3].strip() 60 | assert label in ["yes", "no"] 61 | outs.append({"id": id, "sentence1": question, "sentence2": context, "label": label}) 62 | lens.append(len(question) + len(context)) 63 | print ("total", len(outs), "seqlen mean", int(np.mean(lens)), "median", int(np.median(lens)), "95th", int(np.percentile(lens, 95)), "max", np.max(lens)) 64 | # 65 | os.system(f"mkdir -p {root}/{dname}_hf") 66 | dump_jsonl(outs, f"{root}/{dname}_hf/{fname}.json") 67 | 68 | process_bioasq("test") 69 | process_bioasq("dev") 70 | process_bioasq("train") 71 | -------------------------------------------------------------------------------- /finetune/README.md: -------------------------------------------------------------------------------- 1 | # Biomedical downstream evaluation 2 | 3 | ## NLU 4 | ### Dependencies 5 | ```bash 6 | conda create -n pubmedgpt python=3.8.12 pytorch=1.12.1 torchdata cudatoolkit=11.3 -c pytorch 7 | conda activate pubmedgpt 8 | pip install -r setup/requirements.txt 9 | ``` 10 | 11 | ### Usage 12 | 13 | Note we are not providing the data. Demo versions of the `.jsonl` files are presented to show expected format. 14 | There should be one json per line for each example in the respective data sets for these tasks. 15 | 16 | For PubMedQA and BioASQ, go to `seqcls/` and run the following command (change paths appropriately for task): 17 | ```bash 18 | task=pubmedqa_hf 19 | datadir=data/$task 20 | outdir=runs/$task/GPT2 21 | mkdir -p $outdir 22 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 run_seqcls_gpt.py \ 23 | --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --model_name_or_path {checkpoint} --train_file \ 24 | $datadir/train.json --validation_file $datadir/dev.json --test_file $datadir/test.json --do_train \ 25 | --do_eval --do_predict --per_device_train_batch_size 1 --gradient_accumulation_steps \ 26 | {grad_accum} --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {num_epochs} --max_seq_length \ 27 | {seq_len} --logging_steps 100 --save_strategy no --evaluation_strategy no --output_dir \ 28 | {run_dir} --overwrite_output_dir --bf16 29 | --seed {seed} --run_name {name} 30 | ``` 31 | 32 | 33 | For MedQA-USMLE, go to `mc/` and run the following command: 34 | ```bash 35 | task=medqa_usmle_hf 36 | datadir=data/$task 37 | outdir=runs/$task/GPT2 38 | mkdir -p $outdir 39 | python -m torch.distributed.launch --nproc_per_node={num_devices} --nnodes=1 --node_rank=0 \ 40 | run_multiple_choice.py --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --model_name_or_path \ 41 | {checkpoint} --train_file data/medqa_usmle_hf/train.json --validation_file data/medqa_usmle_hf/dev.json \ 42 | --test_file data/medqa_usmle_hf/test.json --do_train --do_eval --do_predict --per_device_train_batch_size \ 43 | {train_per_device_batch_size} --per_device_eval_batch_size 1 --gradient_accumulation_steps {grad_accum} \ 44 | --learning_rate {lr} --warmup_ratio 0.5 --num_train_epochs {epochs} --max_seq_length 512 \ 45 | --{numerical_format} --seed {seed} --data_seed {seed} --logging_first_step --logging_steps 20 \ 46 | --save_strategy no --evaluation_strategy steps --eval_steps 500 --run_name {run_name} \ 47 | --output_dir trash/ \ 48 | --overwrite_output_dir 49 | ``` 50 | 51 | ## NLG 52 | Go to `./textgen`. 53 | 54 | ### Usage (seq2seq tasks) 55 | Make sure the task dataset is in `./textgen/data`. See `meqsum` (a medical text simplification task) as an example. The dataset folder should have `.source` and `.target` files. The `.source` file should contain the original text in a one example per line format (e.g. the full original question from the user in the MeQSum task) and the `.target` file should contain the desired output in a one example per line format (e.g. the summarization of the question). This set up can be adapted for a new task. For instance you could place biomedical articles in the source files and then brief summaries in the target files. 56 | 57 | Go to `./textgen/gpt2`. 58 | To finetune, run: 59 | ``` 60 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 61 | finetune_for_summarization.py --output_dir {run_dir} --model_name_or_path {checkpoint} 62 | --tokenizer_name stanford-crfm/pubmed_gpt_tokenizer --per_device_train_batch_size 1 63 | --per_device_eval_batch_size 1 --save_strategy no --do_eval --train_data_file 64 | data/meqsum/train.source --eval_data_file data/meqsum/val.source --save_total_limit 2 65 | --overwrite_output_dir --gradient_accumulation_steps {grad_accum} --learning_rate {lr} 66 | --warmup_ratio 0.5 --weight_decay 0.0 --seed 7 --evaluation_strategy steps --eval_steps 200 67 | --bf16 --num_train_epochs {num_epochs} --logging_steps 100 --logging_first_step 68 | ``` 69 | 70 | After finetuning, run generation on the test set by: 71 | 72 | ``` 73 | CUDA_VISIBLE_DEVICES=0 python -u run_generation_batch.py --fp16 --max_source_length -1 --length 400 --model_name_or_path={finetune_checkpoint} --num_return_sequences 5 --stop_token [SEP] --tokenizer_name={finetune_checkpoint} --task_mode=meqsum --control_mode=no --tuning_mode finetune --gen_dir gen_results__tgtlen400__no_repeat_ngram_size6 --batch_size 9 --temperature 1.0 --no_repeat_ngram_size 6 --length_penalty -0.5 --wandb_entity=None --wandb_project=None --wandb_run_name=None 74 | ``` 75 | 76 | 77 | ### Acknowledgement 78 | The NLG part of the code was built on https://github.com/XiangLi1999/PrefixTuning 79 | -------------------------------------------------------------------------------- /finetune/utils/hf_flash_gpt_2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Modified HF GPT2 w/flash attention""" 17 | 18 | import os 19 | from typing import Optional, Tuple, Union 20 | 21 | import torch 22 | from einops import rearrange 23 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 24 | from torch import nn 25 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 26 | from transformers.models.gpt2.modeling_gpt2 import ( 27 | GPT2MLP, 28 | CausalLMOutputWithCrossAttentions, 29 | GPT2Attention, 30 | GPT2Block, 31 | GPT2LMHeadModel, 32 | GPT2Model, 33 | GPT2PreTrainedModel, 34 | ) 35 | 36 | 37 | class GPT2FlashAttention(GPT2Attention): 38 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 39 | super().__init__(config=config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) 40 | self.attn_pdrop = config.attn_pdrop 41 | 42 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 43 | # rearrange to flash attention form 44 | key = rearrange(key, 'b h s d -> b s h d') 45 | value = rearrange(value, 'b h s d -> b s h d') 46 | query = rearrange(query, 'b h s d -> b s h d') 47 | 48 | # stack 49 | qkv = torch.stack([query,key,value], dim=2) 50 | assert qkv.dtype in [torch.float16, torch.bfloat16] 51 | 52 | # flash attention logic 53 | batch_size = qkv.shape[0] 54 | seqlen = qkv.shape[1] 55 | dk = qkv.shape[4] 56 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 57 | max_s = seqlen 58 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) 59 | attn_pdrop = self.attn_pdrop if self.training else 0.0 60 | softmax_scale = (1.0 / (dk ** 0.5)) if self.scale_attn_weights else 1.0 61 | softmax_scale = (softmax_scale / float(self.layer_idx + 1)) if self.scale_attn_by_inverse_layer_idx else softmax_scale 62 | output = flash_attn_unpadded_qkvpacked_func( 63 | qkv, cu_seqlens, max_s, attn_pdrop, 64 | softmax_scale=softmax_scale, causal=True 65 | ) 66 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) 67 | output = rearrange(output, 'b s h d -> b h s d') 68 | 69 | return output, None 70 | 71 | 72 | class GPT2FlashBlock(GPT2Block): 73 | def __init__(self, config, layer_idx=None): 74 | super(GPT2Block, self).__init__() 75 | hidden_size = config.hidden_size 76 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 77 | 78 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 79 | self.attn = GPT2FlashAttention(config, layer_idx=layer_idx) 80 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 81 | 82 | if config.add_cross_attention: 83 | self.crossattention = GPT2FlashAttention(config, is_cross_attention=True, layer_idx=layer_idx) 84 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 85 | 86 | self.mlp = GPT2MLP(inner_dim, config) 87 | 88 | 89 | class GPT2FlashModel(GPT2Model): 90 | def __init__(self, config): 91 | super(GPT2Model, self).__init__(config) 92 | 93 | self.embed_dim = config.hidden_size 94 | 95 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 96 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 97 | 98 | self.drop = nn.Dropout(config.embd_pdrop) 99 | self.h = nn.ModuleList([GPT2FlashBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 100 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 101 | 102 | # Model parallel 103 | self.model_parallel = False 104 | self.device_map = None 105 | self.gradient_checkpointing = False 106 | 107 | # Initialize weights and apply final processing 108 | self.post_init() 109 | 110 | 111 | class GPT2FlashLMHeadModel(GPT2LMHeadModel): 112 | def __init__(self, config): 113 | super(GPT2LMHeadModel, self).__init__(config) 114 | 115 | self.transformer = GPT2FlashModel(config) 116 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 117 | 118 | # Model parallel 119 | self.model_parallel = False 120 | self.device_map = None 121 | 122 | # Initialize weights and apply final processing 123 | self.post_init() 124 | -------------------------------------------------------------------------------- /finetune/textgen/gpt2/finetune_for_summarization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from dataclasses import dataclass, field 4 | from transformers import ( 5 | CONFIG_MAPPING, 6 | MODEL_WITH_LM_HEAD_MAPPING, 7 | AutoConfig, 8 | AutoModelWithLMHead, 9 | AutoTokenizer, 10 | HfArgumentParser, 11 | PreTrainedTokenizer, 12 | TextDataset, 13 | Trainer, 14 | TrainingArguments, 15 | set_seed, 16 | GPT2LMHeadModel, 17 | AutoModelForCausalLM, 18 | ) 19 | 20 | from sum_data_collator import DataCollatorForSumLanguageModeling 21 | from sum_dataset import LineByLineSumTextDataset 22 | 23 | import torch.distributed as dist 24 | 25 | import json 26 | 27 | import sys 28 | 29 | sys.path.insert(0, "../..") 30 | 31 | @dataclass 32 | class ModelArguments: 33 | """ 34 | Arguments for the model 35 | """ 36 | 37 | model_name_or_path: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | "help": ( 41 | "The model checkpoint for weights initialization. Leave None if you want to train a model from" 42 | " scratch." 43 | ) 44 | }, 45 | ) 46 | 47 | tokenizer_name: Optional[str] = field( 48 | default="gpt2", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 49 | ) 50 | 51 | use_flash: bool = field( 52 | default=False, metadata={"help": "Use flash attention."} 53 | ) 54 | 55 | @dataclass 56 | class DataArguments: 57 | """ 58 | Arguments for data 59 | """ 60 | 61 | train_data_file: Optional[str] = field( 62 | default=None, metadata={"help": "The input training data file (a text file)."} 63 | ) 64 | eval_data_file: Optional[str] = field( 65 | default=None, 66 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 67 | ) 68 | max_source_length: Optional[int] = field( 69 | default=510, metadata={"help": "the max source length of summarization data. "} 70 | ) 71 | train_max_target_length: Optional[int] = field( 72 | default=510, metadata={"help": "the max target length for training data. "} 73 | ) 74 | eval_max_target_length: Optional[int] = field( 75 | default=510, metadata={"help": "the max target length for dev data. "} 76 | ) 77 | seq_prefix: Optional[str] = field( 78 | default="", 79 | metadata={"help": "A string to begin every sequence with."}, 80 | ) 81 | no_sep: bool = field( 82 | default=False, metadata={"help": "Don't use a separator token."} 83 | ) 84 | block_size: int = field( 85 | default=-1, 86 | metadata={ 87 | "help": ( 88 | "Optional input sequence length after tokenization." 89 | "The training dataset will be truncated in block of this size for training." 90 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 91 | ) 92 | }, 93 | ) 94 | 95 | 96 | def get_dataset( 97 | args: DataArguments, 98 | tokenizer: PreTrainedTokenizer, 99 | evaluate: bool = False, 100 | cache_dir: Optional[str] = None, 101 | training_args: TrainingArguments = None, 102 | ): 103 | file_path = args.eval_data_file if evaluate else args.train_data_file 104 | max_source_length = args.max_source_length 105 | max_target_length = args.train_max_target_length if not evaluate else args.eval_max_target_length 106 | dataset = LineByLineSumTextDataset( 107 | tokenizer=tokenizer, 108 | file_path=file_path, 109 | block_size=1024, 110 | bos_tok=tokenizer.bos_token, 111 | eos_tok=tokenizer.eos_token, 112 | max_source_length=max_source_length, 113 | max_target_length=max_target_length, 114 | seq_prefix=args.seq_prefix, 115 | no_sep=args.no_sep 116 | ) 117 | 118 | return dataset 119 | 120 | 121 | def finetune(): 122 | # parse args 123 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 124 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 125 | # set seed 126 | set_seed(training_args.seed) 127 | # set up model 128 | config = AutoConfig.from_pretrained(model_args.model_name_or_path) 129 | if model_args.use_flash: 130 | from utils.hf_flash_gpt_2 import GPT2FlashLMHeadModel 131 | model = GPT2FlashLMHeadModel.from_pretrained( 132 | model_args.model_name_or_path, 133 | config=config, 134 | ) 135 | else: 136 | model = AutoModelForCausalLM.from_pretrained( 137 | model_args.model_name_or_path, 138 | config=config, 139 | ) 140 | # set up tokenizer 141 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name) 142 | # add extra pad token 143 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 144 | tokenizer.add_special_tokens({"bos_token": "<|startoftext|>"}) 145 | tokenizer.add_special_tokens({"eos_token": "<|endoftext|>"}) 146 | embedding_layer = model.resize_token_embeddings(len(tokenizer)) 147 | # set up data collator 148 | data_collator = DataCollatorForSumLanguageModeling(tokenizer=tokenizer) 149 | # set up data sets 150 | train_dataset = get_dataset(data_args, tokenizer=tokenizer, training_args=training_args) 151 | eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) 152 | # set up trainer 153 | trainer = Trainer( 154 | model=model, 155 | args=training_args, 156 | train_dataset=train_dataset, 157 | eval_dataset=eval_dataset, 158 | tokenizer=tokenizer, 159 | data_collator=data_collator 160 | ) 161 | # launch fine tuning 162 | trainer.train() 163 | # save final model 164 | trainer.save_model() 165 | trainer.save_state() 166 | 167 | if __name__ == "__main__": 168 | finetune() 169 | -------------------------------------------------------------------------------- /finetune/textgen/gpt2/sum_data_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from torch.nn.utils.rnn import pad_sequence 5 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy 6 | from transformers.tokenization_utils import PreTrainedTokenizer 7 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 8 | 9 | @dataclass 10 | class DataCollatorForSumLanguageModeling: 11 | """ 12 | Data collator used for language modeling. 13 | - collates batches of tensors, honoring their tokenizer's pad_token 14 | - preprocesses batches for masked language modeling 15 | """ 16 | tokenizer: PreTrainedTokenizer 17 | mlm: bool = False 18 | format_mode: str = 'cat' 19 | mlm_probability: float = 0.15 20 | 21 | def __call__( 22 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 23 | ) -> Dict[str, torch.Tensor]: 24 | if isinstance(examples[0], (dict, BatchEncoding)): 25 | examples = [e["input_ids"] for e in examples] 26 | # print(examples[0]) 27 | # print(len(examples)) 28 | input_ids, labels, src, tgt = zip(*examples) 29 | # print(len(input_ids), len(labels), len(weights)) 30 | if self.mlm: 31 | inputs, labels = self.mask_tokens(batch) 32 | return {"input_ids": inputs, "labels": labels} 33 | else: 34 | 35 | # print(self.format_mode) 36 | 37 | if self.format_mode == 'peek' or self.format_mode == 'cat': 38 | mode_input = 1 39 | elif self.format_mode == 'nopeek': 40 | assert False, 'should use format_mode = peek or cat.' 41 | mode_input = 2 42 | elif self.format_mode == 'infix': 43 | assert False, 'should use format_mode = peek or cat.' 44 | mode_input = 4 45 | 46 | # mode_input = 1 # means that we take the input again. 47 | # mode_input = 2 # means that we do not peek at src again. 48 | # mode_input = 3 # means that we look at the categories, and see the input again. 49 | 50 | # print(self.format_mode, mode_input) 51 | 52 | if mode_input == 1: 53 | # input, batch 54 | batch = self._tensorize_batch(input_ids) 55 | labels = self._tensorize_batch(labels) 56 | src = self._tensorize_batch(src) 57 | 58 | labels[labels == self.tokenizer.pad_token_id] = -100 # tgt 59 | src_attn = (src != self.tokenizer.pad_token_id) # src 60 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt 61 | 62 | return {"input_ids": batch, "labels": labels} 63 | 64 | 65 | def _tensorize_batch( 66 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 67 | ) -> torch.Tensor: 68 | # In order to accept both lists of lists and lists of Tensors 69 | if isinstance(examples[0], (list, tuple)): 70 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 71 | length_of_first = examples[0].size(0) 72 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 73 | if are_tensors_same_length: 74 | return torch.stack(examples, dim=0) 75 | else: 76 | if self.tokenizer._pad_token is None: 77 | raise ValueError( 78 | "You are attempting to pad samples but the tokenizer you are using" 79 | f" ({self.tokenizer.__class__.__name__}) does not have one." 80 | ) 81 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) 82 | 83 | 84 | @dataclass 85 | class DataCollatorForSumBatchGenLanguageModeling: 86 | """ 87 | Data collator used for language modeling. 88 | - collates batches of tensors, honoring their tokenizer's pad_token 89 | - preprocesses batches for masked language modeling 90 | """ 91 | tokenizer: PreTrainedTokenizer 92 | mlm: bool = True 93 | format_mode: str = 'cat' 94 | mlm_probability: float = 0.15 95 | max_source_length: int = 512 96 | max_target_length: int = 100 97 | 98 | 99 | def __call__( 100 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 101 | ) -> Dict[str, torch.Tensor]: 102 | if isinstance(examples[0], (dict, BatchEncoding)): 103 | examples = [e["input_ids"] for e in examples] 104 | # print(examples[0]) 105 | # print(len(examples)) 106 | 107 | mode_gen = 1 108 | 109 | if mode_gen == 0: 110 | input_ids, labels, src, tgt = zip(*examples) 111 | # print(len(input_ids), len(labels), len(weights)) 112 | 113 | 114 | 115 | src = self._tensorize_batch(src) #src 116 | tgt = self._tensorize_batch(tgt) # src 117 | 118 | src_attn = (src != self.tokenizer.pad_token_id) # src 119 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt 120 | 121 | return {"input_ids": src, "labels": tgt, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 122 | 'src':src} 123 | 124 | else: 125 | src, tgt = zip(*examples) 126 | bsz = len(src) 127 | self.tokenizer.padding_side = "left" 128 | src = self.tokenizer(src, return_tensors="pt", padding=True, truncation=True, max_length=self.max_source_length) 129 | tgt = self.tokenizer(tgt, return_tensors="pt", padding=True, truncation=True, max_length=self.max_target_length) 130 | bos_seq = torch.ones(bsz, 1).fill_(self.tokenizer.bos_token_id).long() 131 | src_input_ids = torch.cat([src['input_ids'], bos_seq], dim=-1) 132 | bos_mask = torch.ones(bsz, 1).long() 133 | src_mask = torch.cat([src["attention_mask"], bos_mask],dim=-1) 134 | 135 | return {"input_ids": src_input_ids, "labels": tgt['input_ids'], 'src_attn': src_mask, 136 | 'tgt_attn': tgt["attention_mask"]} 137 | 138 | 139 | 140 | 141 | def _tensorize_batch( 142 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 143 | ) -> torch.Tensor: 144 | # In order to accept both lists of lists and lists of Tensors 145 | if isinstance(examples[0], (list, tuple)): 146 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 147 | length_of_first = examples[0].size(0) 148 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 149 | if are_tensors_same_length: 150 | return torch.stack(examples, dim=0) 151 | else: 152 | if self.tokenizer._pad_token is None: 153 | raise ValueError( 154 | "You are attempting to pad samples but the tokenizer you are using" 155 | f" ({self.tokenizer.__class__.__name__}) does not have one." 156 | ) 157 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) 158 | 159 | -------------------------------------------------------------------------------- /finetune/textgen/gpt2/sum_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import time 5 | import copy 6 | import json 7 | from typing import Dict, List, Optional 8 | import ast 9 | import torch 10 | from torch.utils.data.dataset import Dataset 11 | 12 | from filelock import FileLock 13 | 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | from transformers.utils import logging 16 | 17 | from pathlib import Path 18 | import linecache 19 | 20 | # from transformers import BertTokenizer, BertForMaskedLM, BertModel, BertTokenizerFast 21 | # from transformers import BertTokenizer, BertTokenizerFast 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class LineByLineSumTextDataset(Dataset): 26 | """ 27 | This will be superseded by a framework-agnostic approach 28 | soon. 29 | """ 30 | 31 | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, bos_tok:str, eos_tok:str, 32 | max_source_length:int, max_target_length:int, seq_prefix:str="", no_sep:bool=False, use_task_instruction:int=0, use_stream_mode:bool=True): 33 | assert os.path.isfile(file_path), f"Input file path {file_path} not found" 34 | # Here, we do not cache the features, operating under the assumption 35 | # that we will soon use fast multithreaded tokenizers from the 36 | # `tokenizers` repo everywhere =) 37 | logger.info("Creating features from dataset file at %s", file_path) 38 | 39 | self.src_file = file_path 40 | self.tgt_file = file_path[:-6] + 'target' 41 | self.max_source_length = max_source_length 42 | self.max_target_length = max_target_length 43 | if use_task_instruction: 44 | self.instruction = "Summarize the following text: " 45 | else: 46 | self.instruction = None 47 | print (f'Task instruction: "{self.instruction}"') 48 | 49 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 50 | eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0] 51 | 52 | self.bos_idx = separator 53 | self.eos_idx = eos_idx 54 | 55 | self.length = [len(x) for x in Path(self.tgt_file).open().readlines()] 56 | self.tokenizer = tokenizer 57 | 58 | self.use_stream_mode = use_stream_mode 59 | 60 | self.seq_prefix = seq_prefix 61 | self.no_sep = no_sep 62 | 63 | if self.use_stream_mode: 64 | return 65 | else: 66 | src_lines = [] 67 | with open(self.src_file, encoding="utf-8") as f: 68 | for line in f: 69 | line = line.strip() 70 | line = self.instruction + line if self.instruction else line 71 | if len(line) > 0 and not line.isspace(): 72 | src_lines.append(line) 73 | 74 | # print(len(list(f.read().splitlines()))) 75 | # src_lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] 76 | print(len(src_lines)) 77 | with open(self.tgt_file, encoding="utf-8") as f: 78 | tgt_lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] 79 | 80 | print(self.tgt_file, len(tgt_lines), '\n', self.src_file, len(src_lines)) 81 | 82 | assert len(tgt_lines) == len(src_lines) 83 | 84 | src_encoding = tokenizer(src_lines, add_special_tokens=True, truncation=True, max_length=max_source_length, 85 | is_split_into_words=False)['input_ids'] 86 | 87 | tgt_encoding = tokenizer(tgt_lines, add_special_tokens=True, truncation=True, max_length=max_target_length, 88 | is_split_into_words=False)['input_ids'] 89 | 90 | assert len(src_encoding) == len(tgt_encoding) 91 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 92 | eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0] 93 | 94 | edited_sents = [] 95 | for src, tgt in zip(src_encoding, tgt_encoding): 96 | sent = src + [separator] + tgt + [eos_idx] 97 | # sent = ' {} {} '.format(src, bos_tok) + tgt + ' {}'.format(eos_tok) 98 | edited_sents.append(sent) 99 | 100 | # batch_encoding = tokenizer(edited_sents, add_special_tokens=True, truncation=True, max_length=block_size, 101 | # is_split_into_words=False) 102 | 103 | self.examples = edited_sents 104 | 105 | self.labels = copy.deepcopy(self.examples) 106 | 107 | 108 | 109 | self.src_sent = [] 110 | self.tgt_sent = [] 111 | if True: 112 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 113 | for i, elem in enumerate(self.labels): 114 | sep_idx = elem.index(separator) + 1 115 | self.src_sent.append(self.examples[i][:sep_idx-1]) 116 | self.tgt_sent.append(self.examples[i][sep_idx-1:]) 117 | self.labels[i][:sep_idx] = [-100] * sep_idx 118 | 119 | 120 | print(self.labels[0]) 121 | print(self.examples[0]) 122 | print(edited_sents[0]) 123 | print(self.src_sent[0]) 124 | print(self.tgt_sent[0]) 125 | # assert len(self.src_cat) == len(self.examples) 126 | 127 | 128 | 129 | 130 | def __len__(self): 131 | return len(self.length) 132 | 133 | 134 | def __getitem__(self, i): 135 | if not self.use_stream_mode: 136 | return (torch.tensor(self.examples[i], dtype=torch.long), 137 | torch.tensor(self.labels[i], dtype=torch.long), 138 | torch.tensor(self.src_sent[i], dtype=torch.long), 139 | torch.tensor(self.tgt_sent[i], dtype=torch.long), 140 | ) 141 | else: 142 | index = i + 1 # linecache starts at 1 143 | source_line = linecache.getline(str(self.src_file), index).rstrip("\n") 144 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 145 | assert source_line, f"empty source line for index {index}" 146 | assert tgt_line, f"empty tgt line for index {index}" 147 | 148 | source_line = self.instruction + source_line if self.instruction else self.seq_prefix + source_line 149 | 150 | src = self.tokenizer(source_line, add_special_tokens=True, truncation=True, max_length=self.max_source_length, 151 | is_split_into_words=False)['input_ids'] 152 | 153 | tgt = self.tokenizer(tgt_line, add_special_tokens=True, truncation=True, max_length=self.max_target_length, 154 | is_split_into_words=False)['input_ids'] 155 | 156 | if self.no_sep: 157 | sent = src + tgt + [self.eos_idx] 158 | label = copy.deepcopy(sent) 159 | label[:len(src)] = [-100] * len(src) 160 | src_sent = sent[:len(src)] 161 | tgt_sent = sent[len(src):] 162 | else: 163 | sent = src + [self.bos_idx] + tgt + [self.eos_idx] 164 | sep_idx = sent.index(self.bos_idx) + 1 165 | label = copy.deepcopy(sent) 166 | label[:sep_idx] = [-100] * sep_idx 167 | src_sent = sent[:sep_idx - 1] 168 | tgt_sent = sent[sep_idx - 1:] 169 | 170 | return (torch.tensor(sent, dtype=torch.long), 171 | torch.tensor(label, dtype=torch.long), 172 | torch.tensor(src_sent, dtype=torch.long), 173 | torch.tensor(tgt_sent, dtype=torch.long), 174 | ) 175 | 176 | 177 | class LineByLineSumBatchGenTextDataset(Dataset): 178 | """ 179 | This will be superseded by a framework-agnostic approach 180 | soon. 181 | """ 182 | 183 | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, bos_tok:str, eos_tok:str, 184 | max_source_length:int, max_target_length:int, use_task_instruction:int=0): 185 | assert os.path.isfile(file_path), f"Input file path {file_path} not found" 186 | # Here, we do not cache the features, operating under the assumption 187 | # that we will soon use fast multithreaded tokenizers from the 188 | # `tokenizers` repo everywhere =) 189 | logger.info("Creating features from dataset file at %s", file_path) 190 | 191 | self.src_file = file_path 192 | self.tgt_file = file_path[:-6] + 'target' 193 | self.max_source_length = max_source_length 194 | self.max_target_length = max_target_length 195 | if use_task_instruction: 196 | self.instruction = "Summarize the following text: " 197 | else: 198 | self.instruction = None 199 | print (f'Task instruction: "{self.instruction}"') 200 | 201 | separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0] 202 | eos_tok = "[SEP]" 203 | eos_idx = tokenizer(eos_tok, add_special_tokens=False)['input_ids'][0] 204 | 205 | self.bos_idx = separator 206 | self.eos_idx = eos_idx 207 | 208 | tokenizer.pad_token = "[PAD]" 209 | tokenizer.pad_token_id = 28896 210 | 211 | self.length = [len(x) for x in Path(self.tgt_file).open().readlines()] 212 | self.tokenizer = tokenizer 213 | return 214 | 215 | 216 | 217 | 218 | def __len__(self): 219 | return len(self.length) 220 | 221 | # def __getitem__(self, i) -> torch.Tensor: 222 | def __getitem__(self, i): 223 | # return (torch.tensor(self.examples[i], dtype=torch.long), 224 | # torch.tensor(self.labels[i], dtype=torch.long), 225 | # torch.tensor(self.src_sent[i], dtype=torch.long), 226 | # torch.tensor(self.tgt_sent[i], dtype=torch.long), 227 | # ) 228 | 229 | modegen = 1 230 | index = i + 1 # linecache starts at 1 231 | source_line = linecache.getline(str(self.src_file), index).rstrip("\n") 232 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 233 | assert source_line, f"empty source line for index {index}" 234 | assert tgt_line, f"empty tgt line for index {index}" 235 | 236 | source_line = self.instruction + source_line if self.instruction else source_line 237 | 238 | if modegen == 0: 239 | 240 | src = self.tokenizer(source_line, add_special_tokens=True, truncation=True, max_length=self.max_source_length, 241 | is_split_into_words=False)['input_ids'] 242 | 243 | tgt = self.tokenizer(tgt_line, add_special_tokens=True, truncation=True, max_length=self.max_target_length, 244 | is_split_into_words=False)['input_ids'] 245 | 246 | sent = src + [self.bos_idx] + tgt + [self.eos_idx] 247 | 248 | sep_idx = sent.index(self.bos_idx) + 1 249 | 250 | label = copy.deepcopy(sent) 251 | label[:sep_idx] = [-100] * sep_idx 252 | src_sent = sent[:sep_idx - 1] 253 | tgt_sent = sent[sep_idx - 1:] 254 | 255 | return (torch.tensor(sent, dtype=torch.long), 256 | torch.tensor(label, dtype=torch.long), 257 | ) 258 | 259 | else: 260 | return (source_line, tgt_line) 261 | 262 | -------------------------------------------------------------------------------- /finetune/utils/custom_modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from dataclasses import dataclass 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | from packaging import version 9 | from torch import nn 10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 11 | 12 | 13 | from transformers.activations import ACT2FN 14 | from transformers.file_utils import ( 15 | ModelOutput, 16 | add_code_sample_docstrings, 17 | add_start_docstrings, 18 | add_start_docstrings_to_model_forward, 19 | replace_return_docstrings, 20 | ) 21 | from transformers.modeling_outputs import ( 22 | BaseModelOutputWithPastAndCrossAttentions, 23 | CausalLMOutputWithCrossAttentions, 24 | SequenceClassifierOutputWithPast, 25 | TokenClassifierOutput, 26 | MultipleChoiceModelOutput, 27 | ) 28 | from transformers.modeling_utils import ( 29 | Conv1D, 30 | PreTrainedModel, 31 | SequenceSummary, 32 | find_pruneable_heads_and_indices, 33 | prune_conv1d_layer, 34 | ) 35 | from transformers.utils import logging 36 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 37 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | _CHECKPOINT_FOR_DOC = "gpt2" 43 | _CONFIG_FOR_DOC = "GPT2Config" 44 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 45 | 46 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 47 | "gpt2", 48 | "gpt2-medium", 49 | "gpt2-large", 50 | "gpt2-xl", 51 | "distilgpt2", 52 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 53 | ] 54 | from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel 55 | 56 | 57 | class GPT2ForTokenClassification(GPT2PreTrainedModel): 58 | def __init__(self, config): 59 | super().__init__(config) 60 | self.num_labels = config.num_labels 61 | 62 | self.transformer = GPT2Model(config) 63 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: 64 | classifier_dropout = config.classifier_dropout 65 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: 66 | classifier_dropout = config.hidden_dropout 67 | else: 68 | classifier_dropout = 0.1 69 | self.dropout = nn.Dropout(classifier_dropout) 70 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 71 | 72 | # Model parallel 73 | self.model_parallel = False 74 | self.device_map = None 75 | 76 | # Initialize weights and apply final processing 77 | self.init_weights() 78 | 79 | def forward( 80 | self, 81 | input_ids=None, 82 | past_key_values=None, 83 | attention_mask=None, 84 | token_type_ids=None, 85 | position_ids=None, 86 | head_mask=None, 87 | inputs_embeds=None, 88 | labels=None, 89 | use_cache=None, 90 | output_attentions=None, 91 | output_hidden_states=None, 92 | return_dict=None, 93 | ): 94 | r""" 95 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 96 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 97 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 98 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 99 | """ 100 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 101 | 102 | transformer_outputs = self.transformer( 103 | input_ids, 104 | past_key_values=past_key_values, 105 | attention_mask=attention_mask, 106 | token_type_ids=token_type_ids, 107 | position_ids=position_ids, 108 | head_mask=head_mask, 109 | inputs_embeds=inputs_embeds, 110 | use_cache=use_cache, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, 114 | ) 115 | 116 | hidden_states = transformer_outputs[0] 117 | hidden_states = self.dropout(hidden_states) 118 | logits = self.classifier(hidden_states) 119 | 120 | loss = None 121 | if labels is not None: 122 | loss_fct = CrossEntropyLoss() 123 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 124 | 125 | if not return_dict: 126 | output = (logits,) + transformer_outputs[2:] 127 | return ((loss,) + output) if loss is not None else output 128 | 129 | return TokenClassifierOutput( 130 | loss=loss, 131 | logits=logits, 132 | hidden_states=transformer_outputs.hidden_states, 133 | attentions=transformer_outputs.attentions, 134 | ) 135 | 136 | 137 | class GPT2ForMultipleChoice(GPT2PreTrainedModel): 138 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 139 | 140 | def __init__(self, config): 141 | super().__init__(config) 142 | # self.num_labels = config.num_labels 143 | if config.use_flash: 144 | print("GPT2ForMultipleChoice using Flash !!") 145 | from .hf_flash_gpt_2 import GPT2FlashModel 146 | self.transformer = GPT2FlashModel(config) 147 | elif config.use_gpt_neo: 148 | print("Using GPT2Neo Model !!") 149 | from .custom_modeling_gpt_neo import GPTNeoModel 150 | self.transformer = GPTNeoModel(config) 151 | else: 152 | self.transformer = GPT2Model(config) 153 | print("GPT2ForMultipleChoice not using Flash !!") 154 | # self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 155 | hidden_size = config.hidden_size if config.use_gpt_neo else config.n_embd 156 | self.classifier = nn.Linear(hidden_size, 1) 157 | 158 | self.init_weights() 159 | 160 | # Model parallel 161 | self.model_parallel = False 162 | self.device_map = None 163 | 164 | def forward( 165 | self, 166 | input_ids=None, 167 | past_key_values=None, 168 | attention_mask=None, 169 | token_type_ids=None, 170 | position_ids=None, 171 | head_mask=None, 172 | inputs_embeds=None, 173 | labels=None, 174 | use_cache=None, 175 | output_attentions=None, 176 | output_hidden_states=None, 177 | return_dict=None, 178 | ): 179 | r""" 180 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 181 | Labels for computing the multiple choice classification loss. Indices should be in :obj:`[0, ..., 182 | num_choices - 1]`, where `num_choices` is the size of the second dimension of the input tensors. (See 183 | `input_ids` above) 184 | """ 185 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 186 | 187 | if input_ids is not None: 188 | batch_size, num_choices, sequence_length = input_ids.shape[:3] 189 | else: 190 | batch_size, num_choices, sequence_length = inputs_embeds.shape[:3] 191 | 192 | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 193 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 194 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 195 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 196 | inputs_embeds = ( 197 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 198 | if inputs_embeds is not None 199 | else None 200 | ) 201 | 202 | transformer_outputs = self.transformer( 203 | input_ids, 204 | past_key_values=past_key_values, 205 | attention_mask=attention_mask, 206 | token_type_ids=token_type_ids, 207 | position_ids=position_ids, 208 | head_mask=head_mask, 209 | inputs_embeds=inputs_embeds, 210 | use_cache=use_cache, 211 | output_attentions=output_attentions, 212 | output_hidden_states=output_hidden_states, 213 | return_dict=return_dict, 214 | ) 215 | hidden_states = transformer_outputs[0] 216 | logits = self.classifier(hidden_states) #[batch x num_choices, ] 217 | 218 | assert ( 219 | self.config.pad_token_id is not None 220 | ), "Cannot handle if no padding token is defined." 221 | if self.config.pad_token_id is None: 222 | sequence_lengths = -1 223 | else: 224 | if input_ids is not None: 225 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 226 | else: 227 | sequence_lengths = -1 228 | logger.warning( 229 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 230 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 231 | ) 232 | 233 | pooled_logits = logits[range(batch_size * num_choices), sequence_lengths] #[batch x num_choices, ] 234 | reshaped_logits = pooled_logits.view(-1, num_choices) #[batch, num_choices] 235 | 236 | loss = None 237 | if labels is not None: 238 | loss_fct = CrossEntropyLoss() 239 | loss = loss_fct(reshaped_logits, labels) 240 | 241 | if not return_dict: 242 | output = (reshaped_logits,) + outputs[2:] 243 | return ((loss,) + output) if loss is not None else output 244 | 245 | return MultipleChoiceModelOutput( 246 | loss=loss, 247 | logits=reshaped_logits, 248 | # hidden_states=transformer_outputs.hidden_states, 249 | # attentions=transformer_outputs.attentions, 250 | ) 251 | 252 | 253 | class GPT2ForSequenceClassification(GPT2PreTrainedModel): 254 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 255 | 256 | def __init__(self, config): 257 | super().__init__(config) 258 | self.num_labels = config.num_labels 259 | if config.use_flash: 260 | print("GPT2ForSequenceClassification using Flash !!") 261 | from .hf_flash_gpt_2 import GPT2FlashModel 262 | self.transformer = GPT2FlashModel(config) 263 | else: 264 | self.transformer = GPT2Model(config) 265 | 266 | self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False) 267 | 268 | self.init_weights() 269 | 270 | # Model parallel 271 | self.model_parallel = False 272 | self.device_map = None 273 | 274 | def forward( 275 | self, 276 | input_ids=None, 277 | past_key_values=None, 278 | attention_mask=None, 279 | token_type_ids=None, 280 | position_ids=None, 281 | head_mask=None, 282 | inputs_embeds=None, 283 | labels=None, 284 | use_cache=None, 285 | output_attentions=None, 286 | output_hidden_states=None, 287 | return_dict=None, 288 | ): 289 | r""" 290 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 291 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 292 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 293 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 294 | """ 295 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 296 | 297 | transformer_outputs = self.transformer( 298 | input_ids, 299 | past_key_values=past_key_values, 300 | attention_mask=attention_mask, 301 | token_type_ids=token_type_ids, 302 | position_ids=position_ids, 303 | head_mask=head_mask, 304 | inputs_embeds=inputs_embeds, 305 | use_cache=use_cache, 306 | output_attentions=output_attentions, 307 | output_hidden_states=output_hidden_states, 308 | return_dict=return_dict, 309 | ) 310 | hidden_states = transformer_outputs[0] 311 | logits = self.classifier(hidden_states) 312 | 313 | if input_ids is not None: 314 | batch_size, sequence_length = input_ids.shape[:2] 315 | else: 316 | batch_size, sequence_length = inputs_embeds.shape[:2] 317 | 318 | assert ( 319 | self.config.pad_token_id is not None or batch_size == 1 320 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 321 | if self.config.pad_token_id is None: 322 | sequence_lengths = -1 323 | else: 324 | if input_ids is not None: 325 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 326 | else: 327 | sequence_lengths = -1 328 | logger.warning( 329 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 330 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 331 | ) 332 | 333 | pooled_logits = logits[range(batch_size), sequence_lengths] 334 | 335 | loss = None 336 | if labels is not None: 337 | if self.num_labels == 1: 338 | # We are doing regression 339 | loss_fct = MSELoss() 340 | loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) 341 | else: 342 | loss_fct = CrossEntropyLoss() 343 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 344 | 345 | if not return_dict: 346 | output = (pooled_logits,) + transformer_outputs[1:] 347 | return ((loss,) + output) if loss is not None else output 348 | 349 | return SequenceClassifierOutputWithPast( 350 | loss=loss, 351 | logits=pooled_logits, 352 | # past_key_values=transformer_outputs.past_key_values, 353 | # hidden_states=transformer_outputs.hidden_states, 354 | # attentions=transformer_outputs.attentions, 355 | ) 356 | -------------------------------------------------------------------------------- /finetune/mc/run_multiple_choice.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for multiple choice. 18 | 19 | https://github.com/huggingface/transformers/blob/bff1c71e84e392af9625c345f9ea71f7b6d75fb3/examples/pytorch/multiple-choice/run_swag.py 20 | """ 21 | # You can also adapt this script on your own multiple choice task. Pointers for this are left as comments. 22 | 23 | import logging 24 | import os 25 | import sys 26 | from dataclasses import dataclass, field 27 | from typing import Optional, Union 28 | 29 | import datasets 30 | import numpy as np 31 | import torch 32 | from datasets import load_dataset 33 | 34 | import transformers 35 | from transformers import ( 36 | AutoConfig, 37 | AutoModelForMultipleChoice, 38 | AutoTokenizer, 39 | HfArgumentParser, 40 | Trainer, 41 | TrainingArguments, 42 | default_data_collator, 43 | set_seed, 44 | ) 45 | from transformers.file_utils import PaddingStrategy 46 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 47 | from transformers.trainer_utils import get_last_checkpoint 48 | from transformers.utils import check_min_version 49 | 50 | sys.path.insert(0, '..') 51 | from utils.custom_modeling_gpt2 import GPT2ForMultipleChoice 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | # check_min_version("4.9.0") 56 | 57 | logger = logging.getLogger(__name__) 58 | 59 | 60 | @dataclass 61 | class ModelArguments: 62 | """ 63 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 64 | """ 65 | 66 | model_name_or_path: str = field( 67 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 68 | ) 69 | config_name: Optional[str] = field( 70 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 71 | ) 72 | tokenizer_name: Optional[str] = field( 73 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 74 | ) 75 | cache_dir: Optional[str] = field( 76 | default=None, 77 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 78 | ) 79 | use_fast_tokenizer: bool = field( 80 | default=True, 81 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 82 | ) 83 | model_revision: str = field( 84 | default="main", 85 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 86 | ) 87 | use_auth_token: bool = field( 88 | default=False, 89 | metadata={ 90 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 91 | "with private models)." 92 | }, 93 | ) 94 | use_flash: bool = field( 95 | default=False, 96 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 97 | ) 98 | use_gpt_neo: bool = field( 99 | default=False, 100 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 101 | ) 102 | 103 | 104 | @dataclass 105 | class DataTrainingArguments: 106 | """ 107 | Arguments pertaining to what data we are going to input our model for training and eval. 108 | """ 109 | 110 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 111 | validation_file: Optional[str] = field( 112 | default=None, 113 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 114 | ) 115 | test_file: Optional[str] = field( 116 | default=None, 117 | metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, 118 | ) 119 | overwrite_cache: bool = field( 120 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 121 | ) 122 | preprocessing_num_workers: Optional[int] = field( 123 | default=None, 124 | metadata={"help": "The number of processes to use for the preprocessing."}, 125 | ) 126 | # num_choices: int = field( 127 | # default=4, 128 | # metadata={"help": "Number of choices in multiple-choice QA."}, 129 | # ) 130 | max_seq_length: Optional[int] = field( 131 | default=None, 132 | metadata={ 133 | "help": "The maximum total input sequence length after tokenization. If passed, sequences longer " 134 | "than this will be truncated, sequences shorter will be padded." 135 | }, 136 | ) 137 | pad_to_max_length: bool = field( 138 | default=False, 139 | metadata={ 140 | "help": "Whether to pad all samples to the maximum sentence length. " 141 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 142 | "efficient on GPU but very bad for TPU." 143 | }, 144 | ) 145 | max_train_samples: Optional[int] = field( 146 | default=None, 147 | metadata={ 148 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 149 | "value if set." 150 | }, 151 | ) 152 | max_eval_samples: Optional[int] = field( 153 | default=None, 154 | metadata={ 155 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 156 | "value if set." 157 | }, 158 | ) 159 | 160 | def __post_init__(self): 161 | if self.train_file is not None: 162 | extension = self.train_file.split(".")[-1] 163 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 164 | if self.validation_file is not None: 165 | extension = self.validation_file.split(".")[-1] 166 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 167 | if self.test_file is not None: 168 | extension = self.test_file.split(".")[-1] 169 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 170 | 171 | @dataclass 172 | class DataCollatorForMultipleChoice: 173 | """ 174 | Data collator that will dynamically pad the inputs for multiple choice received. 175 | Args: 176 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 177 | The tokenizer used for encoding the data. 178 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 179 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 180 | among: 181 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 182 | sequence if provided). 183 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 184 | maximum acceptable input length for the model if that argument is not provided. 185 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 186 | different lengths). 187 | max_length (:obj:`int`, `optional`): 188 | Maximum length of the returned list and optionally padding length (see above). 189 | pad_to_multiple_of (:obj:`int`, `optional`): 190 | If set will pad the sequence to a multiple of the provided value. 191 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 192 | 7.5 (Volta). 193 | """ 194 | 195 | tokenizer: PreTrainedTokenizerBase 196 | padding: Union[bool, str, PaddingStrategy] = True 197 | max_length: Optional[int] = None 198 | pad_to_multiple_of: Optional[int] = None 199 | 200 | def __call__(self, features): 201 | label_name = "label" if "label" in features[0].keys() else "labels" 202 | labels = [int(feature.pop(label_name)) for feature in features] 203 | batch_size = len(features) 204 | num_choices = len(features[0]["input_ids"]) 205 | flattened_features = [ 206 | [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features 207 | ] 208 | flattened_features = sum(flattened_features, []) 209 | 210 | batch = self.tokenizer.pad( 211 | flattened_features, 212 | padding=self.padding, 213 | max_length=self.max_length, 214 | pad_to_multiple_of=self.pad_to_multiple_of, 215 | return_tensors="pt", 216 | ) 217 | 218 | # Un-flatten 219 | batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} 220 | # Add back labels 221 | batch["labels"] = torch.tensor(labels, dtype=torch.int64) 222 | return batch 223 | 224 | 225 | def main(): 226 | # See all possible arguments in src/transformers/training_args.py 227 | # or by passing the --help flag to this script. 228 | # We now keep distinct sets of args, for a cleaner separation of concerns. 229 | 230 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 231 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 232 | # If we pass only one argument to the script and it's the path to a json file, 233 | # let's parse it to get our arguments. 234 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 235 | else: 236 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 237 | 238 | # Setup logging 239 | logging.basicConfig( 240 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 241 | datefmt="%m/%d/%Y %H:%M:%S", 242 | handlers=[logging.StreamHandler(sys.stdout)], 243 | ) 244 | log_level = training_args.get_process_log_level() 245 | logger.setLevel(log_level) 246 | datasets.utils.logging.set_verbosity(log_level) 247 | transformers.utils.logging.set_verbosity(log_level) 248 | transformers.utils.logging.enable_default_handler() 249 | transformers.utils.logging.enable_explicit_format() 250 | 251 | # Log on each process the small summary: 252 | logger.warning( 253 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 254 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 255 | ) 256 | logger.info(f"Training/evaluation parameters {training_args}") 257 | 258 | # Detecting last checkpoint. 259 | last_checkpoint = None 260 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 261 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 262 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 263 | raise ValueError( 264 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 265 | "Use --overwrite_output_dir to overcome." 266 | ) 267 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 268 | logger.info( 269 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 270 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 271 | ) 272 | 273 | # Set seed before initializing model. 274 | set_seed(training_args.seed) 275 | 276 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 277 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 278 | # (the dataset will be downloaded automatically from the datasets Hub). 279 | 280 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 281 | # 'text' is found. You can easily tweak this behavior (see below). 282 | 283 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 284 | # download the dataset. 285 | if data_args.train_file is not None or data_args.validation_file is not None: 286 | data_files = {} 287 | if data_args.train_file is not None: 288 | data_files["train"] = data_args.train_file 289 | if data_args.validation_file is not None: 290 | data_files["validation"] = data_args.validation_file 291 | if data_args.test_file is not None: 292 | data_files["test"] = data_args.test_file 293 | extension = data_args.train_file.split(".")[-1] 294 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 295 | else: 296 | # Downloading and loading the swag dataset from the hub. 297 | raw_datasets = load_dataset("swag", "regular", cache_dir=model_args.cache_dir) 298 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 299 | # https://huggingface.co/docs/datasets/loading_datasets.html. 300 | 301 | # Load pretrained model and tokenizer 302 | 303 | # Distributed training: 304 | # The .from_pretrained methods guarantee that only one local process can concurrently 305 | # download model & vocab. 306 | config = AutoConfig.from_pretrained( 307 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 308 | cache_dir=model_args.cache_dir, 309 | revision=model_args.model_revision, 310 | use_auth_token=True if model_args.use_auth_token else None, 311 | ) 312 | config.use_flash = model_args.use_flash 313 | config.use_gpt_neo = model_args.use_gpt_neo 314 | tokenizer = AutoTokenizer.from_pretrained( 315 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 316 | cache_dir=model_args.cache_dir, 317 | use_fast=model_args.use_fast_tokenizer, 318 | revision=model_args.model_revision, 319 | use_auth_token=True if model_args.use_auth_token else None, 320 | ) 321 | #Added for GPT2 322 | if config.model_type == "gpt2" or "gpt_neo": 323 | model_class = GPT2ForMultipleChoice 324 | else: 325 | model_class = AutoModelForMultipleChoice 326 | 327 | model = model_class.from_pretrained( 328 | model_args.model_name_or_path, 329 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 330 | config=config, 331 | cache_dir=model_args.cache_dir, 332 | revision=model_args.model_revision, 333 | use_auth_token=True if model_args.use_auth_token else None, 334 | ) 335 | #Added for GPT2 336 | if tokenizer.pad_token_id is None: 337 | print('Adding [PAD] token to tokenizer and model word embeddings.') 338 | num_added_tokens = tokenizer.add_special_tokens({'pad_token': '[PAD]', 'cls_token': '[CLS]', 'sep_token': '[SEP]'}) 339 | embedding_layer = model.resize_token_embeddings(len(tokenizer)) 340 | config.pad_token_id = tokenizer.pad_token_id 341 | 342 | 343 | 344 | # When using your own dataset or a different dataset from swag, you will probably need to change this. 345 | _num_choices = len([elm for elm in raw_datasets['train'].features.keys() if elm.startswith('ending')]) 346 | print ('\nnum_choices according to dataset:', _num_choices, '\n') 347 | # raw_datasets['train'].features: {'id': Value(dtype='int64', id=None), 'sent1': Value(dtype='string', id=None), 'sent2': Value(dtype='string', id=None), 'ending0': Value(dtype='string', id=None), 'ending1': Value(dtype='string', id=None), 'ending2': Value(dtype='string', id=None), 'ending3': Value(dtype='string', id=None), 'label': Value(dtype='string', id=None)} 348 | ending_names = [f"ending{i}" for i in range(_num_choices)] 349 | context_name = "sent1" 350 | question_header_name = "sent2" 351 | 352 | if data_args.max_seq_length is None: 353 | max_seq_length = tokenizer.model_max_length 354 | if max_seq_length > 1024: 355 | logger.warning( 356 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 357 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 358 | ) 359 | max_seq_length = 1024 360 | else: 361 | if data_args.max_seq_length > tokenizer.model_max_length: 362 | logger.warning( 363 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 364 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 365 | ) 366 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 367 | 368 | # Preprocessing the datasets. 369 | def preprocess_function(examples): 370 | first_sentences = [[context] * _num_choices for context in examples[context_name]] 371 | question_headers = examples[question_header_name] 372 | second_sentences = [ 373 | [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers) 374 | ] 375 | 376 | # Flatten out 377 | first_sentences = sum(first_sentences, []) 378 | second_sentences = sum(second_sentences, []) 379 | 380 | #Added for GPT2 381 | if config.model_type == "gpt2": 382 | first_sentences = [s + tokenizer.sep_token for s in first_sentences] 383 | second_sentences = [s + tokenizer.sep_token for s in second_sentences] 384 | 385 | # Tokenize 386 | tokenized_examples = tokenizer( 387 | first_sentences, 388 | second_sentences, 389 | truncation=True, 390 | max_length=max_seq_length, 391 | padding="max_length" if data_args.pad_to_max_length else False, 392 | ) 393 | # Un-flatten 394 | return {k: [v[i : i + _num_choices] for i in range(0, len(v), _num_choices)] for k, v in tokenized_examples.items()} 395 | 396 | 397 | if training_args.do_train: 398 | if "train" not in raw_datasets: 399 | raise ValueError("--do_train requires a train dataset") 400 | train_dataset = raw_datasets["train"] 401 | if data_args.max_train_samples is not None: 402 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 403 | with training_args.main_process_first(desc="train dataset map pre-processing"): 404 | train_dataset = train_dataset.map( 405 | preprocess_function, 406 | batched=True, 407 | num_proc=data_args.preprocessing_num_workers, 408 | load_from_cache_file=not data_args.overwrite_cache, 409 | ) 410 | 411 | if training_args.do_eval: 412 | if "validation" not in raw_datasets: 413 | raise ValueError("--do_eval requires a validation dataset") 414 | eval_dataset = raw_datasets["validation"] 415 | if data_args.max_eval_samples is not None: 416 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 417 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 418 | eval_dataset = eval_dataset.map( 419 | preprocess_function, 420 | batched=True, 421 | num_proc=data_args.preprocessing_num_workers, 422 | load_from_cache_file=not data_args.overwrite_cache, 423 | ) 424 | 425 | if training_args.do_predict: #Added 426 | if "test" not in raw_datasets: 427 | raise ValueError("--do_predict requires a test dataset") 428 | predict_dataset = raw_datasets["test"] 429 | with training_args.main_process_first(desc="test dataset map pre-processing"): 430 | predict_dataset = predict_dataset.map( 431 | preprocess_function, 432 | batched=True, 433 | num_proc=data_args.preprocessing_num_workers, 434 | load_from_cache_file=not data_args.overwrite_cache, 435 | ) 436 | 437 | # Data collator 438 | data_collator = ( 439 | default_data_collator 440 | if data_args.pad_to_max_length 441 | else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 442 | ) 443 | 444 | # Metric 445 | def compute_metrics(eval_predictions): 446 | predictions, label_ids = eval_predictions 447 | preds = np.argmax(predictions, axis=1) 448 | return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()} 449 | 450 | # Initialize our Trainer 451 | trainer = Trainer( 452 | model=model, 453 | args=training_args, 454 | train_dataset=train_dataset if training_args.do_train else None, 455 | eval_dataset=eval_dataset if training_args.do_eval else None, 456 | tokenizer=tokenizer, 457 | data_collator=data_collator, 458 | compute_metrics=compute_metrics, 459 | ) 460 | 461 | # Training 462 | if training_args.do_train: 463 | checkpoint = None 464 | if training_args.resume_from_checkpoint is not None: 465 | checkpoint = training_args.resume_from_checkpoint 466 | elif last_checkpoint is not None: 467 | checkpoint = last_checkpoint 468 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 469 | trainer.save_model() # Saves the tokenizer too for easy upload 470 | metrics = train_result.metrics 471 | 472 | max_train_samples = ( 473 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 474 | ) 475 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 476 | 477 | trainer.log_metrics("train", metrics) 478 | trainer.save_metrics("train", metrics) 479 | trainer.save_state() 480 | 481 | # Evaluation 482 | if training_args.do_eval: 483 | logger.info("*** Evaluate ***") 484 | 485 | metrics = trainer.evaluate() 486 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 487 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 488 | 489 | trainer.log_metrics("eval", metrics) 490 | trainer.save_metrics("eval", metrics) 491 | 492 | if training_args.do_predict: #Added 493 | logger.info("*** Predict ***") 494 | results = trainer.predict(predict_dataset) 495 | metrics = results.metrics 496 | metrics["predict_samples"] = len(predict_dataset) 497 | 498 | trainer.log_metrics("predict", metrics) 499 | trainer.save_metrics("predict", metrics) 500 | trainer.log(metrics) #Added 501 | 502 | #Added 503 | import json 504 | output_dir = training_args.output_dir 505 | json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()}, 506 | open(f"{output_dir}/predict_outputs.json", "w")) 507 | 508 | 509 | if training_args.push_to_hub: 510 | trainer.push_to_hub( 511 | finetuned_from=model_args.model_name_or_path, 512 | tasks="multiple-choice", 513 | dataset_tags="swag", 514 | dataset_args="regular", 515 | dataset="SWAG", 516 | language="en", 517 | ) 518 | 519 | 520 | def _mp_fn(index): 521 | # For xla_spawn (TPUs) 522 | main() 523 | 524 | 525 | if __name__ == "__main__": 526 | main() 527 | -------------------------------------------------------------------------------- /finetune/seqcls/run_seqcls_gpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification. 17 | 18 | Adapted from 19 | https://github.com/huggingface/transformers/blob/72aee83ced5f31302c5e331d896412737287f976/examples/pytorch/text-classification/run_glue.py 20 | """ 21 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 22 | 23 | import logging 24 | import os 25 | import random 26 | import sys 27 | from dataclasses import dataclass, field 28 | from typing import Optional 29 | 30 | import datasets 31 | import numpy as np 32 | from datasets import load_dataset, load_metric 33 | 34 | import torch 35 | import transformers 36 | from transformers import ( 37 | AutoConfig, 38 | AutoModelForSequenceClassification, 39 | AutoTokenizer, 40 | DataCollatorWithPadding, 41 | EvalPrediction, 42 | HfArgumentParser, 43 | PretrainedConfig, 44 | Trainer, 45 | TrainingArguments, 46 | default_data_collator, 47 | set_seed, 48 | ) 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version 51 | from transformers.utils.versions import require_version 52 | 53 | sys.path.insert(0, '..') 54 | from utils.custom_modeling_gpt2 import GPT2ForSequenceClassification 55 | from utils.custom_modeling_gpt_neo import GPTNeoForSequenceClassification 56 | 57 | 58 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 59 | check_min_version("4.9.0") 60 | 61 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 62 | 63 | task_to_keys = { 64 | "cola": ("sentence", None), 65 | "mnli": ("premise", "hypothesis"), 66 | "mrpc": ("sentence1", "sentence2"), 67 | "qnli": ("question", "sentence"), 68 | "qqp": ("question1", "question2"), 69 | "rte": ("sentence1", "sentence2"), 70 | "sst2": ("sentence", None), 71 | "stsb": ("sentence1", "sentence2"), 72 | "wnli": ("sentence1", "sentence2"), 73 | } 74 | 75 | logger = logging.getLogger(__name__) 76 | 77 | 78 | @dataclass 79 | class DataTrainingArguments: 80 | """ 81 | Arguments pertaining to what data we are going to input our model for training and eval. 82 | Using `HfArgumentParser` we can turn this class 83 | into argparse arguments to be able to specify them on 84 | the command line. 85 | """ 86 | 87 | task_name: Optional[str] = field( 88 | default=None, 89 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 90 | ) 91 | metric_name: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "The name of the metric"}, 94 | ) 95 | dataset_name: Optional[str] = field( 96 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 97 | ) 98 | dataset_config_name: Optional[str] = field( 99 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 100 | ) 101 | max_seq_length: int = field( 102 | default=128, 103 | metadata={ 104 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 105 | "than this will be truncated, sequences shorter will be padded." 106 | }, 107 | ) 108 | overwrite_cache: bool = field( 109 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 110 | ) 111 | preprocessing_num_workers: Optional[int] = field( 112 | default=None, 113 | metadata={"help": "The number of processes to use for the preprocessing."}, 114 | ) 115 | 116 | pad_to_max_length: bool = field( 117 | default=True, 118 | metadata={ 119 | "help": "Whether to pad all samples to `max_seq_length`. " 120 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 121 | }, 122 | ) 123 | max_train_samples: Optional[int] = field( 124 | default=None, 125 | metadata={ 126 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 127 | "value if set." 128 | }, 129 | ) 130 | max_eval_samples: Optional[int] = field( 131 | default=None, 132 | metadata={ 133 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 134 | "value if set." 135 | }, 136 | ) 137 | max_predict_samples: Optional[int] = field( 138 | default=None, 139 | metadata={ 140 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 141 | "value if set." 142 | }, 143 | ) 144 | train_file: Optional[str] = field( 145 | default=None, metadata={"help": "A csv or a json file containing the training data."} 146 | ) 147 | validation_file: Optional[str] = field( 148 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 149 | ) 150 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 151 | 152 | gpt2_append_eos_tok: int = field( 153 | default=0, metadata={"help": "Append EOS token after input sequence or not"} 154 | ) 155 | 156 | def __post_init__(self): 157 | if self.task_name is not None: 158 | self.task_name = self.task_name.lower() 159 | if self.task_name not in task_to_keys.keys(): 160 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 161 | elif self.dataset_name is not None: 162 | pass 163 | elif self.train_file is None or self.validation_file is None: 164 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 165 | else: 166 | train_extension = self.train_file.split(".")[-1] 167 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 168 | validation_extension = self.validation_file.split(".")[-1] 169 | assert ( 170 | validation_extension == train_extension 171 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 172 | 173 | 174 | @dataclass 175 | class ModelArguments: 176 | """ 177 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 178 | """ 179 | 180 | model_name_or_path: str = field( 181 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 182 | ) 183 | config_name: Optional[str] = field( 184 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 185 | ) 186 | tokenizer_name: Optional[str] = field( 187 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 188 | ) 189 | cache_dir: Optional[str] = field( 190 | default=None, 191 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 192 | ) 193 | use_fast_tokenizer: bool = field( 194 | default=True, 195 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 196 | ) 197 | model_revision: str = field( 198 | default="main", 199 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 200 | ) 201 | use_auth_token: bool = field( 202 | default=False, 203 | metadata={ 204 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 205 | "with private models)." 206 | }, 207 | ) 208 | use_flash: bool = field( 209 | default=False, metadata={"help": "Use flash attention."} 210 | ) 211 | 212 | 213 | def main(): 214 | # See all possible arguments in src/transformers/training_args.py 215 | # or by passing the --help flag to this script. 216 | # We now keep distinct sets of args, for a cleaner separation of concerns. 217 | 218 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 219 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 220 | # If we pass only one argument to the script and it's the path to a json file, 221 | # let's parse it to get our arguments. 222 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 223 | else: 224 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 225 | 226 | # Setup logging 227 | logging.basicConfig( 228 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 229 | datefmt="%m/%d/%Y %H:%M:%S", 230 | handlers=[logging.StreamHandler(sys.stdout)], 231 | ) 232 | 233 | log_level = training_args.get_process_log_level() 234 | logger.setLevel(log_level) 235 | datasets.utils.logging.set_verbosity(log_level) 236 | transformers.utils.logging.set_verbosity(log_level) 237 | transformers.utils.logging.enable_default_handler() 238 | transformers.utils.logging.enable_explicit_format() 239 | 240 | # Log on each process the small summary: 241 | logger.warning( 242 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 243 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 244 | ) 245 | logger.info(f"Training/evaluation parameters {training_args}") 246 | 247 | # Detecting last checkpoint. 248 | last_checkpoint = None 249 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 250 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 251 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 252 | raise ValueError( 253 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 254 | "Use --overwrite_output_dir to overcome." 255 | ) 256 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 257 | logger.info( 258 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 259 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 260 | ) 261 | 262 | # Set seed before initializing model. 263 | set_seed(training_args.seed) 264 | 265 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 266 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 267 | # 268 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 269 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 270 | # label if at least two columns are provided. 271 | # 272 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 273 | # single column. You can easily tweak this behavior (see below) 274 | # 275 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 276 | # download the dataset. 277 | if data_args.task_name is not None: 278 | # Downloading and loading a dataset from the hub. 279 | raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 280 | elif data_args.dataset_name is not None: 281 | # Downloading and loading a dataset from the hub. 282 | raw_datasets = load_dataset( 283 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 284 | ) 285 | else: 286 | # Loading a dataset from your local files. 287 | # CSV/JSON training and evaluation files are needed. 288 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 289 | 290 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 291 | # when you use `do_predict` without specifying a GLUE benchmark task. 292 | if training_args.do_predict: 293 | if data_args.test_file is not None: 294 | train_extension = data_args.train_file.split(".")[-1] 295 | test_extension = data_args.test_file.split(".")[-1] 296 | assert ( 297 | test_extension == train_extension 298 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 299 | data_files["test"] = data_args.test_file 300 | else: 301 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 302 | 303 | for key in data_files.keys(): 304 | logger.info(f"load a local file for {key}: {data_files[key]}") 305 | 306 | if data_args.train_file.endswith(".csv"): 307 | # Loading a dataset from local csv files 308 | raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 309 | else: 310 | # Loading a dataset from local json files 311 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 312 | # See more about loading any type of standard or custom dataset at 313 | # https://huggingface.co/docs/datasets/loading_datasets.html. 314 | 315 | # Labels 316 | if data_args.task_name is not None: 317 | is_regression = data_args.task_name == "stsb" 318 | if not is_regression: 319 | label_list = raw_datasets["train"].features["label"].names 320 | num_labels = len(label_list) 321 | else: 322 | num_labels = 1 323 | else: 324 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 325 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 326 | if is_regression: 327 | print ('is_regression', is_regression) 328 | num_labels = 1 329 | else: 330 | # A useful fast method: 331 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 332 | label_list = raw_datasets["train"].unique("label") 333 | label_list.sort() # Let's sort it for determinism 334 | print ('\nlabel_list', label_list) 335 | num_labels = len(label_list) 336 | 337 | # Load pretrained model and tokenizer 338 | # 339 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 340 | # download model & vocab. 341 | config = AutoConfig.from_pretrained( 342 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 343 | num_labels=num_labels, 344 | finetuning_task=data_args.task_name, 345 | cache_dir=model_args.cache_dir, 346 | revision=model_args.model_revision, 347 | use_auth_token=True if model_args.use_auth_token else None, 348 | ) 349 | config.use_flash = model_args.use_flash 350 | tokenizer = AutoTokenizer.from_pretrained( 351 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 352 | cache_dir=model_args.cache_dir, 353 | use_fast=model_args.use_fast_tokenizer, 354 | revision=model_args.model_revision, 355 | use_auth_token=True if model_args.use_auth_token else None, 356 | ) 357 | if config.model_type == "gpt2": 358 | model_class = GPT2ForSequenceClassification 359 | elif config.model_type == "gpt_neo": 360 | model_class = GPTNeoForSequenceClassification 361 | else: 362 | model_class = AutoModelForSequenceClassification 363 | model = model_class.from_pretrained( 364 | model_args.model_name_or_path, 365 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 366 | config=config, 367 | cache_dir=model_args.cache_dir, 368 | revision=model_args.model_revision, 369 | use_auth_token=True if model_args.use_auth_token else None, 370 | ) 371 | #Added for GPT 372 | if tokenizer.pad_token_id is None: 373 | print('Adding [PAD] token to tokenizer and model word embeddings.') 374 | num_added_tokens = tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 375 | tokenizer.add_tokens(["<|CONTEXT|>", "<|QUESTION1|>", "<|QUESTION2|>", "<|ANSWER|>"]) 376 | embedding_layer = model.resize_token_embeddings(len(tokenizer)) 377 | config.pad_token_id = tokenizer.pad_token_id 378 | 379 | # Preprocessing the raw_datasets 380 | if data_args.task_name is not None: 381 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 382 | else: 383 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 384 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 385 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 386 | sentence1_key, sentence2_key = "sentence1", "sentence2" 387 | elif "sentence" in non_label_column_names: 388 | sentence1_key, sentence2_key = "sentence", None 389 | else: 390 | if len(non_label_column_names) >= 2: 391 | sentence1_key, sentence2_key = non_label_column_names[:2] 392 | else: 393 | sentence1_key, sentence2_key = non_label_column_names[0], None 394 | 395 | # Padding strategy 396 | if data_args.pad_to_max_length: 397 | padding = "max_length" 398 | else: 399 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 400 | padding = False 401 | 402 | # Some models have set the order of the labels to use, so let's make sure we do use it. 403 | label_to_id = None 404 | if ( 405 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 406 | and data_args.task_name is not None 407 | and not is_regression 408 | ): 409 | # Some have all caps in their config, some don't. 410 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 411 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 412 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 413 | else: 414 | logger.warning( 415 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 416 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 417 | "\nIgnoring the model labels as a result.", 418 | ) 419 | elif data_args.task_name is None and not is_regression: 420 | label_to_id = {v: i for i, v in enumerate(label_list)} 421 | 422 | if label_to_id is not None: 423 | model.config.label2id = label_to_id 424 | model.config.id2label = {id: label for label, id in config.label2id.items()} 425 | 426 | if data_args.max_seq_length > tokenizer.model_max_length: 427 | logger.warning( 428 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 429 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 430 | ) 431 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 432 | 433 | #def modify_sentence1(text): 434 | #return "<|CONTEXT|>" + text 435 | 436 | #def modify_sentence2(text): 437 | #return "<|QUESTION|>" + text + "<|ANSWER|>" 438 | 439 | def preprocess_function(examples): 440 | 441 | # Tokenize the texts 442 | contexts = examples[sentence2_key] 443 | questions = examples[sentence1_key] 444 | 445 | args = ( 446 | (examples[sentence1_key],) if sentence2_key is None else (contexts, questions) 447 | ) 448 | 449 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 450 | 451 | #Added for GPT2 452 | if config.model_type in ["gpt2"] and data_args.gpt2_append_eos_tok: 453 | assert padding == "max_length" 454 | assert sorted(result.keys()) == sorted(["input_ids", "attention_mask"]) 455 | input_ids = torch.tensor(result["input_ids"]) 456 | attention_mask = torch.tensor(result["attention_mask"]) 457 | sequence_lengths = torch.clamp(input_ids.ne(tokenizer.pad_token_id).sum(-1), max=max_seq_length-1) 458 | input_ids[range(len(input_ids)), sequence_lengths] = tokenizer.eos_token_id 459 | attention_mask[range(len(input_ids)), sequence_lengths] = 1 460 | result["input_ids"] = input_ids.tolist() 461 | result["attention_mask"] = attention_mask.tolist() 462 | 463 | # Map labels to IDs (not necessary for GLUE tasks) 464 | if label_to_id is not None and "label" in examples: 465 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 466 | return result 467 | 468 | with training_args.main_process_first(desc="dataset map pre-processing"): 469 | raw_datasets = raw_datasets.map( 470 | preprocess_function, 471 | batched=True, 472 | num_proc=data_args.preprocessing_num_workers, 473 | load_from_cache_file=not data_args.overwrite_cache, 474 | desc="Running tokenizer on dataset", 475 | ) 476 | if training_args.do_train: 477 | if "train" not in raw_datasets: 478 | raise ValueError("--do_train requires a train dataset") 479 | train_dataset = raw_datasets["train"] 480 | if data_args.max_train_samples is not None: 481 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 482 | 483 | if training_args.do_eval: 484 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 485 | raise ValueError("--do_eval requires a validation dataset") 486 | eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 487 | if data_args.max_eval_samples is not None: 488 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 489 | 490 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 491 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 492 | raise ValueError("--do_predict requires a test dataset") 493 | predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] 494 | if data_args.max_predict_samples is not None: 495 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 496 | 497 | # Log a few random samples from the training set: 498 | # if training_args.do_train: 499 | # for index in random.sample(range(len(train_dataset)), 3): 500 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 501 | 502 | 503 | 504 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 505 | # predictions and label_ids field) and has to return a dictionary string to float. 506 | def compute_metrics(p: EvalPrediction): 507 | # Get the metric function 508 | if data_args.task_name is not None: 509 | metric = load_metric("glue", data_args.task_name) 510 | else: 511 | metric = load_metric("accuracy") 512 | 513 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 514 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 515 | if data_args.task_name is not None: 516 | result = metric.compute(predictions=preds, references=p.label_ids) 517 | if len(result) > 1: 518 | result["combined_score"] = np.mean(list(result.values())).item() 519 | return result 520 | elif data_args.metric_name == "pearsonr": 521 | from scipy.stats import pearsonr as scipy_pearsonr 522 | pearsonr = float(scipy_pearsonr(p.label_ids, preds)[0]) 523 | return {"pearsonr": pearsonr} 524 | elif data_args.metric_name == "PRF1": 525 | TP = ((preds == p.label_ids) & (preds != 0)).astype(int).sum().item() 526 | P_total = (preds != 0).astype(int).sum().item() 527 | L_total = (p.label_ids != 0).astype(int).sum().item() 528 | P = TP / P_total if P_total else 0 529 | R = TP / L_total if L_total else 0 530 | F1 = 2 * P*R/(P+R) if (P+R) else 0 531 | return {"precision": P, "recall": R, "F1": F1} 532 | elif is_regression: 533 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 534 | else: 535 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 536 | 537 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 538 | if data_args.pad_to_max_length: 539 | data_collator = default_data_collator 540 | elif training_args.fp16: 541 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 542 | else: 543 | data_collator = None 544 | 545 | # Initialize our Trainer 546 | trainer = Trainer( 547 | model=model, 548 | args=training_args, 549 | train_dataset=train_dataset if training_args.do_train else None, 550 | eval_dataset=eval_dataset if training_args.do_eval else None, 551 | compute_metrics=compute_metrics, 552 | tokenizer=tokenizer, 553 | data_collator=data_collator, 554 | ) 555 | 556 | # Training 557 | if training_args.do_train: 558 | checkpoint = None 559 | if training_args.resume_from_checkpoint is not None: 560 | checkpoint = training_args.resume_from_checkpoint 561 | elif last_checkpoint is not None: 562 | checkpoint = last_checkpoint 563 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 564 | metrics = train_result.metrics 565 | max_train_samples = ( 566 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 567 | ) 568 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 569 | 570 | #trainer.save_model() # Saves the tokenizer too for easy upload 571 | 572 | trainer.log_metrics("train", metrics) 573 | trainer.save_metrics("train", metrics) 574 | trainer.save_state() 575 | 576 | # Evaluation 577 | if training_args.do_eval: 578 | logger.info("*** Evaluate ***") 579 | 580 | # Loop to handle MNLI double evaluation (matched, mis-matched) 581 | tasks = [data_args.task_name] 582 | eval_datasets = [eval_dataset] 583 | if data_args.task_name == "mnli": 584 | tasks.append("mnli-mm") 585 | eval_datasets.append(raw_datasets["validation_mismatched"]) 586 | 587 | for eval_dataset, task in zip(eval_datasets, tasks): 588 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 589 | 590 | max_eval_samples = ( 591 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 592 | ) 593 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 594 | 595 | trainer.log_metrics("eval", metrics) 596 | trainer.save_metrics("eval", metrics) 597 | 598 | if training_args.do_predict: 599 | logger.info("*** Predict ***") 600 | 601 | # Loop to handle MNLI double evaluation (matched, mis-matched) 602 | tasks = [data_args.task_name] 603 | predict_datasets = [predict_dataset] 604 | if data_args.task_name == "mnli": 605 | tasks.append("mnli-mm") 606 | predict_datasets.append(raw_datasets["test_mismatched"]) 607 | 608 | for predict_dataset, task in zip(predict_datasets, tasks): 609 | metrics = trainer.evaluate(eval_dataset=predict_dataset, metric_key_prefix="test") 610 | 611 | max_test_samples = ( 612 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(predict_dataset) 613 | ) 614 | metrics["test_samples"] = min(max_test_samples, len(predict_dataset)) 615 | 616 | trainer.log_metrics("test", metrics) 617 | trainer.save_metrics("test", metrics) 618 | trainer.log(metrics) 619 | 620 | 621 | if training_args.push_to_hub: 622 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} 623 | if data_args.task_name is not None: 624 | kwargs["language"] = "en" 625 | kwargs["dataset_tags"] = "glue" 626 | kwargs["dataset_args"] = data_args.task_name 627 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" 628 | 629 | trainer.push_to_hub(**kwargs) 630 | 631 | 632 | def _mp_fn(index): 633 | # For xla_spawn (TPUs) 634 | main() 635 | 636 | 637 | if __name__ == "__main__": 638 | main() 639 | -------------------------------------------------------------------------------- /finetune/utils/custom_modeling_gpt_neo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch GPT Neo model. torch==4.9.0 """ 16 | 17 | 18 | import os 19 | from typing import Tuple 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | from torch.nn import CrossEntropyLoss, MSELoss 25 | 26 | from transformers.activations import ACT2FN 27 | from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward 28 | from transformers.modeling_outputs import ( 29 | BaseModelOutputWithPast, 30 | BaseModelOutputWithPastAndCrossAttentions, 31 | CausalLMOutputWithCrossAttentions, 32 | CausalLMOutputWithPast, 33 | SequenceClassifierOutputWithPast, 34 | ) 35 | from transformers.modeling_utils import PreTrainedModel 36 | from transformers.utils import logging 37 | from transformers.models.gpt_neo.configuration_gpt_neo import GPTNeoConfig 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | _CONFIG_FOR_DOC = "GPTNeoConfig" 43 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 44 | 45 | GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [ 46 | "EleutherAI/gpt-neo-1.3B", 47 | # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo 48 | ] 49 | 50 | _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B" 51 | 52 | 53 | def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): 54 | """Load tf checkpoints in a pytorch model""" 55 | try: 56 | import re 57 | 58 | import tensorflow as tf 59 | except ImportError: 60 | logger.error( 61 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 62 | "https://www.tensorflow.org/install/ for installation instructions." 63 | ) 64 | raise 65 | tf_path = os.path.abspath(gpt_neo_checkpoint_path) 66 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 67 | # Load weights from TF model 68 | init_vars = tf.train.list_variables(tf_path) 69 | names = [] 70 | arrays = [] 71 | for name, shape in init_vars: 72 | if "global_step" not in name and "adam" not in name: 73 | array = tf.train.load_variable(tf_path, name) 74 | array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy() 75 | name = name.replace("attn/q", "attn/attention/q_proj/w") 76 | name = name.replace("attn/k", "attn/attention/k_proj/w") 77 | name = name.replace("attn/v", "attn/attention/v_proj/w") 78 | name = name.replace("attn/o", "attn/attention/out_proj/w") 79 | name = name.replace("norm_1", "ln_1") 80 | name = name.replace("norm_2", "ln_2") 81 | name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b") 82 | name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w") 83 | name = name.replace("conv1d_main/c_fc/bias", "c_fc/b") 84 | name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w") 85 | name = name.replace("conv1d_main/c_proj/bias", "c_proj/b") 86 | 87 | names.append(name) 88 | arrays.append(array) 89 | 90 | for name, array in zip(names, arrays): 91 | name = name[5:] # skip "gpt2/" 92 | name = name.split("/") 93 | pointer = model.transformer 94 | for m_name in name: 95 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 96 | scope_names = re.split(r"(\d+)", m_name) 97 | else: 98 | scope_names = [m_name] 99 | if scope_names[0] == "w" or scope_names[0] == "g": 100 | pointer = getattr(pointer, "weight") 101 | elif scope_names[0] == "b": 102 | pointer = getattr(pointer, "bias") 103 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 104 | pointer = getattr(pointer, scope_names[0]) 105 | pointer = getattr(pointer, "weight") 106 | else: 107 | pointer = getattr(pointer, scope_names[0]) 108 | if len(scope_names) >= 2: 109 | num = int(scope_names[1]) 110 | pointer = pointer[num] 111 | 112 | if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: 113 | array = array.transpose() 114 | 115 | if name == ["wte"]: 116 | # if vocab is padded, then trim off the padding embeddings 117 | array = array[: config.vocab_size] 118 | 119 | try: 120 | assert ( 121 | pointer.shape == array.shape 122 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}" 123 | except AssertionError as e: 124 | e.args += (pointer.shape, array.shape) 125 | raise 126 | print(f"Initialize PyTorch weight {name}") 127 | pointer.data = torch.from_numpy(array) 128 | 129 | # init the final linear layer using word embeddings 130 | embs = model.transformer.wte.weight 131 | lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False) 132 | lin.weight = embs 133 | model.set_output_embeddings(lin) 134 | return model 135 | 136 | 137 | class GPTNeoAttentionMixin: 138 | """ 139 | A few attention related utilities for attention modules in GPT Neo, to be used as a mixin. 140 | """ 141 | 142 | @staticmethod 143 | def _get_block_length_and_num_blocks(seq_length, window_size): 144 | """ 145 | Computes ``block_length`` and ``num_blocks`` such that ``seq_length`` becomes evenly divisible by 146 | ``block_length``. 147 | """ 148 | block_length = window_size 149 | while seq_length % block_length != 0: 150 | block_length -= 1 151 | num_blocks = seq_length // block_length 152 | return block_length, num_blocks 153 | 154 | @staticmethod 155 | def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True): 156 | """ 157 | Used to implement attention between consecutive blocks. This method assumes that dim 1 of :obj:`tensor` 158 | represents the :obj:`seq_length` dimension. It splits :obj:`seq_length` dimension into :obj:`num_blocks` and 159 | :obj:`window_size` + :obj:`block_length`. It pads the :obj:`seq_length` dimension if necessary. 160 | 161 | Example:: 162 | 163 | tensor: torch.tensor([[[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]) 164 | with shape (1, 8, 1) 165 | block_length = window_size = 4 166 | _look_back => 167 | torch.tensor([[[[ 0.0000], [ 0.0000], [ 0.0000], [ 0.0000], [ 0.4983], [ 2.6918], [-0.0071], [ 1.0492]], 168 | [[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]]) 169 | 170 | Args: 171 | tensor (:obj:`torch.Tensor`): tensor of shape :obj:`[batch_size, seq_length, hidden_dim]` or :obj:`[batch_size, seq_length]` 172 | block_length (:obj:`int`): An integer specifying the length of each block, used as a step size when creating the blocks. 173 | window_size (:obj:`int`): An integer specifying the size of attention window, used to calculate the final block size when creating the block. 174 | pad_value (obj:`int`): An integer specifying the value to use when padding the :obj:`tensor`. 175 | is_key_value (:obj:`bool`): A boolean indicating if the :obj:`tensor` is a key/value tensor. 176 | 177 | Returns: 178 | tensor of shape :obj:`[batch_size, num_blocks, window_size + block_length, ...]` if :obj:`is_key_value` is 179 | :obj:`True` else a tensor of shape :obj:`[batch_size, window_size + block_length, num_blocks, ...]` 180 | """ 181 | if len(tensor.shape) == 3: 182 | padding_side = (0, 0, window_size, 0) 183 | elif len(tensor.shape) == 2: 184 | padding_side = (window_size, 0) 185 | else: 186 | raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}") 187 | 188 | padded_tensor = nn.functional.pad(tensor, padding_side, value=pad_value) 189 | padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length) 190 | 191 | if is_key_value: 192 | padded_tensor = padded_tensor.transpose(-2, -1) 193 | return padded_tensor 194 | 195 | @staticmethod 196 | def _split_seq_length_dim_to(tensors, dim_factor_1, dim_factor_2): 197 | """ 198 | Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims 199 | """ 200 | batch_size = tensors.shape[0] 201 | split_dim_shape = (batch_size, dim_factor_1, dim_factor_2) 202 | 203 | if len(tensors.shape) == 3: 204 | return torch.reshape(tensors, split_dim_shape + (-1,)) 205 | elif len(tensors.shape) == 2: 206 | return torch.reshape(tensors, split_dim_shape) 207 | else: 208 | raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}") 209 | 210 | @staticmethod 211 | def create_local_attention_mask(batch_size, seq_length, window_size, device, attention_mask=None): 212 | block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) 213 | indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1) 214 | 215 | query_indices = GPTNeoAttentionMixin._split_seq_length_dim_to(indices, num_blocks, block_length) 216 | key_indices = GPTNeoAttentionMixin._look_back(indices, block_length, window_size, is_key_value=False) 217 | 218 | # create mask tensor such that each block contains a causal_mask for that block 219 | causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)) 220 | 221 | if attention_mask is None: 222 | attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device) 223 | 224 | # A block can also be padded because of the _look_back operation 225 | # look back into the attention_block such that it will also get padded the same way 226 | # and have 0s in the padded position 227 | attention_mask = GPTNeoAttentionMixin._look_back(attention_mask, block_length, window_size, is_key_value=False) 228 | attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimension to account for hidden_dim 229 | 230 | # Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation) 231 | # will contain 0s. 232 | # This also makes sure that other positions ignored by the attention_mask will also be ignored 233 | # in the causal_mask. 234 | causal_mask = causal_mask * attention_mask 235 | 236 | # In GPT Neo's local attention each window can attend to at most window_size tokens 237 | # rest of the tokens should be ignored. 238 | relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1) 239 | visible = torch.gt(relative_position, -window_size) 240 | 241 | causal_mask = causal_mask * visible 242 | causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimension to account for num_heads 243 | 244 | return causal_mask 245 | 246 | def _split_heads(self, tensor, num_heads, attn_head_size): 247 | """ 248 | Splits hidden_size dim into attn_head_size and num_heads 249 | """ 250 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 251 | tensor = tensor.view(*new_shape) 252 | if len(tensor.shape) == 5: 253 | return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) 254 | elif len(tensor.shape) == 4: 255 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 256 | else: 257 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") 258 | 259 | def _merge_heads(self, tensor, num_heads, attn_head_size): 260 | """ 261 | Merges attn_head_size dim and num_attn_heads dim into hidden_size 262 | """ 263 | if len(tensor.shape) == 5: 264 | tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() 265 | elif len(tensor.shape) == 4: 266 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 267 | else: 268 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") 269 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 270 | return tensor.view(new_shape) 271 | 272 | def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None): 273 | # Keep the attention weights computation in fp32 to avoid overflow issues 274 | query = query.to(torch.float32) 275 | key = key.to(torch.float32) 276 | 277 | with torch.cuda.amp.autocast(enabled=False): 278 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 279 | attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype)) 280 | 281 | if attention_mask is not None: 282 | # Apply the attention mask 283 | attn_weights = attn_weights + attention_mask 284 | 285 | attn_weights = nn.Softmax(dim=-1)(attn_weights) 286 | attn_weights = attn_weights.to(value.dtype) 287 | attn_weights = attn_dropout(attn_weights) 288 | 289 | # Mask heads if we want to 290 | if head_mask is not None: 291 | attn_weights = attn_weights * head_mask 292 | 293 | attn_output = torch.matmul(attn_weights, value) 294 | 295 | return attn_output, attn_weights 296 | 297 | 298 | class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin): 299 | def __init__(self, config): 300 | super().__init__() 301 | 302 | max_positions = config.max_position_embeddings 303 | self.register_buffer( 304 | "bias", 305 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 306 | 1, 1, max_positions, max_positions 307 | ), 308 | ) 309 | self.register_buffer("masked_bias", torch.tensor(-1e9)) 310 | 311 | self.attn_dropout = nn.Dropout(config.attention_dropout) 312 | self.resid_dropout = nn.Dropout(config.resid_dropout) 313 | 314 | self.embed_dim = config.hidden_size 315 | self.num_heads = config.num_heads 316 | self.head_dim = self.embed_dim // self.num_heads 317 | if self.head_dim * self.num_heads != self.embed_dim: 318 | raise ValueError( 319 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 320 | ) 321 | 322 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 323 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 324 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 325 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) 326 | 327 | def forward( 328 | self, 329 | hidden_states, 330 | attention_mask=None, 331 | layer_past=None, 332 | head_mask=None, 333 | use_cache=False, 334 | output_attentions=False, 335 | ): 336 | 337 | query = self.q_proj(hidden_states) 338 | key = self.k_proj(hidden_states) 339 | value = self.v_proj(hidden_states) 340 | 341 | query = self._split_heads(query, self.num_heads, self.head_dim) 342 | key = self._split_heads(key, self.num_heads, self.head_dim) 343 | value = self._split_heads(value, self.num_heads, self.head_dim) 344 | 345 | if layer_past is not None: 346 | past_key = layer_past[0] 347 | past_value = layer_past[1] 348 | key = torch.cat((past_key, key), dim=-2) 349 | value = torch.cat((past_value, value), dim=-2) 350 | 351 | if use_cache is True: 352 | present = (key, value) 353 | else: 354 | present = None 355 | 356 | query_length, key_length = query.size(-2), key.size(-2) 357 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 358 | 359 | attn_output, attn_weights = self._attn( 360 | query, key, value, causal_mask, self.masked_bias, self.attn_dropout, attention_mask, head_mask 361 | ) 362 | 363 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 364 | attn_output = self.out_proj(attn_output) 365 | attn_output = self.resid_dropout(attn_output) 366 | 367 | outputs = (attn_output, present) 368 | if output_attentions: 369 | outputs += (attn_weights,) 370 | 371 | return outputs # a, present, (attentions) 372 | 373 | 374 | class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin): 375 | def __init__(self, config): 376 | super().__init__() 377 | 378 | self.register_buffer("masked_bias", torch.tensor(-1e9)) 379 | 380 | self.attn_dropout = nn.Dropout(config.attention_dropout) 381 | self.resid_dropout = nn.Dropout(config.resid_dropout) 382 | 383 | self.embed_dim = config.hidden_size 384 | self.num_heads = config.num_heads 385 | self.head_dim = self.embed_dim // self.num_heads 386 | if self.head_dim * self.num_heads != self.embed_dim: 387 | raise ValueError( 388 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 389 | ) 390 | 391 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 392 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 393 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 394 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) 395 | 396 | self.window_size = config.window_size 397 | 398 | def forward( 399 | self, 400 | hidden_states, 401 | attention_mask, 402 | layer_past=None, 403 | head_mask=None, 404 | use_cache=False, 405 | output_attentions=False, 406 | ): 407 | query = self.q_proj(hidden_states) 408 | 409 | if layer_past is not None: 410 | past = layer_past[0] 411 | key_value_hidden_states = torch.cat([past, hidden_states], dim=1) 412 | past_length = past.size()[1] 413 | else: 414 | key_value_hidden_states = hidden_states 415 | past_length = 0 416 | 417 | key = self.k_proj(key_value_hidden_states) 418 | value = self.v_proj(key_value_hidden_states) 419 | 420 | # compute block length and num_blocks 421 | batch_size, seq_length = hidden_states.shape[:2] 422 | full_seq_length = seq_length + past_length 423 | block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size) 424 | 425 | # create buckets 426 | if layer_past is not None: 427 | # we just need 1 block with block_length 1 when caching is enabled 428 | query = self._split_seq_length_dim_to(query, 1, 1) 429 | else: 430 | query = self._split_seq_length_dim_to(query, num_blocks, block_length) 431 | 432 | key = self._look_back(key, block_length, self.window_size) 433 | value = self._look_back(value, block_length, self.window_size) 434 | 435 | # select key/value vectors only for the last block 436 | if layer_past is not None: 437 | key = key[:, -1:, ...] 438 | value = value[:, -1:, ...] 439 | 440 | query = self._split_heads(query, self.num_heads, self.head_dim) 441 | key = self._split_heads(key, self.num_heads, self.head_dim) 442 | value = self._split_heads(value, self.num_heads, self.head_dim) 443 | 444 | if layer_past is not None: 445 | # only take the mask for the last block 446 | attention_mask = attention_mask[:, -1:, :, -1:, :] 447 | 448 | # attn 449 | attn_output, attn_weights = self._attn( 450 | query, 451 | key, 452 | value, 453 | causal_mask=attention_mask, 454 | masked_bias=self.masked_bias, 455 | attn_dropout=self.attn_dropout, 456 | head_mask=head_mask, 457 | ) 458 | 459 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 460 | attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim) 461 | 462 | attn_output = self.out_proj(attn_output) 463 | attn_output = self.resid_dropout(attn_output) 464 | 465 | outputs = (attn_output,) 466 | if output_attentions: 467 | outputs += (attn_weights,) 468 | 469 | return outputs # a, (attentions) 470 | 471 | 472 | class GPTNeoAttention(nn.Module): 473 | def __init__(self, config, layer_id=0): 474 | super().__init__() 475 | self.layer_id = layer_id 476 | self.attention_layers = config.attention_layers 477 | self.attention_type = self.attention_layers[layer_id] 478 | 479 | if self.attention_type == "global": 480 | self.attention = GPTNeoSelfAttention(config) 481 | elif self.attention_type == "local": 482 | self.attention = GPTNeoLocalSelfAttention(config) 483 | else: 484 | raise NotImplementedError( 485 | "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: " 486 | f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only." 487 | ) 488 | 489 | def forward( 490 | self, 491 | hidden_states, 492 | layer_past=None, 493 | attention_mask=None, 494 | head_mask=None, 495 | use_cache=False, 496 | output_attentions=False, 497 | ): 498 | outputs = self.attention( 499 | hidden_states, 500 | attention_mask=attention_mask, 501 | layer_past=layer_past, 502 | head_mask=head_mask, 503 | use_cache=use_cache, 504 | output_attentions=output_attentions, 505 | ) 506 | 507 | # cache the hidden_states instead of key_value_states 508 | # for local attention layer 509 | if self.attention_type == "local": 510 | if layer_past is None: 511 | past = hidden_states 512 | else: 513 | past = torch.cat([layer_past[0], hidden_states], dim=1) 514 | outputs = (outputs[0], (past,)) + outputs[1:] 515 | return outputs 516 | 517 | 518 | class GPTNeoMLP(nn.Module): 519 | def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size 520 | super().__init__() 521 | embed_dim = config.hidden_size 522 | self.c_fc = nn.Linear(embed_dim, intermediate_size) 523 | self.c_proj = nn.Linear(intermediate_size, embed_dim) 524 | self.act = ACT2FN[config.activation_function] 525 | self.dropout = nn.Dropout(config.resid_dropout) 526 | 527 | def forward(self, hidden_states): 528 | hidden_states = self.c_fc(hidden_states) 529 | hidden_states = self.act(hidden_states) 530 | hidden_states = self.c_proj(hidden_states) 531 | hidden_states = self.dropout(hidden_states) 532 | return hidden_states 533 | 534 | 535 | class GPTNeoBlock(nn.Module): 536 | def __init__(self, config, layer_id): 537 | super().__init__() 538 | hidden_size = config.hidden_size 539 | inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size 540 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 541 | self.attn = GPTNeoAttention(config, layer_id) 542 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 543 | self.mlp = GPTNeoMLP(inner_dim, config) 544 | 545 | def forward( 546 | self, 547 | hidden_states, 548 | layer_past=None, 549 | attention_mask=None, 550 | head_mask=None, 551 | use_cache=False, 552 | output_attentions=False, 553 | ): 554 | residual = hidden_states 555 | hidden_states = self.ln_1(hidden_states) 556 | attn_outputs = self.attn( 557 | hidden_states, 558 | layer_past=layer_past, 559 | attention_mask=attention_mask, 560 | head_mask=head_mask, 561 | use_cache=use_cache, 562 | output_attentions=output_attentions, 563 | ) 564 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 565 | outputs = attn_outputs[1:] 566 | # residual connection 567 | hidden_states = attn_output + residual 568 | 569 | residual = hidden_states 570 | hidden_states = self.ln_2(hidden_states) 571 | feed_forward_hidden_states = self.mlp(hidden_states) 572 | # residual connection 573 | hidden_states = residual + feed_forward_hidden_states 574 | 575 | if use_cache: 576 | outputs = (hidden_states,) + outputs 577 | else: 578 | outputs = (hidden_states,) + outputs[1:] 579 | 580 | return outputs # hidden_states, present, (attentions, cross_attentions) 581 | 582 | 583 | class GPTNeoPreTrainedModel(PreTrainedModel): 584 | """ 585 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 586 | models. 587 | """ 588 | 589 | config_class = GPTNeoConfig 590 | load_tf_weights = load_tf_weights_in_gpt_neo 591 | base_model_prefix = "transformer" 592 | 593 | def __init__(self, *inputs, **kwargs): 594 | super().__init__(*inputs, **kwargs) 595 | 596 | def _init_weights(self, module): 597 | """Initialize the weights.""" 598 | if isinstance(module, (nn.Linear,)): 599 | # Slightly different from the TF version which uses truncated_normal for initialization 600 | # cf https://github.com/pytorch/pytorch/pull/5617 601 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 602 | if module.bias is not None: 603 | module.bias.data.zero_() 604 | elif isinstance(module, nn.Embedding): 605 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 606 | if module.padding_idx is not None: 607 | module.weight.data[module.padding_idx].zero_() 608 | elif isinstance(module, nn.LayerNorm): 609 | module.bias.data.zero_() 610 | module.weight.data.fill_(1.0) 611 | 612 | 613 | GPT_NEO_START_DOCSTRING = r""" 614 | 615 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 616 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 617 | pruning heads etc.) 618 | 619 | This model is also a PyTorch `torch.nn.Module `__ 620 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 621 | general usage and behavior. 622 | 623 | Parameters: 624 | config (:class:`~transformers.GPTNeoConfig`): Model configuration class with all the parameters of the model. 625 | Initializing with a config file does not load the weights associated with the model, only the 626 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 627 | weights. 628 | """ 629 | 630 | GPT_NEO_INPUTS_DOCSTRING = r""" 631 | Args: 632 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): 633 | :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else 634 | ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input 635 | sequence tokens in the vocabulary. 636 | 637 | If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be 638 | passed as ``input_ids``. 639 | 640 | Indices can be obtained using :class:`~transformers.GPTNeoTokenizer`. See 641 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 642 | details. 643 | 644 | `What are input IDs? <../glossary.html#input-ids>`__ 645 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.num_layers`): 646 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 647 | :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which 648 | have their past given to this model should not be passed as ``input_ids`` as they have already been 649 | computed. 650 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 651 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 652 | 653 | - 1 for tokens that are **not masked**, 654 | - 0 for tokens that are **masked**. 655 | 656 | `What are attention masks? <../glossary.html#attention-mask>`__ 657 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): 658 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 659 | 1]``: 660 | 661 | - 0 corresponds to a `sentence A` token, 662 | - 1 corresponds to a `sentence B` token. 663 | 664 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 665 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 666 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 667 | config.max_position_embeddings - 1]``. 668 | 669 | `What are position IDs? <../glossary.html#position-ids>`_ 670 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 671 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 672 | 673 | - 1 indicates the head is **not masked**, 674 | - 0 indicates the head is **masked**. 675 | 676 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 677 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 678 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 679 | vectors than the model's internal embedding lookup matrix. 680 | 681 | If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see 682 | :obj:`past_key_values`). 683 | use_cache (:obj:`bool`, `optional`): 684 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 685 | decoding (see :obj:`past_key_values`). 686 | output_attentions (:obj:`bool`, `optional`): 687 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 688 | tensors for more detail. 689 | output_hidden_states (:obj:`bool`, `optional`): 690 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 691 | more detail. 692 | return_dict (:obj:`bool`, `optional`): 693 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 694 | """ 695 | 696 | 697 | @add_start_docstrings( 698 | "The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.", 699 | GPT_NEO_START_DOCSTRING, 700 | ) 701 | class GPTNeoModel(GPTNeoPreTrainedModel): 702 | def __init__(self, config): 703 | super().__init__(config) 704 | 705 | self.embed_dim = config.hidden_size 706 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 707 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 708 | self.drop = nn.Dropout(config.embed_dropout) 709 | self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) 710 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 711 | 712 | self.init_weights() 713 | 714 | def get_input_embeddings(self): 715 | return self.wte 716 | 717 | def set_input_embeddings(self, new_embeddings): 718 | self.wte = new_embeddings 719 | 720 | #@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) 721 | #@add_code_sample_docstrings( 722 | #tokenizer_class=_TOKENIZER_FOR_DOC, 723 | #checkpoint=_CHECKPOINT_FOR_DOC, 724 | #output_type=BaseModelOutputWithPastAndCrossAttentions, 725 | #config_class=_CONFIG_FOR_DOC, 726 | #) 727 | def forward( 728 | self, 729 | input_ids=None, 730 | past_key_values=None, 731 | attention_mask=None, 732 | token_type_ids=None, 733 | position_ids=None, 734 | head_mask=None, 735 | inputs_embeds=None, 736 | use_cache=None, 737 | output_attentions=None, 738 | output_hidden_states=None, 739 | return_dict=None, 740 | ): 741 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 742 | output_hidden_states = ( 743 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 744 | ) 745 | use_cache = use_cache if use_cache is not None else self.config.use_cache 746 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 747 | 748 | if input_ids is not None and inputs_embeds is not None: 749 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 750 | elif input_ids is not None: 751 | input_shape = input_ids.size() 752 | input_ids = input_ids.view(-1, input_shape[-1]) 753 | batch_size = input_ids.shape[0] 754 | elif inputs_embeds is not None: 755 | input_shape = inputs_embeds.size()[:-1] 756 | batch_size = inputs_embeds.shape[0] 757 | else: 758 | raise ValueError("You have to specify either input_ids or inputs_embeds") 759 | 760 | device = input_ids.device if input_ids is not None else inputs_embeds.device 761 | 762 | if token_type_ids is not None: 763 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 764 | if position_ids is not None: 765 | position_ids = position_ids.view(-1, input_shape[-1]) 766 | 767 | if past_key_values is None: 768 | past_length = 0 769 | past_key_values = tuple([None] * len(self.h)) 770 | else: 771 | past_length = past_key_values[0][0].size(-2) 772 | 773 | device = input_ids.device if input_ids is not None else inputs_embeds.device 774 | if position_ids is None: 775 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 776 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 777 | 778 | # Attention mask. 779 | if attention_mask is not None: 780 | assert batch_size > 0, "batch_size has to be defined and > 0" 781 | global_attention_mask = attention_mask.view(batch_size, -1) 782 | # We create a 3D attention mask from a 2D tensor mask. 783 | # Sizes are [batch_size, 1, 1, to_seq_length] 784 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 785 | # this attention mask is more simple than the triangular masking of causal attention 786 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 787 | global_attention_mask = global_attention_mask[:, None, None, :] 788 | 789 | # Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for 790 | # masked positions, this operation will create a tensor which is 0.0 for 791 | # positions we want to attend and -10000.0 for masked positions. 792 | # Since we are adding it to the raw scores before the softmax, this is 793 | # effectively the same as removing these entirely. 794 | global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility 795 | global_attention_mask = (1.0 - global_attention_mask) * -10000.0 796 | else: 797 | global_attention_mask = None 798 | 799 | # Local causal attention mask 800 | batch_size, seq_length = input_shape 801 | full_seq_length = seq_length + past_length 802 | local_attention_mask = GPTNeoAttentionMixin.create_local_attention_mask( 803 | batch_size, full_seq_length, self.config.window_size, device, attention_mask 804 | ) 805 | 806 | # Prepare head mask if needed 807 | # 1.0 in head_mask indicate we keep the head 808 | # attention_probs has shape bsz x num_heads x N x N 809 | # head_mask has shape n_layer x batch x num_heads x N x N 810 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 811 | 812 | if inputs_embeds is None: 813 | inputs_embeds = self.wte(input_ids) 814 | position_embeds = self.wpe(position_ids) 815 | hidden_states = inputs_embeds + position_embeds 816 | 817 | if token_type_ids is not None: 818 | token_type_embeds = self.wte(token_type_ids) 819 | hidden_states = hidden_states + token_type_embeds 820 | 821 | hidden_states = self.drop(hidden_states) 822 | 823 | output_shape = input_shape + (hidden_states.size(-1),) 824 | 825 | presents = () if use_cache else None 826 | all_self_attentions = () if output_attentions else None 827 | all_hidden_states = () if output_hidden_states else None 828 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 829 | attn_type = self.config.attention_layers[i] 830 | attn_mask = global_attention_mask if attn_type == "global" else local_attention_mask 831 | 832 | if output_hidden_states: 833 | all_hidden_states = all_hidden_states + (hidden_states,) 834 | 835 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 836 | 837 | if use_cache: 838 | logger.warning( 839 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 840 | "`use_cache=False`..." 841 | ) 842 | use_cache = False 843 | 844 | def create_custom_forward(module): 845 | def custom_forward(*inputs): 846 | # None for past_key_value 847 | return module(*inputs, use_cache, output_attentions) 848 | 849 | return custom_forward 850 | 851 | outputs = torch.utils.checkpoint.checkpoint( 852 | create_custom_forward(block), 853 | hidden_states, 854 | None, 855 | attn_mask, 856 | head_mask[i], 857 | ) 858 | else: 859 | outputs = block( 860 | hidden_states, 861 | layer_past=layer_past, 862 | attention_mask=attn_mask, 863 | head_mask=head_mask[i], 864 | use_cache=use_cache, 865 | output_attentions=output_attentions, 866 | ) 867 | 868 | hidden_states = outputs[0] 869 | if use_cache is True: 870 | presents = presents + (outputs[1],) 871 | 872 | if output_attentions: 873 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 874 | 875 | hidden_states = self.ln_f(hidden_states) 876 | 877 | hidden_states = hidden_states.view(*output_shape) 878 | # Add last hidden state 879 | if output_hidden_states: 880 | all_hidden_states = all_hidden_states + (hidden_states,) 881 | 882 | if not return_dict: 883 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 884 | 885 | return BaseModelOutputWithPast( 886 | last_hidden_state=hidden_states, 887 | past_key_values=presents, 888 | hidden_states=all_hidden_states, 889 | attentions=all_self_attentions, 890 | ) 891 | 892 | 893 | @add_start_docstrings( 894 | """ 895 | The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input 896 | embeddings). 897 | """, 898 | GPT_NEO_START_DOCSTRING, 899 | ) 900 | class GPTNeoForCausalLM(GPTNeoPreTrainedModel): 901 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 902 | _keys_to_ignore_on_save = [r"lm_head.weight"] 903 | 904 | def __init__(self, config): 905 | super().__init__(config) 906 | self.transformer = GPTNeoModel(config) 907 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 908 | 909 | self.init_weights() 910 | 911 | def get_output_embeddings(self): 912 | return self.lm_head 913 | 914 | def set_output_embeddings(self, new_embeddings): 915 | self.lm_head = new_embeddings 916 | 917 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 918 | token_type_ids = kwargs.get("token_type_ids", None) 919 | # only last token for inputs_ids if past is defined in kwargs 920 | if past: 921 | input_ids = input_ids[:, -1].unsqueeze(-1) 922 | if token_type_ids is not None: 923 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 924 | 925 | attention_mask = kwargs.get("attention_mask", None) 926 | position_ids = kwargs.get("position_ids", None) 927 | 928 | if attention_mask is not None and position_ids is None: 929 | # create position_ids on the fly for batch generation 930 | position_ids = attention_mask.long().cumsum(-1) - 1 931 | position_ids.masked_fill_(attention_mask == 0, 1) 932 | if past: 933 | position_ids = position_ids[:, -1].unsqueeze(-1) 934 | else: 935 | position_ids = None 936 | return { 937 | "input_ids": input_ids, 938 | "past_key_values": past, 939 | "use_cache": kwargs.get("use_cache"), 940 | "position_ids": position_ids, 941 | "attention_mask": attention_mask, 942 | "token_type_ids": token_type_ids, 943 | } 944 | 945 | #@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) 946 | #@add_code_sample_docstrings( 947 | #tokenizer_class=_TOKENIZER_FOR_DOC, 948 | #checkpoint=_CHECKPOINT_FOR_DOC, 949 | #output_type=CausalLMOutputWithCrossAttentions, 950 | #config_class=_CONFIG_FOR_DOC, 951 | #) 952 | def forward( 953 | self, 954 | input_ids=None, 955 | past_key_values=None, 956 | attention_mask=None, 957 | token_type_ids=None, 958 | position_ids=None, 959 | head_mask=None, 960 | inputs_embeds=None, 961 | labels=None, 962 | use_cache=None, 963 | output_attentions=None, 964 | output_hidden_states=None, 965 | return_dict=None, 966 | ): 967 | r""" 968 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 969 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 970 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to 971 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 972 | """ 973 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 974 | 975 | transformer_outputs = self.transformer( 976 | input_ids, 977 | past_key_values=past_key_values, 978 | attention_mask=attention_mask, 979 | token_type_ids=token_type_ids, 980 | position_ids=position_ids, 981 | head_mask=head_mask, 982 | inputs_embeds=inputs_embeds, 983 | use_cache=use_cache, 984 | output_attentions=output_attentions, 985 | output_hidden_states=output_hidden_states, 986 | return_dict=return_dict, 987 | ) 988 | hidden_states = transformer_outputs[0] 989 | 990 | lm_logits = self.lm_head(hidden_states) 991 | 992 | loss = None 993 | if labels is not None: 994 | # Compute loss in fp32 to match with mesh-tf version 995 | # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 996 | lm_logits = lm_logits.to(torch.float32) 997 | 998 | # Shift so that tokens < n predict n 999 | shift_logits = lm_logits[..., :-1, :].contiguous() 1000 | shift_labels = labels[..., 1:].contiguous() 1001 | # Flatten the tokens 1002 | loss_fct = CrossEntropyLoss() 1003 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1004 | 1005 | lm_logits = lm_logits.to(hidden_states.dtype) 1006 | loss = loss.to(hidden_states.dtype) 1007 | 1008 | if not return_dict: 1009 | output = (lm_logits,) + transformer_outputs[1:] 1010 | return ((loss,) + output) if loss is not None else output 1011 | 1012 | return CausalLMOutputWithPast( 1013 | loss=loss, 1014 | logits=lm_logits, 1015 | past_key_values=transformer_outputs.past_key_values, 1016 | hidden_states=transformer_outputs.hidden_states, 1017 | attentions=transformer_outputs.attentions, 1018 | ) 1019 | 1020 | @staticmethod 1021 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1022 | """ 1023 | This function is used to re-order the :obj:`past_key_values` cache if 1024 | :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is 1025 | called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. 1026 | """ 1027 | return tuple( 1028 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1029 | for layer_past in past 1030 | ) 1031 | 1032 | 1033 | @add_start_docstrings( 1034 | """ 1035 | The GPTNeo Model transformer with a sequence classification head on top (linear layer). 1036 | 1037 | :class:`~transformers.GPTNeoForSequenceClassification` uses the last token in order to do the classification, as 1038 | other causal models (e.g. GPT-1) do. 1039 | 1040 | Since it does classification on the last token, it requires to know the position of the last token. If a 1041 | :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each 1042 | row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot 1043 | guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take 1044 | the last value in each row of the batch). 1045 | """, 1046 | GPT_NEO_START_DOCSTRING, 1047 | ) 1048 | class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): 1049 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 1050 | 1051 | def __init__(self, config): 1052 | super().__init__(config) 1053 | self.num_labels = config.num_labels 1054 | self.transformer = GPTNeoModel(config) 1055 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1056 | 1057 | self.init_weights() 1058 | 1059 | #@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) 1060 | #@add_code_sample_docstrings( 1061 | #tokenizer_class=_TOKENIZER_FOR_DOC, 1062 | #checkpoint=_CHECKPOINT_FOR_DOC, 1063 | #output_type=SequenceClassifierOutputWithPast, 1064 | #config_class=_CONFIG_FOR_DOC, 1065 | #) 1066 | def forward( 1067 | self, 1068 | input_ids=None, 1069 | past_key_values=None, 1070 | attention_mask=None, 1071 | token_type_ids=None, 1072 | position_ids=None, 1073 | head_mask=None, 1074 | inputs_embeds=None, 1075 | labels=None, 1076 | use_cache=None, 1077 | output_attentions=None, 1078 | output_hidden_states=None, 1079 | return_dict=None, 1080 | ): 1081 | r""" 1082 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1083 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1084 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 1085 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1086 | """ 1087 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1088 | 1089 | transformer_outputs = self.transformer( 1090 | input_ids, 1091 | past_key_values=past_key_values, 1092 | attention_mask=attention_mask, 1093 | token_type_ids=token_type_ids, 1094 | position_ids=position_ids, 1095 | head_mask=head_mask, 1096 | inputs_embeds=inputs_embeds, 1097 | use_cache=use_cache, 1098 | output_attentions=output_attentions, 1099 | output_hidden_states=output_hidden_states, 1100 | return_dict=return_dict, 1101 | ) 1102 | hidden_states = transformer_outputs[0] 1103 | logits = self.score(hidden_states) 1104 | 1105 | if input_ids is not None: 1106 | batch_size, sequence_length = input_ids.shape[:2] 1107 | else: 1108 | batch_size, sequence_length = inputs_embeds.shape[:2] 1109 | 1110 | assert ( 1111 | self.config.pad_token_id is not None or batch_size == 1 1112 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 1113 | if self.config.pad_token_id is None: 1114 | sequence_lengths = -1 1115 | else: 1116 | if input_ids is not None: 1117 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 1118 | else: 1119 | sequence_lengths = -1 1120 | logger.warning( 1121 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1122 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1123 | ) 1124 | 1125 | pooled_logits = logits[range(batch_size), sequence_lengths] 1126 | 1127 | loss = None 1128 | if labels is not None: 1129 | if self.num_labels == 1: 1130 | # We are doing regression 1131 | loss_fct = MSELoss() 1132 | loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) 1133 | else: 1134 | loss_fct = CrossEntropyLoss() 1135 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1136 | 1137 | if not return_dict: 1138 | output = (pooled_logits,) + transformer_outputs[1:] 1139 | return ((loss,) + output) if loss is not None else output 1140 | 1141 | return SequenceClassifierOutputWithPast( 1142 | loss=loss, 1143 | logits=pooled_logits, 1144 | # past_key_values=transformer_outputs.past_key_values, #this takes up memory 1145 | # hidden_states=transformer_outputs.hidden_states, 1146 | # attentions=transformer_outputs.attentions, 1147 | ) 1148 | --------------------------------------------------------------------------------