├── __init__.py ├── models.py ├── requirements.txt ├── LICENSE ├── .gitignore ├── README.md ├── run_inference.py └── utils.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clean-text==0.6.0 2 | num2words==0.5.10 3 | pandas==1.4.1 4 | scipy==1.7.3 5 | torch==1.10.2 6 | transformers==4.14.1 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ddemszky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conversational Uptake 2 | This repository contains data for the paper: 3 | 4 | Demszky, D., Liu, J., Mancenido, Z., Cohen, J., Hill, H., Jurafsky, D., & Hashimoto, T. (2021). [Measuring Conversational Uptake: A Case Study on Student-Teacher Interactions](https://arxiv.org/pdf/2106.03873.pdf). In _Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics (ACL)_. 5 | 6 | ``` 7 | @inproceedings{demszky2021measuring, 8 | title={{Measuring Conversational Uptake: A Case Study on Student-Teacher Interactions}}, 9 | author={Demszky, Dorottya and Liu, Jing and Mancenido, Zid and Cohen, Julie and Hill, Heather and Jurafsky, Dan and Hashimoto, Tatsunori}, 10 | booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics (ACL)}, 11 | year={2021} 12 | } 13 | ``` 14 | 15 | ## Annotated Uptake Dataset 16 | 17 | The annotated dataset contains a sample of **2246 exchanges** extracted from a dataset of anonymized 4-5th grade US elementary math classroom transcripts collected by the [National Center for Teacher Effectiveness (NCTE)](https://cepr.harvard.edu/ncte) in New England schools between 2010-2013. These exchanges are turns by students (with at least 5 words), followed by teacher turns in a classroom conversation. 18 | 19 | The exchanges are annotated by thirteen experts in math instruction (former and current math teachers and trained raters for classroom observation protocols). The coding instrument can be viewed [here](https://docs.google.com/document/d/1UGAXW3H-bV1m0PWcDM7aGcRgkdrY-fovcPstB4YphvA/edit?usp=sharing). 20 | 21 | Each exchange is coded for three items: 22 | * `student_on_task`: Whether the student utterance is on task (related to math). This is a binary variable: either 0 (off task) or 1 (on task). 23 | * `teacher_on_task`: Whether the teacher utterance is on task (related to math). This is a binary variable: either 0 (off task) or 1 (on task). 24 | * `uptake`: Degree of uptake, or in other words, the extent to which the teacher demonstrates that they have heard the student by building on their contribution. This could take values 0 (low), 1 (mid), 2 (high). 25 | 26 | The data is in the comma-separated file `data/uptake_dataset.csv`. 27 | 28 | The file includes the following columns: 29 | 30 | * `obs_id`: Observation ID, mappable to unique transcripts in the NCTE dataset. 31 | * `exchange_idx`: ID of the exchange within the transcript. 32 | * `student_text`: Student utterance. 33 | * `teacher_text`: Teacher utterance (following the utterance in `student_text`). 34 | * `student_on_task`: Average rating for `student_on_task` across the three raters. 35 | * `student_on_task_majority`: The majority rating for `student_on_task` across the three raters. 36 | * `student_on_task_num_agree`: Number of raters who agree on the `student_on_task` code. 37 | * `student_on_task_zscore`: Average rating for `student_on_task`, after z-scoring the ratings for each rater. 38 | * `teacher_on_task`: Average rating for `teacher_on_task` across the three raters. 39 | * `teacher_on_task_majority`: The majority rating for `teacher_on_task` across the three raters. 40 | * `teacher_on_task_num_agree`: Number of raters who agree on the `teacher_on_task` code. 41 | * `teacher_on_task_zscore`: Average rating for `teacher_on_task`, after z-scoring the ratings for each rater. 42 | * `uptake`: Average rating for `uptake` across the three raters. 43 | * `uptake_majority`: The majority rating for `uptake` across the three raters. Value is None if there is no majority label (no agreement between any of the raters). 44 | * `uptake_num_agree`: Number of raters who agree on the `uptake` code. 45 | * `uptake_zscore`: Average rating for `uptake`, after z-scoring the ratings for each rater. *We use this item for our main evaluations*. 46 | 47 | Each example can be **uniquely identified with the combination of the `obs_id` and `exchange_idx` columns**. 48 | 49 | ## Pre-Trained Model 50 | 51 | The uptake model is available on Huggingface for download: https://huggingface.co/stanford-nlpxed/uptake-model 52 | 53 | 54 | ## Running inference 55 | 56 | Please follow the following steps to run inference with the pre-trained model: 57 | 1. Create virtual environment: `python3 -m venv venv` 58 | 2. Activate virtual environment: `source venv/bin/activate` 59 | 3. Install requirements `$ pip3 install -r requirements.txt`. Currently the Pytorch version is for a CPU, so if you're running this on a GPU, you'll probably want to update the Pytorch (and maybe transformer) installation so that it works on a GPU. 60 | 4. Download and unzip the model checkpoint -- see above. 61 | 5. Put all your data into a single csv file. There should be a column indicating the utterance from speaker A and the utterance from speaker B, and the model will predict to what extent speaker B's utterance takes up speaker A's utterance. See the `data/uptake_annotations.csv` file for an example, where speaker A = `student_text` and speaker B = `teacher_text`. 62 | 6. You can inference like this: `$ python3 run_inference.py --data_file data/uptake_data.csv --speakerA student_text --speakerB teacher_text --output_col uptake_predictions --output predictions/uptake_data_predictions.csv` 63 | 64 | **Notes** 65 | * Make sure there are no empty string or NaNs in your data. 66 | * The uptake model will only predict scores for utterance pairs where the first utterance is at least *5 tokens* long, ignoring punctuation. 67 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | """Author: Dora Demszky 2 | 3 | Predict uptake scores for utterance pairs, by running inference with an existing model checkpoint. 4 | 5 | Usage: 6 | 7 | python run_inference.py --data_file data/uptake_data.csv --speakerA student_text --speakerB teacher_text --output_col uptake_predictions --output predictions/uptake_data_predictions.csv 8 | 9 | """ 10 | 11 | from argparse import ArgumentParser 12 | import string 13 | import re 14 | from scipy.special import softmax 15 | import pandas as pd 16 | 17 | from utils import clean_str, clean_str_nopunct 18 | import torch 19 | from transformers import BertTokenizer 20 | from utils import MultiHeadModel, BertInputBuilder 21 | 22 | punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'})) 23 | punct_chars.sort() 24 | punctuation = ''.join(punct_chars) 25 | replace = re.compile('[%s]' % re.escape(punctuation)) 26 | 27 | 28 | def get_num_words(text): 29 | if not isinstance(text, str): 30 | print("%s is not a string" % text) 31 | text = replace.sub(' ', text) 32 | text = re.sub(r'\s+', ' ', text) 33 | text = text.strip() 34 | text = re.sub(r'\[.+\]', " ", text) 35 | return len(text.split()) 36 | 37 | 38 | def get_clean_text(text, remove_punct=False): 39 | if remove_punct: 40 | return clean_str_nopunct(text) 41 | return clean_str(text) 42 | 43 | 44 | def get_prediction(model, instance, device): 45 | instance["attention_mask"] = [[1] * len(instance["input_ids"])] 46 | for key in ["input_ids", "token_type_ids", "attention_mask"]: 47 | instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 48 | instance[key].to(device) 49 | 50 | output = model(input_ids=instance["input_ids"], 51 | attention_mask=instance["attention_mask"], 52 | token_type_ids=instance["token_type_ids"], 53 | return_pooler_output=False) 54 | return output 55 | 56 | def get_uptake_score(utterances, speakerA, speakerB, model, device, input_builder, max_length): 57 | 58 | textA = get_clean_text(utterances[speakerA], remove_punct=False) 59 | textB = get_clean_text(utterances[speakerB], remove_punct=False) 60 | 61 | instance = input_builder.build_inputs([textA], textB, 62 | max_length=max_length, 63 | input_str=True) 64 | output = get_prediction(model, instance, device) 65 | uptake_score = softmax(output["nsp_logits"][0].tolist())[1] 66 | return uptake_score 67 | 68 | 69 | def main(): 70 | parser = ArgumentParser() 71 | parser.add_argument("--data_file", type=str, default="", help="Path or url of the dataset (csv).") 72 | parser.add_argument("--speakerA", type=str, default="speakerA", help="Column indicating speaker A.") 73 | parser.add_argument("--speakerB", type=str, default="speakerB", help="Column indicating speaker B (uptake is calculated for this speaker).") 74 | parser.add_argument("--model_checkpoint", type=str, 75 | default="checkpoints/Feb25_09-02-16_combined_education_dataset_02252021.json_6.25e-05_hist1_cand4_bert-base-uncased_ne1_nsp1", 76 | help="Path, url or short name of the model") 77 | parser.add_argument("--output_col", type=str, default="uptake_predictions", 78 | help="Name of column for storing predictions.") 79 | parser.add_argument("--output", type=str, default="", 80 | help="Filename for storing predictions.") 81 | parser.add_argument("--max_length", type=int, default=120, help="Maximum input sequence length") 82 | parser.add_argument("--student_min_words", type=int, default=5, help="Maximum input sequence length") 83 | args = parser.parse_args() 84 | 85 | 86 | print("Loading models...") 87 | device = "cuda" if torch.cuda.is_available() else "cpu" 88 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 89 | input_builder = BertInputBuilder(tokenizer=tokenizer) 90 | uptake_model = MultiHeadModel.from_pretrained(args.model_checkpoint, head2size={"nsp": 2}) 91 | uptake_model.to(device) 92 | 93 | utterances = pd.read_csv(args.data_file) 94 | print("EXAMPLES") 95 | for i, row in utterances.head().iterrows(): 96 | print("speaker A: %s" % row[args.speakerA]) 97 | print("speaker B: %s" % row[args.speakerB]) 98 | print("----") 99 | 100 | print("Running inference on %d examples..." % len(utterances)) 101 | uptake_model.eval() 102 | uptake_scores = [] 103 | with torch.no_grad(): 104 | for i, utt in utterances.iterrows(): 105 | prev_num_words = get_num_words(utt[args.speakerA]) 106 | if prev_num_words < args.student_min_words: 107 | uptake_scores.append(None) 108 | continue 109 | uptake_score = get_uptake_score(utterances=utt, 110 | speakerA=args.speakerA, 111 | speakerB=args.speakerB, 112 | model=uptake_model, 113 | device=device, 114 | input_builder=input_builder, 115 | max_length=args.max_length) 116 | uptake_scores.append(uptake_score) 117 | 118 | utterances[args.output_col] = uptake_scores 119 | utterances.to_csv(args.output, index=False) 120 | 121 | 122 | 123 | 124 | if __name__ == "__main__": 125 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel 3 | from torch import nn 4 | from itertools import chain 5 | from torch.nn import MSELoss, CrossEntropyLoss 6 | from cleantext import clean 7 | from num2words import num2words 8 | import re 9 | 10 | def number_to_words(num): 11 | try: 12 | return num2words(re.sub(",", "", num)) 13 | except: 14 | return num 15 | 16 | 17 | clean_str = lambda s: clean(s, 18 | fix_unicode=True, # fix various unicode errors 19 | to_ascii=True, # transliterate to closest ASCII representation 20 | lower=True, # lowercase text 21 | no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them 22 | no_urls=True, # replace all URLs with a special token 23 | no_emails=True, # replace all email addresses with a special token 24 | no_phone_numbers=True, # replace all phone numbers with a special token 25 | no_numbers=True, # replace all numbers with a special token 26 | no_digits=False, # replace all digits with a special token 27 | no_currency_symbols=False, # replace all currency symbols with a special token 28 | no_punct=False, # fully remove punctuation 29 | replace_with_url="", 30 | replace_with_email="", 31 | replace_with_phone_number="", 32 | replace_with_number=lambda m: number_to_words(m.group()), 33 | replace_with_digit="0", 34 | replace_with_currency_symbol="", 35 | lang="en" 36 | ) 37 | 38 | clean_str_nopunct = lambda s: clean(s, 39 | fix_unicode=True, # fix various unicode errors 40 | to_ascii=True, # transliterate to closest ASCII representation 41 | lower=True, # lowercase text 42 | no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them 43 | no_urls=True, # replace all URLs with a special token 44 | no_emails=True, # replace all email addresses with a special token 45 | no_phone_numbers=True, # replace all phone numbers with a special token 46 | no_numbers=True, # replace all numbers with a special token 47 | no_digits=False, # replace all digits with a special token 48 | no_currency_symbols=False, # replace all currency symbols with a special token 49 | no_punct=True, # fully remove punctuation 50 | replace_with_url="", 51 | replace_with_email="", 52 | replace_with_phone_number="", 53 | replace_with_number=lambda m: number_to_words(m.group()), 54 | replace_with_digit="0", 55 | replace_with_currency_symbol="", 56 | lang="en" 57 | ) 58 | 59 | 60 | 61 | class MultiHeadModel(BertPreTrainedModel): 62 | """Pre-trained BERT model that uses our loss functions""" 63 | 64 | def __init__(self, config, head2size): 65 | super(MultiHeadModel, self).__init__(config, head2size) 66 | config.num_labels = 1 67 | self.bert = BertModel(config) 68 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 69 | module_dict = {} 70 | for head_name, num_labels in head2size.items(): 71 | module_dict[head_name] = nn.Linear(config.hidden_size, num_labels) 72 | self.heads = nn.ModuleDict(module_dict) 73 | 74 | self.init_weights() 75 | 76 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, 77 | head2labels=None, return_pooler_output=False, head2mask=None, 78 | nsp_loss_weights=None): 79 | 80 | device = "cuda" if torch.cuda.is_available() else "cpu" 81 | 82 | # Get logits 83 | output = self.bert( 84 | input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, 85 | output_attentions=False, output_hidden_states=False, return_dict=True) 86 | pooled_output = self.dropout(output["pooler_output"]).to(device) 87 | 88 | head2logits = {} 89 | return_dict = {} 90 | for head_name, head in self.heads.items(): 91 | head2logits[head_name] = self.heads[head_name](pooled_output) 92 | head2logits[head_name] = head2logits[head_name].float() 93 | return_dict[head_name + "_logits"] = head2logits[head_name] 94 | 95 | 96 | if head2labels is not None: 97 | for head_name, labels in head2labels.items(): 98 | num_classes = head2logits[head_name].shape[1] 99 | 100 | # Regression (e.g. for politeness) 101 | if num_classes == 1: 102 | 103 | # Only consider positive examples 104 | if head2mask is not None and head_name in head2mask: 105 | num_positives = head2labels[head2mask[head_name]].sum() # use certain labels as mask 106 | if num_positives == 0: 107 | return_dict[head_name + "_loss"] = torch.tensor([0]).to(device) 108 | else: 109 | loss_fct = MSELoss(reduction='none') 110 | loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) 111 | return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives 112 | else: 113 | loss_fct = MSELoss() 114 | return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) 115 | else: 116 | loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float()) 117 | return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1)) 118 | 119 | 120 | if return_pooler_output: 121 | return_dict["pooler_output"] = output["pooler_output"] 122 | 123 | return return_dict 124 | 125 | class InputBuilder(object): 126 | """Base class for building inputs from segments.""" 127 | 128 | def __init__(self, tokenizer): 129 | self.tokenizer = tokenizer 130 | self.mask = [tokenizer.mask_token_id] 131 | 132 | def build_inputs(self, history, reply, max_length): 133 | raise NotImplementedError 134 | 135 | def mask_seq(self, sequence, seq_id): 136 | sequence[seq_id] = self.mask 137 | return sequence 138 | 139 | @classmethod 140 | def _combine_sequence(self, history, reply, max_length, flipped=False): 141 | # Trim all inputs to max_length 142 | history = [s[:max_length] for s in history] 143 | reply = reply[:max_length] 144 | if flipped: 145 | return [reply] + history 146 | return history + [reply] 147 | 148 | 149 | class BertInputBuilder(InputBuilder): 150 | """Processor for BERT inputs""" 151 | 152 | def __init__(self, tokenizer): 153 | InputBuilder.__init__(self, tokenizer) 154 | self.cls = [tokenizer.cls_token_id] 155 | self.sep = [tokenizer.sep_token_id] 156 | self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"] 157 | self.padded_inputs = ["input_ids", "token_type_ids"] 158 | self.flipped = False 159 | 160 | 161 | def build_inputs(self, history, reply, max_length, input_str=True): 162 | """See base class.""" 163 | if input_str: 164 | history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history] 165 | reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply)) 166 | sequence = self._combine_sequence(history, reply, max_length, self.flipped) 167 | sequence = [s + self.sep for s in sequence] 168 | sequence[0] = self.cls + sequence[0] 169 | 170 | instance = {} 171 | instance["input_ids"] = list(chain(*sequence)) 172 | last_speaker = 0 173 | other_speaker = 1 174 | seq_length = len(sequence) 175 | instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker 176 | for i, s in enumerate(sequence) for _ in s] 177 | return instance --------------------------------------------------------------------------------