├── README.md ├── img └── model_architecture.png ├── models ├── bert.py └── meta_learning_via_bert_incontext_tuning.py ├── requirements.txt ├── run.py ├── train ├── train_bert.py └── train_meta_learning_via_bert_incontext_tuning.py └── utils ├── batch_collator.py ├── load_data.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [Automated Scoring for Reading Comprehension via In-context BERT Tuning](https://arxiv.org/abs/2205.09864) 2 | 3 | We present our (grand prize-winning) solution to the [NAEP Automated Scoring Challenge](https://github.com/NAEP-AS-Challenge/info) for reading comprehension items. We develop a novel automated scoring approach based on meta-learning ideas via in-context tuning of language models. This repository contains the implementation of our best performing model *Meta-trained BERT In-context* and the BERT fine-tuning baseline from our paper [Automated Scoring for Reading Comprehension via In-context BERT Tuning](https://arxiv.org/abs/2205.09864) by [Nigel Fernandez](https://www.linkedin.com/in/ni9elf/), [Aritra Ghosh](https://arghosh.github.io), [Naiming Liu](https://www.linkedin.com/in/naiming-liu-lucy0817/), [Zichao Wang](https://zw16.web.rice.edu), [Benoît Choffin](https://benoitchoffin.github.io/about/), [Richard Baraniuk](https://richb.rice.edu), and [Andrew Lan](https://people.umass.edu/~andrewlan) published at AIED 2022. 4 | 5 |

6 | 7 |

8 | 9 | For any questions please [email](mailto:nigel@cs.umass.edu) or raise an issue. 10 | 11 | If you find our code or paper useful, please consider citing: 12 | ``` 13 | @inproceedings{ fernandez2022AS, 14 | title={ Automated Scoring for Reading Comprehension via In-context {BERT} Tuning }, 15 | author={ Fernandez, Nigel and Ghosh, Aritra and Liu, Naiming and Wang, Zichao and Choffin, Beno{\^{\i}}t and Baraniuk, Richard and Lan, Andrew }, 16 | booktitle={ 23rd International Conference on Artificial Intelligence in Education (AIED 2022) }, 17 | year={ 2022 } 18 | } 19 | ``` 20 | 21 | 22 | 23 | ## Contents 24 | 25 | 1. [Installation](#installation) 26 | 2. [Usage](#usage) 27 | 3. [Dataset](#dataset) 28 | 4. [Code Structure](#code-structure) 29 | 5. [Acknowledgements](#acknowledgements) 30 | 31 | 32 | 33 | ## Installation 34 | 35 | Our code is tested with `Python 3.7.4` in a virtual environment created using: 36 | 37 | ``` 38 | virtualenv -p python3 env 39 | ``` 40 | 41 | Install the dependencies listed in the `requirements.txt` file by running: 42 | ``` 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | 47 | 48 | ## Usage 49 | A [Neptune](https://neptune.ai) account is required to log train-test-val information (loss, metrics, etc). Best model checkpoints are saved locally. 50 | 51 | To train-val-test a BERT baseline model for fine-tuning, please run: 52 | ``` 53 | python run.py\ 54 | --name "lm_base_response_only"\ 55 | --task "[name]"\ 56 | --lm "bert-base-uncased"\ 57 | --batch_size 32\ 58 | --neptune_project "[name]"\ 59 | --data_folder [name]\ 60 | --cross_val_fold 1\ 61 | --cuda\ 62 | --neptune_project [name]\ 63 | --neptune 64 | ``` 65 | To train-val-test our Meta-trained BERT In-context model, please run: 66 | ``` 67 | python run.py\ 68 | --name "meta_lm_incontext"\ 69 | --task "meta_learning_via_lm_incontext_tuning"\ 70 | --lm "bert-base-uncased"\ 71 | --batch_size 32\ 72 | --neptune_project "[name]"\ 73 | --meta_learning\ 74 | --num_test_avg 8\ 75 | --num_val_avg 8\ 76 | --data_folder [name]\ 77 | --cross_val_fold 1\ 78 | --cuda\ 79 | --neptune_project [name]\ 80 | --neptune 81 | ``` 82 | 83 | 84 | Argument information: 85 | 86 | `-h, --help` Show this help message and exit 87 | 88 | `--name NAME` Name of the experiment 89 | 90 | `--neptune_project NEPTUNE_PROJECT` Name of the neptune project 91 | 92 | `--lm LM` Base language model (provide any Hugging face model name) 93 | 94 | `--task TASK` Item name (not required for meta learning via in-context tuning) 95 | 96 | `--demographic` Use demographic information of student 97 | 98 | `--meta_learning` Enable meta-learning via BERT in-context tuning 99 | 100 | `--num_test_avg NUM_TEST_AVG` Number of different sets of randomly sampled examples per test datapoint to average score predictions 101 | 102 | `--num_val_avg NUM_VAL_AVG` Number of different sets of randomly sampled examples per val datapoint to average score predictions 103 | 104 | `--num_examples NUM_EXAMPLES` Number of in-context examples from each score class to add to input 105 | 106 | `--trunc_len TRUNC_LEN` Max number of words in each in-context example 107 | 108 | `--lr_schedule LR_SCHEDULE` Learning rate schedule to use 109 | 110 | `--opt {sgd,adam,lars}` Optimizer to use 111 | 112 | `--iters ITERS` Number of epochs 113 | 114 | `--lr LR` Base learning rate 115 | 116 | `--batch_size BATCH_SIZE` Batch size 117 | 118 | `--data_folder DATA_FOLDER` Dataset folder name containing train-val-test splits for each cross validation fold 119 | 120 | `--cross_val_fold CROSS_VAL_FOLD` Cross validation fold to use 121 | 122 | `--save_freq SAVE_FREQ` Epoch frequency to save the model 123 | 124 | `--eval_freq EVAL_FREQ` Epoch frequency for evaluation 125 | 126 | `--workers WORKERS` Number of data loader workers 127 | 128 | `--seed SEED` Random seed 129 | 130 | `--cuda` Use cuda 131 | 132 | `--save` Save model every save_freq epochs 133 | 134 | `--neptune` Enable logging to Neptune 135 | 136 | `--debug` Debug mode with less items and smaller datasets 137 | 138 | `--amp` Apply automatic mixed precision training to save GPU memory 139 | 140 | 141 | 142 | ## Dataset 143 | For our experimentation, we used the training dataset provided by the [NAEP Automated Scoring Challenge](https://github.com/NAEP-AS-Challenge/info) organizers. Please contact them for usage. 144 | 145 | To run our approach for equivalent problems, the dataset structure expected is three separate train, validation and test json files, each containing a list of data samples represented as dictionaries `[{key:value}]`. Each data sample dictionary contains the following (key, value) pairs: 146 | ``` 147 | {'bl':[string], 'l1':[int], 'l2':[int], 'sx':[string], 'rc':[string], 'txt':[string]} 148 | ``` 149 | 150 | The keys above are: 151 | 152 | `'bl'`: Unique database like key to identify student response 153 | 154 | `'l1'` : Score label by first human rater 155 | 156 | `'l2'` : Score label by second human rater (set as -1 if not available) 157 | 158 | `'sx'` : Sex of student (optional, required if demographic argument is true) 159 | 160 | `'rc'` : Race of student (optional, required if demographic argument is true) 161 | 162 | `'txt'` : Student response text to the reading comprehension item to be scored 163 | 164 | 165 | 166 | ## Code Structure 167 | `models` contains `bert.py` and `meta_learning_via_bert_incontext_tuning.py` implementing BERT and meta-trained BERT in-context models, respectively. 168 | 169 | `train` contains `train_bert.py` and `train_meta_learning_via_bert_incontext_tuning.py` implementing training scripts for BERT and meta-trained BERT in-context, respectively. 170 | 171 | `utils` contains `batch_collator.py` to collate batches for training, `load_data.py` to load train-val-test sets, and `utils.py` with general utility functions. 172 | 173 | 174 | 175 | ## Acknowledgements 176 | Fernandez, Ghosh, and Lan are partially supported by the National Science Foundation under grants IIS-1917713 and IIS-2118706. -------------------------------------------------------------------------------- /img/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ni9elf/automated-scoring/4b8bc5e7e3cfe08220feddf114eeeca2452a53b1/img/model_architecture.png -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AdamW, GPT2Tokenizer, GPT2LMHeadModel, GPT2ForSequenceClassification, GPT2Config 3 | import torch 4 | from torch import nn 5 | 6 | from utils.batch_collator import CollateWraper 7 | from utils.load_data import load_dataset_base 8 | 9 | 10 | class LanguageModelBase(nn.Module): 11 | def __init__(self, params, device): 12 | super().__init__() 13 | self.params = copy.deepcopy(params) 14 | self.device = device 15 | 16 | 17 | def prepare_model(self): 18 | self.config = AutoConfig.from_pretrained(self.params.lm, num_labels=self.num_labels) 19 | self.tokenizer = AutoTokenizer.from_pretrained(self.params.lm) 20 | if self.tokenizer.pad_token is None: 21 | self.tokenizer.pad_token = self.tokenizer.eos_token 22 | self.config.pad_token_id = self.config.eos_token_id 23 | self.model = AutoModelForSequenceClassification.from_pretrained(self.params.lm, config=self.config).to(self.device) 24 | self.optimizer = AdamW(self.model.parameters(), lr=self.params.lr) 25 | # Multi GPU mode 26 | if( torch.cuda.device_count() > 1 ): 27 | self.model = nn.DataParallel(self.model) 28 | 29 | 30 | def prepare_data(self): 31 | data = load_dataset_base(self.params.task, self.params.debug, self.params.data_folder, self.params.cross_val_fold) 32 | self.trainset = data['train'] 33 | self.validset = data['valid'] 34 | self.testset = data['test'] 35 | self.max_label = max(data['train_dist'].keys()) 36 | self.min_label = min(data['train_dist'].keys()) 37 | self.num_labels = self.max_label - self.min_label + 1 38 | 39 | 40 | def dataloaders(self): 41 | train_loader = torch.utils.data.DataLoader(self.trainset, collate_fn=CollateWraper(self.tokenizer, self.min_label), 42 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=True, drop_last=False) 43 | valid_loader = torch.utils.data.DataLoader(self.validset, collate_fn=CollateWraper(self.tokenizer, self.min_label), 44 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=False, drop_last=False) 45 | test_loader = torch.utils.data.DataLoader(self.testset, collate_fn=CollateWraper(self.tokenizer, self.min_label), 46 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=False, drop_last=False) 47 | 48 | return train_loader, valid_loader, test_loader 49 | 50 | 51 | def zero_grad(self): 52 | self.optimizer.zero_grad() 53 | 54 | 55 | def grad_step(self, scaler): 56 | if( self.params.amp ): 57 | scaler.step(self.optimizer) 58 | else: 59 | self.optimizer.step() 60 | 61 | 62 | def train_step(self, batch, scaler): 63 | self.zero_grad() 64 | 65 | # Cast operations to mixed precision 66 | if( self.params.amp ): 67 | with torch.cuda.amp.autocast(): 68 | outputs = self.model(**batch["inputs"]) 69 | else: 70 | outputs = self.model(**batch["inputs"]) 71 | 72 | loss = outputs.loss 73 | 74 | # Multi gpu mode 75 | if( torch.cuda.device_count() > 1 ): 76 | if( self.params.amp ): 77 | scaler.scale(loss.sum()).backward() 78 | else: 79 | loss.sum().backward() 80 | else: 81 | if( self.params.amp ): 82 | scaler.scale(loss).backward() 83 | else: 84 | loss.backward() 85 | 86 | self.grad_step(scaler) 87 | 88 | logits = outputs.logits 89 | predictions = torch.argmax(logits, dim=-1) 90 | acc = ( predictions == batch["inputs"]["labels"] ) 91 | 92 | if( self.params.amp ): 93 | scaler.update() 94 | 95 | return {'loss': loss.detach().cpu(), 96 | 'acc':acc.detach().cpu(), 97 | 'kappa':{ 98 | 'preds':predictions.detach().cpu(), 99 | 'labels':batch["inputs"]["labels"].detach().cpu() 100 | } 101 | } 102 | 103 | 104 | def eval_step(self, batch): 105 | # Same as test step 106 | out = self.test_step(batch) 107 | 108 | return out 109 | 110 | 111 | def test_step(self, batch): 112 | with torch.no_grad(): 113 | if( self.params.amp ): 114 | with torch.cuda.amp.autocast(): 115 | outputs = self.model(**batch["inputs"]) 116 | else: 117 | outputs = self.model(**batch["inputs"]) 118 | 119 | loss = outputs.loss 120 | logits = outputs.logits 121 | predictions = torch.argmax(logits, dim=-1) 122 | acc = (predictions == batch["inputs"]["labels"]) 123 | 124 | return { 125 | 'loss': loss.detach().cpu(), 126 | 'acc':acc.detach().cpu(), 127 | 'kappa':{ 128 | 'preds':predictions.detach().cpu(), 129 | 'labels':batch["inputs"]["labels"].detach().cpu() 130 | } 131 | } -------------------------------------------------------------------------------- /models/meta_learning_via_bert_incontext_tuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.bert import LanguageModelBase 5 | from utils.batch_collator import CollateWraperInContextTuningMetaLearning 6 | from utils.load_data import load_dataset_in_context_tuning_with_meta_learning 7 | 8 | 9 | class MetaLearningViaLanguageModelInContextTuning(LanguageModelBase): 10 | def __init__(self, params, device, task_list, task_to_question, task_to_passage, passages): 11 | super().__init__(params, device) 12 | self.task_list = task_list 13 | self.task_to_question = task_to_question 14 | self.task_to_passage = task_to_passage 15 | self.passages = passages 16 | 17 | 18 | def prepare_data(self): 19 | self.data_meta, _, _ = load_dataset_in_context_tuning_with_meta_learning(self.params.debug, self.params.data_folder, 20 | self.task_list, self.params.cross_val_fold) 21 | self.trainset = self.data_meta['train'] 22 | self.validsets = {} 23 | self.testsets = {} 24 | for task in self.task_list: 25 | self.validsets[task] = self.data_meta[task]['valid'] 26 | self.testsets[task] = self.data_meta[task]['test'] 27 | self.test_batch_size = 12 28 | self.val_batch_size = 12 29 | 30 | # For meta-trained model, min_label=1 and max_label=4 is the same for all items since a fixed classification layer is used for all items 31 | self.min_label = 1 32 | self.max_label = 4 33 | self.num_labels = self.max_label - self.min_label + 1 34 | 35 | 36 | def dataloaders(self): 37 | collate_fn_train = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question, 38 | self.params.num_examples, 39 | self.params.trunc_len, mode="train", 40 | use_demographic = self.params.demographic) 41 | collate_fn_val = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question, 42 | self.params.num_examples, 43 | self.params.trunc_len, mode="val", num_val_avg=self.params.num_val_avg, 44 | val_batch_size=self.val_batch_size, 45 | use_demographic = self.params.demographic) 46 | collate_fn_test = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question, 47 | self.params.num_examples, 48 | self.params.trunc_len, mode="test", num_test_avg=self.params.num_test_avg, 49 | test_batch_size=self.test_batch_size, 50 | use_demographic = self.params.demographic) 51 | 52 | train_loader = torch.utils.data.DataLoader(self.trainset, collate_fn=collate_fn_train, batch_size=self.params.batch_size, 53 | num_workers=self.params.workers, shuffle=True, drop_last=False) 54 | valid_loaders = {} 55 | test_loaders = {} 56 | for task in self.task_list: 57 | # For validation, batch_size after collating = batch_size * num_val_avg 58 | valid_loaders[task] = torch.utils.data.DataLoader(self.validsets[task], collate_fn=collate_fn_val, 59 | batch_size=self.val_batch_size, num_workers=self.params.workers, 60 | shuffle=False, drop_last=False) 61 | # For testing, batch_size after collating = batch_size * num_test_avg 62 | test_loaders[task] = torch.utils.data.DataLoader(self.testsets[task], collate_fn=collate_fn_test, 63 | batch_size=self.test_batch_size, num_workers=self.params.workers, 64 | shuffle=False, drop_last=False) 65 | 66 | return train_loader, valid_loaders, test_loaders 67 | 68 | 69 | def train_step(self, batch, scaler, loss_func): 70 | self.zero_grad() 71 | 72 | if( self.params.amp ): 73 | # Cast operations to mixed precision 74 | with torch.cuda.amp.autocast(): 75 | outputs = self.model(**batch["inputs"]) 76 | else: 77 | outputs = self.model(**batch["inputs"]) 78 | 79 | logits = outputs.logits 80 | 81 | # Mask invalid score classes as negative infinity 82 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index 83 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device) 84 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1 85 | mask = mask.cumsum(dim=1)[:, :-1] 86 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf')) 87 | 88 | # Calculate masked cross entropy loss 89 | loss = loss_func(masked_logits.view(-1, self.num_labels), batch["inputs"]["labels"].view(-1)) 90 | 91 | # Apply a softmax over valid score classes only 92 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1) 93 | 94 | # Calculate accuracy 95 | predictions = torch.argmax(softmax_outs, dim=-1) 96 | acc = ( predictions == batch["inputs"]["labels"] ) 97 | 98 | # Multi gpu mode 99 | if( torch.cuda.device_count() > 1 ): 100 | if( self.params.amp ): 101 | scaler.scale(loss.sum()).backward() 102 | else: 103 | loss.sum().backward() 104 | else: 105 | if( self.params.amp ): 106 | scaler.scale(loss).backward() 107 | else: 108 | loss.backward() 109 | 110 | self.grad_step(scaler) 111 | 112 | if( self.params.amp ): 113 | scaler.update() 114 | 115 | return {'loss': loss.detach().cpu(), 116 | 'acc':acc.detach().cpu(), 117 | 'kappa':{ 118 | 'preds':predictions.detach().cpu(), 119 | 'labels':batch["inputs"]["labels"].detach().cpu() 120 | } 121 | } 122 | 123 | 124 | def eval_step(self, batch): 125 | # Validation time averaging: Apply different sets of randomly sampled in-context examples per val datapoint to average score predictions 126 | 127 | with torch.no_grad(): 128 | if( self.params.amp ): 129 | with torch.cuda.amp.autocast(): 130 | outputs = self.model(**batch["inputs"]) 131 | else: 132 | outputs = self.model(**batch["inputs"]) 133 | 134 | 135 | loss = outputs.loss 136 | # Dimension of logits = batch_size X num_classes 137 | # Where batch_size = val_batch_size * num_val_avg 138 | logits = outputs.logits 139 | 140 | # Mask invalid score classes as negative infinity 141 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index 142 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device) 143 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1 144 | mask = mask.cumsum(dim=1)[:, :-1] 145 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf')) 146 | 147 | # Apply a softmax over valid score classes only 148 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1) 149 | 150 | # Reshaped dimension of softmax_outs = val_batch_size X num_val_avg X num_classes 151 | softmax_outs = torch.reshape(softmax_outs, (batch["actual_batch_size"], self.params.num_val_avg, -1)) 152 | 153 | # Mean averaging on softmax_outs across val_samples 154 | # Dimension of outs = val_batch_size X num_classes 155 | outs = torch.mean(softmax_outs, dim=1) 156 | # Dimension of predictions = test_batch_size X 1 157 | predictions = torch.argmax(outs, dim=-1) 158 | 159 | # Pick every num_val_avg label since labels are repeated in batch 160 | batch["inputs"]["labels"] = batch["inputs"]["labels"][::self.params.num_val_avg] 161 | # Calculate accuracy 162 | acc = (predictions == batch["inputs"]["labels"]) 163 | 164 | return { 165 | 'loss': loss.detach().cpu(), 166 | 'acc':acc.detach().cpu(), 167 | 'kappa':{ 168 | 'preds':predictions.detach().cpu(), 169 | 'labels':batch["inputs"]["labels"].detach().cpu() 170 | } 171 | } 172 | 173 | 174 | def test_step(self, batch): 175 | # Test time averaging: Apply different sets of randomly sampled in-context examples per test datapoint to average score predictions 176 | 177 | with torch.no_grad(): 178 | if( self.params.amp ): 179 | with torch.cuda.amp.autocast(): 180 | outputs = self.model(**batch["inputs"]) 181 | else: 182 | outputs = self.model(**batch["inputs"]) 183 | 184 | loss = outputs.loss 185 | # Dimension of logits = batch_size X num_classes 186 | # Where batch_size = test_batch_size * num_test_avg 187 | logits = outputs.logits 188 | 189 | # Mask invalid score classes as negative infinity 190 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index 191 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device) 192 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1 193 | mask = mask.cumsum(dim=1)[:, :-1] 194 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf')) 195 | 196 | # Apply a softmax over valid score classes only 197 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1) 198 | 199 | # Reshaped dimension of softmax_outs = test_batch_size X num_test_avg X num_classes 200 | softmax_outs = torch.reshape(softmax_outs, (batch["actual_batch_size"], self.params.num_test_avg, -1)) 201 | 202 | # Mean averaging on softmax_outs across test_samples 203 | # Dimension of outs = test_batch_size X num_classes 204 | outs = torch.mean(softmax_outs, dim=1) 205 | # Dimension of predictions = test_batch_size X 1 206 | predictions = torch.argmax(outs, dim=-1) 207 | 208 | # Pick every num_test_avg label since labels are repeated in batch 209 | batch["inputs"]["labels"] = batch["inputs"]["labels"][::self.params.num_test_avg] 210 | # Calculate accuracy 211 | acc = (predictions == batch["inputs"]["labels"]) 212 | 213 | return { 214 | 'loss': loss.detach().cpu(), 215 | 'acc':acc.detach().cpu(), 216 | 'kappa':{ 217 | 'preds':predictions.detach().cpu(), 218 | 'labels':batch["inputs"]["labels"].detach().cpu() 219 | } 220 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | scikit-learn 4 | neptune-client -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import random 4 | 5 | from transformers import logging 6 | 7 | from utils.utils import add_params 8 | from train.train_bert import train_bert 9 | from train.train_meta_learning_via_bert_incontext_tuning import train_meta_learning_via_bert_incontext_tuning 10 | 11 | 12 | # Disable warnings in hugging face logger 13 | logging.set_verbosity_error() 14 | # Set your Neptune token for logging 15 | NEPTUNE_API_TOKEN = "SET_TOKEN" 16 | 17 | 18 | def main(): 19 | args = add_params() 20 | 21 | # Local saved models dir 22 | saved_models_dir = "../../../saved_models/" 23 | # Set random seed 24 | if args.seed != -1: 25 | random.seed(args.seed) 26 | torch.manual_seed(args.seed) 27 | cudnn.deterministic = True 28 | # Set device 29 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 30 | if args.cuda: assert device.type == 'cuda', 'no gpu found!' 31 | # Logging to neptune 32 | run = None 33 | if args.neptune: 34 | import neptune.new as neptune 35 | run = neptune.init( 36 | project = args.neptune_project, 37 | api_token = NEPTUNE_API_TOKEN, 38 | capture_hardware_metrics = False, 39 | name = args.name, 40 | ) 41 | run["parameters"] = vars(args) 42 | 43 | if( args.amp ): 44 | # Using pytorch automatic mixed precision (fp16/fp32) for faster training 45 | # https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/ 46 | scaler = torch.cuda.amp.GradScaler() 47 | else: 48 | scaler = None 49 | 50 | # Train-val-test model 51 | if( args.meta_learning ): 52 | train_meta_learning_via_bert_incontext_tuning(args, run, device, saved_models_dir, scaler) 53 | else: 54 | train_bert(args, run, device, saved_models_dir, scaler) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() -------------------------------------------------------------------------------- /train/train_bert.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | 4 | from models.bert import LanguageModelBase 5 | from utils.utils import agg_all_metrics, save_model 6 | 7 | 8 | def train_bert(args, run, device, saved_models_dir, scaler): 9 | # prepare data and model for training 10 | model = LanguageModelBase(args, device=device) 11 | model.prepare_data() 12 | model.prepare_model() 13 | 14 | # Log item info to neptune 15 | if args.neptune: 16 | run["parameters/n_labels"] = model.max_label - model.min_label + 1 17 | 18 | # Metric used is Quadratic Weighted Kappa (QWK) 19 | # Best test kappa 20 | best_test_metric = -1 21 | # Test kappa corresponding to best validation kappa 22 | test_metric_for_best_valid_metric = -1 23 | # Best validation kappa 24 | best_valid_metric = -1 25 | 26 | 27 | # Train-val-test loop 28 | for cur_iter in tqdm(range(args.iters)): 29 | train_loader, valid_loader, test_loader = model.dataloaders() 30 | 31 | # Train epoch 32 | start_time = time.time() 33 | # Set model to train mode needed if dropout, etc is used 34 | model.train() 35 | train_logs = [] 36 | for batch in train_loader: 37 | batch = {k: v.to(device) for k, v in batch.items()} 38 | logs = model.train_step(batch, scaler) 39 | train_logs.append(logs) 40 | train_it_time = time.time() - start_time 41 | 42 | # Aggregate logs across all batches 43 | train_logs = agg_all_metrics(train_logs) 44 | # Log to neptune 45 | if args.neptune: 46 | run["metrics/train/accuracy"].log(train_logs['acc']) 47 | run["metrics/train/kappa"].log(train_logs['kappa']) 48 | run["metrics/train/loss"].log(train_logs['loss']) 49 | run["logs/train/it_time"].log(train_it_time) 50 | 51 | # Set model to test mode needed if dropout, etc is used 52 | model.eval() 53 | if( (cur_iter % args.eval_freq == 0) or (cur_iter >= args.iters) ): 54 | test_logs, valid_logs = [], [] 55 | # Validation epoch 56 | eval_start_time = time.time() 57 | for batch in valid_loader: 58 | batch = {k: v.to(device) for k, v in batch.items()} 59 | logs = model.eval_step(batch) 60 | valid_logs.append(logs) 61 | eval_it_time = time.time()-eval_start_time 62 | 63 | # Test epoch 64 | test_start_time = time.time() 65 | for batch in test_loader: 66 | batch = {k: v.to(device) for k, v in batch.items()} 67 | logs = model.test_step(batch) 68 | test_logs.append(logs) 69 | test_it_time = time.time()-test_start_time 70 | 71 | # Aggregate logs across batches 72 | valid_logs = agg_all_metrics(valid_logs) 73 | test_logs = agg_all_metrics(test_logs) 74 | 75 | # Update metrics and save model 76 | if( float(test_logs['kappa']) > best_test_metric ): 77 | best_test_metric = float(test_logs['kappa']) 78 | # Save model with best test kappa (not based on validation set) 79 | dir_best_test_metric = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "/best_test_kappa/" 80 | save_model(dir_best_test_metric, model) 81 | if( float(valid_logs['kappa']) > best_valid_metric ): 82 | best_valid_metric = float(valid_logs['kappa']) 83 | test_metric_for_best_valid_metric = float(test_logs['kappa']) 84 | # Save model with best validation kappa 85 | dir_best_valid_metric = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "/best_valid_kappa/" 86 | save_model(dir_best_valid_metric, model) 87 | 88 | # Push logs to neptune 89 | if args.neptune: 90 | run["metrics/test/accuracy"].log(test_logs['acc']) 91 | run["metrics/test/kappa"].log(test_logs['kappa']) 92 | run["metrics/test/loss"].log(test_logs['loss']) 93 | run["metrics/valid/accuracy"].log(valid_logs['acc']) 94 | run["metrics/valid/kappa"].log(valid_logs['kappa']) 95 | run["metrics/valid/loss"].log(valid_logs['loss']) 96 | run["metrics/test/best_kappa"].log(best_test_metric) 97 | run["metrics/test/best_kappa_with_valid"].log(test_metric_for_best_valid_metric) 98 | run["logs/cur_iter"].log(cur_iter) 99 | run["logs/valid/it_time"].log(eval_it_time) 100 | run["logs/test/it_time"].log(test_it_time) 101 | 102 | # Save model every save_freq epochs 103 | if( (cur_iter % args.save_freq == 0) or (cur_iter >= args.iters) ): 104 | dir_model = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "cross_val_fold_{}".format(args.cross_val_fold) + "/" + "/epoch_{}/".format(cur_iter) 105 | save_model(dir_model, model) 106 | -------------------------------------------------------------------------------- /train/train_meta_learning_via_bert_incontext_tuning.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | from torch import nn 4 | from tqdm import tqdm 5 | 6 | from models.meta_learning_via_bert_incontext_tuning import MetaLearningViaLanguageModelInContextTuning 7 | from utils.utils import agg_all_metrics, save_model 8 | 9 | 10 | 11 | def train_meta_learning_via_bert_incontext_tuning(args, run, device, saved_models_dir, scaler): 12 | # Load list of tasks for meta learning 13 | with open("data/tasks.json", "r") as f: 14 | task_list = json.load(f) 15 | if( args.debug ): 16 | task_list = task_list[0:2] 17 | # Load task to question map for each task 18 | with open("data/task_to_question.json", "r") as f: 19 | task_to_question = json.load(f) 20 | # Load task to passage map for each task 21 | with open("data/task_to_passage.json", "r") as f: 22 | task_to_passage = json.load(f) 23 | # Load passage texts 24 | with open("data/passages.json", "r") as f: 25 | passages = json.load(f) 26 | 27 | # Prepare data and model for training 28 | model = MetaLearningViaLanguageModelInContextTuning(args, device, task_list, task_to_question, task_to_passage, passages) 29 | model.prepare_data() 30 | model.prepare_model() 31 | 32 | # Dict of metric variables = kappa for each task trained on during meta learning 33 | metrics = {} 34 | for task in task_list: 35 | metrics[task] = {} 36 | # Best test kappa 37 | metrics[task]["best_test_metric"] = -1 38 | # Test kappa corresponding to best validation kappa 39 | metrics[task]["test_metric_for_best_valid_metric"] = -1 40 | # Best validation kappa 41 | metrics[task]["best_valid_metric"] = -1 42 | 43 | # Train-val-test loop 44 | loss_func = nn.CrossEntropyLoss() 45 | for cur_iter in tqdm(range(args.iters)): 46 | train_loader, valid_loaders, test_loaders = model.dataloaders() 47 | 48 | # Train epoch on one big train dataset = union of train datasets across items for meta learning 49 | start_time = time.time() 50 | # Set model to train mode needed if dropout, etc is used 51 | model.train() 52 | train_logs = [] 53 | for batch in train_loader: 54 | batch = {k: v.to(device) for k, v in batch.items()} 55 | logs = model.train_step(batch, scaler, loss_func) 56 | train_logs.append(logs) 57 | train_it_time = time.time() - start_time 58 | 59 | # Aggregate logs across all batches 60 | train_logs = agg_all_metrics(train_logs) 61 | # Log to neptune 62 | if args.neptune: 63 | run["metrics/train/accuracy"].log(train_logs['acc']) 64 | run["metrics/train/kappa"].log(train_logs['kappa']) 65 | run["metrics/train/loss"].log(train_logs['loss']) 66 | run["logs/train/it_time"].log(train_it_time) 67 | 68 | # Set model to test mode needed if dropout, etc is used 69 | model.eval() 70 | if( (cur_iter % args.eval_freq == 0) or (cur_iter >= args.iters) ): 71 | # Dict of validation and test logs for all items 72 | test_logs, valid_logs = {}, {} 73 | for task in task_list: 74 | test_logs[task], valid_logs[task] = [], [] 75 | 76 | # Validation epoch for each item 77 | eval_start_time = time.time() 78 | for task in task_list: 79 | valid_loader = valid_loaders[task] 80 | for batch in valid_loader: 81 | batch = {k: v.to(device) for k, v in batch.items()} 82 | logs = model.eval_step(batch) 83 | valid_logs[task].append(logs) 84 | eval_it_time = time.time()-eval_start_time 85 | 86 | # Test epoch for each item 87 | test_start_time = time.time() 88 | for task in task_list: 89 | test_loader = test_loaders[task] 90 | for batch in test_loader: 91 | batch = {k: v.to(device) for k, v in batch.items()} 92 | logs = model.test_step(batch) 93 | test_logs[task].append(logs) 94 | test_it_time = time.time()-test_start_time 95 | 96 | # Aggregate logs across batches and and across items 97 | for task in task_list: 98 | valid_logs[task] = agg_all_metrics(valid_logs[task]) 99 | test_logs[task] = agg_all_metrics(test_logs[task]) 100 | 101 | for task in task_list: 102 | # Update metrics 103 | if( len(test_logs[task]) > 0 ): 104 | metrics[task]["best_test_metric"] = max(test_logs[task]['kappa'], metrics[task]["best_test_metric"]) 105 | if( len(valid_logs[task]) > 0 ): 106 | if( float(valid_logs[task]["kappa"]) > metrics[task]["best_valid_metric"] ): 107 | metrics[task]["best_valid_metric"] = valid_logs[task]["kappa"] 108 | if( len(test_logs[task]) > 0 ): 109 | metrics[task]["test_metric_for_best_valid_metric"] = float(test_logs[task]["kappa"]) 110 | 111 | # Log to neptune for all items 112 | if args.neptune: 113 | if( len(test_logs[task]) > 0 ): 114 | run["metrics/{}/test/accuracy".format(task)].log(test_logs[task]['acc']) 115 | run["metrics/{}/test/kappa".format(task)].log(test_logs[task]['kappa']) 116 | run["metrics/{}/test/loss".format(task)].log(test_logs[task]['loss']) 117 | run["metrics/{}/test/best_kappa".format(task)].log(metrics[task]["best_test_metric"]) 118 | run["metrics/{}/test/best_kappa_with_valid".format(task)].log(metrics[task]["test_metric_for_best_valid_metric"]) 119 | if( len(valid_logs[task]) > 0 ): 120 | run["metrics/{}/valid/accuracy".format(task)].log(valid_logs[task]['acc']) 121 | run["metrics/{}/valid/kappa".format(task)].log(valid_logs[task]['kappa']) 122 | run["metrics/{}/valid/loss".format(task)].log(valid_logs[task]['loss']) 123 | run["metrics/{}/valid/best_kappa".format(task)].log(metrics[task]["best_valid_metric"]) 124 | run["logs/cur_iter"].log(cur_iter) 125 | run["logs/valid/it_time"].log(eval_it_time) 126 | run["logs/test/it_time"].log(test_it_time) 127 | 128 | # Save model after every epoch irrespective of save_freq param 129 | dir_model = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "cross_val_fold_{}".format(args.cross_val_fold) + "/" + "/epoch_{}/".format(cur_iter) 130 | save_model(dir_model, model) -------------------------------------------------------------------------------- /utils/batch_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | 5 | def tokenize_function(tokenizer, sentences_1, sentences_2=None): 6 | if(sentences_2 == None): 7 | return tokenizer(sentences_1, padding=True, truncation=True, return_tensors="pt") 8 | else: 9 | return tokenizer(sentences_1, sentences_2, padding=True, truncation=True, return_tensors="pt") 10 | 11 | 12 | class CollateWraperParent(object): 13 | def __init__(self, tokenizer, min_label): 14 | self.tokenizer = tokenizer 15 | self.min_label = min_label 16 | 17 | 18 | class CollateWraper(CollateWraperParent): 19 | # batch collator for BERT fine-tuning 20 | def __init__(self, tokenizer, min_label): 21 | super().__init__(tokenizer, min_label) 22 | 23 | def __call__(self, batch): 24 | # Construct features 25 | features = [d['txt'] for d in batch] 26 | inputs = tokenize_function(self.tokenizer, features) 27 | 28 | # Construct labels 29 | labels = torch.tensor([d['l1'] if d['l1']>=0 else d['l2'] for d in batch]).long() - self.min_label 30 | inputs['labels'] = labels 31 | 32 | return {"inputs" : inputs} 33 | 34 | 35 | class CollateWraperInContextTuningMetaLearning(CollateWraperParent): 36 | # batch collator for meta learning via BERT in-context tuning 37 | def __init__(self, tokenizer, data_meta, task_to_question, num_examples, trunc_len, mode, num_test_avg=8, 38 | num_val_avg=8, test_batch_size=1, val_batch_size=1, max_seq_len=512, use_demographic=False): 39 | super().__init__(tokenizer, min_label = 1) 40 | 41 | self.data_meta = data_meta 42 | self.num_examples = num_examples 43 | self.trunc_len = trunc_len 44 | self.task_to_question = task_to_question 45 | self.mode = mode 46 | # Adding an extra 50 words in case num_tokens < num_words after tokenization 47 | self.max_seq_len = max_seq_len + 50 48 | 49 | # Convert numeric scores to meaningful words 50 | self.label_to_text = { 51 | 1 : "poor", 52 | 2 : "fair", 53 | 3 : "good", 54 | 4 : "excellent" 55 | } 56 | 57 | # Demographic information 58 | self.use_demographic = use_demographic 59 | self.gender_map = { 60 | "1" : "male", 61 | "2" : "female" 62 | } 63 | self.race_map = { 64 | "1" : "white", 65 | "2" : "african american", 66 | "3" : "hispanic", 67 | "4" : "asian", 68 | "5" : "american indian", 69 | "6" : "pacific islander", 70 | "7" : "multiracial" 71 | } 72 | 73 | # Meta learning via BERT in-context tuning 74 | self.num_test_avg = num_test_avg 75 | self.num_val_avg = num_val_avg 76 | self.test_batch_size = test_batch_size 77 | self.val_batch_size = val_batch_size 78 | 79 | 80 | def __call__(self, batch): 81 | if( self.mode == "test" or self.mode == "val" ): 82 | # Since drop_last=False in test/val loader, record actual test_batch_size/val_batch_size for last batch constructed 83 | actual_batch_size = torch.tensor(len(batch)).long() 84 | 85 | # Repeat each test/val sample num_test_avg/num_val_avg times sequentially 86 | new_batch = [] 87 | for d in batch: 88 | if( self.mode == "test" ): 89 | new_batch += [d for _ in range(self.num_test_avg)] 90 | else: 91 | new_batch += [d for _ in range(self.num_val_avg)] 92 | batch = new_batch 93 | else: 94 | actual_batch_size = torch.tensor(-1).long() 95 | 96 | # Construct features: features_1 (answer txt) will have different segment embeddings than features_2 (remaining txt) 97 | features_1 = [] 98 | features_2 = [] 99 | for d in batch: 100 | # Randomly sample num_examples in-context examples from each class in train set for datapoint d 101 | examples_many_per_class = [] 102 | # List examples_each_class stores one example from each class 103 | examples_one_per_class = [] 104 | labels = list(range(d["min"], d["max"] + 1)) 105 | 106 | for label in labels: 107 | examples_class = self.data_meta[d["task"]]["examples"][label] 108 | 109 | # Remove current datapoint d from examples_class by checking unique booklet identifiers => no information leakage 110 | examples_class = [ex for ex in examples_class if ex["bl"] != d["bl"]] 111 | 112 | # Sampling num_examples without replacement 113 | if( len(examples_class) < self.num_examples ): 114 | random.shuffle(examples_class) 115 | examples_class_d = examples_class 116 | else: 117 | examples_class_d = random.sample(examples_class, self.num_examples) 118 | 119 | if( len(examples_class_d) > 1 ): 120 | examples_one_per_class += [examples_class_d[0]] 121 | examples_many_per_class += examples_class_d[1:] 122 | elif( len(examples_class_d) == 1 ): 123 | examples_one_per_class += [examples_class_d[0]] 124 | examples_many_per_class += [] 125 | else: 126 | examples_one_per_class += [] 127 | examples_many_per_class += [] 128 | 129 | # Construct input text with task instructions 130 | if( self.use_demographic ): 131 | input_txt = "score this answer written by {} {} student: ".format(self.gender_map[d["sx"]], self.race_map[d["rc"]]) + d['txt'] 132 | else: 133 | input_txt = "score this answer: " + d['txt'] 134 | features_1.append(input_txt) 135 | 136 | # Add range of valid score classes for datapoint d 137 | examples_txt = " scores: " + " ".join([ (self.label_to_text[label] + " ") for label in range(d["min"], d["max"] + 1) ]) 138 | # Add question text 139 | examples_txt += "[SEP] question: {} [SEP] ".format(self.task_to_question[d["task"]]) 140 | 141 | # Shuffle examples across classes 142 | random.shuffle(examples_one_per_class) 143 | random.shuffle(examples_many_per_class) 144 | 145 | # Since truncation might occur if text length exceed max input length to LM, 146 | # we ensure at least one example from each score class is present 147 | examples_d = examples_one_per_class + examples_many_per_class 148 | curr_len = len(input_txt.split(" ") + examples_txt.split(" ")) 149 | for i in range(len(examples_d)): 150 | example = examples_d[i] 151 | example_txt_tokens = example['txt'].split(" ") 152 | curr_example_len = len(example_txt_tokens) 153 | example_txt = " ".join(example_txt_tokens[:self.trunc_len]) 154 | example_label = (example['l1'] if example['l1']>=0 else example['l2']) 155 | # [SEP] at the end of the last example is automatically added by tokenizer 156 | if( i == (len(examples_d)-1) ): 157 | if( self.use_demographic ): 158 | examples_txt += ( " example written by {} {} student: ".format(self.gender_map[example["sx"]], self.race_map[example["rc"]]) + example_txt + " score: " + self.label_to_text[example_label] ) 159 | else: 160 | examples_txt += ( " example: " + example_txt + " score: " + self.label_to_text[example_label] ) 161 | else: 162 | if( self.use_demographic ): 163 | examples_txt += ( " example written by {} {} student: ".format(self.gender_map[example["sx"]], self.race_map[example["rc"]]) + example_txt + " score: " + self.label_to_text[example_label] + " [SEP] " ) 164 | else: 165 | examples_txt += ( " example: " + example_txt + " score: " + self.label_to_text[example_label] + " [SEP] " ) 166 | 167 | # Stop adding in-context examples when max_seq_len is reached 168 | if( (curr_example_len + curr_len) > self.max_seq_len): 169 | break 170 | else: 171 | curr_len += curr_example_len 172 | features_2.append(examples_txt) 173 | 174 | inputs = tokenize_function(self.tokenizer, features_1, features_2) 175 | 176 | # Construct labels 177 | labels = torch.tensor([ (d['l1']-d["min"]) if d['l1']>=0 else (d['l2']-d["min"]) for d in batch]).long() 178 | inputs['labels'] = labels 179 | 180 | # Store max_label for each d in batch which is used during softmax masking 181 | max_labels = torch.tensor([( d["max"]-d["min"]+1) for d in batch]).long() 182 | 183 | return { 184 | "inputs" : inputs, 185 | "max_labels" : max_labels, 186 | "actual_batch_size" : actual_batch_size 187 | } -------------------------------------------------------------------------------- /utils/load_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import os 4 | 5 | 6 | RAW_DIR = "../../data/NAEP_AS_Challenge_Data/Items for Item-Specific Models/" 7 | 8 | 9 | def compute_distribution(output): 10 | append_keys = {} 11 | 12 | for k, v in output.items(): 13 | dist, new_key, total_count = defaultdict(float), k+'_dist', 0. 14 | for d in v: 15 | n_rating = (d['l1']>=0) + (d['l2']>=0) 16 | total_count +=n_rating 17 | if( d['l1'] != -1 ): 18 | dist[d['l1']] += 1./n_rating 19 | if( d['l2'] != -1): 20 | dist[d['l2']] += 1./n_rating 21 | for l in dist: 22 | dist[l] /= total_count 23 | append_keys[new_key] = dist 24 | 25 | for k in append_keys: 26 | output[k] = append_keys[k] 27 | 28 | 29 | def load_dataset_base(task, debug=False, data_folder="data_split_answer_spell_checked", cross_val_fold=1): 30 | """ 31 | Returns a dictionary: 32 | { 33 | 'train' : [{key:value}], # training dataset 34 | 'val' : [{key:value}], # validation dataset 35 | 'test' : [{key:value}], # test dataset 36 | 'train_dist' : {label:percentage} # distribution of scores in train-set 37 | 'val_dist' : {label:percentage} # distribution of scores in val-set 38 | 'test_dist' : {label:percentage} # distribution of scores in test-set 39 | } 40 | 41 | Each of the train/val/test datasets are a list of samples. Each sample is a dictionary of following (key, value) pairs: 42 | {'bl':string, 'l1':int, 'l2':int, 'sx':string, 'rc':string, 'txt':string} 43 | 44 | The keys above are: 45 | bl: unique database like key to identify student response 46 | 'l1' : score by human rater 1 47 | 'l2' : score by human rater 2 (set as -1 if not available) 48 | 'sx' : sex of student (optional) 49 | 'rc' : race of student 50 | 'txt' : student response text to the reading comprehension item to be scored 51 | """ 52 | 53 | suffix = "_".join(data_folder.split("_")[1:]) 54 | 55 | if( cross_val_fold == 0 ): 56 | dir_name = RAW_DIR + task + "/" + data_folder 57 | else: 58 | # Use spell checked version and cross validation fold number 59 | dir_name = RAW_DIR + task + "/" + "cross_val_fold_{}".format(cross_val_fold) + "/" + data_folder 60 | 61 | data = {} 62 | filenames = [("train", "train"), ("val", "valid"), ("test", "test")] 63 | for i in range(len(filenames)): 64 | filename = os.path.join(dir_name, task.split('/')[1] + "_{}_{}.json".format(filenames[i][0], suffix)) 65 | with open(filename, "r") as f: 66 | data[filenames[i][1]] = json.load(f) 67 | 68 | # Compute score distribution 69 | compute_distribution(data) 70 | 71 | # Debug with less data if required 72 | if(debug): 73 | if( len(data["train"]) > 0 ): 74 | data["train"] = data["train"][:4] 75 | if( len(data["valid"]) > 0 ): 76 | data["valid"] = data["valid"][:4] 77 | if( len(data["test"]) > 0 ): 78 | data["test"] = data["test"][:4] 79 | 80 | return data 81 | 82 | 83 | def load_dataset_in_context_tuning(task, debug=False, data_folder="data_split_answer_spell_checked", cross_val_fold=1): 84 | data = load_dataset_base(task, debug, data_folder, cross_val_fold) 85 | 86 | # Construct in_context examples from training dataset partitioned according to class score label 87 | examples_train = None 88 | max_label = max(data['train_dist'].keys()) 89 | min_label = min(data['train_dist'].keys()) 90 | examples_train = {} 91 | for label in range(min_label, max_label + 1): 92 | examples_train[label] = [] 93 | for datapoint in data["train"]: 94 | label = datapoint['l1'] if datapoint['l1']>=0 else datapoint['l2'] 95 | examples_train[label].append(datapoint) 96 | 97 | return data, examples_train, min_label, max_label 98 | 99 | 100 | def load_dataset_in_context_tuning_with_meta_learning(debug=False, data_folder="data_split_answer_spell_checked", task_list=[], cross_val_fold=1): 101 | # Load list of item names 102 | with open("data/tasks.json", "r") as f: 103 | task_list = json.load(f) 104 | if( debug ): 105 | task_list = task_list[0:2] 106 | 107 | data_meta = {} 108 | for task in task_list: 109 | data_meta[task] = {} 110 | data, examples_train, min_label, max_label = load_dataset_in_context_tuning(task, debug, data_folder, cross_val_fold) 111 | data_meta[task]["train"] = data["train"] 112 | data_meta[task]["valid"] = data["valid"] 113 | data_meta[task]["test"] = data["test"] 114 | # Add in-context examples from training dataset => no information leakage from val/test sets 115 | data_meta[task]["examples"] = {} 116 | for label in range(min_label, max_label + 1): 117 | data_meta[task]["examples"][label] = examples_train[label] 118 | # Add task, min_label and max_label info to each sample 119 | for set in ["train", "valid", "test"]: 120 | for sample in data_meta[task][set]: 121 | sample["min"] = min_label 122 | sample["max"] = max_label 123 | sample["task"] = task 124 | 125 | # Union of training datasets across tasks 126 | data_meta["train"] = [] 127 | for task in task_list: 128 | data_meta["train"] += data_meta[task]["train"] 129 | 130 | return data_meta, None, None -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import cohen_kappa_score 3 | import argparse 4 | import torch 5 | import pathlib 6 | 7 | 8 | def tonp(x): 9 | if isinstance(x, (np.ndarray, float, int)): 10 | return np.array(x) 11 | 12 | return x.detach().cpu().numpy() 13 | 14 | 15 | def agg_all_metrics(outputs): 16 | # Aggregate metrics for entire epoch across all batches 17 | 18 | if( len(outputs) == 0 ): 19 | return outputs 20 | 21 | res = {} 22 | keys = [ k for k in outputs[0].keys() if not isinstance(outputs[0][k], dict) ] 23 | for k in keys: 24 | all_logs = np.concatenate([tonp(x[k]).reshape(-1) for x in outputs]) 25 | if( k != 'epoch' ): 26 | res[k] = np.mean(all_logs) 27 | else: 28 | res[k] = all_logs[-1] 29 | 30 | if 'kappa' in outputs[0]: 31 | pred_logs = np.concatenate([tonp(x['kappa']['preds']).reshape(-1) for x in outputs]) 32 | label_logs = np.concatenate([tonp(x['kappa']['labels']).reshape(-1) for x in outputs]) 33 | if( np.array_equal(pred_logs, label_logs) ): 34 | # Edge case: cohen_kappa_score from sklearn returns a value of NaN if perfect agreement 35 | res['kappa'] = 1 36 | else: 37 | res['kappa'] = cohen_kappa_score(pred_logs, label_logs, weights= 'quadratic') 38 | 39 | return res 40 | 41 | 42 | def add_params(): 43 | parser = argparse.ArgumentParser(description='automated_scoring') 44 | 45 | parser.add_argument('--name', default='automated_scoring', help='Name of the experiment') 46 | parser.add_argument('--neptune_project', default="user_name/project_name", help='Name of the neptune project') 47 | # Problem definition 48 | parser.add_argument('--lm', default='bert-base-uncased', help='Base language model (provide any Hugging face model name)') 49 | parser.add_argument('--task', default="item_name", help='Item name (not required for meta learning via in-context tuning)') 50 | # Add demographic information for fairness analysis - generative models don't have this option 51 | parser.add_argument('--demographic', action='store_true', help='Use demographic information of student') 52 | 53 | # Meta learning BERT via in-context tuning 54 | # At testing, batch_size = batch_size * num_test_avg, similarly for validation 55 | # Ensure increased batch_size at test/val can be loaded onto GPU 56 | parser.add_argument('--meta_learning', action='store_true', help='Enable meta-learning via BERT in-context tuning') 57 | parser.add_argument('--num_test_avg', default=8, type=int, help='Number of different sets of randomly sampled examples per test datapoint to average score predictions') 58 | parser.add_argument('--num_val_avg', default=8, type=int, help='Number of different sets of randomly sampled examples per val datapoint to average score predictions') 59 | parser.add_argument('--num_examples', default=25, type=int, help='Number of in-context examples from each score class to add to input') 60 | parser.add_argument('--trunc_len', default=70, type=int, help='Max number of words in each in-context example') 61 | 62 | # Optimizer params 63 | parser.add_argument('--lr_schedule', default='warmup-const', help='Learning rate schedule to use') 64 | parser.add_argument('--opt', default='adam', choices=['sgd', 'adam', 'lars'], help='Optimizer to use') 65 | parser.add_argument('--iters', default=100, type=int, help='Number of epochs') 66 | parser.add_argument('--lr', default=2e-5, type=float, help='Base learning rate') 67 | parser.add_argument('--batch_size', default=32, type=int, help='Batch size') 68 | 69 | # Data loading 70 | parser.add_argument('--data_folder', default="data_split_answer_spell_checked_submission", help='Dataset folder name containing train-val-test splits for each cross validation fold') 71 | parser.add_argument('--cross_val_fold', default=1, type=int, help='Cross validation fold to use') 72 | 73 | # Extras 74 | parser.add_argument('--save_freq', default=1, type=int, help='Epoch frequency to save the model') 75 | parser.add_argument('--eval_freq', default=1, type=int, help='Epoch frequency for evaluation') 76 | parser.add_argument('--workers', default=4, type=int, help='Number of data loader workers') 77 | parser.add_argument('--seed', default=999, type=int, help='Random seed') 78 | parser.add_argument('--cuda', action='store_true', help='Use cuda') 79 | parser.add_argument('--save', action='store_true', help='Save model every save_freq epochs') 80 | parser.add_argument('--neptune', action='store_true', help='Enable logging to Neptune') 81 | parser.add_argument('--debug', action='store_true', help='Debug mode with less items and smaller datasets') 82 | # Automatic mixed precision training -> faster training but might affect accuracy 83 | parser.add_argument('--amp', action='store_true', help='Apply automatic mixed precision training') 84 | 85 | params = parser.parse_args() 86 | 87 | return params 88 | 89 | 90 | def save_model(dir_model, model): 91 | pathlib.Path(dir_model).mkdir(parents=True, exist_ok=True) 92 | model.tokenizer.save_pretrained(dir_model) 93 | if( torch.cuda.device_count() > 1 ): 94 | # For an nn.DataParallel object, the model is stored in .module 95 | model.model.module.save_pretrained(dir_model) 96 | else: 97 | model.model.save_pretrained(dir_model) --------------------------------------------------------------------------------