├── README.md
├── img
└── model_architecture.png
├── models
├── bert.py
└── meta_learning_via_bert_incontext_tuning.py
├── requirements.txt
├── run.py
├── train
├── train_bert.py
└── train_meta_learning_via_bert_incontext_tuning.py
└── utils
├── batch_collator.py
├── load_data.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # [Automated Scoring for Reading Comprehension via In-context BERT Tuning](https://arxiv.org/abs/2205.09864)
2 |
3 | We present our (grand prize-winning) solution to the [NAEP Automated Scoring Challenge](https://github.com/NAEP-AS-Challenge/info) for reading comprehension items. We develop a novel automated scoring approach based on meta-learning ideas via in-context tuning of language models. This repository contains the implementation of our best performing model *Meta-trained BERT In-context* and the BERT fine-tuning baseline from our paper [Automated Scoring for Reading Comprehension via In-context BERT Tuning](https://arxiv.org/abs/2205.09864) by [Nigel Fernandez](https://www.linkedin.com/in/ni9elf/), [Aritra Ghosh](https://arghosh.github.io), [Naiming Liu](https://www.linkedin.com/in/naiming-liu-lucy0817/), [Zichao Wang](https://zw16.web.rice.edu), [Benoît Choffin](https://benoitchoffin.github.io/about/), [Richard Baraniuk](https://richb.rice.edu), and [Andrew Lan](https://people.umass.edu/~andrewlan) published at AIED 2022.
4 |
5 |
6 |
7 |
8 |
9 | For any questions please [email](mailto:nigel@cs.umass.edu) or raise an issue.
10 |
11 | If you find our code or paper useful, please consider citing:
12 | ```
13 | @inproceedings{ fernandez2022AS,
14 | title={ Automated Scoring for Reading Comprehension via In-context {BERT} Tuning },
15 | author={ Fernandez, Nigel and Ghosh, Aritra and Liu, Naiming and Wang, Zichao and Choffin, Beno{\^{\i}}t and Baraniuk, Richard and Lan, Andrew },
16 | booktitle={ 23rd International Conference on Artificial Intelligence in Education (AIED 2022) },
17 | year={ 2022 }
18 | }
19 | ```
20 |
21 |
22 |
23 | ## Contents
24 |
25 | 1. [Installation](#installation)
26 | 2. [Usage](#usage)
27 | 3. [Dataset](#dataset)
28 | 4. [Code Structure](#code-structure)
29 | 5. [Acknowledgements](#acknowledgements)
30 |
31 |
32 |
33 | ## Installation
34 |
35 | Our code is tested with `Python 3.7.4` in a virtual environment created using:
36 |
37 | ```
38 | virtualenv -p python3 env
39 | ```
40 |
41 | Install the dependencies listed in the `requirements.txt` file by running:
42 | ```
43 | pip install -r requirements.txt
44 | ```
45 |
46 |
47 |
48 | ## Usage
49 | A [Neptune](https://neptune.ai) account is required to log train-test-val information (loss, metrics, etc). Best model checkpoints are saved locally.
50 |
51 | To train-val-test a BERT baseline model for fine-tuning, please run:
52 | ```
53 | python run.py\
54 | --name "lm_base_response_only"\
55 | --task "[name]"\
56 | --lm "bert-base-uncased"\
57 | --batch_size 32\
58 | --neptune_project "[name]"\
59 | --data_folder [name]\
60 | --cross_val_fold 1\
61 | --cuda\
62 | --neptune_project [name]\
63 | --neptune
64 | ```
65 | To train-val-test our Meta-trained BERT In-context model, please run:
66 | ```
67 | python run.py\
68 | --name "meta_lm_incontext"\
69 | --task "meta_learning_via_lm_incontext_tuning"\
70 | --lm "bert-base-uncased"\
71 | --batch_size 32\
72 | --neptune_project "[name]"\
73 | --meta_learning\
74 | --num_test_avg 8\
75 | --num_val_avg 8\
76 | --data_folder [name]\
77 | --cross_val_fold 1\
78 | --cuda\
79 | --neptune_project [name]\
80 | --neptune
81 | ```
82 |
83 |
84 | Argument information:
85 |
86 | `-h, --help` Show this help message and exit
87 |
88 | `--name NAME` Name of the experiment
89 |
90 | `--neptune_project NEPTUNE_PROJECT` Name of the neptune project
91 |
92 | `--lm LM` Base language model (provide any Hugging face model name)
93 |
94 | `--task TASK` Item name (not required for meta learning via in-context tuning)
95 |
96 | `--demographic` Use demographic information of student
97 |
98 | `--meta_learning` Enable meta-learning via BERT in-context tuning
99 |
100 | `--num_test_avg NUM_TEST_AVG` Number of different sets of randomly sampled examples per test datapoint to average score predictions
101 |
102 | `--num_val_avg NUM_VAL_AVG` Number of different sets of randomly sampled examples per val datapoint to average score predictions
103 |
104 | `--num_examples NUM_EXAMPLES` Number of in-context examples from each score class to add to input
105 |
106 | `--trunc_len TRUNC_LEN` Max number of words in each in-context example
107 |
108 | `--lr_schedule LR_SCHEDULE` Learning rate schedule to use
109 |
110 | `--opt {sgd,adam,lars}` Optimizer to use
111 |
112 | `--iters ITERS` Number of epochs
113 |
114 | `--lr LR` Base learning rate
115 |
116 | `--batch_size BATCH_SIZE` Batch size
117 |
118 | `--data_folder DATA_FOLDER` Dataset folder name containing train-val-test splits for each cross validation fold
119 |
120 | `--cross_val_fold CROSS_VAL_FOLD` Cross validation fold to use
121 |
122 | `--save_freq SAVE_FREQ` Epoch frequency to save the model
123 |
124 | `--eval_freq EVAL_FREQ` Epoch frequency for evaluation
125 |
126 | `--workers WORKERS` Number of data loader workers
127 |
128 | `--seed SEED` Random seed
129 |
130 | `--cuda` Use cuda
131 |
132 | `--save` Save model every save_freq epochs
133 |
134 | `--neptune` Enable logging to Neptune
135 |
136 | `--debug` Debug mode with less items and smaller datasets
137 |
138 | `--amp` Apply automatic mixed precision training to save GPU memory
139 |
140 |
141 |
142 | ## Dataset
143 | For our experimentation, we used the training dataset provided by the [NAEP Automated Scoring Challenge](https://github.com/NAEP-AS-Challenge/info) organizers. Please contact them for usage.
144 |
145 | To run our approach for equivalent problems, the dataset structure expected is three separate train, validation and test json files, each containing a list of data samples represented as dictionaries `[{key:value}]`. Each data sample dictionary contains the following (key, value) pairs:
146 | ```
147 | {'bl':[string], 'l1':[int], 'l2':[int], 'sx':[string], 'rc':[string], 'txt':[string]}
148 | ```
149 |
150 | The keys above are:
151 |
152 | `'bl'`: Unique database like key to identify student response
153 |
154 | `'l1'` : Score label by first human rater
155 |
156 | `'l2'` : Score label by second human rater (set as -1 if not available)
157 |
158 | `'sx'` : Sex of student (optional, required if demographic argument is true)
159 |
160 | `'rc'` : Race of student (optional, required if demographic argument is true)
161 |
162 | `'txt'` : Student response text to the reading comprehension item to be scored
163 |
164 |
165 |
166 | ## Code Structure
167 | `models` contains `bert.py` and `meta_learning_via_bert_incontext_tuning.py` implementing BERT and meta-trained BERT in-context models, respectively.
168 |
169 | `train` contains `train_bert.py` and `train_meta_learning_via_bert_incontext_tuning.py` implementing training scripts for BERT and meta-trained BERT in-context, respectively.
170 |
171 | `utils` contains `batch_collator.py` to collate batches for training, `load_data.py` to load train-val-test sets, and `utils.py` with general utility functions.
172 |
173 |
174 |
175 | ## Acknowledgements
176 | Fernandez, Ghosh, and Lan are partially supported by the National Science Foundation under grants IIS-1917713 and IIS-2118706.
--------------------------------------------------------------------------------
/img/model_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ni9elf/automated-scoring/4b8bc5e7e3cfe08220feddf114eeeca2452a53b1/img/model_architecture.png
--------------------------------------------------------------------------------
/models/bert.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AdamW, GPT2Tokenizer, GPT2LMHeadModel, GPT2ForSequenceClassification, GPT2Config
3 | import torch
4 | from torch import nn
5 |
6 | from utils.batch_collator import CollateWraper
7 | from utils.load_data import load_dataset_base
8 |
9 |
10 | class LanguageModelBase(nn.Module):
11 | def __init__(self, params, device):
12 | super().__init__()
13 | self.params = copy.deepcopy(params)
14 | self.device = device
15 |
16 |
17 | def prepare_model(self):
18 | self.config = AutoConfig.from_pretrained(self.params.lm, num_labels=self.num_labels)
19 | self.tokenizer = AutoTokenizer.from_pretrained(self.params.lm)
20 | if self.tokenizer.pad_token is None:
21 | self.tokenizer.pad_token = self.tokenizer.eos_token
22 | self.config.pad_token_id = self.config.eos_token_id
23 | self.model = AutoModelForSequenceClassification.from_pretrained(self.params.lm, config=self.config).to(self.device)
24 | self.optimizer = AdamW(self.model.parameters(), lr=self.params.lr)
25 | # Multi GPU mode
26 | if( torch.cuda.device_count() > 1 ):
27 | self.model = nn.DataParallel(self.model)
28 |
29 |
30 | def prepare_data(self):
31 | data = load_dataset_base(self.params.task, self.params.debug, self.params.data_folder, self.params.cross_val_fold)
32 | self.trainset = data['train']
33 | self.validset = data['valid']
34 | self.testset = data['test']
35 | self.max_label = max(data['train_dist'].keys())
36 | self.min_label = min(data['train_dist'].keys())
37 | self.num_labels = self.max_label - self.min_label + 1
38 |
39 |
40 | def dataloaders(self):
41 | train_loader = torch.utils.data.DataLoader(self.trainset, collate_fn=CollateWraper(self.tokenizer, self.min_label),
42 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=True, drop_last=False)
43 | valid_loader = torch.utils.data.DataLoader(self.validset, collate_fn=CollateWraper(self.tokenizer, self.min_label),
44 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=False, drop_last=False)
45 | test_loader = torch.utils.data.DataLoader(self.testset, collate_fn=CollateWraper(self.tokenizer, self.min_label),
46 | batch_size=self.params.batch_size, num_workers=self.params.workers, shuffle=False, drop_last=False)
47 |
48 | return train_loader, valid_loader, test_loader
49 |
50 |
51 | def zero_grad(self):
52 | self.optimizer.zero_grad()
53 |
54 |
55 | def grad_step(self, scaler):
56 | if( self.params.amp ):
57 | scaler.step(self.optimizer)
58 | else:
59 | self.optimizer.step()
60 |
61 |
62 | def train_step(self, batch, scaler):
63 | self.zero_grad()
64 |
65 | # Cast operations to mixed precision
66 | if( self.params.amp ):
67 | with torch.cuda.amp.autocast():
68 | outputs = self.model(**batch["inputs"])
69 | else:
70 | outputs = self.model(**batch["inputs"])
71 |
72 | loss = outputs.loss
73 |
74 | # Multi gpu mode
75 | if( torch.cuda.device_count() > 1 ):
76 | if( self.params.amp ):
77 | scaler.scale(loss.sum()).backward()
78 | else:
79 | loss.sum().backward()
80 | else:
81 | if( self.params.amp ):
82 | scaler.scale(loss).backward()
83 | else:
84 | loss.backward()
85 |
86 | self.grad_step(scaler)
87 |
88 | logits = outputs.logits
89 | predictions = torch.argmax(logits, dim=-1)
90 | acc = ( predictions == batch["inputs"]["labels"] )
91 |
92 | if( self.params.amp ):
93 | scaler.update()
94 |
95 | return {'loss': loss.detach().cpu(),
96 | 'acc':acc.detach().cpu(),
97 | 'kappa':{
98 | 'preds':predictions.detach().cpu(),
99 | 'labels':batch["inputs"]["labels"].detach().cpu()
100 | }
101 | }
102 |
103 |
104 | def eval_step(self, batch):
105 | # Same as test step
106 | out = self.test_step(batch)
107 |
108 | return out
109 |
110 |
111 | def test_step(self, batch):
112 | with torch.no_grad():
113 | if( self.params.amp ):
114 | with torch.cuda.amp.autocast():
115 | outputs = self.model(**batch["inputs"])
116 | else:
117 | outputs = self.model(**batch["inputs"])
118 |
119 | loss = outputs.loss
120 | logits = outputs.logits
121 | predictions = torch.argmax(logits, dim=-1)
122 | acc = (predictions == batch["inputs"]["labels"])
123 |
124 | return {
125 | 'loss': loss.detach().cpu(),
126 | 'acc':acc.detach().cpu(),
127 | 'kappa':{
128 | 'preds':predictions.detach().cpu(),
129 | 'labels':batch["inputs"]["labels"].detach().cpu()
130 | }
131 | }
--------------------------------------------------------------------------------
/models/meta_learning_via_bert_incontext_tuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from models.bert import LanguageModelBase
5 | from utils.batch_collator import CollateWraperInContextTuningMetaLearning
6 | from utils.load_data import load_dataset_in_context_tuning_with_meta_learning
7 |
8 |
9 | class MetaLearningViaLanguageModelInContextTuning(LanguageModelBase):
10 | def __init__(self, params, device, task_list, task_to_question, task_to_passage, passages):
11 | super().__init__(params, device)
12 | self.task_list = task_list
13 | self.task_to_question = task_to_question
14 | self.task_to_passage = task_to_passage
15 | self.passages = passages
16 |
17 |
18 | def prepare_data(self):
19 | self.data_meta, _, _ = load_dataset_in_context_tuning_with_meta_learning(self.params.debug, self.params.data_folder,
20 | self.task_list, self.params.cross_val_fold)
21 | self.trainset = self.data_meta['train']
22 | self.validsets = {}
23 | self.testsets = {}
24 | for task in self.task_list:
25 | self.validsets[task] = self.data_meta[task]['valid']
26 | self.testsets[task] = self.data_meta[task]['test']
27 | self.test_batch_size = 12
28 | self.val_batch_size = 12
29 |
30 | # For meta-trained model, min_label=1 and max_label=4 is the same for all items since a fixed classification layer is used for all items
31 | self.min_label = 1
32 | self.max_label = 4
33 | self.num_labels = self.max_label - self.min_label + 1
34 |
35 |
36 | def dataloaders(self):
37 | collate_fn_train = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question,
38 | self.params.num_examples,
39 | self.params.trunc_len, mode="train",
40 | use_demographic = self.params.demographic)
41 | collate_fn_val = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question,
42 | self.params.num_examples,
43 | self.params.trunc_len, mode="val", num_val_avg=self.params.num_val_avg,
44 | val_batch_size=self.val_batch_size,
45 | use_demographic = self.params.demographic)
46 | collate_fn_test = CollateWraperInContextTuningMetaLearning(self.tokenizer, self.data_meta, self.task_to_question,
47 | self.params.num_examples,
48 | self.params.trunc_len, mode="test", num_test_avg=self.params.num_test_avg,
49 | test_batch_size=self.test_batch_size,
50 | use_demographic = self.params.demographic)
51 |
52 | train_loader = torch.utils.data.DataLoader(self.trainset, collate_fn=collate_fn_train, batch_size=self.params.batch_size,
53 | num_workers=self.params.workers, shuffle=True, drop_last=False)
54 | valid_loaders = {}
55 | test_loaders = {}
56 | for task in self.task_list:
57 | # For validation, batch_size after collating = batch_size * num_val_avg
58 | valid_loaders[task] = torch.utils.data.DataLoader(self.validsets[task], collate_fn=collate_fn_val,
59 | batch_size=self.val_batch_size, num_workers=self.params.workers,
60 | shuffle=False, drop_last=False)
61 | # For testing, batch_size after collating = batch_size * num_test_avg
62 | test_loaders[task] = torch.utils.data.DataLoader(self.testsets[task], collate_fn=collate_fn_test,
63 | batch_size=self.test_batch_size, num_workers=self.params.workers,
64 | shuffle=False, drop_last=False)
65 |
66 | return train_loader, valid_loaders, test_loaders
67 |
68 |
69 | def train_step(self, batch, scaler, loss_func):
70 | self.zero_grad()
71 |
72 | if( self.params.amp ):
73 | # Cast operations to mixed precision
74 | with torch.cuda.amp.autocast():
75 | outputs = self.model(**batch["inputs"])
76 | else:
77 | outputs = self.model(**batch["inputs"])
78 |
79 | logits = outputs.logits
80 |
81 | # Mask invalid score classes as negative infinity
82 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index
83 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device)
84 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1
85 | mask = mask.cumsum(dim=1)[:, :-1]
86 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf'))
87 |
88 | # Calculate masked cross entropy loss
89 | loss = loss_func(masked_logits.view(-1, self.num_labels), batch["inputs"]["labels"].view(-1))
90 |
91 | # Apply a softmax over valid score classes only
92 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1)
93 |
94 | # Calculate accuracy
95 | predictions = torch.argmax(softmax_outs, dim=-1)
96 | acc = ( predictions == batch["inputs"]["labels"] )
97 |
98 | # Multi gpu mode
99 | if( torch.cuda.device_count() > 1 ):
100 | if( self.params.amp ):
101 | scaler.scale(loss.sum()).backward()
102 | else:
103 | loss.sum().backward()
104 | else:
105 | if( self.params.amp ):
106 | scaler.scale(loss).backward()
107 | else:
108 | loss.backward()
109 |
110 | self.grad_step(scaler)
111 |
112 | if( self.params.amp ):
113 | scaler.update()
114 |
115 | return {'loss': loss.detach().cpu(),
116 | 'acc':acc.detach().cpu(),
117 | 'kappa':{
118 | 'preds':predictions.detach().cpu(),
119 | 'labels':batch["inputs"]["labels"].detach().cpu()
120 | }
121 | }
122 |
123 |
124 | def eval_step(self, batch):
125 | # Validation time averaging: Apply different sets of randomly sampled in-context examples per val datapoint to average score predictions
126 |
127 | with torch.no_grad():
128 | if( self.params.amp ):
129 | with torch.cuda.amp.autocast():
130 | outputs = self.model(**batch["inputs"])
131 | else:
132 | outputs = self.model(**batch["inputs"])
133 |
134 |
135 | loss = outputs.loss
136 | # Dimension of logits = batch_size X num_classes
137 | # Where batch_size = val_batch_size * num_val_avg
138 | logits = outputs.logits
139 |
140 | # Mask invalid score classes as negative infinity
141 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index
142 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device)
143 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1
144 | mask = mask.cumsum(dim=1)[:, :-1]
145 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf'))
146 |
147 | # Apply a softmax over valid score classes only
148 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1)
149 |
150 | # Reshaped dimension of softmax_outs = val_batch_size X num_val_avg X num_classes
151 | softmax_outs = torch.reshape(softmax_outs, (batch["actual_batch_size"], self.params.num_val_avg, -1))
152 |
153 | # Mean averaging on softmax_outs across val_samples
154 | # Dimension of outs = val_batch_size X num_classes
155 | outs = torch.mean(softmax_outs, dim=1)
156 | # Dimension of predictions = test_batch_size X 1
157 | predictions = torch.argmax(outs, dim=-1)
158 |
159 | # Pick every num_val_avg label since labels are repeated in batch
160 | batch["inputs"]["labels"] = batch["inputs"]["labels"][::self.params.num_val_avg]
161 | # Calculate accuracy
162 | acc = (predictions == batch["inputs"]["labels"])
163 |
164 | return {
165 | 'loss': loss.detach().cpu(),
166 | 'acc':acc.detach().cpu(),
167 | 'kappa':{
168 | 'preds':predictions.detach().cpu(),
169 | 'labels':batch["inputs"]["labels"].detach().cpu()
170 | }
171 | }
172 |
173 |
174 | def test_step(self, batch):
175 | # Test time averaging: Apply different sets of randomly sampled in-context examples per test datapoint to average score predictions
176 |
177 | with torch.no_grad():
178 | if( self.params.amp ):
179 | with torch.cuda.amp.autocast():
180 | outputs = self.model(**batch["inputs"])
181 | else:
182 | outputs = self.model(**batch["inputs"])
183 |
184 | loss = outputs.loss
185 | # Dimension of logits = batch_size X num_classes
186 | # Where batch_size = test_batch_size * num_test_avg
187 | logits = outputs.logits
188 |
189 | # Mask invalid score classes as negative infinity
190 | # https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index
191 | mask = torch.zeros(logits.shape[0], logits.shape[1] + 1, dtype=logits.dtype, device=logits.device)
192 | mask[(torch.arange(logits.shape[0]), batch["max_labels"])] = 1
193 | mask = mask.cumsum(dim=1)[:, :-1]
194 | masked_logits = logits.masked_fill_(mask.eq(1), value=float('-inf'))
195 |
196 | # Apply a softmax over valid score classes only
197 | softmax_outs = nn.functional.softmax(masked_logits, dim=-1)
198 |
199 | # Reshaped dimension of softmax_outs = test_batch_size X num_test_avg X num_classes
200 | softmax_outs = torch.reshape(softmax_outs, (batch["actual_batch_size"], self.params.num_test_avg, -1))
201 |
202 | # Mean averaging on softmax_outs across test_samples
203 | # Dimension of outs = test_batch_size X num_classes
204 | outs = torch.mean(softmax_outs, dim=1)
205 | # Dimension of predictions = test_batch_size X 1
206 | predictions = torch.argmax(outs, dim=-1)
207 |
208 | # Pick every num_test_avg label since labels are repeated in batch
209 | batch["inputs"]["labels"] = batch["inputs"]["labels"][::self.params.num_test_avg]
210 | # Calculate accuracy
211 | acc = (predictions == batch["inputs"]["labels"])
212 |
213 | return {
214 | 'loss': loss.detach().cpu(),
215 | 'acc':acc.detach().cpu(),
216 | 'kappa':{
217 | 'preds':predictions.detach().cpu(),
218 | 'labels':batch["inputs"]["labels"].detach().cpu()
219 | }
220 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | scikit-learn
4 | neptune-client
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 | import random
4 |
5 | from transformers import logging
6 |
7 | from utils.utils import add_params
8 | from train.train_bert import train_bert
9 | from train.train_meta_learning_via_bert_incontext_tuning import train_meta_learning_via_bert_incontext_tuning
10 |
11 |
12 | # Disable warnings in hugging face logger
13 | logging.set_verbosity_error()
14 | # Set your Neptune token for logging
15 | NEPTUNE_API_TOKEN = "SET_TOKEN"
16 |
17 |
18 | def main():
19 | args = add_params()
20 |
21 | # Local saved models dir
22 | saved_models_dir = "../../../saved_models/"
23 | # Set random seed
24 | if args.seed != -1:
25 | random.seed(args.seed)
26 | torch.manual_seed(args.seed)
27 | cudnn.deterministic = True
28 | # Set device
29 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
30 | if args.cuda: assert device.type == 'cuda', 'no gpu found!'
31 | # Logging to neptune
32 | run = None
33 | if args.neptune:
34 | import neptune.new as neptune
35 | run = neptune.init(
36 | project = args.neptune_project,
37 | api_token = NEPTUNE_API_TOKEN,
38 | capture_hardware_metrics = False,
39 | name = args.name,
40 | )
41 | run["parameters"] = vars(args)
42 |
43 | if( args.amp ):
44 | # Using pytorch automatic mixed precision (fp16/fp32) for faster training
45 | # https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
46 | scaler = torch.cuda.amp.GradScaler()
47 | else:
48 | scaler = None
49 |
50 | # Train-val-test model
51 | if( args.meta_learning ):
52 | train_meta_learning_via_bert_incontext_tuning(args, run, device, saved_models_dir, scaler)
53 | else:
54 | train_bert(args, run, device, saved_models_dir, scaler)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
--------------------------------------------------------------------------------
/train/train_bert.py:
--------------------------------------------------------------------------------
1 | import time
2 | from tqdm import tqdm
3 |
4 | from models.bert import LanguageModelBase
5 | from utils.utils import agg_all_metrics, save_model
6 |
7 |
8 | def train_bert(args, run, device, saved_models_dir, scaler):
9 | # prepare data and model for training
10 | model = LanguageModelBase(args, device=device)
11 | model.prepare_data()
12 | model.prepare_model()
13 |
14 | # Log item info to neptune
15 | if args.neptune:
16 | run["parameters/n_labels"] = model.max_label - model.min_label + 1
17 |
18 | # Metric used is Quadratic Weighted Kappa (QWK)
19 | # Best test kappa
20 | best_test_metric = -1
21 | # Test kappa corresponding to best validation kappa
22 | test_metric_for_best_valid_metric = -1
23 | # Best validation kappa
24 | best_valid_metric = -1
25 |
26 |
27 | # Train-val-test loop
28 | for cur_iter in tqdm(range(args.iters)):
29 | train_loader, valid_loader, test_loader = model.dataloaders()
30 |
31 | # Train epoch
32 | start_time = time.time()
33 | # Set model to train mode needed if dropout, etc is used
34 | model.train()
35 | train_logs = []
36 | for batch in train_loader:
37 | batch = {k: v.to(device) for k, v in batch.items()}
38 | logs = model.train_step(batch, scaler)
39 | train_logs.append(logs)
40 | train_it_time = time.time() - start_time
41 |
42 | # Aggregate logs across all batches
43 | train_logs = agg_all_metrics(train_logs)
44 | # Log to neptune
45 | if args.neptune:
46 | run["metrics/train/accuracy"].log(train_logs['acc'])
47 | run["metrics/train/kappa"].log(train_logs['kappa'])
48 | run["metrics/train/loss"].log(train_logs['loss'])
49 | run["logs/train/it_time"].log(train_it_time)
50 |
51 | # Set model to test mode needed if dropout, etc is used
52 | model.eval()
53 | if( (cur_iter % args.eval_freq == 0) or (cur_iter >= args.iters) ):
54 | test_logs, valid_logs = [], []
55 | # Validation epoch
56 | eval_start_time = time.time()
57 | for batch in valid_loader:
58 | batch = {k: v.to(device) for k, v in batch.items()}
59 | logs = model.eval_step(batch)
60 | valid_logs.append(logs)
61 | eval_it_time = time.time()-eval_start_time
62 |
63 | # Test epoch
64 | test_start_time = time.time()
65 | for batch in test_loader:
66 | batch = {k: v.to(device) for k, v in batch.items()}
67 | logs = model.test_step(batch)
68 | test_logs.append(logs)
69 | test_it_time = time.time()-test_start_time
70 |
71 | # Aggregate logs across batches
72 | valid_logs = agg_all_metrics(valid_logs)
73 | test_logs = agg_all_metrics(test_logs)
74 |
75 | # Update metrics and save model
76 | if( float(test_logs['kappa']) > best_test_metric ):
77 | best_test_metric = float(test_logs['kappa'])
78 | # Save model with best test kappa (not based on validation set)
79 | dir_best_test_metric = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "/best_test_kappa/"
80 | save_model(dir_best_test_metric, model)
81 | if( float(valid_logs['kappa']) > best_valid_metric ):
82 | best_valid_metric = float(valid_logs['kappa'])
83 | test_metric_for_best_valid_metric = float(test_logs['kappa'])
84 | # Save model with best validation kappa
85 | dir_best_valid_metric = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "/best_valid_kappa/"
86 | save_model(dir_best_valid_metric, model)
87 |
88 | # Push logs to neptune
89 | if args.neptune:
90 | run["metrics/test/accuracy"].log(test_logs['acc'])
91 | run["metrics/test/kappa"].log(test_logs['kappa'])
92 | run["metrics/test/loss"].log(test_logs['loss'])
93 | run["metrics/valid/accuracy"].log(valid_logs['acc'])
94 | run["metrics/valid/kappa"].log(valid_logs['kappa'])
95 | run["metrics/valid/loss"].log(valid_logs['loss'])
96 | run["metrics/test/best_kappa"].log(best_test_metric)
97 | run["metrics/test/best_kappa_with_valid"].log(test_metric_for_best_valid_metric)
98 | run["logs/cur_iter"].log(cur_iter)
99 | run["logs/valid/it_time"].log(eval_it_time)
100 | run["logs/test/it_time"].log(test_it_time)
101 |
102 | # Save model every save_freq epochs
103 | if( (cur_iter % args.save_freq == 0) or (cur_iter >= args.iters) ):
104 | dir_model = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "cross_val_fold_{}".format(args.cross_val_fold) + "/" + "/epoch_{}/".format(cur_iter)
105 | save_model(dir_model, model)
106 |
--------------------------------------------------------------------------------
/train/train_meta_learning_via_bert_incontext_tuning.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | from torch import nn
4 | from tqdm import tqdm
5 |
6 | from models.meta_learning_via_bert_incontext_tuning import MetaLearningViaLanguageModelInContextTuning
7 | from utils.utils import agg_all_metrics, save_model
8 |
9 |
10 |
11 | def train_meta_learning_via_bert_incontext_tuning(args, run, device, saved_models_dir, scaler):
12 | # Load list of tasks for meta learning
13 | with open("data/tasks.json", "r") as f:
14 | task_list = json.load(f)
15 | if( args.debug ):
16 | task_list = task_list[0:2]
17 | # Load task to question map for each task
18 | with open("data/task_to_question.json", "r") as f:
19 | task_to_question = json.load(f)
20 | # Load task to passage map for each task
21 | with open("data/task_to_passage.json", "r") as f:
22 | task_to_passage = json.load(f)
23 | # Load passage texts
24 | with open("data/passages.json", "r") as f:
25 | passages = json.load(f)
26 |
27 | # Prepare data and model for training
28 | model = MetaLearningViaLanguageModelInContextTuning(args, device, task_list, task_to_question, task_to_passage, passages)
29 | model.prepare_data()
30 | model.prepare_model()
31 |
32 | # Dict of metric variables = kappa for each task trained on during meta learning
33 | metrics = {}
34 | for task in task_list:
35 | metrics[task] = {}
36 | # Best test kappa
37 | metrics[task]["best_test_metric"] = -1
38 | # Test kappa corresponding to best validation kappa
39 | metrics[task]["test_metric_for_best_valid_metric"] = -1
40 | # Best validation kappa
41 | metrics[task]["best_valid_metric"] = -1
42 |
43 | # Train-val-test loop
44 | loss_func = nn.CrossEntropyLoss()
45 | for cur_iter in tqdm(range(args.iters)):
46 | train_loader, valid_loaders, test_loaders = model.dataloaders()
47 |
48 | # Train epoch on one big train dataset = union of train datasets across items for meta learning
49 | start_time = time.time()
50 | # Set model to train mode needed if dropout, etc is used
51 | model.train()
52 | train_logs = []
53 | for batch in train_loader:
54 | batch = {k: v.to(device) for k, v in batch.items()}
55 | logs = model.train_step(batch, scaler, loss_func)
56 | train_logs.append(logs)
57 | train_it_time = time.time() - start_time
58 |
59 | # Aggregate logs across all batches
60 | train_logs = agg_all_metrics(train_logs)
61 | # Log to neptune
62 | if args.neptune:
63 | run["metrics/train/accuracy"].log(train_logs['acc'])
64 | run["metrics/train/kappa"].log(train_logs['kappa'])
65 | run["metrics/train/loss"].log(train_logs['loss'])
66 | run["logs/train/it_time"].log(train_it_time)
67 |
68 | # Set model to test mode needed if dropout, etc is used
69 | model.eval()
70 | if( (cur_iter % args.eval_freq == 0) or (cur_iter >= args.iters) ):
71 | # Dict of validation and test logs for all items
72 | test_logs, valid_logs = {}, {}
73 | for task in task_list:
74 | test_logs[task], valid_logs[task] = [], []
75 |
76 | # Validation epoch for each item
77 | eval_start_time = time.time()
78 | for task in task_list:
79 | valid_loader = valid_loaders[task]
80 | for batch in valid_loader:
81 | batch = {k: v.to(device) for k, v in batch.items()}
82 | logs = model.eval_step(batch)
83 | valid_logs[task].append(logs)
84 | eval_it_time = time.time()-eval_start_time
85 |
86 | # Test epoch for each item
87 | test_start_time = time.time()
88 | for task in task_list:
89 | test_loader = test_loaders[task]
90 | for batch in test_loader:
91 | batch = {k: v.to(device) for k, v in batch.items()}
92 | logs = model.test_step(batch)
93 | test_logs[task].append(logs)
94 | test_it_time = time.time()-test_start_time
95 |
96 | # Aggregate logs across batches and and across items
97 | for task in task_list:
98 | valid_logs[task] = agg_all_metrics(valid_logs[task])
99 | test_logs[task] = agg_all_metrics(test_logs[task])
100 |
101 | for task in task_list:
102 | # Update metrics
103 | if( len(test_logs[task]) > 0 ):
104 | metrics[task]["best_test_metric"] = max(test_logs[task]['kappa'], metrics[task]["best_test_metric"])
105 | if( len(valid_logs[task]) > 0 ):
106 | if( float(valid_logs[task]["kappa"]) > metrics[task]["best_valid_metric"] ):
107 | metrics[task]["best_valid_metric"] = valid_logs[task]["kappa"]
108 | if( len(test_logs[task]) > 0 ):
109 | metrics[task]["test_metric_for_best_valid_metric"] = float(test_logs[task]["kappa"])
110 |
111 | # Log to neptune for all items
112 | if args.neptune:
113 | if( len(test_logs[task]) > 0 ):
114 | run["metrics/{}/test/accuracy".format(task)].log(test_logs[task]['acc'])
115 | run["metrics/{}/test/kappa".format(task)].log(test_logs[task]['kappa'])
116 | run["metrics/{}/test/loss".format(task)].log(test_logs[task]['loss'])
117 | run["metrics/{}/test/best_kappa".format(task)].log(metrics[task]["best_test_metric"])
118 | run["metrics/{}/test/best_kappa_with_valid".format(task)].log(metrics[task]["test_metric_for_best_valid_metric"])
119 | if( len(valid_logs[task]) > 0 ):
120 | run["metrics/{}/valid/accuracy".format(task)].log(valid_logs[task]['acc'])
121 | run["metrics/{}/valid/kappa".format(task)].log(valid_logs[task]['kappa'])
122 | run["metrics/{}/valid/loss".format(task)].log(valid_logs[task]['loss'])
123 | run["metrics/{}/valid/best_kappa".format(task)].log(metrics[task]["best_valid_metric"])
124 | run["logs/cur_iter"].log(cur_iter)
125 | run["logs/valid/it_time"].log(eval_it_time)
126 | run["logs/test/it_time"].log(test_it_time)
127 |
128 | # Save model after every epoch irrespective of save_freq param
129 | dir_model = saved_models_dir + args.name + "/" + run.get_run_url().split("/")[-1] + "/" + args.task + "/" + "cross_val_fold_{}".format(args.cross_val_fold) + "/" + "/epoch_{}/".format(cur_iter)
130 | save_model(dir_model, model)
--------------------------------------------------------------------------------
/utils/batch_collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 |
5 | def tokenize_function(tokenizer, sentences_1, sentences_2=None):
6 | if(sentences_2 == None):
7 | return tokenizer(sentences_1, padding=True, truncation=True, return_tensors="pt")
8 | else:
9 | return tokenizer(sentences_1, sentences_2, padding=True, truncation=True, return_tensors="pt")
10 |
11 |
12 | class CollateWraperParent(object):
13 | def __init__(self, tokenizer, min_label):
14 | self.tokenizer = tokenizer
15 | self.min_label = min_label
16 |
17 |
18 | class CollateWraper(CollateWraperParent):
19 | # batch collator for BERT fine-tuning
20 | def __init__(self, tokenizer, min_label):
21 | super().__init__(tokenizer, min_label)
22 |
23 | def __call__(self, batch):
24 | # Construct features
25 | features = [d['txt'] for d in batch]
26 | inputs = tokenize_function(self.tokenizer, features)
27 |
28 | # Construct labels
29 | labels = torch.tensor([d['l1'] if d['l1']>=0 else d['l2'] for d in batch]).long() - self.min_label
30 | inputs['labels'] = labels
31 |
32 | return {"inputs" : inputs}
33 |
34 |
35 | class CollateWraperInContextTuningMetaLearning(CollateWraperParent):
36 | # batch collator for meta learning via BERT in-context tuning
37 | def __init__(self, tokenizer, data_meta, task_to_question, num_examples, trunc_len, mode, num_test_avg=8,
38 | num_val_avg=8, test_batch_size=1, val_batch_size=1, max_seq_len=512, use_demographic=False):
39 | super().__init__(tokenizer, min_label = 1)
40 |
41 | self.data_meta = data_meta
42 | self.num_examples = num_examples
43 | self.trunc_len = trunc_len
44 | self.task_to_question = task_to_question
45 | self.mode = mode
46 | # Adding an extra 50 words in case num_tokens < num_words after tokenization
47 | self.max_seq_len = max_seq_len + 50
48 |
49 | # Convert numeric scores to meaningful words
50 | self.label_to_text = {
51 | 1 : "poor",
52 | 2 : "fair",
53 | 3 : "good",
54 | 4 : "excellent"
55 | }
56 |
57 | # Demographic information
58 | self.use_demographic = use_demographic
59 | self.gender_map = {
60 | "1" : "male",
61 | "2" : "female"
62 | }
63 | self.race_map = {
64 | "1" : "white",
65 | "2" : "african american",
66 | "3" : "hispanic",
67 | "4" : "asian",
68 | "5" : "american indian",
69 | "6" : "pacific islander",
70 | "7" : "multiracial"
71 | }
72 |
73 | # Meta learning via BERT in-context tuning
74 | self.num_test_avg = num_test_avg
75 | self.num_val_avg = num_val_avg
76 | self.test_batch_size = test_batch_size
77 | self.val_batch_size = val_batch_size
78 |
79 |
80 | def __call__(self, batch):
81 | if( self.mode == "test" or self.mode == "val" ):
82 | # Since drop_last=False in test/val loader, record actual test_batch_size/val_batch_size for last batch constructed
83 | actual_batch_size = torch.tensor(len(batch)).long()
84 |
85 | # Repeat each test/val sample num_test_avg/num_val_avg times sequentially
86 | new_batch = []
87 | for d in batch:
88 | if( self.mode == "test" ):
89 | new_batch += [d for _ in range(self.num_test_avg)]
90 | else:
91 | new_batch += [d for _ in range(self.num_val_avg)]
92 | batch = new_batch
93 | else:
94 | actual_batch_size = torch.tensor(-1).long()
95 |
96 | # Construct features: features_1 (answer txt) will have different segment embeddings than features_2 (remaining txt)
97 | features_1 = []
98 | features_2 = []
99 | for d in batch:
100 | # Randomly sample num_examples in-context examples from each class in train set for datapoint d
101 | examples_many_per_class = []
102 | # List examples_each_class stores one example from each class
103 | examples_one_per_class = []
104 | labels = list(range(d["min"], d["max"] + 1))
105 |
106 | for label in labels:
107 | examples_class = self.data_meta[d["task"]]["examples"][label]
108 |
109 | # Remove current datapoint d from examples_class by checking unique booklet identifiers => no information leakage
110 | examples_class = [ex for ex in examples_class if ex["bl"] != d["bl"]]
111 |
112 | # Sampling num_examples without replacement
113 | if( len(examples_class) < self.num_examples ):
114 | random.shuffle(examples_class)
115 | examples_class_d = examples_class
116 | else:
117 | examples_class_d = random.sample(examples_class, self.num_examples)
118 |
119 | if( len(examples_class_d) > 1 ):
120 | examples_one_per_class += [examples_class_d[0]]
121 | examples_many_per_class += examples_class_d[1:]
122 | elif( len(examples_class_d) == 1 ):
123 | examples_one_per_class += [examples_class_d[0]]
124 | examples_many_per_class += []
125 | else:
126 | examples_one_per_class += []
127 | examples_many_per_class += []
128 |
129 | # Construct input text with task instructions
130 | if( self.use_demographic ):
131 | input_txt = "score this answer written by {} {} student: ".format(self.gender_map[d["sx"]], self.race_map[d["rc"]]) + d['txt']
132 | else:
133 | input_txt = "score this answer: " + d['txt']
134 | features_1.append(input_txt)
135 |
136 | # Add range of valid score classes for datapoint d
137 | examples_txt = " scores: " + " ".join([ (self.label_to_text[label] + " ") for label in range(d["min"], d["max"] + 1) ])
138 | # Add question text
139 | examples_txt += "[SEP] question: {} [SEP] ".format(self.task_to_question[d["task"]])
140 |
141 | # Shuffle examples across classes
142 | random.shuffle(examples_one_per_class)
143 | random.shuffle(examples_many_per_class)
144 |
145 | # Since truncation might occur if text length exceed max input length to LM,
146 | # we ensure at least one example from each score class is present
147 | examples_d = examples_one_per_class + examples_many_per_class
148 | curr_len = len(input_txt.split(" ") + examples_txt.split(" "))
149 | for i in range(len(examples_d)):
150 | example = examples_d[i]
151 | example_txt_tokens = example['txt'].split(" ")
152 | curr_example_len = len(example_txt_tokens)
153 | example_txt = " ".join(example_txt_tokens[:self.trunc_len])
154 | example_label = (example['l1'] if example['l1']>=0 else example['l2'])
155 | # [SEP] at the end of the last example is automatically added by tokenizer
156 | if( i == (len(examples_d)-1) ):
157 | if( self.use_demographic ):
158 | examples_txt += ( " example written by {} {} student: ".format(self.gender_map[example["sx"]], self.race_map[example["rc"]]) + example_txt + " score: " + self.label_to_text[example_label] )
159 | else:
160 | examples_txt += ( " example: " + example_txt + " score: " + self.label_to_text[example_label] )
161 | else:
162 | if( self.use_demographic ):
163 | examples_txt += ( " example written by {} {} student: ".format(self.gender_map[example["sx"]], self.race_map[example["rc"]]) + example_txt + " score: " + self.label_to_text[example_label] + " [SEP] " )
164 | else:
165 | examples_txt += ( " example: " + example_txt + " score: " + self.label_to_text[example_label] + " [SEP] " )
166 |
167 | # Stop adding in-context examples when max_seq_len is reached
168 | if( (curr_example_len + curr_len) > self.max_seq_len):
169 | break
170 | else:
171 | curr_len += curr_example_len
172 | features_2.append(examples_txt)
173 |
174 | inputs = tokenize_function(self.tokenizer, features_1, features_2)
175 |
176 | # Construct labels
177 | labels = torch.tensor([ (d['l1']-d["min"]) if d['l1']>=0 else (d['l2']-d["min"]) for d in batch]).long()
178 | inputs['labels'] = labels
179 |
180 | # Store max_label for each d in batch which is used during softmax masking
181 | max_labels = torch.tensor([( d["max"]-d["min"]+1) for d in batch]).long()
182 |
183 | return {
184 | "inputs" : inputs,
185 | "max_labels" : max_labels,
186 | "actual_batch_size" : actual_batch_size
187 | }
--------------------------------------------------------------------------------
/utils/load_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 | import os
4 |
5 |
6 | RAW_DIR = "../../data/NAEP_AS_Challenge_Data/Items for Item-Specific Models/"
7 |
8 |
9 | def compute_distribution(output):
10 | append_keys = {}
11 |
12 | for k, v in output.items():
13 | dist, new_key, total_count = defaultdict(float), k+'_dist', 0.
14 | for d in v:
15 | n_rating = (d['l1']>=0) + (d['l2']>=0)
16 | total_count +=n_rating
17 | if( d['l1'] != -1 ):
18 | dist[d['l1']] += 1./n_rating
19 | if( d['l2'] != -1):
20 | dist[d['l2']] += 1./n_rating
21 | for l in dist:
22 | dist[l] /= total_count
23 | append_keys[new_key] = dist
24 |
25 | for k in append_keys:
26 | output[k] = append_keys[k]
27 |
28 |
29 | def load_dataset_base(task, debug=False, data_folder="data_split_answer_spell_checked", cross_val_fold=1):
30 | """
31 | Returns a dictionary:
32 | {
33 | 'train' : [{key:value}], # training dataset
34 | 'val' : [{key:value}], # validation dataset
35 | 'test' : [{key:value}], # test dataset
36 | 'train_dist' : {label:percentage} # distribution of scores in train-set
37 | 'val_dist' : {label:percentage} # distribution of scores in val-set
38 | 'test_dist' : {label:percentage} # distribution of scores in test-set
39 | }
40 |
41 | Each of the train/val/test datasets are a list of samples. Each sample is a dictionary of following (key, value) pairs:
42 | {'bl':string, 'l1':int, 'l2':int, 'sx':string, 'rc':string, 'txt':string}
43 |
44 | The keys above are:
45 | bl: unique database like key to identify student response
46 | 'l1' : score by human rater 1
47 | 'l2' : score by human rater 2 (set as -1 if not available)
48 | 'sx' : sex of student (optional)
49 | 'rc' : race of student
50 | 'txt' : student response text to the reading comprehension item to be scored
51 | """
52 |
53 | suffix = "_".join(data_folder.split("_")[1:])
54 |
55 | if( cross_val_fold == 0 ):
56 | dir_name = RAW_DIR + task + "/" + data_folder
57 | else:
58 | # Use spell checked version and cross validation fold number
59 | dir_name = RAW_DIR + task + "/" + "cross_val_fold_{}".format(cross_val_fold) + "/" + data_folder
60 |
61 | data = {}
62 | filenames = [("train", "train"), ("val", "valid"), ("test", "test")]
63 | for i in range(len(filenames)):
64 | filename = os.path.join(dir_name, task.split('/')[1] + "_{}_{}.json".format(filenames[i][0], suffix))
65 | with open(filename, "r") as f:
66 | data[filenames[i][1]] = json.load(f)
67 |
68 | # Compute score distribution
69 | compute_distribution(data)
70 |
71 | # Debug with less data if required
72 | if(debug):
73 | if( len(data["train"]) > 0 ):
74 | data["train"] = data["train"][:4]
75 | if( len(data["valid"]) > 0 ):
76 | data["valid"] = data["valid"][:4]
77 | if( len(data["test"]) > 0 ):
78 | data["test"] = data["test"][:4]
79 |
80 | return data
81 |
82 |
83 | def load_dataset_in_context_tuning(task, debug=False, data_folder="data_split_answer_spell_checked", cross_val_fold=1):
84 | data = load_dataset_base(task, debug, data_folder, cross_val_fold)
85 |
86 | # Construct in_context examples from training dataset partitioned according to class score label
87 | examples_train = None
88 | max_label = max(data['train_dist'].keys())
89 | min_label = min(data['train_dist'].keys())
90 | examples_train = {}
91 | for label in range(min_label, max_label + 1):
92 | examples_train[label] = []
93 | for datapoint in data["train"]:
94 | label = datapoint['l1'] if datapoint['l1']>=0 else datapoint['l2']
95 | examples_train[label].append(datapoint)
96 |
97 | return data, examples_train, min_label, max_label
98 |
99 |
100 | def load_dataset_in_context_tuning_with_meta_learning(debug=False, data_folder="data_split_answer_spell_checked", task_list=[], cross_val_fold=1):
101 | # Load list of item names
102 | with open("data/tasks.json", "r") as f:
103 | task_list = json.load(f)
104 | if( debug ):
105 | task_list = task_list[0:2]
106 |
107 | data_meta = {}
108 | for task in task_list:
109 | data_meta[task] = {}
110 | data, examples_train, min_label, max_label = load_dataset_in_context_tuning(task, debug, data_folder, cross_val_fold)
111 | data_meta[task]["train"] = data["train"]
112 | data_meta[task]["valid"] = data["valid"]
113 | data_meta[task]["test"] = data["test"]
114 | # Add in-context examples from training dataset => no information leakage from val/test sets
115 | data_meta[task]["examples"] = {}
116 | for label in range(min_label, max_label + 1):
117 | data_meta[task]["examples"][label] = examples_train[label]
118 | # Add task, min_label and max_label info to each sample
119 | for set in ["train", "valid", "test"]:
120 | for sample in data_meta[task][set]:
121 | sample["min"] = min_label
122 | sample["max"] = max_label
123 | sample["task"] = task
124 |
125 | # Union of training datasets across tasks
126 | data_meta["train"] = []
127 | for task in task_list:
128 | data_meta["train"] += data_meta[task]["train"]
129 |
130 | return data_meta, None, None
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import cohen_kappa_score
3 | import argparse
4 | import torch
5 | import pathlib
6 |
7 |
8 | def tonp(x):
9 | if isinstance(x, (np.ndarray, float, int)):
10 | return np.array(x)
11 |
12 | return x.detach().cpu().numpy()
13 |
14 |
15 | def agg_all_metrics(outputs):
16 | # Aggregate metrics for entire epoch across all batches
17 |
18 | if( len(outputs) == 0 ):
19 | return outputs
20 |
21 | res = {}
22 | keys = [ k for k in outputs[0].keys() if not isinstance(outputs[0][k], dict) ]
23 | for k in keys:
24 | all_logs = np.concatenate([tonp(x[k]).reshape(-1) for x in outputs])
25 | if( k != 'epoch' ):
26 | res[k] = np.mean(all_logs)
27 | else:
28 | res[k] = all_logs[-1]
29 |
30 | if 'kappa' in outputs[0]:
31 | pred_logs = np.concatenate([tonp(x['kappa']['preds']).reshape(-1) for x in outputs])
32 | label_logs = np.concatenate([tonp(x['kappa']['labels']).reshape(-1) for x in outputs])
33 | if( np.array_equal(pred_logs, label_logs) ):
34 | # Edge case: cohen_kappa_score from sklearn returns a value of NaN if perfect agreement
35 | res['kappa'] = 1
36 | else:
37 | res['kappa'] = cohen_kappa_score(pred_logs, label_logs, weights= 'quadratic')
38 |
39 | return res
40 |
41 |
42 | def add_params():
43 | parser = argparse.ArgumentParser(description='automated_scoring')
44 |
45 | parser.add_argument('--name', default='automated_scoring', help='Name of the experiment')
46 | parser.add_argument('--neptune_project', default="user_name/project_name", help='Name of the neptune project')
47 | # Problem definition
48 | parser.add_argument('--lm', default='bert-base-uncased', help='Base language model (provide any Hugging face model name)')
49 | parser.add_argument('--task', default="item_name", help='Item name (not required for meta learning via in-context tuning)')
50 | # Add demographic information for fairness analysis - generative models don't have this option
51 | parser.add_argument('--demographic', action='store_true', help='Use demographic information of student')
52 |
53 | # Meta learning BERT via in-context tuning
54 | # At testing, batch_size = batch_size * num_test_avg, similarly for validation
55 | # Ensure increased batch_size at test/val can be loaded onto GPU
56 | parser.add_argument('--meta_learning', action='store_true', help='Enable meta-learning via BERT in-context tuning')
57 | parser.add_argument('--num_test_avg', default=8, type=int, help='Number of different sets of randomly sampled examples per test datapoint to average score predictions')
58 | parser.add_argument('--num_val_avg', default=8, type=int, help='Number of different sets of randomly sampled examples per val datapoint to average score predictions')
59 | parser.add_argument('--num_examples', default=25, type=int, help='Number of in-context examples from each score class to add to input')
60 | parser.add_argument('--trunc_len', default=70, type=int, help='Max number of words in each in-context example')
61 |
62 | # Optimizer params
63 | parser.add_argument('--lr_schedule', default='warmup-const', help='Learning rate schedule to use')
64 | parser.add_argument('--opt', default='adam', choices=['sgd', 'adam', 'lars'], help='Optimizer to use')
65 | parser.add_argument('--iters', default=100, type=int, help='Number of epochs')
66 | parser.add_argument('--lr', default=2e-5, type=float, help='Base learning rate')
67 | parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
68 |
69 | # Data loading
70 | parser.add_argument('--data_folder', default="data_split_answer_spell_checked_submission", help='Dataset folder name containing train-val-test splits for each cross validation fold')
71 | parser.add_argument('--cross_val_fold', default=1, type=int, help='Cross validation fold to use')
72 |
73 | # Extras
74 | parser.add_argument('--save_freq', default=1, type=int, help='Epoch frequency to save the model')
75 | parser.add_argument('--eval_freq', default=1, type=int, help='Epoch frequency for evaluation')
76 | parser.add_argument('--workers', default=4, type=int, help='Number of data loader workers')
77 | parser.add_argument('--seed', default=999, type=int, help='Random seed')
78 | parser.add_argument('--cuda', action='store_true', help='Use cuda')
79 | parser.add_argument('--save', action='store_true', help='Save model every save_freq epochs')
80 | parser.add_argument('--neptune', action='store_true', help='Enable logging to Neptune')
81 | parser.add_argument('--debug', action='store_true', help='Debug mode with less items and smaller datasets')
82 | # Automatic mixed precision training -> faster training but might affect accuracy
83 | parser.add_argument('--amp', action='store_true', help='Apply automatic mixed precision training')
84 |
85 | params = parser.parse_args()
86 |
87 | return params
88 |
89 |
90 | def save_model(dir_model, model):
91 | pathlib.Path(dir_model).mkdir(parents=True, exist_ok=True)
92 | model.tokenizer.save_pretrained(dir_model)
93 | if( torch.cuda.device_count() > 1 ):
94 | # For an nn.DataParallel object, the model is stored in .module
95 | model.model.module.save_pretrained(dir_model)
96 | else:
97 | model.model.save_pretrained(dir_model)
--------------------------------------------------------------------------------