├── __init__.py ├── HCT ├── __init__.py ├── rm_cache.sh ├── run_hct.sh ├── create_noisy_dataset.ipynb ├── dataset_prepare.ipynb ├── test.ipynb └── run_glue_hct.py ├── examples ├── DataMap.png └── DataMapCompare.png ├── .gitignore ├── requirements.txt ├── run_glue.sh ├── .vscode-upload.json ├── plot.sh ├── run_glue_and_record_td.sh ├── experiments.md ├── data_utils_glue.py ├── run_glue_copy.sh ├── run_glue_ambi.sh ├── selection_utils.py ├── data_selection.py ├── data_utils.py ├── README.md ├── dy_filtering.py ├── run_glue.py └── hardness_classification_prepare.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /HCT/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/DataMap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beyondguo/TrainingDynamics/HEAD/examples/DataMap.png -------------------------------------------------------------------------------- /examples/DataMapCompare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beyondguo/TrainingDynamics/HEAD/examples/DataMapCompare.png -------------------------------------------------------------------------------- /HCT/rm_cache.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=snli 2 | cd datasets/$TASK_NAME/with_conf/train 3 | rm cache* 4 | cd ../../../../ 5 | cd datasets/$TASK_NAME/with_conf/validation 6 | rm cache* 7 | cd ../../../../ 8 | cd datasets/$TASK_NAME/with_conf/test 9 | rm cache* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # basic ignore: 2 | __pycache__/ 3 | .idea/ 4 | .ipynb_checkpoints/ 5 | .DS_Store 6 | .vscode 7 | /tmp 8 | 9 | .empty/ 10 | # log 11 | /dy_log 12 | /log 13 | 14 | # data 15 | /datasets 16 | 17 | 18 | # models 19 | saved_models/ 20 | HCT/*.weight -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.9.0 2 | datasets==2.3.2 3 | huggingface_hub==0.4.0 4 | jsonx==2020.10.0 5 | matplotlib==3.3.1 6 | numpy==1.19.1 7 | pandas==0.24.2 8 | seaborn==0.11.2 9 | torch==1.7.0a0+8deb4fe 10 | tqdm==4.31.1 11 | transformers==4.18.0 12 | -------------------------------------------------------------------------------- /run_glue.sh: -------------------------------------------------------------------------------- 1 | 2 | # prajjwal1/bert-tiny 3 | # distilbert-base-cased 4 | # roberta-large 5 | 6 | export TASK_NAME=sst2 7 | export MODEL=prajjwal1/bert-tiny 8 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 9 | --model_name_or_path $MODEL \ 10 | --task_name $TASK_NAME \ 11 | --max_length 128 \ 12 | --per_device_train_batch_size 32 \ 13 | --learning_rate 2e-5 \ 14 | --num_train_epochs 10 \ 15 | # --output_dir tmp/$TASK_NAME/ -------------------------------------------------------------------------------- /.vscode-upload.json: -------------------------------------------------------------------------------- 1 | [{ 2 | "name":"", 3 | "host": "", 4 | "port": 22, 5 | "username": "", 6 | "password": "", 7 | "remotePath": "", 8 | "localPath": "", 9 | "disable": false, 10 | "private_key": "~/.ssh/id_rsa" 11 | },{ 12 | "name":"", 13 | "host": "", 14 | "port": 22, 15 | "username": "", 16 | "password": "", 17 | "remotePath": "", 18 | "localPath": "", 19 | "disable": false, 20 | "private_key": "~/.ssh/id_rsa" 21 | }] -------------------------------------------------------------------------------- /plot.sh: -------------------------------------------------------------------------------- 1 | # task_name, model 这俩参数,只是用于展示在plot上 2 | # model_dir 是放 training_dynamics 文件夹的那个目录 3 | # plots_dir 是绘制好的 plot 存放的目录 4 | # burn_out 计算多少轮的 dynamics 5 | 6 | # bert-tiny 7 | # distilbert-base-cased 8 | # roberta-large 9 | 10 | export TASK_NAME=rte-noisy-0.4 11 | export MODEL=bert-base-cased 12 | python -m dy_filtering \ 13 | --plot \ 14 | --task_name $TASK_NAME \ 15 | --model_dir dy_log/$TASK_NAME/$MODEL \ 16 | --plots_dir dy_log/$TASK_NAME/$MODEL \ 17 | --model $MODEL \ 18 | --burn_out 5 19 | 20 | -------------------------------------------------------------------------------- /run_glue_and_record_td.sh: -------------------------------------------------------------------------------- 1 | 2 | # ====================== Recording Training Dynamics: =========== 3 | # `--nproc_per_node N` means the number of GPUs to use 4 | # Suggested models: 5 | # prajjwal1/bert-tiny 6 | # distilbert-base-cased 7 | # bert-base-cased 8 | # roberta-large 9 | 10 | 11 | export TASK_NAME=mnli 12 | export MODEL=distilbert-base-cased 13 | 14 | # CUDA_VISIBLE_DEVICES=7 python run_glue.py \ 15 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 16 | --seed 5 \ 17 | --model_name_or_path $MODEL \ 18 | --task_name $TASK_NAME \ 19 | --max_length 128 \ 20 | --per_device_train_batch_size 32 \ 21 | --learning_rate 2e-5 \ 22 | --num_train_epochs 5 \ 23 | --do_recording \ 24 | -------------------------------------------------------------------------------- /experiments.md: -------------------------------------------------------------------------------- 1 | # sst2 2 | ``` 3 | Main params: 4 | - GPUs: 8 5 | - max_length 128 6 | - batch_size 32 7 | - epochs 5 8 | ``` 9 | 10 | - **`bert-base-cased`** 11 | - 100% train: 0.9105504587155964 12 | - 33% (easy): 0.8772935779816514 13 | - 33% (hard): 0.908256880733945 14 | - 33% (ambiguous): 0.9025229357798165 15 | 16 | --- 17 | 18 | - **`distilbert-base-cased`** 19 | - 100% train: 0.9094036697247706 20 | - 33% (easy): 0.856651376146789 21 | - 33% (hard): 0.8944954128440367 22 | - 33% (ambiguous): 0.8956422018348624 23 | 24 | --- 25 | 26 | - **`bert-tiny`** 27 | - 100% train: 0.788990825688 28 | - 33% (easy): 0.694954128440367 29 | - 33% (hard): 0.37155963302752293 30 | - 33% (ambiguous): 0.5022935779816514 31 | 32 | use the data selection by `distilbert-base-cased` 33 | - 33% (easy): 0.5389908256880734 34 | - 33% (hard): 0.5871559633027523 35 | - 33% (ambiguous): 0.6020642201834863 36 | 37 | -------------------------------------------------------------------------------- /HCT/run_hct.sh: -------------------------------------------------------------------------------- 1 | # ====================== Basic GLUE tasks training and evaluation: 2 | # the following is the standard training for GLUE tasks 3 | 4 | # export TASK_NAME=sst2 5 | # export MODEL=distilbert-base-cased 6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 7 | # --model_name_or_path $MODEL \ 8 | # --task_name $TASK_NAME \ 9 | # --max_length 128 \ 10 | # --per_device_train_batch_size 32 \ 11 | # --learning_rate 2e-5 \ 12 | # --num_train_epochs 5 \ 13 | 14 | 15 | # ————> Suggested models: 16 | # prajjwal1/bert-tiny 17 | # distilbert-base-cased 18 | # bert-base-cased 19 | # roberta-large 20 | 21 | # =========== Train GLUE tasks ============= 22 | # https://huggingface.co/datasets/glue 23 | 24 | export TASK_NAME=rte 25 | export MODEL=bert-base-cased 26 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue_hct.py \ 27 | CUDA_VISIBLE_DEVICES=5 python run_glue_hct.py \ 28 | --seed 5 \ 29 | --model_name_or_path $MODEL \ 30 | --task_name $TASK_NAME \ 31 | --max_length 128 \ 32 | --per_device_train_batch_size 32 \ 33 | --learning_rate 4e-5 \ 34 | --num_train_epochs 10 \ 35 | --temperature 1\ 36 | --mu 0.5 \ 37 | --more_ambiguous 38 | # --hard_with_ls \ 39 | # --ls_weight 0.1 \ 40 | # --hard_inference 41 | 42 | # --with_data_selection \ 43 | # --data_selection_region ambiguous \ 44 | # --output_dir tmp/$TASK_NAME/ 45 | 46 | -------------------------------------------------------------------------------- /data_utils_glue.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import re 4 | import tqdm 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def convert_string_to_unique_number(string: str) -> int: 10 | """ 11 | Hack to convert SNLI ID into a unique integer ID, for tensorizing. 12 | """ 13 | id_map = {'e': '0', 'c': '1', 'n': '2'} 14 | 15 | # SNLI-specific hacks. 16 | if string.startswith('vg_len'): 17 | code = '555' 18 | elif string.startswith('vg_verb'): 19 | code = '444' 20 | else: 21 | code = '000' 22 | 23 | try: 24 | number = int(code + re.sub(r"\D", "", string) + id_map.get(string[-1], '3')) 25 | except: 26 | number = random.randint(10000, 99999) 27 | logger.info(f"Cannot find ID for {string}, using random number {number}.") 28 | return number 29 | 30 | 31 | def read_glue_tsv(file_path: str, 32 | guid_index: int, 33 | label_index: int = -1, 34 | guid_as_int: bool = False): 35 | """ 36 | Reads TSV files for GLUE-style text classification tasks. 37 | Returns: 38 | - a mapping between the example ID and the entire line as a string. 39 | - the header of the TSV file. 40 | """ 41 | tsv_dict = {} 42 | 43 | i = -1 44 | with open(file_path, 'r') as tsv_file: 45 | for line in tqdm.tqdm([line for line in tsv_file]): 46 | i += 1 47 | if i == 0: 48 | header = line.strip() 49 | field_names = line.strip().split("\t") 50 | continue 51 | 52 | fields = line.strip().split("\t") 53 | label = fields[label_index] 54 | if len(fields) > len(field_names): 55 | # SNLI / MNLI fields sometimes contain multiple annotator labels. 56 | # Ignore all except the gold label. 57 | reformatted_fields = fields[:len(field_names)-1] + [label] 58 | assert len(reformatted_fields) == len(field_names) 59 | reformatted_line = "\t".join(reformatted_fields) 60 | else: 61 | reformatted_line = line.strip() 62 | 63 | if label == "-" or label == "": 64 | logger.info(f"Skippping line: {line}") 65 | continue 66 | 67 | if guid_index is None: 68 | guid = i 69 | else: 70 | guid = fields[guid_index] # PairID. 71 | if guid in tsv_dict: 72 | logger.info(f"Found clash in IDs ... skipping example {guid}.") 73 | continue 74 | tsv_dict[guid] = reformatted_line.strip() 75 | 76 | logger.info(f"Read {len(tsv_dict)} valid examples, with unique IDS, out of {i} from {file_path}") 77 | if guid_as_int: 78 | tsv_numeric = {int(convert_string_to_unique_number(k)): v for k, v in tsv_dict.items()} 79 | return tsv_numeric, header 80 | return tsv_dict, header -------------------------------------------------------------------------------- /run_glue_copy.sh: -------------------------------------------------------------------------------- 1 | # ====================== Basic GLUE tasks training and evaluation: 2 | # the following is the standard training for GLUE tasks 3 | 4 | # export TASK_NAME=sst2 5 | # export MODEL=distilbert-base-cased 6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 7 | # --model_name_or_path $MODEL \ 8 | # --task_name $TASK_NAME \ 9 | # --max_length 128 \ 10 | # --per_device_train_batch_size 32 \ 11 | # --learning_rate 2e-5 \ 12 | # --num_train_epochs 5 \ 13 | 14 | 15 | # ————> Training with Data Selection: 16 | # after you run `data_selection.py` and obtain the `three_regions_data_indices.json` file 17 | # you can train a GLUE classifier again with your specified data selection 18 | # set `--with_data_selection` to turn on data selection 19 | # set `--data_selection_region [region]` to specify the region, choices are "easy", "hard", and "ambiguous" 20 | 21 | 22 | # ————> Suggested models: 23 | # prajjwal1/bert-tiny 24 | # distilbert-base-cased 25 | # bert-base-cased 26 | # roberta-large 27 | 28 | # =========== Train GLUE tasks ============= 29 | # https://huggingface.co/datasets/glue 30 | 31 | export TASK_NAME=boolq 32 | export MODEL=bert-base-cased 33 | # CUDA_VISIBLE_DEVICES=4 python run_glue.py \ 34 | python -m torch.distributed.launch --nproc_per_node 1 --use_env run_glue.py \ 35 | --seed 5 \ 36 | --model_name_or_path $MODEL \ 37 | --task_name $TASK_NAME \ 38 | --output_dir saved_models/$TASK_NAME/$MODEL \ 39 | --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_9 \ 40 | --checkpointing_steps epoch \ 41 | --max_length 128 \ 42 | --per_device_train_batch_size 32 \ 43 | --learning_rate 5e-5 \ 44 | --num_train_epochs 10 \ 45 | --continue_train \ 46 | --continue_num_train_epochs 5 \ 47 | --log_name ambiguous \ 48 | --selected_indices_filename selected_indices_ambi_top0.33_balance_from500 \ 49 | # --do_lwf \ 50 | 51 | 52 | 53 | 54 | # --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \ # 指定了之后,就会直接load该epoch的模型 55 | 56 | # --max_train_steps 10 \ 57 | # --with_data_selection \ 58 | # --data_selection_region ambiguous \ 59 | # --output_dir tmp/$TASK_NAME/ 60 | 61 | 62 | # =========== Use Your Own Dataset ======== 63 | 64 | # export MODEL=bert-base-cased 65 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 66 | # --model_name_or_path $MODEL \ 67 | # --max_length 128 \ 68 | # --per_device_train_batch_size 32 \ 69 | # --learning_rate 2e-5 \ 70 | # --num_train_epochs 5 \ 71 | # --train_file datasets/qnli-easy-hard_train.csv \ 72 | # --validation_file datasets/qnli-easy-hard_valid.csv 73 | 74 | 75 | 76 | # cd ../K2T 77 | # sh oc.sh -------------------------------------------------------------------------------- /run_glue_ambi.sh: -------------------------------------------------------------------------------- 1 | # ====================== Basic GLUE tasks training and evaluation: 2 | # the following is the standard training for GLUE tasks 3 | 4 | # export TASK_NAME=sst2 5 | # export MODEL=distilbert-base-cased 6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 7 | # --model_name_or_path $MODEL \ 8 | # --task_name $TASK_NAME \ 9 | # --max_length 128 \ 10 | # --per_device_train_batch_size 32 \ 11 | # --learning_rate 2e-5 \ 12 | # --num_train_epochs 5 \ 13 | 14 | 15 | # ————> Training with Data Selection: 16 | # after you run `data_selection.py` and obtain the `three_regions_data_indices.json` file 17 | # you can train a GLUE classifier again with your specified data selection 18 | # set `--with_data_selection` to turn on data selection 19 | # set `--data_selection_region [region]` to specify the region, choices are "easy", "hard", and "ambiguous" 20 | 21 | 22 | # ————> Suggested models: 23 | # prajjwal1/bert-tiny 24 | # distilbert-base-cased 25 | # bert-base-cased 26 | # roberta-large 27 | 28 | # =========== Train GLUE tasks ============= 29 | # https://huggingface.co/datasets/glue 30 | 31 | export TASK_NAME=mnli 32 | export MODEL=roberta-large 33 | # CUDA_VISIBLE_DEVICES=4 python run_glue.py \ 34 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 35 | --seed 5 \ 36 | --model_name_or_path $MODEL \ 37 | --task_name $TASK_NAME \ 38 | --checkpointing_steps epoch \ 39 | --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \ 40 | --max_length 128 \ 41 | --per_device_train_batch_size 32 \ 42 | --learning_rate 2e-5 \ 43 | --num_train_epochs 5 \ 44 | --log_name ambiguous \ 45 | --continue_train_with_sample_loss \ 46 | --continue_train \ 47 | --continue_num_train_epochs 5 \ 48 | # --selected_indices_filename selected_indices_ambi_top0.1_balance+ \ 49 | # --do_lwf \ 50 | 51 | 52 | 53 | # --output_dir saved_models/$TASK_NAME/$MODEL \ 54 | # --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \ # 指定了之后,就会直接load该epoch的模型 55 | 56 | # --max_train_steps 10 \ 57 | # --with_data_selection \ 58 | # --data_selection_region ambiguous \ 59 | # --output_dir tmp/$TASK_NAME/ 60 | 61 | 62 | # =========== Use Your Own Dataset ======== 63 | 64 | # export MODEL=bert-base-cased 65 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \ 66 | # --model_name_or_path $MODEL \ 67 | # --max_length 128 \ 68 | # --per_device_train_batch_size 32 \ 69 | # --learning_rate 2e-5 \ 70 | # --num_train_epochs 5 \ 71 | # --train_file datasets/qnli-easy-hard_train.csv \ 72 | # --validation_file datasets/qnli-easy-hard_valid.csv 73 | 74 | 75 | 76 | # cd ../K2T 77 | # sh oc.sh -------------------------------------------------------------------------------- /selection_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import tqdm 7 | 8 | from typing import List 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO 12 | ) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def log_training_dynamics(output_dir: os.path, 17 | epoch: int, 18 | train_ids: List[int], 19 | train_logits: List[List[float]], 20 | train_golds: List[int]): 21 | """ 22 | Save training dynamics (logits) from given epoch as records of a `.jsonl` file. 23 | """ 24 | td_df = pd.DataFrame({"guid": train_ids, 25 | f"logits_epoch_{epoch}": train_logits, 26 | "gold": train_golds}) 27 | 28 | logging_dir = os.path.join(output_dir, f"training_dynamics") 29 | # Create directory for logging training dynamics, if it doesn't already exist. 30 | if not os.path.exists(logging_dir): 31 | os.makedirs(logging_dir) 32 | epoch_file_name = os.path.join(logging_dir, f"dynamics_epoch_{epoch}.jsonl") 33 | td_df.to_json(epoch_file_name, lines=True, orient="records") 34 | logger.info(f"Training Dynamics logged to {epoch_file_name}") 35 | 36 | 37 | def read_training_dynamics(model_dir: os.path, 38 | strip_last: bool = False, 39 | id_field: str = "guid", 40 | burn_out: int = None): 41 | """ 42 | Given path to logged training dynamics, merge stats across epochs. 43 | Returns: 44 | - Dict between ID of a train instances and its gold label, and the list of logits across epochs. 45 | """ 46 | train_dynamics = {} 47 | 48 | td_dir = os.path.join(model_dir, "training_dynamics") 49 | num_epochs = len([f for f in os.listdir(td_dir) if os.path.isfile(os.path.join(td_dir, f))]) 50 | if burn_out: 51 | num_epochs = burn_out 52 | 53 | logger.info(f"Reading {num_epochs} files from {td_dir} ...") 54 | for epoch_num in tqdm.tqdm(range(num_epochs)): 55 | epoch_file = os.path.join(td_dir, f"dynamics_epoch_{epoch_num}.jsonl") 56 | assert os.path.exists(epoch_file) 57 | 58 | with open(epoch_file, "r") as infile: 59 | # print('*** Current Reading:',epoch_file) 60 | for line in infile: 61 | record = json.loads(line.strip()) 62 | guid = record[id_field] if not strip_last else record[id_field][:-1] 63 | if guid not in train_dynamics: 64 | assert epoch_num == 0 65 | train_dynamics[guid] = {"gold": record["gold"], "logits": []} 66 | train_dynamics[guid]["logits"].append(record[f"logits_epoch_{epoch_num}"]) 67 | 68 | logger.info(f"Read training dynamics for {len(train_dynamics)} train instances.") 69 | return train_dynamics 70 | -------------------------------------------------------------------------------- /data_selection.py: -------------------------------------------------------------------------------- 1 | # Only applied to training set 2 | # python data_selection.py --task_name qnli --model_name bert-base-cased --proportion 0.5 --burn_out 4 3 | import json 4 | import random 5 | random.seed(1) 6 | import argparse 7 | from dy_filtering import read_training_dynamics, compute_train_dy_metrics 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--task_name", type=str) 11 | parser.add_argument("--model_name", type=str) 12 | parser.add_argument("--proportion", type=float, default=0.33) 13 | parser.add_argument("--burn_out", type=int) 14 | args = parser.parse_args() 15 | 16 | TASK_NAME = args.task_name 17 | MODEL = args.model_name 18 | PROPORTION = args.proportion 19 | 20 | # 读取并合并到一个文件 21 | td = read_training_dynamics(f'dy_log/{TASK_NAME}/{MODEL}/') 22 | # 计算 metrics,转化成一个 dataframe 23 | td_df, _ = compute_train_dy_metrics(td, burn_out=args.burn_out) 24 | 25 | 26 | def consider_ascending_order(filtering_metric: str) -> bool: 27 | """ 28 | Determine if the metric values' sorting order to get the most `valuable` examples for training. 29 | """ 30 | if filtering_metric == "variability": 31 | return False 32 | elif filtering_metric == "confidence": 33 | return True 34 | elif filtering_metric == "threshold_closeness": 35 | return False 36 | elif filtering_metric == "forgetfulness": 37 | return False 38 | elif filtering_metric == "correctness": 39 | return True 40 | else: 41 | raise NotImplementedError(f"Filtering based on {filtering_metric} not implemented!") 42 | 43 | 44 | 45 | def data_selection(metric, select_worst, proportion, shuffle=True): 46 | ascending = consider_ascending_order(metric) 47 | if select_worst: 48 | ascending = not consider_ascending_order(metric) 49 | sorted_df = td_df.sort_values(by=metric, ascending=ascending) 50 | selected_df = sorted_df.head(n=int(proportion * len(sorted_df))) 51 | indices = list(selected_df['guid']) 52 | if shuffle: 53 | random.shuffle(indices) 54 | return {'indices':indices, 'df':selected_df} 55 | 56 | 57 | """ 58 | hard-to-learn: METRIC = 'confidence' 59 | easy-to-learn: METRIC = 'confidence', SELECT_WORST = True 60 | ambiguoug: METRIC = 'variability' 61 | """ 62 | 63 | three_regions_data_indices = {'hard':data_selection('confidence', False, PROPORTION)['indices'], 64 | 'easy':data_selection('confidence', True, PROPORTION)['indices'], 65 | 'ambiguous':data_selection('variability', False, PROPORTION)['indices']} 66 | 67 | with open(f'dy_log/{TASK_NAME}/{MODEL}/three_regions_data_indices.json','w') as f: 68 | f.write(json.dumps(three_regions_data_indices)) 69 | 70 | # 然后可以直接跑glue任务,在选择训练集的时候,使用select函数来指定对应样本即可: 71 | """ e.g. 72 | from datasets import load_dataset 73 | raw_datasets = load_dataset('glue','sst2') 74 | easy_train_set = raw_datasets['train'].select(three_regions_data_indices['easy']) 75 | """ 76 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Utilities for data handling. 4 | """ 5 | import json 6 | import logging 7 | import os 8 | import pandas as pd 9 | import shutil 10 | 11 | from typing import Dict 12 | 13 | from data_utils_glue import read_glue_tsv 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def read_data(file_path: str, 19 | task_name: str, 20 | guid_as_int: bool = False): 21 | """ 22 | Reads task-specific datasets from corresponding GLUE-style TSV files. 23 | """ 24 | logger.warning("Data reading only works when data is in TSV format, " 25 | " and last column as classification label.") 26 | 27 | # `guid_index`: should be 2 for SNLI, 0 for MNLI and None for any random tsv file. 28 | if task_name == "MNLI": 29 | return read_glue_tsv(file_path, 30 | guid_index=0, 31 | guid_as_int=guid_as_int) 32 | elif task_name == "SNLI": 33 | return read_glue_tsv(file_path, 34 | guid_index=2, 35 | guid_as_int=guid_as_int) 36 | elif task_name == "WINOGRANDE": 37 | return read_glue_tsv(file_path, 38 | guid_index=0) 39 | elif task_name == "QNLI": 40 | return read_glue_tsv(file_path, 41 | guid_index=0) 42 | else: 43 | raise NotImplementedError(f"Reader for {task_name} not implemented.") 44 | 45 | 46 | def convert_tsv_entries_to_dataframe(tsv_dict: Dict, header: str) -> pd.DataFrame: 47 | """ 48 | Converts entries from TSV file to Pandas DataFrame for faster processing. 49 | """ 50 | header_fields = header.strip().split("\t") 51 | data = {header: [] for header in header_fields} 52 | 53 | for line in tsv_dict.values(): 54 | fields = line.strip().split("\t") 55 | assert len(header_fields) == len(fields) 56 | for field, header in zip(fields, header_fields): 57 | data[header].append(field) 58 | 59 | df = pd.DataFrame(data, columns=header_fields) 60 | return df 61 | 62 | 63 | def copy_dev_test(task_name: str, 64 | from_dir: os.path, 65 | to_dir: os.path, 66 | extension: str = ".tsv"): 67 | """ 68 | Copies development and test sets (for data selection experiments) from `from_dir` to `to_dir`. 69 | """ 70 | if task_name == "MNLI": 71 | dev_filename = "dev_matched.tsv" 72 | test_filename = "dev_mismatched.tsv" 73 | elif task_name in ["SNLI", "QNLI", "WINOGRANDE"]: 74 | dev_filename = f"dev{extension}" 75 | test_filename = f"test{extension}" 76 | else: 77 | raise NotImplementedError(f"Logic for {task_name} not implemented.") 78 | 79 | dev_path = os.path.join(from_dir, dev_filename) 80 | if os.path.exists(dev_path): 81 | shutil.copyfile(dev_path, os.path.join(to_dir, dev_filename)) 82 | else: 83 | raise ValueError(f"No file found at {dev_path}") 84 | 85 | test_path = os.path.join(from_dir, test_filename) 86 | if os.path.exists(test_path): 87 | shutil.copyfile(test_path, os.path.join(to_dir, test_filename)) 88 | else: 89 | raise ValueError(f"No file found at {test_path}") 90 | 91 | 92 | def read_jsonl(file_path: str, key: str = "pairID"): 93 | """ 94 | Reads JSONL file to recover mapping between one particular key field 95 | in the line and the result of the line as a JSON dict. 96 | If no key is provided, return a list of JSON dicts. 97 | """ 98 | df = pd.read_json(file_path, lines=True) 99 | records = df.to_dict('records') 100 | logger.info(f"Read {len(records)} JSON records from {file_path}.") 101 | 102 | if key: 103 | assert key in df.columns 104 | return {record[key]: record for record in records} 105 | return records 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrainingDynamics 2 | > Re-implementation of the paper [Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics (EMNLP-20)](https://aclanthology.org/2020.emnlp-main.746/). This project is mainly based on [AllenAI's Dataset Cartography](https://github.com/allenai/cartography) project, where the model outputs (logits) of each sample is recorded after every training epoch. Based on these records, training dynamics (prediction confidence, variability, etc.) are computed to plot the Data Cartography to visualize the distribution of all training samples. However, the [original repo](https://github.com/allenai/cartography) hasn't been maintained for a long time. In this repo, we use the **latest version** of packages to reimplement the Dataset Cartography, as well as some other extensions based on the training dynamics. 3 | 4 | ## Basic requirements: 5 | - transformers==4.18.0 6 | - torch==1.7.0 7 | - datasets==2.3.2 8 | - accelerate==0.9.0 9 | 10 | More requirements see `requirements.txt`. 11 | 12 | ## Usage: 13 | For example, we want to record the training dynamics of SST2 dataset (a sentiment classification task from GLUE), we do the following steps: 14 | 15 | 1. **Run `sh run_glue_and_record_td.sh` to obtain the training dynamics.** 16 | 17 | Specify the `TASK_NAME` (here we choose `sst2`), `MODEL` you want to use and num of epochs to train the classifier. 18 | 19 | The following infomation will be recorded during training: 20 | - 'guid': the id of the sample 21 | - 'logits_epoch_{epoch}': output logits vector of the current sample 22 | - 'gold': the true label (index) 23 | 24 | After training, we can find the log files in `./dy_log/{TASK_NAME}/{MODEL}/training_dynamics` directory like: 25 | ```shell 26 | dynamics_epoch_0.jsonl 27 | dynamics_epoch_1.jsonl 28 | dynamics_epoch_2.jsonl 29 | ... 30 | ``` 31 | each file contains records like: 32 | ```shell 33 | {"guid": 50325, "logits_epoch_0": [2.943110942840576, -2.2836594581604004], "gold": 0, "device": "cuda:0"} 34 | {"guid": 42123, "logits_epoch_0": [-2.7155513763427734, 3.249767541885376], "gold": 1, "device": "cuda:0"} 35 | {"guid": 42936, "logits_epoch_0": [-1.1907235383987427, 2.1173453330993652], "gold": 1, "device": "cuda:0"} 36 | ... 37 | ``` 38 | 39 | 2. **Run `sh plot.sh` to plot the data cartography based the recorded training dynamics.** 40 | 41 | In `plot.sh`, we can specify the TASK_NAME and MODEL, which are used to determine the path of the training dynamics. First, the log files from each epoch are collected together, several metrics (confidence, variability, correctness, forgetfulness, etc.) are calculated and saved into a single file, named by 'td_metrics.jsonl' (in the save directory `./dy_log/{TASK}/{MODEL}/training_dynamics`): 42 | 43 | ```shell 44 | {"guid":50325,"index":0,"threshold_closeness":0.0039580798,"confidence":0.9960261285,"variability":0.0012847629,"correctness":4,"forgetfulness":0} 45 | {"guid":42123,"index":1,"threshold_closeness":0.0012448987,"confidence":0.9987535477,"variability":0.0007707975,"correctness":4,"forgetfulness":0} 46 | {"guid":42936,"index":2,"threshold_closeness":0.0396512556,"confidence":0.958637923,"variability":0.0095242939,"correctness":4,"forgetfulness":0} 47 | ... 48 | ``` 49 | 50 | Then, a data map (dataset cartography) is plotted based on these metrics: 51 | ![Data Map](examples/DataMap.png) 52 | 53 | ## Data Selection 54 | After recording the training dynamics, we can re-train the model by selecting a subset (e.g. use only the ambiguous samples for training). 55 | For example, for `sst2` task and `bert-tiny` model, just run: 56 | ```shell 57 | python data_selection.py --task_name sst2 --model_name bert-tiny --burn_out 4 58 | ``` 59 | then you can get a json file at `dy_log/sst2/bert-tiny/three_regions_data_indices.json` 60 | 61 | then, run `sh run_glue.sh` by adding `--with_data_selection` and `--data_selection_region [region]`. 62 | 63 | More details see comments in `run_glue.sh`. 64 | 65 | ## Other Extensions: 66 | Apart from the above usage, we can also compare the difference between two models (e.g. a strong model and a weak model) by computing the change of the dynamics. For example, we train a weak model (BERT-tiny) and strong model (RoBERTa-large) on SST2 dataset and plot their difference: 67 | 68 | Data Map Comparison 69 | 70 | 71 | You can find more detailed usage of this repo in our notebook `plot_demo.ipynb`. 72 | 73 | --- 74 | 75 | *Have fun and fell free to give your feedback :)* 76 | 77 | Citation: 78 | ``` 79 | @inproceedings{swayamdipta-etal-2020-dataset, 80 | title = "Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics", 81 | author = "Swayamdipta, Swabha and 82 | Schwartz, Roy and 83 | Lourie, Nicholas and 84 | Wang, Yizhong and 85 | Hajishirzi, Hannaneh and 86 | Smith, Noah A. and 87 | Choi, Yejin", 88 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 89 | month = nov, 90 | year = "2020", 91 | address = "Online", 92 | publisher = "Association for Computational Linguistics", 93 | url = "https://aclanthology.org/2020.emnlp-main.746", 94 | doi = "10.18653/v1/2020.emnlp-main.746", 95 | pages = "9275--9293", 96 | } 97 | ``` -------------------------------------------------------------------------------- /HCT/create_noisy_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from datasets import load_dataset" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Reusing dataset glue (/home/v-biyangguo/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" 22 | ] 23 | }, 24 | { 25 | "data": { 26 | "application/vnd.jupyter.widget-view+json": { 27 | "model_id": "a057b84c31934616ac67c2d40c9af4ba", 28 | "version_major": 2, 29 | "version_minor": 0 30 | }, 31 | "text/plain": [ 32 | " 0%| | 0/3 [00:00 int: 36 | """ 37 | Given a epoch-wise trend of train predictions, compute frequency with which 38 | an example is forgotten, i.e. predicted incorrectly _after_ being predicted correctly. 39 | Based on: https://arxiv.org/abs/1812.05159 40 | """ 41 | if not any(correctness_trend): # Example is never predicted correctly, or learnt! 42 | return 1000 43 | learnt = False # Predicted correctly in the current epoch. 44 | times_forgotten = 0 45 | for is_correct in correctness_trend: 46 | if (not learnt and not is_correct) or (learnt and is_correct): 47 | # nothing changed. 48 | continue 49 | elif learnt and not is_correct: 50 | # Forgot after learning at some point! 51 | learnt = False 52 | times_forgotten += 1 53 | elif not learnt and is_correct: 54 | # Learnt! 55 | learnt = True 56 | return times_forgotten 57 | 58 | 59 | def compute_correctness(trend: List[float]) -> float: 60 | """ 61 | Aggregate #times an example is predicted correctly during all training epochs. 62 | """ 63 | return sum(trend) 64 | 65 | 66 | def compute_train_dy_metrics(training_dynamics, include_ci=False, burn_out=100): 67 | """ 68 | Given the training dynamics (logits for each training instance across epochs), compute metrics 69 | based on it, for data map coorodinates. 70 | Computed metrics are: confidence, variability, correctness, forgetfulness, threshold_closeness--- 71 | the last two being baselines from prior work 72 | (Example Forgetting: https://arxiv.org/abs/1812.05159 and 73 | Active Bias: https://arxiv.org/abs/1704.07433 respectively). 74 | Returns: 75 | - DataFrame with these metrics. 76 | - DataFrame with more typical training evaluation metrics, such as accuracy / loss. 77 | """ 78 | confidence_ = {} 79 | variability_ = {} 80 | threshold_closeness_ = {} 81 | correctness_ = {} 82 | forgetfulness_ = {} 83 | 84 | # Functions to be applied to the data. 85 | variability_func = lambda conf: np.std(conf) 86 | if include_ci: # Based on prior work on active bias (https://arxiv.org/abs/1704.07433) 87 | variability_func = lambda conf: np.sqrt(np.var(conf) + np.var(conf) * np.var(conf) / (len(conf)-1)) 88 | threshold_closeness_func = lambda conf: conf * (1 - conf) 89 | 90 | loss = torch.nn.CrossEntropyLoss() 91 | 92 | num_tot_epochs = len(list(training_dynamics.values())[0]["logits"]) 93 | if burn_out < num_tot_epochs: 94 | logger.info(f"Computing training dynamics. Burning out at {burn_out} of {num_tot_epochs}. ") 95 | else: 96 | logger.info(f"Computing training dynamics across {num_tot_epochs} epochs") 97 | logger.info("Metrics computed: confidence, variability, correctness, forgetfulness, threshold_closeness") 98 | 99 | logits = {i: [] for i in range(num_tot_epochs)} 100 | targets = {i: [] for i in range(num_tot_epochs)} 101 | training_accuracy = defaultdict(float) 102 | 103 | for guid in tqdm.tqdm(training_dynamics): 104 | correctness_trend = [] 105 | true_probs_trend = [] 106 | 107 | record = training_dynamics[guid] 108 | for i, epoch_logits in enumerate(record["logits"]): 109 | probs = torch.nn.functional.softmax(torch.Tensor(epoch_logits), dim=-1) 110 | true_class_prob = float(probs[record["gold"]]) 111 | true_probs_trend.append(true_class_prob) 112 | 113 | prediction = np.argmax(epoch_logits) 114 | is_correct = (prediction == record["gold"]).item() 115 | correctness_trend.append(is_correct) 116 | 117 | training_accuracy[i] += is_correct 118 | logits[i].append(epoch_logits) 119 | targets[i].append(record["gold"]) 120 | 121 | if burn_out < num_tot_epochs: 122 | correctness_trend = correctness_trend[:burn_out] 123 | true_probs_trend = true_probs_trend[:burn_out] 124 | 125 | correctness_[guid] = compute_correctness(correctness_trend) 126 | confidence_[guid] = np.mean(true_probs_trend) 127 | variability_[guid] = variability_func(true_probs_trend) 128 | 129 | forgetfulness_[guid] = compute_forgetfulness(correctness_trend) 130 | threshold_closeness_[guid] = threshold_closeness_func(confidence_[guid]) 131 | 132 | # Should not affect ranking, so ignoring. 133 | epsilon_var = np.mean(list(variability_.values())) 134 | 135 | column_names = ['guid', 136 | 'index', 137 | 'threshold_closeness', 138 | 'confidence', 139 | 'variability', 140 | 'correctness', 141 | 'forgetfulness',] 142 | df = pd.DataFrame([[guid, 143 | i, 144 | threshold_closeness_[guid], 145 | confidence_[guid], 146 | variability_[guid], 147 | correctness_[guid], 148 | forgetfulness_[guid], 149 | ] for i, guid in enumerate(correctness_)], columns=column_names) 150 | 151 | df_train = pd.DataFrame([[i, 152 | loss(torch.Tensor(logits[i]), torch.LongTensor(targets[i])).item() / len(training_dynamics), 153 | training_accuracy[i] / len(training_dynamics) 154 | ] for i in range(num_tot_epochs)], 155 | columns=['epoch', 'loss', 'train_acc']) 156 | return df, df_train 157 | 158 | 159 | def consider_ascending_order(filtering_metric: str) -> bool: 160 | """ 161 | Determine if the metric values' sorting order to get the most `valuable` examples for training. 162 | """ 163 | if filtering_metric == "variability": 164 | return False 165 | elif filtering_metric == "confidence": 166 | return True 167 | elif filtering_metric == "threshold_closeness": 168 | return False 169 | elif filtering_metric == "forgetfulness": 170 | return False 171 | elif filtering_metric == "correctness": 172 | return True 173 | else: 174 | raise NotImplementedError(f"Filtering based on {filtering_metric} not implemented!") 175 | 176 | 177 | def write_filtered_data(args, train_dy_metrics): 178 | """ 179 | Filter data based on the given metric, and write it in TSV format to train GLUE-style classifier. 180 | """ 181 | # First save the args for filtering, to keep track of which model was used for filtering. 182 | argparse_dict = vars(args) 183 | with open(os.path.join(args.filtering_output_dir, f"filtering_configs.json"), "w") as outfile: 184 | outfile.write(json.dumps(argparse_dict, indent=4, sort_keys=True) + "\n") 185 | 186 | # Determine whether to sort data in ascending order or not, based on the metric. 187 | is_ascending = consider_ascending_order(args.metric) 188 | if args.worst: 189 | is_ascending = not is_ascending 190 | 191 | # Sort by selection. 192 | sorted_scores = train_dy_metrics.sort_values(by=[args.metric],ascending=is_ascending) 193 | 194 | original_train_file = os.path.join(os.path.join(args.data_dir, args.task_name), f"train.tsv") 195 | train_numeric, header = read_data(original_train_file, task_name=args.task_name, guid_as_int=True) 196 | 197 | for fraction in [0.01, 0.05, 0.10, 0.1667, 0.25, 0.3319, 0.50, 0.75]: 198 | outdir = os.path.join(args.filtering_output_dir, 199 | f"cartography_{args.metric}_{fraction:.2f}/{args.task_name}") 200 | if not os.path.exists(outdir): 201 | os.makedirs(outdir) 202 | 203 | # Dev and test need not be subsampled. 204 | copy_dev_test(args.task_name, 205 | from_dir=os.path.join(args.data_dir, args.task_name), 206 | to_dir=outdir) 207 | 208 | num_samples = int(fraction * len(train_numeric)) 209 | with open(os.path.join(outdir, f"train.tsv"), "w") as outfile: 210 | outfile.write(header + "\n") 211 | selected = sorted_scores.head(n=num_samples+1) 212 | if args.both_ends: 213 | hardest = sorted_scores.head(n=int(num_samples * 0.7)) 214 | easiest = sorted_scores.tail(n=num_samples - hardest.shape[0]) 215 | selected = pd.concat([hardest, easiest]) 216 | fm = args.metric 217 | logger.info(f"Selecting both ends: {fm} = " 218 | f"({hardest.head(1)[fm].values[0]:3f}: {hardest.tail(1)[fm].values[0]:3f}) " 219 | f"& ({easiest.head(1)[fm].values[0]:3f}: {easiest.tail(1)[fm].values[0]:3f})") 220 | 221 | selection_iterator = tqdm.tqdm(range(len(selected))) 222 | for idx in selection_iterator: 223 | selection_iterator.set_description( 224 | f"{args.metric} = {selected.iloc[idx][args.metric]:.4f}") 225 | 226 | selected_id = selected.iloc[idx]["guid"] 227 | if args.task_name in ["SNLI", "MNLI"]: 228 | selected_id = int(selected_id) 229 | elif args.task_name == "WINOGRANDE": 230 | selected_id = str(int(selected_id)) 231 | record = train_numeric[selected_id] 232 | outfile.write(record + "\n") 233 | 234 | logger.info(f"Wrote {num_samples} samples to {outdir}.") 235 | 236 | 237 | def plot_data_map(dataframe: pd.DataFrame, 238 | plot_dir: os.path, 239 | hue_metric: str = 'correct.', 240 | title: str = '', 241 | model: str = 'RoBERTa', 242 | show_hist: bool = False, 243 | max_instances_to_plot = 55000): 244 | # Set style. 245 | sns.set(style='whitegrid', font_scale=1.6, font='Georgia', context='paper') 246 | logger.info(f"Plotting figure for {title} using the {model} model ...") 247 | 248 | # Subsample data to plot, so the plot is not too busy. 249 | dataframe = dataframe.sample(n=max_instances_to_plot if dataframe.shape[0] > max_instances_to_plot else len(dataframe)) 250 | 251 | # Normalize correctness to a value between 0 and 1. 252 | dataframe = dataframe.assign(corr_frac = lambda d: d.correctness / d.correctness.max()) 253 | dataframe['correct.'] = [f"{x:.1f}" for x in dataframe['corr_frac']] 254 | 255 | main_metric = 'variability' 256 | other_metric = 'confidence' 257 | 258 | hue = hue_metric 259 | num_hues = len(dataframe[hue].unique().tolist()) 260 | style = hue_metric if num_hues < 8 else None 261 | 262 | if not show_hist: 263 | fig, ax0 = plt.subplots(1, 1, figsize=(8, 6)) 264 | else: 265 | fig = plt.figure(figsize=(14, 10), ) 266 | gs = fig.add_gridspec(3, 2, width_ratios=[5, 1]) # 构造一个三行两列的图,两列的宽度比例是 5:1 267 | ax0 = fig.add_subplot(gs[:, 0]) # ax0是一个子图,位置在整个fig的第一列区域 268 | 269 | # Make the scatterplot. 270 | # Choose a palette. 271 | pal = sns.diverging_palette(260, 15, n=num_hues, sep=10, center="dark") 272 | 273 | plot = sns.scatterplot(x=main_metric, 274 | y=other_metric, 275 | ax=ax0, 276 | data=dataframe, 277 | hue=hue, 278 | palette=pal, 279 | style=style, 280 | s=30) 281 | 282 | # Annotate Regions. 283 | bb = lambda c: dict(boxstyle="round,pad=0.3", ec=c, lw=2, fc="white") 284 | func_annotate = lambda text, xyc, bbc : ax0.annotate(text, 285 | xy=xyc, 286 | xycoords="axes fraction", 287 | fontsize=15, 288 | color='black', 289 | va="center", 290 | ha="center", 291 | rotation=350, 292 | bbox=bb(bbc)) 293 | an1 = func_annotate("ambiguous", xyc=(0.9, 0.5), bbc='black') 294 | an2 = func_annotate("easy-to-learn", xyc=(0.27, 0.85), bbc='r') 295 | an3 = func_annotate("hard-to-learn", xyc=(0.35, 0.25), bbc='b') 296 | 297 | 298 | if not show_hist: 299 | plot.legend(ncol=1, bbox_to_anchor=[0.175, 0.5], loc='right') 300 | else: 301 | plot.legend(fancybox=True, shadow=True, ncol=1) 302 | plot.set_xlabel('variability') 303 | plot.set_ylabel('confidence') 304 | 305 | if show_hist: 306 | plot.set_title(f"{title}-{model} Data Map", fontsize=17) 307 | 308 | # Make the histograms. 309 | ax1 = fig.add_subplot(gs[0, 1]) 310 | ax2 = fig.add_subplot(gs[1, 1]) 311 | ax3 = fig.add_subplot(gs[2, 1]) 312 | 313 | plott0 = dataframe.hist(column=['confidence'], ax=ax1, color='#622a87') 314 | plott0[0].set_title('') 315 | plott0[0].set_xlabel('confidence') 316 | plott0[0].set_ylabel('density') 317 | 318 | plott1 = dataframe.hist(column=['variability'], ax=ax2, color='teal') 319 | plott1[0].set_title('') 320 | plott1[0].set_xlabel('variability') 321 | plott1[0].set_ylabel('density') 322 | 323 | plot2 = sns.countplot(x="correct.", data=dataframe, ax=ax3, color='#86bf91') 324 | ax3.xaxis.grid(True) # Show the vertical gridlines 325 | 326 | plot2.set_title('') 327 | plot2.set_xlabel('correctness') 328 | plot2.set_ylabel('density') 329 | 330 | fig.tight_layout() 331 | filename = f'{plot_dir}/{title}_{model}.pdf' if show_hist else f'figures/compact_{title}_{model}.pdf' 332 | fig.savefig(filename, dpi=300) 333 | logger.info(f"Plot saved to {filename}") 334 | 335 | 336 | if __name__ == "__main__": 337 | parser = argparse.ArgumentParser() 338 | parser.add_argument("--filter", 339 | action="store_true", 340 | help="Whether to filter data subsets based on specified `metric`.") 341 | parser.add_argument("--plot", 342 | action="store_true", 343 | help="Whether to plot data maps and save as `pdf`.") 344 | parser.add_argument("--model_dir", 345 | "-o", 346 | required=True, 347 | type=os.path.abspath, 348 | help="Directory where model training dynamics stats reside.") 349 | parser.add_argument("--data_dir", 350 | "-d", 351 | default="/Users/swabhas/data/glue/WINOGRANDE/xl/", 352 | type=os.path.abspath, 353 | help="Directory where data for task resides.") 354 | parser.add_argument("--plots_dir", 355 | default="./cartography/", 356 | type=os.path.abspath, 357 | help="Directory where plots are to be saved.") 358 | parser.add_argument("--task_name", 359 | "-t", 360 | # default="WINOGRANDE", 361 | # choices=("SNLI", "MNLI", "QNLI", "WINOGRANDE"), 362 | help="Which task are we plotting or filtering for.") 363 | parser.add_argument('--metric', 364 | choices=('threshold_closeness', 365 | 'confidence', 366 | 'variability', 367 | 'correctness', 368 | 'forgetfulness'), 369 | help="Metric to filter data by.",) 370 | parser.add_argument("--include_ci", 371 | action="store_true", 372 | help="Compute the confidence interval for variability.") 373 | parser.add_argument("--filtering_output_dir", 374 | "-f", 375 | default="./filtered/", 376 | type=os.path.abspath, 377 | help="Output directory where filtered datasets are to be written.") 378 | parser.add_argument("--worst", 379 | action="store_true", 380 | help="Select from the opposite end of the spectrum acc. to metric," 381 | "for baselines") 382 | parser.add_argument("--both_ends", 383 | action="store_true", 384 | help="Select from both ends of the spectrum acc. to metric,") 385 | parser.add_argument("--burn_out", 386 | type=int, 387 | default=100, 388 | help="# Epochs for which to compute train dynamics.") 389 | parser.add_argument("--model", 390 | default="RoBERTa", 391 | help="Model for which data map is being plotted") 392 | 393 | args = parser.parse_args() 394 | 395 | training_dynamics = read_training_dynamics(args.model_dir, 396 | strip_last=True if args.task_name in ["QNLI"] else False, 397 | burn_out=args.burn_out if args.burn_out < 100 else None) 398 | total_epochs = len(list(training_dynamics.values())[0]["logits"]) 399 | if args.burn_out > total_epochs: 400 | args.burn_out = total_epochs 401 | logger.info(f"Total epochs found: {args.burn_out}") 402 | train_dy_metrics, _ = compute_train_dy_metrics(training_dynamics, args.include_ci, args.burn_out) 403 | 404 | burn_out_str = f"_{args.burn_out}" if args.burn_out > total_epochs else "" 405 | train_dy_filename = os.path.join(args.model_dir, f"td_metrics{burn_out_str}.jsonl") 406 | train_dy_metrics.to_json(train_dy_filename, 407 | orient='records', 408 | lines=True) 409 | logger.info(f"Metrics based on Training Dynamics written to {train_dy_filename}") 410 | 411 | if args.filter: 412 | assert args.filtering_output_dir 413 | if not os.path.exists(args.filtering_output_dir): 414 | os.makedirs(args.filtering_output_dir) 415 | assert args.metric 416 | write_filtered_data(args, train_dy_metrics) 417 | 418 | if args.plot: 419 | assert args.plots_dir 420 | if not os.path.exists(args.plots_dir): 421 | os.makedirs(args.plots_dir) 422 | plot_data_map(train_dy_metrics, args.plots_dir, title=args.task_name, show_hist=True, model=args.model) 423 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" 16 | import argparse 17 | from collections import defaultdict 18 | import json 19 | import logging 20 | import math 21 | import os 22 | import random 23 | from pathlib import Path 24 | 25 | import datasets 26 | import torch 27 | from datasets import load_dataset, load_metric 28 | from torch.utils.data import DataLoader 29 | from tqdm.auto import tqdm 30 | 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import set_seed 35 | from huggingface_hub import Repository 36 | from transformers import ( 37 | AutoConfig, 38 | AutoModelForSequenceClassification, 39 | AutoTokenizer, 40 | DataCollatorWithPadding, 41 | PretrainedConfig, 42 | SchedulerType, 43 | default_data_collator, 44 | get_scheduler, 45 | ) 46 | # from transformers.utils import get_full_repo_name, send_example_telemetry 47 | from transformers.utils.versions import require_version 48 | 49 | 50 | logger = get_logger(__name__) 51 | 52 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 53 | 54 | task_to_keys = { 55 | "cola": ("sentence", None), 56 | "mnli": ("premise", "hypothesis"), 57 | "mrpc": ("sentence1", "sentence2"), 58 | "qnli": ("question", "sentence"), 59 | "qqp": ("question1", "question2"), 60 | "rte": ("sentence1", "sentence2"), 61 | "sst2": ("sentence", None), 62 | "stsb": ("sentence1", "sentence2"), 63 | "wnli": ("sentence1", "sentence2"), 64 | } 65 | 66 | 67 | def parse_args(): 68 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 69 | parser.add_argument( 70 | "--task_name", 71 | type=str, 72 | default=None, 73 | help="The name of the glue task to train on.", 74 | choices=list(task_to_keys.keys()), 75 | ) 76 | parser.add_argument( 77 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 78 | ) 79 | parser.add_argument( 80 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 81 | ) 82 | parser.add_argument( 83 | "--max_length", 84 | type=int, 85 | default=128, 86 | help=( 87 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 88 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 89 | ), 90 | ) 91 | parser.add_argument( 92 | "--pad_to_max_length", 93 | action="store_true", 94 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 95 | ) 96 | parser.add_argument( 97 | "--model_name_or_path", 98 | type=str, 99 | help="Path to pretrained model or model identifier from huggingface.co/models.", 100 | required=True, 101 | ) 102 | parser.add_argument( 103 | "--use_slow_tokenizer", 104 | action="store_true", 105 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 106 | ) 107 | parser.add_argument( 108 | "--per_device_train_batch_size", 109 | type=int, 110 | default=8, 111 | help="Batch size (per device) for the training dataloader.", 112 | ) 113 | parser.add_argument( 114 | "--per_device_eval_batch_size", 115 | type=int, 116 | default=8, 117 | help="Batch size (per device) for the evaluation dataloader.", 118 | ) 119 | parser.add_argument( 120 | "--learning_rate", 121 | type=float, 122 | default=5e-5, 123 | help="Initial learning rate (after the potential warmup period) to use.", 124 | ) 125 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 126 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 127 | parser.add_argument( 128 | "--max_train_steps", 129 | type=int, 130 | default=None, 131 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 132 | ) 133 | parser.add_argument( 134 | "--gradient_accumulation_steps", 135 | type=int, 136 | default=1, 137 | help="Number of updates steps to accumulate before performing a backward/update pass.", 138 | ) 139 | parser.add_argument( 140 | "--lr_scheduler_type", 141 | type=SchedulerType, 142 | default="linear", 143 | help="The scheduler type to use.", 144 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 145 | ) 146 | parser.add_argument( 147 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 148 | ) 149 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 150 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 151 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 152 | parser.add_argument( 153 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 154 | ) 155 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 156 | parser.add_argument( 157 | "--checkpointing_steps", 158 | type=str, 159 | default=None, 160 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 161 | ) 162 | parser.add_argument( 163 | "--resume_from_checkpoint", 164 | type=str, 165 | default=None, 166 | help="If the training should continue from a checkpoint folder.", 167 | ) 168 | parser.add_argument( 169 | "--with_tracking", 170 | action="store_true", 171 | help="Whether to enable experiment trackers for logging.", 172 | ) 173 | parser.add_argument( 174 | "--report_to", 175 | type=str, 176 | default="all", 177 | help=( 178 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 179 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 180 | "Only applicable when `--with_tracking` is passed." 181 | ), 182 | ) 183 | parser.add_argument( 184 | "--ignore_mismatched_sizes", 185 | action="store_true", 186 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 187 | ) 188 | args = parser.parse_args() 189 | 190 | # Sanity checks 191 | if args.task_name is None and args.train_file is None and args.validation_file is None: 192 | raise ValueError("Need either a task name or a training/validation file.") 193 | else: 194 | if args.train_file is not None: 195 | extension = args.train_file.split(".")[-1] 196 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 197 | if args.validation_file is not None: 198 | extension = args.validation_file.split(".")[-1] 199 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 200 | 201 | # if args.push_to_hub: 202 | # assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 203 | 204 | return args 205 | 206 | 207 | def main(): 208 | args = parse_args() 209 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 210 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 211 | # send_example_telemetry("run_glue_no_trainer", args) 212 | 213 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 214 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 215 | # in the environment 216 | accelerator = ( 217 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() 218 | ) 219 | # Make one log on every process with the configuration for debugging. 220 | logging.basicConfig( 221 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 222 | datefmt="%m/%d/%Y %H:%M:%S", 223 | level=logging.INFO, 224 | ) 225 | logger.info(accelerator.state, main_process_only=False) 226 | if accelerator.is_local_main_process: 227 | datasets.utils.logging.set_verbosity_warning() 228 | transformers.utils.logging.set_verbosity_info() 229 | else: 230 | datasets.utils.logging.set_verbosity_error() 231 | transformers.utils.logging.set_verbosity_error() 232 | 233 | # If passed along, set the training seed now. 234 | if args.seed is not None: 235 | set_seed(args.seed) 236 | 237 | # # Handle the repository creation 238 | # if accelerator.is_main_process: 239 | # if args.push_to_hub: 240 | # if args.hub_model_id is None: 241 | # repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 242 | # else: 243 | # repo_name = args.hub_model_id 244 | # repo = Repository(args.output_dir, clone_from=repo_name) 245 | 246 | # with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 247 | # if "step_*" not in gitignore: 248 | # gitignore.write("step_*\n") 249 | # if "epoch_*" not in gitignore: 250 | # gitignore.write("epoch_*\n") 251 | # elif args.output_dir is not None: 252 | # os.makedirs(args.output_dir, exist_ok=True) 253 | # accelerator.wait_for_everyone() 254 | 255 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 256 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 257 | 258 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 259 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 260 | # label if at least two columns are provided. 261 | 262 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 263 | # single column. You can easily tweak this behavior (see below) 264 | 265 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 266 | # download the dataset. 267 | if args.task_name is not None: 268 | # Downloading and loading a dataset from the hub. 269 | raw_datasets = load_dataset("glue", args.task_name) 270 | else: 271 | # Loading the dataset from local csv or json file. 272 | data_files = {} 273 | if args.train_file is not None: 274 | data_files["train"] = args.train_file 275 | if args.validation_file is not None: 276 | data_files["validation"] = args.validation_file 277 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] 278 | raw_datasets = load_dataset(extension, data_files=data_files) 279 | # See more about loading any type of standard or custom dataset at 280 | # https://huggingface.co/docs/datasets/loading_datasets.html. 281 | 282 | # Labels 283 | if args.task_name is not None: 284 | is_regression = args.task_name == "stsb" 285 | if not is_regression: 286 | label_list = raw_datasets["train"].features["label"].names 287 | num_labels = len(label_list) 288 | else: 289 | num_labels = 1 290 | else: 291 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 292 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 293 | if is_regression: 294 | num_labels = 1 295 | else: 296 | # A useful fast method: 297 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 298 | label_list = raw_datasets["train"].unique("label") 299 | label_list.sort() # Let's sort it for determinism 300 | num_labels = len(label_list) 301 | 302 | # Load pretrained model and tokenizer 303 | # 304 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 305 | # download model & vocab. 306 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) 307 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 308 | model = AutoModelForSequenceClassification.from_pretrained( 309 | args.model_name_or_path, 310 | from_tf=bool(".ckpt" in args.model_name_or_path), 311 | config=config, 312 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 313 | ) 314 | # 对非HF官方模型的名称的处理,只保留模型名 315 | if '/' in args.model_name_or_path: 316 | args.model_name_or_path = args.model_name_or_path.split('/')[-1] 317 | 318 | # Preprocessing the datasets 319 | # --------------- GLUE tasks --------------- 320 | if args.task_name is not None: 321 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 322 | # --------------- Other tasks --------------- 323 | else: 324 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 325 | # 这里的逻辑是这样的: 326 | # 对于非glue的数据集,要求要包含`label`字段 327 | # 然后希望你有`sentence1`, `sentence2`这两个字段,这样就跟glue对齐了 328 | # 如果你也不是用的这个名字,那就选择非label列的前两个字段来分别作为sentence1和sentence2 329 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 330 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 331 | sentence1_key, sentence2_key = "sentence1", "sentence2" 332 | else: 333 | if len(non_label_column_names) >= 2: 334 | sentence1_key, sentence2_key = non_label_column_names[:2] 335 | else: 336 | sentence1_key, sentence2_key = non_label_column_names[0], None 337 | 338 | # Some models have set the order of the labels to use, so let's make sure we do use it. 339 | label_to_id = None 340 | if ( 341 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 342 | and args.task_name is not None 343 | and not is_regression 344 | ): 345 | # Some have all caps in their config, some don't. 346 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 347 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 348 | logger.info( 349 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " 350 | "Using it!" 351 | ) 352 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} 353 | else: 354 | logger.warning( 355 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 356 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 357 | "\nIgnoring the model labels as a result.", 358 | ) 359 | elif args.task_name is None and not is_regression: 360 | label_to_id = {v: i for i, v in enumerate(label_list)} 361 | 362 | if label_to_id is not None: 363 | model.config.label2id = label_to_id 364 | model.config.id2label = {id: label for label, id in config.label2id.items()} 365 | elif args.task_name is not None and not is_regression: 366 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 367 | model.config.id2label = {id: label for label, id in config.label2id.items()} 368 | 369 | padding = "max_length" if args.pad_to_max_length else False 370 | 371 | def preprocess_function(examples): 372 | # Tokenize the texts 373 | texts = ( 374 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 375 | ) 376 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) 377 | 378 | if "label" in examples: 379 | if label_to_id is not None: 380 | # Map labels to IDs (not necessary for GLUE tasks) 381 | result["labels"] = [label_to_id[l] for l in examples["label"]] 382 | else: 383 | # In all cases, rename the column to labels because the model will expect that. 384 | result["labels"] = examples["label"] 385 | return result 386 | 387 | with accelerator.main_process_first(): 388 | processed_datasets = raw_datasets.map( 389 | preprocess_function, 390 | batched=True, 391 | # 得把这行改掉: 392 | # 以SST2为例,这里会把 ['sentence', 'label', 'idx'] 给去掉(不用担心label,因为上面已经新建了一个labels列) 393 | # remove_columns=raw_datasets["train"].column_names, 394 | # 改为: 395 | remove_columns=[c for c in raw_datasets["train"].column_names if c != 'idx'], # 保留idx,其他的可以去掉 396 | desc="Running tokenizer on dataset", 397 | ) 398 | 399 | train_dataset = processed_datasets["train"] 400 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] 401 | 402 | # Log a few random samples from the training set: 403 | for index in random.sample(range(len(train_dataset)), 3): 404 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 405 | 406 | # DataLoaders creation: 407 | if args.pad_to_max_length: 408 | # If padding was already done ot max length, we use the default data collator that will just convert everything 409 | # to tensors. 410 | data_collator = default_data_collator 411 | else: 412 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 413 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 414 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 415 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 416 | 417 | train_dataloader = DataLoader( 418 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 419 | ) 420 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 421 | 422 | # Optimizer 423 | # Split weights in two groups, one with weight decay and the other not. 424 | no_decay = ["bias", "LayerNorm.weight"] 425 | optimizer_grouped_parameters = [ 426 | { 427 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 428 | "weight_decay": args.weight_decay, 429 | }, 430 | { 431 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 432 | "weight_decay": 0.0, 433 | }, 434 | ] 435 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 436 | 437 | # Scheduler and math around the number of training steps. 438 | overrode_max_train_steps = False 439 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 440 | if args.max_train_steps is None: 441 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 442 | overrode_max_train_steps = True 443 | 444 | lr_scheduler = get_scheduler( 445 | name=args.lr_scheduler_type, 446 | optimizer=optimizer, 447 | num_warmup_steps=args.num_warmup_steps, 448 | num_training_steps=args.max_train_steps, 449 | ) 450 | 451 | # Prepare everything with our `accelerator`. 452 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 453 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 454 | ) 455 | 456 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 457 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 458 | if overrode_max_train_steps: 459 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 460 | # Afterwards we recalculate our number of training epochs 461 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 462 | 463 | # Figure out how many steps we should save the Accelerator states 464 | if hasattr(args.checkpointing_steps, "isdigit"): 465 | checkpointing_steps = args.checkpointing_steps 466 | if args.checkpointing_steps.isdigit(): 467 | checkpointing_steps = int(args.checkpointing_steps) 468 | else: 469 | checkpointing_steps = None 470 | 471 | # We need to initialize the trackers we use, and also store our configuration. 472 | # We initialize the trackers only on main process because `accelerator.log` 473 | # only logs on main process and we don't want empty logs/runs on other processes. 474 | if args.with_tracking: 475 | if accelerator.is_main_process: 476 | experiment_config = vars(args) 477 | # TensorBoard cannot log Enums, need the raw value 478 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 479 | accelerator.init_trackers("glue_no_trainer", experiment_config) 480 | 481 | # Get the metric function 482 | if args.task_name is not None: 483 | metric = load_metric("glue", args.task_name) 484 | else: 485 | metric = load_metric("accuracy") 486 | 487 | # Train! 488 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 489 | 490 | logger.info("***** Running training *****") 491 | logger.info(f" Num examples = {len(train_dataset)}") 492 | logger.info(f" Num Epochs = {args.num_train_epochs}") 493 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 494 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 495 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 496 | logger.info(f" Total optimization steps = {args.max_train_steps}") 497 | # Only show the progress bar once on each machine. 498 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 499 | completed_steps = 0 500 | starting_epoch = 0 501 | # Potentially load in the weights and states from a previous save 502 | if args.resume_from_checkpoint: 503 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 504 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 505 | accelerator.load_state(args.resume_from_checkpoint) 506 | path = os.path.basename(args.resume_from_checkpoint) 507 | else: 508 | # Get the most recent checkpoint 509 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 510 | dirs.sort(key=os.path.getctime) 511 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 512 | # Extract `epoch_{i}` or `step_{i}` 513 | training_difference = os.path.splitext(path)[0] 514 | 515 | if "epoch" in training_difference: 516 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 517 | resume_step = None 518 | else: 519 | resume_step = int(training_difference.replace("step_", "")) 520 | starting_epoch = resume_step // len(train_dataloader) 521 | resume_step -= starting_epoch * len(train_dataloader) 522 | 523 | # ============================ Training Loop ============================ 524 | for epoch in range(starting_epoch, args.num_train_epochs): 525 | if accelerator.is_main_process: 526 | if not os.path.exists(f'dy_log/{args.task_name}/'): 527 | os.mkdir(f'dy_log/{args.task_name}/') 528 | if not os.path.exists(f'dy_log/{args.task_name}/{args.model_name_or_path}'): 529 | os.mkdir(f'dy_log/{args.task_name}/{args.model_name_or_path}') 530 | log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/' 531 | if not os.path.exists(log_path): 532 | os.mkdir(log_path) 533 | 534 | accelerator.wait_for_everyone() # 只在 main process 里面创建文件夹,然后让其他 process 等待 main process 创建完毕 535 | log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/' 536 | print('-*-*-*- ',log_path, os.path.exists(log_path),accelerator.device) 537 | 538 | model.train() 539 | if args.with_tracking: 540 | total_loss = 0 541 | for step, batch in enumerate(train_dataloader): 542 | # We need to skip steps until we reach the resumed step 543 | if args.resume_from_checkpoint and epoch == starting_epoch: 544 | if resume_step is not None and step < resume_step: 545 | completed_steps += 1 546 | continue 547 | # batch中包含了idx字段,这里需要去除 548 | batch = {k:v for k,v in batch.items() if k != 'idx'} 549 | outputs = model(**batch) 550 | loss = outputs.loss 551 | # We keep track of the loss at each epoch 552 | if args.with_tracking: 553 | total_loss += loss.detach().float() 554 | loss = loss / args.gradient_accumulation_steps 555 | accelerator.backward(loss) 556 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 557 | optimizer.step() 558 | lr_scheduler.step() 559 | optimizer.zero_grad() 560 | progress_bar.update(1) 561 | completed_steps += 1 562 | 563 | if isinstance(checkpointing_steps, int): 564 | if completed_steps % checkpointing_steps == 0: 565 | output_dir = f"step_{completed_steps }" 566 | if args.output_dir is not None: 567 | output_dir = os.path.join(args.output_dir, output_dir) 568 | accelerator.save_state(output_dir) 569 | 570 | if completed_steps >= args.max_train_steps: 571 | break 572 | # ------------------ Recording Training Dynamics -------------------- 573 | # 在每一个epoch之后,对train set所有样本再过一遍,记录dynamics 574 | # 每个epoch单独一个文件 575 | logger.info('---------- Recording Training Dynamics (Epoch %s) -----------'%epoch) 576 | training_dynamics = [] 577 | all_ids = [] 578 | for step, batch in enumerate(tqdm(train_dataloader)): 579 | # print('- - - - - - - - - - ',len(batch['idx']), accelerator.device) 580 | idx_list = batch['idx']#.tolist() 581 | label_list = batch['labels']#.tolist() 582 | batch = {k:v for k,v in batch.items() if k != 'idx'} 583 | logits_list = model(**batch).logits#.tolist() # [[],[],[],...] batch_size个[] 584 | # 这里的关键:通过 gather 把每个 GPU上的结果合并 585 | # 由于在使用多卡训练时,不同卡可能存在样本的重复,同一个卡也会对最后一个batch进行补齐,也会样本重复 586 | # 使用 gather 的话,就可以按照原来的分配方式,逆着组合回去,就不用你自己处理了 587 | # gather 之后的,在每个卡上,下述变量里包含的数量,都等同于只使用单卡进行训练时的数量 588 | # 所以下面的for训练执行完之后,training_dynamics里就包含了全部样本,你在写入文件时,记住只在一个 process 中写入 589 | idx_list, label_list, logits_list = accelerator.gather((idx_list, label_list, logits_list)) 590 | # print('idx_list', idx_list.shape, accelerator.device) 591 | # print('label_list', label_list.shape, accelerator.device) 592 | 593 | for idx, label, logits in zip(idx_list.tolist(), label_list.tolist(), logits_list.tolist()): 594 | if idx in all_ids: # 由于 data_loader 可能会对最后一个 batch 进行补全,所以这里要去掉重复的样本 595 | continue 596 | all_ids.append(idx) 597 | record = {'guid': idx, 'logits_epoch_%s'%epoch: logits, 'gold': label, 'device':str(accelerator.device)} 598 | training_dynamics.append(record) 599 | print(len(all_ids),len(list(set(all_ids))),str(accelerator.device)) 600 | 601 | print('---- Num of training_dynamics: ',len(training_dynamics),' Device: ', str(accelerator.device)) 602 | if accelerator.is_main_process: 603 | assert os.path.exists(log_path),log_path 604 | writer = open(log_path + f'dynamics_epoch_{epoch}.jsonl', 'w') 605 | for record in training_dynamics: 606 | writer.write(json.dumps(record) + "\n") 607 | logger.info(f'Epoch {epoch} Saved to [{log_path}]') 608 | writer.close() 609 | accelerator.wait_for_everyone() 610 | 611 | # ------------------------------------------------------------------------ 612 | 613 | 614 | 615 | # evaluation (validation set) 616 | model.eval() 617 | samples_seen = 0 618 | for step, batch in enumerate(eval_dataloader): 619 | batch = {k:v for k,v in batch.items() if k != 'idx'} # 需不需要设置device ? 620 | with torch.no_grad(): 621 | outputs = model(**batch) 622 | predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() 623 | predictions, references = accelerator.gather((predictions, batch["labels"])) 624 | # If we are in a multiprocess environment, the last batch has duplicates 625 | if accelerator.num_processes > 1: 626 | if step == len(eval_dataloader) - 1: 627 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 628 | references = references[: len(eval_dataloader.dataset) - samples_seen] 629 | else: 630 | samples_seen += references.shape[0] 631 | metric.add_batch( 632 | predictions=predictions, 633 | references=references, 634 | ) 635 | 636 | eval_metric = metric.compute() 637 | logger.info(f"epoch {epoch}: {eval_metric}") 638 | 639 | if args.with_tracking: 640 | accelerator.log( 641 | { 642 | "accuracy" if args.task_name is not None else "glue": eval_metric, 643 | "train_loss": total_loss.item() / len(train_dataloader), 644 | "epoch": epoch, 645 | "step": completed_steps, 646 | }, 647 | step=completed_steps, 648 | ) 649 | 650 | # if args.push_to_hub and epoch < args.num_train_epochs - 1: 651 | # accelerator.wait_for_everyone() 652 | # unwrapped_model = accelerator.unwrap_model(model) 653 | # unwrapped_model.save_pretrained( 654 | # args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 655 | # ) 656 | # if accelerator.is_main_process: 657 | # tokenizer.save_pretrained(args.output_dir) 658 | # repo.push_to_hub( 659 | # commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True 660 | # ) 661 | 662 | if args.checkpointing_steps == "epoch": 663 | output_dir = f"epoch_{epoch}" 664 | if args.output_dir is not None: 665 | output_dir = os.path.join(args.output_dir, output_dir) 666 | accelerator.save_state(output_dir) 667 | # ============================ End Training Loop ============================ 668 | 669 | 670 | if args.output_dir is not None: 671 | accelerator.wait_for_everyone() 672 | unwrapped_model = accelerator.unwrap_model(model) 673 | unwrapped_model.save_pretrained( 674 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 675 | ) 676 | if accelerator.is_main_process: 677 | tokenizer.save_pretrained(args.output_dir) 678 | # if args.push_to_hub: 679 | # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 680 | 681 | if args.task_name == "mnli": 682 | # Final evaluation on mismatched validation set 683 | eval_dataset = processed_datasets["validation_mismatched"] 684 | eval_dataloader = DataLoader( 685 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 686 | ) 687 | eval_dataloader = accelerator.prepare(eval_dataloader) 688 | 689 | model.eval() 690 | for step, batch in enumerate(eval_dataloader): 691 | outputs = model(**batch) 692 | predictions = outputs.logits.argmax(dim=-1) 693 | metric.add_batch( 694 | predictions=accelerator.gather(predictions), 695 | references=accelerator.gather(batch["labels"]), 696 | ) 697 | 698 | eval_metric = metric.compute() 699 | logger.info(f"mnli-mm: {eval_metric}") 700 | 701 | if args.output_dir is not None: 702 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 703 | json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) 704 | 705 | 706 | if __name__ == "__main__": 707 | main() -------------------------------------------------------------------------------- /hardness_classification_prepare.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# hardness classification\n", 8 | "- use 'data_selection.py' set `-proportion 0.5` to split all samples into easy and hard.\n", 9 | "- then read the outputted 'three_regions_data_indices.json' to get their indices" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 5, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import json\n", 19 | "dataset_name = 'mrpc'\n", 20 | "with open(f'dy_log/{dataset_name}/bert-base-cased/three_regions_data_indices.json' ,'r') as f:\n", 21 | " d = json.loads(f.read())" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 6, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "1210 1210\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "print(len(d['hard']), len(d['easy']))\n", 39 | "assert len(set(d['hard']).intersection(set(d['easy']))) == 0" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 7, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "Reusing dataset glue (/home/v-biyangguo/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" 52 | ] 53 | }, 54 | { 55 | "data": { 56 | "application/vnd.jupyter.widget-view+json": { 57 | "model_id": "008e864d865b4e2e8ace44c252f27211", 58 | "version_major": 2, 59 | "version_minor": 0 60 | }, 61 | "text/plain": [ 62 | " 0%| | 0/3 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0measy_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'easy'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mhard_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'hard'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 125 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 522\u001b[0m }\n\u001b[1;32m 523\u001b[0m \u001b[0;31m# apply actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"DatasetDict\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;31m# re-apply format to the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 126 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/fingerprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0;31m# Call actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 480\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;31m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 127 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mselect\u001b[0;34m(self, indices, keep_in_memory, indices_cache_file_name, writer_batch_size, new_fingerprint)\u001b[0m\n\u001b[1;32m 3073\u001b[0m \u001b[0mindices_cache_file_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindices_cache_file_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3074\u001b[0m \u001b[0mwriter_batch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwriter_batch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3075\u001b[0;31m \u001b[0mnew_fingerprint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_fingerprint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3076\u001b[0m )\n\u001b[1;32m 3077\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 128 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 522\u001b[0m }\n\u001b[1;32m 523\u001b[0m \u001b[0;31m# apply actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"DatasetDict\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;31m# re-apply format to the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 129 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/fingerprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0;31m# Call actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 480\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;31m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 130 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_select_with_indices_mapping\u001b[0;34m(self, indices, keep_in_memory, indices_cache_file_name, writer_batch_size, new_fingerprint)\u001b[0m\n\u001b[1;32m 3199\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3200\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3201\u001b[0;31m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3202\u001b[0m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3203\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 131 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_check_valid_indices_value\u001b[0;34m(index, size)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 611\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mIndexError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Index {index} out of range for dataset of size {size}.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 612\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 132 | "\u001b[0;31mIndexError\u001b[0m: Index 4073 out of range for dataset of size 3668." 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "easy_data = data['train'].select(d['easy'])\n", 138 | "hard_data = data['train'].select(d['hard'])" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "application/vnd.jupyter.widget-view+json": { 149 | "model_id": "a6bf20ded5684a05859ec3f574323c93", 150 | "version_major": 2, 151 | "version_minor": 0 152 | }, 153 | "text/plain": [ 154 | "Flattening the indices: 0%| | 0/53 [00:00\n", 541 | "\n", 554 | "\n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | "
confidencecorrectnessforgetfulnessguidindexthreshold_closenessvariability
00.72953950115400.1973120.188230
10.71976150157410.2017050.147285
20.7246134169820.1995490.205787
30.7643565037730.1801160.162303
40.80028550243740.1598290.177974
\n", 620 | "" 621 | ], 622 | "text/plain": [ 623 | " confidence correctness forgetfulness guid index threshold_closeness \\\n", 624 | "0 0.729539 5 0 1154 0 0.197312 \n", 625 | "1 0.719761 5 0 1574 1 0.201705 \n", 626 | "2 0.724613 4 1 698 2 0.199549 \n", 627 | "3 0.764356 5 0 377 3 0.180116 \n", 628 | "4 0.800285 5 0 2437 4 0.159829 \n", 629 | "\n", 630 | " variability \n", 631 | "0 0.188230 \n", 632 | "1 0.147285 \n", 633 | "2 0.205787 \n", 634 | "3 0.162303 \n", 635 | "4 0.177974 " 636 | ] 637 | }, 638 | "execution_count": 23, 639 | "metadata": {}, 640 | "output_type": "execute_result" 641 | } 642 | ], 643 | "source": [ 644 | "import pandas as pd\n", 645 | "td_df = pd.read_json(f'dy_log/{dataset_name}/bert-base-cased/td_metrics.jsonl', lines=True) # lines=True 是因为你加载的是jsonl文件,每行都是一个dictionary\n", 646 | "print(len(td_df))\n", 647 | "td_df.head()" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 24, 653 | "metadata": {}, 654 | "outputs": [ 655 | { 656 | "data": { 657 | "text/plain": [ 658 | "2490" 659 | ] 660 | }, 661 | "execution_count": 24, 662 | "metadata": {}, 663 | "output_type": "execute_result" 664 | } 665 | ], 666 | "source": [ 667 | "id2conf = {}\n", 668 | "for guid, conf in zip(list(td_df['guid']), list(td_df['confidence'])):\n", 669 | " id2conf[guid] = conf\n", 670 | "len(id2conf)" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 25, 676 | "metadata": {}, 677 | "outputs": [ 678 | { 679 | "data": { 680 | "text/plain": [ 681 | "DatasetDict({\n", 682 | " train: Dataset({\n", 683 | " features: ['sentence1', 'sentence2', 'idx', 'label', 'confidence'],\n", 684 | " num_rows: 2490\n", 685 | " })\n", 686 | " validation: Dataset({\n", 687 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n", 688 | " num_rows: 277\n", 689 | " })\n", 690 | " test: Dataset({\n", 691 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n", 692 | " num_rows: 3000\n", 693 | " })\n", 694 | "})" 695 | ] 696 | }, 697 | "execution_count": 25, 698 | "metadata": {}, 699 | "output_type": "execute_result" 700 | } 701 | ], 702 | "source": [ 703 | "train_ids = data['train']['idx']\n", 704 | "train_conf = [id2conf[id] for id in train_ids]\n", 705 | "data['train'] = data['train'].add_column('confidence', train_conf)\n", 706 | "\n", 707 | "data" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 26, 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "data": { 717 | "text/plain": [ 718 | "[0.6924006224,\n", 719 | " 0.5262133539,\n", 720 | " 0.7324662924,\n", 721 | " 0.7629377484000001,\n", 722 | " 0.5932154834000001,\n", 723 | " 0.6104391217,\n", 724 | " 0.4496941358,\n", 725 | " 0.6968309402,\n", 726 | " 0.6517617762,\n", 727 | " 0.7604392648]" 728 | ] 729 | }, 730 | "execution_count": 26, 731 | "metadata": {}, 732 | "output_type": "execute_result" 733 | } 734 | ], 735 | "source": [ 736 | "data['train']['confidence'][:10]" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": 27, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "data.save_to_disk(f\"datasets/{dataset_name}/with_conf/\")" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 7, 751 | "metadata": {}, 752 | "outputs": [ 753 | { 754 | "data": { 755 | "text/plain": [ 756 | "DatasetDict({\n", 757 | " test: Dataset({\n", 758 | " features: ['premise', 'hypothesis', 'label'],\n", 759 | " num_rows: 9824\n", 760 | " })\n", 761 | " train: Dataset({\n", 762 | " features: ['premise', 'hypothesis', 'label', 'idx', 'confidence'],\n", 763 | " num_rows: 549367\n", 764 | " })\n", 765 | " validation: Dataset({\n", 766 | " features: ['premise', 'hypothesis', 'label'],\n", 767 | " num_rows: 9842\n", 768 | " })\n", 769 | "})" 770 | ] 771 | }, 772 | "execution_count": 7, 773 | "metadata": {}, 774 | "output_type": "execute_result" 775 | } 776 | ], 777 | "source": [ 778 | "# 读取\n", 779 | "from datasets import load_from_disk\n", 780 | "reloaded_data = load_from_disk(f\"datasets/{dataset_name}/with_conf/\")\n", 781 | "reloaded_data\n" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": 9, 787 | "metadata": {}, 788 | "outputs": [ 789 | { 790 | "data": { 791 | "text/plain": [ 792 | "[1, 0, 2, 1, 0, 2, 0, 1, 2, 1]" 793 | ] 794 | }, 795 | "execution_count": 9, 796 | "metadata": {}, 797 | "output_type": "execute_result" 798 | } 799 | ], 800 | "source": [ 801 | "reloaded_data['test']['label'][:10]" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 10, 807 | "metadata": {}, 808 | "outputs": [ 809 | { 810 | "data": { 811 | "text/plain": [ 812 | "[1, 0, 2, 0, 1, 2, 2, 1, 0, 0]" 813 | ] 814 | }, 815 | "execution_count": 10, 816 | "metadata": {}, 817 | "output_type": "execute_result" 818 | } 819 | ], 820 | "source": [ 821 | "reloaded_data['validation']['label'][:10]" 822 | ] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "execution_count": null, 827 | "metadata": {}, 828 | "outputs": [], 829 | "source": [] 830 | } 831 | ], 832 | "metadata": { 833 | "interpreter": { 834 | "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f" 835 | }, 836 | "kernelspec": { 837 | "display_name": "Python 3.6.10 64-bit ('base': conda)", 838 | "name": "python3" 839 | }, 840 | "language_info": { 841 | "name": "python", 842 | "version": "" 843 | }, 844 | "orig_nbformat": 4 845 | }, 846 | "nbformat": 4, 847 | "nbformat_minor": 2 848 | } -------------------------------------------------------------------------------- /HCT/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import json\n", 11 | "import logging\n", 12 | "import math\n", 13 | "import os\n", 14 | "import random\n", 15 | "from pathlib import Path\n", 16 | "\n", 17 | "import datasets\n", 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "import torch.nn.functional as F\n", 21 | "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n", 22 | "\n", 23 | "from datasets import load_dataset, load_metric\n", 24 | "from torch.utils.data import DataLoader\n", 25 | "from tqdm.auto import tqdm\n", 26 | "\n", 27 | "import transformers\n", 28 | "from transformers import (\n", 29 | " AutoConfig,\n", 30 | " AutoModel,\n", 31 | " AutoModelForSequenceClassification,\n", 32 | " AutoTokenizer,\n", 33 | " DataCollatorWithPadding,\n", 34 | " PretrainedConfig,\n", 35 | " SchedulerType,\n", 36 | " default_data_collator,\n", 37 | " get_scheduler,\n", 38 | ")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "Downloading and preparing dataset glue/sst2 (download: 7.09 MiB, generated: 4.81 MiB, post-processed: Unknown size, total: 11.90 MiB) to /home/v-biyangguo/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...\n" 51 | ] 52 | }, 53 | { 54 | "data": { 55 | "application/vnd.jupyter.widget-view+json": { 56 | "model_id": "9274b98cf63e4e20b154556afc1bcf91", 57 | "version_major": 2, 58 | "version_minor": 0 59 | }, 60 | "text/plain": [ 61 | "Downloading data: 0%| | 0.00/7.44M [00:00)" 416 | ] 417 | }, 418 | "execution_count": 17, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "# model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')\n", 425 | "logits = model(**batch).logits\n", 426 | "logits" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 20, 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "data": { 436 | "text/plain": [ 437 | "tensor(7.4759e-09, grad_fn=)" 438 | ] 439 | }, 440 | "execution_count": 20, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "# from torch import nn\n", 447 | "# kld_loss_fct = nn.KLDivLoss(reduction=\"batchmean\")\n", 448 | "\n", 449 | "kld_loss_fct(\n", 450 | " nn.functional.log_softmax(logits / 1, dim=-1),\n", 451 | " nn.functional.softmax(logits / 1, dim=-1),\n", 452 | " ) * (1) ** 2" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 30, 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "data": { 462 | "text/plain": [ 463 | "(torch.Size([8, 2]),\n", 464 | " torch.Size([8, 2]),\n", 465 | " tensor([[0.6097, 0.3903],\n", 466 | " [0.5698, 0.4302],\n", 467 | " [0.5829, 0.4171],\n", 468 | " [0.5860, 0.4140],\n", 469 | " [0.5920, 0.4080],\n", 470 | " [0.5839, 0.4161],\n", 471 | " [0.5857, 0.4143],\n", 472 | " [0.5875, 0.4125]], grad_fn=))" 473 | ] 474 | }, 475 | "execution_count": 30, 476 | "metadata": {}, 477 | "output_type": "execute_result" 478 | } 479 | ], 480 | "source": [ 481 | "T = 1\n", 482 | "logits_gate = hardness_gate(pooled_output)\n", 483 | "gate_weights = F.softmax(logits_gate/T, dim=1)\n", 484 | "logits_gate.shape, gate_weights.shape, gate_weights" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "## 只使用 `easy expert` 进行训练和预测\n", 492 | "weight rank:\n", 493 | "- ambi-easy\n", 494 | "- easy\n", 495 | "- ambi-hard\n", 496 | "- hard" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 31, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "data": { 506 | "text/plain": [ 507 | "tensor([0.6097, 0.5698, 0.5829, 0.5860, 0.5920, 0.5839, 0.5857, 0.5875],\n", 508 | " grad_fn=)" 509 | ] 510 | }, 511 | "execution_count": 31, 512 | "metadata": {}, 513 | "output_type": "execute_result" 514 | } 515 | ], 516 | "source": [ 517 | "easy_probs = gate_weights[:,0]\n", 518 | "easy_probs" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 37, 524 | "metadata": {}, 525 | "outputs": [ 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "(tensor([0.8903, 0.9302, 0.9171, 0.9140, 0.9080, 0.9161, 0.9143, 0.9125],\n", 530 | " grad_fn=),\n", 531 | " tensor([0.9753, 1.0191, 1.0047, 1.0013, 0.9947, 1.0036, 1.0017, 0.9997],\n", 532 | " grad_fn=))" 533 | ] 534 | }, 535 | "execution_count": 37, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | } 539 | ], 540 | "source": [ 541 | "easy_weights = torch.where(easy_probs>0.5, 1-torch.abs(easy_probs-0.5), easy_probs)\n", 542 | "# 归一化\n", 543 | "batch_size = 8\n", 544 | "easy_weights, easy_weights * batch_size / torch.sum(easy_weights)" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 40, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/plain": [ 555 | "(tensor([0.4060, 0.3822, 0.9169, 0.4725, 1.1667, 0.2972, 0.4749, 0.9775],\n", 556 | " grad_fn=),\n", 557 | " tensor(0.6367, grad_fn=))" 558 | ] 559 | }, 560 | "execution_count": 40, 561 | "metadata": {}, 562 | "output_type": "execute_result" 563 | } 564 | ], 565 | "source": [ 566 | "easy_weights = easy_weights * batch_size / torch.sum(easy_weights)\n", 567 | "example_loss = torch.tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778])\n", 568 | "easy_weights * example_loss, torch.mean(easy_weights * example_loss)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": {}, 574 | "source": [ 575 | "end of 只使用 `easy expert` 进行训练和预测.\n", 576 | "---" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 22, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "data": { 586 | "text/plain": [ 587 | "(tensor([[0.7000, 0.3000],\n", 588 | " [0.4000, 0.6000]]),\n", 589 | " tensor([[0.8000, 0.3000],\n", 590 | " [0.4000, 0.9000]]))" 591 | ] 592 | }, 593 | "execution_count": 22, 594 | "metadata": {}, 595 | "output_type": "execute_result" 596 | } 597 | ], 598 | "source": [ 599 | "X = torch.tensor([[0.7,0.3],[0.4,0.6]])\n", 600 | "X, torch.where(X>0.5, 1-torch.abs(X-0.5), X)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 24, 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [ 609 | "pooled_output = dropout(pooled_output)\n", 610 | "logits_easy = classifier_easy(pooled_output)\n", 611 | "logits_hard = classifier_easy(pooled_output)\n" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 28, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/plain": [ 622 | "(tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000]),\n", 623 | " tensor([[0.7000],\n", 624 | " [0.7000],\n", 625 | " [0.7000],\n", 626 | " [0.7000],\n", 627 | " [0.7000],\n", 628 | " [0.7000],\n", 629 | " [0.7000],\n", 630 | " [0.7000]]),\n", 631 | " tensor([[0.3000],\n", 632 | " [0.3000],\n", 633 | " [0.3000],\n", 634 | " [0.3000],\n", 635 | " [0.3000],\n", 636 | " [0.3000],\n", 637 | " [0.3000],\n", 638 | " [0.3000]]),\n", 639 | " tensor([[0.7000, 0.3000],\n", 640 | " [0.7000, 0.3000],\n", 641 | " [0.7000, 0.3000],\n", 642 | " [0.7000, 0.3000],\n", 643 | " [0.7000, 0.3000],\n", 644 | " [0.7000, 0.3000],\n", 645 | " [0.7000, 0.3000],\n", 646 | " [0.7000, 0.3000]]))" 647 | ] 648 | }, 649 | "execution_count": 28, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "confidences = torch.tensor([0.7]*8)\n", 656 | "confidences\n", 657 | "easy_probs = confidences.view(-1,1)\n", 658 | "hard_probs = 1 - easy_probs\n", 659 | "hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)\n", 660 | "confidences, easy_probs, hard_probs, hardness_probs" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 35, 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "data": { 670 | "text/plain": [ 671 | "(tensor([[-0.6693, -0.7176],\n", 672 | " [-0.6422, -0.7469],\n", 673 | " [-0.6316, -0.7587],\n", 674 | " [-0.6843, -0.7021],\n", 675 | " [-0.6417, -0.7474],\n", 676 | " [-0.6898, -0.6965],\n", 677 | " [-0.6593, -0.7282],\n", 678 | " [-0.6960, -0.6903]], grad_fn=),\n", 679 | " tensor([[0.7000, 0.3000],\n", 680 | " [0.7000, 0.3000],\n", 681 | " [0.7000, 0.3000],\n", 682 | " [0.7000, 0.3000],\n", 683 | " [0.7000, 0.3000],\n", 684 | " [0.7000, 0.3000],\n", 685 | " [0.7000, 0.3000],\n", 686 | " [0.7000, 0.3000]]))" 687 | ] 688 | }, 689 | "execution_count": 35, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | } 693 | ], 694 | "source": [ 695 | "F.log_softmax(logits_gate, dim=-1),hardness_probs" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": 42, 701 | "metadata": {}, 702 | "outputs": [ 703 | { 704 | "data": { 705 | "text/plain": [ 706 | "tensor(0.0712, grad_fn=)" 707 | ] 708 | }, 709 | "execution_count": 42, 710 | "metadata": {}, 711 | "output_type": "execute_result" 712 | } 713 | ], 714 | "source": [ 715 | "loss_gate = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')\n", 716 | "loss_gate" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 50, 722 | "metadata": {}, 723 | "outputs": [ 724 | { 725 | "data": { 726 | "text/plain": [ 727 | "(tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],\n", 728 | " grad_fn=),\n", 729 | " tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],\n", 730 | " grad_fn=))" 731 | ] 732 | }, 733 | "execution_count": 50, 734 | "metadata": {}, 735 | "output_type": "execute_result" 736 | } 737 | ], 738 | "source": [ 739 | "num_labels = config.num_labels\n", 740 | "labels = batch['labels']\n", 741 | "loss_fct = CrossEntropyLoss(reduction='none')\n", 742 | "loss_easy = loss_fct(logits_easy.view(-1, num_labels), labels.view(-1))\n", 743 | "loss_hard = loss_fct(logits_hard.view(-1, num_labels), labels.view(-1))\n", 744 | "loss_easy, loss_hard" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 54, 750 | "metadata": {}, 751 | "outputs": [ 752 | { 753 | "data": { 754 | "text/plain": [ 755 | "tensor([[0.4163, 0.4163],\n", 756 | " [0.3751, 0.3751],\n", 757 | " [0.9126, 0.9126],\n", 758 | " [0.4719, 0.4719],\n", 759 | " [1.1729, 1.1729],\n", 760 | " [0.2961, 0.2961],\n", 761 | " [0.4741, 0.4741],\n", 762 | " [0.9778, 0.9778]], grad_fn=)" 763 | ] 764 | }, 765 | "execution_count": 54, 766 | "metadata": {}, 767 | "output_type": "execute_result" 768 | } 769 | ], 770 | "source": [ 771 | "easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)\n", 772 | "easy_hard_loss_cat" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": 55, 778 | "metadata": {}, 779 | "outputs": [ 780 | { 781 | "data": { 782 | "text/plain": [ 783 | "(tensor([[0.4163, 0.4163],\n", 784 | " [0.3751, 0.3751],\n", 785 | " [0.9126, 0.9126],\n", 786 | " [0.4719, 0.4719],\n", 787 | " [1.1729, 1.1729],\n", 788 | " [0.2961, 0.2961],\n", 789 | " [0.4741, 0.4741],\n", 790 | " [0.9778, 0.9778]], grad_fn=),\n", 791 | " tensor([[0.5121, 0.4879],\n", 792 | " [0.5262, 0.4738],\n", 793 | " [0.5317, 0.4683],\n", 794 | " [0.5044, 0.4956],\n", 795 | " [0.5264, 0.4736],\n", 796 | " [0.5017, 0.4983],\n", 797 | " [0.5172, 0.4828],\n", 798 | " [0.4986, 0.5014]], grad_fn=),\n", 799 | " tensor([[0.2132, 0.2031],\n", 800 | " [0.1974, 0.1777],\n", 801 | " [0.4853, 0.4274],\n", 802 | " [0.2380, 0.2338],\n", 803 | " [0.6174, 0.5555],\n", 804 | " [0.1486, 0.1476],\n", 805 | " [0.2452, 0.2289],\n", 806 | " [0.4875, 0.4903]], grad_fn=))" 807 | ] 808 | }, 809 | "execution_count": 55, 810 | "metadata": {}, 811 | "output_type": "execute_result" 812 | } 813 | ], 814 | "source": [ 815 | "easy_hard_loss_cat, gate_weights, easy_hard_loss_cat * gate_weights" 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": 80, 821 | "metadata": {}, 822 | "outputs": [ 823 | { 824 | "data": { 825 | "text/plain": [ 826 | "tensor(0.3185, grad_fn=)" 827 | ] 828 | }, 829 | "execution_count": 80, 830 | "metadata": {}, 831 | "output_type": "execute_result" 832 | } 833 | ], 834 | "source": [ 835 | "weighted_loss = easy_hard_loss_cat * gate_weights\n", 836 | "torch.mean(weighted_loss)" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": 111, 842 | "metadata": {}, 843 | "outputs": [ 844 | { 845 | "name": "stderr", 846 | "output_type": "stream", 847 | "text": [ 848 | "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']\n", 849 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 850 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 851 | ] 852 | } 853 | ], 854 | "source": [ 855 | "from transformers.modeling_outputs import SequenceClassifierOutput\n", 856 | "\n", 857 | "class HCTForSequenceClassification(nn.Module):\n", 858 | " def __init__(self, model_name_or_path, config):\n", 859 | " super(HCTForSequenceClassification, self).__init__()\n", 860 | " self.encoder = AutoModel.from_pretrained(model_name_or_path)\n", 861 | " self.config = config\n", 862 | " self.num_labels = config.num_labels\n", 863 | " self.classifier_dropout = (\n", 864 | " self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob\n", 865 | " )\n", 866 | " self.dropout = nn.Dropout(self.classifier_dropout)\n", 867 | " self.classifier_easy = nn.Linear(self.config.hidden_size, config.num_labels)\n", 868 | " self.classifier_hard = nn.Linear(self.config.hidden_size, config.num_labels)\n", 869 | " # 2 experts: easy or hard\n", 870 | " # gate output: 0 for easy, 1 for hard\n", 871 | " self.hardness_gate = nn.Linear(self.config.hidden_size,2) \n", 872 | "\n", 873 | "\n", 874 | " def forward(\n", 875 | " self,\n", 876 | " input_ids: Optional[torch.Tensor] = None,\n", 877 | " attention_mask: Optional[torch.Tensor] = None,\n", 878 | " token_type_ids: Optional[torch.Tensor] = None,\n", 879 | " position_ids: Optional[torch.Tensor] = None,\n", 880 | " head_mask: Optional[torch.Tensor] = None,\n", 881 | " inputs_embeds: Optional[torch.Tensor] = None,\n", 882 | " labels: Optional[torch.Tensor] = None,\n", 883 | " confidences: Optional[torch.Tensor] = None, # the confidence value of a sample\n", 884 | " output_attentions: Optional[bool] = None,\n", 885 | " output_hidden_states: Optional[bool] = None,\n", 886 | " return_dict: Optional[bool] = None,\n", 887 | " ):\n", 888 | " r\"\"\"\n", 889 | " labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n", 890 | " Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n", 891 | " config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n", 892 | " `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n", 893 | " \"\"\"\n", 894 | " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", 895 | "\n", 896 | " outputs = self.encoder(\n", 897 | " input_ids,\n", 898 | " attention_mask=attention_mask,\n", 899 | " token_type_ids=token_type_ids,\n", 900 | " position_ids=position_ids,\n", 901 | " head_mask=head_mask,\n", 902 | " inputs_embeds=inputs_embeds,\n", 903 | " output_attentions=output_attentions,\n", 904 | " output_hidden_states=output_hidden_states,\n", 905 | " return_dict=return_dict,\n", 906 | " )\n", 907 | " # the CLS vector\n", 908 | " pooled_output = outputs[1] \n", 909 | " # gating:\n", 910 | " logits_gate = self.hardness_gate(pooled_output)\n", 911 | " gate_weights = F.softmax(logits_gate)\n", 912 | " # easy/hard experts:\n", 913 | " pooled_output = self.dropout(pooled_output)\n", 914 | " logits_easy = self.classifier_easy(pooled_output)\n", 915 | " logits_hard = self.classifier_hard(pooled_output)\n", 916 | "\n", 917 | " loss = None\n", 918 | " loss_easy, loss_hard, gate_loss = None, None, None\n", 919 | " if labels is not None:\n", 920 | " if self.config.problem_type is None:\n", 921 | " if self.num_labels == 1:\n", 922 | " self.config.problem_type = \"regression\"\n", 923 | " elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n", 924 | " self.config.problem_type = \"single_label_classification\"\n", 925 | " else:\n", 926 | " self.config.problem_type = \"multi_label_classification\"\n", 927 | "\n", 928 | " # gating loss\n", 929 | " # confidences 其实相当于 easy 分支的概率,所以你还需要自己构造一个 hard prob\n", 930 | " easy_probs = confidences.view(-1,1)\n", 931 | " hard_probs = 1 - easy_probs\n", 932 | " hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)\n", 933 | " gate_loss = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')\n", 934 | " if self.config.problem_type == \"regression\":\n", 935 | " loss_fct = MSELoss(reduction='none')\n", 936 | " if self.num_labels == 1:\n", 937 | " loss_easy = loss_fct(logits_easy.squeeze(), labels.squeeze())\n", 938 | " loss_hard = loss_fct(logits_hard.squeeze(), labels.squeeze())\n", 939 | " else:\n", 940 | " loss_easy = loss_fct(logits_easy, labels)\n", 941 | " loss_hard = loss_fct(logits_hard, labels)\n", 942 | " elif self.config.problem_type == \"single_label_classification\":\n", 943 | " loss_fct = CrossEntropyLoss(reduction='none') # reduction='none', 来得到每个sample的loss\n", 944 | " loss_easy = loss_fct(logits_easy.view(-1, self.num_labels), labels.view(-1))\n", 945 | " loss_hard = loss_fct(logits_hard.view(-1, self.num_labels), labels.view(-1))\n", 946 | " elif self.config.problem_type == \"multi_label_classification\":\n", 947 | " loss_fct = BCEWithLogitsLoss(reduction='none')\n", 948 | " loss_easy = loss_fct(logits_easy, labels)\n", 949 | " loss_hard = loss_fct(logits_hard, labels)\n", 950 | " \n", 951 | " easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)\n", 952 | " weighted_loss = easy_hard_loss_cat * gate_weights\n", 953 | " clf_loss = torch.mean(weighted_loss)\n", 954 | " loss = gate_loss + clf_loss\n", 955 | " \n", 956 | " if not return_dict:\n", 957 | " output = (logits_gate,logits_easy, logits_hard,) + outputs[2:]\n", 958 | " return ((loss,) + output) if loss is not None else output\n", 959 | "\n", 960 | " return SequenceClassifierOutput(\n", 961 | " loss=loss,\n", 962 | " logits={\"gate\":logits_gate, \"easy\":logits_easy, \"hard\":logits_hard},\n", 963 | " hidden_states=outputs.hidden_states,\n", 964 | " attentions=outputs.attentions,\n", 965 | " )\n", 966 | "\n", 967 | "my_model = HCTForSequenceClassification('bert-base-cased', config)" 968 | ] 969 | }, 970 | { 971 | "cell_type": "code", 972 | "execution_count": 112, 973 | "metadata": {}, 974 | "outputs": [ 975 | { 976 | "name": "stderr", 977 | "output_type": "stream", 978 | "text": [ 979 | "ipykernel_launcher:57: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 980 | ] 981 | }, 982 | { 983 | "data": { 984 | "text/plain": [ 985 | "(tensor(0.5328, grad_fn=),\n", 986 | " {'gate': tensor([[0.4412, 0.2590],\n", 987 | " [0.4171, 0.3171],\n", 988 | " [0.4294, 0.3268],\n", 989 | " [0.4348, 0.3207],\n", 990 | " [0.3658, 0.3255],\n", 991 | " [0.4153, 0.2928],\n", 992 | " [0.3959, 0.2656],\n", 993 | " [0.4340, 0.2752]], grad_fn=),\n", 994 | " 'easy': tensor([[ 0.0633, 0.9459],\n", 995 | " [-0.1190, 0.8201],\n", 996 | " [ 0.1919, 0.9389],\n", 997 | " [-0.2952, 0.6612],\n", 998 | " [ 0.0297, 0.8347],\n", 999 | " [-0.2183, 1.2751],\n", 1000 | " [-0.1609, 1.0679],\n", 1001 | " [-0.1628, 0.9590]], grad_fn=),\n", 1002 | " 'hard': tensor([[-0.2337, 0.2637],\n", 1003 | " [-0.3422, 0.3396],\n", 1004 | " [-0.2908, 0.3610],\n", 1005 | " [-0.2687, 0.1167],\n", 1006 | " [-0.1619, -0.1485],\n", 1007 | " [-0.7175, 0.3527],\n", 1008 | " [-0.4886, 0.2593],\n", 1009 | " [-0.1387, 0.3856]], grad_fn=)})" 1010 | ] 1011 | }, 1012 | "execution_count": 112, 1013 | "metadata": {}, 1014 | "output_type": "execute_result" 1015 | } 1016 | ], 1017 | "source": [ 1018 | "my_outputs = my_model(**batch, confidences=confidences)\n", 1019 | "my_outputs.loss, my_outputs.logits" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "code", 1024 | "execution_count": 107, 1025 | "metadata": {}, 1026 | "outputs": [ 1027 | { 1028 | "data": { 1029 | "text/plain": [ 1030 | "(tensor([1, 1, 1, 1, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 1, 1, 1, 1]))" 1031 | ] 1032 | }, 1033 | "execution_count": 107, 1034 | "metadata": {}, 1035 | "output_type": "execute_result" 1036 | } 1037 | ], 1038 | "source": [ 1039 | "# indices = torch.argmax(my_outputs.logits['gate'],dim=1)\n", 1040 | "indices = torch.tensor([1,1,1,1,0,0,0,0])\n", 1041 | "indices, 1-indices" 1042 | ] 1043 | }, 1044 | { 1045 | "cell_type": "code", 1046 | "execution_count": 113, 1047 | "metadata": {}, 1048 | "outputs": [ 1049 | { 1050 | "data": { 1051 | "text/plain": [ 1052 | "tensor([[-0.2337, 0.2637],\n", 1053 | " [-0.3422, 0.3396],\n", 1054 | " [-0.2908, 0.3610],\n", 1055 | " [-0.2687, 0.1167],\n", 1056 | " [ 0.0297, 0.8347],\n", 1057 | " [-0.2183, 1.2751],\n", 1058 | " [-0.1609, 1.0679],\n", 1059 | " [-0.1628, 0.9590]], grad_fn=)" 1060 | ] 1061 | }, 1062 | "execution_count": 113, 1063 | "metadata": {}, 1064 | "output_type": "execute_result" 1065 | } 1066 | ], 1067 | "source": [ 1068 | "my_outputs.logits['easy'] * (1-indices).view(-1,1), my_outputs.logits['hard'] * indices.view(-1,1)\n", 1069 | "\n", 1070 | "my_outputs.logits['easy'] * (1-indices).view(-1,1) + my_outputs.logits['hard'] * indices.view(-1,1)" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "code", 1075 | "execution_count": 114, 1076 | "metadata": {}, 1077 | "outputs": [ 1078 | { 1079 | "data": { 1080 | "text/plain": [ 1081 | "(tensor([[ 0.0633, 0.9459],\n", 1082 | " [-0.1190, 0.8201],\n", 1083 | " [ 0.1919, 0.9389],\n", 1084 | " [-0.2952, 0.6612],\n", 1085 | " [ 0.0297, 0.8347],\n", 1086 | " [-0.2183, 1.2751],\n", 1087 | " [-0.1609, 1.0679],\n", 1088 | " [-0.1628, 0.9590]], grad_fn=),\n", 1089 | " tensor([[-0.2337, 0.2637],\n", 1090 | " [-0.3422, 0.3396],\n", 1091 | " [-0.2908, 0.3610],\n", 1092 | " [-0.2687, 0.1167],\n", 1093 | " [-0.1619, -0.1485],\n", 1094 | " [-0.7175, 0.3527],\n", 1095 | " [-0.4886, 0.2593],\n", 1096 | " [-0.1387, 0.3856]], grad_fn=))" 1097 | ] 1098 | }, 1099 | "execution_count": 114, 1100 | "metadata": {}, 1101 | "output_type": "execute_result" 1102 | } 1103 | ], 1104 | "source": [ 1105 | "my_outputs.logits['easy'], my_outputs.logits['hard']" 1106 | ] 1107 | }, 1108 | { 1109 | "cell_type": "code", 1110 | "execution_count": 117, 1111 | "metadata": {}, 1112 | "outputs": [ 1113 | { 1114 | "data": { 1115 | "text/plain": [ 1116 | "(tensor([[0.5454, 0.4546],\n", 1117 | " [0.5250, 0.4750],\n", 1118 | " [0.5256, 0.4744],\n", 1119 | " [0.5285, 0.4715],\n", 1120 | " [0.5101, 0.4899],\n", 1121 | " [0.5306, 0.4694],\n", 1122 | " [0.5325, 0.4675],\n", 1123 | " [0.5396, 0.4604]], grad_fn=),\n", 1124 | " tensor([0.5454, 0.5250, 0.5256, 0.5285, 0.5101, 0.5306, 0.5325, 0.5396],\n", 1125 | " grad_fn=),\n", 1126 | " tensor([0.4546, 0.4750, 0.4744, 0.4715, 0.4899, 0.4694, 0.4675, 0.4604],\n", 1127 | " grad_fn=))" 1128 | ] 1129 | }, 1130 | "execution_count": 117, 1131 | "metadata": {}, 1132 | "output_type": "execute_result" 1133 | } 1134 | ], 1135 | "source": [ 1136 | "weights = F.softmax(my_outputs.logits['gate'], dim=1)\n", 1137 | "weights, weights[:,0], weights[:,1] # 第0列是easy的权重,第1列是hard的权重" 1138 | ] 1139 | }, 1140 | { 1141 | "cell_type": "code", 1142 | "execution_count": 120, 1143 | "metadata": {}, 1144 | "outputs": [ 1145 | { 1146 | "data": { 1147 | "text/plain": [ 1148 | "(tensor([[ 0.0345, 0.5159],\n", 1149 | " [-0.0625, 0.4305],\n", 1150 | " [ 0.1009, 0.4935],\n", 1151 | " [-0.1560, 0.3494],\n", 1152 | " [ 0.0152, 0.4258],\n", 1153 | " [-0.1158, 0.6765],\n", 1154 | " [-0.0857, 0.5687],\n", 1155 | " [-0.0879, 0.5175]], grad_fn=),\n", 1156 | " tensor([[-0.1062, 0.1199],\n", 1157 | " [-0.1626, 0.1613],\n", 1158 | " [-0.1380, 0.1712],\n", 1159 | " [-0.1267, 0.0550],\n", 1160 | " [-0.0793, -0.0728],\n", 1161 | " [-0.3368, 0.1656],\n", 1162 | " [-0.2284, 0.1212],\n", 1163 | " [-0.0639, 0.1775]], grad_fn=))" 1164 | ] 1165 | }, 1166 | "execution_count": 120, 1167 | "metadata": {}, 1168 | "output_type": "execute_result" 1169 | } 1170 | ], 1171 | "source": [ 1172 | "my_outputs.logits['easy'] * weights[:,0].view(-1,1), my_outputs.logits['hard'] * weights[:,1].view(-1,1)" 1173 | ] 1174 | }, 1175 | { 1176 | "cell_type": "code", 1177 | "execution_count": 1, 1178 | "metadata": {}, 1179 | "outputs": [ 1180 | { 1181 | "name": "stderr", 1182 | "output_type": "stream", 1183 | "text": [ 1184 | "Reusing dataset snli (/home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)\n" 1185 | ] 1186 | }, 1187 | { 1188 | "data": { 1189 | "application/vnd.jupyter.widget-view+json": { 1190 | "model_id": "9fafeef372dd4a3d99f7233786ee4508", 1191 | "version_major": 2, 1192 | "version_minor": 0 1193 | }, 1194 | "text/plain": [ 1195 | " 0%| | 0/3 [00:00=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 60 | 61 | task_to_keys = { 62 | "cola": ("sentence", None), 63 | "mnli": ("premise", "hypothesis"), 64 | "mrpc": ("sentence1", "sentence2"), 65 | "qnli": ("question", "sentence"), 66 | "qqp": ("question1", "question2"), 67 | "rte": ("sentence1", "sentence2"), 68 | "sst2": ("sentence", None), 69 | "stsb": ("sentence1", "sentence2"), 70 | "wnli": ("sentence1", "sentence2"), 71 | 72 | "snli": ("premise", "hypothesis"), 73 | 74 | "boolq": ("question", "passage"), 75 | "cb": ("premise", "hypothesis"), 76 | # "mrpc-noisy":("sentence1", "sentence2"), 77 | # "rte-noisy":("sentence1", "sentence2"), 78 | } 79 | 80 | 81 | def parse_args(): 82 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 83 | parser.add_argument( 84 | "--task_name", 85 | type=str, 86 | default=None, 87 | help="The name of the glue task to train on.", 88 | # choices=list(task_to_keys.keys()), 89 | ) 90 | parser.add_argument( 91 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 92 | ) 93 | parser.add_argument( 94 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 95 | ) 96 | parser.add_argument( 97 | "--max_length", 98 | type=int, 99 | default=128, 100 | help=( 101 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 102 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 103 | ), 104 | ) 105 | parser.add_argument( 106 | "--pad_to_max_length", 107 | action="store_true", 108 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 109 | ) 110 | parser.add_argument( 111 | "--model_name_or_path", 112 | type=str, 113 | help="Path to pretrained model or model identifier from huggingface.co/models.", 114 | required=True, 115 | ) 116 | parser.add_argument( 117 | "--use_slow_tokenizer", 118 | action="store_true", 119 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 120 | ) 121 | parser.add_argument( 122 | "--per_device_train_batch_size", 123 | type=int, 124 | default=8, 125 | help="Batch size (per device) for the training dataloader.", 126 | ) 127 | parser.add_argument( 128 | "--per_device_eval_batch_size", 129 | type=int, 130 | default=8, 131 | help="Batch size (per device) for the evaluation dataloader.", 132 | ) 133 | parser.add_argument( 134 | "--learning_rate", 135 | type=float, 136 | default=5e-5, 137 | help="Initial learning rate (after the potential warmup period) to use.", 138 | ) 139 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 140 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 141 | parser.add_argument( 142 | "--max_train_steps", 143 | type=int, 144 | default=None, 145 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 146 | ) 147 | parser.add_argument( 148 | "--gradient_accumulation_steps", 149 | type=int, 150 | default=1, 151 | help="Number of updates steps to accumulate before performing a backward/update pass.", 152 | ) 153 | parser.add_argument( 154 | "--lr_scheduler_type", 155 | type=SchedulerType, 156 | default="linear", 157 | help="The scheduler type to use.", 158 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 159 | ) 160 | parser.add_argument( 161 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 162 | ) 163 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 164 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 165 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 166 | parser.add_argument( 167 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 168 | ) 169 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 170 | parser.add_argument( 171 | "--checkpointing_steps", 172 | type=str, 173 | default=None, 174 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 175 | ) 176 | parser.add_argument( 177 | "--resume_from_checkpoint", 178 | type=str, 179 | default=None, 180 | help="If the training should continue from a checkpoint folder.", 181 | ) 182 | parser.add_argument( 183 | "--with_tracking", 184 | action="store_true", 185 | help="Whether to enable experiment trackers for logging.", 186 | ) 187 | parser.add_argument( 188 | "--report_to", 189 | type=str, 190 | default="all", 191 | help=( 192 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 193 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 194 | "Only applicable when `--with_tracking` is passed." 195 | ), 196 | ) 197 | parser.add_argument( 198 | "--ignore_mismatched_sizes", 199 | action="store_true", 200 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 201 | ) 202 | parser.add_argument("--do_recording", action="store_true", help="Whether to record the training dynamics.") 203 | parser.add_argument("--with_data_selection", action="store_true", help="Use only a selected subset of the training data for model training.") 204 | parser.add_argument("--data_selection_region", default=None, choices=("easy","hard","ambiguous"), 205 | help="Three regions from the dataset cartography: easy, hard and ambiguous") 206 | parser.add_argument("--temperature", type=float, default=1., help="the temperature for the softmax in HCT model") 207 | parser.add_argument("--mu", type=float, default=0.5, help="weight for the gate loss in HCT model") 208 | parser.add_argument("--hard_inference", action="store_true", default=False, help="weight for the gate loss in HCT model") 209 | parser.add_argument("--hard_with_ls", action="store_true", help="if set, use label_smoothing for hard expert") 210 | parser.add_argument("--ls_weight", type=float, default=0.1, help="weight for label_smoothing") 211 | parser.add_argument("--more_ambiguous", action="store_true", default=False, help="set to put more weights to ambiguous samples") 212 | 213 | args = parser.parse_args() 214 | 215 | # Sanity checks 216 | if args.task_name is None and args.train_file is None and args.validation_file is None: 217 | raise ValueError("Need either a task name or a training/validation file.") 218 | else: 219 | if args.train_file is not None: 220 | extension = args.train_file.split(".")[-1] 221 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 222 | if args.validation_file is not None: 223 | extension = args.validation_file.split(".")[-1] 224 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 225 | 226 | # if args.push_to_hub: 227 | # assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 228 | 229 | return args 230 | 231 | from typing import List, Optional, Tuple, Union 232 | 233 | 234 | class HCTForSequenceClassification(nn.Module): 235 | """ 236 | Modified from `BertForSequenceClassification` class 237 | """ 238 | def __init__(self, model_name_or_path, config, temperature, mu, hard_with_ls=False, ls_weight=None, more_ambiguous=False): 239 | super(HCTForSequenceClassification, self).__init__() 240 | self.encoder = AutoModel.from_pretrained(model_name_or_path) 241 | self.config = config 242 | self.num_labels = config.num_labels 243 | self.classifier_dropout = ( 244 | self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob 245 | ) 246 | self.dropout = nn.Dropout(self.classifier_dropout) 247 | self.classifier_easy = nn.Linear(self.config.hidden_size, config.num_labels) 248 | self.classifier_hard = nn.Linear(self.config.hidden_size, config.num_labels) 249 | # 2 experts: easy or hard 250 | # gate output: 0 for easy, 1 for hard 251 | self.hardness_gate = nn.Linear(self.config.hidden_size,2) 252 | 253 | self.T = temperature 254 | self.mu = mu 255 | self.hard_with_ls = hard_with_ls 256 | self.ls_weight = ls_weight 257 | self.more_ambiguous = more_ambiguous 258 | 259 | 260 | def forward( 261 | self, 262 | input_ids: Optional[torch.Tensor] = None, 263 | attention_mask: Optional[torch.Tensor] = None, 264 | token_type_ids: Optional[torch.Tensor] = None, 265 | position_ids: Optional[torch.Tensor] = None, 266 | head_mask: Optional[torch.Tensor] = None, 267 | inputs_embeds: Optional[torch.Tensor] = None, 268 | labels: Optional[torch.Tensor] = None, 269 | # the confidence value of a sample 270 | confidences: Optional[torch.Tensor] = None, 271 | output_attentions: Optional[bool] = None, 272 | output_hidden_states: Optional[bool] = None, 273 | return_dict: Optional[bool] = None, 274 | ): 275 | r""" 276 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 277 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 278 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 279 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 280 | """ 281 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 282 | 283 | outputs = self.encoder( 284 | input_ids, 285 | attention_mask=attention_mask, 286 | token_type_ids=token_type_ids, 287 | position_ids=position_ids, 288 | head_mask=head_mask, 289 | inputs_embeds=inputs_embeds, 290 | output_attentions=output_attentions, 291 | output_hidden_states=output_hidden_states, 292 | return_dict=return_dict, 293 | ) 294 | # the CLS vector 295 | pooled_output = outputs[1] 296 | # gating: 297 | logits_gate = self.hardness_gate(pooled_output) 298 | gate_weights = F.softmax(logits_gate/self.T, dim=1) # 299 | # easy/hard experts: 300 | pooled_output = self.dropout(pooled_output) 301 | logits_easy = self.classifier_easy(pooled_output) 302 | logits_hard = self.classifier_hard(pooled_output) 303 | 304 | loss = None 305 | loss_easy, loss_hard, gate_loss = None, None, None 306 | if labels is not None: 307 | if self.config.problem_type is None: 308 | if self.num_labels == 1: 309 | self.config.problem_type = "regression" 310 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 311 | self.config.problem_type = "single_label_classification" 312 | else: 313 | self.config.problem_type = "multi_label_classification" 314 | 315 | if self.training: 316 | # gating loss, 仅在 train 模式下才有 317 | # confidences 其实相当于 easy 分支的概率,所以你还需要自己构造一个 hard prob 318 | easy_probs = confidences.view(-1,1) 319 | hard_probs = 1 - easy_probs 320 | hardness_probs = torch.cat([easy_probs,hard_probs],dim=1) 321 | gate_loss = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean') 322 | else: 323 | gate_loss = None 324 | if self.config.problem_type == "regression": 325 | loss_fct = MSELoss(reduction='none') 326 | if self.num_labels == 1: 327 | loss_easy = loss_fct(logits_easy.squeeze(), labels.squeeze()) 328 | loss_hard = loss_fct(logits_hard.squeeze(), labels.squeeze()) 329 | else: 330 | loss_easy = loss_fct(logits_easy, labels) 331 | loss_hard = loss_fct(logits_hard, labels) 332 | elif self.config.problem_type == "single_label_classification": 333 | loss_fct = CrossEntropyLoss(reduction='none') # reduction='none', 来得到每个sample的loss 334 | loss_easy = loss_fct(logits_easy.view(-1, self.num_labels), labels.view(-1)) 335 | # hard 端使用 label smoothing,即原始 CE loss 和一个完全 smoothing loss 的加权平均 336 | # 目前使用的是 torch 1.7 版本,所有只能自己手写,在1.10版本后可以直接设置 label smoothing 337 | if self.hard_with_ls: 338 | loss_ls_fct = CrossEntropyLoss(reduction='none', label_smoothing=self.ls_weight) 339 | loss_hard = loss_ls_fct(logits_hard.view(-1, self.num_labels), labels.view(-1)) 340 | else: 341 | loss_hard = loss_fct(logits_hard.view(-1, self.num_labels), labels.view(-1)) 342 | elif self.config.problem_type == "multi_label_classification": 343 | loss_fct = BCEWithLogitsLoss(reduction='none') 344 | loss_easy = loss_fct(logits_easy, labels) 345 | loss_hard = loss_fct(logits_hard, labels) 346 | 347 | easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1) 348 | # !!! 349 | if self.more_ambiguous: 350 | gate_weights = torch.where(gate_weights>0.5, 1-torch.abs(gate_weights-0.5), gate_weights) 351 | 352 | # way 1: weighted loss of easy and hard experts 353 | # TODO: 这里的权重,最好归一化一下 354 | weighted_loss = easy_hard_loss_cat * gate_weights 355 | 356 | # !!! 357 | # way 2: only easy expert 358 | # easy_weights = gate_weights[:,0] 359 | # batch_size = 32 # TODO 360 | # easy_weights = easy_weights * batch_size / torch.sum(easy_weights) # 归一化 361 | # weighted_loss = loss_easy * easy_weights 362 | # 这个策略,对noise可能有效,对干净数据集反而影响效果 363 | 364 | 365 | clf_loss = torch.mean(weighted_loss) 366 | 367 | # only use easy expert 368 | # easy_probs = gate_weights[:,0] 369 | # easy_weights = torch.where(easy_probs>0.5, 1-torch.abs(easy_probs-0.5), easy_probs) 370 | 371 | if self.training: # 只有 train 的时候才有 gate_loss 372 | loss = self.mu * gate_loss + (1 - self.mu) * clf_loss 373 | else: 374 | loss = clf_loss 375 | 376 | if not return_dict: 377 | output = (logits_gate,logits_easy, logits_hard,) + outputs[2:] 378 | return ((loss,) + output) if loss is not None else output 379 | 380 | return SequenceClassifierOutput( 381 | loss=loss, 382 | logits={"gate":logits_gate, "easy":logits_easy, "hard":logits_hard}, 383 | hidden_states=outputs.hidden_states, 384 | attentions=outputs.attentions, 385 | ) 386 | 387 | 388 | 389 | def main(): 390 | args = parse_args() 391 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 392 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 393 | # send_example_telemetry("run_glue_no_trainer", args) 394 | 395 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 396 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 397 | # in the environment 398 | accelerator = ( 399 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() 400 | ) 401 | # Make one log on every process with the configuration for debugging. 402 | logging.basicConfig( 403 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 404 | datefmt="%m/%d/%Y %H:%M:%S", 405 | level=logging.INFO, 406 | ) 407 | logger.info(accelerator.state, main_process_only=False) 408 | if accelerator.is_local_main_process: 409 | datasets.utils.logging.set_verbosity_warning() 410 | transformers.utils.logging.set_verbosity_info() 411 | else: 412 | datasets.utils.logging.set_verbosity_error() 413 | transformers.utils.logging.set_verbosity_error() 414 | 415 | # If passed along, set the training seed now. 416 | if args.seed is not None: 417 | set_seed(args.seed) 418 | 419 | 420 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 421 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 422 | 423 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 424 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 425 | # label if at least two columns are provided. 426 | 427 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 428 | # single column. You can easily tweak this behavior (see below) 429 | 430 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 431 | # download the dataset. 432 | if args.task_name is not None: 433 | # Downloading and loading a dataset from the hub. 434 | # raw_datasets = load_dataset("glue", args.task_name) 435 | # load processed GLUE dataset that contains 'confidence' value 436 | raw_datasets = load_from_disk(f"../datasets/{args.task_name}/with_conf/") 437 | assert 'confidence' in raw_datasets['train'].column_names, "the Train set must contain a 'confidence' column!!" 438 | print("Yes! I have loaded it! ----------") 439 | 440 | 441 | else: 442 | # Loading the dataset from local csv or json file. 443 | data_files = {} 444 | if args.train_file is not None: 445 | data_files["train"] = args.train_file 446 | if args.validation_file is not None: 447 | data_files["validation"] = args.validation_file 448 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] 449 | raw_datasets = load_dataset(extension, data_files=data_files) 450 | # See more about loading any type of standard or custom dataset at 451 | # https://huggingface.co/docs/datasets/loading_datasets.html. 452 | 453 | # Labels 454 | if args.task_name is not None: 455 | is_regression = args.task_name == "stsb" 456 | if not is_regression: 457 | label_list = raw_datasets["validation"].features["label"].names 458 | num_labels = len(label_list) 459 | else: 460 | num_labels = 1 461 | else: 462 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 463 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 464 | if is_regression: 465 | num_labels = 1 466 | else: 467 | # A useful fast method: 468 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 469 | label_list = raw_datasets["validation"].unique("label") 470 | label_list.sort() # Let's sort it for determinism 471 | num_labels = len(label_list) 472 | 473 | # --------------------------- Data Selection: ----------------------------------- 474 | # data selection is ONLY applied on train set 475 | if args.with_data_selection: 476 | assert args.data_selection_region is not None, "You much specify `data_selection_region` when using `with_data_selection`" 477 | model_name = args.model_name_or_path 478 | if '/' in model_name: 479 | model_name = model_name.split('/')[-1] 480 | assert os.path.exists(f'dy_log/{args.task_name}/{model_name}/three_regions_data_indices.json'), "Selection indices file not found!" 481 | with open(f'dy_log/{args.task_name}/{model_name}/three_regions_data_indices.json','r') as f: 482 | three_regions_data_indices = json.loads(f.read()) 483 | selected_indices = three_regions_data_indices[args.data_selection_region] 484 | raw_datasets['train'] = raw_datasets['train'].select(selected_indices) 485 | 486 | logger.info("~~~~~ Applying Data Selection ~~~~~ ") 487 | logger.info(f"~~~~~ Region: {args.data_selection_region} ") 488 | logger.info(f"~~~~~ Size: {len(raw_datasets['train'])} ") 489 | 490 | # ---------------------------------------------------------------------------------------------------- 491 | # with open(f'dy_log/sst2/distilbert-base-cased/three_regions_data_indices.json','r') as f: 492 | 493 | # Load pretrained model and tokenizer 494 | # 495 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 496 | # download model & vocab. 497 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) 498 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 499 | # model = AutoModelForSequenceClassification.from_pretrained( 500 | # args.model_name_or_path, 501 | # from_tf=bool(".ckpt" in args.model_name_or_path), 502 | # config=config, 503 | # ignore_mismatched_sizes=args.ignore_mismatched_sizes, 504 | # ) 505 | hct_model = HCTForSequenceClassification(args.model_name_or_path, config, temperature=args.temperature, mu=args.mu, hard_with_ls=args.hard_with_ls, ls_weight=args.ls_weight, more_ambiguous=args.more_ambiguous) 506 | 507 | 508 | 509 | # 对非HF官方模型的名称的处理,只保留模型名 510 | if '/' in args.model_name_or_path: 511 | args.model_name_or_path = args.model_name_or_path.split('/')[-1] 512 | 513 | # Preprocessing the datasets 514 | # --------------- GLUE tasks (and some SuperGLUE tasks, like boolq/cb)--------------- 515 | if args.task_name is not None: 516 | if 'noisy' in args.task_name: 517 | task_name = args.task_name.split('-')[0] 518 | sentence1_key, sentence2_key = task_to_keys[task_name] 519 | else: 520 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 521 | # --------------- Other tasks --------------- 522 | else: 523 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 524 | # 这里的逻辑是这样的: 525 | # 对于非glue的数据集,要求要包含`label`字段 526 | # 然后希望你有`sentence1`, `sentence2`这两个字段,这样就跟glue对齐了 527 | # 如果你也不是用的这个名字,那就选择非label列的前两个字段来分别作为sentence1和sentence2 528 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 529 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 530 | sentence1_key, sentence2_key = "sentence1", "sentence2" 531 | elif "sentence" in non_label_column_names: # for classical classification tasks, like sst2 532 | sentence1_key, sentence2_key = ("sentence", None) 533 | elif "question" in non_label_column_names and "sentence" in non_label_column_names: # for tasks like qnli 534 | sentence1_key, sentence2_key = ("question", "sentence") 535 | else: 536 | if len(non_label_column_names) >= 2: 537 | sentence1_key, sentence2_key = non_label_column_names[:2] 538 | else: 539 | sentence1_key, sentence2_key = non_label_column_names[0], None 540 | 541 | # Some models have set the order of the labels to use, so let's make sure we do use it. 542 | label_to_id = None 543 | # if ( 544 | # model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 545 | # and args.task_name is not None 546 | # and not is_regression 547 | # ): 548 | # # Some have all caps in their config, some don't. 549 | # label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 550 | # if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 551 | # logger.info( 552 | # f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " 553 | # "Using it!" 554 | # ) 555 | # label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} 556 | # else: 557 | # logger.warning( 558 | # "Your model seems to have been trained with labels, but they don't match the dataset: ", 559 | # f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 560 | # "\nIgnoring the model labels as a result.", 561 | # ) 562 | # elif args.task_name is None and not is_regression: 563 | # label_to_id = {v: i for i, v in enumerate(label_list)} 564 | 565 | # if label_to_id is not None: 566 | # model.config.label2id = label_to_id 567 | # model.config.id2label = {id: label for label, id in config.label2id.items()} 568 | # elif args.task_name is not None and not is_regression: 569 | # model.config.label2id = {l: i for i, l in enumerate(label_list)} 570 | # model.config.id2label = {id: label for label, id in config.label2id.items()} 571 | 572 | padding = "max_length" if args.pad_to_max_length else False 573 | 574 | def preprocess_function(examples): 575 | # Tokenize the texts 576 | texts = ( 577 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 578 | ) 579 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) 580 | 581 | if "label" in examples: 582 | if label_to_id is not None: 583 | # Map labels to IDs (not necessary for GLUE tasks) 584 | result["labels"] = [label_to_id[l] for l in examples["label"]] 585 | else: 586 | # In all cases, rename the column to labels because the model will expect that. 587 | result["labels"] = examples["label"] 588 | 589 | if "confidence" in examples: 590 | result["confidences"] = examples["confidence"] 591 | return result 592 | 593 | with accelerator.main_process_first(): 594 | processed_datasets = raw_datasets.map( 595 | preprocess_function, 596 | batched=True, 597 | # 得把这行改掉: 598 | # 以SST2为例,这里会把 ['sentence', 'label', 'idx'] 给去掉(不用担心label,因为上面已经新建了一个labels列) 599 | # remove_columns=raw_datasets["train"].column_names, 600 | # 改为: 601 | # 保留 idx 和 confidences,其他的可以去掉 (confidences 而不是 confidence,前者是模型要接收的名字) 602 | # remove_columns=[c for c in raw_datasets["train"].column_names if c not in ['idx', 'confidences']], 603 | desc="Running tokenizer on dataset", 604 | ) 605 | 606 | train_dataset = processed_datasets["train"].remove_columns([c for c in raw_datasets["train"].column_names if c not in ['idx', 'confidences']]) 607 | if args.task_name == 'mnli': 608 | eval_dataset = processed_datasets["validation_matched"].remove_columns([c for c in raw_datasets["validation_matched"].column_names if c not in ['idx']]) 609 | else: 610 | eval_dataset = processed_datasets["validation"].remove_columns([c for c in raw_datasets["validation"].column_names if c not in ['idx']]) 611 | 612 | 613 | # Log a few random samples from the training set: 614 | for index in random.sample(range(len(train_dataset)), 3): 615 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 616 | 617 | # DataLoaders creation: 618 | if args.pad_to_max_length: 619 | # If padding was already done ot max length, we use the default data collator that will just convert everything 620 | # to tensors. 621 | data_collator = default_data_collator 622 | else: 623 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 624 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 625 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 626 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 627 | 628 | train_dataloader = DataLoader( 629 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 630 | ) 631 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 632 | 633 | # Optimizer 634 | # Split weights in two groups, one with weight decay and the other not. 635 | no_decay = ["bias", "LayerNorm.weight"] 636 | optimizer_grouped_parameters = [ 637 | { 638 | "params": [p for n, p in hct_model.named_parameters() if not any(nd in n for nd in no_decay)], 639 | "weight_decay": args.weight_decay, 640 | }, 641 | { 642 | "params": [p for n, p in hct_model.named_parameters() if any(nd in n for nd in no_decay)], 643 | "weight_decay": 0.0, 644 | }, 645 | ] 646 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 647 | 648 | # Scheduler and math around the number of training steps. 649 | overrode_max_train_steps = False 650 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 651 | if args.max_train_steps is None: 652 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 653 | overrode_max_train_steps = True 654 | 655 | lr_scheduler = get_scheduler( 656 | name=args.lr_scheduler_type, 657 | optimizer=optimizer, 658 | num_warmup_steps=args.num_warmup_steps, 659 | num_training_steps=args.max_train_steps, 660 | ) 661 | 662 | # Prepare everything with our `accelerator`. 663 | hct_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 664 | hct_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 665 | ) 666 | 667 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 668 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 669 | if overrode_max_train_steps: 670 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 671 | # Afterwards we recalculate our number of training epochs 672 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 673 | 674 | # Figure out how many steps we should save the Accelerator states 675 | if hasattr(args.checkpointing_steps, "isdigit"): 676 | checkpointing_steps = args.checkpointing_steps 677 | if args.checkpointing_steps.isdigit(): 678 | checkpointing_steps = int(args.checkpointing_steps) 679 | else: 680 | checkpointing_steps = None 681 | 682 | # We need to initialize the trackers we use, and also store our configuration. 683 | # We initialize the trackers only on main process because `accelerator.log` 684 | # only logs on main process and we don't want empty logs/runs on other processes. 685 | if args.with_tracking: 686 | if accelerator.is_main_process: 687 | experiment_config = vars(args) 688 | # TensorBoard cannot log Enums, need the raw value 689 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 690 | accelerator.init_trackers("glue_no_trainer", experiment_config) 691 | 692 | # Get the metric function 693 | if args.task_name is not None: 694 | if args.task_name == 'snli': 695 | metric = load_metric("glue", 'mnli') 696 | elif args.task_name in ['boolq','cb','axb','axg']: 697 | metric = load_metric("super_glue" ,args.task_name) 698 | elif 'noisy' in args.task_name: 699 | task_name = args.task_name.split('-')[0] 700 | metric = load_metric("glue", task_name) 701 | else: 702 | metric = load_metric("glue", args.task_name) 703 | else: 704 | metric = load_metric("accuracy") 705 | 706 | # Train! 707 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 708 | 709 | logger.info("***** Running training *****") 710 | logger.info(f" Num examples = {len(train_dataset)}") 711 | logger.info(f" Num Epochs = {args.num_train_epochs}") 712 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 713 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 714 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 715 | logger.info(f" Total optimization steps = {args.max_train_steps}") 716 | # Only show the progress bar once on each machine. 717 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 718 | completed_steps = 0 719 | starting_epoch = 0 720 | # Potentially load in the weights and states from a previous save 721 | if args.resume_from_checkpoint: 722 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 723 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 724 | accelerator.load_state(args.resume_from_checkpoint) 725 | path = os.path.basename(args.resume_from_checkpoint) 726 | else: 727 | # Get the most recent checkpoint 728 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 729 | dirs.sort(key=os.path.getctime) 730 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 731 | # Extract `epoch_{i}` or `step_{i}` 732 | training_difference = os.path.splitext(path)[0] 733 | 734 | if "epoch" in training_difference: 735 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 736 | resume_step = None 737 | else: 738 | resume_step = int(training_difference.replace("step_", "")) 739 | starting_epoch = resume_step // len(train_dataloader) 740 | resume_step -= starting_epoch * len(train_dataloader) 741 | 742 | # ============================ Training Loop ============================ 743 | for epoch in range(starting_epoch, args.num_train_epochs): 744 | 745 | hct_model.train() 746 | if args.with_tracking: 747 | total_loss = 0 748 | for step, batch in enumerate(train_dataloader): 749 | # We need to skip steps until we reach the resumed step 750 | if args.resume_from_checkpoint and epoch == starting_epoch: 751 | if resume_step is not None and step < resume_step: 752 | completed_steps += 1 753 | continue 754 | # batch中包含了idx字段,这里需要去除 755 | batch = {k:v for k,v in batch.items() if k != 'idx'} 756 | outputs = hct_model(**batch) 757 | loss = outputs.loss 758 | # We keep track of the loss at each epoch 759 | if args.with_tracking: 760 | total_loss += loss.detach().float() 761 | loss = loss / args.gradient_accumulation_steps 762 | accelerator.backward(loss) 763 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 764 | optimizer.step() 765 | lr_scheduler.step() 766 | optimizer.zero_grad() 767 | progress_bar.update(1) 768 | completed_steps += 1 769 | 770 | if isinstance(checkpointing_steps, int): 771 | if completed_steps % checkpointing_steps == 0: 772 | output_dir = f"step_{completed_steps }" 773 | if args.output_dir is not None: 774 | output_dir = os.path.join(args.output_dir, output_dir) 775 | accelerator.save_state(output_dir) 776 | 777 | if completed_steps >= args.max_train_steps: 778 | break 779 | # ------------------ Recording Training Dynamics -------------------- 780 | # 在每一个epoch之后,对train set所有样本再过一遍,记录dynamics 781 | # 每个epoch单独一个文件 782 | # if args.do_recording: 783 | # if accelerator.is_main_process: 784 | # if not os.path.exists(f'dy_log/{args.task_name}/'): 785 | # os.mkdir(f'dy_log/{args.task_name}/') 786 | # if not os.path.exists(f'dy_log/{args.task_name}/{args.model_name_or_path}'): 787 | # os.mkdir(f'dy_log/{args.task_name}/{args.model_name_or_path}') 788 | # log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/' 789 | # if not os.path.exists(log_path): 790 | # os.mkdir(log_path) 791 | 792 | # accelerator.wait_for_everyone() # 只在 main process 里面创建文件夹,然后让其他 process 等待 main process 创建完毕 793 | # log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/' 794 | # print('-*-*-*- ',log_path, os.path.exists(log_path),accelerator.device) 795 | 796 | # logger.info('---------- Recording Training Dynamics (Epoch %s) -----------'%epoch) 797 | # training_dynamics = [] 798 | # all_ids = [] 799 | # for step, batch in enumerate(tqdm(train_dataloader)): 800 | # # print('- - - - - - - - - - ',len(batch['idx']), accelerator.device) 801 | # idx_list = batch['idx']#.tolist() 802 | # label_list = batch['labels']#.tolist() 803 | # batch = {k:v for k,v in batch.items() if k != 'idx'} 804 | # logits_list = model(**batch).logits#.tolist() # [[],[],[],...] batch_size个[] 805 | # # 这里的关键:通过 gather 把每个 GPU上的结果合并 806 | # # 由于在使用多卡训练时,不同卡可能存在样本的重复,同一个卡也会对最后一个batch进行补齐,也会样本重复 807 | # # 使用 gather 的话,就可以按照原来的分配方式,逆着组合回去,就不用你自己处理了 808 | # # gather 之后的,在每个卡上,下述变量里包含的数量,都等同于只使用单卡进行训练时的数量 809 | # # 所以下面的for训练执行完之后,training_dynamics里就包含了全部样本,你在写入文件时,记住只在一个 process 中写入 810 | # idx_list, label_list, logits_list = accelerator.gather((idx_list, label_list, logits_list)) 811 | # # print('idx_list', idx_list.shape, accelerator.device) 812 | # # print('label_list', label_list.shape, accelerator.device) 813 | 814 | # for idx, label, logits in zip(idx_list.tolist(), label_list.tolist(), logits_list.tolist()): 815 | # if idx in all_ids: # 由于 data_loader 可能会对最后一个 batch 进行补全,所以这里要去掉重复的样本 816 | # continue 817 | # all_ids.append(idx) 818 | # record = {'guid': idx, 'logits_epoch_%s'%epoch: logits, 'gold': label, 'device':str(accelerator.device)} 819 | # training_dynamics.append(record) 820 | 821 | # if accelerator.is_main_process: 822 | # print('---- Num of training_dynamics: ',len(training_dynamics),' Device: ', str(accelerator.device)) 823 | # print(len(all_ids),len(list(set(all_ids))),str(accelerator.device)) 824 | # assert os.path.exists(log_path),log_path 825 | # writer = open(log_path + f'dynamics_epoch_{epoch}.jsonl', 'w') 826 | # for record in training_dynamics: 827 | # writer.write(json.dumps(record) + "\n") 828 | # logger.info(f'Epoch {epoch} Saved to [{log_path}]') 829 | # writer.close() 830 | # accelerator.wait_for_everyone() 831 | # ------------------------------------------------------------------------ 832 | 833 | def hct_batch_inference(model, batch, hard_choice=True, T=1, single_expert=None): 834 | model.eval() 835 | with torch.no_grad(): 836 | outputs = model(**batch) 837 | logits = outputs.logits 838 | 839 | if not single_expert: 840 | if hard_choice: 841 | expert_choices = torch.argmax(logits['gate'],dim=1) # expert choice for each sample in the batch 842 | chosen_logits = logits['easy'] * (1-expert_choices).view(-1,1) + logits['hard'] * expert_choices.view(-1,1) 843 | return chosen_logits 844 | else: 845 | weights = F.softmax(logits['gate']/T, dim=1) 846 | # !!! 847 | if args.more_ambiguous: 848 | weights = torch.where(weights>0.5, 1-torch.abs(weights-0.5), weights) 849 | weighted_logits = logits['easy'] * weights[:,0].view(-1,1) + logits['hard'] * weights[:,1].view(-1,1) 850 | return weighted_logits 851 | else: 852 | assert single_expert in ['easy', 'hard'] 853 | return logits[single_expert] 854 | 855 | # def hct_batch_inference_single(model, batch, expert='easy'): 856 | # model.eval() 857 | # with torch.no_grad(): 858 | # outputs = model(**batch) 859 | # logits = outputs.logits 860 | # return logits[expert] 861 | 862 | 863 | # evaluation (validation set) 864 | # hard inference: 865 | hct_model.eval() 866 | samples_seen = 0 867 | for step, batch in enumerate(eval_dataloader): 868 | batch = {k:v for k,v in batch.items() if k != 'idx'} 869 | logits = hct_batch_inference(hct_model, batch, hard_choice=True, T=args.temperature) # , single_expert='easy' 870 | predictions = logits.argmax(dim=-1) if not is_regression else logits.squeeze() 871 | predictions, references = accelerator.gather((predictions, batch["labels"])) 872 | # If we are in a multiprocess environment, the last batch has duplicates 873 | if accelerator.num_processes > 1: 874 | if step == len(eval_dataloader) - 1: 875 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 876 | references = references[: len(eval_dataloader.dataset) - samples_seen] 877 | else: 878 | samples_seen += references.shape[0] 879 | metric.add_batch( 880 | predictions=predictions, 881 | references=references, 882 | ) 883 | eval_metric = metric.compute() 884 | logger.info(f"***Evaluation (hard inference) *** epoch {epoch}: {eval_metric}") 885 | 886 | # soft inference: 887 | hct_model.eval() 888 | samples_seen = 0 889 | for step, batch in enumerate(eval_dataloader): 890 | batch = {k:v for k,v in batch.items() if k != 'idx'} 891 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature) # , single_expert='easy' 892 | predictions = logits.argmax(dim=-1) if not is_regression else logits.squeeze() 893 | predictions, references = accelerator.gather((predictions, batch["labels"])) 894 | # If we are in a multiprocess environment, the last batch has duplicates 895 | if accelerator.num_processes > 1: 896 | if step == len(eval_dataloader) - 1: 897 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 898 | references = references[: len(eval_dataloader.dataset) - samples_seen] 899 | else: 900 | samples_seen += references.shape[0] 901 | metric.add_batch( 902 | predictions=predictions, 903 | references=references, 904 | ) 905 | eval_metric = metric.compute() 906 | logger.info(f"***Evaluation (soft inference) *** epoch {epoch}: {eval_metric}") 907 | 908 | if args.with_tracking: 909 | accelerator.log( 910 | { 911 | "accuracy" if args.task_name is not None else "glue": eval_metric, 912 | "train_loss": total_loss.item() / len(train_dataloader), 913 | "epoch": epoch, 914 | "step": completed_steps, 915 | }, 916 | step=completed_steps, 917 | ) 918 | 919 | if args.checkpointing_steps == "epoch": 920 | output_dir = f"epoch_{epoch}" 921 | if args.output_dir is not None: 922 | output_dir = os.path.join(args.output_dir, output_dir) 923 | accelerator.save_state(output_dir) 924 | # ============================ End Training Loop ============================ 925 | 926 | 927 | if args.output_dir is not None: 928 | accelerator.wait_for_everyone() 929 | unwrapped_model = accelerator.unwrap_model(hct_model) 930 | unwrapped_model.save_pretrained( 931 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 932 | ) 933 | if accelerator.is_main_process: 934 | tokenizer.save_pretrained(args.output_dir) 935 | # if args.push_to_hub: 936 | # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 937 | 938 | if args.task_name == "mnli": 939 | # Final evaluation on mismatched validation set 940 | # eval_dataset = processed_datasets["validation_mismatched"] 941 | eval_dataset = processed_datasets["validation_mismatched"].remove_columns([c for c in raw_datasets["validation_mismatched"].column_names if c not in ['idx']]) 942 | eval_dataloader = DataLoader( 943 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 944 | ) 945 | eval_dataloader = accelerator.prepare(eval_dataloader) 946 | 947 | hct_model.eval() 948 | for step, batch in enumerate(eval_dataloader): 949 | batch = {k:v for k,v in batch.items() if k != 'idx'} 950 | # outputs = hct_model(**batch) 951 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature) 952 | predictions = logits.argmax(dim=-1) 953 | metric.add_batch( 954 | predictions=accelerator.gather(predictions), 955 | references=accelerator.gather(batch["labels"]), 956 | ) 957 | eval_metric = metric.compute() 958 | logger.info(f"mnli-mm: {eval_metric}") 959 | 960 | if args.task_name == "snli": 961 | # Final evaluation on mismatched validation set 962 | eval_dataset = processed_datasets["test"].remove_columns([c for c in raw_datasets["test"].column_names if c not in ['idx']]) 963 | eval_dataloader = DataLoader( 964 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 965 | ) 966 | eval_dataloader = accelerator.prepare(eval_dataloader) 967 | 968 | hct_model.eval() 969 | for step, batch in enumerate(eval_dataloader): 970 | # batch中包含了idx字段,这里需要去除 971 | batch = {k:v for k,v in batch.items() if k != 'idx'} 972 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature) 973 | predictions = logits.argmax(dim=-1) 974 | metric.add_batch( 975 | predictions=accelerator.gather(predictions), 976 | references=accelerator.gather(batch["labels"]), 977 | ) 978 | 979 | eval_metric = metric.compute() 980 | logger.info(f"snli-test: {eval_metric}") 981 | 982 | 983 | if args.output_dir is not None: 984 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 985 | json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) 986 | 987 | 988 | if __name__ == "__main__": 989 | main() --------------------------------------------------------------------------------