├── t5 ├── __init__.py ├── predictor.py ├── pretrained_cnndmail_dataset_reader.py └── pretrained_model.py ├── utils ├── __init__.py └── section_names.py ├── pointergen ├── __init__.py ├── custom_instance.py ├── beam_search_predictor.py ├── fields.py ├── cnndmail_dataset_reader.py ├── conditioned_cnndmail_dataset_reader.py ├── conditioned_model.py └── conditioned_model_with_coverage.py ├── sequential_sentence_tagger ├── __init__.py ├── multilabel_predictor.py ├── unilabel_predictor.py ├── sentseq_dataset_reader.py ├── model_hilstm.py ├── sentseq_binary_dataset_reader.py └── all_classification_metrics.py ├── requirements.txt ├── prepare_data.sh ├── README.md ├── dataset_creators └── AMI │ ├── resolve_timestamp_ties.py │ ├── xmin_dataset_creator.py │ ├── clash_resolutions.json │ ├── variant_tasks_creator.py │ └── create_dataset.py ├── show_rouge_scores.py ├── make_predicted_xmin_datasets.py ├── predict_ami.sh └── calculate_rouge.py /t5/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pointergen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/section_names.py: -------------------------------------------------------------------------------- 1 | ami_section_names = ["abstract", 2 | "actions", 3 | "decisions", 4 | "problems"] 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.8.5 2 | torch==1.5.1 3 | transformers==3.0.2 4 | jsonlines==2.0.0 5 | gdown 6 | overrides==3.1.0 7 | pyrouge==0.1.3 8 | nlp==0.4.0 9 | rouge_score==0.0.4 -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | mkdir rawAMI 2 | cd rawAMI 3 | wget http://groups.inf.ed.ac.uk/ami/AMICorpusAnnotations/ami_public_manual_1.6.2.zip 4 | unzip ami_public_manual_1.6.2.zip 5 | cd .. 6 | 7 | mkdir dataset_ami 8 | 9 | cd dataset_creators/AMI 10 | python create_dataset.py 11 | python resolve_timestamp_ties.py 12 | python variant_tasks_creator.py 13 | python xmin_dataset_creator.py 14 | 15 | 16 | cd ../../ 17 | 18 | 19 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/multilabel_predictor.py: -------------------------------------------------------------------------------- 1 | from allennlp.predictors.predictor import Predictor 2 | from allennlp.models.model import Model 3 | from allennlp.data.dataset_readers import DatasetReader 4 | from allennlp.data import Instance 5 | from overrides import overrides 6 | from allennlp.common.util import JsonDict 7 | 8 | 9 | 10 | @Predictor.register("simple_multilabel_classifier") 11 | class BeamSearchPredictor(Predictor): 12 | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: 13 | super().__init__(model, dataset_reader) 14 | 15 | @overrides 16 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 17 | """ 18 | Expects JSON that looks like ``{"article_lines": ["...", "...", ...], "summary_lines": ["...", "...", ...]}``. 19 | """ 20 | return self._dataset_reader.dict_to_instance(json_dict) 21 | 22 | @overrides 23 | def predict_json(self, inputs: JsonDict) -> JsonDict: 24 | instance = self._json_to_instance(inputs) 25 | predicted = self.predict_instance(instance) 26 | 27 | to_return = {"input": inputs["article_lines"], 28 | "ground_truth": inputs["labels"], 29 | "prediction": predicted} 30 | 31 | for extra_key in ["case_id"]: 32 | if extra_key in inputs.keys(): 33 | to_return[extra_key] = inputs[extra_key] 34 | 35 | return to_return 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/unilabel_predictor.py: -------------------------------------------------------------------------------- 1 | from allennlp.predictors.predictor import Predictor 2 | from allennlp.models.model import Model 3 | from allennlp.data.dataset_readers import DatasetReader 4 | from allennlp.data import Instance 5 | from overrides import overrides 6 | from allennlp.common.util import JsonDict 7 | 8 | import numpy as np 9 | 10 | import pdb 11 | 12 | 13 | @Predictor.register("simple_unilabel_classifier") 14 | class BeamSearchPredictor(Predictor): 15 | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: 16 | super().__init__(model, dataset_reader) 17 | 18 | @overrides 19 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 20 | """ 21 | Expects JSON that looks like ``{"article_lines": ["...", "...", ...], "summary_lines": ["...", "...", ...]}``. 22 | """ 23 | return self._dataset_reader.dict_to_instance(json_dict) 24 | 25 | @overrides 26 | def predict_json(self, inputs: JsonDict) -> JsonDict: 27 | instance = self._json_to_instance(inputs) 28 | predicted = self.predict_instance(instance) 29 | 30 | to_return = {"input": inputs["article_lines"], 31 | "ground_truth": np.max(np.array(inputs["labels"]).astype(int), axis=1, keepdims=True).tolist(), 32 | "prediction": predicted} 33 | 34 | for extra_key in ["case_id"]: 35 | if extra_key in inputs.keys(): 36 | to_return[extra_key] = inputs[extra_key] 37 | 38 | return to_return 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This repository contains the code for running modular summarization pipelines as described in the publication 3 | `Krishna K, Khosla K, Bigham J, Lipton ZC. Generating SOAP Notes from Doctor-Patient Conversations." ACL 2021.` 4 | 5 | The paper can be found here : https://aclanthology.org/2021.acl-long.384/ 6 | 7 | 8 | ### Instructions 9 | 10 | Although we can not release models trained on the confidential medical data, we have released models trained on the publicly available AMI dataset. 11 | To reproduce the results on the AMI dataset, you need to follow the steps listed below. 12 | For convenience, we have also created a Google Colab notebook [here](https://colab.research.google.com/drive/1P0dp4rctvhSWdzfgml4B7yt3Qketrzst?usp=sharing) that runs these steps on Google's servers (free-of-cost as of June 2021) and produces the summaries and their rouge scores. 13 | 14 | **Step1:** Set up the environment by installing the required packages mentioned in `requirements.txt` using pip. 15 | 16 | **Step2:** Download the `ami_models` folder from this [link](https://drive.google.com/drive/folders/12Fkv_JvhJotvDTk2Z5PZYs4POIkqmwBZ?usp=sharing) and put it at the root of the repository: 17 | 18 | 19 | **Step3:** Run the following 3 commands to prepare data, run summary generation pipelines, and show the achieved rouge scores. 20 | 21 | ```bash 22 | # command1: downloads and preprocesses AMI dataset 23 | ./prepare_data.sh 24 | 25 | # command2: runs the summarization pipelines on the data and computes rouge scores 26 | # (before running this command, you need to download the models as shown above) 27 | ./predict_ami.sh 28 | 29 | # command3: print the results 30 | python show_results.py 31 | ``` 32 | -------------------------------------------------------------------------------- /pointergen/custom_instance.py: -------------------------------------------------------------------------------- 1 | from allennlp.data import Instance 2 | from overrides import overrides 3 | from allennlp.data import Vocabulary 4 | from pointergen.fields import SourceTextField, TargetTextField 5 | 6 | class SyncedFieldsInstance(Instance): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @overrides 11 | def index_fields(self, vocab: Vocabulary) -> None: 12 | """ 13 | Indexes all fields in this ``Instance`` using the provided ``Vocabulary``. 14 | This `mutates` the current object, it does not return a new ``Instance``. 15 | A ``DataIterator`` will call this on each pass through a dataset; we use the ``indexed`` 16 | flag to make sure that indexing only happens once. 17 | 18 | This means that if for some reason you modify your vocabulary after you've 19 | indexed your instances, you might get unexpected behavior. 20 | """ 21 | if not self.indexed: 22 | self.indexed = True 23 | all_fields = self.fields.values() 24 | source_fields = list(filter(lambda x:type(x)==SourceTextField, all_fields)) 25 | target_fields = list(filter(lambda x:type(x)==TargetTextField, all_fields)) 26 | 27 | assert (len(source_fields)==1), "There should be exactly one source fields because otherwise OOV indices would clash" 28 | for field in self.fields.values(): 29 | if type(field) not in [SourceTextField, TargetTextField]: 30 | field.index(vocab) 31 | 32 | source_field = source_fields[0] 33 | oov_list = source_field.index(vocab) 34 | self.oov_list = oov_list 35 | 36 | for target_field in target_fields: 37 | target_field.index(vocab, oov_list) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /pointergen/beam_search_predictor.py: -------------------------------------------------------------------------------- 1 | from allennlp.predictors.predictor import Predictor 2 | from allennlp.models.model import Model 3 | from allennlp.data.dataset_readers import DatasetReader 4 | from allennlp.data import Instance 5 | from overrides import overrides 6 | from allennlp.common.util import JsonDict, sanitize 7 | 8 | 9 | class TextGen(Predictor): 10 | def __init__(self, model: Model, dataset_reader: DatasetReader, decode_strategy="greedy") -> None: 11 | super().__init__(model, dataset_reader) 12 | self.decode_strategy = decode_strategy 13 | 14 | @overrides 15 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 16 | """ 17 | Expects JSON that looks like ``{"article_lines": ["...", "...", ...], "summary_lines": ["...", "...", ...]}``. 18 | """ 19 | return self._dataset_reader.dict_to_instance(json_dict) 20 | 21 | @overrides 22 | def predict_instance(self, instance: Instance) -> JsonDict: 23 | outputs = self._model.forward_on_instance(instance, decode_strategy=self.decode_strategy) 24 | return sanitize(outputs) 25 | 26 | 27 | @overrides 28 | def predict_json(self, inputs: JsonDict) -> JsonDict: 29 | instance = self._json_to_instance(inputs) 30 | predicted = self.predict_instance(instance) 31 | ground_truth = " ".join(instance.fields["meta"]["target_tokens"]) 32 | 33 | to_return = {"input": inputs["article_lines"], 34 | "ground_truth": ground_truth, 35 | "prediction": predicted} 36 | 37 | for extra_key in ["case_id", "index_in_note", "section", "orig_section"]: 38 | if extra_key in inputs.keys(): 39 | to_return[extra_key] = inputs[extra_key] 40 | 41 | return to_return 42 | 43 | 44 | 45 | @Predictor.register("beamsearch_constrained") 46 | class Beamsearch(TextGen): 47 | def __init__(self, model: Model, dataset_reader: DatasetReader): 48 | super(Beamsearch, self).__init__(model, dataset_reader, decode_strategy="beamsearch_constrained") 49 | 50 | @Predictor.register("beamsearch") 51 | class BeamsearchConstrained(TextGen): 52 | def __init__(self, model: Model, dataset_reader: DatasetReader): 53 | super(BeamsearchConstrained, self).__init__(model, dataset_reader, decode_strategy="beamsearch_unconstrained") 54 | 55 | 56 | -------------------------------------------------------------------------------- /dataset_creators/AMI/resolve_timestamp_ties.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | from hashlib import md5 3 | from collections import defaultdict 4 | import json 5 | from copy import deepcopy 6 | 7 | 8 | ami_section_names = ["abstract", 9 | "actions", 10 | "decisions", 11 | "problems"] 12 | 13 | 14 | def get_utt(obj): 15 | return obj["speaker"]+" "+obj["txt"] 16 | 17 | clash_resolve_dict = json.load(open("./clash_resolutions.json")) 18 | 19 | def patch_file(fpath): 20 | x=list(jsonlines.open(fpath)) 21 | 22 | for x1 in x: 23 | id1=x1["id"] 24 | if id1 not in clash_resolve_dict: 25 | continue 26 | this_resolve_dict = clash_resolve_dict[id1] 27 | trans1=x1["transcript"] 28 | tstamps = list(x1["transcript"].keys()) 29 | tstamps = sorted(tstamps, key=float) 30 | for (i,t) in enumerate(tstamps): 31 | if t in this_resolve_dict: 32 | first_tstamp = tstamps[i] 33 | second_tstamp = tstamps[i+1] 34 | first_utt = get_utt(trans1[first_tstamp]) 35 | second_utt = get_utt(trans1[second_tstamp]) 36 | first_hash=md5(first_utt.encode("utf-8")).digest() 37 | second_hash=md5(second_utt.encode("utf-8")).digest() 38 | current_order = first_hash None: 16 | super().__init__(model, dataset_reader) 17 | self.decode_strategy = decode_strategy 18 | 19 | @overrides 20 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 21 | """ 22 | Expects JSON that looks like ``{"article_lines": ["...", "...", ...], "summary_lines": ["...", "...", ...]}``. 23 | """ 24 | return self._dataset_reader.dict_to_instance(json_dict) 25 | 26 | @overrides 27 | def predict_instance(self, instance: Instance) -> JsonDict: 28 | outputs = self._model.forward_on_instance(instance, decode_strategy=self.decode_strategy) 29 | return sanitize(outputs) 30 | 31 | def _join_tokens(self, token_list): 32 | return " ".join(token_list) 33 | 34 | @overrides 35 | def predict_json(self, inputs: JsonDict) -> JsonDict: 36 | instance = self._json_to_instance(inputs) 37 | predicted = self.predict_instance(instance) 38 | predicted_str = self._join_tokens(predicted["tokens"]) 39 | 40 | ground_truth = self._join_tokens(instance.fields["meta"]["target_tokens"]) 41 | 42 | to_return = {"input": inputs["article_lines"], 43 | "ground_truth": ground_truth, 44 | "prediction": predicted_str} 45 | 46 | for extra_key in inputs.keys(): 47 | if extra_key not in to_return: 48 | to_return[extra_key] = inputs[extra_key] 49 | 50 | return to_return 51 | 52 | 53 | 54 | class TextGenWordpiece(TextGen): 55 | def __init__(self, model: Model, dataset_reader: DatasetReader, decode_strategy="greedy") -> None: 56 | super(TextGenWordpiece, self).__init__(model, dataset_reader, decode_strategy) 57 | self._tokenizer = T5Tokenizer.from_pretrained("t5-base") 58 | self.new_unk_token = self._tokenizer.unk_token 59 | 60 | def _replace_toks_inplace(self, toklist, from_tok, to_tok): 61 | for _i, tok in enumerate(toklist): 62 | if tok==from_tok: 63 | toklist[_i]=to_tok 64 | 65 | @overrides 66 | def _join_tokens(self, token_list): 67 | self._replace_toks_inplace(token_list, from_tok="@@UNKNOWN@@", to_tok=self.new_unk_token) 68 | assert START_SYMBOL not in token_list 69 | assert END_SYMBOL not in token_list 70 | assert "@@PADDING@@" not in token_list 71 | return self._tokenizer.convert_tokens_to_string(token_list) 72 | 73 | 74 | @Predictor.register("greedy") 75 | class SimpleGreedyWordpiece(TextGenWordpiece): 76 | def __init__(self, model: Model, dataset_reader: DatasetReader): 77 | super(SimpleGreedyWordpiece, self).__init__(model, dataset_reader, decode_strategy="greedy") 78 | 79 | @Predictor.register("beamsearch") 80 | class BeamsearchWordpiece(TextGenWordpiece): 81 | def __init__(self, model: Model, dataset_reader: DatasetReader): 82 | super(BeamsearchWordpiece, self).__init__(model, dataset_reader, decode_strategy="beamsearch") 83 | 84 | -------------------------------------------------------------------------------- /show_rouge_scores.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | def add_result(fpath, method_name, dict_obj): 4 | df = pd.read_csv(fpath, sep="\t") 5 | dd = df.loc[0].to_dict() 6 | del dd['Unnamed: 0'] 7 | dict_obj[method_name]=dd 8 | 9 | dict_obj = {} 10 | 11 | add_result(fpath="./ami_models/t5_models/ser_entrywise_conditioned_t5_base/test_outputs.jsonl.rougescores.csv", 12 | method_name="CLUSTER2SENT+T5-BASE (ORACLE)", 13 | dict_obj=dict_obj) 14 | 15 | add_result(fpath="./ami_models/t5_models/ser_entrywise_conditioned_t5_small/test_outputs.jsonl.rougescores.csv", 16 | method_name="CLUSTER2SENT+T5-SMALL (ORACLE)", 17 | dict_obj=dict_obj) 18 | 19 | add_result(fpath="./ami_models/t5_models/ser_sectionwise_conditioned_t5_small/test_outputs.jsonl.rougescores.csv", 20 | method_name="EXT2SEC+T5-SMALL (ORACLE)", 21 | dict_obj=dict_obj) 22 | 23 | add_result(fpath="./ami_models/t5_models/ser_allxmins_fullsummary_t5_small/test_outputs.jsonl.rougescores.csv", 24 | method_name="EXT2NOTE+T5-SMALL (ORACLE)", 25 | dict_obj=dict_obj) 26 | 27 | add_result(fpath="./ami_models/t5_models/ser_allxmins_fullsummary_t5_small/test_outputs.jsonl.rougescores.csv", 28 | method_name="EXT2NOTE+T5-SMALL (ORACLE)", 29 | dict_obj=dict_obj) 30 | 31 | add_result(fpath="./ami_models/pg_models/ser_fullconversation_fullsummary/test_outputs.jsonl.rougescores.csv", 32 | method_name="CONV2NOTE+PG (ORACLE)", 33 | dict_obj=dict_obj) 34 | 35 | add_result(fpath="./ami_models/pg_models/ser_allxmins_fullsummary/test_outputs.jsonl.rougescores.csv", 36 | method_name="EXT2NOTE+PG (ORACLE)", 37 | dict_obj=dict_obj) 38 | 39 | add_result(fpath="./ami_models/pg_models/ser_sectionwise_allxmins_sectionsummary/test_outputs.jsonl.rougescores.csv", 40 | method_name="EXT2SEC+PG (ORACLE)", 41 | dict_obj=dict_obj) 42 | 43 | add_result(fpath="./ami_models/pg_models/ser_entrywise_summarization/test_outputs.jsonl.rougescores.csv", 44 | method_name="CLUSTERSENT+PG (ORACLE)", 45 | dict_obj=dict_obj) 46 | 47 | add_result(fpath="./ami_models/t5_models/ser_entrywise_conditioned_t5_base/test_outputs_on_hlstm.jsonl.rougescores.csv", 48 | method_name="CLUSTERSENT+T5-BASE (HLSTM)", 49 | dict_obj=dict_obj) 50 | 51 | add_result(fpath="./ami_models/t5_models/ser_entrywise_conditioned_t5_small/test_outputs_on_hlstm.jsonl.rougescores.csv", 52 | method_name="CLUSTERSENT+T5-SMALL (HLSTM)", 53 | dict_obj=dict_obj) 54 | 55 | add_result(fpath="./ami_models/t5_models/ser_sectionwise_conditioned_t5_small/test_outputs_on_hlstm.jsonl.rougescores.csv", 56 | method_name="EXT2SEC+T5-SMALL (HLSTM)", 57 | dict_obj=dict_obj) 58 | 59 | add_result(fpath="./ami_models/t5_models/ser_allxmins_fullsummary_t5_small/test_outputs_on_hlstm.jsonl.rougescores.csv", 60 | method_name="EXT2NOTE+T5-SMALL (HLSTM)", 61 | dict_obj=dict_obj) 62 | 63 | add_result(fpath="./ami_models/pg_models/ser_allxmins_fullsummary/test_outputs_on_hlstm.jsonl.rougescores.csv", 64 | method_name="EXT2NOTE+PG (HLSTM)", 65 | dict_obj=dict_obj) 66 | 67 | add_result(fpath="./ami_models/pg_models/ser_sectionwise_allxmins_sectionsummary/test_outputs_on_hlstm.jsonl.rougescores.csv", 68 | method_name="EXT2SEC+PG (HLSTM)", 69 | dict_obj=dict_obj) 70 | 71 | add_result(fpath="./ami_models/pg_models/ser_entrywise_summarization/test_outputs_on_hlstm.jsonl.rougescores.csv", 72 | method_name="CLUSTER2SENT+PG (HLSTM)", 73 | dict_obj=dict_obj) 74 | 75 | 76 | df_complete = pd.DataFrame(dict_obj) 77 | df_complete = df_complete.transpose() 78 | df_complete = df_complete*100 79 | df_complete = df_complete.round(2) 80 | 81 | print(df_complete) 82 | 83 | 84 | -------------------------------------------------------------------------------- /pointergen/fields.py: -------------------------------------------------------------------------------- 1 | from allennlp.data.fields import TextField 2 | from allennlp.data import Token, TokenIndexer 3 | from overrides import overrides 4 | from allennlp.data import Vocabulary 5 | from allennlp.data.token_indexers.token_indexer import TokenIndexer, TokenType 6 | from allennlp.data.token_indexers import SingleIdTokenIndexer 7 | from copy import copy 8 | 9 | from typing import List, Dict 10 | TokenList = List[TokenType] 11 | 12 | 13 | class SourceTextField(TextField): 14 | def __init__(self, tokens: List[Token], token_indexers: Dict[str, TokenIndexer]) -> None: 15 | assert(len(token_indexers)==1), "Only one indexer is allowed in a SourceTextField" 16 | super().__init__(tokens, token_indexers) 17 | 18 | @overrides 19 | def index(self, vocab: Vocabulary): 20 | pass 21 | 22 | def index(self, vocab: Vocabulary) -> List[str]: 23 | token_arrays: Dict[str, TokenList] = {} 24 | indexer_name_to_indexed_token: Dict[str, List[str]] = {} 25 | token_index_to_indexer_name: Dict[str, str] = {} 26 | 27 | for indexer_name, indexer in self._token_indexers.items(): 28 | assert type(indexer)==SingleIdTokenIndexer, "The indexer must be a singleidtokenindexer" 29 | token_indices = indexer.tokens_to_indices(self.tokens, vocab, indexer_name) 30 | 31 | oovs_list : List[str] = [] 32 | for key, val in token_indices.items(): 33 | #key is string, val is array of ints 34 | oov_id = vocab._token_to_index[indexer.namespace][vocab._oov_token] 35 | 36 | ids_with_unks : List[int] = val 37 | ids_with_oovs : List[int] = [] 38 | 39 | for _id, word in zip(ids_with_unks, self.tokens): 40 | if _id == oov_id: 41 | if word.text not in oovs_list: 42 | oovs_list.append(word.text) 43 | ids_with_oovs.append(vocab.get_vocab_size(indexer.namespace) + oovs_list.index(word.text)) 44 | else: 45 | ids_with_oovs.append(_id) 46 | 47 | token_arrays.update({ 48 | "ids_with_unks": ids_with_unks, 49 | "ids_with_oovs": ids_with_oovs, 50 | "num_oovs": [len(oovs_list)] 51 | }) 52 | indexer_name_to_indexed_token[indexer_name] = ["ids_with_unks", "ids_with_oovs", "num_oovs"] 53 | token_index_to_indexer_name["ids_with_unks"] = indexer_name 54 | token_index_to_indexer_name["ids_with_oovs"] = indexer_name 55 | token_index_to_indexer_name["num_oovs"] = indexer_name 56 | 57 | 58 | self._indexed_tokens = token_arrays 59 | self._indexer_name_to_indexed_token = indexer_name_to_indexed_token 60 | self._token_index_to_indexer_name = token_index_to_indexer_name 61 | self._oovs = oovs_list 62 | 63 | return self._oovs 64 | 65 | class TargetTextField(TextField): 66 | def __init__(self, tokens: List[Token], token_indexers: Dict[str, TokenIndexer]) -> None: 67 | super().__init__(tokens, token_indexers) 68 | 69 | @overrides 70 | def index(self, vocab: Vocabulary, oovs_list: TokenList): 71 | token_arrays: Dict[str, TokenList] = {} 72 | indexer_name_to_indexed_token: Dict[str, List[str]] = {} 73 | token_index_to_indexer_name: Dict[str, str] = {} 74 | for indexer_name, indexer in self._token_indexers.items(): 75 | assert type(indexer)==SingleIdTokenIndexer, "The indexer must be a singleidtokenindexer" 76 | token_indices = indexer.tokens_to_indices(self.tokens, vocab, indexer_name) 77 | 78 | for key, val in token_indices.items(): 79 | oov_id = vocab._token_to_index[indexer.namespace][vocab._oov_token] 80 | 81 | ids_with_unks : List[int] = val 82 | ids_with_oovs : List[int] = [] 83 | 84 | for _id, word in zip(ids_with_unks, self.tokens): 85 | if _id == oov_id: 86 | if word.text not in oovs_list: 87 | ids_with_oovs.append(_id) # let it be the vocab id for OOV 88 | else: 89 | ids_with_oovs.append(vocab.get_vocab_size(indexer.namespace) + oovs_list.index(word.text)) 90 | else: 91 | ids_with_oovs.append(_id) 92 | 93 | token_arrays.update({ 94 | "ids_with_unks": ids_with_unks, 95 | "ids_with_oovs": ids_with_oovs 96 | }) 97 | 98 | indexer_name_to_indexed_token[indexer_name] = ["ids_with_unks", "ids_with_oovs"] 99 | token_index_to_indexer_name["ids_with_unks"] = indexer_name 100 | token_index_to_indexer_name["ids_with_oovs"] = indexer_name 101 | 102 | self._indexed_tokens = token_arrays 103 | self._indexer_name_to_indexed_token = indexer_name_to_indexed_token 104 | self._token_index_to_indexer_name = token_index_to_indexer_name 105 | self._oovs = oovs_list 106 | -------------------------------------------------------------------------------- /dataset_creators/AMI/xmin_dataset_creator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from collections import OrderedDict, Counter 5 | from tqdm import tqdm 6 | 7 | from copy import deepcopy 8 | import jsonlines 9 | import os 10 | 11 | train_dataset = list(jsonlines.open("../../dataset_ami/root/train.jsonl", "r")) 12 | val_dataset = list(jsonlines.open("../../dataset_ami/root/val.jsonl", "r")) 13 | test_dataset = list(jsonlines.open("../../dataset_ami/root/test.jsonl", "r")) 14 | 15 | output_root_path = "../../dataset_ami/" 16 | 17 | seq_of_sections = ["abstract", 18 | "actions", 19 | "decisions", 20 | "problems"] 21 | 22 | def makeTimeOrderedDict(unordered_dict): 23 | return OrderedDict({k:unordered_dict[k] for k in sorted(unordered_dict.keys(), key=lambda i:float(i))}) 24 | 25 | 26 | task_name = "allxmin_binary_classification" 27 | 28 | def get_allxmin_binary_classification_dp(base_dp): 29 | base_dp=deepcopy(base_dp) 30 | 31 | all_xmins=set() 32 | for section_name in seq_of_sections: 33 | sorted_entries = sorted(base_dp[section_name], 34 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 35 | for entry in sorted_entries: 36 | all_xmins.update(entry["sorted_xmins"]) 37 | 38 | sorted_turns = makeTimeOrderedDict(base_dp["transcript"]) 39 | input_lines = [] 40 | for tstamp, turn in sorted_turns.items(): 41 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 42 | 43 | if len(input_lines)==0: 44 | return None 45 | 46 | label_dict={"is_important":0} 47 | labels = np.zeros((len(input_lines),1), dtype=int).tolist() 48 | 49 | for idx, tstamp in enumerate(sorted_turns.keys()): 50 | if tstamp in all_xmins: 51 | labels[idx][0]=1 52 | 53 | return {"case_id":base_dp["id"] ,"article_lines":input_lines, "labels":labels, "label_dict":label_dict} 54 | 55 | train_section_dataset = [get_allxmin_binary_classification_dp(x) for x in tqdm(train_dataset)] 56 | val_section_dataset = [get_allxmin_binary_classification_dp(x) for x in val_dataset] 57 | test_section_dataset = [get_allxmin_binary_classification_dp(x) for x in test_dataset] 58 | 59 | task_output_path = os.path.join(output_root_path, task_name) 60 | os.mkdir(task_output_path) 61 | 62 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 63 | for obj in train_section_dataset: 64 | if obj!=None: 65 | w.write(obj) 66 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 67 | for obj in val_section_dataset: 68 | if obj!=None: 69 | w.write(obj) 70 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 71 | for obj in test_section_dataset: 72 | if obj!=None: 73 | w.write(obj) 74 | 75 | 76 | task_name = "sectionwise_xmin_multilabel_classification" 77 | 78 | def get_sectionwise_xmin_multilabel_classification_dp(base_dp): 79 | base_dp=deepcopy(base_dp) 80 | 81 | sorted_turns = makeTimeOrderedDict(base_dp["transcript"]) 82 | input_lines = [] 83 | for tstamp, turn in sorted_turns.items(): 84 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 85 | 86 | if len(input_lines)==0: 87 | return None 88 | 89 | label_dict={sec:_i for (_i, sec) in enumerate(seq_of_sections)} 90 | labels = np.zeros((len(input_lines),len(seq_of_sections)), dtype=int).tolist() 91 | 92 | for section_name in seq_of_sections: 93 | all_xmins=set() 94 | sorted_entries = sorted(base_dp[section_name], 95 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 96 | for entry in sorted_entries: 97 | all_xmins.update(entry["sorted_xmins"]) 98 | 99 | short_section_name = section_name 100 | for idx, tstamp in enumerate(sorted_turns.keys()): 101 | if tstamp in all_xmins: 102 | labels[idx][label_dict[short_section_name]]=1 103 | 104 | return {"case_id":base_dp["id"] ,"article_lines":input_lines, "labels":labels, "label_dict":label_dict} 105 | 106 | 107 | train_section_dataset = [get_sectionwise_xmin_multilabel_classification_dp(x) for x in tqdm(train_dataset)] 108 | val_section_dataset = [get_sectionwise_xmin_multilabel_classification_dp(x) for x in val_dataset] 109 | test_section_dataset = [get_sectionwise_xmin_multilabel_classification_dp(x) for x in test_dataset] 110 | 111 | task_output_path = os.path.join(output_root_path, task_name) 112 | os.mkdir(task_output_path) 113 | 114 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 115 | for obj in train_section_dataset: 116 | if obj!=None: 117 | w.write(obj) 118 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 119 | for obj in val_section_dataset: 120 | if obj!=None: 121 | w.write(obj) 122 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 123 | for obj in test_section_dataset: 124 | if obj!=None: 125 | w.write(obj) 126 | 127 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/sentseq_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | import numpy as np 5 | from overrides import overrides 6 | import pickle 7 | 8 | import jsonlines 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.common.file_utils import cached_path 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 14 | from allennlp.data.fields import TextField, ArrayField, MetadataField, ListField 15 | from allennlp.data.instance import Instance 16 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 17 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 18 | from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter 19 | 20 | 21 | from overrides import overrides 22 | 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | 27 | @DatasetReader.register("sentseq_dataset_reader") 28 | class SentSeqDatasetReader(DatasetReader): 29 | def __init__(self, 30 | max_sent_length : int = np.inf, 31 | max_sents_per_example : int = np.inf, 32 | tokenizer: Tokenizer = None, 33 | token_indexers: Dict[str, TokenIndexer] = None, 34 | lowercase_tokens : bool = False, 35 | lazy: bool = False, 36 | max_to_read = np.inf) -> None: 37 | super().__init__(lazy) 38 | self.lowercase_tokens = lowercase_tokens 39 | self.max_sent_length = max_sent_length 40 | self.max_sents_per_example = max_sents_per_example 41 | self.max_to_read = max_to_read 42 | self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) 43 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 44 | if "tokens" not in self._token_indexers or \ 45 | not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer): 46 | raise ConfigurationError("CNNDmailDatasetReader expects 'token_indexers' to contain " 47 | "a 'single_id' token indexer called 'tokens'.") 48 | 49 | 50 | @overrides 51 | def _read(self, file_path): 52 | logger.info("Reading instances from lines in file at: %s", file_path) 53 | with jsonlines.open(file_path, "r") as reader: 54 | num_passed = 0 55 | for dp in reader: 56 | if num_passed == self.max_to_read: 57 | return 58 | if len(" ".join(dp["article_lines"]))==0: # if the input article has length 0 then there is a crash due to some NaN popping up. there are 114 such datapoins in cnndmail trainset 59 | continue 60 | num_passed += 1 61 | yield self.dict_to_instance(dp) 62 | 63 | def dict_to_instance(self, dp): 64 | input_lines = dp["article_lines"] 65 | labels = dp["labels"] 66 | 67 | if len(input_lines)!=len(labels): 68 | print("error in", dp["case_id"]) 69 | 70 | 71 | input_lines = input_lines[:self.max_sents_per_example] 72 | shortened_input_lines = [] 73 | for line in input_lines: 74 | shortened_input_lines.append( " ".join(line.split(" ")[:self.max_sent_length]) ) 75 | labels = labels[:self.max_sents_per_example] 76 | return self.text_to_instance(shortened_input_lines, labels) 77 | 78 | @staticmethod 79 | def _tokens_to_ids(tokens: List[Token]) -> List[int]: 80 | ids: Dict[str, int] = {} 81 | out: List[int] = [] 82 | for token in tokens: 83 | out.append(ids.setdefault(token.text.lower(), len(ids))) 84 | return out 85 | 86 | @overrides 87 | def text_to_instance(self, shortened_input_lines: List[str], labels: List[List[int]] = None) -> Instance: # type: ignore 88 | """ 89 | Turn raw source string and target string into an ``Instance``. 90 | 91 | Parameters 92 | ---------- 93 | source_string : ``str``, required 94 | target_string : ``str``, optional (default = None) 95 | 96 | Returns 97 | ------- 98 | Instance 99 | See the above for a description of the fields that the instance will contain. 100 | """ 101 | # pylint: disable=arguments-differ 102 | if self.lowercase_tokens: 103 | for i, s in enumerate(shortened_input_lines): 104 | shortened_input_lines[i]=s.lower() 105 | 106 | tokenized_lines = [[Token(START_SYMBOL)]+self._tokenizer.tokenize(s)+[Token(END_SYMBOL)] for s in shortened_input_lines] 107 | indexed_line_fields = [TextField(tokenized_line, self._token_indexers) for tokenized_line in tokenized_lines] 108 | 109 | seq_of_textfields = ListField(indexed_line_fields) 110 | 111 | meta_fields = {"tokenized_lines": tokenized_lines} 112 | fields_dict = { 113 | "lines": seq_of_textfields, 114 | } 115 | 116 | if labels is not None: 117 | labelfield = ArrayField(np.array(labels)) 118 | fields_dict["labels"] = labelfield 119 | 120 | fields_dict["meta"] = MetadataField(meta_fields) 121 | 122 | return Instance(fields_dict) 123 | 124 | 125 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/model_hilstm.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | from allennlp.models.model import Model 5 | from typing import Dict 6 | from overrides import overrides 7 | from allennlp.data.dataset import Batch 8 | from allennlp.nn import util 9 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 10 | 11 | from allennlp.data.instance import Instance 12 | 13 | import torch 14 | from torch.nn import LSTM 15 | from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper 16 | from allennlp.modules import TimeDistributed 17 | 18 | from sequential_sentence_tagger.all_classification_metrics import AllClassificationMetric 19 | 20 | EPS=1e-8 21 | 22 | 23 | @Model.register("sequential_sentence_tagger") 24 | class Seq2Seq(Model): 25 | def __init__(self, vocab, num_labels, hidden_size=256, emb_size=128, dataset:str="abridge"): 26 | super().__init__(vocab) 27 | # self.vocab=vocab 28 | 29 | ## vocab related setup begins 30 | self.vocab_size=vocab.get_vocab_size() 31 | self.PAD_ID = vocab.get_token_index(vocab._padding_token) 32 | self.OOV_ID = vocab.get_token_index(vocab._oov_token) 33 | self.START_ID = vocab.get_token_index(START_SYMBOL) 34 | self.END_ID = vocab.get_token_index(END_SYMBOL) 35 | ## vocab related setup ends 36 | 37 | self.emb_size=emb_size 38 | self.hidden_size=hidden_size 39 | self.num_labels=num_labels 40 | 41 | self.emb_layer = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.emb_size) 42 | lstm_layer = LSTM(input_size=self.emb_size, hidden_size=self.hidden_size, bidirectional=True, batch_first=True) 43 | self.sentence_lstm_encoder = TimeDistributed(PytorchSeq2SeqWrapper(lstm_layer)) 44 | 45 | lstm_layer2 = LSTM(input_size=2*self.hidden_size, hidden_size=self.hidden_size, bidirectional=True, batch_first=True) 46 | self.contextual_encoder = PytorchSeq2SeqWrapper(lstm_layer2) 47 | self.projection_layer = nn.Linear(2*self.hidden_size, self.num_labels) 48 | 49 | self.loss =nn.BCEWithLogitsLoss(reduction='none') 50 | 51 | self._metric = AllClassificationMetric(num_classes=num_labels, dataset=dataset) 52 | 53 | 54 | # buffers because these dont need grads. These are placed here because they will be replicated across gpus 55 | self.register_buffer("true_rep", torch.tensor(1.0)) 56 | self.register_buffer("false_rep", torch.tensor(0.0)) 57 | 58 | 59 | def forward(self, lines, labels=None, meta=None, only_predict_probs=False): 60 | # print(self.PAD_ID) 61 | input_tokens = lines["tokens"] 62 | input_pad_mask= input_tokens!=self.PAD_ID 63 | # print(input_pad_mask) 64 | embedded_seq = self.emb_layer(input_tokens) 65 | sentenced_encoded = self.sentence_lstm_encoder(embedded_seq, input_pad_mask) # batchxnumsentxsentlenxhidden_size 66 | 67 | embedding_summed = torch.sum(sentenced_encoded, axis=-2, keepdim=False) 68 | to_divide = torch.sum(input_pad_mask, axis=-1, keepdim=True)+EPS 69 | meanpooled_sent_reps = embedding_summed/to_divide # batchxnumsentx2*hidden_size 70 | 71 | sentence_level_pad_mask = input_pad_mask[:,:,0] 72 | # print(sentence_level_pad_mask) 73 | contextual_embedded = self.contextual_encoder(meanpooled_sent_reps, sentence_level_pad_mask) 74 | 75 | logits = self.projection_layer(contextual_embedded) 76 | 77 | loss_without_mask = self.loss(logits, labels) 78 | loss_with_mask = loss_without_mask*sentence_level_pad_mask.unsqueeze(-1) 79 | total_loss = torch.sum(loss_with_mask) 80 | numpreds = torch.sum(sentence_level_pad_mask) 81 | avg_loss = total_loss/numpreds 82 | 83 | 84 | if only_predict_probs: 85 | probs = torch.nn.functional.sigmoid(logits) 86 | return probs.detach().cpu().numpy() 87 | 88 | if labels is not None: 89 | probs = torch.nn.functional.sigmoid(logits) 90 | self._metric(gold_labels=labels.reshape(-1,self.num_labels), predictions=probs.reshape(-1,self.num_labels), mask=sentence_level_pad_mask.reshape(-1)) 91 | 92 | return { 93 | "loss": avg_loss 94 | } 95 | 96 | 97 | @overrides 98 | def forward_on_instance(self, instance: Instance) -> Dict[str, str]: 99 | """ 100 | Takes an :class:`~allennlp.data.instance.Instance`, which typically has raw text in it, 101 | converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays 102 | through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) 103 | and returns the result. Before returning the result, we convert any 104 | ``torch.Tensors`` into numpy arrays and remove the batch dimension. 105 | """ 106 | cuda_device = self._get_prediction_device() 107 | dataset = Batch([instance]) 108 | dataset.index_instances(self.vocab) 109 | model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) 110 | return self.forward(**model_input, only_predict_probs=True) 111 | 112 | 113 | def get_metrics(self, reset: bool = False) -> Dict[str, any]: 114 | metrics = self._metric.get_metric(reset) 115 | return metrics 116 | 117 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/sentseq_binary_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | import numpy as np 5 | from overrides import overrides 6 | import pickle 7 | 8 | import jsonlines 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 12 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 13 | from allennlp.data.fields import TextField, ArrayField, MetadataField, NamespaceSwappingField, ListField 14 | from allennlp.data.instance import Instance 15 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 16 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 17 | from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter 18 | 19 | 20 | from overrides import overrides 21 | 22 | 23 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 24 | 25 | 26 | @DatasetReader.register("sentseq_binary_dataset_reader") 27 | class SentSeqBinaryDatasetReader(DatasetReader): 28 | def __init__(self, 29 | max_sent_length : int = np.inf, 30 | max_sents_per_example : int = np.inf, 31 | tokenizer: Tokenizer = None, 32 | token_indexers: Dict[str, TokenIndexer] = None, 33 | lowercase_tokens : bool = False, 34 | lazy: bool = False, 35 | max_to_read = np.inf) -> None: 36 | super().__init__(lazy) 37 | self.lowercase_tokens = lowercase_tokens 38 | self.max_sent_length = max_sent_length 39 | self.max_sents_per_example = max_sents_per_example 40 | self.max_to_read = max_to_read 41 | self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) 42 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 43 | if "tokens" not in self._token_indexers or \ 44 | not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer): 45 | raise ConfigurationError("CNNDmailDatasetReader expects 'token_indexers' to contain " 46 | "a 'single_id' token indexer called 'tokens'.") 47 | 48 | 49 | @overrides 50 | def _read(self, file_path): 51 | logger.info("Reading instances from lines in file at: %s", file_path) 52 | with jsonlines.open(file_path, "r") as reader: 53 | num_passed = 0 54 | for dp in reader: 55 | if num_passed == self.max_to_read: 56 | return 57 | if len(" ".join(dp["article_lines"]))==0: # if the input article has length 0 then there is a crash due to some NaN popping up. there are 114 such datapoins in cnndmail trainset 58 | continue 59 | num_passed += 1 60 | yield self.dict_to_instance(dp) 61 | 62 | def dict_to_instance(self, dp): 63 | input_lines = dp["article_lines"] 64 | labels = dp["labels"] 65 | 66 | if len(input_lines)!=len(labels): 67 | print("error in", dp["case_id"]) 68 | 69 | 70 | input_lines = input_lines[:self.max_sents_per_example] 71 | shortened_input_lines = [] 72 | for line in input_lines: 73 | shortened_input_lines.append( " ".join(line.split(" ")[:self.max_sent_length]) ) 74 | labels = labels[:self.max_sents_per_example] 75 | return self.text_to_instance(shortened_input_lines, labels) 76 | 77 | @staticmethod 78 | def _tokens_to_ids(tokens: List[Token]) -> List[int]: 79 | ids: Dict[str, int] = {} 80 | out: List[int] = [] 81 | for token in tokens: 82 | out.append(ids.setdefault(token.text.lower(), len(ids))) 83 | return out 84 | 85 | @overrides 86 | def text_to_instance(self, shortened_input_lines: List[str], labels: List[List[int]] = None) -> Instance: # type: ignore 87 | """ 88 | Turn raw source string and target string into an ``Instance``. 89 | 90 | Parameters 91 | ---------- 92 | source_string : ``str``, required 93 | target_string : ``str``, optional (default = None) 94 | 95 | Returns 96 | ------- 97 | Instance 98 | See the above for a description of the fields that the instance will contain. 99 | """ 100 | # pylint: disable=arguments-differ 101 | if self.lowercase_tokens: 102 | for i, s in enumerate(shortened_input_lines): 103 | shortened_input_lines[i]=s.lower() 104 | 105 | tokenized_lines = [[Token(START_SYMBOL)]+self._tokenizer.tokenize(s)+[Token(END_SYMBOL)] for s in shortened_input_lines] 106 | indexed_line_fields = [TextField(tokenized_line, self._token_indexers) for tokenized_line in tokenized_lines] 107 | 108 | seq_of_textfields = ListField(indexed_line_fields) 109 | 110 | meta_fields = {"tokenized_lines": tokenized_lines} 111 | fields_dict = { 112 | "lines": seq_of_textfields, 113 | } 114 | 115 | if labels is not None: 116 | labelfield = ArrayField(np.max(np.array(labels), axis=1, keepdims=True)) 117 | fields_dict["labels"] = labelfield 118 | 119 | fields_dict["meta"] = MetadataField(meta_fields) 120 | 121 | return Instance(fields_dict) 122 | 123 | 124 | -------------------------------------------------------------------------------- /pointergen/cnndmail_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | import numpy as np 5 | from overrides import overrides 6 | import pickle 7 | 8 | import jsonlines 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.common.file_utils import cached_path 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 14 | from allennlp.data.fields import TextField, ArrayField, MetadataField, NamespaceSwappingField 15 | from allennlp.data.instance import Instance 16 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 17 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 18 | from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter 19 | 20 | from overrides import overrides 21 | 22 | from pointergen.custom_instance import SyncedFieldsInstance 23 | from pointergen.fields import SourceTextField, TargetTextField 24 | 25 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 26 | 27 | 28 | @DatasetReader.register("cnndmail_dataset_reader") 29 | class CNNDmailDatasetReader(DatasetReader): 30 | def __init__(self, 31 | max_source_length : int = 400, 32 | max_target_length : int =100, 33 | tokenizer: Tokenizer = None, 34 | token_indexers: Dict[str, TokenIndexer] = None, 35 | lowercase_tokens : bool = False, 36 | lazy: bool = False, 37 | max_to_read = np.inf) -> None: 38 | super().__init__(lazy) 39 | self.lowercase_tokens = lowercase_tokens 40 | self.max_source_length = max_source_length 41 | self.max_target_length = max_target_length 42 | self.max_to_read = max_to_read 43 | self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) 44 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 45 | if "tokens" not in self._token_indexers or \ 46 | not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer): 47 | raise ConfigurationError("CNNDmailDatasetReader expects 'token_indexers' to contain " 48 | "a 'single_id' token indexer called 'tokens'.") 49 | 50 | 51 | @overrides 52 | def _read(self, file_path): 53 | logger.info("Reading instances from lines in file at: %s", file_path) 54 | with jsonlines.open(file_path, "r") as reader: 55 | num_passed = 0 56 | for dp in reader: 57 | if num_passed == self.max_to_read: 58 | return 59 | if len(" ".join(dp["article_lines"]))==0: # if the input article has length 0 then there is a crash due to some NaN popping up. there are 114 such datapoins in cnndmail trainset 60 | continue 61 | num_passed += 1 62 | yield self.dict_to_instance(dp) 63 | 64 | def dict_to_instance(self, dp): 65 | source_sequence = " ".join(dp["article_lines"]) 66 | target_sequence = " ".join(dp["summary_lines"]) 67 | source_words_truncated = source_sequence.split(" ")[:self.max_source_length] 68 | target_words_truncated = target_sequence.split(" ")[:self.max_target_length] 69 | source_sequence = " ".join(source_words_truncated) 70 | target_sequence = " ".join(target_words_truncated) 71 | return self.text_to_instance(source_sequence, target_sequence) 72 | 73 | @staticmethod 74 | def _tokens_to_ids(tokens: List[Token]) -> List[int]: 75 | ids: Dict[str, int] = {} 76 | out: List[int] = [] 77 | for token in tokens: 78 | out.append(ids.setdefault(token.text.lower(), len(ids))) 79 | return out 80 | 81 | @overrides 82 | def text_to_instance(self, source_string: str, target_string: str = None) -> Instance: # type: ignore 83 | """ 84 | Turn raw source string and target string into an ``Instance``. 85 | 86 | Parameters 87 | ---------- 88 | source_string : ``str``, required 89 | target_string : ``str``, optional (default = None) 90 | 91 | Returns 92 | ------- 93 | Instance 94 | See the above for a description of the fields that the instance will contain. 95 | """ 96 | # pylint: disable=arguments-differ 97 | if self.lowercase_tokens: 98 | source_string = source_string.lower() 99 | target_string = target_string.lower() 100 | tokenized_source = self._tokenizer.tokenize(source_string) 101 | source_field = SourceTextField(tokenized_source, self._token_indexers) 102 | 103 | meta_fields = {"source_tokens": [x.text for x in tokenized_source]} 104 | fields_dict = { 105 | "source_tokens": source_field, 106 | } 107 | 108 | if target_string is not None: 109 | tokenized_target = self._tokenizer.tokenize(target_string) 110 | meta_fields["target_tokens"] = [x.text for x in tokenized_target] 111 | tokenized_target.insert(0, Token(START_SYMBOL)) 112 | tokenized_target.append(Token(END_SYMBOL)) 113 | target_field = TargetTextField(tokenized_target, self._token_indexers) 114 | fields_dict["target_tokens"] = target_field 115 | 116 | fields_dict["meta"] = MetadataField(meta_fields) 117 | 118 | return SyncedFieldsInstance(fields_dict) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /pointergen/conditioned_cnndmail_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | import numpy as np 5 | 6 | import jsonlines 7 | 8 | from allennlp.common.checks import ConfigurationError 9 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 11 | from allennlp.data.fields import TextField, ArrayField, MetadataField, NamespaceSwappingField, LabelField 12 | from allennlp.data.instance import Instance 13 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 14 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 15 | from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter 16 | 17 | from overrides import overrides 18 | 19 | from pointergen.custom_instance import SyncedFieldsInstance 20 | from pointergen.fields import SourceTextField, TargetTextField 21 | 22 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | @DatasetReader.register("conditioned_cnndmail_dataset_reader") 26 | class CNNDmailDatasetReader(DatasetReader): 27 | def __init__(self, 28 | max_source_length : int = 400, 29 | max_target_length : int =100, 30 | tokenizer: Tokenizer = None, 31 | token_indexers: Dict[str, TokenIndexer] = None, 32 | lowercase_tokens : bool = False, 33 | lazy: bool = False, 34 | max_to_read = np.inf) -> None: 35 | super().__init__(lazy) 36 | self.lowercase_tokens = lowercase_tokens 37 | self.max_source_length = max_source_length 38 | self.max_target_length = max_target_length 39 | self.max_to_read = max_to_read 40 | self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) 41 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 42 | if "tokens" not in self._token_indexers or \ 43 | not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer): 44 | raise ConfigurationError("CNNDmailDatasetReader expects 'token_indexers' to contain " 45 | "a 'single_id' token indexer called 'tokens'.") 46 | 47 | 48 | @overrides 49 | def _read(self, file_path): 50 | logger.info("Reading instances from lines in file at: %s", file_path) 51 | with jsonlines.open(file_path, "r") as reader: 52 | num_passed = 0 53 | for dp in reader: 54 | if num_passed == self.max_to_read: 55 | return 56 | if len(" ".join(dp["article_lines"]))==0: # if the input article has length 0 then there is a crash due to some NaN popping up. there are 114 such datapoins in cnndmail trainset 57 | continue 58 | num_passed += 1 59 | yield self.dict_to_instance(dp) 60 | 61 | def dict_to_instance(self, dp): 62 | source_sequence = " ".join(dp["article_lines"]) 63 | target_sequence = " ".join(dp["summary_lines"]) 64 | source_words_truncated = source_sequence.split(" ")[:self.max_source_length] 65 | target_words_truncated = target_sequence.split(" ")[:self.max_target_length] 66 | source_sequence = " ".join(source_words_truncated) 67 | target_sequence = " ".join(target_words_truncated) 68 | section_label = dp["section"] 69 | return self.text_to_instance(source_sequence, target_sequence, section_label) 70 | 71 | @staticmethod 72 | def _tokens_to_ids(tokens: List[Token]) -> List[int]: 73 | ids: Dict[str, int] = {} 74 | out: List[int] = [] 75 | for token in tokens: 76 | out.append(ids.setdefault(token.text.lower(), len(ids))) 77 | return out 78 | 79 | @overrides 80 | def text_to_instance(self, source_string: str, target_string: str = None, section_label: str = None) -> Instance: # type: ignore 81 | """ 82 | Turn raw source string and target string into an ``Instance``. 83 | 84 | Parameters 85 | ---------- 86 | source_string : ``str``, required 87 | target_string : ``str``, optional (default = None) 88 | 89 | Returns 90 | ------- 91 | Instance 92 | See the above for a description of the fields that the instance will contain. 93 | """ 94 | # pylint: disable=arguments-differ 95 | if self.lowercase_tokens: 96 | source_string = source_string.lower() 97 | target_string = target_string.lower() 98 | tokenized_source = self._tokenizer.tokenize(source_string) 99 | source_field = SourceTextField(tokenized_source, self._token_indexers) 100 | section_field = LabelField(section_label, label_namespace="section_labels") 101 | 102 | meta_fields = {"source_tokens": [x.text for x in tokenized_source]} 103 | fields_dict = { 104 | "source_tokens": source_field, 105 | "section_labels": section_field, 106 | } 107 | 108 | if target_string is not None: 109 | tokenized_target = self._tokenizer.tokenize(target_string) 110 | meta_fields["target_tokens"] = [x.text for x in tokenized_target] 111 | tokenized_target.insert(0, Token(START_SYMBOL)) 112 | tokenized_target.append(Token(END_SYMBOL)) 113 | target_field = TargetTextField(tokenized_target, self._token_indexers) 114 | fields_dict["target_tokens"] = target_field 115 | 116 | fields_dict["meta"] = MetadataField(meta_fields) 117 | 118 | return SyncedFieldsInstance(fields_dict) 119 | 120 | 121 | -------------------------------------------------------------------------------- /dataset_creators/AMI/clash_resolutions.json: -------------------------------------------------------------------------------- 1 | {"ES2002b": {"1804.93": true, "2066.87": false}, "ES2002c": {"526.87": false, "772.01": true, "2293.15": true}, "ES2002d": {"151.6": false, "456.85": true}, "ES2003b": {"1861.91": true}, "ES2003c": {"125.66": true, "1459.57": true, "2250.33": true}, "ES2003d": {"1326.14": false, "1773.98": false, "1784.86": false, "1984.93": true}, "ES2004a": {"359.11": true}, "ES2004b": {"1471.32": true}, "ES2004c": {"868.04": false, "1377.38": true}, "ES2004d": {"1023.43": true, "1382.43": false, "1795.96": true, "1918.31": true, "2031.37": true}, "ES2005a": {"144.4": false}, "ES2005b": {"17.03": false, "1394.69": true}, "ES2005c": {"574.63": true}, "ES2005d": {"94.37": true, "1525.7": true, "1661.1": false}, "ES2006b": {"596.26": true, "2086.55": true}, "ES2006c": {"957.6": true, "1287.3": false, "1386.8": false, "1556.79": true}, "ES2006d": {"693.57": true, "726.22": true, "1260.79": false, "1321.46": true, "1457.44": false}, "ES2007b": {"283.2": false, "1010.27": true, "1322.52": false}, "ES2008b": {"1101.41": false}, "ES2008c": {"2019.05": true}, "ES2008d": {"974.5": true, "975.68": true, "1421.48": true, "1522.9": false}, "ES2009a": {"740.51": true}, "ES2009c": {"1708.26": true}, "ES2009d": {"1183.6": false, "1944.7": true, "2096.02": true}, "ES2010b": {"549.01": false, "860.82": true}, "ES2010c": {"1021.55": true, "1368.68": true, "1578.44": true}, "ES2010d": {"265.51": true}, "ES2011a": {"173.16": false, "659.66": true, "868.88": false}, "ES2011b": {"1255.35": true, "1393.63": true, "1407.03": true}, "ES2011c": {"243.83": true, "731.94": true, "840.02": false, "1340.6": false}, "ES2011d": {"1748.42": false}, "ES2012c": {"1135.33": true, "1374.64": true}, "ES2012d": {"128.23": true}, "ES2013a": {"610.73": true}, "ES2014b": {"435.02": true, "964.81": false, "1378.24": false, "1578.16": false, "2264.83": false}, "ES2014c": {"816.1": false, "883.94": true, "1557.53": false, "1948.06": false}, "ES2014d": {"1533.08": false, "1878.96": false, "2036.52": true}, "ES2015a": {"470.91": false}, "ES2015b": {"1091.14": false, "1326.05": true, "1980.88": true}, "ES2015c": {"573.19": false, "704.86": true, "923.21": false, "1153.23": false}, "ES2015d": {"797.95": true, "969.91": true, "1407.99": true}, "ES2016a": {"1205.65": false}, "ES2016b": {"1656.08": true}, "IS1000a": {"1251.08": false}, "IS1000b": {"1756.32": true, "2059.67": false}, "IS1000c": {"642.02": true, "1540.39": false, "1919.03": false}, "IS1000d": {"1180.45": true, "1738.47": true}, "IS1001a": {"682.74": false}, "IS1001b": {"690.23": false, "1702.86": false}, "IS1001d": {"386.45": true}, "IS1002b": {"1980.28": false, "1990.15": false}, "IS1002c": {"1353.35": true}, "IS1003b": {"95.31": false, "314.43": false, "974.04": false, "1206.07": true}, "IS1003c": {"202.32": true, "348.14": false, "401.06": true, "515.87": true, "884.49": false, "1174.79": true, "1576.38": false, "1820.34": false}, "IS1003d": {"189.04": true, "272.07": false, "644.18": true, "850.58": true, "1034.14": true, "1283.56": false}, "IS1004a": {"598.4": false}, "IS1004b": {"203.27": true, "1781.6": false}, "IS1004c": {"2078.79": false}, "IS1004d": {"820.69": true, "855.14": false, "1318.87": true}, "IS1005c": {"762.24": false}, "IS1006b": {"1631.05": false, "1650.3": false, "2043.37": false, "2074.66": false, "2133.47": true}, "IS1006c": {"932.7": false, "1162.75": true, "1884.31": false}, "IS1006d": {"1273.56": false, "1452.99": false, "1505.6": false, "1540.93": true, "1568.08": true, "1692.11": false, "1791.08": false}, "IS1007c": {"152.68": false, "1320.85": false}, "IS1007d": {"464.92": false, "1209.47": false}, "IS1008b": {"530.41": false}, "IS1008c": {"1246.98": false}, "IS1008d": {"556.01": false, "710.92": false}, "IS1009b": {"857.0": false}, "IS1009d": {"130.37": true, "1553.05": false, "1830.94": true}, "TS3003b": {"1723.04": false}, "TS3003c": {"1417.83": false}, "TS3003d": {"1403.4": false, "1679.47": true, "2387.97": false}, "TS3004a": {"999.57": true}, "TS3004b": {"1138.43": false, "1250.1": true, "1952.85": false, "2004.81": true, "2014.13": false, "2070.54": true}, "TS3004c": {"175.42": true, "532.27": false, "1207.39": true, "1215.55": true, "1360.04": true, "1924.57": false}, "TS3004d": {"1748.9": true}, "TS3005a": {"914.32": true}, "TS3005b": {"982.15": false, "1142.94": false, "2209.35": false}, "TS3005d": {"99.37": true, "226.89": true, "780.3": true, "853.68": true, "1163.66": false, "1521.67": true, "1616.2": true, "1861.92": false, "2030.0": false, "2308.78": false, "2354.56": false}, "TS3006a": {"313.74": false, "1049.81": false, "1139.84": false}, "TS3006b": {"654.11": true, "917.18": true, "1170.04": false, "1731.4": false, "1993.5": false, "2065.36": true, "2208.69": true, "2324.25": true}, "TS3006c": {"1051.7": true, "1191.9": true, "1558.03": false, "1791.51": true, "1850.28": false}, "TS3006d": {"520.11": true, "560.72": false, "697.29": true, "935.84": false, "942.27": false, "1220.8": true, "1281.37": true, "2097.2": true, "2128.98": false, "2199.49": false, "2842.26": false}, "TS3007b": {"1857.21": true, "2178.23": true, "2407.84": true}, "TS3007c": {"244.17": true, "292.9": true, "783.64": false, "1994.77": true}, "TS3007d": {"899.98": false, "952.65": false, "1217.07": false, "1225.93": false, "1326.5": true, "1539.95": false, "2303.62": false}, "TS3008b": {"1245.4": false, "1611.97": true, "1809.76": false}, "TS3008c": {"2279.06": true}, "TS3008d": {"1212.48": true, "1255.83": false, "1508.45": false, "1711.95": false, "1811.95": false, "2147.64": false}, "TS3009a": {"186.84": true, "533.04": false, "909.35": false}, "TS3009b": {"831.93": false, "1233.18": true, "1337.99": false, "2056.9": false, "2258.66": false, "2387.85": true}, "TS3009c": {"1675.27": false, "2260.45": true}, "TS3009d": {"644.52": false, "672.15": false, "1025.27": false}, "TS3010c": {"653.79": true, "1139.96": false, "1566.72": true}, "TS3010d": {"393.93": true}, "TS3011b": {"373.27": true, "788.63": false, "867.63": true, "1766.66": false, "2127.8": false}, "TS3011c": {"721.31": true, "1352.39": true, "1420.9": true, "1455.86": false, "2223.47": false, "2284.2": false}, "TS3011d": {"719.874": true, "1748.494": false}, "TS3012b": {"159.42": true, "1530.76": true, "1959.42": true}, "TS3012c": {"326.75": true, "592.17": true, "1259.63": true, "1567.98": true, "2090.28": false, "2176.26": true, "2316.01": false}, "TS3012d": {"354.65": true, "1000.54": true, "1565.06": true, "1669.15": false}} -------------------------------------------------------------------------------- /t5/pretrained_cnndmail_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb 3 | from typing import List, Dict 4 | 5 | import numpy as np 6 | from overrides import overrides 7 | import pickle 8 | 9 | import jsonlines 10 | 11 | from allennlp.common.checks import ConfigurationError 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 14 | from allennlp.data.fields import TextField, ArrayField, MetadataField 15 | from allennlp.data.instance import Instance 16 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 17 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 18 | 19 | from overrides import overrides 20 | 21 | from transformers import T5Tokenizer 22 | 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | 27 | 28 | @DatasetReader.register("pretrained_cnndmail_dataset_reader") 29 | class PretrainedCNNDmailDatasetReader(DatasetReader): 30 | def __init__(self, 31 | max_source_length : int = 400, 32 | max_target_length : int =100, 33 | tokenizer: Tokenizer = None, 34 | pretrained_model_name: str= 't5-base', 35 | token_indexers: Dict[str, TokenIndexer] = None, 36 | lowercase_tokens : bool = False, 37 | lazy: bool = False, 38 | max_to_read = np.inf) -> None: 39 | super().__init__(lazy) 40 | self.lowercase_tokens = lowercase_tokens 41 | self.max_source_length = max_source_length 42 | self.max_target_length = max_target_length 43 | self.max_to_read = max_to_read 44 | 45 | # REMEMBER : Your data file must not contain things like PAD, UNK, START, STOP explicitly 46 | self._tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name) 47 | 48 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 49 | if "tokens" not in self._token_indexers or \ 50 | not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer): 51 | raise ConfigurationError("CNNDmailDatasetReader expects 'token_indexers' to contain " 52 | "a 'single_id' token indexer called 'tokens'.") 53 | 54 | 55 | @overrides 56 | def _read(self, file_path): 57 | logger.info("Reading instances from lines in file at: %s", file_path) 58 | with jsonlines.open(file_path, "r") as reader: 59 | num_passed = 0 60 | for dp in reader: 61 | if num_passed == self.max_to_read: 62 | return 63 | if len(" ".join(dp["article_lines"]))==0: # if the input article has length 0 then there is a crash due to some NaN popping up. there are 114 such datapoins in cnndmail trainset 64 | continue 65 | num_passed += 1 66 | yield self.dict_to_instance(dp) 67 | 68 | def dict_to_instance(self, dp): 69 | source_sequence = " ".join(dp["article_lines"]) 70 | target_sequence = " ".join(dp["summary_lines"]) 71 | source_words_truncated = source_sequence.split(" ")[:self.max_source_length] 72 | target_words_truncated = target_sequence.split(" ")[:self.max_target_length] 73 | source_sequence = " ".join(source_words_truncated) 74 | target_sequence = " ".join(target_words_truncated) 75 | return self.text_to_instance(source_sequence, target_sequence) 76 | 77 | 78 | @overrides 79 | def text_to_instance(self, source_string: str, target_string: str = None) -> Instance: # type: ignore 80 | """ 81 | Turn raw source string and target string into an ``Instance``. 82 | 83 | Parameters 84 | ---------- 85 | source_string : ``str``, required 86 | target_string : ``str``, optional (default = None) 87 | 88 | Returns 89 | ------- 90 | Instance 91 | See the above for a description of the fields that the instance will contain. 92 | """ 93 | # pylint: disable=arguments-differ 94 | if self.lowercase_tokens: 95 | source_string = source_string.lower() 96 | target_string = target_string.lower() 97 | tokenized_source = self._tokenizer.tokenize(source_string) 98 | tokenized_source = [Token(w) for w in tokenized_source] 99 | source_field = TextField(tokenized_source, self._token_indexers) 100 | 101 | meta_fields = {"source_tokens": [x.text for x in tokenized_source]} 102 | fields_dict = { 103 | "source_tokens": source_field, 104 | } 105 | 106 | if target_string is not None: 107 | tokenized_target = self._tokenizer.tokenize(target_string) 108 | tokenized_target = [Token(w) for w in tokenized_target] 109 | meta_fields["target_tokens"] = [x.text for x in tokenized_target] 110 | tokenized_target.insert(0, Token(START_SYMBOL)) 111 | tokenized_target.append(Token(END_SYMBOL)) 112 | target_field = TextField(tokenized_target, self._token_indexers) 113 | fields_dict["target_tokens"] = target_field 114 | 115 | fields_dict["meta"] = MetadataField(meta_fields) 116 | 117 | return Instance(fields_dict) 118 | 119 | 120 | 121 | @DatasetReader.register("conditioned_pretrained_cnndmail_dataset_reader") 122 | class ConditionedPretrainedCNNDmailDatasetReader(PretrainedCNNDmailDatasetReader): 123 | @overrides 124 | def dict_to_instance(self, dp): 125 | source_sequence = " ".join(dp["article_lines"]) 126 | target_sequence = " ".join(dp["summary_lines"]) 127 | source_words_truncated = source_sequence.split(" ")[:self.max_source_length] 128 | target_words_truncated = target_sequence.split(" ")[:self.max_target_length] 129 | 130 | source_sequence = " ".join(source_words_truncated) 131 | section_name = dp["section"].replace("_"," ").strip().lower() 132 | source_sequence = section_name+" "+source_sequence 133 | 134 | target_sequence = " ".join(target_words_truncated) 135 | return self.text_to_instance(source_sequence, target_sequence) 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /t5/pretrained_model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from allennlp.models.model import Model 7 | from pointergen.custom_instance import SyncedFieldsInstance 8 | from typing import Dict 9 | from overrides import overrides 10 | from allennlp.data.dataset import Batch 11 | from allennlp.nn import util 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.training.metrics import CategoricalAccuracy 14 | 15 | from transformers import T5ForConditionalGeneration, T5Config 16 | 17 | EPS=1e-8 18 | 19 | 20 | @Model.register("pretrained_t5") 21 | class PretrainedT5(Model): 22 | def __init__(self, vocab, pretrained_model_name='t5-base', min_decode_length=0, max_decode_length=99): 23 | super().__init__(vocab) 24 | 25 | ## vocab related setup begins 26 | assert "tokens" in vocab._token_to_index and len(vocab._token_to_index.keys())==1, "Vocabulary must have tokens as the only namespace" 27 | self.vocab_size=vocab.get_vocab_size() 28 | self.PAD_ID = vocab.get_token_index(vocab._padding_token) 29 | self.OOV_ID = vocab.get_token_index(vocab._oov_token) 30 | self.START_ID = vocab.get_token_index(START_SYMBOL) 31 | self.END_ID = vocab.get_token_index(END_SYMBOL) 32 | ## vocab related setup ends 33 | 34 | 35 | self.min_decode_length = min_decode_length 36 | self.max_decode_length = max_decode_length 37 | 38 | self.metrics = { 39 | "accuracy" : CategoricalAccuracy(), 40 | } 41 | 42 | # buffers because these dont need grads. These are placed here because they will be replicated across gpus 43 | self.register_buffer("true_rep", torch.tensor(1.0)) 44 | self.register_buffer("false_rep", torch.tensor(0.0)) 45 | 46 | self.softmax = nn.Softmax(dim=-1) 47 | 48 | T5Model = T5ForConditionalGeneration.from_pretrained(pretrained_model_name) 49 | self.T5Model = T5Model 50 | 51 | 52 | 53 | def forward(self, source_tokens, target_tokens, meta=None, only_predict_probs=False, return_pgen=False): 54 | inp_ids = source_tokens["tokens"] 55 | 56 | feed_tensor = target_tokens["tokens"][:, :-1] 57 | target_tensor = target_tokens["tokens"].detach().clone()[:, 1:] 58 | 59 | batch_size = inp_ids.size(0) 60 | input_pad_mask=torch.where(inp_ids!=self.PAD_ID, self.true_rep, self.false_rep) 61 | 62 | # output_pad_mask is not needed rather the lm labels have to be set to -100 63 | output_pad_mask=torch.where(target_tensor!=self.PAD_ID, self.true_rep, self.false_rep) 64 | target_tensor[target_tensor==self.PAD_ID]=-100 #hardcoded value courtesy of transformers people 65 | 66 | loss, logits, _, probably_encoder_output = self.T5Model( 67 | input_ids=inp_ids, 68 | attention_mask=input_pad_mask, 69 | decoder_input_ids=feed_tensor, 70 | use_cache=None, #TODO: see if we should use this 71 | labels=target_tensor) 72 | 73 | 74 | predicted_seqfirst = logits.permute(1,0,2) 75 | true_labels_seqfirst = target_tensor.permute(1,0) 76 | mask_seqfirst = output_pad_mask.permute(1,0) 77 | 78 | for metric in self.metrics.values(): 79 | for (p, t, m) in zip(predicted_seqfirst, true_labels_seqfirst, mask_seqfirst): 80 | metric(p, t, m) # TODO: make sure that none of the metrics need softmaxed probabilities 81 | 82 | return { 83 | "loss": loss, 84 | "logits": logits 85 | } 86 | 87 | 88 | 89 | 90 | @overrides 91 | def forward_on_instance(self, instance: SyncedFieldsInstance, decode_strategy) -> Dict[str, str]: 92 | cuda_device = self._get_prediction_device() 93 | dataset = Batch([instance]) 94 | dataset.index_instances(self.vocab) 95 | model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) 96 | if decode_strategy=='greedy': 97 | output_ids = self.common_decode(**model_input, min_length=self.min_decode_length, max_length=self.max_decode_length, num_beams=1) 98 | elif decode_strategy=='beamsearch': 99 | output_ids = self.common_decode(**model_input, min_length=self.min_decode_length, max_length=self.max_decode_length, num_beams=4) 100 | else: 101 | raise NotImplementedError 102 | 103 | output_words = [] 104 | 105 | for _id in output_ids: 106 | output_words.append(self.vocab.get_token_from_index(_id)) 107 | 108 | assert output_words[0]==START_SYMBOL, "somehow the first symbol is not the START symbol. might be a bug" 109 | output_words=output_words[1:] 110 | 111 | if output_words[-1]==END_SYMBOL: 112 | output_words = output_words[:-1] 113 | 114 | return {"tokens": output_words} 115 | 116 | 117 | @overrides 118 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 119 | metrics_to_return = { 120 | metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items() 121 | } 122 | return metrics_to_return 123 | 124 | 125 | def common_decode(self, source_tokens, target_tokens=None, meta=None, min_length=35, max_length=120, num_beams=1): 126 | # pdb.set_trace() 127 | inp_ids = source_tokens["tokens"] 128 | input_pad_mask=torch.where(inp_ids!=self.PAD_ID, self.true_rep, self.false_rep) 129 | 130 | generated_ids = self.T5Model.generate( 131 | input_ids = inp_ids, 132 | attention_mask = input_pad_mask, 133 | min_length = min_length, 134 | max_length = max_length, 135 | decoder_start_token_id = self.START_ID, 136 | bos_token_id = self.START_ID, 137 | pad_token_id = self.PAD_ID, 138 | eos_token_id = self.END_ID, 139 | num_beams = num_beams 140 | ) 141 | 142 | return generated_ids.detach().cpu().numpy()[0] 143 | 144 | 145 | 146 | 147 | @Model.register("randominit_t5") 148 | class RandomInitT5(PretrainedT5): 149 | def __init__(self, vocab, pretrained_model_name='t5-base', min_decode_length=0, max_decode_length=99): 150 | super().__init__(vocab) 151 | 152 | ## vocab related setup begins 153 | assert "tokens" in vocab._token_to_index and len(vocab._token_to_index.keys())==1, "Vocabulary must have tokens as the only namespace" 154 | self.vocab_size=vocab.get_vocab_size() 155 | self.PAD_ID = vocab.get_token_index(vocab._padding_token) 156 | self.OOV_ID = vocab.get_token_index(vocab._oov_token) 157 | self.START_ID = vocab.get_token_index(START_SYMBOL) 158 | self.END_ID = vocab.get_token_index(END_SYMBOL) 159 | ## vocab related setup ends 160 | 161 | 162 | self.min_decode_length = min_decode_length 163 | self.max_decode_length = max_decode_length 164 | 165 | self.metrics = { 166 | "accuracy" : CategoricalAccuracy(), 167 | } 168 | 169 | # buffers because these dont need grads. These are placed here because they will be replicated across gpus 170 | self.register_buffer("true_rep", torch.tensor(1.0)) 171 | self.register_buffer("false_rep", torch.tensor(0.0)) 172 | 173 | self.softmax = nn.Softmax(dim=-1) 174 | 175 | base_config = T5Config.from_pretrained(pretrained_model_name) 176 | self.T5Model = T5ForConditionalGeneration(base_config) 177 | 178 | -------------------------------------------------------------------------------- /sequential_sentence_tagger/all_classification_metrics.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, List 3 | 4 | from overrides import overrides 5 | import torch 6 | 7 | from allennlp.training.metrics.metric import Metric 8 | 9 | 10 | import numpy as np 11 | 12 | import sklearn 13 | 14 | 15 | from utils.section_names import ami_section_names 16 | import json 17 | 18 | 19 | @Metric.register("all-classification-metric") 20 | class AllClassificationMetric(Metric): 21 | def __init__(self, num_classes:int, dataset="abridge") -> None: 22 | self.num_classes=num_classes 23 | self.y_true=[] 24 | self.y_pred_class=[] 25 | self.y_pred_cont=[] 26 | self.dataset=dataset 27 | 28 | def __call__(self, 29 | gold_labels: torch.Tensor, 30 | predictions: torch.Tensor, 31 | mask: Optional[torch.Tensor] = None, 32 | thresholds: List[float]= None): 33 | """ 34 | Parameters 35 | ---------- 36 | predictions : shape is num_datapointsXnum_classes 37 | gold_labels : shape is num_datapointsXnum_classes 38 | mask: ``torch.Tensor``, optional (default = None). 39 | """ 40 | predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) 41 | predictions=predictions.detach().cpu().numpy() 42 | gold_labels=gold_labels.detach().cpu().numpy() 43 | 44 | assert predictions.ndim==2 45 | assert predictions.shape[1]==self.num_classes 46 | assert gold_labels.ndim==2 47 | assert gold_labels.shape[1]==self.num_classes 48 | 49 | 50 | if mask is not None: 51 | assert mask.ndim==1 52 | mask = mask.cpu().numpy() 53 | gold_labels = gold_labels[mask] 54 | predictions = predictions[mask] 55 | 56 | 57 | self.y_true.append(gold_labels) 58 | self.y_pred_cont.append(predictions) 59 | 60 | if type(thresholds)==type(None): 61 | self.y_pred_class.append(np.round(predictions, decimals=0)) 62 | else: 63 | self.y_pred_class.append((predictions >= thresholds).astype(int)) 64 | 65 | 66 | def get_metric(self, reset: bool = False): 67 | """ 68 | Returns 69 | ------- 70 | The accumulated metrics 71 | """ 72 | all_results_dict={"nothing":0.0} 73 | if reset: 74 | if self.dataset=="ami": 75 | label_names = ami_section_names 76 | elif self.dataset=="unary_prediction": 77 | label_names = ["important"] 78 | else: 79 | raise NotImplementedError 80 | 81 | # pdb.set_trace() 82 | all_results_dict = calc_metrics(np.concatenate(self.y_true, axis=0), 83 | np.concatenate(self.y_pred_class, axis=0), 84 | np.concatenate(self.y_pred_cont, axis=0), 85 | label_names) 86 | self.reset() 87 | 88 | return all_results_dict 89 | 90 | @overrides 91 | def reset(self): 92 | self.y_true=np.empty(shape=(0, self.num_classes)) 93 | self.y_pred_class=np.empty(shape=(0, self.num_classes)) 94 | self.y_pred_cont=np.empty(shape=(0, self.num_classes)) 95 | 96 | 97 | 98 | 99 | def precision_at_1(y_true, y_pred): 100 | assert y_pred.dtype==np.float32 or y_pred.dtype==np.float64 or y_pred.dtype==np.float16 101 | most_probable_pred=np.argmax(y_pred, axis=1) 102 | most_probable_pred=most_probable_pred.reshape(-1) 103 | 104 | hit_placewise=np.zeros(y_true.shape[1]) 105 | 106 | hits=0 107 | total=0 108 | 109 | for idx, arr in zip(most_probable_pred, y_true): 110 | ispresent=arr[idx] 111 | if ispresent: 112 | hit_placewise[idx]+=1 113 | hits+=ispresent 114 | total+=1 115 | 116 | return hits/total, hit_placewise 117 | 118 | 119 | def calc_metrics( y_true, y_pred_class, y_pred_cont, label_names): 120 | return_dict={} 121 | 122 | # print( {"y_true": y_true.shape, 123 | # "y_pred_class": y_pred_class.shape, 124 | # "y_pred_cont": y_pred_cont.shape } ) 125 | 126 | num_points=y_true.shape[0] 127 | num_classes=y_true.shape[1] 128 | 129 | 130 | return_dict["aggregate_micro-precision"]=sklearn.metrics.precision_score(y_true, y_pred_class, average="micro") 131 | return_dict["aggregate_macro-precision"]=sklearn.metrics.precision_score(y_true, y_pred_class, average="macro") 132 | 133 | return_dict["aggregate_micro-recall"]=sklearn.metrics.recall_score(y_true, y_pred_class, average="micro") 134 | return_dict["aggregate_macro-recall"]=sklearn.metrics.recall_score(y_true, y_pred_class, average="macro") 135 | 136 | return_dict["aggregate_micro-f1"]=sklearn.metrics.f1_score(y_true, y_pred_class, average="micro") 137 | return_dict["aggregate_macro-f1"]=sklearn.metrics.f1_score(y_true, y_pred_class, average="macro") 138 | 139 | return_dict["aggregate_accuracy"]=sklearn.metrics.accuracy_score(y_true.flatten(), y_pred_class.flatten()) 140 | 141 | return_dict["aggregate_micro-auc"]=sklearn.metrics.roc_auc_score(y_true, y_pred_cont, average="micro") 142 | return_dict["aggregate_macro-auc"]=sklearn.metrics.roc_auc_score(y_true, y_pred_cont, average="macro") 143 | 144 | classwise_results={} 145 | classwise_results["precision"]=sklearn.metrics.precision_score(y_true, y_pred_class, average=None) 146 | classwise_results["recall"]=sklearn.metrics.recall_score(y_true, y_pred_class, average=None) 147 | classwise_results["f1"]=sklearn.metrics.f1_score(y_true, y_pred_class, average=None) 148 | classwise_results["accuracy"]=np.array([sklearn.metrics.accuracy_score(y_true[:,_i], y_pred_class[:,_i]) 149 | for _i in range(num_classes) ]) 150 | classwise_results["auc"]=sklearn.metrics.roc_auc_score(y_true, y_pred_cont, average=None) 151 | 152 | if classwise_results["auc"].shape==(): # if theres only one class it returns a scalar 153 | classwise_results["auc"]=np.array([classwise_results["auc"]]) 154 | 155 | # print("##################") 156 | # print(classwise_results["auc"].shape) 157 | # print("##################") 158 | 159 | # pdb.set_trace() 160 | 161 | precision_at_1_calculated = precision_at_1(y_true, y_pred_cont) 162 | return_dict["aggregate_precision-at-1"]=precision_at_1_calculated[0] 163 | piecwise_contribution_to_p1=precision_at_1_calculated[1] 164 | 165 | classwise_dfdict={} 166 | for _i in range(num_classes): 167 | heading=label_names[_i] 168 | classwise_dfdict[heading]={} 169 | classwise_dfdict[heading]["prevalence_rate"]=sum(y_true[:,_i]/num_points) 170 | classwise_dfdict[heading]["precision"]=classwise_results["precision"][_i] 171 | classwise_dfdict[heading]["recall"]=classwise_results["recall"][_i] 172 | classwise_dfdict[heading]["f1"]=classwise_results["f1"][_i] 173 | classwise_dfdict[heading]["accuracy"]=classwise_results["accuracy"][_i] 174 | classwise_dfdict[heading]["auc"]=classwise_results["auc"][_i] 175 | classwise_dfdict[heading]["contribution_to_p1"]=piecwise_contribution_to_p1[_i]/sum(piecwise_contribution_to_p1) 176 | 177 | dict_toprint={ 178 | "classwise_results": classwise_dfdict, 179 | "aggregate_results": return_dict, 180 | } 181 | print(json.dumps(dict_toprint)) 182 | 183 | 184 | for _i in range(num_classes): 185 | return_dict["classwise_"+str(_i)+"_prevalence_rate"]=sum(y_true[:,_i]/num_points) 186 | return_dict["classwise_"+str(_i)+"_precision"]=classwise_results["precision"][_i] 187 | return_dict["classwise_"+str(_i)+"_recall"]=classwise_results["recall"][_i] 188 | return_dict["classwise_"+str(_i)+"_f1"]=classwise_results["f1"][_i] 189 | return_dict["classwise_"+str(_i)+"_accuracy"]=classwise_results["accuracy"][_i] 190 | return_dict["classwise_"+str(_i)+"_auc"]=classwise_results["auc"][_i] 191 | 192 | 193 | 194 | return return_dict 195 | -------------------------------------------------------------------------------- /dataset_creators/AMI/variant_tasks_creator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from collections import OrderedDict, Counter 4 | from tqdm import tqdm 5 | import jsonlines 6 | import os 7 | 8 | 9 | train_dataset = list(jsonlines.open("../../dataset_ami/root/train.jsonl", "r")) 10 | val_dataset = list(jsonlines.open("../../dataset_ami/root/val.jsonl", "r")) 11 | test_dataset = list(jsonlines.open("../../dataset_ami/root/test.jsonl", "r")) 12 | 13 | 14 | output_root_path = "../../dataset_ami/" 15 | 16 | 17 | seq_of_sections = ["abstract", 18 | "actions", 19 | "decisions", 20 | "problems"] 21 | 22 | def makeTimeOrderedDict(unordered_dict): 23 | return OrderedDict({k:unordered_dict[k] for k in sorted(unordered_dict.keys(), key=lambda i:float(i))}) 24 | 25 | print(train_dataset[5].keys()) 26 | 27 | 28 | task_name = "full_summarization" 29 | 30 | def get_full_summarization_dp(base_dp): 31 | sorted_turns = makeTimeOrderedDict(base_dp["transcript"]) 32 | input_lines = [] 33 | for turn in sorted_turns.values(): 34 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 35 | 36 | if len(input_lines)==0: 37 | return None 38 | 39 | allsection_output_lines = [] 40 | 41 | for section_name in seq_of_sections: 42 | output_lines = ["@@"+section_name+"@@"] 43 | sorted_entries = sorted(base_dp[section_name], 44 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 45 | for entry in sorted_entries: 46 | output_lines.append(entry["summary"]) 47 | allsection_output_lines.extend(output_lines) 48 | 49 | 50 | return {"case_id":base_dp["id"] ,"article_lines":input_lines, "summary_lines":allsection_output_lines} 51 | 52 | 53 | 54 | task_output_path = os.path.join(output_root_path, task_name) 55 | os.mkdir(task_output_path) 56 | 57 | train_section_dataset = [get_full_summarization_dp(x) for x in tqdm(train_dataset)] 58 | val_section_dataset = [get_full_summarization_dp(x) for x in val_dataset] 59 | test_section_dataset = [get_full_summarization_dp(x) for x in test_dataset] 60 | 61 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 62 | for obj in train_section_dataset: 63 | if obj!=None: 64 | w.write(obj) 65 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 66 | for obj in val_section_dataset: 67 | if obj!=None: 68 | w.write(obj) 69 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 70 | for obj in test_section_dataset: 71 | if obj!=None: 72 | w.write(obj) 73 | 74 | 75 | 76 | task_name = "allxmin_summarization" 77 | 78 | def get_allxmin_summarization_dp(base_dp): 79 | all_xmins=set() 80 | allsection_output_lines = [] 81 | 82 | for section_name in seq_of_sections: 83 | output_lines = ["@@"+section_name+"@@"] 84 | sorted_entries = sorted(base_dp[section_name], 85 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 86 | for entry in sorted_entries: 87 | output_lines.append(entry["summary"]) 88 | all_xmins.update(entry["sorted_xmins"]) 89 | allsection_output_lines.extend(output_lines) 90 | 91 | all_xmins = sorted(list(all_xmins), key=float) 92 | 93 | sorted_turns = [base_dp["transcript"][k] for k in all_xmins] 94 | input_lines = [] 95 | for turn in sorted_turns: 96 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 97 | 98 | if len(input_lines)==0 or len(allsection_output_lines)==0: 99 | return None 100 | 101 | return {"case_id":base_dp["id"] ,"article_lines":input_lines, "summary_lines":allsection_output_lines} 102 | 103 | 104 | 105 | task_output_path = os.path.join(output_root_path, task_name) 106 | os.mkdir(task_output_path) 107 | 108 | train_section_dataset = [get_allxmin_summarization_dp(x) for x in tqdm(train_dataset)] 109 | val_section_dataset = [get_allxmin_summarization_dp(x) for x in val_dataset] 110 | test_section_dataset = [get_allxmin_summarization_dp(x) for x in test_dataset] 111 | 112 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 113 | for obj in train_section_dataset: 114 | if obj!=None: 115 | w.write(obj) 116 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 117 | for obj in val_section_dataset: 118 | if obj!=None: 119 | w.write(obj) 120 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 121 | for obj in test_section_dataset: 122 | if obj!=None: 123 | w.write(obj) 124 | 125 | 126 | 127 | task_name = "sectionwise_allxmin_summarization" 128 | 129 | def get_sectionwise_allxmin_summarization_dps(base_dp): 130 | new_datapoints=[] 131 | for section_name in seq_of_sections: 132 | all_xmins=set() 133 | output_lines = [] 134 | sorted_entries = sorted(base_dp[section_name], 135 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 136 | for entry in sorted_entries: 137 | output_lines.append(entry["summary"]) 138 | all_xmins.update(entry["sorted_xmins"]) 139 | 140 | all_xmins = sorted(list(all_xmins), key=float) 141 | 142 | sorted_turns = [base_dp["transcript"][k] for k in all_xmins] 143 | input_lines = [] 144 | for turn in sorted_turns: 145 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 146 | 147 | if len(input_lines)==0 or len(output_lines)==0: 148 | continue 149 | 150 | new_datapoints.append({"case_id":base_dp["id"] ,"article_lines":input_lines, "summary_lines":output_lines, "section":section_name}) 151 | 152 | return new_datapoints 153 | 154 | 155 | 156 | task_output_path = os.path.join(output_root_path, task_name) 157 | os.mkdir(task_output_path) 158 | 159 | train_section_dataset = [get_sectionwise_allxmin_summarization_dps(x) for x in tqdm(train_dataset)] 160 | val_section_dataset = [get_sectionwise_allxmin_summarization_dps(x) for x in val_dataset] 161 | test_section_dataset = [get_sectionwise_allxmin_summarization_dps(x) for x in test_dataset] 162 | 163 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 164 | for lst in train_section_dataset: 165 | for obj in lst: 166 | w.write(obj) 167 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 168 | for lst in val_section_dataset: 169 | for obj in lst: 170 | w.write(obj) 171 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 172 | for lst in test_section_dataset: 173 | for obj in lst: 174 | w.write(obj) 175 | 176 | 177 | 178 | 179 | task_name = "entrywise_summarization" 180 | 181 | def get_entrywise_summarization_dps(base_dp): 182 | individual_entry_dps=[] 183 | for section_name in seq_of_sections: 184 | sorted_entries = sorted(base_dp[section_name], 185 | key=lambda i:( float(i["xmin_interval"][0]) , float(i["xmin_interval"][1]) ) ) 186 | 187 | for idx, entry in enumerate(sorted_entries): 188 | sorted_turns = [base_dp["transcript"][k] for k in entry["sorted_xmins"]] 189 | input_lines = [] 190 | for turn in sorted_turns: 191 | input_lines.append(turn["speaker"]+" "+turn["txt"]) 192 | output_lines = [entry["summary"]] 193 | 194 | if len(input_lines)==0: 195 | continue 196 | 197 | individual_entry_dps.append({"case_id":base_dp["id"] , "index_in_note":idx , "article_lines":input_lines, "summary_lines":output_lines, "section":section_name}) 198 | 199 | return individual_entry_dps 200 | 201 | 202 | 203 | task_output_path = os.path.join(output_root_path, task_name) 204 | os.mkdir(task_output_path) 205 | 206 | train_section_dataset = [get_entrywise_summarization_dps(x) for x in tqdm(train_dataset)] 207 | val_section_dataset = [get_entrywise_summarization_dps(x) for x in val_dataset] 208 | test_section_dataset = [get_entrywise_summarization_dps(x) for x in test_dataset] 209 | 210 | with jsonlines.open(os.path.join(task_output_path, "train.jsonl"), "w") as w: 211 | for lst in train_section_dataset: 212 | for obj in lst: 213 | w.write(obj) 214 | with jsonlines.open(os.path.join(task_output_path, "val.jsonl"), "w") as w: 215 | for lst in val_section_dataset: 216 | for obj in lst: 217 | w.write(obj) 218 | with jsonlines.open(os.path.join(task_output_path, "test.jsonl"), "w") as w: 219 | for lst in test_section_dataset: 220 | for obj in lst: 221 | w.write(obj) 222 | 223 | -------------------------------------------------------------------------------- /make_predicted_xmin_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import jsonlines 4 | import os 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser(description='calculate rouge scores') 14 | 15 | parser.add_argument( 16 | '-dataset', 17 | dest='dataset', 18 | help='Dataset name (either medical or ami)', 19 | type=str, 20 | required=True, 21 | ) 22 | 23 | parser.add_argument( 24 | '-ser_dir', 25 | dest='ser_dir', 26 | help='serialization directory containing validation and test outputs', 27 | type=str, 28 | required=True, 29 | ) 30 | 31 | 32 | parser.add_argument( 33 | '-mode', 34 | dest='mode', 35 | help='unilabel/multilabel prediction', 36 | default='multilabel', 37 | type=str 38 | ) 39 | 40 | 41 | args=parser.parse_args() 42 | dataset = args.dataset 43 | ser_dir = args.ser_dir 44 | dataset_dir = f"dataset_{dataset}" 45 | mode = args.mode 46 | 47 | 48 | if mode=="unilabel": 49 | val_fpath = os.path.join(ser_dir, "val_outputs.jsonl") 50 | val_predictions = list(jsonlines.open(val_fpath)) 51 | test_fpath = os.path.join(ser_dir, "test_outputs.jsonl") 52 | test_predictions = list(jsonlines.open(test_fpath)) 53 | 54 | all_ground_truths=[] 55 | all_predictions=[] 56 | for elem in val_predictions: 57 | all_predictions.append(elem["prediction"][0]) 58 | all_ground_truths.append(elem["ground_truth"]) 59 | 60 | all_ground_truths = np.concatenate(all_ground_truths, axis=0) 61 | all_predictions = np.concatenate(all_predictions, axis=0) 62 | 63 | 64 | base_rates = all_ground_truths.sum(axis=0)/len(all_ground_truths) 65 | 66 | thresholds = [] 67 | 68 | for j in range(all_predictions.shape[1]): 69 | br = base_rates[j] 70 | sec_pred_probs = all_predictions[:,j] 71 | cutoff = np.quantile(sec_pred_probs,1-br) 72 | thresholds.append(cutoff) 73 | 74 | 75 | def get_allxmin_dp(base_dp): 76 | # pdb.set_trace() 77 | allxmin_utterances = [] 78 | 79 | test_texts = base_dp["input"] 80 | y_pred = np.array(base_dp["prediction"][0]) 81 | 82 | 83 | threshold = thresholds[0] 84 | sec_pred_probs = y_pred[:,0] 85 | is_xmin = sec_pred_probs>=threshold 86 | 87 | for line, pred in zip(test_texts, is_xmin): 88 | if pred: 89 | allxmin_utterances.append(line) 90 | 91 | dp_to_return = {"case_id":base_dp['case_id'], 92 | "article_lines":allxmin_utterances, 93 | "summary_lines":["dummy"]} 94 | 95 | # pdb.set_trace() 96 | 97 | return dp_to_return 98 | 99 | allxmin_dps = [] 100 | 101 | for dp in tqdm(test_predictions): 102 | new_dp = get_allxmin_dp(dp) 103 | allxmin_dps.append(new_dp) 104 | 105 | 106 | existing_caseids = set() 107 | 108 | for dp in allxmin_dps: 109 | existing_caseids.add(dp["case_id"]) 110 | 111 | print(f'{len(existing_caseids)} cases found in the outgoing file') 112 | 113 | output_path = os.path.join(ser_dir, "predicted_allxmin_test.jsonl") 114 | with jsonlines.open(output_path, "w") as w: 115 | for dp in allxmin_dps: 116 | w.write(dp) 117 | 118 | exit(0) 119 | 120 | 121 | 122 | 123 | # GETTING LABEL DICT 124 | temp_dataset_path = os.path.join(dataset_dir, "sectionwise_xmin_multilabel_classification", "test.jsonl") 125 | temp_dataset=list(jsonlines.open(temp_dataset_path, "r")) 126 | label_dict = temp_dataset[0]["label_dict"] 127 | label_arr = ["_" for _ in range(len(label_dict.keys()))] 128 | for label, idx in label_dict.items(): 129 | label_arr[idx]=label 130 | 131 | 132 | val_fpath = os.path.join(ser_dir, "val_outputs.jsonl") 133 | val_predictions = list(jsonlines.open(val_fpath)) 134 | test_fpath = os.path.join(ser_dir, "test_outputs.jsonl") 135 | test_predictions = list(jsonlines.open(test_fpath)) 136 | 137 | all_ground_truths=[] 138 | all_predictions=[] 139 | for elem in val_predictions: 140 | all_predictions.append(elem["prediction"][0]) 141 | all_ground_truths.append(elem["ground_truth"]) 142 | 143 | all_ground_truths = np.concatenate(all_ground_truths, axis=0) 144 | all_predictions = np.concatenate(all_predictions, axis=0) 145 | 146 | 147 | base_rates = all_ground_truths.sum(axis=0)/len(all_ground_truths) 148 | 149 | thresholds = [] 150 | 151 | for j in range(all_predictions.shape[1]): 152 | br = base_rates[j] 153 | sec_pred_probs = all_predictions[:,j] 154 | cutoff = np.quantile(sec_pred_probs,1-br) 155 | thresholds.append(cutoff) 156 | 157 | thresholds = np.array(thresholds) 158 | print("Thresholds=", thresholds) 159 | 160 | def get_sectionwise_dps(base_dp): 161 | sectionwise_xmins = defaultdict(list) 162 | 163 | test_texts = base_dp["input"] 164 | y_pred = np.array(base_dp["prediction"][0]) 165 | 166 | for j, section_name in enumerate(label_arr): 167 | threshold = thresholds[j] 168 | sec_pred_probs = y_pred[:,j] 169 | is_xmin = sec_pred_probs>=threshold 170 | 171 | for line, pred in zip(test_texts, is_xmin): 172 | if pred: 173 | sectionwise_xmins[section_name].append(line) 174 | 175 | dps_to_return=[] 176 | 177 | for section, lines in sectionwise_xmins.items(): 178 | dps_to_return.append({"case_id":base_dp['case_id'], "article_lines":lines, "summary_lines":["dummy"], "section":section}) 179 | 180 | return dps_to_return 181 | 182 | 183 | sectionwise_allxmin_dps = [] 184 | 185 | for dp in tqdm(test_predictions): 186 | new_dps = get_sectionwise_dps(dp) 187 | sectionwise_allxmin_dps.extend(new_dps) 188 | 189 | 190 | existing_caseids = set() 191 | 192 | for dp in sectionwise_allxmin_dps: 193 | existing_caseids.add(dp["case_id"]) 194 | 195 | print(f'{len(existing_caseids)} cases found in the outgoing file') 196 | 197 | output_path = os.path.join(ser_dir, "predicted_sectionwise_allxmin.jsonl") 198 | with jsonlines.open(output_path, "w") as w: 199 | for dp in sectionwise_allxmin_dps: 200 | w.write(dp) 201 | 202 | 203 | 204 | ############################### 205 | ######## MAKING CLUSTERS 206 | ############################### 207 | 208 | 209 | def get_intervals(arr, cohesion=0): 210 | arr=list(arr) 211 | arr2=deepcopy(arr) 212 | 213 | # making the 214 | for i, elem in enumerate(arr): 215 | if elem!=1: 216 | continue 217 | lookahead = arr[i+1:i+1+cohesion+1] 218 | if 1 in lookahead: 219 | first_occ = lookahead.index(1) 220 | for j in range(i+1,i+1+first_occ): 221 | if j=threshold 259 | 260 | labels=is_xmin.astype(int) 261 | snippet_intervals = get_intervals(labels, cohesion) 262 | 263 | for _i, interval in enumerate(snippet_intervals): 264 | # CHOICE1: add the sentences in between in the input cluster 265 | # relevant_input_lines = test_texts[interval[0]: interval[1]+1] 266 | 267 | # CHOICE2: do not add the sentences in between in the input cluster 268 | relevant_input_lines = [] 269 | for idx in range(interval[0], interval[1]+1): 270 | if labels[idx]==1: 271 | relevant_input_lines.append(test_texts[idx]) 272 | 273 | dps_to_return.append({ 274 | 'article_lines': relevant_input_lines, 275 | 'summary_lines': ['dummy'], 276 | 'case_id': base_dp['case_id'], 277 | 'index_in_note': _i, 278 | 'section': section_name 279 | }) 280 | 281 | 282 | return dps_to_return 283 | 284 | 285 | 286 | # FIGURING OUT THE OPTIMAL VALUE OF COHESION PARAMETER FROM VALIDATION DATA 287 | temp_dataset_path = os.path.join(dataset_dir, "entrywise_summarization", "val.jsonl") 288 | temp_dataset=list(jsonlines.open(temp_dataset_path, "r")) 289 | gt_num_clusters=len(temp_dataset) 290 | 291 | chosen_cohesion_param=[] 292 | 293 | for cohesion_param in range(0,100): 294 | # print(f"trying out cohesion parameter = {cohesion_param}") 295 | entrywise_xmin_dps = [] 296 | for dp in val_predictions: 297 | new_dps = get_entrywise_dps(dp, cohesion=cohesion_param) 298 | entrywise_xmin_dps.extend(new_dps) 299 | print(f"ground truth has {gt_num_clusters} clusters, cohesion={cohesion_param} created {len(entrywise_xmin_dps)}") 300 | if len(entrywise_xmin_dps)=1 evidence = {num_good_abs}") 271 | print(f"Number of meetings where at least one summary sent has no evidence marked = {len(bad_meetings)}") 272 | 273 | 274 | 275 | for meeting_id, out_dict in final_summary_dict.items(): 276 | formatted_sents = {} 277 | all_transcript_sents = list(transcript_sentences[meeting_id].values()) 278 | all_transcript_sents = sorted(all_transcript_sents, key=lambda x:float(x["start_timestamp"])) 279 | for sent in all_transcript_sents: 280 | if sent["txt"]!="": 281 | formatted_sents[sent["start_timestamp"]] = {"speaker": sent["speaker"], "txt": sent["txt"]} 282 | 283 | out_dict["transcript"] = formatted_sents 284 | 285 | print("-----------------------------------------------------------------------") 286 | print("Sample datapoint looks like this --------------------------------------") 287 | print("-----------------------------------------------------------------------") 288 | 289 | pprint(final_summary_dict["ES2004a"]) 290 | 291 | 292 | all_meeting_ids = list(final_summary_dict.keys()) 293 | 294 | 295 | output_dataset_dir = "../../dataset_ami/root/" 296 | os.mkdir(output_dataset_dir) 297 | 298 | meetings_to_write = set(all_meeting_ids) 299 | print(f"I have {len(meetings_to_write)} datapoints to write") 300 | 301 | with open(os.path.join(output_dataset_dir,"train.jsonl"), "w") as w: 302 | for meeting_id in all_meeting_ids: 303 | if meeting_id[:-1] in train_meetings: # :-1 to remove the final a,b,c letters 304 | out_str=json.dumps(final_summary_dict[meeting_id]) 305 | w.write(out_str+"\n") 306 | meetings_to_write.remove(meeting_id) 307 | 308 | with open(os.path.join(output_dataset_dir,"val.jsonl"), "w") as w: 309 | for meeting_id in all_meeting_ids: 310 | if meeting_id[:-1] in val_meetings: # :-1 to remove the final a,b,c letters 311 | out_str=json.dumps(final_summary_dict[meeting_id]) 312 | w.write(out_str+"\n") 313 | meetings_to_write.remove(meeting_id) 314 | 315 | with open(os.path.join(output_dataset_dir,"test.jsonl"), "w") as w: 316 | for meeting_id in all_meeting_ids: 317 | if meeting_id[:-1] in test_meetings: # :-1 to remove the final a,b,c letters 318 | out_str=json.dumps(final_summary_dict[meeting_id]) 319 | w.write(out_str+"\n") 320 | meetings_to_write.remove(meeting_id) 321 | 322 | print(f"After writing, I have {len(meetings_to_write)} datapoints still left, potentially because they were non-scenario") 323 | print("They are:", meetings_to_write) 324 | 325 | -------------------------------------------------------------------------------- /pointergen/conditioned_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | 6 | import pdb 7 | from torch.autograd import Variable 8 | 9 | import math 10 | 11 | import torch.nn.functional as F 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | from nltk.translate.bleu_score import corpus_bleu, sentence_bleu 16 | from tqdm import tqdm_notebook 17 | 18 | from allennlp.models import Model 19 | 20 | from torch.nn.utils import clip_grad_norm_ 21 | from allennlp.models.model import Model 22 | from pointergen.custom_instance import SyncedFieldsInstance 23 | from typing import Dict 24 | from overrides import overrides 25 | from allennlp.data.dataset import Batch 26 | from allennlp.nn import util 27 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 28 | from allennlp.training.metrics import CategoricalAccuracy 29 | 30 | EPS=1e-8 31 | 32 | 33 | 34 | def add_with_expansion(A, B): 35 | '''A and B must be of single dimension''' 36 | assert A.ndim==1 and B.ndim==1 37 | shape_diff = np.array(B.shape) - np.array(A.shape) 38 | shape_diff = np.clip(shape_diff, a_min=0, a_max=np.inf).astype(np.int32) 39 | padded_A=np.lib.pad(A, ((0,shape_diff[0]),), 'constant', constant_values=(0)) 40 | 41 | shape_diff = np.array(A.shape) - np.array(B.shape) 42 | shape_diff = np.clip(shape_diff, a_min=0, a_max=np.inf).astype(np.int32) 43 | padded_B=np.lib.pad(B, ((0,shape_diff[0]),), 'constant', constant_values=(0)) 44 | 45 | return padded_A+padded_B 46 | 47 | 48 | def uniform_tensor(shape, a, b): 49 | output = torch.FloatTensor(*shape).uniform_(a, b) 50 | return output 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, total_encoder_hidden_size, total_decoder_hidden_size, attn_vec_size): 54 | super(Attention, self).__init__() 55 | self.total_encoder_hidden_size=total_encoder_hidden_size 56 | self.total_decoder_hidden_size=total_decoder_hidden_size 57 | self.attn_vec_size=attn_vec_size 58 | 59 | 60 | self.Wh_layer=nn.Linear(total_encoder_hidden_size, attn_vec_size, bias=False) 61 | self.Ws_layer=nn.Linear(total_decoder_hidden_size, attn_vec_size, bias=True) 62 | self.selector_vector_layer=nn.Linear(attn_vec_size, 1, bias=False) # called 'v' in see et al 63 | 64 | 65 | def forward(self, encoded_seq, decoder_state, input_pad_mask): 66 | ''' 67 | encoded seq is batchsizexenc_seqlenxtotal_encoder_hidden_size 68 | decoder_state is batchsizexdec_seqlenxtotal_decoder_hidden_size 69 | ''' 70 | 71 | projected_decstates = self.Ws_layer(decoder_state) 72 | projected_encstates = self.Wh_layer(encoded_seq) 73 | 74 | added_projections=projected_decstates.unsqueeze(2)+projected_encstates.unsqueeze(1) #batchsizeXdeclenXenclenXattnvecsize 75 | added_projections=torch.tanh(added_projections) 76 | 77 | attn_logits=self.selector_vector_layer(added_projections) 78 | attn_logits=attn_logits.squeeze(3) 79 | 80 | attn_weights = torch.softmax(attn_logits, dim=-1) # shape=batchXdec_lenXenc_len 81 | attn_weights2 = attn_weights*input_pad_mask.unsqueeze(1) 82 | attn_weights_renormalized = attn_weights2/torch.sum(attn_weights2, dim=-1, keepdim=True) # shape=batchx1x1 # TODO - why is there a division without EPS ? 83 | 84 | context_vector = torch.sum(encoded_seq.unsqueeze(1)*attn_weights_renormalized.unsqueeze(-1) , dim=-2) 85 | 86 | return context_vector, attn_weights_renormalized 87 | 88 | 89 | 90 | 91 | class CopyMechanism(nn.Module): 92 | def __init__( 93 | self, encoder_hidden_size, decoder_hidden_size, decoder_input_size): 94 | super(CopyMechanism, self).__init__() 95 | self.pgen=nn.Sequential( 96 | nn.Linear(encoder_hidden_size+2*decoder_hidden_size+decoder_input_size, 1), 97 | nn.Sigmoid() 98 | ) 99 | self.output_probs=nn.Softmax(dim=-1) 100 | 101 | def forward( 102 | self, output_logits, attn_weights, decoder_hidden_state, decoder_input, 103 | context_vector, encoder_input, max_oovs): 104 | '''output_logits = batchXseqlenXoutvocab 105 | attn_weights = batchXseqlenXenc_len 106 | decoder_hidden_state = batchXseqlenXdecoder_hidden_size 107 | context_vector = batchXseqlenXencoder_hidden_dim 108 | encoder_input = batchxenc_len''' 109 | output_probabilities=self.output_probs(output_logits) 110 | 111 | # print(output_probabilities) 112 | 113 | batch_size = output_probabilities.size(0) 114 | output_len = output_probabilities.size(1) 115 | append_for_copy = torch.zeros((batch_size, output_len, max_oovs)).cuda() 116 | output_probabilities=torch.cat([output_probabilities, append_for_copy], dim=-1) 117 | 118 | pre_pgen_tensor=torch.cat([context_vector, decoder_hidden_state, decoder_input], dim=-1) 119 | pgen=self.pgen(pre_pgen_tensor) # batchsizeXseqlenX1 120 | pcopy=1.0-pgen 121 | 122 | encoder_input=encoder_input.unsqueeze(1).expand(-1, output_len , -1) # batchXseqlenXenc_len 123 | 124 | # Note that padding words donot get any attention because the attention is a masked attention 125 | 126 | copy_probabilities=torch.zeros_like(output_probabilities) # batchXseqlenXoutvocab 127 | copy_probabilities.scatter_add_(2, encoder_input, attn_weights) 128 | 129 | # print(copy_probabilities) 130 | 131 | # try: 132 | # copy_probabilities=torch.zeros_like(output_probabilities) # batchsizexout_vocab 133 | # copy_probabilities.scatter_add_(1, encoder_input, attn_weights) 134 | # except RuntimeError: 135 | # print("hajraat hajraat hajraat ", output_probabilities.shape) 136 | 137 | total_probabilities=pgen*output_probabilities+pcopy*copy_probabilities 138 | return total_probabilities, pgen # batchXseqlenXoutvocab , batchsizeXseqlenX1 139 | 140 | 141 | class DualEmbedder(nn.Module): 142 | def __init__( 143 | self, token_vocab_size, token_emb_size, conditioning_vocab_size, conditioning_emb_size): 144 | super(DualEmbedder, self).__init__() 145 | self.token_embedder = nn.Embedding(num_embeddings=token_vocab_size, embedding_dim=token_emb_size) 146 | self.conditioning_embedder = nn.Embedding(num_embeddings=conditioning_vocab_size, embedding_dim=conditioning_emb_size) 147 | 148 | def forward( 149 | self, token_batch, condition_batch): 150 | '''token_batch = batchXseqlen 151 | condition_batch = batch''' 152 | embedded_token_seq = self.token_embedder(token_batch) 153 | embedded_conditioning_seq = self.conditioning_embedder(condition_batch.unsqueeze(1)) 154 | 155 | seq_len = token_batch.shape[1] 156 | embedded_conditioning_seq_repeated = embedded_conditioning_seq.repeat(1,seq_len,1) 157 | return torch.cat([embedded_token_seq, embedded_conditioning_seq_repeated], axis=-1) 158 | 159 | 160 | @Model.register("conditioned_pointer_generator") 161 | class Seq2Seq(Model): 162 | def __init__(self, vocab, hidden_size=256, token_emb_size=128, num_encoder_layers=1, num_decoder_layers=1, min_decode_length=0, max_decode_length=9999, use_copy_mech=True, conditioning_emb_size=32): 163 | super().__init__(vocab) 164 | # self.vocab=vocab 165 | 166 | ## vocab related setup begins 167 | # assert "tokens" in vocab._token_to_index and len(vocab._token_to_index.keys())==1, "Vocabulary must have tokens as the only namespace" 168 | self.vocab_size=vocab.get_vocab_size(namespace="tokens") 169 | self.PAD_ID = vocab.get_token_index(vocab._padding_token, namespace="tokens") 170 | self.OOV_ID = vocab.get_token_index(vocab._oov_token, namespace="tokens") 171 | self.START_ID = vocab.get_token_index(START_SYMBOL, namespace="tokens") 172 | self.END_ID = vocab.get_token_index(END_SYMBOL, namespace="tokens") 173 | 174 | self.conditioning_vocab_size = vocab.get_vocab_size(namespace="section_labels") 175 | ## vocab related setup ends 176 | 177 | 178 | self.token_emb_size = token_emb_size 179 | self.conditioning_emb_size = conditioning_emb_size 180 | self.hidden_size=hidden_size 181 | self.num_encoder_layers=num_encoder_layers 182 | self.num_decoder_layers=num_decoder_layers 183 | self.crossentropy=nn.CrossEntropyLoss() 184 | 185 | self.min_decode_length = min_decode_length 186 | self.max_decode_length = max_decode_length 187 | 188 | self.metrics = { 189 | "accuracy" : CategoricalAccuracy(), 190 | } 191 | 192 | # buffers because these dont need grads. These are placed here because they will be replicated across gpus 193 | self.register_buffer("true_rep", torch.tensor(1.0)) 194 | self.register_buffer("false_rep", torch.tensor(0.0)) 195 | 196 | self.pre_output_dim=hidden_size 197 | 198 | self.use_copy_mech=use_copy_mech 199 | 200 | self.output_embedder = DualEmbedder(self.vocab_size, self.token_emb_size, self.conditioning_vocab_size, self.conditioning_emb_size) 201 | 202 | self.encoder_rnn = torch.nn.LSTM(input_size=self.token_emb_size+self.conditioning_emb_size, hidden_size=self.hidden_size, num_layers=self.num_encoder_layers, batch_first=True, bidirectional=True) 203 | 204 | self.fuse_h_layer= nn.Sequential( 205 | nn.Linear(2*hidden_size, hidden_size), 206 | nn.ReLU() 207 | ) 208 | 209 | self.fuse_c_layer= nn.Sequential( 210 | nn.Linear(2*hidden_size, hidden_size), 211 | nn.ReLU() 212 | ) 213 | 214 | self.attention_layer=Attention(2*hidden_size, 2*hidden_size, 2*hidden_size) 215 | 216 | if self.use_copy_mech: 217 | self.copymech=CopyMechanism(2*self.hidden_size, self.hidden_size, self.token_emb_size+self.conditioning_emb_size) 218 | 219 | self.decoder_rnn=torch.nn.LSTM(input_size=self.token_emb_size+self.conditioning_emb_size, hidden_size=self.hidden_size, num_layers=self.num_decoder_layers, batch_first=False, bidirectional=False) 220 | 221 | self.statenctx_to_prefinal = nn.Linear(3*hidden_size, hidden_size, bias=True) 222 | self.project_to_decoder_input = nn.Linear(self.token_emb_size+self.conditioning_emb_size+2*hidden_size, self.token_emb_size+self.conditioning_emb_size, bias=True) 223 | 224 | self.output_projector = torch.nn.Conv1d(self.pre_output_dim, self.vocab_size, kernel_size=1, bias=True) 225 | self.softmax = nn.Softmax(dim=-1) 226 | 227 | 228 | def forward(self, source_tokens, target_tokens, section_labels, meta=None, only_predict_probs=False, return_pgen=False): 229 | inp_with_unks = source_tokens["ids_with_unks"] 230 | inp_with_oovs = source_tokens["ids_with_oovs"] 231 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 232 | 233 | feed_tensor = target_tokens["ids_with_unks"][:, :-1] 234 | if self.use_copy_mech: 235 | target_tensor = target_tokens["ids_with_oovs"][:,1:] 236 | else: 237 | target_tensor = target_tokens["ids_with_unks"][:, 1:] 238 | 239 | 240 | batch_size = inp_with_unks.size(0) 241 | # preparing intial state for feeding into decoder. layers of decoder after first one get zeros as initial state 242 | inp_enc_seq, (last_h_value, last_c_value) = self.encode(inp_with_unks, section_labels) 243 | 244 | 245 | # inp_enc_seq is batchsizeXseqlenX2*hiddensize 246 | h_value = self.pad_zeros_to_init_state(last_h_value) 247 | c_value = self.pad_zeros_to_init_state(last_c_value) 248 | state_from_inp = (h_value, c_value) 249 | 250 | input_pad_mask=torch.where(inp_with_unks!=0, self.true_rep, self.false_rep) 251 | 252 | output_embedded = self.output_embedder(feed_tensor, section_labels) 253 | seqlen_first = output_embedded.permute(1,0,2) 254 | output_seq_len = seqlen_first.size(0) 255 | 256 | #initial values 257 | decoder_hidden_state=state_from_inp 258 | context_vector=torch.zeros(batch_size,1,2*self.hidden_size).cuda() 259 | 260 | # CONTROVERSIAL DIFFERENCE FROM SEE ET AL 261 | decoder_hstates_batchfirst = state_from_inp[0].permute(1, 0, 2) 262 | decoder_cstates_batchfirst = state_from_inp[1].permute(1, 0, 2) 263 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 264 | context_vector, _ = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask) 265 | # 266 | 267 | output_probs=[] 268 | pgens=[] 269 | 270 | for _i in range(output_seq_len): 271 | seqlen_first_onetimestep = seqlen_first[_i:_i+1] # shape is 1xbatchsizexembsize 272 | context_vector_seqlenfirst = context_vector.permute(1,0,2) # seqlen is 1 always 273 | pre_input_to_decoder=torch.cat([seqlen_first_onetimestep, context_vector_seqlenfirst], dim=-1) 274 | input_to_decoder=self.project_to_decoder_input(pre_input_to_decoder) # shape is 1xbatchsizexembsize 275 | 276 | decoder_h_values, decoder_hidden_state = self.decoder_rnn(input_to_decoder, decoder_hidden_state) 277 | # decoder_h_values is shape 1XbatchsizeXhiddensize 278 | 279 | decoder_h_values_batchfirst = decoder_h_values.permute(1,0,2) 280 | 281 | decoder_hstates_batchfirst = decoder_hidden_state[0].permute(1, 0, 2) 282 | decoder_cstates_batchfirst = decoder_hidden_state[1].permute(1, 0, 2) 283 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 284 | 285 | context_vector, attn_weights = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask) 286 | 287 | decstate_and_context=torch.cat([decoder_h_values_batchfirst, context_vector], dim=-1) #batchsizeXdec_seqlenX3*hidden_size 288 | prefinal_tensor = self.statenctx_to_prefinal(decstate_and_context) 289 | seqlen_last = prefinal_tensor.permute(0,2,1) #batchsizeXpre_output_dimXdec_seqlen 290 | logits = self.output_projector(seqlen_last) 291 | logits = logits.permute(0,2,1) # batchXdec_seqlenXvocab 292 | 293 | # return self.copymech.output_probs(logits) 294 | 295 | # now doing copymechanism 296 | if self.use_copy_mech: 297 | probs_after_copying, pgen = self.copymech(logits, attn_weights, concatenated_decoder_states, input_to_decoder.permute(1,0,2), context_vector, inp_with_oovs, max_oovs) 298 | pgens.append(pgen) 299 | output_probs.append(probs_after_copying) 300 | else: 301 | output_probs.append(self.softmax(logits)) 302 | 303 | # if only_predict_probs: 304 | # return output_probs 305 | 306 | # now calculating loss and numpreds 307 | '''outprobs is list of batchX1xvocabsize 308 | target_tensor is batchXseqlen''' 309 | targets_tensor_seqfirst = target_tensor.permute(1,0) 310 | pad_mask=torch.where(targets_tensor_seqfirst!=self.PAD_ID, self.true_rep, self.false_rep) 311 | # TODO: SHOULD WE SET REQUIRES_GRAD=FALSE FOR PAD_MASK? 312 | 313 | loss=0.0 314 | numpreds=0 315 | total_pgen=0 316 | 317 | total_pgen_placewise=torch.zeros((output_seq_len)).cuda() 318 | numpreds_placewise=torch.zeros((output_seq_len)).cuda() 319 | 320 | if return_pgen and not self.use_copy_mech: 321 | print("Cannot return pgen when copy mechanism is switched off") 322 | assert False 323 | 324 | for _i in range(len(output_probs)): 325 | predicted_probs = output_probs[_i].squeeze(1) 326 | true_labels = targets_tensor_seqfirst[_i] 327 | mask_labels = pad_mask[_i] 328 | selected_probs=torch.gather(input=predicted_probs, dim=1, index=true_labels.unsqueeze(1)) 329 | selected_probs=selected_probs.squeeze(1) 330 | selected_neg_logprobs=-1*torch.log(selected_probs) 331 | loss+=torch.sum(selected_neg_logprobs*mask_labels) 332 | 333 | this_numpreds=torch.sum(mask_labels).detach() 334 | numpreds+=this_numpreds 335 | 336 | for metric in self.metrics.values(): 337 | metric(predicted_probs, true_labels, mask_labels) 338 | 339 | if return_pgen: 340 | pgen=pgens[_i].squeeze(1).squeeze(1) 341 | total_pgen+=torch.sum(pgen*mask_labels) 342 | 343 | total_pgen_placewise[_i]+=torch.sum(pgen*mask_labels).detach() 344 | numpreds_placewise[_i]+=this_numpreds 345 | 346 | # print(pgen.shape, mask_labels.shape , (pgen*mask_labels).shape, selected_neg_logprobs.shape) 347 | # print(torch.sum(pgen), torch.sum(pgen*mask_labels), torch.sum(mask_labels).detach()) 348 | 349 | if return_pgen: 350 | return { 351 | "loss": loss/numpreds, 352 | "total_pgen": total_pgen, 353 | "numpreds_placewise": numpreds_placewise, 354 | "total_pgen_placewise": total_pgen_placewise 355 | } 356 | else: 357 | return { 358 | "loss": loss/numpreds 359 | } 360 | 361 | 362 | def pad_zeros_to_init_state(self, h_value): 363 | '''can also be c_value''' 364 | assert(h_value.size(0)==1) # h_value should only be of last layer of lstm 365 | return torch.cat([h_value]+[torch.zeros_like(h_value) for _i in range(self.num_encoder_layers-1)], dim=0) 366 | 367 | 368 | def encode(self, inp, section_labels): 369 | '''Get the encoding of input''' 370 | batch_size = inp.size(0) 371 | inp_seq_len = inp.size(1) 372 | inp_embedded = self.output_embedder(inp, section_labels) 373 | inp_encoded = self.encoder_rnn(inp_embedded) 374 | output_seq=inp_encoded[0] 375 | h_value, c_value = inp_encoded[1] 376 | 377 | h_value_layerwise=h_value.reshape(self.num_encoder_layers, 2, batch_size, self.hidden_size) # numlayersXbidirecXbatchXhid 378 | c_value_layerwise=c_value.reshape(self.num_encoder_layers, 2, batch_size, self.hidden_size) # numlayersXbidirecXbatchXhid 379 | 380 | last_layer_h=h_value_layerwise[-1:,:,:,:] 381 | last_layer_c=c_value_layerwise[-1:,:,:,:] 382 | 383 | last_layer_h=last_layer_h.permute(0,2,1,3).contiguous().view(1, batch_size, 2*self.hidden_size) 384 | last_layer_c=last_layer_c.permute(0,2,1,3).contiguous().view(1, batch_size, 2*self.hidden_size) 385 | 386 | last_layer_h_fused=self.fuse_h_layer(last_layer_h) 387 | last_layer_c_fused=self.fuse_c_layer(last_layer_c) 388 | 389 | return output_seq, (last_layer_h_fused, last_layer_c_fused) 390 | 391 | 392 | def decode_onestep(self, past_outp_input, section_label, past_state_tuple, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs): 393 | '''run one step of decoder. past_outp_input is batchsizex1 394 | section_label is batchsize 395 | past_context_vector is batchsizeX1Xtwice_of_hiddensize''' 396 | outp_embedded = self.output_embedder(past_outp_input, section_label) 397 | tok_seqlen_first = outp_embedded.permute(1,0,2) 398 | assert(tok_seqlen_first.size(0)==1) # only one timestep allowed 399 | 400 | context_vector_seqlenfirst = past_context_vector.permute(1,0,2) # seqlen is 1 always 401 | pre_input_to_decoder=torch.cat([tok_seqlen_first, context_vector_seqlenfirst], dim=-1) 402 | input_to_decoder=self.project_to_decoder_input(pre_input_to_decoder) # shape is 1xbatchsizexembsize 403 | 404 | 405 | decoder_h_values, decoder_hidden_state = self.decoder_rnn(input_to_decoder, past_state_tuple) 406 | # decoder_h_values is shape 1XbatchsizeXhiddensize 407 | decoder_h_values_batchfirst = decoder_h_values.permute(1,0,2) 408 | 409 | decoder_hstates_batchfirst = decoder_hidden_state[0].permute(1, 0, 2) 410 | decoder_cstates_batchfirst = decoder_hidden_state[1].permute(1, 0, 2) 411 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 412 | 413 | context_vector, attn_weights = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask) 414 | 415 | decstate_and_context=torch.cat([decoder_h_values_batchfirst, context_vector], dim=-1) #batchsizeXdec_seqlenX3*hidden_size 416 | prefinal_tensor = self.statenctx_to_prefinal(decstate_and_context) 417 | seqlen_last = prefinal_tensor.permute(0,2,1) #batchsizeXpre_output_dimXdec_seqlen 418 | logits = self.output_projector(seqlen_last) 419 | logits = logits.permute(0,2,1) # batchXdec_seqlenXvocab 420 | 421 | # now doing copymechanism 422 | if self.use_copy_mech: 423 | probs_after_copying, _ = self.copymech(logits, attn_weights, concatenated_decoder_states, input_to_decoder.permute(1,0,2), context_vector, inp_with_oovs, max_oovs) 424 | prob_to_return = probs_after_copying[0].squeeze(1) 425 | else: 426 | prob_to_return = self.softmax(logits).squeeze(1) 427 | 428 | # max_attended = inp_with_oovs[0][torch.argmax(attn_weights)].item() 429 | # max_prob = torch.argmax(probs_after_copying[0][0]) 430 | # print("Attended=", self.vocab._id2token[max_attended]) 431 | # print("Maxprob=", self.vocab._id2token[max_prob]) 432 | 433 | return prob_to_return, decoder_hidden_state, context_vector 434 | 435 | 436 | 437 | 438 | @overrides 439 | def forward_on_instance(self, instance: SyncedFieldsInstance, decode_strategy=None) -> Dict[str, str]: 440 | """ 441 | Takes an :class:`~allennlp.data.instance.Instance`, which typically has raw text in it, 442 | converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays 443 | through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) 444 | and returns the result. Before returning the result, we convert any 445 | ``torch.Tensors`` into numpy arrays and remove the batch dimension. 446 | """ 447 | cuda_device = self._get_prediction_device() 448 | dataset = Batch([instance]) 449 | dataset.index_instances(self.vocab) 450 | model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) 451 | output_ids = self.beam_search_decode(**model_input, min_length=self.min_decode_length, max_length=self.max_decode_length) 452 | # output_ids = self.greedy_decode(**model_input, min_length=self.min_decode_length, max_length=self.max_decode_length) 453 | 454 | output_words = [] 455 | for _id in output_ids: 456 | if _id Dict[str, float]: 470 | metrics_to_return = { 471 | metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items() 472 | } 473 | return metrics_to_return 474 | 475 | 476 | def beam_search_decode(self, source_tokens, section_labels, target_tokens=None, meta=None, beam_width=4, min_length=35, max_length=120): 477 | inp_with_unks = source_tokens["ids_with_unks"] 478 | inp_with_oovs = source_tokens["ids_with_oovs"] 479 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 480 | input_pad_mask=torch.where(inp_with_unks!=self.PAD_ID, self.true_rep, self.false_rep) 481 | inp_enc_seq, (intial_h_value, intial_c_value) = self.encode(inp_with_unks, section_labels) 482 | h_value = self.pad_zeros_to_init_state(intial_h_value) 483 | c_value = self.pad_zeros_to_init_state(intial_c_value) 484 | source_encoding=(h_value, c_value) 485 | 486 | # the first context vector is calculated by using the first lstm decoder state 487 | first_decoder_hstates_batchfirst = source_encoding[0].permute(1, 0, 2) 488 | first_decoder_cstates_batchfirst = source_encoding[1].permute(1, 0, 2) 489 | first_concatenated_decoder_states = torch.cat([first_decoder_cstates_batchfirst, first_decoder_hstates_batchfirst], dim=-1) 490 | first_context_vector, _ = self.attention_layer(inp_enc_seq, first_concatenated_decoder_states, input_pad_mask) 491 | 492 | hypotheses = [ {"dec_state" : source_encoding, 493 | "past_context_vector" : first_context_vector, 494 | "logprobs" : [0.0], 495 | "out_words" : [self.START_ID] 496 | } ] 497 | 498 | finished_hypotheses = [] 499 | 500 | def sort_hyps(list_of_hyps): 501 | return sorted(list_of_hyps, key=lambda x:sum(x["logprobs"])/len(x["logprobs"]), reverse=True) 502 | 503 | counter=0 504 | while counter=self.vocab_size: # this guy is an OOV 511 | in_tok=self.OOV_ID 512 | old_dec_state=hyp["dec_state"] 513 | past_context_vector=hyp["past_context_vector"] 514 | old_logprobs=hyp["logprobs"] 515 | new_probs, new_dec_state, new_context_vector = self.decode_onestep( torch.tensor([[in_tok]]).cuda(), section_labels, old_dec_state, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs) 516 | 517 | probs, indices = torch.topk(new_probs[0], dim=0, k=2*beam_width) 518 | for p, idx in zip(probs, indices): 519 | new_dict = {"dec_state" : new_dec_state, 520 | "past_context_vector" : new_context_vector, 521 | "logprobs" : old_logprobs+[float(torch.log(p).detach().cpu().numpy())], 522 | "out_words" : old_out_words+[idx.item()] 523 | } 524 | new_hypotheses.append(new_dict) 525 | 526 | # time to pick the best of new hypotheses 527 | sorted_new_hypotheses = sort_hyps(new_hypotheses) 528 | hypotheses=[] 529 | for hyp in sorted_new_hypotheses: 530 | if hyp["out_words"][-1]==self.END_ID: 531 | if len(hyp["out_words"])>min_length+1: 532 | finished_hypotheses.append(hyp) 533 | else: 534 | hypotheses.append(hyp) 535 | if len(hypotheses) == beam_width or len(finished_hypotheses) == beam_width: 536 | break 537 | 538 | # for hyp in finished_hypotheses: 539 | # print(hyp["out_words"]) 540 | 541 | if len(finished_hypotheses)>0: 542 | final_candidates = finished_hypotheses 543 | else: 544 | final_candidates = hypotheses 545 | 546 | sorted_final_candidates = sort_hyps(final_candidates) 547 | 548 | best_candidate = sorted_final_candidates[0] 549 | # second_best_candidate = sorted_final_candidates[1] 550 | 551 | 552 | # print(best_candidate["logprobs"]) 553 | return best_candidate["out_words"] #, best_candidate["log_likelihood"] 554 | 555 | 556 | def greedy_decode(self, source_tokens, section_labels, target_tokens=None, meta=None, min_length=35, max_length=120): 557 | inp_with_unks = source_tokens["ids_with_unks"] 558 | inp_with_oovs = source_tokens["ids_with_oovs"] 559 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 560 | input_pad_mask=torch.where(inp_with_unks!=self.PAD_ID, self.true_rep, self.false_rep) 561 | inp_enc_seq, (intial_h_value, intial_c_value) = self.encode(inp_with_unks, section_labels) 562 | h_value = self.pad_zeros_to_init_state(intial_h_value) 563 | c_value = self.pad_zeros_to_init_state(intial_c_value) 564 | source_encoding=(h_value, c_value) 565 | 566 | # the first context vector is calculated by using the first lstm decoder state 567 | first_decoder_hstates_batchfirst = source_encoding[0].permute(1, 0, 2) 568 | first_decoder_cstates_batchfirst = source_encoding[1].permute(1, 0, 2) 569 | first_concatenated_decoder_states = torch.cat([first_decoder_cstates_batchfirst, first_decoder_hstates_batchfirst], dim=-1) 570 | first_context_vector, _ = self.attention_layer(inp_enc_seq, first_concatenated_decoder_states, input_pad_mask) 571 | 572 | hyp = {"dec_state" : source_encoding, 573 | "past_context_vector" : first_context_vector, 574 | "logprobs" : [0.0], 575 | "out_words" : [self.START_ID] 576 | } 577 | 578 | counter=0 579 | while counter=self.vocab_size: # this guy is an OOV 585 | in_tok=self.OOV_ID 586 | old_dec_state=hyp["dec_state"] 587 | past_context_vector=hyp["past_context_vector"] 588 | old_logprobs=hyp["logprobs"] 589 | new_probs, new_dec_state, new_context_vector = self.decode_onestep( torch.tensor([[in_tok]]).cuda(), section_labels, old_dec_state, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs) 590 | 591 | probs, indices = torch.topk(new_probs[0], dim=0, k=1) 592 | assert len(probs)==1 and len(indices)==1 593 | p = probs[0] 594 | idx = indices[0] 595 | hyp = {"dec_state" : new_dec_state, 596 | "past_context_vector" : new_context_vector, 597 | "logprobs" : old_logprobs+[float(torch.log(p).detach().cpu().numpy())], 598 | "out_words" : old_out_words+[idx.item()] 599 | } 600 | 601 | # time to pick the best of new hypotheses 602 | if hyp["out_words"][-1]==self.END_ID: 603 | if len(hyp["out_words"])>min_length+1: 604 | break 605 | 606 | best_candidate = hyp 607 | 608 | # print(best_candidate["logprobs"]) 609 | return best_candidate["out_words"] #, best_candidate["log_likelihood"] 610 | -------------------------------------------------------------------------------- /pointergen/conditioned_model_with_coverage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from allennlp.models.model import Model 9 | from pointergen.custom_instance import SyncedFieldsInstance 10 | from typing import Dict 11 | from overrides import overrides 12 | from allennlp.data.dataset import Batch 13 | from allennlp.nn import util 14 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 15 | from allennlp.training.metrics import CategoricalAccuracy, Average 16 | 17 | import pdb 18 | 19 | 20 | EPS=1e-8 21 | 22 | 23 | 24 | def add_with_expansion(A, B): 25 | '''A and B must be of single dimension''' 26 | assert A.ndim==1 and B.ndim==1 27 | shape_diff = np.array(B.shape) - np.array(A.shape) 28 | shape_diff = np.clip(shape_diff, a_min=0, a_max=np.inf).astype(np.int32) 29 | padded_A=np.lib.pad(A, ((0,shape_diff[0]),), 'constant', constant_values=(0)) 30 | 31 | shape_diff = np.array(A.shape) - np.array(B.shape) 32 | shape_diff = np.clip(shape_diff, a_min=0, a_max=np.inf).astype(np.int32) 33 | padded_B=np.lib.pad(B, ((0,shape_diff[0]),), 'constant', constant_values=(0)) 34 | 35 | return padded_A+padded_B 36 | 37 | 38 | def uniform_tensor(shape, a, b): 39 | output = torch.FloatTensor(*shape).uniform_(a, b) 40 | return output 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, total_encoder_hidden_size, total_decoder_hidden_size, attn_vec_size): 44 | super(Attention, self).__init__() 45 | self.total_encoder_hidden_size=total_encoder_hidden_size 46 | self.total_decoder_hidden_size=total_decoder_hidden_size 47 | self.attn_vec_size=attn_vec_size 48 | 49 | 50 | self.Wh_layer=nn.Linear(total_encoder_hidden_size, attn_vec_size, bias=False) 51 | self.Ws_layer=nn.Linear(total_decoder_hidden_size, attn_vec_size, bias=True) 52 | self.selector_vector_layer=nn.Linear(attn_vec_size, 1, bias=False) # called 'v' in see et al 53 | 54 | self.Wc_layer=nn.Linear(1, attn_vec_size, bias=False) 55 | torch.nn.init.zeros_(self.Wc_layer.weight) 56 | 57 | 58 | def forward(self, encoded_seq, decoder_state, input_pad_mask, coverage=None): 59 | ''' 60 | encoded seq is batchsizexenc_seqlenxtotal_encoder_hidden_size 61 | decoder_state is batchsizexdec_seqlenxtotal_decoder_hidden_size 62 | coverage = batchsizexdec_seqlenxenc_seqlen 63 | ''' 64 | 65 | projected_decstates = self.Ws_layer(decoder_state) 66 | projected_encstates = self.Wh_layer(encoded_seq) 67 | added_projections=projected_decstates.unsqueeze(2)+projected_encstates.unsqueeze(1) #batchsizeXdeclenXenclenXattnvecsize 68 | 69 | if coverage is not None: 70 | projected_coverage = self.Wc_layer(coverage.unsqueeze(-1)) # shape = batchsize X dec_seqlen x enc_seqlen X attn_vec_size 71 | added_projections += projected_coverage 72 | 73 | added_projections=torch.tanh(added_projections) 74 | 75 | attn_logits=self.selector_vector_layer(added_projections) 76 | attn_logits=attn_logits.squeeze(3) 77 | 78 | attn_weights = torch.softmax(attn_logits, dim=-1) # shape=batchXdec_lenXenc_len 79 | attn_weights2 = attn_weights*input_pad_mask.unsqueeze(1) 80 | attn_weights_renormalized = attn_weights2/torch.sum(attn_weights2, dim=-1, keepdim=True) # shape=batchx1x1 # TODO - why is there a division without EPS ? 81 | 82 | context_vector = torch.sum(encoded_seq.unsqueeze(1)*attn_weights_renormalized.unsqueeze(-1) , dim=-2) 83 | # shape batchXdec_seqlenXhiddensize 84 | # print(context_vector) 85 | 86 | return context_vector, attn_weights_renormalized 87 | 88 | 89 | 90 | 91 | class CopyMechanism(nn.Module): 92 | def __init__( 93 | self, encoder_hidden_size, decoder_hidden_size, decoder_input_size): 94 | super(CopyMechanism, self).__init__() 95 | self.pgen=nn.Sequential( 96 | nn.Linear(encoder_hidden_size+2*decoder_hidden_size+decoder_input_size, 1), 97 | nn.Sigmoid() 98 | ) 99 | self.output_probs=nn.Softmax(dim=-1) 100 | 101 | def forward( 102 | self, output_logits, attn_weights, decoder_hidden_state, decoder_input, 103 | context_vector, encoder_input, max_oovs): 104 | '''output_logits = batchXseqlenXoutvocab 105 | attn_weights = batchXseqlenXenc_len 106 | decoder_hidden_state = batchXseqlenXdecoder_hidden_size 107 | context_vector = batchXseqlenXencoder_hidden_dim 108 | encoder_input = batchxenc_len''' 109 | output_probabilities=self.output_probs(output_logits) 110 | 111 | # print(output_probabilities) 112 | 113 | batch_size = output_probabilities.size(0) 114 | output_len = output_probabilities.size(1) 115 | append_for_copy = torch.zeros((batch_size, output_len, max_oovs)).cuda() 116 | output_probabilities=torch.cat([output_probabilities, append_for_copy], dim=-1) 117 | 118 | pre_pgen_tensor=torch.cat([context_vector, decoder_hidden_state, decoder_input], dim=-1) 119 | pgen=self.pgen(pre_pgen_tensor) # batchsizeXseqlenX1 120 | pcopy=1.0-pgen 121 | 122 | encoder_input=encoder_input.unsqueeze(1).expand(-1, output_len , -1) # batchXseqlenXenc_len 123 | 124 | # Note that padding words donot get any attention because the attention is a masked attention 125 | 126 | copy_probabilities=torch.zeros_like(output_probabilities) # batchXseqlenXoutvocab 127 | copy_probabilities.scatter_add_(2, encoder_input, attn_weights) 128 | 129 | total_probabilities=pgen*output_probabilities+pcopy*copy_probabilities 130 | return total_probabilities, pgen # batchXseqlenXoutvocab , batchsizeXseqlenX1 131 | 132 | 133 | 134 | class DualEmbedder(nn.Module): 135 | def __init__( 136 | self, token_vocab_size, token_emb_size, conditioning_vocab_size, conditioning_emb_size): 137 | super(DualEmbedder, self).__init__() 138 | self.token_embedder = nn.Embedding(num_embeddings=token_vocab_size, embedding_dim=token_emb_size) 139 | self.conditioning_embedder = nn.Embedding(num_embeddings=conditioning_vocab_size, embedding_dim=conditioning_emb_size) 140 | 141 | def forward( 142 | self, token_batch, condition_batch): 143 | '''token_batch = batchXseqlen 144 | condition_batch = batch''' 145 | embedded_token_seq = self.token_embedder(token_batch) 146 | embedded_conditioning_seq = self.conditioning_embedder(condition_batch.unsqueeze(1)) 147 | 148 | seq_len = token_batch.shape[1] 149 | embedded_conditioning_seq_repeated = embedded_conditioning_seq.repeat(1,seq_len,1) 150 | return torch.cat([embedded_token_seq, embedded_conditioning_seq_repeated], axis=-1) 151 | 152 | 153 | @Model.register("conditioned_pointer_generator_withcoverage") 154 | class Seq2Seq(Model): 155 | def __init__(self, vocab, hidden_size=256, token_emb_size=128, num_encoder_layers=1, num_decoder_layers=1, min_decode_length=0, max_decode_length=9999, use_copy_mech=True, initial_precoverage_paramfile=None, coverage_coef=1.0, conditioning_emb_size=32): 156 | super().__init__(vocab) 157 | # self.vocab=vocab 158 | 159 | ## vocab related setup begins 160 | self.vocab_size=vocab.get_vocab_size(namespace="tokens") 161 | self.PAD_ID = vocab.get_token_index(vocab._padding_token, namespace="tokens") 162 | self.OOV_ID = vocab.get_token_index(vocab._oov_token, namespace="tokens") 163 | self.START_ID = vocab.get_token_index(START_SYMBOL, namespace="tokens") 164 | self.END_ID = vocab.get_token_index(END_SYMBOL, namespace="tokens") 165 | 166 | 167 | self.conditioning_vocab_size = vocab.get_vocab_size(namespace="section_labels") 168 | ## vocab related setup ends 169 | 170 | 171 | 172 | self.token_emb_size = token_emb_size 173 | self.conditioning_emb_size = conditioning_emb_size 174 | self.hidden_size=hidden_size 175 | self.num_encoder_layers=num_encoder_layers 176 | self.num_decoder_layers=num_decoder_layers 177 | self.crossentropy=nn.CrossEntropyLoss() 178 | 179 | self.min_decode_length = min_decode_length 180 | self.max_decode_length = max_decode_length 181 | 182 | self.coverage_coef = coverage_coef 183 | 184 | 185 | 186 | self.metrics = { 187 | "accuracy" : CategoricalAccuracy(), 188 | "coverage_loss": Average(), 189 | "nll_loss": Average(), 190 | "total_loss": Average(), 191 | } 192 | 193 | # buffers because these dont need grads. These are placed here because they will be replicated across gpus 194 | self.register_buffer("true_rep", torch.tensor(1.0)) 195 | self.register_buffer("false_rep", torch.tensor(0.0)) 196 | 197 | self.pre_output_dim=hidden_size 198 | 199 | self.use_copy_mech=use_copy_mech 200 | 201 | self.output_embedder = DualEmbedder(self.vocab_size, self.token_emb_size, self.conditioning_vocab_size, self.conditioning_emb_size) 202 | 203 | self.encoder_rnn = torch.nn.LSTM(input_size=self.token_emb_size+self.conditioning_emb_size, hidden_size=self.hidden_size, num_layers=self.num_encoder_layers, batch_first=True, bidirectional=True) 204 | 205 | 206 | self.fuse_h_layer= nn.Sequential( 207 | nn.Linear(2*hidden_size, hidden_size), 208 | nn.ReLU() 209 | ) 210 | 211 | self.fuse_c_layer= nn.Sequential( 212 | nn.Linear(2*hidden_size, hidden_size), 213 | nn.ReLU() 214 | ) 215 | 216 | self.attention_layer=Attention(2*hidden_size, 2*hidden_size, 2*hidden_size) 217 | 218 | if self.use_copy_mech: 219 | self.copymech=CopyMechanism(2*self.hidden_size, self.hidden_size, self.token_emb_size+self.conditioning_emb_size) 220 | 221 | self.decoder_rnn=torch.nn.LSTM(input_size=self.token_emb_size+self.conditioning_emb_size, hidden_size=self.hidden_size, num_layers=self.num_decoder_layers, batch_first=False, bidirectional=False) 222 | 223 | self.statenctx_to_prefinal = nn.Linear(3*hidden_size, hidden_size, bias=True) 224 | self.project_to_decoder_input = nn.Linear(self.token_emb_size+self.conditioning_emb_size+2*hidden_size, self.token_emb_size+self.conditioning_emb_size, bias=True) 225 | 226 | self.output_projector = torch.nn.Conv1d(self.pre_output_dim, self.vocab_size, kernel_size=1, bias=True) 227 | self.softmax = nn.Softmax(dim=-1) 228 | 229 | if initial_precoverage_paramfile!=None: 230 | print(f"Loading precoverage weights from {initial_precoverage_paramfile}") 231 | # this will contain path to a .th weights file that contains pre-coverage weights 232 | if not os.path.exists(initial_precoverage_paramfile): 233 | print("WARNING: PRE-COVERAGE FILE NOT FOUND. STARTING FROM RANDOMLY INITIALIZED PARAMETERS") 234 | else: 235 | pretrained_dict = torch.load(initial_precoverage_paramfile, map_location="cuda:0") 236 | model_dict = self.state_dict() 237 | model_dict.update(pretrained_dict) 238 | self.load_state_dict(model_dict) 239 | 240 | def forward(self, source_tokens, target_tokens, section_labels, meta=None, only_predict_probs=False, return_pgen=False): 241 | inp_with_unks = source_tokens["ids_with_unks"] 242 | inp_with_oovs = source_tokens["ids_with_oovs"] 243 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 244 | 245 | feed_tensor = target_tokens["ids_with_unks"][:, :-1] 246 | if self.use_copy_mech: 247 | target_tensor = target_tokens["ids_with_oovs"][:,1:] 248 | else: 249 | target_tensor = target_tokens["ids_with_unks"][:, 1:] 250 | 251 | 252 | batch_size = inp_with_unks.size(0) 253 | # preparing intial state for feeding into decoder. layers of decoder after first one get zeros as initial state 254 | inp_enc_seq, (last_h_value, last_c_value) = self.encode(inp_with_unks, section_labels) 255 | 256 | 257 | # inp_enc_seq is batchsizeXseqlenX2*hiddensize 258 | h_value = self.pad_zeros_to_init_state(last_h_value) 259 | c_value = self.pad_zeros_to_init_state(last_c_value) 260 | state_from_inp = (h_value, c_value) 261 | 262 | input_pad_mask=torch.where(inp_with_unks!=self.PAD_ID, self.true_rep, self.false_rep) 263 | 264 | output_embedded = self.output_embedder(feed_tensor, section_labels) 265 | seqlen_first = output_embedded.permute(1,0,2) 266 | output_seq_len = seqlen_first.size(0) 267 | 268 | #initial values 269 | decoder_hidden_state=state_from_inp 270 | context_vector=torch.zeros(batch_size,1,2*self.hidden_size).cuda() 271 | 272 | # CONTROVERSIAL DIFFERENCE FROM SEE ET AL 273 | decoder_hstates_batchfirst = state_from_inp[0].permute(1, 0, 2) 274 | decoder_cstates_batchfirst = state_from_inp[1].permute(1, 0, 2) 275 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 276 | context_vector, _ = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask) 277 | # 278 | 279 | output_probs=[] 280 | pgens=[] 281 | coverages = [torch.zeros_like(inp_with_unks).type(torch.float).cuda()] 282 | all_attn_weights = [] 283 | 284 | for _i in range(output_seq_len): 285 | seqlen_first_onetimestep = seqlen_first[_i:_i+1] # shape is 1xbatchsizexembsize 286 | context_vector_seqlenfirst = context_vector.permute(1,0,2) # seqlen is 1 always 287 | pre_input_to_decoder=torch.cat([seqlen_first_onetimestep, context_vector_seqlenfirst], dim=-1) 288 | input_to_decoder=self.project_to_decoder_input(pre_input_to_decoder) # shape is 1xbatchsizexembsize 289 | 290 | decoder_h_values, decoder_hidden_state = self.decoder_rnn(input_to_decoder, decoder_hidden_state) 291 | # decoder_h_values is shape 1XbatchsizeXhiddensize 292 | 293 | decoder_h_values_batchfirst = decoder_h_values.permute(1,0,2) 294 | 295 | decoder_hstates_batchfirst = decoder_hidden_state[0].permute(1, 0, 2) 296 | decoder_cstates_batchfirst = decoder_hidden_state[1].permute(1, 0, 2) 297 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 298 | 299 | prev_coverage = coverages[-1] 300 | 301 | context_vector, attn_weights = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask, prev_coverage.unsqueeze(1)) 302 | 303 | all_attn_weights.append(attn_weights.squeeze(1)) 304 | 305 | coverages.append(prev_coverage + attn_weights.squeeze(1)) 306 | 307 | decstate_and_context=torch.cat([decoder_h_values_batchfirst, context_vector], dim=-1) #batchsizeXdec_seqlenX3*hidden_size 308 | prefinal_tensor = self.statenctx_to_prefinal(decstate_and_context) 309 | seqlen_last = prefinal_tensor.permute(0,2,1) #batchsizeXpre_output_dimXdec_seqlen 310 | logits = self.output_projector(seqlen_last) 311 | logits = logits.permute(0,2,1) # batchXdec_seqlenXvocab 312 | 313 | # now doing copymechanism 314 | if self.use_copy_mech: 315 | probs_after_copying, pgen = self.copymech(logits, attn_weights, concatenated_decoder_states, input_to_decoder.permute(1,0,2), context_vector, inp_with_oovs, max_oovs) 316 | pgens.append(pgen) 317 | output_probs.append(probs_after_copying) 318 | else: 319 | output_probs.append(self.softmax(logits)) 320 | 321 | 322 | # now calculating loss and numpreds 323 | '''outprobs is list of batchX1xvocabsize 324 | target_tensor is batchXseqlen''' 325 | targets_tensor_seqfirst = target_tensor.permute(1,0) 326 | target_pad_mask = torch.where(targets_tensor_seqfirst!=self.PAD_ID, self.true_rep, self.false_rep) 327 | # TODO: SHOULD WE SET REQUIRES_GRAD=FALSE FOR PAD_MASK? 328 | 329 | loss=0.0 330 | numpreds=0 331 | total_pgen=0 332 | 333 | total_pgen_placewise=torch.zeros((output_seq_len)).cuda() 334 | numpreds_placewise=torch.zeros((output_seq_len)).cuda() 335 | 336 | if return_pgen and not self.use_copy_mech: 337 | print("Cannot return pgen when copy mechanism is switched off") 338 | assert False 339 | 340 | for _i in range(len(output_probs)): 341 | predicted_probs = output_probs[_i].squeeze(1) 342 | true_labels = targets_tensor_seqfirst[_i] 343 | mask_labels = target_pad_mask[_i] 344 | selected_probs=torch.gather(input=predicted_probs, dim=1, index=true_labels.unsqueeze(1)) 345 | selected_probs=selected_probs.squeeze(1) 346 | selected_neg_logprobs=-1*torch.log(selected_probs) 347 | loss+=torch.sum(selected_neg_logprobs*mask_labels) 348 | 349 | this_numpreds=torch.sum(mask_labels).detach() 350 | numpreds+=this_numpreds 351 | 352 | self.metrics["accuracy"](predicted_probs, true_labels, mask_labels) 353 | 354 | if return_pgen: 355 | pgen=pgens[_i].squeeze(1).squeeze(1) 356 | total_pgen+=torch.sum(pgen*mask_labels) 357 | 358 | total_pgen_placewise[_i]+=torch.sum(pgen*mask_labels).detach() 359 | numpreds_placewise[_i]+=this_numpreds 360 | 361 | coverage_loss = self.coverage_loss(all_attn_weights, target_pad_mask.permute(1,0)) # have to permute to get batcsize to first dim 362 | nll_loss = loss/numpreds 363 | total_loss = nll_loss + self.coverage_coef*coverage_loss 364 | 365 | self.metrics["coverage_loss"](coverage_loss.item()) 366 | self.metrics["nll_loss"](nll_loss.item()) 367 | self.metrics["total_loss"](total_loss.item()) 368 | 369 | return { 370 | "loss": total_loss 371 | } 372 | 373 | 374 | def coverage_loss(self, all_attn_weights, output_padding_mask): 375 | '''all_attn_weights is list of elems where each elem is batchsizeXinp_enclen 376 | mask is batchsizeXdeclen''' 377 | coverages = [torch.zeros_like(all_attn_weights[0])] 378 | covlosses = [] 379 | for a in all_attn_weights: 380 | old_coverage = coverages[-1] 381 | minimums = torch.min(a, old_coverage) 382 | covloss = torch.sum(minimums, dim=1, keepdim=True) 383 | covlosses.append(covloss) 384 | new_coverage = old_coverage + a 385 | coverages.append(new_coverage) 386 | concatenated_covlosses = torch.cat(covlosses, dim=1) 387 | coverage_loss = torch.sum(concatenated_covlosses*output_padding_mask)/torch.sum(output_padding_mask) 388 | return coverage_loss 389 | 390 | 391 | 392 | def pad_zeros_to_init_state(self, h_value): 393 | '''can also be c_value''' 394 | assert(h_value.size(0)==1) # h_value should only be of last layer of lstm 395 | return torch.cat([h_value]+[torch.zeros_like(h_value) for _i in range(self.num_encoder_layers-1)], dim=0) 396 | 397 | 398 | def encode(self, inp, section_labels): 399 | '''Get the encoding of input''' 400 | batch_size = inp.size(0) 401 | inp_seq_len = inp.size(1) 402 | inp_embedded = self.output_embedder(inp, section_labels) 403 | inp_encoded = self.encoder_rnn(inp_embedded) 404 | output_seq=inp_encoded[0] 405 | h_value, c_value = inp_encoded[1] 406 | 407 | h_value_layerwise=h_value.reshape(self.num_encoder_layers, 2, batch_size, self.hidden_size) # numlayersXbidirecXbatchXhid 408 | c_value_layerwise=c_value.reshape(self.num_encoder_layers, 2, batch_size, self.hidden_size) # numlayersXbidirecXbatchXhid 409 | 410 | last_layer_h=h_value_layerwise[-1:,:,:,:] 411 | last_layer_c=c_value_layerwise[-1:,:,:,:] 412 | 413 | last_layer_h=last_layer_h.permute(0,2,1,3).contiguous().view(1, batch_size, 2*self.hidden_size) 414 | last_layer_c=last_layer_c.permute(0,2,1,3).contiguous().view(1, batch_size, 2*self.hidden_size) 415 | 416 | last_layer_h_fused=self.fuse_h_layer(last_layer_h) 417 | last_layer_c_fused=self.fuse_c_layer(last_layer_c) 418 | 419 | return output_seq, (last_layer_h_fused, last_layer_c_fused) 420 | 421 | 422 | def decode_onestep(self, past_outp_input, section_label, past_state_tuple, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs, past_coverage_vector): 423 | '''run one step of decoder. outp_input is batchsizex1 424 | section_label is batchsize 425 | past_context_vector is batchsizeX1Xtwice_of_hiddensize 426 | past_coverage_vector is batchsizeXenc_len''' 427 | outp_embedded = self.output_embedder(past_outp_input, section_label) 428 | tok_seqlen_first = outp_embedded.permute(1,0,2) 429 | assert(tok_seqlen_first.size(0)==1) # only one timestep allowed 430 | 431 | context_vector_seqlenfirst = past_context_vector.permute(1,0,2) # seqlen is 1 always 432 | pre_input_to_decoder=torch.cat([tok_seqlen_first, context_vector_seqlenfirst], dim=-1) 433 | input_to_decoder=self.project_to_decoder_input(pre_input_to_decoder) # shape is 1xbatchsizexembsize 434 | 435 | 436 | decoder_h_values, decoder_hidden_state = self.decoder_rnn(input_to_decoder, past_state_tuple) 437 | # decoder_h_values is shape 1XbatchsizeXhiddensize 438 | decoder_h_values_batchfirst = decoder_h_values.permute(1,0,2) 439 | 440 | decoder_hstates_batchfirst = decoder_hidden_state[0].permute(1, 0, 2) 441 | decoder_cstates_batchfirst = decoder_hidden_state[1].permute(1, 0, 2) 442 | concatenated_decoder_states = torch.cat([decoder_cstates_batchfirst, decoder_hstates_batchfirst], dim=-1) 443 | 444 | context_vector, attn_weights = self.attention_layer(inp_enc_seq, concatenated_decoder_states, input_pad_mask, past_coverage_vector.unsqueeze(1)) 445 | 446 | decstate_and_context=torch.cat([decoder_h_values_batchfirst, context_vector], dim=-1) #batchsizeXdec_seqlenX3*hidden_size 447 | prefinal_tensor = self.statenctx_to_prefinal(decstate_and_context) 448 | seqlen_last = prefinal_tensor.permute(0,2,1) #batchsizeXpre_output_dimXdec_seqlen 449 | logits = self.output_projector(seqlen_last) 450 | logits = logits.permute(0,2,1) # batchXdec_seqlenXvocab 451 | 452 | 453 | # now doing copymechanism 454 | if self.use_copy_mech: 455 | probs_after_copying, _ = self.copymech(logits, attn_weights, concatenated_decoder_states, input_to_decoder.permute(1,0,2), context_vector, inp_with_oovs, max_oovs) 456 | prob_to_return = probs_after_copying[0].squeeze(1) 457 | else: 458 | prob_to_return = self.softmax(logits).squeeze(1) 459 | 460 | 461 | return prob_to_return, decoder_hidden_state, context_vector, attn_weights 462 | 463 | 464 | 465 | def get_initial_state(self, start_ids, initial_decode_state): 466 | '''start_ids is tensor of size batchsizeXseqlen''' 467 | outp_embedded = self.output_embedder(start_ids) 468 | seqlen_first = outp_embedded.permute(1,0,2) 469 | feed=seqlen_first 470 | seqlen=feed.size(0) 471 | h_value, c_value = initial_decode_state 472 | for idx in range(seqlen): 473 | _ , (h_value, c_value) = self.decoder_rnn(feed[idx:idx+1], (h_value, c_value)) 474 | 475 | return (h_value, c_value) 476 | 477 | 478 | @overrides 479 | def forward_on_instance(self, instance: SyncedFieldsInstance, decode_strategy=None) -> Dict[str, str]: 480 | """ 481 | Takes an :class:`~allennlp.data.instance.Instance`, which typically has raw text in it, 482 | converts that text into arrays using this model's :class:`Vocabulary`, passes those arrays 483 | through :func:`self.forward()` and :func:`self.decode()` (which by default does nothing) 484 | and returns the result. Before returning the result, we convert any 485 | ``torch.Tensors`` into numpy arrays and remove the batch dimension. 486 | """ 487 | cuda_device = self._get_prediction_device() 488 | dataset = Batch([instance]) 489 | dataset.index_instances(self.vocab) 490 | model_input = util.move_to_device(dataset.as_tensor_dict(), cuda_device) 491 | output_ids = self.beam_search_decode(**model_input, min_length=self.min_decode_length, max_length=self.max_decode_length) 492 | 493 | output_words = [] 494 | for _id in output_ids: 495 | if _id Dict[str, float]: 509 | metrics_to_return = { 510 | metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items() 511 | } 512 | return metrics_to_return 513 | 514 | 515 | 516 | def beam_search_decode(self, source_tokens, section_labels, target_tokens=None, meta=None, beam_width=4, min_length=35, max_length=120): 517 | inp_with_unks = source_tokens["ids_with_unks"] 518 | inp_with_oovs = source_tokens["ids_with_oovs"] 519 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 520 | input_pad_mask=torch.where(inp_with_unks!=self.PAD_ID, self.true_rep, self.false_rep) 521 | inp_enc_seq, (intial_h_value, intial_c_value) = self.encode(inp_with_unks, section_labels) 522 | h_value = self.pad_zeros_to_init_state(intial_h_value) 523 | c_value = self.pad_zeros_to_init_state(intial_c_value) 524 | source_encoding=(h_value, c_value) 525 | 526 | # the first context vector is calculated by using the first lstm decoder state 527 | first_decoder_hstates_batchfirst = source_encoding[0].permute(1, 0, 2) 528 | first_decoder_cstates_batchfirst = source_encoding[1].permute(1, 0, 2) 529 | first_concatenated_decoder_states = torch.cat([first_decoder_cstates_batchfirst, first_decoder_hstates_batchfirst], dim=-1) 530 | starting_coverage = torch.zeros_like(inp_with_unks).type(torch.float).cuda() 531 | first_context_vector, first_attention = self.attention_layer(inp_enc_seq, first_concatenated_decoder_states, input_pad_mask, starting_coverage) 532 | 533 | hypotheses = [ {"dec_state" : source_encoding, 534 | "past_context_vector" : first_context_vector, 535 | "logprobs" : [0.0], 536 | "out_words" : [self.START_ID], 537 | "coverage" : first_attention.squeeze(1), 538 | } ] 539 | 540 | finished_hypotheses = [] 541 | 542 | def sort_hyps(list_of_hyps): 543 | return sorted(list_of_hyps, key=lambda x:sum(x["logprobs"])/len(x["logprobs"]), reverse=True) 544 | 545 | counter=0 546 | while counter=self.vocab_size: # this guy is an OOV 553 | in_tok=self.OOV_ID 554 | old_dec_state=hyp["dec_state"] 555 | past_context_vector=hyp["past_context_vector"] 556 | past_coverage_vector=hyp["coverage"] 557 | old_logprobs=hyp["logprobs"] 558 | with torch.no_grad(): 559 | new_probs, new_dec_state, new_context_vector, attn_weights = self.decode_onestep( torch.tensor([[in_tok]]).cuda(), section_labels, old_dec_state, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs, past_coverage_vector) 560 | 561 | probs, indices = torch.topk(new_probs[0], dim=0, k=2*beam_width) 562 | for p, idx in zip(probs, indices): 563 | new_dict = {"dec_state" : new_dec_state, 564 | "past_context_vector" : new_context_vector, 565 | "logprobs" : old_logprobs+[float(torch.log(p).detach().cpu().numpy())], 566 | "out_words" : old_out_words+[idx.item()], 567 | "coverage": past_coverage_vector+attn_weights.squeeze(1) 568 | } 569 | new_hypotheses.append(new_dict) 570 | 571 | # time to pick the best of new hypotheses 572 | sorted_new_hypotheses = sort_hyps(new_hypotheses) 573 | hypotheses=[] 574 | for hyp in sorted_new_hypotheses: 575 | if hyp["out_words"][-1]==self.END_ID: 576 | if len(hyp["out_words"])>min_length+1: 577 | finished_hypotheses.append(hyp) 578 | else: 579 | hypotheses.append(hyp) 580 | if len(hypotheses) == beam_width or len(finished_hypotheses) == beam_width: 581 | break 582 | 583 | 584 | if len(finished_hypotheses)>0: 585 | final_candidates = finished_hypotheses 586 | else: 587 | final_candidates = hypotheses 588 | 589 | sorted_final_candidates = sort_hyps(final_candidates) 590 | 591 | best_candidate = sorted_final_candidates[0] 592 | 593 | return best_candidate["out_words"] #, best_candidate["log_likelihood"] 594 | 595 | 596 | 597 | def greedy_decode(self, source_tokens, section_labels, target_tokens=None, meta=None, min_length=35, max_length=120): 598 | inp_with_unks = source_tokens["ids_with_unks"] 599 | inp_with_oovs = source_tokens["ids_with_oovs"] 600 | max_oovs = int(torch.max(source_tokens["num_oovs"])) 601 | input_pad_mask=torch.where(inp_with_unks!=self.PAD_ID, self.true_rep, self.false_rep) 602 | inp_enc_seq, (intial_h_value, intial_c_value) = self.encode(inp_with_unks, section_labels) 603 | h_value = self.pad_zeros_to_init_state(intial_h_value) 604 | c_value = self.pad_zeros_to_init_state(intial_c_value) 605 | source_encoding=(h_value, c_value) 606 | 607 | # the first context vector is calculated by using the first lstm decoder state 608 | first_decoder_hstates_batchfirst = source_encoding[0].permute(1, 0, 2) 609 | first_decoder_cstates_batchfirst = source_encoding[1].permute(1, 0, 2) 610 | first_concatenated_decoder_states = torch.cat([first_decoder_cstates_batchfirst, first_decoder_hstates_batchfirst], dim=-1) 611 | first_context_vector, _ = self.attention_layer(inp_enc_seq, first_concatenated_decoder_states, input_pad_mask) 612 | 613 | hyp = {"dec_state" : source_encoding, 614 | "past_context_vector" : first_context_vector, 615 | "logprobs" : [0.0], 616 | "out_words" : [self.START_ID] 617 | } 618 | 619 | counter=0 620 | while counter=self.vocab_size: # this guy is an OOV 626 | in_tok=self.OOV_ID 627 | old_dec_state=hyp["dec_state"] 628 | past_context_vector=hyp["past_context_vector"] 629 | old_logprobs=hyp["logprobs"] 630 | new_probs, new_dec_state, new_context_vector = self.decode_onestep( torch.tensor([[in_tok]]).cuda(), section_labels, old_dec_state, past_context_vector, inp_enc_seq, inp_with_oovs, input_pad_mask, max_oovs) 631 | 632 | probs, indices = torch.topk(new_probs[0], dim=0, k=1) 633 | assert len(probs)==1 and len(indices)==1 634 | p = probs[0] 635 | idx = indices[0] 636 | hyp = {"dec_state" : new_dec_state, 637 | "past_context_vector" : new_context_vector, 638 | "logprobs" : old_logprobs+[float(torch.log(p).detach().cpu().numpy())], 639 | "out_words" : old_out_words+[idx.item()] 640 | } 641 | 642 | # time to pick the best of new hypotheses 643 | if hyp["out_words"][-1]==self.END_ID: 644 | if len(hyp["out_words"])>min_length+1: 645 | break 646 | 647 | best_candidate = hyp 648 | 649 | return best_candidate["out_words"] #, best_candidate["log_likelihood"] 650 | 651 | 652 | --------------------------------------------------------------------------------