├── README.md ├── corpus ├── AMI │ ├── AMI.py │ └── __init__.py ├── Corpus.py ├── Maptask │ ├── Maptask.py │ └── __init__.py ├── Oasis │ ├── Oasis.py │ └── __init__.py ├── Simulated │ ├── Simulated.py │ └── __init__.py ├── Switchboard │ ├── DAMSL.py │ ├── Switchboard.py │ ├── Switchboard.pyc │ └── __init__.py └── __init__.py ├── data ├── README.md └── data.zip ├── eval_svmlight_output.py ├── eval_svmlight_output_insertion.py ├── generate_grid.py ├── generate_joined_dat.py ├── generate_noentities_baseline.py ├── generate_shuffled.py ├── load_grids.py └── train_models.py /README.md: -------------------------------------------------------------------------------- 1 | # Coherence-models-for-dialogue 2 | This is the repository for the Interspeech 2018 paper ["Coherence models for dialogue"](https://arxiv.org/pdf/1806.08044.pdf) . 3 | If you use our code, please cite our paper: 4 | 5 | Cervone, A., Stepanov, E.A., & Riccardi, G. (2018). Coherence Models for Dialogue. Interspeech. 6 | 7 | ## Prerequisites 8 | 9 | - python 2.7+ 10 | - spacy version 1 11 | - tqdm 12 | - mlxtend 13 | 14 | The code generates input feature vectors files for [SVM light](http://svmlight.joachims.org/) version 6.02, so in order to finally train the models you will need to install that as well. 15 | 16 | 17 | ## Data preprocessing 18 | 19 | The data used in the experiments (i.e. only the grid files, since source corpora are under licenses) is available in the `data` folder. Furthermore, the scripts we used to generate the data from source corpora is available. See the README file in the `data/` folder for further details. 20 | The corpora preprocessing step in the `corpus` folder is a modification of the scripts from [this library](https://github.com/ColingPaper2018/DialogueAct-Tagger). 21 | 22 | ### Where do I find the source corpora used in the experiments? 23 | - BT Oasis is available [via email request](http://groups.inf.ed.ac.uk/oasis/) 24 | - The Switchboard Dialogue Act Corpus is available [here](https://web.stanford.edu/~jurafsky/swb1_dialogact_annot.tar.gz) 25 | - AMI is available for download from [this page](http://groups.inf.ed.ac.uk/ami/download/) 26 | 27 | 28 | ## Getting started 29 | 30 | ### Generate features vectors with the provided data 31 | 32 | After having unzipped the file `data/data.zip` (the data used in our experiments), you can directly generate the feature vectors for SVM light for the corpus Oasis with default parameters: 33 | ``` 34 | python train_models.py -g Oasis 35 | ``` 36 | You can find the generated feature vector files in the newly created path `experiments/Oasis/reordering/egrid_-coref/` , divided according to the train/dev/test splits. 37 | 38 | ### Train 39 | 40 | Then, to train a model using SVM light you can run: 41 | ``` 42 | svm_learn -z p experiments/Oasis/reordering/egrid_-coref/Oasis_sal1_range2_2_train.dat my_model 43 | ``` 44 | 45 | ### Predict 46 | 47 | To classify the test set using your newly trained model: 48 | ``` 49 | svm_classify experiments/Oasis/reordering/egrid_-coref/Oasis_sal1_range2_2_test.dat my_model my_prediction 50 | ``` 51 | 52 | ### Evaluate 53 | 54 | Finally, to get the accuracy and other metrics reported in our paper for the newly trained model on the testset: 55 | ``` 56 | python eval_svmlight_output.py --testfile experiments/Oasis/reordering/egrid_-coref/Oasis_sal1_range2_2_test.dat --predfile my_prediction 57 | ``` 58 | If you don't specify the file with the model's predictions (```--predfile```), the script reports the performance of a random baseline on the testset using the same metrics: 59 | ``` 60 | python eval_svmlight_output.py --testfile experiments/Oasis/reordering/egrid_-coref/Oasis_sal1_range2_2_test.dat 61 | ``` 62 | ### Data generation from the corpus 63 | 64 | For generating grids for training entity grid models using only entities (without coreference) for the corpus Oasis (provided you gave the correct path to the Oasis source files) in verbose mode: 65 | ``` 66 | python generate_grid.py Oasis egrid_-coref egrid_-coref data/ -v 67 | ``` 68 | After having generated the original grids, generate shuffled grids (for the reordering task) for the same corpus: 69 | ``` 70 | python generate_shuffled.py -gs Oasis 71 | ``` 72 | 73 | 74 | ## Link to our experiments files 75 | You can get all our experiments files (obtained with the above procedure) [here](https://www.dropbox.com/s/1sewemx965o2jec/experiments.zip?dl=0) 76 | -------------------------------------------------------------------------------- /corpus/AMI/AMI.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import unicode_literals 3 | import os 4 | from collections import OrderedDict 5 | from corpus.Corpus import Corpus 6 | """ 7 | AMI class: loads the corpus into tuples (sentence,DA,prevDA). Provides methods 8 | to dump the corpus in CSV format with original annotation and with ISO annotation 9 | """ 10 | 11 | 12 | class AMI(Corpus): 13 | def __init__(self, ami_folder): 14 | # check whether the ami_folder contains a valid AMI installation 15 | try: 16 | assert os.path.exists(ami_folder) # folder exists 17 | assert os.path.exists(ami_folder+"/words/ES2002a.A.words.xml") # words files exist 18 | assert os.path.exists(ami_folder+"/dialogueActs/ES2002a.A.dialog-act.xml") # DA files exist 19 | except AssertionError: 20 | print("The folder "+ami_folder+" does not contain some important files from the corpus.") 21 | print("You can download a complete version of the corpus at http://groups.inf.ed.ac.uk/ami/download/") 22 | exit(1) 23 | self.ami_folder = ami_folder 24 | self.csv_corpus = [] 25 | 26 | def load_csv(self): 27 | dialogs={} # this will store dialogs from the corpus 28 | dialog_names=[] # this will store filenames from the corpus 29 | for dialog_name in os.listdir(self.ami_folder+"/dialogueActs/"): 30 | if "dialog-act" in dialog_name: # DA file 31 | dialog_names.append(dialog_name.split("dialog-act")[0]) 32 | for dialog_name in dialog_names: 33 | dialogs[dialog_name] = OrderedDict() 34 | self.load_words(dialogs, dialog_name) 35 | self.load_dialog_acts(dialogs, dialog_name) 36 | self.load_segments(dialogs, dialog_name) 37 | self.csv_corpus = self.create_csv(dialogs) 38 | return self.csv_corpus 39 | 40 | def load_words(self,dialogs,dialog_name): 41 | with open(self.ami_folder+"/words/" + dialog_name + "words.xml") as wfile: 42 | for line in wfile: 43 | if not "")[1].split("<")[0] 49 | dialogs[dialog_name][word_id] = [] 50 | dialogs[dialog_name][word_id].append(word_value) 51 | 52 | def load_dialog_acts(self,dialogs,dialog_name): 53 | with open(self.ami_folder+"/dialogueActs/" + dialog_name + "dialog-act.xml") as actfile: 54 | dact = "" 55 | for line in actfile: 56 | if "")[1].split("<")[0] 93 | dialogs[move_id] = {"move": utt_id, "text": value, "speaker": speaker} 94 | return dialogs 95 | 96 | def create_csv(self, dialogs): 97 | csv_corpus=[] 98 | conversations = [] 99 | current_move = 0 100 | current_conv = {} 101 | for sent, da, move in dialogs: 102 | if move < current_move and move % 2 == 0: 103 | conversations.append(current_conv) 104 | current_conv = {} 105 | current_move = move 106 | current_conv[move] = (sent, da) 107 | for c in conversations: 108 | segment=0 109 | prevDA = "unclassifiable" 110 | for k in sorted(c.keys()): 111 | csv_corpus.append(tuple(list(c[k])+[prevDA, segment, None, None])) 112 | prevDA = c[k][1] 113 | segment+=1 114 | return csv_corpus 115 | 116 | @staticmethod 117 | def da_to_dimension(corpus_tuple): 118 | da=corpus_tuple[1] 119 | if da == "acknowledge": 120 | return "Feedback" 121 | elif da == "unclassifiable": 122 | return None 123 | else: # everything else is a task. that's why they call it maptask :) 124 | return "Task" 125 | 126 | @staticmethod 127 | def da_to_cf(corpus_tuple): 128 | da=corpus_tuple[1] 129 | if da == "acknowledge": 130 | return "Feedback" 131 | elif da in ["explain", "clarify", "reply-y", "reply-n", "reply-w"]: 132 | return "Statement" 133 | elif da in ["query_yn"]: 134 | return "PropQ" 135 | elif da in ["query_w"]: 136 | return "SetQ" 137 | elif da in ["instruct"]: 138 | return "Directive" 139 | else: 140 | return None -------------------------------------------------------------------------------- /corpus/Maptask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/corpus/Maptask/__init__.py -------------------------------------------------------------------------------- /corpus/Oasis/Oasis.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import unicode_literals 3 | import os 4 | import csv 5 | from lxml import etree 6 | from corpus.Corpus import Corpus 7 | """ 8 | Oasis class: loads the corpus into tuples (sentence,DA,prevDA). Provides methods 9 | to dump the corpus in CSV format with original annotation and with ISO annotation 10 | """ 11 | 12 | 13 | class Oasis(Corpus): 14 | def __init__(self, oasis_folder): 15 | # check whether the oasis_folder contains a valid Oasis installation 16 | try: 17 | assert os.path.exists(oasis_folder) # folder exists 18 | assert os.path.exists(oasis_folder + "/Data/Lancs_BT150") # dialogs folders exist 19 | assert os.path.exists(oasis_folder + "/Data/Lancs_BT150/075812009.a.lturn.xml") # DA files exist 20 | except AssertionError: 21 | print("The folder " + oasis_folder + " does not contain some important files from the corpus.") 22 | print("Check http://groups.inf.ed.ac.uk/oasis/ for info on how to obtain the complete SWDA corpus.") 23 | exit(1) 24 | self.oasis_folder = oasis_folder 25 | self.csv_corpus = [] 26 | 27 | def load_csv(self): 28 | # Read dialogue files from Oasis 29 | dialogs=self.create_dialogs() 30 | self.csv_corpus = self.create_csv(dialogs) 31 | return self.csv_corpus 32 | 33 | def create_dialogs(self): 34 | dialogs = {} 35 | for fname in os.listdir(self.oasis_folder+"/Data/Lancs_BT150/"): 36 | f = open(self.oasis_folder+"/Data/Lancs_BT150/" + fname.strip()) 37 | t = etree.parse(f) 38 | turns = t.xpath("//lturn") 39 | for turn in turns: 40 | self.parse_xml_turn(dialogs, turn) 41 | return dialogs 42 | 43 | def parse_xml_turn(self, dialogs, turn): 44 | dialog_id = turn.attrib["id"].split(".")[0] 45 | try: # subturn 46 | turn_id = int(turn.attrib["id"].split(".")[-2]) 47 | except: # turn 48 | turn_id = int(turn.attrib["id"].split(".")[-1]) 49 | if dialogs.get(dialog_id, None) is None: # new dialog 50 | dialogs[dialog_id] = {} 51 | if dialogs[dialog_id].get(turn_id, None) is None: # new turn 52 | dialogs[dialog_id][turn_id] = [] 53 | segments = turn.xpath(".//segment") 54 | for segment in segments: 55 | self.add_segment_to_dialog(dialogs, dialog_id, turn_id, segment) 56 | 57 | def add_segment_to_dialog(self,dialogs, dialog_id, turn_id, segment): 58 | segm_type = segment.attrib["type"] 59 | tag = segment.attrib["sp-act"] 60 | try: 61 | wFile = segment[0].attrib["href"].split("#")[0] 62 | except: 63 | return 64 | ids = segment[0].attrib["href"].split("#")[1] 65 | start_id = ids.split("(")[1].split(")")[0] 66 | stop_id = ids.split("(")[-1][:-1] 67 | start_n = int(start_id.split(".")[3]) 68 | text = wFile.split(".xml")[0] 69 | if not 'anchor' in stop_id: 70 | stop_n = int(stop_id.split(".")[3]) 71 | else: 72 | stop_n = start_n 73 | id_set = ["@id = '" + text + "." + str(i) + "'" for i in range(start_n, stop_n + 1)] 74 | with open(self.oasis_folder+"/Data/Lancs_BT150/" + wFile) as f: 75 | tree = etree.parse(f) 76 | segment = tree.xpath('//*[' + " or ".join(id_set) + ']') 77 | sentence = " ".join([x.text for x in segment if 78 | x.text is not None and x.text not in ["?", ",", ".", "!", ";"]]) 79 | if sentence != "": 80 | dialogs[dialog_id][turn_id].append((sentence, tag, segm_type)) 81 | 82 | def create_csv(self, dialogs): 83 | ''' 84 | output csv: 85 | {filename : [(DA, utt, speaker, turn number)]} 86 | ''' 87 | # print('Dialogues len: ', len(dialogs)) 88 | # print('Dialogues type: ', type(dialogs)) 89 | # print('Dialogues keys: ', dialogs.keys()[0]) 90 | # print('Dialogues val ex: ', dialogs[dialogs.keys()[0]]) 91 | # print('Dialogues keys: ', dialogs.keys()[1]) 92 | # print('Dialogues val ex: ', dialogs[dialogs.keys()[1]]) 93 | csv_corpus = {} 94 | for d in dialogs: 95 | csv_dialogue = [] 96 | prevTag = "other" 97 | prevType = "other" 98 | speaker_A = True 99 | turn_number = 1 100 | 101 | for segm in sorted(dialogs[d].keys()): 102 | speaker = 'A' if speaker_A==True else 'B' 103 | 104 | for sentence in dialogs[d][segm]: 105 | 106 | # csv_corpus.append((sentence[0], sentence[1], prevTag, segm, sentence[2], prevType)) 107 | # csv_dialogue.append((sentence[1], unicode(sentence[0], "utf-8"), speaker, turn_number)) # Python 2.7 108 | csv_dialogue.append((sentence[1], sentence[0], speaker, turn_number)) 109 | 110 | 111 | # Avoid indexes for empty turns 112 | if dialogs[d][segm]: 113 | turn_number += 1 114 | 115 | # Change speaker 116 | if speaker_A is True: 117 | speaker_A = False 118 | else: 119 | speaker_A = True 120 | 121 | try: 122 | prevTag = dialogs[d][segm][-1][1] 123 | prevType = dialogs[d][segm][-1][2] 124 | except: # no prev in this segment 125 | pass 126 | 127 | csv_corpus[d] = csv_dialogue 128 | return csv_corpus 129 | 130 | @staticmethod 131 | def da_to_dimension(corpus_tuple): 132 | da=corpus_tuple[1] 133 | da_type=corpus_tuple[4] 134 | if da in ["suggest","inform","offer"] or da_type in ["q_wh","q_yn","imp"]: 135 | return "Task" 136 | elif da in ["thank","bye","greet","pardon","regret"]: 137 | return "SocialObligationManagement" 138 | elif da =="ackn" or da_type=="backchannel": 139 | return "Feedback" 140 | else: 141 | return None 142 | 143 | @staticmethod 144 | def da_to_cf(corpus_tuple): 145 | da=corpus_tuple[1] 146 | da_type=corpus_tuple[4] 147 | if da_type == "q_wh": 148 | return "SetQ" 149 | elif da_type == "q_yn": 150 | return "CheckQ" 151 | elif da_type=="imp" or da=="suggest": 152 | return "Directive" 153 | elif da=="inform": 154 | return "Statement" 155 | elif da=="offer": 156 | return "Commissive" 157 | elif da=="thank": 158 | return "Thanking" 159 | elif da in ["bye","greet"]: 160 | return "Salutation" 161 | elif da in ["pardon","regret"]: 162 | return "Apology" 163 | elif da=="ackn" or da_type=="backchannel": 164 | return "Feedback" 165 | else: 166 | return None -------------------------------------------------------------------------------- /corpus/Oasis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/corpus/Oasis/__init__.py -------------------------------------------------------------------------------- /corpus/Simulated/Simulated.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from corpus.Corpus import Corpus 4 | """ 5 | AMI class: loads the corpus into tuples (sentence,DA,prevDA). Provides methods 6 | to dump the corpus in CSV format with original annotation and with ISO annotation 7 | """ 8 | 9 | 10 | class Simulated(Corpus): 11 | def __init__(self, ami_folder): 12 | # check whether the ami_folder contains a valid AMI installation 13 | try: 14 | assert os.path.exists(ami_folder) # folder exists 15 | assert os.path.exists(ami_folder+"/words/ES2002a.A.words.xml") # words files exist 16 | assert os.path.exists(ami_folder+"/dialogueActs/ES2002a.A.dialog-act.xml") # DA files exist 17 | except AssertionError: 18 | print("The folder "+ami_folder+" does not contain some important files from the corpus.") 19 | print("You can download a complete version of the corpus at http://groups.inf.ed.ac.uk/ami/download/") 20 | exit(1) 21 | self.ami_folder = ami_folder 22 | self.csv_corpus = [] 23 | 24 | def load_csv(self): 25 | dialogs={} # this will store dialogs from the corpus 26 | dialog_names=[] # this will store filenames from the corpus 27 | for dialog_name in os.listdir(self.ami_folder+"/dialogueActs/"): 28 | if "dialog-act" in dialog_name: # DA file 29 | dialog_names.append(dialog_name.split("dialog-act")[0]) 30 | for dialog_name in dialog_names: 31 | dialogs[dialog_name] = OrderedDict() 32 | self.load_words(dialogs, dialog_name) 33 | self.load_dialog_acts(dialogs, dialog_name) 34 | self.load_segments(dialogs, dialog_name) 35 | self.csv_corpus = self.create_csv(dialogs) 36 | return self.csv_corpus 37 | 38 | def load_words(self,dialogs,dialog_name): 39 | with open(self.ami_folder+"/words/" + dialog_name + "words.xml") as wfile: 40 | for line in wfile: 41 | if not "")[1].split("<")[0] 47 | dialogs[dialog_name][word_id] = [] 48 | dialogs[dialog_name][word_id].append(word_value) 49 | 50 | def load_dialog_acts(self,dialogs,dialog_name): 51 | with open(self.ami_folder+"/dialogueActs/" + dialog_name + "dialog-act.xml") as actfile: 52 | dact = "" 53 | for line in actfile: 54 | if "", 44 | "slashes": r"\-{2}|[^a-z]\-{1}[^a-z]|\w+\-{1}[^a-z]", 45 | "hash": r"#", 46 | "typo": r"\*typo" 47 | } 48 | result_regex = "" 49 | if all is True: 50 | result_regex = "|".join(regex for regex in regexes.values()) 51 | else: 52 | if any(regex_type not in regexes.keys() for regex_type in keys): 53 | raise TypeError("Invalid regex type requested") 54 | result_regex = "|".join(regexes.get(regex_type) for regex_type in keys) 55 | return result_regex 56 | 57 | @staticmethod 58 | def write_csv(to_write, headers, filename="generated_dataset", outpath=""): 59 | with open(outpath+filename+'.csv','wb') as outfile: 60 | writer = csv.writer(outfile, delimiter=',') 61 | writer.writerow(headers) 62 | for line in to_write: 63 | writer.writerow(line) 64 | print("Written output csv file: ", filename) 65 | 66 | def create_filelist(self): 67 | filelist=[] 68 | for folder in os.listdir(self.corpus_folder): 69 | if folder.startswith("sw"): # dialog folder 70 | for filename in os.listdir(self.corpus_folder+"/"+folder): 71 | if filename.startswith("sw"): # dialog file 72 | filelist.append(self.corpus_folder+"/"+folder+"/"+filename) 73 | return filelist 74 | 75 | def create_csv(self,filelist): 76 | csv_corpus = [] 77 | for filename in filelist: 78 | prev_speaker = None 79 | segment = 0 80 | prev_DAs = {"A":"%", "B":"%"} 81 | with open(filename) as f: 82 | utterances = f.readlines() 83 | for line in utterances: 84 | line = line.strip() 85 | try: 86 | sentence = line.split("utt")[1].split(":")[1] 87 | sw_tag = line.split("utt")[0].split()[0] 88 | if "A" in line.split("utt")[0]: # A speaking 89 | speaker = "A" 90 | else: 91 | speaker = "B" 92 | except: # not an SWDA utterance format: probably a header line 93 | continue 94 | if speaker != prev_speaker: 95 | prev_speaker = speaker 96 | segment += 1 97 | sentence = re.sub(r"([+/\}\[\]]|\{\w)", "", 98 | sentence) # this REGEX removes prosodic information and disfluencies 99 | DA_tag = DAMSL.sw_to_damsl(sw_tag, prev_DAs[speaker]) 100 | csv_corpus.append((sentence, DA_tag, prev_DAs[speaker], segment, None, None)) 101 | prev_DAs[speaker] = DA_tag 102 | return csv_corpus 103 | 104 | def create_dialogue_csv(self, filelist): 105 | ''' 106 | output csv: 107 | {filename : [(DA, utt, speaker, turn number)]} 108 | ''' 109 | csv_corpus = {} 110 | # filelist = filelist[:5] 111 | # filelist = ['../../Datasets/Switchboard/data/switchboard1-release2//sw00utt/sw_0004_4327.utt'] 112 | 113 | for filename in filelist: 114 | csv_dialogue = [] 115 | prev_speaker=None 116 | segment=0 117 | prev_DAs={"A":"%","B":"%"} 118 | with open(filename) as f: 119 | utterances = f.readlines() 120 | for line in utterances: 121 | line = line.strip() 122 | try: 123 | sentence = line.split("utt")[1].split(":")[1] 124 | sw_tag = line.split("utt")[0].split()[0] 125 | if "A" in line.split("utt")[0]: # A speaking 126 | speaker="A" 127 | else: 128 | speaker="B" 129 | except: # not an SWDA utterance format: probably a header line 130 | continue 131 | if speaker != prev_speaker: 132 | prev_speaker = speaker 133 | segment += 1 134 | sentence = re.sub(self.get_regex(all=True), "", sentence) 135 | sentence = re.sub(r"\'re[^\w]", " are", sentence) # preprocess "lawyers're" into "lawyers are" 136 | DA_tag = DAMSL.sw_to_damsl(sw_tag,prev_DAs[speaker]) 137 | csv_dialogue.append((DA_tag, sentence, speaker, segment)) 138 | csv_corpus[filename.split("/")[-1]] = csv_dialogue 139 | 140 | return csv_corpus 141 | 142 | 143 | 144 | 145 | @staticmethod 146 | def da_to_dimension(corpus_tuple): 147 | da=corpus_tuple[1] 148 | if da in ["statement-non-opinion","statement-opinion","rhetorical-questions","hedge","or-clause", 149 | "wh-question", "declarative-wh-question","backchannel-in-question-form","yes-no-question", 150 | "declarative-yn-question","tag-question","offers-options-commits","action-directive"]: 151 | return "Task" 152 | elif da in ["thanking","apology","downplayer","conventional-closing"]: 153 | return "SocialObligationManagement" 154 | elif da in ["signal-non-understanding", "acknowledge", "appreciation"]: 155 | return "Feedback" 156 | else: 157 | return None 158 | 159 | @staticmethod 160 | def da_to_cf(corpus_tuple): 161 | da=corpus_tuple[1] 162 | if da in ["statement-non-opinion","statement-opinion","rhetorical-questions","hedge"]: 163 | return "Statement" 164 | elif da == "or-clause": 165 | return "ChoiceQ" 166 | elif da in ["wh-question","declarative-wh-question"]: 167 | return "SetQ" 168 | elif da in ["backchannel-in-question-form","yes-no-question","declarative-yn-question","tag-question"]: 169 | return "PropQ" 170 | elif da == "offers-options-commits": 171 | return "Commissive" 172 | elif da == "action-directive": 173 | return "Directive" 174 | elif da in ["thanking"]: 175 | return "Thanking" 176 | elif da in ["apology","downplayer"]: 177 | return "Apology" 178 | elif da in "conventional-closing": 179 | return "Salutation" 180 | elif da in ["signal-non-understanding","acknowledge","appreciation"]: 181 | return "Feedback" 182 | else: 183 | return None -------------------------------------------------------------------------------- /corpus/Switchboard/Switchboard.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/corpus/Switchboard/Switchboard.pyc -------------------------------------------------------------------------------- /corpus/Switchboard/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/corpus/Switchboard/__init__.py -------------------------------------------------------------------------------- /corpus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/corpus/__init__.py -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Data 3 | 4 | The zipped folder data.zip contains the grids generated for our experiments using the script generate_grid.py . 5 | The data is divided according to the corpus (e.g. Switchboard, AMI etc.). 6 | In the folder for each corpus there is: 7 | - one subfolder for each configuration used to generate the grids (e.g. egrid_-coref etc.), containing all grids in .csv format 8 | - one subfolder named `shuffled`, containing the original shuffles (using turn indexes) for all files used in our experiments in .csv format 9 | - the file `Train_Validation_Test_split.csv` containing the splits used in our experiments for that corpus 10 | 11 | -------------------------------------------------------------------------------- /data/data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecervi/Coherence-models-for-dialogue/f266f5f79755b268a0656b79610a0161a08ec446/data/data.zip -------------------------------------------------------------------------------- /eval_svmlight_output.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging, sys 4 | from collections import defaultdict 5 | import numpy as np 6 | from optparse import OptionParser 7 | from itertools import combinations 8 | from random import random 9 | # from statsmodels.sandbox.stats.runs import mcnemar 10 | from sklearn.metrics import confusion_matrix 11 | from mlxtend.evaluate import mcnemar_table 12 | from statsmodels.stats.contingency_tables import mcnemar 13 | from mlxtend.evaluate import mcnemar_table 14 | from mlxtend.evaluate import mcnemar as mlx_mcnemar 15 | from scipy.stats import mannwhitneyu, wilcoxon 16 | # from sklearn.datasets import load_svmlight_file 17 | from sklearn.metrics import accuracy_score, classification_report, average_precision_score 18 | 19 | # parse commandline arguments 20 | op = OptionParser() 21 | op.add_option("--testfile", 22 | action="store", type=str, dest="testfile", 23 | help="Testfile.") 24 | op.add_option("--testfile2", 25 | action="store", type=str, dest="testfile2", 26 | help="Testfile.") 27 | op.add_option("--predfile", 28 | action="store", type=str, dest="predfile", 29 | help="Prediction file.") 30 | op.add_option("--predfile2", 31 | action="store", type=str, dest="predfile2", 32 | help="Prediction file to compare with.") 33 | op.add_option("--statsign", 34 | action="store_true", dest="statsign", 35 | help="Prediction file to compare with.") 36 | 37 | 38 | def map_score(lists): 39 | average_precs = [] 40 | for _, candidates in lists.items(): 41 | score, label = zip(*candidates) 42 | label = list(map(lambda x: int(x)-1, label)) 43 | average_precs.append(average_precision_score(label, score)) 44 | return sum(average_precs) / len(average_precs) 45 | 46 | def mrr_score(lists): 47 | recp_ranks = [] 48 | for _, candidates in lists.items(): 49 | rank = 0 50 | for i, (_, label) in enumerate(sorted(candidates, reverse=True, key=lambda x: x[0]), 1): 51 | if label == 2: 52 | rank += 1. / i 53 | break 54 | recp_ranks.append(rank) 55 | return sum(recp_ranks) / len(recp_ranks) 56 | 57 | def prec_at(lists, n): 58 | precs = [] 59 | for _, candidates in lists.items(): 60 | for i, (_, label) in enumerate(sorted(candidates, reverse=True, key=lambda x: x[0]), 1): 61 | if i > n: 62 | precs.append(0.) 63 | break 64 | elif label == 2: 65 | precs.append(1.) 66 | break 67 | return sum(precs) / len(precs), precs 68 | 69 | 70 | def read_test_file(path): 71 | with open(path, 'r') as infile: 72 | print('Reading: ', path) 73 | for line in infile: 74 | if line[0] is not '#': 75 | yield line.strip().split() 76 | 77 | def evaluate(testfile, predfile = None): 78 | print('Predfile: ', predfile) 79 | print('Testfile: ', testfile) 80 | queries = defaultdict(list) 81 | test_file = list(read_test_file(testfile)) 82 | if predfile: 83 | with open(predfile, 'r') as pred: 84 | for prd, doc in zip(pred, test_file): 85 | 86 | lbl, qid = doc[0], doc[1] 87 | # print('Lab: ', lbl, ' Qid: ', qid, ' Pred: ', prd) 88 | queries[qid].append((float(prd.strip()), int(lbl))) 89 | else: 90 | for doc in test_file: 91 | 92 | lbl, qid = doc[0], doc[1] 93 | # print('Lab: ', lbl, ' Qid: ', qid, ' Pred: ', prd) 94 | queries[qid].append((random(), int(lbl))) 95 | 96 | # print('Testfile: ', len([i for i in read_test_file(testfile)])) 97 | # print('Predfile: ', len([i for i in read_test_file(predfile)])) 98 | 99 | y_pred = list() 100 | y_true = list() 101 | for qid in queries: 102 | pairs_numb = [i for i in combinations(queries[qid], 2) if i[0][1]!=i[1][1]] 103 | # print('Pairs numb: ', len(pairs_numb), ' Qid ',qid) 104 | for pair in pairs_numb: 105 | (pred_1, true_1), (pred_2, true_2) = pair 106 | y_pred.append(int(pred_1 <= pred_2)) 107 | y_true.append(int(true_1 <= true_2)) 108 | 109 | 110 | print("Accuracy: {:.4f}".format(accuracy_score(y_true, y_pred))) 111 | print(classification_report(y_true, y_pred)) 112 | print("\n Rank Metrics \n ") 113 | print("MAP: {:.4f}".format(map_score(queries))) 114 | print("MRR: {:.4f}".format(mrr_score(queries))) 115 | 116 | print("\n Precisions\n ") 117 | 118 | print("PREC@{}: {:.4f}".format(1, prec_at(queries, 1)[0])) 119 | print("PREC@{}: {:.4f}".format(2, prec_at(queries, 2)[0])) 120 | print("PREC@{}: {:.4f}".format(3, prec_at(queries, 3)[0])) 121 | print("PREC@{}: {:.4f}".format(5, prec_at(queries, 5)[0])) 122 | print("PREC@{}: {:.4f}".format(10, prec_at(queries, 10)[0])) 123 | y_pred_prec = prec_at(queries, 1)[1] 124 | y_true_prec = np.ones(len(y_pred_prec)) 125 | 126 | return y_true, y_pred, y_true_prec, y_pred_prec 127 | 128 | 129 | def test_mannwhithney(predfile1, predfile2, testfile, testfile2): 130 | y_true1, y_pred1, y_true_prec1, y_pred_prec1 = evaluate(testfile, predfile1) 131 | y_true2, y_pred2, y_true_prec2, y_pred_prec2 = evaluate(testfile2, predfile2) 132 | print('\n First model: ', predfile1) 133 | print('Ex: ', y_pred1[:10], ' Len: ', len(y_pred1)) 134 | print('Second model: ', predfile2) 135 | print('Ex: ', y_pred2[:10], ' Len: ', len(y_pred2)) 136 | print('Is testset the same? ', len([i for i in np.equal(np.array(y_true1), np.array(y_true2)) if i is False])) 137 | 138 | mc_tb = mcnemar_table(y_target=np.array(y_true1), 139 | y_model1=np.array(y_pred1), 140 | y_model2=np.array(y_pred2)) 141 | print('Contingency table: ', mc_tb) 142 | mcnemar_res = mcnemar(mc_tb) 143 | print('McNemar: p value: {:.20f}'.format(mcnemar_res.pvalue)) 144 | chi2, p = mlx_mcnemar(ary=mc_tb, corrected=True) 145 | print('McNemar: chi:{:.4f} p value: {}'.format(chi2, p)) 146 | mc_tb_prec = mcnemar_table(y_target=np.array(y_true_prec1), 147 | y_model1=np.array(y_pred_prec1), 148 | y_model2=np.array(y_pred_prec2)) 149 | mcnemar_res_prec = mcnemar(mc_tb_prec) 150 | print('McNemar PRECISION: p value: {}'.format(mcnemar_res_prec.pvalue)) 151 | # mw_stat, mw_p_val = mannwhitneyu(np.array(y_pred1), np.array(y_pred2), alternative='less') 152 | # print('Mann Whitney: Stats: ', mw_stat, ' p value: ', mw_p_val) 153 | # wil_stat, wil_p_val = wilcoxon(np.array(y_pred1), np.array(y_pred2)) 154 | # print('Wilcoxon: Stats: ', wil_stat, ' p value: ', wil_p_val) 155 | 156 | def is_interactive(): 157 | return not hasattr(sys.modules['__main__'], '__file__') 158 | 159 | argv = [] if is_interactive() else sys.argv[1:] 160 | (opts, args) = op.parse_args(argv) 161 | 162 | 163 | def main(): 164 | if not opts.statsign: 165 | print('Evaluating...') 166 | evaluate(opts.testfile, opts.predfile) 167 | else: 168 | print('Performing Mann Whitney test ...') 169 | test_mannwhithney(opts.predfile, opts.predfile2, opts.testfile, opts.testfile2) 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /eval_svmlight_output_insertion.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging, sys 4 | from collections import defaultdict 5 | import numpy as np 6 | from optparse import OptionParser 7 | from itertools import combinations 8 | from random import random 9 | # from sklearn.datasets import load_svmlight_file 10 | from statsmodels.sandbox.stats.runs import mcnemar 11 | from sklearn.metrics import accuracy_score, classification_report, average_precision_score 12 | 13 | # parse commandline arguments 14 | op = OptionParser() 15 | op.add_option("--testfile", 16 | action="store", type=str, dest="testfile", 17 | help="Testfile.") 18 | op.add_option("--predfile", 19 | action="store", type=str, dest="predfile", 20 | help="Prediction file.") 21 | op.add_option("--predfile2", 22 | action="store", type=str, dest="predfile2", 23 | help="Prediction file to compare with.") 24 | op.add_option("--statsign", 25 | action="store_true", dest="statsign", 26 | help="Prediction file to compare with.") 27 | 28 | 29 | def map_score(lists): 30 | average_precs = [] 31 | for _, candidates in lists.items(): 32 | score, label = zip(*candidates) 33 | label = map(lambda x: int(x)-1, label) 34 | average_precs.append(average_precision_score(label, score)) 35 | return sum(average_precs) / len(average_precs) 36 | 37 | def mrr_score(lists): 38 | recp_ranks = [] 39 | for _, candidates in lists.items(): 40 | rank = 0 41 | for i, (_, label) in enumerate(sorted(candidates, reverse=True, key=lambda x: x[0]), 1): 42 | if label == 2: 43 | rank += 1. / i 44 | break 45 | recp_ranks.append(rank) 46 | return sum(recp_ranks) / len(recp_ranks) 47 | 48 | def prec_at(lists, n): 49 | precs = [] 50 | for _, candidates in lists.items(): 51 | for i, (_, label) in enumerate(sorted(candidates, reverse=True, key=lambda x: x[0]), 1): 52 | if i > n: 53 | precs.append(0.) 54 | break 55 | elif label == 2: 56 | precs.append(1.) 57 | break 58 | return sum(precs) / len(precs) 59 | 60 | 61 | def average_score(func, lists, *args): 62 | scores = [] 63 | for _, candidates in lists.items(): 64 | scores.append(func(candidates, *args)) 65 | return sum(scores) / len(scores) 66 | 67 | 68 | def read_test_file(path): 69 | query_id = None 70 | with open(path, 'r') as infile: 71 | for line in infile: 72 | if line[0] is not '#': 73 | yield query_id, line.strip().split() 74 | else: 75 | query_id = line.strip().split()[2] 76 | 77 | def evaluate(testfile, predfile = None): 78 | queries = {} 79 | test_file = list(read_test_file(testfile)) 80 | if predfile: 81 | with open(predfile, 'r') as pred: 82 | for prd, (doc_id, doc) in zip(pred, test_file): 83 | 84 | lbl, qid = doc[0], doc[1] 85 | # print('Lab: ', lbl, ' Qid: ', qid, ' Pred: ', prd) 86 | queries[doc_id] = queries.get(doc_id, {}) 87 | queries[doc_id][qid] = queries[doc_id].get(qid, list()) 88 | queries[doc_id][qid].append((float(prd.strip()), int(lbl))) 89 | else: 90 | for (doc_id, doc) in test_file: 91 | 92 | lbl, qid = doc[0], doc[1] 93 | # print('Lab: ', lbl, ' Qid: ', qid, ' Pred: ', prd) 94 | queries[doc_id] = queries.get(doc_id, {}) 95 | queries[doc_id][qid] = queries[doc_id].get(qid, list()) 96 | queries[doc_id][qid].append((random(), int(lbl))) 97 | 98 | # print('Testfile: ', len([i for i in read_test_file(testfile)])) 99 | # print('Predfile: ', len([i for i in read_test_file(predfile)])) 100 | 101 | # y_pred = list() 102 | # y_true = list() 103 | # for doc_id in queries: 104 | # for qid in queries[doc_id]: 105 | # pairs_numb = [i for i in combinations(queries[qid], 2) if i[0][1]!=i[1][1]] 106 | # # print('Pairs numb: ', len(pairs_numb), ' Qid ',qid) 107 | # for pair in pairs_numb: 108 | # (pred_1, true_1), (pred_2, true_2) = pair 109 | # y_pred.append(int(pred_1 <= pred_2)) 110 | # y_true.append(int(true_1 <= true_2)) 111 | 112 | # 113 | # print("Accuracy: {:.4f}".format(accuracy_score(y_true, y_pred))) 114 | # print(classification_report(y_true, y_pred)) 115 | print("\n Rank Metrics \n ") 116 | print("Average MAP: {:.4f}".format(average_score(map_score, queries))) 117 | print("Average MRR: {:.4f}".format(average_score(mrr_score, queries))) 118 | 119 | print("\n Precisions\n ") 120 | 121 | print("Average PREC@{}: {:.4f}".format(1, average_score(prec_at, queries, 1))) 122 | print("Average PREC@{}: {:.4f}".format(2, average_score(prec_at, queries, 2))) 123 | print("Average PREC@{}: {:.4f}".format(3, average_score(prec_at, queries, 3))) 124 | print("Average PREC@{}: {:.4f}".format(5, average_score(prec_at, queries, 5))) 125 | print("Average PREC@{}: {:.4f}".format(10, average_score(prec_at, queries, 10))) 126 | 127 | return 128 | 129 | 130 | def test_mcnemar(predfile1, predfile2, testfile): 131 | pass 132 | 133 | 134 | def is_interactive(): 135 | return not hasattr(sys.modules['__main__'], '__file__') 136 | 137 | argv = [] if is_interactive() else sys.argv[1:] 138 | (opts, args) = op.parse_args(argv) 139 | 140 | 141 | def main(): 142 | if not opts.statsign: 143 | print('Evaluating...') 144 | evaluate(opts.testfile, opts.predfile) 145 | else: 146 | print('Performing Mc Nemar test ...') 147 | test_mcnemar(opts.predfile, opts.predfile2, opts.testfile) 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /generate_grid.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from future.utils import iteritems 3 | from builtins import dict 4 | from corpus.Switchboard.Switchboard import Switchboard 5 | from corpus.Oasis.Oasis import Oasis 6 | from corpus.Maptask.Maptask import Maptask 7 | from corpus.AMI.AMI import AMI 8 | from collections import OrderedDict 9 | from itertools import groupby 10 | from operator import itemgetter 11 | from random import shuffle 12 | import argparse 13 | import numpy as np 14 | import sys 15 | import inspect 16 | import tqdm 17 | import copy 18 | import timeit 19 | import re 20 | import logging 21 | import spacy 22 | import csv 23 | import os 24 | 25 | ''' 26 | GridGenerator takes as input a dialogue with DAs and converts it into a Grid, 27 | the Grid object is then saved in a text file. 28 | 29 | Input structure: 30 | One DA per line 31 | 32 | Grid structure: 33 | N 1 2 34 | DA1 G - - 35 | DA2 - Q - 36 | 37 | Differences compared to Elsner and Charniak: they use gold parses for Switchboard, 38 | while we give preference to gold DAs (only 53% utts of SWDA have gold parses) 39 | 40 | - Lapata and Barzilay (2005): 41 | Mention detection: entities are NPs, each noun in the NP given same role as head 42 | e.g. Former Chilean dictator Augusto Pinochet mapped to 3 ents (dictator, Augusto, Pinochet) 43 | 44 | - Barzilay and Lapata(2008): 45 | Mention detection: 46 | - COREF: nouns are considered coreferent if identical, NEs/multiword nouns divided and considered independent 47 | e.g. Microsoft Corp. mapped to 2 ents (Microsoft, Corp.), role of each element in NP=head's role 48 | + COREF: coreferent NPs mapped to the head 49 | 50 | - Elsner & Charniak (2011), Extending the Entity grid: 51 | Mention detection: all NPs (not only heads), non-head nouns are given the role X 52 | 53 | Files generated with spacy 1 54 | 55 | ''' 56 | 57 | all_datasets_path = '../../Datasets/' 58 | corpora_paths = {'Switchboard': all_datasets_path + 'Switchboard/data/switchboard1-release2/', 59 | 'Oasis': all_datasets_path + 'Oasis', 60 | 'Maptask': all_datasets_path + 'maptaskv2-1', 61 | 'AMI': all_datasets_path + 'ami_public_manual_1.6.2'} 62 | 63 | 64 | def get_corpus(corpus_name): 65 | corpus_path = corpora_paths.get(corpus_name) 66 | if corpus_name=='Switchboard': 67 | corpus_loader = Switchboard(corpus_path) 68 | elif corpus_name=='Oasis': 69 | corpus_loader = Oasis(corpus_path) 70 | elif corpus_name=='Maptask': 71 | corpus_loader = Maptask(corpus_path) 72 | elif corpus_name=='AMI': 73 | corpus_loader = AMI(corpus_path) 74 | corpus_dct = corpus_loader.load_csv() 75 | return corpus_dct, corpus_loader 76 | 77 | 78 | class GridGenerator(object): 79 | 80 | def __init__(self, nlp=None, coref=None): 81 | if nlp is None: 82 | print("Loading spacy model") 83 | try: 84 | spacy.info('en_core_web_sm') 85 | model = 'en_core_web_sm' 86 | except IOError: 87 | print("No spacy 2 model detected, using spacy1 'en' model") 88 | model = 'en' 89 | self.nlp = spacy.load(model) 90 | 91 | if not coref: 92 | try: 93 | from neuralcoref_lib.neuralcoref import Coref 94 | self.coref = Coref(nlp=self.nlp) 95 | except: 96 | self.coref = coref 97 | logging.info('Coreference not available') 98 | 99 | else: 100 | self.coref = coref 101 | 102 | self.spacy_tags = {'pos': {'noun': ['NOUN', 'PROPN'], 103 | 'pronoun': ['PRON']}, 104 | 'dep': {'subject': ['csubj', 'csubjpass', 'nsubj', 'nsubjpass'], #'agent','expl', 105 | 'object': ['dobj', 'iobj', 'oprd']}} # ADD 'attr' for spacy 2 106 | self.conversational_pronouns = ["i", "you"] 107 | self.grammatical_ranking = ['S', 'O', 'X'] 108 | 109 | def extract_coref_from_one_shot(self, utts, speakers): 110 | clusters = self.coref.one_shot_coref(utterances=utts, utterances_speakers_id=speakers) 111 | # clusters = self.coref.get_clusters(use_no_coref_list=False) # If you want to keep "I","you" 112 | # print("Clusters: ", clusters) 113 | 114 | parsed_utts = self.coref.get_utterances(last_utterances_added=False) # Get all spacy utts in dialogue 115 | entities = [[(ent.text, ent.label_) for ent in parsed_utt.ents] for parsed_utt in parsed_utts] 116 | # Retrieve also syntactic role 117 | 118 | mentions = self.coref.get_mentions() 119 | 120 | most_representative = self.coref.get_most_representative() # With coref 121 | # print("Most representative: ", most_representative) 122 | return clusters, parsed_utts, entities, mentions 123 | 124 | def group_turns(self, dialogue): 125 | 126 | grouped = [list(g) for k, g in groupby(dialogue, itemgetter(3))] 127 | regrouped_dialogue = [(g[-1][0], u' '.join(sent[1] for sent in g), g[-1][2], g[-1][3]) 128 | for g in grouped] 129 | 130 | return regrouped_dialogue 131 | 132 | def corpus_stats(self, corpus_dct): 133 | corpus_dct = {dialogue_id: self.group_turns(dialogue) for dialogue_id, dialogue in iteritems(corpus_dct)} 134 | tokens_per_turn = [[len(turn[1].split()) for turn in dialogue] for dialogue in corpus_dct.values()] 135 | mean_tokens_per_turn = [sum(dialogue)/float(len(dialogue)) for dialogue in tokens_per_turn] 136 | mean_tokens_per_turn_all = sum(mean_tokens_per_turn) / float(len(mean_tokens_per_turn)) 137 | turns_per_dialogue = [len(dialogue) for dialogue in corpus_dct.values()] 138 | mean_turns_per_dialogue = sum(turns_per_dialogue) / float(len(turns_per_dialogue)) 139 | print("Average tokens per turn: ", mean_tokens_per_turn_all) 140 | print("Average dialogue turns: ", mean_turns_per_dialogue) 141 | return mean_tokens_per_turn_all, mean_turns_per_dialogue 142 | 143 | 144 | def check_named_entities(self, parsed_utt): 145 | NE_tokens = [[parsed_utt[i] for i in range(span.start, span.end)] 146 | for span in parsed_utt.ents] 147 | # Strip non-nouns from NEs 148 | NE_tokens = self.filter_non_nouns(NE_tokens) 149 | return NE_tokens 150 | 151 | def filter_non_nouns(self, token_span): 152 | # Keep only nouns and pronouns (strip adjectives etc. from NPs) 153 | token_span = [[token for token in np if token.pos_ in self.spacy_tags['pos']['noun']+ 154 | self.spacy_tags['pos']['pronoun']] 155 | for np in token_span] 156 | return token_span 157 | 158 | def extract_nps(self, parsed_utt, NEs=None, 159 | include_prons=False, exclude_conversation_prons=True): 160 | 161 | if not NEs: NEs = [] 162 | 163 | NPs_tokens = [[parsed_utt[i] for i in range(span.start, span.end)] for span in parsed_utt.noun_chunks] 164 | # print('All NPs: ', NPs_tokens) 165 | 166 | # Remove all pronouns 167 | if include_prons is False: 168 | NPs_tokens = [[token for token in span] 169 | for span in NPs_tokens 170 | if not any(token.pos_ in self.spacy_tags['pos']['pronoun'] for token in span)] 171 | # print('All prons removed: ', NPs_tokens) 172 | 173 | # Remove only conversational pronouns 174 | elif exclude_conversation_prons is True: 175 | NPs_tokens = [[token for token in span] 176 | for span in NPs_tokens 177 | if not any(token.pos_ in self.spacy_tags['pos']['pronoun'] 178 | and token.text.lower() in self.conversational_pronouns 179 | for token in span)] 180 | # print('Personal prons removed: ', NPs_tokens) 181 | 182 | # Strip non-nouns from NPs 183 | NPs_tokens = self.filter_non_nouns(NPs_tokens) 184 | # print("Simple NPs only nouns: ", NPs_tokens) 185 | 186 | # Check whether a "compound" is part of a NE, if not consider it an independent entity from head NP 187 | NPs_tokens = [np for np in NPs_tokens if np not in NEs] 188 | # print("Simple NPs no ent: ", NPs_tokens) 189 | 190 | return NPs_tokens 191 | 192 | def extract_head_nps(self, NPs_tokens=None, NEs=None): 193 | # Lapata(2008) 194 | # headNPs: (only head of Noun Phrase), partial coref (only if identical) 195 | if not NPs_tokens: NPs_tokens = [] 196 | if not NEs: NEs = [] 197 | 198 | # Keep only head from each NP (delete other tokens) 199 | NPs_tokens = [[token for token in span if token.head.i not in [token.i for token in span]] for span in NPs_tokens] 200 | 201 | # Join resulting NPs with NEs 202 | NPs_tokens = NPs_tokens + NEs 203 | 204 | return NPs_tokens 205 | 206 | def extract_all_nps(self, NPs_tokens=None, NEs=None): 207 | # Elser & Charniak (2008) 208 | # allNPs: (non-head mentions given role X), partial coref (only if identical) 209 | if not NPs_tokens: NPs_tokens = [] 210 | if not NEs: NEs = [] 211 | 212 | # Divide final NPs (to be considered independently) 213 | NPs_tokens = [[token] for np in NPs_tokens for token in np] 214 | 215 | # Join resulting NPs with NEs 216 | NPs_tokens = NPs_tokens + NEs 217 | 218 | return NPs_tokens 219 | 220 | def extract_entities_from_utt(self, utt, entities_type="headNPs", use_coref=False, 221 | include_prons= False, exclude_conversation_prons= True): 222 | 223 | if use_coref: 224 | resolved_utts = self.coref.get_resolved_utterances(use_no_coref_list=exclude_conversation_prons) 225 | parsed_utts = [self.nlp(resolved_utt) for resolved_utt in resolved_utts] 226 | else: 227 | parsed_utts = [self.nlp(utt)] 228 | 229 | # print('Parsed utt:', [u.text for u in parsed_utts]) 230 | entities = [] 231 | 232 | for parsed_utt in parsed_utts: 233 | # Extract NPs and Named Entities 234 | NEs = self.check_named_entities(parsed_utt) 235 | NPs_tokens = self.extract_nps(parsed_utt, 236 | NEs=NEs, 237 | include_prons=include_prons, 238 | exclude_conversation_prons=exclude_conversation_prons) 239 | # Lapata(2008) 240 | if entities_type=="headNPs": 241 | entities.append(self.extract_head_nps(NPs_tokens=NPs_tokens, NEs=NEs)) 242 | 243 | # Elser & Charniak (2008) 244 | elif entities_type=="allNPs": 245 | entities.append(self.extract_all_nps(NPs_tokens=NPs_tokens, NEs=NEs)) 246 | 247 | # Join Spacy utts 248 | entities = [ent for utt in entities for ent in utt] 249 | 250 | return entities 251 | 252 | 253 | def assign_tag_to_entities(self, entities, tag, tag_type="synrole_head"): 254 | # tag_type = synrole - syntactic role, if more than one in current text follow synrole ranking, 255 | # da - dialogue act 256 | 257 | 258 | if tag_type=='da': 259 | entity_tags = [(entity, tag) for entity in entities] 260 | elif tag_type=='synrole_head': 261 | entity_tags = [(entity_span, "S") if any(entity.dep_ in self.spacy_tags['dep']['subject'] if entity.dep_!='compound' 262 | else entity.head.dep_ in self.spacy_tags['dep']['subject'] 263 | for entity in entity_span) 264 | else (entity_span, "O") if any(entity.dep_ in self.spacy_tags['dep']['object'] if entity.dep_!='compound' 265 | else entity.head.dep_ in self.spacy_tags['dep']['object'] 266 | for entity in entity_span) 267 | else (entity_span, "X") for entity_span in entities] 268 | elif tag_type=='synrole_X': 269 | entity_tags = [(entity_span, "S") if any(entity.dep_ in self.spacy_tags['dep']['subject'] for entity in entity_span) 270 | else (entity_span, "O") if any(entity.dep_ in self.spacy_tags['dep']['object'] for entity in entity_span) 271 | else (entity_span, "X") for entity_span in entities] 272 | elif tag_type=='is_present': 273 | entity_tags = [(entity_span, "X") for entity_span in entities] 274 | else: 275 | raise TypeError("Not implemented tag type") 276 | 277 | # print("Tagged entities: ", entity_tags) 278 | return entity_tags 279 | 280 | def group_same_utt_entities(self, current_entities, tag_type): 281 | grouped_entities = [list(g) for k, g in 282 | groupby(sorted(current_entities), itemgetter(0))] # Group entities by text 283 | 284 | grouped_entities = [[e_group[0]] if all(e[1] == e_group[0][1] for e in e_group) else e_group 285 | for e_group in 286 | grouped_entities] # Reduce groups of entities with same text and same tag to one 287 | 288 | if all(len(e_group) == 1 for e_group in grouped_entities): 289 | current_entities = [e_group[0] for e_group in grouped_entities] 290 | 291 | # Only possible if the chosen tag_type = syntactic role 292 | else: 293 | if tag_type not in ['synrole_head','synrole_X']: 294 | raise TypeError('Different DA categories found in the same DA span!') 295 | 296 | # Get role according to rank 297 | get_role = lambda y: min(y, key=lambda x: self.grammatical_ranking.index(x)) 298 | current_entities = [(e_group[0][0], get_role([x[1] for x in e_group])) if len(e_group)>1 299 | else e_group[0] for e_group in grouped_entities] 300 | 301 | return current_entities 302 | 303 | def remove_disfluencies(self, current_entities): 304 | 305 | # Remove empty entities 306 | current_entities = [(entity_span, tag) for entity_span, tag in current_entities if entity_span] 307 | 308 | # Reduce consecutive repetitions of the same entity to one entity 309 | current_entities = [([entity_span[0]], tag) if all(entity.text==entity_span[0].text for entity in entity_span) 310 | else (entity_span, tag) for entity_span, tag in current_entities] 311 | 312 | return current_entities 313 | 314 | def transform_tokens_groups_in_text(self, current_entities): 315 | # Entities are lowercased and only the text form is kept 316 | return [(u' '.join(ent.lower_ for ent in entity_span), tag) for entity_span, tag in current_entities] 317 | 318 | def complicated_coref(self, tagged_entities, previous_mentions, exclude_conversation_prons): 319 | all_mentions = self.coref.get_mentions() # Tokens spans 320 | new_mentions = [] 321 | # If any new mentions were found 322 | if (len(all_mentions) - previous_mentions) > 0: 323 | new_mentions = all_mentions[-(len(all_mentions) - previous_mentions):] 324 | 325 | print("- New Mentions: ", new_mentions) 326 | print("- New Mentions: ", [[(token.text, token.i) for token in span] for span in new_mentions]) 327 | print("- All mentions: ", all_mentions) 328 | 329 | resolved_utt = self.coref.get_resolved_utterances() 330 | repr = self.coref.get_most_representative() 331 | clusters = self.coref.get_clusters(use_no_coref_list=exclude_conversation_prons) 332 | 333 | cluster_words = {all_mentions[k]: [all_mentions[v] for v in vals] for k, vals in 334 | iteritems(self.coref.get_clusters())} 335 | 336 | print("- Clusters: ", self.coref.get_clusters()) 337 | print("- Clusters: ", {all_mentions[k]: [all_mentions[v] for v in vals] for k, vals in 338 | iteritems(clusters)}) 339 | print("- Clusters: ", cluster_words) 340 | print("- RESOLVED UTT: ", resolved_utt) 341 | print("- Repr: ", repr) 342 | print("- Toks: ", 343 | [(coref_original[0].text, coref_original[0].pos_, coref_original[0].i) for coref_original, coref_replace 344 | in repr.items()]) 345 | # print("- Toks: ", [(type(coref_original[0]), type(coref_replace[0])) for 346 | # coref_original, coref_replace in repr.items()]) 347 | 348 | # Find mapping between mentions and current NP entities 349 | for entity_token_span, tag in tagged_entities: 350 | for mention_span in new_mentions: 351 | if any(entity_token in mention_span for entity_token in entity_token_span): 352 | print('Found matching coreference: Entity_token_span: ', entity_token_span) 353 | print('Found matching coreference: Mention_span: ', mention_span) 354 | print('Found matching coreference: Mention_span index: ', all_mentions.index(mention_span)) 355 | 356 | # Find antecedent if there is one 357 | 358 | mention_cluster = [(k, vals) for k, vals in iteritems(clusters) if 359 | all_mentions.index(mention_span) in vals] 360 | print('Cluster: ', mention_cluster) 361 | # print('Pairs scores: ', self.coref.get_scores()['pair_scores']) 362 | previous_mentions = len(all_mentions) 363 | 364 | 365 | 366 | def process_corpus(self, corpus, entities_type="allNPs", 367 | include_prons=False, exclude_conversation_prons=True, 368 | tag_type="synrole", group_by="DAspan", use_coref = False, 369 | end_of_turn_tag=False, no_entity_column=False): 370 | 371 | ''' 372 | :param corpus: {filename : [(DA, utt, speaker, turn number)]} 373 | :param entities_type = "headNPs", 374 | "allNPs" 375 | :param group_by = "DAspan", 376 | "turns" 377 | :param tag_type = "synrole", 378 | "DAtag" 379 | 380 | :return: grid: {filename: [DAs[entities]]} 381 | # each dialogue is represented as a grid: matrix DAs x Entities (including NO ent) 382 | ''' 383 | 384 | 385 | grids = {} 386 | 387 | # corpus = {k: v for k, v in corpus.iteritems() if k in ['sw_0657_2900.utt','sw_0915_3624.utt']} # Testing 388 | 389 | # For each dialogue 390 | for dialogue_id, dialogue in tqdm.tqdm(iteritems(corpus), total=len(corpus)): 391 | 392 | # Test utt 393 | # print("Dialogue id: ", dialogue_id) 394 | # print("Dialogue len: ", len(dialogue)) 395 | # dialogue = dialogue[:8] 396 | # test_utt = u"San Francisco is a great town. She loves it. It is great, do you agree? " \ 397 | # u"The world's largest oil company is located there. Drugs are great! The chilean leader Barack Obama was happy." 398 | # test_utt = u"My mom is great. I love her! I love drugs!" 399 | # test_utt2 = u"I also like them. The world's largest oil company is here. My mom is called Julia." 400 | # # test_utt = u'Hello!' 401 | # dialogue[0] = (u'test', test_utt, u'A', -2) 402 | # dialogue[1] = (u'test', test_utt2, u'B', -1) 403 | 404 | if self.coref is not None: 405 | self.coref.clean_history() 406 | 407 | # previous_mentions = 0 408 | dialogue_entities = {} # List of turns (entity: list of turns) 409 | if no_entity_column is True: 410 | dialogue_entities['no_entity'] = [] 411 | 412 | # Select text span 413 | if group_by=="turns": 414 | dialogue = self.group_turns(dialogue) 415 | 416 | # Minimum 5 dialogue turns 417 | if len(self.group_turns(dialogue)) < 5: 418 | continue 419 | 420 | # For each text span (utterance) extract list of entities 421 | for tag, utt, speaker, turn_id in dialogue: 422 | 423 | previous_turns_len = len(list(dialogue_entities.values())[0]) if dialogue_entities.keys() else 0 424 | 425 | # Preprocess utt removing double spaces 426 | utt = re.sub(' +', ' ', utt) 427 | 428 | # start = timer() 429 | if use_coref: 430 | self.coref.continuous_coref(utterances=utt, utterances_speakers_id=speaker) 431 | 432 | # t_continuous_coref = timer() 433 | # print('Time only continuous_coref: ', t_continuous_coref - start) 434 | 435 | # Extract entities 436 | current_entities = self.extract_entities_from_utt(utt, 437 | entities_type=entities_type, 438 | include_prons=include_prons, 439 | use_coref=use_coref, 440 | exclude_conversation_prons=exclude_conversation_prons) 441 | 442 | # Assign tag to the entities: [(token list, tag)] 443 | tagged_entities = self.assign_tag_to_entities(current_entities, tag, tag_type=tag_type) 444 | 445 | # Remove disfluencies 446 | tagged_entities = self.remove_disfluencies(tagged_entities) 447 | 448 | # Transform spacy tokens into text 449 | tagged_entities = self.transform_tokens_groups_in_text(tagged_entities) 450 | 451 | # Map repetitions of the same entity in current utt 452 | tagged_entities = self.group_same_utt_entities(tagged_entities, tag_type) 453 | 454 | # Check previous entity dict if there is already an entity, else add new entity column 455 | for entity_key, entity_tag in tagged_entities: 456 | 457 | if entity_key in dialogue_entities: 458 | dialogue_entities[entity_key].append(entity_tag) # Update entity with new tag 459 | else: 460 | dialogue_entities[entity_key] = ['_']*previous_turns_len+[entity_tag] 461 | 462 | 463 | # Update no entity column if there are no entities 464 | if not tagged_entities and no_entity_column and tag_type=='da': 465 | dialogue_entities['no_entity'].append(tag) # Update entity with new tag 466 | # Initialize dialogue_entities if there are no tagged entities creating tmp entity later to be dropped 467 | elif not tagged_entities and not dialogue_entities: 468 | dialogue_entities['']=['_'] 469 | 470 | 471 | # Update all entities not in this turn 472 | dialogue_entities = {ent: (tags+['_'] if len(tags)<(previous_turns_len+1) else tags) 473 | for ent, tags in iteritems(dialogue_entities)} 474 | 475 | if end_of_turn_tag is True: 476 | dialogue_entities = {ent: (tags + ['']) 477 | for ent, tags in iteritems(dialogue_entities)} 478 | # t_end = timer() 479 | # print('Time from continuous_coref to end of dialogue: ', t_end - t_continuous_coref) 480 | # print('Grids lengths: ', set([len(en) for en in dialogue_entities.values()])) 481 | 482 | # print('--Grid -Whole grid: ', dialogue_entities) 483 | 484 | # Remove tmp initializing entity 485 | dialogue_entities.pop('', None) 486 | 487 | grids[dialogue_id] = dialogue_entities 488 | 489 | # break 490 | 491 | logging.info('All grids parsed') 492 | 493 | 494 | return grids 495 | 496 | 497 | # 498 | # def get_intra_shuffle(self, dialogue, shuffles_number=5): 499 | # index_shuff = range(len(dialogue)) 500 | # shuffled_orders = [shuffle(index_shuff) for shuff_i in range(shuffles_number)] 501 | # return shuffled_orders 502 | 503 | 504 | 505 | 506 | 507 | def sort_grid_entity_appearance(self, grids_dct): 508 | return OrderedDict(sorted(grids_dct.items(), key=lambda x: min(i for i, v in enumerate(x[1]) if v != "-"))) 509 | 510 | def turn_grids_into_to_write(self, grids_dct): 511 | # print('Formatted grids values 0: ', grids_dct.values()[0]) 512 | formatted = [[entity[i] for entity in grids_dct.values()] for i in range(len(list(grids_dct.values())[0]))] 513 | formatted.insert(0, grids_dct.keys()) 514 | return formatted 515 | 516 | 517 | def create_csv_folder(self, grids_dct, options, folder_name, folder_path, corpus_name, min_len_dial = 1): 518 | full_path = folder_path + corpus_name +'/'+folder_name+'/' 519 | logging.info('Creating output directory: %s', full_path) 520 | 521 | if not os.path.exists(full_path): 522 | os.makedirs(full_path) 523 | empty_dials = [] 524 | dump_csv(full_path+'Params', options.items()) # Params file 525 | logging.info('Params file created') 526 | for dialogue_id, dialogue in iteritems(grids_dct): 527 | if len(dialogue) >= min_len_dial: 528 | 529 | formatted_grid = self.sort_grid_entity_appearance(grids_dct[dialogue_id]) 530 | formatted_grid = self.turn_grids_into_to_write(formatted_grid) 531 | dump_csv(full_path+dialogue_id, formatted_grid) 532 | else: 533 | empty_dials.append(dialogue_id) 534 | print('Empty dialogue id: ', empty_dials) 535 | print('Len Empty dialogue id: ', len(empty_dials)) 536 | return 537 | 538 | 539 | def dump_csv(out_file, to_write): 540 | # print(to_write) 541 | with open(out_file + '.csv', 'w') as out: 542 | csv_out = csv.writer(out) 543 | for row in to_write: 544 | csv_out.writerow(row) 545 | 546 | def main(args): 547 | print(''.join(y for y in["-"]*180)) 548 | logging.basicConfig( 549 | level=(logging.DEBUG if args.verbose else logging.INFO), 550 | format='%(levelname)s %(message)s') 551 | 552 | if not args.outputname: 553 | raise TypeError('Missing output directory name') 554 | # if os.path.exists(args.outputpath+args.outputname): 555 | # raise Warning('Folder %s already exists', args.outputpath+args.outputname) 556 | # overwrite_folder = raw_input("Enter your name: ") 557 | 558 | 559 | # swda = Switchboard(args.input) 560 | # corpus_dct = swda.load_csv() 561 | 562 | corpus_dct, corpus_loader = get_corpus(args.input) 563 | 564 | logging.info('Files number: %d', len(corpus_dct)) 565 | # tags = corpus_loader.get_tags() 566 | # logging.info('DAs number: %d', len(tags)) 567 | logging.info('Corpus loaded') 568 | logging.info('Corpus dimension: %d', len(corpus_dct)) 569 | 570 | # Debug 571 | print('Ex. File names in dct: ', list(corpus_dct.keys())[0]) 572 | print('Ex. Dialogue type: ', type(corpus_dct[list(corpus_dct.keys())[0]])) 573 | print('Ex. Len of dialogue: ', len(corpus_dct[list(corpus_dct.keys())[0]])) 574 | print('Ex. Turn type: ', type(corpus_dct[list(corpus_dct.keys())[0]][0])) 575 | print('Ex. Turn 0: ', corpus_dct[list(corpus_dct.keys())[0]][0]) 576 | print('Ex. Turn 0-4: ') 577 | for y in corpus_dct[list(corpus_dct.keys())[0]]: 578 | print(y) 579 | 580 | grid_generator = GridGenerator() 581 | 582 | options = { 583 | 'entities_type' : 'headNPs', 584 | 'include_prons' : False, 585 | 'exclude_conversation_prons' : True, 586 | 'group_by' : 'DAspan', # DAspan, turns (Elsner & Charniak, 2011, "Disentangling Chat") 587 | 'tag_type' : 'da', # da, synrole_head (Lapata & Barzilay 2005/2008), synrole_X (Elsner & Charniak 2011) 588 | 'use_coref' : False, 589 | 'no_entity_column' : True, 590 | 'end_of_turn_tag' : False} 591 | 592 | default_confs = ['egrid_-coref', 'egrid_+coref', 'extgrid_-coref', 'simple_egrid_-coref', 593 | 'egrid_-coref_DAspan', 'egrid_-coref_DAspan_da', 'egrid_-coref_DAspan_da_noentcol'] 594 | 595 | if args.default not in default_confs+['']: raise TypeError('Default configuration inserted is not allowed') 596 | 597 | if args.default in default_confs: 598 | logging.info('Default configuration requested: %s', args.default) 599 | options['group_by'] = 'turns' 600 | options['no_entity_column'] = False 601 | options['end_of_turn_tag'] = False 602 | 603 | if args.default in ['egrid_-coref', 'egrid_-coref_DAspan', 'egrid_-coref_DAspan_da', 604 | 'egrid_-coref_DAspan_da_noentcol', 'simple_egrid_-coref']: 605 | # Head noun divided into, each noun in NP same grammatical role 606 | options['entities_type'] = 'allNPs' 607 | options['tag_type'] = 'synrole_head' 608 | options['use_coref'] = False 609 | if args.default in ['simple_egrid_-coref']: 610 | options['tag_type'] = 'is_present' 611 | if args.default in ['egrid_-coref_DAspan', 'egrid_-coref_DAspan_da','egrid_-coref_DAspan_da_noentcol']: 612 | options['group_by'] = 'DAspan' 613 | if args.default in ['egrid_-coref_DAspan_da','egrid_-coref_DAspan_da_noentcol']: 614 | options['tag_type'] = 'da' 615 | if args.default in ['egrid_-coref_DAspan_da_noentcol']: 616 | options['no_entity_column'] = True 617 | 618 | elif args.default=='egrid_+coref': 619 | # Keep only head nouns, perform coreference on it 620 | options['entities_type'] = 'headNPs' 621 | options['tag_type'] = 'synrole_head' 622 | options['use_coref'] = True 623 | options['include_prons'] = True 624 | 625 | # Extended grid default: what's the difference with egrid_-coref? Supposedly "Bush spokeman", where Bush=X 626 | elif args.default=='extgrid_-coref': 627 | # Add non-head nouns 628 | options['entities_type'] = 'allNPs' 629 | options['tag_type'] = 'synrole_X' 630 | options['use_coref'] = False 631 | 632 | logging.info('Set up') 633 | for k, v in iteritems(options): 634 | logging.info('%s : %s', k, v) 635 | 636 | # Process corpus 637 | grids = grid_generator.process_corpus(corpus_dct, 638 | entities_type=options['entities_type'], 639 | include_prons=options['include_prons'], 640 | exclude_conversation_prons= options['exclude_conversation_prons'], 641 | group_by=options['group_by'], 642 | tag_type=options['tag_type'], 643 | use_coref=options['use_coref'], 644 | no_entity_column=options['no_entity_column'], 645 | end_of_turn_tag=options['end_of_turn_tag']) 646 | 647 | print('Len grids: ', len(grids)) 648 | 649 | # Write out 650 | grid_generator.create_csv_folder(grids, options, args.outputname, args.outputpath, args.input) 651 | 652 | 653 | 654 | 655 | 656 | def argparser(parser=None, func=main): 657 | """parse command line arguments""" 658 | 659 | if parser is None: 660 | parser = argparse.ArgumentParser(prog='grid') 661 | 662 | parser.description = 'Generative implementation of Entity grid' 663 | parser.formatter_class = argparse.ArgumentDefaultsHelpFormatter 664 | 665 | parser.add_argument('input', nargs='?', 666 | # type=argparse.FileType('r'), 667 | # default=sys.stdin, 668 | type=str, 669 | help='input corpus in doctext format') 670 | 671 | parser.add_argument('default', nargs='?', 672 | type=str, 673 | default='', 674 | help="default settings of all params") 675 | 676 | parser.add_argument('outputname', nargs='?', 677 | type=str, 678 | help="output_name") 679 | 680 | parser.add_argument('outputpath', nargs='?', 681 | type=str, 682 | default='data/', 683 | help="output folder path") 684 | 685 | parser.add_argument('--verbose', '-v', 686 | action='store_true', 687 | help='increase the verbosity level') 688 | 689 | if func is not None: 690 | parser.set_defaults(func=func) 691 | 692 | return parser 693 | 694 | if __name__ == '__main__': 695 | main(argparser().parse_args()) 696 | 697 | # Old 698 | # python generate_grid.py ../../Datasets/Switchboard/data/switchboard1-release2/ outout 699 | # python generate_grid.py ../../Datasets/Switchboard/data/switchboard1-release2/ egrid_-coref_DAspan_da egrid_-coref_DAspan_da data/ -v 700 | # python generate_grid.py ../../Datasets/Switchboard/data/switchboard1-release2/ egrid_-coref_DAspan_da_noentcol egrid_-coref_DAspan_da_noentcol data/ -v 701 | 702 | # New 703 | # python generate_grid.py Switchboard egrid_-coref_DAspan_da egrid_-coref_DAspan_da data/ -v 704 | # python generate_grid.py Oasis egrid_-coref_DAspan_da egrid_-coref_DAspan_da data/ -v 705 | # python generate_grid.py AMI egrid_-coref egrid_-coref data/ -v 706 | # python generate_grid.py Switchboard simple_egrid_-coref simple_egrid_-coref data/ -v 707 | # simple_egrid_-coref -------------------------------------------------------------------------------- /generate_joined_dat.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os, errno 3 | 4 | def create_path(filename): 5 | if not os.path.exists(os.path.dirname(filename)): 6 | try: 7 | os.makedirs(os.path.dirname(filename)) 8 | except OSError as exc: 9 | if exc.errno != errno.EEXIST: 10 | raise 11 | return filename 12 | 13 | def read_test_file(path): 14 | query_id = None 15 | with open(path, 'r') as infile: 16 | for line in infile: 17 | if line[0] is not '#': 18 | yield query_id, line.strip().split() 19 | else: 20 | query_id = line.strip().split()[2] 21 | 22 | def combine_files(egrid_file, noents_file, trans): 23 | trans2feat = {'2': 16} 24 | max_trans_feat = trans2feat.get(trans) 25 | docs_to_write = defaultdict(list) 26 | 27 | for i, egrid_info in enumerate(egrid_file): 28 | 29 | noents_info = noents_file[i] 30 | # print(egrid_info) 31 | # print(noents_info) 32 | doc_id, egrid_i_feat = egrid_info 33 | doc_id_noent, noent_i_feat = noents_info 34 | 35 | label = egrid_i_feat[0] 36 | qid = egrid_i_feat[1] 37 | 38 | if doc_id!=doc_id_noent or label!=noent_i_feat[0] or qid!=noent_i_feat[1]: 39 | print('Doc id: ', doc_id, ' Doc noent: ', doc_id_noent) 40 | print('Label : ', label, ' Label no ent: ', noent_i_feat[0]) 41 | print('Qid : ', qid, ' Qid no ent: ', noent_i_feat[1]) 42 | raise TypeError('Matching failed between the two files') 43 | 44 | 45 | egrid_features = [(f_i.split(':')[0], f_i.split(':')[1]) for f_i in egrid_i_feat[2:]] 46 | mod_noents_features = [(int(f_i.split(':')[0])+max_trans_feat, f_i.split(':')[1]) for f_i in noent_i_feat[2:]] 47 | 48 | joined_i = (label, qid, egrid_features+mod_noents_features) 49 | docs_to_write[doc_id].append(joined_i) 50 | # print('Joined: ', joined_i) 51 | 52 | return docs_to_write 53 | 54 | def write_to_dat(docs_to_write, outpath): 55 | with open(outpath, 'w') as to_write: 56 | for doc_id, infos in docs_to_write.items(): 57 | to_write.write('# query ' + str(doc_id) + '\n') 58 | for info_i in infos: 59 | label = info_i[0] 60 | query_i = info_i[1] 61 | features = info_i[2] 62 | 63 | to_write.write(str(label) + " " + str(query_i)) 64 | for feat_ind, feat_val in features: 65 | to_write.write(" " + str(feat_ind) + ":" + str(feat_val)) 66 | to_write.write('\n') 67 | 68 | 69 | 70 | def main(): 71 | corpus = 'Oasis' 72 | exper_path = 'experiments/' 73 | #egrid = 'simple_egrid_-coref' 74 | egrid = 'egrid_-coref' 75 | noents = 'noents_baseline' 76 | task = 'last_turn_ranking' 77 | saliency = 1 78 | trans = '2' 79 | data_types = ['test', 'train', 'dev'] 80 | # data_types = ['test'] 81 | # joined_name = 'egrid+noents' 82 | joined_name = 'simple_egrid+noents' 83 | 84 | for data_type in data_types: 85 | 86 | path_noents = exper_path + corpus + '/' + task + '/' + noents + '/' + corpus + '_sal' + str( 87 | saliency) + '_range' + trans + "_" + trans + "_" + data_type + '.dat' 88 | path_egrid = exper_path + corpus + '/' + task + '/' + egrid + '/' + corpus + '_sal' + str( 89 | saliency) + '_range' + trans + "_" + trans + "_" + data_type + '.dat' 90 | joined_path = create_path(exper_path + corpus + '/' + task + '/' + joined_name + '/' + corpus + '_sal' + str( 91 | saliency) + '_range' + trans + "_" + trans + "_" + data_type + '.dat') 92 | 93 | print(path_egrid) 94 | print(path_noents) 95 | egrid_file = list(read_test_file(path_egrid)) 96 | noents_file = list(read_test_file(path_noents)) 97 | print(len(egrid_file)) 98 | print(len(noents_file)) 99 | to_write = combine_files(egrid_file, noents_file, trans) 100 | write_to_dat(to_write, joined_path) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /generate_noentities_baseline.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from load_grids import GridLoader 3 | import pandas as pd 4 | import os 5 | import tqdm 6 | 7 | 8 | def main(): 9 | corpus = 'AMI' 10 | grids_path = 'data/'+corpus+'/egrid_-coref_DAspan_da_noentcol/' 11 | no_ents_baseline_path = 'data/'+corpus+'/noents_baseline/' 12 | grid_loader = GridLoader(grids_path) 13 | if not os.path.exists(grids_path): 14 | raise TypeError("The following folder does not exist " + grids_path) 15 | 16 | grids, _ = grid_loader.load_data() 17 | print('Number of grids: ', len(grids)) 18 | grid_names = [x for x in grids if x != 'Params'] 19 | 20 | for grid_i_name in tqdm.tqdm(grid_names): 21 | grid_i = grids.get(grid_i_name) 22 | grid_i_da_seq = [list(set([da for da in row if da != '_'])) for index, row in grid_i.iterrows()] 23 | if all(len(das) == 1 for das in grid_i_da_seq): 24 | grid_i_da_seq = [da[0] for da in grid_i_da_seq] 25 | df = pd.DataFrame.from_items([('all_das', grid_i_da_seq)]) 26 | df.to_csv(path_or_buf=no_ents_baseline_path + grid_i_name + '.csv', index=False) 27 | else: 28 | raise TypeError("Not only one Dialogue Act per row in grid " + grid_i_name) 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | main() -------------------------------------------------------------------------------- /generate_shuffled.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from future.utils import iteritems 3 | from builtins import dict 4 | from corpus.Switchboard.Switchboard import Switchboard 5 | from generate_grid import dump_csv, GridGenerator 6 | from load_grids import GridLoader 7 | from itertools import groupby 8 | from generate_grid import corpora_paths, get_corpus 9 | from pandas.util.testing import assert_frame_equal 10 | from operator import itemgetter 11 | import itertools 12 | import argparse 13 | import pandas as pd 14 | import warnings 15 | import copy 16 | import tqdm 17 | import re 18 | import logging 19 | import random 20 | import os 21 | import numpy as np 22 | 23 | warnings.simplefilter('error', UserWarning) 24 | 25 | 26 | class GridShuffler(object): 27 | '''Class to generate and save shuffled files''' 28 | 29 | def __init__(self, grid_folder=None, grid_loader=None, grid_generator=None): 30 | 31 | if grid_loader: 32 | self.grid_loader = grid_loader 33 | else: 34 | try: 35 | assert os.path.exists(grid_folder) 36 | except AssertionError: 37 | print("The folder " + grid_folder + " does not exist.") 38 | exit(1) 39 | self.grid_loader = GridLoader(grid_folder) 40 | 41 | self.grids, self.grids_params = self.grid_loader.get_data() 42 | if not self.grids: 43 | self.grids, self.grids_params = self.grid_loader.load_data() 44 | 45 | if grid_generator: 46 | self.grid_generator = grid_generator 47 | else: 48 | self.grid_generator = GridGenerator(coref='no_coref', nlp='no_nlp') 49 | 50 | self.turn2da = {} # tuples dct 51 | self.da2turn = {} 52 | 53 | def update_grids_dct(self, grid_names): 54 | self.grids = {k:v for k,v in iteritems(self.grids) if k in grid_names} 55 | 56 | def get_intra_shuffle(self, len_dialogue, shuffles_number=20): 57 | index_shuff = range(len_dialogue) 58 | # print('Dial len: ', len(index_shuff)) 59 | 60 | shuffled_orders = [] 61 | # for i in range(shuffles_number): 62 | # shuffled_order = random.sample(index_shuff, len(index_shuff)) 63 | # shuffled_orders.append(shuffled_order) 64 | shuff_i = 0 65 | # print('Shuffled orders: ', shuffled_orders) 66 | while shuff_i < shuffles_number: 67 | s = random.sample(index_shuff, len(index_shuff)) 68 | if s != index_shuff: 69 | shuffled_orders.append(s) 70 | shuff_i += 1 71 | else: 72 | pass 73 | return shuffled_orders 74 | 75 | def generate_turn_shuffle_index(self, corpus, shuffles_number=20, min_dial_len =4, corpus_path=''): 76 | 77 | shuffled_dialogues = {} 78 | random.seed(0) 79 | 80 | for dialogue_id, dialogue in tqdm.tqdm(iteritems(corpus), total=len(corpus)): 81 | dialogue = self.grid_generator.group_turns(dialogue) 82 | # print('Dialogue: ', dialogue_id) 83 | 84 | # dialogue = dialogue[:8] # To delete 85 | len_dialogue = len(dialogue) 86 | 87 | # Minimum 5 turns and check we have a corresponding grid file 88 | if len_dialogue > min_dial_len and os.path.exists(corpus_path+dialogue_id+'.csv'): 89 | shuffled_orders = self.get_intra_shuffle(len_dialogue, shuffles_number=shuffles_number) 90 | max_ind = list(set([max(s) for s in shuffled_orders])) 91 | if len(max_ind)==1 and max_ind[0]==(len_dialogue-1) and max_ind[0]==list(set([len(s)-1 for s in shuffled_orders]))[0]: 92 | shuffled_dialogues[dialogue_id] = shuffled_orders 93 | else: 94 | print('ID: ', dialogue_id) 95 | print('Dialogue orders messed up') 96 | print('Len Dialogue: ', len_dialogue) 97 | print('Dialogue check: ', [dialogue[i] for i in range(len_dialogue) if int(dialogue[i][-1]) != (i + 1)]) 98 | print('Shuff len :', [len(s) for s in shuffled_orders]) 99 | print('Shuff max :', max_ind) 100 | 101 | print('Len shuffled dialogues: ', len(shuffled_dialogues)) 102 | return shuffled_dialogues 103 | 104 | def count_pairs(self): 105 | raise NotImplementedError 106 | 107 | def write_csv_indexes(self, shuffled_dialogues, folder_name='shuffled', 108 | folder_path='data/', corpus_name='Switchboard', rewrite_new=False): 109 | full_path = folder_path + corpus_name + '/' + folder_name+'/' 110 | print('Out path: ', full_path) 111 | if not os.path.exists(full_path): 112 | if rewrite_new: 113 | os.makedirs(full_path) 114 | print('Shuffled indexes dir created: ', full_path) 115 | else: 116 | warnings.warn('The shuffled folder already exists and cannot be overwritten.') 117 | for dialogue_id, dialogue_indexes in iteritems(shuffled_dialogues): 118 | dump_csv(full_path+dialogue_id, dialogue_indexes) 119 | print('Shuffled files created') 120 | 121 | def create_shuffle_index_files(self, corpus_dct, corpus_name, shuffles_number=20, grids_path='', rewrite_new=False): 122 | shuffled_dialogues = self.generate_turn_shuffle_index(corpus_dct, 123 | shuffles_number=shuffles_number, 124 | corpus_path=grids_path) 125 | print('Len shuffled dialogues: ', len(shuffled_dialogues)) 126 | self.write_csv_indexes(shuffled_dialogues, 127 | corpus_name=corpus_name, 128 | rewrite_new=rewrite_new) 129 | 130 | def check_match_shuff_original(self, shuff_path): 131 | # Check shuffled and original match 132 | if not os.path.exists(shuff_path): 133 | warnings.warn('The shuffled folder does not exist') 134 | shuffled_dialogues_list = [x.rstrip('.csv') for x in os.listdir(shuff_path)] 135 | is_match = all(grid_name in shuffled_dialogues_list for grid_name in self.grids) 136 | return is_match 137 | 138 | def get_turns_to_da_map(self, dialogue, grid_i_name, is_eot=False): 139 | group_maps = [] 140 | previous_ind = 0 141 | # start/end tuples where index=turn 142 | for k, g in groupby(dialogue, itemgetter(3)): 143 | g = list(g) 144 | group_maps.append([previous_ind, (previous_ind)+len(g)]) 145 | previous_ind += len(g) 146 | # print('2nd group: ', group_maps[1]) 147 | # # Get das for turn-range 148 | # print('2nd group corresponding indexes: ', [dialogue[i] for i in range(group_maps[1][0], group_maps[1][-1])]) 149 | # print('3d group: ', group_maps[2]) 150 | # print('3d group corresponding indexes: ', [dialogue[i] for i in range(group_maps[2][0], group_maps[2][-1])]) 151 | 152 | self.turn2da[grid_i_name] = {group_maps.index(g):g for g in group_maps} 153 | self.da2turn[grid_i_name] = {i:group_maps.index(g) for g in group_maps for i in range(g[0],g[-1])} 154 | return group_maps 155 | 156 | 157 | def map_index_shuffles_to_grids(self, permuted_indexes, grid, df=False): 158 | # Returns list of rows from original grid reordered according to permuted order 159 | # return [[grid.iloc[ind] for ind in perm] for perm in permuted_indexes] 160 | print('Permuted indexes: ', len(permuted_indexes)) 161 | print('Grid shape: ', grid.shape) 162 | if df is False: 163 | perm_rows = [[grid.iloc[ind] for ind in perm] for perm in permuted_indexes] 164 | else: 165 | perm_rows = [] 166 | # For permutation type 167 | for perm in permuted_indexes: 168 | perm_i_rows = [grid.iloc[ind] for ind in perm] 169 | perm_i_rows_df = pd.DataFrame.from_items([(c, [r[i] for r in perm_i_rows]) for i, c in enumerate(grid.columns)]) 170 | perm_rows.append(perm_i_rows_df) 171 | 172 | return perm_rows 173 | 174 | def map_index_shuffles_to_grids_fast(self, permuted_indexes, grid): 175 | # Returns list of rows from original grid reordered according to permuted order 176 | # return [[grid.iloc[ind] for ind in perm] for perm in permuted_indexes] 177 | # print('Permuted indexes: ', len(permuted_indexes)) 178 | # print('Grid shape: ', grid.shape) 179 | permuted_indexes = [permuted_indexes[0]] # Testing 180 | perm_rows = [] 181 | # For permutation type 182 | for perm in permuted_indexes: 183 | perm_grid = grid.copy() 184 | 185 | for ind_pos, ind in enumerate(perm): 186 | if ind < grid.shape[0] and ind_pos < grid.shape[0]: 187 | perm_grid.iloc[ind_pos] = grid.iloc[ind].copy() 188 | perm_rows.append(perm_grid) 189 | 190 | return perm_rows 191 | 192 | def map_index_shuffles_to_grids_veryfast(self, permuted_indexes, grid): 193 | # Returns list of rows from original grid reordered according to permuted order 194 | # return [[grid.iloc[ind] for ind in perm] for perm in permuted_indexes] 195 | # print('Permuted indexes: ', len(permuted_indexes)) 196 | # print('Grid shape: ', grid.shape) 197 | # permuted_indexes = [permuted_indexes[0]] # Testing 198 | # perm_rows = [] 199 | # # For permutation type 200 | # for perm in permuted_indexes: 201 | # perm_rows.append(grid.reindex(perm).reset_index(drop=True)) 202 | perm_rows = [grid.reindex(perm).reset_index(drop=True) for perm in permuted_indexes] 203 | return perm_rows 204 | 205 | def map_index_reinsertion_to_grids_veryfast(self, grid, times_number = 10): 206 | # Returns list of rows from original grid reordered according to permuted order 207 | perm_rows = [] 208 | turns_number = grid.shape[0] 209 | 210 | np.random.seed(0) 211 | 212 | # print('Original grid: ', turns_number) 213 | # print('Original grid: ', grid) 214 | for i in range(min(turns_number, times_number)): 215 | 216 | sent_idx = np.random.randint(0, turns_number-1) 217 | # print('Turn index to reinsert: ', sent_idx) 218 | del_index = range(turns_number) 219 | del del_index[sent_idx] 220 | all_perm_turn_i = [] 221 | for j in range(min(turns_number, times_number)): 222 | permuted_index = copy.deepcopy(del_index) 223 | cand = np.random.randint(0, turns_number) 224 | while cand == sent_idx: 225 | cand = np.random.randint(0, turns_number) 226 | permuted_index.insert(cand, sent_idx) 227 | # print('Permuted index: ', permuted_index) 228 | all_perm_turn_i.append(grid.reindex(permuted_index).reset_index(drop=True)) 229 | # print('Permuted grid: ', grid.reindex(permuted_index).reset_index(drop=True)) 230 | perm_rows.append(all_perm_turn_i) 231 | 232 | return perm_rows 233 | 234 | def map_index_shuffles_to_grids_das(self, permuted_indexes, grid, group_maps, df=False): 235 | # Returns list of rows from original grid reordered according to permuted order 236 | perm_rows=[] 237 | # For permutation type 238 | for perm in permuted_indexes: 239 | perm_i_rows = [] 240 | 241 | # For each turn index 242 | for ind in perm: 243 | # print('Perm ind: ', ind) 244 | # print('SO: ', group_maps[ind], ' Range: ', range(group_maps[ind][0], group_maps[ind][-1])) 245 | 246 | # Select all DA row indexes in that span 247 | for da_i in range(group_maps[ind][0]-1, group_maps[ind][-1]-1): 248 | # print('Rows: ', grid.iloc[da_i]) 249 | if da_i < grid.shape[0]: 250 | perm_i_rows.append(grid.iloc[da_i]) 251 | else: 252 | pass 253 | # print('DA index: ', da_i) 254 | # print('Previous row selected: ', [r for r in grid.iloc[da_i-2]]) 255 | 256 | if df is False: 257 | perm_i_rows.insert(0, grid.columns) 258 | else: 259 | # Convert into pandas DataFrame 260 | perm_i_rows = pd.DataFrame.from_items([(c, [r[i] for r in perm_i_rows]) for i, c in enumerate(grid.columns)]) 261 | 262 | perm_rows.append(perm_i_rows) 263 | 264 | return perm_rows 265 | 266 | def map_index_shuffles_to_grids_das_fast(self, permuted_indexes, grid, group_maps, df=False): 267 | perm_rows = [] 268 | # For permutation type 269 | 270 | permuted_indexes=[permuted_indexes[0]] # Testing 271 | for perm in permuted_indexes: 272 | # perm_i_rows = [] 273 | perm_grid = grid.copy() 274 | ind_pos = 0 275 | 276 | # For each turn index 277 | for ind in perm: 278 | # print('Perm ind: ', ind, 'Ind pos: ', ind_pos) 279 | # print('SO: ', group_maps[ind], ' Range: ', range(group_maps[ind][0], group_maps[ind][-1])) 280 | 281 | # Select all DA row indexes in that span 282 | for da_i in range(group_maps[ind][0], group_maps[ind][-1]): 283 | if ind_pos < grid.shape[0] and da_i < grid.shape[0]: 284 | # print('Rows: ', grid.iloc[da_i]) 285 | # print('Rows: ', [(grid.columns[i], y) for i, y in enumerate(grid.iloc[da_i]) if y != '_']) 286 | perm_grid.iloc[ind_pos] = grid.iloc[da_i].copy() 287 | ind_pos += 1 288 | # It cannot be only ind_pos but a range 289 | # print('Perm Rows: ', [(perm_grid.columns[i], y) for i, y in enumerate(grid.iloc[da_i]) if y != '_']) 290 | # perm_i_rows.append(grid.iloc[da_i]) 291 | else: 292 | pass 293 | # print('DA index: ', da_i) 294 | # print('Previous row selected: ', [r for r in grid.iloc[da_i-2]]) 295 | # print('Perm grid 3', [(perm_grid.columns[i], y) for i, y in enumerate(perm_grid.iloc[3]) if y != '_']) 296 | # print('Perm grid 1', [(perm_grid.columns[i], y) for i, y in enumerate(perm_grid.iloc[1]) if y != '_']) 297 | perm_rows.append(perm_grid) 298 | 299 | return perm_rows 300 | 301 | def map_index_shuffles_to_grids_das_veryfast(self, permuted_indexes, grid, group_maps, df=False): 302 | # permuted_indexes = [permuted_indexes[0]] # Testing 303 | perm_rows = [grid.reindex(self.turns_to_da(perm, group_maps)).reset_index(drop=True) for perm in permuted_indexes] 304 | # perm_rows = [grid.reindex(self.turns_to_da(perm, group_maps)) for perm in 305 | # permuted_indexes] # Testing 306 | 307 | return perm_rows 308 | 309 | def map_index_reinsertion_to_grids_das_veryfast(self, turns_number, grid, group_maps, times_number=10, df=False): 310 | # permuted_indexes = [permuted_indexes[0]] # Testing 311 | np.random.seed(0) 312 | perm_rows = [] 313 | turns_number = len(turns_number[0]) 314 | for i in range(min(turns_number, times_number)): 315 | sent_idx = np.random.randint(0, turns_number-1) 316 | del_index = range(turns_number) 317 | del del_index[sent_idx] 318 | all_perm_turn_i = [] 319 | for i in range(min(turns_number, times_number)): 320 | permuted_index = copy.deepcopy(del_index) 321 | cand = np.random.randint(0, turns_number) 322 | while cand == sent_idx: 323 | cand = np.random.randint(0, turns_number) 324 | permuted_index.insert(cand, sent_idx) 325 | # print('Permuted index: ', permuted_index) 326 | # print('Len permuted index: ', len(permuted_index)) 327 | # print('Group maps: ', group_maps) 328 | all_perm_turn_i.append(grid.reindex(self.turns_to_da(permuted_index, group_maps)).reset_index(drop=True)) 329 | perm_rows.append(all_perm_turn_i) 330 | 331 | return perm_rows 332 | 333 | 334 | 335 | def turns_to_da(self, perm, group_maps): 336 | return list(itertools.chain(*map(lambda x: range(*group_maps[x]), perm))) 337 | 338 | def test_correspondance(self, y_row_ind, permut_i, permuted_indexes_i, dialogue_i, group_maps, grid_i, grid_i_name, permuted_files): 339 | 340 | print('Grid i shape: ', grid_i.shape) 341 | print('Dialogue len: ', len(dialogue_i)) 342 | print('turn to DA:', self.turn2da[grid_i_name]) 343 | print('DA to turn:', self.da2turn[grid_i_name]) 344 | print('Len one permutation: ', len(permuted_files[grid_i_name][0])) 345 | print('Type one permutation: ', type(permuted_files[grid_i_name][0])) 346 | print('Shape one permutation: ', permuted_files[grid_i_name][0].shape) 347 | # print('Perm df: ', permuted_files[grid_i_name][0]) 348 | print('Shape one permutation: ', permuted_files[grid_i_name][0].iloc[0]) 349 | 350 | turn_y_permuted_index = permuted_indexes_i[permut_i][y_row_ind] # ind to substitute: 44 351 | print('First permut, first row index: ', turn_y_permuted_index) 352 | print('Da group indexes corresponding to that turn index: ', group_maps[turn_y_permuted_index]) 353 | print('Dialogue ref: ', [dialogue_i[i] for i in range(group_maps[turn_y_permuted_index][0], 354 | group_maps[turn_y_permuted_index][-1])]) 355 | 356 | # Check original grid rows corresponding to turn_y_permuted_index (44) 357 | print('Original grid da rows for turn ', turn_y_permuted_index) 358 | # For each DA index corresponding to the first turn in the permuted indexes 359 | for d in range(group_maps[turn_y_permuted_index][0],group_maps[turn_y_permuted_index][-1]): 360 | print('Original row name: ', grid_i.iloc[d].name) 361 | 362 | # Check permuted grid rows corresponding to y_row_ind (0) 363 | print('Permuted grid rows for turn ', y_row_ind) 364 | print('Start ', y_row_ind) 365 | print('End ', y_row_ind + (group_maps[turn_y_permuted_index][-1] - group_maps[turn_y_permuted_index][0])) 366 | 367 | # # For each DA index corresponding to the first turn in the permuted indexes 368 | # for d in range(y_row_ind, 369 | # y_row_ind+(group_maps[turn_y_permuted_index][-1] - group_maps[turn_y_permuted_index][0])): 370 | # print('da index') 371 | # print('Permuted grid first list Name: ', permuted_files[grid_i_name][permut_i][d].name) 372 | 373 | # for x in range(1, 100): 374 | # current_perm = permuted_files[grid_i_name][permut_i] 375 | # current_perm_indexes = permuted_indexes_i[permut_i] # [48, 85, ] 376 | # print('Permuted grid DA ind: ', x,' Orig DA ind: ', current_perm[x].name) 377 | # print('Corresponding original Turn index: ', 378 | # self.da2turn[grid_i_name][current_perm[x].name], 379 | # ' Turn position in permuted: ', current_perm_indexes.index(self.da2turn[grid_i_name][current_perm[x].name])) 380 | # print('Dialogue i ', 381 | # dialogue_i[current_perm[x].name]) 382 | 383 | # for x in range(0, 100): 384 | # # print('Permuted grid DA ind: ', x, ' Permut row for that index: ', permuted_files[grid_i_name][permut_i].iloc[x]) 385 | # current_perm = permuted_files[grid_i_name][permut_i] 386 | # current_perm_indexes = permuted_indexes_i[permut_i] # [48, 85, ] 387 | # print('Permuted grid DA ind: ', x, ' Permut row for that index: ', 388 | # [(current_perm.columns[i], y) 389 | # for i, y in enumerate(current_perm.iloc[x]) if y != '_']) 390 | # print('Corresponding original Turn index: ', permuted_indexes_i[permut_i][self.da2turn[grid_i_name][x]]) 391 | 392 | def generate_shuffled_grids(self, folder_name='shuffled', corpus_name ='Switchboard', 393 | folder_path='data/', corpus_dct=None, 394 | only_grids = None, df=False, saliency=1, return_originals=False): 395 | # Check shuffled folder exist 396 | shuff_path = folder_path + corpus_name + '/' + folder_name+'/' 397 | print('Shuff path: ', shuff_path) 398 | self.check_match_shuff_original(shuff_path) 399 | self.grid_generator.corpus_stats(corpus_dct) 400 | 401 | permuted_files = {} 402 | if return_originals: 403 | original_files = {} 404 | 405 | grid_names = [x for x in self.grids if not re.match(r'.+\_s[0-9][0-9]*', x) and x!='Params'] 406 | # print('Grid names: ', grid_names) 407 | 408 | if only_grids is not None: 409 | grid_names = [x for x in grid_names if x in only_grids and x!='.DS_Store'] 410 | 411 | print('Len grids to permute: ', len(grid_names)) 412 | self.update_grids_dct(grid_names) 413 | 414 | # Permute already generated grids according to shuffled indexes order 415 | for grid_i_name in tqdm.tqdm(grid_names): 416 | # print('Grid id: ', grid_i_name) 417 | grid_i = self.grids.get(grid_i_name) 418 | 419 | # Check saliency 420 | if saliency>1: 421 | grid_i.drop([col for col in grid_i if len([i for i in grid_i[col] if i != '_']) < saliency], axis=1) 422 | 423 | if return_originals: 424 | original_files[grid_i_name] = grid_i 425 | 426 | 427 | # print('Grid i columns: ', grid_i.columns) 428 | shuffled_indexes_i = pd.read_csv(shuff_path+grid_i_name+'.csv', header=None, engine="c").T 429 | 430 | # Get permutations 431 | permuted_indexes_i = [[ind for ind in shuffled_indexes_i[col]] for col in shuffled_indexes_i.columns] 432 | # print('First permut: ', permuted_indexes_i[0]) 433 | # print('Groupby: ', self.grids_params.group_by[1]) 434 | # print('End_of_turn_tag: ', self.grids_params.end_of_turn_tag[1]) 435 | 436 | # Read Params of source grids: if 'group_by' : 'DAspan' (one more layer of mapping) or 'turns' (you can get it directly) 437 | if self.grids_params.group_by[1] != 'turns': 438 | dialogue_i = corpus_dct.get(grid_i_name) 439 | group_maps = self.get_turns_to_da_map(dialogue_i, grid_i_name) # list of turns span 440 | # print('Group maps: ', group_maps) 441 | # permuted_files[grid_i_name] = self.map_index_shuffles_to_grids_das(permuted_indexes_i, grid_i, group_maps, df=df) 442 | permuted_files[grid_i_name] = self.map_index_shuffles_to_grids_das_veryfast(permuted_indexes_i, grid_i, group_maps) 443 | # print(assert_frame_equal(self.map_index_shuffles_to_grids_das_veryfast(permuted_indexes_i, grid_i, group_maps)[0], 444 | # permuted_files[grid_i_name][0], check_dtype=False)) 445 | 446 | # y_row_ind = 1 # Index of row to check 447 | # permut_i = 0 # Permutation number 448 | # self.test_correspondance(y_row_ind, permut_i, permuted_indexes_i, dialogue_i, 449 | # group_maps, grid_i, grid_i_name, permuted_files) 450 | 451 | else: 452 | 453 | # print('Orig grid: ', grid_i.columns) 454 | # print('Orig grid: ', grid_i.iloc[0]) 455 | permuted_files[grid_i_name] = self.map_index_shuffles_to_grids_veryfast(permuted_indexes_i, grid_i) 456 | # print(assert_frame_equal(self.map_index_shuffles_to_grids_veryfast(permuted_indexes_i, grid_i)[0], permuted_files[grid_i_name][0], check_dtype=False)) 457 | 458 | # print('First permuted grid type: ', type(permuted_files[grid_i_name][0])) 459 | # print('First permut, first row index: ', permuted_indexes_i[0][0]) 460 | # print('First permuted grid first list: ', permuted_files[grid_i_name][0][0]) 461 | # print('First permut, first row index: ', permuted_indexes_i[0][1]) 462 | # print('First permuted grid first list: ', permuted_files[grid_i_name][0][1]) 463 | # print('Original grid row ', permuted_indexes_i[0][1], ' :', grid_i.iloc[permuted_indexes_i[0][1]]) 464 | 465 | # break 466 | 467 | # print('Permutation 0 indexes: ', permuted_indexes_i[0]) 468 | 469 | if return_originals: 470 | return permuted_files, original_files 471 | else: 472 | return permuted_files 473 | 474 | def generate_grids_for_insertion(self, folder_name='shuffled', corpus_name ='Switchboard', 475 | folder_path='data/', corpus_dct=None, 476 | only_grids = None, df=False, saliency=1, return_originals=False): 477 | 478 | shuff_path = folder_path + corpus_name + '/' + folder_name + '/' 479 | 480 | permuted_files = {} 481 | grid_names = [x for x in self.grids if not re.match(r'.+\_s[0-9][0-9]*', x) and x!='Params'] 482 | # print('Grid names: ', grid_names) 483 | 484 | if return_originals: 485 | original_files = {} 486 | 487 | 488 | if only_grids is not None: 489 | grid_names = [x for x in grid_names if x in only_grids and x!='.DS_Store'] 490 | 491 | print('Len grids to permute: ', len(grid_names)) 492 | self.update_grids_dct(grid_names) 493 | 494 | # Permute already generated grids according to shuffled indexes order 495 | for grid_i_name in tqdm.tqdm(grid_names): 496 | # print('Grid id: ', grid_i_name) 497 | grid_i = self.grids.get(grid_i_name) 498 | 499 | # Check saliency 500 | if saliency>1: 501 | grid_i.drop([col for col in grid_i if len([i for i in grid_i[col] if i != '_']) < saliency], axis=1) 502 | 503 | if return_originals: 504 | original_files[grid_i_name] = grid_i 505 | 506 | shuffled_indexes_i = pd.read_csv(shuff_path+grid_i_name+'.csv', header=None, engine="c").T 507 | 508 | # Get permutations 509 | permuted_indexes_i = [[ind for ind in shuffled_indexes_i[col]] for col in shuffled_indexes_i.columns] 510 | 511 | # Read Params of source grids: if 'group_by' : 'DAspan' (one more layer of mapping) or 'turns' (you can get it directly) 512 | if self.grids_params.group_by[1] != 'turns': 513 | dialogue_i = corpus_dct.get(grid_i_name) 514 | group_maps = self.get_turns_to_da_map(dialogue_i, grid_i_name) # list of turns span 515 | # print('Group maps: ', group_maps) 516 | 517 | # turns_number, grid, group_maps 518 | permuted_files[grid_i_name] = self.map_index_reinsertion_to_grids_das_veryfast(permuted_indexes_i, grid_i, group_maps) 519 | 520 | # y_row_ind = 1 # Index of row to check 521 | # permut_i = 0 # Permutation number 522 | # self.test_correspondance(y_row_ind, permut_i, permuted_indexes_i, dialogue_i, 523 | # group_maps, grid_i, grid_i_name, permuted_files) 524 | 525 | else: 526 | 527 | permuted_files[grid_i_name] = self.map_index_reinsertion_to_grids_veryfast(grid_i) 528 | 529 | 530 | # print('Permutation 0 indexes: ', permuted_indexes_i[0]) 531 | 532 | if return_originals: 533 | return permuted_files, original_files 534 | else: 535 | return permuted_files 536 | 537 | 538 | def write_shuffled_grids(self, permuted_files, grids_path): 539 | 540 | print('Writing shuffled grid to: ', grids_path) 541 | for grid_i, permuted_files_i in tqdm.tqdm(iteritems(permuted_files)): 542 | # print('Grid: ', grid_i) 543 | # print('Perm files: ', len(permuted_files_i)) 544 | for perm_i, perm_file in enumerate(permuted_files_i): 545 | # print("Writing ", grids_path+grid_i+"_s"+str(perm_i)) 546 | # print('Test: ') 547 | # for i in range(1, len(perm_file)): 548 | # print('N:', perm_file[i].name, [y for y in perm_file[i]]) 549 | dump_csv(grids_path+grid_i+"_s"+str(perm_i), perm_file) 550 | # break 551 | 552 | return 553 | 554 | 555 | def parse(): 556 | parser = argparse.ArgumentParser(description='Shuffle generator') 557 | parser.add_argument('-gs', '--generate_shuffle', default='Oasis', help='Generate shuffle') 558 | parser.add_argument('-m', '--grid_mode', default='egrid_-coref', help='Grid mode') 559 | parser.add_argument('-sn', '--shuffles_number', default=20, help='Number of shuffles') 560 | parser.add_argument('-rr', '--rewrite_new', default=True, help='Overwrite shuffle indexes') 561 | args = parser.parse_args() 562 | return args 563 | 564 | 565 | def run(args): 566 | corpus_name, grid_mode, shuffles_number, rewrite_new = args.generate_shuffle, args.grid_mode, \ 567 | args.shuffles_number, args.rewrite_new 568 | 569 | if args.generate_shuffle: 570 | grids_path = 'data/' + corpus_name + '/' + grid_mode + '/' 571 | corpus_dct, _ = get_corpus(corpus_name) 572 | grid_loader = GridLoader(grids_path) 573 | grid_generator = GridGenerator(coref='no_coref', 574 | nlp='no_nlp') 575 | grid_shuffler = GridShuffler(grids_path, 576 | grid_loader=grid_loader, 577 | grid_generator=grid_generator) 578 | grid_shuffler.create_shuffle_index_files(corpus_dct, 579 | corpus_name=corpus_name, 580 | shuffles_number=shuffles_number, 581 | grids_path=grids_path, 582 | rewrite_new=rewrite_new) 583 | 584 | 585 | # def main(): 586 | # corpus_name = 'Oasis' 587 | # 588 | # # swda = Switchboard('../../Datasets/Switchboard/data/switchboard1-release2/') 589 | # corpus_dct, corpus_loader = get_corpus(corpus_name) 590 | # # print('Corpus files number: ', len(corpus_dct)) 591 | # 592 | # # grids_path = 'data/Switchboard/egrid_-coref/' 593 | # grids_path = 'data/'+corpus_name+'/egrid_-coref/' 594 | # 595 | # grid_loader = GridLoader(grids_path) 596 | # grid_generator = GridGenerator(coref='no_coref', nlp='no_nlp') 597 | # grid_shuffler = GridShuffler(grids_path, grid_loader=grid_loader, grid_generator=grid_generator) 598 | # 599 | # 600 | # # Testing mode 601 | # # corpus_dct = {k:v for k,v in corpus_dct.iteritems() if k in ['sw_0915_3624.utt']} 602 | # 603 | # grid_shuffler.create_shuffle_index_files(corpus_dct, corpus_name =corpus_name, shuffles_number=20, grids_path=grids_path) 604 | # 605 | # # permuted_files = grid_shuffler.generate_shuffled_grids(corpus_dct=corpus_dct) 606 | # # grid_shuffler.write_shuffled_grids(permuted_files, shuff_path) 607 | 608 | 609 | if __name__ == '__main__': 610 | args = parse() 611 | run(args) 612 | 613 | 614 | 615 | -------------------------------------------------------------------------------- /load_grids.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from numpy import cumsum 3 | # from generate_grid import dump_csv 4 | import argparse 5 | import logging 6 | import os 7 | import pandas as pd 8 | 9 | 10 | class GridLoader(object): 11 | 12 | ''' Class to load grids (together with set-up params) ''' 13 | 14 | def __init__(self, grid_folder=None): 15 | try: 16 | assert os.path.exists(grid_folder) # folder exists 17 | assert os.path.exists(grid_folder + "Params.csv") 18 | except AssertionError: 19 | print("The folder " + grid_folder + " does not exist.") 20 | exit(1) 21 | try: 22 | assert os.path.exists(grid_folder + "Params.csv") 23 | except AssertionError: 24 | print("Params file missing in folder " + grid_folder) 25 | exit(1) 26 | self.grid_folder = grid_folder 27 | self.grids = [] 28 | self.params = [] 29 | 30 | def reformat_params(self, params): 31 | # params_tups = zip(params[params.columns[0]], params[params.columns[1]]) 32 | params.columns = params.iloc[0] 33 | params.drop(params.index[0], inplace=True) 34 | # params.reindex(params.index.drop(1)) 35 | return params 36 | 37 | def load_data(self, df=True): 38 | files = [f for f in os.listdir(self.grid_folder) 39 | if os.path.isfile(os.path.join(self.grid_folder, f))] 40 | grids = {f.split('.csv')[0]: pd.read_csv(os.path.join(self.grid_folder, f), engine="c") for f in files 41 | if f not in ['Params.csv', '.DS_Store.csv']} 42 | params = pd.read_csv(os.path.join(self.grid_folder, 'Params.csv'), header=None, engine="c").T 43 | params = self.reformat_params(params) 44 | # params = grids['Params'].T 45 | # grids.pop('Params') 46 | self.grids = grids 47 | self.params = params 48 | return grids, params 49 | 50 | def get_data(self): 51 | return self.grids, self.params 52 | 53 | def percentage_split(self, seq, percentages): 54 | # TODO: acknowledge 55 | percentages = list(percentages) 56 | cdf = cumsum(percentages) 57 | assert cdf[-1] == 1.0 58 | stops = list(map(int, cdf * len(seq))) 59 | return [seq[a:b] for a, b in zip([0] + stops, stops)] 60 | 61 | def split_dataset(self, dataset, split_percentages=(0.2, 0.8)): 62 | # shuffle(dataset) 63 | print("All dataset size: ", len(dataset)) 64 | test, training_all = self.percentage_split(dataset, split_percentages) 65 | print("Test size: ", len(test)) 66 | print("All training data size: ", len(training_all)) 67 | validation, training = self.percentage_split(training_all, split_percentages) 68 | print("Strict training size: ", len(training)) 69 | print("Validation size: ", len(validation)) 70 | return training, validation, test 71 | 72 | def get_training_test_splits(self, data_folder='data/', 73 | corpus_name = 'Switchboard', 74 | default_grids_folder='egrid_-coref/', 75 | split_filename ="Train_Validation_Test_split"): 76 | 77 | # It returns training/validation/test splits if available, 78 | # otherwise it generates new splits 79 | 80 | splits = None 81 | grids_folder = data_folder + corpus_name + '/' + default_grids_folder 82 | try: 83 | assert os.path.exists(data_folder + corpus_name + '/' + split_filename + ".csv") 84 | print('Using already available train dev test split') 85 | splits = pd.read_csv(data_folder + corpus_name + '/' + split_filename + ".csv", dtype=str, keep_default_na=False) 86 | except AssertionError: 87 | print('No training/val/test splits already available') 88 | try: 89 | assert os.path.exists(grids_folder) 90 | dataset = [x.rstrip('.csv') for x in os.listdir(grids_folder) if x not in ['Params.csv', '.DS_Store.csv']] 91 | training, validation, test = self.split_dataset(dataset) 92 | 93 | splits = pd.DataFrame(dict([(k, pd.Series(v)) 94 | for k, v in [('training', training), 95 | ('validation', validation), 96 | ('test', test)]])) 97 | # splits = pd.DataFrame(dict([(k, v) 98 | # for k, v in [('training', training), 99 | # ('validation', validation), 100 | # ('test', test)]])) 101 | 102 | splits.to_csv(path_or_buf=data_folder + corpus_name + '/' + split_filename + '.csv') 103 | # print('Test: ', test) 104 | # print('Check test split : ', dataset.index(test[-1])) 105 | # print('Check val split : ', dataset.index(validation[-1])) 106 | # print('Check train split : ', dataset.index(training[-1])) 107 | 108 | except AssertionError: 109 | print("Missing data in folder " + grids_folder) 110 | exit(1) 111 | 112 | return splits 113 | 114 | def main(): 115 | grids_path = 'data/egrid_-coref/' 116 | grid_loader = GridLoader(grids_path) 117 | # print(grid_loader.load_data()) 118 | print('Splits: ', grid_loader.get_training_test_splits(corpus_name='Oasis').shape) 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /train_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from future.utils import iteritems 3 | from builtins import dict 4 | from load_grids import GridLoader 5 | # from sklearn.feature_extraction.text import CountVectorizer 6 | from corpus.Switchboard.Switchboard import Switchboard 7 | from generate_grid import corpora_paths, get_corpus 8 | from generate_shuffled import GridShuffler 9 | from itertools import permutations 10 | from collections import defaultdict 11 | # from pysofia import svm_train, svm_predict, learner_type, loop_type, eta_type 12 | from sklearn.metrics import accuracy_score 13 | from numpy.random import shuffle 14 | from sklearn import metrics 15 | # import pysofia 16 | import tqdm 17 | import argparse 18 | # import svmlight 19 | import logging 20 | import numpy as np 21 | import pandas as pd 22 | import os 23 | import errno 24 | 25 | 26 | class EntityGridLM(object): 27 | '''Class to train models and save experiments''' 28 | 29 | def __init__(self, grid_folder=None, grid_loader=None): 30 | try: 31 | assert os.path.exists(grid_folder) 32 | except AssertionError: 33 | print("The folder " + grid_folder + " does not exist.") 34 | exit(1) 35 | self.grid_folder = grid_folder 36 | self.grid_loader = grid_loader(self.grid_folder) 37 | 38 | def score_document_coherence(self): 39 | m = 2 # columns_num (entities) 40 | n = 3 # column_length (text spans) 41 | normalizer = 1/(m*n) 42 | document_coherence = 0 43 | 44 | all_entities_prob = [] 45 | 46 | # For each entity 47 | for columns_num in range(len(m)): 48 | single_entity_prob = [] 49 | 50 | # For each span 51 | for column_len in range(len(n)): 52 | # get ngrams for that span 53 | span_ngrams = [] 54 | single_entity_prob.append() 55 | 56 | # Sum log probs of role transitions of that entity 57 | all_entities_prob.append(sum(single_entity_prob)) 58 | 59 | # Sum log probs of all entities 60 | document_score = normalizer*sum(all_entities_prob) 61 | 62 | raise NotImplementedError 63 | 64 | 65 | class EntitiesFeatureExtractor(object): 66 | 67 | def __init__(self, grid_loader=None, grid_folder=None): 68 | 69 | if grid_loader: 70 | self.grid_loader = grid_loader 71 | else: 72 | try: 73 | assert os.path.exists(grid_folder) 74 | except AssertionError: 75 | print("The folder " + grid_folder + " does not exist.") 76 | exit(1) 77 | self.grid_loader = GridLoader(grid_folder) 78 | 79 | self.grids, self.grids_params = self.grid_loader.get_data() 80 | if not self.grids: 81 | self.grids, self.grids_params = self.grid_loader.load_data() 82 | self.vocabulary = self.get_vocabulary() 83 | self.grid_shuffler = GridShuffler(grid_folder=grid_folder, grid_loader=grid_loader) 84 | 85 | def update_grids_dct(self, grid_names): 86 | self.grids = {k:v for k,v in iteritems(self.grids) if k in grid_names} 87 | 88 | 89 | def get_transitions_probs_for_grid(self, trans2id, grid_i, transitions_count, transition_range, logprobs=True): 90 | dummy_value = float(10 ** -10) 91 | transition_len = transition_range[1] 92 | transitions_grid_i = self.count_transitions_in_grid_fast(grid_i, transition_len, trans2id) 93 | probs_grid_i = self.get_probs_normalized_by_trans_of_len(transitions_grid_i, transitions_count) 94 | 95 | # print('N-Transitions range: ', transitions_count) 96 | # freq_transitions_count = {k: v for k, v in transitions_grid_i.iteritems() if v > 0} 97 | # print('Transitions count: ', freq_transitions_count) 98 | # print('Probabilities: ') 99 | # print(self.show_probs({k: v for k, v in probs_grid_i.iteritems() if k in freq_transitions_count})) 100 | 101 | if logprobs is True: 102 | logprobs_grid_i = {k: (np.log(v) if v > 0 else np.log(dummy_value)) for k, v in iteritems(probs_grid_i)} 103 | # print('Log Probabilities: ') 104 | # print(self.show_probs({k: v for k, v in logprobs_grid_i.iteritems() if k in freq_transitions_count})) 105 | return logprobs_grid_i 106 | else: 107 | return probs_grid_i 108 | 109 | 110 | 111 | 112 | def extract_transitions_probs(self, 113 | corpus_dct=None, 114 | transition_range=(2, 2), 115 | saliency=1, 116 | logprobs=True, 117 | corpus_name='Switchboard', 118 | task='reordering'): 119 | ''' Returns a dict with structure 120 | {grid_name: [orig_probs, perm1_probs, perm2_probs]} 121 | where orig_probs is a dict ''' 122 | 123 | # TODO: Add and tokens? 124 | 125 | print('Params: ') 126 | print(self.grids_params) 127 | print('Vocab: ', self.vocabulary) 128 | print('Vocab size: ', len(self.vocabulary)) 129 | 130 | all_combs = self.generate_combinations(self.vocabulary, transition_range) 131 | # all_combs_str = [''.join(tag for tag in comb) for comb in all_combs] 132 | all_combs_str = [tuple(tag for tag in comb) for comb in all_combs] 133 | trans2id = {x:i for i, x in enumerate(all_combs_str)} 134 | grids_transitions_dict = {} 135 | 136 | # grid_names = ['sw_0001_4325.utt', 'sw_0002_4330.utt', 'sw_0003_4103.utt'] # Testing mode 137 | # grid_names = ['sw_0755_3018.utt'] 138 | grid_names = [n for n in corpus_dct.keys() if n!='.DS_Store'] #[:10] # Testing mode 139 | self.update_grids_dct(grid_names) 140 | 141 | # print('Grids keys : ', grid_names) 142 | # permuted_files = self.grid_shuffler.generate_shuffled_grids(corpus_dct=corpus_dct, only_grids=grid_names, df=True) 143 | 144 | if task=='reordering': 145 | permuted_files = self.grid_shuffler.generate_shuffled_grids(corpus_dct=corpus_dct, only_grids=grid_names, 146 | corpus_name=corpus_name, 147 | saliency=saliency, df=False) 148 | else: 149 | permuted_files = self.grid_shuffler.generate_grids_for_insertion(corpus_dct=corpus_dct, 150 | only_grids=grid_names, 151 | corpus_name=corpus_name, 152 | saliency=saliency, df=False) 153 | 154 | print('Permutation files len: ', len(permuted_files)) 155 | print('First permut len: ', len(permuted_files[grid_names[0]])) 156 | # print('First permut example shape: ', permuted_files[grid_names[0]][0].shape) 157 | 158 | # Compute probs per grid 159 | for grid_i_name in tqdm.tqdm(grid_names): 160 | # print('Grid id: ', grid_i_name) 161 | 162 | # Original order for grid_i 163 | grid_i = self.grids.get(grid_i_name) 164 | 165 | # Check saliency and modify grid accordingly (permuted files were already generated according to saliency) 166 | if saliency>1: 167 | grid_i.drop([col for col in grid_i if len([i for i in grid_i[col] if i != '_']) < saliency], axis=1) 168 | 169 | # # Short example 170 | # grid_i = pd.DataFrame({i: (grid_i[i][:6]) for ind, i in enumerate(grid_i) if ind < 15}) 171 | # print(grid_i) 172 | 173 | permutations_i = permuted_files[grid_i_name] 174 | transitions_count = self.get_total_numb_trans_given(grid_i, transition_range) 175 | 176 | # The first probs distribution in probs_grid_i is the original one 177 | if task=='reordering': 178 | grids_transitions_dict[grid_i_name] = [self.get_transitions_probs_for_grid(trans2id, grid_ij, 179 | transitions_count, transition_range, 180 | logprobs=logprobs) 181 | for grid_ij in [grid_i]+permutations_i] 182 | else: 183 | original = [self.get_transitions_probs_for_grid(trans2id, grid_i, transitions_count, 184 | transition_range, logprobs=logprobs)] 185 | reinserted = [[self.get_transitions_probs_for_grid(trans2id, grid_jy, 186 | transitions_count, transition_range,logprobs=logprobs) for grid_jy in sent_ind_j] 187 | for sent_ind_j in permutations_i] 188 | grids_transitions_dict[grid_i_name] = original+reinserted # TODO: for reinsertion we need one iter per turn removed 189 | 190 | self.grids.pop(grid_i_name) 191 | 192 | return grids_transitions_dict 193 | 194 | 195 | def show_probs(self, probs_grid_i, top=30): 196 | 197 | for i, x in enumerate(sorted(probs_grid_i, key=lambda x: probs_grid_i[x])): 198 | print(probs_grid_i[x], " : ", x) 199 | if i==top: 200 | break 201 | 202 | def get_total_numb_trans_given(self, grid_i, transition_range): 203 | n, m = grid_i.shape 204 | transition_len = transition_range[1] 205 | transitions_number_per_entity = n-(transition_len-1) 206 | transitions_number = transitions_number_per_entity*m 207 | # print('Column length: ', n, " 'Columns number: ", m) 208 | # print('Total number of transitions of length ', transition_len, " : ", transitions_number) 209 | return transitions_number 210 | 211 | def get_column_headers(self, grid): 212 | return list(grid.columns.values) 213 | 214 | def get_probs_normalized_by_trans_of_len(self, transitions_in_grid, transitions_count): 215 | # 1 - Get probability normalizing per grid counts / number of transitions of len=transition range 216 | return {comb: np.divide(float(count), float(transitions_count)) for comb, count in iteritems(transitions_in_grid)} 217 | 218 | 219 | def get_vocabulary(self): 220 | return set([role for grid_i in self.grids.values() for entity in grid_i for role in grid_i[entity]]) 221 | 222 | def generate_combinations(self, voc, transition_range): 223 | min_transition, max_transition = transition_range 224 | doubles = [(tag, tag) for tag in voc] 225 | all_combs = [comb for i in range(min_transition, max_transition+1) for comb in permutations(voc, i)] + doubles 226 | return all_combs 227 | 228 | 229 | def count_transitions_in_grid_per_entity(self, all_combs_str, grid, saliency=1): 230 | ''' Returns a list of dictionaries, where a dictionary contains transitions counts for one entity ''' 231 | 232 | return [{comb: ''.join(x for x in grid[entity_col]).count(comb) for comb in all_combs_str} 233 | for entity_col in grid] 234 | 235 | 236 | def count_transitions_in_grid(self, all_combs_str, grid): 237 | ''' Returns a list of dictionaries, where a dictionary contains transitions counts for one entity ''' 238 | 239 | # all_entities_transitions_in_grid = {comb: ''.join(x for x in grid[entity_col]).count(comb) for comb in all_combs_str 240 | # for entity_col in grid} 241 | all_entities_transitions_in_grid = {comb: 0 for comb in all_combs_str} 242 | for entity_col in grid: 243 | entity_seq = ''.join(x for x in grid[entity_col]) 244 | for comb in all_combs_str: 245 | all_entities_transitions_in_grid[comb] = entity_seq.count(comb) + all_entities_transitions_in_grid[comb] 246 | # print("Len combs: ", len(all_entities_transitions_in_grid)) 247 | return all_entities_transitions_in_grid 248 | 249 | def count_transitions_in_grid_fast_old(self, all_combs_str, grid, trans_range): 250 | ''' Returns a list of dictionaries, where a dictionary contains transitions counts for one entity ''' 251 | 252 | # all_entities_transitions_in_grid = {comb: ''.join(x for x in grid[entity_col]).count(comb) for comb in all_combs_str 253 | # for entity_col in grid} 254 | # all_entities_transitions_in_grid = {comb: 0 for comb in all_combs_str} 255 | n, _ = grid.shape # Column len, number 256 | count_comb = {comb: 0 for comb in all_combs_str} 257 | for j in grid: 258 | tmp = grid[j].tolist() 259 | for i in range(n - trans_range + 1): 260 | count_comb[tuple(tmp[i:i + trans_range])] += 1 261 | return count_comb 262 | 263 | def count_transitions_in_grid_fast(self, grid, trans_range, trans2id): 264 | ''' Returns a list of dictionaries, where a dictionary contains transitions counts for one entity ''' 265 | 266 | # all_entities_transitions_in_grid = {comb: ''.join(x for x in grid[entity_col]).count(comb) for comb in all_combs_str 267 | # for entity_col in grid} 268 | # all_entities_transitions_in_grid = {comb: 0 for comb in all_combs_str} 269 | 270 | # trans2id = {('S','O'):1, ('-','-'):2, ...} 271 | n, _ = grid.shape # Column len, number 272 | count_comb = defaultdict(int) 273 | tmp = grid.T.values.tolist() 274 | for ent in tmp: 275 | for i in range(n - trans_range + 1): 276 | count_comb[trans2id[tuple(ent[i:i + trans_range])]] += 1 277 | return count_comb 278 | 279 | 280 | def featurize_transitions_dct(self, transitions_dict): 281 | X, y, blocks = [], [], [] 282 | block_i = 0 283 | for grid_i_name, grids_i in iteritems(transitions_dict): 284 | # print('Grid name: ', grid_i_name) 285 | 286 | for j, grid_ij in enumerate(grids_i): 287 | X.append(np.asarray([grid_ij[k] for k in sorted(grid_ij.keys())])) 288 | y_ij = 0 if j==0 else 1 289 | y.append(y_ij) 290 | blocks.append(block_i) 291 | block_i += 1 292 | 293 | return np.asarray(X), np.asarray(y), np.asarray(blocks) 294 | 295 | def featurize_transitions_dct_svmlightformat_py(self, transitions_dict): 296 | # (