├── issue-20201028.PNG ├── code └── xdai │ ├── ner │ ├── transition_discontinuous │ │ ├── config.json │ │ ├── cadec.sh │ │ ├── test_parsing.py │ │ ├── train.py │ │ └── parsing.py │ ├── convert_text_inline_to_conll.py │ ├── convert_conll_to_text_inline.py │ ├── evaluate.py │ └── mention.py │ ├── utils │ ├── attention.py │ ├── seq2vec.py │ ├── seq2seq.py │ ├── args.py │ ├── token.py │ ├── common.py │ ├── vocab.py │ ├── nn.py │ ├── token_embedder.py │ ├── token_indexer.py │ ├── instance.py │ ├── iterator.py │ └── train.py │ └── elmo │ └── utils.py ├── data ├── share2013 │ ├── dev.list │ ├── script.sh │ ├── convert_text_inline.py │ ├── output.log │ ├── extract_ann.py │ ├── convert_ann_using_token_idx.py │ └── tokenization.py ├── sample │ └── dev.txt ├── share2014 │ ├── dev.list │ ├── script.sh │ └── extract_ann.py └── cadec │ ├── build_data_for_transition_discontinuous_ner.sh │ ├── split_train_test.py │ ├── extract_annotations.py │ ├── tokenization.py │ ├── split │ ├── test.id │ ├── dev.id │ └── train.id │ ├── convert_text_inline.py │ ├── convert_ann_using_token_idx.py │ ├── build_data_for_transition_discontinous_ner.log │ └── convert_flat_mentions.py └── README.md /issue-20201028.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dainlp/acl2020-transition-discontinuous-ner/HEAD/issue-20201028.PNG -------------------------------------------------------------------------------- /code/xdai/ner/transition_discontinuous/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_word_embeddings": "/data/dai031/Corpora/GloVe/glove.6B.100d.txt", 3 | "word_embedding_size": 100, 4 | "char_embedding_size": 16, 5 | "action_embedding_size": 20, 6 | "lstm_cell_size": 200, 7 | "lstm_layers": 2, 8 | "dropout": 0.5 9 | } -------------------------------------------------------------------------------- /data/share2013/dev.list: -------------------------------------------------------------------------------- 1 | 08380-043167-ECG_REPORT.txt 2 | 10612-047357-ECG_REPORT.txt 3 | 01234-029456-DISCHARGE_SUMMARY.txt 4 | 15013-102321-ECHO_REPORT.txt 5 | 07514-025655-DISCHARGE_SUMMARY.txt 6 | 06445-096221-ECHO_REPORT.txt 7 | 00414-104513-ECHO_REPORT.txt 8 | 08951-002958-DISCHARGE_SUMMARY.txt 9 | 07429-001857-DISCHARGE_SUMMARY.txt 10 | 17451-147855-RADIOLOGY_REPORT.txt 11 | 16093-011230-DISCHARGE_SUMMARY.txt 12 | 17336-021181-DISCHARGE_SUMMARY.txt 13 | 00098-016139-DISCHARGE_SUMMARY.txt 14 | 07761-036998-ECG_REPORT.txt 15 | 22566-151087-RADIOLOGY_REPORT.txt 16 | 19709-026760-DISCHARGE_SUMMARY.txt 17 | 07352-013977-DISCHARGE_SUMMARY.txt 18 | 25003-338492-RADIOLOGY_REPORT.txt 19 | 21273-244548-RADIOLOGY_REPORT.txt -------------------------------------------------------------------------------- /code/xdai/ner/transition_discontinuous/cadec.sh: -------------------------------------------------------------------------------- 1 | for seed in 52 869 1001 50542 353778 2 | do 3 | python train.py --output_dir /data/dai031/Experiments/TransitionDiscontinuous/cadec/$seed \ 4 | --train_filepath /data/dai031/Experiments/CADEC/adr/split/train.txt \ 5 | --dev_filepath /data/dai031/Experiments/CADEC/adr/split/dev.txt \ 6 | --test_filepath /data/dai031/Experiments/CADEC/adr/split/test.txt \ 7 | --log_filepath /data/dai031/Experiments/TransitionDiscontinuous/cadec/$seed/train.log \ 8 | --model_type elmo --pretrained_model_dir /data/dai031/Corpora/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B \ 9 | --weight_decay 0.0001 --max_grad_norm 5 \ 10 | --learning_rate 0.001 --num_train_epochs 20 --patience 5 --eval_metric f1-overall \ 11 | --max_save_checkpoints 0 \ 12 | --cuda_device 0 --seed $seed 13 | done -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository has a pytorch implementation of transition-based model for discontinuous NER, introduced in our ACL 2020 paper: 2 | 3 | Xiang Dai, Sarvnaz Karimi, Ben Hachey, and Cecile Paris. 2020. An Effective Transition-based Model for Discontinuous NER. In ACL, Seattle, Washington. 4 | 5 | * CADEC dataset can be downloaded at: https://data.csiro.au/dap/landingpage?pid=csiro:10948&v=3&d=true 6 | * ShARe data can be downloaded at: https://physionet.org/ 7 | 8 | 9 | Once you download the dataset, you can use script data/cadec/build_data_for_transition_discontinuous_ner.sh to build the dataset. 10 | * Sample data can be found at data/sample directory. 11 | 12 | 13 | The script used to train the model on CADEC can be found in code/xdai/ner/transition_discontinuous/cadec.sh 14 | 15 | #### Update data/share2013/script.sh at 2021-July-26 16 | * Add details about how to extract and rename these downloaded files 17 | -------------------------------------------------------------------------------- /data/sample/dev.txt: -------------------------------------------------------------------------------- 1 | Aches and pains , breathing difficulties , panic attacks , stength and stamina destroyed , palpatations , strong tingling in one hand , some tinititis and occassional ear aches . 2 | 0,2 ADR|4,5 ADR|7,8 ADR|10,13 ADR|15,15 ADR|17,21 ADR|24,24 ADR|27,28 ADR 3 | 4 | Have mentioned symptoms as they appeared to my doctor only to be prescribed zorloft for depression and anxiety and noten for heart palpatations . 5 | 13,13 Drug|15,15 ADR|17,17 ADR|19,19 Drug|21,22 ADR 6 | 7 | I had an extreme reaction - tunnel vision , vertigo , tingling and numbness ( hands , legs , face ) , complete weakness in the right side of my body including loss of strength , and what I term as complete body flushes ( like a hot flash without the hot ) . 8 | 6,7 ADR|9,9 ADR|11,11,15,15 ADR|13,13,15,15 ADR|11,11,17,17 ADR|11,11,19,19 ADR|13,13,17,17 ADR|13,13,19,19 ADR|22,30 ADR|32,34 ADR|41,52 ADR 9 | 10 | 40 mg dose Lipitor worked well in reducing my overall cholesterol levels ; which are high in my family . 11 | 3,3 Drug 12 | 13 | -------------------------------------------------------------------------------- /data/share2014/dev.list: -------------------------------------------------------------------------------- 1 | 09622-087101-ECG_REPORT.txt 2 | 23039-078076-ECG_REPORT.txt 3 | 08870-061373-ECG_REPORT.txt 4 | 23590-017830-DISCHARGE_SUMMARY.txt 5 | 23969-299900-RADIOLOGY_REPORT.txt 6 | 19791-003873-DISCHARGE_SUMMARY.txt 7 | 17336-021181-DISCHARGE_SUMMARY.txt 8 | 20223-103427-ECHO_REPORT.txt 9 | 07726-023607-DISCHARGE_SUMMARY.txt 10 | 14835-325902-RADIOLOGY_REPORT.txt 11 | 20442-023289-DISCHARGE_SUMMARY.txt 12 | 04269-027967-DISCHARGE_SUMMARY.txt 13 | 09569-067879-ECG_REPORT.txt 14 | 04303-005081-DISCHARGE_SUMMARY.txt 15 | 00211-027889-DISCHARGE_SUMMARY.txt 16 | 15751-026988-DISCHARGE_SUMMARY.txt 17 | 09963-257487-RADIOLOGY_REPORT.txt 18 | 02916-100844-ECHO_REPORT.txt 19 | 14108-340203-RADIOLOGY_REPORT.txt 20 | 22230-040122-ECG_REPORT.txt 21 | 26563-387055-RADIOLOGY_REPORT.txt 22 | 10689-110055-ECHO_REPORT.txt 23 | 13101-048474-ECG_REPORT.txt 24 | 15295-348292-RADIOLOGY_REPORT.txt 25 | 16055-152402-RADIOLOGY_REPORT.txt 26 | 07786-029701-ECG_REPORT.txt 27 | 04082-167766-RADIOLOGY_REPORT.txt 28 | 21115-101632-ECHO_REPORT.txt 29 | 01455-067052-ECG_REPORT.txt -------------------------------------------------------------------------------- /data/share2013/script.sh: -------------------------------------------------------------------------------- 1 | # Download shareclef-ehealth-2013-natural-language-processing-and-information-retrieval-for-clinical-care-1.0.zip from https://physionet.org/content/shareclefehealth2013/1.0/ and unzip it 2 | 3 | SHARE2013_DIR=/home/gdpr/Corpora/ShARe2013/shareclef-ehealth-2013-natural-language-processing-and-information-retrieval-for-clinical-care-1.0 4 | cd $SHARE2013_DIR 5 | mkdir train 6 | unzip Task1TrainSetCorpus199.zip 7 | mv ALLREPORTS train/text 8 | unzip Task1TrainSetGOLD199knowtatorehost.zip 9 | mv ALLSAVED train/ann 10 | 11 | mkdir test 12 | unzip Task1TestSetCorpus100.zip 13 | mv ALLREPORTS test/text 14 | tar -xf Task1Gold_SN2012.tar.bz2 15 | mv Gold_SN2012 test/ann 16 | 17 | # run the code 18 | cd /home/gdpr/Downloads/acl2020-transition-discontinuous-ner-master/data/share2013 19 | 20 | python extract_ann.py --input_dir=$SHARE2013_DIR/train/ann --text_dir=$SHARE2013_DIR/train/text --split=train 21 | python extract_ann.py --input_dir=$SHARE2013_DIR/test/ann --text_dir=$SHARE2013_DIR/test/text --split=test 22 | 23 | python tokenization.py --input_dir=$SHARE2013_DIR/train/text --split=train 24 | python tokenization.py --input_dir=$SHARE2013_DIR/test/text --split=test 25 | 26 | python convert_ann_using_token_idx.py 27 | 28 | mkdir processed_share2013 29 | python convert_text_inline.py --output_dir processed_share2013 30 | -------------------------------------------------------------------------------- /code/xdai/utils/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from xdai.utils.nn import masked_softmax 3 | 4 | 5 | class Attention(torch.nn.Module): 6 | def __init__(self, normalize: bool = True): 7 | super().__init__() 8 | self._normalize = normalize 9 | 10 | 11 | def forward(self, vector, matrix, matrix_mask=None): 12 | similarities = self._forward_internal(vector, matrix) 13 | if self._normalize: 14 | return masked_softmax(similarities, matrix_mask) 15 | else: 16 | return similarities 17 | 18 | 19 | def _forward_internal(self, vector, matrix): 20 | raise NotImplementedError 21 | 22 | 23 | class BilinearAttention(Attention): 24 | '''The similarity between the vector x and the matrix y is: x^T W y + b, where W, b are parameters''' 25 | def __init__(self, vector_dim, matrix_dim): 26 | super().__init__() 27 | self._W = torch.nn.parameter.Parameter(torch.Tensor(vector_dim, matrix_dim)) 28 | self._b = torch.nn.parameter.Parameter(torch.Tensor(1)) 29 | self.reset_parameters() 30 | 31 | 32 | def reset_parameters(self): 33 | torch.nn.init.xavier_uniform_(self._W) 34 | self._b.data.fill_(0) 35 | 36 | 37 | def _forward_internal(self, vector, matrix): 38 | intermediate = vector.mm(self._W).unsqueeze(1) 39 | return intermediate.bmm(matrix.transpose(1, 2)).squeeze(1) + self._b -------------------------------------------------------------------------------- /data/cadec/build_data_for_transition_discontinuous_ner.sh: -------------------------------------------------------------------------------- 1 | mkdir /data/dai031/Experiments/CADEC/adr 2 | 3 | echo "Extract annotations ..." >> build_data_for_transition_discontinous_ner.log 4 | python extract_annotations.py --output_filepath /data/dai031/Experiments/CADEC/adr/ann --type_of_interest ADR --log_filepath build_data_for_transition_discontinous_ner.log 5 | 6 | echo "Tokenization ..." >> build_data_for_transition_discontinous_ner.log 7 | python tokenization.py --log_filepath build_data_for_transition_discontinous_ner.log 8 | 9 | echo "Convert annotations from character level offsets to token level idx ..." >> build_data_for_transition_discontinous_ner.log 10 | python convert_ann_using_token_idx.py --input_ann /data/dai031/Experiments/CADEC/adr/ann --output_ann /data/dai031/Experiments/CADEC/adr/tokens.ann --log_filepath build_data_for_transition_discontinous_ner.log 11 | 12 | echo "Create text inline format ..." >> build_data_for_transition_discontinous_ner.log 13 | python convert_text_inline.py --input_ann /data/dai031/Experiments/CADEC/adr/tokens.ann --output_filepath /data/dai031/Experiments/CADEC/adr/inline --log_filepath build_data_for_transition_discontinous_ner.log 14 | 15 | echo "Split the data set into train, dev, test splits ..." >> build_data_for_transition_discontinous_ner.log 16 | python split_train_test.py --input_filepath /data/dai031/Experiments/CADEC/adr/inline --output_dir /data/dai031/Experiments/CADEC/adr/split -------------------------------------------------------------------------------- /code/xdai/ner/convert_text_inline_to_conll.py: -------------------------------------------------------------------------------- 1 | '''Usage: 2 | python convert_text_inline_to_conll.py --input_filepath /data/dai031/Experiments/CADEC/adr/flat/train.txt --output_filepath /data/dai031/Experiments/CADEC/adr/flat/train.conll 3 | python convert_text_inline_to_conll.py --input_filepath /data/dai031/Experiments/CADEC/adr/flat/dev.txt --output_filepath /data/dai031/Experiments/CADEC/adr/flat/dev.conll 4 | python convert_text_inline_to_conll.py --input_filepath /data/dai031/Experiments/CADEC/adr/flat/test.txt --output_filepath /data/dai031/Experiments/CADEC/adr/flat/test.conll 5 | ''' 6 | import argparse, os, sys 7 | 8 | sys.path.insert(0, os.path.abspath("../..")) 9 | from xdai.ner.mention import mentions_to_bio_tags 10 | 11 | 12 | def parse_parameters(parser=None): 13 | if parser is None: parser = argparse.ArgumentParser() 14 | 15 | ## Required 16 | parser.add_argument("--input_filepath", type=str) 17 | parser.add_argument("--output_filepath", type=str) 18 | 19 | args, _ = parser.parse_known_args() 20 | return args 21 | 22 | 23 | if __name__ == "__main__": 24 | args = parse_parameters() 25 | with open(args.output_filepath, "w") as out_f: 26 | with open(args.input_filepath) as in_f: 27 | for text in in_f: 28 | tokens = text.strip().split() 29 | mentions = next(in_f).strip() 30 | assert len(next(in_f).strip()) == 0 31 | tags = mentions_to_bio_tags(mentions.strip(), len(tokens)) 32 | for token, tag in zip(tokens, tags): 33 | out_f.write("%s %s\n" % (token, tag)) 34 | out_f.write("\n") -------------------------------------------------------------------------------- /code/xdai/ner/convert_conll_to_text_inline.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys 2 | sys.path.insert(0, os.path.abspath("../..")) 3 | 4 | from xdai.ner.mention import bio_tags_to_mentions, bioes_to_bio 5 | 6 | 7 | def parse_parameters(parser=None): 8 | if parser is None: parser = argparse.ArgumentParser() 9 | 10 | ## Required 11 | parser.add_argument("--input_filepath", default="/data/dai031/Experiments/flair/conll2003/test.tsv", type=str) 12 | parser.add_argument("--output_filepath", default="/data/dai031/Experiments/flair/conll2003/test.txt", type=str) 13 | parser.add_argument("--pred_column_idx", default=-1, type=int) 14 | 15 | args, _ = parser.parse_known_args() 16 | return args 17 | 18 | 19 | if __name__ == "__main__": 20 | args = parse_parameters() 21 | 22 | sentences = [] 23 | with open(args.input_filepath) as f: 24 | tokens, tags = [], [] 25 | for line in f: 26 | sp = line.strip().split() 27 | if len(sp) < 2 or sp[0] == "-DOCSTART-": 28 | if len(tokens) > 0: 29 | sentences.append((tokens, bio_tags_to_mentions(bioes_to_bio(tags)))) 30 | tokens, tags = [], [] 31 | continue 32 | tokens.append(sp[0]) 33 | tags.append(sp[args.pred_column_idx]) 34 | if len(tokens) > 0: 35 | sentences.append((tokens, bio_tags_to_mentions(bioes_to_bio(tags)))) 36 | 37 | with open(args.output_filepath, "w") as f: 38 | for (tokens, mentions) in sentences: 39 | f.write("%s\n" % " ".join(tokens)) 40 | mentions = [str(m) for m in mentions] 41 | f.write("%s\n" % "|".join(mentions)) 42 | f.write("\n") -------------------------------------------------------------------------------- /data/share2014/script.sh: -------------------------------------------------------------------------------- 1 | # Download shareclef-ehealth-evaluation-lab-2014-task-2-disorder-attributes-in-clinical-reports-1.0.zip.zip from https://physionet.org/content/shareclefehealth2014task2/1.0/ and unzip it 2 | 3 | SHARE2014_DIR=/home/gdpr/Corpora/ShAReCLEF2014-t2/shareclef-ehealth-evaluation-lab-2014-task-2-disorder-attributes-in-clinical-reports-1.0 4 | cd $SHARE2014_DIR 5 | mkdir train 6 | unzip 2014ShAReCLEFeHealthTasks2_training_10Jan2014.zip 7 | mv 2014ShAReCLEFeHealthTasks2_training_10Jan2014/2014ShAReCLEFeHealthTask2_training_corpus train/text 8 | mv 2014ShAReCLEFeHealthTasks2_training_10Jan2014/2014ShAReCLEFeHealthTask2_training_pipedelimited train/ann 9 | 10 | mkdir test 11 | unzip ShAReCLEFeHealth2014Task2_test_default_values.zip 12 | mv ShAReCLEFeHealth2014Task2_test_default_values_with_corpus/ShAReCLEFeHealth2104Task2_test_data_corpus test/text 13 | unzip ShAReCLEFeHealth2014_test_data_gold.zip 14 | mv ShAReCLEFeHealth2014_test_data_gold test/ann 15 | 16 | # run the code 17 | cd /home/gdpr/Downloads/acl2020-transition-discontinuous-ner-master/data/share2014 18 | 19 | python extract_ann.py --ann_dir=$SHARE2014_DIR/train/ann --text_dir=$SHARE2014_DIR/train/text --split=train 20 | python extract_ann.py --ann_dir=$SHARE2014_DIR/test/ann --text_dir=$SHARE2014_DIR/test/text --split=test 21 | 22 | cp ../share2013/tokenization.py ./ 23 | python tokenization.py --input_dir=$SHARE2014_DIR/train/text --split=train 24 | python tokenization.py --input_dir=$SHARE2014_DIR/test/text --split=test 25 | 26 | cp ../share2013/convert_ann_using_token_idx.py ./ 27 | python convert_ann_using_token_idx.py 28 | 29 | mkdir processed_share2014 30 | cp ../share2013/convert_text_inline.py ./ 31 | python convert_text_inline.py --output_dir processed_share2014 32 | -------------------------------------------------------------------------------- /code/xdai/elmo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py 5 | Update date: 2019-Nov-5''' 6 | def add_sentence_boundary_token_ids(tensor, mask, sentence_begin_token, sentence_end_token): 7 | sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() 8 | input_shape = list(tensor.data.shape) 9 | output_shape = list(input_shape) 10 | output_shape[1] = input_shape[1] + 2 11 | tensor_with_boundary_tokens = tensor.new_zeros(*output_shape) 12 | assert len(input_shape) == 3 13 | tensor_with_boundary_tokens[:, 1:-1, :] = tensor 14 | for i, j in enumerate(sequence_lengths): 15 | tensor_with_boundary_tokens[i, 0, :] = sentence_begin_token 16 | tensor_with_boundary_tokens[i, j + 1, :] = sentence_end_token 17 | mask_with_boundary_tokens = ((tensor_with_boundary_tokens > 0).long().sum(dim=-1) > 0).long() 18 | 19 | return tensor_with_boundary_tokens, mask_with_boundary_tokens 20 | 21 | 22 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py 23 | Update date: 2019-Nov-5''' 24 | def remove_sentence_boundaries(tensor, mask): 25 | sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() 26 | input_shape = list(tensor.data.shape) 27 | output_shape = list(input_shape) 28 | output_shape[1] = input_shape[1] - 2 29 | tensor_without_boundary_tokens = tensor.new_zeros(*output_shape) 30 | output_mask = tensor.new_zeros((output_shape[0], output_shape[1]), dtype=torch.long) 31 | for i, j in enumerate(sequence_lengths): 32 | if j > 2: 33 | tensor_without_boundary_tokens[i, :(j-2), :] = tensor[i, 1:(j-1), :] 34 | output_mask[i, :(j-2)] = 1 35 | return tensor_without_boundary_tokens, output_mask -------------------------------------------------------------------------------- /data/cadec/split_train_test.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse, os 3 | 4 | 5 | def parse_parameters(parser=None): 6 | if parser is None: parser = argparse.ArgumentParser() 7 | 8 | ## Required 9 | parser.add_argument("--input_filepath", default="/data/dai031/Experiments/CADEC/text-inline", type=str) 10 | parser.add_argument("--output_dir", default="/data/dai031/Experiments/CADEC/split", type=str) 11 | 12 | args, _ = parser.parse_known_args() 13 | return args 14 | 15 | 16 | def output_sentence(sentences, filepath, split): 17 | with open(filepath, "w") as f: 18 | for sentence in sentences: 19 | if sentence[0] in split: 20 | f.write("%s\n" % sentence[1]) 21 | f.write("%s\n" % sentence[2]) 22 | f.write("\n") 23 | 24 | 25 | if __name__ == "__main__": 26 | train_set = [l.strip() for l in open("split/train.id").readlines()] 27 | dev_set = [l.strip() for l in open("split/dev.id").readlines()] 28 | test_set = [l.strip() for l in open("split/test.id").readlines()] 29 | 30 | args = parse_parameters() 31 | if not os.path.exists(args.output_dir): 32 | os.mkdir(args.output_dir) 33 | 34 | sentences = [] 35 | with open(args.input_filepath) as f: 36 | for line in f: 37 | doc = line.strip().replace("Document: ", "") 38 | tokens = next(f).strip() 39 | mentions = next(f).strip() 40 | sentences.append((doc, tokens, mentions)) 41 | assert next(f).strip() == "" 42 | 43 | output_sentence(sentences, os.path.join(args.output_dir, "train.txt"), train_set) 44 | output_sentence(sentences, os.path.join(args.output_dir, "dev.txt"), dev_set) 45 | output_sentence(sentences, os.path.join(args.output_dir, "test.txt"), test_set) -------------------------------------------------------------------------------- /code/xdai/utils/seq2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/seq2vec_encoders/cnn_encoder.py 6 | Update date: 2019-Nov-5''' 7 | class CnnEncoder(torch.nn.Module): 8 | def __init__(self, input_dim=16, num_filters=128, ngram_filter_sizes=[3]): 9 | super(CnnEncoder, self).__init__() 10 | self._input_dim = input_dim 11 | self._convolution_layers = [torch.nn.Conv1d(in_channels=input_dim, 12 | out_channels=num_filters, 13 | kernel_size=ngram_size) 14 | for ngram_size in ngram_filter_sizes] 15 | for i, conv_layer in enumerate(self._convolution_layers): 16 | self.add_module("conv_layer_%d" % i, conv_layer) 17 | self._output_dim = num_filters * len(ngram_filter_sizes) 18 | 19 | def get_input_dim(self): 20 | return self._input_dim 21 | 22 | def get_output_dim(self): 23 | return self._output_dim 24 | 25 | def forward(self, inputs, mask=None): 26 | if mask is not None: 27 | inputs = inputs * mask.unsqueeze(-1).float() 28 | 29 | # The convolution layers expect input of shape (batch size, in_channels(input_dim), sequence lengths) 30 | inputs = torch.transpose(inputs, 1, 2) 31 | # Each convolutiona layer returns output of size (batch size, num of filters, pool length), 32 | # where pool length = sequence lengths - ngram_size + 1 33 | filter_outputs = [] 34 | for i in range(len(self._convolution_layers)): 35 | convolution_layer = getattr(self, "conv_layer_{}".format(i)) 36 | filter_outputs.append(F.relu(convolution_layer(inputs)).max(dim=2)[0]) 37 | 38 | maxpool_output = torch.cat(filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0] 39 | 40 | return maxpool_output -------------------------------------------------------------------------------- /data/cadec/extract_annotations.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse, os, sys, re 3 | from typing import List 4 | 5 | def extract_indices_from_brat_annotation(indices: str) -> List[int]: 6 | indices = re.findall(r"\d+", indices) 7 | indices = sorted([int(i) for i in indices]) 8 | return indices 9 | 10 | 11 | def parse_parameters(parser=None): 12 | if parser is None: parser = argparse.ArgumentParser() 13 | 14 | ## Required 15 | parser.add_argument("--input_ann", default="/data/dai031/Corpora/CADEC/cadec/original", type=str) 16 | parser.add_argument("--input_text", default="/data/dai031/Corpora/CADEC/cadec/text", type=str) 17 | parser.add_argument("--output_filepath", default="/data/dai031/Experiments/CADEC/all/ann", type=str) 18 | parser.add_argument("--type_of_interest", default="") 19 | 20 | args, _ = parser.parse_known_args() 21 | return args 22 | 23 | 24 | def _get_mention_from_text(text, indices): 25 | tokens = [] 26 | for i in range(0, len(indices), 2): 27 | start = int(indices[i]) 28 | end = int(indices[i + 1]) 29 | tokens.append(text[start:end]) 30 | return " ".join(tokens) 31 | 32 | 33 | if __name__ == "__main__": 34 | args = parse_parameters() 35 | num_annotations = 0 36 | with open(args.output_filepath, "w") as out_f: 37 | for document in os.listdir(args.input_ann): 38 | with open(os.path.join(args.input_ann, document), "r") as in_f: 39 | document = document.replace(".ann", "") 40 | text = open(os.path.join(args.input_text, "%s.txt" % document)).read() 41 | for line in in_f: 42 | line = line.strip() 43 | if line[0] != "T": continue 44 | sp = line.strip().split("\t") 45 | assert len(sp) == 3 46 | mention = sp[2] 47 | sp = sp[1].split(" ") 48 | label = sp[0] 49 | if args.type_of_interest != "" and args.type_of_interest != label: continue 50 | indices = extract_indices_from_brat_annotation(" ".join(sp[1:])) 51 | mention_from_text = _get_mention_from_text(text, indices) 52 | if mention != mention_from_text: 53 | print("Update the mention from (%s) to (%s)." % (mention, mention_from_text)) 54 | mention = mention_from_text 55 | num_annotations += 1 56 | out_f.write("%s\t%s\t%s\t%s\n" % (document, label, ",".join([str(i) for i in indices]), mention)) 57 | 58 | print("Extract %d annotations." % num_annotations) -------------------------------------------------------------------------------- /data/share2014/extract_ann.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, os, re, sys 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | def parse_parameters(parser=None): 7 | if parser is None: parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument("--ann_dir", default=None, type=str) 10 | parser.add_argument("--text_dir", default=None, type=str) 11 | parser.add_argument("--split", default=None, type=str) 12 | parser.add_argument("--log_filepath", default="output.log", type=str) 13 | 14 | args, _ = parser.parse_known_args() 15 | return args 16 | 17 | 18 | if __name__ == "__main__": 19 | args = parse_parameters() 20 | handlers = [logging.FileHandler(filename=args.log_filepath), logging.StreamHandler(sys.stdout)] 21 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 22 | level=logging.INFO, handlers=handlers) 23 | 24 | mentions = {} 25 | for filename in os.listdir(args.ann_dir): 26 | with open(os.path.join(args.text_dir, filename.replace(".pipe", ""))) as text_f: 27 | text = text_f.read() 28 | with open(os.path.join(args.ann_dir, filename)) as in_f: 29 | for line in in_f: 30 | sp = line.strip().split("|") 31 | assert sp[0] == filename.replace(".pipe", "") 32 | indices = re.findall(r"\d+", sp[1]) 33 | indices = sorted([int(i) for i in indices]) 34 | mention_text, gap_text = [], [] 35 | for i in range(0, len(indices), 2): 36 | mention_text.append( 37 | text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 38 | for i in range(1, len(indices) - 2, 2): 39 | gap_text.append(text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 40 | 41 | mention_text = " ".join(mention_text).strip() 42 | gap_text = " ".join(gap_text).strip() 43 | if len(indices) > 2 and len(gap_text) == 0: 44 | logger.info("%s in %s is not a real discontinuous entity, as the mention (%s) has no gap." % ( 45 | ",".join([str(i) for i in indices]), filename, mention_text)) 46 | indices = [indices[0], indices[-1]] 47 | str_indices = ",".join([str(i) for i in indices]) 48 | mentions[(filename.replace(".pipe", ""), str_indices)] = (mention_text, gap_text) 49 | 50 | with open("%s.ann" % args.split, "w") as out_f: 51 | for k, v in mentions.items(): 52 | out_f.write("%s\tDisorder\t%s\t%s\t%s\n" % (k[0], k[1], v[0], v[1])) -------------------------------------------------------------------------------- /data/cadec/tokenization.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse, os, sys, re 3 | 4 | class Token(object): 5 | def __init__(self, text=None, start=None, end=None, orig_text=None, text_id=None): 6 | self.text = text # might be normalized 7 | self.start = int(start) if start is not None else None # the character offset of this token into the tokenized sentence. 8 | if end is not None: 9 | self.end = int(end) 10 | else: 11 | if self.text is not None and self.start is not None: 12 | self.end = self.start + len(self.text) 13 | else: 14 | self.end = None 15 | self.orig_text = orig_text 16 | self.text_id = text_id 17 | 18 | 19 | def __str__(self): 20 | return self.text 21 | 22 | 23 | def __repr__(self): 24 | return self.__str__() 25 | 26 | class LettersDigitsTokenizer: 27 | def tokenize(self, text): 28 | tokens = [Token(m.group(), start=m.start()) for m in re.finditer(r"[^\W\d_]+|\d+|\S", text)] 29 | return tokens 30 | 31 | 32 | def parse_parameters(parser=None): 33 | if parser is None: parser = argparse.ArgumentParser() 34 | 35 | ## Required 36 | parser.add_argument("--input_dir", default="/data/dai031/Corpora/CADEC/cadec/text", type=str) 37 | parser.add_argument("--output_filepath", default="/data/dai031/Experiments/CADEC/tokens", type=str) 38 | 39 | args, _ = parser.parse_known_args() 40 | return args 41 | 42 | 43 | if __name__ == "__main__": 44 | args = parse_parameters() 45 | 46 | tokenizer = LettersDigitsTokenizer() 47 | num_docs, num_sents, num_tokens = 0, 0, 0 48 | 49 | with open(args.output_filepath, "w") as out_f: 50 | for doc in os.listdir(args.input_dir): 51 | with open(os.path.join(args.input_dir, doc), "r") as in_f: 52 | doc = doc.replace(".txt", "") 53 | num_docs += 1 54 | text = in_f.read() 55 | line_start = 0 56 | for line in text.split("\n"): 57 | line = line.strip() 58 | if len(line) > 0: 59 | num_sents += 1 60 | line_start = text.find(line, line_start) 61 | assert line_start >= 0 62 | for token in tokenizer.tokenize(line): 63 | num_tokens += 1 64 | token_start = token.start + line_start 65 | token_end = token.end + line_start 66 | out_f.write("%s %s %d %d\n" % (token.text, doc, token_start, token_end)) 67 | out_f.write("\n") 68 | line_start += len(line) 69 | 70 | print("%d documents, %d sentences, %s tokens" % (num_docs, num_sents, num_tokens)) 71 | -------------------------------------------------------------------------------- /code/xdai/utils/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 3 | 4 | 5 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py 6 | inputs: 7 | tensor: batch first tensor 8 | outputs: 9 | sorted_sequence_lengths: sorted by decreasing size 10 | restoration_indices: sorted_tensor.index_select(0, restoration_indices) == original_tensor 11 | permutation_index: useful if want to sort other tensors using the same ordering 12 | Update date: 2019-April-26''' 13 | def sort_batch_by_length(tensor, sequence_lengths): 14 | assert isinstance(tensor, torch.Tensor) and isinstance(sequence_lengths, torch.Tensor) 15 | 16 | sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True) 17 | sorted_tensor = tensor.index_select(0, permutation_index) 18 | 19 | index_range = torch.arange(0, len(sequence_lengths), device=sequence_lengths.device) 20 | _, reverse_mapping = permutation_index.sort(0, descending=False) 21 | restoration_indices = index_range.index_select(0, reverse_mapping) 22 | 23 | return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index 24 | 25 | 26 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/encoder_base.py 27 | Update date: 2019-03-03''' 28 | class EncoderBase(torch.nn.Module): 29 | def __init__(self): 30 | super(EncoderBase, self).__init__() 31 | 32 | 33 | def sort_and_run_forward(self, module, inputs, mask): 34 | sequence_lengths = mask.long().sum(-1) 35 | sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices = sort_batch_by_length(inputs, 36 | sequence_lengths) 37 | 38 | packed_sequence_input = pack_padded_sequence(sorted_inputs[:, :, :], 39 | sorted_sequence_lengths[:].data.tolist(), 40 | batch_first=True) 41 | 42 | module_output, final_states = module(packed_sequence_input, None) 43 | return module_output, final_states, restoration_indices 44 | 45 | 46 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/seq2seq_encoders/__init__.py 47 | Update date: 2019-03-03''' 48 | class LstmEncoder(EncoderBase): 49 | def __init__(self, input_size, hidden_size, num_layers, dropout=0.5, bidirectional=True): 50 | super(LstmEncoder, self).__init__() 51 | 52 | self._module = torch.nn.LSTM(batch_first=True, input_size=input_size, hidden_size=hidden_size, 53 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 54 | 55 | 56 | def forward(self, inputs, mask=None): 57 | packed_sequence_output, final_states, restoration_indices = \ 58 | self.sort_and_run_forward(self._module, inputs, mask) 59 | unpacked_sequence_tensor, _ = pad_packed_sequence(packed_sequence_output, batch_first=True) 60 | return unpacked_sequence_tensor.index_select(0, restoration_indices) -------------------------------------------------------------------------------- /data/cadec/split/test.id: -------------------------------------------------------------------------------- 1 | LIPITOR.759 2 | LIPITOR.553 3 | LIPITOR.625 4 | DICLOFENAC-SODIUM.3 5 | LIPITOR.355 6 | LIPITOR.28 7 | LIPITOR.49 8 | LIPITOR.320 9 | LIPITOR.840 10 | LIPITOR.15 11 | LIPITOR.804 12 | LIPITOR.777 13 | VOLTAREN.1 14 | LIPITOR.416 15 | LIPITOR.474 16 | ARTHROTEC.88 17 | LIPITOR.468 18 | ARTHROTEC.104 19 | LIPITOR.558 20 | LIPITOR.703 21 | LIPITOR.400 22 | LIPITOR.951 23 | LIPITOR.839 24 | LIPITOR.517 25 | LIPITOR.346 26 | LIPITOR.127 27 | ARTHROTEC.126 28 | VOLTAREN-XR.10 29 | LIPITOR.532 30 | LIPITOR.720 31 | LIPITOR.77 32 | LIPITOR.737 33 | LIPITOR.927 34 | LIPITOR.125 35 | LIPITOR.842 36 | ARTHROTEC.73 37 | LIPITOR.641 38 | LIPITOR.406 39 | VOLTAREN-XR.21 40 | LIPITOR.545 41 | LIPITOR.448 42 | LIPITOR.215 43 | LIPITOR.137 44 | ARTHROTEC.135 45 | LIPITOR.212 46 | LIPITOR.302 47 | LIPITOR.667 48 | LIPITOR.942 49 | LIPITOR.310 50 | LIPITOR.21 51 | LIPITOR.583 52 | LIPITOR.157 53 | LIPITOR.885 54 | PENNSAID.1 55 | LIPITOR.465 56 | LIPITOR.441 57 | LIPITOR.184 58 | LIPITOR.991 59 | LIPITOR.982 60 | LIPITOR.877 61 | VOLTAREN-XR.4 62 | LIPITOR.597 63 | LIPITOR.732 64 | VOLTAREN.6 65 | ARTHROTEC.43 66 | LIPITOR.150 67 | LIPITOR.353 68 | LIPITOR.736 69 | LIPITOR.926 70 | LIPITOR.767 71 | LIPITOR.256 72 | LIPITOR.126 73 | LIPITOR.869 74 | LIPITOR.257 75 | LIPITOR.467 76 | LIPITOR.276 77 | LIPITOR.234 78 | LIPITOR.496 79 | LIPITOR.686 80 | LIPITOR.91 81 | LIPITOR.245 82 | LIPITOR.181 83 | CAMBIA.3 84 | LIPITOR.185 85 | LIPITOR.10 86 | LIPITOR.567 87 | LIPITOR.586 88 | LIPITOR.557 89 | LIPITOR.300 90 | VOLTAREN.7 91 | LIPITOR.854 92 | VOLTAREN-XR.8 93 | LIPITOR.722 94 | LIPITOR.772 95 | VOLTAREN-XR.1 96 | ZIPSOR.5 97 | LIPITOR.916 98 | LIPITOR.221 99 | LIPITOR.315 100 | LIPITOR.514 101 | LIPITOR.613 102 | LIPITOR.301 103 | ARTHROTEC.123 104 | LIPITOR.934 105 | ARTHROTEC.28 106 | LIPITOR.915 107 | LIPITOR.7 108 | LIPITOR.121 109 | ARTHROTEC.102 110 | ARTHROTEC.122 111 | LIPITOR.158 112 | LIPITOR.46 113 | LIPITOR.792 114 | LIPITOR.609 115 | LIPITOR.22 116 | LIPITOR.327 117 | ARTHROTEC.115 118 | LIPITOR.870 119 | LIPITOR.367 120 | ARTHROTEC.99 121 | LIPITOR.970 122 | LIPITOR.331 123 | LIPITOR.108 124 | LIPITOR.94 125 | ARTHROTEC.37 126 | LIPITOR.295 127 | LIPITOR.359 128 | LIPITOR.271 129 | VOLTAREN-XR.18 130 | LIPITOR.666 131 | LIPITOR.373 132 | ARTHROTEC.92 133 | LIPITOR.716 134 | LIPITOR.569 135 | VOLTAREN.8 136 | LIPITOR.790 137 | LIPITOR.862 138 | LIPITOR.431 139 | LIPITOR.231 140 | LIPITOR.306 141 | LIPITOR.699 142 | LIPITOR.622 143 | LIPITOR.461 144 | LIPITOR.159 145 | LIPITOR.540 146 | VOLTAREN.38 147 | LIPITOR.850 148 | LIPITOR.451 149 | LIPITOR.156 150 | ARTHROTEC.57 151 | LIPITOR.821 152 | LIPITOR.986 153 | LIPITOR.529 154 | ARTHROTEC.136 155 | LIPITOR.780 156 | LIPITOR.693 157 | LIPITOR.941 158 | LIPITOR.656 159 | ARTHROTEC.108 160 | LIPITOR.425 161 | LIPITOR.568 162 | LIPITOR.23 163 | ARTHROTEC.27 164 | LIPITOR.883 165 | LIPITOR.548 166 | LIPITOR.84 167 | LIPITOR.727 168 | LIPITOR.161 169 | LIPITOR.429 170 | LIPITOR.68 171 | LIPITOR.681 172 | LIPITOR.223 173 | LIPITOR.783 174 | LIPITOR.994 175 | LIPITOR.714 176 | LIPITOR.762 177 | LIPITOR.800 178 | LIPITOR.81 179 | LIPITOR.752 180 | LIPITOR.415 181 | LIPITOR.145 182 | LIPITOR.196 183 | LIPITOR.98 184 | ARTHROTEC.120 185 | LIPITOR.344 186 | LIPITOR.427 187 | LIPITOR.743 188 | LIPITOR.133 -------------------------------------------------------------------------------- /data/cadec/split/dev.id: -------------------------------------------------------------------------------- 1 | LIPITOR.124 2 | LIPITOR.174 3 | LIPITOR.188 4 | DICLOFENAC-SODIUM.1 5 | ARTHROTEC.83 6 | LIPITOR.644 7 | LIPITOR.368 8 | LIPITOR.828 9 | LIPITOR.329 10 | LIPITOR.593 11 | LIPITOR.765 12 | LIPITOR.308 13 | LIPITOR.755 14 | LIPITOR.805 15 | LIPITOR.890 16 | LIPITOR.741 17 | ARTHROTEC.67 18 | LIPITOR.612 19 | ARTHROTEC.58 20 | CATAFLAM.7 21 | LIPITOR.57 22 | LIPITOR.34 23 | LIPITOR.141 24 | LIPITOR.794 25 | LIPITOR.908 26 | LIPITOR.779 27 | LIPITOR.372 28 | ARTHROTEC.87 29 | LIPITOR.490 30 | LIPITOR.70 31 | LIPITOR.796 32 | ARTHROTEC.40 33 | LIPITOR.541 34 | VOLTAREN.34 35 | DICLOFENAC-SODIUM.6 36 | DICLOFENAC-SODIUM.4 37 | ARTHROTEC.26 38 | LIPITOR.192 39 | LIPITOR.325 40 | LIPITOR.992 41 | LIPITOR.808 42 | LIPITOR.512 43 | ARTHROTEC.31 44 | LIPITOR.33 45 | LIPITOR.930 46 | LIPITOR.740 47 | LIPITOR.912 48 | ARTHROTEC.22 49 | CATAFLAM.6 50 | LIPITOR.868 51 | LIPITOR.190 52 | LIPITOR.163 53 | LIPITOR.475 54 | LIPITOR.910 55 | LIPITOR.679 56 | LIPITOR.510 57 | LIPITOR.827 58 | LIPITOR.768 59 | LIPITOR.678 60 | LIPITOR.559 61 | LIPITOR.268 62 | VOLTAREN.20 63 | LIPITOR.819 64 | VOLTAREN-XR.6 65 | VOLTAREN-XR.2 66 | LIPITOR.489 67 | LIPITOR.396 68 | LIPITOR.324 69 | LIPITOR.426 70 | LIPITOR.526 71 | LIPITOR.949 72 | LIPITOR.651 73 | LIPITOR.338 74 | LIPITOR.103 75 | LIPITOR.264 76 | LIPITOR.995 77 | LIPITOR.265 78 | ARTHROTEC.48 79 | LIPITOR.65 80 | VOLTAREN.19 81 | ARTHROTEC.70 82 | LIPITOR.785 83 | LIPITOR.282 84 | LIPITOR.397 85 | LIPITOR.266 86 | LIPITOR.603 87 | LIPITOR.978 88 | LIPITOR.997 89 | LIPITOR.342 90 | ARTHROTEC.13 91 | LIPITOR.246 92 | VOLTAREN.36 93 | LIPITOR.634 94 | LIPITOR.835 95 | LIPITOR.653 96 | LIPITOR.307 97 | LIPITOR.939 98 | LIPITOR.506 99 | ARTHROTEC.32 100 | LIPITOR.725 101 | DICLOFENAC-SODIUM.2 102 | LIPITOR.370 103 | LIPITOR.217 104 | LIPITOR.602 105 | ARTHROTEC.63 106 | LIPITOR.242 107 | ARTHROTEC.142 108 | LIPITOR.834 109 | LIPITOR.459 110 | LIPITOR.538 111 | LIPITOR.672 112 | SOLARAZE.3 113 | CATAFLAM.1 114 | VOLTAREN.9 115 | LIPITOR.309 116 | LIPITOR.394 117 | LIPITOR.380 118 | LIPITOR.754 119 | LIPITOR.959 120 | LIPITOR.277 121 | LIPITOR.66 122 | ARTHROTEC.90 123 | LIPITOR.760 124 | LIPITOR.47 125 | LIPITOR.222 126 | LIPITOR.495 127 | LIPITOR.390 128 | LIPITOR.530 129 | VOLTAREN.22 130 | LIPITOR.319 131 | ARTHROTEC.116 132 | LIPITOR.130 133 | ARTHROTEC.36 134 | LIPITOR.757 135 | LIPITOR.365 136 | LIPITOR.375 137 | LIPITOR.788 138 | LIPITOR.974 139 | LIPITOR.228 140 | LIPITOR.511 141 | LIPITOR.452 142 | LIPITOR.972 143 | LIPITOR.376 144 | LIPITOR.749 145 | VOLTAREN.11 146 | LIPITOR.476 147 | ARTHROTEC.109 148 | LIPITOR.556 149 | ARTHROTEC.46 150 | ARTHROTEC.121 151 | LIPITOR.455 152 | LIPITOR.482 153 | LIPITOR.882 154 | LIPITOR.177 155 | LIPITOR.654 156 | LIPITOR.549 157 | ARTHROTEC.52 158 | ARTHROTEC.138 159 | LIPITOR.379 160 | LIPITOR.831 161 | LIPITOR.528 162 | LIPITOR.53 163 | LIPITOR.109 164 | LIPITOR.480 165 | LIPITOR.119 166 | LIPITOR.230 167 | LIPITOR.585 168 | LIPITOR.961 169 | LIPITOR.110 170 | LIPITOR.48 171 | LIPITOR.507 172 | ARTHROTEC.95 173 | LIPITOR.570 174 | VOLTAREN.26 175 | LIPITOR.263 176 | ARTHROTEC.117 177 | LIPITOR.424 178 | LIPITOR.967 179 | LIPITOR.74 180 | LIPITOR.105 181 | LIPITOR.316 182 | LIPITOR.402 183 | LIPITOR.18 184 | ARTHROTEC.2 185 | LIPITOR.537 186 | LIPITOR.993 187 | LIPITOR.712 -------------------------------------------------------------------------------- /data/cadec/convert_text_inline.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse 3 | from collections import defaultdict 4 | 5 | 6 | def parse_parameters(parser=None): 7 | if parser is None: parser = argparse.ArgumentParser() 8 | 9 | ## Required 10 | parser.add_argument("--input_ann", default="/data/dai031/Experiments/CADEC/tokens.ann", type=str) 11 | parser.add_argument("--input_tokens", default="/data/dai031/Experiments/CADEC/tokens", type=str) 12 | parser.add_argument("--output_filepath", default="/data/dai031/Experiments/CADEC/text-inline", type=str) 13 | parser.add_argument("--no_doc_info", action="store_true") 14 | 15 | args, _ = parser.parse_known_args() 16 | return args 17 | 18 | 19 | def output_sentence(f, tokens, mentions, doc=None): 20 | def check_mention_text(tokens, mentions): 21 | for mention in mentions: 22 | tokenized_mention = [] 23 | indices = [int(i) for i in mention[0].split(",")] 24 | for i in range(0, len(indices), 2): 25 | start, end = indices[i], indices[i + 1] 26 | tokenized_mention += tokens[start:end + 1] 27 | 28 | if "".join(tokenized_mention) != mention[2].replace(" ", ""): 29 | print("%s (original) vs %s (tokenized)" % (mention[2], " ".join(tokenized_mention))) 30 | 31 | if doc is not None: 32 | f.write("Document: %s\n" % doc) 33 | f.write("%s\n" % (" ".join(tokens))) 34 | check_mention_text(tokens, mentions) 35 | mentions = ["%s %s" % (m[0], m[1]) for m in mentions] 36 | f.write("%s\n\n" % ("|".join(mentions))) 37 | 38 | 39 | def load_mentions(filepath): 40 | mentions = defaultdict(list) 41 | with open(filepath) as f: 42 | for line in f: 43 | sp = line.strip().split("\t") 44 | assert len(sp) == 5 45 | doc, sent_idx, label, indices, mention = sp 46 | mentions[(doc, int(sent_idx))].append((indices, label, mention)) 47 | return mentions 48 | 49 | 50 | if __name__ == "__main__": 51 | args = parse_parameters() 52 | mentions = load_mentions(args.input_ann) 53 | 54 | with open(args.output_filepath, "w") as out_f: 55 | with open(args.input_tokens) as in_f: 56 | pre_doc, sent_idx = None, 0 57 | tokens = [] 58 | for line in in_f: 59 | if len(line.strip()) == 0: 60 | if len(tokens) > 0: 61 | assert pre_doc is not None 62 | output_sentence(out_f, tokens, mentions.get((pre_doc, sent_idx), []), 63 | None if args.no_doc_info else pre_doc) 64 | sent_idx += 1 65 | tokens = [] 66 | continue 67 | sp = line.strip().split() 68 | token, doc, _, _ = sp 69 | if pre_doc is None: 70 | pre_doc = doc 71 | if pre_doc != doc: 72 | pre_doc = doc 73 | sent_idx = 0 74 | assert len(tokens) == 0 75 | tokens.append(token) 76 | if len(tokens) > 0: 77 | assert pre_doc is not None 78 | output_sentence(out_f, tokens, mentions.get((pre_doc, sent_idx), []), 79 | None if args.no_doc_info else pre_doc) -------------------------------------------------------------------------------- /code/xdai/utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse, logging 2 | 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | '''Update date: 2019-Nov-5''' 8 | def parse_parameters(parser=None): 9 | if parser is None: parser = argparse.ArgumentParser() 10 | 11 | ## Data 12 | parser.add_argument("--train_filepath", default=None, type=str) 13 | parser.add_argument("--num_train_instances", default=None, type=int) 14 | parser.add_argument("--dev_filepath", default=None, type=str) 15 | parser.add_argument("--num_dev_instances", default=None, type=int) 16 | parser.add_argument("--test_filepath", default=None, type=str) 17 | parser.add_argument("--cache_dir", default=None, type=str) 18 | parser.add_argument("--overwrite_cache", action="store_true") 19 | parser.add_argument("--output_dir", default=None, type=str) 20 | parser.add_argument("--overwrite_output_dir", action="store_true") 21 | parser.add_argument("--log_filepath", default=None, type=str) 22 | parser.add_argument("--summary_json", default=None, type=str) 23 | parser.add_argument("--encoding", default="utf-8-sig", type=str) 24 | 25 | ## Train 26 | parser.add_argument("--do_train", action="store_true") 27 | parser.add_argument("--learning_rate", default=5e-5, type=float) 28 | parser.add_argument("--train_batch_size_per_gpu", default=8, type=int) 29 | parser.add_argument("--num_train_epochs", default=3, type=int) 30 | parser.add_argument("--max_steps", default=0, type=int, help="If > 0, override num_train_epochs.") 31 | parser.add_argument("--warmup_steps", default=0, type=int) 32 | parser.add_argument("--logging_steps", default=50, type=int) 33 | parser.add_argument("--save_steps", default=0, type=int) 34 | parser.add_argument("--max_save_checkpoints", default=2, type=int) 35 | parser.add_argument("--patience", default=0, type=int) 36 | parser.add_argument("--max_grad_norm", default=None, type=float) 37 | parser.add_argument("--grad_clipping", default=5, type=float) 38 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 39 | parser.add_argument("--weight_decay", default=None, type=float) 40 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, 41 | help="Number of update steps to accumulate before performing a backward/update pass") 42 | parser.add_argument("--seed", default=52, type=int) 43 | parser.add_argument("--cuda_device", default="0", type=str, help="a list cuda devices, splitted by ,") 44 | 45 | ## Evaluation 46 | parser.add_argument("--do_eval", action="store_true") 47 | parser.add_argument("--eval_batch_size_per_gpu", default=8, type=int) 48 | parser.add_argument("--eval_metric", default=None, type=str) 49 | parser.add_argument("--eval_all_checkpoints", action="store_true", help="Evaluate all checkpoints.") 50 | parser.add_argument("--eval_during_training", action="store_true", 51 | help="Evaluate during training at each save step.") 52 | 53 | ## Model 54 | parser.add_argument("--model_type", default=None, type=str) 55 | parser.add_argument("--pretrained_model_dir", default=None, type=str) 56 | parser.add_argument("--max_seq_length", default=128, type=int) 57 | parser.add_argument("--labels", default="0,1", type=str) 58 | parser.add_argument("--label_filepath", default=None, type=str) 59 | parser.add_argument("--tag_schema", default="B,I") 60 | parser.add_argument("--do_lower_case", action="store_true") 61 | 62 | args, _ = parser.parse_known_args() 63 | 64 | return args -------------------------------------------------------------------------------- /code/xdai/utils/token.py: -------------------------------------------------------------------------------- 1 | import re, spacy 2 | from typing import List, NamedTuple 3 | from xdai.utils.common import load_spacy_model 4 | 5 | 6 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/token.py 7 | Update date: 2019-Nov-5''' 8 | class Token(NamedTuple): 9 | text: str = None 10 | start: int = None # the character offset of this token into the tokenized sentence. 11 | 12 | 13 | @property 14 | def end(self): 15 | return self.start + len(self.text) 16 | 17 | 18 | def __str__(self): 19 | return self.text 20 | 21 | 22 | def __repr__(self): 23 | return self.__str__() 24 | 25 | 26 | '''Update date: 2019-Nov-26''' 27 | def preprocess_twitter(text): 28 | tokens = [] 29 | for token in text.strip().split(): 30 | if len(token) > 1 and (token[0] == "@" or token[0] == "#"): 31 | tokens.append(token[0]) 32 | tokens.append(token[1:]) 33 | else: 34 | tokens.append(token) 35 | return " ".join(tokens) 36 | 37 | 38 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/letters_digits_tokenizer.py 39 | Update date: 2019-Nov-25''' 40 | class LettersDigitsTokenizer: 41 | def tokenize(self, text): 42 | tokens = [Token(m.group(), start=m.start()) for m in re.finditer(r"[^\W\d_]+|\d+|\S", text)] 43 | return tokens 44 | 45 | 46 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/whitespace_tokenizer.py 47 | Update date: 2019-Nov-25''' 48 | class WhitespaceTokenizer: 49 | def tokenize(self, text): 50 | return [Token(t) for t in text.split()] 51 | 52 | 53 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/spacy_tokenizer.py#_remove_spaces 54 | Update date: 2019-Nov-25''' 55 | def _remove_spaces(tokens: List[spacy.tokens.Token]): 56 | return [t for t in tokens if not t.is_space] 57 | 58 | 59 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/spacy_tokenizer.py 60 | Update date: 2019-Nov-25''' 61 | class SpacyTokenizer: 62 | def __init__(self, language="en_core_web_sm"): 63 | self.spacy = load_spacy_model(language) 64 | 65 | 66 | def _sanitize(self, tokens): 67 | return [Token(t.text, t.idx) for t in tokens] 68 | 69 | 70 | def batch_tokenize(self, texts: List[str]): 71 | return [self._sanitize(_remove_spaces(tokens)) for tokens in self.spacy.pipe(texts, n_threads=-1)] 72 | 73 | 74 | def tokenize(self, text): 75 | return self._sanitize(_remove_spaces(self.spacy(text))) 76 | 77 | 78 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/sentence_splitter.py 79 | Update date: 2019-Nov-25''' 80 | class SpacySentenceSplitter: 81 | def __init__(self, language="en_core_web_sm", rule_based=False): 82 | self.spacy = load_spacy_model(language, parse=not rule_based) 83 | if rule_based: 84 | sbd_name = "sbd" if spacy.__version__ < "2.1" else "sentencizer" 85 | if not self.spacy.has_pipe(sbd_name): 86 | sbd = self.spacy.create_pipe(sbd_name) 87 | self.spacy.add_pipe(sbd) 88 | 89 | 90 | def split_sentences(self, text: str) -> List[str]: 91 | return [sent.string.strip() for sent in self.spacy(text).sents] 92 | 93 | 94 | def batch_split_sentences(self, texts: List[str]) -> List[List[str]]: 95 | return [[sent.string.strip() for sent in doc.sents] for doc in self.spacy.pipe(texts)] -------------------------------------------------------------------------------- /data/share2013/convert_text_inline.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, os, sys 2 | from collections import defaultdict 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def parse_parameters(parser=None): 8 | if parser is None: parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--output_dir", default=None, type=str) 11 | parser.add_argument("--log_filepath", default="output.log", type=str) 12 | 13 | args, _ = parser.parse_known_args() 14 | return args 15 | 16 | 17 | def _output_sentence(out_f, tokens, anns, document, sentence): 18 | out_f.write("%s\n" % (" ".join(tokens))) 19 | mentions = anns.get((document, sentence), []) 20 | mentions = ["%s Disorder" % (mention) for mention in mentions] 21 | out_f.write("%s\n\n" % ("|".join(mentions))) 22 | 23 | 24 | def read_data(ann_filepath, tokens_filepath): 25 | anns = defaultdict(list) 26 | with open(ann_filepath) as f: 27 | for line in f: 28 | sp = line.strip().split("\t") 29 | assert len(sp) == 4 or len(sp) == 5 30 | document, sentence_idx, _, indices = sp[0:4] 31 | anns[(document, int(sentence_idx))].append((indices)) 32 | 33 | with open(tokens_filepath) as f: 34 | sentences = [] 35 | pre_doc, sentence_idx = None, 0 36 | tokens = [] 37 | for line in f: 38 | if len(line.strip()) == 0: 39 | if len(tokens) > 0: 40 | assert pre_doc is not None 41 | sentences.append((pre_doc, sentence_idx, tokens)) 42 | sentence_idx += 1 43 | tokens = [] 44 | continue 45 | sp = line.strip().split() 46 | token, cur_doc, _, _ = sp 47 | if pre_doc is None: 48 | pre_doc = cur_doc 49 | if pre_doc != cur_doc: 50 | pre_doc = cur_doc 51 | sentence_idx = 0 52 | assert len(tokens) == 0 53 | tokens.append(token) 54 | if len(tokens) > 0: 55 | assert pre_doc is not None 56 | sentences.append((pre_doc, sentence_idx, tokens)) 57 | 58 | return anns, sentences 59 | 60 | 61 | if __name__ == "__main__": 62 | args = parse_parameters() 63 | handlers = [logging.FileHandler(filename=args.log_filepath), logging.StreamHandler(sys.stdout)] 64 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 65 | level=logging.INFO, handlers=handlers) 66 | 67 | train_ann, train_sentences = read_data("train.token.ann", "train.tokens") 68 | test_ann, test_sentences = read_data("test.token.ann", "test.tokens") 69 | 70 | dev_list = [n.strip() for n in open("dev.list").readlines()] 71 | 72 | with open(os.path.join(args.output_dir, "train.txt"), "w") as f: 73 | for sentence in train_sentences: 74 | document, sentence_idx, tokens = sentence 75 | if document not in dev_list: 76 | _output_sentence(f, tokens, train_ann, document, sentence_idx) 77 | 78 | with open(os.path.join(args.output_dir, "dev.txt"), "w") as f: 79 | for sentence in train_sentences: 80 | document, sentence_idx, tokens = sentence 81 | if document in dev_list: 82 | _output_sentence(f, tokens, train_ann, document, sentence_idx) 83 | 84 | with open(os.path.join(args.output_dir, "test.txt"), "w") as f: 85 | for sentence in test_sentences: 86 | document, sentence_idx, tokens = sentence 87 | _output_sentence(f, tokens, test_ann, document, sentence_idx) -------------------------------------------------------------------------------- /data/share2013/output.log: -------------------------------------------------------------------------------- 1 | 11/07/2020 20:27:36 - INFO - __main__ - 1469,1472,1473,1475 in 22821-026994-DISCHARGE_SUMMARY.txt.knowtator.xml is not a real discontinuous entity, as the mention (ABD NT) has no gap. 2 | 11/07/2020 20:27:36 - INFO - __main__ - 1211,1216,1217,1221 in 04303-005081-DISCHARGE_SUMMARY.txt.knowtator.xml is not a real discontinuous entity, as the mention (chest pain) has no gap. 3 | 11/07/2020 20:27:36 - INFO - __main__ - 2453,2460,2466,2474 in 02652-006395-DISCHARGE_SUMMARY.txt.knowtator.xml is not a real discontinuous entity, as the mention (ABSCESS RT FLANK) has no gap. 4 | 11/07/2020 20:27:36 - INFO - __main__ - 2453,2460,2466,2474 in 02652-006395-DISCHARGE_SUMMARY.txt.knowtator.xml is not a real discontinuous entity, as the mention (ABSCESS RT FLANK) has no gap. 5 | 11/07/2020 20:27:45 - INFO - __main__ - 07797-005646-DISCHARGE_SUMMARY.txt||Disease_Disorder||C0456867||554||577||580||588 is not a real discontinuous entity, as the mention (high grade large B-cell lymphoma) has no gap. 6 | 11/07/2020 20:33:34 - INFO - __main__ - 9758 sentences and 142159 tokens in /data/dai031/Corpora/ShAReCLEF2013/train/text 7 | 11/07/2020 20:33:42 - INFO - __main__ - 9009 sentences and 136783 tokens in /data/dai031/Corpora/ShAReCLEF2013/test/text 8 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 00098-016139-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 9 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 04303-005081-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 10 | 11/07/2020 20:37:47 - INFO - __main__ - Find token whose original end offset is 5869 by adjusting its offset 1. 11 | 11/07/2020 20:37:47 - INFO - __main__ - Find token whose original end offset is 2475 by adjusting its offset 1. 12 | 11/07/2020 20:37:47 - INFO - __main__ - Find token whose original end offset is 178 by adjusting its offset -1. 13 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 07429-001857-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 14 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 02410-026171-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 15 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 01234-029456-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 16 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 21833-003461-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 17 | 11/07/2020 20:37:47 - INFO - __main__ - Find token whose original end offset is 605 by adjusting its offset 1. 18 | 11/07/2020 20:37:47 - INFO - __main__ - Find token whose original start offset is 886 by adjusting its offset 1. 19 | 11/07/2020 20:37:47 - INFO - __main__ - Convert 5815 out of 5821 annotations. 20 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 19778-001791-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 21 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 19778-001791-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 22 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 08990-002227-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 23 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 08990-002227-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 24 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 07797-005646-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 25 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 09339-028983-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 26 | 11/07/2020 20:37:47 - INFO - __main__ - This annotation from document 09339-028983-DISCHARGE_SUMMARY.txt is abandoned because it crossing multiple sentences. 27 | 11/07/2020 20:37:47 - INFO - __main__ - Convert 5333 out of 5340 annotations. 28 | -------------------------------------------------------------------------------- /code/xdai/ner/transition_discontinuous/test_parsing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import logging, os, sys 3 | sys.path.insert(0, os.path.abspath("../../..")) 4 | 5 | from xdai.ner.transition_discontinuous.parsing import Parser 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def _sanity_check_actions(): 12 | sentences, actions = [], [] 13 | sentences.append("all joints , muscles and tendons hurt all over my body") 14 | actions.append("OUT SHIFT OUT SHIFT OUT SHIFT SHIFT RIGHT-REDUCE COMPLETE-Finding RIGHT-REDUCE COMPLETE-Finding RIGHT-REDUCE COMPLETE-Finding OUT OUT OUT OUT") 15 | # "1,1,6,6 Finding|3,3,6,6 Finding|5,6 Finding" 16 | 17 | sentences.append("Very severe pain in arms , knees , hands .") 18 | actions.append("SHIFT SHIFT REDUCE SHIFT REDUCE SHIFT REDUCE SHIFT LEFT-REDUCE COMPLETE-Finding OUT SHIFT LEFT-REDUCE COMPLETE-Finding OUT SHIFT REDUCE COMPLETE-Finding OUT") 19 | # "0,3,6,6 ADR|0,3,8,8 ADR|0,4 ADR" 20 | 21 | parse = Parser() 22 | for s, a in zip(sentences, actions): 23 | mentions = parse.parse(a.split(), len(s.split())) 24 | logger.info("|".join([str(m) for m in mentions])) 25 | 26 | 27 | def _sanity_check_mention2actions(): 28 | sentences, mentions = [], [] 29 | sentences.append("could hardly walk or lift my arms") 30 | mentions.append("1,1,4,6 ADR|1,2 ADR") 31 | # ['OUT', 'SHIFT', 'SHIFT', 'LEFT-REDUCE', 'COMPLETE-ADR', 'OUT', 'SHIFT', 'REDUCE', 'SHIFT', 'REDUCE', 'SHIFT', 'REDUCE', 'COMPLETE-ADR'] 32 | 33 | 34 | sentences.append("all joints , muscles and tendons hurt all over my body") 35 | mentions.append("1,1,6,6 Finding|3,3,6,6 Finding|5,6 Finding") 36 | # "OUT SHIFT OUT SHIFT OUT SHIFT SHIFT RIGHT-REDUCE COMPLETE-Finding RIGHT-REDUCE COMPLETE-Finding RIGHT-REDUCE COMPLETE-Finding OUT OUT OUT OUT" 37 | 38 | sentences.append("Very severe pain in arms , knees , hands .") 39 | mentions.append("0,3,6,6 ADR|0,3,8,8 ADR|0,4 ADR") 40 | # "SHIFT SHIFT REDUCE SHIFT REDUCE SHIFT REDUCE SHIFT LEFT-REDUCE COMPLETE-Finding OUT SHIFT LEFT-REDUCE COMPLETE-Finding OUT SHIFT REDUCE COMPLETE-Finding OUT" 41 | 42 | parse = Parser() 43 | for i, s in enumerate(sentences): 44 | actions = parse.mention2actions(mentions[i], len(s.split())) 45 | logger.info(actions) 46 | 47 | 48 | def _sanity_check_instance(sentence, mentions, verbose=False): 49 | sentence_length = len(sentence.split()) 50 | 51 | parse = Parser() 52 | 53 | actions = parse.mention2actions(mentions, sentence_length) 54 | preds = parse.parse(actions, sentence_length) 55 | golds = mentions.split("|") 56 | preds = [str(m) for m in preds] 57 | 58 | FP, FN = 0, 0 59 | for pred in preds: 60 | if pred not in golds: 61 | FP += 1 62 | for gold in golds: 63 | if gold not in preds: 64 | FN += 1 65 | 66 | if verbose and (not (FP == 0 and FN == 0)): 67 | logger.info(sentence) 68 | logger.info(golds) 69 | logger.info(preds) 70 | 71 | return FP == 0 and FN == 0 72 | 73 | 74 | def check_dataset(data_dir): 75 | for split in ["train", "dev", "test"]: 76 | total_sentences, error_sentences = 0, 0 77 | if not os.path.exists(os.path.join(data_dir, "%s.txt" % split)): continue 78 | with open(os.path.join(data_dir, "%s.txt" % split)) as f: 79 | for line in f: 80 | sentence = line.strip() 81 | if len(sentence) == 0: continue 82 | mentions = next(f).strip() 83 | if len(mentions) > 0: 84 | correct = _sanity_check_instance(sentence, mentions, verbose=True) 85 | total_sentences += 1 86 | if not correct: 87 | error_sentences += 1 88 | assert next(f).strip() == "" 89 | logger.info("%d errors out of %d sentences" % (error_sentences, total_sentences)) 90 | 91 | 92 | if __name__ == "__main__": 93 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 94 | level=logging.INFO, filename=None) 95 | 96 | # _sanity_check_actions() 97 | #_sanity_check_mention2actions() 98 | #check_dataset("/data/dai031/Experiments/ShARe2013") 99 | check_dataset("/data/dai031/Experiments/CADEC/adr/split") -------------------------------------------------------------------------------- /code/xdai/utils/common.py: -------------------------------------------------------------------------------- 1 | import json, logging, os, random, shutil, spacy, torch 2 | import numpy as np 3 | from typing import Any, Dict 4 | 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | '''Update date: 2019-Nov-3''' 10 | def create_output_dir(args): 11 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 12 | if args.overwrite_output_dir: 13 | shutil.rmtree(args.output_dir) 14 | else: 15 | raise ValueError("Output directory (%s) already exists." % args.output_dir) 16 | os.makedirs(args.output_dir, exist_ok=True) 17 | 18 | 19 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/common/util.py#dump_metrics 20 | Update date: 2019-Nov-3''' 21 | def dump_metrics(file_path: str, metrics: Dict[str, Any], log: bool = False) -> None: 22 | metrics_json = json.dumps(metrics, indent=2) 23 | with open(file_path, "w") as metrics_file: 24 | metrics_file.write(metrics_json) 25 | if log: 26 | logger.info("Metrics: %s", metrics_json) 27 | 28 | 29 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py 30 | Update date: 2019-April-26''' 31 | def has_tensor(obj) -> bool: 32 | if isinstance(obj, torch.Tensor): 33 | return True 34 | if isinstance(obj, dict): 35 | return any(has_tensor(v) for v in obj.values()) 36 | if isinstance(obj, (list, tuple)): 37 | return any(has_tensor(i) for i in obj) 38 | return False 39 | 40 | 41 | LOADED_SPACY_MODELS = {} 42 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/common/util.py#get_spacy_model 43 | Update date: 2019-Nov-25''' 44 | def load_spacy_model(spacy_model_name, parse=False): 45 | options = (spacy_model_name, parse) 46 | if options not in LOADED_SPACY_MODELS: 47 | disable = ["vectors", "textcat", "tagger", "ner"] 48 | if not parse: 49 | disable.append("parser") 50 | spacy_model = spacy.load(spacy_model_name, disable=disable) 51 | LOADED_SPACY_MODELS[options] = spacy_model 52 | return LOADED_SPACY_MODELS[options] 53 | 54 | 55 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/metrics/metric.py 56 | Update date: 2019-03-01''' 57 | def move_to_cpu(*tensors): 58 | return (x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in tensors) 59 | 60 | 61 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py (move_to_device) 62 | Update date: 2019-April-26''' 63 | def move_to_gpu(obj, cuda_device=0): 64 | if cuda_device < 0 or not has_tensor(obj): return obj 65 | if isinstance(obj, torch.Tensor): return obj.cuda(cuda_device) 66 | if isinstance(obj, dict): 67 | return {k: move_to_gpu(v, cuda_device) for k, v in obj.items()} 68 | if isinstance(obj, list): 69 | return [move_to_gpu(v, cuda_device) for v in obj] 70 | if isinstance(obj, tuple): 71 | return tuple([move_to_gpu(v, cuda_device) for v in obj]) 72 | return obj 73 | 74 | 75 | '''Update date: 2019-Nov-3''' 76 | def pad_sequence_to_length(sequence, desired_length, default_value=lambda: 0): 77 | padded_sequence = sequence[:desired_length] 78 | for _ in range(desired_length - len(padded_sequence)): 79 | padded_sequence.append(default_value()) 80 | return padded_sequence 81 | 82 | 83 | '''Update date: 2019-Nov-4''' 84 | def set_cuda(args): 85 | cuda_device = [int(i) for i in args.cuda_device.split(",")] 86 | args.cuda_device = [i for i in cuda_device if i >= 0] 87 | args.n_gpu = len(args.cuda_device) 88 | logger.info("Device: %s, n_gpu: %s" % (args.cuda_device, args.n_gpu)) 89 | 90 | 91 | '''Update date: 2019-Nov-3''' 92 | def set_random_seed(args): 93 | if args.seed <= 0: 94 | logger.info("Does not set the random seed, since the value is %s" % args.seed) 95 | return 96 | random.seed(args.seed) 97 | np.random.seed(int(args.seed / 2)) 98 | torch.manual_seed(int(args.seed / 4)) 99 | torch.cuda.manual_seed_all(int(args.seed / 8)) 100 | 101 | 102 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py#sort_batch_by_length 103 | Update date: 2019-Nov-5''' 104 | def sort_batch_by_length(tensor, sequence_lengths): 105 | '''restoration_indices: sorted_tensor.index_select(0, restoration_indices) == original_tensor''' 106 | assert isinstance(tensor, torch.Tensor) and isinstance(sequence_lengths, torch.Tensor) 107 | 108 | sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True) 109 | sorted_tensor = tensor.index_select(0, permutation_index) 110 | 111 | index_range = torch.arange(0, len(sequence_lengths), device=sequence_lengths.device) 112 | _, reverse_mapping = permutation_index.sort(0, descending=False) 113 | restoration_indices = index_range.index_select(0, reverse_mapping) 114 | 115 | return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index -------------------------------------------------------------------------------- /data/share2013/extract_ann.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, os, sys 2 | import xml.etree.ElementTree as ET 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def parse_parameters(parser=None): 8 | if parser is None: parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--input_dir", default=None, type=str) 11 | parser.add_argument("--text_dir", default=None, type=str) 12 | parser.add_argument("--split", default=None, type=str) 13 | parser.add_argument("--log_filepath", default="output.log", type=str) 14 | 15 | args, _ = parser.parse_known_args() 16 | return args 17 | 18 | 19 | def process_test(input_dir, text_dir, output_filepath): 20 | mentions = {} 21 | 22 | for filename in os.listdir(input_dir): 23 | with open(os.path.join(text_dir, filename)) as text_f: 24 | text = text_f.read() 25 | with open(os.path.join(input_dir, filename)) as in_f: 26 | for line in in_f: 27 | sp = line.strip().split("||") 28 | assert sp[1] == "Disease_Disorder" and sp[0] == filename 29 | indices = [i for i in sp[3:]] 30 | indices = sorted([int(i) for i in indices]) 31 | mention_text, gap_text = [], [] 32 | for i in range(0, len(indices), 2): 33 | mention_text.append(text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 34 | for i in range(1, len(indices) - 2, 2): 35 | gap_text.append(text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 36 | mention_text = " ".join(mention_text).strip() 37 | gap_text = " ".join(gap_text).strip() 38 | if len(indices) > 2 and len(gap_text) == 0: 39 | logger.info("%s is not a real discontinuous entity, as the mention (%s) has no gap." % ( 40 | line.strip(), mention_text)) 41 | indices = [indices[0], indices[-1]] 42 | str_indices = ",".join([str(i) for i in indices]) 43 | mentions[(filename, str_indices)] = (mention_text, gap_text) 44 | 45 | with open(output_filepath, "w") as f: 46 | for k, v in mentions.items(): 47 | f.write("%s\tDisorder\t%s\t%s\t%s\n" % (k[0], k[1], v[0], v[1])) 48 | 49 | 50 | def process_train(input_dir, text_dir, output_filepath): 51 | mentions = {} 52 | 53 | for filename in os.listdir(input_dir): 54 | with open(os.path.join(text_dir, filename.replace(".knowtator.xml", "")), "r") as text_f: 55 | text = text_f.read() 56 | root = ET.parse(os.path.join(input_dir, filename)).getroot() 57 | document = root.get("textSource") 58 | 59 | assert document.find("DISCHARGE_SUMMARY") > 0 or document.find("ECHO_REPORT") > 0 or document.find( 60 | "RADIOLOGY_REPORT") > 0 or document.find("ECG_REPORT") > 0 61 | 62 | for mention in root.findall("annotation"): 63 | indices = [] 64 | for span in mention.findall("span"): 65 | indices.append(span.get("start")) 66 | indices.append(span.get("end")) 67 | indices = sorted([int(i) for i in indices]) 68 | mention_text, gap_text = [], [] 69 | for i in range(0, len(indices), 2): 70 | mention_text.append(text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 71 | for i in range(1, len(indices) - 2, 2): 72 | gap_text.append(text[indices[i]:indices[i + 1]].replace("\n", " ").replace("\t", " ").strip()) 73 | 74 | mention_text = " ".join(mention_text).strip() 75 | gap_text = " ".join(gap_text).strip() 76 | if len(indices) > 2 and len(gap_text) == 0: 77 | logger.info("%s in %s is not a real discontinuous entity, as the mention (%s) has no gap." % ( 78 | ",".join([str(i) for i in indices]), filename, mention_text)) 79 | indices = [indices[0], indices[-1]] 80 | str_indices = ",".join([str(i) for i in indices]) 81 | mentions[(document, str_indices)] = (mention_text, gap_text) 82 | 83 | with open(output_filepath, "w") as f: 84 | for k, v in mentions.items(): 85 | f.write("%s\tDisorder\t%s\t%s\t%s\n" % (k[0], k[1], v[0], v[1])) 86 | 87 | 88 | if __name__ == "__main__": 89 | args = parse_parameters() 90 | handlers = [logging.FileHandler(filename=args.log_filepath), logging.StreamHandler(sys.stdout)] 91 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 92 | level=logging.INFO, handlers=handlers) 93 | 94 | assert args.split in ["train", "test"] 95 | if args.split == "train": 96 | process_train(args.input_dir, args.text_dir, "%s.ann" % args.split) 97 | else: 98 | process_test(args.input_dir, args.text_dir, "%s.ann" % args.split) -------------------------------------------------------------------------------- /data/share2013/convert_ann_using_token_idx.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, sys 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | def parse_parameters(parser=None): 7 | if parser is None: parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument("--input_filepath", default=None, type=str) 10 | parser.add_argument("--log_filepath", default="output.log", type=str) 11 | 12 | args, _ = parser.parse_known_args() 13 | return args 14 | 15 | 16 | def _load_token_boundaries(filepath): 17 | token_start, token_end = {}, {} 18 | with open(filepath) as f: 19 | pre_doc, sentence_idx, token_idx = None, 0, 0 20 | for line in f: 21 | if len(line.strip()) == 0: 22 | sentence_idx += 1 23 | token_idx = 0 24 | continue 25 | sp = line.strip().split() 26 | assert len(sp) == 4 27 | cur_doc = sp[1] 28 | if pre_doc is None: 29 | pre_doc = cur_doc 30 | if pre_doc != cur_doc: 31 | sentence_idx = 0 32 | assert token_idx == 0 33 | pre_doc = cur_doc 34 | start, end = int(sp[2]), int(sp[3]) 35 | token_start[(cur_doc, start)] = (sentence_idx, token_idx, sp[0]) 36 | token_end[(cur_doc, end)] = (sentence_idx, token_idx, sp[0]) 37 | token_idx += 1 38 | return token_start, token_end 39 | 40 | 41 | def _find_token_idx(document, char_offset, token_boundaries, start=True): 42 | if (document, char_offset) in token_boundaries: return (token_boundaries[(document, char_offset)], 0) 43 | 44 | for offset_adjust in range(1, 10): 45 | if start: 46 | if (document, char_offset + offset_adjust) in token_boundaries: 47 | return (token_boundaries[(document, char_offset + offset_adjust)], offset_adjust) 48 | if (document, char_offset - offset_adjust) in token_boundaries: 49 | return (token_boundaries[(document, char_offset - offset_adjust)], -offset_adjust) 50 | else: 51 | if (document, char_offset - offset_adjust) in token_boundaries: 52 | return (token_boundaries[(document, char_offset - offset_adjust)], -offset_adjust) 53 | if (document, char_offset + offset_adjust) in token_boundaries: 54 | return (token_boundaries[(document, char_offset + offset_adjust)], offset_adjust) 55 | 56 | logger.info("Cannot find token whose %s offset is %s from document %s." % ("start" if start else "end", char_offset, document)) 57 | return (None, None) 58 | 59 | 60 | if __name__ == "__main__": 61 | args = parse_parameters() 62 | handlers = [logging.FileHandler(filename=args.log_filepath), logging.StreamHandler(sys.stdout)] 63 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 64 | level=logging.INFO, handlers=handlers) 65 | 66 | for split in ["train", "test"]: 67 | token_start, token_end = _load_token_boundaries("%s.tokens" % split) 68 | num_orig_ann, num_final_ann = 0, 0 69 | with open("%s.token.ann" % split, "w") as out_f: 70 | with open("%s.ann" % split) as in_f: 71 | for line in in_f: 72 | sp = line.strip().split("\t") 73 | assert len(sp) == 5 or len(sp) == 4, sp 74 | document, label, indices, mention = sp[0:4] 75 | indices = [int(i) for i in indices.split(",")] 76 | num_orig_ann += 1 77 | 78 | token_idx = [] 79 | for i in range(0, len(indices), 2): 80 | start, end = indices[i], indices[i + 1] 81 | start_token_idx, offset_adjust = _find_token_idx(document, start, token_start, True) 82 | if start_token_idx is not None: 83 | token_idx.append(start_token_idx) 84 | if offset_adjust != 0: 85 | logger.info("Find token whose original start offset is %s by adjusting its offset %s." % (start, offset_adjust)) 86 | end_token_idx, offset_adjust = _find_token_idx(document, end, token_end, False) 87 | if end_token_idx is not None: 88 | token_idx.append(end_token_idx) 89 | if offset_adjust != 0: 90 | logger.info("Find token whose original end offset is %s by adjusting its offset %s." % (end, offset_adjust)) 91 | 92 | if len(token_idx) != len(indices): 93 | logger.info("Cannot find all corresponding token indices, so abandon this annotation from document %s" % document) 94 | else: 95 | if len(set([i[0] for i in token_idx])) != 1: 96 | logger.info("This annotation from document %s is abandoned because it crossing multiple sentences." % document) 97 | else: 98 | sentence_idx = token_idx[0][0] 99 | token_idx = [str(i[1]) for i in token_idx] 100 | num_final_ann += 1 101 | out_f.write("%s\t%s\t%s\t%s\t%s\n" % (document, sentence_idx, label, ",".join(token_idx), mention)) 102 | 103 | logger.info("Convert %d out of %d annotations." % (num_final_ann, num_orig_ann)) -------------------------------------------------------------------------------- /data/cadec/convert_ann_using_token_idx.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse 3 | from typing import List 4 | 5 | 6 | def parse_parameters(parser=None): 7 | if parser is None: parser = argparse.ArgumentParser() 8 | 9 | ## Required 10 | parser.add_argument("--input_ann", default="/data/dai031/Experiments/CADEC/ann", type=str) 11 | parser.add_argument("--input_tokens", default="/data/dai031/Experiments/CADEC/tokens", type=str) 12 | parser.add_argument("--output_ann", default="/data/dai031/Experiments/CADEC/tokens.ann", type=str) 13 | 14 | args, _ = parser.parse_known_args() 15 | return args 16 | 17 | 18 | def merge_consecutive_indices(indices: List[int]) -> List[int]: 19 | '''convert 136 142 143 147 into 136 147 (these two spans are actually consecutive), 20 | 136 142 143 147 148 160 into 136 160 (these three spans are consecutive) 21 | it only makes sense when these indices are inclusive''' 22 | consecutive_indices = [] 23 | assert len(indices) % 2 == 0 24 | for i, v in enumerate(indices): 25 | if (i == 0) or (i == len(indices) - 1): 26 | consecutive_indices.append(v) 27 | else: 28 | if i % 2 == 0: 29 | if v > indices[i - 1] + 1: 30 | consecutive_indices.append(v) 31 | else: 32 | if v + 1 < indices[i + 1]: 33 | consecutive_indices.append(v) 34 | assert len(consecutive_indices) % 2 == 0 and len(consecutive_indices) <= len(indices) 35 | if len(indices) != len(consecutive_indices): 36 | indices = " ".join([str(i) for i in indices]) 37 | print("Convert from [%s] to [%s]." % (indices, " ".join([str(i) for i in consecutive_indices]))) 38 | return consecutive_indices 39 | 40 | 41 | def load_token_boundaries(filepath): 42 | token_start, token_end = {}, {} 43 | with open(filepath) as f: 44 | pre_doc, sent_idx, token_idx = None, 0, 0 45 | for line in f: 46 | if len(line.strip()) == 0 and pre_doc is not None: 47 | sent_idx += 1 48 | token_idx = 0 49 | continue 50 | sp = line.strip().split() 51 | assert len(sp) == 4 52 | token, doc, start, end = sp 53 | if pre_doc is None: 54 | pre_doc = doc 55 | if pre_doc != doc: 56 | sent_idx = 0 57 | assert token_idx == 0 58 | pre_doc = doc 59 | token_start[(doc, int(start))] = (sent_idx, token_idx, token) 60 | token_end[(doc, int(end))] = (sent_idx, token_idx, token) 61 | token_idx += 1 62 | return token_start, token_end 63 | 64 | 65 | def find_token_starting_at_offset(doc, offset, token_boundaries): 66 | if (doc, offset) in token_boundaries: return token_boundaries[(doc, offset)] 67 | 68 | adjust = 0 69 | while (offset - adjust) >= 0 and (doc, offset - adjust) not in token_boundaries: 70 | adjust += 1 71 | 72 | if (doc, offset - adjust) in token_boundaries: 73 | print("Cannot find original offset (%d) in document (%s), so use (%d) instead." % (offset, doc, offset - adjust)) 74 | return token_boundaries[(doc, offset - adjust)] 75 | else: 76 | print("Cannot find offset (%d) in document (%s)." % (offset, doc)) 77 | return None 78 | 79 | 80 | def find_token_ending_at_offset(doc, offset, token_boundaries): 81 | if (doc, offset) in token_boundaries: return token_boundaries[(doc, offset)] 82 | 83 | adjust = 0 84 | while adjust < 20 and (doc, offset + adjust) not in token_boundaries: 85 | adjust += 1 86 | 87 | if (doc, offset + adjust) in token_boundaries: 88 | print("Cannot find original offset (%d) in document (%s), so use (%d) instead." % (offset, doc, offset + adjust)) 89 | return token_boundaries[(doc, offset + adjust)] 90 | else: 91 | print("Cannot find offset (%d) in document (%s)." % (offset, doc)) 92 | return None 93 | 94 | 95 | if __name__ == "__main__": 96 | args = parse_parameters() 97 | token_start, token_end = load_token_boundaries(args.input_tokens) 98 | 99 | with open(args.output_ann, "w") as out_f: 100 | with open(args.input_ann) as in_f: 101 | for line in in_f: 102 | sp = line.strip().split("\t") 103 | assert len(sp) == 4 104 | doc, label, indices, mention = sp 105 | indices = [int(i) for i in indices.split(",")] 106 | 107 | token_indices = [] 108 | for i in range(0, len(indices), 2): 109 | start_token_idx = find_token_starting_at_offset(doc, indices[i], token_start) 110 | end_token_idx = find_token_ending_at_offset(doc, indices[i + 1], token_end) 111 | assert start_token_idx is not None and end_token_idx is not None 112 | token_indices.append(start_token_idx) 113 | token_indices.append(end_token_idx) 114 | 115 | assert len(indices) == len(token_indices) 116 | assert len(set([i[0] for i in token_indices])) == 1 117 | 118 | sent_idx = token_indices[0][0] 119 | token_indices = sorted([i[1] for i in token_indices]) 120 | token_indices = merge_consecutive_indices(token_indices) 121 | token_indices = ",".join([str(i) for i in token_indices]) 122 | out_f.write("%s\t%s\t%s\t%s\t%s\n" % (doc, sent_idx, label, token_indices, mention)) -------------------------------------------------------------------------------- /code/xdai/utils/vocab.py: -------------------------------------------------------------------------------- 1 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/vocabulary.py 2 | Update date: 2019-Nov-5''' 3 | import codecs, logging, os 4 | from collections import defaultdict 5 | from typing import Dict 6 | from tqdm import tqdm 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | DEFAULT_NON_PADDED_NAMESPACES = ["tags", "labels"] 13 | 14 | 15 | class _NamespaceDependentDefaultDict(defaultdict): 16 | def __init__(self, padded_function, non_padded_function): 17 | # we do not take non_padded_namespaces as a parameter, 18 | # because we consider any namespace whose key name ends with labels or tags as non padded namespace, 19 | # and use padded namespace otherwise 20 | self._padded_function = padded_function 21 | self._non_padded_function = non_padded_function 22 | super(_NamespaceDependentDefaultDict, self).__init__() 23 | 24 | 25 | def __missing__(self, key): 26 | if any(key.endswith(pattern) for pattern in DEFAULT_NON_PADDED_NAMESPACES): 27 | value = self._non_padded_function() 28 | else: 29 | value = self._padded_function() 30 | dict.__setitem__(self, key, value) 31 | return value 32 | 33 | 34 | class _ItemToIndexDefaultDict(_NamespaceDependentDefaultDict): 35 | def __init__(self, padding_item, oov_item): 36 | super(_ItemToIndexDefaultDict, self).__init__(lambda: {padding_item: 0, oov_item: 1}, 37 | lambda: {}) 38 | 39 | 40 | class _IndexToItemDefaultDict(_NamespaceDependentDefaultDict): 41 | def __init__(self, padding_item, oov_item): 42 | super(_IndexToItemDefaultDict, self).__init__(lambda: {0: padding_item, 1: oov_item}, 43 | lambda: {}) 44 | 45 | 46 | class Vocabulary: 47 | def __init__(self, counter: Dict[str, Dict[str, int]] = None, min_count: Dict[str, int] = None, 48 | max_vocab_size: Dict[str, int] = None): 49 | self._padding_item = "@@PADDING@@" 50 | self._oov_item = "@@UNKNOWN@@" 51 | self._item_to_index = _ItemToIndexDefaultDict(self._padding_item, self._oov_item) 52 | self._index_to_item = _IndexToItemDefaultDict(self._padding_item, self._oov_item) 53 | self._extend(counter=counter, min_count=min_count, max_vocab_size=max_vocab_size) 54 | 55 | 56 | '''Update date: 2019-Nov-9''' 57 | def save_to_files(self, directory): 58 | os.makedirs(directory, exist_ok=True) 59 | for namespace, mapping in self._index_to_item.items(): 60 | with codecs.open(os.path.join(directory, "%s.txt" % namespace), "w", "utf-8") as f: 61 | for i in range(len(mapping)): 62 | f.write("%s\n" % (mapping[i].replace("\n", "@@NEWLINE@@").strip())) 63 | 64 | 65 | '''Update date: 2019-Nov-9''' 66 | @classmethod 67 | def from_files(cls, directory): 68 | logger.info("Loading item dictionaries from %s.", directory) 69 | vocab = cls() 70 | for namespace in os.listdir(directory): 71 | if not namespace.endswith(".txt"): continue 72 | with codecs.open(os.path.join(directory, namespace), "r", "utf-8") as f: 73 | namespace = namespace.replace(".txt", "") 74 | for i, line in enumerate(f): 75 | line = line.strip() 76 | if len(line) == 0: continue 77 | item = line.replace("@@NEWLINE@@", "\n") 78 | vocab._item_to_index[namespace][item] = i 79 | vocab._index_to_item[namespace][i] = item 80 | return vocab 81 | 82 | 83 | @classmethod 84 | def from_instances(cls, instances, min_count=None, max_vocab_size=None): 85 | counter = defaultdict(lambda: defaultdict(int)) 86 | for instance in tqdm(instances): 87 | instance.count_vocab_items(counter) 88 | return cls(counter=counter, min_count=min_count, max_vocab_size=max_vocab_size) 89 | 90 | 91 | def _extend(self, counter, min_count=None, max_vocab_size=None): 92 | counter = counter or {} 93 | min_count = min_count or {} 94 | max_vocab_size = max_vocab_size or {} 95 | for namespace in counter: 96 | item_counts = list(counter[namespace].items()) 97 | item_counts.sort(key=lambda x: x[1], reverse=True) 98 | 99 | if namespace in max_vocab_size and max_vocab_size[namespace] > 0: 100 | item_counts = item_counts[:max_vocab_size[namespace]] 101 | for item, count in item_counts: 102 | if count >= min_count.get(namespace, 1): 103 | self._add_item_to_namespace(item, namespace) 104 | 105 | 106 | def _add_item_to_namespace(self, item, namespace="tokens"): 107 | if item not in self._item_to_index[namespace]: 108 | idx = len(self._item_to_index[namespace]) 109 | self._item_to_index[namespace][item] = idx 110 | self._index_to_item[namespace][idx] = item 111 | 112 | 113 | def get_index_to_item_vocabulary(self, namespace="tokens"): 114 | return self._index_to_item[namespace] 115 | 116 | 117 | def get_item_to_index_vocabulary(self, namespace="tokens"): 118 | return self._item_to_index[namespace] 119 | 120 | 121 | def get_item_index(self, item, namespace="tokens"): 122 | if item in self._item_to_index[namespace]: 123 | return self._item_to_index[namespace][item] 124 | else: 125 | return self._item_to_index[namespace][self._oov_item] 126 | 127 | 128 | def get_item_from_index(self, idx: int, namespace="tokens"): 129 | return self._index_to_item[namespace][idx] 130 | 131 | 132 | def get_vocab_size(self, namespace="tokens"): 133 | return len(self._item_to_index[namespace]) -------------------------------------------------------------------------------- /data/cadec/build_data_for_transition_discontinous_ner.log: -------------------------------------------------------------------------------- 1 | Extract annotations ... 2 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in back weakness) to (weakness in back). 3 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in shoulders weakness) to (weakness in shoulders). 4 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in legs weakness) to (weakness in legs). 5 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (pain muscle) to (muscle pain). 6 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (legs tingling in) to (tingling in legs). 7 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (arms tingling in) to (tingling in arms). 8 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (Achilles tendon pain in) to (pain in Achilles tendon). 9 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (on the back of the head pain in every tendon) to (pain in every tendon on the back of the head). 10 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (bleeding several times a month menstrual cycle is completely out of whack) to (menstrual cycle is completely out of whack bleeding several times a month). 11 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (severe joint pains) to (joint pains severe). 12 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (started going numb legs) to (legs started going numb). 13 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (serious cramping hands) to (serious hands cramping). 14 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in both legs pain) to (pain in both legs). 15 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in both legs weakness) to (weakness in both legs). 16 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (comfortably Couldn't walk) to (Couldn't walk comfortably). 17 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (comfortably Couldn't sleep) to (Couldn't sleep comfortably). 18 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (hurts throat) to (throat hurts). 19 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (very blurry eyesight) to (eyesight very blurry). 20 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (calves muscle pain in the) to (muscle pain in the calves). 21 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (thighs muscle pain in the) to (muscle pain in the thighs). 22 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (back muscle pain in the) to (muscle pain in the back). 23 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (neck muscle pain in the) to (muscle pain in the neck). 24 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (glutes muscle pain in the) to (muscle pain in the glutes). 25 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (Severe pain shoulder) to (Severe shoulder pain). 26 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in my knees joint pain) to (joint pain in my knees). 27 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (eyes microabrasions) to (eyes microabrasion). 28 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (hurt exploding arm and neck) to (arm and neck hurt exploding). 29 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (pain in stomach) to (pain i stomach). 30 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in arms Tingling) to (Tingling in arms). 31 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (feet numbing) to (numbing feet). 32 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in forearms tightness) to (tightness in forearms). 33 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (stomach rough) to (rough stomach). 34 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (in the triceps arm fatigue) to (arm fatigue in the triceps). 35 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (wrists pain) to (pain wrists). 36 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (hips pain) to (pain hips). 37 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (tire very quickly muscles) to (muscles tire very quickly). 38 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (renal failure) to (rena failure). 39 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (renal and respiratory failure) to (respiratory failure). 40 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (painful muscles) to (muscles painful). 41 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (painful joints) to (joints painful). 42 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (calves muscle pain in) to (muscle pain in calves). 43 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (calves joint pain in) to (joint pain in calves). 44 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (swelling feet) to (feet swelling). 45 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (knees aches) to (aches knees). 46 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (back aches) to (aches back). 47 | 11/06/2019 09:46:31 - INFO - __main__ - Update the mention from (chest aches) to (aches chest). 48 | 11/06/2019 09:46:31 - INFO - __main__ - Extract 6318 annotations. 49 | Tokenization ... 50 | 11/06/2019 09:56:25 - INFO - __main__ - 1250 documents, 7597 sentences, 122938 tokens 51 | Convert annotations from character level offsets to token level idx ... 52 | 11/06/2019 10:10:11 - INFO - __main__ - Find token whose original end offset is 172 by adjusting its offset 1. 53 | 11/06/2019 10:10:11 - INFO - __main__ - Find token whose original end offset is 221 by adjusting its offset 1. 54 | 11/06/2019 10:10:11 - INFO - __main__ - Find token whose original end offset is 415 by adjusting its offset 1. 55 | 11/06/2019 10:10:11 - INFO - __main__ - Find token whose original start offset is 432 by adjusting its offset 1. 56 | 11/06/2019 10:10:11 - INFO - __main__ - Convert 6318 out of 6318 annotations. 57 | Create text inline format ... 58 | Split the data set into train, dev, test splits ... 59 | -------------------------------------------------------------------------------- /code/xdai/ner/evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Usage: the input file should be of text-inline format 3 | python evaluate.py --gold_filepath /data/dai031/Experiments/CADEC/adr/split/test.txt --pred_filepath /data/dai031/Experiments/flair/cadec-adr/test.txt 4 | python evaluate.py --gold_filepath /data/dai031/Experiments/CADEC/adr/split/test.txt --pred_filepath /data/dai031/Experiments/TransitionDiscontinuous/cadec-50542/test.pred 5 | ''' 6 | import argparse, os, sys 7 | from typing import Dict, List 8 | from collections import defaultdict 9 | 10 | sys.path.insert(0, os.path.abspath("../..")) 11 | from xdai.ner.mention import Mention 12 | 13 | 14 | def parse_parameters(parser=None): 15 | if parser is None: parser = argparse.ArgumentParser() 16 | 17 | ## Required 18 | parser.add_argument("--gold_filepath", default=None, type=str) 19 | parser.add_argument("--pred_filepath", default=None, type=str) 20 | args, _ = parser.parse_known_args() 21 | return args 22 | 23 | 24 | '''Update: 2019-Nov-9''' 25 | def compute_f1(TP: int, FP: int, FN: int) -> Dict: 26 | precision = float(TP) / float(TP + FP) if TP + FP > 0 else 0 27 | recall = float(TP) / float(TP + FN) if TP + FN > 0 else 0 28 | f1 = 2. * ((precision * recall) / (precision + recall)) if precision + recall > 0 else 0 29 | return precision, recall, f1 30 | 31 | 32 | '''Update: 2019-Nov-9''' 33 | def compute_on_corpus(gold_corpus: List[List[str]], pred_corpus: List[List[str]]): 34 | assert len(gold_corpus) == len(pred_corpus) # number of sentences 35 | 36 | TP, FP, FN = defaultdict(int), defaultdict(int), defaultdict(int) 37 | for gold_sentence, pred_sentence in zip(gold_corpus, pred_corpus): 38 | for gold in gold_sentence: 39 | if gold in pred_sentence: 40 | TP[gold.split()[-1]] += 1 41 | else: 42 | FN[gold.split()[-1]] += 1 43 | for pred in pred_sentence: 44 | if pred not in gold_sentence: 45 | FP[pred.split()[-1]] += 1 46 | 47 | entity_types = set(TP.keys()) | set(FP.keys()) | set(FN.keys()) 48 | metrics = {} 49 | precision_per_type, recall_per_type, f1_per_type = [], [], [] 50 | for t in entity_types: 51 | precision, recall, f1 = compute_f1(TP[t], FP[t], FN[t]) 52 | metrics["%s-precision" % t] = precision 53 | precision_per_type.append(precision) 54 | metrics["%s-recall" % t] = recall 55 | recall_per_type.append(recall) 56 | metrics["%s-f1" % t] = f1 57 | f1_per_type.append(f1) 58 | 59 | metrics["macro-precision"] = sum(precision_per_type) / len(precision_per_type) if len(precision_per_type) > 0 else 0.0 60 | metrics["macro-recall"] = sum(recall_per_type) / len(recall_per_type) if len(recall_per_type) > 0 else 0.0 61 | metrics["macro-f1"] = sum(f1_per_type) / len(f1_per_type) if len(f1_per_type) > 0 else 0.0 62 | 63 | precision, recall, f1 = compute_f1(sum(TP.values()), sum(FP.values()), sum(FN.values())) 64 | metrics["micro-precision"] = precision 65 | metrics["micro-recall"] = recall 66 | metrics["micro-f1"] = f1 67 | 68 | return metrics 69 | 70 | 71 | '''Update: 2019-Nov-9''' 72 | def compute_on_sentences_with_disc(gold_corpus, pred_corpus): 73 | assert len(gold_corpus) == len(pred_corpus) 74 | 75 | gold_disc_corpus, pred_disc_corpus = [], [] 76 | for gold_sentence, pred_sentence in zip(gold_corpus, pred_corpus): 77 | gold_mentions = Mention.create_mentions("|".join(gold_sentence)) 78 | if any(m.discontinuous for m in gold_mentions): 79 | gold_disc_corpus.append(gold_sentence) 80 | pred_disc_corpus.append(pred_sentence) 81 | 82 | metrics = compute_on_corpus(gold_disc_corpus, pred_disc_corpus) 83 | return {"sentences_with_disc-%s" % k: v for k, v in metrics.items()} 84 | 85 | 86 | '''Update: 2019-Nov-9''' 87 | def compute_on_disc_mentions(gold_corpus, pred_corpus): 88 | assert len(gold_corpus) == len(pred_corpus) 89 | 90 | TP, FP, FN = 0.0, 0.0, 0.0 91 | for gold_sentence, pred_sentence in zip(gold_corpus, pred_corpus): 92 | gold_mentions = [m for m in Mention.create_mentions("|".join(gold_sentence)) if m.discontinuous] 93 | pred_mentions = [m for m in Mention.create_mentions("|".join(pred_sentence)) if m.discontinuous] 94 | for pred in pred_mentions: 95 | if str(pred) in gold_sentence: 96 | TP += 1 97 | else: 98 | FP += 1 99 | for gold in gold_mentions: 100 | if str(gold) not in pred_sentence: 101 | FN += 1 102 | 103 | precision, recall, f1 = compute_f1(TP, FP, FN) 104 | return {"disc-mention-micro-precision": precision, "disc-mention-micro-recall": recall, "disc-mention-micro-f1": f1} 105 | 106 | 107 | if __name__ == "__main__": 108 | args = parse_parameters() 109 | 110 | sentences = [] 111 | gold_mentions, pred_mentions = [], [] 112 | with open(args.pred_filepath) as f: 113 | for sentence in f: 114 | sentences.append(sentence.strip()) 115 | if not args.gold_filepath: 116 | gold = next(f).strip() 117 | gold = gold.split("|") if len(gold) > 0 else [] 118 | gold_mentions.append(gold) 119 | pred = next(f).strip() 120 | pred = pred.split("|") if len(pred) > 0 else [] 121 | pred_mentions.append(pred) 122 | assert len(next(f).strip()) == 0 123 | 124 | if args.gold_filepath is not None: 125 | with open(args.gold_filepath) as f: 126 | sent_id = 0 127 | for sentence in f: 128 | assert sentence.strip() == sentences[sent_id] 129 | sent_id += 1 130 | gold = next(f).strip() 131 | gold = gold.split("|") if len(gold) > 0 else [] 132 | gold_mentions.append(gold) 133 | assert len(next(f).strip()) == 0 134 | 135 | metrics = compute_on_corpus(gold_mentions, pred_mentions) 136 | for k, v in metrics.items(): 137 | if k.find("micro") >= 0: 138 | print(k, v) 139 | 140 | metrics = compute_on_sentences_with_disc(gold_mentions, pred_mentions) 141 | for k, v in metrics.items(): 142 | if k.find("micro") >= 0: 143 | print(k, v) 144 | 145 | metrics = compute_on_disc_mentions(gold_mentions, pred_mentions) 146 | for k, v in metrics.items(): 147 | if k.find("micro") >= 0: 148 | print(k, v) -------------------------------------------------------------------------------- /code/xdai/ner/transition_discontinuous/train.py: -------------------------------------------------------------------------------- 1 | import json, logging, os, sys, torch 2 | sys.path.insert(0, os.path.abspath("../../..")) 3 | 4 | from xdai.utils.args import parse_parameters 5 | from xdai.utils.common import create_output_dir, pad_sequence_to_length, set_cuda, set_random_seed 6 | from xdai.utils.instance import Instance, MetadataField, TextField 7 | from xdai.utils.iterator import BasicIterator, BucketIterator 8 | from xdai.utils.token import Token 9 | from xdai.utils.token_indexer import SingleIdTokenIndexer, TokenCharactersIndexer, ELMoIndexer 10 | from xdai.utils.train import train_op, eval_op 11 | from xdai.utils.vocab import Vocabulary 12 | from xdai.ner.transition_discontinuous.models import TransitionModel 13 | from xdai.ner.mention import Mention 14 | from xdai.ner.transition_discontinuous.parsing import Parser 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | '''Update at April-22-2019''' 21 | class ActionField: 22 | def __init__(self, actions, inputs): 23 | self._key = "actions" 24 | self.actions = actions 25 | self._indexed_actions = None 26 | self.inputs = inputs 27 | 28 | if all([isinstance(a, int) for a in actions]): 29 | self._indexed_actions = actions 30 | 31 | 32 | def count_vocab_items(self, counter): 33 | if self._indexed_actions is None: 34 | for action in self.actions: 35 | counter[self._key][action] += 1 36 | 37 | 38 | def index(self, vocab): 39 | if self._indexed_actions is None: 40 | self._indexed_actions = [vocab.get_item_index(action, self._key) for action in self.actions] 41 | 42 | 43 | def get_padding_lengths(self): 44 | return {"num_tokens": self.inputs.sequence_length() * 2} 45 | 46 | 47 | def as_tensor(self, padding_lengths): 48 | desired_num_actions = padding_lengths["num_tokens"] 49 | padded_actions = pad_sequence_to_length(self._indexed_actions, desired_num_actions) 50 | return torch.LongTensor(padded_actions) 51 | 52 | 53 | def batch_tensors(self, tensor_list): 54 | return torch.stack(tensor_list) 55 | 56 | 57 | class DatasetReader: 58 | def __init__(self, args): 59 | self.args = args 60 | self.parse = Parser() 61 | self._token_indexers = {"tokens": SingleIdTokenIndexer(), "token_characters": TokenCharactersIndexer()} 62 | if args.model_type == "elmo": 63 | self._token_indexers["elmo_characters"] = ELMoIndexer() 64 | 65 | 66 | def read(self, filepath, training=False): 67 | instances = [] 68 | with open(filepath, "r") as f: 69 | for sentence in f: 70 | tokens = [Token(t) for t in sentence.strip().split()] 71 | annotations = next(f).strip() 72 | actions = self.parse.mention2actions(annotations, len(tokens)) 73 | oracle_mentions = [str(s) for s in self.parse.parse(actions, len(tokens))] 74 | gold_mentions = annotations.split("|") if len(annotations) > 0 else [] 75 | 76 | if len(oracle_mentions) != len(gold_mentions) or len(oracle_mentions) != len( 77 | set(oracle_mentions) & set(gold_mentions)): 78 | logger.debug("Discard this instance whose oracle mention is: %s, while its gold mention is: %s" % ( 79 | "|".join(oracle_mentions), annotations)) 80 | if not training: 81 | instances.append(self._to_instance(sentence, annotations, tokens, actions)) 82 | else: 83 | instances.append(self._to_instance(sentence, annotations, tokens, actions)) 84 | 85 | assert len(next(f).strip()) == 0 86 | return instances 87 | 88 | 89 | def _to_instance(self, sentence, annotations, tokens, actions): 90 | text_fields = TextField(tokens, self._token_indexers) 91 | action_fields = ActionField(actions, text_fields) 92 | sentence = MetadataField(sentence.strip()) 93 | annotations = MetadataField(annotations.strip()) 94 | return Instance( 95 | {"sentence": sentence, "annotations": annotations, "tokens": text_fields, "actions": action_fields}) 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_parameters() 100 | create_output_dir(args) 101 | 102 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 103 | level=logging.INFO, filename=args.log_filepath) 104 | addition_args = json.load(open("config.json")) 105 | for k, v in addition_args.items(): 106 | setattr(args, k, v) 107 | logger.info( 108 | "Parameters: %s" % json.dumps({k: v for k, v in vars(args).items() if v is not None}, indent=2, sort_keys=True)) 109 | 110 | set_cuda(args) 111 | set_random_seed(args) 112 | 113 | dataset_reader = DatasetReader(args) 114 | train_data = dataset_reader.read(args.train_filepath, training=True) 115 | if args.dev_filepath is None: 116 | num_dev_instances = int(len(train_data) / 10) 117 | dev_data = train_data[0:num_dev_instances] 118 | train_data = train_data[num_dev_instances:] 119 | else: 120 | dev_data = dataset_reader.read(args.dev_filepath) 121 | if args.num_train_instances is not None: train_data = train_data[0:args.num_train_instances] 122 | if args.num_dev_instances is not None: dev_data = dev_data[0:args.num_dev_instances] 123 | logger.info("Load %d instances from train set." % (len(train_data))) 124 | logger.info("Load %d instances from dev set." % (len(dev_data))) 125 | test_data = dataset_reader.read(args.test_filepath) 126 | logger.info("Load %d instances from test set." % (len(test_data))) 127 | 128 | datasets = {"train": train_data, "validation": dev_data, "test": test_data} 129 | vocab = Vocabulary.from_instances((instance for dataset in datasets.values() for instance in dataset)) 130 | vocab.save_to_files(os.path.join(args.output_dir, "vocabulary")) 131 | train_iterator = BucketIterator(sorting_keys=[['tokens', 'tokens_length']], batch_size=args.train_batch_size_per_gpu) 132 | train_iterator.index_with(vocab) 133 | dev_iterator = BasicIterator(batch_size=args.eval_batch_size_per_gpu) 134 | dev_iterator.index_with(vocab) 135 | 136 | model = TransitionModel(args, vocab).cuda(args.cuda_device[0]) 137 | parameters = [p for _, p in model.named_parameters() if p.requires_grad] 138 | 139 | optimizer = torch.optim.Adam(parameters, lr=args.learning_rate) 140 | 141 | metrics = train_op(args, model, optimizer, train_data, train_iterator, dev_data, dev_iterator) 142 | logger.info(metrics) 143 | 144 | model.load_state_dict(torch.load(os.path.join(args.output_dir, "best.th"))) 145 | test_metrics, test_preds = eval_op(args, model, test_data, dev_iterator) 146 | logger.info(test_metrics) 147 | with open(os.path.join(args.output_dir, "test.pred"), "w") as f: 148 | for i in test_preds: 149 | f.write("%s\n%s\n\n" % (i[0], i[1])) 150 | 151 | if args.dev_filepath is not None: 152 | dev_metrics, dev_preds = eval_op(args, model, dev_data, dev_iterator) 153 | logger.info(dev_metrics) 154 | with open(os.path.join(args.output_dir, "dev.pred"), "w") as f: 155 | for i in dev_preds: 156 | f.write("%s\n%s\n\n" % (i[0], i[1])) -------------------------------------------------------------------------------- /code/xdai/ner/transition_discontinuous/parsing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | from xdai.ner.mention import merge_consecutive_indices, Mention 4 | 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | '''Update date: 2019-Nov-5''' 10 | class _NodeInStack(Mention): 11 | @classmethod 12 | def single_token_node(cls, idx): 13 | return Mention.create_mention([idx, idx], "") 14 | 15 | 16 | @classmethod 17 | def merge_nodes(cls, m1, m2): 18 | # TODO: m1 and m2 cannot completely contain each other 19 | indices = sorted(m1.indices + m2.indices) 20 | indices = merge_consecutive_indices(indices) 21 | return Mention.create_mention(indices, "") 22 | 23 | 24 | 25 | '''Update date: 2019-Nov-5''' 26 | class Parser(object): 27 | # actions include SHIFT, OUT, COMPLETE-Y, REDUCE, LEFT-REDUCE, RIGHT-REDUCE 28 | def parse(self, actions, seq_length=None) -> List[Mention]: 29 | mentions, stack = [], [] 30 | if seq_length is None: seq_length = len(actions) * 2 31 | buffer = [i for i in range(seq_length)] 32 | for action in actions: 33 | if action == "SHIFT": 34 | if len(buffer) < 1: 35 | logger.info("Invalid SHIFT action: the buffer is empty.") 36 | else: 37 | stack.append(_NodeInStack.single_token_node(buffer[0])) 38 | buffer.pop(0) 39 | elif action == "OUT": 40 | if len(buffer) < 1: 41 | logger.info("Invalid OUT action: the buffer is empty.") 42 | else: 43 | buffer.pop(0) 44 | elif action.startswith("COMPLETE"): 45 | if len(stack) < 1: 46 | logger.info("Invalid COMPLETE action: the stack is empty.") 47 | else: 48 | mention = stack.pop(-1) 49 | mention.label = action.split("-")[-1].strip() 50 | mentions.append(mention) 51 | else: 52 | if action.find("REDUCE") >= 0 and len(stack) >= 2: 53 | right_node = stack.pop(-1) 54 | left_node = stack.pop(-1) 55 | if Mention.contains(left_node, right_node) or Mention.contains(right_node, left_node): 56 | logger.info("Invalid REDUCE action: the last two elements in the stack contain each other") 57 | else: 58 | merged = _NodeInStack.merge_nodes(left_node, right_node) 59 | if action.startswith("LEFT"): stack.append(left_node) 60 | if action.startswith("RIGHT"): stack.append(right_node) 61 | stack.append(merged) 62 | else: 63 | logger.info( 64 | "Invalid REDUCE action: %s, the number of elements in the stack is %d." % (action, len(stack))) 65 | return mentions 66 | 67 | 68 | def mention2actions(self, mentions: str, sentence_length: int): 69 | def _detect_overlapping_mentions(mentions): 70 | mentions = Mention.create_mentions(mentions) 71 | 72 | for i in range(len(mentions)): 73 | if mentions[i]._overlapping: continue 74 | for j in range(len(mentions)): 75 | if i == j: continue 76 | if Mention.overlap_spans(mentions[i], mentions[j]): 77 | assert mentions[i].label == mentions[j].label 78 | mentions[i]._overlapping = True 79 | mentions[j]._overlapping = True 80 | return mentions 81 | 82 | 83 | def _involve_mention(mentions, token_id): 84 | for i, mention in enumerate(mentions): 85 | for span in mention.spans: 86 | if span.start <= token_id and token_id <= span.end: 87 | return True 88 | return False 89 | 90 | 91 | def _find_relevant_mentions(mentions, node): 92 | parents, equals = [], [] 93 | for i in range(len(mentions)): 94 | if Mention.contains(mentions[i], node): 95 | parents.append(i) 96 | if Mention.equal_spans(mentions[i], node): 97 | equals.append(i) 98 | return parents, equals 99 | 100 | 101 | mentions = _detect_overlapping_mentions(mentions) 102 | actions, stack = [], [] 103 | buffer = [i for i in range(sentence_length)] 104 | 105 | while len(buffer) > 0: 106 | if not _involve_mention(mentions, buffer[0]): 107 | actions.append("OUT") 108 | buffer.pop(0) 109 | else: 110 | actions.append("SHIFT") 111 | stack.append(_NodeInStack.single_token_node(buffer[0])) 112 | buffer.pop(0) 113 | 114 | stack_changed = True 115 | 116 | # COMPLETE, REDUCE, LEFT-REDUCE, RIGHT-REDUCE 117 | # if the last item of the stack is a mention, and does not involve with other mentions, then COMPLETE 118 | while stack_changed: 119 | stack_changed = False 120 | 121 | if len(stack) >= 1: 122 | parents, equals = _find_relevant_mentions(mentions, stack[-1]) 123 | if len(equals) == 1 and len(parents) == 1: 124 | actions.append("COMPLETE-%s" % mentions[equals[0]].label) 125 | stack.pop(-1) 126 | mentions.pop(equals[0]) 127 | stack_changed = True 128 | 129 | # three REDUCE actions 130 | if len(stack) >= 2: 131 | if not Mention.overlap_spans(stack[-2], stack[-1]): 132 | last_two_ndoes = _NodeInStack.merge_nodes(stack[-2], stack[-1]) 133 | parents_of_two, _ = _find_relevant_mentions(mentions, last_two_ndoes) 134 | if len(parents_of_two) > 0: 135 | parent_of_left, _ = _find_relevant_mentions(mentions, stack[-2]) 136 | parent_of_right, _ = _find_relevant_mentions(mentions, stack[-1]) 137 | if len(parents_of_two) != len(parent_of_left): 138 | actions.append("LEFT-REDUCE") 139 | stack.pop(-1) 140 | stack.append(last_two_ndoes) 141 | stack_changed = True 142 | elif len(parents_of_two) != len(parent_of_right): 143 | actions.append("RIGHT-REDUCE") 144 | stack.pop(-2) 145 | stack.append(last_two_ndoes) 146 | stack_changed = True 147 | else: 148 | actions.append("REDUCE") 149 | stack.pop(-1) 150 | stack.pop(-1) 151 | stack.append(last_two_ndoes) 152 | stack_changed = True 153 | return actions -------------------------------------------------------------------------------- /data/share2013/tokenization.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, os, sys 2 | from typing import List, NamedTuple 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def parse_parameters(parser=None): 8 | if parser is None: parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--input_dir", default=None, type=str) 11 | parser.add_argument("--split", default=None, type=str) 12 | parser.add_argument("--log_filepath", default="output.log", type=str) 13 | 14 | args, _ = parser.parse_known_args() 15 | return args 16 | 17 | 18 | class Token(NamedTuple): 19 | text: str = None 20 | idx: int = None # start character offset 21 | 22 | 23 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/tokenizers/word_splitter.py#SimpleWordSplitter''' 24 | class CustomSplitter: 25 | def __init__(self): 26 | self.special_cases = set(["mr.", "mrs.", "etc.", "e.g.", "cf.", "c.f.", "eg.", "al."]) 27 | self.special_beginning = set(["http", "www"]) 28 | self.contractions = set(["n't", "'s", "'ve", "'re", "'ll", "'d", "'m"]) 29 | self.contractions |= set([x.replace("'", "’") for x in self.contractions]) 30 | self.ending_punctuation = set(['"', "'", '.', ',', ';', ')', ']', '}', ':', '!', '?', '%', '”', "’"]) 31 | self.beginning_punctuation = set(['"', "'", '(', '[', '{', '#', '$', '“', "‘", "+", "*", "="]) 32 | self.delimiters = set(["-", "/", ",", ")", "&", "(", "?", ".", "\\", ";", ":", "+", ">", "%"]) 33 | 34 | def split_tokens(self, sentence: str) -> List[Token]: 35 | original_sentence = sentence 36 | sentence = list(sentence) 37 | sentence = "".join([o if not o in self.delimiters else " %s " % o for o in sentence]) 38 | tokens = [] 39 | field_start, field_end = 0, 0 40 | for filed in sentence.split(): 41 | filed = filed.strip() 42 | if len(filed) == 0: continue 43 | 44 | field_start = original_sentence.find(filed, field_start) 45 | field_end = field_start + len(filed) 46 | assert field_start >= 0, "cannot find (%s) from \"%s\" after offset %d" % ( 47 | filed, original_sentence, field_start) 48 | 49 | add_at_end = [] 50 | while self._can_split(filed) and filed[0] in self.beginning_punctuation: 51 | tokens.append(Token(filed[0], field_start)) 52 | filed = filed[1:] 53 | field_start += 1 54 | 55 | while self._can_split(filed) and filed[-1] in self.ending_punctuation: 56 | add_at_end.insert(0, Token(filed[-1], field_start + len(filed) - 1)) 57 | filed = filed[:-1] 58 | 59 | remove_contractions = True 60 | while remove_contractions: 61 | remove_contractions = False 62 | for contraction in self.contractions: 63 | if self._can_split(filed) and filed.lower().endswith(contraction): 64 | add_at_end.insert(0, Token(filed[-len(contraction):], field_start + len(filed) - len(contraction))) 65 | filed = filed[:-len(contraction)] 66 | remove_contractions = True 67 | 68 | if filed: 69 | tokens.append(Token(filed, field_start)) 70 | tokens.extend(add_at_end) 71 | field_start = field_end 72 | return tokens 73 | 74 | def _can_split(self, token): 75 | if not token: return False 76 | if token.lower() in self.special_cases: return False 77 | for _special_beginning in self.special_beginning: 78 | if token.lower().startswith(_special_beginning): return False 79 | return True 80 | 81 | 82 | class CustomSentenceSplitter(): 83 | def _next_character_is_upper(self, text, i): 84 | while i < len(text): 85 | if len(text[i].strip()) == 0: 86 | i += 1 87 | elif text[i].isupper(): 88 | return True 89 | else: 90 | break 91 | return False 92 | 93 | # do very simple things: if there is a period '.', and the next character is uppercased, call it a sentence. 94 | def split_sentence(self, text): 95 | break_points = [0] 96 | for i in range(len(text)): 97 | if text[i] in [".", "!", "?"]: 98 | if self._next_character_is_upper(text, i + 1): 99 | break_points.append(i + 1) 100 | break_points.append(-1) 101 | sentences = [] 102 | for s, e in zip(break_points[0:-1], break_points[1:]): 103 | if e == -1: 104 | sentences.append(text[s:].strip()) 105 | else: 106 | sentences.append(text[s:e].strip()) 107 | return sentences 108 | 109 | 110 | if __name__ == "__main__": 111 | args = parse_parameters() 112 | handlers = [logging.FileHandler(filename=args.log_filepath), logging.StreamHandler(sys.stdout)] 113 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", 114 | level=logging.INFO, handlers=handlers) 115 | 116 | token_splitter = CustomSplitter() 117 | sentence_splitter = CustomSentenceSplitter() 118 | 119 | with open("%s.tokens" % args.split, "w") as out_f: 120 | num_of_sentences, num_tokens = 0, 0 121 | for filename in os.listdir(args.input_dir): 122 | if not filename.endswith("txt"): continue 123 | with open(os.path.join(args.input_dir, filename)) as in_f: 124 | text = in_f.read() 125 | token_start, token_end = 0, 0 126 | for i, line in enumerate(text.splitlines()): 127 | if len(line.strip()) > 0: 128 | if i <= 5 or len(line.strip()) < 150: 129 | num_of_sentences += 1 130 | tokens = token_splitter.split_tokens(line.strip()) 131 | for token in tokens: 132 | num_tokens += 1 133 | token_start = text.find(token.text, token_start) 134 | assert token_start >= 0 135 | token_end = token_start + len(token.text.strip()) 136 | out_f.write("%s %s %d %d\n" % (token.text, filename, token_start, token_end)) 137 | token_start = token_end 138 | out_f.write("\n") 139 | else: 140 | for sentence in sentence_splitter.split_sentence(line.strip()): 141 | num_of_sentences += 1 142 | tokens = token_splitter.split_tokens(sentence.strip()) 143 | for token in tokens: 144 | num_tokens += 1 145 | token_start = text.find(token.text, token_start) 146 | assert token_start >= 0 147 | token_end = token_start + len(token.text.strip()) 148 | out_f.write("%s %s %d %d\n" % (token.text, filename, token_start, token_end)) 149 | token_start = token_end 150 | out_f.write("\n") 151 | 152 | logger.info(f"{num_of_sentences} sentences and {num_tokens} tokens in {args.input_dir}") -------------------------------------------------------------------------------- /data/cadec/convert_flat_mentions.py: -------------------------------------------------------------------------------- 1 | '''Update date: 2020-Jan-13''' 2 | import argparse 3 | from typing import List 4 | 5 | 6 | def merge_consecutive_indices(indices: List[int]) -> List[int]: 7 | '''convert 136 142 143 147 into 136 147 (these two spans are actually consecutive), 8 | 136 142 143 147 148 160 into 136 160 (these three spans are consecutive) 9 | it only makes sense when these indices are inclusive''' 10 | consecutive_indices = [] 11 | assert len(indices) % 2 == 0 12 | for i, v in enumerate(indices): 13 | if (i == 0) or (i == len(indices) - 1): 14 | consecutive_indices.append(v) 15 | else: 16 | if i % 2 == 0: 17 | if v > indices[i - 1] + 1: 18 | consecutive_indices.append(v) 19 | else: 20 | if v + 1 < indices[i + 1]: 21 | consecutive_indices.append(v) 22 | assert len(consecutive_indices) % 2 == 0 and len(consecutive_indices) <= len(indices) 23 | if len(indices) != len(consecutive_indices): 24 | indices = " ".join([str(i) for i in indices]) 25 | print("Convert from [%s] to [%s]." % (indices, " ".join([str(i) for i in consecutive_indices]))) 26 | return consecutive_indices 27 | 28 | 29 | class Span(object): 30 | def __init__(self, start, end): 31 | '''start and end are inclusive''' 32 | self.start = int(start) 33 | self.end = int(end) 34 | 35 | 36 | @classmethod 37 | def overlaps(cls, span1, span2): 38 | '''whether span1 overlaps with span2, including equals''' 39 | if span1.end < span2.start: return False 40 | if span1.start > span2.end: return False 41 | return True 42 | 43 | 44 | def __str__(self): 45 | return self.__repr__() 46 | 47 | 48 | def __repr__(self): 49 | return "%d,%d" % (self.start, self.end) 50 | 51 | 52 | class Mention(object): 53 | def __init__(self, spans, label: str): 54 | assert len(spans) >= 1 55 | self.spans = spans 56 | self.label = label 57 | self.discontinuous = (len(spans) > 1) 58 | self._overlapping_spans = set() 59 | 60 | 61 | @property 62 | def start(self): 63 | return self.spans[0].start 64 | 65 | 66 | @property 67 | def end(self): 68 | return self.spans[-1].end 69 | 70 | 71 | @property 72 | def overlapping(self): 73 | return len(self._overlapping_spans) > 0 74 | 75 | 76 | @classmethod 77 | def overlap_spans(cls, mention1, mention2): 78 | for span1 in mention1.spans: 79 | for span2 in mention2.spans: 80 | if Span.overlaps(span1, span2): 81 | return True 82 | return False 83 | 84 | 85 | @classmethod 86 | def remove_discontinuous_mentions(cls, mentions): 87 | '''convert discontinuous mentions, such as 17,20,22,22 Disorder, to 17,22 Disorder''' 88 | continuous_mentions = [] 89 | for mention in mentions: 90 | if mention.discontinuous: 91 | continuous_mentions.append(Mention.create_mention([mention.start, mention.end], mention.label)) 92 | else: 93 | continuous_mentions.append(mention) 94 | return continuous_mentions 95 | 96 | 97 | @classmethod 98 | def merge_overlapping_mentions(cls, mentions): 99 | ''' 100 | Given a list of mentions which may overlap with each other, erase these overlapping. 101 | For example 102 | 1) if an mention starts at 1, ends at 4, the other one starts at 3, ends at 5. 103 | Then group these together as one mention starting at 1, ending at 5 if they are of the same type, 104 | otherwise, raise an Error. 105 | ''' 106 | overlapping_may_exist = True 107 | while overlapping_may_exist: 108 | overlapping_may_exist = False 109 | merged_mentions = {} 110 | for i in range(len(mentions)): 111 | for j in range(len(mentions)): 112 | if i == j: continue 113 | if Mention.overlap_spans(mentions[i], mentions[j]): 114 | assert mentions[i].label == mentions[j].label, "TODO: two mentions of different types overlap" 115 | overlapping_may_exist = True 116 | merged_mention_start = min(mentions[i].start, mentions[j].start) 117 | merged_mention_end = max(mentions[i].end, mentions[j].end) 118 | merged_mention = Mention.create_mention([merged_mention_start, merged_mention_end], 119 | mentions[i].label) 120 | if (merged_mention_start, merged_mention_end) not in merged_mentions: 121 | merged_mentions[(merged_mention_start, merged_mention_end)] = merged_mention 122 | mentions[i]._overlapping_spans.add(0) 123 | mentions[j]._overlapping_spans.add(0) 124 | mentions = [mention for mention in mentions if not mention.overlapping] + list(merged_mentions.values()) 125 | return mentions 126 | 127 | 128 | @classmethod 129 | def create_mention(cls, indices, label: str): 130 | ''' 131 | the original indices can be 136,142,143,147, these two spans are actually consecutive, so convert to 136,147 132 | similarily, convert 136,142,143,147,148,160 into 136,160 (these three spans are consecutive) 133 | additionally, sort the indices: 119,125,92,96 to 92,96,119,125 134 | ''' 135 | assert len(indices) % 2 == 0 136 | indices = sorted(indices) 137 | indices = merge_consecutive_indices(indices) 138 | spans = [Span(indices[i], indices[i + 1]) for i in range(0, len(indices), 2)] 139 | return cls(spans, label) 140 | 141 | 142 | @classmethod 143 | def create_mentions(cls, mentions: str): 144 | '''Input: 5,6 DATE|6,6 DAY|5,6 EVENT''' 145 | if len(mentions.strip()) == 0: return [] 146 | results = [] 147 | for mention in mentions.split("|"): 148 | indices, label = mention.split() 149 | indices = [int(i) for i in indices.split(",")] 150 | results.append(Mention.create_mention(indices, label)) 151 | return results 152 | 153 | 154 | def __str__(self): 155 | return self.__repr__() 156 | 157 | 158 | def __repr__(self): 159 | spans = [str(s) for s in self.spans] 160 | return "%s %s" % (",".join(spans), self.label) 161 | 162 | 163 | def parse_parameters(parser=None): 164 | if parser is None: parser = argparse.ArgumentParser() 165 | 166 | ## Required 167 | parser.add_argument("--input_filepath", type=str) 168 | parser.add_argument("--output_filepath", type=str) 169 | 170 | args, _ = parser.parse_known_args() 171 | return args 172 | 173 | 174 | if __name__ == "__main__": 175 | args = parse_parameters() 176 | 177 | num_flat_mentions, num_original_mentions = 0, 0 178 | 179 | with open(args.output_filepath, "w") as out_f: 180 | with open(args.input_filepath) as in_f: 181 | for text in in_f: 182 | out_f.write(text) 183 | mentions = next(in_f).strip() 184 | assert len(next(in_f).strip()) == 0 185 | if len(mentions) > 0: 186 | mentions = Mention.create_mentions(mentions) 187 | num_original_mentions += len(mentions) 188 | disc_removed = Mention.remove_discontinuous_mentions(mentions) 189 | flat_mentions = Mention.merge_overlapping_mentions(disc_removed) 190 | num_flat_mentions += len(flat_mentions) 191 | mentions = "|".join([str(m) for m in flat_mentions]) 192 | out_f.write("%s\n\n" % mentions) 193 | 194 | print("After merging overlapping mentions, there are %d out of %d mentions." % (num_flat_mentions, num_original_mentions)) -------------------------------------------------------------------------------- /code/xdai/utils/nn.py: -------------------------------------------------------------------------------- 1 | import itertools, torch 2 | from typing import List 3 | 4 | 5 | '''Update date: 2019-Nov-5''' 6 | def block_orthogonal(tensor, split_sizes: List[int], gain=1.0): 7 | # tensor: the tensor to initialize 8 | # split_sizes: [10, 20] result in the tensor being split into chunks of size 10 along the first dimension 9 | # 20 along the second 10 | # Used in the case of recurrent models which use multiple gates applied to linear projections. 11 | # Separate parameters should be initialized independently 12 | data = tensor.data 13 | sizes = list(tensor.size()) 14 | if any([a % b != 0 for a, b in zip(sizes, split_sizes)]): 15 | raise ValueError("block_orthogonal: tensor size and split sizes not compatible!") 16 | 17 | indexes = [list(range(0, max_size, split)) for max_size, split in zip(sizes, split_sizes)] 18 | for block_state_indices in itertools.product(*indexes): 19 | index_and_step_tuples = zip(block_state_indices, split_sizes) 20 | block_slice = tuple([slice(start_index, start_index + step) for start_index, step in index_and_step_tuples]) 21 | data[block_slice] = torch.nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain) 22 | 23 | 24 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py#clamp_tensor 25 | Update date: 2019-Nov-6''' 26 | def clamp_tensor(tensor, minimum, maximum): 27 | if tensor.is_sparse: 28 | coalesced_tensor = tensor.coalesce() 29 | 30 | coalesced_tensor._values().clamp_(minimum, maximum) 31 | return coalesced_tensor 32 | else: 33 | return tensor.clamp(minimum, maximum) 34 | 35 | 36 | '''Update date: 2019-Nov-7''' 37 | def enable_gradient_clipping(model, grad_clipping) -> None: 38 | if grad_clipping is not None: 39 | for parameter in model.parameters(): 40 | if parameter.requires_grad: 41 | parameter.register_hook(lambda grad: clamp_tensor(grad, minimum=-grad_clipping, maximum=grad_clipping)) 42 | 43 | 44 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/highway.py 45 | https://github.com/LiyuanLucasLiu/LM-LSTM-CRF/blob/master/model/highway.py 46 | A gated combination of a linear transformation and a non-linear transformation of its input. 47 | Math: y = g * x + (1 - g) * f(A(x)), 48 | g is an element-wise gate, computed as: sigmoid(B(x)). 49 | A is a linear transformation, f is an element-wise non-linearity 50 | Update date: 2019-Nov-5''' 51 | class Highway(torch.nn.Module): 52 | def __init__(self, input_dim, num_layers=1): 53 | super(Highway, self).__init__() 54 | 55 | self._layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]) 56 | for layer in self._layers: 57 | # Bias the highway layer to just carry its input forward. 58 | # Set the bias on B(x) to be positive, then g will be biased to be high 59 | # The bias on B(x) is the second half of the bias vector in each linear layer. 60 | layer.bias[input_dim:].data.fill_(1) 61 | 62 | def forward(self, inputs): 63 | current_inputs = inputs 64 | for layer in self._layers: 65 | linear_part = current_inputs 66 | projected_inputs = layer(current_inputs) 67 | 68 | nonlinear_part, gate = projected_inputs.chunk(2, dim=-1) 69 | nonlinear_part = torch.nn.functional.relu(nonlinear_part) 70 | gate = torch.sigmoid(gate) 71 | current_inputs = gate * linear_part + (1 - gate) * nonlinear_part 72 | return current_inputs 73 | 74 | 75 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py 76 | Update date: 2019-April-26''' 77 | def masked_softmax(vector, mask=None): 78 | if mask is None: 79 | return torch.nn.functional.softmax(vector, dim=-1) 80 | else: 81 | mask = mask.float() 82 | assert mask.dim() == vector.dim() 83 | # use a very large negative number for those masked positions 84 | # so that the probabilities of those positions would be approximately 0. 85 | # This is not accurate in math, but works for most cases and consumes less memory. 86 | masked_vector = vector.masked_fill((1 - mask).byte(), -1e32) 87 | return torch.nn.functional.softmax(masked_vector, dim=-1) 88 | 89 | 90 | '''Update date: 2019-Nov-7''' 91 | def rescale_gradients(model, grad_norm): 92 | if grad_norm: 93 | parameters = [p for p in model.parameters() if p.grad is not None] 94 | return sparse_clip_norm(parameters, grad_norm) 95 | return None 96 | 97 | 98 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/scalar_mix.py 99 | Compute a parameterised scalar mixture of N tensors: 100 | outs = gamma * sum(s_k * tensor_k) 101 | s_k = softmax(w) 102 | gamma and w are parameters 103 | Imagine tensor_k are outputs of each layer in ELMo, and outs is its final weighted (s_k) representation. 104 | Update date: 2019-Nov-5''' 105 | class ScalarMix(torch.nn.Module): 106 | def __init__(self, num_tensors, trainable=True): 107 | super(ScalarMix, self).__init__() 108 | self.num_tensors = num_tensors 109 | self.scalar_parameters = torch.nn.ParameterList( 110 | [torch.nn.Parameter(torch.FloatTensor([0.0]), requires_grad=trainable) for _ in range(num_tensors)]) 111 | self.gamma = torch.nn.Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) 112 | 113 | def forward(self, tensors, mask=None): 114 | # tensors must all be the same shape, let's say (batch_size, timesteps, dim) 115 | assert self.num_tensors == len(tensors) 116 | 117 | normed_weights = torch.nn.functional.softmax(torch.cat([p for p in self.scalar_parameters]), dim=0) 118 | normed_weights = torch.split(normed_weights, split_size_or_sections=1) 119 | pieces = [] 120 | for weight, tensor in zip(normed_weights, tensors): 121 | pieces.append(weight * tensor) 122 | return self.gamma * sum(pieces) 123 | 124 | 125 | '''Update date: 2019-Nov-7''' 126 | def sparse_clip_norm(parameters, max_norm: float): 127 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 128 | total_norm = 0 129 | for p in parameters: 130 | if p.grad.is_sparse: 131 | grad = p.grad.data.coalesce() 132 | param_norm = grad._values().norm(2.) 133 | else: 134 | param_norm = p.grad.data.norm(2.) 135 | 136 | total_norm += param_norm ** 2. 137 | 138 | total_norm = total_norm ** (1. / 2.) 139 | 140 | clip_coef = max_norm / (total_norm + 1e-6) 141 | if clip_coef < 1: 142 | for p in parameters: 143 | if p.grad.is_sparse: 144 | p.grad.data._values().mul_(clip_coef) 145 | else: 146 | p.grad.data.mul_(clip_coef) 147 | return total_norm 148 | 149 | 150 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/time_distributed.py 151 | Given an input shaped like (batch_size, sequence_length, ...) and a Module that takes input like (batch_size, ...) 152 | TimeDistributed can reshape the input to be (batch_size * sequence_length, ...) applies the Module, then reshape back. 153 | Update date: 2019-Nov-5''' 154 | class TimeDistributed(torch.nn.Module): 155 | def __init__(self, module): 156 | super(TimeDistributed, self).__init__() 157 | self._module = module 158 | 159 | def forward(self, *inputs): 160 | reshaped_inputs = [] 161 | 162 | for input_tensor in inputs: 163 | input_size = input_tensor.size() 164 | assert len(input_size) > 2 165 | squashed_shape = [-1] + [x for x in input_size[2:]] 166 | reshaped_inputs.append(input_tensor.contiguous().view(*squashed_shape)) 167 | 168 | reshaped_outputs = self._module(*reshaped_inputs) 169 | 170 | original_shape = [input_size[0], input_size[1]] + [x for x in reshaped_outputs.size()[1:]] 171 | outputs = reshaped_outputs.contiguous().view(*original_shape) 172 | return outputs -------------------------------------------------------------------------------- /code/xdai/utils/token_embedder.py: -------------------------------------------------------------------------------- 1 | import inspect, os, torch 2 | import numpy as np 3 | from typing import Dict 4 | from xdai.utils.nn import TimeDistributed 5 | from xdai.utils.seq2vec import CnnEncoder 6 | from xdai.elmo.models import Elmo 7 | 8 | 9 | '''Update date: 2019-Nov-5''' 10 | class Embedding(torch.nn.Module): 11 | def __init__(self, vocab_size, embedding_dim, weight: torch.FloatTensor = None, trainable=True): 12 | super(Embedding, self).__init__() 13 | self.output_dim = embedding_dim 14 | 15 | if weight is None: 16 | weight = torch.FloatTensor(vocab_size, embedding_dim) 17 | self.weight = torch.nn.Parameter(weight, requires_grad=trainable) 18 | torch.nn.init.xavier_uniform_(self.weight) 19 | else: 20 | assert weight.size() == (vocab_size, embedding_dim) 21 | self.weight = torch.nn.Parameter(weight, requires_grad=trainable) 22 | 23 | 24 | def get_output_dim(self): 25 | return self.output_dim 26 | 27 | 28 | def forward(self, inputs): 29 | outs = torch.nn.functional.embedding(inputs, self.weight) 30 | return outs 31 | 32 | 33 | '''Update date: 2019-Nov-5''' 34 | class TokenCharactersEmbedder(torch.nn.Module): 35 | def __init__(self, embedding: Embedding, encoder, dropout=0.0): 36 | super(TokenCharactersEmbedder, self).__init__() 37 | self._embedding = TimeDistributed(embedding) 38 | self._encoder = TimeDistributed(encoder) 39 | if dropout > 0: 40 | self._dropout = torch.nn.Dropout(p=dropout) 41 | else: 42 | self._dropout = lambda x: x 43 | 44 | 45 | def get_output_dim(self): 46 | return self._encoder._module.get_output_dim() 47 | 48 | 49 | def forward(self, token_characters): 50 | '''token_characters: batch_size, num_tokens, num_characters''' 51 | mask = (token_characters != 0).long() 52 | outs = self._embedding(token_characters) 53 | outs = self._encoder(outs, mask) 54 | outs = self._dropout(outs) 55 | return outs 56 | 57 | 58 | '''A single layer of ELMo representations, essentially a wrapper around ELMo(num_output_representations=1, ...) 59 | Update date: 2019-Nov-5''' 60 | class ElmoTokenEmbedder(torch.nn.Module): 61 | def __init__(self, options_file, weight_file, dropout=0.5, requires_grad=False, 62 | projection_dim=None): 63 | super(ElmoTokenEmbedder, self).__init__() 64 | 65 | self._elmo = Elmo(options_file, weight_file, num_output_representations=1, dropout=dropout, 66 | requires_grad=requires_grad) 67 | if projection_dim: 68 | self._projection = torch.nn.Linear(self._elmo.get_output_dim(), projection_dim) 69 | self.output_dim = projection_dim 70 | else: 71 | self._projection = None 72 | self.output_dim = self._elmo.get_output_dim() 73 | 74 | 75 | def get_output_dim(self): 76 | return self.output_dim 77 | 78 | 79 | def forward(self, inputs): 80 | # inputs: batch_size, num_tokens, 50 81 | elmo_output = self._elmo(inputs) 82 | elmo_representations = elmo_output["elmo_representations"][0] 83 | if self._projection: 84 | projection = self._projection 85 | for _ in range(elmo_representations.dim() - 2): 86 | projection = TimeDistributed(projection) 87 | elmo_representations = projection(elmo_representations) 88 | return elmo_representations 89 | 90 | 91 | '''Update date: 2019-Nov-5''' 92 | def _load_pretrained_embeddings(filepath, dimension, token2idx): 93 | tokens_to_keep = set(token2idx.keys()) 94 | embeddings = {} 95 | if filepath != "" and os.path.isfile(filepath): 96 | with open(filepath, "r", encoding="utf-8") as f: 97 | for line in f: 98 | sp = line.strip().split(" ") 99 | if len(sp) <= dimension: continue 100 | token = sp[0] 101 | if token not in tokens_to_keep: continue 102 | embeddings[token] = np.array([float(x) for x in sp[1:]]) 103 | 104 | print(" # Load %d out of %d words (%d-dimensional) from pretrained embedding file (%s)!" % ( 105 | len(embeddings), len(token2idx), dimension, filepath)) 106 | 107 | all_embeddings = np.asarray(list(embeddings.values())) 108 | embeddings_mean = float(np.mean(all_embeddings)) 109 | embeddings_std = float(np.std(all_embeddings)) 110 | 111 | weights = np.random.normal(embeddings_mean, embeddings_std, size=(len(token2idx), dimension)) 112 | for token, i in token2idx.items(): 113 | if token in embeddings: 114 | weights[i] = embeddings[token] 115 | return weights 116 | 117 | 118 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/modules/text_field_embedders/* 119 | Takes as input the dict produced by TextField and 120 | returns as output an embedded representations of the tokens in that field 121 | Update date: 2019-Nov-5''' 122 | class TextFieldEmbedder(torch.nn.Module): 123 | def __init__(self, token_embedders, embedder_to_indexer_map=None): 124 | super(TextFieldEmbedder, self).__init__() 125 | 126 | self.token_embedders = token_embedders 127 | self._embedder_to_indexer_map = embedder_to_indexer_map 128 | for k, embedder in token_embedders.items(): 129 | self.add_module("token_embedder_%s" % k, embedder) 130 | 131 | 132 | def get_output_dim(self): 133 | return sum([embedder.get_output_dim() for embedder in self.token_embedders.values()]) 134 | 135 | 136 | '''text_field_input is the output of a call to TextField.as_tensor (see instance.py). 137 | Each tensor in here is assumed to have a shape roughly similar to (batch_size, num_tokens)''' 138 | def forward(self, text_field_input: Dict[str, torch.Tensor], **kwargs): 139 | outs = [] 140 | for k in sorted(self.token_embedders.keys()): 141 | embedder = getattr(self, "token_embedder_%s" % k) 142 | forward_params = inspect.signature(embedder.forward).parameters 143 | forward_params_values = {} 144 | for param in forward_params.keys(): 145 | if param in kwargs: 146 | forward_params_values[param] = kwargs[param] 147 | if self._embedder_to_indexer_map is not None and k in self._embedder_to_indexer_map: 148 | indexer_map = self._embedder_to_indexer_map[k] 149 | assert isinstance(indexer_map, dict) 150 | tensors = {name: text_field_input[argument] for name, argument in indexer_map.items()} 151 | outs.append(embedder(**tensors, **forward_params_values)) 152 | else: 153 | tensors = [text_field_input[k]] 154 | outs.append(embedder(*tensors, **forward_params_values)) 155 | return torch.cat(outs, dim=-1) 156 | 157 | 158 | @classmethod 159 | def tokens_embedder(cls, vocab, args): 160 | token2idx = vocab.get_item_to_index_vocabulary("tokens") 161 | weight = _load_pretrained_embeddings(args.pretrained_word_embeddings, dimension=100, token2idx=token2idx) 162 | return Embedding(len(token2idx), embedding_dim=100, weight=torch.FloatTensor(weight)) 163 | 164 | 165 | @classmethod 166 | def token_characters_embedder(cls, vocab, args): 167 | embedding = Embedding(vocab.get_vocab_size("token_characters"), embedding_dim=16) 168 | return TokenCharactersEmbedder(embedding, CnnEncoder()) 169 | 170 | 171 | @classmethod 172 | def elmo_embedder(cls, vocab, args): 173 | option_file = os.path.join(args.pretrained_model_dir, "options.json") 174 | weight_file = os.path.join(args.pretrained_model_dir, "weights.hdf5") 175 | return ElmoTokenEmbedder(option_file, weight_file) 176 | 177 | 178 | @classmethod 179 | def create_embedder(cls, args, vocab): 180 | embedder_to_indexer_map = {} 181 | embedders = {"tokens": TextFieldEmbedder.tokens_embedder(vocab, args), 182 | "token_characters": TextFieldEmbedder.token_characters_embedder(vocab, args)} 183 | 184 | if args.model_type == "elmo": 185 | embedders["elmo_characters"] = TextFieldEmbedder.elmo_embedder(vocab, args) 186 | return cls(embedders, embedder_to_indexer_map) -------------------------------------------------------------------------------- /code/xdai/utils/token_indexer.py: -------------------------------------------------------------------------------- 1 | import itertools, re 2 | from xdai.utils.common import pad_sequence_to_length 3 | 4 | 5 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/token_indexer.py#TokenIndexer 6 | Update date: 2019-Nov-5''' 7 | class _TokenIndexer: 8 | '''A ``TokenIndexer`` determines how string tokens get represented as arrays of indices in a model.''' 9 | def __init__(self, token_min_padding_length=0): 10 | self._token_min_padding_length = token_min_padding_length 11 | 12 | 13 | def tokens_to_indices(self, tokens, vocabulary, index_name): 14 | '''Take a list of tokens and convert them to one or more sets of indices.''' 15 | raise NotImplementedError 16 | 17 | 18 | def get_padding_token(self): 19 | '''When we need to add padding tokens, what should they look like? A 'blank' token.''' 20 | raise NotImplementedError 21 | 22 | 23 | def get_padding_lengths(self, token): 24 | raise NotImplementedError 25 | 26 | 27 | def get_token_min_padding_length(self): 28 | return self._token_min_padding_length 29 | 30 | 31 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/single_id_token_indexer.py 32 | Update date: 2019-Nov-5''' 33 | class SingleIdTokenIndexer(_TokenIndexer): 34 | def __init__(self, lowercase_tokens=True, normalize_digits=False, token_min_padding_length=0): 35 | super().__init__(token_min_padding_length) 36 | self.namespace = "tokens" 37 | self.lowercase_tokens = lowercase_tokens 38 | self.normalize_digits = normalize_digits 39 | 40 | 41 | def count_vocab_items(self, token, counter): 42 | text = token.text 43 | if self.lowercase_tokens: text = text.lower() 44 | if self.normalize_digits: text = re.sub(r"[0-9]", "0", text) 45 | counter[self.namespace][text] += 1 46 | 47 | 48 | def tokens_to_indices(self, tokens, vocabulary, index_name): 49 | indices = [] 50 | for token in tokens: 51 | text = token.text 52 | if self.lowercase_tokens: text = text.lower() 53 | if self.normalize_digits: text = re.sub(r"[0-9]", "0", text) 54 | indices.append(vocabulary.get_item_index(text, self.namespace)) 55 | return {index_name: indices} 56 | 57 | 58 | def get_padding_token(self): 59 | return 0 60 | 61 | 62 | def get_padding_lengths(self, token): 63 | return {} 64 | 65 | 66 | '''tokens: {'tokens': [53, 10365, 9, 53, 15185, 10]} 67 | desired_num_tokens: {'tokens': 11} 68 | return: {'tokens': [53, 10365, 9, 53, 15185, 10, 0, 0, 0, 0, 0]}''' 69 | def pad_token_sequence(self, tokens, desired_num_tokens, padding_lengths=None): 70 | return {k: pad_sequence_to_length(v, desired_num_tokens[k]) for k, v in tokens.items()} 71 | 72 | 73 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/token_characters_indexer.py 74 | Update date: 2019-Nov-5''' 75 | class TokenCharactersIndexer(_TokenIndexer): 76 | def __init__(self, token_min_padding_length=0): 77 | super().__init__(token_min_padding_length) 78 | self._namespace = "token_characters" 79 | # If using CnnEncoder to build character-level representations, 80 | # this value is set to the maximum value of ngram_filter_sizes 81 | self._min_padding_length = 3 82 | 83 | 84 | def count_vocab_items(self, token, counter): 85 | for c in list(token.text): 86 | counter[self._namespace][c] += 1 87 | 88 | 89 | def tokens_to_indices(self, tokens, vocabulary, index_name): 90 | indices = [] 91 | for token in tokens: 92 | token_indices = [] 93 | for c in list(token.text): 94 | index = vocabulary.get_item_index(c, self._namespace) 95 | token_indices.append(index) 96 | indices.append(token_indices) 97 | return {index_name: indices} 98 | 99 | 100 | def get_padding_lengths(self, token): 101 | return {"num_token_characters": max(len(token), self._min_padding_length)} 102 | 103 | 104 | def get_padding_token(self): 105 | return [] 106 | 107 | 108 | '''tokens: {'token_characters': [[45, 8, 6, 4, 6, 9, 12], [52, 3, 4, 3], 109 | [6, 5], [15, 2, 8, 18, 2, 8], [4, 3, 10, 30, 9], [21, 6, 4, 12], 110 | [42, 2, 5, 4, 15, 7, 8, 2], [19]]} 111 | desired_num_tokens: {'token_characters': 10} 112 | padding_lengths: {'tokens_length': 10, 'token_characters_length': 10, 'num_tokens': 10, 'num_token_characters': 12} 113 | return: {'token_characters': [[45, 8, 6, 4, 6, 9, 12, 0, 0, 0, 0, 0], [52, 3, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0], 114 | [6, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [15, 2, 8, 18, 2, 8, 0, 0, 0, 0, 0, 0], 115 | [4, 3, 10, 30, 9, 0, 0, 0, 0, 0, 0, 0], [21, 6, 4, 12, 0, 0, 0, 0, 0, 0, 0, 0], 116 | [42, 2, 5, 4, 15, 7, 8, 2, 0, 0, 0, 0], [19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 117 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} 118 | ''' 119 | def pad_token_sequence(self, tokens, desired_num_tokens, padding_lengths): 120 | padded_tokens = pad_sequence_to_length(tokens[self._namespace], 121 | desired_length=desired_num_tokens[self._namespace], 122 | default_value=self.get_padding_token) 123 | 124 | desired_token_length = padding_lengths["num_token_characters"] 125 | longest_token_length = max([len(t) for t in tokens[self._namespace]]) 126 | 127 | if desired_token_length > longest_token_length: 128 | padded_tokens.append([0] * desired_token_length) 129 | 130 | padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=0))) 131 | if desired_token_length > longest_token_length: 132 | padded_tokens.pop() 133 | 134 | return {self._namespace: [list(token[:desired_token_length]) for token in padded_tokens]} 135 | 136 | 137 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/elmo_indexer.py#ELMoCharacterMapper 138 | Update date: 2019-Nov-5''' 139 | class ELMoCharacterMapper: 140 | max_word_length = 50 141 | # 0-255 for utf-8 encoding bytes 142 | beginning_of_sentence_character = 256 143 | end_of_sentence_character = 257 144 | beginning_of_word_character = 258 145 | end_of_word_character = 259 146 | padding_character = 260 147 | 148 | beginning_of_sentence_characters = [padding_character] * max_word_length 149 | beginning_of_sentence_characters[0] = beginning_of_word_character 150 | beginning_of_sentence_characters[1] = beginning_of_sentence_character 151 | beginning_of_sentence_characters[2] = end_of_word_character 152 | 153 | end_of_sentence_characters = [padding_character] * max_word_length 154 | end_of_sentence_characters[0] = beginning_of_word_character 155 | end_of_sentence_characters[1] = end_of_sentence_character 156 | end_of_sentence_characters[2] = end_of_word_character 157 | 158 | bos_token = "" 159 | eos_token = "" 160 | 161 | 162 | @staticmethod 163 | def convert_word_to_char_ids(word): 164 | if word == ELMoCharacterMapper.bos_token: 165 | char_ids = ELMoCharacterMapper.beginning_of_sentence_characters 166 | elif word == ELMoCharacterMapper.eos_token: 167 | char_ids = ELMoCharacterMapper.end_of_sentence_characters 168 | else: 169 | word_encoded = word.encode("utf-8", "ignore")[:(ELMoCharacterMapper.max_word_length - 2)] 170 | char_ids = [ELMoCharacterMapper.padding_character] * ELMoCharacterMapper.max_word_length 171 | char_ids[0] = ELMoCharacterMapper.beginning_of_word_character 172 | for i, v in enumerate(word_encoded, start=1): 173 | char_ids[i] = v 174 | char_ids[len(word_encoded) + 1] = ELMoCharacterMapper.end_of_word_character 175 | return [c + 1 for c in char_ids] # add 1 for masking 176 | 177 | 178 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/elmo_indexer.py#ELMoTokenCharactersIndexer 179 | Update date: 2019-Nov-5''' 180 | class ELMoIndexer(_TokenIndexer): 181 | def __init__(self, token_min_padding_length=0): 182 | super().__init__(token_min_padding_length) 183 | self._namespace = "elmo_characters" 184 | 185 | 186 | def count_vocab_items(self, token, counter): 187 | pass 188 | 189 | 190 | def tokens_to_indices(self, tokens, vocabulary, index_name): 191 | texts = [token.text for token in tokens] 192 | return {index_name: [ELMoCharacterMapper.convert_word_to_char_ids(text) for text in texts]} 193 | 194 | 195 | def get_padding_lengths(self, token): 196 | return {} 197 | 198 | 199 | def get_padding_token(self): 200 | return [] 201 | 202 | 203 | @staticmethod 204 | def _default_value_for_padding(): 205 | return [0] * ELMoCharacterMapper.max_word_length 206 | 207 | 208 | def pad_token_sequence(self, tokens, desired_num_tokens, padding_lengths): 209 | return {k: pad_sequence_to_length(v, desired_length=desired_num_tokens[k], 210 | default_value=self._default_value_for_padding) for k, v in tokens.items()} -------------------------------------------------------------------------------- /code/xdai/utils/instance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from typing import Dict, List 4 | from xdai.utils.token import Token 5 | 6 | 7 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/fields/field.py 8 | Update date: 2019-Nov-5''' 9 | class _Field: 10 | def count_vocab_items(self, counter): 11 | pass 12 | 13 | 14 | def index(self, vocab): 15 | pass 16 | 17 | 18 | def get_padding_lengths(self): 19 | raise NotImplementedError 20 | 21 | 22 | def as_tensor(self, padding_lengths: Dict[str, int]): 23 | raise NotImplementedError 24 | 25 | 26 | def batch_tensors(self, tensor_list): 27 | return torch.stack(tensor_list) 28 | 29 | 30 | def __eq__(self, other): 31 | if isinstance(self, other.__class__): 32 | return self.__dict__ == other.__dict__ 33 | return NotImplemented 34 | 35 | 36 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/fields/metadata_field.py 37 | Update date: 2019-Nov-5''' 38 | class MetadataField(_Field): 39 | def __init__(self, metadata): 40 | self.metadata = metadata 41 | 42 | 43 | def __getitem__(self, key): 44 | try: 45 | return self.metadata[key] 46 | except TypeError: 47 | raise TypeError("Metadata is not a dict.") 48 | 49 | 50 | def __iter__(self): 51 | try: 52 | return iter(self.metadata) 53 | except TypeError: 54 | raise TypeError("Metadata is not iterable.") 55 | 56 | 57 | def __len__(self): 58 | try: 59 | return len(self.metadata) 60 | except TypeError: 61 | raise TypeError("Metadata has no length.") 62 | 63 | 64 | def get_padding_lengths(self): 65 | return {} 66 | 67 | 68 | def as_tensor(self, padding_lengths): 69 | return self.metadata 70 | 71 | 72 | def empty_field(self): 73 | return MetadataField(None) 74 | 75 | 76 | @classmethod 77 | def batch_tensors(cls, tensor_list): 78 | return tensor_list 79 | 80 | 81 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py#batch_tensor_dicts 82 | Update date: 2019-Nov-5''' 83 | def _batch_tensor_dicts(tensor_dicts): 84 | '''takes a list of tensor dictionaries, returns a single dictionary with all tensors with the same key batched''' 85 | key_to_tensors = defaultdict(list) 86 | for tensor_dict in tensor_dicts: 87 | for key, tensor in tensor_dict.items(): 88 | key_to_tensors[key].append(tensor) 89 | 90 | batched_tensors = {} 91 | for key, tensor_list in key_to_tensors.items(): 92 | batched_tensor = torch.stack(tensor_list) 93 | batched_tensors[key] = batched_tensor 94 | return batched_tensors 95 | 96 | 97 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/fields/text_field.py 98 | Update date: 2019-Nov-5''' 99 | class TextField(_Field): 100 | def __init__(self, tokens: List[Token], token_indexers): 101 | self.tokens = tokens 102 | self._token_indexers = token_indexers 103 | self._indexed_tokens = None 104 | self._indexer_name_to_indexed_token = None 105 | self._token_index_to_indexer_name = None 106 | 107 | 108 | def __iter__(self): 109 | return iter(self.tokens) 110 | 111 | 112 | def __getitem__(self, idx): 113 | return self.tokens[idx] 114 | 115 | 116 | def __len__(self): 117 | return len(self.tokens) 118 | 119 | 120 | def count_vocab_items(self, counter): 121 | for indexer in self._token_indexers.values(): 122 | for token in self.tokens: 123 | indexer.count_vocab_items(token, counter) 124 | 125 | 126 | def index(self, vocab): 127 | token_arrays = {} 128 | indexer_name_to_indexed_token = {} 129 | token_index_to_indexer_name = {} 130 | for indexer_name, indexer in self._token_indexers.items(): 131 | token_indices = indexer.tokens_to_indices(self.tokens, vocab, indexer_name) 132 | token_arrays.update(token_indices) 133 | indexer_name_to_indexed_token[indexer_name] = list(token_indices.keys()) 134 | for token_index in token_indices: 135 | token_index_to_indexer_name[token_index] = indexer_name 136 | self._indexed_tokens = token_arrays 137 | self._indexer_name_to_indexed_token = indexer_name_to_indexed_token 138 | self._token_index_to_indexer_name = token_index_to_indexer_name 139 | 140 | 141 | def get_padding_lengths(self): 142 | lengths = [] 143 | assert self._indexed_tokens is not None, "Call .index(vocabulary) before determining padding lengths." 144 | for indexer_name, indexer in self._token_indexers.items(): 145 | indexer_lengths = {} 146 | for indexed_tokens_key in self._indexer_name_to_indexed_token[indexer_name]: 147 | token_lengths = [indexer.get_padding_lengths(token) for token in self._indexed_tokens[indexed_tokens_key]] 148 | if not token_lengths: 149 | token_lengths = [indexer.get_padding_lengths([])] 150 | for key in token_lengths[0]: 151 | indexer_lengths[key] = max(x[key] if key in x else 0 for x in token_lengths) 152 | lengths.append(indexer_lengths) 153 | 154 | padding_lengths = {} 155 | num_tokens = set() 156 | for token_index, token_list in self._indexed_tokens.items(): 157 | indexer_name = self._token_index_to_indexer_name[token_index] 158 | indexer = self._token_indexers[indexer_name] 159 | padding_lengths[f"{token_index}_length"] = max(len(token_list), indexer.get_token_min_padding_length()) 160 | num_tokens.add(len(token_list)) 161 | padding_lengths["num_tokens"] = max(num_tokens) 162 | 163 | padding_keys = {key for d in lengths for key in d.keys()} 164 | for padding_key in padding_keys: 165 | padding_lengths[padding_key] = max(x[padding_key] if padding_key in x else 0 for x in lengths) 166 | return padding_lengths 167 | 168 | 169 | def sequence_length(self): 170 | return len(self.tokens) 171 | 172 | 173 | def as_tensor(self, padding_lengths): 174 | tensors = {} 175 | for indexer_name, indexer in self._token_indexers.items(): 176 | desired_num_tokens = {indexed_tokens_key: padding_lengths[f"{indexed_tokens_key}_length"] for 177 | indexed_tokens_key in self._indexer_name_to_indexed_token[indexer_name]} 178 | indices_to_pad = {indexed_tokens_key: self._indexed_tokens[indexed_tokens_key] for indexed_tokens_key in 179 | self._indexer_name_to_indexed_token[indexer_name]} 180 | padded_array = indexer.pad_token_sequence(indices_to_pad, desired_num_tokens, padding_lengths) 181 | indexer_tensors = {key: torch.LongTensor(array) for key, array in padded_array.items()} 182 | tensors.update(indexer_tensors) 183 | return tensors 184 | 185 | 186 | def empty_field(self): 187 | text_field = TextField([], self._token_indexers) 188 | text_field._indexed_tokens = {} 189 | text_field._indexer_name_to_indexed_token = {} 190 | for indexer_name, indexer in self._token_indexers.items(): 191 | array_keys = indexer.get_keys(indexer_name) 192 | for key in array_keys: 193 | text_field._indexed_tokens[key] = [] 194 | text_field._indexer_name_to_indexed_token[indexer_name] = array_keys 195 | return text_field 196 | 197 | 198 | def batch_tensors(self, tensor_dicts): 199 | return _batch_tensor_dicts(tensor_dicts) 200 | 201 | 202 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py#get_text_field_mask''' 203 | @classmethod 204 | def get_text_field_mask(cls, text_field_tensors: Dict[str, torch.Tensor]): 205 | if "mask" in text_field_tensors: 206 | return text_field_tensors["mask"] 207 | 208 | tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()] 209 | tensor_dims.sort(key=lambda x: x[0]) 210 | 211 | assert tensor_dims[0][0] == 2 212 | 213 | token_tensor = tensor_dims[0][1] 214 | return (token_tensor != 0).long() 215 | 216 | 217 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/instance.py 218 | An instance is a collection of Field objects, specifying the inputs and outputs to the model. 219 | Update date: 2019-Nov-5''' 220 | class Instance: 221 | def __init__(self, fields): 222 | self.fields = fields 223 | self.indexed = False 224 | 225 | 226 | def __getitem__(self, key): 227 | return self.fields[key] 228 | 229 | 230 | def __iter__(self): 231 | return iter(self.fields) 232 | 233 | 234 | def __len__(self): 235 | return len(self.fields) 236 | 237 | 238 | def add_field(self, field_name, field, vocab): 239 | self.fields[field_name] = field 240 | if self.indexed: 241 | field.index(vocab) 242 | 243 | 244 | def count_vocab_items(self, counter): 245 | for field in self.fields.values(): 246 | field.count_vocab_items(counter) 247 | 248 | 249 | def index_fields(self, vocab): 250 | if not self.indexed: 251 | self.indexed = True 252 | for field in self.fields.values(): 253 | field.index(vocab) 254 | 255 | 256 | def get_padding_lengths(self): 257 | lengths = {} 258 | for field_name, field in self.fields.items(): 259 | lengths[field_name] = field.get_padding_lengths() 260 | return lengths 261 | 262 | 263 | def as_tensor_dict(self, padding_lengths): 264 | padding_lengths = padding_lengths or self.get_padding_lengths() 265 | tensors = {} 266 | for field_name, field in self.fields.items(): 267 | tensors[field_name] = field.as_tensor(padding_lengths[field_name]) 268 | return tensors 269 | -------------------------------------------------------------------------------- /code/xdai/utils/iterator.py: -------------------------------------------------------------------------------- 1 | import itertools, math, random 2 | from collections import defaultdict 3 | from typing import cast, Dict, Iterable, List, Tuple 4 | 5 | 6 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/common/util.py#ensure_list 7 | Update date: 2019-Nov-18''' 8 | def ensure_list(iterable): 9 | return iterable if isinstance(iterable, list) else list(iterable) 10 | 11 | 12 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/dataset.py 13 | Update date: 2019-Nov-18''' 14 | class Batch(Iterable): 15 | def __init__(self, instances): 16 | super().__init__() 17 | self.instances = ensure_list(instances) 18 | 19 | 20 | '''return: {'tokens': {'tokens_length': 45, 'token_characters_length': 45, 'elmo_characters_length': 45, 21 | 'num_token_characters': 15}, 'tags': {'num_tokens': 45}})''' 22 | def get_padding_lengths(self): 23 | padding_lengths: Dict[str, Dict[str, int]] = defaultdict(dict) 24 | all_instance_lengths: List[Dict[str, Dict[str, int]]] = [i.get_padding_lengths() for i in self.instances] 25 | all_field_lengths: Dict[str, List[Dict[str, int]]] = defaultdict(list) 26 | for instance_lengths in all_instance_lengths: 27 | for field_name, instance_field_lengths in instance_lengths.items(): 28 | all_field_lengths[field_name].append(instance_field_lengths) 29 | 30 | for field_name, field_lengths in all_field_lengths.items(): 31 | for padding_key in field_lengths[0].keys(): 32 | max_value = max(x.get(padding_key, 0) for x in field_lengths) 33 | padding_lengths[field_name][padding_key] = max_value 34 | return padding_lengths 35 | 36 | 37 | def as_tensor_dict(self, padding_lengths: Dict[str, Dict[str, int]] = None): 38 | if padding_lengths is None: padding_lengths = defaultdict(dict) 39 | instance_padding_lengths = self.get_padding_lengths() 40 | lengths_to_use = defaultdict(dict) 41 | for field_name, instance_field_lengths in instance_padding_lengths.items(): 42 | for padding_key in instance_field_lengths.keys(): 43 | if padding_key in padding_lengths[field_name]: 44 | lengths_to_use[field_name][padding_key] = padding_lengths[field_name][padding_key] 45 | else: 46 | lengths_to_use[field_name][padding_key] = instance_field_lengths[padding_key] 47 | 48 | 49 | field_tensors: Dict[str, list] = defaultdict(list) 50 | for instance in self.instances: 51 | for field, tensors in instance.as_tensor_dict(lengths_to_use).items(): 52 | field_tensors[field].append(tensors) 53 | 54 | field_classes = self.instances[0].fields 55 | final_fields = {} 56 | for field_name, field_tensor_list in field_tensors.items(): 57 | final_fields[field_name] = field_classes[field_name].batch_tensors(field_tensor_list) 58 | return final_fields 59 | 60 | 61 | def __iter__(self): 62 | return iter(self.instances) 63 | 64 | 65 | def index_instances(self, vocab): 66 | for instance in self.instances: 67 | instance.index_fields(vocab) 68 | 69 | 70 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/data_iterator.py 71 | Update date: 2019-Nov-18''' 72 | class _Iterator: 73 | def __init__(self, batch_size=64, max_instances_in_memory=None, cache_instances=False): 74 | self.vocab = None 75 | self._batch_size = batch_size 76 | self._max_instances_in_memory = max_instances_in_memory 77 | self._cache_instances = cache_instances 78 | self._cache = defaultdict(list) 79 | 80 | 81 | def __call__(self, instances, shuffle=True): 82 | key = id(instances) 83 | 84 | if self._cache_instances and key in self._cache: 85 | tensor_dicts = self._cache[key] 86 | if shuffle: 87 | random.shuffle(tensor_dicts) 88 | for tensor_dict in tensor_dicts: 89 | yield tensor_dict 90 | else: 91 | batches = self._create_batches(instances, shuffle) 92 | add_to_cache = self._cache_instances and key not in self._cache 93 | for batch in batches: 94 | if self.vocab is not None: 95 | batch.index_instances(self.vocab) 96 | padding_lengths = batch.get_padding_lengths() 97 | tensor_dict = batch.as_tensor_dict(padding_lengths) 98 | if add_to_cache: 99 | self._cache[key].append(tensor_dict) 100 | yield tensor_dict 101 | 102 | 103 | def _take_instances(self, instances): 104 | yield from iter(instances) 105 | 106 | 107 | def _memory_sized_lists(self, instances): 108 | lazy = not isinstance(instances, list) 109 | iterator = self._take_instances(instances) 110 | if lazy and self._max_instances_in_memory is None: 111 | yield from _Iterator.lazy_groups_of(iterator, self._batch_size) 112 | elif self._max_instances_in_memory is not None: 113 | yield from _Iterator.lazy_groups_of(iterator, self._max_instances_in_memory) 114 | else: 115 | yield ensure_list(instances) 116 | 117 | 118 | def _ensure_batch_is_sufficiently_small(self, batch_instances): 119 | return [list(batch_instances)] 120 | 121 | 122 | def get_num_batches(self, instances): 123 | if not isinstance(instances, list): 124 | return 1 125 | return math.ceil(len(ensure_list(instances)) / self._batch_size) 126 | 127 | 128 | def index_with(self, vocab): 129 | self.vocab = vocab 130 | 131 | 132 | @classmethod 133 | def lazy_group_of(cls, iterator, group_size): 134 | return iter(lambda: list(itertools.islice(iterator, 0, group_size)), []) 135 | 136 | 137 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/basic_iterator.py 138 | Update date: 2019-Nov-19''' 139 | class BasicIterator(_Iterator): 140 | def _create_batches(self, instances, shuffle=True): 141 | for instance_list in self._memory_sized_lists(instances): 142 | if shuffle: 143 | random.shuffle(instance_list) 144 | iterator = iter(instance_list) 145 | for batch_instances in _Iterator.lazy_group_of(iterator, self._batch_size): 146 | for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances): 147 | yield Batch(possibly_smaller_batches) 148 | 149 | 150 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/common/util.py#add_noise_to_dict_values 151 | Update date: 2019-Nov-19''' 152 | def _add_noise_to_dict_values(dictionary, noise_param): 153 | dict_with_noise = {} 154 | for key, value in dictionary.items(): 155 | noise_value = value * noise_param 156 | noise = random.uniform(-noise_value, noise_value) 157 | dict_with_noise[key] = value + noise 158 | return dict_with_noise 159 | 160 | 161 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py#sort_by_padding 162 | Update date: 2019-Nov-19''' 163 | def _sort_by_padding(instances, sorting_keys: List[Tuple[str, str]], vocab, padding_noise=0.1): 164 | instances_with_lengths = [] 165 | for instance in instances: 166 | instance.index_fields(vocab) 167 | padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) 168 | 169 | if padding_noise > 0.0: 170 | noisy_lengths = {} 171 | for field_name, field_lengths in padding_lengths.items(): 172 | noisy_lengths[field_name] = _add_noise_to_dict_values(field_lengths, padding_noise) 173 | padding_lengths = noisy_lengths 174 | 175 | instance_with_lengths = ([padding_lengths[field_name][padding_key] 176 | for (field_name, padding_key) in sorting_keys], instance) 177 | instances_with_lengths.append(instance_with_lengths) 178 | instances_with_lengths.sort(key=lambda x: x[0]) 179 | sorted_instances = [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths] 180 | return sorted_instances 181 | 182 | 183 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py 184 | Update date: 2019-Nov-19''' 185 | class BucketIterator(_Iterator): 186 | def __init__(self, sorting_keys: List[Tuple[str, str]], padding_noise=0.1, batch_size=64, biggest_batch_first=False, 187 | max_instances_in_memory=None, cache_instances=False): 188 | super(BucketIterator, self).__init__(batch_size=batch_size, max_instances_in_memory=max_instances_in_memory, 189 | cache_instances=cache_instances) 190 | self._sorting_keys = sorting_keys 191 | self._padding_noise = padding_noise 192 | self._biggest_batch_first = biggest_batch_first 193 | 194 | 195 | def _create_batches(self, instances, shuffle=True): 196 | for instance_list in self._memory_sized_lists(instances): 197 | instance_list = _sort_by_padding(instance_list, self._sorting_keys, self.vocab, self._padding_noise) 198 | batches = [] 199 | for batch_instances in _Iterator.lazy_group_of(iter(instance_list), self._batch_size): 200 | for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances): 201 | batches.append(Batch(possibly_smaller_batches)) 202 | 203 | move_to_front = self._biggest_batch_first and len(batches) > 1 204 | if move_to_front: 205 | last_batch = batches.pop() 206 | penultimate_batch = batches.pop() 207 | if shuffle: 208 | random.shuffle(batches) 209 | if move_to_front: 210 | batches.insert(0, penultimate_batch) 211 | batches.insert(0, last_batch) 212 | 213 | yield from batches -------------------------------------------------------------------------------- /code/xdai/utils/train.py: -------------------------------------------------------------------------------- 1 | import logging, os, re, shutil, torch 2 | from tqdm import tqdm 3 | from typing import List 4 | from xdai.utils.common import move_to_gpu 5 | from xdai.utils.nn import enable_gradient_clipping, rescale_gradients 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | '''Update date: 2019-Nov-6''' 12 | class MetricTracker: 13 | def __init__(self, should_decrease, patience=None): 14 | self._best_so_far = None 15 | self._patience = patience 16 | self._epochs_with_no_improvement = 0 17 | self._is_best_so_far = True 18 | self.best_epoch_metrics = {} 19 | self._epoch_number = 0 20 | self.best_epoch = None 21 | self._should_decrease = should_decrease 22 | 23 | 24 | def clear(self) -> None: 25 | self._best_so_far = None 26 | self._epochs_with_no_improvement = 0 27 | self._is_best_so_far = True 28 | self._epoch_number = 0 29 | self.best_epoch = None 30 | 31 | 32 | def state_dict(self): 33 | return { 34 | "best_so_far": self._best_so_far, 35 | "patience": self._patience, 36 | "epochs_with_no_improvement": self._epochs_with_no_improvement, 37 | "is_best_so_far": self._is_best_so_far, 38 | "should_decrease": self._should_decrease, 39 | "best_epoch_metrics": self.best_epoch_metrics, 40 | "epoch_number": self._epoch_number, 41 | "best_epoch": self.best_epoch, 42 | } 43 | 44 | 45 | def load_state_dict(self, state_dict) -> None: 46 | self._best_so_far = state_dict["best_so_far"] 47 | self._patience = state_dict["patience"] 48 | self._epochs_with_no_improvement = state_dict["epochs_with_no_improvement"] 49 | self._is_best_so_far = state_dict["is_best_so_far"] 50 | self._should_decrease = state_dict["should_decrease"] 51 | self.best_epoch_metrics = state_dict["best_epoch_metrics"] 52 | self._epoch_number = state_dict["epoch_number"] 53 | self.best_epoch = state_dict["best_epoch"] 54 | 55 | 56 | def add_metric(self, metric): 57 | if self._best_so_far is None: 58 | new_best = True 59 | else: 60 | if self._should_decrease: 61 | if metric < self._best_so_far: 62 | new_best = True 63 | else: 64 | if metric > self._best_so_far: 65 | new_best = True 66 | 67 | if new_best: 68 | self.best_epoch = self._epoch_number 69 | self._is_best_so_far = True 70 | self._best_so_far = metric 71 | self._epochs_with_no_improvement = 0 72 | else: 73 | self._is_best_so_far = False 74 | self._epochs_with_no_improvement += 1 75 | self._epoch_number += 1 76 | 77 | 78 | def is_best_so_far(self) -> bool: 79 | return self._is_best_so_far 80 | 81 | 82 | def should_stop_early(self) -> bool: 83 | if self._patience is None: 84 | return False 85 | else: 86 | return self._epochs_with_no_improvement >= self._patience 87 | 88 | 89 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py (batch_loss) 90 | Update date: 2019-March-03''' 91 | def _batch_loss(args, model, batch): 92 | batch = move_to_gpu(batch, cuda_device=args.cuda_device[0]) 93 | output_dict = model(**batch) 94 | loss = output_dict.get("loss") 95 | return loss 96 | 97 | 98 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py (_validation_loss) 99 | Update date: 2019-April-20''' 100 | def _get_val_loss(args, model, iterator, data): 101 | model.eval() 102 | generator = iterator(data, shuffle=False) 103 | total_loss, batch_counter = 0.0, 0 104 | for batch in generator: 105 | batch_counter += 1 106 | _loss = _batch_loss(args, model, batch) 107 | if isinstance(_loss, float): 108 | total_loss += _loss 109 | else: 110 | total_loss += _loss.item() 111 | loss = float(total_loss / batch_counter) if batch_counter > 0 else 0.0 112 | return loss 113 | 114 | 115 | '''Update date: 2019-April-20''' 116 | def _is_best_model_so_far(this_epoch_score: float, score_per_epoch: List[float]): 117 | if not score_per_epoch: 118 | return True 119 | else: 120 | return this_epoch_score > max(score_per_epoch) 121 | 122 | 123 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py 124 | Update date: 2019-April-20''' 125 | def _output_metrics_to_console(train_metrics, dev_metrics={}): 126 | metric_names = list(train_metrics.keys()) + list(dev_metrics.keys()) 127 | metric_names = list(set(metric_names)) 128 | train_metrics = ["%s: %s" % (k, str(train_metrics.get(k, 0))) for k in metric_names] 129 | logger.info(" # Train set \n %s" % ("; ".join(train_metrics))) 130 | dev_metrics = ["%s: %s" % (k, str(dev_metrics.get(k, 0))) for k in metric_names] 131 | logger.info(" # Dev set \n %s" % ("; ".join(dev_metrics))) 132 | 133 | 134 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py#_save_checkpoint 135 | Update date: 2019-Nov-9''' 136 | def _save_checkpoint(model_dir, model, epoch, is_best=False): 137 | model_path = os.path.join(model_dir, "epoch_%s.th" % epoch) 138 | torch.save(model.state_dict(), model_path) 139 | if is_best: 140 | logger.info(" # Best dev performance so far. Copying weights to %s/best.th" % model_dir) 141 | shutil.copyfile(model_path, os.path.join(model_dir, "best.th")) 142 | 143 | 144 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py 145 | Update date: 2019-April-20''' 146 | def _should_early_stop(score_per_epoch: List[float], patience=0): 147 | if patience > 0 and patience < len(score_per_epoch): 148 | return max(score_per_epoch[-patience:]) <= max(score_per_epoch[:-patience]) 149 | return False 150 | 151 | 152 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py#_train_epoch 153 | Update date: 2019-Nov-9''' 154 | def _train_epoch(args, model, optimizer, iterator, data, shuffle=True): 155 | model.train() 156 | total_loss = 0.0 157 | generator = iterator(data, shuffle=shuffle) 158 | num_batches = iterator.get_num_batches(data) 159 | batch_counter = 0 160 | 161 | for batch in generator: 162 | batch_counter += 1 163 | optimizer.zero_grad() 164 | loss = _batch_loss(args, model, batch) 165 | loss.backward() 166 | total_loss += loss.item() 167 | rescale_gradients(model, args.max_grad_norm) 168 | optimizer.step() 169 | 170 | metrics = model.get_metrics(reset=False) 171 | metrics["loss"] = float(total_loss / batch_counter) if batch_counter > 0 else 0.0 172 | 173 | if batch_counter % args.logging_steps == 0 or batch_counter == num_batches: 174 | logger.info("%d out of %d batches, loss: %.3f" % (batch_counter, num_batches, metrics["loss"])) 175 | 176 | metrics = model.get_metrics(reset=True) 177 | metrics["loss"] = float(total_loss / batch_counter) if batch_counter > 0 else 0.0 178 | return metrics 179 | 180 | 181 | '''Update date: 2019-Nov-9''' 182 | def _check_max_save_checkpoints(output_dir, max_save_checkpoints, pattern=("epoch_", ".th")): 183 | if max_save_checkpoints < 0: return None 184 | checkpoints = [f for f in os.listdir(output_dir) if f.startswith(pattern[0]) and f.endswith(pattern[1])] 185 | if len(checkpoints) > max_save_checkpoints: 186 | numbers = sorted([int(re.findall("\d+", filename)[0]) for filename in checkpoints], reverse=True) 187 | for n in numbers[max_save_checkpoints:]: 188 | os.remove(os.path.join(output_dir, "%s%d%s" % (pattern[0], n, pattern[1]))) 189 | 190 | 191 | '''Reference url: https://github.com/allenai/allennlp/blob/master/allennlp/training/trainer.py#train 192 | Update date: 2019-Nov-9''' 193 | def train_op(args, model, optimizer, train_data, train_iterator, dev_data, dev_iterator): 194 | enable_gradient_clipping(model, args.grad_clipping) 195 | model_dir = args.output_dir 196 | max_epoches = args.num_train_epochs 197 | patience = args.patience 198 | validation_metric = args.eval_metric 199 | 200 | validation_metric_per_epoch = [] 201 | metrics = {} 202 | 203 | for epoch in range(0, max_epoches): 204 | logger.info("Epoch %d/%d" % (epoch + 1, max_epoches)) 205 | train_metrics = _train_epoch(args, model, optimizer, train_iterator, train_data) 206 | with torch.no_grad(): 207 | val_loss = _get_val_loss(args, model, dev_iterator, dev_data) 208 | val_metrics = model.get_metrics(reset=True) 209 | val_metrics["loss"] = val_loss 210 | this_epoch_val_metric = val_metrics[validation_metric] 211 | is_best = _is_best_model_so_far(this_epoch_val_metric, validation_metric_per_epoch) 212 | validation_metric_per_epoch.append(this_epoch_val_metric) 213 | 214 | _output_metrics_to_console(train_metrics, val_metrics) 215 | 216 | metrics["epoch"] = epoch 217 | for k, v in train_metrics.items(): 218 | metrics["training_" + k] = v 219 | for k, v in val_metrics.items(): 220 | metrics["validation_" + k] = v 221 | 222 | if is_best: 223 | metrics["best_epoch"] = epoch 224 | for k, v in val_metrics.items(): 225 | metrics["best_validation_" + k] = v 226 | 227 | _save_checkpoint(model_dir, model, epoch, is_best) 228 | _check_max_save_checkpoints(args.output_dir, args.max_save_checkpoints) 229 | 230 | if _should_early_stop(validation_metric_per_epoch, patience): 231 | logger.info(" # Ran out of patience. Stopping training.") 232 | break 233 | return metrics 234 | 235 | 236 | '''Update date: 2019-Nov-7''' 237 | def eval_op(args, model, data, data_iterator): 238 | sentences, predictions = [], [] 239 | with torch.no_grad(): 240 | model.eval() 241 | generator = data_iterator(data, shuffle=False) 242 | total_loss = 0.0 243 | for batch in tqdm(generator, desc="Evaluating"): 244 | sentences += batch["sentence"] 245 | batch = move_to_gpu(batch, args.cuda_device[0]) 246 | output_dict = model(**batch) 247 | loss = output_dict.get("loss", None) 248 | predictions += output_dict.get("preds") 249 | if loss: 250 | total_loss += loss.item() 251 | final_metrics = model.get_metrics(reset=True) 252 | final_metrics["loss"] = total_loss 253 | outputs = [(s, p) for s, p in zip(sentences, predictions)] 254 | return final_metrics, outputs -------------------------------------------------------------------------------- /data/cadec/split/train.id: -------------------------------------------------------------------------------- 1 | LIPITOR.774 2 | LIPITOR.960 3 | LIPITOR.896 4 | VOLTAREN.33 5 | LIPITOR.845 6 | LIPITOR.965 7 | LIPITOR.243 8 | LIPITOR.878 9 | LIPITOR.473 10 | LIPITOR.691 11 | LIPITOR.477 12 | LIPITOR.226 13 | LIPITOR.670 14 | LIPITOR.444 15 | LIPITOR.614 16 | LIPITOR.696 17 | LIPITOR.233 18 | LIPITOR.652 19 | LIPITOR.705 20 | ARTHROTEC.42 21 | DICLOFENAC-POTASSIUM.2 22 | LIPITOR.563 23 | LIPITOR.884 24 | LIPITOR.387 25 | LIPITOR.880 26 | LIPITOR.574 27 | LIPITOR.404 28 | DICLOFENAC-SODIUM.5 29 | LIPITOR.208 30 | LIPITOR.798 31 | LIPITOR.731 32 | LIPITOR.478 33 | ARTHROTEC.119 34 | LIPITOR.872 35 | LIPITOR.649 36 | LIPITOR.440 37 | LIPITOR.63 38 | LIPITOR.771 39 | LIPITOR.458 40 | LIPITOR.981 41 | CATAFLAM.3 42 | LIPITOR.966 43 | LIPITOR.389 44 | LIPITOR.719 45 | LIPITOR.863 46 | LIPITOR.592 47 | LIPITOR.332 48 | LIPITOR.386 49 | LIPITOR.304 50 | LIPITOR.848 51 | LIPITOR.599 52 | LIPITOR.631 53 | ARTHROTEC.62 54 | LIPITOR.919 55 | LIPITOR.138 56 | LIPITOR.227 57 | LIPITOR.139 58 | LIPITOR.112 59 | LIPITOR.270 60 | LIPITOR.210 61 | LIPITOR.388 62 | LIPITOR.596 63 | LIPITOR.155 64 | LIPITOR.339 65 | LIPITOR.984 66 | LIPITOR.766 67 | ARTHROTEC.72 68 | LIPITOR.542 69 | LIPITOR.132 70 | LIPITOR.929 71 | LIPITOR.488 72 | LIPITOR.946 73 | LIPITOR.151 74 | LIPITOR.117 75 | LIPITOR.395 76 | ARTHROTEC.71 77 | LIPITOR.531 78 | LIPITOR.289 79 | LIPITOR.706 80 | LIPITOR.104 81 | LIPITOR.947 82 | LIPITOR.219 83 | LIPITOR.865 84 | ARTHROTEC.89 85 | LIPITOR.551 86 | VOLTAREN-XR.17 87 | VOLTAREN-XR.13 88 | LIPITOR.816 89 | ARTHROTEC.69 90 | ARTHROTEC.107 91 | LIPITOR.957 92 | LIPITOR.626 93 | LIPITOR.751 94 | CATAFLAM.8 95 | LIPITOR.366 96 | LIPITOR.377 97 | LIPITOR.218 98 | LIPITOR.595 99 | LIPITOR.27 100 | LIPITOR.317 101 | VOLTAREN.32 102 | LIPITOR.630 103 | LIPITOR.943 104 | LIPITOR.824 105 | LIPITOR.341 106 | LIPITOR.95 107 | LIPITOR.550 108 | LIPITOR.718 109 | ARTHROTEC.76 110 | LIPITOR.616 111 | LIPITOR.422 112 | LIPITOR.861 113 | VOLTAREN.37 114 | LIPITOR.207 115 | LIPITOR.166 116 | ARTHROTEC.44 117 | ARTHROTEC.100 118 | LIPITOR.669 119 | ARTHROTEC.139 120 | LIPITOR.508 121 | LIPITOR.873 122 | LIPITOR.789 123 | LIPITOR.784 124 | LIPITOR.990 125 | LIPITOR.623 126 | LIPITOR.646 127 | LIPITOR.875 128 | ARTHROTEC.38 129 | LIPITOR.911 130 | LIPITOR.932 131 | LIPITOR.701 132 | LIPITOR.399 133 | LIPITOR.470 134 | LIPITOR.485 135 | LIPITOR.296 136 | LIPITOR.689 137 | LIPITOR.191 138 | LIPITOR.898 139 | LIPITOR.575 140 | LIPITOR.811 141 | LIPITOR.498 142 | LIPITOR.123 143 | LIPITOR.726 144 | LIPITOR.383 145 | LIPITOR.13 146 | VOLTAREN.31 147 | LIPITOR.262 148 | LIPITOR.615 149 | VOLTAREN-XR.3 150 | LIPITOR.413 151 | ARTHROTEC.24 152 | LIPITOR.435 153 | LIPITOR.62 154 | ARTHROTEC.51 155 | LIPITOR.968 156 | LIPITOR.500 157 | LIPITOR.204 158 | LIPITOR.619 159 | LIPITOR.187 160 | LIPITOR.371 161 | ARTHROTEC.114 162 | ARTHROTEC.85 163 | VOLTAREN.25 164 | ARTHROTEC.97 165 | LIPITOR.518 166 | ARTHROTEC.144 167 | LIPITOR.343 168 | ARTHROTEC.66 169 | LIPITOR.369 170 | CATAFLAM.2 171 | ARTHROTEC.133 172 | LIPITOR.287 173 | LIPITOR.220 174 | LIPITOR.9 175 | LIPITOR.30 176 | LIPITOR.735 177 | ARTHROTEC.80 178 | PENNSAID.4 179 | LIPITOR.801 180 | VOLTAREN-XR.11 181 | LIPITOR.668 182 | LIPITOR.143 183 | LIPITOR.430 184 | LIPITOR.988 185 | LIPITOR.436 186 | LIPITOR.247 187 | LIPITOR.888 188 | VOLTAREN.5 189 | LIPITOR.29 190 | LIPITOR.921 191 | LIPITOR.225 192 | LIPITOR.753 193 | LIPITOR.996 194 | LIPITOR.469 195 | LIPITOR.502 196 | ARTHROTEC.127 197 | VOLTAREN.14 198 | LIPITOR.683 199 | LIPITOR.522 200 | LIPITOR.438 201 | LIPITOR.199 202 | LIPITOR.313 203 | LIPITOR.648 204 | LIPITOR.680 205 | LIPITOR.826 206 | LIPITOR.237 207 | DICLOFENAC-POTASSIUM.3 208 | ARTHROTEC.9 209 | LIPITOR.953 210 | LIPITOR.822 211 | VOLTAREN.23 212 | LIPITOR.945 213 | CATAFLAM.5 214 | LIPITOR.887 215 | LIPITOR.698 216 | LIPITOR.102 217 | VOLTAREN-XR.9 218 | LIPITOR.985 219 | LIPITOR.434 220 | LIPITOR.4 221 | LIPITOR.830 222 | VOLTAREN.41 223 | LIPITOR.611 224 | LIPITOR.453 225 | LIPITOR.729 226 | LIPITOR.336 227 | LIPITOR.167 228 | VOLTAREN.10 229 | LIPITOR.964 230 | LIPITOR.384 231 | ARTHROTEC.129 232 | ARTHROTEC.60 233 | LIPITOR.61 234 | LIPITOR.999 235 | LIPITOR.515 236 | LIPITOR.663 237 | LIPITOR.770 238 | LIPITOR.659 239 | LIPITOR.290 240 | LIPITOR.255 241 | LIPITOR.867 242 | LIPITOR.900 243 | LIPITOR.935 244 | LIPITOR.519 245 | ARTHROTEC.18 246 | LIPITOR.322 247 | LIPITOR.172 248 | LIPITOR.278 249 | LIPITOR.969 250 | ARTHROTEC.59 251 | LIPITOR.820 252 | LIPITOR.897 253 | LIPITOR.3 254 | LIPITOR.419 255 | LIPITOR.702 256 | ARTHROTEC.17 257 | LIPITOR.35 258 | LIPITOR.193 259 | LIPITOR.973 260 | LIPITOR.357 261 | ARTHROTEC.94 262 | LIPITOR.955 263 | LIPITOR.385 264 | LIPITOR.618 265 | LIPITOR.405 266 | LIPITOR.169 267 | LIPITOR.724 268 | LIPITOR.812 269 | LIPITOR.202 270 | LIPITOR.128 271 | VOLTAREN.3 272 | LIPITOR.598 273 | LIPITOR.254 274 | VOLTAREN-XR.7 275 | LIPITOR.940 276 | VOLTAREN.39 277 | LIPITOR.293 278 | LIPITOR.249 279 | LIPITOR.69 280 | LIPITOR.40 281 | LIPITOR.536 282 | LIPITOR.288 283 | LIPITOR.539 284 | ARTHROTEC.86 285 | LIPITOR.337 286 | CATAFLAM.10 287 | LIPITOR.892 288 | LIPITOR.925 289 | LIPITOR.825 290 | LIPITOR.971 291 | VOLTAREN-XR.16 292 | SOLARAZE.1 293 | LIPITOR.294 294 | LIPITOR.756 295 | LIPITOR.118 296 | LIPITOR.189 297 | LIPITOR.203 298 | ARTHROTEC.6 299 | LIPITOR.565 300 | LIPITOR.501 301 | LIPITOR.709 302 | LIPITOR.443 303 | LIPITOR.244 304 | VOLTAREN.40 305 | LIPITOR.134 306 | LIPITOR.503 307 | ZIPSOR.3 308 | LIPITOR.423 309 | LIPITOR.838 310 | ARTHROTEC.21 311 | LIPITOR.334 312 | ARTHROTEC.105 313 | LIPITOR.891 314 | ARTHROTEC.5 315 | LIPITOR.183 316 | LIPITOR.411 317 | ARTHROTEC.35 318 | LIPITOR.909 319 | LIPITOR.677 320 | LIPITOR.31 321 | LIPITOR.152 322 | LIPITOR.401 323 | ARTHROTEC.145 324 | LIPITOR.297 325 | LIPITOR.802 326 | LIPITOR.617 327 | CATAFLAM.9 328 | LIPITOR.799 329 | LIPITOR.629 330 | LIPITOR.655 331 | LIPITOR.858 332 | ARTHROTEC.20 333 | LIPITOR.432 334 | LIPITOR.979 335 | LIPITOR.497 336 | LIPITOR.786 337 | LIPITOR.661 338 | ARTHROTEC.1 339 | LIPITOR.914 340 | LIPITOR.577 341 | ARTHROTEC.143 342 | LIPITOR.279 343 | LIPITOR.546 344 | VOLTAREN.17 345 | ARTHROTEC.55 346 | LIPITOR.505 347 | LIPITOR.107 348 | LIPITOR.356 349 | LIPITOR.106 350 | LIPITOR.258 351 | LIPITOR.350 352 | ARTHROTEC.68 353 | LIPITOR.841 354 | LIPITOR.948 355 | LIPITOR.778 356 | LIPITOR.814 357 | LIPITOR.410 358 | LIPITOR.513 359 | ARTHROTEC.41 360 | LIPITOR.697 361 | LIPITOR.607 362 | LIPITOR.904 363 | LIPITOR.252 364 | ARTHROTEC.103 365 | ARTHROTEC.81 366 | LIPITOR.849 367 | ARTHROTEC.91 368 | LIPITOR.421 369 | VOLTAREN.12 370 | LIPITOR.392 371 | LIPITOR.813 372 | LIPITOR.298 373 | LIPITOR.352 374 | LIPITOR.836 375 | LIPITOR.54 376 | LIPITOR.364 377 | LIPITOR.403 378 | LIPITOR.71 379 | VOLTAREN.2 380 | LIPITOR.175 381 | VOLTAREN.4 382 | LIPITOR.776 383 | LIPITOR.627 384 | LIPITOR.248 385 | LIPITOR.100 386 | LIPITOR.980 387 | LIPITOR.742 388 | LIPITOR.632 389 | LIPITOR.178 390 | LIPITOR.2 391 | LIPITOR.111 392 | LIPITOR.544 393 | LIPITOR.11 394 | LIPITOR.321 395 | VOLTAREN-XR.20 396 | ARTHROTEC.11 397 | ARTHROTEC.34 398 | VOLTAREN-XR.22 399 | LIPITOR.273 400 | LIPITOR.471 401 | LIPITOR.446 402 | LIPITOR.744 403 | ARTHROTEC.98 404 | LIPITOR.673 405 | LIPITOR.682 406 | LIPITOR.235 407 | LIPITOR.56 408 | LIPITOR.894 409 | LIPITOR.354 410 | LIPITOR.547 411 | LIPITOR.637 412 | LIPITOR.312 413 | LIPITOR.360 414 | LIPITOR.93 415 | LIPITOR.170 416 | LIPITOR.591 417 | LIPITOR.976 418 | ARTHROTEC.77 419 | VOLTAREN.21 420 | LIPITOR.781 421 | LIPITOR.253 422 | LIPITOR.333 423 | LIPITOR.335 424 | LIPITOR.660 425 | CATAFLAM.4 426 | ARTHROTEC.16 427 | LIPITOR.671 428 | LIPITOR.920 429 | ARTHROTEC.82 430 | LIPITOR.75 431 | LIPITOR.817 432 | LIPITOR.493 433 | LIPITOR.176 434 | VOLTAREN.28 435 | LIPITOR.717 436 | LIPITOR.639 437 | LIPITOR.721 438 | LIPITOR.895 439 | VOLTAREN.44 440 | LIPITOR.409 441 | LIPITOR.261 442 | LIPITOR.608 443 | LIPITOR.162 444 | LIPITOR.280 445 | LIPITOR.837 446 | LIPITOR.92 447 | LIPITOR.182 448 | ARTHROTEC.49 449 | LIPITOR.464 450 | ARTHROTEC.140 451 | LIPITOR.823 452 | LIPITOR.871 453 | LIPITOR.809 454 | ARTHROTEC.134 455 | LIPITOR.745 456 | LIPITOR.903 457 | LIPITOR.87 458 | LIPITOR.707 459 | LIPITOR.179 460 | LIPITOR.32 461 | LIPITOR.684 462 | LIPITOR.412 463 | LIPITOR.466 464 | LIPITOR.775 465 | LIPITOR.936 466 | LIPITOR.173 467 | LIPITOR.685 468 | LIPITOR.197 469 | LIPITOR.318 470 | LIPITOR.866 471 | LIPITOR.913 472 | LIPITOR.937 473 | LIPITOR.761 474 | LIPITOR.99 475 | ARTHROTEC.19 476 | LIPITOR.259 477 | LIPITOR.640 478 | LIPITOR.97 479 | LIPITOR.460 480 | LIPITOR.852 481 | LIPITOR.675 482 | VOLTAREN.16 483 | LIPITOR.303 484 | LIPITOR.55 485 | LIPITOR.889 486 | LIPITOR.292 487 | ARTHROTEC.64 488 | LIPITOR.462 489 | PENNSAID.3 490 | LIPITOR.642 491 | LIPITOR.561 492 | LIPITOR.728 493 | ARTHROTEC.124 494 | LIPITOR.274 495 | LIPITOR.149 496 | LIPITOR.39 497 | LIPITOR.650 498 | LIPITOR.876 499 | ARTHROTEC.25 500 | LIPITOR.251 501 | LIPITOR.135 502 | LIPITOR.952 503 | LIPITOR.853 504 | LIPITOR.116 505 | LIPITOR.160 506 | LIPITOR.38 507 | LIPITOR.146 508 | LIPITOR.428 509 | LIPITOR.594 510 | LIPITOR.347 511 | LIPITOR.216 512 | LIPITOR.439 513 | LIPITOR.782 514 | ARTHROTEC.110 515 | LIPITOR.899 516 | LIPITOR.96 517 | LIPITOR.168 518 | LIPITOR.643 519 | LIPITOR.791 520 | LIPITOR.314 521 | LIPITOR.147 522 | ARTHROTEC.84 523 | LIPITOR.918 524 | ARTHROTEC.10 525 | LIPITOR.998 526 | LIPITOR.209 527 | LIPITOR.847 528 | ARTHROTEC.33 529 | LIPITOR.527 530 | VOLTAREN-XR.19 531 | LIPITOR.340 532 | ARTHROTEC.118 533 | LIPITOR.362 534 | LIPITOR.86 535 | LIPITOR.378 536 | LIPITOR.566 537 | LIPITOR.523 538 | ARTHROTEC.30 539 | LIPITOR.236 540 | LIPITOR.555 541 | ARTHROTEC.39 542 | LIPITOR.855 543 | LIPITOR.42 544 | LIPITOR.710 545 | LIPITOR.833 546 | LIPITOR.122 547 | LIPITOR.420 548 | LIPITOR.25 549 | LIPITOR.114 550 | LIPITOR.815 551 | LIPITOR.928 552 | VOLTAREN.30 553 | ARTHROTEC.128 554 | LIPITOR.486 555 | VOLTAREN.45 556 | ARTHROTEC.14 557 | LIPITOR.491 558 | LIPITOR.906 559 | LIPITOR.232 560 | LIPITOR.136 561 | LIPITOR.829 562 | PENNSAID.2 563 | LIPITOR.17 564 | LIPITOR.305 565 | LIPITOR.621 566 | LIPITOR.72 567 | ZIPSOR.4 568 | LIPITOR.382 569 | LIPITOR.818 570 | LIPITOR.610 571 | LIPITOR.6 572 | LIPITOR.148 573 | LIPITOR.381 574 | LIPITOR.881 575 | LIPITOR.803 576 | LIPITOR.494 577 | LIPITOR.83 578 | LIPITOR.361 579 | LIPITOR.763 580 | LIPITOR.846 581 | LIPITOR.676 582 | LIPITOR.60 583 | LIPITOR.1000 584 | ARTHROTEC.93 585 | SOLARAZE.2 586 | LIPITOR.154 587 | CAMBIA.2 588 | LIPITOR.407 589 | LIPITOR.291 590 | LIPITOR.746 591 | LIPITOR.241 592 | LIPITOR.764 593 | LIPITOR.44 594 | LIPITOR.59 595 | LIPITOR.734 596 | LIPITOR.281 597 | LIPITOR.239 598 | VOLTAREN-XR.5 599 | LIPITOR.859 600 | ARTHROTEC.113 601 | LIPITOR.700 602 | LIPITOR.447 603 | LIPITOR.807 604 | LIPITOR.433 605 | ARTHROTEC.61 606 | LIPITOR.164 607 | LIPITOR.492 608 | LIPITOR.658 609 | LIPITOR.113 610 | LIPITOR.487 611 | LIPITOR.101 612 | LIPITOR.533 613 | LIPITOR.79 614 | LIPITOR.708 615 | LIPITOR.85 616 | LIPITOR.131 617 | LIPITOR.938 618 | LIPITOR.694 619 | LIPITOR.975 620 | LIPITOR.963 621 | LIPITOR.41 622 | LIPITOR.283 623 | LIPITOR.647 624 | LIPITOR.442 625 | LIPITOR.50 626 | LIPITOR.832 627 | LIPITOR.516 628 | LIPITOR.793 629 | LIPITOR.272 630 | ARTHROTEC.54 631 | LIPITOR.633 632 | VOLTAREN.46 633 | LIPITOR.88 634 | ARTHROTEC.53 635 | LIPITOR.600 636 | LIPITOR.73 637 | LIPITOR.590 638 | LIPITOR.730 639 | LIPITOR.58 640 | LIPITOR.285 641 | LIPITOR.52 642 | VOLTAREN.24 643 | LIPITOR.665 644 | LIPITOR.902 645 | VOLTAREN.27 646 | LIPITOR.80 647 | LIPITOR.456 648 | LIPITOR.851 649 | LIPITOR.573 650 | LIPITOR.843 651 | LIPITOR.275 652 | LIPITOR.8 653 | VOLTAREN.42 654 | LIPITOR.299 655 | LIPITOR.76 656 | LIPITOR.977 657 | VOLTAREN-XR.12 658 | LIPITOR.773 659 | LIPITOR.165 660 | ARTHROTEC.50 661 | LIPITOR.795 662 | LIPITOR.564 663 | LIPITOR.733 664 | LIPITOR.606 665 | LIPITOR.769 666 | ARTHROTEC.106 667 | LIPITOR.662 668 | LIPITOR.905 669 | LIPITOR.797 670 | LIPITOR.348 671 | LIPITOR.874 672 | ARTHROTEC.79 673 | LIPITOR.269 674 | LIPITOR.14 675 | ARTHROTEC.23 676 | LIPITOR.142 677 | LIPITOR.20 678 | LIPITOR.153 679 | VOLTAREN.29 680 | VOLTAREN-XR.14 681 | LIPITOR.240 682 | LIPITOR.472 683 | LIPITOR.323 684 | LIPITOR.534 685 | LIPITOR.201 686 | LIPITOR.437 687 | LIPITOR.250 688 | LIPITOR.690 689 | LIPITOR.787 690 | LIPITOR.520 691 | ARTHROTEC.45 692 | LIPITOR.901 693 | ARTHROTEC.96 694 | LIPITOR.810 695 | VOLTAREN.18 696 | LIPITOR.580 697 | LIPITOR.983 698 | ARTHROTEC.111 699 | LIPITOR.509 700 | LIPITOR.238 701 | LIPITOR.543 702 | LIPITOR.499 703 | LIPITOR.363 704 | ARTHROTEC.56 705 | LIPITOR.917 706 | LIPITOR.987 707 | LIPITOR.579 708 | LIPITOR.645 709 | ARTHROTEC.12 710 | LIPITOR.198 711 | VOLTAREN.13 712 | ARTHROTEC.112 713 | LIPITOR.571 714 | LIPITOR.860 715 | LIPITOR.186 716 | LIPITOR.620 717 | LIPITOR.958 718 | LIPITOR.584 719 | ARTHROTEC.130 720 | LIPITOR.326 721 | LIPITOR.450 722 | LIPITOR.24 723 | CAMBIA.4 724 | LIPITOR.893 725 | LIPITOR.260 726 | LIPITOR.418 727 | LIPITOR.43 728 | LIPITOR.64 729 | LIPITOR.51 730 | LIPITOR.715 731 | ZIPSOR.1 732 | LIPITOR.525 733 | LIPITOR.67 734 | LIPITOR.723 735 | ARTHROTEC.132 736 | ARTHROTEC.7 737 | LIPITOR.115 738 | LIPITOR.524 739 | LIPITOR.628 740 | LIPITOR.129 741 | LIPITOR.989 742 | LIPITOR.224 743 | ARTHROTEC.3 744 | LIPITOR.1 745 | DICLOFENAC-SODIUM.7 746 | LIPITOR.484 747 | LIPITOR.328 748 | LIPITOR.806 749 | LIPITOR.351 750 | LIPITOR.572 751 | DICLOFENAC-POTASSIUM.1 752 | LIPITOR.349 753 | LIPITOR.90 754 | LIPITOR.36 755 | ARTHROTEC.101 756 | LIPITOR.758 757 | LIPITOR.463 758 | LIPITOR.267 759 | LIPITOR.692 760 | LIPITOR.552 761 | LIPITOR.454 762 | LIPITOR.374 763 | ARTHROTEC.15 764 | LIPITOR.120 765 | ZIPSOR.2 766 | LIPITOR.12 767 | LIPITOR.457 768 | ARTHROTEC.141 769 | LIPITOR.554 770 | LIPITOR.604 771 | LIPITOR.229 772 | ARTHROTEC.78 773 | LIPITOR.605 774 | LIPITOR.750 775 | FLECTOR.1 776 | LIPITOR.205 777 | LIPITOR.588 778 | LIPITOR.713 779 | LIPITOR.748 780 | LIPITOR.962 781 | LIPITOR.140 782 | LIPITOR.844 783 | LIPITOR.535 784 | LIPITOR.687 785 | LIPITOR.391 786 | LIPITOR.26 787 | LIPITOR.449 788 | ARTHROTEC.29 789 | LIPITOR.37 790 | LIPITOR.589 791 | LIPITOR.576 792 | LIPITOR.16 793 | LIPITOR.358 794 | LIPITOR.19 795 | LIPITOR.414 796 | LIPITOR.664 797 | ARTHROTEC.4 798 | LIPITOR.695 799 | LIPITOR.738 800 | LIPITOR.200 801 | LIPITOR.704 802 | LIPITOR.5 803 | LIPITOR.214 804 | LIPITOR.171 805 | LIPITOR.886 806 | LIPITOR.956 807 | LIPITOR.624 808 | ARTHROTEC.125 809 | LIPITOR.954 810 | LIPITOR.345 811 | LIPITOR.504 812 | LIPITOR.393 813 | LIPITOR.674 814 | ARTHROTEC.131 815 | LIPITOR.857 816 | ARTHROTEC.8 817 | LIPITOR.582 818 | LIPITOR.933 819 | ARTHROTEC.137 820 | LIPITOR.581 821 | LIPITOR.587 822 | VOLTAREN.35 823 | ARTHROTEC.75 824 | LIPITOR.398 825 | LIPITOR.922 826 | LIPITOR.864 827 | LIPITOR.311 828 | LIPITOR.330 829 | LIPITOR.78 830 | LIPITOR.194 831 | CAMBIA.1 832 | LIPITOR.711 833 | VOLTAREN-XR.15 834 | LIPITOR.180 835 | LIPITOR.636 836 | LIPITOR.45 837 | LIPITOR.688 838 | LIPITOR.417 839 | VOLTAREN.43 840 | LIPITOR.408 841 | LIPITOR.483 842 | LIPITOR.195 843 | ARTHROTEC.47 844 | LIPITOR.944 845 | LIPITOR.739 846 | LIPITOR.560 847 | LIPITOR.856 848 | LIPITOR.578 849 | LIPITOR.479 850 | LIPITOR.144 851 | LIPITOR.747 852 | LIPITOR.89 853 | VOLTAREN.15 854 | LIPITOR.445 855 | ARTHROTEC.65 856 | LIPITOR.211 857 | LIPITOR.521 858 | LIPITOR.907 859 | LIPITOR.638 860 | LIPITOR.879 861 | LIPITOR.950 862 | ARTHROTEC.74 863 | LIPITOR.286 864 | LIPITOR.206 865 | LIPITOR.601 866 | LIPITOR.562 867 | LIPITOR.923 868 | LIPITOR.481 869 | LIPITOR.931 870 | LIPITOR.82 871 | LIPITOR.213 872 | LIPITOR.657 873 | LIPITOR.284 874 | LIPITOR.924 875 | LIPITOR.635 -------------------------------------------------------------------------------- /code/xdai/ner/mention.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | '''Update date: 2019-Nov-5''' 9 | class Span(object): 10 | def __init__(self, start, end): 11 | '''start and end are inclusive''' 12 | self.start = int(start) 13 | self.end = int(end) 14 | 15 | 16 | @classmethod 17 | def contains(cls, span1, span2): 18 | '''whether span1 contains span2, including equals''' 19 | return span1.start <= span2.start and span1.end >= span2.end 20 | 21 | 22 | @classmethod 23 | def equals(cls, span1, span2): 24 | '''whether span1 equals span2''' 25 | return span1.start == span2.start and span1.end == span2.end 26 | 27 | 28 | @classmethod 29 | def overlaps(cls, span1, span2): 30 | '''whether span1 overlaps with span2, including equals''' 31 | if span1.end < span2.start: return False 32 | if span1.start > span2.end: return False 33 | return True 34 | 35 | 36 | @property 37 | def length(self): 38 | return self.end + 1 - self.start 39 | 40 | 41 | def __str__(self): 42 | return "%d,%d" % (self.start, self.end) 43 | 44 | 45 | '''Update date: 2019-Nov-5''' 46 | def merge_consecutive_indices(indices: List[int]) -> List[int]: 47 | '''convert 136 142 143 147 into 136 147 (these two spans are actually consecutive), 48 | 136 142 143 147 148 160 into 136 160 (these three spans are consecutive) 49 | it only makes sense when these indices are inclusive''' 50 | consecutive_indices = [] 51 | assert len(indices) % 2 == 0 52 | for i, v in enumerate(indices): 53 | if (i == 0) or (i == len(indices) - 1): 54 | consecutive_indices.append(v) 55 | else: 56 | if i % 2 == 0: 57 | if v > indices[i - 1] + 1: 58 | consecutive_indices.append(v) 59 | else: 60 | if v + 1 < indices[i + 1]: 61 | consecutive_indices.append(v) 62 | assert len(consecutive_indices) % 2 == 0 and len(consecutive_indices) <= len(indices) 63 | if len(indices) != len(consecutive_indices): 64 | logger.debug("Convert from [%s] to [%s]." % ( 65 | " ".join([str(i) for i in indices]), " ".join([str(i) for i in consecutive_indices]))) 66 | return consecutive_indices 67 | 68 | 69 | '''Update date: 2019-Nov-15''' 70 | class Mention(object): 71 | def __init__(self, spans: List[Span], label: str): 72 | assert len(spans) >= 1 73 | self.spans = spans 74 | self.label = label 75 | 76 | # assume these spans are not consecutive and sorted by indices, needs to be done before creating the mention 77 | self.discontinuous = (len(spans) > 1) 78 | self._overlapping = False 79 | self._overlapping_spans = set() 80 | 81 | 82 | @property 83 | def start(self): 84 | return self.spans[0].start 85 | 86 | 87 | @property 88 | def end(self): 89 | return self.spans[-1].end 90 | 91 | 92 | @property 93 | def indices(self): 94 | return sorted([span.start for span in self.spans] + [span.end for span in self.spans]) 95 | 96 | 97 | @property 98 | def length(self): 99 | return sum([span.length for span in self.spans]) 100 | 101 | 102 | @property 103 | def interval_length(self): 104 | if self.discontinuous: 105 | return self.end + 1 - self.start - self.length 106 | return 0 107 | 108 | 109 | @property 110 | def overlapping(self): 111 | return len(self._overlapping_spans) > 0 112 | 113 | 114 | @property 115 | def overlap_at_left(self): 116 | return len(self._overlapping_spans) == 1 and list(self._overlapping_spans)[0] == 0 117 | 118 | 119 | @property 120 | def overlap_at_right(self): 121 | return len(self._overlapping_spans) == 1 and list(self._overlapping_spans)[0] == len(self.spans) - 1 122 | 123 | 124 | @classmethod 125 | def contains(cls, mention1, mention2): 126 | span2contained = 0 127 | for span2 in mention2.spans: 128 | for span1 in mention1.spans: 129 | if Span.contains(span1, span2): 130 | span2contained += 1 131 | break 132 | return span2contained == len(mention2.spans) 133 | 134 | 135 | @classmethod 136 | def equal_spans(cls, mention1, mention2): 137 | if len(mention1.spans) != len(mention2.spans): return False 138 | for span1, span2 in zip(mention1.spans, mention2.spans): 139 | if not Span.equals(span1, span2): 140 | return False 141 | return True 142 | 143 | 144 | @classmethod 145 | def equals(cls, mention1, mention2): 146 | return Mention.equal_spans(mention1, mention2) and mention1.label == mention2.label 147 | 148 | 149 | @classmethod 150 | def overlap_spans(cls, mention1, mention2): 151 | overlap_span = False 152 | for span1 in mention1.spans: 153 | for span2 in mention2.spans: 154 | if Span.overlaps(span1, span2): 155 | overlap_span = True 156 | break 157 | if overlap_span: break 158 | return overlap_span 159 | 160 | 161 | @classmethod 162 | def remove_discontinuous_mentions(cls, mentions): 163 | '''convert discontinuous mentions, such as 17,20,22,22 Disorder, to 17,22 Disorder''' 164 | continuous_mentions = [] 165 | for mention in mentions: 166 | if mention.discontinuous: 167 | continuous_mentions.append(Mention.create_mention([mention.start, mention.end], mention.label)) 168 | else: 169 | continuous_mentions.append(mention) 170 | return continuous_mentions 171 | 172 | 173 | @classmethod 174 | def remove_nested_mentions(cls, mentions): 175 | '''if an mention is contained completely by one other mention, get rid of the inner one.''' 176 | outer_mentions = [] 177 | for i in range(len(mentions)): 178 | nested = False 179 | for j in range(len(mentions)): 180 | if i == j: continue 181 | if Mention.contains(mentions[j], mentions[i]): 182 | assert not Mention.contains(mentions[j], mentions[i]), "TODO: multi-type mentions" 183 | nested = True 184 | break 185 | if not nested: 186 | outer_mentions.append(mentions[i]) 187 | return outer_mentions 188 | 189 | 190 | @classmethod 191 | def merge_overlapping_mentions(cls, mentions): 192 | ''' 193 | Given a list of mentions which may overlap with each other, erase these overlapping. 194 | For example 195 | 1) if an mention starts at 1, ends at 4, the other one starts at 3, ends at 5. 196 | Then group these together as one mention starting at 1, ending at 5 if they are of the same type, 197 | otherwise, raise an Error. 198 | ''' 199 | overlapping_may_exist = True 200 | while overlapping_may_exist: 201 | overlapping_may_exist = False 202 | merged_mentions = {} 203 | for i in range(len(mentions)): 204 | for j in range(len(mentions)): 205 | if i == j: continue 206 | if Mention.overlap_spans(mentions[i], mentions[j]): 207 | assert mentions[i].label == mentions[j].label, "TODO: two mentions of different types overlap" 208 | overlapping_may_exist = True 209 | merged_mention_start = min(mentions[i].start, mentions[j].start) 210 | merged_mention_end = max(mentions[i].end, mentions[j].end) 211 | merged_mention = Mention.create_mention([merged_mention_start, merged_mention_end], 212 | mentions[i].label) 213 | if (merged_mention_start, merged_mention_end) not in merged_mentions: 214 | merged_mentions[(merged_mention_start, merged_mention_end)] = merged_mention 215 | mentions[i]._overlapping_spans.add(0) 216 | mentions[j]._overlapping_spans.add(0) 217 | mentions = [mention for mention in mentions if not mention.overlapping] + list(merged_mentions.values()) 218 | return mentions 219 | 220 | 221 | @classmethod 222 | def create_mention(cls, indices: List[int], label: str): 223 | ''' 224 | the original indices can be 136,142,143,147, these two spans are actually consecutive, so convert to 136,147 225 | similarily, convert 136,142,143,147,148,160 into 136,160 (these three spans are consecutive) 226 | additionally, sort the indices: 119,125,92,96 to 92,96,119,125 227 | ''' 228 | assert len(indices) % 2 == 0 229 | indices = sorted(indices) 230 | indices = merge_consecutive_indices(indices) 231 | spans = [Span(indices[i], indices[i + 1]) for i in range(0, len(indices), 2)] 232 | return cls(spans, label) 233 | 234 | 235 | @classmethod 236 | def create_mentions(cls, mentions: str) -> List[object]: 237 | '''Input: 5,6 DATE|6,6 DAY|5,6 EVENT''' 238 | if len(mentions.strip()) == 0: return [] 239 | results = [] 240 | for mention in mentions.split("|"): 241 | indices, label = mention.split() 242 | indices = [int(i) for i in indices.split(",")] 243 | results.append(Mention.create_mention(indices, label)) 244 | return results 245 | 246 | 247 | @classmethod 248 | def check_overlap_spans(cls, mention1, mention2): 249 | overlap_span = False 250 | for i, span1 in enumerate(mention1.spans): 251 | for j, span2 in enumerate(mention2.spans): 252 | if Span.overlaps(span1, span2): 253 | overlap_span = True 254 | mention1._overlapping_spans.add(i) 255 | mention2._overlapping_spans.add(j) 256 | return overlap_span 257 | 258 | 259 | @classmethod 260 | def check_overlap_mentions(cls, mentions): 261 | for i in range(len(mentions)): 262 | for j in range(len(mentions)): 263 | if i == j: continue 264 | Mention.check_overlap_spans(mentions[i], mentions[j]) 265 | return mentions 266 | 267 | 268 | def print_text(self, tokens): 269 | print_tokens = [] 270 | indices = self.indices 271 | for i, token in enumerate(tokens): 272 | if i in indices: 273 | print_tokens.append("\x1b[6;30;42m%s\x1b[0m" % token) 274 | else: 275 | print_tokens.append(token) 276 | print("%s" % " ".join(print_tokens)) 277 | 278 | 279 | def __str__(self): 280 | spans = [str(s) for s in self.spans] 281 | return "%s %s" % (",".join(spans), self.label) 282 | 283 | 284 | '''Convert a list of BIO tags to a list of mentions 285 | this list of BIO tags should be in perfect format, for example, I- tag cannot follow a O tag. 286 | Update: 2019-Nov-1''' 287 | def bio_tags_to_mentions(bio_tags: List[str]) -> List[Mention]: 288 | mentions = [] 289 | i = 0 290 | while i < len(bio_tags): 291 | if bio_tags[i][0] == "B": 292 | start = i 293 | end = i 294 | label = bio_tags[i][2:] 295 | while end + 1 < len(bio_tags) and bio_tags[end + 1][0] == "I": 296 | assert bio_tags[end + 1][2:] == label 297 | end += 1 298 | mentions.append(Mention.create_mention([start, end], label)) 299 | i = end + 1 300 | else: 301 | i += 1 302 | return mentions 303 | 304 | 305 | '''Convert a list of BIOES tags to BIO tags 306 | Update: 2019-Oct-13''' 307 | def bioes_to_bio(bioes_tags): 308 | bio_tags = [] 309 | for tag in bioes_tags: 310 | if tag[0] == "O": 311 | bio_tags.append(tag) 312 | else: 313 | if tag[0] in ["B", "S"]: 314 | bio_tags.append("B-%s" % tag[2:]) 315 | else: 316 | if len(bio_tags) == 0 or bio_tags[-1] == "O": 317 | bio_tags.append("B-%s" % tag[2:]) 318 | else: 319 | if bio_tags[-1][1:] == tag[1:]: 320 | bio_tags.append("I-%s" % tag[2:]) 321 | else: 322 | bio_tags.append("B-%s" % tag[2:]) 323 | assert len(bio_tags) == len(bioes_tags) 324 | return bio_tags 325 | 326 | 327 | '''Convert a list of BIO tags to BIOES tags 328 | Update: 2019-Oct-13''' 329 | def bio_to_bioes(original_tags: List[str]) -> List[str]: 330 | def _change_prefix(original_tag, new_prefix): 331 | assert original_tag.find("-") > 0 and len(new_prefix) == 1 332 | chars = list(original_tag) 333 | chars[0] = new_prefix 334 | return "".join(chars) 335 | 336 | def _pop_replace_append(stack, bioes_sequence, new_prefix): 337 | tag = stack.pop() 338 | new_tag = _change_prefix(tag, new_prefix) 339 | bioes_sequence.append(new_tag) 340 | 341 | def _process_stack(stack, bioes_sequence): 342 | if len(stack) == 1: 343 | _pop_replace_append(stack, bioes_sequence, "S") 344 | else: 345 | recoded_stack = [] 346 | _pop_replace_append(stack, recoded_stack, "E") 347 | while len(stack) >= 2: 348 | _pop_replace_append(stack, recoded_stack, "I") 349 | _pop_replace_append(stack, recoded_stack, "B") 350 | recoded_stack.reverse() 351 | bioes_sequence.extend(recoded_stack) 352 | 353 | bioes_sequence = [] 354 | stack = [] 355 | 356 | for tag in original_tags: 357 | if tag == "O": 358 | if len(stack) == 0: 359 | bioes_sequence.append(tag) 360 | else: 361 | _process_stack(stack, bioes_sequence) 362 | bioes_sequence.append(tag) 363 | elif tag[0] == "I": 364 | if len(stack) == 0: 365 | stack.append(tag) 366 | else: 367 | this_type = tag[2:] 368 | prev_type = stack[-1][2:] 369 | if this_type == prev_type: 370 | stack.append(tag) 371 | else: 372 | _process_stack(stack, bioes_sequence) 373 | stack.append(tag) 374 | elif tag[0] == "B": 375 | if len(stack) > 0: 376 | _process_stack(stack, bioes_sequence) 377 | stack.append(tag) 378 | else: 379 | raise ValueError("Invalid tag:", tag) 380 | 381 | if len(stack) > 0: 382 | _process_stack(stack, bioes_sequence) 383 | 384 | return bioes_sequence 385 | 386 | 387 | '''Convert a list of mentions into BIO tags: 9,9 Drug|21,21 Drug|11,12 ADR 388 | Update: 2019-Oct-28''' 389 | def mentions_to_bio_tags(mentions: str, num_of_tokens: int): 390 | tags = ["O"] * num_of_tokens 391 | if len(mentions.strip()) == 0: return tags 392 | for mention in mentions.split("|"): 393 | indices, label = mention.split() 394 | sp = indices.split(",") 395 | assert len(sp) == 2, sp 396 | start, end = int(sp[0]), int(sp[1]) 397 | for i in range(start, end + 1): 398 | assert tags[i] == "O", mentions 399 | if i == start: 400 | tags[i] = "B-%s" % label 401 | else: 402 | tags[i] = "I-%s" % label 403 | return tags --------------------------------------------------------------------------------