├── __init__.py
├── HCT
├── __init__.py
├── rm_cache.sh
├── run_hct.sh
├── create_noisy_dataset.ipynb
├── dataset_prepare.ipynb
├── test.ipynb
└── run_glue_hct.py
├── examples
├── DataMap.png
└── DataMapCompare.png
├── .gitignore
├── requirements.txt
├── run_glue.sh
├── .vscode-upload.json
├── plot.sh
├── run_glue_and_record_td.sh
├── experiments.md
├── data_utils_glue.py
├── run_glue_copy.sh
├── run_glue_ambi.sh
├── selection_utils.py
├── data_selection.py
├── data_utils.py
├── README.md
├── dy_filtering.py
├── run_glue.py
└── hardness_classification_prepare.ipynb
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/HCT/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/DataMap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/TrainingDynamics/HEAD/examples/DataMap.png
--------------------------------------------------------------------------------
/examples/DataMapCompare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beyondguo/TrainingDynamics/HEAD/examples/DataMapCompare.png
--------------------------------------------------------------------------------
/HCT/rm_cache.sh:
--------------------------------------------------------------------------------
1 | export TASK_NAME=snli
2 | cd datasets/$TASK_NAME/with_conf/train
3 | rm cache*
4 | cd ../../../../
5 | cd datasets/$TASK_NAME/with_conf/validation
6 | rm cache*
7 | cd ../../../../
8 | cd datasets/$TASK_NAME/with_conf/test
9 | rm cache*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # basic ignore:
2 | __pycache__/
3 | .idea/
4 | .ipynb_checkpoints/
5 | .DS_Store
6 | .vscode
7 | /tmp
8 |
9 | .empty/
10 | # log
11 | /dy_log
12 | /log
13 |
14 | # data
15 | /datasets
16 |
17 |
18 | # models
19 | saved_models/
20 | HCT/*.weight
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.9.0
2 | datasets==2.3.2
3 | huggingface_hub==0.4.0
4 | jsonx==2020.10.0
5 | matplotlib==3.3.1
6 | numpy==1.19.1
7 | pandas==0.24.2
8 | seaborn==0.11.2
9 | torch==1.7.0a0+8deb4fe
10 | tqdm==4.31.1
11 | transformers==4.18.0
12 |
--------------------------------------------------------------------------------
/run_glue.sh:
--------------------------------------------------------------------------------
1 |
2 | # prajjwal1/bert-tiny
3 | # distilbert-base-cased
4 | # roberta-large
5 |
6 | export TASK_NAME=sst2
7 | export MODEL=prajjwal1/bert-tiny
8 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
9 | --model_name_or_path $MODEL \
10 | --task_name $TASK_NAME \
11 | --max_length 128 \
12 | --per_device_train_batch_size 32 \
13 | --learning_rate 2e-5 \
14 | --num_train_epochs 10 \
15 | # --output_dir tmp/$TASK_NAME/
--------------------------------------------------------------------------------
/.vscode-upload.json:
--------------------------------------------------------------------------------
1 | [{
2 | "name":"",
3 | "host": "",
4 | "port": 22,
5 | "username": "",
6 | "password": "",
7 | "remotePath": "",
8 | "localPath": "",
9 | "disable": false,
10 | "private_key": "~/.ssh/id_rsa"
11 | },{
12 | "name":"",
13 | "host": "",
14 | "port": 22,
15 | "username": "",
16 | "password": "",
17 | "remotePath": "",
18 | "localPath": "",
19 | "disable": false,
20 | "private_key": "~/.ssh/id_rsa"
21 | }]
--------------------------------------------------------------------------------
/plot.sh:
--------------------------------------------------------------------------------
1 | # task_name, model 这俩参数,只是用于展示在plot上
2 | # model_dir 是放 training_dynamics 文件夹的那个目录
3 | # plots_dir 是绘制好的 plot 存放的目录
4 | # burn_out 计算多少轮的 dynamics
5 |
6 | # bert-tiny
7 | # distilbert-base-cased
8 | # roberta-large
9 |
10 | export TASK_NAME=rte-noisy-0.4
11 | export MODEL=bert-base-cased
12 | python -m dy_filtering \
13 | --plot \
14 | --task_name $TASK_NAME \
15 | --model_dir dy_log/$TASK_NAME/$MODEL \
16 | --plots_dir dy_log/$TASK_NAME/$MODEL \
17 | --model $MODEL \
18 | --burn_out 5
19 |
20 |
--------------------------------------------------------------------------------
/run_glue_and_record_td.sh:
--------------------------------------------------------------------------------
1 |
2 | # ====================== Recording Training Dynamics: ===========
3 | # `--nproc_per_node N` means the number of GPUs to use
4 | # Suggested models:
5 | # prajjwal1/bert-tiny
6 | # distilbert-base-cased
7 | # bert-base-cased
8 | # roberta-large
9 |
10 |
11 | export TASK_NAME=mnli
12 | export MODEL=distilbert-base-cased
13 |
14 | # CUDA_VISIBLE_DEVICES=7 python run_glue.py \
15 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
16 | --seed 5 \
17 | --model_name_or_path $MODEL \
18 | --task_name $TASK_NAME \
19 | --max_length 128 \
20 | --per_device_train_batch_size 32 \
21 | --learning_rate 2e-5 \
22 | --num_train_epochs 5 \
23 | --do_recording \
24 |
--------------------------------------------------------------------------------
/experiments.md:
--------------------------------------------------------------------------------
1 | # sst2
2 | ```
3 | Main params:
4 | - GPUs: 8
5 | - max_length 128
6 | - batch_size 32
7 | - epochs 5
8 | ```
9 |
10 | - **`bert-base-cased`**
11 | - 100% train: 0.9105504587155964
12 | - 33% (easy): 0.8772935779816514
13 | - 33% (hard): 0.908256880733945
14 | - 33% (ambiguous): 0.9025229357798165
15 |
16 | ---
17 |
18 | - **`distilbert-base-cased`**
19 | - 100% train: 0.9094036697247706
20 | - 33% (easy): 0.856651376146789
21 | - 33% (hard): 0.8944954128440367
22 | - 33% (ambiguous): 0.8956422018348624
23 |
24 | ---
25 |
26 | - **`bert-tiny`**
27 | - 100% train: 0.788990825688
28 | - 33% (easy): 0.694954128440367
29 | - 33% (hard): 0.37155963302752293
30 | - 33% (ambiguous): 0.5022935779816514
31 |
32 | use the data selection by `distilbert-base-cased`
33 | - 33% (easy): 0.5389908256880734
34 | - 33% (hard): 0.5871559633027523
35 | - 33% (ambiguous): 0.6020642201834863
36 |
37 |
--------------------------------------------------------------------------------
/HCT/run_hct.sh:
--------------------------------------------------------------------------------
1 | # ====================== Basic GLUE tasks training and evaluation:
2 | # the following is the standard training for GLUE tasks
3 |
4 | # export TASK_NAME=sst2
5 | # export MODEL=distilbert-base-cased
6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
7 | # --model_name_or_path $MODEL \
8 | # --task_name $TASK_NAME \
9 | # --max_length 128 \
10 | # --per_device_train_batch_size 32 \
11 | # --learning_rate 2e-5 \
12 | # --num_train_epochs 5 \
13 |
14 |
15 | # ————> Suggested models:
16 | # prajjwal1/bert-tiny
17 | # distilbert-base-cased
18 | # bert-base-cased
19 | # roberta-large
20 |
21 | # =========== Train GLUE tasks =============
22 | # https://huggingface.co/datasets/glue
23 |
24 | export TASK_NAME=rte
25 | export MODEL=bert-base-cased
26 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue_hct.py \
27 | CUDA_VISIBLE_DEVICES=5 python run_glue_hct.py \
28 | --seed 5 \
29 | --model_name_or_path $MODEL \
30 | --task_name $TASK_NAME \
31 | --max_length 128 \
32 | --per_device_train_batch_size 32 \
33 | --learning_rate 4e-5 \
34 | --num_train_epochs 10 \
35 | --temperature 1\
36 | --mu 0.5 \
37 | --more_ambiguous
38 | # --hard_with_ls \
39 | # --ls_weight 0.1 \
40 | # --hard_inference
41 |
42 | # --with_data_selection \
43 | # --data_selection_region ambiguous \
44 | # --output_dir tmp/$TASK_NAME/
45 |
46 |
--------------------------------------------------------------------------------
/data_utils_glue.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import re
4 | import tqdm
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | def convert_string_to_unique_number(string: str) -> int:
10 | """
11 | Hack to convert SNLI ID into a unique integer ID, for tensorizing.
12 | """
13 | id_map = {'e': '0', 'c': '1', 'n': '2'}
14 |
15 | # SNLI-specific hacks.
16 | if string.startswith('vg_len'):
17 | code = '555'
18 | elif string.startswith('vg_verb'):
19 | code = '444'
20 | else:
21 | code = '000'
22 |
23 | try:
24 | number = int(code + re.sub(r"\D", "", string) + id_map.get(string[-1], '3'))
25 | except:
26 | number = random.randint(10000, 99999)
27 | logger.info(f"Cannot find ID for {string}, using random number {number}.")
28 | return number
29 |
30 |
31 | def read_glue_tsv(file_path: str,
32 | guid_index: int,
33 | label_index: int = -1,
34 | guid_as_int: bool = False):
35 | """
36 | Reads TSV files for GLUE-style text classification tasks.
37 | Returns:
38 | - a mapping between the example ID and the entire line as a string.
39 | - the header of the TSV file.
40 | """
41 | tsv_dict = {}
42 |
43 | i = -1
44 | with open(file_path, 'r') as tsv_file:
45 | for line in tqdm.tqdm([line for line in tsv_file]):
46 | i += 1
47 | if i == 0:
48 | header = line.strip()
49 | field_names = line.strip().split("\t")
50 | continue
51 |
52 | fields = line.strip().split("\t")
53 | label = fields[label_index]
54 | if len(fields) > len(field_names):
55 | # SNLI / MNLI fields sometimes contain multiple annotator labels.
56 | # Ignore all except the gold label.
57 | reformatted_fields = fields[:len(field_names)-1] + [label]
58 | assert len(reformatted_fields) == len(field_names)
59 | reformatted_line = "\t".join(reformatted_fields)
60 | else:
61 | reformatted_line = line.strip()
62 |
63 | if label == "-" or label == "":
64 | logger.info(f"Skippping line: {line}")
65 | continue
66 |
67 | if guid_index is None:
68 | guid = i
69 | else:
70 | guid = fields[guid_index] # PairID.
71 | if guid in tsv_dict:
72 | logger.info(f"Found clash in IDs ... skipping example {guid}.")
73 | continue
74 | tsv_dict[guid] = reformatted_line.strip()
75 |
76 | logger.info(f"Read {len(tsv_dict)} valid examples, with unique IDS, out of {i} from {file_path}")
77 | if guid_as_int:
78 | tsv_numeric = {int(convert_string_to_unique_number(k)): v for k, v in tsv_dict.items()}
79 | return tsv_numeric, header
80 | return tsv_dict, header
--------------------------------------------------------------------------------
/run_glue_copy.sh:
--------------------------------------------------------------------------------
1 | # ====================== Basic GLUE tasks training and evaluation:
2 | # the following is the standard training for GLUE tasks
3 |
4 | # export TASK_NAME=sst2
5 | # export MODEL=distilbert-base-cased
6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
7 | # --model_name_or_path $MODEL \
8 | # --task_name $TASK_NAME \
9 | # --max_length 128 \
10 | # --per_device_train_batch_size 32 \
11 | # --learning_rate 2e-5 \
12 | # --num_train_epochs 5 \
13 |
14 |
15 | # ————> Training with Data Selection:
16 | # after you run `data_selection.py` and obtain the `three_regions_data_indices.json` file
17 | # you can train a GLUE classifier again with your specified data selection
18 | # set `--with_data_selection` to turn on data selection
19 | # set `--data_selection_region [region]` to specify the region, choices are "easy", "hard", and "ambiguous"
20 |
21 |
22 | # ————> Suggested models:
23 | # prajjwal1/bert-tiny
24 | # distilbert-base-cased
25 | # bert-base-cased
26 | # roberta-large
27 |
28 | # =========== Train GLUE tasks =============
29 | # https://huggingface.co/datasets/glue
30 |
31 | export TASK_NAME=boolq
32 | export MODEL=bert-base-cased
33 | # CUDA_VISIBLE_DEVICES=4 python run_glue.py \
34 | python -m torch.distributed.launch --nproc_per_node 1 --use_env run_glue.py \
35 | --seed 5 \
36 | --model_name_or_path $MODEL \
37 | --task_name $TASK_NAME \
38 | --output_dir saved_models/$TASK_NAME/$MODEL \
39 | --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_9 \
40 | --checkpointing_steps epoch \
41 | --max_length 128 \
42 | --per_device_train_batch_size 32 \
43 | --learning_rate 5e-5 \
44 | --num_train_epochs 10 \
45 | --continue_train \
46 | --continue_num_train_epochs 5 \
47 | --log_name ambiguous \
48 | --selected_indices_filename selected_indices_ambi_top0.33_balance_from500 \
49 | # --do_lwf \
50 |
51 |
52 |
53 |
54 | # --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \ # 指定了之后,就会直接load该epoch的模型
55 |
56 | # --max_train_steps 10 \
57 | # --with_data_selection \
58 | # --data_selection_region ambiguous \
59 | # --output_dir tmp/$TASK_NAME/
60 |
61 |
62 | # =========== Use Your Own Dataset ========
63 |
64 | # export MODEL=bert-base-cased
65 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
66 | # --model_name_or_path $MODEL \
67 | # --max_length 128 \
68 | # --per_device_train_batch_size 32 \
69 | # --learning_rate 2e-5 \
70 | # --num_train_epochs 5 \
71 | # --train_file datasets/qnli-easy-hard_train.csv \
72 | # --validation_file datasets/qnli-easy-hard_valid.csv
73 |
74 |
75 |
76 | # cd ../K2T
77 | # sh oc.sh
--------------------------------------------------------------------------------
/run_glue_ambi.sh:
--------------------------------------------------------------------------------
1 | # ====================== Basic GLUE tasks training and evaluation:
2 | # the following is the standard training for GLUE tasks
3 |
4 | # export TASK_NAME=sst2
5 | # export MODEL=distilbert-base-cased
6 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
7 | # --model_name_or_path $MODEL \
8 | # --task_name $TASK_NAME \
9 | # --max_length 128 \
10 | # --per_device_train_batch_size 32 \
11 | # --learning_rate 2e-5 \
12 | # --num_train_epochs 5 \
13 |
14 |
15 | # ————> Training with Data Selection:
16 | # after you run `data_selection.py` and obtain the `three_regions_data_indices.json` file
17 | # you can train a GLUE classifier again with your specified data selection
18 | # set `--with_data_selection` to turn on data selection
19 | # set `--data_selection_region [region]` to specify the region, choices are "easy", "hard", and "ambiguous"
20 |
21 |
22 | # ————> Suggested models:
23 | # prajjwal1/bert-tiny
24 | # distilbert-base-cased
25 | # bert-base-cased
26 | # roberta-large
27 |
28 | # =========== Train GLUE tasks =============
29 | # https://huggingface.co/datasets/glue
30 |
31 | export TASK_NAME=mnli
32 | export MODEL=roberta-large
33 | # CUDA_VISIBLE_DEVICES=4 python run_glue.py \
34 | python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
35 | --seed 5 \
36 | --model_name_or_path $MODEL \
37 | --task_name $TASK_NAME \
38 | --checkpointing_steps epoch \
39 | --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \
40 | --max_length 128 \
41 | --per_device_train_batch_size 32 \
42 | --learning_rate 2e-5 \
43 | --num_train_epochs 5 \
44 | --log_name ambiguous \
45 | --continue_train_with_sample_loss \
46 | --continue_train \
47 | --continue_num_train_epochs 5 \
48 | # --selected_indices_filename selected_indices_ambi_top0.1_balance+ \
49 | # --do_lwf \
50 |
51 |
52 |
53 | # --output_dir saved_models/$TASK_NAME/$MODEL \
54 | # --resume_from_checkpoint saved_models/$TASK_NAME/$MODEL/epoch_4 \ # 指定了之后,就会直接load该epoch的模型
55 |
56 | # --max_train_steps 10 \
57 | # --with_data_selection \
58 | # --data_selection_region ambiguous \
59 | # --output_dir tmp/$TASK_NAME/
60 |
61 |
62 | # =========== Use Your Own Dataset ========
63 |
64 | # export MODEL=bert-base-cased
65 | # python -m torch.distributed.launch --nproc_per_node 8 --use_env run_glue.py \
66 | # --model_name_or_path $MODEL \
67 | # --max_length 128 \
68 | # --per_device_train_batch_size 32 \
69 | # --learning_rate 2e-5 \
70 | # --num_train_epochs 5 \
71 | # --train_file datasets/qnli-easy-hard_train.csv \
72 | # --validation_file datasets/qnli-easy-hard_valid.csv
73 |
74 |
75 |
76 | # cd ../K2T
77 | # sh oc.sh
--------------------------------------------------------------------------------
/selection_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import numpy as np
4 | import os
5 | import pandas as pd
6 | import tqdm
7 |
8 | from typing import List
9 |
10 | logging.basicConfig(
11 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO
12 | )
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def log_training_dynamics(output_dir: os.path,
17 | epoch: int,
18 | train_ids: List[int],
19 | train_logits: List[List[float]],
20 | train_golds: List[int]):
21 | """
22 | Save training dynamics (logits) from given epoch as records of a `.jsonl` file.
23 | """
24 | td_df = pd.DataFrame({"guid": train_ids,
25 | f"logits_epoch_{epoch}": train_logits,
26 | "gold": train_golds})
27 |
28 | logging_dir = os.path.join(output_dir, f"training_dynamics")
29 | # Create directory for logging training dynamics, if it doesn't already exist.
30 | if not os.path.exists(logging_dir):
31 | os.makedirs(logging_dir)
32 | epoch_file_name = os.path.join(logging_dir, f"dynamics_epoch_{epoch}.jsonl")
33 | td_df.to_json(epoch_file_name, lines=True, orient="records")
34 | logger.info(f"Training Dynamics logged to {epoch_file_name}")
35 |
36 |
37 | def read_training_dynamics(model_dir: os.path,
38 | strip_last: bool = False,
39 | id_field: str = "guid",
40 | burn_out: int = None):
41 | """
42 | Given path to logged training dynamics, merge stats across epochs.
43 | Returns:
44 | - Dict between ID of a train instances and its gold label, and the list of logits across epochs.
45 | """
46 | train_dynamics = {}
47 |
48 | td_dir = os.path.join(model_dir, "training_dynamics")
49 | num_epochs = len([f for f in os.listdir(td_dir) if os.path.isfile(os.path.join(td_dir, f))])
50 | if burn_out:
51 | num_epochs = burn_out
52 |
53 | logger.info(f"Reading {num_epochs} files from {td_dir} ...")
54 | for epoch_num in tqdm.tqdm(range(num_epochs)):
55 | epoch_file = os.path.join(td_dir, f"dynamics_epoch_{epoch_num}.jsonl")
56 | assert os.path.exists(epoch_file)
57 |
58 | with open(epoch_file, "r") as infile:
59 | # print('*** Current Reading:',epoch_file)
60 | for line in infile:
61 | record = json.loads(line.strip())
62 | guid = record[id_field] if not strip_last else record[id_field][:-1]
63 | if guid not in train_dynamics:
64 | assert epoch_num == 0
65 | train_dynamics[guid] = {"gold": record["gold"], "logits": []}
66 | train_dynamics[guid]["logits"].append(record[f"logits_epoch_{epoch_num}"])
67 |
68 | logger.info(f"Read training dynamics for {len(train_dynamics)} train instances.")
69 | return train_dynamics
70 |
--------------------------------------------------------------------------------
/data_selection.py:
--------------------------------------------------------------------------------
1 | # Only applied to training set
2 | # python data_selection.py --task_name qnli --model_name bert-base-cased --proportion 0.5 --burn_out 4
3 | import json
4 | import random
5 | random.seed(1)
6 | import argparse
7 | from dy_filtering import read_training_dynamics, compute_train_dy_metrics
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--task_name", type=str)
11 | parser.add_argument("--model_name", type=str)
12 | parser.add_argument("--proportion", type=float, default=0.33)
13 | parser.add_argument("--burn_out", type=int)
14 | args = parser.parse_args()
15 |
16 | TASK_NAME = args.task_name
17 | MODEL = args.model_name
18 | PROPORTION = args.proportion
19 |
20 | # 读取并合并到一个文件
21 | td = read_training_dynamics(f'dy_log/{TASK_NAME}/{MODEL}/')
22 | # 计算 metrics,转化成一个 dataframe
23 | td_df, _ = compute_train_dy_metrics(td, burn_out=args.burn_out)
24 |
25 |
26 | def consider_ascending_order(filtering_metric: str) -> bool:
27 | """
28 | Determine if the metric values' sorting order to get the most `valuable` examples for training.
29 | """
30 | if filtering_metric == "variability":
31 | return False
32 | elif filtering_metric == "confidence":
33 | return True
34 | elif filtering_metric == "threshold_closeness":
35 | return False
36 | elif filtering_metric == "forgetfulness":
37 | return False
38 | elif filtering_metric == "correctness":
39 | return True
40 | else:
41 | raise NotImplementedError(f"Filtering based on {filtering_metric} not implemented!")
42 |
43 |
44 |
45 | def data_selection(metric, select_worst, proportion, shuffle=True):
46 | ascending = consider_ascending_order(metric)
47 | if select_worst:
48 | ascending = not consider_ascending_order(metric)
49 | sorted_df = td_df.sort_values(by=metric, ascending=ascending)
50 | selected_df = sorted_df.head(n=int(proportion * len(sorted_df)))
51 | indices = list(selected_df['guid'])
52 | if shuffle:
53 | random.shuffle(indices)
54 | return {'indices':indices, 'df':selected_df}
55 |
56 |
57 | """
58 | hard-to-learn: METRIC = 'confidence'
59 | easy-to-learn: METRIC = 'confidence', SELECT_WORST = True
60 | ambiguoug: METRIC = 'variability'
61 | """
62 |
63 | three_regions_data_indices = {'hard':data_selection('confidence', False, PROPORTION)['indices'],
64 | 'easy':data_selection('confidence', True, PROPORTION)['indices'],
65 | 'ambiguous':data_selection('variability', False, PROPORTION)['indices']}
66 |
67 | with open(f'dy_log/{TASK_NAME}/{MODEL}/three_regions_data_indices.json','w') as f:
68 | f.write(json.dumps(three_regions_data_indices))
69 |
70 | # 然后可以直接跑glue任务,在选择训练集的时候,使用select函数来指定对应样本即可:
71 | """ e.g.
72 | from datasets import load_dataset
73 | raw_datasets = load_dataset('glue','sst2')
74 | easy_train_set = raw_datasets['train'].select(three_regions_data_indices['easy'])
75 | """
76 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Utilities for data handling.
4 | """
5 | import json
6 | import logging
7 | import os
8 | import pandas as pd
9 | import shutil
10 |
11 | from typing import Dict
12 |
13 | from data_utils_glue import read_glue_tsv
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def read_data(file_path: str,
19 | task_name: str,
20 | guid_as_int: bool = False):
21 | """
22 | Reads task-specific datasets from corresponding GLUE-style TSV files.
23 | """
24 | logger.warning("Data reading only works when data is in TSV format, "
25 | " and last column as classification label.")
26 |
27 | # `guid_index`: should be 2 for SNLI, 0 for MNLI and None for any random tsv file.
28 | if task_name == "MNLI":
29 | return read_glue_tsv(file_path,
30 | guid_index=0,
31 | guid_as_int=guid_as_int)
32 | elif task_name == "SNLI":
33 | return read_glue_tsv(file_path,
34 | guid_index=2,
35 | guid_as_int=guid_as_int)
36 | elif task_name == "WINOGRANDE":
37 | return read_glue_tsv(file_path,
38 | guid_index=0)
39 | elif task_name == "QNLI":
40 | return read_glue_tsv(file_path,
41 | guid_index=0)
42 | else:
43 | raise NotImplementedError(f"Reader for {task_name} not implemented.")
44 |
45 |
46 | def convert_tsv_entries_to_dataframe(tsv_dict: Dict, header: str) -> pd.DataFrame:
47 | """
48 | Converts entries from TSV file to Pandas DataFrame for faster processing.
49 | """
50 | header_fields = header.strip().split("\t")
51 | data = {header: [] for header in header_fields}
52 |
53 | for line in tsv_dict.values():
54 | fields = line.strip().split("\t")
55 | assert len(header_fields) == len(fields)
56 | for field, header in zip(fields, header_fields):
57 | data[header].append(field)
58 |
59 | df = pd.DataFrame(data, columns=header_fields)
60 | return df
61 |
62 |
63 | def copy_dev_test(task_name: str,
64 | from_dir: os.path,
65 | to_dir: os.path,
66 | extension: str = ".tsv"):
67 | """
68 | Copies development and test sets (for data selection experiments) from `from_dir` to `to_dir`.
69 | """
70 | if task_name == "MNLI":
71 | dev_filename = "dev_matched.tsv"
72 | test_filename = "dev_mismatched.tsv"
73 | elif task_name in ["SNLI", "QNLI", "WINOGRANDE"]:
74 | dev_filename = f"dev{extension}"
75 | test_filename = f"test{extension}"
76 | else:
77 | raise NotImplementedError(f"Logic for {task_name} not implemented.")
78 |
79 | dev_path = os.path.join(from_dir, dev_filename)
80 | if os.path.exists(dev_path):
81 | shutil.copyfile(dev_path, os.path.join(to_dir, dev_filename))
82 | else:
83 | raise ValueError(f"No file found at {dev_path}")
84 |
85 | test_path = os.path.join(from_dir, test_filename)
86 | if os.path.exists(test_path):
87 | shutil.copyfile(test_path, os.path.join(to_dir, test_filename))
88 | else:
89 | raise ValueError(f"No file found at {test_path}")
90 |
91 |
92 | def read_jsonl(file_path: str, key: str = "pairID"):
93 | """
94 | Reads JSONL file to recover mapping between one particular key field
95 | in the line and the result of the line as a JSON dict.
96 | If no key is provided, return a list of JSON dicts.
97 | """
98 | df = pd.read_json(file_path, lines=True)
99 | records = df.to_dict('records')
100 | logger.info(f"Read {len(records)} JSON records from {file_path}.")
101 |
102 | if key:
103 | assert key in df.columns
104 | return {record[key]: record for record in records}
105 | return records
106 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrainingDynamics
2 | > Re-implementation of the paper [Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics (EMNLP-20)](https://aclanthology.org/2020.emnlp-main.746/). This project is mainly based on [AllenAI's Dataset Cartography](https://github.com/allenai/cartography) project, where the model outputs (logits) of each sample is recorded after every training epoch. Based on these records, training dynamics (prediction confidence, variability, etc.) are computed to plot the Data Cartography to visualize the distribution of all training samples. However, the [original repo](https://github.com/allenai/cartography) hasn't been maintained for a long time. In this repo, we use the **latest version** of packages to reimplement the Dataset Cartography, as well as some other extensions based on the training dynamics.
3 |
4 | ## Basic requirements:
5 | - transformers==4.18.0
6 | - torch==1.7.0
7 | - datasets==2.3.2
8 | - accelerate==0.9.0
9 |
10 | More requirements see `requirements.txt`.
11 |
12 | ## Usage:
13 | For example, we want to record the training dynamics of SST2 dataset (a sentiment classification task from GLUE), we do the following steps:
14 |
15 | 1. **Run `sh run_glue_and_record_td.sh` to obtain the training dynamics.**
16 |
17 | Specify the `TASK_NAME` (here we choose `sst2`), `MODEL` you want to use and num of epochs to train the classifier.
18 |
19 | The following infomation will be recorded during training:
20 | - 'guid': the id of the sample
21 | - 'logits_epoch_{epoch}': output logits vector of the current sample
22 | - 'gold': the true label (index)
23 |
24 | After training, we can find the log files in `./dy_log/{TASK_NAME}/{MODEL}/training_dynamics` directory like:
25 | ```shell
26 | dynamics_epoch_0.jsonl
27 | dynamics_epoch_1.jsonl
28 | dynamics_epoch_2.jsonl
29 | ...
30 | ```
31 | each file contains records like:
32 | ```shell
33 | {"guid": 50325, "logits_epoch_0": [2.943110942840576, -2.2836594581604004], "gold": 0, "device": "cuda:0"}
34 | {"guid": 42123, "logits_epoch_0": [-2.7155513763427734, 3.249767541885376], "gold": 1, "device": "cuda:0"}
35 | {"guid": 42936, "logits_epoch_0": [-1.1907235383987427, 2.1173453330993652], "gold": 1, "device": "cuda:0"}
36 | ...
37 | ```
38 |
39 | 2. **Run `sh plot.sh` to plot the data cartography based the recorded training dynamics.**
40 |
41 | In `plot.sh`, we can specify the TASK_NAME and MODEL, which are used to determine the path of the training dynamics. First, the log files from each epoch are collected together, several metrics (confidence, variability, correctness, forgetfulness, etc.) are calculated and saved into a single file, named by 'td_metrics.jsonl' (in the save directory `./dy_log/{TASK}/{MODEL}/training_dynamics`):
42 |
43 | ```shell
44 | {"guid":50325,"index":0,"threshold_closeness":0.0039580798,"confidence":0.9960261285,"variability":0.0012847629,"correctness":4,"forgetfulness":0}
45 | {"guid":42123,"index":1,"threshold_closeness":0.0012448987,"confidence":0.9987535477,"variability":0.0007707975,"correctness":4,"forgetfulness":0}
46 | {"guid":42936,"index":2,"threshold_closeness":0.0396512556,"confidence":0.958637923,"variability":0.0095242939,"correctness":4,"forgetfulness":0}
47 | ...
48 | ```
49 |
50 | Then, a data map (dataset cartography) is plotted based on these metrics:
51 | 
52 |
53 | ## Data Selection
54 | After recording the training dynamics, we can re-train the model by selecting a subset (e.g. use only the ambiguous samples for training).
55 | For example, for `sst2` task and `bert-tiny` model, just run:
56 | ```shell
57 | python data_selection.py --task_name sst2 --model_name bert-tiny --burn_out 4
58 | ```
59 | then you can get a json file at `dy_log/sst2/bert-tiny/three_regions_data_indices.json`
60 |
61 | then, run `sh run_glue.sh` by adding `--with_data_selection` and `--data_selection_region [region]`.
62 |
63 | More details see comments in `run_glue.sh`.
64 |
65 | ## Other Extensions:
66 | Apart from the above usage, we can also compare the difference between two models (e.g. a strong model and a weak model) by computing the change of the dynamics. For example, we train a weak model (BERT-tiny) and strong model (RoBERTa-large) on SST2 dataset and plot their difference:
67 |
68 |
69 |
70 |
71 | You can find more detailed usage of this repo in our notebook `plot_demo.ipynb`.
72 |
73 | ---
74 |
75 | *Have fun and fell free to give your feedback :)*
76 |
77 | Citation:
78 | ```
79 | @inproceedings{swayamdipta-etal-2020-dataset,
80 | title = "Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics",
81 | author = "Swayamdipta, Swabha and
82 | Schwartz, Roy and
83 | Lourie, Nicholas and
84 | Wang, Yizhong and
85 | Hajishirzi, Hannaneh and
86 | Smith, Noah A. and
87 | Choi, Yejin",
88 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
89 | month = nov,
90 | year = "2020",
91 | address = "Online",
92 | publisher = "Association for Computational Linguistics",
93 | url = "https://aclanthology.org/2020.emnlp-main.746",
94 | doi = "10.18653/v1/2020.emnlp-main.746",
95 | pages = "9275--9293",
96 | }
97 | ```
--------------------------------------------------------------------------------
/HCT/create_noisy_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from datasets import load_dataset"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stderr",
19 | "output_type": "stream",
20 | "text": [
21 | "Reusing dataset glue (/home/v-biyangguo/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
22 | ]
23 | },
24 | {
25 | "data": {
26 | "application/vnd.jupyter.widget-view+json": {
27 | "model_id": "a057b84c31934616ac67c2d40c9af4ba",
28 | "version_major": 2,
29 | "version_minor": 0
30 | },
31 | "text/plain": [
32 | " 0%| | 0/3 [00:00, ?it/s]"
33 | ]
34 | },
35 | "metadata": {},
36 | "output_type": "display_data"
37 | }
38 | ],
39 | "source": [
40 | "dataset_name = 'rte'\n",
41 | "raw_datasets = load_dataset('glue', dataset_name)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "data": {
51 | "text/plain": [
52 | "Counter({1: 1241, 0: 1249})"
53 | ]
54 | },
55 | "execution_count": 3,
56 | "metadata": {},
57 | "output_type": "execute_result"
58 | }
59 | ],
60 | "source": [
61 | "from collections import Counter\n",
62 | "orig_labels = raw_datasets['train']['label']\n",
63 | "c = Counter(orig_labels)\n",
64 | "c"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 9,
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "data": {
74 | "text/plain": [
75 | "(996, [1577, 1722, 165, 1060, 2094, 1990, 1658, 1242, 1952, 1466])"
76 | ]
77 | },
78 | "execution_count": 9,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "# making some noise to the TRAIN set\n",
85 | "import random\n",
86 | "random.seed(0)\n",
87 | "\n",
88 | "shuffled_ids = random.sample(range(len(orig_labels)), len(orig_labels))\n",
89 | "NOISE_RATIO = 0.4\n",
90 | "noisy_ids = shuffled_ids[:int(NOISE_RATIO * len(shuffled_ids))]\n",
91 | "len(noisy_ids), noisy_ids[:10]"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 10,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "new_labels = orig_labels[:]\n",
101 | "for i in noisy_ids:\n",
102 | " new_labels[i] = 0 if new_labels[i] == 1 else 1"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 11,
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "data": {
112 | "text/plain": [
113 | "([0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],\n",
114 | " '\\n',\n",
115 | " [1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1])"
116 | ]
117 | },
118 | "execution_count": 11,
119 | "metadata": {},
120 | "output_type": "execute_result"
121 | }
122 | ],
123 | "source": [
124 | "new_labels[:20],'\\n', orig_labels[:20]"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 12,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "data": {
134 | "text/plain": [
135 | "DatasetDict({\n",
136 | " train: Dataset({\n",
137 | " features: ['sentence1', 'sentence2', 'idx', 'label'],\n",
138 | " num_rows: 2490\n",
139 | " })\n",
140 | " validation: Dataset({\n",
141 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
142 | " num_rows: 277\n",
143 | " })\n",
144 | " test: Dataset({\n",
145 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
146 | " num_rows: 3000\n",
147 | " })\n",
148 | "})"
149 | ]
150 | },
151 | "execution_count": 12,
152 | "metadata": {},
153 | "output_type": "execute_result"
154 | }
155 | ],
156 | "source": [
157 | "raw_datasets['train'] = raw_datasets['train'].remove_columns(['label'])\n",
158 | "raw_datasets['train'] = raw_datasets['train'].add_column('label', new_labels)\n",
159 | "raw_datasets"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 15,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "raw_datasets.save_to_disk(f'../datasets/{dataset_name}-noisy-{NOISE_RATIO}/with_idx')"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 16,
174 | "metadata": {},
175 | "outputs": [
176 | {
177 | "data": {
178 | "text/plain": [
179 | "Counter({0: 1251, 1: 1239})"
180 | ]
181 | },
182 | "execution_count": 16,
183 | "metadata": {},
184 | "output_type": "execute_result"
185 | }
186 | ],
187 | "source": [
188 | "Counter(new_labels)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": []
197 | }
198 | ],
199 | "metadata": {
200 | "interpreter": {
201 | "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
202 | },
203 | "kernelspec": {
204 | "display_name": "Python 3.6.10 64-bit ('base': conda)",
205 | "name": "python3"
206 | },
207 | "language_info": {
208 | "name": "python",
209 | "version": ""
210 | },
211 | "orig_nbformat": 4
212 | },
213 | "nbformat": 4,
214 | "nbformat_minor": 2
215 | }
--------------------------------------------------------------------------------
/HCT/dataset_prepare.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from datasets import load_dataset"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 3,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "text/plain": [
20 | "DatasetDict({\n",
21 | " test: Dataset({\n",
22 | " features: ['premise', 'hypothesis', 'label'],\n",
23 | " num_rows: 10000\n",
24 | " })\n",
25 | " train: Dataset({\n",
26 | " features: ['premise', 'hypothesis', 'label'],\n",
27 | " num_rows: 550152\n",
28 | " })\n",
29 | " validation: Dataset({\n",
30 | " features: ['premise', 'hypothesis', 'label'],\n",
31 | " num_rows: 10000\n",
32 | " })\n",
33 | "})"
34 | ]
35 | },
36 | "execution_count": 3,
37 | "metadata": {},
38 | "output_type": "execute_result"
39 | }
40 | ],
41 | "source": [
42 | "dataset_name = 'snli'\n",
43 | "raw_datasets = load_dataset(dataset_name)\n",
44 | "raw_datasets"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 4,
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "name": "stderr",
54 | "output_type": "stream",
55 | "text": [
56 | "Loading cached processed dataset at /home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bd229a5e884be60a.arrow\n"
57 | ]
58 | },
59 | {
60 | "data": {
61 | "application/vnd.jupyter.widget-view+json": {
62 | "model_id": "8cedbfb9a550451ca7c53baade7281ba",
63 | "version_major": 2,
64 | "version_minor": 0
65 | },
66 | "text/plain": [
67 | " 0%| | 0/10 [00:00, ?ba/s]"
68 | ]
69 | },
70 | "metadata": {},
71 | "output_type": "display_data"
72 | },
73 | {
74 | "data": {
75 | "application/vnd.jupyter.widget-view+json": {
76 | "model_id": "c43b25c10fc04b678cba14c5c8ec95f9",
77 | "version_major": 2,
78 | "version_minor": 0
79 | },
80 | "text/plain": [
81 | " 0%| | 0/10 [00:00, ?ba/s]"
82 | ]
83 | },
84 | "metadata": {},
85 | "output_type": "display_data"
86 | },
87 | {
88 | "data": {
89 | "text/plain": [
90 | "DatasetDict({\n",
91 | " test: Dataset({\n",
92 | " features: ['premise', 'hypothesis', 'label'],\n",
93 | " num_rows: 9824\n",
94 | " })\n",
95 | " train: Dataset({\n",
96 | " features: ['premise', 'hypothesis', 'label'],\n",
97 | " num_rows: 549367\n",
98 | " })\n",
99 | " validation: Dataset({\n",
100 | " features: ['premise', 'hypothesis', 'label'],\n",
101 | " num_rows: 9842\n",
102 | " })\n",
103 | "})"
104 | ]
105 | },
106 | "execution_count": 4,
107 | "metadata": {},
108 | "output_type": "execute_result"
109 | }
110 | ],
111 | "source": [
112 | "# remove samples with label -1\n",
113 | "raw_datasets['train'] = raw_datasets['train'].filter(lambda x:x['label']!=-1)\n",
114 | "raw_datasets['validation'] = raw_datasets['validation'].filter(lambda x:x['label']!=-1)\n",
115 | "raw_datasets['test'] = raw_datasets['test'].filter(lambda x:x['label']!=-1)\n",
116 | "raw_datasets"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 5,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "data": {
126 | "text/plain": [
127 | "(549367, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
128 | ]
129 | },
130 | "execution_count": 5,
131 | "metadata": {},
132 | "output_type": "execute_result"
133 | }
134 | ],
135 | "source": [
136 | "# add idx into train set\n",
137 | "ids = list(range(len(raw_datasets['train'])))\n",
138 | "len(ids), ids[:10]"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 10,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "text/plain": [
149 | "DatasetDict({\n",
150 | " test: Dataset({\n",
151 | " features: ['premise', 'hypothesis', 'label'],\n",
152 | " num_rows: 9824\n",
153 | " })\n",
154 | " train: Dataset({\n",
155 | " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
156 | " num_rows: 549367\n",
157 | " })\n",
158 | " validation: Dataset({\n",
159 | " features: ['premise', 'hypothesis', 'label'],\n",
160 | " num_rows: 9842\n",
161 | " })\n",
162 | "})"
163 | ]
164 | },
165 | "execution_count": 10,
166 | "metadata": {},
167 | "output_type": "execute_result"
168 | }
169 | ],
170 | "source": [
171 | "raw_datasets['train'] = raw_datasets['train'].add_column('idx',ids)\n",
172 | "raw_datasets"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 11,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "application/vnd.jupyter.widget-view+json": {
183 | "model_id": "68f896fcba3d485fba513c754adc2872",
184 | "version_major": 2,
185 | "version_minor": 0
186 | },
187 | "text/plain": [
188 | "Flattening the indices: 0%| | 0/10 [00:00, ?ba/s]"
189 | ]
190 | },
191 | "metadata": {},
192 | "output_type": "display_data"
193 | },
194 | {
195 | "data": {
196 | "application/vnd.jupyter.widget-view+json": {
197 | "model_id": "713f37bed4da4e83812e8c10fd8893a9",
198 | "version_major": 2,
199 | "version_minor": 0
200 | },
201 | "text/plain": [
202 | "Flattening the indices: 0%| | 0/10 [00:00, ?ba/s]"
203 | ]
204 | },
205 | "metadata": {},
206 | "output_type": "display_data"
207 | }
208 | ],
209 | "source": [
210 | "dataset_name = 'snli'\n",
211 | "raw_datasets.save_to_disk(f\"../datasets/{dataset_name}/with_idx\")"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": []
220 | }
221 | ],
222 | "metadata": {
223 | "interpreter": {
224 | "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
225 | },
226 | "kernelspec": {
227 | "display_name": "Python 3.6.10 64-bit ('conda': virtualenv)",
228 | "name": "python3"
229 | },
230 | "language_info": {
231 | "codemirror_mode": {
232 | "name": "ipython",
233 | "version": 3
234 | },
235 | "file_extension": ".py",
236 | "mimetype": "text/x-python",
237 | "name": "python",
238 | "nbconvert_exporter": "python",
239 | "pygments_lexer": "ipython3",
240 | "version": "3.6.10"
241 | },
242 | "orig_nbformat": 4
243 | },
244 | "nbformat": 4,
245 | "nbformat_minor": 2
246 | }
--------------------------------------------------------------------------------
/dy_filtering.py:
--------------------------------------------------------------------------------
1 | """
2 | Filtering and dataset mapping methods based on training dynamics.
3 | By default, this module reads training dynamics from a given trained model and
4 | computes the metrics---confidence, variability, correctness,
5 | as well as baseline metrics of forgetfulness and threshold closeness
6 | for each instance in the training data.
7 | If specified, data maps can be plotted with respect to confidence and variability.
8 | Moreover, datasets can be filtered with respect any of the other metrics.
9 | """
10 | import argparse
11 | import jsonx
12 | import logging
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import os
16 | import pandas as pd
17 | import seaborn as sns
18 | import torch
19 | import tqdm
20 |
21 | from collections import defaultdict
22 | from typing import List
23 |
24 | from data_utils import read_data, read_jsonl, copy_dev_test
25 | from selection_utils import read_training_dynamics
26 |
27 | # TODO(SS): Named tuple for tasks and filtering methods.
28 |
29 | logging.basicConfig(
30 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO
31 | )
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def compute_forgetfulness(correctness_trend: List[float]) -> int:
36 | """
37 | Given a epoch-wise trend of train predictions, compute frequency with which
38 | an example is forgotten, i.e. predicted incorrectly _after_ being predicted correctly.
39 | Based on: https://arxiv.org/abs/1812.05159
40 | """
41 | if not any(correctness_trend): # Example is never predicted correctly, or learnt!
42 | return 1000
43 | learnt = False # Predicted correctly in the current epoch.
44 | times_forgotten = 0
45 | for is_correct in correctness_trend:
46 | if (not learnt and not is_correct) or (learnt and is_correct):
47 | # nothing changed.
48 | continue
49 | elif learnt and not is_correct:
50 | # Forgot after learning at some point!
51 | learnt = False
52 | times_forgotten += 1
53 | elif not learnt and is_correct:
54 | # Learnt!
55 | learnt = True
56 | return times_forgotten
57 |
58 |
59 | def compute_correctness(trend: List[float]) -> float:
60 | """
61 | Aggregate #times an example is predicted correctly during all training epochs.
62 | """
63 | return sum(trend)
64 |
65 |
66 | def compute_train_dy_metrics(training_dynamics, include_ci=False, burn_out=100):
67 | """
68 | Given the training dynamics (logits for each training instance across epochs), compute metrics
69 | based on it, for data map coorodinates.
70 | Computed metrics are: confidence, variability, correctness, forgetfulness, threshold_closeness---
71 | the last two being baselines from prior work
72 | (Example Forgetting: https://arxiv.org/abs/1812.05159 and
73 | Active Bias: https://arxiv.org/abs/1704.07433 respectively).
74 | Returns:
75 | - DataFrame with these metrics.
76 | - DataFrame with more typical training evaluation metrics, such as accuracy / loss.
77 | """
78 | confidence_ = {}
79 | variability_ = {}
80 | threshold_closeness_ = {}
81 | correctness_ = {}
82 | forgetfulness_ = {}
83 |
84 | # Functions to be applied to the data.
85 | variability_func = lambda conf: np.std(conf)
86 | if include_ci: # Based on prior work on active bias (https://arxiv.org/abs/1704.07433)
87 | variability_func = lambda conf: np.sqrt(np.var(conf) + np.var(conf) * np.var(conf) / (len(conf)-1))
88 | threshold_closeness_func = lambda conf: conf * (1 - conf)
89 |
90 | loss = torch.nn.CrossEntropyLoss()
91 |
92 | num_tot_epochs = len(list(training_dynamics.values())[0]["logits"])
93 | if burn_out < num_tot_epochs:
94 | logger.info(f"Computing training dynamics. Burning out at {burn_out} of {num_tot_epochs}. ")
95 | else:
96 | logger.info(f"Computing training dynamics across {num_tot_epochs} epochs")
97 | logger.info("Metrics computed: confidence, variability, correctness, forgetfulness, threshold_closeness")
98 |
99 | logits = {i: [] for i in range(num_tot_epochs)}
100 | targets = {i: [] for i in range(num_tot_epochs)}
101 | training_accuracy = defaultdict(float)
102 |
103 | for guid in tqdm.tqdm(training_dynamics):
104 | correctness_trend = []
105 | true_probs_trend = []
106 |
107 | record = training_dynamics[guid]
108 | for i, epoch_logits in enumerate(record["logits"]):
109 | probs = torch.nn.functional.softmax(torch.Tensor(epoch_logits), dim=-1)
110 | true_class_prob = float(probs[record["gold"]])
111 | true_probs_trend.append(true_class_prob)
112 |
113 | prediction = np.argmax(epoch_logits)
114 | is_correct = (prediction == record["gold"]).item()
115 | correctness_trend.append(is_correct)
116 |
117 | training_accuracy[i] += is_correct
118 | logits[i].append(epoch_logits)
119 | targets[i].append(record["gold"])
120 |
121 | if burn_out < num_tot_epochs:
122 | correctness_trend = correctness_trend[:burn_out]
123 | true_probs_trend = true_probs_trend[:burn_out]
124 |
125 | correctness_[guid] = compute_correctness(correctness_trend)
126 | confidence_[guid] = np.mean(true_probs_trend)
127 | variability_[guid] = variability_func(true_probs_trend)
128 |
129 | forgetfulness_[guid] = compute_forgetfulness(correctness_trend)
130 | threshold_closeness_[guid] = threshold_closeness_func(confidence_[guid])
131 |
132 | # Should not affect ranking, so ignoring.
133 | epsilon_var = np.mean(list(variability_.values()))
134 |
135 | column_names = ['guid',
136 | 'index',
137 | 'threshold_closeness',
138 | 'confidence',
139 | 'variability',
140 | 'correctness',
141 | 'forgetfulness',]
142 | df = pd.DataFrame([[guid,
143 | i,
144 | threshold_closeness_[guid],
145 | confidence_[guid],
146 | variability_[guid],
147 | correctness_[guid],
148 | forgetfulness_[guid],
149 | ] for i, guid in enumerate(correctness_)], columns=column_names)
150 |
151 | df_train = pd.DataFrame([[i,
152 | loss(torch.Tensor(logits[i]), torch.LongTensor(targets[i])).item() / len(training_dynamics),
153 | training_accuracy[i] / len(training_dynamics)
154 | ] for i in range(num_tot_epochs)],
155 | columns=['epoch', 'loss', 'train_acc'])
156 | return df, df_train
157 |
158 |
159 | def consider_ascending_order(filtering_metric: str) -> bool:
160 | """
161 | Determine if the metric values' sorting order to get the most `valuable` examples for training.
162 | """
163 | if filtering_metric == "variability":
164 | return False
165 | elif filtering_metric == "confidence":
166 | return True
167 | elif filtering_metric == "threshold_closeness":
168 | return False
169 | elif filtering_metric == "forgetfulness":
170 | return False
171 | elif filtering_metric == "correctness":
172 | return True
173 | else:
174 | raise NotImplementedError(f"Filtering based on {filtering_metric} not implemented!")
175 |
176 |
177 | def write_filtered_data(args, train_dy_metrics):
178 | """
179 | Filter data based on the given metric, and write it in TSV format to train GLUE-style classifier.
180 | """
181 | # First save the args for filtering, to keep track of which model was used for filtering.
182 | argparse_dict = vars(args)
183 | with open(os.path.join(args.filtering_output_dir, f"filtering_configs.json"), "w") as outfile:
184 | outfile.write(json.dumps(argparse_dict, indent=4, sort_keys=True) + "\n")
185 |
186 | # Determine whether to sort data in ascending order or not, based on the metric.
187 | is_ascending = consider_ascending_order(args.metric)
188 | if args.worst:
189 | is_ascending = not is_ascending
190 |
191 | # Sort by selection.
192 | sorted_scores = train_dy_metrics.sort_values(by=[args.metric],ascending=is_ascending)
193 |
194 | original_train_file = os.path.join(os.path.join(args.data_dir, args.task_name), f"train.tsv")
195 | train_numeric, header = read_data(original_train_file, task_name=args.task_name, guid_as_int=True)
196 |
197 | for fraction in [0.01, 0.05, 0.10, 0.1667, 0.25, 0.3319, 0.50, 0.75]:
198 | outdir = os.path.join(args.filtering_output_dir,
199 | f"cartography_{args.metric}_{fraction:.2f}/{args.task_name}")
200 | if not os.path.exists(outdir):
201 | os.makedirs(outdir)
202 |
203 | # Dev and test need not be subsampled.
204 | copy_dev_test(args.task_name,
205 | from_dir=os.path.join(args.data_dir, args.task_name),
206 | to_dir=outdir)
207 |
208 | num_samples = int(fraction * len(train_numeric))
209 | with open(os.path.join(outdir, f"train.tsv"), "w") as outfile:
210 | outfile.write(header + "\n")
211 | selected = sorted_scores.head(n=num_samples+1)
212 | if args.both_ends:
213 | hardest = sorted_scores.head(n=int(num_samples * 0.7))
214 | easiest = sorted_scores.tail(n=num_samples - hardest.shape[0])
215 | selected = pd.concat([hardest, easiest])
216 | fm = args.metric
217 | logger.info(f"Selecting both ends: {fm} = "
218 | f"({hardest.head(1)[fm].values[0]:3f}: {hardest.tail(1)[fm].values[0]:3f}) "
219 | f"& ({easiest.head(1)[fm].values[0]:3f}: {easiest.tail(1)[fm].values[0]:3f})")
220 |
221 | selection_iterator = tqdm.tqdm(range(len(selected)))
222 | for idx in selection_iterator:
223 | selection_iterator.set_description(
224 | f"{args.metric} = {selected.iloc[idx][args.metric]:.4f}")
225 |
226 | selected_id = selected.iloc[idx]["guid"]
227 | if args.task_name in ["SNLI", "MNLI"]:
228 | selected_id = int(selected_id)
229 | elif args.task_name == "WINOGRANDE":
230 | selected_id = str(int(selected_id))
231 | record = train_numeric[selected_id]
232 | outfile.write(record + "\n")
233 |
234 | logger.info(f"Wrote {num_samples} samples to {outdir}.")
235 |
236 |
237 | def plot_data_map(dataframe: pd.DataFrame,
238 | plot_dir: os.path,
239 | hue_metric: str = 'correct.',
240 | title: str = '',
241 | model: str = 'RoBERTa',
242 | show_hist: bool = False,
243 | max_instances_to_plot = 55000):
244 | # Set style.
245 | sns.set(style='whitegrid', font_scale=1.6, font='Georgia', context='paper')
246 | logger.info(f"Plotting figure for {title} using the {model} model ...")
247 |
248 | # Subsample data to plot, so the plot is not too busy.
249 | dataframe = dataframe.sample(n=max_instances_to_plot if dataframe.shape[0] > max_instances_to_plot else len(dataframe))
250 |
251 | # Normalize correctness to a value between 0 and 1.
252 | dataframe = dataframe.assign(corr_frac = lambda d: d.correctness / d.correctness.max())
253 | dataframe['correct.'] = [f"{x:.1f}" for x in dataframe['corr_frac']]
254 |
255 | main_metric = 'variability'
256 | other_metric = 'confidence'
257 |
258 | hue = hue_metric
259 | num_hues = len(dataframe[hue].unique().tolist())
260 | style = hue_metric if num_hues < 8 else None
261 |
262 | if not show_hist:
263 | fig, ax0 = plt.subplots(1, 1, figsize=(8, 6))
264 | else:
265 | fig = plt.figure(figsize=(14, 10), )
266 | gs = fig.add_gridspec(3, 2, width_ratios=[5, 1]) # 构造一个三行两列的图,两列的宽度比例是 5:1
267 | ax0 = fig.add_subplot(gs[:, 0]) # ax0是一个子图,位置在整个fig的第一列区域
268 |
269 | # Make the scatterplot.
270 | # Choose a palette.
271 | pal = sns.diverging_palette(260, 15, n=num_hues, sep=10, center="dark")
272 |
273 | plot = sns.scatterplot(x=main_metric,
274 | y=other_metric,
275 | ax=ax0,
276 | data=dataframe,
277 | hue=hue,
278 | palette=pal,
279 | style=style,
280 | s=30)
281 |
282 | # Annotate Regions.
283 | bb = lambda c: dict(boxstyle="round,pad=0.3", ec=c, lw=2, fc="white")
284 | func_annotate = lambda text, xyc, bbc : ax0.annotate(text,
285 | xy=xyc,
286 | xycoords="axes fraction",
287 | fontsize=15,
288 | color='black',
289 | va="center",
290 | ha="center",
291 | rotation=350,
292 | bbox=bb(bbc))
293 | an1 = func_annotate("ambiguous", xyc=(0.9, 0.5), bbc='black')
294 | an2 = func_annotate("easy-to-learn", xyc=(0.27, 0.85), bbc='r')
295 | an3 = func_annotate("hard-to-learn", xyc=(0.35, 0.25), bbc='b')
296 |
297 |
298 | if not show_hist:
299 | plot.legend(ncol=1, bbox_to_anchor=[0.175, 0.5], loc='right')
300 | else:
301 | plot.legend(fancybox=True, shadow=True, ncol=1)
302 | plot.set_xlabel('variability')
303 | plot.set_ylabel('confidence')
304 |
305 | if show_hist:
306 | plot.set_title(f"{title}-{model} Data Map", fontsize=17)
307 |
308 | # Make the histograms.
309 | ax1 = fig.add_subplot(gs[0, 1])
310 | ax2 = fig.add_subplot(gs[1, 1])
311 | ax3 = fig.add_subplot(gs[2, 1])
312 |
313 | plott0 = dataframe.hist(column=['confidence'], ax=ax1, color='#622a87')
314 | plott0[0].set_title('')
315 | plott0[0].set_xlabel('confidence')
316 | plott0[0].set_ylabel('density')
317 |
318 | plott1 = dataframe.hist(column=['variability'], ax=ax2, color='teal')
319 | plott1[0].set_title('')
320 | plott1[0].set_xlabel('variability')
321 | plott1[0].set_ylabel('density')
322 |
323 | plot2 = sns.countplot(x="correct.", data=dataframe, ax=ax3, color='#86bf91')
324 | ax3.xaxis.grid(True) # Show the vertical gridlines
325 |
326 | plot2.set_title('')
327 | plot2.set_xlabel('correctness')
328 | plot2.set_ylabel('density')
329 |
330 | fig.tight_layout()
331 | filename = f'{plot_dir}/{title}_{model}.pdf' if show_hist else f'figures/compact_{title}_{model}.pdf'
332 | fig.savefig(filename, dpi=300)
333 | logger.info(f"Plot saved to {filename}")
334 |
335 |
336 | if __name__ == "__main__":
337 | parser = argparse.ArgumentParser()
338 | parser.add_argument("--filter",
339 | action="store_true",
340 | help="Whether to filter data subsets based on specified `metric`.")
341 | parser.add_argument("--plot",
342 | action="store_true",
343 | help="Whether to plot data maps and save as `pdf`.")
344 | parser.add_argument("--model_dir",
345 | "-o",
346 | required=True,
347 | type=os.path.abspath,
348 | help="Directory where model training dynamics stats reside.")
349 | parser.add_argument("--data_dir",
350 | "-d",
351 | default="/Users/swabhas/data/glue/WINOGRANDE/xl/",
352 | type=os.path.abspath,
353 | help="Directory where data for task resides.")
354 | parser.add_argument("--plots_dir",
355 | default="./cartography/",
356 | type=os.path.abspath,
357 | help="Directory where plots are to be saved.")
358 | parser.add_argument("--task_name",
359 | "-t",
360 | # default="WINOGRANDE",
361 | # choices=("SNLI", "MNLI", "QNLI", "WINOGRANDE"),
362 | help="Which task are we plotting or filtering for.")
363 | parser.add_argument('--metric',
364 | choices=('threshold_closeness',
365 | 'confidence',
366 | 'variability',
367 | 'correctness',
368 | 'forgetfulness'),
369 | help="Metric to filter data by.",)
370 | parser.add_argument("--include_ci",
371 | action="store_true",
372 | help="Compute the confidence interval for variability.")
373 | parser.add_argument("--filtering_output_dir",
374 | "-f",
375 | default="./filtered/",
376 | type=os.path.abspath,
377 | help="Output directory where filtered datasets are to be written.")
378 | parser.add_argument("--worst",
379 | action="store_true",
380 | help="Select from the opposite end of the spectrum acc. to metric,"
381 | "for baselines")
382 | parser.add_argument("--both_ends",
383 | action="store_true",
384 | help="Select from both ends of the spectrum acc. to metric,")
385 | parser.add_argument("--burn_out",
386 | type=int,
387 | default=100,
388 | help="# Epochs for which to compute train dynamics.")
389 | parser.add_argument("--model",
390 | default="RoBERTa",
391 | help="Model for which data map is being plotted")
392 |
393 | args = parser.parse_args()
394 |
395 | training_dynamics = read_training_dynamics(args.model_dir,
396 | strip_last=True if args.task_name in ["QNLI"] else False,
397 | burn_out=args.burn_out if args.burn_out < 100 else None)
398 | total_epochs = len(list(training_dynamics.values())[0]["logits"])
399 | if args.burn_out > total_epochs:
400 | args.burn_out = total_epochs
401 | logger.info(f"Total epochs found: {args.burn_out}")
402 | train_dy_metrics, _ = compute_train_dy_metrics(training_dynamics, args.include_ci, args.burn_out)
403 |
404 | burn_out_str = f"_{args.burn_out}" if args.burn_out > total_epochs else ""
405 | train_dy_filename = os.path.join(args.model_dir, f"td_metrics{burn_out_str}.jsonl")
406 | train_dy_metrics.to_json(train_dy_filename,
407 | orient='records',
408 | lines=True)
409 | logger.info(f"Metrics based on Training Dynamics written to {train_dy_filename}")
410 |
411 | if args.filter:
412 | assert args.filtering_output_dir
413 | if not os.path.exists(args.filtering_output_dir):
414 | os.makedirs(args.filtering_output_dir)
415 | assert args.metric
416 | write_filtered_data(args, train_dy_metrics)
417 |
418 | if args.plot:
419 | assert args.plots_dir
420 | if not os.path.exists(args.plots_dir):
421 | os.makedirs(args.plots_dir)
422 | plot_data_map(train_dy_metrics, args.plots_dir, title=args.task_name, show_hist=True, model=args.model)
423 |
--------------------------------------------------------------------------------
/run_glue.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
16 | import argparse
17 | from collections import defaultdict
18 | import json
19 | import logging
20 | import math
21 | import os
22 | import random
23 | from pathlib import Path
24 |
25 | import datasets
26 | import torch
27 | from datasets import load_dataset, load_metric
28 | from torch.utils.data import DataLoader
29 | from tqdm.auto import tqdm
30 |
31 | import transformers
32 | from accelerate import Accelerator
33 | from accelerate.logging import get_logger
34 | from accelerate.utils import set_seed
35 | from huggingface_hub import Repository
36 | from transformers import (
37 | AutoConfig,
38 | AutoModelForSequenceClassification,
39 | AutoTokenizer,
40 | DataCollatorWithPadding,
41 | PretrainedConfig,
42 | SchedulerType,
43 | default_data_collator,
44 | get_scheduler,
45 | )
46 | # from transformers.utils import get_full_repo_name, send_example_telemetry
47 | from transformers.utils.versions import require_version
48 |
49 |
50 | logger = get_logger(__name__)
51 |
52 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
53 |
54 | task_to_keys = {
55 | "cola": ("sentence", None),
56 | "mnli": ("premise", "hypothesis"),
57 | "mrpc": ("sentence1", "sentence2"),
58 | "qnli": ("question", "sentence"),
59 | "qqp": ("question1", "question2"),
60 | "rte": ("sentence1", "sentence2"),
61 | "sst2": ("sentence", None),
62 | "stsb": ("sentence1", "sentence2"),
63 | "wnli": ("sentence1", "sentence2"),
64 | }
65 |
66 |
67 | def parse_args():
68 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
69 | parser.add_argument(
70 | "--task_name",
71 | type=str,
72 | default=None,
73 | help="The name of the glue task to train on.",
74 | choices=list(task_to_keys.keys()),
75 | )
76 | parser.add_argument(
77 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
78 | )
79 | parser.add_argument(
80 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
81 | )
82 | parser.add_argument(
83 | "--max_length",
84 | type=int,
85 | default=128,
86 | help=(
87 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
88 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
89 | ),
90 | )
91 | parser.add_argument(
92 | "--pad_to_max_length",
93 | action="store_true",
94 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
95 | )
96 | parser.add_argument(
97 | "--model_name_or_path",
98 | type=str,
99 | help="Path to pretrained model or model identifier from huggingface.co/models.",
100 | required=True,
101 | )
102 | parser.add_argument(
103 | "--use_slow_tokenizer",
104 | action="store_true",
105 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
106 | )
107 | parser.add_argument(
108 | "--per_device_train_batch_size",
109 | type=int,
110 | default=8,
111 | help="Batch size (per device) for the training dataloader.",
112 | )
113 | parser.add_argument(
114 | "--per_device_eval_batch_size",
115 | type=int,
116 | default=8,
117 | help="Batch size (per device) for the evaluation dataloader.",
118 | )
119 | parser.add_argument(
120 | "--learning_rate",
121 | type=float,
122 | default=5e-5,
123 | help="Initial learning rate (after the potential warmup period) to use.",
124 | )
125 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
126 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
127 | parser.add_argument(
128 | "--max_train_steps",
129 | type=int,
130 | default=None,
131 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
132 | )
133 | parser.add_argument(
134 | "--gradient_accumulation_steps",
135 | type=int,
136 | default=1,
137 | help="Number of updates steps to accumulate before performing a backward/update pass.",
138 | )
139 | parser.add_argument(
140 | "--lr_scheduler_type",
141 | type=SchedulerType,
142 | default="linear",
143 | help="The scheduler type to use.",
144 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
145 | )
146 | parser.add_argument(
147 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
148 | )
149 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
150 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
151 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
152 | parser.add_argument(
153 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
154 | )
155 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
156 | parser.add_argument(
157 | "--checkpointing_steps",
158 | type=str,
159 | default=None,
160 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
161 | )
162 | parser.add_argument(
163 | "--resume_from_checkpoint",
164 | type=str,
165 | default=None,
166 | help="If the training should continue from a checkpoint folder.",
167 | )
168 | parser.add_argument(
169 | "--with_tracking",
170 | action="store_true",
171 | help="Whether to enable experiment trackers for logging.",
172 | )
173 | parser.add_argument(
174 | "--report_to",
175 | type=str,
176 | default="all",
177 | help=(
178 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
179 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
180 | "Only applicable when `--with_tracking` is passed."
181 | ),
182 | )
183 | parser.add_argument(
184 | "--ignore_mismatched_sizes",
185 | action="store_true",
186 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
187 | )
188 | args = parser.parse_args()
189 |
190 | # Sanity checks
191 | if args.task_name is None and args.train_file is None and args.validation_file is None:
192 | raise ValueError("Need either a task name or a training/validation file.")
193 | else:
194 | if args.train_file is not None:
195 | extension = args.train_file.split(".")[-1]
196 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
197 | if args.validation_file is not None:
198 | extension = args.validation_file.split(".")[-1]
199 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
200 |
201 | # if args.push_to_hub:
202 | # assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
203 |
204 | return args
205 |
206 |
207 | def main():
208 | args = parse_args()
209 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
210 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
211 | # send_example_telemetry("run_glue_no_trainer", args)
212 |
213 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
214 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
215 | # in the environment
216 | accelerator = (
217 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
218 | )
219 | # Make one log on every process with the configuration for debugging.
220 | logging.basicConfig(
221 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
222 | datefmt="%m/%d/%Y %H:%M:%S",
223 | level=logging.INFO,
224 | )
225 | logger.info(accelerator.state, main_process_only=False)
226 | if accelerator.is_local_main_process:
227 | datasets.utils.logging.set_verbosity_warning()
228 | transformers.utils.logging.set_verbosity_info()
229 | else:
230 | datasets.utils.logging.set_verbosity_error()
231 | transformers.utils.logging.set_verbosity_error()
232 |
233 | # If passed along, set the training seed now.
234 | if args.seed is not None:
235 | set_seed(args.seed)
236 |
237 | # # Handle the repository creation
238 | # if accelerator.is_main_process:
239 | # if args.push_to_hub:
240 | # if args.hub_model_id is None:
241 | # repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
242 | # else:
243 | # repo_name = args.hub_model_id
244 | # repo = Repository(args.output_dir, clone_from=repo_name)
245 |
246 | # with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
247 | # if "step_*" not in gitignore:
248 | # gitignore.write("step_*\n")
249 | # if "epoch_*" not in gitignore:
250 | # gitignore.write("epoch_*\n")
251 | # elif args.output_dir is not None:
252 | # os.makedirs(args.output_dir, exist_ok=True)
253 | # accelerator.wait_for_everyone()
254 |
255 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
256 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
257 |
258 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
259 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
260 | # label if at least two columns are provided.
261 |
262 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
263 | # single column. You can easily tweak this behavior (see below)
264 |
265 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
266 | # download the dataset.
267 | if args.task_name is not None:
268 | # Downloading and loading a dataset from the hub.
269 | raw_datasets = load_dataset("glue", args.task_name)
270 | else:
271 | # Loading the dataset from local csv or json file.
272 | data_files = {}
273 | if args.train_file is not None:
274 | data_files["train"] = args.train_file
275 | if args.validation_file is not None:
276 | data_files["validation"] = args.validation_file
277 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
278 | raw_datasets = load_dataset(extension, data_files=data_files)
279 | # See more about loading any type of standard or custom dataset at
280 | # https://huggingface.co/docs/datasets/loading_datasets.html.
281 |
282 | # Labels
283 | if args.task_name is not None:
284 | is_regression = args.task_name == "stsb"
285 | if not is_regression:
286 | label_list = raw_datasets["train"].features["label"].names
287 | num_labels = len(label_list)
288 | else:
289 | num_labels = 1
290 | else:
291 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
292 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
293 | if is_regression:
294 | num_labels = 1
295 | else:
296 | # A useful fast method:
297 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
298 | label_list = raw_datasets["train"].unique("label")
299 | label_list.sort() # Let's sort it for determinism
300 | num_labels = len(label_list)
301 |
302 | # Load pretrained model and tokenizer
303 | #
304 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
305 | # download model & vocab.
306 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
307 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
308 | model = AutoModelForSequenceClassification.from_pretrained(
309 | args.model_name_or_path,
310 | from_tf=bool(".ckpt" in args.model_name_or_path),
311 | config=config,
312 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
313 | )
314 | # 对非HF官方模型的名称的处理,只保留模型名
315 | if '/' in args.model_name_or_path:
316 | args.model_name_or_path = args.model_name_or_path.split('/')[-1]
317 |
318 | # Preprocessing the datasets
319 | # --------------- GLUE tasks ---------------
320 | if args.task_name is not None:
321 | sentence1_key, sentence2_key = task_to_keys[args.task_name]
322 | # --------------- Other tasks ---------------
323 | else:
324 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
325 | # 这里的逻辑是这样的:
326 | # 对于非glue的数据集,要求要包含`label`字段
327 | # 然后希望你有`sentence1`, `sentence2`这两个字段,这样就跟glue对齐了
328 | # 如果你也不是用的这个名字,那就选择非label列的前两个字段来分别作为sentence1和sentence2
329 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
330 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
331 | sentence1_key, sentence2_key = "sentence1", "sentence2"
332 | else:
333 | if len(non_label_column_names) >= 2:
334 | sentence1_key, sentence2_key = non_label_column_names[:2]
335 | else:
336 | sentence1_key, sentence2_key = non_label_column_names[0], None
337 |
338 | # Some models have set the order of the labels to use, so let's make sure we do use it.
339 | label_to_id = None
340 | if (
341 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
342 | and args.task_name is not None
343 | and not is_regression
344 | ):
345 | # Some have all caps in their config, some don't.
346 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
347 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
348 | logger.info(
349 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
350 | "Using it!"
351 | )
352 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
353 | else:
354 | logger.warning(
355 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
356 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
357 | "\nIgnoring the model labels as a result.",
358 | )
359 | elif args.task_name is None and not is_regression:
360 | label_to_id = {v: i for i, v in enumerate(label_list)}
361 |
362 | if label_to_id is not None:
363 | model.config.label2id = label_to_id
364 | model.config.id2label = {id: label for label, id in config.label2id.items()}
365 | elif args.task_name is not None and not is_regression:
366 | model.config.label2id = {l: i for i, l in enumerate(label_list)}
367 | model.config.id2label = {id: label for label, id in config.label2id.items()}
368 |
369 | padding = "max_length" if args.pad_to_max_length else False
370 |
371 | def preprocess_function(examples):
372 | # Tokenize the texts
373 | texts = (
374 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
375 | )
376 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True)
377 |
378 | if "label" in examples:
379 | if label_to_id is not None:
380 | # Map labels to IDs (not necessary for GLUE tasks)
381 | result["labels"] = [label_to_id[l] for l in examples["label"]]
382 | else:
383 | # In all cases, rename the column to labels because the model will expect that.
384 | result["labels"] = examples["label"]
385 | return result
386 |
387 | with accelerator.main_process_first():
388 | processed_datasets = raw_datasets.map(
389 | preprocess_function,
390 | batched=True,
391 | # 得把这行改掉:
392 | # 以SST2为例,这里会把 ['sentence', 'label', 'idx'] 给去掉(不用担心label,因为上面已经新建了一个labels列)
393 | # remove_columns=raw_datasets["train"].column_names,
394 | # 改为:
395 | remove_columns=[c for c in raw_datasets["train"].column_names if c != 'idx'], # 保留idx,其他的可以去掉
396 | desc="Running tokenizer on dataset",
397 | )
398 |
399 | train_dataset = processed_datasets["train"]
400 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
401 |
402 | # Log a few random samples from the training set:
403 | for index in random.sample(range(len(train_dataset)), 3):
404 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
405 |
406 | # DataLoaders creation:
407 | if args.pad_to_max_length:
408 | # If padding was already done ot max length, we use the default data collator that will just convert everything
409 | # to tensors.
410 | data_collator = default_data_collator
411 | else:
412 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
413 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
414 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
415 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))
416 |
417 | train_dataloader = DataLoader(
418 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
419 | )
420 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
421 |
422 | # Optimizer
423 | # Split weights in two groups, one with weight decay and the other not.
424 | no_decay = ["bias", "LayerNorm.weight"]
425 | optimizer_grouped_parameters = [
426 | {
427 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
428 | "weight_decay": args.weight_decay,
429 | },
430 | {
431 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
432 | "weight_decay": 0.0,
433 | },
434 | ]
435 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
436 |
437 | # Scheduler and math around the number of training steps.
438 | overrode_max_train_steps = False
439 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
440 | if args.max_train_steps is None:
441 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
442 | overrode_max_train_steps = True
443 |
444 | lr_scheduler = get_scheduler(
445 | name=args.lr_scheduler_type,
446 | optimizer=optimizer,
447 | num_warmup_steps=args.num_warmup_steps,
448 | num_training_steps=args.max_train_steps,
449 | )
450 |
451 | # Prepare everything with our `accelerator`.
452 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
453 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
454 | )
455 |
456 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
457 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
458 | if overrode_max_train_steps:
459 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
460 | # Afterwards we recalculate our number of training epochs
461 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
462 |
463 | # Figure out how many steps we should save the Accelerator states
464 | if hasattr(args.checkpointing_steps, "isdigit"):
465 | checkpointing_steps = args.checkpointing_steps
466 | if args.checkpointing_steps.isdigit():
467 | checkpointing_steps = int(args.checkpointing_steps)
468 | else:
469 | checkpointing_steps = None
470 |
471 | # We need to initialize the trackers we use, and also store our configuration.
472 | # We initialize the trackers only on main process because `accelerator.log`
473 | # only logs on main process and we don't want empty logs/runs on other processes.
474 | if args.with_tracking:
475 | if accelerator.is_main_process:
476 | experiment_config = vars(args)
477 | # TensorBoard cannot log Enums, need the raw value
478 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
479 | accelerator.init_trackers("glue_no_trainer", experiment_config)
480 |
481 | # Get the metric function
482 | if args.task_name is not None:
483 | metric = load_metric("glue", args.task_name)
484 | else:
485 | metric = load_metric("accuracy")
486 |
487 | # Train!
488 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
489 |
490 | logger.info("***** Running training *****")
491 | logger.info(f" Num examples = {len(train_dataset)}")
492 | logger.info(f" Num Epochs = {args.num_train_epochs}")
493 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
494 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
495 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
496 | logger.info(f" Total optimization steps = {args.max_train_steps}")
497 | # Only show the progress bar once on each machine.
498 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
499 | completed_steps = 0
500 | starting_epoch = 0
501 | # Potentially load in the weights and states from a previous save
502 | if args.resume_from_checkpoint:
503 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
504 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
505 | accelerator.load_state(args.resume_from_checkpoint)
506 | path = os.path.basename(args.resume_from_checkpoint)
507 | else:
508 | # Get the most recent checkpoint
509 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
510 | dirs.sort(key=os.path.getctime)
511 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
512 | # Extract `epoch_{i}` or `step_{i}`
513 | training_difference = os.path.splitext(path)[0]
514 |
515 | if "epoch" in training_difference:
516 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
517 | resume_step = None
518 | else:
519 | resume_step = int(training_difference.replace("step_", ""))
520 | starting_epoch = resume_step // len(train_dataloader)
521 | resume_step -= starting_epoch * len(train_dataloader)
522 |
523 | # ============================ Training Loop ============================
524 | for epoch in range(starting_epoch, args.num_train_epochs):
525 | if accelerator.is_main_process:
526 | if not os.path.exists(f'dy_log/{args.task_name}/'):
527 | os.mkdir(f'dy_log/{args.task_name}/')
528 | if not os.path.exists(f'dy_log/{args.task_name}/{args.model_name_or_path}'):
529 | os.mkdir(f'dy_log/{args.task_name}/{args.model_name_or_path}')
530 | log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/'
531 | if not os.path.exists(log_path):
532 | os.mkdir(log_path)
533 |
534 | accelerator.wait_for_everyone() # 只在 main process 里面创建文件夹,然后让其他 process 等待 main process 创建完毕
535 | log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/'
536 | print('-*-*-*- ',log_path, os.path.exists(log_path),accelerator.device)
537 |
538 | model.train()
539 | if args.with_tracking:
540 | total_loss = 0
541 | for step, batch in enumerate(train_dataloader):
542 | # We need to skip steps until we reach the resumed step
543 | if args.resume_from_checkpoint and epoch == starting_epoch:
544 | if resume_step is not None and step < resume_step:
545 | completed_steps += 1
546 | continue
547 | # batch中包含了idx字段,这里需要去除
548 | batch = {k:v for k,v in batch.items() if k != 'idx'}
549 | outputs = model(**batch)
550 | loss = outputs.loss
551 | # We keep track of the loss at each epoch
552 | if args.with_tracking:
553 | total_loss += loss.detach().float()
554 | loss = loss / args.gradient_accumulation_steps
555 | accelerator.backward(loss)
556 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
557 | optimizer.step()
558 | lr_scheduler.step()
559 | optimizer.zero_grad()
560 | progress_bar.update(1)
561 | completed_steps += 1
562 |
563 | if isinstance(checkpointing_steps, int):
564 | if completed_steps % checkpointing_steps == 0:
565 | output_dir = f"step_{completed_steps }"
566 | if args.output_dir is not None:
567 | output_dir = os.path.join(args.output_dir, output_dir)
568 | accelerator.save_state(output_dir)
569 |
570 | if completed_steps >= args.max_train_steps:
571 | break
572 | # ------------------ Recording Training Dynamics --------------------
573 | # 在每一个epoch之后,对train set所有样本再过一遍,记录dynamics
574 | # 每个epoch单独一个文件
575 | logger.info('---------- Recording Training Dynamics (Epoch %s) -----------'%epoch)
576 | training_dynamics = []
577 | all_ids = []
578 | for step, batch in enumerate(tqdm(train_dataloader)):
579 | # print('- - - - - - - - - - ',len(batch['idx']), accelerator.device)
580 | idx_list = batch['idx']#.tolist()
581 | label_list = batch['labels']#.tolist()
582 | batch = {k:v for k,v in batch.items() if k != 'idx'}
583 | logits_list = model(**batch).logits#.tolist() # [[],[],[],...] batch_size个[]
584 | # 这里的关键:通过 gather 把每个 GPU上的结果合并
585 | # 由于在使用多卡训练时,不同卡可能存在样本的重复,同一个卡也会对最后一个batch进行补齐,也会样本重复
586 | # 使用 gather 的话,就可以按照原来的分配方式,逆着组合回去,就不用你自己处理了
587 | # gather 之后的,在每个卡上,下述变量里包含的数量,都等同于只使用单卡进行训练时的数量
588 | # 所以下面的for训练执行完之后,training_dynamics里就包含了全部样本,你在写入文件时,记住只在一个 process 中写入
589 | idx_list, label_list, logits_list = accelerator.gather((idx_list, label_list, logits_list))
590 | # print('idx_list', idx_list.shape, accelerator.device)
591 | # print('label_list', label_list.shape, accelerator.device)
592 |
593 | for idx, label, logits in zip(idx_list.tolist(), label_list.tolist(), logits_list.tolist()):
594 | if idx in all_ids: # 由于 data_loader 可能会对最后一个 batch 进行补全,所以这里要去掉重复的样本
595 | continue
596 | all_ids.append(idx)
597 | record = {'guid': idx, 'logits_epoch_%s'%epoch: logits, 'gold': label, 'device':str(accelerator.device)}
598 | training_dynamics.append(record)
599 | print(len(all_ids),len(list(set(all_ids))),str(accelerator.device))
600 |
601 | print('---- Num of training_dynamics: ',len(training_dynamics),' Device: ', str(accelerator.device))
602 | if accelerator.is_main_process:
603 | assert os.path.exists(log_path),log_path
604 | writer = open(log_path + f'dynamics_epoch_{epoch}.jsonl', 'w')
605 | for record in training_dynamics:
606 | writer.write(json.dumps(record) + "\n")
607 | logger.info(f'Epoch {epoch} Saved to [{log_path}]')
608 | writer.close()
609 | accelerator.wait_for_everyone()
610 |
611 | # ------------------------------------------------------------------------
612 |
613 |
614 |
615 | # evaluation (validation set)
616 | model.eval()
617 | samples_seen = 0
618 | for step, batch in enumerate(eval_dataloader):
619 | batch = {k:v for k,v in batch.items() if k != 'idx'} # 需不需要设置device ?
620 | with torch.no_grad():
621 | outputs = model(**batch)
622 | predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
623 | predictions, references = accelerator.gather((predictions, batch["labels"]))
624 | # If we are in a multiprocess environment, the last batch has duplicates
625 | if accelerator.num_processes > 1:
626 | if step == len(eval_dataloader) - 1:
627 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
628 | references = references[: len(eval_dataloader.dataset) - samples_seen]
629 | else:
630 | samples_seen += references.shape[0]
631 | metric.add_batch(
632 | predictions=predictions,
633 | references=references,
634 | )
635 |
636 | eval_metric = metric.compute()
637 | logger.info(f"epoch {epoch}: {eval_metric}")
638 |
639 | if args.with_tracking:
640 | accelerator.log(
641 | {
642 | "accuracy" if args.task_name is not None else "glue": eval_metric,
643 | "train_loss": total_loss.item() / len(train_dataloader),
644 | "epoch": epoch,
645 | "step": completed_steps,
646 | },
647 | step=completed_steps,
648 | )
649 |
650 | # if args.push_to_hub and epoch < args.num_train_epochs - 1:
651 | # accelerator.wait_for_everyone()
652 | # unwrapped_model = accelerator.unwrap_model(model)
653 | # unwrapped_model.save_pretrained(
654 | # args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
655 | # )
656 | # if accelerator.is_main_process:
657 | # tokenizer.save_pretrained(args.output_dir)
658 | # repo.push_to_hub(
659 | # commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
660 | # )
661 |
662 | if args.checkpointing_steps == "epoch":
663 | output_dir = f"epoch_{epoch}"
664 | if args.output_dir is not None:
665 | output_dir = os.path.join(args.output_dir, output_dir)
666 | accelerator.save_state(output_dir)
667 | # ============================ End Training Loop ============================
668 |
669 |
670 | if args.output_dir is not None:
671 | accelerator.wait_for_everyone()
672 | unwrapped_model = accelerator.unwrap_model(model)
673 | unwrapped_model.save_pretrained(
674 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
675 | )
676 | if accelerator.is_main_process:
677 | tokenizer.save_pretrained(args.output_dir)
678 | # if args.push_to_hub:
679 | # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
680 |
681 | if args.task_name == "mnli":
682 | # Final evaluation on mismatched validation set
683 | eval_dataset = processed_datasets["validation_mismatched"]
684 | eval_dataloader = DataLoader(
685 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
686 | )
687 | eval_dataloader = accelerator.prepare(eval_dataloader)
688 |
689 | model.eval()
690 | for step, batch in enumerate(eval_dataloader):
691 | outputs = model(**batch)
692 | predictions = outputs.logits.argmax(dim=-1)
693 | metric.add_batch(
694 | predictions=accelerator.gather(predictions),
695 | references=accelerator.gather(batch["labels"]),
696 | )
697 |
698 | eval_metric = metric.compute()
699 | logger.info(f"mnli-mm: {eval_metric}")
700 |
701 | if args.output_dir is not None:
702 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
703 | json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
704 |
705 |
706 | if __name__ == "__main__":
707 | main()
--------------------------------------------------------------------------------
/hardness_classification_prepare.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# hardness classification\n",
8 | "- use 'data_selection.py' set `-proportion 0.5` to split all samples into easy and hard.\n",
9 | "- then read the outputted 'three_regions_data_indices.json' to get their indices"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 5,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import json\n",
19 | "dataset_name = 'mrpc'\n",
20 | "with open(f'dy_log/{dataset_name}/bert-base-cased/three_regions_data_indices.json' ,'r') as f:\n",
21 | " d = json.loads(f.read())"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 6,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "1210 1210\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "print(len(d['hard']), len(d['easy']))\n",
39 | "assert len(set(d['hard']).intersection(set(d['easy']))) == 0"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 7,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stderr",
49 | "output_type": "stream",
50 | "text": [
51 | "Reusing dataset glue (/home/v-biyangguo/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
52 | ]
53 | },
54 | {
55 | "data": {
56 | "application/vnd.jupyter.widget-view+json": {
57 | "model_id": "008e864d865b4e2e8ace44c252f27211",
58 | "version_major": 2,
59 | "version_minor": 0
60 | },
61 | "text/plain": [
62 | " 0%| | 0/3 [00:00, ?it/s]"
63 | ]
64 | },
65 | "metadata": {},
66 | "output_type": "display_data"
67 | }
68 | ],
69 | "source": [
70 | "import datasets\n",
71 | "from datasets import load_dataset\n",
72 | "data = load_dataset('glue',dataset_name)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 15,
78 | "metadata": {},
79 | "outputs": [
80 | {
81 | "data": {
82 | "application/vnd.jupyter.widget-view+json": {
83 | "model_id": "b692977869464090be85b97ca90ae200",
84 | "version_major": 2,
85 | "version_minor": 0
86 | },
87 | "text/plain": [
88 | " 0%| | 0/4 [00:00, ?ba/s]"
89 | ]
90 | },
91 | "metadata": {},
92 | "output_type": "display_data"
93 | },
94 | {
95 | "data": {
96 | "text/plain": [
97 | "Dataset({\n",
98 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
99 | " num_rows: 1210\n",
100 | "})"
101 | ]
102 | },
103 | "execution_count": 15,
104 | "metadata": {},
105 | "output_type": "execute_result"
106 | }
107 | ],
108 | "source": [
109 | "data['train'].filter(lambda x:x['idx'] in d['easy'])"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 8,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "ename": "IndexError",
119 | "evalue": "Index 4073 out of range for dataset of size 3668.",
120 | "output_type": "error",
121 | "traceback": [
122 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
123 | "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
124 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0measy_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'easy'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mhard_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'hard'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
125 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 522\u001b[0m }\n\u001b[1;32m 523\u001b[0m \u001b[0;31m# apply actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"DatasetDict\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;31m# re-apply format to the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
126 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/fingerprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0;31m# Call actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 480\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;31m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
127 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mselect\u001b[0;34m(self, indices, keep_in_memory, indices_cache_file_name, writer_batch_size, new_fingerprint)\u001b[0m\n\u001b[1;32m 3073\u001b[0m \u001b[0mindices_cache_file_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindices_cache_file_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3074\u001b[0m \u001b[0mwriter_batch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwriter_batch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3075\u001b[0;31m \u001b[0mnew_fingerprint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_fingerprint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3076\u001b[0m )\n\u001b[1;32m 3077\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
128 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 522\u001b[0m }\n\u001b[1;32m 523\u001b[0m \u001b[0;31m# apply actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"DatasetDict\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;31m# re-apply format to the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
129 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/fingerprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0;31m# Call actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 480\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 481\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;31m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
130 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_select_with_indices_mapping\u001b[0;34m(self, indices, keep_in_memory, indices_cache_file_name, writer_batch_size, new_fingerprint)\u001b[0m\n\u001b[1;32m 3199\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3200\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3201\u001b[0;31m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3202\u001b[0m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3203\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
131 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_check_valid_indices_value\u001b[0;34m(index, size)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_check_valid_indices_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 611\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mIndexError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Index {index} out of range for dataset of size {size}.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 612\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
132 | "\u001b[0;31mIndexError\u001b[0m: Index 4073 out of range for dataset of size 3668."
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "easy_data = data['train'].select(d['easy'])\n",
138 | "hard_data = data['train'].select(d['hard'])"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 5,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "application/vnd.jupyter.widget-view+json": {
149 | "model_id": "a6bf20ded5684a05859ec3f574323c93",
150 | "version_major": 2,
151 | "version_minor": 0
152 | },
153 | "text/plain": [
154 | "Flattening the indices: 0%| | 0/53 [00:00, ?ba/s]"
155 | ]
156 | },
157 | "metadata": {},
158 | "output_type": "display_data"
159 | },
160 | {
161 | "data": {
162 | "application/vnd.jupyter.widget-view+json": {
163 | "model_id": "55fbeeae9fe24bf3b33735038e55b337",
164 | "version_major": 2,
165 | "version_minor": 0
166 | },
167 | "text/plain": [
168 | "Flattening the indices: 0%| | 0/53 [00:00, ?ba/s]"
169 | ]
170 | },
171 | "metadata": {},
172 | "output_type": "display_data"
173 | },
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "{0} {1}\n"
179 | ]
180 | }
181 | ],
182 | "source": [
183 | "easy_data = easy_data.remove_columns(['label'])\n",
184 | "easy_data = easy_data.add_column('label', [0]*len(easy_data))\n",
185 | "\n",
186 | "hard_data = hard_data.remove_columns(['label'])\n",
187 | "hard_data = hard_data.add_column('label', [1]*len(easy_data))\n",
188 | "\n",
189 | "print(set(easy_data['label']), set(hard_data['label']))"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 9,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "data": {
199 | "text/plain": [
200 | "Dataset({\n",
201 | " features: ['question', 'sentence', 'idx', 'label'],\n",
202 | " num_rows: 52371\n",
203 | "})"
204 | ]
205 | },
206 | "execution_count": 9,
207 | "metadata": {},
208 | "output_type": "execute_result"
209 | }
210 | ],
211 | "source": [
212 | "easy_data"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 14,
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "import random\n",
222 | "easy_indices = d['easy']\n",
223 | "random.shuffle(easy_indices)\n",
224 | "\n",
225 | "hard_indices = d['hard']\n",
226 | "random.shuffle(hard_indices)\n",
227 | "\n",
228 | "easy_data_train = easy_data.filter(lambda x: x['idx'] in easy_indices[:int(0.7*len(easy_indices))])\n",
229 | "easy_data_valid = easy_data.filter(lambda x: x['idx'] in easy_indices[int(0.7*len(easy_indices)):])\n",
230 | "\n",
231 | "hard_data_train = hard_data.filter(lambda x: x['idx'] in hard_indices[:int(0.7*len(hard_indices))])\n",
232 | "hard_data_valid = hard_data.filter(lambda x: x['idx'] in hard_indices[int(0.7*len(hard_indices)):])\n",
233 | "\n",
234 | "assert len(set(easy_data_train['idx']).intersection(set(easy_data_valid['idx']))) == 0\n",
235 | "assert len(set(hard_data_train['idx']).intersection(set(hard_data_valid['idx']))) == 0\n"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 15,
241 | "metadata": {},
242 | "outputs": [
243 | {
244 | "data": {
245 | "text/plain": [
246 | "(Dataset({\n",
247 | " features: ['question', 'sentence', 'idx', 'label'],\n",
248 | " num_rows: 36659\n",
249 | " }),\n",
250 | " Dataset({\n",
251 | " features: ['question', 'sentence', 'idx', 'label'],\n",
252 | " num_rows: 15712\n",
253 | " }))"
254 | ]
255 | },
256 | "execution_count": 15,
257 | "metadata": {},
258 | "output_type": "execute_result"
259 | }
260 | ],
261 | "source": [
262 | "easy_data_train, easy_data_valid"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": 16,
268 | "metadata": {},
269 | "outputs": [],
270 | "source": [
271 | "new_train_set = datasets.concatenate_datasets([easy_data_train, hard_data_train]).shuffle(seed=1)\n",
272 | "new_valid_set = datasets.concatenate_datasets([easy_data_valid, hard_data_valid]).shuffle(seed=1)"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 24,
278 | "metadata": {},
279 | "outputs": [
280 | {
281 | "data": {
282 | "text/plain": [
283 | "(Dataset({\n",
284 | " features: ['sentence', 'idx', 'label'],\n",
285 | " num_rows: 47142\n",
286 | " }),\n",
287 | " Dataset({\n",
288 | " features: ['sentence', 'idx', 'label'],\n",
289 | " num_rows: 20206\n",
290 | " }))"
291 | ]
292 | },
293 | "execution_count": 24,
294 | "metadata": {},
295 | "output_type": "execute_result"
296 | }
297 | ],
298 | "source": [
299 | "new_train_set, new_valid_set"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 18,
305 | "metadata": {},
306 | "outputs": [
307 | {
308 | "data": {
309 | "application/vnd.jupyter.widget-view+json": {
310 | "model_id": "c05971f823424364b029e8ee256180b7",
311 | "version_major": 2,
312 | "version_minor": 0
313 | },
314 | "text/plain": [
315 | "Creating CSV from Arrow format: 0%| | 0/8 [00:00, ?ba/s]"
316 | ]
317 | },
318 | "metadata": {},
319 | "output_type": "display_data"
320 | },
321 | {
322 | "data": {
323 | "application/vnd.jupyter.widget-view+json": {
324 | "model_id": "71f0a90b0e2047ae80675832114358f9",
325 | "version_major": 2,
326 | "version_minor": 0
327 | },
328 | "text/plain": [
329 | "Creating CSV from Arrow format: 0%| | 0/4 [00:00, ?ba/s]"
330 | ]
331 | },
332 | "metadata": {},
333 | "output_type": "display_data"
334 | },
335 | {
336 | "data": {
337 | "text/plain": [
338 | "7585188"
339 | ]
340 | },
341 | "execution_count": 18,
342 | "metadata": {},
343 | "output_type": "execute_result"
344 | }
345 | ],
346 | "source": [
347 | "new_train_set.to_csv(f'datasets/{dataset_name}-easy-hard_train.csv')\n",
348 | "new_valid_set.to_csv(f'datasets/{dataset_name}-easy-hard_valid.csv')"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 19,
354 | "metadata": {},
355 | "outputs": [
356 | {
357 | "name": "stderr",
358 | "output_type": "stream",
359 | "text": [
360 | "Using custom data configuration default-36cdcd5cd5588345\n"
361 | ]
362 | },
363 | {
364 | "name": "stdout",
365 | "output_type": "stream",
366 | "text": [
367 | "Downloading and preparing dataset csv/default to /home/v-biyangguo/.cache/huggingface/datasets/csv/default-36cdcd5cd5588345/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58...\n"
368 | ]
369 | },
370 | {
371 | "data": {
372 | "application/vnd.jupyter.widget-view+json": {
373 | "model_id": "c4658d04212c4dbea9a5635fdf570db2",
374 | "version_major": 2,
375 | "version_minor": 0
376 | },
377 | "text/plain": [
378 | "Downloading data files: 0%| | 0/2 [00:00, ?it/s]"
379 | ]
380 | },
381 | "metadata": {},
382 | "output_type": "display_data"
383 | },
384 | {
385 | "data": {
386 | "application/vnd.jupyter.widget-view+json": {
387 | "model_id": "e3e71df6ada143dc8f88adcb8f3b2b92",
388 | "version_major": 2,
389 | "version_minor": 0
390 | },
391 | "text/plain": [
392 | "Extracting data files: 0%| | 0/2 [00:00, ?it/s]"
393 | ]
394 | },
395 | "metadata": {},
396 | "output_type": "display_data"
397 | },
398 | {
399 | "data": {
400 | "application/vnd.jupyter.widget-view+json": {
401 | "model_id": "2718b74502ab450fa0dc85252f0054c1",
402 | "version_major": 2,
403 | "version_minor": 0
404 | },
405 | "text/plain": [
406 | "0 tables [00:00, ? tables/s]"
407 | ]
408 | },
409 | "metadata": {},
410 | "output_type": "display_data"
411 | },
412 | {
413 | "data": {
414 | "application/vnd.jupyter.widget-view+json": {
415 | "model_id": "d5e53274d1164cbb96086c6e13ad253c",
416 | "version_major": 2,
417 | "version_minor": 0
418 | },
419 | "text/plain": [
420 | "0 tables [00:00, ? tables/s]"
421 | ]
422 | },
423 | "metadata": {},
424 | "output_type": "display_data"
425 | },
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "Dataset csv downloaded and prepared to /home/v-biyangguo/.cache/huggingface/datasets/csv/default-36cdcd5cd5588345/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58. Subsequent calls will reuse this data.\n"
431 | ]
432 | },
433 | {
434 | "data": {
435 | "application/vnd.jupyter.widget-view+json": {
436 | "model_id": "62fe50a754ca40f397b962ebe25bcdcd",
437 | "version_major": 2,
438 | "version_minor": 0
439 | },
440 | "text/plain": [
441 | " 0%| | 0/2 [00:00, ?it/s]"
442 | ]
443 | },
444 | "metadata": {},
445 | "output_type": "display_data"
446 | },
447 | {
448 | "data": {
449 | "text/plain": [
450 | "DatasetDict({\n",
451 | " train: Dataset({\n",
452 | " features: ['Unnamed: 0', 'question', 'sentence', 'idx', 'label'],\n",
453 | " num_rows: 73318\n",
454 | " })\n",
455 | " validation: Dataset({\n",
456 | " features: ['Unnamed: 0', 'question', 'sentence', 'idx', 'label'],\n",
457 | " num_rows: 31424\n",
458 | " })\n",
459 | "})"
460 | ]
461 | },
462 | "execution_count": 19,
463 | "metadata": {},
464 | "output_type": "execute_result"
465 | }
466 | ],
467 | "source": [
468 | "# 读取\n",
469 | "from datasets import load_dataset\n",
470 | "data_files = {}\n",
471 | "data_files[\"train\"] = f'datasets/{dataset_name}-easy-hard_train.csv'\n",
472 | "data_files[\"validation\"] = f'datasets/{dataset_name}-easy-hard_valid.csv'\n",
473 | "extension = 'csv'\n",
474 | "raw_datasets = load_dataset(extension, data_files=data_files)\n",
475 | "raw_datasets"
476 | ]
477 | },
478 | {
479 | "cell_type": "markdown",
480 | "metadata": {},
481 | "source": [
482 | "# Add `confidence` value to the dataset \n",
483 | "- 需要先跑 `plot.sh` 获得 `td_metrics.jsonl`"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 22,
489 | "metadata": {},
490 | "outputs": [
491 | {
492 | "data": {
493 | "text/plain": [
494 | "DatasetDict({\n",
495 | " train: Dataset({\n",
496 | " features: ['sentence1', 'sentence2', 'idx', 'label'],\n",
497 | " num_rows: 2490\n",
498 | " })\n",
499 | " validation: Dataset({\n",
500 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
501 | " num_rows: 277\n",
502 | " })\n",
503 | " test: Dataset({\n",
504 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
505 | " num_rows: 3000\n",
506 | " })\n",
507 | "})"
508 | ]
509 | },
510 | "execution_count": 22,
511 | "metadata": {},
512 | "output_type": "execute_result"
513 | }
514 | ],
515 | "source": [
516 | "import datasets\n",
517 | "from datasets import load_dataset, load_from_disk\n",
518 | "dataset_name = 'rte-noisy-0.4'\n",
519 | "# data = load_dataset('super_glue',dataset_name) # normal GLUE or SuperGLUE datasets\n",
520 | "data = load_from_disk(f'datasets/{dataset_name}/with_idx') # other datasets that saved locally\n",
521 | "\n",
522 | "data"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": 23,
528 | "metadata": {},
529 | "outputs": [
530 | {
531 | "name": "stdout",
532 | "output_type": "stream",
533 | "text": [
534 | "2490\n"
535 | ]
536 | },
537 | {
538 | "data": {
539 | "text/html": [
540 | "\n",
541 | "\n",
554 | "
\n",
555 | " \n",
556 | " \n",
557 | " | \n",
558 | " confidence | \n",
559 | " correctness | \n",
560 | " forgetfulness | \n",
561 | " guid | \n",
562 | " index | \n",
563 | " threshold_closeness | \n",
564 | " variability | \n",
565 | "
\n",
566 | " \n",
567 | " \n",
568 | " \n",
569 | " | 0 | \n",
570 | " 0.729539 | \n",
571 | " 5 | \n",
572 | " 0 | \n",
573 | " 1154 | \n",
574 | " 0 | \n",
575 | " 0.197312 | \n",
576 | " 0.188230 | \n",
577 | "
\n",
578 | " \n",
579 | " | 1 | \n",
580 | " 0.719761 | \n",
581 | " 5 | \n",
582 | " 0 | \n",
583 | " 1574 | \n",
584 | " 1 | \n",
585 | " 0.201705 | \n",
586 | " 0.147285 | \n",
587 | "
\n",
588 | " \n",
589 | " | 2 | \n",
590 | " 0.724613 | \n",
591 | " 4 | \n",
592 | " 1 | \n",
593 | " 698 | \n",
594 | " 2 | \n",
595 | " 0.199549 | \n",
596 | " 0.205787 | \n",
597 | "
\n",
598 | " \n",
599 | " | 3 | \n",
600 | " 0.764356 | \n",
601 | " 5 | \n",
602 | " 0 | \n",
603 | " 377 | \n",
604 | " 3 | \n",
605 | " 0.180116 | \n",
606 | " 0.162303 | \n",
607 | "
\n",
608 | " \n",
609 | " | 4 | \n",
610 | " 0.800285 | \n",
611 | " 5 | \n",
612 | " 0 | \n",
613 | " 2437 | \n",
614 | " 4 | \n",
615 | " 0.159829 | \n",
616 | " 0.177974 | \n",
617 | "
\n",
618 | " \n",
619 | "
\n",
620 | "
"
621 | ],
622 | "text/plain": [
623 | " confidence correctness forgetfulness guid index threshold_closeness \\\n",
624 | "0 0.729539 5 0 1154 0 0.197312 \n",
625 | "1 0.719761 5 0 1574 1 0.201705 \n",
626 | "2 0.724613 4 1 698 2 0.199549 \n",
627 | "3 0.764356 5 0 377 3 0.180116 \n",
628 | "4 0.800285 5 0 2437 4 0.159829 \n",
629 | "\n",
630 | " variability \n",
631 | "0 0.188230 \n",
632 | "1 0.147285 \n",
633 | "2 0.205787 \n",
634 | "3 0.162303 \n",
635 | "4 0.177974 "
636 | ]
637 | },
638 | "execution_count": 23,
639 | "metadata": {},
640 | "output_type": "execute_result"
641 | }
642 | ],
643 | "source": [
644 | "import pandas as pd\n",
645 | "td_df = pd.read_json(f'dy_log/{dataset_name}/bert-base-cased/td_metrics.jsonl', lines=True) # lines=True 是因为你加载的是jsonl文件,每行都是一个dictionary\n",
646 | "print(len(td_df))\n",
647 | "td_df.head()"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": 24,
653 | "metadata": {},
654 | "outputs": [
655 | {
656 | "data": {
657 | "text/plain": [
658 | "2490"
659 | ]
660 | },
661 | "execution_count": 24,
662 | "metadata": {},
663 | "output_type": "execute_result"
664 | }
665 | ],
666 | "source": [
667 | "id2conf = {}\n",
668 | "for guid, conf in zip(list(td_df['guid']), list(td_df['confidence'])):\n",
669 | " id2conf[guid] = conf\n",
670 | "len(id2conf)"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 25,
676 | "metadata": {},
677 | "outputs": [
678 | {
679 | "data": {
680 | "text/plain": [
681 | "DatasetDict({\n",
682 | " train: Dataset({\n",
683 | " features: ['sentence1', 'sentence2', 'idx', 'label', 'confidence'],\n",
684 | " num_rows: 2490\n",
685 | " })\n",
686 | " validation: Dataset({\n",
687 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
688 | " num_rows: 277\n",
689 | " })\n",
690 | " test: Dataset({\n",
691 | " features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
692 | " num_rows: 3000\n",
693 | " })\n",
694 | "})"
695 | ]
696 | },
697 | "execution_count": 25,
698 | "metadata": {},
699 | "output_type": "execute_result"
700 | }
701 | ],
702 | "source": [
703 | "train_ids = data['train']['idx']\n",
704 | "train_conf = [id2conf[id] for id in train_ids]\n",
705 | "data['train'] = data['train'].add_column('confidence', train_conf)\n",
706 | "\n",
707 | "data"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": 26,
713 | "metadata": {},
714 | "outputs": [
715 | {
716 | "data": {
717 | "text/plain": [
718 | "[0.6924006224,\n",
719 | " 0.5262133539,\n",
720 | " 0.7324662924,\n",
721 | " 0.7629377484000001,\n",
722 | " 0.5932154834000001,\n",
723 | " 0.6104391217,\n",
724 | " 0.4496941358,\n",
725 | " 0.6968309402,\n",
726 | " 0.6517617762,\n",
727 | " 0.7604392648]"
728 | ]
729 | },
730 | "execution_count": 26,
731 | "metadata": {},
732 | "output_type": "execute_result"
733 | }
734 | ],
735 | "source": [
736 | "data['train']['confidence'][:10]"
737 | ]
738 | },
739 | {
740 | "cell_type": "code",
741 | "execution_count": 27,
742 | "metadata": {},
743 | "outputs": [],
744 | "source": [
745 | "data.save_to_disk(f\"datasets/{dataset_name}/with_conf/\")"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": 7,
751 | "metadata": {},
752 | "outputs": [
753 | {
754 | "data": {
755 | "text/plain": [
756 | "DatasetDict({\n",
757 | " test: Dataset({\n",
758 | " features: ['premise', 'hypothesis', 'label'],\n",
759 | " num_rows: 9824\n",
760 | " })\n",
761 | " train: Dataset({\n",
762 | " features: ['premise', 'hypothesis', 'label', 'idx', 'confidence'],\n",
763 | " num_rows: 549367\n",
764 | " })\n",
765 | " validation: Dataset({\n",
766 | " features: ['premise', 'hypothesis', 'label'],\n",
767 | " num_rows: 9842\n",
768 | " })\n",
769 | "})"
770 | ]
771 | },
772 | "execution_count": 7,
773 | "metadata": {},
774 | "output_type": "execute_result"
775 | }
776 | ],
777 | "source": [
778 | "# 读取\n",
779 | "from datasets import load_from_disk\n",
780 | "reloaded_data = load_from_disk(f\"datasets/{dataset_name}/with_conf/\")\n",
781 | "reloaded_data\n"
782 | ]
783 | },
784 | {
785 | "cell_type": "code",
786 | "execution_count": 9,
787 | "metadata": {},
788 | "outputs": [
789 | {
790 | "data": {
791 | "text/plain": [
792 | "[1, 0, 2, 1, 0, 2, 0, 1, 2, 1]"
793 | ]
794 | },
795 | "execution_count": 9,
796 | "metadata": {},
797 | "output_type": "execute_result"
798 | }
799 | ],
800 | "source": [
801 | "reloaded_data['test']['label'][:10]"
802 | ]
803 | },
804 | {
805 | "cell_type": "code",
806 | "execution_count": 10,
807 | "metadata": {},
808 | "outputs": [
809 | {
810 | "data": {
811 | "text/plain": [
812 | "[1, 0, 2, 0, 1, 2, 2, 1, 0, 0]"
813 | ]
814 | },
815 | "execution_count": 10,
816 | "metadata": {},
817 | "output_type": "execute_result"
818 | }
819 | ],
820 | "source": [
821 | "reloaded_data['validation']['label'][:10]"
822 | ]
823 | },
824 | {
825 | "cell_type": "code",
826 | "execution_count": null,
827 | "metadata": {},
828 | "outputs": [],
829 | "source": []
830 | }
831 | ],
832 | "metadata": {
833 | "interpreter": {
834 | "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
835 | },
836 | "kernelspec": {
837 | "display_name": "Python 3.6.10 64-bit ('base': conda)",
838 | "name": "python3"
839 | },
840 | "language_info": {
841 | "name": "python",
842 | "version": ""
843 | },
844 | "orig_nbformat": 4
845 | },
846 | "nbformat": 4,
847 | "nbformat_minor": 2
848 | }
--------------------------------------------------------------------------------
/HCT/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from collections import defaultdict\n",
10 | "import json\n",
11 | "import logging\n",
12 | "import math\n",
13 | "import os\n",
14 | "import random\n",
15 | "from pathlib import Path\n",
16 | "\n",
17 | "import datasets\n",
18 | "import torch\n",
19 | "from torch import nn\n",
20 | "import torch.nn.functional as F\n",
21 | "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
22 | "\n",
23 | "from datasets import load_dataset, load_metric\n",
24 | "from torch.utils.data import DataLoader\n",
25 | "from tqdm.auto import tqdm\n",
26 | "\n",
27 | "import transformers\n",
28 | "from transformers import (\n",
29 | " AutoConfig,\n",
30 | " AutoModel,\n",
31 | " AutoModelForSequenceClassification,\n",
32 | " AutoTokenizer,\n",
33 | " DataCollatorWithPadding,\n",
34 | " PretrainedConfig,\n",
35 | " SchedulerType,\n",
36 | " default_data_collator,\n",
37 | " get_scheduler,\n",
38 | ")"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 2,
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "name": "stdout",
48 | "output_type": "stream",
49 | "text": [
50 | "Downloading and preparing dataset glue/sst2 (download: 7.09 MiB, generated: 4.81 MiB, post-processed: Unknown size, total: 11.90 MiB) to /home/v-biyangguo/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...\n"
51 | ]
52 | },
53 | {
54 | "data": {
55 | "application/vnd.jupyter.widget-view+json": {
56 | "model_id": "9274b98cf63e4e20b154556afc1bcf91",
57 | "version_major": 2,
58 | "version_minor": 0
59 | },
60 | "text/plain": [
61 | "Downloading data: 0%| | 0.00/7.44M [00:00, ?B/s]"
62 | ]
63 | },
64 | "metadata": {},
65 | "output_type": "display_data"
66 | },
67 | {
68 | "data": {
69 | "application/vnd.jupyter.widget-view+json": {
70 | "model_id": "234e8f2b7e824737b11fc80a2aa51ffe",
71 | "version_major": 2,
72 | "version_minor": 0
73 | },
74 | "text/plain": [
75 | "Generating train split: 0%| | 0/67349 [00:00, ? examples/s]"
76 | ]
77 | },
78 | "metadata": {},
79 | "output_type": "display_data"
80 | },
81 | {
82 | "data": {
83 | "application/vnd.jupyter.widget-view+json": {
84 | "model_id": "928425bf8e7b4abcadfd6258c8603dcb",
85 | "version_major": 2,
86 | "version_minor": 0
87 | },
88 | "text/plain": [
89 | "Generating validation split: 0%| | 0/872 [00:00, ? examples/s]"
90 | ]
91 | },
92 | "metadata": {},
93 | "output_type": "display_data"
94 | },
95 | {
96 | "data": {
97 | "application/vnd.jupyter.widget-view+json": {
98 | "model_id": "1464f6e488444e6688725d758884c3fb",
99 | "version_major": 2,
100 | "version_minor": 0
101 | },
102 | "text/plain": [
103 | "Generating test split: 0%| | 0/1821 [00:00, ? examples/s]"
104 | ]
105 | },
106 | "metadata": {},
107 | "output_type": "display_data"
108 | },
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "Dataset glue downloaded and prepared to /home/v-biyangguo/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.\n"
114 | ]
115 | },
116 | {
117 | "data": {
118 | "application/vnd.jupyter.widget-view+json": {
119 | "model_id": "ff91ac1313ad461eb0e1f9ec9fa217c0",
120 | "version_major": 2,
121 | "version_minor": 0
122 | },
123 | "text/plain": [
124 | " 0%| | 0/3 [00:00, ?it/s]"
125 | ]
126 | },
127 | "metadata": {},
128 | "output_type": "display_data"
129 | }
130 | ],
131 | "source": [
132 | "from typing import List, Optional, Tuple, Union\n",
133 | "from datasets import load_dataset\n",
134 | "\n",
135 | "raw_datasets = load_dataset('glue','sst2')"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 3,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "data": {
145 | "application/vnd.jupyter.widget-view+json": {
146 | "model_id": "970ec2e0be3f4ec591354e4ee2a95c2f",
147 | "version_major": 2,
148 | "version_minor": 0
149 | },
150 | "text/plain": [
151 | "Downloading: 0%| | 0.00/29.0 [00:00, ?B/s]"
152 | ]
153 | },
154 | "metadata": {},
155 | "output_type": "display_data"
156 | },
157 | {
158 | "data": {
159 | "application/vnd.jupyter.widget-view+json": {
160 | "model_id": "b8c8c1ecd10f48a69030fff9508c625e",
161 | "version_major": 2,
162 | "version_minor": 0
163 | },
164 | "text/plain": [
165 | "Downloading: 0%| | 0.00/570 [00:00, ?B/s]"
166 | ]
167 | },
168 | "metadata": {},
169 | "output_type": "display_data"
170 | },
171 | {
172 | "data": {
173 | "application/vnd.jupyter.widget-view+json": {
174 | "model_id": "6aa93d7649294a4f8c18a859f5071758",
175 | "version_major": 2,
176 | "version_minor": 0
177 | },
178 | "text/plain": [
179 | "Downloading: 0%| | 0.00/208k [00:00, ?B/s]"
180 | ]
181 | },
182 | "metadata": {},
183 | "output_type": "display_data"
184 | },
185 | {
186 | "data": {
187 | "application/vnd.jupyter.widget-view+json": {
188 | "model_id": "b5fb84b91c944182b7ece80094352bec",
189 | "version_major": 2,
190 | "version_minor": 0
191 | },
192 | "text/plain": [
193 | "Downloading: 0%| | 0.00/426k [00:00, ?B/s]"
194 | ]
195 | },
196 | "metadata": {},
197 | "output_type": "display_data"
198 | },
199 | {
200 | "data": {
201 | "application/vnd.jupyter.widget-view+json": {
202 | "model_id": "eb13645ce6eb448097216d9356d2da80",
203 | "version_major": 2,
204 | "version_minor": 0
205 | },
206 | "text/plain": [
207 | "Running tokenizer on dataset: 0%| | 0/68 [00:00, ?ba/s]"
208 | ]
209 | },
210 | "metadata": {},
211 | "output_type": "display_data"
212 | },
213 | {
214 | "data": {
215 | "application/vnd.jupyter.widget-view+json": {
216 | "model_id": "ec13f4cd46e34bf7a436994b7b455fb3",
217 | "version_major": 2,
218 | "version_minor": 0
219 | },
220 | "text/plain": [
221 | "Running tokenizer on dataset: 0%| | 0/1 [00:00, ?ba/s]"
222 | ]
223 | },
224 | "metadata": {},
225 | "output_type": "display_data"
226 | },
227 | {
228 | "data": {
229 | "application/vnd.jupyter.widget-view+json": {
230 | "model_id": "ea320e42234e438fae5be3918a8186c7",
231 | "version_major": 2,
232 | "version_minor": 0
233 | },
234 | "text/plain": [
235 | "Running tokenizer on dataset: 0%| | 0/2 [00:00, ?ba/s]"
236 | ]
237 | },
238 | "metadata": {},
239 | "output_type": "display_data"
240 | },
241 | {
242 | "data": {
243 | "text/plain": [
244 | "DatasetDict({\n",
245 | " train: Dataset({\n",
246 | " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
247 | " num_rows: 67349\n",
248 | " })\n",
249 | " validation: Dataset({\n",
250 | " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
251 | " num_rows: 872\n",
252 | " })\n",
253 | " test: Dataset({\n",
254 | " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
255 | " num_rows: 1821\n",
256 | " })\n",
257 | "})"
258 | ]
259 | },
260 | "execution_count": 3,
261 | "metadata": {},
262 | "output_type": "execute_result"
263 | }
264 | ],
265 | "source": [
266 | "tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')\n",
267 | "def preprocess_function(examples):\n",
268 | " # Tokenize the texts\n",
269 | " texts = (examples['sentence'], None)\n",
270 | " result = tokenizer(*texts, padding=\"max_length\", max_length=100, truncation=True)\n",
271 | " if \"label\" in examples:\n",
272 | " result[\"labels\"] = examples[\"label\"]\n",
273 | " return result\n",
274 | "\n",
275 | "\n",
276 | "processed_datasets = raw_datasets.map(\n",
277 | " preprocess_function,\n",
278 | " batched=True,\n",
279 | " remove_columns=raw_datasets[\"train\"].column_names, \n",
280 | " desc=\"Running tokenizer on dataset\",)\n",
281 | "processed_datasets"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {},
288 | "outputs": [],
289 | "source": []
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 4,
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/plain": [
299 | "(dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels']),\n",
300 | " torch.Size([8, 100]))"
301 | ]
302 | },
303 | "execution_count": 4,
304 | "metadata": {},
305 | "output_type": "execute_result"
306 | }
307 | ],
308 | "source": [
309 | "data_collator = DataCollatorWithPadding(tokenizer)\n",
310 | "dataloader = DataLoader(processed_datasets['validation'], shuffle=True, collate_fn=data_collator, batch_size=8)\n",
311 | "c = 0\n",
312 | "for batch in dataloader:\n",
313 | " pass\n",
314 | "batch.keys(), batch['input_ids'].shape"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 5,
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/plain": [
325 | "4"
326 | ]
327 | },
328 | "execution_count": 5,
329 | "metadata": {},
330 | "output_type": "execute_result"
331 | }
332 | ],
333 | "source": [
334 | "len(batch)"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 6,
340 | "metadata": {},
341 | "outputs": [
342 | {
343 | "data": {
344 | "application/vnd.jupyter.widget-view+json": {
345 | "model_id": "1005d24c983d4455aae050c7c3819d31",
346 | "version_major": 2,
347 | "version_minor": 0
348 | },
349 | "text/plain": [
350 | "Downloading: 0%| | 0.00/416M [00:00, ?B/s]"
351 | ]
352 | },
353 | "metadata": {},
354 | "output_type": "display_data"
355 | },
356 | {
357 | "name": "stderr",
358 | "output_type": "stream",
359 | "text": [
360 | "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
361 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
362 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
363 | ]
364 | }
365 | ],
366 | "source": [
367 | "config = AutoConfig.from_pretrained('bert-base-cased', num_labels=2)\n",
368 | "encoder = AutoModel.from_pretrained('bert-base-cased')\n",
369 | "classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)\n",
370 | "dropout = nn.Dropout(classifier_dropout)\n",
371 | "classifier_easy = nn.Linear(config.hidden_size, config.num_labels)\n",
372 | "classifier_hard = nn.Linear(config.hidden_size, config.num_labels)\n",
373 | "# 2 experts: easy or hard\n",
374 | "hardness_gate = nn.Linear(config.hidden_size,2) "
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 8,
380 | "metadata": {},
381 | "outputs": [
382 | {
383 | "data": {
384 | "text/plain": [
385 | "torch.Size([8, 768])"
386 | ]
387 | },
388 | "execution_count": 8,
389 | "metadata": {},
390 | "output_type": "execute_result"
391 | }
392 | ],
393 | "source": [
394 | "batch_wo_labels = {k:v for k,v in batch.items() if k != 'labels'}\n",
395 | "outputs = encoder(**batch_wo_labels)\n",
396 | "pooled_output = outputs[1]\n",
397 | "pooled_output.shape"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": 17,
403 | "metadata": {},
404 | "outputs": [
405 | {
406 | "data": {
407 | "text/plain": [
408 | "tensor([[ 0.1883, -0.2964],\n",
409 | " [ 0.1280, -0.3338],\n",
410 | " [ 0.1373, -0.3579],\n",
411 | " [ 0.2173, -0.2385],\n",
412 | " [ 0.0645, -0.3355],\n",
413 | " [ 0.1644, -0.2774],\n",
414 | " [ 0.1745, -0.3913],\n",
415 | " [ 0.2355, -0.2404]], grad_fn=)"
416 | ]
417 | },
418 | "execution_count": 17,
419 | "metadata": {},
420 | "output_type": "execute_result"
421 | }
422 | ],
423 | "source": [
424 | "# model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')\n",
425 | "logits = model(**batch).logits\n",
426 | "logits"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 20,
432 | "metadata": {},
433 | "outputs": [
434 | {
435 | "data": {
436 | "text/plain": [
437 | "tensor(7.4759e-09, grad_fn=)"
438 | ]
439 | },
440 | "execution_count": 20,
441 | "metadata": {},
442 | "output_type": "execute_result"
443 | }
444 | ],
445 | "source": [
446 | "# from torch import nn\n",
447 | "# kld_loss_fct = nn.KLDivLoss(reduction=\"batchmean\")\n",
448 | "\n",
449 | "kld_loss_fct(\n",
450 | " nn.functional.log_softmax(logits / 1, dim=-1),\n",
451 | " nn.functional.softmax(logits / 1, dim=-1),\n",
452 | " ) * (1) ** 2"
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": 30,
458 | "metadata": {},
459 | "outputs": [
460 | {
461 | "data": {
462 | "text/plain": [
463 | "(torch.Size([8, 2]),\n",
464 | " torch.Size([8, 2]),\n",
465 | " tensor([[0.6097, 0.3903],\n",
466 | " [0.5698, 0.4302],\n",
467 | " [0.5829, 0.4171],\n",
468 | " [0.5860, 0.4140],\n",
469 | " [0.5920, 0.4080],\n",
470 | " [0.5839, 0.4161],\n",
471 | " [0.5857, 0.4143],\n",
472 | " [0.5875, 0.4125]], grad_fn=))"
473 | ]
474 | },
475 | "execution_count": 30,
476 | "metadata": {},
477 | "output_type": "execute_result"
478 | }
479 | ],
480 | "source": [
481 | "T = 1\n",
482 | "logits_gate = hardness_gate(pooled_output)\n",
483 | "gate_weights = F.softmax(logits_gate/T, dim=1)\n",
484 | "logits_gate.shape, gate_weights.shape, gate_weights"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "## 只使用 `easy expert` 进行训练和预测\n",
492 | "weight rank:\n",
493 | "- ambi-easy\n",
494 | "- easy\n",
495 | "- ambi-hard\n",
496 | "- hard"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": 31,
502 | "metadata": {},
503 | "outputs": [
504 | {
505 | "data": {
506 | "text/plain": [
507 | "tensor([0.6097, 0.5698, 0.5829, 0.5860, 0.5920, 0.5839, 0.5857, 0.5875],\n",
508 | " grad_fn=)"
509 | ]
510 | },
511 | "execution_count": 31,
512 | "metadata": {},
513 | "output_type": "execute_result"
514 | }
515 | ],
516 | "source": [
517 | "easy_probs = gate_weights[:,0]\n",
518 | "easy_probs"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": 37,
524 | "metadata": {},
525 | "outputs": [
526 | {
527 | "data": {
528 | "text/plain": [
529 | "(tensor([0.8903, 0.9302, 0.9171, 0.9140, 0.9080, 0.9161, 0.9143, 0.9125],\n",
530 | " grad_fn=),\n",
531 | " tensor([0.9753, 1.0191, 1.0047, 1.0013, 0.9947, 1.0036, 1.0017, 0.9997],\n",
532 | " grad_fn=))"
533 | ]
534 | },
535 | "execution_count": 37,
536 | "metadata": {},
537 | "output_type": "execute_result"
538 | }
539 | ],
540 | "source": [
541 | "easy_weights = torch.where(easy_probs>0.5, 1-torch.abs(easy_probs-0.5), easy_probs)\n",
542 | "# 归一化\n",
543 | "batch_size = 8\n",
544 | "easy_weights, easy_weights * batch_size / torch.sum(easy_weights)"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": 40,
550 | "metadata": {},
551 | "outputs": [
552 | {
553 | "data": {
554 | "text/plain": [
555 | "(tensor([0.4060, 0.3822, 0.9169, 0.4725, 1.1667, 0.2972, 0.4749, 0.9775],\n",
556 | " grad_fn=),\n",
557 | " tensor(0.6367, grad_fn=))"
558 | ]
559 | },
560 | "execution_count": 40,
561 | "metadata": {},
562 | "output_type": "execute_result"
563 | }
564 | ],
565 | "source": [
566 | "easy_weights = easy_weights * batch_size / torch.sum(easy_weights)\n",
567 | "example_loss = torch.tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778])\n",
568 | "easy_weights * example_loss, torch.mean(easy_weights * example_loss)"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {},
574 | "source": [
575 | "end of 只使用 `easy expert` 进行训练和预测.\n",
576 | "---"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": 22,
582 | "metadata": {},
583 | "outputs": [
584 | {
585 | "data": {
586 | "text/plain": [
587 | "(tensor([[0.7000, 0.3000],\n",
588 | " [0.4000, 0.6000]]),\n",
589 | " tensor([[0.8000, 0.3000],\n",
590 | " [0.4000, 0.9000]]))"
591 | ]
592 | },
593 | "execution_count": 22,
594 | "metadata": {},
595 | "output_type": "execute_result"
596 | }
597 | ],
598 | "source": [
599 | "X = torch.tensor([[0.7,0.3],[0.4,0.6]])\n",
600 | "X, torch.where(X>0.5, 1-torch.abs(X-0.5), X)"
601 | ]
602 | },
603 | {
604 | "cell_type": "code",
605 | "execution_count": 24,
606 | "metadata": {},
607 | "outputs": [],
608 | "source": [
609 | "pooled_output = dropout(pooled_output)\n",
610 | "logits_easy = classifier_easy(pooled_output)\n",
611 | "logits_hard = classifier_easy(pooled_output)\n"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": 28,
617 | "metadata": {},
618 | "outputs": [
619 | {
620 | "data": {
621 | "text/plain": [
622 | "(tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000]),\n",
623 | " tensor([[0.7000],\n",
624 | " [0.7000],\n",
625 | " [0.7000],\n",
626 | " [0.7000],\n",
627 | " [0.7000],\n",
628 | " [0.7000],\n",
629 | " [0.7000],\n",
630 | " [0.7000]]),\n",
631 | " tensor([[0.3000],\n",
632 | " [0.3000],\n",
633 | " [0.3000],\n",
634 | " [0.3000],\n",
635 | " [0.3000],\n",
636 | " [0.3000],\n",
637 | " [0.3000],\n",
638 | " [0.3000]]),\n",
639 | " tensor([[0.7000, 0.3000],\n",
640 | " [0.7000, 0.3000],\n",
641 | " [0.7000, 0.3000],\n",
642 | " [0.7000, 0.3000],\n",
643 | " [0.7000, 0.3000],\n",
644 | " [0.7000, 0.3000],\n",
645 | " [0.7000, 0.3000],\n",
646 | " [0.7000, 0.3000]]))"
647 | ]
648 | },
649 | "execution_count": 28,
650 | "metadata": {},
651 | "output_type": "execute_result"
652 | }
653 | ],
654 | "source": [
655 | "confidences = torch.tensor([0.7]*8)\n",
656 | "confidences\n",
657 | "easy_probs = confidences.view(-1,1)\n",
658 | "hard_probs = 1 - easy_probs\n",
659 | "hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)\n",
660 | "confidences, easy_probs, hard_probs, hardness_probs"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": 35,
666 | "metadata": {},
667 | "outputs": [
668 | {
669 | "data": {
670 | "text/plain": [
671 | "(tensor([[-0.6693, -0.7176],\n",
672 | " [-0.6422, -0.7469],\n",
673 | " [-0.6316, -0.7587],\n",
674 | " [-0.6843, -0.7021],\n",
675 | " [-0.6417, -0.7474],\n",
676 | " [-0.6898, -0.6965],\n",
677 | " [-0.6593, -0.7282],\n",
678 | " [-0.6960, -0.6903]], grad_fn=),\n",
679 | " tensor([[0.7000, 0.3000],\n",
680 | " [0.7000, 0.3000],\n",
681 | " [0.7000, 0.3000],\n",
682 | " [0.7000, 0.3000],\n",
683 | " [0.7000, 0.3000],\n",
684 | " [0.7000, 0.3000],\n",
685 | " [0.7000, 0.3000],\n",
686 | " [0.7000, 0.3000]]))"
687 | ]
688 | },
689 | "execution_count": 35,
690 | "metadata": {},
691 | "output_type": "execute_result"
692 | }
693 | ],
694 | "source": [
695 | "F.log_softmax(logits_gate, dim=-1),hardness_probs"
696 | ]
697 | },
698 | {
699 | "cell_type": "code",
700 | "execution_count": 42,
701 | "metadata": {},
702 | "outputs": [
703 | {
704 | "data": {
705 | "text/plain": [
706 | "tensor(0.0712, grad_fn=)"
707 | ]
708 | },
709 | "execution_count": 42,
710 | "metadata": {},
711 | "output_type": "execute_result"
712 | }
713 | ],
714 | "source": [
715 | "loss_gate = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')\n",
716 | "loss_gate"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": 50,
722 | "metadata": {},
723 | "outputs": [
724 | {
725 | "data": {
726 | "text/plain": [
727 | "(tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],\n",
728 | " grad_fn=),\n",
729 | " tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],\n",
730 | " grad_fn=))"
731 | ]
732 | },
733 | "execution_count": 50,
734 | "metadata": {},
735 | "output_type": "execute_result"
736 | }
737 | ],
738 | "source": [
739 | "num_labels = config.num_labels\n",
740 | "labels = batch['labels']\n",
741 | "loss_fct = CrossEntropyLoss(reduction='none')\n",
742 | "loss_easy = loss_fct(logits_easy.view(-1, num_labels), labels.view(-1))\n",
743 | "loss_hard = loss_fct(logits_hard.view(-1, num_labels), labels.view(-1))\n",
744 | "loss_easy, loss_hard"
745 | ]
746 | },
747 | {
748 | "cell_type": "code",
749 | "execution_count": 54,
750 | "metadata": {},
751 | "outputs": [
752 | {
753 | "data": {
754 | "text/plain": [
755 | "tensor([[0.4163, 0.4163],\n",
756 | " [0.3751, 0.3751],\n",
757 | " [0.9126, 0.9126],\n",
758 | " [0.4719, 0.4719],\n",
759 | " [1.1729, 1.1729],\n",
760 | " [0.2961, 0.2961],\n",
761 | " [0.4741, 0.4741],\n",
762 | " [0.9778, 0.9778]], grad_fn=)"
763 | ]
764 | },
765 | "execution_count": 54,
766 | "metadata": {},
767 | "output_type": "execute_result"
768 | }
769 | ],
770 | "source": [
771 | "easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)\n",
772 | "easy_hard_loss_cat"
773 | ]
774 | },
775 | {
776 | "cell_type": "code",
777 | "execution_count": 55,
778 | "metadata": {},
779 | "outputs": [
780 | {
781 | "data": {
782 | "text/plain": [
783 | "(tensor([[0.4163, 0.4163],\n",
784 | " [0.3751, 0.3751],\n",
785 | " [0.9126, 0.9126],\n",
786 | " [0.4719, 0.4719],\n",
787 | " [1.1729, 1.1729],\n",
788 | " [0.2961, 0.2961],\n",
789 | " [0.4741, 0.4741],\n",
790 | " [0.9778, 0.9778]], grad_fn=),\n",
791 | " tensor([[0.5121, 0.4879],\n",
792 | " [0.5262, 0.4738],\n",
793 | " [0.5317, 0.4683],\n",
794 | " [0.5044, 0.4956],\n",
795 | " [0.5264, 0.4736],\n",
796 | " [0.5017, 0.4983],\n",
797 | " [0.5172, 0.4828],\n",
798 | " [0.4986, 0.5014]], grad_fn=),\n",
799 | " tensor([[0.2132, 0.2031],\n",
800 | " [0.1974, 0.1777],\n",
801 | " [0.4853, 0.4274],\n",
802 | " [0.2380, 0.2338],\n",
803 | " [0.6174, 0.5555],\n",
804 | " [0.1486, 0.1476],\n",
805 | " [0.2452, 0.2289],\n",
806 | " [0.4875, 0.4903]], grad_fn=))"
807 | ]
808 | },
809 | "execution_count": 55,
810 | "metadata": {},
811 | "output_type": "execute_result"
812 | }
813 | ],
814 | "source": [
815 | "easy_hard_loss_cat, gate_weights, easy_hard_loss_cat * gate_weights"
816 | ]
817 | },
818 | {
819 | "cell_type": "code",
820 | "execution_count": 80,
821 | "metadata": {},
822 | "outputs": [
823 | {
824 | "data": {
825 | "text/plain": [
826 | "tensor(0.3185, grad_fn=)"
827 | ]
828 | },
829 | "execution_count": 80,
830 | "metadata": {},
831 | "output_type": "execute_result"
832 | }
833 | ],
834 | "source": [
835 | "weighted_loss = easy_hard_loss_cat * gate_weights\n",
836 | "torch.mean(weighted_loss)"
837 | ]
838 | },
839 | {
840 | "cell_type": "code",
841 | "execution_count": 111,
842 | "metadata": {},
843 | "outputs": [
844 | {
845 | "name": "stderr",
846 | "output_type": "stream",
847 | "text": [
848 | "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']\n",
849 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
850 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
851 | ]
852 | }
853 | ],
854 | "source": [
855 | "from transformers.modeling_outputs import SequenceClassifierOutput\n",
856 | "\n",
857 | "class HCTForSequenceClassification(nn.Module):\n",
858 | " def __init__(self, model_name_or_path, config):\n",
859 | " super(HCTForSequenceClassification, self).__init__()\n",
860 | " self.encoder = AutoModel.from_pretrained(model_name_or_path)\n",
861 | " self.config = config\n",
862 | " self.num_labels = config.num_labels\n",
863 | " self.classifier_dropout = (\n",
864 | " self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob\n",
865 | " )\n",
866 | " self.dropout = nn.Dropout(self.classifier_dropout)\n",
867 | " self.classifier_easy = nn.Linear(self.config.hidden_size, config.num_labels)\n",
868 | " self.classifier_hard = nn.Linear(self.config.hidden_size, config.num_labels)\n",
869 | " # 2 experts: easy or hard\n",
870 | " # gate output: 0 for easy, 1 for hard\n",
871 | " self.hardness_gate = nn.Linear(self.config.hidden_size,2) \n",
872 | "\n",
873 | "\n",
874 | " def forward(\n",
875 | " self,\n",
876 | " input_ids: Optional[torch.Tensor] = None,\n",
877 | " attention_mask: Optional[torch.Tensor] = None,\n",
878 | " token_type_ids: Optional[torch.Tensor] = None,\n",
879 | " position_ids: Optional[torch.Tensor] = None,\n",
880 | " head_mask: Optional[torch.Tensor] = None,\n",
881 | " inputs_embeds: Optional[torch.Tensor] = None,\n",
882 | " labels: Optional[torch.Tensor] = None,\n",
883 | " confidences: Optional[torch.Tensor] = None, # the confidence value of a sample\n",
884 | " output_attentions: Optional[bool] = None,\n",
885 | " output_hidden_states: Optional[bool] = None,\n",
886 | " return_dict: Optional[bool] = None,\n",
887 | " ):\n",
888 | " r\"\"\"\n",
889 | " labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n",
890 | " Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n",
891 | " config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n",
892 | " `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
893 | " \"\"\"\n",
894 | " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
895 | "\n",
896 | " outputs = self.encoder(\n",
897 | " input_ids,\n",
898 | " attention_mask=attention_mask,\n",
899 | " token_type_ids=token_type_ids,\n",
900 | " position_ids=position_ids,\n",
901 | " head_mask=head_mask,\n",
902 | " inputs_embeds=inputs_embeds,\n",
903 | " output_attentions=output_attentions,\n",
904 | " output_hidden_states=output_hidden_states,\n",
905 | " return_dict=return_dict,\n",
906 | " )\n",
907 | " # the CLS vector\n",
908 | " pooled_output = outputs[1] \n",
909 | " # gating:\n",
910 | " logits_gate = self.hardness_gate(pooled_output)\n",
911 | " gate_weights = F.softmax(logits_gate)\n",
912 | " # easy/hard experts:\n",
913 | " pooled_output = self.dropout(pooled_output)\n",
914 | " logits_easy = self.classifier_easy(pooled_output)\n",
915 | " logits_hard = self.classifier_hard(pooled_output)\n",
916 | "\n",
917 | " loss = None\n",
918 | " loss_easy, loss_hard, gate_loss = None, None, None\n",
919 | " if labels is not None:\n",
920 | " if self.config.problem_type is None:\n",
921 | " if self.num_labels == 1:\n",
922 | " self.config.problem_type = \"regression\"\n",
923 | " elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
924 | " self.config.problem_type = \"single_label_classification\"\n",
925 | " else:\n",
926 | " self.config.problem_type = \"multi_label_classification\"\n",
927 | "\n",
928 | " # gating loss\n",
929 | " # confidences 其实相当于 easy 分支的概率,所以你还需要自己构造一个 hard prob\n",
930 | " easy_probs = confidences.view(-1,1)\n",
931 | " hard_probs = 1 - easy_probs\n",
932 | " hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)\n",
933 | " gate_loss = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')\n",
934 | " if self.config.problem_type == \"regression\":\n",
935 | " loss_fct = MSELoss(reduction='none')\n",
936 | " if self.num_labels == 1:\n",
937 | " loss_easy = loss_fct(logits_easy.squeeze(), labels.squeeze())\n",
938 | " loss_hard = loss_fct(logits_hard.squeeze(), labels.squeeze())\n",
939 | " else:\n",
940 | " loss_easy = loss_fct(logits_easy, labels)\n",
941 | " loss_hard = loss_fct(logits_hard, labels)\n",
942 | " elif self.config.problem_type == \"single_label_classification\":\n",
943 | " loss_fct = CrossEntropyLoss(reduction='none') # reduction='none', 来得到每个sample的loss\n",
944 | " loss_easy = loss_fct(logits_easy.view(-1, self.num_labels), labels.view(-1))\n",
945 | " loss_hard = loss_fct(logits_hard.view(-1, self.num_labels), labels.view(-1))\n",
946 | " elif self.config.problem_type == \"multi_label_classification\":\n",
947 | " loss_fct = BCEWithLogitsLoss(reduction='none')\n",
948 | " loss_easy = loss_fct(logits_easy, labels)\n",
949 | " loss_hard = loss_fct(logits_hard, labels)\n",
950 | " \n",
951 | " easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)\n",
952 | " weighted_loss = easy_hard_loss_cat * gate_weights\n",
953 | " clf_loss = torch.mean(weighted_loss)\n",
954 | " loss = gate_loss + clf_loss\n",
955 | " \n",
956 | " if not return_dict:\n",
957 | " output = (logits_gate,logits_easy, logits_hard,) + outputs[2:]\n",
958 | " return ((loss,) + output) if loss is not None else output\n",
959 | "\n",
960 | " return SequenceClassifierOutput(\n",
961 | " loss=loss,\n",
962 | " logits={\"gate\":logits_gate, \"easy\":logits_easy, \"hard\":logits_hard},\n",
963 | " hidden_states=outputs.hidden_states,\n",
964 | " attentions=outputs.attentions,\n",
965 | " )\n",
966 | "\n",
967 | "my_model = HCTForSequenceClassification('bert-base-cased', config)"
968 | ]
969 | },
970 | {
971 | "cell_type": "code",
972 | "execution_count": 112,
973 | "metadata": {},
974 | "outputs": [
975 | {
976 | "name": "stderr",
977 | "output_type": "stream",
978 | "text": [
979 | "ipykernel_launcher:57: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
980 | ]
981 | },
982 | {
983 | "data": {
984 | "text/plain": [
985 | "(tensor(0.5328, grad_fn=),\n",
986 | " {'gate': tensor([[0.4412, 0.2590],\n",
987 | " [0.4171, 0.3171],\n",
988 | " [0.4294, 0.3268],\n",
989 | " [0.4348, 0.3207],\n",
990 | " [0.3658, 0.3255],\n",
991 | " [0.4153, 0.2928],\n",
992 | " [0.3959, 0.2656],\n",
993 | " [0.4340, 0.2752]], grad_fn=),\n",
994 | " 'easy': tensor([[ 0.0633, 0.9459],\n",
995 | " [-0.1190, 0.8201],\n",
996 | " [ 0.1919, 0.9389],\n",
997 | " [-0.2952, 0.6612],\n",
998 | " [ 0.0297, 0.8347],\n",
999 | " [-0.2183, 1.2751],\n",
1000 | " [-0.1609, 1.0679],\n",
1001 | " [-0.1628, 0.9590]], grad_fn=),\n",
1002 | " 'hard': tensor([[-0.2337, 0.2637],\n",
1003 | " [-0.3422, 0.3396],\n",
1004 | " [-0.2908, 0.3610],\n",
1005 | " [-0.2687, 0.1167],\n",
1006 | " [-0.1619, -0.1485],\n",
1007 | " [-0.7175, 0.3527],\n",
1008 | " [-0.4886, 0.2593],\n",
1009 | " [-0.1387, 0.3856]], grad_fn=)})"
1010 | ]
1011 | },
1012 | "execution_count": 112,
1013 | "metadata": {},
1014 | "output_type": "execute_result"
1015 | }
1016 | ],
1017 | "source": [
1018 | "my_outputs = my_model(**batch, confidences=confidences)\n",
1019 | "my_outputs.loss, my_outputs.logits"
1020 | ]
1021 | },
1022 | {
1023 | "cell_type": "code",
1024 | "execution_count": 107,
1025 | "metadata": {},
1026 | "outputs": [
1027 | {
1028 | "data": {
1029 | "text/plain": [
1030 | "(tensor([1, 1, 1, 1, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 1, 1, 1, 1]))"
1031 | ]
1032 | },
1033 | "execution_count": 107,
1034 | "metadata": {},
1035 | "output_type": "execute_result"
1036 | }
1037 | ],
1038 | "source": [
1039 | "# indices = torch.argmax(my_outputs.logits['gate'],dim=1)\n",
1040 | "indices = torch.tensor([1,1,1,1,0,0,0,0])\n",
1041 | "indices, 1-indices"
1042 | ]
1043 | },
1044 | {
1045 | "cell_type": "code",
1046 | "execution_count": 113,
1047 | "metadata": {},
1048 | "outputs": [
1049 | {
1050 | "data": {
1051 | "text/plain": [
1052 | "tensor([[-0.2337, 0.2637],\n",
1053 | " [-0.3422, 0.3396],\n",
1054 | " [-0.2908, 0.3610],\n",
1055 | " [-0.2687, 0.1167],\n",
1056 | " [ 0.0297, 0.8347],\n",
1057 | " [-0.2183, 1.2751],\n",
1058 | " [-0.1609, 1.0679],\n",
1059 | " [-0.1628, 0.9590]], grad_fn=)"
1060 | ]
1061 | },
1062 | "execution_count": 113,
1063 | "metadata": {},
1064 | "output_type": "execute_result"
1065 | }
1066 | ],
1067 | "source": [
1068 | "my_outputs.logits['easy'] * (1-indices).view(-1,1), my_outputs.logits['hard'] * indices.view(-1,1)\n",
1069 | "\n",
1070 | "my_outputs.logits['easy'] * (1-indices).view(-1,1) + my_outputs.logits['hard'] * indices.view(-1,1)"
1071 | ]
1072 | },
1073 | {
1074 | "cell_type": "code",
1075 | "execution_count": 114,
1076 | "metadata": {},
1077 | "outputs": [
1078 | {
1079 | "data": {
1080 | "text/plain": [
1081 | "(tensor([[ 0.0633, 0.9459],\n",
1082 | " [-0.1190, 0.8201],\n",
1083 | " [ 0.1919, 0.9389],\n",
1084 | " [-0.2952, 0.6612],\n",
1085 | " [ 0.0297, 0.8347],\n",
1086 | " [-0.2183, 1.2751],\n",
1087 | " [-0.1609, 1.0679],\n",
1088 | " [-0.1628, 0.9590]], grad_fn=),\n",
1089 | " tensor([[-0.2337, 0.2637],\n",
1090 | " [-0.3422, 0.3396],\n",
1091 | " [-0.2908, 0.3610],\n",
1092 | " [-0.2687, 0.1167],\n",
1093 | " [-0.1619, -0.1485],\n",
1094 | " [-0.7175, 0.3527],\n",
1095 | " [-0.4886, 0.2593],\n",
1096 | " [-0.1387, 0.3856]], grad_fn=))"
1097 | ]
1098 | },
1099 | "execution_count": 114,
1100 | "metadata": {},
1101 | "output_type": "execute_result"
1102 | }
1103 | ],
1104 | "source": [
1105 | "my_outputs.logits['easy'], my_outputs.logits['hard']"
1106 | ]
1107 | },
1108 | {
1109 | "cell_type": "code",
1110 | "execution_count": 117,
1111 | "metadata": {},
1112 | "outputs": [
1113 | {
1114 | "data": {
1115 | "text/plain": [
1116 | "(tensor([[0.5454, 0.4546],\n",
1117 | " [0.5250, 0.4750],\n",
1118 | " [0.5256, 0.4744],\n",
1119 | " [0.5285, 0.4715],\n",
1120 | " [0.5101, 0.4899],\n",
1121 | " [0.5306, 0.4694],\n",
1122 | " [0.5325, 0.4675],\n",
1123 | " [0.5396, 0.4604]], grad_fn=),\n",
1124 | " tensor([0.5454, 0.5250, 0.5256, 0.5285, 0.5101, 0.5306, 0.5325, 0.5396],\n",
1125 | " grad_fn=),\n",
1126 | " tensor([0.4546, 0.4750, 0.4744, 0.4715, 0.4899, 0.4694, 0.4675, 0.4604],\n",
1127 | " grad_fn=))"
1128 | ]
1129 | },
1130 | "execution_count": 117,
1131 | "metadata": {},
1132 | "output_type": "execute_result"
1133 | }
1134 | ],
1135 | "source": [
1136 | "weights = F.softmax(my_outputs.logits['gate'], dim=1)\n",
1137 | "weights, weights[:,0], weights[:,1] # 第0列是easy的权重,第1列是hard的权重"
1138 | ]
1139 | },
1140 | {
1141 | "cell_type": "code",
1142 | "execution_count": 120,
1143 | "metadata": {},
1144 | "outputs": [
1145 | {
1146 | "data": {
1147 | "text/plain": [
1148 | "(tensor([[ 0.0345, 0.5159],\n",
1149 | " [-0.0625, 0.4305],\n",
1150 | " [ 0.1009, 0.4935],\n",
1151 | " [-0.1560, 0.3494],\n",
1152 | " [ 0.0152, 0.4258],\n",
1153 | " [-0.1158, 0.6765],\n",
1154 | " [-0.0857, 0.5687],\n",
1155 | " [-0.0879, 0.5175]], grad_fn=),\n",
1156 | " tensor([[-0.1062, 0.1199],\n",
1157 | " [-0.1626, 0.1613],\n",
1158 | " [-0.1380, 0.1712],\n",
1159 | " [-0.1267, 0.0550],\n",
1160 | " [-0.0793, -0.0728],\n",
1161 | " [-0.3368, 0.1656],\n",
1162 | " [-0.2284, 0.1212],\n",
1163 | " [-0.0639, 0.1775]], grad_fn=))"
1164 | ]
1165 | },
1166 | "execution_count": 120,
1167 | "metadata": {},
1168 | "output_type": "execute_result"
1169 | }
1170 | ],
1171 | "source": [
1172 | "my_outputs.logits['easy'] * weights[:,0].view(-1,1), my_outputs.logits['hard'] * weights[:,1].view(-1,1)"
1173 | ]
1174 | },
1175 | {
1176 | "cell_type": "code",
1177 | "execution_count": 1,
1178 | "metadata": {},
1179 | "outputs": [
1180 | {
1181 | "name": "stderr",
1182 | "output_type": "stream",
1183 | "text": [
1184 | "Reusing dataset snli (/home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)\n"
1185 | ]
1186 | },
1187 | {
1188 | "data": {
1189 | "application/vnd.jupyter.widget-view+json": {
1190 | "model_id": "9fafeef372dd4a3d99f7233786ee4508",
1191 | "version_major": 2,
1192 | "version_minor": 0
1193 | },
1194 | "text/plain": [
1195 | " 0%| | 0/3 [00:00, ?it/s]"
1196 | ]
1197 | },
1198 | "metadata": {},
1199 | "output_type": "display_data"
1200 | },
1201 | {
1202 | "data": {
1203 | "text/plain": [
1204 | "DatasetDict({\n",
1205 | " test: Dataset({\n",
1206 | " features: ['premise', 'hypothesis', 'label'],\n",
1207 | " num_rows: 10000\n",
1208 | " })\n",
1209 | " train: Dataset({\n",
1210 | " features: ['premise', 'hypothesis', 'label'],\n",
1211 | " num_rows: 550152\n",
1212 | " })\n",
1213 | " validation: Dataset({\n",
1214 | " features: ['premise', 'hypothesis', 'label'],\n",
1215 | " num_rows: 10000\n",
1216 | " })\n",
1217 | "})"
1218 | ]
1219 | },
1220 | "execution_count": 1,
1221 | "metadata": {},
1222 | "output_type": "execute_result"
1223 | }
1224 | ],
1225 | "source": [
1226 | "from datasets import load_dataset\n",
1227 | "snli_data = load_dataset('snli')\n",
1228 | "snli_data"
1229 | ]
1230 | },
1231 | {
1232 | "cell_type": "code",
1233 | "execution_count": 8,
1234 | "metadata": {},
1235 | "outputs": [
1236 | {
1237 | "data": {
1238 | "text/plain": [
1239 | "Counter({1: 3219, 0: 3368, 2: 3237, -1: 176})"
1240 | ]
1241 | },
1242 | "execution_count": 8,
1243 | "metadata": {},
1244 | "output_type": "execute_result"
1245 | }
1246 | ],
1247 | "source": [
1248 | "from collections import Counter\n",
1249 | "c = Counter(snli_data['test']['label'])\n",
1250 | "c"
1251 | ]
1252 | },
1253 | {
1254 | "cell_type": "code",
1255 | "execution_count": 6,
1256 | "metadata": {},
1257 | "outputs": [
1258 | {
1259 | "name": "stderr",
1260 | "output_type": "stream",
1261 | "text": [
1262 | "Loading cached processed dataset at /home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bd229a5e884be60a.arrow\n"
1263 | ]
1264 | },
1265 | {
1266 | "data": {
1267 | "text/plain": [
1268 | "Dataset({\n",
1269 | " features: ['premise', 'hypothesis', 'label'],\n",
1270 | " num_rows: 549367\n",
1271 | "})"
1272 | ]
1273 | },
1274 | "execution_count": 6,
1275 | "metadata": {},
1276 | "output_type": "execute_result"
1277 | }
1278 | ],
1279 | "source": [
1280 | "snli_data['train'] = snli_data['train'].filter(lambda x:x['label']!=-1)\n",
1281 | "snli_data['train']"
1282 | ]
1283 | },
1284 | {
1285 | "cell_type": "code",
1286 | "execution_count": 5,
1287 | "metadata": {},
1288 | "outputs": [
1289 | {
1290 | "data": {
1291 | "text/plain": [
1292 | "549367"
1293 | ]
1294 | },
1295 | "execution_count": 5,
1296 | "metadata": {},
1297 | "output_type": "execute_result"
1298 | }
1299 | ],
1300 | "source": [
1301 | "550152-785"
1302 | ]
1303 | },
1304 | {
1305 | "cell_type": "code",
1306 | "execution_count": null,
1307 | "metadata": {},
1308 | "outputs": [],
1309 | "source": []
1310 | }
1311 | ],
1312 | "metadata": {
1313 | "interpreter": {
1314 | "hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
1315 | },
1316 | "kernelspec": {
1317 | "display_name": "Python 3.6.10 64-bit ('conda': virtualenv)",
1318 | "name": "python3"
1319 | },
1320 | "language_info": {
1321 | "codemirror_mode": {
1322 | "name": "ipython",
1323 | "version": 3
1324 | },
1325 | "file_extension": ".py",
1326 | "mimetype": "text/x-python",
1327 | "name": "python",
1328 | "nbconvert_exporter": "python",
1329 | "pygments_lexer": "ipython3",
1330 | "version": "3.6.10"
1331 | },
1332 | "orig_nbformat": 4
1333 | },
1334 | "nbformat": 4,
1335 | "nbformat_minor": 2
1336 | }
--------------------------------------------------------------------------------
/HCT/run_glue_hct.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
16 | import argparse
17 | from collections import defaultdict
18 | import json
19 | import logging
20 | import math
21 | import os
22 | import random
23 | from pathlib import Path
24 |
25 | import datasets
26 | import torch
27 | from torch import nn
28 | import torch.nn.functional as F
29 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30 |
31 | from datasets import load_dataset, load_metric, load_from_disk
32 | from torch.utils.data import DataLoader
33 | from tqdm.auto import tqdm
34 |
35 | import transformers
36 | from accelerate import Accelerator
37 | from accelerate.logging import get_logger
38 | from accelerate.utils import set_seed
39 | from huggingface_hub import Repository
40 | from transformers import (
41 | AutoConfig,
42 | AutoModel,
43 | AutoModelForSequenceClassification,
44 | AutoTokenizer,
45 | DataCollatorWithPadding,
46 | PretrainedConfig,
47 | SchedulerType,
48 | default_data_collator,
49 | get_scheduler,
50 | )
51 | from transformers.modeling_outputs import SequenceClassifierOutput
52 |
53 | # from transformers.utils import get_full_repo_name, send_example_telemetry
54 | from transformers.utils.versions import require_version
55 |
56 |
57 | logger = get_logger(__name__)
58 |
59 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
60 |
61 | task_to_keys = {
62 | "cola": ("sentence", None),
63 | "mnli": ("premise", "hypothesis"),
64 | "mrpc": ("sentence1", "sentence2"),
65 | "qnli": ("question", "sentence"),
66 | "qqp": ("question1", "question2"),
67 | "rte": ("sentence1", "sentence2"),
68 | "sst2": ("sentence", None),
69 | "stsb": ("sentence1", "sentence2"),
70 | "wnli": ("sentence1", "sentence2"),
71 |
72 | "snli": ("premise", "hypothesis"),
73 |
74 | "boolq": ("question", "passage"),
75 | "cb": ("premise", "hypothesis"),
76 | # "mrpc-noisy":("sentence1", "sentence2"),
77 | # "rte-noisy":("sentence1", "sentence2"),
78 | }
79 |
80 |
81 | def parse_args():
82 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
83 | parser.add_argument(
84 | "--task_name",
85 | type=str,
86 | default=None,
87 | help="The name of the glue task to train on.",
88 | # choices=list(task_to_keys.keys()),
89 | )
90 | parser.add_argument(
91 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
92 | )
93 | parser.add_argument(
94 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
95 | )
96 | parser.add_argument(
97 | "--max_length",
98 | type=int,
99 | default=128,
100 | help=(
101 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
102 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
103 | ),
104 | )
105 | parser.add_argument(
106 | "--pad_to_max_length",
107 | action="store_true",
108 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
109 | )
110 | parser.add_argument(
111 | "--model_name_or_path",
112 | type=str,
113 | help="Path to pretrained model or model identifier from huggingface.co/models.",
114 | required=True,
115 | )
116 | parser.add_argument(
117 | "--use_slow_tokenizer",
118 | action="store_true",
119 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
120 | )
121 | parser.add_argument(
122 | "--per_device_train_batch_size",
123 | type=int,
124 | default=8,
125 | help="Batch size (per device) for the training dataloader.",
126 | )
127 | parser.add_argument(
128 | "--per_device_eval_batch_size",
129 | type=int,
130 | default=8,
131 | help="Batch size (per device) for the evaluation dataloader.",
132 | )
133 | parser.add_argument(
134 | "--learning_rate",
135 | type=float,
136 | default=5e-5,
137 | help="Initial learning rate (after the potential warmup period) to use.",
138 | )
139 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
140 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
141 | parser.add_argument(
142 | "--max_train_steps",
143 | type=int,
144 | default=None,
145 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
146 | )
147 | parser.add_argument(
148 | "--gradient_accumulation_steps",
149 | type=int,
150 | default=1,
151 | help="Number of updates steps to accumulate before performing a backward/update pass.",
152 | )
153 | parser.add_argument(
154 | "--lr_scheduler_type",
155 | type=SchedulerType,
156 | default="linear",
157 | help="The scheduler type to use.",
158 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
159 | )
160 | parser.add_argument(
161 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
162 | )
163 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
164 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
165 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
166 | parser.add_argument(
167 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
168 | )
169 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
170 | parser.add_argument(
171 | "--checkpointing_steps",
172 | type=str,
173 | default=None,
174 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
175 | )
176 | parser.add_argument(
177 | "--resume_from_checkpoint",
178 | type=str,
179 | default=None,
180 | help="If the training should continue from a checkpoint folder.",
181 | )
182 | parser.add_argument(
183 | "--with_tracking",
184 | action="store_true",
185 | help="Whether to enable experiment trackers for logging.",
186 | )
187 | parser.add_argument(
188 | "--report_to",
189 | type=str,
190 | default="all",
191 | help=(
192 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
193 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
194 | "Only applicable when `--with_tracking` is passed."
195 | ),
196 | )
197 | parser.add_argument(
198 | "--ignore_mismatched_sizes",
199 | action="store_true",
200 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
201 | )
202 | parser.add_argument("--do_recording", action="store_true", help="Whether to record the training dynamics.")
203 | parser.add_argument("--with_data_selection", action="store_true", help="Use only a selected subset of the training data for model training.")
204 | parser.add_argument("--data_selection_region", default=None, choices=("easy","hard","ambiguous"),
205 | help="Three regions from the dataset cartography: easy, hard and ambiguous")
206 | parser.add_argument("--temperature", type=float, default=1., help="the temperature for the softmax in HCT model")
207 | parser.add_argument("--mu", type=float, default=0.5, help="weight for the gate loss in HCT model")
208 | parser.add_argument("--hard_inference", action="store_true", default=False, help="weight for the gate loss in HCT model")
209 | parser.add_argument("--hard_with_ls", action="store_true", help="if set, use label_smoothing for hard expert")
210 | parser.add_argument("--ls_weight", type=float, default=0.1, help="weight for label_smoothing")
211 | parser.add_argument("--more_ambiguous", action="store_true", default=False, help="set to put more weights to ambiguous samples")
212 |
213 | args = parser.parse_args()
214 |
215 | # Sanity checks
216 | if args.task_name is None and args.train_file is None and args.validation_file is None:
217 | raise ValueError("Need either a task name or a training/validation file.")
218 | else:
219 | if args.train_file is not None:
220 | extension = args.train_file.split(".")[-1]
221 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
222 | if args.validation_file is not None:
223 | extension = args.validation_file.split(".")[-1]
224 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
225 |
226 | # if args.push_to_hub:
227 | # assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
228 |
229 | return args
230 |
231 | from typing import List, Optional, Tuple, Union
232 |
233 |
234 | class HCTForSequenceClassification(nn.Module):
235 | """
236 | Modified from `BertForSequenceClassification` class
237 | """
238 | def __init__(self, model_name_or_path, config, temperature, mu, hard_with_ls=False, ls_weight=None, more_ambiguous=False):
239 | super(HCTForSequenceClassification, self).__init__()
240 | self.encoder = AutoModel.from_pretrained(model_name_or_path)
241 | self.config = config
242 | self.num_labels = config.num_labels
243 | self.classifier_dropout = (
244 | self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob
245 | )
246 | self.dropout = nn.Dropout(self.classifier_dropout)
247 | self.classifier_easy = nn.Linear(self.config.hidden_size, config.num_labels)
248 | self.classifier_hard = nn.Linear(self.config.hidden_size, config.num_labels)
249 | # 2 experts: easy or hard
250 | # gate output: 0 for easy, 1 for hard
251 | self.hardness_gate = nn.Linear(self.config.hidden_size,2)
252 |
253 | self.T = temperature
254 | self.mu = mu
255 | self.hard_with_ls = hard_with_ls
256 | self.ls_weight = ls_weight
257 | self.more_ambiguous = more_ambiguous
258 |
259 |
260 | def forward(
261 | self,
262 | input_ids: Optional[torch.Tensor] = None,
263 | attention_mask: Optional[torch.Tensor] = None,
264 | token_type_ids: Optional[torch.Tensor] = None,
265 | position_ids: Optional[torch.Tensor] = None,
266 | head_mask: Optional[torch.Tensor] = None,
267 | inputs_embeds: Optional[torch.Tensor] = None,
268 | labels: Optional[torch.Tensor] = None,
269 | # the confidence value of a sample
270 | confidences: Optional[torch.Tensor] = None,
271 | output_attentions: Optional[bool] = None,
272 | output_hidden_states: Optional[bool] = None,
273 | return_dict: Optional[bool] = None,
274 | ):
275 | r"""
276 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
277 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
278 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
279 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
280 | """
281 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
282 |
283 | outputs = self.encoder(
284 | input_ids,
285 | attention_mask=attention_mask,
286 | token_type_ids=token_type_ids,
287 | position_ids=position_ids,
288 | head_mask=head_mask,
289 | inputs_embeds=inputs_embeds,
290 | output_attentions=output_attentions,
291 | output_hidden_states=output_hidden_states,
292 | return_dict=return_dict,
293 | )
294 | # the CLS vector
295 | pooled_output = outputs[1]
296 | # gating:
297 | logits_gate = self.hardness_gate(pooled_output)
298 | gate_weights = F.softmax(logits_gate/self.T, dim=1) #
299 | # easy/hard experts:
300 | pooled_output = self.dropout(pooled_output)
301 | logits_easy = self.classifier_easy(pooled_output)
302 | logits_hard = self.classifier_hard(pooled_output)
303 |
304 | loss = None
305 | loss_easy, loss_hard, gate_loss = None, None, None
306 | if labels is not None:
307 | if self.config.problem_type is None:
308 | if self.num_labels == 1:
309 | self.config.problem_type = "regression"
310 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
311 | self.config.problem_type = "single_label_classification"
312 | else:
313 | self.config.problem_type = "multi_label_classification"
314 |
315 | if self.training:
316 | # gating loss, 仅在 train 模式下才有
317 | # confidences 其实相当于 easy 分支的概率,所以你还需要自己构造一个 hard prob
318 | easy_probs = confidences.view(-1,1)
319 | hard_probs = 1 - easy_probs
320 | hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)
321 | gate_loss = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')
322 | else:
323 | gate_loss = None
324 | if self.config.problem_type == "regression":
325 | loss_fct = MSELoss(reduction='none')
326 | if self.num_labels == 1:
327 | loss_easy = loss_fct(logits_easy.squeeze(), labels.squeeze())
328 | loss_hard = loss_fct(logits_hard.squeeze(), labels.squeeze())
329 | else:
330 | loss_easy = loss_fct(logits_easy, labels)
331 | loss_hard = loss_fct(logits_hard, labels)
332 | elif self.config.problem_type == "single_label_classification":
333 | loss_fct = CrossEntropyLoss(reduction='none') # reduction='none', 来得到每个sample的loss
334 | loss_easy = loss_fct(logits_easy.view(-1, self.num_labels), labels.view(-1))
335 | # hard 端使用 label smoothing,即原始 CE loss 和一个完全 smoothing loss 的加权平均
336 | # 目前使用的是 torch 1.7 版本,所有只能自己手写,在1.10版本后可以直接设置 label smoothing
337 | if self.hard_with_ls:
338 | loss_ls_fct = CrossEntropyLoss(reduction='none', label_smoothing=self.ls_weight)
339 | loss_hard = loss_ls_fct(logits_hard.view(-1, self.num_labels), labels.view(-1))
340 | else:
341 | loss_hard = loss_fct(logits_hard.view(-1, self.num_labels), labels.view(-1))
342 | elif self.config.problem_type == "multi_label_classification":
343 | loss_fct = BCEWithLogitsLoss(reduction='none')
344 | loss_easy = loss_fct(logits_easy, labels)
345 | loss_hard = loss_fct(logits_hard, labels)
346 |
347 | easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)
348 | # !!!
349 | if self.more_ambiguous:
350 | gate_weights = torch.where(gate_weights>0.5, 1-torch.abs(gate_weights-0.5), gate_weights)
351 |
352 | # way 1: weighted loss of easy and hard experts
353 | # TODO: 这里的权重,最好归一化一下
354 | weighted_loss = easy_hard_loss_cat * gate_weights
355 |
356 | # !!!
357 | # way 2: only easy expert
358 | # easy_weights = gate_weights[:,0]
359 | # batch_size = 32 # TODO
360 | # easy_weights = easy_weights * batch_size / torch.sum(easy_weights) # 归一化
361 | # weighted_loss = loss_easy * easy_weights
362 | # 这个策略,对noise可能有效,对干净数据集反而影响效果
363 |
364 |
365 | clf_loss = torch.mean(weighted_loss)
366 |
367 | # only use easy expert
368 | # easy_probs = gate_weights[:,0]
369 | # easy_weights = torch.where(easy_probs>0.5, 1-torch.abs(easy_probs-0.5), easy_probs)
370 |
371 | if self.training: # 只有 train 的时候才有 gate_loss
372 | loss = self.mu * gate_loss + (1 - self.mu) * clf_loss
373 | else:
374 | loss = clf_loss
375 |
376 | if not return_dict:
377 | output = (logits_gate,logits_easy, logits_hard,) + outputs[2:]
378 | return ((loss,) + output) if loss is not None else output
379 |
380 | return SequenceClassifierOutput(
381 | loss=loss,
382 | logits={"gate":logits_gate, "easy":logits_easy, "hard":logits_hard},
383 | hidden_states=outputs.hidden_states,
384 | attentions=outputs.attentions,
385 | )
386 |
387 |
388 |
389 | def main():
390 | args = parse_args()
391 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
392 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
393 | # send_example_telemetry("run_glue_no_trainer", args)
394 |
395 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
396 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
397 | # in the environment
398 | accelerator = (
399 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
400 | )
401 | # Make one log on every process with the configuration for debugging.
402 | logging.basicConfig(
403 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
404 | datefmt="%m/%d/%Y %H:%M:%S",
405 | level=logging.INFO,
406 | )
407 | logger.info(accelerator.state, main_process_only=False)
408 | if accelerator.is_local_main_process:
409 | datasets.utils.logging.set_verbosity_warning()
410 | transformers.utils.logging.set_verbosity_info()
411 | else:
412 | datasets.utils.logging.set_verbosity_error()
413 | transformers.utils.logging.set_verbosity_error()
414 |
415 | # If passed along, set the training seed now.
416 | if args.seed is not None:
417 | set_seed(args.seed)
418 |
419 |
420 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
421 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
422 |
423 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
424 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
425 | # label if at least two columns are provided.
426 |
427 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
428 | # single column. You can easily tweak this behavior (see below)
429 |
430 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
431 | # download the dataset.
432 | if args.task_name is not None:
433 | # Downloading and loading a dataset from the hub.
434 | # raw_datasets = load_dataset("glue", args.task_name)
435 | # load processed GLUE dataset that contains 'confidence' value
436 | raw_datasets = load_from_disk(f"../datasets/{args.task_name}/with_conf/")
437 | assert 'confidence' in raw_datasets['train'].column_names, "the Train set must contain a 'confidence' column!!"
438 | print("Yes! I have loaded it! ----------")
439 |
440 |
441 | else:
442 | # Loading the dataset from local csv or json file.
443 | data_files = {}
444 | if args.train_file is not None:
445 | data_files["train"] = args.train_file
446 | if args.validation_file is not None:
447 | data_files["validation"] = args.validation_file
448 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
449 | raw_datasets = load_dataset(extension, data_files=data_files)
450 | # See more about loading any type of standard or custom dataset at
451 | # https://huggingface.co/docs/datasets/loading_datasets.html.
452 |
453 | # Labels
454 | if args.task_name is not None:
455 | is_regression = args.task_name == "stsb"
456 | if not is_regression:
457 | label_list = raw_datasets["validation"].features["label"].names
458 | num_labels = len(label_list)
459 | else:
460 | num_labels = 1
461 | else:
462 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
463 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
464 | if is_regression:
465 | num_labels = 1
466 | else:
467 | # A useful fast method:
468 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
469 | label_list = raw_datasets["validation"].unique("label")
470 | label_list.sort() # Let's sort it for determinism
471 | num_labels = len(label_list)
472 |
473 | # --------------------------- Data Selection: -----------------------------------
474 | # data selection is ONLY applied on train set
475 | if args.with_data_selection:
476 | assert args.data_selection_region is not None, "You much specify `data_selection_region` when using `with_data_selection`"
477 | model_name = args.model_name_or_path
478 | if '/' in model_name:
479 | model_name = model_name.split('/')[-1]
480 | assert os.path.exists(f'dy_log/{args.task_name}/{model_name}/three_regions_data_indices.json'), "Selection indices file not found!"
481 | with open(f'dy_log/{args.task_name}/{model_name}/three_regions_data_indices.json','r') as f:
482 | three_regions_data_indices = json.loads(f.read())
483 | selected_indices = three_regions_data_indices[args.data_selection_region]
484 | raw_datasets['train'] = raw_datasets['train'].select(selected_indices)
485 |
486 | logger.info("~~~~~ Applying Data Selection ~~~~~ ")
487 | logger.info(f"~~~~~ Region: {args.data_selection_region} ")
488 | logger.info(f"~~~~~ Size: {len(raw_datasets['train'])} ")
489 |
490 | # ----------------------------------------------------------------------------------------------------
491 | # with open(f'dy_log/sst2/distilbert-base-cased/three_regions_data_indices.json','r') as f:
492 |
493 | # Load pretrained model and tokenizer
494 | #
495 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
496 | # download model & vocab.
497 | config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
498 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
499 | # model = AutoModelForSequenceClassification.from_pretrained(
500 | # args.model_name_or_path,
501 | # from_tf=bool(".ckpt" in args.model_name_or_path),
502 | # config=config,
503 | # ignore_mismatched_sizes=args.ignore_mismatched_sizes,
504 | # )
505 | hct_model = HCTForSequenceClassification(args.model_name_or_path, config, temperature=args.temperature, mu=args.mu, hard_with_ls=args.hard_with_ls, ls_weight=args.ls_weight, more_ambiguous=args.more_ambiguous)
506 |
507 |
508 |
509 | # 对非HF官方模型的名称的处理,只保留模型名
510 | if '/' in args.model_name_or_path:
511 | args.model_name_or_path = args.model_name_or_path.split('/')[-1]
512 |
513 | # Preprocessing the datasets
514 | # --------------- GLUE tasks (and some SuperGLUE tasks, like boolq/cb)---------------
515 | if args.task_name is not None:
516 | if 'noisy' in args.task_name:
517 | task_name = args.task_name.split('-')[0]
518 | sentence1_key, sentence2_key = task_to_keys[task_name]
519 | else:
520 | sentence1_key, sentence2_key = task_to_keys[args.task_name]
521 | # --------------- Other tasks ---------------
522 | else:
523 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
524 | # 这里的逻辑是这样的:
525 | # 对于非glue的数据集,要求要包含`label`字段
526 | # 然后希望你有`sentence1`, `sentence2`这两个字段,这样就跟glue对齐了
527 | # 如果你也不是用的这个名字,那就选择非label列的前两个字段来分别作为sentence1和sentence2
528 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
529 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
530 | sentence1_key, sentence2_key = "sentence1", "sentence2"
531 | elif "sentence" in non_label_column_names: # for classical classification tasks, like sst2
532 | sentence1_key, sentence2_key = ("sentence", None)
533 | elif "question" in non_label_column_names and "sentence" in non_label_column_names: # for tasks like qnli
534 | sentence1_key, sentence2_key = ("question", "sentence")
535 | else:
536 | if len(non_label_column_names) >= 2:
537 | sentence1_key, sentence2_key = non_label_column_names[:2]
538 | else:
539 | sentence1_key, sentence2_key = non_label_column_names[0], None
540 |
541 | # Some models have set the order of the labels to use, so let's make sure we do use it.
542 | label_to_id = None
543 | # if (
544 | # model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
545 | # and args.task_name is not None
546 | # and not is_regression
547 | # ):
548 | # # Some have all caps in their config, some don't.
549 | # label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
550 | # if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
551 | # logger.info(
552 | # f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
553 | # "Using it!"
554 | # )
555 | # label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
556 | # else:
557 | # logger.warning(
558 | # "Your model seems to have been trained with labels, but they don't match the dataset: ",
559 | # f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
560 | # "\nIgnoring the model labels as a result.",
561 | # )
562 | # elif args.task_name is None and not is_regression:
563 | # label_to_id = {v: i for i, v in enumerate(label_list)}
564 |
565 | # if label_to_id is not None:
566 | # model.config.label2id = label_to_id
567 | # model.config.id2label = {id: label for label, id in config.label2id.items()}
568 | # elif args.task_name is not None and not is_regression:
569 | # model.config.label2id = {l: i for i, l in enumerate(label_list)}
570 | # model.config.id2label = {id: label for label, id in config.label2id.items()}
571 |
572 | padding = "max_length" if args.pad_to_max_length else False
573 |
574 | def preprocess_function(examples):
575 | # Tokenize the texts
576 | texts = (
577 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
578 | )
579 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True)
580 |
581 | if "label" in examples:
582 | if label_to_id is not None:
583 | # Map labels to IDs (not necessary for GLUE tasks)
584 | result["labels"] = [label_to_id[l] for l in examples["label"]]
585 | else:
586 | # In all cases, rename the column to labels because the model will expect that.
587 | result["labels"] = examples["label"]
588 |
589 | if "confidence" in examples:
590 | result["confidences"] = examples["confidence"]
591 | return result
592 |
593 | with accelerator.main_process_first():
594 | processed_datasets = raw_datasets.map(
595 | preprocess_function,
596 | batched=True,
597 | # 得把这行改掉:
598 | # 以SST2为例,这里会把 ['sentence', 'label', 'idx'] 给去掉(不用担心label,因为上面已经新建了一个labels列)
599 | # remove_columns=raw_datasets["train"].column_names,
600 | # 改为:
601 | # 保留 idx 和 confidences,其他的可以去掉 (confidences 而不是 confidence,前者是模型要接收的名字)
602 | # remove_columns=[c for c in raw_datasets["train"].column_names if c not in ['idx', 'confidences']],
603 | desc="Running tokenizer on dataset",
604 | )
605 |
606 | train_dataset = processed_datasets["train"].remove_columns([c for c in raw_datasets["train"].column_names if c not in ['idx', 'confidences']])
607 | if args.task_name == 'mnli':
608 | eval_dataset = processed_datasets["validation_matched"].remove_columns([c for c in raw_datasets["validation_matched"].column_names if c not in ['idx']])
609 | else:
610 | eval_dataset = processed_datasets["validation"].remove_columns([c for c in raw_datasets["validation"].column_names if c not in ['idx']])
611 |
612 |
613 | # Log a few random samples from the training set:
614 | for index in random.sample(range(len(train_dataset)), 3):
615 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
616 |
617 | # DataLoaders creation:
618 | if args.pad_to_max_length:
619 | # If padding was already done ot max length, we use the default data collator that will just convert everything
620 | # to tensors.
621 | data_collator = default_data_collator
622 | else:
623 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
624 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
625 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
626 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))
627 |
628 | train_dataloader = DataLoader(
629 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
630 | )
631 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
632 |
633 | # Optimizer
634 | # Split weights in two groups, one with weight decay and the other not.
635 | no_decay = ["bias", "LayerNorm.weight"]
636 | optimizer_grouped_parameters = [
637 | {
638 | "params": [p for n, p in hct_model.named_parameters() if not any(nd in n for nd in no_decay)],
639 | "weight_decay": args.weight_decay,
640 | },
641 | {
642 | "params": [p for n, p in hct_model.named_parameters() if any(nd in n for nd in no_decay)],
643 | "weight_decay": 0.0,
644 | },
645 | ]
646 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
647 |
648 | # Scheduler and math around the number of training steps.
649 | overrode_max_train_steps = False
650 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
651 | if args.max_train_steps is None:
652 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
653 | overrode_max_train_steps = True
654 |
655 | lr_scheduler = get_scheduler(
656 | name=args.lr_scheduler_type,
657 | optimizer=optimizer,
658 | num_warmup_steps=args.num_warmup_steps,
659 | num_training_steps=args.max_train_steps,
660 | )
661 |
662 | # Prepare everything with our `accelerator`.
663 | hct_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
664 | hct_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
665 | )
666 |
667 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
668 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
669 | if overrode_max_train_steps:
670 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
671 | # Afterwards we recalculate our number of training epochs
672 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
673 |
674 | # Figure out how many steps we should save the Accelerator states
675 | if hasattr(args.checkpointing_steps, "isdigit"):
676 | checkpointing_steps = args.checkpointing_steps
677 | if args.checkpointing_steps.isdigit():
678 | checkpointing_steps = int(args.checkpointing_steps)
679 | else:
680 | checkpointing_steps = None
681 |
682 | # We need to initialize the trackers we use, and also store our configuration.
683 | # We initialize the trackers only on main process because `accelerator.log`
684 | # only logs on main process and we don't want empty logs/runs on other processes.
685 | if args.with_tracking:
686 | if accelerator.is_main_process:
687 | experiment_config = vars(args)
688 | # TensorBoard cannot log Enums, need the raw value
689 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
690 | accelerator.init_trackers("glue_no_trainer", experiment_config)
691 |
692 | # Get the metric function
693 | if args.task_name is not None:
694 | if args.task_name == 'snli':
695 | metric = load_metric("glue", 'mnli')
696 | elif args.task_name in ['boolq','cb','axb','axg']:
697 | metric = load_metric("super_glue" ,args.task_name)
698 | elif 'noisy' in args.task_name:
699 | task_name = args.task_name.split('-')[0]
700 | metric = load_metric("glue", task_name)
701 | else:
702 | metric = load_metric("glue", args.task_name)
703 | else:
704 | metric = load_metric("accuracy")
705 |
706 | # Train!
707 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
708 |
709 | logger.info("***** Running training *****")
710 | logger.info(f" Num examples = {len(train_dataset)}")
711 | logger.info(f" Num Epochs = {args.num_train_epochs}")
712 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
713 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
714 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
715 | logger.info(f" Total optimization steps = {args.max_train_steps}")
716 | # Only show the progress bar once on each machine.
717 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
718 | completed_steps = 0
719 | starting_epoch = 0
720 | # Potentially load in the weights and states from a previous save
721 | if args.resume_from_checkpoint:
722 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
723 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
724 | accelerator.load_state(args.resume_from_checkpoint)
725 | path = os.path.basename(args.resume_from_checkpoint)
726 | else:
727 | # Get the most recent checkpoint
728 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
729 | dirs.sort(key=os.path.getctime)
730 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
731 | # Extract `epoch_{i}` or `step_{i}`
732 | training_difference = os.path.splitext(path)[0]
733 |
734 | if "epoch" in training_difference:
735 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
736 | resume_step = None
737 | else:
738 | resume_step = int(training_difference.replace("step_", ""))
739 | starting_epoch = resume_step // len(train_dataloader)
740 | resume_step -= starting_epoch * len(train_dataloader)
741 |
742 | # ============================ Training Loop ============================
743 | for epoch in range(starting_epoch, args.num_train_epochs):
744 |
745 | hct_model.train()
746 | if args.with_tracking:
747 | total_loss = 0
748 | for step, batch in enumerate(train_dataloader):
749 | # We need to skip steps until we reach the resumed step
750 | if args.resume_from_checkpoint and epoch == starting_epoch:
751 | if resume_step is not None and step < resume_step:
752 | completed_steps += 1
753 | continue
754 | # batch中包含了idx字段,这里需要去除
755 | batch = {k:v for k,v in batch.items() if k != 'idx'}
756 | outputs = hct_model(**batch)
757 | loss = outputs.loss
758 | # We keep track of the loss at each epoch
759 | if args.with_tracking:
760 | total_loss += loss.detach().float()
761 | loss = loss / args.gradient_accumulation_steps
762 | accelerator.backward(loss)
763 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
764 | optimizer.step()
765 | lr_scheduler.step()
766 | optimizer.zero_grad()
767 | progress_bar.update(1)
768 | completed_steps += 1
769 |
770 | if isinstance(checkpointing_steps, int):
771 | if completed_steps % checkpointing_steps == 0:
772 | output_dir = f"step_{completed_steps }"
773 | if args.output_dir is not None:
774 | output_dir = os.path.join(args.output_dir, output_dir)
775 | accelerator.save_state(output_dir)
776 |
777 | if completed_steps >= args.max_train_steps:
778 | break
779 | # ------------------ Recording Training Dynamics --------------------
780 | # 在每一个epoch之后,对train set所有样本再过一遍,记录dynamics
781 | # 每个epoch单独一个文件
782 | # if args.do_recording:
783 | # if accelerator.is_main_process:
784 | # if not os.path.exists(f'dy_log/{args.task_name}/'):
785 | # os.mkdir(f'dy_log/{args.task_name}/')
786 | # if not os.path.exists(f'dy_log/{args.task_name}/{args.model_name_or_path}'):
787 | # os.mkdir(f'dy_log/{args.task_name}/{args.model_name_or_path}')
788 | # log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/'
789 | # if not os.path.exists(log_path):
790 | # os.mkdir(log_path)
791 |
792 | # accelerator.wait_for_everyone() # 只在 main process 里面创建文件夹,然后让其他 process 等待 main process 创建完毕
793 | # log_path = f'dy_log/{args.task_name}/{args.model_name_or_path}/training_dynamics/'
794 | # print('-*-*-*- ',log_path, os.path.exists(log_path),accelerator.device)
795 |
796 | # logger.info('---------- Recording Training Dynamics (Epoch %s) -----------'%epoch)
797 | # training_dynamics = []
798 | # all_ids = []
799 | # for step, batch in enumerate(tqdm(train_dataloader)):
800 | # # print('- - - - - - - - - - ',len(batch['idx']), accelerator.device)
801 | # idx_list = batch['idx']#.tolist()
802 | # label_list = batch['labels']#.tolist()
803 | # batch = {k:v for k,v in batch.items() if k != 'idx'}
804 | # logits_list = model(**batch).logits#.tolist() # [[],[],[],...] batch_size个[]
805 | # # 这里的关键:通过 gather 把每个 GPU上的结果合并
806 | # # 由于在使用多卡训练时,不同卡可能存在样本的重复,同一个卡也会对最后一个batch进行补齐,也会样本重复
807 | # # 使用 gather 的话,就可以按照原来的分配方式,逆着组合回去,就不用你自己处理了
808 | # # gather 之后的,在每个卡上,下述变量里包含的数量,都等同于只使用单卡进行训练时的数量
809 | # # 所以下面的for训练执行完之后,training_dynamics里就包含了全部样本,你在写入文件时,记住只在一个 process 中写入
810 | # idx_list, label_list, logits_list = accelerator.gather((idx_list, label_list, logits_list))
811 | # # print('idx_list', idx_list.shape, accelerator.device)
812 | # # print('label_list', label_list.shape, accelerator.device)
813 |
814 | # for idx, label, logits in zip(idx_list.tolist(), label_list.tolist(), logits_list.tolist()):
815 | # if idx in all_ids: # 由于 data_loader 可能会对最后一个 batch 进行补全,所以这里要去掉重复的样本
816 | # continue
817 | # all_ids.append(idx)
818 | # record = {'guid': idx, 'logits_epoch_%s'%epoch: logits, 'gold': label, 'device':str(accelerator.device)}
819 | # training_dynamics.append(record)
820 |
821 | # if accelerator.is_main_process:
822 | # print('---- Num of training_dynamics: ',len(training_dynamics),' Device: ', str(accelerator.device))
823 | # print(len(all_ids),len(list(set(all_ids))),str(accelerator.device))
824 | # assert os.path.exists(log_path),log_path
825 | # writer = open(log_path + f'dynamics_epoch_{epoch}.jsonl', 'w')
826 | # for record in training_dynamics:
827 | # writer.write(json.dumps(record) + "\n")
828 | # logger.info(f'Epoch {epoch} Saved to [{log_path}]')
829 | # writer.close()
830 | # accelerator.wait_for_everyone()
831 | # ------------------------------------------------------------------------
832 |
833 | def hct_batch_inference(model, batch, hard_choice=True, T=1, single_expert=None):
834 | model.eval()
835 | with torch.no_grad():
836 | outputs = model(**batch)
837 | logits = outputs.logits
838 |
839 | if not single_expert:
840 | if hard_choice:
841 | expert_choices = torch.argmax(logits['gate'],dim=1) # expert choice for each sample in the batch
842 | chosen_logits = logits['easy'] * (1-expert_choices).view(-1,1) + logits['hard'] * expert_choices.view(-1,1)
843 | return chosen_logits
844 | else:
845 | weights = F.softmax(logits['gate']/T, dim=1)
846 | # !!!
847 | if args.more_ambiguous:
848 | weights = torch.where(weights>0.5, 1-torch.abs(weights-0.5), weights)
849 | weighted_logits = logits['easy'] * weights[:,0].view(-1,1) + logits['hard'] * weights[:,1].view(-1,1)
850 | return weighted_logits
851 | else:
852 | assert single_expert in ['easy', 'hard']
853 | return logits[single_expert]
854 |
855 | # def hct_batch_inference_single(model, batch, expert='easy'):
856 | # model.eval()
857 | # with torch.no_grad():
858 | # outputs = model(**batch)
859 | # logits = outputs.logits
860 | # return logits[expert]
861 |
862 |
863 | # evaluation (validation set)
864 | # hard inference:
865 | hct_model.eval()
866 | samples_seen = 0
867 | for step, batch in enumerate(eval_dataloader):
868 | batch = {k:v for k,v in batch.items() if k != 'idx'}
869 | logits = hct_batch_inference(hct_model, batch, hard_choice=True, T=args.temperature) # , single_expert='easy'
870 | predictions = logits.argmax(dim=-1) if not is_regression else logits.squeeze()
871 | predictions, references = accelerator.gather((predictions, batch["labels"]))
872 | # If we are in a multiprocess environment, the last batch has duplicates
873 | if accelerator.num_processes > 1:
874 | if step == len(eval_dataloader) - 1:
875 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
876 | references = references[: len(eval_dataloader.dataset) - samples_seen]
877 | else:
878 | samples_seen += references.shape[0]
879 | metric.add_batch(
880 | predictions=predictions,
881 | references=references,
882 | )
883 | eval_metric = metric.compute()
884 | logger.info(f"***Evaluation (hard inference) *** epoch {epoch}: {eval_metric}")
885 |
886 | # soft inference:
887 | hct_model.eval()
888 | samples_seen = 0
889 | for step, batch in enumerate(eval_dataloader):
890 | batch = {k:v for k,v in batch.items() if k != 'idx'}
891 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature) # , single_expert='easy'
892 | predictions = logits.argmax(dim=-1) if not is_regression else logits.squeeze()
893 | predictions, references = accelerator.gather((predictions, batch["labels"]))
894 | # If we are in a multiprocess environment, the last batch has duplicates
895 | if accelerator.num_processes > 1:
896 | if step == len(eval_dataloader) - 1:
897 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
898 | references = references[: len(eval_dataloader.dataset) - samples_seen]
899 | else:
900 | samples_seen += references.shape[0]
901 | metric.add_batch(
902 | predictions=predictions,
903 | references=references,
904 | )
905 | eval_metric = metric.compute()
906 | logger.info(f"***Evaluation (soft inference) *** epoch {epoch}: {eval_metric}")
907 |
908 | if args.with_tracking:
909 | accelerator.log(
910 | {
911 | "accuracy" if args.task_name is not None else "glue": eval_metric,
912 | "train_loss": total_loss.item() / len(train_dataloader),
913 | "epoch": epoch,
914 | "step": completed_steps,
915 | },
916 | step=completed_steps,
917 | )
918 |
919 | if args.checkpointing_steps == "epoch":
920 | output_dir = f"epoch_{epoch}"
921 | if args.output_dir is not None:
922 | output_dir = os.path.join(args.output_dir, output_dir)
923 | accelerator.save_state(output_dir)
924 | # ============================ End Training Loop ============================
925 |
926 |
927 | if args.output_dir is not None:
928 | accelerator.wait_for_everyone()
929 | unwrapped_model = accelerator.unwrap_model(hct_model)
930 | unwrapped_model.save_pretrained(
931 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
932 | )
933 | if accelerator.is_main_process:
934 | tokenizer.save_pretrained(args.output_dir)
935 | # if args.push_to_hub:
936 | # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
937 |
938 | if args.task_name == "mnli":
939 | # Final evaluation on mismatched validation set
940 | # eval_dataset = processed_datasets["validation_mismatched"]
941 | eval_dataset = processed_datasets["validation_mismatched"].remove_columns([c for c in raw_datasets["validation_mismatched"].column_names if c not in ['idx']])
942 | eval_dataloader = DataLoader(
943 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
944 | )
945 | eval_dataloader = accelerator.prepare(eval_dataloader)
946 |
947 | hct_model.eval()
948 | for step, batch in enumerate(eval_dataloader):
949 | batch = {k:v for k,v in batch.items() if k != 'idx'}
950 | # outputs = hct_model(**batch)
951 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature)
952 | predictions = logits.argmax(dim=-1)
953 | metric.add_batch(
954 | predictions=accelerator.gather(predictions),
955 | references=accelerator.gather(batch["labels"]),
956 | )
957 | eval_metric = metric.compute()
958 | logger.info(f"mnli-mm: {eval_metric}")
959 |
960 | if args.task_name == "snli":
961 | # Final evaluation on mismatched validation set
962 | eval_dataset = processed_datasets["test"].remove_columns([c for c in raw_datasets["test"].column_names if c not in ['idx']])
963 | eval_dataloader = DataLoader(
964 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
965 | )
966 | eval_dataloader = accelerator.prepare(eval_dataloader)
967 |
968 | hct_model.eval()
969 | for step, batch in enumerate(eval_dataloader):
970 | # batch中包含了idx字段,这里需要去除
971 | batch = {k:v for k,v in batch.items() if k != 'idx'}
972 | logits = hct_batch_inference(hct_model, batch, hard_choice=False, T=args.temperature)
973 | predictions = logits.argmax(dim=-1)
974 | metric.add_batch(
975 | predictions=accelerator.gather(predictions),
976 | references=accelerator.gather(batch["labels"]),
977 | )
978 |
979 | eval_metric = metric.compute()
980 | logger.info(f"snli-test: {eval_metric}")
981 |
982 |
983 | if args.output_dir is not None:
984 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
985 | json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
986 |
987 |
988 | if __name__ == "__main__":
989 | main()
--------------------------------------------------------------------------------